aboutsummaryrefslogtreecommitdiff
path: root/mlir/test
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test')
-rw-r--r--mlir/test/Conversion/ConvertToSPIRV/vector.mlir2
-rw-r--r--mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir32
-rw-r--r--mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir17
-rw-r--r--mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir147
-rw-r--r--mlir/test/Dialect/AMDGPU/inlining.mlir14
-rw-r--r--mlir/test/Dialect/Math/canonicalize_ipowi.mlir4
-rw-r--r--mlir/test/Dialect/Vector/canonicalize.mlir2
-rw-r--r--mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir126
-rw-r--r--mlir/test/Dialect/Vector/int-range-interface.mlir2
-rw-r--r--mlir/test/Dialect/Vector/invalid.mlir43
-rw-r--r--mlir/test/Dialect/Vector/linearize.mlir27
-rw-r--r--mlir/test/Dialect/Vector/ops.mlir36
-rw-r--r--mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir4
-rw-r--r--mlir/test/Dialect/XeGPU/move-gpu-func-to-warp-op.mlir63
-rw-r--r--mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir83
-rw-r--r--mlir/test/IR/invalid-ops.mlir2
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir8
-rw-r--r--mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp31
-rw-r--r--mlir/test/mlir-runner/utils.mlir2
-rw-r--r--mlir/test/python/rewrite.py69
20 files changed, 430 insertions, 284 deletions
diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
index eb9feaa..a75f30d 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
@@ -86,7 +86,7 @@ func.func @fma_size1_vector(%a: vector<1xf32>, %b: vector<1xf32>, %c: vector<1xf
// CHECK: %[[VAL:.+]] = spirv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]]
// CHECK: spirv.ReturnValue %[[VAL]] : vector<4xf32>
func.func @splat(%f : f32) -> vector<4xf32> {
- %splat = vector.splat %f : vector<4xf32>
+ %splat = vector.broadcast %f : f32 to vector<4xf32>
return %splat : vector<4xf32>
}
diff --git a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
index c8a434b..1735e08 100644
--- a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
@@ -429,38 +429,6 @@ func.func @broadcast_vec2d_from_vec1d(%arg0: vector<[8]xi16>) {
return
}
-//===----------------------------------------------------------------------===//
-// vector.splat
-//===----------------------------------------------------------------------===//
-
-// -----
-
-// CHECK-LABEL: func.func @splat_vec2d_from_i32(
-// CHECK-SAME: %[[SRC:.*]]: i32) {
-// CHECK: %[[BCST:.*]] = vector.broadcast %[[SRC]] : i32 to vector<[4]xi32>
-// CHECK: arm_sme.get_tile : vector<[4]x[4]xi32>
-// CHECK: %[[VSCALE:.*]] = vector.vscale
-// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %{{.*}} : index
-// CHECK: scf.for {{.*}} to %[[NUM_TILE_SLICES]] {{.*}} {
-// CHECK: arm_sme.insert_tile_slice %[[BCST]], {{.*}} : vector<[4]xi32> into vector<[4]x[4]xi32>
-func.func @splat_vec2d_from_i32(%arg0: i32) {
- %0 = vector.splat %arg0 : vector<[4]x[4]xi32>
- "prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> ()
- return
-}
-
-// -----
-
-// CHECK-LABEL: func.func @splat_vec2d_from_f16(
-// CHECK-SAME: %[[SRC:.*]]: f16) {
-// CHECK: %[[BCST:.*]] = vector.broadcast %[[SRC]] : f16 to vector<[8]xf16>
-// CHECK: scf.for
-// CHECK: arm_sme.insert_tile_slice %[[BCST]], {{.*}} : vector<[8]xf16> into vector<[8]x[8]xf16>
-func.func @splat_vec2d_from_f16(%arg0: f16) {
- %0 = vector.splat %arg0 : vector<[8]x[8]xf16>
- "prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
- return
-}
//===----------------------------------------------------------------------===//
// vector.transpose
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index 5973c2ba..cb48ca3 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2217,23 +2217,6 @@ func.func @compress_store_op_with_alignment(%arg0: memref<?xindex>, %arg1: vecto
// -----
//===----------------------------------------------------------------------===//
-// vector.splat
-//===----------------------------------------------------------------------===//
-
-// vector.splat is converted to vector.broadcast. Then, vector.broadcast is converted to LLVM.
-// CHECK-LABEL: @splat_0d
-// CHECK-NOT: splat
-// CHECK: return
-func.func @splat_0d(%elt: f32) -> (vector<f32>, vector<4xf32>, vector<[4]xf32>) {
- %a = vector.splat %elt : vector<f32>
- %b = vector.splat %elt : vector<4xf32>
- %c = vector.splat %elt : vector<[4]xf32>
- return %a, %b, %c : vector<f32>, vector<4xf32>, vector<[4]xf32>
-}
-
-// -----
-
-//===----------------------------------------------------------------------===//
// vector.scalable_insert
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
index b31a973..72e70ff 100644
--- a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
+++ b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
@@ -35,7 +35,7 @@ llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32
// -----
// CHECK-LABEL: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(
llvm.func @blockload2d_cache_control(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> {
- // CHECK: xevm.DecorationCacheControl =
+ // CHECK: xevm.DecorationCacheControl =
// CHECK-SAME: 6442 : i32, 0 : i32, 1 : i32, 0 : i32
// CHECK-SAME: 6442 : i32, 1 : i32, 1 : i32, 0 : i32
%loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
@@ -345,3 +345,148 @@ llvm.func @blockstore_scalar(%ptr: !llvm.ptr<3>, %data: i64) {
xevm.blockstore %ptr, %data <{cache_control=#xevm.store_cache_control<L1wt_L2uc_L3wb>}> : (!llvm.ptr<3>, i64)
llvm.return
}
+
+// -----
+// CHECK-LABEL: llvm.func @local_id.x() -> i32 {
+llvm.func @local_id.x() -> i32 {
+ // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z12get_local_idj(%[[VAR0]])
+ // CHECK-SAME: {function_type = !llvm.func<i32 (i32)>, linkage = #llvm.linkage<external>,
+ // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
+ // CHECK-SAME: no_unwind, sym_name = "_Z12get_local_idj", visibility_ = 0 : i64, will_return} : (i32) -> i32
+ %1 = xevm.local_id.x : i32
+ llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @local_id.y() -> i32 {
+llvm.func @local_id.y() -> i32 {
+ // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i32) : i32
+ %1 = xevm.local_id.y : i32
+ llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @local_id.z() -> i32 {
+llvm.func @local_id.z() -> i32 {
+ // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(2 : i32) : i32
+ %1 = xevm.local_id.z : i32
+ llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @local_size.x() -> i32 {
+llvm.func @local_size.x() -> i32 {
+ // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z14get_local_sizej(%[[VAR0]])
+ // CHECK-SAME: {function_type = !llvm.func<i32 (i32)>, linkage = #llvm.linkage<external>,
+ // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
+ // CHECK-SAME: no_unwind, sym_name = "_Z14get_local_sizej", visibility_ = 0 : i64, will_return} : (i32) -> i32
+ %1 = xevm.local_size.x : i32
+ llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @local_size.y() -> i32 {
+llvm.func @local_size.y() -> i32 {
+ // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i32) : i32
+ %1 = xevm.local_size.y : i32
+ llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @local_size.z() -> i32 {
+llvm.func @local_size.z() -> i32 {
+ // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(2 : i32) : i32
+ %1 = xevm.local_size.z : i32
+ llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @group_id.x() -> i32 {
+llvm.func @group_id.x() -> i32 {
+ // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z12get_group_idj(%[[VAR0]])
+ // CHECK-SAME: {function_type = !llvm.func<i32 (i32)>, linkage = #llvm.linkage<external>,
+ // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
+ // CHECK-SAME: no_unwind, sym_name = "_Z12get_group_idj", visibility_ = 0 : i64, will_return} : (i32) -> i32
+ %1 = xevm.group_id.x : i32
+ llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @group_id.y() -> i32 {
+llvm.func @group_id.y() -> i32 {
+ // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i32) : i32
+ %1 = xevm.group_id.y : i32
+ llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @group_id.z() -> i32 {
+llvm.func @group_id.z() -> i32 {
+ // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(2 : i32) : i32
+ %1 = xevm.group_id.z : i32
+ llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @group_count.x() -> i32 {
+llvm.func @group_count.x() -> i32 {
+ // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z14get_num_groupsj(%[[VAR0]])
+ // CHECK-SAME: {function_type = !llvm.func<i32 (i32)>, linkage = #llvm.linkage<external>,
+ // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
+ // CHECK-SAME: no_unwind, sym_name = "_Z14get_num_groupsj", visibility_ = 0 : i64, will_return} : (i32) -> i32
+ %1 = xevm.group_count.x : i32
+ llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @group_count.y() -> i32 {
+llvm.func @group_count.y() -> i32 {
+ // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i32) : i32
+ %1 = xevm.group_count.y : i32
+ llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @group_count.z() -> i32 {
+llvm.func @group_count.z() -> i32 {
+ // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(2 : i32) : i32
+ %1 = xevm.group_count.z : i32
+ llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func spir_funccc @_Z22get_sub_group_local_id() -> i32 attributes {no_unwind, will_return}
+llvm.func @lane_id() -> i32 {
+ // CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id()
+ // CHECK-SAME: {function_type = !llvm.func<i32 ()>, linkage = #llvm.linkage<external>,
+ // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
+ // CHECK-SAME: no_unwind, sym_name = "_Z22get_sub_group_local_id", visibility_ = 0 : i64, will_return} : () -> i32
+ %1 = xevm.lane_id : i32
+ llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func spir_funccc @_Z18get_sub_group_size() -> i32 attributes {no_unwind, will_return}
+llvm.func @subgroup_size() -> i32 {
+ // CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z18get_sub_group_size()
+ // CHECK-SAME: {function_type = !llvm.func<i32 ()>, linkage = #llvm.linkage<external>,
+ // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
+ // CHECK-SAME: no_unwind, sym_name = "_Z18get_sub_group_size", visibility_ = 0 : i64, will_return} : () -> i32
+ %1 = xevm.subgroup_size : i32
+ llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func spir_funccc @_Z16get_sub_group_id() -> i32 attributes {no_unwind, will_return}
+llvm.func @subgroup_id() -> i32 {
+ // CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id()
+ // CHECK-SAME: {function_type = !llvm.func<i32 ()>, linkage = #llvm.linkage<external>,
+ // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
+ // CHECK-SAME: no_unwind, sym_name = "_Z16get_sub_group_id", visibility_ = 0 : i64, will_return} : () -> i32
+ %1 = xevm.subgroup_id : i32
+ llvm.return %1 : i32
+}
diff --git a/mlir/test/Dialect/AMDGPU/inlining.mlir b/mlir/test/Dialect/AMDGPU/inlining.mlir
new file mode 100644
index 0000000..33fb635
--- /dev/null
+++ b/mlir/test/Dialect/AMDGPU/inlining.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s --inline | FileCheck %s
+
+func.func @calee(%arg0 : f32) -> f32 {
+ %0 = amdgpu.permlane_swap %arg0 32 : f32
+ func.return %0 : f32
+}
+
+// CHECK-LABEL: func @caller
+func.func @caller(%arg0 : f32) -> f32 {
+ // CHECK-NOT: call
+ // CHECK: amdgpu.permlane_swap
+ %0 = call @calee(%arg0) : (f32) -> f32
+ func.return %0 : f32
+}
diff --git a/mlir/test/Dialect/Math/canonicalize_ipowi.mlir b/mlir/test/Dialect/Math/canonicalize_ipowi.mlir
index 9e65a96..6812092 100644
--- a/mlir/test/Dialect/Math/canonicalize_ipowi.mlir
+++ b/mlir/test/Dialect/Math/canonicalize_ipowi.mlir
@@ -105,9 +105,9 @@ func.func @ipowi32_fold(%result : memref<?xi32>) {
// --- Test vector folding ---
%arg11_base = arith.constant 2 : i32
- %arg11_base_vec = vector.splat %arg11_base : vector<2x2xi32>
+ %arg11_base_vec = vector.broadcast %arg11_base : i32 to vector<2x2xi32>
%arg11_power = arith.constant 30 : i32
- %arg11_power_vec = vector.splat %arg11_power : vector<2x2xi32>
+ %arg11_power_vec = vector.broadcast %arg11_power : i32 to vector<2x2xi32>
%res11_vec = math.ipowi %arg11_base_vec, %arg11_power_vec : vector<2x2xi32>
%i11 = arith.constant 11 : index
%res11 = vector.extract %res11_vec[1, 1] : i32 from vector<2x2xi32>
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index eb369c0..59774f9 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -837,7 +837,7 @@ func.func @fold_extract_splatlike(%a : f32, %idx0 : index, %idx1 : index, %idx2
// CHECK-LABEL: fold_extract_vector_from_splat
// CHECK: vector.broadcast {{.*}} f32 to vector<4xf32>
func.func @fold_extract_vector_from_splat(%a : f32, %idx0 : index, %idx1 : index) -> vector<4xf32> {
- %b = vector.splat %a : vector<1x2x4xf32>
+ %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
return %r : vector<4xf32>
}
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir
deleted file mode 100644
index e4a9391..0000000
--- a/mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir
+++ /dev/null
@@ -1,126 +0,0 @@
-// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
-
-// This file should be removed when vector.splat is removed.
-// This file tests canonicalization/folding with vector.splat.
-// These tests all have equivalent tests using vector.broadcast in canonicalize.mlir
-
-
-// CHECK-LABEL: fold_extract_splat
-// CHECK-SAME: %[[A:.*]]: f32
-// CHECK: return %[[A]] : f32
-func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
- %b = vector.splat %a : vector<1x2x4xf32>
- %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
- return %r : f32
-}
-
-// -----
-
-// CHECK-LABEL: extract_strided_splat
-// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} f16 to vector<2x4xf16>
-// CHECK-NEXT: return %[[B]] : vector<2x4xf16>
-func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> {
- %0 = vector.splat %arg0 : vector<16x4xf16>
- %1 = vector.extract_strided_slice %0
- {offsets = [1, 0], sizes = [2, 4], strides = [1, 1]} :
- vector<16x4xf16> to vector<2x4xf16>
- return %1 : vector<2x4xf16>
-}
-
-// -----
-
-// CHECK-LABEL: func @splat_fold
-// CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
-// CHECK-NEXT: return [[V]] : vector<4xf32>
-func.func @splat_fold() -> vector<4xf32> {
- %c = arith.constant 1.0 : f32
- %v = vector.splat %c : vector<4xf32>
- return %v : vector<4xf32>
-
-}
-
-// -----
-
-// CHECK-LABEL: func @transpose_splat2(
-// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> {
-// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : f32 to vector<3x4xf32>
-// CHECK: return %[[VAL_1]] : vector<3x4xf32>
-func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> {
- %splat = vector.splat %arg : vector<4x3xf32>
- %0 = vector.transpose %splat, [1, 0] : vector<4x3xf32> to vector<3x4xf32>
- return %0 : vector<3x4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @insert_strided_slice_splat
-// CHECK-SAME: (%[[ARG:.*]]: f32)
-// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : f32 to vector<8x16xf32>
-// CHECK-NEXT: return %[[SPLAT]] : vector<8x16xf32>
-func.func @insert_strided_slice_splat(%x: f32) -> (vector<8x16xf32>) {
- %splat0 = vector.splat %x : vector<4x4xf32>
- %splat1 = vector.splat %x : vector<8x16xf32>
- %0 = vector.insert_strided_slice %splat0, %splat1 {offsets = [2, 2], strides = [1, 1]}
- : vector<4x4xf32> into vector<8x16xf32>
- return %0 : vector<8x16xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @shuffle_splat
-// CHECK-SAME: (%[[ARG:.*]]: i32)
-// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<4xi32>
-// CHECK-NEXT: return %[[SPLAT]] : vector<4xi32>
-func.func @shuffle_splat(%x : i32) -> vector<4xi32> {
- %v0 = vector.splat %x : vector<4xi32>
- %v1 = vector.splat %x : vector<2xi32>
- %shuffle = vector.shuffle %v0, %v1 [2, 3, 4, 5] : vector<4xi32>, vector<2xi32>
- return %shuffle : vector<4xi32>
-}
-
-
-// -----
-
-// CHECK-LABEL: func @insert_splat
-// CHECK-SAME: (%[[ARG:.*]]: i32)
-// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<2x4x3xi32>
-// CHECK-NEXT: return %[[SPLAT]] : vector<2x4x3xi32>
-func.func @insert_splat(%x : i32) -> vector<2x4x3xi32> {
- %v0 = vector.splat %x : vector<4x3xi32>
- %v1 = vector.splat %x : vector<2x4x3xi32>
- %insert = vector.insert %v0, %v1[0] : vector<4x3xi32> into vector<2x4x3xi32>
- return %insert : vector<2x4x3xi32>
-}
-
-// -----
-
-// CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression
-// CHECK-SAME: (%[[A:.*]]: f32, %[[C:.*]]: vector<2xf32>)
-func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %c: vector<2xf32>) -> (f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) {
- // Splat scalar to 0D and extract scalar.
- %0 = vector.splat %a : vector<f32>
- %1 = vector.extract %0[] : f32 from vector<f32>
-
- // Broadcast scalar to 0D and extract scalar.
- %2 = vector.splat %a : vector<f32>
- %3 = vector.extract %2[] : f32 from vector<f32>
-
- // Splat scalar to 2D and extract scalar.
- %6 = vector.splat %a : vector<2x3xf32>
- %7 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
-
- // Broadcast scalar to 3D and extract scalar.
- %8 = vector.splat %a : vector<5x6x7xf32>
- %9 = vector.extract %8[2, 1, 5] : f32 from vector<5x6x7xf32>
-
- // Extract 2D from 3D that was broadcasted from a scalar.
- // CHECK: %[[EXTRACT2:.*]] = vector.broadcast %[[A]] : f32 to vector<6x7xf32>
- %10 = vector.extract %8[2] : vector<6x7xf32> from vector<5x6x7xf32>
-
- // Extract 1D from 2D that was splat'ed from a scalar.
- // CHECK: %[[EXTRACT3:.*]] = vector.broadcast %[[A]] : f32 to vector<3xf32>
- %11 = vector.extract %6[1] : vector<3xf32> from vector<2x3xf32>
-
- // CHECK: return %[[A]], %[[A]], %[[A]], %[[A]], %[[EXTRACT2]], %[[EXTRACT3]]
- return %1, %3, %7, %9, %10, %11 : f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>
-}
diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir
index b2f16bb..4da8d8a 100644
--- a/mlir/test/Dialect/Vector/int-range-interface.mlir
+++ b/mlir/test/Dialect/Vector/int-range-interface.mlir
@@ -28,7 +28,7 @@ func.func @float_constant_splat() -> vector<8xf32> {
// CHECK: test.reflect_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index}
func.func @vector_splat() -> vector<4xindex> {
%0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : index
- %1 = vector.splat %0 : vector<4xindex>
+ %1 = vector.broadcast %0 : index to vector<4xindex>
%2 = test.reflect_bounds %1 : vector<4xindex>
func.return %2 : vector<4xindex>
}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 6ee70fd..5f035e3 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -320,7 +320,7 @@ func.func @test_vector.transfer_write(%m: memref<1xi32>, %2: vector<1x32xi32>)
func.func @test_vector.transfer_read(%arg0: vector<4x3xf32>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
- %vf0 = vector.splat %f0 : vector<4x3xf32>
+ %vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32>
// expected-error@+1 {{ requires memref or ranked tensor type}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : vector<4x3xf32>, vector<1x1x2x3xf32>
}
@@ -330,7 +330,7 @@ func.func @test_vector.transfer_read(%arg0: vector<4x3xf32>) {
func.func @test_vector.transfer_read(%arg0: memref<4x3xf32>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
- %vf0 = vector.splat %f0 : vector<4x3xf32>
+ %vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32>
// expected-error@+1 {{ requires vector type}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : memref<4x3xf32>, f32
}
@@ -414,7 +414,7 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?x?xf32>) {
%c3 = arith.constant 3 : index
%cst = arith.constant 3.0 : f32
// expected-note@+1 {{prior use here}}
- %mask = vector.splat %c1 : vector<3x8x7xi1>
+ %mask = vector.broadcast %c1 : i1 to vector<3x8x7xi1>
// expected-error@+1 {{expects different type than prior uses: 'vector<3x7xi1>' vs 'vector<3x8x7xi1>'}}
%0 = vector.transfer_read %arg0[%c3, %c3, %c3], %cst, %mask {permutation_map = affine_map<(d0, d1, d2)->(d0, 0, d2)>} : memref<?x?x?xf32>, vector<3x8x7xf32>
}
@@ -424,7 +424,7 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?x?xf32>) {
func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
- %vf0 = vector.splat %f0 : vector<4x3xf32>
+ %vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32>
// expected-error@+1 {{requires source vector element and vector result ranks to match}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<4x3xf32>>, vector<3xf32>
}
@@ -434,7 +434,7 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<6xf32>>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
- %vf0 = vector.splat %f0 : vector<6xf32>
+ %vf0 = vector.broadcast %f0 : f32 to vector<6xf32>
// expected-error@+1 {{requires the bitwidth of the minor 1-D vector to be an integral multiple of the bitwidth of the minor 1-D vector of the source}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : memref<?x?xvector<6xf32>>, vector<3xf32>
}
@@ -444,7 +444,7 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<6xf32>>) {
func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
- %vf0 = vector.splat %f0 : vector<2x3xf32>
+ %vf0 = vector.broadcast %f0 : f32 to vector<2x3xf32>
// expected-error@+1 {{ expects the in_bounds attr of same rank as permutation_map results: affine_map<(d0, d1) -> (d0, d1)>}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {in_bounds = [true], permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<2x3xf32>>, vector<1x1x2x3xf32>
}
@@ -454,8 +454,8 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
- %vf0 = vector.splat %f0 : vector<2x3xf32>
- %mask = vector.splat %c1 : vector<2x3xi1>
+ %vf0 = vector.broadcast %f0 : f32 to vector<2x3xf32>
+ %mask = vector.broadcast %c1 : f32 to vector<2x3xi1>
// expected-error@+1 {{does not support masks with vector element type}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0, %mask {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<2x3xf32>>, vector<1x1x2x3xf32>
}
@@ -492,7 +492,7 @@ func.func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
func.func @test_vector.transfer_write(%arg0: memref<vector<4x3xf32>>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
- %vf0 = vector.splat %f0 : vector<4x3xf32>
+ %vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32>
// expected-error@+1 {{ requires vector type}}
vector.transfer_write %arg0, %arg0[%c3, %c3] : memref<vector<4x3xf32>>, vector<4x3xf32>
}
@@ -502,7 +502,7 @@ func.func @test_vector.transfer_write(%arg0: memref<vector<4x3xf32>>) {
func.func @test_vector.transfer_write(%arg0: vector<4x3xf32>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
- %vf0 = vector.splat %f0 : vector<4x3xf32>
+ %vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32>
// expected-error@+1 {{ requires memref or ranked tensor type}}
vector.transfer_write %arg0, %arg0[%c3, %c3] : vector<4x3xf32>, f32
}
@@ -1981,29 +1981,6 @@ func.func @invalid_step_2d() {
// -----
//===----------------------------------------------------------------------===//
-// vector.splat
-//===----------------------------------------------------------------------===//
-
-// -----
-
-func.func @vector_splat_invalid_result(%v : f32) {
- // expected-error@+1 {{invalid kind of type specified: expected builtin.vector, but found 'memref<8xf32>'}}
- vector.splat %v : memref<8xf32>
- return
-}
-
-// -----
-
-// expected-note @+1 {{prior use here}}
-func.func @vector_splat_type_mismatch(%a: f32) {
- // expected-error @+1 {{expects different type than prior uses: 'i32' vs 'f32'}}
- %0 = vector.splat %a : vector<1xi32>
- return
-}
-
-// -----
-
-//===----------------------------------------------------------------------===//
// vector.load
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index fe697c8..ee5cfbcd 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -428,33 +428,6 @@ func.func @test_linearize_across_for(%arg0 : vector<4xi8>) -> vector<4xi8> {
// -----
-// CHECK-LABEL: linearize_vector_splat
-// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x2xi32>
-func.func @linearize_vector_splat(%arg0: i32) -> vector<4x2xi32> {
-
- // CHECK: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8xi32>
- // CHECK: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<8xi32> to vector<4x2xi32>
- // CHECK: return %[[CAST]] : vector<4x2xi32>
- %0 = vector.splat %arg0 : vector<4x2xi32>
- return %0 : vector<4x2xi32>
-}
-
-// -----
-
-// CHECK-LABEL: linearize_scalable_vector_splat
-// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x[2]xi32>
-func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
-
- // CHECK: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<[8]xi32>
- // CHECK: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<[8]xi32> to vector<4x[2]xi32>
- // CHECK: return %[[CAST]] : vector<4x[2]xi32>
- %0 = vector.splat %arg0 : vector<4x[2]xi32>
- return %0 : vector<4x[2]xi32>
-
-}
-
-// -----
-
// CHECK-LABEL: linearize_create_mask
// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<1x16xi1>
func.func @linearize_create_mask(%arg0 : index, %arg1 : index) -> vector<1x16xi1> {
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 550e52a..da9a1a8 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -45,11 +45,11 @@ func.func @vector_transfer_ops(%arg0: memref<?x?xf32>,
%i0 = arith.constant 0 : index
%i1 = arith.constant 1 : i1
- %vf0 = vector.splat %f0 : vector<4x3xf32>
- %v0 = vector.splat %c0 : vector<4x3xi32>
- %vi0 = vector.splat %i0 : vector<4x3xindex>
+ %vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32>
+ %v0 = vector.broadcast %c0 : i32 to vector<4x3xi32>
+ %vi0 = vector.broadcast %i0 : index to vector<4x3xindex>
%m = arith.constant dense<[0, 0, 1, 0, 1]> : vector<5xi1>
- %m2 = vector.splat %i1 : vector<4x5xi1>
+ %m2 = vector.broadcast %i1 : i1 to vector<4x5xi1>
//
// CHECK: vector.transfer_read
%0 = vector.transfer_read %arg0[%c3, %c3], %f0 {permutation_map = affine_map<(d0, d1)->(d0)>} : memref<?x?xf32>, vector<128xf32>
@@ -106,9 +106,9 @@ func.func @vector_transfer_ops_tensor(%arg0: tensor<?x?xf32>,
%c0 = arith.constant 0 : i32
%i0 = arith.constant 0 : index
- %vf0 = vector.splat %f0 : vector<4x3xf32>
- %v0 = vector.splat %c0 : vector<4x3xi32>
- %vi0 = vector.splat %i0 : vector<4x3xindex>
+ %vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32>
+ %v0 = vector.broadcast %c0 : i32 to vector<4x3xi32>
+ %vi0 = vector.broadcast %i0 : index to vector<4x3xindex>
//
// CHECK: vector.transfer_read
@@ -922,28 +922,6 @@ func.func @vector_scan(%0: vector<4x8x16x32xf32>) -> vector<4x8x16x32xf32> {
return %2#0 : vector<4x8x16x32xf32>
}
-// CHECK-LABEL: func @test_splat_op
-// CHECK-SAME: %[[s:.*]]: f32, %[[s2:.*]]: !llvm.ptr<1>
-func.func @test_splat_op(%s : f32, %s2 : !llvm.ptr<1>) {
- // CHECK: vector.splat %[[s]] : vector<8xf32>
- %v = vector.splat %s : vector<8xf32>
-
- // CHECK: vector.splat %[[s]] : vector<4xf32>
- %u = "vector.splat"(%s) : (f32) -> vector<4xf32>
-
- // CHECK: vector.splat %[[s2]] : vector<16x!llvm.ptr<1>>
- %w = vector.splat %s2 : vector<16x!llvm.ptr<1>>
- return
-}
-
-// CHECK-LABEL: func @vector_splat_0d(
-func.func @vector_splat_0d(%a: f32) -> vector<f32> {
- // CHECK: vector.splat %{{.*}} : vector<f32>
- %0 = vector.splat %a : vector<f32>
- return %0 : vector<f32>
-}
-
-
// CHECK-LABEL: func @vector_mask
func.func @vector_mask(%a: vector<8xi32>, %m0: vector<8xi1>) -> i32 {
// CHECK-NEXT: %{{.*}} = vector.mask %{{.*}} { vector.reduction <add>, %{{.*}} : vector<8xi32> into i32 } : vector<8xi1> -> i32
diff --git a/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir b/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
index e74eb08..6e5d68c 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
@@ -49,7 +49,7 @@ func.func @vector_maskedload(%arg0 : memref<4x5xf32>) -> vector<4xf32> {
%idx_4 = arith.constant 4 : index
%mask = vector.create_mask %idx_1 : vector<4xi1>
%s = arith.constant 0.0 : f32
- %pass_thru = vector.splat %s : vector<4xf32>
+ %pass_thru = vector.broadcast %s : f32 to vector<4xf32>
%0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru : memref<4x5xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
return %0: vector<4xf32>
}
@@ -65,7 +65,7 @@ func.func @vector_maskedload_with_alignment(%arg0 : memref<4x5xf32>) -> vector<4
%idx_4 = arith.constant 4 : index
%mask = vector.create_mask %idx_1 : vector<4xi1>
%s = arith.constant 0.0 : f32
- %pass_thru = vector.splat %s : vector<4xf32>
+ %pass_thru = vector.broadcast %s : f32 to vector<4xf32>
%0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru {alignment = 8}: memref<4x5xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
return %0: vector<4xf32>
}
diff --git a/mlir/test/Dialect/XeGPU/move-gpu-func-to-warp-op.mlir b/mlir/test/Dialect/XeGPU/move-gpu-func-to-warp-op.mlir
new file mode 100644
index 0000000..d289d73
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/move-gpu-func-to-warp-op.mlir
@@ -0,0 +1,63 @@
+// RUN: mlir-opt -test-xegpu-move-func-to-warp-op -split-input-file --allow-unregistered-dialect %s | FileCheck %s
+
+gpu.module @test {
+gpu.func @empty() {
+ gpu.return
+}
+}
+
+// CHECK-LABEL: gpu.func @empty() {
+// CHECK-NEXT: gpu.return
+// CHECK-NEXT: }
+
+// -----
+gpu.module @test {
+gpu.func @gemm(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
+ %c0 = arith.constant 0 : index
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+ %1 = xegpu.create_nd_tdesc %arg1 : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+ %2 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+ %3 = xegpu.load_nd %1[%c0, %c0] : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+ %4 = xegpu.dpas %2, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+ %5 = xegpu.create_nd_tdesc %arg2 : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+ xegpu.store_nd %4, %5[%c0, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+ gpu.return
+}
+}
+
+// CHECK-LABEL: gpu.func @gemm(
+// CHECK: %[[ARG0:[a-zA-Z0-9]+]]: memref<8x16xf16>, %[[ARG1:[a-zA-Z0-9]+]]: memref<16x16xf16>,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: memref<8x16xf32>) {
+// CHECK: %[[LANEID:.*]] = gpu.lane_id
+// CHECK-NEXT: gpu.warp_execute_on_lane_0(%[[LANEID]])[16]
+// CHECK-SAME: args(%[[ARG0]], %[[ARG1]], %[[ARG2]] : memref<8x16xf16>, memref<16x16xf16>, memref<8x16xf32>) {
+// CHECK: ^bb0(%[[ARG3:[a-zA-Z0-9]+]]: memref<8x16xf16>, %[[ARG4:[a-zA-Z0-9]+]]: memref<16x16xf16>,
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: memref<8x16xf32>):
+// CHECK-NEXT: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG3]] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+// CHECK-NEXT: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG4]] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+// CHECK-NEXT: %[[T3:.*]] = xegpu.load_nd %[[T1]][{{.*}}] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+// CHECK-NEXT: %[[T4:.*]] = xegpu.load_nd %[[T2]][{{.*}}] : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: %[[T5:.*]] = xegpu.dpas %[[T3]], %[[T4]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+// CHECK-NEXT: %[[T6:.*]] = xegpu.create_nd_tdesc %[[ARG5]] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK-NEXT: xegpu.store_nd %[[T5]], %[[T6]][%{{.*}}] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: gpu.return
+
+// -----
+gpu.module @test {
+gpu.func @already_in_warp_op() {
+ %laneid = gpu.lane_id
+ gpu.warp_execute_on_lane_0(%laneid)[16] {
+ "some_op"() : () -> ()
+ gpu.yield
+ }
+ gpu.return
+}
+}
+
+// CHECK-LABEL: gpu.func @already_in_warp_op() {
+// CHECK: %[[LANEID:.*]] = gpu.lane_id
+// CHECK: gpu.warp_execute_on_lane_0(%[[LANEID]])[16] {
+// CHECK: "some_op"() : () -> ()
+// CHECK: }
+// CHECK: gpu.return
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
index 40b66d1..f233dff 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
@@ -530,7 +530,7 @@ gpu.module @xevm_module{
// CHECK-NEXT: }
// CHECK-NEXT: %[[T1:.*]] = vector.transpose %[[W]]#1, [1, 0] : vector<1x2xf32> to vector<2x1xf32>
gpu.module @xevm_module{
- gpu.func @vector_transpose(%arg0: memref<2x16xf32>, %laneid: index) {
+ gpu.func @vector_transpose(%laneid: index) {
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2x1xf32>) {
%cst = "some_op"()
{layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
@@ -556,7 +556,7 @@ gpu.module @xevm_module{
// CHECK: }
// CHECK: vector.bitcast %[[W]]#1 : vector<4x2xi8> to vector<4x1xi16>
gpu.module @xevm_module{
- gpu.func @vector_bitcast(%arg0: memref<4x16xi16>, %laneid: index) {
+ gpu.func @vector_bitcast(%laneid: index) {
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<4x1xi16>) {
%cst = "some_op"()
{layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>}
@@ -573,3 +573,82 @@ gpu.module @xevm_module{
gpu.return
}
}
+
+// -----
+// CHECK-LABEL: gpu.func @vector_shapecast_rank_increasing
+// CHECK: %{{.*}}:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1x1xf32>, vector<1xf32>) {
+// CHECK: gpu.yield %{{.*}} : vector<1x16xf32>, vector<16xf32>
+// CHECK: }
+// CHECK: %{{.*}} = vector.shape_cast %{{.*}}#1 : vector<1xf32> to vector<1x1xf32>
+gpu.module @xevm_module {
+ gpu.func @vector_shapecast_rank_increasing(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) {
+ %cst = "some_op"()
+ {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>}
+ : () -> (vector<16xf32>)
+ %cast = vector.shape_cast %cst
+ {
+ layout_operand_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }
+ : vector<16xf32> to vector<1x16xf32>
+ gpu.yield %cast : vector<1x16xf32>
+ }
+ "some_user_op"(%r) : (vector<1x1xf32>) -> ()
+ gpu.return
+ }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @vector_shapecast_rank_reducing(
+// CHECK: %{{.*}}:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1xf32>, vector<1x1xf32>) {
+// CHECK: gpu.yield %{{.*}} : vector<16xf32>, vector<1x16xf32>
+// CHECK: }
+// CHECK: %{{.*}} = vector.shape_cast %{{.*}}#1 : vector<1x1xf32> to vector<1xf32>
+gpu.module @xevm_module {
+ gpu.func @vector_shapecast_rank_reducing(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf32>) {
+ %cst = "some_op"()
+ {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : () -> (vector<1x16xf32>)
+ %cast = vector.shape_cast %cst
+ {
+ layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>
+ }
+ : vector<1x16xf32> to vector<16xf32>
+ gpu.yield %cast : vector<16xf32>
+ }
+ "some_user_op"(%r) : (vector<1xf32>) -> ()
+ gpu.return
+ }
+}
+
+// -----
+// NOTE: Layouts are still valid, but distribution still requires a slice layout for the operand.
+//
+// CHECK-LABEL: gpu.func @vector_shapecast_unsupported
+// CHECK: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1x1xf32>) {
+// CHECK: %[[T1:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<1x16xf32>
+// CHECK: gpu.yield %[[T1]] : vector<1x16xf32>
+// CHECK: }
+// CHECK: "some_user_op"(%[[W]]) : (vector<1x1xf32>) -> ()
+// CHECK: gpu.return
+gpu.module @xevm_module {
+ gpu.func @vector_shapecast_unsupported(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) {
+ %cst = "some_op"()
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> }
+ : () -> (vector<16xf32>)
+ %cast = vector.shape_cast %cst
+ {
+ layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }
+ : vector<16xf32> to vector<1x16xf32>
+ gpu.yield %cast : vector<1x16xf32>
+ }
+ "some_user_op"(%r) : (vector<1x1xf32>) -> ()
+ gpu.return
+ }
+}
diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 12a911c..0c5fec8c 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -107,7 +107,7 @@ func.func @return_not_in_function() {
// -----
func.func @invalid_splat(%v : f32) { // expected-note {{prior use here}}
- vector.splat %v : vector<8xf64>
+ vector.broadcast %v : f64 to vector<8xf64>
// expected-error@-1 {{expects different type than prior uses}}
return
}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir
index 6ec1031..6c2183f 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir
@@ -21,13 +21,6 @@ func.func @print_vector_0d(%a: vector<f32>) {
return
}
-func.func @splat_0d(%a: f32) {
- %1 = vector.splat %a : vector<f32>
- // CHECK: ( 42 )
- vector.print %1: vector<f32>
- return
-}
-
func.func @broadcast_0d(%a: f32) {
%1 = vector.broadcast %a : f32 to vector<f32>
// CHECK: ( 42 )
@@ -147,7 +140,6 @@ func.func @entry() {
// Warning: these must be called in their textual order of definition in the
// file to not mess up FileCheck.
- call @splat_0d(%4) : (f32) -> ()
call @broadcast_0d(%4) : (f32) -> ()
call @bitcast_0d() : () -> ()
call @constant_mask_0d() : () -> ()
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 6ba7a00..76d4611 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -247,6 +247,36 @@ struct TestXeGPUSGDistribute
}
};
+struct TestXeGPUMoveFuncBodyToWarpOp
+ : public PassWrapper<TestXeGPUMoveFuncBodyToWarpOp,
+ OperationPass<gpu::GPUModuleOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestXeGPUMoveFuncBodyToWarpOp)
+
+ StringRef getArgument() const final {
+ return "test-xegpu-move-func-to-warp-op";
+ }
+
+ StringRef getDescription() const final {
+ return "Test the implementation of XeGPU move gpu function body to "
+ "WarpExecuteOnLane0 op.";
+ }
+
+ void getDependentDialects(::mlir::DialectRegistry &registry) const override {
+ registry.insert<xegpu::XeGPUDialect>();
+ registry.insert<gpu::GPUDialect>();
+ }
+
+ TestXeGPUMoveFuncBodyToWarpOp() = default;
+ TestXeGPUMoveFuncBodyToWarpOp(const TestXeGPUMoveFuncBodyToWarpOp &pass) =
+ default;
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ xegpu::populateXeGPUMoveFuncBodyToWarpOpPatterns(patterns);
+ (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+ }
+};
+
struct TestXeGPULayoutInterface
: public PassWrapper<TestXeGPULayoutInterface,
OperationPass<gpu::GPUModuleOp>> {
@@ -312,6 +342,7 @@ void registerTestXeGPULowerings() {
PassRegistration<TestXeGPUUnrollingPatterns>();
PassRegistration<TestXeGPULayoutInterface>();
PassRegistration<TestXeGPUSGDistribute>();
+ PassRegistration<TestXeGPUMoveFuncBodyToWarpOp>();
}
} // namespace test
} // namespace mlir
diff --git a/mlir/test/mlir-runner/utils.mlir b/mlir/test/mlir-runner/utils.mlir
index 0c25078..d3fc23b 100644
--- a/mlir/test/mlir-runner/utils.mlir
+++ b/mlir/test/mlir-runner/utils.mlir
@@ -56,7 +56,7 @@ func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interf
func.func @vector_splat_2d() {
%c0 = arith.constant 0 : index
%f10 = arith.constant 10.0 : f32
- %vf10 = vector.splat %f10: !vector_type_C
+ %vf10 = vector.broadcast %f10: f32 to !vector_type_C
%C = memref.alloc() : !matrix_type_CC
memref.store %vf10, %C[%c0, %c0]: !matrix_type_CC
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
new file mode 100644
index 0000000..acf7db2
--- /dev/null
+++ b/mlir/test/python/rewrite.py
@@ -0,0 +1,69 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+from mlir.ir import *
+from mlir.passmanager import *
+from mlir.dialects.builtin import ModuleOp
+from mlir.dialects import arith
+from mlir.rewrite import *
+
+
+def run(f):
+ print("\nTEST:", f.__name__)
+ f()
+
+
+# CHECK-LABEL: TEST: testRewritePattern
+@run
+def testRewritePattern():
+ def to_muli(op, rewriter):
+ with rewriter.ip:
+ new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location)
+ rewriter.replace_op(op, new_op.owner)
+
+ def constant_1_to_2(op, rewriter):
+ c = op.attributes["value"].value
+ if c != 1:
+ return True # failed to match
+ with rewriter.ip:
+ new_op = arith.constant(op.result.type, 2, loc=op.location)
+ rewriter.replace_op(op, [new_op])
+
+ with Context():
+ patterns = RewritePatternSet()
+ patterns.add(arith.AddIOp, to_muli)
+ patterns.add(arith.ConstantOp, constant_1_to_2)
+ frozen = patterns.freeze()
+
+ module = ModuleOp.parse(
+ r"""
+ module {
+ func.func @add(%a: i64, %b: i64) -> i64 {
+ %sum = arith.addi %a, %b : i64
+ return %sum : i64
+ }
+ }
+ """
+ )
+
+ apply_patterns_and_fold_greedily(module, frozen)
+ # CHECK: %0 = arith.muli %arg0, %arg1 : i64
+ # CHECK: return %0 : i64
+ print(module)
+
+ module = ModuleOp.parse(
+ r"""
+ module {
+ func.func @const() -> (i64, i64) {
+ %0 = arith.constant 1 : i64
+ %1 = arith.constant 3 : i64
+ return %0, %1 : i64, i64
+ }
+ }
+ """
+ )
+
+ apply_patterns_and_fold_greedily(module, frozen)
+ # CHECK: %c2_i64 = arith.constant 2 : i64
+ # CHECK: %c3_i64 = arith.constant 3 : i64
+ # CHECK: return %c2_i64, %c3_i64 : i64, i64
+ print(module)