aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/Dialect
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/Dialect')
-rw-r--r--mlir/test/Dialect/Async/canonicalize.mlir10
-rw-r--r--mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir33
-rw-r--r--mlir/test/Dialect/GPU/invalid.mlir44
-rw-r--r--mlir/test/Dialect/GPU/ops.mlir5
-rw-r--r--mlir/test/Dialect/Linalg/canonicalize.mlir71
-rw-r--r--mlir/test/Dialect/Linalg/data-layout-propagation.mlir18
-rw-r--r--mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir55
-rw-r--r--mlir/test/Dialect/Linalg/invalid.mlir37
-rw-r--r--mlir/test/Dialect/Linalg/transform-lower-pack.mlir16
-rw-r--r--mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir149
-rw-r--r--mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir24
-rw-r--r--mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir51
-rw-r--r--mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir13
-rw-r--r--mlir/test/Dialect/OpenMP/ops.mlir10
-rw-r--r--mlir/test/Dialect/SCF/canonicalize.mlir34
-rw-r--r--mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir24
-rw-r--r--mlir/test/Dialect/SPIRV/IR/bit-ops.mlir6
-rw-r--r--mlir/test/Dialect/SPIRV/IR/gl-ops.mlir38
-rw-r--r--mlir/test/Dialect/SPIRV/IR/group-ops.mlir2
-rw-r--r--mlir/test/Dialect/SPIRV/IR/image-ops.mlir2
-rw-r--r--mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir44
-rw-r--r--mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir10
-rw-r--r--mlir/test/Dialect/SPIRV/IR/logical-ops.mlir20
-rw-r--r--mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir16
-rw-r--r--mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir12
-rw-r--r--mlir/test/Dialect/SPIRV/IR/types.mlir6
-rw-r--r--mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir24
-rw-r--r--mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/availability.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/canonicalize.mlir20
-rw-r--r--mlir/test/Dialect/Tosa/controlflow.mlir35
-rw-r--r--mlir/test/Dialect/Tosa/dynamic_extension.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/error_if_check.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/invalid.mlir20
-rw-r--r--mlir/test/Dialect/Tosa/invalid_extension.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/level_check.mlir20
-rw-r--r--mlir/test/Dialect/Tosa/ops.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir27
-rw-r--r--mlir/test/Dialect/Tosa/verifier.mlir85
-rw-r--r--mlir/test/Dialect/Vector/canonicalize.mlir256
-rw-r--r--mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir10
-rw-r--r--mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir126
-rw-r--r--mlir/test/Dialect/Vector/int-range-interface.mlir38
-rw-r--r--mlir/test/Dialect/Vector/invalid.mlir56
-rw-r--r--mlir/test/Dialect/Vector/ops.mlir32
-rw-r--r--mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir16
-rw-r--r--mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir24
-rw-r--r--mlir/test/Dialect/Vector/vector-sink.mlir139
-rw-r--r--mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir8
-rw-r--r--mlir/test/Dialect/XeGPU/invalid.mlir68
-rw-r--r--mlir/test/Dialect/XeGPU/ops.mlir29
-rw-r--r--mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir139
-rw-r--r--mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir242
54 files changed, 1397 insertions, 781 deletions
diff --git a/mlir/test/Dialect/Async/canonicalize.mlir b/mlir/test/Dialect/Async/canonicalize.mlir
new file mode 100644
index 0000000..1a74eaa
--- /dev/null
+++ b/mlir/test/Dialect/Async/canonicalize.mlir
@@ -0,0 +1,10 @@
+// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s
+
+// CHECK-NOT: async.execute
+
+func.func @empty_execute() {
+ %token = async.execute {
+ async.yield
+ }
+ return
+}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir
new file mode 100644
index 0000000..e2ab876
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='builtin.module(test.symbol_scope_isolated(test-one-shot-module-bufferize))' -split-input-file | FileCheck %s
+
+"test.symbol_scope_isolated"() ({
+ // CHECK-LABEL: func @inner_func(
+ // CHECK-SAME: %[[arg0:.*]]: memref<?xf32
+ func.func @inner_func(%t: tensor<?xf32>) -> (tensor<?xf32>, f32) {
+ // CHECK-NOT: copy
+ %f = arith.constant 1.0 : f32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK: memref.store %{{.*}}, %[[arg0]]
+ %0 = tensor.insert %f into %t[%c0] : tensor<?xf32>
+ // CHECK: %[[load:.*]] = memref.load %[[arg0]]
+ %1 = tensor.extract %0[%c1] : tensor<?xf32>
+ // CHECK: return %[[arg0]], %[[load]] : memref<?xf32{{.*}}>, f32
+ return %0, %1 : tensor<?xf32>, f32
+ }
+
+ // CHECK-LABEL: func @call_func_with_non_tensor_return(
+ // CHECK-SAME: %[[arg0:.*]]: memref<?xf32
+ func.func @call_func_with_non_tensor_return(
+ %t0: tensor<?xf32> {bufferization.writable = true}) -> (f32, tensor<?xf32>) {
+ // CHECK-NOT: alloc
+ // CHECK-NOT: copy
+ // CHECK: %[[call:.*]]:2 = call @inner_func(%[[arg0]])
+ %0, %1 = call @inner_func(%t0) : (tensor<?xf32>) -> (tensor<?xf32>, f32)
+ // CHECK: return %[[call]]#1, %[[call]]#0 : f32, memref<?xf32,{{.*}}>
+ return %1, %0 : f32, tensor<?xf32>
+ }
+ "test.finish" () : () -> ()
+}) : () -> ()
+
+
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 162ff06..35381da 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -479,20 +479,16 @@ func.func @shuffle_unsupported_type_vec(%arg0 : vector<[4]xf32>, %arg1 : i32, %a
// -----
func.func @rotate_mismatching_type(%arg0 : f32) {
- %offset = arith.constant 4 : i32
- %width = arith.constant 16 : i32
// expected-error@+1 {{op failed to verify that all of {value, rotateResult} have same type}}
- %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (i32, i1)
+ %rotate, %valid = "gpu.rotate"(%arg0) { offset = 4 : i32, width = 16 : i32 } : (f32) -> (i32, i1)
return
}
// -----
func.func @rotate_unsupported_type(%arg0 : index) {
- %offset = arith.constant 4 : i32
- %width = arith.constant 16 : i32
// expected-error@+1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'index'}}
- %rotate, %valid = gpu.rotate %arg0, %offset, %width : index
+ %rotate, %valid = gpu.rotate %arg0, 4, 16 : index
return
}
@@ -502,55 +498,31 @@ func.func @rotate_unsupported_type_vec(%arg0 : vector<[4]xf32>) {
%offset = arith.constant 4 : i32
%width = arith.constant 16 : i32
// expected-error@+1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'vector<[4]xf32>'}}
- %rotate, %valid = gpu.rotate %arg0, %offset, %width : vector<[4]xf32>
+ %rotate, %valid = gpu.rotate %arg0, 4, 16 : vector<[4]xf32>
return
}
// -----
func.func @rotate_unsupported_width(%arg0 : f32) {
- %offset = arith.constant 4 : i32
- %width = arith.constant 15 : i32
- // expected-error@+1 {{op width must be a power of two}}
- %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
+ // expected-error@+1 {{'gpu.rotate' op attribute 'width' failed to satisfy constraint: 32-bit signless integer attribute whose value is a power of two > 0}}
+ %rotate, %valid = "gpu.rotate"(%arg0) { offset = 4 : i32, width = 15 : i32 } : (f32) -> (f32, i1)
return
}
// -----
func.func @rotate_unsupported_offset(%arg0 : f32) {
- %offset = arith.constant 16 : i32
- %width = arith.constant 16 : i32
// expected-error@+1 {{op offset must be in the range [0, 16)}}
- %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
+ %rotate, %valid = "gpu.rotate"(%arg0) { offset = 16 : i32, width = 16 : i32 }: (f32) -> (f32, i1)
return
}
// -----
func.func @rotate_unsupported_offset_minus(%arg0 : f32) {
- %offset = arith.constant -1 : i32
- %width = arith.constant 16 : i32
- // expected-error@+1 {{op offset must be in the range [0, 16)}}
- %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
- return
-}
-
-// -----
-
-func.func @rotate_offset_non_constant(%arg0 : f32, %offset : i32) {
- %width = arith.constant 16 : i32
- // expected-error@+1 {{op offset is not a constant value}}
- %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
- return
-}
-
-// -----
-
-func.func @rotate_width_non_constant(%arg0 : f32, %width : i32) {
- %offset = arith.constant 0 : i32
- // expected-error@+1 {{op width is not a constant value}}
- %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
+ // expected-error@+1 {{'gpu.rotate' op attribute 'offset' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 0}}
+ %rotate, %valid = "gpu.rotate"(%arg0) { offset = -1 : i32, width = 16 : i32 } : (f32) -> (f32, i1)
return
}
diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index 2aef80f..ee1fdfa 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -140,9 +140,8 @@ module attributes {gpu.container_module} {
// CHECK: gpu.shuffle idx %{{.*}}, %{{.*}}, %{{.*}} : f32
%shfl3, %pred3 = gpu.shuffle idx %arg0, %offset, %width : f32
- // CHECK: gpu.rotate %{{.*}}, %{{.*}}, %{{.*}} : f32
- %rotate_width = arith.constant 16 : i32
- %rotate, %pred4 = gpu.rotate %arg0, %offset, %rotate_width : f32
+ // CHECK: gpu.rotate %{{.*}}, 3, 16 : f32
+ %rotate, %pred4 = gpu.rotate %arg0, 3, 16 : f32
"gpu.barrier"() : () -> ()
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 9cbb56e4..5c5f7e8 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1176,6 +1176,52 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>)
// -----
+// CHECK-LABEL: @broadcast_broadcast_fold
+// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32>
+// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x3xf32>
+// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32>
+// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) dimensions = [1, 2]
+// CHECK-NOT: linalg.broadcast
+// CHECK: return %[[BROADCAST]] : tensor<2x3x4xf32>
+func.func @broadcast_broadcast_fold(%input: tensor<2xf32>,
+ %init1: tensor<2x3xf32>,
+ %init2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+ %broadcast1 = linalg.broadcast
+ ins(%input: tensor<2xf32>)
+ outs(%init1: tensor<2x3xf32>)
+ dimensions = [1]
+ %broadcast2 = linalg.broadcast
+ ins(%broadcast1: tensor<2x3xf32>)
+ outs(%init2: tensor<2x3x4xf32>)
+ dimensions = [2]
+ func.return %broadcast2 : tensor<2x3x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast_broadcast_fold
+// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32>
+// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x4xf32>
+// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32>
+// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) dimensions = [1, 2]
+// CHECK-NOT: linalg.broadcast
+// CHECK: return %[[BROADCAST]] : tensor<2x3x4xf32>
+func.func @broadcast_broadcast_fold(%input: tensor<2xf32>,
+ %init1: tensor<2x4xf32>,
+ %init2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+ %broadcast1 = linalg.broadcast
+ ins(%input: tensor<2xf32>)
+ outs(%init1: tensor<2x4xf32>)
+ dimensions = [1]
+ %broadcast2 = linalg.broadcast
+ ins(%broadcast1: tensor<2x4xf32>)
+ outs(%init2: tensor<2x3x4xf32>)
+ dimensions = [1]
+ func.return %broadcast2 : tensor<2x3x4xf32>
+}
+
+// -----
+
func.func @transpose_1d(%input: tensor<16xf32>,
%init: tensor<16xf32>) -> tensor<16xf32> {
%transpose = linalg.transpose
@@ -1387,42 +1433,43 @@ func.func @recursive_effect(%arg : tensor<1xf32>) {
// CHECK-LABEL: @recursive_effect
// CHECK: linalg.map
+// -----
+
//===----------------------------------------------------------------------===//
// linalg.pack
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @fold_pack_constant_splat
// CHECK-NOT: linalg.pack
-// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
-func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
+// CHECK: arith.constant dense<1.000000e-01> : tensor<4x8x8x32xf32>
+func.func @fold_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> {
%cst = arith.constant dense<1.000000e-01> : tensor<64x128xf32>
%0 = linalg.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
- inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<8x16x8x32xf32>
- return %0 : tensor<8x16x8x32xf32>
+ inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<4x8x8x32xf32>
+ return %0 : tensor<4x8x8x32xf32>
}
// -----
// CHECK-LABEL: func @fold_padding_value_pack_constant_splat
// CHECK-NOT: linalg.pack
-// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
-func.func @fold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
+// CHECK: arith.constant dense<1.000000e-01> : tensor<4x8x8x32xf32>
+func.func @fold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> {
%pad = arith.constant 1.000000e-01 : f32
%cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
%0 = linalg.pack %cst
padding_value(%pad : f32)
outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
- inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
- return %0 : tensor<8x16x8x32xf32>
+ inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<4x8x8x32xf32>
+ return %0 : tensor<4x8x8x32xf32>
}
-
// -----
// CHECK-LABEL: func @nofold_padding_value_pack_constant_splat
// CHECK: arith.constant dense<1.000000e-01> : tensor<63x127xf32>
// CHECK: linalg.pack
-func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
+func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> {
%pad = arith.constant 0.0 : f32
%cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
%0 = linalg.pack %cst
@@ -1430,8 +1477,8 @@ func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32
outer_dims_perm = [1, 0]
inner_dims_pos = [0, 1]
inner_tiles = [8, 32]
- into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
- return %0 : tensor<8x16x8x32xf32>
+ into %dest : tensor<63x127xf32> -> tensor<4x8x8x32xf32>
+ return %0 : tensor<4x8x8x32xf32>
}
// -----
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 6fc8d9f..cc26fa4 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1295,24 +1295,6 @@ func.func @no_bubble_up_pack_expanded_padding_through_expand_cannot_reassociate(
// -----
-func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(%arg0: tensor<32x64xf32>) -> tensor<8x4x16x8xf32> {
- %empty = tensor.empty() : tensor<8x4x16x8xf32>
- %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32>
- %pack = linalg.pack %expanded inner_dims_pos = [0] inner_tiles = [8] into %empty : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32>
- return %pack : tensor<8x4x16x8xf32>
-}
-// CHECK-LABEL: func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x4x16x8xf32>
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]]
-// CHECK-SAME: output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32>
-// CHECK: %[[PACK:.+]] = linalg.pack %[[EXPANDED]]
-// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [8] into %[[EMPTY]]
-// CHECK-SAME: : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32>
-// CHECK: return %[[PACK]] : tensor<8x4x16x8xf32>
-
-// -----
-
func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index, %sz0: index) -> tensor<?x256x256xf32> {
%6 = tensor.empty(%dim) : tensor<?x256xf32>
%unpack = linalg.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index a00c798..5f42938 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -1076,6 +1076,44 @@ func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> te
// -----
+func.func @drop_unit_dim_mixed_static_dynamic(%arg0: tensor<1x?xf32>) -> tensor<1x16xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %padded = tensor.pad %arg0 low[%c0, %c1] high[%c0, %c0] {
+ ^bb0(%arg1: index, %arg2: index):
+ tensor.yield %cst : f32
+ } : tensor<1x?xf32> to tensor<1x16xf32>
+ return %padded : tensor<1x16xf32>
+}
+// CHECK-LABEL: func @drop_unit_dim_mixed_static_dynamic
+// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARGS:.*]] : tensor<1x?xf32> into tensor<?xf32>
+// CHECK: %[[PADDED:.*]] = tensor.pad %[[COLLAPSE]] low[1] high[0] {
+// CHECK: ^bb0(%[[IDX:.*]]: index):
+// CHECK: tensor.yield %[[CST]] : f32
+// CHECK: } : tensor<?xf32> to tensor<16xf32>
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[PADDED]] {{\[\[}}0, 1]] output_shape [1, 16] : tensor<16xf32> into tensor<1x16xf32>
+// CHECK: return %[[EXPANDED]] : tensor<1x16xf32>
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+module {
+ func.func @drop_unit_dim_corresponding_to_dynamic_dim(%arg0: tensor<1x?x?x1xf32>, %arg1: index) -> tensor<?x1x61x1xf32> {
+ %cst = arith.constant dense<1.000000e+00> : tensor<1x1x1x1xf32>
+ %0 = tensor.empty(%arg1) : tensor<?x1x61x1xf32>
+ %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %cst : tensor<1x?x?x1xf32>, tensor<1x1x1x1xf32>) outs(%0 : tensor<?x1x61x1xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %2 = arith.mulf %in, %in_0 : f32
+ %3 = arith.addf %out, %2 : f32
+ linalg.yield %3 : f32
+ } -> tensor<?x1x61x1xf32>
+ return %1 : tensor<?x1x61x1xf32>
+ }
+}
// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (0, d0)>
// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> ()>
@@ -1097,23 +1135,6 @@ func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> te
// CHECK: return %[[VAL_14]] : tensor<?x1x61x1xf32>
// CHECK: }
-#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
-#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
-#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
-module {
- func.func @drop_unit_dim_corresponding_to_dynamic_dim(%arg0: tensor<1x?x?x1xf32>, %arg1: index) -> tensor<?x1x61x1xf32> {
- %cst = arith.constant dense<1.000000e+00> : tensor<1x1x1x1xf32>
- %0 = tensor.empty(%arg1) : tensor<?x1x61x1xf32>
- %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %cst : tensor<1x?x?x1xf32>, tensor<1x1x1x1xf32>) outs(%0 : tensor<?x1x61x1xf32>) {
- ^bb0(%in: f32, %in_0: f32, %out: f32):
- %2 = arith.mulf %in, %in_0 : f32
- %3 = arith.addf %out, %2 : f32
- linalg.yield %3 : f32
- } -> tensor<?x1x61x1xf32>
- return %1 : tensor<?x1x61x1xf32>
- }
-}
-
// -----
func.func @no_fold_empty_tensor_dim_out_of_bounds(%arg0: tensor<1x?x10xf32>) -> tensor<1x?xf32> {
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index da1dfc7..40bf4d1 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1760,6 +1760,7 @@ func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf
}
// -----
+
func.func @pack_mismatch_inner_tile_size_and_output_shape(
%input : tensor<?x?xf32>, %output : tensor<?x?x8x8xf32>) -> tensor<?x?x8x8xf32> {
// expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
@@ -1824,27 +1825,47 @@ func.func @unpack_invalid_outer_dims_perm(%source: tensor<128x256xf32>, %dest: t
// -----
+func.func @pack_with_artificial_padding(%input: tensor<9xf32>, %output: tensor<3x8xf32>) -> tensor<3x8xf32> {
+ %cst = arith.constant 0.0 : f32
+ // expected-error@+1 {{expected 'tensor<2x8xf32>' for the packed domain value, got 'tensor<3x8xf32>'}}
+ %0 = linalg.pack %input padding_value(%cst : f32) inner_dims_pos = [0]
+ inner_tiles = [8] into %output
+ : tensor<9xf32> -> tensor<3x8xf32>
+ return %0 : tensor<3x8xf32>
+}
+
+// -----
+
// The outer dims in the output tensor are incorrectly/unexpectedly transposed.
// This could be fixed by adding `outer_dims_perm = [1, 0]` (the default value assumes no transpose).
func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<4x16x32x16xf32>) -> tensor<4x16x32x16xf32> {
- // expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<16x4x32x16xf32>', got 'tensor<4x16x32x16xf32>'}}
+ // expected-error@+1 {{expected 'tensor<16x4x32x16xf32>' for the packed domain value, got 'tensor<4x16x32x16xf32>'}}
%0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [32, 16] into %output : tensor<256x128xf32> -> tensor<4x16x32x16xf32>
return %0 : tensor<4x16x32x16xf32>
}
// -----
-func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
- // expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x8x16x32xf32>', got 'tensor<8x8x32x16xf32>'}}
- %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32>
- return %0 : tensor<8x8x32x16xf32>
+func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<8x7x16x32xf32>) -> tensor<8x7x16x32xf32> {
+ // expected-error@+1 {{expected 'tensor<8x8x16x32xf32>' for the packed domain value, got 'tensor<8x7x16x32xf32>'}}
+ %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x7x16x32xf32>
+ return %0 : tensor<8x7x16x32xf32>
+}
+
+// -----
+
+func.func @unpack_with_artifical_tiles_that_are_dropped(%input: tensor<3x8xf32>, %output: tensor<9xf32>) -> tensor<9xf32> {
+ // expected-error@+1 {{expected 'tensor<2x8xf32>' for the packed domain value, got 'tensor<3x8xf32>'}}
+ %0 = linalg.unpack %input inner_dims_pos = [0] inner_tiles = [8] into %output
+ : tensor<3x8xf32> -> tensor<9xf32>
+ return %0 : tensor<9xf32>
}
// -----
-func.func @unpack_invalid(%output: tensor<256x128xf32>, %input: tensor<8x8x32x16xf32>) -> tensor<256x128xf32> {
- // expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x32x4x32xf32>', got 'tensor<8x8x32x16xf32>'}}
- %0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
+func.func @unpack_invalid_source_shape(%output: tensor<256x128xf32>, %input: tensor<8x8x4x32xf32>) -> tensor<256x128xf32> {
+ // expected-error@+1 {{expected 'tensor<8x32x4x32xf32>' for the packed domain value, got 'tensor<8x8x4x32xf32>'}}
+ %0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x4x32xf32> -> tensor<256x128xf32>
return %0 : tensor<256x128xf32>
}
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 81fd7a8..9e7681d 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -326,23 +326,23 @@ module attributes {transform.with_named_sequence} {
// -----
// CHECK-LABEL: func.func @pack_with_pad(
-func.func @pack_with_pad(%src: tensor<4225x12xf32>, %dest: tensor<265x16x16x1xf32>)
- -> tensor<265x16x16x1xf32> {
+func.func @pack_with_pad(%src: tensor<4225x12xf32>, %dest: tensor<265x12x16x1xf32>)
+ -> tensor<265x12x16x1xf32> {
// CHECK: tensor.pad {{.*}} low[0, 0]
- // CHECK: : tensor<4225x12xf32> to tensor<4240x16xf32>
+ // CHECK: : tensor<4225x12xf32> to tensor<4240x12xf32>
// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2, 3]]
- // CHECK-SAME: : tensor<4240x16xf32> into tensor<265x16x16x1xf32>
+ // CHECK-SAME: : tensor<4240x12xf32> into tensor<265x16x12x1xf32>
// CHECK: linalg.transpose
- // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>)
- // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>)
+ // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<265x16x12x1xf32>)
+ // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<265x12x16x1xf32>)
// CHECK-SAME: permutation = [0, 2, 1, 3]
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.pack %src
padding_value(%cst : f32)
inner_dims_pos = [0, 1]
inner_tiles = [16, 1] into %dest
- : tensor<4225x12xf32> -> tensor<265x16x16x1xf32>
- return %0 : tensor<265x16x16x1xf32>
+ : tensor<4225x12xf32> -> tensor<265x12x16x1xf32>
+ return %0 : tensor<265x12x16x1xf32>
}
module attributes {transform.with_named_sequence} {
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir
index 78619b6..981f5dc 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir
@@ -52,22 +52,22 @@ module {
// CHECK-LABEL: @generic
// CHECK-SAME: %[[T0:.*]]: tensor<7x5xf32>,
-// CHECK-SAME: %[[T1:.*]]: tensor<7x11x12xf32>)
- func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x12xf32>) -> tensor<7x11x12xf32> {
+// CHECK-SAME: %[[T1:.*]]: tensor<7x11x11xf32>)
+ func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x11xf32>) -> tensor<7x11x11xf32> {
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.
// CHECK: %[[PAD0:.*]] = tensor.pad %[[T0]] low[0, 0] high[2, 0]
// CHECK: : tensor<7x5xf32> to tensor<9x5xf32>
// CHECK: %[[PAD1:.*]] = tensor.pad %[[T1]] low[0, 0, 0] high[2, 4, 2] {
- // CHECK: : tensor<7x11x12xf32> to tensor<9x15x14xf32>
+ // CHECK: : tensor<7x11x11xf32> to tensor<9x15x13xf32>
// CHECK-NEXT: linalg.generic
- // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 12] [1, 1, 1] : tensor<9x15x14xf32> to tensor<7x11x12xf32>
- %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x12xf32>) {
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 11] [1, 1, 1] : tensor<9x15x13xf32> to tensor<7x11x11xf32>
+ %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x11xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
- } -> tensor<7x11x12xf32>
- return %0 : tensor<7x11x12xf32>
+ } -> tensor<7x11x11xf32>
+ return %0 : tensor<7x11x11xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
@@ -83,7 +83,7 @@ module {
// -----
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3 + 5)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3 + 4)>
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 5)>
#map = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -272,3 +272,136 @@ module attributes {transform.with_named_sequence} {
}
}
+// -----
+
+// CHECK-LABEL: pad_conv
+func.func @pad_conv(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 12]
+ // CHECK: : tensor<1x16x16x4xf32> to tensor<1x16x18x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
+ // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0]
+ // CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32>
+ // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32>
+
+ %0 = linalg.conv_2d_nhwc_fhwc
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>)
+ outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+ return %0 : tensor<1x14x14x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
+ padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
+ } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 16) * 16 + 2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 16) * 16)>
+
+// CHECK-LABEL: pad_conv_dynamic
+func.func @pad_conv_dynamic(%arg0: tensor<1x16x?x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x?x16xf32>) -> tensor<1x14x?x16xf32> {
+
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[D0_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32>
+ // CHECK: %[[D0_1:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x16x?x4xf32>
+ // CHECK: %[[H0:.*]] = affine.apply #[[$MAP0]]()[%[[D0_0]], %[[D0_1]]]
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, %[[H0]], 12]
+ // CHECK: : tensor<1x16x?x4xf32> to tensor<1x16x?x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
+ // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
+ // CHECK: %[[D1_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32>
+ // CHECK: %[[H1:.*]] = affine.apply #[[$MAP1]]()[%[[D0_0]], %[[D1_0]]]
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, %[[H1]], 0]
+ // CHECK: : tensor<1x14x?x16xf32> to tensor<1x14x?x16xf32>
+ // CHECK: %[[D2_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32>
+ // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, %[[D2_0]], 16] [1, 1, 1, 1] : tensor<1x14x?x16xf32> to tensor<1x14x?x16xf32>
+
+ %0 = linalg.conv_2d_nhwc_fhwc
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<1x16x?x4xf32>, tensor<16x3x3x4xf32>)
+ outs(%arg2: tensor<1x14x?x16xf32>) -> tensor<1x14x?x16xf32>
+ return %0 : tensor<1x14x?x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
+ padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
+ } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: pad_conv_strided
+func.func @pad_conv_strided(%arg0: tensor<1x42x42x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 6, 12]
+ // CHECK: : tensor<1x42x42x4xf32> to tensor<1x42x48x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
+ // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0]
+ // CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32>
+ // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32>
+
+ %0 = linalg.conv_2d_nhwc_fhwc
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<3> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<1x42x42x4xf32>, tensor<16x3x3x4xf32>)
+ outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+ return %0 : tensor<1x14x14x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
+ padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
+ } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: pad_conv_dilated
+func.func @pad_conv_dilated(%arg0: tensor<1x18x18x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 12]
+ // CHECK: : tensor<1x18x18x4xf32> to tensor<1x18x20x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
+ // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0]
+ // CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32>
+ // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32>
+
+ %0 = linalg.conv_2d_nhwc_fhwc
+ {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<1x18x18x4xf32>, tensor<16x3x3x4xf32>)
+ outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+ return %0 : tensor<1x14x14x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
+ padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
+ } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
index 26c03ed..f741876 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
@@ -69,22 +69,22 @@ module {
// CHECK-LABEL: @generic
// CHECK-SAME: %[[T0:.*]]: tensor<7x5xf32>,
-// CHECK-SAME: %[[T1:.*]]: tensor<7x11x12xf32>)
- func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x12xf32>) -> tensor<7x11x12xf32> {
+// CHECK-SAME: %[[T1:.*]]: tensor<7x11x11xf32>)
+ func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x11xf32>) -> tensor<7x11x11xf32> {
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.
// CHECK: %[[PAD0:.*]] = tensor.pad %[[T0]] low[0, 0] high[1, 0]
// CHECK: : tensor<7x5xf32> to tensor<8x5xf32>
// CHECK: %[[PAD1:.*]] = tensor.pad %[[T1]] low[0, 0, 0] high[1, 3, 1] {
- // CHECK: : tensor<7x11x12xf32> to tensor<8x14x13xf32>
+ // CHECK: : tensor<7x11x11xf32> to tensor<8x14x12xf32>
// CHECK-NEXT: linalg.generic
- // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 12] [1, 1, 1] : tensor<8x14x13xf32> to tensor<7x11x12xf32>
- %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x12xf32>) {
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 11] [1, 1, 1] : tensor<8x14x12xf32> to tensor<7x11x11xf32>
+ %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x11xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
- } -> tensor<7x11x12xf32>
- return %0 : tensor<7x11x12xf32>
+ } -> tensor<7x11x11xf32>
+ return %0 : tensor<7x11x11xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
@@ -102,7 +102,7 @@ module {
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (-s0 + 8)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (-s0 + 13)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (-s0 + 12)>
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 5)>
#map = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -127,13 +127,13 @@ module {
// CHECK: %[[D2_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<?x11x?xf32>
// CHECK: %[[H2:.*]] = affine.apply #[[$MAP1]]()[%[[D2_0]]]
// CHECK: tensor.pad %{{.*}} low[0, 0, 0] high[%[[H1]], 3, %[[H2]]] {
- // CHECK: : tensor<?x11x?xf32> to tensor<8x14x13xf32>
+ // CHECK: : tensor<?x11x?xf32> to tensor<8x14x12xf32>
//
// CHECK: %[[D0_2:.*]] = tensor.dim %{{.*}}, %[[C0]] : tensor<?x5xf32>
// CHECK: %[[D2_1:.*]] = affine.apply #[[$MAP2]]()[%[[D0_2]]]
- // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<8x5xf32>) outs(%{{.*}} : tensor<8x14x13xf32>) {
- // CHECK: } -> tensor<8x14x13xf32>
- // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [%[[D0_2]], 11, %[[D2_1]]] [1, 1, 1] : tensor<8x14x13xf32> to tensor<?x11x?xf32>
+ // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<8x5xf32>) outs(%{{.*}} : tensor<8x14x12xf32>) {
+ // CHECK: } -> tensor<8x14x12xf32>
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [%[[D0_2]], 11, %[[D2_1]]] [1, 1, 1] : tensor<8x14x12xf32> to tensor<?x11x?xf32>
//
%0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<?x5xf32>) outs(%arg1 : tensor<?x11x?xf32>) {
^bb0(%in: f32, %out: f32):
diff --git a/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir
index c3ee892..d7722ea 100644
--- a/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir
@@ -230,18 +230,17 @@ func.func @vectorize_nd_tensor_extract_index_from_tensor(%arg0: tensor<3x3xf32>,
// CHECK-SAME: %[[ARG4:.*]]: tensor<4x7x3x2xf32>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[PV:.*]] = ub.poison : i32
-// CHECK-DAG: %[[CST:.*]] = arith.constant dense<3> : vector<7x2x4x3xindex>
+// CHECK-DAG: %[[CST:.*]] = arith.constant dense<3> : vector<4x3xindex>
// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<true> : vector<4x7x3x2xi1>
// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<4x7x3x2xf32>
// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], %[[PV]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32>
// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], %[[PV]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32>
// CHECK: %[[CAST:.*]] = arith.index_cast %[[V0]] : vector<4x3xi32> to vector<4x3xindex>
-// CHECK: %[[B1:.*]] = vector.broadcast %[[CAST]] : vector<4x3xindex> to vector<7x2x4x3xindex>
// CHECK: %[[CAST_1:.*]] = arith.index_cast %[[V1]] : vector<4x3xi32> to vector<4x3xindex>
-// CHECK: %[[B2:.*]] = vector.broadcast %[[CAST_1]] : vector<4x3xindex> to vector<7x2x4x3xindex>
-// CHECK: %[[MULI:.*]] = arith.muli %[[B1]], %[[CST]] : vector<7x2x4x3xindex>
-// CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[MULI]] : vector<7x2x4x3xindex>
-// CHECK: %[[T:.*]] = vector.transpose %[[ADDI]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex>
+// CHECK: %[[MULI:.*]] = arith.muli %[[CAST]], %[[CST]] : vector<4x3xindex>
+// CHECK: %[[ADDI:.*]] = arith.addi %[[CAST_1]], %[[MULI]] : vector<4x3xindex>
+// CHECK: %[[B:.*]] = vector.broadcast %[[ADDI]] : vector<4x3xindex> to vector<7x2x4x3xindex>
+// CHECK: %[[T:.*]] = vector.transpose %[[B]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex>
// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]]] [%[[T]]], %[[CST_1]], %[[PASSTHRU]] : tensor<3x3xf32>, vector<4x7x3x2xindex>, vector<4x7x3x2xi1>, vector<4x7x3x2xf32> into vector<4x7x3x2xf32>
// CHECK: vector.transfer_write %[[GATHER]], %[[ARG4]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true]} : vector<4x7x3x2xf32>, tensor<4x7x3x2xf32>
@@ -270,20 +269,16 @@ func.func @vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load(%
// CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32>
// CHECK-SAME: %[[ARG1:.*]]: index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[CST:.*]] = arith.constant dense<768> : vector<1x8xindex>
-// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<128> : vector<1x8xindex>
// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
-// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<true> : vector<8x1xi1>
-// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
+// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<true> : vector<8x1xi1>
+// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<[0, 98304, 196608, 294912, 393216, 491520, 589824, 688128]> : vector<8xindex>
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<8x1xf32>
-// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_3]] : vector<8xindex> to vector<1x8xindex>
// CHECK: %[[ADDI_ARG1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index
-// CHECK: %[[MULI_1:.*]] = arith.muli %[[B1]], %[[CST_0]] : vector<1x8xindex>
-// CHECK: %[[MULI_2:.*]] = arith.muli %[[MULI_1]], %[[CST]] : vector<1x8xindex>
-// CHECK: %[[T:.*]] = vector.transpose %[[MULI_2]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
+// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_1]] : vector<8xindex> to vector<1x8xindex>
+// CHECK: %[[T:.*]] = vector.transpose %[[B1]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
// CHECK: %[[B2:.*]] = vector.broadcast %[[ADDI_ARG1]] : index to vector<8x1xindex>
// CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[T]] : vector<8x1xindex>
-// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_2]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
+// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_0]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
// CHECK: vector.transfer_write %[[GATHER]], %[[EMPTY]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
// -----
@@ -309,15 +304,13 @@ func.func @index_from_output_column_vector_gather_load(%src: tensor<8x128xf32>)
// CHECK-LABEL: func.func @index_from_output_column_vector_gather_load(
// CHECK-SAME: %[[SRC:.*]]: tensor<8x128xf32>) -> tensor<8x1xf32> {
-// CHECK: %[[C128:.*]] = arith.constant dense<128> : vector<1x8xindex>
+// CHECK: %[[IDX_VEC:.*]] = arith.constant dense<[0, 128, 256, 384, 512, 640, 768, 896]> : vector<8xindex>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[PASS_THRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8x1xi1>
-// CHECK: %[[IDX_VEC:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
// CHECK: %[[OUT:.*]] = tensor.empty() : tensor<8x1xf32>
// CHECK: %[[B:.*]] = vector.broadcast %[[IDX_VEC]] : vector<8xindex> to vector<1x8xindex>
-// CHECK: %[[MUL:.*]] = arith.muli %[[B]], %[[C128]] : vector<1x8xindex>
-// CHECK: %[[TR:.*]] = vector.transpose %[[MUL]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
+// CHECK: %[[TR:.*]] = vector.transpose %[[B]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
// CHECK: %[[GATHER:.*]] = vector.gather %[[SRC]]{{\[}}%[[C0]], %[[C0]]] {{\[}}%[[TR]]], %[[MASK]], %[[PASS_THRU]] : tensor<8x128xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
// CHECK: %[[RES:.*]] = vector.transfer_write %[[GATHER]], %[[OUT]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
// CHECK: return %[[RES]] : tensor<8x1xf32>
@@ -420,12 +413,12 @@ func.func @vectorize_nd_tensor_extract_with_affine_apply_gather(%6: tensor<80x16
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<true> : vector<1x4xi1>
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32>
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_7:.*]] = arith.constant dense<16> : vector<1x4xindex>
+// CHECK-DAG: %[[VAL_7:.*]] = arith.constant dense<16> : vector<4xindex>
// CHECK: %[[VAL_8:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex>
// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : vector<4xindex>
-// CHECK: %[[VAL_10:.*]] = vector.broadcast %[[VAL_9]] : vector<4xindex> to vector<1x4xindex>
-// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_10]], %[[VAL_7]] : vector<1x4xindex>
-// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_7]] : vector<1x4xindex>
+// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_9]], %[[VAL_7]] : vector<4xindex>
+// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_7]] : vector<4xindex>
+// CHECK: %[[VAL_12:.*]] = vector.broadcast %[[VAL_11]] : vector<4xindex> to vector<1x4xindex>
// CHECK: %[[VAL_13:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_12]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
// CHECK: %[[VAL_14:.*]] = vector.transfer_write %[[VAL_13]], %[[VAL_2]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
// CHECK: return %[[VAL_14]] : tensor<1x4xf32>
@@ -450,14 +443,12 @@ func.func @vectorize_nd_tensor_extract_with_maxsi_gather(%arg0: tensor<80x16xf32
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_maxsi_gather(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<80x16xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<1264> : vector<1x4xindex>
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[1264, 1265, 1266, 1267]> : vector<4xindex>
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<true> : vector<1x4xi1>
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32>
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_7:.*]] = vector.broadcast %[[VAL_2]] : vector<4xindex> to vector<1x4xindex>
-// CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_3]] : vector<1x4xindex>
-// CHECK: %[[VAL_9:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_8]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[VAL_9:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_7]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
// CHECK: %[[VAL_10:.*]] = vector.transfer_write %[[VAL_9]], %[[VAL_1]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
// CHECK: return %[[VAL_10]] : tensor<1x4xf32>
// CHECK: }
@@ -519,13 +510,13 @@ func.func @vectorize_reverse_like_tensor_extract(%arg0: tensor<1x2x3xf32>, %arg1
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]
// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]
-// CHECK-DAG: %[[CST:.+]] = arith.constant dense<3> : vector<1x1x3xindex>
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x3xi1>
// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32>
// CHECK-DAG: %[[INIT_IDX:.+]] = arith.constant dense<[2, 1, 0]> : vector<3xindex>
-// CHECK: %[[T0:.+]] = vector.broadcast %[[ARG2]] : index to vector<1x1x3xindex>
-// CHECK: %[[T1:.+]] = arith.muli %[[T0]], %[[CST]] : vector<1x1x3xindex>
+// CHECK: %[[T0:.+]] = arith.muli %[[ARG2]], %[[C3]] : index
+// CHECK: %[[T1:.+]] = vector.broadcast %[[T0]] : index to vector<1x1x3xindex>
// CHECK: %[[T2:.+]] = vector.broadcast %[[INIT_IDX]]
// CHECK: %[[T3:.+]] = arith.addi %[[T2]], %[[T1]]
// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[T3]]], %[[MASK]], %[[PASSTHRU]]
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
index 98e8f50..d41d861 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
@@ -941,20 +941,17 @@ module attributes {transform.with_named_sequence} {
// CHECK-LABEL: func @test_vectorize_dynamic_shapes_unpack
// CHECK-SAME: %[[ARG_0:.*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[ARG_1:.*]]: tensor<?x?x16x2xf32>
func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x16x2xf32>) -> tensor<?x?xf32> {
// CHECK: %[[C0:.*]] = arith.constant 0
-// CHECK: %[[DIM:.*]] = tensor.dim %arg0, %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[DIM0:.*]] = tensor.dim %arg0, %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00
// CHECK: %[[C01:.*]] = arith.constant 0
// CHECK: %[[C02:.*]] = arith.constant 0
-// CHECK: %[[DIM4:.*]] = tensor.dim %arg1, %[[C02]] : tensor<?x?x16x2xf32>
-// CHECK: %[[CNST14:.*]] = arith.constant 1
-// CHECK: %[[DIM6:.*]] = tensor.dim %arg1, %[[CNST14]] : tensor<?x?x16x2xf32>
+// CHECK: %[[DIM_0:.*]] = tensor.dim %[[ARG_1]], %[[C02]] : tensor<?x?x16x2xf32>
+// CHECK: %[[C1:.*]] = arith.constant 1
+// CHECK: %[[DIM6:.*]] = tensor.dim %[[ARG_1]], %[[C1]] : tensor<?x?x16x2xf32>
// CHECK: %[[CNST16:.*]] = arith.constant 16 : index
// CHECK: %[[CNST2:.*]] = arith.constant 2 : index
-// CHECK: %[[readMsk0:.*]] = vector.create_mask %[[DIM4]], %[[DIM6]], %[[CNST16]], %[[CNST2]] : vector<2x1x16x2xi1>
+// CHECK: %[[readMsk0:.*]] = vector.create_mask %[[DIM_0]], %[[DIM6]], %[[CNST16]], %[[CNST2]] : vector<2x1x16x2xi1>
// CHECK: %[[read0:.*]] = vector.mask %[[readMsk0]] {{.*}} vector.transfer_read %{{.*}} : tensor<?x?x16x2xf32>, vector<2x1x16x2xf32> } : vector<2x1x16x2xi1> -> vector<2x1x16x2xf32>
// CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 3, 1, 2] : vector<2x1x16x2xf32> to vector<2x2x1x16xf32>
// CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<2x2x1x16xf32> to vector<4x16xf32>
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 4c50ed3..8c846cd 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -1406,7 +1406,7 @@ func.func @omp_atomic_update(%x : memref<i32>, %expr : i32, %xBool : memref<i1>,
// CHECK-NEXT: (%[[XVAL:.*]]: i1):
// CHECK-NEXT: %[[NEWVAL:.*]] = llvm.icmp "eq" %[[XVAL]], %[[EXPRBOOL]] : i1
// CHECK-NEXT: omp.yield(%[[NEWVAL]] : i1)
- // }
+ // CHECK-NEXT: }
omp.atomic.update %xBool : memref<i1> {
^bb0(%xval: i1):
%newval = llvm.icmp "eq" %xval, %exprBool : i1
@@ -1562,6 +1562,14 @@ func.func @omp_atomic_update(%x : memref<i32>, %expr : i32, %xBool : memref<i1>,
omp.yield(%newval : i32)
}
+ // CHECK: omp.atomic.update %[[X]] : memref<i32> {
+ // CHECK-NEXT: (%[[XVAL:.*]]: i32):
+ // CHECK-NEXT: omp.yield(%{{.+}} : i32)
+ // CHECK-NEXT: } {atomic_control = #omp.atomic_control<ignore_denormal_mode = true, fine_grained_memory = true, remote_memory = true>}
+ omp.atomic.update %x : memref<i32> {
+ ^bb0(%xval:i32):
+ omp.yield(%const:i32)
+ } {atomic_control = #omp.atomic_control<ignore_denormal_mode = true, fine_grained_memory = true, remote_memory = true>}
return
}
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 12d30e17..308cf150 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1440,8 +1440,8 @@ func.func @propagate_into_execute_region() {
// -----
-// CHECK-LABEL: func @execute_region_elim
-func.func @execute_region_elim() {
+// CHECK-LABEL: func @execute_region_inline
+func.func @execute_region_inline() {
affine.for %i = 0 to 100 {
"test.foo"() : () -> ()
%v = scf.execute_region -> i64 {
@@ -1461,8 +1461,30 @@ func.func @execute_region_elim() {
// -----
-// CHECK-LABEL: func @func_execute_region_elim
-func.func @func_execute_region_elim() {
+// CHECK-LABEL: func @execute_region_no_inline
+func.func @execute_region_no_inline() {
+ affine.for %i = 0 to 100 {
+ "test.foo"() : () -> ()
+ %v = scf.execute_region -> i64 no_inline {
+ %x = "test.val"() : () -> i64
+ scf.yield %x : i64
+ }
+ "test.bar"(%v) : (i64) -> ()
+ }
+ return
+}
+
+// CHECK-NEXT: affine.for %arg0 = 0 to 100 {
+// CHECK-NEXT: "test.foo"() : () -> ()
+// CHECK-NEXT: scf.execute_region
+// CHECK-NEXT: %[[VAL:.*]] = "test.val"() : () -> i64
+// CHECK-NEXT: scf.yield %[[VAL]] : i64
+// CHECK-NEXT: }
+
+// -----
+
+// CHECK-LABEL: func @func_execute_region_inline
+func.func @func_execute_region_inline() {
"test.foo"() : () -> ()
%v = scf.execute_region -> i64 {
%c = "test.cmp"() : () -> i1
@@ -1496,8 +1518,8 @@ func.func @func_execute_region_elim() {
// -----
-// CHECK-LABEL: func @func_execute_region_elim_multi_yield
-func.func @func_execute_region_elim_multi_yield() {
+// CHECK-LABEL: func @func_execute_region_inline_multi_yield
+func.func @func_execute_region_inline_multi_yield() {
"test.foo"() : () -> ()
%v = scf.execute_region -> i64 {
%c = "test.cmp"() : () -> i1
diff --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
index 3adafc1..c703274 100644
--- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
@@ -13,7 +13,7 @@ func.func @fadd_scalar(%arg: f32) -> f32 {
// -----
func.func @fadd_bf16_scalar(%arg: bf16) -> bf16 {
- // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+ // expected-error @+1 {{operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}}
%0 = spirv.FAdd %arg, %arg : bf16
return %0 : bf16
}
@@ -33,7 +33,7 @@ func.func @fdiv_scalar(%arg: f32) -> f32 {
// -----
func.func @fdiv_bf16_scalar(%arg: bf16) -> bf16 {
- // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+ // expected-error @+1 {{operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}}
%0 = spirv.FDiv %arg, %arg : bf16
return %0 : bf16
}
@@ -53,7 +53,7 @@ func.func @fmod_scalar(%arg: f32) -> f32 {
// -----
func.func @fmod_bf16_scalar(%arg: bf16) -> bf16 {
- // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+ // expected-error @+1 {{operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}}
%0 = spirv.FMod %arg, %arg : bf16
return %0 : bf16
}
@@ -79,7 +79,7 @@ func.func @fmul_vector(%arg: vector<4xf32>) -> vector<4xf32> {
// -----
func.func @fmul_i32(%arg: i32) -> i32 {
- // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+ // expected-error @+1 {{operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}}
%0 = spirv.FMul %arg, %arg : i32
return %0 : i32
}
@@ -87,7 +87,7 @@ func.func @fmul_i32(%arg: i32) -> i32 {
// -----
func.func @fmul_bf16(%arg: bf16) -> bf16 {
- // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+ // expected-error @+1 {{operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}}
%0 = spirv.FMul %arg, %arg : bf16
return %0 : bf16
}
@@ -95,7 +95,7 @@ func.func @fmul_bf16(%arg: bf16) -> bf16 {
// -----
func.func @fmul_bf16_vector(%arg: vector<4xbf16>) -> vector<4xbf16> {
- // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+ // expected-error @+1 {{operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}}
%0 = spirv.FMul %arg, %arg : vector<4xbf16>
return %0 : vector<4xbf16>
}
@@ -103,7 +103,7 @@ func.func @fmul_bf16_vector(%arg: vector<4xbf16>) -> vector<4xbf16> {
// -----
func.func @fmul_tensor(%arg: tensor<4xf32>) -> tensor<4xf32> {
- // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+ // expected-error @+1 {{operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}}
%0 = spirv.FMul %arg, %arg : tensor<4xf32>
return %0 : tensor<4xf32>
}
@@ -123,7 +123,7 @@ func.func @fnegate_scalar(%arg: f32) -> f32 {
// -----
func.func @fnegate_bf16_scalar(%arg: bf16) -> bf16 {
- // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+ // expected-error @+1 {{operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}}
%0 = spirv.FNegate %arg : bf16
return %0 : bf16
}
@@ -143,7 +143,7 @@ func.func @frem_scalar(%arg: f32) -> f32 {
// -----
func.func @frem_bf16_scalar(%arg: bf16) -> bf16 {
- // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+ // expected-error @+1 {{operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}}
%0 = spirv.FRem %arg, %arg : bf16
return %0 : bf16
}
@@ -163,7 +163,7 @@ func.func @fsub_scalar(%arg: f32) -> f32 {
// -----
func.func @fsub_bf16_scalar(%arg: bf16) -> bf16 {
- // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+ // expected-error @+1 {{operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}}
%0 = spirv.FSub %arg, %arg : bf16
return %0 : bf16
}
@@ -348,7 +348,7 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 {
// -----
func.func @dot(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
- // expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16}}
+ // expected-error @+1 {{'spirv.Dot' op operand #0 must be fixed-length vector of 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16}}
%0 = spirv.Dot %arg0, %arg1 : vector<4xi32> -> i32
return %0 : i32
}
@@ -558,7 +558,7 @@ func.func @vector_times_scalar(%vector: vector<4xf32>, %scalar: f32) -> vector<3
// -----
func.func @vector_bf16_times_scalar_bf16(%vector: vector<4xbf16>, %scalar: bf16) -> vector<4xbf16> {
- // expected-error @+1 {{op operand #0 must be vector of 16/32/64-bit float values of length 2/3/4}}
+ // expected-error @+1 {{operand #0 must be vector of 16/32/64-bit float values of length 2/3/4}}
%0 = spirv.VectorTimesScalar %vector, %scalar : (vector<4xbf16>, bf16) -> vector<4xbf16>
return %0 : vector<4xbf16>
}
diff --git a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
index f3f0ebf..4bdac19 100644
--- a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
@@ -137,7 +137,7 @@ func.func @bitwise_or_all_ones_vector(%arg: vector<3xi8>) -> vector<3xi8> {
// -----
func.func @bitwise_or_float(%arg0: f16, %arg1: f16) -> f16 {
- // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}}
+ // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4}}
%0 = spirv.BitwiseOr %arg0, %arg1 : f16
return %0 : f16
}
@@ -165,7 +165,7 @@ func.func @bitwise_xor_vector(%arg: vector<4xi32>) -> vector<4xi32> {
// -----
func.func @bitwise_xor_float(%arg0: f16, %arg1: f16) -> f16 {
- // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}}
+ // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4}}
%0 = spirv.BitwiseXor %arg0, %arg1 : f16
return %0 : f16
}
@@ -274,7 +274,7 @@ func.func @bitwise_and_zext_vector(%arg: vector<2xi8>) -> vector<2xi32> {
// -----
func.func @bitwise_and_float(%arg0: f16, %arg1: f16) -> f16 {
- // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}}
+ // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4}}
%0 = spirv.BitwiseAnd %arg0, %arg1 : f16
return %0 : f16
}
diff --git a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
index 5c5d94c..fd8a2ff 100644
--- a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
@@ -19,7 +19,7 @@ func.func @expvec(%arg0 : vector<3xf16>) -> () {
// -----
func.func @exp(%arg0 : i32) -> () {
- // expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values}}
+ // expected-error @+1 {{op operand #0 must be 16/32-bit float or fixed-length vector of 16/32-bit float values}}
%2 = spirv.GL.Exp %arg0 : i32
return
}
@@ -27,7 +27,7 @@ func.func @exp(%arg0 : i32) -> () {
// -----
func.func @exp(%arg0 : vector<5xf32>) -> () {
- // expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values of length 2/3/4}}
+ // expected-error @+1 {{op operand #0 must be 16/32-bit float or fixed-length vector of 16/32-bit float values of length 2/3/4}}
%2 = spirv.GL.Exp %arg0 : vector<5xf32>
return
}
@@ -51,7 +51,7 @@ func.func @exp(%arg0 : i32) -> () {
// -----
func.func @exp_bf16(%arg0 : bf16) -> () {
- // expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values of length 2/3/4}}
+ // expected-error @+1 {{op operand #0 must be 16/32-bit float or fixed-length vector of 16/32-bit float values of length 2/3/4}}
%2 = spirv.GL.Exp %arg0 : bf16
return
}
@@ -101,7 +101,7 @@ func.func @iminmax(%arg0: i32, %arg1: i32) {
// -----
func.func @fmaxminbf16vec(%arg0 : vector<3xbf16>, %arg1 : vector<3xbf16>) {
- // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+ // expected-error @+1 {{operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}}
%1 = spirv.GL.FMax %arg0, %arg1 : vector<3xbf16>
%2 = spirv.GL.FMin %arg0, %arg1 : vector<3xbf16>
return
@@ -499,7 +499,7 @@ func.func @frexp_struct_mismatch_type(%arg0 : f32) -> () {
// -----
func.func @frexp_struct_wrong_type(%arg0 : i32) -> () {
- // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+ // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}}
%2 = spirv.GL.FrexpStruct %arg0 : i32 -> !spirv.struct<(i32, i32)>
return
}
@@ -614,7 +614,7 @@ func.func @findimsb_vector_i64(%arg0 : vector<3xi64>) -> () {
// -----
func.func @findimsb_error_scalar_float(%arg0 : f32) -> () {
- // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/1}}
+ // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/1}}
%2 = spirv.GL.FindILsb %arg0 : f32
return
}
@@ -640,7 +640,7 @@ func.func @findsmsb_vector(%arg0 : vector<3xi32>) -> () {
// -----
func.func @findsmsb_error_scalar_i64(%arg0 : i64) -> () {
- // expected-error @+1 {{operand #0 must be Int32 or vector of Int32}}
+ // expected-error @+1 {{operand #0 must be Int32 or fixed-length vector of Int32}}
%2 = spirv.GL.FindSMsb %arg0 : i64
return
}
@@ -666,7 +666,7 @@ func.func @findumsb_vector(%arg0 : vector<3xi32>) -> () {
// -----
func.func @findumsb(%arg0 : i64) -> () {
- // expected-error @+1 {{operand #0 must be Int32 or vector of Int32}}
+ // expected-error @+1 {{operand #0 must be Int32 or fixed-length vector of Int32}}
%2 = spirv.GL.FindUMsb %arg0 : i64
return
}
@@ -692,7 +692,7 @@ func.func @distance_vector(%arg0 : vector<3xf32>, %arg1 : vector<3xf32>) {
// -----
func.func @distance_invalid_type(%arg0 : i32, %arg1 : i32) {
- // expected-error @+1 {{'spirv.GL.Distance' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}}
+ // expected-error @+1 {{'spirv.GL.Distance' op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16}}
%0 = spirv.GL.Distance %arg0, %arg1 : i32, i32 -> f32
return
}
@@ -708,7 +708,7 @@ func.func @distance_arg_mismatch(%arg0 : vector<3xf32>, %arg1 : vector<4xf32>) {
// -----
func.func @distance_invalid_vector_size(%arg0 : vector<5xf32>, %arg1 : vector<5xf32>) {
- // expected-error @+1 {{'spirv.GL.Distance' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}}
+ // expected-error @+1 {{'spirv.GL.Distance' op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16}}
%0 = spirv.GL.Distance %arg0, %arg1 : vector<5xf32>, vector<5xf32> -> f32
return
}
@@ -736,7 +736,7 @@ func.func @cross(%arg0 : vector<3xf32>, %arg1 : vector<3xf32>) {
// -----
func.func @cross_invalid_type(%arg0 : vector<3xi32>, %arg1 : vector<3xi32>) {
- // expected-error @+1 {{'spirv.GL.Cross' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'vector<3xi32>'}}
+ // expected-error @+1 {{'spirv.GL.Cross' op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'vector<3xi32>'}}
%0 = spirv.GL.Cross %arg0, %arg1 : vector<3xi32>
return
}
@@ -762,7 +762,7 @@ func.func @normalize_vector(%arg0 : vector<3xf32>) {
// -----
func.func @normalize_invalid_type(%arg0 : i32) {
- // expected-error @+1 {{'spirv.GL.Normalize' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+ // expected-error @+1 {{'spirv.GL.Normalize' op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}}
%0 = spirv.GL.Normalize %arg0 : i32
return
}
@@ -788,7 +788,7 @@ func.func @reflect_vector(%arg0 : vector<3xf32>, %arg1 : vector<3xf32>) {
// -----
func.func @reflect_invalid_type(%arg0 : i32, %arg1 : i32) {
- // expected-error @+1 {{'spirv.GL.Reflect' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+ // expected-error @+1 {{'spirv.GL.Reflect' op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}}
%0 = spirv.GL.Reflect %arg0, %arg1 : i32
return
}
@@ -814,7 +814,7 @@ func.func @fractvec(%arg0 : vector<3xf16>) -> () {
// -----
func.func @fract_invalid_type(%arg0 : i32) {
- // expected-error @+1 {{'spirv.GL.Fract' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+ // expected-error @+1 {{'spirv.GL.Fract' op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}}
%0 = spirv.GL.Fract %arg0 : i32
return
}
@@ -840,7 +840,7 @@ func.func @log2vec(%arg0 : vector<3xf16>) -> () {
// -----
func.func @log2_invalid_type(%arg0 : i32) -> () {
- // expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values}}
+ // expected-error @+1 {{op operand #0 must be 16/32-bit float or fixed-length vector of 16/32-bit float values}}
%0 = spirv.GL.Log2 %arg0 : i32
return
}
@@ -866,7 +866,7 @@ func.func @tanhvec(%arg0 : vector<3xf16>) -> () {
// -----
func.func @tanh_invalid_type(%arg0 : i32) -> () {
- // expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values}}
+ // expected-error @+1 {{op operand #0 must be 16/32-bit float or fixed-length vector of 16/32-bit float values}}
%0 = spirv.GL.Tanh %arg0 : i32
return
}
@@ -892,7 +892,7 @@ func.func @exp2vec(%arg0 : vector<3xf16>) -> () {
// -----
func.func @exp2_invalid_type(%arg0 : i32) -> () {
- // expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values}}
+ // expected-error @+1 {{op operand #0 must be 16/32-bit float or fixed-length vector of 16/32-bit float values}}
%0 = spirv.GL.Exp2 %arg0 : i32
return
}
@@ -1022,7 +1022,7 @@ func.func @lengthvec(%arg0 : vector<3xf32>) -> () {
// -----
func.func @length_i32_in(%arg0 : i32) -> () {
- // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'i32'}}
+ // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'i32'}}
%0 = spirv.GL.Length %arg0 : i32 -> f32
return
}
@@ -1038,7 +1038,7 @@ func.func @length_f16_in(%arg0 : f16) -> () {
// -----
func.func @length_i32vec_in(%arg0 : vector<3xi32>) -> () {
- // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'vector<3xi32>'}}
+ // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'vector<3xi32>'}}
%0 = spirv.GL.Length %arg0 : vector<3xi32> -> f32
return
}
diff --git a/mlir/test/Dialect/SPIRV/IR/group-ops.mlir b/mlir/test/Dialect/SPIRV/IR/group-ops.mlir
index d9957ad8..d7a4a6d 100644
--- a/mlir/test/Dialect/SPIRV/IR/group-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/group-ops.mlir
@@ -49,7 +49,7 @@ func.func @group_broadcast_negative_scope(%value: f32, %localid: vector<3xi32> )
// -----
func.func @group_broadcast_negative_locid_dtype(%value: f32, %localid: vector<3xf32> ) -> f32 {
- // expected-error @+1 {{operand #1 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values}}
+ // expected-error @+1 {{op operand #1 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values}}
%0 = spirv.GroupBroadcast <Subgroup> %value, %localid : f32, vector<3xf32>
return %0: f32
}
diff --git a/mlir/test/Dialect/SPIRV/IR/image-ops.mlir b/mlir/test/Dialect/SPIRV/IR/image-ops.mlir
index d3aaef7..320a8fa 100644
--- a/mlir/test/Dialect/SPIRV/IR/image-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/image-ops.mlir
@@ -349,7 +349,7 @@ func.func @image_fetch_2d_result(%arg0: !spirv.image<f32, Dim2D, NoDepth, NonArr
// -----
func.func @image_fetch_float_coords(%arg0: !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Rgba8>, %arg1: vector<2xf32>) -> () {
- // expected-error @+1 {{op operand #1 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'vector<2xf32>'}}
+ // expected-error @+1 {{op operand #1 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'vector<2xf32>'}}
%0 = spirv.ImageFetch %arg0, %arg1 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Rgba8>, vector<2xf32> -> vector<2xf32>
spirv.Return
}
diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
index bb15d01..2e2fb1a 100644
--- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
@@ -21,7 +21,7 @@ spirv.func @f32_to_bf16_vec(%arg0 : vector<2xf32>) "None" {
// -----
spirv.func @f32_to_bf16_unsupported(%arg0 : f64) "None" {
- // expected-error @+1 {{operand #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}}
+ // expected-error @+1 {{operand #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16, but got}}
%0 = spirv.INTEL.ConvertFToBF16 %arg0 : f64 to i16
spirv.Return
}
@@ -29,7 +29,7 @@ spirv.func @f32_to_bf16_unsupported(%arg0 : f64) "None" {
// -----
spirv.func @f32_to_bf16_vec_unsupported(%arg0 : vector<2xf32>) "None" {
- // expected-error @+1 {{operand and result must have same number of elements}}
+ // expected-error @+1 {{op requires the same shape for all operands and results}}
%0 = spirv.INTEL.ConvertFToBF16 %arg0 : vector<2xf32> to vector<4xi16>
spirv.Return
}
@@ -57,7 +57,7 @@ spirv.func @bf16_to_f32_vec(%arg0 : vector<2xi16>) "None" {
// -----
spirv.func @bf16_to_f32_unsupported(%arg0 : i16) "None" {
- // expected-error @+1 {{result #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}}
+ // expected-error @+1 {{result #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16, but got}}
%0 = spirv.INTEL.ConvertBF16ToF %arg0 : i16 to f16
spirv.Return
}
@@ -65,7 +65,7 @@ spirv.func @bf16_to_f32_unsupported(%arg0 : i16) "None" {
// -----
spirv.func @bf16_to_f32_vec_unsupported(%arg0 : vector<2xi16>) "None" {
- // expected-error @+1 {{operand and result must have same number of elements}}
+ // expected-error @+1 {{op requires the same shape for all operands and results}}
%0 = spirv.INTEL.ConvertBF16ToF %arg0 : vector<2xi16> to vector<3xf32>
spirv.Return
}
@@ -73,6 +73,42 @@ spirv.func @bf16_to_f32_vec_unsupported(%arg0 : vector<2xi16>) "None" {
// -----
//===----------------------------------------------------------------------===//
+// spirv.INTEL.RoundFToTF32
+//===----------------------------------------------------------------------===//
+
+spirv.func @f32_to_tf32(%arg0 : f32) "None" {
+ // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : f32 to f32
+ %0 = spirv.INTEL.RoundFToTF32 %arg0 : f32 to f32
+ spirv.Return
+}
+
+// -----
+
+spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" {
+ // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : vector<2xf32> to vector<2xf32>
+ %0 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<2xf32>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @f32_to_tf32_unsupported(%arg0 : f64) "None" {
+ // expected-error @+1 {{op operand #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16, but got 'f64'}}
+ %0 = spirv.INTEL.RoundFToTF32 %arg0 : f64 to f32
+ spirv.Return
+}
+
+// -----
+
+spirv.func @f32_to_tf32_vec_unsupported(%arg0 : vector<2xf32>) "None" {
+ // expected-error @+1 {{op requires the same shape for all operands and results}}
+ %0 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<4xf32>
+ spirv.Return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
// spirv.INTEL.SplitBarrier
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
index 61a35b7..491c7a7 100644
--- a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
@@ -583,7 +583,7 @@ spirv.func @matrix_times_scalar(%a: !spirv.coopmatrix<2x2xf32, Workgroup, Matrix
// These binary arithmetic instructions do not support coop matrix operands.
spirv.func @fmod(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>) "None" {
- // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}}
+ // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16}}
%p = spirv.FMod %a, %b : !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>
spirv.Return
}
@@ -591,14 +591,14 @@ spirv.func @fmod(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, %b: !spirv.c
// -----
spirv.func @frem(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>) "None" {
- // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}}
+ // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16}}
%p = spirv.FRem %a, %b : !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>
spirv.Return
}
// -----
spirv.func @smod(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" {
- // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}}
+ // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}}
%p = spirv.SMod %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
spirv.Return
}
@@ -606,7 +606,7 @@ spirv.func @smod(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.c
// -----
spirv.func @srem(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" {
- // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}}
+ // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}}
%p = spirv.SRem %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
spirv.Return
}
@@ -614,7 +614,7 @@ spirv.func @srem(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.c
// -----
spirv.func @umod(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" {
- // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}}
+ // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}}
%p = spirv.UMod %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
spirv.Return
}
diff --git a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
index d6c3464..d7f4ed0 100644
--- a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
@@ -33,6 +33,24 @@ func.func @inotequal_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vecto
// -----
//===----------------------------------------------------------------------===//
+// spirv.IsFinite
+//===----------------------------------------------------------------------===//
+
+func.func @isfinite_scalar(%arg0: f32) -> i1 {
+ // CHECK: spirv.IsFinite {{.*}} : f32
+ %0 = spirv.IsFinite %arg0 : f32
+ return %0 : i1
+}
+
+func.func @isfinite_vector(%arg0: vector<2xf32>) -> vector<2xi1> {
+ // CHECK: spirv.IsFinite {{.*}} : vector<2xf32>
+ %0 = spirv.IsFinite %arg0 : vector<2xf32>
+ return %0 : vector<2xi1>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
// spirv.IsInf
//===----------------------------------------------------------------------===//
@@ -166,7 +184,7 @@ func.func @logicalUnary(%arg0 : i1)
func.func @logicalUnary(%arg0 : i32)
{
- // expected-error @+1 {{'operand' must be bool or vector of bool values of length 2/3/4/8/16, but got 'i32'}}
+ // expected-error @+1 {{'operand' must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}}
%0 = spirv.LogicalNot %arg0 : i32
return
}
diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
index 7ab94f1..bdb2abd 100644
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
@@ -185,7 +185,7 @@ func.func @group_non_uniform_fmul_clustered_reduce(%val: vector<2xf32>) -> vecto
// -----
func.func @group_non_uniform_bf16_fmul_reduce(%val: bf16) -> bf16 {
- // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'bf16'}}
+ // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'bf16'}}
%0 = spirv.GroupNonUniformFMul <Workgroup> <Reduce> %val : bf16 -> bf16
return %0: bf16
}
@@ -206,7 +206,7 @@ func.func @group_non_uniform_fmax_reduce(%val: f32) -> f32 {
// -----
func.func @group_non_uniform_bf16_fmax_reduce(%val: bf16) -> bf16 {
- // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'bf16'}}
+ // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'bf16'}}
%0 = spirv.GroupNonUniformFMax <Workgroup> <Reduce> %val : bf16 -> bf16
return %0: bf16
}
@@ -511,7 +511,7 @@ func.func @group_non_uniform_bitwise_and(%val: i32) -> i32 {
// -----
func.func @group_non_uniform_bitwise_and(%val: i1) -> i1 {
- // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}}
+ // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}}
%0 = spirv.GroupNonUniformBitwiseAnd <Workgroup> <Reduce> %val : i1 -> i1
return %0: i1
}
@@ -532,7 +532,7 @@ func.func @group_non_uniform_bitwise_or(%val: i32) -> i32 {
// -----
func.func @group_non_uniform_bitwise_or(%val: i1) -> i1 {
- // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}}
+ // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}}
%0 = spirv.GroupNonUniformBitwiseOr <Workgroup> <Reduce> %val : i1 -> i1
return %0: i1
}
@@ -553,7 +553,7 @@ func.func @group_non_uniform_bitwise_xor(%val: i32) -> i32 {
// -----
func.func @group_non_uniform_bitwise_xor(%val: i1) -> i1 {
- // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}}
+ // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}}
%0 = spirv.GroupNonUniformBitwiseXor <Workgroup> <Reduce> %val : i1 -> i1
return %0: i1
}
@@ -574,7 +574,7 @@ func.func @group_non_uniform_logical_and(%val: i1) -> i1 {
// -----
func.func @group_non_uniform_logical_and(%val: i32) -> i32 {
- // expected-error @+1 {{operand #0 must be bool or vector of bool values of length 2/3/4/8/16, but got 'i32'}}
+ // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}}
%0 = spirv.GroupNonUniformLogicalAnd <Workgroup> <Reduce> %val : i32 -> i32
return %0: i32
}
@@ -595,7 +595,7 @@ func.func @group_non_uniform_logical_or(%val: i1) -> i1 {
// -----
func.func @group_non_uniform_logical_or(%val: i32) -> i32 {
- // expected-error @+1 {{operand #0 must be bool or vector of bool values of length 2/3/4/8/16, but got 'i32'}}
+ // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}}
%0 = spirv.GroupNonUniformLogicalOr <Workgroup> <Reduce> %val : i32 -> i32
return %0: i32
}
@@ -616,7 +616,7 @@ func.func @group_non_uniform_logical_xor(%val: i1) -> i1 {
// -----
func.func @group_non_uniform_logical_xor(%val: i32) -> i32 {
- // expected-error @+1 {{operand #0 must be bool or vector of bool values of length 2/3/4/8/16, but got 'i32'}}
+ // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}}
%0 = spirv.GroupNonUniformLogicalXor <Workgroup> <Reduce> %val : i32 -> i32
return %0: i32
}
diff --git a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
index 8f021ed..6aaaa60 100644
--- a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
@@ -19,7 +19,7 @@ func.func @expvec(%arg0 : vector<3xf16>) -> () {
// -----
func.func @exp(%arg0 : i32) -> () {
- // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+ // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}}
%2 = spirv.CL.exp %arg0 : i32
return
}
@@ -27,7 +27,7 @@ func.func @exp(%arg0 : i32) -> () {
// -----
func.func @exp(%arg0 : vector<5xf32>) -> () {
- // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4}}
+ // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4}}
%2 = spirv.CL.exp %arg0 : vector<5xf32>
return
}
@@ -75,7 +75,7 @@ func.func @fabsf64(%arg0 : f64) -> () {
// -----
func.func @fabs(%arg0 : i32) -> () {
- // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+ // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}}
%2 = spirv.CL.fabs %arg0 : i32
return
}
@@ -83,7 +83,7 @@ func.func @fabs(%arg0 : i32) -> () {
// -----
func.func @fabs(%arg0 : vector<5xf32>) -> () {
- // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4}}
+ // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4}}
%2 = spirv.CL.fabs %arg0 : vector<5xf32>
return
}
@@ -137,7 +137,7 @@ func.func @sabsi8(%arg0 : i8) -> () {
// -----
func.func @sabs(%arg0 : f32) -> () {
- // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values}}
+ // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values}}
%2 = spirv.CL.s_abs %arg0 : f32
return
}
@@ -145,7 +145,7 @@ func.func @sabs(%arg0 : f32) -> () {
// -----
func.func @sabs(%arg0 : vector<5xi32>) -> () {
- // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}}
+ // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4}}
%2 = spirv.CL.s_abs %arg0 : vector<5xi32>
return
}
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index 5d05a654..6d321af 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -296,6 +296,12 @@ func.func private @struct_type_with_matrix_2(!spirv.struct<(!spirv.matrix<3 x ve
// CHECK: func private @struct_empty(!spirv.struct<()>)
func.func private @struct_empty(!spirv.struct<()>)
+// CHECK: func.func private @struct_block(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>)
+func.func private @struct_block(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>)
+
+// CHECK: func.func private @struct_two_dec(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block, CPacked>)
+func.func private @struct_two_dec(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block, CPacked>)
+
// -----
// expected-error @+1 {{offset specification must be given for all members}}
diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
index bd51a07..f3a3218 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
@@ -66,3 +66,27 @@ spirv.module Logical GLSL450 attributes {spirv.target_env = #spirv.target_env<#s
// CHECK: spirv.EntryPoint "GLCompute" [[FN]], [[VAR0]], [[VAR1]]
// CHECK: spirv.ExecutionMode [[FN]] "LocalSize", 32, 1, 1
} // end spirv.module
+
+// -----
+
+module {
+ spirv.module Logical GLSL450 attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Sampled1D], []>, #spirv.resource_limits<>>} {
+ // CHECK-DAG: spirv.GlobalVariable @[[IMAGE_GV:.*]] bind(0, 0) : !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
+ // CHECK: spirv.func @read_image
+ spirv.func @read_image(%arg0: !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) "None" attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1]>} {
+ // CHECK: %[[IMAGE_ADDR:.*]] = spirv.mlir.addressof @[[IMAGE_GV]] : !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
+ %cst0_i32 = spirv.Constant 0 : i32
+ // CHECK: spirv.Load "UniformConstant" %[[IMAGE_ADDR]]
+ %0 = spirv.Load "UniformConstant" %arg0 : !spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
+ %1 = spirv.Image %0 : !spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
+ %2 = spirv.ImageFetch %1, %cst0_i32 : !spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>, i32 -> vector<4xf32>
+ %3 = spirv.CompositeExtract %2[0 : i32] : vector<4xf32>
+ %cst0_i32_0 = spirv.Constant 0 : i32
+ %cst0_i32_1 = spirv.Constant 0 : i32
+ %cst1_i32 = spirv.Constant 1 : i32
+ %4 = spirv.AccessChain %arg1[%cst0_i32_0, %cst0_i32] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
+ spirv.Store "StorageBuffer" %4, %3 : f32
+ spirv.Return
+ }
+ }
+}
diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index 2b23766..8d7f3da 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -178,7 +178,7 @@ spirv.module Logical GLSL450 attributes {
// Vulkan memory model requires SPV_KHR_vulkan_memory_model, which is enabled
// implicitly by v1.5.
-// CHECK: requires #spirv.vce<v1.0, [VulkanMemoryModel], [SPV_KHR_vulkan_memory_model]>
+// CHECK: requires #spirv.vce<v1.5, [VulkanMemoryModel], [SPV_KHR_vulkan_memory_model]>
spirv.module Logical Vulkan attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.5, [Shader, VulkanMemoryModel], []>, #spirv.resource_limits<>>
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index 0176fc2..6398161 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -645,7 +645,7 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
// CHECK: tosa.cond_if profiles: [ ]
// CHECK: tosa.cond_if extensions: [ [controlflow] ]
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
} else {
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 6b55442..5150ee3 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -241,6 +241,26 @@ func.func @clamp_f32_is_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> {
// -----
+// CHECK-LABEL: @clamp_boolean_is_noop
+func.func @clamp_boolean_is_noop(%arg0: tensor<4xi1>) -> tensor<4xi1> {
+ // CHECK: return %arg0
+ // CHECK-NOT: tosa.clamp
+ %0 = tosa.clamp %arg0 {min_val = false, max_val = true} : (tensor<4xi1>) -> tensor<4xi1>
+ return %0 : tensor<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @clamp_boolean_dynamic_is_noop
+func.func @clamp_boolean_dynamic_is_noop(%arg0: tensor<?xi1>) -> tensor<?xi1> {
+ // CHECK: return %arg0
+ // CHECK-NOT: tosa.clamp
+ %0 = tosa.clamp %arg0 {min_val = false, max_val = true} : (tensor<?xi1>) -> tensor<?xi1>
+ return %0 : tensor<?xi1>
+}
+
+// -----
+
// CHECK-LABEL: @clamp_int8_is_noop
func.func @clamp_int8_is_noop(%arg0: tensor<4xi8>) -> tensor<4xi8> {
// CHECK: return %arg0
diff --git a/mlir/test/Dialect/Tosa/controlflow.mlir b/mlir/test/Dialect/Tosa/controlflow.mlir
new file mode 100644
index 0000000..06312c7
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/controlflow.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-opt -split-input-file %s | FileCheck %s
+
+// -----
+
+func.func @condif_cond_type_check(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // CHECK: tosa.cond_if %[[ARG2:.*]] : tensor<i1> -> tensor<f32> {
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
+ %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ // CHECK: } else {
+ } else {
+ %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @condif_block_args_check(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // CHECK: tosa.cond_if %[[ARG2:.*]] (%[[ARG3:.*]] = %[[ARG0:.*]], %[[ARG4:.*]] = %[[ARG1:.*]]) : tensor<i1> (tensor<f32>, tensor<f32>) -> tensor<f32> {
+ // CHECK-NEXT: ^bb0(%[[ARG3]]: tensor<f32>, %[[ARG4]]: tensor<f32>):
+ %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<f32>, tensor<f32>) -> tensor<f32> {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ // CHECK: } else {
+ // CHECK-NEXT: ^bb0(%[[ARG3]]: tensor<f32>, %[[ARG4]]: tensor<f32>):
+ } else {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
diff --git a/mlir/test/Dialect/Tosa/dynamic_extension.mlir b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
index 8739f97..e23ce430 100644
--- a/mlir/test/Dialect/Tosa/dynamic_extension.mlir
+++ b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
@@ -2,7 +2,7 @@
// Check operations when the dynamic extension is enabled.
//--------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic strict-op-spec-alignment allow-invalid-op-datatype-combinations"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic allow-invalid-op-datatype-combinations"
// -----
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index eb25011..fad1bec 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -259,7 +259,7 @@ func.func @test_cond_if_else_not_isolated_from_above(%arg0: tensor<f32>, %arg1:
func.func @test_cond_if_simplified_form_not_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
// expected-error@+1 {{'tosa.cond_if' op is not conformant to the TOSA specification. It requires the 'then' region is isolated from above.}}
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<f32>) {
tosa.yield %arg0 : tensor<f32>
} else {
tosa.yield %arg1 : tensor<f32>
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 716362e..3bccb32 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1125,7 +1125,7 @@ func.func @test_sub_with_unequal_result_ranks(%arg0: tensor<1x21x3xf32>, %arg1:
// CHECK-LABEL: test_mul_non_scalar_shift_2d
func.func @test_mul_non_scalar_shift_2d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
%shift = "tosa.const"() <{values = dense<0> : tensor<1x1xi8>}> : () -> tensor<1x1xi8>
- // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}}
+ // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant unranked tensor of 8-bit signless integer values or tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}}
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1x1xi8>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
@@ -1134,7 +1134,7 @@ func.func @test_mul_non_scalar_shift_2d(%arg0: tensor<13x21x3xf32>, %arg1: tenso
// CHECK-LABEL: test_mul_non_scalar_shift_1d
func.func @test_mul_non_scalar_shift_1d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
%shift = "tosa.const"() <{values = dense<0> : tensor<2xi8>}> : () -> tensor<2xi8>
- // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<2xi8>'}}
+ // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant unranked tensor of 8-bit signless integer values or tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<2xi8>'}}
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<2xi8>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
@@ -2036,3 +2036,19 @@ func.func @test_scatter_duplicate_indices(%arg0: tensor<2x52x3xf32>, %arg2: tens
%0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi32>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32>
return %0 : tensor<2x52x3xf32>
}
+
+// -----
+
+func.func @test_reduce_all_unsupported_data_types(%arg0: tensor<2x12x11xf32>) -> tensor<1x12x11xf32> {
+ // expected-error@+1 {{'tosa.reduce_all' op illegal: operation operand/result data types did not align with any profile or extension, got (f32,f32), did you mean (i1,i1)?}}
+ %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<2x12x11xf32>) -> tensor<1x12x11xf32>
+ return %0 : tensor<1x12x11xf32>
+}
+
+// -----
+
+func.func @test_rfft2d(%arg0: tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, tensor<13x8x9xbf16>) {
+ // expected-error@+1 {{'tosa.rfft2d' op illegal: operation operand/result data types did not align with any profile or extension, got (bf16,bf16,bf16), did you mean (f32,f32,f32)?}}
+ %0, %1 = tosa.rfft2d %arg0 : (tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, tensor<13x8x9xbf16>)
+ return %0, %1 : tensor<13x8x9xbf16>, tensor<13x8x9xbf16>
+}
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 5630c33..3154f54 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -337,7 +337,7 @@ func.func @test_cast_bf16_i32(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xi32
// -----
func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
// expected-error@+1 {{'tosa.cond_if' op illegal: requires [controlflow]}}
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
} else {
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 0dddf26..0184d2b 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="extension=dynamic"
func.func @test_argmax_rank_invalid(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> {
// expected-error@+1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}}
@@ -48,10 +48,10 @@ func.func @test_add_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tens
// -----
-func.func @test_arithmetic_right_shift_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+func.func @test_arithmetic_right_shift_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi32>, %arg1: tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32> {
// expected-error@+1 {{'tosa.arithmetic_right_shift' op failed level check: operand rank(shape) <= MAX_RANK}}
- %0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (tensor<1x1x1x1x13x21x3xf32>, tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
- return %0 : tensor<1x1x1x1x13x21x3xf32>
+ %0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (tensor<1x1x1x1x13x21x3xi32>, tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32>
+ return %0 : tensor<1x1x1x1x13x21x3xi32>
}
// -----
@@ -1506,13 +1506,13 @@ func.func @test_while_tensor_list_size(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1:
// -----
func.func @test_cond_if_max_nested_depth(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>, %arg3: tensor<i1>) -> tensor<f32> {
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
- %1 = tosa.cond_if %arg3 -> (tensor<f32>) {
- %2 = tosa.cond_if %arg2 -> (tensor<f32>) {
- %3 = tosa.cond_if %arg3 -> (tensor<f32>) {
- %4 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
+ %1 = tosa.cond_if %arg3 : tensor<i1>-> tensor<f32> {
+ %2 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
+ %3 = tosa.cond_if %arg3 : tensor<i1> -> tensor<f32> {
+ %4 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
// expected-error@+1 {{'tosa.cond_if' op failed level check: 6 >= MAX_NESTING}}
- %5 = tosa.cond_if %arg3 -> (tensor<f32>) {
+ %5 = tosa.cond_if %arg3 : tensor<i1> -> tensor<f32> {
%res = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %res : tensor<f32>
} else {
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index ef51197e..30361a8 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -839,7 +839,7 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
// -----
// CHECK-LABEL: cond_if
func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
} else {
diff --git a/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir b/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir
index 38ac8d8..e957bdd 100644
--- a/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir
@@ -54,7 +54,7 @@ func.func @test_no_change(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
// CHECK-LABEL: test_regions
// CHECK: %arg0: tensor<i8>, %arg1: tensor<i8>
func.func @test_regions(%arg0: tensor<ui8>, %arg1: tensor<ui8>, %arg2: tensor<i1>) -> tensor<ui8> {
- // CHECK: tosa.cond_if %arg2 -> (tensor<i8>)
+ // CHECK: tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<i8>, tensor<i8>) -> tensor<i8>
%0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
^bb0(%arg3: tensor<ui8>, %arg4: tensor<ui8>):
// CHECK: %1 = tosa.add %arg0, %arg1 : (tensor<i8>, tensor<i8>) -> tensor<i8>
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 9d43f89..7b8fc24 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -357,6 +357,17 @@ func.func @test_accepts_unranked_scalar_tensor(%arg0: tensor<1x2x2xf32>, %arg1:
// -----
+// CHECK-LABEL: @test_unranked_scalar_i8_tensor
+func.func @test_unranked_scalar_i8_tensor(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>, %arg2: tensor<1xi8>) -> tensor<4xi32> {
+ // CHECK: %[[SHIFT:.*]] = tosa.cast %arg2 : (tensor<1xi8>) -> tensor<1xi8>
+ %shift = tosa.cast %arg2 : (tensor<1xi8>) -> tensor<*xi8>
+ // CHECK: tosa.mul %arg0, %arg1, %[[SHIFT]] : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<*xi8>) -> tensor<4xi32>
+ return %0 : tensor<4xi32>
+}
+
+// -----
+
// CHECK-LABEL: @test_table_static
func.func @test_table_static(%arg0 : tensor<4x5xi16>, %arg1 : tensor<513xi16>) -> () {
// CHECK:tosa.table %arg0, %arg1 : (tensor<4x5xi16>, tensor<513xi16>) -> tensor<4x5xi16>
@@ -1166,8 +1177,8 @@ func.func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tens
%b = tosa.log %arg1 : (tensor<f32>) -> tensor<f32>
// CHECK: tosa.cond_if
- // CHECK: -> (tensor<f32>)
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ // CHECK: -> tensor<f32>
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
tosa.yield %a : tensor<f32>
} else {
tosa.yield %b : tensor<f32>
@@ -1180,8 +1191,8 @@ func.func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tens
// CHECK-LABEL: @if_test_dynamic
func.func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
// CHECK: tosa.cond_if
- // CHECK: -> (tensor<?xf32>)
- %0 = tosa.cond_if %arg2 -> (tensor<?xf32>) {
+ // CHECK: -> tensor<?xf32>
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<?xf32> {
tosa.yield %arg0 : tensor<2xf32>
} else {
tosa.yield %arg1 : tensor<3xf32>
@@ -1194,8 +1205,8 @@ func.func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 :
// CHECK-LABEL: @if_test_unranked
func.func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
// CHECK: tosa.cond_if
- // CHECK: -> (tensor<*xf32>)
- %0 = tosa.cond_if %arg2 -> (tensor<*xf32>) {
+ // CHECK: -> tensor<*xf32>
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<*xf32> {
tosa.yield %arg0 : tensor<f32>
} else {
tosa.yield %arg1 : tensor<3xf32>
@@ -1208,8 +1219,8 @@ func.func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 :
// CHECK-LABEL: @if_test_propagate
func.func @if_test_propagate(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> () {
// CHECK: tosa.cond_if
- // CHECK: -> (tensor<f32>)
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ // CHECK: -> tensor<f32>
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
} else {
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index b305236..2a937b0 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -500,9 +500,39 @@ func.func @test_cond_if_input_list_mismatch_else_block_2(%arg0: tensor<f32>, %ar
// -----
+func.func @test_cond_if_input_list_mismatch_else_block_simple(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' arguments (1) and 'input_list' (2)}}
+ %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<f32>, tensor<f32>) -> tensor<f32> {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ } else {
+ ^bb0(%arg3: tensor<f32>):
+ tosa.yield %arg3 : tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_cond_if_input_list_mismatch_else_block_simple_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' arguments (2) and 'input_list' (1)}}
+ %0 = tosa.cond_if %arg2 (%arg3 = %arg0) : tensor<i1> (tensor<f32>) -> tensor<f32> {
+ ^bb0(%arg3: tensor<f32>):
+ tosa.yield %arg3 : tensor<f32>
+ } else {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
+
+// -----
+
func.func @test_cond_if_output_list_mismatch_then_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
// expected-error@+1 {{'tosa.cond_if' op require same number of values in 'then_graph' results (2) and 'output_list' (1)}}
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
%2 = tosa.add %1, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1, %2 : tensor<f32>, tensor<f32>
@@ -517,7 +547,7 @@ func.func @test_cond_if_output_list_mismatch_then_block(%arg0: tensor<f32>, %arg
func.func @test_cond_if_output_list_mismatch_then_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
// expected-error@+1 {{'tosa.cond_if' op require same number of values in 'then_graph' results (1) and 'output_list' (2)}}
- %0, %2 = tosa.cond_if %arg2 -> (tensor<f32>, tensor<f32>) {
+ %0, %2 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<f32>, tensor<f32>) {
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
} else {
@@ -531,7 +561,7 @@ func.func @test_cond_if_output_list_mismatch_then_block_2(%arg0: tensor<f32>, %a
func.func @test_cond_if_output_list_mismatch_else_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
// expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' results (2) and 'output_list' (1)}}
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
} else {
@@ -546,7 +576,7 @@ func.func @test_cond_if_output_list_mismatch_else_block(%arg0: tensor<f32>, %arg
func.func @test_cond_if_output_list_mismatch_else_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
// expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' results (1) and 'output_list' (2)}}
- %0, %2 = tosa.cond_if %arg2 -> (tensor<f32>, tensor<f32>) {
+ %0, %2 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<f32>, tensor<f32>) {
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
%2 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1, %2 : tensor<f32>, tensor<f32>
@@ -574,6 +604,53 @@ func.func @test_cond_if_cond_input_not_size_one(%arg0: tensor<f32>, %arg1: tenso
// -----
+// CHECK-LABEL: cond_if_cond_type
+func.func @test_cond_if_cond_type(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // expected-error@+2 {{expected ':'}}
+ // expected-error@+1 {{custom op 'tosa.cond_if' expected type for condition operand}}
+ %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ tosa.yield %arg0 : tensor<f32>
+ } else {
+ tosa.yield %arg1 : tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_cond_if_input_list_type_mismatch_simple(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // expected-error@+1 {{custom op 'tosa.cond_if' expected as many input types as operands (expected 2 got 0)}}
+ %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> () -> tensor<f32> {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ } else {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_cond_if_incorrect_type_simple(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // expected-error@+2 {{expected non-function type}}
+ // expected-error@+1 {{custom op 'tosa.cond_if' expected list of types for block arguments followed by arrow type and list of return types}}
+ %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (%arg3) -> tensor<f32> {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ } else {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
+
+// -----
+
func.func @test_while_loop_input_list_mismatch_body_block_in(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
%0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
// expected-error@+1 {{'tosa.while_loop' op require same number of values in 'body_graph' arguments (3) and 'input_list' (2)}}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1461c30..f86fb38 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -823,11 +823,11 @@ func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32
// -----
-// CHECK-LABEL: fold_extract_scalar_from_splat
+// CHECK-LABEL: fold_extract_splatlike
// CHECK-SAME: %[[A:.*]]: f32
// CHECK: return %[[A]] : f32
-func.func @fold_extract_scalar_from_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
- %b = vector.splat %a : vector<1x2x4xf32>
+func.func @fold_extract_splatlike(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
+ %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
return %r : f32
}
@@ -1330,11 +1330,11 @@ func.func @fold_consecutive_broadcasts(%a : i32) -> vector<4x16xi32> {
// -----
-// CHECK-LABEL: shape_cast_constant
+// CHECK-LABEL: shape_cast_splat_constant
// CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : vector<3x4x2xi32>
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<20x2xf32>
// CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32>
-func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
+func.func @shape_cast_splat_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
%cst = arith.constant dense<2.000000e+00> : vector<5x4x2xf32>
%cst_1 = arith.constant dense<1> : vector<12x2xi32>
%0 = vector.shape_cast %cst : vector<5x4x2xf32> to vector<20x2xf32>
@@ -1344,6 +1344,36 @@ func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
// -----
+// Test of shape_cast's fold method:
+// shape_cast(constant) -> constant.
+//
+// CHECK-LABEL: @shape_cast_dense_int_constant
+// CHECK: %[[CST:.*]] = arith.constant
+// CHECK-SAME{LITERAL}: dense<[[2, 3, 5], [7, 11, 13]]>
+// CHECK: return %[[CST]] : vector<2x3xi8>
+func.func @shape_cast_dense_int_constant() -> vector<2x3xi8> {
+ %cst = arith.constant dense<[2, 3, 5, 7, 11, 13]> : vector<6xi8>
+ %0 = vector.shape_cast %cst : vector<6xi8> to vector<2x3xi8>
+ return %0 : vector<2x3xi8>
+}
+
+// -----
+
+// Test of shape_cast fold's method:
+// (shape_cast(const_x), const_x) -> (const_x_folded, const_x)
+//
+// CHECK-LABEL: @shape_cast_dense_float_constant
+// CHECK-DAG: %[[CST0:.*]] = {{.*}}1.000000e+00, 2.000000e+00{{.*}} vector<1x2xf32>
+// CHECK-DAG: %[[CST1:.*]] = {{.*}}1.000000e+00, 2.000000e+00{{.*}} vector<2xf32>
+// CHECK: return %[[CST1]], %[[CST0]] : vector<2xf32>, vector<1x2xf32>
+func.func @shape_cast_dense_float_constant() -> (vector<2xf32>, vector<1x2xf32>){
+ %cst = arith.constant dense<[[1.0, 2.0]]> : vector<1x2xf32>
+ %0 = vector.shape_cast %cst : vector<1x2xf32> to vector<2xf32>
+ return %0, %cst : vector<2xf32>, vector<1x2xf32>
+}
+
+// -----
+
// CHECK-LABEL: shape_cast_poison
// CHECK-DAG: %[[CST1:.*]] = ub.poison : vector<3x4x2xi32>
// CHECK-DAG: %[[CST0:.*]] = ub.poison : vector<20x2xf32>
@@ -2033,11 +2063,11 @@ func.func @insert_strided_slice_full_range(%source: vector<16x16xf16>, %dest: ve
// -----
-// CHECK-LABEL: extract_strided_splat
-// CHECK: %[[B:.*]] = vector.splat %{{.*}} : vector<2x4xf16>
+// CHECK-LABEL: extract_strided_splatlike
+// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} f16 to vector<2x4xf16>
// CHECK-NEXT: return %[[B]] : vector<2x4xf16>
-func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> {
- %0 = vector.splat %arg0 : vector<16x4xf16>
+func.func @extract_strided_splatlike(%arg0: f16) -> vector<2x4xf16> {
+ %0 = vector.broadcast %arg0 : f16 to vector<16x4xf16>
%1 = vector.extract_strided_slice %0
{offsets = [1, 0], sizes = [2, 4], strides = [1, 1]} :
vector<16x4xf16> to vector<2x4xf16>
@@ -2323,14 +2353,14 @@ func.func @extract_extract_strided2(%A: vector<2x4xf32>)
// -----
-// CHECK-LABEL: func @splat_fold
-func.func @splat_fold() -> vector<4xf32> {
+// CHECK-LABEL: func @splatlike_fold
+// CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
+// CHECK-NEXT: return [[V]] : vector<4xf32>
+func.func @splatlike_fold() -> vector<4xf32> {
%c = arith.constant 1.0 : f32
- %v = vector.splat %c : vector<4xf32>
+ %v = vector.broadcast %c : f32 to vector<4xf32>
return %v : vector<4xf32>
- // CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
- // CHECK-NEXT: return [[V]] : vector<4xf32>
}
// -----
@@ -2469,10 +2499,10 @@ func.func @shuffle_nofold1(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<5
// -----
-// CHECK-LABEL: func @transpose_splat_constant
+// CHECK-LABEL: func @transpose_splatlike_constant
// CHECK: %[[CST:.+]] = arith.constant dense<5.000000e+00> : vector<8x4xf32>
// CHECK: return %[[CST]]
-func.func @transpose_splat_constant() -> vector<8x4xf32> {
+func.func @transpose_splatlike_constant() -> vector<8x4xf32> {
%cst = arith.constant dense<5.0> : vector<4x8xf32>
%0 = vector.transpose %cst, [1, 0] : vector<4x8xf32> to vector<8x4xf32>
return %0 : vector<8x4xf32>
@@ -2480,13 +2510,13 @@ func.func @transpose_splat_constant() -> vector<8x4xf32> {
// -----
-// CHECK-LABEL: func @transpose_splat2(
-// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> {
-// CHECK: %[[VAL_1:.*]] = vector.splat %[[VAL_0]] : vector<3x4xf32>
-// CHECK: return %[[VAL_1]] : vector<3x4xf32>
-// CHECK: }
-func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> {
- %splat = vector.splat %arg : vector<4x3xf32>
+// CHECK-LABEL: func @transpose_splatlike2(
+// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> {
+// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : f32 to vector<3x4xf32>
+// CHECK: return %[[VAL_1]] : vector<3x4xf32>
+// CHECK: }
+func.func @transpose_splatlike2(%arg : f32) -> vector<3x4xf32> {
+ %splat = vector.broadcast %arg : f32 to vector<4x3xf32>
%0 = vector.transpose %splat, [1, 0] : vector<4x3xf32> to vector<3x4xf32>
return %0 : vector<3x4xf32>
}
@@ -2562,118 +2592,6 @@ func.func @insert_2d_splat_constant()
// -----
-// CHECK-LABEL: func @insert_element_fold
-// CHECK: %[[V:.+]] = arith.constant dense<[0, 1, 7, 3]> : vector<4xi32>
-// CHECK: return %[[V]]
-func.func @insert_element_fold() -> vector<4xi32> {
- %v = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32>
- %s = arith.constant 7 : i32
- %i = arith.constant 2 : i32
- %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32>
- return %1 : vector<4xi32>
-}
-
-// -----
-
-// CHECK-LABEL: func @insert_element_invalid_fold
-func.func @insert_element_invalid_fold() -> vector<1xf32> {
- // Out-of-bound index here.
- %c26 = arith.constant 26 : index
- %cst_2 = arith.constant 1.60215309E+9 : f32
- %cst_20 = arith.constant dense<1.60215309E+9> : vector<1xf32>
-// CHECK: vector.insertelement
- %46 = vector.insertelement %cst_2, %cst_20[%c26 : index] : vector<1xf32>
- return %46 : vector<1xf32>
-}
-
-
-// -----
-
-// Do not crash on poison
-// CHECK-LABEL: func @insert_poison_fold1
-// CHECK: vector.insertelement
-func.func @insert_poison_fold1() -> vector<4xi32> {
- %v = ub.poison : vector<4xi32>
- %s = arith.constant 7 : i32
- %i = arith.constant 2 : i32
- %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32>
- return %1 : vector<4xi32>
-}
-
-// -----
-
-// Do not crash on poison
-// CHECK-LABEL: func @insert_poison_fold2
-// CHECK: vector.insertelement
-func.func @insert_poison_fold2() -> vector<4xi32> {
- %v = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32>
- %s = ub.poison : i32
- %i = arith.constant 2 : i32
- %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32>
- return %1 : vector<4xi32>
-}
-
-// -----
-
-// Do not crash on poison
-// CHECK-LABEL: func @insert_poison_fold3
-// CHECK: vector.insertelement
-func.func @insert_poison_fold3() -> vector<4xi32> {
- %v = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32>
- %s = arith.constant 7 : i32
- %i = ub.poison : i32
- %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32>
- return %1 : vector<4xi32>
-}
-
-// -----
-
-// CHECK-LABEL: func @extract_element_fold
-// CHECK: %[[C:.+]] = arith.constant 5 : i32
-// CHECK: return %[[C]]
-func.func @extract_element_fold() -> i32 {
- %v = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32>
- %i = arith.constant 2 : i32
- %1 = vector.extractelement %v[%i : i32] : vector<4xi32>
- return %1 : i32
-}
-
-// CHECK-LABEL: func @extract_element_splat_fold
-// CHECK-SAME: (%[[ARG:.+]]: i32)
-// CHECK: return %[[ARG]]
-func.func @extract_element_splat_fold(%a : i32) -> i32 {
- %v = vector.splat %a : vector<4xi32>
- %i = arith.constant 2 : i32
- %1 = vector.extractelement %v[%i : i32] : vector<4xi32>
- return %1 : i32
-}
-
-// -----
-
-// Do not crash on poison
-// CHECK-LABEL: func @extract_element_poison_fold1
-// CHECK: vector.extractelement
-func.func @extract_element_poison_fold1() -> i32 {
- %v = ub.poison : vector<4xi32>
- %i = arith.constant 2 : i32
- %1 = vector.extractelement %v[%i : i32] : vector<4xi32>
- return %1 : i32
-}
-
-// -----
-
-// Do not crash on poison
-// CHECK-LABEL: func @extract_element_poison_fold2
-// CHECK: vector.extractelement
-func.func @extract_element_poison_fold2() -> i32 {
- %v = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32>
- %i = ub.poison : i32
- %1 = vector.extractelement %v[%i : i32] : vector<4xi32>
- return %1 : i32
-}
-
-// -----
-
// CHECK-LABEL: func @reduce_one_element_vector_extract
// CHECK-SAME: (%[[V:.+]]: vector<1xf32>)
// CHECK: %[[S:.+]] = vector.extract %[[V]][0] : f32 from vector<1xf32>
@@ -2781,13 +2699,13 @@ func.func @bitcast(%a: vector<4x8xf32>) -> vector<4x16xi16> {
// -----
-// CHECK-LABEL: @insert_strided_slice_splat
+// CHECK-LABEL: @insert_strided_slice_splatlike
// CHECK-SAME: (%[[ARG:.*]]: f32)
-// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8x16xf32>
+// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : f32 to vector<8x16xf32>
// CHECK-NEXT: return %[[SPLAT]] : vector<8x16xf32>
-func.func @insert_strided_slice_splat(%x: f32) -> (vector<8x16xf32>) {
- %splat0 = vector.splat %x : vector<4x4xf32>
- %splat1 = vector.splat %x : vector<8x16xf32>
+func.func @insert_strided_slice_splatlike(%x: f32) -> (vector<8x16xf32>) {
+ %splat0 = vector.broadcast %x : f32 to vector<4x4xf32>
+ %splat1 = vector.broadcast %x : f32 to vector<8x16xf32>
%0 = vector.insert_strided_slice %splat0, %splat1 {offsets = [2, 2], strides = [1, 1]}
: vector<4x4xf32> into vector<8x16xf32>
return %0 : vector<8x16xf32>
@@ -2860,13 +2778,13 @@ func.func @insert_strided_2d_constant() ->
// -----
-// CHECK-LABEL: func @shuffle_splat
+// CHECK-LABEL: func @shuffle_splatlike
// CHECK-SAME: (%[[ARG:.*]]: i32)
-// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<4xi32>
+// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<4xi32>
// CHECK-NEXT: return %[[SPLAT]] : vector<4xi32>
-func.func @shuffle_splat(%x : i32) -> vector<4xi32> {
- %v0 = vector.splat %x : vector<4xi32>
- %v1 = vector.splat %x : vector<2xi32>
+func.func @shuffle_splatlike(%x : i32) -> vector<4xi32> {
+ %v0 = vector.broadcast %x : i32 to vector<4xi32>
+ %v1 = vector.broadcast %x : i32 to vector<2xi32>
%shuffle = vector.shuffle %v0, %v1 [2, 3, 4, 5] : vector<4xi32>, vector<2xi32>
return %shuffle : vector<4xi32>
}
@@ -2874,13 +2792,13 @@ func.func @shuffle_splat(%x : i32) -> vector<4xi32> {
// -----
-// CHECK-LABEL: func @insert_splat
+// CHECK-LABEL: func @insert_splatlike
// CHECK-SAME: (%[[ARG:.*]]: i32)
-// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<2x4x3xi32>
+// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<2x4x3xi32>
// CHECK-NEXT: return %[[SPLAT]] : vector<2x4x3xi32>
-func.func @insert_splat(%x : i32) -> vector<2x4x3xi32> {
- %v0 = vector.splat %x : vector<4x3xi32>
- %v1 = vector.splat %x : vector<2x4x3xi32>
+func.func @insert_splatlike(%x : i32) -> vector<2x4x3xi32> {
+ %v0 = vector.broadcast %x : i32 to vector<4x3xi32>
+ %v1 = vector.broadcast %x : i32 to vector<2x4x3xi32>
%insert = vector.insert %v0, %v1[0] : vector<4x3xi32> into vector<2x4x3xi32>
return %insert : vector<2x4x3xi32>
}
@@ -2933,18 +2851,6 @@ func.func @extract_strided_slice_of_constant_mask() -> vector<5x7xi1>{
// -----
-// CHECK-LABEL: func.func @fold_extractelement_of_broadcast(
-// CHECK-SAME: %[[f:.*]]: f32
-// CHECK: return %[[f]]
-func.func @fold_extractelement_of_broadcast(%f: f32) -> f32 {
- %0 = vector.broadcast %f : f32 to vector<15xf32>
- %c5 = arith.constant 5 : index
- %1 = vector.extractelement %0 [%c5 : index] : vector<15xf32>
- return %1 : f32
-}
-
-// -----
-
// CHECK-LABEL: func.func @fold_0d_vector_reduction
func.func @fold_0d_vector_reduction(%arg0: vector<f32>) -> f32 {
// CHECK-NEXT: %[[RES:.*]] = vector.extract %arg{{.*}}[] : f32 from vector<f32>
@@ -3124,11 +3030,11 @@ func.func @rank_1_shuffle_to_interleave(%arg0: vector<6xi32>, %arg1: vector<6xi3
// -----
-// CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression(
-// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: vector<f32>, %[[c:.*]]: vector<2xf32>)
-func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>, %c: vector<2xf32>) -> (f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) {
- // Splat scalar to 0D and extract scalar.
- %0 = vector.splat %a : vector<f32>
+// CHECK-LABEL: func @extract_from_0d_splatlike_broadcast_regression(
+// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: vector<f32>, %[[C:.*]]: vector<2xf32>)
+func.func @extract_from_0d_splatlike_broadcast_regression(%a: f32, %b: vector<f32>, %c: vector<2xf32>) -> (f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) {
+ // Splat/broadcast scalar to 0D and extract scalar.
+ %0 = vector.broadcast %a : f32 to vector<f32>
%1 = vector.extract %0[] : f32 from vector<f32>
// Broadcast scalar to 0D and extract scalar.
@@ -3136,12 +3042,12 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,
%3 = vector.extract %2[] : f32 from vector<f32>
// Broadcast 0D to 3D and extract scalar.
- // CHECK: %[[extract1:.*]] = vector.extract %[[b]][] : f32 from vector<f32>
+ // CHECK: %[[EXTRACT1:.*]] = vector.extract %[[B]][] : f32 from vector<f32>
%4 = vector.broadcast %b : vector<f32> to vector<1x2x4xf32>
%5 = vector.extract %4[0, 0, 1] : f32 from vector<1x2x4xf32>
- // Splat scalar to 2D and extract scalar.
- %6 = vector.splat %a : vector<2x3xf32>
+ // Splat/broadcast scalar to 2D and extract scalar.
+ %6 = vector.broadcast %a : f32 to vector<2x3xf32>
%7 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
// Broadcast scalar to 3D and extract scalar.
@@ -3149,14 +3055,14 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,
%9 = vector.extract %8[2, 1, 5] : f32 from vector<5x6x7xf32>
// Extract 2D from 3D that was broadcasted from a scalar.
- // CHECK: %[[extract2:.*]] = vector.broadcast %[[a]] : f32 to vector<6x7xf32>
+ // CHECK: %[[EXTRACT2:.*]] = vector.broadcast %[[A]] : f32 to vector<6x7xf32>
%10 = vector.extract %8[2] : vector<6x7xf32> from vector<5x6x7xf32>
// Extract 1D from 2D that was splat'ed from a scalar.
- // CHECK: %[[extract3:.*]] = vector.broadcast %[[a]] : f32 to vector<3xf32>
+ // CHECK: %[[EXTRACT3:.*]] = vector.broadcast %[[A]] : f32 to vector<3xf32>
%11 = vector.extract %6[1] : vector<3xf32> from vector<2x3xf32>
- // CHECK: return %[[a]], %[[a]], %[[extract1]], %[[a]], %[[a]], %[[extract2]], %[[extract3]]
+ // CHECK: return %[[A]], %[[A]], %[[EXTRACT1]], %[[A]], %[[A]], %[[EXTRACT2]], %[[EXTRACT3]]
return %1, %3, %5, %7, %9, %10, %11 : f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>
}
@@ -3598,7 +3504,7 @@ func.func @fold_insert_use_chain(%arg : vector<4x4xf32>, %val : f32, %pos: index
%v_0 = vector.insert %val, %arg[%pos, 0] : f32 into vector<4x4xf32>
%v_1 = vector.insert %val, %v_0[%pos, 0] : f32 into vector<4x4xf32>
%v_2 = vector.insert %val, %v_1[%pos, 0] : f32 into vector<4x4xf32>
- return %v_2 : vector<4x4xf32>
+ return %v_2 : vector<4x4xf32>
}
// -----
@@ -3612,5 +3518,5 @@ func.func @fold_insert_use_chain(%arg : vector<4x4xf32>, %val : f32, %pos: index
func.func @no_fold_insert_use_chain_mismatch_static_position(%arg : vector<4xf32>, %val : f32) -> vector<4xf32> {
%v_0 = vector.insert %val, %arg[0] : f32 into vector<4xf32>
%v_1 = vector.insert %val, %v_0[1] : f32 into vector<4xf32>
- return %v_1 : vector<4xf32>
+ return %v_1 : vector<4xf32>
}
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
index fdab2a8..f43328f 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
@@ -36,9 +36,9 @@ func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32
// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32)
func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) {
%0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
- // CHECK: %[[SPLAT1:.*]] = vector.splat %[[A]] : vector<3xf32>
+ // CHECK: %[[SPLAT1:.*]] = vector.broadcast %[[A]] : f32 to vector<3xf32>
%1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32>
- // CHECK: %[[SPLAT2:.*]] = vector.splat %[[B]] : vector<3xf32>
+ // CHECK: %[[SPLAT2:.*]] = vector.broadcast %[[B]] : f32 to vector<3xf32>
%2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32>
// CHECK: return %[[SPLAT1]], %[[SPLAT2]]
return %1, %2 : vector<3xf32>, vector<3xf32>
@@ -63,11 +63,11 @@ func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>,
// CHECK-LABEL: func @from_elements_to_splat(
// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32)
func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) {
- // CHECK: %[[SPLAT:.*]] = vector.splat %[[A]] : vector<2x3xf32>
+ // CHECK: %[[SPLAT:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3xf32>
%0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32>
// CHECK: %[[FROM_EL:.*]] = vector.from_elements {{.*}} : vector<2x3xf32>
%1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32>
- // CHECK: %[[SPLAT2:.*]] = vector.splat %[[A]] : vector<f32>
+ // CHECK: %[[SPLAT2:.*]] = vector.broadcast %[[A]] : f32 to vector<f32>
%2 = vector.from_elements %a : vector<f32>
// CHECK: return %[[SPLAT]], %[[FROM_EL]], %[[SPLAT2]]
return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
@@ -170,7 +170,7 @@ func.func @large_source_with_shape_cast_required(%arg0: vector<2x2x2x2xi8>) -> v
// Could match, but handled by `rewriteFromElementsAsSplat`.
// CHECK-LABEL: func @extract_single_elm(
// CHECK-NEXT: vector.extract
-// CHECK-NEXT: vector.splat
+// CHECK-NEXT: vector.broadcast
// CHECK-NEXT: return
func.func @extract_single_elm(%arg0 : vector<2x3xi8>) -> vector<1xi8> {
%0 = vector.extract %arg0[0, 0] : i8 from vector<2x3xi8>
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir
new file mode 100644
index 0000000..e4a9391
--- /dev/null
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir
@@ -0,0 +1,126 @@
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+// This file should be removed when vector.splat is removed.
+// This file tests canonicalization/folding with vector.splat.
+// These tests all have equivalent tests using vector.broadcast in canonicalize.mlir
+
+
+// CHECK-LABEL: fold_extract_splat
+// CHECK-SAME: %[[A:.*]]: f32
+// CHECK: return %[[A]] : f32
+func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
+ %b = vector.splat %a : vector<1x2x4xf32>
+ %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
+ return %r : f32
+}
+
+// -----
+
+// CHECK-LABEL: extract_strided_splat
+// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} f16 to vector<2x4xf16>
+// CHECK-NEXT: return %[[B]] : vector<2x4xf16>
+func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> {
+ %0 = vector.splat %arg0 : vector<16x4xf16>
+ %1 = vector.extract_strided_slice %0
+ {offsets = [1, 0], sizes = [2, 4], strides = [1, 1]} :
+ vector<16x4xf16> to vector<2x4xf16>
+ return %1 : vector<2x4xf16>
+}
+
+// -----
+
+// CHECK-LABEL: func @splat_fold
+// CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
+// CHECK-NEXT: return [[V]] : vector<4xf32>
+func.func @splat_fold() -> vector<4xf32> {
+ %c = arith.constant 1.0 : f32
+ %v = vector.splat %c : vector<4xf32>
+ return %v : vector<4xf32>
+
+}
+
+// -----
+
+// CHECK-LABEL: func @transpose_splat2(
+// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> {
+// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : f32 to vector<3x4xf32>
+// CHECK: return %[[VAL_1]] : vector<3x4xf32>
+func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> {
+ %splat = vector.splat %arg : vector<4x3xf32>
+ %0 = vector.transpose %splat, [1, 0] : vector<4x3xf32> to vector<3x4xf32>
+ return %0 : vector<3x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @insert_strided_slice_splat
+// CHECK-SAME: (%[[ARG:.*]]: f32)
+// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : f32 to vector<8x16xf32>
+// CHECK-NEXT: return %[[SPLAT]] : vector<8x16xf32>
+func.func @insert_strided_slice_splat(%x: f32) -> (vector<8x16xf32>) {
+ %splat0 = vector.splat %x : vector<4x4xf32>
+ %splat1 = vector.splat %x : vector<8x16xf32>
+ %0 = vector.insert_strided_slice %splat0, %splat1 {offsets = [2, 2], strides = [1, 1]}
+ : vector<4x4xf32> into vector<8x16xf32>
+ return %0 : vector<8x16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @shuffle_splat
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<4xi32>
+// CHECK-NEXT: return %[[SPLAT]] : vector<4xi32>
+func.func @shuffle_splat(%x : i32) -> vector<4xi32> {
+ %v0 = vector.splat %x : vector<4xi32>
+ %v1 = vector.splat %x : vector<2xi32>
+ %shuffle = vector.shuffle %v0, %v1 [2, 3, 4, 5] : vector<4xi32>, vector<2xi32>
+ return %shuffle : vector<4xi32>
+}
+
+
+// -----
+
+// CHECK-LABEL: func @insert_splat
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<2x4x3xi32>
+// CHECK-NEXT: return %[[SPLAT]] : vector<2x4x3xi32>
+func.func @insert_splat(%x : i32) -> vector<2x4x3xi32> {
+ %v0 = vector.splat %x : vector<4x3xi32>
+ %v1 = vector.splat %x : vector<2x4x3xi32>
+ %insert = vector.insert %v0, %v1[0] : vector<4x3xi32> into vector<2x4x3xi32>
+ return %insert : vector<2x4x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression
+// CHECK-SAME: (%[[A:.*]]: f32, %[[C:.*]]: vector<2xf32>)
+func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %c: vector<2xf32>) -> (f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) {
+ // Splat scalar to 0D and extract scalar.
+ %0 = vector.splat %a : vector<f32>
+ %1 = vector.extract %0[] : f32 from vector<f32>
+
+ // Broadcast scalar to 0D and extract scalar.
+ %2 = vector.splat %a : vector<f32>
+ %3 = vector.extract %2[] : f32 from vector<f32>
+
+ // Splat scalar to 2D and extract scalar.
+ %6 = vector.splat %a : vector<2x3xf32>
+ %7 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
+
+ // Broadcast scalar to 3D and extract scalar.
+ %8 = vector.splat %a : vector<5x6x7xf32>
+ %9 = vector.extract %8[2, 1, 5] : f32 from vector<5x6x7xf32>
+
+ // Extract 2D from 3D that was broadcasted from a scalar.
+ // CHECK: %[[EXTRACT2:.*]] = vector.broadcast %[[A]] : f32 to vector<6x7xf32>
+ %10 = vector.extract %8[2] : vector<6x7xf32> from vector<5x6x7xf32>
+
+ // Extract 1D from 2D that was splat'ed from a scalar.
+ // CHECK: %[[EXTRACT3:.*]] = vector.broadcast %[[A]] : f32 to vector<3xf32>
+ %11 = vector.extract %6[1] : vector<3xf32> from vector<2x3xf32>
+
+ // CHECK: return %[[A]], %[[A]], %[[A]], %[[A]], %[[EXTRACT2]], %[[EXTRACT3]]
+ return %1, %3, %7, %9, %10, %11 : f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>
+}
diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir
index 0263193..b2f16bb 100644
--- a/mlir/test/Dialect/Vector/int-range-interface.mlir
+++ b/mlir/test/Dialect/Vector/int-range-interface.mlir
@@ -51,6 +51,15 @@ func.func @vector_shape_cast() -> vector<4x4xindex> {
func.return %2 : vector<4x4xindex>
}
+// CHECK-LABEL: func @vector_transpose
+// CHECK: test.reflect_bounds {smax = 8 : index, smin = 7 : index, umax = 8 : index, umin = 7 : index}
+func.func @vector_transpose() -> vector<2x4xindex> {
+ %0 = test.with_bounds { smax = 8 : index, smin = 7 : index, umax = 8 : index, umin = 7 : index } : vector<4x2xindex>
+ %1 = vector.transpose %0, [1, 0] : vector<4x2xindex> to vector<2x4xindex>
+ %2 = test.reflect_bounds %1 : vector<2x4xindex>
+ func.return %2 : vector<2x4xindex>
+}
+
// CHECK-LABEL: func @vector_extract
// CHECK: test.reflect_bounds {smax = 6 : index, smin = 5 : index, umax = 6 : index, umin = 5 : index}
func.func @vector_extract() -> index {
@@ -60,16 +69,6 @@ func.func @vector_extract() -> index {
func.return %2 : index
}
-// CHECK-LABEL: func @vector_extractelement
-// CHECK: test.reflect_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index}
-func.func @vector_extractelement() -> index {
- %c0 = arith.constant 0 : index
- %0 = test.with_bounds { umin = 6 : index, umax = 7 : index, smin = 6 : index, smax = 7 : index } : vector<4xindex>
- %1 = vector.extractelement %0[%c0 : index] : vector<4xindex>
- %2 = test.reflect_bounds %1 : index
- func.return %2 : index
-}
-
// CHECK-LABEL: func @vector_add
// CHECK: test.reflect_bounds {smax = 12 : index, smin = 10 : index, umax = 12 : index, umin = 10 : index}
func.func @vector_add() -> vector<4xindex> {
@@ -90,17 +89,6 @@ func.func @vector_insert() -> vector<4xindex> {
func.return %3 : vector<4xindex>
}
-// CHECK-LABEL: func @vector_insertelement
-// CHECK: test.reflect_bounds {smax = 8 : index, smin = 5 : index, umax = 8 : index, umin = 5 : index}
-func.func @vector_insertelement() -> vector<4xindex> {
- %c0 = arith.constant 0 : index
- %0 = test.with_bounds { umin = 5 : index, umax = 7 : index, smin = 5 : index, smax = 7 : index } : vector<4xindex>
- %1 = test.with_bounds { umin = 6 : index, umax = 8 : index, smin = 6 : index, smax = 8 : index } : index
- %2 = vector.insertelement %1, %0[%c0 : index] : vector<4xindex>
- %3 = test.reflect_bounds %2 : vector<4xindex>
- func.return %3 : vector<4xindex>
-}
-
// CHECK-LABEL: func @test_loaded_vector_extract
// No bounds
// CHECK: test.reflect_bounds {smax = 2147483647 : si32, smin = -2147483648 : si32, umax = 4294967295 : ui32, umin = 0 : ui32} %{{.*}} : i32
@@ -120,3 +108,11 @@ func.func @test_vector_extsi() -> vector<2xi32> {
%2 = test.reflect_bounds %1 : vector<2xi32>
func.return %2 : vector<2xi32>
}
+
+// CHECK-LABEL: func @vector_step
+// CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index}
+func.func @vector_step() -> vector<8xindex> {
+ %0 = vector.step : vector<8xindex>
+ %1 = test.reflect_bounds %0 : vector<8xindex>
+ func.return %1 : vector<8xindex>
+}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index ca837d3..c21de56 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -119,30 +119,6 @@ func.func @shuffle_empty_mask(%arg0: vector<2xf32>, %arg1: vector<2xf32>) {
// -----
-func.func @extract_element(%arg0: vector<f32>) {
- %c = arith.constant 3 : i32
- // expected-error@+1 {{expected position to be empty with 0-D vector}}
- %1 = vector.extractelement %arg0[%c : i32] : vector<f32>
-}
-
-// -----
-
-func.func @extract_element(%arg0: vector<4xf32>) {
- %c = arith.constant 3 : i32
- // expected-error@+1 {{expected position for 1-D vector}}
- %1 = vector.extractelement %arg0[] : vector<4xf32>
-}
-
-// -----
-
-func.func @extract_element(%arg0: vector<4x4xf32>) {
- %c = arith.constant 3 : i32
- // expected-error@+1 {{unexpected >1 vector rank}}
- %1 = vector.extractelement %arg0[%c : i32] : vector<4x4xf32>
-}
-
-// -----
-
func.func @extract_vector_type(%arg0: index) {
// expected-error@+1 {{invalid kind of type specified: expected builtin.vector, but found 'index'}}
%1 = vector.extract %arg0[] : index from index
@@ -192,38 +168,6 @@ func.func @extract_position_overflow(%arg0: vector<4x8x16xf32>) {
// -----
-func.func @insert_element(%arg0: f32, %arg1: vector<f32>) {
- %c = arith.constant 3 : i32
- // expected-error@+1 {{expected position to be empty with 0-D vector}}
- %0 = vector.insertelement %arg0, %arg1[%c : i32] : vector<f32>
-}
-
-// -----
-
-func.func @insert_element(%arg0: f32, %arg1: vector<4xf32>) {
- %c = arith.constant 3 : i32
- // expected-error@+1 {{expected position for 1-D vector}}
- %0 = vector.insertelement %arg0, %arg1[] : vector<4xf32>
-}
-
-// -----
-
-func.func @insert_element(%arg0: f32, %arg1: vector<4x4xf32>) {
- %c = arith.constant 3 : i32
- // expected-error@+1 {{unexpected >1 vector rank}}
- %0 = vector.insertelement %arg0, %arg1[%c : i32] : vector<4x4xf32>
-}
-
-// -----
-
-func.func @insert_element_wrong_type(%arg0: i32, %arg1: vector<4xf32>) {
- %c = arith.constant 3 : i32
- // expected-error@+1 {{'vector.insertelement' op failed to verify that source operand type matches element type of result}}
- %0 = "vector.insertelement" (%arg0, %arg1, %c) : (i32, vector<4xf32>, i32) -> (vector<4xf32>)
-}
-
-// -----
-
func.func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
// expected-error@+1 {{expected position attribute of rank no greater than dest vector rank}}
%1 = vector.insert %a, %b[3, 3, 3, 3, 3, 3] : f32 into vector<4x8x16xf32>
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 6a56116..625ffc1 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -199,22 +199,6 @@ func.func @shuffle_poison_mask(%a: vector<4xf32>, %b: vector<4xf32>) -> vector<4
return %1 : vector<4xf32>
}
-// CHECK-LABEL: @extract_element_0d
-func.func @extract_element_0d(%a: vector<f32>) -> f32 {
- // CHECK-NEXT: vector.extractelement %{{.*}}[] : vector<f32>
- %1 = vector.extractelement %a[] : vector<f32>
- return %1 : f32
-}
-
-// CHECK-LABEL: @extract_element
-func.func @extract_element(%a: vector<16xf32>) -> f32 {
- // CHECK: %[[C15:.*]] = arith.constant 15 : i32
- %c = arith.constant 15 : i32
- // CHECK-NEXT: vector.extractelement %{{.*}}[%[[C15]] : i32] : vector<16xf32>
- %1 = vector.extractelement %a[%c : i32] : vector<16xf32>
- return %1 : f32
-}
-
// CHECK-LABEL: @extract_const_idx
func.func @extract_const_idx(%arg0: vector<4x8x16xf32>)
-> (vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32) {
@@ -256,22 +240,6 @@ func.func @extract_poison_idx(%a: vector<4x5xf32>) -> f32 {
return %0 : f32
}
-// CHECK-LABEL: @insert_element_0d
-func.func @insert_element_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
- // CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[] : vector<f32>
- %1 = vector.insertelement %a, %b[] : vector<f32>
- return %1 : vector<f32>
-}
-
-// CHECK-LABEL: @insert_element
-func.func @insert_element(%a: f32, %b: vector<16xf32>) -> vector<16xf32> {
- // CHECK: %[[C15:.*]] = arith.constant 15 : i32
- %c = arith.constant 15 : i32
- // CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[%[[C15]] : i32] : vector<16xf32>
- %1 = vector.insertelement %a, %b[%c : i32] : vector<16xf32>
- return %1 : vector<16xf32>
-}
-
// CHECK-LABEL: @insert_const_idx
func.func @insert_const_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>,
%res: vector<4x8x16xf32>) -> vector<4x8x16xf32> {
diff --git a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
index 8e167a5..d5e3443 100644
--- a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
@@ -2,7 +2,7 @@
// CHECK-LABEL: func @broadcast_vec1d_from_scalar
// CHECK-SAME: %[[A:.*0]]: f32
-// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2xf32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2xf32>
// CHECK: return %[[T0]] : vector<2xf32>
func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
@@ -12,7 +12,7 @@ func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
// CHECK-LABEL: func @broadcast_vec2d_from_scalar
// CHECK-SAME: %[[A:.*0]]: f32
-// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3xf32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3xf32>
// CHECK: return %[[T0]] : vector<2x3xf32>
func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
@@ -22,7 +22,7 @@ func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
// CHECK-LABEL: func @broadcast_vec3d_from_scalar
// CHECK-SAME: %[[A:.*0]]: f32
-// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3x4xf32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3x4xf32>
// CHECK: return %[[T0]] : vector<2x3x4xf32>
func.func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> {
@@ -87,7 +87,7 @@ func.func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf3
// CHECK-LABEL: func @broadcast_stretch
// CHECK-SAME: %[[A:.*0]]: vector<1xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<1xf32>
-// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<4xf32>
+// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<4xf32>
// CHECK: return %[[T1]] : vector<4xf32>
func.func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> {
@@ -113,16 +113,16 @@ func.func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32>
// CHECK-SAME: %[[A:.*0]]: vector<4x1xf32>
// CHECK: %[[U0:.*]] = ub.poison : vector<4x3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<4x1xf32>
-// CHECK: %[[T2:.*]] = vector.splat %[[T0]] : vector<3xf32>
+// CHECK: %[[T2:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[U0]] [0] : vector<3xf32> into vector<4x3xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : f32 from vector<4x1xf32>
-// CHECK: %[[T6:.*]] = vector.splat %[[T4]] : vector<3xf32>
+// CHECK: %[[T6:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<4x3xf32>
// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : f32 from vector<4x1xf32>
-// CHECK: %[[T10:.*]] = vector.splat %[[T8]] : vector<3xf32>
+// CHECK: %[[T10:.*]] = vector.broadcast %[[T8]] : f32 to vector<3xf32>
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T7]] [2] : vector<3xf32> into vector<4x3xf32>
// CHECK: %[[T12:.*]] = vector.extract %[[A]][3, 0] : f32 from vector<4x1xf32>
-// CHECK: %[[T14:.*]] = vector.splat %[[T12]] : vector<3xf32>
+// CHECK: %[[T14:.*]] = vector.broadcast %[[T12]] : f32 to vector<3xf32>
// CHECK: %[[T15:.*]] = vector.insert %[[T14]], %[[T11]] [3] : vector<3xf32> into vector<4x3xf32>
// CHECK: return %[[T15]] : vector<4x3xf32>
diff --git a/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir
index 059d955..5a8125e 100644
--- a/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir
@@ -5,11 +5,11 @@
// CHECK-SAME: %[[B:.*1]]: vector<3xf32>
// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32>
-// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32>
+// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
// CHECK: %[[T2:.*]] = arith.mulf %[[T1]], %[[B]] : vector<3xf32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xf32>
+// CHECK: %[[T5:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32>
// CHECK: %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<2x3xf32>
// CHECK: return %[[T7]] : vector<2x3xf32>
@@ -26,12 +26,12 @@ func.func @outerproduct_noacc(%arg0: vector<2xf32>,
// CHECK-SAME: %[[C:.*2]]: vector<2x3xf32>
// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32>
-// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32>
+// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xf32> from vector<2x3xf32>
// CHECK: %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32>
// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.splat %[[T5]] : vector<3xf32>
+// CHECK: %[[T6:.*]] = vector.broadcast %[[T5]] : f32 to vector<3xf32>
// CHECK: %[[T7:.*]] = vector.extract %[[C]][1] : vector<3xf32> from vector<2x3xf32>
// CHECK: %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32>
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32>
@@ -49,11 +49,11 @@ func.func @outerproduct_acc(%arg0: vector<2xf32>,
// CHECK-SAME: %[[B:.*1]]: vector<3xi32>
// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32>
-// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32>
+// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : i32 to vector<3xi32>
// CHECK: %[[T2:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32>
-// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xi32>
+// CHECK: %[[T5:.*]] = vector.broadcast %[[T4]] : i32 to vector<3xi32>
// CHECK: %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xi32> into vector<2x3xi32>
// CHECK: return %[[T7]] : vector<2x3xi32>
@@ -69,13 +69,13 @@ func.func @outerproduct_noacc_int(%arg0: vector<2xi32>,
// CHECK-SAME: %[[C:.*2]]: vector<2x3xi32>
// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32>
-// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32>
+// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : i32 to vector<3xi32>
// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xi32> from vector<2x3xi32>
// CHECK: %[[T3:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32>
// CHECK: %[[T4:.*]] = arith.addi %[[T3]], %[[T2]] : vector<3xi32>
// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32>
-// CHECK: %[[T7:.*]] = vector.splat %[[T6]] : vector<3xi32>
+// CHECK: %[[T7:.*]] = vector.broadcast %[[T6]] : i32 to vector<3xi32>
// CHECK: %[[T8:.*]] = vector.extract %[[C]][1] : vector<3xi32> from vector<2x3xi32>
// CHECK: %[[T9:.*]] = arith.muli %[[T7]], %[[B]] : vector<3xi32>
// CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[T8]] : vector<3xi32>
@@ -91,7 +91,7 @@ func.func @outerproduct_acc_int(%arg0: vector<2xi32>,
// CHECK-LABEL: func @axpy_fp(
// CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
// CHECK-SAME: %[[B:.*1]]: f32)
-// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : f32 to vector<16xf32>
// CHECK: %[[T1:.*]] = arith.mulf %[[A]], %[[T0]] : vector<16xf32>
// CHECK: return %[[T1]] : vector<16xf32>
func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> {
@@ -103,7 +103,7 @@ func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> {
// CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
// CHECK-SAME: %[[B:.*1]]: f32,
// CHECK-SAME: %[[C:.*2]]: vector<16xf32>)
-// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : f32 to vector<16xf32>
// CHECK: %[[T1:.*]] = vector.fma %[[A]], %[[T0]], %[[C]] : vector<16xf32>
// CHECK: return %[[T1]] : vector<16xf32>
func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) -> vector<16xf32> {
@@ -114,7 +114,7 @@ func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>
// CHECK-LABEL: func @axpy_int(
// CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
// CHECK-SAME: %[[B:.*1]]: i32)
-// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : i32 to vector<16xi32>
// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32>
// CHECK: return %[[T1]] : vector<16xi32>
func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> {
@@ -126,7 +126,7 @@ func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> {
// CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
// CHECK-SAME: %[[B:.*1]]: i32,
// CHECK-SAME: %[[C:.*2]]: vector<16xi32>)
-// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : i32 to vector<16xi32>
// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32>
// CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[C]] : vector<16xi32>
// CHECK: return %[[T2]] : vector<16xi32>
diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir
index b826cdc..ef881ba 100644
--- a/mlir/test/Dialect/Vector/vector-sink.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink.mlir
@@ -180,13 +180,14 @@ func.func @negative_not_elementwise() -> vector<2x2xf32> {
// -----
-// The source and the result for arith.cmp have different types - not supported
-
-// CHECK-LABEL: func.func @negative_source_and_result_mismatch
-// CHECK: %[[BROADCAST:.+]] = vector.broadcast
-// CHECK: %[[RETURN:.+]] = arith.cmpf uno, %[[BROADCAST]], %[[BROADCAST]]
-// CHECK: return %[[RETURN]]
-func.func @negative_source_and_result_mismatch(%arg0 : f32, %arg1 : vector<1xf32>) -> vector<1xi1> {
+// The source and the result for arith.cmp have different types
+
+// CHECK-LABEL: func.func @source_and_result_mismatch(
+// CHECK-SAME: %[[ARG0:.+]]: f32)
+// CHECK: %[[COMPARE:.+]] = arith.cmpf uno, %[[ARG0]], %[[ARG0]]
+// CHECK: %[[BROADCAST:.+]] = vector.broadcast %[[COMPARE]] : i1 to vector<1xi1>
+// CHECK: return %[[BROADCAST]]
+func.func @source_and_result_mismatch(%arg0 : f32) -> vector<1xi1> {
%0 = vector.broadcast %arg0 : f32 to vector<1xf32>
%1 = arith.cmpf uno, %0, %0 : vector<1xf32>
return %1 : vector<1xi1>
@@ -210,6 +211,130 @@ func.func @negative_op_only_supports_vectors(%arg0 : f32) -> vector<1xf32> {
return %1 : vector<1xf32>
}
+// -----
+
+// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const(
+// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> {
+// CHECK: %[[NEW_CST:.*]] = arith.constant 2 : index
+// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[NEW_CST]] : index
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
+// CHECK: return %[[BCAST]] : vector<1x4xindex>
+
+func.func @broadcast_scalar_and_splat_const(%arg0: index) -> vector<1x4xindex> {
+ %0 = vector.broadcast %arg0 : index to vector<1x4xindex>
+ %cst = arith.constant dense<2> : vector<1x4xindex>
+ %2 = arith.addi %0, %cst : vector<1x4xindex>
+ return %2 : vector<1x4xindex>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const_const_first(
+// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> {
+// CHECK: %[[NEW_CST:.*]] = arith.constant 2 : index
+// CHECK: %[[SUB:.*]] = arith.subi %[[NEW_CST]], %[[ARG_0]] : index
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[SUB]] : index to vector<1x4xindex>
+// CHECK: return %[[BCAST]] : vector<1x4xindex>
+
+func.func @broadcast_scalar_and_splat_const_const_first(%arg0: index) -> vector<1x4xindex> {
+ %0 = vector.broadcast %arg0 : index to vector<1x4xindex>
+ %cst = arith.constant dense<2> : vector<1x4xindex>
+ %2 = arith.subi %cst, %0 : vector<1x4xindex>
+ return %2 : vector<1x4xindex>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_vector_and_splat_const(
+// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>) -> vector<3x4xf32> {
+// CHECK: %[[NEW_CST:.*]] = arith.constant dense<2.000000e+00> : vector<4xf32>
+// CHECK: %[[ADD:.*]] = arith.mulf %[[ARG_0]], %[[NEW_CST]] : vector<4xf32>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : vector<4xf32> to vector<3x4xf32>
+// CHECK: return %[[BCAST]] : vector<3x4xf32>
+
+func.func @broadcast_vector_and_splat_const(%arg0: vector<4xf32>) -> vector<3x4xf32> {
+ %0 = vector.broadcast %arg0 : vector<4xf32> to vector<3x4xf32>
+ %cst = arith.constant dense<2.000000e+00> : vector<3x4xf32>
+ %2 = arith.mulf %0, %cst : vector<3x4xf32>
+ return %2 : vector<3x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @negative_broadcast_with_non_splat_const(
+// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> {
+// CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : index to vector<1x4xindex>
+// CHECK-DAG: %[[CST:.*]] = arith.constant dense<{{\[}}[0, 1, 2, 3]]> : vector<1x4xindex>
+// CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[CST]] : vector<1x4xindex>
+// CHECK: return %[[ADD]] : vector<1x4xindex>
+
+func.func @negative_broadcast_with_non_splat_const(%arg0: index) -> vector<1x4xindex> {
+ %0 = vector.broadcast %arg0 : index to vector<1x4xindex>
+ %cst = arith.constant dense<[[0, 1, 2, 3]]> : vector<1x4xindex>
+ %2 = arith.addi %0, %cst : vector<1x4xindex>
+ return %2 : vector<1x4xindex>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_scalar_mixed_type(
+// CHECK-SAME: %[[ARG_0:.*]]: f16) -> vector<1x4xf32> {
+// CHECK: %[[EXTF:.*]] = arith.extf %[[ARG_0]] : f16 to f32
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTF]] : f32 to vector<1x4xf32>
+// CHECK: return %[[BCAST]] : vector<1x4xf32>
+
+func.func @broadcast_scalar_mixed_type(%arg0: f16) -> vector<1x4xf32> {
+ %0 = vector.broadcast %arg0 : f16 to vector<1x4xf16>
+ %1 = arith.extf %0 : vector<1x4xf16> to vector<1x4xf32>
+ return %1 : vector<1x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_vector_mixed_type(
+// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf16>) -> vector<3x4xf32> {
+// CHECK: %[[EXTF:.*]] = arith.extf %[[ARG_0]] : vector<4xf16> to vector<4xf32>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTF]] : vector<4xf32> to vector<3x4xf32>
+// CHECK: return %[[BCAST]] : vector<3x4xf32>
+
+func.func @broadcast_vector_mixed_type(%arg0: vector<4xf16>) -> vector<3x4xf32> {
+ %0 = vector.broadcast %arg0 : vector<4xf16> to vector<3x4xf16>
+ %1 = arith.extf %0 : vector<3x4xf16> to vector<3x4xf32>
+ return %1 : vector<3x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const_mixed_type(
+// CHECK-SAME: %[[ARG_0:.*]]: f32) -> vector<1x4xf32> {
+// CHECK: %[[NEW_CST:.*]] = arith.constant 3 : i32
+// CHECK: %[[POW:.*]] = math.fpowi %[[ARG_0]], %[[NEW_CST]] : f32, i32
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[POW]] : f32 to vector<1x4xf32>
+// CHECK: return %[[BCAST]] : vector<1x4xf32>
+
+func.func @broadcast_scalar_and_splat_const_mixed_type(%arg0: f32) -> vector<1x4xf32> {
+ %0 = vector.broadcast %arg0 : f32 to vector<1x4xf32>
+ %cst = arith.constant dense<3> : vector<1x4xi32>
+ %2 = math.fpowi %0, %cst : vector<1x4xf32>, vector<1x4xi32>
+ return %2 : vector<1x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_vector_and_splat_const_mixed_type(
+// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>) -> vector<3x4xf32> {
+// CHECK: %[[NEW_CST:.*]] = arith.constant dense<3> : vector<4xi32>
+// CHECK: %[[POW:.*]] = math.fpowi %[[ARG_0]], %[[NEW_CST]] : vector<4xf32>, vector<4xi32>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[POW]] : vector<4xf32> to vector<3x4xf32>
+// CHECK: return %[[BCAST]] : vector<3x4xf32>
+
+func.func @broadcast_vector_and_splat_const_mixed_type(%arg0: vector<4xf32>) -> vector<3x4xf32> {
+ %0 = vector.broadcast %arg0 : vector<4xf32> to vector<3x4xf32>
+ %cst = arith.constant dense<3> : vector<3x4xi32>
+ %2 = math.fpowi %0, %cst : vector<3x4xf32>, vector<3x4xi32>
+ return %2 : vector<3x4xf32>
+}
+
//===----------------------------------------------------------------------===//
// [Pattern: ReorderCastOpsOnBroadcast]
//
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index 511ab70..1b54d54 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -284,19 +284,19 @@ func.func @transfer_read_permutations(%mem_0 : memref<?x?xf32>, %mem_1 : memref<
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
-// CHECK: %[[MASK0:.*]] = vector.splat %{{.*}} : vector<14x7xi1>
+// CHECK: %[[MASK0:.*]] = vector.broadcast %{{.*}} : i1 to vector<14x7xi1>
%mask0 = vector.splat %m : vector<14x7xi1>
%0 = vector.transfer_read %mem_1[%c0, %c0, %c0, %c0], %cst, %mask0 {in_bounds = [true, false, true, true], permutation_map = #map0} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
// CHECK: vector.transfer_read {{.*}} %[[MASK0]] {in_bounds = [false, true, true, true], permutation_map = #[[$MAP0]]} : memref<?x?x?x?xf32>, vector<14x7x8x16xf32>
// CHECK: vector.transpose %{{.*}}, [1, 0, 2, 3] : vector<14x7x8x16xf32> to vector<7x14x8x16xf32>
-// CHECK: %[[MASK1:.*]] = vector.splat %{{.*}} : vector<16x14xi1>
+// CHECK: %[[MASK1:.*]] = vector.broadcast %{{.*}} : i1 to vector<16x14xi1>
%mask1 = vector.splat %m : vector<16x14xi1>
%1 = vector.transfer_read %mem_1[%c0, %c0, %c0, %c0], %cst, %mask1 {in_bounds = [true, false, true, false], permutation_map = #map1} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
// CHECK: vector.transfer_read {{.*}} %[[MASK1]] {in_bounds = [false, false, true, true], permutation_map = #[[$MAP0]]} : memref<?x?x?x?xf32>, vector<16x14x7x8xf32>
// CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32>
-// CHECK: %[[MASK3:.*]] = vector.splat %{{.*}} : vector<14x7xi1>
+// CHECK: %[[MASK3:.*]] = vector.broadcast %{{.*}} : i1 to vector<14x7xi1>
%mask2 = vector.splat %m : vector<14x7xi1>
%2 = vector.transfer_read %mem_1[%c0, %c0, %c0, %c0], %cst, %mask2 {in_bounds = [true, false, true, true], permutation_map = #map2} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
// CHECK: vector.transfer_read {{.*}} %[[MASK3]] {in_bounds = [false, true, true], permutation_map = #[[$MAP1]]} : memref<?x?x?x?xf32>, vector<14x16x7xf32>
@@ -336,7 +336,7 @@ func.func @transfer_write_permutations_tensor_masked(
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
- // CHECK: %[[MASK:.*]] = vector.splat %[[M]] : vector<16x14x7x8xi1>
+ // CHECK: %[[MASK:.*]] = vector.broadcast %[[M]] : i1 to vector<16x14x7x8xi1>
%mask0 = vector.splat %m : vector<16x14x7x8xi1>
%res = vector.transfer_write %vec, %dst[%c0, %c0, %c0, %c0], %mask0 {in_bounds = [true, false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>} : vector<7x14x8x16xf32>, tensor<?x?x?x?xf32>
// CHECK: %[[NEW_VEC0:.*]] = vector.transpose %{{.*}} [3, 1, 0, 2] : vector<7x14x8x16xf32> to vector<16x14x7x8xf32>
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 0160bfe..dff3ffa 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -385,6 +385,74 @@ func.func @load_gather_vc_3(%src: ui64) {
}
// -----
+func.func @prefetch_offset_wi_1(%src: memref<4x4xf32>) {
+ %offsets = arith.constant dense<[0]> : vector<1xindex>
+ // expected-error@+1 {{Expecting the source is a 1D memref or pointer}}
+ xegpu.prefetch %src[%offsets]: memref<4x4xf32>, vector<1xindex>
+ return
+}
+
+// -----
+func.func @load_gather_offset_sg(%src: memref<?xf16>) {
+ %offsets = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %mask = arith.constant dense<1>: vector<8xi1>
+ // expected-error@+1 {{Mask should match value except the chunk size dim}}
+ %2 = xegpu.load %src[%offsets], %mask
+ : memref<?xf16>, vector<4xindex>, vector<8xi1>
+ -> vector<4x2xf16>
+ return
+}
+
+// -----
+func.func @load_gather_offset_wi(%src: ui64) {
+ %mask = arith.constant dense<1>: vector<1xi1>
+ %offsets = arith.constant dense<[0]> : vector<1xindex>
+ // expected-error@+1 {{value elements must match chunk size}}
+ %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64, vector<1xindex>, vector<1xi1> -> vector<4xf32>
+ return
+}
+
+// -----
+func.func @store_scatter_offset_wi_1(%src: memref<?xf16>) {
+ %val = arith.constant dense<2.9>: vector<4xf16>
+ %offsets = arith.constant dense<[0]> : vector<1xindex>
+ %mask = arith.constant dense<1>: vector<1xi1>
+ // expected-error@+1 {{value elements must match chunk size}}
+ xegpu.store %val, %src[%offsets], %mask
+ : vector<4xf16>, memref<?xf16>, vector<1xindex>, vector<1xi1>
+ return
+}
+
+// -----
+func.func @store_scatter_offset_wi_2(%src: memref<4x4xf16>) {
+ %val = arith.constant dense<2.9>: vector<4xf16>
+ %offsets = arith.constant dense<[0]> : vector<1xindex>
+ %mask = arith.constant dense<1>: vector<1xi1>
+ // expected-error@+1 {{Expecting the dest is a 1D memref or pointer}}
+ xegpu.store %val, %src[%offsets], %mask
+ : vector<4xf16>, memref<4x4xf16>, vector<1xindex>, vector<1xi1>
+ return
+}
+
+// -----
+func.func @load_gather_offset_wi_2(%src: ui64) {
+ %mask = arith.constant dense<1>: vector<1xi1>
+ %offsets = arith.constant dense<[0]> : vector<1xindex>
+ // expected-error@+1 {{value elements must match chunk size}}
+ %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64, vector<1xindex>, vector<1xi1> -> vector<4xf16>
+ return
+}
+
+// -----
+func.func @load_gather_offset_wi_1(%src: memref<4x4xf32>) {
+ %mask = arith.constant dense<1>: vector<1xi1>
+ %offsets = arith.constant dense<[0]> : vector<1xindex>
+ // expected-error@+1 {{Expecting the source is a 1D memref or pointer}}
+ %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : memref<4x4xf32>, vector<1xindex>, vector<1xi1> -> vector<2xf32>
+ return
+}
+
+// -----
func.func @store_scatter_vc_1(%src: memref<24x32xf32>) {
%0 = arith.constant dense<1>: vector<4xi1>
%1 = arith.constant dense<2.9>: vector<4x2xf32>
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 3ebb1b969a..6be2371 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -521,6 +521,16 @@ gpu.func @subgroup_load_4(%src: ui64) {
gpu.return
}
+// CHECK: gpu.func @subgroup_load_offset_1(%arg0: memref<?xf16>) {
+gpu.func @subgroup_load_offset_1(%src: memref<?xf16>) {
+ %offset = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %mask = arith.constant dense<1>: vector<4xi1>
+ //CHECK: %[[R1:.*]] = xegpu.load %arg0[%cst], %cst_0 <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : memref<?xf16>, vector<4xindex>, vector<4xi1> -> vector<4x2xf16>
+ %val = xegpu.load %src[%offset], %mask <{chunk_size=2, l1_hint = #xegpu.cache_hint<cached>}>
+ : memref<?xf16>, vector<4xindex>, vector<4xi1> -> vector<4x2xf16>
+ gpu.return
+}
+
// CHECK: gpu.func @subgroup_store(%[[arg0:.*]]: ui64) {
gpu.func @subgroup_store(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
@@ -626,6 +636,17 @@ gpu.func @subgroup_store_4(%src: ui64) {
gpu.return
}
+// CHECK: gpu.func @subgroup_store_offset_1(%arg0: memref<?xf16>) {
+gpu.func @subgroup_store_offset_1(%dest: memref<?xf16>) {
+ %val = arith.constant dense<2.9>: vector<4x2xf16>
+ %offset = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %mask = arith.constant dense<1>: vector<4xi1>
+ //CHECK: xegpu.store %[[R0:.*]], %arg0[%cst_0], %cst_1 <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<4x2xf16>, memref<?xf16>, vector<4xindex>, vector<4xi1>
+ xegpu.store %val, %dest[%offset], %mask <{chunk_size=2, l1_hint = #xegpu.cache_hint<cached>}>
+ : vector<4x2xf16>, memref<?xf16>, vector<4xindex>, vector<4xi1>
+ gpu.return
+}
+
// CHECK: gpu.func @prefetch(%[[arg0:.*]]: ui64) {
gpu.func @prefetch(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
@@ -637,6 +658,14 @@ gpu.func @prefetch(%src: ui64) {
gpu.return
}
+// CHECK: gpu.func @prefetch_offset(%[[arg0:.*]]: ui64) {
+gpu.func @prefetch_offset(%src: ui64) {
+ //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ // CHECK: xegpu.prefetch %[[arg0]][%cst] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : ui64, vector<4xindex>
+ xegpu.prefetch %src[%0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: ui64, vector<4xindex>
+ gpu.return
+}
// CHECK: gpu.func @create_update_tdesc(%[[arg0:.*]]: ui64) {
gpu.func @create_update_tdesc(%src: ui64) {
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index d67bdb4..628a485 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -2,122 +2,117 @@
gpu.module @test_round_robin_assignment {
// CHECK-LABEL: create_nd_tdesc
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @create_nd_tdesc(%src: memref<24x32xf32>) {
- // CHECK-COUNT-12: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<24x32xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) {
+ // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf32>
+ // CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-NOT: xegpu.create_nd_tdesc
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
// CHECK-LABEL: load_nd_tdesc
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @load_nd_tdesc(%src: memref<24x32xf32>) {
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
- // CHECK-COUNT-12: xegpu.load_nd %{{.*}}
- // CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
- // CHECK-SAME-COUNT-12: -> vector<2x2xf32>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-COUNT-4: xegpu.load_nd %{{.*}}
+ // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-SAME-COUNT-4: -> vector<16x16xf32>
// CHECK-NOT: xegpu.load_nd
%load = xegpu.load_nd %tdesc
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
- -> vector<24x32xf32>
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<256x128xf32>
gpu.return
}
// CHECK-LABEL: store_nd
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @store_nd(%src: memref<24x32xf32>) {
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
- // CHECK-COUNT-12: xegpu.store_nd %{{.*}}, %{{.*}}
- // CHECK-SAME-COUNT-12: : vector<2x2xf32>, !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @store_nd(%src: memref<256x128xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, %{{.*}}
+ // CHECK-SAME-COUNT-4: : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-NOT : xegpu.store_nd
%load = xegpu.load_nd %tdesc
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
- -> vector<24x32xf32>
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<256x128xf32>
xegpu.store_nd %load, %tdesc
- : vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
// CHECK-LABEL: update_nd
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @update_nd(%src: memref<24x32xf32>){
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
- // CHECK-COUNT-12: xegpu.update_nd_offset %{{.*}}, [0, 16]
- // CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @update_nd(%src: memref<256x128xf32>){
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-COUNT-4: xegpu.update_nd_offset %{{.*}}, [0, 16]
+ // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>>
// CHECK-NOT: xegpu.update_nd_offset
%update = xegpu.update_nd_offset %tdesc, [0, 16]
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
// CHECK-LABEL: dpas
- // CHECK-SAME: (%[[ARG_0:.*]]: memref<8x8xf32>, %[[ARG_1:.*]]: memref<8x8xf32>, %[[ARG_2:.*]]: memref<8x8xf32>)
- gpu.func @dpas(%a: memref<8x8xf32>, %b: memref<8x8xf32>, %c: memref<8x8xf32>) {
- // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<8x8xf32>
- // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+ // CHECK-SAME: (%[[ARG_0:.*]]: memref<256x128xf16>, %[[ARG_1:.*]]: memref<128x256xf16>)
+ gpu.func @dpas(%a: memref<256x128xf16>, %b: memref<128x256xf16>) {
+ // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf16>
+ // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-NOT: xegpu.create_nd_tdesc
- // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<8x8xf32>
- // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
- // CHECK-NOT: xegpu.create_nd_tdesc
- // CHECK-COUNT-4: xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}] : memref<8x8xf32>
- // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+ // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<128x256xf16>
+ // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [4, 8], lane_data = [1, 1]>>
// CHECK-NOT: xegpu.create_nd_tdesc
// CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}}
- // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
- // CHECK-SAME-COUNT-16: : vector<2x2xf32>, vector<2x2xf32> -> vector<2x2xf32>
+ // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ // CHECK-SAME-COUNT-16: : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
// CHECK-NOT: xegpu.dpas
- %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<8x8xf32>
- -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<256x128xf16>
+ -> !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
%load_a = xegpu.load_nd %tdesc_a
- : !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
- -> vector<8x8xf32>
- %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<8x8xf32>
- -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ : !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<256x128xf16>
+ %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x256xf16>
+ -> !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
%load_b = xegpu.load_nd %tdesc_b
- : !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
- -> vector<8x8xf32>
- %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<8x8xf32>
- -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ : !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
+ -> vector<128x256xf16>
%dpas = xegpu.dpas %load_a, %load_b
- {layout_result_0 = #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
- : vector<8x8xf32>, vector<8x8xf32> -> vector<8x8xf32>
+ {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>}
+ : vector<256x128xf16>, vector<128x256xf16> -> vector<256x256xf32>
gpu.return
}
// CHECK-LABEL: prefetch_nd_tdesc
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @prefetch_nd_tdesc(%src: memref<24x32xf32>) {
- // CHECK-COUNT-12: xegpu.prefetch_nd %{{.*}}
- // CHECK-SAME-COUNT-12 : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @prefetch_nd_tdesc(%src: memref<256x128xf32>) {
+ // CHECK-COUNT-4: xegpu.prefetch_nd %{{.*}}
+ // CHECK-SAME-COUNT-4: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-NOT: xegpu.prefetch_nd
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
xegpu.prefetch_nd %tdesc
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
// CHECK-LABEL: broadcast
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
- gpu.func @broadcast(%src: memref<24x1xf32>) {
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
- -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<128x1xf32>
+ gpu.func @broadcast(%src: memref<128x1xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<128x1xf32>
+ -> !xegpu.tensor_desc<128x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc
- : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
- -> vector<24x1xf32>
- // CHECK-COUNT-3: vector.broadcast {{.*}}
- // CHECK-SAME-COUNT-3: {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
- // CHECK-SAME-COUNT-3: : vector<2x1xf32> to vector<2x4xf32>
+ : !xegpu.tensor_desc<128x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
+ -> vector<128x1xf32>
+ // CHECK-COUNT-2: vector.broadcast {{.*}}
+ // CHECK-SAME-COUNT-2: {layout_result_0 = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>}
+ // CHECK-SAME-COUNT-2: : vector<16x1xf32> to vector<16x32xf32>
// CHECK-NOT: vector.broadcast
%broadcast = vector.broadcast %load
- {layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 4], lane_layout = [2, 1], lane_data = [1, 1]>}
- : vector<24x1xf32> to vector<24x8xf32>
+ {layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 32], lane_layout = [8, 1], lane_data = [1, 1]>}
+ : vector<128x1xf32> to vector<128x64xf32>
gpu.return
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index d511224..d4b0037 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -4,201 +4,181 @@
//CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)>
gpu.module @test_1_1_assignment {
// CHECK-LABEL: create_nd_tdesc
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @create_nd_tdesc(%src: memref<24x32xf32>) {
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) {
// CHECK: %[[SGID:.*]] = gpu.subgroup_id
- // CHECK: %[[C12:.*]] = arith.constant 12 : index
- // CHECK: %[[C4:.*]] = arith.constant 4 : index
// CHECK: %[[C8:.*]] = arith.constant 8 : index
+ // CHECK: %[[C32:.*]] = arith.constant 32 : index
+ // CHECK: %[[C4:.*]] = arith.constant 4 : index
+ // CHECK: %[[C32_0:.*]] = arith.constant 32 : index
+ // CHECK: %[[C4_1:.*]] = arith.constant 4 : index
// CHECK: %[[DIV:.*]] = affine.apply #map()[%[[SGID]]]
// CHECK: %[[REM:.*]] = affine.apply #map1()[%[[SGID]]]
- // CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C12]]
- // CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C8]]
- // CHECK: %[[C24:.*]] = arith.constant 24 : index
- // CHECK: %[[MOD:.*]] = index.remu %[[MUL1]], %[[C24]]
+ // CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C32]]
+ // CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C32_0]]
// CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: %[[ADD1:.*]] = index.add %[[MOD]], %[[C0]]
- // CHECK: %[[C32:.*]] = arith.constant 32 : index
- // CHECK: %[[MOD1:.*]] = index.remu %[[MUL2]], %[[C32]]
- // CHECK: %[[C0_1:.*]] = arith.constant 0 : index
- // CHECK: %[[ADD2:.*]] = index.add %[[MOD1]], %[[C0_1]]
- // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<24x32xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK: %[[C256:.*]] = arith.constant 256 : index
+ // CHECK: %[[MOD:.*]] = index.remu %[[MUL1]], %[[C256]]
+ // CHECK: %[[C0_2:.*]] = arith.constant 0 : index
+ // CHECK: %[[ADD1:.*]] = index.add %[[MOD]], %[[C0_2]]
+ // CHECK: %[[C0_3:.*]] = arith.constant 0 : index
+ // CHECK: %[[C128:.*]] = arith.constant 128 : index
+ // CHECK: %[[MOD1:.*]] = index.remu %[[MUL2]], %[[C128]]
+ // CHECK: %[[C0_4:.*]] = arith.constant 0 : index
+ // CHECK: %[[ADD2:.*]] = index.add %[[MOD1]], %[[C0_4]]
+ // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<256x128xf32>
+ // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK: gpu.return
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
// CHECK-LABEL: load_nd_tdesc
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @load_nd_tdesc(%src: memref<24x32xf32>) {
- // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) {
+ // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
+ // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]]
- // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- // CHECK-SAME: -> vector<12x8xf32>
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-SAME: -> vector<32x32xf32>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
- -> vector<24x32xf32>
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<256x128xf32>
gpu.return
}
// CHECK-LABEL: store_nd
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @store_nd(%src: memref<24x32xf32>) {
- // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @store_nd(%src: memref<256x128xf32>) {
+ // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
+ // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]]
- // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- // CHECK-SAME: -> vector<12x8xf32>
+ // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-SAME: -> vector<32x32xf32>
// CHECK: xegpu.store_nd %[[LOAD]], %[[TDESC]]
- // CHECK-SAME: : vector<12x8xf32>, !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK-SAME: : vector<32x32xf32>, !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
- -> vector<24x32xf32>
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<256x128xf32>
xegpu.store_nd %load, %tdesc
- : vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
// CHECK-LABEL: update_nd
-// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
-gpu.func @update_nd(%src: memref<24x32xf32>){
- // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+gpu.func @update_nd(%src: memref<256x128xf32>){
+ // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
+ // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK: %[[UPDATE:.*]] = xegpu.update_nd_offset %[[TDESC]], [0, 16]
- // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
%update = xegpu.update_nd_offset %tdesc, [0, 16]
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
// CHECK-LABEL: dpas
-// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
-// CHECK-SAME: %[[ARG_1:.*]]: memref<32x24xf32>
-gpu.func @dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
- // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
- // CHECk-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- // CHECK: %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]]
- // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- // CHECK-SAME: -> vector<12x8xf32>
- // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>>
- // CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]]
- // CHECK-SAME: : !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>>
- // CHECK-SAME: -> vector<8x12xf32>
- // CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]]
- // CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
- // CHECK-SAME: : vector<12x8xf32>, vector<8x12xf32> -> vector<12x12xf32>
- %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
+ // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x128xf16>, vector<128x16xf16> -> vector<16x16xf32>
+ %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<128x128xf16>
+ -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>>
%load_a = xegpu.load_nd %tdesc_a
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
- -> vector<24x32xf32>
- %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32>
- -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>>
+ : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<128x128xf16>
+ %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x128xf16>
+ -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
%load_b = xegpu.load_nd %tdesc_b
- : !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>>
- -> vector<32x24xf32>
+ : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
+ -> vector<128x128xf16>
%dpas = xegpu.dpas %load_a, %load_b
- {layout_result_0 = #xegpu.layout<sg_layout = [2, 2], sg_data = [12, 12], lane_layout = [2, 2], lane_data = [1, 1]>}
- : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
+ {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>}
+ : vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
gpu.return
}
// CHECK-LABEL: dpas_no_sg_data
-// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
-// CHECK-SAME: %[[ARG_1:.*]]: memref<32x24xf32>
-gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
- // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
- // CHECk-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- // CHECK: %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]]
- // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- // CHECK-SAME: -> vector<12x8xf32>
- // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>>
- // CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]]
- // CHECK-SAME: : !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>>
- // CHECK-SAME: -> vector<8x12xf32>
- // CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]]
- // CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
- // CHECK-SAME: : vector<12x8xf32>, vector<8x12xf32> -> vector<12x12xf32>
- %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], lane_layout = [2, 8], lane_data = [1, 1]>>
+gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
+ // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
+ %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<128x128xf16>
+ -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1],
+ order = [1, 0]>>
%load_a = xegpu.load_nd %tdesc_a
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], lane_layout = [2, 8], lane_data = [1, 1]>>
- -> vector<24x32xf32>
- %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32>
- -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], lane_layout = [8, 2], lane_data = [1, 1]>>
+ : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1],
+ order = [1, 0]>>
+ -> vector<128x128xf16>
+ %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x128xf16>
+ -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1],
+ order = [1, 0]>>
%load_b = xegpu.load_nd %tdesc_b
- : !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], lane_layout = [8, 2], lane_data = [1, 1]>>
- -> vector<32x24xf32>
+ : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1],
+ order = [1, 0]>>
+ -> vector<128x128xf16>
%dpas = xegpu.dpas %load_a, %load_b
- {layout_result_0 = #xegpu.layout<sg_layout = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
- : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
+ {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>}
+ : vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
gpu.return
}
// CHECK-LABEL: prefetch_nd_tdesc
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @prefetch_nd_tdesc(%src: memref<24x32xf32>) {
- // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @prefetch_nd_tdesc(%src: memref<256x128xf32>) {
+ // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
+ // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK: xegpu.prefetch_nd %[[TDESC]]
- // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
xegpu.prefetch_nd %tdesc
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
// CHECK-LABEL: dpas_with_no_create_nd_desc
- gpu.func @dpas_with_no_create_nd_desc(%a: vector<24x32xf32>, %b: vector<32x24xf32>) {
- // CHECK-NOT: vector<12x12xf32>
+ gpu.func @dpas_with_no_create_nd_desc(%a: vector<256x128xf32>, %b: vector<128x256xf32>) {
+ // CHECK-NOT: vector<32x32xf32>
%dpas = xegpu.dpas %a, %b
{layout = #xegpu.layout<sg_layout = [2, 2], sg_data = [12, 12], lane_layout = [2, 2], lane_data = [1, 1]>}
- : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
+ : vector<256x128xf32>, vector<128x256xf32> -> vector<256x256xf32>
gpu.return
}
// CHECK-LABEL: broadcast_dim1
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
- gpu.func @broadcast_dim1(%src: memref<24x1xf32>) {
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
- -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x1xf32>
+ gpu.func @broadcast_dim1(%src: memref<256x1xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x1xf32>
+ -> !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc
- : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
- -> vector<24x1xf32>
- // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
- // CHECK-SAME: : vector<12x1xf32> to vector<12x8xf32>
- %broadcast = vector.broadcast %load
- {layout_result_0 = #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 8], lane_layout = [2, 1], lane_data = [1, 1]>}
- : vector<24x1xf32> to vector<24x8xf32>
+ : !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
+ -> vector<256x1xf32>
+ // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>}
+ // CHECK-SAME: : vector<32x1xf32> to vector<32x32xf32>
+ %broadcast = vector.broadcast %load
+ {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 32], lane_layout = [8, 1], lane_data = [1, 1]>}
+ : vector<256x1xf32> to vector<256x32xf32>
gpu.return
}
// CHECK-LABEL: broadcast_dim0
- // CHECK-SAME: %[[ARG_0:.*]]: memref<1x32xf32>
- gpu.func @broadcast_dim0(%src: memref<1x32xf32>) {
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<1x32xf32>
- -> !xegpu.tensor_desc<1x32xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 8], lane_layout = [1, 8], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<1x128xf32>
+ gpu.func @broadcast_dim0(%src: memref<1x128xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<1x128xf32>
+ -> !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc
- : !xegpu.tensor_desc<1x32xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 8], lane_layout = [1, 8], lane_data = [1, 1]>>
- -> vector<1x32xf32>
- // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 8], lane_data = [1, 1]>}
- // CHECK-SAME: : vector<1x8xf32> to vector<12x8xf32>
+ : !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<1x128xf32>
+ // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ // CHECK-SAME: : vector<1x32xf32> to vector<32x32xf32>
%broadcast = vector.broadcast %load
- {layout_result_0 = #xegpu.layout<sg_layout = [1, 4], sg_data = [12, 8], lane_layout = [1, 8], lane_data = [1, 1]>}
- : vector<1x32xf32> to vector<12x32xf32>
+ {layout_result_0 = #xegpu.layout<sg_layout = [1, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>}
+ : vector<1x128xf32> to vector<32x128xf32>
gpu.return
}