aboutsummaryrefslogtreecommitdiff
path: root/mlir/test
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test')
-rw-r--r--mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir (renamed from mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir)27
-rw-r--r--mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir46
-rw-r--r--mlir/test/Conversion/GPUToLLVMSPV/printf.mlir16
-rw-r--r--mlir/test/Dialect/AMDGPU/invalid.mlir46
-rw-r--r--mlir/test/Dialect/AMDGPU/ops.mlir15
-rw-r--r--mlir/test/Dialect/Affine/unroll.mlir6
-rw-r--r--mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir14
-rw-r--r--mlir/test/Dialect/SCF/canonicalize.mlir142
-rw-r--r--mlir/test/Dialect/Tosa/availability.mlir90
-rw-r--r--mlir/test/Dialect/Tosa/canonicalize.mlir8
-rw-r--r--mlir/test/Dialect/Tosa/invalid.mlir9
-rw-r--r--mlir/test/Dialect/Tosa/invalid_extension.mlir46
-rw-r--r--mlir/test/Dialect/Tosa/level_check.mlir40
-rw-r--r--mlir/test/Dialect/Tosa/ops.mlir70
-rw-r--r--mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir23
-rw-r--r--mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir99
-rw-r--r--mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir66
-rw-r--r--mlir/test/Dialect/Tosa/verifier.mlir136
-rw-r--r--mlir/test/Dialect/Vector/vector-mem-transforms.mlir184
-rw-r--r--mlir/test/Dialect/XeGPU/xegpu-blocking.mlir70
-rw-r--r--mlir/test/Integration/Dialect/XeVM/GPU/gpu_printf.mlir30
-rw-r--r--mlir/test/Transforms/scf-loop-unroll.mlir2
-rw-r--r--mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp20
23 files changed, 1012 insertions, 193 deletions
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
index 638a7c3..d1301d0 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
@@ -1,35 +1,36 @@
-// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s
+
// CHECK-LABEL: @wmma_to_rocdl
func.func @wmma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 : vector<4xf32>,
%arg3 : vector<16xbf16>, %arg4 : vector<8xf16>, %arg5 : vector<8xbf16>,
%arg6 : vector<16xi8>, %arg7 : vector<8xi32>, %arg8 : vector<4xi32>,
%arg9 : vector<16xui8>, %arg10 : vector<16xi4>, %arg11 : vector<8xi4>) {
// CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32>
- amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32>
+ amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32>
// CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32>
- amdgpu.wmma %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32>
+ amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32>
// CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xf32>) -> vector<8xf32>
- amdgpu.wmma %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
+ amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
// CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<4xf32>) -> vector<4xf32>
- amdgpu.wmma %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32>
+ amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32>
// CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
- amdgpu.wmma %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16>
+ amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16>
// CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16>
- amdgpu.wmma %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
+ amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
// CHECK: %[[raw_bf16x16:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16>
// CHECK-NEXT: llvm.bitcast %[[raw_bf16x16]] : vector<16xi16> to vector<16xbf16>
- amdgpu.wmma %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
+ amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
// CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16>
// CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16>
- amdgpu.wmma %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
+ amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
// CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32>
- amdgpu.wmma %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32>
+ amdgpu.wmma 16x16x16 %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32>
// CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32>
- amdgpu.wmma %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32>
+ amdgpu.wmma 16x16x16 %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32>
// CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
- amdgpu.wmma %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32>
+ amdgpu.wmma 16x16x16 %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32>
// CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
- amdgpu.wmma %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32>
+ amdgpu.wmma 16x16x16 %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32>
func.return
}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
index 94a1b78..b8973233 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1200 --allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1200 --allow-unregistered-dialect | FileCheck %s
// CHECK-LABEL: @wmma_to_rocdl
func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<4xf16>,
%arg2 : vector<8xf32>, %arg3 : vector<4xf32>,
@@ -9,60 +9,60 @@ func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<4xf16>,
%arg12 : vector<8xi32>, %arg13 : vector<4xi32>,
%arg14 : vector<16xi4>, %arg15 : vector<8xi4>, %arg16 : vector<4xi4>) {
// CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<8xf32>) -> vector<8xf32>
- amdgpu.wmma %arg0 * %arg0 + %arg2 : vector<8xf16>, vector<8xf16>, vector<8xf32>
+ amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg2 : vector<8xf16>, vector<8xf16>, vector<8xf32>
// CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf32>) -> vector<4xf32>
- amdgpu.wmma %arg1 * %arg1 + %arg3 : vector<4xf16>, vector<4xf16>, vector<4xf32>
+ amdgpu.wmma 16x16x16 %arg1 * %arg1 + %arg3 : vector<4xf16>, vector<4xf16>, vector<4xf32>
// CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<8xf32>) -> vector<8xf32>
- amdgpu.wmma %arg4 * %arg4 + %arg2 : vector<8xbf16>, vector<8xbf16>, vector<8xf32>
+ amdgpu.wmma 16x16x16 %arg4 * %arg4 + %arg2 : vector<8xbf16>, vector<8xbf16>, vector<8xf32>
// CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xf32>) -> vector<4xf32>
- amdgpu.wmma %arg5 * %arg5 + %arg3 : vector<4xbf16>, vector<4xbf16>, vector<4xf32>
+ amdgpu.wmma 16x16x16 %arg5 * %arg5 + %arg3 : vector<4xbf16>, vector<4xbf16>, vector<4xf32>
// CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<8xf16>, i1) -> vector<8xf16>
- amdgpu.wmma %arg0 * %arg0 + %arg0 : vector<8xf16>, vector<8xf16>, vector<8xf16>
+ amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg0 : vector<8xf16>, vector<8xf16>, vector<8xf16>
// CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf16>, i1) -> vector<4xf16>
- amdgpu.wmma %arg1 * %arg1 + %arg1 : vector<4xf16>, vector<4xf16>, vector<4xf16>
+ amdgpu.wmma 16x16x16 %arg1 * %arg1 + %arg1 : vector<4xf16>, vector<4xf16>, vector<4xf16>
// CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<8xi16>, i1) -> vector<8xi16>
// CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16>
- amdgpu.wmma %arg4 * %arg4 + %arg4 : vector<8xbf16>, vector<8xbf16>, vector<8xbf16>
+ amdgpu.wmma 16x16x16 %arg4 * %arg4 + %arg4 : vector<8xbf16>, vector<8xbf16>, vector<8xbf16>
// CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xi16>, i1) -> vector<4xi16>
- amdgpu.wmma %arg5 * %arg5 + %arg5 : vector<4xbf16>, vector<4xbf16>, vector<4xbf16>
+ amdgpu.wmma 16x16x16 %arg5 * %arg5 + %arg5 : vector<4xbf16>, vector<4xbf16>, vector<4xbf16>
// CHECK: rocdl.wmma.f32.16x16x16.fp8_fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
- amdgpu.wmma %arg6 * %arg6 + %arg2 : vector<8xf8E4M3FN>, vector<8xf8E4M3FN>, vector<8xf32>
+ amdgpu.wmma 16x16x16 %arg6 * %arg6 + %arg2 : vector<8xf8E4M3FN>, vector<8xf8E4M3FN>, vector<8xf32>
// CHECK: rocdl.wmma.f32.16x16x16.fp8_fp8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32>
- amdgpu.wmma %arg7 * %arg7 + %arg3 : vector<4xf8E4M3FN>, vector<4xf8E4M3FN>, vector<4xf32>
+ amdgpu.wmma 16x16x16 %arg7 * %arg7 + %arg3 : vector<4xf8E4M3FN>, vector<4xf8E4M3FN>, vector<4xf32>
// CHECK: rocdl.wmma.f32.16x16x16.fp8_bf8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
- amdgpu.wmma %arg6 * %arg8 + %arg2 : vector<8xf8E4M3FN>, vector<8xf8E5M2>, vector<8xf32>
+ amdgpu.wmma 16x16x16 %arg6 * %arg8 + %arg2 : vector<8xf8E4M3FN>, vector<8xf8E5M2>, vector<8xf32>
// CHECK: rocdl.wmma.f32.16x16x16.fp8_bf8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32>
- amdgpu.wmma %arg7 * %arg9 + %arg3 : vector<4xf8E4M3FN>, vector<4xf8E5M2>, vector<4xf32>
+ amdgpu.wmma 16x16x16 %arg7 * %arg9 + %arg3 : vector<4xf8E4M3FN>, vector<4xf8E5M2>, vector<4xf32>
// CHECK: rocdl.wmma.f32.16x16x16.bf8_bf8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
- amdgpu.wmma %arg8 * %arg8 + %arg2 : vector<8xf8E5M2>, vector<8xf8E5M2>, vector<8xf32>
+ amdgpu.wmma 16x16x16 %arg8 * %arg8 + %arg2 : vector<8xf8E5M2>, vector<8xf8E5M2>, vector<8xf32>
// CHECK: rocdl.wmma.f32.16x16x16.bf8_bf8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32>
- amdgpu.wmma %arg9 * %arg9 + %arg3 : vector<4xf8E5M2>, vector<4xf8E5M2>, vector<4xf32>
+ amdgpu.wmma 16x16x16 %arg9 * %arg9 + %arg3 : vector<4xf8E5M2>, vector<4xf8E5M2>, vector<4xf32>
// CHECK: rocdl.wmma.f32.16x16x16.bf8_fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
- amdgpu.wmma %arg8 * %arg6 + %arg2 : vector<8xf8E5M2>, vector<8xf8E4M3FN>, vector<8xf32>
+ amdgpu.wmma 16x16x16 %arg8 * %arg6 + %arg2 : vector<8xf8E5M2>, vector<8xf8E4M3FN>, vector<8xf32>
// CHECK: rocdl.wmma.f32.16x16x16.bf8_fp8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32>
- amdgpu.wmma %arg9 * %arg7 + %arg3 : vector<4xf8E5M2>, vector<4xf8E4M3FN>, vector<4xf32>
+ amdgpu.wmma 16x16x16 %arg9 * %arg7 + %arg3 : vector<4xf8E5M2>, vector<4xf8E4M3FN>, vector<4xf32>
// CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
- amdgpu.wmma %arg10 * %arg10 + %arg12 {clamp} : vector<8xi8>, vector<8xi8>, vector<8xi32>
+ amdgpu.wmma 16x16x16 %arg10 * %arg10 + %arg12 {clamp} : vector<8xi8>, vector<8xi8>, vector<8xi32>
// CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
- amdgpu.wmma %arg11 * %arg11 + %arg13 {unsignedA, unsignedB, clamp}: vector<4xi8>, vector<4xi8>, vector<4xi32>
+ amdgpu.wmma 16x16x16 %arg11 * %arg11 + %arg13 {unsignedA, unsignedB, clamp}: vector<4xi8>, vector<4xi8>, vector<4xi32>
// CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
- amdgpu.wmma %arg14 * %arg14 + %arg12 {clamp} : vector<16xi4>, vector<16xi4>, vector<8xi32>
+ amdgpu.wmma 16x16x32 %arg14 * %arg14 + %arg12 {clamp} : vector<16xi4>, vector<16xi4>, vector<8xi32>
// CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
- amdgpu.wmma %arg15 * %arg15 + %arg13 {clamp} : vector<8xi4>, vector<8xi4>, vector<4xi32>
+ amdgpu.wmma 16x16x32 %arg15 * %arg15 + %arg13 {clamp} : vector<8xi4>, vector<8xi4>, vector<4xi32>
// CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<8xi32>, i1) -> vector<8xi32>
- amdgpu.wmma %arg15 * %arg15 + %arg12 {clamp} : vector<8xi4>, vector<8xi4>, vector<8xi32>
+ amdgpu.wmma 16x16x16 %arg15 * %arg15 + %arg12 {clamp} : vector<8xi4>, vector<8xi4>, vector<8xi32>
// CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
- amdgpu.wmma %arg16 * %arg16 + %arg13 {clamp} : vector<4xi4>, vector<4xi4>, vector<4xi32>
+ amdgpu.wmma 16x16x16 %arg16 * %arg16 + %arg13 {clamp} : vector<4xi4>, vector<4xi4>, vector<4xi32>
func.return
}
diff --git a/mlir/test/Conversion/GPUToLLVMSPV/printf.mlir b/mlir/test/Conversion/GPUToLLVMSPV/printf.mlir
new file mode 100644
index 0000000..74017e8
--- /dev/null
+++ b/mlir/test/Conversion/GPUToLLVMSPV/printf.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-opt %s -convert-gpu-to-llvm-spv | FileCheck %s
+
+gpu.module @test_module {
+ // CHECK: llvm.mlir.global internal constant @[[$PRINT_GLOBAL:[A-Za-z0-9_]+]]("Hello: %d\0A\00") {addr_space = 2 : i32}
+ // CHECK: llvm.func spir_funccc @_Z6printfPU3AS2Kcz(!llvm.ptr<2>, ...) -> i32
+ // CHECK-LABEL: llvm.func spir_funccc @test_printf
+ // CHECK: (%[[ARG0:.*]]: i32)
+ gpu.func @test_printf(%arg0: i32) {
+ // CHECK: %[[IMM0:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL]] : !llvm.ptr<2>
+ // CHECK-NEXT: %[[IMM2:.*]] = llvm.getelementptr %[[IMM0]][0, 0] : (!llvm.ptr<2>) -> !llvm.ptr<2>, !llvm.array<11 x i8>
+ // CHECK-NEXT: %{{.*}} = llvm.call spir_funccc @_Z6printfPU3AS2Kcz(%[[IMM2]], %[[ARG0]]) vararg(!llvm.func<i32 (ptr<2>, ...)>) : (!llvm.ptr<2>, i32) -> i32
+ gpu.printf "Hello: %d\n", %arg0 : i32
+ gpu.return
+ }
+}
+
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index a8256b1..6a2518a 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -120,9 +120,49 @@ func.func @no_negation(%a: f32, %b: f32, %c: vector<32xf32>) -> vector<32xf32> {
// -----
-func.func @wmma(%arg0 : vector<16xf16>, %arg1 : vector<8xi32>) -> vector<8xi32> {
- // expected-error@+1 {{'amdgpu.wmma' op Expected int sources with int destination}}
- %0 = amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xi32>
+func.func @wmma_f16_i32(%arg0 : vector<16xf16>, %arg1 : vector<8xi32>) -> vector<8xi32> {
+ // expected-error@+1 {{'amdgpu.wmma' op expected int sources with int destination}}
+ %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xi32>
+ func.return %0 : vector<8xi32>
+}
+
+// -----
+
+func.func @wmma_i16_f32(%arg0 : vector<16xi8>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+ // expected-error@+1 {{'amdgpu.wmma' op expected float sources with float destination}}
+ %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xf32>
+ func.return %0 : vector<8xf32>
+}
+
+// -----
+
+func.func @wmma_no_k_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
+ // expected-error@+1 {{'amdgpu.wmma' expected 3 dimensions in MNK dimension list}}
+ %0 = amdgpu.wmma 16x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
+ func.return %0 : vector<8xi32>
+}
+
+// -----
+
+func.func @wmma_wrong_m_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
+ // expected-error@+1 {{'amdgpu.wmma' op attribute 'm' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16}}}
+ %0 = amdgpu.wmma 32x16x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
+ func.return %0 : vector<8xi32>
+}
+
+// -----
+
+func.func @wmma_wrong_n_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
+ // expected-error@+1 {{'amdgpu.wmma' op attribute 'n' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16}}}
+ %0 = amdgpu.wmma 16x32x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
+ func.return %0 : vector<8xi32>
+}
+
+// -----
+
+func.func @wmma_wrong_k_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
+ // expected-error@+1 {{'amdgpu.wmma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16, 32}}}
+ %0 = amdgpu.wmma 16x16x24 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
func.return %0 : vector<8xi32>
}
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index f9c6899..a185eb6 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -565,13 +565,20 @@ func.func @mfma(%arg0 : f32, %arg1 : vector<32xf32>) -> vector<32xf32> {
func.return %0 : vector<32xf32>
}
-// CHECK-LABEL: func @wmma
-func.func @wmma(%arg0 : vector<16xf16>, %arg1 : vector<8xf16>) -> vector<8xf16> {
- // CHECK: amdgpu.wmma
- %0 = amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf16>
+// CHECK-LABEL: func @wmma_f16_16x16x16_f16
+func.func @wmma_f16_16x16x16_f16(%arg0 : vector<16xf16>, %arg1 : vector<8xf16>) -> vector<8xf16> {
+ // CHECK: amdgpu.wmma 16x16x16
+ %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf16>
func.return %0 : vector<8xf16>
}
+// CHECK-LABEL: func @wmma_i32_16x16x32_i4
+func.func @wmma_i32_16x16x32_i4(%arg0 : vector<16xi4>, %arg1 : vector<8xi32>) -> vector<8xi32> {
+ // CHECK: amdgpu.wmma 16x16x32
+ %0 = amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg1 : vector<16xi4>, vector<16xi4>, vector<8xi32>
+ func.return %0 : vector<8xi32>
+}
+
// CHECK-LABEL: func @swizzle_bitmode
func.func @swizzle_bitmode(%arg0 : f32) -> f32 {
// CHECK: amdgpu.swizzle_bitmode
diff --git a/mlir/test/Dialect/Affine/unroll.mlir b/mlir/test/Dialect/Affine/unroll.mlir
index 574e9f4..efdceed 100644
--- a/mlir/test/Dialect/Affine/unroll.mlir
+++ b/mlir/test/Dialect/Affine/unroll.mlir
@@ -1,9 +1,9 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="builtin.module(func.func(affine-loop-unroll{unroll-full=true}))" | FileCheck %s --check-prefix UNROLL-FULL
-// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="builtin.module(func.func(affine-loop-unroll{unroll-full=true unroll-full-threshold=2}))" | FileCheck %s --check-prefix SHORT
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="builtin.module(func.func(affine-loop-unroll{unroll-factor=-1}))" | FileCheck %s --check-prefix UNROLL-FULL
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="builtin.module(func.func(affine-loop-unroll{unroll-factor=-1 unroll-full-threshold=2}))" | FileCheck %s --check-prefix SHORT
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="builtin.module(func.func(affine-loop-unroll{unroll-factor=4}))" | FileCheck %s --check-prefix UNROLL-BY-4
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="builtin.module(func.func(affine-loop-unroll{unroll-factor=1}))" | FileCheck %s --check-prefix UNROLL-BY-1
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="builtin.module(func.func(affine-loop-unroll{unroll-factor=5 cleanup-unroll=true}))" | FileCheck %s --check-prefix UNROLL-CLEANUP-LOOP
-// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="builtin.module(gpu.module(gpu.func(affine-loop-unroll{unroll-full=true})))" | FileCheck %s --check-prefix GPU-UNROLL-FULL
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="builtin.module(gpu.module(gpu.func(affine-loop-unroll{unroll-factor=-1})))" | FileCheck %s --check-prefix GPU-UNROLL-FULL
// UNROLL-FULL-DAG: [[$MAP0:#map[0-9]*]] = affine_map<(d0) -> (d0 + 1)>
// UNROLL-FULL-DAG: [[$MAP1:#map[0-9]*]] = affine_map<(d0) -> (d0 + 2)>
diff --git a/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
index 8bd7ae8..ac1f22b 100644
--- a/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
@@ -63,6 +63,20 @@ func.func @memref_dim_all_positive(%m: memref<?xf32>, %x: index) {
// -----
+// CHECK-LABEL: func @memref_expand(
+// CHECK-SAME: %[[m:[a-zA-Z0-9]+]]: memref<?xf32>
+// CHECK-SAME: %[[sz:[a-zA-Z0-9]+]]: index
+// CHECK: %[[c4:.*]] = arith.constant 4 : index
+// CHECK: return %[[sz]], %[[c4]]
+func.func @memref_expand(%m: memref<?xf32>, %sz: index) -> (index, index) {
+ %0 = memref.expand_shape %m [[0, 1]] output_shape [%sz, 4]: memref<?xf32> into memref<?x4xf32>
+ %1 = "test.reify_bound"(%0) {dim = 0} : (memref<?x4xf32>) -> (index)
+ %2 = "test.reify_bound"(%0) {dim = 1} : (memref<?x4xf32>) -> (index)
+ return %1, %2 : index, index
+}
+
+// -----
+
// CHECK-LABEL: func @memref_get_global(
// CHECK: %[[c4:.*]] = arith.constant 4 : index
// CHECK: return %[[c4]]
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 2bec636..084c3fc 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1604,6 +1604,148 @@ func.func @func_execute_region_inline_multi_yield() {
// -----
+// Test case with single scf.yield op inside execute_region and its operand is defined outside the execute_region op.
+// Make scf.execute_region not to return anything.
+
+// CHECK: scf.execute_region no_inline {
+// CHECK: func.call @foo() : () -> ()
+// CHECK: scf.yield
+// CHECK: }
+
+module {
+func.func private @foo()->()
+func.func private @execute_region_yeilding_external_value() -> memref<1x60xui8> {
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
+ %1 = scf.execute_region -> memref<1x60xui8> no_inline {
+ func.call @foo():()->()
+ scf.yield %alloc: memref<1x60xui8>
+ }
+ return %1 : memref<1x60xui8>
+}
+}
+
+// -----
+
+// Test case with scf.yield op inside execute_region with multiple operands.
+// One of operands is defined outside the execute_region op.
+// Remove just this operand from the op results.
+
+// CHECK: %[[VAL_1:.*]] = scf.execute_region -> memref<1x120xui8> no_inline {
+// CHECK: %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
+// CHECK: func.call @foo() : () -> ()
+// CHECK: scf.yield %[[VAL_2]] : memref<1x120xui8>
+// CHECK: }
+module {
+func.func private @foo()->()
+func.func private @execute_region_yeilding_external_and_local_values() -> (memref<1x60xui8>, memref<1x120xui8>) {
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
+ %1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
+ %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
+ func.call @foo():()->()
+ scf.yield %alloc, %alloc_1: memref<1x60xui8>, memref<1x120xui8>
+ }
+ return %1, %2 : memref<1x60xui8>, memref<1x120xui8>
+}
+}
+
+// -----
+
+// Test case with multiple scf.yield ops inside execute_region with same operands and those operands are defined outside the execute_region op..
+// Make scf.execute_region not to return anything.
+// scf.yield must remain, cause scf.execute_region can't be empty.
+
+// CHECK: scf.execute_region no_inline {
+// CHECK: %[[VAL_3:.*]] = "test.cmp"() : () -> i1
+// CHECK: cf.cond_br %[[VAL_3]], ^bb1, ^bb2
+// CHECK: ^bb1:
+// CHECK: scf.yield
+// CHECK: ^bb2:
+// CHECK: scf.yield
+// CHECK: }
+
+module {
+ func.func private @foo()->()
+ func.func private @execute_region_multiple_yields_same_operands() -> (memref<1x60xui8>, memref<1x120xui8>) {
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
+ %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
+ %1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
+ %c = "test.cmp"() : () -> i1
+ cf.cond_br %c, ^bb2, ^bb3
+ ^bb2:
+ func.call @foo():()->()
+ scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8>
+ ^bb3:
+ func.call @foo():()->()
+ scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8>
+ }
+ return %1, %2 : memref<1x60xui8>, memref<1x120xui8>
+ }
+}
+
+// -----
+
+// Test case with multiple scf.yield ops with at least one different operand, then no change.
+
+// CHECK: %[[VAL_3:.*]]:2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
+// CHECK: ^bb1:
+// CHECK: scf.yield %{{.*}}, %{{.*}} : memref<1x60xui8>, memref<1x120xui8>
+// CHECK: ^bb2:
+// CHECK: scf.yield %{{.*}}, %{{.*}} : memref<1x60xui8>, memref<1x120xui8>
+// CHECK: }
+
+module {
+ func.func private @foo()->()
+ func.func private @execute_region_multiple_yields_different_operands() -> (memref<1x60xui8>, memref<1x120xui8>) {
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
+ %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
+ %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
+ %1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
+ %c = "test.cmp"() : () -> i1
+ cf.cond_br %c, ^bb2, ^bb3
+ ^bb2:
+ func.call @foo():()->()
+ scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8>
+ ^bb3:
+ func.call @foo():()->()
+ scf.yield %alloc, %alloc_2 : memref<1x60xui8>, memref<1x120xui8>
+ }
+ return %1, %2 : memref<1x60xui8>, memref<1x120xui8>
+ }
+}
+
+// -----
+
+// Test case with multiple scf.yield ops each has different operand.
+// In this case scf.execute_region isn't changed.
+
+// CHECK: %[[VAL_2:.*]] = scf.execute_region -> memref<1x60xui8> no_inline {
+// CHECK: ^bb1:
+// CHECK: scf.yield %{{.*}} : memref<1x60xui8>
+// CHECK: ^bb2:
+// CHECK: scf.yield %{{.*}} : memref<1x60xui8>
+// CHECK: }
+
+module {
+func.func private @foo()->()
+func.func private @execute_region_multiple_yields_different_operands() -> (memref<1x60xui8>) {
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
+ %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
+ %1 = scf.execute_region -> (memref<1x60xui8>) no_inline {
+ %c = "test.cmp"() : () -> i1
+ cf.cond_br %c, ^bb2, ^bb3
+ ^bb2:
+ func.call @foo():()->()
+ scf.yield %alloc : memref<1x60xui8>
+ ^bb3:
+ func.call @foo():()->()
+ scf.yield %alloc_1 : memref<1x60xui8>
+ }
+ return %1 : memref<1x60xui8>
+}
+}
+
+// -----
+
// CHECK-LABEL: func @canonicalize_parallel_insert_slice_indices(
// CHECK-SAME: %[[arg0:.*]]: tensor<1x5xf32>, %[[arg1:.*]]: tensor<?x?xf32>
func.func @canonicalize_parallel_insert_slice_indices(
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index 600c4c7..a05f423 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -9,7 +9,7 @@
// CHECK-LABEL: argmax
func.func @test_argmax(%arg0: tensor<14x19xf32>) -> tensor<14xi32> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ [int16, fp8e4m3, fp8e5m2, bf16] ]
+ // CHECK: extensions: [ [int16, fp8e4m3, fp8e5m2, bf16, int64] ]
%0 = tosa.argmax %arg0 {axis = 1 : i32} : (tensor<14x19xf32>) -> tensor<14xi32>
return %0 : tensor<14xi32>
}
@@ -144,7 +144,7 @@ func.func @test_erf(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
// CHECK-LABEL: add
func.func @test_add(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ [bf16] ]
+ // CHECK: extensions: [ [bf16, int64] ]
%0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
@@ -153,7 +153,7 @@ func.func @test_add(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> t
// CHECK-LABEL: arithmetic_right_shift
func.func @test_arithmetic_right_shift(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
// CHECK: profiles: [ [pro_int] ]
- // CHECK: extensions: [ ]
+ // CHECK: extensions: [ [int64] ]
%0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
@@ -162,7 +162,7 @@ func.func @test_arithmetic_right_shift(%arg0: tensor<13x21x1xf32>, %arg1: tensor
// CHECK-LABEL: bitwise_and
func.func @test_bitwise_and(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<13x21x3xi32> {
// CHECK: profiles: [ [pro_int] ]
- // CHECK: extensions: [ ]
+ // CHECK: extensions: [ [int64] ]
%0 = tosa.bitwise_and %arg0, %arg1 : (tensor<13x21x3xi32>, tensor<13x21x1xi32>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
}
@@ -171,7 +171,7 @@ func.func @test_bitwise_and(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x21x1xi3
// CHECK-LABEL: bitwise_or
func.func @test_bitwise_or(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> {
// CHECK: profiles: [ [pro_int] ]
- // CHECK: extensions: [ ]
+ // CHECK: extensions: [ [int64] ]
%0 = tosa.bitwise_or %arg0, %arg1 : (tensor<13x21x3xi32>, tensor<13x1x3xi32>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
}
@@ -180,7 +180,7 @@ func.func @test_bitwise_or(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>
// CHECK-LABEL: bitwise_xor
func.func @test_bitwise_xor(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
// CHECK: profiles: [ [pro_int] ]
- // CHECK: extensions: [ ]
+ // CHECK: extensions: [ [int64] ]
%0 = tosa.bitwise_xor %arg0, %arg1 : (tensor<13x21x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
}
@@ -189,7 +189,7 @@ func.func @test_bitwise_xor(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi3
// CHECK-LABEL: intdiv
func.func @test_intdiv(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ ]
+ // CHECK: extensions: [ [int64] ]
%0 = tosa.intdiv %arg0, %arg1 : (tensor<13x21x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
}
@@ -207,7 +207,7 @@ func.func @test_logical_and(%arg0: tensor<13x21x3xi1>, %arg1: tensor<13x21x1xi1>
// CHECK-LABEL: logical_left_shift
func.func @test_logical_left_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<13x21x3xi32> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ ]
+ // CHECK: extensions: [ [int64] ]
%0 = tosa.logical_left_shift %arg0, %arg1 : (tensor<13x21x3xi32>, tensor<13x21x1xi32>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
}
@@ -216,7 +216,7 @@ func.func @test_logical_left_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x
// CHECK-LABEL: logical_right_shift
func.func @test_logical_right_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<13x21x3xi32> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ ]
+ // CHECK: extensions: [ [int64] ]
%0 = tosa.logical_right_shift %arg0, %arg1 : (tensor<13x21x3xi32>, tensor<13x21x1xi32>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
}
@@ -243,7 +243,7 @@ func.func @test_logical_xor(%arg0: tensor<13x1x3xi1>, %arg1: tensor<13x21x3xi1>)
// CHECK-LABEL: maximum
func.func @test_max(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ [bf16] ]
+ // CHECK: extensions: [ [bf16, int64] ]
%0 = tosa.maximum %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x21x1xf32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
@@ -252,7 +252,7 @@ func.func @test_max(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> t
// CHECK-LABEL: minimum
func.func @test_min(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> tensor<13x21x3xf32> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ [bf16] ]
+ // CHECK: extensions: [ [bf16, int64] ]
%0 = tosa.minimum %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<1x21x3xf32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
@@ -262,7 +262,7 @@ func.func @test_min(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> te
func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
%shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ [bf16] ]
+ // CHECK: extensions: [ [bf16, int64] ]
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
@@ -280,7 +280,7 @@ func.func @test_pow(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> t
// CHECK-LABEL: sub
func.func @test_sub(%arg0: tensor<1x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ [bf16] ]
+ // CHECK: extensions: [ [bf16, int64] ]
%0 = tosa.sub %arg0, %arg1 : (tensor<1x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
@@ -298,7 +298,7 @@ func.func @test_table(%arg0: tensor<64xi32>, %arg1: tensor<513x!quant.uniform<i1
// CHECK-LABEL: abs
func.func @test_abs(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ [bf16] ]
+ // CHECK: extensions: [ [bf16, int64] ]
%0 = tosa.abs %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
@@ -307,7 +307,7 @@ func.func @test_abs(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
// CHECK-LABEL: bitwise_not
func.func @test_bitwise_not(%arg0: tensor<13x21x1xi32>) -> tensor<13x21x1xi32> {
// CHECK: profiles: [ [pro_int] ]
- // CHECK: extensions: [ ]
+ // CHECK: extensions: [ [int64] ]
%0 = tosa.bitwise_not %arg0 : (tensor<13x21x1xi32>) -> tensor<13x21x1xi32>
return %0 : tensor<13x21x1xi32>
}
@@ -325,7 +325,7 @@ func.func @test_ceil(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
// CHECK-LABEL: clz
func.func @test_clz(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
// CHECK: profiles: [ [pro_int] ]
- // CHECK: extensions: [ ]
+ // CHECK: extensions: [ [int64] ]
%0 = tosa.clz %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
}
@@ -379,7 +379,7 @@ func.func @test_logical_not(%arg0: tensor<1x21x3xi1>) -> tensor<1x21x3xi1> {
// CHECK-LABEL: negate
func.func @test_negate(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ [bf16] ]
+ // CHECK: extensions: [ [bf16, int64] ]
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%output_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%0 = tosa.negate %arg0, %input_zp, %output_zp : (tensor<13x21x3xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<13x21x3xf32>
@@ -417,7 +417,7 @@ func.func @test_sin(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
// CHECK-LABEL: select
func.func @test_select(%arg0: tensor<1x1x1xi1>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ [bf16] ]
+ // CHECK: extensions: [ [bf16, int64] ]
%0 = tosa.select %arg0, %arg1, %arg2 : (tensor<1x1x1xi1>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
@@ -426,7 +426,7 @@ func.func @test_select(%arg0: tensor<1x1x1xi1>, %arg1: tensor<13x21x3xf32>, %arg
// CHECK-LABEL: equal
func.func @test_equal(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xi1> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ [bf16] ]
+ // CHECK: extensions: [ [bf16, int64] ]
%0 = tosa.equal %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xi1>
return %0 : tensor<13x21x3xi1>
}
@@ -435,7 +435,7 @@ func.func @test_equal(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) ->
// CHECK-LABEL: greater
func.func @test_greater(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xi1> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ [bf16] ]
+ // CHECK: extensions: [ [bf16, int64] ]
%0 = tosa.greater %arg0, %arg1 : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1>
return %0 : tensor<13x21x3xi1>
}
@@ -444,7 +444,7 @@ func.func @test_greater(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>)
// CHECK-LABEL: greater_equal
func.func @test_greater_equal(%arg0: tensor<13x1x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xi1> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ [bf16] ]
+ // CHECK: extensions: [ [bf16, int64] ]
%0 = tosa.greater_equal %arg0, %arg1 : (tensor<13x1x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1>
return %0 : tensor<13x21x3xi1>
}
@@ -471,7 +471,7 @@ func.func @test_reduce_any(%arg0: tensor<13x21x3xi1>) -> tensor<1x21x3xi1> {
// CHECK-LABEL: reduce_max
func.func @test_reduce_max(%arg0: tensor<13x21x3xf32>) -> tensor<1x21x3xf32> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ [bf16] ]
+ // CHECK: extensions: [ [bf16, int64] ]
%0 = tosa.reduce_max %arg0 {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32>
return %0 : tensor<1x21x3xf32>
}
@@ -480,7 +480,7 @@ func.func @test_reduce_max(%arg0: tensor<13x21x3xf32>) -> tensor<1x21x3xf32> {
// CHECK-LABEL: reduce_min
func.func @test_reduce_min(%arg0: tensor<13x21x3xf32>) -> tensor<1x21x3xf32> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ [bf16] ]
+ // CHECK: extensions: [ [bf16, int64] ]
%0 = tosa.reduce_min %arg0 {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32>
return %0 : tensor<1x21x3xf32>
}
@@ -498,7 +498,7 @@ func.func @test_reduce_product(%arg0: tensor<13x21x3xf32>) -> tensor<1x21x3xf32>
// CHECK-LABEL: reduce_sum
func.func @test_reduce_sum(%arg0: tensor<13x21x3xf32>) -> tensor<1x21x3xf32> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ [bf16] ]
+ // CHECK: extensions: [ [bf16, int64] ]
%0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32>
return %0 : tensor<1x21x3xf32>
}
@@ -507,7 +507,7 @@ func.func @test_reduce_sum(%arg0: tensor<13x21x3xf32>) -> tensor<1x21x3xf32> {
// CHECK-LABEL: concat
func.func @test_concat(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<26x21x3xf32> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16, int16] ]
+ // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16, int16, int64] ]
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<26x21x3xf32>
return %0 : tensor<26x21x3xf32>
}
@@ -518,7 +518,7 @@ func.func @test_pad(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
%padding = tosa.const_shape {values = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
%pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
// CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
+ // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16, int64] ]
%0 = tosa.pad %arg0, %padding, %pad_const : (tensor<13x21x3xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
@@ -528,7 +528,7 @@ func.func @test_pad(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
func.func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<1x819xf32> {
%1 = tosa.const_shape {values = dense<[1, 819]> : tensor<2xindex>} : () -> !tosa.shape<2>
// CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
+ // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16, int64] ]
%0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xf32>, !tosa.shape<2>) -> tensor<1x819xf32>
return %0 : tensor<1x819xf32>
}
@@ -537,7 +537,7 @@ func.func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<1x819xf32> {
// CHECK-LABEL: reverse
func.func @test_reverse(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
+ // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16, int64] ]
%0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
@@ -548,7 +548,7 @@ func.func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<4x11x1xf32> {
%0 = tosa.const_shape {values = dense<[4, 11, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
%1 = tosa.const_shape {values = dense<[6, 8, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
// CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
+ // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16, int64] ]
%2 = tosa.slice %arg0, %0, %1 : (tensor<13x21x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xf32>
return %2 : tensor<4x11x1xf32>
}
@@ -558,7 +558,7 @@ func.func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<4x11x1xf32> {
func.func @test_tile(%arg0: tensor<13x21x3xf32>) -> tensor<39x21x6xf32> {
%cst = tosa.const_shape { values = dense<[3, 1, 2]> : tensor<3xindex> } : () -> !tosa.shape<3>
// CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
+ // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16, int64] ]
%0 = tosa.tile %arg0, %cst: (tensor<13x21x3xf32>, !tosa.shape<3>) -> tensor<39x21x6xf32>
return %0 : tensor<39x21x6xf32>
}
@@ -567,7 +567,7 @@ func.func @test_tile(%arg0: tensor<13x21x3xf32>) -> tensor<39x21x6xf32> {
// CHECK-LABEL: transpose
func.func @test_transpose(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xf32> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
+ // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16, int64] ]
%1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>}: (tensor<13x21x3xf32>) -> tensor<3x13x21xf32>
return %1 : tensor<3x13x21xf32>
}
@@ -576,7 +576,7 @@ func.func @test_transpose(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xf32> {
// CHECK-LABEL: gather
func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<13x26x3xf32> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
+ // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16, int64] ]
%0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<13x26x3xf32>
return %0 : tensor<13x26x3xf32>
}
@@ -585,7 +585,7 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
// CHECK-LABEL: scatter
func.func @test_scatter(%arg0: tensor<13x28x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x28x3xf32> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
+ // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16, int64] ]
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x28x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x28x3xf32>
return %0 : tensor<13x28x3xf32>
}
@@ -606,7 +606,7 @@ func.func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> {
// CHECK-LABEL: cast
func.func @test_cast1(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
+ // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16, mxfp, int64] ]
%0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
@@ -626,7 +626,7 @@ func.func @test_rescale(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439
// CHECK-LABEL: test_const
func.func @test_const(%arg0 : index) -> tensor<4xi32> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ [int4, int16, fp8e4m3, fp8e5m2, bf16] ]
+ // CHECK: extensions: [ [int4, int16, fp8e4m3, fp8e5m2, bf16, mxfp, int64] ]
%0 = "tosa.const"() {values = dense<[3, 0, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
return %0 : tensor<4xi32>
}
@@ -635,7 +635,7 @@ func.func @test_const(%arg0 : index) -> tensor<4xi32> {
// CHECK-LABEL: identity
func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ [int4, int16, fp8e4m3, fp8e5m2, bf16] ]
+ // CHECK: extensions: [ [int4, int16, fp8e4m3, fp8e5m2, bf16, int64] ]
%0 = tosa.identity %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
}
@@ -696,3 +696,21 @@ func.func @test_const_shape() -> !tosa.shape<4> {
return %cst : !tosa.shape<4>
}
+// -----
+// CHECK-LABEL: test_cast_from_block_scaled
+func.func @test_cast_from_block_scaled(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
+ // CHECK: profiles: [ [pro_fp] ]
+ // CHECK: extensions: [ [bf16, mxfp] ]
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32>
+ return %0 : tensor<4x32xf32>
+}
+
+// -----
+// CHECK-LABEL: test_cast_to_block_scaled
+func.func @test_cast_to_block_scaled(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) {
+ // CHECK: profiles: [ [pro_fp] ]
+ // CHECK: extensions: [ [bf16, mxfp] ]
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = BLOCK_SIZE_32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>
+}
+
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 7574afa..5a40f3f 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -18,11 +18,11 @@ func.func @test_argmax_fold_i64_index(%arg0: tensor<1xi8>) -> tensor<i64> {
// -----
-// CHECK-LABEL: @pad_wh_avg_pool2d_fold
-func.func @pad_wh_avg_pool2d_fold(%input: tensor<1x10x8x3xf32>) -> tensor<1x6x5x3xf32> {
- // CHECK-NOT: tosa.pad
+// CHECK-LABEL: @pad_wh_avg_pool2d_nofold
+func.func @pad_wh_avg_pool2d_nofold(%input: tensor<1x10x8x3xf32>) -> tensor<1x6x5x3xf32> {
+ // CHECK: tosa.pad
// CHECK: tosa.avg_pool2d
- // CHECK-SAME: pad = array<i64: 1, 1, 1, 1>
+ // CHECK-SAME: pad = array<i64: 0, 1, 0, 1>
%pad_shape = tosa.const_shape { values = dense<[0, 0, 1, 0, 1, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
%pad_const = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 2a3985c..119991ca 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -729,15 +729,6 @@ func.func @test_mul_missing_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1
// -----
-// CHECK-LABEL: test_unsupported_int64_data_type
-func.func @test_unsupported_int64_data_type(%arg0: tensor<1x13x13x5xf32>) -> tensor<1x13x13xi64> {
- // expected-error@+1 {{'tosa.argmax' op is not profile-aligned: element type 'i64' is not legal}}
- %0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x13x13x5xf32>) -> tensor<1x13x13xi64>
- return %0 : tensor<1x13x13xi64>
-}
-
-// -----
-
// CHECK-LABEL: test_mismatch_in_out_data_type_clamp
func.func @test_mismatch_in_out_data_type_clamp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf16> {
// expected-error@+1 {{'tosa.clamp' op requires the same element type for all operands and results}}
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index e5c9402..68a9578 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -538,3 +538,49 @@ func.func @test_avg_pool2d_non_const_output_zp(%arg0: tensor<1x32x32x8xf32>, %ou
(tensor<1x32x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x8xf32>
return %0 : tensor<1x32x32x8xf32>
}
+
+// -----
+
+func.func @test_matmul_t_block_scaled(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf8E4M3FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+ // expected-error@+1 {{'tosa.matmul_t_block_scaled' op illegal: requires [mxfp] but not enabled in target}}
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf8E4M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+ return %0 : tensor<4x8x16xf32>
+}
+
+// -----
+
+func.func @test_argmax_int64(%arg0: tensor<1x13x13x5xf32>) -> tensor<1x13x13xi64> {
+ // expected-error@+1 {{'tosa.argmax' op illegal: requires [int64] but not enabled in target}}
+ %0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x13x13x5xf32>) -> tensor<1x13x13xi64>
+ return %0 : tensor<1x13x13xi64>
+}
+
+// -----
+func.func @test_const_fp6e3m2(%arg0 : index) -> tensor<4xf6E3M2FN> {
+ // expected-error@+1 {{'tosa.const' op illegal: requires [mxfp] but not enabled in target}}
+ %0 = "tosa.const"() {values = dense<[0.0, 0.0, 0.0, 0.0]> : tensor<4xf6E3M2FN>} : () -> tensor<4xf6E3M2FN>
+ return %0 : tensor<4xf6E3M2FN>
+}
+
+// -----
+func.func @test_cast_f4e2m1(%arg0: tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16> {
+ // expected-error@+1 {{'tosa.cast' op illegal: requires all of [bf16, mxfp] but not enabled in target}}
+ %0 = tosa.cast %arg0 : (tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16>
+ return %0 : tensor<13x21x3xbf16>
+}
+
+// -----
+
+func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf8E5M2>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
+ // expected-error@+1 {{'tosa.cast_from_block_scaled' op illegal: requires [mxfp] but not enabled in target}}
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf8E5M2>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32>
+ return %0 : tensor<4x32xf32>
+}
+
+// -----
+
+func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>) {
+ // expected-error@+1 {{'tosa.cast_to_block_scaled' op illegal: requires [mxfp] but not enabled in target}}
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>
+}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 8cc357e..a708764 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1622,3 +1622,43 @@ func.func @test_unranked_weight_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor
%0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4xf32>, tensor<*xf32>, tensor<8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
+
+// -----
+
+func.func @test_matmul_t_block_scaled_invalid_size(%arg0: tensor<4x8x536870912xf4E2M1FN>, %arg1: tensor<4x8x16777216xf8E8M0FNU>, %arg2: tensor<4x16x536870912xf4E2M1FN>, %arg3: tensor<4x16x16777216xf8E8M0FNU>) -> tensor<*xf32> {
+ // expected-error@+1 {{'tosa.matmul_t_block_scaled' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x536870912xf4E2M1FN>, tensor<4x8x16777216xf8E8M0FNU>, tensor<4x16x536870912xf4E2M1FN>, tensor<4x16x16777216xf8E8M0FNU>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
+// -----
+
+func.func @test_cast_from_block_scaled_invalid_size(%arg0: tensor<67108864x32xf6E2M3FN>, %arg1: tensor<67108864x1xf8E8M0FNU>) -> tensor<67108864x32xf32> {
+ // expected-error@+1 {{'tosa.cast_from_block_scaled' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<67108864x32xf6E2M3FN>, tensor<67108864x1xf8E8M0FNU>) -> tensor<67108864x32xf32>
+ return %0 : tensor<67108864x32xf32>
+}
+
+// -----
+
+func.func @test_cast_from_block_scaled_invalid_rank(%arg0: tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, %arg1: tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>) -> tensor<1x2x3x4x5x6x7x32xf32> {
+ // expected-error@+1 {{'tosa.cast_from_block_scaled' op failed level check: operand rank(shape) <= MAX_RANK}}
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>) -> tensor<1x2x3x4x5x6x7x32xf32>
+ return %0 : tensor<1x2x3x4x5x6x7x32xf32>
+}
+
+// -----
+
+func.func @test_cast_to_block_scaled_invalid_size(%arg0: tensor<67108864x32xf32>) -> (tensor<67108864x32xf6E2M3FN>, tensor<67108864x1xf8E8M0FNU>) {
+ // expected-error@+1 {{'tosa.cast_to_block_scaled' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<67108864x32xf32>) -> (tensor<67108864x32xf6E2M3FN>, tensor<67108864x1xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<67108864x32xf6E2M3FN>, tensor<67108864x1xf8E8M0FNU>
+}
+
+// -----
+
+func.func @test_cast_to_block_scaled_invalid_rank(%arg0: tensor<1x2x3x4x5x6x7x32xf32>) -> (tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>) {
+ // expected-error@+1 {{'tosa.cast_to_block_scaled' op failed level check: operand rank(shape) <= MAX_RANK}}
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<1x2x3x4x5x6x7x32xf32>) -> (tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>
+}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 868b7b7..865f712 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -1226,3 +1226,73 @@ func.func @test_scatter_f8E4M3FN(%arg0: tensor<13x29x3xf8E4M3FN>, %arg1: tensor<
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x29x3xf8E4M3FN>, tensor<13x26xi32>, tensor<13x26x3xf8E4M3FN>) -> tensor<13x29x3xf8E4M3FN>
return %0 : tensor<13x29x3xf8E4M3FN>
}
+
+// -----
+// CHECK-LABEL: test_matmul_t_block_scaled_static
+func.func @test_matmul_t_block_scaled_static(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf8E4M3FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf8E4M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+ return %0 : tensor<4x8x16xf32>
+}
+
+// -----
+// CHECK-LABEL: test_matmul_t_block_scaled_unranked
+func.func @test_matmul_t_block_scaled_unranked(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<*xf8E4M3FN>, %arg3: tensor<*xf8E8M0FNU>) -> tensor<*xf32> {
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>, tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
+// -----
+// CHECK-LABEL: test_matmul_t_block_scaled_fp6e3m2
+func.func @test_matmul_t_block_scaled_fp6e3m2(%arg0: tensor<4x8x32xf6E3M2FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf6E3M2FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32xf6E3M2FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf6E3M2FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+ return %0 : tensor<4x8x16xf32>
+}
+
+// -----
+// CHECK-LABEL: test_matmul_t_block_scaled_fp6e2m3
+func.func @test_matmul_t_block_scaled_fp6e2m3(%arg0: tensor<4x8x32xf6E2M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf6E2M3FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32xf6E2M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf6E2M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+ return %0 : tensor<4x8x16xf32>
+}
+
+// -----
+// CHECK-LABEL: test_matmul_t_block_scaled_fp4e2m1
+func.func @test_matmul_t_block_scaled_fp4e2m1(%arg0: tensor<4x8x32xf4E2M1FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf4E2M1FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32xf4E2M1FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf4E2M1FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+ return %0 : tensor<4x8x16xf32>
+}
+
+// -----
+// CHECK-LABEL: test_matmul_t_block_scaled_broadcast
+func.func @test_matmul_t_block_scaled_broadcast(%arg0: tensor<?x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<?x16x32xf8E4M3FN>, %arg3: tensor<1x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<?x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<?x16x32xf8E4M3FN>, tensor<1x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+ return %0 : tensor<4x8x16xf32>
+}
+
+// -----
+// CHECK-LABEL: test_cast_from_block_scaled_static
+func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32>
+ return %0 : tensor<4x32xf32>
+}
+
+// -----
+// CHECK-LABEL: test_cast_from_block_scaled_unranked
+func.func @test_cast_from_block_scaled_unranked(%arg0: tensor<*xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>) -> tensor<*xf32> {
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
+// -----
+// CHECK-LABEL: test_cast_to_block_scaled_static
+func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) {
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>
+}
+
+// -----
+// CHECK-LABEL: test_cast_to_block_scaled_unranked
+func.func @test_cast_to_block_scaled_unranked(%arg0: tensor<*xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) {
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>
+}
diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
index 7ff8065..7de7b85 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
@@ -2,7 +2,7 @@
// Enable all supported extensions to focus the verification of expected profile requirement errors.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround,mxfp" -tosa-validate="strict-op-spec-alignment"
// -----
func.func @test_const_f16() -> tensor<3x11x11x3xf16> {
@@ -325,3 +325,24 @@ func.func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> {
%1 = tosa.resize %arg0, %scale, %offset, %border { mode = BILINEAR } : (tensor<1x32x32x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x64x64x8xf32>
return %1 : tensor<1x64x64x8xf32>
}
+
+// -----
+func.func @test_matmul_t_block_scaled(%arg0: tensor<4x8x32xf6E3M2FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf6E3M2FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+ // expected-error@+1 {{'tosa.matmul_t_block_scaled' op illegal: requires [pro_fp] but not enabled in target}}
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32xf6E3M2FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf6E3M2FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+ return %0 : tensor<4x8x16xf32>
+}
+
+// -----
+func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
+ // expected-error@+1 {{'tosa.cast_from_block_scaled' op illegal: requires [pro_fp] but not enabled in target}}
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32>
+ return %0 : tensor<4x32xf32>
+}
+
+// -----
+func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) {
+ // expected-error@+1 {{'tosa.cast_to_block_scaled' op illegal: requires [pro_fp] but not enabled in target}}
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 80f06f1..54556a0 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1574,3 +1574,102 @@ func.func @test_mul_scalar(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<*xf
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
+
+// -----
+
+// CHECK-LABEL: test_matmul_t_block_scaled_static
+func.func @test_matmul_t_block_scaled_static(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<1x16x32xf8E4M3FN>, %arg3: tensor<1x16x1xf8E8M0FNU>) -> tensor<?x?x?xf32> {
+ // CHECK: -> tensor<4x8x16xf32>
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<1x16x32xf8E4M3FN>, tensor<1x16x1xf8E8M0FNU>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_matmul_t_block_scaled_unranked_a_data
+func.func @test_matmul_t_block_scaled_unranked_a_data(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf8E4M3FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<?x?x?xf32> {
+ // CHECK: -> tensor<4x8x16xf32>
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf8E4M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_matmul_t_block_scaled_unranked_b_data_and_scale
+func.func @test_matmul_t_block_scaled_unranked_b_data_and_scale(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<*xf8E4M3FN>, %arg3: tensor<*xf8E8M0FNU>) -> tensor<?x?x?xf32> {
+ // CHECK: -> tensor<4x8x?xf32>
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_matmul_t_block_scaled_unranked_all
+func.func @test_matmul_t_block_scaled_unranked_all(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<*xf8E4M3FN>, %arg3: tensor<*xf8E8M0FNU>) -> tensor<?x?x?xf32> {
+ // CHECK: -> tensor<?x?x?xf32>
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>, tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_matmul_t_block_scaled_broadcast_b_data
+func.func @test_matmul_t_block_scaled_broadcast_b_data(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<1x4x32xf8E4M3FN>, %arg3: tensor<1x4x1xf8E8M0FNU>) -> tensor<?x?x?xf32> {
+ // CHECK: -> tensor<?x?x4xf32>
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>, tensor<1x4x32xf8E4M3FN>, tensor<1x4x1xf8E8M0FNU>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_matmul_t_block_scaled_broadcast_b_scale
+func.func @test_matmul_t_block_scaled_broadcast_b_scale(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<*xf8E4M3FN>, %arg3: tensor<1x4x1xf8E8M0FNU>) -> tensor<?x?x?xf32> {
+ // CHECK: -> tensor<?x?x4xf32>
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>, tensor<*xf8E4M3FN>, tensor<1x4x1xf8E8M0FNU>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_from_block_scaled_static
+func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<*xf32> {
+ // CHECK: -> tensor<4x32xf32>
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_from_block_scaled_unranked_input_scale
+func.func @test_cast_from_block_scaled_unranked_input_scale(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>) -> tensor<*xf32> {
+ // CHECK: -> tensor<4x32xf32>
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<*xf8E8M0FNU>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_to_block_scaled_static
+func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) {
+ // CHECK: -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>)
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_to_block_scaled_unranked
+func.func @test_cast_to_block_scaled_unranked(%arg0: tensor<*xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) {
+ // CHECK: -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>)
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_to_block_scaled_dynamic_scales
+func.func @test_cast_to_block_scaled_dynamic_scales(%arg0: tensor<4x?xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) {
+ // CHECK: -> (tensor<4x?xf4E2M1FN>, tensor<4x?xf8E8M0FNU>)
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x?xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
index 8164509..f3d8dab 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,mxfp,int64" -tosa-validate="strict-op-spec-alignment" | FileCheck %s
// -----
@@ -18,3 +18,67 @@ func.func @test_matmul_fp8_input_fp32_acc_type(%arg0: tensor<1x14x19xf8E4M3FN>,
%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E4M3FN>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x14x28xf32>
return %0 : tensor<1x14x28xf32>
}
+
+// -----
+
+// CHECK-LABEL: test_matmul_t_block_scaled_fp6e2m3
+func.func @test_matmul_t_block_scaled_fp6e2m3(%arg0: tensor<4x8x32xf6E2M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf6E2M3FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = BLOCK_SIZE_32} : (tensor<4x8x32xf6E2M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf6E2M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+ return %0 : tensor<4x8x16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_argmax_int64
+func.func @test_argmax_int64(%arg0: tensor<1x13x13x5xf32>) -> tensor<1x13x13xi64> {
+ %0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x13x13x5xf32>) -> tensor<1x13x13xi64>
+ return %0 : tensor<1x13x13xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_const_i64
+func.func @test_const_i64(%arg0 : index) -> tensor<4xi64> {
+ %0 = "tosa.const"() {values = dense<[3, 0, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
+ return %0 : tensor<4xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_const_fp6e3m2
+func.func @test_const_fp6e3m2(%arg0 : index) -> tensor<4xf6E3M2FN> {
+ %0 = "tosa.const"() {values = dense<[0.0, 0.0, 0.0, 0.0]> : tensor<4xf6E3M2FN>} : () -> tensor<4xf6E3M2FN>
+ return %0 : tensor<4xf6E3M2FN>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_f4e2m1
+func.func @test_cast_f4e2m1(%arg0: tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16> {
+ %0 = tosa.cast %arg0 : (tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16>
+ return %0 : tensor<13x21x3xbf16>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_from_block_scaled_fp8e5m2_fp32
+func.func @test_cast_from_block_scaled_fp8e5m2_fp32(%arg0: tensor<4x32xf8E5M2>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf8E5M2>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32>
+ return %0 : tensor<4x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_from_block_scaled_fp8e5m2_bf16
+func.func @test_cast_from_block_scaled_fp8e5m2_bf16(%arg0: tensor<4x32xf8E5M2>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xbf16> {
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf8E5M2>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xbf16>
+ return %0 : tensor<4x32xbf16>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_to_block_scaled_static
+func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>) {
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>
+}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 430b06a..6cf76cd 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1033,7 +1033,6 @@ module {
// -----
-// CHECK-LABEL: @scatter_invalid_indices_N
func.func @scatter_invalid_indices_N(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<3x2xi32>, %arg2 : tensor<2x2x5xi32>) {
// expected-error@+1 {{'tosa.scatter' op requires indices dimension 0 to have size 2, got 3}}
%1 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<3x2xi32>, tensor<2x2x5xi32>) -> tensor<2x4x5xi32>
@@ -1042,7 +1041,6 @@ func.func @scatter_invalid_indices_N(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<3
// -----
-// CHECK-LABEL: @scatter_invalid_input_N
func.func @scatter_invalid_input_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<2x2xi32>, %arg2 : tensor<3x2x5xi32>) {
// expected-error@+1 {{'tosa.scatter' op requires input dimension 0 to have size 2, got 3}}
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<2x2xi32>, tensor<3x2x5xi32>) -> tensor<2x4x5xi32>
@@ -1051,7 +1049,6 @@ func.func @scatter_invalid_input_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<2x2
// -----
-// CHECK-LABEL: @scatter_invalid_out_N
func.func @scatter_invalid_out_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
// expected-error@+1 {{'tosa.scatter' op requires values_out dimension 0 to have size 2, got 3}}
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<3x4x5xi32>
@@ -1060,7 +1057,6 @@ func.func @scatter_invalid_out_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi
// -----
-// CHECK-LABEL: @scatter_invalid_out_K
func.func @scatter_invalid_out_K(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
// expected-error@+1 {{'tosa.scatter' op requires values_out dimension 1 to have size 4, got 3}}
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<2x3x5xi32>
@@ -1069,7 +1065,6 @@ func.func @scatter_invalid_out_K(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi
// -----
-// CHECK-LABEL: @scatter_invalid_input_W
func.func @scatter_invalid_input_W(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x3x5xi32>) {
// expected-error@+1 {{'tosa.scatter' op requires input dimension 1 to have size 2, got 3}}
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x3x5xi32>) -> tensor<2x4x5xi32>
@@ -1078,7 +1073,6 @@ func.func @scatter_invalid_input_W(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2
// -----
-// CHECK-LABEL: @scatter_invalid_input_C
func.func @scatter_invalid_input_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x6xi32>) {
// expected-error@+1 {{'tosa.scatter' op requires input dimension 2 to have size 5, got 6}}
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x6xi32>) -> tensor<2x4x5xi32>
@@ -1087,7 +1081,6 @@ func.func @scatter_invalid_input_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2
// -----
-// CHECK-LABEL: @scatter_invalid_out_C
func.func @scatter_invalid_out_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
// expected-error@+1 {{'tosa.scatter' op requires values_out dimension 2 to have size 5, got 6}}
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<2x4x6xi32>
@@ -1096,9 +1089,136 @@ func.func @scatter_invalid_out_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi
// -----
-// CHECK-LABEL: @scatter_invalid_K_W
func.func @scatter_invalid_K_W(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<2x6xi32>, %arg2 : tensor<2x6x5xi32>) {
// expected-error@+1 {{'tosa.scatter' op requires dimensions K >= W, got K=4 and W=6}}
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<2x6xi32>, tensor<2x6x5xi32>) -> tensor<2x4x5xi32>
return
}
+
+// -----
+
+func.func @test_matmul_t_block_scaled_data_mismatch(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf8E5M2>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+ // expected-error@+1 {{'tosa.matmul_t_block_scaled' op expect A_data and B_data to have same element type, got 'f8E4M3FN' and 'f8E5M2'}}
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf8E5M2>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+ return %0 : tensor<4x8x16xf32>
+}
+
+// -----
+
+func.func @test_matmul_t_block_scaled_output_batch_mismatch(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor<?x8x1xf8E8M0FNU>, %arg2: tensor<*xf8E4M3FN>, %arg3: tensor<4x?x?xf8E8M0FNU>) -> tensor<5x?x?xf32> {
+ // expected-error@+1 {{'tosa.matmul_t_block_scaled' op expected output shape 5, ?, ? to be compatible with expected output shape 4, 8, ?}}
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<*xf8E4M3FN>, tensor<?x8x1xf8E8M0FNU>, tensor<*xf8E4M3FN>, tensor<4x?x?xf8E8M0FNU>) -> tensor<5x?x?xf32>
+ return %0 : tensor<5x?x?xf32>
+}
+
+// -----
+
+func.func @test_matmul_t_block_scaled_output_height_mismatch(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor<?x9x1xf8E8M0FNU>, %arg2: tensor<*xf8E4M3FN>, %arg3: tensor<4x?x?xf8E8M0FNU>) -> tensor<4x8x?xf32> {
+ // expected-error@+1 {{'tosa.matmul_t_block_scaled' op expected output shape 4, 8, ? to be compatible with expected output shape 4, 9, ?}}
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<*xf8E4M3FN>, tensor<?x9x1xf8E8M0FNU>, tensor<*xf8E4M3FN>, tensor<4x?x?xf8E8M0FNU>) -> tensor<4x8x?xf32>
+ return %0 : tensor<4x8x?xf32>
+}
+
+// -----
+
+func.func @test_matmul_t_block_scaled_output_width_mismatch(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor<?x?x1xf8E8M0FNU>, %arg2: tensor<?x1x?xf8E4M3FN>, %arg3: tensor<*xf8E8M0FNU>) -> tensor<?x?x10xf32> {
+ // expected-error@+1 {{'tosa.matmul_t_block_scaled' op expected output shape ?, ?, 10 to be compatible with expected output shape ?, ?, 1}}
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<*xf8E4M3FN>, tensor<?x?x1xf8E8M0FNU>, tensor<?x1x?xf8E4M3FN>, tensor<*xf8E8M0FNU>) -> tensor<?x?x10xf32>
+ return %0 : tensor<?x?x10xf32>
+}
+
+// -----
+
+func.func @test_matmul_t_block_scaled_channel_not_multiple_of_block_size(%arg0: tensor<4x8x55xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf8E4M3FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+ // expected-error@+1 {{'tosa.matmul_t_block_scaled' op expected channels of b_data to match size 55, got 32}}
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x8x55xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf8E4M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+ return %0 : tensor<4x8x16xf32>
+}
+
+// -----
+
+func.func @test_matmul_t_block_scaled_batch_mismatch(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<2x16x32xf8E4M3FN>, %arg3: tensor<2x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+ // expected-error@+1 {{'tosa.matmul_t_block_scaled' op expect B matrix batch size to be broadcast compatible with A, got D=2 vs N=4}}
+ %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<2x16x32xf8E4M3FN>, tensor<2x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+ return %0 : tensor<4x8x16xf32>
+}
+
+// -----
+
+func.func @cast_from_block_scaled_incompatible_input_output_shape(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<5x32xf32> {
+ // expected-error@+1 {{'tosa.cast_from_block_scaled' op require compatible shapes for input_data ('tensor<4x32xf4E2M1FN>') and output_data ('tensor<5x32xf32>')}}
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<5x32xf32>
+ return %0 : tensor<5x32xf32>
+}
+
+// -----
+
+func.func @cast_from_block_scaled_not_scalar(%arg0: tensor<f4E2M1FN>, %arg1: tensor<f8E8M0FNU>) -> tensor<f32> {
+ // expected-error@+1 {{'tosa.cast_from_block_scaled' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f4E2M1FN>'}}
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<f4E2M1FN>, tensor<f8E8M0FNU>) -> tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @cast_from_block_scaled_not_divisible_by_block_size(%arg0: tensor<4x33xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x33xf32> {
+ // expected-error@+1 {{'tosa.cast_from_block_scaled' op expect last dimension of input_data (33) to be divisible by block_size (32)}}
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x33xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<4x33xf32>
+ return %0 : tensor<4x33xf32>
+}
+
+// -----
+
+func.func @cast_from_block_scaled_data_scale_mismatch(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<5x1xf8E8M0FNU>) -> tensor<4x32xf32> {
+ // expected-error@+1 {{'tosa.cast_from_block_scaled' op require compatible shapes for input_data ('tensor<4x32xf4E2M1FN>') and input_scale ('tensor<5x1xf8E8M0FNU>') except for the last dimension}}
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<5x1xf8E8M0FNU>) -> tensor<4x32xf32>
+ return %0 : tensor<4x32xf32>
+}
+
+// -----
+
+func.func @cast_from_block_scaled_data_scale_channel_mismatch(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x2xf8E8M0FNU>) -> tensor<4x32xf32> {
+ // expected-error@+1 {{'tosa.cast_from_block_scaled' op expect last dimension of input_scale (2) to be equal to last dimension of input_data / block_size (1)}}
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>) -> tensor<4x32xf32>
+ return %0 : tensor<4x32xf32>
+}
+
+// -----
+
+func.func @test_cast_to_block_scaled_incompatible_input_output_shape(%arg0: tensor<4x32xf32>) -> (tensor<5x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) {
+ // expected-error@+1 {{'tosa.cast_to_block_scaled' op require compatible shapes for input_data ('tensor<4x32xf32>') and output_data ('tensor<5x32xf4E2M1FN>')}}
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<5x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<5x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>
+}
+
+// -----
+
+func.func @test_cast_to_block_scaled_not_scalar(%arg0: tensor<f32>) -> (tensor<f4E2M1FN>, tensor<f8E8M0FNU>) {
+ // expected-error@+1 {{'tosa.cast_to_block_scaled' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<f32>) -> (tensor<f4E2M1FN>, tensor<f8E8M0FNU>)
+ return %0#0, %0#1 : tensor<f4E2M1FN>, tensor<f8E8M0FNU>
+}
+
+// -----
+
+func.func @test_cast_to_block_scaled_not_divisible_by_block_size(%arg0: tensor<4x33xf32>) -> (tensor<4x33xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) {
+ // expected-error@+1 {{'tosa.cast_to_block_scaled' op expect last dimension of input_data (33) to be divisible by block_size (32)}}
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x33xf32>) -> (tensor<4x33xf4E2M1FN>, tensor<4x1xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<4x33xf4E2M1FN>, tensor<4x1xf8E8M0FNU>
+}
+
+// -----
+
+func.func @test_cast_to_block_scaled_data_scale_mismatch(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<5x1xf8E8M0FNU>) {
+ // expected-error@+1 {{'tosa.cast_to_block_scaled' op require compatible shapes for output_data ('tensor<4x32xf4E2M1FN>') and output_scale ('tensor<5x1xf8E8M0FNU>') except for the last dimension}}
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<5x1xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<5x1xf8E8M0FNU>
+}
+
+// -----
+
+func.func @test_cast_to_block_scaled_data_scale_channel_mismatch(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>) {
+ // expected-error@+1 {{'tosa.cast_to_block_scaled' op expect last dimension of output_scale (2) to be equal to last dimension of output_data / block_size (1)}}
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>)
+ return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>
+}
diff --git a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
index e6593320..2004a47 100644
--- a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
@@ -1,12 +1,16 @@
// RUN: mlir-opt %s -test-vector-to-vector-lowering | FileCheck %s
-// CHECK-LABEL: func @maskedload0(
-// CHECK-SAME: %[[A0:.*]]: memref<?xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
-// CHECK-DAG: %[[C:.*]] = arith.constant 0 : index
-// CHECK-NEXT: %[[T:.*]] = vector.load %[[A0]][%[[C]]] : memref<?xf32>, vector<16xf32>
-// CHECK-NEXT: return %[[T]] : vector<16xf32>
-func.func @maskedload0(%base: memref<?xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+//-----------------------------------------------------------------------------
+// [Pattern: MaskedLoadFolder]
+//-----------------------------------------------------------------------------
+
+// CHECK-LABEL: func @fold_maskedload_all_true_dynamic(
+// CHECK-SAME: %[[BASE:.*]]: memref<?xf32>,
+// CHECK-SAME: %[[PASS_THRU:.*]]: vector<16xf32>) -> vector<16xf32> {
+// CHECK-DAG: %[[IDX:.*]] = arith.constant 0 : index
+// CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[BASE]][%[[IDX]]] : memref<?xf32>, vector<16xf32>
+// CHECK-NEXT: return %[[LOAD]] : vector<16xf32>
+func.func @fold_maskedload_all_true_dynamic(%base: memref<?xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
%c0 = arith.constant 0 : index
%mask = vector.constant_mask [16] : vector<16xi1>
%ld = vector.maskedload %base[%c0], %mask, %pass_thru
@@ -14,13 +18,13 @@ func.func @maskedload0(%base: memref<?xf32>, %pass_thru: vector<16xf32>) -> vect
return %ld : vector<16xf32>
}
-// CHECK-LABEL: func @maskedload1(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
-// CHECK-DAG: %[[C:.*]] = arith.constant 0 : index
-// CHECK-NEXT: %[[T:.*]] = vector.load %[[A0]][%[[C]]] : memref<16xf32>, vector<16xf32>
-// CHECK-NEXT: return %[[T]] : vector<16xf32>
-func.func @maskedload1(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+// CHECK-LABEL: func @fold_maskedload_all_true_static(
+// CHECK-SAME: %[[BASE:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[PASS_THRU:.*]]: vector<16xf32>) -> vector<16xf32> {
+// CHECK-DAG: %[[IDX:.*]] = arith.constant 0 : index
+// CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[BASE]][%[[IDX]]] : memref<16xf32>, vector<16xf32>
+// CHECK-NEXT: return %[[LOAD]] : vector<16xf32>
+func.func @fold_maskedload_all_true_static(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
%c0 = arith.constant 0 : index
%mask = vector.constant_mask [16] : vector<16xi1>
%ld = vector.maskedload %base[%c0], %mask, %pass_thru
@@ -28,11 +32,11 @@ func.func @maskedload1(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vec
return %ld : vector<16xf32>
}
-// CHECK-LABEL: func @maskedload2(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
-// CHECK-NEXT: return %[[A1]] : vector<16xf32>
-func.func @maskedload2(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+// CHECK-LABEL: func @fold_maskedload_all_false_static(
+// CHECK-SAME: %[[BASE:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[PASS_THRU:.*]]: vector<16xf32>) -> vector<16xf32> {
+// CHECK-NEXT: return %[[PASS_THRU]] : vector<16xf32>
+func.func @fold_maskedload_all_false_static(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
%c0 = arith.constant 0 : index
%mask = vector.constant_mask [0] : vector<16xi1>
%ld = vector.maskedload %base[%c0], %mask, %pass_thru
@@ -40,13 +44,13 @@ func.func @maskedload2(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vec
return %ld : vector<16xf32>
}
-// CHECK-LABEL: func @maskedload3(
-// CHECK-SAME: %[[A0:.*]]: memref<?xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
-// CHECK-DAG: %[[C:.*]] = arith.constant 8 : index
-// CHECK-NEXT: %[[T:.*]] = vector.load %[[A0]][%[[C]]] : memref<?xf32>, vector<16xf32>
-// CHECK-NEXT: return %[[T]] : vector<16xf32>
-func.func @maskedload3(%base: memref<?xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+// CHECK-LABEL: func @fold_maskedload_dynamic_non_zero_idx(
+// CHECK-SAME: %[[BASE:.*]]: memref<?xf32>,
+// CHECK-SAME: %[[PASS_THRU:.*]]: vector<16xf32>) -> vector<16xf32> {
+// CHECK-DAG: %[[IDX:.*]] = arith.constant 8 : index
+// CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[BASE]][%[[IDX]]] : memref<?xf32>, vector<16xf32>
+// CHECK-NEXT: return %[[LOAD]] : vector<16xf32>
+func.func @fold_maskedload_dynamic_non_zero_idx(%base: memref<?xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
%c8 = arith.constant 8 : index
%mask = vector.constant_mask [16] : vector<16xi1>
%ld = vector.maskedload %base[%c8], %mask, %pass_thru
@@ -54,39 +58,49 @@ func.func @maskedload3(%base: memref<?xf32>, %pass_thru: vector<16xf32>) -> vect
return %ld : vector<16xf32>
}
-// CHECK-LABEL: func @maskedstore1(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) {
-// CHECK-NEXT: %[[C:.*]] = arith.constant 0 : index
-// CHECK-NEXT: vector.store %[[A1]], %[[A0]][%[[C]]] : memref<16xf32>, vector<16xf32>
+//-----------------------------------------------------------------------------
+// [Pattern: MaskedStoreFolder]
+//-----------------------------------------------------------------------------
+
+// CHECK-LABEL: func @fold_maskedstore_all_true(
+// CHECK-SAME: %[[BASE:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[VALUE:.*]]: vector<16xf32>) {
+// CHECK-NEXT: %[[IDX:.*]] = arith.constant 0 : index
+// CHECK-NEXT: vector.store %[[VALUE]], %[[BASE]][%[[IDX]]] : memref<16xf32>, vector<16xf32>
// CHECK-NEXT: return
-func.func @maskedstore1(%base: memref<16xf32>, %value: vector<16xf32>) {
+func.func @fold_maskedstore_all_true(%base: memref<16xf32>, %value: vector<16xf32>) {
%c0 = arith.constant 0 : index
%mask = vector.constant_mask [16] : vector<16xi1>
vector.maskedstore %base[%c0], %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32>
return
}
-// CHECK-LABEL: func @maskedstore2(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) {
+// CHECK-LABEL: func @fold_maskedstore_all_false(
+// CHECK-SAME: %[[BASE:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[VALUE:.*]]: vector<16xf32>) {
// CHECK-NEXT: return
-func.func @maskedstore2(%base: memref<16xf32>, %value: vector<16xf32>) {
+func.func @fold_maskedstore_all_false(%base: memref<16xf32>, %value: vector<16xf32>) {
%c0 = arith.constant 0 : index
%mask = vector.constant_mask [0] : vector<16xi1>
vector.maskedstore %base[%c0], %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32>
return
}
-// CHECK-LABEL: func @gather1(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
-// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) -> vector<16xf32> {
+//-----------------------------------------------------------------------------
+// [Pattern: GatherFolder]
+//-----------------------------------------------------------------------------
+
+/// There is no alternative (i.e. simpler) Op for this, hence no-fold.
+
+// CHECK-LABEL: func @no_fold_gather_all_true(
+// CHECK-SAME: %[[BASE:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[INDICES:.*]]: vector<16xi32>,
+// CHECK-SAME: %[[PASS_THRU:.*]]: vector<16xf32>) -> vector<16xf32> {
// CHECK-NEXT: %[[C:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[M:.*]] = arith.constant dense<true> : vector<16xi1>
-// CHECK-NEXT: %[[G:.*]] = vector.gather %[[A0]][%[[C]]] [%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-NEXT: %[[G:.*]] = vector.gather %[[BASE]][%[[C]]] [%[[INDICES]]], %[[M]], %[[PASS_THRU]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
// CHECK-NEXT: return %[[G]] : vector<16xf32>
-func.func @gather1(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+func.func @no_fold_gather_all_true(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
%c0 = arith.constant 0 : index
%mask = vector.constant_mask [16] : vector<16xi1>
%ld = vector.gather %base[%c0][%indices], %mask, %pass_thru
@@ -94,12 +108,12 @@ func.func @gather1(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru:
return %ld : vector<16xf32>
}
-// CHECK-LABEL: func @gather2(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
-// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) -> vector<16xf32> {
-// CHECK-NEXT: return %[[A2]] : vector<16xf32>
-func.func @gather2(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+// CHECK-LABEL: func @fold_gather_all_true(
+// CHECK-SAME: %[[BASE:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[INDICES:.*]]: vector<16xi32>,
+// CHECK-SAME: %[[PASS_THRU:.*]]: vector<16xf32>) -> vector<16xf32> {
+// CHECK-NEXT: return %[[PASS_THRU]] : vector<16xf32>
+func.func @fold_gather_all_true(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
%c0 = arith.constant 0 : index
%mask = vector.constant_mask [0] : vector<16xi1>
%ld = vector.gather %base[%c0][%indices], %mask, %pass_thru
@@ -107,15 +121,21 @@ func.func @gather2(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru:
return %ld : vector<16xf32>
}
-// CHECK-LABEL: func @scatter1(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
-// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) {
+//-----------------------------------------------------------------------------
+// [Pattern: ScatterFolder]
+//-----------------------------------------------------------------------------
+
+/// There is no alternative (i.e. simpler) Op for this, hence no-fold.
+
+// CHECK-LABEL: func @no_fold_scatter_all_true(
+// CHECK-SAME: %[[BASE:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[INDICES:.*]]: vector<16xi32>,
+// CHECK-SAME: %[[VALUE:.*]]: vector<16xf32>) {
// CHECK-NEXT: %[[C:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[M:.*]] = arith.constant dense<true> : vector<16xi1>
-// CHECK-NEXT: vector.scatter %[[A0]][%[[C]]] [%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+// CHECK-NEXT: vector.scatter %[[BASE]][%[[C]]] [%[[INDICES]]], %[[M]], %[[VALUE]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
// CHECK-NEXT: return
-func.func @scatter1(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {
+func.func @no_fold_scatter_all_true(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {
%c0 = arith.constant 0 : index
%mask = vector.constant_mask [16] : vector<16xi1>
vector.scatter %base[%c0][%indices], %mask, %value
@@ -123,12 +143,12 @@ func.func @scatter1(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vec
return
}
-// CHECK-LABEL: func @scatter2(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
-// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) {
+// CHECK-LABEL: func @fold_scatter_all_false(
+// CHECK-SAME: %[[BASE:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[INDICES:.*]]: vector<16xi32>,
+// CHECK-SAME: %[[VALUE:.*]]: vector<16xf32>) {
// CHECK-NEXT: return
-func.func @scatter2(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {
+func.func @fold_scatter_all_false(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {
%c0 = arith.constant 0 : index
%0 = vector.type_cast %base : memref<16xf32> to memref<vector<16xf32>>
%mask = vector.constant_mask [0] : vector<16xi1>
@@ -137,13 +157,17 @@ func.func @scatter2(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vec
return
}
-// CHECK-LABEL: func @expand1(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
+//-----------------------------------------------------------------------------
+// [Pattern: ExpandLoadFolder]
+//-----------------------------------------------------------------------------
+
+// CHECK-LABEL: func @fold_expandload_all_true(
+// CHECK-SAME: %[[BASE:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[PASS_THRU:.*]]: vector<16xf32>) -> vector<16xf32> {
// CHECK-DAG: %[[C:.*]] = arith.constant 0 : index
-// CHECK-NEXT: %[[T:.*]] = vector.load %[[A0]][%[[C]]] : memref<16xf32>, vector<16xf32>
+// CHECK-NEXT: %[[T:.*]] = vector.load %[[BASE]][%[[C]]] : memref<16xf32>, vector<16xf32>
// CHECK-NEXT: return %[[T]] : vector<16xf32>
-func.func @expand1(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+func.func @fold_expandload_all_true(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
%c0 = arith.constant 0 : index
%mask = vector.constant_mask [16] : vector<16xi1>
%ld = vector.expandload %base[%c0], %mask, %pass_thru
@@ -151,11 +175,11 @@ func.func @expand1(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<
return %ld : vector<16xf32>
}
-// CHECK-LABEL: func @expand2(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
-// CHECK-NEXT: return %[[A1]] : vector<16xf32>
-func.func @expand2(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+// CHECK-LABEL: func @fold_expandload_all_false(
+// CHECK-SAME: %[[BASE:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[PASS_THRU:.*]]: vector<16xf32>) -> vector<16xf32> {
+// CHECK-NEXT: return %[[PASS_THRU]] : vector<16xf32>
+func.func @fold_expandload_all_false(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
%c0 = arith.constant 0 : index
%mask = vector.constant_mask [0] : vector<16xi1>
%ld = vector.expandload %base[%c0], %mask, %pass_thru
@@ -163,24 +187,28 @@ func.func @expand2(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<
return %ld : vector<16xf32>
}
-// CHECK-LABEL: func @compress1(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) {
+//-----------------------------------------------------------------------------
+// [Pattern: CompressStoreFolder]
+//-----------------------------------------------------------------------------
+
+// CHECK-LABEL: func @fold_compressstore_all_true(
+// CHECK-SAME: %[[BASE:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[VALUE:.*]]: vector<16xf32>) {
// CHECK-NEXT: %[[C:.*]] = arith.constant 0 : index
-// CHECK-NEXT: vector.store %[[A1]], %[[A0]][%[[C]]] : memref<16xf32>, vector<16xf32>
+// CHECK-NEXT: vector.store %[[VALUE]], %[[BASE]][%[[C]]] : memref<16xf32>, vector<16xf32>
// CHECK-NEXT: return
-func.func @compress1(%base: memref<16xf32>, %value: vector<16xf32>) {
+func.func @fold_compressstore_all_true(%base: memref<16xf32>, %value: vector<16xf32>) {
%c0 = arith.constant 0 : index
%mask = vector.constant_mask [16] : vector<16xi1>
vector.compressstore %base[%c0], %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32>
return
}
-// CHECK-LABEL: func @compress2(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) {
+// CHECK-LABEL: func @fold_compressstore_all_false(
+// CHECK-SAME: %[[BASE:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[VALUE:.*]]: vector<16xf32>) {
// CHECK-NEXT: return
-func.func @compress2(%base: memref<16xf32>, %value: vector<16xf32>) {
+func.func @fold_compressstore_all_false(%base: memref<16xf32>, %value: vector<16xf32>) {
%c0 = arith.constant 0 : index
%mask = vector.constant_mask [0] : vector<16xi1>
vector.compressstore %base[%c0], %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32>
diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index fe4f44c..7e742af 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -682,3 +682,73 @@ gpu.module @test_kernel {
gpu.return
}
}
+
+// -----
+gpu.module @test_kernel {
+ // CHECK-LABEL: remove_unit_dim_inst_data
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK: [[cst:%.+]] = arith.constant dense<0.000000e+00> : vector<1x1x32xf32>
+ // CHECK: [[cst_0:%.+]] = arith.constant dense<true> : vector<16xi1>
+ // CHECK: [[cst_1:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex>
+ // CHECK: [[cst_2:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex>
+ // CHECK: [[ld_0:%.+]] = xegpu.load [[arg0]][[[cst_1]]], [[cst_0]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16xf32>
+ // CHECK: [[ld_1:%.+]] = xegpu.load [[arg0]][[[cst_2]]], [[cst_0]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16xf32>
+ // CHECK: [[ins_0:%.+]] = vector.insert_strided_slice [[ld_0]], [[cst]] {offsets = [0, 0, 0], strides = [1]} : vector<16xf32> into vector<1x1x32xf32>
+ // CHECK: [[ins_1:%.+]] = vector.insert_strided_slice [[ld_1]], [[ins_0]] {offsets = [0, 0, 16], strides = [1]} : vector<16xf32> into vector<1x1x32xf32>
+ gpu.func @remove_unit_dim_inst_data(%src: ui64) -> vector<1x1x32xf32> {
+ %cst = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 1, 16]>} dense<[[
+ [0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248]
+ ]]> : vector<1x1x32xindex>
+
+ %mask = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 1, 16]>} dense<true> : vector<1x1x32xi1>
+ %ld = xegpu.load %src[%cst], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<inst_data = [1, 1, 16]>, l1_hint = #xegpu.cache_hint<cached>} : ui64, vector<1x1x32xindex>, vector<1x1x32xi1> -> vector<1x1x32xf32>
+
+ gpu.return %ld : vector<1x1x32xf32>
+ }
+}
+
+// -----
+#l = #xegpu.layout<inst_data = [1, 16]>
+gpu.module @test_kernel {
+ // CHECK-LABEL: load_store_nd_with_offsets
+ // CHECK-SAME: [[arg0:%.+]]: memref<1024x1024xf32>, [[arg1:%.+]]: memref<1024x1024xf32>, [[arg2:%.+]]: memref<1024x1024xf32>
+ // CHECK-DAG: [[cst:%.+]] = arith.constant dense<0.000000e+00> : vector<1x32xf32>
+ // CHECK-DAG: [[c16:%.+]] = arith.constant 16 : index
+ // CHECK-DAG: [[c0:%.+]] = arith.constant 0 : index
+ // CHECK: [[tdesc_a:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x16xf32>
+ // CHECK: [[tdesc_b:%.+]] = xegpu.create_nd_tdesc [[arg1]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x16xf32>
+ // CHECK: [[tdesc_c:%.+]] = xegpu.create_nd_tdesc [[arg2]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x16xf32>
+ // CHECK: [[ld_a0:%.+]] = xegpu.load_nd [[tdesc_a]][[[c0]], [[c0]]] : !xegpu.tensor_desc<1x16xf32> -> vector<1x16xf32>
+ // CHECK: [[ld_a1:%.+]] = xegpu.load_nd [[tdesc_a]][[[c0]], [[c16]]] : !xegpu.tensor_desc<1x16xf32> -> vector<1x16xf32>
+ // CHECK: [[ld_b0:%.+]] = xegpu.load_nd [[tdesc_b]][[[c0]], [[c0]]] : !xegpu.tensor_desc<1x16xf32> -> vector<1x16xf32>
+ // CHECK: [[ld_b1:%.+]] = xegpu.load_nd [[tdesc_b]][[[c0]], [[c16]]] : !xegpu.tensor_desc<1x16xf32> -> vector<1x16xf32>
+ // CHECK: [[cast_a0:%.+]] = vector.shape_cast [[ld_a0]] : vector<1x16xf32> to vector<16xf32>
+ // CHECK: [[cast_b0:%.+]] = vector.shape_cast [[ld_b0]] : vector<1x16xf32> to vector<16xf32>
+ // CHECK: [[add0:%.+]] = arith.addf [[cast_a0]], [[cast_b0]] : vector<16xf32>
+ // CHECK: [[ins0:%.+]] = vector.insert_strided_slice [[add0]], [[cst]] {offsets = [0, 0], strides = [1]} : vector<16xf32> into vector<1x32xf32>
+ // CHECK: [[cast_a1:%.+]] = vector.shape_cast [[ld_a1]] : vector<1x16xf32> to vector<16xf32>
+ // CHECK: [[cast_b1:%.+]] = vector.shape_cast [[ld_b1]] : vector<1x16xf32> to vector<16xf32>
+ // CHECK: [[add1:%.+]] = arith.addf [[cast_a1]], [[cast_b1]] : vector<16xf32>
+ // CHECK: [[ins1:%.+]] = vector.insert_strided_slice [[add1]], [[ins0]] {offsets = [0, 16], strides = [1]} : vector<16xf32> into vector<1x32xf32>
+ // CHECK: [[ext0:%.+]] = vector.extract_strided_slice [[ins1]] {offsets = [0, 0], sizes = [1, 16], strides = [1, 1]} : vector<1x32xf32> to vector<1x16xf32>
+ // CHECK: [[ext1:%.+]] = vector.extract_strided_slice [[ins1]] {offsets = [0, 16], sizes = [1, 16], strides = [1, 1]} : vector<1x32xf32> to vector<1x16xf32>
+ // CHECK: xegpu.store_nd [[ext0]], [[tdesc_c]][[[c0]], [[c0]]] : vector<1x16xf32>, !xegpu.tensor_desc<1x16xf32>
+ // CHECK: xegpu.store_nd [[ext1]], [[tdesc_c]][[[c0]], [[c16]]] : vector<1x16xf32>, !xegpu.tensor_desc<1x16xf32>
+ gpu.func @load_store_nd_with_offsets(%A: memref<1024x1024xf32>, %B: memref<1024x1024xf32>, %C: memref<1024x1024xf32>) {
+ %c0 = arith.constant 0 : index
+
+ %a_tdesc = xegpu.create_nd_tdesc %A : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x32xf32, #l>
+ %b_tdesc = xegpu.create_nd_tdesc %B : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x32xf32, #l>
+ %c_tdesc = xegpu.create_nd_tdesc %C : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x32xf32, #l>
+
+ %a = xegpu.load_nd %a_tdesc[%c0, %c0] : !xegpu.tensor_desc<1x32xf32, #l> -> vector<1x32xf32>
+ %b = xegpu.load_nd %b_tdesc[%c0, %c0] : !xegpu.tensor_desc<1x32xf32, #l> -> vector<1x32xf32>
+
+ %result = arith.addf %a, %b {layout_result_0 = #l} : vector<1x32xf32>
+ xegpu.store_nd %result, %c_tdesc[%c0, %c0] : vector<1x32xf32>, !xegpu.tensor_desc<1x32xf32, #l>
+ gpu.return
+ }
+}
diff --git a/mlir/test/Integration/Dialect/XeVM/GPU/gpu_printf.mlir b/mlir/test/Integration/Dialect/XeVM/GPU/gpu_printf.mlir
new file mode 100644
index 0000000..edf8775
--- /dev/null
+++ b/mlir/test/Integration/Dialect/XeVM/GPU/gpu_printf.mlir
@@ -0,0 +1,30 @@
+// RUN: mlir-opt %s \
+// RUN: | mlir-opt -pass-pipeline='builtin.module(cse,func.func(gpu-async-region),xevm-attach-target,gpu.module(convert-gpu-to-llvm-spv{use-64bit-index=true},convert-xevm-to-llvm,cse))' \
+// RUN: | mlir-opt -convert-scf-to-cf -convert-cf-to-llvm -convert-vector-to-llvm -convert-arith-to-llvm \
+// RUN: | mlir-opt -gpu-to-llvm -reconcile-unrealized-casts -cse -gpu-module-to-binary \
+// RUN: | mlir-runner \
+// RUN: --shared-libs=%mlir_sycl_runtime \
+// RUN: --shared-libs=%mlir_runner_utils \
+// RUN: --shared-libs=%mlir_c_runner_utils \
+// RUN: --entry-point-result=void \
+// RUN: | FileCheck %s
+
+module @test attributes {gpu.container_module} {
+ gpu.module @test_module {
+ gpu.func @test_printf(%arg0: i32, %arg1: f32) kernel {
+ gpu.printf "Hello: %d\n", %arg0 : i32
+ gpu.printf "Hello: %f\n", %arg1 : f32
+ gpu.return
+ }
+ }
+
+ func.func @main() attributes {llvm.emit_c_interface} {
+ %c1 = arith.constant 1 : index
+ %c11 = arith.constant 11 : i32
+ %c4 = arith.constant 4.0 : f32
+ // CHECK: Hello: 11
+ // CHECK: Hello: 4.000000
+ gpu.launch_func @test_module::@test_printf blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%c11 : i32, %c4 : f32)
+ return
+ }
+}
diff --git a/mlir/test/Transforms/scf-loop-unroll.mlir b/mlir/test/Transforms/scf-loop-unroll.mlir
index 0ef6ad1..db96c65 100644
--- a/mlir/test/Transforms/scf-loop-unroll.mlir
+++ b/mlir/test/Transforms/scf-loop-unroll.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=3" -split-input-file -canonicalize | FileCheck %s
// RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=1" -split-input-file -canonicalize | FileCheck %s --check-prefix UNROLL-BY-1
-// RUN: mlir-opt %s --test-loop-unrolling="unroll-full=true" -split-input-file -canonicalize | FileCheck %s --check-prefix UNROLL-FULL
+// RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=-1" -split-input-file -canonicalize | FileCheck %s --check-prefix UNROLL-FULL
// CHECK-LABEL: scf_loop_unroll_single
func.func @scf_loop_unroll_single(%arg0 : f32, %arg1 : f32) -> f32 {
diff --git a/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp b/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp
index ced0033..2470380 100644
--- a/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp
@@ -42,11 +42,10 @@ struct TestLoopUnrollingPass
TestLoopUnrollingPass(const TestLoopUnrollingPass &) {}
explicit TestLoopUnrollingPass(uint64_t unrollFactorParam,
unsigned loopDepthParam,
- bool annotateLoopParam, bool unrollFullParam) {
+ bool annotateLoopParam) {
unrollFactor = unrollFactorParam;
loopDepth = loopDepthParam;
annotateLoop = annotateLoopParam;
- unrollFull = unrollFactorParam;
}
void getDependentDialects(DialectRegistry &registry) const override {
@@ -54,6 +53,12 @@ struct TestLoopUnrollingPass
}
void runOnOperation() override {
+ if (!(unrollFactor.getValue() > 0 || unrollFactor.getValue() == -1)) {
+ emitError(UnknownLoc::get(&getContext()),
+ "Invalid option: 'unroll-factor' should be greater than 0 or "
+ "equal to -1");
+ return signalPassFailure();
+ }
SmallVector<scf::ForOp, 4> loops;
getOperation()->walk([&](scf::ForOp forOp) {
if (getNestingDepth(forOp) == loopDepth)
@@ -65,15 +70,15 @@ struct TestLoopUnrollingPass
}
};
for (auto loop : loops) {
- if (unrollFull)
+ if (unrollFactor.getValue() == -1)
(void)loopUnrollFull(loop);
else
(void)loopUnrollByFactor(loop, unrollFactor, annotateFn);
}
}
- Option<uint64_t> unrollFactor{*this, "unroll-factor",
- llvm::cl::desc("Loop unroll factor."),
- llvm::cl::init(1)};
+ Option<int64_t> unrollFactor{*this, "unroll-factor",
+ llvm::cl::desc("Loop unroll factor."),
+ llvm::cl::init(1)};
Option<bool> annotateLoop{*this, "annotate",
llvm::cl::desc("Annotate unrolled iterations."),
llvm::cl::init(false)};
@@ -82,9 +87,6 @@ struct TestLoopUnrollingPass
llvm::cl::init(false)};
Option<unsigned> loopDepth{*this, "loop-depth", llvm::cl::desc("Loop depth."),
llvm::cl::init(0)};
- Option<bool> unrollFull{*this, "unroll-full",
- llvm::cl::desc("Full unroll loops."),
- llvm::cl::init(false)};
};
} // namespace