diff options
Diffstat (limited to 'mlir/test/Dialect')
54 files changed, 1397 insertions, 781 deletions
diff --git a/mlir/test/Dialect/Async/canonicalize.mlir b/mlir/test/Dialect/Async/canonicalize.mlir new file mode 100644 index 0000000..1a74eaa --- /dev/null +++ b/mlir/test/Dialect/Async/canonicalize.mlir @@ -0,0 +1,10 @@ +// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s + +// CHECK-NOT: async.execute + +func.func @empty_execute() { + %token = async.execute { + async.yield + } + return +} diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir new file mode 100644 index 0000000..e2ab876 --- /dev/null +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir @@ -0,0 +1,33 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='builtin.module(test.symbol_scope_isolated(test-one-shot-module-bufferize))' -split-input-file | FileCheck %s + +"test.symbol_scope_isolated"() ({ + // CHECK-LABEL: func @inner_func( + // CHECK-SAME: %[[arg0:.*]]: memref<?xf32 + func.func @inner_func(%t: tensor<?xf32>) -> (tensor<?xf32>, f32) { + // CHECK-NOT: copy + %f = arith.constant 1.0 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + // CHECK: memref.store %{{.*}}, %[[arg0]] + %0 = tensor.insert %f into %t[%c0] : tensor<?xf32> + // CHECK: %[[load:.*]] = memref.load %[[arg0]] + %1 = tensor.extract %0[%c1] : tensor<?xf32> + // CHECK: return %[[arg0]], %[[load]] : memref<?xf32{{.*}}>, f32 + return %0, %1 : tensor<?xf32>, f32 + } + + // CHECK-LABEL: func @call_func_with_non_tensor_return( + // CHECK-SAME: %[[arg0:.*]]: memref<?xf32 + func.func @call_func_with_non_tensor_return( + %t0: tensor<?xf32> {bufferization.writable = true}) -> (f32, tensor<?xf32>) { + // CHECK-NOT: alloc + // CHECK-NOT: copy + // CHECK: %[[call:.*]]:2 = call @inner_func(%[[arg0]]) + %0, %1 = call @inner_func(%t0) : (tensor<?xf32>) -> (tensor<?xf32>, f32) + // CHECK: return %[[call]]#1, %[[call]]#0 : f32, memref<?xf32,{{.*}}> + return %1, %0 : f32, tensor<?xf32> + } + "test.finish" () : () -> () +}) : () -> () + + diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir index 162ff06..35381da 100644 --- a/mlir/test/Dialect/GPU/invalid.mlir +++ b/mlir/test/Dialect/GPU/invalid.mlir @@ -479,20 +479,16 @@ func.func @shuffle_unsupported_type_vec(%arg0 : vector<[4]xf32>, %arg1 : i32, %a // ----- func.func @rotate_mismatching_type(%arg0 : f32) { - %offset = arith.constant 4 : i32 - %width = arith.constant 16 : i32 // expected-error@+1 {{op failed to verify that all of {value, rotateResult} have same type}} - %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (i32, i1) + %rotate, %valid = "gpu.rotate"(%arg0) { offset = 4 : i32, width = 16 : i32 } : (f32) -> (i32, i1) return } // ----- func.func @rotate_unsupported_type(%arg0 : index) { - %offset = arith.constant 4 : i32 - %width = arith.constant 16 : i32 // expected-error@+1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'index'}} - %rotate, %valid = gpu.rotate %arg0, %offset, %width : index + %rotate, %valid = gpu.rotate %arg0, 4, 16 : index return } @@ -502,55 +498,31 @@ func.func @rotate_unsupported_type_vec(%arg0 : vector<[4]xf32>) { %offset = arith.constant 4 : i32 %width = arith.constant 16 : i32 // expected-error@+1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'vector<[4]xf32>'}} - %rotate, %valid = gpu.rotate %arg0, %offset, %width : vector<[4]xf32> + %rotate, %valid = gpu.rotate %arg0, 4, 16 : vector<[4]xf32> return } // ----- func.func @rotate_unsupported_width(%arg0 : f32) { - %offset = arith.constant 4 : i32 - %width = arith.constant 15 : i32 - // expected-error@+1 {{op width must be a power of two}} - %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1) + // expected-error@+1 {{'gpu.rotate' op attribute 'width' failed to satisfy constraint: 32-bit signless integer attribute whose value is a power of two > 0}} + %rotate, %valid = "gpu.rotate"(%arg0) { offset = 4 : i32, width = 15 : i32 } : (f32) -> (f32, i1) return } // ----- func.func @rotate_unsupported_offset(%arg0 : f32) { - %offset = arith.constant 16 : i32 - %width = arith.constant 16 : i32 // expected-error@+1 {{op offset must be in the range [0, 16)}} - %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1) + %rotate, %valid = "gpu.rotate"(%arg0) { offset = 16 : i32, width = 16 : i32 }: (f32) -> (f32, i1) return } // ----- func.func @rotate_unsupported_offset_minus(%arg0 : f32) { - %offset = arith.constant -1 : i32 - %width = arith.constant 16 : i32 - // expected-error@+1 {{op offset must be in the range [0, 16)}} - %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1) - return -} - -// ----- - -func.func @rotate_offset_non_constant(%arg0 : f32, %offset : i32) { - %width = arith.constant 16 : i32 - // expected-error@+1 {{op offset is not a constant value}} - %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1) - return -} - -// ----- - -func.func @rotate_width_non_constant(%arg0 : f32, %width : i32) { - %offset = arith.constant 0 : i32 - // expected-error@+1 {{op width is not a constant value}} - %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1) + // expected-error@+1 {{'gpu.rotate' op attribute 'offset' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 0}} + %rotate, %valid = "gpu.rotate"(%arg0) { offset = -1 : i32, width = 16 : i32 } : (f32) -> (f32, i1) return } diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir index 2aef80f..ee1fdfa 100644 --- a/mlir/test/Dialect/GPU/ops.mlir +++ b/mlir/test/Dialect/GPU/ops.mlir @@ -140,9 +140,8 @@ module attributes {gpu.container_module} { // CHECK: gpu.shuffle idx %{{.*}}, %{{.*}}, %{{.*}} : f32 %shfl3, %pred3 = gpu.shuffle idx %arg0, %offset, %width : f32 - // CHECK: gpu.rotate %{{.*}}, %{{.*}}, %{{.*}} : f32 - %rotate_width = arith.constant 16 : i32 - %rotate, %pred4 = gpu.rotate %arg0, %offset, %rotate_width : f32 + // CHECK: gpu.rotate %{{.*}}, 3, 16 : f32 + %rotate, %pred4 = gpu.rotate %arg0, 3, 16 : f32 "gpu.barrier"() : () -> () diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 9cbb56e4..5c5f7e8 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -1176,6 +1176,52 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>) // ----- +// CHECK-LABEL: @broadcast_broadcast_fold +// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32> +// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x3xf32> +// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32> +// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) dimensions = [1, 2] +// CHECK-NOT: linalg.broadcast +// CHECK: return %[[BROADCAST]] : tensor<2x3x4xf32> +func.func @broadcast_broadcast_fold(%input: tensor<2xf32>, + %init1: tensor<2x3xf32>, + %init2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { + %broadcast1 = linalg.broadcast + ins(%input: tensor<2xf32>) + outs(%init1: tensor<2x3xf32>) + dimensions = [1] + %broadcast2 = linalg.broadcast + ins(%broadcast1: tensor<2x3xf32>) + outs(%init2: tensor<2x3x4xf32>) + dimensions = [2] + func.return %broadcast2 : tensor<2x3x4xf32> +} + +// ----- + +// CHECK-LABEL: @broadcast_broadcast_fold +// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32> +// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x4xf32> +// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32> +// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) dimensions = [1, 2] +// CHECK-NOT: linalg.broadcast +// CHECK: return %[[BROADCAST]] : tensor<2x3x4xf32> +func.func @broadcast_broadcast_fold(%input: tensor<2xf32>, + %init1: tensor<2x4xf32>, + %init2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { + %broadcast1 = linalg.broadcast + ins(%input: tensor<2xf32>) + outs(%init1: tensor<2x4xf32>) + dimensions = [1] + %broadcast2 = linalg.broadcast + ins(%broadcast1: tensor<2x4xf32>) + outs(%init2: tensor<2x3x4xf32>) + dimensions = [1] + func.return %broadcast2 : tensor<2x3x4xf32> +} + +// ----- + func.func @transpose_1d(%input: tensor<16xf32>, %init: tensor<16xf32>) -> tensor<16xf32> { %transpose = linalg.transpose @@ -1387,42 +1433,43 @@ func.func @recursive_effect(%arg : tensor<1xf32>) { // CHECK-LABEL: @recursive_effect // CHECK: linalg.map +// ----- + //===----------------------------------------------------------------------===// // linalg.pack //===----------------------------------------------------------------------===// // CHECK-LABEL: func @fold_pack_constant_splat // CHECK-NOT: linalg.pack -// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32> -func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> { +// CHECK: arith.constant dense<1.000000e-01> : tensor<4x8x8x32xf32> +func.func @fold_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> { %cst = arith.constant dense<1.000000e-01> : tensor<64x128xf32> %0 = linalg.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] - inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<8x16x8x32xf32> - return %0 : tensor<8x16x8x32xf32> + inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<4x8x8x32xf32> + return %0 : tensor<4x8x8x32xf32> } // ----- // CHECK-LABEL: func @fold_padding_value_pack_constant_splat // CHECK-NOT: linalg.pack -// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32> -func.func @fold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> { +// CHECK: arith.constant dense<1.000000e-01> : tensor<4x8x8x32xf32> +func.func @fold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> { %pad = arith.constant 1.000000e-01 : f32 %cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32> %0 = linalg.pack %cst padding_value(%pad : f32) outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] - inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32> - return %0 : tensor<8x16x8x32xf32> + inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<4x8x8x32xf32> + return %0 : tensor<4x8x8x32xf32> } - // ----- // CHECK-LABEL: func @nofold_padding_value_pack_constant_splat // CHECK: arith.constant dense<1.000000e-01> : tensor<63x127xf32> // CHECK: linalg.pack -func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> { +func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> { %pad = arith.constant 0.0 : f32 %cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32> %0 = linalg.pack %cst @@ -1430,8 +1477,8 @@ func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32] - into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32> - return %0 : tensor<8x16x8x32xf32> + into %dest : tensor<63x127xf32> -> tensor<4x8x8x32xf32> + return %0 : tensor<4x8x8x32xf32> } // ----- diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir index 6fc8d9f..cc26fa4 100644 --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -1295,24 +1295,6 @@ func.func @no_bubble_up_pack_expanded_padding_through_expand_cannot_reassociate( // ----- -func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(%arg0: tensor<32x64xf32>) -> tensor<8x4x16x8xf32> { - %empty = tensor.empty() : tensor<8x4x16x8xf32> - %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32> - %pack = linalg.pack %expanded inner_dims_pos = [0] inner_tiles = [8] into %empty : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32> - return %pack : tensor<8x4x16x8xf32> -} -// CHECK-LABEL: func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x4x16x8xf32> -// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] -// CHECK-SAME: output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32> -// CHECK: %[[PACK:.+]] = linalg.pack %[[EXPANDED]] -// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [8] into %[[EMPTY]] -// CHECK-SAME: : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32> -// CHECK: return %[[PACK]] : tensor<8x4x16x8xf32> - -// ----- - func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index, %sz0: index) -> tensor<?x256x256xf32> { %6 = tensor.empty(%dim) : tensor<?x256xf32> %unpack = linalg.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<?x32x8x8xf32> -> tensor<?x256xf32> diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir index a00c798..5f42938 100644 --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -1076,6 +1076,44 @@ func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> te // ----- +func.func @drop_unit_dim_mixed_static_dynamic(%arg0: tensor<1x?xf32>) -> tensor<1x16xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.000000e+00 : f32 + %padded = tensor.pad %arg0 low[%c0, %c1] high[%c0, %c0] { + ^bb0(%arg1: index, %arg2: index): + tensor.yield %cst : f32 + } : tensor<1x?xf32> to tensor<1x16xf32> + return %padded : tensor<1x16xf32> +} +// CHECK-LABEL: func @drop_unit_dim_mixed_static_dynamic +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARGS:.*]] : tensor<1x?xf32> into tensor<?xf32> +// CHECK: %[[PADDED:.*]] = tensor.pad %[[COLLAPSE]] low[1] high[0] { +// CHECK: ^bb0(%[[IDX:.*]]: index): +// CHECK: tensor.yield %[[CST]] : f32 +// CHECK: } : tensor<?xf32> to tensor<16xf32> +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[PADDED]] {{\[\[}}0, 1]] output_shape [1, 16] : tensor<16xf32> into tensor<1x16xf32> +// CHECK: return %[[EXPANDED]] : tensor<1x16xf32> + +// ----- + +#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> +module { + func.func @drop_unit_dim_corresponding_to_dynamic_dim(%arg0: tensor<1x?x?x1xf32>, %arg1: index) -> tensor<?x1x61x1xf32> { + %cst = arith.constant dense<1.000000e+00> : tensor<1x1x1x1xf32> + %0 = tensor.empty(%arg1) : tensor<?x1x61x1xf32> + %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %cst : tensor<1x?x?x1xf32>, tensor<1x1x1x1xf32>) outs(%0 : tensor<?x1x61x1xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %2 = arith.mulf %in, %in_0 : f32 + %3 = arith.addf %out, %2 : f32 + linalg.yield %3 : f32 + } -> tensor<?x1x61x1xf32> + return %1 : tensor<?x1x61x1xf32> + } +} // CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (0, d0)> // CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> ()> @@ -1097,23 +1135,6 @@ func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> te // CHECK: return %[[VAL_14]] : tensor<?x1x61x1xf32> // CHECK: } -#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> -#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> -#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> -module { - func.func @drop_unit_dim_corresponding_to_dynamic_dim(%arg0: tensor<1x?x?x1xf32>, %arg1: index) -> tensor<?x1x61x1xf32> { - %cst = arith.constant dense<1.000000e+00> : tensor<1x1x1x1xf32> - %0 = tensor.empty(%arg1) : tensor<?x1x61x1xf32> - %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %cst : tensor<1x?x?x1xf32>, tensor<1x1x1x1xf32>) outs(%0 : tensor<?x1x61x1xf32>) { - ^bb0(%in: f32, %in_0: f32, %out: f32): - %2 = arith.mulf %in, %in_0 : f32 - %3 = arith.addf %out, %2 : f32 - linalg.yield %3 : f32 - } -> tensor<?x1x61x1xf32> - return %1 : tensor<?x1x61x1xf32> - } -} - // ----- func.func @no_fold_empty_tensor_dim_out_of_bounds(%arg0: tensor<1x?x10xf32>) -> tensor<1x?xf32> { diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index da1dfc7..40bf4d1 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -1760,6 +1760,7 @@ func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf } // ----- + func.func @pack_mismatch_inner_tile_size_and_output_shape( %input : tensor<?x?xf32>, %output : tensor<?x?x8x8xf32>) -> tensor<?x?x8x8xf32> { // expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}} @@ -1824,27 +1825,47 @@ func.func @unpack_invalid_outer_dims_perm(%source: tensor<128x256xf32>, %dest: t // ----- +func.func @pack_with_artificial_padding(%input: tensor<9xf32>, %output: tensor<3x8xf32>) -> tensor<3x8xf32> { + %cst = arith.constant 0.0 : f32 + // expected-error@+1 {{expected 'tensor<2x8xf32>' for the packed domain value, got 'tensor<3x8xf32>'}} + %0 = linalg.pack %input padding_value(%cst : f32) inner_dims_pos = [0] + inner_tiles = [8] into %output + : tensor<9xf32> -> tensor<3x8xf32> + return %0 : tensor<3x8xf32> +} + +// ----- + // The outer dims in the output tensor are incorrectly/unexpectedly transposed. // This could be fixed by adding `outer_dims_perm = [1, 0]` (the default value assumes no transpose). func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<4x16x32x16xf32>) -> tensor<4x16x32x16xf32> { - // expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<16x4x32x16xf32>', got 'tensor<4x16x32x16xf32>'}} + // expected-error@+1 {{expected 'tensor<16x4x32x16xf32>' for the packed domain value, got 'tensor<4x16x32x16xf32>'}} %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [32, 16] into %output : tensor<256x128xf32> -> tensor<4x16x32x16xf32> return %0 : tensor<4x16x32x16xf32> } // ----- -func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> { - // expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x8x16x32xf32>', got 'tensor<8x8x32x16xf32>'}} - %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32> - return %0 : tensor<8x8x32x16xf32> +func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<8x7x16x32xf32>) -> tensor<8x7x16x32xf32> { + // expected-error@+1 {{expected 'tensor<8x8x16x32xf32>' for the packed domain value, got 'tensor<8x7x16x32xf32>'}} + %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x7x16x32xf32> + return %0 : tensor<8x7x16x32xf32> +} + +// ----- + +func.func @unpack_with_artifical_tiles_that_are_dropped(%input: tensor<3x8xf32>, %output: tensor<9xf32>) -> tensor<9xf32> { + // expected-error@+1 {{expected 'tensor<2x8xf32>' for the packed domain value, got 'tensor<3x8xf32>'}} + %0 = linalg.unpack %input inner_dims_pos = [0] inner_tiles = [8] into %output + : tensor<3x8xf32> -> tensor<9xf32> + return %0 : tensor<9xf32> } // ----- -func.func @unpack_invalid(%output: tensor<256x128xf32>, %input: tensor<8x8x32x16xf32>) -> tensor<256x128xf32> { - // expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x32x4x32xf32>', got 'tensor<8x8x32x16xf32>'}} - %0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x32x16xf32> -> tensor<256x128xf32> +func.func @unpack_invalid_source_shape(%output: tensor<256x128xf32>, %input: tensor<8x8x4x32xf32>) -> tensor<256x128xf32> { + // expected-error@+1 {{expected 'tensor<8x32x4x32xf32>' for the packed domain value, got 'tensor<8x8x4x32xf32>'}} + %0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x4x32xf32> -> tensor<256x128xf32> return %0 : tensor<256x128xf32> } diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir index 81fd7a8..9e7681d 100644 --- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir +++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir @@ -326,23 +326,23 @@ module attributes {transform.with_named_sequence} { // ----- // CHECK-LABEL: func.func @pack_with_pad( -func.func @pack_with_pad(%src: tensor<4225x12xf32>, %dest: tensor<265x16x16x1xf32>) - -> tensor<265x16x16x1xf32> { +func.func @pack_with_pad(%src: tensor<4225x12xf32>, %dest: tensor<265x12x16x1xf32>) + -> tensor<265x12x16x1xf32> { // CHECK: tensor.pad {{.*}} low[0, 0] - // CHECK: : tensor<4225x12xf32> to tensor<4240x16xf32> + // CHECK: : tensor<4225x12xf32> to tensor<4240x12xf32> // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2, 3]] - // CHECK-SAME: : tensor<4240x16xf32> into tensor<265x16x16x1xf32> + // CHECK-SAME: : tensor<4240x12xf32> into tensor<265x16x12x1xf32> // CHECK: linalg.transpose - // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>) - // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>) + // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<265x16x12x1xf32>) + // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<265x12x16x1xf32>) // CHECK-SAME: permutation = [0, 2, 1, 3] %cst = arith.constant 0.000000e+00 : f32 %0 = linalg.pack %src padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [16, 1] into %dest - : tensor<4225x12xf32> -> tensor<265x16x16x1xf32> - return %0 : tensor<265x16x16x1xf32> + : tensor<4225x12xf32> -> tensor<265x12x16x1xf32> + return %0 : tensor<265x12x16x1xf32> } module attributes {transform.with_named_sequence} { diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir index 78619b6..981f5dc 100644 --- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir @@ -52,22 +52,22 @@ module { // CHECK-LABEL: @generic // CHECK-SAME: %[[T0:.*]]: tensor<7x5xf32>, -// CHECK-SAME: %[[T1:.*]]: tensor<7x11x12xf32>) - func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x12xf32>) -> tensor<7x11x12xf32> { +// CHECK-SAME: %[[T1:.*]]: tensor<7x11x11xf32>) + func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x11xf32>) -> tensor<7x11x11xf32> { // CHECK-DAG: %[[CST:.*]] = arith.constant 0. // CHECK: %[[PAD0:.*]] = tensor.pad %[[T0]] low[0, 0] high[2, 0] // CHECK: : tensor<7x5xf32> to tensor<9x5xf32> // CHECK: %[[PAD1:.*]] = tensor.pad %[[T1]] low[0, 0, 0] high[2, 4, 2] { - // CHECK: : tensor<7x11x12xf32> to tensor<9x15x14xf32> + // CHECK: : tensor<7x11x11xf32> to tensor<9x15x13xf32> // CHECK-NEXT: linalg.generic - // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 12] [1, 1, 1] : tensor<9x15x14xf32> to tensor<7x11x12xf32> - %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x12xf32>) { + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 11] [1, 1, 1] : tensor<9x15x13xf32> to tensor<7x11x11xf32> + %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x11xf32>) { ^bb0(%in: f32, %out: f32): linalg.yield %in : f32 - } -> tensor<7x11x12xf32> - return %0 : tensor<7x11x12xf32> + } -> tensor<7x11x11xf32> + return %0 : tensor<7x11x11xf32> } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { @@ -83,7 +83,7 @@ module { // ----- // CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3 + 5)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3 + 4)> // CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 5)> #map = affine_map<(d0, d1, d2) -> (d0, d1)> @@ -272,3 +272,136 @@ module attributes {transform.with_named_sequence} { } } +// ----- + +// CHECK-LABEL: pad_conv +func.func @pad_conv(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> { + + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 12] + // CHECK: : tensor<1x16x16x4xf32> to tensor<1x16x18x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12] + // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0] + // CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32> + // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32> + + %0 = linalg.conv_2d_nhwc_fhwc + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>) + outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> + return %0 : tensor<1x14x14x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of { + padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32] + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 16) * 16 + 2)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 16) * 16)> + +// CHECK-LABEL: pad_conv_dynamic +func.func @pad_conv_dynamic(%arg0: tensor<1x16x?x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x?x16xf32>) -> tensor<1x14x?x16xf32> { + + // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index + // CHECK: %[[D0_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32> + // CHECK: %[[D0_1:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x16x?x4xf32> + // CHECK: %[[H0:.*]] = affine.apply #[[$MAP0]]()[%[[D0_0]], %[[D0_1]]] + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, %[[H0]], 12] + // CHECK: : tensor<1x16x?x4xf32> to tensor<1x16x?x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12] + // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32> + // CHECK: %[[D1_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32> + // CHECK: %[[H1:.*]] = affine.apply #[[$MAP1]]()[%[[D0_0]], %[[D1_0]]] + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, %[[H1]], 0] + // CHECK: : tensor<1x14x?x16xf32> to tensor<1x14x?x16xf32> + // CHECK: %[[D2_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32> + // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, %[[D2_0]], 16] [1, 1, 1, 1] : tensor<1x14x?x16xf32> to tensor<1x14x?x16xf32> + + %0 = linalg.conv_2d_nhwc_fhwc + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<1x16x?x4xf32>, tensor<16x3x3x4xf32>) + outs(%arg2: tensor<1x14x?x16xf32>) -> tensor<1x14x?x16xf32> + return %0 : tensor<1x14x?x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of { + padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32] + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-LABEL: pad_conv_strided +func.func @pad_conv_strided(%arg0: tensor<1x42x42x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> { + + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 6, 12] + // CHECK: : tensor<1x42x42x4xf32> to tensor<1x42x48x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12] + // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0] + // CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32> + // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32> + + %0 = linalg.conv_2d_nhwc_fhwc + {dilations = dense<1> : tensor<2xi64>, strides = dense<3> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<1x42x42x4xf32>, tensor<16x3x3x4xf32>) + outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> + return %0 : tensor<1x14x14x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of { + padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32] + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-LABEL: pad_conv_dilated +func.func @pad_conv_dilated(%arg0: tensor<1x18x18x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> { + + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 12] + // CHECK: : tensor<1x18x18x4xf32> to tensor<1x18x20x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12] + // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0] + // CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32> + // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32> + + %0 = linalg.conv_2d_nhwc_fhwc + {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<1x18x18x4xf32>, tensor<16x3x3x4xf32>) + outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> + return %0 : tensor<1x14x14x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of { + padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32] + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir index 26c03ed..f741876 100644 --- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir @@ -69,22 +69,22 @@ module { // CHECK-LABEL: @generic // CHECK-SAME: %[[T0:.*]]: tensor<7x5xf32>, -// CHECK-SAME: %[[T1:.*]]: tensor<7x11x12xf32>) - func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x12xf32>) -> tensor<7x11x12xf32> { +// CHECK-SAME: %[[T1:.*]]: tensor<7x11x11xf32>) + func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x11xf32>) -> tensor<7x11x11xf32> { // CHECK-DAG: %[[CST:.*]] = arith.constant 0. // CHECK: %[[PAD0:.*]] = tensor.pad %[[T0]] low[0, 0] high[1, 0] // CHECK: : tensor<7x5xf32> to tensor<8x5xf32> // CHECK: %[[PAD1:.*]] = tensor.pad %[[T1]] low[0, 0, 0] high[1, 3, 1] { - // CHECK: : tensor<7x11x12xf32> to tensor<8x14x13xf32> + // CHECK: : tensor<7x11x11xf32> to tensor<8x14x12xf32> // CHECK-NEXT: linalg.generic - // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 12] [1, 1, 1] : tensor<8x14x13xf32> to tensor<7x11x12xf32> - %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x12xf32>) { + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 11] [1, 1, 1] : tensor<8x14x12xf32> to tensor<7x11x11xf32> + %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x11xf32>) { ^bb0(%in: f32, %out: f32): linalg.yield %in : f32 - } -> tensor<7x11x12xf32> - return %0 : tensor<7x11x12xf32> + } -> tensor<7x11x11xf32> + return %0 : tensor<7x11x11xf32> } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { @@ -102,7 +102,7 @@ module { // CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (-s0 + 8)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (-s0 + 13)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (-s0 + 12)> // CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 5)> #map = affine_map<(d0, d1, d2) -> (d0, d1)> @@ -127,13 +127,13 @@ module { // CHECK: %[[D2_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<?x11x?xf32> // CHECK: %[[H2:.*]] = affine.apply #[[$MAP1]]()[%[[D2_0]]] // CHECK: tensor.pad %{{.*}} low[0, 0, 0] high[%[[H1]], 3, %[[H2]]] { - // CHECK: : tensor<?x11x?xf32> to tensor<8x14x13xf32> + // CHECK: : tensor<?x11x?xf32> to tensor<8x14x12xf32> // // CHECK: %[[D0_2:.*]] = tensor.dim %{{.*}}, %[[C0]] : tensor<?x5xf32> // CHECK: %[[D2_1:.*]] = affine.apply #[[$MAP2]]()[%[[D0_2]]] - // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<8x5xf32>) outs(%{{.*}} : tensor<8x14x13xf32>) { - // CHECK: } -> tensor<8x14x13xf32> - // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [%[[D0_2]], 11, %[[D2_1]]] [1, 1, 1] : tensor<8x14x13xf32> to tensor<?x11x?xf32> + // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<8x5xf32>) outs(%{{.*}} : tensor<8x14x12xf32>) { + // CHECK: } -> tensor<8x14x12xf32> + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [%[[D0_2]], 11, %[[D2_1]]] [1, 1, 1] : tensor<8x14x12xf32> to tensor<?x11x?xf32> // %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<?x5xf32>) outs(%arg1 : tensor<?x11x?xf32>) { ^bb0(%in: f32, %out: f32): diff --git a/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir index c3ee892..d7722ea 100644 --- a/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir +++ b/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir @@ -230,18 +230,17 @@ func.func @vectorize_nd_tensor_extract_index_from_tensor(%arg0: tensor<3x3xf32>, // CHECK-SAME: %[[ARG4:.*]]: tensor<4x7x3x2xf32> // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[PV:.*]] = ub.poison : i32 -// CHECK-DAG: %[[CST:.*]] = arith.constant dense<3> : vector<7x2x4x3xindex> +// CHECK-DAG: %[[CST:.*]] = arith.constant dense<3> : vector<4x3xindex> // CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<true> : vector<4x7x3x2xi1> // CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<4x7x3x2xf32> // CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], %[[PV]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32> // CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], %[[PV]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32> // CHECK: %[[CAST:.*]] = arith.index_cast %[[V0]] : vector<4x3xi32> to vector<4x3xindex> -// CHECK: %[[B1:.*]] = vector.broadcast %[[CAST]] : vector<4x3xindex> to vector<7x2x4x3xindex> // CHECK: %[[CAST_1:.*]] = arith.index_cast %[[V1]] : vector<4x3xi32> to vector<4x3xindex> -// CHECK: %[[B2:.*]] = vector.broadcast %[[CAST_1]] : vector<4x3xindex> to vector<7x2x4x3xindex> -// CHECK: %[[MULI:.*]] = arith.muli %[[B1]], %[[CST]] : vector<7x2x4x3xindex> -// CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[MULI]] : vector<7x2x4x3xindex> -// CHECK: %[[T:.*]] = vector.transpose %[[ADDI]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex> +// CHECK: %[[MULI:.*]] = arith.muli %[[CAST]], %[[CST]] : vector<4x3xindex> +// CHECK: %[[ADDI:.*]] = arith.addi %[[CAST_1]], %[[MULI]] : vector<4x3xindex> +// CHECK: %[[B:.*]] = vector.broadcast %[[ADDI]] : vector<4x3xindex> to vector<7x2x4x3xindex> +// CHECK: %[[T:.*]] = vector.transpose %[[B]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex> // CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]]] [%[[T]]], %[[CST_1]], %[[PASSTHRU]] : tensor<3x3xf32>, vector<4x7x3x2xindex>, vector<4x7x3x2xi1>, vector<4x7x3x2xf32> into vector<4x7x3x2xf32> // CHECK: vector.transfer_write %[[GATHER]], %[[ARG4]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true]} : vector<4x7x3x2xf32>, tensor<4x7x3x2xf32> @@ -270,20 +269,16 @@ func.func @vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load(% // CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32> // CHECK-SAME: %[[ARG1:.*]]: index // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[CST:.*]] = arith.constant dense<768> : vector<1x8xindex> -// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<128> : vector<1x8xindex> // CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32> -// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<true> : vector<8x1xi1> -// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex> +// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<true> : vector<8x1xi1> +// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<[0, 98304, 196608, 294912, 393216, 491520, 589824, 688128]> : vector<8xindex> // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<8x1xf32> -// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_3]] : vector<8xindex> to vector<1x8xindex> // CHECK: %[[ADDI_ARG1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index -// CHECK: %[[MULI_1:.*]] = arith.muli %[[B1]], %[[CST_0]] : vector<1x8xindex> -// CHECK: %[[MULI_2:.*]] = arith.muli %[[MULI_1]], %[[CST]] : vector<1x8xindex> -// CHECK: %[[T:.*]] = vector.transpose %[[MULI_2]], [1, 0] : vector<1x8xindex> to vector<8x1xindex> +// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_1]] : vector<8xindex> to vector<1x8xindex> +// CHECK: %[[T:.*]] = vector.transpose %[[B1]], [1, 0] : vector<1x8xindex> to vector<8x1xindex> // CHECK: %[[B2:.*]] = vector.broadcast %[[ADDI_ARG1]] : index to vector<8x1xindex> // CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[T]] : vector<8x1xindex> -// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_2]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32> +// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_0]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32> // CHECK: vector.transfer_write %[[GATHER]], %[[EMPTY]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32> // ----- @@ -309,15 +304,13 @@ func.func @index_from_output_column_vector_gather_load(%src: tensor<8x128xf32>) // CHECK-LABEL: func.func @index_from_output_column_vector_gather_load( // CHECK-SAME: %[[SRC:.*]]: tensor<8x128xf32>) -> tensor<8x1xf32> { -// CHECK: %[[C128:.*]] = arith.constant dense<128> : vector<1x8xindex> +// CHECK: %[[IDX_VEC:.*]] = arith.constant dense<[0, 128, 256, 384, 512, 640, 768, 896]> : vector<8xindex> // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[PASS_THRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32> // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8x1xi1> -// CHECK: %[[IDX_VEC:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex> // CHECK: %[[OUT:.*]] = tensor.empty() : tensor<8x1xf32> // CHECK: %[[B:.*]] = vector.broadcast %[[IDX_VEC]] : vector<8xindex> to vector<1x8xindex> -// CHECK: %[[MUL:.*]] = arith.muli %[[B]], %[[C128]] : vector<1x8xindex> -// CHECK: %[[TR:.*]] = vector.transpose %[[MUL]], [1, 0] : vector<1x8xindex> to vector<8x1xindex> +// CHECK: %[[TR:.*]] = vector.transpose %[[B]], [1, 0] : vector<1x8xindex> to vector<8x1xindex> // CHECK: %[[GATHER:.*]] = vector.gather %[[SRC]]{{\[}}%[[C0]], %[[C0]]] {{\[}}%[[TR]]], %[[MASK]], %[[PASS_THRU]] : tensor<8x128xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32> // CHECK: %[[RES:.*]] = vector.transfer_write %[[GATHER]], %[[OUT]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32> // CHECK: return %[[RES]] : tensor<8x1xf32> @@ -420,12 +413,12 @@ func.func @vectorize_nd_tensor_extract_with_affine_apply_gather(%6: tensor<80x16 // CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<true> : vector<1x4xi1> // CHECK-DAG: %[[VAL_5:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32> // CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant dense<16> : vector<1x4xindex> +// CHECK-DAG: %[[VAL_7:.*]] = arith.constant dense<16> : vector<4xindex> // CHECK: %[[VAL_8:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex> // CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : vector<4xindex> -// CHECK: %[[VAL_10:.*]] = vector.broadcast %[[VAL_9]] : vector<4xindex> to vector<1x4xindex> -// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_10]], %[[VAL_7]] : vector<1x4xindex> -// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_7]] : vector<1x4xindex> +// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_9]], %[[VAL_7]] : vector<4xindex> +// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_7]] : vector<4xindex> +// CHECK: %[[VAL_12:.*]] = vector.broadcast %[[VAL_11]] : vector<4xindex> to vector<1x4xindex> // CHECK: %[[VAL_13:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_12]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32> // CHECK: %[[VAL_14:.*]] = vector.transfer_write %[[VAL_13]], %[[VAL_2]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32> // CHECK: return %[[VAL_14]] : tensor<1x4xf32> @@ -450,14 +443,12 @@ func.func @vectorize_nd_tensor_extract_with_maxsi_gather(%arg0: tensor<80x16xf32 // CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_maxsi_gather( // CHECK-SAME: %[[VAL_0:.*]]: tensor<80x16xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> { -// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<1264> : vector<1x4xindex> +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[1264, 1265, 1266, 1267]> : vector<4xindex> // CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<true> : vector<1x4xi1> // CHECK-DAG: %[[VAL_5:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32> // CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index // CHECK: %[[VAL_7:.*]] = vector.broadcast %[[VAL_2]] : vector<4xindex> to vector<1x4xindex> -// CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_3]] : vector<1x4xindex> -// CHECK: %[[VAL_9:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_8]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32> +// CHECK: %[[VAL_9:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_7]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32> // CHECK: %[[VAL_10:.*]] = vector.transfer_write %[[VAL_9]], %[[VAL_1]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32> // CHECK: return %[[VAL_10]] : tensor<1x4xf32> // CHECK: } @@ -519,13 +510,13 @@ func.func @vectorize_reverse_like_tensor_extract(%arg0: tensor<1x2x3xf32>, %arg1 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]] // CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]] // CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]] -// CHECK-DAG: %[[CST:.+]] = arith.constant dense<3> : vector<1x1x3xindex> +// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x3xi1> // CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32> // CHECK-DAG: %[[INIT_IDX:.+]] = arith.constant dense<[2, 1, 0]> : vector<3xindex> -// CHECK: %[[T0:.+]] = vector.broadcast %[[ARG2]] : index to vector<1x1x3xindex> -// CHECK: %[[T1:.+]] = arith.muli %[[T0]], %[[CST]] : vector<1x1x3xindex> +// CHECK: %[[T0:.+]] = arith.muli %[[ARG2]], %[[C3]] : index +// CHECK: %[[T1:.+]] = vector.broadcast %[[T0]] : index to vector<1x1x3xindex> // CHECK: %[[T2:.+]] = vector.broadcast %[[INIT_IDX]] // CHECK: %[[T3:.+]] = arith.addi %[[T2]], %[[T1]] // CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[T3]]], %[[MASK]], %[[PASSTHRU]] diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir index 98e8f50..d41d861 100644 --- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir +++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir @@ -941,20 +941,17 @@ module attributes {transform.with_named_sequence} { // CHECK-LABEL: func @test_vectorize_dynamic_shapes_unpack // CHECK-SAME: %[[ARG_0:.*]]: tensor<?x?xf32>, +// CHECK-SAME: %[[ARG_1:.*]]: tensor<?x?x16x2xf32> func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x16x2xf32>) -> tensor<?x?xf32> { // CHECK: %[[C0:.*]] = arith.constant 0 -// CHECK: %[[DIM:.*]] = tensor.dim %arg0, %[[C0]] : tensor<?x?xf32> -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[DIM0:.*]] = tensor.dim %arg0, %[[C1]] : tensor<?x?xf32> -// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 // CHECK: %[[C01:.*]] = arith.constant 0 // CHECK: %[[C02:.*]] = arith.constant 0 -// CHECK: %[[DIM4:.*]] = tensor.dim %arg1, %[[C02]] : tensor<?x?x16x2xf32> -// CHECK: %[[CNST14:.*]] = arith.constant 1 -// CHECK: %[[DIM6:.*]] = tensor.dim %arg1, %[[CNST14]] : tensor<?x?x16x2xf32> +// CHECK: %[[DIM_0:.*]] = tensor.dim %[[ARG_1]], %[[C02]] : tensor<?x?x16x2xf32> +// CHECK: %[[C1:.*]] = arith.constant 1 +// CHECK: %[[DIM6:.*]] = tensor.dim %[[ARG_1]], %[[C1]] : tensor<?x?x16x2xf32> // CHECK: %[[CNST16:.*]] = arith.constant 16 : index // CHECK: %[[CNST2:.*]] = arith.constant 2 : index -// CHECK: %[[readMsk0:.*]] = vector.create_mask %[[DIM4]], %[[DIM6]], %[[CNST16]], %[[CNST2]] : vector<2x1x16x2xi1> +// CHECK: %[[readMsk0:.*]] = vector.create_mask %[[DIM_0]], %[[DIM6]], %[[CNST16]], %[[CNST2]] : vector<2x1x16x2xi1> // CHECK: %[[read0:.*]] = vector.mask %[[readMsk0]] {{.*}} vector.transfer_read %{{.*}} : tensor<?x?x16x2xf32>, vector<2x1x16x2xf32> } : vector<2x1x16x2xi1> -> vector<2x1x16x2xf32> // CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 3, 1, 2] : vector<2x1x16x2xf32> to vector<2x2x1x16xf32> // CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<2x2x1x16xf32> to vector<4x16xf32> diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 4c50ed3..8c846cd 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -1406,7 +1406,7 @@ func.func @omp_atomic_update(%x : memref<i32>, %expr : i32, %xBool : memref<i1>, // CHECK-NEXT: (%[[XVAL:.*]]: i1): // CHECK-NEXT: %[[NEWVAL:.*]] = llvm.icmp "eq" %[[XVAL]], %[[EXPRBOOL]] : i1 // CHECK-NEXT: omp.yield(%[[NEWVAL]] : i1) - // } + // CHECK-NEXT: } omp.atomic.update %xBool : memref<i1> { ^bb0(%xval: i1): %newval = llvm.icmp "eq" %xval, %exprBool : i1 @@ -1562,6 +1562,14 @@ func.func @omp_atomic_update(%x : memref<i32>, %expr : i32, %xBool : memref<i1>, omp.yield(%newval : i32) } + // CHECK: omp.atomic.update %[[X]] : memref<i32> { + // CHECK-NEXT: (%[[XVAL:.*]]: i32): + // CHECK-NEXT: omp.yield(%{{.+}} : i32) + // CHECK-NEXT: } {atomic_control = #omp.atomic_control<ignore_denormal_mode = true, fine_grained_memory = true, remote_memory = true>} + omp.atomic.update %x : memref<i32> { + ^bb0(%xval:i32): + omp.yield(%const:i32) + } {atomic_control = #omp.atomic_control<ignore_denormal_mode = true, fine_grained_memory = true, remote_memory = true>} return } diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index 12d30e17..308cf150 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -1440,8 +1440,8 @@ func.func @propagate_into_execute_region() { // ----- -// CHECK-LABEL: func @execute_region_elim -func.func @execute_region_elim() { +// CHECK-LABEL: func @execute_region_inline +func.func @execute_region_inline() { affine.for %i = 0 to 100 { "test.foo"() : () -> () %v = scf.execute_region -> i64 { @@ -1461,8 +1461,30 @@ func.func @execute_region_elim() { // ----- -// CHECK-LABEL: func @func_execute_region_elim -func.func @func_execute_region_elim() { +// CHECK-LABEL: func @execute_region_no_inline +func.func @execute_region_no_inline() { + affine.for %i = 0 to 100 { + "test.foo"() : () -> () + %v = scf.execute_region -> i64 no_inline { + %x = "test.val"() : () -> i64 + scf.yield %x : i64 + } + "test.bar"(%v) : (i64) -> () + } + return +} + +// CHECK-NEXT: affine.for %arg0 = 0 to 100 { +// CHECK-NEXT: "test.foo"() : () -> () +// CHECK-NEXT: scf.execute_region +// CHECK-NEXT: %[[VAL:.*]] = "test.val"() : () -> i64 +// CHECK-NEXT: scf.yield %[[VAL]] : i64 +// CHECK-NEXT: } + +// ----- + +// CHECK-LABEL: func @func_execute_region_inline +func.func @func_execute_region_inline() { "test.foo"() : () -> () %v = scf.execute_region -> i64 { %c = "test.cmp"() : () -> i1 @@ -1496,8 +1518,8 @@ func.func @func_execute_region_elim() { // ----- -// CHECK-LABEL: func @func_execute_region_elim_multi_yield -func.func @func_execute_region_elim_multi_yield() { +// CHECK-LABEL: func @func_execute_region_inline_multi_yield +func.func @func_execute_region_inline_multi_yield() { "test.foo"() : () -> () %v = scf.execute_region -> i64 { %c = "test.cmp"() : () -> i1 diff --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir index 3adafc1..c703274 100644 --- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir @@ -13,7 +13,7 @@ func.func @fadd_scalar(%arg: f32) -> f32 { // ----- func.func @fadd_bf16_scalar(%arg: bf16) -> bf16 { - // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + // expected-error @+1 {{operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}} %0 = spirv.FAdd %arg, %arg : bf16 return %0 : bf16 } @@ -33,7 +33,7 @@ func.func @fdiv_scalar(%arg: f32) -> f32 { // ----- func.func @fdiv_bf16_scalar(%arg: bf16) -> bf16 { - // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + // expected-error @+1 {{operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}} %0 = spirv.FDiv %arg, %arg : bf16 return %0 : bf16 } @@ -53,7 +53,7 @@ func.func @fmod_scalar(%arg: f32) -> f32 { // ----- func.func @fmod_bf16_scalar(%arg: bf16) -> bf16 { - // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + // expected-error @+1 {{operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}} %0 = spirv.FMod %arg, %arg : bf16 return %0 : bf16 } @@ -79,7 +79,7 @@ func.func @fmul_vector(%arg: vector<4xf32>) -> vector<4xf32> { // ----- func.func @fmul_i32(%arg: i32) -> i32 { - // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + // expected-error @+1 {{operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}} %0 = spirv.FMul %arg, %arg : i32 return %0 : i32 } @@ -87,7 +87,7 @@ func.func @fmul_i32(%arg: i32) -> i32 { // ----- func.func @fmul_bf16(%arg: bf16) -> bf16 { - // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + // expected-error @+1 {{operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}} %0 = spirv.FMul %arg, %arg : bf16 return %0 : bf16 } @@ -95,7 +95,7 @@ func.func @fmul_bf16(%arg: bf16) -> bf16 { // ----- func.func @fmul_bf16_vector(%arg: vector<4xbf16>) -> vector<4xbf16> { - // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + // expected-error @+1 {{operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}} %0 = spirv.FMul %arg, %arg : vector<4xbf16> return %0 : vector<4xbf16> } @@ -103,7 +103,7 @@ func.func @fmul_bf16_vector(%arg: vector<4xbf16>) -> vector<4xbf16> { // ----- func.func @fmul_tensor(%arg: tensor<4xf32>) -> tensor<4xf32> { - // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + // expected-error @+1 {{operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}} %0 = spirv.FMul %arg, %arg : tensor<4xf32> return %0 : tensor<4xf32> } @@ -123,7 +123,7 @@ func.func @fnegate_scalar(%arg: f32) -> f32 { // ----- func.func @fnegate_bf16_scalar(%arg: bf16) -> bf16 { - // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + // expected-error @+1 {{operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}} %0 = spirv.FNegate %arg : bf16 return %0 : bf16 } @@ -143,7 +143,7 @@ func.func @frem_scalar(%arg: f32) -> f32 { // ----- func.func @frem_bf16_scalar(%arg: bf16) -> bf16 { - // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + // expected-error @+1 {{operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}} %0 = spirv.FRem %arg, %arg : bf16 return %0 : bf16 } @@ -163,7 +163,7 @@ func.func @fsub_scalar(%arg: f32) -> f32 { // ----- func.func @fsub_bf16_scalar(%arg: bf16) -> bf16 { - // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + // expected-error @+1 {{operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}} %0 = spirv.FSub %arg, %arg : bf16 return %0 : bf16 } @@ -348,7 +348,7 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 { // ----- func.func @dot(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 { - // expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16}} + // expected-error @+1 {{'spirv.Dot' op operand #0 must be fixed-length vector of 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16}} %0 = spirv.Dot %arg0, %arg1 : vector<4xi32> -> i32 return %0 : i32 } @@ -558,7 +558,7 @@ func.func @vector_times_scalar(%vector: vector<4xf32>, %scalar: f32) -> vector<3 // ----- func.func @vector_bf16_times_scalar_bf16(%vector: vector<4xbf16>, %scalar: bf16) -> vector<4xbf16> { - // expected-error @+1 {{op operand #0 must be vector of 16/32/64-bit float values of length 2/3/4}} + // expected-error @+1 {{operand #0 must be vector of 16/32/64-bit float values of length 2/3/4}} %0 = spirv.VectorTimesScalar %vector, %scalar : (vector<4xbf16>, bf16) -> vector<4xbf16> return %0 : vector<4xbf16> } diff --git a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir index f3f0ebf..4bdac19 100644 --- a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir @@ -137,7 +137,7 @@ func.func @bitwise_or_all_ones_vector(%arg: vector<3xi8>) -> vector<3xi8> { // ----- func.func @bitwise_or_float(%arg0: f16, %arg1: f16) -> f16 { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4}} %0 = spirv.BitwiseOr %arg0, %arg1 : f16 return %0 : f16 } @@ -165,7 +165,7 @@ func.func @bitwise_xor_vector(%arg: vector<4xi32>) -> vector<4xi32> { // ----- func.func @bitwise_xor_float(%arg0: f16, %arg1: f16) -> f16 { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4}} %0 = spirv.BitwiseXor %arg0, %arg1 : f16 return %0 : f16 } @@ -274,7 +274,7 @@ func.func @bitwise_and_zext_vector(%arg: vector<2xi8>) -> vector<2xi32> { // ----- func.func @bitwise_and_float(%arg0: f16, %arg1: f16) -> f16 { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4}} %0 = spirv.BitwiseAnd %arg0, %arg1 : f16 return %0 : f16 } diff --git a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir index 5c5d94c..fd8a2ff 100644 --- a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir @@ -19,7 +19,7 @@ func.func @expvec(%arg0 : vector<3xf16>) -> () { // ----- func.func @exp(%arg0 : i32) -> () { - // expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values}} + // expected-error @+1 {{op operand #0 must be 16/32-bit float or fixed-length vector of 16/32-bit float values}} %2 = spirv.GL.Exp %arg0 : i32 return } @@ -27,7 +27,7 @@ func.func @exp(%arg0 : i32) -> () { // ----- func.func @exp(%arg0 : vector<5xf32>) -> () { - // expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values of length 2/3/4}} + // expected-error @+1 {{op operand #0 must be 16/32-bit float or fixed-length vector of 16/32-bit float values of length 2/3/4}} %2 = spirv.GL.Exp %arg0 : vector<5xf32> return } @@ -51,7 +51,7 @@ func.func @exp(%arg0 : i32) -> () { // ----- func.func @exp_bf16(%arg0 : bf16) -> () { - // expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values of length 2/3/4}} + // expected-error @+1 {{op operand #0 must be 16/32-bit float or fixed-length vector of 16/32-bit float values of length 2/3/4}} %2 = spirv.GL.Exp %arg0 : bf16 return } @@ -101,7 +101,7 @@ func.func @iminmax(%arg0: i32, %arg1: i32) { // ----- func.func @fmaxminbf16vec(%arg0 : vector<3xbf16>, %arg1 : vector<3xbf16>) { - // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + // expected-error @+1 {{operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}} %1 = spirv.GL.FMax %arg0, %arg1 : vector<3xbf16> %2 = spirv.GL.FMin %arg0, %arg1 : vector<3xbf16> return @@ -499,7 +499,7 @@ func.func @frexp_struct_mismatch_type(%arg0 : f32) -> () { // ----- func.func @frexp_struct_wrong_type(%arg0 : i32) -> () { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}} %2 = spirv.GL.FrexpStruct %arg0 : i32 -> !spirv.struct<(i32, i32)> return } @@ -614,7 +614,7 @@ func.func @findimsb_vector_i64(%arg0 : vector<3xi64>) -> () { // ----- func.func @findimsb_error_scalar_float(%arg0 : f32) -> () { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/1}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/1}} %2 = spirv.GL.FindILsb %arg0 : f32 return } @@ -640,7 +640,7 @@ func.func @findsmsb_vector(%arg0 : vector<3xi32>) -> () { // ----- func.func @findsmsb_error_scalar_i64(%arg0 : i64) -> () { - // expected-error @+1 {{operand #0 must be Int32 or vector of Int32}} + // expected-error @+1 {{operand #0 must be Int32 or fixed-length vector of Int32}} %2 = spirv.GL.FindSMsb %arg0 : i64 return } @@ -666,7 +666,7 @@ func.func @findumsb_vector(%arg0 : vector<3xi32>) -> () { // ----- func.func @findumsb(%arg0 : i64) -> () { - // expected-error @+1 {{operand #0 must be Int32 or vector of Int32}} + // expected-error @+1 {{operand #0 must be Int32 or fixed-length vector of Int32}} %2 = spirv.GL.FindUMsb %arg0 : i64 return } @@ -692,7 +692,7 @@ func.func @distance_vector(%arg0 : vector<3xf32>, %arg1 : vector<3xf32>) { // ----- func.func @distance_invalid_type(%arg0 : i32, %arg1 : i32) { - // expected-error @+1 {{'spirv.GL.Distance' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}} + // expected-error @+1 {{'spirv.GL.Distance' op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16}} %0 = spirv.GL.Distance %arg0, %arg1 : i32, i32 -> f32 return } @@ -708,7 +708,7 @@ func.func @distance_arg_mismatch(%arg0 : vector<3xf32>, %arg1 : vector<4xf32>) { // ----- func.func @distance_invalid_vector_size(%arg0 : vector<5xf32>, %arg1 : vector<5xf32>) { - // expected-error @+1 {{'spirv.GL.Distance' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}} + // expected-error @+1 {{'spirv.GL.Distance' op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16}} %0 = spirv.GL.Distance %arg0, %arg1 : vector<5xf32>, vector<5xf32> -> f32 return } @@ -736,7 +736,7 @@ func.func @cross(%arg0 : vector<3xf32>, %arg1 : vector<3xf32>) { // ----- func.func @cross_invalid_type(%arg0 : vector<3xi32>, %arg1 : vector<3xi32>) { - // expected-error @+1 {{'spirv.GL.Cross' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'vector<3xi32>'}} + // expected-error @+1 {{'spirv.GL.Cross' op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'vector<3xi32>'}} %0 = spirv.GL.Cross %arg0, %arg1 : vector<3xi32> return } @@ -762,7 +762,7 @@ func.func @normalize_vector(%arg0 : vector<3xf32>) { // ----- func.func @normalize_invalid_type(%arg0 : i32) { - // expected-error @+1 {{'spirv.GL.Normalize' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + // expected-error @+1 {{'spirv.GL.Normalize' op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}} %0 = spirv.GL.Normalize %arg0 : i32 return } @@ -788,7 +788,7 @@ func.func @reflect_vector(%arg0 : vector<3xf32>, %arg1 : vector<3xf32>) { // ----- func.func @reflect_invalid_type(%arg0 : i32, %arg1 : i32) { - // expected-error @+1 {{'spirv.GL.Reflect' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + // expected-error @+1 {{'spirv.GL.Reflect' op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}} %0 = spirv.GL.Reflect %arg0, %arg1 : i32 return } @@ -814,7 +814,7 @@ func.func @fractvec(%arg0 : vector<3xf16>) -> () { // ----- func.func @fract_invalid_type(%arg0 : i32) { - // expected-error @+1 {{'spirv.GL.Fract' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + // expected-error @+1 {{'spirv.GL.Fract' op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}} %0 = spirv.GL.Fract %arg0 : i32 return } @@ -840,7 +840,7 @@ func.func @log2vec(%arg0 : vector<3xf16>) -> () { // ----- func.func @log2_invalid_type(%arg0 : i32) -> () { - // expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values}} + // expected-error @+1 {{op operand #0 must be 16/32-bit float or fixed-length vector of 16/32-bit float values}} %0 = spirv.GL.Log2 %arg0 : i32 return } @@ -866,7 +866,7 @@ func.func @tanhvec(%arg0 : vector<3xf16>) -> () { // ----- func.func @tanh_invalid_type(%arg0 : i32) -> () { - // expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values}} + // expected-error @+1 {{op operand #0 must be 16/32-bit float or fixed-length vector of 16/32-bit float values}} %0 = spirv.GL.Tanh %arg0 : i32 return } @@ -892,7 +892,7 @@ func.func @exp2vec(%arg0 : vector<3xf16>) -> () { // ----- func.func @exp2_invalid_type(%arg0 : i32) -> () { - // expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values}} + // expected-error @+1 {{op operand #0 must be 16/32-bit float or fixed-length vector of 16/32-bit float values}} %0 = spirv.GL.Exp2 %arg0 : i32 return } @@ -1022,7 +1022,7 @@ func.func @lengthvec(%arg0 : vector<3xf32>) -> () { // ----- func.func @length_i32_in(%arg0 : i32) -> () { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'i32'}} + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'i32'}} %0 = spirv.GL.Length %arg0 : i32 -> f32 return } @@ -1038,7 +1038,7 @@ func.func @length_f16_in(%arg0 : f16) -> () { // ----- func.func @length_i32vec_in(%arg0 : vector<3xi32>) -> () { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'vector<3xi32>'}} + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'vector<3xi32>'}} %0 = spirv.GL.Length %arg0 : vector<3xi32> -> f32 return } diff --git a/mlir/test/Dialect/SPIRV/IR/group-ops.mlir b/mlir/test/Dialect/SPIRV/IR/group-ops.mlir index d9957ad8..d7a4a6d 100644 --- a/mlir/test/Dialect/SPIRV/IR/group-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/group-ops.mlir @@ -49,7 +49,7 @@ func.func @group_broadcast_negative_scope(%value: f32, %localid: vector<3xi32> ) // ----- func.func @group_broadcast_negative_locid_dtype(%value: f32, %localid: vector<3xf32> ) -> f32 { - // expected-error @+1 {{operand #1 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values}} + // expected-error @+1 {{op operand #1 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values}} %0 = spirv.GroupBroadcast <Subgroup> %value, %localid : f32, vector<3xf32> return %0: f32 } diff --git a/mlir/test/Dialect/SPIRV/IR/image-ops.mlir b/mlir/test/Dialect/SPIRV/IR/image-ops.mlir index d3aaef7..320a8fa 100644 --- a/mlir/test/Dialect/SPIRV/IR/image-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/image-ops.mlir @@ -349,7 +349,7 @@ func.func @image_fetch_2d_result(%arg0: !spirv.image<f32, Dim2D, NoDepth, NonArr // ----- func.func @image_fetch_float_coords(%arg0: !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Rgba8>, %arg1: vector<2xf32>) -> () { - // expected-error @+1 {{op operand #1 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'vector<2xf32>'}} + // expected-error @+1 {{op operand #1 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'vector<2xf32>'}} %0 = spirv.ImageFetch %arg0, %arg1 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Rgba8>, vector<2xf32> -> vector<2xf32> spirv.Return } diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir index bb15d01..2e2fb1a 100644 --- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir @@ -21,7 +21,7 @@ spirv.func @f32_to_bf16_vec(%arg0 : vector<2xf32>) "None" { // ----- spirv.func @f32_to_bf16_unsupported(%arg0 : f64) "None" { - // expected-error @+1 {{operand #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}} + // expected-error @+1 {{operand #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16, but got}} %0 = spirv.INTEL.ConvertFToBF16 %arg0 : f64 to i16 spirv.Return } @@ -29,7 +29,7 @@ spirv.func @f32_to_bf16_unsupported(%arg0 : f64) "None" { // ----- spirv.func @f32_to_bf16_vec_unsupported(%arg0 : vector<2xf32>) "None" { - // expected-error @+1 {{operand and result must have same number of elements}} + // expected-error @+1 {{op requires the same shape for all operands and results}} %0 = spirv.INTEL.ConvertFToBF16 %arg0 : vector<2xf32> to vector<4xi16> spirv.Return } @@ -57,7 +57,7 @@ spirv.func @bf16_to_f32_vec(%arg0 : vector<2xi16>) "None" { // ----- spirv.func @bf16_to_f32_unsupported(%arg0 : i16) "None" { - // expected-error @+1 {{result #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}} + // expected-error @+1 {{result #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16, but got}} %0 = spirv.INTEL.ConvertBF16ToF %arg0 : i16 to f16 spirv.Return } @@ -65,7 +65,7 @@ spirv.func @bf16_to_f32_unsupported(%arg0 : i16) "None" { // ----- spirv.func @bf16_to_f32_vec_unsupported(%arg0 : vector<2xi16>) "None" { - // expected-error @+1 {{operand and result must have same number of elements}} + // expected-error @+1 {{op requires the same shape for all operands and results}} %0 = spirv.INTEL.ConvertBF16ToF %arg0 : vector<2xi16> to vector<3xf32> spirv.Return } @@ -73,6 +73,42 @@ spirv.func @bf16_to_f32_vec_unsupported(%arg0 : vector<2xi16>) "None" { // ----- //===----------------------------------------------------------------------===// +// spirv.INTEL.RoundFToTF32 +//===----------------------------------------------------------------------===// + +spirv.func @f32_to_tf32(%arg0 : f32) "None" { + // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : f32 to f32 + %0 = spirv.INTEL.RoundFToTF32 %arg0 : f32 to f32 + spirv.Return +} + +// ----- + +spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" { + // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : vector<2xf32> to vector<2xf32> + %0 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<2xf32> + spirv.Return +} + +// ----- + +spirv.func @f32_to_tf32_unsupported(%arg0 : f64) "None" { + // expected-error @+1 {{op operand #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16, but got 'f64'}} + %0 = spirv.INTEL.RoundFToTF32 %arg0 : f64 to f32 + spirv.Return +} + +// ----- + +spirv.func @f32_to_tf32_vec_unsupported(%arg0 : vector<2xf32>) "None" { + // expected-error @+1 {{op requires the same shape for all operands and results}} + %0 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<4xf32> + spirv.Return +} + +// ----- + +//===----------------------------------------------------------------------===// // spirv.INTEL.SplitBarrier //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir index 61a35b7..491c7a7 100644 --- a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir @@ -583,7 +583,7 @@ spirv.func @matrix_times_scalar(%a: !spirv.coopmatrix<2x2xf32, Workgroup, Matrix // These binary arithmetic instructions do not support coop matrix operands. spirv.func @fmod(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>) "None" { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}} + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16}} %p = spirv.FMod %a, %b : !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA> spirv.Return } @@ -591,14 +591,14 @@ spirv.func @fmod(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, %b: !spirv.c // ----- spirv.func @frem(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>) "None" { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}} + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16}} %p = spirv.FRem %a, %b : !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA> spirv.Return } // ----- spirv.func @smod(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}} %p = spirv.SMod %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA> spirv.Return } @@ -606,7 +606,7 @@ spirv.func @smod(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.c // ----- spirv.func @srem(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}} %p = spirv.SRem %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA> spirv.Return } @@ -614,7 +614,7 @@ spirv.func @srem(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.c // ----- spirv.func @umod(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}} %p = spirv.UMod %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA> spirv.Return } diff --git a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir index d6c3464..d7f4ed0 100644 --- a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir @@ -33,6 +33,24 @@ func.func @inotequal_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vecto // ----- //===----------------------------------------------------------------------===// +// spirv.IsFinite +//===----------------------------------------------------------------------===// + +func.func @isfinite_scalar(%arg0: f32) -> i1 { + // CHECK: spirv.IsFinite {{.*}} : f32 + %0 = spirv.IsFinite %arg0 : f32 + return %0 : i1 +} + +func.func @isfinite_vector(%arg0: vector<2xf32>) -> vector<2xi1> { + // CHECK: spirv.IsFinite {{.*}} : vector<2xf32> + %0 = spirv.IsFinite %arg0 : vector<2xf32> + return %0 : vector<2xi1> +} + +// ----- + +//===----------------------------------------------------------------------===// // spirv.IsInf //===----------------------------------------------------------------------===// @@ -166,7 +184,7 @@ func.func @logicalUnary(%arg0 : i1) func.func @logicalUnary(%arg0 : i32) { - // expected-error @+1 {{'operand' must be bool or vector of bool values of length 2/3/4/8/16, but got 'i32'}} + // expected-error @+1 {{'operand' must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}} %0 = spirv.LogicalNot %arg0 : i32 return } diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir index 7ab94f1..bdb2abd 100644 --- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir @@ -185,7 +185,7 @@ func.func @group_non_uniform_fmul_clustered_reduce(%val: vector<2xf32>) -> vecto // ----- func.func @group_non_uniform_bf16_fmul_reduce(%val: bf16) -> bf16 { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'bf16'}} + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'bf16'}} %0 = spirv.GroupNonUniformFMul <Workgroup> <Reduce> %val : bf16 -> bf16 return %0: bf16 } @@ -206,7 +206,7 @@ func.func @group_non_uniform_fmax_reduce(%val: f32) -> f32 { // ----- func.func @group_non_uniform_bf16_fmax_reduce(%val: bf16) -> bf16 { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'bf16'}} + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'bf16'}} %0 = spirv.GroupNonUniformFMax <Workgroup> <Reduce> %val : bf16 -> bf16 return %0: bf16 } @@ -511,7 +511,7 @@ func.func @group_non_uniform_bitwise_and(%val: i32) -> i32 { // ----- func.func @group_non_uniform_bitwise_and(%val: i1) -> i1 { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}} %0 = spirv.GroupNonUniformBitwiseAnd <Workgroup> <Reduce> %val : i1 -> i1 return %0: i1 } @@ -532,7 +532,7 @@ func.func @group_non_uniform_bitwise_or(%val: i32) -> i32 { // ----- func.func @group_non_uniform_bitwise_or(%val: i1) -> i1 { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}} %0 = spirv.GroupNonUniformBitwiseOr <Workgroup> <Reduce> %val : i1 -> i1 return %0: i1 } @@ -553,7 +553,7 @@ func.func @group_non_uniform_bitwise_xor(%val: i32) -> i32 { // ----- func.func @group_non_uniform_bitwise_xor(%val: i1) -> i1 { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}} %0 = spirv.GroupNonUniformBitwiseXor <Workgroup> <Reduce> %val : i1 -> i1 return %0: i1 } @@ -574,7 +574,7 @@ func.func @group_non_uniform_logical_and(%val: i1) -> i1 { // ----- func.func @group_non_uniform_logical_and(%val: i32) -> i32 { - // expected-error @+1 {{operand #0 must be bool or vector of bool values of length 2/3/4/8/16, but got 'i32'}} + // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}} %0 = spirv.GroupNonUniformLogicalAnd <Workgroup> <Reduce> %val : i32 -> i32 return %0: i32 } @@ -595,7 +595,7 @@ func.func @group_non_uniform_logical_or(%val: i1) -> i1 { // ----- func.func @group_non_uniform_logical_or(%val: i32) -> i32 { - // expected-error @+1 {{operand #0 must be bool or vector of bool values of length 2/3/4/8/16, but got 'i32'}} + // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}} %0 = spirv.GroupNonUniformLogicalOr <Workgroup> <Reduce> %val : i32 -> i32 return %0: i32 } @@ -616,7 +616,7 @@ func.func @group_non_uniform_logical_xor(%val: i1) -> i1 { // ----- func.func @group_non_uniform_logical_xor(%val: i32) -> i32 { - // expected-error @+1 {{operand #0 must be bool or vector of bool values of length 2/3/4/8/16, but got 'i32'}} + // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}} %0 = spirv.GroupNonUniformLogicalXor <Workgroup> <Reduce> %val : i32 -> i32 return %0: i32 } diff --git a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir index 8f021ed..6aaaa60 100644 --- a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir @@ -19,7 +19,7 @@ func.func @expvec(%arg0 : vector<3xf16>) -> () { // ----- func.func @exp(%arg0 : i32) -> () { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}} %2 = spirv.CL.exp %arg0 : i32 return } @@ -27,7 +27,7 @@ func.func @exp(%arg0 : i32) -> () { // ----- func.func @exp(%arg0 : vector<5xf32>) -> () { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4}} + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4}} %2 = spirv.CL.exp %arg0 : vector<5xf32> return } @@ -75,7 +75,7 @@ func.func @fabsf64(%arg0 : f64) -> () { // ----- func.func @fabs(%arg0 : i32) -> () { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}} %2 = spirv.CL.fabs %arg0 : i32 return } @@ -83,7 +83,7 @@ func.func @fabs(%arg0 : i32) -> () { // ----- func.func @fabs(%arg0 : vector<5xf32>) -> () { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4}} + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4}} %2 = spirv.CL.fabs %arg0 : vector<5xf32> return } @@ -137,7 +137,7 @@ func.func @sabsi8(%arg0 : i8) -> () { // ----- func.func @sabs(%arg0 : f32) -> () { - // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values}} + // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values}} %2 = spirv.CL.s_abs %arg0 : f32 return } @@ -145,7 +145,7 @@ func.func @sabs(%arg0 : f32) -> () { // ----- func.func @sabs(%arg0 : vector<5xi32>) -> () { - // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}} + // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4}} %2 = spirv.CL.s_abs %arg0 : vector<5xi32> return } diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir index 5d05a654..6d321af 100644 --- a/mlir/test/Dialect/SPIRV/IR/types.mlir +++ b/mlir/test/Dialect/SPIRV/IR/types.mlir @@ -296,6 +296,12 @@ func.func private @struct_type_with_matrix_2(!spirv.struct<(!spirv.matrix<3 x ve // CHECK: func private @struct_empty(!spirv.struct<()>) func.func private @struct_empty(!spirv.struct<()>) +// CHECK: func.func private @struct_block(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>) +func.func private @struct_block(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>) + +// CHECK: func.func private @struct_two_dec(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block, CPacked>) +func.func private @struct_two_dec(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block, CPacked>) + // ----- // expected-error @+1 {{offset specification must be given for all members}} diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir index bd51a07..f3a3218 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir @@ -66,3 +66,27 @@ spirv.module Logical GLSL450 attributes {spirv.target_env = #spirv.target_env<#s // CHECK: spirv.EntryPoint "GLCompute" [[FN]], [[VAR0]], [[VAR1]] // CHECK: spirv.ExecutionMode [[FN]] "LocalSize", 32, 1, 1 } // end spirv.module + +// ----- + +module { + spirv.module Logical GLSL450 attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Sampled1D], []>, #spirv.resource_limits<>>} { + // CHECK-DAG: spirv.GlobalVariable @[[IMAGE_GV:.*]] bind(0, 0) : !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant> + // CHECK: spirv.func @read_image + spirv.func @read_image(%arg0: !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) "None" attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1]>} { + // CHECK: %[[IMAGE_ADDR:.*]] = spirv.mlir.addressof @[[IMAGE_GV]] : !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant> + %cst0_i32 = spirv.Constant 0 : i32 + // CHECK: spirv.Load "UniformConstant" %[[IMAGE_ADDR]] + %0 = spirv.Load "UniformConstant" %arg0 : !spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>> + %1 = spirv.Image %0 : !spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>> + %2 = spirv.ImageFetch %1, %cst0_i32 : !spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>, i32 -> vector<4xf32> + %3 = spirv.CompositeExtract %2[0 : i32] : vector<4xf32> + %cst0_i32_0 = spirv.Constant 0 : i32 + %cst0_i32_1 = spirv.Constant 0 : i32 + %cst1_i32 = spirv.Constant 1 : i32 + %4 = spirv.AccessChain %arg1[%cst0_i32_0, %cst0_i32] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer> + spirv.Store "StorageBuffer" %4, %3 : f32 + spirv.Return + } + } +} diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir index 2b23766..8d7f3da 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir @@ -178,7 +178,7 @@ spirv.module Logical GLSL450 attributes { // Vulkan memory model requires SPV_KHR_vulkan_memory_model, which is enabled // implicitly by v1.5. -// CHECK: requires #spirv.vce<v1.0, [VulkanMemoryModel], [SPV_KHR_vulkan_memory_model]> +// CHECK: requires #spirv.vce<v1.5, [VulkanMemoryModel], [SPV_KHR_vulkan_memory_model]> spirv.module Logical Vulkan attributes { spirv.target_env = #spirv.target_env< #spirv.vce<v1.5, [Shader, VulkanMemoryModel], []>, #spirv.resource_limits<>> diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir index 0176fc2..6398161 100644 --- a/mlir/test/Dialect/Tosa/availability.mlir +++ b/mlir/test/Dialect/Tosa/availability.mlir @@ -645,7 +645,7 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> { func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { // CHECK: tosa.cond_if profiles: [ ] // CHECK: tosa.cond_if extensions: [ [controlflow] ] - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %1 : tensor<f32> } else { diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 6b55442..5150ee3 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -241,6 +241,26 @@ func.func @clamp_f32_is_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> { // ----- +// CHECK-LABEL: @clamp_boolean_is_noop +func.func @clamp_boolean_is_noop(%arg0: tensor<4xi1>) -> tensor<4xi1> { + // CHECK: return %arg0 + // CHECK-NOT: tosa.clamp + %0 = tosa.clamp %arg0 {min_val = false, max_val = true} : (tensor<4xi1>) -> tensor<4xi1> + return %0 : tensor<4xi1> +} + +// ----- + +// CHECK-LABEL: @clamp_boolean_dynamic_is_noop +func.func @clamp_boolean_dynamic_is_noop(%arg0: tensor<?xi1>) -> tensor<?xi1> { + // CHECK: return %arg0 + // CHECK-NOT: tosa.clamp + %0 = tosa.clamp %arg0 {min_val = false, max_val = true} : (tensor<?xi1>) -> tensor<?xi1> + return %0 : tensor<?xi1> +} + +// ----- + // CHECK-LABEL: @clamp_int8_is_noop func.func @clamp_int8_is_noop(%arg0: tensor<4xi8>) -> tensor<4xi8> { // CHECK: return %arg0 diff --git a/mlir/test/Dialect/Tosa/controlflow.mlir b/mlir/test/Dialect/Tosa/controlflow.mlir new file mode 100644 index 0000000..06312c7 --- /dev/null +++ b/mlir/test/Dialect/Tosa/controlflow.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-opt -split-input-file %s | FileCheck %s + +// ----- + +func.func @condif_cond_type_check(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { + // CHECK: tosa.cond_if %[[ARG2:.*]] : tensor<i1> -> tensor<f32> { + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { + %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + // CHECK: } else { + } else { + %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + } + return %0 : tensor<f32> +} + +// ----- + +func.func @condif_block_args_check(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { + // CHECK: tosa.cond_if %[[ARG2:.*]] (%[[ARG3:.*]] = %[[ARG0:.*]], %[[ARG4:.*]] = %[[ARG1:.*]]) : tensor<i1> (tensor<f32>, tensor<f32>) -> tensor<f32> { + // CHECK-NEXT: ^bb0(%[[ARG3]]: tensor<f32>, %[[ARG4]]: tensor<f32>): + %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<f32>, tensor<f32>) -> tensor<f32> { + ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): + %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + // CHECK: } else { + // CHECK-NEXT: ^bb0(%[[ARG3]]: tensor<f32>, %[[ARG4]]: tensor<f32>): + } else { + ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): + %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + } + return %0 : tensor<f32> +} diff --git a/mlir/test/Dialect/Tosa/dynamic_extension.mlir b/mlir/test/Dialect/Tosa/dynamic_extension.mlir index 8739f97..e23ce430 100644 --- a/mlir/test/Dialect/Tosa/dynamic_extension.mlir +++ b/mlir/test/Dialect/Tosa/dynamic_extension.mlir @@ -2,7 +2,7 @@ // Check operations when the dynamic extension is enabled. //-------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic strict-op-spec-alignment allow-invalid-op-datatype-combinations" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic allow-invalid-op-datatype-combinations" // ----- diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir index eb25011..fad1bec 100644 --- a/mlir/test/Dialect/Tosa/error_if_check.mlir +++ b/mlir/test/Dialect/Tosa/error_if_check.mlir @@ -259,7 +259,7 @@ func.func @test_cond_if_else_not_isolated_from_above(%arg0: tensor<f32>, %arg1: func.func @test_cond_if_simplified_form_not_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { // expected-error@+1 {{'tosa.cond_if' op is not conformant to the TOSA specification. It requires the 'then' region is isolated from above.}} - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + %0 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<f32>) { tosa.yield %arg0 : tensor<f32> } else { tosa.yield %arg1 : tensor<f32> diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 716362e..3bccb32 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -1125,7 +1125,7 @@ func.func @test_sub_with_unequal_result_ranks(%arg0: tensor<1x21x3xf32>, %arg1: // CHECK-LABEL: test_mul_non_scalar_shift_2d func.func @test_mul_non_scalar_shift_2d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> { %shift = "tosa.const"() <{values = dense<0> : tensor<1x1xi8>}> : () -> tensor<1x1xi8> - // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}} + // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant unranked tensor of 8-bit signless integer values or tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}} %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1x1xi8>) -> tensor<13x21x3xf32> return %0 : tensor<13x21x3xf32> } @@ -1134,7 +1134,7 @@ func.func @test_mul_non_scalar_shift_2d(%arg0: tensor<13x21x3xf32>, %arg1: tenso // CHECK-LABEL: test_mul_non_scalar_shift_1d func.func @test_mul_non_scalar_shift_1d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> { %shift = "tosa.const"() <{values = dense<0> : tensor<2xi8>}> : () -> tensor<2xi8> - // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<2xi8>'}} + // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant unranked tensor of 8-bit signless integer values or tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<2xi8>'}} %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<2xi8>) -> tensor<13x21x3xf32> return %0 : tensor<13x21x3xf32> } @@ -2036,3 +2036,19 @@ func.func @test_scatter_duplicate_indices(%arg0: tensor<2x52x3xf32>, %arg2: tens %0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi32>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32> return %0 : tensor<2x52x3xf32> } + +// ----- + +func.func @test_reduce_all_unsupported_data_types(%arg0: tensor<2x12x11xf32>) -> tensor<1x12x11xf32> { + // expected-error@+1 {{'tosa.reduce_all' op illegal: operation operand/result data types did not align with any profile or extension, got (f32,f32), did you mean (i1,i1)?}} + %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<2x12x11xf32>) -> tensor<1x12x11xf32> + return %0 : tensor<1x12x11xf32> +} + +// ----- + +func.func @test_rfft2d(%arg0: tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, tensor<13x8x9xbf16>) { + // expected-error@+1 {{'tosa.rfft2d' op illegal: operation operand/result data types did not align with any profile or extension, got (bf16,bf16,bf16), did you mean (f32,f32,f32)?}} + %0, %1 = tosa.rfft2d %arg0 : (tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, tensor<13x8x9xbf16>) + return %0, %1 : tensor<13x8x9xbf16>, tensor<13x8x9xbf16> +} diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir index 5630c33..3154f54 100644 --- a/mlir/test/Dialect/Tosa/invalid_extension.mlir +++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir @@ -337,7 +337,7 @@ func.func @test_cast_bf16_i32(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xi32 // ----- func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { // expected-error@+1 {{'tosa.cond_if' op illegal: requires [controlflow]}} - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %1 : tensor<f32> } else { diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index 0dddf26..0184d2b 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate +// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="extension=dynamic" func.func @test_argmax_rank_invalid(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> { // expected-error@+1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}} @@ -48,10 +48,10 @@ func.func @test_add_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tens // ----- -func.func @test_arithmetic_right_shift_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> { +func.func @test_arithmetic_right_shift_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi32>, %arg1: tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32> { // expected-error@+1 {{'tosa.arithmetic_right_shift' op failed level check: operand rank(shape) <= MAX_RANK}} - %0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (tensor<1x1x1x1x13x21x3xf32>, tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> - return %0 : tensor<1x1x1x1x13x21x3xf32> + %0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (tensor<1x1x1x1x13x21x3xi32>, tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32> + return %0 : tensor<1x1x1x1x13x21x3xi32> } // ----- @@ -1506,13 +1506,13 @@ func.func @test_while_tensor_list_size(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1: // ----- func.func @test_cond_if_max_nested_depth(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>, %arg3: tensor<i1>) -> tensor<f32> { - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { - %1 = tosa.cond_if %arg3 -> (tensor<f32>) { - %2 = tosa.cond_if %arg2 -> (tensor<f32>) { - %3 = tosa.cond_if %arg3 -> (tensor<f32>) { - %4 = tosa.cond_if %arg2 -> (tensor<f32>) { + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { + %1 = tosa.cond_if %arg3 : tensor<i1>-> tensor<f32> { + %2 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { + %3 = tosa.cond_if %arg3 : tensor<i1> -> tensor<f32> { + %4 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { // expected-error@+1 {{'tosa.cond_if' op failed level check: 6 >= MAX_NESTING}} - %5 = tosa.cond_if %arg3 -> (tensor<f32>) { + %5 = tosa.cond_if %arg3 : tensor<i1> -> tensor<f32> { %res = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %res : tensor<f32> } else { diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index ef51197e..30361a8 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -839,7 +839,7 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> { // ----- // CHECK-LABEL: cond_if func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %1 : tensor<f32> } else { diff --git a/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir b/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir index 38ac8d8..e957bdd 100644 --- a/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir +++ b/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir @@ -54,7 +54,7 @@ func.func @test_no_change(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> { // CHECK-LABEL: test_regions // CHECK: %arg0: tensor<i8>, %arg1: tensor<i8> func.func @test_regions(%arg0: tensor<ui8>, %arg1: tensor<ui8>, %arg2: tensor<i1>) -> tensor<ui8> { - // CHECK: tosa.cond_if %arg2 -> (tensor<i8>) + // CHECK: tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<i8>, tensor<i8>) -> tensor<i8> %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ ^bb0(%arg3: tensor<ui8>, %arg4: tensor<ui8>): // CHECK: %1 = tosa.add %arg0, %arg1 : (tensor<i8>, tensor<i8>) -> tensor<i8> diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index 9d43f89..7b8fc24 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -357,6 +357,17 @@ func.func @test_accepts_unranked_scalar_tensor(%arg0: tensor<1x2x2xf32>, %arg1: // ----- +// CHECK-LABEL: @test_unranked_scalar_i8_tensor +func.func @test_unranked_scalar_i8_tensor(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>, %arg2: tensor<1xi8>) -> tensor<4xi32> { + // CHECK: %[[SHIFT:.*]] = tosa.cast %arg2 : (tensor<1xi8>) -> tensor<1xi8> + %shift = tosa.cast %arg2 : (tensor<1xi8>) -> tensor<*xi8> + // CHECK: tosa.mul %arg0, %arg1, %[[SHIFT]] : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32> + %0 = tosa.mul %arg0, %arg1, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<*xi8>) -> tensor<4xi32> + return %0 : tensor<4xi32> +} + +// ----- + // CHECK-LABEL: @test_table_static func.func @test_table_static(%arg0 : tensor<4x5xi16>, %arg1 : tensor<513xi16>) -> () { // CHECK:tosa.table %arg0, %arg1 : (tensor<4x5xi16>, tensor<513xi16>) -> tensor<4x5xi16> @@ -1166,8 +1177,8 @@ func.func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tens %b = tosa.log %arg1 : (tensor<f32>) -> tensor<f32> // CHECK: tosa.cond_if - // CHECK: -> (tensor<f32>) - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + // CHECK: -> tensor<f32> + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { tosa.yield %a : tensor<f32> } else { tosa.yield %b : tensor<f32> @@ -1180,8 +1191,8 @@ func.func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tens // CHECK-LABEL: @if_test_dynamic func.func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () { // CHECK: tosa.cond_if - // CHECK: -> (tensor<?xf32>) - %0 = tosa.cond_if %arg2 -> (tensor<?xf32>) { + // CHECK: -> tensor<?xf32> + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<?xf32> { tosa.yield %arg0 : tensor<2xf32> } else { tosa.yield %arg1 : tensor<3xf32> @@ -1194,8 +1205,8 @@ func.func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 : // CHECK-LABEL: @if_test_unranked func.func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () { // CHECK: tosa.cond_if - // CHECK: -> (tensor<*xf32>) - %0 = tosa.cond_if %arg2 -> (tensor<*xf32>) { + // CHECK: -> tensor<*xf32> + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<*xf32> { tosa.yield %arg0 : tensor<f32> } else { tosa.yield %arg1 : tensor<3xf32> @@ -1208,8 +1219,8 @@ func.func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 : // CHECK-LABEL: @if_test_propagate func.func @if_test_propagate(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> () { // CHECK: tosa.cond_if - // CHECK: -> (tensor<f32>) - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + // CHECK: -> tensor<f32> + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %1 : tensor<f32> } else { diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir index b305236..2a937b0 100644 --- a/mlir/test/Dialect/Tosa/verifier.mlir +++ b/mlir/test/Dialect/Tosa/verifier.mlir @@ -500,9 +500,39 @@ func.func @test_cond_if_input_list_mismatch_else_block_2(%arg0: tensor<f32>, %ar // ----- +func.func @test_cond_if_input_list_mismatch_else_block_simple(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { + // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' arguments (1) and 'input_list' (2)}} + %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<f32>, tensor<f32>) -> tensor<f32> { + ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): + %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + } else { + ^bb0(%arg3: tensor<f32>): + tosa.yield %arg3 : tensor<f32> + } + return %0 : tensor<f32> +} + +// ----- + +func.func @test_cond_if_input_list_mismatch_else_block_simple_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { + // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' arguments (2) and 'input_list' (1)}} + %0 = tosa.cond_if %arg2 (%arg3 = %arg0) : tensor<i1> (tensor<f32>) -> tensor<f32> { + ^bb0(%arg3: tensor<f32>): + tosa.yield %arg3 : tensor<f32> + } else { + ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): + %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + } + return %0 : tensor<f32> +} + +// ----- + func.func @test_cond_if_output_list_mismatch_then_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'then_graph' results (2) and 'output_list' (1)}} - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> %2 = tosa.add %1, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %1, %2 : tensor<f32>, tensor<f32> @@ -517,7 +547,7 @@ func.func @test_cond_if_output_list_mismatch_then_block(%arg0: tensor<f32>, %arg func.func @test_cond_if_output_list_mismatch_then_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'then_graph' results (1) and 'output_list' (2)}} - %0, %2 = tosa.cond_if %arg2 -> (tensor<f32>, tensor<f32>) { + %0, %2 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<f32>, tensor<f32>) { %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %1 : tensor<f32> } else { @@ -531,7 +561,7 @@ func.func @test_cond_if_output_list_mismatch_then_block_2(%arg0: tensor<f32>, %a func.func @test_cond_if_output_list_mismatch_else_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' results (2) and 'output_list' (1)}} - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %1 : tensor<f32> } else { @@ -546,7 +576,7 @@ func.func @test_cond_if_output_list_mismatch_else_block(%arg0: tensor<f32>, %arg func.func @test_cond_if_output_list_mismatch_else_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' results (1) and 'output_list' (2)}} - %0, %2 = tosa.cond_if %arg2 -> (tensor<f32>, tensor<f32>) { + %0, %2 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<f32>, tensor<f32>) { %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> %2 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %1, %2 : tensor<f32>, tensor<f32> @@ -574,6 +604,53 @@ func.func @test_cond_if_cond_input_not_size_one(%arg0: tensor<f32>, %arg1: tenso // ----- +// CHECK-LABEL: cond_if_cond_type +func.func @test_cond_if_cond_type(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { + // expected-error@+2 {{expected ':'}} + // expected-error@+1 {{custom op 'tosa.cond_if' expected type for condition operand}} + %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + tosa.yield %arg0 : tensor<f32> + } else { + tosa.yield %arg1 : tensor<f32> + } + return %0 : tensor<f32> +} + +// ----- + +func.func @test_cond_if_input_list_type_mismatch_simple(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { + // expected-error@+1 {{custom op 'tosa.cond_if' expected as many input types as operands (expected 2 got 0)}} + %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> () -> tensor<f32> { + ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): + %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + } else { + ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): + %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + } + return %0 : tensor<f32> +} + +// ----- + +func.func @test_cond_if_incorrect_type_simple(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { + // expected-error@+2 {{expected non-function type}} + // expected-error@+1 {{custom op 'tosa.cond_if' expected list of types for block arguments followed by arrow type and list of return types}} + %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (%arg3) -> tensor<f32> { + ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): + %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + } else { + ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): + %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + } + return %0 : tensor<f32> +} + +// ----- + func.func @test_while_loop_input_list_mismatch_body_block_in(%arg0: tensor<10xi32>, %arg1: tensor<i32>) { %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32> // expected-error@+1 {{'tosa.while_loop' op require same number of values in 'body_graph' arguments (3) and 'input_list' (2)}} diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 1461c30..f86fb38 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -823,11 +823,11 @@ func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32 // ----- -// CHECK-LABEL: fold_extract_scalar_from_splat +// CHECK-LABEL: fold_extract_splatlike // CHECK-SAME: %[[A:.*]]: f32 // CHECK: return %[[A]] : f32 -func.func @fold_extract_scalar_from_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 { - %b = vector.splat %a : vector<1x2x4xf32> +func.func @fold_extract_splatlike(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 { + %b = vector.broadcast %a : f32 to vector<1x2x4xf32> %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32> return %r : f32 } @@ -1330,11 +1330,11 @@ func.func @fold_consecutive_broadcasts(%a : i32) -> vector<4x16xi32> { // ----- -// CHECK-LABEL: shape_cast_constant +// CHECK-LABEL: shape_cast_splat_constant // CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : vector<3x4x2xi32> // CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<20x2xf32> // CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32> -func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) { +func.func @shape_cast_splat_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) { %cst = arith.constant dense<2.000000e+00> : vector<5x4x2xf32> %cst_1 = arith.constant dense<1> : vector<12x2xi32> %0 = vector.shape_cast %cst : vector<5x4x2xf32> to vector<20x2xf32> @@ -1344,6 +1344,36 @@ func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) { // ----- +// Test of shape_cast's fold method: +// shape_cast(constant) -> constant. +// +// CHECK-LABEL: @shape_cast_dense_int_constant +// CHECK: %[[CST:.*]] = arith.constant +// CHECK-SAME{LITERAL}: dense<[[2, 3, 5], [7, 11, 13]]> +// CHECK: return %[[CST]] : vector<2x3xi8> +func.func @shape_cast_dense_int_constant() -> vector<2x3xi8> { + %cst = arith.constant dense<[2, 3, 5, 7, 11, 13]> : vector<6xi8> + %0 = vector.shape_cast %cst : vector<6xi8> to vector<2x3xi8> + return %0 : vector<2x3xi8> +} + +// ----- + +// Test of shape_cast fold's method: +// (shape_cast(const_x), const_x) -> (const_x_folded, const_x) +// +// CHECK-LABEL: @shape_cast_dense_float_constant +// CHECK-DAG: %[[CST0:.*]] = {{.*}}1.000000e+00, 2.000000e+00{{.*}} vector<1x2xf32> +// CHECK-DAG: %[[CST1:.*]] = {{.*}}1.000000e+00, 2.000000e+00{{.*}} vector<2xf32> +// CHECK: return %[[CST1]], %[[CST0]] : vector<2xf32>, vector<1x2xf32> +func.func @shape_cast_dense_float_constant() -> (vector<2xf32>, vector<1x2xf32>){ + %cst = arith.constant dense<[[1.0, 2.0]]> : vector<1x2xf32> + %0 = vector.shape_cast %cst : vector<1x2xf32> to vector<2xf32> + return %0, %cst : vector<2xf32>, vector<1x2xf32> +} + +// ----- + // CHECK-LABEL: shape_cast_poison // CHECK-DAG: %[[CST1:.*]] = ub.poison : vector<3x4x2xi32> // CHECK-DAG: %[[CST0:.*]] = ub.poison : vector<20x2xf32> @@ -2033,11 +2063,11 @@ func.func @insert_strided_slice_full_range(%source: vector<16x16xf16>, %dest: ve // ----- -// CHECK-LABEL: extract_strided_splat -// CHECK: %[[B:.*]] = vector.splat %{{.*}} : vector<2x4xf16> +// CHECK-LABEL: extract_strided_splatlike +// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} f16 to vector<2x4xf16> // CHECK-NEXT: return %[[B]] : vector<2x4xf16> -func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> { - %0 = vector.splat %arg0 : vector<16x4xf16> +func.func @extract_strided_splatlike(%arg0: f16) -> vector<2x4xf16> { + %0 = vector.broadcast %arg0 : f16 to vector<16x4xf16> %1 = vector.extract_strided_slice %0 {offsets = [1, 0], sizes = [2, 4], strides = [1, 1]} : vector<16x4xf16> to vector<2x4xf16> @@ -2323,14 +2353,14 @@ func.func @extract_extract_strided2(%A: vector<2x4xf32>) // ----- -// CHECK-LABEL: func @splat_fold -func.func @splat_fold() -> vector<4xf32> { +// CHECK-LABEL: func @splatlike_fold +// CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32> +// CHECK-NEXT: return [[V]] : vector<4xf32> +func.func @splatlike_fold() -> vector<4xf32> { %c = arith.constant 1.0 : f32 - %v = vector.splat %c : vector<4xf32> + %v = vector.broadcast %c : f32 to vector<4xf32> return %v : vector<4xf32> - // CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32> - // CHECK-NEXT: return [[V]] : vector<4xf32> } // ----- @@ -2469,10 +2499,10 @@ func.func @shuffle_nofold1(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<5 // ----- -// CHECK-LABEL: func @transpose_splat_constant +// CHECK-LABEL: func @transpose_splatlike_constant // CHECK: %[[CST:.+]] = arith.constant dense<5.000000e+00> : vector<8x4xf32> // CHECK: return %[[CST]] -func.func @transpose_splat_constant() -> vector<8x4xf32> { +func.func @transpose_splatlike_constant() -> vector<8x4xf32> { %cst = arith.constant dense<5.0> : vector<4x8xf32> %0 = vector.transpose %cst, [1, 0] : vector<4x8xf32> to vector<8x4xf32> return %0 : vector<8x4xf32> @@ -2480,13 +2510,13 @@ func.func @transpose_splat_constant() -> vector<8x4xf32> { // ----- -// CHECK-LABEL: func @transpose_splat2( -// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> { -// CHECK: %[[VAL_1:.*]] = vector.splat %[[VAL_0]] : vector<3x4xf32> -// CHECK: return %[[VAL_1]] : vector<3x4xf32> -// CHECK: } -func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> { - %splat = vector.splat %arg : vector<4x3xf32> +// CHECK-LABEL: func @transpose_splatlike2( +// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> { +// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : f32 to vector<3x4xf32> +// CHECK: return %[[VAL_1]] : vector<3x4xf32> +// CHECK: } +func.func @transpose_splatlike2(%arg : f32) -> vector<3x4xf32> { + %splat = vector.broadcast %arg : f32 to vector<4x3xf32> %0 = vector.transpose %splat, [1, 0] : vector<4x3xf32> to vector<3x4xf32> return %0 : vector<3x4xf32> } @@ -2562,118 +2592,6 @@ func.func @insert_2d_splat_constant() // ----- -// CHECK-LABEL: func @insert_element_fold -// CHECK: %[[V:.+]] = arith.constant dense<[0, 1, 7, 3]> : vector<4xi32> -// CHECK: return %[[V]] -func.func @insert_element_fold() -> vector<4xi32> { - %v = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32> - %s = arith.constant 7 : i32 - %i = arith.constant 2 : i32 - %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32> - return %1 : vector<4xi32> -} - -// ----- - -// CHECK-LABEL: func @insert_element_invalid_fold -func.func @insert_element_invalid_fold() -> vector<1xf32> { - // Out-of-bound index here. - %c26 = arith.constant 26 : index - %cst_2 = arith.constant 1.60215309E+9 : f32 - %cst_20 = arith.constant dense<1.60215309E+9> : vector<1xf32> -// CHECK: vector.insertelement - %46 = vector.insertelement %cst_2, %cst_20[%c26 : index] : vector<1xf32> - return %46 : vector<1xf32> -} - - -// ----- - -// Do not crash on poison -// CHECK-LABEL: func @insert_poison_fold1 -// CHECK: vector.insertelement -func.func @insert_poison_fold1() -> vector<4xi32> { - %v = ub.poison : vector<4xi32> - %s = arith.constant 7 : i32 - %i = arith.constant 2 : i32 - %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32> - return %1 : vector<4xi32> -} - -// ----- - -// Do not crash on poison -// CHECK-LABEL: func @insert_poison_fold2 -// CHECK: vector.insertelement -func.func @insert_poison_fold2() -> vector<4xi32> { - %v = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32> - %s = ub.poison : i32 - %i = arith.constant 2 : i32 - %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32> - return %1 : vector<4xi32> -} - -// ----- - -// Do not crash on poison -// CHECK-LABEL: func @insert_poison_fold3 -// CHECK: vector.insertelement -func.func @insert_poison_fold3() -> vector<4xi32> { - %v = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32> - %s = arith.constant 7 : i32 - %i = ub.poison : i32 - %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32> - return %1 : vector<4xi32> -} - -// ----- - -// CHECK-LABEL: func @extract_element_fold -// CHECK: %[[C:.+]] = arith.constant 5 : i32 -// CHECK: return %[[C]] -func.func @extract_element_fold() -> i32 { - %v = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32> - %i = arith.constant 2 : i32 - %1 = vector.extractelement %v[%i : i32] : vector<4xi32> - return %1 : i32 -} - -// CHECK-LABEL: func @extract_element_splat_fold -// CHECK-SAME: (%[[ARG:.+]]: i32) -// CHECK: return %[[ARG]] -func.func @extract_element_splat_fold(%a : i32) -> i32 { - %v = vector.splat %a : vector<4xi32> - %i = arith.constant 2 : i32 - %1 = vector.extractelement %v[%i : i32] : vector<4xi32> - return %1 : i32 -} - -// ----- - -// Do not crash on poison -// CHECK-LABEL: func @extract_element_poison_fold1 -// CHECK: vector.extractelement -func.func @extract_element_poison_fold1() -> i32 { - %v = ub.poison : vector<4xi32> - %i = arith.constant 2 : i32 - %1 = vector.extractelement %v[%i : i32] : vector<4xi32> - return %1 : i32 -} - -// ----- - -// Do not crash on poison -// CHECK-LABEL: func @extract_element_poison_fold2 -// CHECK: vector.extractelement -func.func @extract_element_poison_fold2() -> i32 { - %v = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32> - %i = ub.poison : i32 - %1 = vector.extractelement %v[%i : i32] : vector<4xi32> - return %1 : i32 -} - -// ----- - // CHECK-LABEL: func @reduce_one_element_vector_extract // CHECK-SAME: (%[[V:.+]]: vector<1xf32>) // CHECK: %[[S:.+]] = vector.extract %[[V]][0] : f32 from vector<1xf32> @@ -2781,13 +2699,13 @@ func.func @bitcast(%a: vector<4x8xf32>) -> vector<4x16xi16> { // ----- -// CHECK-LABEL: @insert_strided_slice_splat +// CHECK-LABEL: @insert_strided_slice_splatlike // CHECK-SAME: (%[[ARG:.*]]: f32) -// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8x16xf32> +// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : f32 to vector<8x16xf32> // CHECK-NEXT: return %[[SPLAT]] : vector<8x16xf32> -func.func @insert_strided_slice_splat(%x: f32) -> (vector<8x16xf32>) { - %splat0 = vector.splat %x : vector<4x4xf32> - %splat1 = vector.splat %x : vector<8x16xf32> +func.func @insert_strided_slice_splatlike(%x: f32) -> (vector<8x16xf32>) { + %splat0 = vector.broadcast %x : f32 to vector<4x4xf32> + %splat1 = vector.broadcast %x : f32 to vector<8x16xf32> %0 = vector.insert_strided_slice %splat0, %splat1 {offsets = [2, 2], strides = [1, 1]} : vector<4x4xf32> into vector<8x16xf32> return %0 : vector<8x16xf32> @@ -2860,13 +2778,13 @@ func.func @insert_strided_2d_constant() -> // ----- -// CHECK-LABEL: func @shuffle_splat +// CHECK-LABEL: func @shuffle_splatlike // CHECK-SAME: (%[[ARG:.*]]: i32) -// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<4xi32> +// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<4xi32> // CHECK-NEXT: return %[[SPLAT]] : vector<4xi32> -func.func @shuffle_splat(%x : i32) -> vector<4xi32> { - %v0 = vector.splat %x : vector<4xi32> - %v1 = vector.splat %x : vector<2xi32> +func.func @shuffle_splatlike(%x : i32) -> vector<4xi32> { + %v0 = vector.broadcast %x : i32 to vector<4xi32> + %v1 = vector.broadcast %x : i32 to vector<2xi32> %shuffle = vector.shuffle %v0, %v1 [2, 3, 4, 5] : vector<4xi32>, vector<2xi32> return %shuffle : vector<4xi32> } @@ -2874,13 +2792,13 @@ func.func @shuffle_splat(%x : i32) -> vector<4xi32> { // ----- -// CHECK-LABEL: func @insert_splat +// CHECK-LABEL: func @insert_splatlike // CHECK-SAME: (%[[ARG:.*]]: i32) -// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<2x4x3xi32> +// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<2x4x3xi32> // CHECK-NEXT: return %[[SPLAT]] : vector<2x4x3xi32> -func.func @insert_splat(%x : i32) -> vector<2x4x3xi32> { - %v0 = vector.splat %x : vector<4x3xi32> - %v1 = vector.splat %x : vector<2x4x3xi32> +func.func @insert_splatlike(%x : i32) -> vector<2x4x3xi32> { + %v0 = vector.broadcast %x : i32 to vector<4x3xi32> + %v1 = vector.broadcast %x : i32 to vector<2x4x3xi32> %insert = vector.insert %v0, %v1[0] : vector<4x3xi32> into vector<2x4x3xi32> return %insert : vector<2x4x3xi32> } @@ -2933,18 +2851,6 @@ func.func @extract_strided_slice_of_constant_mask() -> vector<5x7xi1>{ // ----- -// CHECK-LABEL: func.func @fold_extractelement_of_broadcast( -// CHECK-SAME: %[[f:.*]]: f32 -// CHECK: return %[[f]] -func.func @fold_extractelement_of_broadcast(%f: f32) -> f32 { - %0 = vector.broadcast %f : f32 to vector<15xf32> - %c5 = arith.constant 5 : index - %1 = vector.extractelement %0 [%c5 : index] : vector<15xf32> - return %1 : f32 -} - -// ----- - // CHECK-LABEL: func.func @fold_0d_vector_reduction func.func @fold_0d_vector_reduction(%arg0: vector<f32>) -> f32 { // CHECK-NEXT: %[[RES:.*]] = vector.extract %arg{{.*}}[] : f32 from vector<f32> @@ -3124,11 +3030,11 @@ func.func @rank_1_shuffle_to_interleave(%arg0: vector<6xi32>, %arg1: vector<6xi3 // ----- -// CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression( -// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: vector<f32>, %[[c:.*]]: vector<2xf32>) -func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>, %c: vector<2xf32>) -> (f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) { - // Splat scalar to 0D and extract scalar. - %0 = vector.splat %a : vector<f32> +// CHECK-LABEL: func @extract_from_0d_splatlike_broadcast_regression( +// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: vector<f32>, %[[C:.*]]: vector<2xf32>) +func.func @extract_from_0d_splatlike_broadcast_regression(%a: f32, %b: vector<f32>, %c: vector<2xf32>) -> (f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) { + // Splat/broadcast scalar to 0D and extract scalar. + %0 = vector.broadcast %a : f32 to vector<f32> %1 = vector.extract %0[] : f32 from vector<f32> // Broadcast scalar to 0D and extract scalar. @@ -3136,12 +3042,12 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>, %3 = vector.extract %2[] : f32 from vector<f32> // Broadcast 0D to 3D and extract scalar. - // CHECK: %[[extract1:.*]] = vector.extract %[[b]][] : f32 from vector<f32> + // CHECK: %[[EXTRACT1:.*]] = vector.extract %[[B]][] : f32 from vector<f32> %4 = vector.broadcast %b : vector<f32> to vector<1x2x4xf32> %5 = vector.extract %4[0, 0, 1] : f32 from vector<1x2x4xf32> - // Splat scalar to 2D and extract scalar. - %6 = vector.splat %a : vector<2x3xf32> + // Splat/broadcast scalar to 2D and extract scalar. + %6 = vector.broadcast %a : f32 to vector<2x3xf32> %7 = vector.extract %6[0, 1] : f32 from vector<2x3xf32> // Broadcast scalar to 3D and extract scalar. @@ -3149,14 +3055,14 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>, %9 = vector.extract %8[2, 1, 5] : f32 from vector<5x6x7xf32> // Extract 2D from 3D that was broadcasted from a scalar. - // CHECK: %[[extract2:.*]] = vector.broadcast %[[a]] : f32 to vector<6x7xf32> + // CHECK: %[[EXTRACT2:.*]] = vector.broadcast %[[A]] : f32 to vector<6x7xf32> %10 = vector.extract %8[2] : vector<6x7xf32> from vector<5x6x7xf32> // Extract 1D from 2D that was splat'ed from a scalar. - // CHECK: %[[extract3:.*]] = vector.broadcast %[[a]] : f32 to vector<3xf32> + // CHECK: %[[EXTRACT3:.*]] = vector.broadcast %[[A]] : f32 to vector<3xf32> %11 = vector.extract %6[1] : vector<3xf32> from vector<2x3xf32> - // CHECK: return %[[a]], %[[a]], %[[extract1]], %[[a]], %[[a]], %[[extract2]], %[[extract3]] + // CHECK: return %[[A]], %[[A]], %[[EXTRACT1]], %[[A]], %[[A]], %[[EXTRACT2]], %[[EXTRACT3]] return %1, %3, %5, %7, %9, %10, %11 : f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32> } @@ -3598,7 +3504,7 @@ func.func @fold_insert_use_chain(%arg : vector<4x4xf32>, %val : f32, %pos: index %v_0 = vector.insert %val, %arg[%pos, 0] : f32 into vector<4x4xf32> %v_1 = vector.insert %val, %v_0[%pos, 0] : f32 into vector<4x4xf32> %v_2 = vector.insert %val, %v_1[%pos, 0] : f32 into vector<4x4xf32> - return %v_2 : vector<4x4xf32> + return %v_2 : vector<4x4xf32> } // ----- @@ -3612,5 +3518,5 @@ func.func @fold_insert_use_chain(%arg : vector<4x4xf32>, %val : f32, %pos: index func.func @no_fold_insert_use_chain_mismatch_static_position(%arg : vector<4xf32>, %val : f32) -> vector<4xf32> { %v_0 = vector.insert %val, %arg[0] : f32 into vector<4xf32> %v_1 = vector.insert %val, %v_0[1] : f32 into vector<4xf32> - return %v_1 : vector<4xf32> + return %v_1 : vector<4xf32> } diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir index fdab2a8..f43328f 100644 --- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir +++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir @@ -36,9 +36,9 @@ func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32 // CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32) func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) { %0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32> - // CHECK: %[[SPLAT1:.*]] = vector.splat %[[A]] : vector<3xf32> + // CHECK: %[[SPLAT1:.*]] = vector.broadcast %[[A]] : f32 to vector<3xf32> %1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32> - // CHECK: %[[SPLAT2:.*]] = vector.splat %[[B]] : vector<3xf32> + // CHECK: %[[SPLAT2:.*]] = vector.broadcast %[[B]] : f32 to vector<3xf32> %2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32> // CHECK: return %[[SPLAT1]], %[[SPLAT2]] return %1, %2 : vector<3xf32>, vector<3xf32> @@ -63,11 +63,11 @@ func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, // CHECK-LABEL: func @from_elements_to_splat( // CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32) func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) { - // CHECK: %[[SPLAT:.*]] = vector.splat %[[A]] : vector<2x3xf32> + // CHECK: %[[SPLAT:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3xf32> %0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32> // CHECK: %[[FROM_EL:.*]] = vector.from_elements {{.*}} : vector<2x3xf32> %1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32> - // CHECK: %[[SPLAT2:.*]] = vector.splat %[[A]] : vector<f32> + // CHECK: %[[SPLAT2:.*]] = vector.broadcast %[[A]] : f32 to vector<f32> %2 = vector.from_elements %a : vector<f32> // CHECK: return %[[SPLAT]], %[[FROM_EL]], %[[SPLAT2]] return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32> @@ -170,7 +170,7 @@ func.func @large_source_with_shape_cast_required(%arg0: vector<2x2x2x2xi8>) -> v // Could match, but handled by `rewriteFromElementsAsSplat`. // CHECK-LABEL: func @extract_single_elm( // CHECK-NEXT: vector.extract -// CHECK-NEXT: vector.splat +// CHECK-NEXT: vector.broadcast // CHECK-NEXT: return func.func @extract_single_elm(%arg0 : vector<2x3xi8>) -> vector<1xi8> { %0 = vector.extract %arg0[0, 0] : i8 from vector<2x3xi8> diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir new file mode 100644 index 0000000..e4a9391 --- /dev/null +++ b/mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir @@ -0,0 +1,126 @@ +// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s + +// This file should be removed when vector.splat is removed. +// This file tests canonicalization/folding with vector.splat. +// These tests all have equivalent tests using vector.broadcast in canonicalize.mlir + + +// CHECK-LABEL: fold_extract_splat +// CHECK-SAME: %[[A:.*]]: f32 +// CHECK: return %[[A]] : f32 +func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 { + %b = vector.splat %a : vector<1x2x4xf32> + %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32> + return %r : f32 +} + +// ----- + +// CHECK-LABEL: extract_strided_splat +// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} f16 to vector<2x4xf16> +// CHECK-NEXT: return %[[B]] : vector<2x4xf16> +func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> { + %0 = vector.splat %arg0 : vector<16x4xf16> + %1 = vector.extract_strided_slice %0 + {offsets = [1, 0], sizes = [2, 4], strides = [1, 1]} : + vector<16x4xf16> to vector<2x4xf16> + return %1 : vector<2x4xf16> +} + +// ----- + +// CHECK-LABEL: func @splat_fold +// CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32> +// CHECK-NEXT: return [[V]] : vector<4xf32> +func.func @splat_fold() -> vector<4xf32> { + %c = arith.constant 1.0 : f32 + %v = vector.splat %c : vector<4xf32> + return %v : vector<4xf32> + +} + +// ----- + +// CHECK-LABEL: func @transpose_splat2( +// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> { +// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : f32 to vector<3x4xf32> +// CHECK: return %[[VAL_1]] : vector<3x4xf32> +func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> { + %splat = vector.splat %arg : vector<4x3xf32> + %0 = vector.transpose %splat, [1, 0] : vector<4x3xf32> to vector<3x4xf32> + return %0 : vector<3x4xf32> +} + +// ----- + +// CHECK-LABEL: @insert_strided_slice_splat +// CHECK-SAME: (%[[ARG:.*]]: f32) +// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : f32 to vector<8x16xf32> +// CHECK-NEXT: return %[[SPLAT]] : vector<8x16xf32> +func.func @insert_strided_slice_splat(%x: f32) -> (vector<8x16xf32>) { + %splat0 = vector.splat %x : vector<4x4xf32> + %splat1 = vector.splat %x : vector<8x16xf32> + %0 = vector.insert_strided_slice %splat0, %splat1 {offsets = [2, 2], strides = [1, 1]} + : vector<4x4xf32> into vector<8x16xf32> + return %0 : vector<8x16xf32> +} + +// ----- + +// CHECK-LABEL: func @shuffle_splat +// CHECK-SAME: (%[[ARG:.*]]: i32) +// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<4xi32> +// CHECK-NEXT: return %[[SPLAT]] : vector<4xi32> +func.func @shuffle_splat(%x : i32) -> vector<4xi32> { + %v0 = vector.splat %x : vector<4xi32> + %v1 = vector.splat %x : vector<2xi32> + %shuffle = vector.shuffle %v0, %v1 [2, 3, 4, 5] : vector<4xi32>, vector<2xi32> + return %shuffle : vector<4xi32> +} + + +// ----- + +// CHECK-LABEL: func @insert_splat +// CHECK-SAME: (%[[ARG:.*]]: i32) +// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<2x4x3xi32> +// CHECK-NEXT: return %[[SPLAT]] : vector<2x4x3xi32> +func.func @insert_splat(%x : i32) -> vector<2x4x3xi32> { + %v0 = vector.splat %x : vector<4x3xi32> + %v1 = vector.splat %x : vector<2x4x3xi32> + %insert = vector.insert %v0, %v1[0] : vector<4x3xi32> into vector<2x4x3xi32> + return %insert : vector<2x4x3xi32> +} + +// ----- + +// CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression +// CHECK-SAME: (%[[A:.*]]: f32, %[[C:.*]]: vector<2xf32>) +func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %c: vector<2xf32>) -> (f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) { + // Splat scalar to 0D and extract scalar. + %0 = vector.splat %a : vector<f32> + %1 = vector.extract %0[] : f32 from vector<f32> + + // Broadcast scalar to 0D and extract scalar. + %2 = vector.splat %a : vector<f32> + %3 = vector.extract %2[] : f32 from vector<f32> + + // Splat scalar to 2D and extract scalar. + %6 = vector.splat %a : vector<2x3xf32> + %7 = vector.extract %6[0, 1] : f32 from vector<2x3xf32> + + // Broadcast scalar to 3D and extract scalar. + %8 = vector.splat %a : vector<5x6x7xf32> + %9 = vector.extract %8[2, 1, 5] : f32 from vector<5x6x7xf32> + + // Extract 2D from 3D that was broadcasted from a scalar. + // CHECK: %[[EXTRACT2:.*]] = vector.broadcast %[[A]] : f32 to vector<6x7xf32> + %10 = vector.extract %8[2] : vector<6x7xf32> from vector<5x6x7xf32> + + // Extract 1D from 2D that was splat'ed from a scalar. + // CHECK: %[[EXTRACT3:.*]] = vector.broadcast %[[A]] : f32 to vector<3xf32> + %11 = vector.extract %6[1] : vector<3xf32> from vector<2x3xf32> + + // CHECK: return %[[A]], %[[A]], %[[A]], %[[A]], %[[EXTRACT2]], %[[EXTRACT3]] + return %1, %3, %7, %9, %10, %11 : f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32> +} diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir index 0263193..b2f16bb 100644 --- a/mlir/test/Dialect/Vector/int-range-interface.mlir +++ b/mlir/test/Dialect/Vector/int-range-interface.mlir @@ -51,6 +51,15 @@ func.func @vector_shape_cast() -> vector<4x4xindex> { func.return %2 : vector<4x4xindex> } +// CHECK-LABEL: func @vector_transpose +// CHECK: test.reflect_bounds {smax = 8 : index, smin = 7 : index, umax = 8 : index, umin = 7 : index} +func.func @vector_transpose() -> vector<2x4xindex> { + %0 = test.with_bounds { smax = 8 : index, smin = 7 : index, umax = 8 : index, umin = 7 : index } : vector<4x2xindex> + %1 = vector.transpose %0, [1, 0] : vector<4x2xindex> to vector<2x4xindex> + %2 = test.reflect_bounds %1 : vector<2x4xindex> + func.return %2 : vector<2x4xindex> +} + // CHECK-LABEL: func @vector_extract // CHECK: test.reflect_bounds {smax = 6 : index, smin = 5 : index, umax = 6 : index, umin = 5 : index} func.func @vector_extract() -> index { @@ -60,16 +69,6 @@ func.func @vector_extract() -> index { func.return %2 : index } -// CHECK-LABEL: func @vector_extractelement -// CHECK: test.reflect_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index} -func.func @vector_extractelement() -> index { - %c0 = arith.constant 0 : index - %0 = test.with_bounds { umin = 6 : index, umax = 7 : index, smin = 6 : index, smax = 7 : index } : vector<4xindex> - %1 = vector.extractelement %0[%c0 : index] : vector<4xindex> - %2 = test.reflect_bounds %1 : index - func.return %2 : index -} - // CHECK-LABEL: func @vector_add // CHECK: test.reflect_bounds {smax = 12 : index, smin = 10 : index, umax = 12 : index, umin = 10 : index} func.func @vector_add() -> vector<4xindex> { @@ -90,17 +89,6 @@ func.func @vector_insert() -> vector<4xindex> { func.return %3 : vector<4xindex> } -// CHECK-LABEL: func @vector_insertelement -// CHECK: test.reflect_bounds {smax = 8 : index, smin = 5 : index, umax = 8 : index, umin = 5 : index} -func.func @vector_insertelement() -> vector<4xindex> { - %c0 = arith.constant 0 : index - %0 = test.with_bounds { umin = 5 : index, umax = 7 : index, smin = 5 : index, smax = 7 : index } : vector<4xindex> - %1 = test.with_bounds { umin = 6 : index, umax = 8 : index, smin = 6 : index, smax = 8 : index } : index - %2 = vector.insertelement %1, %0[%c0 : index] : vector<4xindex> - %3 = test.reflect_bounds %2 : vector<4xindex> - func.return %3 : vector<4xindex> -} - // CHECK-LABEL: func @test_loaded_vector_extract // No bounds // CHECK: test.reflect_bounds {smax = 2147483647 : si32, smin = -2147483648 : si32, umax = 4294967295 : ui32, umin = 0 : ui32} %{{.*}} : i32 @@ -120,3 +108,11 @@ func.func @test_vector_extsi() -> vector<2xi32> { %2 = test.reflect_bounds %1 : vector<2xi32> func.return %2 : vector<2xi32> } + +// CHECK-LABEL: func @vector_step +// CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index} +func.func @vector_step() -> vector<8xindex> { + %0 = vector.step : vector<8xindex> + %1 = test.reflect_bounds %0 : vector<8xindex> + func.return %1 : vector<8xindex> +} diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index ca837d3..c21de56 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -119,30 +119,6 @@ func.func @shuffle_empty_mask(%arg0: vector<2xf32>, %arg1: vector<2xf32>) { // ----- -func.func @extract_element(%arg0: vector<f32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{expected position to be empty with 0-D vector}} - %1 = vector.extractelement %arg0[%c : i32] : vector<f32> -} - -// ----- - -func.func @extract_element(%arg0: vector<4xf32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{expected position for 1-D vector}} - %1 = vector.extractelement %arg0[] : vector<4xf32> -} - -// ----- - -func.func @extract_element(%arg0: vector<4x4xf32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{unexpected >1 vector rank}} - %1 = vector.extractelement %arg0[%c : i32] : vector<4x4xf32> -} - -// ----- - func.func @extract_vector_type(%arg0: index) { // expected-error@+1 {{invalid kind of type specified: expected builtin.vector, but found 'index'}} %1 = vector.extract %arg0[] : index from index @@ -192,38 +168,6 @@ func.func @extract_position_overflow(%arg0: vector<4x8x16xf32>) { // ----- -func.func @insert_element(%arg0: f32, %arg1: vector<f32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{expected position to be empty with 0-D vector}} - %0 = vector.insertelement %arg0, %arg1[%c : i32] : vector<f32> -} - -// ----- - -func.func @insert_element(%arg0: f32, %arg1: vector<4xf32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{expected position for 1-D vector}} - %0 = vector.insertelement %arg0, %arg1[] : vector<4xf32> -} - -// ----- - -func.func @insert_element(%arg0: f32, %arg1: vector<4x4xf32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{unexpected >1 vector rank}} - %0 = vector.insertelement %arg0, %arg1[%c : i32] : vector<4x4xf32> -} - -// ----- - -func.func @insert_element_wrong_type(%arg0: i32, %arg1: vector<4xf32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{'vector.insertelement' op failed to verify that source operand type matches element type of result}} - %0 = "vector.insertelement" (%arg0, %arg1, %c) : (i32, vector<4xf32>, i32) -> (vector<4xf32>) -} - -// ----- - func.func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) { // expected-error@+1 {{expected position attribute of rank no greater than dest vector rank}} %1 = vector.insert %a, %b[3, 3, 3, 3, 3, 3] : f32 into vector<4x8x16xf32> diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 6a56116..625ffc1 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -199,22 +199,6 @@ func.func @shuffle_poison_mask(%a: vector<4xf32>, %b: vector<4xf32>) -> vector<4 return %1 : vector<4xf32> } -// CHECK-LABEL: @extract_element_0d -func.func @extract_element_0d(%a: vector<f32>) -> f32 { - // CHECK-NEXT: vector.extractelement %{{.*}}[] : vector<f32> - %1 = vector.extractelement %a[] : vector<f32> - return %1 : f32 -} - -// CHECK-LABEL: @extract_element -func.func @extract_element(%a: vector<16xf32>) -> f32 { - // CHECK: %[[C15:.*]] = arith.constant 15 : i32 - %c = arith.constant 15 : i32 - // CHECK-NEXT: vector.extractelement %{{.*}}[%[[C15]] : i32] : vector<16xf32> - %1 = vector.extractelement %a[%c : i32] : vector<16xf32> - return %1 : f32 -} - // CHECK-LABEL: @extract_const_idx func.func @extract_const_idx(%arg0: vector<4x8x16xf32>) -> (vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32) { @@ -256,22 +240,6 @@ func.func @extract_poison_idx(%a: vector<4x5xf32>) -> f32 { return %0 : f32 } -// CHECK-LABEL: @insert_element_0d -func.func @insert_element_0d(%a: f32, %b: vector<f32>) -> vector<f32> { - // CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[] : vector<f32> - %1 = vector.insertelement %a, %b[] : vector<f32> - return %1 : vector<f32> -} - -// CHECK-LABEL: @insert_element -func.func @insert_element(%a: f32, %b: vector<16xf32>) -> vector<16xf32> { - // CHECK: %[[C15:.*]] = arith.constant 15 : i32 - %c = arith.constant 15 : i32 - // CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[%[[C15]] : i32] : vector<16xf32> - %1 = vector.insertelement %a, %b[%c : i32] : vector<16xf32> - return %1 : vector<16xf32> -} - // CHECK-LABEL: @insert_const_idx func.func @insert_const_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vector<4x8x16xf32>) -> vector<4x8x16xf32> { diff --git a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir index 8e167a5..d5e3443 100644 --- a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir @@ -2,7 +2,7 @@ // CHECK-LABEL: func @broadcast_vec1d_from_scalar // CHECK-SAME: %[[A:.*0]]: f32 -// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2xf32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2xf32> // CHECK: return %[[T0]] : vector<2xf32> func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> { @@ -12,7 +12,7 @@ func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> { // CHECK-LABEL: func @broadcast_vec2d_from_scalar // CHECK-SAME: %[[A:.*0]]: f32 -// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3xf32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3xf32> // CHECK: return %[[T0]] : vector<2x3xf32> func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> { @@ -22,7 +22,7 @@ func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> { // CHECK-LABEL: func @broadcast_vec3d_from_scalar // CHECK-SAME: %[[A:.*0]]: f32 -// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3x4xf32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3x4xf32> // CHECK: return %[[T0]] : vector<2x3x4xf32> func.func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> { @@ -87,7 +87,7 @@ func.func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf3 // CHECK-LABEL: func @broadcast_stretch // CHECK-SAME: %[[A:.*0]]: vector<1xf32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<1xf32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<4xf32> +// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<4xf32> // CHECK: return %[[T1]] : vector<4xf32> func.func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> { @@ -113,16 +113,16 @@ func.func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32> // CHECK-SAME: %[[A:.*0]]: vector<4x1xf32> // CHECK: %[[U0:.*]] = ub.poison : vector<4x3xf32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<4x1xf32> -// CHECK: %[[T2:.*]] = vector.splat %[[T0]] : vector<3xf32> +// CHECK: %[[T2:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32> // CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[U0]] [0] : vector<3xf32> into vector<4x3xf32> // CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : f32 from vector<4x1xf32> -// CHECK: %[[T6:.*]] = vector.splat %[[T4]] : vector<3xf32> +// CHECK: %[[T6:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32> // CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<4x3xf32> // CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : f32 from vector<4x1xf32> -// CHECK: %[[T10:.*]] = vector.splat %[[T8]] : vector<3xf32> +// CHECK: %[[T10:.*]] = vector.broadcast %[[T8]] : f32 to vector<3xf32> // CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T7]] [2] : vector<3xf32> into vector<4x3xf32> // CHECK: %[[T12:.*]] = vector.extract %[[A]][3, 0] : f32 from vector<4x1xf32> -// CHECK: %[[T14:.*]] = vector.splat %[[T12]] : vector<3xf32> +// CHECK: %[[T14:.*]] = vector.broadcast %[[T12]] : f32 to vector<3xf32> // CHECK: %[[T15:.*]] = vector.insert %[[T14]], %[[T11]] [3] : vector<3xf32> into vector<4x3xf32> // CHECK: return %[[T15]] : vector<4x3xf32> diff --git a/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir index 059d955..5a8125e 100644 --- a/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir @@ -5,11 +5,11 @@ // CHECK-SAME: %[[B:.*1]]: vector<3xf32> // CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32> +// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32> // CHECK: %[[T2:.*]] = arith.mulf %[[T1]], %[[B]] : vector<3xf32> // CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32> // CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32> -// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xf32> +// CHECK: %[[T5:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32> // CHECK: %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32> // CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<2x3xf32> // CHECK: return %[[T7]] : vector<2x3xf32> @@ -26,12 +26,12 @@ func.func @outerproduct_noacc(%arg0: vector<2xf32>, // CHECK-SAME: %[[C:.*2]]: vector<2x3xf32> // CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32> +// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32> // CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xf32> from vector<2x3xf32> // CHECK: %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32> // CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32> // CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32> -// CHECK: %[[T6:.*]] = vector.splat %[[T5]] : vector<3xf32> +// CHECK: %[[T6:.*]] = vector.broadcast %[[T5]] : f32 to vector<3xf32> // CHECK: %[[T7:.*]] = vector.extract %[[C]][1] : vector<3xf32> from vector<2x3xf32> // CHECK: %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32> // CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32> @@ -49,11 +49,11 @@ func.func @outerproduct_acc(%arg0: vector<2xf32>, // CHECK-SAME: %[[B:.*1]]: vector<3xi32> // CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32> +// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : i32 to vector<3xi32> // CHECK: %[[T2:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32> // CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32> // CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32> -// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xi32> +// CHECK: %[[T5:.*]] = vector.broadcast %[[T4]] : i32 to vector<3xi32> // CHECK: %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32> // CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xi32> into vector<2x3xi32> // CHECK: return %[[T7]] : vector<2x3xi32> @@ -69,13 +69,13 @@ func.func @outerproduct_noacc_int(%arg0: vector<2xi32>, // CHECK-SAME: %[[C:.*2]]: vector<2x3xi32> // CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32> +// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : i32 to vector<3xi32> // CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xi32> from vector<2x3xi32> // CHECK: %[[T3:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32> // CHECK: %[[T4:.*]] = arith.addi %[[T3]], %[[T2]] : vector<3xi32> // CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32> // CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32> -// CHECK: %[[T7:.*]] = vector.splat %[[T6]] : vector<3xi32> +// CHECK: %[[T7:.*]] = vector.broadcast %[[T6]] : i32 to vector<3xi32> // CHECK: %[[T8:.*]] = vector.extract %[[C]][1] : vector<3xi32> from vector<2x3xi32> // CHECK: %[[T9:.*]] = arith.muli %[[T7]], %[[B]] : vector<3xi32> // CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[T8]] : vector<3xi32> @@ -91,7 +91,7 @@ func.func @outerproduct_acc_int(%arg0: vector<2xi32>, // CHECK-LABEL: func @axpy_fp( // CHECK-SAME: %[[A:.*0]]: vector<16xf32>, // CHECK-SAME: %[[B:.*1]]: f32) -// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : f32 to vector<16xf32> // CHECK: %[[T1:.*]] = arith.mulf %[[A]], %[[T0]] : vector<16xf32> // CHECK: return %[[T1]] : vector<16xf32> func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> { @@ -103,7 +103,7 @@ func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> { // CHECK-SAME: %[[A:.*0]]: vector<16xf32>, // CHECK-SAME: %[[B:.*1]]: f32, // CHECK-SAME: %[[C:.*2]]: vector<16xf32>) -// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : f32 to vector<16xf32> // CHECK: %[[T1:.*]] = vector.fma %[[A]], %[[T0]], %[[C]] : vector<16xf32> // CHECK: return %[[T1]] : vector<16xf32> func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) -> vector<16xf32> { @@ -114,7 +114,7 @@ func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32> // CHECK-LABEL: func @axpy_int( // CHECK-SAME: %[[A:.*0]]: vector<16xi32>, // CHECK-SAME: %[[B:.*1]]: i32) -// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : i32 to vector<16xi32> // CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32> // CHECK: return %[[T1]] : vector<16xi32> func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> { @@ -126,7 +126,7 @@ func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> { // CHECK-SAME: %[[A:.*0]]: vector<16xi32>, // CHECK-SAME: %[[B:.*1]]: i32, // CHECK-SAME: %[[C:.*2]]: vector<16xi32>) -// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : i32 to vector<16xi32> // CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32> // CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[C]] : vector<16xi32> // CHECK: return %[[T2]] : vector<16xi32> diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir index b826cdc..ef881ba 100644 --- a/mlir/test/Dialect/Vector/vector-sink.mlir +++ b/mlir/test/Dialect/Vector/vector-sink.mlir @@ -180,13 +180,14 @@ func.func @negative_not_elementwise() -> vector<2x2xf32> { // ----- -// The source and the result for arith.cmp have different types - not supported - -// CHECK-LABEL: func.func @negative_source_and_result_mismatch -// CHECK: %[[BROADCAST:.+]] = vector.broadcast -// CHECK: %[[RETURN:.+]] = arith.cmpf uno, %[[BROADCAST]], %[[BROADCAST]] -// CHECK: return %[[RETURN]] -func.func @negative_source_and_result_mismatch(%arg0 : f32, %arg1 : vector<1xf32>) -> vector<1xi1> { +// The source and the result for arith.cmp have different types + +// CHECK-LABEL: func.func @source_and_result_mismatch( +// CHECK-SAME: %[[ARG0:.+]]: f32) +// CHECK: %[[COMPARE:.+]] = arith.cmpf uno, %[[ARG0]], %[[ARG0]] +// CHECK: %[[BROADCAST:.+]] = vector.broadcast %[[COMPARE]] : i1 to vector<1xi1> +// CHECK: return %[[BROADCAST]] +func.func @source_and_result_mismatch(%arg0 : f32) -> vector<1xi1> { %0 = vector.broadcast %arg0 : f32 to vector<1xf32> %1 = arith.cmpf uno, %0, %0 : vector<1xf32> return %1 : vector<1xi1> @@ -210,6 +211,130 @@ func.func @negative_op_only_supports_vectors(%arg0 : f32) -> vector<1xf32> { return %1 : vector<1xf32> } +// ----- + +// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const( +// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> { +// CHECK: %[[NEW_CST:.*]] = arith.constant 2 : index +// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[NEW_CST]] : index +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex> +// CHECK: return %[[BCAST]] : vector<1x4xindex> + +func.func @broadcast_scalar_and_splat_const(%arg0: index) -> vector<1x4xindex> { + %0 = vector.broadcast %arg0 : index to vector<1x4xindex> + %cst = arith.constant dense<2> : vector<1x4xindex> + %2 = arith.addi %0, %cst : vector<1x4xindex> + return %2 : vector<1x4xindex> +} + +// ----- + +// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const_const_first( +// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> { +// CHECK: %[[NEW_CST:.*]] = arith.constant 2 : index +// CHECK: %[[SUB:.*]] = arith.subi %[[NEW_CST]], %[[ARG_0]] : index +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[SUB]] : index to vector<1x4xindex> +// CHECK: return %[[BCAST]] : vector<1x4xindex> + +func.func @broadcast_scalar_and_splat_const_const_first(%arg0: index) -> vector<1x4xindex> { + %0 = vector.broadcast %arg0 : index to vector<1x4xindex> + %cst = arith.constant dense<2> : vector<1x4xindex> + %2 = arith.subi %cst, %0 : vector<1x4xindex> + return %2 : vector<1x4xindex> +} + +// ----- + +// CHECK-LABEL: func.func @broadcast_vector_and_splat_const( +// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>) -> vector<3x4xf32> { +// CHECK: %[[NEW_CST:.*]] = arith.constant dense<2.000000e+00> : vector<4xf32> +// CHECK: %[[ADD:.*]] = arith.mulf %[[ARG_0]], %[[NEW_CST]] : vector<4xf32> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : vector<4xf32> to vector<3x4xf32> +// CHECK: return %[[BCAST]] : vector<3x4xf32> + +func.func @broadcast_vector_and_splat_const(%arg0: vector<4xf32>) -> vector<3x4xf32> { + %0 = vector.broadcast %arg0 : vector<4xf32> to vector<3x4xf32> + %cst = arith.constant dense<2.000000e+00> : vector<3x4xf32> + %2 = arith.mulf %0, %cst : vector<3x4xf32> + return %2 : vector<3x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @negative_broadcast_with_non_splat_const( +// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> { +// CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : index to vector<1x4xindex> +// CHECK-DAG: %[[CST:.*]] = arith.constant dense<{{\[}}[0, 1, 2, 3]]> : vector<1x4xindex> +// CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[CST]] : vector<1x4xindex> +// CHECK: return %[[ADD]] : vector<1x4xindex> + +func.func @negative_broadcast_with_non_splat_const(%arg0: index) -> vector<1x4xindex> { + %0 = vector.broadcast %arg0 : index to vector<1x4xindex> + %cst = arith.constant dense<[[0, 1, 2, 3]]> : vector<1x4xindex> + %2 = arith.addi %0, %cst : vector<1x4xindex> + return %2 : vector<1x4xindex> +} + +// ----- + +// CHECK-LABEL: func.func @broadcast_scalar_mixed_type( +// CHECK-SAME: %[[ARG_0:.*]]: f16) -> vector<1x4xf32> { +// CHECK: %[[EXTF:.*]] = arith.extf %[[ARG_0]] : f16 to f32 +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTF]] : f32 to vector<1x4xf32> +// CHECK: return %[[BCAST]] : vector<1x4xf32> + +func.func @broadcast_scalar_mixed_type(%arg0: f16) -> vector<1x4xf32> { + %0 = vector.broadcast %arg0 : f16 to vector<1x4xf16> + %1 = arith.extf %0 : vector<1x4xf16> to vector<1x4xf32> + return %1 : vector<1x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @broadcast_vector_mixed_type( +// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf16>) -> vector<3x4xf32> { +// CHECK: %[[EXTF:.*]] = arith.extf %[[ARG_0]] : vector<4xf16> to vector<4xf32> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTF]] : vector<4xf32> to vector<3x4xf32> +// CHECK: return %[[BCAST]] : vector<3x4xf32> + +func.func @broadcast_vector_mixed_type(%arg0: vector<4xf16>) -> vector<3x4xf32> { + %0 = vector.broadcast %arg0 : vector<4xf16> to vector<3x4xf16> + %1 = arith.extf %0 : vector<3x4xf16> to vector<3x4xf32> + return %1 : vector<3x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const_mixed_type( +// CHECK-SAME: %[[ARG_0:.*]]: f32) -> vector<1x4xf32> { +// CHECK: %[[NEW_CST:.*]] = arith.constant 3 : i32 +// CHECK: %[[POW:.*]] = math.fpowi %[[ARG_0]], %[[NEW_CST]] : f32, i32 +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[POW]] : f32 to vector<1x4xf32> +// CHECK: return %[[BCAST]] : vector<1x4xf32> + +func.func @broadcast_scalar_and_splat_const_mixed_type(%arg0: f32) -> vector<1x4xf32> { + %0 = vector.broadcast %arg0 : f32 to vector<1x4xf32> + %cst = arith.constant dense<3> : vector<1x4xi32> + %2 = math.fpowi %0, %cst : vector<1x4xf32>, vector<1x4xi32> + return %2 : vector<1x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @broadcast_vector_and_splat_const_mixed_type( +// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>) -> vector<3x4xf32> { +// CHECK: %[[NEW_CST:.*]] = arith.constant dense<3> : vector<4xi32> +// CHECK: %[[POW:.*]] = math.fpowi %[[ARG_0]], %[[NEW_CST]] : vector<4xf32>, vector<4xi32> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[POW]] : vector<4xf32> to vector<3x4xf32> +// CHECK: return %[[BCAST]] : vector<3x4xf32> + +func.func @broadcast_vector_and_splat_const_mixed_type(%arg0: vector<4xf32>) -> vector<3x4xf32> { + %0 = vector.broadcast %arg0 : vector<4xf32> to vector<3x4xf32> + %cst = arith.constant dense<3> : vector<3x4xi32> + %2 = math.fpowi %0, %cst : vector<3x4xf32>, vector<3x4xi32> + return %2 : vector<3x4xf32> +} + //===----------------------------------------------------------------------===// // [Pattern: ReorderCastOpsOnBroadcast] // diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir index 511ab70..1b54d54 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -284,19 +284,19 @@ func.func @transfer_read_permutations(%mem_0 : memref<?x?xf32>, %mem_1 : memref< %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index -// CHECK: %[[MASK0:.*]] = vector.splat %{{.*}} : vector<14x7xi1> +// CHECK: %[[MASK0:.*]] = vector.broadcast %{{.*}} : i1 to vector<14x7xi1> %mask0 = vector.splat %m : vector<14x7xi1> %0 = vector.transfer_read %mem_1[%c0, %c0, %c0, %c0], %cst, %mask0 {in_bounds = [true, false, true, true], permutation_map = #map0} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32> // CHECK: vector.transfer_read {{.*}} %[[MASK0]] {in_bounds = [false, true, true, true], permutation_map = #[[$MAP0]]} : memref<?x?x?x?xf32>, vector<14x7x8x16xf32> // CHECK: vector.transpose %{{.*}}, [1, 0, 2, 3] : vector<14x7x8x16xf32> to vector<7x14x8x16xf32> -// CHECK: %[[MASK1:.*]] = vector.splat %{{.*}} : vector<16x14xi1> +// CHECK: %[[MASK1:.*]] = vector.broadcast %{{.*}} : i1 to vector<16x14xi1> %mask1 = vector.splat %m : vector<16x14xi1> %1 = vector.transfer_read %mem_1[%c0, %c0, %c0, %c0], %cst, %mask1 {in_bounds = [true, false, true, false], permutation_map = #map1} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32> // CHECK: vector.transfer_read {{.*}} %[[MASK1]] {in_bounds = [false, false, true, true], permutation_map = #[[$MAP0]]} : memref<?x?x?x?xf32>, vector<16x14x7x8xf32> // CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32> -// CHECK: %[[MASK3:.*]] = vector.splat %{{.*}} : vector<14x7xi1> +// CHECK: %[[MASK3:.*]] = vector.broadcast %{{.*}} : i1 to vector<14x7xi1> %mask2 = vector.splat %m : vector<14x7xi1> %2 = vector.transfer_read %mem_1[%c0, %c0, %c0, %c0], %cst, %mask2 {in_bounds = [true, false, true, true], permutation_map = #map2} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32> // CHECK: vector.transfer_read {{.*}} %[[MASK3]] {in_bounds = [false, true, true], permutation_map = #[[$MAP1]]} : memref<?x?x?x?xf32>, vector<14x16x7xf32> @@ -336,7 +336,7 @@ func.func @transfer_write_permutations_tensor_masked( // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index %c0 = arith.constant 0 : index - // CHECK: %[[MASK:.*]] = vector.splat %[[M]] : vector<16x14x7x8xi1> + // CHECK: %[[MASK:.*]] = vector.broadcast %[[M]] : i1 to vector<16x14x7x8xi1> %mask0 = vector.splat %m : vector<16x14x7x8xi1> %res = vector.transfer_write %vec, %dst[%c0, %c0, %c0, %c0], %mask0 {in_bounds = [true, false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>} : vector<7x14x8x16xf32>, tensor<?x?x?x?xf32> // CHECK: %[[NEW_VEC0:.*]] = vector.transpose %{{.*}} [3, 1, 0, 2] : vector<7x14x8x16xf32> to vector<16x14x7x8xf32> diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index 0160bfe..dff3ffa 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -385,6 +385,74 @@ func.func @load_gather_vc_3(%src: ui64) { } // ----- +func.func @prefetch_offset_wi_1(%src: memref<4x4xf32>) { + %offsets = arith.constant dense<[0]> : vector<1xindex> + // expected-error@+1 {{Expecting the source is a 1D memref or pointer}} + xegpu.prefetch %src[%offsets]: memref<4x4xf32>, vector<1xindex> + return +} + +// ----- +func.func @load_gather_offset_sg(%src: memref<?xf16>) { + %offsets = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %mask = arith.constant dense<1>: vector<8xi1> + // expected-error@+1 {{Mask should match value except the chunk size dim}} + %2 = xegpu.load %src[%offsets], %mask + : memref<?xf16>, vector<4xindex>, vector<8xi1> + -> vector<4x2xf16> + return +} + +// ----- +func.func @load_gather_offset_wi(%src: ui64) { + %mask = arith.constant dense<1>: vector<1xi1> + %offsets = arith.constant dense<[0]> : vector<1xindex> + // expected-error@+1 {{value elements must match chunk size}} + %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64, vector<1xindex>, vector<1xi1> -> vector<4xf32> + return +} + +// ----- +func.func @store_scatter_offset_wi_1(%src: memref<?xf16>) { + %val = arith.constant dense<2.9>: vector<4xf16> + %offsets = arith.constant dense<[0]> : vector<1xindex> + %mask = arith.constant dense<1>: vector<1xi1> + // expected-error@+1 {{value elements must match chunk size}} + xegpu.store %val, %src[%offsets], %mask + : vector<4xf16>, memref<?xf16>, vector<1xindex>, vector<1xi1> + return +} + +// ----- +func.func @store_scatter_offset_wi_2(%src: memref<4x4xf16>) { + %val = arith.constant dense<2.9>: vector<4xf16> + %offsets = arith.constant dense<[0]> : vector<1xindex> + %mask = arith.constant dense<1>: vector<1xi1> + // expected-error@+1 {{Expecting the dest is a 1D memref or pointer}} + xegpu.store %val, %src[%offsets], %mask + : vector<4xf16>, memref<4x4xf16>, vector<1xindex>, vector<1xi1> + return +} + +// ----- +func.func @load_gather_offset_wi_2(%src: ui64) { + %mask = arith.constant dense<1>: vector<1xi1> + %offsets = arith.constant dense<[0]> : vector<1xindex> + // expected-error@+1 {{value elements must match chunk size}} + %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64, vector<1xindex>, vector<1xi1> -> vector<4xf16> + return +} + +// ----- +func.func @load_gather_offset_wi_1(%src: memref<4x4xf32>) { + %mask = arith.constant dense<1>: vector<1xi1> + %offsets = arith.constant dense<[0]> : vector<1xindex> + // expected-error@+1 {{Expecting the source is a 1D memref or pointer}} + %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : memref<4x4xf32>, vector<1xindex>, vector<1xi1> -> vector<2xf32> + return +} + +// ----- func.func @store_scatter_vc_1(%src: memref<24x32xf32>) { %0 = arith.constant dense<1>: vector<4xi1> %1 = arith.constant dense<2.9>: vector<4x2xf32> diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir index 3ebb1b969a..6be2371 100644 --- a/mlir/test/Dialect/XeGPU/ops.mlir +++ b/mlir/test/Dialect/XeGPU/ops.mlir @@ -521,6 +521,16 @@ gpu.func @subgroup_load_4(%src: ui64) { gpu.return } +// CHECK: gpu.func @subgroup_load_offset_1(%arg0: memref<?xf16>) { +gpu.func @subgroup_load_offset_1(%src: memref<?xf16>) { + %offset = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %mask = arith.constant dense<1>: vector<4xi1> + //CHECK: %[[R1:.*]] = xegpu.load %arg0[%cst], %cst_0 <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : memref<?xf16>, vector<4xindex>, vector<4xi1> -> vector<4x2xf16> + %val = xegpu.load %src[%offset], %mask <{chunk_size=2, l1_hint = #xegpu.cache_hint<cached>}> + : memref<?xf16>, vector<4xindex>, vector<4xi1> -> vector<4x2xf16> + gpu.return +} + // CHECK: gpu.func @subgroup_store(%[[arg0:.*]]: ui64) { gpu.func @subgroup_store(%src: ui64) { //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> @@ -626,6 +636,17 @@ gpu.func @subgroup_store_4(%src: ui64) { gpu.return } +// CHECK: gpu.func @subgroup_store_offset_1(%arg0: memref<?xf16>) { +gpu.func @subgroup_store_offset_1(%dest: memref<?xf16>) { + %val = arith.constant dense<2.9>: vector<4x2xf16> + %offset = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %mask = arith.constant dense<1>: vector<4xi1> + //CHECK: xegpu.store %[[R0:.*]], %arg0[%cst_0], %cst_1 <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<4x2xf16>, memref<?xf16>, vector<4xindex>, vector<4xi1> + xegpu.store %val, %dest[%offset], %mask <{chunk_size=2, l1_hint = #xegpu.cache_hint<cached>}> + : vector<4x2xf16>, memref<?xf16>, vector<4xindex>, vector<4xi1> + gpu.return +} + // CHECK: gpu.func @prefetch(%[[arg0:.*]]: ui64) { gpu.func @prefetch(%src: ui64) { //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> @@ -637,6 +658,14 @@ gpu.func @prefetch(%src: ui64) { gpu.return } +// CHECK: gpu.func @prefetch_offset(%[[arg0:.*]]: ui64) { +gpu.func @prefetch_offset(%src: ui64) { + //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + // CHECK: xegpu.prefetch %[[arg0]][%cst] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : ui64, vector<4xindex> + xegpu.prefetch %src[%0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: ui64, vector<4xindex> + gpu.return +} // CHECK: gpu.func @create_update_tdesc(%[[arg0:.*]]: ui64) { gpu.func @create_update_tdesc(%src: ui64) { diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir index d67bdb4..628a485 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir @@ -2,122 +2,117 @@ gpu.module @test_round_robin_assignment { // CHECK-LABEL: create_nd_tdesc - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @create_nd_tdesc(%src: memref<24x32xf32>) { - // CHECK-COUNT-12: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<24x32xf32> - // CHECK-SAME: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) { + // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK-NOT: xegpu.create_nd_tdesc - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return } // CHECK-LABEL: load_nd_tdesc - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @load_nd_tdesc(%src: memref<24x32xf32>) { - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> - // CHECK-COUNT-12: xegpu.load_nd %{{.*}} - // CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> - // CHECK-SAME-COUNT-12: -> vector<2x2xf32> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) { + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK-COUNT-4: xegpu.load_nd %{{.*}} + // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK-SAME-COUNT-4: -> vector<16x16xf32> // CHECK-NOT: xegpu.load_nd %load = xegpu.load_nd %tdesc - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> - -> vector<24x32xf32> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> + -> vector<256x128xf32> gpu.return } // CHECK-LABEL: store_nd - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @store_nd(%src: memref<24x32xf32>) { - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> - // CHECK-COUNT-12: xegpu.store_nd %{{.*}}, %{{.*}} - // CHECK-SAME-COUNT-12: : vector<2x2xf32>, !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @store_nd(%src: memref<256x128xf32>) { + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, %{{.*}} + // CHECK-SAME-COUNT-4: : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK-NOT : xegpu.store_nd %load = xegpu.load_nd %tdesc - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> - -> vector<24x32xf32> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> + -> vector<256x128xf32> xegpu.store_nd %load, %tdesc - : vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> + : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return } // CHECK-LABEL: update_nd - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @update_nd(%src: memref<24x32xf32>){ - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> - // CHECK-COUNT-12: xegpu.update_nd_offset %{{.*}}, [0, 16] - // CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @update_nd(%src: memref<256x128xf32>){ + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK-COUNT-4: xegpu.update_nd_offset %{{.*}}, [0, 16] + // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>> // CHECK-NOT: xegpu.update_nd_offset %update = xegpu.update_nd_offset %tdesc, [0, 16] - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return } // CHECK-LABEL: dpas - // CHECK-SAME: (%[[ARG_0:.*]]: memref<8x8xf32>, %[[ARG_1:.*]]: memref<8x8xf32>, %[[ARG_2:.*]]: memref<8x8xf32>) - gpu.func @dpas(%a: memref<8x8xf32>, %b: memref<8x8xf32>, %c: memref<8x8xf32>) { - // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<8x8xf32> - // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> + // CHECK-SAME: (%[[ARG_0:.*]]: memref<256x128xf16>, %[[ARG_1:.*]]: memref<128x256xf16>) + gpu.func @dpas(%a: memref<256x128xf16>, %b: memref<128x256xf16>) { + // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf16> + // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK-NOT: xegpu.create_nd_tdesc - // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<8x8xf32> - // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> - // CHECK-NOT: xegpu.create_nd_tdesc - // CHECK-COUNT-4: xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}] : memref<8x8xf32> - // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> + // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<128x256xf16> + // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [4, 8], lane_data = [1, 1]>> // CHECK-NOT: xegpu.create_nd_tdesc // CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}} - // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} - // CHECK-SAME-COUNT-16: : vector<2x2xf32>, vector<2x2xf32> -> vector<2x2xf32> + // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} + // CHECK-SAME-COUNT-16: : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32> // CHECK-NOT: xegpu.dpas - %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<8x8xf32> - -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<256x128xf16> + -> !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> %load_a = xegpu.load_nd %tdesc_a - : !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> - -> vector<8x8xf32> - %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<8x8xf32> - -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> + : !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> + -> vector<256x128xf16> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x256xf16> + -> !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>> %load_b = xegpu.load_nd %tdesc_b - : !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> - -> vector<8x8xf32> - %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<8x8xf32> - -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> + : !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>> + -> vector<128x256xf16> %dpas = xegpu.dpas %load_a, %load_b - {layout_result_0 = #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>} - : vector<8x8xf32>, vector<8x8xf32> -> vector<8x8xf32> + {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>} + : vector<256x128xf16>, vector<128x256xf16> -> vector<256x256xf32> gpu.return } // CHECK-LABEL: prefetch_nd_tdesc - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @prefetch_nd_tdesc(%src: memref<24x32xf32>) { - // CHECK-COUNT-12: xegpu.prefetch_nd %{{.*}} - // CHECK-SAME-COUNT-12 : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @prefetch_nd_tdesc(%src: memref<256x128xf32>) { + // CHECK-COUNT-4: xegpu.prefetch_nd %{{.*}} + // CHECK-SAME-COUNT-4: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK-NOT: xegpu.prefetch_nd - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> xegpu.prefetch_nd %tdesc - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return } // CHECK-LABEL: broadcast - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32> - gpu.func @broadcast(%src: memref<24x1xf32>) { - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32> - -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<128x1xf32> + gpu.func @broadcast(%src: memref<128x1xf32>) { + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<128x1xf32> + -> !xegpu.tensor_desc<128x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 1], lane_layout = [8, 1], lane_data = [1, 1]>> %load = xegpu.load_nd %tdesc - : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>> - -> vector<24x1xf32> - // CHECK-COUNT-3: vector.broadcast {{.*}} - // CHECK-SAME-COUNT-3: {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>} - // CHECK-SAME-COUNT-3: : vector<2x1xf32> to vector<2x4xf32> + : !xegpu.tensor_desc<128x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 1], lane_layout = [8, 1], lane_data = [1, 1]>> + -> vector<128x1xf32> + // CHECK-COUNT-2: vector.broadcast {{.*}} + // CHECK-SAME-COUNT-2: {layout_result_0 = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>} + // CHECK-SAME-COUNT-2: : vector<16x1xf32> to vector<16x32xf32> // CHECK-NOT: vector.broadcast %broadcast = vector.broadcast %load - {layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 4], lane_layout = [2, 1], lane_data = [1, 1]>} - : vector<24x1xf32> to vector<24x8xf32> + {layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 32], lane_layout = [8, 1], lane_data = [1, 1]>} + : vector<128x1xf32> to vector<128x64xf32> gpu.return } diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir index d511224..d4b0037 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir @@ -4,201 +4,181 @@ //CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)> gpu.module @test_1_1_assignment { // CHECK-LABEL: create_nd_tdesc - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @create_nd_tdesc(%src: memref<24x32xf32>) { + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) { // CHECK: %[[SGID:.*]] = gpu.subgroup_id - // CHECK: %[[C12:.*]] = arith.constant 12 : index - // CHECK: %[[C4:.*]] = arith.constant 4 : index // CHECK: %[[C8:.*]] = arith.constant 8 : index + // CHECK: %[[C32:.*]] = arith.constant 32 : index + // CHECK: %[[C4:.*]] = arith.constant 4 : index + // CHECK: %[[C32_0:.*]] = arith.constant 32 : index + // CHECK: %[[C4_1:.*]] = arith.constant 4 : index // CHECK: %[[DIV:.*]] = affine.apply #map()[%[[SGID]]] // CHECK: %[[REM:.*]] = affine.apply #map1()[%[[SGID]]] - // CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C12]] - // CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C8]] - // CHECK: %[[C24:.*]] = arith.constant 24 : index - // CHECK: %[[MOD:.*]] = index.remu %[[MUL1]], %[[C24]] + // CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C32]] + // CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C32_0]] // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[ADD1:.*]] = index.add %[[MOD]], %[[C0]] - // CHECK: %[[C32:.*]] = arith.constant 32 : index - // CHECK: %[[MOD1:.*]] = index.remu %[[MUL2]], %[[C32]] - // CHECK: %[[C0_1:.*]] = arith.constant 0 : index - // CHECK: %[[ADD2:.*]] = index.add %[[MOD1]], %[[C0_1]] - // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<24x32xf32> - // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> + // CHECK: %[[C256:.*]] = arith.constant 256 : index + // CHECK: %[[MOD:.*]] = index.remu %[[MUL1]], %[[C256]] + // CHECK: %[[C0_2:.*]] = arith.constant 0 : index + // CHECK: %[[ADD1:.*]] = index.add %[[MOD]], %[[C0_2]] + // CHECK: %[[C0_3:.*]] = arith.constant 0 : index + // CHECK: %[[C128:.*]] = arith.constant 128 : index + // CHECK: %[[MOD1:.*]] = index.remu %[[MUL2]], %[[C128]] + // CHECK: %[[C0_4:.*]] = arith.constant 0 : index + // CHECK: %[[ADD2:.*]] = index.add %[[MOD1]], %[[C0_4]] + // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<256x128xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK: gpu.return - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return } // CHECK-LABEL: load_nd_tdesc - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @load_nd_tdesc(%src: memref<24x32xf32>) { - // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> - // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) { + // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] - // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - // CHECK-SAME: -> vector<12x8xf32> - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> + // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK-SAME: -> vector<32x32xf32> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> %load = xegpu.load_nd %tdesc - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> - -> vector<24x32xf32> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> + -> vector<256x128xf32> gpu.return } // CHECK-LABEL: store_nd - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @store_nd(%src: memref<24x32xf32>) { - // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> - // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @store_nd(%src: memref<256x128xf32>) { + // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] - // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - // CHECK-SAME: -> vector<12x8xf32> + // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK-SAME: -> vector<32x32xf32> // CHECK: xegpu.store_nd %[[LOAD]], %[[TDESC]] - // CHECK-SAME: : vector<12x8xf32>, !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> + // CHECK-SAME: : vector<32x32xf32>, !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> %load = xegpu.load_nd %tdesc - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> - -> vector<24x32xf32> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> + -> vector<256x128xf32> xegpu.store_nd %load, %tdesc - : vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> + : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return } // CHECK-LABEL: update_nd -// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> -gpu.func @update_nd(%src: memref<24x32xf32>){ - // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> - // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> +// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> +gpu.func @update_nd(%src: memref<256x128xf32>){ + // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK: %[[UPDATE:.*]] = xegpu.update_nd_offset %[[TDESC]], [0, 16] - // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> + // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> %update = xegpu.update_nd_offset %tdesc, [0, 16] - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return } // CHECK-LABEL: dpas -// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> -// CHECK-SAME: %[[ARG_1:.*]]: memref<32x24xf32> -gpu.func @dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { - // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> - // CHECk-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - // CHECK: %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]] - // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - // CHECK-SAME: -> vector<12x8xf32> - // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32> - // CHECK-SAME: -> !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>> - // CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]] - // CHECK-SAME: : !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>> - // CHECK-SAME: -> vector<8x12xf32> - // CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]] - // CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} - // CHECK-SAME: : vector<12x8xf32>, vector<8x12xf32> -> vector<12x12xf32> - %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> +gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>) { + // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x128xf16>, vector<128x16xf16> -> vector<16x16xf32> + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<128x128xf16> + -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>> %load_a = xegpu.load_nd %tdesc_a - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> - -> vector<24x32xf32> - %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32> - -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>> + : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>> + -> vector<128x128xf16> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x128xf16> + -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>> %load_b = xegpu.load_nd %tdesc_b - : !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>> - -> vector<32x24xf32> + : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>> + -> vector<128x128xf16> %dpas = xegpu.dpas %load_a, %load_b - {layout_result_0 = #xegpu.layout<sg_layout = [2, 2], sg_data = [12, 12], lane_layout = [2, 2], lane_data = [1, 1]>} - : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32> + {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>} + : vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32> gpu.return } // CHECK-LABEL: dpas_no_sg_data -// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> -// CHECK-SAME: %[[ARG_1:.*]]: memref<32x24xf32> -gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { - // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> - // CHECk-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - // CHECK: %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]] - // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - // CHECK-SAME: -> vector<12x8xf32> - // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32> - // CHECK-SAME: -> !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>> - // CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]] - // CHECK-SAME: : !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>> - // CHECK-SAME: -> vector<8x12xf32> - // CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]] - // CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} - // CHECK-SAME: : vector<12x8xf32>, vector<8x12xf32> -> vector<12x12xf32> - %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], lane_layout = [2, 8], lane_data = [1, 1]>> +gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) { + // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32> + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<128x128xf16> + -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1], + order = [1, 0]>> %load_a = xegpu.load_nd %tdesc_a - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], lane_layout = [2, 8], lane_data = [1, 1]>> - -> vector<24x32xf32> - %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32> - -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], lane_layout = [8, 2], lane_data = [1, 1]>> + : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1], + order = [1, 0]>> + -> vector<128x128xf16> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x128xf16> + -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1], + order = [1, 0]>> %load_b = xegpu.load_nd %tdesc_b - : !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], lane_layout = [8, 2], lane_data = [1, 1]>> - -> vector<32x24xf32> + : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1], + order = [1, 0]>> + -> vector<128x128xf16> %dpas = xegpu.dpas %load_a, %load_b - {layout_result_0 = #xegpu.layout<sg_layout = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>} - : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32> + {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} + : vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32> gpu.return } // CHECK-LABEL: prefetch_nd_tdesc - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @prefetch_nd_tdesc(%src: memref<24x32xf32>) { - // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> - // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @prefetch_nd_tdesc(%src: memref<256x128xf32>) { + // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK: xegpu.prefetch_nd %[[TDESC]] - // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> + // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> xegpu.prefetch_nd %tdesc - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return } // CHECK-LABEL: dpas_with_no_create_nd_desc - gpu.func @dpas_with_no_create_nd_desc(%a: vector<24x32xf32>, %b: vector<32x24xf32>) { - // CHECK-NOT: vector<12x12xf32> + gpu.func @dpas_with_no_create_nd_desc(%a: vector<256x128xf32>, %b: vector<128x256xf32>) { + // CHECK-NOT: vector<32x32xf32> %dpas = xegpu.dpas %a, %b {layout = #xegpu.layout<sg_layout = [2, 2], sg_data = [12, 12], lane_layout = [2, 2], lane_data = [1, 1]>} - : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32> + : vector<256x128xf32>, vector<128x256xf32> -> vector<256x256xf32> gpu.return } // CHECK-LABEL: broadcast_dim1 - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32> - gpu.func @broadcast_dim1(%src: memref<24x1xf32>) { - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32> - -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x1xf32> + gpu.func @broadcast_dim1(%src: memref<256x1xf32>) { + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x1xf32> + -> !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 1], lane_layout = [8, 1], lane_data = [1, 1]>> %load = xegpu.load_nd %tdesc - : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>> - -> vector<24x1xf32> - // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>} - // CHECK-SAME: : vector<12x1xf32> to vector<12x8xf32> - %broadcast = vector.broadcast %load - {layout_result_0 = #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 8], lane_layout = [2, 1], lane_data = [1, 1]>} - : vector<24x1xf32> to vector<24x8xf32> + : !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 1], lane_layout = [8, 1], lane_data = [1, 1]>> + -> vector<256x1xf32> + // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>} + // CHECK-SAME: : vector<32x1xf32> to vector<32x32xf32> + %broadcast = vector.broadcast %load + {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 32], lane_layout = [8, 1], lane_data = [1, 1]>} + : vector<256x1xf32> to vector<256x32xf32> gpu.return } // CHECK-LABEL: broadcast_dim0 - // CHECK-SAME: %[[ARG_0:.*]]: memref<1x32xf32> - gpu.func @broadcast_dim0(%src: memref<1x32xf32>) { - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<1x32xf32> - -> !xegpu.tensor_desc<1x32xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 8], lane_layout = [1, 8], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<1x128xf32> + gpu.func @broadcast_dim0(%src: memref<1x128xf32>) { + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<1x128xf32> + -> !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 32], lane_layout = [1, 16], lane_data = [1, 1]>> %load = xegpu.load_nd %tdesc - : !xegpu.tensor_desc<1x32xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 8], lane_layout = [1, 8], lane_data = [1, 1]>> - -> vector<1x32xf32> - // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 8], lane_data = [1, 1]>} - // CHECK-SAME: : vector<1x8xf32> to vector<12x8xf32> + : !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 32], lane_layout = [1, 16], lane_data = [1, 1]>> + -> vector<1x128xf32> + // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} + // CHECK-SAME: : vector<1x32xf32> to vector<32x32xf32> %broadcast = vector.broadcast %load - {layout_result_0 = #xegpu.layout<sg_layout = [1, 4], sg_data = [12, 8], lane_layout = [1, 8], lane_data = [1, 1]>} - : vector<1x32xf32> to vector<12x32xf32> + {layout_result_0 = #xegpu.layout<sg_layout = [1, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>} + : vector<1x128xf32> to vector<32x128xf32> gpu.return } |