aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/docs/Canonicalization.md2
-rw-r--r--mlir/docs/Dialects/Shard.md6
-rw-r--r--mlir/include/mlir-c/Rewrite.h2
-rw-r--r--mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h8
-rw-r--r--mlir/include/mlir/Conversion/Passes.td7
-rw-r--r--mlir/include/mlir/Dialect/Affine/LoopUtils.h2
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td5
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td26
-rw-r--r--mlir/include/mlir/Dialect/OpenACC/OpenACC.h4
-rw-r--r--mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td7
-rw-r--r--mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td14
-rw-r--r--mlir/include/mlir/Dialect/Shard/IR/ShardOps.td117
-rw-r--r--mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h48
-rw-r--r--mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc940
-rw-r--r--mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td33
-rw-r--r--mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h13
-rw-r--r--mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td7
-rw-r--r--mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td175
-rw-r--r--mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td6
-rw-r--r--mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td65
-rw-r--r--mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td62
-rw-r--r--mlir/include/mlir/TableGen/CodeGenHelpers.h2
-rw-r--r--mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h71
-rw-r--r--mlir/lib/Analysis/FlatLinearValueConstraints.cpp9
-rw-r--r--mlir/lib/Analysis/Presburger/Simplex.cpp2
-rw-r--r--mlir/lib/Bindings/Python/IRCore.cpp56
-rw-r--r--mlir/lib/Bindings/Python/Rewrite.cpp2
-rw-r--r--mlir/lib/CAPI/Transforms/Rewrite.cpp2
-rw-r--r--mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp2
-rw-r--r--mlir/lib/Conversion/MathToROCDL/CMakeLists.txt1
-rw-r--r--mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp76
-rw-r--r--mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp4
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp7
-rw-r--r--mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt1
-rw-r--r--mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp200
-rw-r--r--mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp8
-rw-r--r--mlir/lib/Dialect/Affine/IR/AffineOps.cpp150
-rw-r--r--mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp5
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp70
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp14
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp6
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp20
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp10
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp36
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp2
-rw-r--r--mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp20
-rw-r--r--mlir/lib/Dialect/Shard/IR/ShardOps.cpp2
-rw-r--r--mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp4
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp4
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp7
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp4
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp74
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformTypes.cpp6
-rw-r--r--mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp2
-rw-r--r--mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp141
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp146
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp122
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp7
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp3
-rw-r--r--mlir/lib/IR/MLIRContext.cpp2
-rw-r--r--mlir/lib/TableGen/CodeGenHelpers.cpp4
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleImport.cpp3
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleTranslation.cpp2
-rw-r--r--mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp5
-rw-r--r--mlir/lib/Target/Wasm/TranslateFromWasm.cpp459
-rw-r--r--mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp2
-rw-r--r--mlir/python/CMakeLists.txt10
-rw-r--r--mlir/python/mlir/dialects/OpenACCOps.td14
-rw-r--r--mlir/python/mlir/dialects/gpu/__init__.py146
-rw-r--r--mlir/python/mlir/dialects/openacc.py5
-rw-r--r--mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir76
-rw-r--r--mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir12
-rw-r--r--mlir/test/Conversion/XeGPUToXeVM/dpas.mlir8
-rw-r--r--mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir201
-rw-r--r--mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir29
-rw-r--r--mlir/test/Dialect/Affine/canonicalize.mlir130
-rw-r--r--mlir/test/Dialect/LLVMIR/canonicalize.mlir11
-rw-r--r--mlir/test/Dialect/LLVMIR/rocdl.mlir14
-rw-r--r--mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir181
-rw-r--r--mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir133
-rw-r--r--mlir/test/Dialect/MemRef/canonicalize.mlir30
-rw-r--r--mlir/test/Dialect/OpenACC/pointer-like-interface-alloc.mlir4
-rw-r--r--mlir/test/Dialect/OpenACC/recipe-populate-firstprivate.mlir10
-rw-r--r--mlir/test/Dialect/OpenACC/recipe-populate-private.mlir10
-rw-r--r--mlir/test/Dialect/Tensor/one-shot-bufferize.mlir29
-rw-r--r--mlir/test/Dialect/Tosa/tosa-attach-target.mlir8
-rw-r--r--mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir21
-rw-r--r--mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir20
-rw-r--r--mlir/test/Dialect/WasmSSA/custom_parser/global.mlir4
-rw-r--r--mlir/test/Dialect/WasmSSA/custom_parser/if.mlir8
-rw-r--r--mlir/test/Dialect/WasmSSA/custom_parser/import.mlir8
-rw-r--r--mlir/test/Dialect/WasmSSA/custom_parser/local.mlir12
-rw-r--r--mlir/test/Dialect/WasmSSA/custom_parser/memory.mlir8
-rw-r--r--mlir/test/Dialect/WasmSSA/custom_parser/table.mlir6
-rw-r--r--mlir/test/Dialect/WasmSSA/extend-invalid.mlir4
-rw-r--r--mlir/test/Dialect/WasmSSA/global-invalid.mlir12
-rw-r--r--mlir/test/Dialect/WasmSSA/locals-invalid.mlir4
-rw-r--r--mlir/test/Dialect/XeGPU/invalid.mlir36
-rw-r--r--mlir/test/Dialect/XeGPU/ops.mlir60
-rw-r--r--mlir/test/Target/LLVMIR/Import/function-attributes.ll6
-rw-r--r--mlir/test/Target/LLVMIR/llvmir.mlir11
-rw-r--r--mlir/test/Target/LLVMIR/nvvmir-invalid.mlir12
-rw-r--r--mlir/test/Target/LLVMIR/rocdl.mlir14
-rw-r--r--mlir/test/Target/Wasm/abs.mlir4
-rw-r--r--mlir/test/Target/Wasm/add_div.mlir40
-rw-r--r--mlir/test/Target/Wasm/and.mlir4
-rw-r--r--mlir/test/Target/Wasm/block.mlir16
-rw-r--r--mlir/test/Target/Wasm/block_complete_type.mlir24
-rw-r--r--mlir/test/Target/Wasm/block_value_type.mlir19
-rw-r--r--mlir/test/Target/Wasm/branch_if.mlir29
-rw-r--r--mlir/test/Target/Wasm/call.mlir17
-rw-r--r--mlir/test/Target/Wasm/clz.mlir4
-rw-r--r--mlir/test/Target/Wasm/comparison_ops.mlir269
-rw-r--r--mlir/test/Target/Wasm/const.mlir8
-rw-r--r--mlir/test/Target/Wasm/convert.mlir85
-rw-r--r--mlir/test/Target/Wasm/copysign.mlir4
-rw-r--r--mlir/test/Target/Wasm/ctz.mlir4
-rw-r--r--mlir/test/Target/Wasm/demote.mlir15
-rw-r--r--mlir/test/Target/Wasm/div.mlir20
-rw-r--r--mlir/test/Target/Wasm/double_nested_loop.mlir63
-rw-r--r--mlir/test/Target/Wasm/empty_blocks_list_and_stack.mlir53
-rw-r--r--mlir/test/Target/Wasm/eq.mlir56
-rw-r--r--mlir/test/Target/Wasm/eqz.mlir21
-rw-r--r--mlir/test/Target/Wasm/extend.mlir69
-rw-r--r--mlir/test/Target/Wasm/global.mlir16
-rw-r--r--mlir/test/Target/Wasm/if.mlir112
-rw-r--r--mlir/test/Target/Wasm/import.mlir12
-rw-r--r--mlir/test/Target/Wasm/inputs/add_div.yaml.wasm50
-rw-r--r--mlir/test/Target/Wasm/inputs/block.yaml.wasm22
-rw-r--r--mlir/test/Target/Wasm/inputs/block_complete_type.yaml.wasm23
-rw-r--r--mlir/test/Target/Wasm/inputs/block_value_type.yaml.wasm18
-rw-r--r--mlir/test/Target/Wasm/inputs/branch_if.yaml.wasm18
-rw-r--r--mlir/test/Target/Wasm/inputs/call.yaml.wasm26
-rw-r--r--mlir/test/Target/Wasm/inputs/comparison_ops.yaml.wasm88
-rw-r--r--mlir/test/Target/Wasm/inputs/convert.yaml.wasm69
-rw-r--r--mlir/test/Target/Wasm/inputs/demote.yaml.wasm18
-rw-r--r--mlir/test/Target/Wasm/inputs/double_nested_loop.yaml.wasm19
-rw-r--r--mlir/test/Target/Wasm/inputs/empty_blocks_list_and_stack.yaml.wasm21
-rw-r--r--mlir/test/Target/Wasm/inputs/eq.yaml.wasm27
-rw-r--r--mlir/test/Target/Wasm/inputs/eqz.yaml.wasm29
-rw-r--r--mlir/test/Target/Wasm/inputs/extend.yaml.wasm40
-rw-r--r--mlir/test/Target/Wasm/inputs/if.yaml.wasm25
-rw-r--r--mlir/test/Target/Wasm/inputs/loop.yaml.wasm17
-rw-r--r--mlir/test/Target/Wasm/inputs/loop_with_inst.yaml.wasm20
-rw-r--r--mlir/test/Target/Wasm/inputs/ne.yaml.wasm27
-rw-r--r--mlir/test/Target/Wasm/inputs/promote.yaml.wasm18
-rw-r--r--mlir/test/Target/Wasm/inputs/reinterpret.yaml.wasm53
-rw-r--r--mlir/test/Target/Wasm/inputs/rounding.yaml.wasm37
-rw-r--r--mlir/test/Target/Wasm/inputs/wrap.yaml.wasm24
-rw-r--r--mlir/test/Target/Wasm/invalid_block_type_index.yaml28
-rw-r--r--mlir/test/Target/Wasm/local.mlir6
-rw-r--r--mlir/test/Target/Wasm/loop.mlir17
-rw-r--r--mlir/test/Target/Wasm/loop_with_inst.mlir33
-rw-r--r--mlir/test/Target/Wasm/max.mlir4
-rw-r--r--mlir/test/Target/Wasm/memory_min_eq_max.mlir2
-rw-r--r--mlir/test/Target/Wasm/memory_min_max.mlir2
-rw-r--r--mlir/test/Target/Wasm/memory_min_no_max.mlir2
-rw-r--r--mlir/test/Target/Wasm/min.mlir4
-rw-r--r--mlir/test/Target/Wasm/ne.mlir52
-rw-r--r--mlir/test/Target/Wasm/neg.mlir4
-rw-r--r--mlir/test/Target/Wasm/or.mlir4
-rw-r--r--mlir/test/Target/Wasm/popcnt.mlir4
-rw-r--r--mlir/test/Target/Wasm/promote.mlir14
-rw-r--r--mlir/test/Target/Wasm/reinterpret.mlir46
-rw-r--r--mlir/test/Target/Wasm/rem.mlir8
-rw-r--r--mlir/test/Target/Wasm/rotl.mlir4
-rw-r--r--mlir/test/Target/Wasm/rotr.mlir4
-rw-r--r--mlir/test/Target/Wasm/rounding.mlir50
-rw-r--r--mlir/test/Target/Wasm/shl.mlir4
-rw-r--r--mlir/test/Target/Wasm/shr_s.mlir4
-rw-r--r--mlir/test/Target/Wasm/shr_u.mlir4
-rw-r--r--mlir/test/Target/Wasm/sqrt.mlir4
-rw-r--r--mlir/test/Target/Wasm/sub.mlir8
-rw-r--r--mlir/test/Target/Wasm/wrap.mlir15
-rw-r--r--mlir/test/Target/Wasm/xor.mlir4
-rw-r--r--mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp6
-rw-r--r--mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp6
-rw-r--r--mlir/test/mlir-pdll/CodeGen/CPP/general.pdll2
-rw-r--r--mlir/test/python/dialects/gpu/dialect.py93
-rw-r--r--mlir/test/python/dialects/openacc.py171
-rw-r--r--mlir/test/python/ir/operation.py28
-rw-r--r--mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp13
-rw-r--r--mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp2
-rw-r--r--mlir/tools/mlir-tblgen/RewriterGen.cpp2
-rw-r--r--mlir/unittests/Dialect/SparseTensor/MergerTest.cpp3
185 files changed, 5811 insertions, 1130 deletions
diff --git a/mlir/docs/Canonicalization.md b/mlir/docs/Canonicalization.md
index 686e500..2622c08 100644
--- a/mlir/docs/Canonicalization.md
+++ b/mlir/docs/Canonicalization.md
@@ -55,7 +55,7 @@ Some important things to think about w.r.t. canonicalization patterns:
* It is always good to eliminate operations entirely when possible, e.g. by
folding known identities (like "x + 0 = x").
-* Pattens with expensive running time (i.e. have O(n) complexity) or
+* Patterns with expensive running time (i.e. have O(n) complexity) or
complicated cost models don't belong to canonicalization: since the
algorithm is executed iteratively until fixed-point we want patterns that
execute quickly (in particular their matching phase).
diff --git a/mlir/docs/Dialects/Shard.md b/mlir/docs/Dialects/Shard.md
index eb6ff61..573e888 100644
--- a/mlir/docs/Dialects/Shard.md
+++ b/mlir/docs/Dialects/Shard.md
@@ -27,9 +27,9 @@ the tensor is sharded - not specified manually.
### Device Groups
-Each collective operation runs within a group of devices. You define groups
-using the `grid` and `grid_axes` attributes, which describe how to slice the
-full device grid into smaller groups.
+Collective operations run within groups of devices, which are defined
+using the `grid` and `grid_axes` attributes. These describe
+how the full device grid is sliced into smaller groups.
Devices that have the same coordinates *outside* the listed `grid_axes` belong
to the same group.
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 2db1d84..fe42a20 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -352,7 +352,7 @@ typedef struct {
/// Create a rewrite pattern that matches the operation
/// with the given rootName, corresponding to mlir::OpRewritePattern.
-MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePattenCreate(
+MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePatternCreate(
MlirStringRef rootName, unsigned benefit, MlirContext context,
MlirRewritePatternCallbacks callbacks, void *userData,
size_t nGeneratedNames, MlirStringRef *generatedNames);
diff --git a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
index 46573e79..60f1888 100644
--- a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
+++ b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
@@ -9,6 +9,7 @@
#define MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/IR/PatternMatch.h"
#include <memory>
@@ -19,8 +20,11 @@ class Pass;
#include "mlir/Conversion/Passes.h.inc"
/// Populate the given list with patterns that convert from Math to ROCDL calls.
-void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter,
- RewritePatternSet &patterns);
+// `chipset` specifies the AMDGPU chipset to target. If `std::nullopt`,
+// none of the chipset dependent patterns are added.
+void populateMathToROCDLConversionPatterns(
+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ std::optional<amdgpu::Chipset> chipset);
} // namespace mlir
#endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 25e9d34..9f76f5d 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -778,6 +778,10 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
let summary = "Convert Math dialect to ROCDL library calls";
let description = [{
This pass converts supported Math ops to ROCDL library calls.
+
+ The chipset option specifies the target AMDGPU architecture. If the chipset
+ is empty, none of the chipset-dependent patterns are added, and the pass
+ will not attempt to parse the chipset.
}];
let dependentDialects = [
"arith::ArithDialect",
@@ -785,6 +789,9 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
"ROCDL::ROCDLDialect",
"vector::VectorDialect",
];
+ let options = [Option<"chipset", "chipset", "std::string",
+ /*default=*/"\"\"",
+ "Chipset that these operations will run on">];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Affine/LoopUtils.h b/mlir/include/mlir/Dialect/Affine/LoopUtils.h
index 9b59af7..830c394 100644
--- a/mlir/include/mlir/Dialect/Affine/LoopUtils.h
+++ b/mlir/include/mlir/Dialect/Affine/LoopUtils.h
@@ -61,7 +61,7 @@ LogicalResult loopUnrollUpToFactor(AffineForOp forOp, uint64_t unrollFactor);
/// Returns true if `loops` is a perfectly nested loop nest, where loops appear
/// in it from outermost to innermost.
-bool LLVM_ATTRIBUTE_UNUSED isPerfectlyNested(ArrayRef<AffineForOp> loops);
+[[maybe_unused]] bool isPerfectlyNested(ArrayRef<AffineForOp> loops);
/// Get perfectly nested sequence of loops starting at root of loop nest
/// (the first op being another AffineFor, and the second op - a terminator).
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 9753dca..d0811a2 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -973,6 +973,7 @@ def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector",
custom<ShuffleType>(ref(type($v1)), type($res), ref($mask))
}];
+ let hasFolder = 1;
let hasVerifier = 1;
string llvmInstName = "ShuffleVector";
@@ -1985,6 +1986,7 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
OptionalAttr<StrAttr>:$instrument_function_exit,
OptionalAttr<UnitAttr>:$no_inline,
OptionalAttr<UnitAttr>:$always_inline,
+ OptionalAttr<UnitAttr>:$inline_hint,
OptionalAttr<UnitAttr>:$no_unwind,
OptionalAttr<UnitAttr>:$will_return,
OptionalAttr<UnitAttr>:$optimize_none,
@@ -2037,6 +2039,9 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
/// Returns true if the `always_inline` attribute is set, false otherwise.
bool isAlwaysInline() { return bool(getAlwaysInlineAttr()); }
+ /// Returns true if the `inline_hint` attribute is set, false otherwise.
+ bool isInlineHint() { return bool(getInlineHintAttr()); }
+
/// Returns true if the `optimize_none` attribute is set, false otherwise.
bool isOptimizeNone() { return bool(getOptimizeNoneAttr()); }
}];
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 6925cec..68f31e6 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -412,6 +412,32 @@ def ROCDL_WaitExpcntOp: ROCDL_ConcreteNonMemIntrOp<"s.wait.expcnt", [], 0, [0],
let assemblyFormat = "$count attr-dict";
}
+def ROCDL_WaitAsynccntOp: ROCDL_ConcreteNonMemIntrOp<"s.wait.asynccnt", [], 0, [0], ["count"]>,
+ Arguments<(ins I16Attr:$count)> {
+ let summary = "Wait until ASYNCCNT is less than or equal to `count`";
+ let description = [{
+ Wait for the counter specified to be less-than or equal-to the `count`
+ before continuing.
+
+ Available on gfx1250+.
+ }];
+ let results = (outs);
+ let assemblyFormat = "$count attr-dict";
+}
+
+def ROCDL_WaitTensorcntOp: ROCDL_ConcreteNonMemIntrOp<"s.wait.tensorcnt", [], 0, [0], ["count"]>,
+ Arguments<(ins I16Attr:$count)> {
+ let summary = "Wait until TENSORCNT is less than or equal to `count`";
+ let description = [{
+ Wait for the counter specified to be less-than or equal-to the `count`
+ before continuing.
+
+ Available on gfx1250+.
+ }];
+ let results = (outs);
+ let assemblyFormat = "$count attr-dict";
+}
+
def ROCDL_SetPrioOp : ROCDL_ConcreteNonMemIntrOp<"s.setprio", [], 0, [0], ["priority"]>,
Arguments<(ins I16Attr:$priority)> {
let assemblyFormat = "$priority attr-dict";
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
index 8f87235..b8aa497 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
@@ -183,6 +183,10 @@ static constexpr StringLiteral getRoutineInfoAttrName() {
return StringLiteral("acc.routine_info");
}
+static constexpr StringLiteral getVarNameAttrName() {
+ return VarNameAttr::name;
+}
+
static constexpr StringLiteral getCombinedConstructsAttrName() {
return CombinedConstructsTypeAttr::name;
}
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index fecf81b..1eaa21b4 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -415,6 +415,13 @@ def OpenACC_ConstructResource : Resource<"::mlir::acc::ConstructResource">;
// Define a resource for the OpenACC current device setting.
def OpenACC_CurrentDeviceIdResource : Resource<"::mlir::acc::CurrentDeviceIdResource">;
+// Attribute for saving variable names - this can be attached to non-acc-dialect
+// operations in order to ensure the name is preserved.
+def OpenACC_VarNameAttr : OpenACC_Attr<"VarName", "var_name"> {
+ let parameters = (ins StringRefParameter<"">:$name);
+ let assemblyFormat = "`<` $name `>`";
+}
+
// Used for data specification in data clauses (2.7.1).
// Either (or both) extent and upperbound must be specified.
def OpenACC_DataBoundsOp : OpenACC_Op<"bounds",
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td
index 6736bc8..93e9e3d 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td
@@ -73,12 +73,18 @@ def OpenACC_PointerLikeTypeInterface : TypeInterface<"PointerLikeType"> {
InterfaceMethod<
/*description=*/[{
Generates allocation operations for the pointer-like type. It will create
- an allocate that produces memory space for an instance of the current type.
+ an allocate operation that produces memory space for an instance of the
+ current type.
The `varName` parameter is optional and can be used to provide a name
- for the allocated variable. If the current type is represented
- in a way that it does not capture the pointee type, `varType` must be
- passed in to provide the necessary type information.
+ for the allocated variable. When provided, it must be used by the
+ implementation; and if the implementing dialect does not have its own
+ way to save it, the discardable `acc.var_name` attribute from the acc
+ dialect will be used.
+
+ If the current type is represented in a way that it does not capture
+ the pointee type, `varType` must be passed in to provide the necessary
+ type information.
The `originalVar` parameter is optional but enables support for dynamic
types (e.g., dynamic memrefs). When provided, implementations can extract
diff --git a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
index b9d7163..5e68f75e 100644
--- a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
+++ b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
@@ -494,7 +494,9 @@ def Shard_AllGatherOp : Shard_CollectiveCommunicationOpBase<"all_gather", [
]> {
let summary = "All-gather over a device grid.";
let description = [{
- Gathers along the `gather_axis` tensor axis.
+ Concatenates all tensor slices from a device group defined by `grid_axes` along
+ the tensor dimension `gather_axis` and replicates the result across all devices
+ in the group.
Example:
```mlir
@@ -546,10 +548,13 @@ def Shard_AllReduceOp : Shard_CollectiveCommunicationOpBase<"all_reduce", [
SameOperandsAndResultShape]> {
let summary = "All-reduce over a device grid.";
let description = [{
- The accumulation element type is specified by the result type and
- it does not need to match the input element type.
- The input element is converted to the result element type before
- performing the reduction.
+ Reduces the input tensor across all devices within the groups defined by
+ `grid_axes`, using the specified reduction method. The operation performs an
+ element-wise reduction over the tensor slices from all devices in each group.
+ Each device in a group receives a replicated copy of the reduction result.
+ The accumulation element type is determined by the result type and does not
+ need to match the input element type. Before performing the reduction, each
+ input element is converted to the result element type.
Attributes:
`reduction`: Indicates the reduction method.
@@ -583,13 +588,15 @@ def Shard_AllSliceOp : Shard_CollectiveCommunicationOpBase<"all_slice", [
SameOperandsAndResultElementType,
SameOperandsAndResultRank
]> {
- let summary = "All-slice over a device grid. This is the inverse of all-gather.";
+ let summary = "All-slice over a device grid.";
let description = [{
- Slice along the `slice_axis` tensor axis.
- This operation can be thought of as the inverse of all-gather.
- Technically, it is not required that all processes have the same input tensor.
- Each process will slice a piece of its local tensor based on its in-group device index.
- The operation does not communicate data between devices.
+ Within each device group defined by `grid_axes`, slices the input tensor along
+ the `slice_axis` dimension. It can be viewed as the inverse of an all-gather if
+ the input data is replicated along the `slice_axis`.
+ Each process simply crops its local data to the slice corresponding to its
+ in-group device index.
+ Notice: `AllSliceOp` does not involve any communication between devices and
+ devices within a group may not have replicated input data.
Example:
```mlir
@@ -610,7 +617,7 @@ def Shard_AllSliceOp : Shard_CollectiveCommunicationOpBase<"all_slice", [
```
Result:
```
- gather tensor
+ slice tensor
axis 1
------------>
+-------+-------+
@@ -646,8 +653,10 @@ def Shard_AllToAllOp : Shard_CollectiveCommunicationOpBase<"all_to_all", [
SameOperandsAndResultRank]> {
let summary = "All-to-all over a device grid.";
let description = [{
- Performs an all-to-all on tensor pieces split along `split_axis`.
- The resulting pieces are concatenated along `concat_axis` on ech device.
+ Each participant logically splits its input along split_axis,
+ then scatters the resulting pieces across the group defined by `grid_axes`.
+ After receiving data pieces from other participants' scatters,
+ it concatenates them along concat_axis to produce the final result.
Example:
```
@@ -702,10 +711,9 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [
]> {
let summary = "Broadcast over a device grid.";
let description = [{
- Broadcast the tensor on `root` to all devices in each respective group.
- The operation broadcasts along grid axes `grid_axes`.
- The `root` device specifies the in-group multi-index that is broadcast to
- all other devices in the group.
+ Copies the input tensor on `root` to all devices in each group defined by
+ `grid_axes`. The `root` device is defined by its in-group multi-index.
+ The contents of input tensors on non-root devices are ignored.
Example:
```
@@ -722,7 +730,7 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [
+-------+-------+ | broadcast
device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) | along axis 0
+-------+-------+ ↓
- device (1, 0) -> | | | <- device (1, 1)
+ device (1, 0) -> | * * | * * | <- device (1, 1)
+-------+-------+
```
@@ -758,11 +766,10 @@ def Shard_GatherOp : Shard_CollectiveCommunicationOpBase<"gather", [
]> {
let summary = "Gather over a device grid.";
let description = [{
- Gathers on device `root` along the `gather_axis` tensor axis.
- `root` specifies the coordinates of a device along `grid_axes`.
- It uniquely identifies the root device for each device group.
- The result tensor on non-root devices is undefined.
- Using it will result in undefined behavior.
+ Concatenates all tensor slices from a device group defined by `grid_axes` along
+ the tensor dimension `gather_axis` and returns the resulting tensor on each
+ `root` device. The result on all other (non-root) devices is undefined.
+ The `root` device is defined by its in-group multi-index.
Example:
```mlir
@@ -821,7 +828,9 @@ def Shard_RecvOp : Shard_CollectiveCommunicationOpBase<"recv", [
]> {
let summary = "Send over a device grid.";
let description = [{
- Receive from a device within a device group.
+ Receive tensor from device `source`, which is defined by its in-group
+ multi-index. The groups are defined by `grid_axes`.
+ The content of input tensor is ignored.
}];
let arguments = !con(commonArgs, (ins
AnyNon0RankedTensor:$input,
@@ -845,13 +854,15 @@ def Shard_ReduceOp : Shard_CollectiveCommunicationOpBase<"reduce", [
]> {
let summary = "Reduce over a device grid.";
let description = [{
- Reduces on device `root` within each device group.
- `root` specifies the coordinates of a device along `grid_axes`.
- It uniquely identifies the root device within its device group.
- The accumulation element type is specified by the result type and
- it does not need to match the input element type.
- The input element is converted to the result element type before
- performing the reduction.
+ Reduces the input tensor across all devices within the groups defined by
+ `grid_axes`, using the specified reduction method. The operation performs an
+ element-wise reduction over the tensor slices from all devices in each group.
+ The reduction result will be returned on the `root` device of each group.
+ It is undefined on all other (non-root) devices.
+ The `root` device is defined by its in-group multi-index.
+ The accumulation element type is determined by the result type and does not
+ need to match the input element type. Before performing the reduction, each
+ input element is converted to the result element type.
Attributes:
`reduction`: Indicates the reduction method.
@@ -886,16 +897,18 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter"
SameOperandsAndResultRank]> {
let summary = "Reduce-scatter over a device grid.";
let description = [{
- After the reduction, the result is scattered within each device group.
- The tensor is split along `scatter_axis` and the pieces distributed
- across the device group.
+ Reduces the input tensor across all devices within the groups defined by
+ `grid_axes` using the specified reduction method. The reduction is performed
+ element-wise across the tensor pieces from all devices in the group.
+ After reduction, the reduction result is scattered (split and distributed)
+ across the device group along `scatter_axis`.
Example:
```
shard.grid @grid0(shape = 2x2)
...
%1 = shard.reduce_scatter %0 on @grid0 grid_axes = [1]
reduction = <max> scatter_axis = 0
- : tensor<3x4xf32> -> tensor<1x4xf64>
+ : tensor<2x2xf32> -> tensor<1x2xf64>
```
Input:
```
@@ -916,13 +929,13 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter"
Result:
```
+-------+
- | 6 8 | <- devices (0, 0)
+ | 5 6 | <- devices (0, 0)
+-------+
- | 10 12 | <- devices (0, 1)
+ | 7 8 | <- devices (0, 1)
+-------+
- | 22 24 | <- devices (1, 0)
+ | 13 14 | <- devices (1, 0)
+-------+
- | 26 28 | <- devices (1, 1)
+ | 15 16 | <- devices (1, 1)
+-------+
```
}];
@@ -950,8 +963,10 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
]> {
let summary = "Scatter over a device grid.";
let description = [{
- For each device group split the input tensor on the `root` device along
- axis `scatter_axis` and scatter the parts across the group devices.
+ For each device group defined by `grid_axes`, the input tensor on the `root`
+ device is split along axis `scatter_axis` and distributed across the group.
+ The content of the input on all other (non-root) devices is ignored.
+ The `root` device is defined by its in-group multi-index.
Example:
```
@@ -968,8 +983,8 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
(0, 1)
↓
+-------+-------+ | scatter tensor
- device (0, 0) -> | | | | axis 0
- | | | ↓
+ device (0, 0) -> | * * | * * | | axis 0
+ | * * | * * | ↓
+-------+-------+
device (1, 0) -> | 1 2 | 5 6 |
| 3 4 | 7 8 |
@@ -1018,7 +1033,8 @@ def Shard_SendOp : Shard_CollectiveCommunicationOpBase<"send", [
]> {
let summary = "Send over a device grid.";
let description = [{
- Send from one device to another within a device group.
+ Send input tensor to device `destination`, which is defined by its in-group
+ multi-index. The groups are defined by `grid_axes`.
}];
let arguments = !con(commonArgs, (ins
AnyNon0RankedTensor:$input,
@@ -1043,12 +1059,11 @@ def Shard_ShiftOp : Shard_CollectiveCommunicationOpBase<"shift", [
]> {
let summary = "Shift over a device grid.";
let description = [{
- Within each device group shift along grid axis `shift_axis` by an offset
- `offset`.
- The result on devices that do not have a corresponding source is undefined.
- `shift_axis` must be one of `grid_axes`.
- If the `rotate` attribute is present,
- instead of a shift a rotation is done.
+ Within each device group defined by `grid_axes`, shifts input tensors along the
+ device grid's axis `shift_axis` by the specified offset. The `shift_axis` must
+ be one of the `grid_axes`. If the `rotate` attribute is set, the shift is circular.
+ That is, the offset wraps around according to the group size along `shift_axis`.
+ Otherwise, the results on devices without a corresponding source are undefined.
Example:
```
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
index 10491f6..4ecf03c 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
@@ -50,28 +50,63 @@ TargetEnvAttr getDefaultTargetEnv(MLIRContext *context);
/// returned by getDefaultTargetEnv() if not provided.
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op);
+/// A thin wrapper around the SpecificationVersion enum to represent
+/// and provide utilities around the TOSA specification version.
+class TosaSpecificationVersion {
+public:
+ TosaSpecificationVersion(uint32_t major, uint32_t minor)
+ : majorVersion(major), minorVersion(minor) {}
+ TosaSpecificationVersion(SpecificationVersion version)
+ : TosaSpecificationVersion(fromVersionEnum(version)) {}
+
+ bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const {
+ return this->majorVersion == baseVersion.majorVersion &&
+ this->minorVersion >= baseVersion.minorVersion;
+ }
+
+ uint32_t getMajor() const { return majorVersion; }
+ uint32_t getMinor() const { return minorVersion; }
+
+private:
+ uint32_t majorVersion = 0;
+ uint32_t minorVersion = 0;
+
+ static TosaSpecificationVersion
+ fromVersionEnum(SpecificationVersion version) {
+ switch (version) {
+ case SpecificationVersion::V_1_0:
+ return TosaSpecificationVersion(1, 0);
+ case SpecificationVersion::V_1_1_DRAFT:
+ return TosaSpecificationVersion(1, 1);
+ }
+ llvm_unreachable("Unknown TOSA version");
+ }
+};
+
+llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version);
+
/// This class represents the capability enabled in the target implementation
/// such as profile, extension, and level. It's a wrapper class around
/// tosa::TargetEnvAttr.
class TargetEnv {
public:
TargetEnv() {}
- explicit TargetEnv(Level level, const ArrayRef<Profile> &profiles,
+ explicit TargetEnv(SpecificationVersion specificationVersion, Level level,
+ const ArrayRef<Profile> &profiles,
const ArrayRef<Extension> &extensions)
- : level(level) {
+ : specificationVersion(specificationVersion), level(level) {
enabledProfiles.insert_range(profiles);
enabledExtensions.insert_range(extensions);
}
explicit TargetEnv(TargetEnvAttr targetAttr)
- : TargetEnv(targetAttr.getLevel(), targetAttr.getProfiles(),
- targetAttr.getExtensions()) {}
+ : TargetEnv(targetAttr.getSpecificationVersion(), targetAttr.getLevel(),
+ targetAttr.getProfiles(), targetAttr.getExtensions()) {}
void addProfile(Profile p) { enabledProfiles.insert(p); }
void addExtension(Extension e) { enabledExtensions.insert(e); }
- // TODO implement the following utilities.
- // Version getSpecVersion() const;
+ SpecificationVersion getSpecVersion() const { return specificationVersion; }
TosaLevel getLevel() const {
if (level == Level::eightK)
@@ -105,6 +140,7 @@ public:
}
private:
+ SpecificationVersion specificationVersion;
Level level;
llvm::SmallSet<Profile, 3> enabledProfiles;
llvm::SmallSet<Extension, 13> enabledExtensions;
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index 1f718ac..c1b5e78 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -2,441 +2,779 @@
// `tools/genspec.py` in https://git.mlplatform.org/tosa/specification.git
profileComplianceMap = {
{"tosa.argmax",
- {{{Profile::pro_int}, {{i8T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, i32T}, {fp32T, i32T}}}}},
+ {{{Profile::pro_int}, {{{i8T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, i32T}, SpecificationVersion::V_1_0},
+ {{fp32T, i32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.avg_pool2d",
- {{{Profile::pro_int}, {{i8T, i8T, i8T, i32T, i8T}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i8T, i32T, i8T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp16T, fp32T, fp16T},
- {fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+ {{{fp16T, fp16T, fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp16T, fp16T, fp16T, fp32T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.conv2d",
- {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
- {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+ {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T},
+ SpecificationVersion::V_1_0}}}}},
{"tosa.conv3d",
- {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
- {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+ {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T},
+ SpecificationVersion::V_1_0}}}}},
{"tosa.depthwise_conv2d",
- {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
- {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+ {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T},
+ SpecificationVersion::V_1_0}}}}},
{"tosa.matmul",
- {{{Profile::pro_int}, {{i8T, i8T, i8T, i8T, i32T}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i8T, i8T, i32T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp16T, fp16T, fp32T},
- {fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+ {{{fp16T, fp16T, fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp16T, fp16T, fp16T, fp16T, fp32T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.max_pool2d",
- {{{Profile::pro_int}, {{i8T, i8T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int}, {{{i8T, i8T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.transpose_conv2d",
- {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
- {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+ {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T},
+ SpecificationVersion::V_1_0}}}}},
{"tosa.clamp",
- {{{Profile::pro_int}, {{i8T, i8T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.erf", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.sigmoid", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.tanh", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int}, {{{i8T, i8T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.erf",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.sigmoid",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.tanh",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.add",
- {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.arithmetic_right_shift",
{{{Profile::pro_int},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.bitwise_and",
{{{Profile::pro_int},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.bitwise_or",
{{{Profile::pro_int},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.bitwise_xor",
{{{Profile::pro_int},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.intdiv",
- {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}},
+ anyOf}}},
{"tosa.logical_and",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf}}},
{"tosa.logical_left_shift",
{{{Profile::pro_int, Profile::pro_fp},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}},
anyOf}}},
{"tosa.logical_right_shift",
{{{Profile::pro_int, Profile::pro_fp},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}},
anyOf}}},
{"tosa.logical_or",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf}}},
{"tosa.logical_xor",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf}}},
{"tosa.maximum",
- {{{Profile::pro_int}, {{i32T, i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+ {{{Profile::pro_int}, {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.minimum",
- {{{Profile::pro_int}, {{i32T, i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+ {{{Profile::pro_int}, {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.mul",
- {{{Profile::pro_int}, {{i8T, i8T, i32T}, {i16T, i16T, i32T}}},
- {{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i32T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_int, Profile::pro_fp},
+ {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.pow",
- {{{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.sub",
- {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
- {"tosa.table", {{{Profile::pro_int}, {{i8T, i8T, i8T}}}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.table",
+ {{{Profile::pro_int}, {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}}}}},
{"tosa.abs",
- {{{Profile::pro_int}, {{i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int}, {{{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.bitwise_not",
- {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}}},
- {"tosa.ceil", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.clz", {{{Profile::pro_int}, {{i32T, i32T}}}}},
- {"tosa.cos", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.exp", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.floor", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.log", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.ceil",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.clz",
+ {{{Profile::pro_int}, {{{i32T, i32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.cos",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.exp",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.floor",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.log",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.logical_not",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf}}},
{"tosa.negate",
{{{Profile::pro_int},
- {{i8T, i8T, i8T, i8T},
- {i16T, i16T, i16T, i16T},
- {i32T, i32T, i32T, i32T}}},
+ {{{i8T, i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T, fp32T}}}}},
+ {{{fp16T, fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.reciprocal",
- {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.rsqrt", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.sin", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.rsqrt",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.sin",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.select",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
{{Profile::pro_int},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.equal",
- {{{Profile::pro_int}, {{i32T, i32T, boolT}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}},
+ {{{Profile::pro_int},
+ {{{i32T, i32T, boolT}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, boolT}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, boolT}, SpecificationVersion::V_1_0}}}}},
{"tosa.greater",
- {{{Profile::pro_int}, {{i32T, i32T, boolT}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}},
+ {{{Profile::pro_int},
+ {{{i32T, i32T, boolT}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, boolT}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, boolT}, SpecificationVersion::V_1_0}}}}},
{"tosa.greater_equal",
- {{{Profile::pro_int}, {{i32T, i32T, boolT}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}},
+ {{{Profile::pro_int},
+ {{{i32T, i32T, boolT}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, boolT}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, boolT}, SpecificationVersion::V_1_0}}}}},
{"tosa.reduce_all",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf}}},
{"tosa.reduce_any",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf}}},
{"tosa.reduce_max",
- {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.reduce_min",
- {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.reduce_product",
- {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.reduce_sum",
- {{{Profile::pro_int}, {{i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int}, {{{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.concat",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
- {{Profile::pro_int}, {{i8T, i8T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.pad",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
{{Profile::pro_int},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.reshape",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
- {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.reverse",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
- {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.slice",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
- {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.tile",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
- {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.transpose",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
- {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.gather",
{{{Profile::pro_int},
- {{i8T, i32T, i8T}, {i16T, i32T, i16T}, {i32T, i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, i32T, fp16T}, {fp32T, i32T, fp32T}}}}},
+ {{{i8T, i32T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i32T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, i32T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, i32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.scatter",
{{{Profile::pro_int},
- {{i8T, i32T, i8T, i8T},
- {i16T, i32T, i16T, i16T},
- {i32T, i32T, i32T, i32T}}},
+ {{{i8T, i32T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i32T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
- {{fp16T, i32T, fp16T, fp16T}, {fp32T, i32T, fp32T, fp32T}}}}},
+ {{{fp16T, i32T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, i32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.resize",
- {{{Profile::pro_int}, {{i8T, i32T}, {i8T, i8T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int},
+ {{{i8T, i32T}, SpecificationVersion::V_1_0},
+ {{i8T, i8T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.cast",
{{{Profile::pro_int},
- {{boolT, i8T},
- {boolT, i16T},
- {boolT, i32T},
- {i8T, boolT},
- {i8T, i16T},
- {i8T, i32T},
- {i16T, boolT},
- {i16T, i8T},
- {i16T, i32T},
- {i32T, boolT},
- {i32T, i8T},
- {i32T, i16T}}},
- {{Profile::pro_fp},
- {{i8T, fp16T},
- {i8T, fp32T},
- {i16T, fp16T},
- {i16T, fp32T},
- {i32T, fp16T},
- {i32T, fp32T},
- {fp16T, i8T},
- {fp16T, i16T},
- {fp16T, i32T},
- {fp16T, fp32T},
- {fp32T, i8T},
- {fp32T, i16T},
- {fp32T, i32T},
- {fp32T, fp16T}}}}},
+ {{{boolT, i8T}, SpecificationVersion::V_1_0},
+ {{boolT, i16T}, SpecificationVersion::V_1_0},
+ {{boolT, i32T}, SpecificationVersion::V_1_0},
+ {{i8T, boolT}, SpecificationVersion::V_1_0},
+ {{i8T, i16T}, SpecificationVersion::V_1_0},
+ {{i8T, i32T}, SpecificationVersion::V_1_0},
+ {{i16T, boolT}, SpecificationVersion::V_1_0},
+ {{i16T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i32T}, SpecificationVersion::V_1_0},
+ {{i32T, boolT}, SpecificationVersion::V_1_0},
+ {{i32T, i8T}, SpecificationVersion::V_1_0},
+ {{i32T, i16T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{i8T, fp16T}, SpecificationVersion::V_1_0},
+ {{i8T, fp32T}, SpecificationVersion::V_1_0},
+ {{i16T, fp16T}, SpecificationVersion::V_1_0},
+ {{i16T, fp32T}, SpecificationVersion::V_1_0},
+ {{i32T, fp16T}, SpecificationVersion::V_1_0},
+ {{i32T, fp32T}, SpecificationVersion::V_1_0},
+ {{fp16T, i8T}, SpecificationVersion::V_1_0},
+ {{fp16T, i16T}, SpecificationVersion::V_1_0},
+ {{fp16T, i32T}, SpecificationVersion::V_1_0},
+ {{fp16T, fp32T}, SpecificationVersion::V_1_0},
+ {{fp32T, i8T}, SpecificationVersion::V_1_0},
+ {{fp32T, i16T}, SpecificationVersion::V_1_0},
+ {{fp32T, i32T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.rescale",
{{{Profile::pro_int},
- {{i8T, i8T, i8T, i8T},
- {i8T, i8T, i16T, i16T},
- {i8T, i8T, i32T, i32T},
- {i16T, i16T, i8T, i8T},
- {i16T, i16T, i16T, i16T},
- {i16T, i16T, i32T, i32T},
- {i32T, i32T, i8T, i8T},
- {i32T, i32T, i16T, i16T},
- {i32T, i32T, i32T, i32T}}}}},
+ {{{i8T, i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i8T, i8T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i32T, i32T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.const",
{{{Profile::pro_int, Profile::pro_fp},
- {{boolT}, {i8T}, {i16T}, {i32T}},
+ {{{boolT}, SpecificationVersion::V_1_0},
+ {{i8T}, SpecificationVersion::V_1_0},
+ {{i16T}, SpecificationVersion::V_1_0},
+ {{i32T}, SpecificationVersion::V_1_0}},
anyOf},
- {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
+ {{Profile::pro_fp},
+ {{{fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.identity",
{{{Profile::pro_int, Profile::pro_fp},
- {{boolT, boolT}, {i8T, i8T}, {i16T, i16T}, {i32T, i32T}},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0},
+ {{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}},
anyOf},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.variable",
- {{{Profile::pro_int}, {{i8T}}}, {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
+ {{{Profile::pro_int}, {{{i8T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.variable_write",
- {{{Profile::pro_int}, {{i8T}}}, {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
+ {{{Profile::pro_int}, {{{i8T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.variable_read",
- {{{Profile::pro_int}, {{i8T}}}, {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
+ {{{Profile::pro_int}, {{{i8T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T}, SpecificationVersion::V_1_0}}}}},
};
extensionComplianceMap = {
{"tosa.argmax",
- {{{Extension::int16}, {{i16T, i32T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T, i32T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, i32T}}},
- {{Extension::bf16}, {{bf16T, i32T}}}}},
+ {{{Extension::int16}, {{{i16T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3}, {{{fp8e4m3T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2}, {{{fp8e5m2T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, i32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.avg_pool2d",
- {{{Extension::int16}, {{i16T, i16T, i16T, i32T, i16T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
+ {{{Extension::int16},
+ {{{i16T, i16T, i16T, i32T, i16T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T},
+ SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T},
+ SpecificationVersion::V_1_0}}},
+ {{Extension::bf16},
+ {{{bf16T, bf16T, bf16T, fp32T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.conv2d",
- {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}},
- {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}},
+ {{{Extension::int4},
+ {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::int16},
+ {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T},
+ SpecificationVersion::V_1_0}}},
{{Extension::fp8e4m3},
- {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}},
+ {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
{{Extension::fp8e5m2},
- {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}},
+ {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
{{Extension::bf16},
- {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
+ {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
+ SpecificationVersion::V_1_0}}}}},
{"tosa.conv3d",
- {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}},
- {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}},
+ {{{Extension::int4},
+ {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::int16},
+ {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T},
+ SpecificationVersion::V_1_0}}},
{{Extension::fp8e4m3},
- {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}},
+ {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
{{Extension::fp8e5m2},
- {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}},
+ {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
{{Extension::bf16},
- {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
+ {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
+ SpecificationVersion::V_1_0}}}}},
{"tosa.depthwise_conv2d",
- {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}},
- {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}},
+ {{{Extension::int4},
+ {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::int16},
+ {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T},
+ SpecificationVersion::V_1_0}}},
{{Extension::fp8e4m3},
- {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}},
+ {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
{{Extension::fp8e5m2},
- {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}},
+ {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
{{Extension::bf16},
- {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
- {"tosa.fft2d", {{{Extension::fft}, {{fp32T, fp32T, fp32T, fp32T}}}}},
+ {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
+ SpecificationVersion::V_1_0}}}}},
+ {"tosa.fft2d",
+ {{{Extension::fft},
+ {{{fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.matmul",
- {{{Extension::int16}, {{i16T, i16T, i16T, i16T, i48T}}},
+ {{{Extension::int16},
+ {{{i16T, i16T, i16T, i16T, i48T}, SpecificationVersion::V_1_0}}},
{{Extension::fp8e4m3},
- {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T},
- {fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp32T}}},
+ {{{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp32T},
+ SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::fp8e5m2},
- {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T},
- {fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp32T}}},
+ {{{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp32T},
+ SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::fp8e4m3, Extension::fp8e5m2},
- {{fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp16T},
- {fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp32T},
- {fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp16T},
- {fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp32T}},
+ {{{fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp16T},
+ SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp32T},
+ SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp16T},
+ SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp32T},
+ SpecificationVersion::V_1_1_DRAFT}},
allOf},
- {{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T, fp32T}}}}},
+ {{Extension::bf16},
+ {{{bf16T, bf16T, bf16T, bf16T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.max_pool2d",
- {{{Extension::int16}, {{i16T, i16T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.rfft2d", {{{Extension::fft}, {{fp32T, fp32T, fp32T}}}}},
+ {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.rfft2d",
+ {{{Extension::fft},
+ {{{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.transpose_conv2d",
- {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}},
- {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}},
+ {{{Extension::int4},
+ {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::int16},
+ {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T},
+ SpecificationVersion::V_1_0}}},
{{Extension::fp8e4m3},
- {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}},
+ {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
{{Extension::fp8e5m2},
- {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}},
+ {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
{{Extension::bf16},
- {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
+ {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
+ SpecificationVersion::V_1_0}}}}},
{"tosa.clamp",
- {{{Extension::int16}, {{i16T, i16T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.erf", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.sigmoid", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.tanh", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.add", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.maximum", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.minimum", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.mul", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.pow", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.sub", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.table", {{{Extension::int16}, {{i16T, i16T, i32T}}}}},
- {"tosa.abs", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.ceil", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.cos", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.exp", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.floor", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.log", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.negate", {{{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T}}}}},
- {"tosa.reciprocal", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.rsqrt", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.sin", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.select", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.equal", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}},
- {"tosa.greater", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}},
- {"tosa.greater_equal", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}},
- {"tosa.reduce_max", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.reduce_min", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.reduce_product", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.reduce_sum", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.erf",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.sigmoid",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.tanh",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.add",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.maximum",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.minimum",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.mul",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.pow",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.sub",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.table",
+ {{{Extension::int16},
+ {{{i16T, i16T, i32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.abs",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.ceil",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.cos",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.exp",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.floor",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.log",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.negate",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.reciprocal",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.rsqrt",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.sin",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.select",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.equal",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, boolT}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.greater",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, boolT}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.greater_equal",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, boolT}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.reduce_max",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.reduce_min",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.reduce_product",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.reduce_sum",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.concat",
- {{{Extension::int16}, {{i16T, i16T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.pad",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16},
+ {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.reshape",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.reverse",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.slice",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.tile",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.transpose",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.gather",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, i32T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, i32T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, i32T, bf16T}}}}},
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, i32T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, i32T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16},
+ {{{bf16T, i32T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.scatter",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, i32T, fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, i32T, fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, i32T, bf16T, bf16T}}}}},
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, i32T, fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, i32T, fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16},
+ {{{bf16T, i32T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.resize",
- {{{Extension::int16}, {{i16T, i48T}, {i16T, i16T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {{{Extension::int16},
+ {{{i16T, i48T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.cast",
{{{Extension::bf16},
- {{i8T, bf16T},
- {i16T, bf16T},
- {i32T, bf16T},
- {bf16T, i8T},
- {bf16T, i16T},
- {bf16T, i32T},
- {bf16T, fp32T},
- {fp32T, bf16T}}},
+ {{{i8T, bf16T}, SpecificationVersion::V_1_0},
+ {{i16T, bf16T}, SpecificationVersion::V_1_0},
+ {{i32T, bf16T}, SpecificationVersion::V_1_0},
+ {{bf16T, i8T}, SpecificationVersion::V_1_0},
+ {{bf16T, i16T}, SpecificationVersion::V_1_0},
+ {{bf16T, i32T}, SpecificationVersion::V_1_0},
+ {{bf16T, fp32T}, SpecificationVersion::V_1_0},
+ {{fp32T, bf16T}, SpecificationVersion::V_1_0}}},
{{Extension::bf16, Extension::fp8e4m3},
- {{bf16T, fp8e4m3T}, {fp8e4m3T, bf16T}},
+ {{{bf16T, fp8e4m3T}, SpecificationVersion::V_1_0},
+ {{fp8e4m3T, bf16T}, SpecificationVersion::V_1_0}},
allOf},
{{Extension::bf16, Extension::fp8e5m2},
- {{bf16T, fp8e5m2T}, {fp8e5m2T, bf16T}},
+ {{{bf16T, fp8e5m2T}, SpecificationVersion::V_1_0},
+ {{fp8e5m2T, bf16T}, SpecificationVersion::V_1_0}},
allOf},
{{Extension::fp8e4m3},
- {{fp8e4m3T, fp16T},
- {fp8e4m3T, fp32T},
- {fp16T, fp8e4m3T},
- {fp32T, fp8e4m3T}}},
+ {{{fp8e4m3T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp8e4m3T, fp32T}, SpecificationVersion::V_1_0},
+ {{fp16T, fp8e4m3T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
{{Extension::fp8e5m2},
- {{fp8e5m2T, fp16T},
- {fp8e5m2T, fp32T},
- {fp16T, fp8e5m2T},
- {fp32T, fp8e5m2T}}}}},
+ {{{fp8e5m2T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp8e5m2T, fp32T}, SpecificationVersion::V_1_0},
+ {{fp16T, fp8e5m2T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp8e5m2T}, SpecificationVersion::V_1_0}}}}},
{"tosa.rescale",
{{{Extension::int16},
- {{i48T, i48T, i8T, i8T},
- {i48T, i48T, i16T, i16T},
- {i48T, i48T, i32T, i32T}}}}},
+ {{{i48T, i48T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i48T, i48T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i48T, i48T, i32T, i32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.const",
- {{{Extension::int4}, {{i4T}}},
- {{Extension::int16}, {{i48T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T}}}}},
+ {{{Extension::int4}, {{{i4T}, SpecificationVersion::V_1_0}}},
+ {{Extension::int16}, {{{i48T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3}, {{{fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2}, {{{fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.identity",
- {{{Extension::int4}, {{i4T, i4T}}},
- {{Extension::int16}, {{i48T, i48T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.variable", {{{Extension::variable}, {{i8T}, {fp16T}, {fp32T}}}}},
+ {{{Extension::int4}, {{{i4T, i4T}, SpecificationVersion::V_1_0}}},
+ {{Extension::int16}, {{{i48T, i48T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.variable",
+ {{{Extension::variable},
+ {{{i8T}, SpecificationVersion::V_1_0},
+ {{fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.variable_write",
- {{{Extension::variable}, {{i8T}, {fp16T}, {fp32T}}}}},
+ {{{Extension::variable},
+ {{{i8T}, SpecificationVersion::V_1_0},
+ {{fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.variable_read",
- {{{Extension::variable}, {{i8T}, {fp16T}, {fp32T}}}}},
+ {{{Extension::variable},
+ {{{i8T}, SpecificationVersion::V_1_0},
+ {{fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T}, SpecificationVersion::V_1_0}}}}},
};
+
// End of auto-generated metadata
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 38cb293..8376a4c 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -221,7 +221,7 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic,
}
//===----------------------------------------------------------------------===//
-// TOSA Spec Section 1.5.
+// TOSA Profiles and extensions
//
// Profile:
// INT : Integer Inference. Integer operations, primarily 8 and 32-bit values.
@@ -293,12 +293,6 @@ def Tosa_ExtensionAttr
def Tosa_ExtensionArrayAttr
: TypedArrayAttrBase<Tosa_ExtensionAttr, "TOSA extension array attribute">;
-def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>;
-def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">;
-
-def Tosa_LevelAttr
- : Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>;
-
// The base class for defining op availability dimensions.
class Availability {
// The following are fields for controlling the generated C++ OpInterface.
@@ -405,17 +399,40 @@ class Extension<list<I32EnumAttrCase> extensions> : Availability {
}
//===----------------------------------------------------------------------===//
+// TOSA Levels
+//===----------------------------------------------------------------------===//
+
+def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>;
+def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">;
+
+def Tosa_LevelAttr
+ : Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>;
+
+//===----------------------------------------------------------------------===//
+// TOSA Specification versions
+//===----------------------------------------------------------------------===//
+
+def Tosa_V_1_0 : I32EnumAttrCase<"V_1_0", 0, "1.0">;
+def Tosa_V_1_1_DRAFT : I32EnumAttrCase<"V_1_1_DRAFT", 1, "1.1.draft">;
+
+def Tosa_SpecificationVersion : Tosa_I32EnumAttr<
+ "SpecificationVersion", "TOSA specification version", "specification_version",
+ [Tosa_V_1_0, Tosa_V_1_1_DRAFT]>;
+
+//===----------------------------------------------------------------------===//
// TOSA target environment.
//===----------------------------------------------------------------------===//
def Tosa_TargetEnv : Tosa_Attr<"TargetEnv", "target_env"> {
let summary = "Target environment information.";
let parameters = ( ins
+ "SpecificationVersion": $specification_version,
"Level": $level,
ArrayRefParameter<"Profile">: $profiles,
ArrayRefParameter<"Extension">: $extensions
);
- let assemblyFormat = "`<` `level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` "
+ let assemblyFormat = "`<` `specification_version` `=` $specification_version `,` "
+ "`level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` "
"`extensions` `=` `[` $extensions `]` `>`";
}
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index 8f5c72b..7b946ad 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -36,12 +36,15 @@ enum CheckCondition {
allOf
};
+using VersionedTypeInfo =
+ std::pair<SmallVector<TypeInfo>, SpecificationVersion>;
+
template <typename T>
struct OpComplianceInfo {
// Certain operations require multiple modes enabled.
// e.g. cast bf16 to fp8e4m3 requires EXT-BF16 and EXT-FP8E4M3.
SmallVector<T> mode;
- SmallVector<SmallVector<TypeInfo>> operandTypeInfoSet;
+ SmallVector<VersionedTypeInfo> operandTypeInfoSet;
CheckCondition condition = CheckCondition::anyOf;
};
@@ -130,9 +133,8 @@ public:
// Find the required profiles or extensions from the compliance info according
// to the operand type combination.
template <typename T>
- SmallVector<T> findMatchedProfile(Operation *op,
- SmallVector<OpComplianceInfo<T>> compInfo,
- CheckCondition &condition);
+ OpComplianceInfo<T>
+ findMatchedEntry(Operation *op, SmallVector<OpComplianceInfo<T>> compInfo);
SmallVector<Profile> getCooperativeProfiles(Extension ext) {
switch (ext) {
@@ -168,8 +170,7 @@ public:
private:
template <typename T>
- FailureOr<SmallVector<T>> getOperatorDefinition(Operation *op,
- CheckCondition &condition);
+ FailureOr<OpComplianceInfo<T>> getOperatorDefinition(Operation *op);
OperationProfileComplianceMap profileComplianceMap;
OperationExtensionComplianceMap extensionComplianceMap;
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 6ae19d8..14b00b0 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -137,6 +137,13 @@ def TosaAttachTarget : Pass<"tosa-attach-target", "ModuleOp"> {
];
let options = [
+ Option<"specificationVersion", "specification_version", "mlir::tosa::SpecificationVersion",
+ /*default=*/"mlir::tosa::SpecificationVersion::V_1_0",
+ "The specification version that TOSA operators should conform to.",
+ [{::llvm::cl::values(
+ clEnumValN(mlir::tosa::SpecificationVersion::V_1_0, "1.0", "TOSA Specification version 1.0"),
+ clEnumValN(mlir::tosa::SpecificationVersion::V_1_1_DRAFT, "1.1.draft", "TOSA Specification version 1.1.draft")
+ )}]>,
Option<"level", "level", "mlir::tosa::Level",
/*default=*/"mlir::tosa::Level::eightK",
"The TOSA level that operators should conform to. A TOSA level defines "
diff --git a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td
index b80ee2c..e9425e8 100644
--- a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td
+++ b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td
@@ -43,9 +43,41 @@ class WasmSSA_BlockLikeOp<string mnemonic, string summaryStr> :
let assemblyFormat = "(`(`$inputs^`)` `:` type($inputs))? attr-dict `:` $body `>` $target";
}
-def WasmSSA_BlockOp : WasmSSA_BlockLikeOp<"block", "Create a nesting level"> {}
+def WasmSSA_BlockOp : WasmSSA_BlockLikeOp<
+ "block",
+ "Create a nesting level with a label at its exit."> {
+ let description = [{
+ Defines a Wasm block, creating a new nested scope.
+ A block contains a body region and an optional list of input values.
+ Control can enter the block and later branch out to the block target.
+ Example:
+
+ ```mlir
+
+ wasmssa.block {
+
+ // instructions
+
+ } > ^successor
+ }];
+}
+
+def WasmSSA_LoopOp : WasmSSA_BlockLikeOp<
+ "loop",
+ "Create a nesting level that define its entry as jump target."> {
+ let description = [{
+ Represents a Wasm loop construct. This defines a nesting level with
+ a label at the entry of the region.
-def WasmSSA_LoopOp : WasmSSA_BlockLikeOp<"loop", "Create a nesting level similar to Block Op, except that it has itself as a successor."> {}
+ Example:
+
+ ```mlir
+
+ wasmssa.loop {
+
+ } > ^successor
+ }];
+}
def WasmSSA_BlockReturnOp : WasmSSA_Op<"block_return", [Terminator,
DeclareOpInterfaceMethods<LabelBranchingOpInterface>]> {
@@ -55,9 +87,16 @@ def WasmSSA_BlockReturnOp : WasmSSA_Op<"block_return", [Terminator,
::mlir::Block* getTarget();
}];
let description = [{
- Marks a return from the current block.
+ Escape from the current nesting level and return the control flow to its successor.
+ Optionally, mark the arguments that should be transfered to the successor block.
- Example:
+ This shouldn't be confused with branch operations that targets the label defined
+ by the nesting level operation.
+
+ For instance, a `wasmssa.block_return` in a loop will give back control to the
+ successor of the loop, where a `branch` targeting the loop will flow back to the entry block of the loop.
+
+ Example:
```mlir
wasmssa.block_return
@@ -127,12 +166,18 @@ def WasmSSA_FuncOp : WasmSSA_Op<"func", [
- Arguments of the entry block of type `!wasm<local T>`, with T the corresponding type
in the function type.
+ By default, `wasmssa.func` have nested visibility. Functions exported by the module
+ are marked with the exported attribute. This gives them public visibility.
+
Example:
```mlir
- // A simple function with no arguments that returns a float32
+ // Internal function with no arguments that returns a float32
wasmssa.func @my_f32_func() -> f32
+ // Exported function with no arguments that returns a float32
+ wasmssa.func exported @my_f32_func() -> f32
+
// A function that takes a local ref argument
wasmssa.func @i64_wrap(%a: !wasmssa<local ref to i64>) -> i32
```
@@ -141,7 +186,7 @@ def WasmSSA_FuncOp : WasmSSA_Op<"func", [
WasmSSA_FuncTypeAttr: $functionType,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs,
- DefaultValuedAttr<StrAttr, "\"nested\"">:$sym_visibility);
+ UnitAttr: $exported);
let regions = (region AnyRegion: $body);
let extraClassDeclaration = [{
@@ -162,6 +207,12 @@ def WasmSSA_FuncOp : WasmSSA_Op<"func", [
/// Returns the result types of this function.
ArrayRef<Type> getResultTypes() { return getFunctionType().getResults(); }
+
+ ::mlir::SymbolTable::Visibility getVisibility() {
+ return getExported() ?
+ ::mlir::SymbolTable::Visibility::Public :
+ ::mlir::SymbolTable::Visibility::Nested;
+ };
}];
let builders = [
@@ -207,8 +258,7 @@ def WasmSSA_FuncImportOp : WasmSSA_Op<"import_func", [
StrAttr: $importName,
WasmSSA_FuncTypeAttr: $type,
OptionalAttr<DictArrayAttr>:$arg_attrs,
- OptionalAttr<DictArrayAttr>:$res_attrs,
- OptionalAttr<StrAttr>:$sym_visibility);
+ OptionalAttr<DictArrayAttr>:$res_attrs);
let extraClassDeclaration = [{
bool isDeclaration() const { return true; }
@@ -221,6 +271,10 @@ def WasmSSA_FuncImportOp : WasmSSA_Op<"import_func", [
::llvm::ArrayRef<Type> getResultTypes() {
return getType().getResults();
}
+
+ ::mlir::SymbolTable::Visibility getVisibility() {
+ return ::mlir::SymbolTable::Visibility::Nested;
+ };
}];
let builders = [
OpBuilder<(ins "StringRef":$symbol,
@@ -238,30 +292,41 @@ def WasmSSA_GlobalOp : WasmSSA_Op<"global", [
let arguments = (ins SymbolNameAttr: $sym_name,
WasmSSA_ValTypeAttr: $type,
UnitAttr: $isMutable,
- OptionalAttr<StrAttr>:$sym_visibility);
+ UnitAttr: $exported);
let description = [{
WebAssembly global variable.
Body contains the initialization instructions for the variable value.
The body must contain only instructions considered `const` in a webassembly context,
such as `wasmssa.const` or `global.get`.
+ By default, `wasmssa.global` have nested visibility. Global exported by the module
+ are marked with the exported attribute. This gives them public visibility.
+
Example:
```mlir
- // Define a global_var, a mutable i32 global variable equal to 10.
- wasmssa.global @global_var i32 mutable nested : {
+ // Define module_global_var, an internal mutable i32 global variable equal to 10.
+ wasmssa.global @module_global_var i32 mutable : {
%[[VAL_0:.*]] = wasmssa.const 10 : i32
wasmssa.return %[[VAL_0]] : i32
}
+
+ // Define global_var, an exported constant i32 global variable equal to 42.
+ wasmssa.global @global_var i32 : {
+ %[[VAL_0:.*]] = wasmssa.const 42 : i32
+ wasmssa.return %[[VAL_0]] : i32
+ }
```
}];
let regions = (region AnyRegion: $initializer);
- let builders = [
- OpBuilder<(ins "StringRef":$symbol,
- "Type": $type,
- "bool": $isMutable)>
- ];
+ let extraClassDeclaration = [{
+ ::mlir::SymbolTable::Visibility getVisibility() {
+ return getExported() ?
+ ::mlir::SymbolTable::Visibility::Public :
+ ::mlir::SymbolTable::Visibility::Nested;
+ };
+ }];
let hasCustomAssemblyFormat = 1;
}
@@ -283,18 +348,14 @@ def WasmSSA_GlobalImportOp : WasmSSA_Op<"import_global", [
StrAttr: $moduleName,
StrAttr: $importName,
WasmSSA_ValTypeAttr: $type,
- UnitAttr: $isMutable,
- OptionalAttr<StrAttr>:$sym_visibility);
+ UnitAttr: $isMutable);
let extraClassDeclaration = [{
bool isDeclaration() const { return true; }
+
+ ::mlir::SymbolTable::Visibility getVisibility() {
+ return ::mlir::SymbolTable::Visibility::Nested;
+ };
}];
- let builders = [
- OpBuilder<(ins "StringRef":$symbol,
- "StringRef":$moduleName,
- "StringRef":$importName,
- "Type": $type,
- "bool": $isMutable)>
- ];
let hasCustomAssemblyFormat = 1;
}
@@ -442,23 +503,33 @@ def WasmSSA_MemOp : WasmSSA_Op<"memory", [Symbol]> {
Define a memory to be used by the program.
Multiple memories can be defined in the same module.
+ By default, `wasmssa.memory` have nested visibility. Memory exported by
+ the module are marked with the exported attribute. This gives them public
+ visibility.
+
Example:
```mlir
- // Define the `mem_0` memory with defined bounds of 0 -> 65536
+ // Define the `mem_0` (internal) memory with defined size bounds of [0:65536]
wasmssa.memory @mem_0 !wasmssa<limit[0:65536]>
+
+ // Define the `mem_1` exported memory with minimal size of 512
+ wasmssa.memory exported @mem_1 !wasmssa<limit[512:]>
```
}];
let arguments = (ins SymbolNameAttr: $sym_name,
WasmSSA_LimitTypeAttr: $limits,
- OptionalAttr<StrAttr>:$sym_visibility);
- let builders = [
- OpBuilder<(ins
- "::llvm::StringRef":$symbol,
- "wasmssa::LimitType":$limit)>
- ];
+ UnitAttr: $exported);
- let assemblyFormat = "$sym_name custom<WasmVisibility>($sym_visibility) $limits attr-dict";
+ let extraClassDeclaration = [{
+ ::mlir::SymbolTable::Visibility getVisibility() {
+ return getExported() ?
+ ::mlir::SymbolTable::Visibility::Public :
+ ::mlir::SymbolTable::Visibility::Nested;
+ };
+ }];
+
+ let assemblyFormat = "(`exported` $exported^)? $sym_name $limits attr-dict";
}
def WasmSSA_MemImportOp : WasmSSA_Op<"import_mem", [Symbol, ImportOpInterface]> {
@@ -476,16 +547,13 @@ def WasmSSA_MemImportOp : WasmSSA_Op<"import_mem", [Symbol, ImportOpInterface]>
let arguments = (ins SymbolNameAttr: $sym_name,
StrAttr: $moduleName,
StrAttr: $importName,
- WasmSSA_LimitTypeAttr: $limits,
- OptionalAttr<StrAttr>:$sym_visibility);
+ WasmSSA_LimitTypeAttr: $limits);
let extraClassDeclaration = [{
- bool isDeclaration() const { return true; }
+ bool isDeclaration() const { return true; }
+ ::mlir::SymbolTable::Visibility getVisibility() {
+ return ::mlir::SymbolTable::Visibility::Nested;
+ };
}];
- let builders = [OpBuilder<(ins
- "::llvm::StringRef":$symbol,
- "::llvm::StringRef":$moduleName,
- "::llvm::StringRef":$importName,
- "wasmssa::LimitType":$limits)>];
let assemblyFormat = "$importName `from` $moduleName `as` $sym_name attr-dict";
}
@@ -493,11 +561,15 @@ def WasmSSA_TableOp : WasmSSA_Op<"table", [Symbol]> {
let summary= "WebAssembly table value";
let arguments = (ins SymbolNameAttr: $sym_name,
WasmSSA_TableTypeAttr: $type,
- OptionalAttr<StrAttr>:$sym_visibility);
- let builders = [OpBuilder<(ins
- "::llvm::StringRef":$symbol,
- "wasmssa::TableType":$type)>];
- let assemblyFormat = "$sym_name custom<WasmVisibility>($sym_visibility) $type attr-dict";
+ UnitAttr: $exported);
+ let extraClassDeclaration = [{
+ ::mlir::SymbolTable::Visibility getVisibility() {
+ return getExported() ?
+ ::mlir::SymbolTable::Visibility::Public :
+ ::mlir::SymbolTable::Visibility::Nested;
+ };
+ }];
+ let assemblyFormat = "(`exported` $exported^)? $sym_name $type attr-dict";
}
def WasmSSA_TableImportOp : WasmSSA_Op<"import_table", [Symbol, ImportOpInterface]> {
@@ -515,17 +587,14 @@ def WasmSSA_TableImportOp : WasmSSA_Op<"import_table", [Symbol, ImportOpInterfac
let arguments = (ins SymbolNameAttr: $sym_name,
StrAttr: $moduleName,
StrAttr: $importName,
- WasmSSA_TableTypeAttr: $type,
- OptionalAttr<StrAttr>:$sym_visibility);
+ WasmSSA_TableTypeAttr: $type);
let extraClassDeclaration = [{
bool isDeclaration() const { return true; }
+ ::mlir::SymbolTable::Visibility getVisibility() {
+ return ::mlir::SymbolTable::Visibility::Nested;
+ };
}];
let assemblyFormat = "$importName `from` $moduleName `as` $sym_name attr-dict";
- let builders = [OpBuilder<(ins
- "::llvm::StringRef":$symbol,
- "::llvm::StringRef":$moduleName,
- "::llvm::StringRef":$importName,
- "wasmssa::TableType":$type)>];
}
def WasmSSA_ReturnOp : WasmSSA_Op<"return", [Terminator]> {
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 5695d5d..19a5231 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -712,10 +712,14 @@ def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> {
return getAttrs().contains(name);
}
- ArrayAttr getStrides() {
+ ArrayAttr getStrideAttr() {
return getAttrs().getAs<ArrayAttr>("stride");
}
+ ArrayAttr getBlockAttr() {
+ return getAttrs().getAs<ArrayAttr>("block");
+ }
+
}];
}
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 73f9061..426377f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1298,14 +1298,14 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure,
}
def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
- AllElementTypesMatch<["mem_desc", "res"]>,
- AllRanksMatch<["mem_desc", "res"]>]> {
+ AllElementTypesMatch<["mem_desc", "res"]>]> {
let arguments = (ins XeGPU_MemDesc:$mem_desc,
Variadic<Index>: $offsets,
DenseI64ArrayAttr: $const_offsets,
+ OptionalAttr<UnitAttr>:$subgroup_block_io,
OptionalAttr<DistributeLayoutAttr>:$layout
);
- let results = (outs XeGPU_ValueType:$res);
+ let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$res);
let assemblyFormat = [{
$mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
prop-dict attr-dict `` `:` type(operands) `->` type(results)
@@ -1319,6 +1319,9 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
Arguments:
- `mem_desc`: the memory descriptor identifying the SLM region.
- `offsets`: the coordinates within the matrix to read from.
+ - `subgroup_block_io`: [optional] An attribute indicating that the operation can be
+ lowered to a subgroup block load. When this attribute is present,
+ the offsets are subgroup-uniform across all lanes.
- `layout`: [optional] An attribute for guiding distributions among
subgroups and/or work-items. It currently can accept either
LayoutAttr or SliceAttr.
@@ -1336,7 +1339,10 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
}
ArrayRef<int64_t> getDataShape() {
- return getRes().getType().getShape();
+ auto resTy = getRes().getType();
+ if (auto vecTy = llvm::dyn_cast<VectorType>(resTy))
+ return vecTy.getShape();
+ return {};
}
}];
@@ -1344,13 +1350,13 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
}
def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
- AllElementTypesMatch<["mem_desc", "data"]>,
- AllRanksMatch<["mem_desc", "data"]>]> {
+ AllElementTypesMatch<["mem_desc", "data"]>]> {
let arguments = (ins
- XeGPU_ValueType:$data,
+ AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$data,
XeGPU_MemDesc:$mem_desc,
Variadic<Index>: $offsets,
DenseI64ArrayAttr: $const_offsets,
+ OptionalAttr<UnitAttr>:$subgroup_block_io,
OptionalAttr<DistributeLayoutAttr>:$layout
);
let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
@@ -1364,6 +1370,9 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
- `mem_desc`: the memory descriptor specifying the SLM region.
- `offsets`: the coordinates within the matrix where the data will be written.
- `data`: the values to be stored in the matrix.
+ - `subgroup_block_io`: [optional] An attribute indicating that the operation can be
+ lowered to a subgroup block store. When this attribute is present,
+ the offsets are subgroup-uniform across all lanes.
- `layout`: [optional] An attribute for guiding distributions among
subgroups and/or work-items. It currently can accept either
LayoutAttr or SliceAttr.
@@ -1378,7 +1387,10 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
}
ArrayRef<int64_t> getDataShape() {
- return getData().getType().getShape();
+ auto DataTy = getData().getType();
+ if (auto vecTy = llvm::dyn_cast<VectorType>(DataTy))
+ return vecTy.getShape();
+ return {};
}
}];
@@ -1386,41 +1398,4 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
let hasVerifier = 1;
}
-def XeGPU_MemDescSubviewOp: XeGPU_Op<"mem_desc_subview",
- [Pure, ViewLikeOpInterface, AllElementTypesMatch<["src", "res"]>]> {
- let description = [{
- Creates a subview of a memory descriptor. The resulting memory descriptor can have
- a lower rank than the source; in this case, the result dimensions correspond to the
- higher-order dimensions of the source memory descriptor.
-
- Arguments:
- - `src` : a memory descriptor.
- - `offsets` : the coordinates within the matrix the subview will be created from.
-
- Results:
- - `res` : a memory descriptor with smaller size.
-
- }];
- let arguments = (ins XeGPU_MemDesc:$src,
- Variadic<Index>:$offsets,
- DenseI64ArrayAttr:$const_offsets);
- let results = (outs XeGPU_MemDesc:$res);
- let assemblyFormat = [{$src `` custom<DynamicIndexList>($offsets, $const_offsets) prop-dict
- attr-dict `` `:` qualified(type($src)) `->` qualified(type($res))}];
- let builders = [
- OpBuilder<(ins "Type": $res, "Value":$src, "llvm::ArrayRef<OpFoldResult>": $offsets)>
- ];
-
- let extraClassDeclaration = [{
- mlir::Value getViewSource() { return getSrc(); }
-
- SmallVector<OpFoldResult> getMixedOffsets() {
- return getMixedValues(getConstOffsets(), getOffsets(), getContext());
- }
- }];
-
- let hasVerifier = 1;
-}
-
-
#endif // MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 84902b2..b1196fb 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -237,12 +237,11 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
return MemDescType::get(getContext(), shape.value_or(getShape()), elementType, getMemLayout());
}
- ArrayAttr getStrides() {
+ ArrayAttr getStrideAttr() {
auto layout = getMemLayout();
if (layout && layout.hasAttr("stride")) {
- return layout.getStrides();
+ return layout.getStrideAttr();
}
-
// derive and return default strides
SmallVector<int64_t> defaultStrides;
llvm::append_range(defaultStrides, getShape().drop_front());
@@ -250,6 +249,63 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
Builder builder(getContext());
return builder.getI64ArrayAttr(defaultStrides);
}
+
+ ArrayAttr getBlockAttr() {
+ auto layout = getMemLayout();
+ if (layout && layout.hasAttr("block")) {
+ return layout.getBlockAttr();
+ }
+ Builder builder(getContext());
+ return builder.getI64ArrayAttr({});
+ }
+
+ /// Heuristic to determine if the MemDesc uses column-major layout,
+ /// based on the rank and the value of the first stride dimension.
+ bool isColMajor() {
+ auto dim0 = dyn_cast<IntegerAttr>(getStrideAttr()[0]);
+ return getRank() == 2 && dim0.getInt() == 1;
+ }
+
+ // Get the Blocking shape for a MemDescType, Which is represented
+ // as an attribute in MemDescType. By default it is the shape
+ // of the mdescTy
+ SmallVector<int64_t> getBlockShape() {
+ SmallVector<int64_t> size(getShape());
+ ArrayAttr blockAttr = getBlockAttr();
+ if (!blockAttr.empty()) {
+ size.clear();
+ for (auto attr : blockAttr.getValue()) {
+ size.push_back(cast<IntegerAttr>(attr).getInt());
+ }
+ }
+ return size;
+ }
+
+ // Get strides as vector of integer.
+ // If it contains block attribute, the strides are blocked strides.
+ //
+ // The blocking is applied to the base matrix shape derived from the
+ // memory descriptor's stride information. If the matrix described by
+ // the memory descriptor is not contiguous, it is assumed that the base
+ // matrix is contiguous and follows the same memory layout.
+ //
+ // It first computes the original matrix shape using the stride info,
+ // then computes the number of blocks in each dimension of original shape,
+ // then compute the outer block shape and stride,
+ // then combines the inner and outer block shape and stride
+ // e.g. for `mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]>`
+ // its memory layout tuple is ([2,32,16,8],[128,256,1,16])
+ // for `mem_desc<256x32xf16, @block=[8, 16]>` with default @stride[32, 1]
+ // its memory layout tuple is ([32,2,8,16],[256,128,16,1])
+ SmallVector<int64_t> getStrideShape();
+
+ /// Generates instructions to compute the linearize offset
+ // if the memory descriptor is blocked, it returns linearize offset based on the blocked layout
+ // the strides of memory descriptor is always considered regardless of blocked or not
+ Value getLinearOffsets(OpBuilder &builder,
+ Location loc, ArrayRef<OpFoldResult> offsets);
+
+
}];
let hasCustomAssemblyFormat = true;
diff --git a/mlir/include/mlir/TableGen/CodeGenHelpers.h b/mlir/include/mlir/TableGen/CodeGenHelpers.h
index 252da21..997aef2 100644
--- a/mlir/include/mlir/TableGen/CodeGenHelpers.h
+++ b/mlir/include/mlir/TableGen/CodeGenHelpers.h
@@ -88,7 +88,7 @@ public:
///
/// Constraints that do not meet the restriction that they can only reference
/// `$_self` and `$_op` are not uniqued.
- void emitOpConstraints(ArrayRef<const llvm::Record *> opDefs);
+ void emitOpConstraints();
/// Unique all compatible type and attribute constraints from a pattern file
/// and emit them at the top of the generated file.
diff --git a/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h b/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h
index 21adde8..cd9ef5b 100644
--- a/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h
+++ b/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h
@@ -19,6 +19,14 @@ namespace mlir {
struct WasmBinaryEncoding {
/// Byte encodings for Wasm instructions.
struct OpCode {
+ // Control instructions.
+ static constexpr std::byte block{0x02};
+ static constexpr std::byte loop{0x03};
+ static constexpr std::byte ifOpCode{0x04};
+ static constexpr std::byte elseOpCode{0x05};
+ static constexpr std::byte branchIf{0x0D};
+ static constexpr std::byte call{0x10};
+
// Locals, globals, constants.
static constexpr std::byte localGet{0x20};
static constexpr std::byte localSet{0x21};
@@ -29,6 +37,42 @@ struct WasmBinaryEncoding {
static constexpr std::byte constFP32{0x43};
static constexpr std::byte constFP64{0x44};
+ // Comparisons.
+ static constexpr std::byte eqzI32{0x45};
+ static constexpr std::byte eqI32{0x46};
+ static constexpr std::byte neI32{0x47};
+ static constexpr std::byte ltSI32{0x48};
+ static constexpr std::byte ltUI32{0x49};
+ static constexpr std::byte gtSI32{0x4A};
+ static constexpr std::byte gtUI32{0x4B};
+ static constexpr std::byte leSI32{0x4C};
+ static constexpr std::byte leUI32{0x4D};
+ static constexpr std::byte geSI32{0x4E};
+ static constexpr std::byte geUI32{0x4F};
+ static constexpr std::byte eqzI64{0x50};
+ static constexpr std::byte eqI64{0x51};
+ static constexpr std::byte neI64{0x52};
+ static constexpr std::byte ltSI64{0x53};
+ static constexpr std::byte ltUI64{0x54};
+ static constexpr std::byte gtSI64{0x55};
+ static constexpr std::byte gtUI64{0x56};
+ static constexpr std::byte leSI64{0x57};
+ static constexpr std::byte leUI64{0x58};
+ static constexpr std::byte geSI64{0x59};
+ static constexpr std::byte geUI64{0x5A};
+ static constexpr std::byte eqF32{0x5B};
+ static constexpr std::byte neF32{0x5C};
+ static constexpr std::byte ltF32{0x5D};
+ static constexpr std::byte gtF32{0x5E};
+ static constexpr std::byte leF32{0x5F};
+ static constexpr std::byte geF32{0x60};
+ static constexpr std::byte eqF64{0x61};
+ static constexpr std::byte neF64{0x62};
+ static constexpr std::byte ltF64{0x63};
+ static constexpr std::byte gtF64{0x64};
+ static constexpr std::byte leF64{0x65};
+ static constexpr std::byte geF64{0x66};
+
// Numeric operations.
static constexpr std::byte clzI32{0x67};
static constexpr std::byte ctzI32{0x68};
@@ -93,6 +137,33 @@ struct WasmBinaryEncoding {
static constexpr std::byte maxF64{0xA5};
static constexpr std::byte copysignF64{0xA6};
static constexpr std::byte wrap{0xA7};
+
+ // Conversion operations
+ static constexpr std::byte extendS{0xAC};
+ static constexpr std::byte extendU{0xAD};
+ static constexpr std::byte convertSI32F32{0xB2};
+ static constexpr std::byte convertUI32F32{0xB3};
+ static constexpr std::byte convertSI64F32{0xB4};
+ static constexpr std::byte convertUI64F32{0xB5};
+
+ static constexpr std::byte demoteF64ToF32{0xB6};
+
+ static constexpr std::byte convertSI32F64{0xB7};
+ static constexpr std::byte convertUI32F64{0xB8};
+ static constexpr std::byte convertSI64F64{0xB9};
+ static constexpr std::byte convertUI64F64{0xBA};
+
+ static constexpr std::byte promoteF32ToF64{0xBB};
+ static constexpr std::byte reinterpretF32AsI32{0xBC};
+ static constexpr std::byte reinterpretF64AsI64{0xBD};
+ static constexpr std::byte reinterpretI32AsF32{0xBE};
+ static constexpr std::byte reinterpretI64AsF64{0xBF};
+
+ static constexpr std::byte extendI328S{0xC0};
+ static constexpr std::byte extendI3216S{0xC1};
+ static constexpr std::byte extendI648S{0xC2};
+ static constexpr std::byte extendI6416S{0xC3};
+ static constexpr std::byte extendI6432S{0xC4};
};
/// Byte encodings of types in Wasm binaries
diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
index 30ce1fb..6588b53 100644
--- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
+++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
@@ -1244,8 +1244,9 @@ bool FlatLinearValueConstraints::areVarsAlignedWithOther(
/// Checks if the SSA values associated with `cst`'s variables in range
/// [start, end) are unique.
-static bool LLVM_ATTRIBUTE_UNUSED areVarsUnique(
- const FlatLinearValueConstraints &cst, unsigned start, unsigned end) {
+[[maybe_unused]] static bool
+areVarsUnique(const FlatLinearValueConstraints &cst, unsigned start,
+ unsigned end) {
assert(start <= cst.getNumDimAndSymbolVars() &&
"Start position out of bounds");
@@ -1267,14 +1268,14 @@ static bool LLVM_ATTRIBUTE_UNUSED areVarsUnique(
}
/// Checks if the SSA values associated with `cst`'s variables are unique.
-static bool LLVM_ATTRIBUTE_UNUSED
+[[maybe_unused]] static bool
areVarsUnique(const FlatLinearValueConstraints &cst) {
return areVarsUnique(cst, 0, cst.getNumDimAndSymbolVars());
}
/// Checks if the SSA values associated with `cst`'s variables of kind `kind`
/// are unique.
-static bool LLVM_ATTRIBUTE_UNUSED
+[[maybe_unused]] static bool
areVarsUnique(const FlatLinearValueConstraints &cst, VarKind kind) {
if (kind == VarKind::SetDim)
diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp
index a1cbe29..547a4c2 100644
--- a/mlir/lib/Analysis/Presburger/Simplex.cpp
+++ b/mlir/lib/Analysis/Presburger/Simplex.cpp
@@ -34,7 +34,7 @@ using Direction = Simplex::Direction;
const int nullIndex = std::numeric_limits<int>::max();
// Return a + scale*b;
-LLVM_ATTRIBUTE_UNUSED
+[[maybe_unused]]
static SmallVector<DynamicAPInt, 8>
scaleAndAddForAssert(ArrayRef<DynamicAPInt> a, const DynamicAPInt &scale,
ArrayRef<DynamicAPInt> b) {
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 7b17106..06d0256 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2730,6 +2730,17 @@ public:
operation->get(), toMlirStringRef(name)));
}
+ static void
+ forEachAttr(MlirOperation op,
+ llvm::function_ref<void(MlirStringRef, MlirAttribute)> fn) {
+ intptr_t n = mlirOperationGetNumAttributes(op);
+ for (intptr_t i = 0; i < n; ++i) {
+ MlirNamedAttribute na = mlirOperationGetAttribute(op, i);
+ MlirStringRef name = mlirIdentifierStr(na.name);
+ fn(name, na.attribute);
+ }
+ }
+
static void bind(nb::module_ &m) {
nb::class_<PyOpAttributeMap>(m, "OpAttributeMap")
.def("__contains__", &PyOpAttributeMap::dunderContains)
@@ -2737,7 +2748,50 @@ public:
.def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
.def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
.def("__setitem__", &PyOpAttributeMap::dunderSetItem)
- .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
+ .def("__delitem__", &PyOpAttributeMap::dunderDelItem)
+ .def("__iter__",
+ [](PyOpAttributeMap &self) {
+ nb::list keys;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(),
+ [&](MlirStringRef name, MlirAttribute) {
+ keys.append(nb::str(name.data, name.length));
+ });
+ return nb::iter(keys);
+ })
+ .def("keys",
+ [](PyOpAttributeMap &self) {
+ nb::list out;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(),
+ [&](MlirStringRef name, MlirAttribute) {
+ out.append(nb::str(name.data, name.length));
+ });
+ return out;
+ })
+ .def("values",
+ [](PyOpAttributeMap &self) {
+ nb::list out;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(),
+ [&](MlirStringRef, MlirAttribute attr) {
+ out.append(PyAttribute(self.operation->getContext(), attr)
+ .maybeDownCast());
+ });
+ return out;
+ })
+ .def("items", [](PyOpAttributeMap &self) {
+ nb::list out;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(),
+ [&](MlirStringRef name, MlirAttribute attr) {
+ out.append(nb::make_tuple(
+ nb::str(name.data, name.length),
+ PyAttribute(self.operation->getContext(), attr)
+ .maybeDownCast()));
+ });
+ return out;
+ });
}
private:
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 5ddb3fb..0f0ed22 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -205,7 +205,7 @@ public:
nb::object res = f(opView, PyPatternRewriter(rewriter));
return logicalResultFromObject(res);
};
- MlirRewritePattern pattern = mlirOpRewritePattenCreate(
+ MlirRewritePattern pattern = mlirOpRewritePatternCreate(
rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
/* nGeneratedNames */ 0,
/* generatedNames */ nullptr);
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 46c329d..41ceb15 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -341,7 +341,7 @@ private:
} // namespace mlir
-MlirRewritePattern mlirOpRewritePattenCreate(
+MlirRewritePattern mlirOpRewritePatternCreate(
MlirStringRef rootName, unsigned benefit, MlirContext context,
MlirRewritePatternCallbacks callbacks, void *userData,
size_t nGeneratedNames, MlirStringRef *generatedNames) {
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index b215211..c03f3a5 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -484,5 +484,5 @@ void mlir::populateGpuToROCDLConversionPatterns(
GPUSubgroupBroadcastOpToROCDL>(converter);
patterns.add<GPUSubgroupSizeOpToROCDL>(converter, chipset);
- populateMathToROCDLConversionPatterns(converter, patterns);
+ populateMathToROCDLConversionPatterns(converter, patterns, chipset);
}
diff --git a/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt b/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt
index 2771955a..8cc3fde 100644
--- a/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt
+++ b/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRMathToROCDL
Core
LINK_LIBS PUBLIC
+ MLIRAMDGPUUtils
MLIRDialectUtils
MLIRFuncDialect
MLIRGPUToGPURuntimeTransforms
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index df219f3..a2dfc12 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -10,6 +10,8 @@
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
@@ -19,6 +21,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/DebugLog.h"
#include "../GPUCommon/GPUOpsLowering.h"
#include "../GPUCommon/OpToFuncCallLowering.h"
@@ -42,8 +45,46 @@ static void populateOpPatterns(const LLVMTypeConverter &converter,
f32ApproxFunc, f16Func);
}
+struct ClampFOpConversion final
+ : public ConvertOpToLLVMPattern<math::ClampFOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Only f16 and f32 types are supported by fmed3
+ Type opTy = op.getType();
+ Type resultType = getTypeConverter()->convertType(opTy);
+
+ if (auto vectorType = dyn_cast<VectorType>(opTy))
+ opTy = vectorType.getElementType();
+
+ if (!isa<Float16Type, Float32Type>(opTy))
+ return rewriter.notifyMatchFailure(
+ op, "fmed3 only supports f16 and f32 types");
+
+ // Handle multi-dimensional vectors (converted to LLVM arrays)
+ if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType))
+ return LLVM::detail::handleMultidimensionalVectors(
+ op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
+ [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
+ typename math::ClampFOp::Adaptor adaptor(operands);
+ return ROCDL::FMed3Op::create(rewriter, op.getLoc(), llvm1DVectorTy,
+ adaptor.getValue(), adaptor.getMin(),
+ adaptor.getMax());
+ },
+ rewriter);
+
+ // Handle 1D vectors and scalars directly
+ rewriter.replaceOpWithNewOp<ROCDL::FMed3Op>(op, op.getType(), op.getValue(),
+ op.getMin(), op.getMax());
+ return success();
+ }
+};
+
void mlir::populateMathToROCDLConversionPatterns(
- const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ std::optional<amdgpu::Chipset> chipset) {
// Handled by mathToLLVM: math::AbsIOp
// Handled by mathToLLVM: math::AbsFOp
// Handled by mathToLLVM: math::CopySignOp
@@ -118,15 +159,21 @@ void mlir::populateMathToROCDLConversionPatterns(
// worth creating a separate pass for it.
populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
"__ocml_fmod_f64", "__ocml_fmod_f16");
+
+ if (chipset.has_value() && chipset->majorVersion >= 9) {
+ patterns.add<ClampFOpConversion>(converter);
+ } else {
+ LDBG() << "Chipset dependent patterns were not added";
+ }
}
-namespace {
-struct ConvertMathToROCDLPass
- : public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
- ConvertMathToROCDLPass() = default;
+struct ConvertMathToROCDLPass final
+ : impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
+ using impl::ConvertMathToROCDLBase<
+ ConvertMathToROCDLPass>::ConvertMathToROCDLBase;
+
void runOnOperation() override;
};
-} // namespace
void ConvertMathToROCDLPass::runOnOperation() {
auto m = getOperation();
@@ -135,10 +182,21 @@ void ConvertMathToROCDLPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
LowerToLLVMOptions options(ctx, DataLayout(m));
LLVMTypeConverter converter(ctx, options);
- populateMathToROCDLConversionPatterns(converter, patterns);
+
+ FailureOr<amdgpu::Chipset> maybeChipset;
+ if (!chipset.empty()) {
+ maybeChipset = amdgpu::Chipset::parse(chipset);
+ if (failed(maybeChipset))
+ return signalPassFailure();
+ }
+ populateMathToROCDLConversionPatterns(
+ converter, patterns,
+ succeeded(maybeChipset) ? std::optional(*maybeChipset) : std::nullopt);
+
ConversionTarget target(getContext());
- target.addLegalDialect<BuiltinDialect, func::FuncDialect,
- vector::VectorDialect, LLVM::LLVMDialect>();
+ target
+ .addLegalDialect<BuiltinDialect, func::FuncDialect, vector::VectorDialect,
+ LLVM::LLVMDialect, ROCDL::ROCDLDialect>();
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index f0d8b78..610ce1f 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -407,11 +407,11 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
if (auto vectorType = dyn_cast<VectorType>(operandType))
nanAttr = DenseElementsAttr::get(vectorType, nan);
- Value NanValue =
+ Value nanValue =
spirv::ConstantOp::create(rewriter, loc, operandType, nanAttr);
Value lhs =
spirv::SelectOp::create(rewriter, loc, cmpNegativeWithFractionalExp,
- NanValue, adaptor.getLhs());
+ nanValue, adaptor.getLhs());
Value abs = spirv::GLFAbsOp::create(rewriter, loc, lhs);
// TODO: The following just forcefully casts y into an integer value in
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 5355909..41d8d53 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1723,17 +1723,18 @@ struct VectorBroadcastScalarToLowRankLowering
return success();
}
- // For 1-d vector, we additionally do a `vectorshuffle`.
auto v =
LLVM::InsertElementOp::create(rewriter, broadcast.getLoc(), vectorType,
poison, adaptor.getSource(), zero);
+ // For 1-d vector, we additionally do a `shufflevector`.
int64_t width = cast<VectorType>(broadcast.getType()).getDimSize(0);
SmallVector<int32_t> zeroValues(width, 0);
// Shuffle the value across the desired number of elements.
- rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(broadcast, v, poison,
- zeroValues);
+ auto shuffle = rewriter.createOrFold<LLVM::ShuffleVectorOp>(
+ broadcast.getLoc(), v, poison, zeroValues);
+ rewriter.replaceOp(broadcast, shuffle);
return success();
}
};
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
index 84b2580..dd9edc4 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
@@ -21,6 +21,7 @@ add_mlir_conversion_library(MLIRXeGPUToXeVM
MLIRIndexDialect
MLIRSCFDialect
MLIRXeGPUDialect
+ MLIRXeGPUUtils
MLIRPass
MLIRTransforms
MLIRSCFTransforms
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 71687b1..fcbf66d 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -20,7 +20,9 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
@@ -62,6 +64,7 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
case xegpu::MemorySpace::SLM:
return static_cast<int>(xevm::AddrSpace::SHARED);
}
+ llvm_unreachable("Unknown XeGPU memory space");
}
// Get same bitwidth flat vector type of new element type.
@@ -185,6 +188,7 @@ class CreateNdDescToXeVMPattern
int64_t rank = mixedSizes.size();
if (rank != 2)
return rewriter.notifyMatchFailure(op, "Expected 2D shape.");
+
auto sourceTy = source.getType();
auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
// If source is a memref, we need to extract the aligned pointer as index.
@@ -363,10 +367,11 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
// Add a builder that creates
// offset * elemByteSize + baseAddr
-static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
- Value baseAddr, Value offset, int64_t elemByteSize) {
+static Value addOffsetToBaseAddr(ConversionPatternRewriter &rewriter,
+ Location loc, Value baseAddr, Value offset,
+ int64_t elemByteSize) {
Value byteSize = arith::ConstantIntOp::create(
- rewriter, loc, rewriter.getI64Type(), elemByteSize);
+ rewriter, loc, baseAddr.getType(), elemByteSize);
Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
return newAddr;
@@ -390,7 +395,8 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
// Load result or Store valye Type can be vector or scalar.
Type valOrResTy;
if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
- valOrResTy = op.getResult().getType();
+ valOrResTy =
+ this->getTypeConverter()->convertType(op.getResult().getType());
else
valOrResTy = adaptor.getValue().getType();
VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
@@ -441,7 +447,8 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
// If offset is provided, we add them to the base pointer.
// Offset is in number of elements, we need to multiply by
// element byte size.
- basePtrI64 = addOffset(rewriter, loc, basePtrI64, offset, elemByteSize);
+ basePtrI64 =
+ addOffsetToBaseAddr(rewriter, loc, basePtrI64, offset, elemByteSize);
}
// Convert base pointer (i64) to LLVM pointer type.
Value basePtrLLVM =
@@ -504,6 +511,147 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
}
};
+// Lower xegpu::CreateMemDescOp to memref::ViewOp. Since SLM access instructions
+// on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than
+// 32 bits will be converted to 32 bits.
+class CreateMemDescOpPattern final
+ : public OpConversionPattern<xegpu::CreateMemDescOp> {
+public:
+ using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ auto resTy = op.getMemDesc();
+
+ // Create the result MemRefType with the same shape, element type, and
+ // memory space
+ auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);
+
+ Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
+ auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy,
+ op.getSource(), zero, ValueRange());
+ rewriter.replaceOp(op, viewOp);
+ return success();
+ }
+};
+
+template <typename OpType,
+ typename = std::enable_if_t<llvm::is_one_of<
+ OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
+class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
+ using OpConversionPattern<OpType>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ SmallVector<OpFoldResult> offsets = op.getMixedOffsets();
+ if (offsets.empty())
+ return rewriter.notifyMatchFailure(op, "Expected offset to be provided.");
+
+ auto loc = op.getLoc();
+ auto ctxt = rewriter.getContext();
+ Value basePtrStruct = adaptor.getMemDesc();
+ Value mdescVal = op.getMemDesc();
+ // Load result or Store value Type can be vector or scalar.
+ Value data;
+ if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>)
+ data = op.getResult();
+ else
+ data = adaptor.getData();
+ VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
+ if (!valOrResVecTy)
+ valOrResVecTy = VectorType::get(1, data.getType());
+
+ int64_t elemBitWidth =
+ valOrResVecTy.getElementType().getIntOrFloatBitWidth();
+ // Element type must be multiple of 8 bits.
+ if (elemBitWidth % 8 != 0)
+ return rewriter.notifyMatchFailure(
+ op, "Expected element type bit width to be multiple of 8.");
+ int64_t elemByteSize = elemBitWidth / 8;
+
+ // Default memory space is SLM.
+ LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
+ ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM));
+
+ auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());
+
+ Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(
+ rewriter, loc, basePtrStruct);
+
+ // Convert base pointer (ptr) to i32
+ Value basePtrI32 = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI32Type(), basePtrLLVM);
+
+ Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
+ linearOffset = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI32Type(), linearOffset);
+ basePtrI32 = addOffsetToBaseAddr(rewriter, loc, basePtrI32, linearOffset,
+ elemByteSize);
+
+ // convert base pointer (i32) to LLVM pointer type
+ basePtrLLVM =
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);
+
+ if (op.getSubgroupBlockIoAttr()) {
+ // if the attribute 'subgroup_block_io' is set to true, it lowers to
+ // xevm.blockload
+
+ Type intElemTy = rewriter.getIntegerType(elemBitWidth);
+ VectorType intVecTy =
+ VectorType::get(valOrResVecTy.getShape(), intElemTy);
+
+ if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
+ Value loadOp =
+ xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM);
+ if (intVecTy != valOrResVecTy) {
+ loadOp =
+ vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp);
+ }
+ rewriter.replaceOp(op, loadOp);
+ } else {
+ Value dataToStore = adaptor.getData();
+ if (valOrResVecTy != intVecTy) {
+ dataToStore =
+ vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore);
+ }
+ xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore,
+ nullptr);
+ rewriter.eraseOp(op);
+ }
+ return success();
+ }
+
+ if (valOrResVecTy.getNumElements() >= 1) {
+ auto chipOpt = xegpu::getChipStr(op);
+ if (!chipOpt || (*chipOpt != "pvc" && *chipOpt != "bmg")) {
+ // the lowering for chunk load only works for pvc and bmg
+ return rewriter.notifyMatchFailure(
+ op, "The lowering is specific to pvc or bmg.");
+ }
+ }
+
+ if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
+ // if the size of valOrResVecTy is 1, it lowers to a scalar load/store
+ // operation. LLVM load/store does not support vector of size 1, so we
+ // need to handle this case separately.
+ auto scalarTy = valOrResVecTy.getElementType();
+ LLVM::LoadOp loadOp;
+ if (valOrResVecTy.getNumElements() == 1)
+ loadOp = LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
+ else
+ loadOp =
+ LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
+ rewriter.replaceOp(op, loadOp);
+ } else {
+ LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
+ rewriter.eraseOp(op);
+ }
+ return success();
+ }
+};
+
class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
@@ -546,8 +694,8 @@ class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
op, "Expected element type bit width to be multiple of 8.");
elemByteSize = elemBitWidth / 8;
}
- basePtrI64 =
- addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
+ basePtrI64 = addOffsetToBaseAddr(rewriter, loc, basePtrI64, offsets,
+ elemByteSize);
}
}
// Default memory space is global.
@@ -784,6 +932,13 @@ struct ConvertXeGPUToXeVMPass
auto i32Type = IntegerType::get(&getContext(), 32);
return VectorType::get(8, i32Type);
});
+ // Convert MemDescType into flattened MemRefType for SLM
+ typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
+ Type elemTy = type.getElementType();
+ int numElems = type.getNumElements();
+ return MemRefType::get(numElems, elemTy, AffineMap(), 3);
+ });
+
typeConverter.addConversion([&](MemRefType type) -> Type {
// Convert MemRefType to i64 type.
return IntegerType::get(&getContext(), 64);
@@ -878,10 +1033,30 @@ struct ConvertXeGPUToXeVMPass
}
return {};
};
- typeConverter.addSourceMaterialization(memrefMaterializationCast);
- typeConverter.addSourceMaterialization(ui64MaterializationCast);
- typeConverter.addSourceMaterialization(ui32MaterializationCast);
- typeConverter.addSourceMaterialization(vectorMaterializationCast);
+
+ // If result type of original op is single element vector and lowered type
+ // is scalar. This materialization cast creates a single element vector by
+ // broadcasting the scalar value.
+ auto singleElementVectorMaterializationCast =
+ [](OpBuilder &builder, Type type, ValueRange inputs,
+ Location loc) -> Value {
+ if (inputs.size() != 1)
+ return {};
+ auto input = inputs.front();
+ if (input.getType().isIntOrIndexOrFloat()) {
+ // If the input is a scalar, and the target type is a vector of single
+ // element, create a single element vector by broadcasting.
+ if (auto vecTy = dyn_cast<VectorType>(type)) {
+ if (vecTy.getNumElements() == 1) {
+ return vector::BroadcastOp::create(builder, loc, vecTy, input)
+ .getResult();
+ }
+ }
+ }
+ return {};
+ };
+ typeConverter.addSourceMaterialization(
+ singleElementVectorMaterializationCast);
typeConverter.addTargetMaterialization(memrefMaterializationCast);
typeConverter.addTargetMaterialization(ui32MaterializationCast);
typeConverter.addTargetMaterialization(ui64MaterializationCast);
@@ -918,6 +1093,9 @@ void mlir::populateXeGPUToXeVMConversionPatterns(
LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
typeConverter, patterns.getContext());
+ patterns.add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>,
+ LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>,
+ CreateMemDescOpPattern>(typeConverter, patterns.getContext());
patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
patterns.getContext());
}
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index f405d0c..c798adb 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -757,13 +757,13 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
offset = numElements - 4l;
}
Type scaleSrcElemType = scaleSrcType.getElementType();
- auto newSrcType = VectorType::get(SmallVector<int64_t>({numElements}),
- scaleSrcElemType);
+ auto newSrcType =
+ VectorType::get(ArrayRef{numElements}, scaleSrcElemType);
Value newScaleSrc =
vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc);
auto extract = vector::ExtractStridedSliceOp::create(
- rewriter, loc, newScaleSrc, ArrayRef<int64_t>{offset},
- ArrayRef<int64_t>{size}, ArrayRef<int64_t>{1});
+ rewriter, loc, newScaleSrc, ArrayRef{offset}, ArrayRef{size},
+ ArrayRef{int64_t(1)});
rewriter.modifyOpInPlace(op, [&] {
op->setOperand(opIdx, extract);
setOpsel(opIdx, opsel);
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 7e5ce26..749e2ba 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -125,9 +125,9 @@ static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest,
// Use "unused attribute" marker to silence clang-tidy warning stemming from
// the inability to see through "llvm::TypeSwitch".
template <>
-bool LLVM_ATTRIBUTE_UNUSED remainsLegalAfterInline(AffineApplyOp op,
- Region *src, Region *dest,
- const IRMapping &mapping) {
+[[maybe_unused]] bool remainsLegalAfterInline(AffineApplyOp op, Region *src,
+ Region *dest,
+ const IRMapping &mapping) {
// If it's a valid dimension, we need to check that it remains so.
if (isValidDim(op.getResult(), src))
return remainsLegalAfterInline(
@@ -1032,8 +1032,8 @@ static void simplifyMinOrMaxExprWithOperands(AffineMap &map,
/// Simplify the map while exploiting information on the values in `operands`.
// Use "unused attribute" marker to silence warning stemming from the inability
// to see through the template expansion.
-static void LLVM_ATTRIBUTE_UNUSED
-simplifyMapWithOperands(AffineMap &map, ArrayRef<Value> operands) {
+[[maybe_unused]] static void simplifyMapWithOperands(AffineMap &map,
+ ArrayRef<Value> operands) {
assert(map.getNumInputs() == operands.size() && "invalid operands for map");
SmallVector<AffineExpr> newResults;
newResults.reserve(map.getNumResults());
@@ -1125,6 +1125,141 @@ static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp,
return success(*map != initialMap);
}
+/// Recursively traverse `e`. If `e` or one of its sub-expressions has the form
+/// e1 + e2 + ... + eK, where the e_i are a super(multi)set of `exprsToRemove`,
+/// place a map between e and `newVal` + sum({e1, e2, .. eK} - exprsToRemove)
+/// into `replacementsMap`. If no entries were added to `replacementsMap`,
+/// nothing was found.
+static void shortenAddChainsContainingAll(
+ AffineExpr e, const llvm::SmallDenseSet<AffineExpr, 4> &exprsToRemove,
+ AffineExpr newVal, DenseMap<AffineExpr, AffineExpr> &replacementsMap) {
+ auto binOp = dyn_cast<AffineBinaryOpExpr>(e);
+ if (!binOp)
+ return;
+ AffineExpr lhs = binOp.getLHS();
+ AffineExpr rhs = binOp.getRHS();
+ if (binOp.getKind() != AffineExprKind::Add) {
+ shortenAddChainsContainingAll(lhs, exprsToRemove, newVal, replacementsMap);
+ shortenAddChainsContainingAll(rhs, exprsToRemove, newVal, replacementsMap);
+ return;
+ }
+ SmallVector<AffineExpr> toPreserve;
+ llvm::SmallDenseSet<AffineExpr, 4> ourTracker(exprsToRemove);
+ AffineExpr thisTerm = rhs;
+ AffineExpr nextTerm = lhs;
+
+ while (thisTerm) {
+ if (!ourTracker.erase(thisTerm)) {
+ toPreserve.push_back(thisTerm);
+ shortenAddChainsContainingAll(thisTerm, exprsToRemove, newVal,
+ replacementsMap);
+ }
+ auto nextBinOp = dyn_cast_if_present<AffineBinaryOpExpr>(nextTerm);
+ if (!nextBinOp || nextBinOp.getKind() != AffineExprKind::Add) {
+ thisTerm = nextTerm;
+ nextTerm = AffineExpr();
+ } else {
+ thisTerm = nextBinOp.getRHS();
+ nextTerm = nextBinOp.getLHS();
+ }
+ }
+ if (!ourTracker.empty())
+ return;
+ // We reverse the terms to be preserved here in order to preserve
+ // associativity between them.
+ AffineExpr newExpr = newVal;
+ for (AffineExpr preserved : llvm::reverse(toPreserve))
+ newExpr = newExpr + preserved;
+ replacementsMap.insert({e, newExpr});
+}
+
+/// If this map contains of the expression `x_1 + x_1 * C_1 + ... x_n * C_N +
+/// ...` (not necessarily in order) where the set of the `x_i` is the set of
+/// outputs of an `affine.delinearize_index` whos inverse is that expression,
+/// replace that expression with the input of that delinearize_index op.
+///
+/// `unitDimInput` is the input that was detected as the potential start to this
+/// replacement chain - if it isn't the rightmost result of the delinearization,
+/// this method fails. (This is intended to ensure we don't have redundant scans
+/// over the same expression).
+///
+/// While this currently only handles delinearizations with a constant basis,
+/// that isn't a fundamental limitation.
+///
+/// This is a utility function for `replaceDimOrSym` below.
+static LogicalResult replaceAffineDelinearizeIndexInverseExpression(
+ AffineDelinearizeIndexOp delinOp, Value resultToReplace, AffineMap *map,
+ SmallVectorImpl<Value> &dims, SmallVectorImpl<Value> &syms) {
+ if (!delinOp.getDynamicBasis().empty())
+ return failure();
+ if (resultToReplace != delinOp.getMultiIndex().back())
+ return failure();
+
+ MLIRContext *ctx = delinOp.getContext();
+ SmallVector<AffineExpr> resToExpr(delinOp.getNumResults(), AffineExpr());
+ for (auto [pos, dim] : llvm::enumerate(dims)) {
+ auto asResult = dyn_cast_if_present<OpResult>(dim);
+ if (!asResult)
+ continue;
+ if (asResult.getOwner() == delinOp.getOperation())
+ resToExpr[asResult.getResultNumber()] = getAffineDimExpr(pos, ctx);
+ }
+ for (auto [pos, sym] : llvm::enumerate(syms)) {
+ auto asResult = dyn_cast_if_present<OpResult>(sym);
+ if (!asResult)
+ continue;
+ if (asResult.getOwner() == delinOp.getOperation())
+ resToExpr[asResult.getResultNumber()] = getAffineSymbolExpr(pos, ctx);
+ }
+ if (llvm::is_contained(resToExpr, AffineExpr()))
+ return failure();
+
+ bool isDimReplacement = llvm::all_of(resToExpr, llvm::IsaPred<AffineDimExpr>);
+ int64_t stride = 1;
+ llvm::SmallDenseSet<AffineExpr, 4> expectedExprs;
+ // This isn't zip_equal since sometimes the delinearize basis is missing a
+ // size for the first result.
+ for (auto [binding, size] : llvm::zip(
+ llvm::reverse(resToExpr), llvm::reverse(delinOp.getStaticBasis()))) {
+ expectedExprs.insert(binding * getAffineConstantExpr(stride, ctx));
+ stride *= size;
+ }
+ if (resToExpr.size() != delinOp.getStaticBasis().size())
+ expectedExprs.insert(resToExpr[0] * stride);
+
+ DenseMap<AffineExpr, AffineExpr> replacements;
+ AffineExpr delinInExpr = isDimReplacement
+ ? getAffineDimExpr(dims.size(), ctx)
+ : getAffineSymbolExpr(syms.size(), ctx);
+
+ for (AffineExpr e : map->getResults())
+ shortenAddChainsContainingAll(e, expectedExprs, delinInExpr, replacements);
+ if (replacements.empty())
+ return failure();
+
+ AffineMap origMap = *map;
+ if (isDimReplacement)
+ dims.push_back(delinOp.getLinearIndex());
+ else
+ syms.push_back(delinOp.getLinearIndex());
+ *map = origMap.replace(replacements, dims.size(), syms.size());
+
+ // Blank out dead dimensions and symbols
+ for (AffineExpr e : resToExpr) {
+ if (auto d = dyn_cast<AffineDimExpr>(e)) {
+ unsigned pos = d.getPosition();
+ if (!map->isFunctionOfDim(pos))
+ dims[pos] = nullptr;
+ }
+ if (auto s = dyn_cast<AffineSymbolExpr>(e)) {
+ unsigned pos = s.getPosition();
+ if (!map->isFunctionOfSymbol(pos))
+ syms[pos] = nullptr;
+ }
+ }
+ return success();
+}
+
/// Replace all occurrences of AffineExpr at position `pos` in `map` by the
/// defining AffineApplyOp expression and operands.
/// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
@@ -1157,6 +1292,11 @@ static LogicalResult replaceDimOrSym(AffineMap *map,
syms);
}
+ if (auto delinOp = v.getDefiningOp<affine::AffineDelinearizeIndexOp>()) {
+ return replaceAffineDelinearizeIndexInverseExpression(delinOp, v, map, dims,
+ syms);
+ }
+
auto affineApply = v.getDefiningOp<AffineApplyOp>();
if (!affineApply)
return failure();
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index cd216ef..4743941 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -1357,7 +1357,7 @@ bool mlir::affine::isValidLoopInterchangePermutation(
/// Returns true if `loops` is a perfectly nested loop nest, where loops appear
/// in it from outermost to innermost.
-bool LLVM_ATTRIBUTE_UNUSED
+[[maybe_unused]] bool
mlir::affine::isPerfectlyNested(ArrayRef<AffineForOp> loops) {
assert(!loops.empty() && "no loops provided");
@@ -1920,8 +1920,7 @@ generatePointWiseCopy(Location loc, Value memref, Value fastMemRef,
return copyNestRoot;
}
-static InFlightDiagnostic LLVM_ATTRIBUTE_UNUSED
-emitRemarkForBlock(Block &block) {
+[[maybe_unused]] static InFlightDiagnostic emitRemarkForBlock(Block &block) {
return block.getParentOp()->emitRemark();
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
index 70faa71..bc17990 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
@@ -41,18 +41,37 @@ namespace bufferization {
using namespace mlir;
-/// Return the unique ReturnOp that terminates `funcOp`.
-/// Return nullptr if there is no such unique ReturnOp.
-static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
- func::ReturnOp returnOp;
+/// Get all the ReturnOp in the funcOp.
+static SmallVector<func::ReturnOp> getReturnOps(func::FuncOp funcOp) {
+ SmallVector<func::ReturnOp> returnOps;
for (Block &b : funcOp.getBody()) {
if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
- if (returnOp)
- return nullptr;
- returnOp = candidateOp;
+ returnOps.push_back(candidateOp);
}
}
- return returnOp;
+ return returnOps;
+}
+
+/// Get the operands at the specified position for all returnOps.
+static SmallVector<Value>
+getReturnOpsOperandInPos(ArrayRef<func::ReturnOp> returnOps, size_t pos) {
+ return llvm::map_to_vector(returnOps, [&](func::ReturnOp returnOp) {
+ return returnOp.getOperand(pos);
+ });
+}
+
+/// Check if all given values are the same buffer as the block argument (modulo
+/// cast ops).
+static bool operandsEqualFuncArgument(ArrayRef<Value> operands,
+ BlockArgument argument) {
+ for (Value val : operands) {
+ while (auto castOp = val.getDefiningOp<memref::CastOp>())
+ val = castOp.getSource();
+
+ if (val != argument)
+ return false;
+ }
+ return true;
}
LogicalResult
@@ -72,40 +91,45 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
for (auto funcOp : module.getOps<func::FuncOp>()) {
if (funcOp.isExternal() || funcOp.isPublic())
continue;
- func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
- // TODO: Support functions with multiple blocks.
- if (!returnOp)
+ SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
+ if (returnOps.empty())
continue;
// Compute erased results.
- SmallVector<Value> newReturnValues;
- BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults());
+ size_t numReturnOps = returnOps.size();
+ size_t numReturnValues = funcOp.getFunctionType().getNumResults();
+ SmallVector<SmallVector<Value>> newReturnValues(numReturnOps);
+ BitVector erasedResultIndices(numReturnValues);
DenseMap<int64_t, int64_t> resultToArgs;
- for (const auto &it : llvm::enumerate(returnOp.getOperands())) {
+ for (size_t i = 0; i < numReturnValues; ++i) {
bool erased = false;
+ SmallVector<Value> returnOperands =
+ getReturnOpsOperandInPos(returnOps, i);
for (BlockArgument bbArg : funcOp.getArguments()) {
- Value val = it.value();
- while (auto castOp = val.getDefiningOp<memref::CastOp>())
- val = castOp.getSource();
-
- if (val == bbArg) {
- resultToArgs[it.index()] = bbArg.getArgNumber();
+ if (operandsEqualFuncArgument(returnOperands, bbArg)) {
+ resultToArgs[i] = bbArg.getArgNumber();
erased = true;
break;
}
}
if (erased) {
- erasedResultIndices.set(it.index());
+ erasedResultIndices.set(i);
} else {
- newReturnValues.push_back(it.value());
+ for (auto [newReturnValue, operand] :
+ llvm::zip(newReturnValues, returnOperands)) {
+ newReturnValue.push_back(operand);
+ }
}
}
// Update function.
if (failed(funcOp.eraseResults(erasedResultIndices)))
return failure();
- returnOp.getOperandsMutable().assign(newReturnValues);
+
+ for (auto [returnOp, newReturnValue] :
+ llvm::zip(returnOps, newReturnValues))
+ returnOp.getOperandsMutable().assign(newReturnValue);
// Update function calls.
for (func::CallOp callOp : callerMap[funcOp]) {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 7ca09d9..3eae67f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2826,6 +2826,20 @@ LogicalResult ShuffleVectorOp::verify() {
return success();
}
+// Folding for shufflevector op when v1 is single element 1D vector
+// and the mask is a single zero. OpFoldResult will be v1 in this case.
+OpFoldResult ShuffleVectorOp::fold(FoldAdaptor adaptor) {
+ // Check if operand 0 is a single element vector.
+ auto vecType = llvm::dyn_cast<VectorType>(getV1().getType());
+ if (!vecType || vecType.getRank() != 1 || vecType.getNumElements() != 1)
+ return {};
+ // Check if the mask is a single zero.
+ // Note: The mask is guaranteed to be non-empty.
+ if (getMask().size() != 1 || getMask()[0] != 0)
+ return {};
+ return getV1();
+}
+
//===----------------------------------------------------------------------===//
// Implementations for LLVM::LLVMFuncOp.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 01a16ce..ac35eea 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -134,10 +134,10 @@ static void printExtTypeParams(AsmPrinter &p, ArrayRef<Type> typeParams,
/// These are unused for now.
/// TODO: Move over to these once more types have been migrated to TypeDef.
-LLVM_ATTRIBUTE_UNUSED static OptionalParseResult
+[[maybe_unused]] static OptionalParseResult
generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
-LLVM_ATTRIBUTE_UNUSED static LogicalResult
-generatedTypePrinter(Type def, AsmPrinter &printer);
+[[maybe_unused]] static LogicalResult generatedTypePrinter(Type def,
+ AsmPrinter &printer);
#include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc"
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 55fe0d9..2a8c330 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -798,6 +798,26 @@ LogicalResult MmaOp::verify() {
" attribute");
}
+ // Validate layout combinations. According to the operation description, most
+ // MMA operations require layoutA=row and layoutB=col. Only m8n8k4 with f16
+ // can use other layout combinations.
+ bool isM8N8K4_F16 =
+ (mmaShape[0] == 8 && mmaShape[1] == 8 && mmaShape[2] == 4 &&
+ getMultiplicandAPtxType() == MMATypes::f16);
+
+ if (!isM8N8K4_F16) {
+ // For all other shapes/types, layoutA must be row and layoutB must be col
+ if (getLayoutA() != MMALayout::row || getLayoutB() != MMALayout::col) {
+ return emitOpError("requires layoutA = #nvvm.mma_layout<row> and "
+ "layoutB = #nvvm.mma_layout<col> for shape <")
+ << mmaShape[0] << ", " << mmaShape[1] << ", " << mmaShape[2]
+ << "> with element types "
+ << stringifyEnum(*getMultiplicandAPtxType()) << " and "
+ << stringifyEnum(*getMultiplicandBPtxType())
+ << ". Only m8n8k4 with f16 supports other layouts.";
+ }
+ }
+
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index d8f983f..6192d79 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3024,10 +3024,10 @@ ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
return failure();
}
if (dynamicPointParseResult.has_value()) {
- Type ChunkSizesType;
+ Type chunkSizesType;
if (failed(*dynamicPointParseResult) || parser.parseComma() ||
- parser.parseType(ChunkSizesType) ||
- parser.resolveOperand(dynamicChunkSizes, ChunkSizesType,
+ parser.parseType(chunkSizesType) ||
+ parser.resolveOperand(dynamicChunkSizes, chunkSizesType,
result.operands)) {
return failure();
}
@@ -3399,9 +3399,9 @@ void transform::ContinuousTileSizesOp::getEffects(
}
static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op,
- Type targetType, Type tile_sizes,
+ Type targetType, Type tileSizes,
Type) {
- printer.printFunctionalType(TypeRange{targetType}, TypeRange{tile_sizes});
+ printer.printFunctionalType(TypeRange{targetType}, TypeRange{tileSizes});
}
static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser,
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 507597b..94947b7 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2158,11 +2158,45 @@ public:
return success();
}
};
+
+struct ReinterpretCastOpConstantFolder
+ : public OpRewritePattern<ReinterpretCastOp> {
+public:
+ using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ReinterpretCastOp op,
+ PatternRewriter &rewriter) const override {
+ unsigned srcStaticCount = llvm::count_if(
+ llvm::concat<OpFoldResult>(op.getMixedOffsets(), op.getMixedSizes(),
+ op.getMixedStrides()),
+ [](OpFoldResult ofr) { return isa<Attribute>(ofr); });
+
+ SmallVector<OpFoldResult> offsets = {op.getConstifiedMixedOffset()};
+ SmallVector<OpFoldResult> sizes = op.getConstifiedMixedSizes();
+ SmallVector<OpFoldResult> strides = op.getConstifiedMixedStrides();
+
+ // TODO: Using counting comparison instead of direct comparison because
+ // getMixedValues (and therefore ReinterpretCastOp::getMixed...) returns
+ // IntegerAttrs, while constifyIndexValues (and therefore
+ // ReinterpretCastOp::getConstifiedMixed...) returns IndexAttrs.
+ if (srcStaticCount ==
+ llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides),
+ [](OpFoldResult ofr) { return isa<Attribute>(ofr); }))
+ return failure();
+
+ auto newReinterpretCast = ReinterpretCastOp::create(
+ rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides);
+
+ rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast);
+ return success();
+ }
+};
} // namespace
void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
+ results.add<ReinterpretCastOpExtractStridedMetadataFolder,
+ ReinterpretCastOpConstantFolder>(context);
}
FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
index 49b7162..6f815ae 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
@@ -121,7 +121,7 @@ struct EmulateWideIntPass final
[&typeConverter](Operation *op) { return typeConverter.isLegal(op); });
RewritePatternSet patterns(ctx);
- // Add common pattenrs to support contants, functions, etc.
+ // Add common patterns to support contants, functions, etc.
arith::populateArithWideIntEmulationPatterns(typeConverter, patterns);
memref::populateMemRefWideIntEmulationPatterns(typeConverter, patterns);
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 642ced9..90cbbd8 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -40,6 +40,16 @@ static bool isScalarLikeType(Type type) {
return type.isIntOrIndexOrFloat() || isa<ComplexType>(type);
}
+/// Helper function to attach the `VarName` attribute to an operation
+/// if a variable name is provided.
+static void attachVarNameAttr(Operation *op, OpBuilder &builder,
+ StringRef varName) {
+ if (!varName.empty()) {
+ auto varNameAttr = acc::VarNameAttr::get(builder.getContext(), varName);
+ op->setAttr(acc::getVarNameAttrName(), varNameAttr);
+ }
+}
+
struct MemRefPointerLikeModel
: public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
MemRefType> {
@@ -83,7 +93,9 @@ struct MemRefPointerLikeModel
// then we can generate an alloca operation.
if (memrefTy.hasStaticShape()) {
needsFree = false; // alloca doesn't need deallocation
- return memref::AllocaOp::create(builder, loc, memrefTy).getResult();
+ auto allocaOp = memref::AllocaOp::create(builder, loc, memrefTy);
+ attachVarNameAttr(allocaOp, builder, varName);
+ return allocaOp.getResult();
}
// For dynamic memrefs, extract sizes from the original variable if
@@ -103,8 +115,10 @@ struct MemRefPointerLikeModel
// Static dimensions are handled automatically by AllocOp
}
needsFree = true; // alloc needs deallocation
- return memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes)
- .getResult();
+ auto allocOp =
+ memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes);
+ attachVarNameAttr(allocOp, builder, varName);
+ return allocOp.getResult();
}
// TODO: Unranked not yet supported.
diff --git a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
index 135c033..645cbff 100644
--- a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
+++ b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
@@ -158,7 +158,7 @@ static FailureOr<GridOp> getGridAndVerify(Operation *op,
}
template <typename It>
-bool isUnique(It begin, It end) {
+static bool isUnique(It begin, It end) {
if (begin == end) {
return true;
}
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
index a1711a6..069191c 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
@@ -143,8 +143,8 @@ void VarInfo::setNum(Var::Num n) {
/// Helper function for `assertUsageConsistency` to better handle SMLoc
/// mismatches.
-LLVM_ATTRIBUTE_UNUSED static llvm::SMLoc
-minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2) {
+[[maybe_unused]] static llvm::SMLoc minSMLoc(AsmParser &parser, llvm::SMLoc sm1,
+ llvm::SMLoc sm2) {
const auto loc1 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(sm1));
assert(loc1 && "Could not get `FileLineColLoc` for first `SMLoc`");
const auto loc2 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(sm2));
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index f539502..684c088 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -43,8 +43,8 @@ using namespace mlir::sparse_tensor;
//===----------------------------------------------------------------------===//
#ifndef NDEBUG
-LLVM_ATTRIBUTE_UNUSED static void dumpIndexMemRef(OpBuilder &builder,
- Location loc, Value memref) {
+[[maybe_unused]] static void dumpIndexMemRef(OpBuilder &builder, Location loc,
+ Value memref) {
memref = memref::CastOp::create(
builder, loc, UnrankedMemRefType::get(builder.getIndexType(), 0), memref);
createFuncCall(builder, loc, "printMemrefInd", TypeRange{},
diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
index 5aad671..1cba1bb 100644
--- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
+#include "llvm/Support/FormatVariadic.h"
namespace mlir {
namespace tosa {
@@ -27,7 +28,7 @@ TargetEnvAttr lookupTargetEnv(Operation *op) {
}
TargetEnvAttr getDefaultTargetEnv(MLIRContext *context) {
- return TargetEnvAttr::get(context, Level::eightK,
+ return TargetEnvAttr::get(context, SpecificationVersion::V_1_0, Level::eightK,
{Profile::pro_int, Profile::pro_fp}, {});
}
@@ -38,5 +39,9 @@ TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) {
return getDefaultTargetEnv(op->getContext());
}
+llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) {
+ return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor());
+}
+
} // namespace tosa
} // namespace mlir
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp
index bcb880a..a0661e4 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp
@@ -61,8 +61,8 @@ public:
ModuleOp mod = getOperation();
MLIRContext *ctx = &getContext();
- const auto targetEnvAttr =
- TargetEnvAttr::get(ctx, level, selectedProfiles, selectedExtensions);
+ const auto targetEnvAttr = TargetEnvAttr::get(
+ ctx, specificationVersion, level, selectedProfiles, selectedExtensions);
mod->setAttr(TargetEnvAttr::name, targetEnvAttr);
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 20f9333..f072e3e 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -335,16 +335,15 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
//===----------------------------------------------------------------------===//
template <typename T>
-FailureOr<SmallVector<T>>
-TosaProfileCompliance::getOperatorDefinition(Operation *op,
- CheckCondition &condition) {
+FailureOr<OpComplianceInfo<T>>
+TosaProfileCompliance::getOperatorDefinition(Operation *op) {
const std::string opName = op->getName().getStringRef().str();
const auto complianceMap = getProfileComplianceMap<T>();
const auto it = complianceMap.find(opName);
if (it == complianceMap.end())
return {};
- return findMatchedProfile<T>(op, it->second, condition);
+ return findMatchedEntry<T>(op, it->second);
}
template <typename T>
@@ -356,22 +355,21 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
if (specRequiredModeSet.size() == 0)
return success();
- CheckCondition condition = CheckCondition::invalid;
- const auto maybeOpRequiredMode = getOperatorDefinition<T>(op, condition);
- if (failed(maybeOpRequiredMode)) {
+ const auto maybeOpDefinition = getOperatorDefinition<T>(op);
+ if (failed(maybeOpDefinition)) {
// Operators such as control-flow and shape ops do not have an operand type
// restriction. When the profile compliance information of operation is not
// found, confirm if the target have enabled the profile required from the
// specification.
- int mode_count = 0;
+ int modeCount = 0;
for (const auto &cands : specRequiredModeSet) {
if (targetEnv.allowsAnyOf(cands))
return success();
- mode_count += cands.size();
+ modeCount += cands.size();
}
op->emitOpError() << "illegal: requires"
- << (mode_count > 1 ? " any of " : " ") << "["
+ << (modeCount > 1 ? " any of " : " ") << "["
<< llvm::join(stringifyProfile<T>(specRequiredModeSet),
", ")
<< "] but not enabled in target\n";
@@ -381,7 +379,10 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
// Find the required profiles or extensions according to the operand type
// combination.
- const auto opRequiredMode = maybeOpRequiredMode.value();
+ const auto opDefinition = maybeOpDefinition.value();
+ const SmallVector<T> opRequiredMode = opDefinition.mode;
+ const CheckCondition condition = opDefinition.condition;
+
if (opRequiredMode.size() == 0) {
// No matched restriction found.
return success();
@@ -437,6 +438,21 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
}
}
+ // Ensure the matched op compliance version does not exceed the target
+ // specification version.
+ const VersionedTypeInfo versionedTypeInfo =
+ opDefinition.operandTypeInfoSet[0];
+ const TosaSpecificationVersion complianceVersion{versionedTypeInfo.second};
+ const TosaSpecificationVersion targetVersion{targetEnv.getSpecVersion()};
+ if (!targetVersion.isBackwardsCompatibleWith(complianceVersion)) {
+ op->emitOpError() << "illegal: the target specification version ("
+ << stringifyVersion(targetVersion)
+ << ") is not backwards compatible with the op compliance "
+ "specification version ("
+ << stringifyVersion(complianceVersion) << ")\n";
+ return failure();
+ }
+
return success();
}
@@ -461,14 +477,14 @@ TosaProfileCompliance::checkExtension(Operation *op,
}
LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
- CheckCondition condition = CheckCondition::invalid;
- const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
- const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
+ const auto maybeProfDef = getOperatorDefinition<Profile>(op);
+ const auto maybeExtDef = getOperatorDefinition<Extension>(op);
if (failed(maybeProfDef) && failed(maybeExtDef))
return success();
- const bool hasEntry = (succeeded(maybeProfDef) && !maybeProfDef->empty()) ||
- (succeeded(maybeExtDef) && !maybeExtDef->empty());
+ const bool hasEntry =
+ (succeeded(maybeProfDef) && !maybeProfDef->mode.empty()) ||
+ (succeeded(maybeExtDef) && !maybeExtDef->mode.empty());
if (!hasEntry) {
std::string message;
llvm::raw_string_ostream os(message);
@@ -488,7 +504,9 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
SmallVector<TypeInfo> bestTypeInfo;
const auto searchBestMatch = [&](auto map) {
for (const auto &complianceInfos : map[opName]) {
- for (const auto &typeInfos : complianceInfos.operandTypeInfoSet) {
+ for (const auto &versionedTypeInfos :
+ complianceInfos.operandTypeInfoSet) {
+ const SmallVector<TypeInfo> typeInfos = versionedTypeInfos.first;
const int matches = llvm::count_if(
llvm::zip_equal(current, typeInfos), [&](const auto zipType) {
return isSameTypeInfo(std::get<0>(zipType),
@@ -520,9 +538,8 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
// Find the profiles or extensions requirement according to the signature of
// type of the operand list.
template <typename T>
-SmallVector<T> TosaProfileCompliance::findMatchedProfile(
- Operation *op, SmallVector<OpComplianceInfo<T>> compInfo,
- CheckCondition &condition) {
+OpComplianceInfo<T> TosaProfileCompliance::findMatchedEntry(
+ Operation *op, SmallVector<OpComplianceInfo<T>> compInfo) {
assert(compInfo.size() != 0 &&
"profile-based compliance information is empty");
@@ -533,27 +550,30 @@ SmallVector<T> TosaProfileCompliance::findMatchedProfile(
return {};
for (size_t i = 0; i < compInfo.size(); i++) {
- SmallVector<SmallVector<TypeInfo>> sets = compInfo[i].operandTypeInfoSet;
- for (SmallVector<TypeInfo> expected : sets) {
+ SmallVector<VersionedTypeInfo> sets = compInfo[i].operandTypeInfoSet;
+ for (const auto &set : sets) {
+ SmallVector<TypeInfo> expected = set.first;
assert(present.size() == expected.size() &&
"the entries for profile-based compliance do not match between "
"the generated metadata and the type definition retrieved from "
" the operation");
- bool is_found = true;
+ bool isFound = true;
// Compare the type signature between the given operation and the
// compliance metadata.
for (size_t j = 0; j < expected.size(); j++) {
if (!isSameTypeInfo(present[j], expected[j])) {
// Verify the next mode set from the list.
- is_found = false;
+ isFound = false;
break;
}
}
- if (is_found == true) {
- condition = compInfo[i].condition;
- return compInfo[i].mode;
+ if (isFound == true) {
+ SmallVector<VersionedTypeInfo> typeInfoSet{set};
+ OpComplianceInfo<T> info{compInfo[i].mode, typeInfoSet,
+ compInfo[i].condition};
+ return info;
}
}
}
diff --git a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
index 9a24c2b..a2cff6a 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
@@ -21,10 +21,10 @@ using namespace mlir;
// These are automatically generated by ODS but are not used as the Transform
// dialect uses a different dispatch mechanism to support dialect extensions.
-LLVM_ATTRIBUTE_UNUSED static OptionalParseResult
+[[maybe_unused]] static OptionalParseResult
generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
-LLVM_ATTRIBUTE_UNUSED static LogicalResult
-generatedTypePrinter(Type def, AsmPrinter &printer);
+[[maybe_unused]] static LogicalResult generatedTypePrinter(Type def,
+ AsmPrinter &printer);
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/Transform/IR/TransformTypes.cpp.inc"
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 025ee9a..c809c502 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -91,7 +91,7 @@ mlir::vector::isTranspose2DSlice(vector::TransposeOp op) {
// Check whether the two source vector dimensions that are greater than one
// must be transposed with each other so that we can apply one of the 2-D
- // transpose pattens. Otherwise, these patterns are not applicable.
+ // transpose patterns. Otherwise, these patterns are not applicable.
if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1],
op.getPermutation()))
return failure();
diff --git a/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp
index 89b62a2..a514ea9 100644
--- a/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp
+++ b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp
@@ -12,6 +12,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/SymbolTable.h"
@@ -39,28 +40,6 @@ void printElseRegion(OpAsmPrinter &opPrinter, Operation *op,
opPrinter.printKeywordOrString("else ");
opPrinter.printRegion(elseRegion);
}
-
-ParseResult parseWasmVisibility(OpAsmParser &opParser, StringAttr &visibility) {
- std::string keyword;
- auto initLocation = opParser.getCurrentLocation();
- std::ignore = opParser.parseOptionalKeywordOrString(&keyword);
- if (keyword == "nested" or keyword == "") {
- visibility = StringAttr::get(opParser.getContext(), "nested");
- return ParseResult::success();
- }
-
- if (keyword == "public" || keyword == "private") {
- visibility = StringAttr::get(opParser.getContext(), keyword);
- return ParseResult::success();
- }
- opParser.emitError(initLocation, "expecting symbol visibility");
- return ParseResult::failure();
-}
-
-void printWasmVisibility(OpAsmPrinter &opPrinter, Operation *op,
- Attribute visibility) {
- opPrinter.printKeywordOrString(cast<StringAttr>(visibility).strref());
-}
} // namespace
#define GET_OP_CLASSES
@@ -167,10 +146,23 @@ Block *FuncOp::addEntryBlock() {
void FuncOp::build(OpBuilder &odsBuilder, OperationState &odsState,
StringRef symbol, FunctionType funcType) {
- FuncOp::build(odsBuilder, odsState, symbol, funcType, {}, {}, "nested");
+ FuncOp::build(odsBuilder, odsState, symbol, funcType, {}, {});
}
ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
+ auto *ctx = parser.getContext();
+ std::string visibilityString;
+ auto loc = parser.getNameLoc();
+ ParseResult res = parser.parseOptionalKeywordOrString(&visibilityString);
+ bool exported{false};
+ if (res.succeeded()) {
+ if (visibilityString != "exported")
+ return parser.emitError(
+ loc, "expecting either `exported` or symbol name. got ")
+ << visibilityString;
+ exported = true;
+ }
+
auto buildFuncType = [&parser](Builder &builder, ArrayRef<Type> argTypes,
ArrayRef<Type> results,
function_interface_impl::VariadicFlag,
@@ -191,11 +183,13 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
return builder.getFunctionType(argTypesWithoutLocal, results);
};
-
- return function_interface_impl::parseFunctionOp(
+ auto funcParseRes = function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false,
getFunctionTypeAttrName(result.name), buildFuncType,
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+ if (exported)
+ result.addAttribute(getExportedAttrName(result.name), UnitAttr::get(ctx));
+ return funcParseRes;
}
LogicalResult FuncOp::verifyBody() {
@@ -224,9 +218,18 @@ LogicalResult FuncOp::verifyBody() {
}
void FuncOp::print(OpAsmPrinter &p) {
+ /// If exported, print it before and mask it before printing
+ /// using generic interface.
+ auto exported = getExported();
+ if (exported) {
+ p << " exported";
+ removeExportedAttr();
+ }
function_interface_impl::printFunctionOp(
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
getArgAttrsAttrName(), getResAttrsAttrName());
+ if (exported)
+ setExported(true);
}
//===----------------------------------------------------------------------===//
@@ -237,38 +240,37 @@ void FuncImportOp::build(OpBuilder &odsBuilder, OperationState &odsState,
StringRef symbol, StringRef moduleName,
StringRef importName, FunctionType type) {
FuncImportOp::build(odsBuilder, odsState, symbol, moduleName, importName,
- type, {}, {}, odsBuilder.getStringAttr("nested"));
+ type, {}, {});
}
//===----------------------------------------------------------------------===//
// GlobalOp
//===----------------------------------------------------------------------===//
-
-void GlobalOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- StringRef symbol, Type type, bool isMutable) {
- GlobalOp::build(odsBuilder, odsState, symbol, type, isMutable,
- odsBuilder.getStringAttr("nested"));
-}
-
// Custom formats
ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) {
StringAttr symbolName;
Type globalType;
auto *ctx = parser.getContext();
- ParseResult res = parser.parseSymbolName(
- symbolName, SymbolTable::getSymbolAttrName(), result.attributes);
+ std::string visibilityString;
+ auto loc = parser.getNameLoc();
+ ParseResult res = parser.parseOptionalKeywordOrString(&visibilityString);
+ if (res.succeeded()) {
+ if (visibilityString != "exported")
+ return parser.emitError(
+ loc, "expecting either `exported` or symbol name. got ")
+ << visibilityString;
+ result.addAttribute(getExportedAttrName(result.name), UnitAttr::get(ctx));
+ }
+ res = parser.parseSymbolName(symbolName, SymbolTable::getSymbolAttrName(),
+ result.attributes);
res = parser.parseType(globalType);
result.addAttribute(getTypeAttrName(result.name), TypeAttr::get(globalType));
std::string mutableString;
res = parser.parseOptionalKeywordOrString(&mutableString);
if (res.succeeded() && mutableString == "mutable")
result.addAttribute("isMutable", UnitAttr::get(ctx));
- std::string visibilityString;
- res = parser.parseOptionalKeywordOrString(&visibilityString);
- if (res.succeeded())
- result.addAttribute("sym_visibility",
- StringAttr::get(ctx, visibilityString));
+
res = parser.parseColon();
Region *globalInitRegion = result.addRegion();
res = parser.parseRegion(*globalInitRegion);
@@ -276,11 +278,11 @@ ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) {
}
void GlobalOp::print(OpAsmPrinter &printer) {
+ if (getExported())
+ printer << " exported";
printer << " @" << getSymName().str() << " " << getType();
if (getIsMutable())
printer << " mutable";
- if (auto vis = getSymVisibility())
- printer << " " << *vis;
printer << " :";
Region &body = getRegion();
if (!body.empty()) {
@@ -319,13 +321,6 @@ GlobalGetOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// GlobalImportOp
//===----------------------------------------------------------------------===//
-void GlobalImportOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- StringRef symbol, StringRef moduleName,
- StringRef importName, Type type, bool isMutable) {
- GlobalImportOp::build(odsBuilder, odsState, symbol, moduleName, importName,
- type, isMutable, odsBuilder.getStringAttr("nested"));
-}
-
ParseResult GlobalImportOp::parse(OpAsmParser &parser, OperationState &result) {
auto *ctx = parser.getContext();
ParseResult res = parseImportOp(parser, result);
@@ -335,12 +330,8 @@ ParseResult GlobalImportOp::parse(OpAsmParser &parser, OperationState &result) {
res = parser.parseOptionalKeywordOrString(&mutableOrSymVisString);
if (res.succeeded() && mutableOrSymVisString == "mutable") {
result.addAttribute("isMutable", UnitAttr::get(ctx));
- res = parser.parseOptionalKeywordOrString(&mutableOrSymVisString);
}
- if (res.succeeded())
- result.addAttribute("sym_visibility",
- StringAttr::get(ctx, mutableOrSymVisString));
res = parser.parseColon();
Type importedType;
@@ -356,8 +347,6 @@ void GlobalImportOp::print(OpAsmPrinter &printer) {
<< "\" as @" << getSymName();
if (getIsMutable())
printer << " mutable";
- if (auto vis = getSymVisibility())
- printer << " " << *vis;
printer << " : " << getType();
}
@@ -431,27 +420,6 @@ LogicalResult LocalTeeOp::verify() {
Block *LoopOp::getLabelTarget() { return &getBody().front(); }
//===----------------------------------------------------------------------===//
-// MemOp
-//===----------------------------------------------------------------------===//
-
-void MemOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- StringRef symbol, LimitType limit) {
- MemOp::build(odsBuilder, odsState, symbol, limit,
- odsBuilder.getStringAttr("nested"));
-}
-
-//===----------------------------------------------------------------------===//
-// MemImportOp
-//===----------------------------------------------------------------------===//
-
-void MemImportOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- StringRef symbol, StringRef moduleName,
- StringRef importName, LimitType limits) {
- MemImportOp::build(odsBuilder, odsState, symbol, moduleName, importName,
- limits, odsBuilder.getStringAttr("nested"));
-}
-
-//===----------------------------------------------------------------------===//
// ReinterpretOp
//===----------------------------------------------------------------------===//
@@ -471,24 +439,3 @@ LogicalResult ReinterpretOp::verify() {
//===----------------------------------------------------------------------===//
void ReturnOp::build(OpBuilder &odsBuilder, OperationState &odsState) {}
-
-//===----------------------------------------------------------------------===//
-// TableOp
-//===----------------------------------------------------------------------===//
-
-void TableOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- StringRef symbol, TableType type) {
- TableOp::build(odsBuilder, odsState, symbol, type,
- odsBuilder.getStringAttr("nested"));
-}
-
-//===----------------------------------------------------------------------===//
-// TableImportOp
-//===----------------------------------------------------------------------===//
-
-void TableImportOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- StringRef symbol, StringRef moduleName,
- StringRef importName, TableType type) {
- TableImportOp::build(odsBuilder, odsState, symbol, moduleName, importName,
- type, odsBuilder.getStringAttr("nested"));
-}
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 9beb22d..1599ae9 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -727,6 +727,152 @@ void MemLayoutAttr::print(AsmPrinter &printer) const {
}
printer << ">";
}
+// a helper utility to perform binary operation on OpFoldResult.
+// If both a and b are attributes, it will simply return the result.
+// Otherwise, the corresponding arith op will be generated, and an
+// contant op will be created if one of them is an attribute.
+template <typename ArithOp>
+OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc,
+ OpBuilder &builder) {
+ auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a);
+ auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b);
+ return builder.create<ArithOp>(loc, aVal, bVal).getResult();
+}
+
+// a helper utility to perform division operation on OpFoldResult and int64_t.
+#define div(a, b) \
+ genBinOp<arith::DivSIOp>(a, builder.getIndexAttr(b), loc, builder)
+
+// a helper utility to perform reminder operation on OpFoldResult and int64_t.
+#define rem(a, b) \
+ genBinOp<arith::RemSIOp>(a, builder.getIndexAttr(b), loc, builder)
+
+// a helper utility to perform multiply operation on OpFoldResult and int64_t.
+#define mul(a, b) \
+ genBinOp<arith::MulIOp>(a, builder.getIndexAttr(b), loc, builder)
+
+// a helper utility to perform addition operation on two OpFoldResult.
+#define add(a, b) genBinOp<arith::AddIOp>(a, b, loc, builder)
+
+// block the given offsets according to the block shape
+// say the original offset is [y, x], and the block shape is [By, Bx],
+// then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
+SmallVector<OpFoldResult> getBlockedOffsets(OpBuilder &builder, Location loc,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<int64_t> blockShape) {
+
+ assert(offsets.size() == blockShape.size() &&
+ "offsets and blockShape must have the same size");
+ SmallVector<OpFoldResult> blockedOffsets;
+ SmallVector<OpFoldResult> divs, rems;
+
+ for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
+ divs.push_back(div(offset, block));
+ rems.push_back(rem(offset, block));
+ }
+ blockedOffsets.append(divs.begin(), divs.end());
+ blockedOffsets.append(rems.begin(), rems.end());
+
+ return blockedOffsets;
+}
+
+// Get strides as vector of integer for MemDesc.
+SmallVector<int64_t> MemDescType::getStrideShape() {
+
+ SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
+
+ ArrayAttr strideAttr = getStrideAttr();
+ SmallVector<int64_t> strides;
+ for (Attribute attr : strideAttr.getValue()) {
+ strides.push_back(cast<IntegerAttr>(attr).getInt());
+ }
+
+ SmallVector<int64_t> innerBlkShape = getBlockShape();
+
+ // get perm from FCD to LCD
+ // perm[i] = the dim with i-th smallest stride
+ SmallVector<int, 4> perm =
+ llvm::to_vector<4>(llvm::seq<int>(0, strides.size()));
+ llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; });
+
+ assert(strides[perm[0]] == 1 && "inner most dim must have stride 1");
+
+ SmallVector<int64_t> innerBlkStride(innerBlkShape.size());
+ innerBlkStride[perm[0]] = 1;
+ for (size_t i = 1; i < perm.size(); ++i)
+ innerBlkStride[perm[i]] =
+ innerBlkStride[perm[i - 1]] * innerBlkShape[perm[i - 1]];
+
+ // compute the original matrix shape using the stride info
+ // and compute the number of blocks in each dimension
+ // The shape of highest dim can't be derived from stride info,
+ // but doesn't impact the stride computation for blocked layout.
+ SmallVector<int64_t> matrixShapeOrig(matrixShape.size());
+ SmallVector<int64_t> BlkShapeOrig(matrixShape.size());
+ for (size_t i = 0; i < perm.size() - 1; ++i) {
+ matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]];
+ BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]];
+ }
+
+ int64_t innerBlkSize = 1;
+ for (auto s : innerBlkShape)
+ innerBlkSize *= s;
+
+ SmallVector<int64_t> outerBlkStride(matrixShape.size());
+ outerBlkStride[perm[0]] = innerBlkSize;
+ for (size_t i = 0; i < perm.size() - 1; ++i) {
+ outerBlkStride[perm[i + 1]] =
+ outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]];
+ }
+
+ // combine the inner and outer strides
+ SmallVector<int64_t> blockedStrides;
+ blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end());
+ blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end());
+
+ return blockedStrides;
+}
+
+// Calculate the linear offset using the blocked offsets and stride
+Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
+ ArrayRef<OpFoldResult> offsets) {
+
+ SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
+ SmallVector<int64_t> blockShape = getBlockShape();
+ SmallVector<int64_t> strides = getStrideShape();
+ SmallVector<OpFoldResult> blockedOffsets;
+
+ // blockshape equal to matrixshape means no blocking
+ if (llvm::equal(blockShape, matrixShape)) {
+ // remove the outer dims from strides
+ strides.erase(strides.begin(), strides.begin() + matrixShape.size());
+ } else {
+ assert(offsets.size() == blockShape.size() &&
+ "offsets and blockShape must have the same size");
+ // say the original offset is [y, x], and the block shape is [By, Bx],
+ // then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
+
+ SmallVector<OpFoldResult> divs, rems;
+
+ for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
+ divs.push_back(div(offset, block));
+ rems.push_back(rem(offset, block));
+ }
+ blockedOffsets.append(divs.begin(), divs.end());
+ blockedOffsets.append(rems.begin(), rems.end());
+ offsets = blockedOffsets;
+ }
+
+ // Start with initial value as matrix descriptor's base offset.
+ Value linearOffset = arith::ConstantIndexOp::create(builder, loc, 0);
+ for (size_t i = 0; i < offsets.size(); ++i) {
+ OpFoldResult mulResult = mul(offsets[i], strides[i]);
+ Value mulVal = getValueOrCreateConstantIndexOp(builder, loc, mulResult);
+ linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset);
+ }
+
+ return linearOffset;
+}
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 81b5788..abd12e2 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -20,8 +20,8 @@
#define DEBUG_TYPE "xegpu"
-namespace mlir {
-namespace xegpu {
+using namespace mlir;
+using namespace mlir::xegpu;
static bool isSharedMemory(const MemRefType &memrefTy) {
Attribute attr = memrefTy.getMemorySpace();
@@ -173,6 +173,49 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
return success();
}
+LogicalResult
+IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
+ UnitAttr subgroup_block_io,
+ function_ref<InFlightDiagnostic()> emitError) {
+
+ if (!dataTy) {
+ if (subgroup_block_io)
+ return emitError() << "subgroup_block_io "
+ "are only allowed when result is a 1D VectorType.";
+ else
+ return success();
+ }
+
+ if (mdescTy.getRank() != 2)
+ return emitError() << "mem_desc must be 2D.";
+
+ ArrayRef<int64_t> dataShape = dataTy.getShape();
+ ArrayRef<int64_t> mdescShape = mdescTy.getShape();
+
+ if (dataShape.size() == 2) {
+ if (subgroup_block_io)
+ return emitError() << "subgroup_block_io "
+ "are only allowed when result is a 1D VectorType.";
+ if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
+ [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
+ return emitError() << "data shape must not exceed mem_desc shape.";
+ } else {
+ SmallVector<int64_t> blockShape = mdescTy.getBlockShape();
+ // if the subgroup_block_io attribute is set, mdescTy must have block
+ // attribute
+ if (subgroup_block_io && !blockShape.size())
+ return emitError() << "mem_desc must have block attribute when "
+ "subgroup_block_io is set.";
+ // if the subgroup_block_io attribute is set, the memdesc should be row
+ // major
+ if (subgroup_block_io && mdescTy.isColMajor())
+ return emitError() << "mem_desc should be row major when "
+ "subgroup_block_io is set.";
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_CreateNdDescOp
//===----------------------------------------------------------------------===//
@@ -1049,23 +1092,20 @@ void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
llvm::SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+ // Call the generated builder with all parameters (including optional ones as
+ // nullptr/empty)
build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
- layout);
+ /*subgroup_block_io=*/nullptr, layout);
}
LogicalResult LoadMatrixOp::verify() {
- VectorType resTy = getRes().getType();
- MemDescType mdescTy = getMemDesc().getType();
- if (mdescTy.getRank() != 2)
- return emitOpError("mem_desc must be 2D.");
+ auto resTy = dyn_cast<VectorType>(getRes().getType());
+ UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
+ MemDescType mdescTy = getMemDesc().getType();
- ArrayRef<int64_t> valueShape = resTy.getShape();
- ArrayRef<int64_t> mdescShape = mdescTy.getShape();
- if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape),
- [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
- return emitOpError("result shape must not exceed mem_desc shape.");
- return success();
+ return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io,
+ [&]() { return emitError(); });
}
//===----------------------------------------------------------------------===//
@@ -1080,62 +1120,18 @@ void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
- layout);
+ /*subgroup_block_io=*/nullptr, layout);
}
LogicalResult StoreMatrixOp::verify() {
- VectorType dataTy = getData().getType();
- MemDescType mdescTy = getMemDesc().getType();
-
- if (mdescTy.getRank() != 2)
- return emitOpError("mem_desc must be 2D.");
-
- ArrayRef<int64_t> dataShape = dataTy.getShape();
- ArrayRef<int64_t> mdescShape = mdescTy.getShape();
- if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
- [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
- return emitOpError("data shape must not exceed mem_desc shape.");
-
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// XeGPU_MemDescSubviewOp
-//===----------------------------------------------------------------------===//
-
-void MemDescSubviewOp::build(OpBuilder &builder, OperationState &state,
- Type resTy, Value src,
- llvm::ArrayRef<OpFoldResult> offsets) {
- llvm::SmallVector<Value> dynamicOffsets;
- llvm::SmallVector<int64_t> staticOffsets;
- dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
- auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
- build(builder, state, resTy, src, dynamicOffsets, staticOffsetsAttr);
-}
-
-LogicalResult MemDescSubviewOp::verify() {
- MemDescType srcTy = getSrc().getType();
- MemDescType resTy = getRes().getType();
- ArrayRef<int64_t> srcShape = srcTy.getShape();
- ArrayRef<int64_t> resShape = resTy.getShape();
-
- if (srcTy.getRank() < resTy.getRank())
- return emitOpError("result rank must not exceed source rank.");
- if (llvm::any_of(
- llvm::zip_equal(resShape, srcShape.take_back(resShape.size())),
- [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
- return emitOpError("result shape must not exceed source shape.");
-
- if (srcTy.getStrides() != resTy.getStrides())
- return emitOpError("result must inherit the source strides.");
-
- return success();
+ auto dataTy = dyn_cast<VectorType>(getData().getType());
+ UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
+ MemDescType mdescTy = getMemDesc().getType();
+ return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io,
+ [&]() { return emitError(); });
}
-} // namespace xegpu
-} // namespace mlir
-
namespace mlir {
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc>
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index a178d0f..aafa1b7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -941,7 +941,9 @@ struct UnrollLoadMatrixOp : public UnrollPattern<xegpu::LoadMatrixOp> {
LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- VectorType valueTy = op.getType();
+ VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
+ assert(valueTy && "the value type must be vector type!");
+
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape || targetShape->size() != (size_t)valueTy.getRank())
return failure();
@@ -984,7 +986,8 @@ struct UnrollStoreMatrixOp : public UnrollPattern<xegpu::StoreMatrixOp> {
return failure();
Location loc = op.getLoc();
- VectorType valueTy = op.getData().getType();
+ VectorType valueTy = llvm::dyn_cast<VectorType>(op.getData().getType());
+ assert(valueTy && "the value type must be vector type!");
ArrayRef<int64_t> shape = valueTy.getShape();
auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index c28d2fc..31a967d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -991,7 +991,8 @@ struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
return failure();
ArrayRef<int64_t> wgShape = op.getDataShape();
- VectorType valueTy = op.getRes().getType();
+ VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType());
+ assert(valueTy && "the value type must be vector type!");
Type elemTy = valueTy.getElementType();
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 89b81cf..5f63fe6 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -1204,7 +1204,7 @@ AffineMap AffineMap::getImpl(unsigned dimCount, unsigned symbolCount,
/// present in result expressions is less than `dimCount` and the highest index
/// of symbolic identifier present in result expressions is less than
/// `symbolCount`.
-LLVM_ATTRIBUTE_UNUSED static bool
+[[maybe_unused]] static bool
willBeValidAffineMap(unsigned dimCount, unsigned symbolCount,
ArrayRef<AffineExpr> results) {
int64_t maxDimPosition = -1;
diff --git a/mlir/lib/TableGen/CodeGenHelpers.cpp b/mlir/lib/TableGen/CodeGenHelpers.cpp
index cb90ef8..d52d5e7 100644
--- a/mlir/lib/TableGen/CodeGenHelpers.cpp
+++ b/mlir/lib/TableGen/CodeGenHelpers.cpp
@@ -49,9 +49,7 @@ StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
raw_ostream &os, const RecordKeeper &records, StringRef tag)
: os(os), uniqueOutputLabel(getUniqueOutputLabel(records, tag)) {}
-void StaticVerifierFunctionEmitter::emitOpConstraints(
- ArrayRef<const Record *> opDefs) {
- NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace());
+void StaticVerifierFunctionEmitter::emitOpConstraints() {
emitTypeConstraints();
emitAttrConstraints();
emitPropConstraints();
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 9603813..857e31b 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -2604,6 +2604,7 @@ static constexpr std::array kExplicitLLVMFuncOpAttributes{
StringLiteral("denormal-fp-math-f32"),
StringLiteral("fp-contract"),
StringLiteral("frame-pointer"),
+ StringLiteral("inlinehint"),
StringLiteral("instrument-function-entry"),
StringLiteral("instrument-function-exit"),
StringLiteral("memory"),
@@ -2643,6 +2644,8 @@ void ModuleImport::processFunctionAttributes(llvm::Function *func,
funcOp.setNoInline(true);
if (func->hasFnAttribute(llvm::Attribute::AlwaysInline))
funcOp.setAlwaysInline(true);
+ if (func->hasFnAttribute(llvm::Attribute::InlineHint))
+ funcOp.setInlineHint(true);
if (func->hasFnAttribute(llvm::Attribute::OptimizeNone))
funcOp.setOptimizeNone(true);
if (func->hasFnAttribute(llvm::Attribute::Convergent))
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 845a14f..147613f 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1652,6 +1652,8 @@ static void convertFunctionAttributes(LLVMFuncOp func,
llvmFunc->addFnAttr(llvm::Attribute::NoInline);
if (func.getAlwaysInlineAttr())
llvmFunc->addFnAttr(llvm::Attribute::AlwaysInline);
+ if (func.getInlineHintAttr())
+ llvmFunc->addFnAttr(llvm::Attribute::InlineHint);
if (func.getOptimizeNoneAttr())
llvmFunc->addFnAttr(llvm::Attribute::OptimizeNone);
if (func.getConvergentAttr())
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 0c3e87a..d9ad8fb 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -2619,6 +2619,11 @@ LogicalResult ControlFlowStructurizer::structurize() {
// region. We cannot handle such cases given that once a value is sinked into
// the SelectionOp/LoopOp's region, there is no escape for it.
for (auto *block : constructBlocks) {
+ if (!block->use_empty())
+ return emitError(block->getParent()->getLoc(),
+ "failed control flow structurization: "
+ "block has uses outside of the "
+ "enclosing selection/loop construct");
for (Operation &op : *block)
if (!op.use_empty())
return op.emitOpError("failed control flow structurization: value has "
diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
index 51c6077..366ba8f 100644
--- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
+++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
@@ -14,6 +14,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/Support/LLVM.h"
@@ -138,6 +139,10 @@ using ImportDesc =
using parsed_inst_t = FailureOr<SmallVector<Value>>;
+struct EmptyBlockMarker {};
+using BlockTypeParseResult =
+ std::variant<EmptyBlockMarker, TypeIdxRecord, Type>;
+
struct WasmModuleSymbolTables {
SmallVector<FunctionSymbolRefContainer> funcSymbols;
SmallVector<GlobalSymbolRefContainer> globalSymbols;
@@ -175,6 +180,9 @@ class ParserHead;
/// Wrapper around SmallVector to only allow access as push and pop on the
/// stack. Makes sure that there are no "free accesses" on the stack to preserve
/// its state.
+/// This class also keep tracks of the Wasm labels defined by different ops,
+/// which can be targeted by control flow ops. This can be modeled as part of
+/// the Value Stack as Wasm control flow ops can only target enclosing labels.
class ValueStack {
private:
struct LabelLevel {
@@ -206,6 +214,16 @@ public:
/// if an error occurs.
LogicalResult pushResults(ValueRange results, Location *opLoc);
+ void addLabelLevel(LabelLevelOpInterface levelOp) {
+ labelLevel.push_back({values.size(), levelOp});
+ LDBG() << "Adding a new frame context to ValueStack";
+ }
+
+ void dropLabelLevel() {
+ assert(!labelLevel.empty() && "Trying to drop a frame from empty context");
+ auto newSize = labelLevel.pop_back_val().stackIdx;
+ values.truncate(newSize);
+ }
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
/// A simple dump function for debugging.
/// Writes output to llvm::dbgs().
@@ -214,6 +232,7 @@ public:
private:
SmallVector<Value> values;
+ SmallVector<LabelLevel> labelLevel;
};
using local_val_t = TypedValue<wasmssa::LocalRefType>;
@@ -248,6 +267,19 @@ private:
buildNumericOp(OpBuilder &builder,
std::enable_if_t<std::is_arithmetic_v<valueType>> * = nullptr);
+ /// Construct a conversion operation of type \p opType that takes a value from
+ /// type \p inputType on the stack and will produce a value of type
+ /// \p outputType.
+ ///
+ /// \p opType - The WASM dialect operation to build.
+ /// \p inputType - The operand type for the built instruction.
+ /// \p outputType - The result type for the built instruction.
+ ///
+ /// \returns The parsed instruction result, or failure.
+ template <typename opType, typename inputType, typename outputType,
+ typename... extraArgsT>
+ inline parsed_inst_t buildConvertOp(OpBuilder &builder, extraArgsT...);
+
/// This function generates a dispatch tree to associate an opcode with a
/// parser. Parsers are registered by specialising the
/// `parseSpecificInstruction` function for the op code to handle.
@@ -280,11 +312,105 @@ private:
}
}
+ ///
+ /// RAII guard class for creating a nesting level
+ ///
+ struct NestingContextGuard {
+ NestingContextGuard(ExpressionParser &parser, LabelLevelOpInterface levelOp)
+ : parser{parser} {
+ parser.addNestingContextLevel(levelOp);
+ }
+ NestingContextGuard(NestingContextGuard &&other) : parser{other.parser} {
+ other.shouldDropOnDestruct = false;
+ }
+ NestingContextGuard(NestingContextGuard const &) = delete;
+ ~NestingContextGuard() {
+ if (shouldDropOnDestruct)
+ parser.dropNestingContextLevel();
+ }
+ ExpressionParser &parser;
+ bool shouldDropOnDestruct = true;
+ };
+
+ void addNestingContextLevel(LabelLevelOpInterface levelOp) {
+ valueStack.addLabelLevel(levelOp);
+ }
+
+ void dropNestingContextLevel() {
+ // Should always succeed as we are droping the frame that was previously
+ // created.
+ valueStack.dropLabelLevel();
+ }
+
+ llvm::FailureOr<FunctionType> getFuncTypeFor(OpBuilder &builder,
+ EmptyBlockMarker) {
+ return builder.getFunctionType({}, {});
+ }
+
+ llvm::FailureOr<FunctionType> getFuncTypeFor(OpBuilder &builder,
+ TypeIdxRecord type) {
+ if (type.id >= symbols.moduleFuncTypes.size())
+ return emitError(*currentOpLoc,
+ "type index references nonexistent type (")
+ << type.id << "). Only " << symbols.moduleFuncTypes.size()
+ << " types are registered";
+ return symbols.moduleFuncTypes[type.id];
+ }
+
+ llvm::FailureOr<FunctionType> getFuncTypeFor(OpBuilder &builder,
+ Type valType) {
+ return builder.getFunctionType({}, {valType});
+ }
+
+ llvm::FailureOr<FunctionType>
+ getFuncTypeFor(OpBuilder &builder, BlockTypeParseResult parseResult) {
+ return std::visit(
+ [this, &builder](auto value) { return getFuncTypeFor(builder, value); },
+ parseResult);
+ }
+
+ llvm::FailureOr<FunctionType>
+ getFuncTypeFor(OpBuilder &builder,
+ llvm::FailureOr<BlockTypeParseResult> parseResult) {
+ if (llvm::failed(parseResult))
+ return failure();
+ return getFuncTypeFor(builder, *parseResult);
+ }
+
+ llvm::FailureOr<FunctionType> parseBlockFuncType(OpBuilder &builder);
+
struct ParseResultWithInfo {
SmallVector<Value> opResults;
std::byte endingByte;
};
+ template <typename FilterT = ByteSequence<WasmBinaryEncoding::endByte>>
+ /// @param blockToFill: the block which content will be populated
+ /// @param resType: the type that this block is supposed to return
+ llvm::FailureOr<std::byte>
+ parseBlockContent(OpBuilder &builder, Block *blockToFill, TypeRange resTypes,
+ Location opLoc, LabelLevelOpInterface levelOp,
+ FilterT parseEndBytes = {}) {
+ OpBuilder::InsertionGuard guard{builder};
+ builder.setInsertionPointToStart(blockToFill);
+ LDBG() << "parsing a block of type "
+ << builder.getFunctionType(blockToFill->getArgumentTypes(),
+ resTypes);
+ auto nC = addNesting(levelOp);
+
+ if (failed(pushResults(blockToFill->getArguments())))
+ return failure();
+ auto bodyParsingRes = parse(builder, parseEndBytes);
+ if (failed(bodyParsingRes))
+ return failure();
+ auto returnOperands = popOperands(resTypes);
+ if (failed(returnOperands))
+ return failure();
+ builder.create<BlockReturnOp>(opLoc, *returnOperands);
+ LDBG() << "end of parsing of a block";
+ return bodyParsingRes->endingByte;
+ }
+
public:
template <std::byte ParseEndByte = WasmBinaryEncoding::endByte>
parsed_inst_t parse(OpBuilder &builder, UniqueByte<ParseEndByte> = {});
@@ -294,7 +420,11 @@ public:
parse(OpBuilder &builder,
ByteSequence<ExpressionParseEnd...> parsingEndFilters);
- FailureOr<SmallVector<Value>> popOperands(TypeRange operandTypes) {
+ NestingContextGuard addNesting(LabelLevelOpInterface levelOp) {
+ return NestingContextGuard{*this, levelOp};
+ }
+
+ FailureOr<llvm::SmallVector<Value>> popOperands(TypeRange operandTypes) {
return valueStack.popOperands(operandTypes, &currentOpLoc.value());
}
@@ -308,6 +438,12 @@ public:
template <typename OpToCreate>
parsed_inst_t parseSetOrTee(OpBuilder &);
+ /// Blocks and Loops have a similar format and differ only in how their exit
+ /// is handled which doesn´t matter at parsing time. Factorizes in one
+ /// function.
+ template <typename OpToCreate>
+ parsed_inst_t parseBlockLikeOp(OpBuilder &);
+
private:
std::optional<Location> currentOpLoc;
ParserHead &parser;
@@ -586,6 +722,29 @@ public:
return success();
}
+ llvm::FailureOr<BlockTypeParseResult> parseBlockType(MLIRContext *ctx) {
+ auto loc = getLocation();
+ auto blockIndicator = peek();
+ if (failed(blockIndicator))
+ return failure();
+ if (*blockIndicator == WasmBinaryEncoding::Type::emptyBlockType) {
+ offset += 1;
+ return {EmptyBlockMarker{}};
+ }
+ if (isValueOneOf(*blockIndicator, valueTypesEncodings))
+ return parseValueType(ctx);
+ /// Block type idx is a 32 bit positive integer encoded as a 33 bit signed
+ /// value
+ auto typeIdx = parseI64();
+ if (failed(typeIdx))
+ return failure();
+ if (*typeIdx < 0 || *typeIdx > std::numeric_limits<uint32_t>::max())
+ return emitError(loc, "type ID should be representable with an unsigned "
+ "32 bits integer. Got ")
+ << *typeIdx;
+ return {TypeIdxRecord{static_cast<uint32_t>(*typeIdx)}};
+ }
+
bool end() const { return curHead().empty(); }
ParserHead copy() const { return *this; }
@@ -701,17 +860,41 @@ inline parsed_inst_t ExpressionParser::parseSpecificInstruction(OpBuilder &) {
void ValueStack::dump() const {
llvm::dbgs() << "================= Wasm ValueStack =======================\n";
llvm::dbgs() << "size: " << size() << "\n";
+ llvm::dbgs() << "nbFrames: " << labelLevel.size() << '\n';
llvm::dbgs() << "<Top>"
<< "\n";
// Stack is pushed to via push_back. Therefore the top of the stack is the
// end of the vector. Iterate in reverse so that the first thing we print
// is the top of the stack.
+ auto indexGetter = [this]() {
+ size_t idx = labelLevel.size();
+ return [this, idx]() mutable -> std::optional<std::pair<size_t, size_t>> {
+ llvm::dbgs() << "IDX: " << idx << '\n';
+ if (idx == 0)
+ return std::nullopt;
+ auto frameId = idx - 1;
+ auto frameLimit = labelLevel[frameId].stackIdx;
+ idx -= 1;
+ return {{frameId, frameLimit}};
+ };
+ };
+ auto getNextFrameIndex = indexGetter();
+ auto nextFrameIdx = getNextFrameIndex();
size_t stackSize = size();
- for (size_t idx = 0; idx < stackSize; idx++) {
+ for (size_t idx = 0; idx < stackSize; ++idx) {
size_t actualIdx = stackSize - 1 - idx;
+ while (nextFrameIdx && (nextFrameIdx->second > actualIdx)) {
+ llvm::dbgs() << " --------------- Frame (" << nextFrameIdx->first
+ << ")\n";
+ nextFrameIdx = getNextFrameIndex();
+ }
llvm::dbgs() << " ";
values[actualIdx].dump();
}
+ while (nextFrameIdx) {
+ llvm::dbgs() << " --------------- Frame (" << nextFrameIdx->first << ")\n";
+ nextFrameIdx = getNextFrameIndex();
+ }
llvm::dbgs() << "<Bottom>"
<< "\n";
llvm::dbgs() << "=========================================================\n";
@@ -726,7 +909,7 @@ parsed_inst_t ValueStack::popOperands(TypeRange operandTypes, Location *opLoc) {
return emitError(*opLoc,
"stack doesn't contain enough values. trying to get ")
<< operandTypes.size() << " operands on a stack containing only "
- << values.size() << " values.";
+ << values.size() << " values";
size_t stackIdxOffset = values.size() - operandTypes.size();
SmallVector<Value> res{};
res.reserve(operandTypes.size());
@@ -735,8 +918,7 @@ parsed_inst_t ValueStack::popOperands(TypeRange operandTypes, Location *opLoc) {
Type stackType = operand.getType();
if (stackType != operandTypes[i])
return emitError(*opLoc, "invalid operand type on stack. expecting ")
- << operandTypes[i] << ", value on stack is of type " << stackType
- << ".";
+ << operandTypes[i] << ", value on stack is of type " << stackType;
LDBG() << " POP: " << operand;
res.push_back(operand);
}
@@ -792,6 +974,151 @@ ExpressionParser::parse(OpBuilder &builder,
}
}
+llvm::FailureOr<FunctionType>
+ExpressionParser::parseBlockFuncType(OpBuilder &builder) {
+ return getFuncTypeFor(builder, parser.parseBlockType(builder.getContext()));
+}
+
+template <typename OpToCreate>
+parsed_inst_t ExpressionParser::parseBlockLikeOp(OpBuilder &builder) {
+ auto opLoc = currentOpLoc;
+ auto funcType = parseBlockFuncType(builder);
+ if (failed(funcType))
+ return failure();
+
+ auto inputTypes = funcType->getInputs();
+ auto inputOps = popOperands(inputTypes);
+ if (failed(inputOps))
+ return failure();
+
+ Block *curBlock = builder.getBlock();
+ Region *curRegion = curBlock->getParent();
+ auto resTypes = funcType->getResults();
+ llvm::SmallVector<Location> locations{};
+ locations.resize(resTypes.size(), *currentOpLoc);
+ auto *successor =
+ builder.createBlock(curRegion, curRegion->end(), resTypes, locations);
+ builder.setInsertionPointToEnd(curBlock);
+ auto blockOp =
+ builder.create<OpToCreate>(*currentOpLoc, *inputOps, successor);
+ auto *blockBody = blockOp.createBlock();
+ if (failed(parseBlockContent(builder, blockBody, resTypes, *opLoc, blockOp)))
+ return failure();
+ builder.setInsertionPointToStart(successor);
+ return {ValueRange{successor->getArguments()}};
+}
+
+template <>
+inline parsed_inst_t
+ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::block>(
+ OpBuilder &builder) {
+ return parseBlockLikeOp<BlockOp>(builder);
+}
+
+template <>
+inline parsed_inst_t
+ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::loop>(
+ OpBuilder &builder) {
+ return parseBlockLikeOp<LoopOp>(builder);
+}
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::ifOpCode>(OpBuilder &builder) {
+ auto opLoc = currentOpLoc;
+ auto funcType = parseBlockFuncType(builder);
+ if (failed(funcType))
+ return failure();
+
+ LDBG() << "Parsing an if instruction of type " << *funcType;
+ auto inputTypes = funcType->getInputs();
+ auto conditionValue = popOperands(builder.getI32Type());
+ if (failed(conditionValue))
+ return failure();
+ auto inputOps = popOperands(inputTypes);
+ if (failed(inputOps))
+ return failure();
+
+ Block *curBlock = builder.getBlock();
+ Region *curRegion = curBlock->getParent();
+ auto resTypes = funcType->getResults();
+ llvm::SmallVector<Location> locations{};
+ locations.resize(resTypes.size(), *currentOpLoc);
+ auto *successor =
+ builder.createBlock(curRegion, curRegion->end(), resTypes, locations);
+ builder.setInsertionPointToEnd(curBlock);
+ auto ifOp = builder.create<IfOp>(*currentOpLoc, conditionValue->front(),
+ *inputOps, successor);
+ auto *ifEntryBlock = ifOp.createIfBlock();
+ constexpr auto ifElseFilter =
+ ByteSequence<WasmBinaryEncoding::endByte,
+ WasmBinaryEncoding::OpCode::elseOpCode>{};
+ auto parseIfRes = parseBlockContent(builder, ifEntryBlock, resTypes, *opLoc,
+ ifOp, ifElseFilter);
+ if (failed(parseIfRes))
+ return failure();
+ if (*parseIfRes == WasmBinaryEncoding::OpCode::elseOpCode) {
+ LDBG() << " else block is present.";
+ Block *elseEntryBlock = ifOp.createElseBlock();
+ auto parseElseRes =
+ parseBlockContent(builder, elseEntryBlock, resTypes, *opLoc, ifOp);
+ if (failed(parseElseRes))
+ return failure();
+ }
+ builder.setInsertionPointToStart(successor);
+ return {ValueRange{successor->getArguments()}};
+}
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::branchIf>(OpBuilder &builder) {
+ auto level = parser.parseLiteral<uint32_t>();
+ if (failed(level))
+ return failure();
+ Block *curBlock = builder.getBlock();
+ Region *curRegion = curBlock->getParent();
+ auto sip = builder.saveInsertionPoint();
+ Block *elseBlock = builder.createBlock(curRegion, curRegion->end());
+ auto condition = popOperands(builder.getI32Type());
+ if (failed(condition))
+ return failure();
+ builder.restoreInsertionPoint(sip);
+ auto targetOp =
+ LabelBranchingOpInterface::getTargetOpFromBlock(curBlock, *level);
+ if (failed(targetOp))
+ return failure();
+ auto inputTypes = targetOp->getLabelTarget()->getArgumentTypes();
+ auto branchArgs = popOperands(inputTypes);
+ if (failed(branchArgs))
+ return failure();
+ builder.create<BranchIfOp>(*currentOpLoc, condition->front(),
+ builder.getUI32IntegerAttr(*level), *branchArgs,
+ elseBlock);
+ builder.setInsertionPointToStart(elseBlock);
+ return {*branchArgs};
+}
+
+template <>
+inline parsed_inst_t
+ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::call>(
+ OpBuilder &builder) {
+ auto loc = *currentOpLoc;
+ auto funcIdx = parser.parseLiteral<uint32_t>();
+ if (failed(funcIdx))
+ return failure();
+ if (*funcIdx >= symbols.funcSymbols.size())
+ return emitError(loc, "Invalid function index: ") << *funcIdx;
+ auto callee = symbols.funcSymbols[*funcIdx];
+ llvm::ArrayRef<Type> inTypes = callee.functionType.getInputs();
+ llvm::ArrayRef<Type> resTypes = callee.functionType.getResults();
+ parsed_inst_t inOperands = popOperands(inTypes);
+ if (failed(inOperands))
+ return failure();
+ auto callOp =
+ builder.create<FuncCallOp>(loc, resTypes, callee.symbol, *inOperands);
+ return {callOp.getResults()};
+}
+
template <>
inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
WasmBinaryEncoding::OpCode::localGet>(OpBuilder &builder) {
@@ -834,7 +1161,7 @@ parsed_inst_t ExpressionParser::parseSetOrTee(OpBuilder &builder) {
if (valueStack.empty())
return emitError(
*currentOpLoc,
- "invalid stack access, trying to access a value on an empty stack.");
+ "invalid stack access, trying to access a value on an empty stack");
parsed_inst_t poppedOp = popOperands(locals[*id].getType().getElementType());
if (failed(poppedOp))
@@ -1000,11 +1327,23 @@ inline parsed_inst_t ExpressionParser::buildNumericOp(
BUILD_NUMERIC_BINOP_FP(CopySignOp, copysign)
BUILD_NUMERIC_BINOP_FP(DivOp, div)
+BUILD_NUMERIC_BINOP_FP(GeOp, ge)
+BUILD_NUMERIC_BINOP_FP(GtOp, gt)
+BUILD_NUMERIC_BINOP_FP(LeOp, le)
+BUILD_NUMERIC_BINOP_FP(LtOp, lt)
BUILD_NUMERIC_BINOP_FP(MaxOp, max)
BUILD_NUMERIC_BINOP_FP(MinOp, min)
BUILD_NUMERIC_BINOP_INT(AndOp, and)
BUILD_NUMERIC_BINOP_INT(DivSIOp, divS)
BUILD_NUMERIC_BINOP_INT(DivUIOp, divU)
+BUILD_NUMERIC_BINOP_INT(GeSIOp, geS)
+BUILD_NUMERIC_BINOP_INT(GeUIOp, geU)
+BUILD_NUMERIC_BINOP_INT(GtSIOp, gtS)
+BUILD_NUMERIC_BINOP_INT(GtUIOp, gtU)
+BUILD_NUMERIC_BINOP_INT(LeSIOp, leS)
+BUILD_NUMERIC_BINOP_INT(LeUIOp, leU)
+BUILD_NUMERIC_BINOP_INT(LtSIOp, ltS)
+BUILD_NUMERIC_BINOP_INT(LtUIOp, ltU)
BUILD_NUMERIC_BINOP_INT(OrOp, or)
BUILD_NUMERIC_BINOP_INT(RemSIOp, remS)
BUILD_NUMERIC_BINOP_INT(RemUIOp, remU)
@@ -1015,7 +1354,9 @@ BUILD_NUMERIC_BINOP_INT(ShRSOp, shrS)
BUILD_NUMERIC_BINOP_INT(ShRUOp, shrU)
BUILD_NUMERIC_BINOP_INT(XOrOp, xor)
BUILD_NUMERIC_BINOP_INTFP(AddOp, add)
+BUILD_NUMERIC_BINOP_INTFP(EqOp, eq)
BUILD_NUMERIC_BINOP_INTFP(MulOp, mul)
+BUILD_NUMERIC_BINOP_INTFP(NeOp, ne)
BUILD_NUMERIC_BINOP_INTFP(SubOp, sub)
BUILD_NUMERIC_UNARY_OP_FP(AbsOp, abs)
BUILD_NUMERIC_UNARY_OP_FP(CeilOp, ceil)
@@ -1025,6 +1366,7 @@ BUILD_NUMERIC_UNARY_OP_FP(SqrtOp, sqrt)
BUILD_NUMERIC_UNARY_OP_FP(TruncOp, trunc)
BUILD_NUMERIC_UNARY_OP_INT(ClzOp, clz)
BUILD_NUMERIC_UNARY_OP_INT(CtzOp, ctz)
+BUILD_NUMERIC_UNARY_OP_INT(EqzOp, eqz)
BUILD_NUMERIC_UNARY_OP_INT(PopCntOp, popcnt)
// Don't need these anymore so let's undef them.
@@ -1036,6 +1378,105 @@ BUILD_NUMERIC_UNARY_OP_INT(PopCntOp, popcnt)
#undef BUILD_NUMERIC_OP
#undef BUILD_NUMERIC_CAST_OP
+template <typename opType, typename inputType, typename outputType,
+ typename... extraArgsT>
+inline parsed_inst_t ExpressionParser::buildConvertOp(OpBuilder &builder,
+ extraArgsT... extraArgs) {
+ static_assert(std::is_arithmetic_v<inputType>,
+ "InputType should be an arithmetic type");
+ static_assert(std::is_arithmetic_v<outputType>,
+ "OutputType should be an arithmetic type");
+ auto intype = buildLiteralType<inputType>(builder);
+ auto outType = buildLiteralType<outputType>(builder);
+ auto operand = popOperands(intype);
+ if (failed(operand))
+ return failure();
+ auto op = builder.create<opType>(*currentOpLoc, outType, operand->front(),
+ extraArgs...);
+ LDBG() << "Built operation: " << op;
+ return {{op.getResult()}};
+}
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::demoteF64ToF32>(OpBuilder &builder) {
+ return buildConvertOp<DemoteOp, double, float>(builder);
+}
+
+template <>
+inline parsed_inst_t
+ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::wrap>(
+ OpBuilder &builder) {
+ return buildConvertOp<WrapOp, int64_t, int32_t>(builder);
+}
+
+#define BUILD_CONVERSION_OP(IN_T, OUT_T, SOURCE_OP, TARGET_OP) \
+ template <> \
+ inline parsed_inst_t ExpressionParser::parseSpecificInstruction< \
+ WasmBinaryEncoding::OpCode::SOURCE_OP>(OpBuilder & builder) { \
+ return buildConvertOp<TARGET_OP, IN_T, OUT_T>(builder); \
+ }
+
+#define BUILD_CONVERT_OP_FOR(DEST_T, WIDTH) \
+ BUILD_CONVERSION_OP(uint32_t, DEST_T, convertUI32F##WIDTH, ConvertUOp) \
+ BUILD_CONVERSION_OP(int32_t, DEST_T, convertSI32F##WIDTH, ConvertSOp) \
+ BUILD_CONVERSION_OP(uint64_t, DEST_T, convertUI64F##WIDTH, ConvertUOp) \
+ BUILD_CONVERSION_OP(int64_t, DEST_T, convertSI64F##WIDTH, ConvertSOp)
+
+BUILD_CONVERT_OP_FOR(float, 32)
+BUILD_CONVERT_OP_FOR(double, 64)
+
+#undef BUILD_CONVERT_OP_FOR
+
+BUILD_CONVERSION_OP(int32_t, int64_t, extendS, ExtendSI32Op)
+BUILD_CONVERSION_OP(int32_t, int64_t, extendU, ExtendUI32Op)
+
+#undef BUILD_CONVERSION_OP
+
+#define BUILD_SLICE_EXTEND_PARSER(IT_WIDTH, EXTRACT_WIDTH) \
+ template <> \
+ parsed_inst_t ExpressionParser::parseSpecificInstruction< \
+ WasmBinaryEncoding::OpCode::extendI##IT_WIDTH##EXTRACT_WIDTH##S>( \
+ OpBuilder & builder) { \
+ using inout_t = int##IT_WIDTH##_t; \
+ auto attr = builder.getUI32IntegerAttr(EXTRACT_WIDTH); \
+ return buildConvertOp<ExtendLowBitsSOp, inout_t, inout_t>(builder, attr); \
+ }
+
+BUILD_SLICE_EXTEND_PARSER(32, 8)
+BUILD_SLICE_EXTEND_PARSER(32, 16)
+BUILD_SLICE_EXTEND_PARSER(64, 8)
+BUILD_SLICE_EXTEND_PARSER(64, 16)
+BUILD_SLICE_EXTEND_PARSER(64, 32)
+
+#undef BUILD_SLICE_EXTEND_PARSER
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::promoteF32ToF64>(OpBuilder &builder) {
+ return buildConvertOp<PromoteOp, float, double>(builder);
+}
+
+#define BUILD_REINTERPRET_PARSER(WIDTH, FP_TYPE) \
+ template <> \
+ inline parsed_inst_t ExpressionParser::parseSpecificInstruction< \
+ WasmBinaryEncoding::OpCode::reinterpretF##WIDTH##AsI##WIDTH>(OpBuilder & \
+ builder) { \
+ return buildConvertOp<ReinterpretOp, FP_TYPE, int##WIDTH##_t>(builder); \
+ } \
+ \
+ template <> \
+ inline parsed_inst_t ExpressionParser::parseSpecificInstruction< \
+ WasmBinaryEncoding::OpCode::reinterpretI##WIDTH##AsF##WIDTH>(OpBuilder & \
+ builder) { \
+ return buildConvertOp<ReinterpretOp, int##WIDTH##_t, FP_TYPE>(builder); \
+ }
+
+BUILD_REINTERPRET_PARSER(32, float)
+BUILD_REINTERPRET_PARSER(64, double)
+
+#undef BUILD_REINTERPRET_PARSER
+
class WasmBinaryParser {
private:
struct SectionRegistry {
@@ -1153,7 +1594,7 @@ private:
if (tid.id >= symbols.moduleFuncTypes.size())
return emitError(loc, "invalid type id: ")
<< tid.id << ". Only " << symbols.moduleFuncTypes.size()
- << " type registration.";
+ << " type registrations";
FunctionType type = symbols.moduleFuncTypes[tid.id];
std::string symbol = symbols.getNewFuncSymbolName();
auto funcOp = FuncImportOp::create(builder, loc, symbol, moduleName,
@@ -1221,7 +1662,7 @@ public:
FileLineColLoc magicLoc = parser.getLocation();
FailureOr<StringRef> magic = parser.consumeNBytes(wasmHeader.size());
if (failed(magic) || magic->compare(wasmHeader)) {
- emitError(magicLoc, "source file does not contain valid Wasm header.");
+ emitError(magicLoc, "source file does not contain valid Wasm header");
return;
}
auto const expectedVersionString = StringRef{"\1\0\0\0", 4};
@@ -1391,7 +1832,7 @@ WasmBinaryParser::parseSectionItem<WasmSectionType::EXPORT>(ParserHead &ph,
return failure();
Operation *op = SymbolTable::lookupSymbolIn(mOp, *currentSymbol);
- SymbolTable::setSymbolVisibility(op, SymbolTable::Visibility::Public);
+ op->setAttr("exported", UnitAttr::get(op->getContext()));
StringAttr symName = SymbolTable::getSymbolName(op);
return SymbolTable{mOp}.rename(symName, *exportName);
}
diff --git a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
index 9670285..3fda5a7 100644
--- a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
+++ b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
@@ -93,7 +93,7 @@ void CodeGen::generate(const ast::Module &astModule, ModuleOp module) {
// Emit function to add the generated matchers to the pattern list.
os << "template <typename... ConfigsT>\n"
- "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns("
+ "[[maybe_unused]] static void populateGeneratedPDLLPatterns("
"::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) {\n";
for (const auto &name : patternNames)
os << " patterns.add<" << name
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 9f5246d..ffa96ad 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -137,6 +137,16 @@ declare_mlir_dialect_python_bindings(
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ TD_FILE dialects/OpenACCOps.td
+ SOURCES
+ dialects/openacc.py
+ DIALECT_NAME acc
+ DEPENDS acc_common_td
+ )
+
+declare_mlir_dialect_python_bindings(
+ ADD_TO_PARENT MLIRPythonSources.Dialects
+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/GPUOps.td
SOURCES_GLOB dialects/gpu/*.py
DIALECT_NAME gpu
diff --git a/mlir/python/mlir/dialects/OpenACCOps.td b/mlir/python/mlir/dialects/OpenACCOps.td
new file mode 100644
index 0000000..69a3002
--- /dev/null
+++ b/mlir/python/mlir/dialects/OpenACCOps.td
@@ -0,0 +1,14 @@
+//===-- OpenACCOps.td - Entry point for OpenACCOps bind ------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_OPENACC_OPS
+#define PYTHON_BINDINGS_OPENACC_OPS
+
+include "mlir/Dialect/OpenACC/OpenACCOps.td"
+
+#endif
diff --git a/mlir/python/mlir/dialects/gpu/__init__.py b/mlir/python/mlir/dialects/gpu/__init__.py
index 4cd80aa..b14ea68 100644
--- a/mlir/python/mlir/dialects/gpu/__init__.py
+++ b/mlir/python/mlir/dialects/gpu/__init__.py
@@ -3,5 +3,151 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._gpu_ops_gen import *
+from .._gpu_ops_gen import _Dialect
from .._gpu_enum_gen import *
from ..._mlir_libs._mlirDialectsGPU import *
+from typing import Callable, Sequence, Union, Optional, List
+
+try:
+ from ...ir import (
+ FunctionType,
+ TypeAttr,
+ StringAttr,
+ UnitAttr,
+ Block,
+ InsertionPoint,
+ ArrayAttr,
+ Type,
+ DictAttr,
+ Attribute,
+ DenseI32ArrayAttr,
+ )
+ from .._ods_common import (
+ get_default_loc_context as _get_default_loc_context,
+ _cext as _ods_cext,
+ )
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+
+@_ods_cext.register_operation(_Dialect, replace=True)
+class GPUFuncOp(GPUFuncOp):
+ __doc__ = GPUFuncOp.__doc__
+
+ KERNEL_ATTR_NAME = "gpu.kernel"
+ KNOWN_BLOCK_SIZE_ATTR_NAME = "known_block_size"
+ KNOWN_GRID_SIZE_ATTR_NAME = "known_grid_size"
+
+ FUNCTION_TYPE_ATTR_NAME = "function_type"
+ SYM_NAME_ATTR_NAME = "sym_name"
+ ARGUMENT_ATTR_NAME = "arg_attrs"
+ RESULT_ATTR_NAME = "res_attrs"
+
+ def __init__(
+ self,
+ function_type: Union[FunctionType, TypeAttr],
+ sym_name: Optional[Union[str, StringAttr]] = None,
+ kernel: Optional[bool] = None,
+ workgroup_attrib_attrs: Optional[Sequence[dict]] = None,
+ private_attrib_attrs: Optional[Sequence[dict]] = None,
+ known_block_size: Optional[Union[Sequence[int], DenseI32ArrayAttr]] = None,
+ known_grid_size: Optional[Union[Sequence[int], DenseI32ArrayAttr]] = None,
+ loc=None,
+ ip=None,
+ body_builder: Optional[Callable[[GPUFuncOp], None]] = None,
+ ):
+ """
+ Create a GPUFuncOp with the provided `function_type`, `sym_name`,
+ `kernel`, `workgroup_attrib_attrs`, `private_attrib_attrs`, `known_block_size`,
+ `known_grid_size`, and `body_builder`.
+ - `function_type` is a FunctionType or a TypeAttr.
+ - `sym_name` is a string or a StringAttr representing the function name.
+ - `kernel` is a boolean representing whether the function is a kernel.
+ - `workgroup_attrib_attrs` is an optional list of dictionaries.
+ - `private_attrib_attrs` is an optional list of dictionaries.
+ - `known_block_size` is an optional list of integers or a DenseI32ArrayAttr representing the known block size.
+ - `known_grid_size` is an optional list of integers or a DenseI32ArrayAttr representing the known grid size.
+ - `body_builder` is an optional callback. When provided, a new entry block
+ is created and the callback is invoked with the new op as argument within
+ an InsertionPoint context already set for the block. The callback is
+ expected to insert a terminator in the block.
+ """
+ function_type = (
+ TypeAttr.get(function_type)
+ if not isinstance(function_type, TypeAttr)
+ else function_type
+ )
+ super().__init__(
+ function_type,
+ workgroup_attrib_attrs=workgroup_attrib_attrs,
+ private_attrib_attrs=private_attrib_attrs,
+ loc=loc,
+ ip=ip,
+ )
+
+ if isinstance(sym_name, str):
+ self.attributes[self.SYM_NAME_ATTR_NAME] = StringAttr.get(sym_name)
+ elif isinstance(sym_name, StringAttr):
+ self.attributes[self.SYM_NAME_ATTR_NAME] = sym_name
+ else:
+ raise ValueError("sym_name must be a string or a StringAttr")
+
+ if kernel:
+ self.attributes[self.KERNEL_ATTR_NAME] = UnitAttr.get()
+
+ if known_block_size is not None:
+ if isinstance(known_block_size, Sequence):
+ block_size = DenseI32ArrayAttr.get(known_block_size)
+ self.attributes[self.KNOWN_BLOCK_SIZE_ATTR_NAME] = block_size
+ elif isinstance(known_block_size, DenseI32ArrayAttr):
+ self.attributes[self.KNOWN_BLOCK_SIZE_ATTR_NAME] = known_block_size
+ else:
+ raise ValueError(
+ "known_block_size must be a list of integers or a DenseI32ArrayAttr"
+ )
+
+ if known_grid_size is not None:
+ if isinstance(known_grid_size, Sequence):
+ grid_size = DenseI32ArrayAttr.get(known_grid_size)
+ self.attributes[self.KNOWN_GRID_SIZE_ATTR_NAME] = grid_size
+ elif isinstance(known_grid_size, DenseI32ArrayAttr):
+ self.attributes[self.KNOWN_GRID_SIZE_ATTR_NAME] = known_grid_size
+ else:
+ raise ValueError(
+ "known_grid_size must be a list of integers or a DenseI32ArrayAttr"
+ )
+
+ if body_builder is not None:
+ with InsertionPoint(self.add_entry_block()):
+ body_builder(self)
+
+ @property
+ def name(self) -> StringAttr:
+ return StringAttr(self.attributes[self.SYM_NAME_ATTR_NAME])
+
+ @property
+ def is_kernel(self) -> bool:
+ return self.KERNEL_ATTR_NAME in self.attributes
+
+ def add_entry_block(self) -> Block:
+ if len(self.body.blocks) > 0:
+ raise RuntimeError(f"Entry block already exists for {self.name.value}")
+
+ function_type = self.function_type.value
+ return self.body.blocks.append(
+ *function_type.inputs,
+ arg_locs=[self.location for _ in function_type.inputs],
+ )
+
+ @property
+ def entry_block(self) -> Block:
+ if len(self.body.blocks) == 0:
+ raise RuntimeError(
+ f"Entry block does not exist for {self.name.value}."
+ + " Do you need to call the add_entry_block() method on this GPUFuncOp?"
+ )
+ return self.body.blocks[0]
+
+ @property
+ def arguments(self) -> Sequence[Type]:
+ return self.function_type.value.inputs
diff --git a/mlir/python/mlir/dialects/openacc.py b/mlir/python/mlir/dialects/openacc.py
new file mode 100644
index 0000000..057f71a
--- /dev/null
+++ b/mlir/python/mlir/dialects/openacc.py
@@ -0,0 +1,5 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ._acc_ops_gen import *
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index dbff233..455f886 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -convert-math-to-rocdl -allow-unregistered-dialect -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -pass-pipeline='builtin.module(convert-math-to-rocdl{chipset=gfx803})' | FileCheck %s --check-prefix=PRE9
+// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -pass-pipeline='builtin.module(convert-math-to-rocdl{chipset=gfx942})' | FileCheck %s --check-prefix=POST9
module @test_module {
// CHECK: llvm.func @__ocml_fmod_f16(f16, f16) -> f16
@@ -596,3 +597,76 @@ module @test_module {
func.return %result : vector<2x2xf16>
}
}
+
+// -----
+
+// f16 clamp → rocdl.fmed3 on gfx9+
+// CHECK-LABEL: func.func @clampf_f16
+func.func @clampf_f16(%x: f16, %lo: f16, %hi: f16) -> f16 {
+ %r = math.clampf %x to [%lo, %hi] : f16
+ return %r : f16
+ // POST9: rocdl.fmed3 {{.*}} : f16
+ // PRE9-NOT: rocdl.fmed3
+ // PRE9: math.clampf {{.*}} : f16
+}
+
+// f32 clamp → rocdl.fmed3 on gfx9+
+// CHECK-LABEL: func.func @clampf_f32
+func.func @clampf_f32(%x: f32, %lo: f32, %hi: f32) -> f32 {
+ %r = math.clampf %x to [%lo, %hi] : f32
+ return %r : f32
+ // POST9: rocdl.fmed3 {{.*}} : f32
+ // PRE9-NOT: rocdl.fmed3
+ // PRE9: math.clampf {{.*}} : f32
+}
+
+// -----
+
+// Vector f16 clamp → rocdl.fmed3 on gfx9+
+// CHECK-LABEL: func.func @clampf_vector_f16
+func.func @clampf_vector_f16(%x: vector<2xf16>, %lo: vector<2xf16>, %hi: vector<2xf16>) -> vector<2xf16> {
+ %r = math.clampf %x to [%lo, %hi] : vector<2xf16>
+ return %r : vector<2xf16>
+ // POST9: rocdl.fmed3 {{.*}} : vector<2xf16>
+ // PRE9-NOT: rocdl.fmed3
+ // PRE9: math.clampf {{.*}} : vector<2xf16>
+}
+
+// -----
+
+// Vector f32 clamp → rocdl.fmed3 on gfx9+
+// CHECK-LABEL: func.func @clampf_vector_f32
+func.func @clampf_vector_f32(%x: vector<2xf32>, %lo: vector<2xf32>, %hi: vector<2xf32>) -> vector<2xf32> {
+ %r = math.clampf %x to [%lo, %hi] : vector<2xf32>
+ return %r : vector<2xf32>
+ // POST9: rocdl.fmed3 {{.*}} : vector<2xf32>
+ // PRE9-NOT: rocdl.fmed3
+ // PRE9: math.clampf {{.*}} : vector<2xf32>
+}
+
+// -----
+
+// Multi-dimensional vector f16 clamp → rocdl.fmed3 on gfx9+ (unrolled to 1D vectors)
+// CHECK-LABEL: func.func @clampf_vector_2d_f16
+func.func @clampf_vector_2d_f16(%x: vector<2x2xf16>, %lo: vector<2x2xf16>, %hi: vector<2x2xf16>) -> vector<2x2xf16> {
+ %r = math.clampf %x to [%lo, %hi] : vector<2x2xf16>
+ return %r : vector<2x2xf16>
+ // POST9: builtin.unrealized_conversion_cast {{.*}} : vector<2x2xf16> to !llvm.array<2 x vector<2xf16>>
+ // POST9: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+ // POST9: rocdl.fmed3 {{.*}} : vector<2xf16>
+ // POST9: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+ // POST9: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+ // POST9: rocdl.fmed3 {{.*}} : vector<2xf16>
+ // POST9: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+ // PRE9-NOT: rocdl.fmed3
+ // PRE9: math.clampf {{.*}} : vector<2x2xf16>
+}
+
+// -----
+// CHECK-LABEL: func.func @clampf_bf16
+func.func @clampf_bf16(%x: bf16, %lo: bf16, %hi: bf16) -> bf16 {
+ %r = math.clampf %x to [%lo, %hi] : bf16
+ return %r : bf16
+ // CHECK: math.clampf {{.*}} : bf16
+ // CHECK-NOT: rocdl.fmed3
+}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 2d33888..d669a3b 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -76,6 +76,18 @@ func.func @broadcast_vec1d_from_f32(%arg0: f32) -> vector<2xf32> {
// -----
+func.func @broadcast_single_elem_vec1d_from_f32(%arg0: f32) -> vector<1xf32> {
+ %0 = vector.broadcast %arg0 : f32 to vector<1xf32>
+ return %0 : vector<1xf32>
+}
+// CHECK-LABEL: @broadcast_single_elem_vec1d_from_f32
+// CHECK-SAME: %[[A:.*]]: f32)
+// CHECK: %[[T0:.*]] = llvm.insertelement %[[A]]
+// CHECK-NOT: llvm.shufflevector
+// CHECK: return %[[T0]] : vector<1xf32>
+
+// -----
+
func.func @broadcast_vec1d_from_f32_scalable(%arg0: f32) -> vector<[2]xf32> {
%0 = vector.broadcast %arg0 : f32 to vector<[2]xf32>
return %0 : vector<[2]xf32>
diff --git a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
index e6f22f0..a9ab0be 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
@@ -1,17 +1,13 @@
// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
-#sg_map_a_f16 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
-#sg_map_b_f16 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
-#sg_map_c_f32 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
-
-gpu.module @load_store_check {
+gpu.module @test_kernel {
// CHECK-LABEL: func.func @dpas(
// CHECK-SAME: %[[ARG0:.*]]: vector<8xf16>, %[[ARG1:.*]]: vector<16xf16>, %[[ARG2:.*]]: vector<8xf32>
func.func @dpas(%a_loaded: vector<8xf16>, %b_loaded: vector<16xf16>, %c_loaded: vector<8xf32>) -> vector<8xf32> {
// Loads are checked in a separate test.
// CHECK: %[[D:.*]] = xevm.mma %[[ARG0]], %[[ARG1]], %[[ARG2]] {shape = <m = 8, n = 16, k = 16>, types = <d = f32, a = f16, b = f16, c = f32>}
// CHECK-SAME: : (vector<8xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32>
- %d = xegpu.dpas %a_loaded, %b_loaded, %c_loaded {a_layout = #sg_map_a_f16, b_layout = #sg_map_b_f16, c_layout = #sg_map_c_f32}
+ %d = xegpu.dpas %a_loaded, %b_loaded, %c_loaded
: vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
return %d : vector<8xf32>
}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
new file mode 100644
index 0000000..d4cb493
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
@@ -0,0 +1,201 @@
+// RUN: mlir-opt -split-input-file -convert-xegpu-to-xevm -cse %s | FileCheck %s
+
+gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
+
+ // e.g. for mem_desc<32x32xf16, @strides=[1, 16]>
+ // its memory layout tuple is (blocked shape = [1,1,32,32],strides=[1024,1024,32,1])
+ //CHECK-LABEL: load_store_matrix_1
+ gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> f32 {
+ %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32>
+
+ //CHECK: %[[TID:.*]] = gpu.thread_id x
+ //CHECK: %[[C1:.*]] = arith.constant 1 : index
+ //CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index
+ //CHECK: %[[C4:.*]] = arith.constant 4 : i32
+ //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i32
+ //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32
+
+ %tid_x = gpu.thread_id x
+ %c0 = arith.constant 0 : index
+ %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32
+
+ //CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3>
+
+ xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index
+
+ gpu.return %1: f32
+ }
+
+// e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 32]>
+ // its memory layout tuple is ([2,4,16,16],[256,512,1,16])
+ //CHECK-LABEL: load_store_matrix_2
+ gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> f16 {
+ %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
+ //CHECK: %[[c0:.*]] = arith.constant 0 : index
+ //CHECK: %[[tid_x:.*]] = gpu.thread_id x
+ //CHECK: %[[c13:.*]] = arith.constant 13 : index
+ //CHECK: %[[c16:.*]] = arith.constant 16 : index
+ //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c13]], %[[c16]] : index
+ //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index
+ //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
+ //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
+
+ //CHECK: %[[c256:.*]] = arith.constant 256 : index
+ //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
+ //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
+ //CHECK: %[[c512:.*]] = arith.constant 512 : index
+ //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index
+ //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
+ //CHECK: %[[c1:.*]] = arith.constant 1 : index
+ //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index
+ //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
+ //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index
+ //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
+
+ //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> f16
+
+
+ %tid_x = gpu.thread_id x
+ %c13 = arith.constant 13 : index
+ %1 = xegpu.load_matrix %0[%c13, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> f16
+
+ //CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3>
+
+ xegpu.store_matrix %1, %0[%c13, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index
+ gpu.return %1: f16
+ }
+
+
+ // e.g. for mem_desc<32x64xf16, @block=[16, 16]>
+ // its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
+ //CHECK-LABEL: load_store_matrix_3
+ gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> f16 {
+ //CHECK: %[[c0:.*]] = arith.constant 0 : index
+ //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>
+ %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
+
+ //CHECK: %[[tid_x:.*]] = gpu.thread_id x
+ //CHECK: %[[c19:.*]] = arith.constant 19 : index
+ %tid_x = gpu.thread_id x
+ %c19 = arith.constant 19: index
+
+ //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
+ //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
+ //CHECK: %[[c16:.*]] = arith.constant 16 : index
+ //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index
+ //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index
+ //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
+ //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
+ //CHECK: %[[c1024:.*]] = arith.constant 1024 : index
+ //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index
+ //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
+ //CHECK: %[[c256:.*]] = arith.constant 256 : index
+ //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c256]] : index
+ //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
+ //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c16]] : index
+ //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
+ //CHECK: %[[c1:.*]] = arith.constant 1 : index
+ //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index
+ //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
+
+ //CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16
+ %1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> f16
+
+ //CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3>
+ xegpu.store_matrix %1, %0[%c19, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index
+
+ //CHECK: gpu.return %[[loaded]] : f16
+ gpu.return %1: f16
+ }
+
+ // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]>
+ // its memory layout tuple is ([2,4,16,16],[256,512,1,16])
+ //CHECK-LABEL: load_store_matrix_4
+ gpu.func @load_store_matrix_4(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
+ %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
+
+ //CHECK: %[[c0:.*]] = arith.constant 0 : index
+ //CHECK: %[[tid_x:.*]] = gpu.thread_id x
+
+ //CHECK: %[[c16:.*]] = arith.constant 16 : index
+ //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
+ //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
+ //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
+ //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
+
+ //CHECK: %[[c256:.*]] = arith.constant 256 : index
+ //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
+ //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
+ //CHECK: %[[c512:.*]] = arith.constant 512 : index
+ //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index
+ //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
+ //CHECK: %[[c1:.*]] = arith.constant 1 : index
+ //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index
+ //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
+ //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index
+ //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
+
+ //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> vector<8xf16>
+
+ %tid_x = gpu.thread_id x
+ %c16 = arith.constant 16 : index
+ %1 = xegpu.load_matrix %0[%c16, %tid_x] : !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<8xf16>
+
+ //CHECK: llvm.store %[[loaded]], {{.*}} : vector<8xf16>, !llvm.ptr<3>
+ xegpu.store_matrix %1, %0[%c16, %tid_x] : vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index
+
+ gpu.return %1: vector<8xf16>
+ }
+
+
+ // e.g. for mem_desc<32x64xf16, @block=[16, 16]>
+ // its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
+ //CHECK-LABEL: load_store_matrix_5
+ gpu.func @load_store_matrix_5(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
+ //CHECK: %[[c0:.*]] = arith.constant 0 : index
+ //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>
+
+ %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
+
+ //CHECK: %[[c16:.*]] = arith.constant 16 : index
+ //CHECK: %[[c48:.*]] = arith.constant 48 : index
+
+ %c16 = arith.constant 16 : index
+ %c48 = arith.constant 48 : index
+
+ //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
+ //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
+ //CHECK: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
+ //CHECK: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
+ //CHECK: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index
+ //CHECK: %[[offset3:.*]] = arith.remsi %[[c48]], %[[c16]] : index
+ //CHECK: %[[c1024:.*]] = arith.constant 1024 : index
+ //CHECK: %[[mul0:.*]] = arith.muli %[[offset0]], %[[c1024]] : index
+ //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
+ //CHECK: %[[c256:.*]] = arith.constant 256 : index
+ //CHECK: %[[mul1:.*]] = arith.muli %[[offset2]], %[[c256]] : index
+ //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
+ //CHECK: %[[mul2:.*]] = arith.muli %[[offset1]], %[[c16]] : index
+ //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
+ //CHECK: %[[c1:.*]] = arith.constant 1 : index
+ //CHECK: %[[mul3:.*]] = arith.muli %[[offset3]], %[[c1]] : index
+ //CHECK: %[[linearOffset:.*]] = arith.addi %[[mul3]], %[[add2]] : index
+ //CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i32
+ //CHECK: %[[c2:.*]] = arith.constant 2 : i32
+ //CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i32
+ //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI64]], %[[byteOffset]] : i32
+ //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i32 to !llvm.ptr<3>
+ //CHECK: %[[loadedI16:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16>
+ //CHECK: %[[loaded:.*]] = vector.bitcast %[[loadedI16]] : vector<8xi16> to vector<8xf16>
+
+ %1 = xegpu.load_matrix %0[%c16, %c48] {subgroup_block_io}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> vector<8xf16>
+
+ //CHECK: %[[storeDataI16:.*]] = vector.bitcast %[[loaded]] : vector<8xf16> to vector<8xi16>
+ //CHECK: xevm.blockstore %[[ptr]], %[[storeDataI16]] : (!llvm.ptr<3>, vector<8xi16>)
+
+ xegpu.store_matrix %1, %0[%c16, %c48] {subgroup_block_io}: vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index
+
+ gpu.return %1: vector<8xf16>
+ }
+
+}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
index 0b150e9..9c552d8 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
@@ -14,19 +14,36 @@ gpu.func @load_gather_i64_src_value_offset(%src: i64, %offset: vector<1xindex>)
// CHECK: %[[VAR4:.*]] = arith.addi %[[ARG0]], %[[VAR3]] : i64
// CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1>
// CHECK: %[[VAR6:.*]] = scf.if %[[VAR2]] -> (f16) {
- // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR5]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>} : !llvm.ptr<1> -> vector<1xf16>
- // CHECK: %[[VAR8:.*]] = vector.extract %[[VAR7]][0] : f16 from vector<1xf16>
- // CHECK: scf.yield %[[VAR8]] : f16
- // CHECK: } else {
- // CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<1xf16>
- // CHECK: %[[VAR7:.*]] = vector.extract %[[CST_0]][0] : f16 from vector<1xf16>
+ // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR5]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>} : !llvm.ptr<1> -> f16
// CHECK: scf.yield %[[VAR7]] : f16
+ // CHECK: } else {
+ // CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f16
+ // CHECK: scf.yield %[[CST_0]] : f16
// CHECK: }
%3 = xegpu.load %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
: i64, vector<1xindex>, vector<1xi1> -> vector<1xf16>
gpu.return
}
}
+
+// -----
+gpu.module @test {
+// CHECK-LABEL: @source_materialize_single_elem_vec
+// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: memref<1xf16>
+gpu.func @source_materialize_single_elem_vec(%src: i64, %offset: vector<1xindex>, %dst: memref<1xf16>) {
+ %1 = arith.constant dense<1>: vector<1xi1>
+ %3 = xegpu.load %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : i64, vector<1xindex>, vector<1xi1> -> vector<1xf16>
+ // CHECK: %[[VAR_IF:.*]] = scf.if
+ // CHECK: %[[VAR_RET:.*]] = vector.broadcast %[[VAR_IF]] : f16 to vector<1xf16>
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: vector.store %[[VAR_RET]], %[[ARG2]][%[[C0]]] : memref<1xf16>, vector<1xf16>
+ %c0 = arith.constant 0 : index
+ vector.store %3, %dst[%c0] : memref<1xf16>, vector<1xf16>
+ gpu.return
+}
+}
+
// -----
gpu.module @test {
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index e56079c..1169cd1 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -2235,6 +2235,136 @@ func.func @affine_leading_zero_no_outer_bound(%arg0: index, %arg1: index) -> ind
// -----
+// CHECK-LABEL: func @delin_apply_cancel_exact
+// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>)
+// CHECK-COUNT-6: memref.store %[[ARG0]], %[[ARG1]][%[[ARG0]]]
+// CHECK-NOT: memref.store
+// CHECK: return
+func.func @delin_apply_cancel_exact(%arg0: index, %arg1: memref<?xindex>) {
+ %a:3 = affine.delinearize_index %arg0 into (4, 5) : index, index, index
+ %b:3 = affine.delinearize_index %arg0 into (3, 4, 5) : index, index, index
+ %c:2 = affine.delinearize_index %arg0 into (20) : index, index
+
+ %t1 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 5 + s2 * 20)>()[%a#2, %a#1, %a#0]
+ memref.store %t1, %arg1[%t1] : memref<?xindex>
+
+ %t2 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s2 * 20 + s1 * 5)>()[%a#2, %a#1, %a#0]
+ memref.store %t2, %arg1[%t2] : memref<?xindex>
+
+ %t3 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 20 + s2 * 5 + s0)>()[%a#2, %a#0, %a#1]
+ memref.store %t3, %arg1[%t3] : memref<?xindex>
+
+ %t4 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 5 + s2 * 20)>()[%b#2, %b#1, %b#0]
+ memref.store %t4, %arg1[%t4] : memref<?xindex>
+
+ %t5 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 20)>()[%c#1, %c#0]
+ memref.store %t5, %arg1[%t5] : memref<?xindex>
+
+ %t6 = affine.apply affine_map<()[s0, s1] -> (s1 * 20 + s0)>()[%c#1, %c#0]
+ memref.store %t6, %arg1[%t5] : memref<?xindex>
+
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @delin_apply_cancel_exact_dim
+// CHECK: affine.for %[[arg1:.+]] = 0 to 256
+// CHECK: memref.store %[[arg1]]
+// CHECK: return
+func.func @delin_apply_cancel_exact_dim(%arg0: memref<?xindex>) {
+ affine.for %arg1 = 0 to 256 {
+ %a:3 = affine.delinearize_index %arg1 into (2, 2, 64) : index, index, index
+ %i = affine.apply affine_map<(d0, d1, d2) -> (d0 + d1 * 128 + d2 * 64)>(%a#2, %a#0, %a#1)
+ memref.store %i, %arg0[%i] : memref<?xindex>
+ }
+ return
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 512)>
+// CHECK-LABEL: func @delin_apply_cancel_const_term
+// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>)
+// CHECK: affine.apply #[[$MAP]]()[%[[ARG0]]]
+// CHECK: return
+func.func @delin_apply_cancel_const_term(%arg0: index, %arg1: memref<?xindex>) {
+ %a:3 = affine.delinearize_index %arg0 into (2, 2, 64) : index, index, index
+
+ %t1 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 128 + s2 * 64 + 512)>()[%a#2, %a#0, %a#1]
+ memref.store %t1, %arg1[%t1] : memref<?xindex>
+
+ return
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 512)>
+// CHECK-LABEL: func @delin_apply_cancel_var_term
+// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>, %[[ARG2:.+]]: index)
+// CHECK: affine.apply #[[$MAP]]()[%[[ARG2]], %[[ARG0]]]
+// CHECK: return
+func.func @delin_apply_cancel_var_term(%arg0: index, %arg1: memref<?xindex>, %arg2: index) {
+ %a:3 = affine.delinearize_index %arg0 into (2, 2, 64) : index, index, index
+
+ %t1 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s0 + s1 * 128 + s2 * 64 + s3 + 512)>()[%a#2, %a#0, %a#1, %arg2]
+ memref.store %t1, %arg1[%t1] : memref<?xindex>
+
+ return
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2 + s0 ceildiv 4)>
+// CHECK-LABEL: func @delin_apply_cancel_nested_exprs
+// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>)
+// CHECK: affine.apply #[[$MAP]]()[%[[ARG0]]]
+// CHECK: return
+func.func @delin_apply_cancel_nested_exprs(%arg0: index, %arg1: memref<?xindex>) {
+ %a:2 = affine.delinearize_index %arg0 into (20) : index, index
+
+ %t1 = affine.apply affine_map<()[s0, s1] -> ((s0 + s1 * 20) ceildiv 4 + (s1 * 20 + s0) * 2)>()[%a#1, %a#0]
+ memref.store %t1, %arg1[%t1] : memref<?xindex>
+
+ return
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-LABEL: func @delin_apply_cancel_preserve_rotation
+// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>)
+// CHECK: %[[A:.+]]:2 = affine.delinearize_index %[[ARG0]] into (20)
+// CHECK: affine.apply #[[$MAP]]()[%[[A]]#1, %[[ARG0]]]
+// CHECK: return
+func.func @delin_apply_cancel_preserve_rotation(%arg0: index, %arg1: memref<?xindex>) {
+ %a:2 = affine.delinearize_index %arg0 into (20) : index, index
+
+ %t1 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 20 + s0)>()[%a#1, %a#0]
+ memref.store %t1, %arg1[%t1] : memref<?xindex>
+
+ return
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 5)>
+// CHECK-LABEL: func @delin_apply_dont_cancel_partial
+// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>)
+// CHECK: %[[A:.+]]:3 = affine.delinearize_index %[[ARG0]] into (3, 4, 5)
+// CHECK: affine.apply #[[$MAP]]()[%[[A]]#2, %[[A]]#1]
+// CHECK: return
+func.func @delin_apply_dont_cancel_partial(%arg0: index, %arg1: memref<?xindex>) {
+ %a:3 = affine.delinearize_index %arg0 into (3, 4, 5) : index, index, index
+
+ %t1 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 5)>()[%a#2, %a#1]
+ memref.store %t1, %arg1[%t1] : memref<?xindex>
+
+ return
+}
+
+// -----
+
// CHECK-LABEL: @cst_value_to_cst_attr_basis_delinearize_index
// CHECK-SAME: (%[[ARG0:.*]]: index)
// CHECK: %[[RET:.*]]:3 = affine.delinearize_index %[[ARG0]] into (3, 4, 2) : index, index
diff --git a/mlir/test/Dialect/LLVMIR/canonicalize.mlir b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
index 8accf6e..755e3a3 100644
--- a/mlir/test/Dialect/LLVMIR/canonicalize.mlir
+++ b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
@@ -235,6 +235,17 @@ llvm.func @fold_gep_canon(%x : !llvm.ptr) -> !llvm.ptr {
// -----
+// CHECK-LABEL: fold_shufflevector
+// CHECK-SAME: %[[ARG1:[[:alnum:]]+]]: vector<1xf32>, %[[ARG2:[[:alnum:]]+]]: vector<1xf32>
+llvm.func @fold_shufflevector(%v1 : vector<1xf32>, %v2 : vector<1xf32>) -> vector<1xf32> {
+ // CHECK-NOT: llvm.shufflevector
+ %c = llvm.shufflevector %v1, %v2 [0] : vector<1xf32>
+ // CHECK: llvm.return %[[ARG1]]
+ llvm.return %c : vector<1xf32>
+}
+
+// -----
+
// Check that LLVM constants participate in cross-dialect constant folding. The
// resulting constant is created in the arith dialect because the last folded
// operation belongs to it.
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 358bd33..242c04f 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -1035,6 +1035,20 @@ llvm.func @rocdl.s.wait.expcnt() {
llvm.return
}
+llvm.func @rocdl.s.wait.asynccnt() {
+ // CHECK-LABEL: rocdl.s.wait.asynccnt
+ // CHECK: rocdl.s.wait.asynccnt 0
+ rocdl.s.wait.asynccnt 0
+ llvm.return
+}
+
+llvm.func @rocdl.s.wait.tensorcnt() {
+ // CHECK-LABEL: rocdl.s.wait.tensorcnt
+ // CHECK: rocdl.s.wait.tensorcnt 0
+ rocdl.s.wait.tensorcnt 0
+ llvm.return
+}
+
// -----
llvm.func @rocdl.readfirstlane(%src : f32) -> f32 {
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
index 35f520a..93a0336 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
@@ -1,5 +1,9 @@
// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.dot
+///----------------------------------------------------------------------------------------
+
// CHECK-LABEL: contraction_dot
func.func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32>) {
@@ -20,6 +24,10 @@ module attributes {transform.with_named_sequence} {
// -----
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.matvec
+///----------------------------------------------------------------------------------------
+
// CHECK-LABEL: contraction_matvec
func.func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) {
@@ -41,6 +49,10 @@ module attributes {transform.with_named_sequence} {
// -----
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.matmul
+///----------------------------------------------------------------------------------------
+
// CHECK-LABEL: contraction_matmul
func.func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) {
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584xf32>
@@ -138,6 +150,10 @@ module attributes {transform.with_named_sequence} {
// -----
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.batch_matmul
+///----------------------------------------------------------------------------------------
+
// CHECK-LABEL: contraction_batch_matmul
func.func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) {
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584x1584xf32>
@@ -159,6 +175,10 @@ module attributes {transform.with_named_sequence} {
// -----
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.cantract
+///----------------------------------------------------------------------------------------
+
// CHECK-LABEL: @matmul_as_contract
// CHECK-SAME: %[[A:.*]]: tensor<24x12xf32>
// CHECK-SAME: %[[B:.*]]: tensor<12x25xf32>
@@ -220,6 +240,10 @@ module attributes {transform.with_named_sequence} {
// -----
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.fill
+///----------------------------------------------------------------------------------------
+
// CHECK-LABEL: func @test_vectorize_fill
func.func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
// CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32>
@@ -259,70 +283,14 @@ module attributes {transform.with_named_sequence} {
// -----
-// CHECK-LABEL: func @test_vectorize_copy
-func.func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) {
- // CHECK: %[[V:.*]] = vector.transfer_read {{.*}} : memref<8x16xf32>, vector<8x16xf32>
- // CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
- memref.copy %A, %B : memref<8x16xf32> to memref<8x16xf32>
- return
-}
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.pack
+///----------------------------------------------------------------------------------------
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
- %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
+// Note, see a similar test in:
+// * vectorization.mlir.
-// -----
-
-// CHECK-LABEL: func @test_vectorize_copy_0d
-func.func @test_vectorize_copy_0d(%A : memref<f32>, %B : memref<f32>) {
- // CHECK-SAME: (%[[A:.*]]: memref<f32>, %[[B:.*]]: memref<f32>)
- // CHECK: %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref<f32>, vector<f32>
- // CHECK: %[[val:.*]] = vector.extract %[[V]][] : f32 from vector<f32>
- // CHECK: %[[VV:.*]] = vector.broadcast %[[val]] : f32 to vector<f32>
- // CHECK: vector.transfer_write %[[VV]], %[[B]][] : vector<f32>, memref<f32>
- memref.copy %A, %B : memref<f32> to memref<f32>
- return
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
- %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
-// CHECK-LABEL: func @test_vectorize_copy_complex
-// CHECK-NOT: vector<
-func.func @test_vectorize_copy_complex(%A : memref<8x16xcomplex<f32>>, %B : memref<8x16xcomplex<f32>>) {
- memref.copy %A, %B : memref<8x16xcomplex<f32>> to memref<8x16xcomplex<f32>>
- return
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
- %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
-// Input identical as the test in vectorization.mlir. Output is different -
-// vector sizes are inferred (rather than user-specified) and hence _no_
-// masking was used.
-
-func.func @test_vectorize_pack(%arg0: tensor<32x8x16xf32>, %arg1: tensor<4x1x32x16x2xf32>) -> tensor<4x1x32x16x2xf32> {
+func.func @pack_no_padding(%arg0: tensor<32x8x16xf32>, %arg1: tensor<4x1x32x16x2xf32>) -> tensor<4x1x32x16x2xf32> {
%pack = linalg.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x8x16xf32> -> tensor<4x1x32x16x2xf32>
return %pack : tensor<4x1x32x16x2xf32>
}
@@ -336,7 +304,7 @@ module attributes {transform.with_named_sequence} {
}
}
-// CHECK-LABEL: func.func @test_vectorize_pack(
+// CHECK-LABEL: func.func @pack_no_padding(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x8x16xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1x32x16x2xf32>) -> tensor<4x1x32x16x2xf32> {
// CHECK-DAG: %[[VAL_2:.*]] = ub.poison : f32
@@ -349,13 +317,16 @@ module attributes {transform.with_named_sequence} {
// -----
-func.func @test_vectorize_padded_pack(%arg0: tensor<32x7x15xf32>, %arg1: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
+// Note, see a similar test in:
+// * vectorization.mlir.
+
+func.func @pack_with_padding(%arg0: tensor<32x7x15xf32>, %arg1: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
%pad = arith.constant 0.000000e+00 : f32
%pack = linalg.pack %arg0 padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32>
return %pack : tensor<32x4x1x16x2xf32>
}
-// CHECK-LABEL: func.func @test_vectorize_padded_pack(
+// CHECK-LABEL: func.func @pack_with_padding(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x7x15xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
// CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
@@ -377,6 +348,10 @@ module attributes {transform.with_named_sequence} {
// -----
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.map
+///----------------------------------------------------------------------------------------
+
func.func @vectorize_map(%arg0: memref<64xf32>,
%arg1: memref<64xf32>, %arg2: memref<64xf32>) {
linalg.map ins(%arg0, %arg1 : memref<64xf32>, memref<64xf32>)
@@ -403,6 +378,10 @@ module attributes {transform.with_named_sequence} {
// -----
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.transpose
+///----------------------------------------------------------------------------------------
+
func.func @vectorize_transpose(%arg0: memref<16x32x64xf32>,
%arg1: memref<32x64x16xf32>) {
linalg.transpose ins(%arg0 : memref<16x32x64xf32>)
@@ -424,6 +403,10 @@ module attributes {transform.with_named_sequence} {
// -----
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.reduce
+///----------------------------------------------------------------------------------------
+
func.func @vectorize_reduce(%arg0: memref<16x32x64xf32>,
%arg1: memref<16x64xf32>) {
linalg.reduce ins(%arg0 : memref<16x32x64xf32>)
@@ -449,6 +432,10 @@ module attributes {transform.with_named_sequence} {
// -----
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.generic
+///----------------------------------------------------------------------------------------
+
#matmul_trait = {
indexing_maps = [
affine_map<(m, n, k) -> (m, k)>,
@@ -1446,6 +1433,8 @@ module attributes {transform.with_named_sequence} {
// -----
+// TODO: Two Linalg Ops in one tests - either split or document "why".
+
// CHECK-DAG: #[[$M6:.*]] = affine_map<(d0, d1) -> (d0, 0)>
// CHECK-LABEL: func @fused_broadcast_red_2d
@@ -1896,3 +1885,65 @@ module attributes {transform.with_named_sequence} {
}
}
+// -----
+
+///----------------------------------------------------------------------------------------
+/// Tests for memref.copy
+///----------------------------------------------------------------------------------------
+
+// CHECK-LABEL: func @test_vectorize_copy
+func.func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) {
+ // CHECK: %[[V:.*]] = vector.transfer_read {{.*}} : memref<8x16xf32>, vector<8x16xf32>
+ // CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
+ memref.copy %A, %B : memref<8x16xf32> to memref<8x16xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func @test_vectorize_copy_0d
+func.func @test_vectorize_copy_0d(%A : memref<f32>, %B : memref<f32>) {
+ // CHECK-SAME: (%[[A:.*]]: memref<f32>, %[[B:.*]]: memref<f32>)
+ // CHECK: %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref<f32>, vector<f32>
+ // CHECK: %[[val:.*]] = vector.extract %[[V]][] : f32 from vector<f32>
+ // CHECK: %[[VV:.*]] = vector.broadcast %[[val]] : f32 to vector<f32>
+ // CHECK: vector.transfer_write %[[VV]], %[[B]][] : vector<f32>, memref<f32>
+ memref.copy %A, %B : memref<f32> to memref<f32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func @test_vectorize_copy_complex
+// CHECK-NOT: vector<
+func.func @test_vectorize_copy_complex(%A : memref<8x16xcomplex<f32>>, %B : memref<8x16xcomplex<f32>>) {
+ memref.copy %A, %B : memref<8x16xcomplex<f32>> to memref<8x16xcomplex<f32>>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
index 11bea8d..1304a90 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
@@ -1307,14 +1307,17 @@ func.func @test_vectorize_unpack_no_vector_sizes_permute(%source: tensor<4x7x4xf
/// Tests for linalg.pack
///----------------------------------------------------------------------------------------
-// Input identical as the test in vectorization-with-patterns.mlir. Output is
-// different - vector sizes are inferred (rather than user-specified) and hence
-// masking was used.
+// This packing requires no padding, so no out-of-bounds read/write vector Ops.
-// CHECK-LABEL: func @test_vectorize_pack
+// Note, see a similar test in:
+// * vectorization-with-patterns.mlir
+// The output is identical (the input vector sizes == the inferred vector
+// sizes, i.e. the tensor sizes).
+
+// CHECK-LABEL: func @pack_no_padding
// CHECK-SAME: %[[SRC:.*]]: tensor<32x8x16xf32>,
// CHECK-SAME: %[[DEST:.*]]: tensor<4x1x32x16x2xf32>
-func.func @test_vectorize_pack(%src: tensor<32x8x16xf32>, %dest: tensor<4x1x32x16x2xf32>) -> tensor<4x1x32x16x2xf32> {
+func.func @pack_no_padding(%src: tensor<32x8x16xf32>, %dest: tensor<4x1x32x16x2xf32>) -> tensor<4x1x32x16x2xf32> {
%pack = linalg.pack %src outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %dest : tensor<32x8x16xf32> -> tensor<4x1x32x16x2xf32>
return %pack : tensor<4x1x32x16x2xf32>
}
@@ -1325,9 +1328,9 @@ func.func @test_vectorize_pack(%src: tensor<32x8x16xf32>, %dest: tensor<4x1x32x1
// CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
// CHECK: %[[TR:.*]] = vector.transpose %[[SC]], [1, 3, 0, 4, 2] : vector<32x4x2x1x16xf32> to vector<4x1x32x16x2xf32>
// CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index
-// CHECK: %[[write:.*]] = vector.transfer_write %[[TR]], %[[DEST]][%[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]]]
+// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[TR]], %[[DEST]][%[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]]]
// CHECK-SAME: {in_bounds = [true, true, true, true, true]} : vector<4x1x32x16x2xf32>, tensor<4x1x32x16x2xf32>
-// CHECK: return %[[write]] : tensor<4x1x32x16x2xf32>
+// CHECK: return %[[WRITE]] : tensor<4x1x32x16x2xf32>
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%src: !transform.any_op {transform.readonly}) {
@@ -1339,14 +1342,18 @@ module attributes {transform.with_named_sequence} {
// -----
-// Input identical as the test in vectorization-with-patterns.mlir. Output is
-// different - vector sizes are inferred (rather than user-specified) and hence
-// masking was used.
+// This packing does require padding, so there are out-of-bounds read/write
+// vector Ops.
+
+// Note, see a similar test in:
+// * vectorization-with-patterns.mlir.
+// The output is different (the input vector sizes != inferred vector sizes,
+// i.e. the tensor sizes).
-// CHECK-LABEL: func @test_vectorize_padded_pack
+// CHECK-LABEL: func @pack_with_padding
// CHECK-SAME: %[[SRC:.*]]: tensor<32x7x15xf32>,
// CHECK-SAME: %[[DEST:.*]]: tensor<32x4x1x16x2xf32>
-func.func @test_vectorize_padded_pack(%src: tensor<32x7x15xf32>, %dest: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
+func.func @pack_with_padding(%src: tensor<32x7x15xf32>, %dest: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
%pad = arith.constant 0.000000e+00 : f32
%pack = linalg.pack %src padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %dest : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32>
return %pack : tensor<32x4x1x16x2xf32>
@@ -1364,9 +1371,9 @@ func.func @test_vectorize_padded_pack(%src: tensor<32x7x15xf32>, %dest: tensor<3
// CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
// CHECK: %[[TR:.*]] = vector.transpose %[[SC]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
// CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index
-// CHECK: %[[write:.*]] = vector.transfer_write %[[TR]], %[[DEST]][%[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]]]
+// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[TR]], %[[DEST]][%[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]]]
// CHECK-SAME: {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
-// CHECK: return %[[write]] : tensor<32x4x1x16x2xf32>
+// CHECK: return %[[WRITE]] : tensor<32x4x1x16x2xf32>
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
@@ -1378,10 +1385,46 @@ module attributes {transform.with_named_sequence} {
// -----
-// CHECK-LABEL: func @test_vectorize_dynamic_pack
+// This packing does require padding, so there are out-of-bounds read/write
+// vector Ops.
+
+// Note, see a similar test in:
+// * vectorization-with-patterns.mlir.
+// The output is identical (in both cases the vector sizes are inferred).
+
+// CHECK-LABEL: func @pack_with_padding_no_vector_sizes
+// CHECK-SAME: %[[SRC:.*]]: tensor<32x7x15xf32>,
+// CHECK-SAME: %[[DEST:.*]]: tensor<32x4x1x16x2xf32>
+func.func @pack_with_padding_no_vector_sizes(%src: tensor<32x7x15xf32>, %dest: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
+ %pad = arith.constant 0.000000e+00 : f32
+ %pack = linalg.pack %src padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %dest : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32>
+ return %pack : tensor<32x4x1x16x2xf32>
+}
+// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[READ:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[CST]]
+// CHECK-SAME: {in_bounds = [true, false, false]} : tensor<32x7x15xf32>, vector<32x8x16xf32>
+// CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
+// CHECK: %[[TR:.*]] = vector.transpose %[[SC]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
+// CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[TR]], %[[DEST]][%[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]]]
+// CHECK-SAME: {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
+// CHECK: return %[[WRITE]] : tensor<32x4x1x16x2xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func @pack_with_dynamic_dims
// CHECK-SAME: %[[SRC:.*]]: tensor<?x?xf32>,
// CHECK-SAME: %[[DEST:.*]]: tensor<?x?x16x2xf32>
-func.func @test_vectorize_dynamic_pack(%src: tensor<?x?xf32>, %dest: tensor<?x?x16x2xf32>) -> tensor<?x?x16x2xf32> {
+func.func @pack_with_dynamic_dims(%src: tensor<?x?xf32>, %dest: tensor<?x?x16x2xf32>) -> tensor<?x?x16x2xf32> {
%pack = linalg.pack %src inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %dest : tensor<?x?xf32> -> tensor<?x?x16x2xf32>
return %pack : tensor<?x?x16x2xf32>
}
@@ -1418,64 +1461,6 @@ module attributes {transform.with_named_sequence} {
}
}
-// -----
-
-// CHECK-LABEL: func @test_vectorize_pack_no_vector_sizes
-// CHECK-SAME: %[[SRC:.*]]: tensor<64x4xf32>,
-// CHECK-SAME: %[[DEST:.*]]: tensor<2x4x16x2xf32>
-func.func @test_vectorize_pack_no_vector_sizes(%src: tensor<64x4xf32>, %dest: tensor<2x4x16x2xf32>) -> tensor<2x4x16x2xf32> {
- %pack = linalg.pack %src outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [16, 2] into %dest : tensor<64x4xf32> -> tensor<2x4x16x2xf32>
- return %pack : tensor<2x4x16x2xf32>
-}
-// CHECK-DAG: %[[CST:.*]] = ub.poison : f32
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[READ:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %[[CST]]
-// CHECK-SAME: {in_bounds = [true, true]} : tensor<64x4xf32>, vector<64x4xf32>
-// CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<64x4xf32> to vector<4x16x2x2xf32>
-// CHECK: %[[TR:.*]] = vector.transpose %[[SC]], [2, 0, 1, 3] : vector<4x16x2x2xf32> to vector<2x4x16x2xf32>
-// CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index
-// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[TR]], %[[DEST]][%[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]]]
-// CHECK-SAME: {in_bounds = [true, true, true, true]} : vector<2x4x16x2xf32>, tensor<2x4x16x2xf32>
-// CHECK: return %[[WRITE]] : tensor<2x4x16x2xf32>
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- transform.structured.vectorize %0 : !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
-// CHECK-LABEL: test_vectorize_padded_pack_no_vector_sizes
-// CHECK-SAME: %[[SRC:.*]]: tensor<32x7x15xf32>,
-// CHECK-SAME: %[[DEST:.*]]: tensor<32x4x1x16x2xf32>
-func.func @test_vectorize_padded_pack_no_vector_sizes(%src: tensor<32x7x15xf32>, %dest: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
- %pad = arith.constant 0.000000e+00 : f32
- %pack = linalg.pack %src padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %dest : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32>
- return %pack : tensor<32x4x1x16x2xf32>
-}
-// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[READ:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[CST]]
-// CHECK-SAME: {in_bounds = [true, false, false]} : tensor<32x7x15xf32>, vector<32x8x16xf32>
-// CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
-// CHECK: %[[TR:.*]] = vector.transpose %[[SC]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
-// CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index
-// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[TR]], %[[DEST]][%[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]]]
-// CHECK-SAME: {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
-// CHECK: return %[[WRITE]] : tensor<32x4x1x16x2xf32>
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- transform.structured.vectorize %0 : !transform.any_op
- transform.yield
- }
-}
-
-
///----------------------------------------------------------------------------------------
/// Tests for other Ops
///----------------------------------------------------------------------------------------
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 16b7a5c..7160b52 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -911,6 +911,21 @@ func.func @reinterpret_noop(%arg : memref<2x3x4xf32>) -> memref<2x3x4xf32> {
// -----
+// CHECK-LABEL: func @reinterpret_constant_fold
+// CHECK-SAME: (%[[ARG:.*]]: memref<f32>)
+// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [100, 100], strides: [100, 1]
+// CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+// CHECK: return %[[CAST]]
+func.func @reinterpret_constant_fold(%arg0: memref<f32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c100 = arith.constant 100 : index
+ %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%c0], sizes: [%c100, %c100], strides: [%c100, %c1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ return %reinterpret_cast : memref<?x?xf32, strided<[?, ?], offset: ?>>
+}
+
+// -----
+
// CHECK-LABEL: func @reinterpret_of_reinterpret
// CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index)
// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1]
@@ -996,10 +1011,9 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?x
// when the strides don't match.
// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_stride
// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]]
-// CHECK: return %[[RES]]
+// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [4, 2, 2], strides: [1, 1, 1]
+// CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+// CHECK: return %[[CAST]]
func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> {
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
%m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref<f32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
@@ -1011,11 +1025,9 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : me
// when the offset doesn't match.
// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_offset
// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
-// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
-// CHECK: return %[[RES]]
+// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [8, 2], strides: [2, 1]
+// CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+// CHECK: return %[[CAST]]
func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
%m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
diff --git a/mlir/test/Dialect/OpenACC/pointer-like-interface-alloc.mlir b/mlir/test/Dialect/OpenACC/pointer-like-interface-alloc.mlir
index 603ace8..3d4bec7 100644
--- a/mlir/test/Dialect/OpenACC/pointer-like-interface-alloc.mlir
+++ b/mlir/test/Dialect/OpenACC/pointer-like-interface-alloc.mlir
@@ -3,7 +3,7 @@
func.func @test_static_memref_alloc() {
%0 = memref.alloca() {test.ptr} : memref<10x20xf32>
// CHECK: Successfully generated alloc for operation: %[[ORIG:.*]] = memref.alloca() {test.ptr} : memref<10x20xf32>
- // CHECK: Generated: %{{.*}} = memref.alloca() : memref<10x20xf32>
+ // CHECK: Generated: %{{.*}} = memref.alloca() {acc.var_name = #acc.var_name<"test_alloc">} : memref<10x20xf32>
return
}
@@ -19,6 +19,6 @@ func.func @test_dynamic_memref_alloc() {
// CHECK: Generated: %[[DIM0:.*]] = memref.dim %[[ORIG]], %[[C0]] : memref<?x?xf32>
// CHECK: Generated: %[[C1:.*]] = arith.constant 1 : index
// CHECK: Generated: %[[DIM1:.*]] = memref.dim %[[ORIG]], %[[C1]] : memref<?x?xf32>
- // CHECK: Generated: %{{.*}} = memref.alloc(%[[DIM0]], %[[DIM1]]) : memref<?x?xf32>
+ // CHECK: Generated: %{{.*}} = memref.alloc(%[[DIM0]], %[[DIM1]]) {acc.var_name = #acc.var_name<"test_alloc">} : memref<?x?xf32>
return
}
diff --git a/mlir/test/Dialect/OpenACC/recipe-populate-firstprivate.mlir b/mlir/test/Dialect/OpenACC/recipe-populate-firstprivate.mlir
index 35355c6..8846c9e 100644
--- a/mlir/test/Dialect/OpenACC/recipe-populate-firstprivate.mlir
+++ b/mlir/test/Dialect/OpenACC/recipe-populate-firstprivate.mlir
@@ -2,7 +2,7 @@
// CHECK: acc.firstprivate.recipe @firstprivate_scalar : memref<f32> init {
// CHECK: ^bb0(%{{.*}}: memref<f32>):
-// CHECK: %[[ALLOC:.*]] = memref.alloca() : memref<f32>
+// CHECK: %[[ALLOC:.*]] = memref.alloca() {acc.var_name = #acc.var_name<"scalar">} : memref<f32>
// CHECK: acc.yield %[[ALLOC]] : memref<f32>
// CHECK: } copy {
// CHECK: ^bb0(%[[SRC:.*]]: memref<f32>, %[[DST:.*]]: memref<f32>):
@@ -20,7 +20,7 @@ func.func @test_scalar() {
// CHECK: acc.firstprivate.recipe @firstprivate_static_2d : memref<10x20xf32> init {
// CHECK: ^bb0(%{{.*}}: memref<10x20xf32>):
-// CHECK: %[[ALLOC:.*]] = memref.alloca() : memref<10x20xf32>
+// CHECK: %[[ALLOC:.*]] = memref.alloca() {acc.var_name = #acc.var_name<"static_2d">} : memref<10x20xf32>
// CHECK: acc.yield %[[ALLOC]] : memref<10x20xf32>
// CHECK: } copy {
// CHECK: ^bb0(%[[SRC:.*]]: memref<10x20xf32>, %[[DST:.*]]: memref<10x20xf32>):
@@ -42,7 +42,7 @@ func.func @test_static_2d() {
// CHECK: %[[DIM0:.*]] = memref.dim %[[ARG]], %[[C0]] : memref<?x?xf32>
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<?x?xf32>
-// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) : memref<?x?xf32>
+// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) {acc.var_name = #acc.var_name<"dynamic_2d">} : memref<?x?xf32>
// CHECK: acc.yield %[[ALLOC]] : memref<?x?xf32>
// CHECK: } copy {
// CHECK: ^bb0(%[[SRC:.*]]: memref<?x?xf32>, %[[DST:.*]]: memref<?x?xf32>):
@@ -65,7 +65,7 @@ func.func @test_dynamic_2d(%arg0: index, %arg1: index) {
// CHECK: ^bb0(%[[ARG:.*]]: memref<10x?xf32>):
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<10x?xf32>
-// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM1]]) : memref<10x?xf32>
+// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM1]]) {acc.var_name = #acc.var_name<"mixed_dims">} : memref<10x?xf32>
// CHECK: acc.yield %[[ALLOC]] : memref<10x?xf32>
// CHECK: } copy {
// CHECK: ^bb0(%[[SRC:.*]]: memref<10x?xf32>, %[[DST:.*]]: memref<10x?xf32>):
@@ -86,7 +86,7 @@ func.func @test_mixed_dims(%arg0: index) {
// CHECK: acc.firstprivate.recipe @firstprivate_scalar_int : memref<i32> init {
// CHECK: ^bb0(%{{.*}}: memref<i32>):
-// CHECK: %[[ALLOC:.*]] = memref.alloca() : memref<i32>
+// CHECK: %[[ALLOC:.*]] = memref.alloca() {acc.var_name = #acc.var_name<"scalar_int">} : memref<i32>
// CHECK: acc.yield %[[ALLOC]] : memref<i32>
// CHECK: } copy {
// CHECK: ^bb0(%[[SRC:.*]]: memref<i32>, %[[DST:.*]]: memref<i32>):
diff --git a/mlir/test/Dialect/OpenACC/recipe-populate-private.mlir b/mlir/test/Dialect/OpenACC/recipe-populate-private.mlir
index 8403ee8..3d5a918 100644
--- a/mlir/test/Dialect/OpenACC/recipe-populate-private.mlir
+++ b/mlir/test/Dialect/OpenACC/recipe-populate-private.mlir
@@ -2,7 +2,7 @@
// CHECK: acc.private.recipe @private_scalar : memref<f32> init {
// CHECK: ^bb0(%{{.*}}: memref<f32>):
-// CHECK: %[[ALLOC:.*]] = memref.alloca() : memref<f32>
+// CHECK: %[[ALLOC:.*]] = memref.alloca() {acc.var_name = #acc.var_name<"scalar">} : memref<f32>
// CHECK: acc.yield %[[ALLOC]] : memref<f32>
// CHECK: }
// CHECK-NOT: destroy
@@ -16,7 +16,7 @@ func.func @test_scalar() {
// CHECK: acc.private.recipe @private_static_2d : memref<10x20xf32> init {
// CHECK: ^bb0(%{{.*}}: memref<10x20xf32>):
-// CHECK: %[[ALLOC:.*]] = memref.alloca() : memref<10x20xf32>
+// CHECK: %[[ALLOC:.*]] = memref.alloca() {acc.var_name = #acc.var_name<"static_2d">} : memref<10x20xf32>
// CHECK: acc.yield %[[ALLOC]] : memref<10x20xf32>
// CHECK: }
// CHECK-NOT: destroy
@@ -34,7 +34,7 @@ func.func @test_static_2d() {
// CHECK: %[[DIM0:.*]] = memref.dim %[[ARG]], %[[C0]] : memref<?x?xf32>
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<?x?xf32>
-// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) : memref<?x?xf32>
+// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) {acc.var_name = #acc.var_name<"dynamic_2d">} : memref<?x?xf32>
// CHECK: acc.yield %[[ALLOC]] : memref<?x?xf32>
// CHECK: } destroy {
// CHECK: ^bb0(%{{.*}}: memref<?x?xf32>, %[[VAL:.*]]: memref<?x?xf32>):
@@ -53,7 +53,7 @@ func.func @test_dynamic_2d(%arg0: index, %arg1: index) {
// CHECK: ^bb0(%[[ARG:.*]]: memref<10x?xf32>):
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<10x?xf32>
-// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM1]]) : memref<10x?xf32>
+// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM1]]) {acc.var_name = #acc.var_name<"mixed_dims">} : memref<10x?xf32>
// CHECK: acc.yield %[[ALLOC]] : memref<10x?xf32>
// CHECK: } destroy {
// CHECK: ^bb0(%{{.*}}: memref<10x?xf32>, %[[VAL:.*]]: memref<10x?xf32>):
@@ -70,7 +70,7 @@ func.func @test_mixed_dims(%arg0: index) {
// CHECK: acc.private.recipe @private_scalar_int : memref<i32> init {
// CHECK: ^bb0(%{{.*}}: memref<i32>):
-// CHECK: %[[ALLOC:.*]] = memref.alloca() : memref<i32>
+// CHECK: %[[ALLOC:.*]] = memref.alloca() {acc.var_name = #acc.var_name<"scalar_int">} : memref<i32>
// CHECK: acc.yield %[[ALLOC]] : memref<i32>
// CHECK: }
// CHECK-NOT: destroy
diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
index b6c72be..f66cf7a 100644
--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
@@ -490,3 +490,32 @@ func.func @collapse_shape_regression(
tensor.collapse_shape %0[[0, 1]] : tensor<5x6xf32> into tensor<30xf32>
return
}
+
+// -----
+
+// CHECK-LABEL: func private @mult_return_callee(
+// CHECK-SAME: %[[T:.*]]: memref<?xf32, strided<[?], offset: ?>>, %[[COND:.*]]: i1,
+// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index) -> index {
+// CHECK: cf.cond_br %[[COND]], ^bb1, ^bb2
+// CHECK: ^bb1:
+// CHECK: return %[[A]] : index
+// CHECK: ^bb2:
+// CHECK: return %[[B]] : index
+func.func private @mult_return_callee(%t: tensor<?xf32>, %cond:i1, %a: index, %b: index) -> (tensor<10xf32>, index) {
+ %casted = tensor.cast %t : tensor<?xf32> to tensor<10xf32>
+ cf.cond_br %cond,^a, ^b
+^a:
+ return %casted, %a : tensor<10xf32>, index
+^b:
+ return %casted, %b : tensor<10xf32>, index
+}
+
+// CHECK-LABEL: func @mult_return(
+// CHECK-SAME: %[[T:.*]]: memref<?xf32, strided<[?], offset: ?>>, %[[COND:.*]]: i1,
+// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index) -> (memref<?xf32, strided<[?], offset: ?>>, index) {
+func.func @mult_return(%t: tensor<?xf32>, %cond:i1, %a: index, %b: index) -> (tensor<10xf32>, index) {
+ // CHECK: %[[RET:.*]] = call @mult_return_callee(%[[T]], %[[COND]], %[[A]], %[[B]]) : (memref<?xf32, strided<[?], offset: ?>>, i1, index, index) -> index
+ // CHECK: return %[[T]], %[[RET]] : memref<?xf32, strided<[?], offset: ?>>, index
+ %t_res, %v = func.call @mult_return_callee(%t, %cond, %a, %b) : (tensor<?xf32>, i1, index, index) -> (tensor<10xf32>, index)
+ return %t_res, %v : tensor<10xf32>, index
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-attach-target.mlir b/mlir/test/Dialect/Tosa/tosa-attach-target.mlir
index d6c886c..a0c59c0 100644
--- a/mlir/test/Dialect/Tosa/tosa-attach-target.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-attach-target.mlir
@@ -1,12 +1,14 @@
// RUN: mlir-opt %s -split-input-file -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,dynamic level=none" | FileCheck %s --check-prefix=CHECK-ALL
// RUN: mlir-opt %s -split-input-file -tosa-attach-target="level=8k" | FileCheck %s --check-prefix=CHECK-LVL-8K
// RUN: mlir-opt %s -split-input-file -tosa-attach-target | FileCheck %s --check-prefix=CHECK-DEFAULT
+// RUN: mlir-opt %s -split-input-file -tosa-attach-target="specification_version=1.1.draft" | FileCheck %s --check-prefix=CHECK-VERSION-1P1
// -----
-// CHECK-ALL: module attributes {tosa.target_env = #tosa.target_env<level = none, profiles = [pro_int, pro_fp], extensions = [int16, int4, bf16, fp8e4m3, fp8e5m2, fft, variable, controlflow, doubleround, inexactround, dynamic]>}
-// CHECK-LVL-8K: module attributes {tosa.target_env = #tosa.target_env<level = "8k", profiles = [], extensions = []>}
-// CHECK-DEFAULT: module attributes {tosa.target_env = #tosa.target_env<level = "8k", profiles = [], extensions = []>}
+// CHECK-ALL: module attributes {tosa.target_env = #tosa.target_env<specification_version = "1.0", level = none, profiles = [pro_int, pro_fp], extensions = [int16, int4, bf16, fp8e4m3, fp8e5m2, fft, variable, controlflow, doubleround, inexactround, dynamic]>}
+// CHECK-LVL-8K: module attributes {tosa.target_env = #tosa.target_env<specification_version = "1.0", level = "8k", profiles = [], extensions = []>}
+// CHECK-DEFAULT: module attributes {tosa.target_env = #tosa.target_env<specification_version = "1.0", level = "8k", profiles = [], extensions = []>}
+// CHECK-VERSION-1P1: module attributes {tosa.target_env = #tosa.target_env<specification_version = "1.1.draft", level = "8k", profiles = [], extensions = []>}
// CHECK-LABEL: test_simple
func.func @test_simple(%arg0 : tensor<1x1x1x1xf32>, %arg1 : tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> {
%1 = tosa.add %arg0, %arg1 : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir
new file mode 100644
index 0000000..51089df
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.0 profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
+
+// -----
+
+func.func @test_matmul_fp8_mixed_precision_operands(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16> {
+ %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
+ %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ // expected-error@+1 {{'tosa.matmul' op illegal: the target specification version (1.0) is not backwards compatible with the op compliance specification version (1.1)}}
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E5M2>, tensor<1xf8E4M3FN>, tensor<1xf8E5M2>) -> tensor<1x14x28xf16>
+ return %0 : tensor<1x14x28xf16>
+}
+
+// -----
+
+func.func @test_matmul_fp8_input_fp32_acc_type(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E4M3FN>) -> tensor<1x14x28xf32> {
+ %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
+ %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
+ // expected-error@+1 {{'tosa.matmul' op illegal: the target specification version (1.0) is not backwards compatible with the op compliance specification version (1.1)}}
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E4M3FN>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x14x28xf32>
+ return %0 : tensor<1x14x28xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
new file mode 100644
index 0000000..8164509
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" | FileCheck %s
+
+// -----
+
+func.func @test_matmul_fp8_mixed_precision_operands(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16> {
+ %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
+ %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E5M2>, tensor<1xf8E4M3FN>, tensor<1xf8E5M2>) -> tensor<1x14x28xf16>
+ return %0 : tensor<1x14x28xf16>
+}
+
+// -----
+
+// CHECK-LABEL: test_matmul_fp8_input_fp32_acc_type
+func.func @test_matmul_fp8_input_fp32_acc_type(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E4M3FN>) -> tensor<1x14x28xf32> {
+ %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
+ %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E4M3FN>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x14x28xf32>
+ return %0 : tensor<1x14x28xf32>
+}
diff --git a/mlir/test/Dialect/WasmSSA/custom_parser/global.mlir b/mlir/test/Dialect/WasmSSA/custom_parser/global.mlir
index b9b3420..a25abbd 100644
--- a/mlir/test/Dialect/WasmSSA/custom_parser/global.mlir
+++ b/mlir/test/Dialect/WasmSSA/custom_parser/global.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt %s | FileCheck %s
module {
- wasmssa.import_global "from_js" from "env" as @global_0 nested : i32
+ wasmssa.import_global "from_js" from "env" as @global_0 : i32
wasmssa.global @global_1 i32 : {
%0 = wasmssa.const 10 : i32
@@ -21,7 +21,7 @@ module {
}
}
-// CHECK-LABEL: wasmssa.import_global "from_js" from "env" as @global_0 nested : i32
+// CHECK-LABEL: wasmssa.import_global "from_js" from "env" as @global_0 : i32
// CHECK-LABEL: wasmssa.global @global_1 i32 : {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
diff --git a/mlir/test/Dialect/WasmSSA/custom_parser/if.mlir b/mlir/test/Dialect/WasmSSA/custom_parser/if.mlir
index 01068cb..cee3c69 100644
--- a/mlir/test/Dialect/WasmSSA/custom_parser/if.mlir
+++ b/mlir/test/Dialect/WasmSSA/custom_parser/if.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s | FileCheck %s
-// CHECK-LABEL: wasmssa.func nested @func_0(
+// CHECK-LABEL: wasmssa.func @func_0(
// CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i32>) -> i32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.local_get %[[ARG0]] : ref to i32
// CHECK: wasmssa.if %[[VAL_0]] : {
@@ -12,7 +12,7 @@
// CHECK: }> ^bb1
// CHECK: ^bb1(%[[VAL_3:.*]]: f32):
// CHECK: wasmssa.return %[[VAL_3]] : f32
-wasmssa.func nested @func_0(%arg0 : !wasmssa<local ref to i32>) -> i32 {
+wasmssa.func @func_0(%arg0 : !wasmssa<local ref to i32>) -> i32 {
%cond = wasmssa.local_get %arg0 : ref to i32
wasmssa.if %cond : {
%c0 = wasmssa.const 0.5 : f32
@@ -25,7 +25,7 @@ wasmssa.func nested @func_0(%arg0 : !wasmssa<local ref to i32>) -> i32 {
wasmssa.return %retVal : f32
}
-// CHECK-LABEL: wasmssa.func nested @func_1(
+// CHECK-LABEL: wasmssa.func @func_1(
// CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i32>) -> i32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.local_get %[[ARG0]] : ref to i32
// CHECK: %[[VAL_1:.*]] = wasmssa.local of type i32
@@ -38,7 +38,7 @@ wasmssa.func nested @func_0(%arg0 : !wasmssa<local ref to i32>) -> i32 {
// CHECK: ^bb1:
// CHECK: %[[VAL_4:.*]] = wasmssa.local_get %[[VAL_1]] : ref to i32
// CHECK: wasmssa.return %[[VAL_4]] : i32
-wasmssa.func nested @func_1(%arg0 : !wasmssa<local ref to i32>) -> i32 {
+wasmssa.func @func_1(%arg0 : !wasmssa<local ref to i32>) -> i32 {
%cond = wasmssa.local_get %arg0 : ref to i32
%var = wasmssa.local of type i32
%zero = wasmssa.const 0
diff --git a/mlir/test/Dialect/WasmSSA/custom_parser/import.mlir b/mlir/test/Dialect/WasmSSA/custom_parser/import.mlir
index 3cc0548..dc23229 100644
--- a/mlir/test/Dialect/WasmSSA/custom_parser/import.mlir
+++ b/mlir/test/Dialect/WasmSSA/custom_parser/import.mlir
@@ -5,13 +5,13 @@ module {
wasmssa.import_func "bar" from "my_module" as @func_1 {sym_visibility = "nested", type = (i32) -> ()}
wasmssa.import_table "table" from "my_module" as @table_0 {sym_visibility = "nested", type = !wasmssa<tabletype !wasmssa.funcref [2:]>}
wasmssa.import_mem "mem" from "my_module" as @mem_0 {limits = !wasmssa<limit[2:]>, sym_visibility = "nested"}
- wasmssa.import_global "glob" from "my_module" as @global_0 nested : i32
- wasmssa.import_global "glob_mut" from "my_other_module" as @global_1 mutable nested : i32
+ wasmssa.import_global "glob" from "my_module" as @global_0 : i32
+ wasmssa.import_global "glob_mut" from "my_other_module" as @global_1 mutable : i32
}
// CHECK-LABEL: wasmssa.import_func "foo" from "my_module" as @func_0 {sym_visibility = "nested", type = (i32) -> ()}
// CHECK: wasmssa.import_func "bar" from "my_module" as @func_1 {sym_visibility = "nested", type = (i32) -> ()}
// CHECK: wasmssa.import_table "table" from "my_module" as @table_0 {sym_visibility = "nested", type = !wasmssa<tabletype !wasmssa.funcref [2:]>}
// CHECK: wasmssa.import_mem "mem" from "my_module" as @mem_0 {limits = !wasmssa<limit[2:]>, sym_visibility = "nested"}
-// CHECK: wasmssa.import_global "glob" from "my_module" as @global_0 nested : i32
-// CHECK: wasmssa.import_global "glob_mut" from "my_other_module" as @global_1 mutable nested : i32
+// CHECK: wasmssa.import_global "glob" from "my_module" as @global_0 : i32
+// CHECK: wasmssa.import_global "glob_mut" from "my_other_module" as @global_1 mutable : i32
diff --git a/mlir/test/Dialect/WasmSSA/custom_parser/local.mlir b/mlir/test/Dialect/WasmSSA/custom_parser/local.mlir
index 3f6423f..f613ebf 100644
--- a/mlir/test/Dialect/WasmSSA/custom_parser/local.mlir
+++ b/mlir/test/Dialect/WasmSSA/custom_parser/local.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt %s | FileCheck %s
module {
- wasmssa.func nested @func_0() -> f32 {
+ wasmssa.func @func_0() -> f32 {
%0 = wasmssa.local of type f32
%1 = wasmssa.local of type f32
%2 = wasmssa.const 8.000000e+00 : f32
@@ -9,7 +9,7 @@ module {
%4 = wasmssa.add %2 %3 : f32
wasmssa.return %4 : f32
}
- wasmssa.func nested @func_1() -> i32 {
+ wasmssa.func @func_1() -> i32 {
%0 = wasmssa.local of type i32
%1 = wasmssa.local of type i32
%2 = wasmssa.const 8 : i32
@@ -17,13 +17,13 @@ module {
%4 = wasmssa.add %2 %3 : i32
wasmssa.return %4 : i32
}
- wasmssa.func nested @func_2(%arg0: !wasmssa<local ref to i32>) -> i32 {
+ wasmssa.func @func_2(%arg0: !wasmssa<local ref to i32>) -> i32 {
%0 = wasmssa.const 3 : i32
wasmssa.return %0 : i32
}
}
-// CHECK-LABEL: wasmssa.func nested @func_0() -> f32 {
+// CHECK-LABEL: wasmssa.func @func_0() -> f32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.local of type f32
// CHECK: %[[VAL_1:.*]] = wasmssa.local of type f32
// CHECK: %[[VAL_2:.*]] = wasmssa.const 8.000000e+00 : f32
@@ -31,7 +31,7 @@ module {
// CHECK: %[[VAL_4:.*]] = wasmssa.add %[[VAL_2]] %[[VAL_3]] : f32
// CHECK: wasmssa.return %[[VAL_4]] : f32
-// CHECK-LABEL: wasmssa.func nested @func_1() -> i32 {
+// CHECK-LABEL: wasmssa.func @func_1() -> i32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.local of type i32
// CHECK: %[[VAL_1:.*]] = wasmssa.local of type i32
// CHECK: %[[VAL_2:.*]] = wasmssa.const 8 : i32
@@ -39,7 +39,7 @@ module {
// CHECK: %[[VAL_4:.*]] = wasmssa.add %[[VAL_2]] %[[VAL_3]] : i32
// CHECK: wasmssa.return %[[VAL_4]] : i32
-// CHECK-LABEL: wasmssa.func nested @func_2(
+// CHECK-LABEL: wasmssa.func @func_2(
// CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i32>) -> i32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 3 : i32
// CHECK: wasmssa.return %[[VAL_0]] : i32
diff --git a/mlir/test/Dialect/WasmSSA/custom_parser/memory.mlir b/mlir/test/Dialect/WasmSSA/custom_parser/memory.mlir
index 47551db..ca6ebe0 100644
--- a/mlir/test/Dialect/WasmSSA/custom_parser/memory.mlir
+++ b/mlir/test/Dialect/WasmSSA/custom_parser/memory.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt %s | FileCheck %s
-// CHECK: wasmssa.memory @mem0 public !wasmssa<limit[0: 65536]>
-wasmssa.memory @mem0 public !wasmssa<limit[0:65536]>
-
-// CHECK: wasmssa.memory @mem1 nested !wasmssa<limit[512:]>
+// CHECK: wasmssa.memory @mem1 !wasmssa<limit[512:]>
wasmssa.memory @mem1 !wasmssa<limit[512:]>
+
+// CHECK: wasmssa.memory exported @mem2 !wasmssa<limit[0: 65536]>
+wasmssa.memory exported @mem2 !wasmssa<limit[0:65536]>
diff --git a/mlir/test/Dialect/WasmSSA/custom_parser/table.mlir b/mlir/test/Dialect/WasmSSA/custom_parser/table.mlir
index 5a874f4..ea630de 100644
--- a/mlir/test/Dialect/WasmSSA/custom_parser/table.mlir
+++ b/mlir/test/Dialect/WasmSSA/custom_parser/table.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt %s | FileCheck %s
-// CHECK: wasmssa.table @tab0 public !wasmssa<tabletype !wasmssa.externref [0: 65536]>
-wasmssa.table @tab0 public !wasmssa<tabletype !wasmssa.externref [0:65536]>
+// CHECK: wasmssa.table exported @tab0 !wasmssa<tabletype !wasmssa.externref [0: 65536]>
+wasmssa.table exported @tab0 !wasmssa<tabletype !wasmssa.externref [0:65536]>
-// CHECK: wasmssa.table @tab1 nested !wasmssa<tabletype !wasmssa.funcref [348:]>
+// CHECK: wasmssa.table @tab1 !wasmssa<tabletype !wasmssa.funcref [348:]>
wasmssa.table @tab1 !wasmssa<tabletype !wasmssa.funcref [348:]>
diff --git a/mlir/test/Dialect/WasmSSA/extend-invalid.mlir b/mlir/test/Dialect/WasmSSA/extend-invalid.mlir
index 8d78280..7687e5f 100644
--- a/mlir/test/Dialect/WasmSSA/extend-invalid.mlir
+++ b/mlir/test/Dialect/WasmSSA/extend-invalid.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
-wasmssa.func nested @extend_low_64() -> i32 {
+wasmssa.func @extend_low_64() -> i32 {
%0 = wasmssa.const 10 : i32
// expected-error@+1 {{extend op can only take 8, 16 or 32 bits. Got 64}}
%1 = wasmssa.extend 64 low bits from %0: i32
@@ -10,7 +10,7 @@ wasmssa.func nested @extend_low_64() -> i32 {
// -----
-wasmssa.func nested @extend_too_much() -> i32 {
+wasmssa.func @extend_too_much() -> i32 {
%0 = wasmssa.const 10 : i32
// expected-error@+1 {{trying to extend the 32 low bits from a 'i32' value is illegal}}
%1 = wasmssa.extend 32 low bits from %0: i32
diff --git a/mlir/test/Dialect/WasmSSA/global-invalid.mlir b/mlir/test/Dialect/WasmSSA/global-invalid.mlir
index b9cafd8..c5bc606 100644
--- a/mlir/test/Dialect/WasmSSA/global-invalid.mlir
+++ b/mlir/test/Dialect/WasmSSA/global-invalid.mlir
@@ -13,7 +13,7 @@ module {
// -----
module {
- wasmssa.import_global "glob" from "my_module" as @global_0 mutable nested : i32
+ wasmssa.import_global "glob" from "my_module" as @global_0 mutable : i32
wasmssa.global @global_1 i32 : {
// expected-error@+1 {{global.get op is considered constant if it's referring to a import.global symbol marked non-mutable}}
%0 = wasmssa.global_get @global_0 : i32
@@ -30,3 +30,13 @@ module {
wasmssa.return %0 : i32
}
}
+
+// -----
+
+module {
+ // expected-error@+1 {{expecting either `exported` or symbol name. got exproted}}
+ wasmssa.global exproted @global_1 i32 : {
+ %0 = wasmssa.const 17 : i32
+ wasmssa.return %0 : i32
+ }
+}
diff --git a/mlir/test/Dialect/WasmSSA/locals-invalid.mlir b/mlir/test/Dialect/WasmSSA/locals-invalid.mlir
index 35c590b..eaad80e 100644
--- a/mlir/test/Dialect/WasmSSA/locals-invalid.mlir
+++ b/mlir/test/Dialect/WasmSSA/locals-invalid.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
-wasmssa.func nested @local_set_err(%arg0: !wasmssa<local ref to i32>) -> i64 {
+wasmssa.func @local_set_err(%arg0: !wasmssa<local ref to i32>) -> i64 {
%0 = wasmssa.const 3 : i64
// expected-error@+1 {{input type and result type of local.set do not match}}
wasmssa.local_set %arg0 : ref to i32 to %0 : i64
@@ -9,7 +9,7 @@ wasmssa.func nested @local_set_err(%arg0: !wasmssa<local ref to i32>) -> i64 {
// -----
-wasmssa.func nested @local_tee_err(%arg0: !wasmssa<local ref to i32>) -> i32 {
+wasmssa.func @local_tee_err(%arg0: !wasmssa<local ref to i32>) -> i32 {
%0 = wasmssa.const 3 : i64
// expected-error@+1 {{input type and output type of local.tee do not match}}
%1 = wasmssa.local_tee %arg0 : ref to i32 to %0 : i64
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 228ef69d..ebbe3ce 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -858,7 +858,7 @@ func.func @load_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16>
// -----
func.func @load_mem_desc_invalid_result_size(%arg0: !xegpu.mem_desc<16x64xf16>) {
- // expected-error@+1 {{result shape must not exceed mem_desc shape}}
+ // expected-error@+1 {{data shape must not exceed mem_desc shape}}
%data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<32x16xf16>
return
}
@@ -871,6 +871,14 @@ func.func @load_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>) {
}
// -----
+func.func @load_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>) {
+ // expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
+ %data2 = xegpu.load_matrix %arg0[8, 8] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16> -> vector<16x16xf16>
+ return
+}
+
+
+// -----
func.func @store_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf32>) {
// expected-error@+1 {{failed to verify that all of {mem_desc, data} have same element type}}
xegpu.store_matrix %arg1, %arg0[8, 8] : vector<16x16xf32>, !xegpu.mem_desc<16x64xf16>
@@ -892,30 +900,16 @@ func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: ve
}
// -----
-func.func @mem_desc_subview_size_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) {
- // expected-error@+1 {{result shape must not exceed source shape}}
- %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<32x16xf16>
- return
-}
-
-// -----
-func.func @mem_desc_subview_layout_mismatch(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride =[1, 16]>>) {
- // expected-error@+1 {{result must inherit the source strides}}
- %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride =[1, 16]>> -> !xegpu.mem_desc<8x16xf16>
- return
-}
-
-// -----
-func.func @mem_desc_subview_element_type_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) {
- // expected-error@+1 {{failed to verify that all of {src, res} have same element type}}
- %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf32, #xegpu.mem_layout<stride =[64, 1]>>
+func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) {
+ // expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
+ xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
return
}
// -----
-func.func @mem_desc_subview_rank_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) {
- // expected-error@+1 {{result rank must not exceed source rank}}
- %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<4x8x16xf16>
+func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) {
+ // expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
+ xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
return
}
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index bb37902..0a10f68 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -825,53 +825,73 @@ gpu.func @create_mem_desc_with_stride() {
gpu.return
}
-// CHECK: gpu.func @load_mem_desc([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
-gpu.func @load_mem_desc(%arg0: !xegpu.mem_desc<16x64xf16>) {
+// CHECK: gpu.func @load_matrix([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
+gpu.func @load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
// CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16>
%data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16>
gpu.return
}
-// CHECK: gpu.func @load_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>)
-gpu.func @load_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) {
+// CHECK: gpu.func @load_matrix_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>)
+gpu.func @load_matrix_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) {
// CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8x16xf16>
%data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8x16xf16>
gpu.return
}
+// CHECK: gpu.func @simt_load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>)
+gpu.func @simt_load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
+ // CHECK: xegpu.load_matrix [[ARG0]][8, 16] : !xegpu.mem_desc<16x64xf16> -> vector<1xf16>
+ %data = xegpu.load_matrix %arg0[8, 16]: !xegpu.mem_desc<16x64xf16> -> vector<1xf16>
+ gpu.return
+}
+
+// CHECK: gpu.func @simt_load_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>)
+gpu.func @simt_load_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>) {
+ // CHECK: xegpu.load_matrix [[ARG0]][8, 16] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>> -> vector<8xf16>
+ %data = xegpu.load_matrix %arg0[8, 16] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>> -> vector<8xf16>
+ gpu.return
+}
+
+// CHECK: gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>)
+gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) {
+ // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8xf16>
+ %data = xegpu.load_matrix %arg0[8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8xf16>
+ gpu.return
+}
-// CHECK: gpu.func @store_mem_desc([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>, [[ARG1:%.+]]: vector<16x16xf16>)
-gpu.func @store_mem_desc(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf16>) {
+// CHECK: gpu.func @store_matrix([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>, [[ARG1:%.+]]: vector<16x16xf16>)
+gpu.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf16>) {
// CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
xegpu.store_matrix %arg1, %arg0[8, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
gpu.return
}
-// CHECK: gpu.func @store_mem_desc_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, [[ARG1:%.+]]: vector<16x16xf16>)
-gpu.func @store_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<16x16xf16>) {
+// CHECK: gpu.func @store_matrix_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, [[ARG1:%.+]]: vector<16x16xf16>)
+gpu.func @store_matrix_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<16x16xf16>) {
// CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][0, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
xegpu.store_matrix %arg1, %arg0[0, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
gpu.return
}
-// CHECK: gpu.func @mem_desc_subview([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
-gpu.func @mem_desc_subview(%arg0: !xegpu.mem_desc<16x64xf16>) {
- //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [64, 1]>>
- %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [64, 1]>>
+// CHECK: gpu.func @simt_store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<1xf16>) {
+gpu.func @simt_store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<1xf16>) {
+ // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 16] : vector<1xf16>, !xegpu.mem_desc<16x64xf16>
+ xegpu.store_matrix %arg1, %arg0[8, 16]: vector<1xf16>, !xegpu.mem_desc<16x64xf16>
gpu.return
}
-// CHECK: gpu.func @mem_desc_subview_lower_rank([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
-gpu.func @mem_desc_subview_lower_rank(%arg0: !xegpu.mem_desc<16x64xf16>) {
- //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<16xf16, #xegpu.mem_layout<stride = [64, 1]>>
- %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<16xf16, #xegpu.mem_layout<stride = [64, 1]>>
+// CHECK: gpu.func @simt_store_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>, %arg1: vector<8xf16>)
+gpu.func @simt_store_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>, %arg1: vector<8xf16>) {
+ // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 16] <{subgroup_block_io}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>
+ xegpu.store_matrix %arg1, %arg0[8, 16] <{subgroup_block_io}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>
gpu.return
}
-// CHECK: gpu.func @mem_desc_subview_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>)
-gpu.func @mem_desc_subview_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) {
- //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [1, 16]>>
- %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [1, 16]>>
+// CHECK: gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<8xf16>) {
+gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<8xf16>) {
+ // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] : vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
+ xegpu.store_matrix %arg1, %arg0[8, 8] : vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
gpu.return
}
diff --git a/mlir/test/Target/LLVMIR/Import/function-attributes.ll b/mlir/test/Target/LLVMIR/Import/function-attributes.ll
index cc3d799..00d09ba 100644
--- a/mlir/test/Target/LLVMIR/Import/function-attributes.ll
+++ b/mlir/test/Target/LLVMIR/Import/function-attributes.ll
@@ -393,6 +393,12 @@ declare void @alwaysinline_attribute() alwaysinline
// -----
+; CHECK-LABEL: @inlinehint_attribute
+; CHECK-SAME: attributes {inline_hint}
+declare void @inlinehint_attribute() inlinehint
+
+// -----
+
; CHECK-LABEL: @optnone_attribute
; CHECK-SAME: attributes {no_inline, optimize_none}
declare void @optnone_attribute() noinline optnone
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 69814f2..cc243c8 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -2555,6 +2555,17 @@ llvm.func @always_inline() attributes { always_inline } {
// -----
+// CHECK-LABEL: @inline_hint
+// CHECK-SAME: #[[ATTRS:[0-9]+]]
+llvm.func @inline_hint() attributes { inline_hint } {
+ llvm.return
+}
+
+// CHECK: #[[ATTRS]]
+// CHECK-SAME: inlinehint
+
+// -----
+
// CHECK-LABEL: @optimize_none
// CHECK-SAME: #[[ATTRS:[0-9]+]]
llvm.func @optimize_none() attributes { no_inline, optimize_none } {
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index b2fe2f5..6cccfe4 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -568,6 +568,18 @@ llvm.func @clusterlaunchcontrol_query_cancel_get_first_cta_id_invalid_return_typ
llvm.return
}
+// -----
+
+// Test that ensures invalid row/col layouts for matrices A and B are not accepted
+llvm.func @nvvm_mma_m16n8k32_s4_s4(%a0 : i32, %a1 : i32, %b0 : i32, %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32,i32,i32,i32)> {
+ // expected-error@+1 {{Only m8n8k4 with f16 supports other layouts.}}
+ %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
+ {layoutA = #nvvm.mma_layout<col>, layoutB = #nvvm.mma_layout<col>,
+ multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>,
+ intOverflowBehavior=#nvvm.mma_int_overflow<satfinite>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
+ llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
+}
// -----
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index fdd2c91..6536fac 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -276,6 +276,20 @@ llvm.func @rocdl.s.wait.expcnt() {
llvm.return
}
+llvm.func @rocdl.s.wait.asynccnt() {
+ // CHECK-LABEL: rocdl.s.wait.asynccnt
+ // CHECK-NEXT: call void @llvm.amdgcn.s.wait.asynccnt(i16 0)
+ rocdl.s.wait.asynccnt 0
+ llvm.return
+}
+
+llvm.func @rocdl.s.wait.tensorcnt() {
+ // CHECK-LABEL: rocdl.s.wait.tensorcnt
+ // CHECK-NEXT: call void @llvm.amdgcn.s.wait.tensorcnt(i16 0)
+ rocdl.s.wait.tensorcnt 0
+ llvm.return
+}
+
llvm.func @rocdl.setprio() {
// CHECK: call void @llvm.amdgcn.s.setprio(i16 0)
rocdl.s.setprio 0
diff --git a/mlir/test/Target/Wasm/abs.mlir b/mlir/test/Target/Wasm/abs.mlir
index 9c45ba7..fe3602a 100644
--- a/mlir/test/Target/Wasm/abs.mlir
+++ b/mlir/test/Target/Wasm/abs.mlir
@@ -12,12 +12,12 @@
)
*/
-// CHECK-LABEL: wasmssa.func @abs_f32() -> f32 {
+// CHECK-LABEL: wasmssa.func exported @abs_f32() -> f32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f32
// CHECK: %[[VAL_1:.*]] = wasmssa.abs %[[VAL_0]] : f32
// CHECK: wasmssa.return %[[VAL_1]] : f32
-// CHECK-LABEL: wasmssa.func @abs_f64() -> f64 {
+// CHECK-LABEL: wasmssa.func exported @abs_f64() -> f64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f64
// CHECK: %[[VAL_1:.*]] = wasmssa.abs %[[VAL_0]] : f64
// CHECK: wasmssa.return %[[VAL_1]] : f64
diff --git a/mlir/test/Target/Wasm/add_div.mlir b/mlir/test/Target/Wasm/add_div.mlir
new file mode 100644
index 0000000..8a87c60
--- /dev/null
+++ b/mlir/test/Target/Wasm/add_div.mlir
@@ -0,0 +1,40 @@
+// RUN: yaml2obj %S/inputs/add_div.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to create this test:
+ (module $test.wasm
+ (type (;0;) (func (param i32) (result i32)))
+ (type (;1;) (func (param i32 i32) (result i32)))
+ (import "env" "twoTimes" (func $twoTimes (type 0)))
+ (func $add (type 1) (param i32 i32) (result i32)
+ local.get 0
+ call $twoTimes
+ local.get 1
+ call $twoTimes
+ i32.add
+ i32.const 2
+ i32.div_s)
+ (memory (;0;) 2)
+ (global $__stack_pointer (mut i32) (i32.const 66560))
+ (export "memory" (memory 0))
+ (export "add" (func $add)))
+*/
+
+// CHECK-LABEL: wasmssa.import_func "twoTimes" from "env" as @func_0 {type = (i32) -> i32}
+
+// CHECK-LABEL: wasmssa.func exported @add(
+// CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i32>,
+// CHECK-SAME: %[[ARG1:.*]]: !wasmssa<local ref to i32>) -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.local_get %[[ARG0]] : ref to i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.call @func_0(%[[VAL_0]]) : (i32) -> i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.local_get %[[ARG1]] : ref to i32
+// CHECK: %[[VAL_3:.*]] = wasmssa.call @func_0(%[[VAL_2]]) : (i32) -> i32
+// CHECK: %[[VAL_4:.*]] = wasmssa.add %[[VAL_1]] %[[VAL_3]] : i32
+// CHECK: %[[VAL_5:.*]] = wasmssa.const 2 : i32
+// CHECK: %[[VAL_6:.*]] = wasmssa.div_si %[[VAL_4]] %[[VAL_5]] : i32
+// CHECK: wasmssa.return %[[VAL_6]] : i32
+// CHECK: }
+// CHECK: wasmssa.memory exported @memory !wasmssa<limit[2:]>
+
+// CHECK-LABEL: wasmssa.global @global_0 i32 mutable : {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 66560 : i32
+// CHECK: wasmssa.return %[[VAL_0]] : i32
diff --git a/mlir/test/Target/Wasm/and.mlir b/mlir/test/Target/Wasm/and.mlir
index 4c0fea0..323d41a 100644
--- a/mlir/test/Target/Wasm/and.mlir
+++ b/mlir/test/Target/Wasm/and.mlir
@@ -14,13 +14,13 @@
)
*/
-// CHECK-LABEL: wasmssa.func @and_i32() -> i32 {
+// CHECK-LABEL: wasmssa.func exported @and_i32() -> i32 {
// CHECK: %0 = wasmssa.const 10 : i32
// CHECK: %1 = wasmssa.const 3 : i32
// CHECK: %2 = wasmssa.and %0 %1 : i32
// CHECK: wasmssa.return %2 : i32
-// CHECK-LABEL: wasmssa.func @and_i64() -> i64 {
+// CHECK-LABEL: wasmssa.func exported @and_i64() -> i64 {
// CHECK: %0 = wasmssa.const 10 : i64
// CHECK: %1 = wasmssa.const 3 : i64
// CHECK: %2 = wasmssa.and %0 %1 : i64
diff --git a/mlir/test/Target/Wasm/block.mlir b/mlir/test/Target/Wasm/block.mlir
new file mode 100644
index 0000000..c85fc1e
--- /dev/null
+++ b/mlir/test/Target/Wasm/block.mlir
@@ -0,0 +1,16 @@
+// RUN: yaml2obj %S/inputs/block.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to create this test:
+(module
+(func(export "i_am_a_block")
+(block $i_am_a_block)
+)
+)
+*/
+
+// CHECK-LABEL: wasmssa.func exported @i_am_a_block() {
+// CHECK: wasmssa.block : {
+// CHECK: wasmssa.block_return
+// CHECK: }> ^bb1
+// CHECK: ^bb1:
+// CHECK: wasmssa.return
diff --git a/mlir/test/Target/Wasm/block_complete_type.mlir b/mlir/test/Target/Wasm/block_complete_type.mlir
new file mode 100644
index 0000000..67df198
--- /dev/null
+++ b/mlir/test/Target/Wasm/block_complete_type.mlir
@@ -0,0 +1,24 @@
+// RUN: yaml2obj %S/inputs/block_complete_type.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to create this test:
+(module
+ (type (;0;) (func (param i32) (result i32)))
+ (type (;1;) (func (result i32)))
+ (func (;0;) (type 1) (result i32)
+ i32.const 14
+ block (param i32) (result i32) ;; label = @1
+ i32.const 1
+ i32.add
+ end))
+*/
+
+// CHECK-LABEL: wasmssa.func @func_0() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 14 : i32
+// CHECK: wasmssa.block(%[[VAL_0]]) : i32 : {
+// CHECK: ^bb0(%[[VAL_1:.*]]: i32):
+// CHECK: %[[VAL_2:.*]] = wasmssa.const 1 : i32
+// CHECK: %[[VAL_3:.*]] = wasmssa.add %[[VAL_1]] %[[VAL_2]] : i32
+// CHECK: wasmssa.block_return %[[VAL_3]] : i32
+// CHECK: }> ^bb1
+// CHECK: ^bb1(%[[VAL_4:.*]]: i32):
+// CHECK: wasmssa.return %[[VAL_4]] : i32
diff --git a/mlir/test/Target/Wasm/block_value_type.mlir b/mlir/test/Target/Wasm/block_value_type.mlir
new file mode 100644
index 0000000..fa30f08
--- /dev/null
+++ b/mlir/test/Target/Wasm/block_value_type.mlir
@@ -0,0 +1,19 @@
+// RUN: yaml2obj %S/inputs/block_value_type.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to create this test:
+(module
+ (type (;0;) (func (result i32)))
+ (func (;0;) (type 0) (result i32)
+ block (result i32) ;; label = @1
+ i32.const 17
+ end))
+*/
+
+
+// CHECK-LABEL: wasmssa.func @func_0() -> i32 {
+// CHECK: wasmssa.block : {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 17 : i32
+// CHECK: wasmssa.block_return %[[VAL_0]] : i32
+// CHECK: }> ^bb1
+// CHECK: ^bb1(%[[VAL_1:.*]]: i32):
+// CHECK: wasmssa.return %[[VAL_1]] : i32
diff --git a/mlir/test/Target/Wasm/branch_if.mlir b/mlir/test/Target/Wasm/branch_if.mlir
new file mode 100644
index 0000000..c91ff37
--- /dev/null
+++ b/mlir/test/Target/Wasm/branch_if.mlir
@@ -0,0 +1,29 @@
+// RUN: yaml2obj %S/inputs/branch_if.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to create this test:
+(module
+ (type $produce_i32 (func (result i32)))
+ (func (type $produce_i32)
+ (block $my_block (type $produce_i32)
+ i32.const 1
+ i32.const 2
+ br_if $my_block
+ i32.const 1
+ i32.add
+ )
+ )
+)
+*/
+
+// CHECK-LABEL: wasmssa.func @func_0() -> i32 {
+// CHECK: wasmssa.block : {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 1 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 2 : i32
+// CHECK: wasmssa.branch_if %[[VAL_1]] to level 0 with args(%[[VAL_0]] : i32) else ^bb1
+// CHECK: ^bb1:
+// CHECK: %[[VAL_2:.*]] = wasmssa.const 1 : i32
+// CHECK: %[[VAL_3:.*]] = wasmssa.add %[[VAL_0]] %[[VAL_2]] : i32
+// CHECK: wasmssa.block_return %[[VAL_3]] : i32
+// CHECK: }> ^bb1
+// CHECK: ^bb1(%[[VAL_4:.*]]: i32):
+// CHECK: wasmssa.return %[[VAL_4]] : i32
diff --git a/mlir/test/Target/Wasm/call.mlir b/mlir/test/Target/Wasm/call.mlir
new file mode 100644
index 0000000..c0169aa
--- /dev/null
+++ b/mlir/test/Target/Wasm/call.mlir
@@ -0,0 +1,17 @@
+// RUN: yaml2obj %S/inputs/call.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to create this test:
+(module
+(func $forty_two (result i32)
+i32.const 42)
+(func(export "forty_two")(result i32)
+call $forty_two))
+*/
+
+// CHECK-LABEL: wasmssa.func @func_0() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 42 : i32
+// CHECK: wasmssa.return %[[VAL_0]] : i32
+
+// CHECK-LABEL: wasmssa.func exported @forty_two() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.call @func_0 : () -> i32
+// CHECK: wasmssa.return %[[VAL_0]] : i32
diff --git a/mlir/test/Target/Wasm/clz.mlir b/mlir/test/Target/Wasm/clz.mlir
index 3e6641d..858c09d 100644
--- a/mlir/test/Target/Wasm/clz.mlir
+++ b/mlir/test/Target/Wasm/clz.mlir
@@ -14,12 +14,12 @@
)
*/
-// CHECK-LABEL: wasmssa.func @clz_i32() -> i32 {
+// CHECK-LABEL: wasmssa.func exported @clz_i32() -> i32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
// CHECK: %[[VAL_1:.*]] = wasmssa.clz %[[VAL_0]] : i32
// CHECK: wasmssa.return %[[VAL_1]] : i32
-// CHECK-LABEL: wasmssa.func @clz_i64() -> i64 {
+// CHECK-LABEL: wasmssa.func exported @clz_i64() -> i64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64
// CHECK: %[[VAL_1:.*]] = wasmssa.clz %[[VAL_0]] : i64
// CHECK: wasmssa.return %[[VAL_1]] : i64
diff --git a/mlir/test/Target/Wasm/comparison_ops.mlir b/mlir/test/Target/Wasm/comparison_ops.mlir
new file mode 100644
index 0000000..91e3a6a
--- /dev/null
+++ b/mlir/test/Target/Wasm/comparison_ops.mlir
@@ -0,0 +1,269 @@
+// RUN: yaml2obj %S/inputs/comparison_ops.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+/* Source code used to create this test:
+(module
+ (func $lt_si32 (result i32)
+ i32.const 12
+ i32.const 50
+ i32.lt_s
+ )
+ (func $le_si32 (result i32)
+ i32.const 12
+ i32.const 50
+ i32.le_s
+ )
+ (func $lt_ui32 (result i32)
+ i32.const 12
+ i32.const 50
+ i32.lt_u
+ )
+ (func $le_ui32 (result i32)
+ i32.const 12
+ i32.const 50
+ i32.le_u
+ )
+ (func $gt_si32 (result i32)
+ i32.const 12
+ i32.const 50
+ i32.gt_s
+ )
+ (func $gt_ui32 (result i32)
+ i32.const 12
+ i32.const 50
+ i32.gt_u
+ )
+ (func $ge_si32 (result i32)
+ i32.const 12
+ i32.const 50
+ i32.ge_s
+ )
+ (func $ge_ui32 (result i32)
+ i32.const 12
+ i32.const 50
+ i32.ge_u
+ )
+ (func $lt_si64 (result i32)
+ i64.const 12
+ i64.const 50
+ i64.lt_s
+ )
+ (func $le_si64 (result i32)
+ i64.const 12
+ i64.const 50
+ i64.le_s
+ )
+ (func $lt_ui64 (result i32)
+ i64.const 12
+ i64.const 50
+ i64.lt_u
+ )
+ (func $le_ui64 (result i32)
+ i64.const 12
+ i64.const 50
+ i64.le_u
+ )
+ (func $gt_si64 (result i32)
+ i64.const 12
+ i64.const 50
+ i64.gt_s
+ )
+ (func $gt_ui64 (result i32)
+ i64.const 12
+ i64.const 50
+ i64.gt_u
+ )
+ (func $ge_si64 (result i32)
+ i64.const 12
+ i64.const 50
+ i64.ge_s
+ )
+ (func $ge_ui64 (result i32)
+ i64.const 12
+ i64.const 50
+ i64.ge_u
+ )
+ (func $lt_f32 (result i32)
+ f32.const 5
+ f32.const 14
+ f32.lt
+ )
+ (func $le_f32 (result i32)
+ f32.const 5
+ f32.const 14
+ f32.le
+ )
+ (func $gt_f32 (result i32)
+ f32.const 5
+ f32.const 14
+ f32.gt
+ )
+ (func $ge_f32 (result i32)
+ f32.const 5
+ f32.const 14
+ f32.ge
+ )
+ (func $lt_f64 (result i32)
+ f64.const 5
+ f64.const 14
+ f64.lt
+ )
+ (func $le_f64 (result i32)
+ f64.const 5
+ f64.const 14
+ f64.le
+ )
+ (func $gt_f64 (result i32)
+ f64.const 5
+ f64.const 14
+ f64.gt
+ )
+ (func $ge_f64 (result i32)
+ f64.const 5
+ f64.const 14
+ f64.ge
+ )
+)
+*/
+
+// CHECK-LABEL: wasmssa.func @func_0() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.lt_si %[[VAL_0]] %[[VAL_1]] : i32 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_1() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.le_si %[[VAL_0]] %[[VAL_1]] : i32 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_2() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.lt_ui %[[VAL_0]] %[[VAL_1]] : i32 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_3() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.le_ui %[[VAL_0]] %[[VAL_1]] : i32 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_4() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.gt_si %[[VAL_0]] %[[VAL_1]] : i32 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_5() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.gt_ui %[[VAL_0]] %[[VAL_1]] : i32 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_6() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.ge_si %[[VAL_0]] %[[VAL_1]] : i32 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_7() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.ge_ui %[[VAL_0]] %[[VAL_1]] : i32 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_8() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i64
+// CHECK: %[[VAL_2:.*]] = wasmssa.lt_si %[[VAL_0]] %[[VAL_1]] : i64 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_9() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i64
+// CHECK: %[[VAL_2:.*]] = wasmssa.le_si %[[VAL_0]] %[[VAL_1]] : i64 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_10() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i64
+// CHECK: %[[VAL_2:.*]] = wasmssa.lt_ui %[[VAL_0]] %[[VAL_1]] : i64 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_11() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i64
+// CHECK: %[[VAL_2:.*]] = wasmssa.le_ui %[[VAL_0]] %[[VAL_1]] : i64 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_12() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i64
+// CHECK: %[[VAL_2:.*]] = wasmssa.gt_si %[[VAL_0]] %[[VAL_1]] : i64 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_13() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i64
+// CHECK: %[[VAL_2:.*]] = wasmssa.gt_ui %[[VAL_0]] %[[VAL_1]] : i64 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_14() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i64
+// CHECK: %[[VAL_2:.*]] = wasmssa.ge_si %[[VAL_0]] %[[VAL_1]] : i64 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_15() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i64
+// CHECK: %[[VAL_2:.*]] = wasmssa.ge_ui %[[VAL_0]] %[[VAL_1]] : i64 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_16() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 5.000000e+00 : f32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.400000e+01 : f32
+// CHECK: %[[VAL_2:.*]] = wasmssa.lt %[[VAL_0]] %[[VAL_1]] : f32 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_17() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 5.000000e+00 : f32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.400000e+01 : f32
+// CHECK: %[[VAL_2:.*]] = wasmssa.le %[[VAL_0]] %[[VAL_1]] : f32 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_18() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 5.000000e+00 : f32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.400000e+01 : f32
+// CHECK: %[[VAL_2:.*]] = wasmssa.gt %[[VAL_0]] %[[VAL_1]] : f32 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_19() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 5.000000e+00 : f32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.400000e+01 : f32
+// CHECK: %[[VAL_2:.*]] = wasmssa.ge %[[VAL_0]] %[[VAL_1]] : f32 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_20() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 5.000000e+00 : f64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.400000e+01 : f64
+// CHECK: %[[VAL_2:.*]] = wasmssa.lt %[[VAL_0]] %[[VAL_1]] : f64 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_21() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 5.000000e+00 : f64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.400000e+01 : f64
+// CHECK: %[[VAL_2:.*]] = wasmssa.le %[[VAL_0]] %[[VAL_1]] : f64 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_22() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 5.000000e+00 : f64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.400000e+01 : f64
+// CHECK: %[[VAL_2:.*]] = wasmssa.gt %[[VAL_0]] %[[VAL_1]] : f64 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_23() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 5.000000e+00 : f64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.400000e+01 : f64
+// CHECK: %[[VAL_2:.*]] = wasmssa.ge %[[VAL_0]] %[[VAL_1]] : f64 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
diff --git a/mlir/test/Target/Wasm/const.mlir b/mlir/test/Target/Wasm/const.mlir
index aa9e76f..adb792a 100644
--- a/mlir/test/Target/Wasm/const.mlir
+++ b/mlir/test/Target/Wasm/const.mlir
@@ -16,22 +16,22 @@
)
*/
-// CHECK-LABEL: wasmssa.func nested @func_0() -> i32 {
+// CHECK-LABEL: wasmssa.func @func_0() -> i32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1 : i32
// CHECK: wasmssa.return %[[VAL_0]] : i32
// CHECK: }
-// CHECK-LABEL: wasmssa.func nested @func_1() -> i64 {
+// CHECK-LABEL: wasmssa.func @func_1() -> i64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 3 : i64
// CHECK: wasmssa.return %[[VAL_0]] : i64
// CHECK: }
-// CHECK-LABEL: wasmssa.func nested @func_2() -> f32 {
+// CHECK-LABEL: wasmssa.func @func_2() -> f32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 4.000000e+00 : f32
// CHECK: wasmssa.return %[[VAL_0]] : f32
// CHECK: }
-// CHECK-LABEL: wasmssa.func nested @func_3() -> f64 {
+// CHECK-LABEL: wasmssa.func @func_3() -> f64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 9.000000e+00 : f64
// CHECK: wasmssa.return %[[VAL_0]] : f64
// CHECK: }
diff --git a/mlir/test/Target/Wasm/convert.mlir b/mlir/test/Target/Wasm/convert.mlir
new file mode 100644
index 0000000..ddc29a7
--- /dev/null
+++ b/mlir/test/Target/Wasm/convert.mlir
@@ -0,0 +1,85 @@
+// RUN: yaml2obj %S/inputs/convert.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to generate this test:
+(module
+ (func (export "convert_i32_u_to_f32") (result f32)
+ i32.const 10
+ f32.convert_i32_u
+ )
+
+ (func (export "convert_i32_s_to_f32") (result f32)
+ i32.const 42
+ f32.convert_i32_s
+ )
+
+ (func (export "convert_i64_u_to_f32") (result f32)
+ i64.const 17
+ f32.convert_i64_u
+ )
+
+ (func (export "convert_i64s_to_f32") (result f32)
+ i64.const 10
+ f32.convert_i64_s
+ )
+
+ (func (export "convert_i32_u_to_f64") (result f64)
+ i32.const 10
+ f64.convert_i32_u
+ )
+
+ (func (export "convert_i32_s_to_f64") (result f64)
+ i32.const 42
+ f64.convert_i32_s
+ )
+
+ (func (export "convert_i64_u_to_f64") (result f64)
+ i64.const 17
+ f64.convert_i64_u
+ )
+
+ (func (export "convert_i64s_to_f64") (result f64)
+ i64.const 10
+ f64.convert_i64_s
+ )
+)
+*/
+
+// CHECK-LABEL: wasmssa.func exported @convert_i32_u_to_f32() -> f32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.convert_u %[[VAL_0]] : i32 to f32
+// CHECK: wasmssa.return %[[VAL_1]] : f32
+
+// CHECK-LABEL: wasmssa.func exported @convert_i32_s_to_f32() -> f32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 42 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.convert_s %[[VAL_0]] : i32 to f32
+// CHECK: wasmssa.return %[[VAL_1]] : f32
+
+// CHECK-LABEL: wasmssa.func exported @convert_i64_u_to_f32() -> f32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 17 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.convert_u %[[VAL_0]] : i64 to f32
+// CHECK: wasmssa.return %[[VAL_1]] : f32
+
+// CHECK-LABEL: wasmssa.func exported @convert_i64s_to_f32() -> f32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.convert_s %[[VAL_0]] : i64 to f32
+// CHECK: wasmssa.return %[[VAL_1]] : f32
+
+// CHECK-LABEL: wasmssa.func exported @convert_i32_u_to_f64() -> f64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.convert_u %[[VAL_0]] : i32 to f64
+// CHECK: wasmssa.return %[[VAL_1]] : f64
+
+// CHECK-LABEL: wasmssa.func exported @convert_i32_s_to_f64() -> f64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 42 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.convert_s %[[VAL_0]] : i32 to f64
+// CHECK: wasmssa.return %[[VAL_1]] : f64
+
+// CHECK-LABEL: wasmssa.func exported @convert_i64_u_to_f64() -> f64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 17 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.convert_u %[[VAL_0]] : i64 to f64
+// CHECK: wasmssa.return %[[VAL_1]] : f64
+
+// CHECK-LABEL: wasmssa.func exported @convert_i64s_to_f64() -> f64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.convert_s %[[VAL_0]] : i64 to f64
+// CHECK: wasmssa.return %[[VAL_1]] : f64
diff --git a/mlir/test/Target/Wasm/copysign.mlir b/mlir/test/Target/Wasm/copysign.mlir
index 33d7a56..90c5b11 100644
--- a/mlir/test/Target/Wasm/copysign.mlir
+++ b/mlir/test/Target/Wasm/copysign.mlir
@@ -16,14 +16,14 @@
)
*/
-// CHECK-LABEL: wasmssa.func @copysign_f32() -> f32 {
+// CHECK-LABEL: wasmssa.func exported @copysign_f32() -> f32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f32
// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.000000e+00 : f32
// CHECK: %[[VAL_2:.*]] = wasmssa.copysign %[[VAL_0]] %[[VAL_1]] : f32
// CHECK: wasmssa.return %[[VAL_2]] : f32
// CHECK: }
-// CHECK-LABEL: wasmssa.func @copysign_f64() -> f64 {
+// CHECK-LABEL: wasmssa.func exported @copysign_f64() -> f64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f64
// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.000000e+00 : f64
// CHECK: %[[VAL_2:.*]] = wasmssa.copysign %[[VAL_0]] %[[VAL_1]] : f64
diff --git a/mlir/test/Target/Wasm/ctz.mlir b/mlir/test/Target/Wasm/ctz.mlir
index 6c0806f..9e7cc5e 100644
--- a/mlir/test/Target/Wasm/ctz.mlir
+++ b/mlir/test/Target/Wasm/ctz.mlir
@@ -14,12 +14,12 @@
)
*/
-// CHECK-LABEL: wasmssa.func @ctz_i32() -> i32 {
+// CHECK-LABEL: wasmssa.func exported @ctz_i32() -> i32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
// CHECK: %[[VAL_1:.*]] = wasmssa.ctz %[[VAL_0]] : i32
// CHECK: wasmssa.return %[[VAL_1]] : i32
-// CHECK-LABEL: wasmssa.func @ctz_i64() -> i64 {
+// CHECK-LABEL: wasmssa.func exported @ctz_i64() -> i64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64
// CHECK: %[[VAL_1:.*]] = wasmssa.ctz %[[VAL_0]] : i64
// CHECK: wasmssa.return %[[VAL_1]] : i64
diff --git a/mlir/test/Target/Wasm/demote.mlir b/mlir/test/Target/Wasm/demote.mlir
new file mode 100644
index 0000000..3d2bc05
--- /dev/null
+++ b/mlir/test/Target/Wasm/demote.mlir
@@ -0,0 +1,15 @@
+// RUN: yaml2obj %S/inputs/demote.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to create this test:
+(module
+ (func $main (result f32)
+ f64.const 2.24
+ f32.demote_f64
+ )
+)
+*/
+
+// CHECK-LABEL: wasmssa.func @func_0() -> f32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 2.240000e+00 : f64
+// CHECK: %[[VAL_1:.*]] = wasmssa.demote %[[VAL_0]] : f64 to f32
+// CHECK: wasmssa.return %[[VAL_1]] : f32
diff --git a/mlir/test/Target/Wasm/div.mlir b/mlir/test/Target/Wasm/div.mlir
index c91f780..4967d96 100644
--- a/mlir/test/Target/Wasm/div.mlir
+++ b/mlir/test/Target/Wasm/div.mlir
@@ -66,61 +66,61 @@
)
*/
-// CHECK-LABEL: wasmssa.func @div_u_i32() -> i32 {
+// CHECK-LABEL: wasmssa.func exported @div_u_i32() -> i32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
// CHECK: %[[VAL_1:.*]] = wasmssa.const 2 : i32
// CHECK: %[[VAL_2:.*]] = wasmssa.div_ui %[[VAL_0]] %[[VAL_1]] : i32
// CHECK: wasmssa.return %[[VAL_2]] : i32
-// CHECK-LABEL: wasmssa.func @div_u_i32_zero() -> i32 {
+// CHECK-LABEL: wasmssa.func exported @div_u_i32_zero() -> i32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
// CHECK: %[[VAL_1:.*]] = wasmssa.const 0 : i32
// CHECK: %[[VAL_2:.*]] = wasmssa.div_ui %[[VAL_0]] %[[VAL_1]] : i32
// CHECK: wasmssa.return %[[VAL_2]] : i32
-// CHECK-LABEL: wasmssa.func @div_s_i32() -> i32 {
+// CHECK-LABEL: wasmssa.func exported @div_s_i32() -> i32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
// CHECK: %[[VAL_1:.*]] = wasmssa.const 2 : i32
// CHECK: %[[VAL_2:.*]] = wasmssa.div_si %[[VAL_0]] %[[VAL_1]] : i32
// CHECK: wasmssa.return %[[VAL_2]] : i32
-// CHECK-LABEL: wasmssa.func @div_s_i32_zero() -> i32 {
+// CHECK-LABEL: wasmssa.func exported @div_s_i32_zero() -> i32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
// CHECK: %[[VAL_1:.*]] = wasmssa.const 0 : i32
// CHECK: %[[VAL_2:.*]] = wasmssa.div_si %[[VAL_0]] %[[VAL_1]] : i32
// CHECK: wasmssa.return %[[VAL_2]] : i32
-// CHECK-LABEL: wasmssa.func @div_u_i64() -> i64 {
+// CHECK-LABEL: wasmssa.func exported @div_u_i64() -> i64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64
// CHECK: %[[VAL_1:.*]] = wasmssa.const 2 : i64
// CHECK: %[[VAL_2:.*]] = wasmssa.div_ui %[[VAL_0]] %[[VAL_1]] : i64
// CHECK: wasmssa.return %[[VAL_2]] : i64
-// CHECK-LABEL: wasmssa.func @div_u_i64_zero() -> i64 {
+// CHECK-LABEL: wasmssa.func exported @div_u_i64_zero() -> i64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64
// CHECK: %[[VAL_1:.*]] = wasmssa.const 0 : i64
// CHECK: %[[VAL_2:.*]] = wasmssa.div_ui %[[VAL_0]] %[[VAL_1]] : i64
// CHECK: wasmssa.return %[[VAL_2]] : i64
-// CHECK-LABEL: wasmssa.func @div_s_i64() -> i64 {
+// CHECK-LABEL: wasmssa.func exported @div_s_i64() -> i64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64
// CHECK: %[[VAL_1:.*]] = wasmssa.const 2 : i64
// CHECK: %[[VAL_2:.*]] = wasmssa.div_si %[[VAL_0]] %[[VAL_1]] : i64
// CHECK: wasmssa.return %[[VAL_2]] : i64
-// CHECK-LABEL: wasmssa.func @div_s_i64_zero() -> i64 {
+// CHECK-LABEL: wasmssa.func exported @div_s_i64_zero() -> i64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64
// CHECK: %[[VAL_1:.*]] = wasmssa.const 0 : i64
// CHECK: %[[VAL_2:.*]] = wasmssa.div_si %[[VAL_0]] %[[VAL_1]] : i64
// CHECK: wasmssa.return %[[VAL_2]] : i64
-// CHECK-LABEL: wasmssa.func @div_f32() -> f32 {
+// CHECK-LABEL: wasmssa.func exported @div_f32() -> f32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f32
// CHECK: %[[VAL_1:.*]] = wasmssa.const 2.000000e+00 : f32
// CHECK: %[[VAL_2:.*]] = wasmssa.div %[[VAL_0]] %[[VAL_1]] : f32
// CHECK: wasmssa.return %[[VAL_2]] : f32
-// CHECK-LABEL: wasmssa.func @div_f64() -> f64 {
+// CHECK-LABEL: wasmssa.func exported @div_f64() -> f64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f64
// CHECK: %[[VAL_1:.*]] = wasmssa.const 2.000000e+00 : f64
// CHECK: %[[VAL_2:.*]] = wasmssa.div %[[VAL_0]] %[[VAL_1]] : f64
diff --git a/mlir/test/Target/Wasm/double_nested_loop.mlir b/mlir/test/Target/Wasm/double_nested_loop.mlir
new file mode 100644
index 0000000..8b3e499
--- /dev/null
+++ b/mlir/test/Target/Wasm/double_nested_loop.mlir
@@ -0,0 +1,63 @@
+// RUN: yaml2obj %S/inputs/double_nested_loop.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/*
+(module
+ (func
+ ;; create a local variable and initialize it to 0
+ (local $i i32)
+ (local $j i32)
+
+ (loop $my_loop
+
+ ;; add one to $i
+ local.get $i
+ i32.const 1
+ i32.add
+ local.set $i
+ (loop $my_second_loop (result i32)
+ i32.const 1
+ local.get $j
+ i32.const 12
+ i32.add
+ local.tee $j
+ local.get $i
+ i32.gt_s
+ br_if $my_second_loop
+ )
+ i32.const 10
+ i32.lt_s
+ br_if $my_loop
+ )
+ )
+)
+*/
+
+// CHECK-LABEL: wasmssa.func @func_0() {
+// CHECK: %[[VAL_0:.*]] = wasmssa.local of type i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.local of type i32
+// CHECK: wasmssa.loop : {
+// CHECK: %[[VAL_2:.*]] = wasmssa.local_get %[[VAL_0]] : ref to i32
+// CHECK: %[[VAL_3:.*]] = wasmssa.const 1 : i32
+// CHECK: %[[VAL_4:.*]] = wasmssa.add %[[VAL_2]] %[[VAL_3]] : i32
+// CHECK: wasmssa.local_set %[[VAL_0]] : ref to i32 to %[[VAL_4]] : i32
+// CHECK: wasmssa.loop : {
+// CHECK: %[[VAL_5:.*]] = wasmssa.const 1 : i32
+// CHECK: %[[VAL_6:.*]] = wasmssa.local_get %[[VAL_1]] : ref to i32
+// CHECK: %[[VAL_7:.*]] = wasmssa.const 12 : i32
+// CHECK: %[[VAL_8:.*]] = wasmssa.add %[[VAL_6]] %[[VAL_7]] : i32
+// CHECK: %[[VAL_9:.*]] = wasmssa.local_tee %[[VAL_1]] : ref to i32 to %[[VAL_8]] : i32
+// CHECK: %[[VAL_10:.*]] = wasmssa.local_get %[[VAL_0]] : ref to i32
+// CHECK: %[[VAL_11:.*]] = wasmssa.gt_si %[[VAL_9]] %[[VAL_10]] : i32 -> i32
+// CHECK: wasmssa.branch_if %[[VAL_11]] to level 0 else ^bb1
+// CHECK: ^bb1:
+// CHECK: wasmssa.block_return %[[VAL_5]] : i32
+// CHECK: }> ^bb1
+// CHECK: ^bb1(%[[VAL_12:.*]]: i32):
+// CHECK: %[[VAL_13:.*]] = wasmssa.const 10 : i32
+// CHECK: %[[VAL_14:.*]] = wasmssa.lt_si %[[VAL_12]] %[[VAL_13]] : i32 -> i32
+// CHECK: wasmssa.branch_if %[[VAL_14]] to level 0 else ^bb2
+// CHECK: ^bb2:
+// CHECK: wasmssa.block_return
+// CHECK: }> ^bb1
+// CHECK: ^bb1:
+// CHECK: wasmssa.return
diff --git a/mlir/test/Target/Wasm/empty_blocks_list_and_stack.mlir b/mlir/test/Target/Wasm/empty_blocks_list_and_stack.mlir
new file mode 100644
index 0000000..5c98f1a
--- /dev/null
+++ b/mlir/test/Target/Wasm/empty_blocks_list_and_stack.mlir
@@ -0,0 +1,53 @@
+// RUN: yaml2obj %S/inputs/empty_blocks_list_and_stack.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to create this test:
+(module
+ (func (param $num i32)
+ (block $b1
+ (block $b2
+ (block $b3
+ )
+ )
+ )
+ )
+
+ (func (param $num i32)
+ (block $b1)
+ (block $b2)
+ (block $b3)
+ )
+)
+
+*/
+
+// CHECK-LABEL: wasmssa.func @func_0(
+// CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i32>) {
+// CHECK: wasmssa.block : {
+// CHECK: wasmssa.block : {
+// CHECK: wasmssa.block : {
+// CHECK: wasmssa.block_return
+// CHECK: }> ^bb1
+// CHECK: ^bb1:
+// CHECK: wasmssa.block_return
+// CHECK: }> ^bb1
+// CHECK: ^bb1:
+// CHECK: wasmssa.block_return
+// CHECK: }> ^bb1
+// CHECK: ^bb1:
+// CHECK: wasmssa.return
+
+// CHECK-LABEL: wasmssa.func @func_1(
+// CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i32>) {
+// CHECK: wasmssa.block : {
+// CHECK: wasmssa.block_return
+// CHECK: }> ^bb1
+// CHECK: ^bb1:
+// CHECK: wasmssa.block : {
+// CHECK: wasmssa.block_return
+// CHECK: }> ^bb2
+// CHECK: ^bb2:
+// CHECK: wasmssa.block : {
+// CHECK: wasmssa.block_return
+// CHECK: }> ^bb3
+// CHECK: ^bb3:
+// CHECK: wasmssa.return
diff --git a/mlir/test/Target/Wasm/eq.mlir b/mlir/test/Target/Wasm/eq.mlir
new file mode 100644
index 0000000..ba3ae2f
--- /dev/null
+++ b/mlir/test/Target/Wasm/eq.mlir
@@ -0,0 +1,56 @@
+// RUN: yaml2obj %S/inputs/eq.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+/* Source code used to create this test:
+(module
+ (func $eq_i32 (result i32)
+ i32.const 12
+ i32.const 50
+ i32.eq
+ )
+
+ (func $eq_i64 (result i32)
+ i64.const 20
+ i64.const 5
+ i64.eq
+ )
+
+ (func $eq_f32 (result i32)
+ f32.const 5
+ f32.const 14
+ f32.eq
+ )
+
+ (func $eq_f64 (result i32)
+ f64.const 17
+ f64.const 0
+ f64.eq
+ )
+)
+*/
+
+// CHECK-LABEL: wasmssa.func @func_0() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.eq %[[VAL_0]] %[[VAL_1]] : i32 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+// CHECK: }
+
+// CHECK-LABEL: wasmssa.func @func_1() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 20 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 5 : i64
+// CHECK: %[[VAL_2:.*]] = wasmssa.eq %[[VAL_0]] %[[VAL_1]] : i64 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+// CHECK: }
+
+// CHECK-LABEL: wasmssa.func @func_2() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 5.000000e+00 : f32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.400000e+01 : f32
+// CHECK: %[[VAL_2:.*]] = wasmssa.eq %[[VAL_0]] %[[VAL_1]] : f32 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+// CHECK: }
+
+// CHECK-LABEL: wasmssa.func @func_3() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.700000e+01 : f64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 0.000000e+00 : f64
+// CHECK: %[[VAL_2:.*]] = wasmssa.eq %[[VAL_0]] %[[VAL_1]] : f64 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+// CHECK: }
diff --git a/mlir/test/Target/Wasm/eqz.mlir b/mlir/test/Target/Wasm/eqz.mlir
new file mode 100644
index 0000000..55cf94a
--- /dev/null
+++ b/mlir/test/Target/Wasm/eqz.mlir
@@ -0,0 +1,21 @@
+// RUN: yaml2obj %S/inputs/eqz.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+/* Source code used to create this test:
+(module
+ (func (export "eqz_i32") (result i32)
+ i32.const 13
+ i32.eqz)
+
+ (func (export "eqz_i64") (result i32)
+ i64.const 13
+ i64.eqz)
+)
+*/
+// CHECK-LABEL: wasmssa.func exported @eqz_i32() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 13 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.eqz %[[VAL_0]] : i32 -> i32
+// CHECK: wasmssa.return %[[VAL_1]] : i32
+
+// CHECK-LABEL: wasmssa.func exported @eqz_i64() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 13 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.eqz %[[VAL_0]] : i64 -> i32
+// CHECK: wasmssa.return %[[VAL_1]] : i32
diff --git a/mlir/test/Target/Wasm/extend.mlir b/mlir/test/Target/Wasm/extend.mlir
new file mode 100644
index 0000000..5d4446a
--- /dev/null
+++ b/mlir/test/Target/Wasm/extend.mlir
@@ -0,0 +1,69 @@
+// RUN: yaml2obj %S/inputs/extend.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to create this test:
+(module
+ (func $i32_s (result i64)
+ i32.const 10
+ i64.extend_i32_s
+ )
+ (func $i32_u (result i64)
+ i32.const 10
+ i64.extend_i32_u
+ )
+ (func $extend8_32 (result i32)
+ i32.const 10
+ i32.extend8_s
+ )
+ (func $extend16_32 (result i32)
+ i32.const 10
+ i32.extend16_s
+ )
+ (func $extend8_64 (result i64)
+ i64.const 10
+ i64.extend8_s
+ )
+ (func $extend16_64 (result i64)
+ i64.const 10
+ i64.extend16_s
+ )
+ (func $extend32_64 (result i64)
+ i64.const 10
+ i64.extend32_s
+ )
+)
+*/
+
+// CHECK-LABEL: wasmssa.func @func_0() -> i64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.extend_i32_s %[[VAL_0]] to i64
+// CHECK: wasmssa.return %[[VAL_1]] : i64
+
+// CHECK-LABEL: wasmssa.func @func_1() -> i64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.extend_i32_u %[[VAL_0]] to i64
+// CHECK: wasmssa.return %[[VAL_1]] : i64
+
+// CHECK-LABEL: wasmssa.func @func_2() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.extend 8 : ui32 low bits from %[[VAL_0]] : i32
+// CHECK: wasmssa.return %[[VAL_1]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_3() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.extend 16 : ui32 low bits from %[[VAL_0]] : i32
+// CHECK: wasmssa.return %[[VAL_1]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_4() -> i64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.extend 8 : ui32 low bits from %[[VAL_0]] : i64
+// CHECK: wasmssa.return %[[VAL_1]] : i64
+
+// CHECK-LABEL: wasmssa.func @func_5() -> i64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.extend 16 : ui32 low bits from %[[VAL_0]] : i64
+// CHECK: wasmssa.return %[[VAL_1]] : i64
+
+// CHECK-LABEL: wasmssa.func @func_6() -> i64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.extend 32 : ui32 low bits from %[[VAL_0]] : i64
+// CHECK: wasmssa.return %[[VAL_1]] : i64
diff --git a/mlir/test/Target/Wasm/global.mlir b/mlir/test/Target/Wasm/global.mlir
index e72fe69..1e4fe44 100644
--- a/mlir/test/Target/Wasm/global.mlir
+++ b/mlir/test/Target/Wasm/global.mlir
@@ -29,9 +29,9 @@ i32.add
)
*/
-// CHECK-LABEL: wasmssa.import_global "from_js" from "env" as @global_0 nested : i32
+// CHECK-LABEL: wasmssa.import_global "from_js" from "env" as @global_0 : i32
-// CHECK-LABEL: wasmssa.func nested @func_0() -> i32 {
+// CHECK-LABEL: wasmssa.func @func_0() -> i32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.global_get @global_0 : i32
// CHECK: %[[VAL_1:.*]] = wasmssa.global_get @global_1 : i32
// CHECK: %[[VAL_2:.*]] = wasmssa.add %[[VAL_0]] %[[VAL_1]] : i32
@@ -41,26 +41,26 @@ i32.add
// CHECK: %[[VAL_6:.*]] = wasmssa.add %[[VAL_2]] %[[VAL_5]] : i32
// CHECK: wasmssa.return %[[VAL_6]] : i32
-// CHECK-LABEL: wasmssa.global @global_1 i32 nested : {
+// CHECK-LABEL: wasmssa.global @global_1 i32 : {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
// CHECK: wasmssa.return %[[VAL_0]] : i32
-// CHECK-LABEL: wasmssa.global @global_2 i32 mutable nested : {
+// CHECK-LABEL: wasmssa.global @global_2 i32 mutable : {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
// CHECK: wasmssa.return %[[VAL_0]] : i32
-// CHECK-LABEL: wasmssa.global @global_3 i32 mutable nested : {
+// CHECK-LABEL: wasmssa.global @global_3 i32 mutable : {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
// CHECK: wasmssa.return %[[VAL_0]] : i32
-// CHECK-LABEL: wasmssa.global @global_4 i64 nested : {
+// CHECK-LABEL: wasmssa.global @global_4 i64 : {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 11 : i64
// CHECK: wasmssa.return %[[VAL_0]] : i64
-// CHECK-LABEL: wasmssa.global @global_5 f32 nested : {
+// CHECK-LABEL: wasmssa.global @global_5 f32 : {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.200000e+01 : f32
// CHECK: wasmssa.return %[[VAL_0]] : f32
-// CHECK-LABEL: wasmssa.global @global_6 f64 nested : {
+// CHECK-LABEL: wasmssa.global @global_6 f64 : {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.300000e+01 : f64
// CHECK: wasmssa.return %[[VAL_0]] : f64
diff --git a/mlir/test/Target/Wasm/if.mlir b/mlir/test/Target/Wasm/if.mlir
new file mode 100644
index 0000000..2d7bfbe
--- /dev/null
+++ b/mlir/test/Target/Wasm/if.mlir
@@ -0,0 +1,112 @@
+// RUN: yaml2obj %S/inputs/if.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to create this test:
+(module
+(type $intMapper (func (param $input i32) (result i32)))
+(func $if_else (type $intMapper)
+ local.get 0
+ i32.const 1
+ i32.and
+ if $isOdd (result i32)
+ local.get 0
+ i32.const 3
+ i32.mul
+ i32.const 1
+ i32.add
+ else
+ local.get 0
+ i32.const 1
+ i32.shr_u
+ end
+)
+
+(func $if_only (type $intMapper)
+ local.get 0
+ local.get 0
+ i32.const 1
+ i32.and
+ if $isOdd (type $intMapper)
+ i32.const 1
+ i32.add
+ end
+)
+
+(func $if_if (type $intMapper)
+ local.get 0
+ i32.ctz
+ if $isEven (result i32)
+ i32.const 2
+ local.get 0
+ i32.const 1
+ i32.shr_u
+ i32.ctz
+ if $isMultipleOfFour (type $intMapper)
+ i32.const 2
+ i32.add
+ end
+ else
+ i32.const 1
+ end
+)
+)
+*/
+// CHECK-LABEL: wasmssa.func @func_0(
+// CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i32>) -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.local_get %[[ARG0]] : ref to i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 1 : i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.and %[[VAL_0]] %[[VAL_1]] : i32
+// CHECK: wasmssa.if %[[VAL_2]] : {
+// CHECK: %[[VAL_3:.*]] = wasmssa.local_get %[[ARG0]] : ref to i32
+// CHECK: %[[VAL_4:.*]] = wasmssa.const 3 : i32
+// CHECK: %[[VAL_5:.*]] = wasmssa.mul %[[VAL_3]] %[[VAL_4]] : i32
+// CHECK: %[[VAL_6:.*]] = wasmssa.const 1 : i32
+// CHECK: %[[VAL_7:.*]] = wasmssa.add %[[VAL_5]] %[[VAL_6]] : i32
+// CHECK: wasmssa.block_return %[[VAL_7]] : i32
+// CHECK: } "else "{
+// CHECK: %[[VAL_8:.*]] = wasmssa.local_get %[[ARG0]] : ref to i32
+// CHECK: %[[VAL_9:.*]] = wasmssa.const 1 : i32
+// CHECK: %[[VAL_10:.*]] = wasmssa.shr_u %[[VAL_8]] by %[[VAL_9]] bits : i32
+// CHECK: wasmssa.block_return %[[VAL_10]] : i32
+// CHECK: }> ^bb1
+// CHECK: ^bb1(%[[VAL_11:.*]]: i32):
+// CHECK: wasmssa.return %[[VAL_11]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_1(
+// CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i32>) -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.local_get %[[ARG0]] : ref to i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.local_get %[[ARG0]] : ref to i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.const 1 : i32
+// CHECK: %[[VAL_3:.*]] = wasmssa.and %[[VAL_1]] %[[VAL_2]] : i32
+// CHECK: wasmssa.if %[[VAL_3]](%[[VAL_0]]) : i32 : {
+// CHECK: ^bb0(%[[VAL_4:.*]]: i32):
+// CHECK: %[[VAL_5:.*]] = wasmssa.const 1 : i32
+// CHECK: %[[VAL_6:.*]] = wasmssa.add %[[VAL_4]] %[[VAL_5]] : i32
+// CHECK: wasmssa.block_return %[[VAL_6]] : i32
+// CHECK: } > ^bb1
+// CHECK: ^bb1(%[[VAL_7:.*]]: i32):
+// CHECK: wasmssa.return %[[VAL_7]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_2(
+// CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i32>) -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.local_get %[[ARG0]] : ref to i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.ctz %[[VAL_0]] : i32
+// CHECK: wasmssa.if %[[VAL_1]] : {
+// CHECK: %[[VAL_2:.*]] = wasmssa.const 2 : i32
+// CHECK: %[[VAL_3:.*]] = wasmssa.local_get %[[ARG0]] : ref to i32
+// CHECK: %[[VAL_4:.*]] = wasmssa.const 1 : i32
+// CHECK: %[[VAL_5:.*]] = wasmssa.shr_u %[[VAL_3]] by %[[VAL_4]] bits : i32
+// CHECK: %[[VAL_6:.*]] = wasmssa.ctz %[[VAL_5]] : i32
+// CHECK: wasmssa.if %[[VAL_6]](%[[VAL_2]]) : i32 : {
+// CHECK: ^bb0(%[[VAL_7:.*]]: i32):
+// CHECK: %[[VAL_8:.*]] = wasmssa.const 2 : i32
+// CHECK: %[[VAL_9:.*]] = wasmssa.add %[[VAL_7]] %[[VAL_8]] : i32
+// CHECK: wasmssa.block_return %[[VAL_9]] : i32
+// CHECK: } > ^bb1
+// CHECK: ^bb1(%[[VAL_10:.*]]: i32):
+// CHECK: wasmssa.block_return %[[VAL_10]] : i32
+// CHECK: } "else "{
+// CHECK: %[[VAL_11:.*]] = wasmssa.const 1 : i32
+// CHECK: wasmssa.block_return %[[VAL_11]] : i32
+// CHECK: }> ^bb1
+// CHECK: ^bb1(%[[VAL_12:.*]]: i32):
+// CHECK: wasmssa.return %[[VAL_12]] : i32
diff --git a/mlir/test/Target/Wasm/import.mlir b/mlir/test/Target/Wasm/import.mlir
index 541dcf3..dcdfa52 100644
--- a/mlir/test/Target/Wasm/import.mlir
+++ b/mlir/test/Target/Wasm/import.mlir
@@ -11,9 +11,9 @@
)
*/
-// CHECK-LABEL: wasmssa.import_func "foo" from "my_module" as @func_0 {sym_visibility = "nested", type = (i32) -> ()}
-// CHECK: wasmssa.import_func "bar" from "my_module" as @func_1 {sym_visibility = "nested", type = (i32) -> ()}
-// CHECK: wasmssa.import_table "table" from "my_module" as @table_0 {sym_visibility = "nested", type = !wasmssa<tabletype !wasmssa.funcref [2:]>}
-// CHECK: wasmssa.import_mem "mem" from "my_module" as @mem_0 {limits = !wasmssa<limit[2:]>, sym_visibility = "nested"}
-// CHECK: wasmssa.import_global "glob" from "my_module" as @global_0 nested : i32
-// CHECK: wasmssa.import_global "glob_mut" from "my_other_module" as @global_1 mutable nested : i32
+// CHECK-LABEL: wasmssa.import_func "foo" from "my_module" as @func_0 {type = (i32) -> ()}
+// CHECK: wasmssa.import_func "bar" from "my_module" as @func_1 {type = (i32) -> ()}
+// CHECK: wasmssa.import_table "table" from "my_module" as @table_0 {type = !wasmssa<tabletype !wasmssa.funcref [2:]>}
+// CHECK: wasmssa.import_mem "mem" from "my_module" as @mem_0 {limits = !wasmssa<limit[2:]>}
+// CHECK: wasmssa.import_global "glob" from "my_module" as @global_0 : i32
+// CHECK: wasmssa.import_global "glob_mut" from "my_other_module" as @global_1 mutable : i32
diff --git a/mlir/test/Target/Wasm/inputs/add_div.yaml.wasm b/mlir/test/Target/Wasm/inputs/add_div.yaml.wasm
new file mode 100644
index 0000000..865c315
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/add_div.yaml.wasm
@@ -0,0 +1,50 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes:
+ - I32
+ ReturnTypes:
+ - I32
+ - Index: 1
+ ParamTypes:
+ - I32
+ - I32
+ ReturnTypes:
+ - I32
+ - Type: IMPORT
+ Imports:
+ - Module: env
+ Field: twoTimes
+ Kind: FUNCTION
+ SigIndex: 0
+ - Type: FUNCTION
+ FunctionTypes: [ 1 ]
+ - Type: MEMORY
+ Memories:
+ - Minimum: 0x2
+ - Type: GLOBAL
+ Globals:
+ - Index: 0
+ Type: I32
+ Mutable: true
+ InitExpr:
+ Opcode: I32_CONST
+ Value: 66560
+ - Type: EXPORT
+ Exports:
+ - Name: memory
+ Kind: MEMORY
+ Index: 0
+ - Name: add
+ Kind: FUNCTION
+ Index: 1
+ - Type: CODE
+ Functions:
+ - Index: 1
+ Locals: []
+ Body: 20001000200110006A41026D0B
+...
diff --git a/mlir/test/Target/Wasm/inputs/block.yaml.wasm b/mlir/test/Target/Wasm/inputs/block.yaml.wasm
new file mode 100644
index 0000000..dd5118a
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/block.yaml.wasm
@@ -0,0 +1,22 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes: []
+ - Type: FUNCTION
+ FunctionTypes: [ 0 ]
+ - Type: EXPORT
+ Exports:
+ - Name: i_am_a_block
+ Kind: FUNCTION
+ Index: 0
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 02400B0B
+...
diff --git a/mlir/test/Target/Wasm/inputs/block_complete_type.yaml.wasm b/mlir/test/Target/Wasm/inputs/block_complete_type.yaml.wasm
new file mode 100644
index 0000000..7a125bf
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/block_complete_type.yaml.wasm
@@ -0,0 +1,23 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes:
+ - I32
+ ReturnTypes:
+ - I32
+ - Index: 1
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Type: FUNCTION
+ FunctionTypes: [ 1 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 410E020041016A0B0B
+...
diff --git a/mlir/test/Target/Wasm/inputs/block_value_type.yaml.wasm b/mlir/test/Target/Wasm/inputs/block_value_type.yaml.wasm
new file mode 100644
index 0000000..4ba291d
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/block_value_type.yaml.wasm
@@ -0,0 +1,18 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Type: FUNCTION
+ FunctionTypes: [ 0 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 027F41110B0B
+...
diff --git a/mlir/test/Target/Wasm/inputs/branch_if.yaml.wasm b/mlir/test/Target/Wasm/inputs/branch_if.yaml.wasm
new file mode 100644
index 0000000..40536ed
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/branch_if.yaml.wasm
@@ -0,0 +1,18 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Type: FUNCTION
+ FunctionTypes: [ 0 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 027F410141020D0041016A0B0B
+...
diff --git a/mlir/test/Target/Wasm/inputs/call.yaml.wasm b/mlir/test/Target/Wasm/inputs/call.yaml.wasm
new file mode 100644
index 0000000..535a623
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/call.yaml.wasm
@@ -0,0 +1,26 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 0 ]
+ - Type: EXPORT
+ Exports:
+ - Name: forty_two
+ Kind: FUNCTION
+ Index: 1
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 412A0B
+ - Index: 1
+ Locals: []
+ Body: 10000B
+...
diff --git a/mlir/test/Target/Wasm/inputs/comparison_ops.yaml.wasm b/mlir/test/Target/Wasm/inputs/comparison_ops.yaml.wasm
new file mode 100644
index 0000000..cde9ee1
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/comparison_ops.yaml.wasm
@@ -0,0 +1,88 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 410C4132480B
+ - Index: 1
+ Locals: []
+ Body: 410C41324C0B
+ - Index: 2
+ Locals: []
+ Body: 410C4132490B
+ - Index: 3
+ Locals: []
+ Body: 410C41324D0B
+ - Index: 4
+ Locals: []
+ Body: 410C41324A0B
+ - Index: 5
+ Locals: []
+ Body: 410C41324B0B
+ - Index: 6
+ Locals: []
+ Body: 410C41324E0B
+ - Index: 7
+ Locals: []
+ Body: 410C41324F0B
+ - Index: 8
+ Locals: []
+ Body: 420C4232530B
+ - Index: 9
+ Locals: []
+ Body: 420C4232570B
+ - Index: 10
+ Locals: []
+ Body: 420C4232540B
+ - Index: 11
+ Locals: []
+ Body: 420C4232580B
+ - Index: 12
+ Locals: []
+ Body: 420C4232550B
+ - Index: 13
+ Locals: []
+ Body: 420C4232560B
+ - Index: 14
+ Locals: []
+ Body: 420C4232590B
+ - Index: 15
+ Locals: []
+ Body: 420C42325A0B
+ - Index: 16
+ Locals: []
+ Body: 430000A04043000060415D0B
+ - Index: 17
+ Locals: []
+ Body: 430000A04043000060415F0B
+ - Index: 18
+ Locals: []
+ Body: 430000A04043000060415E0B
+ - Index: 19
+ Locals: []
+ Body: 430000A0404300006041600B
+ - Index: 20
+ Locals: []
+ Body: 440000000000001440440000000000002C40630B
+ - Index: 21
+ Locals: []
+ Body: 440000000000001440440000000000002C40650B
+ - Index: 22
+ Locals: []
+ Body: 440000000000001440440000000000002C40640B
+ - Index: 23
+ Locals: []
+ Body: 440000000000001440440000000000002C40660B
+...
diff --git a/mlir/test/Target/Wasm/inputs/convert.yaml.wasm b/mlir/test/Target/Wasm/inputs/convert.yaml.wasm
new file mode 100644
index 0000000..c346a75
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/convert.yaml.wasm
@@ -0,0 +1,69 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - F32
+ - Index: 1
+ ParamTypes: []
+ ReturnTypes:
+ - F64
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 0, 0, 0, 1, 1, 1, 1 ]
+ - Type: EXPORT
+ Exports:
+ - Name: convert_i32_u_to_f32
+ Kind: FUNCTION
+ Index: 0
+ - Name: convert_i32_s_to_f32
+ Kind: FUNCTION
+ Index: 1
+ - Name: convert_i64_u_to_f32
+ Kind: FUNCTION
+ Index: 2
+ - Name: convert_i64s_to_f32
+ Kind: FUNCTION
+ Index: 3
+ - Name: convert_i32_u_to_f64
+ Kind: FUNCTION
+ Index: 4
+ - Name: convert_i32_s_to_f64
+ Kind: FUNCTION
+ Index: 5
+ - Name: convert_i64_u_to_f64
+ Kind: FUNCTION
+ Index: 6
+ - Name: convert_i64s_to_f64
+ Kind: FUNCTION
+ Index: 7
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 410AB30B
+ - Index: 1
+ Locals: []
+ Body: 412AB20B
+ - Index: 2
+ Locals: []
+ Body: 4211B50B
+ - Index: 3
+ Locals: []
+ Body: 420AB40B
+ - Index: 4
+ Locals: []
+ Body: 410AB80B
+ - Index: 5
+ Locals: []
+ Body: 412AB70B
+ - Index: 6
+ Locals: []
+ Body: 4211BA0B
+ - Index: 7
+ Locals: []
+ Body: 420AB90B
+...
diff --git a/mlir/test/Target/Wasm/inputs/demote.yaml.wasm b/mlir/test/Target/Wasm/inputs/demote.yaml.wasm
new file mode 100644
index 0000000..3997045
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/demote.yaml.wasm
@@ -0,0 +1,18 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - F32
+ - Type: FUNCTION
+ FunctionTypes: [ 0 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 44EC51B81E85EB0140B60B
+...
diff --git a/mlir/test/Target/Wasm/inputs/double_nested_loop.yaml.wasm b/mlir/test/Target/Wasm/inputs/double_nested_loop.yaml.wasm
new file mode 100644
index 0000000..41a2944
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/double_nested_loop.yaml.wasm
@@ -0,0 +1,19 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes: []
+ - Type: FUNCTION
+ FunctionTypes: [ 0 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals:
+ - Type: I32
+ Count: 2
+ Body: 0340200041016A2100037F41012001410C6A220120004A0D000B410A480D000B0B
+...
diff --git a/mlir/test/Target/Wasm/inputs/empty_blocks_list_and_stack.yaml.wasm b/mlir/test/Target/Wasm/inputs/empty_blocks_list_and_stack.yaml.wasm
new file mode 100644
index 0000000..3171409
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/empty_blocks_list_and_stack.yaml.wasm
@@ -0,0 +1,21 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes:
+ - I32
+ ReturnTypes: []
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 0 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 0240024002400B0B0B0B
+ - Index: 1
+ Locals: []
+ Body: 02400B02400B02400B0B
+...
diff --git a/mlir/test/Target/Wasm/inputs/eq.yaml.wasm b/mlir/test/Target/Wasm/inputs/eq.yaml.wasm
new file mode 100644
index 0000000..1998369
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/eq.yaml.wasm
@@ -0,0 +1,27 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 0, 0, 0 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 410C4132460B
+ - Index: 1
+ Locals: []
+ Body: 42144205510B
+ - Index: 2
+ Locals: []
+ Body: 430000A04043000060415B0B
+ - Index: 3
+ Locals: []
+ Body: 440000000000003140440000000000000000610B
+...
diff --git a/mlir/test/Target/Wasm/inputs/eqz.yaml.wasm b/mlir/test/Target/Wasm/inputs/eqz.yaml.wasm
new file mode 100644
index 0000000..894ac50
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/eqz.yaml.wasm
@@ -0,0 +1,29 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 0 ]
+ - Type: EXPORT
+ Exports:
+ - Name: eqz_i32
+ Kind: FUNCTION
+ Index: 0
+ - Name: eqz_i64
+ Kind: FUNCTION
+ Index: 1
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 410D450B
+ - Index: 1
+ Locals: []
+ Body: 420D500B
+...
diff --git a/mlir/test/Target/Wasm/inputs/extend.yaml.wasm b/mlir/test/Target/Wasm/inputs/extend.yaml.wasm
new file mode 100644
index 0000000..7e872ba
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/extend.yaml.wasm
@@ -0,0 +1,40 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I64
+ - Index: 1
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 0, 1, 1, 0, 0, 0 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 410AAC0B
+ - Index: 1
+ Locals: []
+ Body: 410AAD0B
+ - Index: 2
+ Locals: []
+ Body: 410AC00B
+ - Index: 3
+ Locals: []
+ Body: 410AC10B
+ - Index: 4
+ Locals: []
+ Body: 420AC20B
+ - Index: 5
+ Locals: []
+ Body: 420AC30B
+ - Index: 6
+ Locals: []
+ Body: 420AC40B
+...
diff --git a/mlir/test/Target/Wasm/inputs/if.yaml.wasm b/mlir/test/Target/Wasm/inputs/if.yaml.wasm
new file mode 100644
index 0000000..ccc38f6
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/if.yaml.wasm
@@ -0,0 +1,25 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes:
+ - I32
+ ReturnTypes:
+ - I32
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 0, 0 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 2000410171047F200041036C41016A0520004101760B0B
+ - Index: 1
+ Locals: []
+ Body: 20002000410171040041016A0B0B
+ - Index: 2
+ Locals: []
+ Body: 200068047F4102200041017668040041026A0B0541010B0B
+...
diff --git a/mlir/test/Target/Wasm/inputs/loop.yaml.wasm b/mlir/test/Target/Wasm/inputs/loop.yaml.wasm
new file mode 100644
index 0000000..9d33894
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/loop.yaml.wasm
@@ -0,0 +1,17 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes: []
+ - Type: FUNCTION
+ FunctionTypes: [ 0 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 03400B0B
+...
diff --git a/mlir/test/Target/Wasm/inputs/loop_with_inst.yaml.wasm b/mlir/test/Target/Wasm/inputs/loop_with_inst.yaml.wasm
new file mode 100644
index 0000000..4b8cc54
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/loop_with_inst.yaml.wasm
@@ -0,0 +1,20 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Type: FUNCTION
+ FunctionTypes: [ 0 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals:
+ - Type: I32
+ Count: 1
+ Body: 037F200041016A21002000410A480B0B
+...
diff --git a/mlir/test/Target/Wasm/inputs/ne.yaml.wasm b/mlir/test/Target/Wasm/inputs/ne.yaml.wasm
new file mode 100644
index 0000000..0167519
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/ne.yaml.wasm
@@ -0,0 +1,27 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 0, 0, 0 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 410C4132470B
+ - Index: 1
+ Locals: []
+ Body: 42144205520B
+ - Index: 2
+ Locals: []
+ Body: 430000A04043000060415C0B
+ - Index: 3
+ Locals: []
+ Body: 440000000000003140440000000000000000620B
+...
diff --git a/mlir/test/Target/Wasm/inputs/promote.yaml.wasm b/mlir/test/Target/Wasm/inputs/promote.yaml.wasm
new file mode 100644
index 0000000..d38603e
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/promote.yaml.wasm
@@ -0,0 +1,18 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - F64
+ - Type: FUNCTION
+ FunctionTypes: [ 0 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 4300002841BB0B
+...
diff --git a/mlir/test/Target/Wasm/inputs/reinterpret.yaml.wasm b/mlir/test/Target/Wasm/inputs/reinterpret.yaml.wasm
new file mode 100644
index 0000000..c01c1b1
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/reinterpret.yaml.wasm
@@ -0,0 +1,53 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Index: 1
+ ParamTypes: []
+ ReturnTypes:
+ - I64
+ - Index: 2
+ ParamTypes: []
+ ReturnTypes:
+ - F32
+ - Index: 3
+ ParamTypes: []
+ ReturnTypes:
+ - F64
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 1, 2, 3 ]
+ - Type: EXPORT
+ Exports:
+ - Name: i32.reinterpret_f32
+ Kind: FUNCTION
+ Index: 0
+ - Name: i64.reinterpret_f64
+ Kind: FUNCTION
+ Index: 1
+ - Name: f32.reinterpret_i32
+ Kind: FUNCTION
+ Index: 2
+ - Name: f64.reinterpret_i64
+ Kind: FUNCTION
+ Index: 3
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 43000080BFBC0B
+ - Index: 1
+ Locals: []
+ Body: 44000000000000F0BFBD0B
+ - Index: 2
+ Locals: []
+ Body: 417FBE0B
+ - Index: 3
+ Locals: []
+ Body: 427FBF0B
+...
diff --git a/mlir/test/Target/Wasm/inputs/rounding.yaml.wasm b/mlir/test/Target/Wasm/inputs/rounding.yaml.wasm
new file mode 100644
index 0000000..c6e8bf6
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/rounding.yaml.wasm
@@ -0,0 +1,37 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - F64
+ - Index: 1
+ ParamTypes: []
+ ReturnTypes:
+ - F32
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 1, 0, 1, 0, 1 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 4433333333333328C09B0B
+ - Index: 1
+ Locals: []
+ Body: 43A01ACF3F8D0B
+ - Index: 2
+ Locals: []
+ Body: 4433333333333328C09C0B
+ - Index: 3
+ Locals: []
+ Body: 43A01ACF3F8E0B
+ - Index: 4
+ Locals: []
+ Body: 4433333333333328C09D0B
+ - Index: 5
+ Locals: []
+ Body: 43A01ACF3F8F0B
+...
diff --git a/mlir/test/Target/Wasm/inputs/wrap.yaml.wasm b/mlir/test/Target/Wasm/inputs/wrap.yaml.wasm
new file mode 100644
index 0000000..51c0b02
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/wrap.yaml.wasm
@@ -0,0 +1,24 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes:
+ - I64
+ ReturnTypes:
+ - I32
+ - Type: FUNCTION
+ FunctionTypes: [ 0 ]
+ - Type: EXPORT
+ Exports:
+ - Name: i64_wrap
+ Kind: FUNCTION
+ Index: 0
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 2000A70B
+...
diff --git a/mlir/test/Target/Wasm/invalid_block_type_index.yaml b/mlir/test/Target/Wasm/invalid_block_type_index.yaml
new file mode 100644
index 0000000..5b83e2e
--- /dev/null
+++ b/mlir/test/Target/Wasm/invalid_block_type_index.yaml
@@ -0,0 +1,28 @@
+
+# RUN: yaml2obj %s | not mlir-translate --import-wasm -o - 2>&1 | FileCheck %s
+
+# CHECK: type index references nonexistent type (2)
+
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes:
+ - I32
+ ReturnTypes:
+ - I32
+ - Index: 1
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Type: FUNCTION
+ FunctionTypes: [ 1 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 410E020241016A0B0B
+# -----------------------------^^ Invalid type ID
diff --git a/mlir/test/Target/Wasm/local.mlir b/mlir/test/Target/Wasm/local.mlir
index 32f5900..9844f9c 100644
--- a/mlir/test/Target/Wasm/local.mlir
+++ b/mlir/test/Target/Wasm/local.mlir
@@ -29,7 +29,7 @@
)
*/
-// CHECK-LABEL: wasmssa.func nested @func_0() -> f32 {
+// CHECK-LABEL: wasmssa.func @func_0() -> f32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.local of type f32
// CHECK: %[[VAL_1:.*]] = wasmssa.local of type f32
// CHECK: %[[VAL_2:.*]] = wasmssa.const 8.000000e+00 : f32
@@ -40,7 +40,7 @@
// CHECK: %[[VAL_6:.*]] = wasmssa.add %[[VAL_3]] %[[VAL_5]] : f32
// CHECK: wasmssa.return %[[VAL_6]] : f32
-// CHECK-LABEL: wasmssa.func nested @func_1() -> i32 {
+// CHECK-LABEL: wasmssa.func @func_1() -> i32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.local of type i32
// CHECK: %[[VAL_1:.*]] = wasmssa.local of type i32
// CHECK: %[[VAL_2:.*]] = wasmssa.const 8 : i32
@@ -51,7 +51,7 @@
// CHECK: %[[VAL_6:.*]] = wasmssa.add %[[VAL_3]] %[[VAL_5]] : i32
// CHECK: wasmssa.return %[[VAL_6]] : i32
-// CHECK-LABEL: wasmssa.func nested @func_2(
+// CHECK-LABEL: wasmssa.func @func_2(
// CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i32>) -> i32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 3 : i32
// CHECK: wasmssa.local_set %[[ARG0]] : ref to i32 to %[[VAL_0]] : i32
diff --git a/mlir/test/Target/Wasm/loop.mlir b/mlir/test/Target/Wasm/loop.mlir
new file mode 100644
index 0000000..29ad502
--- /dev/null
+++ b/mlir/test/Target/Wasm/loop.mlir
@@ -0,0 +1,17 @@
+// RUN: yaml2obj %S/inputs/loop.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* IR generated from:
+(module
+ (func
+ (loop $my_loop
+ )
+ )
+)*/
+
+// CHECK-LABEL: wasmssa.func @func_0() {
+// CHECK: wasmssa.loop : {
+// CHECK: wasmssa.block_return
+// CHECK: }> ^bb1
+// CHECK: ^bb1:
+// CHECK: wasmssa.return
+// CHECK: }
diff --git a/mlir/test/Target/Wasm/loop_with_inst.mlir b/mlir/test/Target/Wasm/loop_with_inst.mlir
new file mode 100644
index 0000000..311d007
--- /dev/null
+++ b/mlir/test/Target/Wasm/loop_with_inst.mlir
@@ -0,0 +1,33 @@
+// RUN: yaml2obj %S/inputs/loop_with_inst.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Code used to create this test:
+
+(module
+ (func (result i32)
+ (local $i i32)
+ (loop $my_loop (result i32)
+ local.get $i
+ i32.const 1
+ i32.add
+ local.set $i
+ local.get $i
+ i32.const 10
+ i32.lt_s
+ )
+ )
+)*/
+
+// CHECK-LABEL: wasmssa.func @func_0() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.local of type i32
+// CHECK: wasmssa.loop : {
+// CHECK: %[[VAL_1:.*]] = wasmssa.local_get %[[VAL_0]] : ref to i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.const 1 : i32
+// CHECK: %[[VAL_3:.*]] = wasmssa.add %[[VAL_1]] %[[VAL_2]] : i32
+// CHECK: wasmssa.local_set %[[VAL_0]] : ref to i32 to %[[VAL_3]] : i32
+// CHECK: %[[VAL_4:.*]] = wasmssa.local_get %[[VAL_0]] : ref to i32
+// CHECK: %[[VAL_5:.*]] = wasmssa.const 10 : i32
+// CHECK: %[[VAL_6:.*]] = wasmssa.lt_si %[[VAL_4]] %[[VAL_5]] : i32 -> i32
+// CHECK: wasmssa.block_return %[[VAL_6]] : i32
+// CHECK: }> ^bb1
+// CHECK: ^bb1(%[[VAL_7:.*]]: i32):
+// CHECK: wasmssa.return %[[VAL_7]] : i32
diff --git a/mlir/test/Target/Wasm/max.mlir b/mlir/test/Target/Wasm/max.mlir
index 4ef2042..9160bde 100644
--- a/mlir/test/Target/Wasm/max.mlir
+++ b/mlir/test/Target/Wasm/max.mlir
@@ -16,14 +16,14 @@
)
*/
-// CHECK-LABEL: wasmssa.func @min_f32() -> f32 {
+// CHECK-LABEL: wasmssa.func exported @min_f32() -> f32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f32
// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.000000e+00 : f32
// CHECK: %[[VAL_2:.*]] = wasmssa.max %[[VAL_0]] %[[VAL_1]] : f32
// CHECK: wasmssa.return %[[VAL_2]] : f32
-// CHECK-LABEL: wasmssa.func @min_f64() -> f64 {
+// CHECK-LABEL: wasmssa.func exported @min_f64() -> f64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f64
// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.000000e+00 : f64
// CHECK: %[[VAL_2:.*]] = wasmssa.max %[[VAL_0]] %[[VAL_1]] : f64
diff --git a/mlir/test/Target/Wasm/memory_min_eq_max.mlir b/mlir/test/Target/Wasm/memory_min_eq_max.mlir
index 2ba5ab5..ea8f719 100644
--- a/mlir/test/Target/Wasm/memory_min_eq_max.mlir
+++ b/mlir/test/Target/Wasm/memory_min_eq_max.mlir
@@ -4,4 +4,4 @@
(module (memory 0 0))
*/
-// CHECK-LABEL: wasmssa.memory @mem_0 nested !wasmssa<limit[0: 0]>
+// CHECK-LABEL: wasmssa.memory @mem_0 !wasmssa<limit[0: 0]>
diff --git a/mlir/test/Target/Wasm/memory_min_max.mlir b/mlir/test/Target/Wasm/memory_min_max.mlir
index ebf6418..88782ec 100644
--- a/mlir/test/Target/Wasm/memory_min_max.mlir
+++ b/mlir/test/Target/Wasm/memory_min_max.mlir
@@ -4,4 +4,4 @@
(module (memory 0 65536))
*/
-// CHECK-LABEL: wasmssa.memory @mem_0 nested !wasmssa<limit[0: 65536]>
+// CHECK-LABEL: wasmssa.memory @mem_0 !wasmssa<limit[0: 65536]>
diff --git a/mlir/test/Target/Wasm/memory_min_no_max.mlir b/mlir/test/Target/Wasm/memory_min_no_max.mlir
index 8d88786..c10c5cc 100644
--- a/mlir/test/Target/Wasm/memory_min_no_max.mlir
+++ b/mlir/test/Target/Wasm/memory_min_no_max.mlir
@@ -4,4 +4,4 @@
(module (memory 1))
*/
-// CHECK-LABEL: wasmssa.memory @mem_0 nested !wasmssa<limit[1:]>
+// CHECK-LABEL: wasmssa.memory @mem_0 !wasmssa<limit[1:]>
diff --git a/mlir/test/Target/Wasm/min.mlir b/mlir/test/Target/Wasm/min.mlir
index 1058c7d..2372bcc 100644
--- a/mlir/test/Target/Wasm/min.mlir
+++ b/mlir/test/Target/Wasm/min.mlir
@@ -16,13 +16,13 @@
)
*/
-// CHECK-LABEL: wasmssa.func @min_f32() -> f32 {
+// CHECK-LABEL: wasmssa.func exported @min_f32() -> f32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f32
// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.000000e+00 : f32
// CHECK: %[[VAL_2:.*]] = wasmssa.min %[[VAL_0]] %[[VAL_1]] : f32
// CHECK: wasmssa.return %[[VAL_2]] : f32
-// CHECK-LABEL: wasmssa.func @min_f64() -> f64 {
+// CHECK-LABEL: wasmssa.func exported @min_f64() -> f64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f64
// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.000000e+00 : f64
// CHECK: %[[VAL_2:.*]] = wasmssa.min %[[VAL_0]] %[[VAL_1]] : f64
diff --git a/mlir/test/Target/Wasm/ne.mlir b/mlir/test/Target/Wasm/ne.mlir
new file mode 100644
index 0000000..331df75
--- /dev/null
+++ b/mlir/test/Target/Wasm/ne.mlir
@@ -0,0 +1,52 @@
+// RUN: yaml2obj %S/inputs/ne.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+/* Source code used to create this test:
+(module
+ (func $ne_i32 (result i32)
+ i32.const 12
+ i32.const 50
+ i32.ne
+ )
+
+ (func $ne_i64 (result i32)
+ i64.const 20
+ i64.const 5
+ i64.ne
+ )
+
+ (func $ne_f32 (result i32)
+ f32.const 5
+ f32.const 14
+ f32.ne
+ )
+
+ (func $ne_f64 (result i32)
+ f64.const 17
+ f64.const 0
+ f64.ne
+ )
+)
+*/
+
+// CHECK-LABEL: wasmssa.func @func_0() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.ne %[[VAL_0]] %[[VAL_1]] : i32 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_1() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 20 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 5 : i64
+// CHECK: %[[VAL_2:.*]] = wasmssa.ne %[[VAL_0]] %[[VAL_1]] : i64 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_2() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 5.000000e+00 : f32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.400000e+01 : f32
+// CHECK: %[[VAL_2:.*]] = wasmssa.ne %[[VAL_0]] %[[VAL_1]] : f32 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_3() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.700000e+01 : f64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 0.000000e+00 : f64
+// CHECK: %[[VAL_2:.*]] = wasmssa.ne %[[VAL_0]] %[[VAL_1]] : f64 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
diff --git a/mlir/test/Target/Wasm/neg.mlir b/mlir/test/Target/Wasm/neg.mlir
index 5811ab50..dae8ee5 100644
--- a/mlir/test/Target/Wasm/neg.mlir
+++ b/mlir/test/Target/Wasm/neg.mlir
@@ -12,12 +12,12 @@
)
*/
-// CHECK-LABEL: wasmssa.func @neg_f32() -> f32 {
+// CHECK-LABEL: wasmssa.func exported @neg_f32() -> f32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f32
// CHECK: %[[VAL_1:.*]] = wasmssa.neg %[[VAL_0]] : f32
// CHECK: wasmssa.return %[[VAL_1]] : f32
-// CHECK-LABEL: wasmssa.func @neg_f64() -> f64 {
+// CHECK-LABEL: wasmssa.func exported @neg_f64() -> f64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f64
// CHECK: %[[VAL_1:.*]] = wasmssa.neg %[[VAL_0]] : f64
// CHECK: wasmssa.return %[[VAL_1]] : f64
diff --git a/mlir/test/Target/Wasm/or.mlir b/mlir/test/Target/Wasm/or.mlir
index 521f2ba..be0b3d7 100644
--- a/mlir/test/Target/Wasm/or.mlir
+++ b/mlir/test/Target/Wasm/or.mlir
@@ -14,13 +14,13 @@
)
*/
-// CHECK-LABEL: wasmssa.func @or_i32() -> i32 {
+// CHECK-LABEL: wasmssa.func exported @or_i32() -> i32 {
// CHECK: %0 = wasmssa.const 10 : i32
// CHECK: %1 = wasmssa.const 3 : i32
// CHECK: %2 = wasmssa.or %0 %1 : i32
// CHECK: wasmssa.return %2 : i32
-// CHECK-LABEL: wasmssa.func @or_i64() -> i64 {
+// CHECK-LABEL: wasmssa.func exported @or_i64() -> i64 {
// CHECK: %0 = wasmssa.const 10 : i64
// CHECK: %1 = wasmssa.const 3 : i64
// CHECK: %2 = wasmssa.or %0 %1 : i64
diff --git a/mlir/test/Target/Wasm/popcnt.mlir b/mlir/test/Target/Wasm/popcnt.mlir
index 235333a..bfaa8eb 100644
--- a/mlir/test/Target/Wasm/popcnt.mlir
+++ b/mlir/test/Target/Wasm/popcnt.mlir
@@ -14,12 +14,12 @@
)
*/
-// CHECK-LABEL: wasmssa.func @popcnt_i32() -> i32 {
+// CHECK-LABEL: wasmssa.func exported @popcnt_i32() -> i32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
// CHECK: %[[VAL_1:.*]] = wasmssa.popcnt %[[VAL_0]] : i32
// CHECK: wasmssa.return %[[VAL_1]] : i32
-// CHECK-LABEL: wasmssa.func @popcnt_i64() -> i64 {
+// CHECK-LABEL: wasmssa.func exported @popcnt_i64() -> i64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64
// CHECK: %[[VAL_1:.*]] = wasmssa.popcnt %[[VAL_0]] : i64
// CHECK: wasmssa.return %[[VAL_1]] : i64
diff --git a/mlir/test/Target/Wasm/promote.mlir b/mlir/test/Target/Wasm/promote.mlir
new file mode 100644
index 0000000..44c31b6
--- /dev/null
+++ b/mlir/test/Target/Wasm/promote.mlir
@@ -0,0 +1,14 @@
+// RUN: yaml2obj %S/inputs/promote.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to generate this test:
+(module
+ (func $main (result f64)
+ f32.const 10.5
+ f64.promote_f32
+ )
+)*/
+
+// CHECK-LABEL: wasmssa.func @func_0() -> f64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.050000e+01 : f32
+// CHECK: %[[VAL_1:.*]] = wasmssa.promote %[[VAL_0]] : f32 to f64
+// CHECK: wasmssa.return %[[VAL_1]] : f64
diff --git a/mlir/test/Target/Wasm/reinterpret.mlir b/mlir/test/Target/Wasm/reinterpret.mlir
new file mode 100644
index 0000000..574d13f
--- /dev/null
+++ b/mlir/test/Target/Wasm/reinterpret.mlir
@@ -0,0 +1,46 @@
+// RUN: yaml2obj %S/inputs/reinterpret.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/*
+Test generated from:
+(module
+ (func (export "i32.reinterpret_f32") (result i32)
+ f32.const -1
+ i32.reinterpret_f32
+ )
+
+ (func (export "i64.reinterpret_f64") (result i64)
+ f64.const -1
+ i64.reinterpret_f64
+ )
+
+ (func (export "f32.reinterpret_i32") (result f32)
+ i32.const -1
+ f32.reinterpret_i32
+ )
+
+ (func (export "f64.reinterpret_i64") (result f64)
+ i64.const -1
+ f64.reinterpret_i64
+ )
+)
+*/
+
+// CHECK-LABEL: wasmssa.func exported @i32.reinterpret_f32() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const -1.000000e+00 : f32
+// CHECK: %[[VAL_1:.*]] = wasmssa.reinterpret %[[VAL_0]] : f32 as i32
+// CHECK: wasmssa.return %[[VAL_1]] : i32
+
+// CHECK-LABEL: wasmssa.func exported @i64.reinterpret_f64() -> i64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const -1.000000e+00 : f64
+// CHECK: %[[VAL_1:.*]] = wasmssa.reinterpret %[[VAL_0]] : f64 as i64
+// CHECK: wasmssa.return %[[VAL_1]] : i64
+
+// CHECK-LABEL: wasmssa.func exported @f32.reinterpret_i32() -> f32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const -1 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.reinterpret %[[VAL_0]] : i32 as f32
+// CHECK: wasmssa.return %[[VAL_1]] : f32
+
+// CHECK-LABEL: wasmssa.func exported @f64.reinterpret_i64() -> f64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const -1 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.reinterpret %[[VAL_0]] : i64 as f64
+// CHECK: wasmssa.return %[[VAL_1]] : f64
diff --git a/mlir/test/Target/Wasm/rem.mlir b/mlir/test/Target/Wasm/rem.mlir
index b19b8d9..16c9c78 100644
--- a/mlir/test/Target/Wasm/rem.mlir
+++ b/mlir/test/Target/Wasm/rem.mlir
@@ -24,28 +24,28 @@
)
*/
-// CHECK-LABEL: wasmssa.func @rem_u_i32() -> i32 {
+// CHECK-LABEL: wasmssa.func exported @rem_u_i32() -> i32 {
// CHECK: %0 = wasmssa.const 10 : i32
// CHECK: %1 = wasmssa.const 3 : i32
// CHECK: %2 = wasmssa.rem_ui %0 %1 : i32
// CHECK: wasmssa.return %2 : i32
// CHECK: }
-// CHECK-LABEL: wasmssa.func @rem_u_i64() -> i64 {
+// CHECK-LABEL: wasmssa.func exported @rem_u_i64() -> i64 {
// CHECK: %0 = wasmssa.const 10 : i64
// CHECK: %1 = wasmssa.const 3 : i64
// CHECK: %2 = wasmssa.rem_ui %0 %1 : i64
// CHECK: wasmssa.return %2 : i64
// CHECK: }
-// CHECK-LABEL: wasmssa.func @rem_s_i32() -> i32 {
+// CHECK-LABEL: wasmssa.func exported @rem_s_i32() -> i32 {
// CHECK: %0 = wasmssa.const 10 : i32
// CHECK: %1 = wasmssa.const 3 : i32
// CHECK: %2 = wasmssa.rem_si %0 %1 : i32
// CHECK: wasmssa.return %2 : i32
// CHECK: }
-// CHECK-LABEL: wasmssa.func @rem_s_i64() -> i64 {
+// CHECK-LABEL: wasmssa.func exported @rem_s_i64() -> i64 {
// CHECK: %0 = wasmssa.const 10 : i64
// CHECK: %1 = wasmssa.const 3 : i64
// CHECK: %2 = wasmssa.rem_si %0 %1 : i64
diff --git a/mlir/test/Target/Wasm/rotl.mlir b/mlir/test/Target/Wasm/rotl.mlir
index ec573554..4c2e5af 100644
--- a/mlir/test/Target/Wasm/rotl.mlir
+++ b/mlir/test/Target/Wasm/rotl.mlir
@@ -14,13 +14,13 @@
)
*/
-// CHECK-LABEL: wasmssa.func @rotl_i32() -> i32 {
+// CHECK-LABEL: wasmssa.func exported @rotl_i32() -> i32 {
// CHECK: %0 = wasmssa.const 10 : i32
// CHECK: %1 = wasmssa.const 3 : i32
// CHECK: %2 = wasmssa.rotl %0 by %1 bits : i32
// CHECK: wasmssa.return %2 : i32
-// CHECK-LABEL: wasmssa.func @rotl_i64() -> i64 {
+// CHECK-LABEL: wasmssa.func exported @rotl_i64() -> i64 {
// CHECK: %0 = wasmssa.const 10 : i64
// CHECK: %1 = wasmssa.const 3 : i64
// CHECK: %2 = wasmssa.rotl %0 by %1 bits : i64
diff --git a/mlir/test/Target/Wasm/rotr.mlir b/mlir/test/Target/Wasm/rotr.mlir
index 5618b43..ec403d0 100644
--- a/mlir/test/Target/Wasm/rotr.mlir
+++ b/mlir/test/Target/Wasm/rotr.mlir
@@ -14,13 +14,13 @@
)
*/
-// CHECK-LABEL: wasmssa.func @rotr_i32() -> i32 {
+// CHECK-LABEL: wasmssa.func exported @rotr_i32() -> i32 {
// CHECK: %0 = wasmssa.const 10 : i32
// CHECK: %1 = wasmssa.const 3 : i32
// CHECK: %2 = wasmssa.rotr %0 by %1 bits : i32
// CHECK: wasmssa.return %2 : i32
-// CHECK-LABEL: wasmssa.func @rotr_i64() -> i64 {
+// CHECK-LABEL: wasmssa.func exported @rotr_i64() -> i64 {
// CHECK: %0 = wasmssa.const 10 : i64
// CHECK: %1 = wasmssa.const 3 : i64
// CHECK: %2 = wasmssa.rotr %0 by %1 bits : i64
diff --git a/mlir/test/Target/Wasm/rounding.mlir b/mlir/test/Target/Wasm/rounding.mlir
new file mode 100644
index 0000000..947637e
--- /dev/null
+++ b/mlir/test/Target/Wasm/rounding.mlir
@@ -0,0 +1,50 @@
+// RUN: yaml2obj %S/inputs/rounding.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+/* Source code used to create this test:
+(module
+ (func $ceil_f64 (result f64)
+ f64.const -12.1
+ f64.ceil
+ )
+ (func $ceil_f32 (result f32)
+ f32.const 1.618
+ f32.ceil
+ )
+ (func $floor_f64 (result f64)
+ f64.const -12.1
+ f64.floor
+ )
+ (func $floor_f32 (result f32)
+ f32.const 1.618
+ f32.floor
+ )
+*/
+
+// CHECK-LABEL: wasmssa.func @func_0() -> f64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const -1.210000e+01 : f64
+// CHECK: %[[VAL_1:.*]] = wasmssa.ceil %[[VAL_0]] : f64
+// CHECK: wasmssa.return %[[VAL_1]] : f64
+
+// CHECK-LABEL: wasmssa.func @func_1() -> f32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.618000e+00 : f32
+// CHECK: %[[VAL_1:.*]] = wasmssa.ceil %[[VAL_0]] : f32
+// CHECK: wasmssa.return %[[VAL_1]] : f32
+
+// CHECK-LABEL: wasmssa.func @func_2() -> f64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const -1.210000e+01 : f64
+// CHECK: %[[VAL_1:.*]] = wasmssa.floor %[[VAL_0]] : f64
+// CHECK: wasmssa.return %[[VAL_1]] : f64
+
+// CHECK-LABEL: wasmssa.func @func_3() -> f32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.618000e+00 : f32
+// CHECK: %[[VAL_1:.*]] = wasmssa.floor %[[VAL_0]] : f32
+// CHECK: wasmssa.return %[[VAL_1]] : f32
+
+// CHECK-LABEL: wasmssa.func @func_4() -> f64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const -1.210000e+01 : f64
+// CHECK: %[[VAL_1:.*]] = wasmssa.trunc %[[VAL_0]] : f64
+// CHECK: wasmssa.return %[[VAL_1]] : f64
+
+// CHECK-LABEL: wasmssa.func @func_5() -> f32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.618000e+00 : f32
+// CHECK: %[[VAL_1:.*]] = wasmssa.trunc %[[VAL_0]] : f32
+// CHECK: wasmssa.return %[[VAL_1]] : f32
diff --git a/mlir/test/Target/Wasm/shl.mlir b/mlir/test/Target/Wasm/shl.mlir
index f2bdd57..1363112 100644
--- a/mlir/test/Target/Wasm/shl.mlir
+++ b/mlir/test/Target/Wasm/shl.mlir
@@ -14,13 +14,13 @@
)
*/
-// CHECK-LABEL: wasmssa.func @shl_i32() -> i32 {
+// CHECK-LABEL: wasmssa.func exported @shl_i32() -> i32 {
// CHECK: %0 = wasmssa.const 10 : i32
// CHECK: %1 = wasmssa.const 3 : i32
// CHECK: %2 = wasmssa.shl %0 by %1 bits : i32
// CHECK: wasmssa.return %2 : i32
-// CHECK-LABEL: wasmssa.func @shl_i64() -> i64 {
+// CHECK-LABEL: wasmssa.func exported @shl_i64() -> i64 {
// CHECK: %0 = wasmssa.const 10 : i64
// CHECK: %1 = wasmssa.const 3 : i64
// CHECK: %2 = wasmssa.shl %0 by %1 bits : i64
diff --git a/mlir/test/Target/Wasm/shr_s.mlir b/mlir/test/Target/Wasm/shr_s.mlir
index 247d9be..da1a38f 100644
--- a/mlir/test/Target/Wasm/shr_s.mlir
+++ b/mlir/test/Target/Wasm/shr_s.mlir
@@ -14,13 +14,13 @@
)
*/
-// CHECK-LABEL: wasmssa.func @shr_s_i32() -> i32 {
+// CHECK-LABEL: wasmssa.func exported @shr_s_i32() -> i32 {
// CHECK: %0 = wasmssa.const 10 : i32
// CHECK: %1 = wasmssa.const 3 : i32
// CHECK: %2 = wasmssa.shr_s %0 by %1 bits : i32
// CHECK: wasmssa.return %2 : i32
-// CHECK-LABEL: wasmssa.func @shr_s_i64() -> i64 {
+// CHECK-LABEL: wasmssa.func exported @shr_s_i64() -> i64 {
// CHECK: %0 = wasmssa.const 10 : i64
// CHECK: %1 = wasmssa.const 3 : i64
// CHECK: %2 = wasmssa.shr_s %0 by %1 bits : i64
diff --git a/mlir/test/Target/Wasm/shr_u.mlir b/mlir/test/Target/Wasm/shr_u.mlir
index 9a79eed..2991c2a 100644
--- a/mlir/test/Target/Wasm/shr_u.mlir
+++ b/mlir/test/Target/Wasm/shr_u.mlir
@@ -14,13 +14,13 @@
)
*/
-// CHECK-LABEL: wasmssa.func @shr_u_i32() -> i32 {
+// CHECK-LABEL: wasmssa.func exported @shr_u_i32() -> i32 {
// CHECK: %0 = wasmssa.const 10 : i32
// CHECK: %1 = wasmssa.const 3 : i32
// CHECK: %2 = wasmssa.shr_u %0 by %1 bits : i32
// CHECK: wasmssa.return %2 : i32
-// CHECK-LABEL: wasmssa.func @shr_u_i64() -> i64 {
+// CHECK-LABEL: wasmssa.func exported @shr_u_i64() -> i64 {
// CHECK: %0 = wasmssa.const 10 : i64
// CHECK: %1 = wasmssa.const 3 : i64
// CHECK: %2 = wasmssa.shr_u %0 by %1 bits : i64
diff --git a/mlir/test/Target/Wasm/sqrt.mlir b/mlir/test/Target/Wasm/sqrt.mlir
index 77444ad..6b968d6 100644
--- a/mlir/test/Target/Wasm/sqrt.mlir
+++ b/mlir/test/Target/Wasm/sqrt.mlir
@@ -12,12 +12,12 @@
)
*/
-// CHECK-LABEL: wasmssa.func @sqrt_f32() -> f32 {
+// CHECK-LABEL: wasmssa.func exported @sqrt_f32() -> f32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f32
// CHECK: %[[VAL_1:.*]] = wasmssa.sqrt %[[VAL_0]] : f32
// CHECK: wasmssa.return %[[VAL_1]] : f32
-// CHECK-LABEL: wasmssa.func @sqrt_f64() -> f64 {
+// CHECK-LABEL: wasmssa.func exported @sqrt_f64() -> f64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f64
// CHECK: %[[VAL_1:.*]] = wasmssa.sqrt %[[VAL_0]] : f64
// CHECK: wasmssa.return %[[VAL_1]] : f64
diff --git a/mlir/test/Target/Wasm/sub.mlir b/mlir/test/Target/Wasm/sub.mlir
index b9c6caf..5b242f4 100644
--- a/mlir/test/Target/Wasm/sub.mlir
+++ b/mlir/test/Target/Wasm/sub.mlir
@@ -27,25 +27,25 @@
)
*/
-// CHECK-LABEL: wasmssa.func nested @func_0() -> i32 {
+// CHECK-LABEL: wasmssa.func @func_0() -> i32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i32
// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i32
// CHECK: %[[VAL_2:.*]] = wasmssa.sub %[[VAL_0]] %[[VAL_1]] : i32
// CHECK: wasmssa.return %[[VAL_2]] : i32
-// CHECK-LABEL: wasmssa.func nested @func_1() -> i64 {
+// CHECK-LABEL: wasmssa.func @func_1() -> i64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 20 : i64
// CHECK: %[[VAL_1:.*]] = wasmssa.const 5 : i64
// CHECK: %[[VAL_2:.*]] = wasmssa.sub %[[VAL_0]] %[[VAL_1]] : i64
// CHECK: wasmssa.return %[[VAL_2]] : i64
-// CHECK-LABEL: wasmssa.func nested @func_2() -> f32 {
+// CHECK-LABEL: wasmssa.func @func_2() -> f32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 5.000000e+00 : f32
// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.400000e+01 : f32
// CHECK: %[[VAL_2:.*]] = wasmssa.sub %[[VAL_0]] %[[VAL_1]] : f32
// CHECK: wasmssa.return %[[VAL_2]] : f32
-// CHECK-LABEL: wasmssa.func nested @func_3() -> f64 {
+// CHECK-LABEL: wasmssa.func @func_3() -> f64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.700000e+01 : f64
// CHECK: %[[VAL_1:.*]] = wasmssa.const 0.000000e+00 : f64
// CHECK: %[[VAL_2:.*]] = wasmssa.sub %[[VAL_0]] %[[VAL_1]] : f64
diff --git a/mlir/test/Target/Wasm/wrap.mlir b/mlir/test/Target/Wasm/wrap.mlir
new file mode 100644
index 0000000..1266758
--- /dev/null
+++ b/mlir/test/Target/Wasm/wrap.mlir
@@ -0,0 +1,15 @@
+// RUN: yaml2obj %S/inputs/wrap.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+/* Source code used to create this test:
+(module
+ (func (export "i64_wrap") (param $in i64) (result i32)
+ local.get $in
+ i32.wrap_i64
+ )
+)
+*/
+
+// CHECK-LABEL: wasmssa.func exported @i64_wrap(
+// CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i64>) -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.local_get %[[ARG0]] : ref to i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.wrap %[[VAL_0]] : i64 to i32
+// CHECK: wasmssa.return %[[VAL_1]] : i32
diff --git a/mlir/test/Target/Wasm/xor.mlir b/mlir/test/Target/Wasm/xor.mlir
index 94691de..56407db 100644
--- a/mlir/test/Target/Wasm/xor.mlir
+++ b/mlir/test/Target/Wasm/xor.mlir
@@ -14,13 +14,13 @@
)
*/
-// CHECK-LABEL: wasmssa.func @xor_i32() -> i32 {
+// CHECK-LABEL: wasmssa.func exported @xor_i32() -> i32 {
// CHECK: %0 = wasmssa.const 10 : i32
// CHECK: %1 = wasmssa.const 3 : i32
// CHECK: %2 = wasmssa.xor %0 %1 : i32
// CHECK: wasmssa.return %2 : i32
-// CHECK-LABEL: wasmssa.func @xor_i64() -> i64 {
+// CHECK-LABEL: wasmssa.func exported @xor_i64() -> i64 {
// CHECK: %0 = wasmssa.const 10 : i64
// CHECK: %1 = wasmssa.const 3 : i64
// CHECK: %2 = wasmssa.xor %0 %1 : i64
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 727c84c..8c5c8e8 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -276,10 +276,8 @@ void TestLinalgTransforms::runOnOperation() {
Operation *consumer = opOperand->getOwner();
// If we have a pack/unpack consumer and a producer that has multiple
// uses, do not apply the folding patterns.
- if (isa<linalg::PackOp, linalg::UnPackOp>(consumer) &&
- isa<TilingInterface>(producer) && !producer->hasOneUse())
- return false;
- return true;
+ return !(isa<linalg::PackOp, linalg::UnPackOp>(consumer) &&
+ isa<TilingInterface>(producer) && !producer->hasOneUse());
};
applyFoldIntoPackAndUnpackPatterns(rootOp, controlFn);
}
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 97fc699..496f18b 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -938,10 +938,10 @@ public:
// These are automatically generated by ODS but are not used as the Transform
// dialect uses a different dispatch mechanism to support dialect extensions.
-LLVM_ATTRIBUTE_UNUSED static OptionalParseResult
+[[maybe_unused]] static OptionalParseResult
generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
-LLVM_ATTRIBUTE_UNUSED static LogicalResult
-generatedTypePrinter(Type def, AsmPrinter &printer);
+[[maybe_unused]] static LogicalResult generatedTypePrinter(Type def,
+ AsmPrinter &printer);
#define GET_TYPEDEF_CLASSES
#include "TestTransformDialectExtensionTypes.cpp.inc"
diff --git a/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll
index 4e869e5..4be30d8 100644
--- a/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll
+++ b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll
@@ -28,7 +28,7 @@
// CHECK: operation "test.op3"
// CHECK: )mlir", context), std::forward<ConfigsT>(configs)...)
-// CHECK: static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns(::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) {
+// CHECK{LITERAL}: [[maybe_unused]] static void populateGeneratedPDLLPatterns(::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) {
// CHECK-NEXT: patterns.add<GeneratedPDLLPattern0>(patterns.getContext(), configs...);
// CHECK-NEXT: patterns.add<NamedPattern>(patterns.getContext(), configs...);
// CHECK-NEXT: patterns.add<GeneratedPDLLPattern1>(patterns.getContext(), configs...);
diff --git a/mlir/test/python/dialects/gpu/dialect.py b/mlir/test/python/dialects/gpu/dialect.py
index 26ee9f3..66c4018 100644
--- a/mlir/test/python/dialects/gpu/dialect.py
+++ b/mlir/test/python/dialects/gpu/dialect.py
@@ -1,6 +1,7 @@
# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
+import mlir.ir as ir
import mlir.dialects.gpu as gpu
import mlir.dialects.gpu.passes
from mlir.passmanager import *
@@ -64,3 +65,95 @@ def testObjectAttr():
# CHECK: #gpu.object<#nvvm.target, kernels = <[#gpu.kernel_metadata<"kernel", () -> ()>]>, "BC\C0\DE5\14\00\00\05\00\00\00b\0C0$MY\BEf">
print(o)
assert o.kernels == kernelTable
+
+
+# CHECK-LABEL: testGPUFuncOp
+@run
+def testGPUFuncOp():
+ assert gpu.GPUFuncOp.__doc__ is not None
+ module = Module.create()
+ with InsertionPoint(module.body):
+ gpu_module_name = StringAttr.get("gpu_module")
+ gpumodule = gpu.GPUModuleOp(gpu_module_name)
+ block = gpumodule.bodyRegion.blocks.append()
+
+ def builder(func: gpu.GPUFuncOp) -> None:
+ gpu.GlobalIdOp(gpu.Dimension.x)
+ gpu.ReturnOp([])
+
+ with InsertionPoint(block):
+ name = StringAttr.get("kernel0")
+ func_type = ir.FunctionType.get(inputs=[], results=[])
+ type_attr = TypeAttr.get(func_type)
+ func = gpu.GPUFuncOp(type_attr, name)
+ func.attributes["sym_name"] = name
+ func.attributes["gpu.kernel"] = UnitAttr.get()
+
+ try:
+ func.entry_block
+ assert False, "Expected RuntimeError"
+ except RuntimeError as e:
+ assert (
+ str(e)
+ == "Entry block does not exist for kernel0. Do you need to call the add_entry_block() method on this GPUFuncOp?"
+ )
+
+ block = func.add_entry_block()
+ with InsertionPoint(block):
+ builder(func)
+
+ try:
+ func.add_entry_block()
+ assert False, "Expected RuntimeError"
+ except RuntimeError as e:
+ assert str(e) == "Entry block already exists for kernel0"
+
+ func = gpu.GPUFuncOp(
+ func_type,
+ sym_name="kernel1",
+ kernel=True,
+ body_builder=builder,
+ known_block_size=[1, 2, 3],
+ known_grid_size=DenseI32ArrayAttr.get([4, 5, 6]),
+ )
+
+ assert func.name.value == "kernel1"
+ assert func.function_type.value == func_type
+ assert func.arg_attrs == None
+ assert func.res_attrs == None
+ assert func.arguments == []
+ assert func.entry_block == func.body.blocks[0]
+ assert func.is_kernel
+ assert func.known_block_size == DenseI32ArrayAttr.get(
+ [1, 2, 3]
+ ), func.known_block_size
+ assert func.known_grid_size == DenseI32ArrayAttr.get(
+ [4, 5, 6]
+ ), func.known_grid_size
+
+ func = gpu.GPUFuncOp(
+ func_type,
+ sym_name="non_kernel_func",
+ body_builder=builder,
+ )
+ assert not func.is_kernel
+ assert func.known_block_size is None
+ assert func.known_grid_size is None
+
+ print(module)
+
+ # CHECK: gpu.module @gpu_module
+ # CHECK: gpu.func @kernel0() kernel {
+ # CHECK: %[[VAL_0:.*]] = gpu.global_id x
+ # CHECK: gpu.return
+ # CHECK: }
+ # CHECK: gpu.func @kernel1() kernel attributes
+ # CHECK-SAME: known_block_size = array<i32: 1, 2, 3>
+ # CHECK-SAME: known_grid_size = array<i32: 4, 5, 6>
+ # CHECK: %[[VAL_0:.*]] = gpu.global_id x
+ # CHECK: gpu.return
+ # CHECK: }
+ # CHECK: gpu.func @non_kernel_func() {
+ # CHECK: %[[VAL_0:.*]] = gpu.global_id x
+ # CHECK: gpu.return
+ # CHECK: }
diff --git a/mlir/test/python/dialects/openacc.py b/mlir/test/python/dialects/openacc.py
new file mode 100644
index 0000000..8f2142a
--- /dev/null
+++ b/mlir/test/python/dialects/openacc.py
@@ -0,0 +1,171 @@
+# RUN: %PYTHON %s | FileCheck %s
+from unittest import result
+from mlir.ir import (
+ Context,
+ FunctionType,
+ Location,
+ Module,
+ InsertionPoint,
+ IntegerType,
+ IndexType,
+ MemRefType,
+ F32Type,
+ Block,
+ ArrayAttr,
+ Attribute,
+ UnitAttr,
+ StringAttr,
+ DenseI32ArrayAttr,
+ ShapedType,
+)
+from mlir.dialects import openacc, func, arith, memref
+from mlir.extras import types
+
+
+def run(f):
+ print("\n// TEST:", f.__name__)
+ with Context(), Location.unknown():
+ f()
+ return f
+
+
+@run
+def testParallelMemcpy():
+ module = Module.create()
+
+ dynamic = ShapedType.get_dynamic_size()
+ memref_f32_1d_any = MemRefType.get([dynamic], types.f32())
+
+ with InsertionPoint(module.body):
+ function_type = FunctionType.get(
+ [memref_f32_1d_any, memref_f32_1d_any, types.i64()], []
+ )
+ f = func.FuncOp(
+ type=function_type,
+ name="memcpy_idiom",
+ )
+ f.attributes["sym_visibility"] = StringAttr.get("public")
+
+ with InsertionPoint(f.add_entry_block()):
+ c1024 = arith.ConstantOp(types.i32(), 1024)
+ c128 = arith.ConstantOp(types.i32(), 128)
+
+ arg0, arg1, arg2 = f.arguments
+
+ copied = openacc.copyin(
+ acc_var=arg0.type,
+ var=arg0,
+ var_type=types.f32(),
+ bounds=[],
+ async_operands=[],
+ implicit=False,
+ structured=True,
+ )
+ created = openacc.create_(
+ acc_var=arg1.type,
+ var=arg1,
+ var_type=types.f32(),
+ bounds=[],
+ async_operands=[],
+ implicit=False,
+ structured=True,
+ )
+
+ parallel_op = openacc.ParallelOp(
+ asyncOperands=[],
+ waitOperands=[],
+ numGangs=[c1024],
+ numWorkers=[],
+ vectorLength=[c128],
+ reductionOperands=[],
+ privateOperands=[],
+ firstprivateOperands=[],
+ dataClauseOperands=[],
+ )
+
+ # Set required device_type and segment attributes to satisfy verifier
+ acc_device_none = ArrayAttr.get([Attribute.parse("#acc.device_type<none>")])
+ parallel_op.numGangsDeviceType = acc_device_none
+ parallel_op.numGangsSegments = DenseI32ArrayAttr.get([1])
+ parallel_op.vectorLengthDeviceType = acc_device_none
+
+ parallel_block = Block.create_at_start(parent=parallel_op.region, arg_types=[])
+
+ with InsertionPoint(parallel_block):
+ c0 = arith.ConstantOp(types.i64(), 0)
+ c1 = arith.ConstantOp(types.i64(), 1)
+
+ loop_op = openacc.LoopOp(
+ results_=[],
+ lowerbound=[c0],
+ upperbound=[f.arguments[2]],
+ step=[c1],
+ gangOperands=[],
+ workerNumOperands=[],
+ vectorOperands=[],
+ tileOperands=[],
+ cacheOperands=[],
+ privateOperands=[],
+ reductionOperands=[],
+ firstprivateOperands=[],
+ )
+
+ # Set loop attributes: gang and independent on device_type<none>
+ acc_device_none = ArrayAttr.get([Attribute.parse("#acc.device_type<none>")])
+ loop_op.gang = acc_device_none
+ loop_op.independent = acc_device_none
+
+ loop_block = Block.create_at_start(
+ parent=loop_op.region, arg_types=[types.i64()]
+ )
+
+ with InsertionPoint(loop_block):
+ idx = arith.index_cast(out=IndexType.get(), in_=loop_block.arguments[0])
+ val = memref.load(memref=copied, indices=[idx])
+ memref.store(value=val, memref=created, indices=[idx])
+ openacc.YieldOp([])
+
+ openacc.YieldOp([])
+
+ deleted = openacc.delete(
+ acc_var=copied,
+ bounds=[],
+ async_operands=[],
+ implicit=False,
+ structured=True,
+ )
+ copied = openacc.copyout(
+ acc_var=created,
+ var=arg1,
+ var_type=types.f32(),
+ bounds=[],
+ async_operands=[],
+ implicit=False,
+ structured=True,
+ )
+ func.ReturnOp([])
+
+ print(module)
+
+ # CHECK: TEST: testParallelMemcpy
+ # CHECK-LABEL: func.func public @memcpy_idiom(
+ # CHECK-SAME: %[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: memref<?xf32>, %[[ARG2:.*]]: i64) {
+ # CHECK: %[[CONSTANT_0:.*]] = arith.constant 1024 : i32
+ # CHECK: %[[CONSTANT_1:.*]] = arith.constant 128 : i32
+ # CHECK: %[[COPYIN_0:.*]] = acc.copyin varPtr(%[[ARG0]] : memref<?xf32>) -> memref<?xf32>
+ # CHECK: %[[CREATE_0:.*]] = acc.create varPtr(%[[ARG1]] : memref<?xf32>) -> memref<?xf32>
+ # CHECK: acc.parallel num_gangs({%[[CONSTANT_0]] : i32}) vector_length(%[[CONSTANT_1]] : i32) {
+ # CHECK: %[[CONSTANT_2:.*]] = arith.constant 0 : i64
+ # CHECK: %[[CONSTANT_3:.*]] = arith.constant 1 : i64
+ # CHECK: acc.loop gang control(%[[VAL_0:.*]] : i64) = (%[[CONSTANT_2]] : i64) to (%[[ARG2]] : i64) step (%[[CONSTANT_3]] : i64) {
+ # CHECK: %[[INDEX_CAST_0:.*]] = arith.index_cast %[[VAL_0]] : i64 to index
+ # CHECK: %[[LOAD_0:.*]] = memref.load %[[COPYIN_0]]{{\[}}%[[INDEX_CAST_0]]] : memref<?xf32>
+ # CHECK: memref.store %[[LOAD_0]], %[[CREATE_0]]{{\[}}%[[INDEX_CAST_0]]] : memref<?xf32>
+ # CHECK: acc.yield
+ # CHECK: } attributes {independent = [#acc.device_type<none>]}
+ # CHECK: acc.yield
+ # CHECK: }
+ # CHECK: acc.delete accPtr(%[[COPYIN_0]] : memref<?xf32>)
+ # CHECK: acc.copyout accPtr(%[[CREATE_0]] : memref<?xf32>) to varPtr(%[[ARG1]] : memref<?xf32>)
+ # CHECK: return
+ # CHECK: }
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index cb4cfc8c..1d4ede1 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -569,12 +569,30 @@ def testOperationAttributes():
# CHECK: Attribute value b'text'
print(f"Attribute value {sattr.value_bytes}")
+ # Python dict-style iteration
# We don't know in which order the attributes are stored.
- # CHECK-DAG: NamedAttribute(dependent="text")
- # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64)
- # CHECK-DAG: NamedAttribute(some.attribute=1 : i8)
- for attr in op.attributes:
- print(str(attr))
+ # CHECK-DAG: dependent
+ # CHECK-DAG: other.attribute
+ # CHECK-DAG: some.attribute
+ for name in op.attributes:
+ print(name)
+
+ # Basic dict-like introspection
+ # CHECK: True
+ print("some.attribute" in op.attributes)
+ # CHECK: False
+ print("missing" in op.attributes)
+ # CHECK: Keys: ['dependent', 'other.attribute', 'some.attribute']
+ print("Keys:", sorted(op.attributes.keys()))
+ # CHECK: Values count 3
+ print("Values count", len(op.attributes.values()))
+ # CHECK: Items count 3
+ print("Items count", len(op.attributes.items()))
+
+ # Dict() conversion test
+ d = {k: v.value for k, v in dict(op.attributes).items()}
+ # CHECK: Dict mapping {'dependent': 'text', 'other.attribute': 3.0, 'some.attribute': 1}
+ print("Dict mapping", d)
# Check that exceptions are raised as expected.
try:
diff --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
index 96af14d..11a2db4 100644
--- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
+++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
@@ -416,7 +416,7 @@ static void emitOneEnumToConversion(const Record *record, raw_ostream &os) {
// Emit the function converting the enum attribute to its LLVM counterpart.
os << formatv(
- "static LLVM_ATTRIBUTE_UNUSED {0} convert{1}ToLLVM({2}::{1} value) {{\n",
+ "[[maybe_unused]] static {0} convert{1}ToLLVM({2}::{1} value) {{\n",
llvmClass, cppClassName, cppNamespace);
os << " switch (value) {\n";
@@ -444,7 +444,7 @@ static void emitOneCEnumToConversion(const Record *record, raw_ostream &os) {
StringRef cppNamespace = enumAttr.getCppNamespace();
// Emit the function converting the enum attribute to its LLVM counterpart.
- os << formatv("static LLVM_ATTRIBUTE_UNUSED int64_t "
+ os << formatv("[[maybe_unused]] static int64_t "
"convert{0}ToLLVM({1}::{0} value) {{\n",
cppClassName, cppNamespace);
os << " switch (value) {\n";
@@ -474,7 +474,7 @@ static void emitOneEnumFromConversion(const Record *record, raw_ostream &os) {
StringRef cppNamespace = enumInfo.getCppNamespace();
// Emit the function converting the enum attribute from its LLVM counterpart.
- os << formatv("inline LLVM_ATTRIBUTE_UNUSED {0}::{1} convert{1}FromLLVM({2} "
+ os << formatv("[[maybe_unused]] inline {0}::{1} convert{1}FromLLVM({2} "
"value) {{\n",
cppNamespace, cppClassName, llvmClass);
os << " switch (value) {\n";
@@ -509,10 +509,9 @@ static void emitOneCEnumFromConversion(const Record *record, raw_ostream &os) {
StringRef cppNamespace = enumInfo.getCppNamespace();
// Emit the function converting the enum attribute from its LLVM counterpart.
- os << formatv(
- "inline LLVM_ATTRIBUTE_UNUSED {0}::{1} convert{1}FromLLVM(int64_t "
- "value) {{\n",
- cppNamespace, cppClassName);
+ os << formatv("[[maybe_unused]] inline {0}::{1} convert{1}FromLLVM(int64_t "
+ "value) {{\n",
+ cppNamespace, cppClassName);
os << " switch (value) {\n";
for (const auto &enumerant : enumInfo.getAllCases()) {
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index daae3c7..3718648 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -4896,7 +4896,7 @@ static void emitOpClassDefs(const RecordKeeper &records,
constraintPrefix);
os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
staticVerifierEmitter.collectOpConstraints(defs);
- staticVerifierEmitter.emitOpConstraints(defs);
+ staticVerifierEmitter.emitOpConstraints();
// Emit the classes.
emitOpClasses(records, defs, os, staticVerifierEmitter,
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 40bc1a9..c3034bb8 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -2120,7 +2120,7 @@ static void emitRewriters(const RecordKeeper &records, raw_ostream &os) {
}
// Emit function to add the generated matchers to the pattern list.
- os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated("
+ os << "[[maybe_unused]] void populateWithGenerated("
"::mlir::RewritePatternSet &patterns) {\n";
for (const auto &name : rewriterNames) {
os << " patterns.add<" << name << ">(patterns.getContext());\n";
diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
index fe26fc1..2a58305 100644
--- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
+++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
@@ -113,8 +113,7 @@ static Match tensorMatch(TensorId tid) { return Match(tid); }
static Match synZeroMatch() { return Match(); }
#define IMPL_BINOP_PATTERN(OP, KIND) \
- LLVM_ATTRIBUTE_UNUSED static Match OP##Match(const Match &e0, \
- const Match &e1) { \
+ [[maybe_unused]] static Match OP##Match(const Match &e0, const Match &e1) { \
return Match(KIND, e0, e1); \
}
FOREVERY_BINOP(IMPL_BINOP_PATTERN)