diff options
Diffstat (limited to 'mlir/test')
18 files changed, 520 insertions, 115 deletions
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir new file mode 100644 index 0000000..bcbdef0 --- /dev/null +++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir @@ -0,0 +1,89 @@ +// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 --allow-unregistered-dialect | FileCheck %s + +// CHECK-LABEL: @wmma_k4 +func.func @wmma_k4(%arg0 : vector<2xf32>, %arg1 : vector<8xf32>) { + // CHECK: rocdl.wmma.f32.16x16x4.f32 %arg0, %arg0, %arg1 + amdgpu.wmma 16x16x4 %arg0 * %arg0 + %arg1 : vector<2xf32>, vector<2xf32>, vector<8xf32> + func.return +} + +// CHECK-LABEL: @wmma_k32 +func.func @wmma_k32(%arg0 : vector<16xf16>, %arg1 : vector<16xbf16>, %arg2 : vector<8xf32>, + %arg3 : vector<8xf16>, %arg4 : vector<8xbf16>) { + // CHECK: rocdl.wmma.f32.16x16x32.f16 %arg0, %arg0, %arg2 + amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<8xf32> + + // CHECK: rocdl.wmma.f16.16x16x32.f16 %arg0, %arg0, {{.*}} : (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) + amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg3 : vector<16xf16>, vector<16xf16>, vector<8xf16> + + // CHECK: rocdl.wmma.f32.16x16x32.bf16 {{.*}}, {{.*}}, %arg2 + amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<8xf32> + + // CHECK: rocdl.wmma.bf16.16x16x32.bf16 {{.*}}, {{.*}}, {{.*}}, {{.*}} : (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) + amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg4 : vector<16xbf16>, vector<16xbf16>, vector<8xbf16> + + func.return +} + +// CHECK-LABEL: @wmma_k64 +func.func @wmma_k64(%arg0 : vector<32xi8>, %arg1 : vector<32xf8E4M3FN>, %arg2 : vector<32xf8E5M2>, + %arg3 : vector<8xi32>, %arg4 : vector<8xf32>, %arg5 : vector<8xf16>) { + // CHECK: rocdl.wmma.i32.16x16x64.iu8 {{.*}}, {{.*}}, {{.*}}, {{.*}}, %arg3, {{.*}} + amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg3 {clamp} : vector<32xi8>, vector<32xi8>, vector<8xi32> + + // CHECK: rocdl.wmma.f32.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg4 + amdgpu.wmma 16x16x64 %arg1 * %arg1 + %arg4 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf32> + + // CHECK: rocdl.wmma.f16.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1) + amdgpu.wmma 16x16x64 %arg1 * %arg1 + %arg5 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf16> + + // CHECK: rocdl.wmma.f32.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg4 + amdgpu.wmma 16x16x64 %arg1 * %arg2 + %arg4 : vector<32xf8E4M3FN>, vector<32xf8E5M2>, vector<8xf32> + + // CHECK: rocdl.wmma.f16.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1) + amdgpu.wmma 16x16x64 %arg1 * %arg2 + %arg5 : vector<32xf8E4M3FN>, vector<32xf8E5M2>, vector<8xf16> + + // CHECK: rocdl.wmma.f32.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg4 + amdgpu.wmma 16x16x64 %arg2 * %arg2 + %arg4 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf32> + + // CHECK: rocdl.wmma.f16.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1) + amdgpu.wmma 16x16x64 %arg2 * %arg2 + %arg5 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf16> + + // CHECK: rocdl.wmma.f32.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg4 + amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg4 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf32> + + // CHECK: rocdl.wmma.f16.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1) + amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg5 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf16> + + func.return +} + +// CHECK-LABEL: @wmma_k128 +func.func @wmma_k128(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>, + %arg2 : vector<8xf32>, %arg3 : vector<8xf16>) { + // CHECK: rocdl.wmma.f32.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg2 + amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg2 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf32> + + // CHECK: rocdl.wmma.f16.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1) + amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg3 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf16> + + // CHECK: rocdl.wmma.f32.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg2 + amdgpu.wmma 16x16x128 %arg0 * %arg1 + %arg2 : vector<64xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf32> + + // CHECK: rocdl.wmma.f16.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1) + amdgpu.wmma 16x16x128 %arg0 * %arg1 + %arg3 : vector<64xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf16> + + // CHECK: rocdl.wmma.f32.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg2 + amdgpu.wmma 16x16x128 %arg1 * %arg1 + %arg2 : vector<64xf8E5M2>, vector<64xf8E5M2>, vector<8xf32> + + // CHECK: rocdl.wmma.f16.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1) + amdgpu.wmma 16x16x128 %arg1 * %arg1 + %arg3 : vector<64xf8E5M2>, vector<64xf8E5M2>, vector<8xf16> + + // CHECK: rocdl.wmma.f32.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg2 + amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg2 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf32> + + // CHECK: rocdl.wmma.f16.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1) + amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg3 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf16> + + func.return +} diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir index dec62f9..7a82236 100644 --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -211,11 +211,25 @@ func.func @complex_exp(%arg: complex<f32>) -> complex<f32> { } // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32> // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32> -// CHECK-DAG: %[[COS_IMAG:.*]] = math.cos %[[IMAG]] : f32 +// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32 +// CHECK-DAG: %[[INF:.*]] = arith.constant 0x7F800000 : f32 // CHECK-DAG: %[[EXP_REAL:.*]] = math.exp %[[REAL]] : f32 -// CHECK-DAG: %[[RESULT_REAL:.]] = arith.mulf %[[EXP_REAL]], %[[COS_IMAG]] : f32 +// CHECK-DAG: %[[REAL_HALF:.*]] = arith.mulf %[[REAL]], %[[HALF]] : f32 +// CHECK-DAG: %[[EXP_HALF:.*]] = math.exp %[[REAL_HALF]] : f32 +// CHECK-DAG: %[[COS_IMAG:.*]] = math.cos %[[IMAG]] : f32 // CHECK-DAG: %[[SIN_IMAG:.*]] = math.sin %[[IMAG]] : f32 -// CHECK-DAG: %[[RESULT_IMAG:.*]] = arith.mulf %[[EXP_REAL]], %[[SIN_IMAG]] : f32 +// CHECK-DAG: %[[IS_INF:.*]] = arith.cmpf oeq, %[[EXP_REAL]], %[[INF]] : f32 +// CHECK-DAG: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32 +// CHECK-DAG: %[[REAL_NORMAL:.*]] = arith.mulf %[[EXP_REAL]], %[[COS_IMAG]] : f32 +// CHECK-DAG: %[[EXP_HALF_COS:.*]] = arith.mulf %[[EXP_HALF]], %[[COS_IMAG]] : f32 +// CHECK-DAG: %[[REAL_OVERFLOW:.*]] = arith.mulf %[[EXP_HALF_COS]], %[[EXP_HALF]] : f32 +// CHECK: %[[RESULT_REAL:.*]] = arith.select %[[IS_INF]], %[[REAL_OVERFLOW]], %[[REAL_NORMAL]] : f32 +// CHECK-DAG: %[[IMAG_NORMAL:.*]] = arith.mulf %[[EXP_REAL]], %[[SIN_IMAG]] : f32 +// CHECK-DAG: %[[EXP_HALF_SIN:.*]] = arith.mulf %[[EXP_HALF]], %[[SIN_IMAG]] : f32 +// CHECK-DAG: %[[IMAG_OVERFLOW:.*]] = arith.mulf %[[EXP_HALF_SIN]], %[[EXP_HALF]] : f32 +// CHECK-DAG: %[[IMAG_NONZERO:.*]] = arith.select %[[IS_INF]], %[[IMAG_OVERFLOW]], %[[IMAG_NORMAL]] : f32 +// CHECK: %[[RESULT_IMAG:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[ZERO]], %[[IMAG_NONZERO]] : f32 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32> // CHECK: return %[[RESULT]] : complex<f32> @@ -832,11 +846,25 @@ func.func @complex_exp_with_fmf(%arg: complex<f32>) -> complex<f32> { } // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32> // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32> -// CHECK-DAG: %[[COS_IMAG:.*]] = math.cos %[[IMAG]] fastmath<nnan,contract> : f32 +// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32 +// CHECK-DAG: %[[INF:.*]] = arith.constant 0x7F800000 : f32 // CHECK-DAG: %[[EXP_REAL:.*]] = math.exp %[[REAL]] fastmath<nnan,contract> : f32 -// CHECK-DAG: %[[RESULT_REAL:.]] = arith.mulf %[[EXP_REAL]], %[[COS_IMAG]] fastmath<nnan,contract> : f32 +// CHECK-DAG: %[[REAL_HALF:.*]] = arith.mulf %[[REAL]], %[[HALF]] fastmath<nnan,contract> : f32 +// CHECK-DAG: %[[EXP_HALF:.*]] = math.exp %[[REAL_HALF]] fastmath<nnan,contract> : f32 +// CHECK-DAG: %[[COS_IMAG:.*]] = math.cos %[[IMAG]] fastmath<nnan,contract> : f32 // CHECK-DAG: %[[SIN_IMAG:.*]] = math.sin %[[IMAG]] fastmath<nnan,contract> : f32 -// CHECK-DAG: %[[RESULT_IMAG:.*]] = arith.mulf %[[EXP_REAL]], %[[SIN_IMAG]] fastmath<nnan,contract> : f32 +// CHECK-DAG: %[[IS_INF:.*]] = arith.cmpf oeq, %[[EXP_REAL]], %[[INF]] fastmath<nnan,contract> : f32 +// CHECK-DAG: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32 +// CHECK-DAG: %[[REAL_NORMAL:.*]] = arith.mulf %[[EXP_REAL]], %[[COS_IMAG]] fastmath<nnan,contract> : f32 +// CHECK-DAG: %[[EXP_HALF_COS:.*]] = arith.mulf %[[EXP_HALF]], %[[COS_IMAG]] fastmath<nnan,contract> : f32 +// CHECK-DAG: %[[REAL_OVERFLOW:.*]] = arith.mulf %[[EXP_HALF_COS]], %[[EXP_HALF]] fastmath<nnan,contract> : f32 +// CHECK: %[[RESULT_REAL:.*]] = arith.select %[[IS_INF]], %[[REAL_OVERFLOW]], %[[REAL_NORMAL]] : f32 +// CHECK-DAG: %[[IMAG_NORMAL:.*]] = arith.mulf %[[EXP_REAL]], %[[SIN_IMAG]] fastmath<nnan,contract> : f32 +// CHECK-DAG: %[[EXP_HALF_SIN:.*]] = arith.mulf %[[EXP_HALF]], %[[SIN_IMAG]] fastmath<nnan,contract> : f32 +// CHECK-DAG: %[[IMAG_OVERFLOW:.*]] = arith.mulf %[[EXP_HALF_SIN]], %[[EXP_HALF]] fastmath<nnan,contract> : f32 +// CHECK-DAG: %[[IMAG_NONZERO:.*]] = arith.select %[[IS_INF]], %[[IMAG_OVERFLOW]], %[[IMAG_NORMAL]] : f32 +// CHECK: %[[RESULT_IMAG:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[ZERO]], %[[IMAG_NONZERO]] : f32 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32> // CHECK: return %[[RESULT]] : complex<f32> diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir index 5784764..4c6f62a 100644 --- a/mlir/test/Dialect/AMDGPU/invalid.mlir +++ b/mlir/test/Dialect/AMDGPU/invalid.mlir @@ -156,14 +156,6 @@ func.func @wmma_no_k_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector // ----- -func.func @wmma_wrong_m_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> { - // expected-error@+1 {{'amdgpu.wmma' op attribute 'm' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16}}} - %0 = amdgpu.wmma 32x16x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32> - func.return %0 : vector<8xi32> -} - -// ----- - func.func @wmma_wrong_n_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> { // expected-error@+1 {{'amdgpu.wmma' op attribute 'n' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16}}} %0 = amdgpu.wmma 16x32x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32> @@ -173,14 +165,62 @@ func.func @wmma_wrong_n_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vec // ----- func.func @wmma_wrong_k_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> { - // expected-error@+1 {{'amdgpu.wmma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16, 32}}} + // expected-error@+1 {{'amdgpu.wmma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {4, 16, 32, 64, 128}}} %0 = amdgpu.wmma 16x16x24 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32> func.return %0 : vector<8xi32> } // ----- -// Missinng `resetOffset` +func.func @wmma_source_length_mismatch(%arg0 : vector<8xf16>, %arg1 : vector<16xf16>, %arg2 : vector<8xf32>) -> vector<8xf32> { + // expected-error@+1 {{'amdgpu.wmma' op source vectors have different lengths}} + %0 = amdgpu.wmma 16x16x16 %arg0 * %arg1 + %arg2 : vector<8xf16>, vector<16xf16>, vector<8xf32> + func.return %0 : vector<8xf32> +} + +// ----- + +func.func @wmma_mismatched_float_types(%arg0 : vector<8xf16>, %arg1 : vector<8xbf16>, %arg2 : vector<8xf32>) -> vector<8xf32> { + // expected-error@+1 {{'amdgpu.wmma' op source element types must match (except for fp8/bf8)}} + %0 = amdgpu.wmma 16x16x16 %arg0 * %arg1 + %arg2 : vector<8xf16>, vector<8xbf16>, vector<8xf32> + func.return %0 : vector<8xf32> +} + +// ----- + +func.func @wmma_mismatched_int_types(%arg0 : vector<8xi8>, %arg1 : vector<8xi4>, %arg2 : vector<8xi32>) -> vector<8xi32> { + // expected-error@+1 {{'amdgpu.wmma' op source element types must match (except for fp8/bf8)}} + %0 = amdgpu.wmma 16x16x16 %arg0 * %arg1 + %arg2 : vector<8xi8>, vector<8xi4>, vector<8xi32> + func.return %0 : vector<8xi32> +} + +// ----- + +func.func @wmma_clamp_float(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) -> vector<8xf32> { + // expected-error@+1 {{'amdgpu.wmma' op clamp flag is not supported for float types}} + %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 {clamp} : vector<8xf16>, vector<8xf16>, vector<8xf32> + func.return %0 : vector<8xf32> +} + +// ----- + +func.func @wmma_unsignedA_float(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) -> vector<8xf32> { + // expected-error@+1 {{'amdgpu.wmma' op unsigned flags are not supported for float types}} + %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 {unsignedA} : vector<8xf16>, vector<8xf16>, vector<8xf32> + func.return %0 : vector<8xf32> +} + +// ----- + +func.func @wmma_unsignedB_float(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) -> vector<8xf32> { + // expected-error@+1 {{'amdgpu.wmma' op unsigned flags are not supported for float types}} + %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 {unsignedB} : vector<8xf16>, vector<8xf16>, vector<8xf32> + func.return %0 : vector<8xf32> +} + +// ----- + +// Missing `resetOffset` func.func @fat_raw_buffer_cast_stripped_offset(%m: memref<8xi32, strided<[1], offset: ?>, #gpu.address_space<global>>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> { // expected-error@+1 {{'amdgpu.fat_raw_buffer_cast' op expected result type to be 'memref<8xi32, strided<[1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>>' but got 'memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>'}} %ret = amdgpu.fat_raw_buffer_cast %m : memref<8xi32, strided<[1], offset: ?>, #gpu.address_space<global>> to memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir index a330967..09134cb 100644 --- a/mlir/test/Dialect/AMDGPU/ops.mlir +++ b/mlir/test/Dialect/AMDGPU/ops.mlir @@ -586,6 +586,41 @@ func.func @wmma_i32_16x16x32_i4(%arg0 : vector<16xi4>, %arg1 : vector<8xi32>) -> func.return %0 : vector<8xi32> } +// CHECK-LABEL: func @wmma_f32_16x16x4_f32 +func.func @wmma_f32_16x16x4_f32(%arg0 : vector<2xf32>, %arg1 : vector<8xf32>) -> vector<8xf32> { + // CHECK: amdgpu.wmma 16x16x4 + %0 = amdgpu.wmma 16x16x4 %arg0 * %arg0 + %arg1 : vector<2xf32>, vector<2xf32>, vector<8xf32> + func.return %0 : vector<8xf32> +} + +// CHECK-LABEL: func @wmma_f32_16x16x64_f8 +func.func @wmma_f32_16x16x64_f8(%arg0 : vector<32xf8E4M3FN>, %arg1 : vector<8xf32>) -> vector<8xf32> { + // CHECK: amdgpu.wmma 16x16x64 + %0 = amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg1 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf32> + func.return %0 : vector<8xf32> +} + +// CHECK-LABEL: func @wmma_f32_16x16x64_bf8 +func.func @wmma_f32_16x16x64_bf8(%arg0 : vector<32xf8E5M2>, %arg1 : vector<8xf32>) -> vector<8xf32> { + // CHECK: amdgpu.wmma 16x16x64 + %0 = amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg1 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf32> + func.return %0 : vector<8xf32> +} + +// CHECK-LABEL: func @wmma_f16_16x16x64_bf8 +func.func @wmma_f16_16x16x64_bf8(%arg0 : vector<32xf8E5M2>, %arg1 : vector<8xf16>) -> vector<8xf16> { + // CHECK: amdgpu.wmma 16x16x64 + %0 = amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg1 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf16> + func.return %0 : vector<8xf16> +} + +// CHECK-LABEL: func @wmma_f16_16x16x64_f8 +func.func @wmma_f16_16x16x64_f8(%arg0 : vector<32xf8E4M3FN>, %arg1 : vector<8xf16>) -> vector<8xf16> { + // CHECK: amdgpu.wmma 16x16x64 + %0 = amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg1 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf16> + func.return %0 : vector<8xf16> +} + // CHECK-LABEL: func @swizzle_bitmode func.func @swizzle_bitmode(%arg0 : f32) -> f32 { // CHECK: amdgpu.swizzle_bitmode diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-memoryeffect-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-memoryeffect-interface.mlir index 40a57b9..e8bb0c0 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-memoryeffect-interface.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-memoryeffect-interface.mlir @@ -156,3 +156,24 @@ func.func @manual_deallocation(%c: i1, %f: f32, %idx: index) -> f32 { // CHECK: cf.assert %[[true]], "expected that the block does not have ownership" // CHECK: memref.dealloc %[[manual_alloc]] // CHECK: bufferization.dealloc (%[[managed_alloc]] : memref<5xf32>) if (%[[true]]) + +// ----- + +// CHECK-LABEL: func.func private @properly_creates_deallocations_in_execute_region( +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: scf.execute_region no_inline { +// CHECK: %[[alloc:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x63x378x16xui8> +// CHECK: bufferization.dealloc (%[[alloc]] : memref<1x63x378x16xui8>) if (%[[true]]) + +func.func private @properly_creates_deallocations_in_execute_region(%arg1: memref<1x16x252x380xui8> ) -> (memref<1x250x378x16xui8> ) { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x250x378x16xui8> + scf.execute_region no_inline { + %subview = memref.subview %arg1[0, 0, 0, 0] [1, 16, 65, 380] [1, 1, 1, 1] : memref<1x16x252x380xui8> to memref<1x16x65x380xui8, strided<[1532160, 95760, 380, 1]>> + %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<1x63x378x16xui8> + test.buffer_based in(%subview: memref<1x16x65x380xui8, strided<[1532160, 95760, 380, 1]>>) out(%alloc_3: memref<1x63x378x16xui8>) + %subview_7 = memref.subview %alloc[0, 0, 0, 0] [1, 63, 378, 16] [1, 1, 1, 1] : memref<1x250x378x16xui8> to memref<1x63x378x16xui8, strided<[1512000, 6048, 16, 1]>> + test.copy(%alloc_3, %subview_7) : (memref<1x63x378x16xui8>, memref<1x63x378x16xui8, strided<[1512000, 6048, 16, 1]>>) + scf.yield + } + return %alloc : memref<1x250x378x16xui8> +} diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir index d5f834b..8db1ebb 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir @@ -381,15 +381,19 @@ func.func private @execute_region_test(%t1 : tensor<?xf32>) // ----- // CHECK-LABEL: func @no_inline_execute_region_not_canonicalized -func.func @no_inline_execute_region_not_canonicalized() { - %c = arith.constant 42 : i32 - // CHECK: scf.execute_region - // CHECK-SAME: no_inline - %v = scf.execute_region -> i32 no_inline { - scf.yield %c : i32 +module { + func.func private @foo()->() + func.func @no_inline_execute_region_not_canonicalized() { + %c = arith.constant 42 : i32 + // CHECK: scf.execute_region + // CHECK-SAME: no_inline + %v = scf.execute_region -> i32 no_inline { + func.call @foo():()->() + scf.yield %c : i32 + } + // CHECK: return + return } - // CHECK: return - return } // ----- diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir index d270ee8..e703600 100644 --- a/mlir/test/Dialect/LLVMIR/rocdl.mlir +++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir @@ -664,6 +664,36 @@ llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) { llvm.return } +// CHECK-LABEL @rocdl.tensor.load.to.lds +llvm.func @rocdl.tensor.load.to.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>, + %dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) { + // CHECK: rocdl.tensor.load.to.lds %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} cachepolicy 0 : vector<4xi32>, vector<8xi32> + rocdl.tensor.load.to.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0 : vector<4xi32>, vector<8xi32> + llvm.return +} + +// CHECK-LABEL @rocdl.tensor.store.from.lds +llvm.func @rocdl.tensor.store.from.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>, + %dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) { + // CHECK: rocdl.tensor.store.from.lds %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} cachepolicy 0 : vector<4xi32>, vector<8xi32> + rocdl.tensor.store.from.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0 : vector<4xi32>, vector<8xi32> + llvm.return +} + +// CHECK-LABEL @rocdl.tensor.load.to.lds.d2 +llvm.func @rocdl.tensor.load.to.lds.d2(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>) { + // CHECK: rocdl.tensor.load.to.lds.d2 %{{.*}}, %{{.*}} cachepolicy 0 : vector<4xi32>, vector<8xi32> + rocdl.tensor.load.to.lds.d2 %dgroup0, %dgroup1 cachepolicy 0 : vector<4xi32>, vector<8xi32> + llvm.return +} + +// CHECK-LABEL @rocdl.tensor.store.from.lds.d2 +llvm.func @rocdl.tensor.store.from.lds.d2(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>) { + // CHECK: rocdl.tensor.store.from.lds.d2 %{{.*}}, %{{.*}} cachepolicy 0 : vector<4xi32>, vector<8xi32> + rocdl.tensor.store.from.lds.d2 %dgroup0, %dgroup1 cachepolicy 0 : vector<4xi32>, vector<8xi32> + llvm.return +} + llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr, %stride : i16, %numRecords : i64, diff --git a/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir index f9b81df..d0aec68 100644 --- a/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir @@ -77,6 +77,24 @@ func.func @memref_expand(%m: memref<?xf32>, %sz: index) -> (index, index) { // ----- +// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)> +// CHECK-LABEL: func @memref_collapse( +// CHECK-SAME: %[[sz0:.*]]: index +// CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[c12:.*]] = arith.constant 12 : index +// CHECK: %[[dim:.*]] = memref.dim %{{.*}}, %[[c2]] : memref<3x4x?x2xf32> +// CHECK: %[[mul:.*]] = affine.apply #[[$MAP]]()[%[[dim]]] +// CHECK: return %[[c12]], %[[mul]] +func.func @memref_collapse(%sz0: index) -> (index, index) { + %0 = memref.alloc(%sz0) : memref<3x4x?x2xf32> + %1 = memref.collapse_shape %0 [[0, 1], [2, 3]] : memref<3x4x?x2xf32> into memref<12x?xf32> + %2 = "test.reify_bound"(%1) {dim = 0} : (memref<12x?xf32>) -> (index) + %3 = "test.reify_bound"(%1) {dim = 1} : (memref<12x?xf32>) -> (index) + return %2, %3 : index, index +} + +// ----- + // CHECK-LABEL: func @memref_get_global( // CHECK: %[[c4:.*]] = arith.constant 4 : index // CHECK: return %[[c4]] diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir index 77d18da..042ee25 100644 --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -2243,3 +2243,76 @@ func.func @test_firstprivate_map(%arg0: memref<10xf32>) { // CHECK-NEXT: acc.yield // CHECK-NEXT: } // CHECK-NEXT: return + +// ----- + +func.func @test_kernel_environment(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) { + %c1 = arith.constant 1 : index + %c1024 = arith.constant 1024 : index + + // Create data clause operands for the kernel environment + %copyin = acc.copyin varPtr(%arg0 : memref<1024xf32>) -> memref<1024xf32> + %create = acc.create varPtr(%arg1 : memref<1024xf32>) -> memref<1024xf32> + + // Kernel environment wraps gpu.launch and captures data mapping + acc.kernel_environment dataOperands(%copyin, %create : memref<1024xf32>, memref<1024xf32>) { + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) + threads(%tx, %ty, %tz) in (%block_x = %c1024, %block_y = %c1, %block_z = %c1) { + // Kernel body uses the mapped data + %val = memref.load %copyin[%tx] : memref<1024xf32> + %result = arith.mulf %val, %val : f32 + memref.store %result, %create[%tx] : memref<1024xf32> + gpu.terminator + } + } + + // Copy results back to host and deallocate device memory + acc.copyout accPtr(%create : memref<1024xf32>) to varPtr(%arg1 : memref<1024xf32>) + acc.delete accPtr(%copyin : memref<1024xf32>) + + return +} + +// CHECK-LABEL: func @test_kernel_environment +// CHECK: %[[COPYIN:.*]] = acc.copyin varPtr(%{{.*}} : memref<1024xf32>) -> memref<1024xf32> +// CHECK: %[[CREATE:.*]] = acc.create varPtr(%{{.*}} : memref<1024xf32>) -> memref<1024xf32> +// CHECK: acc.kernel_environment dataOperands(%[[COPYIN]], %[[CREATE]] : memref<1024xf32>, memref<1024xf32>) { +// CHECK: gpu.launch +// CHECK: memref.load %[[COPYIN]] +// CHECK: memref.store %{{.*}}, %[[CREATE]] +// CHECK: } +// CHECK: } +// CHECK: acc.copyout accPtr(%[[CREATE]] : memref<1024xf32>) to varPtr(%{{.*}} : memref<1024xf32>) +// CHECK: acc.delete accPtr(%[[COPYIN]] : memref<1024xf32>) + +// ----- + +func.func @test_kernel_environment_with_async(%arg0: memref<1024xf32>) { + %c1 = arith.constant 1 : index + %c1024 = arith.constant 1024 : index + %async_val = arith.constant 1 : i32 + + %create = acc.create varPtr(%arg0 : memref<1024xf32>) async(%async_val : i32) -> memref<1024xf32> + + // Kernel environment with async clause + acc.kernel_environment dataOperands(%create : memref<1024xf32>) async(%async_val : i32) { + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) + threads(%tx, %ty, %tz) in (%block_x = %c1024, %block_y = %c1, %block_z = %c1) { + %f0 = arith.constant 0.0 : f32 + memref.store %f0, %create[%tx] : memref<1024xf32> + gpu.terminator + } + } + + acc.copyout accPtr(%create : memref<1024xf32>) async(%async_val : i32) to varPtr(%arg0 : memref<1024xf32>) + + return +} + +// CHECK-LABEL: func @test_kernel_environment_with_async +// CHECK: %[[ASYNC:.*]] = arith.constant 1 : i32 +// CHECK: %[[CREATE:.*]] = acc.create varPtr(%{{.*}} : memref<1024xf32>) async(%[[ASYNC]] : i32) -> memref<1024xf32> +// CHECK: acc.kernel_environment dataOperands(%[[CREATE]] : memref<1024xf32>) async(%[[ASYNC]] : i32) +// CHECK: gpu.launch +// CHECK: memref.store %{{.*}}, %[[CREATE]] +// CHECK: acc.copyout accPtr(%[[CREATE]] : memref<1024xf32>) async(%[[ASYNC]] : i32) to varPtr(%{{.*}} : memref<1024xf32>) diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir index 37fc86b..3f481ad 100644 --- a/mlir/test/Dialect/SCF/invalid.mlir +++ b/mlir/test/Dialect/SCF/invalid.mlir @@ -373,7 +373,7 @@ func.func @reduceReturn_not_inside_reduce(%arg0 : f32) { func.func @std_if_incorrect_yield(%arg0: i1, %arg1: f32) { - // expected-error@+1 {{region control flow edge from Region #0 to parent results: source has 1 operands, but target successor needs 2}} + // expected-error@+1 {{region control flow edge from Operation scf.yield to parent results: source has 1 operands, but target successor <to parent> needs 2}} %x, %y = scf.if %arg0 -> (f32, f32) { %0 = arith.addf %arg1, %arg1 : f32 scf.yield %0 : f32 @@ -544,7 +544,7 @@ func.func @while_invalid_terminator() { func.func @while_cross_region_type_mismatch() { %true = arith.constant true - // expected-error@+1 {{'scf.while' op region control flow edge from Region #0 to Region #1: source has 0 operands, but target successor needs 1}} + // expected-error@+1 {{region control flow edge from Operation scf.condition to Region #1: source has 0 operands, but target successor <to region #1 with 1 inputs> needs 1}} scf.while : () -> () { scf.condition(%true) } do { @@ -557,7 +557,7 @@ func.func @while_cross_region_type_mismatch() { func.func @while_cross_region_type_mismatch() { %true = arith.constant true - // expected-error@+1 {{'scf.while' op along control flow edge from Region #0 to Region #1: source type #0 'i1' should match input type #0 'i32'}} + // expected-error@+1 {{along control flow edge from Operation scf.condition to Region #1: source type #0 'i1' should match input type #0 'i32'}} %0 = scf.while : () -> (i1) { scf.condition(%true) %true : i1 } do { @@ -570,7 +570,7 @@ func.func @while_cross_region_type_mismatch() { func.func @while_result_type_mismatch() { %true = arith.constant true - // expected-error@+1 {{'scf.while' op region control flow edge from Region #0 to parent results: source has 1 operands, but target successor needs 0}} + // expected-error@+1 {{region control flow edge from Operation scf.condition to parent results: source has 1 operands, but target successor <to parent> needs 0}} scf.while : () -> () { scf.condition(%true) %true : i1 } do { diff --git a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir index 1bcef0a..ea587e9 100644 --- a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir +++ b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir @@ -49,6 +49,11 @@ func.func @conj(%arg: complex<f32>) -> complex<f32> { func.return %conj : complex<f32> } +func.func @exp(%arg: complex<f32>) -> complex<f32> { + %exp = complex.exp %arg : complex<f32> + func.return %exp : complex<f32> +} + // %input contains pairs of lhs, rhs, i.e. [lhs_0, rhs_0, lhs_1, rhs_1,...] func.func @test_binary(%input: tensor<?xcomplex<f32>>, %func: (complex<f32>, complex<f32>) -> complex<f32>) { @@ -353,5 +358,32 @@ func.func @entry() { call @test_element_f64(%abs_test_cast, %abs_func) : (tensor<?xcomplex<f64>>, (complex<f64>) -> f64) -> () + // complex.exp test + %exp_test = arith.constant dense<[ + (1.0, 2.0), + // CHECK: -1.1312 + // CHECK-NEXT: 2.4717 + + // The first case to consider is overflow of exp(real_part). If computed + // directly, this yields inf * 0 = NaN, which is incorrect. + (500.0, 0.0), + // CHECK-NEXT: inf + // CHECK-NOT: nan + // CHECK-NEXT: 0 + + // In this case, the overflow of exp(real_part) is compensated when + // sin(imag_part) is close to zero, yielding a finite imaginary part. + (90.0238094, 5.900613e-39) + // CHECK-NEXT: inf + // CHECK-NOT: inf + // CHECK-NEXT: 7.3746 + ]> : tensor<3xcomplex<f32>> + %exp_test_cast = tensor.cast %exp_test + : tensor<3xcomplex<f32>> to tensor<?xcomplex<f32>> + + %exp_func = func.constant @exp : (complex<f32>) -> complex<f32> + call @test_unary(%exp_test_cast, %exp_func) + : (tensor<?xcomplex<f32>>, (complex<f32>) -> complex<f32>) -> () + func.return } diff --git a/mlir/test/Target/LLVMIR/ptr.mlir b/mlir/test/Target/LLVMIR/ptr.mlir index 473ac05..94b6628 100644 --- a/mlir/test/Target/LLVMIR/ptr.mlir +++ b/mlir/test/Target/LLVMIR/ptr.mlir @@ -284,8 +284,8 @@ llvm.func @ptr_add_cst() -> !ptr.ptr<#llvm.address_space<0>> { // CHECK-LABEL: define i64 @ptr_diff_scalar // CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) { -// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64 -// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64 +// CHECK-NEXT: %[[P1INT:.*]] = ptrtoaddr ptr %[[PTR1]] to i64 +// CHECK-NEXT: %[[P2INT:.*]] = ptrtoaddr ptr %[[PTR2]] to i64 // CHECK-NEXT: %[[DIFF:.*]] = sub i64 %[[P1INT]], %[[P2INT]] // CHECK-NEXT: ret i64 %[[DIFF]] // CHECK-NEXT: } @@ -296,8 +296,8 @@ llvm.func @ptr_diff_scalar(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr. // CHECK-LABEL: define i32 @ptr_diff_scalar_i32 // CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) { -// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64 -// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64 +// CHECK-NEXT: %[[P1INT:.*]] = ptrtoaddr ptr %[[PTR1]] to i64 +// CHECK-NEXT: %[[P2INT:.*]] = ptrtoaddr ptr %[[PTR2]] to i64 // CHECK-NEXT: %[[DIFF:.*]] = sub i64 %[[P1INT]], %[[P2INT]] // CHECK-NEXT: %[[TRUNC:.*]] = trunc i64 %[[DIFF]] to i32 // CHECK-NEXT: ret i32 %[[TRUNC]] @@ -309,8 +309,8 @@ llvm.func @ptr_diff_scalar_i32(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: ! // CHECK-LABEL: define <4 x i64> @ptr_diff_vector // CHECK-SAME: (<4 x ptr> %[[PTRS1:.*]], <4 x ptr> %[[PTRS2:.*]]) { -// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint <4 x ptr> %[[PTRS1]] to <4 x i64> -// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint <4 x ptr> %[[PTRS2]] to <4 x i64> +// CHECK-NEXT: %[[P1INT:.*]] = ptrtoaddr <4 x ptr> %[[PTRS1]] to <4 x i64> +// CHECK-NEXT: %[[P2INT:.*]] = ptrtoaddr <4 x ptr> %[[PTRS2]] to <4 x i64> // CHECK-NEXT: %[[DIFF:.*]] = sub <4 x i64> %[[P1INT]], %[[P2INT]] // CHECK-NEXT: ret <4 x i64> %[[DIFF]] // CHECK-NEXT: } @@ -321,8 +321,8 @@ llvm.func @ptr_diff_vector(%ptrs1: vector<4x!ptr.ptr<#llvm.address_space<0>>>, % // CHECK-LABEL: define <8 x i32> @ptr_diff_vector_i32 // CHECK-SAME: (<8 x ptr> %[[PTRS1:.*]], <8 x ptr> %[[PTRS2:.*]]) { -// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint <8 x ptr> %[[PTRS1]] to <8 x i64> -// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint <8 x ptr> %[[PTRS2]] to <8 x i64> +// CHECK-NEXT: %[[P1INT:.*]] = ptrtoaddr <8 x ptr> %[[PTRS1]] to <8 x i64> +// CHECK-NEXT: %[[P2INT:.*]] = ptrtoaddr <8 x ptr> %[[PTRS2]] to <8 x i64> // CHECK-NEXT: %[[DIFF:.*]] = sub <8 x i64> %[[P1INT]], %[[P2INT]] // CHECK-NEXT: %[[TRUNC:.*]] = trunc <8 x i64> %[[DIFF]] to <8 x i32> // CHECK-NEXT: ret <8 x i32> %[[TRUNC]] @@ -344,8 +344,8 @@ llvm.func @ptr_diff_with_constants() -> i64 { // CHECK-LABEL: define i64 @ptr_diff_with_flags_nsw // CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) { -// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64 -// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64 +// CHECK-NEXT: %[[P1INT:.*]] = ptrtoaddr ptr %[[PTR1]] to i64 +// CHECK-NEXT: %[[P2INT:.*]] = ptrtoaddr ptr %[[PTR2]] to i64 // CHECK-NEXT: %[[DIFF:.*]] = sub nsw i64 %[[P1INT]], %[[P2INT]] // CHECK-NEXT: ret i64 %[[DIFF]] // CHECK-NEXT: } @@ -356,8 +356,8 @@ llvm.func @ptr_diff_with_flags_nsw(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr // CHECK-LABEL: define i64 @ptr_diff_with_flags_nuw // CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) { -// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64 -// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64 +// CHECK-NEXT: %[[P1INT:.*]] = ptrtoaddr ptr %[[PTR1]] to i64 +// CHECK-NEXT: %[[P2INT:.*]] = ptrtoaddr ptr %[[PTR2]] to i64 // CHECK-NEXT: %[[DIFF:.*]] = sub nuw i64 %[[P1INT]], %[[P2INT]] // CHECK-NEXT: ret i64 %[[DIFF]] // CHECK-NEXT: } @@ -368,8 +368,8 @@ llvm.func @ptr_diff_with_flags_nuw(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr // CHECK-LABEL: define i64 @ptr_diff_with_flags_nsw_nuw // CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) { -// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64 -// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64 +// CHECK-NEXT: %[[P1INT:.*]] = ptrtoaddr ptr %[[PTR1]] to i64 +// CHECK-NEXT: %[[P2INT:.*]] = ptrtoaddr ptr %[[PTR2]] to i64 // CHECK-NEXT: %[[DIFF:.*]] = sub nuw nsw i64 %[[P1INT]], %[[P2INT]] // CHECK-NEXT: ret i64 %[[DIFF]] // CHECK-NEXT: } diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index 30126f6..8a848221 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -1040,6 +1040,36 @@ llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) { llvm.return } +// CHECK-LABEL: rocdl.tensor.load.to.lds +llvm.func @rocdl.tensor.load.to.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>, + %dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) { + // CHECK: call void @llvm.amdgcn.tensor.load.to.lds(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 0) + rocdl.tensor.load.to.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0 : vector<4xi32>, vector<8xi32> + llvm.return +} + +// CHECK-LABEL: rocdl.tensor.store.from.lds +llvm.func @rocdl.tensor.store.from.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>, + %dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) { + // CHECK: call void @llvm.amdgcn.tensor.store.from.lds(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 0) + rocdl.tensor.store.from.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0 : vector<4xi32>, vector<8xi32> + llvm.return +} + +// CHECK-LABEL: rocdl.tensor.load.to.lds.d2 +llvm.func @rocdl.tensor.load.to.lds.d2(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>) { + // CHECK: call void @llvm.amdgcn.tensor.load.to.lds.d2(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, i32 0) + rocdl.tensor.load.to.lds.d2 %dgroup0, %dgroup1 cachepolicy 0 : vector<4xi32>, vector<8xi32> + llvm.return +} + +// CHECK-LABEL: rocdl.tensor.store.from.lds.d2 +llvm.func @rocdl.tensor.store.from.lds.d2(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>) { + // CHECK: call void @llvm.amdgcn.tensor.store.from.lds.d2(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, i32 0) + rocdl.tensor.store.from.lds.d2 %dgroup0, %dgroup1 cachepolicy 0 : vector<4xi32>, vector<8xi32> + llvm.return +} + llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr, %stride : i16, %numRecords : i64, diff --git a/mlir/test/Target/SPIRV/decorations-intel-cache-controls.mlir b/mlir/test/Target/SPIRV/decorations-intel-cache-controls.mlir new file mode 100644 index 0000000..62d15de --- /dev/null +++ b/mlir/test/Target/SPIRV/decorations-intel-cache-controls.mlir @@ -0,0 +1,42 @@ +// RUN: mlir-translate --no-implicit-module --split-input-file --test-spirv-roundtrip --verify-diagnostics %s | FileCheck %s + +// CHECK-LABEL: spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> { +spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> { + spirv.func @cache_controls() "None" { + // CHECK: spirv.Variable {cache_control_load_intel = [#spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>, #spirv.cache_control_load_intel<cache_level = 1, load_cache_control = Cached>, #spirv.cache_control_load_intel<cache_level = 2, load_cache_control = InvalidateAfterR>]} : !spirv.ptr<f32, Function> + %0 = spirv.Variable {cache_control_load_intel = [#spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>, #spirv.cache_control_load_intel<cache_level = 1, load_cache_control = Cached>, #spirv.cache_control_load_intel<cache_level = 2, load_cache_control = InvalidateAfterR>]} : !spirv.ptr<f32, Function> + // CHECK: spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel<cache_level = 0, store_cache_control = Uncached>, #spirv.cache_control_store_intel<cache_level = 1, store_cache_control = WriteThrough>, #spirv.cache_control_store_intel<cache_level = 2, store_cache_control = WriteBack>]} : !spirv.ptr<f32, Function> + %1 = spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel<cache_level = 0, store_cache_control = Uncached>, #spirv.cache_control_store_intel<cache_level = 1, store_cache_control = WriteThrough>, #spirv.cache_control_store_intel<cache_level = 2, store_cache_control = WriteBack>]} : !spirv.ptr<f32, Function> + spirv.Return + } +} + +// ----- + +spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> { + spirv.func @cache_controls_invalid_type() "None" { + // expected-error@below {{expecting array attribute of CacheControlLoadINTEL for CacheControlLoadINTEL}} + %0 = spirv.Variable {cache_control_load_intel = #spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>} : !spirv.ptr<f32, Function> + spirv.Return + } +} + +// ----- + +spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> { + spirv.func @cache_controls_invalid_type() "None" { + // expected-error@below {{expecting array attribute of CacheControlStoreINTEL for CacheControlStoreINTEL}} + %0 = spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel<cache_level = 0, store_cache_control = Uncached>, 0 : i32]} : !spirv.ptr<f32, Function> + spirv.Return + } +} + +// ----- + +spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> { + spirv.func @cache_controls_invalid_type() "None" { + // expected-error@below {{expecting non-empty array attribute of CacheControlStoreINTEL for CacheControlStoreINTEL}} + %0 = spirv.Variable {cache_control_store_intel = []} : !spirv.ptr<f32, Function> + spirv.Return + } +} diff --git a/mlir/test/Target/SPIRV/decorations.mlir b/mlir/test/Target/SPIRV/decorations.mlir index 90ba690e..712fd17 100644 --- a/mlir/test/Target/SPIRV/decorations.mlir +++ b/mlir/test/Target/SPIRV/decorations.mlir @@ -1,27 +1,32 @@ -// RUN: mlir-translate -no-implicit-module -split-input-file -test-spirv-roundtrip -verify-diagnostics %s | FileCheck %s +// RUN: mlir-translate --no-implicit-module --split-input-file --test-spirv-roundtrip %s | FileCheck %s -spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { +// RUN: %if spirv-tools %{ rm -rf %t %} +// RUN: %if spirv-tools %{ mkdir %t %} +// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv --split-input-file --spirv-save-validation-files-with-prefix=%t/module %s %} +// RUN: %if spirv-tools %{ spirv-val %t %} + +spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> { // CHECK: location = 0 : i32 spirv.GlobalVariable @var {location = 0 : i32} : !spirv.ptr<vector<4xf32>, Input> } // ----- -spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { +spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> { // CHECK: no_perspective spirv.GlobalVariable @var {no_perspective} : !spirv.ptr<vector<4xf32>, Input> } // ----- -spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { +spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> { // CHECK: flat spirv.GlobalVariable @var {flat} : !spirv.ptr<si32, Input> } // ----- -spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { +spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], [SPV_KHR_variable_pointers]> { // CHECK: aliased // CHECK: aliased spirv.GlobalVariable @var1 bind(0, 0) {aliased} : !spirv.ptr<!spirv.struct<(!spirv.array<4xf32, stride=4>[0])>, StorageBuffer> @@ -30,28 +35,28 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { // ----- -spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { +spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], [SPV_KHR_variable_pointers]> { // CHECK: non_readable spirv.GlobalVariable @var bind(0, 0) {non_readable} : !spirv.ptr<!spirv.struct<(!spirv.array<4xf32, stride=4>[0])>, StorageBuffer> } // ----- -spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { +spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], [SPV_KHR_variable_pointers]> { // CHECK: non_writable spirv.GlobalVariable @var bind(0, 0) {non_writable} : !spirv.ptr<!spirv.struct<(!spirv.array<4xf32, stride=4>[0])>, StorageBuffer> } // ----- -spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { +spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], [SPV_KHR_variable_pointers]> { // CHECK: restrict spirv.GlobalVariable @var bind(0, 0) {restrict} : !spirv.ptr<!spirv.struct<(!spirv.array<4xf32, stride=4>[0])>, StorageBuffer> } // ----- -spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { +spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> { // CHECK: relaxed_precision spirv.GlobalVariable @var {location = 0 : i32, relaxed_precision} : !spirv.ptr<vector<4xf32>, Output> } @@ -84,7 +89,7 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> { // ----- -spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel], []> { +spirv.module Logical OpenCL requires #spirv.vce<v1.0, [Kernel, Linkage], [SPV_KHR_no_integer_wrap_decoration]> { spirv.func @iadd_decorations(%arg: i32) -> i32 "None" { // CHECK: spirv.IAdd %{{.*}}, %{{.*}} {no_signed_wrap, no_unsigned_wrap} %0 = spirv.IAdd %arg, %arg {no_signed_wrap, no_unsigned_wrap} : i32 @@ -94,7 +99,7 @@ spirv.func @iadd_decorations(%arg: i32) -> i32 "None" { // ----- -spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel], []> { +spirv.module Logical OpenCL requires #spirv.vce<v1.0, [Kernel, Linkage], []> { spirv.func @fadd_decorations(%arg: f32) -> f32 "None" { // CHECK: spirv.FAdd %{{.*}}, %{{.*}} {fp_fast_math_mode = #spirv.fastmath_mode<NotNaN|NotInf|NSZ>} %0 = spirv.FAdd %arg, %arg {fp_fast_math_mode = #spirv.fastmath_mode<NotNaN|NotInf|NSZ>} : f32 @@ -104,7 +109,7 @@ spirv.func @fadd_decorations(%arg: f32) -> f32 "None" { // ----- -spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel], []> { +spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> { spirv.func @fmul_decorations(%arg: f32) -> f32 "None" { // CHECK: spirv.FMul %{{.*}}, %{{.*}} {no_contraction} %0 = spirv.FMul %arg, %arg {no_contraction} : f32 @@ -114,7 +119,7 @@ spirv.func @fmul_decorations(%arg: f32) -> f32 "None" { // ----- -spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel, Float16], []> { +spirv.module Logical OpenCL requires #spirv.vce<v1.0, [Kernel, Linkage, Float16], []> { spirv.func @fp_rounding_mode(%arg: f32) -> f16 "None" { // CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode<RTN>} : f32 to f16 %0 = spirv.FConvert %arg {fp_rounding_mode = #spirv.fp_rounding_mode<RTN>} : f32 to f16 @@ -124,51 +129,7 @@ spirv.func @fp_rounding_mode(%arg: f32) -> f16 "None" { // ----- -// CHECK-LABEL: spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> { - -spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> { - spirv.func @cache_controls() "None" { - // CHECK: spirv.Variable {cache_control_load_intel = [#spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>, #spirv.cache_control_load_intel<cache_level = 1, load_cache_control = Cached>, #spirv.cache_control_load_intel<cache_level = 2, load_cache_control = InvalidateAfterR>]} : !spirv.ptr<f32, Function> - %0 = spirv.Variable {cache_control_load_intel = [#spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>, #spirv.cache_control_load_intel<cache_level = 1, load_cache_control = Cached>, #spirv.cache_control_load_intel<cache_level = 2, load_cache_control = InvalidateAfterR>]} : !spirv.ptr<f32, Function> - // CHECK: spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel<cache_level = 0, store_cache_control = Uncached>, #spirv.cache_control_store_intel<cache_level = 1, store_cache_control = WriteThrough>, #spirv.cache_control_store_intel<cache_level = 2, store_cache_control = WriteBack>]} : !spirv.ptr<f32, Function> - %1 = spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel<cache_level = 0, store_cache_control = Uncached>, #spirv.cache_control_store_intel<cache_level = 1, store_cache_control = WriteThrough>, #spirv.cache_control_store_intel<cache_level = 2, store_cache_control = WriteBack>]} : !spirv.ptr<f32, Function> - spirv.Return - } -} - -// ----- - -spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> { - spirv.func @cache_controls_invalid_type() "None" { - // expected-error@below {{expecting array attribute of CacheControlLoadINTEL for CacheControlLoadINTEL}} - %0 = spirv.Variable {cache_control_load_intel = #spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>} : !spirv.ptr<f32, Function> - spirv.Return - } -} - -// ----- - -spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> { - spirv.func @cache_controls_invalid_type() "None" { - // expected-error@below {{expecting array attribute of CacheControlStoreINTEL for CacheControlStoreINTEL}} - %0 = spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel<cache_level = 0, store_cache_control = Uncached>, 0 : i32]} : !spirv.ptr<f32, Function> - spirv.Return - } -} - -// ----- - -spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> { - spirv.func @cache_controls_invalid_type() "None" { - // expected-error@below {{expecting non-empty array attribute of CacheControlStoreINTEL for CacheControlStoreINTEL}} - %0 = spirv.Variable {cache_control_store_intel = []} : !spirv.ptr<f32, Function> - spirv.Return - } -} - -// ----- - -spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { +spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> { // CHECK: spirv.func @relaxed_precision_arg({{%.*}}: !spirv.ptr<f32, Function> {spirv.decoration = #spirv.decoration<RelaxedPrecision>}) "None" attributes {relaxed_precision} { spirv.func @relaxed_precision_arg(%arg0: !spirv.ptr<f32, Function> {spirv.decoration = #spirv.decoration<RelaxedPrecision>}) -> () "None" attributes {relaxed_precision} { spirv.Return diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp index eb0d980..7a7a583 100644 --- a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp +++ b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp @@ -66,7 +66,7 @@ public: void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch, RegionBranchPoint regionFrom, - RegionBranchPoint regionTo, + RegionSuccessor regionTo, const NextAccess &after, NextAccess *before) override; @@ -240,7 +240,7 @@ void NextAccessAnalysis::visitCallControlFlowTransfer( void NextAccessAnalysis::visitRegionBranchControlFlowTransfer( RegionBranchOpInterface branch, RegionBranchPoint regionFrom, - RegionBranchPoint regionTo, const NextAccess &after, NextAccess *before) { + RegionSuccessor regionTo, const NextAccess &after, NextAccess *before) { LDBG() << "visitRegionBranchControlFlowTransfer: " << OpWithFlags(branch.getOperation(), OpPrintingFlags().skipRegions()); LDBG() << " regionFrom: " << (regionFrom.isParent() ? "parent" : "region"); diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp index b211e24..4d4ec02 100644 --- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp @@ -633,8 +633,9 @@ ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) { parser.getCurrentLocation(), result.operands); } -OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) { - assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) && +OperandRange RegionIfOp::getEntrySuccessorOperands(RegionSuccessor successor) { + assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, + successor.getSuccessor()) && "invalid region index"); return getOperands(); } @@ -643,10 +644,11 @@ void RegionIfOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { // We always branch to the join region. if (!point.isParent()) { - if (point != getJoinRegion()) + if (point.getTerminatorPredecessorOrNull()->getParentRegion() != + &getJoinRegion()) regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs())); else - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); return; } @@ -673,7 +675,7 @@ void AnyCondOp::getSuccessorRegions(RegionBranchPoint point, if (point.isParent()) regions.emplace_back(&getRegion()); else - regions.emplace_back(getResults()); + regions.emplace_back(getOperation(), getResults()); } void AnyCondOp::getRegionInvocationBounds( @@ -1107,11 +1109,11 @@ void LoopBlockOp::getSuccessorRegions( if (point.isParent()) return; - regions.emplace_back((*this)->getResults()); + regions.emplace_back(getOperation(), getOperation()->getResults()); } -OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) { - assert(point == getBody()); +OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionSuccessor successor) { + assert(successor.getSuccessor() == &getBody()); return MutableOperandRange(getInitMutable()); } @@ -1120,8 +1122,8 @@ OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) { //===----------------------------------------------------------------------===// MutableOperandRange -LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) { - if (point.isParent()) +LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionSuccessor successor) { + if (successor.isParent()) return getExitArgMutable(); return getNextIterArgMutable(); } @@ -1213,7 +1215,7 @@ void TestStoreWithARegion::getSuccessorRegions( if (point.isParent()) regions.emplace_back(&getBody(), getBody().front().getArguments()); else - regions.emplace_back(); + regions.emplace_back(getOperation(), getOperation()->getResults()); } //===----------------------------------------------------------------------===// @@ -1227,7 +1229,7 @@ void TestStoreWithALoopRegion::getSuccessorRegions( // enter the body. regions.emplace_back( RegionSuccessor(&getBody(), getBody().front().getArguments())); - regions.emplace_back(); + regions.emplace_back(getOperation(), getOperation()->getResults()); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 05a33cf..a3430ba 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2581,7 +2581,7 @@ def LoopBlockTerminatorOp : TEST_Op<"loop_block_term", def TestNoTerminatorOp : TEST_Op<"switch_with_no_break", [ NoTerminator, - DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getSuccessorRegions"]> + DeclareOpInterfaceMethods<RegionBranchOpInterface> ]> { let arguments = (ins Index:$arg, DenseI64ArrayAttr:$cases); let regions = (region VariadicRegion<SizedRegion<1>>:$caseRegions); |
