diff options
Diffstat (limited to 'mlir/test')
177 files changed, 4451 insertions, 3678 deletions
diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt index ac8b44f5..89568e7 100644 --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -68,6 +68,7 @@ endif() llvm_canonicalize_cmake_booleans( LLVM_BUILD_EXAMPLES LLVM_HAS_NVPTX_TARGET + LLVM_INCLUDE_SPIRV_TOOLS_TESTS MLIR_ENABLE_BINDINGS_PYTHON MLIR_ENABLE_CUDA_RUNNER MLIR_ENABLE_ROCM_CONVERSIONS @@ -217,6 +218,11 @@ if(MLIR_ENABLE_BINDINGS_PYTHON) ) endif() +if (LLVM_INCLUDE_SPIRV_TOOLS_TESTS) + list(APPEND MLIR_TEST_DEPENDS spirv-as) + list(APPEND MLIR_TEST_DEPENDS spirv-val) +endif() + # This target can be used to just build the dependencies # for the check-mlir target without executing the tests. # This is useful for bots when splitting the build step diff --git a/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir b/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir index b980451..1d36be1 100644 --- a/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir +++ b/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir @@ -163,27 +163,23 @@ func.func @conversion_f4_f16_fallback(%in: vector<2x2xf4E2M1FN>, %scale: vector< // CHECK-DAG: %[[SCALE_CAST:.+]] = vector.shape_cast %[[BCAST]] // CHECK-DAG: %[[SCALE_EXT:.+]] = arith.extf %[[SCALE_CAST]] // CHECK-DAG: vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 0, 0], sizes = [1, 1, 4], strides = [1, 1, 1]} -// CHECK-NEXT: vector.shape_cast +// CHECK-NEXT: %[[IN_SLICE_CAST:.+]] = vector.shape_cast // CHECK-NEXT: vector.extract %[[SCALE_EXT]][0, 0, 0] -// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]} -// CHECK-NEXT: amdgpu.scaled_ext_packed -// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]} -// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]} -// CHECK-NEXT: amdgpu.scaled_ext_packed -// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]} +// CHECK-NEXT: %[[LOWHALF:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_CAST]][0] +// CHECK-NEXT: vector.insert_strided_slice %[[LOWHALF]], %{{.+}} {offsets = [0], strides = [1]} +// CHECK-NEXT: %[[HIGHHALF:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_CAST]][1] +// CHECK-NEXT: vector.insert_strided_slice %[[HIGHHALF]], %{{.+}} {offsets = [2], strides = [1]} // CHECK-NEXT: vector.shape_cast // CHECK-NEXT: vector.insert_strided_slice %{{.+}} {offsets = [0, 0, 0], strides = [1, 1, 1]} // CHECK-NEXT: vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 1, 0], sizes = [1, 1, 4], strides = [1, 1, 1]} // CHECK-NEXT: vector.shape_cast // CHECK-NEXT: vector.extract %[[SCALE_EXT]][0, 1, 0] -// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]} // CHECK-NEXT: amdgpu.scaled_ext_packed // CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]} -// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]} // CHECK-NEXT: amdgpu.scaled_ext_packed // CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]} // CHECK-NEXT: vector.shape_cast -// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0, 1, 0], strides = [1, 1, 1]} +// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0, 1, 0], strides = [1, 1, 1]} func.func @conversion_broadcast(%in: vector<8x8xf8E5M2>, %scale: vector<8x2xf8E8M0FNU>) -> vector<8x8xf32> { %bc = vector.broadcast %scale : vector<8x2xf8E8M0FNU> to vector<4x8x2xf8E8M0FNU> %cast1 = vector.shape_cast %in : vector<8x8xf8E5M2> to vector<8x2x4xf8E5M2> @@ -203,21 +199,17 @@ func.func @conversion_broadcast(%in: vector<8x8xf8E5M2>, %scale: vector<8x2xf8E8 // CHECK-NEXT: %[[SCALE_EXT:.+]] = arith.extf %[[SCALE_FLAT]] : vector<6xf8E8M0FNU> to vector<6xf32> // CHECK-NEXT: %[[IN_SLICE_0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [3], strides = [1]} : vector<6xf8E5M2> to vector<3xf8E5M2> // CHECK-NEXT: %[[SCALE_SCALAR_0:.+]] = vector.extract %[[SCALE_EXT]][0] : f32 from vector<6xf32> -// CHECK-NEXT: %[[IN_CHUNK_0A:.+]] = vector.extract_strided_slice %[[IN_SLICE_0]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf8E5M2> to vector<2xf8E5M2> -// CHECK-NEXT: %[[PACKED_0A:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_0A]][0], %[[SCALE_SCALAR_0]] : vector<2xf8E5M2> to vector<2xf32> +// CHECK-NEXT: %[[PACKED_0A:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_0]][0], %[[SCALE_SCALAR_0]] : vector<3xf8E5M2> to vector<2xf32> // CHECK-NEXT: %[[PARTIAL_ACC_0:.+]] = vector.insert_strided_slice %[[PACKED_0A]], %[[CST_PARTIAL]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<3xf32> -// CHECK-NEXT: %[[IN_CHUNK_0B:.+]] = vector.extract_strided_slice %[[IN_SLICE_0]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf8E5M2> to vector<1xf8E5M2> -// CHECK-NEXT: %[[PACKED_0B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_0B]][0], %[[SCALE_SCALAR_0]] : vector<1xf8E5M2> to vector<2xf32> +// CHECK-NEXT: %[[PACKED_0B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_0]][1], %[[SCALE_SCALAR_0]] : vector<3xf8E5M2> to vector<2xf32> // CHECK-NEXT: %[[PACKED_0B:.+]] = vector.extract_strided_slice %[[PACKED_0B_RAW]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32> // CHECK-NEXT: %[[OUT_SLICE_0:.+]] = vector.insert_strided_slice %[[PACKED_0B]], %[[PARTIAL_ACC_0]] {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32> // CHECK-NEXT: %[[FINAL_ACC_A:.+]] = vector.insert_strided_slice %[[OUT_SLICE_0]], %[[CST_FINAL]] {offsets = [0], strides = [1]} : vector<3xf32> into vector<6xf32> // CHECK-NEXT: %[[IN_SLICE_1:.+]] = vector.extract_strided_slice %arg0 {offsets = [3], sizes = [3], strides = [1]} : vector<6xf8E5M2> to vector<3xf8E5M2> // CHECK-NEXT: %[[SCALE_SCALAR_1:.+]] = vector.extract %[[SCALE_EXT]][3] : f32 from vector<6xf32> -// CHECK-NEXT: %[[IN_CHUNK_1A:.+]] = vector.extract_strided_slice %[[IN_SLICE_1]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf8E5M2> to vector<2xf8E5M2> -// CHECK-NEXT: %[[PACKED_1A:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_1A]][0], %[[SCALE_SCALAR_1]] : vector<2xf8E5M2> to vector<2xf32> +// CHECK-NEXT: %[[PACKED_1A:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_1]][0], %[[SCALE_SCALAR_1]] : vector<3xf8E5M2> to vector<2xf32> // CHECK-NEXT: %[[PARTIAL_ACC_1:.+]] = vector.insert_strided_slice %[[PACKED_1A]], %[[CST_PARTIAL]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<3xf32> -// CHECK-NEXT: %[[IN_CHUNK_1B:.+]] = vector.extract_strided_slice %[[IN_SLICE_1]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf8E5M2> to vector<1xf8E5M2> -// CHECK-NEXT: %[[PACKED_1B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_1B]][0], %[[SCALE_SCALAR_1]] : vector<1xf8E5M2> to vector<2xf32> +// CHECK-NEXT: %[[PACKED_1B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_1]][1], %[[SCALE_SCALAR_1]] : vector<3xf8E5M2> to vector<2xf32> // CHECK-NEXT: %[[PACKED_1B:.+]] = vector.extract_strided_slice %[[PACKED_1B_RAW]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32> // CHECK-NEXT: %[[OUT_SLICE_1:.+]] = vector.insert_strided_slice %[[PACKED_1B]], %[[PARTIAL_ACC_1]] {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32> // CHECK-NEXT: %[[RESULT:.+]] = vector.insert_strided_slice %[[OUT_SLICE_1]], %[[FINAL_ACC_A]] {offsets = [3], strides = [1]} : vector<3xf32> into vector<6xf32> @@ -236,11 +228,9 @@ func.func @conversion_broadcast_odd(%in: vector<6xf8E5M2>, %scale: vector<2xf8E8 // CHECK-DAG: %[[SCALE_SPLAT:.+]] = vector.broadcast %arg1 : f8E8M0FNU to vector<4xf8E8M0FNU> // CHECK-DAG: %[[SCALE_EXTF:.+]] = arith.extf %[[SCALE_SPLAT]] : vector<4xf8E8M0FNU> to vector<4xf32> // CHECK-DAG: %[[SCALE_SCALAR:.+]] = vector.extract %[[SCALE_EXTF]][0] : f32 from vector<4xf32> -// CHECK: %[[IN_CHUNK0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2> -// CHECK-NEXT: %[[OUT_CHUNK0:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK0]][0], %[[SCALE_SCALAR]] : vector<2xf8E5M2> to vector<2xf32> +// CHECK: %[[OUT_CHUNK0:.+]] = amdgpu.scaled_ext_packed %arg0[0], %[[SCALE_SCALAR]] : vector<4xf8E5M2> to vector<2xf32> // CHECK-NEXT: %[[ACCUM_A:.+]] = vector.insert_strided_slice %[[OUT_CHUNK0]], %[[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> -// CHECK-NEXT: %[[IN_CHUNK1:.+]] = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2> -// CHECK-NEXT: %[[OUT_CHUNK1:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK1]][0], %[[SCALE_SCALAR]] : vector<2xf8E5M2> to vector<2xf32> +// CHECK-NEXT: %[[OUT_CHUNK1:.+]] = amdgpu.scaled_ext_packed %arg0[1], %[[SCALE_SCALAR]] : vector<4xf8E5M2> to vector<2xf32> // CHECK-NEXT: %[[FINAL_RESULT:.+]] = vector.insert_strided_slice %[[OUT_CHUNK1]], %[[ACCUM_A]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> // CHECK-NEXT: return %[[FINAL_RESULT]] : vector<4xf32> func.func @conversion_broadcast(%in: vector<4xf8E5M2>, %scale: f8E8M0FNU) -> vector<4xf32> { @@ -261,3 +251,27 @@ func.func @conversion_scalar(%in: f8E5M2, %scale: f8E8M0FNU) -> f32 { %ext = arith.scaling_extf %in, %scale : f8E5M2, f8E8M0FNU to f32 return %ext : f32 } + +// ----- + +// CHECK-LABEL: @long_fp4_broadcast +// CHECK-COUNT-4: amdgpu.scaled_ext_packed %{{.+}}[3] +// CHECK-NOT: amdgpu.scaled_ext_packed +// CHECK: return +func.func @long_fp4_broadcast(%in: vector<32xf4E2M1FN>, %scale: f32) -> vector<32xf32> { + %splat = vector.broadcast %scale : f32 to vector<32xf32> + %ext = arith.scaling_extf %in, %splat : vector<32xf4E2M1FN>, vector<32xf32> to vector<32xf32> + return %ext : vector<32xf32> +} + +// ----- + +// CHECK-LABEL: @long_fp8_broadcast +// CHECK-COUNT-8: amdgpu.scaled_ext_packed %{{.+}}[1] +// CHECK-NOT: amdgpu.scaled_ext_packed +// CHECK: return +func.func @long_fp8_broadcast(%in: vector<32xf8E4M3FN>, %scale: f32) -> vector<32xf32> { + %splat = vector.broadcast %scale : f32 to vector<32xf32> + %ext = arith.scaling_extf %in, %splat : vector<32xf8E4M3FN>, vector<32xf32> to vector<32xf32> + return %ext : vector<32xf32> +} diff --git a/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir b/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir index 488e75c..90a8608 100644 --- a/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir +++ b/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir @@ -88,28 +88,20 @@ func.func @conversion_f4_fallback(%in: vector<2x2xf32>, %scale: vector<2x2xf8E8M // CHECK-NEXT: vector.shape_cast // CHECK-NEXT: vector.extract %[[SCALE_EXT]][0, 0, 0] // CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]} -// CHECK-NEXT: amdgpu.packed_scaled_trunc -// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]} -// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]} +// CHECK-NEXT: %[[P1:.+]] = amdgpu.packed_scaled_trunc // CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]} -// CHECK-NEXT: amdgpu.packed_scaled_trunc -// CHECK-NEXT: vector.extract_strided_slice -// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]} -// CHECK-NEXT: vector.shape_cast -// CHECK-NEXT: vector.insert_strided_slice %{{.+}} {offsets = [0, 0, 0], strides = [1, 1, 1]} +// CHECK-NEXT: %[[P2:.+]] = amdgpu.packed_scaled_trunc {{.*}} into %[[P1]][1] +// CHECK-NEXT: %[[P2_CAST:.+]] = vector.shape_cast %[[P2]] : vector<4xf8E5M2> to vector<1x1x4xf8E5M2> +// CHECK-NEXT: vector.insert_strided_slice %[[P2_CAST]], %{{.+}} {offsets = [0, 0, 0], strides = [1, 1, 1]} // CHECK-NEXT: vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 1, 0], sizes = [1, 1, 4], strides = [1, 1, 1]} // CHECK-NEXT: vector.shape_cast // CHECK-NEXT: vector.extract %[[SCALE_EXT]][0, 1, 0] // CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]} // CHECK-NEXT: amdgpu.packed_scaled_trunc -// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]} -// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]} // CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]} // CHECK-NEXT: amdgpu.packed_scaled_trunc -// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]} -// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]} // CHECK-NEXT: vector.shape_cast -// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0, 1, 0], strides = [1, 1, 1]} +// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0, 1, 0], strides = [1, 1, 1]} func.func @conversion_broadcast(%in: vector<8x8xf32>, %scale: vector<8x2xf8E8M0FNU>) -> vector<8x8xf8E5M2> { %bc = vector.broadcast %scale : vector<8x2xf8E8M0FNU> to vector<4x8x2xf8E8M0FNU> %cast1 = vector.shape_cast %in : vector<8x8xf32> to vector<8x2x4xf32> @@ -122,7 +114,7 @@ func.func @conversion_broadcast(%in: vector<8x8xf32>, %scale: vector<8x2xf8E8M0F // ----- // CHECK-LABEL: @conversion_broadcast_odd -// CHECK-NEXT: %[[CST3:.+]] = arith.constant dense<0.000000e+00> : vector<3xf8E5M2> +// CHECK-NEXT: %[[CST4:.+]] = arith.constant dense<0.000000e+00> : vector<4xf8E5M2> // CHECK-NEXT: %[[CST6:.+]] = arith.constant dense<0.000000e+00> : vector<6xf8E5M2> // CHECK-NEXT: %[[SCALE_BCAST:.+]] = vector.broadcast %arg1 : vector<2xf8E8M0FNU> to vector<3x2xf8E8M0FNU> // CHECK-NEXT: %[[SCALE_FLAT:.+]] = vector.shape_cast %[[SCALE_BCAST]] : vector<3x2xf8E8M0FNU> to vector<6xf8E8M0FNU> @@ -130,24 +122,18 @@ func.func @conversion_broadcast(%in: vector<8x8xf32>, %scale: vector<8x2xf8E8M0F // CHECK-NEXT: %[[IN_CHUNK0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32> // CHECK-NEXT: %[[SCALE0:.+]] = vector.extract %[[SCALE_EXTF]][0] : f32 from vector<6xf32> // CHECK-NEXT: %[[IN_CHUNK0_PART0:.+]] = vector.extract_strided_slice %[[IN_CHUNK0]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf32> to vector<2xf32> -// CHECK-NEXT: %[[PACKED0_PART0:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK0_PART0]] into undef[0], %[[SCALE0]] : vector<2xf32> to vector<4xf8E5M2> -// CHECK-NEXT: %[[OUT_CHUNK0_PART0:.+]] = vector.extract_strided_slice %[[PACKED0_PART0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2> -// CHECK-NEXT: %[[ACCUM0_PART0:.+]] = vector.insert_strided_slice %[[OUT_CHUNK0_PART0]], %[[CST3]] {offsets = [0], strides = [1]} : vector<2xf8E5M2> into vector<3xf8E5M2> +// CHECK-NEXT: %[[PACKED0_PART0:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK0_PART0]] into %[[CST4]][0], %[[SCALE0]] : vector<2xf32> to vector<4xf8E5M2> // CHECK-NEXT: %[[IN_CHUNK0_PART1:.+]] = vector.extract_strided_slice %[[IN_CHUNK0]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32> -// CHECK-NEXT: %[[PACKED0_PART1:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK0_PART1]] into undef[0], %[[SCALE0]] : vector<1xf32> to vector<4xf8E5M2> -// CHECK-NEXT: %[[OUT_CHUNK0_PART1:.+]] = vector.extract_strided_slice %[[PACKED0_PART1]] {offsets = [0], sizes = [1], strides = [1]} : vector<4xf8E5M2> to vector<1xf8E5M2> -// CHECK-NEXT: %[[CHUNK0_RES:.+]] = vector.insert_strided_slice %[[OUT_CHUNK0_PART1]], %[[ACCUM0_PART0]] {offsets = [2], strides = [1]} : vector<1xf8E5M2> into vector<3xf8E5M2> +// CHECK-NEXT: %[[PACKED0_PART1:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK0_PART1]] into %[[PACKED0_PART0]][1], %[[SCALE0]] : vector<1xf32> to vector<4xf8E5M2> +// CHECK-NEXT: %[[CHUNK0_RES:.+]] = vector.extract_strided_slice %[[PACKED0_PART1]] {offsets = [0], sizes = [3], strides = [1]} : vector<4xf8E5M2> to vector<3xf8E5M2> // CHECK-NEXT: %[[FINAL_ACCUM_A:.+]] = vector.insert_strided_slice %[[CHUNK0_RES]], %[[CST6]] {offsets = [0], strides = [1]} : vector<3xf8E5M2> into vector<6xf8E5M2> // CHECK-NEXT: %[[IN_CHUNK1:.+]] = vector.extract_strided_slice %arg0 {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32> // CHECK-NEXT: %[[SCALE1:.+]] = vector.extract %[[SCALE_EXTF]][3] : f32 from vector<6xf32> // CHECK-NEXT: %[[IN_CHUNK1_PART0:.+]] = vector.extract_strided_slice %[[IN_CHUNK1]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf32> to vector<2xf32> -// CHECK-NEXT: %[[PACKED1_PART0:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK1_PART0]] into undef[0], %[[SCALE1]] : vector<2xf32> to vector<4xf8E5M2> -// CHECK-NEXT: %[[OUT_CHUNK1_PART0:.+]] = vector.extract_strided_slice %[[PACKED1_PART0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2> -// CHECK-NEXT: %[[ACCUM1_PART0:.+]] = vector.insert_strided_slice %[[OUT_CHUNK1_PART0]], %[[CST3]] {offsets = [0], strides = [1]} : vector<2xf8E5M2> into vector<3xf8E5M2> +// CHECK-NEXT: %[[PACKED1_PART0:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK1_PART0]] into %[[CST4]][0], %[[SCALE1]] : vector<2xf32> to vector<4xf8E5M2> // CHECK-NEXT: %[[IN_CHUNK1_PART1:.+]] = vector.extract_strided_slice %[[IN_CHUNK1]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32> -// CHECK-NEXT: %[[PACKED1_PART1:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK1_PART1]] into undef[0], %[[SCALE1]] : vector<1xf32> to vector<4xf8E5M2> -// CHECK-NEXT: %[[OUT_CHUNK1_PART1:.+]] = vector.extract_strided_slice %[[PACKED1_PART1]] {offsets = [0], sizes = [1], strides = [1]} : vector<4xf8E5M2> to vector<1xf8E5M2> -// CHECK-NEXT: %[[CHUNK1_RES:.+]] = vector.insert_strided_slice %[[OUT_CHUNK1_PART1]], %[[ACCUM1_PART0]] {offsets = [2], strides = [1]} : vector<1xf8E5M2> into vector<3xf8E5M2> +// CHECK-NEXT: %[[PACKED1_PART1:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK1_PART1]] into %[[PACKED1_PART0]][1], %[[SCALE1]] : vector<1xf32> to vector<4xf8E5M2> +// CHECK-NEXT: %[[CHUNK1_RES:.+]] = vector.extract_strided_slice %[[PACKED1_PART1]] {offsets = [0], sizes = [3], strides = [1]} : vector<4xf8E5M2> to vector<3xf8E5M2> // CHECK-NEXT: %[[FINAL_RESULT:.+]] = vector.insert_strided_slice %[[CHUNK1_RES]], %[[FINAL_ACCUM_A]] {offsets = [3], strides = [1]} : vector<3xf8E5M2> into vector<6xf8E5M2> // CHECK-NEXT: return %[[FINAL_RESULT]] : vector<6xf8E5M2> func.func @conversion_broadcast_odd(%in: vector<6xf32>, %scale: vector<2xf8E8M0FNU>) -> vector<6xf8E5M2> { @@ -165,14 +151,10 @@ func.func @conversion_broadcast_odd(%in: vector<6xf32>, %scale: vector<2xf8E8M0F // CHECK-DAG: %[[SCALE_EXTF:.+]] = arith.extf %[[SCALE_SPLAT]] : vector<4xf8E8M0FNU> to vector<4xf32> // CHECK-DAG: %[[SCALE_SCALAR:.+]] = vector.extract %[[SCALE_EXTF]][0] : f32 from vector<4xf32> // CHECK: %[[IN_CHUNK0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> -// CHECK-NEXT: %[[PACKED0:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK0]] into undef[0], %[[SCALE_SCALAR]] : vector<2xf32> to vector<4xf8E5M2> -// CHECK-NEXT: %[[OUT_CHUNK0:.+]] = vector.extract_strided_slice %[[PACKED0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2> -// CHECK-NEXT: %[[ACCUM_A:.+]] = vector.insert_strided_slice %[[OUT_CHUNK0]], %[[CST]] {offsets = [0], strides = [1]} : vector<2xf8E5M2> into vector<4xf8E5M2> +// CHECK-NEXT: %[[PACKED0:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK0]] into %[[CST]][0], %[[SCALE_SCALAR]] : vector<2xf32> to vector<4xf8E5M2> // CHECK-NEXT: %[[IN_CHUNK1:.+]] = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> -// CHECK-NEXT: %[[PACKED1:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK1]] into undef[0], %[[SCALE_SCALAR]] : vector<2xf32> to vector<4xf8E5M2> -// CHECK-NEXT: %[[OUT_CHUNK1:.+]] = vector.extract_strided_slice %[[PACKED1]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2> -// CHECK-NEXT: %[[FINAL_RESULT:.+]] = vector.insert_strided_slice %[[OUT_CHUNK1]], %[[ACCUM_A]] {offsets = [2], strides = [1]} : vector<2xf8E5M2> into vector<4xf8E5M2> -// CHECK-NEXT: return %[[FINAL_RESULT]] : vector<4xf8E5M2> +// CHECK-NEXT: %[[PACKED1:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK1]] into %[[PACKED0]][1], %[[SCALE_SCALAR]] : vector<2xf32> to vector<4xf8E5M2> +// CHECK-NEXT: return %[[PACKED1]] : vector<4xf8E5M2> func.func @conversion_broadcast(%in: vector<4xf32>, %scale: f8E8M0FNU) -> vector<4xf8E5M2> { %splat = vector.broadcast %scale : f8E8M0FNU to vector<4xf8E8M0FNU> %ext = arith.scaling_truncf %in, %splat : vector<4xf32>, vector<4xf8E8M0FNU> to vector<4xf8E5M2> @@ -191,3 +173,27 @@ func.func @conversion_scalar(%in: f32, %scale: f8E8M0FNU) -> f8E5M2 { %ext = arith.scaling_truncf %in, %scale : f32, f8E8M0FNU to f8E5M2 return %ext : f8E5M2 } + +// ----- + +// CHECK-LABEL: @long_fp4_broadcast +// CHECK-COUNT-4: amdgpu.packed_scaled_trunc %{{.*}} into %{{.+}}[3] +// CHECK-NOT: amdgpu.packed_scaled_trunc +// CHECK: return +func.func @long_fp4_broadcast(%in: vector<32xf32>, %scale: f32) -> vector<32xf4E2M1FN> { + %splat = vector.broadcast %scale : f32 to vector<32xf32> + %trunc = arith.scaling_truncf %in, %splat : vector<32xf32>, vector<32xf32> to vector<32xf4E2M1FN> + return %trunc : vector<32xf4E2M1FN> +} + +// ----- + +// CHECK-LABEL: @long_fp8_broadcast +// CHECK-COUNT-8: amdgpu.packed_scaled_trunc %{{.*}} into %{{.+}}[1] +// CHECK-NOT: amdgpu.packed_scaled_trunc +// CHECK: return +func.func @long_fp8_broadcast(%in: vector<32xf32>, %scale: f32) -> vector<32xf8E4M3FN> { + %splat = vector.broadcast %scale : f32 to vector<32xf32> + %trunc = arith.scaling_truncf %in, %splat : vector<32xf32>, vector<32xf32> to vector<32xf8E4M3FN> + return %trunc : vector<32xf8E4M3FN> +} diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir index 1abe0fd..6e2352e 100644 --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir @@ -559,6 +559,23 @@ func.func @constant() { return } +// CHECK-LABEL: @constant_8bit_float +func.func @constant_8bit_float() { + // CHECK: spirv.Constant 56 : i8 + %cst = arith.constant 1.0 : f8E4M3 + // CHECK: spirv.Constant 56 : i8 + %cst_i8 = arith.bitcast %cst : f8E4M3 to i8 + // CHECK: spirv.Constant dense<56> : vector<4xi8> + %cst_vector = arith.constant dense<1.0> : vector<4xf8E4M3> + // CHECK: spirv.Constant dense<56> : vector<4xi8> + %cst_vector_i8 = arith.bitcast %cst_vector : vector<4xf8E4M3> to vector<4xi8> + // CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8> + %cst_tensor = arith.constant dense<1.0> : tensor<4xf8E5M2> + // CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8> + %cst_tensor_i8 = arith.bitcast %cst_tensor : tensor<4xf8E5M2> to tensor<4xi8> + return +} + // CHECK-LABEL: @constant_16bit func.func @constant_16bit() { // CHECK: spirv.Constant 4 : i16 diff --git a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir index bae7c59..ae59f28 100644 --- a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir +++ b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir @@ -2,8 +2,26 @@ // CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32 // CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64 +// CHECK-DAG: @__ocml_carg_f32(complex<f32>) -> f32 +// CHECK-DAG: @__ocml_carg_f64(complex<f64>) -> f64 +// CHECK-DAG: @__ocml_ccos_f32(complex<f32>) -> complex<f32> +// CHECK-DAG: @__ocml_ccos_f64(complex<f64>) -> complex<f64> // CHECK-DAG: @__ocml_cexp_f32(complex<f32>) -> complex<f32> // CHECK-DAG: @__ocml_cexp_f64(complex<f64>) -> complex<f64> +// CHECK-DAG: @__ocml_clog_f32(complex<f32>) -> complex<f32> +// CHECK-DAG: @__ocml_clog_f64(complex<f64>) -> complex<f64> +// CHECK-DAG: @__ocml_conj_f32(complex<f32>) -> complex<f32> +// CHECK-DAG: @__ocml_conj_f64(complex<f64>) -> complex<f64> +// CHECK-DAG: @__ocml_cpow_f32(complex<f32>, complex<f32>) -> complex<f32> +// CHECK-DAG: @__ocml_cpow_f64(complex<f64>, complex<f64>) -> complex<f64> +// CHECK-DAG: @__ocml_csin_f32(complex<f32>) -> complex<f32> +// CHECK-DAG: @__ocml_csin_f64(complex<f64>) -> complex<f64> +// CHECK-DAG: @__ocml_csqrt_f32(complex<f32>) -> complex<f32> +// CHECK-DAG: @__ocml_csqrt_f64(complex<f64>) -> complex<f64> +// CHECK-DAG: @__ocml_ctan_f32(complex<f32>) -> complex<f32> +// CHECK-DAG: @__ocml_ctan_f64(complex<f64>) -> complex<f64> +// CHECK-DAG: @__ocml_ctanh_f32(complex<f32>) -> complex<f32> +// CHECK-DAG: @__ocml_ctanh_f64(complex<f64>) -> complex<f64> //CHECK-LABEL: @abs_caller func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) { @@ -15,6 +33,26 @@ func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) { return %rf, %rd : f32, f64 } +//CHECK-LABEL: @angle_caller +func.func @angle_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) { + // CHECK: %[[AF:.*]] = call @__ocml_carg_f32(%{{.*}}) + %af = complex.angle %f : complex<f32> + // CHECK: %[[AD:.*]] = call @__ocml_carg_f64(%{{.*}}) + %ad = complex.angle %d : complex<f64> + // CHECK: return %[[AF]], %[[AD]] + return %af, %ad : f32, f64 +} + +//CHECK-LABEL: @cos_caller +func.func @cos_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { + // CHECK: %[[CF:.*]] = call @__ocml_ccos_f32(%{{.*}}) + %cf = complex.cos %f : complex<f32> + // CHECK: %[[CD:.*]] = call @__ocml_ccos_f64(%{{.*}}) + %cd = complex.cos %d : complex<f64> + // CHECK: return %[[CF]], %[[CD]] + return %cf, %cd : complex<f32>, complex<f64> +} + //CHECK-LABEL: @exp_caller func.func @exp_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { // CHECK: %[[EF:.*]] = call @__ocml_cexp_f32(%{{.*}}) @@ -24,3 +62,73 @@ func.func @exp_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, comp // CHECK: return %[[EF]], %[[ED]] return %ef, %ed : complex<f32>, complex<f64> } + +//CHECK-LABEL: @log_caller +func.func @log_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { + // CHECK: %[[LF:.*]] = call @__ocml_clog_f32(%{{.*}}) + %lf = complex.log %f : complex<f32> + // CHECK: %[[LD:.*]] = call @__ocml_clog_f64(%{{.*}}) + %ld = complex.log %d : complex<f64> + // CHECK: return %[[LF]], %[[LD]] + return %lf, %ld : complex<f32>, complex<f64> +} + +//CHECK-LABEL: @conj_caller +func.func @conj_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { + // CHECK: %[[CF:.*]] = call @__ocml_conj_f32(%{{.*}}) + %cf2 = complex.conj %f : complex<f32> + // CHECK: %[[CD:.*]] = call @__ocml_conj_f64(%{{.*}}) + %cd2 = complex.conj %d : complex<f64> + // CHECK: return %[[CF]], %[[CD]] + return %cf2, %cd2 : complex<f32>, complex<f64> +} + +//CHECK-LABEL: @pow_caller +func.func @pow_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { + // CHECK: %[[PF:.*]] = call @__ocml_cpow_f32(%{{.*}}, %{{.*}}) + %pf = complex.pow %f, %f : complex<f32> + // CHECK: %[[PD:.*]] = call @__ocml_cpow_f64(%{{.*}}, %{{.*}}) + %pd = complex.pow %d, %d : complex<f64> + // CHECK: return %[[PF]], %[[PD]] + return %pf, %pd : complex<f32>, complex<f64> +} + +//CHECK-LABEL: @sin_caller +func.func @sin_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { + // CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}}) + %sf2 = complex.sin %f : complex<f32> + // CHECK: %[[SD:.*]] = call @__ocml_csin_f64(%{{.*}}) + %sd2 = complex.sin %d : complex<f64> + // CHECK: return %[[SF]], %[[SD]] + return %sf2, %sd2 : complex<f32>, complex<f64> +} + +//CHECK-LABEL: @sqrt_caller +func.func @sqrt_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { + // CHECK: %[[SF:.*]] = call @__ocml_csqrt_f32(%{{.*}}) + %sf = complex.sqrt %f : complex<f32> + // CHECK: %[[SD:.*]] = call @__ocml_csqrt_f64(%{{.*}}) + %sd = complex.sqrt %d : complex<f64> + // CHECK: return %[[SF]], %[[SD]] + return %sf, %sd : complex<f32>, complex<f64> +} + +//CHECK-LABEL: @tan_caller +func.func @tan_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { + // CHECK: %[[TF:.*]] = call @__ocml_ctan_f32(%{{.*}}) + %tf2 = complex.tan %f : complex<f32> + // CHECK: %[[TD:.*]] = call @__ocml_ctan_f64(%{{.*}}) + %td2 = complex.tan %d : complex<f64> + // CHECK: return %[[TF]], %[[TD]] + return %tf2, %td2 : complex<f32>, complex<f64> +} + +//CHECK-LABEL: @tanh_caller +func.func @tanh_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { + // CHECK: %[[TF:.*]] = call @__ocml_ctanh_f32(%{{.*}}) + %tf = complex.tanh %f : complex<f32> + // CHECK: %[[TD:.*]] = call @__ocml_ctanh_f64(%{{.*}}) + %td = complex.tanh %d : complex<f64> + // CHECK: return %[[TF]], %[[TD]] + return %tf, %td : complex<f32>, complex<f64> +} diff --git a/mlir/test/Conversion/ConvertToSPIRV/convert-gpu-modules.mlir b/mlir/test/Conversion/ConvertToSPIRV/convert-gpu-modules.mlir index 00bbd1c..96ad107 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/convert-gpu-modules.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/convert-gpu-modules.mlir @@ -85,11 +85,10 @@ module attributes { // CHECK: spirv.Load "StorageBuffer" %val = memref.load %arg0[%idx0] : memref<2xi32> // CHECK: spirv.CompositeInsert - %vec = vector.insertelement %val, %vec0[%idx0 : index] : vector<2xi32> + %vec = vector.insert %val, %vec0[%idx0] : i32 into vector<2xi32> // CHECK: spirv.VectorShuffle %shuffle = vector.shuffle %vec, %vec[3, 2, 1, 0] : vector<2xi32>, vector<2xi32> - // CHECK: spirv.CompositeExtract - %res = vector.extractelement %shuffle[%idx0 : index] : vector<4xi32> + %res = vector.extract %shuffle[%idx0] : i32 from vector<4xi32> // CHECK: spirv.AccessChain // CHECK: spirv.Store "StorageBuffer" memref.store %res, %arg1[%idx0]: memref<4xi32> @@ -102,9 +101,9 @@ module attributes { // CHECK-SAME: %{{.*}}: memref<2xi32>, %{{.*}}: memref<4xi32> // CHECK: arith.constant // CHECK: memref.load - // CHECK: vector.insertelement + // CHECK: vector.insert // CHECK: vector.shuffle - // CHECK: vector.extractelement + // CHECK: vector.extract // CHECK: memref.store // CHECK: gpu.return } diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir index fb14feb..eb9feaa 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir @@ -51,108 +51,6 @@ func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf3 // ----- -// CHECK-LABEL: @extract_element -// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32 -// CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32 -func.func @extract_element(%arg0 : vector<4xf32>, %id : i32) -> f32 { - %0 = vector.extractelement %arg0[%id : i32] : vector<4xf32> - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_cst -// CHECK-SAME: %[[V:.*]]: vector<4xf32> -// CHECK: spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32> -func.func @extract_element_cst(%arg0 : vector<4xf32>) -> f32 { - %idx = arith.constant 1 : i32 - %0 = vector.extractelement %arg0[%idx : i32] : vector<4xf32> - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_index -func.func @extract_element_index(%arg0 : vector<4xf32>, %id : index) -> f32 { - // CHECK: spirv.VectorExtractDynamic - %0 = vector.extractelement %arg0[%id : index] : vector<4xf32> - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_size1_vector -// CHECK-SAME:(%[[S:.+]]: f32, -func.func @extract_element_size1_vector(%arg0 : f32, %i: index) -> f32 { - %bcast = vector.broadcast %arg0 : f32 to vector<1xf32> - %0 = vector.extractelement %bcast[%i : index] : vector<1xf32> - // CHECK: spirv.ReturnValue %[[S]] - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_0d_vector -// CHECK-SAME: (%[[S:.+]]: f32) -func.func @extract_element_0d_vector(%arg0 : f32) -> f32 { - %bcast = vector.broadcast %arg0 : f32 to vector<f32> - %0 = vector.extractelement %bcast[] : vector<f32> - // CHECK: spirv.ReturnValue %[[S]] - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @insert_element -// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32 -// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32 -func.func @insert_element(%val: f32, %arg0 : vector<4xf32>, %id : i32) -> vector<4xf32> { - %0 = vector.insertelement %val, %arg0[%id : i32] : vector<4xf32> - return %0: vector<4xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_cst -// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32> -// CHECK: spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32> -func.func @insert_element_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> { - %idx = arith.constant 2 : i32 - %0 = vector.insertelement %val, %arg0[%idx : i32] : vector<4xf32> - return %0: vector<4xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_index -func.func @insert_element_index(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> { - // CHECK: spirv.VectorInsertDynamic - %0 = vector.insertelement %val, %arg0[%id : index] : vector<4xf32> - return %0: vector<4xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_size1_vector -// CHECK-SAME: (%[[S:[a-z0-9]+]]: f32 -func.func @insert_element_size1_vector(%scalar: f32, %vector : vector<1xf32>, %i: index) -> vector<1xf32> { - %0 = vector.insertelement %scalar, %vector[%i : index] : vector<1xf32> - // CHECK: spirv.ReturnValue %[[S]] - return %0: vector<1xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_0d_vector -// CHECK-SAME: (%[[S:[a-z0-9]+]]: f32 -func.func @insert_element_0d_vector(%scalar: f32, %vector : vector<f32>) -> vector<f32> { - %0 = vector.insertelement %scalar, %vector[] : vector<f32> - // CHECK: spirv.ReturnValue %[[S]] - return %0: vector<f32> -} - -// ----- - // CHECK-LABEL: @insert_size1_vector // CHECK-SAME: %[[SUB:.*]]: f32, %[[FULL:.*]]: vector<3xf32> // CHECK: %[[RET:.*]] = spirv.CompositeInsert %[[SUB]], %[[FULL]][2 : i32] : f32 into vector<3xf32> diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir index 1737f4a..0c77c88 100644 --- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir +++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir @@ -1,6 +1,8 @@ // RUN: mlir-opt -split-input-file -convert-func-to-spirv %s -o - | FileCheck %s // RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-lt-32-bit-scalar-types=false" %s | \ // RUN: FileCheck %s --check-prefix=NOEMU +// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-unsupported-float-types=false" %s | \ +// RUN: FileCheck %s --check-prefix=UNSUPPORTED_FLOAT //===----------------------------------------------------------------------===// // Integer types @@ -944,3 +946,55 @@ func.func @unranked_tensor(%arg0: tensor<*xi32>) { return } func.func @dynamic_dim_tensor(%arg0: tensor<8x?xi32>) { return } } // end module + + +// ----- + +// Check that 8-bit float types are emulated as i8. +module attributes { + spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int8], []>, #spirv.resource_limits<>> +} { + + // CHECK: spirv.func @float8_to_integer8 + // CHECK-SAME: (%arg0: i8 + // CHECK-SAME: %arg1: i8 + // CHECK-SAME: %arg2: i8 + // CHECK-SAME: %arg3: i8 + // CHECK-SAME: %arg4: i8 + // CHECK-SAME: %arg5: i8 + // CHECK-SAME: %arg6: i8 + // CHECK-SAME: %arg7: i8 + // CHECK-SAME: %arg8: vector<4xi8> + // CHECK-SAME: %arg9: !spirv.ptr<!spirv.struct<(!spirv.array<8 x i8, stride=1> [0])>, StorageBuffer> + // CHECK-SAME: %arg10: !spirv.array<4 x i8> + // UNSUPPORTED_FLOAT-LABEL: func.func @float8_to_integer8 + // UNSUPPORTED_FLOAT-SAME: (%arg0: f8E5M2 + // UNSUPPORTED_FLOAT-SAME: %arg1: f8E4M3 + // UNSUPPORTED_FLOAT-SAME: %arg2: f8E4M3FN + // UNSUPPORTED_FLOAT-SAME: %arg3: f8E5M2FNUZ + // UNSUPPORTED_FLOAT-SAME: %arg4: f8E4M3FNUZ + // UNSUPPORTED_FLOAT-SAME: %arg5: f8E4M3B11FNUZ + // UNSUPPORTED_FLOAT-SAME: %arg6: f8E3M4 + // UNSUPPORTED_FLOAT-SAME: %arg7: f8E8M0FNU + // UNSUPPORTED_FLOAT-SAME: %arg8: vector<4xf8E4M3B11FNUZ> + // UNSUPPORTED_FLOAT-SAME: %arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>> + // UNSUPPORTED_FLOAT-SAME: %arg10: tensor<4xf8E5M2> + // UNSUPPORTED_FLOAT-SAME: ) { + + func.func @float8_to_integer8( + %arg0: f8E5M2, // CHECK-NOT: f8E5M2 + %arg1: f8E4M3, // CHECK-NOT: f8E4M3 + %arg2: f8E4M3FN, // CHECK-NOT: f8E4M3FN + %arg3: f8E5M2FNUZ, // CHECK-NOT: f8E5M2FNUZ + %arg4: f8E4M3FNUZ, // CHECK-NOT: f8E4M3FNUZ + %arg5: f8E4M3B11FNUZ, // CHECK-NOT: f8E4M3B11FNUZ + %arg6: f8E3M4, // CHECK-NOT: f8E3M4 + %arg7: f8E8M0FNU, // CHECK-NOT: f8E8M0FNU + %arg8: vector<4xf8E4M3B11FNUZ>, // CHECK-NOT: vector<4xf8E4M3B11FNUZ> + %arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>, // CHECK-NOT: memref + %arg10: tensor<4xf8E5M2> // CHECK-NOT: tensor + ) { + // CHECK: spirv.Return + return + } +} diff --git a/mlir/test/Conversion/GPUToSPIRV/lookup-target-env.mlir b/mlir/test/Conversion/GPUToSPIRV/lookup-target-env.mlir new file mode 100644 index 0000000..983747b --- /dev/null +++ b/mlir/test/Conversion/GPUToSPIRV/lookup-target-env.mlir @@ -0,0 +1,40 @@ +// RUN: mlir-opt --split-input-file --convert-gpu-to-spirv %s | FileCheck %s + +module attributes {gpu.container_module} { + // CHECK-LABEL: spirv.module @{{.*}} GLSL450 + gpu.module @kernels [#spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>] { + // CHECK: spirv.func @load_kernel + // CHECK-SAME: %[[ARG:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<48 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}) + gpu.func @load_kernel(%arg0: memref<12x4xf32>) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} { + %c0 = arith.constant 0 : index + // CHECK: %[[PTR:.*]] = spirv.AccessChain %[[ARG]]{{\[}}{{%.*}}, {{%.*}}{{\]}} + // CHECK-NEXT: {{%.*}} = spirv.Load "StorageBuffer" %[[PTR]] : f32 + %0 = memref.load %arg0[%c0, %c0] : memref<12x4xf32> + // CHECK: spirv.Return + gpu.return + } + } +} + +// ----- +// Checks that the `-convert-gpu-to-spirv` pass selects the first +// `spirv.target_env` from the `targets` array attribute attached to `gpu.module`. +module attributes {gpu.container_module} { + // CHECK-LABEL: spirv.module @{{.*}} GLSL450 + // CHECK-SAME: #spirv.target_env<#spirv.vce<v1.4, [Shader], [SPV_KHR_storage_buffer_storage_class]> + gpu.module @kernels [ + #spirv.target_env<#spirv.vce<v1.4, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>, + #spirv.target_env<#spirv.vce<v1.0, [Kernel], []>, #spirv.resource_limits<>>, + #spirv.target_env<#spirv.vce<v1.0, [Shader], []>, #spirv.resource_limits<>>] { + // CHECK: spirv.func @load_kernel + // CHECK-SAME: %[[ARG:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<48 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}) + gpu.func @load_kernel(%arg0: memref<12x4xf32>) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} { + %c0 = arith.constant 0 : index + // CHECK: %[[PTR:.*]] = spirv.AccessChain %[[ARG]]{{\[}}{{%.*}}, {{%.*}}{{\]}} + // CHECK-NEXT: {{%.*}} = spirv.Load "StorageBuffer" %[[PTR]] : f32 + %0 = memref.load %arg0[%c0, %c0] : memref<12x4xf32> + // CHECK: spirv.Return + gpu.return + } + } +} diff --git a/mlir/test/Conversion/GPUToSPIRV/rotate.mlir b/mlir/test/Conversion/GPUToSPIRV/rotate.mlir index b96dd37..c71d220 100644 --- a/mlir/test/Conversion/GPUToSPIRV/rotate.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/rotate.mlir @@ -10,16 +10,14 @@ gpu.module @kernels { // CHECK-LABEL: spirv.func @rotate() gpu.func @rotate() kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} { - %offset = arith.constant 4 : i32 - %width = arith.constant 16 : i32 %val = arith.constant 42.0 : f32 + // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32 // CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32 // CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32 - // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32 // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32 // CHECK: %{{.+}} = spirv.Constant true - %result, %valid = gpu.rotate %val, %offset, %width : f32 + %result, %valid = gpu.rotate %val, 4, 16 : f32 gpu.return } } @@ -38,18 +36,16 @@ gpu.module @kernels { // CHECK-LABEL: spirv.func @rotate_width_less_than_subgroup_size() gpu.func @rotate_width_less_than_subgroup_size() kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} { - %offset = arith.constant 4 : i32 - %width = arith.constant 8 : i32 %val = arith.constant 42.0 : f32 + // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32 // CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32 // CHECK: %[[WIDTH:.+]] = spirv.Constant 8 : i32 - // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32 // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32 // CHECK: %[[INVOCATION_ID_ADDR:.+]] = spirv.mlir.addressof @__builtin__SubgroupLocalInvocationId__ // CHECK: %[[INVOCATION_ID:.+]] = spirv.Load "Input" %[[INVOCATION_ID_ADDR]] // CHECK: %{{.+}} = spirv.ULessThan %[[INVOCATION_ID]], %[[WIDTH]] - %result, %valid = gpu.rotate %val, %offset, %width : f32 + %result, %valid = gpu.rotate %val, 4, 8 : f32 gpu.return } } @@ -67,34 +63,10 @@ module attributes { gpu.module @kernels { gpu.func @rotate_with_bigger_than_subgroup_size() kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} { - %offset = arith.constant 4 : i32 - %width = arith.constant 32 : i32 %val = arith.constant 42.0 : f32 // expected-error @+1 {{failed to legalize operation 'gpu.rotate'}} - %result, %valid = gpu.rotate %val, %offset, %width : f32 - gpu.return - } -} - -} - -// ----- - -module attributes { - gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformRotateKHR], []>, - #spirv.resource_limits<subgroup_size = 16>> -} { - -gpu.module @kernels { - gpu.func @rotate_non_const_width(%width: i32) kernel - attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} { - %offset = arith.constant 4 : i32 - %val = arith.constant 42.0 : f32 - - // expected-error @+1 {{'gpu.rotate' op width is not a constant value}} - %result, %valid = gpu.rotate %val, %offset, %width : f32 + %result, %valid = gpu.rotate %val, 4, 32 : f32 gpu.return } } diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-fpclassify-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-fpclassify-spirv.mlir new file mode 100644 index 0000000..3e5f592 --- /dev/null +++ b/mlir/test/Conversion/MathToSPIRV/math-to-fpclassify-spirv.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt --convert-math-to-spirv %s | FileCheck %s + +module attributes { + spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], []>, #spirv.resource_limits<>> +} { + + // CHECK-LABEL: @fpclassify + func.func @fpclassify(%x: f32, %v: vector<4xf32>) { + // CHECK: spirv.IsFinite %{{.*}} : f32 + %0 = math.isfinite %x : f32 + // CHECK: spirv.IsFinite %{{.*}} : vector<4xf32> + %1 = math.isfinite %v : vector<4xf32> + + // CHECK: spirv.IsNan %{{.*}} : f32 + %2 = math.isnan %x : f32 + // CHECK: spirv.IsNan %{{.*}} : vector<4xf32> + %3 = math.isnan %v : vector<4xf32> + + // CHECK: spirv.IsInf %{{.*}} : f32 + %4 = math.isinf %x : f32 + // CHECK: spirv.IsInf %{{.*}} : vector<4xf32> + %5 = math.isinf %v : vector<4xf32> + + return + } + +} diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc.mlir new file mode 100644 index 0000000..e391a89 --- /dev/null +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc.mlir @@ -0,0 +1,72 @@ +// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefix=CPP +// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefix=NOCPP + +func.func @alloc() { + %alloc = memref.alloc() : memref<999xi32> + return +} + +// CPP: module { +// CPP-NEXT: emitc.include <"cstdlib"> +// CPP-LABEL: alloc() +// CPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t +// CPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index +// CPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t +// CPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">> +// CPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32> +// CPP-NEXT: return + +// NOCPP: module { +// NOCPP-NEXT: emitc.include <"stdlib.h"> +// NOCPP-LABEL: alloc() +// NOCPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t +// NOCPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index +// NOCPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t +// NOCPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">> +// NOCPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32> +// NOCPP-NEXT: return + +func.func @alloc_aligned() { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<999xf32> + return +} + +// CPP-LABEL: alloc_aligned +// CPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t +// CPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index +// CPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t +// CPP-NEXT: %[[ALIGNMENT:.*]] = "emitc.constant"() <{value = 64 : index}> : () -> !emitc.size_t +// CPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "aligned_alloc"(%[[ALIGNMENT]], %[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t, !emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">> +// CPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<f32> +// CPP-NEXT: return + +// NOCPP-LABEL: alloc_aligned +// NOCPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t +// NOCPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index +// NOCPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t +// NOCPP-NEXT: %[[ALIGNMENT:.*]] = "emitc.constant"() <{value = 64 : index}> : () -> !emitc.size_t +// NOCPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "aligned_alloc"(%[[ALIGNMENT]], %[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t, !emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">> +// NOCPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<f32> +// NOCPP-NEXT: return + +func.func @allocating_multi() { + %alloc_5 = memref.alloc() : memref<7x999xi32> + return +} + +// CPP-LABEL: allocating_multi +// CPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t +// CPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 6993 : index}> : () -> index +// CPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t +// CPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void"> +// CPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32> +// CPP-NEXT: return + +// NOCPP-LABEL: allocating_multi +// NOCPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t +// NOCPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 6993 : index}> : () -> index +// NOCPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t +// NOCPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">> +// NOCPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32> +// NOCPP-NEXT: return + diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir index 8d720ce..580b09d 100644 --- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir +++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir @@ -580,30 +580,6 @@ func.func @elect_one_leader_sync() { // ----- -// CHECK-LABEL: @stmatrix( -// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !llvm.ptr<3>, -// CHECK-SAME: %[[arg1:[a-zA-Z0-9_]+]]: i32, -// CHECK-SAME: %[[arg2:[a-zA-Z0-9_]+]]: i32, -// CHECK-SAME: %[[arg3:[a-zA-Z0-9_]+]]: i32, -// CHECK-SAME: %[[arg4:[a-zA-Z0-9_]+]]: i32) -llvm.func @stmatrix(%arg0 : !llvm.ptr<3>, %m1 : i32, %m2 : i32, %m3 : i32, %m4 : i32) { -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.m8n8.shared.b16 [$0], {$1};", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> () -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.m8n8.shared.b16 [$0], {$1, $2};", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> () -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> () -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [$0], {$1};", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> () -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [$0], {$1, $2};", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> () -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> () - nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32 - nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32 - nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32, i32, i32 - nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32 - nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32 - nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32, i32, i32 - llvm.return -} - -// ----- - // CHECK-LABEL: @init_mbarrier_arrive_expect_tx llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) { //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "prefetch.tensormap [$0];", "l" diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir index d54d003..5e20b5a 100644 --- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir +++ b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir @@ -1,14 +1,14 @@ -// RUN: mlir-opt %s -convert-mesh-to-mpi -canonicalize -split-input-file | FileCheck %s +// RUN: mlir-opt %s -convert-shard-to-mpi -canonicalize -split-input-file | FileCheck %s // ----- -// CHECK: mesh.mesh @mesh0 -mesh.mesh @mesh0(shape = 3x4x5) +// CHECK: shard.grid @grid0 +shard.grid @grid0(shape = 3x4x5) func.func @process_multi_index() -> (index, index, index) { // CHECK: mpi.comm_rank // CHECK-DAG: %[[v4:.*]] = arith.remsi // CHECK-DAG: %[[v0:.*]] = arith.remsi // CHECK-DAG: %[[v1:.*]] = arith.remsi - %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index + %0:3 = shard.process_multi_index on @grid0 axes = [] : index, index, index // CHECK: return %[[v1]], %[[v0]], %[[v4]] : index, index, index return %0#0, %0#1, %0#2 : index, index, index } @@ -17,7 +17,7 @@ func.func @process_multi_index() -> (index, index, index) { func.func @process_linear_index() -> index { // CHECK: %[[RES:.*]], %[[rank:.*]] = mpi.comm_rank // CHECK: %[[cast:.*]] = arith.index_cast %[[rank]] : i32 to index - %0 = mesh.process_linear_index on @mesh0 : index + %0 = shard.process_linear_index on @grid0 : index // CHECK: return %[[cast]] : index return %0 : index } @@ -29,7 +29,7 @@ func.func @neighbors_dim0(%arg0 : tensor<120x120x120xi8>) -> (index, index) { %c4 = arith.constant 4 : index // CHECK-DAG: [[up:%.*]] = arith.constant 44 : index // CHECK-DAG: [[down:%.*]] = arith.constant 4 : index - %idx:2 = mesh.neighbors_linear_indices on @mesh0[%c1, %c0, %c4] split_axes = [0] : index, index + %idx:2 = shard.neighbors_linear_indices on @grid0[%c1, %c0, %c4] split_axes = [0] : index, index // CHECK: return [[down]], [[up]] : index, index return %idx#0, %idx#1 : index, index } @@ -41,7 +41,7 @@ func.func @neighbors_dim1(%arg0 : tensor<120x120x120xi8>) -> (index, index) { %c4 = arith.constant 4 : index // CHECK-DAG: [[up:%.*]] = arith.constant 29 : index // CHECK-DAG: [[down:%.*]] = arith.constant -1 : index - %idx:2 = mesh.neighbors_linear_indices on @mesh0[%c1, %c0, %c4] split_axes = [1] : index, index + %idx:2 = shard.neighbors_linear_indices on @grid0[%c1, %c0, %c4] split_axes = [1] : index, index // CHECK: return [[down]], [[up]] : index, index return %idx#0, %idx#1 : index, index } @@ -53,20 +53,20 @@ func.func @neighbors_dim2(%arg0 : tensor<120x120x120xi8>) -> (index, index) { %c4 = arith.constant 4 : index // CHECK-DAG: [[up:%.*]] = arith.constant -1 : index // CHECK-DAG: [[down:%.*]] = arith.constant 23 : index - %idx:2 = mesh.neighbors_linear_indices on @mesh0[%c1, %c0, %c4] split_axes = [2] : index, index + %idx:2 = shard.neighbors_linear_indices on @grid0[%c1, %c0, %c4] split_axes = [2] : index, index // CHECK: return [[down]], [[up]] : index, index return %idx#0, %idx#1 : index, index } // ----- -// CHECK: mesh.mesh @mesh0 +// CHECK: shard.grid @grid0 module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } { - mesh.mesh @mesh0(shape = 3x4x5) + shard.grid @grid0(shape = 3x4x5) func.func @process_multi_index() -> (index, index, index) { // CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index - %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index + %0:3 = shard.process_multi_index on @grid0 axes = [] : index, index, index // CHECK: return %[[c1]], %[[c0]], %[[c4]] : index, index, index return %0#0, %0#1, %0#2 : index, index, index } @@ -74,7 +74,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } { // CHECK-LABEL: func @process_linear_index func.func @process_linear_index() -> index { // CHECK: %[[c24:.*]] = arith.constant 24 : index - %0 = mesh.process_linear_index on @mesh0 : index + %0 = shard.process_linear_index on @grid0 : index // CHECK: return %[[c24]] : index return %0 : index } @@ -82,7 +82,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } { // ----- module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } { - mesh.mesh @mesh0(shape = 3x4x5) + shard.grid @grid0(shape = 3x4x5) // CHECK-LABEL: func.func @allreduce_tensor( func.func @allreduce_tensor( // CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32> @@ -97,7 +97,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } { // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf32> into memref<12xf32> // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf32>, memref<12xf32> // CHECK: [[v2:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<3x4xf32> to tensor<3x4xf32> - %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1] reduction = max : tensor<3x4xf32> -> tensor<3x4xf32> + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : tensor<3x4xf32> -> tensor<3x4xf32> // CHECK: return [[v2]] : tensor<3x4xf32> return %0 : tensor<3x4xf32> } @@ -114,7 +114,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } { // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf32> into memref<12xf32> // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf32>, memref<12xf32> - %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf32> + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf32> // CHECK: return [[valloc]] : memref<3x4xf32> return %0 : memref<3x4xf32> } @@ -131,14 +131,14 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } { // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf64> into memref<12xf64> // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf64>, memref<12xf64> - %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf64> + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf64> // CHECK: return [[valloc]] : memref<3x4xf64> return %0 : memref<3x4xf64> } } // ----- -mesh.mesh @mesh0(shape = 3x4x5) +shard.grid @grid0(shape = 3x4x5) // CHECK-LABEL: func @update_halo_1d_first func.func @update_halo_1d_first( // CHECK-SAME: [[arg0:%.*]]: memref<120x120x120xi8> @@ -155,14 +155,14 @@ func.func @update_halo_1d_first( // CHECK: mpi.recv( // CHECK-SAME: : memref<3x120x120xi8>, i32, i32 // CHECK: memref.subview [[arg0]][117, 0, 0] [3, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<3x120x120xi8 - %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 3] : memref<120x120x120xi8> + %res = shard.update_halo %arg0 on @grid0 split_axes = [[0]] halo_sizes = [2, 3] : memref<120x120x120xi8> // CHECK: return [[res:%.*]] : memref<120x120x120xi8> return %res : memref<120x120x120xi8> } // ----- module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 1> } { - mesh.mesh @mesh0(shape = 4) + shard.grid @grid0(shape = 4) // CHECK-LABEL: func @update_halo_1d_with_zero func.func @update_halo_1d_with_zero ( // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8> @@ -179,7 +179,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 1> } { // CHECK: [[vsubview_0:%.*]] = memref.subview [[varg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1]>> // CHECK: memref.copy [[valloc]], [[vsubview_0]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1]>> // CHECK: memref.dealloc [[valloc]] : memref<2x120x120xi8> - %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 0] : memref<120x120x120xi8> + %res = shard.update_halo %arg0 on @grid0 split_axes = [[0]] halo_sizes = [2, 0] : memref<120x120x120xi8> // CHECK: return [[varg0]] : memref<120x120x120xi8> return %res : memref<120x120x120xi8> } @@ -187,7 +187,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 1> } { // ----- module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } { - mesh.mesh @mesh0(shape = 3x4x5) + shard.grid @grid0(shape = 3x4x5) // CHECK-LABEL: func @update_halo_3d func.func @update_halo_3d( // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8> @@ -236,7 +236,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } { // CHECK: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8> // CHECK: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]], [[v2]]) : memref<2x120x120xi8>, i32, i32 // CHECK: memref.dealloc [[valloc_10]] : memref<2x120x120xi8> - %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8> + %res = shard.update_halo %arg0 on @grid0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8> // CHECK: return [[varg0]] : memref<120x120x120xi8> return %res : memref<120x120x120xi8> } @@ -291,18 +291,18 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } { // CHECK: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]], [[v3]]) : memref<2x120x120xi8>, i32, i32 // CHECK: memref.dealloc [[valloc_10]] : memref<2x120x120xi8> // CHECK: [[v4:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8> to tensor<120x120x120xi8> - %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8> + %res = shard.update_halo %arg0 on @grid0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8> // CHECK: return [[v4]] : tensor<120x120x120xi8> return %res : tensor<120x120x120xi8> } } // ----- -mesh.mesh @mesh0(shape = 2x2x4) +shard.grid @grid0(shape = 2x2x4) // CHECK-LABEL: func.func @return_sharding( // CHECK-SAME: [[varg0:%.*]]: tensor<2x4xf32>) -> (tensor<2x4xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) { -func.func @return_sharding(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>, !mesh.sharding) { - %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] : !mesh.sharding +func.func @return_sharding(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>, !shard.sharding) { + %sharding = shard.sharding @grid0 split_axes = [[0, 1], [2]] : !shard.sharding // CHECK: [[vcst:%.*]] = arith.constant dense<2> : tensor<1xi16> // CHECK: [[vcst_0:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16> // CHECK: [[vcm1_i16:%.*]] = arith.constant -1 : i16 @@ -316,13 +316,13 @@ func.func @return_sharding(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>, !mesh.sh // CHECK: [[vcast_2:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64> // CHECK: [[vcast_3:%.*]] = tensor.cast [[v3]] : tensor<0x0xi64> to tensor<?x?xi64> // CHECK: return [[varg0]], [[vcast]], [[vcast_2]], [[vcast_3]] : tensor<2x4xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64> - return %arg0, %sharding : tensor<2x4xf32>, !mesh.sharding + return %arg0, %sharding : tensor<2x4xf32>, !shard.sharding } // CHECK-LABEL: func.func @return_sharding_halos( // CHECK-SAME: [[varg0:%.*]]: tensor<6x8xf32>) -> (tensor<6x8xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) { -func.func @return_sharding_halos(%arg0: tensor<6x8xf32>) -> (tensor<6x8xf32>, !mesh.sharding) { - %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] halo_sizes = [0, 4, 3, 1] : !mesh.sharding +func.func @return_sharding_halos(%arg0: tensor<6x8xf32>) -> (tensor<6x8xf32>, !shard.sharding) { + %sharding = shard.sharding @grid0 split_axes = [[0, 1], [2]] halo_sizes = [0, 4, 3, 1] : !shard.sharding // CHECK: [[vcst:%.*]] = arith.constant dense<{{\[\[}}0, 4], [3, 1]]> : tensor<2x2xi64> // CHECK: [[vcst_0:%.*]] = arith.constant dense<2> : tensor<1xi16> // CHECK: [[vcst_1:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16> @@ -336,13 +336,13 @@ func.func @return_sharding_halos(%arg0: tensor<6x8xf32>) -> (tensor<6x8xf32>, !m // CHECK: [[vcast_3:%.*]] = tensor.cast [[vcst]] : tensor<2x2xi64> to tensor<?x?xi64> // CHECK: [[vcast_4:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64> // CHECK: return [[varg0]], [[vcast]], [[vcast_3]], [[vcast_4]] : tensor<6x8xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64> - return %arg0, %sharding : tensor<6x8xf32>, !mesh.sharding + return %arg0, %sharding : tensor<6x8xf32>, !shard.sharding } // CHECK-LABEL: func.func @return_sharding_offs( // CHECK-SAME: [[varg0:%.*]]: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) { -func.func @return_sharding_offs(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>, !mesh.sharding) { - %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] sharded_dims_offsets = [0, 3, 5, 7, 8, 0, 0, 5, 10, 16] : !mesh.sharding +func.func @return_sharding_offs(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>, !shard.sharding) { + %sharding = shard.sharding @grid0 split_axes = [[0, 1], [2]] sharded_dims_offsets = [0, 3, 5, 7, 8, 0, 0, 5, 10, 16] : !shard.sharding // CHECK: [[vcst:%.*]] = arith.constant dense<[0, 0, 5, 10, 16]> : tensor<5xi64> // CHECK: [[vcst_0:%.*]] = arith.constant dense<[0, 3, 5, 7, 8]> : tensor<5xi64> // CHECK: [[vcm9223372036854775808_i64:%.*]] = arith.constant -9223372036854775808 : i64 @@ -362,5 +362,5 @@ func.func @return_sharding_offs(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>, !me // CHECK: [[vcast_6:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64> // CHECK: [[vcast_7:%.*]] = tensor.cast [[vinserted_slice_5]] : tensor<2x5xi64> to tensor<?x?xi64> // CHECK: return [[varg0]], [[vcast]], [[vcast_6]], [[vcast_7]] : tensor<?x?xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64> - return %arg0, %sharding : tensor<?x?xf32>, !mesh.sharding + return %arg0, %sharding : tensor<?x?xf32>, !shard.sharding } diff --git a/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shardshape-to-mpi.mlir index 156bbfb..9729d2b 100644 --- a/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir +++ b/mlir/test/Conversion/ShardToMPI/convert-shardshape-to-mpi.mlir @@ -1,21 +1,21 @@ -// RUN: mlir-opt %s --convert-mesh-to-mpi -canonicalize | FileCheck %s +// RUN: mlir-opt %s --convert-shard-to-mpi -canonicalize | FileCheck %s module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } { - // CHECK: mesh.mesh @mesh0 - mesh.mesh @mesh0(shape = 3x4x5) + // CHECK: shard.grid @grid0 + shard.grid @grid0(shape = 3x4x5) - // Notice: comm_world_rank/linear index 24 is multiindex [1, 0, 4] in @mesh0 + // Notice: comm_world_rank/linear index 24 is multiindex [1, 0, 4] in @grid0 // all shards are equal // CHECK-LABEL: func.func @shard_shape_equal() -> (index, index, index) { func.func @shard_shape_equal() -> (index, index, index) { - %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding - %0:3 = mesh.process_multi_index on @mesh0 : index, index, index + %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] : !shard.sharding + %0:3 = shard.process_multi_index on @grid0 : index, index, index %c9 = arith.constant 9 : index %c12 = arith.constant 12 : index // CHECK: [[vc3:%.*]] = arith.constant 3 : index - %1:3 = mesh.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index + %1:3 = shard.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index // CHECK: return [[vc3]], [[vc3]], [[vc3]] : index, index, index return %1#0, %1#1, %1#2 : index, index, index } @@ -23,13 +23,13 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } { // last shard in last dim gets an extra element // CHECK-LABEL: func.func @shard_shape_odd_1() -> (index, index, index) { func.func @shard_shape_odd_1() -> (index, index, index) { - %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding - %0:3 = mesh.process_multi_index on @mesh0 : index, index, index + %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] : !shard.sharding + %0:3 = shard.process_multi_index on @grid0 : index, index, index %c9 = arith.constant 9 : index %c12 = arith.constant 12 : index // CHECK-DAG: [[vc3:%.*]] = arith.constant 3 : index // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index - %1:3 = mesh.shard_shape dims = [%c9, %c12, 16] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index + %1:3 = shard.shard_shape dims = [%c9, %c12, 16] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index // CHECK: return [[vc3]], [[vc3]], [[vc4]] : index, index, index return %1#0, %1#1, %1#2 : index, index, index } @@ -37,11 +37,11 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } { // In the second dimension the shard sizes are now [3 4 4 4] // CHECK-LABEL: func.func @shard_shape_odd_2() -> (index, index, index) { func.func @shard_shape_odd_2() -> (index, index, index) { - %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding - %0:3 = mesh.process_multi_index on @mesh0 : index, index, index + %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] : !shard.sharding + %0:3 = shard.process_multi_index on @grid0 : index, index, index %c9 = arith.constant 9 : index // CHECK: [[vc3:%.*]] = arith.constant 3 : index - %1:3 = mesh.shard_shape dims = [%c9, 15, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index + %1:3 = shard.shard_shape dims = [%c9, 15, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index // CHECK: return [[vc3]], [[vc3]], [[vc3]] : index, index, index return %1#0, %1#1, %1#2 : index, index, index } @@ -49,11 +49,11 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } { // In the first dimension the shard sizes are now [3 4 4] // CHECK-LABEL: func.func @shard_shape_odd_3() -> (index, index, index) { func.func @shard_shape_odd_3() -> (index, index, index) { - %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding - %0:3 = mesh.process_multi_index on @mesh0 : index, index, index + %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] : !shard.sharding + %0:3 = shard.process_multi_index on @grid0 : index, index, index // CHECK-DAG: [[vc3:%.*]] = arith.constant 3 : index // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index - %1:3 = mesh.shard_shape dims = [11, 12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index + %1:3 = shard.shard_shape dims = [11, 12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index // CHECK: return [[vc4]], [[vc3]], [[vc3]] : index, index, index return %1#0, %1#1, %1#2 : index, index, index } @@ -61,14 +61,14 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } { // extract from sharded_dims_offsets // CHECK-LABEL: func.func @shard_shape_sharded_dims_offs() -> (index, index, index) { func.func @shard_shape_sharded_dims_offs() -> (index, index, index) { - %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] - sharded_dims_offsets = [0, 1, 4, 9, 0, 2, 6, 12, 12, 0, 3, 6, 9, 12, 15]: !mesh.sharding - %0:3 = mesh.process_multi_index on @mesh0 : index, index, index + %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] + sharded_dims_offsets = [0, 1, 4, 9, 0, 2, 6, 12, 12, 0, 3, 6, 9, 12, 15]: !shard.sharding + %0:3 = shard.process_multi_index on @grid0 : index, index, index %c9 = arith.constant 9 : index %c12 = arith.constant 12 : index // CHECK: [[vc3:%.*]] = arith.constant 3 : index // CHECK: [[vc2:%.*]] = arith.constant 2 : index - %1:3 = mesh.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index + %1:3 = shard.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index // CHECK: return [[vc3]], [[vc2]], [[vc3]] : index, index, index return %1#0, %1#1, %1#2 : index, index, index } diff --git a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir index fa7a91c..b6f2383 100644 --- a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir +++ b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir @@ -36,7 +36,7 @@ func.func @while_test(%arg0 : tensor<i32>) -> (tensor<i32>) { func.func @if_test(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> (tensor<f32>) { // CHECK: [[EX:%.+]] = tensor.extract [[ARG2]] // CHECK: [[IF:%.+]] = scf.if [[EX]] -> (tensor<f32>) { - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { // CHECK: scf.yield [[ARG0]] tosa.yield %arg0 : tensor<f32> diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir index 8c135d5..31e17fb 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir @@ -274,73 +274,6 @@ func.func @shuffle_2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf3 // ----- //===----------------------------------------------------------------------===// -// vector.extractelement -//===----------------------------------------------------------------------===// - -func.func @extractelement_from_vec_0d_f32(%arg0: vector<f32>) -> f32 { - %1 = vector.extractelement %arg0[] : vector<f32> - return %1 : f32 -} -// CHECK-LABEL: @extractelement_from_vec_0d_f32 -// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64 -// CHECK: llvm.extractelement %{{.*}}[%[[C0]] : {{.*}}] : vector<1xf32> - -// ----- - -func.func @extractelement_from_vec_1d_f32_idx_as_i32(%arg0: vector<16xf32>) -> f32 { - %0 = arith.constant 15 : i32 - %1 = vector.extractelement %arg0[%0 : i32]: vector<16xf32> - return %1 : f32 -} -// CHECK-LABEL: @extractelement_from_vec_1d_f32_idx_as_i32( -// CHECK-SAME: %[[A:.*]]: vector<16xf32>) -// CHECK: %[[C:.*]] = arith.constant 15 : i32 -// CHECK: %[[X:.*]] = llvm.extractelement %[[A]][%[[C]] : i32] : vector<16xf32> -// CHECK: return %[[X]] : f32 - -// ----- - -func.func @extractelement_from_vec_1d_f32_idx_as_i32_scalable(%arg0: vector<[16]xf32>) -> f32 { - %0 = arith.constant 15 : i32 - %1 = vector.extractelement %arg0[%0 : i32]: vector<[16]xf32> - return %1 : f32 -} -// CHECK-LABEL: @extractelement_from_vec_1d_f32_idx_as_i32_scalable( -// CHECK-SAME: %[[A:.*]]: vector<[16]xf32>) -// CHECK: %[[C:.*]] = arith.constant 15 : i32 -// CHECK: %[[X:.*]] = llvm.extractelement %[[A]][%[[C]] : i32] : vector<[16]xf32> -// CHECK: return %[[X]] : f32 - -// ----- -func.func @extractelement_from_vec_1d_f32_idx_as_index(%arg0: vector<16xf32>) -> f32 { - %0 = arith.constant 15 : index - %1 = vector.extractelement %arg0[%0 : index]: vector<16xf32> - return %1 : f32 -} -// CHECK-LABEL: @extractelement_from_vec_1d_f32_idx_as_index( -// CHECK-SAME: %[[A:.*]]: vector<16xf32>) -// CHECK: %[[C:.*]] = arith.constant 15 : index -// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %[[C]] : index to i64 -// CHECK: %[[X:.*]] = llvm.extractelement %[[A]][%[[I]] : i64] : vector<16xf32> -// CHECK: return %[[X]] : f32 - -// ----- - -func.func @extractelement_from_vec_1d_f32_idx_as_index_scalable(%arg0: vector<[16]xf32>) -> f32 { - %0 = arith.constant 15 : index - %1 = vector.extractelement %arg0[%0 : index]: vector<[16]xf32> - return %1 : f32 -} -// CHECK-LABEL: @extractelement_from_vec_1d_f32_idx_as_index_scalable( -// CHECK-SAME: %[[A:.*]]: vector<[16]xf32>) -// CHECK: %[[C:.*]] = arith.constant 15 : index -// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %[[C]] : index to i64 -// CHECK: %[[X:.*]] = llvm.extractelement %[[A]][%[[I]] : i64] : vector<[16]xf32> -// CHECK: return %[[X]] : f32 - -// ----- - -//===----------------------------------------------------------------------===// // vector.extract //===----------------------------------------------------------------------===// @@ -592,81 +525,6 @@ func.func @extract_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const(%arg : // ----- //===----------------------------------------------------------------------===// -// vector.insertelement -//===----------------------------------------------------------------------===// - -func.func @insertelement_into_vec_0d_f32(%arg0: f32, %arg1: vector<f32>) -> vector<f32> { - %1 = vector.insertelement %arg0, %arg1[] : vector<f32> - return %1 : vector<f32> -} -// CHECK-LABEL: @insertelement_into_vec_0d_f32 -// CHECK-SAME: %[[A:.*]]: f32, -// CHECK: %[[B:.*]] = builtin.unrealized_conversion_cast %{{.*}} : -// CHECK: vector<f32> to vector<1xf32> -// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64 -// CHECK: %[[X:.*]] = llvm.insertelement %[[A]], %[[B]][%[[C0]] : {{.*}}] : vector<1xf32> - -// ----- - -func.func @insertelement_into_vec_1d_f32_idx_as_i32(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> { - %0 = arith.constant 3 : i32 - %1 = vector.insertelement %arg0, %arg1[%0 : i32] : vector<4xf32> - return %1 : vector<4xf32> -} -// CHECK-LABEL: @insertelement_into_vec_1d_f32_idx_as_i32( -// CHECK-SAME: %[[A:.*]]: f32, -// CHECK-SAME: %[[B:.*]]: vector<4xf32>) -// CHECK: %[[C:.*]] = arith.constant 3 : i32 -// CHECK: %[[X:.*]] = llvm.insertelement %[[A]], %[[B]][%[[C]] : i32] : vector<4xf32> -// CHECK: return %[[X]] : vector<4xf32> - -// ----- - -func.func @insertelement_into_vec_1d_f32_idx_as_i32_scalable(%arg0: f32, %arg1: vector<[4]xf32>) -> vector<[4]xf32> { - %0 = arith.constant 3 : i32 - %1 = vector.insertelement %arg0, %arg1[%0 : i32] : vector<[4]xf32> - return %1 : vector<[4]xf32> -} -// CHECK-LABEL: @insertelement_into_vec_1d_f32_idx_as_i32_scalable( -// CHECK-SAME: %[[A:.*]]: f32, -// CHECK-SAME: %[[B:.*]]: vector<[4]xf32>) -// CHECK: %[[C:.*]] = arith.constant 3 : i32 -// CHECK: %[[X:.*]] = llvm.insertelement %[[A]], %[[B]][%[[C]] : i32] : vector<[4]xf32> -// CHECK: return %[[X]] : vector<[4]xf32> - -// ----- - -func.func @insertelement_into_vec_1d_f32_scalable_idx_as_index(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> { - %0 = arith.constant 3 : index - %1 = vector.insertelement %arg0, %arg1[%0 : index] : vector<4xf32> - return %1 : vector<4xf32> -} -// CHECK-LABEL: @insertelement_into_vec_1d_f32_scalable_idx_as_index( -// CHECK-SAME: %[[A:.*]]: f32, -// CHECK-SAME: %[[B:.*]]: vector<4xf32>) -// CHECK: %[[C:.*]] = arith.constant 3 : index -// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %[[C]] : index to i64 -// CHECK: %[[X:.*]] = llvm.insertelement %[[A]], %[[B]][%[[I]] : i64] : vector<4xf32> -// CHECK: return %[[X]] : vector<4xf32> - -// ----- - -func.func @insertelement_into_vec_1d_f32_scalable_idx_as_index_scalable(%arg0: f32, %arg1: vector<[4]xf32>) -> vector<[4]xf32> { - %0 = arith.constant 3 : index - %1 = vector.insertelement %arg0, %arg1[%0 : index] : vector<[4]xf32> - return %1 : vector<[4]xf32> -} -// CHECK-LABEL: @insertelement_into_vec_1d_f32_scalable_idx_as_index_scalable( -// CHECK-SAME: %[[A:.*]]: f32, -// CHECK-SAME: %[[B:.*]]: vector<[4]xf32>) -// CHECK: %[[C:.*]] = arith.constant 3 : index -// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %[[C]] : index to i64 -// CHECK: %[[X:.*]] = llvm.insertelement %[[A]], %[[B]][%[[I]] : i64] : vector<[4]xf32> -// CHECK: return %[[X]] : vector<[4]xf32> - -// ----- - -//===----------------------------------------------------------------------===// // vector.insert //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index f43a41a..8918f91 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -400,67 +400,6 @@ func.func @insert_dynamic_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> // ----- -// CHECK-LABEL: @extract_element -// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32 -// CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32 -func.func @extract_element(%arg0 : vector<4xf32>, %id : i32) -> f32 { - %0 = vector.extractelement %arg0[%id : i32] : vector<4xf32> - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_cst -// CHECK-SAME: %[[V:.*]]: vector<4xf32> -// CHECK: spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32> -func.func @extract_element_cst(%arg0 : vector<4xf32>) -> f32 { - %idx = arith.constant 1 : i32 - %0 = vector.extractelement %arg0[%idx : i32] : vector<4xf32> - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_index -func.func @extract_element_index(%arg0 : vector<4xf32>, %id : index) -> f32 { - // CHECK: spirv.VectorExtractDynamic - %0 = vector.extractelement %arg0[%id : index] : vector<4xf32> - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_size5_vector -func.func @extract_element_size5_vector(%arg0 : vector<5xf32>, %id : i32) -> f32 { - // CHECK: vector.extractelement - %0 = vector.extractelement %arg0[%id : i32] : vector<5xf32> - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_size1_vector -// CHECK-SAME: (%[[S:.+]]: f32 -func.func @extract_element_size1_vector(%arg0 : f32, %i: index) -> f32 { - %bcast = vector.broadcast %arg0 : f32 to vector<1xf32> - %0 = vector.extractelement %bcast[%i : index] : vector<1xf32> - // CHECK: return %[[S]] - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_0d_vector -// CHECK-SAME: (%[[S:.+]]: f32) -func.func @extract_element_0d_vector(%arg0 : f32) -> f32 { - %bcast = vector.broadcast %arg0 : f32 to vector<f32> - %0 = vector.extractelement %bcast[] : vector<f32> - // CHECK: return %[[S]] - return %0: f32 -} - -// ----- - // CHECK-LABEL: @extract_strided_slice // CHECK-SAME: %[[ARG:.+]]: vector<4xf32> // CHECK: spirv.VectorShuffle [1 : i32, 2 : i32] %[[ARG]], %[[ARG]] : vector<4xf32>, vector<4xf32> -> vector<2xf32> @@ -473,67 +412,6 @@ func.func @extract_strided_slice(%arg0: vector<4xf32>) -> (vector<2xf32>, vector // ----- -// CHECK-LABEL: @insert_element -// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32 -// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32 -func.func @insert_element(%val: f32, %arg0 : vector<4xf32>, %id : i32) -> vector<4xf32> { - %0 = vector.insertelement %val, %arg0[%id : i32] : vector<4xf32> - return %0: vector<4xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_cst -// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32> -// CHECK: spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32> -func.func @insert_element_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> { - %idx = arith.constant 2 : i32 - %0 = vector.insertelement %val, %arg0[%idx : i32] : vector<4xf32> - return %0: vector<4xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_index -func.func @insert_element_index(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> { - // CHECK: spirv.VectorInsertDynamic - %0 = vector.insertelement %val, %arg0[%id : index] : vector<4xf32> - return %0: vector<4xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_size5_vector -func.func @insert_element_size5_vector(%val: f32, %arg0 : vector<5xf32>, %id : i32) -> vector<5xf32> { - // CHECK: vector.insertelement - %0 = vector.insertelement %val, %arg0[%id : i32] : vector<5xf32> - return %0 : vector<5xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_size1_vector -// CHECK-SAME: (%[[S:[a-z0-9]+]]: f32 -func.func @insert_element_size1_vector(%scalar: f32, %vector : vector<1xf32>, %i: index) -> vector<1xf32> { - %0 = vector.insertelement %scalar, %vector[%i : index] : vector<1xf32> - // CHECK: %[[V:.+]] = builtin.unrealized_conversion_cast %arg0 : f32 to vector<1xf32> - // CHECK: return %[[V]] - return %0: vector<1xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_0d_vector -// CHECK-SAME: (%[[S:[a-z0-9]+]]: f32 -func.func @insert_element_0d_vector(%scalar: f32, %vector : vector<f32>) -> vector<f32> { - %0 = vector.insertelement %scalar, %vector[] : vector<f32> - // CHECK: %[[V:.+]] = builtin.unrealized_conversion_cast %arg0 : f32 to vector<f32> - // CHECK: return %[[V]] - return %0: vector<f32> -} - -// ----- - // CHECK-LABEL: @insert_strided_slice // CHECK-SAME: %[[PART:.+]]: vector<2xf32>, %[[ALL:.+]]: vector<4xf32> // CHECK: spirv.VectorShuffle [0 : i32, 4 : i32, 5 : i32, 3 : i32] %[[ALL]], %[[PART]] : vector<4xf32>, vector<2xf32> -> vector<4xf32> diff --git a/mlir/test/Dialect/AMDGPU/canonicalize.mlir b/mlir/test/Dialect/AMDGPU/canonicalize.mlir index 4559e39..5501ad4 100644 --- a/mlir/test/Dialect/AMDGPU/canonicalize.mlir +++ b/mlir/test/Dialect/AMDGPU/canonicalize.mlir @@ -130,3 +130,32 @@ func.func @dead_atomic_add(%arg0: memref<4xf32>, %arg1: f32) { amdgpu.raw_buffer_atomic_fadd {boundsCheck = true} %arg1 -> %arg0[%c4_i32] : f32 -> memref<4xf32>, i32 func.return } + +// ----- + +// CHECK-LABEL: func @fold_gather_to_lds_of_cast +func.func @fold_gather_to_lds_of_cast(%global: memref<128x72xf32, 1>, %lds: memref<64x64xf32, 3>) { +// CHECK-SAME: %[[GLOBAL:[A-Za-z0-9]+]]: memref<128x72xf32, 1> + %c0 = arith.constant 0 : index + %0 = memref.cast %global : memref<128x72xf32, 1> to memref<?x?xf32, 1> + // CHECK: amdgpu.gather_to_lds %[[GLOBAL]] + // CHECK-SAME: : f32, memref<128x72xf32, 1> + amdgpu.gather_to_lds %0[%c0, %c0], %lds[%c0, %c0] + : f32, memref<?x?xf32, 1>, memref<64x64xf32, 3> + func.return +} + +// ----- + +// CHECK-LABEL: func @fold_gather_to_lds_of_cast_dest +func.func @fold_gather_to_lds_of_cast_dest(%global: memref<128x72xf32, 1>, %lds: memref<64x64xf32, 3>) { +// CHECK-SAME: %[[GLOBAL:[A-Za-z0-9]+]]: memref<128x72xf32, 1> +// CHECK-SAME: %[[LDS:[A-Za-z0-9]+]]: memref<64x64xf32, 3> + %c0 = arith.constant 0 : index + %0 = memref.cast %lds : memref<64x64xf32, 3> to memref<?x?xf32, 3> + // CHECK: amdgpu.gather_to_lds %[[GLOBAL]][{{.*}}], %[[LDS]] + // CHECK-SAME: : f32, memref<128x72xf32, 1>, memref<64x64xf32, 3> + amdgpu.gather_to_lds %global[%c0, %c0], %0[%c0, %c0] + : f32, memref<128x72xf32, 1>, memref<?x?xf32, 3> + func.return +} diff --git a/mlir/test/Dialect/Arith/mesh-spmdize.mlir b/mlir/test/Dialect/Arith/mesh-spmdize.mlir deleted file mode 100644 index 6b55dd5..0000000 --- a/mlir/test/Dialect/Arith/mesh-spmdize.mlir +++ /dev/null @@ -1,17 +0,0 @@ -// RUN: mlir-opt \ -// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization))" \ -// RUN: %s | FileCheck %s - -mesh.mesh @mesh4x4(shape = 4x4) - -// CHECK-LABEL: func @test_spmdize_constant -// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : -// tensor<256x1024xf32> CHECK-NEXT: [[vc434_i32:%.*]] = arith.constant 434 : -// i32 CHECK-NEXT: return [[vcst]] : tensor<256x1024xf32> -func.func @test_spmdize_constant() ->(tensor<1024x1024xf32>)attributes{llvm.emit_c_interface} { - %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> - %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding - %sharding_annotated_1 = mesh.shard %cst to %sharding_1 : tensor<1024x1024xf32> - %ci = arith.constant 434 : i32 - return %sharding_annotated_1 : tensor<1024x1024xf32> -} diff --git a/mlir/test/Dialect/Arith/shard-partition.mlir b/mlir/test/Dialect/Arith/shard-partition.mlir new file mode 100644 index 0000000..be89427 --- /dev/null +++ b/mlir/test/Dialect/Arith/shard-partition.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-opt \ +// RUN: --pass-pipeline="builtin.module(func.func(shard-partition))" \ +// RUN: %s | FileCheck %s + +shard.grid @grid4x4(shape = 4x4) + +// CHECK-LABEL: func @test_partition_constant +// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : +// tensor<256x1024xf32> CHECK-NEXT: [[vc434_i32:%.*]] = arith.constant 434 : +// i32 CHECK-NEXT: return [[vcst]] : tensor<256x1024xf32> +func.func @test_partition_constant() ->(tensor<1024x1024xf32>)attributes{llvm.emit_c_interface} { + %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + %sharding_1 = shard.sharding @grid4x4 split_axes = [[0]] : !shard.sharding + %sharded_1 = shard.shard %cst to %sharding_1 : tensor<1024x1024xf32> + %ci = arith.constant 434 : i32 + return %sharded_1 : tensor<1024x1024xf32> +} diff --git a/mlir/test/Dialect/Arith/sharding-propagation.mlir b/mlir/test/Dialect/Arith/sharding-propagation.mlir index 19eb340..762620d 100644 --- a/mlir/test/Dialect/Arith/sharding-propagation.mlir +++ b/mlir/test/Dialect/Arith/sharding-propagation.mlir @@ -1,54 +1,54 @@ // RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation))" %s | FileCheck %s -mesh.mesh @mesh4x4(shape = 4x4) +shard.grid @grid4x4(shape = 4x4) // CHECK-LABEL: func.func @test_shard_constant() -> tensor<1024x1024xf32> attributes {llvm.emit_c_interface} { // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> -// CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding -// CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding +// CHECK-NEXT: [[vsharded:%.*]] = shard.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32> // CHECK-NEXT: [[vcst_0:%.*]] = arith.constant 4.340000e+01 : f32 // CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<1024x1024xf32> -// CHECK-NEXT: [[vsharding_1:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding -// CHECK-NEXT: [[vsharding_annotated_2:%.*]] = mesh.shard [[v0]] to [[vsharding_1]] : tensor<1024x1024xf32> -// CHECK-NEXT: [[vsharding_3:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding -// CHECK-NEXT: [[vsharding_annotated_4:%.*]] = mesh.shard [[vsharding_annotated]] to [[vsharding_3]] annotate_for_users : tensor<1024x1024xf32> -// CHECK-NEXT: [[vsharding_5:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding -// CHECK-NEXT: [[vsharding_annotated_6:%.*]] = mesh.shard [[vsharding_annotated_2]] to [[vsharding_5]] annotate_for_users : tensor<1024x1024xf32> -// CHECK-NEXT: [[v1:%.*]] = linalg.add ins([[vsharding_annotated_4]], [[vcst_0]] : tensor<1024x1024xf32>, f32) outs([[vsharding_annotated_6]] : tensor<1024x1024xf32>) -> tensor<1024x1024xf32> -// CHECK-NEXT: [[vsharding_7:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding -// CHECK-NEXT: [[vsharding_annotated_8:%.*]] = mesh.shard [[v1]] to [[vsharding_7]] : tensor<1024x1024xf32> -// CHECK-NEXT: return [[vsharding_annotated_8]] : tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding_1:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding +// CHECK-NEXT: [[vsharded_2:%.*]] = shard.shard [[v0]] to [[vsharding_1]] : tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding_3:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding +// CHECK-NEXT: [[vsharded_4:%.*]] = shard.shard [[vsharded]] to [[vsharding_3]] annotate_for_users : tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding_5:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding +// CHECK-NEXT: [[vsharded_6:%.*]] = shard.shard [[vsharded_2]] to [[vsharding_5]] annotate_for_users : tensor<1024x1024xf32> +// CHECK-NEXT: [[v1:%.*]] = linalg.add ins([[vsharded_4]], [[vcst_0]] : tensor<1024x1024xf32>, f32) outs([[vsharded_6]] : tensor<1024x1024xf32>) -> tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding_7:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding +// CHECK-NEXT: [[vsharded_8:%.*]] = shard.shard [[v1]] to [[vsharding_7]] : tensor<1024x1024xf32> +// CHECK-NEXT: return [[vsharded_8]] : tensor<1024x1024xf32> func.func @test_shard_constant() -> (tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} { %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> - %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding - %sharding_annotated_1 = mesh.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32> + %sharding_1 = shard.sharding @grid4x4 split_axes = [[0]] : !shard.sharding + %sharded_1 = shard.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32> %ci = arith.constant 43.4e+00 : f32 %o1 = tensor.empty() : tensor<1024x1024xf32> - %res = linalg.add ins(%sharding_annotated_1, %ci : tensor<1024x1024xf32>, f32) outs(%o1 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32> + %res = linalg.add ins(%sharded_1, %ci : tensor<1024x1024xf32>, f32) outs(%o1 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32> return %res : tensor<1024x1024xf32> } // CHECK-LABEL: func.func @test_shard_constant_back() -> tensor<1024x1024xf32> attributes {llvm.emit_c_interface} { // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> -// CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding -// CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding +// CHECK-NEXT: [[vsharded:%.*]] = shard.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32> // CHECK-NEXT: [[vcst_0:%.*]] = arith.constant 4.340000e+01 : f32 // CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<1024x1024xf32> -// CHECK-NEXT: [[vsharding_1:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding -// CHECK-NEXT: [[vsharding_annotated_2:%.*]] = mesh.shard [[v0]] to [[vsharding_1]] : tensor<1024x1024xf32> -// CHECK-NEXT: [[vsharding_3:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding -// CHECK-NEXT: [[vsharding_annotated_4:%.*]] = mesh.shard [[vsharding_annotated]] to [[vsharding_3]] annotate_for_users : tensor<1024x1024xf32> -// CHECK-NEXT: [[vsharding_5:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding -// CHECK-NEXT: [[vsharding_annotated_6:%.*]] = mesh.shard [[vsharding_annotated_2]] to [[vsharding_5]] annotate_for_users : tensor<1024x1024xf32> -// CHECK-NEXT: [[v1:%.*]] = linalg.add ins([[vsharding_annotated_4]], [[vcst_0]] : tensor<1024x1024xf32>, f32) outs([[vsharding_annotated_6]] : tensor<1024x1024xf32>) -> tensor<1024x1024xf32> -// CHECK-NEXT: [[vsharding_7:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding -// CHECK-NEXT: [[vsharding_annotated_8:%.*]] = mesh.shard [[v1]] to [[vsharding_7]] : tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding_1:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding +// CHECK-NEXT: [[vsharded_2:%.*]] = shard.shard [[v0]] to [[vsharding_1]] : tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding_3:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding +// CHECK-NEXT: [[vsharded_4:%.*]] = shard.shard [[vsharded]] to [[vsharding_3]] annotate_for_users : tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding_5:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding +// CHECK-NEXT: [[vsharded_6:%.*]] = shard.shard [[vsharded_2]] to [[vsharding_5]] annotate_for_users : tensor<1024x1024xf32> +// CHECK-NEXT: [[v1:%.*]] = linalg.add ins([[vsharded_4]], [[vcst_0]] : tensor<1024x1024xf32>, f32) outs([[vsharded_6]] : tensor<1024x1024xf32>) -> tensor<1024x1024xf32> +// CHECK-NEXT: [[vsharding_7:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding +// CHECK-NEXT: [[vsharded_8:%.*]] = shard.shard [[v1]] to [[vsharding_7]] : tensor<1024x1024xf32> func.func @test_shard_constant_back() -> (tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} { %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> %ci = arith.constant 43.4e+00 : f32 %o1 = tensor.empty() : tensor<1024x1024xf32> %res = linalg.add ins(%cst_1, %ci : tensor<1024x1024xf32>, f32) outs(%o1 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32> - %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding - %sharding_annotated_1 = mesh.shard %res to %sharding_1 : tensor<1024x1024xf32> - return %sharding_annotated_1 : tensor<1024x1024xf32> + %sharding_1 = shard.sharding @grid4x4 split_axes = [[0]] : !shard.sharding + %sharded_1 = shard.shard %res to %sharding_1 : tensor<1024x1024xf32> + return %sharded_1 : tensor<1024x1024xf32> } diff --git a/mlir/test/Dialect/Async/canonicalize.mlir b/mlir/test/Dialect/Async/canonicalize.mlir new file mode 100644 index 0000000..1a74eaa --- /dev/null +++ b/mlir/test/Dialect/Async/canonicalize.mlir @@ -0,0 +1,10 @@ +// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s + +// CHECK-NOT: async.execute + +func.func @empty_execute() { + %token = async.execute { + async.yield + } + return +} diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir new file mode 100644 index 0000000..e2ab876 --- /dev/null +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir @@ -0,0 +1,33 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='builtin.module(test.symbol_scope_isolated(test-one-shot-module-bufferize))' -split-input-file | FileCheck %s + +"test.symbol_scope_isolated"() ({ + // CHECK-LABEL: func @inner_func( + // CHECK-SAME: %[[arg0:.*]]: memref<?xf32 + func.func @inner_func(%t: tensor<?xf32>) -> (tensor<?xf32>, f32) { + // CHECK-NOT: copy + %f = arith.constant 1.0 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + // CHECK: memref.store %{{.*}}, %[[arg0]] + %0 = tensor.insert %f into %t[%c0] : tensor<?xf32> + // CHECK: %[[load:.*]] = memref.load %[[arg0]] + %1 = tensor.extract %0[%c1] : tensor<?xf32> + // CHECK: return %[[arg0]], %[[load]] : memref<?xf32{{.*}}>, f32 + return %0, %1 : tensor<?xf32>, f32 + } + + // CHECK-LABEL: func @call_func_with_non_tensor_return( + // CHECK-SAME: %[[arg0:.*]]: memref<?xf32 + func.func @call_func_with_non_tensor_return( + %t0: tensor<?xf32> {bufferization.writable = true}) -> (f32, tensor<?xf32>) { + // CHECK-NOT: alloc + // CHECK-NOT: copy + // CHECK: %[[call:.*]]:2 = call @inner_func(%[[arg0]]) + %0, %1 = call @inner_func(%t0) : (tensor<?xf32>) -> (tensor<?xf32>, f32) + // CHECK: return %[[call]]#1, %[[call]]#0 : f32, memref<?xf32,{{.*}}> + return %1, %0 : f32, tensor<?xf32> + } + "test.finish" () : () -> () +}) : () -> () + + diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir index f44e290..2acd194 100644 --- a/mlir/test/Dialect/Bufferization/canonicalize.mlir +++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir @@ -255,16 +255,32 @@ func.func @clone_and_preceding_dealloc(%arg0: memref<?xf32>) -> memref<32xf32> { func.func @tensor_cast_to_buffer(%arg0 : tensor<4x6x16x32xi8>) -> memref<?x?x16x32xi8> { %0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8> - %1 = bufferization.to_buffer %0 : tensor<?x?x16x32xi8> to memref<?x?x16x32xi8> + %1 = bufferization.to_buffer %0 read_only : tensor<?x?x16x32xi8> to memref<?x?x16x32xi8> return %1 : memref<?x?x16x32xi8> } -// CHECK: %[[M:.+]] = bufferization.to_buffer %[[ARG0]] : tensor<4x6x16x32xi8> +// CHECK: %[[M:.+]] = bufferization.to_buffer %[[ARG0]] read_only : tensor<4x6x16x32xi8> // CHECK: %[[M1:.+]] = memref.cast %[[M]] // CHECK-SAME: memref<4x6x16x32xi8> to memref<?x?x16x32xi8> // CHECK: return %[[M1]] : memref<?x?x16x32xi8> // ----- +// CHECK-LABEL: func @tensor_cast_to_buffer +// CHECK-SAME: %[[ARG0:.+]]: tensor<4x6x16x32xi8> +func.func @tensor_cast_to_buffer_layout_and_memspace(%arg0 : tensor<4x6x16x32xi8>) -> + memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1> { + %0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8> + %1 = bufferization.to_buffer %0 : tensor<?x?x16x32xi8> to memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1> + return %1 : memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1> +} +// CHECK: %[[M:.+]] = bufferization.to_buffer %[[ARG0]] : tensor<4x6x16x32xi8> +// CHECK: %[[M1:.+]] = memref.cast %[[M]] +// CHECK-SAME: memref<4x6x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1> +// CHECK-SAME: to memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1> +// CHECK: return %[[M1]] : memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1> + +// ----- + // Folding of memref.load(to_buffer(%v, %idxs)) -> tensor.extract(%v, %idx) // CHECK-LABEL: func @load_from_buffer_cast( func.func @load_from_buffer_cast(%arg0: index, %arg1: index, diff --git a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir index c67a0c1..029fa78 100644 --- a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir +++ b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=emitc.name_hint' %s | FileCheck %s +// RUN: mlir-opt --wrap-emitc-func-in-class %s | FileCheck %s module attributes { } { emitc.func @model(%arg0: !emitc.array<1xf32> {emitc.name_hint = "another_feature"}, diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir index 162ff06..35381da 100644 --- a/mlir/test/Dialect/GPU/invalid.mlir +++ b/mlir/test/Dialect/GPU/invalid.mlir @@ -479,20 +479,16 @@ func.func @shuffle_unsupported_type_vec(%arg0 : vector<[4]xf32>, %arg1 : i32, %a // ----- func.func @rotate_mismatching_type(%arg0 : f32) { - %offset = arith.constant 4 : i32 - %width = arith.constant 16 : i32 // expected-error@+1 {{op failed to verify that all of {value, rotateResult} have same type}} - %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (i32, i1) + %rotate, %valid = "gpu.rotate"(%arg0) { offset = 4 : i32, width = 16 : i32 } : (f32) -> (i32, i1) return } // ----- func.func @rotate_unsupported_type(%arg0 : index) { - %offset = arith.constant 4 : i32 - %width = arith.constant 16 : i32 // expected-error@+1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'index'}} - %rotate, %valid = gpu.rotate %arg0, %offset, %width : index + %rotate, %valid = gpu.rotate %arg0, 4, 16 : index return } @@ -502,55 +498,31 @@ func.func @rotate_unsupported_type_vec(%arg0 : vector<[4]xf32>) { %offset = arith.constant 4 : i32 %width = arith.constant 16 : i32 // expected-error@+1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'vector<[4]xf32>'}} - %rotate, %valid = gpu.rotate %arg0, %offset, %width : vector<[4]xf32> + %rotate, %valid = gpu.rotate %arg0, 4, 16 : vector<[4]xf32> return } // ----- func.func @rotate_unsupported_width(%arg0 : f32) { - %offset = arith.constant 4 : i32 - %width = arith.constant 15 : i32 - // expected-error@+1 {{op width must be a power of two}} - %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1) + // expected-error@+1 {{'gpu.rotate' op attribute 'width' failed to satisfy constraint: 32-bit signless integer attribute whose value is a power of two > 0}} + %rotate, %valid = "gpu.rotate"(%arg0) { offset = 4 : i32, width = 15 : i32 } : (f32) -> (f32, i1) return } // ----- func.func @rotate_unsupported_offset(%arg0 : f32) { - %offset = arith.constant 16 : i32 - %width = arith.constant 16 : i32 // expected-error@+1 {{op offset must be in the range [0, 16)}} - %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1) + %rotate, %valid = "gpu.rotate"(%arg0) { offset = 16 : i32, width = 16 : i32 }: (f32) -> (f32, i1) return } // ----- func.func @rotate_unsupported_offset_minus(%arg0 : f32) { - %offset = arith.constant -1 : i32 - %width = arith.constant 16 : i32 - // expected-error@+1 {{op offset must be in the range [0, 16)}} - %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1) - return -} - -// ----- - -func.func @rotate_offset_non_constant(%arg0 : f32, %offset : i32) { - %width = arith.constant 16 : i32 - // expected-error@+1 {{op offset is not a constant value}} - %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1) - return -} - -// ----- - -func.func @rotate_width_non_constant(%arg0 : f32, %width : i32) { - %offset = arith.constant 0 : i32 - // expected-error@+1 {{op width is not a constant value}} - %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1) + // expected-error@+1 {{'gpu.rotate' op attribute 'offset' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 0}} + %rotate, %valid = "gpu.rotate"(%arg0) { offset = -1 : i32, width = 16 : i32 } : (f32) -> (f32, i1) return } diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir index 2aef80f..ee1fdfa 100644 --- a/mlir/test/Dialect/GPU/ops.mlir +++ b/mlir/test/Dialect/GPU/ops.mlir @@ -140,9 +140,8 @@ module attributes {gpu.container_module} { // CHECK: gpu.shuffle idx %{{.*}}, %{{.*}}, %{{.*}} : f32 %shfl3, %pred3 = gpu.shuffle idx %arg0, %offset, %width : f32 - // CHECK: gpu.rotate %{{.*}}, %{{.*}}, %{{.*}} : f32 - %rotate_width = arith.constant 16 : i32 - %rotate, %pred4 = gpu.rotate %arg0, %offset, %rotate_width : f32 + // CHECK: gpu.rotate %{{.*}}, 3, 16 : f32 + %rotate, %pred4 = gpu.rotate %arg0, 3, 16 : f32 "gpu.barrier"() : () -> () diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 7284ae7..5c5f7e8 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -1176,6 +1176,52 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>) // ----- +// CHECK-LABEL: @broadcast_broadcast_fold +// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32> +// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x3xf32> +// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32> +// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) dimensions = [1, 2] +// CHECK-NOT: linalg.broadcast +// CHECK: return %[[BROADCAST]] : tensor<2x3x4xf32> +func.func @broadcast_broadcast_fold(%input: tensor<2xf32>, + %init1: tensor<2x3xf32>, + %init2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { + %broadcast1 = linalg.broadcast + ins(%input: tensor<2xf32>) + outs(%init1: tensor<2x3xf32>) + dimensions = [1] + %broadcast2 = linalg.broadcast + ins(%broadcast1: tensor<2x3xf32>) + outs(%init2: tensor<2x3x4xf32>) + dimensions = [2] + func.return %broadcast2 : tensor<2x3x4xf32> +} + +// ----- + +// CHECK-LABEL: @broadcast_broadcast_fold +// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32> +// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x4xf32> +// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32> +// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) dimensions = [1, 2] +// CHECK-NOT: linalg.broadcast +// CHECK: return %[[BROADCAST]] : tensor<2x3x4xf32> +func.func @broadcast_broadcast_fold(%input: tensor<2xf32>, + %init1: tensor<2x4xf32>, + %init2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { + %broadcast1 = linalg.broadcast + ins(%input: tensor<2xf32>) + outs(%init1: tensor<2x4xf32>) + dimensions = [1] + %broadcast2 = linalg.broadcast + ins(%broadcast1: tensor<2x4xf32>) + outs(%init2: tensor<2x3x4xf32>) + dimensions = [1] + func.return %broadcast2 : tensor<2x3x4xf32> +} + +// ----- + func.func @transpose_1d(%input: tensor<16xf32>, %init: tensor<16xf32>) -> tensor<16xf32> { %transpose = linalg.transpose @@ -1387,42 +1433,43 @@ func.func @recursive_effect(%arg : tensor<1xf32>) { // CHECK-LABEL: @recursive_effect // CHECK: linalg.map +// ----- + //===----------------------------------------------------------------------===// // linalg.pack //===----------------------------------------------------------------------===// // CHECK-LABEL: func @fold_pack_constant_splat // CHECK-NOT: linalg.pack -// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32> -func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> { +// CHECK: arith.constant dense<1.000000e-01> : tensor<4x8x8x32xf32> +func.func @fold_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> { %cst = arith.constant dense<1.000000e-01> : tensor<64x128xf32> %0 = linalg.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] - inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<8x16x8x32xf32> - return %0 : tensor<8x16x8x32xf32> + inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<4x8x8x32xf32> + return %0 : tensor<4x8x8x32xf32> } // ----- // CHECK-LABEL: func @fold_padding_value_pack_constant_splat // CHECK-NOT: linalg.pack -// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32> -func.func @fold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> { +// CHECK: arith.constant dense<1.000000e-01> : tensor<4x8x8x32xf32> +func.func @fold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> { %pad = arith.constant 1.000000e-01 : f32 %cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32> %0 = linalg.pack %cst padding_value(%pad : f32) outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] - inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32> - return %0 : tensor<8x16x8x32xf32> + inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<4x8x8x32xf32> + return %0 : tensor<4x8x8x32xf32> } - // ----- // CHECK-LABEL: func @nofold_padding_value_pack_constant_splat // CHECK: arith.constant dense<1.000000e-01> : tensor<63x127xf32> // CHECK: linalg.pack -func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> { +func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> { %pad = arith.constant 0.0 : f32 %cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32> %0 = linalg.pack %cst @@ -1430,8 +1477,8 @@ func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32] - into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32> - return %0 : tensor<8x16x8x32xf32> + into %dest : tensor<63x127xf32> -> tensor<4x8x8x32xf32> + return %0 : tensor<4x8x8x32xf32> } // ----- @@ -1889,31 +1936,84 @@ func.func @fold_cast_unpack_dynamic_tile_size( // linalg.unpack + tensor.extract_slice //===----------------------------------------------------------------------===// -func.func @fold_extract_slice_into_unpack( - %src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index -) -> tensor<28x28x?xf32> { +func.func @fold_extract_slice_into_unpack_slicing_trailing_dim(%src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>) -> tensor<28x28x10xf32> { %unpack = linalg.unpack %src outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 16] - into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32> + into %dest : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32> %extracted_slice = tensor.extract_slice %unpack - [0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32> - return %extracted_slice : tensor<28x28x?xf32> + [0, 0, 0] [28, 28, 10] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x28x10xf32> + return %extracted_slice : tensor<28x28x10xf32> } +// CHECK-LABEL: func @fold_extract_slice_into_unpack_slicing_trailing_dim +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]] +// CHECK-SAME: [0, 0, 0] [28, 28, 10] [1, 1, 1] +// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]] +// CHECK-SAME: into %[[DEST_SLICE]] +// CHECK: return %[[UNPACK]] -// CHECK-LABEL: func @fold_extract_slice_into_unpack -// CHECK-SAME: %[[SRC:.+]]: tensor<28x2x?x16x16xf32> -// CHECK-SAME: %[[DEST:.+]]: tensor<28x32x?xf32> -// CHECK-SAME: %[[SIZE:.+]]: index +// ----- + +// The available dimension size is [17, 32], because CeilDiv(%d1, 16) == 2. + +func.func @fold_extract_slice_into_unpack_slicing_dim_1(%src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>) -> tensor<28x17x15xf32> { + %unpack = linalg.unpack %src + inner_dims_pos = [1, 2] + inner_tiles = [16, 16] + into %dest : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32> + %extracted_slice = tensor.extract_slice %unpack + [0, 0, 0] [28, 17, 15] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x17x15xf32> + return %extracted_slice : tensor<28x17x15xf32> +} +// CHECK-LABEL: func @fold_extract_slice_into_unpack_slicing_dim_1( +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] // CHECK: %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]] -// CHECK-SAME: [0, 0, 0] [28, 28, %[[SIZE]]] [1, 1, 1] +// CHECK-SAME: [0, 0, 0] [28, 17, 15] [1, 1, 1] // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]] // CHECK-SAME: into %[[DEST_SLICE]] // CHECK: return %[[UNPACK]] // ----- +// The available dimension size is [17, 32], because CeilDiv(%d1, 16) == 2. + +func.func @no_fold_extract_slice_into_unpack_artificial_padding(%src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>) -> tensor<28x16x15xf32> { + %unpack = linalg.unpack %src + inner_dims_pos = [1, 2] + inner_tiles = [16, 16] + into %dest : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32> + %extracted_slice = tensor.extract_slice %unpack + [0, 0, 0] [28, 16, 15] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x16x15xf32> + return %extracted_slice : tensor<28x16x15xf32> +} +// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_artificial_padding +// CHECK: linalg.unpack +// CHECK: tensor.extract_slice + +// ----- + +func.func @no_fold_extract_slice_into_unpack_dynamic( + %src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index +) -> tensor<28x28x?xf32> { + %unpack = linalg.unpack %src + outer_dims_perm = [0, 1, 2] + inner_dims_pos = [1, 2] + inner_tiles = [16, 16] + into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32> + %extracted_slice = tensor.extract_slice %unpack + [0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32> + return %extracted_slice : tensor<28x28x?xf32> +} +// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_dynamic +// CHECK: linalg.unpack +// CHECK: tensor.extract_slice + +// ----- + func.func @no_fold_extract_slice_into_unpack_rank_reducing( %src : tensor<28x2x16xf32>, %dest : tensor<28x32xf32> ) -> tensor<28xf32> { diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir index 6fc8d9f..cc26fa4 100644 --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -1295,24 +1295,6 @@ func.func @no_bubble_up_pack_expanded_padding_through_expand_cannot_reassociate( // ----- -func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(%arg0: tensor<32x64xf32>) -> tensor<8x4x16x8xf32> { - %empty = tensor.empty() : tensor<8x4x16x8xf32> - %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32> - %pack = linalg.pack %expanded inner_dims_pos = [0] inner_tiles = [8] into %empty : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32> - return %pack : tensor<8x4x16x8xf32> -} -// CHECK-LABEL: func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x4x16x8xf32> -// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] -// CHECK-SAME: output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32> -// CHECK: %[[PACK:.+]] = linalg.pack %[[EXPANDED]] -// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [8] into %[[EMPTY]] -// CHECK-SAME: : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32> -// CHECK: return %[[PACK]] : tensor<8x4x16x8xf32> - -// ----- - func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index, %sz0: index) -> tensor<?x256x256xf32> { %6 = tensor.empty(%dim) : tensor<?x256xf32> %unpack = linalg.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<?x32x8x8xf32> -> tensor<?x256xf32> diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir index a00c798..5f42938 100644 --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -1076,6 +1076,44 @@ func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> te // ----- +func.func @drop_unit_dim_mixed_static_dynamic(%arg0: tensor<1x?xf32>) -> tensor<1x16xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.000000e+00 : f32 + %padded = tensor.pad %arg0 low[%c0, %c1] high[%c0, %c0] { + ^bb0(%arg1: index, %arg2: index): + tensor.yield %cst : f32 + } : tensor<1x?xf32> to tensor<1x16xf32> + return %padded : tensor<1x16xf32> +} +// CHECK-LABEL: func @drop_unit_dim_mixed_static_dynamic +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARGS:.*]] : tensor<1x?xf32> into tensor<?xf32> +// CHECK: %[[PADDED:.*]] = tensor.pad %[[COLLAPSE]] low[1] high[0] { +// CHECK: ^bb0(%[[IDX:.*]]: index): +// CHECK: tensor.yield %[[CST]] : f32 +// CHECK: } : tensor<?xf32> to tensor<16xf32> +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[PADDED]] {{\[\[}}0, 1]] output_shape [1, 16] : tensor<16xf32> into tensor<1x16xf32> +// CHECK: return %[[EXPANDED]] : tensor<1x16xf32> + +// ----- + +#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> +module { + func.func @drop_unit_dim_corresponding_to_dynamic_dim(%arg0: tensor<1x?x?x1xf32>, %arg1: index) -> tensor<?x1x61x1xf32> { + %cst = arith.constant dense<1.000000e+00> : tensor<1x1x1x1xf32> + %0 = tensor.empty(%arg1) : tensor<?x1x61x1xf32> + %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %cst : tensor<1x?x?x1xf32>, tensor<1x1x1x1xf32>) outs(%0 : tensor<?x1x61x1xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %2 = arith.mulf %in, %in_0 : f32 + %3 = arith.addf %out, %2 : f32 + linalg.yield %3 : f32 + } -> tensor<?x1x61x1xf32> + return %1 : tensor<?x1x61x1xf32> + } +} // CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (0, d0)> // CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> ()> @@ -1097,23 +1135,6 @@ func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> te // CHECK: return %[[VAL_14]] : tensor<?x1x61x1xf32> // CHECK: } -#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> -#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> -#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> -module { - func.func @drop_unit_dim_corresponding_to_dynamic_dim(%arg0: tensor<1x?x?x1xf32>, %arg1: index) -> tensor<?x1x61x1xf32> { - %cst = arith.constant dense<1.000000e+00> : tensor<1x1x1x1xf32> - %0 = tensor.empty(%arg1) : tensor<?x1x61x1xf32> - %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %cst : tensor<1x?x?x1xf32>, tensor<1x1x1x1xf32>) outs(%0 : tensor<?x1x61x1xf32>) { - ^bb0(%in: f32, %in_0: f32, %out: f32): - %2 = arith.mulf %in, %in_0 : f32 - %3 = arith.addf %out, %2 : f32 - linalg.yield %3 : f32 - } -> tensor<?x1x61x1xf32> - return %1 : tensor<?x1x61x1xf32> - } -} - // ----- func.func @no_fold_empty_tensor_dim_out_of_bounds(%arg0: tensor<1x?x10xf32>) -> tensor<1x?xf32> { diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index da1dfc7..40bf4d1 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -1760,6 +1760,7 @@ func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf } // ----- + func.func @pack_mismatch_inner_tile_size_and_output_shape( %input : tensor<?x?xf32>, %output : tensor<?x?x8x8xf32>) -> tensor<?x?x8x8xf32> { // expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}} @@ -1824,27 +1825,47 @@ func.func @unpack_invalid_outer_dims_perm(%source: tensor<128x256xf32>, %dest: t // ----- +func.func @pack_with_artificial_padding(%input: tensor<9xf32>, %output: tensor<3x8xf32>) -> tensor<3x8xf32> { + %cst = arith.constant 0.0 : f32 + // expected-error@+1 {{expected 'tensor<2x8xf32>' for the packed domain value, got 'tensor<3x8xf32>'}} + %0 = linalg.pack %input padding_value(%cst : f32) inner_dims_pos = [0] + inner_tiles = [8] into %output + : tensor<9xf32> -> tensor<3x8xf32> + return %0 : tensor<3x8xf32> +} + +// ----- + // The outer dims in the output tensor are incorrectly/unexpectedly transposed. // This could be fixed by adding `outer_dims_perm = [1, 0]` (the default value assumes no transpose). func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<4x16x32x16xf32>) -> tensor<4x16x32x16xf32> { - // expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<16x4x32x16xf32>', got 'tensor<4x16x32x16xf32>'}} + // expected-error@+1 {{expected 'tensor<16x4x32x16xf32>' for the packed domain value, got 'tensor<4x16x32x16xf32>'}} %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [32, 16] into %output : tensor<256x128xf32> -> tensor<4x16x32x16xf32> return %0 : tensor<4x16x32x16xf32> } // ----- -func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> { - // expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x8x16x32xf32>', got 'tensor<8x8x32x16xf32>'}} - %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32> - return %0 : tensor<8x8x32x16xf32> +func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<8x7x16x32xf32>) -> tensor<8x7x16x32xf32> { + // expected-error@+1 {{expected 'tensor<8x8x16x32xf32>' for the packed domain value, got 'tensor<8x7x16x32xf32>'}} + %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x7x16x32xf32> + return %0 : tensor<8x7x16x32xf32> +} + +// ----- + +func.func @unpack_with_artifical_tiles_that_are_dropped(%input: tensor<3x8xf32>, %output: tensor<9xf32>) -> tensor<9xf32> { + // expected-error@+1 {{expected 'tensor<2x8xf32>' for the packed domain value, got 'tensor<3x8xf32>'}} + %0 = linalg.unpack %input inner_dims_pos = [0] inner_tiles = [8] into %output + : tensor<3x8xf32> -> tensor<9xf32> + return %0 : tensor<9xf32> } // ----- -func.func @unpack_invalid(%output: tensor<256x128xf32>, %input: tensor<8x8x32x16xf32>) -> tensor<256x128xf32> { - // expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x32x4x32xf32>', got 'tensor<8x8x32x16xf32>'}} - %0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x32x16xf32> -> tensor<256x128xf32> +func.func @unpack_invalid_source_shape(%output: tensor<256x128xf32>, %input: tensor<8x8x4x32xf32>) -> tensor<256x128xf32> { + // expected-error@+1 {{expected 'tensor<8x32x4x32xf32>' for the packed domain value, got 'tensor<8x8x4x32xf32>'}} + %0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x4x32xf32> -> tensor<256x128xf32> return %0 : tensor<256x128xf32> } diff --git a/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir b/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir deleted file mode 100644 index 5297eeb..0000000 --- a/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir +++ /dev/null @@ -1,42 +0,0 @@ -// RUN: mlir-opt \ -// RUN: --verify-each \ -// RUN: --pass-pipeline="builtin.module(func.func(sharding-propagation))" \ -// RUN: %s | FileCheck %s - -mesh.mesh @mesh_2(shape = 2) - -// CHECK-LABEL: func @matmul_shard_prallel_axis -func.func @matmul_shard_prallel_axis( - // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<2x3xf32>, - %arg0 : tensor<2x3xf32>, - // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<3x2xf32>, - %arg1 : tensor<3x2xf32>, - // CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<2x2xf32> - %out_dps: tensor<2x2xf32> -) -> tensor<2x2xf32> { - // CHECK: %[[SIN1_ANNOTATED_0:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding - // CHECK-NEXT: %[[IN1_ANNOTATED_0:.*]] = mesh.shard %[[IN1]] to %[[SIN1_ANNOTATED_0]] : tensor<2x3xf32> - // CHECK: %[[SIN1_ANNOTATED_1:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding - // CHECK-NEXT: %[[IN1_ANNOTATED_1:.*]] = mesh.shard %[[IN1_ANNOTATED_0]] to %[[SIN1_ANNOTATED_1]] annotate_for_users : tensor<2x3xf32> - // CHECK: %[[SIN2_ANNOTATED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[]] : !mesh.sharding - // CHECK-NEXT: %[[IN2_ANNOTATED:.*]] = mesh.shard %[[IN2]] to %[[SIN2_ANNOTATED]] annotate_for_users : tensor<3x2xf32> - // CHECK: %[[SDPS_OUT_ANNOTATED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding - // CHECK-NEXT: %[[DPS_OUT_ANNOTATED:.*]] = mesh.shard %[[DPS_OUT]] to %[[SDPS_OUT_ANNOTATED]] annotate_for_users : tensor<2x2xf32> - %sarg0_sharded = mesh.sharding @mesh_2 split_axes = [[0]] : !mesh.sharding - %arg0_sharded = mesh.shard %arg0 to %sarg0_sharded : tensor<2x3xf32> - - // CHECK: %[[RES:.*]] = linalg.matmul ins(%[[IN1_ANNOTATED_1]], %[[IN2_ANNOTATED]] : tensor<2x3xf32>, tensor<3x2xf32>) - // CHECK-SAME: outs(%[[DPS_OUT_ANNOTATED]] : tensor<2x2xf32>) -> tensor<2x2xf32> - %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>) - outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32> - - // CHECK: %[[SRES_ANNOTATED_0:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding - // CHECK-NEXT: %[[RES_ANNOTATED_0:.*]] = mesh.shard %[[RES]] to %[[SRES_ANNOTATED_0]] : tensor<2x2xf32> - // CHECK: %[[SRES_ANNOTATED_1:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[]] : !mesh.sharding - // CHECK-NEXT: %[[RES_ANNOTATED_1:.*]] = mesh.shard %[[RES_ANNOTATED_0]] to %[[SRES_ANNOTATED_1]] annotate_for_users : tensor<2x2xf32> - %sres_sharded = mesh.sharding @mesh_2 split_axes = [[]] : !mesh.sharding - %res_sharded = mesh.shard %res to %sres_sharded annotate_for_users : tensor<2x2xf32> - - // CHECK: return %[[RES_ANNOTATED_1]] : tensor<2x2xf32> - return %res_sharded : tensor<2x2xf32> -} diff --git a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir b/mlir/test/Dialect/Linalg/shard-partition.mlir index ce12b29..aee9707 100644 --- a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir +++ b/mlir/test/Dialect/Linalg/shard-partition.mlir @@ -1,15 +1,15 @@ // RUN: mlir-opt \ -// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-single-fold))" \ +// RUN: --pass-pipeline="builtin.module(func.func(shard-partition,test-single-fold))" \ // RUN: --split-input-file \ // RUN: %s | FileCheck %s // CHECK: #[[$MAP_IDENTITY_1D:.*]] = affine_map<(d0) -> (d0)> #map_identity_1d = affine_map<(d0) -> (d0)> -mesh.mesh @mesh_1d(shape = 2) +shard.grid @grid_1d(shape = 2) -// CHECK-LABEL: func @elementwise_static_1d_mesh_static_1d_tensor -func.func @elementwise_static_1d_mesh_static_1d_tensor( +// CHECK-LABEL: func @elementwise_static_1d_grid_static_1d_tensor +func.func @elementwise_static_1d_grid_static_1d_tensor( // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<1xi8>, %in1: tensor<2xi8>, // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<1xi8>, @@ -18,13 +18,13 @@ func.func @elementwise_static_1d_mesh_static_1d_tensor( %dps_out: tensor<2xi8> // CHECK-SAME: -> tensor<1xi8> { ) -> tensor<2xi8> { - %sharding = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %in1_sharded1 = mesh.shard %in1 to %sharding : tensor<2xi8> - %in1_sharded2 = mesh.shard %in1_sharded1 to %sharding annotate_for_users : tensor<2xi8> - %in2_sharded1 = mesh.shard %in2 to %sharding : tensor<2xi8> - %in2_sharded2 = mesh.shard %in2_sharded1 to %sharding annotate_for_users : tensor<2xi8> - %dps_out_sharded1 = mesh.shard %dps_out to %sharding : tensor<2xi8> - %dps_out_shared2 = mesh.shard %dps_out_sharded1 to %sharding annotate_for_users : tensor<2xi8> + %sharding = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %in1_sharded1 = shard.shard %in1 to %sharding : tensor<2xi8> + %in1_sharded2 = shard.shard %in1_sharded1 to %sharding annotate_for_users : tensor<2xi8> + %in2_sharded1 = shard.shard %in2 to %sharding : tensor<2xi8> + %in2_sharded2 = shard.shard %in2_sharded1 to %sharding annotate_for_users : tensor<2xi8> + %dps_out_sharded1 = shard.shard %dps_out to %sharding : tensor<2xi8> + %dps_out_shared2 = shard.shard %dps_out_sharded1 to %sharding annotate_for_users : tensor<2xi8> // CHECK: %[[RES:.*]] = linalg.generic { // CHECK-SAME: indexing_maps = [#[[$MAP_IDENTITY_1D]], #[[$MAP_IDENTITY_1D]], #[[$MAP_IDENTITY_1D]]], // CHECK-SAME: iterator_types = ["parallel"]} @@ -39,18 +39,18 @@ func.func @elementwise_static_1d_mesh_static_1d_tensor( %res_scalar = arith.muli %in1_scalar, %in2_scalar : i8 linalg.yield %res_scalar : i8 } -> tensor<2xi8> - %res_sharded1 = mesh.shard %res to %sharding : tensor<2xi8> - %res_shared2 = mesh.shard %res_sharded1 to %sharding annotate_for_users : tensor<2xi8> + %res_sharded1 = shard.shard %res to %sharding : tensor<2xi8> + %res_shared2 = shard.shard %res_sharded1 to %sharding annotate_for_users : tensor<2xi8> // CHECK: return %[[RES]] : tensor<1xi8> return %res_shared2 : tensor<2xi8> } // ----- -mesh.mesh @mesh_1d(shape = 4) +shard.grid @grid_1d(shape = 4) -// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_parallel_iterator_sharding -func.func @matmul_1d_mesh_static_tensors_parallel_iterator_sharding( +// CHECK-LABEL: func @matmul_1d_grid_static_tensors_parallel_iterator_sharding +func.func @matmul_1d_grid_static_tensors_parallel_iterator_sharding( // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<1x3xi8>, %in1: tensor<4x3xi8>, // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<3x8xi8>, @@ -59,32 +59,32 @@ func.func @matmul_1d_mesh_static_tensors_parallel_iterator_sharding( %dps_out: tensor<4x8xi8> // CHECK-SAME: -> tensor<1x8xi8> { ) -> tensor<4x8xi8> { - %sharding = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %in1_shared1 = mesh.shard %in1 to %sharding : tensor<4x3xi8> - %in1_shared2 = mesh.shard %in1_shared1 to %sharding annotate_for_users : tensor<4x3xi8> - %sharding2 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding - %in2_shared1 = mesh.shard %in2 to %sharding2 : tensor<3x8xi8> - %in2_shared2 = mesh.shard %in2_shared1 to %sharding2 annotate_for_users : tensor<3x8xi8> - %dps_out_shared1 = mesh.shard %dps_out to %sharding : tensor<4x8xi8> - %dps_out_shared2 = mesh.shard %dps_out_shared1 to %sharding annotate_for_users : tensor<4x8xi8> + %sharding = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %in1_shared1 = shard.shard %in1 to %sharding : tensor<4x3xi8> + %in1_shared2 = shard.shard %in1_shared1 to %sharding annotate_for_users : tensor<4x3xi8> + %sharding2 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding + %in2_shared1 = shard.shard %in2 to %sharding2 : tensor<3x8xi8> + %in2_shared2 = shard.shard %in2_shared1 to %sharding2 annotate_for_users : tensor<3x8xi8> + %dps_out_shared1 = shard.shard %dps_out to %sharding : tensor<4x8xi8> + %dps_out_shared2 = shard.shard %dps_out_shared1 to %sharding annotate_for_users : tensor<4x8xi8> // CHECK: %[[RES:.*]] = linalg.matmul // CHECK-SAME: ins(%[[IN1]], %[[IN2]] : tensor<1x3xi8>, tensor<3x8xi8>) // CHECK-SAME: outs(%[[DPS_OUT]] : tensor<1x8xi8>) // CHECK-SAME: -> tensor<1x8xi8> %res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x3xi8>, tensor<3x8xi8>) outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8> - %res_shared1 = mesh.shard %res to %sharding : tensor<4x8xi8> - %res_shared2 = mesh.shard %res_shared1 to %sharding annotate_for_users : tensor<4x8xi8> + %res_shared1 = shard.shard %res to %sharding : tensor<4x8xi8> + %res_shared2 = shard.shard %res_shared1 to %sharding annotate_for_users : tensor<4x8xi8> // CHECK: return %[[RES]] : tensor<1x8xi8> return %res_shared2 : tensor<4x8xi8> } // ----- -mesh.mesh @mesh_1d(shape = 3) +shard.grid @grid_1d(shape = 3) -// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding -func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding( +// CHECK-LABEL: func @matmul_1d_grid_static_tensors_reduction_iterator_sharding +func.func @matmul_1d_grid_static_tensors_reduction_iterator_sharding( // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<4x2xi8>, %in1: tensor<4x6xi8>, // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<2x8xi8>, @@ -93,19 +93,19 @@ func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding( %dps_out: tensor<4x8xi8> // CHECK-SAME: -> tensor<4x8xi8> { ) -> tensor<4x8xi8> { - %sharding = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding - %in1_shared1 = mesh.shard %in1 to %sharding : tensor<4x6xi8> - %in1_shared2 = mesh.shard %in1_shared1 to %sharding annotate_for_users : tensor<4x6xi8> - %sharding2 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %in2_shared1 = mesh.shard %in2 to %sharding2 : tensor<6x8xi8> - %in2_shared2 = mesh.shard %in2_shared1 to %sharding2 annotate_for_users : tensor<6x8xi8> - %sharding3 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding - %dps_out_shared1 = mesh.shard %dps_out to %sharding3 : tensor<4x8xi8> - %dps_out_shared2 = mesh.shard %dps_out_shared1 to %sharding3 annotate_for_users : tensor<4x8xi8> + %sharding = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding + %in1_shared1 = shard.shard %in1 to %sharding : tensor<4x6xi8> + %in1_shared2 = shard.shard %in1_shared1 to %sharding annotate_for_users : tensor<4x6xi8> + %sharding2 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %in2_shared1 = shard.shard %in2 to %sharding2 : tensor<6x8xi8> + %in2_shared2 = shard.shard %in2_shared1 to %sharding2 annotate_for_users : tensor<6x8xi8> + %sharding3 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding + %dps_out_shared1 = shard.shard %dps_out to %sharding3 : tensor<4x8xi8> + %dps_out_shared2 = shard.shard %dps_out_shared1 to %sharding3 annotate_for_users : tensor<4x8xi8> // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C0_I8:.*]] = arith.constant 0 : i8 - // CHECK-DAG: %[[PROCESS_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index - // CHECK-DAG: %[[MESH_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index + // CHECK-DAG: %[[PROCESS_IDX:.*]] = shard.process_multi_index on @grid_1d axes = [0] : index + // CHECK-DAG: %[[SHARD_SIZE:.*]] = shard.grid_shape @grid_1d axes = [0] : index // CHECK: %[[DPS_INIT_OPERAND_CONDITION:.*]] = arith.cmpi eq, %[[PROCESS_IDX]], %[[C0]] : index // CHECK: %[[DPS_INIT_OPERAND:.*]] = scf.if %[[DPS_INIT_OPERAND_CONDITION]] -> (tensor<4x8xi8>) { // CHECK: scf.yield %[[DPS_OUT]] : tensor<4x8xi8> @@ -117,21 +117,21 @@ func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding( // CHECK: } // CHECK: %[[SHARDED_MATMUL:.*]] = linalg.matmul ins(%[[IN1]], %[[IN2]] : tensor<4x2xi8>, tensor<2x8xi8>) // CHECK-SAME: outs(%[[DPS_INIT_OPERAND]] : tensor<4x8xi8>) -> tensor<4x8xi8> - // CHECK: %[[ALL_REDUCED:.*]] = mesh.all_reduce %[[SHARDED_MATMUL]] on @mesh_1d mesh_axes = [0] : tensor<4x8xi8> -> tensor<4x8xi8> + // CHECK: %[[ALL_REDUCED:.*]] = shard.all_reduce %[[SHARDED_MATMUL]] on @grid_1d grid_axes = [0] : tensor<4x8xi8> -> tensor<4x8xi8> %res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x6xi8>, tensor<6x8xi8>) outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8> - %res_shared1 = mesh.shard %res to %sharding3 : tensor<4x8xi8> - %res_shared2 = mesh.shard %res_shared1 to %sharding3 annotate_for_users : tensor<4x8xi8> + %res_shared1 = shard.shard %res to %sharding3 : tensor<4x8xi8> + %res_shared2 = shard.shard %res_shared1 to %sharding3 annotate_for_users : tensor<4x8xi8> // CHECK: return %[[ALL_REDUCED]] : tensor<4x8xi8> return %res_shared2 : tensor<4x8xi8> } // ----- -mesh.mesh @mesh_1d(shape = 4) +shard.grid @grid_1d(shape = 4) -// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis -func.func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis( +// CHECK-LABEL: func @matmul_1d_grid_static_tensors_parallel_iterator_unsplit_last_axis +func.func @matmul_1d_grid_static_tensors_parallel_iterator_unsplit_last_axis( // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<4x6xi8>, %in1: tensor<4x6xi8>, // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<6x8xi8>, @@ -140,25 +140,25 @@ func.func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis( %dps_out: tensor<4x8xi8> // CHECK-SAME: -> tensor<4x8xi8> { ) -> tensor<4x8xi8> { - %sharding1 = mesh.sharding @mesh_1d split_axes = [[], []] : !mesh.sharding - %in1_replicated1 = mesh.shard %in1 to %sharding1 : tensor<4x6xi8> - %in1_replicated2 = mesh.shard %in1_replicated1 to %sharding1 annotate_for_users : tensor<4x6xi8> - // CHECK: %[[ALL_SLICE1:.*]] = mesh.all_slice %[[IN2]] on @mesh_1d mesh_axes = [0] slice_axis = 1 - %in2_replicated = mesh.shard %in2 to %sharding1 : tensor<6x8xi8> - %sharding2 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding - %in2_sharded = mesh.shard %in2_replicated to %sharding2 annotate_for_users : tensor<6x8xi8> - // CHECK: %[[ALL_SLICE2:.*]] = mesh.all_slice %[[DPS_OUT]] on @mesh_1d mesh_axes = [0] slice_axis = 1 - %dps_out_replicated = mesh.shard %dps_out to %sharding1 : tensor<4x8xi8> - %dps_out_sharded = mesh.shard %dps_out_replicated to %sharding2 annotate_for_users : tensor<4x8xi8> + %sharding1 = shard.sharding @grid_1d split_axes = [[], []] : !shard.sharding + %in1_replicated1 = shard.shard %in1 to %sharding1 : tensor<4x6xi8> + %in1_replicated2 = shard.shard %in1_replicated1 to %sharding1 annotate_for_users : tensor<4x6xi8> + // CHECK: %[[ALL_SLICE1:.*]] = shard.all_slice %[[IN2]] on @grid_1d grid_axes = [0] slice_axis = 1 + %in2_replicated = shard.shard %in2 to %sharding1 : tensor<6x8xi8> + %sharding2 = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding + %in2_sharded = shard.shard %in2_replicated to %sharding2 annotate_for_users : tensor<6x8xi8> + // CHECK: %[[ALL_SLICE2:.*]] = shard.all_slice %[[DPS_OUT]] on @grid_1d grid_axes = [0] slice_axis = 1 + %dps_out_replicated = shard.shard %dps_out to %sharding1 : tensor<4x8xi8> + %dps_out_sharded = shard.shard %dps_out_replicated to %sharding2 annotate_for_users : tensor<4x8xi8> // CHECK: %[[MATMUL_RES:.*]] = linalg.matmul // CHECK-SAME: ins(%[[IN1]], %[[ALL_SLICE1]] : tensor<4x6xi8>, tensor<6x2xi8>) // CHECK-SAME: outs(%[[ALL_SLICE2]] : tensor<4x2xi8>) // CHECK-SAME: -> tensor<4x2xi8> %res = linalg.matmul ins(%in1_replicated2, %in2_sharded : tensor<4x6xi8>, tensor<6x8xi8>) outs(%dps_out_sharded : tensor<4x8xi8>) -> tensor<4x8xi8> - // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[MATMUL_RES]] on @mesh_1d mesh_axes = [0] gather_axis = 1 : tensor<4x2xi8> -> tensor<4x8xi8> - %res_sharded = mesh.shard %res to %sharding2 : tensor<4x8xi8> - %res_replicated = mesh.shard %res_sharded to %sharding1 annotate_for_users : tensor<4x8xi8> + // CHECK: %[[ALL_GATHER:.*]] = shard.all_gather %[[MATMUL_RES]] on @grid_1d grid_axes = [0] gather_axis = 1 : tensor<4x2xi8> -> tensor<4x8xi8> + %res_sharded = shard.shard %res to %sharding2 : tensor<4x8xi8> + %res_replicated = shard.shard %res_sharded to %sharding1 annotate_for_users : tensor<4x8xi8> // CHECK: return %[[ALL_GATHER]] : tensor<4x8xi8> return %res_replicated : tensor<4x8xi8> } diff --git a/mlir/test/Dialect/Linalg/sharding-propagation.mlir b/mlir/test/Dialect/Linalg/sharding-propagation.mlir new file mode 100644 index 0000000..e0ecefc --- /dev/null +++ b/mlir/test/Dialect/Linalg/sharding-propagation.mlir @@ -0,0 +1,42 @@ +// RUN: mlir-opt \ +// RUN: --verify-each \ +// RUN: --pass-pipeline="builtin.module(func.func(sharding-propagation))" \ +// RUN: %s | FileCheck %s + +shard.grid @grid_2(shape = 2) + +// CHECK-LABEL: func @matmul_shard_prallel_axis +func.func @matmul_shard_prallel_axis( + // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<2x3xf32>, + %arg0 : tensor<2x3xf32>, + // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<3x2xf32>, + %arg1 : tensor<3x2xf32>, + // CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<2x2xf32> + %out_dps: tensor<2x2xf32> +) -> tensor<2x2xf32> { + // CHECK: %[[SIN1_ANNOTATED_0:.*]] = shard.sharding @grid_2 split_axes = {{\[}}[0]] : !shard.sharding + // CHECK-NEXT: %[[IN1_ANNOTATED_0:.*]] = shard.shard %[[IN1]] to %[[SIN1_ANNOTATED_0]] : tensor<2x3xf32> + // CHECK: %[[SIN1_ANNOTATED_1:.*]] = shard.sharding @grid_2 split_axes = {{\[}}[0]] : !shard.sharding + // CHECK-NEXT: %[[IN1_ANNOTATED_1:.*]] = shard.shard %[[IN1_ANNOTATED_0]] to %[[SIN1_ANNOTATED_1]] annotate_for_users : tensor<2x3xf32> + // CHECK: %[[SIN2_ANNOTATED:.*]] = shard.sharding @grid_2 split_axes = {{\[}}[]] : !shard.sharding + // CHECK-NEXT: %[[IN2_ANNOTATED:.*]] = shard.shard %[[IN2]] to %[[SIN2_ANNOTATED]] annotate_for_users : tensor<3x2xf32> + // CHECK: %[[SDPS_OUT_ANNOTATED:.*]] = shard.sharding @grid_2 split_axes = {{\[}}[0]] : !shard.sharding + // CHECK-NEXT: %[[DPS_OUT_ANNOTATED:.*]] = shard.shard %[[DPS_OUT]] to %[[SDPS_OUT_ANNOTATED]] annotate_for_users : tensor<2x2xf32> + %sarg0_sharded = shard.sharding @grid_2 split_axes = [[0]] : !shard.sharding + %arg0_sharded = shard.shard %arg0 to %sarg0_sharded : tensor<2x3xf32> + + // CHECK: %[[RES:.*]] = linalg.matmul ins(%[[IN1_ANNOTATED_1]], %[[IN2_ANNOTATED]] : tensor<2x3xf32>, tensor<3x2xf32>) + // CHECK-SAME: outs(%[[DPS_OUT_ANNOTATED]] : tensor<2x2xf32>) -> tensor<2x2xf32> + %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>) + outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32> + + // CHECK: %[[SRES_ANNOTATED_0:.*]] = shard.sharding @grid_2 split_axes = {{\[}}[0]] : !shard.sharding + // CHECK-NEXT: %[[RES_ANNOTATED_0:.*]] = shard.shard %[[RES]] to %[[SRES_ANNOTATED_0]] : tensor<2x2xf32> + // CHECK: %[[SRES_ANNOTATED_1:.*]] = shard.sharding @grid_2 split_axes = {{\[}}[]] : !shard.sharding + // CHECK-NEXT: %[[RES_ANNOTATED_1:.*]] = shard.shard %[[RES_ANNOTATED_0]] to %[[SRES_ANNOTATED_1]] annotate_for_users : tensor<2x2xf32> + %sres_sharded = shard.sharding @grid_2 split_axes = [[]] : !shard.sharding + %res_sharded = shard.shard %res to %sres_sharded annotate_for_users : tensor<2x2xf32> + + // CHECK: return %[[RES_ANNOTATED_1]] : tensor<2x2xf32> + return %res_sharded : tensor<2x2xf32> +} diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir index 81fd7a8..9e7681d 100644 --- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir +++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir @@ -326,23 +326,23 @@ module attributes {transform.with_named_sequence} { // ----- // CHECK-LABEL: func.func @pack_with_pad( -func.func @pack_with_pad(%src: tensor<4225x12xf32>, %dest: tensor<265x16x16x1xf32>) - -> tensor<265x16x16x1xf32> { +func.func @pack_with_pad(%src: tensor<4225x12xf32>, %dest: tensor<265x12x16x1xf32>) + -> tensor<265x12x16x1xf32> { // CHECK: tensor.pad {{.*}} low[0, 0] - // CHECK: : tensor<4225x12xf32> to tensor<4240x16xf32> + // CHECK: : tensor<4225x12xf32> to tensor<4240x12xf32> // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2, 3]] - // CHECK-SAME: : tensor<4240x16xf32> into tensor<265x16x16x1xf32> + // CHECK-SAME: : tensor<4240x12xf32> into tensor<265x16x12x1xf32> // CHECK: linalg.transpose - // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>) - // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>) + // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<265x16x12x1xf32>) + // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<265x12x16x1xf32>) // CHECK-SAME: permutation = [0, 2, 1, 3] %cst = arith.constant 0.000000e+00 : f32 %0 = linalg.pack %src padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [16, 1] into %dest - : tensor<4225x12xf32> -> tensor<265x16x16x1xf32> - return %0 : tensor<265x16x16x1xf32> + : tensor<4225x12xf32> -> tensor<265x12x16x1xf32> + return %0 : tensor<265x12x16x1xf32> } module attributes {transform.with_named_sequence} { diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir index 78619b6..981f5dc 100644 --- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir @@ -52,22 +52,22 @@ module { // CHECK-LABEL: @generic // CHECK-SAME: %[[T0:.*]]: tensor<7x5xf32>, -// CHECK-SAME: %[[T1:.*]]: tensor<7x11x12xf32>) - func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x12xf32>) -> tensor<7x11x12xf32> { +// CHECK-SAME: %[[T1:.*]]: tensor<7x11x11xf32>) + func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x11xf32>) -> tensor<7x11x11xf32> { // CHECK-DAG: %[[CST:.*]] = arith.constant 0. // CHECK: %[[PAD0:.*]] = tensor.pad %[[T0]] low[0, 0] high[2, 0] // CHECK: : tensor<7x5xf32> to tensor<9x5xf32> // CHECK: %[[PAD1:.*]] = tensor.pad %[[T1]] low[0, 0, 0] high[2, 4, 2] { - // CHECK: : tensor<7x11x12xf32> to tensor<9x15x14xf32> + // CHECK: : tensor<7x11x11xf32> to tensor<9x15x13xf32> // CHECK-NEXT: linalg.generic - // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 12] [1, 1, 1] : tensor<9x15x14xf32> to tensor<7x11x12xf32> - %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x12xf32>) { + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 11] [1, 1, 1] : tensor<9x15x13xf32> to tensor<7x11x11xf32> + %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x11xf32>) { ^bb0(%in: f32, %out: f32): linalg.yield %in : f32 - } -> tensor<7x11x12xf32> - return %0 : tensor<7x11x12xf32> + } -> tensor<7x11x11xf32> + return %0 : tensor<7x11x11xf32> } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { @@ -83,7 +83,7 @@ module { // ----- // CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3 + 5)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3 + 4)> // CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 5)> #map = affine_map<(d0, d1, d2) -> (d0, d1)> @@ -272,3 +272,136 @@ module attributes {transform.with_named_sequence} { } } +// ----- + +// CHECK-LABEL: pad_conv +func.func @pad_conv(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> { + + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 12] + // CHECK: : tensor<1x16x16x4xf32> to tensor<1x16x18x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12] + // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0] + // CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32> + // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32> + + %0 = linalg.conv_2d_nhwc_fhwc + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>) + outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> + return %0 : tensor<1x14x14x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of { + padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32] + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 16) * 16 + 2)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 16) * 16)> + +// CHECK-LABEL: pad_conv_dynamic +func.func @pad_conv_dynamic(%arg0: tensor<1x16x?x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x?x16xf32>) -> tensor<1x14x?x16xf32> { + + // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index + // CHECK: %[[D0_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32> + // CHECK: %[[D0_1:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x16x?x4xf32> + // CHECK: %[[H0:.*]] = affine.apply #[[$MAP0]]()[%[[D0_0]], %[[D0_1]]] + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, %[[H0]], 12] + // CHECK: : tensor<1x16x?x4xf32> to tensor<1x16x?x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12] + // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32> + // CHECK: %[[D1_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32> + // CHECK: %[[H1:.*]] = affine.apply #[[$MAP1]]()[%[[D0_0]], %[[D1_0]]] + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, %[[H1]], 0] + // CHECK: : tensor<1x14x?x16xf32> to tensor<1x14x?x16xf32> + // CHECK: %[[D2_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32> + // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, %[[D2_0]], 16] [1, 1, 1, 1] : tensor<1x14x?x16xf32> to tensor<1x14x?x16xf32> + + %0 = linalg.conv_2d_nhwc_fhwc + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<1x16x?x4xf32>, tensor<16x3x3x4xf32>) + outs(%arg2: tensor<1x14x?x16xf32>) -> tensor<1x14x?x16xf32> + return %0 : tensor<1x14x?x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of { + padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32] + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-LABEL: pad_conv_strided +func.func @pad_conv_strided(%arg0: tensor<1x42x42x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> { + + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 6, 12] + // CHECK: : tensor<1x42x42x4xf32> to tensor<1x42x48x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12] + // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0] + // CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32> + // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32> + + %0 = linalg.conv_2d_nhwc_fhwc + {dilations = dense<1> : tensor<2xi64>, strides = dense<3> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<1x42x42x4xf32>, tensor<16x3x3x4xf32>) + outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> + return %0 : tensor<1x14x14x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of { + padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32] + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-LABEL: pad_conv_dilated +func.func @pad_conv_dilated(%arg0: tensor<1x18x18x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> { + + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 12] + // CHECK: : tensor<1x18x18x4xf32> to tensor<1x18x20x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12] + // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0] + // CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32> + // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32> + + %0 = linalg.conv_2d_nhwc_fhwc + {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<1x18x18x4xf32>, tensor<16x3x3x4xf32>) + outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> + return %0 : tensor<1x14x14x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of { + padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32] + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir index 26c03ed..f741876 100644 --- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir @@ -69,22 +69,22 @@ module { // CHECK-LABEL: @generic // CHECK-SAME: %[[T0:.*]]: tensor<7x5xf32>, -// CHECK-SAME: %[[T1:.*]]: tensor<7x11x12xf32>) - func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x12xf32>) -> tensor<7x11x12xf32> { +// CHECK-SAME: %[[T1:.*]]: tensor<7x11x11xf32>) + func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x11xf32>) -> tensor<7x11x11xf32> { // CHECK-DAG: %[[CST:.*]] = arith.constant 0. // CHECK: %[[PAD0:.*]] = tensor.pad %[[T0]] low[0, 0] high[1, 0] // CHECK: : tensor<7x5xf32> to tensor<8x5xf32> // CHECK: %[[PAD1:.*]] = tensor.pad %[[T1]] low[0, 0, 0] high[1, 3, 1] { - // CHECK: : tensor<7x11x12xf32> to tensor<8x14x13xf32> + // CHECK: : tensor<7x11x11xf32> to tensor<8x14x12xf32> // CHECK-NEXT: linalg.generic - // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 12] [1, 1, 1] : tensor<8x14x13xf32> to tensor<7x11x12xf32> - %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x12xf32>) { + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 11] [1, 1, 1] : tensor<8x14x12xf32> to tensor<7x11x11xf32> + %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x11xf32>) { ^bb0(%in: f32, %out: f32): linalg.yield %in : f32 - } -> tensor<7x11x12xf32> - return %0 : tensor<7x11x12xf32> + } -> tensor<7x11x11xf32> + return %0 : tensor<7x11x11xf32> } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { @@ -102,7 +102,7 @@ module { // CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (-s0 + 8)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (-s0 + 13)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (-s0 + 12)> // CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 5)> #map = affine_map<(d0, d1, d2) -> (d0, d1)> @@ -127,13 +127,13 @@ module { // CHECK: %[[D2_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<?x11x?xf32> // CHECK: %[[H2:.*]] = affine.apply #[[$MAP1]]()[%[[D2_0]]] // CHECK: tensor.pad %{{.*}} low[0, 0, 0] high[%[[H1]], 3, %[[H2]]] { - // CHECK: : tensor<?x11x?xf32> to tensor<8x14x13xf32> + // CHECK: : tensor<?x11x?xf32> to tensor<8x14x12xf32> // // CHECK: %[[D0_2:.*]] = tensor.dim %{{.*}}, %[[C0]] : tensor<?x5xf32> // CHECK: %[[D2_1:.*]] = affine.apply #[[$MAP2]]()[%[[D0_2]]] - // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<8x5xf32>) outs(%{{.*}} : tensor<8x14x13xf32>) { - // CHECK: } -> tensor<8x14x13xf32> - // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [%[[D0_2]], 11, %[[D2_1]]] [1, 1, 1] : tensor<8x14x13xf32> to tensor<?x11x?xf32> + // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<8x5xf32>) outs(%{{.*}} : tensor<8x14x12xf32>) { + // CHECK: } -> tensor<8x14x12xf32> + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [%[[D0_2]], 11, %[[D2_1]]] [1, 1, 1] : tensor<8x14x12xf32> to tensor<?x11x?xf32> // %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<?x5xf32>) outs(%arg1 : tensor<?x11x?xf32>) { ^bb0(%in: f32, %out: f32): diff --git a/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir index c3ee892..d7722ea 100644 --- a/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir +++ b/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir @@ -230,18 +230,17 @@ func.func @vectorize_nd_tensor_extract_index_from_tensor(%arg0: tensor<3x3xf32>, // CHECK-SAME: %[[ARG4:.*]]: tensor<4x7x3x2xf32> // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[PV:.*]] = ub.poison : i32 -// CHECK-DAG: %[[CST:.*]] = arith.constant dense<3> : vector<7x2x4x3xindex> +// CHECK-DAG: %[[CST:.*]] = arith.constant dense<3> : vector<4x3xindex> // CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<true> : vector<4x7x3x2xi1> // CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<4x7x3x2xf32> // CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], %[[PV]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32> // CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], %[[PV]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32> // CHECK: %[[CAST:.*]] = arith.index_cast %[[V0]] : vector<4x3xi32> to vector<4x3xindex> -// CHECK: %[[B1:.*]] = vector.broadcast %[[CAST]] : vector<4x3xindex> to vector<7x2x4x3xindex> // CHECK: %[[CAST_1:.*]] = arith.index_cast %[[V1]] : vector<4x3xi32> to vector<4x3xindex> -// CHECK: %[[B2:.*]] = vector.broadcast %[[CAST_1]] : vector<4x3xindex> to vector<7x2x4x3xindex> -// CHECK: %[[MULI:.*]] = arith.muli %[[B1]], %[[CST]] : vector<7x2x4x3xindex> -// CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[MULI]] : vector<7x2x4x3xindex> -// CHECK: %[[T:.*]] = vector.transpose %[[ADDI]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex> +// CHECK: %[[MULI:.*]] = arith.muli %[[CAST]], %[[CST]] : vector<4x3xindex> +// CHECK: %[[ADDI:.*]] = arith.addi %[[CAST_1]], %[[MULI]] : vector<4x3xindex> +// CHECK: %[[B:.*]] = vector.broadcast %[[ADDI]] : vector<4x3xindex> to vector<7x2x4x3xindex> +// CHECK: %[[T:.*]] = vector.transpose %[[B]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex> // CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]]] [%[[T]]], %[[CST_1]], %[[PASSTHRU]] : tensor<3x3xf32>, vector<4x7x3x2xindex>, vector<4x7x3x2xi1>, vector<4x7x3x2xf32> into vector<4x7x3x2xf32> // CHECK: vector.transfer_write %[[GATHER]], %[[ARG4]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true]} : vector<4x7x3x2xf32>, tensor<4x7x3x2xf32> @@ -270,20 +269,16 @@ func.func @vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load(% // CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32> // CHECK-SAME: %[[ARG1:.*]]: index // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[CST:.*]] = arith.constant dense<768> : vector<1x8xindex> -// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<128> : vector<1x8xindex> // CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32> -// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<true> : vector<8x1xi1> -// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex> +// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<true> : vector<8x1xi1> +// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<[0, 98304, 196608, 294912, 393216, 491520, 589824, 688128]> : vector<8xindex> // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<8x1xf32> -// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_3]] : vector<8xindex> to vector<1x8xindex> // CHECK: %[[ADDI_ARG1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index -// CHECK: %[[MULI_1:.*]] = arith.muli %[[B1]], %[[CST_0]] : vector<1x8xindex> -// CHECK: %[[MULI_2:.*]] = arith.muli %[[MULI_1]], %[[CST]] : vector<1x8xindex> -// CHECK: %[[T:.*]] = vector.transpose %[[MULI_2]], [1, 0] : vector<1x8xindex> to vector<8x1xindex> +// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_1]] : vector<8xindex> to vector<1x8xindex> +// CHECK: %[[T:.*]] = vector.transpose %[[B1]], [1, 0] : vector<1x8xindex> to vector<8x1xindex> // CHECK: %[[B2:.*]] = vector.broadcast %[[ADDI_ARG1]] : index to vector<8x1xindex> // CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[T]] : vector<8x1xindex> -// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_2]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32> +// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_0]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32> // CHECK: vector.transfer_write %[[GATHER]], %[[EMPTY]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32> // ----- @@ -309,15 +304,13 @@ func.func @index_from_output_column_vector_gather_load(%src: tensor<8x128xf32>) // CHECK-LABEL: func.func @index_from_output_column_vector_gather_load( // CHECK-SAME: %[[SRC:.*]]: tensor<8x128xf32>) -> tensor<8x1xf32> { -// CHECK: %[[C128:.*]] = arith.constant dense<128> : vector<1x8xindex> +// CHECK: %[[IDX_VEC:.*]] = arith.constant dense<[0, 128, 256, 384, 512, 640, 768, 896]> : vector<8xindex> // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[PASS_THRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32> // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8x1xi1> -// CHECK: %[[IDX_VEC:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex> // CHECK: %[[OUT:.*]] = tensor.empty() : tensor<8x1xf32> // CHECK: %[[B:.*]] = vector.broadcast %[[IDX_VEC]] : vector<8xindex> to vector<1x8xindex> -// CHECK: %[[MUL:.*]] = arith.muli %[[B]], %[[C128]] : vector<1x8xindex> -// CHECK: %[[TR:.*]] = vector.transpose %[[MUL]], [1, 0] : vector<1x8xindex> to vector<8x1xindex> +// CHECK: %[[TR:.*]] = vector.transpose %[[B]], [1, 0] : vector<1x8xindex> to vector<8x1xindex> // CHECK: %[[GATHER:.*]] = vector.gather %[[SRC]]{{\[}}%[[C0]], %[[C0]]] {{\[}}%[[TR]]], %[[MASK]], %[[PASS_THRU]] : tensor<8x128xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32> // CHECK: %[[RES:.*]] = vector.transfer_write %[[GATHER]], %[[OUT]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32> // CHECK: return %[[RES]] : tensor<8x1xf32> @@ -420,12 +413,12 @@ func.func @vectorize_nd_tensor_extract_with_affine_apply_gather(%6: tensor<80x16 // CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<true> : vector<1x4xi1> // CHECK-DAG: %[[VAL_5:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32> // CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant dense<16> : vector<1x4xindex> +// CHECK-DAG: %[[VAL_7:.*]] = arith.constant dense<16> : vector<4xindex> // CHECK: %[[VAL_8:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex> // CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : vector<4xindex> -// CHECK: %[[VAL_10:.*]] = vector.broadcast %[[VAL_9]] : vector<4xindex> to vector<1x4xindex> -// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_10]], %[[VAL_7]] : vector<1x4xindex> -// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_7]] : vector<1x4xindex> +// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_9]], %[[VAL_7]] : vector<4xindex> +// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_7]] : vector<4xindex> +// CHECK: %[[VAL_12:.*]] = vector.broadcast %[[VAL_11]] : vector<4xindex> to vector<1x4xindex> // CHECK: %[[VAL_13:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_12]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32> // CHECK: %[[VAL_14:.*]] = vector.transfer_write %[[VAL_13]], %[[VAL_2]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32> // CHECK: return %[[VAL_14]] : tensor<1x4xf32> @@ -450,14 +443,12 @@ func.func @vectorize_nd_tensor_extract_with_maxsi_gather(%arg0: tensor<80x16xf32 // CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_maxsi_gather( // CHECK-SAME: %[[VAL_0:.*]]: tensor<80x16xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> { -// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<1264> : vector<1x4xindex> +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[1264, 1265, 1266, 1267]> : vector<4xindex> // CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<true> : vector<1x4xi1> // CHECK-DAG: %[[VAL_5:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32> // CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index // CHECK: %[[VAL_7:.*]] = vector.broadcast %[[VAL_2]] : vector<4xindex> to vector<1x4xindex> -// CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_3]] : vector<1x4xindex> -// CHECK: %[[VAL_9:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_8]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32> +// CHECK: %[[VAL_9:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_7]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32> // CHECK: %[[VAL_10:.*]] = vector.transfer_write %[[VAL_9]], %[[VAL_1]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32> // CHECK: return %[[VAL_10]] : tensor<1x4xf32> // CHECK: } @@ -519,13 +510,13 @@ func.func @vectorize_reverse_like_tensor_extract(%arg0: tensor<1x2x3xf32>, %arg1 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]] // CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]] // CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]] -// CHECK-DAG: %[[CST:.+]] = arith.constant dense<3> : vector<1x1x3xindex> +// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x3xi1> // CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32> // CHECK-DAG: %[[INIT_IDX:.+]] = arith.constant dense<[2, 1, 0]> : vector<3xindex> -// CHECK: %[[T0:.+]] = vector.broadcast %[[ARG2]] : index to vector<1x1x3xindex> -// CHECK: %[[T1:.+]] = arith.muli %[[T0]], %[[CST]] : vector<1x1x3xindex> +// CHECK: %[[T0:.+]] = arith.muli %[[ARG2]], %[[C3]] : index +// CHECK: %[[T1:.+]] = vector.broadcast %[[T0]] : index to vector<1x1x3xindex> // CHECK: %[[T2:.+]] = vector.broadcast %[[INIT_IDX]] // CHECK: %[[T3:.+]] = arith.addi %[[T2]], %[[T1]] // CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[T3]]], %[[MASK]], %[[PASSTHRU]] diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir index 98e8f50..d41d861 100644 --- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir +++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir @@ -941,20 +941,17 @@ module attributes {transform.with_named_sequence} { // CHECK-LABEL: func @test_vectorize_dynamic_shapes_unpack // CHECK-SAME: %[[ARG_0:.*]]: tensor<?x?xf32>, +// CHECK-SAME: %[[ARG_1:.*]]: tensor<?x?x16x2xf32> func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x16x2xf32>) -> tensor<?x?xf32> { // CHECK: %[[C0:.*]] = arith.constant 0 -// CHECK: %[[DIM:.*]] = tensor.dim %arg0, %[[C0]] : tensor<?x?xf32> -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[DIM0:.*]] = tensor.dim %arg0, %[[C1]] : tensor<?x?xf32> -// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 // CHECK: %[[C01:.*]] = arith.constant 0 // CHECK: %[[C02:.*]] = arith.constant 0 -// CHECK: %[[DIM4:.*]] = tensor.dim %arg1, %[[C02]] : tensor<?x?x16x2xf32> -// CHECK: %[[CNST14:.*]] = arith.constant 1 -// CHECK: %[[DIM6:.*]] = tensor.dim %arg1, %[[CNST14]] : tensor<?x?x16x2xf32> +// CHECK: %[[DIM_0:.*]] = tensor.dim %[[ARG_1]], %[[C02]] : tensor<?x?x16x2xf32> +// CHECK: %[[C1:.*]] = arith.constant 1 +// CHECK: %[[DIM6:.*]] = tensor.dim %[[ARG_1]], %[[C1]] : tensor<?x?x16x2xf32> // CHECK: %[[CNST16:.*]] = arith.constant 16 : index // CHECK: %[[CNST2:.*]] = arith.constant 2 : index -// CHECK: %[[readMsk0:.*]] = vector.create_mask %[[DIM4]], %[[DIM6]], %[[CNST16]], %[[CNST2]] : vector<2x1x16x2xi1> +// CHECK: %[[readMsk0:.*]] = vector.create_mask %[[DIM_0]], %[[DIM6]], %[[CNST16]], %[[CNST2]] : vector<2x1x16x2xi1> // CHECK: %[[read0:.*]] = vector.mask %[[readMsk0]] {{.*}} vector.transfer_read %{{.*}} : tensor<?x?x16x2xf32>, vector<2x1x16x2xf32> } : vector<2x1x16x2xi1> -> vector<2x1x16x2xf32> // CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 3, 1, 2] : vector<2x1x16x2xf32> to vector<2x2x1x16xf32> // CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<2x2x1x16xf32> to vector<4x16xf32> diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir deleted file mode 100644 index aff07bb..0000000 --- a/mlir/test/Dialect/Mesh/canonicalization.mlir +++ /dev/null @@ -1,248 +0,0 @@ -// RUN: mlir-opt --canonicalize %s | FileCheck %s - -mesh.mesh @mesh0(shape = 2x4) - -// CHECK-LABEL: func @all_reduce_empty_mesh_axes -func.func @all_reduce_empty_mesh_axes( -// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> - %arg0 : tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NOT: mesh.all_reduce - %0 = mesh.all_reduce %arg0 on @mesh0 - mesh_axes = [] - : tensor<4xf32> -> tensor<4xf32> -// CHECK: return %[[ARG]] - return %0 : tensor<4xf32> -} - -// CHECK-LABEL: func @all_reduce_empty_mesh_axes_different_return_type -func.func @all_reduce_empty_mesh_axes_different_return_type( - %arg0 : tensor<4xf32>) -> tensor<4xf64> { -// CHECK: mesh.all_reduce - %0 = mesh.all_reduce %arg0 on @mesh0 -// CHECK-NOT: mesh_axes - mesh_axes = [] - : tensor<4xf32> -> tensor<4xf64> - return %0 : tensor<4xf64> -} - -// CHECK-LABEL: func @all_reduce_default_reduction -func.func @all_reduce_default_reduction( - %arg0 : tensor<4xf32>) -> tensor<4xf64> { - %0 = mesh.all_reduce %arg0 on @mesh0 - mesh_axes = [0] -// CHECK-NOT: reduction - reduction = sum - : tensor<4xf32> -> tensor<4xf64> - return %0 : tensor<4xf64> -} - -// CHECK-LABEL: func @all_to_all_empty_mesh_axes -func.func @all_to_all_empty_mesh_axes( -// CHECK-SAME: %[[ARG:.*]]: tensor<8xf32> - %arg0 : tensor<8xf32>) -> tensor<8xf32> { -// CHECK-NOT: mesh.all_to_all - %0 = mesh.all_to_all %arg0 on @mesh0 - mesh_axes = [] - split_axis = 0 - concat_axis = 0 - : tensor<8xf32> -> tensor<8xf32> -// CHECK: return %[[ARG]] - return %0 : tensor<8xf32> -} - -// CHECK-LABEL: func @all_gather_empty_mesh_axes -func.func @all_gather_empty_mesh_axes( -// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> - %arg0 : tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NOT: mesh.all_gather - %0 = mesh.all_gather %arg0 on @mesh0 - mesh_axes = [] - gather_axis = 0 - : tensor<4xf32> -> tensor<4xf32> -// CHECK: return %[[ARG]] - return %0 : tensor<4xf32> -} - -// CHECK-LABEL: func @all_slice_empty_mesh_axes -func.func @all_slice_empty_mesh_axes( -// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> - %arg0 : tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NOT: mesh.scatter - %0 = mesh.all_slice %arg0 on @mesh0 - mesh_axes = [] - slice_axis = 0 - : tensor<4xf32> -> tensor<4xf32> -// CHECK: return %[[ARG]] - return %0 : tensor<4xf32> -} - -// CHECK-LABEL: func @broadcast_empty_mesh_axes -func.func @broadcast_empty_mesh_axes( -// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> - %arg0 : tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NOT: mesh.broadcast - %0 = mesh.broadcast %arg0 on @mesh0 - mesh_axes = [] - root = [] - : (tensor<4xf32>) -> tensor<4xf32> -// CHECK: return %[[ARG]] - return %0 : tensor<4xf32> -} - -// CHECK-LABEL: func @gather_empty_mesh_axes -func.func @gather_empty_mesh_axes( -// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> - %arg0 : tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NOT: mesh.gather - %0 = mesh.gather %arg0 on @mesh0 - mesh_axes = [] - gather_axis = 0 - root = [] - : (tensor<4xf32>) -> tensor<4xf32> -// CHECK: return %[[ARG]] - return %0 : tensor<4xf32> -} - -// CHECK-LABEL: func @receive_empty_mesh_axes -func.func @receive_empty_mesh_axes( -// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> - %arg0 : tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NOT: mesh.recv - %0 = mesh.recv %arg0 on @mesh0 - mesh_axes = [] - : (tensor<4xf32>) -> tensor<4xf32> -// CHECK: return %[[ARG]] - return %0 : tensor<4xf32> -} - -// CHECK-LABEL: func @reduce_empty_mesh_axes -func.func @reduce_empty_mesh_axes( -// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> - %arg0 : tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NOT: mesh.reduce - %0 = mesh.reduce %arg0 on @mesh0 - mesh_axes = [] - root = [] - : (tensor<4xf32>) -> tensor<4xf32> -// CHECK: return %[[ARG]] - return %0 : tensor<4xf32> -} - -// CHECK-LABEL: func @reduce_scatter_empty_mesh_axes -func.func @reduce_scatter_empty_mesh_axes( -// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> - %arg0 : tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NOT: mesh.reduce_scatter - %0 = mesh.reduce_scatter %arg0 on @mesh0 - mesh_axes = [] - scatter_axis = 0 - : tensor<4xf32> -> tensor<4xf32> -// CHECK: return %[[ARG]] - return %0 : tensor<4xf32> -} - -// CHECK-LABEL: func @reduce_scatter_empty_mesh_axes_different_return_type -func.func @reduce_scatter_empty_mesh_axes_different_return_type( - %arg0 : tensor<4xf32>) -> tensor<4xf64> { -// CHECK: mesh.reduce_scatter - %0 = mesh.reduce_scatter %arg0 on @mesh0 -// CHECK-NOT: mesh_axes - mesh_axes = [] - scatter_axis = 0 - : tensor<4xf32> -> tensor<4xf64> - return %0 : tensor<4xf64> -} - -// CHECK-LABEL: func @reduce_scatter_default_reduction -func.func @reduce_scatter_default_reduction( - %arg0 : tensor<4xf32>) -> tensor<2xf64> { - %0 = mesh.reduce_scatter %arg0 on @mesh0 - mesh_axes = [0] -// CHECK-NOT: reduction - reduction = sum - scatter_axis = 0 - : tensor<4xf32> -> tensor<2xf64> - return %0 : tensor<2xf64> -} - -// CHECK-LABEL: func @scatter_empty_mesh_axes -func.func @scatter_empty_mesh_axes( -// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> - %arg0 : tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NOT: mesh.scatter - %0 = mesh.scatter %arg0 on @mesh0 - mesh_axes = [] - scatter_axis = 0 - root = [] - : (tensor<4xf32>) -> tensor<4xf32> -// CHECK: return %[[ARG]] - return %0 : tensor<4xf32> -} - -// CHECK-LABEL: func @send_empty_mesh_axes -func.func @send_empty_mesh_axes( -// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> - %arg0 : tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NOT: mesh.send - %0 = mesh.send %arg0 on @mesh0 - mesh_axes = [] - destination = [] - : (tensor<4xf32>) -> tensor<4xf32> -// CHECK: return %[[ARG]] - return %0 : tensor<4xf32> -} - -mesh.mesh @mesh4x4(shape = 4x4) -// CHECK-LABEL: func @test_halo_sizes -func.func @test_halo_sizes() -> !mesh.sharding { - %c2_i64 = arith.constant 2 : i64 - // CHECK mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 2, 22] : !mesh.sharding - %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, %c2_i64, %c2_i64, 22] : !mesh.sharding - return %sharding : !mesh.sharding -} - -// CHECK-LABEL: func @test_shard_offs -func.func @test_shard_offs() -> !mesh.sharding { - %c2_i64 = arith.constant 2 : i64 - // CHECK mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, 2, 3, 4, 0, 2, 3, 4, 22] : !mesh.sharding - %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, %c2_i64, 3, 4, 0, %c2_i64, 3, 4, 22] : !mesh.sharding - return %sharding : !mesh.sharding -} - -// CHECK-LABEL: func @test_duplicate_shardops -func.func @test_duplicate_shardops() -> (tensor<1024x1024xf32>, tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} { - // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> - %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> - // CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0, 1]] : !mesh.sharding - %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding - %cst_2 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> - %sharding_2 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding - %sharding_annotated_2 = mesh.shard %cst_2 to %sharding_2 : tensor<1024x1024xf32> - %cst_3 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> - %sharding_3 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding - %sharding_annotated_3 = mesh.shard %cst_3 to %sharding_3 : tensor<1024x1024xf32> - // CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32> - %sharding_annotated_1 = mesh.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32> - // CHECK-NEXT: return [[vsharding_annotated]], [[vsharding_annotated]] : tensor<1024x1024xf32>, tensor<1024x1024xf32> - return %sharding_annotated_1, %sharding_annotated_2 : tensor<1024x1024xf32>, tensor<1024x1024xf32> -} - -// CHECK-LABEL: func @test_duplicate_shardops_diff -func.func @test_duplicate_shardops_diff() -> (tensor<1024x1024xf32>, tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} { - // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> - %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> - // CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding - %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding - %cst_2 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> - // CHECK-NEXT: [[vsharding_0:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0, 1]] : !mesh.sharding - %sharding_2 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding - // CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding_0]] : tensor<1024x1024xf32> - %sharding_annotated_2 = mesh.shard %cst_2 to %sharding_2 : tensor<1024x1024xf32> - %cst_3 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> - %sharding_3 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding - %sharding_annotated_3 = mesh.shard %cst_3 to %sharding_3 : tensor<1024x1024xf32> - // CHECK-NEXT: [[vsharding_annotated_1:%.*]] = mesh.shard [[vsharding_annotated]] to [[vsharding]] : tensor<1024x1024xf32> - %sharding_annotated_1 = mesh.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32> - // CHECK-NEXT: return [[vsharding_annotated_1]], [[vsharding_annotated]] : tensor<1024x1024xf32>, tensor<1024x1024xf32> - return %sharding_annotated_1, %sharding_annotated_2 : tensor<1024x1024xf32>, tensor<1024x1024xf32> -} diff --git a/mlir/test/Dialect/Mesh/folding.mlir b/mlir/test/Dialect/Mesh/folding.mlir deleted file mode 100644 index 369f316d..0000000 --- a/mlir/test/Dialect/Mesh/folding.mlir +++ /dev/null @@ -1,22 +0,0 @@ -// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s - -mesh.mesh @mesh0(shape = 4x?x2) -mesh.mesh @mesh1(shape = 2x3) - -// CHECK-LABEL: func.func @mesh_shape_op_folding -func.func @mesh_shape_op_folding() -> (index, index) { - // CHECK: %[[AXIS_2_SIZE:.*]] = arith.constant 2 : index - // CHECK: %[[AXIS_1_SIZE:.*]] = mesh.mesh_shape @mesh0 axes = [1] : index - %0:2 = mesh.mesh_shape @mesh0 axes = [2, 1] : index, index - // CHECK: return %[[AXIS_2_SIZE]], %[[AXIS_1_SIZE]] - return %0#0, %0#1 : index, index -} - -// CHECK-LABEL: func.func @mesh_shape_op_folding_all_axes_static_mesh -func.func @mesh_shape_op_folding_all_axes_static_mesh() -> (index, index) { - // CHECK: %[[AXIS_0_SIZE:.*]] = arith.constant 2 : index - // CHECK: %[[AXIS_1_SIZE:.*]] = arith.constant 3 : index - %0:2 = mesh.mesh_shape @mesh1 : index, index - // CHECK: return %[[AXIS_0_SIZE]], %[[AXIS_1_SIZE]] - return %0#0, %0#1 : index, index -} diff --git a/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir b/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir deleted file mode 100644 index 6ab711b..0000000 --- a/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir +++ /dev/null @@ -1,49 +0,0 @@ -// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward}))" %s | FileCheck %s - -#map = affine_map<(d0, d1) -> (d0, d1)> -module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "mpich", "MPI:comm_world_rank" = 0 : i32>} { - mesh.mesh @mesh(shape = 1) {sym_visibility = "private"} - func.func @test_forward() -> (tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>) attributes {llvm.emit_c_interface} { - %c1_i32 = arith.constant 1 : i32 - // CHECK: [[v3:%.*]] = tensor.empty() : tensor<6x6xi32> - %0 = tensor.empty() : tensor<6x6xi32> - // CHECK: [[v1:%.*]] = linalg.fill ins - // CHECK: [[vsharding_0:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding - // CHECK: [[vsharding_annotated_1:%.*]] = mesh.shard [[v1]] to [[vsharding_0]] : tensor<6x6xi32> - %1 = linalg.fill ins(%c1_i32 : i32) outs(%0 : tensor<6x6xi32>) -> tensor<6x6xi32> - %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding - %sharding_annotated = mesh.shard %1 to %sharding : tensor<6x6xi32> - // CHECK: [[v2:%.*]] = tensor.empty() : tensor<6x6xi32> - // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding - // CHECK: [[vsharding_annotated_3:%.*]] = mesh.shard [[vsharding_annotated_1]] to [[vsharding_2]] annotate_for_users : tensor<6x6xi32> - %3 = tensor.empty() : tensor<6x6xi32> - // CHECK: [[vsharding_4:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding - // CHECK: [[vsharding_annotated_5:%.*]] = mesh.shard [[v2]] to [[vsharding_4]] annotate_for_users : tensor<6x6xi32> - // CHECK: [[v3:%.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} - // CHECK-SAME: ins([[vsharding_annotated_3]], [[vsharding_annotated_3]] : tensor<6x6xi32>, tensor<6x6xi32>) outs([[vsharding_annotated_5]] : tensor<6x6xi32>) { - // CHECK: [[vsharding_6:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding - // CHECK: [[vsharding_annotated_7:%.*]] = mesh.shard [[v3]] to [[vsharding_6]] : tensor<6x6xi32> - %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%sharding_annotated, %sharding_annotated - : tensor<6x6xi32>, tensor<6x6xi32>) outs(%3 : tensor<6x6xi32>) { - ^bb0(%in: i32, %in_2: i32, %out: i32): - %9 = arith.addi %in, %in_2 : i32 - linalg.yield %9 : i32 - } -> tensor<6x6xi32> - %c0_i32 = arith.constant 0 : i32 - %6 = tensor.empty() : tensor<i32> - %7 = linalg.fill ins(%c0_i32 : i32) outs(%6 : tensor<i32>) -> tensor<i32> - // CHECK: [[vreduced:%.*]] = linalg.reduce ins - // CHECK: [[vsharding_12:%.*]] = mesh.sharding @mesh split_axes = [] : !mesh.sharding - // CHECK: [[vsharding_annotated_13:%.*]] = mesh.shard [[vreduced]] to [[vsharding_12]] : tensor<i32> - %reduced = linalg.reduce ins(%4 : tensor<6x6xi32>) outs(%7 : tensor<i32>) dimensions = [0, 1] - (%in: i32, %init: i32) { - %9 = arith.addi %in, %init : i32 - linalg.yield %9 : i32 - } - // CHECK: [[vsharding_14:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}]] : !mesh.sharding - %sharding_0 = mesh.sharding @mesh split_axes = [[]] : !mesh.sharding - // CHECK: [[vsharding_annotated_15:%.*]] = mesh.shard [[vsharding_annotated_13]] to [[vsharding_14]] annotate_for_users : tensor<i32> - %sharding_annotated_1 = mesh.shard %reduced to %sharding_0 annotate_for_users : tensor<i32> - return %sharding_annotated, %4, %sharding_annotated_1 : tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32> - } -} diff --git a/mlir/test/Dialect/Mesh/inlining.mlir b/mlir/test/Dialect/Mesh/inlining.mlir deleted file mode 100644 index c41a709..0000000 --- a/mlir/test/Dialect/Mesh/inlining.mlir +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: mlir-opt -inline %s | FileCheck %s - -mesh.mesh @mesh0(shape = 4x?x2) - -func.func private @mesh_to_inline() -> (index, index) { - %0:2 = mesh.mesh_shape @mesh0 axes = [2, 1] : index, index - return %0#0, %0#1 : index, index -} -// CHECK-LABEL: func.func @main -func.func @main() -> (index, index) { - // CHECK-NEXT: %[[AXIS_SIZE:.*]]:2 = mesh.mesh_shape @mesh0 axes = [2, 1] : index - %0:2 = func.call @mesh_to_inline() : () -> (index, index) - // CHECK-NEXT: return %[[AXIS_SIZE]]#0, %[[AXIS_SIZE]]#1 - return %0#0, %0#1 : index, index -} diff --git a/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir b/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir deleted file mode 100644 index e23cfd7..0000000 --- a/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir +++ /dev/null @@ -1,23 +0,0 @@ -// RUN: mlir-opt -test-mesh-process-multi-index-op-lowering %s | FileCheck %s - -mesh.mesh @mesh2d(shape = ?x?) - -// CHECK-LABEL: func.func @multi_index_2d_mesh -func.func @multi_index_2d_mesh() -> (index, index) { - // CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index - // CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.mesh_shape @mesh2d : index, index - // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index - %0:2 = mesh.process_multi_index on @mesh2d : index, index - // CHECK: return %[[MULTI_IDX]]#0, %[[MULTI_IDX]]#1 : index, index - return %0#0, %0#1 : index, index -} - -// CHECK-LABEL: func.func @multi_index_2d_mesh_single_inner_axis -func.func @multi_index_2d_mesh_single_inner_axis() -> index { - // CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index - // CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.mesh_shape @mesh2d : index, index - // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index - %0 = mesh.process_multi_index on @mesh2d axes = [0] : index - // CHECK: return %[[MULTI_IDX]]#0 : index - return %0 : index -} diff --git a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir deleted file mode 100644 index 5e62c92..0000000 --- a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir +++ /dev/null @@ -1,168 +0,0 @@ -// RUN: mlir-opt -test-mesh-resharding-spmdization %s | FileCheck %s - -mesh.mesh @mesh_1d(shape = 2) -mesh.mesh @mesh_1d_dynamic(shape = ?) - -// CHECK-LABEL: func @same_source_and_target_sharding -func.func @same_source_and_target_sharding( - // CHECK-SAME: %[[ARG:.*]]: tensor<2xf32> - %arg0: tensor<2xf32> -) -> tensor<2xf32> { - %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor<2xf32> - %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2xf32> - // CHECK: return %[[ARG]] - return %1 : tensor<2xf32> -} - -// CHECK-LABEL: func @identical_source_and_target_sharding -func.func @identical_source_and_target_sharding( - // CHECK-SAME: %[[ARG:.*]]: tensor<2xf32> - %arg0: tensor<2xf32> -) -> tensor<2xf32> { - %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor<2xf32> - %1 = mesh.shard %0 to %s0 annotate_for_users : tensor<2xf32> - // CHECK: return %[[ARG]] - return %1 : tensor<2xf32> -} - -// CHECK-LABEL: func @split_replicated_tensor_axis -func.func @split_replicated_tensor_axis( - // CHECK-SAME: %[[ARG:.*]]: tensor<3x14xf32> - %arg0: tensor<3x14xf32> -) -> tensor<3x14xf32> { - // CHECK: %[[ALL_SLICE:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 1 - // CHECK-SAME: tensor<3x14xf32> -> tensor<3x7xf32> - // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[ALL_SLICE]] : tensor<3x7xf32> to tensor<3x14xf32> - %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor<3x14xf32> - %s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<3x14xf32> - // CHECK: return %[[RESULT]] : tensor<3x14xf32> - return %1 : tensor<3x14xf32> -} - -// CHECK-LABEL: func @split_replicated_tensor_axis_dynamic -func.func @split_replicated_tensor_axis_dynamic( - // CHECK-SAME: %[[ARG:.*]]: tensor<?x3x?xf32> - %arg0: tensor<?x3x?xf32> -) -> tensor<?x3x?xf32> { - // CHECK: %[[RESULT:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d_dynamic mesh_axes = [0] slice_axis = 0 - // CHECK-SAME: tensor<?x3x?xf32> -> tensor<?x3x?xf32> - %s0 = mesh.sharding @mesh_1d_dynamic split_axes = [[], [], []] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor<?x3x?xf32> - %s1 = mesh.sharding @mesh_1d_dynamic split_axes = [[0]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<?x3x?xf32> - // CHECK: return %[[RESULT]] : tensor<?x3x?xf32> - return %1 : tensor<?x3x?xf32> -} - -// CHECK-LABEL: func @move_split_axis -func.func @move_split_axis( - // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32> - %arg0: tensor<10x14xf32> -) -> tensor<10x14xf32> { - // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32> - // CHECK: %[[TARGET_SHARD:.*]] = mesh.all_to_all %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<5x14xf32> -> tensor<10x7xf32> - // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x7xf32> to tensor<10x14xf32> - %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32> - %s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32> - // CHECK: return %[[RES]] : tensor<10x14xf32> - return %1 : tensor<10x14xf32> -} - -// CHECK-LABEL: func @move_split_axis_dynamic_mesh -func.func @move_split_axis_dynamic_mesh( - // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32> - %arg0: tensor<10x14xf32> -) -> tensor<10x14xf32> { - // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32> - // CHECK: %[[ALL_TO_ALL:.*]] = mesh.all_to_all %[[SOURCE_SHARD]] on @mesh_1d_dynamic mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x?xf32> - // CHECK: %[[TARGET_SHARD:.*]] = tensor.cast %[[ALL_TO_ALL]] : tensor<?x?xf32> to tensor<10x?xf32> - // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x?xf32> to tensor<10x14xf32> - %s0 = mesh.sharding @mesh_1d_dynamic split_axes = [[0]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32> - %s1 = mesh.sharding @mesh_1d_dynamic split_axes = [[], [0]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32> - // CHECK: return %[[RES]] : tensor<10x14xf32> - return %1 : tensor<10x14xf32> -} - -// CHECK-LABEL: func @move_split_dynamic_axis -func.func @move_split_dynamic_axis( - // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32> - %arg0: tensor<?x14xf32> -) -> tensor<?x14xf32> { - // CHECK: %[[TARGET_SHARD:.*]] = mesh.all_to_all %[[ARG]] on @mesh_1d mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x7xf32> - // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<?x7xf32> to tensor<?x14xf32> - %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor<?x14xf32> - %s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<?x14xf32> - // CHECK: return %[[RES]] : tensor<?x14xf32> - return %1 : tensor<?x14xf32> -} - -// CHECK-LABEL: func @unshard_static_axis -func.func @unshard_static_axis( - // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32> - %arg0: tensor<10x14xf32> -) -> tensor<10x14xf32> { - // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32> - // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<5x14xf32> -> tensor<10x14xf32> - %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32> - %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32> - // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32> - return %1 : tensor<10x14xf32> -} - -// CHECK-LABEL: func @unshard_static_last_axis -func.func @unshard_static_last_axis( - // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32> - %arg0: tensor<10x14xf32> -) -> tensor<10x14xf32> { - // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<10x7xf32> - // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] gather_axis = 1 : tensor<10x7xf32> -> tensor<10x14xf32> - %s0 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32> - %s1 = mesh.sharding @mesh_1d split_axes = [[], []] : !mesh.sharding - %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32> - // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32> - return %1 : tensor<10x14xf32> -} - -// CHECK-LABEL: func @unshard_dynamic_axis -func.func @unshard_dynamic_axis( - // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32> - %arg0: tensor<?x14xf32> -) -> tensor<?x14xf32> { - // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[ARG]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32> - %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor<?x14xf32> - %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<?x14xf32> - // CHECK: return %[[ALL_GATHER]] : tensor<?x14xf32> - return %1 : tensor<?x14xf32> -} - -// CHECK-LABEL: func @unshard_static_axis_on_dynamic_mesh_axis -func.func @unshard_static_axis_on_dynamic_mesh_axis( -// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32> - %arg0: tensor<10x14xf32> -) -> tensor<10x14xf32> { - // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32> - // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d_dynamic mesh_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32> - // CHECK: %[[RES:.*]] = tensor.cast %[[ALL_GATHER]] : tensor<?x14xf32> to tensor<10x14xf32> - %s0 = mesh.sharding @mesh_1d_dynamic split_axes = [[0]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32> - %s1 = mesh.sharding @mesh_1d_dynamic split_axes = [[]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32> - // CHECK: return %[[RES]] : tensor<10x14xf32> - return %1 : tensor<10x14xf32> -} diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir deleted file mode 100644 index 0881d994..0000000 --- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir +++ /dev/null @@ -1,301 +0,0 @@ -// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation,cse))" %s | FileCheck %s - -mesh.mesh @mesh_2(shape = 2) -mesh.mesh @mesh_1d(shape = ?) -mesh.mesh @mesh_2d(shape = 2x4) -mesh.mesh @mesh_3d(shape = ?x?x?) - -// CHECK-LABEL: func.func @element_wise_empty_sharding_info -func.func @element_wise_empty_sharding_info(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { - // CHECK-NEXT: tosa.sigmoid - %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32> - // CHECK-NEXT: return - return %0 : tensor<8x16xf32> -} - -// CHECK-LABEL: func.func @element_wise_on_def -// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32> -func.func @element_wise_on_def(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { - // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding - // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32> - // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]] - %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32> - // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S0]] : tensor<8x16xf32> - %s1 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 : tensor<8x16xf32> - // CHECK-NEXT: return %[[V2]] - return %1 : tensor<8x16xf32> -} - -// CHECK-LABEL: func.func @element_wise_on_use -// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32> -func.func @element_wise_on_use(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { - // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding - // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32> - %s0 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 annotate_for_users : tensor<8x16xf32> - // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]] - %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32> - // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S0]] : tensor<8x16xf32> - // CHECK-NEXT: return %[[V2]] - return %1 : tensor<8x16xf32> -} - -// CHECK-LABEL: func.func @element_wise_on_graph_output -// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32> -func.func @element_wise_on_graph_output(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { - // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding - // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32> - // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]] - %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32> - // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S0]] : tensor<8x16xf32> - // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S0]] annotate_for_users : tensor<8x16xf32> - %s1 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<8x16xf32> - // CHECK-NEXT: return %[[V3]] - return %1 : tensor<8x16xf32> -} - -// CHECK-LABEL: func.func @element_wise_on_graph_input -// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32> -func.func @element_wise_on_graph_input(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { - // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding - // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] : tensor<8x16xf32> - // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[V0]] to %[[S0]] annotate_for_users : tensor<8x16xf32> - %s0 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor<8x16xf32> - // CHECK-NEXT: %[[V2:.*]] = tosa.sigmoid %[[V1]] - %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32> - // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S0]] : tensor<8x16xf32> - // CHECK-NEXT: return %[[V3]] - return %1 : tensor<8x16xf32> -} - -// CHECK-LABEL: func.func @arrow_structure -// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32> -func.func @arrow_structure(%arg0: tensor<8x16xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) { - // CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding - // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG]] to %[[S1]] annotate_for_users : tensor<8x16xf32> - // CHECK-NEXT: %[[V2:.*]] = tosa.tanh %[[V1]] - // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S1]] : tensor<8x16xf32> - %0 = tosa.tanh %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32> - // CHECK-NEXT: %[[V4:.*]] = mesh.shard %[[V3]] to %[[S1]] annotate_for_users : tensor<8x16xf32> - // CHECK-NEXT: %[[V5:.*]] = tosa.abs %[[V4]] - // CHECK-NEXT: %[[V6:.*]] = mesh.shard %[[V5]] to %[[S1]] : tensor<8x16xf32> - %1 = tosa.abs %0: (tensor<8x16xf32>) -> tensor<8x16xf32> - // CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding - // CHECK-NEXT: %[[ZP1:.*]] = mesh.shard %arg1 to %[[S3]] annotate_for_users : tensor<1xf32> - // CHECK-NEXT: %[[ZP2:.*]] = mesh.shard %arg2 to %[[S3]] annotate_for_users : tensor<1xf32> - // CHECK-NEXT: %[[V7:.*]] = tosa.negate %[[V4]], %[[ZP1]], %[[ZP2]] - // CHECK-NEXT: %[[V8:.*]] = mesh.shard %[[V7]] to %[[S1]] : tensor<8x16xf32> - %2 = tosa.negate %0, %arg1, %arg2 : (tensor<8x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<8x16xf32> - %s3 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding - %3 = mesh.shard %2 to %s3 : tensor<8x16xf32> - // CHECK-NEXT: return %[[V6]], %[[V8]] - return %1, %3 : tensor<8x16xf32>, tensor<8x16xf32> -} - -// CHECK-LABEL: func.func @matmul_on_def_shard_batch_and_m -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32> -func.func @matmul_on_def_shard_batch_and_m(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> { - // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding - // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32> - // CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0]] : !mesh.sharding - // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32> - // CHECK-NEXT: %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding - // CHECK-NEXT: %[[ZP:.*]] = mesh.shard %[[ARG2]] to %[[S2]] annotate_for_users : tensor<1xf32> - // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]] - %0 = tosa.matmul %arg0, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32> - // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S0]] : tensor<2x16x32xf32> - %s1 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 : tensor<2x16x32xf32> - // CHECK-NEXT: return %[[V3]] - return %1 : tensor<2x16x32xf32> -} - -// CHECK-LABEL: func.func @matmul_on_def_shard_m_and_n -// CHECK-SAME: [[varg0:%.*]]: tensor<2x16x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<1xf32> -func.func @matmul_on_def_shard_m_and_n(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> { - // CHECK: [[vsharding:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0]] : !mesh.sharding - // CHECK: [[vsharded:%.*]] = mesh.shard [[varg0]] to [[vsharding]] annotate_for_users : tensor<2x16x8xf32> - // CHECK: [[vsharding_0:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [], [1]] : !mesh.sharding - // CHECK: [[vsharded_1:%.*]] = mesh.shard [[varg1]] to [[vsharding_0]] annotate_for_users : tensor<2x8x32xf32> - // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding - // CHECK: [[vsharded_3:%.*]] = mesh.shard [[varg2]] to [[vsharding_2]] annotate_for_users : tensor<1xf32> - // CHECK: [[v0:%.*]] = tosa.matmul - %0 = tosa.matmul %arg0, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32> - // CHECK: [[vsharding_4:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0], [1]] : !mesh.sharding - // CHECK: [[vsharded_5:%.*]] = mesh.shard [[v0]] to [[vsharding_4]] : tensor<2x16x32xf32> - %s1 = mesh.sharding @mesh_2d split_axes = [[], [0], [1]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 : tensor<2x16x32xf32> - // CHECK-NEXT: return [[vsharded_5]] - return %1 : tensor<2x16x32xf32> -} - -// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_k -// CHECK-SAME: [[varg0:%.*]]: tensor<2x16x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<1xf32> -func.func @matmul_on_use_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> { - // CHECK: [[vsharding:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0], [1]] : !mesh.sharding - %s0 = mesh.sharding @mesh_2d split_axes = [[], [0], [1]] : !mesh.sharding - // CHECK: [[vsharded:%.*]] = mesh.shard [[varg0]] to [[vsharding]] : tensor<2x16x8xf32> - %arg0_s = mesh.shard %arg0 to %s0 : tensor<2x16x8xf32> - // CHECK: [[vsharded_0:%.*]] = mesh.shard [[vsharded]] to [[vsharding]] annotate_for_users : tensor<2x16x8xf32> - // CHECK: [[vsharding_1:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1]] : !mesh.sharding - // CHECK: [[vsharded_2:%.*]] = mesh.shard [[varg1]] to [[vsharding_1]] annotate_for_users : tensor<2x8x32xf32> - // CHECK: [[vsharding_3:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding - // CHECK: [[vsharded_4:%.*]] = mesh.shard [[varg2]] to [[vsharding_3]] annotate_for_users : tensor<1xf32> - // CHECK: [[v0:%.*]] = tosa.matmul - // CHECK: [[vsharding_5:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0]] : !mesh.sharding - // CHECK: [[vsharded_6:%.*]] = mesh.shard [[v0]] to [[vsharding_5]] : tensor<2x16x32xf32> - %0 = tosa.matmul %arg0_s, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32> - // CHECK: return [[vsharded_6]] - return %0 : tensor<2x16x32xf32> -} - -// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_duplicted_k -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32> -func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> { - // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1], [0]] : !mesh.sharding - // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32> - %s0 = mesh.sharding @mesh_2d split_axes = [[], [1], [0]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 annotate_for_users : tensor<2x16x8xf32> - // CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0]] : !mesh.sharding - // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32> - %s1 = mesh.sharding @mesh_2d split_axes = [[], [0]] : !mesh.sharding - %1 = mesh.shard %arg1 to %s1 annotate_for_users : tensor<2x8x32xf32> - // CHECK-NEXT: %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding - // CHECK-NEXT: %[[ZP:.*]] = mesh.shard %[[ARG2]] to %[[S2]] annotate_for_users : tensor<1xf32> - // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]] - %2 = tosa.matmul %0, %1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32> - // CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1]] : !mesh.sharding - // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]] : tensor<2x16x32xf32> - // CHECK-NEXT: return %[[V3]] - return %2 : tensor<2x16x32xf32> -} - -// CHECK-LABEL: func.func @resolve_conflicting_annotations -func.func @resolve_conflicting_annotations( - // CHECK-SAME: %[[IN1:.*]]: tensor<2x3xf32>, - %arg0: tensor<2x3xf32>, - // CHECK-SAME: %[[IN2:.*]]: tensor<3x2xf32>, - %arg1: tensor<3x2xf32>, - // CHECK-SAME: %[[OUT_DPS:.*]]: tensor<2x2xf32> - %out_dps: tensor<2x2xf32> -// CHECK-SAME: ) -> tensor<2x2xf32> { -) -> tensor<2x2xf32> { - // CHECK: %[[SIN1_SHARDED1:.*]] = mesh.sharding @mesh_2 split_axes = {{\[\[}}0]] : !mesh.sharding - // CHECK-NEXT: %[[IN1_SHARDED1:.*]] = mesh.shard %[[IN1]] to %[[SIN1_SHARDED1]] : tensor<2x3xf32> - // CHECK: %[[SIN2_SHARDED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[\[}}]] : !mesh.sharding - // CHECK-NEXT: %[[IN1_SHARDED2:.*]] = mesh.shard %[[IN1_SHARDED1]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<2x3xf32> - // CHECK-NEXT: %[[IN2_SHARDED:.*]] = mesh.shard %[[IN2]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<3x2xf32> - // CHECK-NEXT: %[[OUT_DPS_SHARDED:.*]] = mesh.shard %[[OUT_DPS]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<2x2xf32> - %sarg0_sharded = mesh.sharding @mesh_2 split_axes = [[0]] : !mesh.sharding - %arg0_sharded = mesh.shard %arg0 to %sarg0_sharded : tensor<2x3xf32> - // CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[IN1_SHARDED2]], %[[IN2_SHARDED]] : tensor<2x3xf32>, tensor<3x2xf32>) - // CHECK-SAME: outs(%[[OUT_DPS_SHARDED]] : tensor<2x2xf32>) -> tensor<2x2xf32> - %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>) - outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: %[[RES:.*]] = mesh.shard %[[MATMUL]] to %[[SIN2_SHARDED]] : tensor<2x2xf32> - %sres_sharded = mesh.sharding @mesh_2 split_axes = [[]] : !mesh.sharding - %res_sharded = mesh.shard %res to %sres_sharded : tensor<2x2xf32> - // CHECK: return %[[RES]] : tensor<2x2xf32> - return %res_sharded : tensor<2x2xf32> -} - -// https://arxiv.org/abs/2211.05102 Figure 2(a) -// The sharding propagation results in unnecessary reshards, -// an optimization pass should be able to remove them. -// CHECK-LABEL: func.func @mlp_1d_weight_stationary -// CHECK-SAME: [[varg0:%.*]]: tensor<2x4x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<2x32x8xf32>, [[varg3:%.*]]: tensor<1xf32> -func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>, %arg3: tensor<1xf32>) -> tensor<2x4x8xf32> { - %s0 = mesh.sharding @mesh_1d split_axes = [[], [], [0, 1, 2]] : !mesh.sharding - %sharded0 = mesh.shard %arg0 to %s0 : tensor<2x4x8xf32> - %sharded1 = mesh.shard %arg1 to %s0 : tensor<2x8x32xf32> - // CHECK: [[vsharding:%.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [], [0, 1, 2]] : !mesh.sharding - // CHECK: [[vsharded:%.*]] = mesh.shard [[varg0]] to [[vsharding]] : tensor<2x4x8xf32> - // CHECK: [[vsharded_0:%.*]] = mesh.shard [[varg1]] to [[vsharding]] : tensor<2x8x32xf32> - // CHECK: [[vsharded_1:%.*]] = mesh.shard [[vsharded]] to [[vsharding]] annotate_for_users : tensor<2x4x8xf32> - // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [0, 1, 2]] : !mesh.sharding - // CHECK: [[vsharded_3:%.*]] = mesh.shard [[vsharded_0]] to [[vsharding_2]] annotate_for_users : tensor<2x8x32xf32> - // CHECK: [[vsharding_4:%.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}]] : !mesh.sharding - // CHECK: [[vsharded_5:%.*]] = mesh.shard [[varg3]] to [[vsharding_4]] annotate_for_users : tensor<1xf32> - // CHECK: [[v0:%.*]] = tosa.matmul - %1 = tosa.matmul %sharded0, %sharded1, %arg3, %arg3 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x32xf32> - // CHECK: [[vsharded_6:%.*]] = mesh.shard [[v0]] to [[vsharding_4]] : tensor<2x4x32xf32> - // CHECK: [[vsharded_7:%.*]] = mesh.shard [[vsharded_6]] to [[vsharding_4]] annotate_for_users : tensor<2x4x32xf32> - // CHECK: [[v1:%.*]] = tosa.sigmoid [[vsharded_7]] : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32> - // CHECK: [[vsharded_8:%.*]] = mesh.shard [[v1]] to [[vsharding_4]] : tensor<2x4x32xf32> - %2 = tosa.sigmoid %1 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32> - %sharding = mesh.sharding @mesh_1d split_axes = [[], [0, 1, 2]] : !mesh.sharding - // CHECK: [[vsharded_9:%.*]] = mesh.shard [[varg2]] to [[vsharding_2]] : tensor<2x32x8xf32> - %sharded2 = mesh.shard %arg2 to %sharding : tensor<2x32x8xf32> - // CHECK: [[vsharded_10:%.*]] = mesh.shard [[vsharded_8]] to [[vsharding_4]] annotate_for_users : tensor<2x4x32xf32> - // CHECK: [[vsharded_11:%.*]] = mesh.shard [[vsharded_9]] to [[vsharding]] annotate_for_users : tensor<2x32x8xf32> - // CHECK: [[v2:%.*]] = tosa.matmul - %3 = tosa.matmul %2, %sharded2, %arg3, %arg3 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x8xf32> - // CHECK: [[vsharded_12:%.*]] = mesh.shard [[v2]] to [[vsharding]] : tensor<2x4x8xf32> - %s4 = mesh.sharding @mesh_1d split_axes = [[], [], [0, 1, 2]] : !mesh.sharding - %4 = mesh.shard %3 to %s4 : tensor<2x4x8xf32> - // CHECK: return [[vsharded_12]] - return %4 : tensor<2x4x8xf32> -} - -// https://arxiv.org/abs/2211.05102 Figure 2(b) -// The sharding propagation results in unnecessary reshards, -// an optimization pass should be able to remove them. -// CHECK-LABEL: func.func @mlp_2d_weight_stationary -// CHECK-SAME: [[varg0:%.*]]: tensor<2x4x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<2x32x8xf32>, [[varg3:%.*]]: tensor<1xf32> -func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>, %arg3: tensor<1xf32>) -> tensor<2x4x8xf32> { - // CHECK: [[vsharding:%.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [], [0, 1, 2]] : !mesh.sharding - %s0 = mesh.sharding @mesh_3d split_axes = [[], [], [0, 1, 2]] : !mesh.sharding - // CHECK: [[vsharded:%.*]] = mesh.shard [[varg0]] to [[vsharding]] : tensor<2x4x8xf32> - %arg0_s = mesh.shard %arg0 to %s0 : tensor<2x4x8xf32> - // CHECK: [[vsharding_0:%.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [0], [1, 2]] : !mesh.sharding - %s1 = mesh.sharding @mesh_3d split_axes = [[], [0], [1, 2]] : !mesh.sharding - // CHECK: [[vsharded_1:%.*]] = mesh.shard [[varg1]] to [[vsharding_0]] : tensor<2x8x32xf32> - %arg1_s = mesh.shard %arg1 to %s1 : tensor<2x8x32xf32> - // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}]] : !mesh.sharding - // CHECK: [[vsharded_3:%.*]] = mesh.shard [[vsharded]] to [[vsharding_2]] annotate_for_users : tensor<2x4x8xf32> - // CHECK: [[vsharded_4:%.*]] = mesh.shard [[vsharded_1]] to [[vsharding]] annotate_for_users : tensor<2x8x32xf32> - // CHECK: [[vsharded_5:%.*]] = mesh.shard [[varg3]] to [[vsharding_2]] annotate_for_users : tensor<1xf32> - // CHECK: [[v0:%.*]] = tosa.matmul - %1 = tosa.matmul %arg0_s, %arg1_s, %arg3, %arg3 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x32xf32> - // CHECK: [[vsharded_6:%.*]] = mesh.shard [[v0]] to [[vsharding]] : tensor<2x4x32xf32> - %2 = mesh.shard %1 to %s0 : tensor<2x4x32xf32> - // CHECK: [[vsharded_7:%.*]] = mesh.shard [[vsharded_6]] to [[vsharding]] annotate_for_users : tensor<2x4x32xf32> - // CHECK: [[v1:%.*]] = tosa.sigmoid - // CHECK: [[vsharded_8:%.*]] = mesh.shard [[v1]] to [[vsharding]] : tensor<2x4x32xf32> - %3 = tosa.sigmoid %2 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32> - // CHECK: [[vsharding_9:%.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [1, 2], [0]] : !mesh.sharding - %s2 = mesh.sharding @mesh_3d split_axes = [[], [1, 2], [0]] : !mesh.sharding - // CHECK: [[vsharded_10:%.*]] = mesh.shard [[varg2]] to [[vsharding_9]] : tensor<2x32x8xf32> - %arg2_s = mesh.shard %arg2 to %s2 : tensor<2x32x8xf32> - // CHECK: [[vsharded_11:%.*]] = mesh.shard [[vsharded_8]] to [[vsharding_2]] annotate_for_users : tensor<2x4x32xf32> - // CHECK: [[vsharded_12:%.*]] = mesh.shard [[vsharded_10]] to [[vsharding]] annotate_for_users : tensor<2x32x8xf32> - // CHECK: [[v2:%.*]] = tosa.matmul - %4 = tosa.matmul %3, %arg2_s, %arg3, %arg3 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x8xf32> - // CHECK: [[vsharded_13:%.*]] = mesh.shard [[v2]] to [[vsharding]] : tensor<2x4x8xf32> - %5 = mesh.shard %4 to %s0 : tensor<2x4x8xf32> - // CHECK: [[vsharded_14:%.*]] = mesh.shard [[vsharded_13]] to [[vsharding]] annotate_for_users : tensor<2x4x8xf32> - %6 = mesh.shard %5 to %s0 annotate_for_users : tensor<2x4x8xf32> - // CHECK: return [[vsharded_14]] - return %6 : tensor<2x4x8xf32> -} - -// CHECK-LABEL: func.func @elementwise_duplicated_chain -// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32> -func.func @elementwise_duplicated_chain(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { - // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding - // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32> - // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]] - %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32> - // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S0]] : tensor<8x16xf32> - // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S0]] annotate_for_users : tensor<8x16xf32> - // CHECK-NEXT: %[[V4:.*]] = tosa.sigmoid %[[V3]] - %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32> - // CHECK-NEXT: %[[V5:.*]] = mesh.shard %[[V4]] to %[[S0]] : tensor<8x16xf32> - %s0 = mesh.sharding @mesh_2d split_axes = [[]] : !mesh.sharding - %2 = mesh.shard %1 to %s0 : tensor<8x16xf32> - // CHECK-NEXT: return %[[V5]] - return %2 : tensor<8x16xf32> -} diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir deleted file mode 100644 index 701898c..0000000 --- a/mlir/test/Dialect/Mesh/spmdization.mlir +++ /dev/null @@ -1,317 +0,0 @@ -// RUN: mlir-opt \ -// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-single-fold))" \ -// RUN: %s | FileCheck %s - -mesh.mesh @mesh_1d(shape = 2) - -// CHECK-LABEL: func @return_sharding -func.func @return_sharding( - // CHECK-SAME: [[ARG:%.*]]: tensor<1xf32> - %arg0: tensor<2xf32> -// CHECK-SAME: ) -> (tensor<1xf32>, !mesh.sharding) { -) -> (tensor<2xf32>, !mesh.sharding) { - %ssharding_annotated = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %sharding_annotated = mesh.shard %arg0 to %ssharding_annotated : tensor<2xf32> - // CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}0]] : !mesh.sharding - %r = mesh.get_sharding %sharding_annotated : tensor<2xf32> -> !mesh.sharding - // CHECK-NEXT: return [[ARG]], [[vsharding]] : tensor<1xf32>, !mesh.sharding - return %sharding_annotated, %r : tensor<2xf32>, !mesh.sharding -} - -// CHECK-LABEL: func @full_replication -func.func @full_replication( - // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> - %arg0: tensor<2xi8> -// CHECK-SAME: -> tensor<2xi8> { -) -> tensor<2xi8> { - %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor<2xi8> - %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2xi8> - // CHECK: return %[[ARG]] : tensor<2xi8> - return %1 : tensor<2xi8> -} - -// CHECK-LABEL: func @sharding_triplet -func.func @sharding_triplet( - // CHECK-SAME: %[[ARG:.*]]: tensor<1xf32> - %arg0: tensor<2xf32> -// CHECK-SAME: ) -> tensor<2xf32> { -) -> tensor<2xf32> { - // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[ARG]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<1xf32> -> tensor<2xf32> - %ssharding_annotated = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %sharding_annotated = mesh.shard %arg0 to %ssharding_annotated : tensor<2xf32> - %ssharding_annotated_0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %sharding_annotated_0 = mesh.shard %sharding_annotated to %ssharding_annotated_0 annotate_for_users : tensor<2xf32> - %ssharding_annotated_1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding - %sharding_annotated_1 = mesh.shard %sharding_annotated_0 to %ssharding_annotated_1 : tensor<2xf32> - // CHECK: return %[[ALL_GATHER]] : tensor<2xf32> - return %sharding_annotated_1 : tensor<2xf32> -} - - -// CHECK-LABEL: func @move_split_axis -func.func @move_split_axis( - // CHECK-SAME: %[[ARG:.*]]: tensor<1x2xi8> - %arg0: tensor<2x2xi8> -// CHECK-SAME: -> tensor<2x1xi8> { -) -> tensor<2x2xi8> { - // CHECK: %[[ALL_TO_ALL:.*]] = mesh.all_to_all %[[ARG]] on @mesh_1d - // CHECK-SAME: mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<1x2xi8> -> tensor<2x1xi8> - %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor<2x2xi8> - %s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2x2xi8> - // CHECK: return %[[ALL_TO_ALL]] : tensor<2x1xi8> - return %1 : tensor<2x2xi8> -} - -// CHECK-LABEL: func @non_tensor_value -func.func @non_tensor_value( - // CHECK-SAME: %[[ARG:.*]]: i8 - %arg0: i8 -// CHECK-SAME: -> i8 { -) -> i8 { - // CHECK: %[[RES:.*]] = arith.addi %[[ARG]], %[[ARG]] : i8 - %0 = arith.addi %arg0, %arg0 : i8 - // CHECK: return %[[RES]] : i8 - return %0 : i8 -} - -// CHECK-LABEL: func @unary_elementwise -func.func @unary_elementwise( - // CHECK-SAME: %[[ARG:.*]]: tensor<1xi8> - %arg0: tensor<2xi8> -// CHECK-SAME: -> tensor<1xi8> { -) -> tensor<2xi8> { - %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor<2xi8> - %s1 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2xi8> - // CHECK: %[[RES:.*]] = tosa.abs %[[ARG]] : (tensor<1xi8>) -> tensor<1xi8> - %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8> - %s3 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %3 = mesh.shard %2 to %s3 : tensor<2xi8> - %s4 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %4 = mesh.shard %3 to %s4 annotate_for_users : tensor<2xi8> - // CHECK: return %[[RES]] : tensor<1xi8> - return %4 : tensor<2xi8> -} - -// full replication -> shard axis -> abs -> shard axis -> full replication -// CHECK-LABEL: func @unary_elementwise_with_resharding -func.func @unary_elementwise_with_resharding( - // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> - %arg0: tensor<2xi8> -// CHECK-SAME: -> tensor<2xi8> { -) -> tensor<2xi8> { - // CHECK: %[[SLICE:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 0 - // CHECK-SAME: tensor<2xi8> -> tensor<1xi8> - %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor<2xi8> - %s1 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2xi8> - // CHECK: %[[ABS:.*]] = tosa.abs %[[SLICE]] : (tensor<1xi8>) -> tensor<1xi8> - %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8> - // CHECK: %[[RES:.*]] = mesh.all_gather %[[ABS]] on @mesh_1d - // CHECK-SAME: mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8> - %s3 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %3 = mesh.shard %2 to %s3 : tensor<2xi8> - %s4 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding - %4 = mesh.shard %3 to %s4 annotate_for_users : tensor<2xi8> - // CHECK: return %[[RES]] : tensor<2xi8> - return %4 : tensor<2xi8> -} - -// CHECK-LABEL: func @binary_elementwise -func.func @binary_elementwise( - // CHECK-SAME: %[[ARG0:.*]]: tensor<1xi8>, - %arg0: tensor<2xi8>, - // CHECK-SAME: %[[ARG1:.*]]: tensor<1xi8> - %arg1: tensor<2xi8> -// CHECK-SAME: -> tensor<1xi8> { -) -> tensor<2xi8> { - %sarg0_sharded = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %arg0_sharded = mesh.shard %arg0 to %sarg0_sharded : tensor<2xi8> - %sop_arg0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %op_arg0 = mesh.shard %arg0_sharded to %sop_arg0 annotate_for_users : tensor<2xi8> - %sarg1_sharded = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %arg1_sharded = mesh.shard %arg1 to %sarg1_sharded : tensor<2xi8> - %sop_arg1 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %op_arg1 = mesh.shard %arg1_sharded to %sop_arg1 annotate_for_users : tensor<2xi8> - // CHECK: %[[RES:.*]] = tosa.add %[[ARG0]], %[[ARG1]] : (tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8> - %op_res = tosa.add %op_arg0, %op_arg1 : (tensor<2xi8>, tensor<2xi8>) -> tensor<2xi8> - %sop_res_sharded = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %op_res_sharded = mesh.shard %op_res to %sop_res_sharded : tensor<2xi8> - %sres = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %res = mesh.shard %op_res_sharded to %sres annotate_for_users : tensor<2xi8> - // CHECK: return %[[RES]] : tensor<1xi8> - return %res : tensor<2xi8> -} - -// reshard -// abs -// reshard -// abs -// reshard -// CHECK-LABEL: func @multiple_chained_ops -func.func @multiple_chained_ops( - // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> - %arg0: tensor<2xi8> -// CHECK-SAME: -> tensor<1xi8> { -) -> tensor<2xi8> { - // CHECK: %[[RESHARD1:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 0 - // CHECK-SAME: tensor<2xi8> -> tensor<1xi8> - %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor<2xi8> - %s1 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2xi8> - // CHECK: %[[ABS1:.*]] = tosa.abs %[[RESHARD1]] : (tensor<1xi8>) -> tensor<1xi8> - %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8> - // CHECK: %[[RESHARD2:.*]] = mesh.all_gather %[[ABS1]] on @mesh_1d - // CHECK-SAME: mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8> - %s3 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %3 = mesh.shard %2 to %s3 : tensor<2xi8> - %s4 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding - %4 = mesh.shard %3 to %s4 annotate_for_users : tensor<2xi8> - // CHECK: %[[ABS2:.*]] = tosa.abs %[[RESHARD2]] : (tensor<2xi8>) -> tensor<2xi8> - %5 = tosa.abs %4 : (tensor<2xi8>) -> tensor<2xi8> - // CHECK: %[[RESHARD3:.*]] = mesh.all_slice %[[ABS2]] on @mesh_1d mesh_axes = [0] slice_axis = 0 : - // CHECK-SAME: tensor<2xi8> -> tensor<1xi8> - %s6 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding - %6 = mesh.shard %5 to %s6 : tensor<2xi8> - %s7 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %7 = mesh.shard %6 to %s7 annotate_for_users : tensor<2xi8> - // CHECK: return %[[RESHARD3]] : tensor<1xi8> - return %7 : tensor<2xi8> -} - -// CHECK-LABEL: func @incomplete_sharding -func.func @incomplete_sharding( - // CHECK-SAME: %[[ARG:.*]]: tensor<4x16xf32> - %arg0: tensor<8x16xf32> -// CHECK-SAME: -> tensor<4x16xf32> { -) -> tensor<8x16xf32> { - %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 annotate_for_users : tensor<8x16xf32> - // CHECK: %[[RES:.*]] = tosa.sigmoid %[[ARG]] : (tensor<4x16xf32>) -> tensor<4x16xf32> - %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32> - %s2 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding - %2 = mesh.shard %1 to %s2 : tensor<8x16xf32> - // CHECK: return %[[RES]] : tensor<4x16xf32> - return %2 : tensor<8x16xf32> -} - -mesh.mesh @mesh_1d_4(shape = 4) - -// CHECK-LABEL: func @ew_chain_with_halo -func.func @ew_chain_with_halo( - // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<5x16xf32> - %arg0: tensor<8x16xf32>, - // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<1xf32> - %arg1: tensor<1xf32>, - // CHECK-SAME: %[[IN3:[A-Za-z0-9_]+]]: tensor<1xf32> - %arg2: tensor<1xf32>) - // CHECK-SAME: -> tensor<5x16xf32> - -> tensor<8x16xf32> { - %ssharding_annotated = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding - %sharding_annotated = mesh.shard %arg0 to %ssharding_annotated annotate_for_users : tensor<8x16xf32> - // CHECK: %[[TMP1:.*]] = tosa.tanh %[[IN1]] : (tensor<5x16xf32>) -> tensor<5x16xf32> - %0 = tosa.tanh %sharding_annotated : (tensor<8x16xf32>) -> tensor<8x16xf32> - %ssharding_annotated_0 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding - %sharding_annotated_0 = mesh.shard %0 to %ssharding_annotated_0 : tensor<8x16xf32> - %ssharding_annotated_1 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding - %sharding_annotated_1 = mesh.shard %sharding_annotated_0 to %ssharding_annotated_1 annotate_for_users : tensor<8x16xf32> - // CHECK-NEXT: %[[TMP2:.*]] = tosa.abs %[[TMP1]] : (tensor<5x16xf32>) -> tensor<5x16xf32> - %1 = tosa.abs %sharding_annotated_1 : (tensor<8x16xf32>) -> tensor<8x16xf32> - %ssharding_annotated_2 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding - %sharding_annotated_2 = mesh.shard %1 to %ssharding_annotated_2 : tensor<8x16xf32> - %ssharding_annotated_4 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding - %sharding_annotated_4 = mesh.shard %sharding_annotated_2 to %ssharding_annotated_4 annotate_for_users : tensor<8x16xf32> - // CHECK-NEXT: %[[TMP3:.*]] = tosa.negate %[[TMP2]], %[[IN2]], %[[IN3]] : (tensor<5x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x16xf32> - %sharding_1 = mesh.sharding @mesh_1d_4 split_axes = [[]] : !mesh.sharding - %zero_point_1 = mesh.shard %arg1 to %sharding_1 annotate_for_users : tensor<1xf32> - %zero_point_2 = mesh.shard %arg2 to %sharding_1 annotate_for_users : tensor<1xf32> - %2 = tosa.negate %sharding_annotated_4, %zero_point_1, %zero_point_2 : (tensor<8x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<8x16xf32> - %ssharding_annotated_5 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding - %sharding_annotated_5 = mesh.shard %2 to %ssharding_annotated_5 : tensor<8x16xf32> - %ssharding_annotated_6 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding - %sharding_annotated_6 = mesh.shard %sharding_annotated_5 to %ssharding_annotated_6 annotate_for_users : tensor<8x16xf32> - // CHECK-NEXT: return %[[TMP3]] : tensor<5x16xf32> - return %sharding_annotated_6 : tensor<8x16xf32> -} - -// CHECK-LABEL: func @test_shard_update_halo -// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x1200xi64> -func.func @test_shard_update_halo(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> { - %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] : !mesh.sharding - // CHECK: %[[T:.*]] = tensor.empty() : tensor<304x1200xi64> - // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][2, 0] [300, 1200] [1, 1] : tensor<300x1200xi64> into tensor<304x1200xi64> - // CHECK: %[[UH:.*]] = mesh.update_halo %[[inserted_slice]] on @mesh_1d_4 split_axes = {{\[\[0]]}} halo_sizes = [2, 2] : tensor<304x1200xi64> - %sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64> - %sharding_0 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 2] : !mesh.sharding - %sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64> - %sharding_annotated_3 = mesh.shard %sharding_annotated_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64> - // CHECK: return %[[UH]] : tensor<304x1200xi64> - return %sharding_annotated_3 : tensor<1200x1200xi64> -} - -mesh.mesh @mesh4x4(shape = 4x4) -// CHECK-LABEL: func @test_shard_update_halo2d -// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x300xi64> -func.func @test_shard_update_halo2d(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> { - %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] : !mesh.sharding - // CHECK: %[[T:.*]] = tensor.empty() : tensor<303x307xi64> - // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][1, 3] [300, 300] [1, 1] : tensor<300x300xi64> into tensor<303x307xi64> - // CHECK: %[[UH:.*]] = mesh.update_halo %[[inserted_slice]] on @mesh4x4 split_axes = {{\[\[}}0], [1]] halo_sizes = [1, 2, 3, 4] : tensor<303x307xi64> - %sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64> - %sharding_0 = mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 3, 4] : !mesh.sharding - %sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64> - %sharding_annotated_3 = mesh.shard %sharding_annotated_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64> - // CHECK: return %[[UH]] : tensor<303x307xi64> - return %sharding_annotated_3 : tensor<1200x1200xi64> -} - -mesh.mesh @mesh(shape = 2) -// CHECK-LABEL: func.func @test_reduce_0d( -// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<3x6xi32> -func.func @test_reduce_0d(%arg0: tensor<6x6xi32>) -> (tensor<i32>) { - %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding - %sharded = mesh.shard %arg0 to %sharding annotate_for_users : tensor<6x6xi32> - %4 = tensor.empty() : tensor<i32> - %sharding_out = mesh.sharding @mesh split_axes = [[]] : !mesh.sharding - %sharded_out = mesh.shard %4 to %sharding_out : tensor<i32> - %sharded_in = mesh.shard %sharded to %sharding annotate_for_users : tensor<6x6xi32> - // CHECK: %[[reduced:.*]] = linalg.reduce ins(%arg0 : tensor<3x6xi32>) - %reduced = linalg.reduce ins(%sharded_in : tensor<6x6xi32>) outs(%sharded_out : tensor<i32>) dimensions = [0, 1] - (%in: i32, %init: i32) { - %6 = arith.addi %in, %init : i32 - linalg.yield %6 : i32 - } - // CHECK: %[[all_reduce:.*]] = mesh.all_reduce %[[reduced]] on @mesh mesh_axes = [0] : tensor<i32> -> tensor<i32> - %sharded_red = mesh.shard %reduced to %sharding_out : tensor<i32> - %sharded_ret = mesh.shard %sharded_red to %sharding_out annotate_for_users : tensor<i32> - // CHECK: return %[[all_reduce]] : tensor<i32> - return %sharded_ret : tensor<i32> -} - -// CHECK-LABEL: func.func @test_reduce_1d( -// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<3x6xi32> -func.func @test_reduce_1d(%arg0: tensor<6x6xi32>) -> (tensor<6xi32>) { - %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding - %sharded = mesh.shard %arg0 to %sharding annotate_for_users : tensor<6x6xi32> - %4 = tensor.empty() : tensor<6xi32> - %sharded_out = mesh.shard %4 to %sharding : tensor<6xi32> - %sharded_in = mesh.shard %sharded to %sharding annotate_for_users : tensor<6x6xi32> - // CHECK: %[[reduced:.*]] = linalg.reduce ins(%arg0 : tensor<3x6xi32>) - %reduced = linalg.reduce ins(%sharded_in : tensor<6x6xi32>) outs(%sharded_out : tensor<6xi32>) dimensions = [1] - (%in: i32, %init: i32) { - %6 = arith.addi %in, %init : i32 - linalg.yield %6 : i32 - } - // CHECK-NOT: mesh.all_reduce - %sharded_red = mesh.shard %reduced to %sharding : tensor<6xi32> - %sharded_ret = mesh.shard %sharded_red to %sharding annotate_for_users : tensor<6xi32> - // CHECK: return %[[reduced]] : tensor<3xi32> - return %sharded_ret : tensor<6xi32> -} diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 4c50ed3..8c846cd 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -1406,7 +1406,7 @@ func.func @omp_atomic_update(%x : memref<i32>, %expr : i32, %xBool : memref<i1>, // CHECK-NEXT: (%[[XVAL:.*]]: i1): // CHECK-NEXT: %[[NEWVAL:.*]] = llvm.icmp "eq" %[[XVAL]], %[[EXPRBOOL]] : i1 // CHECK-NEXT: omp.yield(%[[NEWVAL]] : i1) - // } + // CHECK-NEXT: } omp.atomic.update %xBool : memref<i1> { ^bb0(%xval: i1): %newval = llvm.icmp "eq" %xval, %exprBool : i1 @@ -1562,6 +1562,14 @@ func.func @omp_atomic_update(%x : memref<i32>, %expr : i32, %xBool : memref<i1>, omp.yield(%newval : i32) } + // CHECK: omp.atomic.update %[[X]] : memref<i32> { + // CHECK-NEXT: (%[[XVAL:.*]]: i32): + // CHECK-NEXT: omp.yield(%{{.+}} : i32) + // CHECK-NEXT: } {atomic_control = #omp.atomic_control<ignore_denormal_mode = true, fine_grained_memory = true, remote_memory = true>} + omp.atomic.update %x : memref<i32> { + ^bb0(%xval:i32): + omp.yield(%const:i32) + } {atomic_control = #omp.atomic_control<ignore_denormal_mode = true, fine_grained_memory = true, remote_memory = true>} return } diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index 12d30e17..308cf150 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -1440,8 +1440,8 @@ func.func @propagate_into_execute_region() { // ----- -// CHECK-LABEL: func @execute_region_elim -func.func @execute_region_elim() { +// CHECK-LABEL: func @execute_region_inline +func.func @execute_region_inline() { affine.for %i = 0 to 100 { "test.foo"() : () -> () %v = scf.execute_region -> i64 { @@ -1461,8 +1461,30 @@ func.func @execute_region_elim() { // ----- -// CHECK-LABEL: func @func_execute_region_elim -func.func @func_execute_region_elim() { +// CHECK-LABEL: func @execute_region_no_inline +func.func @execute_region_no_inline() { + affine.for %i = 0 to 100 { + "test.foo"() : () -> () + %v = scf.execute_region -> i64 no_inline { + %x = "test.val"() : () -> i64 + scf.yield %x : i64 + } + "test.bar"(%v) : (i64) -> () + } + return +} + +// CHECK-NEXT: affine.for %arg0 = 0 to 100 { +// CHECK-NEXT: "test.foo"() : () -> () +// CHECK-NEXT: scf.execute_region +// CHECK-NEXT: %[[VAL:.*]] = "test.val"() : () -> i64 +// CHECK-NEXT: scf.yield %[[VAL]] : i64 +// CHECK-NEXT: } + +// ----- + +// CHECK-LABEL: func @func_execute_region_inline +func.func @func_execute_region_inline() { "test.foo"() : () -> () %v = scf.execute_region -> i64 { %c = "test.cmp"() : () -> i1 @@ -1496,8 +1518,8 @@ func.func @func_execute_region_elim() { // ----- -// CHECK-LABEL: func @func_execute_region_elim_multi_yield -func.func @func_execute_region_elim_multi_yield() { +// CHECK-LABEL: func @func_execute_region_inline_multi_yield +func.func @func_execute_region_inline_multi_yield() { "test.foo"() : () -> () %v = scf.execute_region -> i64 { %c = "test.cmp"() : () -> i1 diff --git a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir index d6c3464..58b8288 100644 --- a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir @@ -33,6 +33,24 @@ func.func @inotequal_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vecto // ----- //===----------------------------------------------------------------------===// +// spirv.IsFinite +//===----------------------------------------------------------------------===// + +func.func @isfinite_scalar(%arg0: f32) -> i1 { + // CHECK: spirv.IsFinite {{.*}} : f32 + %0 = spirv.IsFinite %arg0 : f32 + return %0 : i1 +} + +func.func @isfinite_vector(%arg0: vector<2xf32>) -> vector<2xi1> { + // CHECK: spirv.IsFinite {{.*}} : vector<2xf32> + %0 = spirv.IsFinite %arg0 : vector<2xf32> + return %0 : vector<2xi1> +} + +// ----- + +//===----------------------------------------------------------------------===// // spirv.IsInf //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir index 5d05a654..6d321af 100644 --- a/mlir/test/Dialect/SPIRV/IR/types.mlir +++ b/mlir/test/Dialect/SPIRV/IR/types.mlir @@ -296,6 +296,12 @@ func.func private @struct_type_with_matrix_2(!spirv.struct<(!spirv.matrix<3 x ve // CHECK: func private @struct_empty(!spirv.struct<()>) func.func private @struct_empty(!spirv.struct<()>) +// CHECK: func.func private @struct_block(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>) +func.func private @struct_block(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>) + +// CHECK: func.func private @struct_two_dec(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block, CPacked>) +func.func private @struct_two_dec(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block, CPacked>) + // ----- // expected-error @+1 {{offset specification must be given for all members}} diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir index bd51a07..f3a3218 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir @@ -66,3 +66,27 @@ spirv.module Logical GLSL450 attributes {spirv.target_env = #spirv.target_env<#s // CHECK: spirv.EntryPoint "GLCompute" [[FN]], [[VAR0]], [[VAR1]] // CHECK: spirv.ExecutionMode [[FN]] "LocalSize", 32, 1, 1 } // end spirv.module + +// ----- + +module { + spirv.module Logical GLSL450 attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Sampled1D], []>, #spirv.resource_limits<>>} { + // CHECK-DAG: spirv.GlobalVariable @[[IMAGE_GV:.*]] bind(0, 0) : !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant> + // CHECK: spirv.func @read_image + spirv.func @read_image(%arg0: !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) "None" attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1]>} { + // CHECK: %[[IMAGE_ADDR:.*]] = spirv.mlir.addressof @[[IMAGE_GV]] : !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant> + %cst0_i32 = spirv.Constant 0 : i32 + // CHECK: spirv.Load "UniformConstant" %[[IMAGE_ADDR]] + %0 = spirv.Load "UniformConstant" %arg0 : !spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>> + %1 = spirv.Image %0 : !spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>> + %2 = spirv.ImageFetch %1, %cst0_i32 : !spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>, i32 -> vector<4xf32> + %3 = spirv.CompositeExtract %2[0 : i32] : vector<4xf32> + %cst0_i32_0 = spirv.Constant 0 : i32 + %cst0_i32_1 = spirv.Constant 0 : i32 + %cst1_i32 = spirv.Constant 1 : i32 + %4 = spirv.AccessChain %arg1[%cst0_i32_0, %cst0_i32] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer> + spirv.Store "StorageBuffer" %4, %3 : f32 + spirv.Return + } + } +} diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir index 2b23766..8d7f3da 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir @@ -178,7 +178,7 @@ spirv.module Logical GLSL450 attributes { // Vulkan memory model requires SPV_KHR_vulkan_memory_model, which is enabled // implicitly by v1.5. -// CHECK: requires #spirv.vce<v1.0, [VulkanMemoryModel], [SPV_KHR_vulkan_memory_model]> +// CHECK: requires #spirv.vce<v1.5, [VulkanMemoryModel], [SPV_KHR_vulkan_memory_model]> spirv.module Logical Vulkan attributes { spirv.target_env = #spirv.target_env< #spirv.vce<v1.5, [Shader, VulkanMemoryModel], []>, #spirv.resource_limits<>> diff --git a/mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir b/mlir/test/Dialect/Shard/all-scatter-op-lowering.mlir index 4f54607a..bc91121 100644 --- a/mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir +++ b/mlir/test/Dialect/Shard/all-scatter-op-lowering.mlir @@ -1,43 +1,43 @@ -// RUN: mlir-opt --split-input-file --test-mesh-all-slice-op-lowering --test-mesh-simplifications --cse %s | FileCheck %s +// RUN: mlir-opt --split-input-file --test-grid-all-slice-op-lowering --test-grid-simplifications --cse %s | FileCheck %s -mesh.mesh @mesh_1d(shape = ?) +shard.grid @grid_1d(shape = ?) -// CHECK-LABEL: func.func @all_slice_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh -func.func @all_slice_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh( +// CHECK-LABEL: func.func @all_slice_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_grid +func.func @all_slice_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_grid( // CHECK: %[[ARG:.*]]: tensor<?xf16> %arg0: tensor<?xf16> // CHECK-SAME: -> tensor<?xf16> { ) -> tensor<?xf16> { // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[PROC_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index - // CHECK-DAG: %[[MESH_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index + // CHECK-DAG: %[[PROC_IDX:.*]] = shard.process_multi_index on @grid_1d axes = [0] : index + // CHECK-DAG: %[[SHARD_SIZE:.*]] = shard.grid_shape @grid_1d axes = [0] : index // CHECK: %[[TENSOR_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %c0 : tensor<?xf16> - // CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[TENSOR_AXIS_SIZE]], %[[MESH_SIZE]] : index + // CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[TENSOR_AXIS_SIZE]], %[[SHARD_SIZE]] : index // CHECK: %[[AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[AXIS_SIZE_CHECK_REMINDER]], %[[C0]] : index // CHECK: cf.assert %[[AXIS_SIZE_CHECK]] - // CHECK: %[[RESULT_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_AXIS_SIZE]], %[[MESH_SIZE]] : index + // CHECK: %[[RESULT_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_AXIS_SIZE]], %[[SHARD_SIZE]] : index // CHECK: %[[SLICE_OFFSET:.*]] = arith.muli %[[PROC_IDX]], %[[RESULT_AXIS_SIZE]] : index // CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][%[[SLICE_OFFSET]]] [%[[RESULT_AXIS_SIZE]]] [1] : tensor<?xf16> to tensor<?xf16> - %0 = mesh.all_slice %arg0 on @mesh_1d mesh_axes = [0] slice_axis = 0 : tensor<?xf16> -> tensor<?xf16> + %0 = shard.all_slice %arg0 on @grid_1d grid_axes = [0] slice_axis = 0 : tensor<?xf16> -> tensor<?xf16> // CHECK: return %[[RESULT]] : tensor<?xf16> return %0 : tensor<?xf16> } // ----- -mesh.mesh @mesh_1d(shape = 2) +shard.grid @grid_1d(shape = 2) -// CHECK-LABEL: func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_mesh -func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_mesh( +// CHECK-LABEL: func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_grid +func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_grid( // CHECK: %[[ARG:.*]]: tensor<2xf16> %arg0: tensor<2xf16> // CHECK-SAME: -> tensor<1xf16> { ) -> tensor<1xf16> { // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index - // CHECK: %[[PROC_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index + // CHECK: %[[PROC_IDX:.*]] = shard.process_multi_index on @grid_1d axes = [0] : index // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[ARG]][%[[PROC_IDX]]] [%[[C1]]] [1] : tensor<2xf16> to tensor<?xf16> // CHECK: %[[RESULT:.*]] = tensor.cast %[[SLICE]] : tensor<?xf16> to tensor<1xf16> - %0 = mesh.all_slice %arg0 on @mesh_1d mesh_axes = [0] slice_axis = 0 : tensor<2xf16> -> tensor<1xf16> + %0 = shard.all_slice %arg0 on @grid_1d grid_axes = [0] slice_axis = 0 : tensor<2xf16> -> tensor<1xf16> // CHECK: return %[[RESULT]] : tensor<1xf16> return %0 : tensor<1xf16> } @@ -46,18 +46,18 @@ func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_mesh( // CHECK: #map = affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)> -mesh.mesh @mesh_4d(shape = ?x?x?x?) +shard.grid @grid_4d(shape = ?x?x?x?) -// CHECK-LABEL: func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh -func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh( +// CHECK-LABEL: func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_grid +func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_grid( // CHECK: %[[ARG:.*]]: tensor<?x?xf16> %arg0 : tensor<?x?xf16> // CHECK-SAME: -> tensor<?x?xf16> { ) -> tensor<?x?xf16> { // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[IN_GROUP_PROC_MULTI_IDX:.*]]:2 = mesh.process_multi_index on @mesh_4d axes = [3, 1] : index, index - // CHECK-DAG: %[[PROC_GROUP_SHAPE:.*]]:2 = mesh.mesh_shape @mesh_4d axes = [3, 1] : index, index + // CHECK-DAG: %[[IN_GROUP_PROC_MULTI_IDX:.*]]:2 = shard.process_multi_index on @grid_4d axes = [3, 1] : index, index + // CHECK-DAG: %[[PROC_GROUP_SHAPE:.*]]:2 = shard.grid_shape @grid_4d axes = [3, 1] : index, index // CHECK: %[[PROC_GROUP_SIZE:.*]] = arith.muli %[[PROC_GROUP_SHAPE]]#0, %[[PROC_GROUP_SHAPE]]#1 : index // CHECK: %[[SCATTER_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[C1]] : tensor<?x?xf16> // CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[SCATTER_AXIS_SIZE]], %[[PROC_GROUP_SIZE]] : index @@ -68,7 +68,7 @@ func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh( // CHECK: %[[AXIS_0_SIZE:.*]] = tensor.dim %[[ARG]], %[[C0]] : tensor<?x?xf16> // CHECK: %[[SCATTER_AXIS_OFFSET:.*]] = arith.muli %[[PROC_IN_GROUP_LINEAR_IDX]], %[[RESULT_SCATTER_AXIS_SIZE]] : index // CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][0, %[[SCATTER_AXIS_OFFSET]]] [%[[AXIS_0_SIZE]], %[[RESULT_SCATTER_AXIS_SIZE]]] [1, 1] : tensor<?x?xf16> to tensor<?x?xf16> - %0 = mesh.all_slice %arg0 on @mesh_4d mesh_axes = [3, 1] slice_axis = 1 : tensor<?x?xf16> -> tensor<?x?xf16> + %0 = shard.all_slice %arg0 on @grid_4d grid_axes = [3, 1] slice_axis = 1 : tensor<?x?xf16> -> tensor<?x?xf16> // CHECK: return %[[RESULT]] : tensor<?x?xf16> return %0 : tensor<?x?xf16> } diff --git a/mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir b/mlir/test/Dialect/Shard/backward-sharding-propagation.mlir index 4223d01..8894c4a 100644 --- a/mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir +++ b/mlir/test/Dialect/Shard/backward-sharding-propagation.mlir @@ -2,17 +2,17 @@ #map = affine_map<(d0, d1) -> (d0, d1)> module { - mesh.mesh @mesh(shape = 1) {sym_visibility = "private"} + shard.grid @grid(shape = 1) {sym_visibility = "private"} func.func @test_forward() -> tensor<6x6xi32> { %c1_i32 = arith.constant 1 : i32 // CHECK: tensor.empty() %0 = tensor.empty() : tensor<6x6xi32> - %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding - // CHECK-COUNT-2: mesh.shard - %sharding_annotated = mesh.shard %0 to %sharding : tensor<6x6xi32> - %1 = linalg.fill ins(%c1_i32 : i32) outs(%sharding_annotated : tensor<6x6xi32>) -> tensor<6x6xi32> + %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding + // CHECK-COUNT-2: shard.shard + %sharded = shard.shard %0 to %sharding : tensor<6x6xi32> + %1 = linalg.fill ins(%c1_i32 : i32) outs(%sharded : tensor<6x6xi32>) -> tensor<6x6xi32> // CHECK: tensor.empty() - // CHECK-NOT: mesh.shard @ + // CHECK-NOT: shard.shard @ %2 = tensor.empty() : tensor<6x6xi32> %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%1, %1 : tensor<6x6xi32>, tensor<6x6xi32>) outs(%2 : tensor<6x6xi32>) { diff --git a/mlir/test/Dialect/Shard/canonicalization.mlir b/mlir/test/Dialect/Shard/canonicalization.mlir new file mode 100644 index 0000000..ed40dfb --- /dev/null +++ b/mlir/test/Dialect/Shard/canonicalization.mlir @@ -0,0 +1,248 @@ +// RUN: mlir-opt --canonicalize %s | FileCheck %s + +shard.grid @grid0(shape = 2x4) + +// CHECK-LABEL: func @all_reduce_empty_grid_axes +func.func @all_reduce_empty_grid_axes( +// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> + %arg0 : tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NOT: shard.all_reduce + %0 = shard.all_reduce %arg0 on @grid0 + grid_axes = [] + : tensor<4xf32> -> tensor<4xf32> +// CHECK: return %[[ARG]] + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func @all_reduce_empty_grid_axes_different_return_type +func.func @all_reduce_empty_grid_axes_different_return_type( + %arg0 : tensor<4xf32>) -> tensor<4xf64> { +// CHECK: shard.all_reduce + %0 = shard.all_reduce %arg0 on @grid0 +// CHECK-NOT: grid_axes + grid_axes = [] + : tensor<4xf32> -> tensor<4xf64> + return %0 : tensor<4xf64> +} + +// CHECK-LABEL: func @all_reduce_default_reduction +func.func @all_reduce_default_reduction( + %arg0 : tensor<4xf32>) -> tensor<4xf64> { + %0 = shard.all_reduce %arg0 on @grid0 + grid_axes = [0] +// CHECK-NOT: reduction + reduction = sum + : tensor<4xf32> -> tensor<4xf64> + return %0 : tensor<4xf64> +} + +// CHECK-LABEL: func @all_to_all_empty_grid_axes +func.func @all_to_all_empty_grid_axes( +// CHECK-SAME: %[[ARG:.*]]: tensor<8xf32> + %arg0 : tensor<8xf32>) -> tensor<8xf32> { +// CHECK-NOT: shard.all_to_all + %0 = shard.all_to_all %arg0 on @grid0 + grid_axes = [] + split_axis = 0 + concat_axis = 0 + : tensor<8xf32> -> tensor<8xf32> +// CHECK: return %[[ARG]] + return %0 : tensor<8xf32> +} + +// CHECK-LABEL: func @all_gather_empty_grid_axes +func.func @all_gather_empty_grid_axes( +// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> + %arg0 : tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NOT: shard.all_gather + %0 = shard.all_gather %arg0 on @grid0 + grid_axes = [] + gather_axis = 0 + : tensor<4xf32> -> tensor<4xf32> +// CHECK: return %[[ARG]] + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func @all_slice_empty_grid_axes +func.func @all_slice_empty_grid_axes( +// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> + %arg0 : tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NOT: shard.scatter + %0 = shard.all_slice %arg0 on @grid0 + grid_axes = [] + slice_axis = 0 + : tensor<4xf32> -> tensor<4xf32> +// CHECK: return %[[ARG]] + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func @broadcast_empty_grid_axes +func.func @broadcast_empty_grid_axes( +// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> + %arg0 : tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NOT: shard.broadcast + %0 = shard.broadcast %arg0 on @grid0 + grid_axes = [] + root = [] + : (tensor<4xf32>) -> tensor<4xf32> +// CHECK: return %[[ARG]] + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func @gather_empty_grid_axes +func.func @gather_empty_grid_axes( +// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> + %arg0 : tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NOT: shard.gather + %0 = shard.gather %arg0 on @grid0 + grid_axes = [] + gather_axis = 0 + root = [] + : (tensor<4xf32>) -> tensor<4xf32> +// CHECK: return %[[ARG]] + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func @receive_empty_grid_axes +func.func @receive_empty_grid_axes( +// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> + %arg0 : tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NOT: shard.recv + %0 = shard.recv %arg0 on @grid0 + grid_axes = [] + : (tensor<4xf32>) -> tensor<4xf32> +// CHECK: return %[[ARG]] + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func @reduce_empty_grid_axes +func.func @reduce_empty_grid_axes( +// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> + %arg0 : tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NOT: shard.reduce + %0 = shard.reduce %arg0 on @grid0 + grid_axes = [] + root = [] + : (tensor<4xf32>) -> tensor<4xf32> +// CHECK: return %[[ARG]] + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func @reduce_scatter_empty_grid_axes +func.func @reduce_scatter_empty_grid_axes( +// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> + %arg0 : tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NOT: shard.reduce_scatter + %0 = shard.reduce_scatter %arg0 on @grid0 + grid_axes = [] + scatter_axis = 0 + : tensor<4xf32> -> tensor<4xf32> +// CHECK: return %[[ARG]] + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func @reduce_scatter_empty_grid_axes_different_return_type +func.func @reduce_scatter_empty_grid_axes_different_return_type( + %arg0 : tensor<4xf32>) -> tensor<4xf64> { +// CHECK: shard.reduce_scatter + %0 = shard.reduce_scatter %arg0 on @grid0 +// CHECK-NOT: grid_axes + grid_axes = [] + scatter_axis = 0 + : tensor<4xf32> -> tensor<4xf64> + return %0 : tensor<4xf64> +} + +// CHECK-LABEL: func @reduce_scatter_default_reduction +func.func @reduce_scatter_default_reduction( + %arg0 : tensor<4xf32>) -> tensor<2xf64> { + %0 = shard.reduce_scatter %arg0 on @grid0 + grid_axes = [0] +// CHECK-NOT: reduction + reduction = sum + scatter_axis = 0 + : tensor<4xf32> -> tensor<2xf64> + return %0 : tensor<2xf64> +} + +// CHECK-LABEL: func @scatter_empty_grid_axes +func.func @scatter_empty_grid_axes( +// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> + %arg0 : tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NOT: shard.scatter + %0 = shard.scatter %arg0 on @grid0 + grid_axes = [] + scatter_axis = 0 + root = [] + : (tensor<4xf32>) -> tensor<4xf32> +// CHECK: return %[[ARG]] + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func @send_empty_grid_axes +func.func @send_empty_grid_axes( +// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> + %arg0 : tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NOT: shard.send + %0 = shard.send %arg0 on @grid0 + grid_axes = [] + destination = [] + : (tensor<4xf32>) -> tensor<4xf32> +// CHECK: return %[[ARG]] + return %0 : tensor<4xf32> +} + +shard.grid @grid4x4(shape = 4x4) +// CHECK-LABEL: func @test_halo_sizes +func.func @test_halo_sizes() -> !shard.sharding { + %c2_i64 = arith.constant 2 : i64 + // CHECK shard.sharding @grid4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 2, 22] : !shard.sharding + %sharding = shard.sharding @grid4x4 split_axes = [[0], [1]] halo_sizes = [1, %c2_i64, %c2_i64, 22] : !shard.sharding + return %sharding : !shard.sharding +} + +// CHECK-LABEL: func @test_shard_offs +func.func @test_shard_offs() -> !shard.sharding { + %c2_i64 = arith.constant 2 : i64 + // CHECK shard.sharding @grid4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, 2, 3, 4, 0, 2, 3, 4, 22] : !shard.sharding + %sharding = shard.sharding @grid4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, %c2_i64, 3, 4, 0, %c2_i64, 3, 4, 22] : !shard.sharding + return %sharding : !shard.sharding +} + +// CHECK-LABEL: func @test_duplicate_shardops +func.func @test_duplicate_shardops() -> (tensor<1024x1024xf32>, tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} { + // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + // CHECK-NEXT: [[vsharding:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0, 1]] : !shard.sharding + %sharding_1 = shard.sharding @grid4x4 split_axes = [[0, 1]] : !shard.sharding + %cst_2 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + %sharding_2 = shard.sharding @grid4x4 split_axes = [[0, 1]] : !shard.sharding + %sharded_2 = shard.shard %cst_2 to %sharding_2 : tensor<1024x1024xf32> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + %sharding_3 = shard.sharding @grid4x4 split_axes = [[0, 1]] : !shard.sharding + %sharded_3 = shard.shard %cst_3 to %sharding_3 : tensor<1024x1024xf32> + // CHECK-NEXT: [[vsharded:%.*]] = shard.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32> + %sharded_1 = shard.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32> + // CHECK-NEXT: return [[vsharded]], [[vsharded]] : tensor<1024x1024xf32>, tensor<1024x1024xf32> + return %sharded_1, %sharded_2 : tensor<1024x1024xf32>, tensor<1024x1024xf32> +} + +// CHECK-LABEL: func @test_duplicate_shardops_diff +func.func @test_duplicate_shardops_diff() -> (tensor<1024x1024xf32>, tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} { + // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + // CHECK-NEXT: [[vsharding:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding + %sharding_1 = shard.sharding @grid4x4 split_axes = [[0]] : !shard.sharding + %cst_2 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + // CHECK-NEXT: [[vsharding_0:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0, 1]] : !shard.sharding + %sharding_2 = shard.sharding @grid4x4 split_axes = [[0, 1]] : !shard.sharding + // CHECK-NEXT: [[vsharded:%.*]] = shard.shard [[vcst]] to [[vsharding_0]] : tensor<1024x1024xf32> + %sharded_2 = shard.shard %cst_2 to %sharding_2 : tensor<1024x1024xf32> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> + %sharding_3 = shard.sharding @grid4x4 split_axes = [[0]] : !shard.sharding + %sharded_3 = shard.shard %cst_3 to %sharding_3 : tensor<1024x1024xf32> + // CHECK-NEXT: [[vsharded_1:%.*]] = shard.shard [[vsharded]] to [[vsharding]] : tensor<1024x1024xf32> + %sharded_1 = shard.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32> + // CHECK-NEXT: return [[vsharded_1]], [[vsharded]] : tensor<1024x1024xf32>, tensor<1024x1024xf32> + return %sharded_1, %sharded_2 : tensor<1024x1024xf32>, tensor<1024x1024xf32> +} diff --git a/mlir/test/Dialect/Shard/folding.mlir b/mlir/test/Dialect/Shard/folding.mlir new file mode 100644 index 0000000..5a0f35b --- /dev/null +++ b/mlir/test/Dialect/Shard/folding.mlir @@ -0,0 +1,22 @@ +// RUN: mlir-opt -test-grid-simplifications %s | FileCheck %s + +shard.grid @grid0(shape = 4x?x2) +shard.grid @grid1(shape = 2x3) + +// CHECK-LABEL: func.func @grid_shape_op_folding +func.func @grid_shape_op_folding() -> (index, index) { + // CHECK: %[[AXIS_2_SIZE:.*]] = arith.constant 2 : index + // CHECK: %[[AXIS_1_SIZE:.*]] = shard.grid_shape @grid0 axes = [1] : index + %0:2 = shard.grid_shape @grid0 axes = [2, 1] : index, index + // CHECK: return %[[AXIS_2_SIZE]], %[[AXIS_1_SIZE]] + return %0#0, %0#1 : index, index +} + +// CHECK-LABEL: func.func @grid_shape_op_folding_all_axes_static_grid +func.func @grid_shape_op_folding_all_axes_static_grid() -> (index, index) { + // CHECK: %[[AXIS_0_SIZE:.*]] = arith.constant 2 : index + // CHECK: %[[AXIS_1_SIZE:.*]] = arith.constant 3 : index + %0:2 = shard.grid_shape @grid1 : index, index + // CHECK: return %[[AXIS_0_SIZE]], %[[AXIS_1_SIZE]] + return %0#0, %0#1 : index, index +} diff --git a/mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir b/mlir/test/Dialect/Shard/forward-backward-sharding-propagation.mlir index dd2eee2..0d8d997 100644 --- a/mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir +++ b/mlir/test/Dialect/Shard/forward-backward-sharding-propagation.mlir @@ -2,25 +2,25 @@ #map = affine_map<(d0, d1) -> (d0, d1)> module { - mesh.mesh @mesh(shape = 1) {sym_visibility = "private"} + shard.grid @grid(shape = 1) {sym_visibility = "private"} func.func @test_forward() -> tensor<6x6xi32> { %c1_i32 = arith.constant 1 : i32 // CHECK: tensor.empty() %0 = tensor.empty() : tensor<6x6xi32> - // CHECK-COUNT-3: mesh.sharding @mesh split_axes = {{\[\[0}}]] - %sharding_row = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding - %annotated_row = mesh.shard %0 to %sharding_row : tensor<6x6xi32> + // CHECK-COUNT-3: shard.sharding @grid split_axes = {{\[\[0}}]] + %sharding_row = shard.sharding @grid split_axes = [[0]] : !shard.sharding + %annotated_row = shard.shard %0 to %sharding_row : tensor<6x6xi32> %1 = linalg.fill ins(%c1_i32 : i32) outs(%annotated_row : tensor<6x6xi32>) -> tensor<6x6xi32> %2 = tensor.empty() : tensor<6x6xi32> - // CHECK-COUNT-4: mesh.sharding @mesh split_axes = {{\[\[1}}]] + // CHECK-COUNT-4: shard.sharding @grid split_axes = {{\[\[1}}]] %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%2, %1 : tensor<6x6xi32>, tensor<6x6xi32>) outs(%2 : tensor<6x6xi32>) { ^bb0(%in: i32, %in_2: i32, %out: i32): %9 = arith.addi %in, %in_2 : i32 linalg.yield %9 : i32 } -> tensor<6x6xi32> - %sharding_col = mesh.sharding @mesh split_axes = [[1]] : !mesh.sharding - %annotated_col = mesh.shard %3 to %sharding_col : tensor<6x6xi32> + %sharding_col = shard.sharding @grid split_axes = [[1]] : !shard.sharding + %annotated_col = shard.shard %3 to %sharding_col : tensor<6x6xi32> // CHECK: return return %annotated_col : tensor<6x6xi32> } diff --git a/mlir/test/Dialect/Shard/forward-sharding-propagation.mlir b/mlir/test/Dialect/Shard/forward-sharding-propagation.mlir new file mode 100644 index 0000000..3cda9ea --- /dev/null +++ b/mlir/test/Dialect/Shard/forward-sharding-propagation.mlir @@ -0,0 +1,49 @@ +// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward}))" %s | FileCheck %s + +#map = affine_map<(d0, d1) -> (d0, d1)> +module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "mpich", "MPI:comm_world_rank" = 0 : i32>} { + shard.grid @grid(shape = 1) {sym_visibility = "private"} + func.func @test_forward() -> (tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>) attributes {llvm.emit_c_interface} { + %c1_i32 = arith.constant 1 : i32 + // CHECK: [[v3:%.*]] = tensor.empty() : tensor<6x6xi32> + %0 = tensor.empty() : tensor<6x6xi32> + // CHECK: [[v1:%.*]] = linalg.fill ins + // CHECK: [[vsharding_0:%.*]] = shard.sharding @grid split_axes = {{\[\[}}0]] : !shard.sharding + // CHECK: [[vsharded_1:%.*]] = shard.shard [[v1]] to [[vsharding_0]] : tensor<6x6xi32> + %1 = linalg.fill ins(%c1_i32 : i32) outs(%0 : tensor<6x6xi32>) -> tensor<6x6xi32> + %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding + %sharded = shard.shard %1 to %sharding : tensor<6x6xi32> + // CHECK: [[v2:%.*]] = tensor.empty() : tensor<6x6xi32> + // CHECK: [[vsharding_2:%.*]] = shard.sharding @grid split_axes = {{\[\[}}0]] : !shard.sharding + // CHECK: [[vsharded_3:%.*]] = shard.shard [[vsharded_1]] to [[vsharding_2]] annotate_for_users : tensor<6x6xi32> + %3 = tensor.empty() : tensor<6x6xi32> + // CHECK: [[vsharding_4:%.*]] = shard.sharding @grid split_axes = {{\[\[}}0]] : !shard.sharding + // CHECK: [[vsharded_5:%.*]] = shard.shard [[v2]] to [[vsharding_4]] annotate_for_users : tensor<6x6xi32> + // CHECK: [[v3:%.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} + // CHECK-SAME: ins([[vsharded_3]], [[vsharded_3]] : tensor<6x6xi32>, tensor<6x6xi32>) outs([[vsharded_5]] : tensor<6x6xi32>) { + // CHECK: [[vsharding_6:%.*]] = shard.sharding @grid split_axes = {{\[\[}}0]] : !shard.sharding + // CHECK: [[vsharded_7:%.*]] = shard.shard [[v3]] to [[vsharding_6]] : tensor<6x6xi32> + %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%sharded, %sharded + : tensor<6x6xi32>, tensor<6x6xi32>) outs(%3 : tensor<6x6xi32>) { + ^bb0(%in: i32, %in_2: i32, %out: i32): + %9 = arith.addi %in, %in_2 : i32 + linalg.yield %9 : i32 + } -> tensor<6x6xi32> + %c0_i32 = arith.constant 0 : i32 + %6 = tensor.empty() : tensor<i32> + %7 = linalg.fill ins(%c0_i32 : i32) outs(%6 : tensor<i32>) -> tensor<i32> + // CHECK: [[vreduced:%.*]] = linalg.reduce ins + // CHECK: [[vsharding_12:%.*]] = shard.sharding @grid split_axes = [] : !shard.sharding + // CHECK: [[vsharded_13:%.*]] = shard.shard [[vreduced]] to [[vsharding_12]] : tensor<i32> + %reduced = linalg.reduce ins(%4 : tensor<6x6xi32>) outs(%7 : tensor<i32>) dimensions = [0, 1] + (%in: i32, %init: i32) { + %9 = arith.addi %in, %init : i32 + linalg.yield %9 : i32 + } + // CHECK: [[vsharding_14:%.*]] = shard.sharding @grid split_axes = {{\[\[}}]] : !shard.sharding + %sharding_0 = shard.sharding @grid split_axes = [[]] : !shard.sharding + // CHECK: [[vsharded_15:%.*]] = shard.shard [[vsharded_13]] to [[vsharding_14]] annotate_for_users : tensor<i32> + %sharded_1 = shard.shard %reduced to %sharding_0 annotate_for_users : tensor<i32> + return %sharded, %4, %sharded_1 : tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32> + } +} diff --git a/mlir/test/Dialect/Shard/inlining.mlir b/mlir/test/Dialect/Shard/inlining.mlir new file mode 100644 index 0000000..ce664b3 --- /dev/null +++ b/mlir/test/Dialect/Shard/inlining.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-opt -inline %s | FileCheck %s + +shard.grid @grid0(shape = 4x?x2) + +func.func private @grid_to_inline() -> (index, index) { + %0:2 = shard.grid_shape @grid0 axes = [2, 1] : index, index + return %0#0, %0#1 : index, index +} +// CHECK-LABEL: func.func @main +func.func @main() -> (index, index) { + // CHECK-NEXT: %[[AXIS_SIZE:.*]]:2 = shard.grid_shape @grid0 axes = [2, 1] : index + %0:2 = func.call @grid_to_inline() : () -> (index, index) + // CHECK-NEXT: return %[[AXIS_SIZE]]#0, %[[AXIS_SIZE]]#1 + return %0#0, %0#1 : index, index +} diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Shard/invalid.mlir index 2656332..6acac97 100644 --- a/mlir/test/Dialect/Mesh/invalid.mlir +++ b/mlir/test/Dialect/Shard/invalid.mlir @@ -1,55 +1,55 @@ // RUN: mlir-opt -split-input-file -verify-diagnostics %s -// expected-error@+1 {{rank of mesh is expected to be a positive integer}} -mesh.mesh @mesh0(shape = []) +// expected-error@+1 {{rank of grid is expected to be a positive integer}} +shard.grid @grid0(shape = []) // ----- -// expected-error@+1 {{custom op 'mesh.mesh' Failed parsing dimension list. Did you mean an empty list? It must be denoted by "[]".}} -mesh.mesh @mesh0(shape = -1) +// expected-error@+1 {{custom op 'shard.grid' Failed parsing dimension list. Did you mean an empty list? It must be denoted by "[]".}} +shard.grid @grid0(shape = -1) // ----- -mesh.mesh @mesh0(shape = 2x4) +shard.grid @grid0(shape = 2x4) -func.func @mesh_axis_duplicated_different_subarray( +func.func @grid_axis_duplicated_different_subarray( %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { - // expected-error@+1 {{mesh axis duplicated}} - %s = mesh.sharding @mesh0 split_axes = [[0], [0]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> + // expected-error@+1 {{grid axis duplicated}} + %s = shard.sharding @grid0 split_axes = [[0], [0]] : !shard.sharding + %0 = shard.shard %arg0 to %s : tensor<4x8xf32> return %0 : tensor<4x8xf32> } // ----- -mesh.mesh @mesh0(shape = 2x4) +shard.grid @grid0(shape = 2x4) -func.func @mesh_axis_duplicated_same_subarray( +func.func @grid_axis_duplicated_same_subarray( %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { - // expected-error@+1 {{mesh axis duplicated}} - %s = mesh.sharding @mesh0 split_axes = [[0, 0]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> + // expected-error@+1 {{grid axis duplicated}} + %s = shard.sharding @grid0 split_axes = [[0, 0]] : !shard.sharding + %0 = shard.shard %arg0 to %s : tensor<4x8xf32> return %0 : tensor<4x8xf32> } // ----- -mesh.mesh @mesh0(shape = 2x4) +shard.grid @grid0(shape = 2x4) -func.func @mesh_axis_negtive_in_split_part( +func.func @grid_axis_negtive_in_split_part( %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { - // expected-error@+1 {{mesh axis is expected to be non-negative}} - %s = mesh.sharding @mesh0 split_axes = [[-1]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> + // expected-error@+1 {{grid axis is expected to be non-negative}} + %s = shard.sharding @grid0 split_axes = [[-1]] : !shard.sharding + %0 = shard.shard %arg0 to %s : tensor<4x8xf32> return %0 : tensor<4x8xf32> } // ----- func.func @sharding_attribute_invalid_nested_symbol(%arg0 : tensor<4x8xf32>) { - // expected-error@+1 {{custom op 'mesh.sharding' invalid kind of attribute specified}} - %s = mesh.sharding @a::@b split_axes = [[0]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> + // expected-error@+1 {{custom op 'shard.sharding' invalid kind of attribute specified}} + %s = shard.sharding @a::@b split_axes = [[0]] : !shard.sharding + %0 = shard.shard %arg0 to %s : tensor<4x8xf32> return } @@ -57,8 +57,8 @@ func.func @sharding_attribute_invalid_nested_symbol(%arg0 : tensor<4x8xf32>) { func.func @sharding_attribute_invalid_halo(%arg0 : tensor<4x8xf32>) { // expected-error@+1 {{halo sizes must be specified for all split axes}} - %s = mesh.sharding @mesh0 split_axes = [[0], [1]] halo_sizes = [1, 2] : !mesh.sharding - %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> + %s = shard.sharding @grid0 split_axes = [[0], [1]] halo_sizes = [1, 2] : !shard.sharding + %0 = shard.shard %arg0 to %s : tensor<4x8xf32> return } @@ -66,292 +66,292 @@ func.func @sharding_attribute_invalid_halo(%arg0 : tensor<4x8xf32>) { func.func @sharding_attribute_invalid_sizes(%arg0 : tensor<4x8xf32>) { // expected-error@+1 {{halo sizes and shard offsets are mutually exclusive}} - %s = mesh.sharding @mesh0 split_axes = [[0]] halo_sizes = [1, 2] sharded_dims_offsets = [0, 2, 2] : !mesh.sharding - %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> + %s = shard.sharding @grid0 split_axes = [[0]] halo_sizes = [1, 2] sharded_dims_offsets = [0, 2, 2] : !shard.sharding + %0 = shard.shard %arg0 to %s : tensor<4x8xf32> return } // ----- -mesh.mesh @mesh_dyn(shape = ?x?) -func.func @sharding_dyn_mesh_and_sizes(%arg0 : tensor<4x8xf32>) { - // expected-error@+1 {{sharded dims offsets are not allowed for devices meshes with dynamic shape}} - %s = mesh.sharding @mesh_dyn split_axes = [[0]] sharded_dims_offsets = [0, 2, 2] : !mesh.sharding - %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> +shard.grid @grid_dyn(shape = ?x?) +func.func @sharding_dyn_grid_and_sizes(%arg0 : tensor<4x8xf32>) { + // expected-error@+1 {{sharded dims offsets are not allowed for device grids with dynamic shape}} + %s = shard.sharding @grid_dyn split_axes = [[0]] sharded_dims_offsets = [0, 2, 2] : !shard.sharding + %0 = shard.shard %arg0 to %s : tensor<4x8xf32> return } // ----- -mesh.mesh @mesh0(shape = 2x4) +shard.grid @grid0(shape = 2x4) func.func @sharding_sizes_count(%arg0 : tensor<4x8xf32>) { // expected-error@+1 {{sharded dims offsets has wrong size}} - %s = mesh.sharding @mesh0 split_axes = [[0], [1]] sharded_dims_offsets = [0, 2, 4, 0, 2, 4, 6] : !mesh.sharding - %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> + %s = shard.sharding @grid0 split_axes = [[0], [1]] sharded_dims_offsets = [0, 2, 4, 0, 2, 4, 6] : !shard.sharding + %0 = shard.shard %arg0 to %s : tensor<4x8xf32> return } // ----- -mesh.mesh @mesh0(shape = 4) +shard.grid @grid0(shape = 4) func.func @sharding_sizes_decreasing(%arg0 : tensor<4x8xf32>) { // expected-error@+1 {{sharded dims offsets must be non-decreasing}} - %s = mesh.sharding @mesh0 split_axes = [[0]] sharded_dims_offsets = [0, 2, 3, 2] : !mesh.sharding - %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> + %s = shard.sharding @grid0 split_axes = [[0]] sharded_dims_offsets = [0, 2, 3, 2] : !shard.sharding + %0 = shard.shard %arg0 to %s : tensor<4x8xf32> return } // ----- -mesh.mesh @mesh0(shape = 2x4) +shard.grid @grid0(shape = 2x4) -func.func @mesh_shape_mesh_axis_out_of_bounds() -> (index, index) { - // expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}} - %0:2 = mesh.mesh_shape @mesh0 axes = [0, 2] : index, index +func.func @grid_shape_grid_axis_out_of_bounds() -> (index, index) { + // expected-error@+1 {{0-based grid axis index 2 is out of bounds. The referenced grid "grid0" is of rank 2.}} + %0:2 = shard.grid_shape @grid0 axes = [0, 2] : index, index return %0#0, %0#1 : index, index } // ----- -mesh.mesh @mesh0(shape = 1x2x3) +shard.grid @grid0(shape = 1x2x3) -func.func @mesh_shape_duplicate_mesh_axis() -> (index, index, index) { - // expected-error@+1 {{Mesh axes contains duplicate elements.}} - %0:3 = mesh.mesh_shape @mesh0 axes = [0, 2, 0] : index, index, index +func.func @grid_shape_duplicate_grid_axis() -> (index, index, index) { + // expected-error@+1 {{Grid axes contains duplicate elements.}} + %0:3 = shard.grid_shape @grid0 axes = [0, 2, 0] : index, index, index return %0#0, %0#1, %0#2 : index, index, index } // ----- -mesh.mesh @mesh0(shape = 2x4) +shard.grid @grid0(shape = 2x4) -func.func @mesh_shape_wrong_number_of_results() -> (index, index) { +func.func @grid_shape_wrong_number_of_results() -> (index, index) { // expected-error@+1 {{Unexpected number of results 2. Expected 1.}} - %0:2 = mesh.mesh_shape @mesh0 axes = [0] : index, index + %0:2 = shard.grid_shape @grid0 axes = [0] : index, index return %0#0, %0#1 : index, index } // ----- -mesh.mesh @mesh0(shape = 1x2x3) +shard.grid @grid0(shape = 1x2x3) -func.func @mesh_shape_wrong_number_of_results_empty_mesh_axes() -> (index, index) { +func.func @grid_shape_wrong_number_of_results_empty_grid_axes() -> (index, index) { // expected-error@+1 {{Unexpected number of results 2. Expected 3.}} - %0:2 = mesh.mesh_shape @mesh0 : index, index + %0:2 = shard.grid_shape @grid0 : index, index return %0#0, %0#1 : index, index } // ----- -func.func @mesh_shape_invalid_mesh_name() -> (index) { - // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}} - %0 = mesh.mesh_shape @this_mesh_symbol_does_not_exist : index +func.func @grid_shape_invalid_grid_name() -> (index) { + // expected-error@+1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}} + %0 = shard.grid_shape @this_grid_symbol_does_not_exist : index return %0#0 : index } // ----- -mesh.mesh @mesh0(shape = 2x4) +shard.grid @grid0(shape = 2x4) -func.func @process_multi_index_mesh_axis_out_of_bounds() -> (index, index) { - // expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}} - %0:2 = mesh.process_multi_index on @mesh0 axes = [0, 2] : index, index +func.func @process_multi_index_grid_axis_out_of_bounds() -> (index, index) { + // expected-error@+1 {{0-based grid axis index 2 is out of bounds. The referenced grid "grid0" is of rank 2.}} + %0:2 = shard.process_multi_index on @grid0 axes = [0, 2] : index, index return %0#0, %0#1 : index, index } // ----- -mesh.mesh @mesh0(shape = 1x2x3) +shard.grid @grid0(shape = 1x2x3) -func.func @process_multi_index_duplicate_mesh_axis() -> (index, index, index) { - // expected-error@+1 {{Mesh axes contains duplicate elements.}} - %0:3 = mesh.process_multi_index on @mesh0 axes = [0, 2, 0] : index, index, index +func.func @process_multi_index_duplicate_grid_axis() -> (index, index, index) { + // expected-error@+1 {{Grid axes contains duplicate elements.}} + %0:3 = shard.process_multi_index on @grid0 axes = [0, 2, 0] : index, index, index return %0#0, %0#1, %0#2 : index, index, index } // ----- -mesh.mesh @mesh0(shape = 2x4) +shard.grid @grid0(shape = 2x4) func.func @process_multi_index_wrong_number_of_results() -> (index, index) { // expected-error@+1 {{Unexpected number of results 2. Expected 1.}} - %0:2 = mesh.process_multi_index on @mesh0 axes = [0] : index, index + %0:2 = shard.process_multi_index on @grid0 axes = [0] : index, index return %0#0, %0#1 : index, index } // ----- -mesh.mesh @mesh0(shape = 1x2x3) +shard.grid @grid0(shape = 1x2x3) -func.func @process_multi_index_wrong_number_of_results_empty_mesh_axes() -> (index, index) { +func.func @process_multi_index_wrong_number_of_results_empty_grid_axes() -> (index, index) { // expected-error@+1 {{Unexpected number of results 2. Expected 3.}} - %0:2 = mesh.process_multi_index on @mesh0 : index, index + %0:2 = shard.process_multi_index on @grid0 : index, index return %0#0, %0#1 : index, index } // ----- -func.func @process_multi_index_invalid_mesh_name() -> (index) { - // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}} - %0 = mesh.process_multi_index on @this_mesh_symbol_does_not_exist : index +func.func @process_multi_index_invalid_grid_name() -> (index) { + // expected-error@+1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}} + %0 = shard.process_multi_index on @this_grid_symbol_does_not_exist : index return %0 : index } // ----- -func.func @process_linear_index_invalid_mesh_name() -> (index) { - // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}} - %0 = mesh.process_linear_index on @this_mesh_symbol_does_not_exist : index +func.func @process_linear_index_invalid_grid_name() -> (index) { + // expected-error@+1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}} + %0 = shard.process_linear_index on @this_grid_symbol_does_not_exist : index return %0 : index } // ----- -func.func @all_reduce_invalid_mesh_symbol( +func.func @all_reduce_invalid_grid_symbol( %arg0 : tensor<4xf32>) -> tensor<4xf64> { - // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}} - %0 = mesh.all_reduce %arg0 on @this_mesh_symbol_does_not_exist reduction = sum + // expected-error@+1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}} + %0 = shard.all_reduce %arg0 on @this_grid_symbol_does_not_exist reduction = sum : tensor<4xf32> -> tensor<4xf64> return %0 : tensor<4xf64> } // ----- -mesh.mesh @mesh0(shape = 2x4) +shard.grid @grid0(shape = 2x4) -func.func @all_reduce_invalid_mesh_axis( +func.func @all_reduce_invalid_grid_axis( %arg0 : tensor<4xf32>) -> tensor<4xf64> { - // expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}} - %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [2] reduction = sum + // expected-error@+1 {{0-based grid axis index 2 is out of bounds. The referenced grid "grid0" is of rank 2.}} + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [2] reduction = sum : tensor<4xf32> -> tensor<4xf64> return %0 : tensor<4xf64> } // ----- -mesh.mesh @mesh0(shape = 2x4) +shard.grid @grid0(shape = 2x4) -func.func @all_reduce_duplicate_mesh_axis( +func.func @all_reduce_duplicate_grid_axis( %arg0 : tensor<4xf32>) -> tensor<4xf64> { - // expected-error@+1 {{Mesh axes contains duplicate elements.}} - %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1, 0] reduction = sum + // expected-error@+1 {{Grid axes contains duplicate elements.}} + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1, 0] reduction = sum : tensor<4xf32> -> tensor<4xf64> return %0 : tensor<4xf64> } // ----- -mesh.mesh @mesh0(shape = 2x4) +shard.grid @grid0(shape = 2x4) func.func @all_reduce_invalid_tensor_dimension_size( %arg0 : tensor<4xf32>) -> tensor<5xf64> { - // expected-error@+1 {{'mesh.all_reduce' op requires the same shape for all operands and results}} - %0 = mesh.all_reduce %arg0 on @mesh0 : tensor<4xf32> -> tensor<5xf64> + // expected-error@+1 {{'shard.all_reduce' op requires the same shape for all operands and results}} + %0 = shard.all_reduce %arg0 on @grid0 : tensor<4xf32> -> tensor<5xf64> return %0 : tensor<5xf64> } // ----- -func.func @all_gather_invalid_mesh_symbol( +func.func @all_gather_invalid_grid_symbol( %arg0 : tensor<4xf32>) -> tensor<4xf32> { - // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}} - %0 = mesh.all_gather %arg0 on @this_mesh_symbol_does_not_exist gather_axis = 0 + // expected-error@+1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}} + %0 = shard.all_gather %arg0 on @this_grid_symbol_does_not_exist gather_axis = 0 : tensor<4xf32> -> tensor<4xf32> return %0 : tensor<4xf32> } // ----- -mesh.mesh @mesh0(shape = 2x4) +shard.grid @grid0(shape = 2x4) -func.func @all_gather_invalid_mesh_axis( +func.func @all_gather_invalid_grid_axis( %arg0 : tensor<4xf32>) -> tensor<4xf32> { - // expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}} - %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 0 + // expected-error@+1 {{0-based grid axis index 2 is out of bounds. The referenced grid "grid0" is of rank 2.}} + %0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 0 : tensor<4xf32> -> tensor<4xf32> return %0 : tensor<4xf32> } // ----- -mesh.mesh @mesh0(shape = 2x4) +shard.grid @grid0(shape = 2x4) -func.func @all_reduce_duplicate_mesh_axis( +func.func @all_reduce_duplicate_grid_axis( %arg0 : tensor<4xf32>) -> tensor<4xf32> { - // expected-error@+1 {{Mesh axes contains duplicate elements.}} - %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2, 2] gather_axis = 0 + // expected-error@+1 {{Grid axes contains duplicate elements.}} + %0 = shard.all_gather %arg0 on @grid0 grid_axes = [2, 2] gather_axis = 0 : tensor<4xf32> -> tensor<4xf32> return %0 : tensor<4xf32> } // ----- -mesh.mesh @mesh0(shape = 1) +shard.grid @grid0(shape = 1) func.func @all_gather_invalid_non_gather_axis_dimension_size( %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> { // expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 4, but got 5.}} - %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 + %0 = shard.all_gather %arg0 on @grid0 grid_axes = [0] gather_axis = 0 : tensor<3x4xf32> -> tensor<3x5xf32> return %0 : tensor<3x5xf32> } // ----- -mesh.mesh @mesh0(shape = 1x2) +shard.grid @grid0(shape = 1x2) func.func @all_gather_invalid_gather_axis_dimension_size( %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> { // expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 8, but got 5.}} - %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [1] gather_axis = 1 + %0 = shard.all_gather %arg0 on @grid0 grid_axes = [1] gather_axis = 1 : tensor<3x4xf32> -> tensor<3x5xf32> return %0 : tensor<3x5xf32> } // ----- -mesh.mesh @mesh0(shape = 1) +shard.grid @grid0(shape = 1) func.func @all_gather_invalid_gather_axis_dynamic_dimension( %arg0 : tensor<?xf32>) -> tensor<3xf32> { // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}} - %0 = mesh.all_gather %arg0 on @mesh0 gather_axis = 0 + %0 = shard.all_gather %arg0 on @grid0 gather_axis = 0 : tensor<?xf32> -> tensor<3xf32> return %0 : tensor<3xf32> } // ----- -mesh.mesh @mesh0(shape = 1) +shard.grid @grid0(shape = 1) func.func @all_gather_invalid_gather_axis( %arg0 : tensor<3xf32>) -> tensor<3xf32> { // expected-error@+1 {{Gather axis 1 is out of bounds [0, 1).}} - %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 1 + %0 = shard.all_gather %arg0 on @grid0 grid_axes = [0] gather_axis = 1 : tensor<3xf32> -> tensor<3xf32> return %0 : tensor<3xf32> } // ----- -mesh.mesh @mesh0(shape = 1) +shard.grid @grid0(shape = 1) func.func @all_gather_invalid_negative_gather_axis( %arg0 : tensor<3xf32>) -> tensor<3xf32> { // expected-error@+1 {{Gather axis -1 is out of bounds [0, 1).}} - %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = -1 + %0 = shard.all_gather %arg0 on @grid0 grid_axes = [0] gather_axis = -1 : tensor<3xf32> -> tensor<3xf32> return %0 : tensor<3xf32> } // ----- -mesh.mesh @mesh0(shape = 3) +shard.grid @grid0(shape = 3) -func.func @all_slice_duplicate_mesh_axis( +func.func @all_slice_duplicate_grid_axis( %arg0 : tensor<?xf32>) -> tensor<?xf32> { - // expected-error@+1 {{Mesh axes contains duplicate elements.}} - %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [0, 0] + // expected-error@+1 {{Grid axes contains duplicate elements.}} + %0 = shard.all_slice %arg0 on @grid0 grid_axes = [0, 0] slice_axis = 0 : tensor<?xf32> -> tensor<?xf32> return %0 : tensor<?xf32> @@ -359,12 +359,12 @@ func.func @all_slice_duplicate_mesh_axis( // ----- -mesh.mesh @mesh0(shape = 3) +shard.grid @grid0(shape = 3) func.func @all_slice_invalid_dynamic_dimension( %arg0 : tensor<?xf32>) -> tensor<2xf32> { // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}} - %0 = mesh.all_slice %arg0 on @mesh0 + %0 = shard.all_slice %arg0 on @grid0 slice_axis = 0 : tensor<?xf32> -> tensor<2xf32> return %0 : tensor<2xf32> @@ -372,12 +372,12 @@ func.func @all_slice_invalid_dynamic_dimension( // ----- -mesh.mesh @mesh0(shape = 3) +shard.grid @grid0(shape = 3) func.func @all_slice_invalid_static_dimension_size( %arg0 : tensor<3xf32>) -> tensor<2xf32> { // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}} - %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.all_slice %arg0 on @grid0 grid_axes = [0] slice_axis = 0 : tensor<3xf32> -> tensor<2xf32> return %0 : tensor<2xf32> @@ -385,12 +385,12 @@ func.func @all_slice_invalid_static_dimension_size( // ----- -mesh.mesh @mesh0(shape = 3) +shard.grid @grid0(shape = 3) func.func @all_slice_invalid_operand_static_dimension_size( %arg0 : tensor<4xf32>) -> tensor<?xf32> { // expected-error@+1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}} - %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.all_slice %arg0 on @grid0 grid_axes = [0] slice_axis = 0 : tensor<4xf32> -> tensor<?xf32> return %0 : tensor<?xf32> @@ -398,10 +398,10 @@ func.func @all_slice_invalid_operand_static_dimension_size( // ----- -func.func @all_to_all_invalid_mesh_symbol( +func.func @all_to_all_invalid_grid_symbol( %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> { - // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}} - %0 = mesh.all_to_all %arg0 on @this_mesh_symbol_does_not_exist + // expected-error@+1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}} + %0 = shard.all_to_all %arg0 on @this_grid_symbol_does_not_exist split_axis = 1 concat_axis = 0 : tensor<3x6xi8> -> tensor<3x6xi8> return %0 : tensor<3x6xi8> @@ -409,12 +409,12 @@ func.func @all_to_all_invalid_mesh_symbol( // ----- -mesh.mesh @mesh0(shape = 1) +shard.grid @grid0(shape = 1) -func.func @all_to_all_duplicate_mesh_axis( +func.func @all_to_all_duplicate_grid_axis( %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> { - // expected-error@+1 {{Mesh axes contains duplicate elements.}} - %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0, 0] + // expected-error@+1 {{Grid axes contains duplicate elements.}} + %0 = shard.all_to_all %arg0 on @grid0 grid_axes = [0, 0] split_axis = 0 concat_axis = 0 : tensor<3x6xi8> -> tensor<3x6xi8> return %0 : tensor<3x6xi8> @@ -422,12 +422,12 @@ func.func @all_to_all_duplicate_mesh_axis( // ----- -mesh.mesh @mesh0(shape = ?x1) +shard.grid @grid0(shape = ?x1) func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_device_group( %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> { // expected-error@+1 {{Dimension size mismatch for result axis 1. Expected dynamic, but got 6.}} - %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.all_to_all %arg0 on @grid0 grid_axes = [0] split_axis = 0 concat_axis = 1 : tensor<3x6xi8> -> tensor<3x6xi8> return %0 : tensor<3x6xi8> @@ -435,12 +435,12 @@ func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_de // ----- -mesh.mesh @mesh0(shape = 1x1) +shard.grid @grid0(shape = 1x1) func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dynamic_operand_dimension( %arg0 : tensor<?x6xi8>) -> tensor<3x?xi8> { // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}} - %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [1] + %0 = shard.all_to_all %arg0 on @grid0 grid_axes = [1] split_axis = 0 concat_axis = 1 : tensor<?x6xi8> -> tensor<3x?xi8> return %0 : tensor<3x?xi8> @@ -448,12 +448,12 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dyna // ----- -mesh.mesh @mesh0(shape = 1x1) +shard.grid @grid0(shape = 1x1) func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dynamic_operand_dimension( %arg0 : tensor<3x?xi8>) -> tensor<?x3xi8> { // expected-error@+1 {{Dimension size mismatch for result axis 1. Expected dynamic, but got 3.}} - %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [1] + %0 = shard.all_to_all %arg0 on @grid0 grid_axes = [1] split_axis = 0 concat_axis = 1 : tensor<3x?xi8> -> tensor<?x3xi8> return %0 : tensor<?x3xi8> @@ -461,12 +461,12 @@ func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dyn // ----- -mesh.mesh @mesh0(shape = 3) +shard.grid @grid0(shape = 3) func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size( %arg0 : tensor<3x2xi8>) -> tensor<1x7xi8> { // expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 6, but got 7.}} - %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.all_to_all %arg0 on @grid0 grid_axes = [0] split_axis = 0 concat_axis = 1 : tensor<3x2xi8> -> tensor<1x7xi8> return %0 : tensor<1x7xi8> @@ -474,12 +474,12 @@ func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size( // ----- -mesh.mesh @mesh0(shape = 3) +shard.grid @grid0(shape = 3) func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size( %arg0 : tensor<3x2xi8>) -> tensor<2x6xi8> { // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}} - %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.all_to_all %arg0 on @grid0 grid_axes = [0] split_axis = 0 concat_axis = 1 : tensor<3x2xi8> -> tensor<2x6xi8> return %0 : tensor<2x6xi8> @@ -487,12 +487,12 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size( // ----- -mesh.mesh @mesh0(shape = 3x?) +shard.grid @grid0(shape = 3x?) func.func @broadcast_root_dimension_out_of_bounds( %arg0 : tensor<2xi8>) -> tensor<2xi8> { // expected-error@+1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}} - %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.broadcast %arg0 on @grid0 grid_axes = [0] root = [3] : (tensor<2xi8>) -> tensor<2xi8> return %0 : tensor<2xi8> @@ -500,12 +500,12 @@ func.func @broadcast_root_dimension_out_of_bounds( // ----- -mesh.mesh @mesh0(shape = 3x?) +shard.grid @grid0(shape = 3x?) func.func @broadcast_root_wrong_number_dimensions( %arg0 : tensor<2xi8>) -> tensor<2xi8> { // expected-error@+1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}} - %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.broadcast %arg0 on @grid0 grid_axes = [0] root = [2, 2] : (tensor<2xi8>) -> tensor<2xi8> return %0 : tensor<2xi8> @@ -513,12 +513,12 @@ func.func @broadcast_root_wrong_number_dimensions( // ----- -mesh.mesh @mesh0(shape = 3x?) +shard.grid @grid0(shape = 3x?) func.func @broadcast_different_input_and_result_type( %arg0 : tensor<2xi8>) -> tensor<2xi16> { - // expected-error@+1 {{'mesh.broadcast' op failed to verify that all of {input, result} have same element type}} - %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0] + // expected-error@+1 {{'shard.broadcast' op failed to verify that all of {input, result} have same element type}} + %0 = shard.broadcast %arg0 on @grid0 grid_axes = [0] root = [2] : (tensor<2xi8>) -> tensor<2xi16> return %0 : tensor<2xi16> @@ -526,84 +526,84 @@ func.func @broadcast_different_input_and_result_type( // ----- -mesh.mesh @mesh0(shape = 1) +shard.grid @grid0(shape = 1) func.func @gather_wrong_return_element_type( %arg0 : tensor<1xf32>) -> tensor<1xi8> { - // expected-error@+1 {{'mesh.gather' op failed to verify that all of {input, result} have same element type}} - %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 root = [0] + // expected-error@+1 {{'shard.gather' op failed to verify that all of {input, result} have same element type}} + %0 = shard.gather %arg0 on @grid0 grid_axes = [0] gather_axis = 0 root = [0] : (tensor<1xf32>) -> tensor<1xi8> return %0 : tensor<1xi8> } // ----- -mesh.mesh @mesh0(shape = 1) +shard.grid @grid0(shape = 1) func.func @gather_invalid_non_gather_axis_dimension_size( %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> { // expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 4, but got 5.}} - %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 root = [0] + %0 = shard.gather %arg0 on @grid0 grid_axes = [0] gather_axis = 0 root = [0] : (tensor<3x4xf32>) -> tensor<3x5xf32> return %0 : tensor<3x5xf32> } // ----- -mesh.mesh @mesh0(shape = 1x2) +shard.grid @grid0(shape = 1x2) func.func @gather_invalid_gather_axis_dimension_size( %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> { // expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 8, but got 5.}} - %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [1] gather_axis = 1 root = [0] + %0 = shard.gather %arg0 on @grid0 grid_axes = [1] gather_axis = 1 root = [0] : (tensor<3x4xf32>) -> tensor<3x5xf32> return %0 : tensor<3x5xf32> } // ----- -mesh.mesh @mesh0(shape = 1) +shard.grid @grid0(shape = 1) func.func @gather_invalid_gather_axis_dynamic_dimension( %arg0 : tensor<?xf32>) -> tensor<3xf32> { // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}} - %0 = mesh.gather %arg0 on @mesh0 gather_axis = 0 root = [] + %0 = shard.gather %arg0 on @grid0 gather_axis = 0 root = [] : (tensor<?xf32>) -> tensor<3xf32> return %0 : tensor<3xf32> } // ----- -mesh.mesh @mesh0(shape = 1) +shard.grid @grid0(shape = 1) func.func @gather_invalid_gather_axis( %arg0 : tensor<3xf32>) -> tensor<3xf32> { // expected-error@+1 {{Gather axis 1 is out of bounds [0, 1).}} - %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 1 root = [0] + %0 = shard.gather %arg0 on @grid0 grid_axes = [0] gather_axis = 1 root = [0] : (tensor<3xf32>) -> tensor<3xf32> return %0 : tensor<3xf32> } // ----- -mesh.mesh @mesh0(shape = 1) +shard.grid @grid0(shape = 1) func.func @gather_invalid_negative_gather_axis( %arg0 : tensor<3xf32>) -> tensor<3xf32> { // expected-error@+1 {{Gather axis -1 is out of bounds [0, 1).}} - %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = -1 root = [0] + %0 = shard.gather %arg0 on @grid0 grid_axes = [0] gather_axis = -1 root = [0] : (tensor<3xf32>) -> tensor<3xf32> return %0 : tensor<3xf32> } // ----- -mesh.mesh @mesh0(shape = 3x?) +shard.grid @grid0(shape = 3x?) func.func @gather_root_dimension_out_of_bounds( %arg0 : tensor<2xi8>) -> tensor<6xi8> { // expected-error@+1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}} - %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 + %0 = shard.gather %arg0 on @grid0 grid_axes = [0] gather_axis = 0 root = [3] : (tensor<2xi8>) -> tensor<6xi8> return %0 : tensor<6xi8> @@ -611,12 +611,12 @@ func.func @gather_root_dimension_out_of_bounds( // ----- -mesh.mesh @mesh0(shape = 3x?) +shard.grid @grid0(shape = 3x?) func.func @gather_root_wrong_number_dimensions( %arg0 : tensor<2xi8>) -> tensor<2xi8> { // expected-error@+1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}} - %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 + %0 = shard.gather %arg0 on @grid0 grid_axes = [0] gather_axis = 0 root = [2, 2] : (tensor<2xi8>) -> tensor<2xi8> return %0 : tensor<2xi8> @@ -624,12 +624,12 @@ func.func @gather_root_wrong_number_dimensions( // ----- -mesh.mesh @mesh0(shape = 3x?) +shard.grid @grid0(shape = 3x?) func.func @receive_source_dimension_out_of_bounds( %arg0 : tensor<2xi8>) -> tensor<2xi8> { // expected-error@+1 {{Out of bounds coordinate 0 for in-group device "source". Got 3, but expected value in the range [0, 2].}} - %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.recv %arg0 on @grid0 grid_axes = [0] source = [3] : (tensor<2xi8>) -> tensor<2xi8> return %0 : tensor<2xi8> @@ -637,12 +637,12 @@ func.func @receive_source_dimension_out_of_bounds( // ----- -mesh.mesh @mesh0(shape = 3x?) +shard.grid @grid0(shape = 3x?) func.func @receive_source_wrong_number_dimensions( %arg0 : tensor<2xi8>) -> tensor<2xi8> { // expected-error@+1 {{In-group device "source" has unexpected multi-index size 2. Expected 1.}} - %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.recv %arg0 on @grid0 grid_axes = [0] source = [2, 2] : (tensor<2xi8>) -> tensor<2xi8> return %0 : tensor<2xi8> @@ -650,12 +650,12 @@ func.func @receive_source_wrong_number_dimensions( // ----- -mesh.mesh @mesh0(shape = 3x?) +shard.grid @grid0(shape = 3x?) func.func @receive_different_input_and_result_type( %arg0 : tensor<2xi8>) -> tensor<2xi16> { - // expected-error@+1 {{'mesh.recv' op failed to verify that all of {input, result} have same element type}} - %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0] + // expected-error@+1 {{'shard.recv' op failed to verify that all of {input, result} have same element type}} + %0 = shard.recv %arg0 on @grid0 grid_axes = [0] source = [2] : (tensor<2xi8>) -> tensor<2xi16> return %0 : tensor<2xi16> @@ -663,12 +663,12 @@ func.func @receive_different_input_and_result_type( // ----- -mesh.mesh @mesh0(shape = 3x?) +shard.grid @grid0(shape = 3x?) func.func @reduce_root_dimension_out_of_bounds( %arg0 : tensor<2xi8>) -> tensor<2xi8> { // expected-error@+1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}} - %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.reduce %arg0 on @grid0 grid_axes = [0] root = [3] : (tensor<2xi8>) -> tensor<2xi8> return %0 : tensor<2xi8> @@ -676,12 +676,12 @@ func.func @reduce_root_dimension_out_of_bounds( // ----- -mesh.mesh @mesh0(shape = 3x?) +shard.grid @grid0(shape = 3x?) func.func @reduce_root_wrong_number_dimensions( %arg0 : tensor<2xi8>) -> tensor<2xi8> { // expected-error@+1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}} - %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.reduce %arg0 on @grid0 grid_axes = [0] root = [2, 2] : (tensor<2xi8>) -> tensor<2xi8> return %0 : tensor<2xi8> @@ -689,12 +689,12 @@ func.func @reduce_root_wrong_number_dimensions( // ----- -mesh.mesh @mesh0(shape = 3x?) +shard.grid @grid0(shape = 3x?) func.func @reduce_different_input_and_result_shape( %arg0 : tensor<2xi8>) -> tensor<3xi16> { - // expected-error@+1 {{'mesh.reduce' op failed to verify that all of {input, result} have same shape}} - %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0] + // expected-error@+1 {{'shard.reduce' op failed to verify that all of {input, result} have same shape}} + %0 = shard.reduce %arg0 on @grid0 grid_axes = [0] root = [2] : (tensor<2xi8>) -> tensor<3xi16> return %0 : tensor<3xi16> @@ -702,60 +702,60 @@ func.func @reduce_different_input_and_result_shape( // ----- -mesh.mesh @mesh0(shape = 3) +shard.grid @grid0(shape = 3) -func.func @reduce_scatter_duplicate_mesh_axis( +func.func @reduce_scatter_duplicate_grid_axis( %arg0 : tensor<?xf32>) -> tensor<?xf64> { - // expected-error@+1 {{Mesh axes contains duplicate elements.}} - %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0, 0] scatter_axis = 0 + // expected-error@+1 {{Grid axes contains duplicate elements.}} + %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0, 0] scatter_axis = 0 : tensor<?xf32> -> tensor<?xf64> return %0 : tensor<?xf64> } // ----- -mesh.mesh @mesh0(shape = 3) +shard.grid @grid0(shape = 3) func.func @reduce_scatter_invalid_dynamic_dimension( %arg0 : tensor<?xf32>) -> tensor<2xf64> { // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}} - %0 = mesh.reduce_scatter %arg0 on @mesh0 scatter_axis = 0 + %0 = shard.reduce_scatter %arg0 on @grid0 scatter_axis = 0 : tensor<?xf32> -> tensor<2xf64> return %0 : tensor<2xf64> } // ----- -mesh.mesh @mesh0(shape = 3) +shard.grid @grid0(shape = 3) func.func @reduce_scatter_invalid_static_dimension_size( %arg0 : tensor<3xf32>) -> tensor<2xf64> { // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}} - %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0] scatter_axis = 0 + %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_axis = 0 : tensor<3xf32> -> tensor<2xf64> return %0 : tensor<2xf64> } // ----- -mesh.mesh @mesh0(shape = 3) +shard.grid @grid0(shape = 3) func.func @reduce_scatter_invalid_operand_static_dimension_size( %arg0 : tensor<4xf32>) -> tensor<?xf64> { // expected-error@+1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}} - %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0] scatter_axis = 0 + %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_axis = 0 : tensor<4xf32> -> tensor<?xf64> return %0 : tensor<?xf64> } // ----- -mesh.mesh @mesh0(shape = 3) +shard.grid @grid0(shape = 3) -func.func @scatter_duplicate_mesh_axis( +func.func @scatter_duplicate_grid_axis( %arg0 : tensor<?xf32>) -> tensor<?xf32> { - // expected-error@+1 {{Mesh axes contains duplicate elements.}} - %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0, 0] + // expected-error@+1 {{Grid axes contains duplicate elements.}} + %0 = shard.scatter %arg0 on @grid0 grid_axes = [0, 0] scatter_axis = 0 root = [0, 0] : (tensor<?xf32>) -> tensor<?xf32> return %0 : tensor<?xf32> @@ -763,12 +763,12 @@ func.func @scatter_duplicate_mesh_axis( // ----- -mesh.mesh @mesh0(shape = 3) +shard.grid @grid0(shape = 3) func.func @scatter_invalid_dynamic_dimension( %arg0 : tensor<?xf32>) -> tensor<2xf32> { // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}} - %0 = mesh.scatter %arg0 on @mesh0 + %0 = shard.scatter %arg0 on @grid0 scatter_axis = 0 root = [] : (tensor<?xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> @@ -776,12 +776,12 @@ func.func @scatter_invalid_dynamic_dimension( // ----- -mesh.mesh @mesh0(shape = 3) +shard.grid @grid0(shape = 3) func.func @scatter_invalid_static_dimension_size( %arg0 : tensor<3xf32>) -> tensor<2xf32> { // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}} - %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.scatter %arg0 on @grid0 grid_axes = [0] scatter_axis = 0 root = [1] : (tensor<3xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> @@ -789,12 +789,12 @@ func.func @scatter_invalid_static_dimension_size( // ----- -mesh.mesh @mesh0(shape = 3) +shard.grid @grid0(shape = 3) func.func @scatter_invalid_operand_static_dimension_size( %arg0 : tensor<4xf32>) -> tensor<?xf32> { // expected-error@+1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}} - %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.scatter %arg0 on @grid0 grid_axes = [0] scatter_axis = 0 root = [1] : (tensor<4xf32>) -> tensor<?xf32> return %0 : tensor<?xf32> @@ -802,12 +802,12 @@ func.func @scatter_invalid_operand_static_dimension_size( // ----- -mesh.mesh @mesh0(shape = 3x?) +shard.grid @grid0(shape = 3x?) func.func @scatter_root_dimension_out_of_bounds( %arg0 : tensor<3xi8>) -> tensor<1xi8> { // expected-error@+1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}} - %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.scatter %arg0 on @grid0 grid_axes = [0] scatter_axis = 0 root = [3] : (tensor<3xi8>) -> tensor<1xi8> return %0 : tensor<1xi8> @@ -815,12 +815,12 @@ func.func @scatter_root_dimension_out_of_bounds( // ----- -mesh.mesh @mesh0(shape = 3x?) +shard.grid @grid0(shape = 3x?) func.func @scatter_root_wrong_number_dimensions( %arg0 : tensor<3xi8>) -> tensor<1xi8> { // expected-error@+1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}} - %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.scatter %arg0 on @grid0 grid_axes = [0] scatter_axis = 0 root = [2, 2] : (tensor<3xi8>) -> tensor<1xi8> return %0 : tensor<1xi8> @@ -828,12 +828,12 @@ func.func @scatter_root_wrong_number_dimensions( // ----- -mesh.mesh @mesh0(shape = 3x?) +shard.grid @grid0(shape = 3x?) func.func @send_destination_dimension_out_of_bounds( %arg0 : tensor<2xi8>) -> tensor<2xi8> { // expected-error@+1 {{Out of bounds coordinate 0 for in-group device "destination". Got 3, but expected value in the range [0, 2].}} - %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.send %arg0 on @grid0 grid_axes = [0] destination = [3] : (tensor<2xi8>) -> tensor<2xi8> return %0 : tensor<2xi8> @@ -841,12 +841,12 @@ func.func @send_destination_dimension_out_of_bounds( // ----- -mesh.mesh @mesh0(shape = 3x?) +shard.grid @grid0(shape = 3x?) func.func @send_destination_wrong_number_dimensions( %arg0 : tensor<2xi8>) -> tensor<2xi8> { // expected-error@+1 {{In-group device "destination" has unexpected multi-index size 2. Expected 1.}} - %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.send %arg0 on @grid0 grid_axes = [0] destination = [2, 2] : (tensor<2xi8>) -> tensor<2xi8> return %0 : tensor<2xi8> @@ -854,12 +854,12 @@ func.func @send_destination_wrong_number_dimensions( // ----- -mesh.mesh @mesh0(shape = 3x?) +shard.grid @grid0(shape = 3x?) func.func @send_different_input_and_result_type( %arg0 : tensor<2xi8>) -> tensor<2xi16> { - // expected-error@+1 {{'mesh.send' op failed to verify that all of {input, result} have same element type}} - %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0] + // expected-error@+1 {{'shard.send' op failed to verify that all of {input, result} have same element type}} + %0 = shard.send %arg0 on @grid0 grid_axes = [0] destination = [2] : (tensor<2xi8>) -> tensor<2xi16> return %0 : tensor<2xi16> @@ -867,10 +867,10 @@ func.func @send_different_input_and_result_type( // ----- -func.func @shift_invalid_mesh_symbol( +func.func @shift_invalid_grid_symbol( %arg0 : tensor<4xi8>) -> tensor<4xi8> { - // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}} - %0 = mesh.shift %arg0 on @this_mesh_symbol_does_not_exist + // expected-error@+1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}} + %0 = shard.shift %arg0 on @this_grid_symbol_does_not_exist shift_axis = 0 offset = -2 : tensor<4xi8> -> tensor<4xi8> return %0 : tensor<4xi8> @@ -878,12 +878,12 @@ func.func @shift_invalid_mesh_symbol( // ----- -mesh.mesh @mesh0(shape = 2x4) +shard.grid @grid0(shape = 2x4) -func.func @shift_invalid_mesh_axis( +func.func @shift_invalid_grid_axis( %arg0 : tensor<4xi8>) -> tensor<4xi8> { - // expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}} - %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [2] + // expected-error@+1 {{0-based grid axis index 2 is out of bounds. The referenced grid "grid0" is of rank 2.}} + %0 = shard.shift %arg0 on @grid0 grid_axes = [2] shift_axis = 2 offset = -2 : tensor<4xi8> -> tensor<4xi8> return %0 : tensor<4xi8> @@ -891,12 +891,12 @@ func.func @shift_invalid_mesh_axis( // ----- -mesh.mesh @mesh0(shape = 2x4) +shard.grid @grid0(shape = 2x4) -func.func @shift_duplicate_mesh_axis( +func.func @shift_duplicate_grid_axis( %arg0 : tensor<4xi8>) -> tensor<4xi8> { - // expected-error@+1 {{Mesh axes contains duplicate elements.}} - %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0, 1, 0] + // expected-error@+1 {{Grid axes contains duplicate elements.}} + %0 = shard.shift %arg0 on @grid0 grid_axes = [0, 1, 0] shift_axis = 0 offset = -2 : tensor<4xi8> -> tensor<4xi8> return %0 : tensor<4xi8> @@ -904,12 +904,12 @@ func.func @shift_duplicate_mesh_axis( // ----- -mesh.mesh @mesh0(shape = 2x4) +shard.grid @grid0(shape = 2x4) func.func @shift_invalid_tensor_dimension_size( %arg0 : tensor<4xi8>) -> tensor<5xi8> { - // expected-error@+1 {{'mesh.shift' op requires the same shape for all operands and results}} - %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0] + // expected-error@+1 {{'shard.shift' op requires the same shape for all operands and results}} + %0 = shard.shift %arg0 on @grid0 grid_axes = [0] shift_axis = 0 offset = 2 : tensor<4xi8> -> tensor<5xi8> return %0 : tensor<5xi8> @@ -917,12 +917,12 @@ func.func @shift_invalid_tensor_dimension_size( // ----- -mesh.mesh @mesh0(shape = 2x4) +shard.grid @grid0(shape = 2x4) func.func @shift_invalid_shift_axis( %arg0 : tensor<4xi8>) -> tensor<4xi8> { - // expected-error@+1 {{Invalid shift axis 1. It must be one of the grouping mesh axes.}} - %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0] + // expected-error@+1 {{Invalid shift axis 1. It must be one of the grouping grid axes.}} + %0 = shard.shift %arg0 on @grid0 grid_axes = [0] shift_axis = 1 offset = 2 : tensor<4xi8> -> tensor<4xi8> return %0 : tensor<4xi8> diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Shard/ops.mlir index c354de5..5265dad 100644 --- a/mlir/test/Dialect/Mesh/ops.mlir +++ b/mlir/test/Dialect/Shard/ops.mlir @@ -1,176 +1,176 @@ // RUN: mlir-opt %s | mlir-opt | FileCheck %s -// CHECK: mesh.mesh @mesh0 -mesh.mesh @mesh0(shape = 2x2x4) +// CHECK: shard.grid @grid0 +shard.grid @grid0(shape = 2x2x4) -// CHECK: mesh.mesh @mesh1(shape = 4x?) -mesh.mesh @mesh1(shape = 4x?) +// CHECK: shard.grid @grid1(shape = 4x?) +shard.grid @grid1(shape = 4x?) -// CHECK: mesh.mesh @mesh2(shape = ?x4) -mesh.mesh @mesh2(shape = ?x4) +// CHECK: shard.grid @grid2(shape = ?x4) +shard.grid @grid2(shape = ?x4) -// CHECK: mesh.mesh @mesh3(shape = ?x?) -mesh.mesh @mesh3(shape = ?x?) +// CHECK: shard.grid @grid3(shape = ?x?) +shard.grid @grid3(shape = ?x?) -mesh.mesh @mesh4(shape = 3) +shard.grid @grid4(shape = 3) -// CHECK: mesh.mesh @mesh5(shape = ?) -mesh.mesh @mesh5(shape = ?) +// CHECK: shard.grid @grid5(shape = ?) +shard.grid @grid5(shape = ?) -// CHECK-LABEL: func @mesh_shard_op_fully_replicated +// CHECK-LABEL: func @grid_shard_op_fully_replicated // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32> -func.func @mesh_shard_op_fully_replicated(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { - // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}]] : !mesh.sharding - %s = mesh.sharding @mesh0 split_axes = [[]] : !mesh.sharding - // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32> - %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> +func.func @grid_shard_op_fully_replicated(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { + // CHECK-NEXT: %[[S:.*]] = shard.sharding @grid0 split_axes = {{\[\[}}]] : !shard.sharding + %s = shard.sharding @grid0 split_axes = [[]] : !shard.sharding + // CHECK-NEXT: shard.shard %[[ARG]] to %[[S]] : tensor<4x8xf32> + %0 = shard.shard %arg0 to %s : tensor<4x8xf32> return %0 : tensor<4x8xf32> } -// CHECK-LABEL: func @mesh_shard_op_1st_dim +// CHECK-LABEL: func @grid_shard_op_1st_dim // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32> -func.func @mesh_shard_op_1st_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { - // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}0]] : !mesh.sharding - %s = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding +func.func @grid_shard_op_1st_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { + // CHECK-NEXT: %[[S:.*]] = shard.sharding @grid0 split_axes = {{\[\[}}0]] : !shard.sharding + %s = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding - %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> + %0 = shard.shard %arg0 to %s : tensor<4x8xf32> return %0 : tensor<4x8xf32> } -// CHECK-LABEL: func @mesh_shard_op_2nd_dim +// CHECK-LABEL: func @grid_shard_op_2nd_dim // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32> -func.func @mesh_shard_op_2nd_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { - // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh1 split_axes = {{\[\[}}], [0]] : !mesh.sharding - %s = mesh.sharding @mesh1 split_axes = [[], [0]] : !mesh.sharding - // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32> - %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> +func.func @grid_shard_op_2nd_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { + // CHECK-NEXT: %[[S:.*]] = shard.sharding @grid1 split_axes = {{\[\[}}], [0]] : !shard.sharding + %s = shard.sharding @grid1 split_axes = [[], [0]] : !shard.sharding + // CHECK-NEXT: shard.shard %[[ARG]] to %[[S]] : tensor<4x8xf32> + %0 = shard.shard %arg0 to %s : tensor<4x8xf32> return %0 : tensor<4x8xf32> } -// CHECK-LABEL: func @mesh_shard_op_1st_and_3rd_dim -func.func @mesh_shard_op_1st_and_3rd_dim( +// CHECK-LABEL: func @grid_shard_op_1st_and_3rd_dim +func.func @grid_shard_op_1st_and_3rd_dim( // CHECK-SAME: %[[ARG:.*]]: tensor<4x8x16xf32> %arg0 : tensor<4x8x16xf32>) -> tensor<4x8x16xf32> { - // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3 split_axes = {{\[\[}}0], [], [1]] : !mesh.sharding - %s = mesh.sharding @mesh3 split_axes = [[0], [], [1]] : !mesh.sharding - // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8x16xf32> - %0 = mesh.shard %arg0 to %s : tensor<4x8x16xf32> + // CHECK-NEXT: %[[S:.*]] = shard.sharding @grid3 split_axes = {{\[\[}}0], [], [1]] : !shard.sharding + %s = shard.sharding @grid3 split_axes = [[0], [], [1]] : !shard.sharding + // CHECK-NEXT: shard.shard %[[ARG]] to %[[S]] : tensor<4x8x16xf32> + %0 = shard.shard %arg0 to %s : tensor<4x8x16xf32> return %0 : tensor<4x8x16xf32> } -// CHECK-LABEL: func @mesh_shard_op_two_users +// CHECK-LABEL: func @grid_shard_op_two_users // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32> -func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) -> +func.func @grid_shard_op_two_users(%arg0 : tensor<4x8xf32>) -> (tensor<4x8xf32>, tensor<4x8xf32>) { - // CHECK-NEXT: %[[V0:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}0]] : !mesh.sharding - %s0 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding - %0 = mesh.shard %arg0 to %s0 : tensor<4x8xf32> - // CHECK-DAG: mesh.sharding @mesh0 split_axes = {{\[\[}}1]] : !mesh.sharding - %s1 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding - %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<4x8xf32> - // CHECK-DAG: mesh.sharding @mesh0 split_axes = {{\[\[}}2]] : !mesh.sharding - %s2 = mesh.sharding @mesh0 split_axes = [[2]] : !mesh.sharding - %2 = mesh.shard %0 to %s2 annotate_for_users : tensor<4x8xf32> + // CHECK-NEXT: %[[V0:.*]] = shard.sharding @grid0 split_axes = {{\[\[}}0]] : !shard.sharding + %s0 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<4x8xf32> + // CHECK-DAG: shard.sharding @grid0 split_axes = {{\[\[}}1]] : !shard.sharding + %s1 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor<4x8xf32> + // CHECK-DAG: shard.sharding @grid0 split_axes = {{\[\[}}2]] : !shard.sharding + %s2 = shard.sharding @grid0 split_axes = [[2]] : !shard.sharding + %2 = shard.shard %0 to %s2 annotate_for_users : tensor<4x8xf32> return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32> } -// CHECK-LABEL: func @mesh_shard_halo_sizes -func.func @mesh_shard_halo_sizes() -> () { +// CHECK-LABEL: func @grid_shard_halo_sizes +func.func @grid_shard_halo_sizes() -> () { // CHECK: %[[C3:.*]] = arith.constant 3 : i64 %c3 = arith.constant 3 : i64 - // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] halo_sizes = [1, 4] : !mesh.sharding - %sharding1 = mesh.sharding @mesh4 split_axes = [[0]] halo_sizes = [1, 4] : !mesh.sharding - // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] halo_sizes = [4, %[[C3]]] : !mesh.sharding - %sharding2 = mesh.sharding @mesh4 split_axes = [[0]] halo_sizes = [4, %c3] : !mesh.sharding + // CHECK: shard.sharding @grid4 split_axes = {{\[\[}}0]] halo_sizes = [1, 4] : !shard.sharding + %sharding1 = shard.sharding @grid4 split_axes = [[0]] halo_sizes = [1, 4] : !shard.sharding + // CHECK: shard.sharding @grid4 split_axes = {{\[\[}}0]] halo_sizes = [4, %[[C3]]] : !shard.sharding + %sharding2 = shard.sharding @grid4 split_axes = [[0]] halo_sizes = [4, %c3] : !shard.sharding return } -// CHECK-LABEL: func @mesh_shard_dims_sizes -func.func @mesh_shard_dims_sizes() -> () { +// CHECK-LABEL: func @grid_shard_dims_sizes +func.func @grid_shard_dims_sizes() -> () { // CHECK: %[[C3:.*]] = arith.constant 3 : i64 %c3 = arith.constant 3 : i64 - // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 6] : !mesh.sharding - %sharding1 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 6] : !mesh.sharding - // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 2, %[[C3]], 5] : !mesh.sharding - %sharding2 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_offsets = [0, 2, %c3, 5] : !mesh.sharding + // CHECK: shard.sharding @grid4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 6] : !shard.sharding + %sharding1 = shard.sharding @grid4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 6] : !shard.sharding + // CHECK: shard.sharding @grid4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 2, %[[C3]], 5] : !shard.sharding + %sharding2 = shard.sharding @grid4 split_axes = [[0]] sharded_dims_offsets = [0, 2, %c3, 5] : !shard.sharding return } -// CHECK-LABEL: func @mesh_shard_shape -func.func @mesh_shard_shape() { +// CHECK-LABEL: func @grid_shard_shape +func.func @grid_shard_shape() { // CHECK: %[[C3:.*]] = arith.constant 3 : index %c3 = arith.constant 3 : index - // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}]] : !mesh.sharding - %s = mesh.sharding @mesh0 split_axes = [[]] : !mesh.sharding - // CHECK-NEXT: mesh.shard_shape dims = [8, %[[C3]] + // CHECK-NEXT: %[[S:.*]] = shard.sharding @grid0 split_axes = {{\[\[}}]] : !shard.sharding + %s = shard.sharding @grid0 split_axes = [[]] : !shard.sharding + // CHECK-NEXT: shard.shard_shape dims = [8, %[[C3]] // CHECK-SAME: ] sharding = %[[S]] device = [%[[C3]] // CHECK-SAME: ] : index, index - %shp:2 = mesh.shard_shape dims = [8, %c3] sharding = %s device = [%c3] : index, index - // CHECK-NEXT: mesh.shard_shape dims = [8, 4] sharding = %[[S]] device = [3] : index, index - %shp1:2 = mesh.shard_shape dims = [8, 4] sharding = %s device = [3] : index, index + %shp:2 = shard.shard_shape dims = [8, %c3] sharding = %s device = [%c3] : index, index + // CHECK-NEXT: shard.shard_shape dims = [8, 4] sharding = %[[S]] device = [3] : index, index + %shp1:2 = shard.shard_shape dims = [8, 4] sharding = %s device = [3] : index, index return } -// CHECK-LABEL: func @mesh_get_sharding +// CHECK-LABEL: func @grid_get_sharding // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32> -func.func @mesh_get_sharding(%arg0 : tensor<4x8xf32>) -> !mesh.sharding { - // CHECK-NEXT: mesh.get_sharding %[[ARG]] : tensor<4x8xf32> -> !mesh.sharding - %0 = mesh.get_sharding %arg0 : tensor<4x8xf32> -> !mesh.sharding - return %0 : !mesh.sharding +func.func @grid_get_sharding(%arg0 : tensor<4x8xf32>) -> !shard.sharding { + // CHECK-NEXT: shard.get_sharding %[[ARG]] : tensor<4x8xf32> -> !shard.sharding + %0 = shard.get_sharding %arg0 : tensor<4x8xf32> -> !shard.sharding + return %0 : !shard.sharding } -// CHECK-LABEL: func @mesh_shape -func.func @mesh_shape() -> (index, index) { - // CHECK: %[[RES:.*]]:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index - %0:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index +// CHECK-LABEL: func @grid_shape +func.func @grid_shape() -> (index, index) { + // CHECK: %[[RES:.*]]:2 = shard.grid_shape @grid0 axes = [0, 1] : index, index + %0:2 = shard.grid_shape @grid0 axes = [0, 1] : index, index // CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index return %0#0, %0#1 : index, index } -// CHECK-LABEL: func @mesh_shape_default_axes -func.func @mesh_shape_default_axes() -> (index, index, index) { - // CHECK: %[[RES:.*]]:3 = mesh.mesh_shape @mesh0 : index, index, index - %0:3 = mesh.mesh_shape @mesh0 : index, index, index +// CHECK-LABEL: func @grid_shape_default_axes +func.func @grid_shape_default_axes() -> (index, index, index) { + // CHECK: %[[RES:.*]]:3 = shard.grid_shape @grid0 : index, index, index + %0:3 = shard.grid_shape @grid0 : index, index, index // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index return %0#0, %0#1, %0#2 : index, index, index } -// CHECK-LABEL: func @mesh_shape_empty_axes -func.func @mesh_shape_empty_axes() -> (index, index, index) { - // CHECK: %[[RES:.*]]:3 = mesh.mesh_shape @mesh0 : index, index, index - %0:3 = mesh.mesh_shape @mesh0 axes = [] : index, index, index +// CHECK-LABEL: func @grid_shape_empty_axes +func.func @grid_shape_empty_axes() -> (index, index, index) { + // CHECK: %[[RES:.*]]:3 = shard.grid_shape @grid0 : index, index, index + %0:3 = shard.grid_shape @grid0 axes = [] : index, index, index // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index return %0#0, %0#1, %0#2 : index, index, index } // CHECK-LABEL: func @process_multi_index func.func @process_multi_index() -> (index, index) { - // CHECK: %[[RES:.*]]:2 = mesh.process_multi_index on @mesh0 axes = [0, 1] : index, index - %0:2 = mesh.process_multi_index on @mesh0 axes = [0, 1] : index, index + // CHECK: %[[RES:.*]]:2 = shard.process_multi_index on @grid0 axes = [0, 1] : index, index + %0:2 = shard.process_multi_index on @grid0 axes = [0, 1] : index, index // CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index return %0#0, %0#1 : index, index } // CHECK-LABEL: func @process_multi_index_default_axes func.func @process_multi_index_default_axes() -> (index, index, index) { - // CHECK: %[[RES:.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index - %0:3 = mesh.process_multi_index on @mesh0 : index, index, index + // CHECK: %[[RES:.*]]:3 = shard.process_multi_index on @grid0 : index, index, index + %0:3 = shard.process_multi_index on @grid0 : index, index, index // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index return %0#0, %0#1, %0#2 : index, index, index } // CHECK-LABEL: func @process_multi_index_empty_axes func.func @process_multi_index_empty_axes() -> (index, index, index) { - // CHECK: %[[RES:.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index - %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index + // CHECK: %[[RES:.*]]:3 = shard.process_multi_index on @grid0 : index, index, index + %0:3 = shard.process_multi_index on @grid0 axes = [] : index, index, index // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index return %0#0, %0#1, %0#2 : index, index, index } // CHECK-LABEL: func @process_linear_index func.func @process_linear_index() -> index { - // CHECK: %[[RES:.*]] = mesh.process_linear_index on @mesh0 : index - %0 = mesh.process_linear_index on @mesh0 : index + // CHECK: %[[RES:.*]] = shard.process_linear_index on @grid0 : index + %0 = shard.process_linear_index on @grid0 : index // CHECK: return %[[RES]] : index return %0 : index } @@ -179,9 +179,9 @@ func.func @process_linear_index() -> index { func.func @all_reduce( // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32> %arg0 : tensor<3x4xf32>) -> tensor<3x4xf64> { - // CHECK-NEXT: mesh.all_reduce %[[ARG]] on @mesh0 mesh_axes = [1, 0] reduction = max + // CHECK-NEXT: shard.all_reduce %[[ARG]] on @grid0 grid_axes = [1, 0] reduction = max // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x4xf64> - %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [1, 0] reduction = max + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [1, 0] reduction = max : tensor<3x4xf32> -> tensor<3x4xf64> return %0 : tensor<3x4xf64> } @@ -190,9 +190,9 @@ func.func @all_reduce( func.func @all_gather( // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32> %arg0 : tensor<3x4xf32>) -> tensor<3x16xf32> { - // CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh0 mesh_axes = [2] gather_axis = 1 + // CHECK-NEXT: shard.all_gather %[[ARG]] on @grid0 grid_axes = [2] gather_axis = 1 // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x16xf32> - %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 1 + %0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 1 : tensor<3x4xf32> -> tensor<3x16xf32> return %0 : tensor<3x16xf32> } @@ -201,20 +201,20 @@ func.func @all_gather( func.func @all_gather_dynamic_dims_in_tensor( // CHECK-SAME: %[[ARG:.*]]: tensor<?x?xf32> %arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> { - // CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh0 mesh_axes = [2] gather_axis = 1 + // CHECK-NEXT: shard.all_gather %[[ARG]] on @grid0 grid_axes = [2] gather_axis = 1 // CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?xf32> - %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 1 + %0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 1 : tensor<?x?xf32> -> tensor<?x?xf32> return %0 : tensor<?x?xf32> } -// CHECK-LABEL: func @all_gather_dynamic_dims_in_mesh -func.func @all_gather_dynamic_dims_in_mesh( +// CHECK-LABEL: func @all_gather_dynamic_dims_in_grid +func.func @all_gather_dynamic_dims_in_grid( // CHECK-SAME: %[[ARG:.*]]: tensor<5x6xf32> %arg0 : tensor<5x6xf32>) -> tensor<5x?xf32> { - // CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh3 mesh_axes = [1] gather_axis = 1 + // CHECK-NEXT: shard.all_gather %[[ARG]] on @grid3 grid_axes = [1] gather_axis = 1 // CHECK-SAME: : tensor<5x6xf32> -> tensor<5x?xf32> - %0 = mesh.all_gather %arg0 on @mesh3 mesh_axes = [1] gather_axis = 1 + %0 = shard.all_gather %arg0 on @grid3 grid_axes = [1] gather_axis = 1 : tensor<5x6xf32> -> tensor<5x?xf32> return %0 : tensor<5x?xf32> } @@ -223,10 +223,10 @@ func.func @all_gather_dynamic_dims_in_mesh( func.func @all_slice_static_dimensions( // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32> %arg0 : tensor<3x4xf32>) -> tensor<3x1xf32> { - // CHECK-NEXT: mesh.all_slice %[[ARG]] - // CHECK-SAME: on @mesh0 mesh_axes = [2] slice_axis = 1 + // CHECK-NEXT: shard.all_slice %[[ARG]] + // CHECK-SAME: on @grid0 grid_axes = [2] slice_axis = 1 // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf32> - %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [2] slice_axis = 1 + %0 = shard.all_slice %arg0 on @grid0 grid_axes = [2] slice_axis = 1 : tensor<3x4xf32> -> tensor<3x1xf32> return %0 : tensor<3x1xf32> } @@ -235,10 +235,10 @@ func.func @all_slice_static_dimensions( func.func @all_slice_dynamic_dimensions( // CHECK-SAME: %[[ARG:.*]]: tensor<?xf32> %arg0 : tensor<?xf32>) -> tensor<?xf32> { - // CHECK-NEXT: mesh.all_slice %[[ARG]] - // CHECK-SAME: on @mesh3 mesh_axes = [0, 1] slice_axis = 0 + // CHECK-NEXT: shard.all_slice %[[ARG]] + // CHECK-SAME: on @grid3 grid_axes = [0, 1] slice_axis = 0 // CHECK-SAME: : tensor<?xf32> -> tensor<?xf32> - %0 = mesh.all_slice %arg0 on @mesh3 mesh_axes = [0, 1] slice_axis = 0 + %0 = shard.all_slice %arg0 on @grid3 grid_axes = [0, 1] slice_axis = 0 : tensor<?xf32> -> tensor<?xf32> return %0 : tensor<?xf32> } @@ -247,10 +247,10 @@ func.func @all_slice_dynamic_dimensions( func.func @all_to_all( // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8> %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> { - // CHECK-NEXT: mesh.all_to_all %[[ARG]] - // CHECK-SAME: on @mesh4 split_axis = 1 concat_axis = 0 + // CHECK-NEXT: shard.all_to_all %[[ARG]] + // CHECK-SAME: on @grid4 split_axis = 1 concat_axis = 0 // CHECK-SAME: : tensor<3x6xi8> -> tensor<3x6xi8> - %0 = mesh.all_to_all %arg0 on @mesh4 + %0 = shard.all_to_all %arg0 on @grid4 split_axis = 1 concat_axis = 0 : tensor<3x6xi8> -> tensor<3x6xi8> return %0 : tensor<3x6xi8> @@ -260,10 +260,10 @@ func.func @all_to_all( func.func @all_to_all_dynamic_dims_in_result( // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8> %arg0 : tensor<3x6xi8>) -> tensor<3x?xi8> { - // CHECK-NEXT: mesh.all_to_all %[[ARG]] - // CHECK-SAME: on @mesh4 split_axis = 1 concat_axis = 0 + // CHECK-NEXT: shard.all_to_all %[[ARG]] + // CHECK-SAME: on @grid4 split_axis = 1 concat_axis = 0 // CHECK-SAME: : tensor<3x6xi8> -> tensor<3x?xi8> - %0 = mesh.all_to_all %arg0 on @mesh4 + %0 = shard.all_to_all %arg0 on @grid4 split_axis = 1 concat_axis = 0 : tensor<3x6xi8> -> tensor<3x?xi8> return %0 : tensor<3x?xi8> @@ -273,10 +273,10 @@ func.func @all_to_all_dynamic_dims_in_result( func.func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size( // CHECK-SAME: %[[ARG:.*]]: tensor<3xi8> %arg0 : tensor<3xi8>) -> tensor<3xi8> { - // CHECK-NEXT: mesh.all_to_all %[[ARG]] - // CHECK-SAME: @mesh4 split_axis = 0 concat_axis = 0 + // CHECK-NEXT: shard.all_to_all %[[ARG]] + // CHECK-SAME: @grid4 split_axis = 0 concat_axis = 0 // CHECK-SAME: : tensor<3xi8> -> tensor<3xi8> - %0 = mesh.all_to_all %arg0 on @mesh4 + %0 = shard.all_to_all %arg0 on @grid4 split_axis = 0 concat_axis = 0 : tensor<3xi8> -> tensor<3xi8> return %0 : tensor<3xi8> @@ -286,10 +286,10 @@ func.func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size( func.func @all_to_all_non_divisible_split_axis_size( // CHECK-SAME: %[[ARG:.*]]: tensor<2x3xi8> %arg0 : tensor<2x3xi8>) -> tensor<?x12xi8> { - // CHECK-NEXT: mesh.all_to_all %[[ARG]] - // CHECK-SAME: @mesh0 mesh_axes = [0, 1] split_axis = 0 concat_axis = 1 + // CHECK-NEXT: shard.all_to_all %[[ARG]] + // CHECK-SAME: @grid0 grid_axes = [0, 1] split_axis = 0 concat_axis = 1 // CHECK-SAME: : tensor<2x3xi8> -> tensor<?x12xi8> - %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0, 1] + %0 = shard.all_to_all %arg0 on @grid0 grid_axes = [0, 1] split_axis = 0 concat_axis = 1 : tensor<2x3xi8> -> tensor<?x12xi8> return %0 : tensor<?x12xi8> @@ -299,11 +299,11 @@ func.func @all_to_all_non_divisible_split_axis_size( func.func @broadcast_static_root( // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8> %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> { - // CHECK-NEXT: mesh.broadcast %[[ARG]] - // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] + // CHECK-NEXT: shard.broadcast %[[ARG]] + // CHECK-SAME: on @grid0 grid_axes = [0, 2] // CHECK-SAME: root = [0, 1] // CHECK-SAME: : (tensor<3x6xi8>) -> tensor<3x6xi8> - %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0, 2] + %0 = shard.broadcast %arg0 on @grid0 grid_axes = [0, 2] root = [0, 1] : (tensor<3x6xi8>) -> tensor<3x6xi8> return %0 : tensor<3x6xi8> @@ -316,11 +316,11 @@ func.func @broadcast_dynamic_root( // CHECK-SAME: %[[ARG1:.*]]: index %arg1 : index ) -> tensor<3x6xi8> { - // CHECK-NEXT: mesh.broadcast %[[ARG0]] - // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] + // CHECK-NEXT: shard.broadcast %[[ARG0]] + // CHECK-SAME: on @grid0 grid_axes = [0, 2] // CHECK-SAME: root = [1, %[[ARG1]]] // CHECK-SAME: : (tensor<3x6xi8>, index) -> tensor<3x6xi8> - %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0, 2] + %0 = shard.broadcast %arg0 on @grid0 grid_axes = [0, 2] root = [1, %arg1] : (tensor<3x6xi8>, index) -> tensor<3x6xi8> return %0 : tensor<3x6xi8> @@ -330,12 +330,12 @@ func.func @broadcast_dynamic_root( func.func @gather_static_root( // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8> %arg0 : tensor<3x6xi8>) -> tensor<24x6xi8> { - // CHECK-NEXT: mesh.gather %[[ARG]] - // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] + // CHECK-NEXT: shard.gather %[[ARG]] + // CHECK-SAME: on @grid0 grid_axes = [0, 2] // CHECK-SAME: gather_axis = 0 // CHECK-SAME: root = [0, 1] // CHECK-SAME: : (tensor<3x6xi8>) -> tensor<24x6xi8> - %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0, 2] + %0 = shard.gather %arg0 on @grid0 grid_axes = [0, 2] gather_axis = 0 root = [0, 1] : (tensor<3x6xi8>) -> tensor<24x6xi8> @@ -349,12 +349,12 @@ func.func @gather_dynamic_root( // CHECK-SAME: %[[ARG1:.*]]: index %arg1 : index ) -> tensor<24x6xi8> { - // CHECK-NEXT: mesh.gather %[[ARG0]] - // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] + // CHECK-NEXT: shard.gather %[[ARG0]] + // CHECK-SAME: on @grid0 grid_axes = [0, 2] // CHECK-SAME: gather_axis = 0 // CHECK-SAME: root = [1, %[[ARG1]]] // CHECK-SAME: : (tensor<3x6xi8>, index) -> tensor<24x6xi8> - %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0, 2] + %0 = shard.gather %arg0 on @grid0 grid_axes = [0, 2] gather_axis = 0 root = [1, %arg1] : (tensor<3x6xi8>, index) -> tensor<24x6xi8> @@ -365,11 +365,11 @@ func.func @gather_dynamic_root( func.func @receive_static_source( // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> %arg0 : tensor<2xi8>) -> tensor<2xi8> { - // CHECK-NEXT: mesh.recv %[[ARG]] - // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] + // CHECK-NEXT: shard.recv %[[ARG]] + // CHECK-SAME: on @grid0 grid_axes = [0, 2] // CHECK-SAME: source = [0, 1] // CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8> - %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2] + %0 = shard.recv %arg0 on @grid0 grid_axes = [0, 2] source = [0, 1] : (tensor<2xi8>) -> tensor<2xi8> return %0 : tensor<2xi8> @@ -382,11 +382,11 @@ func.func @receive_dynamic_source( // CHECK-SAME: %[[ARG1:.*]]: index %arg1 : index ) -> tensor<2xi8> { - // CHECK-NEXT: mesh.recv %[[ARG0]] - // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] + // CHECK-NEXT: shard.recv %[[ARG0]] + // CHECK-SAME: on @grid0 grid_axes = [0, 2] // CHECK-SAME: source = [1, %[[ARG1]]] // CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8> - %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2] + %0 = shard.recv %arg0 on @grid0 grid_axes = [0, 2] source = [1, %arg1] : (tensor<2xi8>, index) -> tensor<2xi8> return %0 : tensor<2xi8> @@ -396,9 +396,9 @@ func.func @receive_dynamic_source( func.func @receive_no_source( // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> %arg0 : tensor<2xi8>) -> tensor<2xi8> { - // CHECK-NEXT: mesh.recv %[[ARG]] + // CHECK-NEXT: shard.recv %[[ARG]] // CHECK-NOT: source - %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2] + %0 = shard.recv %arg0 on @grid0 grid_axes = [0, 2] : (tensor<2xi8>) -> tensor<2xi8> return %0 : tensor<2xi8> } @@ -407,11 +407,11 @@ func.func @receive_no_source( func.func @reduce_static_root( // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> %arg0 : tensor<2xi8>) -> tensor<2xi8> { - // CHECK-NEXT: mesh.reduce %[[ARG]] - // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] + // CHECK-NEXT: shard.reduce %[[ARG]] + // CHECK-SAME: on @grid0 grid_axes = [0, 2] // CHECK-SAME: root = [0, 1] // CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8> - %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2] + %0 = shard.reduce %arg0 on @grid0 grid_axes = [0, 2] root = [0, 1] : (tensor<2xi8>) -> tensor<2xi8> return %0 : tensor<2xi8> @@ -424,11 +424,11 @@ func.func @reduce_dynamic_root( // CHECK-SAME: %[[ARG1:.*]]: index %arg1 : index ) -> tensor<2xi8> { - // CHECK-NEXT: mesh.reduce %[[ARG0]] - // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] + // CHECK-NEXT: shard.reduce %[[ARG0]] + // CHECK-SAME: on @grid0 grid_axes = [0, 2] // CHECK-SAME: root = [1, %[[ARG1]]] // CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8> - %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2] + %0 = shard.reduce %arg0 on @grid0 grid_axes = [0, 2] root = [1, %arg1] : (tensor<2xi8>, index) -> tensor<2xi8> return %0 : tensor<2xi8> @@ -438,11 +438,11 @@ func.func @reduce_dynamic_root( func.func @reduce_different_return_element_type( // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> %arg0 : tensor<2xi8>) -> tensor<2xi16> { - // CHECK-NEXT: mesh.reduce %[[ARG]] - // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] + // CHECK-NEXT: shard.reduce %[[ARG]] + // CHECK-SAME: on @grid0 grid_axes = [0, 2] // CHECK-SAME: root = [0, 1] // CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi16> - %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2] + %0 = shard.reduce %arg0 on @grid0 grid_axes = [0, 2] root = [0, 1] : (tensor<2xi8>) -> tensor<2xi16> return %0 : tensor<2xi16> @@ -452,10 +452,10 @@ func.func @reduce_different_return_element_type( func.func @reduce_scatter_static_dimensions( // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32> %arg0 : tensor<3x4xf32>) -> tensor<3x1xf64> { - // CHECK-NEXT: mesh.reduce_scatter %[[ARG]] - // CHECK-SAME: on @mesh0 mesh_axes = [2] reduction = max scatter_axis = 1 + // CHECK-NEXT: shard.reduce_scatter %[[ARG]] + // CHECK-SAME: on @grid0 grid_axes = [2] reduction = max scatter_axis = 1 // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf64> - %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [2] + %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [2] reduction = max scatter_axis = 1 : tensor<3x4xf32> -> tensor<3x1xf64> return %0 : tensor<3x1xf64> @@ -465,10 +465,10 @@ func.func @reduce_scatter_static_dimensions( func.func @reduce_scatter_dynamic_dimensions( // CHECK-SAME: %[[ARG:.*]]: tensor<?xf32> %arg0 : tensor<?xf32>) -> tensor<?xf64> { - // CHECK-NEXT: mesh.reduce_scatter %[[ARG]] - // CHECK-SAME: on @mesh3 mesh_axes = [0, 1] scatter_axis = 0 + // CHECK-NEXT: shard.reduce_scatter %[[ARG]] + // CHECK-SAME: on @grid3 grid_axes = [0, 1] scatter_axis = 0 // CHECK-SAME: : tensor<?xf32> -> tensor<?xf64> - %0 = mesh.reduce_scatter %arg0 on @mesh3 mesh_axes = [0, 1] scatter_axis = 0 + %0 = shard.reduce_scatter %arg0 on @grid3 grid_axes = [0, 1] scatter_axis = 0 : tensor<?xf32> -> tensor<?xf64> return %0 : tensor<?xf64> } @@ -477,11 +477,11 @@ func.func @reduce_scatter_dynamic_dimensions( func.func @scatter_static_dimensions( // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32> %arg0 : tensor<3x4xf32>) -> tensor<3x1xf32> { - // CHECK-NEXT: mesh.scatter %[[ARG]] - // CHECK-SAME: on @mesh0 mesh_axes = [2] + // CHECK-NEXT: shard.scatter %[[ARG]] + // CHECK-SAME: on @grid0 grid_axes = [2] // CHECK-SAME: scatter_axis = 1 root = [1] // CHECK-SAME: : (tensor<3x4xf32>) -> tensor<3x1xf32> - %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [2] + %0 = shard.scatter %arg0 on @grid0 grid_axes = [2] scatter_axis = 1 root = [1] : (tensor<3x4xf32>) -> tensor<3x1xf32> return %0 : tensor<3x1xf32> @@ -491,11 +491,11 @@ func.func @scatter_static_dimensions( func.func @scatter_dynamic_dimensions( // CHECK-SAME: %[[ARG:.*]]: tensor<?xf32> %arg0 : tensor<?xf32>) -> tensor<?xf32> { - // CHECK-NEXT: mesh.scatter %[[ARG]] - // CHECK-SAME: on @mesh3 mesh_axes = [0, 1] + // CHECK-NEXT: shard.scatter %[[ARG]] + // CHECK-SAME: on @grid3 grid_axes = [0, 1] // CHECK-SAME: scatter_axis = 0 root = [1, 2] // CHECK-SAME: : (tensor<?xf32>) -> tensor<?xf32> - %0 = mesh.scatter %arg0 on @mesh3 mesh_axes = [0, 1] + %0 = shard.scatter %arg0 on @grid3 grid_axes = [0, 1] scatter_axis = 0 root = [1, 2] : (tensor<?xf32>) -> tensor<?xf32> return %0 : tensor<?xf32> @@ -508,12 +508,12 @@ func.func @scatter_dynamic_root( // CHECK-SAME: %[[ARG1:.*]]: index %arg1 : index ) -> tensor<1xi8> { - // CHECK-NEXT: mesh.scatter %[[ARG0]] - // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] + // CHECK-NEXT: shard.scatter %[[ARG0]] + // CHECK-SAME: on @grid0 grid_axes = [0, 2] // CHECK-SAME: scatter_axis = 0 // CHECK-SAME: root = [1, %[[ARG1]]] // CHECK-SAME: : (tensor<8xi8>, index) -> tensor<1xi8> - %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0, 2] + %0 = shard.scatter %arg0 on @grid0 grid_axes = [0, 2] scatter_axis = 0 root = [1, %arg1] : (tensor<8xi8>, index) -> tensor<1xi8> @@ -524,11 +524,11 @@ func.func @scatter_dynamic_root( func.func @send_static_destination( // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> %arg0 : tensor<2xi8>) -> tensor<2xi8> { - // CHECK-NEXT: mesh.send %[[ARG]] - // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] + // CHECK-NEXT: shard.send %[[ARG]] + // CHECK-SAME: on @grid0 grid_axes = [0, 2] // CHECK-SAME: destination = [0, 1] // CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8> - %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0, 2] + %0 = shard.send %arg0 on @grid0 grid_axes = [0, 2] destination = [0, 1] : (tensor<2xi8>) -> tensor<2xi8> return %0 : tensor<2xi8> @@ -541,11 +541,11 @@ func.func @send_dynamic_destination( // CHECK-SAME: %[[ARG1:.*]]: index %arg1 : index ) -> tensor<2xi8> { - // CHECK-NEXT: mesh.send %[[ARG0]] - // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] + // CHECK-NEXT: shard.send %[[ARG0]] + // CHECK-SAME: on @grid0 grid_axes = [0, 2] // CHECK-SAME: destination = [1, %[[ARG1]]] // CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8> - %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0, 2] + %0 = shard.send %arg0 on @grid0 grid_axes = [0, 2] destination = [1, %arg1] : (tensor<2xi8>, index) -> tensor<2xi8> return %0 : tensor<2xi8> @@ -555,11 +555,11 @@ func.func @send_dynamic_destination( func.func @shift( // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> %arg0 : tensor<2xi8>) -> tensor<2xi8> { - // CHECK-NEXT: mesh.shift %[[ARG]] - // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] + // CHECK-NEXT: shard.shift %[[ARG]] + // CHECK-SAME: on @grid0 grid_axes = [0, 2] // CHECK-SAME: shift_axis = 2 offset = -2 rotate // CHECK-SAME: : tensor<2xi8> -> tensor<2xi8> - %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0, 2] + %0 = shard.shift %arg0 on @grid0 grid_axes = [0, 2] shift_axis = 2 offset = -2 rotate : tensor<2xi8> -> tensor<2xi8> return %0 : tensor<2xi8> @@ -570,16 +570,16 @@ func.func @update_halo( // CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8> %arg0 : memref<12x12xi8>) { // CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i64 - // CHECK-NEXT: %[[UH1:.*]] = mesh.update_halo %[[ARG]] on @mesh0 + // CHECK-NEXT: %[[UH1:.*]] = shard.update_halo %[[ARG]] on @grid0 // CHECK-SAME: split_axes = {{\[\[}}0]] // CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8> %c2 = arith.constant 2 : i64 - %uh1 = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] + %uh1 = shard.update_halo %arg0 on @grid0 split_axes = [[0]] halo_sizes = [2, %c2] : memref<12x12xi8> - // CHECK-NEXT: %[[UH2:.*]] = mesh.update_halo %[[UH1]] on @mesh0 + // CHECK-NEXT: %[[UH2:.*]] = shard.update_halo %[[UH1]] on @grid0 // CHECK-SAME: split_axes = {{\[\[}}0], [1]] // CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2] : memref<12x12xi8> - %uh2 = mesh.update_halo %uh1 on @mesh0 split_axes = [[0], [1]] + %uh2 = shard.update_halo %uh1 on @grid0 split_axes = [[0], [1]] halo_sizes = [2, 2, %c2, 2] : memref<12x12xi8> return } diff --git a/mlir/test/Dialect/Shard/partition.mlir b/mlir/test/Dialect/Shard/partition.mlir new file mode 100644 index 0000000..c2572cc --- /dev/null +++ b/mlir/test/Dialect/Shard/partition.mlir @@ -0,0 +1,317 @@ +// RUN: mlir-opt \ +// RUN: --pass-pipeline="builtin.module(func.func(shard-partition,test-single-fold))" \ +// RUN: %s | FileCheck %s + +shard.grid @grid_1d(shape = 2) + +// CHECK-LABEL: func @return_sharding +func.func @return_sharding( + // CHECK-SAME: [[ARG:%.*]]: tensor<1xf32> + %arg0: tensor<2xf32> +// CHECK-SAME: ) -> (tensor<1xf32>, !shard.sharding) { +) -> (tensor<2xf32>, !shard.sharding) { + %ssharded = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %sharded = shard.shard %arg0 to %ssharded : tensor<2xf32> + // CHECK-NEXT: [[vsharding:%.*]] = shard.sharding @grid_1d split_axes = {{\[\[}}0]] : !shard.sharding + %r = shard.get_sharding %sharded : tensor<2xf32> -> !shard.sharding + // CHECK-NEXT: return [[ARG]], [[vsharding]] : tensor<1xf32>, !shard.sharding + return %sharded, %r : tensor<2xf32>, !shard.sharding +} + +// CHECK-LABEL: func @full_replication +func.func @full_replication( + // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> + %arg0: tensor<2xi8> +// CHECK-SAME: -> tensor<2xi8> { +) -> tensor<2xi8> { + %s0 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<2xi8> + %s1 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor<2xi8> + // CHECK: return %[[ARG]] : tensor<2xi8> + return %1 : tensor<2xi8> +} + +// CHECK-LABEL: func @sharding_triplet +func.func @sharding_triplet( + // CHECK-SAME: %[[ARG:.*]]: tensor<1xf32> + %arg0: tensor<2xf32> +// CHECK-SAME: ) -> tensor<2xf32> { +) -> tensor<2xf32> { + // CHECK: %[[ALL_GATHER:.*]] = shard.all_gather %[[ARG]] on @grid_1d grid_axes = [0] gather_axis = 0 : tensor<1xf32> -> tensor<2xf32> + %ssharded = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %sharded = shard.shard %arg0 to %ssharded : tensor<2xf32> + %ssharded_0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %sharded_0 = shard.shard %sharded to %ssharded_0 annotate_for_users : tensor<2xf32> + %ssharded_1 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding + %sharded_1 = shard.shard %sharded_0 to %ssharded_1 : tensor<2xf32> + // CHECK: return %[[ALL_GATHER]] : tensor<2xf32> + return %sharded_1 : tensor<2xf32> +} + + +// CHECK-LABEL: func @move_split_axis +func.func @move_split_axis( + // CHECK-SAME: %[[ARG:.*]]: tensor<1x2xi8> + %arg0: tensor<2x2xi8> +// CHECK-SAME: -> tensor<2x1xi8> { +) -> tensor<2x2xi8> { + // CHECK: %[[ALL_TO_ALL:.*]] = shard.all_to_all %[[ARG]] on @grid_1d + // CHECK-SAME: grid_axes = [0] split_axis = 1 concat_axis = 0 : tensor<1x2xi8> -> tensor<2x1xi8> + %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<2x2xi8> + %s1 = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor<2x2xi8> + // CHECK: return %[[ALL_TO_ALL]] : tensor<2x1xi8> + return %1 : tensor<2x2xi8> +} + +// CHECK-LABEL: func @non_tensor_value +func.func @non_tensor_value( + // CHECK-SAME: %[[ARG:.*]]: i8 + %arg0: i8 +// CHECK-SAME: -> i8 { +) -> i8 { + // CHECK: %[[RES:.*]] = arith.addi %[[ARG]], %[[ARG]] : i8 + %0 = arith.addi %arg0, %arg0 : i8 + // CHECK: return %[[RES]] : i8 + return %0 : i8 +} + +// CHECK-LABEL: func @unary_elementwise +func.func @unary_elementwise( + // CHECK-SAME: %[[ARG:.*]]: tensor<1xi8> + %arg0: tensor<2xi8> +// CHECK-SAME: -> tensor<1xi8> { +) -> tensor<2xi8> { + %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<2xi8> + %s1 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor<2xi8> + // CHECK: %[[RES:.*]] = tosa.abs %[[ARG]] : (tensor<1xi8>) -> tensor<1xi8> + %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8> + %s3 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %3 = shard.shard %2 to %s3 : tensor<2xi8> + %s4 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %4 = shard.shard %3 to %s4 annotate_for_users : tensor<2xi8> + // CHECK: return %[[RES]] : tensor<1xi8> + return %4 : tensor<2xi8> +} + +// full replication -> shard axis -> abs -> shard axis -> full replication +// CHECK-LABEL: func @unary_elementwise_with_resharding +func.func @unary_elementwise_with_resharding( + // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> + %arg0: tensor<2xi8> +// CHECK-SAME: -> tensor<2xi8> { +) -> tensor<2xi8> { + // CHECK: %[[SLICE:.*]] = shard.all_slice %[[ARG]] on @grid_1d grid_axes = [0] slice_axis = 0 + // CHECK-SAME: tensor<2xi8> -> tensor<1xi8> + %s0 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<2xi8> + %s1 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor<2xi8> + // CHECK: %[[ABS:.*]] = tosa.abs %[[SLICE]] : (tensor<1xi8>) -> tensor<1xi8> + %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8> + // CHECK: %[[RES:.*]] = shard.all_gather %[[ABS]] on @grid_1d + // CHECK-SAME: grid_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8> + %s3 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %3 = shard.shard %2 to %s3 : tensor<2xi8> + %s4 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding + %4 = shard.shard %3 to %s4 annotate_for_users : tensor<2xi8> + // CHECK: return %[[RES]] : tensor<2xi8> + return %4 : tensor<2xi8> +} + +// CHECK-LABEL: func @binary_elementwise +func.func @binary_elementwise( + // CHECK-SAME: %[[ARG0:.*]]: tensor<1xi8>, + %arg0: tensor<2xi8>, + // CHECK-SAME: %[[ARG1:.*]]: tensor<1xi8> + %arg1: tensor<2xi8> +// CHECK-SAME: -> tensor<1xi8> { +) -> tensor<2xi8> { + %sarg0_sharded = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %arg0_sharded = shard.shard %arg0 to %sarg0_sharded : tensor<2xi8> + %sop_arg0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %op_arg0 = shard.shard %arg0_sharded to %sop_arg0 annotate_for_users : tensor<2xi8> + %sarg1_sharded = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %arg1_sharded = shard.shard %arg1 to %sarg1_sharded : tensor<2xi8> + %sop_arg1 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %op_arg1 = shard.shard %arg1_sharded to %sop_arg1 annotate_for_users : tensor<2xi8> + // CHECK: %[[RES:.*]] = tosa.add %[[ARG0]], %[[ARG1]] : (tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8> + %op_res = tosa.add %op_arg0, %op_arg1 : (tensor<2xi8>, tensor<2xi8>) -> tensor<2xi8> + %sop_res_sharded = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %op_res_sharded = shard.shard %op_res to %sop_res_sharded : tensor<2xi8> + %sres = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %res = shard.shard %op_res_sharded to %sres annotate_for_users : tensor<2xi8> + // CHECK: return %[[RES]] : tensor<1xi8> + return %res : tensor<2xi8> +} + +// reshard +// abs +// reshard +// abs +// reshard +// CHECK-LABEL: func @multiple_chained_ops +func.func @multiple_chained_ops( + // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> + %arg0: tensor<2xi8> +// CHECK-SAME: -> tensor<1xi8> { +) -> tensor<2xi8> { + // CHECK: %[[RESHARD1:.*]] = shard.all_slice %[[ARG]] on @grid_1d grid_axes = [0] slice_axis = 0 + // CHECK-SAME: tensor<2xi8> -> tensor<1xi8> + %s0 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<2xi8> + %s1 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor<2xi8> + // CHECK: %[[ABS1:.*]] = tosa.abs %[[RESHARD1]] : (tensor<1xi8>) -> tensor<1xi8> + %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8> + // CHECK: %[[RESHARD2:.*]] = shard.all_gather %[[ABS1]] on @grid_1d + // CHECK-SAME: grid_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8> + %s3 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %3 = shard.shard %2 to %s3 : tensor<2xi8> + %s4 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding + %4 = shard.shard %3 to %s4 annotate_for_users : tensor<2xi8> + // CHECK: %[[ABS2:.*]] = tosa.abs %[[RESHARD2]] : (tensor<2xi8>) -> tensor<2xi8> + %5 = tosa.abs %4 : (tensor<2xi8>) -> tensor<2xi8> + // CHECK: %[[RESHARD3:.*]] = shard.all_slice %[[ABS2]] on @grid_1d grid_axes = [0] slice_axis = 0 : + // CHECK-SAME: tensor<2xi8> -> tensor<1xi8> + %s6 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding + %6 = shard.shard %5 to %s6 : tensor<2xi8> + %s7 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %7 = shard.shard %6 to %s7 annotate_for_users : tensor<2xi8> + // CHECK: return %[[RESHARD3]] : tensor<1xi8> + return %7 : tensor<2xi8> +} + +// CHECK-LABEL: func @incomplete_sharding +func.func @incomplete_sharding( + // CHECK-SAME: %[[ARG:.*]]: tensor<4x16xf32> + %arg0: tensor<8x16xf32> +// CHECK-SAME: -> tensor<4x16xf32> { +) -> tensor<8x16xf32> { + %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 annotate_for_users : tensor<8x16xf32> + // CHECK: %[[RES:.*]] = tosa.sigmoid %[[ARG]] : (tensor<4x16xf32>) -> tensor<4x16xf32> + %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32> + %s2 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %2 = shard.shard %1 to %s2 : tensor<8x16xf32> + // CHECK: return %[[RES]] : tensor<4x16xf32> + return %2 : tensor<8x16xf32> +} + +shard.grid @grid_1d_4(shape = 4) + +// CHECK-LABEL: func @ew_chain_with_halo +func.func @ew_chain_with_halo( + // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<5x16xf32> + %arg0: tensor<8x16xf32>, + // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<1xf32> + %arg1: tensor<1xf32>, + // CHECK-SAME: %[[IN3:[A-Za-z0-9_]+]]: tensor<1xf32> + %arg2: tensor<1xf32>) + // CHECK-SAME: -> tensor<5x16xf32> + -> tensor<8x16xf32> { + %ssharded = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding + %sharded = shard.shard %arg0 to %ssharded annotate_for_users : tensor<8x16xf32> + // CHECK: %[[TMP1:.*]] = tosa.tanh %[[IN1]] : (tensor<5x16xf32>) -> tensor<5x16xf32> + %0 = tosa.tanh %sharded : (tensor<8x16xf32>) -> tensor<8x16xf32> + %ssharded_0 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding + %sharded_0 = shard.shard %0 to %ssharded_0 : tensor<8x16xf32> + %ssharded_1 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding + %sharded_1 = shard.shard %sharded_0 to %ssharded_1 annotate_for_users : tensor<8x16xf32> + // CHECK-NEXT: %[[TMP2:.*]] = tosa.abs %[[TMP1]] : (tensor<5x16xf32>) -> tensor<5x16xf32> + %1 = tosa.abs %sharded_1 : (tensor<8x16xf32>) -> tensor<8x16xf32> + %ssharded_2 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding + %sharded_2 = shard.shard %1 to %ssharded_2 : tensor<8x16xf32> + %ssharded_4 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding + %sharded_4 = shard.shard %sharded_2 to %ssharded_4 annotate_for_users : tensor<8x16xf32> + // CHECK-NEXT: %[[TMP3:.*]] = tosa.negate %[[TMP2]], %[[IN2]], %[[IN3]] : (tensor<5x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x16xf32> + %sharding_1 = shard.sharding @grid_1d_4 split_axes = [[]] : !shard.sharding + %zero_point_1 = shard.shard %arg1 to %sharding_1 annotate_for_users : tensor<1xf32> + %zero_point_2 = shard.shard %arg2 to %sharding_1 annotate_for_users : tensor<1xf32> + %2 = tosa.negate %sharded_4, %zero_point_1, %zero_point_2 : (tensor<8x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<8x16xf32> + %ssharded_5 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding + %sharded_5 = shard.shard %2 to %ssharded_5 : tensor<8x16xf32> + %ssharded_6 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding + %sharded_6 = shard.shard %sharded_5 to %ssharded_6 annotate_for_users : tensor<8x16xf32> + // CHECK-NEXT: return %[[TMP3]] : tensor<5x16xf32> + return %sharded_6 : tensor<8x16xf32> +} + +// CHECK-LABEL: func @test_shard_update_halo +// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x1200xi64> +func.func @test_shard_update_halo(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> { + %sharding = shard.sharding @grid_1d_4 split_axes = [[0]] : !shard.sharding + // CHECK: %[[T:.*]] = tensor.empty() : tensor<304x1200xi64> + // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][2, 0] [300, 1200] [1, 1] : tensor<300x1200xi64> into tensor<304x1200xi64> + // CHECK: %[[UH:.*]] = shard.update_halo %[[inserted_slice]] on @grid_1d_4 split_axes = {{\[\[0]]}} halo_sizes = [2, 2] : tensor<304x1200xi64> + %sharded = shard.shard %arg0 to %sharding : tensor<1200x1200xi64> + %sharding_0 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 2] : !shard.sharding + %sharded_1 = shard.shard %sharded to %sharding_0 : tensor<1200x1200xi64> + %sharded_3 = shard.shard %sharded_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64> + // CHECK: return %[[UH]] : tensor<304x1200xi64> + return %sharded_3 : tensor<1200x1200xi64> +} + +shard.grid @grid4x4(shape = 4x4) +// CHECK-LABEL: func @test_shard_update_halo2d +// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x300xi64> +func.func @test_shard_update_halo2d(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> { + %sharding = shard.sharding @grid4x4 split_axes = [[0], [1]] : !shard.sharding + // CHECK: %[[T:.*]] = tensor.empty() : tensor<303x307xi64> + // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][1, 3] [300, 300] [1, 1] : tensor<300x300xi64> into tensor<303x307xi64> + // CHECK: %[[UH:.*]] = shard.update_halo %[[inserted_slice]] on @grid4x4 split_axes = {{\[\[}}0], [1]] halo_sizes = [1, 2, 3, 4] : tensor<303x307xi64> + %sharded = shard.shard %arg0 to %sharding : tensor<1200x1200xi64> + %sharding_0 = shard.sharding @grid4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 3, 4] : !shard.sharding + %sharded_1 = shard.shard %sharded to %sharding_0 : tensor<1200x1200xi64> + %sharded_3 = shard.shard %sharded_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64> + // CHECK: return %[[UH]] : tensor<303x307xi64> + return %sharded_3 : tensor<1200x1200xi64> +} + +shard.grid @grid(shape = 2) +// CHECK-LABEL: func.func @test_reduce_0d( +// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<3x6xi32> +func.func @test_reduce_0d(%arg0: tensor<6x6xi32>) -> (tensor<i32>) { + %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding + %sharded = shard.shard %arg0 to %sharding annotate_for_users : tensor<6x6xi32> + %4 = tensor.empty() : tensor<i32> + %sharding_out = shard.sharding @grid split_axes = [[]] : !shard.sharding + %sharded_out = shard.shard %4 to %sharding_out : tensor<i32> + %sharded_in = shard.shard %sharded to %sharding annotate_for_users : tensor<6x6xi32> + // CHECK: %[[reduced:.*]] = linalg.reduce ins(%arg0 : tensor<3x6xi32>) + %reduced = linalg.reduce ins(%sharded_in : tensor<6x6xi32>) outs(%sharded_out : tensor<i32>) dimensions = [0, 1] + (%in: i32, %init: i32) { + %6 = arith.addi %in, %init : i32 + linalg.yield %6 : i32 + } + // CHECK: %[[all_reduce:.*]] = shard.all_reduce %[[reduced]] on @grid grid_axes = [0] : tensor<i32> -> tensor<i32> + %sharded_red = shard.shard %reduced to %sharding_out : tensor<i32> + %sharded_ret = shard.shard %sharded_red to %sharding_out annotate_for_users : tensor<i32> + // CHECK: return %[[all_reduce]] : tensor<i32> + return %sharded_ret : tensor<i32> +} + +// CHECK-LABEL: func.func @test_reduce_1d( +// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<3x6xi32> +func.func @test_reduce_1d(%arg0: tensor<6x6xi32>) -> (tensor<6xi32>) { + %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding + %sharded = shard.shard %arg0 to %sharding annotate_for_users : tensor<6x6xi32> + %4 = tensor.empty() : tensor<6xi32> + %sharded_out = shard.shard %4 to %sharding : tensor<6xi32> + %sharded_in = shard.shard %sharded to %sharding annotate_for_users : tensor<6x6xi32> + // CHECK: %[[reduced:.*]] = linalg.reduce ins(%arg0 : tensor<3x6xi32>) + %reduced = linalg.reduce ins(%sharded_in : tensor<6x6xi32>) outs(%sharded_out : tensor<6xi32>) dimensions = [1] + (%in: i32, %init: i32) { + %6 = arith.addi %in, %init : i32 + linalg.yield %6 : i32 + } + // CHECK-NOT: shard.all_reduce + %sharded_red = shard.shard %reduced to %sharding : tensor<6xi32> + %sharded_ret = shard.shard %sharded_red to %sharding annotate_for_users : tensor<6xi32> + // CHECK: return %[[reduced]] : tensor<3xi32> + return %sharded_ret : tensor<6xi32> +} diff --git a/mlir/test/Dialect/Shard/process-multi-index-op-lowering.mlir b/mlir/test/Dialect/Shard/process-multi-index-op-lowering.mlir new file mode 100644 index 0000000..33c7a8f --- /dev/null +++ b/mlir/test/Dialect/Shard/process-multi-index-op-lowering.mlir @@ -0,0 +1,23 @@ +// RUN: mlir-opt -test-grid-process-multi-index-op-lowering %s | FileCheck %s + +shard.grid @grid2d(shape = ?x?) + +// CHECK-LABEL: func.func @multi_index_2d_grid +func.func @multi_index_2d_grid() -> (index, index) { + // CHECK: %[[LINEAR_IDX:.*]] = shard.process_linear_index on @grid2d : index + // CHECK: %[[SHARD_SHAPE:.*]]:2 = shard.grid_shape @grid2d : index, index + // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[SHARD_SHAPE]]#0, %[[SHARD_SHAPE]]#1) : index, index + %0:2 = shard.process_multi_index on @grid2d : index, index + // CHECK: return %[[MULTI_IDX]]#0, %[[MULTI_IDX]]#1 : index, index + return %0#0, %0#1 : index, index +} + +// CHECK-LABEL: func.func @multi_index_2d_grid_single_inner_axis +func.func @multi_index_2d_grid_single_inner_axis() -> index { + // CHECK: %[[LINEAR_IDX:.*]] = shard.process_linear_index on @grid2d : index + // CHECK: %[[SHARD_SHAPE:.*]]:2 = shard.grid_shape @grid2d : index, index + // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[SHARD_SHAPE]]#0, %[[SHARD_SHAPE]]#1) : index, index + %0 = shard.process_multi_index on @grid2d axes = [0] : index + // CHECK: return %[[MULTI_IDX]]#0 : index + return %0 : index +} diff --git a/mlir/test/Dialect/Shard/resharding-partition.mlir b/mlir/test/Dialect/Shard/resharding-partition.mlir new file mode 100644 index 0000000..ff9e840 --- /dev/null +++ b/mlir/test/Dialect/Shard/resharding-partition.mlir @@ -0,0 +1,168 @@ +// RUN: mlir-opt -test-grid-resharding-partition %s | FileCheck %s + +shard.grid @grid_1d(shape = 2) +shard.grid @grid_1d_dynamic(shape = ?) + +// CHECK-LABEL: func @same_source_and_target_sharding +func.func @same_source_and_target_sharding( + // CHECK-SAME: %[[ARG:.*]]: tensor<2xf32> + %arg0: tensor<2xf32> +) -> tensor<2xf32> { + %s0 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<2xf32> + %s1 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor<2xf32> + // CHECK: return %[[ARG]] + return %1 : tensor<2xf32> +} + +// CHECK-LABEL: func @identical_source_and_target_sharding +func.func @identical_source_and_target_sharding( + // CHECK-SAME: %[[ARG:.*]]: tensor<2xf32> + %arg0: tensor<2xf32> +) -> tensor<2xf32> { + %s0 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<2xf32> + %1 = shard.shard %0 to %s0 annotate_for_users : tensor<2xf32> + // CHECK: return %[[ARG]] + return %1 : tensor<2xf32> +} + +// CHECK-LABEL: func @split_replicated_tensor_axis +func.func @split_replicated_tensor_axis( + // CHECK-SAME: %[[ARG:.*]]: tensor<3x14xf32> + %arg0: tensor<3x14xf32> +) -> tensor<3x14xf32> { + // CHECK: %[[ALL_SLICE:.*]] = shard.all_slice %[[ARG]] on @grid_1d grid_axes = [0] slice_axis = 1 + // CHECK-SAME: tensor<3x14xf32> -> tensor<3x7xf32> + // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[ALL_SLICE]] : tensor<3x7xf32> to tensor<3x14xf32> + %s0 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<3x14xf32> + %s1 = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor<3x14xf32> + // CHECK: return %[[RESULT]] : tensor<3x14xf32> + return %1 : tensor<3x14xf32> +} + +// CHECK-LABEL: func @split_replicated_tensor_axis_dynamic +func.func @split_replicated_tensor_axis_dynamic( + // CHECK-SAME: %[[ARG:.*]]: tensor<?x3x?xf32> + %arg0: tensor<?x3x?xf32> +) -> tensor<?x3x?xf32> { + // CHECK: %[[RESULT:.*]] = shard.all_slice %[[ARG]] on @grid_1d_dynamic grid_axes = [0] slice_axis = 0 + // CHECK-SAME: tensor<?x3x?xf32> -> tensor<?x3x?xf32> + %s0 = shard.sharding @grid_1d_dynamic split_axes = [[], [], []] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<?x3x?xf32> + %s1 = shard.sharding @grid_1d_dynamic split_axes = [[0]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor<?x3x?xf32> + // CHECK: return %[[RESULT]] : tensor<?x3x?xf32> + return %1 : tensor<?x3x?xf32> +} + +// CHECK-LABEL: func @move_split_axis +func.func @move_split_axis( + // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32> + %arg0: tensor<10x14xf32> +) -> tensor<10x14xf32> { + // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32> + // CHECK: %[[TARGET_SHARD:.*]] = shard.all_to_all %[[SOURCE_SHARD]] on @grid_1d grid_axes = [0] split_axis = 1 concat_axis = 0 : tensor<5x14xf32> -> tensor<10x7xf32> + // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x7xf32> to tensor<10x14xf32> + %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<10x14xf32> + %s1 = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor<10x14xf32> + // CHECK: return %[[RES]] : tensor<10x14xf32> + return %1 : tensor<10x14xf32> +} + +// CHECK-LABEL: func @move_split_axis_dynamic_grid +func.func @move_split_axis_dynamic_grid( + // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32> + %arg0: tensor<10x14xf32> +) -> tensor<10x14xf32> { + // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32> + // CHECK: %[[ALL_TO_ALL:.*]] = shard.all_to_all %[[SOURCE_SHARD]] on @grid_1d_dynamic grid_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x?xf32> + // CHECK: %[[TARGET_SHARD:.*]] = tensor.cast %[[ALL_TO_ALL]] : tensor<?x?xf32> to tensor<10x?xf32> + // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x?xf32> to tensor<10x14xf32> + %s0 = shard.sharding @grid_1d_dynamic split_axes = [[0]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<10x14xf32> + %s1 = shard.sharding @grid_1d_dynamic split_axes = [[], [0]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor<10x14xf32> + // CHECK: return %[[RES]] : tensor<10x14xf32> + return %1 : tensor<10x14xf32> +} + +// CHECK-LABEL: func @move_split_dynamic_axis +func.func @move_split_dynamic_axis( + // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32> + %arg0: tensor<?x14xf32> +) -> tensor<?x14xf32> { + // CHECK: %[[TARGET_SHARD:.*]] = shard.all_to_all %[[ARG]] on @grid_1d grid_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x7xf32> + // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<?x7xf32> to tensor<?x14xf32> + %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<?x14xf32> + %s1 = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor<?x14xf32> + // CHECK: return %[[RES]] : tensor<?x14xf32> + return %1 : tensor<?x14xf32> +} + +// CHECK-LABEL: func @unshard_static_axis +func.func @unshard_static_axis( + // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32> + %arg0: tensor<10x14xf32> +) -> tensor<10x14xf32> { + // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32> + // CHECK: %[[ALL_GATHER:.*]] = shard.all_gather %[[SOURCE_SHARD]] on @grid_1d grid_axes = [0] gather_axis = 0 : tensor<5x14xf32> -> tensor<10x14xf32> + %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<10x14xf32> + %s1 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor<10x14xf32> + // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32> + return %1 : tensor<10x14xf32> +} + +// CHECK-LABEL: func @unshard_static_last_axis +func.func @unshard_static_last_axis( + // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32> + %arg0: tensor<10x14xf32> +) -> tensor<10x14xf32> { + // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<10x7xf32> + // CHECK: %[[ALL_GATHER:.*]] = shard.all_gather %[[SOURCE_SHARD]] on @grid_1d grid_axes = [0] gather_axis = 1 : tensor<10x7xf32> -> tensor<10x14xf32> + %s0 = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<10x14xf32> + %s1 = shard.sharding @grid_1d split_axes = [[], []] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor<10x14xf32> + // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32> + return %1 : tensor<10x14xf32> +} + +// CHECK-LABEL: func @unshard_dynamic_axis +func.func @unshard_dynamic_axis( + // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32> + %arg0: tensor<?x14xf32> +) -> tensor<?x14xf32> { + // CHECK: %[[ALL_GATHER:.*]] = shard.all_gather %[[ARG]] on @grid_1d grid_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32> + %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<?x14xf32> + %s1 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor<?x14xf32> + // CHECK: return %[[ALL_GATHER]] : tensor<?x14xf32> + return %1 : tensor<?x14xf32> +} + +// CHECK-LABEL: func @unshard_static_axis_on_dynamic_grid_axis +func.func @unshard_static_axis_on_dynamic_grid_axis( +// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32> + %arg0: tensor<10x14xf32> +) -> tensor<10x14xf32> { + // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32> + // CHECK: %[[ALL_GATHER:.*]] = shard.all_gather %[[SOURCE_SHARD]] on @grid_1d_dynamic grid_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32> + // CHECK: %[[RES:.*]] = tensor.cast %[[ALL_GATHER]] : tensor<?x14xf32> to tensor<10x14xf32> + %s0 = shard.sharding @grid_1d_dynamic split_axes = [[0]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<10x14xf32> + %s1 = shard.sharding @grid_1d_dynamic split_axes = [[]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor<10x14xf32> + // CHECK: return %[[RES]] : tensor<10x14xf32> + return %1 : tensor<10x14xf32> +} diff --git a/mlir/test/Dialect/Mesh/sharding-propagation-failed.mlir b/mlir/test/Dialect/Shard/sharding-propagation-failed.mlir index b5eb98d..b5eb98d 100644 --- a/mlir/test/Dialect/Mesh/sharding-propagation-failed.mlir +++ b/mlir/test/Dialect/Shard/sharding-propagation-failed.mlir diff --git a/mlir/test/Dialect/Shard/sharding-propagation.mlir b/mlir/test/Dialect/Shard/sharding-propagation.mlir new file mode 100644 index 0000000..34aaf05 --- /dev/null +++ b/mlir/test/Dialect/Shard/sharding-propagation.mlir @@ -0,0 +1,301 @@ +// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation,cse))" %s | FileCheck %s + +shard.grid @grid_2(shape = 2) +shard.grid @grid_1d(shape = ?) +shard.grid @grid_2d(shape = 2x4) +shard.grid @grid_3d(shape = ?x?x?) + +// CHECK-LABEL: func.func @element_wise_empty_sharding_info +func.func @element_wise_empty_sharding_info(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { + // CHECK-NEXT: tosa.sigmoid + %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: return + return %0 : tensor<8x16xf32> +} + +// CHECK-LABEL: func.func @element_wise_on_def +// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32> +func.func @element_wise_on_def(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { + // CHECK-NEXT: %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0], [1]] : !shard.sharding + // CHECK-NEXT: %[[V0:.*]] = shard.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32> + // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]] + %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %[[V2:.*]] = shard.shard %[[V1]] to %[[S0]] : tensor<8x16xf32> + %s1 = shard.sharding @grid_2d split_axes = [[0], [1]] : !shard.sharding + %1 = shard.shard %0 to %s1 : tensor<8x16xf32> + // CHECK-NEXT: return %[[V2]] + return %1 : tensor<8x16xf32> +} + +// CHECK-LABEL: func.func @element_wise_on_use +// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32> +func.func @element_wise_on_use(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { + // CHECK-NEXT: %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0], [1]] : !shard.sharding + // CHECK-NEXT: %[[V0:.*]] = shard.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32> + %s0 = shard.sharding @grid_2d split_axes = [[0], [1]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 annotate_for_users : tensor<8x16xf32> + // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]] + %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %[[V2:.*]] = shard.shard %[[V1]] to %[[S0]] : tensor<8x16xf32> + // CHECK-NEXT: return %[[V2]] + return %1 : tensor<8x16xf32> +} + +// CHECK-LABEL: func.func @element_wise_on_graph_output +// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32> +func.func @element_wise_on_graph_output(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { + // CHECK-NEXT: %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0], [1]] : !shard.sharding + // CHECK-NEXT: %[[V0:.*]] = shard.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32> + // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]] + %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %[[V2:.*]] = shard.shard %[[V1]] to %[[S0]] : tensor<8x16xf32> + // CHECK-NEXT: %[[V3:.*]] = shard.shard %[[V2]] to %[[S0]] annotate_for_users : tensor<8x16xf32> + %s1 = shard.sharding @grid_2d split_axes = [[0], [1]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor<8x16xf32> + // CHECK-NEXT: return %[[V3]] + return %1 : tensor<8x16xf32> +} + +// CHECK-LABEL: func.func @element_wise_on_graph_input +// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32> +func.func @element_wise_on_graph_input(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { + // CHECK-NEXT: %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0], [1]] : !shard.sharding + // CHECK-NEXT: %[[V0:.*]] = shard.shard %[[ARG]] to %[[S0]] : tensor<8x16xf32> + // CHECK-NEXT: %[[V1:.*]] = shard.shard %[[V0]] to %[[S0]] annotate_for_users : tensor<8x16xf32> + %s0 = shard.sharding @grid_2d split_axes = [[0], [1]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<8x16xf32> + // CHECK-NEXT: %[[V2:.*]] = tosa.sigmoid %[[V1]] + %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %[[V3:.*]] = shard.shard %[[V2]] to %[[S0]] : tensor<8x16xf32> + // CHECK-NEXT: return %[[V3]] + return %1 : tensor<8x16xf32> +} + +// CHECK-LABEL: func.func @arrow_structure +// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32> +func.func @arrow_structure(%arg0: tensor<8x16xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) { + // CHECK-NEXT: %[[S1:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0], [1]] : !shard.sharding + // CHECK-NEXT: %[[V1:.*]] = shard.shard %[[ARG]] to %[[S1]] annotate_for_users : tensor<8x16xf32> + // CHECK-NEXT: %[[V2:.*]] = tosa.tanh %[[V1]] + // CHECK-NEXT: %[[V3:.*]] = shard.shard %[[V2]] to %[[S1]] : tensor<8x16xf32> + %0 = tosa.tanh %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %[[V4:.*]] = shard.shard %[[V3]] to %[[S1]] annotate_for_users : tensor<8x16xf32> + // CHECK-NEXT: %[[V5:.*]] = tosa.abs %[[V4]] + // CHECK-NEXT: %[[V6:.*]] = shard.shard %[[V5]] to %[[S1]] : tensor<8x16xf32> + %1 = tosa.abs %0: (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %[[S3:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}]] : !shard.sharding + // CHECK-NEXT: %[[ZP1:.*]] = shard.shard %arg1 to %[[S3]] annotate_for_users : tensor<1xf32> + // CHECK-NEXT: %[[ZP2:.*]] = shard.shard %arg2 to %[[S3]] annotate_for_users : tensor<1xf32> + // CHECK-NEXT: %[[V7:.*]] = tosa.negate %[[V4]], %[[ZP1]], %[[ZP2]] + // CHECK-NEXT: %[[V8:.*]] = shard.shard %[[V7]] to %[[S1]] : tensor<8x16xf32> + %2 = tosa.negate %0, %arg1, %arg2 : (tensor<8x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<8x16xf32> + %s3 = shard.sharding @grid_2d split_axes = [[0], [1]] : !shard.sharding + %3 = shard.shard %2 to %s3 : tensor<8x16xf32> + // CHECK-NEXT: return %[[V6]], %[[V8]] + return %1, %3 : tensor<8x16xf32>, tensor<8x16xf32> +} + +// CHECK-LABEL: func.func @matmul_on_def_shard_batch_and_m +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32> +func.func @matmul_on_def_shard_batch_and_m(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> { + // CHECK-NEXT: %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0], [1]] : !shard.sharding + // CHECK-NEXT: %[[V0:.*]] = shard.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32> + // CHECK-NEXT: %[[S1:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0]] : !shard.sharding + // CHECK-NEXT: %[[V1:.*]] = shard.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32> + // CHECK-NEXT: %[[S2:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}]] : !shard.sharding + // CHECK-NEXT: %[[ZP:.*]] = shard.shard %[[ARG2]] to %[[S2]] annotate_for_users : tensor<1xf32> + // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]] + %0 = tosa.matmul %arg0, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32> + // CHECK-NEXT: %[[V3:.*]] = shard.shard %[[V2]] to %[[S0]] : tensor<2x16x32xf32> + %s1 = shard.sharding @grid_2d split_axes = [[0], [1]] : !shard.sharding + %1 = shard.shard %0 to %s1 : tensor<2x16x32xf32> + // CHECK-NEXT: return %[[V3]] + return %1 : tensor<2x16x32xf32> +} + +// CHECK-LABEL: func.func @matmul_on_def_shard_m_and_n +// CHECK-SAME: [[varg0:%.*]]: tensor<2x16x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<1xf32> +func.func @matmul_on_def_shard_m_and_n(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> { + // CHECK: [[vsharding:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [0]] : !shard.sharding + // CHECK: [[vsharded:%.*]] = shard.shard [[varg0]] to [[vsharding]] annotate_for_users : tensor<2x16x8xf32> + // CHECK: [[vsharding_0:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [], [1]] : !shard.sharding + // CHECK: [[vsharded_1:%.*]] = shard.shard [[varg1]] to [[vsharding_0]] annotate_for_users : tensor<2x8x32xf32> + // CHECK: [[vsharding_2:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}]] : !shard.sharding + // CHECK: [[vsharded_3:%.*]] = shard.shard [[varg2]] to [[vsharding_2]] annotate_for_users : tensor<1xf32> + // CHECK: [[v0:%.*]] = tosa.matmul + %0 = tosa.matmul %arg0, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32> + // CHECK: [[vsharding_4:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [0], [1]] : !shard.sharding + // CHECK: [[vsharded_5:%.*]] = shard.shard [[v0]] to [[vsharding_4]] : tensor<2x16x32xf32> + %s1 = shard.sharding @grid_2d split_axes = [[], [0], [1]] : !shard.sharding + %1 = shard.shard %0 to %s1 : tensor<2x16x32xf32> + // CHECK-NEXT: return [[vsharded_5]] + return %1 : tensor<2x16x32xf32> +} + +// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_k +// CHECK-SAME: [[varg0:%.*]]: tensor<2x16x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<1xf32> +func.func @matmul_on_use_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> { + // CHECK: [[vsharding:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [0], [1]] : !shard.sharding + %s0 = shard.sharding @grid_2d split_axes = [[], [0], [1]] : !shard.sharding + // CHECK: [[vsharded:%.*]] = shard.shard [[varg0]] to [[vsharding]] : tensor<2x16x8xf32> + %arg0_s = shard.shard %arg0 to %s0 : tensor<2x16x8xf32> + // CHECK: [[vsharded_0:%.*]] = shard.shard [[vsharded]] to [[vsharding]] annotate_for_users : tensor<2x16x8xf32> + // CHECK: [[vsharding_1:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [1]] : !shard.sharding + // CHECK: [[vsharded_2:%.*]] = shard.shard [[varg1]] to [[vsharding_1]] annotate_for_users : tensor<2x8x32xf32> + // CHECK: [[vsharding_3:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}]] : !shard.sharding + // CHECK: [[vsharded_4:%.*]] = shard.shard [[varg2]] to [[vsharding_3]] annotate_for_users : tensor<1xf32> + // CHECK: [[v0:%.*]] = tosa.matmul + // CHECK: [[vsharding_5:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [0]] : !shard.sharding + // CHECK: [[vsharded_6:%.*]] = shard.shard [[v0]] to [[vsharding_5]] : tensor<2x16x32xf32> + %0 = tosa.matmul %arg0_s, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32> + // CHECK: return [[vsharded_6]] + return %0 : tensor<2x16x32xf32> +} + +// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_duplicted_k +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32> +func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> { + // CHECK-NEXT: %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [1], [0]] : !shard.sharding + // CHECK-NEXT: %[[V0:.*]] = shard.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32> + %s0 = shard.sharding @grid_2d split_axes = [[], [1], [0]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 annotate_for_users : tensor<2x16x8xf32> + // CHECK-NEXT: %[[S1:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [0]] : !shard.sharding + // CHECK-NEXT: %[[V1:.*]] = shard.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32> + %s1 = shard.sharding @grid_2d split_axes = [[], [0]] : !shard.sharding + %1 = shard.shard %arg1 to %s1 annotate_for_users : tensor<2x8x32xf32> + // CHECK-NEXT: %[[S2:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}]] : !shard.sharding + // CHECK-NEXT: %[[ZP:.*]] = shard.shard %[[ARG2]] to %[[S2]] annotate_for_users : tensor<1xf32> + // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]] + %2 = tosa.matmul %0, %1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32> + // CHECK-NEXT: %[[S3:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [1]] : !shard.sharding + // CHECK-NEXT: %[[V3:.*]] = shard.shard %[[V2]] to %[[S3]] : tensor<2x16x32xf32> + // CHECK-NEXT: return %[[V3]] + return %2 : tensor<2x16x32xf32> +} + +// CHECK-LABEL: func.func @resolve_conflicting_annotations +func.func @resolve_conflicting_annotations( + // CHECK-SAME: %[[IN1:.*]]: tensor<2x3xf32>, + %arg0: tensor<2x3xf32>, + // CHECK-SAME: %[[IN2:.*]]: tensor<3x2xf32>, + %arg1: tensor<3x2xf32>, + // CHECK-SAME: %[[OUT_DPS:.*]]: tensor<2x2xf32> + %out_dps: tensor<2x2xf32> +// CHECK-SAME: ) -> tensor<2x2xf32> { +) -> tensor<2x2xf32> { + // CHECK: %[[SIN1_SHARDED1:.*]] = shard.sharding @grid_2 split_axes = {{\[\[}}0]] : !shard.sharding + // CHECK-NEXT: %[[IN1_SHARDED1:.*]] = shard.shard %[[IN1]] to %[[SIN1_SHARDED1]] : tensor<2x3xf32> + // CHECK: %[[SIN2_SHARDED:.*]] = shard.sharding @grid_2 split_axes = {{\[\[}}]] : !shard.sharding + // CHECK-NEXT: %[[IN1_SHARDED2:.*]] = shard.shard %[[IN1_SHARDED1]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<2x3xf32> + // CHECK-NEXT: %[[IN2_SHARDED:.*]] = shard.shard %[[IN2]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<3x2xf32> + // CHECK-NEXT: %[[OUT_DPS_SHARDED:.*]] = shard.shard %[[OUT_DPS]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<2x2xf32> + %sarg0_sharded = shard.sharding @grid_2 split_axes = [[0]] : !shard.sharding + %arg0_sharded = shard.shard %arg0 to %sarg0_sharded : tensor<2x3xf32> + // CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[IN1_SHARDED2]], %[[IN2_SHARDED]] : tensor<2x3xf32>, tensor<3x2xf32>) + // CHECK-SAME: outs(%[[OUT_DPS_SHARDED]] : tensor<2x2xf32>) -> tensor<2x2xf32> + %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>) + outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32> + // CHECK-NEXT: %[[RES:.*]] = shard.shard %[[MATMUL]] to %[[SIN2_SHARDED]] : tensor<2x2xf32> + %sres_sharded = shard.sharding @grid_2 split_axes = [[]] : !shard.sharding + %res_sharded = shard.shard %res to %sres_sharded : tensor<2x2xf32> + // CHECK: return %[[RES]] : tensor<2x2xf32> + return %res_sharded : tensor<2x2xf32> +} + +// https://arxiv.org/abs/2211.05102 Figure 2(a) +// The sharding propagation results in unnecessary reshards, +// an optimization pass should be able to remove them. +// CHECK-LABEL: func.func @mlp_1d_weight_stationary +// CHECK-SAME: [[varg0:%.*]]: tensor<2x4x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<2x32x8xf32>, [[varg3:%.*]]: tensor<1xf32> +func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>, %arg3: tensor<1xf32>) -> tensor<2x4x8xf32> { + %s0 = shard.sharding @grid_1d split_axes = [[], [], [0, 1, 2]] : !shard.sharding + %sharded0 = shard.shard %arg0 to %s0 : tensor<2x4x8xf32> + %sharded1 = shard.shard %arg1 to %s0 : tensor<2x8x32xf32> + // CHECK: [[vsharding:%.*]] = shard.sharding @grid_1d split_axes = {{\[\[}}], [], [0, 1, 2]] : !shard.sharding + // CHECK: [[vsharded:%.*]] = shard.shard [[varg0]] to [[vsharding]] : tensor<2x4x8xf32> + // CHECK: [[vsharded_0:%.*]] = shard.shard [[varg1]] to [[vsharding]] : tensor<2x8x32xf32> + // CHECK: [[vsharded_1:%.*]] = shard.shard [[vsharded]] to [[vsharding]] annotate_for_users : tensor<2x4x8xf32> + // CHECK: [[vsharding_2:%.*]] = shard.sharding @grid_1d split_axes = {{\[\[}}], [0, 1, 2]] : !shard.sharding + // CHECK: [[vsharded_3:%.*]] = shard.shard [[vsharded_0]] to [[vsharding_2]] annotate_for_users : tensor<2x8x32xf32> + // CHECK: [[vsharding_4:%.*]] = shard.sharding @grid_1d split_axes = {{\[\[}}]] : !shard.sharding + // CHECK: [[vsharded_5:%.*]] = shard.shard [[varg3]] to [[vsharding_4]] annotate_for_users : tensor<1xf32> + // CHECK: [[v0:%.*]] = tosa.matmul + %1 = tosa.matmul %sharded0, %sharded1, %arg3, %arg3 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x32xf32> + // CHECK: [[vsharded_6:%.*]] = shard.shard [[v0]] to [[vsharding_4]] : tensor<2x4x32xf32> + // CHECK: [[vsharded_7:%.*]] = shard.shard [[vsharded_6]] to [[vsharding_4]] annotate_for_users : tensor<2x4x32xf32> + // CHECK: [[v1:%.*]] = tosa.sigmoid [[vsharded_7]] : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32> + // CHECK: [[vsharded_8:%.*]] = shard.shard [[v1]] to [[vsharding_4]] : tensor<2x4x32xf32> + %2 = tosa.sigmoid %1 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32> + %sharding = shard.sharding @grid_1d split_axes = [[], [0, 1, 2]] : !shard.sharding + // CHECK: [[vsharded_9:%.*]] = shard.shard [[varg2]] to [[vsharding_2]] : tensor<2x32x8xf32> + %sharded2 = shard.shard %arg2 to %sharding : tensor<2x32x8xf32> + // CHECK: [[vsharded_10:%.*]] = shard.shard [[vsharded_8]] to [[vsharding_4]] annotate_for_users : tensor<2x4x32xf32> + // CHECK: [[vsharded_11:%.*]] = shard.shard [[vsharded_9]] to [[vsharding]] annotate_for_users : tensor<2x32x8xf32> + // CHECK: [[v2:%.*]] = tosa.matmul + %3 = tosa.matmul %2, %sharded2, %arg3, %arg3 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x8xf32> + // CHECK: [[vsharded_12:%.*]] = shard.shard [[v2]] to [[vsharding]] : tensor<2x4x8xf32> + %s4 = shard.sharding @grid_1d split_axes = [[], [], [0, 1, 2]] : !shard.sharding + %4 = shard.shard %3 to %s4 : tensor<2x4x8xf32> + // CHECK: return [[vsharded_12]] + return %4 : tensor<2x4x8xf32> +} + +// https://arxiv.org/abs/2211.05102 Figure 2(b) +// The sharding propagation results in unnecessary reshards, +// an optimization pass should be able to remove them. +// CHECK-LABEL: func.func @mlp_2d_weight_stationary +// CHECK-SAME: [[varg0:%.*]]: tensor<2x4x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<2x32x8xf32>, [[varg3:%.*]]: tensor<1xf32> +func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>, %arg3: tensor<1xf32>) -> tensor<2x4x8xf32> { + // CHECK: [[vsharding:%.*]] = shard.sharding @grid_3d split_axes = {{\[\[}}], [], [0, 1, 2]] : !shard.sharding + %s0 = shard.sharding @grid_3d split_axes = [[], [], [0, 1, 2]] : !shard.sharding + // CHECK: [[vsharded:%.*]] = shard.shard [[varg0]] to [[vsharding]] : tensor<2x4x8xf32> + %arg0_s = shard.shard %arg0 to %s0 : tensor<2x4x8xf32> + // CHECK: [[vsharding_0:%.*]] = shard.sharding @grid_3d split_axes = {{\[\[}}], [0], [1, 2]] : !shard.sharding + %s1 = shard.sharding @grid_3d split_axes = [[], [0], [1, 2]] : !shard.sharding + // CHECK: [[vsharded_1:%.*]] = shard.shard [[varg1]] to [[vsharding_0]] : tensor<2x8x32xf32> + %arg1_s = shard.shard %arg1 to %s1 : tensor<2x8x32xf32> + // CHECK: [[vsharding_2:%.*]] = shard.sharding @grid_3d split_axes = {{\[\[}}]] : !shard.sharding + // CHECK: [[vsharded_3:%.*]] = shard.shard [[vsharded]] to [[vsharding_2]] annotate_for_users : tensor<2x4x8xf32> + // CHECK: [[vsharded_4:%.*]] = shard.shard [[vsharded_1]] to [[vsharding]] annotate_for_users : tensor<2x8x32xf32> + // CHECK: [[vsharded_5:%.*]] = shard.shard [[varg3]] to [[vsharding_2]] annotate_for_users : tensor<1xf32> + // CHECK: [[v0:%.*]] = tosa.matmul + %1 = tosa.matmul %arg0_s, %arg1_s, %arg3, %arg3 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x32xf32> + // CHECK: [[vsharded_6:%.*]] = shard.shard [[v0]] to [[vsharding]] : tensor<2x4x32xf32> + %2 = shard.shard %1 to %s0 : tensor<2x4x32xf32> + // CHECK: [[vsharded_7:%.*]] = shard.shard [[vsharded_6]] to [[vsharding]] annotate_for_users : tensor<2x4x32xf32> + // CHECK: [[v1:%.*]] = tosa.sigmoid + // CHECK: [[vsharded_8:%.*]] = shard.shard [[v1]] to [[vsharding]] : tensor<2x4x32xf32> + %3 = tosa.sigmoid %2 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32> + // CHECK: [[vsharding_9:%.*]] = shard.sharding @grid_3d split_axes = {{\[\[}}], [1, 2], [0]] : !shard.sharding + %s2 = shard.sharding @grid_3d split_axes = [[], [1, 2], [0]] : !shard.sharding + // CHECK: [[vsharded_10:%.*]] = shard.shard [[varg2]] to [[vsharding_9]] : tensor<2x32x8xf32> + %arg2_s = shard.shard %arg2 to %s2 : tensor<2x32x8xf32> + // CHECK: [[vsharded_11:%.*]] = shard.shard [[vsharded_8]] to [[vsharding_2]] annotate_for_users : tensor<2x4x32xf32> + // CHECK: [[vsharded_12:%.*]] = shard.shard [[vsharded_10]] to [[vsharding]] annotate_for_users : tensor<2x32x8xf32> + // CHECK: [[v2:%.*]] = tosa.matmul + %4 = tosa.matmul %3, %arg2_s, %arg3, %arg3 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x8xf32> + // CHECK: [[vsharded_13:%.*]] = shard.shard [[v2]] to [[vsharding]] : tensor<2x4x8xf32> + %5 = shard.shard %4 to %s0 : tensor<2x4x8xf32> + // CHECK: [[vsharded_14:%.*]] = shard.shard [[vsharded_13]] to [[vsharding]] annotate_for_users : tensor<2x4x8xf32> + %6 = shard.shard %5 to %s0 annotate_for_users : tensor<2x4x8xf32> + // CHECK: return [[vsharded_14]] + return %6 : tensor<2x4x8xf32> +} + +// CHECK-LABEL: func.func @elementwise_duplicated_chain +// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32> +func.func @elementwise_duplicated_chain(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { + // CHECK-NEXT: %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}]] : !shard.sharding + // CHECK-NEXT: %[[V0:.*]] = shard.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32> + // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]] + %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %[[V2:.*]] = shard.shard %[[V1]] to %[[S0]] : tensor<8x16xf32> + // CHECK-NEXT: %[[V3:.*]] = shard.shard %[[V2]] to %[[S0]] annotate_for_users : tensor<8x16xf32> + // CHECK-NEXT: %[[V4:.*]] = tosa.sigmoid %[[V3]] + %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %[[V5:.*]] = shard.shard %[[V4]] to %[[S0]] : tensor<8x16xf32> + %s0 = shard.sharding @grid_2d split_axes = [[]] : !shard.sharding + %2 = shard.shard %1 to %s0 : tensor<8x16xf32> + // CHECK-NEXT: return %[[V5]] + return %2 : tensor<8x16xf32> +} diff --git a/mlir/test/Dialect/Mesh/simplifications.mlir b/mlir/test/Dialect/Shard/simplifications.mlir index e955f4c..33cd490 100644 --- a/mlir/test/Dialect/Mesh/simplifications.mlir +++ b/mlir/test/Dialect/Shard/simplifications.mlir @@ -1,7 +1,7 @@ -// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s +// RUN: mlir-opt -test-grid-simplifications %s | FileCheck %s -mesh.mesh @mesh0(shape = 4x2) -mesh.mesh @mesh1(shape = 4) +shard.grid @grid0(shape = 4x2) +shard.grid @grid1(shape = 4) // Checks that `all_reduce(x) + all_reduce(y)` gets transformed to // `all_reduce(x + y)`. @@ -11,13 +11,13 @@ func.func @all_reduce_arith_addf_endomorphism( %arg0: tensor<5xf32>, // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32> %arg1: tensor<5xf32>) -> tensor<5xf32> { - %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] : tensor<5xf32> -> tensor<5xf32> - %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] + %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0] : tensor<5xf32> -> tensor<5xf32> // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ARG0]], %[[ARG1]] %2 = arith.addf %0, %1 : tensor<5xf32> - // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] + // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ADD_RES]] // CHECK: return %[[ALL_REDUCE_RES]] return %2 : tensor<5xf32> } @@ -28,13 +28,13 @@ func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_result( %arg0: tensor<5xf32>, // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32> %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) { - %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] : tensor<5xf32> -> tensor<5xf32> - %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] + %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0] : tensor<5xf32> -> tensor<5xf32> // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ARG0]], %[[ARG1]] %2 = arith.addf %0, %1 : tensor<5xf32> - // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] + // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ADD_RES]] // CHECK: return %[[ALL_REDUCE_RES]], %[[ALL_REDUCE_RES]] return %2, %2 : tensor<5xf32>, tensor<5xf32> } @@ -46,11 +46,11 @@ func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_all_reduce_result %arg0: tensor<5xf32>, // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32> %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) { - // CHECK: %[[ALL_REDUCE_0_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] - %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] + // CHECK: %[[ALL_REDUCE_0_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG0]] + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] : tensor<5xf32> -> tensor<5xf32> - // CHECK: %[[ALL_REDUCE_1_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] - %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] + // CHECK: %[[ALL_REDUCE_1_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG1]] + %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0] : tensor<5xf32> -> tensor<5xf32> // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE_0_RES]], %[[ALL_REDUCE_1_RES]] %2 = arith.addf %0, %1 : tensor<5xf32> @@ -58,17 +58,17 @@ func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_all_reduce_result return %0, %2 : tensor<5xf32>, tensor<5xf32> } -// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_mesh -func.func @all_reduce_arith_addf_no_endomorphism_different_mesh( +// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_grid +func.func @all_reduce_arith_addf_no_endomorphism_different_grid( // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32> %arg0: tensor<5xf32>, // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32> %arg1: tensor<5xf32>) -> tensor<5xf32> { - // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 - %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] + // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG0]] on @grid0 + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] : tensor<5xf32> -> tensor<5xf32> - // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh1 - %1 = mesh.all_reduce %arg1 on @mesh1 mesh_axes = [0] + // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG1]] on @grid1 + %1 = shard.all_reduce %arg1 on @grid1 grid_axes = [0] : tensor<5xf32> -> tensor<5xf32> // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]] %2 = arith.addf %0, %1 : tensor<5xf32> @@ -76,17 +76,17 @@ func.func @all_reduce_arith_addf_no_endomorphism_different_mesh( return %2 : tensor<5xf32> } -// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_mesh_axes -func.func @all_reduce_arith_addf_no_endomorphism_different_mesh_axes( +// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_grid_axes +func.func @all_reduce_arith_addf_no_endomorphism_different_grid_axes( // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32> %arg0: tensor<5xf32>, // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32> %arg1: tensor<5xf32>) -> tensor<5xf32> { - // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0] - %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] + // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG0]] on @grid0 grid_axes = [0] + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] : tensor<5xf32> -> tensor<5xf32> - // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [1] - %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [1] + // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG1]] on @grid0 grid_axes = [1] + %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [1] : tensor<5xf32> -> tensor<5xf32> // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]] %2 = arith.addf %0, %1 : tensor<5xf32> @@ -100,11 +100,11 @@ func.func @all_reduce_arith_addf_no_endomorphism_wrong_reduction_kind( %arg0: tensor<5xf32>, // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32> %arg1: tensor<5xf32>) -> tensor<5xf32> { - // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0] reduction = max - %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = max + // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG0]] on @grid0 grid_axes = [0] reduction = max + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] reduction = max : tensor<5xf32> -> tensor<5xf32> - // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [0] - %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] + // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG1]] on @grid0 grid_axes = [0] + %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0] : tensor<5xf32> -> tensor<5xf32> // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]] %2 = arith.addf %0, %1 : tensor<5xf32> @@ -118,11 +118,11 @@ func.func @all_reduce_arith_addf_no_endomorphism_different_operand_result_elemen %arg0: tensor<5xf32>, // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32> %arg1: tensor<5xf32>) -> tensor<5xf64> { - // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0] - %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] + // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG0]] on @grid0 grid_axes = [0] + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] : tensor<5xf32> -> tensor<5xf64> - // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [0] - %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] + // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG1]] on @grid0 grid_axes = [0] + %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0] : tensor<5xf32> -> tensor<5xf64> // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]] %2 = arith.addf %0, %1 : tensor<5xf64> @@ -138,13 +138,13 @@ func.func @all_reduce_arith_minimumf_endomorphism( %arg0: tensor<5xf32>, // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32> %arg1: tensor<5xf32>) -> tensor<5xf32> { - %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = min + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] reduction = min : tensor<5xf32> -> tensor<5xf32> - %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = min + %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0] reduction = min : tensor<5xf32> -> tensor<5xf32> // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minimumf %[[ARG0]], %[[ARG1]] %2 = arith.minimumf %0, %1 : tensor<5xf32> - // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = min + // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ADD_RES]] on @grid0 grid_axes = [0] reduction = min // CHECK: return %[[ALL_REDUCE_RES]] return %2 : tensor<5xf32> } @@ -155,13 +155,13 @@ func.func @all_reduce_arith_minsi_endomorphism( %arg0: tensor<5xi32>, // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xi32> %arg1: tensor<5xi32>) -> tensor<5xi32> { - %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = min + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] reduction = min : tensor<5xi32> -> tensor<5xi32> - %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = min + %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0] reduction = min : tensor<5xi32> -> tensor<5xi32> // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minsi %[[ARG0]], %[[ARG1]] %2 = arith.minsi %0, %1 : tensor<5xi32> - // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = min + // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ADD_RES]] on @grid0 grid_axes = [0] reduction = min // CHECK: return %[[ALL_REDUCE_RES]] return %2 : tensor<5xi32> } diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir index 16efa73..9da2dea 100644 --- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir +++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir @@ -1,22 +1,92 @@ // RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack %s | FileCheck %s // RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack-control %s | FileCheck %s --check-prefix=CONTROL -func.func @fold_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>, +func.func @fold_extract_slice_into_unpack_slicing_trailing_dim(%arg0 : tensor<28x2x1x16x16xf32>) -> tensor<28x28x10xf32> { + %empty = tensor.empty() : tensor<28x28x15xf32> + %unpack = linalg.unpack %arg0 + inner_dims_pos = [1, 2] + inner_tiles = [16, 16] + into %empty : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32> + %extracted_slice = tensor.extract_slice %unpack + [0, 0, 0] [28, 28, 10] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x28x10xf32> + return %extracted_slice : tensor<28x28x10xf32> +} +// CHECK-LABEL: func @fold_extract_slice_into_unpack_slicing_trailing_dim +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK: %[[DEST_SLICE:.+]] = tensor.empty() : tensor<28x28x10xf32> +// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]] +// CHECK-SAME: into %[[DEST_SLICE]] +// CHECK: return %[[UNPACK]] + +// ----- + +// The available dimension size is [17, 32], because CeilDiv(%d1, 16) == 2. + +func.func @fold_extract_slice_into_unpack_slicing_dim_1(%arg0 : tensor<28x2x1x16x16xf32>) -> tensor<28x17x15xf32> { + %empty = tensor.empty() : tensor<28x28x15xf32> + %unpack = linalg.unpack %arg0 + inner_dims_pos = [1, 2] + inner_tiles = [16, 16] + into %empty : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32> + %extracted_slice = tensor.extract_slice %unpack + [0, 0, 0] [28, 17, 15] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x17x15xf32> + return %extracted_slice : tensor<28x17x15xf32> +} +// CHECK-LABEL: func @fold_extract_slice_into_unpack_slicing_dim_1( +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK: %[[DEST_SLICE:.+]] = tensor.empty() : tensor<28x17x15xf32> +// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]] +// CHECK-SAME: into %[[DEST_SLICE]] +// CHECK: return %[[UNPACK]] + +// ----- + +// The available dimension size is [17, 32], because CeilDiv(%d1, 16) == 2. + +func.func @no_fold_extract_slice_into_unpack_artificial_padding(%arg0 : tensor<28x2x1x16x16xf32>) -> tensor<28x16x15xf32> { + %empty = tensor.empty() : tensor<28x28x15xf32> + %unpack = linalg.unpack %arg0 + inner_dims_pos = [1, 2] + inner_tiles = [16, 16] + into %empty : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32> + %extracted_slice = tensor.extract_slice %unpack + [0, 0, 0] [28, 16, 15] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x16x15xf32> + return %extracted_slice : tensor<28x16x15xf32> +} +// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_artificial_padding +// CHECK: linalg.unpack +// CHECK: tensor.extract_slice + +// ----- + +func.func @no_fold_extract_slice_into_unpack_dynamic( + %src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index +) -> tensor<28x28x?xf32> { + %unpack = linalg.unpack %src + outer_dims_perm = [0, 1, 2] + inner_dims_pos = [1, 2] + inner_tiles = [16, 16] + into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32> + %extracted_slice = tensor.extract_slice %unpack + [0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32> + return %extracted_slice : tensor<28x28x?xf32> +} +// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_dynamic +// CHECK: linalg.unpack +// CHECK: tensor.extract_slice + +// ----- + +func.func @nofold_dynamic_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>, %arg2 : index, %arg3 : index) -> tensor<?x?xf32> { %0 = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1 : tensor<?x?x8x4xf32> -> tensor<?x?xf32> %1 = tensor.extract_slice %0[0, 0] [%arg2, %arg3] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32> return %1 : tensor<?x?xf32> } -// CHECK: func @fold_unpack_slice( -// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x8x4xf32> -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32> -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index -// CHECK: %[[INIT:.+]] = tensor.empty(%[[ARG2]], %[[ARG3]]) : tensor<?x?xf32> -// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [8, 4] -// CHECK-SAME: into %[[INIT]] -// CHECK: return %[[UNPACK]] +// CHECK-LABEL: func @nofold_dynamic_unpack_slice( +// CHECK: linalg.unpack +// CHECK: tensor.extract_slice // ----- @@ -59,48 +129,62 @@ func.func @nofold_unpack_slice_rank_reduced(%arg0 : tensor<?x?x8x4xf32>, %arg1 : // ----- -func.func @pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> { - %c0 = arith.constant 0 : index +func.func @fold_pad_pack(%src: tensor<9x16xf32>) -> tensor<2x1x8x32xf32> { %cst = arith.constant 0.000000e+00 : f32 - %padded = tensor.pad %src low[0, 0] high[15, 0] { + %padded = tensor.pad %src low[0, 0] high[7, 0] { ^bb0(%arg0: index, %arg1: index): tensor.yield %cst : f32 - } : tensor<16641x16xf32> to tensor<16656x16xf32> - %empty = tensor.empty() : tensor<2082x1x8x32xf32> + } : tensor<9x16xf32> to tensor<16x16xf32> + %empty = tensor.empty() : tensor<2x1x8x32xf32> %pack = linalg.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty - : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32> - return %pack : tensor<2082x1x8x32xf32> + : tensor<16x16xf32> -> tensor<2x1x8x32xf32> + return %pack : tensor<2x1x8x32xf32> } -// CHECK-LABEL: func.func @pad_pack +// CHECK-LABEL: func.func @fold_pad_pack // CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] // CHECK: %[[PAD_VAL:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<2082x1x8x32xf32> +// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<2x1x8x32xf32> // CHECK: %[[PACK:.+]] = linalg.pack %[[SRC]] // CHECK-SAME: padding_value(%[[PAD_VAL]] : f32) // CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %[[DEST]] // ----- -func.func @nofold_pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> { - %c0 = arith.constant 0 : index +func.func @nofold_pad_pack_artificial_padding(%src: tensor<9x16xf32>) -> tensor<3x1x8x32xf32> { %cst = arith.constant 0.000000e+00 : f32 - %padded = tensor.pad %src nofold low[0, 0] high[15, 0] { + %padded = tensor.pad %src low[0, 0] high[8, 0] { ^bb0(%arg0: index, %arg1: index): tensor.yield %cst : f32 - } : tensor<16641x16xf32> to tensor<16656x16xf32> + } : tensor<9x16xf32> to tensor<17x16xf32> + %empty = tensor.empty() : tensor<3x1x8x32xf32> + %pack = linalg.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty + : tensor<17x16xf32> -> tensor<3x1x8x32xf32> + return %pack : tensor<3x1x8x32xf32> +} +// CHECK-LABLE: func.func @nofold_pad_pack_artificial_padding( +// CHECK: tensor.pad +// CHECK: linalg.pack + +// ----- + +func.func @nofold_pad_pack_with_nofold_attribute(%src: tensor<16649x16xf32>) -> tensor<2082x1x8x32xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %padded = tensor.pad %src nofold low[0, 0] high[7, 0] { + ^bb0(%arg0: index, %arg1: index): + tensor.yield %cst : f32 + } : tensor<16649x16xf32> to tensor<16656x16xf32> %empty = tensor.empty() : tensor<2082x1x8x32xf32> %pack = linalg.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32> return %pack : tensor<2082x1x8x32xf32> } -// CHECK-LABEL: func.func @nofold_pad_pack +// CHECK-LABEL: func.func @nofold_pad_pack_with_nofold_attribute( // CHECK: tensor.pad // CHECK: linalg.pack // ----- func.func @pad_pack_different_padding_value(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> { - %c0 = arith.constant 0 : index %cst0 = arith.constant 0.000000e+00 : f32 %cst1 = arith.constant 1.000000e+00 : f32 %padded = tensor.pad %src low[0, 0] high[15, 0] { diff --git a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir deleted file mode 100644 index 8598d81..0000000 --- a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir +++ /dev/null @@ -1,52 +0,0 @@ -// RUN: mlir-opt \ -// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-single-fold))" \ -// RUN: %s | FileCheck %s - -mesh.mesh @mesh_1d_4(shape = 4) - -// CHECK-LABEL: func @tensor_empty_static_sharded_dims_offsets -func.func @tensor_empty_static_sharded_dims_offsets() -> () { - %b = tensor.empty() : tensor<8x16xf32> - %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding - %sharded= mesh.shard %b to %sharding : tensor<8x16xf32> - // CHECK: %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding - // CHECK: %[[proc_multi_idx:.*]] = mesh.process_multi_index on @mesh_1d_4 : index - // CHECK: %[[V0:.*]]:2 = mesh.shard_shape dims = [8, 16] sharding = %[[sharding]] device = [%[[proc_multi_idx]] - // CHECK-SAME: ] : index, index - // CHECK: tensor.empty(%[[V0]]#0) : tensor<?x16xf32> - - return -} - -// CHECK-LABEL: func @tensor_empty_dynamic_sharded_dims_offsets -// CHECK-SAME: %[[A0:.*]]: index -func.func @tensor_empty_dynamic_sharded_dims_offsets(%arg0 : index) -> () { - %b = tensor.empty(%arg0) : tensor<8x?xf32> - %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding - %sharded= mesh.shard %b to %sharding : tensor<8x?xf32> - // CHECK: %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding - // CHECK: %[[proc_multi_idx:.*]] = mesh.process_multi_index on @mesh_1d_4 : index - // CHECK: %[[V0:.*]]:2 = mesh.shard_shape dims = [8, %[[A0]] - // CHECK-SAME: ] sharding = %[[sharding]] device = [%[[proc_multi_idx]] - // CHECK-SAME: ] : index, index - // CHECK: tensor.empty(%[[V0]]#0, %[[A0]]) : tensor<?x?xf32> - - return -} - -// CHECK-LABEL: func @tensor_empty_same_static_dims_sizes -func.func @tensor_empty_same_static_dims_sizes() -> () { - %b = tensor.empty() : tensor<16x16xf32> - %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 4, 8, 12, 16] : !mesh.sharding - %sharded= mesh.shard %b to %sharding : tensor<16x16xf32> - // CHECK-NEXT: tensor.empty() : tensor<4x16xf32> - - return -} - -// CHECK-LABEL: func @tensor_empty_0d -func.func @tensor_empty_0d() -> () { - tensor.empty() : tensor<f32> - // CHECK-NEXT: tensor.empty() : tensor<f32> - return -} diff --git a/mlir/test/Dialect/Tensor/shard-partition.mlir b/mlir/test/Dialect/Tensor/shard-partition.mlir new file mode 100644 index 0000000..5918ee1 --- /dev/null +++ b/mlir/test/Dialect/Tensor/shard-partition.mlir @@ -0,0 +1,52 @@ +// RUN: mlir-opt \ +// RUN: --pass-pipeline="builtin.module(func.func(shard-partition,test-single-fold))" \ +// RUN: %s | FileCheck %s + +shard.grid @grid_1d_4(shape = 4) + +// CHECK-LABEL: func @tensor_empty_static_sharded_dims_offsets +func.func @tensor_empty_static_sharded_dims_offsets() -> () { + %b = tensor.empty() : tensor<8x16xf32> + %sharding = shard.sharding @grid_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !shard.sharding + %sharded= shard.shard %b to %sharding : tensor<8x16xf32> + // CHECK: %[[sharding:.*]] = shard.sharding @grid_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !shard.sharding + // CHECK: %[[proc_multi_idx:.*]] = shard.process_multi_index on @grid_1d_4 : index + // CHECK: %[[V0:.*]]:2 = shard.shard_shape dims = [8, 16] sharding = %[[sharding]] device = [%[[proc_multi_idx]] + // CHECK-SAME: ] : index, index + // CHECK: tensor.empty(%[[V0]]#0) : tensor<?x16xf32> + + return +} + +// CHECK-LABEL: func @tensor_empty_dynamic_sharded_dims_offsets +// CHECK-SAME: %[[A0:.*]]: index +func.func @tensor_empty_dynamic_sharded_dims_offsets(%arg0 : index) -> () { + %b = tensor.empty(%arg0) : tensor<8x?xf32> + %sharding = shard.sharding @grid_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !shard.sharding + %sharded= shard.shard %b to %sharding : tensor<8x?xf32> + // CHECK: %[[sharding:.*]] = shard.sharding @grid_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !shard.sharding + // CHECK: %[[proc_multi_idx:.*]] = shard.process_multi_index on @grid_1d_4 : index + // CHECK: %[[V0:.*]]:2 = shard.shard_shape dims = [8, %[[A0]] + // CHECK-SAME: ] sharding = %[[sharding]] device = [%[[proc_multi_idx]] + // CHECK-SAME: ] : index, index + // CHECK: tensor.empty(%[[V0]]#0, %[[A0]]) : tensor<?x?xf32> + + return +} + +// CHECK-LABEL: func @tensor_empty_same_static_dims_sizes +func.func @tensor_empty_same_static_dims_sizes() -> () { + %b = tensor.empty() : tensor<16x16xf32> + %sharding = shard.sharding @grid_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 4, 8, 12, 16] : !shard.sharding + %sharded= shard.shard %b to %sharding : tensor<16x16xf32> + // CHECK-NEXT: tensor.empty() : tensor<4x16xf32> + + return +} + +// CHECK-LABEL: func @tensor_empty_0d +func.func @tensor_empty_0d() -> () { + tensor.empty() : tensor<f32> + // CHECK-NEXT: tensor.empty() : tensor<f32> + return +} diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir index 0176fc2..6398161 100644 --- a/mlir/test/Dialect/Tosa/availability.mlir +++ b/mlir/test/Dialect/Tosa/availability.mlir @@ -645,7 +645,7 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> { func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { // CHECK: tosa.cond_if profiles: [ ] // CHECK: tosa.cond_if extensions: [ [controlflow] ] - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %1 : tensor<f32> } else { diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 11c8d54..5150ee3 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -241,6 +241,26 @@ func.func @clamp_f32_is_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> { // ----- +// CHECK-LABEL: @clamp_boolean_is_noop +func.func @clamp_boolean_is_noop(%arg0: tensor<4xi1>) -> tensor<4xi1> { + // CHECK: return %arg0 + // CHECK-NOT: tosa.clamp + %0 = tosa.clamp %arg0 {min_val = false, max_val = true} : (tensor<4xi1>) -> tensor<4xi1> + return %0 : tensor<4xi1> +} + +// ----- + +// CHECK-LABEL: @clamp_boolean_dynamic_is_noop +func.func @clamp_boolean_dynamic_is_noop(%arg0: tensor<?xi1>) -> tensor<?xi1> { + // CHECK: return %arg0 + // CHECK-NOT: tosa.clamp + %0 = tosa.clamp %arg0 {min_val = false, max_val = true} : (tensor<?xi1>) -> tensor<?xi1> + return %0 : tensor<?xi1> +} + +// ----- + // CHECK-LABEL: @clamp_int8_is_noop func.func @clamp_int8_is_noop(%arg0: tensor<4xi8>) -> tensor<4xi8> { // CHECK: return %arg0 @@ -1349,3 +1369,14 @@ func.func @test_fold_i1_to_i32_cast() -> tensor<i32> { %1 = "tosa.cast"(%0) : (tensor<i1>) -> tensor<i32> return %1 : tensor<i32> } + +// ----- + +// CHECK-LABEL: @test_fold_i32_to_i1_cast +// CHECK: %[[OUT:.*]] = "tosa.const"() <{values = dense<true> : tensor<i1>}> : () -> tensor<i1> +// CHECK: return %[[OUT]] : tensor<i1> +func.func @test_fold_i32_to_i1_cast() -> tensor<i1> { + %0 = "tosa.const"() <{values = dense<10> : tensor<i32>}> : () -> tensor<i32> + %1 = "tosa.cast"(%0) : (tensor<i32>) -> tensor<i1> + return %1 : tensor<i1> +} diff --git a/mlir/test/Dialect/Tosa/controlflow.mlir b/mlir/test/Dialect/Tosa/controlflow.mlir new file mode 100644 index 0000000..06312c7 --- /dev/null +++ b/mlir/test/Dialect/Tosa/controlflow.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-opt -split-input-file %s | FileCheck %s + +// ----- + +func.func @condif_cond_type_check(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { + // CHECK: tosa.cond_if %[[ARG2:.*]] : tensor<i1> -> tensor<f32> { + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { + %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + // CHECK: } else { + } else { + %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + } + return %0 : tensor<f32> +} + +// ----- + +func.func @condif_block_args_check(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { + // CHECK: tosa.cond_if %[[ARG2:.*]] (%[[ARG3:.*]] = %[[ARG0:.*]], %[[ARG4:.*]] = %[[ARG1:.*]]) : tensor<i1> (tensor<f32>, tensor<f32>) -> tensor<f32> { + // CHECK-NEXT: ^bb0(%[[ARG3]]: tensor<f32>, %[[ARG4]]: tensor<f32>): + %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<f32>, tensor<f32>) -> tensor<f32> { + ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): + %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + // CHECK: } else { + // CHECK-NEXT: ^bb0(%[[ARG3]]: tensor<f32>, %[[ARG4]]: tensor<f32>): + } else { + ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): + %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + } + return %0 : tensor<f32> +} diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir index eb25011..fad1bec 100644 --- a/mlir/test/Dialect/Tosa/error_if_check.mlir +++ b/mlir/test/Dialect/Tosa/error_if_check.mlir @@ -259,7 +259,7 @@ func.func @test_cond_if_else_not_isolated_from_above(%arg0: tensor<f32>, %arg1: func.func @test_cond_if_simplified_form_not_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { // expected-error@+1 {{'tosa.cond_if' op is not conformant to the TOSA specification. It requires the 'then' region is isolated from above.}} - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + %0 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<f32>) { tosa.yield %arg0 : tensor<f32> } else { tosa.yield %arg1 : tensor<f32> diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index ed74714..3bccb32 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -300,7 +300,7 @@ func.func @test_concat_input_rank_mismatch(%arg0: tensor<1x2x3xf32>, %arg1: tens func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>) { %0 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4> %1 = "tosa.const"() {values = dense<3.14> : tensor<2xf32>} : () -> tensor<2xf32> - // expected-error@+1 {{'tosa.pad' op operand #2 must be tosa-conformant scalar tensor of number values, but got 'tensor<2xf32>'}} + // expected-error@+1 {{'tosa.pad' op operand #2 must be tosa-conformant unranked tensor of unsigned integer or signless integer or floating-point values or tosa-conformant scalar tensor of number values, but got 'tensor<2xf32>'}} %2 = tosa.pad %arg0, %0, %1 : (tensor<13x21xf32>, !tosa.shape<4>, tensor<2xf32>) -> tensor<13x21xf32> return } @@ -1006,7 +1006,7 @@ func.func @test_non_tosa_ops() { func.func @test_pad_rank0_pad_const(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf8E5M2> { %padding = tosa.const_shape {values = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6> %cst = "tosa.const"() { values = dense<-0.0> : tensor<f8E4M3FN> } : () -> tensor<f8E4M3FN> - // expected-error@+1 {{'tosa.pad' op operand #2 must be tosa-conformant scalar tensor of number values, but got 'tensor<f8E4M3FN>'}} + // expected-error@+1 {{'tosa.pad' op operand #2 must be tosa-conformant unranked tensor of unsigned integer or signless integer or floating-point values or tosa-conformant scalar tensor of number values, but got 'tensor<f8E4M3FN>'}} %0 = tosa.pad %arg0, %padding, %cst : (tensor<13x21x3xf8E4M3FN>, !tosa.shape<6>, tensor<f8E4M3FN>) -> tensor<13x21x3xf8E5M2> return %0 : tensor<13x21x3xf8E5M2> } @@ -1125,7 +1125,7 @@ func.func @test_sub_with_unequal_result_ranks(%arg0: tensor<1x21x3xf32>, %arg1: // CHECK-LABEL: test_mul_non_scalar_shift_2d func.func @test_mul_non_scalar_shift_2d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> { %shift = "tosa.const"() <{values = dense<0> : tensor<1x1xi8>}> : () -> tensor<1x1xi8> - // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}} + // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant unranked tensor of 8-bit signless integer values or tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}} %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1x1xi8>) -> tensor<13x21x3xf32> return %0 : tensor<13x21x3xf32> } @@ -1134,7 +1134,7 @@ func.func @test_mul_non_scalar_shift_2d(%arg0: tensor<13x21x3xf32>, %arg1: tenso // CHECK-LABEL: test_mul_non_scalar_shift_1d func.func @test_mul_non_scalar_shift_1d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> { %shift = "tosa.const"() <{values = dense<0> : tensor<2xi8>}> : () -> tensor<2xi8> - // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<2xi8>'}} + // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant unranked tensor of 8-bit signless integer values or tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<2xi8>'}} %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<2xi8>) -> tensor<13x21x3xf32> return %0 : tensor<13x21x3xf32> } @@ -2036,3 +2036,19 @@ func.func @test_scatter_duplicate_indices(%arg0: tensor<2x52x3xf32>, %arg2: tens %0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi32>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32> return %0 : tensor<2x52x3xf32> } + +// ----- + +func.func @test_reduce_all_unsupported_data_types(%arg0: tensor<2x12x11xf32>) -> tensor<1x12x11xf32> { + // expected-error@+1 {{'tosa.reduce_all' op illegal: operation operand/result data types did not align with any profile or extension, got (f32,f32), did you mean (i1,i1)?}} + %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<2x12x11xf32>) -> tensor<1x12x11xf32> + return %0 : tensor<1x12x11xf32> +} + +// ----- + +func.func @test_rfft2d(%arg0: tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, tensor<13x8x9xbf16>) { + // expected-error@+1 {{'tosa.rfft2d' op illegal: operation operand/result data types did not align with any profile or extension, got (bf16,bf16,bf16), did you mean (f32,f32,f32)?}} + %0, %1 = tosa.rfft2d %arg0 : (tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, tensor<13x8x9xbf16>) + return %0, %1 : tensor<13x8x9xbf16>, tensor<13x8x9xbf16> +} diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir index 5630c33..3154f54 100644 --- a/mlir/test/Dialect/Tosa/invalid_extension.mlir +++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir @@ -337,7 +337,7 @@ func.func @test_cast_bf16_i32(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xi32 // ----- func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { // expected-error@+1 {{'tosa.cond_if' op illegal: requires [controlflow]}} - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %1 : tensor<f32> } else { diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index 0dddf26..bf9ed8a 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -48,10 +48,10 @@ func.func @test_add_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tens // ----- -func.func @test_arithmetic_right_shift_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> { +func.func @test_arithmetic_right_shift_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi32>, %arg1: tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32> { // expected-error@+1 {{'tosa.arithmetic_right_shift' op failed level check: operand rank(shape) <= MAX_RANK}} - %0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (tensor<1x1x1x1x13x21x3xf32>, tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> - return %0 : tensor<1x1x1x1x13x21x3xf32> + %0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (tensor<1x1x1x1x13x21x3xi32>, tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32> + return %0 : tensor<1x1x1x1x13x21x3xi32> } // ----- @@ -1506,13 +1506,13 @@ func.func @test_while_tensor_list_size(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1: // ----- func.func @test_cond_if_max_nested_depth(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>, %arg3: tensor<i1>) -> tensor<f32> { - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { - %1 = tosa.cond_if %arg3 -> (tensor<f32>) { - %2 = tosa.cond_if %arg2 -> (tensor<f32>) { - %3 = tosa.cond_if %arg3 -> (tensor<f32>) { - %4 = tosa.cond_if %arg2 -> (tensor<f32>) { + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { + %1 = tosa.cond_if %arg3 : tensor<i1>-> tensor<f32> { + %2 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { + %3 = tosa.cond_if %arg3 : tensor<i1> -> tensor<f32> { + %4 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { // expected-error@+1 {{'tosa.cond_if' op failed level check: 6 >= MAX_NESTING}} - %5 = tosa.cond_if %arg3 -> (tensor<f32>) { + %5 = tosa.cond_if %arg3 : tensor<i1> -> tensor<f32> { %res = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %res : tensor<f32> } else { diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index ef51197e..30361a8 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -839,7 +839,7 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> { // ----- // CHECK-LABEL: cond_if func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %1 : tensor<f32> } else { diff --git a/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir b/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir index 38ac8d8..e957bdd 100644 --- a/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir +++ b/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir @@ -54,7 +54,7 @@ func.func @test_no_change(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> { // CHECK-LABEL: test_regions // CHECK: %arg0: tensor<i8>, %arg1: tensor<i8> func.func @test_regions(%arg0: tensor<ui8>, %arg1: tensor<ui8>, %arg2: tensor<i1>) -> tensor<ui8> { - // CHECK: tosa.cond_if %arg2 -> (tensor<i8>) + // CHECK: tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<i8>, tensor<i8>) -> tensor<i8> %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ ^bb0(%arg3: tensor<ui8>, %arg4: tensor<ui8>): // CHECK: %1 = tosa.add %arg0, %arg1 : (tensor<i8>, tensor<i8>) -> tensor<i8> diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index d0f4027..7b8fc24 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -344,6 +344,30 @@ func.func @test_unranked_zero_points_matmul(%arg0: tensor<1x2x3xf32>, %arg1: ten // ----- +// CHECK-LABEL: @test_accepts_unranked_scalar_tensor +func.func @test_accepts_unranked_scalar_tensor(%arg0: tensor<1x2x2xf32>, %arg1: tensor<1xf32>) -> tensor<*xf32> { + // CHECK: %[[ZP:.*]] = tosa.cast %arg1 : (tensor<1xf32>) -> tensor<1xf32> + %0 = tosa.cast %arg1 : (tensor<1xf32>) -> tensor<*xf32> + // CHECK: %[[SHAPE:.*]] = tosa.const_shape + %1 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6> + // CHECK: tosa.pad %arg0, %[[SHAPE]], %[[ZP]] : (tensor<1x2x2xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<1x3x3xf32> + %2 = tosa.pad %arg0, %1, %0 : (tensor<1x2x2xf32>, !tosa.shape<6>, tensor<*xf32>) -> tensor<*xf32> + return %2 : tensor<*xf32> +} + +// ----- + +// CHECK-LABEL: @test_unranked_scalar_i8_tensor +func.func @test_unranked_scalar_i8_tensor(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>, %arg2: tensor<1xi8>) -> tensor<4xi32> { + // CHECK: %[[SHIFT:.*]] = tosa.cast %arg2 : (tensor<1xi8>) -> tensor<1xi8> + %shift = tosa.cast %arg2 : (tensor<1xi8>) -> tensor<*xi8> + // CHECK: tosa.mul %arg0, %arg1, %[[SHIFT]] : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32> + %0 = tosa.mul %arg0, %arg1, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<*xi8>) -> tensor<4xi32> + return %0 : tensor<4xi32> +} + +// ----- + // CHECK-LABEL: @test_table_static func.func @test_table_static(%arg0 : tensor<4x5xi16>, %arg1 : tensor<513xi16>) -> () { // CHECK:tosa.table %arg0, %arg1 : (tensor<4x5xi16>, tensor<513xi16>) -> tensor<4x5xi16> @@ -1153,8 +1177,8 @@ func.func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tens %b = tosa.log %arg1 : (tensor<f32>) -> tensor<f32> // CHECK: tosa.cond_if - // CHECK: -> (tensor<f32>) - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + // CHECK: -> tensor<f32> + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { tosa.yield %a : tensor<f32> } else { tosa.yield %b : tensor<f32> @@ -1167,8 +1191,8 @@ func.func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tens // CHECK-LABEL: @if_test_dynamic func.func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () { // CHECK: tosa.cond_if - // CHECK: -> (tensor<?xf32>) - %0 = tosa.cond_if %arg2 -> (tensor<?xf32>) { + // CHECK: -> tensor<?xf32> + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<?xf32> { tosa.yield %arg0 : tensor<2xf32> } else { tosa.yield %arg1 : tensor<3xf32> @@ -1181,8 +1205,8 @@ func.func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 : // CHECK-LABEL: @if_test_unranked func.func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () { // CHECK: tosa.cond_if - // CHECK: -> (tensor<*xf32>) - %0 = tosa.cond_if %arg2 -> (tensor<*xf32>) { + // CHECK: -> tensor<*xf32> + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<*xf32> { tosa.yield %arg0 : tensor<f32> } else { tosa.yield %arg1 : tensor<3xf32> @@ -1195,8 +1219,8 @@ func.func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 : // CHECK-LABEL: @if_test_propagate func.func @if_test_propagate(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> () { // CHECK: tosa.cond_if - // CHECK: -> (tensor<f32>) - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + // CHECK: -> tensor<f32> + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %1 : tensor<f32> } else { diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir index b305236..2a937b0 100644 --- a/mlir/test/Dialect/Tosa/verifier.mlir +++ b/mlir/test/Dialect/Tosa/verifier.mlir @@ -500,9 +500,39 @@ func.func @test_cond_if_input_list_mismatch_else_block_2(%arg0: tensor<f32>, %ar // ----- +func.func @test_cond_if_input_list_mismatch_else_block_simple(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { + // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' arguments (1) and 'input_list' (2)}} + %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<f32>, tensor<f32>) -> tensor<f32> { + ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): + %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + } else { + ^bb0(%arg3: tensor<f32>): + tosa.yield %arg3 : tensor<f32> + } + return %0 : tensor<f32> +} + +// ----- + +func.func @test_cond_if_input_list_mismatch_else_block_simple_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { + // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' arguments (2) and 'input_list' (1)}} + %0 = tosa.cond_if %arg2 (%arg3 = %arg0) : tensor<i1> (tensor<f32>) -> tensor<f32> { + ^bb0(%arg3: tensor<f32>): + tosa.yield %arg3 : tensor<f32> + } else { + ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): + %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + } + return %0 : tensor<f32> +} + +// ----- + func.func @test_cond_if_output_list_mismatch_then_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'then_graph' results (2) and 'output_list' (1)}} - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> %2 = tosa.add %1, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %1, %2 : tensor<f32>, tensor<f32> @@ -517,7 +547,7 @@ func.func @test_cond_if_output_list_mismatch_then_block(%arg0: tensor<f32>, %arg func.func @test_cond_if_output_list_mismatch_then_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'then_graph' results (1) and 'output_list' (2)}} - %0, %2 = tosa.cond_if %arg2 -> (tensor<f32>, tensor<f32>) { + %0, %2 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<f32>, tensor<f32>) { %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %1 : tensor<f32> } else { @@ -531,7 +561,7 @@ func.func @test_cond_if_output_list_mismatch_then_block_2(%arg0: tensor<f32>, %a func.func @test_cond_if_output_list_mismatch_else_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' results (2) and 'output_list' (1)}} - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %1 : tensor<f32> } else { @@ -546,7 +576,7 @@ func.func @test_cond_if_output_list_mismatch_else_block(%arg0: tensor<f32>, %arg func.func @test_cond_if_output_list_mismatch_else_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' results (1) and 'output_list' (2)}} - %0, %2 = tosa.cond_if %arg2 -> (tensor<f32>, tensor<f32>) { + %0, %2 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<f32>, tensor<f32>) { %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> %2 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %1, %2 : tensor<f32>, tensor<f32> @@ -574,6 +604,53 @@ func.func @test_cond_if_cond_input_not_size_one(%arg0: tensor<f32>, %arg1: tenso // ----- +// CHECK-LABEL: cond_if_cond_type +func.func @test_cond_if_cond_type(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { + // expected-error@+2 {{expected ':'}} + // expected-error@+1 {{custom op 'tosa.cond_if' expected type for condition operand}} + %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + tosa.yield %arg0 : tensor<f32> + } else { + tosa.yield %arg1 : tensor<f32> + } + return %0 : tensor<f32> +} + +// ----- + +func.func @test_cond_if_input_list_type_mismatch_simple(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { + // expected-error@+1 {{custom op 'tosa.cond_if' expected as many input types as operands (expected 2 got 0)}} + %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> () -> tensor<f32> { + ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): + %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + } else { + ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): + %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + } + return %0 : tensor<f32> +} + +// ----- + +func.func @test_cond_if_incorrect_type_simple(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { + // expected-error@+2 {{expected non-function type}} + // expected-error@+1 {{custom op 'tosa.cond_if' expected list of types for block arguments followed by arrow type and list of return types}} + %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (%arg3) -> tensor<f32> { + ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): + %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + } else { + ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): + %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + } + return %0 : tensor<f32> +} + +// ----- + func.func @test_while_loop_input_list_mismatch_body_block_in(%arg0: tensor<10xi32>, %arg1: tensor<i32>) { %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32> // expected-error@+1 {{'tosa.while_loop' op require same number of values in 'body_graph' arguments (3) and 'input_list' (2)}} diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 1461c30..56996b5 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1330,11 +1330,11 @@ func.func @fold_consecutive_broadcasts(%a : i32) -> vector<4x16xi32> { // ----- -// CHECK-LABEL: shape_cast_constant +// CHECK-LABEL: shape_cast_splat_constant // CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : vector<3x4x2xi32> // CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<20x2xf32> // CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32> -func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) { +func.func @shape_cast_splat_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) { %cst = arith.constant dense<2.000000e+00> : vector<5x4x2xf32> %cst_1 = arith.constant dense<1> : vector<12x2xi32> %0 = vector.shape_cast %cst : vector<5x4x2xf32> to vector<20x2xf32> @@ -1344,6 +1344,36 @@ func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) { // ----- +// Test of shape_cast's fold method: +// shape_cast(constant) -> constant. +// +// CHECK-LABEL: @shape_cast_dense_int_constant +// CHECK: %[[CST:.*]] = arith.constant +// CHECK-SAME{LITERAL}: dense<[[2, 3, 5], [7, 11, 13]]> +// CHECK: return %[[CST]] : vector<2x3xi8> +func.func @shape_cast_dense_int_constant() -> vector<2x3xi8> { + %cst = arith.constant dense<[2, 3, 5, 7, 11, 13]> : vector<6xi8> + %0 = vector.shape_cast %cst : vector<6xi8> to vector<2x3xi8> + return %0 : vector<2x3xi8> +} + +// ----- + +// Test of shape_cast fold's method: +// (shape_cast(const_x), const_x) -> (const_x_folded, const_x) +// +// CHECK-LABEL: @shape_cast_dense_float_constant +// CHECK-DAG: %[[CST0:.*]] = {{.*}}1.000000e+00, 2.000000e+00{{.*}} vector<1x2xf32> +// CHECK-DAG: %[[CST1:.*]] = {{.*}}1.000000e+00, 2.000000e+00{{.*}} vector<2xf32> +// CHECK: return %[[CST1]], %[[CST0]] : vector<2xf32>, vector<1x2xf32> +func.func @shape_cast_dense_float_constant() -> (vector<2xf32>, vector<1x2xf32>){ + %cst = arith.constant dense<[[1.0, 2.0]]> : vector<1x2xf32> + %0 = vector.shape_cast %cst : vector<1x2xf32> to vector<2xf32> + return %0, %cst : vector<2xf32>, vector<1x2xf32> +} + +// ----- + // CHECK-LABEL: shape_cast_poison // CHECK-DAG: %[[CST1:.*]] = ub.poison : vector<3x4x2xi32> // CHECK-DAG: %[[CST0:.*]] = ub.poison : vector<20x2xf32> @@ -2562,118 +2592,6 @@ func.func @insert_2d_splat_constant() // ----- -// CHECK-LABEL: func @insert_element_fold -// CHECK: %[[V:.+]] = arith.constant dense<[0, 1, 7, 3]> : vector<4xi32> -// CHECK: return %[[V]] -func.func @insert_element_fold() -> vector<4xi32> { - %v = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32> - %s = arith.constant 7 : i32 - %i = arith.constant 2 : i32 - %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32> - return %1 : vector<4xi32> -} - -// ----- - -// CHECK-LABEL: func @insert_element_invalid_fold -func.func @insert_element_invalid_fold() -> vector<1xf32> { - // Out-of-bound index here. - %c26 = arith.constant 26 : index - %cst_2 = arith.constant 1.60215309E+9 : f32 - %cst_20 = arith.constant dense<1.60215309E+9> : vector<1xf32> -// CHECK: vector.insertelement - %46 = vector.insertelement %cst_2, %cst_20[%c26 : index] : vector<1xf32> - return %46 : vector<1xf32> -} - - -// ----- - -// Do not crash on poison -// CHECK-LABEL: func @insert_poison_fold1 -// CHECK: vector.insertelement -func.func @insert_poison_fold1() -> vector<4xi32> { - %v = ub.poison : vector<4xi32> - %s = arith.constant 7 : i32 - %i = arith.constant 2 : i32 - %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32> - return %1 : vector<4xi32> -} - -// ----- - -// Do not crash on poison -// CHECK-LABEL: func @insert_poison_fold2 -// CHECK: vector.insertelement -func.func @insert_poison_fold2() -> vector<4xi32> { - %v = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32> - %s = ub.poison : i32 - %i = arith.constant 2 : i32 - %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32> - return %1 : vector<4xi32> -} - -// ----- - -// Do not crash on poison -// CHECK-LABEL: func @insert_poison_fold3 -// CHECK: vector.insertelement -func.func @insert_poison_fold3() -> vector<4xi32> { - %v = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32> - %s = arith.constant 7 : i32 - %i = ub.poison : i32 - %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32> - return %1 : vector<4xi32> -} - -// ----- - -// CHECK-LABEL: func @extract_element_fold -// CHECK: %[[C:.+]] = arith.constant 5 : i32 -// CHECK: return %[[C]] -func.func @extract_element_fold() -> i32 { - %v = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32> - %i = arith.constant 2 : i32 - %1 = vector.extractelement %v[%i : i32] : vector<4xi32> - return %1 : i32 -} - -// CHECK-LABEL: func @extract_element_splat_fold -// CHECK-SAME: (%[[ARG:.+]]: i32) -// CHECK: return %[[ARG]] -func.func @extract_element_splat_fold(%a : i32) -> i32 { - %v = vector.splat %a : vector<4xi32> - %i = arith.constant 2 : i32 - %1 = vector.extractelement %v[%i : i32] : vector<4xi32> - return %1 : i32 -} - -// ----- - -// Do not crash on poison -// CHECK-LABEL: func @extract_element_poison_fold1 -// CHECK: vector.extractelement -func.func @extract_element_poison_fold1() -> i32 { - %v = ub.poison : vector<4xi32> - %i = arith.constant 2 : i32 - %1 = vector.extractelement %v[%i : i32] : vector<4xi32> - return %1 : i32 -} - -// ----- - -// Do not crash on poison -// CHECK-LABEL: func @extract_element_poison_fold2 -// CHECK: vector.extractelement -func.func @extract_element_poison_fold2() -> i32 { - %v = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32> - %i = ub.poison : i32 - %1 = vector.extractelement %v[%i : i32] : vector<4xi32> - return %1 : i32 -} - -// ----- - // CHECK-LABEL: func @reduce_one_element_vector_extract // CHECK-SAME: (%[[V:.+]]: vector<1xf32>) // CHECK: %[[S:.+]] = vector.extract %[[V]][0] : f32 from vector<1xf32> @@ -2933,18 +2851,6 @@ func.func @extract_strided_slice_of_constant_mask() -> vector<5x7xi1>{ // ----- -// CHECK-LABEL: func.func @fold_extractelement_of_broadcast( -// CHECK-SAME: %[[f:.*]]: f32 -// CHECK: return %[[f]] -func.func @fold_extractelement_of_broadcast(%f: f32) -> f32 { - %0 = vector.broadcast %f : f32 to vector<15xf32> - %c5 = arith.constant 5 : index - %1 = vector.extractelement %0 [%c5 : index] : vector<15xf32> - return %1 : f32 -} - -// ----- - // CHECK-LABEL: func.func @fold_0d_vector_reduction func.func @fold_0d_vector_reduction(%arg0: vector<f32>) -> f32 { // CHECK-NEXT: %[[RES:.*]] = vector.extract %arg{{.*}}[] : f32 from vector<f32> diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir index 0263193..b2f16bb 100644 --- a/mlir/test/Dialect/Vector/int-range-interface.mlir +++ b/mlir/test/Dialect/Vector/int-range-interface.mlir @@ -51,6 +51,15 @@ func.func @vector_shape_cast() -> vector<4x4xindex> { func.return %2 : vector<4x4xindex> } +// CHECK-LABEL: func @vector_transpose +// CHECK: test.reflect_bounds {smax = 8 : index, smin = 7 : index, umax = 8 : index, umin = 7 : index} +func.func @vector_transpose() -> vector<2x4xindex> { + %0 = test.with_bounds { smax = 8 : index, smin = 7 : index, umax = 8 : index, umin = 7 : index } : vector<4x2xindex> + %1 = vector.transpose %0, [1, 0] : vector<4x2xindex> to vector<2x4xindex> + %2 = test.reflect_bounds %1 : vector<2x4xindex> + func.return %2 : vector<2x4xindex> +} + // CHECK-LABEL: func @vector_extract // CHECK: test.reflect_bounds {smax = 6 : index, smin = 5 : index, umax = 6 : index, umin = 5 : index} func.func @vector_extract() -> index { @@ -60,16 +69,6 @@ func.func @vector_extract() -> index { func.return %2 : index } -// CHECK-LABEL: func @vector_extractelement -// CHECK: test.reflect_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index} -func.func @vector_extractelement() -> index { - %c0 = arith.constant 0 : index - %0 = test.with_bounds { umin = 6 : index, umax = 7 : index, smin = 6 : index, smax = 7 : index } : vector<4xindex> - %1 = vector.extractelement %0[%c0 : index] : vector<4xindex> - %2 = test.reflect_bounds %1 : index - func.return %2 : index -} - // CHECK-LABEL: func @vector_add // CHECK: test.reflect_bounds {smax = 12 : index, smin = 10 : index, umax = 12 : index, umin = 10 : index} func.func @vector_add() -> vector<4xindex> { @@ -90,17 +89,6 @@ func.func @vector_insert() -> vector<4xindex> { func.return %3 : vector<4xindex> } -// CHECK-LABEL: func @vector_insertelement -// CHECK: test.reflect_bounds {smax = 8 : index, smin = 5 : index, umax = 8 : index, umin = 5 : index} -func.func @vector_insertelement() -> vector<4xindex> { - %c0 = arith.constant 0 : index - %0 = test.with_bounds { umin = 5 : index, umax = 7 : index, smin = 5 : index, smax = 7 : index } : vector<4xindex> - %1 = test.with_bounds { umin = 6 : index, umax = 8 : index, smin = 6 : index, smax = 8 : index } : index - %2 = vector.insertelement %1, %0[%c0 : index] : vector<4xindex> - %3 = test.reflect_bounds %2 : vector<4xindex> - func.return %3 : vector<4xindex> -} - // CHECK-LABEL: func @test_loaded_vector_extract // No bounds // CHECK: test.reflect_bounds {smax = 2147483647 : si32, smin = -2147483648 : si32, umax = 4294967295 : ui32, umin = 0 : ui32} %{{.*}} : i32 @@ -120,3 +108,11 @@ func.func @test_vector_extsi() -> vector<2xi32> { %2 = test.reflect_bounds %1 : vector<2xi32> func.return %2 : vector<2xi32> } + +// CHECK-LABEL: func @vector_step +// CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index} +func.func @vector_step() -> vector<8xindex> { + %0 = vector.step : vector<8xindex> + %1 = test.reflect_bounds %0 : vector<8xindex> + func.return %1 : vector<8xindex> +} diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index ca837d3..c21de56 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -119,30 +119,6 @@ func.func @shuffle_empty_mask(%arg0: vector<2xf32>, %arg1: vector<2xf32>) { // ----- -func.func @extract_element(%arg0: vector<f32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{expected position to be empty with 0-D vector}} - %1 = vector.extractelement %arg0[%c : i32] : vector<f32> -} - -// ----- - -func.func @extract_element(%arg0: vector<4xf32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{expected position for 1-D vector}} - %1 = vector.extractelement %arg0[] : vector<4xf32> -} - -// ----- - -func.func @extract_element(%arg0: vector<4x4xf32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{unexpected >1 vector rank}} - %1 = vector.extractelement %arg0[%c : i32] : vector<4x4xf32> -} - -// ----- - func.func @extract_vector_type(%arg0: index) { // expected-error@+1 {{invalid kind of type specified: expected builtin.vector, but found 'index'}} %1 = vector.extract %arg0[] : index from index @@ -192,38 +168,6 @@ func.func @extract_position_overflow(%arg0: vector<4x8x16xf32>) { // ----- -func.func @insert_element(%arg0: f32, %arg1: vector<f32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{expected position to be empty with 0-D vector}} - %0 = vector.insertelement %arg0, %arg1[%c : i32] : vector<f32> -} - -// ----- - -func.func @insert_element(%arg0: f32, %arg1: vector<4xf32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{expected position for 1-D vector}} - %0 = vector.insertelement %arg0, %arg1[] : vector<4xf32> -} - -// ----- - -func.func @insert_element(%arg0: f32, %arg1: vector<4x4xf32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{unexpected >1 vector rank}} - %0 = vector.insertelement %arg0, %arg1[%c : i32] : vector<4x4xf32> -} - -// ----- - -func.func @insert_element_wrong_type(%arg0: i32, %arg1: vector<4xf32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{'vector.insertelement' op failed to verify that source operand type matches element type of result}} - %0 = "vector.insertelement" (%arg0, %arg1, %c) : (i32, vector<4xf32>, i32) -> (vector<4xf32>) -} - -// ----- - func.func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) { // expected-error@+1 {{expected position attribute of rank no greater than dest vector rank}} %1 = vector.insert %a, %b[3, 3, 3, 3, 3, 3] : f32 into vector<4x8x16xf32> diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 6a56116..625ffc1 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -199,22 +199,6 @@ func.func @shuffle_poison_mask(%a: vector<4xf32>, %b: vector<4xf32>) -> vector<4 return %1 : vector<4xf32> } -// CHECK-LABEL: @extract_element_0d -func.func @extract_element_0d(%a: vector<f32>) -> f32 { - // CHECK-NEXT: vector.extractelement %{{.*}}[] : vector<f32> - %1 = vector.extractelement %a[] : vector<f32> - return %1 : f32 -} - -// CHECK-LABEL: @extract_element -func.func @extract_element(%a: vector<16xf32>) -> f32 { - // CHECK: %[[C15:.*]] = arith.constant 15 : i32 - %c = arith.constant 15 : i32 - // CHECK-NEXT: vector.extractelement %{{.*}}[%[[C15]] : i32] : vector<16xf32> - %1 = vector.extractelement %a[%c : i32] : vector<16xf32> - return %1 : f32 -} - // CHECK-LABEL: @extract_const_idx func.func @extract_const_idx(%arg0: vector<4x8x16xf32>) -> (vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32) { @@ -256,22 +240,6 @@ func.func @extract_poison_idx(%a: vector<4x5xf32>) -> f32 { return %0 : f32 } -// CHECK-LABEL: @insert_element_0d -func.func @insert_element_0d(%a: f32, %b: vector<f32>) -> vector<f32> { - // CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[] : vector<f32> - %1 = vector.insertelement %a, %b[] : vector<f32> - return %1 : vector<f32> -} - -// CHECK-LABEL: @insert_element -func.func @insert_element(%a: f32, %b: vector<16xf32>) -> vector<16xf32> { - // CHECK: %[[C15:.*]] = arith.constant 15 : i32 - %c = arith.constant 15 : i32 - // CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[%[[C15]] : i32] : vector<16xf32> - %1 = vector.insertelement %a, %b[%c : i32] : vector<16xf32> - return %1 : vector<16xf32> -} - // CHECK-LABEL: @insert_const_idx func.func @insert_const_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vector<4x8x16xf32>) -> vector<4x8x16xf32> { diff --git a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir index 8e167a5..d5e3443 100644 --- a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir @@ -2,7 +2,7 @@ // CHECK-LABEL: func @broadcast_vec1d_from_scalar // CHECK-SAME: %[[A:.*0]]: f32 -// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2xf32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2xf32> // CHECK: return %[[T0]] : vector<2xf32> func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> { @@ -12,7 +12,7 @@ func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> { // CHECK-LABEL: func @broadcast_vec2d_from_scalar // CHECK-SAME: %[[A:.*0]]: f32 -// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3xf32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3xf32> // CHECK: return %[[T0]] : vector<2x3xf32> func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> { @@ -22,7 +22,7 @@ func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> { // CHECK-LABEL: func @broadcast_vec3d_from_scalar // CHECK-SAME: %[[A:.*0]]: f32 -// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3x4xf32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3x4xf32> // CHECK: return %[[T0]] : vector<2x3x4xf32> func.func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> { @@ -87,7 +87,7 @@ func.func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf3 // CHECK-LABEL: func @broadcast_stretch // CHECK-SAME: %[[A:.*0]]: vector<1xf32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<1xf32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<4xf32> +// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<4xf32> // CHECK: return %[[T1]] : vector<4xf32> func.func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> { @@ -113,16 +113,16 @@ func.func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32> // CHECK-SAME: %[[A:.*0]]: vector<4x1xf32> // CHECK: %[[U0:.*]] = ub.poison : vector<4x3xf32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<4x1xf32> -// CHECK: %[[T2:.*]] = vector.splat %[[T0]] : vector<3xf32> +// CHECK: %[[T2:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32> // CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[U0]] [0] : vector<3xf32> into vector<4x3xf32> // CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : f32 from vector<4x1xf32> -// CHECK: %[[T6:.*]] = vector.splat %[[T4]] : vector<3xf32> +// CHECK: %[[T6:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32> // CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<4x3xf32> // CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : f32 from vector<4x1xf32> -// CHECK: %[[T10:.*]] = vector.splat %[[T8]] : vector<3xf32> +// CHECK: %[[T10:.*]] = vector.broadcast %[[T8]] : f32 to vector<3xf32> // CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T7]] [2] : vector<3xf32> into vector<4x3xf32> // CHECK: %[[T12:.*]] = vector.extract %[[A]][3, 0] : f32 from vector<4x1xf32> -// CHECK: %[[T14:.*]] = vector.splat %[[T12]] : vector<3xf32> +// CHECK: %[[T14:.*]] = vector.broadcast %[[T12]] : f32 to vector<3xf32> // CHECK: %[[T15:.*]] = vector.insert %[[T14]], %[[T11]] [3] : vector<3xf32> into vector<4x3xf32> // CHECK: return %[[T15]] : vector<4x3xf32> diff --git a/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir index 059d955..5a8125e 100644 --- a/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir @@ -5,11 +5,11 @@ // CHECK-SAME: %[[B:.*1]]: vector<3xf32> // CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32> +// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32> // CHECK: %[[T2:.*]] = arith.mulf %[[T1]], %[[B]] : vector<3xf32> // CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32> // CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32> -// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xf32> +// CHECK: %[[T5:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32> // CHECK: %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32> // CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<2x3xf32> // CHECK: return %[[T7]] : vector<2x3xf32> @@ -26,12 +26,12 @@ func.func @outerproduct_noacc(%arg0: vector<2xf32>, // CHECK-SAME: %[[C:.*2]]: vector<2x3xf32> // CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32> +// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32> // CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xf32> from vector<2x3xf32> // CHECK: %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32> // CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32> // CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32> -// CHECK: %[[T6:.*]] = vector.splat %[[T5]] : vector<3xf32> +// CHECK: %[[T6:.*]] = vector.broadcast %[[T5]] : f32 to vector<3xf32> // CHECK: %[[T7:.*]] = vector.extract %[[C]][1] : vector<3xf32> from vector<2x3xf32> // CHECK: %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32> // CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32> @@ -49,11 +49,11 @@ func.func @outerproduct_acc(%arg0: vector<2xf32>, // CHECK-SAME: %[[B:.*1]]: vector<3xi32> // CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32> +// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : i32 to vector<3xi32> // CHECK: %[[T2:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32> // CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32> // CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32> -// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xi32> +// CHECK: %[[T5:.*]] = vector.broadcast %[[T4]] : i32 to vector<3xi32> // CHECK: %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32> // CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xi32> into vector<2x3xi32> // CHECK: return %[[T7]] : vector<2x3xi32> @@ -69,13 +69,13 @@ func.func @outerproduct_noacc_int(%arg0: vector<2xi32>, // CHECK-SAME: %[[C:.*2]]: vector<2x3xi32> // CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32> +// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : i32 to vector<3xi32> // CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xi32> from vector<2x3xi32> // CHECK: %[[T3:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32> // CHECK: %[[T4:.*]] = arith.addi %[[T3]], %[[T2]] : vector<3xi32> // CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32> // CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32> -// CHECK: %[[T7:.*]] = vector.splat %[[T6]] : vector<3xi32> +// CHECK: %[[T7:.*]] = vector.broadcast %[[T6]] : i32 to vector<3xi32> // CHECK: %[[T8:.*]] = vector.extract %[[C]][1] : vector<3xi32> from vector<2x3xi32> // CHECK: %[[T9:.*]] = arith.muli %[[T7]], %[[B]] : vector<3xi32> // CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[T8]] : vector<3xi32> @@ -91,7 +91,7 @@ func.func @outerproduct_acc_int(%arg0: vector<2xi32>, // CHECK-LABEL: func @axpy_fp( // CHECK-SAME: %[[A:.*0]]: vector<16xf32>, // CHECK-SAME: %[[B:.*1]]: f32) -// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : f32 to vector<16xf32> // CHECK: %[[T1:.*]] = arith.mulf %[[A]], %[[T0]] : vector<16xf32> // CHECK: return %[[T1]] : vector<16xf32> func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> { @@ -103,7 +103,7 @@ func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> { // CHECK-SAME: %[[A:.*0]]: vector<16xf32>, // CHECK-SAME: %[[B:.*1]]: f32, // CHECK-SAME: %[[C:.*2]]: vector<16xf32>) -// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : f32 to vector<16xf32> // CHECK: %[[T1:.*]] = vector.fma %[[A]], %[[T0]], %[[C]] : vector<16xf32> // CHECK: return %[[T1]] : vector<16xf32> func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) -> vector<16xf32> { @@ -114,7 +114,7 @@ func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32> // CHECK-LABEL: func @axpy_int( // CHECK-SAME: %[[A:.*0]]: vector<16xi32>, // CHECK-SAME: %[[B:.*1]]: i32) -// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : i32 to vector<16xi32> // CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32> // CHECK: return %[[T1]] : vector<16xi32> func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> { @@ -126,7 +126,7 @@ func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> { // CHECK-SAME: %[[A:.*0]]: vector<16xi32>, // CHECK-SAME: %[[B:.*1]]: i32, // CHECK-SAME: %[[C:.*2]]: vector<16xi32>) -// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : i32 to vector<16xi32> // CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32> // CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[C]] : vector<16xi32> // CHECK: return %[[T2]] : vector<16xi32> diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir index b826cdc..ef881ba 100644 --- a/mlir/test/Dialect/Vector/vector-sink.mlir +++ b/mlir/test/Dialect/Vector/vector-sink.mlir @@ -180,13 +180,14 @@ func.func @negative_not_elementwise() -> vector<2x2xf32> { // ----- -// The source and the result for arith.cmp have different types - not supported - -// CHECK-LABEL: func.func @negative_source_and_result_mismatch -// CHECK: %[[BROADCAST:.+]] = vector.broadcast -// CHECK: %[[RETURN:.+]] = arith.cmpf uno, %[[BROADCAST]], %[[BROADCAST]] -// CHECK: return %[[RETURN]] -func.func @negative_source_and_result_mismatch(%arg0 : f32, %arg1 : vector<1xf32>) -> vector<1xi1> { +// The source and the result for arith.cmp have different types + +// CHECK-LABEL: func.func @source_and_result_mismatch( +// CHECK-SAME: %[[ARG0:.+]]: f32) +// CHECK: %[[COMPARE:.+]] = arith.cmpf uno, %[[ARG0]], %[[ARG0]] +// CHECK: %[[BROADCAST:.+]] = vector.broadcast %[[COMPARE]] : i1 to vector<1xi1> +// CHECK: return %[[BROADCAST]] +func.func @source_and_result_mismatch(%arg0 : f32) -> vector<1xi1> { %0 = vector.broadcast %arg0 : f32 to vector<1xf32> %1 = arith.cmpf uno, %0, %0 : vector<1xf32> return %1 : vector<1xi1> @@ -210,6 +211,130 @@ func.func @negative_op_only_supports_vectors(%arg0 : f32) -> vector<1xf32> { return %1 : vector<1xf32> } +// ----- + +// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const( +// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> { +// CHECK: %[[NEW_CST:.*]] = arith.constant 2 : index +// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[NEW_CST]] : index +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex> +// CHECK: return %[[BCAST]] : vector<1x4xindex> + +func.func @broadcast_scalar_and_splat_const(%arg0: index) -> vector<1x4xindex> { + %0 = vector.broadcast %arg0 : index to vector<1x4xindex> + %cst = arith.constant dense<2> : vector<1x4xindex> + %2 = arith.addi %0, %cst : vector<1x4xindex> + return %2 : vector<1x4xindex> +} + +// ----- + +// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const_const_first( +// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> { +// CHECK: %[[NEW_CST:.*]] = arith.constant 2 : index +// CHECK: %[[SUB:.*]] = arith.subi %[[NEW_CST]], %[[ARG_0]] : index +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[SUB]] : index to vector<1x4xindex> +// CHECK: return %[[BCAST]] : vector<1x4xindex> + +func.func @broadcast_scalar_and_splat_const_const_first(%arg0: index) -> vector<1x4xindex> { + %0 = vector.broadcast %arg0 : index to vector<1x4xindex> + %cst = arith.constant dense<2> : vector<1x4xindex> + %2 = arith.subi %cst, %0 : vector<1x4xindex> + return %2 : vector<1x4xindex> +} + +// ----- + +// CHECK-LABEL: func.func @broadcast_vector_and_splat_const( +// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>) -> vector<3x4xf32> { +// CHECK: %[[NEW_CST:.*]] = arith.constant dense<2.000000e+00> : vector<4xf32> +// CHECK: %[[ADD:.*]] = arith.mulf %[[ARG_0]], %[[NEW_CST]] : vector<4xf32> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : vector<4xf32> to vector<3x4xf32> +// CHECK: return %[[BCAST]] : vector<3x4xf32> + +func.func @broadcast_vector_and_splat_const(%arg0: vector<4xf32>) -> vector<3x4xf32> { + %0 = vector.broadcast %arg0 : vector<4xf32> to vector<3x4xf32> + %cst = arith.constant dense<2.000000e+00> : vector<3x4xf32> + %2 = arith.mulf %0, %cst : vector<3x4xf32> + return %2 : vector<3x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @negative_broadcast_with_non_splat_const( +// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> { +// CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : index to vector<1x4xindex> +// CHECK-DAG: %[[CST:.*]] = arith.constant dense<{{\[}}[0, 1, 2, 3]]> : vector<1x4xindex> +// CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[CST]] : vector<1x4xindex> +// CHECK: return %[[ADD]] : vector<1x4xindex> + +func.func @negative_broadcast_with_non_splat_const(%arg0: index) -> vector<1x4xindex> { + %0 = vector.broadcast %arg0 : index to vector<1x4xindex> + %cst = arith.constant dense<[[0, 1, 2, 3]]> : vector<1x4xindex> + %2 = arith.addi %0, %cst : vector<1x4xindex> + return %2 : vector<1x4xindex> +} + +// ----- + +// CHECK-LABEL: func.func @broadcast_scalar_mixed_type( +// CHECK-SAME: %[[ARG_0:.*]]: f16) -> vector<1x4xf32> { +// CHECK: %[[EXTF:.*]] = arith.extf %[[ARG_0]] : f16 to f32 +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTF]] : f32 to vector<1x4xf32> +// CHECK: return %[[BCAST]] : vector<1x4xf32> + +func.func @broadcast_scalar_mixed_type(%arg0: f16) -> vector<1x4xf32> { + %0 = vector.broadcast %arg0 : f16 to vector<1x4xf16> + %1 = arith.extf %0 : vector<1x4xf16> to vector<1x4xf32> + return %1 : vector<1x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @broadcast_vector_mixed_type( +// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf16>) -> vector<3x4xf32> { +// CHECK: %[[EXTF:.*]] = arith.extf %[[ARG_0]] : vector<4xf16> to vector<4xf32> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTF]] : vector<4xf32> to vector<3x4xf32> +// CHECK: return %[[BCAST]] : vector<3x4xf32> + +func.func @broadcast_vector_mixed_type(%arg0: vector<4xf16>) -> vector<3x4xf32> { + %0 = vector.broadcast %arg0 : vector<4xf16> to vector<3x4xf16> + %1 = arith.extf %0 : vector<3x4xf16> to vector<3x4xf32> + return %1 : vector<3x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const_mixed_type( +// CHECK-SAME: %[[ARG_0:.*]]: f32) -> vector<1x4xf32> { +// CHECK: %[[NEW_CST:.*]] = arith.constant 3 : i32 +// CHECK: %[[POW:.*]] = math.fpowi %[[ARG_0]], %[[NEW_CST]] : f32, i32 +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[POW]] : f32 to vector<1x4xf32> +// CHECK: return %[[BCAST]] : vector<1x4xf32> + +func.func @broadcast_scalar_and_splat_const_mixed_type(%arg0: f32) -> vector<1x4xf32> { + %0 = vector.broadcast %arg0 : f32 to vector<1x4xf32> + %cst = arith.constant dense<3> : vector<1x4xi32> + %2 = math.fpowi %0, %cst : vector<1x4xf32>, vector<1x4xi32> + return %2 : vector<1x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @broadcast_vector_and_splat_const_mixed_type( +// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>) -> vector<3x4xf32> { +// CHECK: %[[NEW_CST:.*]] = arith.constant dense<3> : vector<4xi32> +// CHECK: %[[POW:.*]] = math.fpowi %[[ARG_0]], %[[NEW_CST]] : vector<4xf32>, vector<4xi32> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[POW]] : vector<4xf32> to vector<3x4xf32> +// CHECK: return %[[BCAST]] : vector<3x4xf32> + +func.func @broadcast_vector_and_splat_const_mixed_type(%arg0: vector<4xf32>) -> vector<3x4xf32> { + %0 = vector.broadcast %arg0 : vector<4xf32> to vector<3x4xf32> + %cst = arith.constant dense<3> : vector<3x4xi32> + %2 = math.fpowi %0, %cst : vector<3x4xf32>, vector<3x4xi32> + return %2 : vector<3x4xf32> +} + //===----------------------------------------------------------------------===// // [Pattern: ReorderCastOpsOnBroadcast] // diff --git a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir index 5dd65ea..44601a4 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir @@ -68,6 +68,24 @@ func.func @transfer_write_unroll(%mem : memref<4x4xf32>, %vec : vector<4x4xf32>) // ----- +// Ensure that cases with mismatched target and source shape ranks +// do not lead to a crash. +// Note: The vector unrolling target shape in `test-vector-transfer-unrolling-patterns` +// is currently hard-coded to [2, 2]. + +// CHECK-LABEL: func @negative_transfer_write +// CHECK-NOT: vector.extract_strided_slice +// CHECK: vector.transfer_write +// CHECK: return +func.func @negative_transfer_write(%vec: vector<6x34x62xi8>) { + %c0 = arith.constant 0 : index + %alloc = memref.alloc() : memref<6x34x62xi8> + vector.transfer_write %vec, %alloc[%c0, %c0, %c0]: vector<6x34x62xi8>, memref<6x34x62xi8> + return +} + +// ----- + // CHECK-LABEL: func @transfer_readwrite_unroll // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index 0160bfe..dff3ffa 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -385,6 +385,74 @@ func.func @load_gather_vc_3(%src: ui64) { } // ----- +func.func @prefetch_offset_wi_1(%src: memref<4x4xf32>) { + %offsets = arith.constant dense<[0]> : vector<1xindex> + // expected-error@+1 {{Expecting the source is a 1D memref or pointer}} + xegpu.prefetch %src[%offsets]: memref<4x4xf32>, vector<1xindex> + return +} + +// ----- +func.func @load_gather_offset_sg(%src: memref<?xf16>) { + %offsets = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %mask = arith.constant dense<1>: vector<8xi1> + // expected-error@+1 {{Mask should match value except the chunk size dim}} + %2 = xegpu.load %src[%offsets], %mask + : memref<?xf16>, vector<4xindex>, vector<8xi1> + -> vector<4x2xf16> + return +} + +// ----- +func.func @load_gather_offset_wi(%src: ui64) { + %mask = arith.constant dense<1>: vector<1xi1> + %offsets = arith.constant dense<[0]> : vector<1xindex> + // expected-error@+1 {{value elements must match chunk size}} + %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64, vector<1xindex>, vector<1xi1> -> vector<4xf32> + return +} + +// ----- +func.func @store_scatter_offset_wi_1(%src: memref<?xf16>) { + %val = arith.constant dense<2.9>: vector<4xf16> + %offsets = arith.constant dense<[0]> : vector<1xindex> + %mask = arith.constant dense<1>: vector<1xi1> + // expected-error@+1 {{value elements must match chunk size}} + xegpu.store %val, %src[%offsets], %mask + : vector<4xf16>, memref<?xf16>, vector<1xindex>, vector<1xi1> + return +} + +// ----- +func.func @store_scatter_offset_wi_2(%src: memref<4x4xf16>) { + %val = arith.constant dense<2.9>: vector<4xf16> + %offsets = arith.constant dense<[0]> : vector<1xindex> + %mask = arith.constant dense<1>: vector<1xi1> + // expected-error@+1 {{Expecting the dest is a 1D memref or pointer}} + xegpu.store %val, %src[%offsets], %mask + : vector<4xf16>, memref<4x4xf16>, vector<1xindex>, vector<1xi1> + return +} + +// ----- +func.func @load_gather_offset_wi_2(%src: ui64) { + %mask = arith.constant dense<1>: vector<1xi1> + %offsets = arith.constant dense<[0]> : vector<1xindex> + // expected-error@+1 {{value elements must match chunk size}} + %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64, vector<1xindex>, vector<1xi1> -> vector<4xf16> + return +} + +// ----- +func.func @load_gather_offset_wi_1(%src: memref<4x4xf32>) { + %mask = arith.constant dense<1>: vector<1xi1> + %offsets = arith.constant dense<[0]> : vector<1xindex> + // expected-error@+1 {{Expecting the source is a 1D memref or pointer}} + %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : memref<4x4xf32>, vector<1xindex>, vector<1xi1> -> vector<2xf32> + return +} + +// ----- func.func @store_scatter_vc_1(%src: memref<24x32xf32>) { %0 = arith.constant dense<1>: vector<4xi1> %1 = arith.constant dense<2.9>: vector<4x2xf32> diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir index 3ebb1b969a..6be2371 100644 --- a/mlir/test/Dialect/XeGPU/ops.mlir +++ b/mlir/test/Dialect/XeGPU/ops.mlir @@ -521,6 +521,16 @@ gpu.func @subgroup_load_4(%src: ui64) { gpu.return } +// CHECK: gpu.func @subgroup_load_offset_1(%arg0: memref<?xf16>) { +gpu.func @subgroup_load_offset_1(%src: memref<?xf16>) { + %offset = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %mask = arith.constant dense<1>: vector<4xi1> + //CHECK: %[[R1:.*]] = xegpu.load %arg0[%cst], %cst_0 <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : memref<?xf16>, vector<4xindex>, vector<4xi1> -> vector<4x2xf16> + %val = xegpu.load %src[%offset], %mask <{chunk_size=2, l1_hint = #xegpu.cache_hint<cached>}> + : memref<?xf16>, vector<4xindex>, vector<4xi1> -> vector<4x2xf16> + gpu.return +} + // CHECK: gpu.func @subgroup_store(%[[arg0:.*]]: ui64) { gpu.func @subgroup_store(%src: ui64) { //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> @@ -626,6 +636,17 @@ gpu.func @subgroup_store_4(%src: ui64) { gpu.return } +// CHECK: gpu.func @subgroup_store_offset_1(%arg0: memref<?xf16>) { +gpu.func @subgroup_store_offset_1(%dest: memref<?xf16>) { + %val = arith.constant dense<2.9>: vector<4x2xf16> + %offset = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %mask = arith.constant dense<1>: vector<4xi1> + //CHECK: xegpu.store %[[R0:.*]], %arg0[%cst_0], %cst_1 <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<4x2xf16>, memref<?xf16>, vector<4xindex>, vector<4xi1> + xegpu.store %val, %dest[%offset], %mask <{chunk_size=2, l1_hint = #xegpu.cache_hint<cached>}> + : vector<4x2xf16>, memref<?xf16>, vector<4xindex>, vector<4xi1> + gpu.return +} + // CHECK: gpu.func @prefetch(%[[arg0:.*]]: ui64) { gpu.func @prefetch(%src: ui64) { //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> @@ -637,6 +658,14 @@ gpu.func @prefetch(%src: ui64) { gpu.return } +// CHECK: gpu.func @prefetch_offset(%[[arg0:.*]]: ui64) { +gpu.func @prefetch_offset(%src: ui64) { + //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + // CHECK: xegpu.prefetch %[[arg0]][%cst] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : ui64, vector<4xindex> + xegpu.prefetch %src[%0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: ui64, vector<4xindex> + gpu.return +} // CHECK: gpu.func @create_update_tdesc(%[[arg0:.*]]: ui64) { gpu.func @create_update_tdesc(%src: ui64) { diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir index d67bdb4..628a485 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir @@ -2,122 +2,117 @@ gpu.module @test_round_robin_assignment { // CHECK-LABEL: create_nd_tdesc - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @create_nd_tdesc(%src: memref<24x32xf32>) { - // CHECK-COUNT-12: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<24x32xf32> - // CHECK-SAME: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) { + // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK-NOT: xegpu.create_nd_tdesc - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return } // CHECK-LABEL: load_nd_tdesc - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @load_nd_tdesc(%src: memref<24x32xf32>) { - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> - // CHECK-COUNT-12: xegpu.load_nd %{{.*}} - // CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> - // CHECK-SAME-COUNT-12: -> vector<2x2xf32> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) { + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK-COUNT-4: xegpu.load_nd %{{.*}} + // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK-SAME-COUNT-4: -> vector<16x16xf32> // CHECK-NOT: xegpu.load_nd %load = xegpu.load_nd %tdesc - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> - -> vector<24x32xf32> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> + -> vector<256x128xf32> gpu.return } // CHECK-LABEL: store_nd - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @store_nd(%src: memref<24x32xf32>) { - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> - // CHECK-COUNT-12: xegpu.store_nd %{{.*}}, %{{.*}} - // CHECK-SAME-COUNT-12: : vector<2x2xf32>, !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @store_nd(%src: memref<256x128xf32>) { + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, %{{.*}} + // CHECK-SAME-COUNT-4: : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK-NOT : xegpu.store_nd %load = xegpu.load_nd %tdesc - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> - -> vector<24x32xf32> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> + -> vector<256x128xf32> xegpu.store_nd %load, %tdesc - : vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> + : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return } // CHECK-LABEL: update_nd - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @update_nd(%src: memref<24x32xf32>){ - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> - // CHECK-COUNT-12: xegpu.update_nd_offset %{{.*}}, [0, 16] - // CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @update_nd(%src: memref<256x128xf32>){ + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK-COUNT-4: xegpu.update_nd_offset %{{.*}}, [0, 16] + // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>> // CHECK-NOT: xegpu.update_nd_offset %update = xegpu.update_nd_offset %tdesc, [0, 16] - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return } // CHECK-LABEL: dpas - // CHECK-SAME: (%[[ARG_0:.*]]: memref<8x8xf32>, %[[ARG_1:.*]]: memref<8x8xf32>, %[[ARG_2:.*]]: memref<8x8xf32>) - gpu.func @dpas(%a: memref<8x8xf32>, %b: memref<8x8xf32>, %c: memref<8x8xf32>) { - // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<8x8xf32> - // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> + // CHECK-SAME: (%[[ARG_0:.*]]: memref<256x128xf16>, %[[ARG_1:.*]]: memref<128x256xf16>) + gpu.func @dpas(%a: memref<256x128xf16>, %b: memref<128x256xf16>) { + // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf16> + // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK-NOT: xegpu.create_nd_tdesc - // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<8x8xf32> - // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> - // CHECK-NOT: xegpu.create_nd_tdesc - // CHECK-COUNT-4: xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}] : memref<8x8xf32> - // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> + // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<128x256xf16> + // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [4, 8], lane_data = [1, 1]>> // CHECK-NOT: xegpu.create_nd_tdesc // CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}} - // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} - // CHECK-SAME-COUNT-16: : vector<2x2xf32>, vector<2x2xf32> -> vector<2x2xf32> + // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} + // CHECK-SAME-COUNT-16: : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32> // CHECK-NOT: xegpu.dpas - %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<8x8xf32> - -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<256x128xf16> + -> !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> %load_a = xegpu.load_nd %tdesc_a - : !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> - -> vector<8x8xf32> - %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<8x8xf32> - -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> + : !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> + -> vector<256x128xf16> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x256xf16> + -> !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>> %load_b = xegpu.load_nd %tdesc_b - : !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> - -> vector<8x8xf32> - %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<8x8xf32> - -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> + : !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>> + -> vector<128x256xf16> %dpas = xegpu.dpas %load_a, %load_b - {layout_result_0 = #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>} - : vector<8x8xf32>, vector<8x8xf32> -> vector<8x8xf32> + {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>} + : vector<256x128xf16>, vector<128x256xf16> -> vector<256x256xf32> gpu.return } // CHECK-LABEL: prefetch_nd_tdesc - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @prefetch_nd_tdesc(%src: memref<24x32xf32>) { - // CHECK-COUNT-12: xegpu.prefetch_nd %{{.*}} - // CHECK-SAME-COUNT-12 : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @prefetch_nd_tdesc(%src: memref<256x128xf32>) { + // CHECK-COUNT-4: xegpu.prefetch_nd %{{.*}} + // CHECK-SAME-COUNT-4: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK-NOT: xegpu.prefetch_nd - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> xegpu.prefetch_nd %tdesc - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return } // CHECK-LABEL: broadcast - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32> - gpu.func @broadcast(%src: memref<24x1xf32>) { - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32> - -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<128x1xf32> + gpu.func @broadcast(%src: memref<128x1xf32>) { + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<128x1xf32> + -> !xegpu.tensor_desc<128x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 1], lane_layout = [8, 1], lane_data = [1, 1]>> %load = xegpu.load_nd %tdesc - : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>> - -> vector<24x1xf32> - // CHECK-COUNT-3: vector.broadcast {{.*}} - // CHECK-SAME-COUNT-3: {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>} - // CHECK-SAME-COUNT-3: : vector<2x1xf32> to vector<2x4xf32> + : !xegpu.tensor_desc<128x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 1], lane_layout = [8, 1], lane_data = [1, 1]>> + -> vector<128x1xf32> + // CHECK-COUNT-2: vector.broadcast {{.*}} + // CHECK-SAME-COUNT-2: {layout_result_0 = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>} + // CHECK-SAME-COUNT-2: : vector<16x1xf32> to vector<16x32xf32> // CHECK-NOT: vector.broadcast %broadcast = vector.broadcast %load - {layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 4], lane_layout = [2, 1], lane_data = [1, 1]>} - : vector<24x1xf32> to vector<24x8xf32> + {layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 32], lane_layout = [8, 1], lane_data = [1, 1]>} + : vector<128x1xf32> to vector<128x64xf32> gpu.return } diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir index d511224..d4b0037 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir @@ -4,201 +4,181 @@ //CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)> gpu.module @test_1_1_assignment { // CHECK-LABEL: create_nd_tdesc - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @create_nd_tdesc(%src: memref<24x32xf32>) { + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) { // CHECK: %[[SGID:.*]] = gpu.subgroup_id - // CHECK: %[[C12:.*]] = arith.constant 12 : index - // CHECK: %[[C4:.*]] = arith.constant 4 : index // CHECK: %[[C8:.*]] = arith.constant 8 : index + // CHECK: %[[C32:.*]] = arith.constant 32 : index + // CHECK: %[[C4:.*]] = arith.constant 4 : index + // CHECK: %[[C32_0:.*]] = arith.constant 32 : index + // CHECK: %[[C4_1:.*]] = arith.constant 4 : index // CHECK: %[[DIV:.*]] = affine.apply #map()[%[[SGID]]] // CHECK: %[[REM:.*]] = affine.apply #map1()[%[[SGID]]] - // CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C12]] - // CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C8]] - // CHECK: %[[C24:.*]] = arith.constant 24 : index - // CHECK: %[[MOD:.*]] = index.remu %[[MUL1]], %[[C24]] + // CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C32]] + // CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C32_0]] // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[ADD1:.*]] = index.add %[[MOD]], %[[C0]] - // CHECK: %[[C32:.*]] = arith.constant 32 : index - // CHECK: %[[MOD1:.*]] = index.remu %[[MUL2]], %[[C32]] - // CHECK: %[[C0_1:.*]] = arith.constant 0 : index - // CHECK: %[[ADD2:.*]] = index.add %[[MOD1]], %[[C0_1]] - // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<24x32xf32> - // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> + // CHECK: %[[C256:.*]] = arith.constant 256 : index + // CHECK: %[[MOD:.*]] = index.remu %[[MUL1]], %[[C256]] + // CHECK: %[[C0_2:.*]] = arith.constant 0 : index + // CHECK: %[[ADD1:.*]] = index.add %[[MOD]], %[[C0_2]] + // CHECK: %[[C0_3:.*]] = arith.constant 0 : index + // CHECK: %[[C128:.*]] = arith.constant 128 : index + // CHECK: %[[MOD1:.*]] = index.remu %[[MUL2]], %[[C128]] + // CHECK: %[[C0_4:.*]] = arith.constant 0 : index + // CHECK: %[[ADD2:.*]] = index.add %[[MOD1]], %[[C0_4]] + // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<256x128xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK: gpu.return - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return } // CHECK-LABEL: load_nd_tdesc - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @load_nd_tdesc(%src: memref<24x32xf32>) { - // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> - // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) { + // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] - // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - // CHECK-SAME: -> vector<12x8xf32> - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> + // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK-SAME: -> vector<32x32xf32> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> %load = xegpu.load_nd %tdesc - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> - -> vector<24x32xf32> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> + -> vector<256x128xf32> gpu.return } // CHECK-LABEL: store_nd - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @store_nd(%src: memref<24x32xf32>) { - // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> - // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @store_nd(%src: memref<256x128xf32>) { + // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] - // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - // CHECK-SAME: -> vector<12x8xf32> + // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK-SAME: -> vector<32x32xf32> // CHECK: xegpu.store_nd %[[LOAD]], %[[TDESC]] - // CHECK-SAME: : vector<12x8xf32>, !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> + // CHECK-SAME: : vector<32x32xf32>, !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> %load = xegpu.load_nd %tdesc - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> - -> vector<24x32xf32> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> + -> vector<256x128xf32> xegpu.store_nd %load, %tdesc - : vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> + : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return } // CHECK-LABEL: update_nd -// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> -gpu.func @update_nd(%src: memref<24x32xf32>){ - // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> - // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> +// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> +gpu.func @update_nd(%src: memref<256x128xf32>){ + // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK: %[[UPDATE:.*]] = xegpu.update_nd_offset %[[TDESC]], [0, 16] - // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> + // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> %update = xegpu.update_nd_offset %tdesc, [0, 16] - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return } // CHECK-LABEL: dpas -// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> -// CHECK-SAME: %[[ARG_1:.*]]: memref<32x24xf32> -gpu.func @dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { - // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> - // CHECk-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - // CHECK: %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]] - // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - // CHECK-SAME: -> vector<12x8xf32> - // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32> - // CHECK-SAME: -> !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>> - // CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]] - // CHECK-SAME: : !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>> - // CHECK-SAME: -> vector<8x12xf32> - // CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]] - // CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} - // CHECK-SAME: : vector<12x8xf32>, vector<8x12xf32> -> vector<12x12xf32> - %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> +gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>) { + // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x128xf16>, vector<128x16xf16> -> vector<16x16xf32> + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<128x128xf16> + -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>> %load_a = xegpu.load_nd %tdesc_a - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> - -> vector<24x32xf32> - %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32> - -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>> + : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>> + -> vector<128x128xf16> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x128xf16> + -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>> %load_b = xegpu.load_nd %tdesc_b - : !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>> - -> vector<32x24xf32> + : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>> + -> vector<128x128xf16> %dpas = xegpu.dpas %load_a, %load_b - {layout_result_0 = #xegpu.layout<sg_layout = [2, 2], sg_data = [12, 12], lane_layout = [2, 2], lane_data = [1, 1]>} - : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32> + {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>} + : vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32> gpu.return } // CHECK-LABEL: dpas_no_sg_data -// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> -// CHECK-SAME: %[[ARG_1:.*]]: memref<32x24xf32> -gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { - // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> - // CHECk-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - // CHECK: %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]] - // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - // CHECK-SAME: -> vector<12x8xf32> - // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32> - // CHECK-SAME: -> !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>> - // CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]] - // CHECK-SAME: : !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>> - // CHECK-SAME: -> vector<8x12xf32> - // CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]] - // CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} - // CHECK-SAME: : vector<12x8xf32>, vector<8x12xf32> -> vector<12x12xf32> - %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], lane_layout = [2, 8], lane_data = [1, 1]>> +gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) { + // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32> + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<128x128xf16> + -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1], + order = [1, 0]>> %load_a = xegpu.load_nd %tdesc_a - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], lane_layout = [2, 8], lane_data = [1, 1]>> - -> vector<24x32xf32> - %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32> - -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], lane_layout = [8, 2], lane_data = [1, 1]>> + : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1], + order = [1, 0]>> + -> vector<128x128xf16> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x128xf16> + -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1], + order = [1, 0]>> %load_b = xegpu.load_nd %tdesc_b - : !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], lane_layout = [8, 2], lane_data = [1, 1]>> - -> vector<32x24xf32> + : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1], + order = [1, 0]>> + -> vector<128x128xf16> %dpas = xegpu.dpas %load_a, %load_b - {layout_result_0 = #xegpu.layout<sg_layout = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>} - : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32> + {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} + : vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32> gpu.return } // CHECK-LABEL: prefetch_nd_tdesc - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @prefetch_nd_tdesc(%src: memref<24x32xf32>) { - // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> - // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @prefetch_nd_tdesc(%src: memref<256x128xf32>) { + // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK: xegpu.prefetch_nd %[[TDESC]] - // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> + // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> xegpu.prefetch_nd %tdesc - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return } // CHECK-LABEL: dpas_with_no_create_nd_desc - gpu.func @dpas_with_no_create_nd_desc(%a: vector<24x32xf32>, %b: vector<32x24xf32>) { - // CHECK-NOT: vector<12x12xf32> + gpu.func @dpas_with_no_create_nd_desc(%a: vector<256x128xf32>, %b: vector<128x256xf32>) { + // CHECK-NOT: vector<32x32xf32> %dpas = xegpu.dpas %a, %b {layout = #xegpu.layout<sg_layout = [2, 2], sg_data = [12, 12], lane_layout = [2, 2], lane_data = [1, 1]>} - : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32> + : vector<256x128xf32>, vector<128x256xf32> -> vector<256x256xf32> gpu.return } // CHECK-LABEL: broadcast_dim1 - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32> - gpu.func @broadcast_dim1(%src: memref<24x1xf32>) { - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32> - -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x1xf32> + gpu.func @broadcast_dim1(%src: memref<256x1xf32>) { + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x1xf32> + -> !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 1], lane_layout = [8, 1], lane_data = [1, 1]>> %load = xegpu.load_nd %tdesc - : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>> - -> vector<24x1xf32> - // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>} - // CHECK-SAME: : vector<12x1xf32> to vector<12x8xf32> - %broadcast = vector.broadcast %load - {layout_result_0 = #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 8], lane_layout = [2, 1], lane_data = [1, 1]>} - : vector<24x1xf32> to vector<24x8xf32> + : !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 1], lane_layout = [8, 1], lane_data = [1, 1]>> + -> vector<256x1xf32> + // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>} + // CHECK-SAME: : vector<32x1xf32> to vector<32x32xf32> + %broadcast = vector.broadcast %load + {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 32], lane_layout = [8, 1], lane_data = [1, 1]>} + : vector<256x1xf32> to vector<256x32xf32> gpu.return } // CHECK-LABEL: broadcast_dim0 - // CHECK-SAME: %[[ARG_0:.*]]: memref<1x32xf32> - gpu.func @broadcast_dim0(%src: memref<1x32xf32>) { - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<1x32xf32> - -> !xegpu.tensor_desc<1x32xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 8], lane_layout = [1, 8], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<1x128xf32> + gpu.func @broadcast_dim0(%src: memref<1x128xf32>) { + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<1x128xf32> + -> !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 32], lane_layout = [1, 16], lane_data = [1, 1]>> %load = xegpu.load_nd %tdesc - : !xegpu.tensor_desc<1x32xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 8], lane_layout = [1, 8], lane_data = [1, 1]>> - -> vector<1x32xf32> - // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 8], lane_data = [1, 1]>} - // CHECK-SAME: : vector<1x8xf32> to vector<12x8xf32> + : !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 32], lane_layout = [1, 16], lane_data = [1, 1]>> + -> vector<1x128xf32> + // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} + // CHECK-SAME: : vector<1x32xf32> to vector<32x32xf32> %broadcast = vector.broadcast %load - {layout_result_0 = #xegpu.layout<sg_layout = [1, 4], sg_data = [12, 8], lane_layout = [1, 8], lane_data = [1, 1]>} - : vector<1x32xf32> to vector<12x32xf32> + {layout_result_0 = #xegpu.layout<sg_layout = [1, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>} + : vector<1x128xf32> to vector<32x128xf32> gpu.return } diff --git a/mlir/test/Examples/transform/Ch3/ops.mlir b/mlir/test/Examples/transform/Ch3/ops.mlir index b2d47cc..707a09f 100644 --- a/mlir/test/Examples/transform/Ch3/ops.mlir +++ b/mlir/test/Examples/transform/Ch3/ops.mlir @@ -30,9 +30,30 @@ module attributes {transform.with_named_sequence} { // ----- func.func private @orig() +func.func private @updated() // CHECK-LABEL: func @test2 func.func @test2() { + // CHECK: call @updated + call @orig() : () -> () + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + %call = transform.structured.match ops{["func.call"]} in %arg0 : (!transform.any_op) -> !transform.my.call_op_interface + // CHECK: transform.my.change_call_target %{{.*}}, "updated" : !transform.my.call_op_interface + transform.my.change_call_target %call, "updated" : !transform.my.call_op_interface + transform.yield + } +} + +// ----- + +func.func private @orig() + +// CHECK-LABEL: func @test3 +func.func @test3() { // CHECK: "my.mm4" call @orig() : () -> () return diff --git a/mlir/test/Examples/transform/Ch3/sequence.mlir b/mlir/test/Examples/transform/Ch3/sequence.mlir index 4d28518..877b006 100644 --- a/mlir/test/Examples/transform/Ch3/sequence.mlir +++ b/mlir/test/Examples/transform/Ch3/sequence.mlir @@ -101,11 +101,12 @@ module attributes {transform.with_named_sequence} { %_1, %outline_target = transform.structured.fuse_into_containing_op %matmul_fused_2 into %loop_third : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) %func, %call = transform.loop.outline %outline_target {func_name = "outlined"} - : (!transform.any_op) -> (!transform.any_op, !transform.op<"func.call">) - - // Rewrite the call target. - transform.my.change_call_target %call, "microkernel" : !transform.op<"func.call"> - + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + // Cast to our new type. + %casted = transform.cast %call : !transform.any_op to !transform.my.call_op_interface + // Using our new operation. + transform.my.change_call_target %casted, "microkernel" : !transform.my.call_op_interface + transform.yield } } diff --git a/mlir/test/IR/diagnostic-nosplit.mlir b/mlir/test/IR/diagnostic-nosplit.mlir new file mode 100644 index 0000000..ecfb9c6 --- /dev/null +++ b/mlir/test/IR/diagnostic-nosplit.mlir @@ -0,0 +1,13 @@ +// RUN: not mlir-opt %s -o - --split-input-file 2>&1 | FileCheck %s +// This test verifies that diagnostic handler doesn't emit splits. + + +// ----- + + + +func.func @constant_out_of_range() { + // CHECK: mlir:11:8: error: 'arith.constant' + %x = "arith.constant"() {value = 100} : () -> i1 + return +} diff --git a/mlir/test/IR/test-pattern-logging-listener.mlir b/mlir/test/IR/test-pattern-logging-listener.mlir index c521110..d3d42e3 100644 --- a/mlir/test/IR/test-pattern-logging-listener.mlir +++ b/mlir/test/IR/test-pattern-logging-listener.mlir @@ -8,15 +8,15 @@ // {anonymous_namespace} vs `anonymous_namespace` (and maybe others?) on the // various platforms. -// CHECK: [pattern-logging-listener] +// CHECK: [pattern-logging-listener:1] // CHECK-SAME: ::ReplaceWithNewOp | notifyOperationInserted | test.new_op -// CHECK: [pattern-logging-listener] +// CHECK: [pattern-logging-listener:1] // CHECK-SAME: ::ReplaceWithNewOp | notifyOperationReplaced (with values) | test.replace_with_new_op -// CHECK: [pattern-logging-listener] +// CHECK: [pattern-logging-listener:1] // CHECK-SAME: ::ReplaceWithNewOp | notifyOperationModified | arith.addi -// CHECK: [pattern-logging-listener] +// CHECK: [pattern-logging-listener:1] // CHECK-SAME: ::ReplaceWithNewOp | notifyOperationModified | arith.addi -// CHECK: [pattern-logging-listener] +// CHECK: [pattern-logging-listener:1] // CHECK-SAME: ::ReplaceWithNewOp | notifyOperationErased | test.replace_with_new_op func.func @replace_with_new_op() -> i32 { %a = "test.replace_with_new_op"() : () -> (i32) diff --git a/mlir/test/IR/top-level.mlir b/mlir/test/IR/top-level.mlir index b571d94..5389691 100644 --- a/mlir/test/IR/top-level.mlir +++ b/mlir/test/IR/top-level.mlir @@ -6,10 +6,10 @@ func.func private @foo() // ----- -// expected-error@-3 {{source must contain a single top-level operation, found: 2}} +// expected-error@-2 {{source must contain a single top-level operation, found: 2}} func.func private @bar() func.func private @baz() // ----- -// expected-error@-3 {{source must contain a single top-level operation, found: 0}} +// expected-error@-2 {{source must contain a single top-level operation, found: 0}} diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir index 05e6782..a7bb039 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir @@ -81,21 +81,21 @@ func.func private @matmul(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tenso func.func private @mmt4d(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tensor<7x13xi32>) -> tensor<7x13xi32> { %zero = arith.constant 0 : i32 - %A_pack_empty = tensor.empty() : tensor<2x16x8x1xi32> + %A_pack_empty = tensor.empty() : tensor<1x16x8x1xi32> %B_pack_empty = tensor.empty() : tensor<2x16x8x1xi32> - %C_pack_empty = tensor.empty() : tensor<2x2x8x8xi32> + %C_pack_empty = tensor.empty() : tensor<1x2x8x8xi32> // Pack matrices - %A_pack = linalg.pack %A padding_value(%zero : i32) inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %A_pack_empty : tensor<7x16xi32> -> tensor<2x16x8x1xi32> + %A_pack = linalg.pack %A padding_value(%zero : i32) inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %A_pack_empty : tensor<7x16xi32> -> tensor<1x16x8x1xi32> %B_pack = linalg.pack %B padding_value(%zero : i32) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [8, 1] into %B_pack_empty : tensor<16x13xi32> -> tensor<2x16x8x1xi32> - %C_pack = linalg.pack %C padding_value(%zero : i32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_pack_empty : tensor<7x13xi32> -> tensor<2x2x8x8xi32> + %C_pack = linalg.pack %C padding_value(%zero : i32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_pack_empty : tensor<7x13xi32> -> tensor<1x2x8x8xi32> // MMT4D - %mmt4d = linalg.mmt4d ins(%A_pack, %B_pack : tensor<2x16x8x1xi32>, tensor<2x16x8x1xi32>) outs(%C_pack : tensor<2x2x8x8xi32>) -> tensor<2x2x8x8xi32> + %mmt4d = linalg.mmt4d ins(%A_pack, %B_pack : tensor<1x16x8x1xi32>, tensor<2x16x8x1xi32>) outs(%C_pack : tensor<1x2x8x8xi32>) -> tensor<1x2x8x8xi32> // Unpack output %C_out_empty = tensor.empty() : tensor<7x13xi32> - %C_out_unpack = linalg.unpack %mmt4d outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_out_empty : tensor<2x2x8x8xi32> -> tensor<7x13xi32> + %C_out_unpack = linalg.unpack %mmt4d outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_out_empty : tensor<1x2x8x8xi32> -> tensor<7x13xi32> return %C_out_unpack : tensor<7x13xi32> } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir index 6e2a82b..6ec1031 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir @@ -4,14 +4,14 @@ // RUN: FileCheck %s func.func @extract_element_0d(%a: vector<f32>) { - %1 = vector.extractelement %a[] : vector<f32> + %1 = vector.extract %a[] : f32 from vector<f32> // CHECK: 42 vector.print %1: f32 return } func.func @insert_element_0d(%a: f32, %b: vector<f32>) -> (vector<f32>) { - %1 = vector.insertelement %a, %b[] : vector<f32> + %1 = vector.insert %a, %b[] : f32 into vector<f32> return %1: vector<f32> } @@ -58,9 +58,9 @@ func.func @broadcast_0d(%a: f32) { func.func @bitcast_0d() { %0 = arith.constant 42 : i32 %1 = arith.constant dense<0> : vector<i32> - %2 = vector.insertelement %0, %1[] : vector<i32> + %2 = vector.insert %0, %1[] : i32 into vector<i32> %3 = vector.bitcast %2 : vector<i32> to vector<f32> - %4 = vector.extractelement %3[] : vector<f32> + %4 = vector.extract %3[] : f32 from vector<f32> %5 = arith.bitcast %4 : f32 to i32 // CHECK: 42 vector.print %5: i32 diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir index b69a200..eb99886 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir @@ -72,7 +72,7 @@ func.func @za0_d_f64() -> i32 { %row = vector.load %mem2[%vnum, %c0] : memref<?x?xf64>, vector<[2]xf64> %inner_add_reduce = scf.for %offset = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_0_f64) -> (f64) { - %t = vector.extractelement %row[%offset : index] : vector<[2]xf64> + %t = vector.extract %row[%offset] : f64 from vector<[2]xf64> %inner_add_reduce_next = arith.addf %inner_iter, %t : f64 scf.yield %inner_add_reduce_next : f64 } @@ -102,7 +102,7 @@ func.func @za0_d_f64() -> i32 { %cmp = arith.cmpf one, %row_1, %row_2 : vector<[2]xf64> %inner_mul_reduce = scf.for %i = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_1) -> (i64) { - %t = vector.extractelement %cmp[%i : index] : vector<[2]xi1> + %t = vector.extract %cmp[%i] : i1 from vector<[2]xi1> %t_i64 = arith.extui %t : i1 to i64 %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64 scf.yield %inner_mul_reduce_next : i64 @@ -125,7 +125,7 @@ func.func @za0_d_f64() -> i32 { %cmp = arith.cmpf oeq, %row_1, %row_2 : vector<[2]xf64> %inner_mul_reduce = scf.for %i = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_1) -> (i64) { - %t = vector.extractelement %cmp[%i : index] : vector<[2]xi1> + %t = vector.extract %cmp[%i] : i1 from vector<[2]xi1> %t_i64 = arith.extui %t : i1 to i64 %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64 scf.yield %inner_mul_reduce_next : i64 diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir index 697fb90..ad8e321 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir @@ -36,7 +36,7 @@ func.func @entry() -> i32 { %row = vector.load %za_b[%vnum, %c0] : memref<?x?xi8>, vector<[16]xi8> %inner_mul_reduce = scf.for %offset = %c0 to %svl_b step %c1_index iter_args(%inner_iter = %init_1) -> (i64) { - %t = vector.extractelement %row[%offset : index] : vector<[16]xi8> + %t = vector.extract %row[%offset] : i8 from vector<[16]xi8> %t_i64 = arith.extui %t : i8 to i64 %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64 scf.yield %inner_mul_reduce_next : i64 @@ -64,7 +64,7 @@ func.func @entry() -> i32 { %row = vector.load %za_b[%vnum, %c0] : memref<?x?xi8>, vector<[16]xi8> %inner_mul_reduce = scf.for %offset = %c0 to %svl_b step %c1_index iter_args(%inner_iter = %init_1) -> (i64) { - %t = vector.extractelement %row[%offset : index] : vector<[16]xi8> + %t = vector.extract %row[%offset] : i8 from vector<[16]xi8> %t_i64 = arith.extui %t : i8 to i64 %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64 scf.yield %inner_mul_reduce_next : i64 diff --git a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot.mlir index 53a7282..aff272c2 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot.mlir @@ -11,8 +11,8 @@ func.func @entry() -> i32 { %b = arith.constant dense<[9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : vector<8xf32> %r = x86vector.avx.intr.dot %a, %b : vector<8xf32> - %1 = vector.extractelement %r[%i0 : i32]: vector<8xf32> - %2 = vector.extractelement %r[%i4 : i32]: vector<8xf32> + %1 = vector.extract %r[%i0] : f32 from vector<8xf32> + %2 = vector.extract %r[%i4] : f32 from vector<8xf32> %d = arith.addf %1, %2 : f32 // CHECK: ( 110, 110, 110, 110, 382, 382, 382, 382 ) diff --git a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/sparse-dot-product.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/sparse-dot-product.mlir index bf1caaa..1c56990 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/sparse-dot-product.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/sparse-dot-product.mlir @@ -196,13 +196,13 @@ func.func @memref_dot_optimized(%m_A : memref<?xi64>, %m_B : memref<?xf64>, iter_args(%sum0 = %data_zero, %b_start0 = %c0) -> (f64, index) { %v_A = vector.transfer_read %m_A[%a], %index_padding : memref<?xi64>, vector<8xi64> - %segA_min = vector.extractelement %v_A[%i0 : i32] : vector<8xi64> + %segA_min = vector.extract %v_A[%i0] : i64 from vector<8xi64> %r1, %next_b_start0 = scf.for %b = %b_start0 to %N step %c8 iter_args(%sum1 = %sum0, %b_start1 = %b_start0) -> (f64, index) { %v_C = vector.transfer_read %m_C[%b], %index_padding : memref<?xi64>, vector<8xi64> - %segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64> + %segB_max = vector.extract %v_C[%i7] : i64 from vector<8xi64> %seg1_done = arith.cmpi "slt", %segB_max, %segA_min : i64 %r2, %next_b_start1 = scf.if %seg1_done -> (f64, index) { @@ -273,10 +273,10 @@ func.func @memref_dot_while(%m_A : memref<?xi64>, %m_B : memref<?xf64>, %v_C = vector.transfer_read %m_C[%b1], %index_padding : memref<?xi64>, vector<8xi64> - %segA_min = vector.extractelement %v_A[%i0 : i32] : vector<8xi64> - %segA_max = vector.extractelement %v_A[%i7 : i32] : vector<8xi64> - %segB_min = vector.extractelement %v_C[%i0 : i32] : vector<8xi64> - %segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64> + %segA_min = vector.extract %v_A[%i0] : i64 from vector<8xi64> + %segA_max = vector.extract %v_A[%i7] : i64 from vector<8xi64> + %segB_min = vector.extract %v_C[%i0] : i64 from vector<8xi64> + %segB_max = vector.extract %v_C[%i7] : i64 from vector<8xi64> %seg1_done = arith.cmpi "slt", %segB_max, %segA_min : i64 %r2, %a2, %b2 = scf.if %seg1_done -> (f64, index, index) { @@ -370,8 +370,8 @@ func.func @memref_dot_while_branchless(%m_A : memref<?xi64>, %m_B : memref<?xf64 -> f64 %r2 = arith.addf %r1, %subresult : f64 - %segA_max = vector.extractelement %v_A[%i7 : i32] : vector<8xi64> - %segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64> + %segA_max = vector.extract %v_A[%i7] : i64 from vector<8xi64> + %segB_max = vector.extract %v_C[%i7] : i64 from vector<8xi64> %cond_a = arith.cmpi "sle", %segA_max, %segB_max : i64 %cond_a_i64 = arith.extui %cond_a : i1 to i64 diff --git a/mlir/test/Integration/Dialect/Vector/CPU/compress.mlir b/mlir/test/Integration/Dialect/Vector/CPU/compress.mlir index e9a66cc..1683fa5 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/compress.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/compress.mlir @@ -28,8 +28,7 @@ func.func @printmem16(%A: memref<?xf32>) { %mem = scf.for %i = %c0 to %c16 step %c1 iter_args(%m_iter = %m) -> (vector<16xf32>) { %c = memref.load %A[%i] : memref<?xf32> - %i32 = arith.index_cast %i : index to i32 - %m_new = vector.insertelement %c, %m_iter[%i32 : i32] : vector<16xf32> + %m_new = vector.insert %c, %m_iter[%i] : f32 into vector<16xf32> scf.yield %m_new : vector<16xf32> } vector.print %mem : vector<16xf32> @@ -49,7 +48,7 @@ func.func @entry() { memref.store %z, %A[%i] : memref<?xf32> %i32 = arith.index_cast %i : index to i32 %fi = arith.sitofp %i32 : i32 to f32 - %v_new = vector.insertelement %fi, %v_iter[%i32 : i32] : vector<16xf32> + %v_new = vector.insert %fi, %v_iter[%i] : f32 into vector<16xf32> scf.yield %v_new : vector<16xf32> } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/maskedstore.mlir b/mlir/test/Integration/Dialect/Vector/CPU/maskedstore.mlir index 2dc00df..826da53 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/maskedstore.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/maskedstore.mlir @@ -28,8 +28,7 @@ func.func @printmem16(%A: memref<?xf32>) { %mem = scf.for %i = %c0 to %c16 step %c1 iter_args(%m_iter = %m) -> (vector<16xf32>) { %c = memref.load %A[%i] : memref<?xf32> - %i32 = arith.index_cast %i : index to i32 - %m_new = vector.insertelement %c, %m_iter[%i32 : i32] : vector<16xf32> + %m_new = vector.insert %c, %m_iter[%i] : f32 into vector<16xf32> scf.yield %m_new : vector<16xf32> } vector.print %mem : vector<16xf32> @@ -53,7 +52,7 @@ func.func @entry() { iter_args(%v_iter = %v) -> (vector<16xf32>) { %i32 = arith.index_cast %i : index to i32 %fi = arith.sitofp %i32 : i32 to f32 - %v_new = vector.insertelement %fi, %v_iter[%i32 : i32] : vector<16xf32> + %v_new = vector.insert %fi, %v_iter[%i] : f32 into vector<16xf32> scf.yield %v_new : vector<16xf32> } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/scatter.mlir b/mlir/test/Integration/Dialect/Vector/CPU/scatter.mlir index 54b6e69..22b5eef 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/scatter.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/scatter.mlir @@ -21,8 +21,7 @@ func.func @printmem8(%A: memref<?xf32>) { %mem = scf.for %i = %c0 to %c8 step %c1 iter_args(%m_iter = %m) -> (vector<8xf32>) { %c = memref.load %A[%i] : memref<?xf32> - %i32 = arith.index_cast %i : index to i32 - %m_new = vector.insertelement %c, %m_iter[%i32 : i32] : vector<8xf32> + %m_new = vector.insert %c, %m_iter[%i] : f32 into vector<8xf32> scf.yield %m_new : vector<8xf32> } vector.print %mem : vector<8xf32> diff --git a/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-1d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-1d.mlir index 2393bd1..639eed4 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-1d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-1d.mlir @@ -200,7 +200,7 @@ func.func @entry() { // CHECK: ( 2, 12, 22, -1, -1, -42, -42, -42, -42 ) // 6. Read a scalar from a 2D memref and broadcast the value to a 1D vector. - // Generates a loop with vector.insertelement. + // Generates a loop with vector.insert. call @transfer_read_1d_broadcast(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> () // CHECK: ( 12, 12, 12, 12, 12, 12, 12, 12, 12 ) diff --git a/mlir/test/Integration/GPU/Vulkan/vector-interleave.mlir b/mlir/test/Integration/GPU/Vulkan/vector-interleave.mlir index e665653..731bd5a 100644 --- a/mlir/test/Integration/GPU/Vulkan/vector-interleave.mlir +++ b/mlir/test/Integration/GPU/Vulkan/vector-interleave.mlir @@ -26,17 +26,17 @@ module attributes { %val2 = memref.load %arg1[%idx0] : memref<2xi32> %val3 = memref.load %arg1[%idx1] : memref<2xi32> - %lhs0 = vector.insertelement %val0, %lhs[%idx0 : index] : vector<2xi32> - %lhs1 = vector.insertelement %val1, %lhs0[%idx1 : index] : vector<2xi32> - %rhs0 = vector.insertelement %val2, %rhs[%idx0 : index] : vector<2xi32> - %rhs1 = vector.insertelement %val3, %rhs0[%idx1 : index] : vector<2xi32> + %lhs0 = vector.insert %val0, %lhs[%idx0] : i32 into vector<2xi32> + %lhs1 = vector.insert %val1, %lhs0[%idx1] : i32 into vector<2xi32> + %rhs0 = vector.insert %val2, %rhs[%idx0] : i32 into vector<2xi32> + %rhs1 = vector.insert %val3, %rhs0[%idx1] : i32 into vector<2xi32> %interleave = vector.interleave %lhs1, %rhs1 : vector<2xi32> -> vector<4xi32> - %res0 = vector.extractelement %interleave[%idx0 : index] : vector<4xi32> - %res1 = vector.extractelement %interleave[%idx1 : index] : vector<4xi32> - %res2 = vector.extractelement %interleave[%idx2 : index] : vector<4xi32> - %res3 = vector.extractelement %interleave[%idx3 : index] : vector<4xi32> + %res0 = vector.extract %interleave[%idx0] : i32 from vector<4xi32> + %res1 = vector.extract %interleave[%idx1] : i32 from vector<4xi32> + %res2 = vector.extract %interleave[%idx2] : i32 from vector<4xi32> + %res3 = vector.extract %interleave[%idx3] : i32 from vector<4xi32> memref.store %res0, %arg2[%idx0]: memref<4xi32> memref.store %res1, %arg2[%idx1]: memref<4xi32> diff --git a/mlir/test/Integration/GPU/Vulkan/vector-shuffle.mlir b/mlir/test/Integration/GPU/Vulkan/vector-shuffle.mlir index dc53fe3..c1b7dba 100644 --- a/mlir/test/Integration/GPU/Vulkan/vector-shuffle.mlir +++ b/mlir/test/Integration/GPU/Vulkan/vector-shuffle.mlir @@ -26,17 +26,17 @@ module attributes { %val2 = memref.load %arg1[%idx0] : memref<2xi32> %val3 = memref.load %arg1[%idx1] : memref<2xi32> - %lhs0 = vector.insertelement %val0, %lhs[%idx0 : index] : vector<2xi32> - %lhs1 = vector.insertelement %val1, %lhs0[%idx1 : index] : vector<2xi32> - %rhs0 = vector.insertelement %val2, %rhs[%idx0 : index] : vector<2xi32> - %rhs1 = vector.insertelement %val3, %rhs0[%idx1 : index] : vector<2xi32> + %lhs0 = vector.insert %val0, %lhs[%idx0] : i32 into vector<2xi32> + %lhs1 = vector.insert %val1, %lhs0[%idx1] : i32 into vector<2xi32> + %rhs0 = vector.insert %val2, %rhs[%idx0] : i32 into vector<2xi32> + %rhs1 = vector.insert %val3, %rhs0[%idx1] : i32 into vector<2xi32> %shuffle = vector.shuffle %lhs1, %rhs1[2, 1, 3, 3] : vector<2xi32>, vector<2xi32> - %res0 = vector.extractelement %shuffle[%idx0 : index] : vector<4xi32> - %res1 = vector.extractelement %shuffle[%idx1 : index] : vector<4xi32> - %res2 = vector.extractelement %shuffle[%idx2 : index] : vector<4xi32> - %res3 = vector.extractelement %shuffle[%idx3 : index] : vector<4xi32> + %res0 = vector.extract %shuffle[%idx0] : i32 from vector<4xi32> + %res1 = vector.extract %shuffle[%idx1] : i32 from vector<4xi32> + %res2 = vector.extract %shuffle[%idx2] : i32 from vector<4xi32> + %res3 = vector.extract %shuffle[%idx3] : i32 from vector<4xi32> memref.store %res0, %arg2[%idx0]: memref<4xi32> memref.store %res1, %arg2[%idx1]: memref<4xi32> diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir index cdbca72..7888462 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir @@ -595,16 +595,17 @@ module attributes {transform.with_named_sequence} { // ----- -// It is valid to fuse the pack op with padding semantics if the tiled -// dimensions do not need padding. +// It is valid to fuse the pack op with padding semantics if it is a perfect +// tiling case. func.func @fuse_pack_consumer_with_padding_semantics(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<22x2x3x16xf32> { - %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) { - %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> - %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> - %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32> + %0 = scf.forall (%arg2, %arg3) = (0, 0) to (64, 32) step (15, 16) shared_outs(%arg4 = %arg1) -> (tensor<64x32xf32>) { + %size = affine.min affine_map<(d0) -> (-d0 + 64, 15)>(%arg2) + %src = tensor.extract_slice %arg0[%arg2, %arg3] [%size, 16] [1, 1] : tensor<64x32xf32> to tensor<?x16xf32> + %dest = tensor.extract_slice %arg4[%arg2, %arg3] [%size, 16] [1, 1] : tensor<64x32xf32> to tensor<?x16xf32> + %2 = linalg.exp ins(%src : tensor<?x16xf32>) outs(%dest : tensor<?x16xf32>) -> tensor<?x16xf32> scf.forall.in_parallel { - tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32> + tensor.parallel_insert_slice %2 into %arg4[%arg2, %arg3] [%size, 16] [1, 1] : tensor<?x16xf32> into tensor<64x32xf32> } } %1 = tensor.empty() : tensor<22x2x3x16xf32> @@ -621,109 +622,39 @@ module attributes {transform.with_named_sequence} { transform.yield } } -// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)> +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (-d0 + 64, 15)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (d0 floordiv 3)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0) -> (d0 ceildiv 3)> +// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0) -> (d0 floordiv 16)> // CHECK: func.func @fuse_pack_consumer_with_padding_semantics( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] // CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<22x2x3x16xf32> // CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16) -// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]]) -// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1] -// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] +// CHECK: %{{.*}}:2 = scf.forall (%[[I:.*]], %[[J:.*]]) = (0, 0) to (64, 32) step (15, 16) +// CHECK-SAME: shared_outs(%[[ELEM_OUT:.*]] = %[[ARG1]], %[[PACK_OUT:.*]] = %[[OUT_INIT]]) +// CHECK: %[[SIZE:.+]] = affine.min #[[MAP0]](%[[I]]) +// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]] +// CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1] +// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT]] +// CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1] // CHECK: %[[ELEM:.*]] = linalg.exp // CHECK-SAME: ins(%[[ELEM_SRC]] // CHECK-SAME: outs(%[[ELEM_DEST]] -// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]]) -// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 1, 3, 16] [1, 1, 1, 1] -// CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[ELEM]] -// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32) -// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 16] -// CHECK-SAME: into %[[TILED_PACK_DEST]] -// CHECK: scf.forall.in_parallel { -// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] -// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 1, 3, 16] [1, 1, 1, 1] - -// ----- - -// It is valid to fuse the pack if the dimension is not tiled even when it needs -// extra padding. - -func.func @fuse_pack_consumer_with_untiled_extra_padding(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<33x2x3x16xf32> { - %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) { - %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> - %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> - %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32> - scf.forall.in_parallel { - tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32> - } - } - %1 = tensor.empty() : tensor<33x2x3x16xf32> - %cst = arith.constant 0.000000e+00 : f32 - %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<33x2x3x16xf32> - return %pack : tensor<33x2x3x16xf32> -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.yield - } -} -// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)> -// CHECK: func.func @fuse_pack_consumer_with_untiled_extra_padding( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<33x2x3x16xf32> -// CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16) -// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]]) -// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1] -// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] -// CHECK: %[[ELEM:.*]] = linalg.exp -// CHECK-SAME: ins(%[[ELEM_SRC]] -// CHECK-SAME: outs(%[[ELEM_DEST]] -// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]]) -// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [33, 1, 3, 16] [1, 1, 1, 1] -// CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[ELEM]] +// CHECK-DAG: %[[D0_OFFSET:.*]] = affine.apply #[[MAP1]](%[[I]]) +// CHECK-DAG: %[[D0_SIZE:.*]] = affine.apply #[[MAP2]](%[[SIZE]]) +// CHECK-DAG: %[[D1_OFFSET:.*]] = affine.apply #[[MAP3]](%[[J]]) +// CHECK-DAG: %[[PACK_INIT:.*]] = tensor.extract_slice %[[PACK_OUT]] +// CHECK-SAME: [%[[D0_OFFSET]], %[[D1_OFFSET]], 0, 0] [%[[D0_SIZE]], 1, 3, 16] [1, 1, 1, 1] +// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]] // CHECK-SAME: padding_value(%[[PAD_VAL]] : f32) // CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 16] // CHECK-SAME: into %[[TILED_PACK_DEST]] // CHECK: scf.forall.in_parallel { -// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] -// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [33, 1, 3, 16] [1, 1, 1, 1] - -// ----- - -// If the dimension is tiled and it needs extra padding, do not fuse the pack -// op. - -func.func @nofuse_pack_consumer_with_extra_padding(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x32x3x16xf32> { - %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) { - %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> - %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> - %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32> - scf.forall.in_parallel { - // expected-error @below {{failed to fuse consumer of slice}} - tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32> - } - } - %1 = tensor.empty() : tensor<23x32x3x16xf32> - %cst = arith.constant 0.000000e+00 : f32 - %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<23x32x3x16xf32> - return %pack : tensor<23x32x3x16xf32> -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.yield - } -} +// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[ELEM_OUT]] +// CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT]] +// CHECK-SAME: [%[[D0_OFFSET]], %[[D1_OFFSET]], 0, 0] [%[[D0_SIZE]], 1, 3, 16] [1, 1, 1, 1] // ----- diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll index 24380b5..a419d75 100644 --- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll +++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll @@ -570,10 +570,10 @@ define void @trap_intrinsics() { ; CHECK-LABEL: llvm.func @memcpy_test define void @memcpy_test(i32 %0, ptr %1, ptr %2) { - ; CHECK: "llvm.intr.memcpy"(%{{.*}}, %{{.*}}, %{{.*}}) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () - call void @llvm.memcpy.p0.p0.i32(ptr %1, ptr %2, i32 %0, i1 false) - ; CHECK: "llvm.intr.memcpy.inline"(%{{.*}}, %{{.*}}) <{isVolatile = false, len = 10 : i64}> : (!llvm.ptr, !llvm.ptr) -> () - call void @llvm.memcpy.inline.p0.p0.i64(ptr %1, ptr %2, i64 10, i1 false) + ; CHECK: "llvm.intr.memcpy"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 4 : i64}, {llvm.align = 8 : i64}, {}], isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () + call void @llvm.memcpy.p0.p0.i32(ptr align 4 %1, ptr align 8 %2, i32 %0, i1 false) + ; CHECK: "llvm.intr.memcpy.inline"(%{{.*}}, %{{.*}}) <{arg_attrs = [{}, {llvm.align = 4 : i64}], isVolatile = false, len = 10 : i64}> : (!llvm.ptr, !llvm.ptr) -> () + call void @llvm.memcpy.inline.p0.p0.i64(ptr %1, ptr align 4 %2, i64 10, i1 false) ; CHECK: "llvm.intr.memcpy.inline"(%{{.*}}, %{{.*}}) <{isVolatile = false, len = 10 : i32}> : (!llvm.ptr, !llvm.ptr) -> () call void @llvm.memcpy.inline.p0.p0.i32(ptr %1, ptr %2, i32 10, i1 false) ret void @@ -581,17 +581,17 @@ define void @memcpy_test(i32 %0, ptr %1, ptr %2) { ; CHECK-LABEL: llvm.func @memmove_test define void @memmove_test(i32 %0, ptr %1, ptr %2) { - ; CHECK: "llvm.intr.memmove"(%{{.*}}, %{{.*}}, %{{.*}}) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () - call void @llvm.memmove.p0.p0.i32(ptr %1, ptr %2, i32 %0, i1 false) + ; CHECK: "llvm.intr.memmove"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 16 : i64}, {}, {}], isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () + call void @llvm.memmove.p0.p0.i32(ptr align 16 %1, ptr %2, i32 %0, i1 false) ret void } ; CHECK-LABEL: llvm.func @memset_test define void @memset_test(i32 %0, ptr %1, i8 %2) { - ; CHECK: "llvm.intr.memset"(%{{.*}}, %{{.*}}, %{{.*}}) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> () - call void @llvm.memset.p0.i32(ptr %1, i8 %2, i32 %0, i1 false) - ; CHECK: "llvm.intr.memset.inline"(%{{.*}}, %{{.*}}) <{isVolatile = false, len = 10 : i64}> : (!llvm.ptr, i8) -> () - call void @llvm.memset.inline.p0.i64(ptr %1, i8 %2, i64 10, i1 false) + ; CHECK: "llvm.intr.memset"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 2 : i64}, {}, {}], isVolatile = false}> : (!llvm.ptr, i8, i32) -> () + call void @llvm.memset.p0.i32(ptr align 2 %1, i8 %2, i32 %0, i1 false) + ; CHECK: "llvm.intr.memset.inline"(%{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 4 : i64}, {}], isVolatile = false, len = 10 : i64}> : (!llvm.ptr, i8) -> () + call void @llvm.memset.inline.p0.i64(ptr align 4 %1, i8 %2, i64 10, i1 false) ; CHECK: "llvm.intr.memset.inline"(%{{.*}}, %{{.*}}) <{isVolatile = false, len = 10 : i32}> : (!llvm.ptr, i8) -> () call void @llvm.memset.inline.p0.i32(ptr %1, i8 %2, i32 10, i1 false) ret void diff --git a/mlir/test/Target/LLVMIR/Import/module-asm.ll b/mlir/test/Target/LLVMIR/Import/module-asm.ll new file mode 100644 index 0000000..38f6ea4 --- /dev/null +++ b/mlir/test/Target/LLVMIR/Import/module-asm.ll @@ -0,0 +1,5 @@ +; RUN: mlir-translate -import-llvm %s | FileCheck %s +; CHECK: llvm.module_asm = ["foo", "bar"] + +module asm "foo" +module asm "bar" diff --git a/mlir/test/Target/LLVMIR/invalid-module.mlir b/mlir/test/Target/LLVMIR/invalid-module.mlir index 7fd5f26..5ed6244 100644 --- a/mlir/test/Target/LLVMIR/invalid-module.mlir +++ b/mlir/test/Target/LLVMIR/invalid-module.mlir @@ -1,6 +1,16 @@ -// RUN: mlir-translate -verify-diagnostics -mlir-to-llvmir --no-implicit-module %s +// RUN: mlir-translate -verify-diagnostics -mlir-to-llvmir --no-implicit-module -split-input-file %s // expected-error@below {{'llvm.func' op can not be translated to an LLVMIR module}} llvm.func @foo() { llvm.return } + +// ----- + +// expected-error@below {{expected an array attribute for a module level asm}} +module attributes {llvm.module_asm = "foo"} {} + +// ----- + +// expected-error@below {{expected a string attribute for each entry of a module level asm}} +module attributes {llvm.module_asm = [42]} {} diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir index 44074ce..eb3510c 100644 --- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir @@ -601,29 +601,33 @@ llvm.func @trap_intrinsics() { // CHECK-LABEL: @memcpy_test llvm.func @memcpy_test(%arg0: i32, %arg2: !llvm.ptr, %arg3: !llvm.ptr) { - // CHECK: call void @llvm.memcpy.p0.p0.i32(ptr %{{.*}}, ptr %{{.*}}, i32 %{{.*}}, i1 false - "llvm.intr.memcpy"(%arg2, %arg3, %arg0) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () - // CHECK: call void @llvm.memcpy.inline.p0.p0.i32(ptr %{{.*}}, ptr %{{.*}}, i32 10, i1 true - "llvm.intr.memcpy.inline"(%arg2, %arg3) <{isVolatile = true, len = 10 : i32}> : (!llvm.ptr, !llvm.ptr) -> () + // CHECK: call void @llvm.memcpy.p0.p0.i32(ptr align 4 %{{.*}}, ptr align 8 %{{.*}}, i32 %{{.*}}, i1 false + "llvm.intr.memcpy"(%arg2, %arg3, %arg0) <{arg_attrs = [{llvm.align = 4 : i64}, {llvm.align = 8 : i64}, {}], isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () + // CHECK: call void @llvm.memcpy.inline.p0.p0.i32(ptr align 4 %{{.*}}, ptr %{{.*}}, i32 10, i1 true + "llvm.intr.memcpy.inline"(%arg2, %arg3) <{arg_attrs = [{llvm.align = 4 : i64}, {}], isVolatile = true, len = 10 : i32}> : (!llvm.ptr, !llvm.ptr) -> () // CHECK: call void @llvm.memcpy.inline.p0.p0.i64(ptr %{{.*}}, ptr %{{.*}}, i64 10, i1 true "llvm.intr.memcpy.inline"(%arg2, %arg3) <{isVolatile = true, len = 10 : i64}> : (!llvm.ptr, !llvm.ptr) -> () + + // Verify that trailing empty argument attribute dictionaries can be omitted. + // CHECK: call void @llvm.memcpy.p0.p0.i32(ptr align 4 %{{.*}}, ptr align 8 %{{.*}}, i32 %{{.*}}, i1 false + "llvm.intr.memcpy"(%arg2, %arg3, %arg0) <{arg_attrs = [{llvm.align = 4 : i64}, {llvm.align = 8 : i64}], isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () llvm.return } // CHECK-LABEL: @memmove_test llvm.func @memmove_test(%arg0: i32, %arg2: !llvm.ptr, %arg3: !llvm.ptr) { - // CHECK: call void @llvm.memmove.p0.p0.i32(ptr %{{.*}}, ptr %{{.*}}, i32 %{{.*}}, i1 false - "llvm.intr.memmove"(%arg2, %arg3, %arg0) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () + // CHECK: call void @llvm.memmove.p0.p0.i32(ptr align 4 %{{.*}}, ptr align 8 %{{.*}}, i32 %{{.*}}, i1 false + "llvm.intr.memmove"(%arg2, %arg3, %arg0) <{arg_attrs = [{llvm.align = 4 : i64}, {llvm.align = 8 : i64}, {}], isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () llvm.return } // CHECK-LABEL: @memset_test llvm.func @memset_test(%arg0: i32, %arg2: !llvm.ptr, %arg3: i8) { %i1 = llvm.mlir.constant(false) : i1 - // CHECK: call void @llvm.memset.p0.i32(ptr %{{.*}}, i8 %{{.*}}, i32 %{{.*}}, i1 false - "llvm.intr.memset"(%arg2, %arg3, %arg0) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> () - // CHECK: call void @llvm.memset.inline.p0.i32(ptr %{{.*}}, i8 %{{.*}}, i32 10, i1 true - "llvm.intr.memset.inline"(%arg2, %arg3) <{isVolatile = true, len = 10 : i32}> : (!llvm.ptr, i8) -> () + // CHECK: call void @llvm.memset.p0.i32(ptr align 8 %{{.*}}, i8 %{{.*}}, i32 %{{.*}}, i1 false + "llvm.intr.memset"(%arg2, %arg3, %arg0) <{arg_attrs = [{llvm.align = 8 : i64}, {}, {}], isVolatile = false}> : (!llvm.ptr, i8, i32) -> () + // CHECK: call void @llvm.memset.inline.p0.i32(ptr align 8 %{{.*}}, i8 %{{.*}}, i32 10, i1 true + "llvm.intr.memset.inline"(%arg2, %arg3) <{arg_attrs = [{llvm.align = 8 : i64}, {}], isVolatile = true, len = 10 : i32}> : (!llvm.ptr, i8) -> () // CHECK: call void @llvm.memset.inline.p0.i64(ptr %{{.*}}, i8 %{{.*}}, i64 10, i1 true "llvm.intr.memset.inline"(%arg2, %arg3) <{isVolatile = true, len = 10 : i64}> : (!llvm.ptr, i8) -> () llvm.return diff --git a/mlir/test/Target/LLVMIR/module-asm.mlir b/mlir/test/Target/LLVMIR/module-asm.mlir new file mode 100644 index 0000000..2afb37c --- /dev/null +++ b/mlir/test/Target/LLVMIR/module-asm.mlir @@ -0,0 +1,6 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +module attributes {llvm.module_asm = ["foo", "bar"]} {} + +// CHECK: module asm "foo" +// CHECK: module asm "bar" diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir index 8c4f0aa..85478cc 100644 --- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir @@ -312,3 +312,42 @@ llvm.func @nvvm_prefetch_uniform_with_invalid_addr_space(%global_ptr: !llvm.ptr< nvvm.prefetch level = L1 uniform, %global_ptr : !llvm.ptr<1> llvm.return } + +// ----- + +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // expected-error@+1 {{'nvvm.stmatrix' op expected num attribute to be 1, 2 or 4}} + nvvm.stmatrix %arg0, %r1, %r2, %r3 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32, i32, i32 + llvm.return +} + +// ----- + +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // expected-error@+1 {{'nvvm.stmatrix' op expected shape to be 8x8 or 16x8}} + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32 + llvm.return +} + +// ----- + +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // expected-error@+1 {{'nvvm.stmatrix' op expected element type to be B16 for 8x8 matrix}} + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b8>} : !llvm.ptr<3>, i32 + llvm.return +} +// ----- + +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // expected-error@+1 {{'nvvm.stmatrix' op expected element type to be B8 for 16x8 matrix}} + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32 + llvm.return +} + +// ----- + +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // expected-error@+1 {{'nvvm.stmatrix' op expected layout to be col for 16x8 matrix}} + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b8>} : !llvm.ptr<3>, i32 + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index f86a041..5c2cfa4 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -573,6 +573,29 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { llvm.return } +// CHECK-LABEL: @st_matrix +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x1.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b8>} : !llvm.ptr<3>, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x2.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b8>} : !llvm.ptr<3>, i32, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32, i32, i32, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32, i32, i32, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x4.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b8>} : !llvm.ptr<3>, i32, i32, i32, i32 + llvm.return +} + // This function has the "kernel" attribute attached and should appear in the // NVVM annotations after conversion. llvm.func @kernel_func() attributes {nvvm.kernel} { diff --git a/mlir/test/Target/LLVMIR/omptarget-debug-reduc-fn-loc.mlir b/mlir/test/Target/LLVMIR/omptarget-debug-reduc-fn-loc.mlir deleted file mode 100644 index d889ef4..0000000 --- a/mlir/test/Target/LLVMIR/omptarget-debug-reduc-fn-loc.mlir +++ /dev/null @@ -1,121 +0,0 @@ -// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s - -module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memory_space", 5 : ui32>>, llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} { - omp.private {type = private} @_QFEi_private_i32 : i32 loc(#loc1) - omp.declare_reduction @add_reduction_i32 : i32 init { - ^bb0(%arg0: i32 loc("test.f90":8:7)): - %0 = llvm.mlir.constant(0 : i32) : i32 loc(#loc2) - omp.yield(%0 : i32) loc(#loc2) - } combiner { - ^bb0(%arg0: i32 loc("test.f90":8:7), %arg1: i32 loc("test.f90":8:7)): - %0 = llvm.add %arg0, %arg1 : i32 loc(#loc2) - omp.yield(%0 : i32) loc(#loc2) - } loc(#loc2) - llvm.func @_QQmain() { - %0 = llvm.mlir.constant(1 : i64) : i64 loc(#loc4) - %1 = llvm.alloca %0 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr<5> loc(#loc4) - %2 = llvm.addrspacecast %1 : !llvm.ptr<5> to !llvm.ptr loc(#loc4) - %3 = llvm.mlir.constant(1 : i64) : i64 loc(#loc1) - %4 = llvm.alloca %3 x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr<5> loc(#loc1) - %5 = llvm.addrspacecast %4 : !llvm.ptr<5> to !llvm.ptr loc(#loc1) - %6 = llvm.mlir.constant(8191 : index) : i64 loc(#loc5) - %7 = llvm.mlir.constant(0 : index) : i64 loc(#loc5) - %8 = llvm.mlir.constant(1 : index) : i64 loc(#loc5) - %9 = llvm.mlir.constant(0 : i32) : i32 loc(#loc5) - %10 = llvm.mlir.constant(8192 : index) : i64 loc(#loc5) - %11 = llvm.mlir.addressof @_QFEarr : !llvm.ptr<1> loc(#loc6) - %12 = llvm.addrspacecast %11 : !llvm.ptr<1> to !llvm.ptr loc(#loc6) - llvm.store %9, %2 : i32, !llvm.ptr loc(#loc7) - %15 = omp.map.info var_ptr(%2 : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "x"} loc(#loc4) - %16 = omp.map.info var_ptr(%5 : !llvm.ptr, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !llvm.ptr {name = "i"} loc(#loc7) - %17 = omp.map.bounds lower_bound(%7 : i64) upper_bound(%6 : i64) extent(%10 : i64) stride(%8 : i64) start_idx(%8 : i64) loc(#loc7) - %18 = omp.map.info var_ptr(%12 : !llvm.ptr, !llvm.array<8192 x i32>) map_clauses(implicit, tofrom) capture(ByRef) bounds(%17) -> !llvm.ptr {name = "arr"} loc(#loc7) - omp.target map_entries(%15 -> %arg0, %16 -> %arg1, %18 -> %arg2 : !llvm.ptr, !llvm.ptr, !llvm.ptr) { - %19 = llvm.mlir.constant(8192 : i32) : i32 loc(#loc5) - %20 = llvm.mlir.constant(1 : i32) : i32 loc(#loc5) - %21 = llvm.mlir.constant(8192 : index) : i64 loc(#loc6) - omp.teams reduction(@add_reduction_i32 %arg0 -> %arg3 : !llvm.ptr) { - omp.parallel private(@_QFEi_private_i32 %arg1 -> %arg4 : !llvm.ptr) { - omp.distribute { - omp.wsloop reduction(@add_reduction_i32 %arg3 -> %arg5 : !llvm.ptr) { - omp.loop_nest (%arg6) : i32 = (%20) to (%19) inclusive step (%20) { - llvm.store %arg6, %arg4 : i32, !llvm.ptr loc(#loc2) - %22 = llvm.load %arg5 : !llvm.ptr -> i32 loc(#loc8) - %23 = llvm.load %arg4 : !llvm.ptr -> i32 loc(#loc8) - %34 = llvm.add %22, %23 : i32 loc(#loc8) - llvm.store %34, %arg5 : i32, !llvm.ptr loc(#loc8) - omp.yield loc(#loc2) - } loc(#loc2) - } {omp.composite} loc(#loc2) - } {omp.composite} loc(#loc2) - omp.terminator loc(#loc2) - } {omp.composite} loc(#loc2) - omp.terminator loc(#loc2) - } loc(#loc2) - omp.terminator loc(#loc2) - } loc(#loc13) - llvm.return loc(#loc9) - } loc(#loc12) - llvm.mlir.global internal @_QFEarr() {addr_space = 1 : i32} : !llvm.array<8192 x i32> { - %0 = llvm.mlir.zero : !llvm.array<8192 x i32> loc(#loc6) - llvm.return %0 : !llvm.array<8192 x i32> loc(#loc6) - } loc(#loc6) -} loc(#loc) - -#loc = loc("test.f90":4:18) -#loc1 = loc("test.f90":4:18) -#loc2 = loc("test.f90":8:7) -#loc3 = loc("test.f90":1:7) -#loc4 = loc("test.f90":3:18) -#loc5 = loc(unknown) -#loc6 = loc("test.f90":5:18) -#loc7 = loc("test.f90":6:7) -#loc8 = loc("test.f90":10:7) -#loc9 = loc("test.f90":16:7) - -#di_file = #llvm.di_file<"target7.f90" in ""> -#di_null_type = #llvm.di_null_type -#di_compile_unit = #llvm.di_compile_unit<id = distinct[0]<>, - sourceLanguage = DW_LANG_Fortran95, file = #di_file, producer = "flang", - isOptimized = false, emissionKind = LineTablesOnly> -#di_subroutine_type = #llvm.di_subroutine_type< - callingConvention = DW_CC_program, types = #di_null_type> -#di_subprogram = #llvm.di_subprogram<id = distinct[1]<>, - compileUnit = #di_compile_unit, scope = #di_file, name = "main", - file = #di_file, subprogramFlags = "Definition|MainSubprogram", - type = #di_subroutine_type> -#di_subprogram1 = #llvm.di_subprogram<compileUnit = #di_compile_unit, - name = "target", file = #di_file, subprogramFlags = "Definition", - type = #di_subroutine_type> - - -#loc12 = loc(fused<#di_subprogram>[#loc3]) -#loc13 = loc(fused<#di_subprogram1>[#loc2]) - -// CHECK-DAG: define internal void @_omp_reduction_shuffle_and_reduce_func -// CHECK-NOT: !dbg -// CHECK: } -// CHECK-DAG: define internal void @_omp_reduction_inter_warp_copy_func -// CHECK-NOT: !dbg -// CHECK: } -// CHECK-DAG: define internal void @"__omp_offloading_{{.*}}__QQmain_l8_omp$reduction$reduction_func.1" -// CHECK-NOT: !dbg -// CHECK: } -// CHECK-DAG: define internal void @_omp_reduction_shuffle_and_reduce_func.2 -// CHECK-NOT: !dbg -// CHECK: } -// CHECK-DAG: define internal void @_omp_reduction_inter_warp_copy_func.3 -// CHECK-NOT: !dbg -// CHECK: } -// CHECK-DAG: define internal void @_omp_reduction_list_to_global_copy_func -// CHECK-NOT: !dbg -// CHECK: } -// CHECK-DAG: define internal void @_omp_reduction_list_to_global_reduce_func -// CHECK-NOT: !dbg -// CHECK: } -// CHECK-DAG: define internal void @_omp_reduction_global_to_list_copy_func -// CHECK-NOT: !dbg -// CHECK: } -// CHECK-DAG: define internal void @_omp_reduction_global_to_list_reduce_func -// CHECK-NOT: !dbg -// CHECK: } diff --git a/mlir/test/Target/LLVMIR/xevm.mlir b/mlir/test/Target/LLVMIR/xevm.mlir new file mode 100644 index 0000000..a3dd0b6 --- /dev/null +++ b/mlir/test/Target/LLVMIR/xevm.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-translate --split-input-file -mlir-to-llvmir %s | FileCheck %s + +module { + llvm.func spir_funccc @_Z8prefetchPU3AS1Kcm(!llvm.ptr<1>, i64) + llvm.func @prefetch(%arg0: !llvm.ptr<1>) { + %0 = llvm.mlir.constant(1 : i64) : i64 + // CHECK-LABEL: call spir_func void @_Z8prefetchPU3AS1Kcm + // CHECK-SAME: !spirv.DecorationCacheControlINTEL ![[DECO1:.*]] + llvm.call spir_funccc @_Z8prefetchPU3AS1Kcm(%arg0, %0) + {function_type = !llvm.func<void (ptr<1>, i64)>, linkage = #llvm.linkage<external>, + no_unwind, sym_name = "_Z8prefetchPU3AS1Kcm", visibility_ = 0 : i64, + xevm.DecorationCacheControl = [[6442 : i32, 0 : i32, 1 : i32, 0 : i32], [6442 : i32, 1 : i32, 1 : i32, 0 : i32]]} + : (!llvm.ptr<1>, i64) -> () + llvm.return + } +} + +// CHECK: ![[DECO1]] = !{![[DECO2:.*]], ![[DECO3:.*]]} +// CHECK: ![[DECO2]] = !{i32 6442, i32 0, i32 1, i32 0} +// CHECK: ![[DECO3]] = !{i32 6442, i32 1, i32 1, i32 0} + diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir index 76d34c2..1695d2a 100644 --- a/mlir/test/Target/SPIRV/constant.mlir +++ b/mlir/test/Target/SPIRV/constant.mlir @@ -1,6 +1,7 @@ // RUN: mlir-translate --no-implicit-module --split-input-file --test-spirv-roundtrip %s | FileCheck %s +// RUN: %if spirv-tools %{ mlir-translate -no-implicit-module --split-input-file -serialize-spirv %s | spirv-val %} -spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { +spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int64, Int16, Int8, Float64, Float16, CooperativeMatrixKHR], [SPV_KHR_vulkan_memory_model, SPV_KHR_cooperative_matrix]> { // CHECK-LABEL: @bool_const spirv.func @bool_const() -> () "None" { // CHECK: spirv.Constant true @@ -305,6 +306,36 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { %coop = spirv.Constant dense<4> : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc> spirv.ReturnValue %coop : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc> } + + // CHECK-LABEL: @arm_tensor_of_i32 + spirv.func @arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" { + // CHECK: {{%.*}} = spirv.Constant dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : !spirv.arm.tensor<2x3xi32> + %0 = spirv.Constant dense<[[1, 2, 3], [4, 5, 6]]> : !spirv.arm.tensor<2x3xi32> + spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32> + } + + // CHECK-LABEL: @splat_arm_tensor_of_i32 + spirv.func @splat_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" { + // CHECK: {{%.*}} = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32> + %0 = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32> + spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32> + } + + // CHECK-LABEL: @arm_tensor_of_f32 + spirv.func @arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" { + // CHECK: {{%.*}} = spirv.Constant dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : !spirv.arm.tensor<2x3xf32> + %0 = spirv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>: !spirv.arm.tensor<2x3xf32> + spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32> + } + + // CHECK-LABEL: @splat_arm_tensor_of_f32 + spirv.func @splat_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" { + // CHECK: {{%.*}} = spirv.Constant dense<2.000000e+00> : !spirv.arm.tensor<2x3xf32> + %0 = spirv.Constant dense<2.0> : !spirv.arm.tensor<2x3xf32> + spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32> + } + + spirv.EntryPoint "GLCompute" @bool_const } // ----- diff --git a/mlir/test/Target/SPIRV/lit.local.cfg b/mlir/test/Target/SPIRV/lit.local.cfg new file mode 100644 index 0000000..6d44394 --- /dev/null +++ b/mlir/test/Target/SPIRV/lit.local.cfg @@ -0,0 +1,4 @@ +if config.spirv_tools_tests: + config.available_features.add("spirv-tools") + config.substitutions.append(("spirv-as", os.path.join(config.llvm_tools_dir, "spirv-as"))) + config.substitutions.append(("spirv-val", os.path.join(config.llvm_tools_dir, "spirv-val"))) diff --git a/mlir/test/Target/SPIRV/logical-ops.mlir b/mlir/test/Target/SPIRV/logical-ops.mlir index b200871..05cbddc 100644 --- a/mlir/test/Target/SPIRV/logical-ops.mlir +++ b/mlir/test/Target/SPIRV/logical-ops.mlir @@ -84,6 +84,8 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { %15 = spirv.IsNan %arg0 : f32 // CHECK: spirv.IsInf %16 = spirv.IsInf %arg1 : f32 + // CHECK: spirv.IsFinite + %17 = spirv.IsFinite %arg0 : f32 spirv.Return } } diff --git a/mlir/test/Target/SPIRV/memory-ops.mlir b/mlir/test/Target/SPIRV/memory-ops.mlir index 6b50c39..786d07a2 100644 --- a/mlir/test/Target/SPIRV/memory-ops.mlir +++ b/mlir/test/Target/SPIRV/memory-ops.mlir @@ -37,32 +37,32 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { // ----- spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { - spirv.func @load_store_zero_rank_float(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>) "None" { - // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])> + spirv.func @load_store_zero_rank_float(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>) "None" { + // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer> // CHECK-NEXT: [[VAL:%.*]] = spirv.Load "StorageBuffer" [[LOAD_PTR]] : f32 %0 = spirv.Constant 0 : i32 - %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer> + %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer> %2 = spirv.Load "StorageBuffer" %1 : f32 - // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])> + // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer> // CHECK-NEXT: spirv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : f32 %3 = spirv.Constant 0 : i32 - %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer> + %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer> spirv.Store "StorageBuffer" %4, %2 : f32 spirv.Return } - spirv.func @load_store_zero_rank_int(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>) "None" { - // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])> + spirv.func @load_store_zero_rank_int(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>) "None" { + // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer> // CHECK-NEXT: [[VAL:%.*]] = spirv.Load "StorageBuffer" [[LOAD_PTR]] : i32 %0 = spirv.Constant 0 : i32 - %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer> + %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer> %2 = spirv.Load "StorageBuffer" %1 : i32 - // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])> + // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer> // CHECK-NEXT: spirv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : i32 %3 = spirv.Constant 0 : i32 - %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer> + %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer> spirv.Store "StorageBuffer" %4, %2 : i32 spirv.Return } diff --git a/mlir/test/Target/SPIRV/struct.mlir b/mlir/test/Target/SPIRV/struct.mlir index 0db0c0b..4984ee7 100644 --- a/mlir/test/Target/SPIRV/struct.mlir +++ b/mlir/test/Target/SPIRV/struct.mlir @@ -7,23 +7,23 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { // CHECK: !spirv.ptr<!spirv.struct<(f32 [0], !spirv.struct<(f32 [0], !spirv.array<16 x f32, stride=4> [4])> [4])>, Input> spirv.GlobalVariable @var1 bind(0, 2) : !spirv.ptr<!spirv.struct<(f32 [0], !spirv.struct<(f32 [0], !spirv.array<16 x f32, stride=4> [4])> [4])>, Input> - // CHECK: !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38])>, StorageBuffer> - spirv.GlobalVariable @var2 : !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38])>, StorageBuffer> + // CHECK: !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38]), Block>, StorageBuffer> + spirv.GlobalVariable @var2 : !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38]), Block>, StorageBuffer> - // CHECK: !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0])>, StorageBuffer> - spirv.GlobalVariable @var3 : !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0])>, StorageBuffer> + // CHECK: !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0]), Block>, StorageBuffer> + spirv.GlobalVariable @var3 : !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0]), Block>, StorageBuffer> - // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4])>, StorageBuffer> - spirv.GlobalVariable @var4 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4])>, StorageBuffer> + // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4]), Block>, StorageBuffer> + spirv.GlobalVariable @var4 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4]), Block>, StorageBuffer> - // CHECK: !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable])>, StorageBuffer> - spirv.GlobalVariable @var5 : !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable])>, StorageBuffer> + // CHECK: !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable]), Block>, StorageBuffer> + spirv.GlobalVariable @var5 : !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable]), Block>, StorageBuffer> - // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable])>, StorageBuffer> - spirv.GlobalVariable @var6 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable])>, StorageBuffer> + // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]), Block>, StorageBuffer> + spirv.GlobalVariable @var6 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]), Block>, StorageBuffer> - // CHECK: !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16])>, StorageBuffer> - spirv.GlobalVariable @var7 : !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16])>, StorageBuffer> + // CHECK: !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]), Block>, StorageBuffer> + spirv.GlobalVariable @var7 : !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]), Block>, StorageBuffer> // CHECK: !spirv.ptr<!spirv.struct<()>, StorageBuffer> spirv.GlobalVariable @empty : !spirv.ptr<!spirv.struct<()>, StorageBuffer> @@ -34,15 +34,17 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { // CHECK: !spirv.ptr<!spirv.struct<test_id, (!spirv.array<128 x f32, stride=4> [0])>, Input> spirv.GlobalVariable @id_var0 : !spirv.ptr<!spirv.struct<test_id, (!spirv.array<128 x f32, stride=4> [0])>, Input> + // CHECK: !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>), Block>, StorageBuffer> + spirv.GlobalVariable @recursive_simple : !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>), Block>, StorageBuffer> - // CHECK: !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>)>, StorageBuffer> - spirv.GlobalVariable @recursive_simple : !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>)>, StorageBuffer> + // CHECK: !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>), Block>, Uniform>), Block>, Uniform> + spirv.GlobalVariable @recursive_2 : !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>), Block>, Uniform>), Block>, Uniform> - // CHECK: !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>)>, Uniform>)>, Uniform> - spirv.GlobalVariable @recursive_2 : !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>)>, Uniform>)>, Uniform> + // CHECK: !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>), Block>, Uniform>), Block>, Uniform> + spirv.GlobalVariable @recursive_3 : !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>), Block>, Uniform>), Block>, Uniform> - // CHECK: !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>)>, Uniform>)>, Uniform> - spirv.GlobalVariable @recursive_3 : !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>)>, Uniform>)>, Uniform> + // CHECK: spirv.GlobalVariable @block : !spirv.ptr<!spirv.struct<vert, (vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>, Output> + spirv.GlobalVariable @block : !spirv.ptr<!spirv.struct<vert, (vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>, Output> // CHECK: !spirv.ptr<!spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, Input>, // CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, Output> diff --git a/mlir/test/Target/SPIRV/undef.mlir b/mlir/test/Target/SPIRV/undef.mlir index b9044fe..8889b80 100644 --- a/mlir/test/Target/SPIRV/undef.mlir +++ b/mlir/test/Target/SPIRV/undef.mlir @@ -13,10 +13,10 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { // CHECK: {{%.*}} = spirv.Undef : !spirv.array<4 x !spirv.array<4 x i32>> %5 = spirv.Undef : !spirv.array<4x!spirv.array<4xi32>> %6 = spirv.CompositeExtract %5[1 : i32, 2 : i32] : !spirv.array<4x!spirv.array<4xi32>> - // CHECK: {{%.*}} = spirv.Undef : !spirv.ptr<!spirv.struct<(f32)>, StorageBuffer> - %7 = spirv.Undef : !spirv.ptr<!spirv.struct<(f32)>, StorageBuffer> + // CHECK: {{%.*}} = spirv.Undef : !spirv.ptr<!spirv.struct<(f32), Block>, StorageBuffer> + %7 = spirv.Undef : !spirv.ptr<!spirv.struct<(f32), Block>, StorageBuffer> %8 = spirv.Constant 0 : i32 - %9 = spirv.AccessChain %7[%8] : !spirv.ptr<!spirv.struct<(f32)>, StorageBuffer>, i32 -> !spirv.ptr<f32, StorageBuffer> + %9 = spirv.AccessChain %7[%8] : !spirv.ptr<!spirv.struct<(f32), Block>, StorageBuffer>, i32 -> !spirv.ptr<f32, StorageBuffer> spirv.Return } } diff --git a/mlir/test/Transforms/compose-subview.mlir b/mlir/test/Transforms/compose-subview.mlir index 53fbb8a..d6fa442 100644 --- a/mlir/test/Transforms/compose-subview.mlir +++ b/mlir/test/Transforms/compose-subview.mlir @@ -1,9 +1,9 @@ // RUN: mlir-opt %s -test-compose-subview -split-input-file | FileCheck %s // CHECK-LABEL: func.func @subview_strided( -// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: 3456>> { +// CHECK-SAME: %[[input:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: 3456>> { func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: 3456>> { - // CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][3, 384] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: 3456>> + // CHECK: {{.*}} = memref.subview %[[input]][3, 384] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: 3456>> %0 = memref.subview %input[2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, strided<[1024, 1], offset: 2304>> %1 = memref.subview %0[1, 128] [1, 128] [1, 1] : memref<2x256xf32, strided<[1024, 1], offset: 2304>> to memref<1x128xf32, strided<[1024, 1], offset: 3456>> return %1 : memref<1x128xf32, strided<[1024, 1], offset: 3456>> @@ -12,9 +12,9 @@ func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, stri // ----- // CHECK-LABEL: func.func @subview_strided( -// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x10xf32, strided<[1024, 1], offset: 3745>> { +// CHECK-SAME: %[[input:.*]]: memref<4x1024xf32>) -> memref<1x10xf32, strided<[1024, 1], offset: 3745>> { func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x10xf32, strided<[1024, 1], offset: 3745>> { - // CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][3, 673] [1, 10] [1, 1] : memref<4x1024xf32> to memref<1x10xf32, strided<[1024, 1], offset: 3745>> + // CHECK: {{.*}} = memref.subview %[[input]][3, 673] [1, 10] [1, 1] : memref<4x1024xf32> to memref<1x10xf32, strided<[1024, 1], offset: 3745>> %0 = memref.subview %input[1, 512] [3, 256] [1, 1] : memref<4x1024xf32> to memref<3x256xf32, strided<[1024, 1], offset: 1536>> %1 = memref.subview %0[1, 128] [2, 128] [1, 1] : memref<3x256xf32, strided<[1024, 1], offset: 1536>> to memref<2x128xf32, strided<[1024, 1], offset: 2688>> %2 = memref.subview %1[1, 33] [1, 10] [1, 1] : memref<2x128xf32, strided<[1024, 1], offset: 2688>> to memref<1x10xf32, strided<[1024, 1], offset: 3745>> @@ -24,12 +24,12 @@ func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x10xf32, strid // ----- // CHECK-LABEL: func.func @subview_strided( -// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> { +// CHECK-SAME: %[[input:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> { func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> { - // CHECK: %[[VAL_1:.*]] = arith.constant 3 : index + // CHECK: %[[C3:.*]] = arith.constant 3 : index %cst_1 = arith.constant 1 : index %cst_2 = arith.constant 2 : index - // CHECK: %[[VAL_2:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], 384] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>> + // CHECK: {{.*}} = memref.subview %[[input]]{{\[}}%[[C3]], 384] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>> %0 = memref.subview %input[%cst_2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, strided<[1024, 1], offset: ?>> %1 = memref.subview %0[%cst_1, 128] [1, 128] [1, 1] : memref<2x256xf32, strided<[1024, 1], offset: ?>> to memref<1x128xf32, strided<[1024, 1], offset: ?>> return %1 : memref<1x128xf32, strided<[1024, 1], offset: ?>> @@ -38,13 +38,13 @@ func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, stri // ----- // CHECK-LABEL: func.func @subview_strided( -// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> { +// CHECK-SAME: %[[input:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> { func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> { - // CHECK: %[[VAL_1:.*]] = arith.constant 3 : index + // CHECK: %[[C3:.*]] = arith.constant 3 : index %cst_2 = arith.constant 2 : index - // CHECK: %[[VAL_2:.*]] = arith.constant 384 : index + // CHECK: %[[C384:.*]] = arith.constant 384 : index %cst_128 = arith.constant 128 : index - // CHECK: %[[VAL_3:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], %[[VAL_2]]] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>> + // CHECK: {{.*}} = memref.subview %[[input]]{{\[}}%[[C3]], %[[C384]]] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>> %0 = memref.subview %input[%cst_2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, strided<[1024, 1], offset: ?>> %1 = memref.subview %0[1, %cst_128] [1, 128] [1, 1] : memref<2x256xf32, strided<[1024, 1], offset: ?>> to memref<1x128xf32, strided<[1024, 1], offset: ?>> return %1 : memref<1x128xf32, strided<[1024, 1], offset: ?>> @@ -53,9 +53,9 @@ func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, stri // ----- // CHECK-LABEL: func.func @subview_strided( -// CHECK-SAME: %[[VAL_0:.*]]: memref<8x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> { +// CHECK-SAME: %[[input:.*]]: memref<8x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> { func.func @subview_strided(%input: memref<8x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> { - // CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][4, 384] [1, 64] [4, 4] : memref<8x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: 4480>> + // CHECK: {{.*}} = memref.subview %[[input]][4, 384] [1, 64] [4, 4] : memref<8x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: 4480>> %0 = memref.subview %input[2, 256] [2, 256] [2, 2] : memref<8x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: 2304>> %1 = memref.subview %0[1, 64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: 2304>> to memref<1x64xf32, strided<[4096, 4], offset: 4480>> return %1 : memref<1x64xf32, strided<[4096, 4], offset: 4480>> @@ -64,9 +64,9 @@ func.func @subview_strided(%input: memref<8x1024xf32>) -> memref<1x64xf32, strid // ----- // CHECK-LABEL: func.func @subview_strided( -// CHECK-SAME: %[[VAL_0:.*]]: memref<30x30xf32>) -> memref<2x2xf32, strided<[240, 8], offset: 217>> { +// CHECK-SAME: %[[input:.*]]: memref<30x30xf32>) -> memref<2x2xf32, strided<[240, 8], offset: 217>> { func.func @subview_strided(%input: memref<30x30xf32>) -> memref<2x2xf32, strided<[240, 8], offset: 217>> { - // CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][7, 7] [2, 2] [8, 8] : memref<30x30xf32> to memref<2x2xf32, strided<[240, 8], offset: 217>> + // CHECK: {{.*}} = memref.subview %[[input]][7, 7] [2, 2] [8, 8] : memref<30x30xf32> to memref<2x2xf32, strided<[240, 8], offset: 217>> %0 = memref.subview %input[1, 1] [12, 12] [2, 2] : memref<30x30xf32> to memref<12x12xf32, strided<[60, 2], offset: 31>> %1 = memref.subview %0[1, 1] [5, 5] [2, 2] : memref<12x12xf32, strided<[60, 2], offset: 31>> to memref<5x5xf32, strided<[120, 4], offset: 93>> %2 = memref.subview %1[1, 1] [2, 2] [2, 2] : memref<5x5xf32, strided<[120, 4], offset: 93>> to memref<2x2xf32, strided<[240, 8], offset: 217>> @@ -76,13 +76,13 @@ func.func @subview_strided(%input: memref<30x30xf32>) -> memref<2x2xf32, strided // ----- // CHECK-LABEL: func.func @subview_strided( -// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> { +// CHECK-SAME: %[[input:.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> { func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> { - // CHECK: %[[VAL_1:.*]] = arith.constant 4 : index + // CHECK: %[[C4:.*]] = arith.constant 4 : index %cst_2 = arith.constant 2 : index - // CHECK: %[[VAL_2:.*]] = arith.constant 384 : index + // CHECK: %[[C384:.*]] = arith.constant 384 : index %cst_64 = arith.constant 64 : index - // CHECK: %[[VAL_3:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], %[[VAL_2]]] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>> + // CHECK: {{.*}} = memref.subview %[[input]]{{\[}}%[[C4]], %[[C384]]] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>> %0 = memref.subview %input[%cst_2, 256] [2, 256] [2, 2] : memref<4x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: ?>> %1 = memref.subview %0[1, %cst_64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: ?>> to memref<1x64xf32, strided<[4096, 4], offset: ?>> return %1 : memref<1x64xf32, strided<[4096, 4], offset: ?>> @@ -91,13 +91,39 @@ func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x64xf32, strid // ----- // CHECK-LABEL: func.func @subview_strided( -// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> { +// CHECK-SAME: %[[input:.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> { func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> { - // CHECK: %[[VAL_1:.*]] = arith.constant 4 : index + // CHECK: %[[C4:.*]] = arith.constant 4 : index %cst_1 = arith.constant 1 : index %cst_2 = arith.constant 2 : index - // CHECK: %[[VAL_2:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], 384] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>> + // CHECK: {{.*}} = memref.subview %[[input]]{{\[}}%[[C4]], 384] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>> %0 = memref.subview %input[%cst_2, 256] [2, 256] [2, 2] : memref<4x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: ?>> %1 = memref.subview %0[%cst_1, 64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: ?>> to memref<1x64xf32, strided<[4096, 4], offset: ?>> return %1 : memref<1x64xf32, strided<[4096, 4], offset: ?>> } + +// ----- + +// CHECK-LABEL: func.func @single_dynamic_size_subview( +// CHECK-SAME: %[[input:.*]]: memref<256x?xf32>, +// CHECK-SAME: %{{.*}}: index, +// CHECK-SAME: %[[SIZE_1:.*]]: index) -> memref<8x?xf32> { +func.func @single_dynamic_size_subview(%input: memref<256x?xf32>, %size0 : index, %size1 : index) -> memref<8x?xf32>{ + %subview = memref.subview %input[0, 0][8, %size0][1, 1] : memref<256x?xf32> to memref<8x?xf32> + %subview_1 = memref.subview %subview[0, 0][8, %size1][1, 1] : memref<8x?xf32> to memref<8x?xf32> + // CHECK: %{{.*}} = memref.subview %[[input]][0, 0] [8, %[[SIZE_1]]] [1, 1] : memref<256x?xf32> to memref<8x?xf32> + return %subview_1 : memref<8x?xf32> +} + +// ----- + +// CHECK-LABEL: func.func @all_dynamic_size_subview( +// CHECK-SAME: %[[input:.*]]: memref<256x?xf32>, +// CHECK-SAME: %{{.*}}: index, +// CHECK-SAME: %[[SIZE1:.*]]: index) -> memref<?x?xf32> { +func.func @all_dynamic_size_subview(%input: memref<256x?xf32>, %size0 : index, %size1 : index) -> memref<?x?xf32>{ + %subview = memref.subview %input[0, 0][%size0, %size0][1, 1] : memref<256x?xf32> to memref<?x?xf32> + %subview_1 = memref.subview %subview[0, 0][%size1, %size1][1, 1] : memref<?x?xf32> to memref<?x?xf32> + // CHECK: {{.*}} = memref.subview %[[input]][0, 0] {{\[}}%[[SIZE1]], %[[SIZE1]]] [1, 1] : memref<256x?xf32> to memref<?x?xf32> + return %subview_1 : memref<?x?xf32> +} diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir index 3af95db..9ded6a3 100644 --- a/mlir/test/Transforms/remove-dead-values.mlir +++ b/mlir/test/Transforms/remove-dead-values.mlir @@ -548,3 +548,26 @@ func.func @test_atomic_yield(%I: memref<10xf32>, %idx : index) { func.return } +// ----- + +// CHECK-LABEL: module @return_void_with_unused_argument +module @return_void_with_unused_argument { + // CHECK-LABEL: func.func private @fn_return_void_with_unused_argument + // CHECK-SAME: (%[[ARG0_FN:.*]]: i32) + func.func private @fn_return_void_with_unused_argument(%arg0: i32, %arg1: memref<4xi32>) -> () { + %sum = arith.addi %arg0, %arg0 : i32 + %c0 = arith.constant 0 : index + %buf = memref.alloc() : memref<1xi32> + memref.store %sum, %buf[%c0] : memref<1xi32> + return + } + // CHECK-LABEL: func.func @main + // CHECK-SAME: (%[[ARG0_MAIN:.*]]: i32) + // CHECK: call @fn_return_void_with_unused_argument(%[[ARG0_MAIN]]) : (i32) -> () + func.func @main(%arg0: i32) -> memref<4xi32> { + %unused = memref.alloc() : memref<4xi32> + call @fn_return_void_with_unused_argument(%arg0, %unused) : (i32, memref<4xi32>) -> () + return %unused : memref<4xi32> + } +} + diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir index db8bd0f..9bffe92 100644 --- a/mlir/test/Transforms/test-legalize-type-conversion.mlir +++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir @@ -104,8 +104,8 @@ func.func @test_signature_conversion_no_converter() { "test.signature_conversion_no_converter"() ({ // expected-error@below {{failed to legalize unresolved materialization from ('f64') to ('f32') that remained live after conversion}} ^bb0(%arg0: f32): - "test.type_consumer"(%arg0) : (f32) -> () // expected-note@below{{see existing live user here}} + "test.type_consumer"(%arg0) : (f32) -> () "test.return"(%arg0) : (f32) -> () }) : () -> () return diff --git a/mlir/test/Transforms/test-legalizer-analysis.mlir b/mlir/test/Transforms/test-legalizer-analysis.mlir index 19a1310..5b07055 100644 --- a/mlir/test/Transforms/test-legalizer-analysis.mlir +++ b/mlir/test/Transforms/test-legalizer-analysis.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns -verify-diagnostics -test-legalize-mode=analysis %s | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns="test-legalize-mode=analysis" -verify-diagnostics %s | FileCheck %s // expected-remark@-2 {{op 'builtin.module' is legalizable}} // expected-remark@+1 {{op 'func.func' is legalizable}} diff --git a/mlir/test/Transforms/test-legalizer-full.mlir b/mlir/test/Transforms/test-legalizer-full.mlir index 5f1148c..dcd0172 100644 --- a/mlir/test/Transforms/test-legalizer-full.mlir +++ b/mlir/test/Transforms/test-legalizer-full.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns -test-legalize-mode=full -split-input-file -verify-diagnostics %s | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns="test-legalize-mode=full" -split-input-file -verify-diagnostics %s | FileCheck %s // CHECK-LABEL: func @multi_level_mapping func.func @multi_level_mapping() { diff --git a/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp b/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp index 8a01a0a..016052c 100644 --- a/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp +++ b/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp @@ -69,25 +69,25 @@ struct MathCosToVCIX final : OpRewritePattern<math::CosOp> { if (legalType.isScalable()) // Use arbitrary runtime vector length when vector type is scalable. // Proper conversion pass should take it from the IR. - rvl = rewriter.create<arith::ConstantOp>(loc, - rewriter.getI64IntegerAttr(9)); + rvl = arith::ConstantOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(9)); Value res; if (n == 1) { - res = rewriter.create<vcix::BinaryImmOp>(loc, legalType, opcodeAttr, vec, - immAttr, rvl); + res = vcix::BinaryImmOp::create(rewriter, loc, legalType, opcodeAttr, vec, + immAttr, rvl); } else { const unsigned eltCount = legalType.getShape()[0]; Type eltTy = legalType.getElementType(); - Value zero = rewriter.create<arith::ConstantOp>( - loc, eltTy, rewriter.getZeroAttr(eltTy)); - res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/); + Value zero = arith::ConstantOp::create(rewriter, loc, eltTy, + rewriter.getZeroAttr(eltTy)); + res = vector::BroadcastOp::create(rewriter, loc, opType, zero /*dummy*/); for (unsigned i = 0; i < n; ++i) { - Value extracted = rewriter.create<vector::ScalableExtractOp>( - loc, legalType, vec, i * eltCount); - Value v = rewriter.create<vcix::BinaryImmOp>(loc, legalType, opcodeAttr, - extracted, immAttr, rvl); - res = rewriter.create<vector::ScalableInsertOp>(loc, v, res, - i * eltCount); + Value extracted = vector::ScalableExtractOp::create( + rewriter, loc, legalType, vec, i * eltCount); + Value v = vcix::BinaryImmOp::create( + rewriter, loc, legalType, opcodeAttr, extracted, immAttr, rvl); + res = vector::ScalableInsertOp::create(rewriter, loc, v, res, + i * eltCount); } } rewriter.replaceOp(op, res); @@ -112,25 +112,25 @@ struct MathSinToVCIX final : OpRewritePattern<math::SinOp> { if (legalType.isScalable()) // Use arbitrary runtime vector length when vector type is scalable. // Proper conversion pass should take it from the IR. - rvl = rewriter.create<arith::ConstantOp>(loc, - rewriter.getI64IntegerAttr(9)); + rvl = arith::ConstantOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(9)); Value res; if (n == 1) { - res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec, - vec, rvl); + res = vcix::BinaryOp::create(rewriter, loc, legalType, opcodeAttr, vec, + vec, rvl); } else { const unsigned eltCount = legalType.getShape()[0]; Type eltTy = legalType.getElementType(); - Value zero = rewriter.create<arith::ConstantOp>( - loc, eltTy, rewriter.getZeroAttr(eltTy)); - res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/); + Value zero = arith::ConstantOp::create(rewriter, loc, eltTy, + rewriter.getZeroAttr(eltTy)); + res = vector::BroadcastOp::create(rewriter, loc, opType, zero /*dummy*/); for (unsigned i = 0; i < n; ++i) { - Value extracted = rewriter.create<vector::ScalableExtractOp>( - loc, legalType, vec, i * eltCount); - Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, - extracted, extracted, rvl); - res = rewriter.create<vector::ScalableInsertOp>(loc, v, res, - i * eltCount); + Value extracted = vector::ScalableExtractOp::create( + rewriter, loc, legalType, vec, i * eltCount); + Value v = vcix::BinaryOp::create(rewriter, loc, legalType, opcodeAttr, + extracted, extracted, rvl); + res = vector::ScalableInsertOp::create(rewriter, loc, v, res, + i * eltCount); } } rewriter.replaceOp(op, res); @@ -152,28 +152,28 @@ struct MathTanToVCIX final : OpRewritePattern<math::TanOp> { Location loc = op.getLoc(); Value vec = op.getOperand(); Attribute opcodeAttr = rewriter.getI64IntegerAttr(0); - Value zero = rewriter.create<arith::ConstantOp>( - loc, eltTy, rewriter.getZeroAttr(eltTy)); + Value zero = arith::ConstantOp::create(rewriter, loc, eltTy, + rewriter.getZeroAttr(eltTy)); Value rvl = nullptr; if (legalType.isScalable()) // Use arbitrary runtime vector length when vector type is scalable. // Proper conversion pass should take it from the IR. - rvl = rewriter.create<arith::ConstantOp>(loc, - rewriter.getI64IntegerAttr(9)); + rvl = arith::ConstantOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(9)); Value res; if (n == 1) { - res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec, - zero, rvl); + res = vcix::BinaryOp::create(rewriter, loc, legalType, opcodeAttr, vec, + zero, rvl); } else { const unsigned eltCount = legalType.getShape()[0]; - res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/); + res = vector::BroadcastOp::create(rewriter, loc, opType, zero /*dummy*/); for (unsigned i = 0; i < n; ++i) { - Value extracted = rewriter.create<vector::ScalableExtractOp>( - loc, legalType, vec, i * eltCount); - Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, - extracted, zero, rvl); - res = rewriter.create<vector::ScalableInsertOp>(loc, v, res, - i * eltCount); + Value extracted = vector::ScalableExtractOp::create( + rewriter, loc, legalType, vec, i * eltCount); + Value v = vcix::BinaryOp::create(rewriter, loc, legalType, opcodeAttr, + extracted, zero, rvl); + res = vector::ScalableInsertOp::create(rewriter, loc, v, res, + i * eltCount); } } rewriter.replaceOp(op, res); @@ -195,30 +195,30 @@ struct MathLogToVCIX final : OpRewritePattern<math::LogOp> { Value vec = op.getOperand(); Attribute opcodeAttr = rewriter.getI64IntegerAttr(0); Value rvl = nullptr; - Value zeroInt = rewriter.create<arith::ConstantOp>( - loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); + Value zeroInt = arith::ConstantOp::create( + rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); if (legalType.isScalable()) // Use arbitrary runtime vector length when vector type is scalable. // Proper conversion pass should take it from the IR. - rvl = rewriter.create<arith::ConstantOp>(loc, - rewriter.getI64IntegerAttr(9)); + rvl = arith::ConstantOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(9)); Value res; if (n == 1) { - res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec, - zeroInt, rvl); + res = vcix::BinaryOp::create(rewriter, loc, legalType, opcodeAttr, vec, + zeroInt, rvl); } else { const unsigned eltCount = legalType.getShape()[0]; Type eltTy = legalType.getElementType(); - Value zero = rewriter.create<arith::ConstantOp>( - loc, eltTy, rewriter.getZeroAttr(eltTy)); - res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/); + Value zero = arith::ConstantOp::create(rewriter, loc, eltTy, + rewriter.getZeroAttr(eltTy)); + res = vector::BroadcastOp::create(rewriter, loc, opType, zero /*dummy*/); for (unsigned i = 0; i < n; ++i) { - Value extracted = rewriter.create<vector::ScalableExtractOp>( - loc, legalType, vec, i * eltCount); - Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, - extracted, zeroInt, rvl); - res = rewriter.create<vector::ScalableInsertOp>(loc, v, res, - i * eltCount); + Value extracted = vector::ScalableExtractOp::create( + rewriter, loc, legalType, vec, i * eltCount); + Value v = vcix::BinaryOp::create(rewriter, loc, legalType, opcodeAttr, + extracted, zeroInt, rvl); + res = vector::ScalableInsertOp::create(rewriter, loc, v, res, + i * eltCount); } } rewriter.replaceOp(op, res); diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp index ed5d06d..3569a73 100644 --- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp +++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp @@ -145,7 +145,7 @@ static LogicalResult testReifyValueBounds(FunctionOpInterface funcOp, if (reifiedScalable->map.getNumInputs() == 1) { // The only possible input to the bound is vscale. vscaleOperand.push_back(std::make_pair( - rewriter.create<vector::VectorScaleOp>(loc), std::nullopt)); + vector::VectorScaleOp::create(rewriter, loc), std::nullopt)); } reified = affine::materializeComputedBound( rewriter, loc, reifiedScalable->map, vscaleOperand); @@ -169,8 +169,9 @@ static LogicalResult testReifyValueBounds(FunctionOpInterface funcOp, rewriter.replaceOp(op, val); return WalkResult::skip(); } - Value constOp = rewriter.create<arith::ConstantIndexOp>( - op->getLoc(), cast<IntegerAttr>(cast<Attribute>(*reified)).getInt()); + Value constOp = arith::ConstantIndexOp::create( + rewriter, op->getLoc(), + cast<IntegerAttr>(cast<Attribute>(*reified)).getInt()); rewriter.replaceOp(op, constOp); return WalkResult::skip(); }); diff --git a/mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp b/mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp index 738d4ee59..a792d08 100644 --- a/mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp +++ b/mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp @@ -60,7 +60,7 @@ struct TestEmulateWideIntPass // casts (and vice versa) and using it insted of `llvm.bitcast`. auto addBitcast = [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) -> Value { - auto cast = builder.create<LLVM::BitcastOp>(loc, type, inputs); + auto cast = LLVM::BitcastOp::create(builder, loc, type, inputs); return cast->getResult(0); }; typeConverter.addSourceMaterialization(addBitcast); diff --git a/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt b/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt index 226e0bb..2ee3222 100644 --- a/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt @@ -1,5 +1,6 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRBufferizationTestPasses + TestOneShotModuleBufferize.cpp TestTensorCopyInsertion.cpp TestTensorLikeAndBufferLike.cpp diff --git a/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp b/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp new file mode 100644 index 0000000..1e2d4a7 --- /dev/null +++ b/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp @@ -0,0 +1,57 @@ +//===- TestOneShotModuleBufferzation.cpp - Bufferization Test -----*- c++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" +#include "mlir/Dialect/Bufferization/Transforms/Transforms.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +struct TestOneShotModuleBufferizePass + : public PassWrapper<TestOneShotModuleBufferizePass, OperationPass<>> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOneShotModuleBufferizePass) + + TestOneShotModuleBufferizePass() = default; + TestOneShotModuleBufferizePass(const TestOneShotModuleBufferizePass &pass) + : PassWrapper(pass) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<bufferization::BufferizationDialect>(); + } + StringRef getArgument() const final { + return "test-one-shot-module-bufferize"; + } + StringRef getDescription() const final { + return "Pass to test One Shot Module Bufferization"; + } + + void runOnOperation() override { + + llvm::errs() << "Running TestOneShotModuleBufferize on: " + << getOperation()->getName() << "\n"; + bufferization::OneShotBufferizationOptions opt; + + opt.bufferizeFunctionBoundaries = true; + bufferization::BufferizationState bufferizationState; + + if (failed(bufferization::runOneShotModuleBufferize(getOperation(), opt, + bufferizationState))) + signalPassFailure(); + } +}; +} // namespace + +namespace mlir::test { +void registerTestOneShotModuleBufferizePass() { + PassRegistration<TestOneShotModuleBufferizePass>(); +} +} // namespace mlir::test diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt index eb2f74e..3b7bd9b 100644 --- a/mlir/test/lib/Dialect/CMakeLists.txt +++ b/mlir/test/lib/Dialect/CMakeLists.txt @@ -10,7 +10,7 @@ add_subdirectory(Linalg) add_subdirectory(LLVM) add_subdirectory(Math) add_subdirectory(MemRef) -add_subdirectory(Mesh) +add_subdirectory(Shard) add_subdirectory(NVGPU) add_subdirectory(SCF) add_subdirectory(Shape) diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp index d0b62e7..c67bcd9 100644 --- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp +++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp @@ -48,8 +48,8 @@ static SmallVector<Value> buildDecomposeTuple(OpBuilder &builder, } for (unsigned i = 0, e = tupleType.size(); i < e; ++i) { Type elementType = tupleType.getType(i); - Value element = builder.create<test::GetTupleElementOp>( - loc, elementType, tuple, builder.getI32IntegerAttr(i)); + Value element = test::GetTupleElementOp::create( + builder, loc, elementType, tuple, builder.getI32IntegerAttr(i)); decompose(element); } }; @@ -94,7 +94,7 @@ static Value buildMakeTupleOp(OpBuilder &builder, TupleType resultType, } // Assemble the tuple from the elements. - return builder.create<test::MakeTupleOp>(loc, resultType, elements); + return test::MakeTupleOp::create(builder, loc, resultType, elements); } /// A pass for testing call graph type decomposition. diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp index 9eade75..9a394d2 100644 --- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp +++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp @@ -56,7 +56,7 @@ struct TestSCFForUtilsPass SmallVector<Value> newYieldValues; for (auto yieldVal : oldYieldValues) { newYieldValues.push_back( - b.create<arith::AddFOp>(loc, yieldVal, yieldVal)); + arith::AddFOp::create(b, loc, yieldVal, yieldVal)); } return newYieldValues; }; @@ -160,13 +160,13 @@ struct TestSCFPipeliningPass Value pred) { Location loc = op->getLoc(); auto ifOp = - rewriter.create<scf::IfOp>(loc, op->getResultTypes(), pred, true); + scf::IfOp::create(rewriter, loc, op->getResultTypes(), pred, true); // True branch. rewriter.moveOpBefore(op, &ifOp.getThenRegion().front(), ifOp.getThenRegion().front().begin()); rewriter.setInsertionPointAfter(op); if (op->getNumResults() > 0) - rewriter.create<scf::YieldOp>(loc, op->getResults()); + scf::YieldOp::create(rewriter, loc, op->getResults()); // False branch. rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); SmallVector<Value> elseYieldOperands; @@ -181,12 +181,12 @@ struct TestSCFPipeliningPass } else { // Default to assuming constant numeric values. for (Type type : op->getResultTypes()) { - elseYieldOperands.push_back(rewriter.create<arith::ConstantOp>( - loc, rewriter.getZeroAttr(type))); + elseYieldOperands.push_back(arith::ConstantOp::create( + rewriter, loc, rewriter.getZeroAttr(type))); } } if (op->getNumResults() > 0) - rewriter.create<scf::YieldOp>(loc, elseYieldOperands); + scf::YieldOp::create(rewriter, loc, elseYieldOperands); return ifOp.getOperation(); } diff --git a/mlir/test/lib/Dialect/SCF/TestWhileOpBuilder.cpp b/mlir/test/lib/Dialect/SCF/TestWhileOpBuilder.cpp index d3113c0..d3f7f0e6 100644 --- a/mlir/test/lib/Dialect/SCF/TestWhileOpBuilder.cpp +++ b/mlir/test/lib/Dialect/SCF/TestWhileOpBuilder.cpp @@ -50,23 +50,23 @@ struct TestSCFWhileOpBuilderPass // Create a WhileOp with the same operands and result types. TypeRange resultTypes = whileOp->getResultTypes(); ValueRange operands = whileOp->getOperands(); - builder.create<WhileOp>( - loc, resultTypes, operands, /*beforeBuilder=*/ + WhileOp::create( + builder, loc, resultTypes, operands, /*beforeBuilder=*/ [&](OpBuilder &b, Location loc, ValueRange args) { // Just cast the before args into the right types for condition. ImplicitLocOpBuilder builder(loc, b); auto castOp = - builder.create<UnrealizedConversionCastOp>(resultTypes, args); - auto cmp = builder.create<ConstantIntOp>(/*value=*/1, /*width=*/1); - builder.create<ConditionOp>(cmp, castOp->getResults()); + UnrealizedConversionCastOp::create(builder, resultTypes, args); + auto cmp = ConstantIntOp::create(builder, /*value=*/1, /*width=*/1); + ConditionOp::create(builder, cmp, castOp->getResults()); }, /*afterBuilder=*/ [&](OpBuilder &b, Location loc, ValueRange args) { // Just cast the after args into the right types for yield. ImplicitLocOpBuilder builder(loc, b); - auto castOp = builder.create<UnrealizedConversionCastOp>( - operands.getTypes(), args); - builder.create<YieldOp>(castOp->getResults()); + auto castOp = UnrealizedConversionCastOp::create( + builder, operands.getTypes(), args); + YieldOp::create(builder, castOp->getResults()); }); }); } diff --git a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt b/mlir/test/lib/Dialect/Shard/CMakeLists.txt index 7bd0493..f91c547 100644 --- a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Shard/CMakeLists.txt @@ -1,14 +1,14 @@ # Exclude tests from libMLIR.so -add_mlir_library(MLIRMeshTest +add_mlir_library(MLIRShardTest TestOpLowering.cpp - TestReshardingSpmdization.cpp + TestReshardingPartition.cpp TestSimplifications.cpp EXCLUDE_FROM_LIBMLIR ) -mlir_target_link_libraries(MLIRMeshTest PUBLIC - MLIRMeshDialect - MLIRMeshTransforms +mlir_target_link_libraries(MLIRShardTest PUBLIC + MLIRShardDialect + MLIRShardTransforms MLIRPass MLIRRewrite MLIRTransformUtils diff --git a/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp b/mlir/test/lib/Dialect/Shard/TestOpLowering.cpp index dbae93b..43f3b3f 100644 --- a/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp +++ b/mlir/test/lib/Dialect/Shard/TestOpLowering.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Mesh/Transforms/Transforms.h" +#include "mlir/Dialect/Shard/Transforms/Transforms.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Pass/Pass.h" @@ -24,17 +24,17 @@ struct TestAllSliceOpLoweringPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); SymbolTableCollection symbolTableCollection; - mesh::populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection); + shard::populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection); LogicalResult status = applyPatternsGreedily(getOperation(), std::move(patterns)); (void)status; assert(succeeded(status) && "applyPatternsGreedily failed."); } void getDependentDialects(DialectRegistry ®istry) const override { - mesh::registerAllSliceOpLoweringDialects(registry); + shard::registerAllSliceOpLoweringDialects(registry); } StringRef getArgument() const final { - return "test-mesh-all-slice-op-lowering"; + return "test-grid-all-slice-op-lowering"; } StringRef getDescription() const final { return "Test lowering of all-slice."; @@ -48,21 +48,21 @@ struct TestMultiIndexOpLoweringPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); SymbolTableCollection symbolTableCollection; - mesh::populateProcessMultiIndexOpLoweringPatterns(patterns, - symbolTableCollection); + shard::populateProcessMultiIndexOpLoweringPatterns(patterns, + symbolTableCollection); LogicalResult status = applyPatternsGreedily(getOperation(), std::move(patterns)); (void)status; assert(succeeded(status) && "applyPatternsGreedily failed."); } void getDependentDialects(DialectRegistry ®istry) const override { - mesh::registerProcessMultiIndexOpLoweringDialects(registry); + shard::registerProcessMultiIndexOpLoweringDialects(registry); } StringRef getArgument() const final { - return "test-mesh-process-multi-index-op-lowering"; + return "test-grid-process-multi-index-op-lowering"; } StringRef getDescription() const final { - return "Test lowering of mesh.process_multi_index op."; + return "Test lowering of shard.process_multi_index op."; } }; diff --git a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp b/mlir/test/lib/Dialect/Shard/TestReshardingPartition.cpp index 102e64d..23fdad1 100644 --- a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp +++ b/mlir/test/lib/Dialect/Shard/TestReshardingPartition.cpp @@ -7,9 +7,9 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Mesh/IR/MeshOps.h" -#include "mlir/Dialect/Mesh/Transforms/Spmdization.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/Dialect/Shard/IR/ShardOps.h" +#include "mlir/Dialect/Shard/Transforms/Partition.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypeInterfaces.h" @@ -22,11 +22,11 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; -using namespace mlir::mesh; +using namespace mlir::shard; namespace { -struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> { +struct TestReshardingRewritePattern : OpRewritePattern<ShardOp> { using OpRewritePattern<ShardOp>::OpRewritePattern; LogicalResult matchAndRewrite(ShardOp op, @@ -36,18 +36,18 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> { } SymbolTableCollection symbolTable; - mesh::MeshOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>( - op, cast<ShardingOp>(op.getSharding().getDefiningOp()).getMeshAttr()); + shard::GridOp grid = symbolTable.lookupNearestSymbolFrom<shard::GridOp>( + op, cast<ShardingOp>(op.getSharding().getDefiningOp()).getGridAttr()); bool foundUser = false; for (auto user : op->getUsers()) { if (auto targetShardOp = llvm::dyn_cast<ShardOp>(user)) { if (targetShardOp.getAnnotateForUsers() && - mesh == symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>( + grid == symbolTable.lookupNearestSymbolFrom<shard::GridOp>( targetShardOp, cast<ShardingOp>( targetShardOp.getSharding().getDefiningOp()) - .getMeshAttr())) { + .getGridAttr())) { foundUser = true; break; } @@ -61,26 +61,25 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> { for (auto user : op->getUsers()) { auto targetShardOp = llvm::dyn_cast<ShardOp>(user); if (!targetShardOp || !targetShardOp.getAnnotateForUsers() || - symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>( + symbolTable.lookupNearestSymbolFrom<shard::GridOp>( targetShardOp, cast<ShardingOp>(targetShardOp.getSharding().getDefiningOp()) - .getMeshAttr()) != mesh) { + .getGridAttr()) != grid) { continue; } ImplicitLocOpBuilder builder(op->getLoc(), rewriter); ShapedType sourceShardShape = - shardShapedType(op.getResult().getType(), mesh, op.getSharding()); + shardShapedType(op.getResult().getType(), grid, op.getSharding()); TypedValue<ShapedType> sourceShard = cast<TypedValue<ShapedType>>( - builder - .create<UnrealizedConversionCastOp>(sourceShardShape, op.getSrc()) + UnrealizedConversionCastOp::create(builder, sourceShardShape, + op.getSrc()) ->getResult(0)); TypedValue<ShapedType> targetShard = - reshard(builder, mesh, op, targetShardOp, sourceShard); + reshard(builder, grid, op, targetShardOp, sourceShard); Value newTargetUnsharded = - builder - .create<UnrealizedConversionCastOp>( - targetShardOp.getResult().getType(), targetShard) + UnrealizedConversionCastOp::create( + builder, targetShardOp.getResult().getType(), targetShard) ->getResult(0); rewriter.replaceAllUsesWith(targetShardOp.getResult(), newTargetUnsharded); @@ -90,13 +89,13 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> { } }; -struct TestMeshReshardingPass - : public PassWrapper<TestMeshReshardingPass, OperationPass<ModuleOp>> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshReshardingPass) +struct TestReshardingPass + : public PassWrapper<TestReshardingPass, OperationPass<ModuleOp>> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReshardingPass) void runOnOperation() override { RewritePatternSet patterns(&getContext()); - patterns.insert<TestMeshReshardingRewritePattern>(&getContext()); + patterns.insert<TestReshardingRewritePattern>(&getContext()); if (failed(applyPatternsGreedily(getOperation().getOperation(), std::move(patterns)))) { return signalPassFailure(); @@ -107,18 +106,18 @@ struct TestMeshReshardingPass registry.insert<BuiltinDialect>(); } StringRef getArgument() const final { - return "test-mesh-resharding-spmdization"; + return "test-grid-resharding-partition"; } StringRef getDescription() const final { - return "Test Mesh dialect resharding spmdization."; + return "Test Shard dialect resharding partition."; } }; } // namespace namespace mlir { namespace test { -void registerTestMeshReshardingSpmdizationPass() { - PassRegistration<TestMeshReshardingPass>(); +void registerTestReshardingPartitionPass() { + PassRegistration<TestReshardingPass>(); } } // namespace test } // namespace mlir diff --git a/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp b/mlir/test/lib/Dialect/Shard/TestSimplifications.cpp index 01e196d..2885215 100644 --- a/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp +++ b/mlir/test/lib/Dialect/Shard/TestSimplifications.cpp @@ -7,8 +7,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Mesh/IR/MeshDialect.h" -#include "mlir/Dialect/Mesh/Transforms/Simplifications.h" +#include "mlir/Dialect/Shard/IR/ShardDialect.h" +#include "mlir/Dialect/Shard/Transforms/Simplifications.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -16,23 +16,23 @@ using namespace mlir; namespace { -struct TestMeshSimplificationsPass - : public PassWrapper<TestMeshSimplificationsPass, OperationPass<>> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshSimplificationsPass) +struct TestShardSimplificationsPass + : public PassWrapper<TestShardSimplificationsPass, OperationPass<>> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestShardSimplificationsPass) void runOnOperation() override; void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert<arith::ArithDialect, mesh::MeshDialect>(); + registry.insert<arith::ArithDialect, shard::ShardDialect>(); } - StringRef getArgument() const final { return "test-mesh-simplifications"; } - StringRef getDescription() const final { return "Test mesh simplifications"; } + StringRef getArgument() const final { return "test-grid-simplifications"; } + StringRef getDescription() const final { return "Test grid simplifications"; } }; } // namespace -void TestMeshSimplificationsPass::runOnOperation() { +void TestShardSimplificationsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); SymbolTableCollection symbolTableCollection; - mesh::populateSimplificationPatterns(patterns, symbolTableCollection); + shard::populateSimplificationPatterns(patterns, symbolTableCollection); [[maybe_unused]] LogicalResult status = applyPatternsGreedily(getOperation(), std::move(patterns)); assert(succeeded(status) && "Rewrite patters application did not converge."); @@ -40,8 +40,8 @@ void TestMeshSimplificationsPass::runOnOperation() { namespace mlir { namespace test { -void registerTestMeshSimplificationsPass() { - PassRegistration<TestMeshSimplificationsPass>(); +void registerTestShardSimplificationsPass() { + PassRegistration<TestShardSimplificationsPass>(); } } // namespace test } // namespace mlir diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp index 0e191c3..687473e 100644 --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -192,8 +192,8 @@ struct RewriteExtractSliceFromCollapseShapeBase // Create the destination tensor using the above values. Type elementType = op.getSourceType().getElementType(); SmallVector<OpFoldResult> outputShape = reifiedShapes[0]; - Value dest = rewriter.create<tensor::EmptyOp>(op->getLoc(), outputShape, - elementType); + Value dest = tensor::EmptyOp::create(rewriter, op->getLoc(), outputShape, + elementType); // Calculate the parameters for the tile loop nest. FailureOr<tensor::ExtractSliceFromCollapseHelper> params = @@ -215,8 +215,8 @@ struct RewriteExtractSliceFromCollapseShapeUsingScfFor PatternRewriter &rewriter) const override { Location loc = op.getLoc(); const unsigned numTiledDims = helper.getIterationSpaceSizes().size(); - auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); - auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1); + auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + auto one = arith::ConstantIndexOp::create(rewriter, loc, 1); SmallVector<Value> lbs(numTiledDims, zero); SmallVector<Value> steps(numTiledDims, one); @@ -228,8 +228,8 @@ struct RewriteExtractSliceFromCollapseShapeUsingScfFor helper.emitLoopNestBody(nestedBuilder, loc, outputIvs); // Insert the slice into the destination. - return {nestedBuilder.create<tensor::InsertSliceOp>( - loc, tile, iterArgs[0], insertParams)}; + return {tensor::InsertSliceOp::create(nestedBuilder, loc, tile, + iterArgs[0], insertParams)}; }); rewriter.replaceOp(op, nest.results); @@ -245,8 +245,9 @@ struct RewriteExtractSliceFromCollapseShapeUsingScfForeach tensor::ExtractSliceFromCollapseHelper &helper, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - auto forallOp = rewriter.create<scf::ForallOp>( - loc, /*numThreads=*/getAsOpFoldResult(helper.getIterationSpaceSizes()), + auto forallOp = scf::ForallOp::create( + rewriter, loc, + /*numThreads=*/getAsOpFoldResult(helper.getIterationSpaceSizes()), /*outputs=*/dest, /*mapping=*/std::nullopt, [&](OpBuilder &nestedBuilder, Location loc, ValueRange regionArgs) { @@ -261,10 +262,10 @@ struct RewriteExtractSliceFromCollapseShapeUsingScfForeach auto [tile, insertParams] = helper.emitLoopNestBody(nestedBuilder, loc, outputIvs); // Insert the slice into the destination. - auto term = nestedBuilder.create<scf::InParallelOp>(loc); + auto term = scf::InParallelOp::create(nestedBuilder, loc); nestedBuilder.setInsertionPointToStart(term.getBody()); - nestedBuilder.create<tensor::ParallelInsertSliceOp>( - loc, tile, outputArgs[0], insertParams); + tensor::ParallelInsertSliceOp::create(nestedBuilder, loc, tile, + outputArgs[0], insertParams); }); rewriter.replaceOp(op, forallOp->getResult(0)); return success(); @@ -355,8 +356,8 @@ static LogicalResult testTrackingListenerReplacements(Operation *rootOp) { MLIRContext *context = rootOp->getContext(); OpBuilder builder(context); OwningOpRef<transform::NamedSequenceOp> transformOp = - builder.create<transform::NamedSequenceOp>( - rootOp->getLoc(), + transform::NamedSequenceOp::create( + builder, rootOp->getLoc(), /*sym_name=*/"test_sequence", /*function_type=*/ TypeAttr::get(FunctionType::get(context, TypeRange{}, TypeRange{})), diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td index 382da59..5685004 100644 --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -347,6 +347,7 @@ def TestCopyCount : Test_Attr<"TestCopyCount"> { let mnemonic = "copy_count"; let parameters = (ins TestParamCopyCount:$copy_count); let assemblyFormat = "`<` $copy_count `>`"; + let genVerifyDecl = 1; } def TestConditionalAliasAttr : Test_Attr<"TestConditionalAlias"> { diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp index b31e90f..5890913 100644 --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -214,6 +214,16 @@ static void printTrueFalse(AsmPrinter &p, std::optional<int> result) { } //===----------------------------------------------------------------------===// +// TestCopyCountAttr Implementation +//===----------------------------------------------------------------------===// + +LogicalResult TestCopyCountAttr::verify( + llvm::function_ref<::mlir::InFlightDiagnostic()> /*emitError*/, + CopyCount /*copy_count*/) { + return success(); +} + +//===----------------------------------------------------------------------===// // CopyCountAttr Implementation //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index 1bbf2cc..a4c615b 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -346,7 +346,7 @@ TestDialect::~TestDialect() { Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { - return builder.create<TestOpConstant>(loc, type, value); + return TestOpConstant::create(builder, loc, type, value); } void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp index 01ae245..1235a5f 100644 --- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp @@ -354,7 +354,7 @@ struct TestInlinerInterface : public DialectInlinerInterface { !(input.getType().isSignlessInteger(16) || input.getType().isSignlessInteger(32))) return nullptr; - return builder.create<TestCastOp>(conversionLoc, resultType, input); + return TestCastOp::create(builder, conversionLoc, resultType, input); } Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable, @@ -362,16 +362,16 @@ struct TestInlinerInterface : public DialectInlinerInterface { DictionaryAttr argumentAttrs) const final { if (!argumentAttrs.contains("test.handle_argument")) return argument; - return builder.create<TestTypeChangerOp>(call->getLoc(), argument.getType(), - argument); + return TestTypeChangerOp::create(builder, call->getLoc(), + argument.getType(), argument); } Value handleResult(OpBuilder &builder, Operation *call, Operation *callable, Value result, DictionaryAttr resultAttrs) const final { if (!resultAttrs.contains("test.handle_result")) return result; - return builder.create<TestTypeChangerOp>(call->getLoc(), result.getType(), - result); + return TestTypeChangerOp::create(builder, call->getLoc(), result.getType(), + result); } void processInlinedCallBlocks( diff --git a/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp b/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp index dc6413b..b98f6ce 100644 --- a/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp +++ b/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp @@ -43,11 +43,11 @@ static LogicalResult convertLoad(OpBuilder &builder, llvm::Instruction *inst, if (failed(addr)) return failure(); // Create the LoadOp - Value loadOp = builder.create<LLVM::LoadOp>( - moduleImport.translateLoc(inst->getDebugLoc()), + Value loadOp = LLVM::LoadOp::create( + builder, moduleImport.translateLoc(inst->getDebugLoc()), moduleImport.convertType(inst->getType()), *addr); - moduleImport.mapValue(inst) = builder.create<SameOperandElementTypeOp>( - loadOp.getLoc(), loadOp.getType(), loadOp, loadOp); + moduleImport.mapValue(inst) = SameOperandElementTypeOp::create( + builder, loadOp.getLoc(), loadOp.getType(), loadOp, loadOp); return success(); } diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp index 3ab4ef2..53055fe 100644 --- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp @@ -18,6 +18,32 @@ using namespace mlir; using namespace test; //===----------------------------------------------------------------------===// +// OverridenSymbolVisibilityOp +//===----------------------------------------------------------------------===// + +SymbolTable::Visibility OverriddenSymbolVisibilityOp::getVisibility() { + return SymbolTable::Visibility::Private; +} + +static StringLiteral getVisibilityString(SymbolTable::Visibility visibility) { + switch (visibility) { + case SymbolTable::Visibility::Private: + return "private"; + case SymbolTable::Visibility::Nested: + return "nested"; + case SymbolTable::Visibility::Public: + return "public"; + } +} + +void OverriddenSymbolVisibilityOp::setVisibility( + SymbolTable::Visibility visibility) { + + emitOpError("cannot change visibility of symbol to ") + << getVisibilityString(visibility); +} + +//===----------------------------------------------------------------------===// // TestBranchOp //===----------------------------------------------------------------------===// @@ -286,9 +312,9 @@ LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value { return builder.createOrFold<tensor::DimOp>(loc, operand, dim); })); - shapes.push_back(builder.create<tensor::FromElementsOp>( - getLoc(), RankedTensorType::get({rank}, builder.getIndexType()), - currShape)); + shapes.push_back(tensor::FromElementsOp::create( + builder, getLoc(), + RankedTensorType::get({rank}, builder.getIndexType()), currShape)); } return success(); } @@ -1302,8 +1328,8 @@ llvm::SmallVector<MemorySlot> TestMultiSlotAlloca::getPromotableSlots() { Value TestMultiSlotAlloca::getDefaultValue(const MemorySlot &slot, OpBuilder &builder) { - return builder.create<TestOpConstant>(getLoc(), slot.elemType, - builder.getI32IntegerAttr(42)); + return TestOpConstant::create(builder, getLoc(), slot.elemType, + builder.getI32IntegerAttr(42)); } void TestMultiSlotAlloca::handleBlockArgument(const MemorySlot &slot, @@ -1335,7 +1361,7 @@ createNewMultiAllocaWithoutSlot(const MemorySlot &slot, OpBuilder &builder, OpBuilder::InsertionGuard guard(builder); builder.setInsertionPoint(oldOp); auto replacement = - builder.create<TestMultiSlotAlloca>(oldOp->getLoc(), newTypes); + TestMultiSlotAlloca::create(builder, oldOp->getLoc(), newTypes); for (auto [oldResult, newResult] : llvm::zip_equal(remainingValues, replacement.getResults())) oldResult.replaceAllUsesWith(newResult); @@ -1384,7 +1410,7 @@ DenseMap<Attribute, MemorySlot> TestMultiSlotAlloca::destructure( for (Attribute usedIndex : usedIndices) { Type elemType = slot.subelementTypes.lookup(usedIndex); MemRefType elemPtr = MemRefType::get({}, elemType); - auto subAlloca = builder.create<TestMultiSlotAlloca>(getLoc(), elemPtr); + auto subAlloca = TestMultiSlotAlloca::create(builder, getLoc(), elemPtr); newAllocators.push_back(subAlloca); slotMap.try_emplace<MemorySlot>(usedIndex, {subAlloca.getResult(0), elemType}); @@ -1412,8 +1438,8 @@ TestMultiSlotAlloca::handleDestructuringComplete( const auto bufferizedOutType = test::TestMemrefType::get( getContext(), outType.getShape(), outType.getElementType(), nullptr); // replace op with memref analogy - auto dummyMemrefOp = rewriter.create<test::TestDummyMemrefOp>( - getLoc(), bufferizedOutType, *buffer); + auto dummyMemrefOp = test::TestDummyMemrefOp::create( + rewriter, getLoc(), bufferizedOutType, *buffer); mlir::bufferization::replaceOpWithBufferizedValues(rewriter, getOperation(), dummyMemrefOp.getResult()); @@ -1434,7 +1460,7 @@ TestMultiSlotAlloca::handleDestructuringComplete( // replace op with memref analogy auto createMemrefOp = - rewriter.create<test::TestCreateMemrefOp>(getLoc(), *bufferizedOutType); + test::TestCreateMemrefOp::create(rewriter, getLoc(), *bufferizedOutType); mlir::bufferization::replaceOpWithBufferizedValues( rewriter, getOperation(), createMemrefOp.getResult()); diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index ab3f847..2eaad55 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -119,12 +119,28 @@ def SymbolOp : TEST_Op<"symbol", [NoMemoryEffect, Symbol]> { OptionalAttr<StrAttr>:$sym_visibility); } +def OverriddenSymbolVisibilityOp : TEST_Op<"overridden_symbol_visibility", [ + DeclareOpInterfaceMethods<Symbol, ["getVisibility", "setVisibility"]>, +]> { + let summary = "operation overridden symbol visibility accessors"; + let arguments = (ins StrAttr:$sym_name); +} + def SymbolScopeOp : TEST_Op<"symbol_scope", [SymbolTable, SingleBlockImplicitTerminator<"TerminatorOp">]> { let summary = "operation which defines a new symbol table"; let regions = (region SizedRegion<1>:$region); } +def SymbolScopeIsolatedOp + : TEST_Op<"symbol_scope_isolated", [IsolatedFromAbove, SymbolTable, + SingleBlockImplicitTerminator< + "TerminatorOp">]> { + let summary = + "operation which defines a new symbol table that is IsolatedFromAbove"; + let regions = (region SizedRegion<1>:$region); +} + def SymbolTableRegionOp : TEST_Op<"symbol_table_region", [SymbolTable]> { let summary = "operation which defines a new symbol table without a " "restriction on a terminator"; @@ -2035,7 +2051,7 @@ def IllegalOpWithRegion : TEST_Op<"illegal_op_with_region"> { OpBuilder::InsertionGuard g($_builder); Block *body = $_builder.createBlock(bodyRegion); $_builder.setInsertionPointToEnd(body); - $_builder.create<IllegalOpTerminator>($_state.location); + IllegalOpTerminator::create($_builder,$_state.location); }]>]; } def IllegalOpWithRegionAnchor : TEST_Op<"illegal_op_with_region_anchor">; @@ -2738,7 +2754,7 @@ def TestLinalgConvOp : static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block, mlir::ArrayRef<mlir::NamedAttribute> attrs, llvm::function_ref<mlir::InFlightDiagnostic()> emitError) { - b.create<mlir::linalg::YieldOp>(block.getArguments().back()); + mlir::linalg::YieldOp::create(b,block.getArguments().back()); } static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &, @@ -2801,7 +2817,7 @@ def TestLinalgFillOp : static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block, mlir::ArrayRef<mlir::NamedAttribute> attrs, llvm::function_ref<mlir::InFlightDiagnostic()> emitError) { - b.create<mlir::linalg::YieldOp>(block.getArguments().back()); + mlir::linalg::YieldOp::create(b,block.getArguments().back()); } static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &, diff --git a/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp b/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp index 6d4e5e3..cc131ad 100644 --- a/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp @@ -313,7 +313,7 @@ ParseResult WrappingRegionOp::parse(OpAsmParser &parser, SmallVector<Value, 8> returnOperands(wrappedOp->getResults()); OpBuilder builder(parser.getContext()); builder.setInsertionPointToEnd(&block); - builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands); + TestReturnOp::create(builder, wrappedOp->getLoc(), returnOperands); // Get the results type for the wrapping op from the terminator operands. Operation &returnOp = body.back().back(); @@ -397,7 +397,7 @@ ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser, builder.create(opLoc, innerOpName, /*operands=*/{lhs, rhs}, innerOpType); // Insert a return statement in the block returning the inner-op's result. - builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults()); + TestReturnOp::create(builder, innerOp->getLoc(), innerOp->getResults()); // Populate the op operation-state with result-type and location. result.addTypes(opFntype.getResults()); diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index b4aeccf..eda618f 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -33,14 +33,14 @@ static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { } static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { - rewriter.create<OpI>(loc, input); + OpI::create(rewriter, loc, input); } static void handleNoResultOp(PatternRewriter &rewriter, OpSymbolBindingNoResult op) { // Turn the no result op to a one-result op. - rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(), - op.getOperand()); + OpSymbolBindingB::create(rewriter, op.getLoc(), op.getOperand().getType(), + op.getOperand()); } static bool getFirstI32Result(Operation *op, Value &value) { @@ -120,8 +120,8 @@ public: return failure(); rewriter.setInsertionPointToStart(op->getBlock()); - auto constOp = rewriter.create<arith::ConstantOp>( - op.getLoc(), rewriter.getBoolAttr(true)); + auto constOp = arith::ConstantOp::create(rewriter, op.getLoc(), + rewriter.getBoolAttr(true)); rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(), Value(constOp)); return success(); @@ -139,8 +139,7 @@ public: LogicalResult matchAndRewrite(TestCommutative2Op op, PatternRewriter &rewriter) const override { - auto operand = - dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp()); + auto operand = op->getOperand(0).getDefiningOp<TestCommutative2Op>(); if (!operand) return failure(); Attribute constInput; @@ -845,8 +844,8 @@ struct TestRegionRewriteUndo : public RewritePattern { rewriter.getUnknownLoc()); // Add an explicitly illegal operation to ensure the conversion fails. - rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); - rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); + ILLegalOpF::create(rewriter, op->getLoc(), rewriter.getIntegerType(32)); + TestValidOp::create(rewriter, op->getLoc(), ArrayRef<Value>()); // Drop this operation. rewriter.eraseOp(op); @@ -865,7 +864,7 @@ struct TestCreateBlock : public RewritePattern { Type i32Type = rewriter.getIntegerType(32); Location loc = op->getLoc(); rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); - rewriter.create<TerminatorOp>(loc); + TerminatorOp::create(rewriter, loc); rewriter.eraseOp(op); return success(); } @@ -884,8 +883,8 @@ struct TestCreateIllegalBlock : public RewritePattern { Location loc = op->getLoc(); rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); // Create an illegal op to ensure the conversion fails. - rewriter.create<ILLegalOpF>(loc, i32Type); - rewriter.create<TerminatorOp>(loc); + ILLegalOpF::create(rewriter, loc, i32Type); + TerminatorOp::create(rewriter, loc); rewriter.eraseOp(op); return success(); } @@ -940,7 +939,7 @@ struct TestUndoBlockErase : public ConversionPattern { ConversionPatternRewriter &rewriter) const final { Block *secondBlock = &*std::next(op->getRegion(0).begin()); rewriter.setInsertionPointToStart(secondBlock); - rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); + ILLegalOpF::create(rewriter, op->getLoc(), rewriter.getF32Type()); rewriter.eraseBlock(secondBlock); rewriter.modifyOpInPlace(op, [] {}); return success(); @@ -1008,9 +1007,8 @@ struct TestPassthroughInvalidOp : public ConversionPattern { // This is a 1:N replacement. Insert a test.cast op. (That's what the // argument materialization used to do.) flattened.push_back( - rewriter - .create<TestCastOp>(op->getLoc(), - op->getOperand(it.index()).getType(), range) + TestCastOp::create(rewriter, op->getLoc(), + op->getOperand(it.index()).getType(), range) .getResult()); } rewriter.replaceOpWithNewOp<TestValidOp>(op, TypeRange(), flattened, @@ -1115,8 +1113,8 @@ struct TestNonRootReplacement : public RewritePattern { LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final { auto resultType = *op->result_type_begin(); - auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); - auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); + auto illegalOp = ILLegalOpF::create(rewriter, op->getLoc(), resultType); + auto legalOp = LegalOpB::create(rewriter, op->getLoc(), resultType); rewriter.replaceOp(illegalOp, legalOp); rewriter.replaceOp(op, illegalOp); @@ -1182,7 +1180,7 @@ struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> { LogicalResult matchAndRewrite(ILLegalOpG op, PatternRewriter &rewriter) const final { IntegerAttr attr = rewriter.getI32IntegerAttr(0); - Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr); + Value val = arith::ConstantOp::create(rewriter, op->getLoc(), attr); rewriter.replaceOpWithNewOp<LegalOpC>(op, val); return success(); }; @@ -1355,7 +1353,7 @@ struct TestTypeConverter : public TypeConverter { /// 1->N type mappings. static Value materializeCast(OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) { - return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); + return TestCastOp::create(builder, loc, resultType, inputs).getResult(); } }; @@ -1363,6 +1361,10 @@ struct TestLegalizePatternDriver : public PassWrapper<TestLegalizePatternDriver, OperationPass<>> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver) + TestLegalizePatternDriver() = default; + TestLegalizePatternDriver(const TestLegalizePatternDriver &other) + : PassWrapper(other) {} + StringRef getArgument() const final { return "test-legalize-patterns"; } StringRef getDescription() const final { return "Run test dialect legalization patterns"; @@ -1370,8 +1372,6 @@ struct TestLegalizePatternDriver /// The mode of conversion to use with the driver. enum class ConversionMode { Analysis, Full, Partial }; - TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} - void getDependentDialects(DialectRegistry ®istry) const override { registry.insert<func::FuncDialect, test::TestDialect>(); } @@ -1500,24 +1500,19 @@ struct TestLegalizePatternDriver op->emitRemark() << "op '" << op->getName() << "' is legalizable"; } - /// The mode of conversion to use. - ConversionMode mode; + Option<ConversionMode> mode{ + *this, "test-legalize-mode", + llvm::cl::desc("The legalization mode to use with the test driver"), + llvm::cl::init(ConversionMode::Partial), + llvm::cl::values( + clEnumValN(ConversionMode::Analysis, "analysis", + "Perform an analysis conversion"), + clEnumValN(ConversionMode::Full, "full", "Perform a full conversion"), + clEnumValN(ConversionMode::Partial, "partial", + "Perform a partial conversion"))}; }; } // namespace -static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> - legalizerConversionMode( - "test-legalize-mode", - llvm::cl::desc("The legalization mode to use with the test driver"), - llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), - llvm::cl::values( - clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, - "analysis", "Perform an analysis conversion"), - clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", - "Perform a full conversion"), - clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, - "partial", "Perform a partial conversion"))); - //===----------------------------------------------------------------------===// // ConversionPatternRewriter::getRemappedValue testing. This method is used // to get the remapped value of an original value that was replaced using @@ -1917,15 +1912,15 @@ struct TestTypeConversionDriver // Allow casting from F64 back to F32. if (!resultType.isF16() && inputs.size() == 1 && inputs[0].getType().isF64()) - return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); + return TestCastOp::create(builder, loc, resultType, inputs).getResult(); // Allow producing an i32 or i64 from nothing. if ((resultType.isInteger(32) || resultType.isInteger(64)) && inputs.empty()) - return builder.create<TestTypeProducerOp>(loc, resultType); + return TestTypeProducerOp::create(builder, loc, resultType); // Allow producing an i64 from an integer. if (isa<IntegerType>(resultType) && inputs.size() == 1 && isa<IntegerType>(inputs[0].getType())) - return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); + return TestCastOp::create(builder, loc, resultType, inputs).getResult(); // Otherwise, fail. return nullptr; }); @@ -2008,7 +2003,7 @@ struct TestTargetMaterializationWithNoUses }); converter.addTargetMaterialization( [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) { - return builder.create<TestCastOp>(loc, type, inputs).getResult(); + return TestCastOp::create(builder, loc, type, inputs).getResult(); }); ConversionTarget target(getContext()); @@ -2059,7 +2054,7 @@ struct TestUndoBlocksMerge : public ConversionPattern { Operation *branchOp = firstBlock.getTerminator(); Block *secondBlock = &*(std::next(op->getRegion(0).begin())); rewriter.setInsertionPointToStart(secondBlock); - rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); + ILLegalOpF::create(rewriter, op->getLoc(), rewriter.getF32Type()); auto succOperands = branchOp->getOperands(); SmallVector<Value, 2> replacements(succOperands); rewriter.eraseOp(branchOp); @@ -2203,9 +2198,7 @@ void registerPatternsTestPass() { PassRegistration<TestStrictPatternDriver>(); PassRegistration<TestWalkPatternDriver>(); - PassRegistration<TestLegalizePatternDriver>([] { - return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode); - }); + PassRegistration<TestLegalizePatternDriver>(); PassRegistration<TestRemappedValue>(); diff --git a/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp b/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp index 103817d..7831b27 100644 --- a/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp +++ b/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp @@ -68,8 +68,8 @@ LogicalResult TestDialectLLVMIRTranslationInterface::amendOperation( if (createSymbol) { OpBuilder builder(op->getRegion(0)); - builder.create<test::SymbolOp>( - op->getLoc(), + test::SymbolOp::create( + builder, op->getLoc(), StringAttr::get(op->getContext(), "sym_from_attr"), /*sym_visibility=*/nullptr); } diff --git a/mlir/test/lib/Dialect/TestIRDLToCpp/TestIRDLToCppDialect.cpp b/mlir/test/lib/Dialect/TestIRDLToCpp/TestIRDLToCppDialect.cpp index bda614a..9550e4c 100644 --- a/mlir/test/lib/Dialect/TestIRDLToCpp/TestIRDLToCppDialect.cpp +++ b/mlir/test/lib/Dialect/TestIRDLToCpp/TestIRDLToCppDialect.cpp @@ -47,9 +47,9 @@ struct TestOpConversion : public OpConversionPattern<test_irdl_to_cpp::BeefOp> { op, op->getResultTypes().front()); rewriter.setInsertionPointAfter(bar); - rewriter.create<test_irdl_to_cpp::HashOp>( - bar.getLoc(), rewriter.getIntegerType(32), adaptor.getLhs(), - adaptor.getRhs()); + test_irdl_to_cpp::HashOp::create(rewriter, bar.getLoc(), + rewriter.getIntegerType(32), + adaptor.getLhs(), adaptor.getRhs()); return success(); } }; diff --git a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp index 3389a1c..6457487 100644 --- a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp +++ b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp @@ -87,9 +87,9 @@ ConvertTosaNegateOp::matchAndRewrite(Operation *op, return failure(); auto newConstOp = - rewriter.create<tosa::ConstOp>(op->getLoc(), dstQConstType, inputElems); - auto newNegateOp = rewriter.create<tosa::NegateOp>( - op->getLoc(), dstQConstType, newConstOp.getResult()); + tosa::ConstOp::create(rewriter, op->getLoc(), dstQConstType, inputElems); + auto newNegateOp = tosa::NegateOp::create( + rewriter, op->getLoc(), dstQConstType, newConstOp.getResult()); rewriter.replaceOp(op, {newNegateOp.getResult()}); return success(); @@ -145,8 +145,8 @@ ConvertTosaConv2DOp::matchAndRewrite(Operation *op, auto newTosaConv2DOpType = RankedTensorType::get(outputType.getShape(), rewriter.getIntegerType(32)); - auto newTosaConv2DOp = rewriter.create<tosa::Conv2DOp>( - op->getLoc(), newTosaConv2DOpType, tosaConv2DOp.getInput(), + auto newTosaConv2DOp = tosa::Conv2DOp::create( + rewriter, op->getLoc(), newTosaConv2DOpType, tosaConv2DOp.getInput(), tosaConv2DOp.getWeight(), tosaConv2DOp.getBias(), tosaConv2DOp.getPadAttr(), tosaConv2DOp.getStrideAttr(), tosaConv2DOp.getDilationAttr(), tosaConv2DOp.getAccTypeAttr()); @@ -178,8 +178,8 @@ ConvertTosaConv2DOp::matchAndRewrite(Operation *op, newTosaConv2DOp.getResult().getType().isUnsignedInteger(); bool outputUnsigned = outputType.isUnsignedInteger(); - auto newTosaRescaleOp = rewriter.create<tosa::RescaleOp>( - op->getLoc(), outputType, newTosaConv2DOp.getResult(), + auto newTosaRescaleOp = tosa::RescaleOp::create( + rewriter, op->getLoc(), outputType, newTosaConv2DOp.getResult(), getConstTensorInt<int32_t>(rewriter, op->getLoc(), {multiplier}), getConstTensorInt<int8_t>(rewriter, op->getLoc(), {static_cast<int8_t>(shift)}), diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp index cdf44c2..97fc699 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -796,8 +796,8 @@ DiagnosedSilenceableFailure mlir::test::TestProduceInvalidIR::applyToOne( transform::TransformState &state) { // Provide some IR that does not verify. rewriter.setInsertionPointToStart(&target->getRegion(0).front()); - rewriter.create<TestDummyPayloadOp>(target->getLoc(), TypeRange(), - ValueRange(), /*failToVerify=*/true); + TestDummyPayloadOp::create(rewriter, target->getLoc(), TypeRange(), + ValueRange(), /*failToVerify=*/true); return DiagnosedSilenceableFailure::success(); } @@ -877,7 +877,8 @@ public: Location loc) -> Value { if (inputs.size() != 1) return Value(); - return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs) + return UnrealizedConversionCastOp::create(builder, loc, resultType, + inputs) .getResult(0); }; addSourceMaterialization(unrealizedCastConverter); diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index a7285ab..f89c944 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -546,8 +546,8 @@ static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder, auto ip = builder.saveInsertionPoint(); builder.setInsertionPoint(moduleOp); - auto global = builder.create<memref::GlobalOp>( - loc, + auto global = memref::GlobalOp::create( + builder, loc, /*sym_name=*/symbolName, /*sym_visibility=*/builder.getStringAttr("private"), /*type=*/memrefType, @@ -560,19 +560,18 @@ static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder, global->moveBefore(&moduleOp.front()); builder.restoreInsertionPoint(ip); - return builder.create<memref::GetGlobalOp>(loc, memrefType, symbolName); + return memref::GetGlobalOp::create(builder, loc, memrefType, symbolName); } static Value warpReduction(Location loc, OpBuilder &builder, Value input, CombiningKind kind, uint32_t size) { // First reduce on a single thread to get per lane reduction value. - Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input); + Value laneVal = vector::ReductionOp::create(builder, loc, kind, input); // Parallel reduction using butterfly shuffles. for (uint64_t i = 1; i < size; i <<= 1) { - Value shuffled = builder - .create<gpu::ShuffleOp>(loc, laneVal, i, - /*width=*/size, - /*mode=*/gpu::ShuffleMode::XOR) + Value shuffled = gpu::ShuffleOp::create(builder, loc, laneVal, i, + /*width=*/size, + /*mode=*/gpu::ShuffleMode::XOR) .getShuffleResult(); laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled); } @@ -647,12 +646,11 @@ struct TestVectorDistribution "unsupported shuffle type"); Type i32Type = builder.getIntegerType(32); Value srcIdxI32 = - builder.create<arith::IndexCastOp>(loc, i32Type, srcIdx); - Value warpSzI32 = builder.create<arith::ConstantOp>( - loc, builder.getIntegerAttr(i32Type, warpSz)); - Value result = builder - .create<gpu::ShuffleOp>(loc, val, srcIdxI32, warpSzI32, - gpu::ShuffleMode::IDX) + arith::IndexCastOp::create(builder, loc, i32Type, srcIdx); + Value warpSzI32 = arith::ConstantOp::create( + builder, loc, builder.getIntegerAttr(i32Type, warpSz)); + Value result = gpu::ShuffleOp::create(builder, loc, val, srcIdxI32, + warpSzI32, gpu::ShuffleMode::IDX) .getResult(0); return result; }; @@ -680,7 +678,7 @@ struct TestVectorDistribution options.warpAllocationFn = allocateGlobalSharedMemory; options.warpSyncronizationFn = [](Location loc, OpBuilder &builder, gpu::WarpExecuteOnLane0Op warpOp) { - builder.create<gpu::BarrierOp>(loc); + gpu::BarrierOp::create(builder, loc); }; // Test on one pattern in isolation. if (warpOpToSCF) { diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp index f71fcf7..c6245b6 100644 --- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp +++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp @@ -20,8 +20,6 @@ using namespace mlir::xegpu; namespace { #define DEBUG_TYPE "test-xegpu-unroll" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") struct TestXeGPUUnrollingPatterns : public PassWrapper<TestXeGPUUnrollingPatterns, diff --git a/mlir/test/lib/IR/TestPrintInvalid.cpp b/mlir/test/lib/IR/TestPrintInvalid.cpp index 8697918..25d1b19 100644 --- a/mlir/test/lib/IR/TestPrintInvalid.cpp +++ b/mlir/test/lib/IR/TestPrintInvalid.cpp @@ -34,13 +34,14 @@ struct TestPrintInvalidPass void runOnOperation() override { Location loc = getOperation().getLoc(); OpBuilder builder(getOperation().getBodyRegion()); - auto funcOp = builder.create<func::FuncOp>( - loc, "test", FunctionType::get(getOperation().getContext(), {}, {})); + auto funcOp = func::FuncOp::create( + builder, loc, "test", + FunctionType::get(getOperation().getContext(), {}, {})); funcOp.addEntryBlock(); // The created function is invalid because there is no return op. llvm::outs() << "Invalid operation:\n" << funcOp << "\n"; builder.setInsertionPointToEnd(&funcOp.getBody().front()); - builder.create<func::ReturnOp>(loc); + func::ReturnOp::create(builder, loc); // Now this function is valid. llvm::outs() << "Valid operation:\n" << funcOp << "\n"; funcOp.erase(); diff --git a/mlir/test/lib/IR/TestSlicing.cpp b/mlir/test/lib/IR/TestSlicing.cpp index 92fd6de..5a5ac45 100644 --- a/mlir/test/lib/IR/TestSlicing.cpp +++ b/mlir/test/lib/IR/TestSlicing.cpp @@ -30,8 +30,8 @@ static LogicalResult createBackwardSliceFunction(Operation *op, OpBuilder builder(parentFuncOp); Location loc = op->getLoc(); std::string clonedFuncOpName = parentFuncOp.getName().str() + suffix.str(); - func::FuncOp clonedFuncOp = builder.create<func::FuncOp>( - loc, clonedFuncOpName, parentFuncOp.getFunctionType()); + func::FuncOp clonedFuncOp = func::FuncOp::create( + builder, loc, clonedFuncOpName, parentFuncOp.getFunctionType()); IRMapping mapper; builder.setInsertionPointToEnd(clonedFuncOp.addEntryBlock()); for (const auto &arg : enumerate(parentFuncOp.getArguments())) @@ -46,7 +46,7 @@ static LogicalResult createBackwardSliceFunction(Operation *op, (void)result; for (Operation *slicedOp : slice) builder.clone(*slicedOp, mapper); - builder.create<func::ReturnOp>(loc); + func::ReturnOp::create(builder, loc); return success(); } diff --git a/mlir/test/lib/Pass/TestPassManager.cpp b/mlir/test/lib/Pass/TestPassManager.cpp index 7afe210..25c8e53 100644 --- a/mlir/test/lib/Pass/TestPassManager.cpp +++ b/mlir/test/lib/Pass/TestPassManager.cpp @@ -217,8 +217,8 @@ struct TestInvalidParentPass void runOnOperation() final { FunctionOpInterface op = getOperation(); OpBuilder b(op.getFunctionBody()); - b.create<test::TestCallOp>(op.getLoc(), TypeRange(), "some_unknown_func", - ValueRange()); + test::TestCallOp::create(b, op.getLoc(), TypeRange(), "some_unknown_func", + ValueRange()); } }; diff --git a/mlir/test/lib/Transforms/TestDialectConversion.cpp b/mlir/test/lib/Transforms/TestDialectConversion.cpp index 8278937..dc0538e 100644 --- a/mlir/test/lib/Transforms/TestDialectConversion.cpp +++ b/mlir/test/lib/Transforms/TestDialectConversion.cpp @@ -45,7 +45,7 @@ struct PDLLTypeConverter : public TypeConverter { /// Hook for materializing a conversion. static Value materializeCast(OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) { - return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs) + return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs) .getResult(0); } }; diff --git a/mlir/test/lib/Transforms/TestInliningCallback.cpp b/mlir/test/lib/Transforms/TestInliningCallback.cpp index c518f3f..2888c3c 100644 --- a/mlir/test/lib/Transforms/TestInliningCallback.cpp +++ b/mlir/test/lib/Transforms/TestInliningCallback.cpp @@ -53,8 +53,8 @@ struct InlinerCallback mlir::Operation &call = inlineBlock->back(); builder.setInsertionPointAfter(&call); - auto executeRegionOp = builder.create<mlir::scf::ExecuteRegionOp>( - call.getLoc(), call.getResultTypes()); + auto executeRegionOp = mlir::scf::ExecuteRegionOp::create( + builder, call.getLoc(), call.getResultTypes()); mlir::Region ®ion = executeRegionOp.getRegion(); // Move the inlined blocks into the region @@ -70,8 +70,8 @@ struct InlinerCallback if (test::TestReturnOp returnOp = llvm::dyn_cast<test::TestReturnOp>(&op)) { mlir::OpBuilder returnBuilder(returnOp); - returnBuilder.create<mlir::scf::YieldOp>(returnOp.getLoc(), - returnOp.getOperands()); + mlir::scf::YieldOp::create(returnBuilder, returnOp.getLoc(), + returnOp.getOperands()); returnOp.erase(); } } @@ -79,8 +79,8 @@ struct InlinerCallback // Add test.return after scf.execute_region builder.setInsertionPointAfter(executeRegionOp); - builder.create<test::TestReturnOp>(executeRegionOp.getLoc(), - executeRegionOp.getResults()); + test::TestReturnOp::create(builder, executeRegionOp.getLoc(), + executeRegionOp.getResults()); } void runOnOperation() override { diff --git a/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp b/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp index 4e0213c..c1fb706 100644 --- a/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp +++ b/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp @@ -28,7 +28,7 @@ makeIsolatedFromAboveImpl(RewriterBase &rewriter, SmallVector<Value> operands = regionOp.getOperands(); operands.append(capturedValues); auto isolatedRegionOp = - rewriter.create<test::IsolatedOneRegionOp>(regionOp.getLoc(), operands); + test::IsolatedOneRegionOp::create(rewriter, regionOp.getLoc(), operands); rewriter.inlineRegionBefore(region, isolatedRegionOp.getRegion(), isolatedRegionOp.getRegion().begin()); rewriter.eraseOp(regionOp); diff --git a/mlir/test/lib/Transforms/TestTransformsOps.cpp b/mlir/test/lib/Transforms/TestTransformsOps.cpp index 9a5632b..ff5838d 100644 --- a/mlir/test/lib/Transforms/TestTransformsOps.cpp +++ b/mlir/test/lib/Transforms/TestTransformsOps.cpp @@ -74,8 +74,8 @@ transform::TestMakeComposedFoldedAffineApply::applyToOne( if (auto v = dyn_cast<Value>(ofr)) { result = v; } else { - result = rewriter.create<arith::ConstantIndexOp>( - loc, getConstantIntValue(ofr).value()); + result = arith::ConstantIndexOp::create(rewriter, loc, + getConstantIntValue(ofr).value()); } results.push_back(result.getDefiningOp()); rewriter.replaceOp(affineApplyOp, result); diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py index 233fef8..feaf5fb 100644 --- a/mlir/test/lit.cfg.py +++ b/mlir/test/lit.cfg.py @@ -343,7 +343,6 @@ if config.enable_assertions: else: config.available_features.add("noasserts") - def have_host_jit_feature_support(feature_name): mlir_runner_exe = lit.util.which("mlir-runner", config.mlir_tools_dir) diff --git a/mlir/test/lit.site.cfg.py.in b/mlir/test/lit.site.cfg.py.in index 132aabe..b1185e1 100644 --- a/mlir/test/lit.site.cfg.py.in +++ b/mlir/test/lit.site.cfg.py.in @@ -5,6 +5,7 @@ import sys config.target_triple = "@LLVM_TARGET_TRIPLE@" config.llvm_src_root = "@LLVM_SOURCE_DIR@" config.llvm_tools_dir = lit_config.substitute("@LLVM_TOOLS_DIR@") +config.spirv_tools_tests = @LLVM_INCLUDE_SPIRV_TOOLS_TESTS@ config.llvm_shlib_ext = "@SHLIBEXT@" config.llvm_shlib_dir = lit_config.substitute(path(r"@SHLIBDIR@")) config.python_executable = "@Python3_EXECUTABLE@" @@ -41,7 +42,7 @@ config.mlir_run_amx_tests = @MLIR_RUN_AMX_TESTS@ config.mlir_run_arm_sve_tests = @MLIR_RUN_ARM_SVE_TESTS@ # This is a workaround for the fact that LIT's: # %if <cond> -# requires <cond> to be in the set of available features. +# requires <cond> to be in the set of available features. # TODO: Update LIT's TestRunner so that this is not required. if config.mlir_run_arm_sve_tests: config.available_features.add("mlir_arm_sve_tests") diff --git a/mlir/test/mlir-runner/simple.mlir b/mlir/test/mlir-runner/simple.mlir index 1a03b99..21dabdd 100644 --- a/mlir/test/mlir-runner/simple.mlir +++ b/mlir/test/mlir-runner/simple.mlir @@ -15,10 +15,10 @@ // RUN: ls %t.o // RUN: rm %t.o -// RUN: mlir-runner %s -dump-object-file -object-filename=%T/test.o \ +// RUN: mlir-runner %s -dump-object-file -object-filename=%t.o \ // RUN: %if target={{s390x-.*}} %{ -argext-abi-check=false %} | FileCheck %s -// RUN: ls %T/test.o -// RUN: rm %T/test.o +// RUN: ls %t.o +// RUN: rm %t.o // Declarations of C library functions. llvm.func @logbf(f32) -> f32 diff --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td index d47411d..a809611 100644 --- a/mlir/test/mlir-tblgen/attrdefs.td +++ b/mlir/test/mlir-tblgen/attrdefs.td @@ -115,6 +115,11 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> { // DEF: return new (allocator.allocate<CompoundAAttrStorage>()) // DEF-SAME: CompoundAAttrStorage(std::move(widthOfSomething), std::move(exampleTdType), std::move(apFloat), std::move(dims), std::move(inner)); +// DEF: CompoundAAttr CompoundAAttr::getChecked( +// DEF-SAME: int widthOfSomething, ::test::SimpleTypeA exampleTdType, ::llvm::APFloat apFloat, ::llvm::ArrayRef<int> dims, ::mlir::Type inner +// DEF-SAME: ) +// DEF-NEXT: return Base::getChecked(emitError, context, std::move(widthOfSomething), std::move(exampleTdType), std::move(apFloat), std::move(dims), std::move(inner)); + // DEF: ::mlir::Type CompoundAAttr::getInner() const { // DEF-NEXT: return getImpl()->inner; } diff --git a/mlir/test/mlir-tblgen/op-properties-predicates.td b/mlir/test/mlir-tblgen/op-properties-predicates.td index 7cd24aa..af09ee7 100644 --- a/mlir/test/mlir-tblgen/op-properties-predicates.td +++ b/mlir/test/mlir-tblgen/op-properties-predicates.td @@ -70,6 +70,12 @@ def OpWithPredicates : NS_Op<"op_with_predicates"> { // CHECK-NEXT: if (!(((!prop.has_value())) || ((::llvm::all_of((*(prop)), [](const int64_t& baseStore) -> bool { return [](int64_t baseIface) -> bool { return ((baseIface >= 0)); }(baseStore); })) && (!(((*(prop)).empty())))))) // CHECK: failed to satisfy constraint: optional non-empty array of non-negative int64_ +// CHECK-LABEL: ::llvm::LogicalResult OpWithPredicatesAdaptor::verify +// Note: comprehensive emission of verifiers is tested in verifyINvariantsImpl() below +// CHECK: int64_t tblgen_scalar = this->getScalar(); +// CHECK: if (!((tblgen_scalar >= 0))) +// CHECK: return emitError(loc, "'test.op_with_predicates' op ""property 'scalar' failed to satisfy constraint: non-negative int64_t"); + // CHECK-LABEL: OpWithPredicates::verifyInvariantsImpl() // Note: for test readability, we capture [[maybe_unused]] into the variable maybe_unused // CHECK: [[maybe_unused:\[\[maybe_unused\]\]]] int64_t tblgen_scalar = this->getScalar(); diff --git a/mlir/test/mlir-tblgen/rewriter-attributes-properties.td b/mlir/test/mlir-tblgen/rewriter-attributes-properties.td index 40af548..23ab24e 100644 --- a/mlir/test/mlir-tblgen/rewriter-attributes-properties.td +++ b/mlir/test/mlir-tblgen/rewriter-attributes-properties.td @@ -44,7 +44,7 @@ def test1 : Pat<(AOp (BOp:$x $y), $_), (AOp $x, $y)>; // CHECK: test::AOp::Properties tblgen_props; // CHECK: tblgen_values.push_back((*x.getODSResults(0).begin())); // CHECK: tblgen_props.y = ::llvm::dyn_cast_if_present<decltype(tblgen_props.y)>(y); -// CHECK: tblgen_AOp_0 = rewriter.create<test::AOp>(odsLoc, tblgen_types, tblgen_values, tblgen_props); +// CHECK: tblgen_AOp_0 = test::AOp::create(rewriter, odsLoc, tblgen_types, tblgen_values, tblgen_props); // Note: These use strings to pick up a non-trivial storage/interface type // difference. diff --git a/mlir/test/mlir-tblgen/rewriter-indexing.td b/mlir/test/mlir-tblgen/rewriter-indexing.td index 0a94746..9bb6103 100644 --- a/mlir/test/mlir-tblgen/rewriter-indexing.td +++ b/mlir/test/mlir-tblgen/rewriter-indexing.td @@ -55,7 +55,7 @@ def test2 : Pat<(COp $attr1, $op1, $attr2, (AOp $op2)), // We expect ODSOperand 0 here, the attribute before the operand in BOp // definition shouldn't shift the counter. // CHECK: op1 = (*castedOp0.getODSOperands(0).begin()).getDefiningOp(); -// CHECK: rewriter.create<test::BOp>((*a.getODSResults(0).begin()).getLoc() +// CHECK: test::BOp::create(rewriter, (*a.getODSResults(0).begin()).getLoc() def test3 : Pat<(BOp $attr, (AOp:$a $input)), (BOp $attr, (AOp $input), (location $a))>; diff --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py index ef1d835..66f7ec8 100644 --- a/mlir/test/python/ir/array_attributes.py +++ b/mlir/test/python/ir/array_attributes.py @@ -31,6 +31,7 @@ def testGetDenseElementsUnsupported(): # CHECK: unimplemented array format conversion from format: print(e) + # CHECK-LABEL: TEST: testGetDenseElementsUnSupportedTypeOkIfExplicitTypeProvided @run def testGetDenseElementsUnSupportedTypeOkIfExplicitTypeProvided(): @@ -41,8 +42,9 @@ def testGetDenseElementsUnSupportedTypeOkIfExplicitTypeProvided(): # realistic example would be a NumPy extension type like the bfloat16 # type from the ml_dtypes package, which isn't a dependency of this # test. - attr = DenseElementsAttr.get(array.view(np.datetime64), - type=IntegerType.get_signless(64)) + attr = DenseElementsAttr.get( + array.view(np.datetime64), type=IntegerType.get_signless(64) + ) # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64> print(attr) # CHECK: {{\[}}[1 2 3] @@ -135,6 +137,7 @@ def testGetDenseElementsFromListMixedTypes(): # Splats. ################################################################################ + # CHECK-LABEL: TEST: testGetDenseElementsSplatInt @run def testGetDenseElementsSplatInt(): @@ -617,3 +620,18 @@ def testGetDenseResourceElementsAttr(): # CHECK: BACKING MEMORY DELETED # CHECK: EXIT FUNCTION print("EXIT FUNCTION") + + +# CHECK-LABEL: TEST: testDanglingResource +print("TEST: testDanglingResource") +# see https://github.com/llvm/llvm-project/pull/149414, https://github.com/llvm/llvm-project/pull/150137, https://github.com/llvm/llvm-project/pull/150561 +# This error occurs only when there is an alive context with a DenseResourceElementsAttr +# in the end of the program, so we put it here without an encapsulating function. +ctx = Context() + +with ctx, Location.unknown(): + DenseResourceElementsAttr.get_from_buffer( + memoryview(np.array([1, 2, 3])), + "some_resource", + RankedTensorType.get((3,), IntegerType.get_signed(32)), + ) |