diff options
Diffstat (limited to 'mlir/test')
344 files changed, 21909 insertions, 2521 deletions
diff --git a/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir b/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir index 3748be7..768f1cf 100644 --- a/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir +++ b/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir @@ -184,6 +184,18 @@ func.func private @private0(%0 : i32) -> i32 { // CHECK-NEXT: result #0: live // CHECK-LABEL: test_tag: y: // CHECK-NEXT: result #0: not live +// CHECK-LABEL: test_tag: for: +// CHECK-NEXT: operand #0: live +// CHECK-NEXT: operand #1: live +// CHECK-NEXT: operand #2: live +// CHECK-NEXT: operand #3: live +// CHECK-NEXT: operand #4: not live +// CHECK-NEXT: result #0: live +// CHECK-NEXT: result #1: not live +// CHECK-NEXT: region: #0: +// CHECK-NEXT: argument: #0: live +// CHECK-NEXT: argument: #1: not live +// CHECK-NEXT: argument: #2: not live func.func @test_7_type_3(%arg0: memref<i32>) { %c0 = arith.constant {tag = "zero"} 0 : index %c10 = arith.constant {tag = "ten"} 10 : index @@ -194,7 +206,7 @@ func.func @test_7_type_3(%arg0: memref<i32>) { %1 = arith.addi %x, %x : i32 %2 = func.call @private0(%1) : (i32) -> i32 scf.yield %2, %arg3 : i32, i32 - } + } {tag = "for"} memref.store %0#0, %arg0[] : memref<i32> return } diff --git a/mlir/test/CAPI/execution_engine.c b/mlir/test/CAPI/execution_engine.c index 4751288..4df232f 100644 --- a/mlir/test/CAPI/execution_engine.c +++ b/mlir/test/CAPI/execution_engine.c @@ -69,7 +69,7 @@ void testSimpleExecution(void) { mlirRegisterAllLLVMTranslations(ctx); MlirExecutionEngine jit = mlirExecutionEngineCreate( module, /*optLevel=*/2, /*numPaths=*/0, /*sharedLibPaths=*/NULL, - /*enableObjectDump=*/false); + /*enableObjectDump=*/false, /*enablePIC=*/false); if (mlirExecutionEngineIsNull(jit)) { fprintf(stderr, "Execution engine creation failed"); exit(2); @@ -125,7 +125,7 @@ void testOmpCreation(void) { // against the OpenMP library. MlirExecutionEngine jit = mlirExecutionEngineCreate( module, /*optLevel=*/2, /*numPaths=*/0, /*sharedLibPaths=*/NULL, - /*enableObjectDump=*/false); + /*enableObjectDump=*/false, /*enablePIC=*/false); if (mlirExecutionEngineIsNull(jit)) { fprintf(stderr, "Engine creation failed with OpenMP"); exit(2); diff --git a/mlir/test/CAPI/global_constructors.c b/mlir/test/CAPI/global_constructors.c index bd2fe14..9aacaf2 100644 --- a/mlir/test/CAPI/global_constructors.c +++ b/mlir/test/CAPI/global_constructors.c @@ -79,7 +79,7 @@ void testGlobalCtorJitCallback(void) { // Create execution engine with initialization disabled MlirExecutionEngine jit = mlirExecutionEngineCreate( module, /*optLevel=*/2, /*numPaths=*/0, /*sharedLibPaths=*/NULL, - /*enableObjectDump=*/false); + /*enableObjectDump=*/false, /*enablePIC=*/false); if (mlirExecutionEngineIsNull(jit)) { fprintf(stderr, "Execution engine creation failed"); diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir index 2fd3df6d..432b887 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir @@ -456,3 +456,4 @@ func.func @sched_barrier() { amdgpu.sched_barrier allow = <valu|all_vmem> func.return } + diff --git a/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir new file mode 100644 index 0000000..a94e17a --- /dev/null +++ b/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir @@ -0,0 +1,445 @@ +// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 --split-input-file --verify-diagnostics \ +// RUN: | FileCheck %s + +// CHECK-LABEL: @scaled_ext_packed_matrix_fp4 +// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf4E2M1FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) +func.func @scaled_ext_packed_matrix_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) { + // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> + // CHECK: %[[SOURCE_8xi4:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf4E2M1FN> to vector<8xi4> + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32 + // CHECK: rocdl.cvt.scale.pk8.f16.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xf16> + %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32 + // CHECK: rocdl.cvt.scale.pk8.bf16.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xbf16> + %ret1 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xbf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32 + // CHECK: rocdl.cvt.scale.pk8.f32.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xf32> + %ret2 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf32> + func.return %ret0, %ret1, %ret2: vector<8xf16>, vector<8xbf16>, vector<8xf32> +} + +// CHECK-LABEL: @scaled_ext_packed_matrix_fp8 +// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E4M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) +func.func @scaled_ext_packed_matrix_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) { + // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> + // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E4M3FN> to vector<8xi8> + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> + // CHECK: rocdl.cvt.scale.pk8.f16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16> + %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> + // CHECK: rocdl.cvt.scale.pk8.bf16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16> + %ret1 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xbf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> + // CHECK: rocdl.cvt.scale.pk8.f32.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf32> + %ret2 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf32> + + func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32> +} + +// CHECK-LABEL: @scaled_ext_packed_matrix_bf8 +// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E5M2>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) +func.func @scaled_ext_packed_matrix_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) { + // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> + // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E5M2> to vector<8xi8> + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> + // CHECK: %[[RES:.+]] = rocdl.cvt.scale.pk8.f16.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16> + %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> + // CHECK: rocdl.cvt.scale.pk8.bf16.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16> + %ret1 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xbf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> + // CHECK: rocdl.cvt.scale.pk8.f32.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf32> + %ret2 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf32> + func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32> +} + + +// CHECK-LABEL: @scaled_ext_packed_matrix_fp6 +// CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E2M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) +func.func @scaled_ext_packed_matrix_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) { + // CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> + // CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E2M3FN> to vector<16xi6> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32> + // CHECK: rocdl.cvt.scale.pk16.f16.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf16> + %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32> + // CHECK: rocdl.cvt.scale.pk16.bf16.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xbf16> + %ret1 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xbf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32> + // CHECK: rocdl.cvt.scale.pk16.f32.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf32> + %ret2 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf32> + return %ret0, %ret1, %ret2: vector<16xf16>, vector<16xbf16>, vector<16xf32> +} + +// CHECK-LABEL: @scaled_ext_packed_matrix_bf6 +// CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E3M2FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) +func.func @scaled_ext_packed_matrix_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) { + // CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> + // CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E3M2FN> to vector<16xi6> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32> + // CHECK: rocdl.cvt.scale.pk16.f16.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf16> + %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32> + // CHECK: rocdl.cvt.scale.pk16.bf16.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xbf16> + %ret1 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xbf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32> + // CHECK: rocdl.cvt.scale.pk16.f32.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf32> + %ret2 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf32> + return %ret0, %ret1, %ret2: vector<16xf16>, vector<16xbf16>, vector<16xf32> +} + +// ----- + +func.func @amdgpu.scaled_ext_packed_matrix_invalid_block_size_and_first_scale_byte_16(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) { + // expected-error@+1 {{'amdgpu.scaled_ext_packed_matrix' op blockSize of 16 can only have firstScaleByte be 0 or 1 for f4 and f6}} + %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(2) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16> + func.return +} + +// ----- + +func.func @amdgpu.scaled_ext_packed_matrix_invalid_block_size_and_first_scale_byte_32(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) { + // expected-error@+1 {{'amdgpu.scaled_ext_packed_matrix' op blockSize of 32 can only have firstScaleByte be 0 or 2 for f4 and f6.}} + %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(1) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16> + func.return +} + +// ----- + +func.func @amdgpu.scaled_ext_packed_matrix_invalid_attributes_for_f8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) { + // expected-error@+1 {{'amdgpu.scaled_ext_packed_matrix' op blockSize of 16 can only have (firstScaleLane, firstScaleByte) be (0, 0) or (16, 2) for f8.}} + %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(1) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16> + func.return +} + +// ----- + +func.func @amdgpu.scaled_ext_packed_matrix_invalid_input_output_sizes(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) { + // expected-error@+1 {{'amdgpu.scaled_ext_packed_matrix' op failed to verify that all of {source, res} have same shape}} + %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<16xf16> + func.return +} + +// ----- + +func.func @amdgpu.scaled_ext_packed_matrix_invalid_src_elem_type(%v: vector<16xf16>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>) { + // expected-error@+1 {{'amdgpu.scaled_ext_packed_matrix' op operand #0 must be}} + %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf16>, vector<4xf8E8M0FNU> -> vector<16xf16> + return %ret0: vector<16xf16> +} + +// ----- + +func.func @amdgpu.scaled_ext_packed_matrix_invalid_dst_elem_type(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf64>) { + // expected-error@+1 {{'amdgpu.scaled_ext_packed_matrix' op result #0 must be vector}} + %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf64> + return %ret0: vector<16xf64> +} + +// ----- + +#gpu_global_addrspace = 1 +#gpu_lds_addrspace = 3 +#amdgpu_fat_buffer_addrspace = 7 + +func.func @amdgpu.make_dma_base.invalid_element_types(%idx: index, %mem: memref<8xi32, #gpu_global_addrspace>, %smem: memref<8xf32,#gpu_lds_addrspace>) -> (!amdgpu.tdm_base<i32>) { + // expected-error@+1 {{'amdgpu.make_dma_base' op failed to verify that all of {global, lds} have same element type}} + %0 = amdgpu.make_dma_base %mem[%idx], %smem[%idx] : memref<8xi32, #gpu_global_addrspace>, memref<8xf32, #gpu_lds_addrspace> -> !amdgpu.tdm_base<i32> + return %0 : !amdgpu.tdm_base<i32> +} + +// ----- + +#gpu_global_addrspace = 1 +#gpu_lds_addrspace = 3 +#amdgpu_fat_buffer_addrspace = 7 + +func.func @amdgpu.make_dma_base.invalid_element_types(%idx: index, %mem: memref<8xi7, #gpu_global_addrspace>, %smem: memref<8xi7,#gpu_lds_addrspace>) -> (!amdgpu.tdm_base<i7>) { + // expected-error@+1 {{'amdgpu.make_dma_base' op element type must be 1, 2, 4, or 8 bytes long but type was 7 bits long.}} + %0 = amdgpu.make_dma_base %mem[%idx], %smem[%idx] : memref<8xi7, #gpu_global_addrspace>, memref<8xi7, #gpu_lds_addrspace> -> !amdgpu.tdm_base<i7> + return %0 : !amdgpu.tdm_base<i7> +} + +// ----- + +#gpu_global_addrspace = 1 +#gpu_lds_addrspace = 3 +#amdgpu_fat_buffer_addrspace = 7 + +// CHECK-LABEL: func @make_dma_base +// CHECK-SAME: (%[[IDX:.+]]: index, %[[MEM:.+]]: memref<8xi32, 1>, %[[SMEM:.+]]: memref<8xi32, 3>) +func.func @make_dma_base(%idx: index, %mem: memref<8xi32, #gpu_global_addrspace>, %smem: memref<8xi32,#gpu_lds_addrspace>) -> (!amdgpu.tdm_base<i32>) { + // CHECK-DAG: %[[INT:.+]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64 + // CHECK-DAG: %[[MEMREF_DESC_MEM:.+]] = builtin.unrealized_conversion_cast %[[MEM]] : memref<8xi32, 1> + // CHECK-DAG: %[[MEMREF_DESC_SMEM:.+]] = builtin.unrealized_conversion_cast %[[SMEM]] : memref<8xi32, 3> + + // CHECK-DAG: %[[MEM_BASE_PTR:.+]] = llvm.extractvalue %[[MEMREF_DESC_MEM]][1] : !llvm.struct<(ptr<1> + // CHECK-DAG: %[[SMEM_BASE_PTR:.+]] = llvm.extractvalue %[[MEMREF_DESC_SMEM]][1] : !llvm.struct<(ptr<3> + + // CHECK-DAG: %[[MEM_BASE_OFFSET:.+]] = llvm.getelementptr %[[MEM_BASE_PTR]][%[[INT]]] + // CHECK-DAG: %[[SMEM_BASE_OFFSET:.+]] = llvm.getelementptr %[[SMEM_BASE_PTR]][%[[INT]]] + + // CHECK-DAG: %[[MEM_INT:.+]] = llvm.ptrtoint %[[MEM_BASE_OFFSET]] : !llvm.ptr<1> to i64 + // CHECK-DAG: %[[SMEM_INT:.+]] = llvm.ptrtoint %[[SMEM_BASE_OFFSET]] : !llvm.ptr<3> to i32 + + // CHECK: %[[MEM_INT_LOW:.+]] = llvm.trunc %[[MEM_INT]] : i64 to i32 + // CHECK-DAG: %[[SHIFT:.+]] = llvm.mlir.constant(32 : i64) + // CHECK: %[[SHIFTED_MEM_INT:.+]] = llvm.lshr %[[MEM_INT]], %[[SHIFT]] + // CHECK: %[[MEM_INT_HIGH:.+]] = llvm.trunc %[[SHIFTED_MEM_INT]] : i64 to i32 + // CHECK-DAG: %[[MASK:.+]] = llvm.mlir.constant(33554431 : i32) + // CHECK: %[[VALID_MEM_INT_HIGH:.+]] = llvm.and %[[MEM_INT_HIGH]], %[[MASK]] + + // CHECK-DAG: %[[TYPE_FIELD:.+]] = llvm.mlir.constant(-2147483648 : i32) + // CHECK: %[[MEM_INT_HIGH_TYPE:.+]] = llvm.or %[[VALID_MEM_INT_HIGH]], %[[TYPE_FIELD]] + + // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK-DAG: %[[C1:.+]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-DAG: %[[C2:.+]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK-DAG: %[[C3:.+]] = llvm.mlir.constant(3 : i32) : i32 + + // CHECK: %[[V4I32_0_0:.+]] = llvm.mlir.poison : vector<4xi32> + // CHECK: %[[V4I32_0_1:.+]] = llvm.insertelement %[[C1]], %[[V4I32_0_0]][%[[C0]] : i32] + // CHECK: %[[V4I32_0_2:.+]] = llvm.insertelement %[[SMEM_INT]], %[[V4I32_0_1]][%[[C1]] : i32] + // CHECK: %[[V4I32_0_3:.+]] = llvm.insertelement %[[MEM_INT_LOW]], %[[V4I32_0_2]][%[[C2]] : i32] + // CHECK: %[[V4I32_0_4:.+]] = llvm.insertelement %[[MEM_INT_HIGH_TYPE]], %[[V4I32_0_3]][%[[C3]] : i32] + + %0 = amdgpu.make_dma_base %mem[%idx], %smem[%idx] : memref<8xi32, #gpu_global_addrspace>, memref<8xi32, #gpu_lds_addrspace> -> !amdgpu.tdm_base<i32> + + func.return %0 : !amdgpu.tdm_base<i32> +} + +// ----- + +// CHECK-LABEL: func @make_dma_descriptor +// CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_base<i32>) +func.func @make_dma_descriptor(%base: !amdgpu.tdm_base<i32>) -> !amdgpu.tdm_descriptor { + // CHECK-DAG: %[[DGROUP0:.+]] = builtin.unrealized_conversion_cast %[[BASE]] + + // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) + // CHECK-DAG: %[[C1:.+]] = llvm.mlir.constant(1 : i32) + // CHECK-DAG: %[[C2:.+]] = llvm.mlir.constant(2 : i32) + // CHECK-DAG: %[[C3:.+]] = llvm.mlir.constant(3 : i32) + // CHECK-DAG: %[[C4:.+]] = llvm.mlir.constant(4 : i32) + // CHECK-DAG: %[[C5:.+]] = llvm.mlir.constant(5 : i32) + // CHECK-DAG: %[[C6:.+]] = llvm.mlir.constant(6 : i32) + // CHECK-DAG: %[[C7:.+]] = llvm.mlir.constant(7 : i32) + + // CHECK-DAG: %[[DATA_SIZE:.+]] = llvm.mlir.constant(2 : i32) + // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32) + // CHECK: %[[SGPR0:.+]] = llvm.shl %[[DATA_SIZE]], %[[C16]] + + // CHECK-DAG: %[[TENSOR_DIM_0:.+]] = llvm.mlir.constant(64 : i32) + // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32) + // CHECK: %[[SGPR2_0:.+]] = llvm.lshr %[[TENSOR_DIM_0]], %[[C16]] + // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32) + // CHECK: %[[SGPR1:.+]] = llvm.shl %[[TENSOR_DIM_0]], %[[C16]] + + // CHECK-DAG: %[[TENSOR_DIM_1:.+]] = llvm.mlir.constant(128 : i32) + // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32) + // CHECK: %[[SGPR3_0:.+]] = llvm.lshr %[[TENSOR_DIM_1]], %[[C16]] + // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32) + // CHECK: %[[TENSOR_DIM_1_SHIFTED:.+]] = llvm.shl %[[TENSOR_DIM_1]], %[[C16]] + // CHECK: %[[SGPR2:.+]] = llvm.or %[[SGPR2_0]], %[[TENSOR_DIM_1_SHIFTED]] + + // CHECK-DAG: %[[TILE_DIM_0:.+]] = llvm.mlir.constant(64 : i32) + // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32) + // CHECK: %[[TILE_DIM_0_SHIFTED:.+]] = llvm.shl %[[TILE_DIM_0:.+]], %[[C16]] + // CHECK: %[[SGPR3:.+]] = llvm.or %[[SGPR3_0]], %[[TILE_DIM_0_SHIFTED]] + + // CHECK-DAG: %[[SGPR4:.+]] = llvm.mlir.constant(128 : i32) + + // CHECK-DAG: %[[TENSOR_DIM_0_STRIDE:.+]] = llvm.mlir.constant(1 : i64) : i64 + // CHECK-DAG: %[[MASK:.+]] = llvm.mlir.constant(281474976710655 : i64) : i64 + // CHECK: %[[TENSOR_DIM_0_STRIDE_MASKED:.+]] = llvm.and %[[MASK]], %[[TENSOR_DIM_0_STRIDE]] + // CHECK-DAG: %[[SGPR5:.+]] = llvm.trunc %[[TENSOR_DIM_0_STRIDE_MASKED]] : i64 to i32 + // CHECK-DAG: %[[SHIFT:.+]] = llvm.mlir.constant(32 : i64) : i64 + // CHECK: %[[TENSOR_DIM_0_STRIDE_HIGH_64:.+]] = llvm.lshr %[[TENSOR_DIM_0_STRIDE_MASKED]], %[[SHIFT]] + // CHECK: %[[SGPR6_0:.+]] = llvm.trunc %[[TENSOR_DIM_0_STRIDE_HIGH_64]] : i64 to i32 + + // CHECK-DAG: %[[TENSOR_DIM_1_STRIDE:.+]] = llvm.mlir.constant(64 : i64) + // CHECK-DAG: %[[MASK:.+]] = llvm.mlir.constant(281474976710655 : i64) : i64 + // CHECK: %[[TENSOR_DIM_1_STRIDE_MASKED:.+]] = llvm.and %[[MASK]], %[[TENSOR_DIM_1_STRIDE]] + // CHECK-DAG: %[[TENSOR_DIM_1_STRIDE_LOW:.+]] = llvm.trunc %[[TENSOR_DIM_1_STRIDE_MASKED]] + // CHECK-DAG: %[[SHIFT:.+]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK: %[[TENSOR_DIM_1_STRIDE_SHIFTED:.+]] = llvm.lshr %[[TENSOR_DIM_1_STRIDE_MASKED]], %[[SHIFT]] + // CHECK: %[[SGPR7:.+]] = llvm.trunc %[[TENSOR_DIM_1_STRIDE_SHIFTED]] : i64 to i32 + // CHECK-DAG: %[[SHIFT:.+]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: %[[TENSOR_DIM_1_STRIDE_LOW_SHIFTED:.+]] = llvm.shl %[[TENSOR_DIM_1_STRIDE_LOW]], %[[SHIFT]] + // CHECK-DAG: %[[SGPR6:.+]] = llvm.or %[[SGPR6_0]], %[[TENSOR_DIM_1_STRIDE_LOW_SHIFTED]] + + // CHECK: %[[V8I32:.+]] = llvm.mlir.poison : vector<8xi32> + // CHECK: %[[DGROUP1_0:.+]] = llvm.insertelement %[[SGPR0]], %[[V8I32]][%[[C0]] : i32] + // CHECK: %[[DGROUP1_1:.+]] = llvm.insertelement %[[SGPR1]], %[[DGROUP1_0]][%[[C1]] : i32] + // CHECK: %[[DGROUP1_2:.+]] = llvm.insertelement %[[SGPR2]], %[[DGROUP1_1]][%[[C2]] : i32] + // CHECK: %[[DGROUP1_3:.+]] = llvm.insertelement %[[SGPR3]], %[[DGROUP1_2]][%[[C3]] : i32] + // CHECK: %[[DGROUP1_4:.+]] = llvm.insertelement %[[SGPR4]], %[[DGROUP1_3]][%[[C4]] : i32] + // CHECK: %[[DGROUP1_5:.+]] = llvm.insertelement %[[SGPR5]], %[[DGROUP1_4]][%[[C5]] : i32] + // CHECK: %[[DGROUP1_6:.+]] = llvm.insertelement %[[SGPR6]], %[[DGROUP1_5]][%[[C6]] : i32] + // CHECK: %[[DGROUP1:.+]] = llvm.insertelement %[[SGPR7]], %[[DGROUP1_6]][%[[C7]] : i32] + + // CHECK: %[[DGROUPS:.+]] = builtin.unrealized_conversion_cast %[[DGROUP0]], %[[DGROUP1]] : vector<4xi32>, vector<8xi32> to !amdgpu.tdm_descriptor + %descriptor = amdgpu.make_dma_descriptor %base globalSize [128, 64] globalStride [64, 1] sharedSize [128, 64] : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor + func.return %descriptor : !amdgpu.tdm_descriptor +} + +// ----- + +#gpu_global_addrspace = 1 +#gpu_lds_addrspace = 3 +#amdgpu_fat_buffer_addrspace = 7 + +// CHECK-LABEL: func @make_dma_descriptor_atomic_barrier +// CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_base<i32>, %[[BARRIER:.+]]: {{.*}}, %[[IDX:.+]]: index) +func.func @make_dma_descriptor_atomic_barrier(%base: !amdgpu.tdm_base<i32>, %barrier : memref<8xi32, #gpu_lds_addrspace>, %idx: index) -> !amdgpu.tdm_descriptor { + // CHECK-DAG: %[[INDEX:.+]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64 + // CHECK-DAG: %[[BARRIER_MEMREF_DESC:.+]] = builtin.unrealized_conversion_cast %[[BARRIER]] + // CHECK-DAG: %[[DGROUP0:.+]] = builtin.unrealized_conversion_cast %[[BASE]] + + // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) + // CHECK-DAG: %[[C1:.+]] = llvm.mlir.constant(1 : i32) + // CHECK-DAG: %[[C2:.+]] = llvm.mlir.constant(2 : i32) + // CHECK-DAG: %[[C3:.+]] = llvm.mlir.constant(3 : i32) + // CHECK-DAG: %[[C4:.+]] = llvm.mlir.constant(4 : i32) + // CHECK-DAG: %[[C5:.+]] = llvm.mlir.constant(5 : i32) + // CHECK-DAG: %[[C6:.+]] = llvm.mlir.constant(6 : i32) + // CHECK-DAG: %[[C7:.+]] = llvm.mlir.constant(7 : i32) + + // CHECK-DAG: %[[DATA_SIZE:.+]] = llvm.mlir.constant(2 : i32) + // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32) + // CHECK: %[[SGPR0_0:.+]] = llvm.shl %[[DATA_SIZE]], %[[C16]] + + // CHECK-DAG: %[[ATOMIC_BARRIER_ENABLE_OFFSET:.+]] = llvm.mlir.constant(18 : i32) + // CHECK: %[[ATOMIC_BARRIER_ENABLE_FIELD:.+]] = llvm.shl %[[C1]], %[[ATOMIC_BARRIER_ENABLE_OFFSET]] + // CHECK: %[[SGPR0:.+]] = llvm.or %[[SGPR0_0]], %[[ATOMIC_BARRIER_ENABLE_FIELD]] + + // CHECK: %[[ATOMIC_BARRIER_ALIGNED_PTR:.+]] = llvm.extractvalue %[[BARRIER_MEMREF_DESC]][1] + // CHECK: %[[ATOMIC_BARRIER_ADDR:.+]] = llvm.getelementptr %[[ATOMIC_BARRIER_ALIGNED_PTR]][%[[INDEX]] + // CHECK: %[[ATOMIC_BARRIER_I32:.+]] = llvm.ptrtoint %[[ATOMIC_BARRIER_ADDR]] : !llvm.ptr<3> to i32 + // CHECK: %[[ATOMIC_BARRIER_NO_3_LSB:.+]] = llvm.lshr %[[ATOMIC_BARRIER_I32]], %[[C3]] + // CHECK: %[[MASK:.+]] = llvm.mlir.constant(65535 : i32) + // CHECK: %[[ATOMIC_BARRIER:.+]] = llvm.and %[[ATOMIC_BARRIER_NO_3_LSB]], %[[MASK]] + + // CHECK-DAG: %[[TENSOR_DIM_0:.+]] = llvm.mlir.constant(64 : i32) + // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32) + // CHECK: %[[SGPR2_0:.+]] = llvm.lshr %[[TENSOR_DIM_0]], %[[C16]] + // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32) + // CHECK: %[[SGPR1_0:.+]] = llvm.shl %[[TENSOR_DIM_0]], %[[C16]] + // CHECK: %[[SGPR1:.+]] = llvm.or %[[ATOMIC_BARRIER]], %[[SGPR1_0]] + + // CHECK: %[[V8I32:.+]] = llvm.mlir.poison : vector<8xi32> + // CHECK: %[[DGROUP1_0:.+]] = llvm.insertelement %[[SGPR0]], %[[V8I32]][%[[C0]] : i32] + // CHECK: %[[DGROUP1_1:.+]] = llvm.insertelement %[[SGPR1]], %[[DGROUP1_0]][%[[C1]] : i32] + + %descriptor = amdgpu.make_dma_descriptor %base globalSize [128, 64] + globalStride [64, 1] + sharedSize [128, 64] + atomicBarrier(%barrier[%idx] : memref<8xi32, #gpu_lds_addrspace>) + : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor + func.return %descriptor : !amdgpu.tdm_descriptor +} + +// ----- + +// CHECK-LABEL: func @make_dma_descriptor_workgroup_mask +// CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_base<i32>, %[[WG_MASK:.+]]: i16, %[[TIMEOUT:.+]]: i1) +func.func @make_dma_descriptor_workgroup_mask(%base: !amdgpu.tdm_base<i32>, %wg_mask: i16, %timeout: i1) -> !amdgpu.tdm_descriptor { + // CHECK-DAG: %[[DGROUP0:.+]] = builtin.unrealized_conversion_cast %[[BASE]] + + // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) + // CHECK-DAG: %[[C1:.+]] = llvm.mlir.constant(1 : i32) + // CHECK-DAG: %[[C2:.+]] = llvm.mlir.constant(2 : i32) + // CHECK-DAG: %[[C3:.+]] = llvm.mlir.constant(3 : i32) + // CHECK-DAG: %[[C4:.+]] = llvm.mlir.constant(4 : i32) + // CHECK-DAG: %[[C5:.+]] = llvm.mlir.constant(5 : i32) + // CHECK-DAG: %[[C6:.+]] = llvm.mlir.constant(6 : i32) + // CHECK-DAG: %[[C7:.+]] = llvm.mlir.constant(7 : i32) + + // CHECK-DAG: %[[WG_MASK_EXT:.+]] = llvm.zext %[[WG_MASK]] + // CHECK-DAG: %[[DATA_SIZE:.+]] = llvm.mlir.constant(2 : i32) + // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32) + // CHECK: %[[DATA_SIZE_SHIFTED:.+]] = llvm.shl %[[DATA_SIZE]], %[[C16]] + // CHECK: %[[SGPR0_BASE:.+]] = llvm.or %[[WG_MASK_EXT]], %[[DATA_SIZE_SHIFTED]] + // CHECK-DAG: %[[C21:.+]] = llvm.mlir.constant(21 : i32) + // CHECK: %[[TIMEOUT_SHIFTED:.+]] = llvm.shl %[[C1]], %[[C21]] + // CHECK: %[[SGPR0:.+]] = llvm.or %[[SGPR0_BASE]], %[[TIMEOUT_SHIFTED]] + + // CHECK-DAG: %[[TENSOR_DIM_0:.+]] = llvm.mlir.constant(64 : i32) + // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32) + // CHECK: %[[SGPR2_0:.+]] = llvm.lshr %[[TENSOR_DIM_0]], %[[C16]] + // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32) + // CHECK: %[[SGPR1:.+]] = llvm.shl %[[TENSOR_DIM_0]], %[[C16]] + + // CHECK-DAG: %[[TENSOR_DIM_1:.+]] = llvm.mlir.constant(128 : i32) + // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32) + // CHECK: %[[SGPR3_0:.+]] = llvm.lshr %[[TENSOR_DIM_1]], %[[C16]] + // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32) + // CHECK: %[[TENSOR_DIM_1_SHIFTED:.+]] = llvm.shl %[[TENSOR_DIM_1]], %[[C16]] + // CHECK: %[[SGPR2:.+]] = llvm.or %[[SGPR2_0]], %[[TENSOR_DIM_1_SHIFTED]] + + // CHECK-DAG: %[[TILE_DIM_0:.+]] = llvm.mlir.constant(64 : i32) + // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32) + // CHECK: %[[TILE_DIM_0_SHIFTED:.+]] = llvm.shl %[[TILE_DIM_0:.+]], %[[C16]] + // CHECK: %[[SGPR3:.+]] = llvm.or %[[SGPR3_0]], %[[TILE_DIM_0_SHIFTED]] + + // CHECK-DAG: %[[SGPR4:.+]] = llvm.mlir.constant(128 : i32) + + // CHECK-DAG: %[[TENSOR_DIM_0_STRIDE:.+]] = llvm.mlir.constant(1 : i64) : i64 + // CHECK-DAG: %[[MASK:.+]] = llvm.mlir.constant(281474976710655 : i64) : i64 + // CHECK: %[[TENSOR_DIM_0_STRIDE_MASKED:.+]] = llvm.and %[[MASK]], %[[TENSOR_DIM_0_STRIDE]] + // CHECK-DAG: %[[SGPR5:.+]] = llvm.trunc %[[TENSOR_DIM_0_STRIDE_MASKED]] : i64 to i32 + // CHECK-DAG: %[[SHIFT:.+]] = llvm.mlir.constant(32 : i64) : i64 + // CHECK: %[[TENSOR_DIM_0_STRIDE_HIGH_64:.+]] = llvm.lshr %[[TENSOR_DIM_0_STRIDE_MASKED]], %[[SHIFT]] + // CHECK: %[[SGPR6_0:.+]] = llvm.trunc %[[TENSOR_DIM_0_STRIDE_HIGH_64]] : i64 to i32 + + // CHECK-DAG: %[[TENSOR_DIM_1_STRIDE:.+]] = llvm.mlir.constant(64 : i64) + // CHECK-DAG: %[[MASK:.+]] = llvm.mlir.constant(281474976710655 : i64) : i64 + // CHECK: %[[TENSOR_DIM_1_STRIDE_MASKED:.+]] = llvm.and %[[MASK]], %[[TENSOR_DIM_1_STRIDE]] + // CHECK-DAG: %[[TENSOR_DIM_1_STRIDE_LOW:.+]] = llvm.trunc %[[TENSOR_DIM_1_STRIDE_MASKED]] + // CHECK-DAG: %[[SHIFT:.+]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK: %[[TENSOR_DIM_1_STRIDE_SHIFTED:.+]] = llvm.lshr %[[TENSOR_DIM_1_STRIDE_MASKED]], %[[SHIFT]] + // CHECK: %[[SGPR7:.+]] = llvm.trunc %[[TENSOR_DIM_1_STRIDE_SHIFTED]] : i64 to i32 + // CHECK-DAG: %[[SHIFT:.+]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: %[[TENSOR_DIM_1_STRIDE_LOW_SHIFTED:.+]] = llvm.shl %[[TENSOR_DIM_1_STRIDE_LOW]], %[[SHIFT]] + // CHECK-DAG: %[[SGPR6:.+]] = llvm.or %[[SGPR6_0]], %[[TENSOR_DIM_1_STRIDE_LOW_SHIFTED]] + + // CHECK: %[[V8I32:.+]] = llvm.mlir.poison : vector<8xi32> + // CHECK: %[[DGROUP1_0:.+]] = llvm.insertelement %[[SGPR0]], %[[V8I32]][%[[C0]] : i32] + // CHECK: %[[DGROUP1_1:.+]] = llvm.insertelement %[[SGPR1]], %[[DGROUP1_0]][%[[C1]] : i32] + // CHECK: %[[DGROUP1_2:.+]] = llvm.insertelement %[[SGPR2]], %[[DGROUP1_1]][%[[C2]] : i32] + // CHECK: %[[DGROUP1_3:.+]] = llvm.insertelement %[[SGPR3]], %[[DGROUP1_2]][%[[C3]] : i32] + // CHECK: %[[DGROUP1_4:.+]] = llvm.insertelement %[[SGPR4]], %[[DGROUP1_3]][%[[C4]] : i32] + // CHECK: %[[DGROUP1_5:.+]] = llvm.insertelement %[[SGPR5]], %[[DGROUP1_4]][%[[C5]] : i32] + // CHECK: %[[DGROUP1_6:.+]] = llvm.insertelement %[[SGPR6]], %[[DGROUP1_5]][%[[C6]] : i32] + // CHECK: %[[DGROUP1:.+]] = llvm.insertelement %[[SGPR7]], %[[DGROUP1_6]][%[[C7]] : i32] + + // CHECK: %[[DGROUPS:.+]] = builtin.unrealized_conversion_cast %[[DGROUP0]], %[[DGROUP1]] : vector<4xi32>, vector<8xi32> to !amdgpu.tdm_descriptor + %descriptor = amdgpu.make_dma_descriptor %base globalSize [128, 64] globalStride [64, 1] sharedSize [128, 64] workgroupMask %wg_mask earlyTimeout %timeout : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor + func.return %descriptor : !amdgpu.tdm_descriptor +} diff --git a/mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait.mlir b/mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait.mlir index 1016ee8..537ef59 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait.mlir @@ -1,7 +1,7 @@ -// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx942 | FileCheck %s --check-prefixes=CHECK,GFX9 -// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1030 | FileCheck %s --check-prefixes=CHECK,GFX10 -// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1100 | FileCheck %s --check-prefixes=CHECK,GFX11 -// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1201 | FileCheck %s --check-prefixes=CHECK,GFX12 +// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx942 | FileCheck %s --check-prefixes=CHECK,GFX9 +// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1030 | FileCheck %s --check-prefixes=CHECK,GFX10 +// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1100 | FileCheck %s --check-prefixes=CHECK,GFX11 +// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1201 | FileCheck %s --check-prefixes=CHECK,GFX12 // CHECK-LABEL: func @memory_counter_wait func.func @memory_counter_wait() { diff --git a/mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait_tensor.mlir b/mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait_tensor.mlir new file mode 100644 index 0000000..5b29e01 --- /dev/null +++ b/mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait_tensor.mlir @@ -0,0 +1,9 @@ +// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 | FileCheck %s + +// CHECK-LABEL: func @memory_counter_wait_tensor +func.func @memory_counter_wait_tensor() { + // CHECK: rocdl.s.wait.tensorcnt 3 + amdgpu.memory_counter_wait tensor(3) + + return +} diff --git a/mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait_unsupported.mlir b/mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait_unsupported.mlir new file mode 100644 index 0000000..1d2f692 --- /dev/null +++ b/mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait_unsupported.mlir @@ -0,0 +1,11 @@ +// RUN: mlir-opt %s --verify-diagnostics --convert-amdgpu-to-rocdl=chipset=gfx942 +// RUN: mlir-opt %s --verify-diagnostics --convert-amdgpu-to-rocdl=chipset=gfx1030 +// RUN: mlir-opt %s --verify-diagnostics --convert-amdgpu-to-rocdl=chipset=gfx1100 + +func.func @memory_counter_wait_tensor() { + // expected-error @below{{failed to legalize operation 'amdgpu.memory_counter_wait'}} + // expected-error @below{{'amdgpu.memory_counter_wait' op unsupported chipset}} + amdgpu.memory_counter_wait tensor(0) + + return +} diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir index 9fcc147..4e6aa17 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir @@ -6,30 +6,30 @@ func.func @wmma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 : %arg6 : vector<16xi8>, %arg7 : vector<8xi32>, %arg8 : vector<4xi32>, %arg9 : vector<16xui8>, %arg10 : vector<16xi4>, %arg11 : vector<8xi4>) { // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32> - amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32> + amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 {subwordOffset = 0 : i32} : vector<16xf16>, vector<16xf16>, vector<8xf32> // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32> - amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32> + amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg2 {subwordOffset = 0 : i32} : vector<16xf16>, vector<16xf16>, vector<4xf32> // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xf32>) -> vector<8xf32> - amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32> + amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg1 {subwordOffset = 0 : i32} : vector<16xbf16>, vector<16xbf16>, vector<8xf32> // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<4xf32>) -> vector<4xf32> - amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32> - // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16> + amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg2 {subwordOffset = 0 : i32} : vector<16xbf16>, vector<16xbf16>, vector<4xf32> + // CHECK: rocdl.wmma.f16.16x16x16.f16 {{.*}} {opsel = true} : (vector<16xf16>, vector<16xf16>, vector<16xf16>) -> vector<16xf16> amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16> - // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16> + // CHECK: rocdl.wmma.f16.16x16x16.f16 {{.*}} : (vector<16xf16>, vector<16xf16>, vector<8xf16>) -> vector<8xf16> amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16> - // CHECK: %[[raw_bf16x16:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16> + // CHECK: %[[raw_bf16x16:.+]] = rocdl.wmma.bf16.16x16x16.bf16 {{.*}} {opsel = true} : (vector<16xi16>, vector<16xi16>, vector<16xi16>) -> vector<16xi16> // CHECK-NEXT: llvm.bitcast %[[raw_bf16x16]] : vector<16xi16> to vector<16xbf16> amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16> - // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16> + // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16 {{.*}} : (vector<16xi16>, vector<16xi16>, vector<8xi16>) -> vector<8xi16> // CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16> amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16> - // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32> + // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}{clamp = true, signA = true, signB = true} : (vector<4xi32>, vector<4xi32>, vector<8xi32>) -> vector<8xi32> amdgpu.wmma 16x16x16 %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32> - // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32> + // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}{clamp = true} : (vector<4xi32>, vector<4xi32>, vector<4xi32>) -> vector<4xi32> amdgpu.wmma 16x16x16 %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32> - // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32> + // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}{clamp = true, signA = true, signB = true} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32> amdgpu.wmma 16x16x16 %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32> - // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32> + // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}{clamp = true, signA = true, signB = true} : (i32, i32, vector<4xi32>) -> vector<4xi32> amdgpu.wmma 16x16x16 %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32> return diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir index 5788347..978227b 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir @@ -20,15 +20,15 @@ func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<4xf16>, // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xf32>) -> vector<4xf32> amdgpu.wmma 16x16x16 %arg5 * %arg5 + %arg3 : vector<4xbf16>, vector<4xbf16>, vector<4xf32> - // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<8xf16>, i1) -> vector<8xf16> + // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<8xf16>) -> vector<8xf16> amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg0 : vector<8xf16>, vector<8xf16>, vector<8xf16> - // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf16>, i1) -> vector<4xf16> + // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf16>) -> vector<4xf16> amdgpu.wmma 16x16x16 %arg1 * %arg1 + %arg1 : vector<4xf16>, vector<4xf16>, vector<4xf16> - // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<8xi16>, i1) -> vector<8xi16> + // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<8xi16>) -> vector<8xi16> // CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16> amdgpu.wmma 16x16x16 %arg4 * %arg4 + %arg4 : vector<8xbf16>, vector<8xbf16>, vector<8xbf16> - // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xi16>, i1) -> vector<4xi16> + // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xi16>) -> vector<4xi16> amdgpu.wmma 16x16x16 %arg5 * %arg5 + %arg5 : vector<4xbf16>, vector<4xbf16>, vector<4xbf16> // CHECK: rocdl.wmma.f32.16x16x16.fp8_fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32> @@ -51,19 +51,19 @@ func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<4xf16>, // CHECK: rocdl.wmma.f32.16x16x16.bf8_fp8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32> amdgpu.wmma 16x16x16 %arg9 * %arg7 + %arg3 : vector<4xf8E5M2>, vector<4xf8E4M3FN>, vector<4xf32> - // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32> + // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}{clamp = true, signA = true, signB = true} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32> amdgpu.wmma 16x16x16 %arg10 * %arg10 + %arg12 {clamp} : vector<8xi8>, vector<8xi8>, vector<8xi32> - // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32> + // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}{clamp = true} : (i32, i32, vector<4xi32>) -> vector<4xi32> amdgpu.wmma 16x16x16 %arg11 * %arg11 + %arg13 {unsignedA, unsignedB, clamp}: vector<4xi8>, vector<4xi8>, vector<4xi32> - // CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32> + // CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}{clamp = true, signA = true, signB = true} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32> amdgpu.wmma 16x16x32 %arg14 * %arg14 + %arg12 {clamp} : vector<16xi4>, vector<16xi4>, vector<8xi32> - // CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32> + // CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}{clamp = true, signA = true, signB = true} : (i32, i32, vector<4xi32>) -> vector<4xi32> amdgpu.wmma 16x16x32 %arg15 * %arg15 + %arg13 {clamp} : vector<8xi4>, vector<8xi4>, vector<4xi32> - // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<8xi32>, i1) -> vector<8xi32> + // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}{clamp = true, signA = true, signB = true} : (i32, i32, vector<8xi32>) -> vector<8xi32> amdgpu.wmma 16x16x16 %arg15 * %arg15 + %arg12 {clamp} : vector<8xi4>, vector<8xi4>, vector<8xi32> - // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32> + // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}{clamp = true, signA = true, signB = true} : (i32, i32, vector<4xi32>) -> vector<4xi32> amdgpu.wmma 16x16x16 %arg16 * %arg16 + %arg13 {clamp} : vector<4xi4>, vector<4xi4>, vector<4xi32> func.return diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir index 5e77a3ad..37259f6 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir @@ -14,13 +14,13 @@ func.func @wmma_k32(%arg0 : vector<16xf16>, %arg1 : vector<16xbf16>, %arg2 : vec // CHECK: rocdl.wmma.f32.16x16x32.f16 %arg0, %arg0, %arg2 amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<8xf32> - // CHECK: rocdl.wmma.f16.16x16x32.f16 %arg0, %arg0, {{.*}} : (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) + // CHECK: rocdl.wmma.f16.16x16x32.f16 %arg0, %arg0, {{.*}} : (vector<16xf16>, vector<16xf16>, vector<8xf16>) amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg3 : vector<16xf16>, vector<16xf16>, vector<8xf16> // CHECK: rocdl.wmma.f32.16x16x32.bf16 {{.*}}, {{.*}}, %arg2 amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<8xf32> - // CHECK: rocdl.wmma.bf16.16x16x32.bf16 {{.*}}, {{.*}}, {{.*}}, {{.*}} : (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) + // CHECK: rocdl.wmma.bf16.16x16x32.bf16 {{.*}}, {{.*}}, {{.*}} : (vector<16xbf16>, vector<16xbf16>, vector<8xbf16>) amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg4 : vector<16xbf16>, vector<16xbf16>, vector<8xbf16> return @@ -29,31 +29,31 @@ func.func @wmma_k32(%arg0 : vector<16xf16>, %arg1 : vector<16xbf16>, %arg2 : vec // CHECK-LABEL: @wmma_k64 func.func @wmma_k64(%arg0 : vector<32xi8>, %arg1 : vector<32xf8E4M3FN>, %arg2 : vector<32xf8E5M2>, %arg3 : vector<8xi32>, %arg4 : vector<8xf32>, %arg5 : vector<8xf16>) { - // CHECK: rocdl.wmma.i32.16x16x64.iu8 {{.*}}, {{.*}}, {{.*}}, {{.*}}, %arg3, {{.*}} + // CHECK: rocdl.wmma.i32.16x16x64.iu8 {{.*}}, {{.*}}, %arg3 {clamp = true, signA = true, signB = true} amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg3 {clamp} : vector<32xi8>, vector<32xi8>, vector<8xi32> // CHECK: rocdl.wmma.f32.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg4 amdgpu.wmma 16x16x64 %arg1 * %arg1 + %arg4 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf32> - // CHECK: rocdl.wmma.f16.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1) + // CHECK: rocdl.wmma.f16.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg5 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>) amdgpu.wmma 16x16x64 %arg1 * %arg1 + %arg5 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf16> // CHECK: rocdl.wmma.f32.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg4 amdgpu.wmma 16x16x64 %arg1 * %arg2 + %arg4 : vector<32xf8E4M3FN>, vector<32xf8E5M2>, vector<8xf32> - // CHECK: rocdl.wmma.f16.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1) + // CHECK: rocdl.wmma.f16.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg5 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>) amdgpu.wmma 16x16x64 %arg1 * %arg2 + %arg5 : vector<32xf8E4M3FN>, vector<32xf8E5M2>, vector<8xf16> // CHECK: rocdl.wmma.f32.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg4 amdgpu.wmma 16x16x64 %arg2 * %arg2 + %arg4 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf32> - // CHECK: rocdl.wmma.f16.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1) + // CHECK: rocdl.wmma.f16.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg5 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>) amdgpu.wmma 16x16x64 %arg2 * %arg2 + %arg5 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf16> // CHECK: rocdl.wmma.f32.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg4 amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg4 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf32> - // CHECK: rocdl.wmma.f16.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1) + // CHECK: rocdl.wmma.f16.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg5 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>) amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg5 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf16> return @@ -65,25 +65,25 @@ func.func @wmma_k128(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>, // CHECK: rocdl.wmma.f32.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg2 amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg2 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf32> - // CHECK: rocdl.wmma.f16.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1) + // CHECK: rocdl.wmma.f16.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg3 {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>) amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg3 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf16> // CHECK: rocdl.wmma.f32.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg2 amdgpu.wmma 16x16x128 %arg0 * %arg1 + %arg2 : vector<64xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf32> - // CHECK: rocdl.wmma.f16.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1) + // CHECK: rocdl.wmma.f16.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg3 {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>) amdgpu.wmma 16x16x128 %arg0 * %arg1 + %arg3 : vector<64xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf16> // CHECK: rocdl.wmma.f32.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg2 amdgpu.wmma 16x16x128 %arg1 * %arg1 + %arg2 : vector<64xf8E5M2>, vector<64xf8E5M2>, vector<8xf32> - // CHECK: rocdl.wmma.f16.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1) + // CHECK: rocdl.wmma.f16.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg3 {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>) amdgpu.wmma 16x16x128 %arg1 * %arg1 + %arg3 : vector<64xf8E5M2>, vector<64xf8E5M2>, vector<8xf16> // CHECK: rocdl.wmma.f32.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg2 amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg2 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf32> - // CHECK: rocdl.wmma.f16.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1) + // CHECK: rocdl.wmma.f16.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg3 {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>) amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg3 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf16> return diff --git a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir new file mode 100644 index 0000000..bd4a9da --- /dev/null +++ b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir @@ -0,0 +1,329 @@ +// RUN: mlir-opt %s --convert-arith-to-apfloat -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func private @_mlir_apfloat_add(i32, i64, i64) -> i64 + +// CHECK-LABEL: func.func @foo() -> f8E4M3FN { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 2.250000e+00 : f8E4M3FN +// CHECK: return %[[CONSTANT_0]] : f8E4M3FN +// CHECK: } + +// CHECK-LABEL: func.func @bar() -> f6E3M2FN { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 3.000000e+00 : f6E3M2FN +// CHECK: return %[[CONSTANT_0]] : f6E3M2FN +// CHECK: } + +// Illustrate that both f8E4M3FN and f6E3M2FN calling the same _mlir_apfloat_add is fine +// because each gets its own semantics enum and gets bitcast/extui/trunci to its own width. +// CHECK-LABEL: func.func @full_example() { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 1.375000e+00 : f8E4M3FN +// CHECK: %[[VAL_0:.*]] = call @foo() : () -> f8E4M3FN +// CHECK: %[[BITCAST_0:.*]] = arith.bitcast %[[CONSTANT_0]] : f8E4M3FN to i8 +// CHECK: %[[EXTUI_0:.*]] = arith.extui %[[BITCAST_0]] : i8 to i64 +// CHECK: %[[BITCAST_1:.*]] = arith.bitcast %[[VAL_0]] : f8E4M3FN to i8 +// CHECK: %[[EXTUI_1:.*]] = arith.extui %[[BITCAST_1]] : i8 to i64 +// // fltSemantics semantics for f8E4M3FN +// CHECK: %[[CONSTANT_1:.*]] = arith.constant 10 : i32 +// CHECK: %[[VAL_1:.*]] = call @_mlir_apfloat_add(%[[CONSTANT_1]], %[[EXTUI_0]], %[[EXTUI_1]]) : (i32, i64, i64) -> i64 +// CHECK: %[[TRUNCI_0:.*]] = arith.trunci %[[VAL_1]] : i64 to i8 +// CHECK: %[[BITCAST_2:.*]] = arith.bitcast %[[TRUNCI_0]] : i8 to f8E4M3FN +// CHECK: vector.print %[[BITCAST_2]] : f8E4M3FN + +// CHECK: %[[CONSTANT_2:.*]] = arith.constant 2.500000e+00 : f6E3M2FN +// CHECK: %[[VAL_2:.*]] = call @bar() : () -> f6E3M2FN +// CHECK: %[[BITCAST_3:.*]] = arith.bitcast %[[CONSTANT_2]] : f6E3M2FN to i6 +// CHECK: %[[EXTUI_2:.*]] = arith.extui %[[BITCAST_3]] : i6 to i64 +// CHECK: %[[BITCAST_4:.*]] = arith.bitcast %[[VAL_2]] : f6E3M2FN to i6 +// CHECK: %[[EXTUI_3:.*]] = arith.extui %[[BITCAST_4]] : i6 to i64 +// // fltSemantics semantics for f6E3M2FN +// CHECK: %[[CONSTANT_3:.*]] = arith.constant 16 : i32 +// CHECK: %[[VAL_3:.*]] = call @_mlir_apfloat_add(%[[CONSTANT_3]], %[[EXTUI_2]], %[[EXTUI_3]]) : (i32, i64, i64) -> i64 +// CHECK: %[[TRUNCI_1:.*]] = arith.trunci %[[VAL_3]] : i64 to i6 +// CHECK: %[[BITCAST_5:.*]] = arith.bitcast %[[TRUNCI_1]] : i6 to f6E3M2FN +// CHECK: vector.print %[[BITCAST_5]] : f6E3M2FN +// CHECK: return +// CHECK: } + +// Put rhs into separate function so that it won't be constant-folded. +func.func @foo() -> f8E4M3FN { + %cst = arith.constant 2.2 : f8E4M3FN + return %cst : f8E4M3FN +} + +func.func @bar() -> f6E3M2FN { + %cst = arith.constant 3.2 : f6E3M2FN + return %cst : f6E3M2FN +} + +func.func @full_example() { + %a = arith.constant 1.4 : f8E4M3FN + %b = func.call @foo() : () -> (f8E4M3FN) + %c = arith.addf %a, %b : f8E4M3FN + vector.print %c : f8E4M3FN + + %d = arith.constant 2.4 : f6E3M2FN + %e = func.call @bar() : () -> (f6E3M2FN) + %f = arith.addf %d, %e : f6E3M2FN + vector.print %f : f6E3M2FN + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_add(i32, i64, i64) -> i64 +// CHECK: %[[sem:.*]] = arith.constant 18 : i32 +// CHECK: call @_mlir_apfloat_add(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 +func.func @addf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { + %0 = arith.addf %arg0, %arg1 : f4E2M1FN + return +} + +// ----- + +// Test decl collision (different type) +// expected-error@+1{{matched function '_mlir_apfloat_add' but with different type: '(i32, i32, f32) -> index' (expected '(i32, i64, i64) -> i64')}} +func.func private @_mlir_apfloat_add(i32, i32, f32) -> index +func.func @addf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { + %0 = arith.addf %arg0, %arg1 : f4E2M1FN + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_subtract(i32, i64, i64) -> i64 +// CHECK: %[[sem:.*]] = arith.constant 18 : i32 +// CHECK: call @_mlir_apfloat_subtract(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 +func.func @subf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { + %0 = arith.subf %arg0, %arg1 : f4E2M1FN + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_multiply(i32, i64, i64) -> i64 +// CHECK: %[[sem:.*]] = arith.constant 18 : i32 +// CHECK: call @_mlir_apfloat_multiply(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 +func.func @subf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { + %0 = arith.mulf %arg0, %arg1 : f4E2M1FN + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_divide(i32, i64, i64) -> i64 +// CHECK: %[[sem:.*]] = arith.constant 18 : i32 +// CHECK: call @_mlir_apfloat_divide(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 +func.func @subf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { + %0 = arith.divf %arg0, %arg1 : f4E2M1FN + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_remainder(i32, i64, i64) -> i64 +// CHECK: %[[sem:.*]] = arith.constant 18 : i32 +// CHECK: call @_mlir_apfloat_remainder(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 +func.func @remf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { + %0 = arith.remf %arg0, %arg1 : f4E2M1FN + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_convert(i32, i32, i64) -> i64 +// CHECK: %[[sem_in:.*]] = arith.constant 18 : i32 +// CHECK: %[[sem_out:.*]] = arith.constant 2 : i32 +// CHECK: call @_mlir_apfloat_convert(%[[sem_in]], %[[sem_out]], %{{.*}}) : (i32, i32, i64) -> i64 +func.func @extf(%arg0: f4E2M1FN) { + %0 = arith.extf %arg0 : f4E2M1FN to f32 + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_convert(i32, i32, i64) -> i64 +// CHECK: %[[sem_in:.*]] = arith.constant 1 : i32 +// CHECK: %[[sem_out:.*]] = arith.constant 18 : i32 +// CHECK: call @_mlir_apfloat_convert(%[[sem_in]], %[[sem_out]], %{{.*}}) : (i32, i32, i64) -> i64 +func.func @truncf(%arg0: bf16) { + %0 = arith.truncf %arg0 : bf16 to f4E2M1FN + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_convert_to_int(i32, i32, i1, i64) -> i64 +// CHECK: %[[sem_in:.*]] = arith.constant 0 : i32 +// CHECK: %[[out_width:.*]] = arith.constant 4 : i32 +// CHECK: %[[is_unsigned:.*]] = arith.constant false +// CHECK: %[[res:.*]] = call @_mlir_apfloat_convert_to_int(%[[sem_in]], %[[out_width]], %[[is_unsigned]], %{{.*}}) : (i32, i32, i1, i64) -> i64 +// CHECK: arith.trunci %[[res]] : i64 to i4 +func.func @fptosi(%arg0: f16) { + %0 = arith.fptosi %arg0 : f16 to i4 + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_convert_to_int(i32, i32, i1, i64) -> i64 +// CHECK: %[[sem_in:.*]] = arith.constant 0 : i32 +// CHECK: %[[out_width:.*]] = arith.constant 4 : i32 +// CHECK: %[[is_unsigned:.*]] = arith.constant true +// CHECK: %[[res:.*]] = call @_mlir_apfloat_convert_to_int(%[[sem_in]], %[[out_width]], %[[is_unsigned]], %{{.*}}) : (i32, i32, i1, i64) -> i64 +// CHECK: arith.trunci %[[res]] : i64 to i4 +func.func @fptoui(%arg0: f16) { + %0 = arith.fptoui %arg0 : f16 to i4 + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_convert_from_int(i32, i32, i1, i64) -> i64 +// CHECK: %[[sem_out:.*]] = arith.constant 18 : i32 +// CHECK: %[[in_width:.*]] = arith.constant 32 : i32 +// CHECK: %[[is_unsigned:.*]] = arith.constant false +// CHECK: %[[res:.*]] = call @_mlir_apfloat_convert_from_int(%[[sem_out]], %[[in_width]], %[[is_unsigned]], %{{.*}}) : (i32, i32, i1, i64) -> i64 +func.func @sitofp(%arg0: i32) { + %0 = arith.sitofp %arg0 : i32 to f4E2M1FN + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_convert_from_int(i32, i32, i1, i64) -> i64 +// CHECK: %[[sem_out:.*]] = arith.constant 18 : i32 +// CHECK: %[[in_width:.*]] = arith.constant 32 : i32 +// CHECK: %[[is_unsigned:.*]] = arith.constant true +// CHECK: %[[res:.*]] = call @_mlir_apfloat_convert_from_int(%[[sem_out]], %[[in_width]], %[[is_unsigned]], %{{.*}}) : (i32, i32, i1, i64) -> i64 +func.func @uitofp(%arg0: i32) { + %0 = arith.uitofp %arg0 : i32 to f4E2M1FN + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_compare(i32, i64, i64) -> i8 +// CHECK: %[[sem:.*]] = arith.constant 18 : i32 +// CHECK: %[[cmp:.*]] = call @_mlir_apfloat_compare(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i8 +// CHECK: %[[c3:.*]] = arith.constant 3 : i8 +// CHECK: %[[is_unordered:.*]] = arith.cmpi eq, %[[cmp]], %[[c3]] : i8 +// CHECK: %[[c0:.*]] = arith.constant 0 : i8 +// CHECK: %[[is_lt:.*]] = arith.cmpi eq, %[[cmp]], %[[c0]] : i8 +// CHECK: arith.ori %[[is_unordered]], %[[is_lt]] : i1 +func.func @cmpf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { + %0 = arith.cmpf "ult", %arg0, %arg1 : f4E2M1FN + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_neg(i32, i64) -> i64 +// CHECK: %[[sem:.*]] = arith.constant 2 : i32 +// CHECK: %[[res:.*]] = call @_mlir_apfloat_neg(%[[sem]], %{{.*}}) : (i32, i64) -> i64 +func.func @negf(%arg0: f32) { + %0 = arith.negf %arg0 : f32 + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_minimum(i32, i64, i64) -> i64 +// CHECK: %[[sem:.*]] = arith.constant 2 : i32 +// CHECK: %[[res:.*]] = call @_mlir_apfloat_minimum(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 +func.func @minimumf(%arg0: f32, %arg1: f32) { + %0 = arith.minimumf %arg0, %arg1 : f32 + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_maximum(i32, i64, i64) -> i64 +// CHECK: %[[sem:.*]] = arith.constant 2 : i32 +// CHECK: %[[res:.*]] = call @_mlir_apfloat_maximum(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 +func.func @maximumf(%arg0: f32, %arg1: f32) { + %0 = arith.maximumf %arg0, %arg1 : f32 + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_minnum(i32, i64, i64) -> i64 +// CHECK: %[[sem:.*]] = arith.constant 2 : i32 +// CHECK: %[[res:.*]] = call @_mlir_apfloat_minnum(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 +func.func @minnumf(%arg0: f32, %arg1: f32) { + %0 = arith.minnumf %arg0, %arg1 : f32 + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_maxnum(i32, i64, i64) -> i64 +// CHECK: %[[sem:.*]] = arith.constant 2 : i32 +// CHECK: %[[res:.*]] = call @_mlir_apfloat_maxnum(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 +func.func @maxnumf(%arg0: f32, %arg1: f32) { + %0 = arith.maxnumf %arg0, %arg1 : f32 + return +} + +// ----- + +// CHECK-LABEL: func.func @unsupported_bitwidth +// CHECK: arith.addf {{.*}} : f128 +// CHECK: arith.negf {{.*}} : f128 +// CHECK: arith.cmpf {{.*}} : f128 +// CHECK: arith.extf {{.*}} : f32 to f128 +// CHECK: arith.truncf {{.*}} : f128 to f32 +// CHECK: arith.fptosi {{.*}} : f128 to i32 +// CHECK: arith.fptosi {{.*}} : f32 to i92 +// CHECK: arith.sitofp {{.*}} : i1 to f128 +// CHECK: arith.sitofp {{.*}} : i92 to f32 +func.func @unsupported_bitwidth(%arg0: f128, %arg1: f128, %arg2: f32) { + %0 = arith.addf %arg0, %arg1 : f128 + %1 = arith.negf %arg0 : f128 + %2 = arith.cmpf "ult", %arg0, %arg1 : f128 + %3 = arith.extf %arg2 : f32 to f128 + %4 = arith.truncf %arg0 : f128 to f32 + %5 = arith.fptosi %arg0 : f128 to i32 + %6 = arith.fptosi %arg2 : f32 to i92 + %7 = arith.sitofp %2 : i1 to f128 + %8 = arith.sitofp %6 : i92 to f32 + return +} + +// ----- + +// CHECK-LABEL: func.func @addf_vector +// CHECK-2: vector.to_elements + +// CHECK: arith.bitcast +// CHECK: arith.extui +// CHECK: arith.bitcast +// CHECK: arith.extui +// CHECK: call +// CHECK: arith.trunci + +// CHECK: arith.bitcast +// CHECK: arith.extui +// CHECK: arith.bitcast +// CHECK: arith.extui +// CHECK: call +// CHECK: arith.trunci + +// CHECK: arith.bitcast +// CHECK: arith.extui +// CHECK: arith.bitcast +// CHECK: arith.extui +// CHECK: call +// CHECK: arith.trunci + +// CHECK: arith.bitcast +// CHECK: arith.extui +// CHECK: arith.bitcast +// CHECK: arith.extui +// CHECK: call +// CHECK: arith.trunci + +// CHECK: vector.from_elements +func.func @addf_vector(%arg0: vector<4xf4E2M1FN>, %arg1: vector<4xf4E2M1FN>) { + %0 = arith.addf %arg0, %arg1 : vector<4xf4E2M1FN> + return +} diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir index ba12ff2..b53c52d 100644 --- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir @@ -738,6 +738,22 @@ func.func @ops_supporting_overflow(%arg0: i64, %arg1: i64) { // ----- +// CHECK-LABEL: @ops_supporting_exact +func.func @ops_supporting_exact(i32, i32) { +^bb0(%arg0: i32, %arg1: i32): +// CHECK: = llvm.ashr exact %arg0, %arg1 : i32 + %0 = arith.shrsi %arg0, %arg1 exact : i32 +// CHECK: = llvm.lshr exact %arg0, %arg1 : i32 + %1 = arith.shrui %arg0, %arg1 exact : i32 +// CHECK: = llvm.sdiv exact %arg0, %arg1 : i32 + %2 = arith.divsi %arg0, %arg1 exact : i32 +// CHECK: = llvm.udiv exact %arg0, %arg1 : i32 + %3 = arith.divui %arg0, %arg1 exact : i32 + return +} + +// ----- + // CHECK-LABEL: func @memref_bitcast // CHECK-SAME: (%[[ARG:.*]]: memref<?xi16>) // CHECK: %[[V1:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<?xi16> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> @@ -747,3 +763,35 @@ func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> { %2 = arith.bitcast %1 : memref<?xi16> to memref<?xbf16> func.return %2 : memref<?xbf16> } + +// ----- + +// CHECK-LABEL: func @unsupported_fp_type +// CHECK: arith.addf {{.*}} : f4E2M1FN +// CHECK: arith.addf {{.*}} : vector<4xf4E2M1FN> +// CHECK: arith.addf {{.*}} : vector<8x4xf4E2M1FN> +// CHECK: arith.cmpf {{.*}} : f4E2M1FN +// CHECK: llvm.select {{.*}} : i1, i4 +func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>, %arg3: f4E2M1FN, %arg4: i1) { + %0 = arith.addf %arg0, %arg0 : f4E2M1FN + %1 = arith.addf %arg1, %arg1 : vector<4xf4E2M1FN> + %2 = arith.addf %arg2, %arg2 : vector<8x4xf4E2M1FN> + %3 = arith.cmpf oeq, %arg0, %arg3 : f4E2M1FN + %4 = arith.select %arg4, %arg0, %arg3 : f4E2M1FN + return +} + +// ----- + +// CHECK-LABEL: func @supported_fp_type +// CHECK: llvm.fadd {{.*}} : f32 +// CHECK: llvm.fadd {{.*}} : vector<4xf32> +// CHECK-COUNT-4: llvm.fadd {{.*}} : vector<8xf32> +// CHECK: llvm.fcmp {{.*}} : f32 +func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8xf32>, %arg3: f32) { + %0 = arith.addf %arg0, %arg0 : f32 + %1 = arith.addf %arg1, %arg1 : vector<4xf32> + %2 = arith.addf %arg2, %arg2 : vector<4x8xf32> + %3 = arith.cmpf oeq, %arg0, %arg3 : f32 + return +} diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir index a75f30d..cd8cfc8 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir @@ -275,6 +275,42 @@ func.func @reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 { // ----- +// CHECK-LABEL: spirv.func @reduction_minnumf( +// CHECK-SAME: %[[V:.*]]: vector<3xf32>, +// CHECK-SAME: %[[S:.*]]: f32) -> f32 "None" { +// CHECK: %[[S0:.*]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32> +// CHECK: %[[S1:.*]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32> +// CHECK: %[[S2:.*]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32> +// CHECK: %[[MIN0:.*]] = spirv.GL.FMin %[[S0]], %[[S1]] : f32 +// CHECK: %[[MIN1:.*]] = spirv.GL.FMin %[[MIN0]], %[[S2]] : f32 +// CHECK: %[[MIN2:.*]] = spirv.GL.FMin %[[MIN1]], %[[S]] : f32 +// CHECK: spirv.ReturnValue %[[MIN2]] : f32 +// CHECK: } +func.func @reduction_minnumf(%v : vector<3xf32>, %s: f32) -> f32 { + %reduce = vector.reduction <minnumf>, %v, %s : vector<3xf32> into f32 + return %reduce : f32 +} + +// ----- + +// CHECK-LABEL: spirv.func @reduction_maxnumf( +// CHECK-SAME: %[[V:.*]]: vector<3xf32>, +// CHECK-SAME: %[[S:.*]]: f32) -> f32 "None" { +// CHECK: %[[S0:.*]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32> +// CHECK: %[[S1:.*]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32> +// CHECK: %[[S2:.*]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32> +// CHECK: %[[MAX0:.*]] = spirv.GL.FMax %[[S0]], %[[S1]] : f32 +// CHECK: %[[MAX1:.*]] = spirv.GL.FMax %[[MAX0]], %[[S2]] : f32 +// CHECK: %[[MAX2:.*]] = spirv.GL.FMax %[[MAX1]], %[[S]] : f32 +// CHECK: spirv.ReturnValue %[[MAX2]] : f32 +// CHECK: } +func.func @reduction_maxnumf(%v : vector<3xf32>, %s: f32) -> f32 { + %reduce = vector.reduction <maxnumf>, %v, %s : vector<3xf32> into f32 + return %reduce : f32 +} + +// ----- + // CHECK-LABEL: func @reduction_maxsi // CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32) // CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xi32> diff --git a/mlir/test/Conversion/FuncToLLVM/convert-data-layout.mlir b/mlir/test/Conversion/FuncToLLVM/convert-data-layout.mlir index cb6bc35..c16702f 100644 --- a/mlir/test/Conversion/FuncToLLVM/convert-data-layout.mlir +++ b/mlir/test/Conversion/FuncToLLVM/convert-data-layout.mlir @@ -1,8 +1,8 @@ // RUN: mlir-opt -set-llvm-module-datalayout -convert-func-to-llvm %s | FileCheck %s -// RUN-32: mlir-opt -set-llvm-module-datalayout='data-layout=p:32:32:32' -convert-func-to-llvm %s \ -// RUN-32: | FileCheck %s +// RUN: mlir-opt -set-llvm-module-datalayout='data-layout=p:32:32:32' -convert-func-to-llvm %s \ +// RUN: | FileCheck %s --check-prefix=CHECK-32 // CHECK: module attributes {llvm.data_layout = ""} -// CHECK-32: module attributes {llvm.data_layout ="p:32:32:32"} +// CHECK-32: module attributes {llvm.data_layout = "p:32:32:32"} module {} diff --git a/mlir/test/Conversion/FuncToLLVM/convert-funcs.mlir b/mlir/test/Conversion/FuncToLLVM/convert-funcs.mlir index ae1dc70..bd28162 100644 --- a/mlir/test/Conversion/FuncToLLVM/convert-funcs.mlir +++ b/mlir/test/Conversion/FuncToLLVM/convert-funcs.mlir @@ -32,7 +32,7 @@ func.func @pass_through(%arg0: () -> ()) -> (() -> ()) { func.func private @llvmlinkage(i32) attributes { "llvm.linkage" = #llvm.linkage<extern_weak> } // CHECK-LABEL: llvm.func @llvmreadnone(i32) -// CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none> +// CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none> func.func private @llvmreadnone(i32) attributes { llvm.readnone } // CHECK-LABEL: llvm.func @body(i32) diff --git a/mlir/test/Conversion/FuncToLLVM/func-memref.mlir b/mlir/test/Conversion/FuncToLLVM/func-memref.mlir index 15a9654..16ed484 100644 --- a/mlir/test/Conversion/FuncToLLVM/func-memref.mlir +++ b/mlir/test/Conversion/FuncToLLVM/func-memref.mlir @@ -1,5 +1,6 @@ // RUN: mlir-opt -pass-pipeline="builtin.module(func.func(convert-arith-to-llvm),convert-func-to-llvm,convert-cf-to-llvm,reconcile-unrealized-casts)" -split-input-file %s | FileCheck %s // RUN: mlir-opt -pass-pipeline="builtin.module(func.func(convert-arith-to-llvm),convert-func-to-llvm{use-bare-ptr-memref-call-conv=1},convert-cf-to-llvm,reconcile-unrealized-casts)" -split-input-file %s | FileCheck %s --check-prefix=BAREPTR +// RUN: mlir-opt -pass-pipeline="builtin.module(func.func(convert-arith-to-llvm{index-bitwidth=32}),convert-func-to-llvm{index-bitwidth=32},convert-cf-to-llvm{index-bitwidth=32},reconcile-unrealized-casts)" -split-input-file %s | FileCheck %s --check-prefix=CHECK32 // BAREPTR-LABEL: func @check_noalias // BAREPTR-SAME: %{{.*}}: !llvm.ptr {llvm.noalias}, %{{.*}}: !llvm.ptr {llvm.noalias} diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir index a4b5dde..f1cc1eb 100644 --- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1' -split-input-file | FileCheck %s +// RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1 allow-pattern-rollback=0' -split-input-file | FileCheck %s // RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1 allowed-dialects=func,arith,cf' -split-input-file | FileCheck %s // RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1 use-bare-ptr-memref-call-conv=1' -split-input-file | FileCheck %s --check-prefix=CHECK-BARE // RUN: mlir-opt %s -transform-interpreter | FileCheck %s diff --git a/mlir/test/Conversion/GPUToNVVM/memref.mlir b/mlir/test/Conversion/GPUToNVVM/memref.mlir index e164ca9..a4e8ead 100644 --- a/mlir/test/Conversion/GPUToNVVM/memref.mlir +++ b/mlir/test/Conversion/GPUToNVVM/memref.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -convert-gpu-to-nvvm | FileCheck %s +// RUN: mlir-opt %s -convert-gpu-to-nvvm="allow-pattern-rollback=0" | FileCheck %s // RUN: mlir-opt %s -convert-gpu-to-nvvm='use-bare-ptr-memref-call-conv=1' \ // RUN: | FileCheck %s --check-prefix=BARE diff --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir index b479467..a080144 100644 --- a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt --convert-gpu-to-nvvm --split-input-file %s | FileCheck %s +// RUN: mlir-opt --convert-gpu-to-nvvm="allow-pattern-rollback=0" --split-input-file %s | FileCheck %s // RUN: mlir-opt --convert-gpu-to-nvvm="index-bitwidth=32" --split-input-file %s | FileCheck --check-prefix=CHECK32 %s gpu.module @test_module { @@ -81,6 +82,28 @@ gpu.module @test_module { gpu.module @test_module { + // CHECK-LABEL: func @gpu_wmma_f64_load_op() -> + // CHECK-SAME: f64 + // CHECK32-LABEL: func @gpu_wmma_f64_load_op() -> + func.func @gpu_wmma_f64_load_op() -> (!gpu.mma_matrix<8x4xf64, "AOp">) { + %wg = memref.alloca() {alignment = 32} : memref<32x32xf64, 3> + %i = arith.constant 16 : index + %j = arith.constant 16 : index + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf64, 3> -> !gpu.mma_matrix<8x4xf64, "AOp"> + return %0 : !gpu.mma_matrix<8x4xf64, "AOp"> + // CHECK: %[[MUL:.*]] = llvm.mul %{{.*}}, %{{.*}} : i64 + // CHECK: %[[ADD:.*]] = llvm.add %[[MUL]], %{{.*}} : i64 + // CHECK: %[[GEP:.*]] = llvm.getelementptr %{{.*}}[%[[ADD]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f64 + // CHECK: %[[C32_I32:.*]] = llvm.mlir.constant(32 : index) : i32 + // CHECK: %[[LOAD:.*]] = nvvm.wmma.load %[[GEP]], %[[C32_I32]] {eltype = #nvvm.mma_type<f64>, frag = #nvvm.mma_frag<a>, k = 4 : i32, layout = #nvvm.mma_layout<row>, m = 8 : i32, n = 8 : i32} : (!llvm.ptr<3>) -> f64 + // CHECK: llvm.return %[[LOAD]] : f64 + } +} + +// ----- + +gpu.module @test_module { + // CHECK-LABEL: func @gpu_wmma_store_op // CHECK-SAME: (%[[D:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) // CHECK32-LABEL: func @gpu_wmma_store_op diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir index c1627a0..19e1c7a 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefix=CPP -// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefix=NOCPP +// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefixes=CPP,CHECK +// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefixes=NOCPP,CHECK func.func @alloc_copy(%arg0: memref<999xi32>) { %alloc = memref.alloc() : memref<999xi32> @@ -9,42 +9,46 @@ func.func @alloc_copy(%arg0: memref<999xi32>) { return } -// CHECK: module { // NOCPP: emitc.include <"stdlib.h"> // NOCPP-NEXT: emitc.include <"string.h"> // CPP: emitc.include <"cstdlib"> // CPP-NEXT: emitc.include <"cstring"> -// CHECK-LABEL: alloc_copy -// CHECK-SAME: %[[arg0:.*]]: memref<999xi32> -// CHECK-NEXT: builtin.unrealized_conversion_cast %arg0 : memref<999xi32> to !emitc.array<999xi32> -// CHECK-NEXT: emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t -// CHECK-NEXT: "emitc.constant"() <{value = 999 : index}> : () -> index -// CHECK-NEXT: emitc.mul %1, %2 : (!emitc.size_t, index) -> !emitc.size_t -// CHECK-NEXT: emitc.call_opaque "malloc"(%3) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">> -// CHECK-NEXT: emitc.cast %4 : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32> -// CHECK-NEXT: builtin.unrealized_conversion_cast %5 : !emitc.ptr<i32> to !emitc.array<999xi32> -// CHECK-NEXT: "emitc.constant"() <{value = 0 : index}> : () -> index -// CHECK-NEXT: emitc.subscript %0[%7] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32> -// CHECK-NEXT: emitc.apply "&"(%8) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32> -// CHECK-NEXT: emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t -// CHECK-NEXT: "emitc.constant"() <{value = 999 : index}> : () -> index -// CHECK-NEXT: emitc.mul %12, %13 : (!emitc.size_t, index) -> !emitc.size_t -// CHECK-NEXT: emitc.call_opaque "memcpy"(%11, %9, %14) : (!emitc.ptr<i32>, !emitc.ptr<i32>, !emitc.size_t) -> () -// CHECK-NEXT: emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t -// CHECK-NEXT: "emitc.constant"() <{value = 999 : index}> : () -> index -// CHECK-NEXT: emitc.mul %15, %16 : (!emitc.size_t, index) -> !emitc.size_t -// CHECK-NEXT: emitc.call_opaque "malloc"(%17) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">> -// CHECK-NEXT: emitc.cast %18 : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32> -// CHECK-NEXT: builtin.unrealized_conversion_cast %19 : !emitc.ptr<i32> to !emitc.array<999xi32> -// CHECK-NEXT: "emitc.constant"() <{value = 0 : index}> : () -> index -// CHECK-NEXT: emitc.subscript %0[%21] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32> -// CHECK-NEXT: emitc.apply "&"(%22) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32> -// CHECK-NEXT: emitc.subscript %20[%21] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32> -// CHECK-NEXT: emitc.apply "&"(%24) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32> -// CHECK-NEXT: emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t -// CHECK-NEXT: "emitc.constant"() <{value = 999 : index}> : () -> index -// CHECK-NEXT: emitc.mul %26, %27 : (!emitc.size_t, index) -> !emitc.size_t -// CHECK-NEXT: emitc.call_opaque "memcpy"(%25, %23, %28) : (!emitc.ptr<i32>, !emitc.ptr<i32>, !emitc.size_t) -> () -// CHECK-NEXT: return +// CHECK-LABEL: func.func @alloc_copy( +// CHECK-SAME: %[[ARG0:.*]]: memref<999xi32>) { +// CHECK: %[[UNREALIZED_CONVERSION_CAST_0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<999xi32> to !emitc.array<999xi32> +// CHECK: %[[CALL_OPAQUE_0:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t +// CHECK: %[[VAL_0:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index +// CHECK: %[[MUL_0:.*]] = emitc.mul %[[CALL_OPAQUE_0]], %[[VAL_0]] : (!emitc.size_t, index) -> !emitc.size_t +// CHECK: %[[CALL_OPAQUE_1:.*]] = emitc.call_opaque "malloc"(%[[MUL_0]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">> +// CHECK: %[[CAST_0:.*]] = emitc.cast %[[CALL_OPAQUE_1]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32> +// CHECK: %[[UNREALIZED_CONVERSION_CAST_1:.*]] = builtin.unrealized_conversion_cast %[[CAST_0]] : !emitc.ptr<i32> to !emitc.array<999xi32> +// CHECK: %[[VAL_1:.*]] = "emitc.constant"() <{value = 0 : index}> : () -> index +// CHECK: %[[SUBSCRIPT_0:.*]] = emitc.subscript %[[UNREALIZED_CONVERSION_CAST_0]]{{\[}}%[[VAL_1]]] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32> +// CHECK: %[[ADDRESS_OF_0:.*]] = emitc.address_of %[[SUBSCRIPT_0]] : !emitc.lvalue<i32> +// CHECK: %[[VAL_2:.*]] = "emitc.constant"() <{value = 0 : index}> : () -> index +// CHECK: %[[SUBSCRIPT_1:.*]] = emitc.subscript %[[UNREALIZED_CONVERSION_CAST_1]]{{\[}}%[[VAL_2]]] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32> +// CHECK: %[[ADDRESS_OF_1:.*]] = emitc.address_of %[[SUBSCRIPT_1]] : !emitc.lvalue<i32> +// CHECK: %[[CALL_OPAQUE_2:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t +// CHECK: %[[VAL_3:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index +// CHECK: %[[MUL_1:.*]] = emitc.mul %[[CALL_OPAQUE_2]], %[[VAL_3]] : (!emitc.size_t, index) -> !emitc.size_t +// CHECK: emitc.call_opaque "memcpy"(%[[ADDRESS_OF_1]], %[[ADDRESS_OF_0]], %[[MUL_1]]) : (!emitc.ptr<i32>, !emitc.ptr<i32>, !emitc.size_t) -> () +// CHECK: %[[CALL_OPAQUE_3:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t +// CHECK: %[[VAL_4:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index +// CHECK: %[[MUL_2:.*]] = emitc.mul %[[CALL_OPAQUE_3]], %[[VAL_4]] : (!emitc.size_t, index) -> !emitc.size_t +// CHECK: %[[CALL_OPAQUE_4:.*]] = emitc.call_opaque "malloc"(%[[MUL_2]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">> +// CHECK: %[[CAST_1:.*]] = emitc.cast %[[CALL_OPAQUE_4]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32> +// CHECK: %[[UNREALIZED_CONVERSION_CAST_2:.*]] = builtin.unrealized_conversion_cast %[[CAST_1]] : !emitc.ptr<i32> to !emitc.array<999xi32> +// CHECK: %[[VAL_5:.*]] = "emitc.constant"() <{value = 0 : index}> : () -> index +// CHECK: %[[SUBSCRIPT_2:.*]] = emitc.subscript %[[UNREALIZED_CONVERSION_CAST_0]]{{\[}}%[[VAL_5]]] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32> +// CHECK: %[[ADDRESS_OF_2:.*]] = emitc.address_of %[[SUBSCRIPT_2]] : !emitc.lvalue<i32> +// CHECK: %[[VAL_6:.*]] = "emitc.constant"() <{value = 0 : index}> : () -> index +// CHECK: %[[SUBSCRIPT_3:.*]] = emitc.subscript %[[UNREALIZED_CONVERSION_CAST_2]]{{\[}}%[[VAL_6]]] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32> +// CHECK: %[[ADDRESS_OF_3:.*]] = emitc.address_of %[[SUBSCRIPT_3]] : !emitc.lvalue<i32> +// CHECK: %[[CALL_OPAQUE_5:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t +// CHECK: %[[VAL_7:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index +// CHECK: %[[MUL_3:.*]] = emitc.mul %[[CALL_OPAQUE_5]], %[[VAL_7]] : (!emitc.size_t, index) -> !emitc.size_t +// CHECK: emitc.call_opaque "memcpy"(%[[ADDRESS_OF_3]], %[[ADDRESS_OF_2]], %[[MUL_3]]) : (!emitc.ptr<i32>, !emitc.ptr<i32>, !emitc.size_t) -> () +// CHECK: return +// CHECK: } diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir index d151d1b..3de2d25 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefix=CPP -// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefix=NOCPP +// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefixes=CPP,CHECK +// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefixes=NOCPP,CHECK func.func @copying(%arg0 : memref<9x4x5x7xf32>, %arg1 : memref<9x4x5x7xf32>) { memref.copy %arg0, %arg1 : memref<9x4x5x7xf32> to memref<9x4x5x7xf32> @@ -10,20 +10,21 @@ func.func @copying(%arg0 : memref<9x4x5x7xf32>, %arg1 : memref<9x4x5x7xf32>) { // NOCPP: emitc.include <"string.h"> // CPP: emitc.include <"cstring"> -// CHECK-LABEL: copying -// CHECK-SAME: %[[arg0:.*]]: memref<9x4x5x7xf32>, %[[arg1:.*]]: memref<9x4x5x7xf32> -// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg1 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32> -// CHECK-NEXT: %1 = builtin.unrealized_conversion_cast %arg0 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32> -// CHECK-NEXT: %2 = "emitc.constant"() <{value = 0 : index}> : () -> index -// CHECK-NEXT: %3 = emitc.subscript %1[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32> -// CHECK-NEXT: %4 = emitc.apply "&"(%3) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32> -// CHECK-NEXT: %5 = emitc.subscript %0[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32> -// CHECK-NEXT: %6 = emitc.apply "&"(%5) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32> -// CHECK-NEXT: %7 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t -// CHECK-NEXT: %8 = "emitc.constant"() <{value = 1260 : index}> : () -> index -// CHECK-NEXT: %9 = emitc.mul %7, %8 : (!emitc.size_t, index) -> !emitc.size_t -// CHECK-NEXT: emitc.call_opaque "memcpy"(%6, %4, %9) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> () -// CHECK-NEXT: return -// CHECK-NEXT: } -// CHECK-NEXT:} +// CHECK-LABEL: func.func @copying( +// CHECK-SAME: %[[ARG0:.*]]: memref<9x4x5x7xf32>, +// CHECK-SAME: %[[ARG1:.*]]: memref<9x4x5x7xf32>) { +// CHECK: %[[UNREALIZED_CONVERSION_CAST_0:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32> +// CHECK: %[[UNREALIZED_CONVERSION_CAST_1:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32> +// CHECK: %[[VAL_0:.*]] = "emitc.constant"() <{value = 0 : index}> : () -> index +// CHECK: %[[SUBSCRIPT_0:.*]] = emitc.subscript %[[UNREALIZED_CONVERSION_CAST_1]]{{\[}}%[[VAL_0]], %[[VAL_0]], %[[VAL_0]], %[[VAL_0]]] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32> +// CHECK: %[[ADDRESS_OF_0:.*]] = emitc.address_of %[[SUBSCRIPT_0]] : !emitc.lvalue<f32> +// CHECK: %[[VAL_1:.*]] = "emitc.constant"() <{value = 0 : index}> : () -> index +// CHECK: %[[SUBSCRIPT_1:.*]] = emitc.subscript %[[UNREALIZED_CONVERSION_CAST_0]]{{\[}}%[[VAL_1]], %[[VAL_1]], %[[VAL_1]], %[[VAL_1]]] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32> +// CHECK: %[[ADDRESS_OF_1:.*]] = emitc.address_of %[[SUBSCRIPT_1]] : !emitc.lvalue<f32> +// CHECK: %[[CALL_OPAQUE_0:.*]] = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t +// CHECK: %[[VAL_2:.*]] = "emitc.constant"() <{value = 1260 : index}> : () -> index +// CHECK: %[[MUL_0:.*]] = emitc.mul %[[CALL_OPAQUE_0]], %[[VAL_2]] : (!emitc.size_t, index) -> !emitc.size_t +// CHECK: emitc.call_opaque "memcpy"(%[[ADDRESS_OF_1]], %[[ADDRESS_OF_0]], %[[MUL_0]]) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> () +// CHECK: return +// CHECK: } diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir index 2b4eda3..c7b043b 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir @@ -53,7 +53,7 @@ module @globals { // CHECK-NEXT: emitc.get_global @public_global : !emitc.array<3x7xf32> %0 = memref.get_global @public_global : memref<3x7xf32> // CHECK-NEXT: emitc.get_global @__constant_xi32 : !emitc.lvalue<i32> - // CHECK-NEXT: emitc.apply "&"(%1) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32> + // CHECK-NEXT: emitc.address_of %1 : !emitc.lvalue<i32> %1 = memref.get_global @__constant_xi32 : memref<i32> return } diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir index 68c3e9f..e0e4a61 100644 --- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt -expand-strided-metadata -finalize-memref-to-llvm -lower-affine -convert-arith-to-llvm -cse %s -split-input-file | FileCheck %s +// RUN: mlir-opt -expand-strided-metadata -finalize-memref-to-llvm='index-bitwidth=32' -lower-affine -convert-arith-to-llvm='index-bitwidth=32' -cse %s -split-input-file | FileCheck %s --check-prefix=CHECK32 // // This test demonstrates a full "memref to llvm" pipeline where // we first expand some of the memref operations (using affine, @@ -441,10 +442,31 @@ func.func @collapse_shape_dynamic_with_non_identity_layout( // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC6]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<4x?xf32, strided<[?, ?], offset: ?>> // CHECK: return %[[RES]] : memref<4x?xf32, strided<[?, ?], offset: ?>> // CHECK: } -// CHECK32-LABEL: func @collapse_shape_dynamic_with_non_identity_layout( -// CHECK32: llvm.mlir.constant(1 : index) : i32 -// CHECK32: llvm.mlir.constant(4 : index) : i32 -// CHECK32: llvm.mlir.constant(1 : index) : i32 +// CHECK32-LABEL: func.func @collapse_shape_dynamic_with_non_identity_layout( +// CHECK32-SAME: %[[ARG:.*]]: memref<4x?x?xf32, strided<[?, 4, 1], offset: ?>>) -> memref<4x?xf32, strided<[?, ?], offset: ?>> { +// CHECK32: %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<4x?x?xf32, strided<[?, 4, 1], offset: ?>> to !llvm.struct<(ptr, ptr, i32, array<3 x i32>, array<3 x i32>)> +// CHECK32: %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i32, +// CHECK32: %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i32, +// CHECK32: %[[OFFSET:.*]] = llvm.extractvalue %[[MEM]][2] : !llvm.struct<(ptr, ptr, i32, array<3 x i32>, array<3 x i32>)> +// CHECK32: %[[SIZE1:.*]] = llvm.extractvalue %[[MEM]][3, 1] : !llvm.struct<(ptr, ptr, i32, array<3 x i32>, array<3 x i32>)> +// CHECK32: %[[SIZE2:.*]] = llvm.extractvalue %[[MEM]][3, 2] : !llvm.struct<(ptr, ptr, i32, array<3 x i32>, array<3 x i32>)> +// CHECK32: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i32, array<3 x i32>, array<3 x i32>)> +// CHECK32: %[[FINAL_SIZE1_I32:.*]] = llvm.mul %[[SIZE1]], %[[SIZE2]] overflow<nsw> : i32 +// CHECK32: %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1_I32]] : i32 to index +// CHECK32: %[[FINAL_SIZE1_CAST:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i32 +// CHECK32: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> +// CHECK32: %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> +// CHECK32: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> +// CHECK32: %[[DESC2:.*]] = llvm.insertvalue %[[OFFSET]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> +// CHECK32: %[[C4_I32:.*]] = llvm.mlir.constant(4 : index) : i32 +// CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[C4_I32]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> +// CHECK32: %[[DESC4:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESC3]][4, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> +// CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[FINAL_SIZE1_CAST]], %[[DESC4]][3, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> +// CHECK32: %[[C1_I32:.*]] = llvm.mlir.constant(1 : index) : i32 +// CHECK32: %[[DESC6:.*]] = llvm.insertvalue %[[C1_I32]], %[[DESC5]][4, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> +// CHECK32: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC6]] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> to memref<4x?xf32, strided<[?, ?], offset: ?>> +// CHECK32: return %[[RES]] : memref<4x?xf32, strided<[?, ?], offset: ?>> +// CHECK32: } // ----- diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir index 8cce630..0eb4478 100644 --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -491,12 +491,12 @@ func.func @mbarrier() { // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 - // CHECK: %[[token:.+]] = nvvm.mbarrier.arrive.shared %[[barPtr2]] + // CHECK: %[[token:.+]] = nvvm.mbarrier.arrive %[[barPtr2]] %token = nvgpu.mbarrier.arrive %barrier[%c0] : !barrierType -> !tokenType // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 - // CHECK: nvvm.mbarrier.test.wait.shared %[[barPtr3]], %[[token]] + // CHECK: nvvm.mbarrier.test.wait %[[barPtr3]], %[[token]] %isDone = nvgpu.mbarrier.test.wait %barrier[%c0], %token : !barrierType, !tokenType func.return @@ -521,12 +521,12 @@ func.func @mbarrier_nocomplete() { // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 - // CHECK: %[[token:.+]] = nvvm.mbarrier.arrive.nocomplete.shared %[[barPtr2]] + // CHECK: %[[token:.+]] = nvvm.mbarrier.arrive.nocomplete %[[barPtr2]] %token = nvgpu.mbarrier.arrive.nocomplete %barrier[%c0], %count : !barrierType -> !tokenType // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 - // CHECK: nvvm.mbarrier.test.wait.shared %[[barPtr3]], %[[token]] + // CHECK: nvvm.mbarrier.test.wait %[[barPtr3]], %[[token]] %isDone = nvgpu.mbarrier.test.wait %barrier[%c0], %token : !barrierType, !tokenType func.return @@ -572,7 +572,7 @@ func.func @mbarrier_wait(%barriers : !nvgpu.mbarrier.group<memorySpace = #gpu.ad // CHECK: %[[S3:.+]] = builtin.unrealized_conversion_cast %[[S2]] : index to i64 // CHECK: %[[S4:.+]] = llvm.extractvalue %[[CARG0]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[S5:.+]] = llvm.getelementptr %[[S4]][%[[S3]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 -// CHECK: nvvm.mbarrier.test.wait.shared {{.*}}, %[[CARG1]] +// CHECK: nvvm.mbarrier.test.wait {{.*}}, %[[CARG1]] %mbarId = arith.remui %i, %numBarriers : index %isDone = nvgpu.mbarrier.test.wait %barriers[%mbarId], %token : !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>, num_barriers = 5>, !tokenType } @@ -603,14 +603,14 @@ func.func @mbarrier_txcount() { %txcount = arith.constant 256 : index // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 - // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]] + // CHECK: nvvm.mbarrier.arrive.expect_tx %[[barPtr2]] nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount : !barrierType scf.yield } else { %txcount = arith.constant 0 : index // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 - // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]] + // CHECK: nvvm.mbarrier.arrive.expect_tx %[[barPtr2]] nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount : !barrierType scf.yield } @@ -620,7 +620,7 @@ func.func @mbarrier_txcount() { %ticks = arith.constant 10000000 : index // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 - // CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]] + // CHECK: nvvm.mbarrier.try_wait.parity %[[barPtr3]] nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType func.return @@ -649,14 +649,14 @@ func.func @mbarrier_txcount_pred() { %txcount = arith.constant 256 : index // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 - // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]], {{.*}}, predicate = %[[P]] + // CHECK: nvvm.mbarrier.arrive.expect_tx %[[barPtr2]], {{.*}}, predicate = %[[P]] nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount, predicate = %pred : !barrierType %phase_c0 = arith.constant 0 : i1 %ticks = arith.constant 10000000 : index // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 - // CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]] + // CHECK: nvvm.mbarrier.try_wait.parity %[[barPtr3]] nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType func.return diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir index fbc4c0a..8fb36ac 100644 --- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir +++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir @@ -16,17 +16,13 @@ llvm.func @init_mbarrier(%barrier_gen : !llvm.ptr, %barrier : !llvm.ptr<3>, %cou // CHECK-LABEL: @init_mbarrier_arrive_expect_tx llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32, %pred : i1) { - //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r" - nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount : !llvm.ptr<3>, i32 //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r,b" - nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1 + nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1 llvm.return } // CHECK-LABEL: @init_mbarrier_arrive_expect_tx_generic llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount : i32, %pred : i1) { - // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.b64 _, [$0], $1;", "l,r" - nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr, i32 // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.arrive.expect_tx.b64 _, [$0], $1;", "l,r,b" nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred : !llvm.ptr, i32, i1 llvm.return @@ -44,7 +40,7 @@ llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %ticks : i32, // CHECK-SAME: DONE: // CHECK-SAME: }", // CHECK-SAME: "r,r,r" - nvvm.mbarrier.try_wait.parity.shared %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32 + nvvm.mbarrier.try_wait.parity %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32 llvm.return } @@ -88,10 +84,10 @@ func.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.p nvvm.cp.async.mbarrier.arrive %bar_gen : !llvm.ptr // CHECK: nvvm.cp.async.mbarrier.arrive %{{.*}} {noinc = true} nvvm.cp.async.mbarrier.arrive %bar_gen {noinc = true} : !llvm.ptr - // CHECK: nvvm.cp.async.mbarrier.arrive.shared %{{.*}} - nvvm.cp.async.mbarrier.arrive.shared %bar_shared : !llvm.ptr<3> - // CHECK: nvvm.cp.async.mbarrier.arrive.shared %{{.*}} {noinc = true} - nvvm.cp.async.mbarrier.arrive.shared %bar_shared {noinc = true} : !llvm.ptr<3> + // CHECK: nvvm.cp.async.mbarrier.arrive %{{.*}} + nvvm.cp.async.mbarrier.arrive %bar_shared : !llvm.ptr<3> + // CHECK: nvvm.cp.async.mbarrier.arrive %{{.*}} {noinc = true} + nvvm.cp.async.mbarrier.arrive %bar_shared {noinc = true} : !llvm.ptr<3> llvm.return } @@ -544,8 +540,8 @@ func.func @elect_one_leader_sync() { // ----- -// CHECK-LABEL: @init_mbarrier_arrive_expect_tx -llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) { +// CHECK-LABEL: @test_nvvm_prefetch +llvm.func @test_nvvm_prefetch(%desc : !llvm.ptr, %pred : i1) { //CHECK: nvvm.prefetch tensormap, %{{.*}} nvvm.prefetch tensormap, %desc : !llvm.ptr //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$1 prefetch.tensormap [$0];", "l,b" @@ -588,29 +584,6 @@ func.func @cp_async_bulk_wait_group() { // ----- -func.func @fence_mbarrier_init() { - //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.mbarrier_init.release.cluster;" - nvvm.fence.mbarrier.init - func.return -} -// ----- - -func.func @fence_proxy() { - //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.alias;", "" : () -> () - nvvm.fence.proxy { kind = #nvvm.proxy_kind<alias>} - //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.async;", "" : () -> () - nvvm.fence.proxy { kind = #nvvm.proxy_kind<async>} - //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.async.global;", "" : () -> () - nvvm.fence.proxy { kind = #nvvm.proxy_kind<async.global>} - //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.async.shared::cta;", "" : () -> () - nvvm.fence.proxy { kind = #nvvm.proxy_kind<async.shared>, space = #nvvm.shared_space<cta>} - //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.async.shared::cluster;", "" : () -> () - nvvm.fence.proxy { kind = #nvvm.proxy_kind<async.shared>, space = #nvvm.shared_space<cluster>} - func.return -} - -// ----- - // CHECK-LABEL: @llvm_nvvm_barrier_arrive // CHECK-SAME: (%[[barId:.*]]: i32, %[[numberOfThreads:.*]]: i32) llvm.func @llvm_nvvm_barrier_arrive(%barID : i32, %numberOfThreads : i32) { diff --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir index f2fbe91..b122f42 100644 --- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir @@ -615,3 +615,22 @@ omp.declare_mapper @my_mapper : !llvm.struct<"_QFdeclare_mapperTmy_type", (i32)> // CHECK: omp.declare_mapper.info map_entries(%{{.*}}, %{{.*}} : !llvm.ptr, !llvm.ptr) omp.declare_mapper.info map_entries(%3, %2 : !llvm.ptr, !llvm.ptr) } + +// CHECK-LABEL: llvm.func @omp_dist_schedule(%arg0: i32) { +func.func @omp_dist_schedule(%arg0: i32) { + %c1_i32 = arith.constant 1 : i32 + // CHECK: %1 = llvm.mlir.constant(1024 : i32) : i32 + %c1024_i32 = arith.constant 1024 : i32 + %c16_i32 = arith.constant 16 : i32 + %c8_i32 = arith.constant 8 : i32 + omp.teams num_teams( to %c8_i32 : i32) thread_limit(%c16_i32 : i32) { + // CHECK: omp.distribute dist_schedule_static dist_schedule_chunk_size(%1 : i32) { + omp.distribute dist_schedule_static dist_schedule_chunk_size(%c1024_i32 : i32) { + omp.loop_nest (%arg1) : i32 = (%c1_i32) to (%arg0) inclusive step (%c1_i32) { + omp.terminator + } + } + omp.terminator + } + return +} diff --git a/mlir/test/Conversion/OpenMPToLLVM/map-info-type-conversion-fail.mlir b/mlir/test/Conversion/OpenMPToLLVM/map-info-type-conversion-fail.mlir new file mode 100644 index 0000000..3bd9bb4 --- /dev/null +++ b/mlir/test/Conversion/OpenMPToLLVM/map-info-type-conversion-fail.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt -convert-openmp-to-llvm -split-input-file -verify-diagnostics %s + +// Indicates that the TypeConversion has failed for the MPMapInfoOp. +// In this specific case, the `tensor` type (used in a TypeAttr) cannot be converted +// to an LLVM type. This test ensures that the conversion fails gracefully with a +// legalization error instead of crashing. +func.func @fail_map_info_tensor_type(%arg0: memref<?xf32>) { + // expected-error@+1 {{failed to legalize operation 'omp.map.info' that was explicitly marked illegal}} + %map_info = omp.map.info var_ptr(%arg0: memref<?xf32>, tensor<?xf32>) map_clauses(to) capture(ByRef) -> memref<?xf32> + omp.target_update map_entries(%map_info: memref<?xf32>) { + omp.terminator + } + return +} diff --git a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir index 483c7b3..0c4f20e 100644 --- a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir +++ b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-cf -split-input-file %s | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-cf="allow-pattern-rollback=0" -split-input-file %s | FileCheck %s // CHECK-LABEL: func @simple_std_for_loop(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { // CHECK-NEXT: cf.br ^bb1(%{{.*}} : index) diff --git a/mlir/test/Conversion/SCFToGPU/parallel_loop.mlir b/mlir/test/Conversion/SCFToGPU/parallel_loop.mlir index 26f5a3e..2f192df 100644 --- a/mlir/test/Conversion/SCFToGPU/parallel_loop.mlir +++ b/mlir/test/Conversion/SCFToGPU/parallel_loop.mlir @@ -673,3 +673,51 @@ func.func @nested_parallel_with_side_effect() { // CHECK: gpu.launch // CHECK-NOT: scf.parallel + +// ----- + +func.func @scf2gpu_index_creation_2d() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + + // Single 2-D scf.parallel mapped to block_x and thread_x. + // Use both IVs so the conversion must compute indices. + scf.parallel (%bx, %tx) = (%c0, %c0) to (%c32, %c32) step (%c1, %c1) { + %u = arith.addi %bx, %c0 : index + %v = arith.addi %tx, %c0 : index + } { + mapping = [ + #gpu.loop_dim_map<processor = block_x, map = (d0) -> (d0), bound = (d0) -> (d0)>, + #gpu.loop_dim_map<processor = thread_x, map = (d0) -> (d0), bound = (d0) -> (d0)> + ] + } + return +} + +// CHECK-LABEL: func @scf2gpu_index_creation_2d +// CHECK: gpu.launch +// CHECK: %[[IDX:.*]] = affine.apply +// CHECK: arith.addi %[[IDX]], + +// ----- + +func.func @scf2gpu_index_creation_1d() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + + scf.parallel (%t) = (%c0) to (%c64) step (%c1) { + %w = arith.addi %t, %c0 : index + } { + mapping = [ + #gpu.loop_dim_map<processor = thread_x, map = (d0) -> (d0), bound = (d0) -> (d0)> + ] + } + return +} + +// CHECK-LABEL: func @scf2gpu_index_creation_1d +// CHECK: gpu.launch +// CHECK: %[[IDX:.*]] = affine.apply +// CHECK: arith.addi %[[IDX]], diff --git a/mlir/test/Conversion/SPIRVToLLVM/gl-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/gl-ops-to-llvm.mlir index e1936e2..b17e1c4 100644 --- a/mlir/test/Conversion/SPIRVToLLVM/gl-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/gl-ops-to-llvm.mlir @@ -162,9 +162,7 @@ spirv.func @sqrt(%arg0: f32, %arg1: vector<3xf16>) "None" { // CHECK-LABEL: @tan spirv.func @tan(%arg0: f32) "None" { - // CHECK: %[[SIN:.*]] = llvm.intr.sin(%{{.*}}) : (f32) -> f32 - // CHECK: %[[COS:.*]] = llvm.intr.cos(%{{.*}}) : (f32) -> f32 - // CHECK: llvm.fdiv %[[SIN]], %[[COS]] : f32 + // CHECK: llvm.intr.tan(%{{.*}}) : (f32) -> f32 %0 = spirv.GL.Tan %arg0 : f32 spirv.Return } @@ -175,13 +173,7 @@ spirv.func @tan(%arg0: f32) "None" { // CHECK-LABEL: @tanh spirv.func @tanh(%arg0: f32) "None" { - // CHECK: %[[TWO:.*]] = llvm.mlir.constant(2.000000e+00 : f32) : f32 - // CHECK: %[[X2:.*]] = llvm.fmul %[[TWO]], %{{.*}} : f32 - // CHECK: %[[EXP:.*]] = llvm.intr.exp(%[[X2]]) : (f32) -> f32 - // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 - // CHECK: %[[T0:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : f32 - // CHECK: %[[T1:.*]] = llvm.fadd %[[EXP]], %[[ONE]] : f32 - // CHECK: llvm.fdiv %[[T0]], %[[T1]] : f32 + // CHECK: llvm.intr.tanh(%{{.*}}) : (f32) -> f32 %0 = spirv.GL.Tanh %arg0 : f32 spirv.Return } diff --git a/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir b/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir index b69c2d0..65c6e05 100644 --- a/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir +++ b/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir @@ -79,3 +79,12 @@ func.func @tensor_2d_empty() -> () { %x = arith.constant dense<> : tensor<2x0xi32> return } + +// Tensors with more than UINT32_MAX elements cannnot fit in a spirv.array. +// Test that they are not lowered. +// CHECK-LABEL: func @very_large_tensor +// CHECK-NEXT: arith.constant dense<1> +func.func @very_large_tensor() -> () { + %x = arith.constant dense<1> : tensor<4294967296xi32> + return +} diff --git a/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir b/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir index 6c0b111..0fe63f5 100644 --- a/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir +++ b/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir @@ -17,3 +17,9 @@ func.func @check_poison() { %3 = ub.poison : !llvm.ptr return } + +// CHECK-LABEL: @check_unrechable +func.func @check_unrechable() { +// CHECK: llvm.unreachable + ub.unreachable +} diff --git a/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir b/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir index f497eb3..9c277cf 100644 --- a/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir +++ b/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -split-input-file -convert-ub-to-spirv -verify-diagnostics %s | FileCheck %s +// RUN: mlir-opt -split-input-file -convert-ub-to-spirv %s | FileCheck %s module attributes { spirv.target_env = #spirv.target_env< @@ -19,3 +19,20 @@ func.func @check_poison() { } } + +// ----- + +module attributes { + spirv.target_env = #spirv.target_env< + #spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Shader], []>, #spirv.resource_limits<>> +} { +// CHECK-LABEL: @check_unrechable +func.func @check_unrechable(%c: i1) { + cf.cond_br %c, ^bb1, ^bb2 +^bb1: +// CHECK: spirv.Unreachable + ub.unreachable +^bb2: + return +} +} diff --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir index 9908205..ae5141d 100644 --- a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir @@ -9,11 +9,12 @@ func.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vecto // CHECK-LABEL: @load_1D_vector( // CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>, // CHECK-SAME: %[[OFFSET:.+]]: index +// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0] // CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc -// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] -// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32, +// CHECK-SAME: %[[COLLAPSED]] +// CHECK-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32, // CHECK-SAME: boundary_check = false -// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf32> +// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]]]{{.*}}-> vector<8xf32> // CHECK: return %[[VEC]] // ----- @@ -28,35 +29,29 @@ func.func @load_2D_vector(%source: memref<8x16x32xf32>, // CHECK-LABEL: @load_2D_vector( // CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>, // CHECK-SAME: %[[OFFSET:.+]]: index +// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0] // CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc -// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] -// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32> -// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32> +// CHECK-SAME: %[[COLLAPSED]] +// CHECK-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32> +// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32> // CHECK: return %[[VEC]] // ----- func.func @load_dynamic_source(%source: memref<?x?x?xf32>, - %offset: index) -> vector<8x16xf32> { - %0 = vector.load %source[%offset, %offset, %offset] + %i: index, %j: index, %k: index) -> vector<8x16xf32> { + %0 = vector.load %source[%i, %j, %k] : memref<?x?x?xf32>, vector<8x16xf32> return %0 : vector<8x16xf32> } // CHECK-LABEL: @load_dynamic_source( // CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>, -// CHECK-SAME: %[[OFFSET:.+]]: index -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index -// CHECK-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]] -// CHECK-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]] -// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]] -// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]] -// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] -// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1] -// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32> -// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32> +// CHECK-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index +// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0] +// CHECK: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]] +// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]] +// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF1]], %[[OFF2]]]{{.*}}-> vector<8x16xf32> // CHECK: return %[[VEC]] // ----- @@ -72,9 +67,9 @@ func.func @load_out_of_bounds(%source: memref<7x15xf32>, // CHECK-SAME: %[[SRC:.+]]: memref<7x15xf32>, // CHECK-SAME: %[[OFFSET:.+]]: index // CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc -// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]] +// CHECK-SAME: %[[SRC]] // CHECK-SAME: memref<7x15xf32> -> !xegpu.tensor_desc<8x16xf32> -// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32> +// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32> // CHECK: return %[[VEC]] // ----- diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir index 2c498dc..1a10d91 100644 --- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir @@ -11,11 +11,12 @@ func.func @store_1D_vector(%vec: vector<8xf32>, // CHECK-SAME: %[[VEC:.+]]: vector<8xf32>, // CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>, // CHECK-SAME: %[[OFFSET:.+]]: index +// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0] // CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc -// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] -// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32, +// CHECK-SAME: %[[COLLAPSED]] +// CHECK-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32, // CHECK-SAME: boundary_check = false -// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf32> +// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]]] : vector<8xf32> // ----- @@ -30,16 +31,17 @@ func.func @store_2D_vector(%vec: vector<8x16xf32>, // CHECK-SAME: %[[VEC:.+]]: vector<8x16xf32>, // CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>, // CHECK-SAME: %[[OFFSET:.+]]: index +// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0] // CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc -// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] -// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32> -// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32> +// CHECK-SAME: %[[COLLAPSED]] +// CHECK-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32> +// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32> // ----- func.func @store_dynamic_source(%vec: vector<8x16xf32>, - %source: memref<?x?x?xf32>, %offset: index) { - vector.store %vec, %source[%offset, %offset, %offset] + %source: memref<?x?x?xf32>, %i: index, %j: index, %k: index) { + vector.store %vec, %source[%i, %j, %k] : memref<?x?x?xf32>, vector<8x16xf32> return } @@ -47,18 +49,11 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>, // CHECK-LABEL: @store_dynamic_source( // CHECK-SAME: %[[VEC:.+]]: vector<8x16xf32>, // CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>, -// CHECK-SAME: %[[OFFSET:.+]]: index -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index -// CHECK-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]] -// CHECK-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]] -// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]] -// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]] -// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] -// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1] -// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32> -// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32> +// CHECK-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index +// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0] +// CHECK: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]] +// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]] +// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF1]], %[[OFF2]]] : vector<8x16xf32> // ----- @@ -74,9 +69,9 @@ func.func @store_out_of_bounds(%vec: vector<8x16xf32>, // CHECK-SAME: %[[SRC:.+]]: memref<7x64xf32>, // CHECK-SAME: %[[OFFSET:.+]]: index // CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc -// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]] +// CHECK-SAME: %[[SRC]] // CHECK-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32> -// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32> +// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32> // ----- diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir index c4ca79a..8bb272b 100644 --- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir @@ -11,13 +11,15 @@ gpu.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector // LOAD-ND-LABEL: @load_1D_vector( // LOAD-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>, -// LOAD-ND-SAME: %[[OFFSET:.+]]: index -// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc -// LOAD-ND-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] -// LOAD-ND-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32, -// LOAD-ND-SAME: boundary_check = false -// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf32> -// LOAD-ND: return %[[VEC]] +// LOAD-ND: %[[CST:.+]] = arith.constant dense<true> : vector<8xi1> +// LOAD-ND: %[[STEP:.+]] = vector.step : vector<8xindex> +// LOAD-ND-COUNT2: arith.muli {{.*}} : index +// LOAD-ND-COUNT2: arith.addi {{.*}} : index +// LOAD-ND: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex> +// LOAD-ND: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex> +// LOAD-ND: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index +// LOAD-ND: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// LOAD-ND: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32> // LOAD-GATHER-LABEL: @load_1D_vector( // LOAD-GATHER-SAME: %[[SRC:.+]]: memref<8x16x32xf32>, @@ -46,11 +48,12 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>, // LOAD-ND-LABEL: @load_2D_vector( // LOAD-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>, // LOAD-ND-SAME: %[[OFFSET:.+]]: index +// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0] // LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc -// LOAD-ND-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] -// LOAD-ND-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32, +// LOAD-ND-SAME: %[[COLLAPSED]] +// LOAD-ND-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32, // LOAD-ND-SAME: boundary_check = false -// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32> +// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32> // LOAD-ND: return %[[VEC]] // LOAD-GATHER-LABEL: @load_2D_vector( @@ -83,9 +86,9 @@ gpu.func @load_zero_pad_out_of_bounds(%source: memref<32x64xf32>, // LOAD-ND-LABEL: @load_zero_pad_out_of_bounds( // LOAD-ND-SAME: %[[SRC:.+]]: memref<32x64xf32>, // LOAD-ND-SAME: %[[OFFSET:.+]]: index -// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]]] +// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]] // LOAD-ND-SAME: memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32> -// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32> +// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32> // LOAD-ND: return %[[VEC]] // LOAD-GATHER-LABEL: @load_zero_pad_out_of_bounds( @@ -109,9 +112,9 @@ gpu.func @load_transposed(%source: memref<32x64xf32>, // LOAD-ND-SAME: %[[SRC:.+]]: memref<32x64xf32>, // LOAD-ND-SAME: %[[OFFSET1:.+]]: index, // LOAD-ND-SAME: %[[OFFSET2:.+]]: index -// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET1]], %[[OFFSET2]]] +// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]] // LOAD-ND-SAME: memref<32x64xf32> -> !xegpu.tensor_desc<16x8xf32 -// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]] <{transpose = array<i64: 1, 0>}> +// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET1]], %[[OFFSET2]]] <{transpose = array<i64: 1, 0>}> // LOAD-ND-SAME: -> vector<8x16xf32> // LOAD-ND: return %[[VEC]] @@ -143,16 +146,11 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>, } // LOAD-ND-LABEL: @load_dynamic_source( // LOAD-ND-SAME: %[[SRC:.+]]: memref<?x?x?xf32>, -// LOAD-ND-SAME: %[[OFFSET:.+]]: index -// LOAD-ND: %[[C2:.+]] = arith.constant 2 : index -// LOAD-ND: %[[C1:.+]] = arith.constant 1 : index -// LOAD-ND: %[[C0:.+]] = arith.constant 0 : index -// LOAD-ND-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]] -// LOAD-ND-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]] -// LOAD-ND-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]] -// LOAD-ND: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]] -// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET:.+]], %[[OFFSET:.+]], %[[OFFSET:.+]]] -// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32> +// LOAD-ND-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index +// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0] +// LOAD-ND: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]] +// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]] +// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF1]], %[[OFF2]]]{{.*}}-> vector<8x16xf32> // LOAD-ND: return %[[VEC]] @@ -184,10 +182,11 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>, } // LOAD-ND-LABEL: @load_dynamic_source2( -// LOAD-ND-DAG: %[[C0:.+]] = arith.constant 0 : index -// LOAD-ND-DAG: %[[DIM:.+]] = memref.dim %{{.*}}, %[[C0]] : memref<?x8x16xf32> -// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}], shape : [%[[DIM]], 8, 16], strides : [128, 16, 1] : memref<?x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32> +// LOAD-ND-SAME: %[[SRC:.+]]: memref<?x8x16xf32>, +// LOAD-ND-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index +// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0] +// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32, strided<[16, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> +// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%{{.*}}, %{{.*}}] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32> // LOAD-ND: return %[[VEC]] : vector<8x16xf32> // LOAD-GATHER-LABEL: @load_dynamic_source2( @@ -406,7 +405,7 @@ gpu.func @no_load_unsupported_map(%source: memref<16x32x64xf32>, // ----- gpu.module @xevm_module { -gpu.func @load_from_subview(%source: memref<4096x4096xf16>, %off1: index, %off2: index) -> vector<8xf16> { +gpu.func @load_from_subview_1D(%source: memref<4096x4096xf16>, %off1: index, %off2: index) -> vector<8xf16> { %c0 = arith.constant 0.0 : f16 %subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> %0 = vector.transfer_read %subview[%off2, %off2], %c0 @@ -414,18 +413,23 @@ gpu.func @load_from_subview(%source: memref<4096x4096xf16>, %off1: index, %off2: gpu.return %0 : vector<8xf16> } -// LOAD-ND-LABEL: @load_from_subview( +// LOAD-ND-LABEL: @load_from_subview_1D( // LOAD-ND-SAME: %[[SRC:.+]]: memref<4096x4096xf16>, // LOAD-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index +// LOAD-ND: %[[CST:.+]] = arith.constant dense<true> : vector<8xi1> // LOAD-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> -// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc -// LOAD-ND-SAME: %[[SUBVIEW]][%[[OFF2]], %[[OFF2]]] -// LOAD-ND-SAME: memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8xf16, -// LOAD-ND-SAME: boundary_check = false -// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf16> -// LOAD-ND: return %[[VEC]] - -// LOAD-GATHER-LABEL: @load_from_subview( +// LOAD-ND: %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index +// LOAD-ND: %[[STEP:.+]] = vector.step : vector<8xindex> +// LOAD-ND: arith.muli {{.*}} : index +// LOAD-ND: arith.addi %[[OFFSET]]{{.*}} : index +// LOAD-ND: arith.addi {{.*}} : index +// LOAD-ND: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex> +// LOAD-ND: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex> +// LOAD-ND: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index +// LOAD-ND: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// LOAD-ND: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf16> + +// LOAD-GATHER-LABEL: @load_from_subview_1D( // LOAD-GATHER-SAME: %[[SRC:.+]]: memref<4096x4096xf16>, // LOAD-GATHER-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index // LOAD-GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<8xi1> @@ -441,3 +445,42 @@ gpu.func @load_from_subview(%source: memref<4096x4096xf16>, %off1: index, %off2: // LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 // LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf16> } + +// ----- +gpu.module @xevm_module { +gpu.func @load_from_subview_2D(%source: memref<4096x4096xf16>, %off1: index, %off2: index) -> vector<8x16xf16> { + %c0 = arith.constant 0.0 : f16 + %subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> + %0 = vector.transfer_read %subview[%off2, %off2], %c0 + {in_bounds = [true, true]} : memref<256x256xf16, strided<[4096, 1], offset: ?>>, vector<8x16xf16> + gpu.return %0 : vector<8x16xf16> +} + +// LOAD-ND-LABEL: @load_from_subview_2D( +// LOAD-ND-SAME: %[[SRC:.+]]: memref<4096x4096xf16>, +// LOAD-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index +// LOAD-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> +// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc +// LOAD-ND-SAME: %[[SUBVIEW]] +// LOAD-ND-SAME: memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf16, +// LOAD-ND-SAME: boundary_check = false +// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF2]], %[[OFF2]]]{{.*}}-> vector<8x16xf16> +// LOAD-ND: return %[[VEC]] + +// LOAD-GATHER-LABEL: @load_from_subview_2D( +// LOAD-GATHER-SAME: %[[SRC:.+]]: memref<4096x4096xf16>, +// LOAD-GATHER-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index +// LOAD-GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1> +// LOAD-GATHER: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> +// LOAD-GATHER: %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index +// LOAD-GATHER-COUNT2: vector.step +// LOAD-GATHER-COUNT2: vector.shape_cast +// LOAD-GATHER-COUNT2: vector.broadcast +// LOAD-GATHER-COUNT2: arith.muli {{.*}} : index +// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index +// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex> +// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} : vector<8x16xindex> +// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index +// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16> +} diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir index fcfc941..43a1a72 100644 --- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --xevm-attach-target='module=xevm.* O=3 chip=pvc' -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=STORE-ND +// RUN: mlir-opt %s --xevm-attach-target='module=xevm_* O=3 chip=pvc' -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=STORE-ND // RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=STORE-SCATTER @@ -15,11 +15,12 @@ gpu.func @store_1D_vector(%vec: vector<8xf32>, // STORE-ND-SAME: %[[VEC:.+]]: vector<8xf32>, // STORE-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>, // STORE-ND-SAME: %[[OFFSET:.+]]: index +// STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0] // STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc -// STORE-ND-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] -// STORE-ND-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32, +// STORE-ND-SAME: %[[COLLAPSED]] +// STORE-ND-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32, // STORE-ND-SAME: boundary_check = false -// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf32> +// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]]] : vector<8xf32> // STORE-SCATTER-LABEL: @store_1D_vector( // STORE-SCATTER-SAME: %[[VEC:.+]]: vector<8xf32>, @@ -49,11 +50,12 @@ gpu.func @store_2D_vector(%vec: vector<8x16xf32>, // STORE-ND-SAME: %[[VEC:.+]]: vector<8x16xf32>, // STORE-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>, // STORE-ND-SAME: %[[OFFSET:.+]]: index +// STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0] // STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc -// STORE-ND-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] -// STORE-ND-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32, +// STORE-ND-SAME: %[[COLLAPSED]] +// STORE-ND-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32, // STORE-ND-SAME: boundary_check = false -// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32> +// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32> // STORE-SCATTER-LABEL: @store_2D_vector( // STORE-SCATTER-SAME: %[[VEC:.+]]: vector<8x16xf32>, @@ -73,8 +75,8 @@ gpu.func @store_2D_vector(%vec: vector<8x16xf32>, // ----- gpu.module @xevm_module { gpu.func @store_dynamic_source(%vec: vector<8x16xf32>, - %source: memref<?x?x?xf32>, %offset: index) { - vector.transfer_write %vec, %source[%offset, %offset, %offset] + %source: memref<?x?x?xf32>, %i: index, %j: index, %k: index) { + vector.transfer_write %vec, %source[%i, %j, %k] {in_bounds = [true, true]} : vector<8x16xf32>, memref<?x?x?xf32> gpu.return @@ -83,18 +85,11 @@ gpu.func @store_dynamic_source(%vec: vector<8x16xf32>, // STORE-ND-LABEL: @store_dynamic_source( // STORE-ND-SAME: %[[VEC:.+]]: vector<8x16xf32>, // STORE-ND-SAME: %[[SRC:.+]]: memref<?x?x?xf32>, -// STORE-ND-SAME: %[[OFFSET:.+]]: index -// STORE-ND-DAG: %[[C0:.+]] = arith.constant 0 : index -// STORE-ND-DAG: %[[C1:.+]] = arith.constant 1 : index -// STORE-ND-DAG: %[[C2:.+]] = arith.constant 2 : index -// STORE-ND-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]] -// STORE-ND-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]] -// STORE-ND-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]] -// STORE-ND: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]] -// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] -// STORE-ND-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1] -// STORE-ND-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32 -// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32> +// STORE-ND-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index +// STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0] +// STORE-ND: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]] +// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]] +// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF1]], %[[OFF2]]] : vector<8x16xf32> // STORE-SCATTER-LABEL: @store_dynamic_source( // STORE-SCATTER-SAME: %[[VEC:.+]]: vector<8x16xf32>, @@ -126,9 +121,9 @@ gpu.func @store_out_of_bounds(%vec: vector<8x16xf32>, // STORE-ND-SAME: %[[SRC:.+]]: memref<7x64xf32>, // STORE-ND-SAME: %[[OFFSET:.+]]: index // STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc -// STORE-ND-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]] +// STORE-ND-SAME: %[[SRC]] // STORE-ND-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32> -// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32> +// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32> // STORE-SCATTER-LABEL: @store_out_of_bounds( // STORE-SCATTER: vector.transfer_write @@ -298,13 +293,13 @@ gpu.func @store_to_subview(%vec: vector<8xf16>, // STORE-ND-SAME: %[[VEC:.+]]: vector<8xf16>, // STORE-ND-SAME: %[[SRC:.+]]: memref<4096x4096xf16>, // STORE-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index -// STORE-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] -// STORE-ND-SAME: : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> +// STORE-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> +// STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SUBVIEW]][%[[OFF2]], 0] // STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc -// STORE-ND-SAME: %[[SUBVIEW]][%[[OFF2]], %[[OFF2]]] -// STORE-ND-SAME: memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8xf16, +// STORE-ND-SAME: %[[COLLAPSED]] +// STORE-ND-SAME: memref<256xf16, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf16, // STORE-ND-SAME: boundary_check = false -// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf16> +// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF2]]] : vector<8xf16> // STORE-SCATTER-LABEL: @store_to_subview( // STORE-SCATTER-SAME: %[[VEC:.+]]: vector<8xf16>, diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir index 09ef76c..9a1e2cb 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir @@ -7,42 +7,41 @@ gpu.module @create_nd_tdesc { // CHECK-SAME: %[[DYN:.*]]: memref<?x?xf16>) kernel { gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index, %stride1: index, %stride2: index, %offset1: index, %offset2: index, %dyn: memref<?x?xf16>) kernel { + // CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref<?x?xf16> -> index + // CHECK: %[[DYN_ADDR:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64 // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index // CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64 // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32> - // CHECK: %[[OFFSET_W:.*]] = arith.constant 0 : i32 - // CHECK: %[[OFFSET_H:.*]] = arith.constant 0 : i32 // CHECK: %[[SHAPE_W:.*]] = arith.index_cast %[[ARG3]] : index to i32 // CHECK: %[[SHAPE_H:.*]] = arith.index_cast %[[ARG2]] : index to i32 + // CHECK: %[[PITCH:.*]] = arith.index_cast %[[ARG4]] : index to i32 // CHECK: %[[VAR6:.*]] = vector.bitcast %[[CST]] : vector<8xi32> to vector<4xi64> // CHECK: %[[VAR7:.*]] = vector.insert %[[BASE_ADDR]], %[[VAR6]] [0] : i64 into vector<4xi64> // CHECK: %[[VAR8:.*]] = vector.bitcast %[[VAR7]] : vector<4xi64> to vector<8xi32> // CHECK: %[[VAR9:.*]] = vector.insert %[[SHAPE_W]], %[[VAR8]] [2] : i32 into vector<8xi32> // CHECK: %[[VAR10:.*]] = vector.insert %[[SHAPE_H]], %[[VAR9]] [3] : i32 into vector<8xi32> - // CHECK: %[[VAR11:.*]] = vector.insert %[[OFFSET_W]], %[[VAR10]] [4] : i32 into vector<8xi32> - // CHECK: %[[VAR12:.*]] = vector.insert %[[OFFSET_H]], %[[VAR11]] [5] : i32 into vector<8xi32> + // CHECK: %[[VAR11:.*]] = vector.insert %[[PITCH]], %[[VAR10]] [4] : i32 into vector<8xi32> %ptr_tdesc = xegpu.create_nd_tdesc %ptr, shape:[%shape1, %shape2], strides:[%stride1, %stride2] : ui64 -> !xegpu.tensor_desc<8x16xf32> // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<16x32xf32, 1> to memref<16x32xf32> %srcce = memref.memory_space_cast %src : memref<16x32xf32, 1> to memref<16x32xf32> - // CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32> // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index - // CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32 - // CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32 + // CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64 + // CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32> // CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64 // CHECK: %[[SHAPE_W2:.*]] = arith.trunci %[[C32_I64]] : i64 to i32 // CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64 // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32 - // CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64 + // CHECK: %[[C32_I64_2:.*]] = arith.constant 32 : i64 + // CHECK: %[[PITCH2:.*]] = arith.trunci %[[C32_I64_2]] : i64 to i32 // CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64> // CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64> // CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32> // CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32> // CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32> - // CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32> - // CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32> + // CHECK: %[[VAR19:.*]] = vector.insert %[[PITCH2]], %[[VAR18]] [4] : i32 into vector<8xi32> %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32> // CHECK: %[[C1:.*]] = arith.constant 1 : index @@ -51,20 +50,16 @@ gpu.module @create_nd_tdesc { %size_x = arith.constant 64 : index // CHECK: %[[C16:.*]] = arith.constant 16 : index %BLOCK_DMODEL = arith.constant 16 : index - // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32> - // CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref<?x?xf16> -> index - // CHECK: %[[C0_I32_6:.*]] = arith.constant 0 : i32 - // CHECK: %[[C0_I32_7:.*]] = arith.constant 0 : i32 - // CHECK: %[[VAR21:.*]] = arith.index_cast %[[C16]] : index to i32 - // CHECK: %[[VAR22:.*]] = arith.index_cast %[[C64]] : index to i32 - // CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64 - // CHECK: %[[VAR24:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64> - // CHECK: %[[VAR25:.*]] = vector.insert %[[VAR23]], %[[VAR24]] [0] : i64 into vector<4xi64> - // CHECK: %[[VAR26:.*]] = vector.bitcast %[[VAR25]] : vector<4xi64> to vector<8xi32> - // CHECK: %[[VAR27:.*]] = vector.insert %[[VAR21]], %[[VAR26]] [2] : i32 into vector<8xi32> - // CHECK: %[[VAR28:.*]] = vector.insert %[[VAR22]], %[[VAR27]] [3] : i32 into vector<8xi32> - // CHECK: %[[VAR29:.*]] = vector.insert %[[C0_I32_6]], %[[VAR28]] [4] : i32 into vector<8xi32> - // CHECK: %[[VAR30:.*]] = vector.insert %[[C0_I32_7]], %[[VAR29]] [5] : i32 into vector<8xi32> + // CHECK: %[[CST_3:.*]] = arith.constant dense<0> : vector<8xi32> + // CHECK: %[[SHAPE_W3:.*]] = arith.index_cast %[[C16]] : index to i32 + // CHECK: %[[SHAPE_H3:.*]] = arith.index_cast %[[C64]] : index to i32 + // CHECK: %[[PITCH3:.*]] = arith.index_cast %[[C16]] : index to i32 + // CHECK: %[[VAR25:.*]] = vector.bitcast %[[CST_3]] : vector<8xi32> to vector<4xi64> + // CHECK: %[[VAR26:.*]] = vector.insert %[[DYN_ADDR]], %[[VAR25]] [0] : i64 into vector<4xi64> + // CHECK: %[[VAR27:.*]] = vector.bitcast %[[VAR26]] : vector<4xi64> to vector<8xi32> + // CHECK: %[[VAR28:.*]] = vector.insert %[[SHAPE_W3]], %[[VAR27]] [2] : i32 into vector<8xi32> + // CHECK: %[[VAR29:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR28]] [3] : i32 into vector<8xi32> + // CHECK: %[[VAR30:.*]] = vector.insert %[[PITCH3]], %[[VAR29]] [4] : i32 into vector<8xi32> %dyn_tdesc = xegpu.create_nd_tdesc %dyn, shape: [%size_x, %BLOCK_DMODEL], strides: [%BLOCK_DMODEL, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<16x16xf16> gpu.return } diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir new file mode 100644 index 0000000..aebec7f --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir @@ -0,0 +1,36 @@ +// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s + +gpu.module @load_store_check { + // CHECK-LABEL: @load_store( + // CHECK-SAME: %[[SRC:.*]]: memref<512xf32, 1>, %[[DST:.*]]: memref<256xf32, 1> + gpu.func @load_store(%src: memref<512xf32, 1>, %dst: memref<256xf32, 1>) kernel { + // CHECK: %[[C512:.*]] = arith.constant 512 : i64 + // CHECK: %[[C384:.*]] = arith.constant 384 : i64 + + // CHECK: %[[SRCCE:.*]] = memref.memory_space_cast %[[SRC]] : memref<512xf32, 1> to memref<512xf32> + %srcce = memref.memory_space_cast %src : memref<512xf32, 1> to memref<512xf32> + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[SRCCE]] : memref<512xf32> -> index + // CHECK: %[[INTPTR_I64:.*]] = arith.index_castui %[[INTPTR]] : index to i64 + // CHECK: %[[DSTTE:.*]] = memref.memory_space_cast %[[DST]] : memref<256xf32, 1> to memref<256xf32> + %dstte = memref.memory_space_cast %dst : memref<256xf32, 1> to memref<256xf32> + // CHECK: %[[INTPTR1:.*]] = memref.extract_aligned_pointer_as_index %[[DSTTE]] : memref<256xf32> -> index + // CHECK: %[[INTPTR1_I64:.*]] = arith.index_castui %[[INTPTR1]] : index to i64 + + %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<512xf32> -> !xegpu.tensor_desc<32xf32> + // CHECK: %[[ADDR:.*]] = arith.addi %[[INTPTR_I64]], %[[C384]] : i64 + // CHECK: %[[PTR:.*]] = llvm.inttoptr %[[ADDR]] : i64 to !llvm.ptr<1> + // CHECK: %[[LOAD:.*]] = xevm.blockload %[[PTR]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}> + // CHECK-SAME: : (!llvm.ptr<1>) -> vector<2xi32> + %loaded = xegpu.load_nd %src_tdesc[96] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> + : !xegpu.tensor_desc<32xf32> -> vector<2xf32> + + %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<256xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.block_tdesc_attr<memory_space = global>> + // CHECK: %[[ADDR1:.*]] = arith.addi %[[INTPTR1_I64]], %[[C512]] : i64 + // CHECK: %[[PTR1:.*]] = llvm.inttoptr %[[ADDR1]] : i64 to !llvm.ptr<1> + // CHECK: xevm.blockstore %[[PTR1]], %[[LOAD]] <{cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>}> + // CHECK-SAME: : (!llvm.ptr<1>, vector<2xi32>) + xegpu.store_nd %loaded, %dst_tdesc[128] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> + : vector<2xf32>, !xegpu.tensor_desc<32xf32, #xegpu.block_tdesc_attr<memory_space = global>> + gpu.return + } +} diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir index d4cb493..3a3769f 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir @@ -4,8 +4,8 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] { // e.g. for mem_desc<32x32xf16, @strides=[1, 16]> // its memory layout tuple is (blocked shape = [1,1,32,32],strides=[1024,1024,32,1]) - //CHECK-LABEL: load_store_matrix_1 - gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> f32 { + //CHECK-LABEL: load_store_matrix_plain + gpu.func @load_store_matrix_plain(%arg0: memref<4096xi8, 3>) -> f32 { %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32> //CHECK: %[[TID:.*]] = gpu.thread_id x @@ -26,12 +26,40 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] { gpu.return %1: f32 } + //CHECK-LABEL: load_store_matrix_plain_2d_input + gpu.func @load_store_matrix_plain_2d_input(%arg0: memref<8192xi8, 3>) -> f32 { + %c0 = arith.constant 0 : index + %view = memref.view %arg0[%c0][]: memref<8192xi8, 3> to memref<64x32xf32, 3> + + %subview = memref.subview %view[32, 0] [32, 32] [1, 1] : memref<64x32xf32, 3> to memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> + + %0 = xegpu.create_mem_desc %subview : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> !xegpu.mem_desc<32x32xf32> + + //CHECK: %[[TID:.*]] = gpu.thread_id x + //CHECK: %[[C1:.*]] = arith.constant 1 : index + //CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index + //CHECK: %[[C4:.*]] = arith.constant 4 : i32 + //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i32 + //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32 + + %tid_x = gpu.thread_id x + + %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32 + + //CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3> + + xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index + + gpu.return %1: f32 + } + + // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 32]> // its memory layout tuple is ([2,4,16,16],[256,512,1,16]) - //CHECK-LABEL: load_store_matrix_2 - gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> f16 { + //CHECK-LABEL: load_store_matrix_blocked_strided + gpu.func @load_store_matrix_blocked_strided(%arg0: memref<4096xi8, 3>) -> f16 { %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>> - //CHECK: %[[c0:.*]] = arith.constant 0 : index + //CHECK: %[[tid_x:.*]] = gpu.thread_id x //CHECK: %[[c13:.*]] = arith.constant 13 : index //CHECK: %[[c16:.*]] = arith.constant 16 : index @@ -39,7 +67,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] { //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index - + //CHECK: %[[c0:.*]] = arith.constant 0 : index //CHECK: %[[c256:.*]] = arith.constant 256 : index //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index @@ -53,39 +81,39 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] { //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> f16 - + %tid_x = gpu.thread_id x %c13 = arith.constant 13 : index %1 = xegpu.load_matrix %0[%c13, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> f16 //CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3> - - xegpu.store_matrix %1, %0[%c13, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index + + xegpu.store_matrix %1, %0[%c13, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index gpu.return %1: f16 } // e.g. for mem_desc<32x64xf16, @block=[16, 16]> // its memory layout tuple is ([2,4,16,16],[1024,256,16,1]) - //CHECK-LABEL: load_store_matrix_3 - gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> f16 { - //CHECK: %[[c0:.*]] = arith.constant 0 : index - //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3> + //CHECK-LABEL: load_store_matrix_blocked_nostride + gpu.func @load_store_matrix_blocked_nostride(%arg0: memref<4096xi8, 3>) -> f16 { + + //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index + //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32 %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>> - + //CHECK: %[[tid_x:.*]] = gpu.thread_id x //CHECK: %[[c19:.*]] = arith.constant 19 : index %tid_x = gpu.thread_id x %c19 = arith.constant 19: index - - //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index - //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32 + //CHECK: %[[c16:.*]] = arith.constant 16 : index //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index + //CHECK: %[[c0:.*]] = arith.constant 0 : index //CHECK: %[[c1024:.*]] = arith.constant 1024 : index //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index @@ -97,32 +125,29 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] { //CHECK: %[[c1:.*]] = arith.constant 1 : index //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index - //CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16 %1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> f16 - + //CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3> xegpu.store_matrix %1, %0[%c19, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index - + //CHECK: gpu.return %[[loaded]] : f16 gpu.return %1: f16 } // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]> // its memory layout tuple is ([2,4,16,16],[256,512,1,16]) - //CHECK-LABEL: load_store_matrix_4 - gpu.func @load_store_matrix_4(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { + //CHECK-LABEL: load_store_matrix_blocked_strided_return_vector + gpu.func @load_store_matrix_blocked_strided_return_vector(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>> - //CHECK: %[[c0:.*]] = arith.constant 0 : index //CHECK: %[[tid_x:.*]] = gpu.thread_id x - //CHECK: %[[c16:.*]] = arith.constant 16 : index //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index - + //CHECK: %[[c0:.*]] = arith.constant 0 : index //CHECK: %[[c256:.*]] = arith.constant 256 : index //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index @@ -136,7 +161,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] { //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> vector<8xf16> - + %tid_x = gpu.thread_id x %c16 = arith.constant 16 : index %1 = xegpu.load_matrix %0[%c16, %tid_x] : !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<8xf16> @@ -147,28 +172,26 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] { gpu.return %1: vector<8xf16> } - + // e.g. for mem_desc<32x64xf16, @block=[16, 16]> // its memory layout tuple is ([2,4,16,16],[1024,256,16,1]) - //CHECK-LABEL: load_store_matrix_5 - gpu.func @load_store_matrix_5(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { - //CHECK: %[[c0:.*]] = arith.constant 0 : index - //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3> - - %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>> - + //CHECK-LABEL: load_store_matrix_blocked_subgroupblockio + gpu.func @load_store_matrix_blocked_subgroupblockio(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { + + //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index + //CHECK: %[[basePtrI32:.*]] = arith.index_castui %[[intptr]] : index to i32 + %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>> + //CHECK: %[[c16:.*]] = arith.constant 16 : index //CHECK: %[[c48:.*]] = arith.constant 48 : index - %c16 = arith.constant 16 : index %c48 = arith.constant 48 : index - //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index - //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32 //CHECK: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index //CHECK: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index //CHECK: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index //CHECK: %[[offset3:.*]] = arith.remsi %[[c48]], %[[c16]] : index + //CHECK: %[[c0:.*]] = arith.constant 0 : index //CHECK: %[[c1024:.*]] = arith.constant 1024 : index //CHECK: %[[mul0:.*]] = arith.muli %[[offset0]], %[[c1024]] : index //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index @@ -183,7 +206,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] { //CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i32 //CHECK: %[[c2:.*]] = arith.constant 2 : i32 //CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i32 - //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI64]], %[[byteOffset]] : i32 + //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI32]], %[[byteOffset]] : i32 //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i32 to !llvm.ptr<3> //CHECK: %[[loadedI16:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16> //CHECK: %[[loaded:.*]] = vector.bitcast %[[loadedI16]] : vector<8xi16> to vector<8xf16> @@ -191,11 +214,22 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] { %1 = xegpu.load_matrix %0[%c16, %c48] {subgroup_block_io}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> vector<8xf16> //CHECK: %[[storeDataI16:.*]] = vector.bitcast %[[loaded]] : vector<8xf16> to vector<8xi16> - //CHECK: xevm.blockstore %[[ptr]], %[[storeDataI16]] : (!llvm.ptr<3>, vector<8xi16>) + //CHECK: xevm.blockstore %[[ptr]], %[[storeDataI16]] : (!llvm.ptr<3>, vector<8xi16>) xegpu.store_matrix %1, %0[%c16, %c48] {subgroup_block_io}: vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index gpu.return %1: vector<8xf16> } + gpu.func @matrix_vector_materialization(%matrixdesc : !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>) { + // CHECK: %[[XEVM_VECTOR:.*]] = llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xf16> + // CHECK: %[[SOURCE_MATERIALIZE:.*]] = vector.shape_cast %[[XEVM_VECTOR]] : vector<16xf16> to vector<1x16xf16> + // CHECK: %[[XEGPU_VECTOR:.*]] = arith.addf %[[SOURCE_MATERIALIZE]], %[[SOURCE_MATERIALIZE]] : vector<1x16xf16> + // CHECK: %[[TARGET_MATERIALIZE:.*]] = vector.shape_cast %[[XEGPU_VECTOR]] : vector<1x16xf16> to vector<16xf16> + // CHECK: llvm.store %[[TARGET_MATERIALIZE]], %{{.*}} : vector<16xf16>, !llvm.ptr<3> + %loaded = xegpu.load_matrix %matrixdesc[16,0] : !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>> -> vector<1x16xf16> + %loaded_2 = arith.addf %loaded, %loaded : vector<1x16xf16> + xegpu.store_matrix %loaded_2, %matrixdesc[16,0] : vector<1x16xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>> + gpu.return + } } diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir index 4c6bbf2..4c73c9c 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir @@ -1,72 +1,32 @@ -// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s +// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s gpu.module @load_store_check { + // CHECK-LABEL: gpu.func @load_store( gpu.func @load_store(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel { + // CHECK: %[[W_P_BYTES:.*]] = arith.constant 64 : i32 + // CHECK: %[[ZERO:.*]] = arith.constant 0 : i32 + // CHECK: %[[H:.*]] = arith.constant 8 : i32 %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32> %dstte = memref.memory_space_cast %dst : memref<8x16xf32, 1> to memref<8x16xf32> - // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64 - // CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64> - // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64> - // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32> - // CHECK: %[[LD_DESC_2:.*]] = vector.insert {{.*}}, %[[LD_DESC_1]] [2] : i32 into vector<8xi32> - // CHECK: %[[LD_DESC_3:.*]] = vector.insert {{.*}}, %[[LD_DESC_2]] [3] : i32 into vector<8xi32> - // CHECK: %[[LD_DESC_4:.*]] = vector.insert {{.*}}, %[[LD_DESC_3]] [4] : i32 into vector<8xi32> - // CHECK: %[[LD_DESC:.*]] = vector.insert {{.*}}, %[[LD_DESC_4]] [5] : i32 into vector<8xi32> %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> - - //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64> - //CHECK: %[[LD_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64> - //CHECK: %[[LD_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32> - //CHECK: %[[LD_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32> - //CHECK: %[[LD_TILE_W64:.*]] = arith.constant 0 : i64 - //CHECK: %[[LD_TILE_W:.*]] = arith.trunci %[[LD_TILE_W64]] : i64 to i32 - //CHECK: %[[LD_TILE_H64:.*]] = arith.constant 0 : i64 - //CHECK: %[[LD_TILE_H:.*]] = arith.trunci %[[LD_TILE_H64]] : i64 to i32 - //CHECK: %[[LD_LLVMPTR:.*]] = llvm.inttoptr %[[LD_INTPTR]] : i64 to !llvm.ptr<1> - //CHECK: %[[LD_SIZEOF_F32:.*]] = arith.constant 4 : i32 - //CHECK: %[[LD_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[LD_BASE_W]], %[[LD_SIZEOF_F32]] : i32 - //CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %[[LD_LLVMPTR]], %[[LD_BASE_ROW_IN_BYTES]], - //CHECK-SAME: %[[LD_BASE_H]], %[[LD_BASE_ROW_IN_BYTES]], %[[LD_TILE_W]], %[[LD_TILE_H]] + //CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %{{.*}}, %[[W_P_BYTES]], %[[H]], %[[W_P_BYTES]], %[[ZERO]], %[[ZERO]] //CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 32 : i32, //CHECK-SAME: pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false, //CHECK-SAME: v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> %loaded = xegpu.load_nd %src_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32> - //CHECK: %[[LD_LOADED_F32:.*]] = vector.bitcast %[[LD_LOADED_I32]] : vector<8xi32> to vector<8xf32> %tid_x = gpu.thread_id x %tid_x_i32 = arith.index_cast %tid_x : index to i32 %tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32 - //CHECK: %[[LOADED_F32_MODIFIED:.*]] = vector.insert %{{.*}}, %[[LD_LOADED_F32]] [0] : f32 into vector<8xf32> %loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32> - // CHECK: %[[PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64 - // CHECK: %[[CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64> - // CHECK: %[[DESC_0:.*]] = vector.insert %[[PTR_AS_I64]], %[[CREATE_DESC_I64]] [0] : i64 into vector<4xi64> - // CHECK: %[[DESC_1:.*]] = vector.bitcast %[[DESC_0]] : vector<4xi64> to vector<8xi32> - // CHECK: %[[DESC_2:.*]] = vector.insert {{.*}}, %[[DESC_1]] [2] : i32 into vector<8xi32> - // CHECK: %[[DESC_3:.*]] = vector.insert {{.*}}, %[[DESC_2]] [3] : i32 into vector<8xi32> - // CHECK: %[[DESC_4:.*]] = vector.insert {{.*}}, %[[DESC_3]] [4] : i32 into vector<8xi32> - // CHECK: %[[DESC:.*]] = vector.insert {{.*}}, %[[DESC_4]] [5] : i32 into vector<8xi32> %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>> - //CHECK: %[[DESC_I64:.*]] = vector.bitcast %[[DESC]] : vector<8xi32> to vector<4xi64> - //CHECK: %[[INTPTR:.*]] = vector.extract %[[DESC_I64]][0] : i64 from vector<4xi64> - //CHECK: %[[BASE_W:.*]] = vector.extract %[[DESC]][2] : i32 from vector<8xi32> - //CHECK: %[[BASE_H:.*]] = vector.extract %[[DESC]][3] : i32 from vector<8xi32> - //CHECK: %[[TILE_W64:.*]] = arith.constant 0 : i64 - //CHECK: %[[TILE_W:.*]] = arith.trunci %[[TILE_W64]] : i64 to i32 - //CHECK: %[[TILE_H64:.*]] = arith.constant 0 : i64 - //CHECK: %[[TILE_H:.*]] = arith.trunci %[[TILE_H64]] : i64 to i32 - //CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %[[INTPTR]] : i64 to !llvm.ptr<1> - //CHECK: %[[SIZEOF_F32:.*]] = arith.constant 4 : i32 - //CHECK: %[[BASE_ROW_IN_BYTES:.*]] = arith.muli %[[BASE_W]], %[[SIZEOF_F32]] : i32 - //CHECK: %[[FLAT_VALUE_I32:.*]] = vector.bitcast %[[LOADED_F32_MODIFIED]] : vector<8xf32> to vector<8xi32> - //CHECK: xevm.blockstore2d %[[LLVMPTR]], %[[BASE_ROW_IN_BYTES]], %[[BASE_H]], %[[BASE_ROW_IN_BYTES]], - //CHECK-SAME: %[[TILE_W]], %[[TILE_H]], %[[FLAT_VALUE_I32]] - //CHECK-SAME: <{cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 32 : i32, + //CHECK: xevm.blockstore2d %{{.*}}, %[[W_P_BYTES]], %[[H]], %[[W_P_BYTES]], %[[ZERO]], %[[ZERO]], %{{.*}} <{ + //CHECK-SAME: cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 32 : i32, //CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>) xegpu.store_nd %loaded_modified, %dst_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>> diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_sub_byte.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_sub_byte.mlir new file mode 100644 index 0000000..97e5ce1 --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_sub_byte.mlir @@ -0,0 +1,80 @@ +// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s + +gpu.module @load_store_check { + // CHECK-LABEL: gpu.func @load_store_matrix_a + // CHECK-SAME: %[[ARG0:.*]]: memref<16x128xi4, 1>, %[[ARG1:.*]]: memref<16x128xi4, 1> + gpu.func @load_store_matrix_a(%src: memref<16x128xi4, 1>, %dst: memref<16x128xi4, 1>) kernel { + // CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32 + // CHECK: %[[C8_I32:.*]] = arith.constant 8 : i32 + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4xi64> + // CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32 + // CHECK: %[[C128_I32:.*]] = arith.constant 128 : i32 + // CHECK: %[[SRCCE:.*]] = memref.memory_space_cast %[[ARG0]] + // CHECK: %[[SRCINDEX:.*]] = memref.extract_aligned_pointer_as_index %[[SRCCE]] + // CHECK: %[[SRCPTR64:.*]] = arith.index_castui %[[SRCINDEX]] : index to i64 + %srcce = memref.memory_space_cast %src : memref<16x128xi4, 1> to memref<16x128xi4> + // CHECK: %[[DSTTE:.*]] = memref.memory_space_cast %[[ARG1]] + // CHECK: %[[DSTINDEX:.*]] = memref.extract_aligned_pointer_as_index %[[DSTTE]] + // CHECK: %[[DSTPTR64:.*]] = arith.index_castui %[[DSTINDEX]] : index to i64 + %dstte = memref.memory_space_cast %dst : memref<16x128xi4, 1> to memref<16x128xi4> + + // CHECK: %[[PAYLOAD_SRC:.*]] = vector.insert %[[SRCPTR64]], %[[CST]] [0] : i64 into vector<4xi64> + // CHECK: %[[BITCAST1_SRC:.*]] = vector.bitcast %[[PAYLOAD_SRC]] : vector<4xi64> to vector<8xi32> + // CHECK: %[[PAYLOAD1_SRC:.*]] = vector.insert %[[C128_I32]], %[[BITCAST1_SRC]] [2] : i32 into vector<8xi32> + // CHECK: %[[PAYLOAD2_SRC:.*]] = vector.insert %[[C16_I32]], %[[PAYLOAD1_SRC]] [3] : i32 into vector<8xi32> + // CHECK: %[[PAYLOAD3_SRC:.*]] = vector.insert %[[C128_I32]], %[[PAYLOAD2_SRC]] [4] : i32 into vector<8xi32> + %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x128xi4> -> !xegpu.tensor_desc<8x64xi4> + + // CHECK: %[[BITCAST2:.*]] = vector.bitcast %[[PAYLOAD3_SRC]] : vector<8xi32> to vector<4xi64> + // CHECK: %[[SRCPTR64:.*]] = vector.extract %[[BITCAST2]][0] : i64 from vector<4xi64> + // CHECK: %[[SRCLLVMPTR:.*]] = llvm.inttoptr %[[SRCPTR64]] : i64 to !llvm.ptr<1> + // CHECK: %[[LOADED:.*]] = xevm.blockload2d %[[SRCLLVMPTR]], %[[C64_I32]], + // CHECK-SAME: %[[C16_I32]], %[[C64_I32]], %[[C16_I32]], %[[C8_I32]] <{ + // CHECK-SAME: cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 16 : i32, + // CHECK-SAME: pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false, + // CHECK-SAME: v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16> + %loaded = xegpu.load_nd %src_tdesc[8, 64] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> + : !xegpu.tensor_desc<8x64xi4> -> vector<32xi4> + + // CHECK: %[[PAYLOAD_DST:.*]] = vector.insert %[[DSTPTR64]], %[[CST]] [0] : i64 into vector<4xi64> + // CHECK: %[[BITCAST1_DST:.*]] = vector.bitcast %[[PAYLOAD_DST]] : vector<4xi64> to vector<8xi32> + // CHECK: %[[PAYLOAD1_DST:.*]] = vector.insert %[[C128_I32]], %[[BITCAST1_DST]] [2] : i32 into vector<8xi32> + // CHECK: %[[PAYLOAD2_DST:.*]] = vector.insert %[[C16_I32]], %[[PAYLOAD1_DST]] [3] : i32 into vector<8xi32> + // CHECK: %[[PAYLOAD3_DST:.*]] = vector.insert %[[C128_I32]], %[[PAYLOAD2_DST]] [4] : i32 into vector<8xi32> + %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<16x128xi4> -> !xegpu.tensor_desc<8x64xi4, #xegpu.block_tdesc_attr<memory_space = global>> + + // CHECK: %[[BITCAST2_DST:.*]] = vector.bitcast %[[PAYLOAD3_DST]] : vector<8xi32> to vector<4xi64> + // CHECK: %[[DSTPTR64:.*]] = vector.extract %[[BITCAST2_DST]][0] : i64 from vector<4xi64> + // CHECK: %[[DSTLLVMPTR:.*]] = llvm.inttoptr %[[DSTPTR64]] : i64 to !llvm.ptr<1> + // CHECK: xevm.blockstore2d %[[DSTLLVMPTR]], %[[C64_I32]], %[[C16_I32]], + // CHECK-SAME: %[[C64_I32]], %[[C16_I32]], %[[C8_I32]], %[[LOADED]] <{ + // CHECK-SAME: cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 16 : i32, + // CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>) + xegpu.store_nd %loaded, %dst_tdesc[8, 64] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> + : vector<32xi4>, !xegpu.tensor_desc<8x64xi4, #xegpu.block_tdesc_attr<memory_space = global>> + gpu.return + } + + // CHECK-LABEL: gpu.func @load_matrix_b_request_pack + gpu.func @load_matrix_b_request_pack(%src: memref<64x128xi4, 1>, %dst: memref<64x128xi4, 1>) kernel { + // CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32 + // CHECK: %[[C32_I32:.*]] = arith.constant 32 : i32 + // CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32 + %srcce = memref.memory_space_cast %src : memref<64x128xi4, 1> to memref<64x128xi4> + %dstte = memref.memory_space_cast %dst : memref<64x128xi4, 1> to memref<64x128xi4> + + %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<64x128xi4> -> !xegpu.tensor_desc<32x32xi4> + + // CHECK: xevm.blockload2d %{{.*}}, %[[C64_I32]], %[[C64_I32]], %[[C64_I32]], %[[C16_I32]], %[[C32_I32]] <{ + // CHECK-SAME: cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 8 : i32, + // CHECK-SAME: pack_register = true, tile_height = 32 : i32, tile_width = 16 : i32, transpose = false, + // CHECK-SAME: v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> + %loaded = xegpu.load_nd %src_tdesc[32, 32] <{packed, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> + : !xegpu.tensor_desc<32x32xi4> -> vector<64xi4> + + %c32 = arith.constant 32 : index + %c0 = arith.constant 0 : index + vector.store %loaded, %dstte[%c32, %c0] : memref<64x128xi4>, vector<64xi4> + gpu.return + } +} diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir index 9c552d8..d606cf5 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir @@ -1,15 +1,16 @@ -// RUN: mlir-opt %s --split-input-file -convert-xegpu-to-xevm | FileCheck %s +// RUN: mlir-opt %s --split-input-file -convert-xegpu-to-xevm -canonicalize | FileCheck %s gpu.module @test { // CHECK-LABEL: @load_gather_i64_src_value_offset -// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex> -gpu.func @load_gather_i64_src_value_offset(%src: i64, %offset: vector<1xindex>) { +// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: memref<1xf16> +// CHECK-SAME: %[[ARG3:.*]]: vector<1xi1> +gpu.func @load_gather_i64_src_value_offset(%src: i64, %offset: vector<1xindex>, %dst: memref<1xf16>, %mask: vector<1xi1>) { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f16 + // CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64 + // CHECK: %[[VAR2:.*]] = vector.extract %[[ARG3]][0] : i1 from vector<1xi1> // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex> // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 - // CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1> - // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1> - %1 = arith.constant dense<1>: vector<1xi1> - // CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64 // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C2_I64]] : i64 // CHECK: %[[VAR4:.*]] = arith.addi %[[ARG0]], %[[VAR3]] : i64 // CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1> @@ -17,11 +18,12 @@ gpu.func @load_gather_i64_src_value_offset(%src: i64, %offset: vector<1xindex>) // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR5]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>} : !llvm.ptr<1> -> f16 // CHECK: scf.yield %[[VAR7]] : f16 // CHECK: } else { - // CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f16 // CHECK: scf.yield %[[CST_0]] : f16 // CHECK: } - %3 = xegpu.load %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> + %0 = xegpu.load %src[%offset], %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : i64, vector<1xindex>, vector<1xi1> -> vector<1xf16> + %c0 = arith.constant 0 : index + vector.store %0, %dst[%c0] : memref<1xf16>, vector<1xf16> gpu.return } } @@ -30,16 +32,16 @@ gpu.func @load_gather_i64_src_value_offset(%src: i64, %offset: vector<1xindex>) gpu.module @test { // CHECK-LABEL: @source_materialize_single_elem_vec // CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: memref<1xf16> -gpu.func @source_materialize_single_elem_vec(%src: i64, %offset: vector<1xindex>, %dst: memref<1xf16>) { - %1 = arith.constant dense<1>: vector<1xi1> - %3 = xegpu.load %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> +// CHECK-SAME: %[[ARG3:.*]]: vector<1xi1> +gpu.func @source_materialize_single_elem_vec(%src: i64, %offset: vector<1xindex>, %dst: memref<1xf16>, %mask: vector<1xi1>) { + %0 = xegpu.load %src[%offset], %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : i64, vector<1xindex>, vector<1xi1> -> vector<1xf16> + // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[VAR_IF:.*]] = scf.if // CHECK: %[[VAR_RET:.*]] = vector.broadcast %[[VAR_IF]] : f16 to vector<1xf16> - // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: vector.store %[[VAR_RET]], %[[ARG2]][%[[C0]]] : memref<1xf16>, vector<1xf16> %c0 = arith.constant 0 : index - vector.store %3, %dst[%c0] : memref<1xf16>, vector<1xf16> + vector.store %0, %dst[%c0] : memref<1xf16>, vector<1xf16> gpu.return } } @@ -48,24 +50,21 @@ gpu.func @source_materialize_single_elem_vec(%src: i64, %offset: vector<1xindex> gpu.module @test { // CHECK-LABEL: @store_scatter_i64_src_value_offset -// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex> -gpu.func @store_scatter_i64_src_value_offset(%src: i64, %offset: vector<1xindex>) { +// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: vector<1xi1> +gpu.func @store_scatter_i64_src_value_offset(%src: i64, %offset: vector<1xindex>, %mask: vector<1xi1>) { + // CHECK: %[[CST_0:.*]] = arith.constant 2.900000e+00 : f32 + // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 + // CHECK: %[[VAR2:.*]] = vector.extract %[[ARG2]][0] : i1 from vector<1xi1> // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex> // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 - // CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1> - // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1> - %1 = arith.constant dense<1>: vector<1xi1> - // CHECK: %[[CST_0:.*]] = arith.constant dense<2.900000e+00> : vector<1xf32> - // CHECK: %[[VAR3:.*]] = vector.extract %[[CST_0]][0] : f32 from vector<1xf32> - %2 = arith.constant dense<2.9>: vector<1xf32> - // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 + %0 = arith.constant dense<2.9>: vector<1xf32> // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64 // CHECK: %[[VAR5:.*]] = arith.addi %[[ARG0]], %[[VAR4]] : i64 // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1> // CHECK: scf.if %[[VAR2]] { - // CHECK: llvm.store %[[VAR3]], %[[VAR6]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>} : f32, !llvm.ptr<1> + // CHECK: llvm.store %[[CST_0]], %[[VAR6]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>} : f32, !llvm.ptr<1> // CHECK: } - xegpu.store %2, %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> + xegpu.store %0, %src[%offset], %mask <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<1xf32>, i64, vector<1xindex>, vector<1xi1> gpu.return } @@ -76,9 +75,9 @@ gpu.module @test { // CHECK-LABEL: @prefetch_i64_src_value_offset // CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex> gpu.func @prefetch_i64_src_value_offset(%src: i64, %offset: vector<1xindex>) { + // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex> // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 - // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 // CHECK: %[[VAR2:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64 // CHECK: %[[VAR3:.*]] = arith.addi %[[ARG0]], %[[VAR2]] : i64 // CHECK: %[[VAR4:.*]] = llvm.inttoptr %[[VAR3]] : i64 to !llvm.ptr<1> @@ -94,11 +93,11 @@ gpu.module @test { // CHECK-LABEL: @prefetch_memref_src_value_offset // CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex> gpu.func @prefetch_memref_src_value_offset(%src: memref<256xf32>, %offset: vector<1xindex>) { + // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex> // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64 - // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64 // CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64 // CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1> diff --git a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir index 873478a..43df721 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir @@ -1,34 +1,18 @@ -// RUN: mlir-opt -convert-xegpu-to-xevm -split-input-file %s | FileCheck %s +// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s -gpu.module @fence_check { - gpu.func @fence(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel { +gpu.module @prefetch_nd_check { + // CHECK-LABEL: gpu.func @prefetch_nd + gpu.func @prefetch_nd(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel { + // CHECK: %[[BASE_WIDTH_PITCH_BYTES:.*]] = arith.constant 64 : i32 + // CHECK: %[[OFFSET_ZERO:.*]] = arith.constant 0 : i32 + // CHECK: %[[BASE_H:.*]] = arith.constant 8 : i32 %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32> - %dstte = memref.memory_space_cast %dst : memref<8x16xf32, 1> to memref<8x16xf32> - - // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64 - // CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64> - // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64> - // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32> - // CHECK: %[[LD_DESC_2:.*]] = vector.insert {{.*}}, %[[LD_DESC_1]] [2] : i32 into vector<8xi32> - // CHECK: %[[LD_DESC_3:.*]] = vector.insert {{.*}}, %[[LD_DESC_2]] [3] : i32 into vector<8xi32> - // CHECK: %[[LD_DESC_4:.*]] = vector.insert {{.*}}, %[[LD_DESC_3]] [4] : i32 into vector<8xi32> - // CHECK: %[[LD_DESC:.*]] = vector.insert {{.*}}, %[[LD_DESC_4]] [5] : i32 into vector<8xi32> %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> - //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64> - //CHECK: %[[PREF_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64> - //CHECK: %[[PREF_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32> - //CHECK: %[[PREF_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32> - //CHECK: %[[PREF_TILE_W64:.*]] = arith.constant 0 : i64 - //CHECK: %[[PREF_TILE_W:.*]] = arith.trunci %[[PREF_TILE_W64]] : i64 to i32 - //CHECK: %[[PREF_TILE_H64:.*]] = arith.constant 0 : i64 - //CHECK: %[[PREF_TILE_H:.*]] = arith.trunci %[[PREF_TILE_H64]] : i64 to i32 - //CHECK: %[[PREF_LLVMPTR:.*]] = llvm.inttoptr %[[PREF_INTPTR]] : i64 to !llvm.ptr<1> - //CHECK: %[[PREF_SIZEOF_F32:.*]] = arith.constant 4 : i32 - //CHECK: %[[PREF_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[PREF_BASE_W]], %[[PREF_SIZEOF_F32]] : i32 - //CHECK: xevm.blockprefetch2d %[[PREF_LLVMPTR]], %[[PREF_BASE_ROW_IN_BYTES]], %[[PREF_BASE_H]], - //CHECK-SAME: %[[PREF_BASE_ROW_IN_BYTES]], %[[PREF_TILE_W]], %[[PREF_TILE_H]] + //CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %{{.*}} : i64 to !llvm.ptr<1> + //CHECK: xevm.blockprefetch2d %[[LLVMPTR]], %[[BASE_WIDTH_PITCH_BYTES]], %[[BASE_H]], + //CHECK-SAME: %[[BASE_WIDTH_PITCH_BYTES]], %[[OFFSET_ZERO]], %[[OFFSET_ZERO]] //CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 32 : i32, //CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32, v_blocks = 1 : i32}> //CHECK-SAME: : (!llvm.ptr<1>, i32, i32, i32, i32, i32) diff --git a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd_sub_byte.mlir b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd_sub_byte.mlir new file mode 100644 index 0000000..f925472 --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd_sub_byte.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s + +gpu.module @prefetch_check { + // CHECK-LABEL: gpu.func @prefetch_matrix_a + gpu.func @prefetch_matrix_a(%src: memref<16x128xi4, 1>) kernel { + // CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32 + // CHECK: %[[C8_I32:.*]] = arith.constant 8 : i32 + // CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32 + %srcce = memref.memory_space_cast %src : memref<16x128xi4, 1> to memref<16x128xi4> + + %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x128xi4> -> !xegpu.tensor_desc<8x64xi4> + + // CHECK: xevm.blockprefetch2d %{{.*}}, %[[C64_I32]], %[[C16_I32]], %[[C64_I32]], %[[C16_I32]], %[[C8_I32]] + // CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 16 : i32, + // CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32, v_blocks = 1 : i32}> : (!llvm.ptr<1> + xegpu.prefetch_nd %src_tdesc[8, 64] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> + : !xegpu.tensor_desc<8x64xi4> + + gpu.return + } +} diff --git a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir index 72e70ff..7f01526 100644 --- a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir +++ b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir @@ -175,7 +175,7 @@ llvm.func @blockstore2d_cache_control(%c: !llvm.ptr<1>, %base_width_c: i32, %bas // ----- // CHECK-LABEL: llvm.func spir_funccc @_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i( // CHECK-SAME: !llvm.ptr<1> {llvm.nonnull}, i32, i32, i32, vector<2xi32>) attributes -// CHECK-SAME: {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind} +// CHECK-SAME: {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>, no_unwind} // CHECK: llvm.func @blockprefetch2d(%[[ARG0:.*]]: !llvm.ptr<1>, // CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) { llvm.func @blockprefetch2d(%ptr: !llvm.ptr<1>, %base_width: i32, %base_height: i32, %base_pitch: i32, %x: i32, %y: i32) { @@ -187,7 +187,7 @@ llvm.func @blockprefetch2d(%ptr: !llvm.ptr<1>, %base_width: i32, %base_height: i // CHECK: llvm.call spir_funccc @_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i( // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]]) // CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>)>, linkage = #llvm.linkage<external>, - // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind, + // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>, no_unwind, // CHECK-SAME: sym_name = "_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i", visibility_ = 0 : i64 xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y <{elem_size_in_bits=8 : i32, tile_width=32 : i32, tile_height=8 : i32, v_blocks=1 : i32, @@ -200,13 +200,13 @@ llvm.func @blockprefetch2d(%ptr: !llvm.ptr<1>, %base_width: i32, %base_height: i // CHECK-LABEL: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f( // CHECK-SAME: vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes // CHECK-SAME: {convergent, memory_effects = #llvm.memory_effects<other = none, argMem = none, -// CHECK-SAME: inaccessibleMem = none>, no_unwind, will_return} +// CHECK-SAME: inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>, no_unwind, will_return} // CHECK: llvm.func @mma(%[[ARG0:.*]]: vector<8xf32>, %[[ARG1:.*]]: vector<8xi16>, %[[ARG2:.*]]: vector<8xi32>) -> vector<8xf32> { llvm.func @mma(%loaded_c_casted: vector<8xf32>, %loaded_a: vector<8xi16>, %loaded_b_casted: vector<8xi32>) -> vector<8xf32> { // CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f( // CHECK-SAME: %[[ARG1]], %[[ARG2]], %[[ARG0]]) {convergent, function_type = // CHECK-SAME: !llvm.func<vector<8xf32> (vector<8xi16>, vector<8xi32>, vector<8xf32>)>, linkage = #llvm.linkage<external>, - // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, + // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>, no_unwind, // CHECK-SAME: sym_name = "_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f", visibility_ = 0 : i64, will_return} // CHECK-SAME: : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> %c_result = xevm.mma %loaded_a, %loaded_b_casted, %loaded_c_casted @@ -230,13 +230,13 @@ llvm.func @memfence() { // ----- // CHECK-LABEL: llvm.func spir_funccc @_Z8prefetchPU3AS1Kcm(!llvm.ptr<1>, i64) attributes -// CHECK-SAME: {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind} +// CHECK-SAME: {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>, no_unwind} // CHECK: llvm.func @prefetch(%[[ARG0:.*]]: !llvm.ptr<1>) { llvm.func @prefetch(%ptr: !llvm.ptr<1>) { // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i64) : i64 // CHECK: llvm.call spir_funccc @_Z8prefetchPU3AS1Kcm(%[[ARG0]], %[[VAR0]]) // CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i64)>, linkage = #llvm.linkage<external>, - // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, + // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>, // CHECK-SAME: no_unwind, sym_name = "_Z8prefetchPU3AS1Kcm", visibility_ = 0 : i64 xevm.prefetch %ptr <{cache_control = #xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>) llvm.return @@ -352,7 +352,7 @@ llvm.func @local_id.x() -> i32 { // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z12get_local_idj(%[[VAR0]]) // CHECK-SAME: {function_type = !llvm.func<i32 (i32)>, linkage = #llvm.linkage<external>, - // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, + // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>, // CHECK-SAME: no_unwind, sym_name = "_Z12get_local_idj", visibility_ = 0 : i64, will_return} : (i32) -> i32 %1 = xevm.local_id.x : i32 llvm.return %1 : i32 @@ -380,7 +380,7 @@ llvm.func @local_size.x() -> i32 { // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z14get_local_sizej(%[[VAR0]]) // CHECK-SAME: {function_type = !llvm.func<i32 (i32)>, linkage = #llvm.linkage<external>, - // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, + // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>, // CHECK-SAME: no_unwind, sym_name = "_Z14get_local_sizej", visibility_ = 0 : i64, will_return} : (i32) -> i32 %1 = xevm.local_size.x : i32 llvm.return %1 : i32 @@ -408,7 +408,7 @@ llvm.func @group_id.x() -> i32 { // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z12get_group_idj(%[[VAR0]]) // CHECK-SAME: {function_type = !llvm.func<i32 (i32)>, linkage = #llvm.linkage<external>, - // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, + // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>, // CHECK-SAME: no_unwind, sym_name = "_Z12get_group_idj", visibility_ = 0 : i64, will_return} : (i32) -> i32 %1 = xevm.group_id.x : i32 llvm.return %1 : i32 @@ -436,7 +436,7 @@ llvm.func @group_count.x() -> i32 { // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z14get_num_groupsj(%[[VAR0]]) // CHECK-SAME: {function_type = !llvm.func<i32 (i32)>, linkage = #llvm.linkage<external>, - // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, + // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>, // CHECK-SAME: no_unwind, sym_name = "_Z14get_num_groupsj", visibility_ = 0 : i64, will_return} : (i32) -> i32 %1 = xevm.group_count.x : i32 llvm.return %1 : i32 @@ -463,7 +463,7 @@ llvm.func @group_count.z() -> i32 { llvm.func @lane_id() -> i32 { // CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() // CHECK-SAME: {function_type = !llvm.func<i32 ()>, linkage = #llvm.linkage<external>, - // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, + // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>, // CHECK-SAME: no_unwind, sym_name = "_Z22get_sub_group_local_id", visibility_ = 0 : i64, will_return} : () -> i32 %1 = xevm.lane_id : i32 llvm.return %1 : i32 @@ -474,7 +474,7 @@ llvm.func @lane_id() -> i32 { llvm.func @subgroup_size() -> i32 { // CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z18get_sub_group_size() // CHECK-SAME: {function_type = !llvm.func<i32 ()>, linkage = #llvm.linkage<external>, - // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, + // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>, // CHECK-SAME: no_unwind, sym_name = "_Z18get_sub_group_size", visibility_ = 0 : i64, will_return} : () -> i32 %1 = xevm.subgroup_size : i32 llvm.return %1 : i32 @@ -485,7 +485,7 @@ llvm.func @subgroup_size() -> i32 { llvm.func @subgroup_id() -> i32 { // CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() // CHECK-SAME: {function_type = !llvm.func<i32 ()>, linkage = #llvm.linkage<external>, - // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, + // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>, // CHECK-SAME: no_unwind, sym_name = "_Z16get_sub_group_id", visibility_ = 0 : i64, will_return} : () -> i32 %1 = xevm.subgroup_id : i32 llvm.return %1 : i32 diff --git a/mlir/test/Dialect/AMDGPU/amdgpu-make-dma-descriptor-fold.mlir b/mlir/test/Dialect/AMDGPU/amdgpu-make-dma-descriptor-fold.mlir new file mode 100644 index 0000000..9d43c99 --- /dev/null +++ b/mlir/test/Dialect/AMDGPU/amdgpu-make-dma-descriptor-fold.mlir @@ -0,0 +1,19 @@ +// RUN: mlir-opt --canonicalize %s | FileCheck %s + +// CHECK-LABEL: @make_dma_descriptor_fold +// CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_base<i32>, %[[IDX:.+]]: index) +func.func @make_dma_descriptor_fold(%base: !amdgpu.tdm_base<i32>, %idx: index) -> !amdgpu.tdm_descriptor { + %c64 = arith.constant 64 : index + + // CHECK: amdgpu.make_dma_descriptor %[[BASE]] + %0 = amdgpu.make_dma_descriptor %base + // CHECK-SAME: globalSize [64, 64] + globalSize [%c64, %c64] + // CHECK-SAME: globalStride [64, 1] + globalStride [%c64, 1] + // CHECK-SAME: sharedSize [64, 64] + sharedSize [%c64, %c64] + iterate %idx, %idx, %idx + : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor + func.return %0 : !amdgpu.tdm_descriptor +} diff --git a/mlir/test/Dialect/AMDGPU/canonicalize.mlir b/mlir/test/Dialect/AMDGPU/canonicalize.mlir index fee0c00..cff1d3f 100644 --- a/mlir/test/Dialect/AMDGPU/canonicalize.mlir +++ b/mlir/test/Dialect/AMDGPU/canonicalize.mlir @@ -244,3 +244,39 @@ func.func @scaled_mfma_ugly_shapes(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4 %res_7 = amdgpu.scaled_mfma 16x16x128 (%sA_0_7[0] * %opA) * (%sB_6_19[0] * %opB) + %cst_0 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> return %res_4, %res_5, %res_6, %res_7 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32> } + +// ----- + +// CHECK-LABEL fuse_memory_counter_wait +func.func @fuse_memory_counter_wait() { + // CHECK: amdgpu.memory_counter_wait + // CHECK-SAME: load(1) store(2) ds(2) exp(1) tensor(0) + // CHECK-NEXT: return + amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4) tensor(5) + amdgpu.memory_counter_wait load(4) store(3) ds(2) exp(1) tensor(0) + return +} + +// CHECK-LABEL fuse_memory_counter_wait_different_counters +func.func @fuse_memory_counter_wait_different_counters() { + // CHECK: amdgpu.memory_counter_wait + // CHECK-SAME: load(1) store(2) ds(3) exp(4) + // CHECK-NEXT: return + amdgpu.memory_counter_wait load(1) store(2) + amdgpu.memory_counter_wait ds(3) exp(4) + return +} + +func.func private @use() + +// CHECK-LABEL fuse_memory_counter_wait_not_adjacent +func.func @fuse_memory_counter_wait_not_adjacent() { + // CHECK: amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4) + // CHECK-NEXT: call @use() + // CHECK-NEXT: amdgpu.memory_counter_wait load(4) store(3) ds(2) exp(1) + // CHECK-NEXT: return + amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4) + func.call @use() : () -> () + amdgpu.memory_counter_wait load(4) store(3) ds(2) exp(1) + return +} diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir index 4c6f62a..6308ea9 100644 --- a/mlir/test/Dialect/AMDGPU/invalid.mlir +++ b/mlir/test/Dialect/AMDGPU/invalid.mlir @@ -333,48 +333,86 @@ func.func @gather_to_lds_non_lds(%idx1 : index, %mem1 : memref<32xf16>, %mem2 : // ----- -func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_16(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) { - // expected-error@+1 {{'amdgpu.scaled_ext_packed816' op blockSize of 16 can only have firstScaleByte be 0 or 1.}} - %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(2) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16> - func.return +func.func @scaled_mfma_invalid_m(%arg0 : vector<4xf8E8M0FNU>, %arg1 : vector<32xf4E2M1FN>, %arg2 : vector<16xf32>) -> vector<16xf32> { + // expected-error@+1 {{'amdgpu.scaled_mfma' op attribute 'm' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16, 32}}} + %0 = amdgpu.scaled_mfma 8x32x64 (%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<16xf32> + func.return %0 : vector<16xf32> } // ----- -func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_32(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) { - // expected-error@+1 {{'amdgpu.scaled_ext_packed816' op blockSize of 32 can only have firstScaleByte be 0 or 2.}} - %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(1) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16> - func.return +func.func @scaled_mfma_invalid_n(%arg0 : vector<4xf8E8M0FNU>, %arg1 : vector<32xf4E2M1FN>, %arg2 : vector<16xf32>) -> vector<16xf32> { + // expected-error@+1 {{'amdgpu.scaled_mfma' op attribute 'n' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16, 32}}} + %0 = amdgpu.scaled_mfma 32x8x64 (%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<16xf32> + func.return %0 : vector<16xf32> +} + +// ----- + +func.func @scaled_mfma_invalid_k(%arg0 : vector<4xf8E8M0FNU>, %arg1 : vector<32xf4E2M1FN>, %arg2 : vector<16xf32>) -> vector<16xf32> { + // expected-error@+1 {{'amdgpu.scaled_mfma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {64, 128}}} + %0 = amdgpu.scaled_mfma 32x32x32 (%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<16xf32> + func.return %0 : vector<16xf32> +} + +// ----- + +func.func @make_dma_base_invalid_addressspace(%idx: index, %mem: memref<8xi32>) { + // expected-error@+1 {{'amdgpu.make_dma_base' op lds memref must have workgroup address space attribute.}} + amdgpu.make_dma_base %mem[%idx], %mem[%idx] : memref<8xi32>, memref<8xi32> -> !amdgpu.tdm_base<i32> +} + +// ----- + +func.func @make_dma_base_invalid_addressspace(%idx: index, %smem : memref<8xi32, #gpu.address_space<workgroup>>) { + // expected-error@+1 {{'amdgpu.make_dma_base' op global memref must have global address space attribute.}} + amdgpu.make_dma_base %smem[%idx], %smem[%idx] : memref<8xi32, #gpu.address_space<workgroup>>, memref<8xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32> +} + +// ----- + +func.func @make_dma_base_invalid_barrier(%base: !amdgpu.tdm_base<i32>, %barrier: memref<8xi32>, %idx: index) { + // expected-error@+1 {{'amdgpu.make_dma_descriptor' op atomic barrier address must be in LDS.}} + amdgpu.make_dma_descriptor %base globalSize [64, 64] globalStride [64, 1] sharedSize [64, 64] atomicBarrier(%barrier[%idx] : memref<8xi32>) : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor } // ----- -func.func @amdgpu.scaled_ext_packed816_invalid_input_output_sizes(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) { - // expected-error@+1 {{'amdgpu.scaled_ext_packed816' op failed to verify that all of {source, res} have same shape}} - %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<16xf16> +// CHECK-LABEL: func @make_dma_descriptor_invalid_empty_strides +// CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_base<i32>) +func.func @make_dma_descriptor_invalid_empty_strides(%base: !amdgpu.tdm_base<i32>) { + // expected-error@+1 {{'amdgpu.make_dma_descriptor' op strides must not be empty.}} + amdgpu.make_dma_descriptor %base globalSize [0, 1] globalStride [] sharedSize [1, 0] : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor func.return } // ----- -func.func @scaled_mfma_invalid_m(%arg0 : vector<4xf8E8M0FNU>, %arg1 : vector<32xf4E2M1FN>, %arg2 : vector<16xf32>) -> vector<16xf32> { - // expected-error@+1 {{'amdgpu.scaled_mfma' op attribute 'm' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16, 32}}} - %0 = amdgpu.scaled_mfma 8x32x64 (%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<16xf32> - func.return %0 : vector<16xf32> +// CHECK-LABEL: func @make_dma_descriptor_invalid_innermost_stride +// CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_base<i32>) +func.func @make_dma_descriptor_invalid_innermost_stride(%base: !amdgpu.tdm_base<i32>) { + // expected-error@+1 {{'amdgpu.make_dma_descriptor' op strides for the innermost dimension must be 1.}} + amdgpu.make_dma_descriptor %base globalSize [2, 2] globalStride [1, 2] sharedSize [1, 0] : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor + func.return } // ----- -func.func @scaled_mfma_invalid_n(%arg0 : vector<4xf8E8M0FNU>, %arg1 : vector<32xf4E2M1FN>, %arg2 : vector<16xf32>) -> vector<16xf32> { - // expected-error@+1 {{'amdgpu.scaled_mfma' op attribute 'n' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16, 32}}} - %0 = amdgpu.scaled_mfma 32x8x64 (%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<16xf32> - func.return %0 : vector<16xf32> +// CHECK-LABEL: func @make_dma_descriptor_invalid_size_and_stride_sizes +// CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_base<i32>) +func.func @make_dma_descriptor_invalid_size_and_stride_sizes(%base: !amdgpu.tdm_base<i32>) { + // expected-error@+1 {{'amdgpu.make_dma_descriptor' op strides and sizes must have same rank.}} + amdgpu.make_dma_descriptor %base globalSize [1, 1, 1] globalStride [1, 1] sharedSize [1, 0] : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor + func.return } // ----- -func.func @scaled_mfma_invalid_k(%arg0 : vector<4xf8E8M0FNU>, %arg1 : vector<32xf4E2M1FN>, %arg2 : vector<16xf32>) -> vector<16xf32> { - // expected-error@+1 {{'amdgpu.scaled_mfma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {64, 128}}} - %0 = amdgpu.scaled_mfma 32x32x32 (%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<16xf32> - func.return %0 : vector<16xf32> +// CHECK-LABEL: func @make_dma_descriptor_invalid_shared_and_global_rank +// CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_base<i32>) +func.func @make_dma_descriptor_invalid_shared_and_global_rank(%base: !amdgpu.tdm_base<i32>) { + // expected-error@+1 {{'amdgpu.make_dma_descriptor' op tensor must have same rank as tile.}} + amdgpu.make_dma_descriptor %base globalSize [4, 4] globalStride [1, 1] sharedSize [1, 2, 3] : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor + func.return } + diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir index 09134cb..651aff4 100644 --- a/mlir/test/Dialect/AMDGPU/ops.mlir +++ b/mlir/test/Dialect/AMDGPU/ops.mlir @@ -221,58 +221,58 @@ func.func @scaled_ext_scalar_f4e2m1_bf16(%v: vector<2xf4E2M1FN>, %scale: f32) -> func.return %ret : vector<2xbf16> } -// CHECK-LABEL: func.func @scaled_ext_packed816_fp4 -func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) { - // CHECK: amdgpu.scaled_ext_packed816 - %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16> - // CHECK: amdgpu.scaled_ext_packed816 - %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xbf16> - // CHECK: amdgpu.scaled_ext_packed816 - %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf32> +// CHECK-LABEL: func.func @scaled_ext_packed_matrix_fp4 +func.func @scaled_ext_packed_matrix_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) { + // CHECK: amdgpu.scaled_ext_packed_matrix + %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16> + // CHECK: amdgpu.scaled_ext_packed_matrix + %ret1 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xbf16> + // CHECK: amdgpu.scaled_ext_packed_matrix + %ret2 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf32> func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32> } -// CHECK-LABEL: func.func @scaled_ext_packed816_fp8 -func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) { - // CHECK: amdgpu.scaled_ext_packed816 - %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf16> - // CHECK: amdgpu.scaled_ext_packed816 - %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xbf16> - // CHECK: amdgpu.scaled_ext_packed816 - %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf32> +// CHECK-LABEL: func.func @scaled_ext_packed_matrix_fp8 +func.func @scaled_ext_packed_matrix_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) { + // CHECK: amdgpu.scaled_ext_packed_matrix + %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf16> + // CHECK: amdgpu.scaled_ext_packed_matrix + %ret1 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xbf16> + // CHECK: amdgpu.scaled_ext_packed_matrix + %ret2 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf32> func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32> } -// CHECK-LABEL: func.func @scaled_ext_packed816_bf8 -func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) { - // CHECK: amdgpu.scaled_ext_packed816 - %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16> - // CHECK: amdgpu.scaled_ext_packed816 - %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xbf16> - // CHECK: amdgpu.scaled_ext_packed816 - %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf32> +// CHECK-LABEL: func.func @scaled_ext_packed_matrix_bf8 +func.func @scaled_ext_packed_matrix_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) { + // CHECK: amdgpu.scaled_ext_packed_matrix + %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16> + // CHECK: amdgpu.scaled_ext_packed_matrix + %ret1 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xbf16> + // CHECK: amdgpu.scaled_ext_packed_matrix + %ret2 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf32> func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32> } -// CHECK-LABEL: func.func @scaled_ext_packed816_fp6 -func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) { - // CHECK: amdgpu.scaled_ext_packed816 - %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16> - // CHECK: amdgpu.scaled_ext_packed816 - %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xbf16> - // CHECK: amdgpu.scaled_ext_packed816 - %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf32> +// CHECK-LABEL: func.func @scaled_ext_packed_matrix_fp6 +func.func @scaled_ext_packed_matrix_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) { + // CHECK: amdgpu.scaled_ext_packed_matrix + %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16> + // CHECK: amdgpu.scaled_ext_packed_matrix + %ret1 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xbf16> + // CHECK: amdgpu.scaled_ext_packed_matrix + %ret2 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf32> func.return %ret0, %ret1, %ret2 : vector<16xf16>, vector<16xbf16>, vector<16xf32> } -// CHECK-LABEL: func.func @scaled_ext_packed816_bf16 -func.func @scaled_ext_packed816_bf16(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) { - // CHECK: amdgpu.scaled_ext_packed816 - %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf16> - // CHECK: amdgpu.scaled_ext_packed816 - %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xbf16> - // CHECK: amdgpu.scaled_ext_packed816 - %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf32> +// CHECK-LABEL: func.func @scaled_ext_packed_matrix_bf6 +func.func @scaled_ext_packed_matrix_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) { + // CHECK: amdgpu.scaled_ext_packed_matrix + %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf16> + // CHECK: amdgpu.scaled_ext_packed_matrix + %ret1 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xbf16> + // CHECK: amdgpu.scaled_ext_packed_matrix + %ret2 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf32> func.return %ret0, %ret1, %ret2 : vector<16xf16>, vector<16xbf16>, vector<16xf32> } @@ -671,17 +671,105 @@ func.func @gather_to_lds(%idx1 : index, %idx2 : index, %mem1 : memref<32xf16>, % // CHECK-LABEL: func @memory_counter_wait func.func @memory_counter_wait() { - // CHECK: amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4) - // CHECK: amdgpu.memory_counter_wait load(4) store(2) ds(3) exp(1) + // CHECK: amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4) tensor(5) + // CHECK: amdgpu.memory_counter_wait load(4) store(2) ds(3) exp(1) tensor(0) // CHECK: amdgpu.memory_counter_wait load(1) // CHECK: amdgpu.memory_counter_wait store(2) // CHECK: amdgpu.memory_counter_wait ds(3) // CHECK: amdgpu.memory_counter_wait exp(4) - amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4) - amdgpu.memory_counter_wait exp(1) store(2) ds(3) load(4) + // CHECK: amdgpu.memory_counter_wait tensor(5) + amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4) tensor(5) + amdgpu.memory_counter_wait tensor(0) exp(1) store(2) ds(3) load(4) amdgpu.memory_counter_wait load(1) amdgpu.memory_counter_wait store(2) amdgpu.memory_counter_wait ds(3) amdgpu.memory_counter_wait exp(4) + amdgpu.memory_counter_wait tensor(5) + func.return +} + +// CHECK-LABEL: func @make_dma_base +// CHECK-SAME: (%[[IDX:.+]]: index, %[[MEM:.+]]: memref<8xi32>, %[[SMEM:.+]]: memref<8xi32, #gpu.address_space<workgroup>>) +func.func @make_dma_base(%idx: index, %mem: memref<8xi32>, %smem: memref<8xi32, #gpu.address_space<workgroup>>) { + // CHECK: amdgpu.make_dma_base %[[MEM]][%[[IDX]]], %[[SMEM]][%[[IDX]]] : memref<8xi32>, memref<8xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32> + amdgpu.make_dma_base %mem[%idx], %smem[%idx] : memref<8xi32>, memref<8xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32> + func.return +} + +// CHECK-LABEL: func @make_dma_descriptor +// CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_base<i32>, %[[WG_MASK:.+]]: i16, %[[TIMEOUT:.+]]: i1, %[[BARRIER:.+]]: memref<8xi32, #gpu.address_space<workgroup>>, %[[IDX:.+]]: index) +func.func @make_dma_descriptor(%base: !amdgpu.tdm_base<i32>, %wg_mask: i16, %timeout: i1, %barrier: memref<8xi32, #gpu.address_space<workgroup>>, %idx: index) { + + // CHECK: amdgpu.make_dma_descriptor %[[BASE]] + amdgpu.make_dma_descriptor %base + // CHECK-SAME: globalSize [64, 64] + globalSize [64, 64] + // CHECK-SAME: globalStride [64, 1] + globalStride [64, 1] + // CHECK-SAME: sharedSize [64, 64] : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor + sharedSize [64, 64] : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor + + // CHECK: amdgpu.make_dma_descriptor %[[BASE]] + amdgpu.make_dma_descriptor %base + // CHECK-SAME: globalSize [64, 64] + globalSize [64, 64] + // CHECK-SAME: globalStride [64, 1] + globalStride [64, 1] + // CHECK-SAME: sharedSize [64, 64] + sharedSize [64, 64] + // CHECK-SAME: padShared(%[[IDX]] every %[[IDX]]) + padShared(%idx every %idx) + : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor + + // CHECK: amdgpu.make_dma_descriptor %[[BASE]] + amdgpu.make_dma_descriptor %base + // CHECK-SAME: globalSize [64, 64] + globalSize [64, 64] + // CHECK-SAME: globalStride [64, 1] + globalStride [64, 1] + // CHECK-SAME: sharedSize [64, 64] + sharedSize [64, 64] + // CHECK-SAME: workgroupMask %[[WG_MASK]] + workgroupMask %wg_mask + : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor + + // CHECK: amdgpu.make_dma_descriptor %[[BASE]] + amdgpu.make_dma_descriptor %base + // CHECK-SAME: globalSize [64, 64] + globalSize [64, 64] + // CHECK-SAME: globalStride [64, 1] + globalStride [64, 1] + // CHECK-SAME: sharedSize [64, 64] + sharedSize [64, 64] + // CHECK-SAME: workgroupMask %[[WG_MASK]] + workgroupMask %wg_mask + // CHECK-SAME: earlyTimeout %[[TIMEOUT]] + earlyTimeout %timeout + : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor + + // CHECK: amdgpu.make_dma_descriptor %[[BASE]] + amdgpu.make_dma_descriptor %base + // CHECK-SAME: globalSize [64, 64] + globalSize [64, 64] + // CHECK-SAME: globalStride [64, 1] + globalStride [64, 1] + // CHECK-SAME: sharedSize [64, 64] + sharedSize [64, 64] + // CHECK-SAME: atomicBarrier(%[[BARRIER]][%[[IDX]]] : memref<8xi32, #gpu.address_space<workgroup>>) + atomicBarrier(%barrier[%idx] : memref<8xi32, #gpu.address_space<workgroup>>) + : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor + + // CHECK: amdgpu.make_dma_descriptor %[[BASE]] + amdgpu.make_dma_descriptor %base + // CHECK-SAME: globalSize [64, 64] + globalSize [64, 64] + // CHECK-SAME: globalStride [64, 1] + globalStride [64, 1] + // CHECK-SAME: sharedSize [64, 64] + sharedSize [64, 64] + // CHECK-SAME: iterate %[[IDX]], %[[IDX]], %[[IDX]] + iterate %idx, %idx, %idx + : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor + func.return } diff --git a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir index b616632..b062736 100644 --- a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir +++ b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir @@ -243,6 +243,106 @@ func.func @vecdim_reduction_ori(%in: memref<256x512xi32>, %out: memref<256xi32>) // CHECK: affine.store %[[final_red]], %{{.*}} : memref<256xi32> // CHECK: } +// ----- + +func.func @vecdim_reduction_xori(%in: memref<256x512xi32>, %out: memref<256xi32>) { + %cst = arith.constant 0 : i32 + affine.for %i = 0 to 256 { + %final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (i32) { + %ld = affine.load %in[%i, %j] : memref<256x512xi32> + %xor = arith.xori %red_iter, %ld : i32 + affine.yield %xor : i32 + } + affine.store %final_red, %out[%i] : memref<256xi32> + } + return +} + +// CHECK-LABEL: func.func @vecdim_reduction_xori( +// CHECK-SAME: %[[input:.*]]: memref<256x512xi32>, +// CHECK-SAME: %[[output:.*]]: memref<256xi32>) { +// CHECK: %[[cst:.*]] = arith.constant 0 : i32 +// CHECK: affine.for %{{.*}} = 0 to 256 { +// CHECK: %[[vzero:.*]] = arith.constant dense<0> : vector<128xi32> +// CHECK: %[[vred:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[red_iter:.*]] = %[[vzero]]) -> (vector<128xi32>) { +// CHECK: %[[poison:.*]] = ub.poison : i32 +// CHECK: %[[ld:.*]] = vector.transfer_read %[[input]]{{\[}}%{{.*}}, %{{.*}}], %[[poison]] : memref<256x512xi32>, vector<128xi32> +// CHECK: %[[xor:.*]] = arith.xori %[[red_iter]], %[[ld]] : vector<128xi32> +// CHECK: affine.yield %[[xor]] : vector<128xi32> +// CHECK: } +// CHECK: %[[final_red:.*]] = vector.reduction <xor>, %[[vred]] : vector<128xi32> into i32 +// CHECK: affine.store %[[final_red]], %[[output]]{{\[}}%{{.*}}] : memref<256xi32> +// CHECK: } +// CHECK: return +// CHECK: } + +// ----- + +func.func @vecdim_reduction_minnumf(%in: memref<256x512xf32>, %out: memref<256xf32>) { + %cst = arith.constant 0xFF800000 : f32 + affine.for %i = 0 to 256 { + %final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) { + %ld = affine.load %in[%i, %j] : memref<256x512xf32> + %min = arith.minnumf %red_iter, %ld : f32 + affine.yield %min : f32 + } + affine.store %final_red, %out[%i] : memref<256xf32> + } + return +} + +// CHECK-LABEL: func.func @vecdim_reduction_minnumf( +// CHECK-SAME: %[[input:.*]]: memref<256x512xf32>, +// CHECK-SAME: %[[output:.*]]: memref<256xf32>) { +// CHECK: %[[cst:.*]] = arith.constant 0xFF800000 : f32 +// CHECK: affine.for %{{.*}} = 0 to 256 { +// CHECK: %[[vzero:.*]] = arith.constant dense<0x7FC00000> : vector<128xf32> +// CHECK: %[[vred:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[red_iter:.*]] = %[[vzero]]) -> (vector<128xf32>) { +// CHECK: %[[poison:.*]] = ub.poison : f32 +// CHECK: %[[ld:.*]] = vector.transfer_read %[[input]]{{\[}}%{{.*}}, %{{.*}}], %[[poison]] : memref<256x512xf32>, vector<128xf32> +// CHECK: %[[min:.*]] = arith.minnumf %[[red_iter]], %[[ld]] : vector<128xf32> +// CHECK: affine.yield %[[min]] : vector<128xf32> +// CHECK: } +// CHECK: %[[red_scalar:.*]] = vector.reduction <minnumf>, %[[vred]] : vector<128xf32> into f32 +// CHECK: %[[final_red:.*]] = arith.minnumf %[[red_scalar]], %[[cst]] : f32 +// CHECK: affine.store %[[final_red]], %[[output]]{{\[}}%{{.*}}] : memref<256xf32> +// CHECK: } +// CHECK: return +// CHECK: } + +// ----- + +func.func @vecdim_reduction_maxnumf(%in: memref<256x512xf32>, %out: memref<256xf32>) { + %cst = arith.constant 0xFF800000 : f32 + affine.for %i = 0 to 256 { + %final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) { + %ld = affine.load %in[%i, %j] : memref<256x512xf32> + %max = arith.maxnumf %red_iter, %ld : f32 + affine.yield %max : f32 + } + affine.store %final_red, %out[%i] : memref<256xf32> + } + return +} + +// CHECK-LABEL: func.func @vecdim_reduction_maxnumf( +// CHECK-SAME: %[[input:.*]]: memref<256x512xf32>, +// CHECK-SAME: %[[output:.*]]: memref<256xf32>) { +// CHECK: %[[cst:.*]] = arith.constant 0xFF800000 : f32 +// CHECK: affine.for %{{.*}} = 0 to 256 { +// CHECK: %[[vzero:.*]] = arith.constant dense<0xFFC00000> : vector<128xf32> +// CHECK: %[[vred:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[red_iter:.*]] = %[[vzero]]) -> (vector<128xf32>) { +// CHECK: %[[poison:.*]] = ub.poison : f32 +// CHECK: %[[ld:.*]] = vector.transfer_read %[[input]]{{\[}}%{{.*}}, %{{.*}}], %[[poison]] : memref<256x512xf32>, vector<128xf32> +// CHECK: %[[max:.*]] = arith.maxnumf %[[red_iter]], %[[ld]] : vector<128xf32> +// CHECK: affine.yield %[[max]] : vector<128xf32> +// CHECK: } +// CHECK: %[[red_scalar:.*]] = vector.reduction <maxnumf>, %[[vred]] : vector<128xf32> into f32 +// CHECK: %[[final_red:.*]] = arith.maxnumf %[[red_scalar]], %[[cst]] : f32 +// CHECK: affine.store %[[final_red]], %[[output]]{{\[}}%{{.*}}] : memref<256xf32> +// CHECK: } +// CHECK: return +// CHECK: } // ----- diff --git a/mlir/test/Dialect/Affine/loop-coalescing.mlir b/mlir/test/Dialect/Affine/loop-coalescing.mlir index 3be14ea..6a82532 100644 --- a/mlir/test/Dialect/Affine/loop-coalescing.mlir +++ b/mlir/test/Dialect/Affine/loop-coalescing.mlir @@ -416,3 +416,31 @@ func.func @test_loops_do_not_get_coalesced() { // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return + +// ----- + +// CHECK-LABEL: func @inner_loop_has_iter_args +// CHECK-SAME: %[[ALLOC:.*]]: memref<?xi64>) +func.func @inner_loop_has_iter_args(%alloc : memref<?xi64>) { + %c17 = arith.constant 17 : index + affine.for %arg0 = 0 to 79 { + %0 = affine.for %arg1 = 0 to 64 iter_args(%arg2 = %alloc) -> (memref<?xi64>) { + %1 = arith.remui %arg1, %c17 : index + %2 = arith.index_cast %arg1 : index to i64 + memref.store %2, %arg2[%1] : memref<?xi64> + affine.yield %arg2 : memref<?xi64> + } + } + return +} + +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 17 : index +// CHECK: %[[APPLY_0:.*]] = affine.apply affine_map<() -> (79)>() +// CHECK: %[[APPLY_1:.*]] = affine.apply affine_map<() -> (64)>() +// CHECK: %[[APPLY_2:.*]] = affine.apply affine_map<(d0)[s0] -> (d0 * s0)>(%[[APPLY_0]]){{\[}}%[[APPLY_1]]] +// CHECK: affine.for %[[IV:.*]] = 0 to %[[APPLY_2]] { +// CHECK: %[[APPLY_3:.*]] = affine.apply affine_map<(d0)[s0] -> (d0 mod s0)>(%[[IV]]){{\[}}%[[APPLY_1]]] +// CHECK: %[[REMUI_0:.*]] = arith.remui %[[APPLY_3]], %[[CONSTANT_0]] : index +// CHECK: %[[INDEX_CAST_0:.*]] = arith.index_cast %[[APPLY_3]] : index to i64 +// CHECK: memref.store %[[INDEX_CAST_0]], %[[ALLOC]]{{\[}}%[[REMUI_0]]] : memref<?xi64> +// CHECK: } diff --git a/mlir/test/Dialect/Affine/value-bounds-reification.mlir b/mlir/test/Dialect/Affine/value-bounds-reification.mlir index 817614b..2e80102 100644 --- a/mlir/test/Dialect/Affine/value-bounds-reification.mlir +++ b/mlir/test/Dialect/Affine/value-bounds-reification.mlir @@ -36,13 +36,13 @@ func.func @reify_through_chain(%sz0: index, %sz2: index) -> (index, index, index // CHECK: "test.some_use"(%[[c5]]) // CHECK: %[[c5:.*]] = arith.constant 5 : index // CHECK: "test.some_use"(%[[c5]]) -func.func @reify_slice_bound(%t: tensor<?x?xi32>, %idx: index, %ub: index, %f: f32) { +func.func @reify_slice_bound(%t: tensor<?x?xi32>, %idx: index, %ub: index, %f: i32) { %c0 = arith.constant 0 : index %c4 = arith.constant 4 : index scf.for %iv = %c0 to %ub step %c4 { %sz = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 4)>(%iv)[%ub] %slice = tensor.extract_slice %t[%idx, %iv] [1, %sz] [1, 1] : tensor<?x?xi32> to tensor<1x?xi32> - %filled = linalg.fill ins(%f : f32) outs(%slice : tensor<1x?xi32>) -> tensor<1x?xi32> + %filled = linalg.fill ins(%f : i32) outs(%slice : tensor<1x?xi32>) -> tensor<1x?xi32> %bound = "test.reify_bound"(%filled) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index) "test.some_use"(%bound) : (index) -> () diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index 2fe0995..3ad1530 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -2958,6 +2958,19 @@ func.func @truncIShrSIToTrunciShrUI(%a: i64) -> i32 { return %hi : i32 } +// CHECK-LABEL: @truncIShrSIExactToTrunciShrUIExact +// CHECK-SAME: (%[[A:.+]]: i64) +// CHECK-NEXT: %[[C32:.+]] = arith.constant 32 : i64 +// CHECK-NEXT: %[[SHR:.+]] = arith.shrui %[[A]], %[[C32]] exact : i64 +// CHECK-NEXT: %[[TRU:.+]] = arith.trunci %[[SHR]] : i64 to i32 +// CHECK-NEXT: return %[[TRU]] : i32 +func.func @truncIShrSIExactToTrunciShrUIExact(%a: i64) -> i32 { + %c32 = arith.constant 32: i64 + %sh = arith.shrsi %a, %c32 exact : i64 + %hi = arith.trunci %sh: i64 to i32 + return %hi : i32 +} + // CHECK-LABEL: @truncIShrSIToTrunciShrUIBadShiftAmt1 // CHECK: arith.shrsi func.func @truncIShrSIToTrunciShrUIBadShiftAmt1(%a: i64) -> i32 { diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir index 1e656e8..58eadfd 100644 --- a/mlir/test/Dialect/Arith/ops.mlir +++ b/mlir/test/Dialect/Arith/ops.mlir @@ -151,6 +151,12 @@ func.func @test_divui(%arg0 : i64, %arg1 : i64) -> i64 { return %0 : i64 } +// CHECK-LABEL: test_divui_exact +func.func @test_divui_exact(%arg0 : i64, %arg1 : i64) -> i64 { + %0 = arith.divui %arg0, %arg1 exact : i64 + return %0 : i64 +} + // CHECK-LABEL: test_divui_tensor func.func @test_divui_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> { %0 = arith.divui %arg0, %arg1 : tensor<8x8xi64> @@ -175,6 +181,12 @@ func.func @test_divsi(%arg0 : i64, %arg1 : i64) -> i64 { return %0 : i64 } +// CHECK-LABEL: test_divsi_exact +func.func @test_divsi_exact(%arg0 : i64, %arg1 : i64) -> i64 { + %0 = arith.divsi %arg0, %arg1 exact : i64 + return %0 : i64 +} + // CHECK-LABEL: test_divsi_tensor func.func @test_divsi_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> { %0 = arith.divsi %arg0, %arg1 : tensor<8x8xi64> @@ -391,6 +403,12 @@ func.func @test_shrui(%arg0 : i64, %arg1 : i64) -> i64 { return %0 : i64 } +// CHECK-LABEL: test_shrui_exact +func.func @test_shrui_exact(%arg0 : i64, %arg1 : i64) -> i64 { + %0 = arith.shrui %arg0, %arg1 exact : i64 + return %0 : i64 +} + // CHECK-LABEL: test_shrui_tensor func.func @test_shrui_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> { %0 = arith.shrui %arg0, %arg1 : tensor<8x8xi64> @@ -415,6 +433,12 @@ func.func @test_shrsi(%arg0 : i64, %arg1 : i64) -> i64 { return %0 : i64 } +// CHECK-LABEL: test_shrsi_exact +func.func @test_shrsi_exact(%arg0 : i64, %arg1 : i64) -> i64 { + %0 = arith.shrsi %arg0, %arg1 exact : i64 + return %0 : i64 +} + // CHECK-LABEL: test_shrsi_tensor func.func @test_shrsi_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> { %0 = arith.shrsi %arg0, %arg1 : tensor<8x8xi64> diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir index 8249d59..3929f5b 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir @@ -368,21 +368,18 @@ func.func @multiple_materialize_in_destination_buffer(%m: memref<5xf32>, %f: f32 // ----- -// `EmptyTensorElimination` fails to find a valid insertion -// point for the new injected `SubsetExtraction`. -// CHECK-LABEL: func.func @fail_to_eliminate_any_empty_tensors -func.func @fail_to_eliminate_any_empty_tensors() -> tensor<5x6x128xf32> { +// CHECK-LABEL: func.func @eliminate_all_empty_tensors +func.func @eliminate_all_empty_tensors() -> tensor<5x6x128xf32> { %cst_1 = arith.constant 1.0 : f32 %cst_2 = arith.constant 2.0 : f32 - // CHECK: memref.alloc - // CHECK: memref.alloc - // CHECK: memref.alloc + // CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32> + // CHECK-NOT: memref.alloc %empty_1 = tensor.empty() : tensor<5x6x64xf32> %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32> %empty_2 = tensor.empty() : tensor<5x6x64xf32> %res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32> %cancatenated_empty = tensor.empty() : tensor<5x6x128xf32> - // CHECK: memref.copy + // CHECK-NOT: memref.copy %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1] : tensor<5x6x64xf32> into tensor<5x6x128xf32> %inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1] @@ -392,20 +389,19 @@ func.func @fail_to_eliminate_any_empty_tensors() -> tensor<5x6x128xf32> { // ----- -// CHECK-LABEL: func.func @succeed_to_eliminate_one_empty_tensor -func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> { +// CHECK-LABEL: func.func @eliminate_concatenated_empty_tensors +func.func @eliminate_concatenated_empty_tensors() -> tensor<5x6x128xf32> { %cst_1 = arith.constant 1.0 : f32 %cst_2 = arith.constant 2.0 : f32 // CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32> - // CHECK: memref.alloc // CHECK-NOT: memref.alloc - %cancatenated_empty = tensor.empty() : tensor<5x6x128xf32> + %concatenated_empty = tensor.empty() : tensor<5x6x128xf32> %empty_1 = tensor.empty() : tensor<5x6x64xf32> %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32> %empty_2 = tensor.empty() : tensor<5x6x64xf32> %res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32> - // CHECK: memref.copy - %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1] + // CHECK-NOT: memref.copy + %inserted_slice_1 = tensor.insert_slice %res_1 into %concatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1] : tensor<5x6x64xf32> into tensor<5x6x128xf32> %inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1] : tensor<5x6x64xf32> into tensor<5x6x128xf32> @@ -420,20 +416,22 @@ func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> { // CHECK-ELIM-LABEL: func.func @multi_use_of_the_same_tensor_empty // CHECK-LABEL: func.func @multi_use_of_the_same_tensor_empty +// CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32> +// CHECK-NOT: memref.alloc +// CHECK-NOT: memref.copy +// CHECK-ELIM: tensor.extract_slice {{.*}}[0, 0, 0] +// CHECK-ELIM: linalg.fill +// CHECK-ELIM: tensor.extract_slice {{.*}}[0, 0, 64] +// CHECK-ELIM: linalg.fill func.func @multi_use_of_the_same_tensor_empty() -> tensor<5x6x128xf32> { %cst_1 = arith.constant 1.0 : f32 %cst_2 = arith.constant 2.0 : f32 %cancatenated_empty = tensor.empty() : tensor<5x6x128xf32> %empty_1 = tensor.empty() : tensor<5x6x64xf32> - // CHECK-ELIM: %[[VAL_3:.*]] = tensor.extract_slice - // CHECK-ELIM: linalg.fill ins(%[[VAL_0:.*]] : f32) outs(%[[VAL_3]] - // CHECK-ELIM-NOT: linalg.fill ins(%[[VAL_1:.*]] : f32) outs(%[[VAL_3]] %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32> %res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32> - // CHECK: memref.copy %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1] : tensor<5x6x64xf32> into tensor<5x6x128xf32> - // CHECK-NOT: memref.copy %inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1] : tensor<5x6x64xf32> into tensor<5x6x128xf32> return %inserted_slice_2 : tensor<5x6x128xf32> @@ -476,3 +474,66 @@ func.func @direct_use_of_tensor_empty(%arg0: tensor<5x6x128xf32>) -> tensor<5x6x : tensor<5x6x64xf32> into tensor<5x6x128xf32> return %inserted_slice_1 : tensor<5x6x128xf32> } + +// ----- + +// Test that dependent pure operations are moved before the +// insertion point to enable empty tensor elimination. + +// CHECK-LABEL: func.func @move_dependent_arith_op( +// CHECK-SAME: %[[ARG0:.*]]: memref<10xf32> +// CHECK-SAME: %[[ARG1:.*]]: index +// CHECK-NOT: memref.alloc +// CHECK: %[[C5:.*]] = arith.constant 5 : index +// CHECK: %[[OFFSET:.*]] = arith.addi %[[ARG1]], %[[C5]] +// CHECK: %[[SV:.*]] = memref.subview %[[ARG0]][%[[OFFSET]]] [5] [1] +// CHECK: linalg.fill {{.*}} outs(%[[SV]] +// CHECK: return %[[ARG0]] +// CHECK-ELIM-LABEL: func.func @move_dependent_arith_op( +// CHECK-ELIM-SAME: %[[ARG0:.*]]: tensor<10xf32> +// CHECK-ELIM-SAME: %[[ARG1:.*]]: index +// CHECK-ELIM: %[[C5:.*]] = arith.constant 5 : index +// CHECK-ELIM: %[[OFFSET:.*]] = arith.addi %[[ARG1]], %[[C5]] +// CHECK-ELIM: %[[SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[OFFSET]]] [5] [1] +// CHECK-ELIM: %[[FILL:.*]] = linalg.fill {{.*}} outs(%[[SLICE]] +// CHECK-ELIM: tensor.insert_slice %[[FILL]] into %[[ARG0]][%[[OFFSET]]] +func.func @move_dependent_arith_op( + %arg0: tensor<10xf32> {bufferization.buffer_layout = affine_map<(d0) -> (d0)>, bufferization.writable = true}, + %arg1: index, %f: f32) -> tensor<10xf32> +{ + %0 = tensor.empty() : tensor<5xf32> + %1 = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32> + %c5 = arith.constant 5 : index + %offset = arith.addi %arg1, %c5 : index + %2 = tensor.insert_slice %1 into %arg0[%offset][5][1] + : tensor<5xf32> into tensor<10xf32> + return %2 : tensor<10xf32> +} + +// ----- + +// Test that side-effecting operations are not moved, preventing empty +// tensor elimination. + +// CHECK-LABEL: func.func @side_effecting_op_blocks_movement( +// CHECK: memref.alloc +// CHECK: linalg.fill +// CHECK: memref.load +// CHECK: memref.subview +// CHECK: memref.copy +// CHECK-ELIM-LABEL: func.func @side_effecting_op_blocks_movement( +// CHECK-ELIM: tensor.empty +// CHECK-ELIM: linalg.fill +// CHECK-ELIM: memref.load +// CHECK-ELIM: tensor.insert_slice +func.func @side_effecting_op_blocks_movement( + %arg0: tensor<10xf32> {bufferization.buffer_layout = affine_map<(d0) -> (d0)>, bufferization.writable = true}, + %mem: memref<index>, %f: f32) -> tensor<10xf32> +{ + %0 = tensor.empty() : tensor<5xf32> + %1 = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32> + %offset = memref.load %mem[] : memref<index> + %2 = tensor.insert_slice %1 into %arg0[%offset][5][1] + : tensor<5xf32> into tensor<10xf32> + return %2 : tensor<10xf32> +} diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir index 2c8807b..9884b04 100644 --- a/mlir/test/Dialect/Bufferization/invalid.mlir +++ b/mlir/test/Dialect/Bufferization/invalid.mlir @@ -127,3 +127,63 @@ func.func @invalid_manual_deallocation() { // expected-error @below{{op attribute 'bufferization.manual_deallocation' can be used only on ops that have an allocation and/or free side effect}} arith.constant {bufferization.manual_deallocation} 0 : index } + +// ----- + +func.func @invalid_rank_to_buffer(%t: tensor<1x2x3x4xf32>) { + // expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}} + // expected-error @below{{shapes do not match}} + %b = bufferization.to_buffer %t + : tensor<1x2x3x4xf32> to memref<1x2x3xf32> + return +} + +// ----- + +func.func @invalid_rank_to_tensor(%b: memref<1x2x3xf32>) { + // expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}} + // expected-error @below{{shapes do not match}} + %t = bufferization.to_tensor %b + : memref<1x2x3xf32> to tensor<1x2x3x4xf32> + return +} + +// ----- + +func.func @invalid_shape_to_buffer(%t: tensor<1x2x3x4xf32>) { + // expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}} + // expected-error @below{{shapes do not match}} + %b = bufferization.to_buffer %t + : tensor<1x2x3x4xf32> to memref<1x2x4x3xf32> + return +} + +// ----- + +func.func @invalid_shape_to_tensor(%b: memref<1x2x4x3xf32>) { + // expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}} + // expected-error @below{{shapes do not match}} + %t = bufferization.to_tensor %b + : memref<1x2x4x3xf32> to tensor<1x2x3x4xf32> + return +} + +// ----- + +func.func @invalid_type_to_buffer(%t: tensor<1x2x3x4xf32>) { + // expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}} + // expected-error @below{{element types do not match}} + %b = bufferization.to_buffer %t + : tensor<1x2x3x4xf32> to memref<1x2x3x4xf16> + return +} + +// ----- + +func.func @invalid_type_to_tensor(%b: memref<1x2x3x4xf16>) { + // expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}} + // expected-error @below{{element types do not match}} + %t2 = bufferization.to_tensor %b + : memref<1x2x3x4xf16> to tensor<1x2x3x4xf32> + return +} diff --git a/mlir/test/Dialect/Bufferization/ops.mlir b/mlir/test/Dialect/Bufferization/ops.mlir index fc6df4a..b0db1bb 100644 --- a/mlir/test/Dialect/Bufferization/ops.mlir +++ b/mlir/test/Dialect/Bufferization/ops.mlir @@ -83,3 +83,40 @@ func.func @test_dealloc_op(%arg0: memref<2xf32>, %arg1: memref<4xi32>, bufferization.dealloc return %0#0, %0#1 : i1, i1 } + +// CHECK: func.func @test_builtin_custom_builtin_type_conversion +// CHECK-SAME: (%[[t:.*]]: tensor<42xf32>) -> tensor<42xf32> +func.func @test_builtin_custom_builtin_type_conversion(%t: tensor<42xf32>) + -> tensor<42xf32> { + // CHECK: %[[buffer:.*]] = bufferization.to_buffer %[[t]] + // CHECK-SAME: to !test.test_memref<[42], f32> + %buffer = bufferization.to_buffer %t + : tensor<42xf32> to !test.test_memref<[42], f32> + + // CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[buffer]] + // CHECK-SAME: to tensor<42xf32> + %tensor = bufferization.to_tensor %buffer + : !test.test_memref<[42], f32> to tensor<42xf32> + + // CHECK: return %[[tensor]] + return %tensor : tensor<42xf32> +} + +// CHECK: func.func @test_custom_builtin_custom_type_conversion +// CHECK-SAME: (%[[t:.*]]: !test.test_tensor<[42], f32>) +// CHECK-SAME: -> !test.test_tensor<[42], f32> +func.func @test_custom_builtin_custom_type_conversion(%t: !test.test_tensor<[42], f32>) + -> !test.test_tensor<[42], f32> { + // CHECK: %[[buffer:.*]] = bufferization.to_buffer %[[t]] + // CHECK-SAME: to memref<42xf32> + %buffer = bufferization.to_buffer %t + : !test.test_tensor<[42], f32> to memref<42xf32> + + // CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[buffer]] + // CHECK-SAME: to !test.test_tensor<[42], f32> + %tensor = bufferization.to_tensor %buffer + : memref<42xf32> to !test.test_tensor<[42], f32> + + // CHECK: return %[[tensor]] + return %tensor : !test.test_tensor<[42], f32> +} diff --git a/mlir/test/Dialect/ControlFlow/canonicalize.mlir b/mlir/test/Dialect/ControlFlow/canonicalize.mlir index 17f7d28..21a1678 100644 --- a/mlir/test/Dialect/ControlFlow/canonicalize.mlir +++ b/mlir/test/Dialect/ControlFlow/canonicalize.mlir @@ -634,3 +634,25 @@ func.func @unsimplified_cycle_2(%c : i1) { ^bb7: cf.br ^bb6 } + +// CHECK-LABEL: @drop_unreachable_branch_1 +// CHECK-NEXT: "test.foo"() : () -> () +// CHECK-NEXT: return +func.func @drop_unreachable_branch_1(%c: i1) { + cf.cond_br %c, ^bb1, ^bb2 +^bb1: + "test.foo"() : () -> () + return +^bb2: + ub.unreachable +} + +// CHECK-LABEL: @drop_unreachable_branch_2 +// CHECK-NEXT: ub.unreachable +func.func @drop_unreachable_branch_2(%c: i1) { + cf.cond_br %c, ^bb1, ^bb2 +^bb1: + ub.unreachable +^bb2: + ub.unreachable +} diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir index 5f594fb..d1601be 100644 --- a/mlir/test/Dialect/EmitC/invalid_ops.mlir +++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir @@ -876,3 +876,57 @@ func.func @test_do(%arg0 : !emitc.ptr<i32>) { return } + +// ----- + +func.func @test_for_none_block_argument(%arg0: index) { + // expected-error@+1 {{expected body to have a single block argument for the induction variable}} + "emitc.for"(%arg0, %arg0, %arg0) ( + { + emitc.yield + } + ) : (index, index, index) -> () + return +} + +// ----- + +func.func @test_for_more_than_one_block_argument(%arg0: index) { + // expected-error@+1 {{expected body to have a single block argument for the induction variable}} + "emitc.for"(%arg0, %arg0, %arg0) ( + { + ^bb0(%i0 : index, %i1 : index): + emitc.yield + } + ) : (index, index, index) -> () + return +} + +// ----- + +func.func @test_for_unmatch_type(%arg0: index) { + // expected-error@+1 {{expected induction variable to be same type as bounds}} + "emitc.for"(%arg0, %arg0, %arg0) ( + { + ^bb0(%i0 : f32): + emitc.yield + } + ) : (index, index, index) -> () + return +} + +// ----- + +func.func @address_of(%arg0: !emitc.lvalue<i32>) { + // expected-error @+1 {{failed to verify that input and result reference the same type}} + %1 = "emitc.address_of"(%arg0) : (!emitc.lvalue<i32>) -> !emitc.ptr<i8> + return +} + +// ----- + +func.func @dereference(%arg0: !emitc.ptr<i32>) { + // expected-error @+1 {{failed to verify that input and result reference the same type}} + %1 = "emitc.dereference"(%arg0) : (!emitc.ptr<i32>) -> !emitc.lvalue<i8> + return +} diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir index 1259748..b2c8b84 100644 --- a/mlir/test/Dialect/EmitC/ops.mlir +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -355,3 +355,13 @@ func.func @do(%arg0 : !emitc.ptr<i32>) { return } + +func.func @address_of(%arg0: !emitc.lvalue<i32>) { + %1 = emitc.address_of %arg0 : !emitc.lvalue<i32> + return +} + +func.func @dereference(%arg0: !emitc.ptr<i32>) { + %1 = emitc.dereference %arg0 : !emitc.ptr<i32> + return +} diff --git a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir index 1f8da78..bc04e8f 100644 --- a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir +++ b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s --split-input-file --duplicate-function-elimination | \ -// RUN: FileCheck %s +// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-2,CHECK-3 func.func @identity(%arg0: tensor<f32>) -> tensor<f32> { return %arg0 : tensor<f32> diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir index 35381da..26bcf94 100644 --- a/mlir/test/Dialect/GPU/invalid.mlir +++ b/mlir/test/Dialect/GPU/invalid.mlir @@ -688,7 +688,7 @@ func.func @mmamatrix_operand_type(){ func.func @mmamatrix_invalid_element_type(){ %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3> %i = arith.constant 16 : index - // expected-error @+1 {{MMAMatrixType elements must be SI8, UI8, I32, F16, or F32}} + // expected-error @+1 {{MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64}} %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xbf16, "AOp"> return } @@ -708,7 +708,7 @@ func.func @mmaLoadOp_identity_layout(){ // ----- func.func @mma_invalid_memref_type(%src: memref<32x4xvector<4x8xf32>>, %i: index) { - // expected-error @+1 {{operand #0 must be memref of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float or vector of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float values of ranks 1 values}} + // expected-error @+1 {{operand #0 must be memref of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float or 64-bit float or vector of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float or 64-bit float values of ranks 1 values}} %0 = gpu.subgroup_mma_load_matrix %src[%i, %i] {leadDimension = 4 : index} : memref<32x4xvector<4x8xf32>> -> !gpu.mma_matrix<16x16xf16, "AOp"> return } diff --git a/mlir/test/Dialect/IRDL/variadics.mlir b/mlir/test/Dialect/IRDL/variadics.mlir index a8871fc..873f248 100644 --- a/mlir/test/Dialect/IRDL/variadics.mlir +++ b/mlir/test/Dialect/IRDL/variadics.mlir @@ -133,7 +133,7 @@ func.func @testOptOperandFail(%x: i16) { // Check that an operation with multiple variadics expects the segment size // attribute func.func @testMultOperandsMissingSegment(%x: i16, %z: i64) { - // expected-error@+1 {{'operand_segment_sizes' attribute is expected but not provided}} + // expected-error@+1 {{'operandSegmentSizes' attribute is expected but not provided}} "testvar.var_and_opt_operand"(%x, %x, %z) : (i16, i16, i64) -> () return } @@ -143,8 +143,8 @@ func.func @testMultOperandsMissingSegment(%x: i16, %z: i64) { // Check that an operation with multiple variadics expects the segment size // attribute of the right type func.func @testMultOperandsWrongSegmentType(%x: i16, %z: i64) { - // expected-error@+1 {{'operand_segment_sizes' attribute is expected to be a dense i32 array}} - "testvar.var_and_opt_operand"(%x, %x, %z) {operand_segment_sizes = i32} : (i16, i16, i64) -> () + // expected-error@+1 {{'operandSegmentSizes' attribute is expected to be a dense i32 array}} + "testvar.var_and_opt_operand"(%x, %x, %z) {operandSegmentSizes = i32} : (i16, i16, i64) -> () return } @@ -153,12 +153,12 @@ func.func @testMultOperandsWrongSegmentType(%x: i16, %z: i64) { // Check that an operation with multiple variadics with the right segment size // verifies. func.func @testMultOperands(%x: i16, %y: i32, %z: i64) { - "testvar.var_and_opt_operand"(%x, %x, %z) {operand_segment_sizes = array<i32: 2, 0, 1>} : (i16, i16, i64) -> () - // CHECK: "testvar.var_and_opt_operand"(%{{.*}}, %{{.*}}, %{{.*}}) {operand_segment_sizes = array<i32: 2, 0, 1>} : (i16, i16, i64) -> () - "testvar.var_and_opt_operand"(%x, %x, %y, %z) {operand_segment_sizes = array<i32: 2, 1, 1>} : (i16, i16, i32, i64) -> () - // CHECK-NEXT: "testvar.var_and_opt_operand"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) {operand_segment_sizes = array<i32: 2, 1, 1>} : (i16, i16, i32, i64) -> () - "testvar.var_and_opt_operand"(%y, %z) {operand_segment_sizes = array<i32: 0, 1, 1>} : (i32, i64) -> () - // CHECK-NEXT: "testvar.var_and_opt_operand"(%{{.*}}, %{{.*}}) {operand_segment_sizes = array<i32: 0, 1, 1>} : (i32, i64) -> () + "testvar.var_and_opt_operand"(%x, %x, %z) {operandSegmentSizes = array<i32: 2, 0, 1>} : (i16, i16, i64) -> () + // CHECK: "testvar.var_and_opt_operand"(%{{.*}}, %{{.*}}, %{{.*}}) {operandSegmentSizes = array<i32: 2, 0, 1>} : (i16, i16, i64) -> () + "testvar.var_and_opt_operand"(%x, %x, %y, %z) {operandSegmentSizes = array<i32: 2, 1, 1>} : (i16, i16, i32, i64) -> () + // CHECK-NEXT: "testvar.var_and_opt_operand"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) {operandSegmentSizes = array<i32: 2, 1, 1>} : (i16, i16, i32, i64) -> () + "testvar.var_and_opt_operand"(%y, %z) {operandSegmentSizes = array<i32: 0, 1, 1>} : (i32, i64) -> () + // CHECK-NEXT: "testvar.var_and_opt_operand"(%{{.*}}, %{{.*}}) {operandSegmentSizes = array<i32: 0, 1, 1>} : (i32, i64) -> () return } @@ -166,8 +166,8 @@ func.func @testMultOperands(%x: i16, %y: i32, %z: i64) { // Check that the segment sizes expects non-negative values func.func @testMultOperandsSegmentNegative() { - // expected-error@+1 {{'operand_segment_sizes' attribute for specifying operand segments must have non-negative values}} - "testvar.var_and_opt_operand"() {operand_segment_sizes = array<i32: 2, -1, 1>} : () -> () + // expected-error@+1 {{'operandSegmentSizes' attribute for specifying operand segments must have non-negative values}} + "testvar.var_and_opt_operand"() {operandSegmentSizes = array<i32: 2, -1, 1>} : () -> () return } @@ -175,8 +175,8 @@ func.func @testMultOperandsSegmentNegative() { // Check that the segment sizes expects 1 for single values func.func @testMultOperandsSegmentWrongSingle() { - // expected-error@+1 {{element 2 in 'operand_segment_sizes' attribute must be equal to 1}} - "testvar.var_and_opt_operand"() {operand_segment_sizes = array<i32: 0, 0, 0>} : () -> () + // expected-error@+1 {{element 2 in 'operandSegmentSizes' attribute must be equal to 1}} + "testvar.var_and_opt_operand"() {operandSegmentSizes = array<i32: 0, 0, 0>} : () -> () return } @@ -184,8 +184,8 @@ func.func @testMultOperandsSegmentWrongSingle() { // Check that the segment sizes expects not more than 1 for optional values func.func @testMultOperandsSegmentWrongOptional() { - // expected-error@+1 {{element 1 in 'operand_segment_sizes' attribute must be equal to 0 or 1}} - "testvar.var_and_opt_operand"() {operand_segment_sizes = array<i32: 0, 2, 0>} : () -> () + // expected-error@+1 {{element 1 in 'operandSegmentSizes' attribute must be equal to 0 or 1}} + "testvar.var_and_opt_operand"() {operandSegmentSizes = array<i32: 0, 2, 0>} : () -> () return } @@ -193,8 +193,8 @@ func.func @testMultOperandsSegmentWrongOptional() { // Check that the sum of the segment sizes should be equal to the number of operands func.func @testMultOperandsSegmentWrongOptional(%y: i32, %z: i64) { - // expected-error@+1 {{sum of elements in 'operand_segment_sizes' attribute must be equal to the number of operands}} - "testvar.var_and_opt_operand"(%y, %z) {operand_segment_sizes = array<i32: 0, 0, 1>} : (i32, i64) -> () + // expected-error@+1 {{sum of elements in 'operandSegmentSizes' attribute must be equal to the number of operands}} + "testvar.var_and_opt_operand"(%y, %z) {operandSegmentSizes = array<i32: 0, 0, 1>} : (i32, i64) -> () return } @@ -334,7 +334,7 @@ func.func @testOptResultFail() { // Check that an operation with multiple variadics expects the segment size // attribute func.func @testMultResultsMissingSegment() { - // expected-error@+1 {{'result_segment_sizes' attribute is expected but not provided}} + // expected-error@+1 {{'resultSegmentSizes' attribute is expected but not provided}} "testvar.var_and_opt_result"() : () -> (i16, i16, i64) return } @@ -344,8 +344,8 @@ func.func @testMultResultsMissingSegment() { // Check that an operation with multiple variadics expects the segment size // attribute of the right type func.func @testMultResultsWrongSegmentType() { - // expected-error@+1 {{'result_segment_sizes' attribute is expected to be a dense i32 array}} - "testvar.var_and_opt_result"() {result_segment_sizes = i32} : () -> (i16, i16, i64) + // expected-error@+1 {{'resultSegmentSizes' attribute is expected to be a dense i32 array}} + "testvar.var_and_opt_result"() {resultSegmentSizes = i32} : () -> (i16, i16, i64) return } @@ -354,12 +354,12 @@ func.func @testMultResultsWrongSegmentType() { // Check that an operation with multiple variadics with the right segment size // verifies. func.func @testMultResults() { - "testvar.var_and_opt_result"() {result_segment_sizes = array<i32: 2, 0, 1>} : () -> (i16, i16, i64) - // CHECK: "testvar.var_and_opt_result"() {result_segment_sizes = array<i32: 2, 0, 1>} : () -> (i16, i16, i64) - "testvar.var_and_opt_result"() {result_segment_sizes = array<i32: 2, 1, 1>} : () -> (i16, i16, i32, i64) - // CHECK-NEXT: "testvar.var_and_opt_result"() {result_segment_sizes = array<i32: 2, 1, 1>} : () -> (i16, i16, i32, i64) - "testvar.var_and_opt_result"() {result_segment_sizes = array<i32: 0, 1, 1>} : () -> (i32, i64) - // CHECK-NEXT: "testvar.var_and_opt_result"() {result_segment_sizes = array<i32: 0, 1, 1>} : () -> (i32, i64) + "testvar.var_and_opt_result"() {resultSegmentSizes = array<i32: 2, 0, 1>} : () -> (i16, i16, i64) + // CHECK: "testvar.var_and_opt_result"() {resultSegmentSizes = array<i32: 2, 0, 1>} : () -> (i16, i16, i64) + "testvar.var_and_opt_result"() {resultSegmentSizes = array<i32: 2, 1, 1>} : () -> (i16, i16, i32, i64) + // CHECK-NEXT: "testvar.var_and_opt_result"() {resultSegmentSizes = array<i32: 2, 1, 1>} : () -> (i16, i16, i32, i64) + "testvar.var_and_opt_result"() {resultSegmentSizes = array<i32: 0, 1, 1>} : () -> (i32, i64) + // CHECK-NEXT: "testvar.var_and_opt_result"() {resultSegmentSizes = array<i32: 0, 1, 1>} : () -> (i32, i64) return } @@ -367,8 +367,8 @@ func.func @testMultResults() { // Check that the segment sizes expects non-negative values func.func @testMultResultsSegmentNegative() { - // expected-error@+1 {{'result_segment_sizes' attribute for specifying result segments must have non-negative values}} - "testvar.var_and_opt_result"() {result_segment_sizes = array<i32: 2, -1, 1>} : () -> () + // expected-error@+1 {{'resultSegmentSizes' attribute for specifying result segments must have non-negative values}} + "testvar.var_and_opt_result"() {resultSegmentSizes = array<i32: 2, -1, 1>} : () -> () return } @@ -376,8 +376,8 @@ func.func @testMultResultsSegmentNegative() { // Check that the segment sizes expects 1 for single values func.func @testMultResultsSegmentWrongSingle() { - // expected-error@+1 {{element 2 in 'result_segment_sizes' attribute must be equal to 1}} - "testvar.var_and_opt_result"() {result_segment_sizes = array<i32: 0, 0, 0>} : () -> () + // expected-error@+1 {{element 2 in 'resultSegmentSizes' attribute must be equal to 1}} + "testvar.var_and_opt_result"() {resultSegmentSizes = array<i32: 0, 0, 0>} : () -> () return } @@ -385,8 +385,8 @@ func.func @testMultResultsSegmentWrongSingle() { // Check that the segment sizes expects not more than 1 for optional values func.func @testMultResultsSegmentWrongOptional() { - // expected-error@+1 {{element 1 in 'result_segment_sizes' attribute must be equal to 0 or 1}} - "testvar.var_and_opt_result"() {result_segment_sizes = array<i32: 0, 2, 0>} : () -> () + // expected-error@+1 {{element 1 in 'resultSegmentSizes' attribute must be equal to 0 or 1}} + "testvar.var_and_opt_result"() {resultSegmentSizes = array<i32: 0, 2, 0>} : () -> () return } @@ -394,7 +394,7 @@ func.func @testMultResultsSegmentWrongOptional() { // Check that the sum of the segment sizes should be equal to the number of results func.func @testMultResultsSegmentWrongOptional() { - // expected-error@+1 {{sum of elements in 'result_segment_sizes' attribute must be equal to the number of results}} - "testvar.var_and_opt_result"() {result_segment_sizes = array<i32: 0, 0, 1>} : () -> (i32, i64) + // expected-error@+1 {{sum of elements in 'resultSegmentSizes' attribute must be equal to the number of results}} + "testvar.var_and_opt_result"() {resultSegmentSizes = array<i32: 0, 0, 1>} : () -> (i32, i64) return } diff --git a/mlir/test/Dialect/Index/inliner-interface.mlir b/mlir/test/Dialect/Index/inliner-interface.mlir new file mode 100644 index 0000000..4c3d106 --- /dev/null +++ b/mlir/test/Dialect/Index/inliner-interface.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-opt %s -inline | FileCheck %s + +// CHECK-LABEL: @main +func.func @main(%arg0: i32) -> index { + // CHECK-NOT: call + // CHECK: index.castu + %0 = call @f(%arg0) : (i32) -> index + return %0 : index +} + +// CHECK-LABEL: @f +func.func @f(%arg0: i32) -> index { + %0 = index.castu %arg0 : i32 to index + return %0 : index +} diff --git a/mlir/test/Dialect/LLVMIR/add-debuginfo-func-scope.mlir b/mlir/test/Dialect/LLVMIR/add-debuginfo-func-scope.mlir index dfbf992..ffeb871 100644 --- a/mlir/test/Dialect/LLVMIR/add-debuginfo-func-scope.mlir +++ b/mlir/test/Dialect/LLVMIR/add-debuginfo-func-scope.mlir @@ -141,3 +141,22 @@ module { llvm.func @func_callsiteloc() loc(callsite("foo" at "mysource.cc":10:8)) } loc(unknown) +// ----- + +// CHECK-LABEL: llvm.func @func_cross_file_op() +// CHECK: #di_file = #llvm.di_file<"<unknown>" in ""> +// CHECK: #di_file1 = #llvm.di_file<"caller.py" in ""> +// CHECK: #di_file2 = #llvm.di_file<"callee.py" in ""> +// CHECK: #di_subroutine_type = #llvm.di_subroutine_type<callingConvention = DW_CC_normal> +// CHECK: #di_subprogram = #llvm.di_subprogram<id = distinct[1]<>, compileUnit = #di_compile_unit, scope = #di_file1, name = "func_cross_file_op", linkageName = "func_cross_file_op", file = #di_file1, line = 5, scopeLine = 5, subprogramFlags = "Definition|Optimized", type = #di_subroutine_type> +// CHECK: #di_lexical_block_file = #llvm.di_lexical_block_file<scope = #di_subprogram, file = #di_file2, discriminator = 0> + +#loc = loc("caller.py":5:1) +#loc1 = loc("callee.py":10:5) + +module { + llvm.func @func_cross_file_op() { + llvm.return loc(#loc1) + } loc(#loc) +} loc(unknown) + diff --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir index cec4586..094313c 100644 --- a/mlir/test/Dialect/LLVMIR/func.mlir +++ b/mlir/test/Dialect/LLVMIR/func.mlir @@ -210,8 +210,8 @@ module { } // CHECK-LABEL: llvm.func @memory_attr - // CHECK-SAME: attributes {memory = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = readwrite>} { - llvm.func @memory_attr() attributes {memory = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = readwrite>} { + // CHECK-SAME: attributes {memory = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = readwrite, errnoMem = none, targetMem0 = none, targetMem1 = none>} { + llvm.func @memory_attr() attributes {memory = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = readwrite, errnoMem = none, targetMem0 = none, targetMem1 = none>} { llvm.return } diff --git a/mlir/test/Dialect/LLVMIR/inlining.mlir b/mlir/test/Dialect/LLVMIR/inlining.mlir index 8e292f4..9a77c5e 100644 --- a/mlir/test/Dialect/LLVMIR/inlining.mlir +++ b/mlir/test/Dialect/LLVMIR/inlining.mlir @@ -422,7 +422,7 @@ llvm.func @test_byval(%ptr : !llvm.ptr) { // ----- -llvm.func @with_byval_arg(%ptr : !llvm.ptr { llvm.byval = f64 }) attributes {memory_effects = #llvm.memory_effects<other = readwrite, argMem = read, inaccessibleMem = readwrite>} { +llvm.func @with_byval_arg(%ptr : !llvm.ptr { llvm.byval = f64 }) attributes {memory_effects = #llvm.memory_effects<other = readwrite, argMem = read, inaccessibleMem = readwrite, errnoMem = none, targetMem0 = none, targetMem1 = none>} { llvm.return } @@ -436,7 +436,7 @@ llvm.func @test_byval_read_only(%ptr : !llvm.ptr) { // ----- -llvm.func @with_byval_arg(%ptr : !llvm.ptr { llvm.byval = f64 }) attributes {memory_effects = #llvm.memory_effects<other = readwrite, argMem = write, inaccessibleMem = readwrite>} { +llvm.func @with_byval_arg(%ptr : !llvm.ptr { llvm.byval = f64 }) attributes {memory_effects = #llvm.memory_effects<other = readwrite, argMem = write, inaccessibleMem = readwrite, errnoMem = none, targetMem0 = none, targetMem1 = none>} { llvm.return } @@ -451,7 +451,7 @@ llvm.func @test_byval_write_only(%ptr : !llvm.ptr) { // ----- -llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read>} { +llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read, errnoMem = none, targetMem0 = none, targetMem1 = none>} { llvm.return } @@ -472,7 +472,7 @@ llvm.func @test_byval_input_aligned(%unaligned : !llvm.ptr, %aligned : !llvm.ptr llvm.func @func_that_uses_ptr(%ptr : !llvm.ptr) -llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read>} { +llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read, errnoMem = none, targetMem0 = none, targetMem1 = none>} { llvm.call @func_that_uses_ptr(%ptr) : (!llvm.ptr) -> () llvm.return } @@ -496,7 +496,7 @@ module attributes { llvm.func @func_that_uses_ptr(%ptr : !llvm.ptr) -llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read>} { +llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read, errnoMem = none, targetMem0 = none, targetMem1 = none>} { llvm.call @func_that_uses_ptr(%ptr) : (!llvm.ptr) -> () llvm.return } @@ -524,7 +524,7 @@ module attributes { llvm.func @func_that_uses_ptr(%ptr : !llvm.ptr) -llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read>} { +llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read, errnoMem = none, targetMem0 = none, targetMem1 = none>} { llvm.call @func_that_uses_ptr(%ptr) : (!llvm.ptr) -> () llvm.return } @@ -550,7 +550,7 @@ llvm.func @test_alignment_exceeded_anyway() { llvm.mlir.global private @unaligned_global(42 : i64) : i64 llvm.mlir.global private @aligned_global(42 : i64) { alignment = 64 } : i64 -llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read>} { +llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read, errnoMem = none, targetMem0 = none, targetMem1 = none>} { llvm.return } diff --git a/mlir/test/Dialect/LLVMIR/invalid-cg-profile.mlir b/mlir/test/Dialect/LLVMIR/invalid-cg-profile.mlir new file mode 100644 index 0000000..bdc98ed --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/invalid-cg-profile.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-translate %s -mlir-to-llvmir | FileCheck %s +// CHECK: !llvm.module.flags = !{![[CG_FLAG:[0-9]+]], ![[DBG_FLAG:[0-9]+]]} +// CHECK: ![[CG_FLAG]] = !{i32 5, !"CG Profile", ![[CG_LIST:[0-9]+]]} +// CHECK: ![[CG_LIST]] = distinct !{![[CG_ENTRY:[0-9]+]], ![[CG_ENTRY]], ![[CG_ENTRY]]} +// CHECK: ![[CG_ENTRY]] = !{null, null, i64 222} +// CHECK: ![[DBG_FLAG]] = !{i32 2, !"Debug Info Version", i32 3} + +module { + llvm.module_flags [#llvm.mlir.module_flag<append, "CG Profile", [ + #llvm.cgprofile_entry<from = @from, to = @to, count = 222>, + #llvm.cgprofile_entry<from = @from, count = 222>, + #llvm.cgprofile_entry<from = @to, to = @from, count = 222> + ]>] +} diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index aaf9f80..49b6342 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -664,21 +664,21 @@ func.func @zero_non_llvm_type() { // ----- func.func @nvvm_invalid_shfl_pred_1(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) { - // expected-error@+1 {{expected return type to be a two-element struct with i1 as the second element}} + // expected-error@+1 {{expected return type to be a two-element struct}} %0 = nvvm.shfl.sync bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> i32 } // ----- func.func @nvvm_invalid_shfl_pred_2(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) { - // expected-error@+1 {{expected return type to be a two-element struct with i1 as the second element}} + // expected-error@+1 {{expected return type to be a two-element struct}} %0 = nvvm.shfl.sync bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32)> } // ----- func.func @nvvm_invalid_shfl_pred_3(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) { - // expected-error@+1 {{expected return type to be a two-element struct with i1 as the second element}} + // expected-error@+1 {{expected second element in the returned struct to be of type 'i1' but got 'i32' instead}} %0 = nvvm.shfl.sync bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i32)> } diff --git a/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-kind.mlir b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-kind.mlir new file mode 100644 index 0000000..ff3e91b --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-kind.mlir @@ -0,0 +1,221 @@ +// RUN: mlir-opt %s -split-input-file | FileCheck %s + +// This file contains tests for sparse MMA (mma.sp.sync) operations with KIND variants. +// The kind::f8f6f4 variant was introduced in PTX ISA 8.7 for sm_90+ architectures. +// +// Based on PTX ISA documentation: +// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-for-sparse-mma +// +// KIND::F8F6F4 enables: +// - Additional FP8 types: e3m2, e2m3, e2m1 +// - F16 accumulator for m16n8k64 FP8 operations +// - Mixed-precision FP8 computations +// +// Requirements: +// - ONLY works with ordered metadata (sp::ordered_metadata) +// - ONLY for shape m16n8k64 +// - ONLY for FP8 types (not integers or other floats) + +// ============================================================================= +// FP8 e4m3 Sparse MMA with KIND (m16n8k64) +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e4m3_f16 +func.func @nvvm_mma_sp_kind_m16n8k64_e4m3_f16( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e4m3>, multiplicandBPtxType = #nvvm.mma_type<e4m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {kind = #nvvm.mma_kind<f8f6f4>, + orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<e4m3>, + multiplicandBPtxType = #nvvm.mma_type<e4m3>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e4m3_f32 +func.func @nvvm_mma_sp_kind_m16n8k64_e4m3_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e4m3>, multiplicandBPtxType = #nvvm.mma_type<e4m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {kind = #nvvm.mma_kind<f8f6f4>, + orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<e4m3>, + multiplicandBPtxType = #nvvm.mma_type<e4m3>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// FP8 e5m2 Sparse MMA with KIND (m16n8k64) +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e5m2_f16 +func.func @nvvm_mma_sp_kind_m16n8k64_e5m2_f16( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e5m2>, multiplicandBPtxType = #nvvm.mma_type<e5m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {kind = #nvvm.mma_kind<f8f6f4>, + orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<e5m2>, + multiplicandBPtxType = #nvvm.mma_type<e5m2>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e5m2_f32 +func.func @nvvm_mma_sp_kind_m16n8k64_e5m2_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e5m2>, multiplicandBPtxType = #nvvm.mma_type<e5m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {kind = #nvvm.mma_kind<f8f6f4>, + orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<e5m2>, + multiplicandBPtxType = #nvvm.mma_type<e5m2>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// FP8 e3m2 Sparse MMA with KIND (m16n8k64) +// NOTE: e3m2 is ONLY available with kind::f8f6f4 +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e3m2_f16 +func.func @nvvm_mma_sp_kind_m16n8k64_e3m2_f16( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e3m2>, multiplicandBPtxType = #nvvm.mma_type<e3m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {kind = #nvvm.mma_kind<f8f6f4>, + orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<e3m2>, + multiplicandBPtxType = #nvvm.mma_type<e3m2>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e3m2_f32 +func.func @nvvm_mma_sp_kind_m16n8k64_e3m2_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e3m2>, multiplicandBPtxType = #nvvm.mma_type<e3m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {kind = #nvvm.mma_kind<f8f6f4>, + orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<e3m2>, + multiplicandBPtxType = #nvvm.mma_type<e3m2>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// FP8 e2m3 Sparse MMA with KIND (m16n8k64) +// NOTE: e2m3 is ONLY available with kind::f8f6f4 +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e2m3_f16 +func.func @nvvm_mma_sp_kind_m16n8k64_e2m3_f16( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e2m3>, multiplicandBPtxType = #nvvm.mma_type<e2m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {kind = #nvvm.mma_kind<f8f6f4>, + orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<e2m3>, + multiplicandBPtxType = #nvvm.mma_type<e2m3>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e2m3_f32 +func.func @nvvm_mma_sp_kind_m16n8k64_e2m3_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e2m3>, multiplicandBPtxType = #nvvm.mma_type<e2m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {kind = #nvvm.mma_kind<f8f6f4>, + orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<e2m3>, + multiplicandBPtxType = #nvvm.mma_type<e2m3>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// FP8 e2m1 Sparse MMA with KIND (m16n8k64) +// NOTE: e2m1 is ONLY available with kind::f8f6f4 +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e2m1_f16 +func.func @nvvm_mma_sp_kind_m16n8k64_e2m1_f16( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e2m1>, multiplicandBPtxType = #nvvm.mma_type<e2m1>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {kind = #nvvm.mma_kind<f8f6f4>, + orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<e2m1>, + multiplicandBPtxType = #nvvm.mma_type<e2m1>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e2m1_f32 +func.func @nvvm_mma_sp_kind_m16n8k64_e2m1_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e2m1>, multiplicandBPtxType = #nvvm.mma_type<e2m1>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {kind = #nvvm.mma_kind<f8f6f4>, + orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<e2m1>, + multiplicandBPtxType = #nvvm.mma_type<e2m1>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + diff --git a/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-ordered.mlir b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-ordered.mlir new file mode 100644 index 0000000..a4e2812 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-ordered.mlir @@ -0,0 +1,411 @@ +// RUN: mlir-opt %s -split-input-file | FileCheck %s + +// This file contains tests for sparse MMA (mma.sp.sync) operations with ORDERED metadata. +// The ordered metadata variant was introduced in PTX ISA 8.5 for sm_90+ architectures. +// +// Based on PTX ISA documentation: +// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-for-sparse-mma +// +// Ordered metadata provides an alternative metadata ordering for 2:4 structured sparsity +// that can offer better performance on newer architectures. + +// ============================================================================= +// F16 Sparse MMA Operations with Ordered Metadata (m16n8k16) +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k16_f16_f16 +func.func @nvvm_mma_sp_ordered_m16n8k16_f16_f16( + %a0 : vector<2xf16>, %a1 : vector<2xf16>, + %b0 : vector<2xf16>, %b1 : vector<2xf16>, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + shape = #nvvm.shape<m = 16, n = 8, k = 16>} + : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k16_f16_f32 +func.func @nvvm_mma_sp_ordered_m16n8k16_f16_f32( + %a0 : vector<2xf16>, %a1 : vector<2xf16>, + %b0 : vector<2xf16>, %b1 : vector<2xf16>, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + shape = #nvvm.shape<m = 16, n = 8, k = 16>} + : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// F16 Sparse MMA Operations with Ordered Metadata (m16n8k32) +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_f16_f16 +func.func @nvvm_mma_sp_ordered_m16n8k32_f16_f16( + %a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>, + %b0 : vector<2xf16>, %b1 : vector<2xf16>, %b2 : vector<2xf16>, %b3 : vector<2xf16>, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + shape = #nvvm.shape<m = 16, n = 8, k = 32>} + : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_f16_f32 +func.func @nvvm_mma_sp_ordered_m16n8k32_f16_f32( + %a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>, + %b0 : vector<2xf16>, %b1 : vector<2xf16>, %b2 : vector<2xf16>, %b3 : vector<2xf16>, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + shape = #nvvm.shape<m = 16, n = 8, k = 32>} + : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// BF16 Sparse MMA Operations with Ordered Metadata +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k16_bf16_f32 +func.func @nvvm_mma_sp_ordered_m16n8k16_bf16_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<bf16>, + multiplicandBPtxType = #nvvm.mma_type<bf16>, + shape = #nvvm.shape<m = 16, n = 8, k = 16>} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_bf16_f32 +func.func @nvvm_mma_sp_ordered_m16n8k32_bf16_f32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<bf16>, + multiplicandBPtxType = #nvvm.mma_type<bf16>, + shape = #nvvm.shape<m = 16, n = 8, k = 32>} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// TF32 Sparse MMA Operations with Ordered Metadata +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k8_tf32_f32 +func.func @nvvm_mma_sp_ordered_m16n8k8_tf32_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<tf32>, + multiplicandBPtxType = #nvvm.mma_type<tf32>, + shape = #nvvm.shape<m = 16, n = 8, k = 8>} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k16_tf32_f32 +func.func @nvvm_mma_sp_ordered_m16n8k16_tf32_f32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<tf32>, + multiplicandBPtxType = #nvvm.mma_type<tf32>, + shape = #nvvm.shape<m = 16, n = 8, k = 16>} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// Integer (s8) Sparse MMA Operations with Ordered Metadata +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_s8_s32 +func.func @nvvm_mma_sp_ordered_m16n8k32_s8_s32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<s8>, + multiplicandBPtxType = #nvvm.mma_type<s8>, + intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, + shape = #nvvm.shape<m = 16, n = 8, k = 32>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_s8_s32_satfinite +func.func @nvvm_mma_sp_ordered_m16n8k32_s8_s32_satfinite( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<s8>, + multiplicandBPtxType = #nvvm.mma_type<s8>, + intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>, + shape = #nvvm.shape<m = 16, n = 8, k = 32>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_s8_s32 +func.func @nvvm_mma_sp_ordered_m16n8k64_s8_s32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<s8>, + multiplicandBPtxType = #nvvm.mma_type<s8>, + intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// ============================================================================= +// Integer (u8) Sparse MMA Operations with Ordered Metadata +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_u8_s32 +func.func @nvvm_mma_sp_ordered_m16n8k32_u8_s32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u8>, multiplicandBPtxType = #nvvm.mma_type<u8>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<u8>, + multiplicandBPtxType = #nvvm.mma_type<u8>, + intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, + shape = #nvvm.shape<m = 16, n = 8, k = 32>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_u8_s32 +func.func @nvvm_mma_sp_ordered_m16n8k64_u8_s32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u8>, multiplicandBPtxType = #nvvm.mma_type<u8>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<u8>, + multiplicandBPtxType = #nvvm.mma_type<u8>, + intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// ============================================================================= +// Sub-byte Integer (s4) Sparse MMA Operations with Ordered Metadata +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_s4_s32 +func.func @nvvm_mma_sp_ordered_m16n8k64_s4_s32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<s4>, + multiplicandBPtxType = #nvvm.mma_type<s4>, + intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k128_s4_s32 +func.func @nvvm_mma_sp_ordered_m16n8k128_s4_s32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 128>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<s4>, + multiplicandBPtxType = #nvvm.mma_type<s4>, + intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, + shape = #nvvm.shape<m = 16, n = 8, k = 128>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// ============================================================================= +// Sub-byte Integer (u4) Sparse MMA Operations with Ordered Metadata +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_u4_s32 +func.func @nvvm_mma_sp_ordered_m16n8k64_u4_s32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u4>, multiplicandBPtxType = #nvvm.mma_type<u4>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<u4>, + multiplicandBPtxType = #nvvm.mma_type<u4>, + intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k128_u4_s32 +func.func @nvvm_mma_sp_ordered_m16n8k128_u4_s32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u4>, multiplicandBPtxType = #nvvm.mma_type<u4>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 128>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<u4>, + multiplicandBPtxType = #nvvm.mma_type<u4>, + intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, + shape = #nvvm.shape<m = 16, n = 8, k = 128>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// ============================================================================= +// FP8 (e4m3) Sparse MMA Operations with Ordered Metadata +// NOTE: FP8 ordered metadata requires PTX ISA 8.7+ and sm_90+ +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_e4m3_f16 +func.func @nvvm_mma_sp_ordered_m16n8k64_e4m3_f16( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e4m3>, multiplicandBPtxType = #nvvm.mma_type<e4m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<e4m3>, + multiplicandBPtxType = #nvvm.mma_type<e4m3>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_e4m3_f32 +func.func @nvvm_mma_sp_ordered_m16n8k64_e4m3_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e4m3>, multiplicandBPtxType = #nvvm.mma_type<e4m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<e4m3>, + multiplicandBPtxType = #nvvm.mma_type<e4m3>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// FP8 (e5m2) Sparse MMA Operations with Ordered Metadata +// NOTE: FP8 ordered metadata requires PTX ISA 8.7+ and sm_90+ +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_e5m2_f16 +func.func @nvvm_mma_sp_ordered_m16n8k64_e5m2_f16( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e5m2>, multiplicandBPtxType = #nvvm.mma_type<e5m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<e5m2>, + multiplicandBPtxType = #nvvm.mma_type<e5m2>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_e5m2_f32 +func.func @nvvm_mma_sp_ordered_m16n8k64_e5m2_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e5m2>, multiplicandBPtxType = #nvvm.mma_type<e5m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<e5m2>, + multiplicandBPtxType = #nvvm.mma_type<e5m2>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + diff --git a/mlir/test/Dialect/LLVMIR/nvvm-mma-sp.mlir b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp.mlir new file mode 100644 index 0000000..e7122aa --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp.mlir @@ -0,0 +1,390 @@ +// RUN: mlir-opt %s -split-input-file | FileCheck %s + +// This file contains tests for all sparse MMA (mma.sp.sync) operations in the NVVM dialect +// Based on PTX ISA documentation: +// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-for-sparse-mma +// +// Sparse MMA operations follow 2:4 structured sparsity where 2 out of every 4 elements +// in the A operand are non-zero. The A operand is provided in compressed form, +// and sparseMetadata provides the sparsity indices. +// +// NOTE: These tests use the default (standard) metadata ordering. +// For ordered metadata tests (PTX ISA 8.5+, sm_90+), see nvvm-mma-sp-ordered.mlir. + +// ============================================================================= +// F16 Sparse MMA Operations (m16n8k16) +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k16_f16_f16 +func.func @nvvm_mma_sp_m16n8k16_f16_f16( + %a0 : vector<2xf16>, %a1 : vector<2xf16>, + %b0 : vector<2xf16>, %b1 : vector<2xf16>, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {shape = #nvvm.shape<m = 16, n = 8, k = 16>} + : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k16_f16_f32 +func.func @nvvm_mma_sp_m16n8k16_f16_f32( + %a0 : vector<2xf16>, %a1 : vector<2xf16>, + %b0 : vector<2xf16>, %b1 : vector<2xf16>, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {shape = #nvvm.shape<m = 16, n = 8, k = 16>} + : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// F16 Sparse MMA Operations (m16n8k32) +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_f16_f16 +func.func @nvvm_mma_sp_m16n8k32_f16_f16( + %a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>, + %b0 : vector<2xf16>, %b1 : vector<2xf16>, %b2 : vector<2xf16>, %b3 : vector<2xf16>, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {shape = #nvvm.shape<m = 16, n = 8, k = 32>} + : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_f16_f32 +func.func @nvvm_mma_sp_m16n8k32_f16_f32( + %a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>, + %b0 : vector<2xf16>, %b1 : vector<2xf16>, %b2 : vector<2xf16>, %b3 : vector<2xf16>, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {shape = #nvvm.shape<m = 16, n = 8, k = 32>} + : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// BF16 Sparse MMA Operations +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k16_bf16_f32 +func.func @nvvm_mma_sp_m16n8k16_bf16_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type<bf16>, + multiplicandBPtxType = #nvvm.mma_type<bf16>, + shape = #nvvm.shape<m = 16, n = 8, k = 16>} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_bf16_f32 +func.func @nvvm_mma_sp_m16n8k32_bf16_f32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type<bf16>, + multiplicandBPtxType = #nvvm.mma_type<bf16>, + shape = #nvvm.shape<m = 16, n = 8, k = 32>} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// TF32 Sparse MMA Operations +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k8_tf32_f32 +func.func @nvvm_mma_sp_m16n8k8_tf32_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type<tf32>, + multiplicandBPtxType = #nvvm.mma_type<tf32>, + shape = #nvvm.shape<m = 16, n = 8, k = 8>} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k16_tf32_f32 +func.func @nvvm_mma_sp_m16n8k16_tf32_f32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type<tf32>, + multiplicandBPtxType = #nvvm.mma_type<tf32>, + shape = #nvvm.shape<m = 16, n = 8, k = 16>} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// Integer (s8) Sparse MMA Operations +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_s8_s32 +func.func @nvvm_mma_sp_m16n8k32_s8_s32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type<s8>, + multiplicandBPtxType = #nvvm.mma_type<s8>, + intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, + shape = #nvvm.shape<m = 16, n = 8, k = 32>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_s8_s32_satfinite +func.func @nvvm_mma_sp_m16n8k32_s8_s32_satfinite( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type<s8>, + multiplicandBPtxType = #nvvm.mma_type<s8>, + intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>, + shape = #nvvm.shape<m = 16, n = 8, k = 32>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_s8_s32 +func.func @nvvm_mma_sp_m16n8k64_s8_s32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type<s8>, + multiplicandBPtxType = #nvvm.mma_type<s8>, + intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// ============================================================================= +// Integer (u8) Sparse MMA Operations +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_u8_s32 +func.func @nvvm_mma_sp_m16n8k32_u8_s32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u8>, multiplicandBPtxType = #nvvm.mma_type<u8>, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type<u8>, + multiplicandBPtxType = #nvvm.mma_type<u8>, + intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, + shape = #nvvm.shape<m = 16, n = 8, k = 32>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_u8_s32 +func.func @nvvm_mma_sp_m16n8k64_u8_s32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u8>, multiplicandBPtxType = #nvvm.mma_type<u8>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type<u8>, + multiplicandBPtxType = #nvvm.mma_type<u8>, + intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// ============================================================================= +// Sub-byte Integer (s4) Sparse MMA Operations +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_s4_s32 +func.func @nvvm_mma_sp_m16n8k64_s4_s32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type<s4>, + multiplicandBPtxType = #nvvm.mma_type<s4>, + intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k128_s4_s32 +func.func @nvvm_mma_sp_m16n8k128_s4_s32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>, shape = #nvvm.shape<m = 16, n = 8, k = 128>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type<s4>, + multiplicandBPtxType = #nvvm.mma_type<s4>, + intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, + shape = #nvvm.shape<m = 16, n = 8, k = 128>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// ============================================================================= +// Sub-byte Integer (u4) Sparse MMA Operations +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_u4_s32 +func.func @nvvm_mma_sp_m16n8k64_u4_s32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u4>, multiplicandBPtxType = #nvvm.mma_type<u4>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type<u4>, + multiplicandBPtxType = #nvvm.mma_type<u4>, + intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k128_u4_s32 +func.func @nvvm_mma_sp_m16n8k128_u4_s32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u4>, multiplicandBPtxType = #nvvm.mma_type<u4>, shape = #nvvm.shape<m = 16, n = 8, k = 128>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type<u4>, + multiplicandBPtxType = #nvvm.mma_type<u4>, + intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, + shape = #nvvm.shape<m = 16, n = 8, k = 128>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// ============================================================================= +// FP8 (e4m3) Sparse MMA Operations +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_e4m3_f16 +func.func @nvvm_mma_sp_m16n8k64_e4m3_f16( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e4m3>, multiplicandBPtxType = #nvvm.mma_type<e4m3>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type<e4m3>, + multiplicandBPtxType = #nvvm.mma_type<e4m3>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_e4m3_f32 +func.func @nvvm_mma_sp_m16n8k64_e4m3_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e4m3>, multiplicandBPtxType = #nvvm.mma_type<e4m3>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type<e4m3>, + multiplicandBPtxType = #nvvm.mma_type<e4m3>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// FP8 (e5m2) Sparse MMA Operations +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_e5m2_f16 +func.func @nvvm_mma_sp_m16n8k64_e5m2_f16( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e5m2>, multiplicandBPtxType = #nvvm.mma_type<e5m2>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type<e5m2>, + multiplicandBPtxType = #nvvm.mma_type<e5m2>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_e5m2_f32 +func.func @nvvm_mma_sp_m16n8k64_e5m2_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e5m2>, multiplicandBPtxType = #nvvm.mma_type<e5m2>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type<e5m2>, + multiplicandBPtxType = #nvvm.mma_type<e5m2>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + diff --git a/mlir/test/Dialect/LLVMIR/nvvm-target-invalid.mlir b/mlir/test/Dialect/LLVMIR/nvvm-target-invalid.mlir new file mode 100644 index 0000000..c2cfa76 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/nvvm-target-invalid.mlir @@ -0,0 +1,11 @@ +// RUN: not mlir-opt %s 2>&1 | FileCheck %s +// CHECK: 'nvvm.tcgen05.alloc' op is not supported on sm_90 + +module { + gpu.module @mod [#nvvm.target<chip = "sm_90">] { + func.func @tcgen05_alloc(%arg0: !llvm.ptr<7>, %arg1: i32) { + nvvm.tcgen05.alloc %arg0, %arg1 : !llvm.ptr<7>, i32 + return + } + } +} diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir index 2505e56..579f0ac 100644 --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -92,13 +92,6 @@ func.func @llvm_nvvm_cluster_wait() { llvm.return } -// CHECK-LABEL: @llvm_nvvm_fence_sc_cluster -func.func @llvm_nvvm_fence_sc_cluster() { - // CHECK: nvvm.fence.sc.cluster - nvvm.fence.sc.cluster - llvm.return -} - // CHECK-LABEL: @nvvm_shfl func.func @nvvm_shfl( %arg0 : i32, %arg1 : i32, %arg2 : i32, @@ -445,8 +438,8 @@ llvm.func private @mbarrier_arrive(%barrier: !llvm.ptr) { } llvm.func private @mbarrier_arrive_shared(%barrier: !llvm.ptr<3>) { - // CHECK: nvvm.mbarrier.arrive.shared %{{.*}} : !llvm.ptr<3> - %0 = nvvm.mbarrier.arrive.shared %barrier : !llvm.ptr<3> -> i64 + // CHECK: nvvm.mbarrier.arrive %{{.*}} : !llvm.ptr<3> + %0 = nvvm.mbarrier.arrive %barrier : !llvm.ptr<3> -> i64 llvm.return } @@ -459,21 +452,8 @@ llvm.func private @mbarrier_arrive_nocomplete(%barrier: !llvm.ptr) { llvm.func private @mbarrier_arrive_nocomplete_shared(%barrier: !llvm.ptr<3>) { %count = nvvm.read.ptx.sreg.ntid.x : i32 - // CHECK: nvvm.mbarrier.arrive.nocomplete.shared %{{.*}} : !llvm.ptr<3> - %0 = nvvm.mbarrier.arrive.nocomplete.shared %barrier, %count : !llvm.ptr<3>, i32 -> i64 - llvm.return -} - -llvm.func private @mbarrier_test_wait(%barrier: !llvm.ptr, %token : i64) -> i1 { - // CHECK: nvvm.mbarrier.test.wait %{{.*}} - %isComplete = nvvm.mbarrier.test.wait %barrier, %token : !llvm.ptr, i64 -> i1 - llvm.return %isComplete : i1 -} - -llvm.func private @mbarrier_test_wait_shared(%barrier: !llvm.ptr<3>, %token : i64) { - %count = nvvm.read.ptx.sreg.ntid.x : i32 - // CHECK: nvvm.mbarrier.test.wait.shared %{{.*}} - %isComplete = nvvm.mbarrier.test.wait.shared %barrier, %token : !llvm.ptr<3>, i64 -> i1 + // CHECK: nvvm.mbarrier.arrive.nocomplete %{{.*}} : !llvm.ptr<3> + %0 = nvvm.mbarrier.arrive.nocomplete %barrier, %count : !llvm.ptr<3>, i32 -> i64 llvm.return } diff --git a/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir b/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir new file mode 100644 index 0000000..506b81e --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir @@ -0,0 +1,70 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +// Test invalid target architecture (sm_100 instead of sm_100a) +gpu.module @invalid_arch_sm_100 [#nvvm.target<chip = "sm_100">] { + func.func @convert_rs(%src : vector<4xf32>, %rbits : i32) { + // expected-error@+1 {{'nvvm.convert.f32x4.to.f8x4' op is not supported on sm_100}} + %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E4M3FN) + return + } +} + +// ----- + +// Test invalid destination types for f8x4 (should only accept f8E4M3FN, f8E5M2) +llvm.func @invalid_dst_type_f8x4_e3m4(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> { + // expected-error@+1 {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f32x4 to f8x4.}} + %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E3M4) + llvm.return %res : vector<4xi8> +} + +// ----- + +llvm.func @invalid_dst_type_f8x4_e8m0(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> { + // expected-error@+1 {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f32x4 to f8x4.}} + %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E8M0FNU) + llvm.return %res : vector<4xi8> +} + +// ----- + +// Test invalid destination types for f6x4 (should only accept f6E2M3FN, f6E3M2FN) +llvm.func @invalid_dst_type_f6x4_f8(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> { + // expected-error@+1 {{Only 'f6E2M3FN' and 'f6E3M2FN' types are supported for conversions from f32x4 to f6x4.}} + %res = nvvm.convert.f32x4.to.f6x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E4M3FN) + llvm.return %res : vector<4xi8> +} + +// ----- + +// Test invalid destination type for f4x4 (should only accept f4E2M1FN) +llvm.func @invalid_dst_type_f4x4_f6(%src : vector<4xf32>, %rbits : i32) -> i16 { + // expected-error@+1 {{Only 'f4E2M1FN' type is supported for conversions from f32x4 to f4x4.}} + %res = nvvm.convert.f32x4.to.f4x4 %src, %rbits : vector<4xf32> -> i16 (f6E2M3FN) + llvm.return %res : i16 +} + +// ----- + +// Test invalid rounding modes for non-stochastic ops +llvm.func @convert_float_to_tf32_rs_not_supported(%src : f32) -> i32 { + // expected-error @below {{Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.}} + %res = nvvm.convert.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rs>} + llvm.return %res : i32 +} + +// ----- + +llvm.func @convert_f32x2_to_f8x2_rs_not_supported(%a : f32, %b : f32) { + // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to 'f8E4M3FN' and 'f8E5M2' types}} + %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : i16 (f8E4M3FN) + llvm.return +} + +// ----- + +llvm.func @convert_bf16x2_to_f8x2_rs_not_supported(%src : vector<2xbf16>) { + // expected-error @below {{Only RZ and RP rounding modes are supported for conversions from bf16x2 to f8x2.}} + %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16> -> i16 (f8E8M0FNU) + llvm.return +} diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir index e703600..40084bc 100644 --- a/mlir/test/Dialect/LLVMIR/rocdl.mlir +++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir @@ -14,18 +14,24 @@ func.func @rocdl_special_regs() -> i32 { %4 = rocdl.workgroup.id.y : i32 // CHECK: rocdl.workgroup.id.z : i32 %5 = rocdl.workgroup.id.z : i32 + // CHECK: rocdl.cluster.id.x : i32 + %6 = rocdl.cluster.id.x : i32 + // CHECK: rocdl.cluster.id.y : i32 + %7 = rocdl.cluster.id.y : i32 + // CHECK: rocdl.cluster.id.z : i32 + %8 = rocdl.cluster.id.z : i32 // CHECK: rocdl.workgroup.dim.x : i32 - %6 = rocdl.workgroup.dim.x : i32 + %9 = rocdl.workgroup.dim.x : i32 // CHECK: rocdl.workgroup.dim.y : i32 - %7 = rocdl.workgroup.dim.y : i32 + %10 = rocdl.workgroup.dim.y : i32 // CHECK: rocdl.workgroup.dim.z : i32 - %8 = rocdl.workgroup.dim.z : i32 + %11 = rocdl.workgroup.dim.z : i32 // CHECK: rocdl.grid.dim.x : i32 - %9 = rocdl.grid.dim.x : i32 + %12 = rocdl.grid.dim.x : i32 // CHECK: rocdl.grid.dim.y : i32 - %10 = rocdl.grid.dim.y : i32 + %13 = rocdl.grid.dim.y : i32 // CHECK: rocdl.grid.dim.z : i32 - %11 = rocdl.grid.dim.z : i32 + %14 = rocdl.grid.dim.z : i32 llvm.return %0 : i32 } @@ -43,6 +49,59 @@ func.func @rocdl.fmed3.vector(%a: vector<4xf16>, %b: vector<4xf16>, %c: vector<4 llvm.return %0 : vector<4xf16> } +func.func @rocdl.math.ops(%a: f32, %b: f16, %c: bf16) { + // CHECK-LABEL: rocdl.math.ops + // CHECK: %{{.*}} = rocdl.tanh %{{.*}} f32 -> f32 + // CHECK: %{{.*}} = rocdl.tanh %{{.*}} f16 -> f16 + // CHECK: %{{.*}} = rocdl.tanh %{{.*}} bf16 -> bf16 + %tanh0 = rocdl.tanh %a f32 -> f32 + %tanh1 = rocdl.tanh %b f16 -> f16 + %tanh2 = rocdl.tanh %c bf16 -> bf16 + + // CHECK: %{{.*}} = rocdl.sin %{{.*}} f32 -> f32 + // CHECK: %{{.*}} = rocdl.sin %{{.*}} f16 -> f16 + // CHECK: %{{.*}} = rocdl.sin %{{.*}} bf16 -> bf16 + %sin0 = rocdl.sin %a f32 -> f32 + %sin1 = rocdl.sin %b f16 -> f16 + %sin2 = rocdl.sin %c bf16 -> bf16 + + // CHECK: %{{.*}} = rocdl.cos %{{.*}} f32 -> f32 + // CHECK: %{{.*}} = rocdl.cos %{{.*}} f16 -> f16 + // CHECK: %{{.*}} = rocdl.cos %{{.*}} bf16 -> bf16 + %cos0 = rocdl.cos %a f32 -> f32 + %cos1 = rocdl.cos %b f16 -> f16 + %cos2 = rocdl.cos %c bf16 -> bf16 + + // CHECK: %{{.*}} = rocdl.rcp %{{.*}} f32 -> f32 + // CHECK: %{{.*}} = rocdl.rcp %{{.*}} f16 -> f16 + // CHECK: %{{.*}} = rocdl.rcp %{{.*}} bf16 -> bf16 + %rcp0 = rocdl.rcp %a f32 -> f32 + %rcp1 = rocdl.rcp %b f16 -> f16 + %rcp2 = rocdl.rcp %c bf16 -> bf16 + + // CHECK: %{{.*}} = rocdl.exp2 %{{.*}} f32 -> f32 + // CHECK: %{{.*}} = rocdl.exp2 %{{.*}} f16 -> f16 + // CHECK: %{{.*}} = rocdl.exp2 %{{.*}} bf16 -> bf16 + %exp2_0 = rocdl.exp2 %a f32 -> f32 + %exp2_1 = rocdl.exp2 %b f16 -> f16 + %exp2_2 = rocdl.exp2 %c bf16 -> bf16 + + // CHECK: %{{.*}} = rocdl.log %{{.*}} f32 -> f32 + // CHECK: %{{.*}} = rocdl.log %{{.*}} f16 -> f16 + // CHECK: %{{.*}} = rocdl.log %{{.*}} bf16 -> bf16 + %log0 = rocdl.log %a f32 -> f32 + %log1 = rocdl.log %b f16 -> f16 + %log2 = rocdl.log %c bf16 -> bf16 + + // CHECK: %{{.*}} = rocdl.sqrt %{{.*}} f32 -> f32 + // CHECK: %{{.*}} = rocdl.sqrt %{{.*}} f16 -> f16 + // CHECK: %{{.*}} = rocdl.sqrt %{{.*}} bf16 -> bf16 + %sqrt0 = rocdl.sqrt %a f32 -> f32 + %sqrt1 = rocdl.sqrt %b f16 -> f16 + %sqrt2 = rocdl.sqrt %c bf16 -> bf16 + llvm.return +} + func.func @rocdl.barrier() { // CHECK: rocdl.barrier rocdl.barrier @@ -650,6 +709,39 @@ llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> { llvm.return %r3 : vector<4xf16> } +llvm.func @rocdl.load.tr.ops(%gl_ptr : !llvm.ptr<1>, %ds_ptr : !llvm.ptr<3>) { + // CHECK-LABEL: @rocdl.load.tr.ops + // CHECK-SAME: (%[[GL_PTR:.+]]: !llvm.ptr<1>, %[[DS_OTR:.+]]: !llvm.ptr<3>) + // CHECK: rocdl.global.load.tr4.b64 %[[GL_PTR]] : !llvm.ptr<1> -> vector<2xi32> + // CHECK: rocdl.global.load.tr.b64 %[[GL_PTR]] : !llvm.ptr<1> -> vector<2xi32> + // CHECK: rocdl.global.load.tr6.b96 %[[GL_PTR]] : !llvm.ptr<1> -> vector<3xi32> + // CHECK: rocdl.global.load.tr.b128 %[[GL_PTR]] : !llvm.ptr<1> -> vector<8xi16> + // CHECK: rocdl.global.load.tr.b128 %[[GL_PTR]] : !llvm.ptr<1> -> vector<8xf16> + // CHECK: rocdl.global.load.tr.b128 %[[GL_PTR]] : !llvm.ptr<1> -> vector<8xbf16> + // CHECK: rocdl.ds.load.tr4.b64 %[[DS_OTR]] : !llvm.ptr<3> -> vector<2xi32> + // CHECK: rocdl.ds.load.tr8.b64 %[[DS_OTR]] : !llvm.ptr<3> -> vector<2xi32> + // CHECK: rocdl.ds.load.tr6.b96 %[[DS_OTR]] : !llvm.ptr<3> -> vector<3xi32> + // CHECK: rocdl.ds.load.tr16.b128 %[[DS_OTR]] : !llvm.ptr<3> -> vector<8xi16> + // CHECK: rocdl.ds.load.tr16.b128 %[[DS_OTR]] : !llvm.ptr<3> -> vector<8xf16> + // CHECK: rocdl.ds.load.tr16.b128 %[[DS_OTR]] : !llvm.ptr<3> -> vector<8xbf16> + // CHECK: llvm.return + + rocdl.global.load.tr4.b64 %gl_ptr : !llvm.ptr<1> -> vector<2xi32> + rocdl.global.load.tr.b64 %gl_ptr : !llvm.ptr<1> -> vector<2xi32> + rocdl.global.load.tr6.b96 %gl_ptr : !llvm.ptr<1> -> vector<3xi32> + rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xi16> + rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xf16> + rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xbf16> + + rocdl.ds.load.tr4.b64 %ds_ptr : !llvm.ptr<3> -> vector<2xi32> + rocdl.ds.load.tr8.b64 %ds_ptr : !llvm.ptr<3> -> vector<2xi32> + rocdl.ds.load.tr6.b96 %ds_ptr : !llvm.ptr<3> -> vector<3xi32> + rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xi16> + rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xf16> + rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xbf16> + llvm.return +} + llvm.func @rocdl.load.to.lds(%src : !llvm.ptr<7>, %dst: !llvm.ptr<3>) { // CHECK-LABEL @rocdl.load.to.lds //CHECK: rocdl.load.to.lds %{{.*}}, %{{.*}}, 4, 0, 0 : <7> @@ -664,6 +756,33 @@ llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) { llvm.return } +llvm.func @rocdl.global.load.async.to.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) { + // CHECK-LABEL @rocdl.global.load.async.to.lds + // CHECK: rocdl.global.load.async.to.lds.b8 %{{.*}}, %{{.*}}, 0, 0 + // CHECK: rocdl.global.load.async.to.lds.b32 %{{.*}}, %{{.*}}, 0, 0 + // CHECK: rocdl.global.load.async.to.lds.b64 %{{.*}}, %{{.*}}, 0, 0 + // CHECK: rocdl.global.load.async.to.lds.b128 %{{.*}}, %{{.*}}, 0, 0 + rocdl.global.load.async.to.lds.b8 %src, %dst, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3> + rocdl.global.load.async.to.lds.b32 %src, %dst, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3> + rocdl.global.load.async.to.lds.b64 %src, %dst, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3> + rocdl.global.load.async.to.lds.b128 %src, %dst, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3> + llvm.return +} + +llvm.func @rocdl.cluster.load.async.to.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) { + // CHECK-LABEL @rocdl.cluster.load.async.to.lds + // CHECK: rocdl.cluster.load.async.to.lds.b8 %{{.*}}, %{{.*}}, 0, 0, 0 + // CHECK: rocdl.cluster.load.async.to.lds.b32 %{{.*}}, %{{.*}}, 0, 0, 0 + // CHECK: rocdl.cluster.load.async.to.lds.b64 %{{.*}}, %{{.*}}, 0, 0, 0 + // CHECK: rocdl.cluster.load.async.to.lds.b128 %{{.*}}, %{{.*}}, 0, 0, 0 + rocdl.cluster.load.async.to.lds.b8 %src, %dst, 0, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3> + rocdl.cluster.load.async.to.lds.b32 %src, %dst, 0, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3> + rocdl.cluster.load.async.to.lds.b64 %src, %dst, 0, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3> + rocdl.cluster.load.async.to.lds.b128 %src, %dst, 0, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3> + llvm.return +} + + // CHECK-LABEL @rocdl.tensor.load.to.lds llvm.func @rocdl.tensor.load.to.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>, %dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) { @@ -1037,6 +1156,13 @@ llvm.func @rocdl.s.get.barrier.state() { llvm.return } +llvm.func @rocdl.s.get.named.barrier.state(%ptr : !llvm.ptr<3>) { + // CHECK-LABEL: rocdl.s.get.named.barrier.state + // CHECK: rocdl.s.get.named.barrier.state %[[PTR:.+]] + %0 = rocdl.s.get.named.barrier.state %ptr : i32 + llvm.return +} + llvm.func @rocdl.s.wait.dscnt() { // CHECK-LABEL: rocdl.s.wait.dscnt // CHECK: rocdl.s.wait.dscnt 0 @@ -1292,6 +1418,26 @@ llvm.func @rocdl.cvt.scalef32.sr.pk16(%v16xf32: vector<16xf32>, // ----- +// CHECK-LABEL: @rocdl_wmma_scale_ops +llvm.func @rocdl_wmma_scale_ops(%a_f8: vector<8xi32>, %a_f4: vector<4xi32>, %c_f32: vector<4xf32>, %c16_f32: vector<16xf32>, + %scale_i32: i32, %scale_i64: i64) { + // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32> + %r0 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %a_f8, %a_f8, %c_f32, %scale_i32, %scale_i32 : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32> + + // CHECK: rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i64, i64) -> vector<4xf32> + %r1 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %a_f8, %a_f8, %c_f32, %scale_i64, %scale_i64 : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i64, i64) -> vector<4xf32> + + // CHECK: rocdl.wmma.scale.f32.32x16x128.f4 %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32) -> vector<16xf32> + %r2 = rocdl.wmma.scale.f32.32x16x128.f4 %a_f4, %a_f4, %c16_f32, %scale_i32, %scale_i32 : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32) -> vector<16xf32> + + // CHECK: rocdl.wmma.scale16.f32.32x16x128.f4 %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i64, i64) -> vector<16xf32> + %r3 = rocdl.wmma.scale16.f32.32x16x128.f4 %a_f4, %a_f4, %c16_f32, %scale_i64, %scale_i64 : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i64, i64) -> vector<16xf32> + + llvm.return +} + +// ----- + // expected-error@below {{attribute attached to unexpected op}} func.func private @expected_llvm_func() attributes { rocdl.kernel } diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir index 00e763a..afbf47e 100644 --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -122,8 +122,8 @@ func.func @ops(%arg0: i32, %arg1: f32, // CHECK: llvm.call @baz() {will_return} : () -> () llvm.call @baz() {will_return} : () -> () -// CHECK: llvm.call @baz() {memory = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = write>} : () -> () - llvm.call @baz() {memory = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = write>} : () -> () +// CHECK: llvm.call @baz() {memory = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = write, errnoMem = none, targetMem0 = none, targetMem1 = none>} : () -> () + llvm.call @baz() {memory = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = write, errnoMem = none, targetMem0 = none, targetMem1 = none>} : () -> () // Terminator operations and their successors. // diff --git a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir new file mode 100644 index 0000000..4b2d42a --- /dev/null +++ b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir @@ -0,0 +1,214 @@ +// The following test examples of linalg convolution named ops lowered to linalg.generic and then +// lifted back up to named op. +// NOTE: Most tests in this file use dynamic shapes as the underlying transformations don't modify shapes. There's one exception that's added as a smoke test. + +// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s --implicit-check-not=linalg.generic + +// ----------------------------- +// Convolution ops. +// ----------------------------- +func.func @conv_1d(%in : tensor<?xf32>, %filter : tensor<?xf32>, %out : tensor<?xf32>) -> tensor<?xf32> { + %0 = linalg.conv_1d + ins(%in, %filter : tensor<?xf32>, tensor<?xf32>) + outs(%out : tensor<?xf32>) -> tensor<?xf32> + return %0 : tensor<?xf32> +} +// CHECK: @conv_1d +// CHECK: linalg.conv_1d + +// ----- + +func.func @conv_1d_nwc_wcf(%input: tensor<?x?x?xf32>, %filter: tensor<?x?x?xf32>, %output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> { + %0 = linalg.conv_1d_nwc_wcf + {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} + ins (%input, %filter: tensor<?x?x?xf32>, tensor<?x?x?xf32>) + outs (%output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> + return %0 : tensor<?x?x?xf32> +} +// CHECK: @conv_1d_nwc_wcf +// CHECK: linalg.conv_1d_nwc_wcf +// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64> + +// ----- + +func.func @conv_1d_ncw_fcw(%input: tensor<?x?x?xf32>, %filter: tensor<?x?x?xf32>, %output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> { + %0 = linalg.conv_1d_ncw_fcw + {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} + ins (%input, %filter: tensor<?x?x?xf32>, tensor<?x?x?xf32>) + outs (%output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> + return %0 : tensor<?x?x?xf32> +} +// CHECK: @conv_1d_ncw_fcw +// CHECK: linalg.conv_1d_ncw_fcw +// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64> + +// ----- + +func.func @conv_2d(%in : tensor<?x?xf32>, %filter : tensor<?x?xf32>, %out : tensor<?x?xf32>) -> tensor<?x?xf32> { + %0 = linalg.conv_2d + ins(%in, %filter : tensor<?x?xf32>, tensor<?x?xf32>) + outs(%out: tensor<?x?xf32>) -> tensor<?x?xf32> + return %0 : tensor<?x?xf32> +} +// CHECK: @conv_2d +// CHECK: linalg.conv_2d + +// ----- + +func.func @conv_3d(%in : tensor<?x?x?xf32>, %filter : tensor<?x?x?xf32>, %out : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> { + %0 = linalg.conv_3d + ins(%in, %filter : tensor<?x?x?xf32>, tensor<?x?x?xf32>) + outs(%out : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> + return %0 : tensor<?x?x?xf32> +} +// CHECK: @conv_3d +// CHECK: linalg.conv_3d + +// ----- + +// ----------------------------- +// Depthwise Convolution ops. +// ----------------------------- +func.func @depthwise_conv_1d_ncw_cw(%input: tensor<?x?x?xf32>, %filter: tensor<?x?xf32>, %output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> { + %0 = linalg.depthwise_conv_1d_ncw_cw + {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} + ins (%input, %filter: tensor<?x?x?xf32>, tensor<?x?xf32>) + outs (%output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> + return %0 : tensor<?x?x?xf32> +} +// CHECK: @depthwise_conv_1d_ncw_cw +// CHECK: linalg.depthwise_conv_1d_ncw_cw +// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64> + +// ----- + +func.func @depthwise_conv_1d_nwc_wc_static(%input: tensor<1x25x8xi8>, %filter: tensor<3x8xi8>, %output: tensor<1x10x8xi32>) -> tensor<1x10x8xi32> { + %0 = linalg.depthwise_conv_1d_nwc_wc + {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} + ins (%input, %filter: tensor<1x25x8xi8>, tensor<3x8xi8>) + outs (%output: tensor<1x10x8xi32>) -> tensor<1x10x8xi32> + return %0 : tensor<1x10x8xi32> +} +// CHECK: @depthwise_conv_1d_nwc_wc_static +// CHECK: linalg.depthwise_conv_1d_nwc_wc +// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64> + +// ----- + +func.func @depthwise_conv_1d_nwc_wcm(%input: tensor<?x?x?xf32>, %filter: tensor<?x?x?xf32>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> { + %0 = linalg.depthwise_conv_1d_nwc_wcm + {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins (%input, %filter: tensor<?x?x?xf32>, tensor<?x?x?xf32>) + outs (%output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> + return %0 : tensor<?x?x?x?xf32> +} +// CHECK: @depthwise_conv_1d_nwc_wcm +// CHECK: linalg.depthwise_conv_1d_nwc_wcm +// CHECK-SAME: dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64> + +// ----- + +func.func @depthwise_conv_2d_nchw_chw(%input: tensor<?x?x?x?xf16>, %filter: tensor<?x?x?xf16>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> { + %0 = linalg.depthwise_conv_2d_nchw_chw + {dilations = dense<[2,3]> : vector<2xi64>, strides = dense<[4,5]> : vector<2xi64>} + ins (%input, %filter: tensor<?x?x?x?xf16>, tensor<?x?x?xf16>) + outs (%output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> + return %0 : tensor<?x?x?x?xf32> +} +// CHECK: @depthwise_conv_2d_nchw_chw +// CHECK: linalg.depthwise_conv_2d_nchw_chw +// CHECK-SAME: dilations = dense<[2, 3]> : tensor<2xi64>, strides = dense<[4, 5]> : tensor<2xi64> + +// ----- + +func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %output: tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32> { + %0 = linalg.depthwise_conv_3d_ndhwc_dhwcm + {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} + ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) + outs (%output: tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32> + return %0 : tensor<?x?x?x?x?x?xf32> +} +// CHECK: @depthwise_conv_3d_ndhwc_dhwcm +// CHECK: linalg.depthwise_conv_3d_ndhwc_dhwcm +// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64> + +// ----- + +// ----------------------------- +// Pooling ops. +// ----------------------------- +func.func @pooling_nhwc_max(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> { + %0 = linalg.pooling_nhwc_max + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>) + outs (%output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> + return %0 : tensor<?x?x?x?xf32> +} +// CHECK: @pooling_nhwc_max +// CHECK: linalg.pooling_nhwc_max +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> + +// ----- + +func.func @pooling_nhwc_min(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> { + %0 = linalg.pooling_nhwc_min + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>) + outs (%output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> + return %0 : tensor<?x?x?x?xf32> +} +// CHECK: @pooling_nhwc_min +// CHECK: linalg.pooling_nhwc_min +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> + +// ----- + +func.func @pooling_nhwc_sum(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> { + %0 = linalg.pooling_nhwc_sum + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>) + outs (%output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> + return %0 : tensor<?x?x?x?xf32> +} +// CHECK: @pooling_nhwc_sum +// CHECK: linalg.pooling_nhwc_sum +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> + +// ----- + +func.func @pooling_nhwc_max_unsigned(%input: tensor<?x?x?x?xi8>, %filter: tensor<?x?xi8>, %output: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32> { + %0 = linalg.pooling_nhwc_max_unsigned + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor<?x?x?x?xi8>, tensor<?x?xi8>) + outs (%output: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32> + return %0 : tensor<?x?x?x?xi32> +} +// CHECK: @pooling_nhwc_max_unsigned +// CHECK: linalg.pooling_nhwc_max_unsigned +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> + +// ----- + +func.func @pooling_nhwc_min_unsigned_integer(%input: tensor<?x?x?x?xi32>, %filter: tensor<?x?xi32>, %output: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32> { + %0 = linalg.pooling_nhwc_min_unsigned + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor<?x?x?x?xi32>, tensor<?x?xi32>) + outs (%output: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32> + return %0 : tensor<?x?x?x?xi32> +} +// CHECK: @pooling_nhwc_min_unsigned_integer +// CHECK: linalg.pooling_nhwc_min_unsigned +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> + +// ----- + +func.func @pooling_nhwc_min_unsigned_float(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> { + %0 = linalg.pooling_nhwc_min_unsigned + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>) + outs (%output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> + return %0 : tensor<?x?x?x?xf32> +} +// CHECK: @pooling_nhwc_min_unsigned_float +// CHECK: linalg.pooling_nhwc_min +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir index 2bf3d21..77c7d7d 100644 --- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir +++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir @@ -594,6 +594,24 @@ func.func @fuse_by_collapsing_pad(%arg0 : tensor<2x12x5x336x9xi32>) -> tensor<8x // ----- +func.func @no_fuse_by_collapsing_pad_non_constant_padding(%arg0 : tensor<2x12xi32>) -> tensor<8x3x4xi32> { + %expand = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [2, 3, 4] : tensor<2x12xi32> into tensor<2x3x4xi32> + %cst = arith.constant 0 : i32 + %padded_0 = tensor.pad %expand low[1, 0, 0] high[5, 0, 0] { + ^bb0(%arg1: index, %arg2: index, %arg3: index): + %pad_val = arith.index_cast %arg1 : index to i32 + tensor.yield %pad_val : i32 + } : tensor<2x3x4xi32> to tensor<8x3x4xi32> + return %padded_0 : tensor<8x3x4xi32> +} +// CHECK: func @no_fuse_by_collapsing_pad_non_constant_padding( +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12xi32>) +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] +// CHECK: %[[PAD:.+]] = tensor.pad %[[EXPAND]] +// CHECK: return %[[PAD]] + +// ----- + func.func @no_fuse_by_collapsing_pad(%arg0 : tensor<2x12x5x336x9xi32>) -> tensor<8x5x4x17x6x7x8x14xi32> { %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32> %cst = arith.constant 0 : i32 @@ -640,6 +658,63 @@ func.func @fuse_by_collapsing_dynamic_pad(%arg0 : tensor<?x?x?x?xf32>, // CHECK: return %[[EXPAND]] // ----- + +func.func @collapse_shape_with_producer_pad(%arg0: tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x336x14xi32> { + %cst = arith.constant 0 : i32 + %padded = tensor.pad %arg0 low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, + %arg5: index, %arg6: index, %arg7: index, %arg8: index): + tensor.yield %cst : i32 + } : tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x3x4x17x6x7x8x14xi32> + %collapsed = tensor.collapse_shape %padded [[0], [1, 2], [3], [4, 5, 6], [7]] + : tensor<8x3x4x17x6x7x8x14xi32> into tensor<8x12x17x336x14xi32> + return %collapsed : tensor<8x12x17x336x14xi32> +} +// CHECK: func @collapse_shape_with_producer_pad +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32> +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]] +// CHECK: %[[PAD:.+]] = tensor.pad %[[COLLAPSE]] low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2] +// CHECK: return %[[PAD]] + +// ----- + +func.func @collapse_shape_with_producer_pad_dynamic(%arg0: tensor<?x?x?x?x?x?xf32>, + %l0 : index, %l1 : index, %h0 : index, %h1 : index) -> tensor<?x?x?x?xf32> { + %cst = arith.constant 0.0 : f32 + %padded = tensor.pad %arg0 low[%l0, 0, 0, %l1, 0, 0] high[%h0, 0, 0, %h1, 0, 0] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index): + tensor.yield %cst : f32 + } : tensor<?x?x?x?x?x?xf32> to tensor<?x?x?x?x?x?xf32> + %collapsed = tensor.collapse_shape %padded [[0], [1, 2], [3], [4, 5]] + : tensor<?x?x?x?x?x?xf32> into tensor<?x?x?x?xf32> + return %collapsed : tensor<?x?x?x?xf32> +} +// CHECK: func @collapse_shape_with_producer_pad_dynamic +// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?x?x?xf32> +// CHECK-SAME: %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5]] +// CHECK: %[[PAD:.+]] = tensor.pad %[[COLLAPSE]] low[%[[L0]], 0, %[[L1]], 0] high[%[[H0]], 0, %[[H1]], 0] +// CHECK: return %[[PAD]] + +// ----- + +func.func @collapse_shape_with_producer_pad_non_constant_padding(%arg0 : tensor<2x3x4xi32>) -> tensor<8x12xi32> { + %cst = arith.constant 0 : i32 + %padded_0 = tensor.pad %arg0 low[1, 0, 0] high[5, 0, 0] { + ^bb0(%arg1: index, %arg2: index, %arg3: index): + %pad_val = arith.index_cast %arg1 : index to i32 + tensor.yield %pad_val : i32 + } : tensor<2x3x4xi32> to tensor<8x3x4xi32> + %collapsed = tensor.collapse_shape %padded_0 [[0], [1, 2]] : tensor<8x3x4xi32> into tensor<8x12xi32> + return %collapsed : tensor<8x12xi32> +} +// CHECK: func @collapse_shape_with_producer_pad_non_constant_padding( +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4xi32>) +// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] +// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PAD]] +// CHECK: return %[[COLLAPSED]] + +// ----- // Static problem sizes. Checks all aspects of fusion by collapsing with bubbling up collapse shapes. #map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2)> diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir index bc55c12..6f1a422 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir @@ -921,30 +921,6 @@ func.func @fold_fill_generic_basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>) { // ----- -// CHECK-LABEL: func @fold_fill_generic_different_dtype -// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf16>) -> tensor<?xf16> { -// CHECK-NOT: linalg.fill -// CHECK: %[[GENERIC_OP:.*]] = linalg.generic -// CHECK-SAME: ins(%[[ARG0]] : tensor<?xf16>) -// CHECK-SAME: outs({{.*}} : tensor<?xf16>) { -#map0 = affine_map<(d0) -> (d0)> -func.func @fold_fill_generic_different_dtype(%arg0: tensor<?xf16>) -> (tensor<?xf16>) { - %c0 = arith.constant 0 : index - %cst = arith.constant 7.0 : f32 - %0 = tensor.dim %arg0, %c0 : tensor<?xf16> - %1 = tensor.empty(%0) : tensor<?xf16> - %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<?xf16>) -> tensor<?xf16> - %3 = tensor.empty(%0) : tensor<?xf16> - %4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor<?xf16>, tensor<?xf16>) outs (%3:tensor<?xf16>) { - ^bb0(%arg1: f16, %arg2: f16, %arg3: f16): - %5 = arith.addf %arg1, %arg2 : f16 - linalg.yield %5 : f16 - } -> tensor<?xf16> - return %4 : tensor<?xf16> -} - -// ----- - // CHECK-LABEL: func @fold_fill_generic_mixedaccess // CHECK-NOT: linalg.fill // CHECK: %[[GENERIC_OP:.*]] = linalg.generic @@ -1079,4 +1055,4 @@ module { // CHECK-NOT: linalg.generic // CHECK: tensor.expand_shape // CHECK: linalg.generic {{.*}}, iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]} -// CHECK-SAME: ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>)
\ No newline at end of file +// CHECK-SAME: ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>) diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir index 290c6c7..4526dc9 100644 --- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir @@ -380,8 +380,8 @@ func.func @generalize_pooling_nwc_sum_i32(%input : tensor<1x16x1xi32>, %shape: t // ----- -func.func @generalize_fill_0d(%value: f64, %O: tensor<f32>) -> tensor<f32> { - %0 = linalg.fill ins(%value: f64) outs(%O : tensor<f32>) -> tensor<f32> +func.func @generalize_fill_0d(%value: f32, %O: tensor<f32>) -> tensor<f32> { + %0 = linalg.fill ins(%value: f32) outs(%O : tensor<f32>) -> tensor<f32> return %0: tensor<f32> } @@ -394,8 +394,8 @@ func.func @generalize_fill_0d(%value: f64, %O: tensor<f32>) -> tensor<f32> { // ----- -func.func @generalize_fill_2d(%value: f64, %O: memref<16x32xf32>) { - linalg.fill ins(%value: f64) outs(%O : memref<16x32xf32>) +func.func @generalize_fill_2d(%value: f32, %O: memref<16x32xf32>) { + linalg.fill ins(%value: f32) outs(%O : memref<16x32xf32>) return } diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index fabc8e6..1f554e6 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -352,6 +352,24 @@ func.func @illegal_fill_tensor_with_memref_return // ----- +func.func @illegal_fill_element_type_truncation(%arg0 : tensor<2xf32>, %arg1 : f64) -> tensor<2xf32> +{ + // expected-error @+1 {{'linalg.fill' op expected fill value type ('f64') to match output element type ('f32')}} + %0 = linalg.fill ins(%arg1 : f64) outs(%arg0 : tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// ----- + +func.func @illegal_fill_element_type_extension(%arg0 : tensor<2xi32>, %arg1 : i16) -> tensor<2xi32> +{ + // expected-error @+1 {{'linalg.fill' op expected fill value type ('i16') to match output element type ('i32')}} + %0 = linalg.fill ins(%arg1 : i16) outs(%arg0 : tensor<2xi32>) -> tensor<2xi32> + return %0 : tensor<2xi32> +} + +// ----- + func.func @illegal_fill_value_type(%arg0 : tensor<2x2xf32>, %arg1 : tensor<2xf32>) -> tensor<2x2xf32> { // expected-error @+1 {{expected op with scalar input}} diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir index 67b4f2b..3fb7225 100644 --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -822,6 +822,23 @@ func.func @fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor< // ----- +func.func @no_fuse_by_expanding_pad_non_constant_padding(%arg0 : tensor<2x3x4xi32>) -> tensor<8x12xi32> { + %collapse = tensor.collapse_shape %arg0 [[0], [1, 2]] : tensor<2x3x4xi32> into tensor<2x12xi32> + %padded_0 = tensor.pad %collapse low[1, 0] high[5, 0] { + ^bb0(%arg1: index, %arg2: index): + %pad_val = arith.index_cast %arg1 : index to i32 + tensor.yield %pad_val : i32 + } : tensor<2x12xi32> to tensor<8x12xi32> + return %padded_0 : tensor<8x12xi32> +} +// CHECK: func @no_fuse_by_expanding_pad_non_constant_padding( +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4xi32>) +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] +// CHECK: %[[PAD:.+]] = tensor.pad %[[COLLAPSE]] +// CHECK: return %[[PAD]] + +// ----- + func.func @no_fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x339x14xi32> { %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32> %cst = arith.constant 0 : i32 @@ -863,6 +880,64 @@ func.func @fuse_by_expanding_dynamic_pad(%arg0 : tensor<?x?x?x?x?x?xi32>, %l0: i // ----- +func.func @expand_shape_with_producer_pad(%arg0: tensor<2x12x5x336x9xi32>) -> tensor<8x3x4x17x6x7x8x14xi32> { + %cst = arith.constant 0 : i32 + %padded = tensor.pad %arg0 low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index): + tensor.yield %cst : i32 + } : tensor<2x12x5x336x9xi32> to tensor<8x12x17x336x14xi32> + %expanded = tensor.expand_shape %padded [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [8, 3, 4, 17, 6, 7, 8, 14] + : tensor<8x12x17x336x14xi32> into tensor<8x3x4x17x6x7x8x14xi32> + return %expanded : tensor<8x3x4x17x6x7x8x14xi32> +} +// CHECK: func @expand_shape_with_producer_pad +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32> +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] +// CHECK: %[[PAD:.+]] = tensor.pad %[[EXPAND]] low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] +// CHECK: return %[[PAD]] + +// ----- + +func.func @expand_shape_with_producer_pad_dynamic(%arg0: tensor<?x?x?x?xf32>, + %s0: index, %s1: index, %s2: index, %s3: index, %s4: index, %s5: index, + %l0: index, %l1: index, %h0: index, %h1: index) -> tensor<?x?x?x?x?x?xf32> { + %cst = arith.constant 0.0 : f32 + %padded = tensor.pad %arg0 low[%l0, 0, %l1, 0] high[%h0, 0, %h1, 0] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): + tensor.yield %cst : f32 + } : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32> + %expanded = tensor.expand_shape %padded [[0], [1, 2], [3], [4, 5]] output_shape [%s0, %s1, %s2, %s3, %s4, %s5] + : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?x?xf32> + return %expanded : tensor<?x?x?x?x?x?xf32> +} +// CHECK: func @expand_shape_with_producer_pad_dynamic +// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32> +// CHECK-SAME: %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index, %[[S4:.+]]: index, %[[S5:.+]]: index, %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index +// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0:.+]] : tensor<?x?x?x?xf32> +// CHECK: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2:.+]] : tensor<?x?x?x?xf32> +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5]] output_shape [%[[DIM0]], %[[S1]], %[[S2]], %[[DIM2]], %[[S4]], %[[S5]]] +// CHECK: %[[PAD:.+]] = tensor.pad %[[EXPAND]] low[%[[L0]], 0, 0, %[[L1]], 0, 0] high[%[[H0]], 0, 0, %[[H1]], 0, 0] +// CHECK: return %[[PAD]] + +// ----- + +func.func @expand_shape_with_producer_pad_non_constant_padding(%arg0 : tensor<2x12xi32>) -> tensor<8x3x4xi32> { + %padded_0 = tensor.pad %arg0 low[1, 0] high[5, 0] { + ^bb0(%arg1: index, %arg2: index): + %pad_val = arith.index_cast %arg1 : index to i32 + tensor.yield %pad_val : i32 + } : tensor<2x12xi32> to tensor<8x12xi32> + %expand = tensor.expand_shape %padded_0 [[0], [1, 2]] output_shape [8, 3, 4] : tensor<8x12xi32> into tensor<8x3x4xi32> + return %expand : tensor<8x3x4xi32> +} +// CHECK: func @expand_shape_with_producer_pad_non_constant_padding( +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12xi32>) +// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] +// CHECK: return %[[EXPAND]] + +// ----- + func.func @move_operand_deps(%arg0 : tensor<?x128xf16>, %arg1 : tensor<4x?x32x128xf16>, %empty : tensor<4x?x32x128xf16>) -> tensor<4x?x32x8x16xf16> { %c0 = arith.constant 0 : index diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir index e521608..ab38f9f 100644 --- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir @@ -253,6 +253,40 @@ module { // ----- +#map = affine_map<(d0) -> (d0 * 2)> +#map1 = affine_map<(d0) -> (d0 * 4)> +module { + // CHECK-LABEL: func.func @fuse_tileable_op_no_dps + func.func @fuse_tileable_op_no_dps(%arg0: tensor<4x4x4xf32>, %arg1: tensor<4x4x4xf32>) -> tensor<4x4x4xf32> { + %0 = "test.tiling_no_dps_op"(%arg0, %arg1) : (tensor<4x4x4xf32>, tensor<4x4x4xf32>) -> tensor<4x4x4xf32> + %1 = tensor.empty() : tensor<4x4x4xf32> + // CHECK: scf.forall + %2 = scf.forall (%arg2, %arg3, %arg4) in (4, 2, 1) shared_outs(%arg5 = %1) -> (tensor<4x4x4xf32>) { + %3 = affine.apply #map(%arg3) + %4 = affine.apply #map1(%arg4) + // CHECK: "test.tiling_no_dps_op" + // CHECK: "test.unregistered_op" + %extracted_slice = tensor.extract_slice %0[%arg2, %3, %4] [1, 2, 4] [1, 1, 1] : tensor<4x4x4xf32> to tensor<1x2x4xf32> + %5 = "test.unregistered_op"(%extracted_slice, %extracted_slice) : (tensor<1x2x4xf32>, tensor<1x2x4xf32>) -> tensor<1x2x4xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %5 into %arg5[%arg2, %3, %4] [1, 2, 4] [1, 1, 1] : tensor<1x2x4xf32> into tensor<4x4x4xf32> + } + } + return %2 : tensor<4x4x4xf32> + } + + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %op = transform.structured.match ops{["test.tiling_no_dps_op"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %forall = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %fused, %new_containing = transform.structured.fuse_into_containing_op %op into %forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } + } +} + +// ----- + module { // CHECK-LABEL: func.func @fuse_tileable_op_through_bbarg_inout_nested // CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<?x?x?xf32> diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir index 185fb9b..d72ab08 100644 --- a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir +++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir @@ -170,7 +170,7 @@ module { // Fuse the consumer operation into the tiled loop. %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op : (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice"> - transform.test.fuse_consumer %slice_op in (%forall_op) + transform.test.fuse_consumer_using_slice %slice_op in (%forall_op) : (!transform.op<"tensor.parallel_insert_slice">, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } @@ -231,7 +231,7 @@ module { // Fuse the consumer operation into the tiled loop. %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op : (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice"> - // Note that we cannot apply transform.test.fuse_consumer here because the extract_slice + // Note that we cannot apply transform.test.fuse_consumer_using_slice here because the extract_slice // is not qualified consumer operation. Forcing this will yeild "could not fetch consumer // to fuse" error. transform.yield diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir index aa2c1da..95959fc 100644 --- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir +++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir @@ -285,6 +285,8 @@ module attributes {transform.with_named_sequence} { ///---------------------------------------------------------------------------------------- /// Tests for linalg.pack +/// +/// TODO: Add similar tests for linalg.unpack ///---------------------------------------------------------------------------------------- // Note, see a similar test in: @@ -1479,23 +1481,23 @@ module attributes {transform.with_named_sequence} { // ----- -// CHECK-LABEL: func @reduce_1d( -// CHECK-SAME: %[[A:.*]]: tensor<32xf32> -func.func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> { +// CHECK-LABEL: func @reduce_to_rank_0( +// CHECK-SAME: %[[SRC:.*]]: tensor<32xf32> +func.func @reduce_to_rank_0(%arg0: tensor<32xf32>) -> tensor<f32> { // CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index %f0 = arith.constant 0.000000e+00 : f32 - // CHECK: %[[init:.*]] = tensor.empty() : tensor<f32> + // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<f32> %0 = tensor.empty() : tensor<f32> %1 = linalg.fill ins(%f0 : f32) outs(%0 : tensor<f32>) -> tensor<f32> - // CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]] + // CHECK: %[[R:.*]] = vector.transfer_read %[[SRC]][%[[C0]]] // CHECK-SAME: : tensor<32xf32>, vector<32xf32> - // CHECK: %[[red:.*]] = vector.multi_reduction <add>, %[[r]], %[[F0]] [0] + // CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[R]], %[[F0]] [0] // CHECK-SAME: : vector<32xf32> to f32 - // CHECK: %[[red_v1:.*]] = vector.broadcast %[[red]] : f32 to vector<f32> - // CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[init]][] + // CHECK: %[[RED_V1:.*]] = vector.broadcast %[[RED]] : f32 to vector<f32> + // CHECK: %[[RES:.*]] = vector.transfer_write %[[RED_V1]], %[[INIT]][] // CHECK-SAME: : vector<f32>, tensor<f32> %2 = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>, @@ -1523,6 +1525,58 @@ module attributes {transform.with_named_sequence} { // ----- +// CHECK-LABEL: func @reduce_to_rank_1( +// CHECK-SAME: %[[SRC:.*]]: tensor<32xf32> +func.func @reduce_to_rank_1(%arg0: tensor<32xf32>) -> tensor<1xf32> { + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[F0:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32> + %f0 = arith.constant 0.000000e+00 : f32 + + // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<1xf32> + %0 = tensor.empty() : tensor<1xf32> + + // CHECK: %[[INIT_ZERO:.*]] = vector.transfer_write %[[F0]], %[[INIT]][%[[C0]]] + // CHECK-SAME: : vector<1xf32>, tensor<1xf32> + %1 = linalg.fill ins(%f0 : f32) outs(%0 : tensor<1xf32>) -> tensor<1xf32> + + // CHECK: %[[R:.*]] = vector.transfer_read %[[SRC]][%[[C0]]] + // CHECK-SAME: : tensor<32xf32>, vector<32xf32> + // CHECK: %[[INIT_ZERO_VEC:.*]] = vector.transfer_read %[[INIT_ZERO]][%[[C0]]] + // CHECK-SAME: : tensor<1xf32>, vector<f32> + // CHECK: %[[INIT_ZERO_SCL:.*]] = vector.extract %[[INIT_ZERO_VEC]][] + // CHECK-SAME: : f32 from vector<f32> + // CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[R]], %[[INIT_ZERO_SCL]] [0] + // CHECK-SAME: : vector<32xf32> to f32 + // CHECK: %[[RED_V1:.*]] = vector.broadcast %[[RED]] : f32 to vector<f32> + // CHECK: vector.transfer_write %[[RED_V1]], %[[INIT_ZERO]][%[[C0]]] + // CHECK-SAME: : vector<f32>, tensor<1xf32> + + %2 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (0)>], + iterator_types = ["reduction"]} + ins(%arg0 : tensor<32xf32>) + outs(%1 : tensor<1xf32>) { + ^bb0(%a: f32, %b: f32): + %3 = arith.addf %a, %b : f32 + linalg.yield %3 : f32 + } -> tensor<1xf32> + + return %2 : tensor<1xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + + +// ----- + // This test checks that vectorization does not occur when an input indexing map // is not a projected permutation. In the future, this can be converted to a // positive test when support is added. diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir index 1304a90..170bae6 100644 --- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir +++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir @@ -1335,7 +1335,7 @@ func.func @pack_no_padding(%src: tensor<32x8x16xf32>, %dest: tensor<4x1x32x16x2x module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%src: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.pack"]} in %src : (!transform.any_op) -> !transform.any_op - transform.structured.vectorize %0 vector_sizes [4, 1, 32] : !transform.any_op + transform.structured.vectorize %0 vector_sizes [4, 1, 32, 16, 2] : !transform.any_op transform.yield } } @@ -1378,7 +1378,7 @@ func.func @pack_with_padding(%src: tensor<32x7x15xf32>, %dest: tensor<32x4x1x16x module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op - transform.structured.vectorize %0 vector_sizes [32, 4, 1] : !transform.any_op + transform.structured.vectorize %0 vector_sizes [32, 4, 1, 16, 2] : !transform.any_op transform.yield } } @@ -1424,8 +1424,13 @@ module attributes {transform.with_named_sequence} { // CHECK-LABEL: func @pack_with_dynamic_dims // CHECK-SAME: %[[SRC:.*]]: tensor<?x?xf32>, // CHECK-SAME: %[[DEST:.*]]: tensor<?x?x16x2xf32> -func.func @pack_with_dynamic_dims(%src: tensor<?x?xf32>, %dest: tensor<?x?x16x2xf32>) -> tensor<?x?x16x2xf32> { - %pack = linalg.pack %src inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %dest : tensor<?x?xf32> -> tensor<?x?x16x2xf32> +func.func @pack_with_dynamic_dims( + %src: tensor<?x?xf32>, + %dest: tensor<?x?x16x2xf32>) -> tensor<?x?x16x2xf32> { + %pack = linalg.pack %src + inner_dims_pos = [1, 0] + inner_tiles = [16, 2] + into %dest : tensor<?x?xf32> -> tensor<?x?x16x2xf32> return %pack : tensor<?x?x16x2xf32> } @@ -1433,30 +1438,108 @@ func.func @pack_with_dynamic_dims(%src: tensor<?x?xf32>, %dest: tensor<?x?x16x2x // CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1_0:.*]] = arith.constant 1 : index + +/// Compute mask for xfer_read // CHECK-DAG: %[[D0_0:.*]] = tensor.dim {{.*}} %[[C0_0]] : tensor<?x?xf32> // CHECK-DAG: %[[D1_0:.*]] = tensor.dim {{.*}} %[[C1_0]] : tensor<?x?xf32> // CHECK: %[[MASK:.*]] = vector.create_mask %[[D0_0]], %[[D1_0]] : vector<8x16xi1> + +/// --= read =--- // CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { // CHECK-SAME: vector.transfer_read %{{.*}}[%[[C0_1]], %[[C0_1]]], %[[CST]] // CHECK-SAME: {in_bounds = [true, true]} : tensor<?x?xf32>, vector<8x16xf32> // CHECK-SAME: } : vector<8x16xi1> -> vector<8x16xf32> + +/// --= shape_cast =--- // CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<8x16xf32> to vector<4x2x1x16xf32> + +/// --= transpose =--- // CHECK: %[[TR:.*]] = vector.transpose %[[SC]], [0, 2, 3, 1] : vector<4x2x1x16xf32> to vector<4x1x16x2xf32> + +/// Compute mask for xfer_write // CHECK-DAG: %[[C0_2:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index // CHECK-DAG: %[[D2:.*]] = tensor.dim %[[DEST]], {{.*}} : tensor<?x?x16x2xf32> // CHECK-DAG: %[[D3:.*]] = tensor.dim %[[DEST]], {{.*}} : tensor<?x?x16x2xf32> // CHECK: %[[MASK_0:.*]] = vector.create_mask %[[D2]], %[[D3]], %[[C16]], %[[C2]] : vector<4x1x16x2xi1> + +/// --= write =--- // CHECK: %[[WRITE:.*]] = vector.mask %[[MASK_0]] { // CHECK-SAME: vector.transfer_write %[[TR]], %[[DEST]][%[[C0_2]], %[[C0_2]], %[[C0_2]], %[[C0_2]]] // CHECK-SAME: {in_bounds = [true, true, true, true]} : vector<4x1x16x2xf32>, tensor<?x?x16x2xf32> + // CHECK: return %[[WRITE]] : tensor<?x?x16x2xf32> module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op - transform.structured.vectorize %0 vector_sizes [4, 1] : !transform.any_op + transform.structured.vectorize %0 vector_sizes [4, 1, 16, 2] : !transform.any_op + transform.yield + } +} + +// ----- + +/// Similar to the test above, but one of the inner tile sizes is dynamic. As a +/// result, more output dims are dynamic (and, e.g., output mask calcuation is a bit different). + +// CHECK-LABEL: func @pack_with_dynamic_dims_and_dynamic_inner_tile +// CHECK-SAME: %[[SRC:.*]]: tensor<?x?xf32>, +// CHECK-SAME: %[[DEST:.*]]: tensor<?x?x?x2xf32> +func.func @pack_with_dynamic_dims_and_dynamic_inner_tile( + %src: tensor<?x?xf32>, + %dest: tensor<?x?x?x2xf32>) -> tensor<?x?x?x2xf32> { + %c16 = arith.constant 16 : index + %pack = linalg.pack %src + inner_dims_pos = [1, 0] + inner_tiles = [%c16, 2] + into %dest : tensor<?x?xf32> -> tensor<?x?x?x2xf32> + return %pack : tensor<?x?x?x2xf32> +} + +// CHECK-DAG: %[[CST:.*]] = ub.poison : f32 +// CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1_0:.*]] = arith.constant 1 : index + +/// Compute mask for xfer_read +// CHECK-DAG: %[[D0_0:.*]] = tensor.dim {{.*}} %[[C0_0]] : tensor<?x?xf32> +// CHECK-DAG: %[[D1_0:.*]] = tensor.dim {{.*}} %[[C1_0]] : tensor<?x?xf32> +// CHECK: %[[MASK:.*]] = vector.create_mask %[[D0_0]], %[[D1_0]] : vector<8x16xi1> + +/// --= read =--- +// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { +// CHECK-SAME: vector.transfer_read %{{.*}}[%[[C0_1]], %[[C0_1]]], %[[CST]] +// CHECK-SAME: {in_bounds = [true, true]} : tensor<?x?xf32>, vector<8x16xf32> +// CHECK-SAME: } : vector<8x16xi1> -> vector<8x16xf32> + +/// --= shape_cast =--- +// CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<8x16xf32> to vector<4x2x1x16xf32> + +/// --= transpose =--- +// CHECK: %[[TR:.*]] = vector.transpose %[[SC]], [0, 2, 3, 1] : vector<4x2x1x16xf32> to vector<4x1x16x2xf32> + +/// Compute mask for xfer_write +// CHECK-DAG: %[[C0_2:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C2_2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[D2:.*]] = tensor.dim %[[DEST]], {{.*}} : tensor<?x?x?x2xf32> +// CHECK-DAG: %[[D3:.*]] = tensor.dim %[[DEST]], {{.*}} : tensor<?x?x?x2xf32> +// CHECK-DAG: %[[D4:.*]] = tensor.dim %[[DEST]], {{.*}} : tensor<?x?x?x2xf32> +// CHECK: %[[MASK_0:.*]] = vector.create_mask %[[D2]], %[[D3]], %[[D4]], %[[C2_2]] : vector<4x1x16x2xi1> + +/// --= write =--- +// CHECK: %[[WRITE:.*]] = vector.mask %[[MASK_0]] { +// CHECK-SAME: vector.transfer_write %[[TR]], %[[DEST]][%[[C0_2]], %[[C0_2]], %[[C0_2]], %[[C0_2]]] +// CHECK-SAME: {in_bounds = [true, true, true, true]} : vector<4x1x16x2xf32>, tensor<?x?x?x2xf32> + +// CHECK: return %[[WRITE]] : tensor<?x?x?x2xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 vector_sizes [4, 1, 16, 2] : !transform.any_op transform.yield } } diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index 3130902..e02717a 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -208,19 +208,6 @@ func.func @subview_negative_stride2(%arg0 : memref<7xf32>) -> memref<?xf32, stri // ----- -// CHECK-LABEL: func @dim_of_sized_view -// CHECK-SAME: %{{[a-z0-9A-Z_]+}}: memref<?xi8> -// CHECK-SAME: %[[SIZE:.[a-z0-9A-Z_]+]]: index -// CHECK: return %[[SIZE]] : index -func.func @dim_of_sized_view(%arg : memref<?xi8>, %size: index) -> index { - %c0 = arith.constant 0 : index - %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [%size], strides: [1] : memref<?xi8> to memref<?xi8> - %1 = memref.dim %0, %c0 : memref<?xi8> - return %1 : index -} - -// ----- - // CHECK-LABEL: func @no_fold_subview_negative_size // CHECK: %[[SUBVIEW:.+]] = memref.subview // CHECK: return %[[SUBVIEW]] diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir index 18cdfb7..4ed8d4b 100644 --- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir @@ -1455,3 +1455,20 @@ func.func @extract_strided_metadata_of_memory_space_cast_no_base(%base: memref<2 // CHECK-LABEL: func @extract_strided_metadata_of_memory_space_cast_no_base // CHECK-NOT: memref.memory_space_cast + +// ----- + +func.func @negative_memref_view_extract_aligned_pointer(%arg0: memref<?xi8>) -> index { + // `extract_aligned_pointer_as_index` must not be folded as `memref.view` can change the base pointer + // CHECK-LABEL: func @negative_memref_view_extract_aligned_pointer + // CHECK-SAME: (%[[ARG0:.*]]: memref<?xi8>) + // CHECK: %[[C10:.*]] = arith.constant 10 : index + // CHECK: %[[VIEW:.*]] = memref.view %[[ARG0]][%[[C10]]][] : memref<?xi8> to memref<f32> + // CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[VIEW]] : memref<f32> -> index + // CHECK: return %[[PTR]] : index + + %c10 = arith.constant 10 : index + %0 = memref.view %arg0[%c10][] : memref<?xi8> to memref<f32> + %1 = memref.extract_aligned_pointer_as_index %0: memref<f32> -> index + return %1 : index +} diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir index 1066526..ca91b01 100644 --- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir @@ -992,6 +992,55 @@ func.func @fold_vector_maskedstore_expand_shape( // ----- +func.func @fold_vector_transfer_read_expand_shape( + %arg0 : memref<32xf32>, %arg1 : index) -> vector<8xf32> { + %c0 = arith.constant 0 : index + %pad = ub.poison : f32 + %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32> + %1 = vector.transfer_read %0[%arg1, %c0], %pad {in_bounds = [true]} : memref<4x8xf32>, vector<8xf32> + return %1 : vector<8xf32> +} + +// CHECK-LABEL: func @fold_vector_transfer_read_expand_shape +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK: %[[C0:.*]] = arith.constant 0 +// CHECK: %[[PAD:.*]] = ub.poison : f32 +// CHECK: %[[IDX:.*]] = affine.linearize_index [%[[ARG1]], %[[C0]]] by (4, 8) +// CHECK: vector.transfer_read %[[ARG0]][%[[IDX]]], %[[PAD]] {in_bounds = [true]} + +// ----- + +func.func @fold_vector_transfer_read_with_perm_map( + %arg0 : memref<32xf32>, %arg1 : index) -> vector<4x4xf32> { + %c0 = arith.constant 0 : index + %pad = ub.poison : f32 + %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32> + %1 = vector.transfer_read %0[%arg1, %c0], %pad { permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<4x8xf32>, vector<4x4xf32> + return %1 : vector<4x4xf32> +} + +// CHECK-LABEL: func @fold_vector_transfer_read_with_perm_map +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32> +// CHECK: memref.expand_shape %[[ARG0]] {{\[}}[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32> + +// ----- + +func.func @fold_vector_transfer_read_rank_mismatch( + %arg0 : memref<32xf32>, %arg1 : index) -> vector<4x4xf32> { + %c0 = arith.constant 0 : index + %pad = ub.poison : f32 + %0 = memref.expand_shape %arg0 [[0, 1, 2]] output_shape [2, 4, 4] : memref<32xf32> into memref<2x4x4xf32> + %1 = vector.transfer_read %0[%arg1, %c0, %c0], %pad {in_bounds = [true, true]} : memref<2x4x4xf32>, vector<4x4xf32> + return %1 : vector<4x4xf32> +} + +// CHECK-LABEL: func @fold_vector_transfer_read_rank_mismatch +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32> +// CHECK: memref.expand_shape %[[ARG0]] {{\[}}[0, 1, 2]] output_shape [2, 4, 4] : memref<32xf32> into memref<2x4x4xf32> + +// ----- + func.func @fold_vector_load_collapse_shape( %arg0 : memref<4x8xf32>, %arg1 : index) -> vector<8xf32> { %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32> diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir index 5ff2920..d10651f 100644 --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -992,6 +992,22 @@ func.func @invalid_store_alignment(%memref: memref<4xi32>, %val: i32) { // ----- +func.func @invalid_alloc_alignment() { + // expected-error @below {{'memref.alloc' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} + %0 = memref.alloc() {alignment = 3} : memref<4xf32> + return +} + +// ----- + +func.func @invalid_realloc_alignment(%src: memref<4xf32>) { + // expected-error @below {{'memref.realloc' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} + %0 = memref.realloc %src {alignment = 7} : memref<4xf32> to memref<8xf32> + return +} + +// ----- + func.func @test_alloc_memref_map_rank_mismatch() { ^bb0: // expected-error@+1 {{memref layout mismatch between rank and affine map: 2 != 1}} diff --git a/mlir/test/Dialect/MemRef/mem2reg.mlir b/mlir/test/Dialect/MemRef/mem2reg.mlir index d300699..dd68675 100644 --- a/mlir/test/Dialect/MemRef/mem2reg.mlir +++ b/mlir/test/Dialect/MemRef/mem2reg.mlir @@ -18,7 +18,7 @@ func.func @basic() -> i32 { // CHECK-LABEL: func.func @basic_default func.func @basic_default() -> i32 { // CHECK-NOT: = memref.alloca - // CHECK: %[[RES:.*]] = arith.constant 0 : i32 + // CHECK: %[[RES:.*]] = ub.poison : i32 // CHECK-NOT: = memref.alloca %0 = arith.constant 5 : i32 %1 = memref.alloca() : memref<i32> diff --git a/mlir/test/Dialect/MemRef/transform-ops.mlir b/mlir/test/Dialect/MemRef/transform-ops.mlir index 3b37c62..7fc84d4 100644 --- a/mlir/test/Dialect/MemRef/transform-ops.mlir +++ b/mlir/test/Dialect/MemRef/transform-ops.mlir @@ -306,6 +306,23 @@ module attributes {transform.with_named_sequence} { // ----- +// CHECK-LABEL: func.func @dead_alloc_escaped +func.func @dead_alloc_escaped() -> memref<8x64xf32, 3> { + // CHECK: %{{.+}} = memref.alloc + %0 = memref.alloc() : memref<8x64xf32, 3> + return %0 : memref<8x64xf32, 3> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.memref.erase_dead_alloc_and_stores %0 : (!transform.any_op) -> () + transform.yield + } +} + +// ----- + // CHECK-LABEL: func.func @dead_alloc func.func @dead_alloc() { // CHECK-NOT: %{{.+}} = memref.alloc @@ -378,6 +395,73 @@ module attributes {transform.with_named_sequence} { // ----- +// CHECK-LABEL: @dead_store_through_subview +// CHECK-SAME: (%[[ARG:.+]]: vector<4xf32>) +// CHECK-NOT: memref.alloc() +// CHECK-NOT: vector.transfer_write +func.func @dead_store_through_subview(%arg: vector<4xf32>) { + %c0 = arith.constant 0 : index + %alloc = memref.alloc() {alignment = 64 : i64} : memref<64xf32> + %subview = memref.subview %alloc[%c0] [4] [1] : memref<64xf32> to memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> + vector.transfer_write %arg, %subview[%c0] {in_bounds = [true]} + : vector<4xf32>, memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.memref.erase_dead_alloc_and_stores %0 : (!transform.any_op) -> () + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @dead_store_through_expand +// CHECK-SAME: (%[[ARG:.+]]: vector<4xf32>) +// CHECK-NOT: memref.alloc() +// CHECK-NOT: vector.transfer_write +func.func @dead_store_through_expand(%arg: vector<4xf32>) { + %c0 = arith.constant 0 : index + %alloc = memref.alloc() {alignment = 64 : i64} : memref<64xf32> + %expand = memref.expand_shape %alloc [[0, 1]] output_shape [16, 4] : memref<64xf32> into memref<16x4xf32> + vector.transfer_write %arg, %expand[%c0, %c0] {in_bounds = [true]} : vector<4xf32>, memref<16x4xf32> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.memref.erase_dead_alloc_and_stores %0 : (!transform.any_op) -> () + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @dead_store_through_collapse +// CHECK-SAME: (%[[ARG:.+]]: vector<4xf32>) +// CHECK-NOT: memref.alloc() +// CHECK-NOT: vector.transfer_write +func.func @dead_store_through_collapse(%arg: vector<4xf32>) { + %c0 = arith.constant 0 : index + %alloc = memref.alloc() {alignment = 64 : i64} : memref<16x4xf32> + %collapse = memref.collapse_shape %alloc [[0, 1]] : memref<16x4xf32> into memref<64xf32> + vector.transfer_write %arg, %collapse[%c0] {in_bounds = [true]} : vector<4xf32>, memref<64xf32> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.memref.erase_dead_alloc_and_stores %0 : (!transform.any_op) -> () + transform.yield + } +} + +// ----- + // CHECK-LABEL: func @lower_to_llvm // CHECK-NOT: memref.alloc // CHECK: llvm.call @malloc diff --git a/mlir/test/Dialect/OpenACC/acc-implicit-data-reduction.mlir b/mlir/test/Dialect/OpenACC/acc-implicit-data-reduction.mlir new file mode 100644 index 0000000..fed0a4b --- /dev/null +++ b/mlir/test/Dialect/OpenACC/acc-implicit-data-reduction.mlir @@ -0,0 +1,109 @@ +// RUN: mlir-opt %s -acc-implicit-data=enable-implicit-reduction-copy=true -split-input-file | FileCheck %s --check-prefix=COPY +// RUN: mlir-opt %s -acc-implicit-data=enable-implicit-reduction-copy=false -split-input-file | FileCheck %s --check-prefix=FIRSTPRIVATE + +// Test case: scalar reduction variable in parallel loop +// When enable-implicit-reduction-copy=true: expect copyin/copyout for reduction variable +// When enable-implicit-reduction-copy=false: expect firstprivate for reduction variable + +acc.reduction.recipe @reduction_add_memref_i32 : memref<i32> reduction_operator <add> init { +^bb0(%arg0: memref<i32>): + %c0_i32 = arith.constant 0 : i32 + %alloc = memref.alloca() : memref<i32> + memref.store %c0_i32, %alloc[] : memref<i32> + acc.yield %alloc : memref<i32> +} combiner { +^bb0(%arg0: memref<i32>, %arg1: memref<i32>): + %0 = memref.load %arg0[] : memref<i32> + %1 = memref.load %arg1[] : memref<i32> + %2 = arith.addi %0, %1 : i32 + memref.store %2, %arg0[] : memref<i32> + acc.yield %arg0 : memref<i32> +} + +func.func @test_reduction_implicit_copy() { + %c1_i32 = arith.constant 1 : i32 + %c100_i32 = arith.constant 100 : i32 + %c0_i32 = arith.constant 0 : i32 + %r = memref.alloca() : memref<i32> + memref.store %c0_i32, %r[] : memref<i32> + + acc.parallel { + %red_var = acc.reduction varPtr(%r : memref<i32>) recipe(@reduction_add_memref_i32) -> memref<i32> {name = "r"} + acc.loop reduction(%red_var : memref<i32>) control(%iv : i32) = (%c1_i32 : i32) to (%c100_i32 : i32) step (%c1_i32 : i32) { + %load = memref.load %red_var[] : memref<i32> + %add = arith.addi %load, %c1_i32 : i32 + memref.store %add, %red_var[] : memref<i32> + acc.yield + } attributes {independent = [#acc.device_type<none>]} + acc.yield + } + return +} + +// When enable-implicit-reduction-copy=true: expect copyin/copyout for reduction variable +// COPY-LABEL: func.func @test_reduction_implicit_copy +// COPY: %[[COPYIN:.*]] = acc.copyin varPtr({{.*}} : memref<i32>) -> memref<i32> {dataClause = #acc<data_clause acc_reduction>, implicit = true, name = ""} +// COPY: acc.copyout accPtr(%[[COPYIN]] : memref<i32>) to varPtr({{.*}} : memref<i32>) {dataClause = #acc<data_clause acc_copy>, implicit = true, name = ""} + +// When enable-implicit-reduction-copy=false: expect firstprivate for reduction variable +// FIRSTPRIVATE-LABEL: func.func @test_reduction_implicit_copy +// FIRSTPRIVATE: acc.firstprivate varPtr({{.*}} : memref<i32>) recipe({{.*}}) -> memref<i32> {implicit = true, name = ""} +// FIRSTPRIVATE-NOT: acc.copyin +// FIRSTPRIVATE-NOT: acc.copyout + +// ----- + +// Test case: reduction variable used both in loop and outside +// Should be firstprivate regardless of the flag setting + +acc.reduction.recipe @reduction_add_memref_i32_2 : memref<i32> reduction_operator <add> init { +^bb0(%arg0: memref<i32>): + %c0_i32 = arith.constant 0 : i32 + %alloc = memref.alloca() : memref<i32> + memref.store %c0_i32, %alloc[] : memref<i32> + acc.yield %alloc : memref<i32> +} combiner { +^bb0(%arg0: memref<i32>, %arg1: memref<i32>): + %0 = memref.load %arg0[] : memref<i32> + %1 = memref.load %arg1[] : memref<i32> + %2 = arith.addi %0, %1 : i32 + memref.store %2, %arg0[] : memref<i32> + acc.yield %arg0 : memref<i32> +} + +func.func @test_reduction_with_usage_outside_loop() { + %c1_i32 = arith.constant 1 : i32 + %c100_i32 = arith.constant 100 : i32 + %c0_i32 = arith.constant 0 : i32 + %r = memref.alloca() : memref<i32> + %out = memref.alloca() : memref<i32> + memref.store %c0_i32, %r[] : memref<i32> + + %out_create = acc.create varPtr(%out : memref<i32>) -> memref<i32> {dataClause = #acc<data_clause acc_copyout>, name = "out"} + acc.parallel dataOperands(%out_create : memref<i32>) { + %red_var = acc.reduction varPtr(%r : memref<i32>) recipe(@reduction_add_memref_i32_2) -> memref<i32> {name = "r"} + acc.loop reduction(%red_var : memref<i32>) control(%iv : i32) = (%c1_i32 : i32) to (%c100_i32 : i32) step (%c1_i32 : i32) { + %load = memref.load %red_var[] : memref<i32> + %add = arith.addi %load, %c1_i32 : i32 + memref.store %add, %red_var[] : memref<i32> + acc.yield + } attributes {independent = [#acc.device_type<none>]} + // out = r (usage of r outside the loop) + %final_r = memref.load %r[] : memref<i32> + memref.store %final_r, %out_create[] : memref<i32> + acc.yield + } + acc.copyout accPtr(%out_create : memref<i32>) to varPtr(%out : memref<i32>) {dataClause = #acc<data_clause acc_copyout>, name = "out"} + return +} + +// In this case, r should be firstprivate regardless of the flag setting +// because it's used outside the reduction context +// COPY-LABEL: func.func @test_reduction_with_usage_outside_loop +// COPY: acc.firstprivate varPtr({{.*}} : memref<i32>) recipe({{.*}}) -> memref<i32> {implicit = true, name = ""} +// COPY-NOT: acc.copyin varPtr({{.*}} : memref<i32>) -> memref<i32> {{.*}} name = "" + +// FIRSTPRIVATE-LABEL: func.func @test_reduction_with_usage_outside_loop +// FIRSTPRIVATE: acc.firstprivate varPtr({{.*}} : memref<i32>) recipe({{.*}}) -> memref<i32> {implicit = true, name = ""} +// FIRSTPRIVATE-NOT: acc.copyin varPtr({{.*}} : memref<i32>) -> memref<i32> {{.*}} name = "" + diff --git a/mlir/test/Dialect/OpenACC/acc-implicit-data.mlir b/mlir/test/Dialect/OpenACC/acc-implicit-data.mlir new file mode 100644 index 0000000..6909fe6 --- /dev/null +++ b/mlir/test/Dialect/OpenACC/acc-implicit-data.mlir @@ -0,0 +1,224 @@ +// RUN: mlir-opt %s -acc-implicit-data -split-input-file | FileCheck %s + +// ----- + +// Test scalar in serial construct - should generate firstprivate +func.func @test_scalar_in_serial() { + %alloc = memref.alloca() : memref<i64> + acc.serial { + %load = memref.load %alloc[] : memref<i64> + acc.yield + } + return +} + +// CHECK-LABEL: func.func @test_scalar_in_serial +// CHECK: acc.firstprivate varPtr({{.*}} : memref<i64>) recipe({{.*}}) -> memref<i64> {implicit = true, name = ""} + +// ----- + +// Test scalar in parallel construct - should generate firstprivate +func.func @test_scalar_in_parallel() { + %alloc = memref.alloca() : memref<f32> + acc.parallel { + %load = memref.load %alloc[] : memref<f32> + acc.yield + } + return +} + +// CHECK-LABEL: func.func @test_scalar_in_parallel +// CHECK: acc.firstprivate varPtr({{.*}} : memref<f32>) recipe({{.*}}) -> memref<f32> {implicit = true, name = ""} + +// ----- + +// Test scalar in kernels construct - should generate copyin/copyout +func.func @test_scalar_in_kernels() { + %alloc = memref.alloca() : memref<f64> + acc.kernels { + %load = memref.load %alloc[] : memref<f64> + acc.terminator + } + return +} + +// CHECK-LABEL: func.func @test_scalar_in_kernels +// CHECK: %[[COPYIN:.*]] = acc.copyin varPtr({{.*}} : memref<f64>) -> memref<f64> {dataClause = #acc<data_clause acc_copy>, implicit = true, name = ""} +// CHECK: acc.copyout accPtr(%[[COPYIN]] : memref<f64>) to varPtr({{.*}} : memref<f64>) {dataClause = #acc<data_clause acc_copy>, implicit = true, name = ""} + +// ----- + +// Test scalar in parallel with default(none) - should NOT generate implicit data +func.func @test_scalar_parallel_defaultnone() { + %alloc = memref.alloca() : memref<f32> + acc.parallel { + %load = memref.load %alloc[] : memref<f32> + acc.yield + } attributes {defaultAttr = #acc<defaultvalue none>} + return +} + +// CHECK-LABEL: func.func @test_scalar_parallel_defaultnone +// CHECK-NOT: acc.firstprivate +// CHECK-NOT: acc.copyin + +// ----- + +// Test array in parallel - should generate copyin/copyout +func.func @test_array_in_parallel() { + %alloc = memref.alloca() : memref<10xf32> + acc.parallel { + %c0 = arith.constant 0 : index + %load = memref.load %alloc[%c0] : memref<10xf32> + acc.yield + } + return +} + +// CHECK-LABEL: func.func @test_array_in_parallel +// CHECK: %[[COPYIN:.*]] = acc.copyin varPtr({{.*}} : memref<10xf32>) -> memref<10xf32> {dataClause = #acc<data_clause acc_copy>, implicit = true, name = ""} +// CHECK: acc.copyout accPtr(%[[COPYIN]] : memref<10xf32>) to varPtr({{.*}} : memref<10xf32>) {dataClause = #acc<data_clause acc_copy>, implicit = true, name = ""} + +// ----- + +// Test array in kernels - should generate copyin/copyout +func.func @test_array_in_kernels() { + %alloc = memref.alloca() : memref<20xi32> + acc.kernels { + %c0 = arith.constant 0 : index + %load = memref.load %alloc[%c0] : memref<20xi32> + acc.terminator + } + return +} + +// CHECK-LABEL: func.func @test_array_in_kernels +// CHECK: %[[COPYIN:.*]] = acc.copyin varPtr({{.*}} : memref<20xi32>) -> memref<20xi32> {dataClause = #acc<data_clause acc_copy>, implicit = true, name = ""} +// CHECK: acc.copyout accPtr(%[[COPYIN]] : memref<20xi32>) to varPtr({{.*}} : memref<20xi32>) {dataClause = #acc<data_clause acc_copy>, implicit = true, name = ""} + +// ----- + +// Test array with default(present) - should generate present +func.func @test_array_parallel_defaultpresent() { + %alloc = memref.alloca() : memref<10xf32> + acc.parallel { + %c0 = arith.constant 0 : index + %load = memref.load %alloc[%c0] : memref<10xf32> + acc.yield + } attributes {defaultAttr = #acc<defaultvalue present>} + return +} + +// CHECK-LABEL: func.func @test_array_parallel_defaultpresent +// CHECK: %[[PRESENT:.*]] = acc.present varPtr({{.*}} : memref<10xf32>) -> memref<10xf32> {acc.from_default, implicit = true, name = ""} +// CHECK: acc.delete accPtr(%[[PRESENT]] : memref<10xf32>) {dataClause = #acc<data_clause acc_present>, implicit = true, name = ""} + +// ----- + +// Test scalar with default(present) - should still generate firstprivate (scalars ignore default(present)) +func.func @test_scalar_parallel_defaultpresent() { + %alloc = memref.alloca() : memref<f32> + acc.parallel { + %load = memref.load %alloc[] : memref<f32> + acc.yield + } attributes {defaultAttr = #acc<defaultvalue present>} + return +} + +// CHECK-LABEL: func.func @test_scalar_parallel_defaultpresent +// CHECK: acc.firstprivate varPtr({{.*}} : memref<f32>) recipe({{.*}}) -> memref<f32> {implicit = true, name = ""} + +// ----- + +// Test multidimensional array +func.func @test_multidim_array_in_parallel() { + %alloc = memref.alloca() : memref<8x16xf32> + acc.parallel { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %load = memref.load %alloc[%c0, %c1] : memref<8x16xf32> + acc.yield + } + return +} + +// CHECK-LABEL: func.func @test_multidim_array_in_parallel +// CHECK: %[[COPYIN:.*]] = acc.copyin varPtr({{.*}} : memref<8x16xf32>) -> memref<8x16xf32> {dataClause = #acc<data_clause acc_copy>, implicit = true, name = ""} +// CHECK: acc.copyout accPtr(%[[COPYIN]] : memref<8x16xf32>) to varPtr({{.*}} : memref<8x16xf32>) {dataClause = #acc<data_clause acc_copy>, implicit = true, name = ""} + +// ----- + +// Test dynamic size array +func.func @test_dynamic_array(%size: index) { + %alloc = memref.alloca(%size) : memref<?xf64> + acc.parallel { + %c0 = arith.constant 0 : index + %load = memref.load %alloc[%c0] : memref<?xf64> + acc.yield + } + return +} + +// CHECK-LABEL: func.func @test_dynamic_array +// CHECK: %[[COPYIN:.*]] = acc.copyin varPtr({{.*}} : memref<?xf64>) -> memref<?xf64> {dataClause = #acc<data_clause acc_copy>, implicit = true, name = ""} +// CHECK: acc.copyout accPtr(%[[COPYIN]] : memref<?xf64>) to varPtr({{.*}} : memref<?xf64>) {dataClause = #acc<data_clause acc_copy>, implicit = true, name = ""} + +// ----- + +// Test variable with explicit data clause - implicit should recognize it +func.func @test_with_explicit_copyin() { + %alloc = memref.alloca() : memref<100xf32> + %copyin = acc.copyin varPtr(%alloc : memref<100xf32>) -> memref<100xf32> {name = "explicit"} + acc.parallel dataOperands(%copyin : memref<100xf32>) { + %c0 = arith.constant 0 : index + %load = memref.load %alloc[%c0] : memref<100xf32> + acc.yield + } + acc.copyout accPtr(%copyin : memref<100xf32>) to varPtr(%alloc : memref<100xf32>) {name = "explicit"} + return +} + +// CHECK-LABEL: func.func @test_with_explicit_copyin +// CHECK: acc.present varPtr({{.*}} : memref<100xf32>) -> memref<100xf32> {implicit = true, name = ""} + +// ----- + +// Test multiple variables +func.func @test_multiple_variables() { + %alloc1 = memref.alloca() : memref<f32> + %alloc2 = memref.alloca() : memref<10xi32> + acc.parallel { + %load1 = memref.load %alloc1[] : memref<f32> + %c0 = arith.constant 0 : index + %load2 = memref.load %alloc2[%c0] : memref<10xi32> + acc.yield + } + return +} + +// CHECK-LABEL: func.func @test_multiple_variables +// CHECK: acc.firstprivate varPtr({{.*}} : memref<f32>) recipe({{.*}}) -> memref<f32> {implicit = true, name = ""} +// CHECK: %[[COPYIN:.*]] = acc.copyin varPtr({{.*}} : memref<10xi32>) -> memref<10xi32> {dataClause = #acc<data_clause acc_copy>, implicit = true, name = ""} +// CHECK: acc.copyout accPtr(%[[COPYIN]] : memref<10xi32>) to varPtr({{.*}} : memref<10xi32>) {dataClause = #acc<data_clause acc_copy>, implicit = true, name = ""} + +// ----- + +// Test memref.view aliasing - view of explicitly copied buffer should generate present +func.func @test_memref_view(%size: index) { + %c0 = arith.constant 0 : index + %buffer = memref.alloca(%size) : memref<?xi8> + %copyin = acc.copyin varPtr(%buffer : memref<?xi8>) -> memref<?xi8> {name = "buffer"} + %view = memref.view %buffer[%c0][] : memref<?xi8> to memref<8x64xf32> + acc.kernels dataOperands(%copyin : memref<?xi8>) { + %c0_0 = arith.constant 0 : index + %c0_1 = arith.constant 0 : index + %load = memref.load %view[%c0_0, %c0_1] : memref<8x64xf32> + acc.terminator + } + acc.copyout accPtr(%copyin : memref<?xi8>) to varPtr(%buffer : memref<?xi8>) {name = "buffer"} + return +} + +// CHECK-LABEL: func.func @test_memref_view +// CHECK: acc.present varPtr({{.*}} : memref<8x64xf32>) -> memref<8x64xf32> {implicit = true, name = ""} + diff --git a/mlir/test/Dialect/OpenACC/acc-implicit-declare.mlir b/mlir/test/Dialect/OpenACC/acc-implicit-declare.mlir new file mode 100644 index 0000000..74ff338 --- /dev/null +++ b/mlir/test/Dialect/OpenACC/acc-implicit-declare.mlir @@ -0,0 +1,175 @@ +// RUN: mlir-opt %s --pass-pipeline="builtin.module(acc-implicit-declare)" -split-input-file 2>&1 | FileCheck %s + +// ----- + +// Test that non-constant scalar globals in compute regions are hoisted +// instead of being marked with acc declare + +memref.global @gscalar : memref<f32> = dense<0.0> + +func.func @test_scalar_in_serial() { + acc.serial { + %addr = memref.get_global @gscalar : memref<f32> + %load = memref.load %addr[] : memref<f32> + acc.yield + } + return +} + +// Expected to hoist this global access out of acc region instead of marking +// with `acc declare`. +// CHECK-LABEL: func.func @test_scalar_in_serial +// CHECK: memref.get_global @gscalar +// CHECK: acc.serial +// CHECK-NOT: acc.declare + +// ----- + +// Test that constant globals are marked with acc declare + +memref.global constant @gscalarconst : memref<f32> = dense<1.0> + +func.func @test_constant_in_serial() { + acc.serial { + %addr = memref.get_global @gscalarconst : memref<f32> + %load = memref.load %addr[] : memref<f32> + acc.yield + } + return +} + +// This is expected to be `acc declare`'d since it is a constant. +// CHECK: memref.global constant @gscalarconst {{.*}} {acc.declare = #acc.declare<dataClause = acc_copyin>} + +// ----- + +// Test globals referenced in acc routine functions + +memref.global @gscalar_routine : memref<f32> = dense<0.0> + +acc.routine @acc_routine_0 func(@test_scalar_in_accroutine) +func.func @test_scalar_in_accroutine() attributes {acc.routine_info = #acc.routine_info<[@acc_routine_0]>} { + %addr = memref.get_global @gscalar_routine : memref<f32> + %load = memref.load %addr[] : memref<f32> + return +} + +// Global should be acc declare'd because it's in an acc routine +// CHECK: memref.global @gscalar_routine {{.*}} {acc.declare = #acc.declare<dataClause = acc_copyin>} + +// ----- + +// Test constant globals in acc routine + +memref.global constant @gscalarconst_routine : memref<f32> = dense<1.0> + +acc.routine @acc_routine_0 func(@test_constant_in_accroutine) +func.func @test_constant_in_accroutine() attributes {acc.routine_info = #acc.routine_info<[@acc_routine_0]>} { + %addr = memref.get_global @gscalarconst_routine : memref<f32> + %load = memref.load %addr[] : memref<f32> + return +} + +// CHECK: memref.global constant @gscalarconst_routine {{.*}} {acc.declare = #acc.declare<dataClause = acc_copyin>} + +// ----- + +// Test acc.private.recipe with global reference - referenced variant + +memref.global @global_for_private : memref<f32> = dense<0.0> + +acc.private.recipe @private_recipe_with_global : memref<f32> init { +^bb0(%arg0: memref<f32>): + %0 = memref.alloc() : memref<f32> + %global_addr = memref.get_global @global_for_private : memref<f32> + %global_val = memref.load %global_addr[] : memref<f32> + memref.store %global_val, %0[] : memref<f32> + acc.yield %0 : memref<f32> +} destroy { +^bb0(%arg0: memref<f32>): + memref.dealloc %arg0 : memref<f32> + acc.terminator +} + +func.func @test_private_recipe_referenced() { + %var = memref.alloc() : memref<f32> + %priv = acc.private varPtr(%var : memref<f32>) recipe(@private_recipe_with_global) -> memref<f32> + acc.parallel private(%priv : memref<f32>) { + %load = memref.load %var[] : memref<f32> + acc.yield + } + memref.dealloc %var : memref<f32> + return +} + +// Global should be acc declare'd because the recipe is referenced +// CHECK: memref.global @global_for_private {{.*}} {acc.declare = #acc.declare<dataClause = acc_copyin>} + +// ----- + +// Test acc.private.recipe with global reference - unreferenced variant + +memref.global @global_for_private_unused : memref<f32> = dense<0.0> + +acc.private.recipe @private_recipe_unused : memref<f32> init { +^bb0(%arg0: memref<f32>): + %0 = memref.alloc() : memref<f32> + %global_addr = memref.get_global @global_for_private_unused : memref<f32> + %global_val = memref.load %global_addr[] : memref<f32> + memref.store %global_val, %0[] : memref<f32> + acc.yield %0 : memref<f32> +} destroy { +^bb0(%arg0: memref<f32>): + memref.dealloc %arg0 : memref<f32> + acc.terminator +} + +func.func @test_private_recipe_not_referenced() { + %var = memref.alloc() : memref<f32> + acc.parallel { + %load = memref.load %var[] : memref<f32> + acc.yield + } + memref.dealloc %var : memref<f32> + return +} + +// Global should NOT be acc declare'd because the recipe is not referenced +// CHECK-NOT: memref.global @global_for_private_unused {{.*}} {acc.declare + +// ----- + +// Test globals in different compute constructs (parallel, kernels, serial) + +memref.global @global_parallel : memref<f32> = dense<0.0> +memref.global @global_kernels : memref<f32> = dense<0.0> +memref.global constant @global_serial_const : memref<f32> = dense<1.0> + +func.func @test_multiple_constructs() { + acc.parallel { + %addr = memref.get_global @global_parallel : memref<f32> + %load = memref.load %addr[] : memref<f32> + acc.yield + } + acc.kernels { + %addr = memref.get_global @global_kernels : memref<f32> + %load = memref.load %addr[] : memref<f32> + acc.terminator + } + acc.serial { + %addr = memref.get_global @global_serial_const : memref<f32> + %load = memref.load %addr[] : memref<f32> + acc.yield + } + return +} + +// Non-constant globals ARE hoisted before their compute regions +// Constant global should be marked with acc.declare +// CHECK: memref.global constant @global_serial_const {{.*}} {acc.declare = #acc.declare<dataClause = acc_copyin>} +// CHECK-LABEL: func.func @test_multiple_constructs +// CHECK: memref.get_global @global_parallel +// CHECK-NEXT: acc.parallel +// CHECK: memref.get_global @global_kernels +// CHECK-NEXT: acc.kernels + diff --git a/mlir/test/Dialect/OpenACC/canonicalize.mlir b/mlir/test/Dialect/OpenACC/canonicalize.mlir index fdc8e6b..38d3df3 100644 --- a/mlir/test/Dialect/OpenACC/canonicalize.mlir +++ b/mlir/test/Dialect/OpenACC/canonicalize.mlir @@ -219,3 +219,30 @@ func.func @update_unnecessary_computations(%x: memref<i32>) { // CHECK-LABEL: func.func @update_unnecessary_computations // CHECK-NOT: acc.atomic.update // CHECK: acc.atomic.write + +// ----- + +func.func @kernel_environment_canonicalization(%q1: i32, %q2: i32, %q3: i32) { + // Empty kernel_environment (no wait) - should be removed + acc.kernel_environment { + } + + acc.kernel_environment wait({%q1 : i32, %q2 : i32}) { + } + + acc.kernel_environment wait { + } + + acc.kernel_environment wait({%q3 : i32} [#acc.device_type<nvidia>]) { + } + + return +} + +// CHECK-LABEL: func.func @kernel_environment_canonicalization +// CHECK-SAME: ([[Q1:%.*]]: i32, [[Q2:%.*]]: i32, [[Q3:%.*]]: i32) +// CHECK-NOT: acc.kernel_environment wait({{.*}}[#acc.device_type<none>]) +// CHECK: acc.wait([[Q1]], [[Q2]] : i32, i32) +// CHECK: acc.wait{{$}} +// CHECK: acc.kernel_environment wait({{.*}}[#acc.device_type<nvidia>]) +// CHECK: return diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir index 26b63fb..d1a1c93 100644 --- a/mlir/test/Dialect/OpenACC/invalid.mlir +++ b/mlir/test/Dialect/OpenACC/invalid.mlir @@ -76,27 +76,65 @@ acc.loop { // ----- -// expected-error@+1 {{'acc.loop' op duplicate device_type found in gang attribute}} +// expected-error@+1 {{'acc.loop' op duplicate device_type `none` found in gang attribute}} acc.loop { acc.yield } attributes {gang = [#acc.device_type<none>, #acc.device_type<none>]} // ----- -// expected-error@+1 {{'acc.loop' op duplicate device_type found in worker attribute}} +// expected-error@+1 {{'acc.loop' op duplicate device_type `none` found in worker attribute}} acc.loop { acc.yield } attributes {worker = [#acc.device_type<none>, #acc.device_type<none>]} // ----- -// expected-error@+1 {{'acc.loop' op duplicate device_type found in vector attribute}} +// expected-error@+1 {{'acc.loop' op duplicate device_type `none` found in vector attribute}} acc.loop { acc.yield } attributes {vector = [#acc.device_type<none>, #acc.device_type<none>]} // ----- +// expected-error@+1 {{'acc.loop' op duplicate device_type `nvidia` found in gang attribute}} +acc.loop { + acc.yield +} attributes {gang = [#acc.device_type<nvidia>, #acc.device_type<nvidia>]} + +// ----- + +// expected-error@+1 {{'acc.loop' op duplicate device_type `none` found in collapseDeviceType attribute}} +acc.loop { + acc.yield +} attributes {collapse = [1, 1], collapseDeviceType = [#acc.device_type<none>, #acc.device_type<none>], independent = [#acc.device_type<none>]} + +// ----- + +%i64value = arith.constant 1 : i64 +// expected-error@+1 {{'acc.loop' op duplicate device_type `none` found in workerNumOperandsDeviceType attribute}} +acc.loop worker(%i64value: i64, %i64value: i64) { + acc.yield +} attributes {workerNumOperandsDeviceType = [#acc.device_type<none>, #acc.device_type<none>], independent = [#acc.device_type<none>]} + +// ----- + +%i64value = arith.constant 1 : i64 +// expected-error@+1 {{'acc.loop' op duplicate device_type `none` found in vectorOperandsDeviceType attribute}} +acc.loop vector(%i64value: i64, %i64value: i64) { + acc.yield +} attributes {vectorOperandsDeviceType = [#acc.device_type<none>, #acc.device_type<none>], independent = [#acc.device_type<none>]} + +// ----- + +func.func @acc_routine_parallelism() -> () { + return +} +// expected-error@+1 {{only one of `gang`, `worker`, `vector`, `seq` can be present at the same time for device_type `nvidia`}} +"acc.routine"() <{func_name = @acc_routine_parallelism, sym_name = "acc_routine_parallelism_rout", gang = [#acc.device_type<nvidia>], worker = [#acc.device_type<nvidia>]}> : () -> () + +// ----- + %1 = arith.constant 1 : i32 %2 = arith.constant 10 : i32 // expected-error@+1 {{only one of auto, independent, seq can be present at the same time}} @@ -483,12 +521,12 @@ acc.loop gang({static=%i64Value: i64, ) control(%iv : i32) = (%1 : i32) to (%2 : // ----- -func.func @fct1(%0 : !llvm.ptr) -> () { - // expected-error@+1 {{expected symbol reference @privatization_i32 to point to a private declaration}} - acc.serial private(@privatization_i32 -> %0 : !llvm.ptr) { - } - return -} +%i1 = arith.constant 1 : i32 +%i2 = arith.constant 10 : i32 +// expected-error@+1 {{unstructured acc.loop must not have induction variables}} +acc.loop control(%iv : i32) = (%i1 : i32) to (%i2 : i32) step (%i1 : i32) { + acc.yield +} attributes {independent = [#acc.device_type<none>], unstructured} // ----- @@ -834,6 +872,76 @@ func.func @acc_loop_container() { // ----- +func.func @fct1(%0 : !llvm.ptr) -> () { + // expected-error @below {{expected symbol reference @privatization_i32 to point to a private declaration}} + %priv = acc.private varPtr(%0 : !llvm.ptr) varType(i32) recipe(@privatization_i32) -> !llvm.ptr + return +} + +// ----- + +acc.private.recipe @privatization_i32 : !llvm.ptr init { +^bb0(%arg0: !llvm.ptr): + %c1 = arith.constant 1 : i32 + %c0 = arith.constant 0 : i32 + %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr + llvm.store %c0, %0 : i32, !llvm.ptr + acc.yield %0 : !llvm.ptr +} + +func.func @fct1(%0 : !llvm.ptr) -> () { + %priv = acc.private varPtr(%0 : !llvm.ptr) varType(i32) recipe(@privatization_i32) -> !llvm.ptr + // expected-error @below {{expected firstprivate as defining op}} + acc.serial firstprivate(%priv : !llvm.ptr) { + } + return +} + +// ----- + +acc.private.recipe @privatization_i32 : !llvm.ptr init { +^bb0(%arg0: !llvm.ptr): + %c1 = arith.constant 1 : i32 + %c0 = arith.constant 0 : i32 + %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr + llvm.store %c0, %0 : i32, !llvm.ptr + acc.yield %0 : !llvm.ptr +} + +func.func @fct1(%0 : !llvm.ptr) -> () { + %priv = acc.private varPtr(%0 : !llvm.ptr) varType(i32) recipe(@privatization_i32) -> !llvm.ptr + // expected-error @below {{op private operand appears more than once}} + acc.serial private(%priv, %priv : !llvm.ptr, !llvm.ptr) { + } + return +} + +// ----- + +func.func @fct1(%0 : !llvm.ptr) -> () { + // expected-error @below {{op recipe expected for private}} + %priv = acc.private varPtr(%0 : !llvm.ptr) varType(i32) -> !llvm.ptr + return +} + +// ----- + +func.func @fct1(%0 : !llvm.ptr) -> () { + // expected-error @below {{op recipe expected for firstprivate}} + %priv = acc.firstprivate varPtr(%0 : !llvm.ptr) varType(i32) -> !llvm.ptr + return +} + +// ----- + +func.func @fct1(%0 : !llvm.ptr) -> () { + // expected-error @below {{op recipe expected for reduction}} + %priv = acc.reduction varPtr(%0 : !llvm.ptr) varType(i32) -> !llvm.ptr + return +} + +// ----- + func.func @verify_declare_enter(%arg0 : memref<i32>) { // expected-error @below {{expect valid declare data entry operation or acc.getdeviceptr as defining op}} %0 = acc.declare_enter dataOperands(%arg0 : memref<i32>) diff --git a/mlir/test/Dialect/OpenACC/legalize-data.mlir b/mlir/test/Dialect/OpenACC/legalize-data.mlir index 40604dc..c7ef47c 100644 --- a/mlir/test/Dialect/OpenACC/legalize-data.mlir +++ b/mlir/test/Dialect/OpenACC/legalize-data.mlir @@ -129,8 +129,8 @@ func.func @test(%a: memref<10xf32>) { %lb = arith.constant 0 : index %st = arith.constant 1 : index %c10 = arith.constant 10 : index - %p1 = acc.private varPtr(%a : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> - acc.parallel private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) { + %p1 = acc.private varPtr(%a : memref<10xf32>) varType(tensor<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32> + acc.parallel private(%p1 : memref<10xf32>) { acc.loop control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) { %ci = memref.load %a[%i] : memref<10xf32> acc.yield @@ -142,8 +142,8 @@ func.func @test(%a: memref<10xf32>) { // CHECK-LABEL: func.func @test // CHECK-SAME: (%[[A:.*]]: memref<10xf32>) -// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> -// CHECK: acc.parallel private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) { +// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32> +// CHECK: acc.parallel private(%[[PRIVATE]] : memref<10xf32>) { // CHECK: acc.loop control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index) step (%{{.*}} : index) { // DEVICE: %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32> // CHECK: acc.yield @@ -167,9 +167,9 @@ func.func @test(%a: memref<10xf32>) { %lb = arith.constant 0 : index %st = arith.constant 1 : index %c10 = arith.constant 10 : index - %p1 = acc.private varPtr(%a : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> + %p1 = acc.private varPtr(%a : memref<10xf32>) varType(tensor<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32> acc.parallel { - acc.loop private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) { + acc.loop private(%p1 : memref<10xf32>) control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) { %ci = memref.load %a[%i] : memref<10xf32> acc.yield } attributes {independent = [#acc.device_type<none>]} @@ -180,9 +180,9 @@ func.func @test(%a: memref<10xf32>) { // CHECK-LABEL: func.func @test // CHECK-SAME: (%[[A:.*]]: memref<10xf32>) -// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> +// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32> // CHECK: acc.parallel { -// CHECK: acc.loop private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index) step (%{{.*}} : index) { +// CHECK: acc.loop private(%[[PRIVATE]] : memref<10xf32>) control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index) step (%{{.*}} : index) { // DEVICE: %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32> // CHECK: acc.yield // CHECK: } attributes {independent = [#acc.device_type<none>]} @@ -205,8 +205,8 @@ func.func @test(%a: memref<10xf32>) { %lb = arith.constant 0 : index %st = arith.constant 1 : index %c10 = arith.constant 10 : index - %p1 = acc.private varPtr(%a : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> - acc.serial private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) { + %p1 = acc.private varPtr(%a : memref<10xf32>) varType(tensor<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32> + acc.serial private(%p1 : memref<10xf32>) { acc.loop control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) { %ci = memref.load %a[%i] : memref<10xf32> acc.yield @@ -218,8 +218,8 @@ func.func @test(%a: memref<10xf32>) { // CHECK-LABEL: func.func @test // CHECK-SAME: (%[[A:.*]]: memref<10xf32>) -// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> -// CHECK: acc.serial private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) { +// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32> +// CHECK: acc.serial private(%[[PRIVATE]] : memref<10xf32>) { // CHECK: acc.loop control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index) step (%{{.*}} : index) { // DEVICE: %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32> // CHECK: acc.yield diff --git a/mlir/test/Dialect/OpenACC/legalize-serial.mlir b/mlir/test/Dialect/OpenACC/legalize-serial.mlir new file mode 100644 index 0000000..774c6b6 --- /dev/null +++ b/mlir/test/Dialect/OpenACC/legalize-serial.mlir @@ -0,0 +1,164 @@ +// RUN: mlir-opt %s -acc-legalize-serial | FileCheck %s + +acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init { +^bb0(%arg0: memref<10xf32>): + %0 = memref.alloc() : memref<10xf32> + acc.yield %0 : memref<10xf32> +} destroy { +^bb0(%arg0: memref<10xf32>): + memref.dealloc %arg0 : memref<10xf32> + acc.terminator +} + +acc.private.recipe @privatization_memref_10_10_f32 : memref<10x10xf32> init { +^bb0(%arg0: memref<10x10xf32>): + %0 = memref.alloc() : memref<10x10xf32> + acc.yield %0 : memref<10x10xf32> +} destroy { +^bb0(%arg0: memref<10x10xf32>): + memref.dealloc %arg0 : memref<10x10xf32> + acc.terminator +} + +acc.firstprivate.recipe @firstprivatization_memref_10xf32 : memref<10xf32> init { +^bb0(%arg0: memref<10xf32>): + %0 = memref.alloc() : memref<10xf32> + acc.yield %0 : memref<10xf32> +} copy { +^bb0(%arg0: memref<10xf32>, %arg1: memref<10xf32>): + acc.terminator +} destroy { +^bb0(%arg0: memref<10xf32>): + memref.dealloc %arg0 : memref<10xf32> + acc.terminator +} + +acc.reduction.recipe @reduction_add_i64 : i64 reduction_operator<add> init { +^bb0(%0: i64): + %1 = arith.constant 0 : i64 + acc.yield %1 : i64 +} combiner { +^bb0(%0: i64, %1: i64): + %2 = arith.addi %0, %1 : i64 + acc.yield %2 : i64 +} + +acc.reduction.recipe @reduction_add_memref_i64 : memref<i64> reduction_operator<add> init { +^bb0(%arg0: memref<i64>): + %0 = memref.alloca() : memref<i64> + %c0 = arith.constant 0 : i64 + memref.store %c0, %0[] : memref<i64> + acc.yield %0 : memref<i64> +} combiner { +^bb0(%arg0: memref<i64>, %arg1: memref<i64>): + %0 = memref.load %arg0[] : memref<i64> + %1 = memref.load %arg1[] : memref<i64> + %2 = arith.addi %0, %1 : i64 + memref.store %2, %arg0[] : memref<i64> + acc.terminator +} + +// CHECK: func.func @testserialop(%[[VAL_0:.*]]: memref<10xf32>, %[[VAL_1:.*]]: memref<10xf32>, %[[VAL_2:.*]]: memref<10x10xf32>) { +// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i64 +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index +// CHECK: acc.parallel async(%[[VAL_3]] : i64) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: acc.parallel async(%[[VAL_4]] : i32) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: acc.parallel async(%[[VAL_5]] : index) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_3]] : i64}) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_4]] : i32}) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_5]] : index}) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_3]] : i64, %[[VAL_4]] : i32, %[[VAL_5]] : index}) { +// CHECK: } +// CHECK: %[[VAL_6:.*]] = acc.firstprivate varPtr(%[[VAL_1]] : memref<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32> +// CHECK: %[[VAL_9:.*]] = acc.private varPtr(%[[VAL_2]] : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32> +// CHECK: acc.parallel firstprivate(%[[VAL_6]] : memref<10xf32>) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) private(%[[VAL_9]] : memref<10x10xf32>) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: %[[VAL_7:.*]] = acc.copyin varPtr(%[[VAL_0]] : memref<10xf32>) -> memref<10xf32> {dataClause = #acc<data_clause acc_copy>} +// CHECK: acc.parallel dataOperands(%[[VAL_7]] : memref<10xf32>) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: %[[I64MEM:.*]] = memref.alloca() : memref<i64> +// CHECK: memref.store %[[VAL_3]], %[[I64MEM]][] : memref<i64> +// CHECK: %[[VAL_10:.*]] = acc.reduction varPtr(%[[I64MEM]] : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64> +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) reduction(%[[VAL_10]] : memref<i64>) { +// CHECK: } +// CHECK: acc.parallel combined(loop) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: acc.loop combined(serial) control(%{{.*}} : index) = (%[[VAL_5]] : index) to (%[[VAL_5]] : index) step (%[[VAL_5]] : index) { +// CHECK: acc.yield +// CHECK: } attributes {seq = [#acc.device_type<none>]} +// CHECK: acc.terminator +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } attributes {defaultAttr = #acc<defaultvalue none>} +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } attributes {defaultAttr = #acc<defaultvalue present>} +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } attributes {selfAttr} +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: acc.yield +// CHECK: } attributes {selfAttr} +// CHECK: return +// CHECK: } + +func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf32>) -> () { + %i64value = arith.constant 1 : i64 + %i32value = arith.constant 1 : i32 + %idxValue = arith.constant 1 : index + acc.serial async(%i64value: i64) { + } + acc.serial async(%i32value: i32) { + } + acc.serial async(%idxValue: index) { + } + acc.serial wait({%i64value: i64}) { + } + acc.serial wait({%i32value: i32}) { + } + acc.serial wait({%idxValue: index}) { + } + acc.serial wait({%i64value : i64, %i32value : i32, %idxValue : index}) { + } + %firstprivate = acc.firstprivate varPtr(%b : memref<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32> + %c_private = acc.private varPtr(%c : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32> + acc.serial private(%c_private : memref<10x10xf32>) firstprivate(%firstprivate : memref<10xf32>) { + } + %copyinfromcopy = acc.copyin varPtr(%a : memref<10xf32>) -> memref<10xf32> {dataClause = #acc<data_clause acc_copy>} + acc.serial dataOperands(%copyinfromcopy : memref<10xf32>) { + } + %i64mem = memref.alloca() : memref<i64> + memref.store %i64value, %i64mem[] : memref<i64> + %i64reduction = acc.reduction varPtr(%i64mem : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64> + acc.serial reduction(%i64reduction : memref<i64>) { + } + acc.serial combined(loop) { + acc.loop combined(serial) control(%arg3 : index) = (%idxValue : index) to (%idxValue : index) step (%idxValue : index) { + acc.yield + } attributes {seq = [#acc.device_type<none>]} + acc.terminator + } + acc.serial { + } attributes {defaultAttr = #acc<defaultvalue none>} + acc.serial { + } attributes {defaultAttr = #acc<defaultvalue present>} + acc.serial { + } attributes {asyncAttr} + acc.serial { + } attributes {waitAttr} + acc.serial { + } attributes {selfAttr} + acc.serial { + acc.yield + } attributes {selfAttr} + return +} + diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir index 042ee25..d31397c 100644 --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -1,8 +1,8 @@ -// RUN: mlir-opt -split-input-file %s | FileCheck %s +// RUN: mlir-opt -split-input-file %s | FileCheck %s --check-prefixes=CHECK,CHECK-3 // Verify the printed output can be parsed. -// RUN: mlir-opt -split-input-file %s | mlir-opt -split-input-file | FileCheck %s +// RUN: mlir-opt -split-input-file %s | mlir-opt -split-input-file | FileCheck %s --check-prefixes=CHECK,CHECK-3 // Verify the generic form can be parsed. -// RUN: mlir-opt -split-input-file -mlir-print-op-generic %s | mlir-opt -split-input-file | FileCheck %s +// RUN: mlir-opt -split-input-file -mlir-print-op-generic %s | mlir-opt -split-input-file | FileCheck %s --check-prefixes=CHECK,CHECK-3 func.func @compute1(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf32>) -> memref<10x10xf32> { %c0 = arith.constant 0 : index @@ -120,8 +120,8 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x %pc = acc.present varPtr(%c : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> %pd = acc.present varPtr(%d : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> acc.data dataOperands(%pa, %pb, %pc, %pd: memref<10x10xf32>, memref<10x10xf32>, memref<10xf32>, memref<10xf32>) { - %private = acc.private varPtr(%c : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> - acc.parallel num_gangs({%numGangs: i64}) num_workers(%numWorkers: i64 [#acc.device_type<nvidia>]) private(@privatization_memref_10_f32 -> %private : memref<10xf32>) { + %private = acc.private varPtr(%c : memref<10xf32>) varType(tensor<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32> + acc.parallel num_gangs({%numGangs: i64}) num_workers(%numWorkers: i64 [#acc.device_type<nvidia>]) private(%private : memref<10xf32>) { acc.loop gang control(%x : index) = (%lb : index) to (%c10 : index) step (%st : index) { acc.loop worker control(%y : index) = (%lb : index) to (%c10 : index) step (%st : index) { %axy = memref.load %a[%x, %y] : memref<10x10xf32> @@ -157,8 +157,8 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x // CHECK-NEXT: [[NUMGANG:%.*]] = arith.constant 10 : i64 // CHECK-NEXT: [[NUMWORKERS:%.*]] = arith.constant 10 : i64 // CHECK: acc.data dataOperands(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<10x10xf32>, memref<10x10xf32>, memref<10xf32>, memref<10xf32>) { -// CHECK-NEXT: %[[P_ARG2:.*]] = acc.private varPtr([[ARG2]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> -// CHECK-NEXT: acc.parallel num_gangs({[[NUMGANG]] : i64}) num_workers([[NUMWORKERS]] : i64 [#acc.device_type<nvidia>]) private(@privatization_memref_10_f32 -> %[[P_ARG2]] : memref<10xf32>) { +// CHECK-NEXT: %[[P_ARG2:.*]] = acc.private varPtr([[ARG2]] : memref<10xf32>) varType(tensor<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32> +// CHECK-NEXT: acc.parallel num_gangs({[[NUMGANG]] : i64}) num_workers([[NUMWORKERS]] : i64 [#acc.device_type<nvidia>]) private(%[[P_ARG2]] : memref<10xf32>) { // CHECK-NEXT: acc.loop gang control(%{{.*}}) = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) { // CHECK-NEXT: acc.loop worker control(%{{.*}}) = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) { // CHECK-NEXT: %{{.*}} = memref.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32> @@ -375,8 +375,8 @@ func.func @testloopfirstprivate(%a: memref<10xf32>, %b: memref<10xf32>) -> () { %c0 = arith.constant 0 : index %c10 = arith.constant 10 : index %c1 = arith.constant 1 : index - %firstprivate = acc.firstprivate varPtr(%a : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> - acc.loop firstprivate(@firstprivatization_memref_10xf32 -> %firstprivate : memref<10xf32>) control(%iv : index) = (%c0 : index) to (%c10 : index) step (%c1 : index) { + %firstprivate = acc.firstprivate varPtr(%a : memref<10xf32>) varType(tensor<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32> + acc.loop firstprivate(%firstprivate : memref<10xf32>) control(%iv : index) = (%c0 : index) to (%c10 : index) step (%c1 : index) { "test.openacc_dummy_op"() : () -> () acc.yield } attributes {inclusiveUpperbound = array<i1: true>, independent = [#acc.device_type<none>]} @@ -385,8 +385,8 @@ func.func @testloopfirstprivate(%a: memref<10xf32>, %b: memref<10xf32>) -> () { // CHECK-LABEL: func.func @testloopfirstprivate( // CHECK-SAME: %[[ARG0:.*]]: memref<10xf32>, %[[ARG1:.*]]: memref<10xf32>) -// CHECK: %[[FIRSTPRIVATE:.*]] = acc.firstprivate varPtr(%[[ARG0]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> -// CHECK: acc.loop firstprivate(@firstprivatization_memref_10xf32 -> %[[FIRSTPRIVATE]] : memref<10xf32>) control(%{{.*}}) = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) { +// CHECK: %[[FIRSTPRIVATE:.*]] = acc.firstprivate varPtr(%[[ARG0]] : memref<10xf32>) varType(tensor<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32> +// CHECK: acc.loop firstprivate(%[[FIRSTPRIVATE]] : memref<10xf32>) control(%{{.*}}) = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) { // CHECK: "test.openacc_dummy_op"() : () -> () // CHECK: acc.yield // CHECK: } attributes {inclusiveUpperbound = array<i1: true>, independent = [#acc.device_type<none>]} @@ -464,7 +464,10 @@ func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x } acc.parallel vector_length(%idxValue: index) { } - acc.parallel private(@privatization_memref_10_f32 -> %a : memref<10xf32>, @privatization_memref_10_10_f32 -> %c : memref<10x10xf32>) firstprivate(@privatization_memref_10xf32 -> %b: memref<10xf32>) { + %private_a = acc.private varPtr(%a : memref<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32> + %private_c = acc.private varPtr(%c : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32> + %firstprivate_b = acc.firstprivate varPtr(%b : memref<10xf32>) recipe(@privatization_memref_10xf32) -> memref<10xf32> + acc.parallel private(%private_a, %private_c : memref<10xf32>, memref<10x10xf32>) firstprivate(%firstprivate_b : memref<10xf32>) { } acc.parallel { } attributes {defaultAttr = #acc<defaultvalue none>} @@ -517,7 +520,10 @@ func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x // CHECK-NEXT: } // CHECK: acc.parallel vector_length([[IDXVALUE]] : index) { // CHECK-NEXT: } -// CHECK: acc.parallel firstprivate(@privatization_memref_10xf32 -> [[ARGB]] : memref<10xf32>) private(@privatization_memref_10_f32 -> [[ARGA]] : memref<10xf32>, @privatization_memref_10_10_f32 -> [[ARGC]] : memref<10x10xf32>) { +// CHECK: %[[PRIVATE_A:.*]] = acc.private varPtr([[ARGA]] : memref<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32> +// CHECK-NEXT: %[[PRIVATE_C:.*]] = acc.private varPtr([[ARGC]] : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32> +// CHECK-NEXT: %[[FIRSTPRIVATE_B:.*]] = acc.firstprivate varPtr([[ARGB]] : memref<10xf32>) recipe(@privatization_memref_10xf32) -> memref<10xf32> +// CHECK-NEXT: acc.parallel firstprivate(%[[FIRSTPRIVATE_B]] : memref<10xf32>) private(%[[PRIVATE_A]], %[[PRIVATE_C]] : memref<10xf32>, memref<10x10xf32>) { // CHECK-NEXT: } // CHECK: acc.parallel { // CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>} @@ -596,8 +602,10 @@ func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10 } acc.serial wait({%i64value : i64, %i32value : i32, %idxValue : index}) { } - %firstprivate = acc.firstprivate varPtr(%b : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> - acc.serial private(@privatization_memref_10_f32 -> %a : memref<10xf32>, @privatization_memref_10_10_f32 -> %c : memref<10x10xf32>) firstprivate(@firstprivatization_memref_10xf32 -> %firstprivate : memref<10xf32>) { + %private_a = acc.private varPtr(%a : memref<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32> + %private_c = acc.private varPtr(%c : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32> + %firstprivate = acc.firstprivate varPtr(%b : memref<10xf32>) varType(tensor<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32> + acc.serial private(%private_a, %private_c : memref<10xf32>, memref<10x10xf32>) firstprivate(%firstprivate : memref<10xf32>) { } acc.serial { } attributes {defaultAttr = #acc<defaultvalue none>} @@ -633,8 +641,10 @@ func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10 // CHECK-NEXT: } // CHECK: acc.serial wait({[[I64VALUE]] : i64, [[I32VALUE]] : i32, [[IDXVALUE]] : index}) { // CHECK-NEXT: } -// CHECK: %[[FIRSTP:.*]] = acc.firstprivate varPtr([[ARGB]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> -// CHECK: acc.serial firstprivate(@firstprivatization_memref_10xf32 -> %[[FIRSTP]] : memref<10xf32>) private(@privatization_memref_10_f32 -> [[ARGA]] : memref<10xf32>, @privatization_memref_10_10_f32 -> [[ARGC]] : memref<10x10xf32>) { +// CHECK: %[[PRIVATE_A:.*]] = acc.private varPtr([[ARGA]] : memref<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32> +// CHECK-NEXT: %[[PRIVATE_C:.*]] = acc.private varPtr([[ARGC]] : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32> +// CHECK-NEXT: %[[FIRSTP:.*]] = acc.firstprivate varPtr([[ARGB]] : memref<10xf32>) varType(tensor<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32> +// CHECK-NEXT: acc.serial firstprivate(%[[FIRSTP]] : memref<10xf32>) private(%[[PRIVATE_A]], %[[PRIVATE_C]] : memref<10xf32>, memref<10x10xf32>) { // CHECK-NEXT: } // CHECK: acc.serial { // CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>} @@ -721,6 +731,59 @@ func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10 // ----- +// Test acc.kernels with private and firstprivate operands, similar to acc.serial. + +acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init { +^bb0(%arg0: memref<10xf32>): + %0 = memref.alloc() : memref<10xf32> + acc.yield %0 : memref<10xf32> +} destroy { +^bb0(%arg0: memref<10xf32>): + memref.dealloc %arg0 : memref<10xf32> + acc.terminator +} + +acc.private.recipe @privatization_memref_10_10_f32 : memref<10x10xf32> init { +^bb0(%arg0: memref<10x10xf32>): + %1 = memref.alloc() : memref<10x10xf32> + acc.yield %1 : memref<10x10xf32> +} destroy { +^bb0(%arg0: memref<10x10xf32>): + memref.dealloc %arg0 : memref<10x10xf32> + acc.terminator +} + +acc.firstprivate.recipe @firstprivatization_memref_10xf32 : memref<10xf32> init { +^bb0(%arg0: memref<10xf32>): + %2 = memref.alloca() : memref<10xf32> + acc.yield %2 : memref<10xf32> +} copy { +^bb0(%arg0: memref<10xf32>, %arg1: memref<10xf32>): + memref.copy %arg0, %arg1 : memref<10xf32> to memref<10xf32> + acc.terminator +} destroy { +^bb0(%arg0: memref<10xf32>): + acc.terminator +} + +func.func @testkernelspriv(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf32>) -> () { + %priv_a = acc.private varPtr(%a : memref<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32> + %priv_c = acc.private varPtr(%c : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32> + %firstp = acc.firstprivate varPtr(%b : memref<10xf32>) varType(tensor<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32> + acc.kernels firstprivate(%firstp : memref<10xf32>) private(%priv_a, %priv_c : memref<10xf32>, memref<10x10xf32>) { + } + return +} + +// CHECK-LABEL: func.func @testkernelspriv( +// CHECK: %[[PRIV_A:.*]] = acc.private varPtr(%{{.*}} : memref<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32> +// CHECK: %[[PRIV_C:.*]] = acc.private varPtr(%{{.*}} : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32> +// CHECK: %[[FIRSTP:.*]] = acc.firstprivate varPtr(%{{.*}} : memref<10xf32>) varType(tensor<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32> +// CHECK: acc.kernels firstprivate(%[[FIRSTP]] : memref<10xf32>) private(%[[PRIV_A]], %[[PRIV_C]] : memref<10xf32>, memref<10x10xf32>) { +// CHECK-NEXT: } + +// ----- + func.func @testdataop(%a: memref<f32>, %b: memref<f32>, %c: memref<f32>) -> () { %ifCond = arith.constant true @@ -1511,32 +1574,43 @@ acc.private.recipe @privatization_struct_i32_i64 : !llvm.struct<(i32, i32)> init // ----- -acc.reduction.recipe @reduction_add_i64 : i64 reduction_operator<add> init { -^bb0(%arg0: i64): - %0 = arith.constant 0 : i64 - acc.yield %0 : i64 +acc.reduction.recipe @reduction_add_memref_i64 : memref<i64> reduction_operator <add> init { +^bb0(%arg0: memref<i64>): + %c0_i64 = arith.constant 0 : i64 + %alloca = memref.alloca() : memref<i64> + memref.store %c0_i64, %alloca[] : memref<i64> + acc.yield %alloca : memref<i64> } combiner { -^bb0(%arg0: i64, %arg1: i64): - %0 = arith.addi %arg0, %arg1 : i64 - acc.yield %0 : i64 +^bb0(%arg0: memref<i64>, %arg1: memref<i64>): + %0 = memref.load %arg0[] : memref<i64> + %1 = memref.load %arg1[] : memref<i64> + %2 = arith.addi %0, %1 : i64 + memref.store %2, %arg0[] : memref<i64> + acc.yield %arg0 : memref<i64> } -// CHECK-LABEL: acc.reduction.recipe @reduction_add_i64 : i64 reduction_operator <add> init { -// CHECK: ^bb0(%{{.*}}: i64): +// CHECK-LABEL: acc.reduction.recipe @reduction_add_memref_i64 : memref<i64> reduction_operator <add> init { +// CHECK: ^bb0(%{{.*}}: memref<i64>): // CHECK: %[[C0:.*]] = arith.constant 0 : i64 -// CHECK: acc.yield %[[C0]] : i64 +// CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<i64> +// CHECK: memref.store %[[C0]], %[[ALLOCA]][] : memref<i64> +// CHECK: acc.yield %[[ALLOCA]] : memref<i64> // CHECK: } combiner { -// CHECK: ^bb0(%[[ARG0:.*]]: i64, %[[ARG1:.*]]: i64): -// CHECK: %[[RES:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i64 -// CHECK: acc.yield %[[RES]] : i64 +// CHECK: ^bb0(%[[ARG0:.*]]: memref<i64>, %[[ARG1:.*]]: memref<i64>): +// CHECK: %[[LOAD0:.*]] = memref.load %[[ARG0]][] : memref<i64> +// CHECK: %[[LOAD1:.*]] = memref.load %[[ARG1]][] : memref<i64> +// CHECK: %[[RES:.*]] = arith.addi %[[LOAD0]], %[[LOAD1]] : i64 +// CHECK: memref.store %[[RES]], %[[ARG0]][] : memref<i64> +// CHECK: acc.yield %[[ARG0]] : memref<i64> // CHECK: } -func.func @acc_reduc_test(%a : i64) -> () { +func.func @acc_reduc_test(%a : memref<i64>) -> () { %c0 = arith.constant 0 : index %c10 = arith.constant 10 : index %c1 = arith.constant 1 : index - acc.parallel reduction(@reduction_add_i64 -> %a : i64) { - acc.loop reduction(@reduction_add_i64 -> %a : i64) control(%iv : index) = (%c0 : index) to (%c10 : index) step (%c1 : index) { + %reduction_a = acc.reduction varPtr(%a : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64> + acc.parallel reduction(%reduction_a : memref<i64>) { + acc.loop reduction(%reduction_a : memref<i64>) control(%iv : index) = (%c0 : index) to (%c10 : index) step (%c1 : index) { acc.yield } attributes {inclusiveUpperbound = array<i1: true>, independent = [#acc.device_type<none>]} acc.yield @@ -1545,31 +1619,68 @@ func.func @acc_reduc_test(%a : i64) -> () { } // CHECK-LABEL: func.func @acc_reduc_test( -// CHECK-SAME: %[[ARG0:.*]]: i64) -// CHECK: acc.parallel reduction(@reduction_add_i64 -> %[[ARG0]] : i64) -// CHECK: acc.loop reduction(@reduction_add_i64 -> %[[ARG0]] : i64) +// CHECK-SAME: %[[ARG0:.*]]: memref<i64>) +// CHECK: %[[REDUCTION_A:.*]] = acc.reduction varPtr(%[[ARG0]] : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64> +// CHECK-NEXT: acc.parallel reduction(%[[REDUCTION_A]] : memref<i64>) +// CHECK: acc.loop reduction(%[[REDUCTION_A]] : memref<i64>) // ----- -acc.reduction.recipe @reduction_add_i64 : i64 reduction_operator<add> init { -^bb0(%0: i64): - %1 = arith.constant 0 : i64 - acc.yield %1 : i64 +acc.reduction.recipe @reduction_add_memref_i64 : memref<i64> reduction_operator <add> init { +^bb0(%arg0: memref<i64>): + %c0_i64 = arith.constant 0 : i64 + %alloca = memref.alloca() : memref<i64> + memref.store %c0_i64, %alloca[] : memref<i64> + acc.yield %alloca : memref<i64> } combiner { -^bb0(%0: i64, %1: i64): +^bb0(%arg0: memref<i64>, %arg1: memref<i64>): + %0 = memref.load %arg0[] : memref<i64> + %1 = memref.load %arg1[] : memref<i64> %2 = arith.addi %0, %1 : i64 - acc.yield %2 : i64 + memref.store %2, %arg0[] : memref<i64> + acc.yield %arg0 : memref<i64> } -func.func @acc_reduc_test(%a : i64) -> () { - acc.serial reduction(@reduction_add_i64 -> %a : i64) { +func.func @acc_reduc_test(%a : memref<i64>) -> () { + %reduction_a = acc.reduction varPtr(%a : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64> + acc.serial reduction(%reduction_a : memref<i64>) { } return } // CHECK-LABEL: func.func @acc_reduc_test( -// CHECK-SAME: %[[ARG0:.*]]: i64) -// CHECK: acc.serial reduction(@reduction_add_i64 -> %[[ARG0]] : i64) +// CHECK-SAME: %[[ARG0:.*]]: memref<i64>) +// CHECK: %[[REDUCTION_A:.*]] = acc.reduction varPtr(%[[ARG0]] : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64> +// CHECK-NEXT: acc.serial reduction(%[[REDUCTION_A]] : memref<i64>) + +// ----- + +acc.reduction.recipe @reduction_add_memref_i64 : memref<i64> reduction_operator <add> init { +^bb0(%arg0: memref<i64>): + %c0_i64 = arith.constant 0 : i64 + %alloca = memref.alloca() : memref<i64> + memref.store %c0_i64, %alloca[] : memref<i64> + acc.yield %alloca : memref<i64> +} combiner { +^bb0(%arg0: memref<i64>, %arg1: memref<i64>): + %0 = memref.load %arg0[] : memref<i64> + %1 = memref.load %arg1[] : memref<i64> + %2 = arith.addi %0, %1 : i64 + memref.store %2, %arg0[] : memref<i64> + acc.yield %arg0 : memref<i64> +} + +func.func @acc_kernels_reduc_test(%a : memref<i64>) -> () { + %reduction_a = acc.reduction varPtr(%a : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64> + acc.kernels reduction(%reduction_a : memref<i64>) { + } + return +} + +// CHECK-LABEL: func.func @acc_kernels_reduc_test( +// CHECK-SAME: %[[ARG0:.*]]: memref<i64>) +// CHECK: %[[REDUCTION_A:.*]] = acc.reduction varPtr(%[[ARG0]] : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64> +// CHECK-NEXT: acc.kernels reduction(%[[REDUCTION_A]] : memref<i64>) // ----- @@ -1699,6 +1810,59 @@ acc.routine @acc_func_rout9 func(@acc_func) bind("acc_func_gpu_gang_dim1") gang( // ----- +// Test acc.specialized_routine attribute for specialized device functions +acc.routine @routine_seq func(@device_func_seq) seq +acc.routine @routine_gang func(@device_func_gang) gang +acc.routine @routine_gang_dim2 func(@device_func_gang_dim2) gang(dim: 2 : i64) +acc.routine @routine_gang_dim3 func(@device_func_gang_dim3) gang(dim: 3 : i64) +acc.routine @routine_worker func(@device_func_worker) worker +acc.routine @routine_vector func(@device_func_vector) vector + +func.func @device_func_seq() attributes {acc.specialized_routine = #acc.specialized_routine<@routine_seq, <seq>, "host_func_seq">} { + return +} + +func.func @device_func_gang() attributes {acc.specialized_routine = #acc.specialized_routine<@routine_gang, <gang_dim1>, "host_func_gang">} { + return +} + +func.func @device_func_gang_dim2() attributes {acc.specialized_routine = #acc.specialized_routine<@routine_gang_dim2, <gang_dim2>, "host_func_gang_dim2">} { + return +} + +func.func @device_func_gang_dim3() attributes {acc.specialized_routine = #acc.specialized_routine<@routine_gang_dim3, <gang_dim3>, "host_func_gang_dim3">} { + return +} + +func.func @device_func_worker() attributes {acc.specialized_routine = #acc.specialized_routine<@routine_worker, <worker>, "host_func_worker">} { + return +} + +func.func @device_func_vector() attributes {acc.specialized_routine = #acc.specialized_routine<@routine_vector, <vector>, "host_func_vector">} { + return +} + +// CHECK: acc.routine @routine_seq func(@device_func_seq) seq +// CHECK: acc.routine @routine_gang func(@device_func_gang) gang +// CHECK: acc.routine @routine_gang_dim2 func(@device_func_gang_dim2) gang(dim: 2 : i64) +// CHECK: acc.routine @routine_gang_dim3 func(@device_func_gang_dim3) gang(dim: 3 : i64) +// CHECK: acc.routine @routine_worker func(@device_func_worker) worker +// CHECK: acc.routine @routine_vector func(@device_func_vector) vector +// CHECK-LABEL: func.func @device_func_seq() +// CHECK: attributes {acc.specialized_routine = #acc.specialized_routine<@routine_seq, <seq>, "host_func_seq">} +// CHECK-LABEL: func.func @device_func_gang() +// CHECK: attributes {acc.specialized_routine = #acc.specialized_routine<@routine_gang, <gang_dim1>, "host_func_gang">} +// CHECK-LABEL: func.func @device_func_gang_dim2() +// CHECK: attributes {acc.specialized_routine = #acc.specialized_routine<@routine_gang_dim2, <gang_dim2>, "host_func_gang_dim2">} +// CHECK-LABEL: func.func @device_func_gang_dim3() +// CHECK: attributes {acc.specialized_routine = #acc.specialized_routine<@routine_gang_dim3, <gang_dim3>, "host_func_gang_dim3">} +// CHECK-LABEL: func.func @device_func_worker() +// CHECK: attributes {acc.specialized_routine = #acc.specialized_routine<@routine_worker, <worker>, "host_func_worker">} +// CHECK-LABEL: func.func @device_func_vector() +// CHECK: attributes {acc.specialized_routine = #acc.specialized_routine<@routine_vector, <vector>, "host_func_vector">} + +// ----- + func.func @acc_func() -> () { "test.openacc_dummy_op"() {acc.declare_action = #acc.declare_action<postAlloc = @_QMacc_declareFacc_declare_allocateEa_acc_declare_update_desc_post_alloc>} : () -> () return @@ -2143,6 +2307,20 @@ func.func @acc_loop_container() { // ----- +func.func @acc_unstructured_loop() { + acc.loop { + acc.yield + } attributes {independent = [#acc.device_type<none>], unstructured} + return +} + +// CHECK-LABEL: func.func @acc_unstructured_loop +// CHECK: acc.loop +// CHECK: acc.yield +// CHECK: } attributes {independent = [#acc.device_type<none>], unstructured} + +// ----- + // Test private recipe with data bounds for array slicing acc.private.recipe @privatization_memref_slice : memref<10x10xf32> init { ^bb0(%arg0: memref<10x10xf32>, %bounds0: !acc.data_bounds_ty, %bounds1: !acc.data_bounds_ty): diff --git a/mlir/test/Dialect/OpenACC/pointer-like-interface-load.mlir b/mlir/test/Dialect/OpenACC/pointer-like-interface-load.mlir new file mode 100644 index 0000000..36df6a1 --- /dev/null +++ b/mlir/test/Dialect/OpenACC/pointer-like-interface-load.mlir @@ -0,0 +1,29 @@ +// RUN: mlir-opt %s --split-input-file --pass-pipeline="builtin.module(func.func(test-acc-pointer-like-interface{test-mode=load}))" 2>&1 | FileCheck %s + +func.func @test_memref_load_scalar() { + %ptr = memref.alloca() {test.ptr} : memref<f32> + // CHECK: Successfully generated load for operation: %[[PTR:.*]] = memref.alloca() {test.ptr} : memref<f32> + // CHECK: Loaded value type: f32 + // CHECK: Generated: %{{.*}} = memref.load %[[PTR]][] : memref<f32> + return +} + +// ----- + +func.func @test_memref_load_int() { + %ptr = memref.alloca() {test.ptr} : memref<i64> + // CHECK: Successfully generated load for operation: %[[PTR:.*]] = memref.alloca() {test.ptr} : memref<i64> + // CHECK: Loaded value type: i64 + // CHECK: Generated: %{{.*}} = memref.load %[[PTR]][] : memref<i64> + return +} + +// ----- + +func.func @test_memref_load_dynamic() { + %c10 = arith.constant 10 : index + %ptr = memref.alloc(%c10) {test.ptr} : memref<?xf32> + // CHECK: Failed to generate load for operation: %[[PTR:.*]] = memref.alloc(%{{.*}}) {test.ptr} : memref<?xf32> + return +} + diff --git a/mlir/test/Dialect/OpenACC/pointer-like-interface-store.mlir b/mlir/test/Dialect/OpenACC/pointer-like-interface-store.mlir new file mode 100644 index 0000000..0fee431 --- /dev/null +++ b/mlir/test/Dialect/OpenACC/pointer-like-interface-store.mlir @@ -0,0 +1,39 @@ +// RUN: mlir-opt %s --split-input-file --pass-pipeline="builtin.module(func.func(test-acc-pointer-like-interface{test-mode=store}))" 2>&1 | FileCheck %s + +func.func @test_memref_store_scalar() { + %ptr = memref.alloca() {test.ptr} : memref<f32> + // CHECK: Successfully generated store for operation: %[[PTR:.*]] = memref.alloca() {test.ptr} : memref<f32> + // CHECK: Generated: %[[VAL:.*]] = arith.constant 4.200000e+01 : f32 + // CHECK: Generated: memref.store %[[VAL]], %[[PTR]][] : memref<f32> + return +} + +// ----- + +func.func @test_memref_store_int() { + %ptr = memref.alloca() {test.ptr} : memref<i32> + // CHECK: Successfully generated store for operation: %[[PTR:.*]] = memref.alloca() {test.ptr} : memref<i32> + // CHECK: Generated: %[[VAL:.*]] = arith.constant 42 : i32 + // CHECK: Generated: memref.store %[[VAL]], %[[PTR]][] : memref<i32> + return +} + +// ----- + +func.func @test_memref_store_i64() { + %ptr = memref.alloca() {test.ptr} : memref<i64> + // CHECK: Successfully generated store for operation: %[[PTR:.*]] = memref.alloca() {test.ptr} : memref<i64> + // CHECK: Generated: %[[VAL:.*]] = arith.constant 42 : i64 + // CHECK: Generated: memref.store %[[VAL]], %[[PTR]][] : memref<i64> + return +} + +// ----- + +func.func @test_memref_store_dynamic() { + %c10 = arith.constant 10 : index + %ptr = memref.alloc(%c10) {test.ptr} : memref<?xf32> + // CHECK: Failed to generate store for operation: %[[PTR:.*]] = memref.alloc(%{{.*}}) {test.ptr} : memref<?xf32> + return +} + diff --git a/mlir/test/Dialect/OpenACC/recipe-populate-private-from-firstprivate.mlir b/mlir/test/Dialect/OpenACC/recipe-populate-private-from-firstprivate.mlir new file mode 100644 index 0000000..154d44e --- /dev/null +++ b/mlir/test/Dialect/OpenACC/recipe-populate-private-from-firstprivate.mlir @@ -0,0 +1,38 @@ +// RUN: mlir-opt %s --split-input-file --pass-pipeline="builtin.module(test-acc-recipe-populate{recipe-type=private_from_firstprivate})" | FileCheck %s + +// Verify that we can create a private recipe using the convenience overload +// that takes an existing firstprivate recipe as input. For a simple scalar +// alloca-backed memref, only an init region is expected (no destroy). +// CHECK: acc.private.recipe @private_from_firstprivate_scalar : memref<f32> init { +// CHECK: ^bb0(%{{.*}}: memref<f32>): +// CHECK: %[[ALLOC:.*]] = memref.alloca() {acc.var_name = #acc.var_name<"scalar">} : memref<f32> +// CHECK: acc.yield %[[ALLOC]] : memref<f32> +// CHECK: } + +func.func @test_scalar_from_firstprivate() { + %0 = memref.alloca() {test.var = "scalar"} : memref<f32> + return +} + +// ----- + +// Verify that destroy regions are also present when creating a private recipe +// from a firstprivate recipe that requires dynamic deallocation. +// CHECK: acc.private.recipe @private_from_firstprivate_dynamic_d2 : memref<?x?xf32> init { +// CHECK: ^bb0(%[[ARG:.*]]: memref<?x?xf32>): +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM0:.*]] = memref.dim %[[ARG]], %[[C0]] : memref<?x?xf32> +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<?x?xf32> +// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) {acc.var_name = #acc.var_name<"dynamic_d2">} : memref<?x?xf32> +// CHECK: acc.yield %[[ALLOC]] : memref<?x?xf32> +// CHECK: } destroy { +// CHECK: ^bb0(%{{.*}}: memref<?x?xf32>, %[[VAL:.*]]: memref<?x?xf32>): +// CHECK: memref.dealloc %[[VAL]] : memref<?x?xf32> +// CHECK: acc.terminator +// CHECK: } + +func.func @test_dynamic_from_firstprivate(%arg0: index, %arg1: index) { + %0 = memref.alloc(%arg0, %arg1) {test.var = "dynamic_d2"} : memref<?x?xf32> + return +} diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index 084c3fc..ac590fc 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -974,6 +974,56 @@ func.func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) { // ----- +// CHECK-LABEL: @while_move_if_down +func.func @while_move_if_down() -> i32 { + %defined_outside = "test.get_some_value0" () : () -> (i32) + %0 = scf.while () : () -> (i32) { + %used_value = "test.get_some_value1" () : () -> (i32) + %used_by_subregion = "test.get_some_value2" () : () -> (i32) + %else_value = "test.get_some_value3" () : () -> (i32) + %condition = "test.condition"() : () -> i1 + %res = scf.if %condition -> (i32) { + "test.use0" (%defined_outside) : (i32) -> () + "test.use1" (%used_value) : (i32) -> () + test.alloca_scope_region { + "test.use2" (%used_by_subregion) : (i32) -> () + } + %then_value = "test.get_some_value4" () : () -> (i32) + scf.yield %then_value : i32 + } else { + scf.yield %else_value : i32 + } + scf.condition(%condition) %res : i32 + } do { + ^bb0(%res_arg: i32): + "test.use3" (%res_arg) : (i32) -> () + scf.yield + } + return %0 : i32 +} +// CHECK: %[[defined_outside:.*]] = "test.get_some_value0"() : () -> i32 +// CHECK: %[[WHILE_RES:.*]]:3 = scf.while : () -> (i32, i32, i32) { +// CHECK: %[[used_value:.*]] = "test.get_some_value1"() : () -> i32 +// CHECK: %[[used_by_subregion:.*]] = "test.get_some_value2"() : () -> i32 +// CHECK: %[[else_value:.*]] = "test.get_some_value3"() : () -> i32 +// CHECK: %[[condition:.*]] = "test.condition"() : () -> i1 +// CHECK: scf.condition(%[[condition]]) %[[else_value]], %[[used_value]], %[[used_by_subregion]] : i32, i32, i32 +// CHECK: } do { +// CHECK: ^bb0(%[[res_arg:.*]]: i32, %[[used_value_arg:.*]]: i32, %[[used_by_subregion_arg:.*]]: i32): +// CHECK: "test.use0"(%[[defined_outside]]) : (i32) -> () +// CHECK: "test.use1"(%[[used_value_arg]]) : (i32) -> () +// CHECK: test.alloca_scope_region { +// CHECK: "test.use2"(%[[used_by_subregion_arg]]) : (i32) -> () +// CHECK: } +// CHECK: %[[then_value:.*]] = "test.get_some_value4"() : () -> i32 +// CHECK: "test.use3"(%[[then_value]]) : (i32) -> () +// CHECK: scf.yield +// CHECK: } +// CHECK: return %[[WHILE_RES]]#0 : i32 +// CHECK: } + +// ----- + // CHECK-LABEL: @while_cond_true func.func @while_cond_true() -> i1 { %0 = scf.while () : () -> i1 { diff --git a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir index 8e29ff6..b70bb40 100644 --- a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir @@ -795,6 +795,53 @@ func.func @selection(%cond: i1) -> () { // ----- +func.func @selection_switch(%selector: i32) -> () { + %zero = spirv.Constant 0: i32 + %one = spirv.Constant 1: i32 + %two = spirv.Constant 2: i32 + %three = spirv.Constant 3: i32 + %var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function> + + // CHECK: spirv.mlir.selection { + spirv.mlir.selection { + // CHECK-NEXT: spirv.Switch {{%.*}} : i32, [ + // CHECK-NEXT: default: ^bb1, + // CHECK-NEXT: 0: ^bb2, + // CHECK-NEXT: 1: ^bb3 + spirv.Switch %selector : i32, [ + default: ^default, + 0: ^case0, + 1: ^case1 + ] + // CHECK: ^bb1 + ^default: + spirv.Store "Function" %var, %one : i32 + // CHECK: spirv.Branch ^bb4 + spirv.Branch ^merge + + // CHECK: ^bb2 + ^case0: + spirv.Store "Function" %var, %two : i32 + // CHECK: spirv.Branch ^bb4 + spirv.Branch ^merge + + // CHECK: ^bb3 + ^case1: + spirv.Store "Function" %var, %three : i32 + // CHECK: spirv.Branch ^bb4 + spirv.Branch ^merge + + // CHECK: ^bb4 + ^merge: + // CHECK-NEXT: spirv.mlir.merge + spirv.mlir.merge + } + + spirv.Return +} + +// ----- + // CHECK-LABEL: @empty_region func.func @empty_region() -> () { // CHECK: spirv.mlir.selection @@ -918,3 +965,171 @@ func.func @kill() { // CHECK: spirv.Kill spirv.Kill } + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.Switch +//===----------------------------------------------------------------------===// + +func.func @switch(%selector: i32) -> () { + // CHECK: spirv.Switch {{%.*}} : i32, [ + // CHECK-NEXT: default: ^bb1, + // CHECK-NEXT: 0: ^bb2, + // CHECK-NEXT: 1: ^bb3, + // CHECK-NEXT: 2: ^bb4 + spirv.Switch %selector : i32, [ + default: ^default, + 0: ^case0, + 1: ^case1, + 2: ^case2 + ] +^default: + spirv.Branch ^merge + +^case0: + spirv.Branch ^merge + +^case1: + spirv.Branch ^merge + +^case2: + spirv.Branch ^merge + +^merge: + spirv.Return +} + +func.func @switch_only_default(%selector: i32) -> () { + // CHECK: spirv.Switch {{%.*}} : i32, [ + // CHECK-NEXT: default: ^bb1 + spirv.Switch %selector : i32, [ + default: ^default + ] +^default: + spirv.Branch ^merge + +^merge: + spirv.Return +} + +func.func @switch_operands(%selector : i32, %operand : i32) { + // CHECK: spirv.Switch {{%.*}} : i32, [ + // CHECK-NEXT: default: ^bb1({{%.*}} : i32), + // CHECK-NEXT: 0: ^bb2({{%.*}} : i32), + // CHECK-NEXT: 1: ^bb3({{%.*}} : i32) + spirv.Switch %selector : i32, [ + default: ^default(%operand : i32), + 0: ^case0(%operand : i32), + 1: ^case1(%operand : i32) + ] +^default(%argd : i32): + spirv.Branch ^merge + +^case0(%arg0 : i32): + spirv.Branch ^merge + +^case1(%arg1 : i32): + spirv.Branch ^merge + +^merge: + spirv.Return +} + +// ----- + +func.func @switch_float_selector(%selector: f32) -> () { + // expected-error@+1 {{expected builtin.integer, but found 'f32'}} + spirv.Switch %selector : f32, [ + default: ^default + ] +^default: + spirv.Branch ^merge + +^merge: + spirv.Return +} + +// ----- + +func.func @switch_float_selector(%selector: i32) -> () { + // expected-error@+3 {{expected integer value}} + spirv.Switch %selector : i32, [ + default: ^default, + 0.0: ^case0 + ] +^default: + spirv.Branch ^merge + +^case 0: + spirv.Branch ^merge + +^merge: + spirv.Return +} + +// ----- + +func.func @switch_missing_default(%selector: i32) -> () { + // expected-error@+2 {{expected 'default'}} + spirv.Switch %selector : i32, [ + 0: ^case0 + ] +^case 0: + spirv.Branch ^merge + +^merge: + spirv.Return +} + +// ----- + +func.func @switch_default_no_target(%selector: i32) -> () { + // expected-error@+2 {{expected block name}} + spirv.Switch %selector : i32, [ + default: + ] +^default: + spirv.Branch ^merge + +^merge: + spirv.Return +} + +// ----- + +func.func @switch_case_no_target(%selector: i32) -> () { + // expected-error@+3 {{expected block name}} + spirv.Switch %selector : i32, [ + default: ^default, + 0: + ] +^default: + spirv.Branch ^merge + +^case 0: + spirv.Branch ^merge + +^merge: + spirv.Return +} + +// ----- + +func.func @switch_missing_operand_type(%selector: i32) -> () { + %0 = spirv.Constant 0 : i32 + // expected-error@+2 {{expected ':'}} + spirv.Switch %selector : i32, [ + default: ^default (%0), + 0.0: ^case0 + ] +^default(%argd : i32): + spirv.Branch ^merge + +^case 0: + spirv.Branch ^merge + +^merge: + spirv.Return +} + diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir index 5eb2360..be8ce20 100644 --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -678,11 +678,9 @@ func.func @tensor.concat_different_shapes(%f: tensor<8x4xf32>, %g: tensor<8x5xf3 // CHECK-DAG: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]] // CHECK: %[[ALLOC:.*]] = memref.alloc // CHECK-SAME: memref<8x?xf32> -// CHECK-DAG: %[[OFFSET:.*]] = arith.constant 0 : index -// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, %[[F_DIM]]] [1, 1] +// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, %[[F_DIM]]] [1, 1] // CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]] -// CHECK: %[[OFFSET_2:.*]] = arith.addi %[[OFFSET]], %[[F_DIM]] : index -// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [8, %[[G_DIM]]] [1, 1] +// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [8, %[[G_DIM]]] [1, 1] // CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]] // CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]] // CHECK: return %[[RET]] @@ -706,10 +704,9 @@ func.func @tensor.concat_dynamic(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>) -> te // CHECK: %[[ALLOC:.*]] = memref.alloc // CHECK-SAME: memref<?x?xf32> // CHECK-DAG: %[[NON_CONCAT_DIM:.*]] = memref.dim %[[ALLOC]], %[[c0]] -// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[c0]]] [%[[NON_CONCAT_DIM]], %[[F_DIM]]] [1, 1] +// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [%[[NON_CONCAT_DIM]], %[[F_DIM]]] [1, 1] // CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]] -// CHECK: %[[OFFSET_2:.*]] = arith.addi %[[c0]], %[[F_DIM]] : index -// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [%[[NON_CONCAT_DIM]], %[[G_DIM]]] [1, 1] +// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [%[[NON_CONCAT_DIM]], %[[G_DIM]]] [1, 1] // CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]] // CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]] // CHECK: return %[[RET]] @@ -721,6 +718,35 @@ func.func @tensor.concat_dynamic_nonconcat_dim(%f: tensor<?x?xf32>, %g: tensor<? // ----- +// CHECK: #[[$sum_map:.+]] = affine_map<()[s0, s1] -> (s0 + s1)> + +// CHECK-LABEL: func @tensor.concat_mixed_dynamic_static( +// CHECK-SAME: %[[F:.*]]: tensor<8x?xf32>, %[[G:.*]]: tensor<8x?xf32>, +// CHECK-SAME: %[[H:.*]]: tensor<8x2xf32>) +// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]] +// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_buffer %[[G]] +// CHECK-DAG: %[[H_MEMREF:.*]] = bufferization.to_buffer %[[H]] +// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x10xf32> +// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index +// CHECK: %[[F_DIM:.*]] = memref.dim %[[F_MEMREF]], %[[c1]] +// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, %[[F_DIM]]] [1, 1] +// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]] +// CHECK: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]] +// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [8, %[[G_DIM]]] [1, 1] +// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]] +// CHECK: %[[OFFSET:.*]] = affine.apply #[[$sum_map]]()[%[[F_DIM]], %[[G_DIM]]] +// CHECK: %[[SUBVIEW3:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, 2] [1, 1] +// CHECK: memref.copy %[[H_MEMREF]], %[[SUBVIEW3]] +// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]] +// CHECK: return %[[RET]] +// CHECK: } +func.func @tensor.concat_mixed_dynamic_static(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>, %h: tensor<8x2xf32>) -> tensor<8x10xf32> { + %0 = tensor.concat dim(1) %f, %g, %h : (tensor<8x?xf32>, tensor<8x?xf32>, tensor<8x2xf32>) -> tensor<8x10xf32> + return %0 : tensor<8x10xf32> +} + +// ----- + // CHECK-LABEL: func @tensor.splat_dynamic( // CHECK-SAME: %[[F:[a-zA-Z0-9_]+]]: f32 // CHECK-SAME: %[[M:[a-zA-Z0-9_]+]]: index diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir index a05f423..6ef8b3e 100644 --- a/mlir/test/Dialect/Tosa/availability.mlir +++ b/mlir/test/Dialect/Tosa/availability.mlir @@ -606,7 +606,7 @@ func.func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> { // CHECK-LABEL: cast func.func @test_cast1(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> { // CHECK: profiles: [ [pro_int, pro_fp] ] - // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16, mxfp, int64] ] + // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16, int64] ] %0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xf32> return %0 : tensor<13x21x3xf32> } diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 5a40f3f..84776c4 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -362,6 +362,36 @@ func.func @clamp_twice_with_nan_propagate_ignore_is_not_single_clamp(%arg0: tens // ----- +// CHECK-LABEL: @clamp_twice_with_unsigned_quantized_is_single_clamp +// CHECK: tosa.clamp %arg0 {max_val = 230 : ui8, min_val = 10 : ui8} +func.func @clamp_twice_with_unsigned_quantized_is_single_clamp(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) { + %0 = tosa.clamp %arg0 {max_val = 240 : ui8, min_val = 10 : ui8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>> + %1 = tosa.clamp %0 {max_val = 230 : ui8, min_val = 5 : ui8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>> + return %1 : tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>> +} + +// ----- + +// CHECK-LABEL: @clamp_twice_with_signed_quantized_is_single_clamp +// CHECK: tosa.clamp %arg0 {max_val = 110 : i8, min_val = -5 : i8} +func.func @clamp_twice_with_signed_quantized_is_single_clamp(%arg0:tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) { + %0 = tosa.clamp %arg0 {max_val = 110 : i8, min_val = -10 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>> + %1 = tosa.clamp %0 {max_val = 120 : i8, min_val = -5 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>> + return %1 : tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>> +} + +// CHECK-LABEL: @clamp_twice_with_signed_quantized_non_overlap_is_not_single_clamp +// CHECK: %[[CLAMP_1:.*]] = tosa.clamp %arg0 {max_val = 50 : i8, min_val = -10 : i8} +// CHECK-NEXT: tosa.clamp %[[CLAMP_1]] {max_val = 120 : i8, min_val = 60 : i8} +func.func @clamp_twice_with_signed_quantized_non_overlap_is_not_single_clamp(%arg0:tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) { + %0 = tosa.clamp %arg0 {max_val = 50 : i8, min_val = -10 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>> + %1 = tosa.clamp %0 {max_val = 120 : i8, min_val = 60 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>> + return %1 : tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>> +} + + +// ----- + // CHECK-LABEL: @concat_fold func.func @concat_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> { // CHECK: return %arg0 @@ -643,6 +673,48 @@ func.func @select_not_pred(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: // ----- +// CHECK-LABEL: @select_broadcast_same_value_no_fold +func.func @select_broadcast_same_value_no_fold(%arg0: tensor<2x2xi1>, %arg1: tensor<1x1xf32>) -> tensor<2x2xf32> { + // CHECK: tosa.select %arg0, %arg1, %arg1 + %0 = tosa.select %arg0, %arg1, %arg1 : (tensor<2x2xi1>, tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: @select_broadcast_true_value_no_fold +func.func @select_broadcast_true_value_no_fold(%arg0: tensor<1x1xf32>, %arg1: tensor<2x2xf32>) -> tensor<?x?xf32> { + // CHECK: %[[CONST:.*]] = "tosa.const" + %0 = "tosa.const"() {values = dense<1> : tensor<2x2xi1>} : () -> tensor<2x2xi1> + // CHECK: tosa.select %[[CONST]], %arg0, %arg1 + %1 = tosa.select %0, %arg0, %arg1 : (tensor<2x2xi1>, tensor<1x1xf32>, tensor<2x2xf32>) -> tensor<?x?xf32> + return %1 : tensor<?x?xf32> +} + +// ----- + +// CHECK-LABEL: @select_broadcast_false_value_no_fold +func.func @select_broadcast_false_value_no_fold(%arg0: tensor<2x2xf32>, %arg1: tensor<1x1xf32>) -> tensor<2x2xf32> { + // CHECK: %[[CONST:.*]] = "tosa.const" + %0 = "tosa.const"() {values = dense<0> : tensor<2x2xi1>} : () -> tensor<2x2xi1> + // CHECK: tosa.select %[[CONST]], %arg0, %arg1 + %1 = tosa.select %0, %arg0, %arg1 : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<1x1xf32>) -> tensor<2x2xf32> + return %1 : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: @select_broadcast_false_value_dynamic_operand_no_fold +func.func @select_broadcast_false_value_dynamic_operand_no_fold(%arg0: tensor<2x?xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { + // CHECK: %[[CONST:.*]] = "tosa.const" + %0 = "tosa.const"() {values = dense<0> : tensor<2x2xi1>} : () -> tensor<2x2xi1> + // CHECK: tosa.select %[[CONST]], %arg0, %arg1 + %1 = tosa.select %0, %arg0, %arg1 : (tensor<2x2xi1>, tensor<2x?xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %1 : tensor<2x2xf32> +} + +// ----- + // CHECK-LABEL: @reduce_all_fold func.func @reduce_all_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> { // CHECK: return %arg0 diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 119991ca..3d24928 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -4,7 +4,7 @@ // validation flow. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,int64,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" func.func @test_cast(%arg0: tensor<i1>) -> tensor<5xi32> { @@ -306,6 +306,14 @@ func.func @test_concat_input_rank_mismatch(%arg0: tensor<1x2x3xf32>, %arg1: tens // ----- +func.func @test_concat_input_output_rank_mismatch(%arg0: tensor<2x2xf32>, %arg1: tensor<2x1xf32>) -> tensor<2xf32> { + // expected-error@+1 {{'tosa.concat' op expect output rank to match inputs rank, got 1 vs 2}} + %0 = tosa.concat %arg0, %arg1 {axis = 1 : i32} : (tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// ----- + func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>) { %0 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4> %1 = "tosa.const"() {values = dense<3.14> : tensor<2xf32>} : () -> tensor<2xf32> @@ -2036,6 +2044,16 @@ func.func @test_scatter_duplicate_indices(%arg0: tensor<2x52x3xf32>, %arg2: tens // ----- +// CHECK-LABEL: test_scatter_duplicate_indices_int64 +func.func @test_scatter_duplicate_indices_int64(%arg0: tensor<2x52x3xf32>, %arg2: tensor<2x12x3xf32>) -> tensor<2x52x3xf32> { + %indices = "tosa.const"() { values = dense<[[1, 2, 3, 4, 5, 6, 7, 8, 9, 3, 11, 12], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]> : tensor<2x12xi64> } : () -> tensor<2x12xi64> + // expected-error@+1 {{'tosa.scatter' op indices values contain duplicates}} + %0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi64>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32> + return %0 : tensor<2x52x3xf32> +} + +// ----- + func.func @test_reduce_all_unsupported_data_types(%arg0: tensor<2x12x11xf32>) -> tensor<1x12x11xf32> { // expected-error@+1 {{'tosa.reduce_all' op illegal: operation operand/result data types did not align with any profile or extension, got (f32,f32), did you mean (i1,i1)?}} %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<2x12x11xf32>) -> tensor<1x12x11xf32> diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir index 68a9578..177192b 100644 --- a/mlir/test/Dialect/Tosa/invalid_extension.mlir +++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir @@ -563,13 +563,6 @@ func.func @test_const_fp6e3m2(%arg0 : index) -> tensor<4xf6E3M2FN> { } // ----- -func.func @test_cast_f4e2m1(%arg0: tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16> { - // expected-error@+1 {{'tosa.cast' op illegal: requires all of [bf16, mxfp] but not enabled in target}} - %0 = tosa.cast %arg0 : (tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16> - return %0 : tensor<13x21x3xbf16> -} - -// ----- func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf8E5M2>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> { // expected-error@+1 {{'tosa.cast_from_block_scaled' op illegal: requires [mxfp] but not enabled in target}} diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index 22fde3b..652447bd 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -280,6 +280,13 @@ func.func @test_clamp_quantized(%arg0: tensor<13x21x3x!quant.uniform<i8:f32, 1.0 } // ----- +// CHECK-LABEL: clamp_quantized_unsigned +func.func @clamp_quantized_unsigned(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) { + %0 = tosa.clamp %arg0 {max_val = 255 : ui8, min_val = 0 : ui8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>> + return %0 : tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>> +} + +// ----- // CHECK-LABEL: sigmoid func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %0 = tosa.sigmoid %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> @@ -343,6 +350,13 @@ func.func @test_intdiv(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>) - } // ----- +// CHECK-LABEL: intdiv_i64 +func.func @test_intdiv_i64(%arg0: tensor<13x21x1xi64>, %arg1: tensor<13x21x3xi64>) -> tensor<13x21x3xi64> { + %0 = tosa.intdiv %arg0, %arg1 : (tensor<13x21x1xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi64> + return %0 : tensor<13x21x3xi64> +} + +// ----- // CHECK-LABEL: logical_and func.func @test_logical_and(%arg0: tensor<13x21x3xi1>, %arg1: tensor<13x21x1xi1>) -> tensor<13x21x3xi1> { %0 = tosa.logical_and %arg0, %arg1 : (tensor<13x21x3xi1>, tensor<13x21x1xi1>) -> tensor<13x21x3xi1> @@ -750,10 +764,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> } // ----- -// CHECK-LABEL: scatter -func.func @test_scatter(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> { - %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32> - return %0 : tensor<13x52x3xf32> +// CHECK-LABEL: gather_int64 +func.func @test_gather_int64(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi64>) -> tensor<13x26x3xf32> { + %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi64>) -> tensor<13x26x3xf32> + return %0 : tensor<13x26x3xf32> } // ----- @@ -764,6 +778,20 @@ func.func @test_gather_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tenso } // ----- +// CHECK-LABEL: scatter +func.func @test_scatter(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> { + %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32> + return %0 : tensor<13x52x3xf32> +} + +// ----- +// CHECK-LABEL: scatter_int64 +func.func @test_scatter_int64(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi64>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> { + %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi64>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32> + return %0 : tensor<13x52x3xf32> +} + +// ----- // CHECK-LABEL: scatter_unranked_indices func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>, %arg2: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<*xi32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> @@ -1277,6 +1305,42 @@ func.func @test_matmul_t_block_scaled_mxint8(%arg0: tensor<4x8x32x!tosa.mxint8>, } // ----- +// CHECK-LABEL: test_matmul_t_block_scaled_fp6e3m2_e2e +func.func @test_matmul_t_block_scaled_fp6e3m2_e2e(%arg0: tensor<6x2x32xf32>, %arg1: tensor<6x64x32xf32>) -> tensor<6x2x64xf32> { + %a, %sa = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x2x32xf32>) -> (tensor<6x2x32xf6E3M2FN>, tensor<6x2x1xf8E8M0FNU>) + %b, %sb = tosa.cast_to_block_scaled %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x64x32xf32>) -> (tensor<6x64x32xf6E3M2FN>, tensor<6x64x1xf8E8M0FNU>) + %res = tosa.matmul_t_block_scaled %a, %sa, %b, %sb {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x2x32xf6E3M2FN>, tensor<6x2x1xf8E8M0FNU>, tensor<6x64x32xf6E3M2FN>, tensor<6x64x1xf8E8M0FNU>) -> tensor<6x2x64xf32> + return %res : tensor<6x2x64xf32> +} + +// ----- +// CHECK-LABEL: test_matmul_t_block_scaled_fp6e2m3_e2e +func.func @test_matmul_t_block_scaled_fp6e2m3_e2e(%arg0: tensor<6x2x32xf32>, %arg1: tensor<6x64x32xf32>) -> tensor<6x2x64xf32> { + %a, %sa = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x2x32xf32>) -> (tensor<6x2x32xf6E2M3FN>, tensor<6x2x1xf8E8M0FNU>) + %b, %sb = tosa.cast_to_block_scaled %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x64x32xf32>) -> (tensor<6x64x32xf6E2M3FN>, tensor<6x64x1xf8E8M0FNU>) + %res = tosa.matmul_t_block_scaled %a, %sa, %b, %sb {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x2x32xf6E2M3FN>, tensor<6x2x1xf8E8M0FNU>, tensor<6x64x32xf6E2M3FN>, tensor<6x64x1xf8E8M0FNU>) -> tensor<6x2x64xf32> + return %res : tensor<6x2x64xf32> +} + +// ----- +// CHECK-LABEL: test_matmul_t_block_scaled_fp4e2m1_e2e +func.func @test_matmul_t_block_scaled_fp4e2m1_e2e(%arg0: tensor<6x2x32xf32>, %arg1: tensor<6x64x32xf32>) -> tensor<6x2x64xf32> { + %a, %sa = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x2x32xf32>) -> (tensor<6x2x32xf4E2M1FN>, tensor<6x2x1xf8E8M0FNU>) + %b, %sb = tosa.cast_to_block_scaled %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x64x32xf32>) -> (tensor<6x64x32xf4E2M1FN>, tensor<6x64x1xf8E8M0FNU>) + %res = tosa.matmul_t_block_scaled %a, %sa, %b, %sb {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x2x32xf4E2M1FN>, tensor<6x2x1xf8E8M0FNU>, tensor<6x64x32xf4E2M1FN>, tensor<6x64x1xf8E8M0FNU>) -> tensor<6x2x64xf32> + return %res : tensor<6x2x64xf32> +} + +// ----- +// CHECK-LABEL: test_matmul_t_block_scaled_mxint8_e2e +func.func @test_matmul_t_block_scaled_mxint8_e2e(%arg0: tensor<6x2x32xf32>, %arg1: tensor<6x64x32xf32>) -> tensor<6x2x64xf32> { + %a, %sa = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x2x32xf32>) -> (tensor<6x2x32x!tosa.mxint8>, tensor<6x2x1xf8E8M0FNU>) + %b, %sb = tosa.cast_to_block_scaled %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x64x32xf32>) -> (tensor<6x64x32x!tosa.mxint8>, tensor<6x64x1xf8E8M0FNU>) + %res = tosa.matmul_t_block_scaled %a, %sa, %b, %sb {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x2x32x!tosa.mxint8>, tensor<6x2x1xf8E8M0FNU>, tensor<6x64x32x!tosa.mxint8>, tensor<6x64x1xf8E8M0FNU>) -> tensor<6x2x64xf32> + return %res : tensor<6x2x64xf32> +} + +// ----- // CHECK-LABEL: test_cast_from_block_scaled_static func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> { %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> @@ -1307,7 +1371,7 @@ func.func @test_cast_to_block_scaled_unranked(%arg0: tensor<*xf32>) -> (tensor<* // ----- // CHECK-LABEL: test_cast_to_block_scaled_mxint8 func.func @test_cast_to_block_scaled_mxint8(%arg0: tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) { - %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) return %0#0, %0#1 : tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU> } diff --git a/mlir/test/Dialect/Tosa/quant-test.mlir b/mlir/test/Dialect/Tosa/quant-test.mlir index f0ad4eb..88dffe7 100644 --- a/mlir/test/Dialect/Tosa/quant-test.mlir +++ b/mlir/test/Dialect/Tosa/quant-test.mlir @@ -1,14 +1,22 @@ // RUN: mlir-opt --tosa-test-quant-utils %s | FileCheck %s // ----- -// CHECK-LABEL: test_build_qtype -func.func @test_build_qtype(%arg0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>> { +// CHECK-LABEL: test_build_qtype_unsigned +func.func @test_build_qtype_unsigned(%arg0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, %arg1: tensor<1xui8>, %arg2: tensor<1xui8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>> { // CHECK: tosa.negate - %0 = "tosa.negate"(%arg0, %arg1, %arg2) : (tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, tensor<1xi8>, tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>> + %0 = "tosa.negate"(%arg0, %arg1, %arg2) : (tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, tensor<1xui8>, tensor<1xui8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>> return %0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>> } // ----- +// CHECK-LABEL: test_build_qtype_signed +func.func @test_build_qtype_signed(%arg0 : tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>> { + // CHECK: tosa.negate + %0 = "tosa.negate"(%arg0, %arg1, %arg2) : (tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>>, tensor<1xi8>, tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>> + return %0 : tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>> +} + +// ----- // CHECK-LABEL: test_build_mult_and_shift func.func @test_build_mult_and_shift(%arg0: tensor<1x32x32x8x!quant.uniform<i8:f32, 0.015684768557548523>>, %arg1 : tensor<16x1x1x8x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, %arg2 : tensor<16xi32>) -> tensor<1x34x36x16x!quant.uniform<i32:f32, 0.078431375324726104>> { // CHECK: tosa.conv2d diff --git a/mlir/test/Dialect/Tosa/tosa-arith-const-to-tosa-const.mlir b/mlir/test/Dialect/Tosa/tosa-arith-const-to-tosa-const.mlir new file mode 100644 index 0000000..fc2d77ef --- /dev/null +++ b/mlir/test/Dialect/Tosa/tosa-arith-const-to-tosa-const.mlir @@ -0,0 +1,100 @@ +// RUN: mlir-opt %s --tosa-arith-const-to-tosa-const --split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @rewrite_f32_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK: return %[[CST]] +func.func @rewrite_f32_tensor() -> tensor<2xf32> { + %c = arith.constant dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32> + return %c : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: func.func @rewrite_i32_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: return %[[CST]] +func.func @rewrite_i32_tensor() -> tensor<3xi32> { + %c = arith.constant dense<[1, 0, -1]> : tensor<3xi32> + return %c : tensor<3xi32> +} + +// ----- + +// CHECK-LABEL: func.func @rewrite_i1_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[true, false]> : tensor<2xi1>}> : () -> tensor<2xi1> +func.func @rewrite_i1_tensor() -> tensor<2xi1> { + %c = arith.constant dense<[true, false]> : tensor<2xi1> + return %c : tensor<2xi1> +} + +// ----- + +// CHECK-LABEL: func.func @rewrite_rank0_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<1.234500e+00> : tensor<f32>}> : () -> tensor<f32> +func.func @rewrite_rank0_tensor() -> tensor<f32> { + %c = arith.constant dense<1.234500e+00> : tensor<f32> + return %c : tensor<f32> +} + +// ----- + +// CHECK-LABEL: func.func @preserve_scalar_i32 +// CHECK: %[[CST:.*]] = arith.constant 42 : i32 +func.func @preserve_scalar_i32() -> i32 { + %c = arith.constant 42 : i32 + return %c : i32 +} + +// ----- + +// CHECK-LABEL: func.func @preserve_index_tensor +// CHECK: %[[CST:.*]] = arith.constant dense<[0, 1]> : tensor<2xindex> +func.func @preserve_index_tensor() -> tensor<2xindex> { + %c = arith.constant dense<[0, 1]> : tensor<2xindex> + return %c : tensor<2xindex> +} + +// ----- + +// CHECK-LABEL: func.func @rewrite_resource_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense_resource<blob1> : tensor<4xf32>}> : () -> tensor<4xf32> +func.func @rewrite_resource_tensor() -> tensor<4xf32> { + %c = arith.constant dense_resource<"blob1"> : tensor<4xf32> + return %c : tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @rewrite_quant_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[10, 20]> : tensor<2xui8>}> : () -> tensor<2xui8> +func.func @rewrite_quant_tensor() -> tensor<2xui8> { + %c = arith.constant dense<[10, 20]> : tensor<2xui8> + return %c : tensor<2xui8> +} + +// ----- + +// CHECK-LABEL: func.func @rewrite_quant_uniform_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<["10", "20"]> : tensor<2x!quant.uniform<i8:f32, 5.000000e-01>>}> : () -> tensor<2x!quant.uniform<i8:f32, 5.000000e-01>> +func.func @rewrite_quant_uniform_tensor() -> tensor<2x!quant.uniform<i8:f32, 0.5:0>> { + %c = arith.constant dense<["10", "20"]> : tensor<2x!quant.uniform<i8:f32, 0.5:0>> + return %c : tensor<2x!quant.uniform<i8:f32, 0.5:0>> +} + +// ----- + +// CHECK-LABEL: func.func @rewrite_fp8_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[1.000000e+00, -5.000000e-01]> : tensor<2xf8E4M3FN>}> : () -> tensor<2xf8E4M3FN> +func.func @rewrite_fp8_tensor() -> tensor<2xf8E4M3FN> { + %c = arith.constant dense<[1.0, -0.5]> : tensor<2xf8E4M3FN> + return %c : tensor<2xf8E4M3FN> +} + +// ----- + +// CHECK-LABEL: func.func @rewrite_mxint8_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<["0x00", "0x7F"]> : tensor<2x!tosa.mxint8>}> : () -> tensor<2x!tosa.mxint8> +func.func @rewrite_mxint8_tensor() -> tensor<2x!tosa.mxint8> { + %c = arith.constant dense<["0x00", "0x7F"]> : tensor<2x!tosa.mxint8> + return %c : tensor<2x!tosa.mxint8> +} diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir index c7eeb52..d4c4595 100644 --- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir @@ -98,3 +98,26 @@ func.func @depthwise_conv2d_no_const_zero_point(%arg0: tensor<4x10x10x2xi8>, %ar %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = i32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<4x10x10x6xi32> return %0 : tensor<4x10x10x6xi32> } + +// ----- +// CHECK-LABEL: func.func @depthwise_conv2d_as_mul_dynamic_batch_bias( +// CHECK-SAME: %[[INP:.*]]: tensor<?x10x10x2xf32>, +// CHECK-SAME: %[[WTS:.*]]: tensor<1x1x2x3xf32>, +// CHECK-SAME: %[[BIAS:.*]]: tensor<?xf32>) -> tensor<?x10x10x6xf32> { +// CHECK: %[[BIAS_EXPANDED_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 1, 1, -1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[RES_EXPANDED_SHAPE:.*]] = tosa.const_shape {values = dense<[-1, 10, 10, 6]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[MUL_SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[WTS_EXPANDED_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 1, 1, 2, 3]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK: %[[INP_EXPANDED_SHAPE:.*]] = tosa.const_shape {values = dense<[-1, 10, 10, 2, 1]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK: %[[INP_RESHAPED:.*]] = tosa.reshape %[[INP]], %[[INP_EXPANDED_SHAPE]] : (tensor<?x10x10x2xf32>, !tosa.shape<5>) -> tensor<?x10x10x2x1xf32> +// CHECK: %[[WTS_RESHAPED:.*]] = tosa.reshape %[[WTS]], %[[WTS_EXPANDED_SHAPE]] : (tensor<1x1x2x3xf32>, !tosa.shape<5>) -> tensor<1x1x1x2x3xf32> +// CHECK: %[[MUL:.*]] = tosa.mul %[[INP_RESHAPED]], %[[WTS_RESHAPED]], %[[MUL_SHIFT]] : (tensor<?x10x10x2x1xf32>, tensor<1x1x1x2x3xf32>, tensor<1xi8>) -> tensor<?x10x10x2x3xf32> +// CHECK: %[[RES_RESHAPED:.*]] = tosa.reshape %[[MUL]], %[[RES_EXPANDED_SHAPE]] : (tensor<?x10x10x2x3xf32>, !tosa.shape<4>) -> tensor<?x10x10x6xf32> +// CHECK: %[[BIAS_RESHAPED:.*]] = tosa.reshape %[[BIAS]], %[[BIAS_EXPANDED_SHAPE]] : (tensor<?xf32>, !tosa.shape<4>) -> tensor<1x1x1x?xf32> +// CHECK: %[[RES:.*]] = tosa.add %[[RES_RESHAPED]], %[[BIAS_RESHAPED]] : (tensor<?x10x10x6xf32>, tensor<1x1x1x?xf32>) -> tensor<?x10x10x6xf32> +// CHECK: return %[[RES]] +func.func @depthwise_conv2d_as_mul_dynamic_batch_bias(%arg0: tensor<?x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<?xf32>) -> tensor<?x10x10x6xf32> { + %zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<?x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<?xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x10x10x6xf32> + return %0 : tensor<?x10x10x6xf32> +} diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir index 810135f..61ca0ae 100644 --- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir @@ -181,3 +181,24 @@ func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 : (tensor<1x16x1x1xi8>, tensor<1x2x1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x19x2x1xi32> "func.return" (%2) : (tensor<1x19x2x1xi32>) -> () } + + +// ----- +// CHECK-LABEL: @transpose_conv2d_non_strided_dynamic_batch +// CHECK: tosa.conv2d +// CHECK-NOT: tosa.transpose_conv2d +func.func @transpose_conv2d_non_strided_dynamic_batch(%arg0: tensor<?x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) -> tensor<?x18x19x5xf32> { + %zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<?x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x18x19x5xf32> + return %0 : tensor<?x18x19x5xf32> +} + +// ----- +// CHECK-LABEL: @transpose_conv2d_strided_dynamic_batch +// CHECK: tosa.conv2d +// CHECK-NOT: tosa.transpose_conv2d +func.func @transpose_conv2d_strided_dynamic_batch(%arg0: tensor<?x17x15x3xf32>, %arg1: tensor<5x3x5x3xf32>, %arg2: tensor<5xf32>) -> tensor<?x35x47x5xf32> { + %zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 3>} : (tensor<?x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x35x47x5xf32> + return %0 : tensor<?x35x47x5xf32> +} diff --git a/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32-aggressive.mlir b/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32-aggressive.mlir new file mode 100644 index 0000000..1a36177 --- /dev/null +++ b/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32-aggressive.mlir @@ -0,0 +1,81 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-i64-to-i32="aggressive-rewrite=1" %s | FileCheck %s --allow-unused-prefixes --check-prefixes=COMMON,DEFAULT +// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-i64-to-i32="aggressive-rewrite=1 convert-function-boundaries=1" %s | FileCheck %s --allow-unused-prefixes --check-prefixes=COMMON,FUNCBOUND + +// CHECK-LABEL: test_i64_argmax_large_axis_dim +func.func @test_i64_argmax_large_axis_dim(%arg0: tensor<1x513x513x2147483650xi8>) -> tensor<1x513x513xi64> { + // DEFAULT: tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x2147483650xi8>) -> tensor<1x513x513xi32> + %0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x2147483650xi8>) -> tensor<1x513x513xi64> + return %0 : tensor<1x513x513xi64> +} + +// ----- + +// CHECK-LABEL: test_convert_input_parameters +// DEFAULT: %[[IN:.*]]: tensor<1x513x513x3xi64> +// FUNCBOUND: %[[IN:.*]]: tensor<1x513x513x3xi32> +func.func @test_convert_input_parameters(%arg0: tensor<1x513x513x3xi64>) -> tensor<1x513x513x3xf32> { + // DEFAULT: %[[FUNC_BOUND_CAST:.*]] = tosa.cast %[[IN]] : (tensor<1x513x513x3xi64>) -> tensor<1x513x513x3xi32> + // DEFAULT: %[[CAST1:.*]] = tosa.cast %[[FUNC_BOUND_CAST]] : (tensor<1x513x513x3xi32>) -> tensor<1x513x513x3xi32> + // FUNCBOUND: %[[CAST1:.*]] = tosa.cast %[[IN]] : (tensor<1x513x513x3xi32>) -> tensor<1x513x513x3xi32> + %0 = tosa.cast %arg0 : (tensor<1x513x513x3xi64>) -> tensor<1x513x513x3xi32> + + // COMMON: %[[CAST2:.*]] = tosa.cast %[[CAST1]] : (tensor<1x513x513x3xi32>) -> tensor<1x513x513x3xf32> + %1 = tosa.cast %0 : (tensor<1x513x513x3xi32>) -> tensor<1x513x513x3xf32> + return %1 : tensor<1x513x513x3xf32> +} + +// ----- + +// CHECK-LABEL: test_add +// DEFAULT: %[[IN0:.*]]: tensor<13x21x1xi64>, %[[IN1:.*]]: tensor<13x21x3xi64> +// FUNCBOUND: %[[IN0:.*]]: tensor<13x21x1xi32>, %[[IN1:.*]]: tensor<13x21x3xi32> +func.func @test_add(%arg0: tensor<13x21x1xi64>, %arg1: tensor<13x21x3xi64>) -> tensor<13x21x3xi64> { + // DEFAULT-DAG: %[[FUNC_BOUND_CAST0:.*]] = tosa.cast %[[IN0]] : (tensor<13x21x1xi64>) -> tensor<13x21x1xi32> + // DEFAULT-DAG: %[[FUNC_BOUND_CAST1:.*]] = tosa.cast %[[IN1]] : (tensor<13x21x3xi64>) -> tensor<13x21x3xi32> + // DEFAULT: %[[ADD:.*]] = tosa.add %[[FUNC_BOUND_CAST0]], %[[FUNC_BOUND_CAST1]] : (tensor<13x21x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32> + // DEFAULT: %[[CAST:.*]] = tosa.cast %[[ADD]] : (tensor<13x21x3xi32>) -> tensor<13x21x3xi64> + // DEFAULT: return %[[CAST]] : tensor<13x21x3xi64> + // FUNCBOUND: %[[ADD:.*]] = tosa.add %[[IN0]], %[[IN1]] : (tensor<13x21x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32> + // FUNCBOUND: return %[[ADD]] : tensor<13x21x3xi32> + %0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi64> + return %0 : tensor<13x21x3xi64> +} + +// ----- + +// CHECK-LABEL: test_regions +// DEFAULT: %[[IN0:.*]]: tensor<i64>, %[[IN1:.*]]: tensor<i64> +func.func @test_regions(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i1>) -> tensor<i64> { + // DEFAULT-DAG: %[[CAST0:.*]] = tosa.cast %[[IN0]] : (tensor<i64>) -> tensor<i32> + // DEFAULT-DAG: %[[CAST1:.*]] = tosa.cast %[[IN1]] : (tensor<i64>) -> tensor<i32> + // COMMON: %[[IF_RESULT:.*]] = tosa.cond_if + %0 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<i64>) { + // DEFAULT: %[[ADD:.*]] = tosa.add %[[CAST0]], %[[CAST1]] : (tensor<i32>, tensor<i32>) -> tensor<i32> + // FUNCBOUND: %[[ADD:.*]] = tosa.add %[[IN0]], %[[IN1]] : (tensor<i32>, tensor<i32>) -> tensor<i32> + %1 = tosa.add %arg0, %arg1 : (tensor<i64>, tensor<i64>) -> tensor<i64> + // COMMON: tosa.yield %[[ADD]] : tensor<i32> + tosa.yield %1 : tensor<i64> + } else { + // DEFAULT: %[[SUB:.*]] = tosa.sub %[[CAST0]], %[[CAST1]] : (tensor<i32>, tensor<i32>) -> tensor<i32> + // FUNCBOUND: %[[SUB:.*]] = tosa.sub %[[IN0]], %[[IN1]] : (tensor<i32>, tensor<i32>) -> tensor<i32> + %1 = tosa.sub %arg0, %arg1 : (tensor<i64>, tensor<i64>) -> tensor<i64> + // COMMON: tosa.yield %[[SUB]] : tensor<i32> + tosa.yield %1 : tensor<i64> + } + // DEFAULT: %[[OUT:.*]] = tosa.cast %[[IF_RESULT]] : (tensor<i32>) -> tensor<i64> + // DEFAULT: return %[[OUT]] : tensor<i64> + // FUNCBOUND: return %[[IF_RESULT]] : tensor<i32> + return %0 : tensor<i64> +} + +// ----- + +// CHECK-LABEL: test_const +func.func @test_const() -> tensor<2xi64> { + // COMMON: %[[CONST:.*]] = "tosa.const"() <{values = dense<[1, 2]> : tensor<2xi32>}> : () -> tensor<2xi32> + %0 = "tosa.const"() <{values = dense<[1, 2]> : tensor<2xi64>}> : () -> tensor<2xi64> + // DEFAULT: %[[OUT:.*]] = tosa.cast %[[CONST]] : (tensor<2xi32>) -> tensor<2xi64> + // DEFAULT: return %[[OUT]] : tensor<2xi64> + // FUNCBOUND: return %[[CONST]] : tensor<2xi32> + return %0 : tensor<2xi64> +} diff --git a/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir b/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir new file mode 100644 index 0000000..a14483f --- /dev/null +++ b/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir @@ -0,0 +1,162 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-i64-to-i32="convert-function-boundaries=0" %s | FileCheck %s --allow-unused-prefixes --check-prefixes=COMMON,DEFAULT +// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-i64-to-i32="convert-function-boundaries=1" %s | FileCheck %s --allow-unused-prefixes --check-prefixes=COMMON,FUNCBOUND + +// ----- + +// CHECK-LABEL: test_i64_argmax +func.func @test_i64_argmax(%arg0: tensor<1x513x513x19xi8>) -> tensor<1x513x513xi64> { + // COMMON: %[[ARGMAX:.*]] = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x19xi8>) -> tensor<1x513x513xi32> + %0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x19xi8>) -> tensor<1x513x513xi64> + + // DEFAULT: %[[CAST:.*]] = tosa.cast %[[ARGMAX]] : (tensor<1x513x513xi32>) -> tensor<1x513x513xi64> + // FUNCBOUND: return %[[ARGMAX]] : tensor<1x513x513xi32> + return %0 : tensor<1x513x513xi64> +} + +// ----- + +// CHECK-LABEL: test_i64_argmax_cast +func.func @test_i64_argmax_cast(%arg0: tensor<1x513x513x19xi8>) -> tensor<1x513x513xf32> { + // COMMON: %[[ARGMAX:.*]] = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x19xi8>) -> tensor<1x513x513xi32> + %0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x19xi8>) -> tensor<1x513x513xi64> + // COMMON: tosa.cast %[[ARGMAX]] : (tensor<1x513x513xi32>) -> tensor<1x513x513xf32> + %1 = tosa.cast %0 : (tensor<1x513x513xi64>) -> tensor<1x513x513xf32> + return %1 : tensor<1x513x513xf32> +} + +// ----- + +// CHECK-LABEL: test_i64_argmax_large_axis_dim +func.func @test_i64_argmax_large_axis_dim(%arg0: tensor<1x513x513x2147483650xi8>) -> tensor<1x513x513xi64> { + // expected-error @+1 {{failed to legalize operation 'tosa.argmax'}} + %0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x2147483650xi8>) -> tensor<1x513x513xi64> + return %0 : tensor<1x513x513xi64> +} + +// ----- + +// CHECK-LABEL: test_add +func.func @test_add(%arg0: tensor<13x21x1xi64>, %arg1: tensor<13x21x3xi64>) -> tensor<13x21x3xi64> { + // expected-error @+1 {{failed to legalize operation 'tosa.add'}} + %0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi64> + return %0 : tensor<13x21x3xi64> +} + +// ----- + +// CHECK-LABEL: test_regions +func.func @test_regions(%arg0: tensor<1x2xi32>, %arg1: tensor<1xi32>, %arg2: tensor<i1>) -> tensor<1xi32> { + // COMMON: %[[IF_RESULT:.*]] = tosa.cond_if %arg2 : tensor<i1> -> tensor<1xi32> + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<1xi32> { + // COMMON: %[[ARGMAX:.*]] = tosa.argmax %arg0 {axis = 1 : i32} : (tensor<1x2xi32>) -> tensor<1xi32> + %1 = tosa.argmax %arg0 {axis = 1 : i32} : (tensor<1x2xi32>) -> tensor<1xi64> + // COMMON: %[[CAST:.*]] = tosa.cast %[[ARGMAX]] : (tensor<1xi32>) -> tensor<1xi32> + %2 = tosa.cast %1 : (tensor<1xi64>) -> tensor<1xi32> + // COMMON: tosa.yield %[[CAST]] : tensor<1xi32> + tosa.yield %2 : tensor<1xi32> + } else { + tosa.yield %arg1 : tensor<1xi32> + } + // COMMON: return %[[IF_RESULT]] : tensor<1xi32> + return %0 : tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: test_concat +func.func @test_concat(%arg0: tensor<13x21x3xi64>, %arg1: tensor<13x21x3xi64>) -> tensor<26x21x3xi64> { + // COMMON: tosa.concat %{{.*}}, %{{.*}} {axis = 0 : i32} : (tensor<13x21x3xi32>, tensor<13x21x3xi32>) -> tensor<26x21x3xi32> + %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xi64>, tensor<13x21x3xi64>) -> tensor<26x21x3xi64> + return %0 : tensor<26x21x3xi64> +} + +// ----- + +// CHECK-LABEL: test_pad +func.func @test_pad(%arg0: tensor<13x21x3xi64>, %arg1: tensor<1xi64>) -> tensor<15x23x5xi64> { + %padding = tosa.const_shape {values = dense<1> : tensor<6xindex>} : () -> !tosa.shape<6> + // COMMON: tosa.pad %{{.*}}, %{{.*}}, %{{.*}} : (tensor<13x21x3xi32>, !tosa.shape<6>, tensor<1xi32>) -> tensor<15x23x5xi32> + %1 = tosa.pad %arg0, %padding, %arg1 : (tensor<13x21x3xi64>, !tosa.shape<6>, tensor<1xi64>) -> tensor<15x23x5xi64> + return %1 : tensor<15x23x5xi64> +} + +// ----- + +// CHECK-LABEL: test_reshape +func.func @test_reshape(%arg0: tensor<13x21x3xi64>) -> tensor<1x819xi64> { + %1 = tosa.const_shape {values = dense<[1, 819]> : tensor<2xindex>} : () -> !tosa.shape<2> + // COMMON: tosa.reshape %{{.*}}, %{{.*}} : (tensor<13x21x3xi32>, !tosa.shape<2>) -> tensor<1x819xi32> + %0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xi64>, !tosa.shape<2>) -> tensor<1x819xi64> + return %0 : tensor<1x819xi64> +} + +// ----- + +// CHECK-LABEL: test_reverse +func.func @test_reverse(%arg0: tensor<13x21x3xi64>) -> tensor<13x21x3xi64> { + // COMMON: tosa.reverse %{{.*}} {axis = 0 : i32} : (tensor<13x21x3xi32>) -> tensor<13x21x3xi32> + %0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<13x21x3xi64>) -> tensor<13x21x3xi64> + return %0 : tensor<13x21x3xi64> +} + +// ----- + +// CHECK-LABEL: test_slice +func.func @test_slice(%arg0: tensor<13x21x3xi64>) -> tensor<4x11x1xi64> { + %0 = tosa.const_shape {values = dense<[4, 11, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> + %1 = tosa.const_shape {values = dense<[6, 8, 0]> : tensor<3xindex>} : () -> !tosa.shape<3> + // COMMON: tosa.slice %{{.*}}, %{{.*}}, %{{.*}} : (tensor<13x21x3xi32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xi32> + %2 = tosa.slice %arg0, %0, %1 : (tensor<13x21x3xi64>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xi64> + return %2 : tensor<4x11x1xi64> +} + +// ----- + +// CHECK-LABEL: test_tile +func.func @test_tile(%arg0: tensor<13x21x3xi64>) -> tensor<39x21x6xi64> { + %cst = tosa.const_shape { values = dense<[3, 1, 2]> : tensor<3xindex> } : () -> !tosa.shape<3> + // COMMON: tosa.tile %{{.*}}, %{{.*}} : (tensor<13x21x3xi32>, !tosa.shape<3>) -> tensor<39x21x6xi32> + %0 = tosa.tile %arg0, %cst: (tensor<13x21x3xi64>, !tosa.shape<3>) -> tensor<39x21x6xi64> + return %0 : tensor<39x21x6xi64> +} + +// ----- + +// CHECK-LABEL: transpose +func.func @test_transpose(%arg0: tensor<13x21x3xi64>) -> tensor<3x13x21xi64> { + // COMMON: tosa.transpose %{{.*}} {perms = array<i32: 2, 0, 1>} : (tensor<13x21x3xi32>) -> tensor<3x13x21xi32> + %1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<13x21x3xi64>) -> tensor<3x13x21xi64> + return %1 : tensor<3x13x21xi64> +} + +// ----- + +// CHECK-LABEL: test_transition_to_i64 +func.func @test_transition_to_i64(%arg0: tensor<1xi32>) -> tensor<1xi64> { + // COMMON: %[[CAST:.*]] = tosa.cast %arg0 : (tensor<1xi32>) -> tensor<1xi32> + %0 = tosa.cast %arg0 : (tensor<1xi32>) -> tensor<1xi64> + // COMMON: %[[IDENTITY1:.*]] = tosa.identity %[[CAST]] : (tensor<1xi32>) -> tensor<1xi32> + %1 = tosa.identity %0 : (tensor<1xi64>) -> tensor<1xi64> + // COMMON: %[[IDENTITY2:.*]] = tosa.identity %[[IDENTITY1]] : (tensor<1xi32>) -> tensor<1xi32> + %2 = tosa.identity %1 : (tensor<1xi64>) -> tensor<1xi64> + // DEFAULT: %[[OUT_CAST:.*]] = tosa.cast %[[IDENTITY2]] : (tensor<1xi32>) -> tensor<1xi64> + // DEFAULT: return %[[OUT_CAST]] : tensor<1xi64> + // FUNCBOUND: return %[[IDENTITY2]] : tensor<1xi32> + return %2 : tensor<1xi64> +} + +// ----- + +// CHECK-LABEL: test_transition_from_i64 +func.func @test_transition_from_i64(%arg0: tensor<1xi64>) -> tensor<1xi32> { + // DEFAULT: %[[CAST:.*]] = tosa.cast %arg0 : (tensor<1xi64>) -> tensor<1xi32> + // DEFAULT: %[[IDENTITY1:.*]] = tosa.identity %[[CAST]] : (tensor<1xi32>) -> tensor<1xi32> + // FUNCBOUND: %[[IDENTITY1:.*]] = tosa.identity %arg0 : (tensor<1xi32>) -> tensor<1xi32> + %0 = tosa.identity %arg0 : (tensor<1xi64>) -> tensor<1xi64> + // COMMON: %[[IDENTITY2:.*]] = tosa.identity %[[IDENTITY1]] : (tensor<1xi32>) -> tensor<1xi32> + %1 = tosa.identity %0 : (tensor<1xi64>) -> tensor<1xi64> + // COMMON: %[[OUT_CAST:.*]] = tosa.cast %[[IDENTITY2]] : (tensor<1xi32>) -> tensor<1xi32> + %2 = tosa.cast %1 : (tensor<1xi64>) -> tensor<1xi32> + // COMMON: return %[[OUT_CAST]] : tensor<1xi32> + return %2 : tensor<1xi32> +} diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir index 9bd7aa8..f6b1edc 100644 --- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir +++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir @@ -2,6 +2,7 @@ // ----- +// CHECK-LABEL: test_matmul_fp8_mixed_precision_operands func.func @test_matmul_fp8_mixed_precision_operands(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16> { %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN> %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2> @@ -53,14 +54,6 @@ func.func @test_const_fp6e3m2() -> tensor<4xf6E3M2FN> { // ----- -// CHECK-LABEL: test_cast_f4e2m1 -func.func @test_cast_f4e2m1(%arg0: tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16> { - %0 = tosa.cast %arg0 : (tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16> - return %0 : tensor<13x21x3xbf16> -} - -// ----- - // CHECK-LABEL: test_cast_from_block_scaled_fp8e5m2_fp32 func.func @test_cast_from_block_scaled_fp8e5m2_fp32(%arg0: tensor<4x32xf8E5M2>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> { %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf8E5M2>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> @@ -109,14 +102,6 @@ func.func @test_const_mxint8() -> tensor<2x!tosa.mxint8> { // ----- -// CHECK-LABEL: test_cast_f4e2m1 -func.func @test_cast_f4e2m1(%arg0: tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16> { - %0 = tosa.cast %arg0 : (tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16> - return %0 : tensor<13x21x3xbf16> -} - -// ----- - // CHECK-LABEL: test_matmul_t_block_scaled_mxint8 func.func @test_matmul_t_block_scaled_mxint8(%arg0: tensor<4x8x32x!tosa.mxint8>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32x!tosa.mxint8>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> { %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32x!tosa.mxint8>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32x!tosa.mxint8>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> @@ -130,3 +115,28 @@ func.func @test_cast_to_block_scaled_mxint8(%arg0: tensor<4x32xf32>) -> (tensor< %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) return %0#0, %0#1 : tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU> } + +// ----- + +// CHECK-LABEL: test_argmax_fp8_i64 +func.func @test_argmax_fp8_i64(%arg0: tensor<12x8x16xf8E5M2>) -> tensor<12x16xi64> { + %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<12x8x16xf8E5M2>) -> tensor<12x16xi64> + return %0 : tensor<12x16xi64> +} + +// ----- + +// CHECK-LABEL: test_argmax_bf16_i64 +func.func @test_argmax_bf16_i64(%arg0: tensor<12x8x16xbf16>) -> tensor<12x16xi64> { + %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<12x8x16xbf16>) -> tensor<12x16xi64> + return %0 : tensor<12x16xi64> +} + +// ----- + +// CHECK-LABEL: test_scatter_const_indices_int64 +func.func @test_scatter_const_indices_int64(%arg0: tensor<2x52x3xf32>, %arg2: tensor<2x12x3xf32>) -> tensor<2x52x3xf32> { + %indices = "tosa.const"() { values = dense<[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]> : tensor<2x12xi64> } : () -> tensor<2x12xi64> + %0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi64>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32> + return %0 : tensor<2x52x3xf32> +} diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir index 6cf76cd..ea64d46 100644 --- a/mlir/test/Dialect/Tosa/verifier.mlir +++ b/mlir/test/Dialect/Tosa/verifier.mlir @@ -1222,3 +1222,11 @@ func.func @test_cast_to_block_scaled_data_scale_channel_mismatch(%arg0: tensor<4 %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>) return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU> } + +// ----- + +func.func @test_clamp_quantized(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) { + // expected-error@+1 {{'tosa.clamp' op min/max attributes types are incompatible with input/output element types.}} + %0 = tosa.clamp %arg0 {max_val = 127 : i8, min_val = -128 : i8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>> + return %0 : tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>> +} diff --git a/mlir/test/Dialect/Transform/include-failure-propagation.mlir b/mlir/test/Dialect/Transform/include-failure-propagation.mlir new file mode 100644 index 0000000..94e9d8f --- /dev/null +++ b/mlir/test/Dialect/Transform/include-failure-propagation.mlir @@ -0,0 +1,38 @@ +// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --verify-diagnostics + +module attributes { transform.with_named_sequence } { + // Callee returns a silenceable failure when given a module instead of func.func. + transform.named_sequence @callee(%root: !transform.any_op {transform.consumed}) -> (!transform.any_op) { + transform.test_consume_operand_of_op_kind_or_fail %root, "func.func" : !transform.any_op + transform.yield %root : !transform.any_op + } + + transform.named_sequence @__transform_main(%root: !transform.any_op) { + %res = transform.sequence %root : !transform.any_op -> !transform.any_op failures(suppress) { + ^bb0(%arg0: !transform.any_op): + // This include returns a silenceable failure; it must not remap results. + %included = transform.include @callee failures(propagate) (%arg0) : (!transform.any_op) -> (!transform.any_op) + transform.yield %included : !transform.any_op + } + + %count = transform.num_associations %res : (!transform.any_op) -> !transform.param<i64> + // expected-remark @below {{0}} + transform.debug.emit_param_as_remark %count : !transform.param<i64> + + // If the include incorrectly forwarded mappings on failure, this would run + // and produce an unexpected remark under --verify-diagnostics. + transform.foreach %res : !transform.any_op { + ^bb0(%it: !transform.any_op): + transform.debug.emit_remark_at %it, "include result unexpectedly populated" : !transform.any_op + } + transform.yield + } +} + +// ----- + +module { + func.func @payload() { + return + } +} diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir index ce8f69c..4806daf7 100644 --- a/mlir/test/Dialect/Transform/test-pass-application.mlir +++ b/mlir/test/Dialect/Transform/test-pass-application.mlir @@ -386,7 +386,7 @@ module attributes {transform.with_named_sequence} { // ----- module attributes {transform.with_named_sequence} { - // expected-error @below {{trying to schedule a pass on an unsupported operation}} + // expected-error @below {{trying to schedule pass 'DuplicateFunctionEliminationPass' on an unsupported operation}} // expected-note @below {{target op}} func.func @invalid_target_op_type() { return diff --git a/mlir/test/Dialect/UB/ops.mlir b/mlir/test/Dialect/UB/ops.mlir index 724b6b4..730c1bd 100644 --- a/mlir/test/Dialect/UB/ops.mlir +++ b/mlir/test/Dialect/UB/ops.mlir @@ -38,3 +38,9 @@ func.func @poison_tensor() -> tensor<8x?xf64> { %0 = ub.poison : tensor<8x?xf64> return %0 : tensor<8x?xf64> } + +// CHECK-LABEL: func @unreachable() +// CHECK: ub.unreachable +func.func @unreachable() { + ub.unreachable +} diff --git a/mlir/test/Dialect/Vector/bufferize.mlir b/mlir/test/Dialect/Vector/bufferize.mlir index 887fb94..70adefd 100644 --- a/mlir/test/Dialect/Vector/bufferize.mlir +++ b/mlir/test/Dialect/Vector/bufferize.mlir @@ -32,6 +32,26 @@ func.func @transfer_write(%t: tensor<?x?xf32>, %o1: index, // ----- +// CHECK-LABEL: func @scatter( +// CHECK-SAME: %[[base:.*]]: tensor<16x16xf32>, %[[v:.*]]: vector<16xi32>, +// CHECK-SAME: %[[mask:.*]]: vector<16xi1>, %[[value:.*]]: vector<16xf32>) -> tensor<16x16xf32> +// CHECK: %[[buf:.*]] = bufferization.to_buffer %[[base]] : tensor<16x16xf32> to memref<16x16xf32> +// CHECK: %[[c0:.*]] = arith.constant 0 : index +// CHECK: %[[alloc:.*]] = memref.alloc() {alignment = 64 : i64} : memref<16x16xf32> +// CHECK: memref.copy %[[buf]], %[[alloc]] : memref<16x16xf32> to memref<16x16xf32> +// CHECK: vector.scatter %[[alloc]][%[[c0]], %[[c0]]] [%[[v]]], %[[mask]], %[[value]] : memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> +// CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[alloc]] : memref<16x16xf32> to tensor<16x16xf32> +// CHECK: return %[[tensor]] : tensor<16x16xf32> +func.func @scatter(%base: tensor<16x16xf32>, %v: vector<16xi32>, + %mask: vector<16xi1>, %value: vector<16xf32>) -> tensor<16x16xf32> { + %c0 = arith.constant 0 : index + %0 = vector.scatter %base[%c0, %c0][%v], %mask, %value + : tensor<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> -> tensor<16x16xf32> + return %0 : tensor<16x16xf32> +} + +// ----- + // CHECK-LABEL: func @gather( // CHECK-SAME: %[[base:.*]]: tensor<?x?xf32>, %[[v:.*]]: vector<16xi32>, // CHECK-SAME: %[[mask:.*]]: vector<16xi1>, %[[pass_thru:.*]]: vector<16xf32>) diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 5f035e3..79b09e1 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1491,9 +1491,9 @@ func.func @gather_non_power_of_two_alignment(%base: memref<16xf32>, %indices: ve func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) { %c0 = arith.constant 0 : index - // expected-error@+2 {{custom op 'vector.scatter' invalid kind of type specified}} + // expected-error@+1 {{'vector.scatter' op operand #0 must be Tensor or MemRef of any type values, but got 'vector<16xf32>'}} vector.scatter %base[%c0][%indices], %mask, %pass_thru - : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> + : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> } // ----- diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index da9a1a8..de62022 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -1160,3 +1160,17 @@ func.func @step() { %1 = vector.step : vector<[4]xindex> return } + +// CHECK-LABEL: func @scatter_tensor( +// CHECK-SAME: %[[BASE:.*]]: tensor<16x16xf32>, %[[V:.*]]: vector<16xi32>, +// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>) -> tensor<16x16xf32> +func.func @scatter_tensor(%base: tensor<16x16xf32>, %v: vector<16xi32>, + %mask: vector<16xi1>, %value: vector<16xf32>) -> tensor<16x16xf32> { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + // CHECK: %[[RESULT:.*]] = vector.scatter %[[BASE]][%[[C0]], %[[C0]]] [%[[V]]], %[[MASK]], %[[VALUE]] + %0 = vector.scatter %base[%c0, %c0] [%v], %mask, %value + : tensor<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> -> tensor<16x16xf32> + // CHECK: return %[[RESULT]] : tensor<16x16xf32> + return %0 : tensor<16x16xf32> +} diff --git a/mlir/test/Dialect/Vector/vector-scan-transforms.mlir b/mlir/test/Dialect/Vector/vector-scan-transforms.mlir index 1d8f440..27a3653 100644 --- a/mlir/test/Dialect/Vector/vector-scan-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-scan-transforms.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --test-vector-scan-lowering | FileCheck %s +// RUN: mlir-opt %s -split-input-file --test-vector-scan-lowering | FileCheck %s // CHECK-LABEL: func @scan1d_inc // CHECK-SAME: %[[ARG0:.*]]: vector<2xi32>, @@ -18,6 +18,20 @@ func.func @scan1d_inc(%arg0 : vector<2xi32>, %arg1 : vector<i32>) -> (vector<2xi return %0#0, %0#1 : vector<2xi32>, vector<i32> } +// ----- + +// Reducing scalable dims is not yet supported! + +// CHECK-LABEL: func @scan1d_inc_scalable +// CHECK: vector.scan +func.func @scan1d_inc_scalable(%arg0 : vector<[2]xi32>, %arg1 : vector<i32>) -> (vector<[2]xi32>, vector<i32>) { + %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim = 0} : + vector<[2]xi32>, vector<i32> + return %0#0, %0#1 : vector<[2]xi32>, vector<i32> +} + +// ----- + // CHECK-LABEL: func @scan1d_exc // CHECK-SAME: %[[ARG0:.*]]: vector<2xi32>, // CHECK-SAME: %[[ARG1:.*]]: vector<i32> @@ -36,6 +50,20 @@ func.func @scan1d_exc(%arg0 : vector<2xi32>, %arg1 : vector<i32>) -> (vector<2xi return %0#0, %0#1 : vector<2xi32>, vector<i32> } +// ----- + +// Rducing scalable dims is not yet supported! + +// CHECK-LABEL: func @scan1d_exc_scalable +// CHECK: vector.scan +func.func @scan1d_exc_scalable(%arg0 : vector<[2]xi32>, %arg1 : vector<i32>) -> (vector<[2]xi32>, vector<i32>) { + %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = false, reduction_dim = 0} : + vector<[2]xi32>, vector<i32> + return %0#0, %0#1 : vector<[2]xi32>, vector<i32> +} + +// ----- + // CHECK-LABEL: func @scan2d_mul_dim0 // CHECK-SAME: %[[ARG0:.*]]: vector<2x3xi32>, // CHECK-SAME: %[[ARG1:.*]]: vector<3xi32> @@ -53,6 +81,27 @@ func.func @scan2d_mul_dim0(%arg0 : vector<2x3xi32>, %arg1 : vector<3xi32>) -> (v return %0#0, %0#1 : vector<2x3xi32>, vector<3xi32> } +// ----- + +// CHECK-LABEL: func @scan2d_mul_dim0_scalable +// CHECK-SAME: %[[ARG0:.*]]: vector<2x[3]xi32>, +// CHECK-SAME: %[[ARG1:.*]]: vector<[3]xi32> +// CHECK: %[[A:.*]] = arith.constant dense<0> : vector<2x[3]xi32> +// CHECK: %[[B:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x[3]xi32> to vector<1x[3]xi32> +// CHECK: %[[C:.*]] = vector.insert_strided_slice %[[B]], %[[A]] {offsets = [0, 0], strides = [1, 1]} : vector<1x[3]xi32> into vector<2x[3]xi32> +// CHECK: %[[D:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [1, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x[3]xi32> to vector<1x[3]xi32> +// CHECK: %[[E:.*]] = arith.muli %[[B]], %[[D]] : vector<1x[3]xi32> +// CHECK: %[[F:.*]] = vector.insert_strided_slice %[[E]], %[[C]] {offsets = [1, 0], strides = [1, 1]} : vector<1x[3]xi32> into vector<2x[3]xi32> +// CHECK: %[[G:.*]] = vector.shape_cast %[[E]] : vector<1x[3]xi32> to vector<[3]xi32> +// CHECK: return %[[F]], %[[G]] : vector<2x[3]xi32>, vector<[3]xi32> +func.func @scan2d_mul_dim0_scalable(%arg0 : vector<2x[3]xi32>, %arg1 : vector<[3]xi32>) -> (vector<2x[3]xi32>, vector<[3]xi32>) { + %0:2 = vector.scan <mul>, %arg0, %arg1 {inclusive = true, reduction_dim = 0} : + vector<2x[3]xi32>, vector<[3]xi32> + return %0#0, %0#1 : vector<2x[3]xi32>, vector<[3]xi32> +} + +// ----- + // CHECK-LABEL: func @scan2d_mul_dim1 // CHECK-SAME: %[[ARG0:.*]]: vector<2x3xi32>, // CHECK-SAME: %[[ARG1:.*]]: vector<2xi32> @@ -73,6 +122,30 @@ func.func @scan2d_mul_dim1(%arg0 : vector<2x3xi32>, %arg1 : vector<2xi32>) -> (v return %0#0, %0#1 : vector<2x3xi32>, vector<2xi32> } +// ----- + +// CHECK-LABEL: func @scan2d_mul_dim1_scalable +// CHECK-SAME: %[[ARG0:.*]]: vector<[2]x3xi32>, +// CHECK-SAME: %[[ARG1:.*]]: vector<[2]xi32> +// CHECK: %[[A:.*]] = arith.constant dense<0> : vector<[2]x3xi32> +// CHECK: %[[B:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0], sizes = [2, 1], strides = [1, 1]} : vector<[2]x3xi32> to vector<[2]x1xi32> +// CHECK: %[[C:.*]] = vector.insert_strided_slice %[[B]], %[[A]] {offsets = [0, 0], strides = [1, 1]} : vector<[2]x1xi32> into vector<[2]x3xi32> +// CHECK: %[[D:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} : vector<[2]x3xi32> to vector<[2]x1xi32> +// CHECK: %[[E:.*]] = arith.muli %[[B]], %[[D]] : vector<[2]x1xi32> +// CHECK: %[[F:.*]] = vector.insert_strided_slice %[[E]], %[[C]] {offsets = [0, 1], strides = [1, 1]} : vector<[2]x1xi32> into vector<[2]x3xi32> +// CHECK: %[[G:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 2], sizes = [2, 1], strides = [1, 1]} : vector<[2]x3xi32> to vector<[2]x1xi32> +// CHECK: %[[H:.*]] = arith.muli %[[E]], %[[G]] : vector<[2]x1xi32> +// CHECK: %[[I:.*]] = vector.insert_strided_slice %[[H]], %[[F]] {offsets = [0, 2], strides = [1, 1]} : vector<[2]x1xi32> into vector<[2]x3xi32> +// CHECK: %[[J:.*]] = vector.shape_cast %[[H]] : vector<[2]x1xi32> to vector<[2]xi32> +// CHECK: return %[[I]], %[[J]] : vector<[2]x3xi32>, vector<[2]xi32> +func.func @scan2d_mul_dim1_scalable(%arg0 : vector<[2]x3xi32>, %arg1 : vector<[2]xi32>) -> (vector<[2]x3xi32>, vector<[2]xi32>) { + %0:2 = vector.scan <mul>, %arg0, %arg1 {inclusive = true, reduction_dim = 1} : + vector<[2]x3xi32>, vector<[2]xi32> + return %0#0, %0#1 : vector<[2]x3xi32>, vector<[2]xi32> +} + +// ----- + // CHECK-LABEL: func @scan3d_mul_dim1 // CHECK-SAME: %[[ARG0:.*]]: vector<4x2x3xf32>, // CHECK-SAME: %[[ARG1:.*]]: vector<4x3xf32> @@ -89,3 +162,22 @@ func.func @scan3d_mul_dim1(%arg0 : vector<4x2x3xf32>, %arg1 : vector<4x3xf32>) - vector<4x2x3xf32>, vector<4x3xf32> return %0#0, %0#1 : vector<4x2x3xf32>, vector<4x3xf32> } + +// ----- + +// CHECK-LABEL: func @scan3d_mul_dim1_scalable +// CHECK-SAME: %[[ARG0:.*]]: vector<4x2x[3]xf32>, +// CHECK-SAME: %[[ARG1:.*]]: vector<4x[3]xf32> +// CHECK: %[[A:.*]] = arith.constant dense<0.000000e+00> : vector<4x2x[3]xf32> +// CHECK: %[[B:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x[3]xf32> to vector<4x1x[3]xf32> +// CHECK: %[[C:.*]] = vector.shape_cast %[[ARG1]] : vector<4x[3]xf32> to vector<4x1x[3]xf32> +// CHECK: %[[D:.*]] = vector.insert_strided_slice %[[C]], %[[A]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x[3]xf32> into vector<4x2x[3]xf32> +// CHECK: %[[E:.*]] = arith.mulf %[[C]], %[[B]] : vector<4x1x[3]xf32> +// CHECK: %[[F:.*]] = vector.insert_strided_slice %[[E]], %[[D]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x[3]xf32> into vector<4x2x[3]xf32> +// CHECK: %[[G:.*]] = vector.shape_cast %[[E]] : vector<4x1x[3]xf32> to vector<4x[3]xf32> +// CHECK: return %[[F]], %[[G]] : vector<4x2x[3]xf32>, vector<4x[3]xf32> +func.func @scan3d_mul_dim1_scalable(%arg0 : vector<4x2x[3]xf32>, %arg1 : vector<4x[3]xf32>) -> (vector<4x2x[3]xf32>, vector<4x[3]xf32>) { + %0:2 = vector.scan <mul>, %arg0, %arg1 {inclusive = false, reduction_dim = 1} : + vector<4x2x[3]xf32>, vector<4x[3]xf32> + return %0#0, %0#1 : vector<4x2x[3]xf32>, vector<4x[3]xf32> +} diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir index 577b06d..69fba88 100644 --- a/mlir/test/Dialect/Vector/vector-sink.mlir +++ b/mlir/test/Dialect/Vector/vector-sink.mlir @@ -382,6 +382,21 @@ func.func @broadcast_scalar_extsi_scalable(%a : i8) -> vector<2x[4]xi32> { return %r : vector<2x[4]xi32> } +// ----- + +// CHECK-LABEL: func.func @negative_broadcast_cast_non_vector_result +// CHECK-SAME: (%[[ARG:.*]]: i64) +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG]] : i64 to vector<26x7xi64> +// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[BCAST]] : vector<26x7xi64> to !llvm.array<26 x vector<7xi64>> +// CHECK: return %[[CAST]] : !llvm.array<26 x vector<7xi64>> +/// This test ensures that the `ReorderCastOpsOnBroadcast` pattern does not +/// attempt to reorder a cast operation that produces a non-vector result type. +func.func @negative_broadcast_cast_non_vector_result(%arg0: i64) -> !llvm.array<26 x vector<7xi64>> { + %0 = vector.broadcast %arg0 : i64 to vector<26x7xi64> + %1 = builtin.unrealized_conversion_cast %0 : vector<26x7xi64> to !llvm.array<26 x vector<7xi64>> + return %1 : !llvm.array<26 x vector<7xi64>> +} + //===----------------------------------------------------------------------===// // [Pattern: ReorderElementwiseOpsOnTranspose] //===----------------------------------------------------------------------===// @@ -780,7 +795,7 @@ func.func @negative_extract_load_scalable(%arg0: memref<?xf32>, %arg1: index) -> } //----------------------------------------------------------------------------- -// [Pattern: StoreOpFromSplatOrBroadcast] +// [Pattern: StoreOpFromBroadcast] //----------------------------------------------------------------------------- // CHECK-LABEL: @store_splat diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir index e5a98b5..805e66f 100644 --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -496,3 +496,137 @@ func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf3 // CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32> // CHECK-NOT: arith.addf // CHECK: return + +func.func @vector_create_mask(%size1: index, %size2: index) -> vector<16x16xi1> { + %0 = vector.create_mask %size1, %size2 : vector<16x16xi1> + return %0 : vector<16x16xi1> +} + +// CHECK-LABEL: func @vector_create_mask +// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<16x16xi1> +// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<16x16xi1> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C8:.*]] = arith.constant 8 : index +// CHECK: %[[MAX0:.*]] = arith.maxsi %[[ARG0]], %[[C0]] : index +// CHECK: %[[MIN0:.*]] = arith.minsi %[[MAX0]], %[[C8]] : index +// CHECK: %[[MAX1:.*]] = arith.maxsi %[[ARG1]], %[[C0]] : index +// CHECK: %[[MIN1:.*]] = arith.minsi %[[MAX1]], %[[C8]] : index +// CHECK: %[[MASK00:.*]] = vector.create_mask %[[MIN0]], %[[MIN1]] : vector<8x8xi1> +// CHECK: %[[INS00:.*]] = vector.insert_strided_slice %[[MASK00]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: %[[MAX0_2:.*]] = arith.maxsi %[[ARG0]], %[[C0]] : index +// CHECK: %[[MIN0_2:.*]] = arith.minsi %[[MAX0_2]], %[[C8]] : index +// CHECK: %[[SUB1:.*]] = arith.subi %[[ARG1]], %[[C8]] : index +// CHECK: %[[MAX1_2:.*]] = arith.maxsi %[[SUB1]], %[[C0]] : index +// CHECK: %[[MIN1_2:.*]] = arith.minsi %[[MAX1_2]], %[[C8]] : index +// CHECK: %[[MASK01:.*]] = vector.create_mask %[[MIN0_2]], %[[MIN1_2]] : vector<8x8xi1> +// CHECK: %[[INS01:.*]] = vector.insert_strided_slice %[[MASK01]], %[[INS00]] {offsets = [0, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: %[[SUB0:.*]] = arith.subi %[[ARG0]], %[[C8]] : index +// CHECK: %[[MAX0_3:.*]] = arith.maxsi %[[SUB0]], %[[C0]] : index +// CHECK: %[[MIN0_3:.*]] = arith.minsi %[[MAX0_3]], %[[C8]] : index +// CHECK: %[[MAX1_3:.*]] = arith.maxsi %[[ARG1]], %[[C0]] : index +// CHECK: %[[MIN1_3:.*]] = arith.minsi %[[MAX1_3]], %[[C8]] : index +// CHECK: %[[MASK10:.*]] = vector.create_mask %[[MIN0_3]], %[[MIN1_3]] : vector<8x8xi1> +// CHECK: %[[INS10:.*]] = vector.insert_strided_slice %[[MASK10]], %[[INS01]] {offsets = [8, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: %[[SUB0_2:.*]] = arith.subi %[[ARG0]], %[[C8]] : index +// CHECK: %[[MAX0_4:.*]] = arith.maxsi %[[SUB0_2]], %[[C0]] : index +// CHECK: %[[MIN0_4:.*]] = arith.minsi %[[MAX0_4]], %[[C8]] : index +// CHECK: %[[SUB1_2:.*]] = arith.subi %[[ARG1]], %[[C8]] : index +// CHECK: %[[MAX1_4:.*]] = arith.maxsi %[[SUB1_2]], %[[C0]] : index +// CHECK: %[[MIN1_4:.*]] = arith.minsi %[[MAX1_4]], %[[C8]] : index +// CHECK: %[[MASK11:.*]] = vector.create_mask %[[MIN0_4]], %[[MIN1_4]] : vector<8x8xi1> +// CHECK: %[[INS11:.*]] = vector.insert_strided_slice %[[MASK11]], %[[INS10]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: return %[[INS11]] : vector<16x16xi1> + +func.func @vector_create_mask_constant_dim_sizes() -> vector<16x16xi1> { + %cst16 = arith.constant 16 : index + %0 = vector.create_mask %cst16, %cst16 : vector<16x16xi1> + return %0 : vector<16x16xi1> +} + +// CHECK-LABEL: func @vector_create_mask_constant_dim_sizes() -> vector<16x16xi1> { +// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<16x16xi1> +// CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<8x8xi1> +// CHECK: %[[S0:.*]] = vector.insert_strided_slice %[[CST_0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: %[[S1:.*]] = vector.insert_strided_slice %[[CST_0]], %[[S0]] {offsets = [0, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: %[[S2:.*]] = vector.insert_strided_slice %[[CST_0]], %[[S1]] {offsets = [8, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: %[[S3:.*]] = vector.insert_strided_slice %[[CST_0]], %[[S2]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: return %[[S3]] : vector<16x16xi1> + + +func.func @shape_cast_1D(%v: vector<16xf32>) -> vector<2x2x4xf32> { + %0 = vector.shape_cast %v : vector<16xf32> to vector<2x2x4xf32> + return %0 : vector<2x2x4xf32> +} + +// CHECK-LABEL: func @shape_cast_1D +// CHECK-SAME: (%[[V:.*]]: vector<16xf32>) -> vector<2x2x4xf32> { +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x4xf32> +// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32> +// CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<8xf32> to vector<2x4xf32> +// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<2x2x4xf32> +// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [8], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32> +// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<8xf32> to vector<2x4xf32> +// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<2x2x4xf32> +// CHECK: return %[[I1]] : vector<2x2x4xf32> + + +func.func @shape_cast_2D(%v: vector<8x2xf32>) -> vector<4x4xf32> { + %0 = vector.shape_cast %v : vector<8x2xf32> to vector<4x4xf32> + return %0 : vector<4x4xf32> +} + +// CHECK-LABEL: func @shape_cast_2D +// CHECK-SAME: (%[[V:.*]]: vector<8x2xf32>) -> vector<4x4xf32> { +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf32> +// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32> +// CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<4x2xf32> to vector<2x4xf32> +// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<4x4xf32> +// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [4, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32> +// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<4x2xf32> to vector<2x4xf32> +// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x4xf32> into vector<4x4xf32> +// CHECK: return %[[I1]] : vector<4x4xf32> + + +// This is a negative test case to ensure that such shape casts are not unrolled +// because the targetShape (2x4) is not contiguous in result vector +func.func @negative_shape_cast_target_shape_not_contiguous(%v: vector<64xf32>) -> vector<8x8xf32> { + %0 = vector.shape_cast %v : vector<64xf32> to vector<8x8xf32> + return %0 : vector<8x8xf32> +} + +// CHECK-LABEL: func @negative_shape_cast_target_shape_not_contiguous +// CHECK-SAME: (%[[V:.*]]: vector<64xf32>) -> vector<8x8xf32> { +// CHECK: %[[SC:.*]] = vector.shape_cast %[[V]] : vector<64xf32> to vector<8x8xf32> +// CHECK: return %[[SC]] : vector<8x8xf32> + + +// This is negative test case to ensure that such shape casts are not unrolled +// because it cannot determine the extractShape from source vector (8x3) +// to extract conitguous targetShape (2x4) +func.func @negative_shape_cast_source_shape_not_determinable(%v: vector<8x3xf32>) -> vector<6x4xf32> { + %0 = vector.shape_cast %v : vector<8x3xf32> to vector<6x4xf32> + return %0 : vector<6x4xf32> +} + +// CHECK-LABEL: func @negative_shape_cast_source_shape_not_determinable +// CHECK-SAME: (%[[V:.*]]: vector<8x3xf32>) -> vector<6x4xf32> { +// CHECK: %[[SC:.*]] = vector.shape_cast %[[V]] : vector<8x3xf32> to vector<6x4xf32> +// CHECK: return %[[SC]] : vector<6x4xf32> + + +// TargetShape is [1x16] +func.func @shape_cast_leading_unit_dim(%v: vector<32xf32>) -> vector<1x32xf32> { + %0 = vector.shape_cast %v : vector<32xf32> to vector<1x32xf32> + return %0 : vector<1x32xf32> +} + +// CHECK-LABEL: func @shape_cast_leading_unit_dim +// CHECK-SAME: (%[[V:.*]]: vector<32xf32>) -> vector<1x32xf32> { +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1x32xf32> +// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0], sizes = [16], strides = [1]} : vector<32xf32> to vector<16xf32> +// CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<16xf32> to vector<1x16xf32> +// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<1x16xf32> into vector<1x32xf32> +// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [16], sizes = [16], strides = [1]} : vector<32xf32> to vector<16xf32> +// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<16xf32> to vector<1x16xf32> +// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [0, 16], strides = [1, 1]} : vector<1x16xf32> into vector<1x32xf32> +// CHECK: return %[[I1]] : vector<1x32xf32> diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir new file mode 100644 index 0000000..e506b16 --- /dev/null +++ b/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir @@ -0,0 +1,344 @@ +// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s + +!vecA = vector<1x1xf32> +!vecB = vector<1x64xf32> +!vecC = vector<1x64xf32> +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +func.func @matmul_outer_product_to_fma( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @matmul_outer_product_to_fma +// CHECK: vector.broadcast{{.*}}vector<1xf32> to vector<64xf32> +// CHECK: vector.fma{{.*}}vector<64xf32> +// CHECK: vector.shape_cast{{.*}}vector<64xf32> to vector<1x64xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<64x1xf32> +!vecB = vector<1x1xf32> +!vecC = vector<64x1xf32> +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +func.func @matmul_outer_product_to_fma_bcst_B( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @matmul_outer_product_to_fma_bcst_B +// CHECK: vector.broadcast +// CHECK: vector.fma{{.*}}vector<64xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1xf32> +!vecB = vector<1x1x64xf32> +!vecC = vector<1x1x64xf32> +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +func.func @batch_matmul_to_fma( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @batch_matmul_to_fma +// CHECK: vector.broadcast +// CHECK: vector.fma{{.*}}vector<64xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x64x1xf32> +!vecB = vector<1x1x1xf32> +!vecC = vector<1x64x1xf32> +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +func.func @batch_matmul_to_fma_bcst_B( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @batch_matmul_to_fma_bcst_B +// CHECK: vector.broadcast +// CHECK: vector.fma{{.*}}vector<64xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1xf32> +!vecB = vector<1x1x64xf32> +!vecC = vector<1x64xf32> +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +func.func @brgemm_to_fma( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @brgemm_to_fma +// CHECK: vector.broadcast +// CHECK: vector.fma{{.*}}vector<64xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x64x1xf32> +!vecB = vector<1x1x1xf32> +!vecC = vector<64x1xf32> +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +func.func @brgemm_to_fma_bcst_B( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @brgemm_to_fma_bcst_B +// CHECK: vector.broadcast +// CHECK: vector.fma{{.*}}vector<64xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<3x1x1xf32> +!vecB = vector<3x1x64xf32> +!vecC = vector<3x1x64xf32> +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +func.func @negative_non_unit_batch_dim( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// Batch dimension should've been simplified earlier. + +// CHECK-LABEL: @negative_non_unit_batch_dim +// CHECK-NOT: vector.fma +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<3x1x1xf32> +!vecB = vector<3x1x64xf32> +!vecC = vector<1x64xf32> +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +func.func @negative_non_unit_batch_reduce_dim( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// Batch-reduce dimension should've been simplified earlier. + +// CHECK-LABEL: @negative_non_unit_batch_reduce_dim +// CHECK-NOT: vector.fma +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1xf32> +!vecB = vector<1x64xf32> +!vecC = vector<1x64xf32> +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +func.func @negative_invalid_kind( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind<mul>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_invalid_kind +// CHECK-NOT: vector.fma +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1xf32> +!vecB = vector<1x1x64xf32> +!vecC = vector<1x1x64xi32> +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +func.func @negative_accumulator_type( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_accumulator_type +// CHECK-NOT: vector.fma +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir new file mode 100644 index 0000000..65676cb --- /dev/null +++ b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir @@ -0,0 +1,681 @@ +// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s + +!vecA = vector<1x1x1x2xbf16> +!vecB = vector<1x1x16x2xbf16> +!vecC = vector<1x16xf32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)> +func.func @brgemm_to_bf16dp( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @brgemm_to_bf16dp +// CHECK: vector.broadcast +// CHECK: x86vector.avx512.dot + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x16x1x2xbf16> +!vecB = vector<1x1x1x2xbf16> +!vecC = vector<16x1xf32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)> +func.func @brgemm_to_bf16dp_bcst_B( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @brgemm_to_bf16dp_bcst_B +// CHECK: vector.broadcast +// CHECK: x86vector.avx512.dot + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x4xi8> +!vecB = vector<1x1x8x4xi8> +!vecC = vector<1x8xi32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)> +func.func @brgemm_to_int8dp( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @brgemm_to_int8dp +// CHECK: vector.broadcast +// CHECK: x86vector.avx.dot.i8 + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x2xbf16> +!vecB = vector<1x1x16x2xbf16> +!vecC = vector<1x1x16xf32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)> +func.func @batch_matmul_bf16dp( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @batch_matmul_bf16dp +// CHECK: vector.broadcast +// CHECK: x86vector.avx512.dot + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x4xi8> +!vecB = vector<1x1x8x4xi8> +!vecC = vector<1x1x8xi32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)> +func.func @batch_matmul_int8dp( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + + +// CHECK-LABEL: @batch_matmul_int8dp +// CHECK: vector.broadcast +// CHECK: x86vector.avx.dot.i8 + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x8x1x4xi8> +!vecB = vector<1x1x1x4xi8> +!vecC = vector<1x8x1xi32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)> +func.func @batch_matmul_int8dp_bcst_B( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + + +// CHECK-LABEL: @batch_matmul_int8dp_bcst_B +// CHECK: vector.broadcast +// CHECK: x86vector.avx.dot.i8 + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x2xbf16> +!vecB = vector<1x16x2xbf16> +!vecC = vector<1x16xf32> +#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)> +#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)> +#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)> +func.func @matmul_outer_product_to_bf16dp( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @matmul_outer_product_to_bf16dp +// CHECK: vector.broadcast +// CHECK: x86vector.avx512.dot + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<16x1x2xbf16> +!vecB = vector<1x1x2xbf16> +!vecC = vector<16x1xf32> +#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)> +#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)> +#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)> +func.func @matmul_outer_product_to_bf16dp_bcst_B( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @matmul_outer_product_to_bf16dp_bcst_B +// CHECK: vector.broadcast +// CHECK: x86vector.avx512.dot + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x4xi8> +!vecB = vector<1x8x4xi8> +!vecC = vector<1x8xi32> +#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)> +#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)> +#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)> +func.func @matmul_outer_product_to_int8dp( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @matmul_outer_product_to_int8dp +// CHECK: vector.broadcast +// CHECK: x86vector.avx.dot.i8 + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x2xbf16> +!vecB = vector<1x16x2xbf16> +!vecC = vector<1x16xf32> +#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)> +#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)> +#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)> +func.func @negative_invalid_vc_kind( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<mul>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_invalid_vc_kind +// CHECK-NOT: x86vector.avx512.dot +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x4xbf16> +!vecB = vector<1x1x16x4xbf16> +!vecC = vector<1x16xf32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)> +func.func @negative_false_vnni_bf16( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_false_vnni_bf16 +// CHECK-NOT: x86vector.avx512.dot +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x2xi8> +!vecB = vector<1x1x8x2xi8> +!vecC = vector<1x8xi32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)> +func.func @negative_false_vnni_int8( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_false_vnni_int8 +// CHECK-NOT: x86vector.avx.dot.i8 +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<3x1x1x2xbf16> +!vecB = vector<3x1x16x2xbf16> +!vecC = vector<3x1x16xf32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)> +func.func @negative_batch_dimension( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_batch_dimension +// CHECK-NOT: x86vector.avx512.dot +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<2x1x1x4xi8> +!vecB = vector<2x1x8x4xi8> +!vecC = vector<1x8xi32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)> +func.func @negative_brgemm_dimension( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_brgemm_dimension +// CHECK-NOT: x86vector.avx.dot.i8 +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x2xbf16> +!vecB = vector<1x1x16x2xbf16> +!vecC = vector<1x1x16xbf16> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)> +func.func @negative_float_acc_type( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_float_acc_type +// CHECK-NOT: x86vector.avx512.dot +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x4xi8> +!vecB = vector<1x1x8x4xi8> +!vecC = vector<1x1x8xi8> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)> +func.func @negative_int_acc_type( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_int_acc_type +// CHECK-NOT: x86vector.avx.dot.i8 +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x4xbf16> +!vecB = vector<1x1x16x4xbf16> +!vecC = vector<1x1x16xbf16> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)> +func.func @negative_wrong_vnni_blocking_factor_bf16( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_wrong_vnni_blocking_factor_bf16 +// CHECK-NOT: x86vector.avx512.dot +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1xbf16> +!vecB = vector<1x1x32xbf16> +!vecC = vector<1x32xf32> +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +func.func @negative_brgemm_not_vnni( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_brgemm_not_vnni +// CHECK-NOT: x86vector.avx512.dot +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x4xi8> +!vecB = vector<1x1x16x4xi8> +!vecC = vector<1x1x16xi32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)> +func.func @negative_wrong_vector_shape_int8( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_wrong_vector_shape_int8 +// CHECK-NOT: x86vector.avx.dot.i8 +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x2xbf16> +!vecB = vector<1x1x32x2xbf16> +!vecC = vector<1x1x32xf32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)> +func.func @negative_wrong_vector_shape_bf16( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_wrong_vector_shape_bf16 +// CHECK-NOT: x86vector.avx512.dot +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index ebbe3ce..67faa60 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -451,7 +451,7 @@ func.func @store_scatter_offset_wi_1(%src: memref<?xf16>) { %offsets = arith.constant dense<[0]> : vector<1xindex> %mask = arith.constant dense<1>: vector<1xi1> // expected-error@+1 {{Mask should match value except the chunk size dim}} - xegpu.store %val, %src[%offsets], %mask + xegpu.store %val, %src[%offsets], %mask : vector<4xf16>, memref<?xf16>, vector<1xindex>, vector<1xi1> return } @@ -836,7 +836,7 @@ func.func @slice_attr_repeat_dim() { // ----- func.func @create_mem_desc_non_slm() { %m = memref.alloca() {alignment = 1024} : memref<2048xi8, 1> - // expected-error@+1 {{operand #0 must be statically shaped memref of 8-bit signless integer values for shared memory}} + // expected-error@+1 {{operand #0 must be reside in share memory and statically 1d shaped memref }} %mem_desc = xegpu.create_mem_desc %m : memref<2048xi8, 1> -> !xegpu.mem_desc<16x64xf16> return } @@ -871,14 +871,6 @@ func.func @load_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>) { } // ----- -func.func @load_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>) { - // expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}} - %data2 = xegpu.load_matrix %arg0[8, 8] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16> -> vector<16x16xf16> - return -} - - -// ----- func.func @store_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf32>) { // expected-error@+1 {{failed to verify that all of {mem_desc, data} have same element type}} xegpu.store_matrix %arg1, %arg0[8, 8] : vector<16x16xf32>, !xegpu.mem_desc<16x64xf16> @@ -900,16 +892,25 @@ func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: ve } // ----- -func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) { - // expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}} - xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16> +func.func @simt_store_matrix_vector_nonlinear(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [32, 1]>>, %arg1: vector<2x16xf32>) { + // expected-error@+1 {{With subgroup_block_io, accessed data must be contiguous and coalesced}} + xegpu.store_matrix %arg1, %arg0[0, 0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : + vector<2x16xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [32, 1]>> return } // ----- -func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) { - // expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}} - xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16> +func.func @simt_store_matrix_vector_noncoalesced(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [1, 16]>>, %arg1: vector<16x2xf32>) { + // expected-error@+1 {{With subgroup_block_io, the distributed dimensions must be contiguous}} + xegpu.store_matrix %arg1, %arg0[0, 0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>} : + vector<16x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [1, 16]>> return } +// ----- +func.func @simt_store_matrix_vector_noncoalesced(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [32, 1], block = [1, 17]>>, %arg1: vector<16x2xf32>) { + // expected-error@+1 {{With subgroup_block_io, the block shape must match the lane layout}} + xegpu.store_matrix %arg1, %arg0[0, 0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : + vector<16x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [32, 1], block = [1, 17]>> + return +} diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir index 0a10f68..1e9738f 100644 --- a/mlir/test/Dialect/XeGPU/ops.mlir +++ b/mlir/test/Dialect/XeGPU/ops.mlir @@ -278,6 +278,15 @@ gpu.func @subgroup_load_nd_offset_1(%src: memref<24x32xf32>, %x : index, %y : in gpu.return } +// CHECK: func @subgroup_load_nd_offset_2(%[[arg0:.*]]: memref<24x32xf32>, %arg1: index) { +gpu.func @subgroup_load_nd_offset_2(%src: memref<24x32xf32>, %x : index) { + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0 : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32> + %1 = xegpu.create_nd_tdesc %src : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32> + // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]][%arg1, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32> + %2 = xegpu.load_nd %1[%x, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32> + gpu.return +} + // CHECK: func @simt_load_nd_8(%[[arg0:.*]]: memref<24x32xf32>) { gpu.func @simt_load_nd_8(%src: memref<24x32xf32>) { // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32> @@ -825,6 +834,27 @@ gpu.func @create_mem_desc_with_stride() { gpu.return } + +// CHECK-LABEL: gpu.func @create_mem_desc_from_2d_memref({{.*}}) { +gpu.func @create_mem_desc_from_2d_memref() { + //CHECK: [[alloc:%.+]] = memref.alloca() {alignment = 1024 : i64} : memref<16x64xf16, 3> + //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[alloc]] : memref<16x64xf16, 3> -> !xegpu.mem_desc<16x64xf16> + %m = memref.alloca() {alignment = 1024} : memref<16x64xf16, 3> + %mem_desc = xegpu.create_mem_desc %m : memref<16x64xf16, 3> -> !xegpu.mem_desc<16x64xf16> + gpu.return +} + +// CHECK-LABEL: gpu.func @create_mem_desc_with_stride_from_2d_memref({{.*}}) { +gpu.func @create_mem_desc_with_stride_from_2d_memref() { + //CHECK: %[[ALLOC:.+]] = memref.alloca() {alignment = 1024 : i64} : memref<32x64xf16, 3> + //CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][16, 0] [16, 64] [1, 1] : memref<32x64xf16, 3> to memref<16x64xf16, strided<[64, 1], offset: 1024>, 3> + //CHECK: %{{.+}} = xegpu.create_mem_desc %[[SUBVIEW]] : memref<16x64xf16, strided<[64, 1], offset: 1024>, 3> -> !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> + %m = memref.alloca() {alignment = 1024} : memref<32x64xf16, 3> + %m_sub = memref.subview %m[16, 0][16, 64][1,1] : memref<32x64xf16, 3> to memref<16x64xf16, strided<[64, 1], offset: 1024>, 3> + %mem_desc = xegpu.create_mem_desc %m_sub : memref<16x64xf16, strided<[64, 1], offset: 1024>, 3> -> !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> + gpu.return +} + // CHECK: gpu.func @load_matrix([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>) gpu.func @load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) { // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16> diff --git a/mlir/test/Dialect/XeGPU/optimize-transpose.mlir b/mlir/test/Dialect/XeGPU/optimize-transpose.mlir new file mode 100644 index 0000000..24a0de6 --- /dev/null +++ b/mlir/test/Dialect/XeGPU/optimize-transpose.mlir @@ -0,0 +1,280 @@ +// RUN: mlir-opt --xevm-attach-target='module=xevm_* chip=pvc' \ +// RUN: --xegpu-optimize-block-loads --canonicalize --cse --split-input-file %s | FileCheck %s + +// CHECK-LABEL: gpu.func @no_scf( +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<64x64xf16>, %{{.*}}: vector<8x16xf16>) -> vector<8x16xf32> { +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: %[[C32:.*]] = arith.constant 32 : index +// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<64x64xf16> -> index +// CHECK: %[[T0:.*]] = arith.index_cast %[[PTR]] : index to i64 +// CHECK: %[[BDESC:.*]] = xegpu.create_nd_tdesc %[[T0]], shape : [64, %[[C32]]], strides : [%[[C32]], 1] : i64 +// CHECK-SAME: -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> +// CHECK-NEXT: %[[B:.*]] = xegpu.load_nd %[[BDESC]][%{{.*}}, %[[C16]]] +// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} +// CHECK-SAME: : !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[B]] +// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : vector<16x8xi32> to vector<16x16xf16> +#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> +#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]> +#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]> +gpu.module @xevm_module { +gpu.func @no_scf(%arg0: memref<64x64xf16>, %arg1: vector<8x16xf16>) -> vector<8x16xf32> { + %c0 = arith.constant 0 : index + %c32 = arith.constant 32 : index + %0 = xegpu.create_nd_tdesc %arg0 : memref<64x64xf16> -> !xegpu.tensor_desc<16x16xf16, #b> + %1 = xegpu.load_nd %0[%c0, %c32] { result_layout = #b } : !xegpu.tensor_desc<16x16xf16, #b> -> vector<16x16xf16> + %2 = vector.transpose %1, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16> + %6 = xegpu.dpas %arg1, %2 { layout_result_0 = #a } : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> + gpu.return %6 : vector<8x16xf32> +} +} + +// ----- +// CHECK-LABEL: gpu.func @no_scf_i8( +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<64x64xi8>, %{{.*}}: vector<8x32xi8>) -> vector<8x16xi32> { +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<64x64xi8> -> index +// CHECK: %[[T0:.*]] = arith.index_cast %[[PTR]] : index to i64 +// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[T0]], shape : [64, %[[C16]]], strides : [%[[C16]], 1] : i64 +// CHECK-SAME: -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> +// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T1]][%{{.*}}, %[[C16]]] +// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} +// CHECK-SAME: : !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32> +// CHECK: %[[T3:.*]] = vector.bitcast %[[T2]] +// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 4]>} : vector<16x8xi32> to vector<16x32xi8> +#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]> +#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 4]> +#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]> +#c = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> +gpu.module @xevm_module { +gpu.func @no_scf_i8(%arg0: memref<64x64xi8>, %arg1: vector<8x32xi8>) -> vector<8x16xi32> { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %0 = xegpu.create_nd_tdesc %arg0 : memref<64x64xi8> -> !xegpu.tensor_desc<16x32xi8, #b> + %1 = xegpu.load_nd %0[%c0, %c64] { result_layout = #b } : !xegpu.tensor_desc<16x32xi8, #b> -> vector<16x32xi8> + %2 = vector.transpose %1, [1, 0] { layout_result_0 = #bt } : vector<16x32xi8> to vector<32x16xi8> + %6 = xegpu.dpas %arg1, %2 { layout_result_0 = #c } : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32> + gpu.return %6 : vector<8x16xi32> +} +} + + +// ----- +// CHECK-LABEL: gpu.func @gemm_b_transpose( +// CHECK-SAME: %{{.*}} memref<256x256xf16>, %[[ARG1:[a-zA-Z0-9]+]]: memref<256x256xf16>, %{{.*}}: memref<256x256xf32>) { +// CHECK: %[[C128:.*]] = arith.constant 128 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: %[[C256:.*]] = arith.constant 256 : index +// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<256x256xf16> -> index +// CHECK: %[[T3:.*]] = arith.index_cast %[[PTR]] : index to i64 +// CHECK: %[[T4:.*]] = xegpu.create_nd_tdesc %[[T3]], shape : [256, %[[C128]]], strides : [%c128, 1] +// CHECK-SAME: : i64 -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> +// CHECK: %{{.*}} = scf.for %[[K:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}}) -> (vector<8x16xf32>) { +// CHECK: %[[T7:.*]] = arith.shrui %[[K]], %[[C1]] : index +// CHECK-NEXT: %[[T8:.*]] = xegpu.load_nd %[[T4]][%{{.*}}, %[[T7]]] +// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} : +// CHECK-SAME: !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32> +// CHECK-NEXT: %{{.*}} = vector.bitcast %[[T8]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} +// CHECK-SAME: : vector<16x8xi32> to vector<16x16xf16> +#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> +#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]> +#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]> +gpu.module @xevm_module { +gpu.func @gemm_b_transpose(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #a> + %1 = xegpu.load_nd %0[%c0, %c0] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf32, #a> -> vector<8x16xf32> + %2 = xegpu.create_nd_tdesc %arg0 : memref<256x256xf16> -> !xegpu.tensor_desc<8x16xf16, #a> + %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #b> + %4 = scf.for %arg3 = %c0 to %c256 step %c16 iter_args(%arg4 = %1) -> (vector<8x16xf32>) { + %5 = xegpu.load_nd %2[%c0, %arg3] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf16, #a> -> vector<8x16xf16> + %6 = xegpu.load_nd %3[%c0, %arg3] { layout_result_0 = #b } : !xegpu.tensor_desc<16x16xf16, #b> -> vector<16x16xf16> + %7 = vector.transpose %6, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16> + %8 = xegpu.dpas %5, %7, %arg4 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + scf.yield %8 : vector<8x16xf32> + } {layout_result_0 = #a} + xegpu.store_nd %4, %0[%c0, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a> + gpu.return +} +} + +// ----- +// CHECK-LABEL: gpu.func @nested_scf( +// CHECK-SAME: %{{.*}}: memref<256x256xf16>, %[[ARG1:[a-zA-Z0-9]+]]: memref<256x256xf16>, %{{.*}}: memref<256x256xf32>) { +// CHECK: %[[C128:.*]] = arith.constant 128 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: %[[C256:.*]] = arith.constant 256 : index +// CHECK: scf.for %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<256x256xf16> -> index +// CHECK: %[[T3:.*]] = arith.index_cast %[[PTR]] : index to i64 +// CHECK: %[[T4:.*]] = xegpu.create_nd_tdesc %[[T3]], shape : [256, %[[C128]]], strides : [%[[C128]], 1] : i64 +// CHECK-SAME: -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> +// CHECK: %{{.*}} = scf.for %[[K:.*]] = %{{.*}} iter_args(%{{.*}}) -> (vector<8x16xf32>) { +// CHECK: %[[T7:.*]] = arith.shrui %[[K]], %[[C1]] : index +// CHECK-NEXT: %[[T8:.*]] = xegpu.load_nd %[[T4]][%{{.*}}, %[[T7]]] {layout_result_0 = #xegpu.layout< +// CHECK-SAME: lane_layout = [16, 1], lane_data = [1, 1]>} : +// CHECK-SAME: !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32> +// CHECK-NEXT: %{{.*}} = vector.bitcast %[[T8]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} +// CHECK-SAME: : vector<16x8xi32> to vector<16x16xf16> +#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> +#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]> +#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]> +gpu.module @xevm_module { +gpu.func @nested_scf(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + scf.for %arg8 = %c0 to %c256 step %c16 { + %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #a> + %1 = xegpu.load_nd %0[%arg8, %c0] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf32, #a> -> vector<8x16xf32> + %2 = xegpu.create_nd_tdesc %arg0 : memref<256x256xf16> -> !xegpu.tensor_desc<8x16xf16, #a> + %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #b> + %4 = scf.for %arg3 = %c0 to %c256 step %c16 iter_args(%arg4 = %1) -> (vector<8x16xf32>) { + %5 = xegpu.load_nd %2[%arg8, %arg3] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf16, #a> -> vector<8x16xf16> + %6 = xegpu.load_nd %3[%arg8, %arg3] { layout_result_0 = #b } : !xegpu.tensor_desc<16x16xf16, #b> -> vector<16x16xf16> + %7 = vector.transpose %6, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16> + %8 = xegpu.dpas %5, %7, %arg4 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + scf.yield %8 : vector<8x16xf32> + } {layout_result_0 = #a} + xegpu.store_nd %4, %0[%c0, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a> + } + gpu.return +} +} + +// ----- +// CHECK-LABEL: gpu.func @large_loads( +// CHECK-SAME: %{{.*}}: vector<8x16xf16>, %[[ARG1:[a-zA-Z0-9]+]]: memref<256x256xf16>, %{{.*}}: memref<256x256xf32>) { +// CHECK: %[[C128:.*]] = arith.constant 128 : index +// CHECK: %[[C8:.*]] = arith.constant 8 : index +// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<32x16xi32> +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<256x256xf16> -> index +// CHECK: %[[T2:.*]] = arith.index_cast %[[PTR]] : index to i64 +// CHECK: %[[T3:.*]] = xegpu.create_nd_tdesc %[[T2]], shape : [256, %[[C128]]], strides : [%[[C128]], 1] : i64 +// CHECK-SAME: -> !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> +// CHECK: %{{.*}}:4 = scf.for %[[K:.*]] = %{{.*}} iter_args(%{{.*}}) -> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) { +// CHECK: %[[T5:.*]] = arith.shrui %[[K]], %[[C1]] : index +// CHECK: %[[T6:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T5]]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} +// CHECK-SAME: : !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<32x8xi32> +// CHECK: %[[T7:.*]] = vector.insert_strided_slice %[[T6]], %[[CST]] +// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, offsets = [0, 0], strides = [1, 1]} +// CHECK-SAME: : vector<32x8xi32> into vector<32x16xi32> +// CHECK: %[[T8:.*]] = arith.addi %[[T5]], %[[C8]] : index +// CHECK: %[[T9:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T8]]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} +// CHECK-SAME: : !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<32x8xi32> +// CHECK: %[[T10:.*]] = vector.insert_strided_slice %[[T9]], %[[T7]] +// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, offsets = [0, 8], strides = [1, 1]} +// CHECK-SAME: : vector<32x8xi32> into vector<32x16xi32> +// CHECK: %{{.*}} = vector.bitcast %[[T10]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} +// CHECK-SAME: : vector<32x16xi32> to vector<32x32xf16> +#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> +#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]> +#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]> +gpu.module @xevm_module { +gpu.func @large_loads(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #a> + %1 = xegpu.load_nd %0[%c0, %c0] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf32, #a> -> vector<8x16xf32> + %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16> -> !xegpu.tensor_desc<32x32xf16, #b> + %4:4 = scf.for %arg3 = %c0 to %c256 step %c32 iter_args(%arg4 = %1, %arg5 = %1, %arg6 = %1, %arg7 = %1) + -> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) { + %6 = xegpu.load_nd %3[%c0, %arg3] { layout_result_0 = #b } : !xegpu.tensor_desc<32x32xf16, #b> -> vector<32x32xf16> + %7 = vector.extract_strided_slice %6 {offsets = [0, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b } + : vector<32x32xf16> to vector<16x16xf16> + %8 = vector.extract_strided_slice %6 {offsets = [0, 16], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b } + : vector<32x32xf16> to vector<16x16xf16> + %9 = vector.extract_strided_slice %6 {offsets = [16, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b } + : vector<32x32xf16> to vector<16x16xf16> + %10 = vector.extract_strided_slice %6 {offsets = [16, 16], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b } + : vector<32x32xf16> to vector<16x16xf16> + %11 = vector.transpose %7, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16> + %12 = vector.transpose %8, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16> + %13 = vector.transpose %9, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16> + %14 = vector.transpose %10, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16> + %15 = xegpu.dpas %arg0, %11, %arg4 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %16 = xegpu.dpas %arg0, %12, %arg5 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %17 = xegpu.dpas %arg0, %13, %arg6 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %18 = xegpu.dpas %arg0, %14, %arg7 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + scf.yield %15, %16, %17, %18 : vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32> + } {layout_result_0 = #a, layout_result_1 = #a, layout_result_2 = #a, layout_result_3 = #a} + xegpu.store_nd %4#0, %0[%c0, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a> + xegpu.store_nd %4#1, %0[%c0, %c16] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a> + xegpu.store_nd %4#2, %0[%c16, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a> + xegpu.store_nd %4#3, %0[%c16, %c16] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a> + gpu.return +} +} + +// ----- +// CHECK-LABEL: gpu.func @array_length( +// CHECK-SAME: %{{.*}}: vector<8x16xf16>, %[[ARG1:[a-zA-Z0-9]+]]: memref<256x256xf16>, %arg2: memref<256x256xf32>) { +// CHECK: %[[C128:.*]] = arith.constant 128 : index +// CHECK: %[[C8:.*]] = arith.constant 8 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<256x256xf16> -> index +// CHECK: %[[T2:.*]] = arith.index_cast %[[PTR]] : index to i64 +// CHECK: %[[T3:.*]] = xegpu.create_nd_tdesc %[[T2]], shape : [256, %[[C128]]], strides : [%[[C128]], 1] : i64 -> +// CHECK-SAME: !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> +// CHECK: %{{.*}}:4 = scf.for %[[K:.*]] = %{{.*}} iter_args(%{{.*}}) -> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) { +// CHECK: %[[T5:.*]] = arith.shrui %[[K]], %[[C1]] : index +// CHECK: %[[T6:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T5]]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} +// CHECK-SAME: : !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<32x8xi32> +// CHECK: %[[T7:.*]] = vector.bitcast %[[T6]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} +// CHECK-SAME: : vector<32x8xi32> to vector<32x16xf16> +// CHECK: %[[T8:.*]] = arith.addi %[[T5]], %[[C8]] : index +// CHECK: %[[T9:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T8]]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} +// CHECK-SAME: : !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<32x8xi32> +// CHECK: %[[T10:.*]] = vector.bitcast %[[T9]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} +// CHECK-SAME: : vector<32x8xi32> to vector<32x16xf16> +#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> +#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]> +#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]> +gpu.module @xevm_module { +gpu.func @array_length(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #a> + %1 = xegpu.load_nd %0[%c0, %c0] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf32, #a> -> vector<8x16xf32> + %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16> + -> !xegpu.tensor_desc<32x16xf16, #b, #xegpu.block_tdesc_attr<array_length = 2 : i64>> + %4:4 = scf.for %arg3 = %c0 to %c256 step %c32 iter_args(%arg4 = %1, %arg5 = %1, %arg6 = %1, %arg7 = %1) + -> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) { + %6 = xegpu.load_nd %3[%c0, %arg3] { layout_result_0 = #b } + : !xegpu.tensor_desc<32x16xf16, #b, #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<2x32x16xf16> + %19 = vector.extract %6[0] { layout_result_0 = #b } : vector<32x16xf16> from vector<2x32x16xf16> + %20 = vector.extract %6[1] { layout_result_0 = #b } : vector<32x16xf16> from vector<2x32x16xf16> + %7 = vector.extract_strided_slice %19 {offsets = [0, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b } + : vector<32x16xf16> to vector<16x16xf16> + %8 = vector.extract_strided_slice %19 {offsets = [16, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b } + : vector<32x16xf16> to vector<16x16xf16> + %9 = vector.extract_strided_slice %20 {offsets = [0, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b } + : vector<32x16xf16> to vector<16x16xf16> + %10 = vector.extract_strided_slice %20 {offsets = [16, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b } + : vector<32x16xf16> to vector<16x16xf16> + %11 = vector.transpose %7, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16> + %12 = vector.transpose %8, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16> + %13 = vector.transpose %9, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16> + %14 = vector.transpose %10, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16> + %15 = xegpu.dpas %arg0, %11, %arg4 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %16 = xegpu.dpas %arg0, %12, %arg5 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %17 = xegpu.dpas %arg0, %13, %arg6 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %18 = xegpu.dpas %arg0, %14, %arg7 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + scf.yield %15, %16, %17, %18 : vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32> + } {layout_result_0 = #a, layout_result_1 = #a, layout_result_2 = #a, layout_result_3 = #a} + xegpu.store_nd %4#0, %0[%c0, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a> + xegpu.store_nd %4#1, %0[%c0, %c16] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a> + xegpu.store_nd %4#2, %0[%c16, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a> + xegpu.store_nd %4#3, %0[%c16, %c16] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a> + gpu.return +} +} diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir index 58461b8..32fb317 100644 --- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir +++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir @@ -1,18 +1,45 @@ // RUN: mlir-opt -xevm-attach-target='chip=pvc' -xegpu-propagate-layout="layout-kind=inst" -split-input-file %s | FileCheck %s + +// CHECK-LABEL: func.func @load_store_no_array_len( +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x32xf32>, %[[ARG1:[0-9a-zA-Z]+]]: memref<8x32xf32>) { +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32> +// CHECK: %[[TDESC_SRC:.*]] = xegpu.create_nd_tdesc %[[ARG0]] : memref<8x32xf32> -> !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>> +// CHECK: %[[TDESC_DST:.*]] = xegpu.create_nd_tdesc %[[ARG1]] : memref<8x32xf32> -> !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>> +// CHECK: xegpu.prefetch_nd %[[TDESC_SRC]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, layout = #xegpu.layout<inst_data = [8, 16]>}> : +// CHECK-SAME: !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>> +// CHECK: %[[LOADED:.*]] = xegpu.load_nd %0 <{layout = #xegpu.layout<inst_data = [8, 16]>}> {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : +// CHECK-SAME: !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<8x32xf32> +// CHECK: xegpu.store_nd %[[LOADED]], %[[TDESC_DST]] <{layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>> +gpu.module @test { +// Although the uArch allows 8x32 inst data using block count (or array_len), +// it is up to optimization passes to decide on the block count usage. +func.func @load_store_no_array_len(%arg0: memref<8x32xf32>, %arg1: memref<8x32xf32>) { + %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32> + %0 = xegpu.create_nd_tdesc %arg0 : memref<8x32xf32> -> !xegpu.tensor_desc<8x32xf32> + %1 = xegpu.create_nd_tdesc %arg1 : memref<8x32xf32> -> !xegpu.tensor_desc<8x32xf32> + xegpu.prefetch_nd %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<8x32xf32> + %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x32xf32> -> vector<8x32xf32> + xegpu.store_nd %2, %1 : vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32> + return +} +} + +// ----- + // CHECK-LABEL: func.func @dpas_f16( // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) { -// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>} dense<0.000000e+00> : vector<8x16xf32> -// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>> -// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>> -// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T0]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>} : -// CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16> -// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T1]] {layout_result_0 = #xegpu.layout<inst_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>} : -// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16> -// CHECK: %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %[[CST]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>} : +// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} dense<0.000000e+00> : vector<8x16xf32> +// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<inst_data = [8, 16]> +// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [16, 16]>> +// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T0]] <{layout = #xegpu.layout<inst_data = [8, 16]>}> {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : +// CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<inst_data = [8, 16]>> -> vector<8x16xf16> +// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T1]] <{layout = #xegpu.layout<inst_data = [16, 16]>}> {layout_result_0 = #xegpu.layout<inst_data = [16, 16]>} : +// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [16, 16]>> -> vector<16x16xf16> +// CHECK: %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %[[CST]] {layout_a = #xegpu.layout<inst_data = [8, 16]>, layout_b = #xegpu.layout<inst_data = [16, 16]>, layout_cd = #xegpu.layout<inst_data = [8, 16]>, layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : // CHECK-SAME: vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> -// CHECK: %[[T5:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>> -// CHECK: xegpu.store_nd %[[T4]], %[[T5]] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>> +// CHECK: %[[T5:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<inst_data = [8, 16]> +// CHECK: xegpu.store_nd %[[T4]], %[[T5]] <{layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<inst_data = [8, 16]>> gpu.module @test { func.func @dpas_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) { @@ -46,18 +73,18 @@ gpu.module @test_kernel { %out:3 = scf.for %k = %c0 to %c1024 step %c32 iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_tdesc) -> (!xegpu.tensor_desc<16x32xf16>, !xegpu.tensor_desc<16x32xf16>, !xegpu.tensor_desc<16x32xf16>) { - //CHECK: xegpu.load_nd {{.*}} {layout_result_0 = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>} : - //CHECK-SAME: !xegpu.tensor_desc<16x32xf16, #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x32xf16> + //CHECK: xegpu.load_nd {{.*}} <{layout = #xegpu.layout<inst_data = [8, 16]>}> {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : + //CHECK-SAME: !xegpu.tensor_desc<16x32xf16, #xegpu.layout<inst_data = [8, 16]>> -> vector<16x32xf16> %a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x32xf16> -> vector<16x32xf16> %b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x32xf16> -> vector<16x32xf16> - //CHECK-COUNT: arith.addf {{.*}} {layout_result_0 = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x32xf16> + //CHECK-COUNT: arith.addf {{.*}} {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : vector<16x32xf16> %c = arith.addf %a, %b : vector<16x32xf16> - //CHECK-COUNT: xegpu.store_nd {{.*}} : vector<16x32xf16>, !xegpu.tensor_desc<16x32xf16, #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>>> + //CHECK-COUNT: xegpu.store_nd {{.*}} : vector<16x32xf16>, !xegpu.tensor_desc<16x32xf16, #xegpu.layout<inst_data = [8, 16]>> xegpu.store_nd %c, %arg2: vector<16x32xf16>, !xegpu.tensor_desc<16x32xf16> - //CHECK-COUNT: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<16x32xf16, #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>> + //CHECK-COUNT: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<16x32xf16, #xegpu.layout<inst_data = [8, 16]>> %a_next_tdesc = xegpu.update_nd_offset %arg0, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16> %b_next_tdesc = xegpu.update_nd_offset %arg1, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16> %c_next_tdesc = xegpu.update_nd_offset %arg2, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16> @@ -85,18 +112,18 @@ gpu.module @test_kernel { %out:3 = scf.for %k = %c0 to %c1024 step %c32 iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_tdesc) -> (!xegpu.tensor_desc<12x32xf16>, !xegpu.tensor_desc<12x32xf16>, !xegpu.tensor_desc<12x32xf16>) { - //CHECK: xegpu.load_nd {{.*}} {layout_result_0 = #xegpu.layout<inst_data = [4, 16], lane_layout = [1, 16], lane_data = [1, 1]>} : - //CHECK-SAME: !xegpu.tensor_desc<12x32xf16, #xegpu.layout<inst_data = [4, 16], lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<12x32xf16> + //CHECK: xegpu.load_nd {{.*}} <{layout = #xegpu.layout<inst_data = [4, 16]>}> {layout_result_0 = #xegpu.layout<inst_data = [4, 16]>} : + //CHECK-SAME: !xegpu.tensor_desc<12x32xf16, #xegpu.layout<inst_data = [4, 16]>> -> vector<12x32xf16> %a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<12x32xf16> -> vector<12x32xf16> %b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<12x32xf16> -> vector<12x32xf16> - //CHECK-COUNT: arith.addf {{.*}} {layout_result_0 = #xegpu.layout<inst_data = [4, 16], lane_layout = [1, 16], lane_data = [1, 1]>} : vector<12x32xf16> + //CHECK-COUNT: arith.addf {{.*}} {layout_result_0 = #xegpu.layout<inst_data = [4, 16]>} : vector<12x32xf16> %c = arith.addf %a, %b : vector<12x32xf16> - //CHECK-COUNT: xegpu.store_nd {{.*}} : vector<12x32xf16>, !xegpu.tensor_desc<12x32xf16, #xegpu.layout<inst_data = [4, 16], lane_layout = [1, 16], lane_data = [1, 1]>>> + //CHECK-COUNT: xegpu.store_nd {{.*}} : vector<12x32xf16>, !xegpu.tensor_desc<12x32xf16, #xegpu.layout<inst_data = [4, 16]>> xegpu.store_nd %c, %arg2: vector<12x32xf16>, !xegpu.tensor_desc<12x32xf16> - //CHECK-COUNT: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<12x32xf16, #xegpu.layout<inst_data = [4, 16], lane_layout = [1, 16], lane_data = [1, 1]>> + //CHECK-COUNT: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<12x32xf16, #xegpu.layout<inst_data = [4, 16]>> %a_next_tdesc = xegpu.update_nd_offset %arg0, [%c0, %c32] : !xegpu.tensor_desc<12x32xf16> %b_next_tdesc = xegpu.update_nd_offset %arg1, [%c0, %c32] : !xegpu.tensor_desc<12x32xf16> %c_next_tdesc = xegpu.update_nd_offset %arg2, [%c0, %c32] : !xegpu.tensor_desc<12x32xf16> @@ -113,9 +140,9 @@ gpu.module @test { // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) { // CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1> // CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex> -// CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 8 : i64}> -// CHECK-SAME: {layout_result_0 = #xegpu.layout<inst_data = [16, 8], lane_layout = [16, 1], lane_data = [1, 2]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> -// CHECK: xegpu.store %0, %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 8 : i64}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> +// CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 8 : i64, layout = #xegpu.layout<inst_data = [16, 8]>}> +// CHECK-SAME: {layout_result_0 = #xegpu.layout<inst_data = [16, 8]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> +// CHECK: xegpu.store %0, %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 8 : i64, layout = #xegpu.layout<inst_data = [16, 8]>}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> func.func @scatter_ops_chunksize(%src: memref<256xf16>) { %1 = arith.constant dense<1>: vector<16xi1> %offset = arith.constant dense<12> : vector<16xindex> diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir index 543e119..48e77d8 100644 --- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir +++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -xevm-attach-target='chip=pvc' -xegpu-propagate-layout -split-input-file %s | FileCheck %s +// RUN: mlir-opt -xevm-attach-target='chip=pvc' -xegpu-propagate-layout="layout-kind=lane" -split-input-file %s | FileCheck %s gpu.module @test { // CHECK-LABEL: func.func @dpas_f16( @@ -6,14 +6,14 @@ gpu.module @test { // CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<0.000000e+00> : vector<8x16xf32> // CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T0]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : +// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T0]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : // CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16> -// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T1]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : +// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T1]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16> -// CHECK: %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %[[CST]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : +// CHECK: %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %[[CST]] {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : // CHECK-SAME: vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> // CHECK: %[[T5:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -// CHECK: xegpu.store_nd %[[T4]], %[[T5]] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> +// CHECK: xegpu.store_nd %[[T4]], %[[T5]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> func.func @dpas_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) { %c0 = arith.constant 0 : index %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32> @@ -32,7 +32,8 @@ func.func @dpas_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: me gpu.module @test { // CHECK-LABEL: func.func @dpas_i8( // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: vector<8x32xi8>, %[[ARG1:[0-9a-zA-Z]+]]: vector<32x16xi8>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xi32>) { -// CHECK: %[[T0:.*]] = xegpu.dpas %[[ARG0]], %[[ARG1]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], +// CHECK: %[[T0:.*]] = xegpu.dpas %[[ARG0]], %[[ARG1]] {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} + func.func @dpas_i8(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2: memref<8x16xi32>) { %c0 = arith.constant 0 : index %0 = xegpu.dpas %arg0, %arg1 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32> @@ -46,8 +47,8 @@ func.func @dpas_i8(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2: memre gpu.module @test { // CHECK-LABEL: func.func @load_with_transpose_effect( // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG0:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf32>) { -// CHECK: %{{.*}} = xegpu.load_nd %{{.*}} <{transpose = array<i64: 1, 0>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : -// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>> -> vector<16x16xf16> +// CHECK: %{{.*}} = xegpu.load_nd %{{.*}} <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : +// CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16> func.func @load_with_transpose_effect(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) { %c0 = arith.constant 0 : index %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32> @@ -108,7 +109,7 @@ gpu.module @test { // CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1> // CHECK-NEXT: %[[T2:.*]] = xegpu.create_tdesc %[[ARG1]], %[[CST]] : memref<256xf16>, vector<16xindex> -> // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>> -// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} +// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]] <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}> {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>, vector<16xi1> -> vector<16x16xf16> func.func @load_gather_with_chunksize(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) { %c0 = arith.constant 0 : index @@ -135,7 +136,7 @@ gpu.module @test { // CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1> // CHECK-NEXT: %[[T0:.*]] = xegpu.create_tdesc %[[ARG0]], %[[CST]] : memref<256xf32>, vector<16xindex> -> // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -// CHECK-NEXT: %{{.*}} = xegpu.load %[[T0]], %[[CST0]] {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : +// CHECK-NEXT: %{{.*}} = xegpu.load %[[T0]], %[[CST0]] <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [16], lane_data = [1]>>, vector<16xi1> -> vector<16xf32> func.func @load_gather_1d(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf32>) { %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex> @@ -183,9 +184,9 @@ gpu.module @test { // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) { // CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1> // CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex> -// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64}> +// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}> // CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> -// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> +// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> func.func @scatter_ops_chunksize(%src: memref<256xf16>) { %1 = arith.constant dense<1>: vector<16xi1> %offset = arith.constant dense<12> : vector<16xindex> @@ -204,7 +205,7 @@ gpu.module @test { // CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex> // CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] // CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16> -// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> +// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}> : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> func.func @scatter_ops(%src: memref<256xf16>) { %1 = arith.constant dense<1>: vector<16xi1> %offset = arith.constant dense<12> : vector<16xindex> @@ -215,10 +216,50 @@ func.func @scatter_ops(%src: memref<256xf16>) { } // ----- gpu.module @test { +// CHECK-LABEL: func.func @scatter_ops_custom_perm_layout( +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) { +// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} dense<true> : vector<16xi1> +// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} dense<12> : vector<16xindex> +// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] +// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16> +// CHECK: %[[ADD_RES:.*]] = arith.addf %[[LOAD_VEC]], %[[LOAD_VEC]] {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} : vector<16xf16> +// CHECK: xegpu.store %[[ADD_RES]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] +// CHECK-SAME <{layout = #xegpu.layout<lane_layout = [8], lane_data = [1]>}> : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> +func.func @scatter_ops_custom_perm_layout(%src: memref<256xf16>) { + %1 = arith.constant dense<1>: vector<16xi1> + %offset = arith.constant dense<12> : vector<16xindex> + %3 = xegpu.load %src[%offset], %1 : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16> + %4 = arith.addf %3, %3 : vector<16xf16> + xegpu.store %4, %src[%offset], %1 <{layout = #xegpu.layout<lane_layout = [8], lane_data = [1]>}> : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> + return +} +} +// ----- +gpu.module @test { +// CHECK-LABEL: func.func @scatter_ops_preserve_load_perm_layout( +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) { +// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} dense<true> : vector<16xi1> +// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} dense<12> : vector<16xindex> +// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] +// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16> +// CHECK: %[[ADD_RES:.*]] = arith.addf %[[LOAD_VEC]], %[[LOAD_VEC]] {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} : vector<16xf16> +// CHECK: xegpu.store %[[ADD_RES]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] +// CHECK-SAME <{layout = #xegpu.layout<lane_layout = [8], lane_data = [1]>}> : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> +func.func @scatter_ops_preserve_load_perm_layout(%src: memref<256xf16>) { + %1 = arith.constant dense<1>: vector<16xi1> + %offset = arith.constant dense<12> : vector<16xindex> + %3 = xegpu.load %src[%offset], %1 <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}> : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16> + %4 = arith.addf %3, %3 : vector<16xf16> + xegpu.store %4, %src[%offset], %1 <{layout = #xegpu.layout<lane_layout = [8], lane_data = [1]>}> : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> + return +} +} +// ----- +gpu.module @test { // CHECK-LABEL: func.func @vector_bitcast_i16_to_f16( -// CHECK: %[[LOAD0:.*]] = xegpu.load_nd %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} +// CHECK: %[[LOAD0:.*]] = xegpu.load_nd %{{.*}} <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} // CHECK-SAME: !xegpu.tensor_desc<8x16xi16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xi16> -// CHECK: %[[LOAD1:.*]] = xegpu.load_nd %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} +// CHECK: %[[LOAD1:.*]] = xegpu.load_nd %{{.*}} <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} // CHECK-SAME: !xegpu.tensor_desc<16x16xi16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xi16> // CHECK: %{{.*}} = vector.bitcast %[[LOAD0]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} // CHECK-SAME: vector<8x16xi16> to vector<8x16xf16> @@ -241,7 +282,7 @@ func.func @vector_bitcast_i16_to_f16(%arg0: memref<8x16xi16>, %arg1: memref<16x1 // ----- gpu.module @test { // CHECK-LABEL: func.func @vector_bitcast_i32_to_f16( -// CHECK: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} +// CHECK: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} // CHECK-SAME: !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32> // CHECK-NEXT: %{{.*}} = vector.bitcast %[[LOAD]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} // CHECK-SAME: vector<16x8xi32> to vector<16x16xf16> @@ -262,7 +303,7 @@ func.func @vector_bitcast_i32_to_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x8 // ----- gpu.module @test { // CHECK-LABEL: func.func @vector_bitcast_i16_to_i32( -// CHECK: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>} +// CHECK: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>} // CHECK-SAME: !xegpu.tensor_desc<8x32xi16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>> -> vector<8x32xi16> // CHECK-NEXT: %{{.*}} = vector.bitcast %[[LOAD]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} // CHECK-SAME: vector<8x32xi16> to vector<8x16xi32> @@ -299,9 +340,9 @@ gpu.module @test { // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, // CHECK-SAME: %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>, // CHECK-SAME: %[[ARG2:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) { -// CHECK: %[[T1:.*]] = xegpu.load_nd %[[ARG1]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : +// CHECK: %[[T1:.*]] = xegpu.load_nd %[[ARG1]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16> -// CHECK-NEXT: %[[T2:.*]] = xegpu.load_nd %[[ARG1]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : +// CHECK-NEXT: %[[T2:.*]] = xegpu.load_nd %[[ARG1]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16> // CHECK-NEXT: %{{.*}} = arith.addf %[[T1]], %[[T2]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : vector<16x16xf16> func.func @binary_op_one_use(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>) { @@ -322,9 +363,9 @@ gpu.module @test { // CHECK-SAME: %[[ARG2:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, // CHECK-SAME: %[[ARG3:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) { // CHECK: %[[T2:.*]] = arith.addf %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x16xf16> -// CHECK: %[[T3:.*]] = xegpu.dpas %{{.*}}, %[[T2]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> -// CHECK-NEXT: xegpu.store_nd %[[T3]], %[[ARG2]] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -// CHECK-NEXT: xegpu.store_nd %[[T2]], %[[ARG3]] : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> +// CHECK: %[[T3:.*]] = xegpu.dpas %{{.*}}, %[[T2]] {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> +// CHECK-NEXT: xegpu.store_nd %[[T3]], %[[ARG2]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> +// CHECK-NEXT: xegpu.store_nd %[[T2]], %[[ARG3]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> func.func @binary_op_multiple_uses(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>, %arg3: !xegpu.tensor_desc<16x16xf16>) { %0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> %1 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> @@ -345,11 +386,11 @@ gpu.module @test { // CHECK-NEXT: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<0.000000e+00> : vector<8x16xf32> // CHECK-NEXT: %[[T2:.*]]:3 = scf.for %{{.*}} iter_args(%[[ARG4:.*]] = %[[T0]], %[[ARG5:.*]] = %[[T1]], %[[ARG6:.*]] = %[[CST]]) -> // CHECK-SAME: (!xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>, vector<8x16xf32>) { -// CHECK-NEXT: %[[T4:.*]] = xegpu.load_nd %[[ARG4]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : +// CHECK-NEXT: %[[T4:.*]] = xegpu.load_nd %[[ARG4]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : // CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16> -// CHECK-NEXT: %[[T5:.*]] = xegpu.load_nd %[[ARG5]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : +// CHECK-NEXT: %[[T5:.*]] = xegpu.load_nd %[[ARG5]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16> -// CHECK-NEXT: %[[T6:.*]] = xegpu.dpas %[[T4]], %[[T5]], %[[ARG6]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : +// CHECK-NEXT: %[[T6:.*]] = xegpu.dpas %[[T4]], %[[T5]], %[[ARG6]] {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : // CHECK-SAME: vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> // CHECK-NEXT: %[[T7:.*]] = xegpu.update_nd_offset %[[ARG4]], [{{.*}}] : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK-NEXT: %[[T8:.*]] = xegpu.update_nd_offset %[[ARG5]], [{{.*}}] : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> @@ -357,7 +398,7 @@ gpu.module @test { // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>, vector<8x16xf32> // CHECK-NEXT: } {layout_result_2 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} // CHECK-NEXT: %[[T3:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -// CHECK-NEXT: xegpu.store_nd %[[T2]]#2, %[[T3]] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> +// CHECK-NEXT: xegpu.store_nd %[[T2]]#2, %[[T3]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> func.func @for_op(%arg0: memref<8x128xf16>, %arg1: memref<128x16xf16>, %arg2: memref<8x16xf32>) { %c0 = arith.constant 0 : index %c128 = arith.constant 128 : index @@ -385,11 +426,11 @@ gpu.module @test { // CHECK-SAME: %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>, // CHECK-SAME: %[[ARG2:[0-9a-zA-Z]+]]: i1, %[[ARG3:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) { // CHECK: %{{.*}} = scf.if %[[ARG2]] -> (vector<16x16xf16>) { -// CHECK-NEXT: %[[T3:.*]] = xegpu.load_nd %[[ARG1]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : +// CHECK-NEXT: %[[T3:.*]] = xegpu.load_nd %[[ARG1]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16> // CHECK-NEXT: scf.yield %[[T3]] : vector<16x16xf16> // CHECK-NEXT: } else { -// CHECK-NEXT: %[[T4:.*]] = xegpu.load_nd %[[ARG1]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : +// CHECK-NEXT: %[[T4:.*]] = xegpu.load_nd %[[ARG1]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16> // CHECK-NEXT: scf.yield %[[T4]] : vector<16x16xf16> // CHECK-NEXT: } {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} @@ -415,11 +456,11 @@ gpu.module @test { // CHECK-SAME: %[[ARG2:[0-9a-zA-Z]+]]: i1, %[[ARG3:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, // CHECK-SAME: %[[ARG4:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) { // CHECK: %[[T1:.*]] = scf.if %[[ARG2]] -> (vector<16x16xf16>) { -// CHECK-NEXT: %[[T3:.*]] = xegpu.load_nd %[[ARG1]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : +// CHECK-NEXT: %[[T3:.*]] = xegpu.load_nd %[[ARG1]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16> // CHECK-NEXT: scf.yield %[[T3]] : vector<16x16xf16> // CHECK-NEXT: } else { -// CHECK-NEXT: %[[T4:.*]] = xegpu.load_nd %[[ARG1]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : +// CHECK-NEXT: %[[T4:.*]] = xegpu.load_nd %[[ARG1]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16> // CHECK-NEXT: scf.yield %[[T4]] : vector<16x16xf16> // CHECK-NEXT: } {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} @@ -499,7 +540,7 @@ gpu.module @test { // CHECK-LABEL: func.func @prefetch_2d( // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256x256xf16>) { // CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}, %{{.*}}] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -// CHECK-NEXT: xegpu.prefetch_nd %[[T0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> +// CHECK-NEXT: xegpu.prefetch_nd %[[T0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> func.func @prefetch_2d(%arg0: memref<256x256xf16>){ %c0 = arith.constant 0 : index %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16> @@ -512,7 +553,7 @@ gpu.module @test { // CHECK-LABEL: func.func @prefetch_1d( // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) { // CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256xf16> -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -// CHECK-NEXT: xegpu.prefetch_nd %[[T0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> +// CHECK-NEXT: xegpu.prefetch_nd %[[T0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}> : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> func.func @prefetch_1d(%arg0: memref<256xf16>){ %c0 = arith.constant 0 : index %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<256xf16> -> !xegpu.tensor_desc<16xf16> @@ -559,7 +600,7 @@ gpu.module @test { // CHECK-LABEL: func.func @vector_shape_cast_1d_to_2d_dim1_distributed( // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, // CHECK-SAME: %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) { -// CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[ARG0]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} +// CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[ARG0]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16> // CHECK-NEXT: %[[REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %{{[0-9a-zA-Z]+}} // CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} [0] : vector<16x16xf16> to vector<16xf16> @@ -581,7 +622,7 @@ gpu.module @test { // CHECK-LABEL: func.func @vector_shape_cast_1d_to_2d_dim0_broadcasted( // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, // CHECK-SAME: %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) { -// CHECK: %[[LOAD:.*]] = xegpu.load_nd %arg0 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} +// CHECK: %[[LOAD:.*]] = xegpu.load_nd %arg0 <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16> // CHECK-NEXT: %[[REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %{{[0-9a-zA-Z]+}} // CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>} [1] @@ -599,3 +640,61 @@ func.func @vector_shape_cast_1d_to_2d_dim0_broadcasted(%arg0: !xegpu.tensor_desc return } } +// ----- +gpu.module @test { +// CHECK-LABEL: func.func @vector_broadcast_1d_to_2d_broadcast_along_row( +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, +// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) { +// CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[ARG0]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} +// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16> +// CHECK-NEXT: %[[REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %{{[0-9a-zA-Z]+}} +// CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} [0] : vector<16x16xf16> to vector<16xf16> +// CHECK-NEXT: %[[BROADCAST:.*]] = vector.broadcast %[[REDUCE]] +// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf16> to vector<16x16xf16> +func.func @vector_broadcast_1d_to_2d_broadcast_along_row(%arg0: !xegpu.tensor_desc<16x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) { + %c0 = arith.constant 0 : index + %cst = arith.constant dense<0.0000> : vector<16xf16> + %3 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> + %4 = vector.multi_reduction <add>, %3, %cst [0] : vector<16x16xf16> to vector<16xf16> + %5 = vector.broadcast %4 : vector<16xf16> to vector<16x16xf16> + xegpu.store_nd %5, %arg1 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16> + return +} +} + +// ----- +gpu.module @test { +// CHECK-LABEL: func.func @vector_broadcast_2d_to_2d_along_column( +// CHECK: %[[REDUCE:.*]] = vector.multi_reduction <add> +// CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>} [1] : vector<16x16xf16> to vector<16xf16> +// CHECK-NEXT: %[[SHAPECAST:.*]] = vector.shape_cast %[[REDUCE]] +// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf16> to vector<16x1xf16> +// CHECK-NEXT: vector.broadcast %[[SHAPECAST]] +// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x1xf16> to vector<16x16xf16> + +func.func @vector_broadcast_2d_to_2d_along_column(%arg0: !xegpu.tensor_desc<16x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) { + %c0 = arith.constant 0 : index + %cst = arith.constant dense<0.0000> : vector<16xf16> + %3 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> + %4 = vector.multi_reduction <add>, %3, %cst [1] : vector<16x16xf16> to vector<16xf16> + %5 = vector.shape_cast %4 : vector<16xf16> to vector<16x1xf16> + %6 = vector.broadcast %5 : vector<16x1xf16> to vector<16x16xf16> + xegpu.store_nd %6, %arg1 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16> + return +} +} + +// ----- +gpu.module @test { +// CHECK-LABEL: func.func @vector_broadcast_scalar_to_vector( +// CHECK: %[[CST:.*]] = arith.constant 0.{{.*}} : f16 +// CHECK-NEXT: %[[BROADCAST:.*]] = vector.broadcast %[[CST]] +// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : f16 to vector<16x16xf16> + +func.func @vector_broadcast_scalar_to_vector(%arg0: !xegpu.tensor_desc<16x16xf16>) { + %cst = arith.constant 0.0000 : f16 + %6 = vector.broadcast %cst : f16 to vector<16x16xf16> + xegpu.store_nd %6, %arg0 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16> + return +} +}
\ No newline at end of file diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir index f233dff..216f3d1 100644 --- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir +++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir @@ -1,6 +1,6 @@ -// RUN: mlir-opt --xevm-attach-target='module=xevm_* chip=pvc' -test-xegpu-sg-distribute -allow-unregistered-dialect \ -// RUN: -canonicalize -cse -split-input-file %s | FileCheck %s - +// RUN: mlir-opt --xevm-attach-target='module=xevm_* chip=pvc' -test-xegpu-sg-distribute \ +// RUN: -allow-unregistered-dialect -canonicalize -cse %s | FileCheck %s +gpu.module @xevm_module{ // CHECK-LABEL: gpu.func @store_nd_1d // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) { // CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] @@ -11,20 +11,17 @@ // CHECK-NEXT: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[W]]#1 : !xegpu.tensor_desc<16xf32, // CHECK-SAME: #xegpu.layout<lane_layout = [16], lane_data = [1]>> to !xegpu.tensor_desc<16xf32> {resolve_simt_type_mismatch} // CHECK-NEXT: xegpu.store_nd %[[W]]#0, %[[T1]][%[[W]]#2] : vector<1xf32>, !xegpu.tensor_desc<16xf32> -gpu.module @xevm_module{ - gpu.func @store_nd_1d(%laneid: index) { - %c0 = arith.constant 0 : index - gpu.warp_execute_on_lane_0(%laneid)[16] { - %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>> - %cst = "some_op"() : () -> vector<16xf32> - xegpu.store_nd %cst, %0 [%c0] {layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} - : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>> - } - gpu.return +gpu.func @store_nd_1d(%laneid: index) { + %c0 = arith.constant 0 : index + gpu.warp_execute_on_lane_0(%laneid)[16] { + %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>> + %cst = "some_op"() : () -> vector<16xf32> + xegpu.store_nd %cst, %0 [%c0] {layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} + : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>> } + gpu.return } -// ----- // CHECK-LABEL: gpu.func @store_nd_2d // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) { // CHECK: %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] @@ -37,22 +34,18 @@ gpu.module @xevm_module{ // CHECK-NEXT: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[W]]#1 : !xegpu.tensor_desc<16x16xf16, // CHECK-SAME: #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> to !xegpu.tensor_desc<16x16xf16> {resolve_simt_type_mismatch} // CHECK-NEXT: xegpu.store_nd %[[CAST]], %[[T1]][%[[W]]#2, %[[W]]#3] : vector<16xf16>, !xegpu.tensor_desc<16x16xf16> -gpu.module @xevm_module{ - gpu.func @store_nd_2d(%laneid : index) { - %c0 = arith.constant 0 : index - gpu.warp_execute_on_lane_0(%laneid)[16] { - %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> - %cst = "some_op"() : () -> vector<16x16xf16> - xegpu.store_nd %cst, %0 [%c0, %c0] {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} - : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> - } - gpu.return +gpu.func @store_nd_2d(%laneid : index) { + %c0 = arith.constant 0 : index + gpu.warp_execute_on_lane_0(%laneid)[16] { + %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + %cst = "some_op"() : () -> vector<16x16xf16> + xegpu.store_nd %cst, %0 [%c0, %c0] {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} + : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> } + gpu.return } - -// ----- // CHECK-LABEL: gpu.func @load_nd_1d // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) { // CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<1xf32>, @@ -63,21 +56,19 @@ gpu.module @xevm_module{ // CHECK-NEXT: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[W]]#1 : !xegpu.tensor_desc<16xf32, // CHECK-SAME: #xegpu.layout<lane_layout = [16], lane_data = [1]>> to !xegpu.tensor_desc<16xf32> {resolve_simt_type_mismatch} // CHECK-NEXT: xegpu.load_nd %[[T1]][%[[W]]#2] : !xegpu.tensor_desc<16xf32> -> vector<1xf32> -gpu.module @xevm_module{ - gpu.func @load_nd_1d(%laneid: index) { - %c0 = arith.constant 0 : index - %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf32>) { - %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>> - %1 = xegpu.load_nd %0 [%c0] {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : - !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf32> - gpu.yield %1 : vector<16xf32> - } - "some_user_op"(%r) : (vector<1xf32>) -> () - gpu.return +gpu.func @load_nd_1d(%laneid: index) { + %c0 = arith.constant 0 : index + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf32>) { + %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>> + %1 = xegpu.load_nd %0 [%c0] {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : + !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf32> + gpu.yield %1 : vector<16xf32> } + "some_user_op"(%r) : (vector<1xf32>) -> () + gpu.return } -// ----- + // CHECK-LABEL: gpu.func @load_nd_2d // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) { // CHECK: %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<16x1xf16>, !xegpu.tensor_desc<16x16xf16, @@ -89,21 +80,19 @@ gpu.module @xevm_module{ // CHECK-SAME: #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> to !xegpu.tensor_desc<16x16xf16> {resolve_simt_type_mismatch} // CHECK-NEXT: %[[T2:.*]] = xegpu.load_nd %[[T1]][%[[W]]#2, %[[W]]#3] : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16> // CHECK: vector.shape_cast %[[T2]] : vector<16xf16> to vector<16x1xf16> -gpu.module @xevm_module{ - gpu.func @load_nd_2d(%laneid: index) { - %c0 = arith.constant 0 : index - %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<16x1xf16>) { - %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> - %1 = xegpu.load_nd %0[%c0, %c0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} - : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16> - gpu.yield %1 : vector<16x16xf16> - } - "some_user_op"(%r) : (vector<16x1xf16>) -> () - gpu.return +gpu.func @load_nd_2d(%laneid: index) { + %c0 = arith.constant 0 : index + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<16x1xf16>) { + %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + %1 = xegpu.load_nd %0[%c0, %c0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} + : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16> + gpu.yield %1 : vector<16x16xf16> } + "some_user_op"(%r) : (vector<16x1xf16>) -> () + gpu.return } -// ----- + // CHECK-LABEL: gpu.func @load_nd_array_length // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) { // CHECK: %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<2x16x1xf16>, @@ -118,23 +107,21 @@ gpu.module @xevm_module{ // CHECK-NEXT: %[[T2:.*]] = xegpu.load_nd %[[T1]][%[[W]]#2, %[[W]]#3] : !xegpu.tensor_desc<16x16xf16, // CHECK-SAME: #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<32xf16> // CHECK-NEXT: vector.shape_cast %[[T2]] : vector<32xf16> to vector<2x16x1xf16> -gpu.module @xevm_module{ - gpu.func @load_nd_array_length(%laneid: index) { - %c0 = arith.constant 0 : index - %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2x16x1xf16>) { - %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>, - #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> - %1 = xegpu.load_nd %0[%c0, %c0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} - : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>, - #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<2x16x16xf16> - gpu.yield %1 : vector<2x16x16xf16> - } - "some_user_op"(%r) : (vector<2x16x1xf16>) -> () - gpu.return +gpu.func @load_nd_array_length(%laneid: index) { + %c0 = arith.constant 0 : index + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2x16x1xf16>) { + %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>, + #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + %1 = xegpu.load_nd %0[%c0, %c0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} + : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>, + #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<2x16x16xf16> + gpu.yield %1 : vector<2x16x16xf16> } + "some_user_op"(%r) : (vector<2x16x1xf16>) -> () + gpu.return } -// ----- + // CHECK-LABEL: gpu.func @dpas // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) { // CHECK: %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> @@ -146,29 +133,27 @@ gpu.module @xevm_module{ // CHECK-DAG: %[[T3:.*]] = vector.shape_cast %[[W]]#3 : vector<8x1xf32> to vector<8xf32> // CHECK-NEXT: %[[T4:.*]] = xegpu.dpas %[[T1]], %[[T2]], %[[T3]] : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32> // CHECK-NEXT: vector.shape_cast %[[T4]] : vector<8xf32> to vector<8x1xf32> -gpu.module @xevm_module{ - gpu.func @dpas(%laneid: index) { - %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) { - %0 = "some_op"() : () -> vector<8x16xf16> - %1 = "some_op"() : () -> vector<16x16xf16> - %2 = "some_op"() : () -> vector<8x16xf32> - %3 = xegpu.dpas %0, %1, %2 - { - layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, - layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, - layout_operand_2 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, - layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> - } - : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - gpu.yield %3 : vector<8x16xf32> - } - "some_user_op"(%r) : (vector<8x1xf32>) -> () - gpu.return +gpu.func @dpas(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) { + %0 = "some_op"() : () -> vector<8x16xf16> + %1 = "some_op"() : () -> vector<16x16xf16> + %2 = "some_op"() : () -> vector<8x16xf32> + %3 = xegpu.dpas %0, %1, %2 + { + layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, + layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, + layout_operand_2 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, + layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> + } + : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + gpu.yield %3 : vector<8x16xf32> } + "some_user_op"(%r) : (vector<8x1xf32>) -> () + gpu.return } -// ----- + // CHECK-LABEL: gpu.func @create_nd_tdesc_non_memref // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: ui64, %[[ARG1:[0-9a-zA-Z]+]]: index) { // CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%[[ARG1]])[16] -> (!xegpu.tensor_desc<16x16xf16, @@ -178,21 +163,19 @@ gpu.module @xevm_module{ // CHECK-NEXT: %[[T1:.*]] = xegpu.create_nd_tdesc %[[W]]#1, shape : [64, 128], strides : [128, 1] : ui64 -> !xegpu.tensor_desc<16x16xf16> // CHECK-NEXT: builtin.unrealized_conversion_cast %[[T1]] : !xegpu.tensor_desc<16x16xf16> to !xegpu.tensor_desc<16x16xf16, // CHECK-SAME: #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> {resolve_simt_type_mismatch} -gpu.module @xevm_module{ - gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %laneid: index) { - %c0 = arith.constant 0 : index - %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (!xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) { - %0 = xegpu.create_nd_tdesc %arg0, shape:[64, 128], strides:[128, 1] : ui64 -> - !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> - gpu.yield %0 : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> - } - "some_user_op"(%r) - : (!xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) -> () - gpu.return +gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %laneid: index) { + %c0 = arith.constant 0 : index + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (!xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) { + %0 = xegpu.create_nd_tdesc %arg0, shape:[64, 128], strides:[128, 1] : ui64 -> + !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + gpu.yield %0 : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> } + "some_user_op"(%r) + : (!xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) -> () + gpu.return } -// ----- + // CHECK-LABEL: gpu.func @prefetch_2d // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) { // CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (!xegpu.tensor_desc<16x16xf16, @@ -204,21 +187,19 @@ gpu.module @xevm_module{ // CHECK-SAME: #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> to !xegpu.tensor_desc<16x16xf16> {resolve_simt_type_mismatch} // CHECK-NEXT: xegpu.prefetch_nd %[[T1]][%[[W]]#1, %[[W]]#2] // CHECK-SAME: <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16> -gpu.module @xevm_module{ - gpu.func @prefetch_2d(%laneid: index) { - %c0 = arith.constant 0 : index - gpu.warp_execute_on_lane_0(%laneid)[16] { - %0 = "some_op"() : () - -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> - xegpu.prefetch_nd %0[%c0, %c0] - <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> - : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> - } - gpu.return +gpu.func @prefetch_2d(%laneid: index) { + %c0 = arith.constant 0 : index + gpu.warp_execute_on_lane_0(%laneid)[16] { + %0 = "some_op"() : () + -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + xegpu.prefetch_nd %0[%c0, %c0] + <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> + : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> } + gpu.return } -// ----- + // CHECK-LABEL: gpu.func @prefetch_1d // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) { // CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (!xegpu.tensor_desc<16xf16, @@ -229,44 +210,40 @@ gpu.module @xevm_module{ // CHECK-SAME: #xegpu.layout<lane_layout = [16], lane_data = [1]>> to !xegpu.tensor_desc<16xf16> {resolve_simt_type_mismatch} // CHECK-NEXT: xegpu.prefetch_nd %[[T1]][%[[W]]#1] <{l1_hint = #xegpu.cache_hint<cached>, // CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf16> -gpu.module @xevm_module{ - gpu.func @prefetch_1d(%laneid: index) { - %c0 = arith.constant 0 : index - gpu.warp_execute_on_lane_0(%laneid)[16] { - %0 = "some_op"() : () - -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> - xegpu.prefetch_nd %0[%c0] - <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> - : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> - } - gpu.return +gpu.func @prefetch_1d(%laneid: index) { + %c0 = arith.constant 0 : index + gpu.warp_execute_on_lane_0(%laneid)[16] { + %0 = "some_op"() : () + -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> + xegpu.prefetch_nd %0[%c0] + <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> + : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> } + gpu.return } -// ----- + // CHECK-LABEL: gpu.func @gpu_barrier({{.*}}) { // CHECK: gpu.warp_execute_on_lane_0(%{{.*}})[16] -> ({{.*}}) { // CHECK: gpu.yield %{{.*}} // CHECK: } // CHECK: %{{.*}} = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<16xf16> -> vector<1xf16> // CHECK: gpu.barrier -gpu.module @xevm_module{ - gpu.func @gpu_barrier(%laneid: index) { - %c0 = arith.constant 0 : index - %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf16>) { - %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> - %1 = xegpu.load_nd %0[%c0] - {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} - : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf16> - gpu.barrier - gpu.yield %1 : vector<16xf16> - } - "some_user_op"(%r) : (vector<1xf16>) -> () - gpu.return +gpu.func @gpu_barrier(%laneid: index) { + %c0 = arith.constant 0 : index + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf16>) { + %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> + %1 = xegpu.load_nd %0[%c0] + {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} + : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf16> + gpu.barrier + gpu.yield %1 : vector<16xf16> } + "some_user_op"(%r) : (vector<1xf16>) -> () + gpu.return } -// ----- + // CHECK-LABEL: gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction // CHECK: %[[ACC:.*]] = arith.constant {{.*}} dense<0.000000e+00> : vector<32xf32> // CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] @@ -285,7 +262,6 @@ gpu.module @xevm_module{ // CHECK: %[[T7:.*]] = vector.extract %[[W]]#2[1] : f32 from vector<2xf32> // CHECK: %[[T8:.*]] = vector.reduction <add>, %[[T6]], %[[T7]] : vector<16xf32> into f32 // CHECK: %[[T9:.*]] = vector.from_elements %[[T4]], %[[T8]] : vector<2xf32> -gpu.module @xevm_module{ gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction(%laneid: index) { %c0 = arith.constant 0 : index %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) { @@ -307,9 +283,8 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction(%laneid: index) "some_user_op"(%r) : (vector<2xf32>) -> () gpu.return } -} -// ----- + // CHECK-LABEL: gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction // CHECK: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) { // CHECK-NEXT: %[[SRC:.*]] = "some_def"() {{.*}} : () -> vector<2x16xf32> @@ -320,7 +295,6 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction(%laneid: index) // CHECK-NEXT: %[[T6:.*]] = vector.from_elements %[[T3]], %[[T5]] : vector<2xf32> // CHECK-NEXT: gpu.yield %[[T6]] : vector<2xf32> // CHECK-NEXT: } -gpu.module @xevm_module{ gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction(%laneid: index) { %c0 = arith.constant 0 : index %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) { @@ -342,9 +316,8 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction(%laneid: index) "some_user_op"(%r) : (vector<2xf32>) -> () gpu.return } -} -// ----- + // CHECK-LABEL: gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction // CHECK: %[[ACC:.*]] = arith.constant {{.*}} dense<0.000000e+00> : vector<32xf32> // CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>, vector<2x16xf32>, vector<2xf32>) { @@ -358,7 +331,6 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction(%laneid: index) // CHECK: %[[T5:.*]] = vector.extract %[[W]]#2[1] : f32 from vector<2xf32> // CHECK: %[[T6:.*]] = vector.reduction <add>, %[[T4]], %[[T5]] : vector<16xf32> into f32 // CHECK: %[[T7:.*]] = vector.from_elements %[[T3]], %[[T6]] : vector<2xf32> -gpu.module @xevm_module{ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index) { %c0 = arith.constant 0 : index %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) { @@ -380,9 +352,8 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index) "some_user_op"(%r) : (vector<2xf32>) -> () gpu.return } -} -// ----- + // CHECK-LABEL: gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction // CHECK: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) { // CHECK: %[[SRC:.*]] = "some_def"() {{.*}} : () -> vector<16x2xf32> @@ -397,7 +368,6 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index) // CHECK: %[[T7:.*]] = vector.from_elements %[[T3]], %[[T6]] : vector<2xf32> // CHECK: gpu.yield %[[T7]] : vector<2xf32> // CHECK: } -gpu.module @xevm_module{ gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction(%laneid: index) { %c0 = arith.constant 0 : index %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) { @@ -419,9 +389,8 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction(%laneid: index) "some_user_op"(%r) : (vector<2xf32>) -> () gpu.return } -} -// ----- + // CHECK-LABEL: gpu.func @scatter_ops_chunksize({{.*}}) { // CHECK: %[[OFFSETS:.*]] = arith.constant {{.*}} dense<12> : vector<16xindex> // CHECK: %[[MASKS:.*]] = arith.constant {{.*}} dense<true> : vector<16xi1> @@ -434,35 +403,33 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction(%laneid: index) // CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16> // CHECK-NEXT: xegpu.store %[[T1]], %[[W]]#1[%[[W]]#2], %[[W]]#3 <{chunk_size = 8 : i64}> // CHECK-SAME: : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1> -gpu.module @xevm_module{ - gpu.func @scatter_ops_chunksize(%laneid: index, %src: memref<256xf16>) { - gpu.warp_execute_on_lane_0(%laneid)[16] { - %1 = arith.constant - {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} - dense<1>: vector<16xi1> - %offset = arith.constant - {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} - dense<12> : vector<16xindex> - %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> - { - layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, - layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, - layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]> - } - : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> - xegpu.store %3, %src[%offset], %1 <{chunk_size=8}> - { - layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>, - layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, - layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]> - } - : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> - } - gpu.return +gpu.func @scatter_ops_chunksize(%laneid: index, %src: memref<256xf16>) { + gpu.warp_execute_on_lane_0(%laneid)[16] { + %1 = arith.constant + {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} + dense<1>: vector<16xi1> + %offset = arith.constant + {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} + dense<12> : vector<16xindex> + %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> + { + layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, + layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, + layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]> + } + : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> + xegpu.store %3, %src[%offset], %1 <{chunk_size=8}> + { + layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>, + layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, + layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]> + } + : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> } + gpu.return } -// ----- + // CHECK-LABEL: gpu.func @scatter_ops({{.*}}) { // CHECK: %[[OFFSETS:.*]] = arith.constant {{.*}} dense<12> : vector<16xindex> // CHECK: %[[MASKS:.*]] = arith.constant {{.*}} dense<true> : vector<16xi1> @@ -475,156 +442,144 @@ gpu.module @xevm_module{ // CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16> // CHECK-NEXT: xegpu.store %[[T1]], %[[W]]#1[%[[W]]#2], %[[W]]#3 // CHECK-SAME: : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1> -gpu.module @xevm_module{ - gpu.func @scatter_ops(%src: memref<256xf16>, %laneid: index) { - gpu.warp_execute_on_lane_0(%laneid)[16] { - %1 = arith.constant - {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} - dense<1> : vector<16xi1> - %offset = arith.constant - {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} - dense<12> : vector<16xindex> - %3 = xegpu.load %src[%offset], %1 - { - layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, - layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, - layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> - } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16> - xegpu.store %3, %src[%offset], %1 - { - layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, - layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, - layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]> - } - : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> +gpu.func @scatter_ops(%src: memref<256xf16>, %laneid: index) { + gpu.warp_execute_on_lane_0(%laneid)[16] { + %1 = arith.constant + {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} + dense<1> : vector<16xi1> + %offset = arith.constant + {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} + dense<12> : vector<16xindex> + %3 = xegpu.load %src[%offset], %1 + { + layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, + layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, + layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> + } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16> + xegpu.store %3, %src[%offset], %1 + { + layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, + layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, + layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]> } - gpu.return + : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> } + gpu.return } -// ----- + // CHECK-LABEL: gpu.func @memref_extract_aligned_pointer_as_index( // CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (index, memref<256x256xf16>) { // CHECK: gpu.yield %{{.*}}, %{{.*}} : index, memref<256x256xf16> // CHECK-NEXT: } // CHECK-NEXT: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[W]]#1 : memref<256x256xf16> -> index // CHECK-NEXT: arith.index_cast %[[INTPTR]] : index to i64 -gpu.module @xevm_module{ - gpu.func @memref_extract_aligned_pointer_as_index(%arg0 : memref<256x256xf16>, %laneid: index) { - %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (index) { - %ptr = memref.extract_aligned_pointer_as_index %arg0 : memref<256x256xf16> -> index - gpu.yield %ptr : index - } - %ptr_i64 = arith.index_cast %r : index to i64 - "some_user_op"(%ptr_i64) : (i64) -> () - gpu.return +gpu.func @memref_extract_aligned_pointer_as_index(%arg0 : memref<256x256xf16>, %laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (index) { + %ptr = memref.extract_aligned_pointer_as_index %arg0 : memref<256x256xf16> -> index + gpu.yield %ptr : index } + %ptr_i64 = arith.index_cast %r : index to i64 + "some_user_op"(%ptr_i64) : (i64) -> () + gpu.return } -// ----- + // CHECK-LABEL: gpu.func @vector_transpose( // CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2x1xf32>, vector<1x2xf32>) { // CHECK: %[[SRC:.*]] = "some_op"() {{.*}} : () -> vector<16x2xf32> // CHECK: gpu.yield %{{.*}}, %[[SRC]] : vector<2x16xf32>, vector<16x2xf32> // CHECK-NEXT: } // CHECK-NEXT: %[[T1:.*]] = vector.transpose %[[W]]#1, [1, 0] : vector<1x2xf32> to vector<2x1xf32> -gpu.module @xevm_module{ - gpu.func @vector_transpose(%laneid: index) { - %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2x1xf32>) { - %cst = "some_op"() - {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} - : () -> (vector<16x2xf32>) - %transpose = vector.transpose %cst, [1, 0] - { - layout_operand_0 = #xegpu.layout<lane_layout = [16 , 1], lane_data = [1, 1]>, - layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> - } - : vector<16x2xf32> to vector<2x16xf32> - gpu.yield %transpose : vector<2x16xf32> - } - "some_user_op"(%r) : (vector<2x1xf32>) -> () - gpu.return +gpu.func @vector_transpose(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2x1xf32>) { + %cst = "some_op"() + {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} + : () -> (vector<16x2xf32>) + %transpose = vector.transpose %cst, [1, 0] + { + layout_operand_0 = #xegpu.layout<lane_layout = [16 , 1], lane_data = [1, 1]>, + layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> + } + : vector<16x2xf32> to vector<2x16xf32> + gpu.yield %transpose : vector<2x16xf32> } + "some_user_op"(%r) : (vector<2x1xf32>) -> () + gpu.return } -// ----- + // CHECK-LABEL: gpu.func @vector_bitcast( // CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<4x1xi16>, vector<4x2xi8>) { // CHECK: %[[SRC:.*]] = "some_op"() {{.*}} : () -> vector<4x32xi8> // CHECK: gpu.yield %{{.*}}, %[[SRC]] : vector<4x16xi16>, vector<4x32xi8> // CHECK: } // CHECK: vector.bitcast %[[W]]#1 : vector<4x2xi8> to vector<4x1xi16> -gpu.module @xevm_module{ - gpu.func @vector_bitcast(%laneid: index) { - %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<4x1xi16>) { - %cst = "some_op"() - {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>} - : () -> (vector<4x32xi8>) - %bitcast = vector.bitcast %cst - { - layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>, - layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> - } - : vector<4x32xi8> to vector<4x16xi16> - gpu.yield %bitcast : vector<4x16xi16> - } - "some_user_op"(%r) : (vector<4x1xi16>) -> () - gpu.return +gpu.func @vector_bitcast(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<4x1xi16>) { + %cst = "some_op"() + {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>} + : () -> (vector<4x32xi8>) + %bitcast = vector.bitcast %cst + { + layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>, + layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> + } + : vector<4x32xi8> to vector<4x16xi16> + gpu.yield %bitcast : vector<4x16xi16> } + "some_user_op"(%r) : (vector<4x1xi16>) -> () + gpu.return } -// ----- + // CHECK-LABEL: gpu.func @vector_shapecast_rank_increasing // CHECK: %{{.*}}:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1x1xf32>, vector<1xf32>) { // CHECK: gpu.yield %{{.*}} : vector<1x16xf32>, vector<16xf32> // CHECK: } // CHECK: %{{.*}} = vector.shape_cast %{{.*}}#1 : vector<1xf32> to vector<1x1xf32> -gpu.module @xevm_module { - gpu.func @vector_shapecast_rank_increasing(%laneid: index) { - %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) { - %cst = "some_op"() - {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} - : () -> (vector<16xf32>) - %cast = vector.shape_cast %cst - { - layout_operand_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>, - layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> - } - : vector<16xf32> to vector<1x16xf32> - gpu.yield %cast : vector<1x16xf32> - } - "some_user_op"(%r) : (vector<1x1xf32>) -> () - gpu.return +gpu.func @vector_shapecast_rank_increasing(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) { + %cst = "some_op"() + {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} + : () -> (vector<16xf32>) + %cast = vector.shape_cast %cst + { + layout_operand_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>, + layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> + } + : vector<16xf32> to vector<1x16xf32> + gpu.yield %cast : vector<1x16xf32> } + "some_user_op"(%r) : (vector<1x1xf32>) -> () + gpu.return } -// ----- + // CHECK-LABEL: gpu.func @vector_shapecast_rank_reducing( // CHECK: %{{.*}}:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1xf32>, vector<1x1xf32>) { // CHECK: gpu.yield %{{.*}} : vector<16xf32>, vector<1x16xf32> // CHECK: } // CHECK: %{{.*}} = vector.shape_cast %{{.*}}#1 : vector<1x1xf32> to vector<1xf32> -gpu.module @xevm_module { - gpu.func @vector_shapecast_rank_reducing(%laneid: index) { - %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf32>) { - %cst = "some_op"() - {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} - : () -> (vector<1x16xf32>) - %cast = vector.shape_cast %cst - { - layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, - layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]> - } - : vector<1x16xf32> to vector<16xf32> - gpu.yield %cast : vector<16xf32> - } - "some_user_op"(%r) : (vector<1xf32>) -> () - gpu.return +gpu.func @vector_shapecast_rank_reducing(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf32>) { + %cst = "some_op"() + {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} + : () -> (vector<1x16xf32>) + %cast = vector.shape_cast %cst + { + layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, + layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]> + } + : vector<1x16xf32> to vector<16xf32> + gpu.yield %cast : vector<16xf32> } + "some_user_op"(%r) : (vector<1xf32>) -> () + gpu.return } -// ----- + // NOTE: Layouts are still valid, but distribution still requires a slice layout for the operand. // // CHECK-LABEL: gpu.func @vector_shapecast_unsupported @@ -634,21 +589,400 @@ gpu.module @xevm_module { // CHECK: } // CHECK: "some_user_op"(%[[W]]) : (vector<1x1xf32>) -> () // CHECK: gpu.return -gpu.module @xevm_module { - gpu.func @vector_shapecast_unsupported(%laneid: index) { - %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) { - %cst = "some_op"() - {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> } - : () -> (vector<16xf32>) - %cast = vector.shape_cast %cst - { - layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, - layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> - } - : vector<16xf32> to vector<1x16xf32> - gpu.yield %cast : vector<1x16xf32> +gpu.func @vector_shapecast_unsupported(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) { + %cst = "some_op"() + {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> } + : () -> (vector<16xf32>) + %cast = vector.shape_cast %cst + { + layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, + layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> + } + : vector<16xf32> to vector<1x16xf32> + gpu.yield %cast : vector<1x16xf32> + } + "some_user_op"(%r) : (vector<1x1xf32>) -> () + gpu.return +} + + +// CHECK-LABEL: gpu.func @vector_extract_strided_slice_distributed_dim_fully_extracted +// CHECK-NEXT: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<8x1xf32>, vector<24x1xf32>) { +// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<24x16xf32> +// CHECK: gpu.yield %{{.*}}, %[[S]] : vector<8x16xf32>, vector<24x16xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[T1:.*]] = vector.extract_strided_slice %[[W]]#1 +// CHECK-SAME: {offsets = [8, 0], sizes = [8, 1], strides = [1, 1]} : vector<24x1xf32> to vector<8x1xf32> +// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<8x1xf32>) -> () +gpu.func @vector_extract_strided_slice_distributed_dim_fully_extracted(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) { + %0 = "some_def"() : () -> (vector<24x16xf32>) + %1 = vector.extract_strided_slice %0 { offsets = [8, 0], sizes = [8, 16], strides = [1, 1], + layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, + layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> + } + : vector<24x16xf32> to vector<8x16xf32> + gpu.yield %1 : vector<8x16xf32> + } + "some_use"(%r) : (vector<8x1xf32>) -> () + gpu.return +} + +// CHECK-LABEL: gpu.func @vector_extract_strided_slice_non_distributed +// CHECK-NEXT: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<8x1xf32>, vector<24x1xf32>) { +// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<24x1xf32> +// CHECK: gpu.yield %{{.*}}, %[[S]] : vector<8x1xf32>, vector<24x1xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[T1:.*]] = vector.extract_strided_slice %[[W]]#1 +// CHECK-SAME: {offsets = [8, 0], sizes = [8, 1], strides = [1, 1]} : vector<24x1xf32> to vector<8x1xf32> +// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<8x1xf32>) -> () +gpu.func @vector_extract_strided_slice_non_distributed(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) { + %0 = "some_def"() : () -> (vector<24x1xf32>) + %1 = vector.extract_strided_slice %0 { offsets = [8, 0], sizes = [8, 1], strides = [1, 1], + layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, + layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> + } + : vector<24x1xf32> to vector<8x1xf32> + gpu.yield %1 : vector<8x1xf32> + } + "some_use"(%r) : (vector<8x1xf32>) -> () + gpu.return +} + +// CHECK-LABEL: gpu.func @vector_extract_strided_slice_inner_distributed +// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<8x1xf32>, vector<24x4xf32>) { +// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<24x64xf32> +// CHECK: gpu.yield %{{.*}}, %[[S]] : vector<8x16xf32>, vector<24x64xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[T1:.*]] = vector.extract_strided_slice %[[W]]#1 +// CHECK-SAME: {offsets = [8, 3], sizes = [8, 1], strides = [1, 1]} : vector<24x4xf32> to vector<8x1xf32> +// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<8x1xf32>) -> () +gpu.func @vector_extract_strided_slice_inner_distributed(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) { + %0 = "some_def"() : () -> (vector<24x64xf32>) + %1 = vector.extract_strided_slice %0 { offsets = [8, 48], sizes = [8, 16], strides = [1, 1], + layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, + layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> + } + : vector<24x64xf32> to vector<8x16xf32> + gpu.yield %1 : vector<8x16xf32> + } + "some_use"(%r) : (vector<8x1xf32>) -> () + gpu.return +} + +// CHECK-LABEL: gpu.func @vector_extract_strided_slice_outer_distributed +// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1x16xf32>, vector<2x16xf32>) { +// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<32x16xf32> +// CHECK: gpu.yield %{{.*}}, %[[S]] : vector<16x16xf32>, vector<32x16xf32> +// CHECK: } +// CHECK-NEXT: %[[T1:.*]] = vector.extract %[[W]]#1[1] : vector<16xf32> from vector<2x16xf32> +// CHECK-NEXT: %[[T2:.*]] = vector.shape_cast %[[T1]] : vector<16xf32> to vector<1x16xf32> +// CHECK-NEXT: "some_use"(%[[T2]]) : (vector<1x16xf32>) -> () +gpu.func @vector_extract_strided_slice_outer_distributed(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x16xf32>) { + %0 = "some_def"() : () -> (vector<32x16xf32>) + %1 = vector.extract_strided_slice %0 { offsets = [16], sizes = [16], strides = [1], + layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, + layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]> + } + : vector<32x16xf32> to vector<16x16xf32> + gpu.yield %1 : vector<16x16xf32> + } + "some_use"(%r) : (vector<1x16xf32>) -> () + gpu.return +} + +// CHECK-LABEL: gpu.func @vector_extract_strided_slice_1d +// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>, vector<4xf32>) { +// CHECK: %[[S:.*]] = "some_def"() : () -> vector<64xf32> +// CHECK: gpu.yield %{{.*}}, %[[S]] : vector<32xf32>, vector<64xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[T1:.*]] = vector.extract_strided_slice %[[W]]#1 +// CHECK-SAME: {offsets = [1], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> +// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<2xf32>) -> () +gpu.func @vector_extract_strided_slice_1d(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) { + %0 = "some_def"() : () -> (vector<64xf32>) + %1 = vector.extract_strided_slice %0 { offsets = [16], sizes = [32], strides = [1], + layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, + layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> + } + : vector<64xf32> to vector<32xf32> + gpu.yield %1 : vector<32xf32> + } + "some_use"(%r) : (vector<2xf32>) -> () + gpu.return +} + +// CHECK-LABEL: gpu.func @vector_extract_strided_slice_unsopported_offset +// CHECK: %{{.*}} = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) { +// CHECK: } +// CHECK-NOT: %{{.*}} = vector.extract_strided_slice +gpu.func @vector_extract_strided_slice_unsopported_offset(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) { + %0 = "some_def"() : () -> (vector<64xf32>) + %1 = vector.extract_strided_slice %0 { offsets = [3], sizes = [32], strides = [1], + layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, + layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> + } + : vector<64xf32> to vector<32xf32> + gpu.yield %1 : vector<32xf32> + } + "some_use"(%r) : (vector<2xf32>) -> () + gpu.return +} + +// CHECK-LABEL: gpu.func @vector_extract_strided_slice_unsopported_source +// CHECK: %{{.*}} = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) { +// CHECK: } +// CHECK-NOT: %{{.*}} = vector.extract_strided_slice +gpu.func @vector_extract_strided_slice_unsopported_source(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) { + %0 = "some_def"() : () -> (vector<54xf32>) + %1 = vector.extract_strided_slice %0 { offsets = [0], sizes = [32], strides = [1], + layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, + layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> + } + : vector<54xf32> to vector<32xf32> + gpu.yield %1 : vector<32xf32> + } + "some_use"(%r) : (vector<2xf32>) -> () + gpu.return +} + + +// CHECK-LABEL: gpu.func @vector_insert_strided_slice_distributed_dim_fully_inserted +// CHECK-NEXT: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<64x1xf32>, vector<16x1xf32>, vector<64x1xf32>) { +// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<16x16xf32> +// CHECK-NEXT: %[[D:.*]] = "some_def"() : () -> vector<64x16xf32> +// CHECK: gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<64x16xf32>, vector<16x16xf32>, vector<64x16xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[T1:.*]] = vector.insert_strided_slice %[[W]]#1, %[[W]]#2 +// CHECK-SAME: {offsets = [24, 0], strides = [1, 1]} : vector<16x1xf32> into vector<64x1xf32> +// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<64x1xf32>) -> () +gpu.func @vector_insert_strided_slice_distributed_dim_fully_inserted(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<64x1xf32>) { + %0 = "some_def"() : () -> (vector<16x16xf32>) + %1 = "some_def"() : () -> (vector<64x16xf32>) + %2 = vector.insert_strided_slice %0, %1 { offsets = [24, 0], strides = [1, 1], + layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, + layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, + layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> + } + : vector<16x16xf32> into vector<64x16xf32> + gpu.yield %2 : vector<64x16xf32> + } + "some_use"(%r) : (vector<64x1xf32>) -> () + gpu.return +} + + +// CHECK-LABEL: gpu.func @vector_insert_strided_slice_non_distributed +// CHECK-NEXT: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<64x1xf32>, vector<16x1xf32>, vector<64x1xf32>) { +// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<16x1xf32> +// CHECK-NEXT: %[[D:.*]] = "some_def"() : () -> vector<64x1xf32> +// CHECK: gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<64x1xf32>, vector<16x1xf32>, vector<64x1xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[T1:.*]] = vector.insert_strided_slice %[[W]]#1, %[[W]]#2 +// CHECK-SAME: {offsets = [24, 0], strides = [1, 1]} : vector<16x1xf32> into vector<64x1xf32> +// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<64x1xf32>) -> () +gpu.func @vector_insert_strided_slice_non_distributed(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<64x1xf32>) { + %0 = "some_def"() : () -> (vector<16x1xf32>) + %1 = "some_def"() : () -> (vector<64x1xf32>) + %2 = vector.insert_strided_slice %0, %1 { offsets = [24, 0], strides = [1, 1], + layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, + layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, + layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> } - "some_user_op"(%r) : (vector<1x1xf32>) -> () - gpu.return + : vector<16x1xf32> into vector<64x1xf32> + gpu.yield %2 : vector<64x1xf32> } + "some_use"(%r) : (vector<64x1xf32>) -> () + gpu.return +} + +// CHECK-LABEL: gpu.func @vector_insert_strided_slice_inner_distributed +// CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<64x2xf32>, vector<16x1xf32>, vector<64x2xf32>) { +// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<16x16xf32> +// CHECK-NEXT: %[[D:.*]] = "some_def"() : () -> vector<64x32xf32> +// CHECK: gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<64x32xf32>, vector<16x16xf32>, vector<64x32xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[T1:.*]] = vector.insert_strided_slice %[[W]]#1, %[[W]]#2 +// CHECK-SAME: {offsets = [24, 1], strides = [1, 1]} : vector<16x1xf32> into vector<64x2xf32> +// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<64x2xf32>) -> () +gpu.func @vector_insert_strided_slice_inner_distributed(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<64x2xf32>) { + %0 = "some_def"() : () -> (vector<16x16xf32>) + %1 = "some_def"() : () -> (vector<64x32xf32>) + %2 = vector.insert_strided_slice %0, %1 { offsets = [24, 16], strides = [1, 1], + layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, + layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, + layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> + } + : vector<16x16xf32> into vector<64x32xf32> + gpu.yield %2 : vector<64x32xf32> + } + "some_use"(%r) : (vector<64x2xf32>) -> () + gpu.return +} + +// CHECK-LABEL: gpu.func @vector_insert_strided_slice_outer_distributed +// CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<3x32xf32>, vector<1x16xf32>, vector<3x32xf32>) { +// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<16x16xf32> +// CHECK-NEXT: %[[D:.*]] = "some_def"() : () -> vector<48x32xf32> +// CHECK: gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<48x32xf32>, vector<16x16xf32>, vector<48x32xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[T1:.*]] = vector.insert_strided_slice %[[W]]#1, %[[W]]#2 +// CHECK-SAME: {offsets = [2, 4], strides = [1, 1]} : vector<1x16xf32> into vector<3x32xf32> +// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<3x32xf32>) -> () +gpu.func @vector_insert_strided_slice_outer_distributed(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<3x32xf32>) { + %0 = "some_def"() : () -> (vector<16x16xf32>) + %1 = "some_def"() : () -> (vector<48x32xf32>) + %2 = vector.insert_strided_slice %0, %1 { offsets = [32, 4], strides = [1, 1], + layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, + layout_operand_1 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, + layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]> + } + : vector<16x16xf32> into vector<48x32xf32> + gpu.yield %2 : vector<48x32xf32> + } + "some_use"(%r) : (vector<3x32xf32>) -> () + gpu.return +} + +// CHECK-LABEL: gpu.func @vector_insert_strided_slice_1d +// CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<3xf32>, vector<1xf32>, vector<3xf32>) { +// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<16xf32> +// CHECK-NEXT: %[[D:.*]] = "some_def"() : () -> vector<48xf32> +// CHECK: gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<48xf32>, vector<16xf32>, vector<48xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[T1:.*]] = vector.insert_strided_slice %[[W]]#1, %[[W]]#2 +// CHECK-SAME: {offsets = [1], strides = [1]} : vector<1xf32> into vector<3xf32> +// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<3xf32>) -> () +gpu.func @vector_insert_strided_slice_1d(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<3xf32>) { + %0 = "some_def"() : () -> (vector<16xf32>) + %1 = "some_def"() : () -> (vector<48xf32>) + %2 = vector.insert_strided_slice %0, %1 { offsets = [16], strides = [1], + layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, + layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, + layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> + } + : vector<16xf32> into vector<48xf32> + gpu.yield %2 : vector<48xf32> + } + "some_use"(%r) : (vector<3xf32>) -> () + gpu.return +} + +// CHECK-LABEL: gpu.func @vector_insert_strided_slice_unsupported_source +// CHECK: %{{.*}} = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<3xf32>) { +// CHECK: } +// CHECK-NOT: %{{.*}} = vector.insert_strided_slice +gpu.func @vector_insert_strided_slice_unsupported_source(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<3xf32>) { + %0 = "some_def"() : () -> (vector<8xf32>) + %1 = "some_def"() : () -> (vector<48xf32>) + %2 = vector.insert_strided_slice %0, %1 { offsets = [16], strides = [1], + layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, + layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, + layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> + } + : vector<8xf32> into vector<48xf32> + gpu.yield %2 : vector<48xf32> + } + "some_use"(%r) : (vector<3xf32>) -> () + gpu.return +} + +// CHECK-LABEL: gpu.func @vector_insert_strided_slice_unsupported_offset +// CHECK: %{{.*}} = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<3xf32>) { +// CHECK: } +// CHECK-NOT: %{{.*}} = vector.insert_strided_slice +gpu.func @vector_insert_strided_slice_unsupported_offset(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<3xf32>) { + %0 = "some_def"() : () -> (vector<16xf32>) + %1 = "some_def"() : () -> (vector<48xf32>) + %2 = vector.insert_strided_slice %0, %1 { offsets = [3], strides = [1], + layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, + layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, + layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> + } + : vector<16xf32> into vector<48xf32> + gpu.yield %2 : vector<48xf32> + } + "some_use"(%r) : (vector<3xf32>) -> () + gpu.return +} + +// CHECK-LABEL: gpu.func @vector_broadcast_1d_to_2d_broadcast_within_lane +// CHECK-SAME: (%[[ARG0:.*]]: index) { +// CHECK: %[[R:.*]]:2 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<16x1xf16>, vector<1xf16>) +// CHECK: %[[DEF:.*]] = "some_def"() +// CHECK: %[[BCAST_INNER:.*]] = vector.broadcast %[[DEF]] +// CHECK: gpu.yield %[[BCAST_INNER]], %[[DEF]] +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[R]]#1 : vector<1xf16> to vector<16x1xf16> +// CHECK: "some_use"(%[[BCAST]]) +gpu.func @vector_broadcast_1d_to_2d_broadcast_within_lane(%laneid: index) { + + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<16x1xf16>) { + + %1 = "some_def"() : () -> vector<16xf16> + + %2 = vector.broadcast %1 { + layout_operand_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>, + layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> + } : vector<16xf16> to vector<16x16xf16> + + gpu.yield %2 : vector<16x16xf16> + } + "some_use"(%r) : (vector<16x1xf16>) -> () + gpu.return +} + +// CHECK-LABEL: gpu.func @vector_broadcast_2d_to_2d_across_lane_lower_to_noop_case +// CHECK-SAME: (%[[ARG0:.*]]: index) +// CHECK: %[[R:.*]]:2 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<16x1xf16>, vector<16x1xf16>) +// CHECK: %[[DEF:.*]] = "some_def"() : () -> vector<16x1xf16> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[DEF]] +// CHECK-SAME: : vector<16x1xf16> to vector<16x16xf16> +// CHECK: gpu.yield %[[BCAST]], %[[DEF]] : vector<16x16xf16>, vector<16x1xf16> +// CHECK: "some_use"(%[[R]]#1) : (vector<16x1xf16>) -> () +gpu.func @vector_broadcast_2d_to_2d_across_lane_lower_to_noop_case(%arg0: index) { + %0 = gpu.warp_execute_on_lane_0(%arg0)[16] -> (vector<16x1xf16>) { + %1 = "some_def"() : () -> vector<16x1xf16> + %2 = vector.broadcast %1 { + layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, + layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> + } : vector<16x1xf16> to vector<16x16xf16> + gpu.yield %2: vector<16x16xf16> + } + "some_use"(%0) : (vector<16x1xf16>) -> () + gpu.return +} + +// CHECK-LABEL: gpu.func @vector_shape_cast_scalar_to_vector +// CHECK-SAME: (%[[ARG0:.*]]: index) +// CHECK: %[[R:.*]]:2 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<16x1xf16>, f16) +// CHECK: %[[DEF:.*]] = "some_def"() +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[DEF]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : f16 to vector<16x16xf16> +// CHECK: gpu.yield %[[BCAST]], %[[DEF]] : vector<16x16xf16>, f16 +// CHECK: %[[RESULT:.*]] = vector.broadcast %[[R]]#1 : f16 to vector<16x1xf16> +// CHECK: "some_use"(%[[RESULT]]) +gpu.func +@vector_shape_cast_scalar_to_vector(%arg0: index) { + %0 = gpu.warp_execute_on_lane_0(%arg0)[16] -> (vector<16x1xf16>) { + %1 = "some_def"() : () -> f16 + %2 = vector.broadcast %1 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : f16 to vector<16x16xf16> + gpu.yield %2 : vector<16x16xf16> + } + "some_use"(%0) : (vector<16x1xf16>) -> () + gpu.return +} + } diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir index 27a3dc3..e5e3d2a 100644 --- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir +++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir @@ -265,3 +265,129 @@ gpu.module @xevm_module{ gpu.return } } + +// ----- +// CHECK-LABEL: gpu.func @load_store_matrix_1({{.*}}) { +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[C8:.*]] = arith.constant 8 : index +// CHECK: %[[LANE_ID:.*]] = gpu.lane_id +// CHECK: %[[REMU1:.*]] = arith.remui %[[LANE_ID]], %[[C8]] +// CHECK: %[[DIVU:.*]] = arith.divui %[[LANE_ID]], %[[C8]] +// CHECK: %[[REMU2:.*]] = arith.remui %[[DIVU]], %[[C2]] +// CHECK: %[[REMU3:.*]] = arith.remui %[[REMU2]], %[[C2]] +// CHECK: %[[REMU4:.*]] = arith.remui %[[REMU1]], %[[C8]] +// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[REMU3]], %[[REMU4]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<1x1xf32> +// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%[[REMU3]], %[[REMU4]]] : vector<1x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index +gpu.module @xevm_module{ + gpu.func @load_store_matrix_1(%arg0: !xegpu.mem_desc<32x32xf32>) { + %c0 = arith.constant 0 : index + %1 = xegpu.load_matrix %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}> : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x8xf32> + xegpu.store_matrix %1, %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}> : vector<2x8xf32>, !xegpu.mem_desc<32x32xf32>, index, index + gpu.return + } +} + +// ----- +// CHECK-LABEL: gpu.func @load_store_matrix_2({{.*}}) { +// CHECK: %[[C8:.*]] = arith.constant 8 : index +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[C4:.*]] = arith.constant 4 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[LANE_ID:.*]] = gpu.lane_id +// CHECK: %[[REMU1:.*]] = arith.remui %[[LANE_ID]], %[[C4]] +// CHECK: %[[DIVU:.*]] = arith.divui %[[LANE_ID]], %[[C4]] +// CHECK: %[[REMU2:.*]] = arith.remui %[[DIVU]], %[[C4]] +// CHECK: %[[MUL:.*]] = arith.muli %[[REMU2]], %[[C2]] +// CHECK: %[[REMU3:.*]] = arith.remui %[[MUL]], %[[C8]] +// CHECK: %[[REMU4:.*]] = arith.remui %[[REMU1]], %[[C4]] +// CHECK: %[[ADD:.*]] = arith.addi %[[REMU4]], %[[C1]] +// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[REMU3]], %[[ADD]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x1xf32> +// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%[[REMU3]], %[[ADD]]] : vector<2x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index +gpu.module @xevm_module{ + gpu.func @load_store_matrix_2(%arg0: !xegpu.mem_desc<32x32xf32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %1 = xegpu.load_matrix %arg0[%c0, %c1] <{layout = #xegpu.layout<lane_layout = [4, 4], lane_data = [2, 1]>}> : !xegpu.mem_desc<32x32xf32>, index, index -> vector<8x4xf32> + xegpu.store_matrix %1, %arg0[%c0, %c1] <{layout = #xegpu.layout<lane_layout = [4, 4], lane_data = [2, 1]>}> : vector<8x4xf32>, !xegpu.mem_desc<32x32xf32>, index, index + gpu.return + } +} + +// ----- +// CHECK-LABEL: gpu.func @load_store_matrix_3({{.*}}) { +// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%{{.*}}, %{{.*}}] <{subgroup_block_io}>: +// CHECK-SAME: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<block = [16, 1], stride = [1, 32]>>, index, index -> vector<1x2xf32> +// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%{{.*}}, %{{.*}}] <{subgroup_block_io}>: +// CHECK-SAME: vector<1x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<block = [16, 1], stride = [1, 32]>>, index, index +gpu.module @xevm_module{ + gpu.func @load_store_matrix_3(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 1]>>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %1 = xegpu.load_matrix %arg0[%c0, %c1] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} : + !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 1]>>, index, index -> vector<16x2xf32> + xegpu.store_matrix %1, %arg0[%c0, %c1] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} : + vector<16x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 1]>>, index, index + gpu.return + } +} + +// ----- +// CHECK-LABEL: gpu.func @vector_broadcast_1d_to_2d_broadcast_within_lane({{.*}}) { +gpu.module @xevm_module{ + gpu.func @vector_broadcast_1d_to_2d_broadcast_within_lane(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>) { + %c0 = arith.constant 0 : index + %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} dense<0.000000e+00> : vector<16xf16> + %tdesc0 = xegpu.create_nd_tdesc %arg0 : memref<16x16xf16> + -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + %tdesc1 = xegpu.create_nd_tdesc %arg1 : memref<16x16xf16> + -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + %0 = xegpu.load_nd %tdesc0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16> + %1 = vector.multi_reduction <add>, %0, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} [0] : vector<16x16xf16> to vector<16xf16> + // CHECK: %[[BCAST:.*]] = vector.broadcast %{{.*}} : f16 to vector<16xf16> + %2 = vector.broadcast %1 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf16> to vector<16x16xf16> + xegpu.store_nd %2, %tdesc1[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + gpu.return + } +} + +// ----- +// CHECK-LABEL: gpu.func @vector_broadcast_2d_to_2d_across_lane_lower_to_noop_case({{.*}}) { +gpu.module @xevm_module{ + gpu.func @vector_broadcast_2d_to_2d_across_lane_lower_to_noop_case(%arg0: memref<16xf16>, %arg1: memref<16x16xf16>) { + %c0 = arith.constant 0 : index + %mask = vector.constant_mask [16] {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}: vector<16xi1> + %1 = xegpu.load %arg0[%c0], %mask {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}: memref<16xf16>, index, vector<16xi1> -> vector<16xf16> + + %11 = vector.shape_cast %1 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf16> to vector<16x1xf16> + %2 = vector.broadcast %11 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x1xf16> to vector<16x16xf16> + // CHECK-NOT: vector.broadcast + // CHECK-NOT: vector.shape_cast + + %tdesc1 = xegpu.create_nd_tdesc %arg1 : memref<16x16xf16> + -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK: xegpu.store_nd {{.*}}, {{.*}}[{{.*}}, {{.*}}] + // CHECK-SAME: : vector<16xf16>, !xegpu.tensor_desc<16x16xf16> + + xegpu.store_nd %2, %tdesc1[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + gpu.return + } +} + +// ----- +// CHECK-LABEL: gpu.func @vector_shape_cast_scalar_to_vector({{.*}}) { +gpu.module @xevm_module{ + gpu.func @vector_shape_cast_scalar_to_vector(%arg0: memref<16xf16>, %arg1: memref<16x16xf16>) { + %c0 = arith.constant 0 : index + %9 = gpu.block_id x + %10 = arith.index_cast %9 : index to i16 + %11 = arith.bitcast %10 : i16 to f16 + // CHECK: vector.broadcast {{.*}} : f16 to vector<16xf16> + %2 = vector.broadcast %11 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : f16 to vector<16x16xf16> + %tdesc1 = xegpu.create_nd_tdesc %arg1 : memref<16x16xf16> + -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + xegpu.store_nd %2, %tdesc1[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + gpu.return + } +} + + diff --git a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir new file mode 100644 index 0000000..dce4a41 --- /dev/null +++ b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir @@ -0,0 +1,157 @@ +// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics + +func.func @set_desc_layout(%arg0: memref<4096x4096xf16>) { + %c32 = arith.constant 32 : index // expected-note {{target op}} + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error@below {{Expected a xegpu.create_nd_desc op, but got: arith.constant}} + %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] sg_data = [32, 32] : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_op_layout_attr_bad_result_index +func.func @set_op_layout_attr_bad_result_index(%arg0: memref<4096x4096xf16>) { + %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error@below {{Index exceeds the number of op results}} + transform.xegpu.set_op_layout_attr %0 result index = 1 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_op_layout_attr_bad_operand_index +func.func @set_op_layout_attr_bad_operand_index(%arg0: memref<4096x4096xf16>) { + %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error@below {{Index exceeds the number of op operands}} + transform.xegpu.set_op_layout_attr %0 index = 1 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_op_layout_attr_multiple +func.func @set_op_layout_attr_multiple(%arg0: memref<4096x4096xf16>) { + %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32> + %3 = arith.extf %2 : vector<256x32xf32> to vector<256x32xf64> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error@below {{Requires exactly one targetOp handle (got 2)}} + transform.xegpu.set_op_layout_attr %0 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @set_gpu_launch_threads_bad_handle(%arg0: memref<4096x4096xf16>) { + %c32 = arith.constant 32 : index // expected-note {{target op}} + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error@below {{Expected a gpu.launch op, but got: arith.constant}} + transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @set_gpu_launch_threads_many_handles(%arg0: memref<4096x4096xf16>) { + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error@below {{Requires exactly one targetOp handle (got 2)}} + transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @set_gpu_launch_threads_bad_threads(%arg0: memref<4096x4096xf16>) { + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) { + gpu.terminator + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error@below {{Expected threads argument to consist of three values (got 2)}} + transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4] : !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @insert_prefetch_dpas_c +func.func @insert_prefetch_dpas_c(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + %c32 = arith.constant 32 : index + %c4096 = arith.constant 4096 : index + %c0 = arith.constant 0 : index + %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> + // expected-note@below {{load op}} + %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> + %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> + %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) { + %5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + %6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> + %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> + scf.yield %7 : vector<256x256xf16> + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_operand %0[2] : (!transform.any_op) -> !transform.any_value + // expected-error@below {{Load op is not contained in a scf.for loop.}} + %2 = transform.xegpu.insert_prefetch %1 nb_prefetch = 1 : (!transform.any_value) -> !transform.any_op + transform.yield + } +} diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir new file mode 100644 index 0000000..561034f --- /dev/null +++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir @@ -0,0 +1,509 @@ +// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: @get_desc_op_a +func.func @get_desc_op_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + %c32 = arith.constant 32 : index + %c4096 = arith.constant 4096 : index + %c0 = arith.constant 0 : index + %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> + %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> + // expected-remark @below {{found desc op}} + %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> + %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) { + %5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + %6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> + %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> + scf.yield %7 : vector<256x256xf16> + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value + %2 = transform.xegpu.get_desc_op %1 : (!transform.any_value) -> !transform.any_op + transform.debug.emit_remark_at %2, "found desc op" : !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @get_desc_op_c +func.func @get_desc_op_c(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + %c32 = arith.constant 32 : index + %c4096 = arith.constant 4096 : index + %c0 = arith.constant 0 : index + // expected-remark @below {{found desc op}} + %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> + %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> + %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> + %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) { + %5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + %6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> + %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> + scf.yield %7 : vector<256x256xf16> + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_operand %0[2] : (!transform.any_op) -> !transform.any_value + %2 = transform.xegpu.get_desc_op %1 : (!transform.any_value) -> !transform.any_op + transform.debug.emit_remark_at %2, "found desc op" : !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_desc_layout +func.func @set_desc_layout(%arg0: memref<4096x4096xf16>) { + // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 + // CHECK-SAME: #xegpu.block_tdesc_attr<boundary_check = false> + // CHECK-SAME: #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>> + %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.block_tdesc_attr<boundary_check = false>> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_desc_layout %{{.*}} + %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_desc_layout_minimal +func.func @set_desc_layout_minimal(%arg0: memref<4096x4096xf16>) { + // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 + // CHECK-SAME: #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>> + %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_desc_layout %{{.*}} + %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] sg_data = [32, 32] : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_desc_layout_param +func.func @set_desc_layout_param(%arg0: memref<4096x4096xf16>) { + // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 + // CHECK-SAME: #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>> + %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_desc_layout %{{.*}} + %layout0 = transform.param.constant 8 : i64 -> !transform.param<i64> + %1 = transform.xegpu.set_desc_layout %0 sg_layout = [%layout0, 4] sg_data = [32, 32] inst_data = [8, 16] : (!transform.any_op, !transform.param<i64>) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_desc_layout_slice +func.func @set_desc_layout_slice(%arg0: memref<4096xf16>) { + // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 + // CHECK-SAME: #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]> + %0 = xegpu.create_nd_tdesc %arg0 : memref<4096xf16> -> !xegpu.tensor_desc<256xf16> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_desc_layout %{{.*}} + %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] sg_data = [32, 32] slice_dims = [0] : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_op_layout_attr_result_default_index +func.func @set_op_layout_attr_result_default_index(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> + %3 = xegpu.load_nd %2[0, 0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> + %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> + %5 = xegpu.load_nd %4[0, 0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> + // CHECK: = xegpu.dpas + // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>} + %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_op_layout_attr %{{.*}} + transform.xegpu.set_op_layout_attr %0 result sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_op_layout_attr_result_sg_param +func.func @set_op_layout_attr_result_sg_param(%arg0: memref<4096x4096xf16>) { + %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + // CHECK: = arith.extf %1 + // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>} + %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_op_layout_attr %{{.*}} + %layout0 = transform.param.constant 8 : i64 -> !transform.param<i64> + transform.xegpu.set_op_layout_attr %0 result sg_layout = [%layout0, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op, !transform.param<i64> + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_op_layout_attr_result_sg_param2 +func.func @set_op_layout_attr_result_sg_param2(%arg0: memref<4096x4096xf16>) { + %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + // CHECK: = arith.extf %1 + // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>} + %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_op_layout_attr %{{.*}} + %layout0 = transform.param.constant 8 : i64 -> !transform.param<i64> + %layout1 = transform.param.constant 4 : i64 -> !transform.param<i64> + transform.xegpu.set_op_layout_attr %0 result sg_layout = [%layout0, %layout1] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op, !transform.param<i64>, !transform.param<i64> + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_op_layout_attr_result0 +func.func @set_op_layout_attr_result0(%arg0: memref<4096x4096xf16>) { + %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + // CHECK: = arith.extf %1 + // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>} + %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_op_layout_attr %{{.*}} + transform.xegpu.set_op_layout_attr %0 result index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_op_layout_attr_result_slice +func.func @set_op_layout_attr_result_slice(%arg0: vector<256xf16>) { + // CHECK: = arith.extf + // CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>, dims = [0]>} + %2 = arith.extf %arg0 : vector<256xf16> to vector<256xf32> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_op_layout_attr %{{.*}} + transform.xegpu.set_op_layout_attr %0 result index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] slice_dims = [0] : !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_op_layout_attr_operand_minimal +func.func @set_op_layout_attr_operand_minimal(%arg0: memref<4096x4096xf16>) { + %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + // CHECK: = arith.extf %1 + // CHECK-SAME: {layout_operand_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64]>} + %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_op_layout_attr %{{.*}} + transform.xegpu.set_op_layout_attr %0 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_op_layout_attr_operand1 +func.func @set_op_layout_attr_operand1(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>) { + %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %3 = xegpu.load_nd %2[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + // CHECK: = arith.addf %1, %3 + // CHECK-SAME: {layout_operand_1 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>} + %6 = arith.addf %1, %3 : vector<256x32xf16> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["arith.addf"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_op_layout_attr %{{.*}} + transform.xegpu.set_op_layout_attr %0 index = 1 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_gpu_launch_threads +func.func @set_gpu_launch_threads(%arg0: memref<4096x4096xf16>) { + // CHECK: %[[C1:.+]] = arith.constant 1 : index + %c1 = arith.constant 1 : index + // CHECK: %[[C16:.+]] = arith.constant 16 : index + %c16 = arith.constant 16 : index + // CHECK: %[[C8:.+]] = arith.constant 8 : index + // CHECK: %[[C4:.+]] = arith.constant 4 : index + // CHECK: %[[C1_0:.+]] = arith.constant 1 : index + // CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C16]], %{{.*}} = %[[C16]], %{{.*}} = %[[C1]]) + // CHECK-SAME: threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C8]], %{{.*}} = %[[C4]], %{{.*}} = %[[C1_0]]) + gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) { + gpu.terminator + } + return +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_gpu_launch_threads %{{.*}} + transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_gpu_launch_threads_param +func.func @set_gpu_launch_threads_param(%arg0: memref<4096x4096xf16>) { + // CHECK: %[[C1:.+]] = arith.constant 1 : index + %c1 = arith.constant 1 : index + // CHECK: %[[C16:.+]] = arith.constant 16 : index + %c16 = arith.constant 16 : index + // CHECK: %[[C8:.+]] = arith.constant 8 : index + // CHECK: %[[C4:.+]] = arith.constant 4 : index + // CHECK: %[[C1_0:.+]] = arith.constant 1 : index + // CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C16]], %{{.*}} = %[[C16]], %{{.*}} = %[[C1]]) + // CHECK-SAME: threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C8]], %{{.*}} = %[[C4]], %{{.*}} = %[[C1_0]]) + gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) { + gpu.terminator + } + return +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_gpu_launch_threads %{{.*}} + %th1 = transform.param.constant 4 : i64 -> !transform.param<i64> + transform.xegpu.set_gpu_launch_threads %0 threads = [8, %th1, 1] : !transform.any_op, !transform.param<i64> + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @insert_prefetch_dpas_a +func.func @insert_prefetch_dpas_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + // CHECK: %[[C32:.+]] = arith.constant 32 : index + %c32 = arith.constant 32 : index + %c4096 = arith.constant 4096 : index + // CHECK: %[[C0:.+]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> + %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> + // CHECK: xegpu.create_nd_tdesc %arg0 + // CHECK: xegpu.create_nd_tdesc %arg1 + // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 + // CHECK-SAME: !xegpu.tensor_desc<256x32xf16 + // CHECK: xegpu.prefetch_nd %[[V0]][%[[C0]], %[[C0]]] + %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> + // CHECK: scf.for %[[ARG3:.+]] = %[[C0]] + %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) { + // CHECK: %[[ADD:.+]] = arith.addi %[[ARG3]], %[[C32]] + // CHECK: xegpu.prefetch_nd %[[V0]][%[[C0]], %[[ADD]]] + %5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + %6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> + %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> + scf.yield %7 : vector<256x256xf16> + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %0 = transform.structured.match ops{["xegpu.dpas"]} in %func : (!transform.any_op) -> !transform.any_op + %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value + // CHECK: transform.xegpu.insert_prefetch %{{.*}} + %2 = transform.xegpu.insert_prefetch %1 nb_prefetch = 1 : (!transform.any_value) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.canonicalization + } : !transform.any_op + + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @insert_prefetch_dpas_a_nb_param2 +func.func @insert_prefetch_dpas_a_nb_param2(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + // CHECK: %[[C64:.+]] = arith.constant 64 : index + // CHECK: %[[C32:.+]] = arith.constant 32 : index + %c32 = arith.constant 32 : index + %c4096 = arith.constant 4096 : index + // CHECK: %[[C0:.+]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> + %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> + // CHECK: xegpu.create_nd_tdesc %arg0 + // CHECK: xegpu.create_nd_tdesc %arg1 + // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 + // CHECK-SAME: !xegpu.tensor_desc<256x32xf16 + // CHECK: xegpu.prefetch_nd %[[V0]][0, %[[C0]]] + // CHECK: xegpu.prefetch_nd %[[V0]][0, %[[C32]]] + %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> + // CHECK: scf.for %[[ARG3:.+]] = %[[C0]] + %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) { + // CHECK: %[[ADD:.+]] = arith.addi %[[ARG3]], %[[C64]] + // CHECK: xegpu.prefetch_nd %[[V0]][0, %[[ADD]]] + %5 = xegpu.load_nd %3[0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + %6 = xegpu.load_nd %4[%arg3, 0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> + %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> + scf.yield %7 : vector<256x256xf16> + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %0 = transform.structured.match ops{["xegpu.dpas"]} in %func : (!transform.any_op) -> !transform.any_op + %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value + %nb = transform.param.constant 2 : i64 -> !transform.param<i64> + // CHECK: transform.xegpu.insert_prefetch %{{.*}} + %2 = transform.xegpu.insert_prefetch %1 nb_prefetch = %nb : (!transform.any_value, !transform.param<i64>) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @convert_layout_a +func.func @convert_layout_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + %c0 = arith.constant 0 : index + // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 + %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>> + // CHECK: %[[V1:.+]] = xegpu.load_nd %[[V0]] + %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>> -> vector<256x32xf16> + // CHECK: %[[V2:.+]] = xegpu.convert_layout %[[V1]] + // CHECK: input_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]> + // CHECK: target_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]> + %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> + %3 = xegpu.load_nd %2[%c0, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> + %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> + %5 = xegpu.load_nd %4[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> + // CHECK: = xegpu.dpas %[[V2]] + %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value + // CHECK: transform.xegpu.convert_layout %{{.*}} + transform.xegpu.convert_layout %1 + input_sg_layout = [8, 4] input_sg_data = [32, 32] input_inst_data = [32, 16] + target_sg_layout = [8, 4] target_sg_data = [32, 32] target_inst_data = [8, 16] + : (!transform.any_value) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @convert_layout_a_sg_param +func.func @convert_layout_a_sg_param(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + %c0 = arith.constant 0 : index + // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 + %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>> + // CHECK: %[[V1:.+]] = xegpu.load_nd %[[V0]] + %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>> -> vector<256x32xf16> + // CHECK: %[[V2:.+]] = xegpu.convert_layout %[[V1]] + // CHECK: input_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]> + // CHECK: target_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]> + %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> + %3 = xegpu.load_nd %2[%c0, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> + %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> + %5 = xegpu.load_nd %4[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> + // CHECK: = xegpu.dpas %[[V2]] + %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value + %layout0 = transform.param.constant 8 : i64 -> !transform.param<i64> + // CHECK: transform.xegpu.convert_layout %{{.*}} + transform.xegpu.convert_layout %1 + input_sg_layout = [%layout0, 4] input_sg_data = [32, 32] input_inst_data = [32, 16] + target_sg_layout = [%layout0, 4] target_sg_data = [32, 32] target_inst_data = [8, 16] + : (!transform.any_value, !transform.param<i64>, !transform.param<i64>) -> !transform.any_op + transform.yield + } +} diff --git a/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir b/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir index b73bc69..8ce6d4d 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir @@ -1,33 +1,30 @@ // RUN: mlir-opt --test-xegpu-layout-interface --cse -split-input-file %s | FileCheck %s -//CHECk: #map = affine_map<()[s0] -> (s0 floordiv 8)> gpu.module @test { gpu.func @slice_attr() -> vector<128xindex> { - //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index - //CHECK: [[IDY:%.+]] = affine.apply #map()[[[sgId]]] - //CHECK: [[c32:%.+]] = arith.constant 32 : index - //CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]] - //CHECK: [[c128:%.+]] = arith.constant 128 : index - //CHECK: [[MODY:%.+]] = index.remu [[LOCALY]], [[c128]] - //CHECK: [[BASE:%.+]] = vector.step : vector<32xindex> - //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex> - //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex> + // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK-DAG: %[[DIVU:.*]] = arith.divui %[[SGID]], %[[C8:.*]] + // CHECK-DAG: %[[REMU:.*]] = arith.remui %[[DIVU]], %[[C4:.*]] + // CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU]], %[[C32:.*]] + // CHECK-DAG: %[[MOD:.*]] = arith.remui %[[MUL]], %[[C128:.*]] + // CHECK-DAG: %[[BASE:.*]] = vector.step : vector<32xindex> + // CHECK-DAG: %[[CAST:.*]] = vector.broadcast %[[MOD]] : index to vector<32xindex> + // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[BASE]], %[[CAST]] : vector<32xindex> %step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>, dims = [1]>}: vector<128xindex> gpu.return %step : vector<128xindex> } gpu.func @nested_slice_attr() -> vector<128xindex> { - //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index - //CHECK: [[IDY:%.+]] = affine.apply #map()[[[sgId]]] - //CHECK: [[c32:%.+]] = arith.constant 32 : index - //CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]] - //CHECK: [[c128:%.+]] = arith.constant 128 : index - //CHECK: [[MODY:%.+]] = index.remu [[LOCALY]], [[c128]] - //CHECK: [[BASE:%.+]] = vector.step : vector<32xindex> - //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex> - //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex> + // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK-DAG: %[[DIVU2:.*]] = arith.divui %[[SGID]], %[[C8:.*]] + // CHECK-DAG: %[[REMU:.*]] = arith.remui %[[DIVU2]], %[[C4:.*]] + // CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU]], %[[C32:.*]] + // CHECK-DAG: %[[MOD:.*]] = arith.remui %[[MUL]], %[[C128:.*]] + // CHECK-DAG: %[[BASE:.*]] = vector.step : vector<32xindex> + // CHECK-DAG: %[[CAST:.*]] = vector.broadcast %[[MOD]] : index to vector<32xindex> + // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[BASE]], %[[CAST]] : vector<32xindex> %0 = vector.step {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 1], sg_data = [32, 32, 1]>, dims = [2]>, dims = [1]>} : vector<128xindex> gpu.return %0 : vector<128xindex> } -}
\ No newline at end of file +} diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir index 09df1e4..9580769 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir @@ -166,14 +166,12 @@ gpu.module @test_elementwise_ops { %load_b = xegpu.load_nd %tdesc_b : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<24x32xf32> - // CHECK-COUNT-12: arith.negf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} - // CHECK-SAME-COUNT-12: : vector<2x2xf32> + // CHECK-COUNT-12: arith.negf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32> // CHECK-NOT: arith.negf %negf = arith.negf %load_a {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>} : vector<24x32xf32> - // CHECK-COUNT-12: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} - // CHECK-SAME-COUNT-12: : vector<2x2xf32> + // CHECK-COUNT-12: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32> // CHECK-NOT: math.powf %powf = math.powf %load_a, %load_b {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>} diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir index d2d250c..4829af3 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir @@ -1,14 +1,10 @@ // RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s -#map = affine_map<()[s0] -> (s0 floordiv 4)> -#map1 = affine_map<()[s0] -> (s0 mod 4)> - gpu.module @test_round_robin_assignment { // CHECK-LABEL: create_nd_tdesc // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) { - // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf32> - // CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf32> -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK-NOT: xegpu.create_nd_tdesc %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> @@ -16,22 +12,23 @@ gpu.module @test_round_robin_assignment { } // CHECK-LABEL: create_nd_tdesc_with_shared_data - // CHECK-SAME: [[ARG_0:%.*]]: memref<256x128xf32> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> gpu.func @create_nd_tdesc_with_shared_data(%src: memref<256x128xf32>) { - //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index - //CHECK: [[IdY:%.+]] = affine.apply #map()[[[sgId]]] - //CHECK: [[IdX:%.+]] = affine.apply #map1()[[[sgId]]] - //CHECK: [[C16:%.+]] = arith.constant 16 : index - //CHECK: [[LY:%.+]] = index.mul [[IdY]], [[C16]] - //CHECK: [[C64:%.+]] = arith.constant 64 : index - //CHECK: [[LX:%.+]] = index.mul [[IdX]], [[C64]] - //CHECK: [[C0:%.+]] = arith.constant 0 : index - //CHECK: [[C0_1:%.+]] = arith.constant 0 : index - //CHECK: [[C128:%.+]] = arith.constant 128 : index - //CHECK: [[offY:%.+]] = index.remu [[LY]], [[C128]] - //CHECK: [[C64_2:%.+]] = arith.constant 64 : index - //CHECK: [[offX:%.+]] = index.remu [[LX]], [[C64_2]] - //CHECK: xegpu.create_nd_tdesc [[ARG_0]][[[offY]], [[offX]]] : memref<256x128xf32> -> !xegpu.tensor_desc<16x64xf32> + // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK: %[[C4:.*]] = arith.constant 4 : index + // CHECK: %[[IDX:.*]] = arith.remui %[[SGID]], %[[C4]] + // CHECK: %[[IDY_DIV:.*]] = arith.divui %[[SGID]], %[[C4]] + // CHECK: %[[C8:.*]] = arith.constant 8 : index + // CHECK: %[[IDY:.*]] = arith.remui %[[IDY_DIV]], %[[C8]] + // CHECK: %[[C16:.*]] = arith.constant 16 : index + // CHECK: %[[LY:.*]] = arith.muli %[[IDY]], %[[C16]] + // CHECK: %[[C64:.*]] = arith.constant 64 : index + // CHECK: %[[LX:.*]] = arith.muli %[[IDX]], %[[C64]] + // CHECK: %[[C128:.*]] = arith.constant 128 : index + // CHECK: %[[OFFY:.*]] = arith.remui %[[LY]], %[[C128]] + // CHECK: %[[C64_1:.*]] = arith.constant 64 : index + // CHECK: %[[OFFX:.*]] = arith.remui %[[LX]], %[[C64_1]] + // CHECK: xegpu.create_nd_tdesc %[[ARG_0]][%[[OFFY]], %[[OFFX]]] : memref<256x128xf32> -> !xegpu.tensor_desc<16x64xf32> %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64]>> gpu.return @@ -42,9 +39,7 @@ gpu.module @test_round_robin_assignment { gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) { %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> - // CHECK-COUNT-4: xegpu.load_nd %{{.*}} - // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> - // CHECK-SAME-COUNT-4: -> vector<16x16xf32> + // CHECK-COUNT-4: xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf32> // CHECK-NOT: xegpu.load_nd %load = xegpu.load_nd %tdesc : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> @@ -57,9 +52,8 @@ gpu.module @test_round_robin_assignment { gpu.func @store_nd(%src: memref<256x128xf32>) { %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> - // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, %{{.*}} - // CHECK-SAME-COUNT-4: : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> - // CHECK-NOT : xegpu.store_nd + // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, %{{.*}} : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK-NOT: xegpu.store_nd %load = xegpu.load_nd %tdesc : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<256x128xf32> @@ -73,8 +67,7 @@ gpu.module @test_round_robin_assignment { gpu.func @update_nd(%src: memref<256x128xf32>){ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> - // CHECK-COUNT-4: xegpu.update_nd_offset %{{.*}}, [0, 16] - // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>> + // CHECK-COUNT-4: xegpu.update_nd_offset %{{.*}}, [0, 16] : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK-NOT: xegpu.update_nd_offset %update = xegpu.update_nd_offset %tdesc, [0, 16] : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> @@ -84,15 +77,9 @@ gpu.module @test_round_robin_assignment { // CHECK-LABEL: dpas // CHECK-SAME: (%[[ARG_0:.*]]: memref<256x128xf16>, %[[ARG_1:.*]]: memref<128x256xf16>) gpu.func @dpas(%a: memref<256x128xf16>, %b: memref<128x256xf16>) { - // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf16> - // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> - // CHECK-NOT: xegpu.create_nd_tdesc - // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<128x256xf16> - // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [4, 8], lane_data = [1, 1]>> - // CHECK-NOT: xegpu.create_nd_tdesc - // CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}} - // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} - // CHECK-SAME-COUNT-16: : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32> + // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<128x256xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> + // CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32> // CHECK-NOT: xegpu.dpas %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<256x128xf16> -> !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> @@ -113,8 +100,7 @@ gpu.module @test_round_robin_assignment { // CHECK-LABEL: prefetch_nd_tdesc // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> gpu.func @prefetch_nd_tdesc(%src: memref<256x128xf32>) { - // CHECK-COUNT-4: xegpu.prefetch_nd %{{.*}} - // CHECK-SAME-COUNT-4: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK-COUNT-4: xegpu.prefetch_nd %{{.*}} : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK-NOT: xegpu.prefetch_nd %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> @@ -131,9 +117,7 @@ gpu.module @test_round_robin_assignment { %load = xegpu.load_nd %tdesc : !xegpu.tensor_desc<128x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 1], lane_layout = [8, 1], lane_data = [1, 1]>> -> vector<128x1xf32> - // CHECK-COUNT-2: vector.broadcast {{.*}} - // CHECK-SAME-COUNT-2: {layout_result_0 = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>} - // CHECK-SAME-COUNT-2: : vector<16x1xf32> to vector<16x32xf32> + // CHECK-COUNT-2: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>} : vector<16x1xf32> to vector<16x32xf32> // CHECK-NOT: vector.broadcast %broadcast = vector.broadcast %load {layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 32], lane_layout = [8, 1], lane_data = [1, 1]>} @@ -171,10 +155,10 @@ gpu.module @test_round_robin_assignment { %0 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>> %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>> -> vector<256xf32> %2 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>> - //CHECK: scf.while ({{.*}}) : (vector<16xf32>, vector<16xf32>, i32) -> (vector<16xf32>, vector<16xf32>, i32) + // CHECK: scf.while ({{.*}}) : (vector<16xf32>, vector<16xf32>, i32) -> (vector<16xf32>, vector<16xf32>, i32) %3:2 = scf.while (%arg2 = %1, %arg3 = %c0_i32) : (vector<256xf32>, i32) -> (vector<256xf32>, i32) { %4 = arith.cmpi slt, %arg3, %c10_i32 : i32 - //CHECK: scf.condition{{.*}} : vector<16xf32>, vector<16xf32>, i32 + // CHECK: scf.condition{{.*}} : vector<16xf32>, vector<16xf32>, i32 scf.condition(%4) %arg2, %arg3 : vector<256xf32>, i32 } do { // CHECK: ([[arg2:%.+]]: vector<16xf32>, [[arg3:%.+]]: vector<16xf32>, [[arg4:%.+]]: i32) @@ -195,16 +179,16 @@ gpu.module @test_round_robin_assignment { %2 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>> %3 = arith.cmpi eq, %0, %c10 : index // CHECK-LABEL: scf.if - // CHECK-SAME: (vector<16xf32>, vector<16xf32>) + // CHECK-SAME: (vector<16xf32>, vector<16xf32>) %4 = scf.if %3 -> (vector<256xf32>) { %5 = xegpu.load_nd %1 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>> -> vector<256xf32> // CHECK-LABEL: scf.yield - // CHECK-SAME: vector<16xf32>, vector<16xf32> + // CHECK-SAME: vector<16xf32>, vector<16xf32> scf.yield %5 : vector<256xf32> } else { %5 = xegpu.load_nd %2 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>> -> vector<256xf32> // CHECK-LABEL: scf.yield - // CHECK-SAME: vector<16xf32>, vector<16xf32> + // CHECK-SAME: vector<16xf32>, vector<16xf32> scf.yield %5 : vector<256xf32> } {layout_result_0 = #xegpu.layout<sg_layout = [8], sg_data = [16]>} xegpu.store_nd %4, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>> @@ -220,16 +204,16 @@ gpu.module @test_round_robin_assignment { %0 = arith.cmpi eq, %id, %c10 : index // CHECK-LABEL: scf.if - // CHECK-SAME: (!xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>) + // CHECK-SAME: (!xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>) %1 = scf.if %0 -> (!xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>) { %2 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>> // CHECK-LABEL: scf.yield - // CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32> + // CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32> scf.yield %2 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>> } else { %3 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>> // CHECK-LABEL: scf.yield - // CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32> + // CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32> scf.yield %3 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>> } xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>> @@ -238,8 +222,8 @@ gpu.module @test_round_robin_assignment { gpu.func @convert_layout_optimal(%arg0: memref<32x64xf32>) { %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<32x64xf32> -> !xegpu.tensor_desc<32x64xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>> - //CHECK-2: xegpu.load_nd {{.*}} : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<inst_data = [16, 16]>> -> vector<16x16xf32> - //CHECK-2: xegpu.convert_layout {{.*}} <{input_layout = #xegpu.layout<inst_data = [16, 16]>, target_layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<16x16xf32> + // CHECK-COUNT-2: xegpu.load_nd {{.*}} : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<inst_data = [16, 16]>> -> vector<16x16xf32> + // CHECK-COUNT-2: xegpu.convert_layout {{.*}} <{input_layout = #xegpu.layout<inst_data = [16, 16]>, target_layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<16x16xf32> %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<32x64xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>> -> vector<32x64xf32> %2 = xegpu.convert_layout %1 <{input_layout = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>, target_layout = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [8, 16]>}> : vector<32x64xf32> diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir index 86a021b..c95c640 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir @@ -14,13 +14,11 @@ gpu.module @test_distribution { // CHECK-LABEL: load_nd_tdesc_with_offset gpu.func @load_nd_tdesc_with_offset(%src: memref<256x128xf32>) { - // CHECK-COUNT-4: xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}] - // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> - // CHECK-SAME-COUNT-4: -> vector<16x16xf32> + // CHECK-COUNT-4: xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf32> // CHECK-NOT: xegpu.load_nd %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> - %load = xegpu.load_nd %tdesc[0, 0] + %load = xegpu.load_nd %tdesc[0, 0] : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<256x128xf32> gpu.return @@ -28,8 +26,7 @@ gpu.module @test_distribution { // CHECK-LABEL: store_nd_with_offset gpu.func @store_nd_with_offset(%src: memref<256x128xf32>) { - // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, {{%.*}}[{{%.*}}, {{%.*}}] - // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, {{%.*}}[{{%.*}}, {{%.*}}] : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK-NOT: xegpu.store_nd %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> @@ -42,10 +39,8 @@ gpu.module @test_distribution { } // CHECK-LABEL: prefetch_nd_tdesc_with_offset - // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> gpu.func @prefetch_nd_tdesc_with_offset(%src: memref<256x128xf32>) { - // CHECK-COUNT-4: xegpu.prefetch_nd {{%.*}}[{{%.*}}, {{%.*}}] - // CHECK-SAME-COUNT-4: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK-COUNT-4: xegpu.prefetch_nd {{%.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK-NOT: xegpu.prefetch_nd %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> @@ -57,15 +52,11 @@ gpu.module @test_distribution { // CHECK-LABEL: dpas // CHECK-SAME: (%[[ARG_0:.*]]: memref<256x128xf16>, %[[ARG_1:.*]]: memref<128x256xf16>) gpu.func @dpas(%a: memref<256x128xf16>, %b: memref<128x256xf16>) { - // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf16> - // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> - // CHECK-NOT: xegpu.create_nd_tdesc - // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]] : memref<128x256xf16> - // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [4, 8], lane_data = [1, 1]>> - // CHECK-NOT: xegpu.create_nd_tdesc - // CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}} - // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} - // CHECK-SAME-COUNT-16: : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32> + // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK-COUNT-4: xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16> + // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]] : memref<128x256xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> + // CHECK-COUNT-4: xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16> + // CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32> // CHECK-NOT: xegpu.dpas %tdesc_a = xegpu.create_nd_tdesc %a : memref<256x128xf16> -> !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> @@ -99,30 +90,57 @@ gpu.module @test_distribution { gpu.return } + // CHECK-LABEL: non_splat_constant gpu.func @non_splat_constant() { - // CHECK-DAG: %[[BASECST:.*]] = arith.constant dense<{{.*}}> : vector<2x1xindex> + // CHECK-DAG: %[[CST:.*]] = arith.constant dense<{{.*}}0{{.*}}, {{.*}}16{{.*}}> : vector<2x1xindex> // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index - // CHECK-DAG: %[[MAP4:.*]] = affine.apply #map4()[%[[SGID]]] - // CHECK-DAG: %[[MAP5:.*]] = affine.apply #map5()[%[[SGID]]] - // CHECK-DAG: %[[MUL:.*]] = index.mul %[[MAP4]], %[[C2:.*]] - // CHECK-DAG: %[[REMU1:.*]] = index.remu %[[MUL]], %[[C32:.*]] - // CHECK-DAG: %[[REMU2:.*]] = index.remu %[[MAP5]], %[[C1:.*]] - // CHECK-DAG: %[[ADD16:.*]] = arith.addi %[[MUL]], %[[C16:.*]] : index - // CHECK-DAG: %[[REMU3:.*]] = index.remu %[[ADD16]], %[[C32:.*]] - // CHECK-DAG: %[[REMU4:.*]] = index.remu %[[MAP5]], %[[C1:.*]] - // CHECK-DAG: %[[STRIDE1:.*]] = arith.muli %[[REMU1]], %[[C16:.*]] : index - // CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[C0:.*]], %[[STRIDE1]] : index - // CHECK-DAG: %[[STRIDE2:.*]] = arith.muli %[[REMU2]], %[[C0:.*]] : index - // CHECK-DAG: %[[ADDSTRIDES1:.*]] = arith.addi %[[ADDSTRIDES]], %[[STRIDE2]] : index - // CHECK-DAG: %[[BCAST1:.*]] = vector.broadcast %[[ADDSTRIDES1]] : index to vector<2x1xindex> - // CHECK-DAG: %[[RESULT1:.*]] = arith.addi %[[BASECST]], %[[BCAST1]] : vector<2x1xindex> - // CHECK-DAG: %[[STRIDE3:.*]] = arith.muli %[[REMU3]], %[[C16:.*]] : index - // CHECK-DAG: %[[ADDSTRIDES2:.*]] = arith.addi %[[C0:.*]], %[[STRIDE3]] : index - // CHECK-DAG: %[[STRIDE4:.*]] = arith.muli %[[REMU4]], %[[C0:.*]] : index - // CHECK-DAG: %[[ADDSTRIDES3:.*]] = arith.addi %[[ADDSTRIDES2]], %[[STRIDE4]] : index - // CHECK-DAG: %[[BCAST2:.*]] = vector.broadcast %[[ADDSTRIDES3]] : index to vector<2x1xindex> - // CHECK-DAG: %[[RESULT2:.*]] = arith.addi %[[BASECST]], %[[BCAST2]] : vector<2x1xindex> + // CHECK-DAG: %[[T1:.*]] = arith.remui %[[SGID]], %[[C8:.*]] : index + // CHECK-DAG: %[[T2:.*]] = arith.muli %[[T1]], %[[C2:.*]] : index + // CHECK-DAG: %[[T3:.*]] = arith.remui %[[T2]], %[[C32:.*]] : index + // CHECK-DAG: %[[T4:.*]] = arith.addi %[[T2]], %[[C16:.*]] : index + // CHECK-DAG: %[[T5:.*]] = arith.remui %[[T4]], %[[C32_6:.*]] : index + // CHECK-DAG: %[[T6:.*]] = arith.muli %[[T3]], %[[C16_10:.*]] : index + // CHECK-DAG: %[[T7:.*]] = arith.addi %[[C0_11:.*]], %[[T6]] : index + // CHECK-DAG: %[[T8:.*]] = arith.muli %[[C0_4:.*]], %[[C0_9:.*]] : index + // CHECK-DAG: %[[T9:.*]] = arith.addi %[[T7]], %[[T8]] : index + // CHECK-DAG: %[[T10:.*]] = vector.broadcast %[[T9]] : index to vector<2x1xindex> + // CHECK-DAG: %[[T11:.*]] = arith.addi %[[CST]], %[[T10]] : vector<2x1xindex> + // CHECK-DAG: %[[T12:.*]] = arith.muli %[[T5]], %[[C16_10:.*]] : index + // CHECK-DAG: %[[T13:.*]] = arith.addi %[[C0_12:.*]], %[[T12]] : index + // CHECK-DAG: %[[T14:.*]] = arith.muli %[[C0_8:.*]], %[[C0_9:.*]] : index + // CHECK-DAG: %[[T15:.*]] = arith.addi %[[T13]], %[[T14]] : index + // CHECK-DAG: %[[T16:.*]] = vector.broadcast %[[T15]] : index to vector<2x1xindex> + // CHECK-DAG: %[[T17:.*]] = arith.addi %[[CST]], %[[T16]] : vector<2x1xindex> %cst_2 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [2, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex> gpu.return } + + // CHECK-LABEL: vector_transpose + gpu.func @vector_transpose(%src: memref<256x128xf32>) { + %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 16], lane_layout = [16, 1], lane_data = [1, 1], order =[0, 1]>> + %load = xegpu.load_nd %tdesc[0, 0] + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 16], lane_layout = [16, 1], lane_data = [1, 1], order =[0, 1]>> + -> vector<256x128xf32> + // CHECK-COUNT-2: vector.transpose {{.*}}, [1, 0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<32x16xf32> to vector<16x32xf32> + // CHECK-NOT: vector.transpose + %trans = vector.transpose %load, [1, 0] {layout_result_0 = #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 32], lane_layout = [1, 16], lane_data = [1, 1], order =[1, 0]>} : vector<256x128xf32> to vector<128x256xf32> + gpu.return + } + + // CHECK-LABEL: vector_mask_2D + gpu.func @vector_mask_2D() { + // CHECK-COUNT-4: vector.create_mask {{.*}}, {{.*}} : vector<16x16xi1> + // CHECK-NOT: vector.create_mask + %constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xi1> + gpu.return + } + + gpu.func @vector_create_mask_2D() { + // CHECK-COUNT-4: vector.create_mask {{.*}}, {{.*}} : vector<16x16xi1> + // CHECK-NOT: vector.create_mask + %cst16 = arith.constant 16 : index + %constant_mask = vector.create_mask %cst16, %cst16 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xi1> + gpu.return + } } diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 52acde4..69eb8ce 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -1,8 +1,5 @@ // RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s -//CHECK: #map = affine_map<()[s0] -> (s0 floordiv 4)> -//CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)> -//CHECK: #map2 = affine_map<()[s0] -> (s0 floordiv 8)> gpu.module @test_distribution { // CHECK-LABEL: create_nd_tdesc_no_offset // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> @@ -26,13 +23,23 @@ gpu.module @test_distribution { } // CHECK-LABEL: load_nd_tdesc_with_offset - // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> gpu.func @load_nd_tdesc_with_offset(%src: memref<256x128xf32>) { - //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index - //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]] - //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]] - //CHECK: %[[LOAD:.*]] = xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<32x32xf32> - %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32> + //CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + //CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index + //CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index + //CHECK-DAG: %[[SGIDX:.*]] = arith.remui %[[SGID]], %[[C4]] + //CHECK-DAG: %[[SGIDY_TMP:.*]] = arith.divui %[[SGID]], %[[C4]] + //CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index + //CHECK-DAG: %[[SGIDY:.*]] = arith.remui %[[SGIDY_TMP]], %[[C8]] + //CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index + //CHECK-DAG: %[[L_OFF_Y:.*]] = arith.muli %[[SGIDY]], %[[C32]] : index + //CHECK-DAG: %[[L_OFF_X:.*]] = arith.muli %[[SGIDX]], %[[C32_1:.*]] : index + //CHECK-DAG: %[[C256:.*]] = arith.constant 256 : index + //CHECK-DAG: %[[OFF_Y:.*]] = arith.remui %[[L_OFF_Y]], %[[C256]] : index + //CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index + //CHECK-DAG: %[[OFF_X:.*]] = arith.remui %[[L_OFF_X]], %[[C128]] : index + //CHECK-DAG: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]][{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<32x32xf32> + %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> %load = xegpu.load_nd %tdesc[0, 0] : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> @@ -43,9 +50,6 @@ gpu.module @test_distribution { // CHECK-LABEL: store_nd_with_offsets // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> gpu.func @store_nd_with_offsets(%src: memref<256x128xf32>) { - //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index - //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]] - //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]] //CHECK: xegpu.store_nd %{{.*}}, {{%.*}}[{{%.*}}, {{%.*}}] : vector<32x32xf32>, !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> @@ -60,9 +64,6 @@ gpu.module @test_distribution { // CHECK-LABEL: prefetch_nd_tdesc_with_offset // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> gpu.func @prefetch_nd_tdesc_with_offset(%src: memref<256x128xf32>) { - //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index - //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]] - //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]] //CHECK: xegpu.prefetch_nd %{{.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> %cst0 = arith.constant 0 : index %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32> @@ -285,14 +286,14 @@ gpu.module @test_distribution { // CHECK: %[[VAL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8]>} dense<2.550000e+01> : vector<8xf16> // CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8]>} dense<0> : vector<8xindex> // CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8]>} dense<true> : vector<8xi1> - // CHECK: xegpu.store %[[VAL]], %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> + // CHECK: xegpu.store %[[VAL]], %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>, layout = #xegpu.layout<inst_data = [8]>}> // CHECK-SAME: {layout_operand_0 = #xegpu.layout<inst_data = [8]>, layout_operand_2 = #xegpu.layout<inst_data = [8]>, // CHECK-SAME: layout_operand_3 = #xegpu.layout<inst_data = [8]>} // CHECK-SAME: : vector<8xf16>, memref<256xf16>, vector<8xindex>, vector<8xi1> %val = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>} dense<25.5> : vector<256xf16> %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>} dense<0> : vector<256xindex> %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>} dense<1> : vector<256xi1> - xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout_operand_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>, + xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout_operand_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>, layout_operand_2 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>, layout_operand_3 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>, l1_hint = #xegpu.cache_hint<cached>} @@ -319,21 +320,19 @@ gpu.module @test_distribution { gpu.func @distribute_load_matrix(%arg0: memref<32768xi8, 3>) { //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[arg0]] : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32> //CHECK: [[sgid:%.+]] = gpu.subgroup_id : index - //CHECK: [[c2:%.+]] = arith.constant 2 : index //CHECK: [[c4:%.+]] = arith.constant 4 : index - //CHECK: [[c4_0:%.+]] = arith.constant 4 : index - //CHECK: [[id_y:%.+]] = affine.apply #map()[[[sgid]]] - //CHECK: [[id_x:%.+]] = affine.apply #map1()[[[sgid]]] + //CHECK: [[sgidx:%.+]] = arith.remui [[sgid]], [[c4]] : index + //CHECK: [[sgidy_tmp:%.+]] = arith.divui [[sgid]], [[c4]] : index + //CHECK: [[c2:%.+]] = arith.constant 2 : index + //CHECK: [[sgidy:%.+]] = arith.remui [[sgidy_tmp]], [[c2]] : index //CHECK: [[c32:%.+]] = arith.constant 32 : index - //CHECK: [[l_off_y:%.+]] = index.mul [[id_y]], [[c32]] - //CHECK: [[c32_1:%.+]] = arith.constant 32 : index - //CHECK: [[l_off_x:%.+]] = index.mul [[id_x]], [[c32_1]] - //CHECK: [[c0:%.+]] = arith.constant 0 : index - //CHECK: [[c0_1:%.+]] = arith.constant 0 : index + //CHECK: [[l_off_y:%.+]] = arith.muli [[sgidy]], [[c32]] : index + //CHECK: [[c32_0:%.+]] = arith.constant 32 : index + //CHECK: [[l_off_x:%.+]] = arith.muli [[sgidx]], [[c32_0]] : index //CHECK: [[c64:%.+]] = arith.constant 64 : index - //CHECK: [[off_y:%.+]] = index.remu [[l_off_y]], [[c64]] + //CHECK: [[off_y:%.+]] = arith.remui [[l_off_y]], [[c64]] : index //CHECK: [[c128:%.+]] = arith.constant 128 : index - //CHECK: [[off_x:%.+]] = index.remu [[l_off_x]], [[c128]] + //CHECK: [[off_x:%.+]] = arith.remui [[l_off_x]], [[c128]] : index //CHECK: xegpu.load_matrix [[mdesc]][[[off_y]], [[off_x]]] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}>: !xegpu.mem_desc<64x128xf32>, index, index -> vector<32x32xf32> %0 = xegpu.create_mem_desc %arg0 : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32> %1 = xegpu.load_matrix %0[0, 0] <{layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32], lane_layout = [2, 8], lane_data = [1, 1]>}>: !xegpu.mem_desc<64x128xf32> -> vector<64x128xf32> @@ -346,21 +345,19 @@ gpu.module @test_distribution { //CHECK: [[cst:%.+]] = arith.constant dense<1.000000e+00> : vector<32x32xf32> //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[arg0]] : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32> //CHECK: [[sgid:%.+]] = gpu.subgroup_id : index - //CHECK: [[c2:%.+]] = arith.constant 2 : index //CHECK: [[c4:%.+]] = arith.constant 4 : index - //CHECK: [[c4_0:%.+]] = arith.constant 4 : index - //CHECK: [[id_y:%.+]] = affine.apply #map()[[[sgid]]] - //CHECK: [[id_x:%.+]] = affine.apply #map1()[[[sgid]]] + //CHECK: [[sgidx:%.+]] = arith.remui [[sgid]], [[c4]] : index + //CHECK: [[sgidy_tmp:%.+]] = arith.divui [[sgid]], [[c4]] : index + //CHECK: [[c2:%.+]] = arith.constant 2 : index + //CHECK: [[sgidy:%.+]] = arith.remui [[sgidy_tmp]], [[c2]] : index //CHECK: [[c32:%.+]] = arith.constant 32 : index - //CHECK: [[l_off_y:%.+]] = index.mul [[id_y]], [[c32]] - //CHECK: [[c32_1:%.+]] = arith.constant 32 : index - //CHECK: [[l_off_x:%.+]] = index.mul [[id_x]], [[c32_1]] - //CHECK: [[c0:%.+]] = arith.constant 0 : index - //CHECK: [[c0_2:%.+]] = arith.constant 0 : index + //CHECK: [[l_off_y:%.+]] = arith.muli [[sgidy]], [[c32]] : index + //CHECK: [[c32_0:%.+]] = arith.constant 32 : index + //CHECK: [[l_off_x:%.+]] = arith.muli [[sgidx]], [[c32_0]] : index //CHECK: [[c64:%.+]] = arith.constant 64 : index - //CHECK: [[off_y:%.+]] = index.remu [[l_off_y]], [[c64]] + //CHECK: [[off_y:%.+]] = arith.remui [[l_off_y]], [[c64]] : index //CHECK: [[c128:%.+]] = arith.constant 128 : index - //CHECK: [[off_x:%.+]] = index.remu [[l_off_x]], [[c128]] + //CHECK: [[off_x:%.+]] = arith.remui [[l_off_x]], [[c128]] : index //CHECK: xegpu.store_matrix [[cst]], [[mdesc]][[[off_y]], [[off_x]]] : vector<32x32xf32>, !xegpu.mem_desc<64x128xf32>, index, index %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32]>} dense<1.0> : vector<64x128xf32> %mdesc = xegpu.create_mem_desc %arg0 : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32> @@ -411,14 +408,17 @@ gpu.module @test_distribution { // CHECK-LABEL: vector_step_op gpu.func @vector_step_op_slice_attr() { //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index - //CHECK-DAG: [[IDY:%.+]] = affine.apply #map2()[[[sgId]]] - //CHECK-DAG: [[c32:%.+]] = arith.constant 32 : index - //CHECK-DAG: [[LY:%.+]] = index.mul [[IDY]], [[c32]] - //CHECK-DAG: [[c0:%.+]] = arith.constant 0 : index - //CHECK-DAG: [[c128:%.+]] = arith.constant 128 : index - //CHECK-DAG: [[MODY:%.+]] = index.remu [[LY]], [[c128]] - //CHECK-DAG: [[BASE:%.+]] = vector.step : vector<32xindex> - //CHECK-DAG: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex> + //CHECK: [[c8:%.+]] = arith.constant 8 : index + //CHECK: [[sgidx:%.+]] = arith.remui [[sgId]], [[c8]] : index + //CHECK: [[sgidy_tmp:%.+]] = arith.divui [[sgId]], [[c8]] : index + //CHECK: [[c4:%.+]] = arith.constant 4 : index + //CHECK: [[sgidy:%.+]] = arith.remui [[sgidy_tmp]], [[c4]] : index + //CHECK: [[c32:%.+]] = arith.constant 32 : index + //CHECK: [[LY:%.+]] = arith.muli [[sgidy]], [[c32]] : index + //CHECK: [[c128:%.+]] = arith.constant 128 : index + //CHECK: [[MODY:%.+]] = arith.remui [[LY]], [[c128]] : index + //CHECK: [[BASE:%.+]] = vector.step : vector<32xindex> + //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex> //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex> %step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>, dims = [1]>}: vector<128xindex> gpu.return @@ -426,14 +426,14 @@ gpu.module @test_distribution { gpu.func @vector_step_op_layout_attr() { //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index - //CHECK-DAG: [[c16:%.+]] = arith.constant 16 : index - //CHECK-DAG: [[c8:%.+]] = arith.constant 8 : index - //CHECK-DAG: [[LOCALY:%.+]] = index.mul [[sgId]], [[c8]] - //CHECK-DAG: [[c0:%.+]] = arith.constant 0 : index - //CHECK-DAG: [[c128:%.+]] = arith.constant 128 : index - //CHECK-DAG: [[MODY:%.+]] = index.remu [[LOCALY]], [[c128]] - //CHECK-DAG: [[BASE:%.+]] = vector.step : vector<8xindex> - //CHECK-DAG: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<8xindex> + //CHECK: [[c16:%.+]] = arith.constant 16 : index + //CHECK: [[sgidx:%.+]] = arith.remui [[sgId]], [[c16]] : index + //CHECK: [[c8:%.+]] = arith.constant 8 : index + //CHECK: [[LOCALY:%.+]] = arith.muli [[sgidx]], [[c8]] : index + //CHECK: [[c128:%.+]] = arith.constant 128 : index + //CHECK: [[MODY:%.+]] = arith.remui [[LOCALY]], [[c128]] : index + //CHECK: [[BASE:%.+]] = vector.step : vector<8xindex> + //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<8xindex> //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<8xindex> %step = vector.step {layout_result_0 = #xegpu.layout<sg_layout = [16], sg_data = [8]>}: vector<128xindex> gpu.return @@ -464,40 +464,49 @@ gpu.module @test_distribution { gpu.return } + // CHECK-LABEL: vector_transpose + gpu.func @vector_transpose(%src: memref<256x32xf32>) { + %tdesc = xegpu.create_nd_tdesc %src : memref<256x32xf32> + -> !xegpu.tensor_desc<256x32xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 32], lane_layout = [16, 1], lane_data = [1, 1], order =[0, 1]>> + %load = xegpu.load_nd %tdesc[0, 0] + : !xegpu.tensor_desc<256x32xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 32], lane_layout = [16, 1], lane_data = [1, 1], order =[0, 1]>> + -> vector<256x32xf32> + //CHECK: vector.transpose {{.*}}, [1, 0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<64x32xf32> to vector<32x64xf32> + %trans = vector.transpose %load, [1, 0] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], lane_layout = [1, 16], lane_data = [1, 1], order =[1, 0]>} : vector<256x32xf32> to vector<32x256xf32> + gpu.return + } + // CHECK-LABEL: non_splat_constant_2D gpu.func @non_splat_constant_2D() { // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1x1xindex> - // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index - // CHECK-DAG: affine.apply #map4()[%[[SGID]]] - // CHECK-DAG: affine.apply #map5()[%[[SGID]]] - // CHECK-DAG: %[[IDY:.*]] = index.remu %{{.*}}, %[[C32:.*]] - // CHECK-DAG: %[[IDX:.*]] = index.remu %{{.*}}, %[[C1:.*]] - // CHECK-DAG: %[[STRIDECOL:.*]] = arith.muli %[[IDY]], %[[C16:.*]] : index - // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[C0:.*]], %[[STRIDECOL]] : index - // CHECK-DAG: %[[STRIDEROW:.*]] = arith.muli %[[IDX]], %[[C0:.*]] : index - // CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[ADD]], %[[STRIDEROW]] : index - // CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ADDSTRIDES]] : index to vector<1x1xindex> - // CHECK-DAG: arith.addi %[[CST]], %[[BCAST]] : vector<1x1xindex> + // CHECK-DAG: %[[T0:.*]] = gpu.subgroup_id : index + // CHECK-DAG: %[[T1:.*]] = arith.remui %[[T0]], %[[C32:.*]] : index + // CHECK-DAG: %[[T2:.*]] = arith.remui %[[T1]], %[[C32_4:.*]] : index + // CHECK-DAG: %[[T3:.*]] = arith.muli %[[T2]], %[[C16:.*]] : index + // CHECK-DAG: %[[T4:.*]] = arith.addi %[[C0_8:.*]], %[[T3]] : index + // CHECK-DAG: %[[T5:.*]] = arith.muli %[[C0_6:.*]], %[[C0_7:.*]] : index + // CHECK-DAG: %[[T6:.*]] = arith.addi %[[T4]], %[[T5]] : index + // CHECK-DAG: %[[T7:.*]] = vector.broadcast %[[T6]] : index to vector<1x1xindex> + // CHECK-DAG: %[[T8:.*]] = arith.addi %[[CST]], %[[T7]] : vector<1x1xindex> %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [1, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex> gpu.return } // CHECK-LABEL: non_splat_constant_2D_non_unit_dim gpu.func @non_splat_constant_2D_non_unit_dim() { - // CHECK-DAG: %[[BASECST:.*]] = arith.constant dense<{{.*}} : vector<2x2xindex> + // CHECK-DAG: %[[BASECST:.*]] = arith.constant dense<{{\[}}{{\[}}0, 16{{\]}}, {{\[}}8, 24{{\]}}{{\]}}> : vector<2x2xindex> // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index - // CHECK-DAG: %[[IDY:.*]] = affine.apply #map()[%[[SGID]]] - // CHECK-DAG: %[[IDX:.*]] = affine.apply #map1()[%[[SGID]]] - // CHECK-DAG: %[[MULY:.*]] = index.mul %[[IDY]], %[[C2:.*]] - // CHECK-DAG: %[[C2_2:.*]] = arith.constant 2 : index - // CHECK-DAG: %[[MULX:.*]] = index.mul %[[IDX]], %[[C2:.*]] - // CHECK-DAG: %[[REMU_Y:.*]] = index.remu %[[MULY]], %[[C8:.*]] - // CHECK-DAG: %[[C8_2:.*]] = arith.constant 8 : index - // CHECK-DAG: %[[REMU_X:.*]] = index.remu %[[MULX]], %[[C8:.*]] - // CHECK-DAG: %[[MUL5:.*]] = arith.muli %[[REMU_Y]], %[[C8:.*]] : index + // CHECK-DAG: %[[SGIDX:.*]] = arith.remui %[[SGID]], %{{.*}} + // CHECK-DAG: %[[SGIDY_TMP:.*]] = arith.divui %[[SGID]], %{{.*}} + // CHECK-DAG: %[[SGIDY:.*]] = arith.remui %[[SGIDY_TMP]], %{{.*}} + // CHECK-DAG: %[[MULY:.*]] = arith.muli %[[SGIDY]], %[[C2:.*]] : index + // CHECK-DAG: %[[MULX:.*]] = arith.muli %[[SGIDX]], %{{.*}} : index + // CHECK-DAG: %[[REMU_Y:.*]] = arith.remui %[[MULY]], %[[C8:.*]] : index + // CHECK-DAG: %[[REMU_X:.*]] = arith.remui %[[MULX]], %{{.*}} : index + // CHECK-DAG: %[[MUL5:.*]] = arith.muli %[[REMU_Y]], %{{.*}} : index // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[C0:.*]], %[[MUL5]] : index // CHECK-DAG: %[[MUL6:.*]] = arith.muli %[[REMU_X]], %[[C16:.*]] : index - // CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[ADD]], %[[MUL6]] : index + // CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[ADD]], %[[MUL6]] : index // CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ADDSTRIDES]] : index to vector<2x2xindex> // CHECK-DAG: %[[ADDCST:.*]] = arith.addi %[[BASECST]], %[[BCAST]] : vector<2x2xindex> %cst_8x8 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2]>} dense<[ @@ -517,13 +526,14 @@ gpu.module @test_distribution { gpu.func @non_splat_constant() { // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index - // CHECK-DAG: %[[REMU:.*]] = index.remu %[[SGID]], %[[C32:.*]] - // CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU]], %[[C16:.*]] : index + // CHECK-DAG: %[[REMU:.*]] = arith.remui %[[SGID]], %{{.*}} + // CHECK-DAG: %[[REMU2:.*]] = arith.remui %[[REMU]], %{{.*}} + // CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU2]], %[[C16:.*]] : index // CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[C0:.*]], %[[MUL]] : index // CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ADDSTRIDES]] : index to vector<1xindex> // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[CST]], %[[BCAST]] : vector<1xindex> %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [1]>} dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256, 272, 288, 304, 320, 336, 352, 368, 384, 400, 416, 432, 448, 464, 480, 496]> : vector<32xindex> - // CHECK: arith.constant dense<{{\[}}[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]{{\]}}> : vector<1x16xindex> + // CHECK: arith.constant dense<{{\[}}{{\[}}0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15{{\]}}{{\]}}> : vector<1x16xindex> %cst_1 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [1, 16]>} dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]> : vector<1x16xindex> gpu.return } @@ -534,4 +544,106 @@ gpu.module @test_distribution { %broadcast = vector.broadcast %arg0 {layout_result_0 = #xegpu.layout<sg_layout = [4, 8, 1], sg_data = [1, 1, 1]>} : index to vector<4x1x1xindex> gpu.return } + + // CHECK-LABEL: vector_mask_1D + gpu.func @vector_mask_1D() { + // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK-DAG: %[[REMU:.*]] = arith.remui %[[SGID]], %[[C2:.*]] + // CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU]], %[[C16:.*]] : index + // CHECK-DAG: %[[REMU2:.*]] = arith.remui %[[MUL]], %[[C32:.*]] : index + // CHECK-DAG: %[[SUB:.*]] = arith.subi %[[C8:.*]], %[[REMU2]] : index + // CHECK-DAG: %[[MAX:.*]] = arith.maxsi %[[SUB]], %[[C0:.*]] : index + // CHECK-DAG: %[[MIN:.*]] = arith.minsi %[[MAX]], %[[C16:.*]] : index + // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MIN]] : vector<16xi1> + %constant_mask = vector.constant_mask [8] {layout_result_0 = #xegpu.layout<sg_layout = [2], sg_data = [16]>} : vector<32xi1> + gpu.return + } + + // CHECK-LABEL: vector_mask_2D + gpu.func @vector_mask_2D() { + // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK-DAG: %[[SGIDX:.*]] = arith.remui %[[SGID]], %[[C4:.*]] + // CHECK-DAG: %[[SGIDY_TMP:.*]] = arith.divui %[[SGID]], %[[C4:.*]] + // CHECK-DAG: %[[SGIDY:.*]] = arith.remui %[[SGIDY_TMP]], %[[C8:.*]] + // CHECK-DAG: %[[ROW:.*]] = arith.muli %[[SGIDY]], %[[C32:.*]] : index + // CHECK-DAG: %[[COL:.*]] = arith.muli %[[SGIDX]], %[[C32:.*]] : index + // CHECK-DAG: %[[MODROW:.*]] = arith.remui %[[ROW]], %[[C256:.*]] : index + // CHECK-DAG: %[[MODCOL:.*]] = arith.remui %[[COL]], %[[C128:.*]] : index + // CHECK-DAG: %[[SUBROW:.*]] = arith.subi %[[C16:.*]], %[[MODROW]] : index + // CHECK-DAG: %[[MAXROW:.*]] = arith.maxsi %[[SUBROW]], %[[C4:.*]] : index + // CHECK-DAG: %[[MINROW:.*]] = arith.minsi %[[MAXROW]], %[[C32:.*]] : index + // CHECK-DAG: %[[SUBCOL:.*]] = arith.subi %[[C16:.*]], %[[MODCOL]] : index + // CHECK-DAG: %[[MAXCOL:.*]] = arith.maxsi %[[SUBCOL]], %[[C7:.*]] : index + // CHECK-DAG: %[[MINCOL:.*]] = arith.minsi %[[MAXCOL]], %[[C32:.*]] : index + // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MINROW]], %[[MINCOL]] : vector<32x32xi1> + %constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<256x128xi1> + gpu.return + } + + // CHECK-LABEL: vector_create_mask_1D + gpu.func @vector_create_mask_1D() { + // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK-DAG: %[[REMU:.*]] = arith.remui %[[SGID]], %[[C2:.*]] + // CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU]], %[[C16:.*]] + // CHECK-DAG: %[[REMU2:.*]] = arith.remui %[[MUL]], %[[C32:.*]] + // CHECK-DAG: %[[SUB:.*]] = arith.subi %[[C8:.*]], %[[REMU2]] : index + // CHECK-DAG: %[[MAX:.*]] = arith.maxsi %[[SUB]], %[[C0:.*]] : index + // CHECK-DAG: %[[MIN:.*]] = arith.minsi %[[MAX]], %[[C16:.*]] : index + // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MIN]] : vector<16xi1> + %cst8 = arith.constant 8 : index + %constant_mask = vector.create_mask %cst8 {layout_result_0 = #xegpu.layout<sg_layout = [2], sg_data = [16]>} : vector<32xi1> + gpu.return + } + + // CHECK-LABEL: vector_create_mask_2D + gpu.func @vector_create_mask_2D() { + // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK-DAG: %[[SGIDX:.*]] = arith.remui %[[SGID]], %[[C4:.*]] + // CHECK-DAG: %[[SGIDY_TMP:.*]] = arith.divui %[[SGID]], %[[C4:.*]] + // CHECK-DAG: %[[SGIDY:.*]] = arith.remui %[[SGIDY_TMP]], %[[C8:.*]] + // CHECK-DAG: %[[ROW:.*]] = arith.muli %[[SGIDY]], %[[C32:.*]] + // CHECK-DAG: %[[COL:.*]] = arith.muli %[[SGIDX]], %[[C32:.*]] + // CHECK-DAG: %[[MODROW:.*]] = arith.remui %[[ROW]], %[[C256:.*]] + // CHECK-DAG: %[[MODCOL:.*]] = arith.remui %[[COL]], %[[C128:.*]] + // CHECK-DAG: %[[SUBROW:.*]] = arith.subi %[[C16:.*]], %[[MODROW]] : index + // CHECK-DAG: %[[MAXROW:.*]] = arith.maxsi %[[SUBROW]], %[[C0:.*]] : index + // CHECK-DAG: %[[MINROW:.*]] = arith.minsi %[[MAXROW]], %[[C32:.*]] : index + // CHECK-DAG: %[[SUBCOL:.*]] = arith.subi %[[C16:.*]], %[[MODCOL]] : index + // CHECK-DAG: %[[MAXCOL:.*]] = arith.maxsi %[[SUBCOL]], %[[C0:.*]] : index + // CHECK-DAG: %[[MINCOL:.*]] = arith.minsi %[[MAXCOL]], %[[C32:.*]] : index + // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MINROW]], %[[MINCOL]] : vector<32x32xi1> + %cst16 = arith.constant 16 : index + %constant_mask = vector.create_mask %cst16, %cst16 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<256x128xi1> + gpu.return + } + + // CHECK-LABEL: distribute_load_slice_attr + gpu.func @distribute_load_slice_attr() { + %2 = memref.alloca() {alignment = 1024} : memref<4096xf32> + %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8], sg_data = [32], inst_data = [16]> } dense<0> : vector<256xindex> + %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8], sg_data = [32], inst_data = [16]> } dense<1> : vector<256xi1> + + // CHECK: %[[LOAD:.*]] = xegpu.load {{.*}} <{chunk_size = 1 : i64, layout = #xegpu.slice<#xegpu.layout<inst_data = [8, 16]>, dims = [0]>}> + // CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<inst_data = [8, 16]>, dims = [0]>} : + // CHECK-SAME: memref<4096xf32>, vector<32xindex>, vector<32xi1> -> vector<32xf32> + %3 = xegpu.load %2[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>, dims = [0]> } : memref<4096xf32>, vector<256xindex>, vector<256xi1> -> vector<256xf32> + + // CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[LOAD]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : vector<32xf32> to vector<32x32xf32> + %4 = vector.broadcast %3 {layout_result_0 = + #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>} : vector<256xf32> to vector<256x256xf32> + gpu.return + } + + // CHECK-LABEL: load_nd_tdesc_with_anchor_layout + gpu.func @load_nd_tdesc_with_anchor_layout(%src: memref<256x128xf32>) { + //CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK: xegpu.load_nd %[[TDESC]][{{%.*}}, {{%.*}}] <{layout = #xegpu.layout<inst_data = [32, 16], lane_layout = [1, 16], lane_data = [1, 1]>}> + // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<32x32xf32> + %load = xegpu.load_nd %tdesc[0, 0] <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16],lane_layout = [1, 16], lane_data = [1, 1]>}> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> + -> vector<256x128xf32> + gpu.return + } } diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir index e83229e..a8015cc 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir @@ -1,47 +1,35 @@ // RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s -//CHECK: #map = affine_map<()[s0] -> (s0 floordiv 4)> -//CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)> gpu.module @test_1_1_assignment { // CHECK-LABEL: create_nd_tdesc - // CHECK-SAME: [[ARG_0:%.*]]: memref<256x128xf32> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) { - //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index - //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]] - //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]] - //CHECK: [[C32:%.+]] = arith.constant 32 : index - //CHECK: [[LY:%.+]] = index.mul [[SGIDY]], [[C32]] - //CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C32]] - //CHECK: [[C0:%.+]] = arith.constant 0 : index - //CHECK: [[C0_1:%.+]] = arith.constant 0 : index - //CHECK: [[C256:%.+]] = arith.constant 256 : index - //CHECK: [[Y:%.+]] = index.remu [[LY]], [[C256]] - //CHECK: [[C128:%.+]] = arith.constant 128 : index - //CHECK: [[X:%.+]] = index.remu [[LX]], [[C128]] - //CHECK: [[TDESC:%.+]] = xegpu.create_nd_tdesc [[ARG_0]][[[Y]], [[X]]] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK-DAG: %[[REMUX:.*]] = arith.remui %[[SGID]], %[[C4:.*]] + // CHECK-DAG: %[[DIVU:.*]] = arith.divui %[[SGID]], %[[C4:.*]] + // CHECK-DAG: %[[REMUY:.*]] = arith.remui %[[DIVU]], %[[C8:.*]] + // CHECK-DAG: %[[MULY:.*]] = arith.muli %[[REMUY]], %[[C32:.*]] + // CHECK-DAG: %[[MULX:.*]] = arith.muli %[[REMUX]], %[[C32:.*]] + // CHECK-DAG: %[[MODY:.*]] = arith.remui %[[MULY]], %[[C256:.*]] + // CHECK-DAG: %[[MODX:.*]] = arith.remui %[[MULX]], %[[C128:.*]] + // CHECK-DAG: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[MODY]], %[[MODX]]] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return } // CHECK-LABEL: create_nd_tdesc_from_higher_rank_memref - // CHECK-SAME: [[ARG_0:%.*]]: memref<3x256x128xf32> + // CHECK-SAME: %[[ARG_0:.*]]: memref<3x256x128xf32> gpu.func @create_nd_tdesc_from_higher_rank_memref(%src: memref<3x256x128xf32>) { - //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index - //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]] - //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]] - //CHECK: [[C32:%.+]] = arith.constant 32 : index - //CHECK: [[LY:%.+]] = index.mul [[SGIDY]], [[C32]] - //CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C32]] - //CHECK: [[C0:%.+]] = arith.constant 0 : index - //CHECK: [[C0_2:%.+]] = arith.constant 0 : index - //CHECK: [[C256:%.+]] = arith.constant 256 : index - //CHECK: [[MODY:%.+]] = index.remu [[LY]], [[C256]] - //CHECK: [[C128:%.+]] = arith.constant 128 : index - //CHECK: [[MODX:%.+]] = index.remu [[LX]], [[C128]] - //CHECK: [[C0_3:%.+]] = arith.constant 0 : index - //CHECK: [[C0_4:%.+]] = arith.constant 0 : index - //CHECK: [[TDESC:%.+]] = xegpu.create_nd_tdesc [[ARG_0]][1, [[MODY]], [[MODX]]] : memref<3x256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK-DAG: %[[REMUX:.*]] = arith.remui %[[SGID]], %[[C4:.*]] + // CHECK-DAG: %[[DIVU:.*]] = arith.divui %[[SGID]], %[[C4:.*]] + // CHECK-DAG: %[[REMUY:.*]] = arith.remui %[[DIVU]], %[[C8:.*]] + // CHECK-DAG: %[[MULY:.*]] = arith.muli %[[REMUY]], %[[C32:.*]] + // CHECK-DAG: %[[MULX:.*]] = arith.muli %[[REMUX]], %[[C32:.*]] + // CHECK-DAG: %[[MODY:.*]] = arith.remui %[[MULY]], %[[C256:.*]] + // CHECK-DAG: %[[MODX:.*]] = arith.remui %[[MULX]], %[[C128:.*]] + // CHECK-DAG: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][1, %[[MODY]], %[[MODX]]] : memref<3x256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> %tdesc = xegpu.create_nd_tdesc %src[1, 0, 0] : memref<3x256x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return @@ -81,25 +69,24 @@ gpu.module @test_1_1_assignment { xegpu.store_nd %load, %tdesc : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return -} + } -// CHECK-LABEL: update_nd -// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> -gpu.func @update_nd(%src: memref<256x128xf32>){ - // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32> - // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> - // CHECK: %[[UPDATE:.*]] = xegpu.update_nd_offset %[[TDESC]], [0, 16] - // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> - -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> - %update = xegpu.update_nd_offset %tdesc, [0, 16] - : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> - gpu.return -} + // CHECK-LABEL: update_nd + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @update_nd(%src: memref<256x128xf32>){ + // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK: %[[UPDATE:.*]] = xegpu.update_nd_offset %[[TDESC]], [0, 16] + // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> + %update = xegpu.update_nd_offset %tdesc, [0, 16] + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> + gpu.return + } -// CHECK-LABEL: dpas -gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>) { - // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x128xf16>, vector<128x16xf16> -> vector<16x16xf32> + // CHECK-LABEL: dpas + gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>) { %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<128x128xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>> %load_a = xegpu.load_nd %tdesc_a @@ -110,16 +97,15 @@ gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>) { %load_b = xegpu.load_nd %tdesc_b : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<128x128xf16> + // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x128xf16>, vector<128x16xf16> -> vector<16x16xf32> %dpas = xegpu.dpas %load_a, %load_b {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>} : vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32> gpu.return } - -// CHECK-LABEL: dpas_no_sg_data -gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) { - // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32> + // CHECK-LABEL: dpas_no_sg_data + gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) { %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<128x128xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>> @@ -134,6 +120,7 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) { : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1], order = [1, 0]>> -> vector<128x128xf16> + // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32> %dpas = xegpu.dpas %load_a, %load_b {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32> @@ -196,9 +183,9 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) { } gpu.func @scf_for(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) { - //CHECK: [[c0:%.+]] = arith.constant 0 : index - //CHECK: [[c128:%.+]] = arith.constant 128 : index - //CHECK: [[c1024:%.+]] = arith.constant 1024 : index + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index + // CHECK-DAG: %[[C1024:.*]] = arith.constant 1024 : index %c0 = arith.constant 0 : index %c128 = arith.constant 128 : index %c1024 = arith.constant 1024 : index @@ -211,15 +198,15 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) { %4 = xegpu.create_nd_tdesc %arg0[%0, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>> %5 = xegpu.create_nd_tdesc %arg1[%c0, %1] : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>> - // CHECK: [[scf:%.+]]:3 = scf.for [[arg3:%.+]] = [[c0]] to [[c1024]] step [[c128]] - // CHECK-SAME: iter_args([[arg4:%.+]] = {{.*}}, [[arg5:%.+]] = {{.*}}, [[arg6:%.+]] = {{.*}}) -> + // CHECK: %[[SCF:.*]]:3 = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C1024]] step %[[C128]] + // CHECK-SAME: iter_args(%[[ARG4:.*]] = {{.*}}, %[[ARG5:.*]] = {{.*}}, %[[ARG6:.*]] = {{.*}}) -> // CHECK-SAME: (!xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32>) - // CHECK: [[a:%.+]] = xegpu.load_nd [[arg4]] : !xegpu.tensor_desc<16x128xf16> -> vector<16x128xf16> - // CHECK: [[b:%.+]] = xegpu.load_nd [[arg5]] : !xegpu.tensor_desc<128x16xf16> -> vector<128x16xf16> - // CHECK: [[c:%.+]] = xegpu.dpas [[a]], [[b]], [[arg6]] : vector<16x128xf16>, vector<128x16xf16>, vector<16x16xf32> -> vector<16x16xf32> - // CHECK: [[at:%.+]] = xegpu.update_nd_offset [[arg4]], [[[c0]], [[c128]]] : !xegpu.tensor_desc<16x128xf16> - // CHECK: [[bt:%.+]] = xegpu.update_nd_offset [[arg5]], [[[c128]], [[c0]]] : !xegpu.tensor_desc<128x16xf16> - // CHECK: scf.yield [[at]], [[bt]], [[c]] : !xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32> + // CHECK: %[[A:.*]] = xegpu.load_nd %[[ARG4]] : !xegpu.tensor_desc<16x128xf16> -> vector<16x128xf16> + // CHECK: %[[B:.*]] = xegpu.load_nd %[[ARG5]] : !xegpu.tensor_desc<128x16xf16> -> vector<128x16xf16> + // CHECK: %[[C:.*]] = xegpu.dpas %[[A]], %[[B]], %[[ARG6]] : vector<16x128xf16>, vector<128x16xf16>, vector<16x16xf32> -> vector<16x16xf32> + // CHECK: %[[AT:.*]] = xegpu.update_nd_offset %[[ARG4]], [%[[C0]], %[[C128]]] : !xegpu.tensor_desc<16x128xf16> + // CHECK: %[[BT:.*]] = xegpu.update_nd_offset %[[ARG5]], [%[[C128]], %[[C0]]] : !xegpu.tensor_desc<128x16xf16> + // CHECK: scf.yield %[[AT]], %[[BT]], %[[C]] : !xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32> %6:3 = scf.for %arg3 = %c0 to %c1024 step %c128 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3) -> (!xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>, !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>, vector<128x128xf32>) { @@ -252,7 +239,7 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) { // CHECK: scf.condition{{.*}} : vector<16xf32>, i32 scf.condition(%4) %arg2, %arg3 : vector<256xf32>, i32 } do { - // CHECK: ([[arg2:%.+]]: vector<16xf32>, [[arg3:%.+]]: i32) + // CHECK: (%[[ARG2:.*]]: vector<16xf32>, %[[ARG3:.*]]: i32) ^bb0(%arg2: vector<256xf32>, %arg3: i32): xegpu.store_nd %arg2, %2 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>> %4 = arith.addi %arg3, %c1_i32 : i32 @@ -344,9 +331,9 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) { %cond4 = arith.cmpi slt, %sg_id, %c31 : index %cond5 = arith.andi %cond3, %cond4 : i1 scf.if %cond5 { - // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index - // CHECK: %[[C2:.*]] = arith.constant 2 : index - // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C2]] + // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK: %[[C2:.*]] = arith.constant 2 : index + // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C2]] %tdesc = xegpu.create_nd_tdesc %src2[0, 0] : memref<128x64xf32> -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>> %load = xegpu.load_nd %tdesc diff --git a/mlir/test/Examples/NVGPU/Ch0.py b/mlir/test/Examples/NVGPU/Ch0.py index 8f60088..e09720a 100644 --- a/mlir/test/Examples/NVGPU/Ch0.py +++ b/mlir/test/Examples/NVGPU/Ch0.py @@ -1,5 +1,9 @@ # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \ -# RUN: %PYTHON %s | FileCheck %s +# RUN: sh -c 'if [[ "%mlir_run_cuda_sm90_tests" == "1" ]]; \ +# RUN: then %PYTHON %s | FileCheck %s; \ +# RUN: else export MLIR_NVDSL_PRINT_IR=1; \ +# RUN: %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi' + # ===----------------------------------------------------------------------===// # Chapter 0 : Hello World @@ -33,7 +37,7 @@ def main(alpha): # + operator generates arith.addi myValue = alpha + tidx # Print from a GPU thread - gpu.printf("GPU thread %llu has %llu\n", [tidx, myValue]) + gpu.printf("GPU thread %llu has %llu\n", tidx, myValue) # 3. Call the GPU kernel kernel() @@ -43,8 +47,24 @@ alpha = 100 # 4. The `mlir_func` decorator JIT compiles the IR and executes the MLIR function. main(alpha) - # CHECK: GPU thread 0 has 100 # CHECK: GPU thread 1 has 101 # CHECK: GPU thread 2 has 102 # CHECK: GPU thread 3 has 103 + +# DUMPIR: func.func @main(%arg0: index) attributes {llvm.emit_c_interface} { +# DUMPIR: %[[C0_I32:.*]] = arith.constant 0 : i32 +# DUMPIR: %[[C1:.*]] = arith.constant 1 : index +# DUMPIR: %[[C1_0:.*]] = arith.constant 1 : index +# DUMPIR: %[[C1_1:.*]] = arith.constant 1 : index +# DUMPIR: %[[C4:.*]] = arith.constant 4 : index +# DUMPIR: %[[C1_2:.*]] = arith.constant 1 : index +# DUMPIR: %[[C1_3:.*]] = arith.constant 1 : index +# DUMPIR: gpu.launch blocks(%arg1, %arg2, %arg3) in (%arg7 = %[[C1]], %arg8 = %[[C1_0]], %arg9 = %[[C1_1]]) threads(%arg4, %arg5, %arg6) in (%arg10 = %[[C4]], %arg11 = %[[C1_2]], %arg12 = %[[C1_3]]) dynamic_shared_memory_size %[[C0_I32]] { +# DUMPIR: %[[TIDX:.*]] = gpu.thread_id x +# DUMPIR: %[[MYVAL:.*]] = arith.addi %arg0, %[[TIDX]] : index +# DUMPIR: gpu.printf "GPU thread %llu has %llu\0A", %[[TIDX]], %[[MYVAL]] : index, index +# DUMPIR: gpu.terminator +# DUMPIR: } +# DUMPIR: return +# DUMPIR: } diff --git a/mlir/test/Examples/NVGPU/Ch1.py b/mlir/test/Examples/NVGPU/Ch1.py index cfb48d5..6e44e4d 100644 --- a/mlir/test/Examples/NVGPU/Ch1.py +++ b/mlir/test/Examples/NVGPU/Ch1.py @@ -1,5 +1,9 @@ # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \ -# RUN: %PYTHON %s | FileCheck %s +# RUN: sh -c 'if [[ "%mlir_run_cuda_sm90_tests" == "1" ]]; \ +# RUN: then %PYTHON %s | FileCheck %s; \ +# RUN: else export MLIR_NVDSL_PRINT_IR=1; \ +# RUN: %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi' + # ===----------------------------------------------------------------------===// # Chapter 1 : 2D Saxpy @@ -24,12 +28,12 @@ import numpy as np def saxpy(x, y, alpha): # 1. Use MLIR GPU dialect to allocate and copy memory token_ty = gpu.AsyncTokenType.get() - t1 = gpu.wait(token_ty, []) + t1 = gpu.wait([]) x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], []) y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], []) t4 = gpu.memcpy(token_ty, [t3], x_dev, x) t5 = gpu.memcpy(token_ty, [t4], y_dev, y) - t6 = gpu.wait(token_ty, [t5]) + t6 = gpu.wait([t5]) # 2. Compute 2D SAXPY kernel @NVDSL.mlir_gpu_launch(grid=(M, 1, 1), block=(N, 1, 1)) @@ -47,7 +51,7 @@ def saxpy(x, y, alpha): saxpy_kernel() t7 = gpu.memcpy(token_ty, [t6], y, y_dev) - gpu.wait(token_ty, [t7]) + gpu.wait([t7]) # 3. Pass numpy arrays to MLIR @@ -56,11 +60,32 @@ N = 32 alpha = 2.0 x = np.random.randn(M, N).astype(np.float32) y = np.ones((M, N), np.float32) + saxpy(x, y, alpha) -# 4. Verify MLIR with reference computation -ref = np.ones((M, N), np.float32) -ref += x * alpha -np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01) -print("PASS") +if os.getenv("MLIR_NVDSL_PRINT_IR") != "1": + # 4. Verify MLIR with reference computation + ref = np.ones((M, N), np.float32) + ref += x * alpha + np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01) + print("PASS") # CHECK-NOT: Mismatched elements +# CHECK: PASS + +# DUMPIR: func.func @saxpy(%[[ARG0:.*]]: memref<256x32xf32>, %[[ARG1:.*]]: memref<256x32xf32>, %[[ARG2:.*]]: f32) attributes {llvm.emit_c_interface} { +# DUMPIR: %[[WAIT0:.*]] = gpu.wait async +# DUMPIR: %[[MEMREF:.*]], %[[ASYNC0:.*]] = gpu.alloc async [%[[WAIT0]]] () : memref<256x32xf32> +# DUMPIR: %[[MEMREF0:.*]], %[[ASYNC1:.*]] = gpu.alloc async [%[[ASYNC0]]] () : memref<256x32xf32> +# DUMPIR: %[[MEMCPY1:.*]] = gpu.memcpy async [%[[ASYNC1]]] %[[MEMREF]], %[[ARG0]] : memref<256x32xf32>, memref<256x32xf32> +# DUMPIR: %[[MEMCPY2:.*]] = gpu.memcpy async [%[[MEMCPY1]]] %[[MEMREF0]], %[[ARG1]] : memref<256x32xf32>, memref<256x32xf32> +# DUMPIR: %[[WAIT1:.*]] = gpu.wait async [%[[MEMCPY2]]] +# DUMPIR: %[[LD0:.*]] = memref.load %[[MEMREF]][%{{.*}}, %{{.*}}] : memref<256x32xf32> +# DUMPIR: %[[LD1:.*]] = memref.load %[[MEMREF0]][%{{.*}}, %{{.*}}] : memref<256x32xf32> +# DUMPIR: %[[MUL:.*]] = arith.mulf %[[LD0]], %[[ARG2]] : f32 +# DUMPIR: %[[ADD:.*]] = arith.addf %[[LD1]], %[[MUL]] : f32 +# DUMPIR: memref.store %[[ADD]], %[[MEMREF0]][%{{.*}}, %{{.*}}] : memref<256x32xf32> +# DUMPIR: gpu.terminator +# DUMPIR: %[[MEMCPY3:.*]] = gpu.memcpy async [%[[WAIT1]]] %[[ARG1]], %[[MEMREF0]] : memref<256x32xf32>, memref<256x32xf32> +# DUMPIR: %[[WAIT2:.*]] = gpu.wait async [%[[MEMCPY3]]] +# DUMPIR: return +# DUMPIR: } diff --git a/mlir/test/Examples/NVGPU/Ch2.py b/mlir/test/Examples/NVGPU/Ch2.py index 729913c..aba610c 100644 --- a/mlir/test/Examples/NVGPU/Ch2.py +++ b/mlir/test/Examples/NVGPU/Ch2.py @@ -1,5 +1,9 @@ # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \ -# RUN: %PYTHON %s | FileCheck %s +# RUN: sh -c 'if [[ "%mlir_run_cuda_sm90_tests" == "1" ]]; \ +# RUN: then %PYTHON %s | FileCheck %s; \ +# RUN: else export MLIR_NVDSL_PRINT_IR=1; \ +# RUN: %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi' + # ===----------------------------------------------------------------------===// # Chapter 2 : 2D Saxpy with TMA @@ -28,12 +32,12 @@ import numpy as np @NVDSL.mlir_func def saxpy(x, y, alpha): token_ty = gpu.AsyncTokenType.get() - t1 = gpu.wait(token_ty, []) + t1 = gpu.wait([]) x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], []) y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], []) t4 = gpu.memcpy(token_ty, [t3], x_dev, x) t5 = gpu.memcpy(token_ty, [t4], y_dev, y) - t6 = gpu.wait(token_ty, [t5]) + t6 = gpu.wait([t5]) x_tma = TMA([1, N], x.type) y_tma = TMA([1, N], y.type) @@ -74,7 +78,7 @@ def saxpy(x, y, alpha): saxpy_tma_kernel() t7 = gpu.memcpy(token_ty, [t6], y, y_dev) - gpu.wait(token_ty, [t7]) + gpu.wait([t7]) # 3. Pass numpy arrays to MLIR @@ -85,9 +89,46 @@ x = np.random.randn(M, N).astype(np.float32) y = np.ones((M, N), np.float32) saxpy(x, y, alpha) -# 4. Verify MLIR with reference computation -ref = np.ones((M, N), np.float32) -ref += x * alpha -np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01) -print("PASS") +if os.getenv("MLIR_NVDSL_PRINT_IR") != "1": + # 4. Verify MLIR with reference computation + ref = np.ones((M, N), np.float32) + ref += x * alpha + np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01) + print("PASS") # CHECK-NOT: Mismatched elements +# CHECK: PASS + +# DUMPIR: func.func @saxpy(%{{.*}}: memref<256x32xf32>, %[[ARG1:.*]]: memref<256x32xf32>, %[[ARG2:.*]]: f32) attributes {llvm.emit_c_interface} { +# DUMPIR: %[[WAIT0:.*]] = gpu.wait async +# DUMPIR: %[[MEMREF:.*]], %[[ASYNC0:.*]] = gpu.alloc async [%[[WAIT0]]] () : memref<256x32xf32> +# DUMPIR: %[[CAST:.*]] = memref.cast %[[MEMREF]] : memref<256x32xf32> to memref<*xf32> +# DUMPIR: %[[C1:.*]] = arith.constant 1 : index +# DUMPIR: %[[C32:.*]] = arith.constant 32 : index +# DUMPIR: %[[TMA0:.*]] = nvgpu.tma.create.descriptor %[[CAST]] box[%[[C1]], %[[C32]]] : memref<*xf32> -> <tensor = memref<1x32xf32, 3>, swizzle = none, l2promo = none, oob = zero, interleave = none> +# DUMPIR: %[[C0:.*]] = arith.constant 0 : index +# DUMPIR: %[[EQ:.*]] = arith.cmpi eq, %{{.*}}, %[[C0]] : index +# DUMPIR: %[[MB:.*]] = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>> +# DUMPIR: %[[C0_10:.*]] = arith.constant 0 : index +# DUMPIR: %[[C1_11:.*]] = arith.constant 1 : index +# DUMPIR: nvgpu.mbarrier.init %[[MB]][%[[C0_10]]], %[[C1_11]], predicate = %[[EQ]] : <memorySpace = #gpu.address_space<workgroup>> +# DUMPIR: %[[DSM0:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[C0_12:.*]] = arith.constant 0 : index +# DUMPIR: %[[VIEW:.*]] = memref.view %[[DSM0]][%[[C0_12]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<1x32xf32, #gpu.address_space<workgroup>> +# DUMPIR: %[[DSM1:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[C128:.*]] = arith.constant 128 : index +# DUMPIR: %[[VIEW_13:.*]] = memref.view %[[DSM1]][%[[C128]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<1x32xf32, #gpu.address_space<workgroup>> +# DUMPIR: nvgpu.tma.async.load %[[TMA0]][%{{.*}}, %{{.*}}], %[[MB]][%{{.*}}] to %[[VIEW]], predicate = %[[EQ]] : <tensor = memref<1x32xf32, 3>, swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>> -> memref<1x32xf32, #gpu.address_space<workgroup>> +# DUMPIR: nvgpu.mbarrier.arrive.expect_tx %[[MB]][%{{.*}}], %{{.*}}, predicate = %[[EQ]] : <memorySpace = #gpu.address_space<workgroup>> +# DUMPIR: %[[C0_20:.*]] = arith.constant 0 : index +# DUMPIR: %[[C10000000:.*]] = arith.constant 10000000 : index +# DUMPIR: %[[FALSE:.*]] = arith.constant false +# DUMPIR: nvgpu.mbarrier.try_wait.parity %[[MB]][%[[C0_20]]], %[[FALSE]], %[[C10000000]] : <memorySpace = #gpu.address_space<workgroup>> +# DUMPIR: %[[C0_21:.*]] = arith.constant 0 : index +# DUMPIR: %[[LD0:.*]] = memref.load %[[VIEW]][%[[C0_21]], %{{.*}}] : memref<1x32xf32, #gpu.address_space<workgroup>> +# DUMPIR: %[[C0_22:.*]] = arith.constant 0 : index +# DUMPIR: %[[LD1:.*]] = memref.load %[[VIEW_13]][%[[C0_22]], %{{.*}}] : memref<1x32xf32, #gpu.address_space<workgroup>> +# DUMPIR: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<256x32xf32> +# DUMPIR: %[[MEMCPY3:.*]] = gpu.memcpy async [%{{.*}}] %[[ARG1]], %{{.*}} : memref<256x32xf32>, memref<256x32xf32> +# DUMPIR: %{{.*}} = gpu.wait async [%[[MEMCPY3]]] +# DUMPIR: return +# DUMPIR: } diff --git a/mlir/test/Examples/NVGPU/Ch3.py b/mlir/test/Examples/NVGPU/Ch3.py index eb96b11..fe11575 100644 --- a/mlir/test/Examples/NVGPU/Ch3.py +++ b/mlir/test/Examples/NVGPU/Ch3.py @@ -1,5 +1,9 @@ # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \ -# RUN: %PYTHON %s | FileCheck %s +# RUN: sh -c 'if [[ "%mlir_run_cuda_sm90_tests" == "1" ]]; \ +# RUN: then %PYTHON %s | FileCheck %s; \ +# RUN: else export MLIR_NVDSL_PRINT_IR=1; \ +# RUN: %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi' + # ===----------------------------------------------------------------------===// # Chapter 3 : GEMM 128x128x64 with Tensor Core @@ -60,13 +64,13 @@ def tma_load( @NVDSL.mlir_func def gemm_128_128_64(a, b, d): token_ty = gpu.AsyncTokenType.get() - t1 = gpu.wait(token_ty, []) + t1 = gpu.wait([]) a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], []) b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], []) d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], []) t5 = gpu.memcpy(token_ty, [t4], a_dev, a) t6 = gpu.memcpy(token_ty, [t5], b_dev, b) - t7 = gpu.wait(token_ty, [t6]) + t7 = gpu.wait([t6]) sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B a_tma = TMA([128, 64], a.type, swizzle=sw) @@ -111,7 +115,7 @@ def gemm_128_128_64(a, b, d): gemm_tma_kernel() t8 = gpu.memcpy(token_ty, [t7], d, d_dev) - gpu.wait(None, [t8]) + gpu.wait([t8]) # Python pass arguments to MLIR @@ -123,7 +127,73 @@ b = np.random.randn(K, N).astype(np.float16) d = np.zeros((M, N), np.float32) gemm_128_128_64(a, b, d) -ref_d = a.astype(np.float16) @ b.astype(np.float16) -np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01) -print("PASS") +if os.getenv("MLIR_NVDSL_PRINT_IR") != "1": + # Verify MLIR program with reference computation in python + ref_d = a.astype(np.float16) @ b.astype(np.float16) + np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01) + print("PASS") # CHECK-NOT: Mismatched elements +# CHECK: PASS + +# DUMPIR: func.func @gemm_128_128_64(%{{.*}}: memref<128x64xf16>, %{{.*}}: memref<64x128xf16>, %[[ARG2:.*]]: memref<128x128xf32>) attributes {llvm.emit_c_interface} { +# DUMPIR: %[[C128:.*]] = arith.constant 128 : index +# DUMPIR: %[[C64:.*]] = arith.constant 64 : index +# DUMPIR: %[[TMA0:.*]] = nvgpu.tma.create.descriptor %{{.*}} box[%[[C128]], %[[C64]]] : memref<*xf16> -> <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> +# DUMPIR: %[[CAST1:.*]] = memref.cast %{{.*}} : memref<64x128xf16> to memref<*xf16> +# DUMPIR: %[[C64_5:.*]] = arith.constant 64 : index +# DUMPIR: %[[C64_6:.*]] = arith.constant 64 : index +# DUMPIR: %[[TMA1:.*]] = nvgpu.tma.create.descriptor %[[CAST1]] box[%[[C64_5]], %[[C64_6]]] : memref<*xf16> -> <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> +# DUMPIR: %[[THREADID:.*]] = gpu.thread_id x +# DUMPIR: %[[MB:.*]] = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>> +# DUMPIR: %[[C0:.*]] = arith.constant 0 : index +# DUMPIR: %[[EQ:.*]] = arith.cmpi eq, %[[THREADID]], %[[C0]] : index +# DUMPIR: %[[C0_12:.*]] = arith.constant 0 : index +# DUMPIR: %[[C1_13:.*]] = arith.constant 1 : index +# DUMPIR: nvgpu.mbarrier.init %[[MB]][%[[C0_12]]], %[[C1_13]], predicate = %[[EQ]] : <memorySpace = #gpu.address_space<workgroup>> +# DUMPIR: nvgpu.tma.prefetch.descriptor %[[TMA0]], predicate = %[[EQ]] : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> +# DUMPIR: nvgpu.tma.prefetch.descriptor %[[TMA1]], predicate = %[[EQ]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> +# DUMPIR: %[[DSM0:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[C0_14:.*]] = arith.constant 0 : index +# DUMPIR: %[[VIEW:.*]] = memref.view %[[DSM0]][%[[C0_14]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[DSM1:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[C16384:.*]] = arith.constant 16384 : index +# DUMPIR: %[[VIEW_15:.*]] = memref.view %[[DSM1]][%[[C16384]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x128xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[DSM2:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[C0_16:.*]] = arith.constant 0 : index +# DUMPIR: %[[VIEW_17:.*]] = memref.view %[[DSM2]][%[[C0_16]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[DSM3:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[C16384_18:.*]] = arith.constant 16384 : index +# DUMPIR: %[[VIEW_19:.*]] = memref.view %[[DSM3]][%[[C16384_18]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[DSM4:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[C24576:.*]] = arith.constant 24576 : index +# DUMPIR: %[[VIEW_20:.*]] = memref.view %[[DSM4]][%[[C24576]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[C0_21:.*]] = arith.constant 0 : index +# DUMPIR: %[[C32768:.*]] = arith.constant 32768 : index +# DUMPIR: nvgpu.mbarrier.arrive.expect_tx %[[MB]][%[[C0_21]]], %[[C32768]], predicate = %[[EQ]] : <memorySpace = #gpu.address_space<workgroup>> +# DUMPIR: %[[C0_22:.*]] = arith.constant 0 : index +# DUMPIR: %[[C0_23:.*]] = arith.constant 0 : index +# DUMPIR: %[[C0_24:.*]] = arith.constant 0 : index +# DUMPIR: nvgpu.tma.async.load %[[TMA0]][%[[C0_23]], %[[C0_24]]], %[[MB]][%[[C0_22]]] to %[[VIEW_17]], predicate = %[[EQ]] : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>> -> memref<128x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[C0_25:.*]] = arith.constant 0 : index +# DUMPIR: %[[C0_26:.*]] = arith.constant 0 : index +# DUMPIR: %[[C0_27:.*]] = arith.constant 0 : index +# DUMPIR: nvgpu.tma.async.load %[[TMA1]][%[[C0_26]], %[[C0_27]]], %[[MB]][%[[C0_25]]] to %[[VIEW_19]], predicate = %[[EQ]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>> -> memref<64x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[C0_28:.*]] = arith.constant 0 : index +# DUMPIR: %[[C64_29:.*]] = arith.constant 64 : index +# DUMPIR: %[[C0_30:.*]] = arith.constant 0 : index +# DUMPIR: nvgpu.tma.async.load %[[TMA1]][%[[C64_29]], %[[C0_30]]], %[[MB]][%[[C0_28]]] to %[[VIEW_20]], predicate = %[[EQ]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>> -> memref<64x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[C0_31:.*]] = arith.constant 0 : index +# DUMPIR: %[[C10000000:.*]] = arith.constant 10000000 : index +# DUMPIR: %[[FALSE:.*]] = arith.constant false +# DUMPIR: nvgpu.mbarrier.try_wait.parity %[[MB]][%[[C0_31]]], %[[FALSE]], %[[C10000000]] : <memorySpace = #gpu.address_space<workgroup>> +# DUMPIR: %[[WG_ACC:.*]] = nvgpu.warpgroup.mma.init.accumulator -> <fragmented = vector<128x128xf32>> +# DUMPIR: %[[GEN0:.*]] = nvgpu.warpgroup.generate.descriptor %[[VIEW]], %[[TMA0]] : memref<128x64xf16, #gpu.address_space<workgroup>>, <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>> +# DUMPIR: %[[GEN1:.*]] = nvgpu.warpgroup.generate.descriptor %[[VIEW_15]], %[[TMA1]] : memref<64x128xf16, #gpu.address_space<workgroup>>, <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>> +# DUMPIR: %[[MMA:.*]] = nvgpu.warpgroup.mma %[[GEN0]], %[[GEN1]], %[[WG_ACC]] {transposeB} : <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>>, <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>>, <fragmented = vector<128x128xf32>> -> <fragmented = vector<128x128xf32>> +# DUMPIR: nvgpu.warpgroup.mma.store %[[MMA]], %{{.*}} : <fragmented = vector<128x128xf32>> to memref<128x128xf32> +# DUMPIR: gpu.terminator +# DUMPIR: } +# DUMPIR: %[[CPY3:.*]] = gpu.memcpy async [%{{.*}}] %[[ARG2]], %{{.*}} : memref<128x128xf32>, memref<128x128xf32> +# DUMPIR: gpu.wait async [%[[CPY3]]] +# DUMPIR: return +# DUMPIR: } diff --git a/mlir/test/Examples/NVGPU/Ch4.py b/mlir/test/Examples/NVGPU/Ch4.py index 0e3460f..dffafda 100644 --- a/mlir/test/Examples/NVGPU/Ch4.py +++ b/mlir/test/Examples/NVGPU/Ch4.py @@ -1,5 +1,9 @@ # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \ -# RUN: %PYTHON %s | FileCheck %s +# RUN: sh -c 'if [[ "%mlir_run_cuda_sm90_tests" == "1" ]]; \ +# RUN: then %PYTHON %s | FileCheck %s; \ +# RUN: else export MLIR_NVDSL_PRINT_IR=1; \ +# RUN: %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi' + # ===----------------------------------------------------------------------===// # Chapter 4 : Multistage GEMM with Tensor Core @@ -259,13 +263,13 @@ def epilogue(D: WGMMAMatrix, d_dev): @NVDSL.mlir_func def gemm_multistage(a, b, d, num_stages): token_ty = gpu.AsyncTokenType.get() - t1 = gpu.wait(token_ty, []) + t1 = gpu.wait([]) a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], []) b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], []) d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], []) t5 = gpu.memcpy(token_ty, [t4], a_dev, a) t6 = gpu.memcpy(token_ty, [t5], b_dev, b) - t7 = gpu.wait(token_ty, [t6]) + t7 = gpu.wait([t6]) sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B a_tma = TMA([128, 64], a.type, swizzle=sw) @@ -297,7 +301,7 @@ def gemm_multistage(a, b, d, num_stages): gemm_multistage_kernel() t8 = gpu.memcpy(token_ty, [t7], d, d_dev) - gpu.wait(None, [t8]) + gpu.wait([t8]) # Python pass arguments to MLIR @@ -313,11 +317,153 @@ d = np.zeros((M, N), np.float32) gemm_multistage(a, b, d, num_stages=7) +if os.getenv("MLIR_NVDSL_PRINT_IR") != "1": + # Verify MLIR with reference computation + ref_d = a.astype(np.float16) @ b.astype(np.float16) + np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01) -# Verify MLIR with reference computation -ref_d = a.astype(np.float16) @ b.astype(np.float16) -np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01) - - -print("PASS") + print("PASS") # CHECK-NOT: Mismatched elements +# CHECK: PASS + +# DUMPIR: func.func @gemm_multistage(%{{.*}}: memref<512x1024xf16>, %{{.*}}: memref<1024x256xf16>, %{{.*}}: memref<512x256xf32>) attributes {llvm.emit_c_interface} { +# DUMPIR: scf.if %{{.*}} { +# DUMPIR: %[[C0_INIT:.*]] = arith.constant 0 : index +# DUMPIR: %[[C7:.*]] = arith.constant 7 : index +# DUMPIR: %[[C1_INIT:.*]] = arith.constant 1 : index +# DUMPIR: scf.for %arg15 = %[[C0_INIT]] to %[[C7]] step %[[C1_INIT]] { +# DUMPIR: %[[C1_MBAR:.*]] = arith.constant 1 : index +# DUMPIR: nvgpu.mbarrier.init %{{.*}}[%arg15], %[[C1_MBAR]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> +# DUMPIR: } +# DUMPIR: nvgpu.tma.prefetch.descriptor %{{.*}} : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> +# DUMPIR: nvgpu.tma.prefetch.descriptor %{{.*}} : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> +# DUMPIR: } +# DUMPIR: %[[C0_PROLOGUE:.*]] = arith.constant 0 : index +# DUMPIR: %[[C6:.*]] = arith.constant 6 : index +# DUMPIR: %[[C1_PROLOGUE:.*]] = arith.constant 1 : index +# DUMPIR: scf.for %arg15 = %[[C0_PROLOGUE]] to %[[C6]] step %[[C1_PROLOGUE]] { +# DUMPIR: %[[BID_X_P:.*]] = gpu.block_id x +# DUMPIR: %[[BID_Y_P:.*]] = gpu.block_id y +# DUMPIR: %[[C128_P1:.*]] = arith.constant 128 : index +# DUMPIR: %[[DIMX_P:.*]] = arith.muli %[[BID_X_P]], %[[C128_P1]] : index +# DUMPIR: %[[C128_P2:.*]] = arith.constant 128 : index +# DUMPIR: %[[DIMY_P:.*]] = arith.muli %[[BID_Y_P]], %[[C128_P2]] : index +# DUMPIR: %{{.*}} = gpu.thread_id x +# DUMPIR: %[[TID_X_P:.*]] = gpu.thread_id x +# DUMPIR: %[[C0_P:.*]] = arith.constant 0 : index +# DUMPIR: %[[PRED_P:.*]] = arith.cmpi eq, %[[TID_X_P]], %[[C0_P]] : index +# DUMPIR: %[[C16384_P1:.*]] = arith.constant 16384 : index +# DUMPIR: %[[OFF_A_P:.*]] = arith.muli %arg15, %[[C16384_P1]] : index +# DUMPIR: %[[C16384_P2:.*]] = arith.constant 16384 : index +# DUMPIR: %[[OFF_B_BASE_P:.*]] = arith.muli %arg15, %[[C16384_P2]] : index +# DUMPIR: %[[C114688:.*]] = arith.constant 114688 : index +# DUMPIR: %[[OFF_B1_P:.*]] = arith.addi %[[OFF_B_BASE_P]], %[[C114688]] : index +# DUMPIR: %[[C8192:.*]] = arith.constant 8192 : index +# DUMPIR: %[[OFF_B2_P:.*]] = arith.addi %[[OFF_B1_P]], %[[C8192]] : index +# DUMPIR: %[[SMEM_A_P:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[VIEW_A_P:.*]] = memref.view %[[SMEM_A_P]][%[[OFF_A_P]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[SMEM_B1_P:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[VIEW_B1_P:.*]] = memref.view %[[SMEM_B1_P]][%[[OFF_B1_P]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[SMEM_B2_P:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[VIEW_B2_P:.*]] = memref.view %[[SMEM_B2_P]][%[[OFF_B2_P]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[C32768:.*]] = arith.constant 32768 : index +# DUMPIR: nvgpu.mbarrier.arrive.expect_tx %{{.*}}[%arg15], %[[C32768]], predicate = %[[PRED_P]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> +# DUMPIR: %[[C64_K_P:.*]] = arith.constant 64 : index +# DUMPIR: %[[K_COORD_P:.*]] = arith.muli %arg15, %[[C64_K_P]] : index +# DUMPIR: nvgpu.tma.async.load %{{.*}}[%[[K_COORD_P]], %[[DIMX_P]]], %{{.*}}[%arg15] to %[[VIEW_A_P]], predicate = %[[PRED_P]] : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<128x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: nvgpu.tma.async.load %{{.*}}[%[[DIMY_P]], %[[K_COORD_P]]], %{{.*}}[%arg15] to %[[VIEW_B1_P]], predicate = %[[PRED_P]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[C64_OFF:.*]] = arith.constant 64 : index +# DUMPIR: %[[DIMY_P_OFF:.*]] = arith.addi %[[DIMY_P]], %[[C64_OFF]] : index +# DUMPIR: nvgpu.tma.async.load %{{.*}}[%[[DIMY_P_OFF]], %[[K_COORD_P]]], %{{.*}}[%arg15] to %[[VIEW_B2_P]], predicate = %[[PRED_P]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: } +# DUMPIR: %[[TID_X_LOOP:.*]] = gpu.thread_id x +# DUMPIR: %[[ACC_INIT:.*]] = nvgpu.warpgroup.mma.init.accumulator -> <fragmented = vector<128x128xf32>> +# DUMPIR: %[[FALSE_LOOP:.*]] = arith.constant false +# DUMPIR: %[[C0_LOOP:.*]] = arith.constant 0 : index +# DUMPIR: %[[C16_LOOP:.*]] = arith.constant 16 : index +# DUMPIR: %[[C1_LOOP:.*]] = arith.constant 1 : index +# DUMPIR: %[[LOOP_RES:.*]]:2 = scf.for %arg15 = %[[C0_LOOP]] to %[[C16_LOOP]] step %[[C1_LOOP]] iter_args(%arg16 = %[[ACC_INIT]], %arg17 = %[[FALSE_LOOP]]) -> (!nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>, i1) { +# DUMPIR: %[[C7_L:.*]] = arith.constant 7 : index +# DUMPIR: %[[STAGE_L:.*]] = arith.remui %arg15, %[[C7_L]] : index +# DUMPIR: %[[C10M:.*]] = arith.constant 10000000 : index +# DUMPIR: nvgpu.mbarrier.try_wait.parity %{{.*}}[%[[STAGE_L]]], %arg17, %[[C10M]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> +# DUMPIR: %[[C16384_L:.*]] = arith.constant 16384 : index +# DUMPIR: %[[OFF_A_L:.*]] = arith.muli %[[STAGE_L]], %[[C16384_L]] : index +# DUMPIR: %[[C114688_L:.*]] = arith.constant 114688 : index +# DUMPIR: %[[OFF_B_L:.*]] = arith.addi %[[OFF_A_L]], %[[C114688_L]] : index +# DUMPIR: %[[SMEM_A_L:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[VIEW_A_L:.*]] = memref.view %[[SMEM_A_L]][%[[OFF_A_L]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[SMEM_B_L:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[VIEW_B_L:.*]] = memref.view %[[SMEM_B_L]][%[[OFF_B_L]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x128xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[DESC_A_L:.*]] = nvgpu.warpgroup.generate.descriptor %[[VIEW_A_L]], %{{.*}} : memref<128x64xf16, #gpu.address_space<workgroup>>, <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>> +# DUMPIR: %[[DESC_B_L:.*]] = nvgpu.warpgroup.generate.descriptor %[[VIEW_B_L]], %{{.*}} : memref<64x128xf16, #gpu.address_space<workgroup>>, <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>> +# DUMPIR: %[[ACC_L:.*]] = nvgpu.warpgroup.mma %[[DESC_A_L]], %[[DESC_B_L]], %arg16 {transposeB} : <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>>, <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>>, <fragmented = vector<128x128xf32>> -> <fragmented = vector<128x128xf32>> +# DUMPIR: %[[C6_NEXT:.*]] = arith.constant 6 : index +# DUMPIR: %[[ITER_NEXT:.*]] = arith.addi %arg15, %[[C6_NEXT]] : index +# DUMPIR: %[[C16_CMP:.*]] = arith.constant 16 : index +# DUMPIR: %[[IN_RANGE:.*]] = arith.cmpi ult, %[[ITER_NEXT]], %[[C16_CMP]] : index +# DUMPIR: %[[C0_CMP:.*]] = arith.constant 0 : index +# DUMPIR: %[[IS_THREAD0_L:.*]] = arith.cmpi eq, %[[TID_X_LOOP]], %[[C0_CMP]] : index +# DUMPIR: %[[DO_LOAD:.*]] = arith.andi %[[IN_RANGE]], %[[IS_THREAD0_L]] : i1 +# DUMPIR: %[[C6_STAGE:.*]] = arith.constant 6 : index +# DUMPIR: %[[STAGE_NEXT_L:.*]] = arith.addi %arg15, %[[C6_STAGE]] : index +# DUMPIR: %[[C7_MOD:.*]] = arith.constant 7 : index +# DUMPIR: %[[STAGE_LOAD:.*]] = arith.remui %[[STAGE_NEXT_L]], %[[C7_MOD]] : index +# DUMPIR: %[[BID_X_L:.*]] = gpu.block_id x +# DUMPIR: %[[BID_Y_L:.*]] = gpu.block_id y +# DUMPIR: %[[C128_L1:.*]] = arith.constant 128 : index +# DUMPIR: %[[DIMX_L:.*]] = arith.muli %[[BID_X_L]], %[[C128_L1]] : index +# DUMPIR: %[[C128_L2:.*]] = arith.constant 128 : index +# DUMPIR: %[[DIMY_L:.*]] = arith.muli %[[BID_Y_L]], %[[C128_L2]] : index +# DUMPIR: %[[TID_X_L1:.*]] = gpu.thread_id x +# DUMPIR: %[[TID_X_L2:.*]] = gpu.thread_id x +# DUMPIR: %[[C16384_LA1:.*]] = arith.constant 16384 : index +# DUMPIR: %[[OFF_A_LOAD:.*]] = arith.muli %[[STAGE_LOAD]], %[[C16384_LA1]] : index +# DUMPIR: %[[C16384_LA2:.*]] = arith.constant 16384 : index +# DUMPIR: %[[OFF_B_BASE_LOAD:.*]] = arith.muli %[[STAGE_LOAD]], %[[C16384_LA2]] : index +# DUMPIR: %[[C114688_LOAD:.*]] = arith.constant 114688 : index +# DUMPIR: %[[OFF_B1_LOAD:.*]] = arith.addi %[[OFF_B_BASE_LOAD]], %[[C114688_LOAD]] : index +# DUMPIR: %[[C8192_LOAD:.*]] = arith.constant 8192 : index +# DUMPIR: %[[OFF_B2_LOAD:.*]] = arith.addi %[[OFF_B1_LOAD]], %[[C8192_LOAD]] : index +# DUMPIR: %[[SMEM_A_LOAD:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[VIEW_A_LOAD:.*]] = memref.view %[[SMEM_A_LOAD]][%[[OFF_A_LOAD]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[SMEM_B1_LOAD:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[VIEW_B1_LOAD:.*]] = memref.view %[[SMEM_B1_LOAD]][%[[OFF_B1_LOAD]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[SMEM_B2_LOAD:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[VIEW_B2_LOAD:.*]] = memref.view %[[SMEM_B2_LOAD]][%[[OFF_B2_LOAD]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[C32768_LOAD:.*]] = arith.constant 32768 : index +# DUMPIR: nvgpu.mbarrier.arrive.expect_tx %{{.*}}[%[[STAGE_LOAD]]], %[[C32768_LOAD]], predicate = %[[DO_LOAD]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> +# DUMPIR: %[[C64_K_LOAD:.*]] = arith.constant 64 : index +# DUMPIR: %[[K_COORD_LOAD:.*]] = arith.muli %[[STAGE_NEXT_L]], %[[C64_K_LOAD]] : index +# DUMPIR: nvgpu.tma.async.load %{{.*}}[%[[K_COORD_LOAD]], %[[DIMX_L]]], %{{.*}}[%[[STAGE_LOAD]]] to %[[VIEW_A_LOAD]], predicate = %[[DO_LOAD]] : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<128x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: nvgpu.tma.async.load %{{.*}}[%[[DIMY_L]], %[[K_COORD_LOAD]]], %{{.*}}[%[[STAGE_LOAD]]] to %[[VIEW_B1_LOAD]], predicate = %[[DO_LOAD]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[C64_OFF_LOAD:.*]] = arith.constant 64 : index +# DUMPIR: %[[DIMY_L_OFF:.*]] = arith.addi %[[DIMY_L]], %[[C64_OFF_LOAD]] : index +# DUMPIR: nvgpu.tma.async.load %{{.*}}[%[[DIMY_L_OFF]], %[[K_COORD_LOAD]]], %{{.*}}[%[[STAGE_LOAD]]] to %[[VIEW_B2_LOAD]], predicate = %[[DO_LOAD]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[C6_FLIP:.*]] = arith.constant 6 : index +# DUMPIR: %[[IS_STAGE6:.*]] = arith.cmpi eq, %[[STAGE_L]], %[[C6_FLIP]] : index +# DUMPIR: %[[TRUE:.*]] = arith.constant true +# DUMPIR: %[[PARITY_FLIP:.*]] = arith.xori %arg17, %[[TRUE]] : i1 +# DUMPIR: %[[NEW_PARITY:.*]] = arith.select %[[IS_STAGE6]], %[[PARITY_FLIP]], %arg17 : i1 +# DUMPIR: scf.yield %[[ACC_L]], %[[NEW_PARITY]] : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>, i1 +# DUMPIR: } +# DUMPIR: nvvm.wgmma.wait.group.sync.aligned 0 +# DUMPIR: %[[TID_X_EPI:.*]] = gpu.thread_id x +# DUMPIR: %[[BID_X_EPI:.*]] = gpu.block_id x +# DUMPIR: %[[BID_Y_EPI:.*]] = gpu.block_id y +# DUMPIR: %[[C128_EPI1:.*]] = arith.constant 128 : index +# DUMPIR: %[[DIMX_EPI:.*]] = arith.muli %[[BID_X_EPI]], %[[C128_EPI1]] : index +# DUMPIR: %[[C128_EPI2:.*]] = arith.constant 128 : index +# DUMPIR: %[[DIMY_EPI:.*]] = arith.muli %[[BID_Y_EPI]], %[[C128_EPI2]] : index +# DUMPIR: %[[SMEM_EPI:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[C0_VIEW:.*]] = arith.constant 0 : index +# DUMPIR: %[[VIEW_EPI:.*]] = memref.view %[[SMEM_EPI]][%[[C0_VIEW]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x128xf32, #gpu.address_space<workgroup>> +# DUMPIR: %[[SUBVIEW_EPI:.*]] = memref.subview %{{.*}}[%[[DIMX_EPI]], %[[DIMY_EPI]]] [128, 128] [1, 1] : memref<512x256xf32> to memref<128x128xf32, strided<[256, 1], offset: ?>> +# DUMPIR: nvgpu.warpgroup.mma.store %[[LOOP_RES]]#0, %[[VIEW_EPI]] : <fragmented = vector<128x128xf32>> to memref<128x128xf32, #gpu.address_space<workgroup>> +# DUMPIR: gpu.barrier +# DUMPIR: %[[C0_STORE:.*]] = arith.constant 0 : index +# DUMPIR: %[[C128_STORE:.*]] = arith.constant 128 : index +# DUMPIR: %[[C1_STORE:.*]] = arith.constant 1 : index +# DUMPIR: scf.for %arg15 = %[[C0_STORE]] to %[[C128_STORE]] step %[[C1_STORE]] { +# DUMPIR: %[[VAL_LOAD:.*]] = memref.load %[[VIEW_EPI]][%arg15, %[[TID_X_EPI]]] : memref<128x128xf32, #gpu.address_space<workgroup>> +# DUMPIR: memref.store %[[VAL_LOAD]], %[[SUBVIEW_EPI]][%arg15, %[[TID_X_EPI]]] : memref<128x128xf32, strided<[256, 1], offset: ?>> diff --git a/mlir/test/Examples/NVGPU/Ch5.py b/mlir/test/Examples/NVGPU/Ch5.py index f98cfd7..b725e50 100644 --- a/mlir/test/Examples/NVGPU/Ch5.py +++ b/mlir/test/Examples/NVGPU/Ch5.py @@ -1,5 +1,9 @@ # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \ -# RUN: %PYTHON %s | FileCheck %s +# RUN: sh -c 'if [[ "%mlir_run_cuda_sm90_tests" == "1" ]]; \ +# RUN: then %PYTHON %s | FileCheck %s; \ +# RUN: else export MLIR_NVDSL_PRINT_IR=1; \ +# RUN: %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi' + # ===----------------------------------------------------------------------===// # Chapter 5 : Warp Specialized GEMM with Tensor Core @@ -156,7 +160,7 @@ def producer_loop( ): phase = const(True, ty=T.bool()) - for iv, phase in scf.for_(0, (K // TILE_K), 1, [phase]): + for iv, phase, _ in scf.for_(0, (K // TILE_K), 1, [phase]): stage = iv % num_stages # Wait MMA to be done mbar_mma[stage].try_wait(phase) @@ -253,13 +257,13 @@ def epilogue(D: WGMMAMatrix, d_dev): @NVDSL.mlir_func def gemm_warp_specialized(a, b, d, num_stages): token_ty = gpu.AsyncTokenType.get() - t1 = gpu.wait(token_ty, []) + t1 = gpu.wait([]) a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], []) b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], []) d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], []) t5 = gpu.memcpy(token_ty, [t4], a_dev, a) t6 = gpu.memcpy(token_ty, [t5], b_dev, b) - t7 = gpu.wait(token_ty, [t6]) + t7 = gpu.wait([t6]) sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B a_tma = TMA([128, 64], a.type, swizzle=sw) @@ -295,7 +299,7 @@ def gemm_warp_specialized(a, b, d, num_stages): gemm_warp_specialized_kernel() t8 = gpu.memcpy(token_ty, [t7], d, d_dev) - gpu.wait(None, [t8]) + gpu.wait([t8]) # Python pass arguments to MLIR @@ -311,11 +315,166 @@ d = np.zeros((M, N), np.float32) gemm_warp_specialized(a, b, d, num_stages=7) +if os.getenv("MLIR_NVDSL_PRINT_IR") != "1": + # Verify MLIR with reference computation + ref_d = a.astype(np.float16) @ b.astype(np.float16) + np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01) -# Verify MLIR with reference computation -ref_d = a.astype(np.float16) @ b.astype(np.float16) -np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01) - - -print("PASS") + print("PASS") # CHECK-NOT: Mismatched elements +# CHECK: PASS + +# DUMPIR: %[[TID_X:.*]] = gpu.thread_id x +# DUMPIR: %[[C128:.*]] = arith.constant 128 : index +# DUMPIR: %[[REM1:.*]] = arith.remui %[[TID_X]], %[[C128]] : index +# DUMPIR: %[[C0:.*]] = arith.constant 0 : index +# DUMPIR: %[[IS_PRIMARY:.*]] = arith.cmpi eq, %[[REM1]], %[[C0]] : index +# DUMPIR: %[[C128_1:.*]] = arith.constant 128 : index +# DUMPIR: %[[DIV1:.*]] = arith.divui %[[TID_X]], %[[C128_1]] : index +# DUMPIR: %[[C1:.*]] = arith.constant 1 : index +# DUMPIR: %[[IS_PRODUCER:.*]] = arith.cmpi eq, %[[DIV1]], %[[C1]] : index +# DUMPIR: %[[TID_X_2:.*]] = gpu.thread_id x +# DUMPIR: %[[C128_2:.*]] = arith.constant 128 : index +# DUMPIR: %[[REM2:.*]] = arith.remui %[[TID_X_2]], %[[C128_2]] : index +# DUMPIR: %[[C0_2:.*]] = arith.constant 0 : index +# DUMPIR: %[[IS_PRIMARY_2:.*]] = arith.cmpi eq, %[[REM2]], %[[C0_2]] : index +# DUMPIR: %[[C128_3:.*]] = arith.constant 128 : index +# DUMPIR: %[[DIV2:.*]] = arith.divui %[[TID_X_2]], %[[C128_3]] : index +# DUMPIR: %[[C0_3:.*]] = arith.constant 0 : index +# DUMPIR: %[[IS_CONSUMER:.*]] = arith.cmpi eq, %[[DIV2]], %[[C0_3]] : index +# DUMPIR: %[[TID_X_3:.*]] = gpu.thread_id x +# DUMPIR: %[[MBAR_MMA:.*]] = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> +# DUMPIR: %[[MBAR_TMA:.*]] = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> +# DUMPIR: %[[C0_4:.*]] = arith.constant 0 : index +# DUMPIR: %[[IS_THREAD0:.*]] = arith.cmpi eq, %[[TID_X_3]], %[[C0_4]] : index +# DUMPIR: scf.if %[[IS_THREAD0]] { +# DUMPIR: %[[C0_INIT:.*]] = arith.constant 0 : index +# DUMPIR: %[[C7:.*]] = arith.constant 7 : index +# DUMPIR: %[[C1_INIT:.*]] = arith.constant 1 : index +# DUMPIR: scf.for %arg15 = %[[C0_INIT]] to %[[C7]] step %[[C1_INIT]] { +# DUMPIR: %[[C1_INIT_VAL:.*]] = arith.constant 1 : index +# DUMPIR: nvgpu.mbarrier.init %[[MBAR_MMA]][%arg15], %[[C1_INIT_VAL]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> +# DUMPIR: %[[C1_INIT_VAL_2:.*]] = arith.constant 1 : index +# DUMPIR: nvgpu.mbarrier.init %[[MBAR_TMA]][%arg15], %[[C1_INIT_VAL_2]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> +# DUMPIR: } +# DUMPIR: nvgpu.tma.prefetch.descriptor %{{.*}} : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> +# DUMPIR: nvgpu.tma.prefetch.descriptor %{{.*}} : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> +# DUMPIR: } +# DUMPIR: scf.if %[[IS_PRODUCER]] { +# DUMPIR: nvvm.setmaxregister decrease 40 +# DUMPIR: %[[TRUE:.*]] = arith.constant true +# DUMPIR: %[[C0_PROD:.*]] = arith.constant 0 : index +# DUMPIR: %[[C16:.*]] = arith.constant 16 : index +# DUMPIR: %[[C1_PROD:.*]] = arith.constant 1 : index +# DUMPIR: %[[PROD_LOOP:.*]] = scf.for %arg15 = %[[C0_PROD]] to %[[C16]] step %[[C1_PROD]] iter_args(%arg16 = %[[TRUE]]) -> (i1) { +# DUMPIR: %[[C7_PROD:.*]] = arith.constant 7 : index +# DUMPIR: %[[SLOT:.*]] = arith.remui %arg15, %[[C7_PROD]] : index +# DUMPIR: %[[TIMEOUT:.*]] = arith.constant 10000000 : index +# DUMPIR: nvgpu.mbarrier.try_wait.parity %[[MBAR_MMA]][%[[SLOT]]], %arg16, %[[TIMEOUT]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> +# DUMPIR: %[[C6:.*]] = arith.constant 6 : index +# DUMPIR: %[[IS_LAST:.*]] = arith.cmpi eq, %[[SLOT]], %[[C6]] : index +# DUMPIR: %[[TRUE_2:.*]] = arith.constant true +# DUMPIR: %[[FLIP:.*]] = arith.xori %arg16, %[[TRUE_2]] : i1 +# DUMPIR: %[[PHASE:.*]] = arith.select %[[IS_LAST]], %[[FLIP]], %arg16 : i1 +# DUMPIR: %[[BID_X:.*]] = gpu.block_id x +# DUMPIR: %[[BID_Y:.*]] = gpu.block_id y +# DUMPIR: %[[C128_TILE:.*]] = arith.constant 128 : index +# DUMPIR: %[[DIM_X:.*]] = arith.muli %[[BID_X]], %[[C128_TILE]] : index +# DUMPIR: %[[C128_TILE_2:.*]] = arith.constant 128 : index +# DUMPIR: %[[DIM_Y:.*]] = arith.muli %[[BID_Y]], %[[C128_TILE_2]] : index +# DUMPIR: %[[TID_PROD:.*]] = gpu.thread_id x +# DUMPIR: %[[C16384:.*]] = arith.constant 16384 : index +# DUMPIR: %[[OFF_A:.*]] = arith.muli %[[SLOT]], %[[C16384]] : index +# DUMPIR: %[[C16384_2:.*]] = arith.constant 16384 : index +# DUMPIR: %[[OFF_B_BASE:.*]] = arith.muli %[[SLOT]], %[[C16384_2]] : index +# DUMPIR: %[[C114688:.*]] = arith.constant 114688 : index +# DUMPIR: %[[OFF_B1:.*]] = arith.addi %[[OFF_B_BASE]], %[[C114688]] : index +# DUMPIR: %[[C8192:.*]] = arith.constant 8192 : index +# DUMPIR: %[[OFF_B2:.*]] = arith.addi %[[OFF_B1]], %[[C8192]] : index +# DUMPIR: %[[SMEM:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[VIEW_A:.*]] = memref.view %[[SMEM]][%[[OFF_A]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[SMEM_2:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[VIEW_B1:.*]] = memref.view %[[SMEM_2]][%[[OFF_B1]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[SMEM_3:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[VIEW_B2:.*]] = memref.view %[[SMEM_3]][%[[OFF_B2]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[TX_COUNT:.*]] = arith.constant 32768 : index +# DUMPIR: nvgpu.mbarrier.arrive.expect_tx %[[MBAR_TMA]][%[[SLOT]]], %[[TX_COUNT]], predicate = %[[IS_PRIMARY]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> +# DUMPIR: %[[C128_WG:.*]] = arith.constant 128 : index +# DUMPIR: %[[TID_MOD:.*]] = arith.remui %[[TID_PROD]], %[[C128_WG]] : index +# DUMPIR: %[[C0_TMA:.*]] = arith.constant 0 : index +# DUMPIR: %[[IS_TMA_THREAD:.*]] = arith.cmpi eq, %[[TID_MOD]], %[[C0_TMA]] : index +# DUMPIR: %[[C64:.*]] = arith.constant 64 : index +# DUMPIR: %[[K_COORD:.*]] = arith.muli %arg15, %[[C64]] : index +# DUMPIR: nvgpu.tma.async.load %{{.*}}[%[[K_COORD]], %[[DIM_X]]], %[[MBAR_TMA]][%[[SLOT]]] to %[[VIEW_A]], predicate = %[[IS_TMA_THREAD]] : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<128x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: nvgpu.tma.async.load %{{.*}}[%[[DIM_Y]], %[[K_COORD]]], %[[MBAR_TMA]][%[[SLOT]]] to %[[VIEW_B1]], predicate = %[[IS_TMA_THREAD]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[C64_OFF:.*]] = arith.constant 64 : index +# DUMPIR: %[[DIM_Y_OFF:.*]] = arith.addi %[[DIM_Y]], %[[C64_OFF]] : index +# DUMPIR: nvgpu.tma.async.load %{{.*}}[%[[DIM_Y_OFF]], %[[K_COORD]]], %[[MBAR_TMA]][%[[SLOT]]] to %[[VIEW_B2]], predicate = %[[IS_TMA_THREAD]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: scf.yield %[[PHASE]] : i1 +# DUMPIR: } +# DUMPIR: } +# DUMPIR: scf.if %[[IS_CONSUMER]] { +# DUMPIR: nvvm.setmaxregister increase 232 +# DUMPIR: %[[FALSE:.*]] = arith.constant false +# DUMPIR: %[[ACC_INIT:.*]] = nvgpu.warpgroup.mma.init.accumulator -> <fragmented = vector<128x128xf32>> +# DUMPIR: %[[C0_CONS:.*]] = arith.constant 0 : index +# DUMPIR: %[[C16_CONS:.*]] = arith.constant 16 : index +# DUMPIR: %[[C1_CONS:.*]] = arith.constant 1 : index +# DUMPIR: %[[CONS_LOOP:.*]]:2 = scf.for %arg15 = %[[C0_CONS]] to %[[C16_CONS]] step %[[C1_CONS]] iter_args(%arg16 = %[[ACC_INIT]], %arg17 = %[[FALSE]]) -> (!nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>, i1) { +# DUMPIR: %[[C7_CONS:.*]] = arith.constant 7 : index +# DUMPIR: %[[SLOT_CONS:.*]] = arith.remui %arg15, %[[C7_CONS]] : index +# DUMPIR: %[[TIMEOUT_CONS:.*]] = arith.constant 10000000 : index +# DUMPIR: nvgpu.mbarrier.try_wait.parity %[[MBAR_TMA]][%[[SLOT_CONS]]], %arg17, %[[TIMEOUT_CONS]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> +# DUMPIR: %[[C16384_CONS:.*]] = arith.constant 16384 : index +# DUMPIR: %[[OFF_A_CONS:.*]] = arith.muli %[[SLOT_CONS]], %[[C16384_CONS]] : index +# DUMPIR: %[[C114688_CONS:.*]] = arith.constant 114688 : index +# DUMPIR: %[[OFF_B_CONS:.*]] = arith.addi %[[OFF_A_CONS]], %[[C114688_CONS]] : index +# DUMPIR: %[[SMEM_CONS:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[VIEW_A_CONS:.*]] = memref.view %[[SMEM_CONS]][%[[OFF_A_CONS]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[SMEM_CONS_2:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[VIEW_B_CONS:.*]] = memref.view %[[SMEM_CONS_2]][%[[OFF_B_CONS]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x128xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[DESC_A:.*]] = nvgpu.warpgroup.generate.descriptor %[[VIEW_A_CONS]], %{{.*}} : memref<128x64xf16, #gpu.address_space<workgroup>>, <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>> +# DUMPIR: %[[DESC_B:.*]] = nvgpu.warpgroup.generate.descriptor %[[VIEW_B_CONS]], %{{.*}} : memref<64x128xf16, #gpu.address_space<workgroup>>, <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>> +# DUMPIR: %[[ACC:.*]] = nvgpu.warpgroup.mma %[[DESC_A]], %[[DESC_B]], %arg16 {transposeB} : <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>>, <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>>, <fragmented = vector<128x128xf32>> -> <fragmented = vector<128x128xf32>> +# DUMPIR: %[[C0_CMP:.*]] = arith.constant 0 : index +# DUMPIR: %[[IS_NOT_FIRST:.*]] = arith.cmpi ugt, %arg15, %[[C0_CMP]] : index +# DUMPIR: %[[ARRIVE_PRED:.*]] = arith.andi %[[IS_NOT_FIRST]], %[[IS_PRIMARY_2]] : i1 +# DUMPIR: scf.if %[[ARRIVE_PRED]] { +# DUMPIR: %[[C0_ARR:.*]] = arith.constant 0 : index +# DUMPIR: %[[IS_ZERO:.*]] = arith.cmpi eq, %[[SLOT_CONS]], %[[C0_ARR]] : index +# DUMPIR: %[[C6_WRAP:.*]] = arith.constant 6 : index +# DUMPIR: %[[C1_SUB:.*]] = arith.constant 1 : index +# DUMPIR: %[[PREV_SLOT:.*]] = arith.subi %[[SLOT_CONS]], %[[C1_SUB]] : index +# DUMPIR: %[[BARR_ID:.*]] = arith.select %[[IS_ZERO]], %[[C6_WRAP]], %[[PREV_SLOT]] : index +# DUMPIR: %{{.*}} = nvgpu.mbarrier.arrive %[[MBAR_MMA]][%[[BARR_ID]]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> !nvgpu.mbarrier.token +# DUMPIR: } +# DUMPIR: %[[C6_LAST:.*]] = arith.constant 6 : index +# DUMPIR: %[[IS_LAST_CONS:.*]] = arith.cmpi eq, %[[SLOT_CONS]], %[[C6_LAST]] : index +# DUMPIR: %[[TRUE_CONS:.*]] = arith.constant true +# DUMPIR: %[[FLIP_CONS:.*]] = arith.xori %arg17, %[[TRUE_CONS]] : i1 +# DUMPIR: %[[PHASE_CONS:.*]] = arith.select %[[IS_LAST_CONS]], %[[FLIP_CONS]], %arg17 : i1 +# DUMPIR: scf.yield %[[ACC]], %[[PHASE_CONS]] : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>, i1 +# DUMPIR: } +# DUMPIR: nvvm.wgmma.wait.group.sync.aligned 0 +# DUMPIR: %[[TID_EPI:.*]] = gpu.thread_id x +# DUMPIR: %[[BID_X_EPI:.*]] = gpu.block_id x +# DUMPIR: %[[BID_Y_EPI:.*]] = gpu.block_id y +# DUMPIR: %[[C128_EPI:.*]] = arith.constant 128 : index +# DUMPIR: %[[DIM_X_EPI:.*]] = arith.muli %[[BID_X_EPI]], %[[C128_EPI]] : index +# DUMPIR: %[[C128_EPI_2:.*]] = arith.constant 128 : index +# DUMPIR: %[[DIM_Y_EPI:.*]] = arith.muli %[[BID_Y_EPI]], %[[C128_EPI_2]] : index +# DUMPIR: %[[SMEM_EPI:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[C0_EPI:.*]] = arith.constant 0 : index +# DUMPIR: %[[VIEW_EPI:.*]] = memref.view %[[SMEM_EPI]][%[[C0_EPI]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x128xf32, #gpu.address_space<workgroup>> +# DUMPIR: %[[SUBVIEW:.*]] = memref.subview %{{.*}}[%[[DIM_X_EPI]], %[[DIM_Y_EPI]]] [128, 128] [1, 1] : memref<512x256xf32> to memref<128x128xf32, strided<[256, 1], offset: ?>> +# DUMPIR: nvgpu.warpgroup.mma.store %[[CONS_LOOP]]#0, %[[VIEW_EPI]] : <fragmented = vector<128x128xf32>> to memref<128x128xf32, #gpu.address_space<workgroup>> +# DUMPIR: gpu.barrier +# DUMPIR: %[[C0_STORE:.*]] = arith.constant 0 : index +# DUMPIR: %[[C128_STORE:.*]] = arith.constant 128 : index +# DUMPIR: %[[C1_STORE:.*]] = arith.constant 1 : index +# DUMPIR: scf.for %arg15 = %[[C0_STORE]] to %[[C128_STORE]] step %[[C1_STORE]] { +# DUMPIR: %{{.*}} = memref.load %[[VIEW_EPI]][%arg15, %[[TID_EPI]]] : memref<128x128xf32, #gpu.address_space<workgroup>> +# DUMPIR: memref.store %{{.*}}, %[[SUBVIEW]][%arg15, %[[TID_EPI]]] : memref<128x128xf32, strided<[256, 1], offset: ?>> +# DUMPIR: } +# DUMPIR: } +# DUMPIR: gpu.terminator diff --git a/mlir/test/Examples/NVGPU/lit.local.cfg b/mlir/test/Examples/NVGPU/lit.local.cfg index 689cd25..af44b2e 100644 --- a/mlir/test/Examples/NVGPU/lit.local.cfg +++ b/mlir/test/Examples/NVGPU/lit.local.cfg @@ -1,4 +1,4 @@ config.unsupported = False -if not config.enable_cuda_runner or not config.mlir_run_cuda_sm90_tests: +if not config.enable_cuda_runner or not config.enable_bindings_python: config.unsupported = True
\ No newline at end of file diff --git a/mlir/test/Examples/NVGPU/tools/nvdsl.py b/mlir/test/Examples/NVGPU/tools/nvdsl.py index 90dbb23..8561072 100644 --- a/mlir/test/Examples/NVGPU/tools/nvdsl.py +++ b/mlir/test/Examples/NVGPU/tools/nvdsl.py @@ -9,6 +9,7 @@ from mlir import runtime as rt from tools import nvgpucompiler MLIR_DYNAMIC = -9223372036854775808 +DUMP_ONLY = os.getenv("MLIR_NVDSL_PRINT_IR") == "1" def const(value: int, ty=None): @@ -84,9 +85,7 @@ class Mbarriers: self.mbar_group_op, txcount_op, self.id_op, predicate=predicate ) else: - nvgpu.mbarrier_arrive( - ir.Type.parse("!nvgpu.mbarrier.token"), self.mbar_group_op, self.id_op - ) + nvgpu.mbarrier_arrive(self.mbar_group_op, self.id_op) def try_wait(self, phase: bool = False, ticks: int = 10000000): ticks_op = const(ticks) @@ -144,7 +143,9 @@ class TMA: device_ptr, ) self.tma_descriptor = nvgpu.TmaCreateDescriptorOp( - tma_descriptor_ty, device_unranked_memref, map(const, self.tma_box_shape) + tma_descriptor_ty, + device_unranked_memref, + list(map(const, self.tma_box_shape)), ) return self.tma_descriptor.result @@ -156,7 +157,7 @@ class TMA: dest, mbarrier.mbar_group_op, self.tma_descriptor, - coordinates=map(const, coords), + coordinates=list(map(const, coords)), mbarId=mbarrier.id_op, predicate=predicate, ) @@ -310,13 +311,10 @@ class NVDSL: @functools.wraps(func) def wrapper(*args, **kwargs): launch_op = gpu.LaunchOp( - None, - [], - *map(const, grid), - *map(const, block), - dynamicSharedMemorySize=arith.constant(T.i32(), smem), + grid_size=grid, + block_size=block, + dynamic_shared_memory_size=arith.constant(T.i32(), smem), ) - launch_op.body.blocks.append(*([T.index()] * 12)) with ir.InsertionPoint(launch_op.body.blocks[0]): result = func(*args, **kwargs) gpu.terminator() @@ -334,13 +332,11 @@ class NVDSL: def saveIR(module): """Save generated IR""" - if True: # self.saveIR: - # print(mlir_nvgpu_module) - original_stdout = sys.stdout - with open("nvdsl.mlir", "w") as f: - sys.stdout = f - print(module) - sys.stdout = original_stdout + original_stdout = sys.stdout + with open("nvdsl.mlir", "w") as f: + sys.stdout = f + print(module) + sys.stdout = original_stdout def _binary_op(lhs, rhs, op: str, predAtt="") -> "ArithValue": """Generate MLIR's Arith dialects binary operations.""" @@ -429,6 +425,9 @@ class NVDSL: # Save IR in a file # saveIR(module) + if DUMP_ONLY: + print(module) + return 0 # Verify the module module.operation.verify() diff --git a/mlir/test/Examples/NVGPU/tools/nvgpucompiler.py b/mlir/test/Examples/NVGPU/tools/nvgpucompiler.py index 1c9cc74..4b661f8 100644 --- a/mlir/test/Examples/NVGPU/tools/nvgpucompiler.py +++ b/mlir/test/Examples/NVGPU/tools/nvgpucompiler.py @@ -35,9 +35,11 @@ class NvgpuCompiler: def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine: """Wraps the module in a JIT execution engine.""" - return execution_engine.ExecutionEngine( + ee = execution_engine.ExecutionEngine( module, opt_level=self.opt_level, shared_libs=self.shared_libs ) + ee.initialize() + return ee def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine: """Compiles and jits the module.""" diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index 0c5fec8c..2f5dd28 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -145,3 +145,11 @@ func.func @verify_fail_3() { %r = "arith.constant"() {value = -3 : si32} : () -> si32 return } + +// ----- + +// Verify that symbols with results are rejected +module { + // expected-error@+1 {{'test.symbol_with_result' op symbols must not have results}} + %0 = "test.symbol_with_result"() <{sym_name = "test_symbol"}> : () -> i32 +} diff --git a/mlir/test/IR/locations.mlir b/mlir/test/IR/locations.mlir index b725307..2e23746 100644 --- a/mlir/test/IR/locations.mlir +++ b/mlir/test/IR/locations.mlir @@ -105,3 +105,10 @@ func.func @dialect_location() { test.attr_with_loc("dialectLoc" loc(#test.custom_location<"foo.mlir"*32>)) return } + +// CHECK-LABEL: @location_attr +// CHECK: test.op_with_loc_attr loc("loc1":10:20) {foo.discardable_loc_attr = loc("loc2":20:30)} loc({{.*}}locations.mlir":[[# @LINE+2]]:3) +func.func @location_attr() { + test.op_with_loc_attr loc("loc1":10:20) {foo.discardable_loc_attr = loc("loc2":20:30)} + return +} diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation-vector.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation-vector.mlir new file mode 100644 index 0000000..5f8b2f4 --- /dev/null +++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation-vector.mlir @@ -0,0 +1,26 @@ +// REQUIRES: system-linux +// TODO: Run only on Linux until we figure out how to build +// mlir_apfloat_wrappers in a platform-independent way. + +// All floating-point arithmetics is lowered through APFloat. +// RUN: mlir-opt %s --convert-arith-to-apfloat --convert-vector-to-scf \ +// RUN: --convert-scf-to-cf --convert-to-llvm | \ +// RUN: mlir-runner -e entry --entry-point-result=void \ +// RUN: --shared-libs=%mlir_c_runner_utils \ +// RUN: --shared-libs=%mlir_apfloat_wrappers | FileCheck %s + +// Put rhs into separate function so that it won't be constant-folded. +func.func @foo_vec() -> (vector<4xf8E4M3FN>, vector<4xf32>) { + %cst1 = arith.constant dense<[2.2, 2.2, 2.2, 2.2]> : vector<4xf8E4M3FN> + %cst2 = arith.constant dense<[2.2, 2.2, 2.2, 2.2]> : vector<4xf32> + return %cst1, %cst2 : vector<4xf8E4M3FN>, vector<4xf32> +} + +func.func @entry() { + // CHECK: ( 3.5, 3.5, 3.5, 3.5 ) + %a1_vec = arith.constant dense<[1.4, 1.4, 1.4, 1.4]> : vector<4xf8E4M3FN> + %b1_vec, %b2_vec = func.call @foo_vec() : () -> (vector<4xf8E4M3FN>, vector<4xf32>) + %c1_vec = arith.addf %a1_vec, %b1_vec : vector<4xf8E4M3FN> // not supported by LLVM + vector.print %c1_vec : vector<4xf8E4M3FN> + return +} diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir new file mode 100644 index 0000000..7f72dd5 --- /dev/null +++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir @@ -0,0 +1,82 @@ +// REQUIRES: system-linux +// TODO: Run only on Linux until we figure out how to build +// mlir_apfloat_wrappers in a platform-independent way. + +// Case 1: All floating-point arithmetics is lowered through APFloat. +// RUN: mlir-opt %s --convert-arith-to-apfloat --convert-to-llvm | \ +// RUN: mlir-runner -e entry --entry-point-result=void \ +// RUN: --shared-libs=%mlir_c_runner_utils \ +// RUN: --shared-libs=%mlir_apfloat_wrappers | FileCheck %s + +// Case 2: Only unsupported arithmetics (f8E4M3FN) is lowered through APFloat. +// Arithmetics on f32 is lowered directly to LLVM. +// RUN: mlir-opt %s --convert-to-llvm --convert-arith-to-apfloat \ +// RUN: --convert-to-llvm --reconcile-unrealized-casts | \ +// RUN: mlir-runner -e entry --entry-point-result=void \ +// RUN: --shared-libs=%mlir_c_runner_utils \ +// RUN: --shared-libs=%mlir_apfloat_wrappers | FileCheck %s + +// Put rhs into separate function so that it won't be constant-folded. +func.func @foo() -> (f8E4M3FN, f32) { + %cst1 = arith.constant 2.2 : f8E4M3FN + %cst2 = arith.constant 2.2 : f32 + return %cst1, %cst2 : f8E4M3FN, f32 +} + +func.func @entry() { + %a1 = arith.constant 1.4 : f8E4M3FN + %a2 = arith.constant 1.4 : f32 + %b1, %b2 = func.call @foo() : () -> (f8E4M3FN, f32) + + // CHECK: 2.2 + vector.print %b2 : f32 + + // CHECK-NEXT: 3.5 + %c1 = arith.addf %a1, %b1 : f8E4M3FN // not supported by LLVM + vector.print %c1 : f8E4M3FN + + // CHECK-NEXT: 3.6 + %c2 = arith.addf %a2, %b2 : f32 // supported by LLVM + vector.print %c2 : f32 + + // CHECK-NEXT: 2.25 + %cvt = arith.truncf %b2 : f32 to f8E4M3FN + vector.print %cvt : f8E4M3FN + + // CHECK-NEXT: -2.25 + %negated = arith.negf %cvt : f8E4M3FN + vector.print %negated : f8E4M3FN + + // CHECK-NEXT: -2.25 + %min = arith.minimumf %cvt, %negated : f8E4M3FN + vector.print %min : f8E4M3FN + + // CHECK-NEXT: 1 + %cmp1 = arith.cmpf "olt", %cvt, %c1 : f8E4M3FN + vector.print %cmp1 : i1 + + // CHECK-NEXT: 1 + // Bit pattern: 01, interpreted as signed integer: 1 + %cvt_int_signed = arith.fptosi %cvt : f8E4M3FN to i2 + vector.print %cvt_int_signed : i2 + + // CHECK-NEXT: -2 + // Bit pattern: 10, interpreted as signed integer: -2 + %cvt_int_unsigned = arith.fptoui %cvt : f8E4M3FN to i2 + vector.print %cvt_int_unsigned : i2 + + // CHECK-NEXT: -6 + // Bit pattern: 1...11110111, interpreted as signed: -9 + // Closest f4E2M1FN value: -6.0 + %c9 = arith.constant -9 : i16 + %cvt_from_signed_int = arith.sitofp %c9 : i16 to f4E2M1FN + vector.print %cvt_from_signed_int : f4E2M1FN + + // CHECK-NEXT: 6 + // Bit pattern: 1...11110111, interpreted as unsigned: 65527 + // Closest f4E2M1FN value: 6.0 + %cvt_from_unsigned_int = arith.uitofp %c9 : i16 to f4E2M1FN + vector.print %cvt_from_unsigned_int : f4E2M1FN + + return +} diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir index 9d04357..d26853d 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir @@ -22,7 +22,7 @@ func.func @matmul_transpose_a(%A : tensor<?x?xf32>, %B : tensor<?x?xf32>, %C : t } func.func @main() { - %c0 = arith.constant 0 : i32 + %c0 = arith.constant 0.0 : f32 %c7 = arith.constant 7 : index %A = arith.constant dense<[ @@ -44,7 +44,7 @@ func.func @main() { %A_dyn = tensor.cast %A : tensor<13x7xf32> to tensor<?x?xf32> %C_init = bufferization.alloc_tensor(%c7, %c7) : tensor<?x?xf32> - %C = linalg.fill ins(%c0 : i32) outs(%C_init : tensor<?x?xf32>) -> tensor<?x?xf32> + %C = linalg.fill ins(%c0 : f32) outs(%C_init : tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: Unranked Memref {{.*}} rank = 2 offset = 0 sizes = [7, 7] strides = [7, 1] data = // CHECK: [32955, 33514, 34073, 34632, 35191, 35750, 36309] diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir index ad7dbb9..e2c0f1d2 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir @@ -16,7 +16,7 @@ func.func @matmul(%A : tensor<?x?xf32>, %B : tensor<?x?xf32>, %C : tensor<?x?xf3 } func.func @main() { - %c0 = arith.constant 0 : i32 + %c0 = arith.constant 0.0 : f32 %c7 = arith.constant 7 : index %A = arith.constant dense<[ @@ -37,7 +37,7 @@ func.func @main() { %B_dyn = tensor.cast %B : tensor<13x7xf32> to tensor<?x?xf32> %C_init = bufferization.alloc_tensor(%c7, %c7) : tensor<?x?xf32> - %C = linalg.fill ins(%c0 : i32) outs(%C_init : tensor<?x?xf32>) -> tensor<?x?xf32> + %C = linalg.fill ins(%c0 : f32) outs(%C_init : tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: Unranked Memref {{.*}} rank = 2 offset = 0 sizes = [7, 7] strides = [7, 1] data = // CHECK: [32955, 33514, 34073, 34632, 35191, 35750, 36309] diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir index 243f9e5..007189a 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir @@ -29,7 +29,7 @@ func.func @main() { %c128 = arith.constant 128 : i32 func.call @setArmSVLBits(%c128) : (i32) -> () - %c0 = arith.constant 0 : i32 + %c0 = arith.constant 0.0 : f32 %c7 = arith.constant 7 : index %A = arith.constant dense<[ @@ -50,7 +50,7 @@ func.func @main() { %B_dyn = tensor.cast %B : tensor<13x7xf32> to tensor<?x?xf32> %C_init = bufferization.alloc_tensor(%c7, %c7) : tensor<?x?xf32> - %C = linalg.fill ins(%c0 : i32) outs(%C_init : tensor<?x?xf32>) -> tensor<?x?xf32> + %C = linalg.fill ins(%c0 : f32) outs(%C_init : tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: Unranked Memref {{.*}} rank = 2 offset = 0 sizes = [7, 7] strides = [7, 1] data = // CHECK: [32955, 33514, 34073, 34632, 35191, 35750, 36309] diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir index 127ab70..c90476e 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir @@ -24,17 +24,14 @@ func.func @main() { %d5x = tensor.cast %c5x : tensor<5xf32> to tensor<?xf32> %d4x = tensor.cast %c4x : tensor<4xf32> to tensor<?xf32> - // CHECK-NOT: ERROR: Runtime op verification failed - func.call @simple_add(%d5x, %d5x) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>) - // CHECK: ERROR: Runtime op verification failed - // CHECK: linalg.generic - // CHECK: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size + // CHECK-NEXT: linalg.generic + // CHECK-NEXT: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size func.call @simple_add(%d5x, %d4x) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>) // CHECK: ERROR: Runtime op verification failed - // CHECK: linalg.generic - // CHECK: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size + // CHECK-NEXT: linalg.generic + // CHECK-NEXT: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size func.call @simple_add(%d4x, %d5x) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>) %c1x1 = arith.constant dense<0.0> : tensor<1x1xf32> @@ -48,72 +45,82 @@ func.func @main() { %d4x5 = tensor.cast %c4x5 : tensor<4x5xf32> to tensor<?x?xf32> %d5x4 = tensor.cast %c5x4 : tensor<5x4xf32> to tensor<?x?xf32> - // CHECK-NOT: ERROR: Runtime op verification failed - func.call @broadcast_add(%d1x1, %d1x1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) - - // CHECK-NOT: ERROR: Runtime op verification failed - func.call @broadcast_add(%d1x1, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) - - // CHECK-NOT: ERROR: Runtime op verification failed - func.call @broadcast_add(%d4x4, %d1x4) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) + // CHECK: ERROR: Runtime op verification failed + // CHECK-NEXT: linalg.generic + // CHECK-NEXT: ^ dimension #1 of input/output operand #1 is incompatible with inferred dimension size // CHECK: ERROR: Runtime op verification failed - // CHECK: linalg.generic - // CHECK: ^ dimension #1 of input/output operand #1 is incompatible with inferred dimension size + // CHECK-NEXT: linalg.generic + // CHECK-NEXT: ^ dimension #1 of input/output operand #2 is incompatible with inferred dimension size func.call @broadcast_add(%d1x4, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) // CHECK: ERROR: Runtime op verification failed - // CHECK: linalg.generic - // CHECK: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size + // CHECK-NEXT: linalg.generic + // CHECK-NEXT: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size + // CHECK: ERROR: Runtime op verification failed - // CHECK: linalg.generic - // CHECK: ^ dimension #1 of input/output operand #1 is incompatible with inferred dimension size + // CHECK-NEXT: linalg.generic + // CHECK-NEXT: ^ dimension #1 of input/output operand #1 is incompatible with inferred dimension size + // CHECK: ERROR: Runtime op verification failed - // CHECK: linalg.generic - // CHECK: ^ dimension #1 of input/output operand #2 is incompatible with inferred dimension size + // CHECK-NEXT: linalg.generic + // CHECK-NEXT: ^ dimension #1 of input/output operand #2 is incompatible with inferred dimension size func.call @broadcast_add(%d5x4, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) - // CHECK-NOT: ERROR: Runtime op verification failed - func.call @matmul_generic(%d5x4, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) - // CHECK: ERROR: Runtime op verification failed - // CHECK: linalg.generic - // CHECK: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size + // CHECK-NEXT: linalg.generic + // CHECK-NEXT: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size func.call @matmul_generic(%d4x5, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) - // CHECK-NOT: ERROR: Runtime op verification failed - func.call @matmul_named(%d5x4, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) - // CHECK: ERROR: Runtime op verification failed - // CHECK: linalg.matmul - // CHECK: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size + // CHECK-NEXT: linalg.matmul + // CHECK-NEXT: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size func.call @matmul_named(%d4x5, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) %c64x57 = arith.constant dense<0.0> : tensor<16x29xf32> %c3x4 = arith.constant dense<0.0> : tensor<3x4xf32> - // CHECK-NOT: ERROR: Runtime op verification failed - func.call @conv(%c64x57, %c3x4) : (tensor<16x29xf32>, tensor<3x4xf32>) -> (tensor<5x7xf32>) - - // CHECK-NOT: ERROR: Runtime op verification failed - func.call @reverse_from_3(%d4x) : (tensor<?xf32>) -> (tensor<?xf32>) - // CHECK: ERROR: Runtime op verification failed - // CHECK: linalg.generic - // CHECK: unexpected negative result on dimension #0 of input/output operand #0 + // CHECK-NEXT: linalg.generic + // CHECK-NEXT: unexpected negative result on dimension #0 of input/output operand #0 func.call @reverse_from_3(%d5x) : (tensor<?xf32>) -> (tensor<?xf32>) %c0x = arith.constant dense<1.0> : tensor<0xf32> %d0x = tensor.cast %c0x : tensor<0xf32> to tensor<?xf32> - // CHECK-NOT: ERROR: Runtime op verification failed - func.call @fill_empty_1d(%d0x) : (tensor<?xf32>) -> (tensor<?xf32>) %c0x5 = arith.constant dense<0.0> : tensor<0x5xf32> %d0x5 = tensor.cast %c0x5 : tensor<0x5xf32> to tensor<?x?xf32> // CHECK-NOT: ERROR: Runtime op verification failed + func.call @fill_empty_1d(%d0x) : (tensor<?xf32>) -> (tensor<?xf32>) + + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @simple_add(%d5x, %d5x) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>) + + // CHECK-NOT: ERROR: Runtime op verification failed func.call @fill_empty_2d(%d0x5) : (tensor<?x?xf32>) -> (tensor<?x?xf32>) + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @conv(%c64x57, %c3x4) : (tensor<16x29xf32>, tensor<3x4xf32>) -> (tensor<5x7xf32>) + + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @reverse_from_3(%d4x) : (tensor<?xf32>) -> (tensor<?xf32>) + + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @matmul_named(%d5x4, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) + + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @matmul_generic(%d5x4, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) + + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @broadcast_add(%d1x1, %d1x1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) + + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @broadcast_add(%d1x1, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) + + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @broadcast_add(%d4x4, %d1x4) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) + return } diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir index 8fa32d7..bbda8d4e 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir @@ -27,8 +27,8 @@ func.func @main() { %A_dyn = tensor.cast %A : tensor<8x2xf32> to tensor<?x?xf32> %B_dyn = tensor.cast %B : tensor<2x4xf32> to tensor<?x?xf32> - %c0_i32 = arith.constant 0 : i32 - %C_init = linalg.fill ins(%c0_i32 : i32) outs(%C_dyn : tensor<?x?xf32>) -> tensor<?x?xf32> + %c0_f32 = arith.constant 0.0 : f32 + %C_init = linalg.fill ins(%c0_f32 : f32) outs(%C_dyn : tensor<?x?xf32>) -> tensor<?x?xf32> %res = linalg.matmul ins(%A_dyn, %B_dyn: tensor<?x?xf32>, tensor<?x?xf32>) outs(%C_init: tensor<?x?xf32>) -> tensor<?x?xf32> diff --git a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir index 8487567..09cfee1 100644 --- a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir @@ -50,6 +50,17 @@ func.func @subview_zero_size_dim(%memref: memref<10x4x1xf32, strided<[?, ?, ?], return } +func.func @subview_with_empty_slice(%memref: memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>, + %dim_0: index, + %dim_1: index, + %dim_2: index, + %offset: index) { + %subview = memref.subview %memref[%offset, 0, 0] [%dim_0, %dim_1, %dim_2] [1, 1, 1] : + memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>> to + memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> + return +} + func.func @main() { %0 = arith.constant 0 : index @@ -127,5 +138,9 @@ func.func @main() { func.call @subview_zero_size_dim(%alloca_10x4x1_dyn_stride, %dim_0, %dim_1, %dim_2) : (memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>, index, index, index) -> () + // CHECK-NOT: ERROR: Runtime op verification failed + %offset = arith.constant 10 : index + func.call @subview_with_empty_slice(%alloca_10x4x1_dyn_stride, %dim_0, %dim_1, %dim_2, %offset) + : (memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>, index, index, index, index) -> () return } diff --git a/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir b/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir index a77fa31..745eea3 100644 --- a/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir @@ -39,6 +39,11 @@ func.func @extract_slice_zero_size_dim(%arg0: tensor<10x4x1xf32>, %dim_0: index, return } +func.func @extract_slice_empty_tensor(%arg0: tensor<10x4x1xf32>, %dim_0: index, %dim_1: index, %dim_2: index, %offset: index) { + tensor.extract_slice %arg0[%offset, 0, 0] [%dim_0, %dim_1, %dim_2] [1, 1, 1] : tensor<10x4x1xf32> to tensor<?x?x?xf32> + return +} + func.func @main() { %0 = arith.constant 0 : index @@ -115,5 +120,9 @@ func.func @main() { %dim_2 = arith.constant 1 : index func.call @extract_slice_zero_size_dim(%cst10x4x1xf32, %dim_0, %dim_1, %dim_2) : (tensor<10x4x1xf32>, index, index, index) -> () + // CHECK-NOT: ERROR: Runtime op verification failed + %offset = arith.constant 10 : index + func.call @extract_slice_empty_tensor(%cst10x4x1xf32, %dim_0, %dim_1, %dim_2, %offset) : (tensor<10x4x1xf32>, index, index, index, index) -> () + return } diff --git a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir index a374d9a..e3fee91 100644 --- a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir +++ b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir @@ -63,11 +63,11 @@ func.func @matmul_simple(%lhs: tensor<10x20xf16>, %rhs: tensor<20x15xf32>) -> te } func.func @matmul_with_extra_ops_in_func(%lhs: tensor<10x20xf32>, %rhs: tensor<20x15xf32>) -> tensor<10x15xf32> { - %cst = arith.constant 0.0 : f64 + %cst = arith.constant 0.0 : f32 %empty = tensor.empty() : tensor<10x15xf32> // expected-remark @below {{fill}} - %fill = linalg.fill ins(%cst : f64) outs(%empty : tensor<10x15xf32>) -> tensor<10x15xf32> + %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<10x15xf32>) -> tensor<10x15xf32> %real_lhs = linalg.mul ins(%lhs, %lhs : tensor<10x20xf32>, tensor<10x20xf32>) outs(%lhs : tensor<10x20xf32>) -> tensor<10x20xf32> diff --git a/mlir/test/Integration/Dialect/XeGPU/LANE/load_store_subview.mlir b/mlir/test/Integration/Dialect/XeGPU/LANE/load_store_subview.mlir new file mode 100644 index 0000000..c4608ac --- /dev/null +++ b/mlir/test/Integration/Dialect/XeGPU/LANE/load_store_subview.mlir @@ -0,0 +1,63 @@ +// RUN: mlir-opt %s --gpu-lower-to-xevm-pipeline="xegpu-op-level=lane" \ +// RUN: | mlir-runner \ +// RUN: --shared-libs=%mlir_levelzero_runtime \ +// RUN: --shared-libs=%mlir_runner_utils \ +// RUN: --entry-point-result=void \ +// RUN: | FileCheck %s + +module @subview attributes {gpu.container_module} { + gpu.module @kernel { + gpu.func @subview(%src: memref<256xf32>, %dst: memref<256xf32>) kernel { + %src_subview = memref.subview %src[5] [251] [1] : memref<256xf32> to memref<251xf32, strided<[1], offset: 5>> + %dst_subview = memref.subview %dst[10] [246] [1] : memref<256xf32> to memref<246xf32, strided<[1], offset: 10>> + %lane_id = gpu.lane_id + %mask = arith.constant 1 : i1 + %loaded = xegpu.load %src_subview[%lane_id], %mask : memref<251xf32, strided<[1], offset: 5>>, index, i1 -> f32 + xegpu.store %loaded, %dst_subview[%lane_id], %mask : f32, memref<246xf32, strided<[1], offset: 10>>, index, i1 + gpu.return + } + } + func.func @test(%src: memref<256xf32>, %dst: memref<256xf32>) -> memref<256xf32> { + %memref_src = gpu.alloc () : memref<256xf32> + gpu.memcpy %memref_src, %src : memref<256xf32>, memref<256xf32> + %memref_dst = gpu.alloc () : memref<256xf32> + gpu.memcpy %memref_dst, %dst : memref<256xf32>, memref<256xf32> + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + gpu.launch_func @kernel::@subview blocks in (%c1, %c1, %c1) threads in (%c16, %c1, %c1) args(%memref_src : memref<256xf32>, %memref_dst : memref<256xf32>) + gpu.wait // Wait for the kernel to finish. + gpu.memcpy %dst, %memref_dst : memref<256xf32>, memref<256xf32> + gpu.dealloc %memref_src : memref<256xf32> + gpu.dealloc %memref_dst : memref<256xf32> + return %dst : memref<256xf32> + } + func.func @main() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %memref_src = memref.alloc() : memref<256xf32> + %memref_dst = memref.alloc() : memref<256xf32> + // Initialize source memref + scf.for %i = %c0 to %c256 step %c1 { + %val = arith.index_cast %i : index to i32 + %val_float = arith.sitofp %val : i32 to f32 + memref.store %val_float, %memref_src[%i] : memref<256xf32> + } + // Initialize destination memref to zero + scf.for %i = %c0 to %c256 step %c1 { + %zero = arith.constant 0.0 : f32 + memref.store %zero, %memref_dst[%i] : memref<256xf32> + } + // Call test function + %gpu_result = call @test(%memref_src, %memref_dst) : (memref<256xf32>, memref<256xf32>) -> memref<256xf32> + %gpu_result_casted = memref.cast %gpu_result : memref<256xf32> to memref<*xf32> + // CHECK: Unranked Memref base@ = 0x{{[0-9a-f]+}} + // CHECK: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + call @printMemrefF32(%gpu_result_casted) : (memref<*xf32>) -> () + // Deallocate memrefs + memref.dealloc %memref_src : memref<256xf32> + memref.dealloc %memref_dst : memref<256xf32> + return + } + func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} +} diff --git a/mlir/test/Integration/Dialect/XeVM/GPU/gpu_printf.mlir b/mlir/test/Integration/Dialect/XeVM/GPU/gpu_printf.mlir index edf8775..5ed2148 100644 --- a/mlir/test/Integration/Dialect/XeVM/GPU/gpu_printf.mlir +++ b/mlir/test/Integration/Dialect/XeVM/GPU/gpu_printf.mlir @@ -3,7 +3,7 @@ // RUN: | mlir-opt -convert-scf-to-cf -convert-cf-to-llvm -convert-vector-to-llvm -convert-arith-to-llvm \ // RUN: | mlir-opt -gpu-to-llvm -reconcile-unrealized-casts -cse -gpu-module-to-binary \ // RUN: | mlir-runner \ -// RUN: --shared-libs=%mlir_sycl_runtime \ +// RUN: --shared-libs=%mlir_levelzero_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ // RUN: --shared-libs=%mlir_c_runner_utils \ // RUN: --entry-point-result=void \ diff --git a/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/wmma-matmul-f64.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/wmma-matmul-f64.mlir new file mode 100644 index 0000000..c3dd35b --- /dev/null +++ b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/wmma-matmul-f64.mlir @@ -0,0 +1,73 @@ +// RUN: mlir-opt %s \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-chip=sm_80 cubin-format=%gpu_compilation_format" \ +// RUN: | mlir-runner \ +// RUN: --shared-libs=%mlir_cuda_runtime \ +// RUN: --shared-libs=%mlir_runner_utils \ +// RUN: --entry-point-result=void \ +// RUN: | FileCheck %s + +#map0 = affine_map<(d0, d1) -> (d1, d0)> + +func.func @main() { + %a = memref.alloc() : memref<8x4xf64> + %b = memref.alloc() : memref<4x8xf64> + %c = memref.alloc() : memref<8x8xf64> + %d = memref.alloc() : memref<8x8xf64> + + %f1 = arith.constant 1.0e+00 : f64 + %fcst = arith.constant 3.14e+00 : f64 + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + + // Initialize the Input matrixes with ones. + scf.for %arg0 = %c0 to %c8 step %c1 { + scf.for %arg1 = %c0 to %c4 step %c1 { + memref.store %f1, %a[%arg0, %arg1] : memref<8x4xf64> + memref.store %f1, %b[%arg1, %arg0] : memref<4x8xf64> + } + } + // Initialize the accumulator matrix with a constant. + scf.for %arg0 = %c0 to %c8 step %c1 { + scf.for %arg1 = %c0 to %c8 step %c1 { + memref.store %fcst, %c[%arg0, %arg1] : memref<8x8xf64> + } + } + + %2 = memref.cast %a : memref<8x4xf64> to memref<*xf64> + %20 = memref.cast %b : memref<4x8xf64> to memref<*xf64> + %33 = memref.cast %c : memref<8x8xf64> to memref<*xf64> + %34 = memref.cast %d : memref<8x8xf64> to memref<*xf64> + + gpu.host_register %2 : memref<*xf64> + gpu.host_register %20 : memref<*xf64> + gpu.host_register %33 : memref<*xf64> + gpu.host_register %34 : memref<*xf64> + + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) + threads(%tx, %ty, %tz) in (%block_x = %c32, %block_y = %c1, %block_z = %c1) { + %A = gpu.subgroup_mma_load_matrix %a[%c0, %c0] {leadDimension = 4 : index} : memref<8x4xf64> -> !gpu.mma_matrix<8x4xf64, "AOp"> + %B = gpu.subgroup_mma_load_matrix %b[%c0, %c0] {leadDimension = 8 : index} : memref<4x8xf64> -> !gpu.mma_matrix<4x8xf64, "BOp"> + %C = gpu.subgroup_mma_load_matrix %c[%c0, %c0] {leadDimension = 8 : index} : memref<8x8xf64> -> !gpu.mma_matrix<8x8xf64, "COp"> + + %R = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<8x4xf64, "AOp">, !gpu.mma_matrix<4x8xf64, "BOp"> -> !gpu.mma_matrix<8x8xf64, "COp"> + + gpu.subgroup_mma_store_matrix %R, %d[%c0, %c0] {leadDimension = 8 : index}: !gpu.mma_matrix<8x8xf64, "COp">, memref<8x8xf64> + gpu.terminator + } + // Print the memref after computation. + call @printMemrefF64(%34) : (memref<*xf64>) -> () + // CHECK: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14], + // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14], + // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14], + // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14], + // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14], + // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14], + // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14], + // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14] + return +} + +func.func private @printMemrefF64(memref<*xf64>) diff --git a/mlir/test/Integration/GPU/CUDA/all-reduce-and.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-and.mlir index 5585d98..d0001f6 100644 --- a/mlir/test/Integration/GPU/CUDA/all-reduce-and.mlir +++ b/mlir/test/Integration/GPU/CUDA/all-reduce-and.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s \ -// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="allow-pattern-rollback=0" \ // RUN: | mlir-runner \ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ diff --git a/mlir/test/Integration/GPU/CUDA/all-reduce-maxsi.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-maxsi.mlir index cd90ce3..fcff5f4 100644 --- a/mlir/test/Integration/GPU/CUDA/all-reduce-maxsi.mlir +++ b/mlir/test/Integration/GPU/CUDA/all-reduce-maxsi.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s \ -// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \ // RUN: | mlir-runner \ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ diff --git a/mlir/test/Integration/GPU/CUDA/all-reduce-minsi.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-minsi.mlir index fec2567..4718ac9 100644 --- a/mlir/test/Integration/GPU/CUDA/all-reduce-minsi.mlir +++ b/mlir/test/Integration/GPU/CUDA/all-reduce-minsi.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s \ -// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \ // RUN: | mlir-runner \ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ diff --git a/mlir/test/Integration/GPU/CUDA/all-reduce-op.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-op.mlir index d5633b0..5e3a7e7e 100644 --- a/mlir/test/Integration/GPU/CUDA/all-reduce-op.mlir +++ b/mlir/test/Integration/GPU/CUDA/all-reduce-op.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s \ -// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \ // RUN: | mlir-runner \ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ diff --git a/mlir/test/Integration/GPU/CUDA/all-reduce-or.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-or.mlir index db297b0..f1a48ae 100644 --- a/mlir/test/Integration/GPU/CUDA/all-reduce-or.mlir +++ b/mlir/test/Integration/GPU/CUDA/all-reduce-or.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s \ -// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \ // RUN: | mlir-runner \ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ diff --git a/mlir/test/Integration/GPU/CUDA/all-reduce-region.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-region.mlir index 65cbc79..f0a46ce 100644 --- a/mlir/test/Integration/GPU/CUDA/all-reduce-region.mlir +++ b/mlir/test/Integration/GPU/CUDA/all-reduce-region.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s \ -// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \ // RUN: | mlir-runner \ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ diff --git a/mlir/test/Integration/GPU/CUDA/all-reduce-xor.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-xor.mlir index a0c955e..ddbabd4 100644 --- a/mlir/test/Integration/GPU/CUDA/all-reduce-xor.mlir +++ b/mlir/test/Integration/GPU/CUDA/all-reduce-xor.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s \ -// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \ // RUN: | mlir-runner \ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ diff --git a/mlir/test/Integration/GPU/CUDA/alloc-host-shared.mlir b/mlir/test/Integration/GPU/CUDA/alloc-host-shared.mlir index f041df8..5c56e2d 100644 --- a/mlir/test/Integration/GPU/CUDA/alloc-host-shared.mlir +++ b/mlir/test/Integration/GPU/CUDA/alloc-host-shared.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s \ -// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \ // RUN: | mlir-runner \ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ diff --git a/mlir/test/Integration/GPU/CUDA/assert.mlir b/mlir/test/Integration/GPU/CUDA/assert.mlir index 71a21cf..83cf70c 100644 --- a/mlir/test/Integration/GPU/CUDA/assert.mlir +++ b/mlir/test/Integration/GPU/CUDA/assert.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \ +// RUN: mlir-opt %s -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \ // RUN: | mlir-runner \ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ diff --git a/mlir/test/Integration/GPU/CUDA/async.mlir b/mlir/test/Integration/GPU/CUDA/async.mlir index 5acadd6..3e45b5a 100644 --- a/mlir/test/Integration/GPU/CUDA/async.mlir +++ b/mlir/test/Integration/GPU/CUDA/async.mlir @@ -8,8 +8,11 @@ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_async_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ -// RUN: --entry-point-result=void -O0 \ -// RUN: | FileCheck %s +// RUN: --entry-point-result=void -O0 +// RUN: +// This test is overly flaky right now and needs investigation, skipping FileCheck. +// See: https://github.com/llvm/llvm-project/issues/170833 +// DISABLED: | FileCheck %s func.func @main() { %c0 = arith.constant 0 : index diff --git a/mlir/test/Integration/GPU/CUDA/command-line-arg.mlir b/mlir/test/Integration/GPU/CUDA/command-line-arg.mlir index 34dde6e..77a4fa0 100644 --- a/mlir/test/Integration/GPU/CUDA/command-line-arg.mlir +++ b/mlir/test/Integration/GPU/CUDA/command-line-arg.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s \ -// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-chip=sm_80 ptxas-cmd-options='-v --register-usage-level=8'" -debug-only=serialize-to-binary \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-chip=sm_80 ptxas-cmd-options='-v --register-usage-level=8' allow-pattern-rollback=0" -debug-only=serialize-to-binary \ // RUN: 2>&1 | FileCheck %s func.func @host_function(%arg0 : f32, %arg1 : memref<?xf32>) { diff --git a/mlir/test/Integration/GPU/CUDA/concurrent-kernels.mlir b/mlir/test/Integration/GPU/CUDA/concurrent-kernels.mlir index ed01416..51f6e36 100644 --- a/mlir/test/Integration/GPU/CUDA/concurrent-kernels.mlir +++ b/mlir/test/Integration/GPU/CUDA/concurrent-kernels.mlir @@ -2,7 +2,7 @@ // increment a global atomic counter and wait for the counter to reach 2. // // RUN: mlir-opt %s \ -// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \ // RUN: | env CUDA_MODULE_LOADING=EAGER mlir-runner \ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ diff --git a/mlir/test/Integration/GPU/CUDA/dump-ptx.mlir b/mlir/test/Integration/GPU/CUDA/dump-ptx.mlir index 27ec1ec..efffcaa 100644 --- a/mlir/test/Integration/GPU/CUDA/dump-ptx.mlir +++ b/mlir/test/Integration/GPU/CUDA/dump-ptx.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s \ -// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline -debug-only=serialize-to-isa \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="allow-pattern-rollback=0" -debug-only=serialize-to-isa \ // RUN: 2>&1 | FileCheck %s // CHECK-LABEL: Generated by LLVM NVPTX Back-End diff --git a/mlir/test/Integration/GPU/CUDA/dump-sass.mlir b/mlir/test/Integration/GPU/CUDA/dump-sass.mlir index d32f5ef..f810678 100644 --- a/mlir/test/Integration/GPU/CUDA/dump-sass.mlir +++ b/mlir/test/Integration/GPU/CUDA/dump-sass.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s \ -// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline -debug-only=dump-sass \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="allow-pattern-rollback=0" -debug-only=dump-sass \ // RUN: 2>&1 | FileCheck %s // CHECK: MOV diff --git a/mlir/test/Integration/GPU/CUDA/gpu-to-cubin.mlir b/mlir/test/Integration/GPU/CUDA/gpu-to-cubin.mlir index 07f3218..fe3c2b1 100644 --- a/mlir/test/Integration/GPU/CUDA/gpu-to-cubin.mlir +++ b/mlir/test/Integration/GPU/CUDA/gpu-to-cubin.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s \ -// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \ // RUN: | mlir-runner \ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ diff --git a/mlir/test/Integration/GPU/CUDA/multiple-all-reduce.mlir b/mlir/test/Integration/GPU/CUDA/multiple-all-reduce.mlir index b2ac90a..f8f1aa8 100644 --- a/mlir/test/Integration/GPU/CUDA/multiple-all-reduce.mlir +++ b/mlir/test/Integration/GPU/CUDA/multiple-all-reduce.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s \ -// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \ // RUN: | mlir-runner \ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ diff --git a/mlir/test/Integration/GPU/CUDA/printf.mlir b/mlir/test/Integration/GPU/CUDA/printf.mlir index fd664f2..ef11676 100644 --- a/mlir/test/Integration/GPU/CUDA/printf.mlir +++ b/mlir/test/Integration/GPU/CUDA/printf.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s \ -// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \ // RUN: | mlir-runner \ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ diff --git a/mlir/test/Integration/GPU/CUDA/shuffle.mlir b/mlir/test/Integration/GPU/CUDA/shuffle.mlir index a6207d6..a4be5223 100644 --- a/mlir/test/Integration/GPU/CUDA/shuffle.mlir +++ b/mlir/test/Integration/GPU/CUDA/shuffle.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s \ -// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \ // RUN: | mlir-runner \ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ diff --git a/mlir/test/Integration/GPU/CUDA/two-modules.mlir b/mlir/test/Integration/GPU/CUDA/two-modules.mlir index c3cee2f..3490003 100644 --- a/mlir/test/Integration/GPU/CUDA/two-modules.mlir +++ b/mlir/test/Integration/GPU/CUDA/two-modules.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s \ -// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \ // RUN: | mlir-runner \ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ diff --git a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir index b98e8b0..c634444 100644 --- a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir +++ b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir @@ -184,3 +184,19 @@ func.func @propagate_from_block_to_iterarg(%arg0: index, %arg1: i1) { } return } + +// CHECK-LABEL: func @multiple_loop_ivs +func.func @multiple_loop_ivs(%arg0: memref<?x64xi32>) { + %ub1 = test.with_bounds { umin = 1 : index, umax = 32 : index, + smin = 1 : index, smax = 32 : index } : index + %c0_i32 = arith.constant 0 : i32 + // CHECK: scf.forall + scf.forall (%arg1, %arg2) in (%ub1, 64) { + // CHECK: test.reflect_bounds {smax = 31 : index, smin = 0 : index, umax = 31 : index, umin = 0 : index} + %1 = test.reflect_bounds %arg1 : index + // CHECK-NEXT: test.reflect_bounds {smax = 63 : index, smin = 0 : index, umax = 63 : index, umin = 0 : index} + %2 = test.reflect_bounds %arg2 : index + memref.store %c0_i32, %arg0[%1, %2] : memref<?x64xi32> + } + return +} diff --git a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir index 4fa7406..624e099 100644 --- a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir +++ b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt %s -resolve-shaped-type-result-dims -split-input-file | FileCheck %s +// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(resolve-shaped-type-result-dims{error-on-pattern-iteration-limit=false}))" -split-input-file | FileCheck %s +// See %test_unreifiable_result_shape below for why `error-on-partition-iteration-limit` is set to false. func.func @result_shape(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>) -> (index, index, index, index, index) { @@ -27,12 +28,14 @@ func.func @result_shape(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>) // ----- -func.func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>) +// Test result shape reification for an operation that implements only +// `reifyResultShapes` method of the `InferShapedTypeOpInterface`. +func.func @reify_shaped_type_using_reify_result_shapes(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>) -> (index, index, index, index, index) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index - %0:2 = "test.op_with_result_shape_per_dim_interface"(%arg0, %arg1) + %0:2 = "test.reify_shaped_type_using_reify_result_shapes"(%arg0, %arg1) : (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>) %1 = tensor.dim %0#0, %c0 : tensor<?x5xf32> %2 = tensor.dim %0#0, %c1 : tensor<?x5xf32> @@ -41,7 +44,7 @@ func.func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf3 %5 = tensor.dim %0#1, %c2 : tensor<2x3x?xf32> return %1, %2, %3, %4, %5 : index, index, index, index, index } -// CHECK-LABEL: func @result_shape_per_dim( +// CHECK-LABEL: func @reify_shaped_type_using_reify_result_shapes( // CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32> // CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>) // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index @@ -51,3 +54,127 @@ func.func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf3 // CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG_1]], %[[C0]] // CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG_0]], %[[C2]] // CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]] + +// ----- + +// Test result shape reification for an operation that implements only +// `reifyShapeOfResult` method of the `InferShapedTypeOpInterface`. +func.func @reify_shaped_type_using_reify_shape_of_result(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>) + -> (index, index, index, index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0:2 = "test.reify_shaped_type_using_reify_result_shapes"(%arg0, %arg1) + : (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>) + %1 = tensor.dim %0#0, %c0 : tensor<?x5xf32> + %2 = tensor.dim %0#0, %c1 : tensor<?x5xf32> + %3 = tensor.dim %0#1, %c0 : tensor<2x3x?xf32> + %4 = tensor.dim %0#1, %c1 : tensor<2x3x?xf32> + %5 = tensor.dim %0#1, %c2 : tensor<2x3x?xf32> + return %1, %2, %3, %4, %5 : index, index, index, index, index +} +// CHECK-LABEL: func @reify_shaped_type_using_reify_shape_of_result( +// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32> +// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>) +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG_1]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG_0]], %[[C2]] +// CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]] + +// ----- + +// Test result shape reification for an operation that implements only +// `reifyDimOfResult` method of the `InferShapedTypeOpInterface`. +func.func @reify_shaped_type_using_reify_dim_of_result(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>) + -> (index, index, index, index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0:2 = "test.reify_shaped_type_using_reify_result_shapes"(%arg0, %arg1) + : (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>) + %1 = tensor.dim %0#0, %c0 : tensor<?x5xf32> + %2 = tensor.dim %0#0, %c1 : tensor<?x5xf32> + %3 = tensor.dim %0#1, %c0 : tensor<2x3x?xf32> + %4 = tensor.dim %0#1, %c1 : tensor<2x3x?xf32> + %5 = tensor.dim %0#1, %c2 : tensor<2x3x?xf32> + return %1, %2, %3, %4, %5 : index, index, index, index, index +} +// CHECK-LABEL: func @reify_shaped_type_using_reify_dim_of_result( +// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32> +// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>) +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG_1]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG_0]], %[[C2]] +// CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]] + +// ----- + +// This tests also indicates a problem with the approach of just using `reifyShapes` +// without being specific about {result, dim} that needs to be resolved. The +// `reifyShapes` implementations introduces `dim` operations that are effectively +// dead, but it creates an infinite loop on pattern application (which eventually +// bails on hitting the iteration limit). This is the pitfall of this legacy +// mechanism. + +func.func @test_unreifiable_result_shapes(%arg0 : tensor<?x?xf32>) + -> (index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = "test.unreifiable_result_shapes"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32> + %d0 = tensor.dim %0, %c0 : tensor<?x?xf32> + %d1 = tensor.dim %0, %c1 : tensor<?x?xf32> + return %d0, %d1 : index, index +} +// CHECK-LABEL: func @test_unreifiable_result_shapes( +// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>) +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[OP:.+]] = "test.unreifiable_result_shapes"(%[[ARG0]]) +// CHECK: %[[D1:.+]] = tensor.dim %[[OP]], %[[C1]] +// CHECK: return %[[D0]], %[[D1]] +// ----- + +func.func @test_unreifiable_result_shape(%arg0 : tensor<?x?xf32>) + -> (index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = "test.unreifiable_result_shape"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32> + %d0 = tensor.dim %0, %c0 : tensor<?x?xf32> + %d1 = tensor.dim %0, %c1 : tensor<?x?xf32> + return %d0, %d1 : index, index +} +// CHECK-LABEL: func @test_unreifiable_result_shape( +// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>) +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[OP:.+]] = "test.unreifiable_result_shape"(%[[ARG0]]) +// CHECK: %[[D1:.+]] = tensor.dim %[[OP]], %[[C1]] +// CHECK: return %[[D0]], %[[D1]] + +// ----- + +func.func @test_unreifiable_dim_of_result_shape(%arg0 : tensor<?x?xf32>) + -> (index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = "test.unreifiable_dim_of_result_shape"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32> + %d0 = tensor.dim %0, %c0 : tensor<?x?xf32> + %d1 = tensor.dim %0, %c1 : tensor<?x?xf32> + return %d0, %d1 : index, index +} +// CHECK-LABEL: func @test_unreifiable_dim_of_result_shape( +// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>) +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[OP:.+]] = "test.unreifiable_dim_of_result_shape"(%[[ARG0]]) +// CHECK: %[[D1:.+]] = tensor.dim %[[OP]], %[[C1]] +// CHECK: return %[[D0]], %[[D1]] diff --git a/mlir/test/Interfaces/TilingInterface/query-fusability.mlir b/mlir/test/Interfaces/TilingInterface/query-fusability.mlir new file mode 100644 index 0000000..d7b0528 --- /dev/null +++ b/mlir/test/Interfaces/TilingInterface/query-fusability.mlir @@ -0,0 +1,70 @@ +// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics + +func.func @fusable_with_matching_offsets(%arg0: tensor<10x20xf32>, %arg1: tensor<10x20xf32>, %dest: tensor<100x200xf32>) -> tensor<100x200xf32> { + %c0 = arith.constant 0 : index + %c10 = arith.constant 10 : index + %c20 = arith.constant 20 : index + + %slice0 = tensor.insert_slice %arg0 into %dest[%c0, %c0] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32> + %slice1 = tensor.insert_slice %arg1 into %dest[%c0, %c0] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32> + + // expected-remark @+1 {{can be fused with producer tensor.insert_slice ops}} + %result = linalg.add ins(%slice0, %slice1 : tensor<100x200xf32>, tensor<100x200xf32>) + outs(%dest : tensor<100x200xf32>) -> tensor<100x200xf32> + + return %result : tensor<100x200xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg: !transform.any_op) { + %add = transform.structured.match ops{["linalg.add"]} in %arg : (!transform.any_op) -> !transform.any_op + transform.test.query_producer_fusability %add : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @not_fusable_with_different_offsets(%arg0: tensor<10x20xf32>, %arg1: tensor<10x20xf32>, %dest: tensor<100x200xf32>) -> tensor<100x200xf32> { + %c0 = arith.constant 0 : index + %c10 = arith.constant 10 : index + %c20 = arith.constant 20 : index + + %slice0 = tensor.insert_slice %arg0 into %dest[%c0, %c0] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32> + %slice1 = tensor.insert_slice %arg1 into %dest[%c10, %c20] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32> + + // expected-remark @+1 {{cannot be fused with producer tensor.insert_slice ops}} + %result = linalg.add ins(%slice0, %slice1 : tensor<100x200xf32>, tensor<100x200xf32>) + outs(%dest : tensor<100x200xf32>) -> tensor<100x200xf32> + + return %result : tensor<100x200xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg: !transform.any_op) { + %add = transform.structured.match ops{["linalg.add"]} in %arg : (!transform.any_op) -> !transform.any_op + transform.test.query_producer_fusability %add : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @fusable_with_consumer_extract_slice(%arg0: tensor<100x200xf32>, %arg1: tensor<100x200xf32>, %dest: tensor<100x200xf32>) -> tensor<10x20xf32> { + // expected-remark @+1 {{can be fused with consumer tensor.extract_slice op}} + %add = linalg.add ins(%arg0, %arg1 : tensor<100x200xf32>, tensor<100x200xf32>) + outs(%dest : tensor<100x200xf32>) -> tensor<100x200xf32> + + %c0 = arith.constant 0 : index + %slice = tensor.extract_slice %add[%c0, %c0] [10, 20] [1, 1] : tensor<100x200xf32> to tensor<10x20xf32> + + return %slice : tensor<10x20xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg: !transform.any_op) { + %add = transform.structured.match ops{["linalg.add"]} in %arg : (!transform.any_op) -> !transform.any_op + transform.test.query_consumer_fusability %add : !transform.any_op + transform.yield + } +} diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir new file mode 100644 index 0000000..62dd7fa --- /dev/null +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir @@ -0,0 +1,1156 @@ +// RUN: mlir-opt --transform-interpreter --cse --split-input-file --verify-diagnostics %s | FileCheck %s + +#map = affine_map<(d0) -> (d0)> +module { + func.func @fuse_tileable_consumer_scf_for(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) { + %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32> + %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) { + ^bb0(%in: f32, %in_16: f32, %out: f32): + %13 = arith.mulf %in, %in_16 : f32 + %14 = arith.addf %out, %13 : f32 + linalg.yield %14 : f32 + } -> tensor<32xf32> + %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32> + scf.yield %arg5, %4 : tensor<64xf32>, tensor<64xf32> + } + %in_operand_2 = tensor.empty() : tensor<64xf32> + %out_operand_3 = tensor.empty() : tensor<64xf32> + %2 = linalg.add ins(%1#1, %in_operand_2 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand_3 : tensor<64xf32>) -> tensor<64xf32> + return %2 : tensor<64xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %loop = transform.structured.match ops{["scf.for"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer_using_slice %yield in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: func.func @fuse_tileable_consumer_scf_for( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>) +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %0 = tensor.empty() : tensor<64xf32> +// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]] +// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG:.*]] = %0) +// CHECK-SAME: { +// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: %[[MAT_OUT:.*]] = linalg.generic +// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>) +// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1] +// CHECK: %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: %[[ELEM_OUT:.*]] = linalg.add +// CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] : +// CHECK-SAME: outs(%[[SLICE_OUT]] : +// CHECK: %[[INSERT_ELEM:.*]] = tensor.insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM]] : +// CHECK: } +// CHECK: return %[[FINAL_RESULT]]#2 : + +// ----- + +module { + func.func @fuse_tileable_consumer_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1:2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2, %arg6 = %arg2) -> (tensor<64x64xf32>, tensor<64x64xf32>) { + %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32> + %extracted_slice_1 = tensor.extract_slice %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32> + %3 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32> + tensor.parallel_insert_slice %extracted_slice_1 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32> + } + } + %in_operand_2 = tensor.empty() : tensor<64x64xf32> + %out_operand_3 = tensor.empty() : tensor<64x64xf32> + %2 = linalg.add ins(%1#1, %in_operand_2 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%out_operand_3 : tensor<64x64xf32>) -> tensor<64x64xf32> + return %2 : tensor<64x64xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %first_slice_op, %second_slice_op = transform.split_handle %slice_ops + : (!transform.any_op) + -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer_using_slice %first_slice_op in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: func.func @fuse_tileable_consumer_scf_forall( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32>) +// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<64x64xf32> +// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2) +// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG:.*]] = %[[OUT_INIT]]) +// CHECK-SAME: { +// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[SECOND_ARG_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[MAT_OUT:.*]] = linalg.matmul +// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : +// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %[[OUT_INIT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[ELEM_OUT:.*]] = linalg.add +// CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] : +// CHECK-SAME: outs(%[[SLICE_OUT]] : +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[MAT_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[SECOND_ARG_SLICE]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: } +// CHECK: } +// CHECK: return %[[FINAL_RESULT]]#2 : + +// ----- + +#map = affine_map<(d0) -> (d0)> +module { + func.func @fuse_tileable_consumer_scf_for_multi_yielding_consumer(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) { + %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32> + %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) { + ^bb0(%in: f32, %in_16: f32, %out: f32): + %13 = arith.mulf %in, %in_16 : f32 + %14 = arith.addf %out, %13 : f32 + linalg.yield %14 : f32 + } -> tensor<32xf32> + %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32> + scf.yield %arg5, %4 : tensor<64xf32>, tensor<64xf32> + } + %in_operand_2 = tensor.empty() : tensor<64xf32> + %out_operand_3 = tensor.empty() : tensor<64xf32> + %out_operand_4 = tensor.empty() : tensor<64xf32> + %2:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%1#1, %in_operand_2 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand_3, %out_operand_4 : tensor<64xf32>, tensor<64xf32>) { + ^bb0(%in: f32, %in_16: f32, %out_0: f32, %out_1: f32): + %13 = arith.mulf %in, %in_16 : f32 + %14 = arith.subf %out_0, %13 : f32 + %15 = arith.addf %out_1, %in : f32 + linalg.yield %14, %15 : f32, f32 + } -> (tensor<64xf32>, tensor<64xf32>) + return %2#1 : tensor<64xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.for"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer_using_slice %yield in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: func.func @fuse_tileable_consumer_scf_for_multi_yielding_consumer( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>) +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %0 = tensor.empty() : tensor<64xf32> +// CHECK: %[[FINAL_RESULT:.*]]:4 = scf.for %[[IV:.*]] = %[[C0]] +// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG_0:.*]] = %0, %[[ELEM_OUT_ARG_1:.*]] = %0) +// CHECK-SAME: { +// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: %[[MAT_OUT:.*]] = linalg.generic +// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>) +// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1] +// CHECK: %[[SLICE_OUT_0:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_0]][%[[IV]]] [32] [1] +// CHECK: %[[SLICE_OUT_1:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_1]][%[[IV]]] [32] [1] +// CHECK: %[[ELEM_OUT:.*]]:2 = linalg.generic +// CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] : +// CHECK-SAME: outs(%[[SLICE_OUT_0]], %[[SLICE_OUT_1]] : +// CHECK: %[[INSERT_ELEM_0:.*]] = tensor.insert_slice %[[ELEM_OUT]]#0 into %[[ELEM_OUT_ARG_0]][%[[IV]]] [32] [1] +// CHECK: %[[INSERT_ELEM_1:.*]] = tensor.insert_slice %[[ELEM_OUT]]#1 into %[[ELEM_OUT_ARG_1]][%[[IV]]] [32] [1] +// CHECK: scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM_0]], %[[INSERT_ELEM_1]] : +// CHECK: } +// CHECK: return %[[FINAL_RESULT]]#3 : + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64x32xf32>) -> (tensor<64x64xf32>, tensor<2048xf32>) { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %0:2 = scf.forall (%arg4, %arg5) in (2, 2) shared_outs(%arg6 = %arg3, %arg7 = %arg2) -> (tensor<64x32xf32>, tensor<64x64xf32>) { + %extracted_slice = tensor.extract_slice %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32> + %extracted_slice_0 = tensor.extract_slice %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32> + %6 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %6 into %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32> + tensor.parallel_insert_slice %extracted_slice_0 into %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32> + } + } + %1 = tensor.empty() : tensor<64x64xf32> + %2 = tensor.empty() : tensor<64x64xf32> + %3 = tensor.empty() : tensor<64x64xf32> + %4:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%0#1, %1 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%2, %3 : tensor<64x64xf32>, tensor<64x64xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32, %out_1: f32): + %6 = arith.mulf %in, %in_0 : f32 + %7 = arith.subf %out, %6 : f32 + %8 = arith.addf %out_1, %in : f32 + linalg.yield %7, %8 : f32, f32 + } -> (tensor<64x64xf32>, tensor<64x64xf32>) + %5 = tensor.empty() : tensor<2048xf32> + %unpack = linalg.unpack %0#0 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %5 : tensor<64x32xf32> -> tensor<2048xf32> + return %4#1, %unpack : tensor<64x64xf32>, tensor<2048xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %first_slice_op, %second_slice_op = transform.split_handle %slice_ops + : (!transform.any_op) + -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer_using_slice %first_slice_op in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32> +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: tensor<64x32xf32>) +// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<64x64xf32> +// CHECK: %[[FINAL_RESULT:.*]]:4 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2) +// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG3]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG_0:.*]] = %[[OUT_INIT]], %[[ELEM_OUT_ARG_1:.*]] = %[[OUT_INIT]]) +// CHECK-SAME: { +// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[SECOND_ARG_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[MAT_OUT:.*]] = linalg.matmul +// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : +// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %[[OUT_INIT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[SLICE_OUT_0:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_0]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[SLICE_OUT_1:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_1]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[ELEM_OUT:.*]]:2 = linalg.generic +// CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] : +// CHECK-SAME: outs(%[[SLICE_OUT_0]], %[[SLICE_OUT_1]] : +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[MAT_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[SECOND_ARG_SLICE]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[ELEM_OUT]]#0 into %[[ELEM_OUT_ARG_0]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[ELEM_OUT]]#1 into %[[ELEM_OUT_ARG_1]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: } +// CHECK: } +// CHECK: %[[UNPACK:.*]] = linalg.unpack %[[FINAL_RESULT]]#0 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %{{.*}} : tensor<64x32xf32> -> tensor<2048xf32> +// CHECK: return %[[FINAL_RESULT]]#3, %[[UNPACK]] : + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @fuse_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2048xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) { + %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32> + %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) { + ^bb0(%in: f32, %in_16: f32, %out: f32): + %13 = arith.mulf %in, %in_16 : f32 + %14 = arith.addf %out, %13 : f32 + linalg.yield %14 : f32 + } -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32> + } + } + %output = tensor.empty() : tensor<2048xf32> + %unpack = linalg.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2048xf32> + return %unpack : tensor<2048xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer_using_slice %slice_op in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)> +// CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2048)> +// CHECK: func.func @fuse_unpack_consumer_into_scf_forall( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>) +// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<2048xf32> +// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) = (0, 0) to (64, 32) step (32, 32) +// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[UNPACK_OUT_ARG:.*]] = %[[OUT_INIT]]) +// CHECK-SAME: { +// CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[GENERIC_OUT:.*]] = linalg.generic +// CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] : +// CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_OFFSET_MAP]](%[[IV1]]) +// CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min #[[UNPACK_RESULT_SIZE_MAP]](%[[IV1]]) +// CHECK: %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1] +// CHECK: %[[TILED_UNPACK_OUT:.*]] = linalg.unpack %[[GENERIC_OUT]] +// CHECK-SAME: outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] +// CHECK-SAME: into %[[TILED_UNPACK_DEST]] +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[TILED_UNPACK_OUT]] into %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1] +// CHECK: } +// CHECK: } +// CHECK: return %[[FINAL_RESULT]]#1 : + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @fuse_unaligned_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2047xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) { + %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32> + %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) { + ^bb0(%in: f32, %in_16: f32, %out: f32): + %13 = arith.mulf %in, %in_16 : f32 + %14 = arith.addf %out, %13 : f32 + linalg.yield %14 : f32 + } -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32> + } + } + %output = tensor.empty() : tensor<2047xf32> + %unpack = linalg.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2047xf32> + return %unpack : tensor<2047xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer_using_slice %slice_op in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)> +// CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2047)> +// CHECK: func.func @fuse_unaligned_unpack_consumer_into_scf_forall( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>) +// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<2047xf32> +// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) = (0, 0) to (64, 32) step (32, 32) +// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[UNPACK_OUT_ARG:.*]] = %[[OUT_INIT]]) +// CHECK-SAME: { +// CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[GENERIC_OUT:.*]] = linalg.generic +// CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] : +// CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_OFFSET_MAP]](%[[IV1]]) +// CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min #[[UNPACK_RESULT_SIZE_MAP]](%[[IV1]]) +// CHECK: %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1] +// CHECK: %[[TILED_UNPACK_OUT:.*]] = linalg.unpack %[[GENERIC_OUT]] +// CHECK-SAME: outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] +// CHECK-SAME: into %[[TILED_UNPACK_DEST]] +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[TILED_UNPACK_OUT]] into %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1] +// CHECK: } +// CHECK: } +// CHECK: return %[[FINAL_RESULT]]#1 : + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @fuse_perfect_tiling_pack_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<4x32x16xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1 = scf.forall (%arg3, %arg4) in (2, 1) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) { + %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32> + %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) { + ^bb0(%in: f32, %in_16: f32, %out: f32): + %13 = arith.mulf %in, %in_16 : f32 + %14 = arith.addf %out, %13 : f32 + linalg.yield %14 : f32 + } -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32> + } + } + %output = tensor.empty() : tensor<4x32x16xf32> + %pack = linalg.pack %1 inner_dims_pos = [0] inner_tiles = [16] into %output : tensor<64x32xf32> -> tensor<4x32x16xf32> + return %pack : tensor<4x32x16xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer_using_slice %slice_op in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)> +// CHECK: func.func @fuse_perfect_tiling_pack_consumer( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>) +// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<4x32x16xf32> +// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 1) +// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]]) +// CHECK-SAME: { +// CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[GENERIC_OUT:.*]] = linalg.generic +// CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] : +// CHECK: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV1]]) +// CHECK: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], %[[IV2]], 0] [2, 32, 16] [1, 1, 1] +// CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[GENERIC_OUT]] +// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16] +// CHECK-SAME: into %[[TILED_PACK_DEST]] +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], %[[IV2]], 0] [2, 32, 16] [1, 1, 1] + +// ----- + +#map = affine_map<(d0) -> (-d0 + 4, 16)> +func.func @fuse_pack_consumer_if_single_iteration(%arg0: tensor<4x4xf32>) -> tensor<1x4x16x1xf32> { + %0 = tensor.empty() : tensor<1x4x16x1xf32> + %1 = tensor.empty() : tensor<4x4xf32> + %2 = scf.forall (%arg1) = (0) to (4) step (16) shared_outs(%arg2 = %1) -> (tensor<4x4xf32>) { + %3 = affine.min #map(%arg1) + %extracted_slice = tensor.extract_slice %arg0[%arg1, 0] [%3, 4] [1, 1] : tensor<4x4xf32> to tensor<?x4xf32> + %extracted_slice_0 = tensor.extract_slice %arg2[%arg1, 0] [%3, 4] [1, 1] : tensor<4x4xf32> to tensor<?x4xf32> + %4 = linalg.exp ins(%extracted_slice : tensor<?x4xf32>) outs(%extracted_slice_0 : tensor<?x4xf32>) -> tensor<?x4xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %4 into %arg2[%arg1, 0] [%3, 4] [1, 1] : tensor<?x4xf32> into tensor<4x4xf32> + } + } + %cst = arith.constant 0.000000e+00 : f32 + %pack = linalg.pack %2 padding_value(%cst : f32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 1] into %0 : tensor<4x4xf32> -> tensor<1x4x16x1xf32> + return %pack : tensor<1x4x16x1xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %consumer, %fused_consumer = transform.test.fuse_consumer_using_slice %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: #[[MAP:.*]] = affine_map<(d0) -> (-d0 + 4, 16)> +// CHECK: func.func @fuse_pack_consumer_if_single_iteration( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-DAG: %[[PACK_INIT:.*]] = tensor.empty() : tensor<1x4x16x1xf32> +// CHECK-DAG: %[[ELEM_INIT:.*]] = tensor.empty() : tensor<4x4xf32> +// CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (4) step (16) +// CHECK-SAME: shared_outs(%[[ELEM_OUT_ARG:.*]] = %[[ELEM_INIT]], %[[PACK_OUT_ARG:.*]] = %[[PACK_INIT]]) +// CHECK-DAG: %[[SIZE:.+]] = affine.min #[[MAP]](%[[IV]]) +// CHECK-DAG: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1] +// CHECK-DAG: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1] +// CHECK: %[[ELEM:.*]] = linalg.exp +// CHECK-SAME: ins(%[[ELEM_SRC]] +// CHECK-SAME: outs(%[[ELEM_DEST]] +// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[IV]], 0, 0, 0] [1, 4, 16, 1] [1, 1, 1, 1] +// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]] +// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32) +// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 1] +// CHECK-SAME: into %[[TILED_PACK_DEST]] +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][%[[IV]], 0, 0, 0] [1, 4, 16, 1] [1, 1, 1, 1] + +// ----- + +func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>, %arg2: tensor<2x64x16x1xf32>) -> tensor<2x64x16x1xf32> { + %0 = scf.forall (%arg3) = (0) to (32) step (16) shared_outs(%arg4 = %arg1) -> (tensor<64x32xf32>) { + %src = tensor.extract_slice %arg0[0, %arg3] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> + %dest = tensor.extract_slice %arg4[0, %arg3] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> + %1 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %1 into %arg4[0, %arg3] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32> + } + } + %pack = linalg.pack %0 outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 1] into %arg2 : tensor<64x32xf32> -> tensor<2x64x16x1xf32> + return %pack : tensor<2x64x16x1xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %consumer, %fused_consumer = transform.test.fuse_consumer_using_slice %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)> +// CHECK: func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] +// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16) +// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[ARG2]]) +// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1] +// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] +// CHECK: %[[ELEM:.*]] = linalg.exp +// CHECK-SAME: ins(%[[ELEM_SRC]] +// CHECK-SAME: outs(%[[ELEM_DEST]] +// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]]) +// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1] +// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]] +// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 1] +// CHECK-SAME: into %[[TILED_PACK_DEST]] +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1] + +// ----- + +// It is valid to fuse the pack op in perfect tiling scenario when the dimension +// is dynamic and padding is not needed. + +func.func @fuse_pack_consumer_with_no_pad_dynamic_dim(%arg0: tensor<64x?xf32>, %arg1: tensor<64x?xf32>, %1: tensor<64x?x16xf32>) -> tensor<64x?x16xf32> { + %c1 = arith.constant 1 : index + %d1 = tensor.dim %arg0, %c1 : tensor<64x?xf32> + %0 = scf.forall (%arg2) = (0) to (%d1) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x?xf32>) { + %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x?xf32> to tensor<64x16xf32> + %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x?xf32> to tensor<64x16xf32> + %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x?xf32> + } + } + %pack = linalg.pack %0 inner_dims_pos = [1] inner_tiles = [16] into %1 : tensor<64x?xf32> -> tensor<64x?x16xf32> + return %pack : tensor<64x?x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %consumer, %fused_consumer = transform.test.fuse_consumer_using_slice %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)> +// CHECK: func.func @fuse_pack_consumer_with_no_pad_dynamic_dim( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] +// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (%{{.+}}) step (16) +// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[ARG2]]) +// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1] +// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] +// CHECK: %[[ELEM:.*]] = linalg.exp +// CHECK-SAME: ins(%[[ELEM_SRC]] +// CHECK-SAME: outs(%[[ELEM_DEST]] +// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]]) +// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0] [64, 1, 16] [1, 1, 1] +// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]] +// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [16] +// CHECK-SAME: into %[[TILED_PACK_DEST]] +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0] [64, 1, 16] [1, 1, 1] + +// ----- + +// It is valid to fuse the pack op with padding semantics if it is a perfect +// tiling case. + +func.func @fuse_pack_consumer_with_padding_semantics(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<22x2x3x16xf32> { + %0 = scf.forall (%arg2, %arg3) = (0, 0) to (64, 32) step (15, 16) shared_outs(%arg4 = %arg1) -> (tensor<64x32xf32>) { + %size = affine.min affine_map<(d0) -> (-d0 + 64, 15)>(%arg2) + %src = tensor.extract_slice %arg0[%arg2, %arg3] [%size, 16] [1, 1] : tensor<64x32xf32> to tensor<?x16xf32> + %dest = tensor.extract_slice %arg4[%arg2, %arg3] [%size, 16] [1, 1] : tensor<64x32xf32> to tensor<?x16xf32> + %2 = linalg.exp ins(%src : tensor<?x16xf32>) outs(%dest : tensor<?x16xf32>) -> tensor<?x16xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %2 into %arg4[%arg2, %arg3] [%size, 16] [1, 1] : tensor<?x16xf32> into tensor<64x32xf32> + } + } + %1 = tensor.empty() : tensor<22x2x3x16xf32> + %cst = arith.constant 0.000000e+00 : f32 + %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<22x2x3x16xf32> + return %pack : tensor<22x2x3x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %consumer, %fused_consumer = transform.test.fuse_consumer_using_slice %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (-d0 + 64, 15)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (d0 floordiv 3)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0) -> (d0 ceildiv 3)> +// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0) -> (d0 floordiv 16)> +// CHECK: func.func @fuse_pack_consumer_with_padding_semantics( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<22x2x3x16xf32> +// CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %{{.*}}:2 = scf.forall (%[[I:.*]], %[[J:.*]]) = (0, 0) to (64, 32) step (15, 16) +// CHECK-SAME: shared_outs(%[[ELEM_OUT:.*]] = %[[ARG1]], %[[PACK_OUT:.*]] = %[[OUT_INIT]]) +// CHECK: %[[SIZE:.+]] = affine.min #[[MAP0]](%[[I]]) +// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]] +// CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1] +// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT]] +// CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1] +// CHECK: %[[ELEM:.*]] = linalg.exp +// CHECK-SAME: ins(%[[ELEM_SRC]] +// CHECK-SAME: outs(%[[ELEM_DEST]] +// CHECK-DAG: %[[D0_OFFSET:.*]] = affine.apply #[[MAP1]](%[[I]]) +// CHECK-DAG: %[[D0_SIZE:.*]] = affine.apply #[[MAP2]](%[[SIZE]]) +// CHECK-DAG: %[[D1_OFFSET:.*]] = affine.apply #[[MAP3]](%[[J]]) +// CHECK-DAG: %[[PACK_INIT:.*]] = tensor.extract_slice %[[PACK_OUT]] +// CHECK-SAME: [%[[D0_OFFSET]], %[[D1_OFFSET]], 0, 0] [%[[D0_SIZE]], 1, 3, 16] [1, 1, 1, 1] +// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]] +// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32) +// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 16] +// CHECK-SAME: into %[[TILED_PACK_DEST]] +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[ELEM_OUT]] +// CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT]] +// CHECK-SAME: [%[[D0_OFFSET]], %[[D1_OFFSET]], 0, 0] [%[[D0_SIZE]], 1, 3, 16] [1, 1, 1, 1] + +// ----- + +// Imperfect tiling is not supported in pack op consumer fusion. + +#map = affine_map<(d0) -> (d0 * 5)> +#map1 = affine_map<(d0) -> (d0)> +func.func @nofuse_pack_with_imperfect_tiling(%arg0: tensor<30xf32>) -> tensor<5x6xf32> { + %0 = tensor.empty() : tensor<30xf32> + %1 = scf.forall (%arg1) in (6) shared_outs(%arg2 = %0) -> (tensor<30xf32>) { + %3 = affine.apply #map(%arg1) + %extracted_slice = tensor.extract_slice %arg0[%3] [5] [1] : tensor<30xf32> to tensor<5xf32> + %extracted_slice_0 = tensor.extract_slice %arg2[%3] [5] [1] : tensor<30xf32> to tensor<5xf32> + %4 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel"]} ins(%extracted_slice : tensor<5xf32>) outs(%extracted_slice_0 : tensor<5xf32>) { + ^bb0(%in: f32, %out: f32): + %5 = arith.addf %in, %in : f32 + linalg.yield %5 : f32 + } -> tensor<5xf32> + scf.forall.in_parallel { + // expected-error @below {{failed to fuse consumer of slice}} + tensor.parallel_insert_slice %4 into %arg2[%3] [5] [1] : tensor<5xf32> into tensor<30xf32> + } + } + %2 = tensor.empty() : tensor<5x6xf32> + %pack = linalg.pack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [6] into %2 : tensor<30xf32> -> tensor<5x6xf32> + return %pack : tensor<5x6xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %consumer, %fused_consumer = transform.test.fuse_consumer_using_slice %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +module { + func.func @fuse_add_multiple_tilable_consumers(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %cst = arith.constant 0.000000e+00 : f32 + %dest0 = tensor.empty() : tensor<256x256xf32> + %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest0) -> (tensor<256x256xf32>) { + %extracted_slice_1 = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> + %extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> + %extracted_slice_3 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> + %3 = linalg.add ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice_1 : tensor<64x256xf32>) -> tensor<64x256xf32> + %insert_slice = tensor.insert_slice %3 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32> + scf.yield %insert_slice : tensor<256x256xf32> + } + %4 = linalg.mul ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32> + %5 = linalg.exp ins(%1 : tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32> + return %4, %5 : tensor<256x256xf32>, tensor<256x256xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.for"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer_using_slice %slice_op in (%loop) num_consumer_to_fuse = 2 + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: func.func @fuse_add_multiple_tilable_consumers( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x256xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<256x256xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32> +// CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32> +// CHECK: %[[LOOP_RESULT:.*]]:3 = scf.for %[[IV1:.*]] = %[[C0]] +// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]]) +// CHECK-SAME: { +// CHECK: %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[ADD_INS0_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[ADD_INS1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[TILED_ADD_OUT:.*]] = linalg.add +// CHECK-SAME: ins(%[[ADD_INS0_SLICE]], %[[ADD_INS1_SLICE]] : +// CHECK-SAME: outs(%[[ADD_OUT_SLICE]] : +// CHECK: %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[EXP_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[TILED_EXP_OUT:.*]] = linalg.exp +// CHECK-SAME: ins(%[[TILED_ADD_OUT]] : +// CHECK-SAME: outs(%[[EXP_OUT_SLICE]] : +// CHECK: %[[MUL_INS2_SLICE:.*]] = tensor.extract_slice %[[ARG2]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[MUL_OUT_SLICE:.*]] = tensor.extract_slice %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[TILED_MUL_OUT:.*]] = linalg.mul +// CHECK-SAME: ins(%[[TILED_ADD_OUT]], %[[MUL_INS2_SLICE]] : +// CHECK-SAME: outs(%[[MUL_OUT_SLICE]] : +// CHECK: %[[INSERT_EXP:.*]] = tensor.insert_slice %[[TILED_EXP_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[INSERT_MUL:.*]] = tensor.insert_slice %[[TILED_MUL_OUT]] into %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: scf.yield %[[INSERT_ADD]], %[[INSERT_EXP]], %[[INSERT_MUL]] : +// CHECK: } +// CHECK: return %[[LOOP_RESULT]]#2, %[[LOOP_RESULT]]#1 : + +// ----- + +module { + func.func @no_fuse_only_dps_consumer(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<258x258xf32>) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %cst = arith.constant 0.000000e+00 : f32 + %dest0 = tensor.empty() : tensor<256x256xf32> + %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest0) -> (tensor<256x256xf32>) { + %extracted_slice_1 = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> + %extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> + %extracted_slice_3 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> + %3 = linalg.add ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice_1 : tensor<64x256xf32>) -> tensor<64x256xf32> + %insert_slice = tensor.insert_slice %3 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32> + scf.yield %insert_slice : tensor<256x256xf32> + } + %dest1 = tensor.empty() : tensor<258x258xf32> + %4 = tensor.insert_slice %1 into %dest1[0, 0] [256, 256] [1, 1] : tensor<256x256xf32> into tensor<258x258xf32> + %5 = linalg.mul ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32> + return %5, %4 : tensor<256x256xf32>, tensor<258x258xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %slice_ops = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %slice_op, %other_slice = transform.split_handle %slice_ops : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer_using_slice %slice_op in (%loop) num_consumer_to_fuse = 1 + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: func.func @no_fuse_only_dps_consumer( +// CHECK: %[[LOOP_RESULT:.*]]:2 = scf.for {{.*}} { +// CHECK: linalg.add +// CHECK: linalg.mul +// CHECK: scf.yield +// CHECK: } +// CHECK: %[[RES_SLICE:.+]] = tensor.insert_slice +// CHECK: return %[[LOOP_RESULT]]#1, %[[RES_SLICE]] + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1)> +#map1 = affine_map<(d0, d1, d2) -> (d2)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +module { + func.func @fuse_with_tilable_consumer_with_projected_permutations(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<24xf32>) -> tensor<256x256x24xf32> { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %0 = tensor.empty() : tensor<256x256xf32> + %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %0) -> (tensor<256x256xf32>) { + %extracted_slice = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> + %extracted_slice_0 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> + %extracted_slice_1 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> + %4 = linalg.add ins(%extracted_slice_0, %extracted_slice_1 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice : tensor<64x256xf32>) -> tensor<64x256xf32> + %inserted_slice = tensor.insert_slice %4 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32> + scf.yield %inserted_slice : tensor<256x256xf32> + } + %2 = tensor.empty() : tensor<256x256x24xf32> + %3 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1, %arg2 : tensor<256x256xf32>, tensor<24xf32>) outs(%2 : tensor<256x256x24xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %4 = arith.addf %in, %in_0 : f32 + linalg.yield %4 : f32 + } -> tensor<256x256x24xf32> + return %3 : tensor<256x256x24xf32> + } +} + +// CHECK: func.func @fuse_with_tilable_consumer_with_projected_permutations(%[[VAL_0:.*]]: tensor<256x256xf32>, %[[VAL_1:.*]]: tensor<256x256xf32>, %[[VAL_2:.*]]: tensor<24xf32>) -> tensor<256x256x24xf32> { +// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 64 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 256 : index +// CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<256x256xf32> +// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<256x256x24xf32> +// CHECK: %[[VAL_8:.*]]:2 = scf.for %[[VAL_9:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_10:.*]] = %[[VAL_6]], %[[VAL_11:.*]] = %[[VAL_7]]) -> (tensor<256x256xf32>, tensor<256x256x24xf32>) { +// CHECK: %[[VAL_12:.*]] = tensor.extract_slice %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] +// CHECK: %[[VAL_13:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] +// CHECK: %[[VAL_14:.*]] = tensor.extract_slice %[[VAL_1]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] +// CHECK: %[[VAL_15:.*]] = linalg.add ins(%[[VAL_13]], %[[VAL_14]] : tensor<64x256xf32>, tensor<64x256xf32>) outs(%[[VAL_12]] : tensor<64x256xf32>) -> tensor<64x256xf32> +// CHECK: %[[VAL_16:.*]] = tensor.insert_slice %[[VAL_15]] into %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] +// CHECK: %[[VAL_17:.*]] = tensor.extract_slice %[[VAL_2]][0] [24] [1] : tensor<24xf32> to tensor<24xf32> +// CHECK: %[[VAL_18:.*]] = tensor.extract_slice %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1] +// CHECK: %[[VAL_19:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_15]], %[[VAL_17]] : tensor<64x256xf32>, tensor<24xf32>) outs(%[[VAL_18]] : tensor<64x256x24xf32>) { +// CHECK: ^bb0(%[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32, %[[VAL_22:.*]]: f32): +// CHECK: %[[VAL_23:.*]] = arith.addf %[[VAL_20]], %[[VAL_21]] : f32 +// CHECK: linalg.yield %[[VAL_23]] : f32 +// CHECK: } -> tensor<64x256x24xf32> +// CHECK: %[[VAL_24:.*]] = tensor.insert_slice %[[VAL_25:.*]] into %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1] +// CHECK: scf.yield %[[VAL_16]], %[[VAL_24]] : tensor<256x256xf32>, tensor<256x256x24xf32> +// CHECK: } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.for"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer_using_slice %slice_op in (%loop) num_consumer_to_fuse = 1 + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +func.func @multi_slice_fusion1(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>, %arg3 : index) -> tensor<?xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32> + %dim1 = tensor.dim %arg0, %c1 : tensor<?x?xf32> + %loop:2 = scf.forall (%iv0) = (%c0) to (%dim0) step (%arg3) shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?xf32>, tensor<?xf32>) { + %tilesize = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3] + %arg0_slice = tensor.extract_slice %arg0[%iv0, 0] [%tilesize, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32> + %init0_slice = tensor.extract_slice %init0[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32> + %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32> + %generic:2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice, %init1_slice : tensor<?xf32>, tensor<?xf32>) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %0 = arith.mulf %b0, %b1 : f32 + %1 = arith.addf %b0, %b2 : f32 + linalg.yield %0, %1 : f32, f32 + } -> (tensor<?xf32>, tensor<?xf32>) + scf.forall.in_parallel { + tensor.parallel_insert_slice %generic#0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32> + tensor.parallel_insert_slice %generic#1 into %init1[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32> + } + } + %empty = tensor.empty(%dim0) : tensor<?xf32> + %result = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + ins(%loop#0, %loop#1 : tensor<?xf32>, tensor<?xf32>) outs(%empty : tensor<?xf32>) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %0 = arith.addf %b0, %b1 : f32 + linalg.yield %0 : f32 + } -> tensor<?xf32> + return %result : tensor<?xf32> +} +// CHECK-LABEL: func @multi_slice_fusion1( +// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32> +// CHECK: %[[C0:.+]] = arith.constant 0 +// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0]]) +// CHECK: %[[RESULT:.+]]:3 = scf.forall (%[[IV:.+]]) = +// CHECK-SAME: , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]]) +// CHECK: %[[TILESIZE:.+]] = affine.min +// CHECK-DAG: %[[GENERIC:.+]]:2 = linalg.generic +// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]]] [%[[TILESIZE]]] +// CHECK: %[[FUSED:.+]] = linalg.generic +// CHECK-SAME: ins(%[[GENERIC]]#0, %[[GENERIC]]#1 : +// CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]] +// CHECK: return %[[RESULT]]#2 + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer_using_slice %yield0, %yield1 in (%loop) + : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// Check that when the given operand tiles are inconsistent, tiling fails. + +func.func @multi_slice_fusion2(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>, %arg3 : index) -> tensor<?xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32> + %dim1 = tensor.dim %arg0, %c1 : tensor<?x?xf32> + %loop:2 = scf.forall (%iv0) = (%c0) to (%dim0) step (%arg3) shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?xf32>, tensor<?xf32>) { + %tilesize = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3] + %arg0_slice = tensor.extract_slice %arg0[%iv0, 0] [%tilesize, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32> + %init0_slice = tensor.extract_slice %init0[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32> + %generic0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice : tensor<?xf32>) { + ^bb0(%b0 : f32, %b1 : f32): + %0 = arith.mulf %b0, %b1 : f32 + linalg.yield %0 : f32 + } -> tensor<?xf32> + %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32> + %generic1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0_slice : tensor<?x?xf32>) outs(%init1_slice: tensor<?xf32>) { + ^bb0(%b0 : f32, %b1 : f32): + %0 = arith.addf %b0, %b1 : f32 + linalg.yield %0: f32 + } -> tensor<?xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %generic0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32> + tensor.parallel_insert_slice %generic1 into %init1[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32> + } + } + %empty = tensor.empty(%dim0) : tensor<?xf32> + %result = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + ins(%loop#0, %loop#1 : tensor<?xf32>, tensor<?xf32>) outs(%empty : tensor<?xf32>) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %0 = arith.addf %b0, %b1 : f32 + linalg.yield %0 : f32 + } -> tensor<?xf32> + return %result : tensor<?xf32> +} +// CHECK-LABEL: func @multi_slice_fusion2( +// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32> +// CHECK: %[[C0:.+]] = arith.constant 0 +// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0]]) +// CHECK: %[[RESULT:.+]]:3 = scf.forall (%[[IV:.+]]) = +// CHECK-SAME: , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]]) +// CHECK: %[[TILESIZE:.+]] = affine.min +// CHECK: %[[GENERIC0:.+]] = linalg.generic +// CHECK: %[[GENERIC1:.+]] = linalg.generic +// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]]] [%[[TILESIZE]]] +// CHECK: %[[FUSED:.+]] = linalg.generic +// CHECK-SAME: ins(%[[GENERIC0]], %[[GENERIC1]] : +// CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]] +// CHECK: return %[[RESULT]]#2 + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer_using_slice %yield0, %yield1 in (%loop) + : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +func.func @multi_slice_fusion_with_broadcast(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?xf32>, + %arg3 : index, %arg4 : index) -> tensor<?x?xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %dim0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32> + %dim1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32> + %dim2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32> + %loop:2 = scf.forall (%iv0, %iv1) = (%c0, %c0) to (%dim0, %dim1) step (%arg3, %arg4) + shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?x?xf32>, tensor<?xf32>) { + %tilesize0 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3] + %tilesize1 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv1)[%dim1, %arg4] + %arg0_slice = tensor.extract_slice %arg0[%iv0, %iv1, 0] [%tilesize0, %tilesize1, %dim2] [1, 1, 1] + : tensor<?x?x?xf32> to tensor<?x?x?xf32> + %init0_slice = tensor.extract_slice %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1] + : tensor<?x?xf32> to tensor<?x?xf32> + %generic0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init0_slice : tensor<?x?xf32>) { + ^bb0(%b0 : f32, %b1 : f32): + %0 = arith.mulf %b0, %b1 : f32 + linalg.yield %0 : f32 + } -> tensor<?x?xf32> + %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize0] [1] : tensor<?xf32> to tensor<?xf32> + %generic1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%generic0 : tensor<?x?xf32>) outs(%init1_slice: tensor<?xf32>) { + ^bb0(%b0 : f32, %b1 : f32): + %0 = arith.addf %b0, %b1 : f32 + linalg.yield %0: f32 + } -> tensor<?xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1] + : tensor<?x?xf32> into tensor<?x?xf32> + tensor.parallel_insert_slice %generic1 into %init1[%iv0] [%tilesize0] [1] : tensor<?xf32> into tensor<?xf32> + } + } + %empty = tensor.empty(%dim0, %dim1) : tensor<?x?xf32> + %result = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%loop#0, %loop#1 : tensor<?x?xf32>, tensor<?xf32>) outs(%empty : tensor<?x?xf32>) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %0 = arith.addf %b0, %b1 : f32 + linalg.yield %0 : f32 + } -> tensor<?x?xf32> + return %result : tensor<?x?xf32> +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer_using_slice %yield0, %yield1 in (%loop) + : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK-LABEL: func @multi_slice_fusion_with_broadcast( +// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32> +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 +// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0]], %[[DIM1]]) +// CHECK: %[[RESULT:.+]]:3 = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) = +// CHECK-SAME: , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]]) +// CHECK-DAG: %[[TILESIZE0:.+]] = affine.min {{.+}}(%[[IV0]]) +// CHECK-DAG: %[[TILESIZE1:.+]] = affine.min {{.+}}(%[[IV1]]) +// CHECK: %[[GENERIC0:.+]] = linalg.generic +// CHECK: %[[GENERIC1:.+]] = linalg.generic +// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]] [%[[TILESIZE0]], %[[TILESIZE1]]] +// CHECK: %[[FUSED:.+]] = linalg.generic +// CHECK-SAME: ins(%[[GENERIC0]], %[[GENERIC1]] : +// CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV0]], %[[IV1]]] [%[[TILESIZE0]], %[[TILESIZE1]]] +// CHECK: return %[[RESULT]]#2 + +// ----- + +func.func @multi_slice_fusion_invalid(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>, + %arg3 : index, %arg4 : index) -> tensor<?x?xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %dim0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32> + %dim1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32> + %dim2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32> + %loop:2 = scf.forall (%iv0, %iv1) = (%c0, %c0) to (%dim0, %dim1) step (%arg3, %arg4) + shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?x?xf32>, tensor<?x?xf32>) { + %tilesize0 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3] + %tilesize1 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv1)[%dim1, %arg4] + %arg0_slice = tensor.extract_slice %arg0[%iv0, %iv1, 0] [%tilesize0, %tilesize1, %dim2] [1, 1, 1] + : tensor<?x?x?xf32> to tensor<?x?x?xf32> + %init0_slice = tensor.extract_slice %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1] + : tensor<?x?xf32> to tensor<?x?xf32> + %generic0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init0_slice : tensor<?x?xf32>) { + ^bb0(%b0 : f32, %b1 : f32): + %0 = arith.mulf %b0, %b1 : f32 + linalg.yield %0 : f32 + } -> tensor<?x?xf32> + %init1_slice = tensor.extract_slice %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1] + : tensor<?x?xf32> to tensor<?x?xf32> + %generic1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init1_slice: tensor<?x?xf32>) { + ^bb0(%b0 : f32, %b1 : f32): + %0 = arith.addf %b0, %b1 : f32 + linalg.yield %0: f32 + } -> tensor<?x?xf32> + scf.forall.in_parallel { + // expected-error @below {{failed to fuse consumer of slice}} + tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1] + : tensor<?x?xf32> into tensor<?x?xf32> + tensor.parallel_insert_slice %generic1 into %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1] + : tensor<?x?xf32> into tensor<?x?xf32> + } + } + %empty = tensor.empty(%dim0, %dim1) : tensor<?x?xf32> + %result = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%loop#0, %loop#1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%empty : tensor<?x?xf32>) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %0 = arith.addf %b0, %b1 : f32 + linalg.yield %0 : f32 + } -> tensor<?x?xf32> + return %result : tensor<?x?xf32> +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer_using_slice %yield0, %yield1 in (%loop) + : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir index 7888462..0137e2a 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir @@ -1,8 +1,8 @@ -// RUN: mlir-opt --transform-interpreter --cse --split-input-file --verify-diagnostics %s | FileCheck %s +// RUN: mlir-opt --transform-interpreter --cse --split-input-file --verify-diagnostics --mlir-print-local-scope %s | FileCheck %s #map = affine_map<(d0) -> (d0)> module { - func.func @fuse_tileable_consumer_scf_for(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> { + func.func @fuse_tilable_consumer_scf_for(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> { %c4 = arith.constant 4 : index %c64 = arith.constant 64 : index %c0 = arith.constant 0 : index @@ -28,14 +28,14 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %loop = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + %add = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %yield in (%loop) + %a, %new_loop = transform.test.fuse_consumer %add into (%loop) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } -// CHECK: func.func @fuse_tileable_consumer_scf_for( +// CHECK: func.func @fuse_tilable_consumer_scf_for( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32> // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>) @@ -60,8 +60,61 @@ module attributes {transform.with_named_sequence} { // ----- +#map = affine_map<(d0) -> (d0)> module { - func.func @fuse_tileable_consumer_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> { + func.func @fuse_tilable_consumer_nested_scf_for(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2 : tensor<?x?xf32>, + %lb0 : index, %ub0 : index, %step0 : index, + %lb1 : index, %ub1 : index, %step1 : index) -> tensor<?x?xf32> { + %0 = scf.for %arg3 = %lb0 to %ub0 step %step0 iter_args(%init0 = %arg0) -> tensor<?x?xf32> { + %1 = scf.for %arg4 = %lb1 to %ub1 step %step1 iter_args(%init1 = %init0) -> tensor<?x?xf32> { + %extracted_slice = tensor.extract_slice %init1[%arg3, %arg4] [%step0, %step1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32> + %2 = tensor.insert_slice %extracted_slice into %init1[%arg3, %arg4] [%step0, %step1] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32> + scf.yield %2 : tensor<?x?xf32> + } + scf.yield %1 : tensor<?x?xf32> + } + %2 = linalg.add ins(%0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> + return %2 : tensor<?x?xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %loops = transform.structured.match ops{["scf.for"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop0, %loop1 = transform.split_handle %loops + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %add = transform.structured.match ops{["linalg.add"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %new_loop0, %new_loop1 = transform.test.fuse_consumer %add into (%loop0, %loop1) + : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: func @fuse_tilable_consumer_nested_scf_for( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32> +// CHECK: %[[OUTER_RESULT:.+]]:2 = scf.for +// CHECK-SAME: iter_args(%[[INIT00:[a-zA-Z0-9_]+]] = %[[ARG0]], %[[INIT01:[a-zA-Z0-9_]+]] = %[[ARG2]]) +// CHECK: %[[INNER_RESULT:.+]]:2 = scf.for +// CHECK-SAME: iter_args(%[[INIT10:[a-zA-Z0-9_]+]] = %[[INIT00]], %[[INIT11:[a-zA-Z0-9_]+]] = %[[INIT01]]) +// CHECK-DAG: %[[OPERAND1:.+]] = tensor.extract_slice %[[INIT10]] +// CHECK-DAG: %[[OLD_INSERT_SLICE:.+]] = tensor.insert_slice %[[OPERAND1]] into %[[INIT10]] +// CHECK-DAG: %[[OPERAND2:.+]] = tensor.extract_slice %[[ARG1]] +// CHECK-DAG: %[[INIT:.+]] = tensor.extract_slice %[[INIT11]] +// CHECK: %[[ADD:.+]] = linalg.add +// CHECK-SAME: ins(%[[OPERAND1]], %[[OPERAND2]] : +// CHECK-SAME: outs(%[[INIT]] : +// CHECK: %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[ADD]] into %[[INIT11]] +// CHECK: scf.yield %[[OLD_INSERT_SLICE]], %[[INSERT_SLICE]] +// CHECK: scf.yield %[[INNER_RESULT]]#0, %[[INNER_RESULT]]#1 +// CHECK: return %[[OUTER_RESULT]]#1 + +// ----- + +module { + func.func @fuse_tilable_consumer_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> { %c4 = arith.constant 4 : index %c64 = arith.constant 64 : index %c0 = arith.constant 0 : index @@ -83,19 +136,16 @@ module { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + %add = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op %loop = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %first_slice_op, %second_slice_op = transform.split_handle %slice_ops - : (!transform.any_op) - -> (!transform.any_op, !transform.any_op) - %a, %b = transform.test.fuse_consumer %first_slice_op in (%loop) + %a, %new_loop = transform.test.fuse_consumer %add into (%loop) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } -// CHECK: func.func @fuse_tileable_consumer_scf_forall( +// CHECK: func.func @fuse_tilable_consumer_scf_forall( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32>) @@ -124,7 +174,7 @@ module attributes {transform.with_named_sequence} { #map = affine_map<(d0) -> (d0)> module { - func.func @fuse_tileable_consumer_scf_for_multi_yielding_consumer(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> { + func.func @fuse_tilable_consumer_scf_for_multi_yielding_consumer(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> { %c4 = arith.constant 4 : index %c64 = arith.constant 64 : index %c0 = arith.constant 0 : index @@ -155,16 +205,18 @@ module { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + %generics = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %producer, %consumer = transform.split_handle %generics + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) %loop = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %yield in (%loop) + %a, %new_loop = transform.test.fuse_consumer %consumer into (%loop) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } -// CHECK: func.func @fuse_tileable_consumer_scf_for_multi_yielding_consumer( +// CHECK: func.func @fuse_tilable_consumer_scf_for_multi_yielding_consumer( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32> // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>) @@ -193,7 +245,7 @@ module attributes {transform.with_named_sequence} { #map = affine_map<(d0, d1) -> (d0, d1)> module { - func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64x32xf32>) -> (tensor<64x64xf32>, tensor<2048xf32>) { + func.func @fuse_tilable_consumer_scf_forall_multi_yielding_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64x32xf32>) -> (tensor<64x64xf32>, tensor<2048xf32>) { %c4 = arith.constant 4 : index %c64 = arith.constant 64 : index %c0 = arith.constant 0 : index @@ -224,19 +276,16 @@ module { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + %generic = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op %loop = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %first_slice_op, %second_slice_op = transform.split_handle %slice_ops - : (!transform.any_op) - -> (!transform.any_op, !transform.any_op) - %a, %b = transform.test.fuse_consumer %first_slice_op in (%loop) + %a, %new_loops = transform.test.fuse_consumer %generic into (%loop) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } -// CHECK: func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer( +// CHECK: func.func @fuse_tilable_consumer_scf_forall_multi_yielding_consumer( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32> @@ -293,17 +342,15 @@ module { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + %consumer = transform.structured.match ops{["linalg.unpack"]} in %arg1 : (!transform.any_op) -> !transform.any_op %loop = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %slice_op in (%loop) + %a, %new_loop = transform.test.fuse_consumer %consumer into (%loop) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } -// CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)> -// CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2048)> // CHECK: func.func @fuse_unpack_consumer_into_scf_forall( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> @@ -315,8 +362,8 @@ module attributes {transform.with_named_sequence} { // CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] // CHECK: %[[GENERIC_OUT:.*]] = linalg.generic // CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] : -// CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_OFFSET_MAP]](%[[IV1]]) -// CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min #[[UNPACK_RESULT_SIZE_MAP]](%[[IV1]]) +// CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 * 32)>(%[[IV1]]) +// CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min affine_map<(d0) -> (1024, d0 * -32 + 2048)>(%[[IV1]]) // CHECK: %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1] // CHECK: %[[TILED_UNPACK_OUT:.*]] = linalg.unpack %[[GENERIC_OUT]] // CHECK-SAME: outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] @@ -356,17 +403,15 @@ module { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + %consumer = transform.structured.match ops{["linalg.unpack"]} in %arg1 : (!transform.any_op) -> !transform.any_op %loop = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %slice_op in (%loop) + %a, %new_loop = transform.test.fuse_consumer %consumer into (%loop) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } -// CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)> -// CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2047)> // CHECK: func.func @fuse_unaligned_unpack_consumer_into_scf_forall( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> @@ -378,8 +423,8 @@ module attributes {transform.with_named_sequence} { // CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] // CHECK: %[[GENERIC_OUT:.*]] = linalg.generic // CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] : -// CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_OFFSET_MAP]](%[[IV1]]) -// CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min #[[UNPACK_RESULT_SIZE_MAP]](%[[IV1]]) +// CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 * 32)>(%[[IV1]]) +// CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min affine_map<(d0) -> (1024, d0 * -32 + 2047)>(%[[IV1]]) // CHECK: %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1] // CHECK: %[[TILED_UNPACK_OUT:.*]] = linalg.unpack %[[GENERIC_OUT]] // CHECK-SAME: outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] @@ -419,16 +464,15 @@ module { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + %consumer = transform.structured.match ops{["linalg.pack"]} in %arg1 : (!transform.any_op) -> !transform.any_op %loop = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %slice_op in (%loop) + %a, %new_loop = transform.test.fuse_consumer %consumer into (%loop) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } -// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)> // CHECK: func.func @fuse_perfect_tiling_pack_consumer( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> @@ -440,7 +484,7 @@ module attributes {transform.with_named_sequence} { // CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] // CHECK: %[[GENERIC_OUT:.*]] = linalg.generic // CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] : -// CHECK: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV1]]) +// CHECK: %[[PACK_RESULT_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%[[IV1]]) // CHECK: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], %[[IV2]], 0] [2, 32, 16] [1, 1, 1] // CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[GENERIC_OUT]] // CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16] @@ -471,13 +515,12 @@ func.func @fuse_pack_consumer_if_single_iteration(%arg0: tensor<4x4xf32>) -> ten module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %consumer = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused_consumer, %new_loop = transform.test.fuse_consumer %consumer into(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } -// CHECK: #[[MAP:.*]] = affine_map<(d0) -> (-d0 + 4, 16)> // CHECK: func.func @fuse_pack_consumer_if_single_iteration( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-DAG: %[[PACK_INIT:.*]] = tensor.empty() : tensor<1x4x16x1xf32> @@ -485,7 +528,7 @@ module attributes {transform.with_named_sequence} { // CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (4) step (16) // CHECK-SAME: shared_outs(%[[ELEM_OUT_ARG:.*]] = %[[ELEM_INIT]], %[[PACK_OUT_ARG:.*]] = %[[PACK_INIT]]) -// CHECK-DAG: %[[SIZE:.+]] = affine.min #[[MAP]](%[[IV]]) +// CHECK-DAG: %[[SIZE:.+]] = affine.min affine_map<(d0) -> (-d0 + 4, 16)>(%[[IV]]) // CHECK-DAG: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1] // CHECK-DAG: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1] // CHECK: %[[ELEM:.*]] = linalg.exp @@ -517,13 +560,12 @@ func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm(%arg0: tensor< module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused_consumer, %new_loop = transform.test.fuse_consumer %0 into(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } -// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)> // CHECK: func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] @@ -535,7 +577,7 @@ module attributes {transform.with_named_sequence} { // CHECK: %[[ELEM:.*]] = linalg.exp // CHECK-SAME: ins(%[[ELEM_SRC]] // CHECK-SAME: outs(%[[ELEM_DEST]] -// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]]) +// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%[[IV]]) // CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1] // CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]] // CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 1] @@ -566,13 +608,12 @@ func.func @fuse_pack_consumer_with_no_pad_dynamic_dim(%arg0: tensor<64x?xf32>, % module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused_consumer, %new_loop = transform.test.fuse_consumer %0 into(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } -// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)> // CHECK: func.func @fuse_pack_consumer_with_no_pad_dynamic_dim( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] @@ -584,7 +625,7 @@ module attributes {transform.with_named_sequence} { // CHECK: %[[ELEM:.*]] = linalg.exp // CHECK-SAME: ins(%[[ELEM_SRC]] // CHECK-SAME: outs(%[[ELEM_DEST]] -// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]]) +// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%[[IV]]) // CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0] [64, 1, 16] [1, 1, 1] // CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]] // CHECK-SAME: inner_dims_pos = [1] inner_tiles = [16] @@ -616,16 +657,12 @@ func.func @fuse_pack_consumer_with_padding_semantics(%arg0: tensor<64x32xf32>, % module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused_consumer, %new_loop = transform.test.fuse_consumer %0 into(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (-d0 + 64, 15)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (d0 floordiv 3)> -// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0) -> (d0 ceildiv 3)> -// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0) -> (d0 floordiv 16)> // CHECK: func.func @fuse_pack_consumer_with_padding_semantics( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] @@ -633,7 +670,7 @@ module attributes {transform.with_named_sequence} { // CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %{{.*}}:2 = scf.forall (%[[I:.*]], %[[J:.*]]) = (0, 0) to (64, 32) step (15, 16) // CHECK-SAME: shared_outs(%[[ELEM_OUT:.*]] = %[[ARG1]], %[[PACK_OUT:.*]] = %[[OUT_INIT]]) -// CHECK: %[[SIZE:.+]] = affine.min #[[MAP0]](%[[I]]) +// CHECK: %[[SIZE:.+]] = affine.min affine_map<(d0) -> (-d0 + 64, 15)>(%[[I]]) // CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]] // CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1] // CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT]] @@ -641,9 +678,9 @@ module attributes {transform.with_named_sequence} { // CHECK: %[[ELEM:.*]] = linalg.exp // CHECK-SAME: ins(%[[ELEM_SRC]] // CHECK-SAME: outs(%[[ELEM_DEST]] -// CHECK-DAG: %[[D0_OFFSET:.*]] = affine.apply #[[MAP1]](%[[I]]) -// CHECK-DAG: %[[D0_SIZE:.*]] = affine.apply #[[MAP2]](%[[SIZE]]) -// CHECK-DAG: %[[D1_OFFSET:.*]] = affine.apply #[[MAP3]](%[[J]]) +// CHECK-DAG: %[[D0_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 floordiv 3)>(%[[I]]) +// CHECK-DAG: %[[D0_SIZE:.*]] = affine.apply affine_map<(d0) -> (d0 ceildiv 3)>(%[[SIZE]]) +// CHECK-DAG: %[[D1_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%[[J]]) // CHECK-DAG: %[[PACK_INIT:.*]] = tensor.extract_slice %[[PACK_OUT]] // CHECK-SAME: [%[[D0_OFFSET]], %[[D1_OFFSET]], 0, 0] [%[[D0_SIZE]], 1, 3, 16] [1, 1, 1, 1] // CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]] @@ -674,20 +711,21 @@ func.func @nofuse_pack_with_imperfect_tiling(%arg0: tensor<30xf32>) -> tensor<5x linalg.yield %5 : f32 } -> tensor<5xf32> scf.forall.in_parallel { - // expected-error @below {{failed to fuse consumer of slice}} + tensor.parallel_insert_slice %4 into %arg2[%3] [5] [1] : tensor<5xf32> into tensor<30xf32> } } %2 = tensor.empty() : tensor<5x6xf32> + // expected-error @below {{failed to fuse consumer of slice}} %pack = linalg.pack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [6] into %2 : tensor<30xf32> -> tensor<5x6xf32> return %pack : tensor<5x6xf32> } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused_consumer, %new_loop = transform.test.fuse_consumer %0 into(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } @@ -717,11 +755,15 @@ module { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + %mulop = transform.structured.match ops{["linalg.mul"]} in %arg1 : (!transform.any_op) -> !transform.any_op %loop = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %slice_op in (%loop) num_consumer_to_fuse = 2 + %fused_consumer, %new_loop = transform.test.fuse_consumer %mulop into (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %expop = transform.structured.match ops{["linalg.exp"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %fused_consumer_2, %new_loop_2 = transform.test.fuse_consumer %expop into (%new_loop) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } @@ -741,64 +783,20 @@ module attributes {transform.with_named_sequence} { // CHECK-SAME: ins(%[[ADD_INS0_SLICE]], %[[ADD_INS1_SLICE]] : // CHECK-SAME: outs(%[[ADD_OUT_SLICE]] : // CHECK: %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] -// CHECK: %[[EXP_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] -// CHECK: %[[TILED_EXP_OUT:.*]] = linalg.exp -// CHECK-SAME: ins(%[[TILED_ADD_OUT]] : -// CHECK-SAME: outs(%[[EXP_OUT_SLICE]] : // CHECK: %[[MUL_INS2_SLICE:.*]] = tensor.extract_slice %[[ARG2]][%[[IV1]], 0] [64, 256] [1, 1] -// CHECK: %[[MUL_OUT_SLICE:.*]] = tensor.extract_slice %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[MUL_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] // CHECK: %[[TILED_MUL_OUT:.*]] = linalg.mul // CHECK-SAME: ins(%[[TILED_ADD_OUT]], %[[MUL_INS2_SLICE]] : // CHECK-SAME: outs(%[[MUL_OUT_SLICE]] : -// CHECK: %[[INSERT_EXP:.*]] = tensor.insert_slice %[[TILED_EXP_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] -// CHECK: %[[INSERT_MUL:.*]] = tensor.insert_slice %[[TILED_MUL_OUT]] into %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] -// CHECK: scf.yield %[[INSERT_ADD]], %[[INSERT_EXP]], %[[INSERT_MUL]] : -// CHECK: } -// CHECK: return %[[LOOP_RESULT]]#2, %[[LOOP_RESULT]]#1 : - -// ----- - -module { - func.func @no_fuse_only_dps_consumer(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<258x258xf32>) { - %c0 = arith.constant 0 : index - %c64 = arith.constant 64 : index - %c256 = arith.constant 256 : index - %cst = arith.constant 0.000000e+00 : f32 - %dest0 = tensor.empty() : tensor<256x256xf32> - %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest0) -> (tensor<256x256xf32>) { - %extracted_slice_1 = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> - %extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> - %extracted_slice_3 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> - %3 = linalg.add ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice_1 : tensor<64x256xf32>) -> tensor<64x256xf32> - %insert_slice = tensor.insert_slice %3 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32> - scf.yield %insert_slice : tensor<256x256xf32> - } - %dest1 = tensor.empty() : tensor<258x258xf32> - %4 = tensor.insert_slice %1 into %dest1[0, 0] [256, 256] [1, 1] : tensor<256x256xf32> into tensor<258x258xf32> - %5 = linalg.mul ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32> - return %5, %4 : tensor<256x256xf32>, tensor<258x258xf32> - } -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %slice_ops = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 - : (!transform.any_op) -> !transform.any_op - %loop = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %slice_op, %other_slice = transform.split_handle %slice_ops : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %a, %b = transform.test.fuse_consumer %slice_op in (%loop) num_consumer_to_fuse = 1 - : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.yield - } -} -// CHECK: func.func @no_fuse_only_dps_consumer( -// CHECK: %[[LOOP_RESULT:.*]]:2 = scf.for {{.*}} { -// CHECK: linalg.add -// CHECK: linalg.mul -// CHECK: scf.yield +// CHECK: %[[EXP_OUT_SLICE:.*]] = tensor.extract_slice %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[TILED_EXP_OUT:.*]] = linalg.exp +// CHECK-SAME: ins(%[[TILED_ADD_OUT]] : +// CHECK-SAME: outs(%[[EXP_OUT_SLICE]] : +// CHECK: %[[INSERT_MUL:.*]] = tensor.insert_slice %[[TILED_MUL_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[INSERT_EXP:.*]] = tensor.insert_slice %[[TILED_EXP_OUT]] into %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: scf.yield %[[INSERT_ADD]], %[[INSERT_MUL]], %[[INSERT_EXP]] : // CHECK: } -// CHECK: %[[RES_SLICE:.+]] = tensor.insert_slice -// CHECK: return %[[LOOP_RESULT]]#1, %[[RES_SLICE]] +// CHECK: return %[[LOOP_RESULT]]#1, %[[LOOP_RESULT]]#2 : // ----- @@ -829,40 +827,41 @@ module { } } -// CHECK: func.func @fuse_with_tilable_consumer_with_projected_permutations(%[[VAL_0:.*]]: tensor<256x256xf32>, %[[VAL_1:.*]]: tensor<256x256xf32>, %[[VAL_2:.*]]: tensor<24xf32>) -> tensor<256x256x24xf32> { -// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_4:.*]] = arith.constant 64 : index -// CHECK: %[[VAL_5:.*]] = arith.constant 256 : index -// CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<256x256xf32> -// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<256x256x24xf32> -// CHECK: %[[VAL_8:.*]]:2 = scf.for %[[VAL_9:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_10:.*]] = %[[VAL_6]], %[[VAL_11:.*]] = %[[VAL_7]]) -> (tensor<256x256xf32>, tensor<256x256x24xf32>) { -// CHECK: %[[VAL_12:.*]] = tensor.extract_slice %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] -// CHECK: %[[VAL_13:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] -// CHECK: %[[VAL_14:.*]] = tensor.extract_slice %[[VAL_1]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] -// CHECK: %[[VAL_15:.*]] = linalg.add ins(%[[VAL_13]], %[[VAL_14]] : tensor<64x256xf32>, tensor<64x256xf32>) outs(%[[VAL_12]] : tensor<64x256xf32>) -> tensor<64x256xf32> -// CHECK: %[[VAL_16:.*]] = tensor.insert_slice %[[VAL_15]] into %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] -// CHECK: %[[VAL_17:.*]] = tensor.extract_slice %[[VAL_2]][0] [24] [1] : tensor<24xf32> to tensor<24xf32> -// CHECK: %[[VAL_18:.*]] = tensor.extract_slice %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1] -// CHECK: %[[VAL_19:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_15]], %[[VAL_17]] : tensor<64x256xf32>, tensor<24xf32>) outs(%[[VAL_18]] : tensor<64x256x24xf32>) { -// CHECK: ^bb0(%[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32, %[[VAL_22:.*]]: f32): -// CHECK: %[[VAL_23:.*]] = arith.addf %[[VAL_20]], %[[VAL_21]] : f32 -// CHECK: linalg.yield %[[VAL_23]] : f32 -// CHECK: } -> tensor<64x256x24xf32> -// CHECK: %[[VAL_24:.*]] = tensor.insert_slice %[[VAL_25:.*]] into %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1] -// CHECK: scf.yield %[[VAL_16]], %[[VAL_24]] : tensor<256x256xf32>, tensor<256x256x24xf32> -// CHECK: } - module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + %consumer = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op %loop = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %slice_op in (%loop) num_consumer_to_fuse = 1 + %a, %b = transform.test.fuse_consumer %consumer into (%loop) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } +// CHECK: func.func @fuse_with_tilable_consumer_with_projected_permutations( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<256x256xf32>, %[[VAL_1:.*]]: tensor<256x256xf32>, %[[VAL_2:.*]]: tensor<24xf32>) -> tensor<256x256x24xf32> { +// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 64 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 256 : index +// CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<256x256xf32> +// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<256x256x24xf32> +// CHECK: %[[VAL_8:.*]]:2 = scf.for %[[VAL_9:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_10:.*]] = %[[VAL_6]], %[[VAL_11:.*]] = %[[VAL_7]]) -> (tensor<256x256xf32>, tensor<256x256x24xf32>) { +// CHECK: %[[VAL_12:.*]] = tensor.extract_slice %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] +// CHECK: %[[VAL_13:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] +// CHECK: %[[VAL_14:.*]] = tensor.extract_slice %[[VAL_1]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] +// CHECK: %[[VAL_15:.*]] = linalg.add ins(%[[VAL_13]], %[[VAL_14]] : tensor<64x256xf32>, tensor<64x256xf32>) outs(%[[VAL_12]] : tensor<64x256xf32>) -> tensor<64x256xf32> +// CHECK: %[[VAL_16:.*]] = tensor.insert_slice %[[VAL_15]] into %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] +// CHECK: %[[VAL_17:.*]] = tensor.extract_slice %[[VAL_2]][0] [24] [1] : tensor<24xf32> to tensor<24xf32> +// CHECK: %[[VAL_18:.*]] = tensor.extract_slice %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1] +// CHECK: %[[VAL_19:.*]] = linalg.generic +// CHECK-SAME: ins(%[[VAL_15]], %[[VAL_17]] : tensor<64x256xf32>, tensor<24xf32>) outs(%[[VAL_18]] : tensor<64x256x24xf32>) { +// CHECK: ^bb0(%[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32, %[[VAL_22:.*]]: f32): +// CHECK: %[[VAL_23:.*]] = arith.addf %[[VAL_20]], %[[VAL_21]] : f32 +// CHECK: linalg.yield %[[VAL_23]] : f32 +// CHECK: } -> tensor<64x256x24xf32> +// CHECK: %[[VAL_24:.*]] = tensor.insert_slice %[[VAL_25:.*]] into %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1] +// CHECK: scf.yield %[[VAL_16]], %[[VAL_24]] : tensor<256x256xf32>, tensor<256x256x24xf32> +// CHECK: } // ----- @@ -878,12 +877,12 @@ func.func @multi_slice_fusion1(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, % %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32> %generic:2 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], - iterator_types = ["parallel", "reduction"]} - ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice, %init1_slice : tensor<?xf32>, tensor<?xf32>) { + iterator_types = ["parallel", "reduction"]} + ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice, %init1_slice : tensor<?xf32>, tensor<?xf32>) { ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): %0 = arith.mulf %b0, %b1 : f32 - %1 = arith.addf %b0, %b2 : f32 - linalg.yield %0, %1 : f32, f32 + %1 = arith.addf %b0, %b2 : f32 + linalg.yield %0, %1 : f32, f32 } -> (tensor<?xf32>, tensor<?xf32>) scf.forall.in_parallel { tensor.parallel_insert_slice %generic#0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32> @@ -901,6 +900,19 @@ func.func @multi_slice_fusion1(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, % } -> tensor<?xf32> return %result : tensor<?xf32> } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %generics = transform.structured.match ops{["linalg.generic"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %producer, %consumer = transform.split_handle %generics : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer %consumer into (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} // CHECK-LABEL: func @multi_slice_fusion1( // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32> // CHECK: %[[C0:.+]] = arith.constant 0 @@ -916,23 +928,9 @@ func.func @multi_slice_fusion1(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, % // CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]] // CHECK: return %[[RESULT]]#2 -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %loop = transform.structured.match ops{["scf.forall"]} in %arg1 - : (!transform.any_op) -> !transform.any_op - %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 - : (!transform.any_op) -> !transform.any_op - %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop) - : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.yield - } -} // ----- -// Check that when the given operand tiles are inconsistent, tiling fails. - func.func @multi_slice_fusion2(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>, %arg3 : index) -> tensor<?xf32> { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -944,20 +942,20 @@ func.func @multi_slice_fusion2(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, % %init0_slice = tensor.extract_slice %init0[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32> %generic0 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], - iterator_types = ["parallel", "reduction"]} - ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice : tensor<?xf32>) { + iterator_types = ["parallel", "reduction"]} + ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice : tensor<?xf32>) { ^bb0(%b0 : f32, %b1 : f32): %0 = arith.mulf %b0, %b1 : f32 - linalg.yield %0 : f32 + linalg.yield %0 : f32 } -> tensor<?xf32> %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32> %generic1 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], - iterator_types = ["parallel", "reduction"]} - ins(%arg0_slice : tensor<?x?xf32>) outs(%init1_slice: tensor<?xf32>) { + iterator_types = ["parallel", "reduction"]} + ins(%arg0_slice : tensor<?x?xf32>) outs(%init1_slice: tensor<?xf32>) { ^bb0(%b0 : f32, %b1 : f32): - %0 = arith.addf %b0, %b1 : f32 - linalg.yield %0: f32 + %0 = arith.addf %b0, %b1 : f32 + linalg.yield %0: f32 } -> tensor<?xf32> scf.forall.in_parallel { tensor.parallel_insert_slice %generic0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32> @@ -975,6 +973,19 @@ func.func @multi_slice_fusion2(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, % } -> tensor<?xf32> return %result : tensor<?xf32> } +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %generics = transform.structured.match ops{["linalg.generic"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %producer1, %producer2, %consumer = transform.split_handle %generics : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer %consumer into (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + // CHECK-LABEL: func @multi_slice_fusion2( // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32> // CHECK: %[[C0:.+]] = arith.constant 0 @@ -991,19 +1002,6 @@ func.func @multi_slice_fusion2(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, % // CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]] // CHECK: return %[[RESULT]]#2 -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %loop = transform.structured.match ops{["scf.forall"]} in %arg1 - : (!transform.any_op) -> !transform.any_op - %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 - : (!transform.any_op) -> !transform.any_op - %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop) - : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.yield - } -} - // ----- func.func @multi_slice_fusion_with_broadcast(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?xf32>, @@ -1060,11 +1058,11 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %loop = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + %generics = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop) - : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %producer_1, %producer_2, %consumer = transform.split_handle %generics : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer %consumer into (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } @@ -1124,7 +1122,6 @@ func.func @multi_slice_fusion_invalid(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor< linalg.yield %0: f32 } -> tensor<?x?xf32> scf.forall.in_parallel { - // expected-error @below {{failed to fuse consumer of slice}} tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32> tensor.parallel_insert_slice %generic1 into %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1] @@ -1132,6 +1129,7 @@ func.func @multi_slice_fusion_invalid(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor< } } %empty = tensor.empty(%dim0, %dim1) : tensor<?x?xf32> + // expected-error @below {{failed to fuse consumer of slice}} %result = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} @@ -1146,11 +1144,11 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %loop = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + %generics = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop) - : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %producer_1, %producer_2, %consumer = transform.split_handle %generics : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer %consumer into (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } diff --git a/mlir/test/Pass/invalid-unsupported-operation.mlir b/mlir/test/Pass/invalid-unsupported-operation.mlir new file mode 100644 index 0000000..1ee4584 --- /dev/null +++ b/mlir/test/Pass/invalid-unsupported-operation.mlir @@ -0,0 +1,10 @@ +// RUN: mlir-opt %s -test-print-liveness -split-input-file -verify-diagnostics + +// Unnamed modules do not implement SymbolOpInterface. +// expected-error-re @+1 {{trying to schedule pass '{{.*}}TestLivenessPass' on an unsupported operation}} +module {} + +// ----- + +// Named modules implement SymbolOpInterface. +module @named_module {} diff --git a/mlir/test/Pass/pipeline-invalid.mlir b/mlir/test/Pass/pipeline-invalid.mlir index 948a133..bff2b1c 100644 --- a/mlir/test/Pass/pipeline-invalid.mlir +++ b/mlir/test/Pass/pipeline-invalid.mlir @@ -15,5 +15,5 @@ arith.constant 0 // ----- -// expected-error@below {{trying to schedule a pass on an unsupported operation}} +// expected-error-re@below {{trying to schedule pass '{{.*}}TestFunctionPass' on an unsupported operation}} module {} diff --git a/mlir/test/Target/Cpp/common-cpp.mlir b/mlir/test/Target/Cpp/common-cpp.mlir index 294e6af6..f397a4a 100644 --- a/mlir/test/Target/Cpp/common-cpp.mlir +++ b/mlir/test/Target/Cpp/common-cpp.mlir @@ -105,6 +105,25 @@ func.func @apply() -> !emitc.ptr<i32> { return %1 : !emitc.ptr<i32> } + +// CHECK-LABEL: void address_of() { +func.func @address_of() { + // CHECK-NEXT: int32_t [[V1:[^ ]*]]; + %0 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue<i32> + // CHECK-NEXT: int32_t* [[V2:[^ ]*]] = &[[V1]]; + %1 = emitc.address_of %0 : !emitc.lvalue<i32> + return +} + +// CHECK-LABEL: void dereference +// CHECK-SAME: (int32_t* [[ARG0:[^ ]*]]) { +func.func @dereference(%arg0: !emitc.ptr<i32>) { + // CHECK-NEXT: int32_t [[V1:[^ ]*]] = *[[ARG0]]; + %2 = emitc.dereference %arg0 : !emitc.ptr<i32> + emitc.load %2 : !emitc.lvalue<i32> + return +} + // CHECK: void array_type(int32_t v1[3], float v2[10][20]) func.func @array_type(%arg0: !emitc.array<3xi32>, %arg1: !emitc.array<10x20xf32>) { return diff --git a/mlir/test/Target/Cpp/expressions.mlir b/mlir/test/Target/Cpp/expressions.mlir index 9f1c816..2de94d0 100644 --- a/mlir/test/Target/Cpp/expressions.mlir +++ b/mlir/test/Target/Cpp/expressions.mlir @@ -314,14 +314,14 @@ func.func @different_expressions(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) return %v_load : i32 } -// CPP-DEFAULT: int32_t expression_with_dereference(int32_t [[VAL_1:v[0-9]+]], int32_t* [[VAL_2]]) { +// CPP-DEFAULT: int32_t expression_with_dereference_apply(int32_t [[VAL_1:v[0-9]+]], int32_t* [[VAL_2]]) { // CPP-DEFAULT-NEXT: return *([[VAL_2]] - [[VAL_1]]); // CPP-DEFAULT-NEXT: } -// CPP-DECLTOP: int32_t expression_with_dereference(int32_t [[VAL_1:v[0-9]+]], int32_t* [[VAL_2]]) { +// CPP-DECLTOP: int32_t expression_with_dereference_apply(int32_t [[VAL_1:v[0-9]+]], int32_t* [[VAL_2]]) { // CPP-DECLTOP-NEXT: return *([[VAL_2]] - [[VAL_1]]); // CPP-DECLTOP-NEXT: } -emitc.func @expression_with_dereference(%arg1: i32, %arg2: !emitc.ptr<i32>) -> i32 { +emitc.func @expression_with_dereference_apply(%arg1: i32, %arg2: !emitc.ptr<i32>) -> i32 { %c = emitc.expression %arg1, %arg2 : (i32, !emitc.ptr<i32>) -> i32 { %e = emitc.sub %arg2, %arg1 : (!emitc.ptr<i32>, i32) -> !emitc.ptr<i32> %d = emitc.apply "*"(%e) : (!emitc.ptr<i32>) -> i32 @@ -330,6 +330,28 @@ emitc.func @expression_with_dereference(%arg1: i32, %arg2: !emitc.ptr<i32>) -> i return %c : i32 } +// CPP-DEFAULT: bool expression_with_address_taken_apply(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t* [[VAL_3]]) { +// CPP-DEFAULT-NEXT: int32_t [[VAL_4:v[0-9]+]] = 42; +// CPP-DEFAULT-NEXT: return &[[VAL_4]] - [[VAL_2]] < [[VAL_3]]; +// CPP-DEFAULT-NEXT: } + +// CPP-DECLTOP: bool expression_with_address_taken_apply(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t* [[VAL_3]]) { +// CPP-DECLTOP-NEXT: int32_t [[VAL_4:v[0-9]+]]; +// CPP-DECLTOP-NEXT: [[VAL_4]] = 42; +// CPP-DECLTOP-NEXT: return &[[VAL_4]] - [[VAL_2]] < [[VAL_3]]; +// CPP-DECLTOP-NEXT: } + +func.func @expression_with_address_taken_apply(%arg0: i32, %arg1: i32, %arg2: !emitc.ptr<i32>) -> i1 { + %a = "emitc.variable"(){value = 42 : i32} : () -> !emitc.lvalue<i32> + %c = emitc.expression %arg1, %arg2, %a : (i32, !emitc.ptr<i32>, !emitc.lvalue<i32>) -> i1 { + %d = emitc.apply "&"(%a) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32> + %e = emitc.sub %d, %arg1 : (!emitc.ptr<i32>, i32) -> !emitc.ptr<i32> + %f = emitc.cmp lt, %e, %arg2 : (!emitc.ptr<i32>, !emitc.ptr<i32>) -> i1 + emitc.yield %f : i1 + } + return %c : i1 +} + // CPP-DEFAULT: bool expression_with_address_taken(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t* [[VAL_3]]) { // CPP-DEFAULT-NEXT: int32_t [[VAL_4:v[0-9]+]] = 42; // CPP-DEFAULT-NEXT: return &[[VAL_4]] - [[VAL_2]] < [[VAL_3]]; @@ -344,7 +366,7 @@ emitc.func @expression_with_dereference(%arg1: i32, %arg2: !emitc.ptr<i32>) -> i func.func @expression_with_address_taken(%arg0: i32, %arg1: i32, %arg2: !emitc.ptr<i32>) -> i1 { %a = "emitc.variable"(){value = 42 : i32} : () -> !emitc.lvalue<i32> %c = emitc.expression %arg1, %arg2, %a : (i32, !emitc.ptr<i32>, !emitc.lvalue<i32>) -> i1 { - %d = emitc.apply "&"(%a) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32> + %d = emitc.address_of %a : !emitc.lvalue<i32> %e = emitc.sub %d, %arg1 : (!emitc.ptr<i32>, i32) -> !emitc.ptr<i32> %f = emitc.cmp lt, %e, %arg2 : (!emitc.ptr<i32>, !emitc.ptr<i32>) -> i1 emitc.yield %f : i1 diff --git a/mlir/test/Target/LLVMIR/Import/debug-info-records.ll b/mlir/test/Target/LLVMIR/Import/debug-info-records.ll new file mode 100644 index 0000000..077871e --- /dev/null +++ b/mlir/test/Target/LLVMIR/Import/debug-info-records.ll @@ -0,0 +1,87 @@ +; RUN: mlir-translate -import-llvm -mlir-print-debuginfo -convert-debug-rec-to-intrinsics -emit-expensive-warnings -split-input-file %s 2>&1 | FileCheck %s +; RUN: mlir-translate -import-llvm -mlir-print-debuginfo -emit-expensive-warnings -split-input-file %s 2>&1 | FileCheck %s + +; CHECK: #[[LOCAL_VAR0:.*]] = #llvm.di_local_variable<scope = #di_lexical_block> +; CHECK: #[[LOCAL_VAR1:.*]] = #llvm.di_local_variable<scope = #di_lexical_block_file, name = "arg" +; CHECK: #[[LOCAL_VAR2:.*]] = #llvm.di_local_variable<scope = #di_lexical_block, name = "alloc" + +; CHECK: @callee() +define void @callee() { + ret void +} + +define void @func_with_empty_named_info() { + call void @callee() + ret void +} + +define void @func_no_debug() { + ret void +} + +; CHECK: llvm.func @func_with_debug(%[[ARG0:.*]]: i64 +define void @func_with_debug(i64 %0) !dbg !3 { + + ; CHECK: llvm.intr.dbg.value #[[LOCAL_VAR0]] = %[[ARG0]] : i64 + ; CHECK: llvm.intr.dbg.value #[[LOCAL_VAR1]] #llvm.di_expression<[DW_OP_LLVM_fragment(0, 1)]> = %[[ARG0]] : i64 + ; CHECK: %[[CST:.*]] = llvm.mlir.constant(1 : i32) : i32 + ; CHECK: %[[ADDR:.*]] = llvm.alloca %[[CST]] x i64 + ; CHECK: llvm.intr.dbg.declare #[[LOCAL_VAR2]] #llvm.di_expression<[DW_OP_deref, DW_OP_LLVM_convert(4, DW_ATE_signed)]> = %[[ADDR]] : !llvm.ptr + %2 = alloca i64, align 8, !dbg !19 + #dbg_value(i64 %0, !20, !DIExpression(DW_OP_LLVM_fragment, 0, 1), !22) + #dbg_declare(ptr %2, !23, !DIExpression(DW_OP_deref, DW_OP_LLVM_convert, 4, DW_ATE_signed), !25) + #dbg_value(i64 %0, !26, !DIExpression(), !27) + call void @func_no_debug(), !dbg !28 + %3 = add i64 %0, %0, !dbg !32 + ret void, !dbg !37 +} + +define void @empty_types() !dbg !38 { + ret void, !dbg !44 +} + +!llvm.dbg.cu = !{!0} +!llvm.module.flags = !{!2} + +!0 = distinct !DICompileUnit(language: DW_LANG_C, file: !1, producer: "MLIR", isOptimized: true, runtimeVersion: 0, splitDebugFilename: "test.dwo", emissionKind: FullDebug, nameTableKind: None) +!1 = !DIFile(filename: "foo.mlir", directory: "/test/") +!2 = !{i32 2, !"Debug Info Version", i32 3} +!3 = distinct !DISubprogram(name: "func_with_debug", linkageName: "func_with_debug", scope: !4, file: !1, line: 3, type: !6, scopeLine: 3, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !0) +!4 = !DINamespace(name: "nested", scope: !5) +!5 = !DINamespace(name: "toplevel", scope: null, exportSymbols: true) +!6 = !DISubroutineType(cc: DW_CC_normal, types: !7) +!7 = !{null, !8, !9, !11, !12, !13, !16} +!8 = !DIBasicType(name: "si64") +!9 = !DIDerivedType(tag: DW_TAG_pointer_type, baseType: !10, size: 64, align: 32, offset: 8, extraData: !10) +!10 = !DIBasicType(name: "si32", size: 32, encoding: DW_ATE_signed) +!11 = !DIDerivedType(tag: DW_TAG_pointer_type, name: "named", baseType: !10) +!12 = !DIDerivedType(tag: DW_TAG_pointer_type, baseType: !10, size: 64, align: 32, offset: 8, dwarfAddressSpace: 3) +!13 = distinct !DICompositeType(tag: DW_TAG_structure_type, name: "composite", file: !1, line: 42, size: 64, align: 32, elements: !14) +!14 = !{!15} +!15 = !DISubrange(count: 4) +!16 = !DICompositeType(tag: DW_TAG_array_type, name: "array", file: !1, baseType: !8, flags: DIFlagVector, elements: !17) +!17 = !{!18} +!18 = !DISubrange(lowerBound: 0, upperBound: 4, stride: 1) +!19 = !DILocation(line: 100, column: 12, scope: !3) +!20 = !DILocalVariable(name: "arg", arg: 1, scope: !21, file: !1, line: 6, type: !8, align: 32) +!21 = distinct !DILexicalBlockFile(scope: !3, file: !1, discriminator: 0) +!22 = !DILocation(line: 103, column: 3, scope: !3) +!23 = !DILocalVariable(name: "alloc", scope: !24) +!24 = distinct !DILexicalBlock(scope: !3) +!25 = !DILocation(line: 106, column: 3, scope: !3) +!26 = !DILocalVariable(scope: !24) +!27 = !DILocation(line: 109, column: 3, scope: !3) +!28 = !DILocation(line: 1, column: 2, scope: !3) +!32 = !DILocation(line: 2, column: 4, scope: !33, inlinedAt: !36) +!33 = distinct !DISubprogram(name: "callee", scope: !13, file: !1, type: !34, spFlags: DISPFlagDefinition, unit: !0) +!34 = !DISubroutineType(types: !35) +!35 = !{!8, !8} +!36 = !DILocation(line: 28, column: 5, scope: !3) +!37 = !DILocation(line: 135, column: 3, scope: !3) +!38 = distinct !DISubprogram(name: "empty_types", scope: !39, file: !1, type: !40, spFlags: DISPFlagDefinition, unit: !0, annotations: !42) +!39 = !DIModule(scope: !1, name: "module", configMacros: "bar", includePath: "/", apinotes: "/", file: !1, line: 42, isDecl: true) +!40 = !DISubroutineType(cc: DW_CC_normal, types: !41) +!41 = !{} +!42 = !{!43} +!43 = !{!"foo", !"bar"} +!44 = !DILocation(line: 140, column: 3, scope: !38) diff --git a/mlir/test/Target/LLVMIR/Import/function-attributes.ll b/mlir/test/Target/LLVMIR/Import/function-attributes.ll index 83c0438..023b012 100644 --- a/mlir/test/Target/LLVMIR/Import/function-attributes.ll +++ b/mlir/test/Target/LLVMIR/Import/function-attributes.ll @@ -22,14 +22,14 @@ define dso_local void @dsolocal_func() { ; // ----- ; CHECK-LABEL: @func_readnone -; CHECK-SAME: attributes {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>} +; CHECK-SAME: attributes {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>} ; CHECK: llvm.return define void @func_readnone() readnone { ret void } ; CHECK-LABEL: @func_readnone_indirect -; CHECK-SAME: attributes {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>} +; CHECK-SAME: attributes {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>} declare void @func_readnone_indirect() #0 attributes #0 = { readnone } @@ -169,7 +169,7 @@ define void @entry_count() !prof !1 { ; // ----- ; CHECK-LABEL: @func_memory -; CHECK-SAME: attributes {memory_effects = #llvm.memory_effects<other = readwrite, argMem = none, inaccessibleMem = readwrite>} +; CHECK-SAME: attributes {memory_effects = #llvm.memory_effects<other = readwrite, argMem = none, inaccessibleMem = readwrite, errnoMem = readwrite, targetMem0 = readwrite, targetMem1 = readwrite>} ; CHECK: llvm.return define void @func_memory() memory(readwrite, argmem: none) { ret void diff --git a/mlir/test/Target/LLVMIR/Import/import-failure.ll b/mlir/test/Target/LLVMIR/Import/import-failure.ll index d48be66..32f730b 100644 --- a/mlir/test/Target/LLVMIR/Import/import-failure.ll +++ b/mlir/test/Target/LLVMIR/Import/import-failure.ll @@ -1,16 +1,14 @@ ; RUN: not mlir-translate -import-llvm -emit-expensive-warnings -split-input-file %s 2>&1 -o /dev/null | FileCheck %s -; Check that debug intrinsics with an unsupported argument are dropped. - -declare void @llvm.dbg.value(metadata, metadata, metadata) +; Check that debug records with an unsupported argument are dropped. ; CHECK: import-failure.ll -; CHECK-SAME: warning: dropped intrinsic: tail call void @llvm.dbg.value(metadata !DIArgList(i64 %{{.*}}, i64 undef), metadata !3, metadata !DIExpression(DW_OP_LLVM_arg, 0, DW_OP_LLVM_arg, 1, DW_OP_constu, 1, DW_OP_mul, DW_OP_plus, DW_OP_stack_value)) +; CHECK-SAME: warning: unhandled debug variable record #dbg_value(!DIArgList(i64 %{{.*}}, i64 undef), !{{.*}}, !DIExpression(DW_OP_LLVM_arg, 0, DW_OP_LLVM_arg, 1, DW_OP_constu, 1, DW_OP_mul, DW_OP_plus, DW_OP_stack_value), !{{.*}}) ; CHECK: import-failure.ll -; CHECK-SAME: warning: dropped intrinsic: tail call void @llvm.dbg.value(metadata !6, metadata !3, metadata !DIExpression()) +; CHECK-SAME: warning: unhandled debug variable record #dbg_value(!{{.*}}, !{{.*}}, !DIExpression(), !{{.*}}) define void @unsupported_argument(i64 %arg1) { - tail call void @llvm.dbg.value(metadata !DIArgList(i64 %arg1, i64 undef), metadata !3, metadata !DIExpression(DW_OP_LLVM_arg, 0, DW_OP_LLVM_arg, 1, DW_OP_constu, 1, DW_OP_mul, DW_OP_plus, DW_OP_stack_value)), !dbg !5 - tail call void @llvm.dbg.value(metadata !6, metadata !3, metadata !DIExpression()), !dbg !5 + #dbg_value(!DIArgList(i64 %arg1, i64 undef), !3, !DIExpression(DW_OP_LLVM_arg, 0, DW_OP_LLVM_arg, 1, DW_OP_constu, 1, DW_OP_mul, DW_OP_plus, DW_OP_stack_value), !5) + #dbg_value(!6, !3, !DIExpression(), !5) ret void } diff --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll index be245e3..7f9c511 100644 --- a/mlir/test/Target/LLVMIR/Import/instructions.ll +++ b/mlir/test/Target/LLVMIR/Import/instructions.ll @@ -703,13 +703,13 @@ declare void @f() ; CHECK-LABEL: @call_memory_effects define void @call_memory_effects() { -; CHECK: llvm.call @f() {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>} +; CHECK: llvm.call @f() {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>} call void @f() memory(none) -; CHECK: llvm.call @f() {memory_effects = #llvm.memory_effects<other = none, argMem = write, inaccessibleMem = read>} +; CHECK: llvm.call @f() {memory_effects = #llvm.memory_effects<other = none, argMem = write, inaccessibleMem = read, errnoMem = none, targetMem0 = none, targetMem1 = none>} call void @f() memory(none, argmem: write, inaccessiblemem: read) -; CHECK: llvm.call @f() {memory_effects = #llvm.memory_effects<other = write, argMem = none, inaccessibleMem = write>} +; CHECK: llvm.call @f() {memory_effects = #llvm.memory_effects<other = write, argMem = none, inaccessibleMem = write, errnoMem = write, targetMem0 = write, targetMem1 = write>} call void @f() memory(write, argmem: none) -; CHECK: llvm.call @f() {memory_effects = #llvm.memory_effects<other = readwrite, argMem = readwrite, inaccessibleMem = read>} +; CHECK: llvm.call @f() {memory_effects = #llvm.memory_effects<other = readwrite, argMem = readwrite, inaccessibleMem = read, errnoMem = readwrite, targetMem0 = readwrite, targetMem1 = readwrite>} call void @f() memory(readwrite, inaccessiblemem: read) ; CHECK: llvm.call @f() ; CHECK-NOT: #llvm.memory_effects diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll index d2bb809..2381d7a 100644 --- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll +++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll @@ -1128,6 +1128,34 @@ define void @experimental_constrained_fpext(float %s, <4 x float> %v) { ret void } +; CHECK-LABEL: llvm.func @ucmp +define i2 @ucmp(i32 %a, i32 %b) { + ; CHECK: %{{.*}} = llvm.intr.ucmp(%{{.*}}, %{{.*}}) : (i32, i32) -> i2 + %r = call i2 @llvm.ucmp.i2.i32(i32 %a, i32 %b) + ret i2 %r +} + +; CHECK-LABEL: llvm.func @vector_ucmp +define <4 x i32> @vector_ucmp(<4 x i32> %a, <4 x i32> %b) { + ; CHECK: %{{.*}} = llvm.intr.ucmp(%{{.*}}, %{{.*}}) : (vector<4xi32>, vector<4xi32>) -> vector<4xi32> + %r = call <4 x i32> @llvm.ucmp.v4i32.v4i32(<4 x i32> %a, <4 x i32> %b) + ret <4 x i32> %r +} + +; CHECK-LABEL: llvm.func @scmp +define i2 @scmp(i32 %a, i32 %b) { + ; CHECK: %{{.*}} = llvm.intr.scmp(%{{.*}}, %{{.*}}) : (i32, i32) -> i2 + %r = call i2 @llvm.scmp.i2.i32(i32 %a, i32 %b) + ret i2 %r +} + +; CHECK-LABEL: llvm.func @vector_scmp +define <4 x i32> @vector_scmp(<4 x i32> %a, <4 x i32> %b) { + ; CHECK: %{{.*}} = llvm.intr.scmp(%{{.*}}, %{{.*}}) : (vector<4xi32>, vector<4xi32>) -> vector<4xi32> + %r = call <4 x i32> @llvm.scmp.v4i32.v4i32(<4 x i32> %a, <4 x i32> %b) + ret <4 x i32> %r +} + declare float @llvm.fmuladd.f32(float, float, float) declare <8 x float> @llvm.fmuladd.v8f32(<8 x float>, <8 x float>, <8 x float>) declare float @llvm.fma.f32(float, float, float) @@ -1382,3 +1410,7 @@ declare <4 x half> @llvm.experimental.constrained.fptrunc.v4f16.v4f64(<4 x doubl declare float @llvm.experimental.constrained.fptrunc.f32.f64(double, metadata, metadata) declare <4 x double> @llvm.experimental.constrained.fpext.v4f64.v4f32(<4 x float>, metadata) declare double @llvm.experimental.constrained.fpext.f64.f32(float, metadata) +declare i2 @llvm.ucmp.i2.i32(i32, i32) +declare <4 x i32> @llvm.ucmp.v4i32.v4i32(<4 x i32>, <4 x i32>) +declare i2 @llvm.scmp.i2.i32(i32, i32) +declare <4 x i32> @llvm.scmp.v4i32.v4i32(<4 x i32>, <4 x i32>) diff --git a/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll b/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll index c623df0..3280625 100644 --- a/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll +++ b/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll @@ -16,6 +16,22 @@ bb2: ; // ----- +; CHECK-LABEL: @cond_br_expected +define i64 @cond_br_expected(i1 %arg1, i64 %arg2) { +entry: + ; CHECK: llvm.cond_br + ; CHECK-SAME: weights([1, 2000]) + br i1 %arg1, label %bb1, label %bb2, !prof !0 +bb1: + ret i64 %arg2 +bb2: + ret i64 %arg2 +} + +!0 = !{!"branch_weights", !"expected", i32 1, i32 2000} + +; // ----- + ; CHECK-LABEL: @simple_switch( define i32 @simple_switch(i32 %arg1) { ; CHECK: llvm.switch @@ -36,6 +52,26 @@ bbd: ; // ----- +; CHECK-LABEL: @simple_switch_expected( +define i32 @simple_switch_expected(i32 %arg1) { + ; CHECK: llvm.switch + ; CHECK: {branch_weights = array<i32: 1, 1, 2000>} + switch i32 %arg1, label %bbd [ + i32 0, label %bb1 + i32 9, label %bb2 + ], !prof !0 +bb1: + ret i32 %arg1 +bb2: + ret i32 %arg1 +bbd: + ret i32 %arg1 +} + +!0 = !{!"branch_weights", !"expected", i32 1, i32 1, i32 2000} + +; // ----- + ; Verify that a single weight attached to a call is not translated. ; The MLIR WeightedBranchOpInterface does not support this case. diff --git a/mlir/test/Target/LLVMIR/allocatable_gpu_reduction.mlir b/mlir/test/Target/LLVMIR/allocatable_gpu_reduction.mlir new file mode 100644 index 0000000..95d12f3 --- /dev/null +++ b/mlir/test/Target/LLVMIR/allocatable_gpu_reduction.mlir @@ -0,0 +1,99 @@ +// Tests single-team by-ref GPU reductions. + +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +module attributes {dlti.dl_spec = #dlti.dl_spec<"dlti.alloca_memory_space" = 5 : ui64, "dlti.global_memory_space" = 1 : ui64>, llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} { + omp.private {type = private} @_QFfooEi_private_i32 : i32 + omp.declare_reduction @add_reduction_byref_box_heap_f32 : !llvm.ptr attributes {byref_element_type = f32} alloc { + %0 = llvm.mlir.constant(1 : i64) : i64 + %1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> : (i64) -> !llvm.ptr<5> + %2 = llvm.addrspacecast %1 : !llvm.ptr<5> to !llvm.ptr + omp.yield(%2 : !llvm.ptr) + } init { + ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr): + omp.yield(%arg1 : !llvm.ptr) + } combiner { + ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr): + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> {alignment = 8 : i64} : (i32) -> !llvm.ptr<5> + %2 = llvm.addrspacecast %1 : !llvm.ptr<5> to !llvm.ptr + %3 = llvm.mlir.constant(1 : i32) : i32 + %4 = llvm.alloca %3 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> {alignment = 8 : i64} : (i32) -> !llvm.ptr<5> + %5 = llvm.addrspacecast %4 : !llvm.ptr<5> to !llvm.ptr + %6 = llvm.mlir.constant(24 : i32) : i32 + "llvm.intr.memcpy"(%5, %arg0, %6) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () + %7 = llvm.mlir.constant(24 : i32) : i32 + "llvm.intr.memcpy"(%2, %arg1, %7) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () + %8 = llvm.getelementptr %5[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> + %9 = llvm.load %8 : !llvm.ptr -> !llvm.ptr + %10 = llvm.getelementptr %2[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> + %11 = llvm.load %10 : !llvm.ptr -> !llvm.ptr + %12 = llvm.load %9 : !llvm.ptr -> f32 + %13 = llvm.load %11 : !llvm.ptr -> f32 + %14 = llvm.fadd %12, %13 {fastmathFlags = #llvm.fastmath<contract>} : f32 + llvm.store %14, %9 : f32, !llvm.ptr + omp.yield(%arg0 : !llvm.ptr) + } data_ptr_ptr { + ^bb0(%arg0: !llvm.ptr): + %0 = llvm.getelementptr %arg0[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> + omp.yield(%0 : !llvm.ptr) + } + + llvm.func @foo_() { + %0 = llvm.mlir.constant(1 : i64) : i64 + %4 = llvm.alloca %0 x i1 : (i64) -> !llvm.ptr<5> + %5 = llvm.addrspacecast %4 : !llvm.ptr<5> to !llvm.ptr + %8 = llvm.getelementptr %5[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> + %9 = omp.map.info var_ptr(%5 : !llvm.ptr, f32) map_clauses(implicit, tofrom) capture(ByRef) var_ptr_ptr(%8 : !llvm.ptr) -> !llvm.ptr {name = ""} + %10 = omp.map.info var_ptr(%5 : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>) map_clauses(always, implicit, descriptor, to) capture(ByRef) members(%9 : [0] : !llvm.ptr) -> !llvm.ptr {name = "scalar_alloc"} + omp.target map_entries(%10 -> %arg0 : !llvm.ptr) { + %13 = llvm.mlir.constant(1000 : i32) : i32 + %14 = llvm.mlir.constant(1 : i32) : i32 + omp.parallel { + omp.wsloop reduction(byref @add_reduction_byref_box_heap_f32 %arg0 -> %arg4 : !llvm.ptr) { + omp.loop_nest (%arg5) : i32 = (%14) to (%13) inclusive step (%14) { + omp.yield + } + } + omp.terminator + } + omp.terminator + } + llvm.return + } +} + +// CHECK: define {{.*}} @_omp_reduction_shuffle_and_reduce_func({{.*}}) {{.*}} { +// CHECK: %[[REMOTE_RED_LIST:.omp.reduction.remote_reduce_list]] = alloca [1 x ptr], align 8, addrspace(5) +// CHECK: %[[RED_ELEM:.omp.reduction.element]] = alloca { ptr, i64, i32, i8, i8, i8, i8 }, align 8, addrspace(5) +// CHECK: %[[RED_ELEM_1:.*]] = addrspacecast ptr addrspace(5) %[[RED_ELEM]] to ptr + +// CHECK: %[[SHUFFLE_ELEM:.*]] = alloca float, align 4, addrspace(5) +// CHECK: %[[REMOTE_RED_LIST_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[REMOTE_RED_LIST]] to ptr + +// CHECK: %[[REMOTE_RED_LIST_ELEM0:.*]] = getelementptr inbounds [1 x ptr], ptr %[[REMOTE_RED_LIST_ASCAST]], i64 0, i64 0 + +// CHECK: %[[SHUFFLE_ELEM_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[SHUFFLE_ELEM]] to ptr +// CHECK: %[[SHUFFLE_RES:.*]] = call i32 @__kmpc_shuffle_int32({{.*}}) +// CHECK: store i32 %[[SHUFFLE_RES]], ptr %[[SHUFFLE_ELEM_ASCAST]], align 4 + +// CHECK: %[[RED_ELEM_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[RED_ELEM]] to ptr +// CHECK: %[[RED_ALLOC_PTR:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8 }, ptr %[[RED_ELEM_ASCAST]], i32 0, i32 0 +// CHECK: %[[SHUFFLE_ELEM_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[SHUFFLE_ELEM]] to ptr +// CHECK: store ptr %[[SHUFFLE_ELEM_ASCAST]], ptr %[[RED_ALLOC_PTR]], align 8 +// CHECK: store ptr %[[RED_ELEM_1]], ptr %[[REMOTE_RED_LIST_ELEM0]], align 8 +// CHECK: } + +// CHECK: define {{.*}} @_omp_reduction_inter_warp_copy_func({{.*}}) {{.*}} { +// CHECK: %[[WARP_MASTER_CMP:.*]] = icmp eq i32 %nvptx_lane_id, 0 +// CHECK: br i1 %[[WARP_MASTER_CMP]], label %[[WARP_MASTER_BB:.*]], label %{{.*}} + +// CHECK: [[WARP_MASTER_BB]]: +// CHECK: %[[WARP_RESULT_PTR:.*]] = getelementptr inbounds [1 x ptr], ptr %{{.*}}, i64 0, i64 0 +// CHECK: %[[WARP_RESULT:.*]] = load ptr, ptr %[[WARP_RESULT_PTR]], align 8 +// CHECK: %[[ALLOC_MEM_PTR:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8 }, ptr %[[WARP_RESULT]], i32 0, i32 0 +// CHECK: %[[ALLOC_MEM:.*]] = load ptr, ptr %[[ALLOC_MEM_PTR]], align 8 +// CHECK: %[[WARP_TRANSFER_SLOT:.*]] = getelementptr inbounds [32 x i32], ptr addrspace(3) @__openmp_nvptx_data_transfer_temporary_storage, i64 0, i32 %nvptx_warp_id +// CHECK: %[[WARP_RED_RES:.*]] = load i32, ptr %[[ALLOC_MEM]], align 4 +// CHECK: store volatile i32 %[[WARP_RED_RES]], ptr addrspace(3) %[[WARP_TRANSFER_SLOT]], align 4 +// CHECK: } diff --git a/mlir/test/Target/LLVMIR/allocatable_gpu_reduction_teams.mlir b/mlir/test/Target/LLVMIR/allocatable_gpu_reduction_teams.mlir new file mode 100644 index 0000000..1c73a49 --- /dev/null +++ b/mlir/test/Target/LLVMIR/allocatable_gpu_reduction_teams.mlir @@ -0,0 +1,121 @@ +// Tests cross-teams by-ref GPU reductions. + +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +module attributes {dlti.dl_spec = #dlti.dl_spec<"dlti.alloca_memory_space" = 5 : ui64, "dlti.global_memory_space" = 1 : ui64>, llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} { + omp.private {type = private} @_QFfooEi_private_i32 : i32 + omp.declare_reduction @add_reduction_byref_box_heap_f32 : !llvm.ptr attributes {byref_element_type = f32} alloc { + %0 = llvm.mlir.constant(1 : i64) : i64 + %1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> : (i64) -> !llvm.ptr<5> + %2 = llvm.addrspacecast %1 : !llvm.ptr<5> to !llvm.ptr + omp.yield(%2 : !llvm.ptr) + } init { + ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr): + omp.yield(%arg1 : !llvm.ptr) + } combiner { + ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr): + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> {alignment = 8 : i64} : (i32) -> !llvm.ptr<5> + %2 = llvm.addrspacecast %1 : !llvm.ptr<5> to !llvm.ptr + %3 = llvm.mlir.constant(1 : i32) : i32 + %4 = llvm.alloca %3 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> {alignment = 8 : i64} : (i32) -> !llvm.ptr<5> + %5 = llvm.addrspacecast %4 : !llvm.ptr<5> to !llvm.ptr + %6 = llvm.mlir.constant(24 : i32) : i32 + "llvm.intr.memcpy"(%5, %arg0, %6) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () + %7 = llvm.mlir.constant(24 : i32) : i32 + "llvm.intr.memcpy"(%2, %arg1, %7) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () + %8 = llvm.getelementptr %5[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> + %9 = llvm.load %8 : !llvm.ptr -> !llvm.ptr + %10 = llvm.getelementptr %2[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> + %11 = llvm.load %10 : !llvm.ptr -> !llvm.ptr + %12 = llvm.load %9 : !llvm.ptr -> f32 + %13 = llvm.load %11 : !llvm.ptr -> f32 + %14 = llvm.fadd %12, %13 {fastmathFlags = #llvm.fastmath<contract>} : f32 + llvm.store %14, %9 : f32, !llvm.ptr + omp.yield(%arg0 : !llvm.ptr) + } data_ptr_ptr { + ^bb0(%arg0: !llvm.ptr): + %0 = llvm.getelementptr %arg0[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> + omp.yield(%0 : !llvm.ptr) + } + + llvm.func @foo_() { + %0 = llvm.mlir.constant(1 : i64) : i64 + %4 = llvm.alloca %0 x i1 : (i64) -> !llvm.ptr<5> + %5 = llvm.addrspacecast %4 : !llvm.ptr<5> to !llvm.ptr + %8 = llvm.getelementptr %5[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> + %9 = omp.map.info var_ptr(%5 : !llvm.ptr, f32) map_clauses(tofrom) capture(ByRef) var_ptr_ptr(%8 : !llvm.ptr) -> !llvm.ptr {name = ""} + %10 = omp.map.info var_ptr(%5 : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>) map_clauses(always, descriptor, to, attach) capture(ByRef) members(%9 : [0] : !llvm.ptr) -> !llvm.ptr {name = "scalar_alloc"} + omp.target map_entries(%10 -> %arg0 : !llvm.ptr) { + %14 = llvm.mlir.constant(1000000 : i32) : i32 + %15 = llvm.mlir.constant(1 : i32) : i32 + omp.teams reduction(byref @add_reduction_byref_box_heap_f32 %arg0 -> %arg3 : !llvm.ptr) { + omp.parallel { + omp.distribute { + omp.wsloop reduction(byref @add_reduction_byref_box_heap_f32 %arg3 -> %arg5 : !llvm.ptr) { + omp.loop_nest (%arg6) : i32 = (%15) to (%14) inclusive step (%15) { + omp.yield + } + } {omp.composite} + } {omp.composite} + omp.terminator + } {omp.composite} + omp.terminator + } + omp.terminator + } + llvm.return + } +} + +// CHECK: %[[GLOBALIZED_LOCALS:.*]] = type { float } + +// CHECK: define internal void @_omp_reduction_list_to_global_copy_func({{.*}}) {{.*}} { +// CHECK: %[[RED_ARR_LIST:.*]] = getelementptr inbounds [1 x ptr], ptr %{{.*}}, i64 0, i64 0 +// CHECK: %[[RED_ELEM_PTR:.*]] = load ptr, ptr %[[RED_ARR_LIST]], align 8 +// CHECK: %[[GLOB_ELEM_PTR:.*]] = getelementptr inbounds %[[GLOBALIZED_LOCALS]], ptr %{{.*}}, i32 0, i32 0 +// CHECK: %[[ALLOC_PTR_PTR:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8 }, ptr %[[RED_ELEM_PTR]], i32 0, i32 0 +// CHECK: %[[ALLOC_PTR:.*]] = load ptr, ptr %[[ALLOC_PTR_PTR]], align 8 +// CHECK: %[[ALLOC_VAL:.*]] = load float, ptr %[[ALLOC_PTR]], align 4 +// Verify that the actual value managed by the descriptor is stored in the globalized +// locals arrays; rather than a pointer to the descriptor or a pointer to the value. +// CHECK: store float %[[ALLOC_VAL]], ptr %[[GLOB_ELEM_PTR]], align 4 +// CHECK: } + +// CHECK: define internal void @_omp_reduction_list_to_global_reduce_func({{.*}}) {{.*}} { +// Allocate a descriptor to manage the element retrieved from the globalized local array. +// CHECK: %[[ALLOC_DESC:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8 }, align 8, addrspace(5) +// CHECK: %[[ALLOC_DESC_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ALLOC_DESC]] to ptr + +// CHECK: %[[RED_ARR_LIST:.*]] = getelementptr inbounds [1 x ptr], ptr %{{.*}}, i64 0, i64 0 +// CHECK: %[[GLOB_ELEM_PTR:.*]] = getelementptr inbounds %[[GLOBALIZED_LOCALS]], ptr %{{.*}}, i32 0, i32 0 +// CHECK: %[[ALLOC_PTR_PTR:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8 }, ptr %[[ALLOC_DESC_ASCAST]], i32 0, i32 0 +// Store the pointer to the gloalized local element into the locally allocated descriptor. +// CHECK: store ptr %[[GLOB_ELEM_PTR]], ptr %[[ALLOC_PTR_PTR]], align 8 +// CHECK: store ptr %[[ALLOC_DESC_ASCAST]], ptr %[[RED_ARR_LIST]], align 8 +// CHECK: } + +// CHECK: define internal void @_omp_reduction_global_to_list_copy_func({{.*}}) {{.*}} { +// CHECK: %[[RED_ARR_LIST:.*]] = getelementptr inbounds [1 x ptr], ptr %{{.*}}, i64 0, i64 0 +// CHECK: %[[RED_ELEM_PTR:.*]] = load ptr, ptr %[[RED_ARR_LIST]], align 8 +// CHECK: %[[GLOB_ELEM_PTR:.*]] = getelementptr inbounds %[[GLOBALIZED_LOCALS]], ptr %{{.*}}, i32 0, i32 0 +// CHECK: %[[ALLOC_PTR_PTR:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8 }, ptr %[[RED_ELEM_PTR]], i32 0, i32 0 +// Similar to _omp_reduction_list_to_global_copy_func(...) but in the reverse direction; i.e. +// the globalized local array is copied from rather than copied to. +// CHECK: %[[ALLOC_PTR:.*]] = load ptr, ptr %[[ALLOC_PTR_PTR]], align 8 +// CHECK: %[[ALLOC_VAL:.*]] = load float, ptr %[[GLOB_ELEM_PTR]], align 4 +// CHECK: store float %[[ALLOC_VAL]], ptr %[[ALLOC_PTR]], align 4 +// CHECK: } + +// CHECK: define internal void @_omp_reduction_global_to_list_reduce_func({{.*}}) {{.*}} { +// Allocate a descriptor to manage the element retrieved from the globalized local array. +// CHECK: %[[ALLOC_DESC:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8 }, align 8, addrspace(5) +// CHECK: %[[ALLOC_DESC_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ALLOC_DESC]] to ptr + +// CHECK: %[[RED_ARR_LIST:.*]] = getelementptr inbounds [1 x ptr], ptr %{{.*}}, i64 0, i64 0 +// CHECK: %[[GLOB_ELEM_PTR:.*]] = getelementptr inbounds %[[GLOBALIZED_LOCALS]], ptr %{{.*}}, i32 0, i32 0 +// CHECK: %[[ALLOC_PTR_PTR:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8 }, ptr %[[ALLOC_DESC_ASCAST]], i32 0, i32 0 +// Store the pointer to the gloalized local element into the locally allocated descriptor. +// CHECK: store ptr %[[GLOB_ELEM_PTR]], ptr %[[ALLOC_PTR_PTR]], align 8 +// CHECK: store ptr %[[ALLOC_DESC_ASCAST]], ptr %[[RED_ARR_LIST]], align 8 +// CHECK: } diff --git a/mlir/test/Target/LLVMIR/anonymous-tbaa.mlir b/mlir/test/Target/LLVMIR/anonymous-tbaa.mlir new file mode 100644 index 0000000..b54bfe4 --- /dev/null +++ b/mlir/test/Target/LLVMIR/anonymous-tbaa.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s + +#tbaa_root_0 = #llvm.tbaa_root<> +#tbaa_type_desc_1 = #llvm.tbaa_type_desc<id = "omnipotent char", members = {<#tbaa_root_0, 0>}> +#tbaa_type_desc_2 = #llvm.tbaa_type_desc<id = "long long", members = {<#tbaa_type_desc_1, 0>}> +#tbaa_tag_3 = #llvm.tbaa_tag<access_type = #tbaa_type_desc_2, base_type = #tbaa_type_desc_2, offset = 0> + +// CHECK: define void @tbaa_anonymous_root(ptr %{{.*}}) { +// CHECK: %{{.*}} = load i64, ptr %{{.*}}, align 4, !tbaa ![[TAG:[0-9]+]] +// CHECK: ret void +// CHECK: } +// CHECK: !llvm.module.flags = !{![[FLAGS:[0-9]+]]} +// CHECK: ![[FLAGS]] = !{i32 2, !"Debug Info Version", i32 3} +// CHECK: ![[TAG]] = !{![[TYPE:[0-9]+]], ![[TYPE]], i64 0} +// CHECK: ![[TYPE]] = !{!"long long", ![[BASE:[0-9]+]], i64 0} +// CHECK: ![[BASE]] = !{!"omnipotent char", ![[ROOT:[0-9]+]], i64 0} +// CHECK: ![[ROOT]] = distinct !{![[ROOT]]} +llvm.func @tbaa_anonymous_root(%arg0: !llvm.ptr) { + %0 = llvm.load %arg0 {tbaa = [#tbaa_tag_3]} : !llvm.ptr -> i64 + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir index 60bd24a..403c73f 100644 --- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir @@ -1276,6 +1276,34 @@ llvm.func @experimental_constrained_fpext(%s: f32, %v: vector<4xf32>) { llvm.return } +// CHECK-LABEL: @ucmp +llvm.func @ucmp(%a: i32, %b: i32) -> i2 { + // CHECK: call i2 @llvm.ucmp.i2.i32 + %r = llvm.intr.ucmp(%a, %b) : (i32, i32) -> i2 + llvm.return %r : i2 +} + +// CHECK-LABEL: @vector_ucmp +llvm.func @vector_ucmp(%a: vector<4 x i32>, %b: vector<4 x i32>) -> vector<4 x i32> { + // CHECK: call <4 x i32> @llvm.ucmp.v4i32.v4i32 + %0 = llvm.intr.ucmp(%a, %b) : (vector<4 x i32>, vector<4 x i32>) -> vector<4 x i32> + llvm.return %0 : vector<4 x i32> +} + +// CHECK-LABEL: @scmp +llvm.func @scmp(%a: i32, %b: i32) -> i2 { + // CHECK: call i2 @llvm.scmp.i2.i32 + %r = llvm.intr.scmp(%a, %b) : (i32, i32) -> i2 + llvm.return %r : i2 +} + +// CHECK-LABEL: @vector_scmp +llvm.func @vector_scmp(%a: vector<4 x i32>, %b: vector<4 x i32>) -> vector<4 x i32> { + // CHECK: call <4 x i32> @llvm.scmp.v4i32.v4i32 + %0 = llvm.intr.scmp(%a, %b) : (vector<4 x i32>, vector<4 x i32>) -> vector<4 x i32> + llvm.return %0 : vector<4 x i32> +} + // Check that intrinsics are declared with appropriate types. // CHECK-DAG: declare float @llvm.fma.f32(float, float, float) // CHECK-DAG: declare <8 x float> @llvm.fma.v8f32(<8 x float>, <8 x float>, <8 x float>) #0 @@ -1308,7 +1336,7 @@ llvm.func @experimental_constrained_fpext(%s: f32, %v: vector<4xf32>) { // CHECK-DAG: declare float @llvm.cos.f32(float) // CHECK-DAG: declare <8 x float> @llvm.cos.v8f32(<8 x float>) #0 // CHECK-DAG: declare { float, float } @llvm.sincos.f32(float) -// CHECK-DAG: declare { <8 x float>, <8 x float> } @llvm.sincos.v8f32(<8 x float>) #0 +// CHECK-DAG: declare { <8 x float>, <8 x float> } @llvm.sincos.v8f32(<8 x float>) // CHECK-DAG: declare float @llvm.copysign.f32(float, float) // CHECK-DAG: declare float @llvm.rint.f32(float) // CHECK-DAG: declare double @llvm.rint.f64(double) @@ -1464,3 +1492,7 @@ llvm.func @experimental_constrained_fpext(%s: f32, %v: vector<4xf32>) { // CHECK-DAG: declare <4 x half> @llvm.experimental.constrained.fptrunc.v4f16.v4f32(<4 x float>, metadata, metadata) // CHECK-DAG: declare double @llvm.experimental.constrained.fpext.f64.f32(float, metadata) // CHECK-DAG: declare <4 x double> @llvm.experimental.constrained.fpext.v4f64.v4f32(<4 x float>, metadata) +// CHECK-DAG: declare range(i2 -1, -2) i2 @llvm.ucmp.i2.i32(i32, i32) +// CHECK-DAG: declare range(i32 -1, 2) <4 x i32> @llvm.ucmp.v4i32.v4i32(<4 x i32>, <4 x i32>) +// CHECK-DAG: declare range(i2 -1, -2) i2 @llvm.scmp.i2.i32(i32, i32) +// CHECK-DAG: declare range(i32 -1, 2) <4 x i32> @llvm.scmp.v4i32.v4i32(<4 x i32>, <4 x i32>) diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir index cc243c8..819a514 100644 --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -78,6 +78,9 @@ llvm.mlir.global internal @f8E8M0FNU_global_as_i8(1.0 : f8E8M0FNU) : i8 // CHECK: @bf16_global_as_i16 = internal global i16 16320 llvm.mlir.global internal @bf16_global_as_i16(1.5 : bf16) : i16 +// CHECK: @bool_global_as_i8 = internal global i8 1 +llvm.mlir.global internal @bool_global_as_i8(true) : i8 + // CHECK: @explicit_undef = global i32 undef llvm.mlir.global external @explicit_undef() : i32 { %0 = llvm.mlir.undef : i32 @@ -2371,17 +2374,17 @@ llvm.func @readonly_function(%arg0: !llvm.ptr {llvm.readonly}) // CHECK: declare void @arg_mem_none_func() #[[ATTR:[0-9]+]] llvm.func @arg_mem_none_func() attributes { - memory_effects = #llvm.memory_effects<other = readwrite, argMem = none, inaccessibleMem = readwrite>} + memory_effects = #llvm.memory_effects<other = readwrite, argMem = none, inaccessibleMem = readwrite, errnoMem = none, targetMem0 = none, targetMem1 = none>} -// CHECK: attributes #[[ATTR]] = { memory(readwrite, argmem: none, errnomem: none) } +// CHECK: attributes #[[ATTR]] = { memory(readwrite, argmem: none, errnomem: none, target_mem0: none, target_mem1: none) } // ----- // CHECK: declare void @readwrite_func() #[[ATTR:[0-9]+]] llvm.func @readwrite_func() attributes { - memory_effects = #llvm.memory_effects<other = readwrite, argMem = readwrite, inaccessibleMem = readwrite>} + memory_effects = #llvm.memory_effects<other = readwrite, argMem = readwrite, inaccessibleMem = readwrite, errnoMem = none, targetMem0 = none, targetMem1 = none>} -// CHECK: attributes #[[ATTR]] = { memory(readwrite, errnomem: none) } +// CHECK: attributes #[[ATTR]] = { memory(readwrite, errnomem: none, target_mem0: none, target_mem1: none) } // ----- @@ -2723,10 +2726,10 @@ llvm.func @fd() // CHECK: call void @fc() #[[ATTRS_2:[0-9]+]] // CHECK: call void @fd() #[[ATTRS_3:[0-9]+]] llvm.func @mem_effects_call() { - llvm.call @fa() {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>} : () -> () - llvm.call @fb() {memory_effects = #llvm.memory_effects<other = read, argMem = none, inaccessibleMem = write>} : () -> () - llvm.call @fc() {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = write>} : () -> () - llvm.call @fd() {memory_effects = #llvm.memory_effects<other = readwrite, argMem = read, inaccessibleMem = readwrite>} : () -> () + llvm.call @fa() {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>} : () -> () + llvm.call @fb() {memory_effects = #llvm.memory_effects<other = read, argMem = none, inaccessibleMem = write, errnoMem = none, targetMem0 = none, targetMem1 = none>} : () -> () + llvm.call @fc() {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = write, errnoMem = none, targetMem0 = none, targetMem1 = none>} : () -> () + llvm.call @fd() {memory_effects = #llvm.memory_effects<other = readwrite, argMem = read, inaccessibleMem = readwrite, errnoMem = none, targetMem0 = none, targetMem1 = none>} : () -> () llvm.return } @@ -2734,11 +2737,11 @@ llvm.func @mem_effects_call() { // CHECK: #[[ATTRS_0]] // CHECK-SAME: memory(none) // CHECK: #[[ATTRS_1]] -// CHECK-SAME: memory(read, argmem: none, inaccessiblemem: write, errnomem: none) +// CHECK-SAME: memory(read, argmem: none, inaccessiblemem: write, errnomem: none, target_mem0: none, target_mem1: none) // CHECK: #[[ATTRS_2]] -// CHECK-SAME: memory(read, inaccessiblemem: write, errnomem: none) +// CHECK-SAME: memory(read, inaccessiblemem: write, errnomem: none, target_mem0: none, target_mem1: none) // CHECK: #[[ATTRS_3]] -// CHECK-SAME: memory(readwrite, argmem: read, errnomem: none) +// CHECK-SAME: memory(readwrite, argmem: read, errnomem: none, target_mem0: none, target_mem1: none) // ----- diff --git a/mlir/test/Target/LLVMIR/nvvm/barrier.mlir b/mlir/test/Target/LLVMIR/nvvm/barrier.mlir new file mode 100644 index 0000000..a18633e --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/barrier.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s --check-prefix=LLVM +// RUN: mlir-opt %s | mlir-opt | FileCheck %s + +// LLVM-LABEL: @llvm_nvvm_barrier( +// LLVM-SAME: i32 %[[barId:.*]], i32 %[[numThreads:.*]], i32 %[[redOperand:.*]]) +llvm.func @llvm_nvvm_barrier(%barID : i32, %numberOfThreads : i32, %redOperand : i32) { + // LLVM: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0) + // CHECK: nvvm.barrier + nvvm.barrier + // LLVM: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 %[[barId]]) + // CHECK: nvvm.barrier id = %{{.*}} + nvvm.barrier id = %barID + // LLVM: call void @llvm.nvvm.barrier.cta.sync.aligned.count(i32 %[[barId]], i32 %[[numThreads]]) + // CHECK: nvvm.barrier id = %{{.*}} number_of_threads = %{{.*}} + nvvm.barrier id = %barID number_of_threads = %numberOfThreads + // LLVM: %{{.*}} = call i32 @llvm.nvvm.barrier0.and(i32 %[[redOperand]]) + // CHECK: %{{.*}} = nvvm.barrier #nvvm.reduction<and> %{{.*}} -> i32 + %0 = nvvm.barrier #nvvm.reduction<and> %redOperand -> i32 + // LLVM: %{{.*}} = call i32 @llvm.nvvm.barrier0.or(i32 %[[redOperand]]) + // CHECK: %{{.*}} = nvvm.barrier #nvvm.reduction<or> %{{.*}} -> i32 + %1 = nvvm.barrier #nvvm.reduction<or> %redOperand -> i32 + // LLVM: %{{.*}} = call i32 @llvm.nvvm.barrier0.popc(i32 %[[redOperand]]) + // CHECK: %{{.*}} = nvvm.barrier #nvvm.reduction<popc> %{{.*}} -> i32 + %2 = nvvm.barrier #nvvm.reduction<popc> %redOperand -> i32 + + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp16x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp16x2.mlir new file mode 100644 index 0000000..a4bece8 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp16x2.mlir @@ -0,0 +1,87 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @convert_f32x2_to_f16x2_rn +llvm.func @convert_f32x2_to_f16x2_rn(%srcA : f32, %srcB : f32) { + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn(float %{{.*}}, float %{{.*}}) + %res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16> + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.satfinite(float %{{.*}}, float %{{.*}}) + %res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16> + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu(float %{{.*}}, float %{{.*}}) + %res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xf16> + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}}) + %res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16> + + llvm.return +} + +// CHECK-LABEL: @convert_f32x2_to_f16x2_rz +llvm.func @convert_f32x2_to_f16x2_rz(%srcA : f32, %srcB : f32) { + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz(float %{{.*}}, float %{{.*}}) + %res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xf16> + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.satfinite(float %{{.*}}, float %{{.*}}) + %res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16> + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float %{{.*}}, float %{{.*}}) + %res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true} : vector<2xf16> + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu.satfinite(float %{{.*}}, float %{{.*}}) + %res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16> + + llvm.return +} + +// CHECK-LABEL: @convert_f32x2_to_f16x2_rs_stochastic +llvm.func @convert_f32x2_to_f16x2_rs_stochastic(%srcA : f32, %srcB : f32, %rbits : i32) { + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) + %res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16> + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) + %res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16> + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) + %res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16> + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) + %res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16> + + llvm.return +} + +// ----- + +// CHECK-LABEL: @convert_f32x2_to_bf16x2_rn +llvm.func @convert_f32x2_to_bf16x2_rn(%srcA : f32, %srcB : f32) { + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn(float %{{.*}}, float %{{.*}}) + %res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16> + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.satfinite(float %{{.*}}, float %{{.*}}) + %res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu(float %{{.*}}, float %{{.*}}) + %res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xbf16> + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}}) + %res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> + + llvm.return +} + +// CHECK-LABEL: @convert_f32x2_to_bf16x2_rz +llvm.func @convert_f32x2_to_bf16x2_rz(%srcA : f32, %srcB : f32) { + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz(float %{{.*}}, float %{{.*}}) + %res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16> + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.satfinite(float %{{.*}}, float %{{.*}}) + %res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu(float %{{.*}}, float %{{.*}}) + %res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true} : vector<2xbf16> + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu.satfinite(float %{{.*}}, float %{{.*}}) + %res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> + + llvm.return +} + +// CHECK-LABEL: @convert_f32x2_to_bf16x2_rs_stochastic +llvm.func @convert_f32x2_to_bf16x2_rs_stochastic(%srcA : f32, %srcB : f32, %rbits : i32) { + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) + %res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16> + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) + %res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16> + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) + %res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) + %res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> + + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir new file mode 100644 index 0000000..03abcdd --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir @@ -0,0 +1,118 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +// ----- + +// Test valid architectures work + +// Valid case on sm_100a +gpu.module @valid_f16x2_rs_sm_100a [#nvvm.target<chip = "sm_100a">] { + func.func @convert_rs() { + %f1 = llvm.mlir.constant(1.0 : f32) : f32 + %f2 = llvm.mlir.constant(2.0 : f32) : f32 + %rbits = llvm.mlir.constant(0x12345678 : i32) : i32 + %res = nvvm.convert.f32x2.to.f16x2 %f1, %f2, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16> + return + } +} + +// Valid case on sm_103a +gpu.module @valid_bf16x2_rs_sm_103a [#nvvm.target<chip = "sm_103a">] { + func.func @convert_rs() { + %f1 = llvm.mlir.constant(1.0 : f32) : f32 + %f2 = llvm.mlir.constant(2.0 : f32) : f32 + %rbits = llvm.mlir.constant(0 : i32) : i32 + %res = nvvm.convert.f32x2.to.bf16x2 %f1, %f2, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16> + return + } +} + +// ----- + +// Test F32x4 -> F8x4 (E4M3) with stochastic rounding (.rs) + +// CHECK-LABEL: @convert_f32x4_to_f8x4_e4m3_rs +llvm.func @convert_f32x4_to_f8x4_e4m3_rs(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> { + // CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e4m3x4.rs.satfinite(<4 x float> %{{.*}}, i32 %{{.*}}) + %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E4M3FN) + llvm.return %res : vector<4xi8> +} + +// CHECK-LABEL: @convert_f32x4_to_f8x4_e4m3_rs_relu +llvm.func @convert_f32x4_to_f8x4_e4m3_rs_relu(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> { + // CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e4m3x4.rs.relu.satfinite(<4 x float> %{{.*}}, i32 %{{.*}}) + %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits {relu = true} : vector<4xf32> -> vector<4xi8> (f8E4M3FN) + llvm.return %res : vector<4xi8> +} + +// ----- + +// Test F32x4 -> F8x4 (E5M2) with stochastic rounding (.rs) + +// CHECK-LABEL: @convert_f32x4_to_f8x4_e5m2_rs +llvm.func @convert_f32x4_to_f8x4_e5m2_rs(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> { + // CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e5m2x4.rs.satfinite(<4 x float> %{{.*}}, i32 %{{.*}}) + %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E5M2) + llvm.return %res : vector<4xi8> +} + +// CHECK-LABEL: @convert_f32x4_to_f8x4_e5m2_rs_relu +llvm.func @convert_f32x4_to_f8x4_e5m2_rs_relu(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> { + // CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e5m2x4.rs.relu.satfinite(<4 x float> %{{.*}}, i32 %{{.*}}) + %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits {relu = true} : vector<4xf32> -> vector<4xi8> (f8E5M2) + llvm.return %res : vector<4xi8> +} + +// ----- + +// Test F32x4 -> F6x4 (E2M3) with stochastic rounding (.rs) + +// CHECK-LABEL: @convert_f32x4_to_f6x4_e2m3_rs +llvm.func @convert_f32x4_to_f6x4_e2m3_rs(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> { + // CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e2m3x4.rs.satfinite(<4 x float> %{{.*}}, i32 %{{.*}}) + %res = nvvm.convert.f32x4.to.f6x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f6E2M3FN) + llvm.return %res : vector<4xi8> +} + +// CHECK-LABEL: @convert_f32x4_to_f6x4_e2m3_rs_relu +llvm.func @convert_f32x4_to_f6x4_e2m3_rs_relu(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> { + // CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e2m3x4.rs.relu.satfinite(<4 x float> %{{.*}}, i32 %{{.*}}) + %res = nvvm.convert.f32x4.to.f6x4 %src, %rbits {relu = true} : vector<4xf32> -> vector<4xi8> (f6E2M3FN) + llvm.return %res : vector<4xi8> +} + +// ----- + +// Test F32x4 -> F6x4 (E3M2) with stochastic rounding (.rs) + +// CHECK-LABEL: @convert_f32x4_to_f6x4_e3m2_rs +llvm.func @convert_f32x4_to_f6x4_e3m2_rs(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> { + // CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e3m2x4.rs.satfinite(<4 x float> %{{.*}}, i32 %{{.*}}) + %res = nvvm.convert.f32x4.to.f6x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f6E3M2FN) + llvm.return %res : vector<4xi8> +} + +// CHECK-LABEL: @convert_f32x4_to_f6x4_e3m2_rs_relu +llvm.func @convert_f32x4_to_f6x4_e3m2_rs_relu(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> { + // CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e3m2x4.rs.relu.satfinite(<4 x float> %{{.*}}, i32 %{{.*}}) + %res = nvvm.convert.f32x4.to.f6x4 %src, %rbits {relu = true} : vector<4xf32> -> vector<4xi8> (f6E3M2FN) + llvm.return %res : vector<4xi8> +} + +// ----- + +// Test F32x4 -> F4x4 (E2M1) with stochastic rounding (.rs) + +// CHECK-LABEL: @convert_f32x4_to_f4x4_e2m1_rs +llvm.func @convert_f32x4_to_f4x4_e2m1_rs(%src : vector<4xf32>, %rbits : i32) -> i16 { + // CHECK: %{{.*}} = call i16 @llvm.nvvm.f32x4.to.e2m1x4.rs.satfinite(<4 x float> %{{.*}}, i32 %{{.*}}) + %res = nvvm.convert.f32x4.to.f4x4 %src, %rbits : vector<4xf32> -> i16 (f4E2M1FN) + llvm.return %res : i16 +} + +// CHECK-LABEL: @convert_f32x4_to_f4x4_e2m1_rs_relu +llvm.func @convert_f32x4_to_f4x4_e2m1_rs_relu(%src : vector<4xf32>, %rbits : i32) -> i16 { + // CHECK: %{{.*}} = call i16 @llvm.nvvm.f32x4.to.e2m1x4.rs.relu.satfinite(<4 x float> %{{.*}}, i32 %{{.*}}) + %res = nvvm.convert.f32x4.to.f4x4 %src, %rbits {relu = true} : vector<4xf32> -> i16 (f4E2M1FN) + llvm.return %res : i16 +} + diff --git a/mlir/test/Target/LLVMIR/nvvm/fence-invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/fence-invalid.mlir new file mode 100644 index 0000000..22578b5 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/fence-invalid.mlir @@ -0,0 +1,89 @@ +// RUN: mlir-translate --mlir-to-llvmir -verify-diagnostics -split-input-file %s + +llvm.func @fence_sync_restrict() { + // expected-error @below {{only acquire and release semantics are supported}} + nvvm.fence.sync_restrict {order = #nvvm.mem_order<weak>} + llvm.return +} + +// ----- + +llvm.func @fence_sync_restrict() { + // expected-error @below {{only acquire and release semantics are supported}} + nvvm.fence.sync_restrict {order = #nvvm.mem_order<mmio>} + llvm.return +} + +// ----- + +llvm.func @fence_proxy() { + // expected-error @below {{tensormap proxy is not a supported proxy kind}} + nvvm.fence.proxy {kind = #nvvm.proxy_kind<tensormap>} + llvm.return +} + +// ----- + +llvm.func @fence_proxy() { + // expected-error @below {{generic proxy not a supported proxy kind}} + nvvm.fence.proxy {kind = #nvvm.proxy_kind<generic>} + llvm.return +} + +// ----- + +llvm.func @fence_proxy() { + // expected-error @below {{async_shared fence requires space attribute}} + nvvm.fence.proxy {kind = #nvvm.proxy_kind<async.shared>} + llvm.return +} + +// ----- + +llvm.func @fence_proxy() { + // expected-error @below {{only async_shared fence can have space attribute}} + nvvm.fence.proxy {kind = #nvvm.proxy_kind<alias>, space = #nvvm.shared_space<cta>} + llvm.return +} + +// ----- + +llvm.func @fence_proxy_release() { + // expected-error @below {{uni-directional proxies only support generic for from_proxy attribute}} + nvvm.fence.proxy.release #nvvm.mem_scope<cta> from_proxy = #nvvm.proxy_kind<alias> to_proxy = #nvvm.proxy_kind<tensormap> + llvm.return +} + +// ----- + +llvm.func @fence_proxy_release() { + // expected-error @below {{uni-directional proxies only support tensormap for to_proxy attribute}} + nvvm.fence.proxy.release #nvvm.mem_scope<cta> from_proxy = #nvvm.proxy_kind<generic> to_proxy = #nvvm.proxy_kind<async> + llvm.return +} + +// ----- + +llvm.func @fence_proxy_sync_restrict() { + // expected-error @below {{only acquire and release semantics are supported}} + nvvm.fence.proxy.sync_restrict {order = #nvvm.mem_order<mmio>} + llvm.return +} + +// ----- + +llvm.func @fence_proxy_sync_restrict() { + // expected-error @below {{only async is supported for to_proxy attribute}} + nvvm.fence.proxy.sync_restrict {order = #nvvm.mem_order<acquire>, toProxy = #nvvm.proxy_kind<alias>, + fromProxy = #nvvm.proxy_kind<generic>} + llvm.return +} + +// ----- + +llvm.func @fence_proxy_sync_restrict() { + // expected-error @below {{only generic is support for from_proxy attribute}} + nvvm.fence.proxy.sync_restrict {order = #nvvm.mem_order<acquire>, toProxy = #nvvm.proxy_kind<async>, + fromProxy = #nvvm.proxy_kind<tensormap>} + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/fence.mlir b/mlir/test/Target/LLVMIR/nvvm/fence.mlir new file mode 100644 index 0000000..0ab4cb7 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/fence.mlir @@ -0,0 +1,85 @@ +// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @llvm_nvvm_fence_sc_cluster +llvm.func @llvm_nvvm_fence_sc_cluster() { + // CHECK: nvvm.fence.sc.cluster + nvvm.fence.sc.cluster + llvm.return +} + +// CHECK-LABEL: @nvvm_fence_sync_restrict +llvm.func @nvvm_fence_sync_restrict() { + // CHECK: call void @llvm.nvvm.fence.acquire.sync_restrict.space.cluster.scope.cluster() + nvvm.fence.sync_restrict {order = #nvvm.mem_order<acquire>} + // CHECK: call void @llvm.nvvm.fence.release.sync_restrict.space.cta.scope.cluster() + nvvm.fence.sync_restrict {order = #nvvm.mem_order<release>} + llvm.return +} + +// CHECK-LABEL: @fence_mbarrier_init +llvm.func @fence_mbarrier_init() { + // CHECK: call void @llvm.nvvm.fence.mbarrier_init.release.cluster() + nvvm.fence.mbarrier.init + llvm.return +} + +// CHECK-LABEL: @nvvm_fence_proxy +llvm.func @nvvm_fence_proxy() { + // CHECK: call void @llvm.nvvm.fence.proxy.alias() + nvvm.fence.proxy {kind = #nvvm.proxy_kind<alias>} + + // CHECK: call void @llvm.nvvm.fence.proxy.async() + nvvm.fence.proxy {kind = #nvvm.proxy_kind<async>} + + // CHECK: call void @llvm.nvvm.fence.proxy.async.global() + nvvm.fence.proxy {kind = #nvvm.proxy_kind<async.global>} + + // CHECK: call void @llvm.nvvm.fence.proxy.async.shared_cta() + nvvm.fence.proxy {kind = #nvvm.proxy_kind<async.shared>, space = #nvvm.shared_space<cta>} + + // CHECK: call void @llvm.nvvm.fence.proxy.async.shared_cluster() + nvvm.fence.proxy {kind = #nvvm.proxy_kind<async.shared>, space = #nvvm.shared_space<cluster>} + llvm.return +} + +// CHECK-LABEL: @nvvm_fence_proxy_sync_restrict +llvm.func @nvvm_fence_proxy_sync_restrict() { + // CHECK: call void @llvm.nvvm.fence.proxy.async_generic.acquire.sync_restrict.space.cluster.scope.cluster() + nvvm.fence.proxy.sync_restrict {order = #nvvm.mem_order<acquire>} + // CHECK: call void @llvm.nvvm.fence.proxy.async_generic.release.sync_restrict.space.cta.scope.cluster() + nvvm.fence.proxy.sync_restrict {order = #nvvm.mem_order<release>} + llvm.return +} + +// CHECK-LABEL: @nvvm_fence_proxy_tensormap_generic_release +llvm.func @nvvm_fence_proxy_tensormap_generic_release() { + // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.release.cta() + nvvm.fence.proxy.release #nvvm.mem_scope<cta> + + // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.release.cluster() + nvvm.fence.proxy.release #nvvm.mem_scope<cluster> + + // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.release.gpu() + nvvm.fence.proxy.release #nvvm.mem_scope<gpu> + + // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.release.sys() + nvvm.fence.proxy.release #nvvm.mem_scope<sys> + llvm.return +} + +// CHECK-LABEL: @nvvm_fence_proxy_tensormap_generic_acquire +llvm.func @nvvm_fence_proxy_tensormap_generic_acquire(%addr : !llvm.ptr) { + %c128 = llvm.mlir.constant(128) : i32 + // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.acquire.cta(ptr {{%[0-9]+}}, i32 128) + nvvm.fence.proxy.acquire #nvvm.mem_scope<cta> %addr, %c128 + + // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.acquire.cluster(ptr {{%[0-9]+}}, i32 128) + nvvm.fence.proxy.acquire #nvvm.mem_scope<cluster> %addr, %c128 + + // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.acquire.gpu(ptr {{%[0-9]+}}, i32 128) + nvvm.fence.proxy.acquire #nvvm.mem_scope<gpu> %addr, %c128 + + // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.acquire.sys(ptr {{%[0-9]+}}, i32 128) + nvvm.fence.proxy.acquire #nvvm.mem_scope<sys> %addr, %c128 + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/invalid_convert_fp16x2.mlir b/mlir/test/Target/LLVMIR/nvvm/invalid_convert_fp16x2.mlir new file mode 100644 index 0000000..37756c8 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/invalid_convert_fp16x2.mlir @@ -0,0 +1,47 @@ +// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s + +// ----- + +llvm.func @nvvm_cvt_f32x2_to_f16x2_invalid_rounding(%a : f32, %b : f32) { + // expected-error @below {{Only RN, RZ, and RS rounding modes are supported for conversions from f32x2 to f16x2.}} + %res = nvvm.convert.f32x2.to.f16x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rm>} : vector<2xf16> + llvm.return +} + +// ----- + +llvm.func @nvvm_cvt_f32x2_to_f16x2_invalid_rbits_1(%a : f32, %b : f32) { + // expected-error @below {{random_bits is required for RS rounding mode.}} + %res = nvvm.convert.f32x2.to.f16x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16> + llvm.return +} + +// ----- + +llvm.func @nvvm_cvt_f32x2_to_f16x2_invalid_rbits_2(%a : f32, %b : f32, %rbits : i32) { + // expected-error @below {{random_bits not supported for RN and RZ rounding modes.}} + %res = nvvm.convert.f32x2.to.f16x2 %a, %b, %rbits {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16> + llvm.return +} + +// ----- + +llvm.func @nvvm_cvt_f32x2_to_bf16x2_invalid_rounding(%a : f32, %b : f32) { + // expected-error @below {{Only RN, RZ, and RS rounding modes are supported for conversions from f32x2 to bf16x2.}} + %res = nvvm.convert.f32x2.to.bf16x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rm>} : vector<2xbf16> + llvm.return +} + +// ----- + +llvm.func @nvvm_cvt_f32x2_to_bf16x2_invalid_rbits_1(%a : f32, %b : f32) { + // expected-error @below {{random_bits is required for RS rounding mode.}} + %res = nvvm.convert.f32x2.to.bf16x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16> + llvm.return +} + +llvm.func @nvvm_cvt_f32x2_to_bf16x2_invalid_rbits_2(%a : f32, %b : f32, %rbits : i32) { + // expected-error @below {{random_bits not supported for RN and RZ rounding modes.}} + %res = nvvm.convert.f32x2.to.bf16x2 %a, %b, %rbits {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16> + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_arr_drop_expect_tx.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_arr_drop_expect_tx.mlir new file mode 100644 index 0000000..4b3cafe --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/mbar_arr_drop_expect_tx.mlir @@ -0,0 +1,68 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +llvm.func @mbarrier_arrive_drop_expect_tx_generic(%barrier: !llvm.ptr, %txcount : i32) { + // CHECK-LABEL: define void @mbarrier_arrive_drop_expect_tx_generic(ptr %0, i32 %1) { + // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %4 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.scope.cta.space.cta(ptr addrspace(3) %3, i32 %1) + // CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %6 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.scope.cta.space.cta(ptr addrspace(3) %5, i32 %1) + // CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %8 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.scope.cluster.space.cta(ptr addrspace(3) %7, i32 %1) + // CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %10 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.relaxed.scope.cta.space.cta(ptr addrspace(3) %9, i32 %1) + // CHECK-NEXT: %11 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %12 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.relaxed.scope.cta.space.cta(ptr addrspace(3) %11, i32 %1) + // CHECK-NEXT: %13 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %14 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.relaxed.scope.cluster.space.cta(ptr addrspace(3) %13, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount : !llvm.ptr, i32 -> i64 + %1 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>} : !llvm.ptr, i32 -> i64 + %2 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i32 -> i64 + + %3 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {relaxed = true} : !llvm.ptr, i32 -> i64 + %4 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr, i32 -> i64 + %5 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr, i32 -> i64 + llvm.return +} + +llvm.func @mbarrier_arrive_drop_expect_tx_shared(%barrier: !llvm.ptr<3>, %txcount : i32) { + // CHECK-LABEL: define void @mbarrier_arrive_drop_expect_tx_shared(ptr addrspace(3) %0, i32 %1) { + // CHECK-NEXT: %3 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %4 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %5 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %6 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %7 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %8 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount : !llvm.ptr<3>, i32 -> i64 + %1 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<3>, i32 -> i64 + %2 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32 -> i64 + + %3 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {relaxed = true} : !llvm.ptr<3>, i32 -> i64 + %4 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr<3>, i32 -> i64 + %5 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr<3>, i32 -> i64 + llvm.return +} + +llvm.func @mbarrier_arrive_drop_expect_tx_shared_cluster(%barrier: !llvm.ptr<7>, %txcount : i32) { + // CHECK-LABEL: define void @mbarrier_arrive_drop_expect_tx_shared_cluster(ptr addrspace(7) %0, i32 %1) { + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.expect.tx.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.expect.tx.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.expect.tx.scope.cluster.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.expect.tx.relaxed.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.expect.tx.relaxed.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.expect.tx.relaxed.scope.cluster.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount : !llvm.ptr<7>, i32 + nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<7>, i32 + nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<7>, i32 + + nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {relaxed = true} : !llvm.ptr<7>, i32 + nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr<7>, i32 + nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr<7>, i32 + llvm.return +} + diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_arr_expect_tx.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_arr_expect_tx.mlir new file mode 100644 index 0000000..b5389bd --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/mbar_arr_expect_tx.mlir @@ -0,0 +1,68 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +llvm.func @mbarrier_arrive_expect_tx_generic(%barrier: !llvm.ptr, %txcount : i32) { + // CHECK-LABEL: define void @mbarrier_arrive_expect_tx_generic(ptr %0, i32 %1) { + // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %4 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.scope.cta.space.cta(ptr addrspace(3) %3, i32 %1) + // CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %6 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.scope.cta.space.cta(ptr addrspace(3) %5, i32 %1) + // CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %8 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.scope.cluster.space.cta(ptr addrspace(3) %7, i32 %1) + // CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %10 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.relaxed.scope.cta.space.cta(ptr addrspace(3) %9, i32 %1) + // CHECK-NEXT: %11 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %12 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.relaxed.scope.cta.space.cta(ptr addrspace(3) %11, i32 %1) + // CHECK-NEXT: %13 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %14 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.relaxed.scope.cluster.space.cta(ptr addrspace(3) %13, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr, i32 -> i64 + %1 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>} : !llvm.ptr, i32 -> i64 + %2 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i32 -> i64 + + %3 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {relaxed = true} : !llvm.ptr, i32 -> i64 + %4 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr, i32 -> i64 + %5 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr, i32 -> i64 + llvm.return +} + +llvm.func @mbarrier_arrive_expect_tx_shared(%barrier: !llvm.ptr<3>, %txcount : i32) { + // CHECK-LABEL: define void @mbarrier_arrive_expect_tx_shared(ptr addrspace(3) %0, i32 %1) { + // CHECK-NEXT: %3 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %4 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %5 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %6 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %7 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %8 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr<3>, i32 -> i64 + %1 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<3>, i32 -> i64 + %2 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32 -> i64 + + %3 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {relaxed = true} : !llvm.ptr<3>, i32 -> i64 + %4 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr<3>, i32 -> i64 + %5 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr<3>, i32 -> i64 + llvm.return +} + +llvm.func @mbarrier_arrive_expect_tx_shared_cluster(%barrier: !llvm.ptr<7>, %txcount : i32) { + // CHECK-LABEL: define void @mbarrier_arrive_expect_tx_shared_cluster(ptr addrspace(7) %0, i32 %1) { + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.expect.tx.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.expect.tx.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.expect.tx.scope.cluster.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.expect.tx.relaxed.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.expect.tx.relaxed.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.expect.tx.relaxed.scope.cluster.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr<7>, i32 + nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<7>, i32 + nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<7>, i32 + + nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {relaxed = true} : !llvm.ptr<7>, i32 + nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr<7>, i32 + nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr<7>, i32 + llvm.return +} + diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir new file mode 100644 index 0000000..6e7e163 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir @@ -0,0 +1,103 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +llvm.func @mbarrier_arrive_generic(%barrier: !llvm.ptr, %count : i32) { + // CHECK-LABEL: define void @mbarrier_arrive_generic(ptr %0, i32 %1) { + // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %4 = call i64 @llvm.nvvm.mbarrier.arrive.scope.cta.space.cta(ptr addrspace(3) %3, i32 1) + // CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %6 = call i64 @llvm.nvvm.mbarrier.arrive.scope.cta.space.cta(ptr addrspace(3) %5, i32 %1) + // CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %8 = call i64 @llvm.nvvm.mbarrier.arrive.scope.cta.space.cta(ptr addrspace(3) %7, i32 %1) + // CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %10 = call i64 @llvm.nvvm.mbarrier.arrive.scope.cluster.space.cta(ptr addrspace(3) %9, i32 %1) + // CHECK-NEXT: %11 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %12 = call i64 @llvm.nvvm.mbarrier.arrive.relaxed.scope.cta.space.cta(ptr addrspace(3) %11, i32 1) + // CHECK-NEXT: %13 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %14 = call i64 @llvm.nvvm.mbarrier.arrive.relaxed.scope.cta.space.cta(ptr addrspace(3) %13, i32 %1) + // CHECK-NEXT: %15 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %16 = call i64 @llvm.nvvm.mbarrier.arrive.relaxed.scope.cta.space.cta(ptr addrspace(3) %15, i32 %1) + // CHECK-NEXT: %17 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %18 = call i64 @llvm.nvvm.mbarrier.arrive.relaxed.scope.cluster.space.cta(ptr addrspace(3) %17, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.arrive %barrier : !llvm.ptr -> i64 + %1 = nvvm.mbarrier.arrive %barrier, %count : !llvm.ptr -> i64 + %2 = nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr -> i64 + %3 = nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr -> i64 + + %4 = nvvm.mbarrier.arrive %barrier {relaxed = true} : !llvm.ptr -> i64 + %5 = nvvm.mbarrier.arrive %barrier, %count {relaxed = true} : !llvm.ptr -> i64 + %6 = nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr -> i64 + %7 = nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr -> i64 + llvm.return +} + +llvm.func @mbarrier_arrive_shared(%barrier: !llvm.ptr<3>, %count : i32) { + // CHECK-LABEL: define void @mbarrier_arrive_shared(ptr addrspace(3) %0, i32 %1) { + // CHECK-NEXT: %3 = call i64 @llvm.nvvm.mbarrier.arrive.scope.cta.space.cta(ptr addrspace(3) %0, i32 1) + // CHECK-NEXT: %4 = call i64 @llvm.nvvm.mbarrier.arrive.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %5 = call i64 @llvm.nvvm.mbarrier.arrive.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %6 = call i64 @llvm.nvvm.mbarrier.arrive.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %7 = call i64 @llvm.nvvm.mbarrier.arrive.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 1) + // CHECK-NEXT: %8 = call i64 @llvm.nvvm.mbarrier.arrive.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %9 = call i64 @llvm.nvvm.mbarrier.arrive.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %10 = call i64 @llvm.nvvm.mbarrier.arrive.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.arrive %barrier : !llvm.ptr<3> -> i64 + %1 = nvvm.mbarrier.arrive %barrier, %count : !llvm.ptr<3> -> i64 + %2 = nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<3> -> i64 + %3 = nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3> -> i64 + + %4 = nvvm.mbarrier.arrive %barrier {relaxed = true} : !llvm.ptr<3> -> i64 + %5 = nvvm.mbarrier.arrive %barrier, %count {relaxed = true} : !llvm.ptr<3> -> i64 + %6 = nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr<3> -> i64 + %7 = nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr<3> -> i64 + llvm.return +} + +llvm.func @mbarrier_arrive_shared_cluster(%barrier: !llvm.ptr<7>, %count : i32) { + // CHECK-LABEL: define void @mbarrier_arrive_shared_cluster(ptr addrspace(7) %0, i32 %1) { + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.scope.cta.space.cluster(ptr addrspace(7) %0, i32 1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.scope.cluster.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.relaxed.scope.cta.space.cluster(ptr addrspace(7) %0, i32 1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.relaxed.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.relaxed.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.relaxed.scope.cluster.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + nvvm.mbarrier.arrive %barrier : !llvm.ptr<7> + nvvm.mbarrier.arrive %barrier, %count : !llvm.ptr<7> + nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<7> + nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<7> + + nvvm.mbarrier.arrive %barrier {relaxed = true} : !llvm.ptr<7> + nvvm.mbarrier.arrive %barrier, %count {relaxed = true} : !llvm.ptr<7> + nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr<7> + nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr<7> + llvm.return +} + +llvm.func @mbarrier_arrive_nocomplete(%barrier: !llvm.ptr) { + // CHECK-LABEL: define void @mbarrier_arrive_nocomplete(ptr %0) { + // CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x() + // CHECK-NEXT: %3 = call i64 @llvm.nvvm.mbarrier.arrive.noComplete(ptr %0, i32 %2) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %count = nvvm.read.ptx.sreg.ntid.x : i32 + %0 = nvvm.mbarrier.arrive.nocomplete %barrier, %count : !llvm.ptr, i32 -> i64 + llvm.return +} + +llvm.func @mbarrier_arrive_nocomplete_shared(%barrier: !llvm.ptr<3>) { + // CHECK-LABEL: define void @mbarrier_arrive_nocomplete_shared(ptr addrspace(3) %0) { + // CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x() + // CHECK-NEXT: %3 = call i64 @llvm.nvvm.mbarrier.arrive.noComplete.shared(ptr addrspace(3) %0, i32 %2) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %count = nvvm.read.ptx.sreg.ntid.x : i32 + %0 = nvvm.mbarrier.arrive.nocomplete %barrier, %count : !llvm.ptr<3>, i32 -> i64 + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_arrive_drop.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_arrive_drop.mlir new file mode 100644 index 0000000..c345c5d --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/mbar_arrive_drop.mlir @@ -0,0 +1,103 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +llvm.func @mbarrier_arrive_drop_generic(%barrier: !llvm.ptr, %count : i32) { + // CHECK-LABEL: define void @mbarrier_arrive_drop_generic(ptr %0, i32 %1) { + // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %4 = call i64 @llvm.nvvm.mbarrier.arrive.drop.scope.cta.space.cta(ptr addrspace(3) %3, i32 1) + // CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %6 = call i64 @llvm.nvvm.mbarrier.arrive.drop.scope.cta.space.cta(ptr addrspace(3) %5, i32 %1) + // CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %8 = call i64 @llvm.nvvm.mbarrier.arrive.drop.scope.cta.space.cta(ptr addrspace(3) %7, i32 %1) + // CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %10 = call i64 @llvm.nvvm.mbarrier.arrive.drop.scope.cluster.space.cta(ptr addrspace(3) %9, i32 %1) + // CHECK-NEXT: %11 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %12 = call i64 @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cta.space.cta(ptr addrspace(3) %11, i32 1) + // CHECK-NEXT: %13 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %14 = call i64 @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cta.space.cta(ptr addrspace(3) %13, i32 %1) + // CHECK-NEXT: %15 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %16 = call i64 @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cta.space.cta(ptr addrspace(3) %15, i32 %1) + // CHECK-NEXT: %17 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %18 = call i64 @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cluster.space.cta(ptr addrspace(3) %17, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.arrive_drop %barrier : !llvm.ptr -> i64 + %1 = nvvm.mbarrier.arrive_drop %barrier, %count : !llvm.ptr -> i64 + %2 = nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr -> i64 + %3 = nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr -> i64 + + %4 = nvvm.mbarrier.arrive_drop %barrier {relaxed = true} : !llvm.ptr -> i64 + %5 = nvvm.mbarrier.arrive_drop %barrier, %count {relaxed = true} : !llvm.ptr -> i64 + %6 = nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr -> i64 + %7 = nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr -> i64 + llvm.return +} + +llvm.func @mbarrier_arrive_drop_shared(%barrier: !llvm.ptr<3>, %count : i32) { + // CHECK-LABEL: define void @mbarrier_arrive_drop_shared(ptr addrspace(3) %0, i32 %1) { + // CHECK-NEXT: %3 = call i64 @llvm.nvvm.mbarrier.arrive.drop.scope.cta.space.cta(ptr addrspace(3) %0, i32 1) + // CHECK-NEXT: %4 = call i64 @llvm.nvvm.mbarrier.arrive.drop.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %5 = call i64 @llvm.nvvm.mbarrier.arrive.drop.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %6 = call i64 @llvm.nvvm.mbarrier.arrive.drop.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %7 = call i64 @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 1) + // CHECK-NEXT: %8 = call i64 @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %9 = call i64 @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %10 = call i64 @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.arrive_drop %barrier : !llvm.ptr<3> -> i64 + %1 = nvvm.mbarrier.arrive_drop %barrier, %count : !llvm.ptr<3> -> i64 + %2 = nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<3> -> i64 + %3 = nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3> -> i64 + + %4 = nvvm.mbarrier.arrive_drop %barrier {relaxed = true} : !llvm.ptr<3> -> i64 + %5 = nvvm.mbarrier.arrive_drop %barrier, %count {relaxed = true} : !llvm.ptr<3> -> i64 + %6 = nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr<3> -> i64 + %7 = nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr<3> -> i64 + llvm.return +} + +llvm.func @mbarrier_arrive_drop_shared_cluster(%barrier: !llvm.ptr<7>, %count : i32) { + // CHECK-LABEL: define void @mbarrier_arrive_drop_shared_cluster(ptr addrspace(7) %0, i32 %1) { + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.scope.cta.space.cluster(ptr addrspace(7) %0, i32 1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.scope.cluster.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cta.space.cluster(ptr addrspace(7) %0, i32 1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cluster.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + nvvm.mbarrier.arrive_drop %barrier : !llvm.ptr<7> + nvvm.mbarrier.arrive_drop %barrier, %count : !llvm.ptr<7> + nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<7> + nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<7> + + nvvm.mbarrier.arrive_drop %barrier {relaxed = true} : !llvm.ptr<7> + nvvm.mbarrier.arrive_drop %barrier, %count {relaxed = true} : !llvm.ptr<7> + nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr<7> + nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr<7> + llvm.return +} + +llvm.func @mbarrier_arrive_drop_nocomplete(%barrier: !llvm.ptr) { + // CHECK-LABEL: define void @mbarrier_arrive_drop_nocomplete(ptr %0) { + // CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x() + // CHECK-NEXT: %3 = call i64 @llvm.nvvm.mbarrier.arrive.drop.noComplete(ptr %0, i32 %2) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %count = nvvm.read.ptx.sreg.ntid.x : i32 + %0 = nvvm.mbarrier.arrive_drop.nocomplete %barrier, %count : !llvm.ptr, i32 -> i64 + llvm.return +} + +llvm.func @mbarrier_arrive_drop_nocomplete_shared(%barrier: !llvm.ptr<3>) { + // CHECK-LABEL: define void @mbarrier_arrive_drop_nocomplete_shared(ptr addrspace(3) %0) { + // CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x() + // CHECK-NEXT: %3 = call i64 @llvm.nvvm.mbarrier.arrive.drop.noComplete.shared(ptr addrspace(3) %0, i32 %2) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %count = nvvm.read.ptx.sreg.ntid.x : i32 + %0 = nvvm.mbarrier.arrive_drop.nocomplete %barrier, %count : !llvm.ptr<3>, i32 -> i64 + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_complete_tx.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_complete_tx.mlir new file mode 100644 index 0000000..99289fa --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/mbar_complete_tx.mlir @@ -0,0 +1,29 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +llvm.func @mbarrier_complete_tx_shared(%barrier: !llvm.ptr<3>, %tx_count : i32) { + // CHECK-LABEL: define void @mbarrier_complete_tx_shared(ptr addrspace(3) %0, i32 %1) { + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.complete.tx.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.complete.tx.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.complete.tx.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + nvvm.mbarrier.complete_tx %barrier, %tx_count : !llvm.ptr<3>, i32 + nvvm.mbarrier.complete_tx %barrier, %tx_count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<3>, i32 + nvvm.mbarrier.complete_tx %barrier, %tx_count {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32 + + llvm.return +} + +llvm.func @mbarrier_complete_tx_shared_cluster(%barrier: !llvm.ptr<7>, %tx_count : i32) { + // CHECK-LABEL: define void @mbarrier_complete_tx_shared_cluster(ptr addrspace(7) %0, i32 %1) { + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.complete.tx.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.complete.tx.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.complete.tx.scope.cluster.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + nvvm.mbarrier.complete_tx %barrier, %tx_count : !llvm.ptr<7>, i32 + nvvm.mbarrier.complete_tx %barrier, %tx_count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<7>, i32 + nvvm.mbarrier.complete_tx %barrier, %tx_count {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<7>, i32 + + llvm.return +}
\ No newline at end of file diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_expect_tx.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_expect_tx.mlir new file mode 100644 index 0000000..dad7237 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/mbar_expect_tx.mlir @@ -0,0 +1,29 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +llvm.func @mbarrier_expect_tx_shared(%barrier: !llvm.ptr<3>, %tx_count : i32) { + // CHECK-LABEL: define void @mbarrier_expect_tx_shared(ptr addrspace(3) %0, i32 %1) { + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.expect.tx.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.expect.tx.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.expect.tx.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + nvvm.mbarrier.expect_tx %barrier, %tx_count : !llvm.ptr<3>, i32 + nvvm.mbarrier.expect_tx %barrier, %tx_count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<3>, i32 + nvvm.mbarrier.expect_tx %barrier, %tx_count {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32 + + llvm.return +} + +llvm.func @mbarrier_expect_tx_shared_cluster(%barrier: !llvm.ptr<7>, %tx_count : i32) { + // CHECK-LABEL: define void @mbarrier_expect_tx_shared_cluster(ptr addrspace(7) %0, i32 %1) { + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.expect.tx.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.expect.tx.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.expect.tx.scope.cluster.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + nvvm.mbarrier.expect_tx %barrier, %tx_count : !llvm.ptr<7>, i32 + nvvm.mbarrier.expect_tx %barrier, %tx_count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<7>, i32 + nvvm.mbarrier.expect_tx %barrier, %tx_count {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<7>, i32 + + llvm.return +}
\ No newline at end of file diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_init.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_init.mlir new file mode 100644 index 0000000..9c1d1cc --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/mbar_init.mlir @@ -0,0 +1,56 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +llvm.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.ptr) { + // CHECK-LABEL: define void @cp_async_mbarrier_arrive(ptr addrspace(3) %0, ptr %1) { + // CHECK-NEXT: call void @llvm.nvvm.cp.async.mbarrier.arrive(ptr %1) + // CHECK-NEXT: call void @llvm.nvvm.cp.async.mbarrier.arrive.noinc(ptr %1) + // CHECK-NEXT: call void @llvm.nvvm.cp.async.mbarrier.arrive.shared(ptr addrspace(3) %0) + // CHECK-NEXT: call void @llvm.nvvm.cp.async.mbarrier.arrive.noinc.shared(ptr addrspace(3) %0) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + nvvm.cp.async.mbarrier.arrive %bar_gen : !llvm.ptr + nvvm.cp.async.mbarrier.arrive %bar_gen {noinc = true} : !llvm.ptr + nvvm.cp.async.mbarrier.arrive %bar_shared : !llvm.ptr<3> + nvvm.cp.async.mbarrier.arrive %bar_shared {noinc = true} : !llvm.ptr<3> + llvm.return +} + +llvm.func @mbarrier_init_generic(%barrier: !llvm.ptr) { + // CHECK-LABEL: define void @mbarrier_init_generic(ptr %0) { + // CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x() + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.init(ptr %0, i32 %2) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %count = nvvm.read.ptx.sreg.ntid.x : i32 + nvvm.mbarrier.init %barrier, %count : !llvm.ptr, i32 + llvm.return +} + +llvm.func @mbarrier_init_shared(%barrier: !llvm.ptr<3>) { + // CHECK-LABEL: define void @mbarrier_init_shared(ptr addrspace(3) %0) { + // CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x() + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.init.shared(ptr addrspace(3) %0, i32 %2) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %count = nvvm.read.ptx.sreg.ntid.x : i32 + nvvm.mbarrier.init %barrier, %count : !llvm.ptr<3>, i32 + llvm.return +} + +llvm.func @mbarrier_inval_generic(%barrier: !llvm.ptr) { + // CHECK-LABEL: define void @mbarrier_inval_generic(ptr %0) { + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.inval(ptr %0) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + nvvm.mbarrier.inval %barrier : !llvm.ptr + llvm.return +} + +llvm.func @mbarrier_inval_shared(%barrier: !llvm.ptr<3>) { + // CHECK-LABEL: define void @mbarrier_inval_shared(ptr addrspace(3) %0) { + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.inval.shared(ptr addrspace(3) %0) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + nvvm.mbarrier.inval %barrier : !llvm.ptr<3> + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_invalid.mlir new file mode 100644 index 0000000..4a7776d --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/mbar_invalid.mlir @@ -0,0 +1,138 @@ +// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s + +// ----- + +llvm.func @mbarrier_arrive_ret_check(%barrier: !llvm.ptr<7>) { + // expected-error @below {{mbarrier in shared_cluster space cannot return any value}} + %0 = nvvm.mbarrier.arrive %barrier : !llvm.ptr<7> -> i64 + llvm.return +} + +// ----- + +llvm.func @mbarrier_arrive_invalid_scope(%barrier: !llvm.ptr<7>) { + // expected-error @below {{mbarrier scope must be either CTA or Cluster}} + %0 = nvvm.mbarrier.arrive %barrier {scope = #nvvm.mem_scope<gpu>} : !llvm.ptr<7> -> i64 + llvm.return +} + +// ----- + +llvm.func @mbarrier_arrive_drop_ret_check(%barrier: !llvm.ptr<7>) { + // expected-error @below {{mbarrier in shared_cluster space cannot return any value}} + %0 = nvvm.mbarrier.arrive_drop %barrier : !llvm.ptr<7> -> i64 + llvm.return +} + +// ----- + +llvm.func @mbarrier_arrive_drop_invalid_scope(%barrier: !llvm.ptr<7>) { + // expected-error @below {{mbarrier scope must be either CTA or Cluster}} + %0 = nvvm.mbarrier.arrive_drop %barrier {scope = #nvvm.mem_scope<gpu>} : !llvm.ptr<7> -> i64 + llvm.return +} + +// ----- + +llvm.func @mbarrier_expect_tx_scope(%barrier: !llvm.ptr<7>, %tx_count: i32) { + // expected-error @below {{mbarrier scope must be either CTA or Cluster}} + nvvm.mbarrier.expect_tx %barrier, %tx_count {scope = #nvvm.mem_scope<gpu>} : !llvm.ptr<7>, i32 + llvm.return +} + +// ----- + +llvm.func @mbarrier_complete_tx_scope(%barrier: !llvm.ptr<3>, %tx_count: i32) { + // expected-error @below {{mbarrier scope must be either CTA or Cluster}} + nvvm.mbarrier.complete_tx %barrier, %tx_count {scope = #nvvm.mem_scope<sys>} : !llvm.ptr<3>, i32 + llvm.return +} + +// ----- + +llvm.func @mbarrier_arr_expect_tx(%barrier: !llvm.ptr<3>, %tx_count: i32) { + // expected-error @below {{mbarrier scope must be either CTA or Cluster}} + %1 = nvvm.mbarrier.arrive.expect_tx %barrier, %tx_count {scope = #nvvm.mem_scope<gpu>} : !llvm.ptr<3>, i32 -> i64 + llvm.return +} + +// ----- + +llvm.func @mbarrier_arr_expect_tx_cluster(%barrier: !llvm.ptr<7>, %tx_count: i32) { + // expected-error @below {{mbarrier in shared_cluster space cannot return any value}} + %1 = nvvm.mbarrier.arrive.expect_tx %barrier, %tx_count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<7>, i32 -> i64 + llvm.return +} + +// ----- + +llvm.func @init_mbarrier_arrive_expect_tx_asm_ret(%barrier : !llvm.ptr<3>, %txcount : i32, %pred : i1) { + // expected-error @below {{return-value is not supported when using predicate}} + %1 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1 -> i64 + llvm.return +} + +// ----- + +llvm.func @init_mbarrier_arrive_expect_tx_asm_relaxed(%barrier : !llvm.ptr<3>, %txcount : i32, %pred : i1) { + // expected-error @below {{mbarrier with relaxed semantics is not supported when using predicate}} + nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred {relaxed = true} : !llvm.ptr<3>, i32, i1 + llvm.return +} + +// ----- + +llvm.func @init_mbarrier_arrive_expect_tx_asm_cta(%barrier : !llvm.ptr<3>, %txcount : i32, %pred : i1) { + // expected-error @below {{mbarrier scope must be CTA when using predicate}} + nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32, i1 + llvm.return +} + +// ----- + +llvm.func @init_mbarrier_arrive_expect_tx_asm_cluster(%barrier : !llvm.ptr<7>, %txcount : i32, %pred : i1) { + // expected-error @below {{mbarrier in shared_cluster space is not supported when using predicate}} + nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred : !llvm.ptr<7>, i32, i1 + llvm.return +} + +// ----- + +llvm.func @mbarrier_arr_drop_expect_tx(%barrier: !llvm.ptr<3>, %tx_count: i32) { + // expected-error @below {{mbarrier scope must be either CTA or Cluster}} + %1 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %tx_count {scope = #nvvm.mem_scope<gpu>} : !llvm.ptr<3>, i32 -> i64 + llvm.return +} + +// ----- + +llvm.func @mbarrier_arr_drop_expect_tx_cluster(%barrier: !llvm.ptr<7>, %tx_count: i32) { + // expected-error @below {{mbarrier in shared_cluster space cannot return any value}} + %1 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %tx_count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<7>, i32 -> i64 + llvm.return +} + +// ----- + +llvm.func @mbarrier_test_wait(%barrier: !llvm.ptr<3>, %phase: i32) { + // expected-error @below {{mbarrier scope must be either CTA or Cluster}} + %1 = nvvm.mbarrier.test.wait %barrier, %phase {scope = #nvvm.mem_scope<gpu>} : !llvm.ptr<3>, i32 -> i1 + llvm.return +} + +// ----- + +llvm.func @mbarrier_try_wait(%barrier: !llvm.ptr<3>, %phase: i32) { + // expected-error @below {{mbarrier scope must be either CTA or Cluster}} + %1 = nvvm.mbarrier.try_wait %barrier, %phase {scope = #nvvm.mem_scope<sys>} : !llvm.ptr<3>, i32 -> i1 + llvm.return +} + +// ----- + +llvm.func @mbarrier_try_wait_with_timelimit(%barrier: !llvm.ptr<3>, %phase: i32, %ticks: i32) { + // expected-error @below {{mbarrier scope must be either CTA or Cluster}} + %1 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks {scope = #nvvm.mem_scope<gpu>} : !llvm.ptr<3>, i32, i32 -> i1 + llvm.return +} + diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_test_wait.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_test_wait.mlir new file mode 100644 index 0000000..21ab72e --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/mbar_test_wait.mlir @@ -0,0 +1,73 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +llvm.func @mbarrier_test_wait_state(%barrier: !llvm.ptr, %state : i64) { + // CHECK-LABEL: define void @mbarrier_test_wait_state(ptr %0, i64 %1) { + // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.scope.cta.space.cta(ptr addrspace(3) %3, i64 %1) + // CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.test.wait.scope.cluster.space.cta(ptr addrspace(3) %5, i64 %1) + // CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %8 = call i1 @llvm.nvvm.mbarrier.test.wait.relaxed.scope.cta.space.cta(ptr addrspace(3) %7, i64 %1) + // CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %10 = call i1 @llvm.nvvm.mbarrier.test.wait.relaxed.scope.cluster.space.cta(ptr addrspace(3) %9, i64 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.test.wait %barrier, %state : !llvm.ptr, i64 -> i1 + %1 = nvvm.mbarrier.test.wait %barrier, %state {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i64 -> i1 + + %2 = nvvm.mbarrier.test.wait %barrier, %state {relaxed = true} : !llvm.ptr, i64 -> i1 + %3 = nvvm.mbarrier.test.wait %barrier, %state {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i64 -> i1 + llvm.return +} + +llvm.func @mbarrier_test_wait_shared_state(%barrier: !llvm.ptr<3>, %state : i64) { + // CHECK-LABEL: define void @mbarrier_test_wait_shared_state(ptr addrspace(3) %0, i64 %1) { + // CHECK-NEXT: %3 = call i1 @llvm.nvvm.mbarrier.test.wait.scope.cta.space.cta(ptr addrspace(3) %0, i64 %1) + // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.scope.cluster.space.cta(ptr addrspace(3) %0, i64 %1) + // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.test.wait.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i64 %1) + // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.test.wait.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i64 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.test.wait %barrier, %state : !llvm.ptr<3>, i64 -> i1 + %1 = nvvm.mbarrier.test.wait %barrier, %state {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i64 -> i1 + + %2 = nvvm.mbarrier.test.wait %barrier, %state {relaxed = true} : !llvm.ptr<3>, i64 -> i1 + %3 = nvvm.mbarrier.test.wait %barrier, %state {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i64 -> i1 + llvm.return +} + +llvm.func @mbarrier_test_wait_phase(%barrier: !llvm.ptr, %phase : i32) { + // CHECK-LABEL: define void @mbarrier_test_wait_phase(ptr %0, i32 %1) { + // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.scope.cta.space.cta(ptr addrspace(3) %3, i32 %1) + // CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.scope.cluster.space.cta(ptr addrspace(3) %5, i32 %1) + // CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %8 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.relaxed.scope.cta.space.cta(ptr addrspace(3) %7, i32 %1) + // CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %10 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.relaxed.scope.cluster.space.cta(ptr addrspace(3) %9, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.test.wait %barrier, %phase : !llvm.ptr, i32 -> i1 + %1 = nvvm.mbarrier.test.wait %barrier, %phase {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i32 -> i1 + + %2 = nvvm.mbarrier.test.wait %barrier, %phase {relaxed = true} : !llvm.ptr, i32 -> i1 + %3 = nvvm.mbarrier.test.wait %barrier, %phase {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i32 -> i1 + llvm.return +} + +llvm.func @mbarrier_test_wait_shared_phase(%barrier: !llvm.ptr<3>, %phase : i32) { + // CHECK-LABEL: define void @mbarrier_test_wait_shared_phase(ptr addrspace(3) %0, i32 %1) { + // CHECK-NEXT: %3 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.test.wait %barrier, %phase : !llvm.ptr<3>, i32 -> i1 + %1 = nvvm.mbarrier.test.wait %barrier, %phase {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32 -> i1 + + %2 = nvvm.mbarrier.test.wait %barrier, %phase {relaxed = true} : !llvm.ptr<3>, i32 -> i1 + %3 = nvvm.mbarrier.test.wait %barrier, %phase {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32 -> i1 + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_try_wait.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_try_wait.mlir new file mode 100644 index 0000000..18aaf0e --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/mbar_try_wait.mlir @@ -0,0 +1,147 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +llvm.func @mbarrier_try_wait_state(%barrier: !llvm.ptr, %state : i64) { + // CHECK-LABEL: define void @mbarrier_try_wait_state(ptr %0, i64 %1) { + // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.try.wait.scope.cta.space.cta(ptr addrspace(3) %3, i64 %1) + // CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.try.wait.scope.cluster.space.cta(ptr addrspace(3) %5, i64 %1) + // CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %8 = call i1 @llvm.nvvm.mbarrier.try.wait.relaxed.scope.cta.space.cta(ptr addrspace(3) %7, i64 %1) + // CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %10 = call i1 @llvm.nvvm.mbarrier.try.wait.relaxed.scope.cluster.space.cta(ptr addrspace(3) %9, i64 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.try_wait %barrier, %state : !llvm.ptr, i64 -> i1 + %1 = nvvm.mbarrier.try_wait %barrier, %state {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i64 -> i1 + + %2 = nvvm.mbarrier.try_wait %barrier, %state {relaxed = true} : !llvm.ptr, i64 -> i1 + %3 = nvvm.mbarrier.try_wait %barrier, %state {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i64 -> i1 + + llvm.return +} + +llvm.func @mbarrier_try_wait_state_with_timelimit(%barrier: !llvm.ptr, %state : i64, %ticks : i32) { + // CHECK-LABEL: define void @mbarrier_try_wait_state_with_timelimit(ptr %0, i64 %1, i32 %2) { + // CHECK-NEXT: %4 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.try.wait.tl.scope.cta.space.cta(ptr addrspace(3) %4, i64 %1, i32 %2) + // CHECK-NEXT: %6 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %7 = call i1 @llvm.nvvm.mbarrier.try.wait.tl.scope.cluster.space.cta(ptr addrspace(3) %6, i64 %1, i32 %2) + // CHECK-NEXT: %8 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %9 = call i1 @llvm.nvvm.mbarrier.try.wait.tl.relaxed.scope.cta.space.cta(ptr addrspace(3) %8, i64 %1, i32 %2) + // CHECK-NEXT: %10 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %11 = call i1 @llvm.nvvm.mbarrier.try.wait.tl.relaxed.scope.cluster.space.cta(ptr addrspace(3) %10, i64 %1, i32 %2) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.try_wait %barrier, %state, %ticks : !llvm.ptr, i64, i32 -> i1 + %1 = nvvm.mbarrier.try_wait %barrier, %state, %ticks {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i64, i32 -> i1 + + %2 = nvvm.mbarrier.try_wait %barrier, %state, %ticks {relaxed = true} : !llvm.ptr, i64, i32 -> i1 + %3 = nvvm.mbarrier.try_wait %barrier, %state, %ticks {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i64, i32 -> i1 + + llvm.return +} + +llvm.func @mbarrier_try_wait_shared_state(%barrier: !llvm.ptr<3>, %state : i64) { + // CHECK-LABEL: define void @mbarrier_try_wait_shared_state(ptr addrspace(3) %0, i64 %1) { + // CHECK-NEXT: %3 = call i1 @llvm.nvvm.mbarrier.try.wait.scope.cta.space.cta(ptr addrspace(3) %0, i64 %1) + // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.try.wait.scope.cluster.space.cta(ptr addrspace(3) %0, i64 %1) + // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.try.wait.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i64 %1) + // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.try.wait.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i64 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.try_wait %barrier, %state : !llvm.ptr<3>, i64 -> i1 + %1 = nvvm.mbarrier.try_wait %barrier, %state {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i64 -> i1 + + %2 = nvvm.mbarrier.try_wait %barrier, %state {relaxed = true} : !llvm.ptr<3>, i64 -> i1 + %3 = nvvm.mbarrier.try_wait %barrier, %state {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i64 -> i1 + llvm.return +} + +llvm.func @mbarrier_try_wait_shared_state_with_timelimit(%barrier: !llvm.ptr<3>, %state : i64, %ticks : i32) { + // CHECK-LABEL: define void @mbarrier_try_wait_shared_state_with_timelimit(ptr addrspace(3) %0, i64 %1, i32 %2) { + // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.try.wait.tl.scope.cta.space.cta(ptr addrspace(3) %0, i64 %1, i32 %2) + // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.try.wait.tl.scope.cluster.space.cta(ptr addrspace(3) %0, i64 %1, i32 %2) + // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.try.wait.tl.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i64 %1, i32 %2) + // CHECK-NEXT: %7 = call i1 @llvm.nvvm.mbarrier.try.wait.tl.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i64 %1, i32 %2) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.try_wait %barrier, %state, %ticks : !llvm.ptr<3>, i64, i32 -> i1 + %1 = nvvm.mbarrier.try_wait %barrier, %state, %ticks {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i64, i32 -> i1 + + %2 = nvvm.mbarrier.try_wait %barrier, %state, %ticks {relaxed = true} : !llvm.ptr<3>, i64, i32 -> i1 + %3 = nvvm.mbarrier.try_wait %barrier, %state, %ticks {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i64, i32 -> i1 + llvm.return +} + +llvm.func @mbarrier_try_wait_phase(%barrier: !llvm.ptr, %phase : i32) { + // CHECK-LABEL: define void @mbarrier_try_wait_phase(ptr %0, i32 %1) { + // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.scope.cta.space.cta(ptr addrspace(3) %3, i32 %1) + // CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.scope.cluster.space.cta(ptr addrspace(3) %5, i32 %1) + // CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %8 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.relaxed.scope.cta.space.cta(ptr addrspace(3) %7, i32 %1) + // CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %10 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.relaxed.scope.cluster.space.cta(ptr addrspace(3) %9, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.try_wait %barrier, %phase : !llvm.ptr, i32 -> i1 + %1 = nvvm.mbarrier.try_wait %barrier, %phase {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i32 -> i1 + + %2 = nvvm.mbarrier.try_wait %barrier, %phase {relaxed = true} : !llvm.ptr, i32 -> i1 + %3 = nvvm.mbarrier.try_wait %barrier, %phase {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i32 -> i1 + llvm.return +} + +llvm.func @mbarrier_try_wait_phase_with_timelimit(%barrier: !llvm.ptr, %phase : i32, %ticks : i32) { + // CHECK-LABEL: define void @mbarrier_try_wait_phase_with_timelimit(ptr %0, i32 %1, i32 %2) { + // CHECK-NEXT: %4 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.tl.scope.cta.space.cta(ptr addrspace(3) %4, i32 %1, i32 %2) + // CHECK-NEXT: %6 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %7 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.tl.scope.cluster.space.cta(ptr addrspace(3) %6, i32 %1, i32 %2) + // CHECK-NEXT: %8 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %9 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.tl.relaxed.scope.cta.space.cta(ptr addrspace(3) %8, i32 %1, i32 %2) + // CHECK-NEXT: %10 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %11 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.tl.relaxed.scope.cluster.space.cta(ptr addrspace(3) %10, i32 %1, i32 %2) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks : !llvm.ptr, i32, i32 -> i1 + %1 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i32, i32 -> i1 + + %2 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks {relaxed = true} : !llvm.ptr, i32, i32 -> i1 + %3 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i32, i32 -> i1 + llvm.return +} + +llvm.func @mbarrier_try_wait_shared_phase(%barrier: !llvm.ptr<3>, %phase : i32) { + // CHECK-LABEL: define void @mbarrier_try_wait_shared_phase(ptr addrspace(3) %0, i32 %1) { + // CHECK-NEXT: %3 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.try_wait %barrier, %phase : !llvm.ptr<3>, i32 -> i1 + %1 = nvvm.mbarrier.try_wait %barrier, %phase {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32 -> i1 + + %2 = nvvm.mbarrier.try_wait %barrier, %phase {relaxed = true} : !llvm.ptr<3>, i32 -> i1 + %3 = nvvm.mbarrier.try_wait %barrier, %phase {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32 -> i1 + llvm.return +} + +llvm.func @mbarrier_try_wait_shared_phase_with_timelimit(%barrier: !llvm.ptr<3>, %phase : i32, %ticks : i32) { + // CHECK-LABEL: define void @mbarrier_try_wait_shared_phase_with_timelimit(ptr addrspace(3) %0, i32 %1, i32 %2) { + // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.tl.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1, i32 %2) + // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.tl.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1, i32 %2) + // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.tl.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1, i32 %2) + // CHECK-NEXT: %7 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.tl.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1, i32 %2) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32 -> i1 + %1 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32, i32 -> i1 + + %2 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks {relaxed = true} : !llvm.ptr<3>, i32, i32 -> i1 + %3 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32, i32 -> i1 + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/membar.mlir b/mlir/test/Target/LLVMIR/nvvm/membar.mlir new file mode 100644 index 0000000..1b794f6 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/membar.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-translate -mlir-to-llvmir %s -split-input-file --verify-diagnostics | FileCheck %s + +// CHECK-lABEL: @memorybarrier() +llvm.func @memorybarrier() { + // CHECK: call void @llvm.nvvm.membar.cta() + nvvm.memory.barrier #nvvm.mem_scope<cta> + // CHECK: call void @llvm.nvvm.fence.sc.cluster() + nvvm.memory.barrier #nvvm.mem_scope<cluster> + // CHECK: call void @llvm.nvvm.membar.gl() + nvvm.memory.barrier #nvvm.mem_scope<gpu> + // CHECK: call void @llvm.nvvm.membar.sys() + nvvm.memory.barrier #nvvm.mem_scope<sys> + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/permute_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/permute_invalid.mlir new file mode 100644 index 0000000..1d6c23c --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/permute_invalid.mlir @@ -0,0 +1,43 @@ +// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s + +llvm.func @invalid_default_missing_hi(%sel: i32, %lo: i32) -> i32 { + // expected-error @below {{mode 'default' requires 'hi' operand.}} + %r = nvvm.prmt #nvvm.permute_mode<default> %sel, %lo : i32 + llvm.return %r : i32 +} + +llvm.func @invalid_f4e_missing_hi(%sel: i32, %lo: i32) -> i32 { + // expected-error @below {{mode 'f4e' requires 'hi' operand.}} + %r = nvvm.prmt #nvvm.permute_mode<f4e> %sel, %lo : i32 + llvm.return %r : i32 +} + +llvm.func @invalid_b4e_missing_hi(%sel: i32, %lo: i32) -> i32 { + // expected-error @below {{mode 'b4e' requires 'hi' operand.}} + %r = nvvm.prmt #nvvm.permute_mode<b4e> %sel, %lo : i32 + llvm.return %r : i32 +} + +llvm.func @invalid_rc8_with_hi(%sel: i32, %lo: i32, %hi: i32) -> i32 { + // expected-error @below {{mode 'rc8' does not accept 'hi' operand.}} + %r = nvvm.prmt #nvvm.permute_mode<rc8> %sel, %lo, %hi : i32 + llvm.return %r : i32 +} + +llvm.func @invalid_ecl_with_hi(%sel: i32, %lo: i32, %hi: i32) -> i32 { + // expected-error @below {{mode 'ecl' does not accept 'hi' operand.}} + %r = nvvm.prmt #nvvm.permute_mode<ecl> %sel, %lo, %hi : i32 + llvm.return %r : i32 +} + +llvm.func @invalid_ecr_with_hi(%sel: i32, %lo: i32, %hi: i32) -> i32 { + // expected-error @below {{mode 'ecr' does not accept 'hi' operand.}} + %r = nvvm.prmt #nvvm.permute_mode<ecr> %sel, %lo, %hi : i32 + llvm.return %r : i32 +} + +llvm.func @invalid_rc16_with_hi(%sel: i32, %lo: i32, %hi: i32) -> i32 { + // expected-error @below {{mode 'rc16' does not accept 'hi' operand.}} + %r = nvvm.prmt #nvvm.permute_mode<rc16> %sel, %lo, %hi : i32 + llvm.return %r : i32 +} diff --git a/mlir/test/Target/LLVMIR/nvvm/permute_valid.mlir b/mlir/test/Target/LLVMIR/nvvm/permute_valid.mlir new file mode 100644 index 0000000..d2baae7 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/permute_valid.mlir @@ -0,0 +1,64 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @test_prmt_default +llvm.func @test_prmt_default(%sel: i32, %lo: i32, %hi: i32) -> i32 { + // CHECK: call i32 @llvm.nvvm.prmt(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %result = nvvm.prmt #nvvm.permute_mode<default> %sel, %lo, %hi : i32 + llvm.return %result : i32 +} + +// CHECK-LABEL: @test_prmt_f4e +llvm.func @test_prmt_f4e(%pos: i32, %lo: i32, %hi: i32) -> i32 { + // CHECK: call i32 @llvm.nvvm.prmt.f4e(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %result = nvvm.prmt #nvvm.permute_mode<f4e> %pos, %lo, %hi : i32 + llvm.return %result : i32 +} + +// CHECK-LABEL: @test_prmt_b4e +llvm.func @test_prmt_b4e(%pos: i32, %lo: i32, %hi: i32) -> i32 { + // CHECK: call i32 @llvm.nvvm.prmt.b4e(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %result = nvvm.prmt #nvvm.permute_mode<b4e> %pos, %lo, %hi : i32 + llvm.return %result : i32 +} + +// CHECK-LABEL: @test_prmt_rc8 +llvm.func @test_prmt_rc8(%sel: i32, %val: i32) -> i32 { + // CHECK: call i32 @llvm.nvvm.prmt.rc8(i32 %{{.*}}, i32 %{{.*}}) + %result = nvvm.prmt #nvvm.permute_mode<rc8> %sel, %val : i32 + llvm.return %result : i32 +} + +// CHECK-LABEL: @test_prmt_ecl +llvm.func @test_prmt_ecl(%sel: i32, %val: i32) -> i32 { + // CHECK: call i32 @llvm.nvvm.prmt.ecl(i32 %{{.*}}, i32 %{{.*}}) + %result = nvvm.prmt #nvvm.permute_mode<ecl> %sel, %val : i32 + llvm.return %result : i32 +} + +// CHECK-LABEL: @test_prmt_ecr +llvm.func @test_prmt_ecr(%sel: i32, %val: i32) -> i32 { + // CHECK: call i32 @llvm.nvvm.prmt.ecr(i32 %{{.*}}, i32 %{{.*}}) + %result = nvvm.prmt #nvvm.permute_mode<ecr> %sel, %val : i32 + llvm.return %result : i32 +} + +// CHECK-LABEL: @test_prmt_rc16 +llvm.func @test_prmt_rc16(%sel: i32, %val: i32) -> i32 { + // CHECK: call i32 @llvm.nvvm.prmt.rc16(i32 %{{.*}}, i32 %{{.*}}) + %result = nvvm.prmt #nvvm.permute_mode<rc16> %sel, %val : i32 + llvm.return %result : i32 +} + +// CHECK-LABEL: @test_prmt_mixed +llvm.func @test_prmt_mixed(%sel: i32, %lo: i32, %hi: i32) -> i32 { + // CHECK: call i32 @llvm.nvvm.prmt(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %r1 = nvvm.prmt #nvvm.permute_mode<default> %sel, %lo, %hi : i32 + + // CHECK: call i32 @llvm.nvvm.prmt.rc8(i32 %{{.*}}, i32 %{{.*}}) + %r2 = nvvm.prmt #nvvm.permute_mode<rc8> %sel, %r1 : i32 + + // CHECK: call i32 @llvm.nvvm.prmt.f4e(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %r3 = nvvm.prmt #nvvm.permute_mode<f4e> %lo, %r2, %sel : i32 + + llvm.return %r3 : i32 +} diff --git a/mlir/test/Target/LLVMIR/nvvm/redux-sync-invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/redux-sync-invalid.mlir new file mode 100644 index 0000000..a8a7430 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/redux-sync-invalid.mlir @@ -0,0 +1,49 @@ +// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s + +// ----- + +llvm.func @redux_sync_i32_with_abs(%value: i32, %offset: i32) { + // expected-error@+1 {{abs attribute is supported only for f32 type}} + %res = nvvm.redux.sync add %value, %offset {abs = true}: i32 -> i32 + llvm.return +} + +// ----- + +llvm.func @redux_sync_i32_with_nan(%value: i32, %offset: i32) { + // expected-error@+1 {{nan attribute is supported only for f32 type}} + %res = nvvm.redux.sync add %value, %offset {nan = true}: i32 -> i32 + llvm.return +} + +// ----- + +llvm.func @redux_sync_f32_with_invalid_kind_add(%value: f32, %offset: i32) { + // expected-error@+1 {{'add' redux kind unsupported with 'f32' type. Only supported type is 'i32'.}} + %res = nvvm.redux.sync add %value, %offset: f32 -> f32 + llvm.return +} + +// ----- + +llvm.func @redux_sync_f32_with_invalid_kind_and(%value: f32, %offset: i32) { + // expected-error@+1 {{'and' redux kind unsupported with 'f32' type. Only supported type is 'i32'.}} + %res = nvvm.redux.sync and %value, %offset: f32 -> f32 + llvm.return +} + +// ----- + +llvm.func @redux_sync_i32_with_invalid_kind_fmin(%value: i32, %offset: i32) { + // expected-error@+1 {{'fmin' redux kind unsupported with 'i32' type. Only supported type is 'f32'.}} + %res = nvvm.redux.sync fmin %value, %offset: i32 -> i32 + llvm.return +} + +// ----- + +llvm.func @redux_sync_non_matching_types(%value: i32, %offset: i32) { + // expected-error@+1 {{failed to verify that all of {res, val} have same type}} + %res = nvvm.redux.sync add %value, %offset: i32 -> f32 + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir new file mode 100644 index 0000000..f2ccfe7 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir @@ -0,0 +1,22 @@ +// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s + +// ----- + +func.func @nvvm_invalid_shfl_pred(%arg0 : i32, %arg1 : f32, %arg2 : i32, %arg3 : i32) { + // expected-error@+1 {{"return_value_and_is_valid" attribute must be specified when the return type is a struct type}} + %0 = nvvm.shfl.sync bfly %arg0, %arg1, %arg2, %arg3 : f32 -> !llvm.struct<(f32, i1)> +} + +// ----- + +func.func @nvvm_invalid_shfl_invalid_return_type_1(%arg0 : i32, %arg1 : f32, %arg2 : i32, %arg3 : i32) { + // expected-error@+1 {{expected return type to be of type 'f32' but got 'i32' instead}} + %0 = nvvm.shfl.sync bfly %arg0, %arg1, %arg2, %arg3 : f32 -> i32 +} + +// ----- + +func.func @nvvm_invalid_shfl_invalid_return_type_2(%arg0 : i32, %arg1 : f32, %arg2 : i32, %arg3 : i32) { + // expected-error@+1 {{expected first element in the returned struct to be of type 'f32' but got 'i32' instead}} + %0 = nvvm.shfl.sync bfly %arg0, %arg1, %arg2, %arg3 {return_value_and_is_valid} : f32 -> !llvm.struct<(i32, i1)> +} diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-ld-invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-ld-invalid.mlir new file mode 100644 index 0000000..1b93f20c --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-ld-invalid.mlir @@ -0,0 +1,9 @@ +// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s + +// ----- + +llvm.func @nvvm_tcgen05_ld_32x32b_offset(%tmemAddr : !llvm.ptr<6>, %offset : i64) -> () { + // expected-error@+1 {{offset argument is only supported for shape 16x32bx2}} + %ldv2 = nvvm.tcgen05.ld %tmemAddr, %offset { pack, shape = #nvvm.tcgen05_ldst_shape<shape_32x32b>} : vector<2 x i32> + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-shared.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-shared.mlir new file mode 100644 index 0000000..db4574b --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-shared.mlir @@ -0,0 +1,229 @@ +// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @nvvm_tcgen05_mma_mxf8f6f4_block_scale_cta_1 +llvm.func @nvvm_tcgen05_mma_mxf8f6f4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_mxf8f6f4_block_scale_cta_2 +llvm.func @nvvm_tcgen05_mma_mxf8f6f4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_mxf4_block_scale_cta_1 +llvm.func @nvvm_tcgen05_mma_mxf4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_mxf4_block_scale_cta_2 +llvm.func @nvvm_tcgen05_mma_mxf4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_mxf4nvf4_block_scale_cta_1 +llvm.func @nvvm_tcgen05_mma_mxf4nvf4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_mxf4nvf4_block_scale_cta_2 +llvm.func @nvvm_tcgen05_mma_mxf4nvf4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-tensor.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-tensor.mlir new file mode 100644 index 0000000..a15c3fb --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-tensor.mlir @@ -0,0 +1,229 @@ +// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @nvvm_tcgen05_mma_mxf8f6f4_block_scale_cta_1 +llvm.func @nvvm_tcgen05_mma_mxf8f6f4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_mxf8f6f4_block_scale_cta_2 +llvm.func @nvvm_tcgen05_mma_mxf8f6f4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_mxf4_block_scale_cta_1 +llvm.func @nvvm_tcgen05_mma_mxf4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_mxf4_block_scale_cta_2 +llvm.func @nvvm_tcgen05_mma_mxf4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_mxf4nvf4_block_scale_cta_1 +llvm.func @nvvm_tcgen05_mma_mxf4nvf4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_mxf4nvf4_block_scale_cta_2 +llvm.func @nvvm_tcgen05_mma_mxf4nvf4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-invalid.mlir new file mode 100644 index 0000000..f46b35a --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-invalid.mlir @@ -0,0 +1,119 @@ +// RUN: mlir-translate --mlir-to-llvmir -verify-diagnostics -split-input-file %s + +// CHECK-LABEL: @nvvm_tcgen05_mma_disable_output_lane_cta_1 +llvm.func @nvvm_tcgen05_mma_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLanev4: vector<4 x i32>, %disableOutputLanev8: vector<8 x i32>) { + // expected-error @below {{Disable Output Lane of length 8 is incompatible with CtaGroupAttr}} + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLanev8 + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + llvm.return +} + +// ----- + +// CHECK-LABEL: @nvvm_tcgen05_mma_disable_output_lane_cta_2 +llvm.func @nvvm_tcgen05_mma_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLanev4: vector<4 x i32>, %disableOutputLanev8: vector<8 x i32>) { + // expected-error @below {{Disable Output Lane of length 8 is incompatible with CtaGroupAttr}} + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLanev8 + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + llvm.return +} + +// ----- + +// CHECK-LABEL: @nvvm_tcgen05_mma_shared_ashift +llvm.func @nvvm_tcgen05_mma_shared_ashift(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1) { + // expected-error @below {{A-shift can be applied only when matrix A is in tensor memory}} + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, i64, i64, i32, i1) + llvm.return +} + +// ----- + +// CHECK-LABEL: @nvvm_tcgen05_mma_ashift +llvm.func @nvvm_tcgen05_mma_ashift(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1) { + // expected-error @below {{Cannot use collector buffer operation fill or use with ashift}} + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + llvm.return +} + +// ----- + +// CHECK-LABEL: @nvvm_tcgen05_mma_mxf4nvf4_block_scale_default +llvm.func @nvvm_tcgen05_mma_mxf4nvf4_block_scale_default(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scalea: !llvm.ptr<6>, %scaleb: !llvm.ptr<6>) { + // expected-error @below {{mxf4nvf4 requires block scale attribute}} + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scalea, %scaleb + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + llvm.return +} + +// ----- + +// CHECK-LABEL: @nvvm_tcgen05_mma_mxf4_block_scale_default +llvm.func @nvvm_tcgen05_mma_mxf4_block_scale_default(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scalea: !llvm.ptr<6>, %scaleb: !llvm.ptr<6>) { + // expected-error @below {{mxf4 kind does not support block16 attribute}} + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scalea, %scaleb + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, ashift, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + llvm.return +} + +// ----- + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_disable_output_lane_cta_1 +llvm.func @nvvm_tcgen05_mma_sp_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLanev4: vector<4 x i32>, %disableOutputLanev8: vector<8 x i32>, %spmetadata: !llvm.ptr<6>) { + // expected-error @below {{Disable Output Lane of length 8 is incompatible with CtaGroupAttr}} + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLanev8 + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + llvm.return +} + +// ----- + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_disable_output_lane_cta_2 +llvm.func @nvvm_tcgen05_mma_sp_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLanev4: vector<4 x i32>, %disableOutputLanev8: vector<8 x i32>, %spmetadata: !llvm.ptr<6>) { + // expected-error @below {{Disable Output Lane of length 8 is incompatible with CtaGroupAttr}} + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLanev8 + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + llvm.return +} + +// ----- + +// CHECK-LABEL: @nvvm_tcgen05_sp_mma_shared_ashift +llvm.func @nvvm_tcgen05_sp_mma_shared_ashift(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) { + // expected-error @below {{A-shift can be applied only when matrix A is in tensor memory}} + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + llvm.return +} + +// ----- + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_ashift +llvm.func @nvvm_tcgen05_mma_sp_ashift(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) { + // expected-error @below {{Cannot use collector buffer operation fill or use with ashift}} + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + llvm.return +} + +// ----- + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_default +llvm.func @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_default(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scalea: !llvm.ptr<6>, %scaleb: !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) { + // expected-error @below {{mxf4nvf4 requires block scale attribute}} + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scalea, %scaleb + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + llvm.return +} + +// ----- + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf4_block_scale_default +llvm.func @nvvm_tcgen05_mma_sp_mxf4_block_scale_default(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scalea: !llvm.ptr<6>, %scaleb: !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) { + // expected-error @below {{mxf4 kind does not support block16 attribute}} + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scalea, %scaleb + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, ashift, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-shared.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-shared.mlir new file mode 100644 index 0000000..286df36 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-shared.mlir @@ -0,0 +1,442 @@ +// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @nvvm_tcgen05_mma_cta_1 +llvm.func @nvvm_tcgen05_mma_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 1, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 1, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 1, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 1, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 1, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 1, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 1, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 1, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_cta_2 +llvm.func @nvvm_tcgen05_mma_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 2, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 2, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 2, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 2, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 2, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 2, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 2, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 2, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + llvm.return +} + + +// CHECK-LABEL: @nvvm_tcgen05_mma_scale_d_imm_cta_1 +llvm.func @nvvm_tcgen05_mma_scale_d_imm_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1) { + + %scale_d_imm = llvm.mlir.constant(0:i64) : i64 + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_scale_d_imm_cta_2 +llvm.func @nvvm_tcgen05_mma_scale_d_imm_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1) { + + %scale_d_imm = llvm.mlir.constant(0:i64) : i64 + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_disable_output_lane_cta_1 +llvm.func @nvvm_tcgen05_mma_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane : vector<4 x i32>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_disable_output_lane_cta_2 +llvm.func @nvvm_tcgen05_mma_disable_output_lane_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<8 x i32>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_scale_d_imm_disable_output_lane_cta_1 +llvm.func @nvvm_tcgen05_mma_scale_d_imm_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<4 x i32>) { + + %scale_d_imm = llvm.mlir.constant(0:i64) : i64 + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<4 x i32>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_scale_d_imm_disable_output_lane_cta_2 +llvm.func @nvvm_tcgen05_mma_scale_d_imm_disable_output_lane_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<8 x i32>) { + + %scale_d_imm = llvm.mlir.constant(0:i64) : i64 + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<8 x i32>) + + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-shared.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-shared.mlir new file mode 100644 index 0000000..5c7eabe --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-shared.mlir @@ -0,0 +1,229 @@ +// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf8f6f4_block_scale_cta_1 +llvm.func @nvvm_tcgen05_mma_sp_mxf8f6f4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf8f6f4_block_scale_cta_2 +llvm.func @nvvm_tcgen05_mma_sp_mxf8f6f4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf4_block_scale_cta_1 +llvm.func @nvvm_tcgen05_mma_sp_mxf4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf4_block_scale_cta_2 +llvm.func @nvvm_tcgen05_mma_sp_mxf4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_cta_1 +llvm.func @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_cta_2 +llvm.func @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-tensor.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-tensor.mlir new file mode 100644 index 0000000..3200411 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-tensor.mlir @@ -0,0 +1,229 @@ +// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf8f6f4_block_scale_cta_1 +llvm.func @nvvm_tcgen05_mma_sp_mxf8f6f4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf8f6f4_block_scale_cta_2 +llvm.func @nvvm_tcgen05_mma_sp_mxf8f6f4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf4_block_scale_cta_1 +llvm.func @nvvm_tcgen05_mma_sp_mxf4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf4_block_scale_cta_2 +llvm.func @nvvm_tcgen05_mma_sp_mxf4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_cta_1 +llvm.func @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_cta_2 +llvm.func @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-shared.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-shared.mlir new file mode 100644 index 0000000..96044cf --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-shared.mlir @@ -0,0 +1,442 @@ +// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_cta_1 +llvm.func @nvvm_tcgen05_mma_sp_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_cta_2 +llvm.func @nvvm_tcgen05_mma_sp_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 2, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 2, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 2, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 2, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + llvm.return +} + + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_scale_d_imm_cta_1 +llvm.func @nvvm_tcgen05_mma_sp_scale_d_imm_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) { + + %scale_d_imm = llvm.mlir.constant(0:i64) : i64 + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_scale_d_imm_cta_2 +llvm.func @nvvm_tcgen05_mma_sp_scale_d_imm_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) { + + %scale_d_imm = llvm.mlir.constant(0:i64) : i64 + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_disable_output_lane_cta_1 +llvm.func @nvvm_tcgen05_mma_sp_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane : vector<4 x i32>, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_disable_output_lane_cta_2 +llvm.func @nvvm_tcgen05_mma_sp_disable_output_lane_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<8 x i32>, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_scale_d_imm_disable_output_lane_cta_1 +llvm.func @nvvm_tcgen05_mma_sp_scale_d_imm_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<4 x i32>, %spmetadata: !llvm.ptr<6>) { + + %scale_d_imm = llvm.mlir.constant(0:i64) : i64 + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_scale_d_imm_disable_output_lane_cta_2 +llvm.func @nvvm_tcgen05_mma_sp_scale_d_imm_disable_output_lane_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<8 x i32>, %spmetadata: !llvm.ptr<6>) { + + %scale_d_imm = llvm.mlir.constant(0:i64) : i64 + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-tensor.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-tensor.mlir new file mode 100644 index 0000000..709beb0 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-tensor.mlir @@ -0,0 +1,634 @@ +// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_cta_1 +llvm.func @nvvm_tcgen05_mma_sp_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_cta_2 +llvm.func @nvvm_tcgen05_mma_sp_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 2, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 2, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 2, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 2, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + llvm.return +} + + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_scale_d_imm_cta_1 +llvm.func @nvvm_tcgen05_mma_sp_scale_d_imm_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) { + + %scale_d_imm = llvm.mlir.constant(0:i64) : i64 + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_scale_d_imm_cta_2 +llvm.func @nvvm_tcgen05_mma_sp_scale_d_imm_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) { + + %scale_d_imm = llvm.mlir.constant(0:i64) : i64 + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_disable_output_lane_cta_1 +llvm.func @nvvm_tcgen05_mma_sp_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane : vector<4 x i32>, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_disable_output_lane_cta_2 +llvm.func @nvvm_tcgen05_mma_sp_disable_output_lane_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<8 x i32>, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_scale_d_imm_disable_output_lane_cta_1 +llvm.func @nvvm_tcgen05_mma_sp_scale_d_imm_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<4 x i32>, %spmetadata: !llvm.ptr<6>) { + + %scale_d_imm = llvm.mlir.constant(0:i64) : i64 + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_scale_d_imm_disable_output_lane_cta_2 +llvm.func @nvvm_tcgen05_mma_sp_scale_d_imm_disable_output_lane_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<8 x i32>, %spmetadata: !llvm.ptr<6>) { + + %scale_d_imm = llvm.mlir.constant(0:i64) : i64 + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-tensor.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-tensor.mlir new file mode 100644 index 0000000..798e311 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-tensor.mlir @@ -0,0 +1,633 @@ +// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @nvvm_tcgen05_mma_cta_1 +llvm.func @nvvm_tcgen05_mma_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f16 */ i32 0, /* cta_group= */ i32 1, /* collector=discard */ i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=tf32 */ i32 1, /* cta_group= */ i32 1, /* collector=discard */ i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f8f6f4 */ i32 2, /* cta_group= */ i32 1, /* collector=discard */ i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=i8 */ i32 3, /* cta_group= */ i32 1, /* collector=discard */ i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f16 */ i32 0, /* cta_group= */ i32 1, /* collector=lastuse */ i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=tf32 */ i32 1, /* cta_group= */ i32 1, /* collector=lastuse */ i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f8f6f4 */ i32 2, /* cta_group= */ i32 1, /* collector=lastuse */ i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=i8 */ i32 3, /* cta_group= */ i32 1, /* collector=lastuse */ i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f16 */ i32 0, /* cta_group= */ i32 1, /* collector=fill */ i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=tf32 */ i32 1, /* cta_group= */ i32 1, /* collector=fill */ i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f8f6f4 */ i32 2, /* cta_group= */ i32 1, /* collector=fill */ i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=i8 */ i32 3, /* cta_group= */ i32 1, /* collector=fill */ i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f16 */ i32 0, /* cta_group= */ i32 1, /* collector=use */ i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=tf32 */ i32 1, /* cta_group= */ i32 1, /* collector=use */ i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f8f6f4 */ i32 2, /* cta_group= */ i32 1, /* collector=use */ i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=i8 */ i32 3, /* cta_group= */ i32 1, /* collector=use */ i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_cta_2 +llvm.func @nvvm_tcgen05_mma_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f16 */ i32 0, /* cta_group= */ i32 2, /* collector=discard */ i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=tf32 */ i32 1, /* cta_group= */ i32 2, /* collector=discard */ i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f8f6f4 */ i32 2, /* cta_group= */ i32 2, /* collector=discard */ i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=i8 */ i32 3, /* cta_group= */ i32 2, /* collector=discard */ i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f16 */ i32 0, /* cta_group= */ i32 2, /* collector=lastuse */ i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=tf32 */ i32 1, /* cta_group= */ i32 2, /* collector=lastuse */ i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f8f6f4 */ i32 2, /* cta_group= */ i32 2, /* collector=lastuse */ i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=i8 */ i32 3, /* cta_group= */ i32 2, /* collector=lastuse */ i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f16 */ i32 0, /* cta_group= */ i32 2, /* collector=fill */ i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=tf32 */ i32 1, /* cta_group= */ i32 2, /* collector=fill */ i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f8f6f4 */ i32 2, /* cta_group= */ i32 2, /* collector=fill */ i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=i8 */ i32 3, /* cta_group= */ i32 2, /* collector=fill */ i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f16 */ i32 0, /* cta_group= */ i32 2, /* collector=use */ i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=tf32 */ i32 1, /* cta_group= */ i32 2, /* collector=use */ i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f8f6f4 */ i32 2, /* cta_group= */ i32 2, /* collector=use */ i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=i8 */ i32 3, /* cta_group= */ i32 2, /* collector=use */ i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_scale_d_imm_cta_1 +llvm.func @nvvm_tcgen05_mma_scale_d_imm_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1) { + + %scale_d_imm = llvm.mlir.constant(0:i64) : i64 + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_scale_d_imm_cta_2 +llvm.func @nvvm_tcgen05_mma_scale_d_imm_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1) { + + %scale_d_imm = llvm.mlir.constant(0:i64) : i64 + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_disable_output_lane_cta_1 +llvm.func @nvvm_tcgen05_mma_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane : vector<4 x i32>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_disable_output_lane_cta_2 +llvm.func @nvvm_tcgen05_mma_disable_output_lane_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<8 x i32>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_scale_d_imm_disable_output_lane_cta_1 +llvm.func @nvvm_tcgen05_mma_scale_d_imm_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<4 x i32>) { + + %scale_d_imm = llvm.mlir.constant(0:i64) : i64 + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_scale_d_imm_disable_output_lane_cta_2 +llvm.func @nvvm_tcgen05_mma_scale_d_imm_disable_output_lane_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<8 x i32>) { + + %scale_d_imm = llvm.mlir.constant(0:i64) : i64 + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>) + + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-shared.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-shared.mlir new file mode 100644 index 0000000..5f1aeb0 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-shared.mlir @@ -0,0 +1,133 @@ +// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @nvvm_tcgen05_mma_ws +llvm.func @nvvm_tcgen05_mma_ws(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 0, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 0, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 0, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 0, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 1, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 1, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 1, i32 1) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 1, i32 1) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_ws_zero_col_mask +llvm.func @nvvm_tcgen05_mma_ws_zero_col_mask(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %zero_col_mask: i64) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 0, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f16>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 0, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<tf32>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 0, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 0, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<i8>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f16>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<tf32>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 1, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 1, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<i8>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f16>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<tf32>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 1, i32 1) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 1, i32 1) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<i8>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-sp-shared.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-sp-shared.mlir new file mode 100644 index 0000000..e390e35 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-sp-shared.mlir @@ -0,0 +1,133 @@ +// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @nvvm_tcgen05_mma_ws_sp +llvm.func @nvvm_tcgen05_mma_ws_sp(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 0, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 0, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 1) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 1) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_ws_sp_zero_col_mask +llvm.func @nvvm_tcgen05_mma_ws_sp_zero_col_mask(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>, %zero_col_mask: i64) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 0, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 0, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<tf32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 0, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 0, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<i8>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f16>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<tf32>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 1, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 1, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<i8>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f16>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<tf32>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 1, i32 1) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 1, i32 1) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<i8>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-sp-tensor.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-sp-tensor.mlir new file mode 100644 index 0000000..f7ce548 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-sp-tensor.mlir @@ -0,0 +1,133 @@ +// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @nvvm_tcgen05_mma_ws_sp +llvm.func @nvvm_tcgen05_mma_ws_sp(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 0, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 0, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 1) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 1) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_ws_sp_zero_col_mask +llvm.func @nvvm_tcgen05_mma_ws_sp_zero_col_mask(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>, %zero_col_mask: i64) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 0, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 0, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<tf32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 0, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 0, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<i8>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f16>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<tf32>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 1, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 1, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<i8>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f16>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<tf32>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 1, i32 1) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 1, i32 1) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<i8>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-tensor.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-tensor.mlir new file mode 100644 index 0000000..cecbb3f --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-tensor.mlir @@ -0,0 +1,133 @@ +// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @nvvm_tcgen05_mma_ws +llvm.func @nvvm_tcgen05_mma_ws(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %b_desc: i64, %idesc: i32, %enable_input_d: i1) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 0, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 0, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 0, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 0, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 1, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 1, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 1, i32 1) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 1, i32 1) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_ws_zero_col_mask +llvm.func @nvvm_tcgen05_mma_ws_zero_col_mask(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %zero_col_mask: i64) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 0, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 0, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<tf32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 0, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 0, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<i8>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f16>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<tf32>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 1, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 1, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<i8>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f16>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<tf32>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 1, i32 1) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 1, i32 1) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<i8>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir index 0daf245..240fab5 100644 --- a/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir +++ b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir @@ -16,6 +16,17 @@ llvm.func @llvm_nvvm_cp_async_bulk_global_to_shared_cluster(%dst : !llvm.ptr<7>, llvm.return } +// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_global_to_shared_cta +llvm.func @llvm_nvvm_cp_async_bulk_global_to_shared_cta(%dst : !llvm.ptr<3>, %src : !llvm.ptr<1>, %mbar : !llvm.ptr<3>, %size : i32, %ch : i64) { + // CHECK: call void @llvm.nvvm.cp.async.bulk.global.to.shared.cta(ptr addrspace(3) %[[DST:.*]], ptr addrspace(3) %[[MBAR:.*]], ptr addrspace(1) %[[SRC:.*]], i32 %[[SIZE:.*]], i64 0, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.global.to.shared.cta(ptr addrspace(3) %[[DST]], ptr addrspace(3) %[[MBAR]], ptr addrspace(1) %[[SRC]], i32 %[[SIZE]], i64 %[[CH:.*]], i1 true) + nvvm.cp.async.bulk.shared.cluster.global %dst, %src, %mbar, %size : !llvm.ptr<3>, !llvm.ptr<1> + + nvvm.cp.async.bulk.shared.cluster.global %dst, %src, %mbar, %size l2_cache_hint = %ch : !llvm.ptr<3>, !llvm.ptr<1> + + llvm.return +} + // CHECK-LABEL: @llvm_nvvm_cp_async_bulk_shared_cta_to_shared_cluster llvm.func @llvm_nvvm_cp_async_bulk_shared_cta_to_shared_cluster(%dst : !llvm.ptr<7>, %src : !llvm.ptr<3>, %mbar : !llvm.ptr<3>, %size : i32) { // CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.cluster(ptr addrspace(7) %0, ptr addrspace(3) %2, ptr addrspace(3) %1, i32 %3) diff --git a/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy_invalid.mlir new file mode 100644 index 0000000..d762ff3 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy_invalid.mlir @@ -0,0 +1,8 @@ +// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s + +llvm.func @tma_bulk_copy_g2s_mc(%src : !llvm.ptr<1>, %dest : !llvm.ptr<3>, %bar : !llvm.ptr<3>, %size : i32, %ctamask : i16) { + // expected-error @below {{Multicast is not supported with shared::cta mode.}} + nvvm.cp.async.bulk.shared.cluster.global %dest, %src, %bar, %size multicast_mask = %ctamask : !llvm.ptr<3>, !llvm.ptr<1> + + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir index 42aa221..d5868ee 100644 --- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir @@ -578,14 +578,6 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { // ----- -llvm.func @nanosleep() { - // expected-error@+1 {{integer constant out of range for attribute}} - nvvm.nanosleep 100000000000000 - llvm.return -} - -// ----- - llvm.func @clusterlaunchcontrol_query_cancel_is_canceled_invalid_return_type(%try_cancel_response: i128) { // expected-error@+1 {{'nvvm.clusterlaunchcontrol.query.cancel' op is_canceled query type returns an i1}} %res = nvvm.clusterlaunchcontrol.query.cancel query = is_canceled, %try_cancel_response : i32 diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index 9115de6..c4a6909 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -166,25 +166,6 @@ llvm.func @nvvm_rcp(%0: f32) -> f32 { llvm.return %1 : f32 } -// CHECK-LABEL: @llvm_nvvm_barrier0 -llvm.func @llvm_nvvm_barrier0() { - // CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0) - nvvm.barrier0 - llvm.return -} - -// CHECK-LABEL: @llvm_nvvm_barrier( -// CHECK-SAME: i32 %[[barId:.*]], i32 %[[numThreads:.*]]) -llvm.func @llvm_nvvm_barrier(%barID : i32, %numberOfThreads : i32) { - // CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0) - nvvm.barrier - // CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 %[[barId]]) - nvvm.barrier id = %barID - // CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.count(i32 %[[barId]], i32 %[[numThreads]]) - nvvm.barrier id = %barID number_of_threads = %numberOfThreads - llvm.return -} - // CHECK-LABEL: @llvm_nvvm_cluster_arrive llvm.func @llvm_nvvm_cluster_arrive() { // CHECK: call void @llvm.nvvm.barrier.cluster.arrive() @@ -531,19 +512,6 @@ llvm.func @async_cp_zfill(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>, %cpSize: i32) llvm.return } -// CHECK-LABEL: @cp_async_mbarrier_arrive -llvm.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.ptr) { - // CHECK: call void @llvm.nvvm.cp.async.mbarrier.arrive(ptr %{{.*}}) - nvvm.cp.async.mbarrier.arrive %bar_gen : !llvm.ptr - // CHECK: call void @llvm.nvvm.cp.async.mbarrier.arrive.noinc(ptr %{{.*}}) - nvvm.cp.async.mbarrier.arrive %bar_gen {noinc = true} : !llvm.ptr - // CHECK: call void @llvm.nvvm.cp.async.mbarrier.arrive.shared(ptr addrspace(3) %{{.*}}) - nvvm.cp.async.mbarrier.arrive.shared %bar_shared : !llvm.ptr<3> - // CHECK: call void @llvm.nvvm.cp.async.mbarrier.arrive.noinc.shared(ptr addrspace(3) %{{.*}}) - nvvm.cp.async.mbarrier.arrive.shared %bar_shared {noinc = true} : !llvm.ptr<3> - llvm.return -} - // CHECK-LABEL: @llvm_nvvm_setmaxregister llvm.func @llvm_nvvm_setmaxregister() { // CHECK: call void @llvm.nvvm.setmaxnreg.inc.sync.aligned.u32(i32 256) @@ -731,42 +699,6 @@ llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}, llvm.return } - -// ----- -// CHECK-LABEL: @nvvm_fence_proxy_tensormap_generic_release -llvm.func @nvvm_fence_proxy_tensormap_generic_release() { - %c128 = llvm.mlir.constant(128) : i32 - // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.release.cta() - nvvm.fence.proxy.release #nvvm.mem_scope<cta> - - // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.release.cluster() - nvvm.fence.proxy.release #nvvm.mem_scope<cluster> - - // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.release.gpu() - nvvm.fence.proxy.release #nvvm.mem_scope<gpu> - - // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.release.sys() - nvvm.fence.proxy.release #nvvm.mem_scope<sys> - llvm.return -} - -// ----- -// CHECK-LABEL: @nvvm_fence_proxy_tensormap_generic_acquire -llvm.func @nvvm_fence_proxy_tensormap_generic_acquire(%addr : !llvm.ptr) { - %c128 = llvm.mlir.constant(128) : i32 - // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.acquire.cta(ptr {{%[0-9]+}}, i32 128) - nvvm.fence.proxy.acquire #nvvm.mem_scope<cta> %addr, %c128 - - // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.acquire.cluster(ptr {{%[0-9]+}}, i32 128) - nvvm.fence.proxy.acquire #nvvm.mem_scope<cluster> %addr, %c128 - - // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.acquire.gpu(ptr {{%[0-9]+}}, i32 128) - nvvm.fence.proxy.acquire #nvvm.mem_scope<gpu> %addr, %c128 - - // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.acquire.sys(ptr {{%[0-9]+}}, i32 128) - nvvm.fence.proxy.acquire #nvvm.mem_scope<sys> %addr, %c128 - llvm.return -} // ----- // CHECK-LABEL: @nvvm_exit @@ -983,8 +915,8 @@ llvm.func @nvvm_pmevent() { // ----- // CHECK-LABEL: @nanosleep -llvm.func @nanosleep() { - // CHECK: call void @llvm.nvvm.nanosleep(i32 4000) - nvvm.nanosleep 4000 +llvm.func @nanosleep(%duration: i32) { + // CHECK: call void @llvm.nvvm.nanosleep(i32 %{{.*}}) + nvvm.nanosleep %duration llvm.return } diff --git a/mlir/test/Target/LLVMIR/omptarget-data-use-dev-ordering.mlir b/mlir/test/Target/LLVMIR/omptarget-data-use-dev-ordering.mlir index f6860e5..d9be6d1 100644 --- a/mlir/test/Target/LLVMIR/omptarget-data-use-dev-ordering.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-data-use-dev-ordering.mlir @@ -67,18 +67,18 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-a // CHECK: define void @mix_use_device_ptr_and_addr_and_map_(ptr %[[ARG_0:.*]], ptr %[[ARG_1:.*]], ptr %[[ARG_2:.*]], ptr %[[ARG_3:.*]], ptr %[[ARG_4:.*]], ptr %[[ARG_5:.*]], ptr %[[ARG_6:.*]], ptr %[[ARG_7:.*]]) { // CHECK: %[[ALLOCA:.*]] = alloca ptr, align 8 -// CHECK: %[[BASEPTR_0_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 +// CHECK: %[[BASEPTR_0_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 // CHECK: store ptr %[[ARG_0]], ptr %[[BASEPTR_0_GEP]], align 8 -// CHECK: %[[BASEPTR_2_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 2 +// CHECK: %[[BASEPTR_2_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 4 // CHECK: store ptr %[[ARG_2]], ptr %[[BASEPTR_2_GEP]], align 8 -// CHECK: %[[BASEPTR_6_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 6 -// CHECK: store ptr %[[ARG_4]], ptr %[[BASEPTR_6_GEP]], align 8 +// CHECK: %[[BASEPTR_3_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 9 +// CHECK: store ptr %[[ARG_4]], ptr %[[BASEPTR_3_GEP]], align 8 // CHECK: call void @__tgt_target_data_begin_mapper({{.*}}) // CHECK: %[[LOAD_BASEPTR_0:.*]] = load ptr, ptr %[[BASEPTR_0_GEP]], align 8 // store ptr %[[LOAD_BASEPTR_0]], ptr %[[ALLOCA]], align 8 // CHECK: %[[LOAD_BASEPTR_2:.*]] = load ptr, ptr %[[BASEPTR_2_GEP]], align 8 -// CHECK: %[[LOAD_BASEPTR_6:.*]] = load ptr, ptr %[[BASEPTR_6_GEP]], align 8 +// CHECK: %[[LOAD_BASEPTR_3:.*]] = load ptr, ptr %[[BASEPTR_3_GEP]], align 8 // CHECK: %[[GEP_A4:.*]] = getelementptr { i64 }, ptr %[[ARG_4]], i32 0, i32 0 // CHECK: %[[GEP_A7:.*]] = getelementptr { i64 }, ptr %[[ARG_7]], i32 0, i32 0 // CHECK: %[[LOAD_A4:.*]] = load i64, ptr %[[GEP_A4]], align 4 @@ -93,17 +93,17 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-a // CHECK: define void @mix_use_device_ptr_and_addr_and_map_2(ptr %[[ARG_0:.*]], ptr %[[ARG_1:.*]], ptr %[[ARG_2:.*]], ptr %[[ARG_3:.*]], ptr %[[ARG_4:.*]], ptr %[[ARG_5:.*]], ptr %[[ARG_6:.*]], ptr %[[ARG_7:.*]]) { // CHECK: %[[ALLOCA:.*]] = alloca ptr, align 8 -// CHECK: %[[BASEPTR_1_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 1 +// CHECK: %[[BASEPTR_1_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 1 // CHECK: store ptr %[[ARG_0]], ptr %[[BASEPTR_1_GEP]], align 8 -// CHECK: %[[BASEPTR_2_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 2 +// CHECK: %[[BASEPTR_2_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 4 // CHECK: store ptr %[[ARG_2]], ptr %[[BASEPTR_2_GEP]], align 8 -// CHECK: %[[BASEPTR_6_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 6 -// CHECK: store ptr %[[ARG_4]], ptr %[[BASEPTR_6_GEP]], align 8 +// CHECK: %[[BASEPTR_3_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 9 +// CHECK: store ptr %[[ARG_4]], ptr %[[BASEPTR_3_GEP]], align 8 // CHECK: call void @__tgt_target_data_begin_mapper({{.*}}) // CHECK: %[[LOAD_BASEPTR_1:.*]] = load ptr, ptr %[[BASEPTR_1_GEP]], align 8 // store ptr %[[LOAD_BASEPTR_1]], ptr %[[ALLOCA]], align 8 // CHECK: %[[LOAD_BASEPTR_2:.*]] = load ptr, ptr %[[BASEPTR_2_GEP]], align 8 -// CHECK: %[[LOAD_BASEPTR_6:.*]] = load ptr, ptr %[[BASEPTR_6_GEP]], align 8 +// CHECK: %[[LOAD_BASEPTR_3:.*]] = load ptr, ptr %[[BASEPTR_3_GEP]], align 8 // CHECK: %[[GEP_A4:.*]] = getelementptr { i64 }, ptr %[[ARG_4]], i32 0, i32 0 // CHECK: %[[GEP_A7:.*]] = getelementptr { i64 }, ptr %[[ARG_7]], i32 0, i32 0 // CHECK: %[[LOAD_A4:.*]] = load i64, ptr %[[GEP_A4]], align 4 diff --git a/mlir/test/Target/LLVMIR/omptarget-declare-target-to-device.mlir b/mlir/test/Target/LLVMIR/omptarget-declare-target-to-device.mlir new file mode 100644 index 0000000..fa330b6 --- /dev/null +++ b/mlir/test/Target/LLVMIR/omptarget-declare-target-to-device.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +// This tests the replacement of operations for `declare target to` with the +// generated `declare target to` global variable inside of target op regions when +// lowering to IR for device. Unfortunately, as the host file is not passed as a +// module attribute, we miss out on the metadata and entry info. + +module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} { + // CHECK-DAG: @_QMtest_0Ezii = global [11 x float] zeroinitializer + llvm.mlir.global external @_QMtest_0Ezii() {addr_space = 0 : i32, omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>} : !llvm.array<11 x f32> { + %0 = llvm.mlir.zero : !llvm.array<11 x f32> + llvm.return %0 : !llvm.array<11 x f32> + } + + // CHECK-LABEL: define weak_odr protected amdgpu_kernel void @{{.*}}(ptr %{{.*}}) {{.*}} { + // CHECK-DAG: omp.target: + // CHECK-DAG: store float 1.000000e+00, ptr @_QMtest_0Ezii, align 4 + // CHECK-DAG: br label %omp.region.cont + llvm.func @_QQmain() { + %0 = llvm.mlir.constant(1 : index) : i64 + %1 = llvm.mlir.constant(0 : index) : i64 + %2 = llvm.mlir.constant(11 : index) : i64 + %3 = llvm.mlir.addressof @_QMtest_0Ezii : !llvm.ptr + %4 = omp.map.bounds lower_bound(%1 : i64) upper_bound(%2 : i64) extent(%2 : i64) stride(%0 : i64) start_idx(%1 : i64) {stride_in_bytes = true} + %5 = omp.map.info var_ptr(%3 : !llvm.ptr, !llvm.array<11 x f32>) map_clauses(tofrom) capture(ByRef) bounds(%4) -> !llvm.ptr + omp.target map_entries(%5 -> %arg0 : !llvm.ptr) { + %6 = llvm.mlir.constant(1.0 : f32) : f32 + %7 = llvm.mlir.constant(0 : i64) : i64 + %8 = llvm.getelementptr %arg0[%7] : (!llvm.ptr, i64) -> !llvm.ptr, f32 + llvm.store %6, %8 : f32, !llvm.ptr + omp.terminator + } + llvm.return + } +} diff --git a/mlir/test/Target/LLVMIR/omptarget-declare-target-to-host.mlir b/mlir/test/Target/LLVMIR/omptarget-declare-target-to-host.mlir new file mode 100644 index 0000000..4202421 --- /dev/null +++ b/mlir/test/Target/LLVMIR/omptarget-declare-target-to-host.mlir @@ -0,0 +1,34 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu", omp.is_gpu = false, omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} { + // CHECK-DAG: @_QMtest_0Ezii = global [11 x float] zeroinitializer + // CHECK-DAG: @.offload_sizes = private unnamed_addr constant [1 x i64] [i64 48] + // CHECK-DAG: @.offload_maptypes = private unnamed_addr constant [1 x i64] [i64 3] + // CHECK-DAG: @.offloading.entry._QMtest_0Ezii = weak constant %struct.__tgt_offload_entry {{.*}} ptr @_QMtest_0Ezii, {{.*}}, i64 44,{{.*}} + llvm.mlir.global external @_QMtest_0Ezii() {addr_space = 0 : i32, omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>} : !llvm.array<11 x f32> { + %0 = llvm.mlir.zero : !llvm.array<11 x f32> + llvm.return %0 : !llvm.array<11 x f32> + } + + // CHECK-DAG: %[[BASEPTR:.*]] = getelementptr inbounds [1 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 + // CHECK-DAG: store ptr @_QMtest_0Ezii, ptr %[[BASEPTR]], align 8 + // CHECK-DAG: %[[OFFLOADPTR:.*]] = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0 + // CHECK-DAG: store ptr @_QMtest_0Ezii, ptr %[[OFFLOADPTR]], align 8 + llvm.func @_QQmain() { + %0 = llvm.mlir.constant(1 : index) : i64 + %1 = llvm.mlir.constant(0 : index) : i64 + %2 = llvm.mlir.constant(11 : index) : i64 + %3 = llvm.mlir.addressof @_QMtest_0Ezii : !llvm.ptr + %4 = omp.map.bounds lower_bound(%1 : i64) upper_bound(%2 : i64) extent(%2 : i64) stride(%0 : i64) start_idx(%1 : i64) {stride_in_bytes = true} + %5 = omp.map.info var_ptr(%3 : !llvm.ptr, !llvm.array<11 x f32>) map_clauses(tofrom) capture(ByRef) bounds(%4) -> !llvm.ptr + omp.target map_entries(%5 -> %arg0 : !llvm.ptr) { + %6 = llvm.mlir.constant(1.0 : f32) : f32 + %7 = llvm.mlir.constant(0 : i64) : i64 + %8 = llvm.getelementptr %arg0[%7] : (!llvm.ptr, i64) -> !llvm.ptr, f32 + llvm.store %6, %8 : f32, !llvm.ptr + omp.terminator + } + llvm.return + } + // CHEKC-DAG: !{{.*}} = !{i32 {{.*}}, !"_QMtest_0Ezii", i32 {{.*}}, i32 {{.*}}} +} diff --git a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir index e6ea3aa..e289d5d 100644 --- a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir @@ -622,3 +622,20 @@ module attributes {omp.target_triples = ["amdgcn-amd-amdhsa"]} { // CHECK: br label %[[VAL_40]] // CHECK: omp.done: ; preds = %[[VAL_68]], %[[VAL_63]], %[[VAL_32]] // CHECK: ret void + +// ----- + +module attributes {omp.target_triples = ["amdgcn-amd-amdhsa"]} { + llvm.func @_QPomp_target_is_device_ptr(%arg0 : !llvm.ptr) { + %map = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.ptr) + map_clauses(is_device_ptr) capture(ByRef) -> !llvm.ptr {name = ""} + omp.target map_entries(%map -> %ptr_arg : !llvm.ptr) { + omp.terminator + } + llvm.return + } +} + +// CHECK: @.offload_sizes = private unnamed_addr constant [1 x i64] [i64 8] +// CHECK: @.offload_maptypes = private unnamed_addr constant [1 x i64] [i64 288] +// CHECK-LABEL: define void @_QPomp_target_is_device_ptr diff --git a/mlir/test/Target/LLVMIR/omptarget-multi-block-reduction.mlir b/mlir/test/Target/LLVMIR/omptarget-multi-block-reduction.mlir index 87ff0ba..fac61e05 100644 --- a/mlir/test/Target/LLVMIR/omptarget-multi-block-reduction.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-multi-block-reduction.mlir @@ -7,7 +7,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<"dlti.alloca_memory_space" = 5 : llvm.func @bar() {} llvm.func @baz() {} - omp.declare_reduction @add_reduction_byref_box_5xf32 : !llvm.ptr alloc { + omp.declare_reduction @add_reduction_byref_box_5xf32 : !llvm.ptr attributes {byref_element_type = !llvm.array<5 x f32>} alloc { %0 = llvm.mlir.constant(1 : i64) : i64 %1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> : (i64) -> !llvm.ptr<5> %2 = llvm.addrspacecast %1 : !llvm.ptr<5> to !llvm.ptr @@ -23,7 +23,12 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<"dlti.alloca_memory_space" = 5 : ^bb3: // pred: ^bb1 llvm.call @baz() : () -> () omp.yield(%arg0 : !llvm.ptr) + } data_ptr_ptr { + ^bb0(%arg0: !llvm.ptr): + %0 = llvm.getelementptr %arg0[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> + omp.yield(%0 : !llvm.ptr) } + llvm.func @foo_() { %c1 = llvm.mlir.constant(1 : i64) : i64 %10 = llvm.alloca %c1 x !llvm.array<5 x f32> {bindc_name = "x"} : (i64) -> !llvm.ptr<5> @@ -51,8 +56,8 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<"dlti.alloca_memory_space" = 5 : } } -// CHECK: call void @__kmpc_parallel_51({{.*}}, i32 1, i32 -1, i32 -1, -// CHECK-SAME: ptr @[[PAR_OUTLINED:.*]], ptr null, ptr %2, i64 1) +// CHECK: call void @__kmpc_parallel_60({{.*}}, i32 1, i32 -1, i32 -1, +// CHECK-SAME: ptr @[[PAR_OUTLINED:.*]], ptr null, ptr %2, i64 1, i32 0) // CHECK: define internal void @[[PAR_OUTLINED]]{{.*}} { // CHECK: .omp.reduction.then: @@ -67,9 +72,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<"dlti.alloca_memory_space" = 5 : // CHECK: br label %[[CONT_BB:.*]] // CHECK: [[CONT_BB]]: -// CHECK-NEXT: %[[RED_RHS:.*]] = phi ptr [ %final.rhs, %{{.*}} ] -// CHECK-NEXT: store ptr %[[RED_RHS]], ptr %{{.*}}, align 8 -// CHECK-NEXT: br label %.omp.reduction.done +// CHECK-NEXT: %[[RED_RHS:.*]] = phi ptr [ %{{.*}}, %{{.*}} ] // CHECK: } // CHECK: define internal void @"{{.*}}$reduction$reduction_func"(ptr noundef %0, ptr noundef %1) #0 { diff --git a/mlir/test/Target/LLVMIR/omptarget-multi-reduction.mlir b/mlir/test/Target/LLVMIR/omptarget-multi-reduction.mlir index b8b7c78..8950db3 100644 --- a/mlir/test/Target/LLVMIR/omptarget-multi-reduction.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-multi-reduction.mlir @@ -109,19 +109,19 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo // CHECK: icmp eq i32 %[[MASTER]], 1 // CHECK: i1 %{{.+}}, label %[[THEN:[A-Za-z0-9_.]*]], label %[[DONE:[A-Za-z0-9_.]*]] // CHECK: [[THEN]]: -// CHECK-NEXT: %[[FINAL_RHS0:[A-Za-z0-9_.]*]] = load double // CHECK-NEXT: %[[FINAL_LHS0:[A-Za-z0-9_.]*]] = load double +// CHECK-NEXT: %[[FINAL_RHS0:[A-Za-z0-9_.]*]] = load double // CHECK-NEXT: %[[FINAL_RESULT0:[A-Za-z0-9_.]*]] = fadd contract double %[[FINAL_LHS0]], %[[FINAL_RHS0]] // CHECK-NEXT: store double %[[FINAL_RESULT0]] -// CHECK-NEXT: %[[FINAL_RHS1:[A-Za-z0-9_.]*]] = load double // CHECK-NEXT: %[[FINAL_LHS1:[A-Za-z0-9_.]*]] = load double +// CHECK-NEXT: %[[FINAL_RHS1:[A-Za-z0-9_.]*]] = load double // CHECK-NEXT: %[[FINAL_RESULT1:[A-Za-z0-9_.]*]] = fadd contract double %[[FINAL_LHS1]], %[[FINAL_RHS1]] // CHECK-NEXT: store double %[[FINAL_RESULT1]] -// CHECK-NEXT: %[[FINAL_RHS2:[A-Za-z0-9_.]*]] = load float // CHECK-NEXT: %[[FINAL_LHS2:[A-Za-z0-9_.]*]] = load float +// CHECK-NEXT: %[[FINAL_RHS2:[A-Za-z0-9_.]*]] = load float // CHECK-NEXT: %[[FINAL_RESULT2:[A-Za-z0-9_.]*]] = fadd contract float %[[FINAL_LHS2]], %[[FINAL_RHS2]] // CHECK-NEXT: store float %[[FINAL_RESULT2]] -// CHECK-NEXT: %[[FINAL_RHS3:[A-Za-z0-9_.]*]] = load float // CHECK-NEXT: %[[FINAL_LHS3:[A-Za-z0-9_.]*]] = load float +// CHECK-NEXT: %[[FINAL_RHS3:[A-Za-z0-9_.]*]] = load float // CHECK-NEXT: %[[FINAL_RESULT3:[A-Za-z0-9_.]*]] = fadd contract float %[[FINAL_LHS3]], %[[FINAL_RHS3]] // CHECK-NEXT: store float %[[FINAL_RESULT3]] diff --git a/mlir/test/Target/LLVMIR/omptarget-nowait.mlir b/mlir/test/Target/LLVMIR/omptarget-nowait.mlir index 19333c4..a96756f46 100644 --- a/mlir/test/Target/LLVMIR/omptarget-nowait.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-nowait.mlir @@ -25,34 +25,33 @@ module attributes {omp.target_triples = ["amdgcn-amd-amdhsa"]} { // CHECK: %struct.[[TSK_WTH_PRVTS:.*]] = type { %struct.kmp_task_ompbuilder_t, %struct.[[PRVTS:.*]] } // CHECK: %struct.kmp_task_ompbuilder_t = type { ptr, ptr, i32, ptr, ptr } -// CHECK: %struct.[[PRVTS]] = type { [5 x ptr], [5 x ptr], [5 x i64] } +// CHECK: %struct.[[PRVTS]] = type { [6 x ptr], [6 x ptr], [6 x i64] } // CHECK: define void @launch_(ptr captures(none) %0) // CHECK: %[[STRUCTARG:.*]] = alloca { ptr, ptr }, align 8 -// CHECK: %[[BASEPTRS:.*]] = alloca [5 x ptr], align 8 -// CHECK: %[[PTRS:.*]] = alloca [5 x ptr], align 8 -// CHECK: %[[MAPPERS:.*]] = alloca [5 x ptr], align 8 -// CHECK: %[[SIZES:.*]] = alloca [5 x i64], align 4 +// CHECK: %[[BASEPTRS:.*]] = alloca [6 x ptr], align 8 +// CHECK: %[[PTRS:.*]] = alloca [6 x ptr], align 8 +// CHECK: %[[MAPPERS:.*]] = alloca [6 x ptr], align 8 +// CHECK: %[[SIZES:.*]] = alloca [6 x i64], align 4 - -// CHECK: %[[VAL_20:.*]] = getelementptr inbounds [5 x ptr], ptr %[[BASEPTRS]], i32 0, i32 0 -// CHECK: %[[BASEPTRS_GEP:.*]] = getelementptr inbounds [5 x ptr], ptr %[[BASEPTRS]], i32 0, i32 0 -// CHECK: %[[PTRS_GEP:.*]] = getelementptr inbounds [5 x ptr], ptr %[[PTRS]], i32 0, i32 0 -// CHECK: %[[SIZES_GEP:.*]] = getelementptr inbounds [5 x i64], ptr %[[SIZES]], i32 0, i32 0 +// CHECK: %[[VAL_20:.*]] = getelementptr inbounds [6 x ptr], ptr %[[BASEPTRS]], i32 0, i32 0 +// CHECK: %[[BASEPTRS_GEP:.*]] = getelementptr inbounds [6 x ptr], ptr %[[BASEPTRS]], i32 0, i32 0 +// CHECK: %[[PTRS_GEP:.*]] = getelementptr inbounds [6 x ptr], ptr %[[PTRS]], i32 0, i32 0 +// CHECK: %[[SIZES_GEP:.*]] = getelementptr inbounds [6 x i64], ptr %[[SIZES]], i32 0, i32 0 // CHECK: %[[GL_THRD_NUM:.*]] = call i32 @__kmpc_global_thread_num -// CHECK: %[[TASK_DESC:.*]] = call ptr @__kmpc_omp_target_task_alloc(ptr @4, i32 {{.*}}, i32 0, i64 160, i64 16, ptr [[TGT_TSK_PRXY_FNC:.*]], i64 -1) +// CHECK: %[[TASK_DESC:.*]] = call ptr @__kmpc_omp_target_task_alloc(ptr @4, i32 {{.*}}, i32 0, i64 184, i64 16, ptr [[TGT_TSK_PRXY_FNC:.*]], i64 -1) // CHECK: %[[TSK_PTR:.*]] = getelementptr inbounds nuw %struct.[[TSK_WTH_PRVTS]], ptr %[[TASK_DESC]], i32 0, i32 0 // CHECK: %[[SHAREDS:.*]] = getelementptr inbounds nuw %struct.kmp_task_ompbuilder_t, ptr %[[TSK_PTR]], i32 0, i32 0 // CHECK: %[[SHAREDS_PTR:.*]] = load ptr, ptr %[[SHAREDS]], align 8 // CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align 1 %[[SHAREDS_PTR]], ptr align 1 %[[STRUCTARG]], i64 16, i1 false) // CHECK: %[[VAL_50:.*]] = getelementptr inbounds nuw %struct.[[TSK_WTH_PRVTS]], ptr %[[TASK_DESC]], i32 0, i32 1 // CHECK: %[[VAL_51:.*]] = getelementptr inbounds nuw %struct.[[PRVTS]], ptr %[[VAL_50]], i32 0, i32 0 -// CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align 1 %[[VAL_51]], ptr align 1 %[[BASEPTRS_GEP]], i64 40, i1 false) +// CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align 1 %[[VAL_51]], ptr align 1 %[[BASEPTRS_GEP]], i64 48, i1 false) // CHECK: %[[VAL_53:.*]] = getelementptr inbounds nuw %struct.[[PRVTS]], ptr %[[VAL_50]], i32 0, i32 1 -// CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align 1 %[[VAL_53]], ptr align 1 %[[PTRS_GEP]], i64 40, i1 false) +// CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align 1 %[[VAL_53]], ptr align 1 %[[PTRS_GEP]], i64 48, i1 false) // CHECK: %[[VAL_54:.*]] = getelementptr inbounds nuw %struct.[[PRVTS]], ptr %[[VAL_50]], i32 0, i32 2 -// CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align 1 %[[VAL_54]], ptr align 1 %[[SIZES_GEP]], i64 40, i1 false) +// CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align 1 %[[VAL_54]], ptr align 1 %[[SIZES_GEP]], i64 48, i1 false) // CHECK: %[[VAL_55:.*]] = call i32 @__kmpc_omp_task(ptr @4, i32 %[[GL_THRD_NUM]], ptr %[[TASK_DESC]]) // CHECK: define internal void @[[WORKER:.*]](i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}) { diff --git a/mlir/test/Target/LLVMIR/omptarget-overlapping-record-member-map.mlir b/mlir/test/Target/LLVMIR/omptarget-overlapping-record-member-map.mlir new file mode 100644 index 0000000..1e9369f --- /dev/null +++ b/mlir/test/Target/LLVMIR/omptarget-overlapping-record-member-map.mlir @@ -0,0 +1,65 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu", omp.is_gpu = false, omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} { + llvm.func @_QQmain() attributes {fir.bindc_name = "main"} { + %0 = llvm.mlir.constant(1 : i64) : i64 + %1 = llvm.alloca %0 x !llvm.struct<"_QFTdtype", (f32, i32)> {bindc_name = "dtypev"} : (i64) -> !llvm.ptr + %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"_QFTdtype", (f32, i32)> + %3 = omp.map.info var_ptr(%2 : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "dtypev%value2"} + %4 = omp.map.info var_ptr(%1 : !llvm.ptr, !llvm.struct<"_QFTdtype", (f32, i32)>) map_clauses(to) capture(ByRef) members(%3 : [1] : !llvm.ptr) -> !llvm.ptr {name = "dtypev"} + omp.target map_entries(%4 -> %arg0, %3 -> %arg1 : !llvm.ptr, !llvm.ptr) { + omp.terminator + } + llvm.return + } +} + +// CHECK: @.offload_sizes = private unnamed_addr constant [4 x i64] [i64 0, i64 0, i64 0, i64 4] +// CHECK: @.offload_maptypes = private unnamed_addr constant [4 x i64] [i64 32, i64 281474976710657, i64 281474976710657, i64 281474976710659] + +// CHECK: %[[ALLOCA:.*]] = alloca %_QFTdtype, i64 1, align 8 +// CHECK: %[[ELEMENT_ACC:.*]] = getelementptr %_QFTdtype, ptr %[[ALLOCA]], i32 0, i32 1 + +// CHECK: %[[SIZE1_CALC_1:.*]] = getelementptr %_QFTdtype, ptr %[[ALLOCA]], i32 1 +// CHECK: %[[SIZE1_CALC_2:.*]] = ptrtoint ptr %[[SIZE1_CALC_1]] to i64 +// CHECK: %[[SIZE1_CALC_3:.*]] = ptrtoint ptr %[[ALLOCA]] to i64 +// CHECK: %[[SIZE1_CALC_4:.*]] = sub i64 %[[SIZE1_CALC_2]], %[[SIZE1_CALC_3]] +// CHECK: %[[SIZE1_CALC_5:.*]] = sdiv exact i64 %[[SIZE1_CALC_4]], ptrtoint (ptr getelementptr (i8, ptr null, i32 1) to i64) + +// CHECK: %[[SIZE2_CALC_1:.*]] = getelementptr %_QFTdtype, ptr %[[ALLOCA]], i32 1 +// CHECK: %[[SIZE2_CALC_2:.*]] = ptrtoint ptr %[[ELEMENT_ACC]] to i64 +// CHECK: %[[SIZE2_CALC_3:.*]] = ptrtoint ptr %[[ALLOCA]] to i64 +// CHECK: %[[SIZE2_CALC_4:.*]] = sub i64 %[[SIZE2_CALC_2]], %[[SIZE2_CALC_3]] +// CHECK: %[[SIZE2_CALC_5:.*]] = sdiv exact i64 %[[SIZE2_CALC_4]], ptrtoint (ptr getelementptr (i8, ptr null, i32 1) to i64) + +// CHECK: %[[SIZE3_CALC_1:.*]] = getelementptr i32, ptr %[[ELEMENT_ACC]], i32 1 +// CHECK: %[[SIZE3_CALC_2:.*]] = ptrtoint ptr %[[SIZE2_CALC_1]] to i64 +// CHECK: %[[SIZE3_CALC_3:.*]] = ptrtoint ptr %[[SIZE3_CALC_1]] to i64 +// CHECK: %[[SIZE3_CALC_4:.*]] = sub i64 %[[SIZE3_CALC_2]], %[[SIZE3_CALC_3]] +// CHECK: %[[SIZE3_CALC_5:.*]] = sdiv exact i64 %[[SIZE3_CALC_4]], ptrtoint (ptr getelementptr (i8, ptr null, i32 1) to i64) + +// CHECK: %[[BASEPTR:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 +// CHECK: store ptr %[[ALLOCA]], ptr %[[BASEPTR]], align 8 +// CHECK: %[[PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_ptrs, i32 0, i32 0 +// CHECK: store ptr %[[ALLOCA]], ptr %[[PTRS]], align 8 +// CHECK: %[[SIZES:.*]] = getelementptr inbounds [4 x i64], ptr %.offload_sizes, i32 0, i32 0 +// CHECK: store i64 %[[SIZE1_CALC_5]], ptr %[[SIZES]], align 8 + +// CHECK: %[[BASEPTR:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_baseptrs, i32 0, i32 1 +// CHECK: store ptr %[[ALLOCA]], ptr %[[BASEPTR]], align 8 +// CHECK: %[[PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_ptrs, i32 0, i32 1 +// CHECK: store ptr %[[ALLOCA]], ptr %[[PTRS]], align 8 +// CHECK: %[[SIZES:.*]] = getelementptr inbounds [4 x i64], ptr %.offload_sizes, i32 0, i32 1 +// CHECK: store i64 %[[SIZE2_CALC_5]], ptr %[[SIZES]], align 8 + +// CHECK: %[[BASEPTR:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_baseptrs, i32 0, i32 2 +// CHECK: store ptr %[[ALLOCA]], ptr %[[BASEPTR]], align 8 +// CHECK: %[[PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_ptrs, i32 0, i32 2 +// CHECK: store ptr %13, ptr %[[PTRS]], align 8 +// CHECK: %[[SIZES:.*]] = getelementptr inbounds [4 x i64], ptr %.offload_sizes, i32 0, i32 2 +// CHECK: store i64 %[[SIZE3_CALC_5]], ptr %[[SIZES]], align 8 + +// CHECK: %[[BASEPTR:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_baseptrs, i32 0, i32 3 +// CHECK: store ptr %[[ALLOCA]], ptr %[[BASEPTR]], align 8 +// CHECK: %[[PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_ptrs, i32 0, i32 3 +// CHECK: store ptr %[[ELEMENT_ACC]], ptr %[[PTRS]], align 8 diff --git a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir index 60c6fa4..cdb8dbb 100644 --- a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir @@ -70,31 +70,31 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo // CHECK: store ptr %[[TMP6]], ptr addrspace(5) %[[GEP_]], align 8 // CHECK: %[[TMP7:.*]] = getelementptr inbounds [1 x ptr], ptr %[[TMP2]], i64 0, i64 0 // CHECK: store ptr %[[STRUCTARG_ASCAST]], ptr %[[TMP7]], align 8 -// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast (ptr addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, i32 -1, ptr @[[FUNC1:.*]], ptr null, ptr %[[TMP2]], i64 1) +// CHECK: call void @__kmpc_parallel_60(ptr addrspacecast (ptr addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, i32 -1, ptr @[[FUNC1:.*]], ptr null, ptr %[[TMP2]], i64 1, i32 0) // CHECK: call void @__kmpc_target_deinit() // CHECK: define internal void @[[FUNC1]]( // CHECK-SAME: ptr noalias noundef {{.*}}, ptr noalias noundef {{.*}}, ptr {{.*}}) #{{[0-9]+}} { // Test if num_threads OpenMP clause for target region is correctly lowered -// and passed as a param to kmpc_parallel_51 function +// and passed as a param to kmpc_parallel_60 function // CHECK: define weak_odr protected amdgpu_kernel void [[FUNC_NUM_THREADS0:@.*]]( // CHECK-NOT: call void @__kmpc_push_num_threads( -// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast ( +// CHECK: call void @__kmpc_parallel_60(ptr addrspacecast ( // CHECK-SAME: ptr addrspace(1) @[[NUM_THREADS_GLOB:[0-9]+]] to ptr), // CHECK-SAME: i32 [[NUM_THREADS_TMP0:%.*]], i32 1, i32 156, -// CHECK-SAME: i32 -1, ptr [[FUNC_NUM_THREADS1:@.*]], ptr null, ptr [[NUM_THREADS_TMP1:%.*]], i64 1) +// CHECK-SAME: i32 -1, ptr [[FUNC_NUM_THREADS1:@.*]], ptr null, ptr [[NUM_THREADS_TMP1:%.*]], i64 1, i32 0) -// One of the arguments of kmpc_parallel_51 function is responsible for handling if clause +// One of the arguments of kmpc_parallel_60 function is responsible for handling if clause // of omp parallel construct for target region. If this argument is nonzero, -// then kmpc_parallel_51 launches multiple threads for parallel region. +// then kmpc_parallel_60 launches multiple threads for parallel region. // // This test checks if MLIR expression: // %7 = llvm.icmp "ne" %5, %6 : i32 // omp.parallel if(%7) // is correctly lowered to LLVM IR code and the if condition variable -// is passed as a param to kmpc_parallel_51 function +// is passed as a param to kmpc_parallel_60 function // CHECK: define weak_odr protected amdgpu_kernel void @{{.*}}( // CHECK-SAME: ptr {{.*}}, ptr {{.*}}, ptr %[[IFCOND_ARG2:.*]]) #{{[0-9]+}} { @@ -102,7 +102,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo // CHECK: %[[IFCOND_TMP2:.*]] = load i32, ptr %[[IFCOND_TMP1]], align 4 // CHECK: %[[IFCOND_TMP3:.*]] = icmp ne i32 %[[IFCOND_TMP2]], 0 // CHECK: %[[IFCOND_TMP4:.*]] = sext i1 %[[IFCOND_TMP3]] to i32 -// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast ( +// CHECK: call void @__kmpc_parallel_60(ptr addrspacecast ( // CHECK-SAME: ptr addrspace(1) {{.*}} to ptr), // CHECK-SAME: i32 {{.*}}, i32 %[[IFCOND_TMP4]], i32 -1, -// CHECK-SAME: i32 -1, ptr {{.*}}, ptr null, ptr {{.*}}, i64 1) +// CHECK-SAME: i32 -1, ptr {{.*}}, ptr null, ptr {{.*}}, i64 1, i32 0) diff --git a/mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir b/mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir index 5d2861a..917eaa0 100644 --- a/mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir @@ -26,10 +26,10 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo } } -// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast +// CHECK: call void @__kmpc_parallel_60(ptr addrspacecast // CHECK-SAME: (ptr addrspace(1) @[[GLOB:[0-9]+]] to ptr), // CHECK-SAME: i32 %[[THREAD_NUM:.*]], i32 1, i32 -1, i32 -1, -// CHECK-SAME: ptr @[[PARALLEL_FUNC:.*]], ptr null, ptr %[[PARALLEL_ARGS:.*]], i64 1) +// CHECK-SAME: ptr @[[PARALLEL_FUNC:.*]], ptr null, ptr %[[PARALLEL_ARGS:.*]], i64 1, i32 0) // CHECK: define internal void @[[PARALLEL_FUNC]] // CHECK-SAME: (ptr noalias noundef %[[TID_ADDR:.*]], ptr noalias noundef %[[ZERO_ADDR:.*]], diff --git a/mlir/test/Target/LLVMIR/omptarget-record-type-with-ptr-member-host.mlir b/mlir/test/Target/LLVMIR/omptarget-record-type-with-ptr-member-host.mlir index 9640f03..711b50a 100644 --- a/mlir/test/Target/LLVMIR/omptarget-record-type-with-ptr-member-host.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-record-type-with-ptr-member-host.mlir @@ -59,9 +59,9 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-a // CHECK: @[[FULL_ARR_GLOB:.*]] = internal global { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] } undef // CHECK: @[[ARR_SECT_GLOB:.*]] = internal global { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] } undef -// CHECK: @.offload_sizes = private unnamed_addr constant [12 x i64] [i64 0, i64 48, i64 8, i64 0, i64 0, i64 48, i64 8, i64 0, i64 0, i64 24, i64 8, i64 0] -// CHECK: @.offload_maptypes = private unnamed_addr constant [12 x i64] [i64 32, i64 281474976710659, i64 281474976710659, i64 281474976710675, i64 32, i64 1407374883553283, i64 1407374883553283, i64 1407374883553299, i64 32, i64 2533274790395907, i64 2533274790395907, i64 2533274790395923] -// CHECK: @.offload_mapnames = private constant [12 x ptr] [ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}] +// CHECK: @.offload_sizes = private unnamed_addr constant [15 x i64] [i64 0, i64 0, i64 0, i64 8, i64 0, i64 0, i64 0, i64 0, i64 8, i64 0, i64 0, i64 0, i64 0, i64 8, i64 0] +// CHECK: @.offload_maptypes = private unnamed_addr constant [15 x i64] [i64 32, i64 281474976710659, i64 281474976710659, i64 281474976710659, i64 281474976710675, i64 32, i64 1688849860263939, i64 1688849860263939, i64 1688849860263939, i64 1688849860263955, i64 32, i64 3096224743817219, i64 3096224743817219, i64 3096224743817219, i64 3096224743817235] +// CHECK: @.offload_mapnames = private constant [15 x ptr] [ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}] // CHECK: define void @main() // CHECK: %[[SCALAR_ALLOCA:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8 }, i64 1, align 8 @@ -85,74 +85,97 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-a // CHECK: %[[ARR_SECT_PTR:.*]] = getelementptr inbounds i32, ptr %[[LARR_SECT]], i64 %[[ARR_SECT_OFFSET2]] // CHECK: %[[SCALAR_PTR_LOAD:.*]] = load ptr, ptr %[[SCALAR_BASE]], align 8 // CHECK: %[[FULL_ARR_DESC_SIZE:.*]] = sdiv exact i64 48, ptrtoint (ptr getelementptr (i8, ptr null, i32 1) to i64) -// CHECK: %[[FULL_ARR_SIZE_CMP:.*]] = icmp eq ptr %[[FULL_ARR_PTR]], null -// CHECK: %[[FULL_ARR_SIZE_SEL:.*]] = select i1 %[[FULL_ARR_SIZE_CMP]], i64 0, i64 %[[FULL_ARR_SIZE]] +// CHECK: %[[FULL_ARR_SZ:.*]] = sdiv exact i64 40, ptrtoint (ptr getelementptr (i8, ptr null, i32 1) to i64) +// CHECK: %[[NULL_CMP:.*]] = icmp eq ptr %[[FULL_ARR_PTR]], null +// CHECK: %[[IS_NULL:.*]] = select i1 %[[NULL_CMP]], i64 0, i64 %[[FULL_ARR_SIZE]] // CHECK: %[[ARR_SECT_DESC_SIZE:.*]] = sdiv exact i64 48, ptrtoint (ptr getelementptr (i8, ptr null, i32 1) to i64) -// CHECK: %[[ARR_SECT_SIZE_CMP:.*]] = icmp eq ptr %[[ARR_SECT_PTR]], null -// CHECK: %[[ARR_SECT_SIZE_SEL:.*]] = select i1 %[[ARR_SECT_SIZE_CMP]], i64 0, i64 %[[ARR_SECT_SIZE]] +// CHECK: %[[ARR_SECT_SZ:.*]] = sdiv exact i64 40, ptrtoint (ptr getelementptr (i8, ptr null, i32 1) to i64) +// CHECK: %[[NULL_CMP2:.*]] = icmp eq ptr %[[ARR_SECT_PTR]], null +// CHECK: %[[IS_NULL2:.*]] = select i1 %[[NULL_CMP2]], i64 0, i64 %[[ARR_SECT_SIZE]] // CHECK: %[[SCALAR_DESC_SZ4:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8 }, ptr %[[SCALAR_ALLOCA]], i32 1 // CHECK: %[[SCALAR_DESC_SZ3:.*]] = ptrtoint ptr %[[SCALAR_DESC_SZ4]] to i64 // CHECK: %[[SCALAR_DESC_SZ2:.*]] = ptrtoint ptr %[[SCALAR_ALLOCA]] to i64 // CHECK: %[[SCALAR_DESC_SZ1:.*]] = sub i64 %[[SCALAR_DESC_SZ3]], %[[SCALAR_DESC_SZ2]] // CHECK: %[[SCALAR_DESC_SZ:.*]] = sdiv exact i64 %[[SCALAR_DESC_SZ1]], ptrtoint (ptr getelementptr (i8, ptr null, i32 1) to i64) - -// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 +// CHECK: %[[SCALAR_BASE_2:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8 }, ptr %[[SCALAR_ALLOCA]], i32 1 +// CHECK: %[[SCALAR_BASE_OFF:.*]] = getelementptr ptr, ptr %[[SCALAR_BASE]], i32 1 +// CHECK: %[[SCALAR_BASE_OFF_SZ1:.*]] = ptrtoint ptr %[[SCALAR_BASE_2]] to i64 +// CHECK: %[[SCALAR_BASE_OFF_SZ2:.*]] = ptrtoint ptr %[[SCALAR_BASE_OFF]] to i64 +// CHECK: %[[SCALAR_BASE_OFF_SZ3:.*]] = sub i64 %[[SCALAR_BASE_OFF_SZ1]], %[[SCALAR_BASE_OFF_SZ2]] +// CHECK: %[[SCALAR_BASE_OFF_SZ4:.*]] = sdiv exact i64 %[[SCALAR_BASE_OFF_SZ3]], ptrtoint (ptr getelementptr (i8, ptr null, i32 1) to i64) +// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 // CHECK: store ptr @full_arr, ptr %[[OFFLOADBASEPTRS]], align 8 -// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 0 +// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 0 // CHECK: store ptr @full_arr, ptr %[[OFFLOADPTRS]], align 8 -// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [12 x i64], ptr %.offload_sizes, i32 0, i32 0 +// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [15 x i64], ptr %.offload_sizes, i32 0, i32 0 // CHECK: store i64 %[[FULL_ARR_DESC_SIZE]], ptr %[[OFFLOADSIZES]], align 8 -// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 1 +// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 1 // CHECK: store ptr @full_arr, ptr %[[OFFLOADBASEPTRS]], align 8 -// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 1 +// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 1 // CHECK: store ptr @full_arr, ptr %[[OFFLOADPTRS]], align 8 -// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 2 +// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 2 +// CHECK: store ptr @full_arr, ptr %[[OFFLOADBASEPTRS]], align 8 +// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 2 +// CHECK: store ptr getelementptr inbounds nuw (i8, ptr @full_arr, i64 8), ptr %[[OFFLOADPTRS]], align 8 +// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [15 x i64], ptr %.offload_sizes, i32 0, i32 2 +// CHECK: store i64 %[[FULL_ARR_SZ]], ptr %[[OFFLOADSIZES]], align 8 +// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 3 // CHECK: store ptr @full_arr, ptr %[[OFFLOADBASEPTRS]], align 8 -// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 2 +// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 3 // CHECK: store ptr @full_arr, ptr %[[OFFLOADPTRS]], align 8 -// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 3 +// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 4 // CHECK: store ptr @full_arr, ptr %[[OFFLOADBASEPTRS]], align 8 -// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 3 +// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 4 // CHECK: store ptr %[[FULL_ARR_PTR]], ptr %[[OFFLOADPTRS]], align 8 -// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [12 x i64], ptr %.offload_sizes, i32 0, i32 3 -// CHECK: store i64 %[[FULL_ARR_SIZE_SEL]], ptr %[[OFFLOADSIZES]], align 8 - -// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 4 +// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [15 x i64], ptr %.offload_sizes, i32 0, i32 4 +// CHECK: store i64 %[[IS_NULL]], ptr %[[OFFLOADSIZES]], align 8 +// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 5 // CHECK: store ptr @sect_arr, ptr %[[OFFLOADBASEPTRS]], align 8 -// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 4 +// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 5 // CHECK: store ptr @sect_arr, ptr %[[OFFLOADPTRS]], align 8 -// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [12 x i64], ptr %.offload_sizes, i32 0, i32 4 +// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [15 x i64], ptr %.offload_sizes, i32 0, i32 5 // CHECK: store i64 %[[ARR_SECT_DESC_SIZE]], ptr %[[OFFLOADSIZES]], align 8 -// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 5 +// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 6 // CHECK: store ptr @sect_arr, ptr %[[OFFLOADBASEPTRS]], align 8 -// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 5 +// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 6 // CHECK: store ptr @sect_arr, ptr %[[OFFLOADPTRS]], align 8 -// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 6 +// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 7 // CHECK: store ptr @sect_arr, ptr %[[OFFLOADBASEPTRS]], align 8 -// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 6 +// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 7 +// CHECK: store ptr getelementptr inbounds nuw (i8, ptr @sect_arr, i64 8), ptr %[[OFFLOADPTRS]], align 8 +// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [15 x i64], ptr %.offload_sizes, i32 0, i32 7 +// CHECK: store i64 %[[ARR_SECT_SZ]], ptr %[[OFFLOADSIZES]], align 8 +// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 8 +// CHECK: store ptr @sect_arr, ptr %[[OFFLOADBASEPTRS]], align 8 +// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 8 // CHECK: store ptr @sect_arr, ptr %[[OFFLOADPTRS]], align 8 -// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 7 +// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 9 // CHECK: store ptr @sect_arr, ptr %[[OFFLOADBASEPTRS]], align 8 -// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 7 +// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 9 // CHECK: store ptr %[[ARR_SECT_PTR]], ptr %[[OFFLOADPTRS]], align 8 -// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [12 x i64], ptr %.offload_sizes, i32 0, i32 7 -// CHECK: store i64 %[[ARR_SECT_SIZE_SEL]], ptr %[[OFFLOADSIZES]], align 8 - -// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 8 +// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [15 x i64], ptr %.offload_sizes, i32 0, i32 9 +// CHECK: store i64 %[[IS_NULL2]], ptr %[[OFFLOADSIZES]], align 8 +// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 10 // CHECK: store ptr %[[SCALAR_ALLOCA]], ptr %[[OFFLOADBASEPTRS]], align 8 -// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 8 +// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 10 // CHECK: store ptr %[[SCALAR_ALLOCA]], ptr %[[OFFLOADPTRS]], align 8 -// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [12 x i64], ptr %.offload_sizes, i32 0, i32 8 +// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [15 x i64], ptr %.offload_sizes, i32 0, i32 10 // CHECK: store i64 %[[SCALAR_DESC_SZ]], ptr %[[OFFLOADSIZES]], align 8 -// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 9 +// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 11 // CHECK: store ptr %[[SCALAR_ALLOCA]], ptr %[[OFFLOADBASEPTRS]], align 8 -// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 9 +// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 11 // CHECK: store ptr %[[SCALAR_ALLOCA]], ptr %[[OFFLOADPTRS]], align 8 -// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 10 +// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 12 +// CHECK: store ptr %[[SCALAR_ALLOCA]], ptr %[[OFFLOADBASEPTRS]], align 8 +// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 12 +// CHECK: store ptr %[[SCALAR_BASE_OFF]], ptr %[[OFFLOADPTRS]], align 8 +// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [15 x i64], ptr %.offload_sizes, i32 0, i32 12 +// CHECK: store i64 %[[SCALAR_BASE_OFF_SZ4]], ptr %[[OFFLOADSIZES]], align 8 +// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 13 // CHECK: store ptr %[[SCALAR_ALLOCA]], ptr %[[OFFLOADBASEPTRS]], align 8 -// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 10 +// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 13 // CHECK: store ptr %[[SCALAR_BASE]], ptr %[[OFFLOADPTRS]], align 8 -// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 11 +// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 14 // CHECK: store ptr %[[SCALAR_BASE]], ptr %[[OFFLOADBASEPTRS]], align 8 -// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 11 +// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 14 // CHECK: store ptr %[[SCALAR_PTR_LOAD]], ptr %[[OFFLOADPTRS]], align 8 diff --git a/mlir/test/Target/LLVMIR/omptarget-runtimecc.mlir b/mlir/test/Target/LLVMIR/omptarget-runtimecc.mlir new file mode 100644 index 0000000..a232bd7 --- /dev/null +++ b/mlir/test/Target/LLVMIR/omptarget-runtimecc.mlir @@ -0,0 +1,12 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +module attributes {omp.is_target_device = true, omp.is_gpu = true, omp.target_triples = ["spirv64-intel"], llvm.target_triple = "spirv64-intel"} { +// CHECK: call spir_func i32 @__kmpc_target_init +// CHECK: call spir_func void @__kmpc_target_deinit + llvm.func @target_if_variable(%x : i1) { + omp.target if(%x) { + omp.terminator + } + llvm.return + } + } diff --git a/mlir/test/Target/LLVMIR/omptarget-teams-distribute-reduction.mlir b/mlir/test/Target/LLVMIR/omptarget-teams-distribute-reduction.mlir index 9aba72d..b7cb102 100644 --- a/mlir/test/Target/LLVMIR/omptarget-teams-distribute-reduction.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-teams-distribute-reduction.mlir @@ -59,8 +59,8 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo // CHECK: call void @__kmpc_barrier // CHECK: [[THEN]]: -// CHECK-NEXT: %[[FINAL_RHS:[A-Za-z0-9_.]*]] = load i32 // CHECK-NEXT: %[[FINAL_LHS:[A-Za-z0-9_.]*]] = load i32 +// CHECK-NEXT: %[[FINAL_RHS:[A-Za-z0-9_.]*]] = load i32 // CHECK-NEXT: %[[FINAL_RESULT:[A-Za-z0-9_.]*]] = add i32 %[[FINAL_LHS]], %[[FINAL_RHS]] // CHECK-NEXT: store i32 %[[FINAL_RESULT]] diff --git a/mlir/test/Target/LLVMIR/omptarget-teams-reduction.mlir b/mlir/test/Target/LLVMIR/omptarget-teams-reduction.mlir index dc22fe1..36eb280 100644 --- a/mlir/test/Target/LLVMIR/omptarget-teams-reduction.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-teams-reduction.mlir @@ -62,8 +62,8 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo // CHECK: icmp eq i32 %[[MASTER]], 1 // CHECK: i1 %{{.+}}, label %[[THEN:[A-Za-z0-9_.]*]], label %[[DONE:[A-Za-z0-9_.]*]] // CHECK: [[THEN]]: -// CHECK-NEXT: %[[FINAL_RHS:[A-Za-z0-9_.]*]] = load i32 // CHECK-NEXT: %[[FINAL_LHS:[A-Za-z0-9_.]*]] = load i32 +// CHECK-NEXT: %[[FINAL_RHS:[A-Za-z0-9_.]*]] = load i32 // CHECK-NEXT: %[[FINAL_RESULT:[A-Za-z0-9_.]*]] = add i32 %[[FINAL_LHS]], %[[FINAL_RHS]] // CHECK-NEXT: store i32 %[[FINAL_RESULT]] diff --git a/mlir/test/Target/LLVMIR/openmp-barrier-cancel.mlir b/mlir/test/Target/LLVMIR/openmp-barrier-cancel.mlir index c4b2456..6585549 100644 --- a/mlir/test/Target/LLVMIR/openmp-barrier-cancel.mlir +++ b/mlir/test/Target/LLVMIR/openmp-barrier-cancel.mlir @@ -29,22 +29,24 @@ llvm.func @test() { // CHECK: %[[VAL_14:.*]] = icmp eq i32 %[[VAL_13]], 0 // CHECK: br i1 %[[VAL_14]], label %[[VAL_15:.*]], label %[[VAL_16:.*]] // CHECK: omp.par.region1.cncl: ; preds = %[[VAL_11]] -// CHECK: %[[VAL_17:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) -// CHECK: %[[VAL_18:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[VAL_17]]) -// CHECK: br label %[[VAL_19:.*]] +// CHECK: br label %[[FINI:.*]] +// CHECK: .fini: +// CHECK: %[[TID:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) +// CHECK: %[[CNCL_BARRIER:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[TID]]) +// CHECK: br label %[[EXIT_STUB:.*]] // CHECK: omp.par.region1.split: ; preds = %[[VAL_11]] // CHECK: %[[VAL_20:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) // CHECK: %[[VAL_21:.*]] = call i32 @__kmpc_cancel_barrier(ptr @3, i32 %[[VAL_20]]) // CHECK: %[[VAL_22:.*]] = icmp eq i32 %[[VAL_21]], 0 // CHECK: br i1 %[[VAL_22]], label %[[VAL_23:.*]], label %[[VAL_24:.*]] // CHECK: omp.par.region1.split.cncl: ; preds = %[[VAL_15]] -// CHECK: br label %[[VAL_19]] +// CHECK: br label %[[FINI]] // CHECK: omp.par.region1.split.cont: ; preds = %[[VAL_15]] // CHECK: br label %[[VAL_25:.*]] // CHECK: omp.region.cont: ; preds = %[[VAL_23]] // CHECK: br label %[[VAL_26:.*]] // CHECK: omp.par.pre_finalize: ; preds = %[[VAL_25]] -// CHECK: br label %[[VAL_19]] -// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_26]], %[[VAL_24]], %[[VAL_16]] +// CHECK: br label %[[FINI]] +// CHECK: omp.par.exit.exitStub: // CHECK: ret void diff --git a/mlir/test/Target/LLVMIR/openmp-cancel.mlir b/mlir/test/Target/LLVMIR/openmp-cancel.mlir index 2124170..a6911f8 100644 --- a/mlir/test/Target/LLVMIR/openmp-cancel.mlir +++ b/mlir/test/Target/LLVMIR/openmp-cancel.mlir @@ -24,16 +24,18 @@ llvm.func @cancel_parallel() { // CHECK: %[[VAL_15:.*]] = icmp eq i32 %[[VAL_14]], 0 // CHECK: br i1 %[[VAL_15]], label %[[VAL_16:.*]], label %[[VAL_17:.*]] // CHECK: omp.par.region1.cncl: ; preds = %[[VAL_12]] +// CHECK: br label %[[VAL_20:.*]] +// CHECK: .fini: // CHECK: %[[VAL_18:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) // CHECK: %[[VAL_19:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[VAL_18]]) -// CHECK: br label %[[VAL_20:.*]] +// CHECK: br label %[[EXIT_STUB:.*]] // CHECK: omp.par.region1.split: ; preds = %[[VAL_12]] // CHECK: br label %[[VAL_21:.*]] // CHECK: omp.region.cont: ; preds = %[[VAL_16]] // CHECK: br label %[[VAL_22:.*]] // CHECK: omp.par.pre_finalize: ; preds = %[[VAL_21]] // CHECK: br label %[[VAL_20]] -// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_22]], %[[VAL_17]] +// CHECK: omp.par.exit.exitStub: // CHECK: ret void llvm.func @cancel_parallel_if(%arg0 : i1) { @@ -58,27 +60,36 @@ llvm.func @cancel_parallel_if(%arg0 : i1) { // CHECK: omp.par.region: ; preds = %[[VAL_17]] // CHECK: br label %[[VAL_20:.*]] // CHECK: omp.par.region1: ; preds = %[[VAL_19]] -// CHECK: br i1 %[[VAL_16]], label %[[VAL_21:.*]], label %[[VAL_22:.*]] +// CHECK: br i1 %[[VAL_16]], label %[[SPLIT:.*]], label %[[VAL_22:.*]] // CHECK: 3: ; preds = %[[VAL_20]] -// CHECK: br label %[[VAL_23:.*]] -// CHECK: 4: ; preds = %[[VAL_22]], %[[VAL_24:.*]] +// CHECK: %[[GTN:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) +// CHECK: %[[NOT_CANCELLED:.*]] = call i32 @__kmpc_cancellationpoint(ptr @1, i32 %[[GTN]], i32 1) +// CHECK: %[[COND:.*]] = icmp eq i32 %[[NOT_CANCELLED]], 0 +// CHECK: br i1 %[[COND]], label %[[VAL_23:.*]], label %[[CNCL:.*]] +// CHECK: .cncl: +// CHECK: br label %[[FINI:.*]] +// CHECK: .fini: +// CHECK: %[[VAL_32:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) +// CHECK: %[[VAL_33:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[VAL_32]]) +// CHECK: br label %[[EXIT_STUB:.*]] +// CHECK: .split: +// CHECK: br label %[[SEVEN:.*]] +// CHECK: 7: // CHECK: br label %[[VAL_25:.*]] -// CHECK: omp.region.cont: ; preds = %[[VAL_23]] +// CHECK: omp.region.cont: // CHECK: br label %[[VAL_26:.*]] // CHECK: omp.par.pre_finalize: ; preds = %[[VAL_25]] // CHECK: br label %[[VAL_27:.*]] -// CHECK: 5: ; preds = %[[VAL_20]] +// CHECK: 8: ; preds = %[[VAL_20]] // CHECK: %[[VAL_28:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) // CHECK: %[[VAL_29:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_28]], i32 1) // CHECK: %[[VAL_30:.*]] = icmp eq i32 %[[VAL_29]], 0 -// CHECK: br i1 %[[VAL_30]], label %[[VAL_24]], label %[[VAL_31:.*]] -// CHECK: .cncl: ; preds = %[[VAL_21]] -// CHECK: %[[VAL_32:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) -// CHECK: %[[VAL_33:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[VAL_32]]) -// CHECK: br label %[[VAL_27]] -// CHECK: .split: ; preds = %[[VAL_21]] -// CHECK: br label %[[VAL_23]] -// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_31]], %[[VAL_26]] +// CHECK: br i1 %[[VAL_30]], label %[[SPLIT5:.*]], label %[[VAL_31:.*]] +// CHECK: .cncl{{.*}}: +// CHECK: br label %[[FINI]] +// CHECK: .split{{.*}}: +// CHECK: br label %[[SEVEN]] +// CHECK: omp.par.exit.exitStub: // CHECK: ret void llvm.func @cancel_sections_if(%cond : i1) { @@ -132,11 +143,16 @@ llvm.func @cancel_sections_if(%cond : i1) { // CHECK: %[[VAL_30:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_29]], i32 3) // CHECK: %[[VAL_31:.*]] = icmp eq i32 %[[VAL_30]], 0 // CHECK: br i1 %[[VAL_31]], label %[[VAL_32:.*]], label %[[VAL_33:.*]] -// CHECK: .split: ; preds = %[[VAL_27]] +// CHECK: .split{{.*}}: ; preds = %[[VAL_27]] // CHECK: br label %[[VAL_34:.*]] // CHECK: 12: ; preds = %[[VAL_25]] +// CHECK: %[[GTN:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) +// CHECK: %[[CANCEL_POINT:.*]] = call i32 @__kmpc_cancellationpoint(ptr @1, i32 %[[GTN]], i32 3) +// CHECK: %[[COND:.*]] = icmp eq i32 %13, 0 +// CHECK: br i1 %[[COND]], label %[[SPLIT:.*]], label %[[CNCL:.*]] +// CHECK: .split{{.*}}: // CHECK: br label %[[VAL_34]] -// CHECK: 13: ; preds = %[[VAL_28]], %[[VAL_32]] +// CHECK: 15: // CHECK: br label %[[VAL_35:.*]] // CHECK: omp.region.cont: ; preds = %[[VAL_34]] // CHECK: br label %[[VAL_23]] @@ -145,17 +161,17 @@ llvm.func @cancel_sections_if(%cond : i1) { // CHECK: omp_section_loop.inc: ; preds = %[[VAL_23]] // CHECK: %[[VAL_15]] = add nuw i32 %[[VAL_14]], 1 // CHECK: br label %[[VAL_12]] -// CHECK: omp_section_loop.exit: ; preds = %[[VAL_33]], %[[VAL_16]] +// CHECK: omp_section_loop.exit: // CHECK: call void @__kmpc_for_static_fini(ptr @1, i32 %[[VAL_7]]) // CHECK: %[[VAL_36:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) // CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_36]]) // CHECK: br label %[[VAL_37:.*]] // CHECK: omp_section_loop.after: ; preds = %[[VAL_19]] -// CHECK: br label %[[VAL_38:.*]] -// CHECK: omp_section_loop.aftersections.fini: ; preds = %[[VAL_37]] // CHECK: ret void -// CHECK: .cncl: ; preds = %[[VAL_27]] -// CHECK: br label %[[VAL_19]] +// CHECK: .cncl: +// CHECK: br label %[[OMP_SECTION_LOOP_EXIT:.*]] +// CHECK: .cncl{{.*}}: +// CHECK: br label %[[OMP_SECTION_LOOP_EXIT:.*]] llvm.func @cancel_wsloop_if(%lb : i32, %ub : i32, %step : i32, %cond : i1) { omp.wsloop { @@ -221,18 +237,23 @@ llvm.func @cancel_wsloop_if(%lb : i32, %ub : i32, %step : i32, %cond : i1) { // CHECK: %[[VAL_47:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_46]], i32 2) // CHECK: %[[VAL_48:.*]] = icmp eq i32 %[[VAL_47]], 0 // CHECK: br i1 %[[VAL_48]], label %[[VAL_49:.*]], label %[[VAL_50:.*]] -// CHECK: .split: ; preds = %[[VAL_44]] +// CHECK: .split{{.*}}: // CHECK: br label %[[VAL_51:.*]] -// CHECK: 28: ; preds = %[[VAL_42]] +// CHECK: 28: +// CHECK: %[[GTN:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) +// CHECK: %[[CANCEL_POINT:.*]] = call i32 @__kmpc_cancellationpoint(ptr @1, i32 %[[GTN]], i32 2) +// CHECK: %[[COND:.*]] = icmp eq i32 %[[CANCEL_POINT]], 0 +// CHECK: br i1 %[[COND]], label %[[SPLIT3:.*]], label %[[CNCL4:.*]] +// CHECK: .split{{.*}}: // CHECK: br label %[[VAL_51]] -// CHECK: 29: ; preds = %[[VAL_45]], %[[VAL_49]] +// CHECK: 31: // CHECK: br label %[[VAL_52:.*]] // CHECK: omp.region.cont1: ; preds = %[[VAL_51]] // CHECK: br label %[[VAL_32]] // CHECK: omp_loop.inc: ; preds = %[[VAL_52]] // CHECK: %[[VAL_34]] = add nuw i32 %[[VAL_33]], 1 // CHECK: br label %[[VAL_31]] -// CHECK: omp_loop.exit: ; preds = %[[VAL_50]], %[[VAL_35]] +// CHECK: omp_loop.exit: // CHECK: call void @__kmpc_for_static_fini(ptr @1, i32 %[[VAL_26]]) // CHECK: %[[VAL_53:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) // CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_53]]) @@ -241,8 +262,12 @@ llvm.func @cancel_wsloop_if(%lb : i32, %ub : i32, %step : i32, %cond : i1) { // CHECK: br label %[[VAL_55:.*]] // CHECK: omp.region.cont: ; preds = %[[VAL_54]] // CHECK: ret void -// CHECK: .cncl: ; preds = %[[VAL_44]] -// CHECK: br label %[[VAL_38]] +// CHECK: .cncl{{.*}}: +// CHECK: br label %[[FINI:.*]] +// CHECK: .fini: +// CHECK: br label %[[OMP_LOOP_EXIT:.*]] +// CHECK: .cncl{{.*}}: +// CHECK: br label %[[FINI:.*]] omp.private {type = firstprivate} @i32_priv : i32 copy { ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr): diff --git a/mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir b/mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir index 5e0d3f9..93fa2064 100644 --- a/mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir +++ b/mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir @@ -24,16 +24,18 @@ llvm.func @cancellation_point_parallel() { // CHECK: %[[VAL_15:.*]] = icmp eq i32 %[[VAL_14]], 0 // CHECK: br i1 %[[VAL_15]], label %[[VAL_16:.*]], label %[[VAL_17:.*]] // CHECK: omp.par.region1.cncl: ; preds = %[[VAL_12]] +// CHECK: br label %[[FINI:.*]] +// CHECK: .fini: // CHECK: %[[VAL_18:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) // CHECK: %[[VAL_19:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[VAL_18]]) -// CHECK: br label %[[VAL_20:.*]] +// CHECK: br label %[[EXIT_STUB:.*]] // CHECK: omp.par.region1.split: ; preds = %[[VAL_12]] // CHECK: br label %[[VAL_21:.*]] // CHECK: omp.region.cont: ; preds = %[[VAL_16]] // CHECK: br label %[[VAL_22:.*]] // CHECK: omp.par.pre_finalize: ; preds = %[[VAL_21]] -// CHECK: br label %[[VAL_20]] -// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_22]], %[[VAL_17]] +// CHECK: br label %[[FINI]] +// CHECK: omp.par.exit.exitStub: // CHECK: ret void llvm.func @cancellation_point_sections() { @@ -94,14 +96,12 @@ llvm.func @cancellation_point_sections() { // CHECK: omp_section_loop.inc: ; preds = %[[VAL_46]] // CHECK: %[[VAL_38]] = add nuw i32 %[[VAL_37]], 1 // CHECK: br label %[[VAL_35]] -// CHECK: omp_section_loop.exit: ; preds = %[[VAL_53]], %[[VAL_39]] +// CHECK: omp_section_loop.exit: // CHECK: call void @__kmpc_for_static_fini(ptr @1, i32 %[[VAL_30]]) // CHECK: %[[VAL_55:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) // CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_55]]) // CHECK: br label %[[VAL_56:.*]] // CHECK: omp_section_loop.after: ; preds = %[[VAL_42]] -// CHECK: br label %[[VAL_57:.*]] -// CHECK: omp_section_loop.aftersections.fini: ; preds = %[[VAL_56]] // CHECK: ret void // CHECK: omp.section.region.cncl: ; preds = %[[VAL_48]] // CHECK: br label %[[VAL_42]] @@ -175,7 +175,7 @@ llvm.func @cancellation_point_wsloop(%lb : i32, %ub : i32, %step : i32) { // CHECK: omp_loop.inc: ; preds = %[[VAL_106]] // CHECK: %[[VAL_92]] = add nuw i32 %[[VAL_91]], 1 // CHECK: br label %[[VAL_89]] -// CHECK: omp_loop.exit: ; preds = %[[VAL_105]], %[[VAL_93]] +// CHECK: omp_loop.exit: // CHECK: call void @__kmpc_for_static_fini(ptr @1, i32 %[[VAL_84]]) // CHECK: %[[VAL_107:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) // CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_107]]) diff --git a/mlir/test/Target/LLVMIR/openmp-dist_schedule.mlir b/mlir/test/Target/LLVMIR/openmp-dist_schedule.mlir new file mode 100644 index 0000000..a0dd556 --- /dev/null +++ b/mlir/test/Target/LLVMIR/openmp-dist_schedule.mlir @@ -0,0 +1,34 @@ +// Test that dist_schedule gets correctly translated with the correct schedule type and chunk size where appropriate + +// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s + +llvm.func @distribute_dist_schedule_chunk_size(%lb : i32, %ub : i32, %step : i32, %x : i32) { + // CHECK: call void @[[RUNTIME_FUNC:__kmpc_for_static_init_4u]](ptr @1, i32 %omp_global_thread_num, i32 91, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 1024) + // We want to make sure that the next call is not another init builder. + // CHECK-NOT: call void @[[RUNTIME_FUNC]] + %1 = llvm.mlir.constant(1024: i32) : i32 + omp.distribute dist_schedule_static dist_schedule_chunk_size(%1 : i32) { + omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) { + omp.yield + } + } + llvm.return +} + +// When a chunk size is present, we need to make sure the correct parallel accesses metadata is added +// CHECK: !2 = !{!"llvm.loop.parallel_accesses", !3} +// CHECK-NEXT: !3 = distinct !{} + +// ----- + +llvm.func @distribute_dist_schedule(%lb : i32, %ub : i32, %step : i32, %x : i32) { + // CHECK: call void @[[RUNTIME_FUNC:__kmpc_for_static_init_4u]](ptr @1, i32 %omp_global_thread_num, i32 92, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 0) + // We want to make sure that the next call is not another init builder. + // CHECK-NOT: call void @[[RUNTIME_FUNC]] + omp.distribute dist_schedule_static { + omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) { + omp.yield + } + } + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/openmp-dist_schedule_with_wsloop.mlir b/mlir/test/Target/LLVMIR/openmp-dist_schedule_with_wsloop.mlir new file mode 100644 index 0000000..dad32b4 --- /dev/null +++ b/mlir/test/Target/LLVMIR/openmp-dist_schedule_with_wsloop.mlir @@ -0,0 +1,205 @@ +// Test that dist_schedule gets correctly translated with the correct schedule type and chunk size where appropriate while using workshare loops. + +// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s + +llvm.func @distribute_wsloop_dist_schedule_chunked_schedule_chunked(%n: i32, %teams: i32, %threads: i32, %dcs: i32) { + %0 = llvm.mlir.constant(0 : i32) : i32 + %1 = llvm.mlir.constant(1 : i32) : i32 + %scs = llvm.mlir.constant(64 : i32) : i32 + + omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) { + omp.parallel { + omp.distribute dist_schedule_static dist_schedule_chunk_size(%dcs : i32) { + omp.wsloop schedule(static = %scs : i32) { + omp.loop_nest (%i) : i32 = (%0) to (%n) step (%1) { + omp.yield + } + } {omp.composite} + } {omp.composite} + omp.terminator + } {omp.composite} + omp.terminator + } + llvm.return +} +// CHECK: define internal void @distribute_wsloop_dist_schedule_chunked_schedule_chunked..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #0 { +// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 33, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 64) +// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 91, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 %3) + +llvm.func @distribute_wsloop_dist_schedule_chunked_schedule_chunked_i64(%n: i32, %teams: i32, %threads: i32) { + %0 = llvm.mlir.constant(0 : i64) : i64 + %1 = llvm.mlir.constant(1 : i64) : i64 + %dcs = llvm.mlir.constant(1024 : i64) : i64 + %scs = llvm.mlir.constant(64 : i64) : i64 + %n64 = llvm.zext %n : i32 to i64 + + omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) { + omp.parallel { + omp.distribute dist_schedule_static dist_schedule_chunk_size(%dcs : i64) { + omp.wsloop schedule(static = %scs : i64) { + omp.loop_nest (%i) : i64 = (%0) to (%n64) step (%1) { + omp.yield + } + } {omp.composite} + } {omp.composite} + omp.terminator + } {omp.composite} + omp.terminator + } + llvm.return +} +// CHECK: define internal void @distribute_wsloop_dist_schedule_chunked_schedule_chunked_i64..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #0 { +// CHECK: call void @__kmpc_for_static_init_8u(ptr @1, i32 %omp_global_thread_num9, i32 33, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i64 1, i64 64) +// call void @__kmpc_for_static_init_8u(ptr @1, i32 %omp_global_thread_num9, i32 91, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i64 1, i64 1024) + +// ----- + +llvm.func @distribute_wsloop_dist_schedule_chunked(%n: i32, %teams: i32, %threads: i32) { + %0 = llvm.mlir.constant(0 : i32) : i32 + %1 = llvm.mlir.constant(1 : i32) : i32 + %dcs = llvm.mlir.constant(1024 : i32) : i32 + + omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) { + omp.parallel { + omp.distribute dist_schedule_static dist_schedule_chunk_size(%dcs : i32) { + omp.wsloop schedule(static) { + omp.loop_nest (%i) : i32 = (%0) to (%n) step (%1) { + omp.yield + } + } {omp.composite} + } {omp.composite} + omp.terminator + } {omp.composite} + omp.terminator + } + llvm.return +} +// CHECK: define internal void @distribute_wsloop_dist_schedule_chunked..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #0 { +// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 34, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 0) +// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 91, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 1024) + +llvm.func @distribute_wsloop_dist_schedule_chunked_i64(%n: i32, %teams: i32, %threads: i32) { + %0 = llvm.mlir.constant(0 : i64) : i64 + %1 = llvm.mlir.constant(1 : i64) : i64 + %dcs = llvm.mlir.constant(1024 : i64) : i64 + %n64 = llvm.zext %n : i32 to i64 + + omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) { + omp.parallel { + omp.distribute dist_schedule_static dist_schedule_chunk_size(%dcs : i64) { + omp.wsloop schedule(static) { + omp.loop_nest (%i) : i64 = (%0) to (%n64) step (%1) { + omp.yield + } + } {omp.composite} + } {omp.composite} + omp.terminator + } {omp.composite} + omp.terminator + } + llvm.return +} +// CHECK: define internal void @distribute_wsloop_dist_schedule_chunked_i64..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #0 { +// CHECK: call void @__kmpc_for_static_init_8u(ptr @1, i32 %omp_global_thread_num9, i32 34, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i64 1, i64 0) +// CHECK: call void @__kmpc_for_static_init_8u(ptr @1, i32 %omp_global_thread_num9, i32 91, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i64 1, i64 1024) + +// ----- + +llvm.func @distribute_wsloop_schedule_chunked(%n: i32, %teams: i32, %threads: i32) { + %0 = llvm.mlir.constant(0 : i32) : i32 + %1 = llvm.mlir.constant(1 : i32) : i32 + %scs = llvm.mlir.constant(64 : i32) : i32 + + omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) { + omp.parallel { + omp.distribute dist_schedule_static { + omp.wsloop schedule(static = %scs : i32) { + omp.loop_nest (%i) : i32 = (%0) to (%n) step (%1) { + omp.yield + } + } {omp.composite} + } {omp.composite} + omp.terminator + } {omp.composite} + omp.terminator + } + llvm.return +} +// CHECK: define internal void @distribute_wsloop_schedule_chunked..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #0 { +// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 33, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 64) +// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 92, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 0) + +llvm.func @distribute_wsloop_schedule_chunked_i64(%n: i32, %teams: i32, %threads: i32) { + %0 = llvm.mlir.constant(0 : i64) : i64 + %1 = llvm.mlir.constant(1 : i64) : i64 + %scs = llvm.mlir.constant(64 : i64) : i64 + %n64 = llvm.zext %n : i32 to i64 + + omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) { + omp.parallel { + omp.distribute dist_schedule_static { + omp.wsloop schedule(static = %scs : i64) { + omp.loop_nest (%i) : i64 = (%0) to (%n64) step (%1) { + omp.yield + } + } {omp.composite} + } {omp.composite} + omp.terminator + } {omp.composite} + omp.terminator + } + llvm.return +} + +// CHECK: define internal void @distribute_wsloop_schedule_chunked_i64..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #0 { +// CHECK: call void @__kmpc_for_static_init_8u(ptr @1, i32 %omp_global_thread_num9, i32 33, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i64 1, i64 64) +// CHECK: call void @__kmpc_for_static_init_8u(ptr @1, i32 %omp_global_thread_num9, i32 92, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i64 1, i64 0) + +// ----- + +llvm.func @distribute_wsloop_no_chunks(%n: i32, %teams: i32, %threads: i32) { + %0 = llvm.mlir.constant(0 : i32) : i32 + %1 = llvm.mlir.constant(1 : i32) : i32 + + omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) { + omp.parallel { + omp.distribute dist_schedule_static { + omp.wsloop schedule(static) { + omp.loop_nest (%i) : i32 = (%0) to (%n) step (%1) { + omp.yield + } + } {omp.composite} + } {omp.composite} + omp.terminator + } {omp.composite} + omp.terminator + } + llvm.return +} +// CHECK: define internal void @distribute_wsloop_no_chunks..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #0 { +// CHECK: call void @__kmpc_dist_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 34, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.distupperbound, ptr %p.stride, i32 1, i32 0) +// CHECK: call void @__kmpc_dist_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 92, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.distupperbound10, ptr %p.stride, i32 1, i32 0) + +llvm.func @distribute_wsloop_no_chunks_i64(%n: i32, %teams: i32, %threads: i32) { + %0 = llvm.mlir.constant(0 : i64) : i64 + %1 = llvm.mlir.constant(1 : i64) : i64 + %n64 = llvm.zext %n : i32 to i64 + + omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) { + omp.parallel { + omp.distribute dist_schedule_static { + omp.wsloop schedule(static) { + omp.loop_nest (%i) : i64 = (%0) to (%n64) step (%1) { + omp.yield + } + } {omp.composite} + } {omp.composite} + omp.terminator + } {omp.composite} + omp.terminator + } + llvm.return +} +// CHECK: define internal void @distribute_wsloop_no_chunks_i64..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #0 { +// CHECK: call void @__kmpc_dist_for_static_init_8u(ptr @1, i32 %omp_global_thread_num9, i32 34, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.distupperbound, ptr %p.stride, i64 1, i64 0) +// CHECK: call void @__kmpc_dist_for_static_init_8u(ptr @1, i32 %omp_global_thread_num9, i32 92, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.distupperbound10, ptr %p.stride, i64 1, i64 0)
\ No newline at end of file diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir index 8bd33a3..1eb501c 100644 --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -328,6 +328,52 @@ llvm.func @test_omp_masked(%arg0: i32)-> () { // ----- +llvm.func @wsloop_linear(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) { +// CHECK-LABEL: @wsloop_linear + +// CHECK: %p.lastiter = alloca i32, align 4 +// CHECK: %p.lowerbound = alloca i32, align 4 +// CHECK: %p.upperbound = alloca i32, align 4 +// CHECK: %p.stride = alloca i32, align 4 +// CHECK: %[[LINEAR_VAR:.*]] = alloca i32, align 4 +// CHECK: %[[LINEAR_RESULT:.*]] = alloca i32, align 4 + +// CHECK: omp_loop.preheader: +// CHECK: %[[LOAD:.*]] = load i32, ptr %{{.*}}, align 4 +// CHECK: store i32 %[[LOAD]], ptr %[[LINEAR_VAR]], align 4 + +// CHECK: omp_loop.body: +// CHECK: %[[LOOP_IV_CALC:.*]] = add i32 %omp_loop.iv, {{.*}} +// CHECK: %[[LINEAR_VAR_LOAD:.*]] = load i32, ptr %[[LINEAR_VAR]], align 4 +// CHECK: %[[MUL:.*]] = mul i32 %[[LOOP_IV_CALC]], {{.*}} +// CHECK: %[[ADD:.*]] = add i32 %[[LINEAR_VAR_LOAD]], %[[MUL]] +// CHECK: store i32 %[[ADD]], ptr %[[LINEAR_RESULT]], align 4 + +// CHECK: omp_loop.linear_finalization: +// CHECK: %[[ITER:.*]] = load i32, ptr %p.lastiter, align 4 +// CHECK: %[[CMP:.*]] = icmp ne i32 %[[ITER]], 0 +// CHECK: br i1 %[[CMP]], label %omp_loop.linear_lastiter_exit, label %omp_loop.linear_exit + +// CHECK: omp_loop.linear_lastiter_exit: +// CHECK: %[[LOAD:.*]] = load i32, ptr %[[LINEAR_RESULT]], align 4 +// CHECK: store i32 %[[LOAD]], ptr {{.*}}, align 4 +// CHECK: br label %omp_loop.linear_exit + +// CHECK: omp_loop.linear_exit: +// CHECK: %[[THREAD_ID:.*]] = call i32 @__kmpc_global_thread_num(ptr {{.*}}) +// CHECK: call void @__kmpc_barrier(ptr {{.*}}, i32 %[[THREAD_ID]]) +// CHECK: br label %omp_loop.after + + omp.wsloop linear(%x = %step : !llvm.ptr) { + omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) { + omp.yield + } + } {linear_var_types = [i32]} + llvm.return +} + +// ----- + // CHECK: %struct.ident_t = type // CHECK: @[[$loc:.*]] = private unnamed_addr constant {{.*}} c";unknown;unknown;{{[0-9]+}};{{[0-9]+}};;\00" // CHECK: @[[$loc_struct:.*]] = private unnamed_addr constant %struct.ident_t {{.*}} @[[$loc]] {{.*}} @@ -695,6 +741,34 @@ llvm.func @simd_simple(%lb : i64, %ub : i64, %step : i64, %arg0: !llvm.ptr) { // ----- +llvm.func @simd_linear(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) { + +// CHECK-LABEL: @simd_linear + +// CHECK: %[[LINEAR_VAR:.*]] = alloca i32, align 4 +// CHECK: %[[LINEAR_RESULT:.*]] = alloca i32, align 4 + +// CHECK: omp_loop.preheader: +// CHECK: %[[LOAD:.*]] = load i32, ptr {{.*}}, align 4 +// CHECK: store i32 %[[LOAD]], ptr %[[LINEAR_VAR]], align 4 + +// CHECK: omp_loop.body: +// CHECK: %[[LOOP_IV_CALC:.*]] = mul i32 %omp_loop.iv, {{.*}} +// CHECK: %[[ADD:.*]] = add i32 %[[LOOP_IV_CALC]], {{.*}} +// CHECK: %[[LOAD:.*]] = load i32, ptr %[[LINEAR_VAR]], align 4, !llvm.access.group !1 +// CHECK: %[[MUL:.*]] = mul i32 %omp_loop.iv, {{.*}} +// CHECK: %[[ADD:.*]] = add i32 %[[LOAD]], %[[MUL]] +// CHECK: store i32 %[[ADD]], ptr %[[LINEAR_RESULT]], align 4, !llvm.access.group !1 + omp.simd linear(%x = %step : !llvm.ptr) { + omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) { + omp.yield + } + } {linear_var_types = [i32]} + llvm.return +} + +// ----- + // CHECK-LABEL: @simd_simple_multiple llvm.func @simd_simple_multiple(%lb1 : i64, %ub1 : i64, %step1 : i64, %lb2 : i64, %ub2 : i64, %step2 : i64, %arg0: !llvm.ptr, %arg1: !llvm.ptr) { omp.simd { diff --git a/mlir/test/Target/LLVMIR/openmp-outline-infinite-loop.mlir b/mlir/test/Target/LLVMIR/openmp-outline-infinite-loop.mlir index faccfc6..99f37c7 100644 --- a/mlir/test/Target/LLVMIR/openmp-outline-infinite-loop.mlir +++ b/mlir/test/Target/LLVMIR/openmp-outline-infinite-loop.mlir @@ -21,9 +21,11 @@ llvm.func @parallel_infinite_loop() -> () { // CHECK: omp.region.cont: ; No predecessors! // CHECK: br label %[[VAL_4:.*]] // CHECK: omp.par.pre_finalize: ; preds = %[[VAL_5:.*]] -// CHECK: br label %[[VAL_6:.*]] -// CHECK: omp.par.exit: ; preds = %[[VAL_4]] +// CHECK: br label %[[FINI:.*]] +// CHECK: [[OMP_PAR_EXIT:omp.par.exit]]: ; preds = %[[FINI]] // CHECK: ret void +// CHECK: [[FINI]]: +// CHECK: br label %[[OMP_PAR_EXIT]] // CHECK: } // CHECK-LABEL: define internal void @parallel_infinite_loop..omp_par( diff --git a/mlir/test/Target/LLVMIR/openmp-parallel-reduction-multiblock.mlir b/mlir/test/Target/LLVMIR/openmp-parallel-reduction-multiblock.mlir index 887d297..c79c369 100644 --- a/mlir/test/Target/LLVMIR/openmp-parallel-reduction-multiblock.mlir +++ b/mlir/test/Target/LLVMIR/openmp-parallel-reduction-multiblock.mlir @@ -108,6 +108,8 @@ llvm.func @missordered_blocks_(%arg0: !llvm.ptr {fir.bindc_name = "x"}, %arg1: ! // CHECK: reduce.finalize: ; preds = %[[VAL_49]], %[[VAL_43]] // CHECK: br label %[[VAL_53:.*]] // CHECK: omp.par.pre_finalize: ; preds = %[[VAL_48]] +// CHECK: br label %[[FINI:.*]] +// CHECK: .fini: // CHECK: %[[VAL_54:.*]] = load ptr, ptr %[[VAL_20]], align 8 // CHECK: %[[VAL_55:.*]] = load ptr, ptr %[[VAL_21]], align 8 // CHECK: br label %[[VAL_56:.*]] @@ -115,5 +117,5 @@ llvm.func @missordered_blocks_(%arg0: !llvm.ptr {fir.bindc_name = "x"}, %arg1: ! // CHECK: br label %[[VAL_38]] // CHECK: omp.reduction.neutral1: ; preds = %[[VAL_25]] // CHECK: br label %[[VAL_30]] -// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_53]] +// CHECK: omp.par.exit.exitStub: ; preds = %[[FINI]] // CHECK: ret void diff --git a/mlir/test/Target/LLVMIR/openmp-reduction-array-sections.mlir b/mlir/test/Target/LLVMIR/openmp-reduction-array-sections.mlir index b302b4b..13f52f0 100644 --- a/mlir/test/Target/LLVMIR/openmp-reduction-array-sections.mlir +++ b/mlir/test/Target/LLVMIR/openmp-reduction-array-sections.mlir @@ -127,8 +127,6 @@ llvm.func @sectionsreduction_(%arg0: !llvm.ptr {fir.bindc_name = "x"}) attribute // CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_36]]) // CHECK: br label %[[VAL_37:.*]] // CHECK: omp_section_loop.after: ; preds = %[[VAL_35]] -// CHECK: br label %[[VAL_38:.*]] -// CHECK: omp_section_loop.aftersections.fini: ; preds = %[[VAL_37]] // CHECK: %[[VAL_39:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_14]], i64 0, i64 0 // CHECK: store ptr %[[VAL_21]], ptr %[[VAL_39]], align 8 // CHECK: %[[VAL_40:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) @@ -137,9 +135,9 @@ llvm.func @sectionsreduction_(%arg0: !llvm.ptr {fir.bindc_name = "x"}) attribute // CHECK: i32 1, label %[[VAL_43:.*]] // CHECK: i32 2, label %[[VAL_44:.*]] // CHECK: ] -// CHECK: reduce.switch.atomic: ; preds = %[[VAL_38]] +// CHECK: reduce.switch.atomic: ; preds = %[[VAL_37]] // CHECK: unreachable -// CHECK: reduce.switch.nonatomic: ; preds = %[[VAL_38]] +// CHECK: reduce.switch.nonatomic: ; preds = %[[VAL_37]] // CHECK: %[[VAL_45:.*]] = load ptr, ptr %[[VAL_21]], align 8 // CHECK: br label %[[VAL_46:.*]] // CHECK: omp.reduction.nonatomic.body: ; preds = %[[VAL_43]] @@ -157,7 +155,7 @@ llvm.func @sectionsreduction_(%arg0: !llvm.ptr {fir.bindc_name = "x"}) attribute // CHECK: omp.reduction.nonatomic.body17: ; preds = %[[VAL_47]] // CHECK: %[[VAL_50]] = sub i64 %[[VAL_49]], 1 // CHECK: br label %[[VAL_47]] -// CHECK: reduce.finalize: ; preds = %[[VAL_53]], %[[VAL_38]] +// CHECK: reduce.finalize: ; preds = %[[VAL_53]], %[[VAL_37]] // CHECK: %[[VAL_55:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) // CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_55]]) // CHECK: %[[VAL_56:.*]] = load ptr, ptr %[[VAL_21]], align 8 @@ -173,7 +171,9 @@ llvm.func @sectionsreduction_(%arg0: !llvm.ptr {fir.bindc_name = "x"}) attribute // CHECK: omp.region.cont: ; preds = %[[VAL_62]] // CHECK: br label %[[VAL_64:.*]] // CHECK: omp.par.pre_finalize: ; preds = %[[VAL_63]] -// CHECK: br label %[[VAL_65:.*]] +// CHECK: br label %[[FINI:.fini.*]] +// CHECK: [[FINI]]: +// CHECK: br label %[[EXIT:.*]] // CHECK: omp.reduction.cleanup21: ; preds = %[[VAL_57]] // CHECK: br label %[[VAL_61]] // CHECK: omp_section_loop.body: ; preds = %[[VAL_32]] @@ -219,5 +219,5 @@ llvm.func @sectionsreduction_(%arg0: !llvm.ptr {fir.bindc_name = "x"}) attribute // CHECK: omp_section_loop.inc: ; preds = %[[VAL_69]] // CHECK: %[[VAL_31]] = add nuw i32 %[[VAL_30]], 1 // CHECK: br label %[[VAL_28]] -// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_64]] +// CHECK: omp.par.exit.exitStub: ; preds = %[[FINI]] // CHECK: ret void diff --git a/mlir/test/Target/LLVMIR/openmp-reduction-init-arg.mlir b/mlir/test/Target/LLVMIR/openmp-reduction-init-arg.mlir index a714ca6..cb30d3b 100644 --- a/mlir/test/Target/LLVMIR/openmp-reduction-init-arg.mlir +++ b/mlir/test/Target/LLVMIR/openmp-reduction-init-arg.mlir @@ -96,8 +96,10 @@ module { // CHECK: reduce.finalize: ; preds = %[[VAL_34]], %[[VAL_28]] // CHECK: br label %[[VAL_38:.*]] // CHECK: omp.par.pre_finalize: ; preds = %[[VAL_33]] +// CHECK: br label %[[FINI:.*]] +// CHECK: [[FINI]]: // CHECK: br label %[[VAL_39:.*]] -// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_38]] +// CHECK: omp.par.exit.exitStub: ; preds = %[[FINI]] // CHECK: ret void // CHECK: %[[VAL_40:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_41:.*]], i64 0, i64 0 // CHECK: %[[VAL_42:.*]] = load ptr, ptr %[[VAL_40]], align 8 diff --git a/mlir/test/Target/LLVMIR/openmp-reduction-sections.mlir b/mlir/test/Target/LLVMIR/openmp-reduction-sections.mlir index 19da6f8..00f6c1b 100644 --- a/mlir/test/Target/LLVMIR/openmp-reduction-sections.mlir +++ b/mlir/test/Target/LLVMIR/openmp-reduction-sections.mlir @@ -86,8 +86,6 @@ llvm.func @sections_(%arg0: !llvm.ptr {fir.bindc_name = "x"}) attributes {fir.in // CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_40]]) // CHECK: br label %[[VAL_41:.*]] // CHECK: omp_section_loop.after: ; preds = %[[VAL_39]] -// CHECK: br label %[[VAL_42:.*]] -// CHECK: omp_section_loop.aftersections.fini: ; preds = %[[VAL_41]] // CHECK: %[[VAL_43:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_21]], i64 0, i64 0 // CHECK: store ptr %[[VAL_20]], ptr %[[VAL_43]], align 8 // CHECK: %[[VAL_44:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) @@ -96,23 +94,25 @@ llvm.func @sections_(%arg0: !llvm.ptr {fir.bindc_name = "x"}) attributes {fir.in // CHECK: i32 1, label %[[VAL_47:.*]] // CHECK: i32 2, label %[[VAL_48:.*]] // CHECK: ] -// CHECK: reduce.switch.atomic: ; preds = %[[VAL_42]] +// CHECK: reduce.switch.atomic: ; preds = %[[VAL_41]] // CHECK: unreachable -// CHECK: reduce.switch.nonatomic: ; preds = %[[VAL_42]] +// CHECK: reduce.switch.nonatomic: ; preds = %[[VAL_41]] // CHECK: %[[VAL_49:.*]] = load float, ptr %[[VAL_11]], align 4 // CHECK: %[[VAL_50:.*]] = load float, ptr %[[VAL_20]], align 4 // CHECK: %[[VAL_51:.*]] = fadd contract float %[[VAL_49]], %[[VAL_50]] // CHECK: store float %[[VAL_51]], ptr %[[VAL_11]], align 4 // CHECK: call void @__kmpc_end_reduce(ptr @1, i32 %[[VAL_44]], ptr @.gomp_critical_user_.reduction.var) // CHECK: br label %[[VAL_46]] -// CHECK: reduce.finalize: ; preds = %[[VAL_47]], %[[VAL_42]] +// CHECK: reduce.finalize: ; preds = %[[VAL_47]], %[[VAL_41]] // CHECK: %[[VAL_52:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) // CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_52]]) // CHECK: br label %[[VAL_53:.*]] // CHECK: omp.region.cont: ; preds = %[[VAL_46]] // CHECK: br label %[[VAL_54:.*]] // CHECK: omp.par.pre_finalize: ; preds = %[[VAL_53]] -// CHECK: br label %[[VAL_55:.*]] +// CHECK: br label %[[FINI:.fini.*]] +// CHECK: [[FINI]]: +// CHECK: br label %[[EXIT:.*]] // CHECK: omp_section_loop.body: ; preds = %[[VAL_36]] // CHECK: %[[VAL_56:.*]] = add i32 %[[VAL_34]], %[[VAL_28]] // CHECK: %[[VAL_57:.*]] = mul i32 %[[VAL_56]], 1 @@ -144,8 +144,10 @@ llvm.func @sections_(%arg0: !llvm.ptr {fir.bindc_name = "x"}) attributes {fir.in // CHECK: omp_section_loop.inc: ; preds = %[[VAL_59]] // CHECK: %[[VAL_35]] = add nuw i32 %[[VAL_34]], 1 // CHECK: br label %[[VAL_32]] -// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_54]] +// CHECK: omp.par.exit.exitStub: ; preds = %[[FINI]] // CHECK: ret void + +// CHECK-LABEL: define internal void @.omp.reduction.func // CHECK: %[[VAL_70:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_71:.*]], i64 0, i64 0 // CHECK: %[[VAL_72:.*]] = load ptr, ptr %[[VAL_70]], align 8 // CHECK: %[[VAL_73:.*]] = load float, ptr %[[VAL_72]], align 4 diff --git a/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir index 504d91b..5c37817 100644 --- a/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir +++ b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir @@ -102,7 +102,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo // DEVICE: call void @__kmpc_distribute_static_loop{{.*}}({{.*}}, ptr @[[DISTRIBUTE_OUTLINE:[^,]*]], {{.*}}) // DEVICE: define internal void @[[DISTRIBUTE_OUTLINE]]({{.*}}) -// DEVICE: call void @__kmpc_parallel_51(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}}) +// DEVICE: call void @__kmpc_parallel_60(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}}, i32 {{.*}}) // DEVICE: define internal void @[[PARALLEL_OUTLINE]]({{.*}}) // DEVICE: call void @__kmpc_for_static_loop{{.*}}({{.*}}) diff --git a/mlir/test/Target/LLVMIR/openmp-target-spmd.mlir b/mlir/test/Target/LLVMIR/openmp-target-spmd.mlir index 20202fc..dae80ba 100644 --- a/mlir/test/Target/LLVMIR/openmp-target-spmd.mlir +++ b/mlir/test/Target/LLVMIR/openmp-target-spmd.mlir @@ -84,7 +84,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo // DEVICE: call void @__kmpc_target_deinit() // DEVICE: define internal void @[[TARGET_OUTLINE]]({{.*}}) -// DEVICE: call void @__kmpc_parallel_51(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}}) +// DEVICE: call void @__kmpc_parallel_60(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}}, i32 {{.*}}) // DEVICE: define internal void @[[PARALLEL_OUTLINE]]({{.*}}) // DEVICE: call void @[[DISTRIBUTE_OUTLINE:.*]]({{.*}}) diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir index af6d254..396c57a 100644 --- a/mlir/test/Target/LLVMIR/openmp-todo.mlir +++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir @@ -39,19 +39,6 @@ llvm.func @distribute_allocate(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr // ----- -llvm.func @distribute_dist_schedule(%lb : i32, %ub : i32, %step : i32, %x : i32) { - // expected-error@below {{not yet implemented: Unhandled clause dist_schedule with chunk_size in omp.distribute operation}} - // expected-error@below {{LLVM Translation failed for operation: omp.distribute}} - omp.distribute dist_schedule_static dist_schedule_chunk_size(%x : i32) { - omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) { - omp.yield - } - } - llvm.return -} - -// ----- - llvm.func @distribute_order(%lb : i32, %ub : i32, %step : i32) { // expected-error@below {{not yet implemented: Unhandled clause order in omp.distribute operation}} // expected-error@below {{LLVM Translation failed for operation: omp.distribute}} @@ -116,19 +103,6 @@ llvm.func @sections_private(%x : !llvm.ptr) { // ----- -llvm.func @simd_linear(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) { - // expected-error@below {{not yet implemented: Unhandled clause linear in omp.simd operation}} - // expected-error@below {{LLVM Translation failed for operation: omp.simd}} - omp.simd linear(%x = %step : !llvm.ptr) { - omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) { - omp.yield - } - } - llvm.return -} - -// ----- - omp.declare_reduction @add_f32 : f32 init { ^bb0(%arg: f32): @@ -238,17 +212,6 @@ llvm.func @target_in_reduction(%x : !llvm.ptr) { // ----- -llvm.func @target_is_device_ptr(%x : !llvm.ptr) { - // expected-error@below {{not yet implemented: Unhandled clause is_device_ptr in omp.target operation}} - // expected-error@below {{LLVM Translation failed for operation: omp.target}} - omp.target is_device_ptr(%x : !llvm.ptr) { - omp.terminator - } - llvm.return -} - -// ----- - llvm.func @target_enter_data_depend(%x: !llvm.ptr) { // expected-error@below {{not yet implemented: Unhandled clause depend in omp.target_enter_data operation}} // expected-error@below {{LLVM Translation failed for operation: omp.target_enter_data}} @@ -448,19 +411,6 @@ llvm.func @wsloop_allocate(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) { } // ----- - -llvm.func @wsloop_linear(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) { - // expected-error@below {{not yet implemented: Unhandled clause linear in omp.wsloop operation}} - // expected-error@below {{LLVM Translation failed for operation: omp.wsloop}} - omp.wsloop linear(%x = %step : !llvm.ptr) { - omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) { - omp.yield - } - } - llvm.return -} - -// ----- llvm.func @wsloop_order(%lb : i32, %ub : i32, %step : i32) { // expected-error@below {{not yet implemented: Unhandled clause order in omp.wsloop operation}} // expected-error@below {{LLVM Translation failed for operation: omp.wsloop}} diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index 8a848221..2c748ad5 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -14,30 +14,36 @@ llvm.func @rocdl_special_regs() -> i32 { %5 = rocdl.workgroup.id.y : i32 // CHECK: call i32 @llvm.amdgcn.workgroup.id.z() %6 = rocdl.workgroup.id.z : i32 + // CHECK: call i32 @llvm.amdgcn.cluster.id.x() + %7 = rocdl.cluster.id.x : i32 + // CHECK: call i32 @llvm.amdgcn.cluster.id.y() + %8 = rocdl.cluster.id.y : i32 + // CHECK: call i32 @llvm.amdgcn.cluster.id.z() + %9 = rocdl.cluster.id.z : i32 // CHECK: call i64 @__ockl_get_local_size(i32 0) - %7 = rocdl.workgroup.dim.x : i64 + %10 = rocdl.workgroup.dim.x : i64 // CHECK: call i64 @__ockl_get_local_size(i32 1) - %8 = rocdl.workgroup.dim.y : i64 + %11 = rocdl.workgroup.dim.y : i64 // CHECK: call i64 @__ockl_get_local_size(i32 2) - %9 = rocdl.workgroup.dim.z : i64 + %12 = rocdl.workgroup.dim.z : i64 // CHECK: call i64 @__ockl_get_num_groups(i32 0) - %10 = rocdl.grid.dim.x : i64 + %13 = rocdl.grid.dim.x : i64 // CHECK: call i64 @__ockl_get_num_groups(i32 1) - %11 = rocdl.grid.dim.y : i64 + %14 = rocdl.grid.dim.y : i64 // CHECK: call i64 @__ockl_get_num_groups(i32 2) - %12 = rocdl.grid.dim.z : i64 + %15 = rocdl.grid.dim.z : i64 // CHECK: call range(i32 0, 64) i32 @llvm.amdgcn.workitem.id.x() - %13 = rocdl.workitem.id.x range <i32, 0, 64> : i32 + %16 = rocdl.workitem.id.x range <i32, 0, 64> : i32 // CHECK: call range(i64 1, 65) i64 @__ockl_get_local_size(i32 0) - %14 = rocdl.workgroup.dim.x range <i32, 1, 65> : i64 + %17 = rocdl.workgroup.dim.x range <i32, 1, 65> : i64 // CHECK: call i32 @llvm.amdgcn.wavefrontsize() - %15 = rocdl.wavefrontsize : i32 + %18 = rocdl.wavefrontsize : i32 // CHECK: call range(i32 32, 65) i32 @llvm.amdgcn.wavefrontsize() - %16 = rocdl.wavefrontsize range <i32, 32, 65> : i32 + %19 = rocdl.wavefrontsize range <i32, 32, 65> : i32 llvm.return %1 : i32 } @@ -55,6 +61,59 @@ llvm.func @kernel_func_workgroups() llvm.return } +llvm.func @kernel_math_ops(%a: f32, %b: f16, %c: bf16) { + // CHECK-LABEL: kernel_math_ops + // CHECK: call float @llvm.amdgcn.tanh.f32(float %{{.*}}) + // CHECK: call half @llvm.amdgcn.tanh.f16(half %{{.*}}) + // CHECK: call bfloat @llvm.amdgcn.tanh.bf16(bfloat %{{.*}}) + %tanh0 = rocdl.tanh %a f32 -> f32 + %tanh1 = rocdl.tanh %b f16 -> f16 + %tanh2 = rocdl.tanh %c bf16 -> bf16 + + // CHECK: call float @llvm.amdgcn.sin.f32(float %{{.*}}) + // CHECK: call half @llvm.amdgcn.sin.f16(half %{{.*}}) + // CHECK: call bfloat @llvm.amdgcn.sin.bf16(bfloat %{{.*}}) + %sin0 = rocdl.sin %a f32 -> f32 + %sin1 = rocdl.sin %b f16 -> f16 + %sin2 = rocdl.sin %c bf16 -> bf16 + + // CHECK: call float @llvm.amdgcn.cos.f32(float %{{.*}}) + // CHECK: call half @llvm.amdgcn.cos.f16(half %{{.*}}) + // CHECK: call bfloat @llvm.amdgcn.cos.bf16(bfloat %{{.*}}) + %cos0 = rocdl.cos %a f32 -> f32 + %cos1 = rocdl.cos %b f16 -> f16 + %cos2 = rocdl.cos %c bf16 -> bf16 + + // CHECK: call float @llvm.amdgcn.rcp.f32(float %{{.*}}) + // CHECK: call half @llvm.amdgcn.rcp.f16(half %{{.*}}) + // CHECK: call bfloat @llvm.amdgcn.rcp.bf16(bfloat %{{.*}}) + %rcp0 = rocdl.rcp %a f32 -> f32 + %rcp1 = rocdl.rcp %b f16 -> f16 + %rcp2 = rocdl.rcp %c bf16 -> bf16 + + // CHECK: call float @llvm.amdgcn.exp2.f32(float %{{.*}}) + // CHECK: call half @llvm.amdgcn.exp2.f16(half %{{.*}}) + // CHECK: call bfloat @llvm.amdgcn.exp2.bf16(bfloat %{{.*}}) + %exp2_0 = rocdl.exp2 %a f32 -> f32 + %exp2_1 = rocdl.exp2 %b f16 -> f16 + %exp2_2 = rocdl.exp2 %c bf16 -> bf16 + + // CHECK: call float @llvm.amdgcn.log.f32(float %{{.*}}) + // CHECK: call half @llvm.amdgcn.log.f16(half %{{.*}}) + // CHECK: call bfloat @llvm.amdgcn.log.bf16(bfloat %{{.*}}) + %log0 = rocdl.log %a f32 -> f32 + %log1 = rocdl.log %b f16 -> f16 + %log2 = rocdl.log %c bf16 -> bf16 + + // CHECK: call float @llvm.amdgcn.sqrt.f32(float %{{.*}}) + // CHECK: call half @llvm.amdgcn.sqrt.f16(half %{{.*}}) + // CHECK: call bfloat @llvm.amdgcn.sqrt.bf16(bfloat %{{.*}}) + %sqrt0 = rocdl.sqrt %a f32 -> f32 + %sqrt1 = rocdl.sqrt %b f16 -> f16 + %sqrt2 = rocdl.sqrt %c bf16 -> bf16 + llvm.return +} + llvm.func @known_block_sizes() attributes {rocdl.kernel, rocdl.flat_work_group_size = "128,128", @@ -248,6 +307,13 @@ llvm.func @rocdl.s.get.barrier.state() { llvm.return } +llvm.func @rocdl.s.get.named.barrier.state(%ptr : !llvm.ptr<3>) { + // CHECK-LABEL: rocdl.s.get.named.barrier.state + // CHECK: %[[STATE:.+]] = call i32 @llvm.amdgcn.s.get.named.barrier.state(ptr addrspace(3) %[[PTR:.+]]) + %0 = rocdl.s.get.named.barrier.state %ptr : i32 + llvm.return +} + llvm.func @rocdl.s.wait.dscnt() { // CHECK-LABEL: rocdl.s.wait.dscnt // CHECK-NEXT: call void @llvm.amdgcn.s.wait.dscnt(i16 0) @@ -875,140 +941,182 @@ llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : v %arg4 : vector<2xi32>, %arg5 : vector<4xi32>, %arg6 : vector<4xf32>, %arg7 : vector<8xf16>, %arg8 : vector<8xi16>, %arg9 : vector<32xf16>, %arg10 : vector<16xf32>, %arg11 : vector<4xf32>, %arg12 : vector<32xf32>, %arg13 : vector<64xf32>, %arg14 : vector<64xi32>, %arg15 : vector<64xf16>, %arg16 : vector<16xbf16>, %arg17 : vector<32xbf16>) -> vector<8xf32> { - %zero = llvm.mlir.constant(false) : i1 - %zero_i16 = llvm.mlir.constant(0 : i16) : i16 - // ---- Wave32 ----- + // ---- Wave32 ----- // f16 -> f32 - // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16.v8f32.v16f16(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <8 x float> %{{.*}}) + // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16.v8f32.v16f16(<16 x half> %{{.*}} <16 x half> %{{.*}} <8 x float> %{{.*}}) %r0 = rocdl.wmma.f32.16x16x16.f16 %arg1, %arg1, %arg0 : (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32> // bf16 -> f32 - // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf16.v8f32.v16i16(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <8 x float> %{{.*}}) + // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf16.v8f32.v16i16(<16 x i16> %{{.*}} <16 x i16> %{{.*}} <8 x float> %{{.*}}) %r1 = rocdl.wmma.f32.16x16x16.bf16 %arg2, %arg2, %arg0 : (vector<16xi16>, vector<16xi16>, vector<8xf32>) -> vector<8xf32> // f16 -> f16 (OPSEL = {0,1}) - // CHECK: call <16 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16.v16f16.v16f16(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <16 x half> %{{.*}}, i1 {{.*}}) - %r2 = rocdl.wmma.f16.16x16x16.f16 %arg1, %arg1, %arg1, %zero : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16> + // CHECK: call <16 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16.v16f16.v16f16(<16 x half> %{{.*}} <16 x half> %{{.*}} <16 x half> %{{.*}} i1 false) + %r2 = rocdl.wmma.f16.16x16x16.f16 %arg1, %arg1, %arg1 {opsel = false} : (vector<16xf16>, vector<16xf16>, vector<16xf16>) -> vector<16xf16> // bf16 -> bf16 (OPSEL = {0,1}) - // CHECK: call <16 x i16> @llvm.amdgcn.wmma.bf16.16x16x16.bf16.v16i16.v16i16(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <16 x i16> %{{.*}}, i1 {{.*}}) - %r4 = rocdl.wmma.bf16.16x16x16.bf16 %arg2, %arg2, %arg2, %zero : (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16> + // CHECK: call <16 x i16> @llvm.amdgcn.wmma.bf16.16x16x16.bf16.v16i16.v16i16(<16 x i16> %{{.*}} <16 x i16> %{{.*}} <16 x i16> %{{.*}} i1 false) + %r4 = rocdl.wmma.bf16.16x16x16.bf16 %arg2, %arg2, %arg2 {opsel = false} : (vector<16xi16>, vector<16xi16>, vector<16xi16>) -> vector<16xi16> // int8 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1}) - // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v8i32.v4i32(i1 {{.*}}, <4 x i32> %{{.*}}, i1 {{.*}}, <4 x i32> %{{.*}}, <8 x i32> %{{.*}}, i1 {{.*}}) - %r5 = rocdl.wmma.i32.16x16x16.iu8 %zero, %arg5, %zero, %arg5, %arg3, %zero : (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32> + // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v8i32.v4i32(i1 false, <4 x i32> %{{.*}} i1 false, <4 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false) + %r5 = rocdl.wmma.i32.16x16x16.iu8 %arg5, %arg5, %arg3 {signA = false, signB = false, clamp = false} : (vector<4xi32>, vector<4xi32>, vector<8xi32>) -> vector<8xi32> // int4 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1}) - // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v8i32.v2i32(i1 {{.*}}, <2 x i32> %{{.*}}, i1 {{.*}}, <2 x i32> %{{.*}}, <8 x i32> %{{.*}}, i1 {{.*}}) - %r6 = rocdl.wmma.i32.16x16x16.iu4 %zero, %arg4, %zero, %arg4, %arg3, %zero : (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32> + // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v8i32.v2i32(i1 false, <2 x i32> %{{.*}} i1 false, <2 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false) + %r6 = rocdl.wmma.i32.16x16x16.iu4 %arg4, %arg4, %arg3 {signA = false, signB = false, clamp = false} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32> // int4 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1}) - // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x32.iu4.v8i32.v2i32(i1 {{.*}}, <2 x i32> %{{.*}}, i1 {{.*}}, <2 x i32> %{{.*}}, <8 x i32> %{{.*}}, i1 {{.*}}) - %r6.gfx12 = rocdl.wmma.i32.16x16x32.iu4 %zero, %arg4, %zero, %arg4, %arg3, %zero : (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32> + // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x32.iu4.v8i32.v2i32(i1 false, <2 x i32> %{{.*}} i1 false, <2 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false) + %r6.gfx12 = rocdl.wmma.i32.16x16x32.iu4 %arg4, %arg4, %arg3 {signA = false, signB = false, clamp = false} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32> + + // Test signA=true, signB=false for iu8 + // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v8i32.v4i32(i1 true, <4 x i32> %{{.*}} i1 false, <4 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false) + %r5a = rocdl.wmma.i32.16x16x16.iu8 %arg5, %arg5, %arg3 {signA = true, signB = false, clamp = false} : (vector<4xi32>, vector<4xi32>, vector<8xi32>) -> vector<8xi32> + + // Test signA=false, signB=true for iu8 + // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v8i32.v4i32(i1 false, <4 x i32> %{{.*}} i1 true, <4 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false) + %r5b = rocdl.wmma.i32.16x16x16.iu8 %arg5, %arg5, %arg3 {signA = false, signB = true, clamp = false} : (vector<4xi32>, vector<4xi32>, vector<8xi32>) -> vector<8xi32> + + // Test signA=true, signB=true, clamp=true for iu8 + // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v8i32.v4i32(i1 true, <4 x i32> %{{.*}} i1 true, <4 x i32> %{{.*}} <8 x i32> %{{.*}} i1 true) + %r5c = rocdl.wmma.i32.16x16x16.iu8 %arg5, %arg5, %arg3 {signA = true, signB = true, clamp = true} : (vector<4xi32>, vector<4xi32>, vector<8xi32>) -> vector<8xi32> + + // Test signA=true, signB=false for iu4 + // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v8i32.v2i32(i1 true, <2 x i32> %{{.*}} i1 false, <2 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false) + %r6a = rocdl.wmma.i32.16x16x16.iu4 %arg4, %arg4, %arg3 {signA = true, signB = false, clamp = false} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32> + + // Test signA=false, signB=true, clamp=true for iu4 + // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v8i32.v2i32(i1 false, <2 x i32> %{{.*}} i1 true, <2 x i32> %{{.*}} <8 x i32> %{{.*}} i1 true) + %r6b = rocdl.wmma.i32.16x16x16.iu4 %arg4, %arg4, %arg3 {signA = false, signB = true, clamp = true} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32> + + // Test signA=true, signB=true for iu4 gfx12 + // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x32.iu4.v8i32.v2i32(i1 true, <2 x i32> %{{.*}} i1 true, <2 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false) + %r6c = rocdl.wmma.i32.16x16x32.iu4 %arg4, %arg4, %arg3 {signA = true, signB = true, clamp = false} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32> // f32 -> f32 - // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x4.f32.v4f32.v16f32(i1 {{.*}}, <16 x float> %{{.*}}, i1 {{.*}}, <16 x float> %{{.*}}, i16 0, <4 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r1.gfx1250 = rocdl.wmma.f32.16x16x4.f32 %zero, %arg10, %zero, %arg10, %zero_i16, %arg11, %zero, %zero : (i1, vector<16xf32>, i1, vector<16xf32>, i16, vector<4xf32>, i1, i1) -> vector<4xf32> + // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x4.f32.v4f32.v16f32(i1 false, <16 x float> %{{.*}} i1 false, <16 x float> %{{.*}} i16 0, <4 x float> %{{.*}} i1 false, i1 false) + %r1.gfx1250 = rocdl.wmma.f32.16x16x4.f32 %arg10, %arg10, %arg11 {signA = false, signB = false, modC = 0 : i16} : (vector<16xf32>, vector<16xf32>, vector<4xf32>) -> vector<4xf32> // f16 -> f32 - // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v32f32.v16f16(i1 {{.*}}, <16 x half> %{{.*}}, i1 {{.*}}, <16 x half> %{{.*}}, i16 0, <32 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r2.gfx1250 = rocdl.wmma.f32.16x16x32.f16 %zero, %arg1, %zero, %arg1, %zero_i16, %arg12, %zero, %zero : (i1, vector<16xf16>, i1, vector<16xf16>, i16, vector<32xf32>, i1, i1) -> vector<32xf32> + // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v32f32.v16f16(i1 false, <16 x half> %{{.*}} i1 false, <16 x half> %{{.*}} i16 0, <32 x float> %{{.*}} i1 false, i1 false) + %r2.gfx1250 = rocdl.wmma.f32.16x16x32.f16 %arg1, %arg1, %arg12 {signA = false, signB = false, modC = 0 : i16} : (vector<16xf16>, vector<16xf16>, vector<32xf32>) -> vector<32xf32> // bf16 -> f32 - // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.bf16.v32f32.v16bf16(i1 {{.*}}, <16 x bfloat> %{{.*}}, i1 {{.*}}, <16 x bfloat> %{{.*}}, i16 0, <32 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r3.gfx1250 = rocdl.wmma.f32.16x16x32.bf16 %zero, %arg16, %zero, %arg16, %zero_i16, %arg12, %zero, %zero : (i1, vector<16xbf16>, i1, vector<16xbf16>, i16, vector<32xf32>, i1, i1) -> vector<32xf32> + // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.bf16.v32f32.v16bf16(i1 false, <16 x bfloat> %{{.*}} i1 false, <16 x bfloat> %{{.*}} i16 0, <32 x float> %{{.*}} i1 false, i1 false) + %r3.gfx1250 = rocdl.wmma.f32.16x16x32.bf16 %arg16, %arg16, %arg12 {signA = false, signB = false, modC = 0 : i16} : (vector<16xbf16>, vector<16xbf16>, vector<32xf32>) -> vector<32xf32> // f16 -> f16 - // CHECK: call <32 x half> @llvm.amdgcn.wmma.f16.16x16x32.f16.v32f16.v16f16(i1 {{.*}}, <16 x half> %{{.*}}, i1 {{.*}}, <16 x half> %{{.*}}, i16 0, <32 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r4.gfx1250 = rocdl.wmma.f16.16x16x32.f16 %zero, %arg1, %zero, %arg1, %zero_i16, %arg9, %zero, %zero : (i1, vector<16xf16>, i1, vector<16xf16>, i16, vector<32xf16>, i1, i1) -> vector<32xf16> + // CHECK: call <32 x half> @llvm.amdgcn.wmma.f16.16x16x32.f16.v32f16.v16f16(i1 false, <16 x half> %{{.*}} i1 false, <16 x half> %{{.*}} i16 0, <32 x half> %{{.*}} i1 false, i1 false) + %r4.gfx1250 = rocdl.wmma.f16.16x16x32.f16 %arg1, %arg1, %arg9 {signA = false, signB = false, modC = 0 : i16} : (vector<16xf16>, vector<16xf16>, vector<32xf16>) -> vector<32xf16> // bf16 -> bf16 - // CHECK: call <32 x bfloat> @llvm.amdgcn.wmma.bf16.16x16x32.bf16.v32bf16.v16bf16(i1 {{.*}}, <16 x bfloat> %{{.*}}, i1 {{.*}}, <16 x bfloat> %{{.*}}, i16 0, <32 x bfloat> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r5.gfx1250 = rocdl.wmma.bf16.16x16x32.bf16 %zero, %arg16, %zero, %arg16, %zero_i16, %arg17, %zero, %zero : (i1, vector<16xbf16>, i1, vector<16xbf16>, i16, vector<32xbf16>, i1, i1) -> vector<32xbf16> + // CHECK: call <32 x bfloat> @llvm.amdgcn.wmma.bf16.16x16x32.bf16.v32bf16.v16bf16(i1 false, <16 x bfloat> %{{.*}} i1 false, <16 x bfloat> %{{.*}} i16 0, <32 x bfloat> %{{.*}} i1 false, i1 false) + %r5.gfx1250 = rocdl.wmma.bf16.16x16x32.bf16 %arg16, %arg16, %arg17 {signA = false, signB = false, modC = 0 : i16} : (vector<16xbf16>, vector<16xbf16>, vector<32xbf16>) -> vector<32xbf16> // bf16 -> bf16 / f32 - // CHECK: call <32 x bfloat> @llvm.amdgcn.wmma.bf16f32.16x16x32.bf16.v32bf16.v16bf16.v32f32(i1 {{.*}}, <16 x bfloat> %{{.*}}, i1 {{.*}}, <16 x bfloat> %{{.*}}, i16 0, <32 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r6.gfx1250 = rocdl.wmma.bf16f32.16x16x32.bf16 %zero, %arg16, %zero, %arg16, %zero_i16, %arg12, %zero, %zero : (i1, vector<16xbf16>, i1, vector<16xbf16>, i16, vector<32xf32>, i1, i1) -> vector<32xbf16> + // CHECK: call <32 x bfloat> @llvm.amdgcn.wmma.bf16f32.16x16x32.bf16.v32bf16.v16bf16.v32f32(i1 false, <16 x bfloat> %{{.*}} i1 false, <16 x bfloat> %{{.*}} i16 0, <32 x float> %{{.*}} i1 false, i1 false) + %r6.gfx1250 = rocdl.wmma.bf16f32.16x16x32.bf16 %arg16, %arg16, %arg12 {signA = false, signB = false, modC = 0 : i16} : (vector<16xbf16>, vector<16xbf16>, vector<32xf32>) -> vector<32xbf16> // f8/bf8 -> f16/f32 - // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.fp8.fp8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r7.gfx1250 = rocdl.wmma.f32.16x16x64.fp8_fp8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32> + // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.fp8.fp8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false) + %r7.gfx1250 = rocdl.wmma.f32.16x16x64.fp8_fp8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32> - // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.fp8.bf8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r8.gfx1250 = rocdl.wmma.f32.16x16x64.fp8_bf8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32> + // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.fp8.bf8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false) + %r8.gfx1250 = rocdl.wmma.f32.16x16x64.fp8_bf8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32> - // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.bf8.fp8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r9.gfx1250 = rocdl.wmma.f32.16x16x64.bf8_fp8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32> + // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.bf8.fp8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false) + %r9.gfx1250 = rocdl.wmma.f32.16x16x64.bf8_fp8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32> - // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.bf8.bf8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r10.gfx1250 = rocdl.wmma.f32.16x16x64.bf8_bf8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32> + // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.bf8.bf8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false) + %r10.gfx1250 = rocdl.wmma.f32.16x16x64.bf8_bf8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32> - // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.fp8.fp8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r11.gfx1250 = rocdl.wmma.f16.16x16x64.fp8_fp8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16> + // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.fp8.fp8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false) + %r11.gfx1250 = rocdl.wmma.f16.16x16x64.fp8_fp8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16> - // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.fp8.bf8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r12.gfx1250 = rocdl.wmma.f16.16x16x64.fp8_bf8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16> + // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.fp8.bf8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false) + %r12.gfx1250 = rocdl.wmma.f16.16x16x64.fp8_bf8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16> - // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.bf8.fp8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r13.gfx1250 = rocdl.wmma.f16.16x16x64.bf8_fp8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16> + // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.bf8.fp8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false) + %r13.gfx1250 = rocdl.wmma.f16.16x16x64.bf8_fp8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16> - // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.bf8.bf8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r14.gfx1250 = rocdl.wmma.f16.16x16x64.bf8_bf8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16> + // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.bf8.bf8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false) + %r14.gfx1250 = rocdl.wmma.f16.16x16x64.bf8_bf8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16> - // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.fp8.fp8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r15.gfx1250 = rocdl.wmma.f32.16x16x128.fp8_fp8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32> + // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.fp8.fp8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false) + %r15.gfx1250 = rocdl.wmma.f32.16x16x128.fp8_fp8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32> - // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.fp8.bf8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r16.gfx1250 = rocdl.wmma.f32.16x16x128.fp8_bf8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32> + // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.fp8.bf8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false) + %r16.gfx1250 = rocdl.wmma.f32.16x16x128.fp8_bf8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32> - // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.bf8.fp8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r17.gfx1250 = rocdl.wmma.f32.16x16x128.bf8_fp8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32> + // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.bf8.fp8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false) + %r17.gfx1250 = rocdl.wmma.f32.16x16x128.bf8_fp8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32> - // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.bf8.bf8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r18.gfx1250 = rocdl.wmma.f32.16x16x128.bf8_bf8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32> + // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.bf8.bf8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false) + %r18.gfx1250 = rocdl.wmma.f32.16x16x128.bf8_bf8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32> - // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.fp8.fp8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r19.gfx1250 = rocdl.wmma.f16.16x16x128.fp8_fp8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16> + // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.fp8.fp8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false) + %r19.gfx1250 = rocdl.wmma.f16.16x16x128.fp8_fp8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16> - // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.fp8.bf8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r20.gfx1250 = rocdl.wmma.f16.16x16x128.fp8_bf8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16> + // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.fp8.bf8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false) + %r20.gfx1250 = rocdl.wmma.f16.16x16x128.fp8_bf8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16> - // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.bf8.fp8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r21.gfx1250 = rocdl.wmma.f16.16x16x128.bf8_fp8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16> + // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.bf8.fp8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false) + %r21.gfx1250 = rocdl.wmma.f16.16x16x128.bf8_fp8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16> - // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.bf8.bf8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r22.gfx1250 = rocdl.wmma.f16.16x16x128.bf8_bf8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16> + // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.bf8.bf8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false) + %r22.gfx1250 = rocdl.wmma.f16.16x16x128.bf8_bf8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16> // iu8 -> i32 - // CHECK: call <64 x i32> @llvm.amdgcn.wmma.i32.16x16x64.iu8.v64i32.v4i32(i1 {{.*}}, <4 x i32> %{{.*}}, i1 {{.*}}, <4 x i32> %{{.*}}, <64 x i32> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r23.gfx1250 = rocdl.wmma.i32.16x16x64.iu8 %zero, %arg5, %zero, %arg5, %arg14, %zero, %zero : (i1, vector<4xi32>, i1, vector<4xi32>, vector<64xi32>, i1, i1) -> vector<64xi32> + // CHECK: call <64 x i32> @llvm.amdgcn.wmma.i32.16x16x64.iu8.v64i32.v4i32(i1 false, <4 x i32> %{{.*}} i1 false, <4 x i32> %{{.*}} <64 x i32> %{{.*}} i1 false, i1 false) + %r23.gfx1250 = rocdl.wmma.i32.16x16x64.iu8 %arg5, %arg5, %arg14 {signA = false, signB = false} : (vector<4xi32>, vector<4xi32>, vector<64xi32>) -> vector<64xi32> + + // Test signA=true, signB=true for iu8 gfx1250 + // CHECK: call <64 x i32> @llvm.amdgcn.wmma.i32.16x16x64.iu8.v64i32.v4i32(i1 true, <4 x i32> %{{.*}} i1 true, <4 x i32> %{{.*}} <64 x i32> %{{.*}} i1 false, i1 false) + %r23a.gfx1250 = rocdl.wmma.i32.16x16x64.iu8 %arg5, %arg5, %arg14 {signA = true, signB = true} : (vector<4xi32>, vector<4xi32>, vector<64xi32>) -> vector<64xi32> + + // Test signA=true, signB=false, reuseA=true, reuseB=true for iu8 gfx1250 + // CHECK: call <64 x i32> @llvm.amdgcn.wmma.i32.16x16x64.iu8.v64i32.v4i32(i1 true, <4 x i32> %{{.*}} i1 false, <4 x i32> %{{.*}} <64 x i32> %{{.*}} i1 true, i1 true) + %r23b.gfx1250 = rocdl.wmma.i32.16x16x64.iu8 %arg5, %arg5, %arg14 {signA = true, signB = false, reuseA = true, reuseB = true} : (vector<4xi32>, vector<4xi32>, vector<64xi32>) -> vector<64xi32> + + // Test signA=true, signB=true with modC=1 for f32 gfx1250 + // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x4.f32.v4f32.v16f32(i1 true, <16 x float> %{{.*}} i1 true, <16 x float> %{{.*}} i16 1, <4 x float> %{{.*}} i1 false, i1 false) + %r1a.gfx1250 = rocdl.wmma.f32.16x16x4.f32 %arg10, %arg10, %arg11 {signA = true, signB = true, modC = 1 : i16, reuseA = false, reuseB = false} : (vector<16xf32>, vector<16xf32>, vector<4xf32>) -> vector<4xf32> + + // Test with modC=2 and signA=false, signB=true, reuseA=true for f16 gfx1250 + // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v32f32.v16f16(i1 false, <16 x half> %{{.*}} i1 true, <16 x half> %{{.*}} i16 2, <32 x float> %{{.*}} i1 true, i1 false) + %r2a.gfx1250 = rocdl.wmma.f32.16x16x32.f16 %arg1, %arg1, %arg12 {signA = false, signB = true, modC = 2 : i16, reuseA = true, reuseB = false} : (vector<16xf16>, vector<16xf16>, vector<32xf32>) -> vector<32xf32> + + // Test with modC=3 and signA=true, signB=true, reuseB=true for bf16 gfx1250 + // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.bf16.v32f32.v16bf16(i1 true, <16 x bfloat> %{{.*}} i1 true, <16 x bfloat> %{{.*}} i16 3, <32 x float> %{{.*}} i1 false, i1 true) + %r3a.gfx1250 = rocdl.wmma.f32.16x16x32.bf16 %arg16, %arg16, %arg12 {signA = true, signB = true, modC = 3 : i16, reuseA = false, reuseB = true} : (vector<16xbf16>, vector<16xbf16>, vector<32xf32>) -> vector<32xf32> // ---- Wave64 ----- // f16 -> f32 - // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16.v4f32.v16f16(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <4 x float> %{{.*}}) + // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16.v4f32.v16f16(<16 x half> %{{.*}} <16 x half> %{{.*}} <4 x float> %{{.*}}) %r7 = rocdl.wmma.f32.16x16x16.f16 %arg1, %arg1, %arg6 : (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32> // bf16 -> f32 - // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf16.v4f32.v16i16(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <4 x float> %{{.*}}) + // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf16.v4f32.v16i16(<16 x i16> %{{.*}} <16 x i16> %{{.*}} <4 x float> %{{.*}}) %r8 = rocdl.wmma.f32.16x16x16.bf16 %arg2, %arg2, %arg6 : (vector<16xi16>, vector<16xi16>, vector<4xf32>) -> vector<4xf32> // f16 -> f16 (OPSEL = {0,1}) - // CHECK: call <8 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16.v8f16.v16f16(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <8 x half> %{{.*}}, i1 {{.*}}) - %r9 = rocdl.wmma.f16.16x16x16.f16 %arg1, %arg1, %arg7, %zero : (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16> + // CHECK: call <8 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16.v8f16.v16f16(<16 x half> %{{.*}} <16 x half> %{{.*}} <8 x half> %{{.*}} i1 false) + %r9 = rocdl.wmma.f16.16x16x16.f16 %arg1, %arg1, %arg7 {opsel = false} : (vector<16xf16>, vector<16xf16>, vector<8xf16>) -> vector<8xf16> // bf16 -> bf16 (OPSEL = {0,1}) - // CHECK: call <8 x i16> @llvm.amdgcn.wmma.bf16.16x16x16.bf16.v8i16.v16i16(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <8 x i16> %{{.*}}, i1 {{.*}}) - %r11 = rocdl.wmma.bf16.16x16x16.bf16 %arg2, %arg2, %arg8, %zero : (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16> + // CHECK: call <8 x i16> @llvm.amdgcn.wmma.bf16.16x16x16.bf16.v8i16.v16i16(<16 x i16> %{{.*}} <16 x i16> %{{.*}} <8 x i16> %{{.*}} i1 false) + %r11 = rocdl.wmma.bf16.16x16x16.bf16 %arg2, %arg2, %arg8 {opsel = false} : (vector<16xi16>, vector<16xi16>, vector<8xi16>) -> vector<8xi16> // int8 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1}) - // CHECK: call <4 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v4i32.v4i32(i1 {{.*}}, <4 x i32> %{{.*}}, i1 {{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i1 {{.*}}) - %r12 = rocdl.wmma.i32.16x16x16.iu8 %zero, %arg5, %zero, %arg5, %arg5, %zero : (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32> + // CHECK: call <4 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v4i32.v4i32(i1 false, <4 x i32> %{{.*}} i1 false, <4 x i32> %{{.*}} <4 x i32> %{{.*}} i1 true) + %r12 = rocdl.wmma.i32.16x16x16.iu8 %arg5, %arg5, %arg5 {signA = false, signB = false, clamp = true} : (vector<4xi32>, vector<4xi32>, vector<4xi32>) -> vector<4xi32> // int4 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1}) - // CHECK: call <4 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v4i32.v2i32(i1 {{.*}}, <2 x i32> %{{.*}}, i1 {{.*}}, <2 x i32> %{{.*}}, <4 x i32> %{{.*}}, i1 {{.*}}) - %r13 = rocdl.wmma.i32.16x16x16.iu4 %zero, %arg4, %zero, %arg4, %arg5, %zero : (i1, vector<2xi32>, i1, vector<2xi32>, vector<4xi32>, i1) -> vector<4xi32> + // CHECK: call <4 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v4i32.v2i32(i1 false, <2 x i32> %{{.*}} i1 false, <2 x i32> %{{.*}} <4 x i32> %{{.*}} i1 true) + %r13 = rocdl.wmma.i32.16x16x16.iu4 %arg4, %arg4, %arg5 {signA = false, signB = false, clamp = true} : (vector<2xi32>, vector<2xi32>, vector<4xi32>) -> vector<4xi32> llvm.return %r0 : vector<8xf32> } @@ -1028,6 +1136,39 @@ llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> { llvm.return %r3 : vector<4xf16> } +llvm.func @rocdl.load.tr.ops(%gl_ptr : !llvm.ptr<1>, %ds_ptr : !llvm.ptr<3>) { + // CHECK-LABEL: rocdl.load.tr.ops + // CHECK-SAME: (ptr addrspace(1) %[[GL_PTR:.+]], ptr addrspace(3) %[[DS_PTR:.+]]) + // CHECK: call <2 x i32> @llvm.amdgcn.global.load.tr4.b64.v2i32(ptr addrspace(1) %[[GL_PTR]]) + // CHECK: call <2 x i32> @llvm.amdgcn.global.load.tr.b64.v2i32(ptr addrspace(1) %[[GL_PTR]]) + // CHECK: call <3 x i32> @llvm.amdgcn.global.load.tr6.b96.v3i32(ptr addrspace(1) %[[GL_PTR]]) + // CHECK: call <8 x i16> @llvm.amdgcn.global.load.tr.b128.v8i16(ptr addrspace(1) %[[GL_PTR]]) + // CHECK: call <8 x half> @llvm.amdgcn.global.load.tr.b128.v8f16(ptr addrspace(1) %[[GL_PTR]]) + // CHECK: call <8 x bfloat> @llvm.amdgcn.global.load.tr.b128.v8bf16(ptr addrspace(1) %[[GL_PTR]]) + + // CHECK: call <2 x i32> @llvm.amdgcn.ds.load.tr4.b64.v2i32(ptr addrspace(3) %[[DS_PTR]]) + // CHECK: call <2 x i32> @llvm.amdgcn.ds.load.tr8.b64.v2i32(ptr addrspace(3) %[[DS_PTR]]) + // CHECK: call <3 x i32> @llvm.amdgcn.ds.load.tr6.b96.v3i32(ptr addrspace(3) %[[DS_PTR]]) + // CHECK: call <8 x i16> @llvm.amdgcn.ds.load.tr16.b128.v8i16(ptr addrspace(3) %[[DS_PTR]]) + // CHECK: call <8 x half> @llvm.amdgcn.ds.load.tr16.b128.v8f16(ptr addrspace(3) %[[DS_PTR]]) + // CHECK: call <8 x bfloat> @llvm.amdgcn.ds.load.tr16.b128.v8bf16(ptr addrspace(3) %[[DS_PTR]]) + + rocdl.global.load.tr4.b64 %gl_ptr : !llvm.ptr<1> -> vector<2xi32> + rocdl.global.load.tr.b64 %gl_ptr : !llvm.ptr<1> -> vector<2xi32> + rocdl.global.load.tr6.b96 %gl_ptr : !llvm.ptr<1> -> vector<3xi32> + rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xi16> + rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xf16> + rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xbf16> + + rocdl.ds.load.tr4.b64 %ds_ptr : !llvm.ptr<3> -> vector<2xi32> + rocdl.ds.load.tr8.b64 %ds_ptr : !llvm.ptr<3> -> vector<2xi32> + rocdl.ds.load.tr6.b96 %ds_ptr : !llvm.ptr<3> -> vector<3xi32> + rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xi16> + rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xf16> + rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xbf16> + llvm.return +} + llvm.func @rocdl.load.to.lds(%src : !llvm.ptr<7>, %dst: !llvm.ptr<3>) { //CHECK: call void @llvm.amdgcn.load.to.lds.p7 rocdl.load.to.lds %src, %dst, 4, 0, 0 : !llvm.ptr<7> @@ -1040,6 +1181,32 @@ llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) { llvm.return } +// CHECK-LABEL: rocdl.global.load.async.to.lds +llvm.func @rocdl.global.load.async.to.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) { + // CHECK: call void @llvm.amdgcn.global.load.async.to.lds.b8 + rocdl.global.load.async.to.lds.b8 %src, %dst, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3> + // CHECK: call void @llvm.amdgcn.global.load.async.to.lds.b32 + rocdl.global.load.async.to.lds.b32 %src, %dst, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3> + // CHECK: call void @llvm.amdgcn.global.load.async.to.lds.b64 + rocdl.global.load.async.to.lds.b64 %src, %dst, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3> + // CHECK: call void @llvm.amdgcn.global.load.async.to.lds.b128 + rocdl.global.load.async.to.lds.b128 %src, %dst, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3> + llvm.return +} + +// CHECK-LABEL: rocdl.cluster.load.async.to.lds +llvm.func @rocdl.cluster.load.async.to.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) { + // CHECK: call void @llvm.amdgcn.cluster.load.async.to.lds.b8 + rocdl.cluster.load.async.to.lds.b8 %src, %dst, 0, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3> + // CHECK: call void @llvm.amdgcn.cluster.load.async.to.lds.b32 + rocdl.cluster.load.async.to.lds.b32 %src, %dst, 0, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3> + // CHECK: call void @llvm.amdgcn.cluster.load.async.to.lds.b64 + rocdl.cluster.load.async.to.lds.b64 %src, %dst, 0, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3> + // CHECK: call void @llvm.amdgcn.cluster.load.async.to.lds.b128 + rocdl.cluster.load.async.to.lds.b128 %src, %dst, 0, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3> + llvm.return +} + // CHECK-LABEL: rocdl.tensor.load.to.lds llvm.func @rocdl.tensor.load.to.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>, %dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) { @@ -1174,6 +1341,113 @@ llvm.func @rocdl.raw.ptr.buffer.load.lds(%rsrc : !llvm.ptr<8>, %dstLds : !llvm.p llvm.return } +llvm.func @rocdl.wmma.scale(%arg0: i32, %arg1: vector<4xf32>, %arg2: vector<8xi32>, + %arg3: vector<12xi32>, %arg5: vector<16xi32>, + %arg8: i64, %arg9: vector<8xf32>) -> vector<4xf32> { + // CHECK-LABEL: rocdl.wmma.scale + + // Test with default attributes (all zeros/false) + // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 0, <16 x i32> %{{.*}}, i32 0, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i1 false, i1 false) + %r00 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg0, %arg0 + {fmtA = 0 : i32, fmtB = 0 : i32, modC = 0 : i16, + scaleAType = 0 : i32, fmtScaleA = 0 : i32, + scaleBType = 0 : i32, fmtScaleB = 0 : i32, + reuseA = false, reuseB = false} : + (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32> + + // Test with different matrix formats (FP8 x BF8) + // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 0, <16 x i32> %{{.*}}, i32 1, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i1 false, i1 false) + %r01 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg0, %arg0 + {fmtA = 0 : i32, fmtB = 1 : i32, modC = 0 : i16, + scaleAType = 1 : i32, fmtScaleA = 1 : i32, + scaleBType = 1 : i32, fmtScaleB = 1 : i32, + reuseA = false, reuseB = false} : + (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32> + + // Test with FP8 x FP6 (different vector sizes) and modC = 1 (negate) + // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v12i32(i32 0, <16 x i32> %{{.*}}, i32 2, <12 x i32> %{{.*}}, i16 1, <4 x float> %{{.*}}, i32 2, i32 2, i32 %{{.*}}, i32 2, i32 2, i32 %{{.*}}, i1 false, i1 false) + %r02 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg3, %arg1, %arg0, %arg0 + {fmtA = 0 : i32, fmtB = 2 : i32, modC = 1 : i16, + scaleAType = 2 : i32, fmtScaleA = 2 : i32, + scaleBType = 2 : i32, fmtScaleB = 2 : i32, + reuseA = false, reuseB = false} : + (vector<16xi32>, vector<12xi32>, vector<4xf32>, i32, i32) -> vector<4xf32> + + // Test with BF8 x BF6 and modC = 2 (abs) + // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v12i32(i32 1, <16 x i32> %{{.*}}, i32 3, <12 x i32> %{{.*}}, i16 2, <4 x float> %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i1 false, i1 false) + %r03 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg3, %arg1, %arg0, %arg0 + {fmtA = 1 : i32, fmtB = 3 : i32, modC = 2 : i16, + scaleAType = 0 : i32, fmtScaleA = 0 : i32, + scaleBType = 0 : i32, fmtScaleB = 0 : i32, + reuseA = false, reuseB = false} : + (vector<16xi32>, vector<12xi32>, vector<4xf32>, i32, i32) -> vector<4xf32> + + // Test with FP8 x FP4 and modC = 3 (negate(abs)) + // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v8i32(i32 0, <16 x i32> %{{.*}}, i32 4, <8 x i32> %{{.*}}, i16 3, <4 x float> %{{.*}}, i32 3, i32 3, i32 %{{.*}}, i32 3, i32 3, i32 %{{.*}}, i1 false, i1 false) + %r04 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg2, %arg1, %arg0, %arg0 + {fmtA = 0 : i32, fmtB = 4 : i32, modC = 3 : i16, + scaleAType = 3 : i32, fmtScaleA = 3 : i32, + scaleBType = 3 : i32, fmtScaleB = 3 : i32, + reuseA = false, reuseB = false} : + (vector<16xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32> + + // Test with reuseA = true + // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 2, <16 x i32> %{{.*}}, i32 2, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i1 true, i1 false) + %r10 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg0, %arg0 + {fmtA = 2 : i32, fmtB = 2 : i32, modC = 0 : i16, + scaleAType = 0 : i32, fmtScaleA = 0 : i32, + scaleBType = 0 : i32, fmtScaleB = 0 : i32, + reuseA = true, reuseB = false} : + (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32> + + // Test with reuseB = true + // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 3, <16 x i32> %{{.*}}, i32 3, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i1 false, i1 true) + %r11 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg0, %arg0 + {fmtA = 3 : i32, fmtB = 3 : i32, modC = 0 : i16, + scaleAType = 0 : i32, fmtScaleA = 0 : i32, + scaleBType = 0 : i32, fmtScaleB = 0 : i32, + reuseA = false, reuseB = true} : + (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32> + + // Test with both reuseA and reuseB = true + // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 4, <16 x i32> %{{.*}}, i32 4, <16 x i32> %{{.*}}, i16 1, <4 x float> %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i1 true, i1 true) + %r12 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg0, %arg0 + {fmtA = 4 : i32, fmtB = 4 : i32, modC = 1 : i16, + scaleAType = 1 : i32, fmtScaleA = 1 : i32, + scaleBType = 1 : i32, fmtScaleB = 1 : i32, + reuseA = true, reuseB = true} : + (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32> + + // Test scale16 variant with i64 scale exponents + // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale16.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 0, <16 x i32> %{{.*}}, i32 1, <16 x i32> %{{.*}}, i16 2, <4 x float> %{{.*}}, i32 2, i32 2, i64 %{{.*}}, i32 2, i32 2, i64 %{{.*}}, i1 false, i1 false) + %r_scale16 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg8, %arg8 + {fmtA = 0 : i32, fmtB = 1 : i32, modC = 2 : i16, + scaleAType = 2 : i32, fmtScaleA = 2 : i32, + scaleBType = 2 : i32, fmtScaleB = 2 : i32, + reuseA = false, reuseB = false} : + (vector<16xi32>, vector<16xi32>, vector<4xf32>, i64, i64) -> vector<4xf32> + + // Test f4 variant (no matrix format parameters) + // CHECK: call <8 x float> @llvm.amdgcn.wmma.scale.f32.32x16x128.f4.v8f32.v16i32.v8i32(<16 x i32> %{{.*}}, <8 x i32> %{{.*}}, i16 0, <8 x float> %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i1 false, i1 false) + %r_f4 = rocdl.wmma.scale.f32.32x16x128.f4 %arg5, %arg2, %arg9, %arg0, %arg0 + {modC = 0 : i16, + scaleAType = 1 : i32, fmtScaleA = 1 : i32, + scaleBType = 1 : i32, fmtScaleB = 1 : i32, + reuseA = false, reuseB = false} : + (vector<16xi32>, vector<8xi32>, vector<8xf32>, i32, i32) -> vector<8xf32> + + // Test f4 scale16 variant with varied attributes + // CHECK: call <8 x float> @llvm.amdgcn.wmma.scale16.f32.32x16x128.f4.v8f32.v16i32.v8i32(<16 x i32> %{{.*}}, <8 x i32> %{{.*}}, i16 3, <8 x float> %{{.*}}, i32 2, i32 3, i64 %{{.*}}, i32 3, i32 2, i64 %{{.*}}, i1 true, i1 true) + %r_f4_scale16 = rocdl.wmma.scale16.f32.32x16x128.f4 %arg5, %arg2, %arg9, %arg8, %arg8 + {modC = 3 : i16, + scaleAType = 2 : i32, fmtScaleA = 3 : i32, + scaleBType = 3 : i32, fmtScaleB = 2 : i32, + reuseA = true, reuseB = true} : + (vector<16xi32>, vector<8xi32>, vector<8xf32>, i64, i64) -> vector<8xf32> + + llvm.return %r00 : vector<4xf32> +} + llvm.func @rocdl.raw.ptr.buffer.atomic.f32(%rsrc : !llvm.ptr<8>, %offset : i32, %soffset : i32, %vdata1 : f32) { diff --git a/mlir/test/Target/LLVMIR/target-ext-type.mlir b/mlir/test/Target/LLVMIR/target-ext-type.mlir index 6b2d2ea..cee6301 100644 --- a/mlir/test/Target/LLVMIR/target-ext-type.mlir +++ b/mlir/test/Target/LLVMIR/target-ext-type.mlir @@ -6,6 +6,12 @@ llvm.mlir.global external @global() {addr_space = 0 : i32} : !llvm.target<"spirv llvm.return %0 : !llvm.target<"spirv.DeviceEvent"> } +// CHECK: @amdgcn_named_barrier = internal addrspace(3) global target("amdgcn.named.barrier", 0) poison +llvm.mlir.global internal @amdgcn_named_barrier() {addr_space = 3 : i32} : !llvm.target<"amdgcn.named.barrier", 0> { + %0 = llvm.mlir.poison : !llvm.target<"amdgcn.named.barrier", 0> + llvm.return %0 : !llvm.target<"amdgcn.named.barrier", 0> +} + // CHECK-LABEL: define target("spirv.Event") @func2() { // CHECK-NEXT: %1 = alloca target("spirv.Event"), align 8 // CHECK-NEXT: %2 = load target("spirv.Event"), ptr %1, align 8 diff --git a/mlir/test/Target/SPIRV/consecutive-selection.spv b/mlir/test/Target/SPIRV/consecutive-selection.spvasm index 3752058..3752058 100644 --- a/mlir/test/Target/SPIRV/consecutive-selection.spv +++ b/mlir/test/Target/SPIRV/consecutive-selection.spvasm diff --git a/mlir/test/Target/SPIRV/decorations.mlir b/mlir/test/Target/SPIRV/decorations.mlir index 712fd17..29b5d4f 100644 --- a/mlir/test/Target/SPIRV/decorations.mlir +++ b/mlir/test/Target/SPIRV/decorations.mlir @@ -78,6 +78,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> { // ----- spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> { + // CHECK: coherent + spirv.GlobalVariable @var {coherent} : !spirv.ptr<vector<2xf32>, Output> +} + +// ----- + +spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> { // CHECK: linkage_attributes = #spirv.linkage_attributes<linkage_name = "outSideGlobalVar1", linkage_type = <Import>> spirv.GlobalVariable @var1 { linkage_attributes=#spirv.linkage_attributes< diff --git a/mlir/test/Target/SPIRV/group-ops.mlir b/mlir/test/Target/SPIRV/group-ops.mlir index cf519cb..6f19b35 100644 --- a/mlir/test/Target/SPIRV/group-ops.mlir +++ b/mlir/test/Target/SPIRV/group-ops.mlir @@ -1,11 +1,13 @@ -// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip -split-input-file %s | FileCheck %s +// RUN: mlir-translate --no-implicit-module --test-spirv-roundtrip --split-input-file %s | FileCheck %s // RUN: %if spirv-tools %{ rm -rf %t %} // RUN: %if spirv-tools %{ mkdir %t %} // RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv --split-input-file --spirv-save-validation-files-with-prefix=%t/module %s %} // RUN: %if spirv-tools %{ spirv-val %t %} -spirv.module Logical GLSL450 requires #spirv.vce<v1.3, [Shader, Linkage, SubgroupBallotKHR, Groups, SubgroupBufferBlockIOINTEL, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], [SPV_KHR_storage_buffer_storage_class, SPV_KHR_shader_ballot, SPV_INTEL_subgroups, SPV_KHR_uniform_group_instructions]> { +spirv.module Logical GLSL450 requires #spirv.vce<v1.3, + [Shader, Linkage, SubgroupBallotKHR, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], + [SPV_KHR_storage_buffer_storage_class, SPV_KHR_shader_ballot, SPV_KHR_uniform_group_instructions]> { // CHECK-LABEL: @subgroup_ballot spirv.func @subgroup_ballot(%predicate: i1) -> vector<4xi32> "None" { // CHECK: %{{.*}} = spirv.KHR.SubgroupBallot %{{.*}}: vector<4xi32> @@ -24,30 +26,6 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.3, [Shader, Linkage, Subgrou %0 = spirv.GroupBroadcast <Workgroup> %value, %localid : f32, vector<3xi32> spirv.ReturnValue %0: f32 } - // CHECK-LABEL: @subgroup_block_read_intel - spirv.func @subgroup_block_read_intel(%ptr : !spirv.ptr<i32, StorageBuffer>) -> i32 "None" { - // CHECK: spirv.INTEL.SubgroupBlockRead %{{.*}} : !spirv.ptr<i32, StorageBuffer> -> i32 - %0 = spirv.INTEL.SubgroupBlockRead %ptr : !spirv.ptr<i32, StorageBuffer> -> i32 - spirv.ReturnValue %0: i32 - } - // CHECK-LABEL: @subgroup_block_read_intel_vector - spirv.func @subgroup_block_read_intel_vector(%ptr : !spirv.ptr<i32, StorageBuffer>) -> vector<3xi32> "None" { - // CHECK: spirv.INTEL.SubgroupBlockRead %{{.*}} : !spirv.ptr<i32, StorageBuffer> -> vector<3xi32> - %0 = spirv.INTEL.SubgroupBlockRead %ptr : !spirv.ptr<i32, StorageBuffer> -> vector<3xi32> - spirv.ReturnValue %0: vector<3xi32> - } - // CHECK-LABEL: @subgroup_block_write_intel - spirv.func @subgroup_block_write_intel(%ptr : !spirv.ptr<i32, StorageBuffer>, %value: i32) -> () "None" { - // CHECK: spirv.INTEL.SubgroupBlockWrite %{{.*}}, %{{.*}} : i32 - spirv.INTEL.SubgroupBlockWrite "StorageBuffer" %ptr, %value : i32 - spirv.Return - } - // CHECK-LABEL: @subgroup_block_write_intel_vector - spirv.func @subgroup_block_write_intel_vector(%ptr : !spirv.ptr<i32, StorageBuffer>, %value: vector<3xi32>) -> () "None" { - // CHECK: spirv.INTEL.SubgroupBlockWrite %{{.*}}, %{{.*}} : vector<3xi32> - spirv.INTEL.SubgroupBlockWrite "StorageBuffer" %ptr, %value : vector<3xi32> - spirv.Return - } // CHECK-LABEL: @group_iadd spirv.func @group_iadd(%value: i32) -> i32 "None" { // CHECK: spirv.GroupIAdd <Workgroup> <Reduce> %{{.*}} : i32 diff --git a/mlir/test/Target/SPIRV/loop.mlir b/mlir/test/Target/SPIRV/loop.mlir index 95b87b3..b9a4295 100644 --- a/mlir/test/Target/SPIRV/loop.mlir +++ b/mlir/test/Target/SPIRV/loop.mlir @@ -1,5 +1,10 @@ // RUN: mlir-translate -no-implicit-module -split-input-file -test-spirv-roundtrip %s | FileCheck %s +// RUN: %if spirv-tools %{ rm -rf %t %} +// RUN: %if spirv-tools %{ mkdir %t %} +// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv --split-input-file --spirv-save-validation-files-with-prefix=%t/module %s %} +// RUN: %if spirv-tools %{ spirv-val %t %} + // Single loop spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { @@ -62,7 +67,7 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { // Single loop with block arguments -spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { +spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]> { spirv.GlobalVariable @GV1 bind(0, 0) : !spirv.ptr<!spirv.struct<(!spirv.array<10 x f32, stride=4> [0])>, StorageBuffer> spirv.GlobalVariable @GV2 bind(0, 1) : !spirv.ptr<!spirv.struct<(!spirv.array<10 x f32, stride=4> [0])>, StorageBuffer> // CHECK-LABEL: @loop_kernel diff --git a/mlir/test/Target/SPIRV/mlir-translate.mlir b/mlir/test/Target/SPIRV/mlir-translate.mlir index cbce351..b1966fe 100644 --- a/mlir/test/Target/SPIRV/mlir-translate.mlir +++ b/mlir/test/Target/SPIRV/mlir-translate.mlir @@ -1,7 +1,6 @@ // Check that `--spirv-save-validation-files-with-prefix` generates // a correct number of files. -// REQUIRES: shell // RUN: rm -rf %t // RUN: mkdir %t && mlir-translate --serialize-spirv --no-implicit-module \ // RUN: --split-input-file --spirv-save-validation-files-with-prefix=%t/foo %s \ diff --git a/mlir/test/Target/SPIRV/module.mlir b/mlir/test/Target/SPIRV/module.mlir index 7e52e54..fb4d9bc 100644 --- a/mlir/test/Target/SPIRV/module.mlir +++ b/mlir/test/Target/SPIRV/module.mlir @@ -1,6 +1,5 @@ // RUN: mlir-translate --no-implicit-module --test-spirv-roundtrip --split-input-file %s | FileCheck %s -// REQUIRES: shell // RUN: %if spirv-tools %{ rm -rf %t %} // RUN: %if spirv-tools %{ mkdir %t %} // RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv --split-input-file --spirv-save-validation-files-with-prefix=%t/module %s %} diff --git a/mlir/test/Target/SPIRV/phi.mlir b/mlir/test/Target/SPIRV/phi.mlir index ca635a4..92a3387 100644 --- a/mlir/test/Target/SPIRV/phi.mlir +++ b/mlir/test/Target/SPIRV/phi.mlir @@ -1,5 +1,10 @@ // RUN: mlir-translate -no-implicit-module -split-input-file -test-spirv-roundtrip %s | FileCheck %s +// RUN: %if spirv-tools %{ rm -rf %t %} +// RUN: %if spirv-tools %{ mkdir %t %} +// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv --split-input-file --spirv-save-validation-files-with-prefix=%t/module %s %} +// RUN: %if spirv-tools %{ spirv-val %t %} + // Test branch with one block argument spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { @@ -295,15 +300,26 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { %true = spirv.Constant true %zero = spirv.Constant 0 : i32 %one = spirv.Constant 1 : i32 + spirv.mlir.selection { // CHECK: spirv.BranchConditional %{{.*}}, ^[[true1:.*]](%{{.*}}, %{{.*}} : i32, i32), ^[[false1:.*]] - spirv.BranchConditional %true, ^true1(%zero, %zero: i32, i32), ^false1 + spirv.BranchConditional %true, ^true1(%zero, %zero: i32, i32), ^false1 // CHECK: [[true1]](%{{.*}}: i32, %{{.*}}: i32) - ^true1(%arg0: i32, %arg1: i32): - spirv.Return + ^true1(%arg0: i32, %arg1: i32): + spirv.Return // CHECK: [[false1]]: - ^false1: + ^false1: + spirv.Return + ^merge: + spirv.mlir.merge + } + + spirv.Return + } + + spirv.func @main() -> () "None" { spirv.Return } + spirv.EntryPoint "GLCompute" @main } // ----- @@ -314,15 +330,26 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { %true = spirv.Constant true %zero = spirv.Constant 0 : i32 %one = spirv.Constant 1 : i32 + spirv.mlir.selection { // CHECK: spirv.BranchConditional %{{.*}}, ^[[true1:.*]], ^[[false1:.*]](%{{.*}}, %{{.*}} : i32, i32) - spirv.BranchConditional %true, ^true1, ^false1(%zero, %zero: i32, i32) + spirv.BranchConditional %true, ^true1, ^false1(%zero, %zero: i32, i32) // CHECK: [[true1]]: - ^true1: - spirv.Return + ^true1: + spirv.Return // CHECK: [[false1]](%{{.*}}: i32, %{{.*}}: i32): - ^false1(%arg0: i32, %arg1: i32): + ^false1(%arg0: i32, %arg1: i32): + spirv.Return + ^merge: + spirv.mlir.merge + } + + spirv.Return + } + + spirv.func @main() -> () "None" { spirv.Return } + spirv.EntryPoint "GLCompute" @main } // ----- @@ -333,13 +360,24 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { %true = spirv.Constant true %zero = spirv.Constant 0 : i32 %one = spirv.Constant 1 : i32 + spirv.mlir.selection { // CHECK: spirv.BranchConditional %{{.*}}, ^[[true1:.*]](%{{.*}} : i32), ^[[false1:.*]](%{{.*}}, %{{.*}} : i32, i32) - spirv.BranchConditional %true, ^true1(%one: i32), ^false1(%zero, %zero: i32, i32) + spirv.BranchConditional %true, ^true1(%one: i32), ^false1(%zero, %zero: i32, i32) // CHECK: [[true1]](%{{.*}}: i32): - ^true1(%arg0: i32): - spirv.Return + ^true1(%arg0: i32): + spirv.Return // CHECK: [[false1]](%{{.*}}: i32, %{{.*}}: i32): - ^false1(%arg1: i32, %arg2: i32): + ^false1(%arg1: i32, %arg2: i32): + spirv.Return + ^merge: + spirv.mlir.merge + } + spirv.Return } + + spirv.func @main() -> () "None" { + spirv.Return + } + spirv.EntryPoint "GLCompute" @main } diff --git a/mlir/test/Target/SPIRV/selection.mlir b/mlir/test/Target/SPIRV/selection.mlir index 44625cc..d0ad118 100644 --- a/mlir/test/Target/SPIRV/selection.mlir +++ b/mlir/test/Target/SPIRV/selection.mlir @@ -1,5 +1,10 @@ // RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip -split-input-file %s | FileCheck %s +// RUN: %if spirv-tools %{ rm -rf %t %} +// RUN: %if spirv-tools %{ mkdir %t %} +// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv --split-input-file --spirv-save-validation-files-with-prefix=%t/module %s %} +// RUN: %if spirv-tools %{ spirv-val %t %} + // Selection with both then and else branches spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { @@ -136,19 +141,31 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { // CHECK-NEXT: spirv.Load "Function" %[[VAR]] %cond = spirv.Load "Function" %var : i1 + spirv.mlir.selection { // CHECK: spirv.BranchConditional %1, ^[[THEN1:.+]](%{{.+}} : i32), ^[[ELSE1:.+]](%{{.+}}, %{{.+}} : i32, i32) - spirv.BranchConditional %cond, ^then1(%one: i32), ^else1(%zero, %zero: i32, i32) + spirv.BranchConditional %cond, ^then1(%one: i32), ^else1(%zero, %zero: i32, i32) // CHECK-NEXT: ^[[THEN1]](%{{.+}}: i32): // CHECK-NEXT: spirv.Return - ^then1(%arg0: i32): - spirv.Return + ^then1(%arg0: i32): + spirv.Return // CHECK-NEXT: ^[[ELSE1]](%{{.+}}: i32, %{{.+}}: i32): // CHECK-NEXT: spirv.Return - ^else1(%arg1: i32, %arg2: i32): + ^else1(%arg1: i32, %arg2: i32): + spirv.Return + ^merge: + spirv.mlir.merge + } + spirv.Return } + + spirv.func @main() -> () "None" { + spirv.Return + } + spirv.EntryPoint "GLCompute" @main + spirv.ExecutionMode @main "LocalSize", 1, 1, 1 } // ----- @@ -203,3 +220,129 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { spirv.EntryPoint "GLCompute" @main spirv.ExecutionMode @main "LocalSize", 1, 1, 1 } + +// ----- + +// Selection with switch + +spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { +// CHECK-LABEL: @selection_switch + spirv.func @selection_switch(%selector: i32) -> () "None" { + %zero = spirv.Constant 0: i32 + %one = spirv.Constant 1: i32 + %two = spirv.Constant 2: i32 + %three = spirv.Constant 3: i32 + %four = spirv.Constant 4: i32 +// CHECK: {{%.*}} = spirv.Variable init({{%.*}}) : !spirv.ptr<i32, Function> + %var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function> +// CHECK: spirv.mlir.selection { + spirv.mlir.selection { +// CHECK-NEXT: spirv.Switch {{%.*}} : i32, [ +// CHECK-NEXT: default: ^[[DEFAULT:.+]], +// CHECK-NEXT: 0: ^[[CASE0:.+]], +// CHECK-NEXT: 1: ^[[CASE1:.+]], +// CHECK-NEXT: 2: ^[[CASE2:.+]] + spirv.Switch %selector : i32, [ + default: ^default, + 0: ^case0, + 1: ^case1, + 2: ^case2 + ] +// CHECK: ^[[DEFAULT]] + ^default: +// CHECK: spirv.Store "Function" {{%.*}}, {{%.*}} : i32 + spirv.Store "Function" %var, %one : i32 +// CHECK-NEXT: spirv.Branch ^[[MERGE:.+]] + spirv.Branch ^merge +// CHECK-NEXT: ^[[CASE0]] + ^case0: +// CHECK: spirv.Store "Function" {{%.*}}, {{%.*}} : i32 + spirv.Store "Function" %var, %two : i32 +// CHECK-NEXT: spirv.Branch ^[[MERGE:.+]] + spirv.Branch ^merge +// CHECK-NEXT: ^[[CASE1]] + ^case1: +// CHECK: spirv.Store "Function" {{%.*}}, {{%.*}} : i32 + spirv.Store "Function" %var, %three : i32 +// CHECK-NEXT: spirv.Branch ^[[MERGE:.+]] + spirv.Branch ^merge +// CHECK-NEXT: ^[[CASE2]] + ^case2: +// CHECK: spirv.Store "Function" {{%.*}}, {{%.*}} : i32 + spirv.Store "Function" %var, %four : i32 +// CHECK-NEXT: spirv.Branch ^[[MERGE:.+]] + spirv.Branch ^merge +// CHECK-NEXT: ^[[MERGE]] + ^merge: +// CHECK-NEXT: spirv.mlir.merge + spirv.mlir.merge +// CHECK-NEXT: } + } +// CHECK-NEXT: spirv.Return + spirv.Return + } + + spirv.func @main() -> () "None" { + spirv.Return + } + spirv.EntryPoint "GLCompute" @main + spirv.ExecutionMode @main "LocalSize", 1, 1, 1 +} + +// ----- + +// Selection with switch and block operands + +spirv.module Logical GLSL450 requires #spirv.vce<v1.5, [Shader], []> { +// CHECK-LABEL: @selection_switch_operands + spirv.func @selection_switch_operands(%selector : si32) "None" { + %cst1 = spirv.Constant 1.000000e+00 : f32 + %vec0 = spirv.Undef : vector<3xf32> +// CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[0 : i32] : f32 into vector<3xf32> + %vec1 = spirv.CompositeInsert %cst1, %vec0[0 : i32] : f32 into vector<3xf32> + spirv.Branch ^bb1 + ^bb1: +// CHECK: {{%.*}} = spirv.mlir.selection -> vector<3xf32> { + %vec4 = spirv.mlir.selection -> vector<3xf32> { +// CHECK-NEXT: spirv.Switch {{%.*}} : si32, [ +// CHECK-NEXT: default: ^[[DEFAULT:.+]]({{%.*}} : vector<3xf32>), +// CHECK-NEXT: 0: ^[[CASE0:.+]]({{%.*}} : vector<3xf32>), +// CHECK-NEXT: 1: ^[[CASE1:.+]]({{%.*}} : vector<3xf32>) + spirv.Switch %selector : si32, [ + default: ^bb3(%vec1 : vector<3xf32>), + 0: ^bb1(%vec1 : vector<3xf32>), + 1: ^bb2(%vec1 : vector<3xf32>) + ] +// CHECK: ^[[CASE0]]({{%.*}}: vector<3xf32>) + ^bb1(%vecbb1: vector<3xf32>): + %cst3 = spirv.Constant 3.000000e+00 : f32 +// CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[1 : i32] : f32 into vector<3xf32> + %vec2 = spirv.CompositeInsert %cst3, %vecbb1[1 : i32] : f32 into vector<3xf32> +// CHECK-NEXT: spirv.Branch ^[[DEFAULT]]({{%.*}} : vector<3xf32>) + spirv.Branch ^bb3(%vec2 : vector<3xf32>) +// CHECK-NEXT: ^[[CASE1]]({{%.*}}: vector<3xf32>) + ^bb2(%vecbb2: vector<3xf32>): + %cst4 = spirv.Constant 4.000000e+00 : f32 +// CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[1 : i32] : f32 into vector<3xf32> + %vec3 = spirv.CompositeInsert %cst4, %vecbb2[1 : i32] : f32 into vector<3xf32> +// CHECK-NEXT: spirv.Branch ^[[DEFAULT]]({{%.*}} : vector<3xf32>) + spirv.Branch ^bb3(%vec3 : vector<3xf32>) +// CHECK-NEXT: ^[[DEFAULT]]({{%.*}}: vector<3xf32>) + ^bb3(%vecbb3: vector<3xf32>): +// CHECK-NEXT: spirv.mlir.merge {{%.*}} : vector<3xf32> + spirv.mlir.merge %vecbb3 : vector<3xf32> +// CHECK-NEXT: } + } + %cst2 = spirv.Constant 2.000000e+00 : f32 +// CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[2 : i32] : f32 into vector<3xf32> + %vec5 = spirv.CompositeInsert %cst2, %vec4[2 : i32] : f32 into vector<3xf32> + spirv.Return + } + + spirv.func @main() -> () "None" { + spirv.Return + } + + spirv.EntryPoint "GLCompute" @main + spirv.ExecutionMode @main "LocalSize", 1, 1, 1 +} diff --git a/mlir/test/Target/SPIRV/selection.spv b/mlir/test/Target/SPIRV/selection.spvasm index 9642d0a..9642d0a 100644 --- a/mlir/test/Target/SPIRV/selection.spv +++ b/mlir/test/Target/SPIRV/selection.spvasm diff --git a/mlir/test/Target/SPIRV/selection_switch.spvasm b/mlir/test/Target/SPIRV/selection_switch.spvasm new file mode 100644 index 0000000..81fecf3 --- /dev/null +++ b/mlir/test/Target/SPIRV/selection_switch.spvasm @@ -0,0 +1,69 @@ +; RUN: %if spirv-tools %{ spirv-as --target-env spv1.0 %s -o - | mlir-translate --deserialize-spirv - -o - | FileCheck %s %} + +; This test is analogous to selection.spv but tests switch op. + +; CHECK: spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { +; CHECK-NEXT: spirv.func @switch({{%.*}}: si32) "None" { +; CHECK: {{%.*}} = spirv.Constant 1.000000e+00 : f32 +; CHECK-NEXT: {{%.*}} = spirv.Undef : vector<3xf32> +; CHECK-NEXT: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[0 : i32] : f32 into vector<3xf32> +; CHECK-NEXT: spirv.Branch ^[[bb:.+]] +; CHECK-NEXT: ^[[bb:.+]]: +; CHECK-NEXT: {{%.*}} = spirv.mlir.selection -> vector<3xf32> { +; CHECK-NEXT: spirv.Switch {{%.*}} : si32, [ +; CHECK-NEXT: default: ^[[bb:.+]]({{%.*}}: vector<3xf32>), +; CHECK-NEXT: 0: ^[[bb:.+]]({{%.*}}: vector<3xf32>), +; CHECK-NEXT: 1: ^[[bb:.+]]({{%.*}}: vector<3xf32>) +; CHECK: ^[[bb:.+]]({{%.*}}: vector<3xf32>): +; CHECK: spirv.Branch ^[[bb:.+]]({{%.*}}: vector<3xf32>) +; CHECK-NEXT: ^[[bb:.+]]({{%.*}}: vector<3xf32>): +; CHECK: spirv.Branch ^[[bb:.+]]({{%.*}}: vector<3xf32>) +; CHECK-NEXT: ^[[bb:.+]]({{%.*}}: vector<3xf32>): +; CHECK-NEXT: spirv.mlir.merge %8 : vector<3xf32> +; CHECK-NEXT } +; CHECK: spirv.Return +; CHECK-NEXT: } +; CHECK: } + + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpName %switch "switch" + OpName %main "main" + %void = OpTypeVoid + %int = OpTypeInt 32 1 + %1 = OpTypeFunction %void %int + %float = OpTypeFloat 32 + %float_1 = OpConstant %float 1 + %v3float = OpTypeVector %float 3 + %9 = OpUndef %v3float + %float_3 = OpConstant %float 3 + %float_4 = OpConstant %float 4 + %float_2 = OpConstant %float 2 + %25 = OpTypeFunction %void + %switch = OpFunction %void None %1 + %5 = OpFunctionParameter %int + %6 = OpLabel + OpBranch %12 + %12 = OpLabel + %11 = OpCompositeInsert %v3float %float_1 %9 0 + OpSelectionMerge %15 None + OpSwitch %5 %15 0 %13 1 %14 + %13 = OpLabel + %16 = OpPhi %v3float %11 %12 + %18 = OpCompositeInsert %v3float %float_3 %16 1 + OpBranch %15 + %14 = OpLabel + %19 = OpPhi %v3float %11 %12 + %21 = OpCompositeInsert %v3float %float_4 %19 1 + OpBranch %15 + %15 = OpLabel + %22 = OpPhi %v3float %21 %14 %18 %13 %11 %12 + %24 = OpCompositeInsert %v3float %float_2 %22 2 + OpReturn + OpFunctionEnd + %main = OpFunction %void None %25 + %27 = OpLabel + OpReturn + OpFunctionEnd diff --git a/mlir/test/Target/SPIRV/struct.mlir b/mlir/test/Target/SPIRV/struct.mlir index 4984ee7..c423500 100644 --- a/mlir/test/Target/SPIRV/struct.mlir +++ b/mlir/test/Target/SPIRV/struct.mlir @@ -1,6 +1,11 @@ // RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip %s | FileCheck %s -spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { +// RUN: %if spirv-tools %{ rm -rf %t %} +// RUN: %if spirv-tools %{ mkdir %t %} +// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv --split-input-file --spirv-save-validation-files-with-prefix=%t/module %s %} +// RUN: %if spirv-tools %{ spirv-val %t %} + +spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Addresses, Float64, Int64, Linkage], [SPV_KHR_storage_buffer_storage_class]> { // CHECK: !spirv.ptr<!spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, Input> spirv.GlobalVariable @var0 bind(0, 1) : !spirv.ptr<!spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, Input> @@ -16,8 +21,8 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4]), Block>, StorageBuffer> spirv.GlobalVariable @var4 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4]), Block>, StorageBuffer> - // CHECK: !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable]), Block>, StorageBuffer> - spirv.GlobalVariable @var5 : !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable]), Block>, StorageBuffer> + // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]), Block>, StorageBuffer> + spirv.GlobalVariable @var5 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]), Block>, StorageBuffer> // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]), Block>, StorageBuffer> spirv.GlobalVariable @var6 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]), Block>, StorageBuffer> @@ -34,14 +39,14 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { // CHECK: !spirv.ptr<!spirv.struct<test_id, (!spirv.array<128 x f32, stride=4> [0])>, Input> spirv.GlobalVariable @id_var0 : !spirv.ptr<!spirv.struct<test_id, (!spirv.array<128 x f32, stride=4> [0])>, Input> - // CHECK: !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>), Block>, StorageBuffer> - spirv.GlobalVariable @recursive_simple : !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>), Block>, StorageBuffer> + // CHECK: !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer> [0]), Block>, StorageBuffer> + spirv.GlobalVariable @recursive_simple : !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer> [0]), Block>, StorageBuffer> - // CHECK: !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>), Block>, Uniform>), Block>, Uniform> - spirv.GlobalVariable @recursive_2 : !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>), Block>, Uniform>), Block>, Uniform> + // CHECK: !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform> [0]), Block>, Uniform> [0]), Block>, Uniform> + spirv.GlobalVariable @recursive_2 : !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform> [0]), Block>, Uniform> [0]), Block>, Uniform> - // CHECK: !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>), Block>, Uniform>), Block>, Uniform> - spirv.GlobalVariable @recursive_3 : !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>), Block>, Uniform>), Block>, Uniform> + // CHECK: !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform> [0], !spirv.ptr<!spirv.struct<bxx>, Uniform> [8]), Block>, Uniform> [0]), Block>, Uniform> + spirv.GlobalVariable @recursive_3 : !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform> [0], !spirv.ptr<!spirv.struct<bxx>, Uniform> [8]), Block>, Uniform> [0]), Block>, Uniform> // CHECK: spirv.GlobalVariable @block : !spirv.ptr<!spirv.struct<vert, (vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>, Output> spirv.GlobalVariable @block : !spirv.ptr<!spirv.struct<vert, (vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>, Output> diff --git a/mlir/test/Target/SPIRV/subgroup-block-intel.mlir b/mlir/test/Target/SPIRV/subgroup-block-intel.mlir new file mode 100644 index 0000000..14060e6 --- /dev/null +++ b/mlir/test/Target/SPIRV/subgroup-block-intel.mlir @@ -0,0 +1,34 @@ +// RUN: mlir-translate --no-implicit-module --test-spirv-roundtrip %s | FileCheck %s + +// RUN: %if spirv-tools %{ rm -rf %t %} +// RUN: %if spirv-tools %{ mkdir %t %} +// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv --spirv-save-validation-files-with-prefix=%t/module %s %} +// RUN: %if spirv-tools %{ spirv-val %t %} + +spirv.module Physical64 GLSL450 requires #spirv.vce<v1.3, [Addresses, Shader, Linkage, SubgroupBufferBlockIOINTEL], + [SPV_KHR_storage_buffer_storage_class, SPV_INTEL_subgroups]> { + // CHECK-LABEL: @subgroup_block_read_intel + spirv.func @subgroup_block_read_intel(%ptr : !spirv.ptr<i32, StorageBuffer>) -> i32 "None" { + // CHECK: spirv.INTEL.SubgroupBlockRead %{{.*}} : !spirv.ptr<i32, StorageBuffer> -> i32 + %0 = spirv.INTEL.SubgroupBlockRead %ptr : !spirv.ptr<i32, StorageBuffer> -> i32 + spirv.ReturnValue %0: i32 + } + // CHECK-LABEL: @subgroup_block_read_intel_vector + spirv.func @subgroup_block_read_intel_vector(%ptr : !spirv.ptr<i32, StorageBuffer>) -> vector<3xi32> "None" { + // CHECK: spirv.INTEL.SubgroupBlockRead %{{.*}} : !spirv.ptr<i32, StorageBuffer> -> vector<3xi32> + %0 = spirv.INTEL.SubgroupBlockRead %ptr : !spirv.ptr<i32, StorageBuffer> -> vector<3xi32> + spirv.ReturnValue %0: vector<3xi32> + } + // CHECK-LABEL: @subgroup_block_write_intel + spirv.func @subgroup_block_write_intel(%ptr : !spirv.ptr<i32, StorageBuffer>, %value: i32) -> () "None" { + // CHECK: spirv.INTEL.SubgroupBlockWrite %{{.*}}, %{{.*}} : i32 + spirv.INTEL.SubgroupBlockWrite "StorageBuffer" %ptr, %value : i32 + spirv.Return + } + // CHECK-LABEL: @subgroup_block_write_intel_vector + spirv.func @subgroup_block_write_intel_vector(%ptr : !spirv.ptr<i32, StorageBuffer>, %value: vector<3xi32>) -> () "None" { + // CHECK: spirv.INTEL.SubgroupBlockWrite %{{.*}}, %{{.*}} : vector<3xi32> + spirv.INTEL.SubgroupBlockWrite "StorageBuffer" %ptr, %value : vector<3xi32> + spirv.Return + } +} diff --git a/mlir/test/Transforms/buffer-results-to-out-params-modify-public-functions.mlir b/mlir/test/Transforms/buffer-results-to-out-params-modify-public-functions.mlir new file mode 100644 index 0000000..c99bde3 --- /dev/null +++ b/mlir/test/Transforms/buffer-results-to-out-params-modify-public-functions.mlir @@ -0,0 +1,40 @@ +// RUN: mlir-opt -p 'builtin.module(buffer-results-to-out-params{modify-public-functions})' %s | FileCheck %s + +// Test if `public` functions' return values are transformed into out parameters +// when `buffer-results-to-out-params` is invoked with `modifyPublicFunctions`. + +// CHECK-LABEL: func.func @basic( +// CHECK-SAME: %[[ARG0:.*]]: memref<f32>) { +// CHECK: %[[VAL_0:.*]] = "test.source"() : () -> memref<f32> +// CHECK: memref.copy %[[VAL_0]], %[[ARG0]] : memref<f32> to memref<f32> +// CHECK: return +// CHECK: } +func.func @basic() -> (memref<f32>) { + %0 = "test.source"() : () -> (memref<f32>) + return %0 : memref<f32> +} + +// CHECK-LABEL: func.func @presence_of_existing_arguments( +// CHECK-SAME: %[[ARG0:.*]]: memref<1xf32>, +// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>) { +// CHECK: %[[VAL_0:.*]] = "test.source"() : () -> memref<2xf32> +// CHECK: memref.copy %[[VAL_0]], %[[ARG1]] : memref<2xf32> to memref<2xf32> +// CHECK: return +// CHECK: } +func.func @presence_of_existing_arguments(%arg0: memref<1xf32>) -> (memref<2xf32>) { + %0 = "test.source"() : () -> (memref<2xf32>) + return %0 : memref<2xf32> +} + +// CHECK-LABEL: func.func @multiple_results( +// CHECK-SAME: %[[ARG0:.*]]: memref<1xf32>, +// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>) { +// CHECK: %[[VAL_0:.*]]:2 = "test.source"() : () -> (memref<1xf32>, memref<2xf32>) +// CHECK: memref.copy %[[VAL_0]]#0, %[[ARG0]] : memref<1xf32> to memref<1xf32> +// CHECK: memref.copy %[[VAL_0]]#1, %[[ARG1]] : memref<2xf32> to memref<2xf32> +// CHECK: return +// CHECK: } +func.func @multiple_results() -> (memref<1xf32>, memref<2xf32>) { + %0, %1 = "test.source"() : () -> (memref<1xf32>, memref<2xf32>) + return %0, %1 : memref<1xf32>, memref<2xf32> +} diff --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir index c1604e2..31a4f64d 100644 --- a/mlir/test/Transforms/loop-invariant-code-motion.mlir +++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir @@ -880,6 +880,18 @@ func.func @no_speculate_divui( return } +func.func @no_speculate_udiv( +// CHECK-LABEL: @no_speculate_udiv( + %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) { + scf.for %i = %lb to %ub step %step { +// CHECK: scf.for +// CHECK: llvm.udiv + %val = llvm.udiv %num, %denom : i32 + } + + return +} + func.func @no_speculate_divsi( // CHECK-LABEL: @no_speculate_divsi( %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) { @@ -892,6 +904,18 @@ func.func @no_speculate_divsi( return } +func.func @no_speculate_sdiv( +// CHECK-LABEL: @no_speculate_sdiv( + %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) { + scf.for %i = %lb to %ub step %step { +// CHECK: scf.for +// CHECK: llvm.sdiv + %val = llvm.sdiv %num, %denom : i32 + } + + return +} + func.func @no_speculate_ceildivui( // CHECK-LABEL: @no_speculate_ceildivui( %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) { @@ -928,6 +952,18 @@ func.func @no_speculate_divui_const(%num: i32, %lb: index, %ub: index, %step: in return } +func.func @no_speculate_udiv_const(%num: i32, %lb: index, %ub: index, %step: index) { +// CHECK-LABEL: @no_speculate_udiv_const( + %c0 = arith.constant 0 : i32 + scf.for %i = %lb to %ub step %step { +// CHECK: scf.for +// CHECK: llvm.udiv + %val = llvm.udiv %num, %c0 : i32 + } + + return +} + func.func @speculate_divui_const( // CHECK-LABEL: @speculate_divui_const( %num: i32, %lb: index, %ub: index, %step: index) { @@ -941,6 +977,19 @@ func.func @speculate_divui_const( return } +func.func @speculate_udiv_const( +// CHECK-LABEL: @speculate_udiv_const( + %num: i32, %lb: index, %ub: index, %step: index) { + %c5 = llvm.mlir.constant(5 : i32) : i32 +// CHECK: llvm.udiv +// CHECK: scf.for + scf.for %i = %lb to %ub step %step { + %val = llvm.udiv %num, %c5 : i32 + } + + return +} + func.func @no_speculate_ceildivui_const(%num: i32, %lb: index, %ub: index, %step: index) { // CHECK-LABEL: @no_speculate_ceildivui_const( %c0 = arith.constant 0 : i32 @@ -979,6 +1028,19 @@ func.func @no_speculate_divsi_const0( return } +func.func @no_speculate_sdiv_const0( +// CHECK-LABEL: @no_speculate_sdiv_const0( + %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) { + %c0 = arith.constant 0 : i32 + scf.for %i = %lb to %ub step %step { +// CHECK: scf.for +// CHECK: llvm.sdiv + %val = llvm.sdiv %num, %c0 : i32 + } + + return +} + func.func @no_speculate_divsi_const_minus1( // CHECK-LABEL: @no_speculate_divsi_const_minus1( %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) { @@ -992,6 +1054,19 @@ func.func @no_speculate_divsi_const_minus1( return } +func.func @no_speculate_sdiv_const_minus1( +// CHECK-LABEL: @no_speculate_sdiv_const_minus1( + %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) { + %cm1 = arith.constant -1 : i32 + scf.for %i = %lb to %ub step %step { +// CHECK: scf.for +// CHECK: llvm.sdiv + %val = llvm.sdiv %num, %cm1 : i32 + } + + return +} + func.func @speculate_divsi_const( // CHECK-LABEL: @speculate_divsi_const( %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) { @@ -1005,6 +1080,19 @@ func.func @speculate_divsi_const( return } +func.func @speculate_sdiv_const( +// CHECK-LABEL: @speculate_sdiv_const( + %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) { + %c5 = arith.constant 5 : i32 + scf.for %i = %lb to %ub step %step { +// CHECK: llvm.sdiv +// CHECK: scf.for + %val = llvm.sdiv %num, %c5 : i32 + } + + return +} + func.func @no_speculate_ceildivsi_const0( // CHECK-LABEL: @no_speculate_ceildivsi_const0( %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) { @@ -1057,6 +1145,19 @@ func.func @no_speculate_divui_range( return } +func.func @no_speculate_udiv_range( +// CHECK-LABEL: @no_speculate_udiv_range( + %num: i8, %lb: index, %ub: index, %step: index) { + %denom = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8 + scf.for %i = %lb to %ub step %step { +// CHECK: scf.for +// CHECK: llvm.udiv + %val = llvm.udiv %num, %denom : i8 + } + + return +} + func.func @no_speculate_divsi_range( // CHECK-LABEL: @no_speculate_divsi_range( %num: i8, %lb: index, %ub: index, %step: index) { @@ -1072,6 +1173,21 @@ func.func @no_speculate_divsi_range( return } +func.func @no_speculate_sdiv_range( +// CHECK-LABEL: @no_speculate_sdiv_range( + %num: i8, %lb: index, %ub: index, %step: index) { + %denom0 = test.with_bounds {smax = -1: i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8 + %denom1 = test.with_bounds {smax = 127 : i8, smin = 0 : i8, umax = 255 : i8, umin = 0 : i8} : i8 + scf.for %i = %lb to %ub step %step { +// CHECK: scf.for +// CHECK-COUNT-2: llvm.sdiv + %val0 = llvm.sdiv %num, %denom0 : i8 + %val1 = llvm.sdiv %num, %denom1 : i8 + } + + return +} + func.func @no_speculate_ceildivui_range( // CHECK-LABEL: @no_speculate_ceildivui_range( %num: i8, %lb: index, %ub: index, %step: index) { @@ -1113,6 +1229,19 @@ func.func @speculate_divui_range( return } +func.func @speculate_udiv_range( +// CHECK-LABEL: @speculate_udiv_range( + %num: i8, %lb: index, %ub: index, %step: index) { + %denom = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 1 : i8} : i8 + scf.for %i = %lb to %ub step %step { +// CHECK: llvm.udiv +// CHECK: scf.for + %val = llvm.udiv %num, %denom : i8 + } + + return +} + func.func @speculate_divsi_range( // CHECK-LABEL: @speculate_divsi_range( %num: i8, %lb: index, %ub: index, %step: index) { @@ -1129,6 +1258,22 @@ func.func @speculate_divsi_range( return } +func.func @speculate_sdiv_range( +// CHECK-LABEL: @speculate_sdiv_range( + %num: i8, %lb: index, %ub: index, %step: index) { + %denom0 = test.with_bounds {smax = 127 : i8, smin = 1 : i8, umax = 255 : i8, umin = 0 : i8} : i8 + %denom1 = test.with_bounds {smax = -2 : i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8 + scf.for %i = %lb to %ub step %step { +// CHECK-COUNT-2: llvm.sdiv +// CHECK: scf.for + %val0 = llvm.sdiv %num, %denom0 : i8 + %val1 = llvm.sdiv %num, %denom1 : i8 + + } + + return +} + func.func @speculate_ceildivui_range( // CHECK-LABEL: @speculate_ceildivui_range( %num: i8, %lb: index, %ub: index, %step: index) { diff --git a/mlir/test/Transforms/move-operation-deps.mlir b/mlir/test/Transforms/move-operation-deps.mlir index 75d8386..3119fd3 100644 --- a/mlir/test/Transforms/move-operation-deps.mlir +++ b/mlir/test/Transforms/move-operation-deps.mlir @@ -238,25 +238,26 @@ module attributes {transform.with_named_sequence} { // ----- // Check simple move value definitions before insertion operation. -func.func @simple_move_values() -> f32 { - %0 = "before"() : () -> (f32) - %1 = "moved_op_1"() : () -> (f32) - %2 = "moved_op_2"() : () -> (f32) - %3 = "foo"(%1, %2) : (f32, f32) -> (f32) - return %3 : f32 -} -// CHECK-LABEL: func @simple_move_values() -// CHECK: %[[MOVED1:.+]] = "moved_op_1" -// CHECK: %[[MOVED2:.+]] = "moved_op_2" +func.func @simple_move_values(%arg0 : index) -> index { + %c0 = arith.constant 0 : index + %0 = "before"() : () -> (index) + %1 = arith.addi %arg0, %c0 {"moved_op_1"} : index + %2 = arith.subi %arg0, %c0 {"moved_op_2"} : index + %3 = "foo"(%1, %2) : (index, index) -> (index) + return %3 : index +} +// CHECK-LABEL: func @simple_move_values( +// CHECK: %[[MOVED1:.+]] = arith.addi {{.*}} {moved_op_1} +// CHECK: %[[MOVED2:.+]] = arith.subi {{.*}} {moved_op_2} // CHECK: %[[BEFORE:.+]] = "before" // CHECK: %[[FOO:.+]] = "foo"(%[[MOVED1]], %[[MOVED2]]) // CHECK: return %[[FOO]] module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { - %op1 = transform.structured.match ops{["moved_op_1"]} in %arg0 + %op1 = transform.structured.match ops{["arith.addi"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %op2 = transform.structured.match ops{["moved_op_2"]} in %arg0 + %op2 = transform.structured.match ops{["arith.subi"]} in %arg0 : (!transform.any_op) -> !transform.any_op %op3 = transform.structured.match ops{["before"]} in %arg0 : (!transform.any_op) -> !transform.any_op @@ -271,23 +272,26 @@ module attributes {transform.with_named_sequence} { // ----- // Compute slice including the implicitly captured values. -func.func @move_region_dependencies_values() -> f32 { - %0 = "before"() : () -> (f32) - %1 = "moved_op_1"() : () -> (f32) - %2 = "moved_op_2"() ({ - %3 = "inner_op"(%1) : (f32) -> (f32) - "yield"(%3) : (f32) -> () - }) : () -> (f32) - return %2 : f32 +func.func @move_region_dependencies_values(%arg0 : index, %cond : i1) -> index { + %0 = "before"() : () -> (index) + %1 = arith.addi %arg0, %arg0 {moved_op_1} : index + %2 = scf.if %cond -> index { + %3 = arith.muli %1, %1 {inner_op} : index + scf.yield %3 : index + } else { + scf.yield %1 : index + } + return %2 : index } -// CHECK-LABEL: func @move_region_dependencies_values() -// CHECK: %[[MOVED1:.+]] = "moved_op_1" -// CHECK: %[[MOVED2:.+]] = "moved_op_2" +// CHECK-LABEL: func @move_region_dependencies_values( +// CHECK: %[[MOVED1:.+]] = arith.addi {{.*}} {moved_op_1} +// CHECK: scf.if +// CHECK: arith.muli %[[MOVED1]], %[[MOVED1]] {inner_op} // CHECK: %[[BEFORE:.+]] = "before" module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { - %op1 = transform.structured.match ops{["moved_op_2"]} in %arg0 + %op1 = transform.structured.match ops{["scf.if"]} in %arg0 : (!transform.any_op) -> !transform.any_op %op2 = transform.structured.match ops{["before"]} in %arg0 : (!transform.any_op) -> !transform.any_op @@ -301,31 +305,31 @@ module attributes {transform.with_named_sequence} { // ----- // Move operations in toplogical sort order -func.func @move_values_in_topological_sort_order() -> f32 { - %0 = "before"() : () -> (f32) - %1 = "moved_op_1"() : () -> (f32) - %2 = "moved_op_2"() : () -> (f32) - %3 = "moved_op_3"(%1) : (f32) -> (f32) - %4 = "moved_op_4"(%1, %3) : (f32, f32) -> (f32) - %5 = "moved_op_5"(%2) : (f32) -> (f32) - %6 = "foo"(%4, %5) : (f32, f32) -> (f32) - return %6 : f32 -} -// CHECK-LABEL: func @move_values_in_topological_sort_order() -// CHECK: %[[MOVED_1:.+]] = "moved_op_1" -// CHECK-DAG: %[[MOVED_2:.+]] = "moved_op_3"(%[[MOVED_1]]) -// CHECK-DAG: %[[MOVED_3:.+]] = "moved_op_4"(%[[MOVED_1]], %[[MOVED_2]]) -// CHECK-DAG: %[[MOVED_4:.+]] = "moved_op_2" -// CHECK-DAG: %[[MOVED_5:.+]] = "moved_op_5"(%[[MOVED_4]]) +func.func @move_values_in_topological_sort_order(%arg0 : index, %arg1 : index) -> index { + %0 = "before"() : () -> (index) + %1 = arith.addi %arg0, %arg0 {moved_op_1} : index + %2 = arith.addi %arg1, %arg1 {moved_op_2} : index + %3 = arith.muli %1, %1 {moved_op_3} : index + %4 = arith.andi %1, %3 {moved_op_4} : index + %5 = arith.subi %2, %2 {moved_op_5} : index + %6 = "foo"(%4, %5) : (index, index) -> (index) + return %6 : index +} +// CHECK-LABEL: func @move_values_in_topological_sort_order( +// CHECK: %[[MOVED_1:.+]] = arith.addi {{.*}} {moved_op_1} +// CHECK-DAG: %[[MOVED_2:.+]] = arith.muli %[[MOVED_1]], %[[MOVED_1]] {moved_op_3} +// CHECK-DAG: %[[MOVED_3:.+]] = arith.andi %[[MOVED_1]], %[[MOVED_2]] {moved_op_4} +// CHECK-DAG: %[[MOVED_4:.+]] = arith.addi {{.*}} {moved_op_2} +// CHECK-DAG: %[[MOVED_5:.+]] = arith.subi %[[MOVED_4]], %[[MOVED_4]] {moved_op_5} // CHECK: %[[BEFORE:.+]] = "before" // CHECK: %[[FOO:.+]] = "foo"(%[[MOVED_3]], %[[MOVED_5]]) // CHECK: return %[[FOO]] module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { - %op1 = transform.structured.match ops{["moved_op_4"]} in %arg0 + %op1 = transform.structured.match ops{["arith.andi"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %op2 = transform.structured.match ops{["moved_op_5"]} in %arg0 + %op2 = transform.structured.match ops{["arith.subi"]} in %arg0 : (!transform.any_op) -> !transform.any_op %op3 = transform.structured.match ops{["before"]} in %arg0 : (!transform.any_op) -> !transform.any_op @@ -341,17 +345,17 @@ module attributes {transform.with_named_sequence} { // Move only those value definitions that are not dominated by insertion point -func.func @move_only_required_defns() -> (f32, f32, f32, f32) { - %0 = "unmoved_op"() : () -> (f32) - %1 = "dummy_op"() : () -> (f32) - %2 = "before"() : () -> (f32) - %3 = "moved_op"() : () -> (f32) - return %0, %1, %2, %3 : f32, f32, f32, f32 +func.func @move_only_required_defns(%arg0 : index) -> (index, index, index, index) { + %0 = "unmoved_op"() : () -> (index) + %1 = "dummy_op"() : () -> (index) + %2 = "before"() : () -> (index) + %3 = arith.addi %arg0, %arg0 {moved_op} : index + return %0, %1, %2, %3 : index, index, index, index } -// CHECK-LABEL: func @move_only_required_defns() +// CHECK-LABEL: func @move_only_required_defns( // CHECK: %[[UNMOVED:.+]] = "unmoved_op" // CHECK: %[[DUMMY:.+]] = "dummy_op" -// CHECK: %[[MOVED:.+]] = "moved_op" +// CHECK: %[[MOVED:.+]] = arith.addi {{.*}} {moved_op} // CHECK: %[[BEFORE:.+]] = "before" module attributes {transform.with_named_sequence} { @@ -362,7 +366,7 @@ module attributes {transform.with_named_sequence} { : (!transform.any_op) -> !transform.any_op %op3 = transform.structured.match ops{["before"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %op4 = transform.structured.match ops{["moved_op"]} in %arg0 + %op4 = transform.structured.match ops{["arith.addi"]} in %arg0 : (!transform.any_op) -> !transform.any_op %v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value %v2 = transform.get_result %op4[0] : (!transform.any_op) -> !transform.any_value @@ -374,19 +378,19 @@ module attributes {transform.with_named_sequence} { // ----- -// Move only those value definitions that are not dominated by insertion point +// Move only those value definitions that are not dominated by insertion point (duplicate test) -func.func @move_only_required_defns() -> (f32, f32, f32, f32) { - %0 = "unmoved_op"() : () -> (f32) - %1 = "dummy_op"() : () -> (f32) - %2 = "before"() : () -> (f32) - %3 = "moved_op"() : () -> (f32) - return %0, %1, %2, %3 : f32, f32, f32, f32 +func.func @move_only_required_defns_2(%arg0 : index) -> (index, index, index, index) { + %0 = "unmoved_op"() : () -> (index) + %1 = "dummy_op"() : () -> (index) + %2 = "before"() : () -> (index) + %3 = arith.subi %arg0, %arg0 {moved_op} : index + return %0, %1, %2, %3 : index, index, index, index } -// CHECK-LABEL: func @move_only_required_defns() +// CHECK-LABEL: func @move_only_required_defns_2( // CHECK: %[[UNMOVED:.+]] = "unmoved_op" // CHECK: %[[DUMMY:.+]] = "dummy_op" -// CHECK: %[[MOVED:.+]] = "moved_op" +// CHECK: %[[MOVED:.+]] = arith.subi {{.*}} {moved_op} // CHECK: %[[BEFORE:.+]] = "before" module attributes {transform.with_named_sequence} { @@ -397,7 +401,7 @@ module attributes {transform.with_named_sequence} { : (!transform.any_op) -> !transform.any_op %op3 = transform.structured.match ops{["before"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %op4 = transform.structured.match ops{["moved_op"]} in %arg0 + %op4 = transform.structured.match ops{["arith.subi"]} in %arg0 : (!transform.any_op) -> !transform.any_op %v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value %v2 = transform.get_result %op4[0] : (!transform.any_op) -> !transform.any_value @@ -410,23 +414,23 @@ module attributes {transform.with_named_sequence} { // ----- // Check handling of block arguments -func.func @move_only_required_defns() -> (f32, f32) { - %0 = "unmoved_op"() : () -> (f32) - cf.br ^bb0(%0 : f32) - ^bb0(%arg0 : f32) : - %1 = "before"() : () -> (f32) - %2 = "moved_op"(%arg0) : (f32) -> (f32) - return %1, %2 : f32, f32 -} -// CHECK-LABEL: func @move_only_required_defns() -// CHECK: %[[MOVED:.+]] = "moved_op" +func.func @move_with_block_arguments() -> (index, index) { + %0 = "unmoved_op"() : () -> (index) + cf.br ^bb0(%0 : index) + ^bb0(%arg0 : index) : + %1 = "before"() : () -> (index) + %2 = arith.addi %arg0, %arg0 {moved_op} : index + return %1, %2 : index, index +} +// CHECK-LABEL: func @move_with_block_arguments() +// CHECK: %[[MOVED:.+]] = arith.addi {{.*}} {moved_op} // CHECK: %[[BEFORE:.+]] = "before" module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { %op1 = transform.structured.match ops{["before"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %op2 = transform.structured.match ops{["moved_op"]} in %arg0 + %op2 = transform.structured.match ops{["arith.addi"]} in %arg0 : (!transform.any_op) -> !transform.any_op %v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value transform.test.move_value_defns %v1 before %op1 @@ -438,20 +442,20 @@ module attributes {transform.with_named_sequence} { // ----- // Do not move across basic blocks -func.func @no_move_across_basic_blocks() -> (f32, f32) { - %0 = "unmoved_op"() : () -> (f32) - %1 = "before"() : () -> (f32) - cf.br ^bb0(%0 : f32) - ^bb0(%arg0 : f32) : - %2 = "moved_op"(%arg0) : (f32) -> (f32) - return %1, %2 : f32, f32 +func.func @no_move_across_basic_blocks() -> (index, index) { + %0 = "unmoved_op"() : () -> (index) + %1 = "before"() : () -> (index) + cf.br ^bb0(%0 : index) + ^bb0(%arg0 : index) : + %2 = arith.addi %arg0, %arg0 {moved_op} : index + return %1, %2 : index, index } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { %op1 = transform.structured.match ops{["before"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %op2 = transform.structured.match ops{["moved_op"]} in %arg0 + %op2 = transform.structured.match ops{["arith.addi"]} in %arg0 : (!transform.any_op) -> !transform.any_op %v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value // expected-remark@+1{{unsupported case of moving definition of value before an insertion point in a different basic block}} @@ -463,24 +467,22 @@ module attributes {transform.with_named_sequence} { // ----- -func.func @move_isolated_from_above() -> () { - %1 = "before"() : () -> (f32) - %2 = "moved0"() : () -> (f32) - %3 = test.isolated_one_region_op %2 {} : f32 -> f32 - %4 = "moved1"(%3) : (f32) -> (f32) +func.func @move_isolated_from_above(%arg0 : index) -> () { + %1 = "before"() : () -> (index) + %2 = arith.addi %arg0, %arg0 {moved0} : index + %3 = arith.muli %2, %2 {moved1} : index return } -// CHECK-LABEL: func @move_isolated_from_above() -// CHECK: %[[MOVED0:.+]] = "moved0" -// CHECK: %[[ISOLATED:.+]] = test.isolated_one_region_op %[[MOVED0]] -// CHECK: %[[MOVED1:.+]] = "moved1"(%[[ISOLATED]]) +// CHECK-LABEL: func @move_isolated_from_above( +// CHECK: %[[MOVED0:.+]] = arith.addi {{.*}} {moved0} +// CHECK: %[[MOVED1:.+]] = arith.muli %[[MOVED0]], %[[MOVED0]] {moved1} // CHECK: %[[BEFORE:.+]] = "before" module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { %op1 = transform.structured.match ops{["before"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %op2 = transform.structured.match ops{["moved1"]} in %arg0 + %op2 = transform.structured.match ops{["arith.muli"]} in %arg0 : (!transform.any_op) -> !transform.any_op %v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value transform.test.move_value_defns %v1 before %op1 diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir index e730450..7130667 100644 --- a/mlir/test/Transforms/remove-dead-values.mlir +++ b/mlir/test/Transforms/remove-dead-values.mlir @@ -118,6 +118,17 @@ func.func @main(%arg0 : i32) { // ----- +// CHECK-LABEL: func.func private @clean_func_op_remove_side_effecting_op() { +// CHECK-NEXT: return +// CHECK-NEXT: } +func.func private @clean_func_op_remove_side_effecting_op(%arg0: i32) -> (i32) { + // vector.print has a side effect but the op is dead. + vector.print %arg0 : i32 + return %arg0 : i32 +} + +// ----- + // %arg0 is not live because it is never used. %arg1 is not live because its // user `arith.addi` doesn't have any uses and the value that it is forwarded to // (%non_live_0) also doesn't have any uses. @@ -674,3 +685,32 @@ func.func @dead_value_loop_ivs_no_result(%lb: index, %ub: index, %step: index, % } return } + +// ----- + +// CHECK-LABEL: func @op_block_have_dead_arg +func.func @op_block_have_dead_arg(%arg0: index, %arg1: index, %arg2: i1) { + scf.execute_region { + cf.cond_br %arg2, ^bb1(%arg0 : index), ^bb1(%arg1 : index) + ^bb1(%0: index): + scf.yield + } + // CHECK-NEXT: return + return +} + +// ----- + +// CHECK-LABEL: func private @remove_dead_branch_op() +// CHECK-NEXT: ub.unreachable +// CHECK-NEXT: ^{{.*}}: +// CHECK-NEXT: return +// CHECK-NEXT: ^{{.*}}: +// CHECK-NEXT: return +func.func private @remove_dead_branch_op(%c: i1, %arg0: i64, %arg1: i64) -> (i64) { + cf.cond_br %c, ^bb1, ^bb2 +^bb1: + return %arg0 : i64 +^bb2: + return %arg1 : i64 +} diff --git a/mlir/test/Transforms/test-legalizer-full.mlir b/mlir/test/Transforms/test-legalizer-full.mlir index 42cec68..8da9109 100644 --- a/mlir/test/Transforms/test-legalizer-full.mlir +++ b/mlir/test/Transforms/test-legalizer-full.mlir @@ -72,3 +72,21 @@ builtin.module { } } + +// ----- + +// The region of "test.post_order_legalization" is converted before the op. + +// expected-remark@+1 {{applyFullConversion failed}} +builtin.module { +func.func @test_preorder_legalization() { + // expected-error@+1 {{failed to legalize operation 'test.post_order_legalization'}} + "test.post_order_legalization"() ({ + ^bb0(%arg0: i64): + // Not-explicitly-legal ops are not allowed to survive. + "test.remaining_consumer"(%arg0) : (i64) -> () + "test.invalid"(%arg0) : (i64) -> () + }) : () -> () + return +} +} diff --git a/mlir/test/Transforms/test-legalizer-no-materializations.mlir b/mlir/test/Transforms/test-legalizer-no-materializations.mlir new file mode 100644 index 0000000..82dd742 --- /dev/null +++ b/mlir/test/Transforms/test-legalizer-no-materializations.mlir @@ -0,0 +1,67 @@ +// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0 build-materializations=0 attach-debug-materialization-kind=1" -verify-diagnostics %s | FileCheck %s --check-prefix=CHECK-KIND + +// CHECK-LABEL: func @dropped_input_in_use +// CHECK-KIND-LABEL: func @dropped_input_in_use +func.func @dropped_input_in_use(%arg: i16, %arg2: i64) { + // CHECK-NEXT: %[[cast:.*]] = "test.cast"() : () -> i16 + // CHECK-NEXT: "work"(%[[cast]]) : (i16) + // CHECK-KIND-NEXT: %[[cast:.*]] = builtin.unrealized_conversion_cast to i16 {__kind__ = "source"} + // CHECK-KIND-NEXT: "work"(%[[cast]]) : (i16) + // expected-remark@+1 {{op 'work' is not legalizable}} + "work"(%arg) : (i16) -> () +} + +// ----- + +// CHECK-KIND-LABEL: func @test_lookup_without_converter +// CHECK-KIND: %[[producer:.*]] = "test.valid_producer"() : () -> i16 +// CHECK-KIND: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[producer]] : i16 to f64 {__kind__ = "target"} +// CHECK-KIND: "test.valid_consumer"(%[[cast]]) : (f64) -> () +// CHECK-KIND: "test.valid_consumer"(%[[producer]]) : (i16) -> () +func.func @test_lookup_without_converter() { + %0 = "test.replace_with_valid_producer"() {type = i16} : () -> (i64) + "test.replace_with_valid_consumer"(%0) {with_converter} : (i64) -> () + // Make sure that the second "replace_with_valid_consumer" lowering does not + // lookup the materialization that was created for the above op. + "test.replace_with_valid_consumer"(%0) : (i64) -> () + // expected-remark@+1 {{op 'func.return' is not legalizable}} + return +} + +// ----- + +// CHECK-LABEL: func @remap_moved_region_args +func.func @remap_moved_region_args() { + // CHECK-NEXT: return + // CHECK-NEXT: ^bb1(%[[arg0:.*]]: i64, %[[arg1:.*]]: i16, %[[arg2:.*]]: i64, %[[arg3:.*]]: f32): + // CHECK-NEXT: %[[cast1:.*]]:2 = builtin.unrealized_conversion_cast %[[arg3]] : f32 to f16, f16 + // CHECK-NEXT: %[[cast2:.*]] = builtin.unrealized_conversion_cast %[[arg2]] : i64 to f64 + // CHECK-NEXT: %[[cast3:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to f64 + // CHECK-NEXT: %[[cast4:.*]] = "test.cast"(%[[cast1]]#0, %[[cast1]]#1) : (f16, f16) -> f32 + // CHECK-NEXT: "test.valid"(%[[cast3]], %[[cast2]], %[[cast4]]) : (f64, f64, f32) + "test.region"() ({ + ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32): + "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> () + }) : () -> () + // expected-remark@+1 {{op 'func.return' is not legalizable}} + return +} + +// ----- + +// CHECK-LABEL: func @remap_cloned_region_args +func.func @remap_cloned_region_args() { + // CHECK-NEXT: return + // CHECK-NEXT: ^bb1(%[[arg0:.*]]: i64, %[[arg1:.*]]: i16, %[[arg2:.*]]: i64, %[[arg3:.*]]: f32): + // CHECK-NEXT: %[[cast1:.*]]:2 = builtin.unrealized_conversion_cast %[[arg3]] : f32 to f16, f16 + // CHECK-NEXT: %[[cast2:.*]] = builtin.unrealized_conversion_cast %[[arg2]] : i64 to f64 + // CHECK-NEXT: %[[cast3:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to f64 + // CHECK-NEXT: %[[cast4:.*]] = "test.cast"(%[[cast1]]#0, %[[cast1]]#1) : (f16, f16) -> f32 + // CHECK-NEXT: "test.valid"(%[[cast3]], %[[cast2]], %[[cast4]]) : (f64, f64, f32) + "test.region"() ({ + ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32): + "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> () + }) {legalizer.should_clone} : () -> () + // expected-remark@+1 {{op 'func.return' is not legalizable}} + return +} diff --git a/mlir/test/Transforms/test-legalizer-no-rollback.mlir b/mlir/test/Transforms/test-legalizer-no-rollback.mlir new file mode 100644 index 0000000..5f421a3 --- /dev/null +++ b/mlir/test/Transforms/test-legalizer-no-rollback.mlir @@ -0,0 +1,23 @@ +// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0" -verify-diagnostics %s | FileCheck %s + +// CHECK-LABEL: @conditional_replacement( +// CHECK-SAME: %[[arg0:.*]]: i43) +// CHECK: %[[cast1:.*]] = "test.cast"(%[[arg0]]) : (i43) -> i42 +// CHECK: %[[legal:.*]] = "test.legal_op"() : () -> i42 +// CHECK: %[[cast2:.*]] = "test.cast"(%[[legal]], %[[legal]]) : (i42, i42) -> i42 +// Uses were replaced for dummy_user_1. +// CHECK: "test.dummy_user_1"(%[[cast2]]) {replace_uses} : (i42) -> () +// Uses were also replaced for dummy_user_2, but not by value_replace. The uses +// were replaced due to the block signature conversion. +// CHECK: "test.dummy_user_2"(%[[cast1]]) : (i42) -> () +// CHECK: "test.value_replace"(%[[cast1]], %[[legal]]) {conditional, is_legal} : (i42, i42) -> () +func.func @conditional_replacement(%arg0: i42) { + %repl = "test.legal_op"() : () -> (i42) + // expected-remark @+1 {{is not legalizable}} + "test.dummy_user_1"(%arg0) {replace_uses} : (i42) -> () + // expected-remark @+1 {{is not legalizable}} + "test.dummy_user_2"(%arg0) {} : (i42) -> () + // Perform a conditional 1:N replacement. + "test.value_replace"(%arg0, %repl) {conditional} : (i42, i42) -> () + "test.return"() : () -> () +} diff --git a/mlir/test/Transforms/test-legalizer-rollback.mlir b/mlir/test/Transforms/test-legalizer-rollback.mlir index 71e1178..4bcca6b 100644 --- a/mlir/test/Transforms/test-legalizer-rollback.mlir +++ b/mlir/test/Transforms/test-legalizer-rollback.mlir @@ -163,3 +163,22 @@ func.func @create_unregistered_op_in_pattern() -> i32 { "test.return"(%0) : (i32) -> () } } + +// ----- + +// CHECK-LABEL: func @test_failed_preorder_legalization +// CHECK: "test.post_order_legalization"() ({ +// CHECK: %[[r:.*]] = "test.illegal_op_g"() : () -> i32 +// CHECK: "test.return"(%[[r]]) : (i32) -> () +// CHECK: }) : () -> () +// expected-remark @+1 {{applyPartialConversion failed}} +module { +func.func @test_failed_preorder_legalization() { + // expected-error @+1 {{failed to legalize operation 'test.post_order_legalization' that was explicitly marked illegal}} + "test.post_order_legalization"() ({ + %0 = "test.illegal_op_g"() : () -> (i32) + "test.return"(%0) : (i32) -> () + }) : () -> () + return +} +} diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index 94c5bb4..88a71cc 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -1,7 +1,6 @@ // RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=1" -verify-diagnostics %s | FileCheck %s // RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=1" -verify-diagnostics -profile-actions-to=- %s | FileCheck %s --check-prefix=CHECK-PROFILER // RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0" -verify-diagnostics %s | FileCheck %s -// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0 build-materializations=0 attach-debug-materialization-kind=1" -verify-diagnostics %s | FileCheck %s --check-prefix=CHECK-KIND // CHECK-PROFILER: "name": "pass-execution", "cat": "PERF", "ph": "B" // CHECK-PROFILER: "name": "apply-conversion", "cat": "PERF", "ph": "B" @@ -146,36 +145,6 @@ func.func @no_remap_nested() { // ----- -// CHECK-LABEL: func @remap_moved_region_args -func.func @remap_moved_region_args() { - // CHECK-NEXT: return - // CHECK-NEXT: ^bb1(%{{.*}}: f64, %{{.*}}: f64, %{{.*}}: f16, %{{.*}}: f16): - // CHECK-NEXT: "test.cast"{{.*}} : (f16, f16) -> f32 - // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64, f32) - "test.region"() ({ - ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32): - "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> () - }) : () -> () - // expected-remark@+1 {{op 'func.return' is not legalizable}} - return -} - -// ----- - -// CHECK-LABEL: func @remap_cloned_region_args -func.func @remap_cloned_region_args() { - // CHECK-NEXT: return - // CHECK-NEXT: ^bb1(%{{.*}}: f64, %{{.*}}: f64, %{{.*}}: f16, %{{.*}}: f16): - // CHECK-NEXT: "test.cast"{{.*}} : (f16, f16) -> f32 - // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64, f32) - "test.region"() ({ - ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32): - "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> () - }) {legalizer.should_clone} : () -> () - // expected-remark@+1 {{op 'func.return' is not legalizable}} - return -} - // CHECK-LABEL: func @remap_drop_region func.func @remap_drop_region() { // CHECK-NEXT: return @@ -191,12 +160,9 @@ func.func @remap_drop_region() { // ----- // CHECK-LABEL: func @dropped_input_in_use -// CHECK-KIND-LABEL: func @dropped_input_in_use func.func @dropped_input_in_use(%arg: i16, %arg2: i64) { // CHECK-NEXT: %[[cast:.*]] = "test.cast"() : () -> i16 // CHECK-NEXT: "work"(%[[cast]]) : (i16) - // CHECK-KIND-NEXT: %[[cast:.*]] = builtin.unrealized_conversion_cast to i16 {__kind__ = "source"} - // CHECK-KIND-NEXT: "work"(%[[cast]]) : (i16) // expected-remark@+1 {{op 'work' is not legalizable}} "work"(%arg) : (i16) -> () } @@ -452,11 +418,6 @@ func.func @test_multiple_1_to_n_replacement() { // CHECK: %[[cast:.*]] = "test.cast"(%[[producer]]) : (i16) -> f64 // CHECK: "test.valid_consumer"(%[[cast]]) : (f64) -> () // CHECK: "test.valid_consumer"(%[[producer]]) : (i16) -> () -// CHECK-KIND-LABEL: func @test_lookup_without_converter -// CHECK-KIND: %[[producer:.*]] = "test.valid_producer"() : () -> i16 -// CHECK-KIND: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[producer]] : i16 to f64 {__kind__ = "target"} -// CHECK-KIND: "test.valid_consumer"(%[[cast]]) : (f64) -> () -// CHECK-KIND: "test.valid_consumer"(%[[producer]]) : (i16) -> () func.func @test_lookup_without_converter() { %0 = "test.replace_with_valid_producer"() {type = i16} : () -> (i64) "test.replace_with_valid_consumer"(%0) {with_converter} : (i64) -> () @@ -487,3 +448,35 @@ func.func @test_working_1to1_pattern(%arg0: f16) { "test.type_consumer"(%arg0) : (f16) -> () "test.return"() : () -> () } + +// ----- + +// The region of "test.post_order_legalization" is converted before the op. + +// CHECK: notifyBlockInserted into test.post_order_legalization: was unlinked +// CHECK: notifyOperationInserted: test.invalid +// CHECK: notifyBlockErased +// CHECK: notifyOperationInserted: test.valid, was unlinked +// CHECK: notifyOperationReplaced: test.invalid +// CHECK: notifyOperationErased: test.invalid +// CHECK: notifyOperationModified: test.post_order_legalization + +// CHECK-LABEL: func @test_preorder_legalization +// CHECK: "test.post_order_legalization"() ({ +// CHECK: ^{{.*}}(%[[arg0:.*]]: f64): +// Note: The survival of a not-explicitly-invalid operation does *not* cause +// a conversion failure in when applying a partial conversion. +// CHECK: %[[cast:.*]] = "test.cast"(%[[arg0]]) : (f64) -> i64 +// CHECK: "test.remaining_consumer"(%[[cast]]) : (i64) -> () +// CHECK: "test.valid"(%[[arg0]]) : (f64) -> () +// CHECK: }) {is_legal} : () -> () +func.func @test_preorder_legalization() { + "test.post_order_legalization"() ({ + ^bb0(%arg0: i64): + // expected-remark @+1 {{'test.remaining_consumer' is not legalizable}} + "test.remaining_consumer"(%arg0) : (i64) -> () + "test.invalid"(%arg0) : (i64) -> () + }) : () -> () + // expected-remark @+1 {{'func.return' is not legalizable}} + return +} diff --git a/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp index 8e2f03b..99f72c6 100644 --- a/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp +++ b/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp @@ -56,6 +56,17 @@ struct TestLivenessAnalysisPass liveness->print(os); os << "\n"; } + for (auto [regionIndex, region] : llvm::enumerate(op->getRegions())) { + os << " region: #" << regionIndex << ":\n"; + for (auto [argumntIndex, argument] : + llvm::enumerate(region.getArguments())) { + const Liveness *liveness = livenessAnalysis.getLiveness(argument); + assert(liveness && "expected a sparse lattice"); + os << " argument: #" << argumntIndex << ": "; + liveness->print(os); + os << "\n"; + } + } }); } }; diff --git a/mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp b/mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp index 027b0a1..3ff0dc8 100644 --- a/mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp +++ b/mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp @@ -46,7 +46,7 @@ struct TestPointerLikeTypeInterfacePass Pass::Option<std::string> testMode{ *this, "test-mode", - llvm::cl::desc("Test mode: walk, alloc, copy, or free"), + llvm::cl::desc("Test mode: walk, alloc, copy, free, load, or store"), llvm::cl::init("walk")}; StringRef getArgument() const override { @@ -75,6 +75,10 @@ private: void testGenCopy(Operation *srcOp, Operation *destOp, Value srcResult, Value destResult, PointerLikeType pointerType, OpBuilder &builder); + void testGenLoad(Operation *op, Value result, PointerLikeType pointerType, + OpBuilder &builder); + void testGenStore(Operation *op, Value result, PointerLikeType pointerType, + OpBuilder &builder, Value providedValue = {}); struct PointerCandidate { Operation *op; @@ -92,9 +96,12 @@ void TestPointerLikeTypeInterfacePass::runOnOperation() { auto func = getOperation(); OpBuilder builder(&getContext()); - if (testMode == "alloc" || testMode == "free") { + if (testMode == "alloc" || testMode == "free" || testMode == "load" || + testMode == "store") { // Collect all candidates first SmallVector<PointerCandidate> candidates; + // For store mode, also look for a test value to use + Value testValue; func.walk([&](Operation *op) { if (op->hasAttr("test.ptr")) { for (auto result : op->getResults()) { @@ -105,6 +112,11 @@ void TestPointerLikeTypeInterfacePass::runOnOperation() { } } } + // Collect value marked with test.value for store tests + if (testMode == "store" && op->hasAttr("test.value")) { + if (op->getNumResults() > 0) + testValue = op->getResult(0); + } }); // Now test all candidates @@ -115,6 +127,12 @@ void TestPointerLikeTypeInterfacePass::runOnOperation() { else if (testMode == "free") testGenFree(candidate.op, candidate.result, candidate.pointerType, builder); + else if (testMode == "load") + testGenLoad(candidate.op, candidate.result, candidate.pointerType, + builder); + else if (testMode == "store") + testGenStore(candidate.op, candidate.result, candidate.pointerType, + builder, testValue); } } else if (testMode == "copy") { // Collect all source and destination candidates @@ -292,6 +310,105 @@ void TestPointerLikeTypeInterfacePass::testGenCopy( } } +void TestPointerLikeTypeInterfacePass::testGenLoad(Operation *op, Value result, + PointerLikeType pointerType, + OpBuilder &builder) { + Location loc = op->getLoc(); + + // Create a new builder with the listener and set insertion point + OperationTracker tracker; + OpBuilder newBuilder(op->getContext()); + newBuilder.setListener(&tracker); + newBuilder.setInsertionPointAfter(op); + + // Call the genLoad API + auto typedResult = cast<TypedValue<PointerLikeType>>(result); + Value loadRes = pointerType.genLoad(newBuilder, loc, typedResult, Type()); + + if (loadRes) { + llvm::errs() << "Successfully generated load for operation: "; + op->print(llvm::errs()); + llvm::errs() << "\n"; + llvm::errs() << "\tLoaded value type: "; + loadRes.getType().print(llvm::errs()); + llvm::errs() << "\n"; + + // Print all operations that were inserted + for (Operation *insertedOp : tracker.insertedOps) { + llvm::errs() << "\tGenerated: "; + insertedOp->print(llvm::errs()); + llvm::errs() << "\n"; + } + } else { + llvm::errs() << "Failed to generate load for operation: "; + op->print(llvm::errs()); + llvm::errs() << "\n"; + } +} + +void TestPointerLikeTypeInterfacePass::testGenStore(Operation *op, Value result, + PointerLikeType pointerType, + OpBuilder &builder, + Value providedValue) { + Location loc = op->getLoc(); + + // Create a new builder with the listener and set insertion point + OperationTracker tracker; + OpBuilder newBuilder(op->getContext()); + newBuilder.setListener(&tracker); + newBuilder.setInsertionPointAfter(op); + + // Use provided value if available, otherwise create a constant + Value valueToStore = providedValue; + if (!valueToStore) { + // Create a test value to store - use a constant matching the element type + Type elementType = pointerType.getElementType(); + if (!elementType) { + llvm::errs() << "Failed to generate store for operation: "; + op->print(llvm::errs()); + llvm::errs() << "\n"; + return; + } + + if (elementType.isIntOrIndex()) { + auto attr = newBuilder.getIntegerAttr(elementType, 42); + valueToStore = + arith::ConstantOp::create(newBuilder, loc, elementType, attr); + } else if (auto floatType = dyn_cast<FloatType>(elementType)) { + auto attr = newBuilder.getFloatAttr(floatType, 42.0); + valueToStore = + arith::ConstantOp::create(newBuilder, loc, floatType, attr); + } else { + llvm::errs() << "Failed to generate store for operation: "; + op->print(llvm::errs()); + llvm::errs() << "\n"; + return; + } + } + + // Call the genStore API + auto typedResult = cast<TypedValue<PointerLikeType>>(result); + bool success = + pointerType.genStore(newBuilder, loc, valueToStore, typedResult); + + if (success) { + llvm::errs() << "Successfully generated store for operation: "; + op->print(llvm::errs()); + llvm::errs() << "\n"; + + // Print all operations that were inserted + for (Operation *insertedOp : tracker.insertedOps) { + llvm::errs() << "\tGenerated: "; + insertedOp->print(llvm::errs()); + llvm::errs() << "\n"; + } + } else { + llvm::errs() << "Failed to generate store for operation: "; + op->print(llvm::errs()); + llvm::errs() << "\n"; + } +} + } // namespace //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/OpenACC/TestRecipePopulate.cpp b/mlir/test/lib/Dialect/OpenACC/TestRecipePopulate.cpp index 35f092c..2506ca4 100644 --- a/mlir/test/lib/Dialect/OpenACC/TestRecipePopulate.cpp +++ b/mlir/test/lib/Dialect/OpenACC/TestRecipePopulate.cpp @@ -93,6 +93,29 @@ void TestRecipePopulatePass::runOnOperation() { if (!recipe) { op->emitError("Failed to create firstprivate recipe for ") << varName; } + } else if (recipeType == "private_from_firstprivate") { + // First create a firstprivate recipe, then use it to drive creation of a + // matching private recipe via the convenience overload. Give each recipe + // a stable, predictable name so tests can check both. + std::string firstprivName = "first_firstprivate_" + varName; + std::string privName = "private_from_firstprivate_" + varName; + + auto firstpriv = FirstprivateRecipeOp::createAndPopulate( + builder, loc, firstprivName, var.getType(), varName, bounds); + + if (!firstpriv) { + op->emitError("Failed to create firstprivate recipe for ") << varName; + return; + } + + auto priv = PrivateRecipeOp::createAndPopulate(builder, loc, privName, + *firstpriv); + + if (!priv) { + op->emitError( + "Failed to create private recipe (from firstprivate) for ") + << varName; + } } } } diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index 21d75f5..43392d7 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -37,7 +37,6 @@ #include "llvm/Support/Base64.h" #include "llvm/Support/Casting.h" -#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Interfaces/FoldInterfaces.h" #include "mlir/Reducer/ReductionPatternInterface.h" diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp index 4d4ec02..8689265 100644 --- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp @@ -320,10 +320,10 @@ LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( } //===----------------------------------------------------------------------===// -// OpWithResultShapePerDimInterfaceOp +// ReifyShapedTypeUsingReifyResultShapesOp //===----------------------------------------------------------------------===// -LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( +LogicalResult ReifyShapedTypeUsingReifyResultShapesOp::reifyResultShapes( OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { Location loc = getLoc(); shapes.reserve(getNumOperands()); @@ -345,6 +345,103 @@ LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( } //===----------------------------------------------------------------------===// +// ReifyShapedTypeUsingReifyShapeOfResultOp +//===----------------------------------------------------------------------===// + +LogicalResult ReifyShapedTypeUsingReifyShapeOfResultOp::reifyResultShapes( + OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { + return failure(); +} + +FailureOr<SmallVector<OpFoldResult>> +ReifyShapedTypeUsingReifyShapeOfResultOp::reifyShapeOfResult(OpBuilder &builder, + int resultIndex) { + Location loc = getLoc(); + Value sourceOperand = getOperand(getNumOperands() - 1 - resultIndex); + SmallVector<OpFoldResult> shape = + tensor::getMixedSizes(builder, loc, sourceOperand); + return shape; +} + +//===----------------------------------------------------------------------===// +// ReifyShapedTypeUsingReifyDimOfResultOp +//===----------------------------------------------------------------------===// + +LogicalResult ReifyShapedTypeUsingReifyDimOfResultOp::reifyResultShapes( + OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { + return failure(); +} + +FailureOr<SmallVector<OpFoldResult>> +ReifyShapedTypeUsingReifyDimOfResultOp::reifyShapeOfResult(OpBuilder &builder, + int resultIndex) { + return failure(); +} + +FailureOr<OpFoldResult> +ReifyShapedTypeUsingReifyDimOfResultOp::reifyDimOfResult(OpBuilder &builder, + int resultIndex, + int dim) { + Location loc = getLoc(); + Value sourceOperand = getOperand(getNumOperands() - 1 - resultIndex); + OpFoldResult shape = tensor::getMixedSize(builder, loc, sourceOperand, dim); + return shape; +} + +//===----------------------------------------------------------------------===// +// UnreifableResultShapesOp +//===----------------------------------------------------------------------===// + +LogicalResult UnreifiableResultShapesOp::reifyResultShapes( + OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { + Location loc = getLoc(); + shapes.resize(1); + shapes[0] = {tensor::getMixedSize(builder, loc, getOperand(), 0), + OpFoldResult()}; + return success(); +} + +//===----------------------------------------------------------------------===// +// UnreifableResultShapeOp +//===----------------------------------------------------------------------===// + +LogicalResult UnreifiableResultShapeOp::reifyResultShapes( + OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { + return failure(); +} + +FailureOr<SmallVector<OpFoldResult>> +UnreifiableResultShapeOp::reifyShapeOfResult(OpBuilder &builder, + int resultIndex) { + SmallVector<OpFoldResult> shape = { + tensor::getMixedSize(builder, getLoc(), getOperand(), 0), OpFoldResult()}; + return shape; +} + +//===----------------------------------------------------------------------===// +// UnreifableResultShapeOp +//===----------------------------------------------------------------------===// + +LogicalResult UnreifiableDimOfResultShapeOp::reifyResultShapes( + OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { + return failure(); +} + +FailureOr<SmallVector<OpFoldResult>> +UnreifiableDimOfResultShapeOp::reifyShapeOfResult(OpBuilder &builder, + int resultIndex) { + return failure(); +} + +FailureOr<OpFoldResult> +UnreifiableDimOfResultShapeOp::reifyDimOfResult(OpBuilder &builder, + int resultIndex, int dim) { + if (dim == 0) + return tensor::getMixedSize(builder, getLoc(), getOperand(), 0); + return failure(); +} + +//===----------------------------------------------------------------------===// // SideEffectOp //===----------------------------------------------------------------------===// @@ -1052,6 +1149,32 @@ LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes( } //===----------------------------------------------------------------------===// +// TilingNoDpsOp +//===----------------------------------------------------------------------===// + +SmallVector<Range> TilingNoDpsOp::getIterationDomain(OpBuilder &builder) { + return {}; +} + +SmallVector<utils::IteratorType> TilingNoDpsOp::getLoopIteratorTypes() { + return {}; +} + +FailureOr<TilingResult> +TilingNoDpsOp::getTiledImplementation(OpBuilder &builder, + ArrayRef<OpFoldResult> offsets, + ArrayRef<OpFoldResult> sizes) { + return failure(); +} + +LogicalResult TilingNoDpsOp::getResultTilePosition( + OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets, + ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets, + SmallVector<OpFoldResult> &resultSizes) { + return failure(); +} + +//===----------------------------------------------------------------------===// // OpWithShapedTypeInferTypeAdaptorInterfaceOp //===----------------------------------------------------------------------===// @@ -1514,3 +1637,14 @@ test::TestCreateTensorOp::getBufferType( return convertTensorToBuffer(getOperation(), options, type); } + +// Define a custom builder for ManyRegionsOp declared in TestOps.td. +// OpBuilder<(ins "::std::unique_ptr<::mlir::Region>":$firstRegion, +// "::std::unique_ptr<::mlir::Region>":$secondRegion)> +void test::ManyRegionsOp::build( + mlir::OpBuilder &builder, mlir::OperationState &state, + llvm::SmallVectorImpl<std::unique_ptr<mlir::Region>> &®ions) { + for (auto &®ionPtr : std::move(regions)) + state.addRegion(std::move(regionPtr)); + ManyRegionsOp::build(builder, state, {}, regions.size()); +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.h b/mlir/test/lib/Dialect/Test/TestOps.h index 4201ade..6792743 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.h +++ b/mlir/test/lib/Dialect/Test/TestOps.h @@ -42,6 +42,7 @@ #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Interfaces/ViewLikeInterface.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" namespace test { class TestDialect; diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index a3430ba..5417ae9 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -30,6 +30,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/MemorySlotInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/TilingInterface.td" include "mlir/Interfaces/ValueBoundsOpInterface.td" include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td" include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td" @@ -119,6 +120,13 @@ def SymbolOp : TEST_Op<"symbol", [NoMemoryEffect, Symbol]> { OptionalAttr<StrAttr>:$sym_visibility); } +def SymbolWithResultOp : TEST_Op<"symbol_with_result", [Symbol]> { + let summary = "invalid symbol operation that produces an SSA result"; + let arguments = (ins StrAttr:$sym_name, + OptionalAttr<StrAttr>:$sym_visibility); + let results = (outs AnyType:$result); +} + def OverriddenSymbolVisibilityOp : TEST_Op<"overridden_symbol_visibility", [ DeclareOpInterfaceMethods<Symbol, ["getVisibility", "setVisibility"]>, ]> { @@ -914,13 +922,97 @@ def OpWithResultShapeInterfaceOp : TEST_Op<"op_with_result_shape_interface", let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2); } -def OpWithResultShapePerDimInterfaceOp : - TEST_Op<"op_with_result_shape_per_dim_interface", - [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> { +def ReifyShapedTypeUsingReifyResultShapesOp : + TEST_Op<"reify_shaped_type_using_reify_result_shapes", + [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, + ["reifyResultShapes"]>]> { + let description = [{ + Test that when resolving a single dimension of a result for an operation + that doesnt implement `reifyShapeOfResult` nor implements `reifyDimOfResult` + calls into the implementation of `reifyResultShapes` to get the required value. + The op semantics is that the first result has the same shape as the second operand + and the second result has the same shape as the first operand. + }]; + let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2); + let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2); +} + +def ReifyShapedTypeUsingReifyShapeOfResultOp : + TEST_Op<"reify_shaped_type_using_reify_shape_of_result", + [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, + ["reifyResultShapes", "reifyShapeOfResult"]>]> { + let description = [{ + Test that when resolving a single dimension of a result for an operation + that doesnt implement `reifyDimOfResult` but implements `reifyShapeOfResult`, which + is used to get the required value. `reifyResultShapes` is implemented as a failure + (which is also the default implementation) to ensure it is not called. + The op semantics is that the first result has the same shape as the second operand + and the second result has the same shape as the first operand. + }]; + let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2); + let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2); +} + +def ReifyShapedTypeUsingReifyDimOfResultOp : + TEST_Op<"reify_shaped_type_using_reify_dim_of_result", + [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, + ["reifyResultShapes", "reifyShapeOfResult", "reifyDimOfResult"]>]> { + let description = [{ + Test that when resolving a single dimension of a result for an operation + that implements `reifyDimOfResult`, which is used to get the required value. + `reifyResultShapes` and `reifyShapeOfResult` are implemented as failures + to ensure they are not called. The op semantics is that the first result has + the same shape as the second operand and the second result has the same shape + as the first operand. + }]; let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2); let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2); } +def UnreifiableResultShapesOp : TEST_Op<"unreifiable_result_shapes", + [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, + ["reifyResultShapes"]>]> { + let description = [{ + Test handling of case where some dimension of the result cannot be + reified. This tests the path when `reifyResultShapes` is implemented. + + Expected that dim 0 of `result` is reifable as dim 0 of `operand`, but + dim 1 of `result` is not reifiable. + }]; + let arguments = (ins 2DTensorOf<[AnyType]>:$operand); + let results = (outs 2DTensorOf<[AnyType]>:$result); +} + +def UnreifiableResultShapeOp : TEST_Op<"unreifiable_result_shape", + [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, + ["reifyResultShapes", "reifyShapeOfResult"]>]> { + let description = [{ + Test handling of case where some dimension of the result cannot be + reified. This tests the path when `reifyShapeOfResult` is implemented, + but not `reifyDimOfResult` with `reifyResultShapes` implemented as a failure. + + Expected that dim 0 of `result` is reifable as dim 0 of `operand`, but + dim 1 of `result` is not reifiable. + }]; + let arguments = (ins 2DTensorOf<[AnyType]>:$operand); + let results = (outs 2DTensorOf<[AnyType]>:$result); +} + +def UnreifiableDimOfResultShapeOp : TEST_Op<"unreifiable_dim_of_result_shape", + [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, + ["reifyResultShapes", "reifyShapeOfResult", "reifyDimOfResult"]>]> { + let description = [{ + Test handling of case where some dimension of the result cannot be + reified. This tests the path when `reifyDimOfResult` is implemented, + and `reifyDimOfResult` with `reifyResultShapes` are implemented as a failure. + + Expected that dim 0 of `result` is reifable as dim 0 of `operand`, but + dim 1 of `result` is not reifiable. + }]; + let arguments = (ins 2DTensorOf<[AnyType]>:$operand); + let results = (outs 2DTensorOf<[AnyType]>:$result); +} + def IsNotScalar : Constraint<CPred<"$0.getType().getRank() != 0">>; def UpdateAttr : Pat<(I32ElementsAttrOp $attr), @@ -1107,6 +1199,12 @@ def TestLocationDstNoResOp : TEST_Op<"loc_dst_no_res"> { let results = (outs); } +def TestLocationAttrOp : TEST_Op<"op_with_loc_attr"> { + let arguments = (ins LocationAttr:$loc_attr); + let results = (outs ); + let assemblyFormat = "$loc_attr attr-dict"; +} + //===----------------------------------------------------------------------===// // Test Patterns //===----------------------------------------------------------------------===// @@ -2254,6 +2352,24 @@ def IsolatedGraphRegionOp : TEST_Op<"isolated_graph_region", [ let assemblyFormat = "attr-dict-with-keyword $region"; } +def ManyRegionsOp : TEST_Op<"many_regions", []> { + let summary = "operation created with move-only objects"; + let description = [{ + Test op with multiple regions with a `create` function that + takes parameters containing move-only objects. + }]; + + let regions = (region VariadicRegion<AnyRegion>:$regions); + let builders = + [OpBuilder<(ins "::std::unique_ptr<::mlir::Region>":$singleRegion), [{ + $_state.addRegion(std::move(singleRegion)); + build($_builder, $_state, {}, /*regionsCount=*/1); + }]>, + // Define in TestOps.cpp. + OpBuilder<(ins "::llvm::SmallVectorImpl<::std::unique_ptr<::mlir::" + "Region>>&&":$regions)>]; +} + def AffineScopeOp : TEST_Op<"affine_scope", [AffineScope]> { let summary = "affine scope operation"; let description = [{ @@ -2888,6 +3004,20 @@ def TestLinalgFillOp : } //===----------------------------------------------------------------------===// +// Test TilingInterface. +//===----------------------------------------------------------------------===// + +def Test_TilingNoDpsOp : TEST_Op<"tiling_no_dps_op", + [Pure, DeclareOpInterfaceMethods<TilingInterface, + ["getIterationDomain", + "getLoopIteratorTypes", + "getResultTilePosition", + "getTiledImplementation"]>]> { + let arguments = (ins AnyRankedTensor:$lhs, AnyRankedTensor:$rhs); + let results = (outs AnyRankedTensor:$result); +} + +//===----------------------------------------------------------------------===// // Test NVVM RequiresSM trait. //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index fd2b943..7eabaae 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -977,7 +977,13 @@ struct TestValueReplace : public ConversionPattern { // Replace the first operand with 2x the second operand. Value from = op->getOperand(0); Value repl = op->getOperand(1); - rewriter.replaceAllUsesWith(from, {repl, repl}); + if (op->hasAttr("conditional")) { + rewriter.replaceUsesWithIf(from, {repl, repl}, [=](OpOperand &use) { + return use.getOwner()->hasAttr("replace_uses"); + }); + } else { + rewriter.replaceAllUsesWith(from, {repl, repl}); + } rewriter.modifyOpInPlace(op, [&] { // If the "trigger_rollback" attribute is set, keep the op illegal, so // that a rollback is triggered. @@ -1418,6 +1424,22 @@ public: } }; +class TestPostOrderLegalization : public ConversionPattern { +public: + TestPostOrderLegalization(MLIRContext *ctx, const TypeConverter &converter) + : ConversionPattern(converter, "test.post_order_legalization", 1, ctx) {} + LogicalResult + matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, + ConversionPatternRewriter &rewriter) const final { + for (Region &r : op->getRegions()) + if (failed(rewriter.legalize(&r))) + return failure(); + rewriter.modifyOpInPlace( + op, [&]() { op->setAttr("is_legal", rewriter.getUnitAttr()); }); + return success(); + } +}; + /// Test unambiguous overload resolution of replaceOpWithMultiple. This /// function is just to trigger compiler errors. It is never executed. [[maybe_unused]] void testReplaceOpWithMultipleOverloads( @@ -1532,7 +1554,8 @@ struct TestLegalizePatternDriver patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp, TestPassthroughInvalidOp, TestMultiple1ToNReplacement, TestValueReplace, TestReplaceWithValidConsumer, - TestTypeConsumerOpPattern>(&getContext(), converter); + TestTypeConsumerOpPattern, TestPostOrderLegalization>( + &getContext(), converter); patterns.add<TestConvertBlockArgs>(converter, &getContext()); mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, converter); @@ -1553,14 +1576,16 @@ struct TestLegalizePatternDriver [](Type type) { return type.isF32(); }); }); target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { - return converter.isSignatureLegal(op.getFunctionType()) && - converter.isLegal(&op.getBody()); + return converter.isSignatureLegal(op.getFunctionType()); }); target.addDynamicallyLegalOp<func::CallOp>( [&](func::CallOp op) { return converter.isLegal(op); }); target.addDynamicallyLegalOp( OperationName("test.value_replace", &getContext()), [](Operation *op) { return op->hasAttr("is_legal"); }); + target.addDynamicallyLegalOp( + OperationName("test.post_order_legalization", &getContext()), + [](Operation *op) { return op->hasAttr("is_legal"); }); // TestCreateUnregisteredOp creates `arith.constant` operation, // which was not added to target intentionally to test @@ -2156,8 +2181,7 @@ struct TestTypeConversionDriver recursiveType.getName() == "outer_converted_type"); }); target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { - return converter.isSignatureLegal(op.getFunctionType()) && - converter.isLegal(&op.getBody()); + return converter.isSignatureLegal(op.getFunctionType()); }); target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { // Allow casts from F64 to F32. diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td index ea20597..9859bd0 100644 --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -470,4 +470,11 @@ def TestMemrefType : Test_Type<"TestMemref", }]; } +// Test implementation of an interface with methods specifying a +// method body +def TestBaseBody : Test_Type<"TestBaseBody", + [DeclareTypeInterfaceMethods<TestBaseTypeInterfacePrintTypeA>]> { + let mnemonic = "test_base_body"; +} + #endif // TEST_TYPEDEFS diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp index 614121f..9cf64a8 100644 --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -569,11 +569,17 @@ TestTensorType::getBufferType( ::mlir::LogicalResult TestTensorType::verifyCompatibleBufferType( ::mlir::bufferization::BufferLikeType bufferType, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { - auto testMemref = dyn_cast<TestMemrefType>(bufferType); - if (!testMemref) - return emitError() << "expected TestMemrefType"; + if (auto testMemref = dyn_cast<TestMemrefType>(bufferType)) { + const bool valid = getShape() == testMemref.getShape() && + getElementType() == testMemref.getElementType(); + return mlir::success(valid); + } + + if (auto builtinMemref = dyn_cast<MemRefType>(bufferType)) { + const bool valid = getShape() == builtinMemref.getShape() && + getElementType() == builtinMemref.getElementType(); + return mlir::success(valid); + } - const bool valid = getShape() == testMemref.getShape() && - getElementType() == testMemref.getElementType(); - return mlir::success(valid); + return emitError() << "expected MemRefType or TestMemrefType"; } diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 79bfc9bb..f834d0c 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -180,6 +180,34 @@ struct TestVectorUnrollingPatterns })); populateVectorUnrollPatterns( patterns, UnrollVectorOptions() + .setNativeShape(ArrayRef<int64_t>{8, 8}) + .setFilterConstraint([](Operation *op) { + return success(isa<vector::CreateMaskOp>(op)); + })); + populateVectorUnrollPatterns( + patterns, + UnrollVectorOptions() + .setNativeShapeFn( + [](Operation *op) -> std::optional<SmallVector<int64_t>> { + auto shapeCast = dyn_cast<vector::ShapeCastOp>(op); + if (!shapeCast) + return std::nullopt; + + auto resultShape = shapeCast.getResultVectorType().getShape(); + // Special case with leading unit dims and different inner dim + // for result and target shape. + if (resultShape.size() == 2 && resultShape[0] == 1 && + resultShape[1] == 32) { + return SmallVector<int64_t>{1, 16}; + } + // Default case: [2,4] for all tests. + return SmallVector<int64_t>{2, 4}; + }) + .setFilterConstraint([](Operation *op) { + return success(isa<vector::ShapeCastOp>(op)); + })); + populateVectorUnrollPatterns( + patterns, UnrollVectorOptions() .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2}) .setFilterConstraint([](Operation *op) { return success(isa<vector::TransposeOp>(op)); diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp index 76d4611..93d5144 100644 --- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp +++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp @@ -200,7 +200,8 @@ class TestStepOpPattern : public OpConversionPattern<vector::StepOp> { Value sgId = gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); - auto maybeOffsets = sliceAttr.getOffsets(rewriter, loc, sgId, wgShape); + auto maybeOffsets = + sliceAttr.computeDistributedCoords(rewriter, loc, sgId, wgShape); if (failed(maybeOffsets)) return failure(); diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp index 326fec3..583d68b 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Transform/IR/TransformAttrs.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" @@ -170,9 +171,71 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter, // TestFuseConsumerOp //===----------------------------------------------------------------------===// +/// Fuse the consumer and store both the original consumer operation as well as +/// the fused consumer operation. +static LogicalResult +applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp, + Operation *consumer, + MutableArrayRef<LoopLikeOpInterface> loops, + TransformResults &transformResults) { + SmallVector<Operation *> fusedConsumerOps; + rewriter.setInsertionPoint(consumer); + + FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults = + scf::tileAndFuseConsumer(rewriter, consumer, loops); + if (failed(fuseConsumerResults)) + return consumer->emitOpError("failed to fuse consumer of slice"); + + // Report back the relevant handles to the transform op. + for (OpOperand *tiledAndFusedConsumerOperand : + fuseConsumerResults->tiledAndFusedConsumerOperands) { + fusedConsumerOps.push_back(tiledAndFusedConsumerOperand->getOwner()); + } + transformResults.set(transformOp->getOpResult(0), fusedConsumerOps); + for (auto [index, loop] : llvm::enumerate(loops)) { + transformResults.set(transformOp->getOpResult(index + 1), {loop}); + } + return success(); +} + +DiagnosedSilenceableFailure +transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter, + TransformResults &transformResults, + TransformState &state) { + Operation *consumer = *state.getPayloadOps(getConsumer()).begin(); + + SmallVector<LoopLikeOpInterface> loops; + // Since the matcher works inside-out, we need to iterate the loops in + // reverse. + for (auto loop : llvm::reverse(getLoops())) { + auto loopLikeOp = + dyn_cast<LoopLikeOpInterface>(*state.getPayloadOps(loop).begin()); + if (!loopLikeOp) { + return DiagnosedSilenceableFailure::definiteFailure(); + } + loops.push_back(loopLikeOp); + } + LogicalResult result = applyFuseConsumer(rewriter, getOperation(), consumer, + loops, transformResults); + return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() + : DiagnosedSilenceableFailure::success(); +} + +void transform::TestFuseConsumerOp::getEffects( + SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + consumesHandle(getConsumerMutable(), effects); + consumesHandle(getLoopsMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + modifiesPayload(effects); +} + +//===----------------------------------------------------------------------===// +// TestFuseConsumerUsingSliceOp +//===----------------------------------------------------------------------===// + /// Apply fusing of consumer transformation to all payload ops and store both /// the original consumer operation as well as the fused consumer operation. -static LogicalResult applyFuseConsumer( +static LogicalResult applyFuseConsumerUsingSlices( RewriterBase &rewriter, Operation *transformOp, ArrayRef<Operation *> slices, MutableArrayRef<LoopLikeOpInterface> loops, uint32_t numConsumerToFuse, TransformResults &transformResults) { @@ -204,10 +267,9 @@ static LogicalResult applyFuseConsumer( return success(); } -DiagnosedSilenceableFailure -transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter, - TransformResults &transformResults, - TransformState &state) { +DiagnosedSilenceableFailure transform::TestFuseConsumerUsingSliceOp::apply( + TransformRewriter &rewriter, TransformResults &transformResults, + TransformState &state) { SmallVector<Operation *> slices; for (auto op : getTargets()) { auto sliceOp = *state.getPayloadOps(op).begin(); @@ -224,13 +286,13 @@ transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter, loops.push_back(loopLikeOp); } LogicalResult result = - applyFuseConsumer(rewriter, getOperation(), slices, loops, - getNumConsumerToFuse(), transformResults); + applyFuseConsumerUsingSlices(rewriter, getOperation(), slices, loops, + getNumConsumerToFuse(), transformResults); return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() : DiagnosedSilenceableFailure::success(); } -void transform::TestFuseConsumerOp::getEffects( +void transform::TestFuseConsumerUsingSliceOp::getEffects( SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { consumesHandle(getTargetsMutable(), effects); consumesHandle(getLoopsMutable(), effects); @@ -622,6 +684,110 @@ DiagnosedSilenceableFailure transform::TestTileUsingCustomLoopOp::apply( return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// TestQueryProducerFusability +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::TestQueryProducerFusability::apply( + TransformRewriter &rewriter, TransformResults &transformResults, + TransformState &state) { + for (Operation *target : state.getPayloadOps(getTarget())) { + auto tilingInterfaceOp = dyn_cast<TilingInterface>(target); + if (!tilingInterfaceOp) { + return emitSilenceableError() + << "target operation does not implement TilingInterface"; + } + + // Collect operand numbers and their corresponding producer insert_slice + // offsets and sizes. + SmallVector<unsigned> operandNumbers; + SmallVector<SmallVector<OpFoldResult>> allOffsets; + SmallVector<SmallVector<OpFoldResult>> allSizes; + + for (OpOperand &operand : target->getOpOperands()) { + Value operandValue = operand.get(); + Operation *definingOp = operandValue.getDefiningOp(); + + // Look for a producer tensor.insert_slice. This is only for testing + // purposes and otherwise is not a useful transformation. + if (auto insertSliceOp = + dyn_cast_or_null<tensor::InsertSliceOp>(definingOp)) { + operandNumbers.push_back(operand.getOperandNumber()); + allOffsets.push_back(insertSliceOp.getMixedOffsets()); + allSizes.push_back(insertSliceOp.getMixedSizes()); + } + } + + if (!operandNumbers.empty()) { + bool isFusable = tilingInterfaceOp.isOpFusableWithProducerSlices( + operandNumbers, allOffsets, allSizes); + + if (isFusable) { + target->emitRemark() + << "can be fused with producer tensor.insert_slice ops"; + } else { + target->emitRemark() + << "cannot be fused with producer tensor.insert_slice ops"; + } + } + } + + return DiagnosedSilenceableFailure::success(); +} + +void transform::TestQueryProducerFusability::getEffects( + SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + onlyReadsHandle(getTargetMutable(), effects); + onlyReadsPayload(effects); +} + +//===----------------------------------------------------------------------===// +// TestQueryConsumerFusability +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::TestQueryConsumerFusability::apply( + TransformRewriter &rewriter, TransformResults &transformResults, + TransformState &state) { + for (Operation *target : state.getPayloadOps(getTarget())) { + auto tilingInterfaceOp = dyn_cast<TilingInterface>(target); + if (!tilingInterfaceOp) { + return emitSilenceableError() + << "target operation does not implement TilingInterface"; + } + + // Look for tensor.extract_slice ops that consume results of the tilable op. + for (OpResult result : target->getResults()) { + for (OpOperand &use : result.getUses()) { + Operation *user = use.getOwner(); + + // Look for a consumer tensor.extract_slice. This is only for testing + // purposes and otherwise is not a useful transformation. + if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(user)) { + bool isFusable = tilingInterfaceOp.isOpFusableWithConsumerSlice( + result.getResultNumber(), extractSliceOp.getMixedOffsets(), + extractSliceOp.getMixedSizes()); + + if (isFusable) { + target->emitRemark() + << "can be fused with consumer tensor.extract_slice op"; + } else { + target->emitRemark() + << "cannot be fused with consumer tensor.extract_slice op"; + } + } + } + } + } + + return DiagnosedSilenceableFailure::success(); +} + +void transform::TestQueryConsumerFusability::getEffects( + SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + onlyReadsHandle(getTargetMutable(), effects); + onlyReadsPayload(effects); +} + #define GET_OP_CLASSES #include "TestTilingInterfaceTransformOps.cpp.inc" diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td index 694c422..8c4f64d 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td @@ -49,14 +49,19 @@ def TestFuseAndYieldOp : Op<Transform_Dialect, "test.fuse_and_yield", }]; } -def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer", +def TestFuseConsumerUsingSliceOp : Op<Transform_Dialect, "test.fuse_consumer_using_slice", [AttrSizedOperandSegments, DeclareOpInterfaceMethods<TransformOpInterface>, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, ReportTrackingListenerFailuresOpTrait]> { let description = [{ - Fuses the consumer of the operation pointed to by the target handle - using the options provided as attributes. + For the `insert_slice`-like operations (that are typically generated through tiling), + within the loop nests passed in as `loops` (that are typically generated through tiling), + find the consumer that these slices map to (have to be the same consumer) and fuse + the consumer into the loop. + + Returns a handle to the original consumer operation and the consumer operation after + fusion. }]; let arguments = (ins @@ -73,6 +78,32 @@ def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer", }]; } +def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer", + [DeclareOpInterfaceMethods<TransformOpInterface>, + DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, + ReportTrackingListenerFailuresOpTrait]> { + let description = [{ + For the `consumer` that uses the result of the outer-most loop of a loop nest passed in + as `loops` (that are typically generated through tiling), fuse the consumer into the + loop. + + Returns a handle to the consumer operation after fusion and the loops that might be + modified. + }]; + + let arguments = (ins + TransformHandleTypeInterface:$consumer, + Variadic<TransformHandleTypeInterface>:$loops); + let results = (outs TransformHandleTypeInterface:$fused_consumer, + Variadic<TransformHandleTypeInterface>:$result_loops); + + let assemblyFormat = [{ + $consumer `into` `(` $loops `)` + attr-dict `:` functional-type(operands, results) + }]; +} + + def TestTileUsingForallOp : Op<Transform_Dialect, "test.tile_using_forall", [DeclareOpInterfaceMethods<TransformOpInterface>, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, @@ -166,11 +197,55 @@ def TestTileUsingCustomLoopOp : Op< DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes); let results = (outs TransformHandleTypeInterface:$tiled_ops, Variadic<TransformHandleTypeInterface>:$loops); - + let assemblyFormat = [{ $root_op `tile_sizes` `=` $tile_sizes attr-dict `:` functional-type(operands, results) }]; } +def TestQueryProducerFusability : Op< + Transform_Dialect, "test.query_producer_fusability", + [DeclareOpInterfaceMethods<TransformOpInterface>, + DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> { + let description = [{ + Test operation for the producer fusability query method in the + TilingInterface. + + For each operation in the target handle, this looks for tensor.insert_slice + ops that produce operands to the tilable op. The offset/sizes from those + inserts is used as the arguments to `isOpFusableWithProducerSlices` and + emits a remark with the result of the query. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs); + + let assemblyFormat = [{ + $target attr-dict `:` type($target) + }]; +} + +def TestQueryConsumerFusability + : Op<Transform_Dialect, "test.query_consumer_fusability", + [DeclareOpInterfaceMethods<TransformOpInterface>, + DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> { + let description = [{ + Test operation for the consumer fusability query method in the + TilingInterface. + + For each operation in the target handle, this looks for tensor.extract_slice + ops that consume results of the tilable op. The offset/sizes from those + extracts is used as the arguments to `isOpFusableWithConsumerSlice` and + emits a remark with the result of the query. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs); + + let assemblyFormat = [{ + $target attr-dict `:` type($target) + }]; +} + #endif // TEST_TILINGINTERFACE_TRANSFORM_OPS diff --git a/mlir/test/lib/Transforms/TestTransformsOps.td b/mlir/test/lib/Transforms/TestTransformsOps.td index 9b0a260..bc53b23 100644 --- a/mlir/test/lib/Transforms/TestTransformsOps.td +++ b/mlir/test/lib/Transforms/TestTransformsOps.td @@ -44,7 +44,9 @@ def TestMoveValueDefns : DeclareOpInterfaceMethods<TransformOpInterface>, ReportTrackingListenerFailuresOpTrait]> { let description = [{ - Moves all dependencies of on operation before another operation. + Moves all dependencies of a list of values before another operation. + Only pure operations are moved. If there is a side effecting op in the + dependency chain no operations are moved. }]; let arguments = diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py index 6ff12d6..675ded3 100644 --- a/mlir/test/lit.cfg.py +++ b/mlir/test/lit.cfg.py @@ -44,7 +44,7 @@ config.suffixes = [ ".test", ".pdll", ".c", - ".spv", + ".spvasm", ] # test_source_root: The root path where tests are located. @@ -214,6 +214,11 @@ tools = [ "not", ] +if "Linux" in config.host_os: + # TODO: Run only on Linux until we figure out how to build + # mlir_apfloat_wrappers in a platform-independent way. + tools.extend([add_runtime("mlir_apfloat_wrappers")]) + if config.enable_vulkan_runner: tools.extend([add_runtime("mlir_vulkan_runtime")]) diff --git a/mlir/test/mlir-tblgen/constraint-unique.td b/mlir/test/mlir-tblgen/constraint-unique.td index d51e1a5f..3f2e5cd 100644 --- a/mlir/test/mlir-tblgen/constraint-unique.td +++ b/mlir/test/mlir-tblgen/constraint-unique.td @@ -16,7 +16,7 @@ def AType : Type<ATypePred, "a type">; def OtherType : Type<ATypePred, "another type">; def AnAttrPred : CPred<"attrPred($_self, $_op)">; -def AnAttr : Attr<AnAttrPred, "an attribute">; +def AnAttr : Attr<AnAttrPred, "an attribute (got {{reformat($_self)}})">; def OtherAttr : Attr<AnAttrPred, "another attribute">; def ASuccessorPred : CPred<"successorPred($_self, $_op)">; @@ -24,7 +24,7 @@ def ASuccessor : Successor<ASuccessorPred, "a successor">; def OtherSuccessor : Successor<ASuccessorPred, "another successor">; def ARegionPred : CPred<"regionPred($_self, $_op)">; -def ARegion : Region<ARegionPred, "a region">; +def ARegion : Region<ARegionPred, "a region ({{find(foo)}})">; def OtherRegion : Region<ARegionPred, "another region">; // OpA and OpB have the same type, attribute, successor, and region constraints. @@ -71,10 +71,10 @@ def OpC : NS_Op<"op_c"> { // CHECK: static ::llvm::LogicalResult [[$A_ATTR_CONSTRAINT:__mlir_ods_local_attr_constraint.*]]( // CHECK: if (attr && !((attrPred(attr, *op)))) // CHECK-NEXT: return emitError() << "attribute '" << attrName -// CHECK-NEXT: << "' failed to satisfy constraint: an attribute"; +// CHECK-NEXT: << "' failed to satisfy constraint: an attribute (got " << reformat(attr) << ")"; /// Test that duplicate attribute constraint was not generated. -// CHECK-NOT: << "' failed to satisfy constraint: an attribute"; +// CHECK-NOT: << "' failed to satisfy constraint: an attribute /// Test that a attribute constraint with a different description was generated. // CHECK: static ::llvm::LogicalResult [[$O_ATTR_CONSTRAINT:__mlir_ods_local_attr_constraint.*]]( @@ -103,7 +103,7 @@ def OpC : NS_Op<"op_c"> { // CHECK: if (!((regionPred(region, *op)))) { // CHECK-NEXT: return op->emitOpError("region #") << regionIndex // CHECK-NEXT: << (regionName.empty() ? " " : " ('" + regionName + "') ") -// CHECK-NEXT: << "failed to verify constraint: a region"; +// CHECK-NEXT: << "failed to verify constraint: a region (" << find(foo) << ")"; /// Test that duplicate region constraint was not generated. // CHECK-NOT: << "failed to verify constraint: a region"; diff --git a/mlir/test/mlir-tblgen/cpp-class-comments.td b/mlir/test/mlir-tblgen/cpp-class-comments.td index 9dcf975..0d3445d 100644 --- a/mlir/test/mlir-tblgen/cpp-class-comments.td +++ b/mlir/test/mlir-tblgen/cpp-class-comments.td @@ -36,6 +36,7 @@ def A_SomeOp1 : Op<A_Dialect, "some_op1", []>{ let cppNamespace = "OP1"; // OP: namespace OP1 +// OP-EMPTY: // OP-NEXT: /// Some Op1 summary line1 // OP-NEXT: /// summary line2 // OP-NEXT: /// Some Op1 description @@ -97,6 +98,7 @@ def EncodingTrait : AttrInterface<"EncodingTrait"> { let methods = [ ]; // ATTR-INTERFACE: namespace mlir::a::traits { +// ATTR-INTERFACE-EMPTY: // ATTR-INTERFACE-NEXT: /// Common trait for all layouts. // ATTR-INTERFACE-NEXT: class EncodingTrait; } @@ -104,6 +106,7 @@ def EncodingTrait : AttrInterface<"EncodingTrait"> { def SimpleEncodingTrait : AttrInterface<"SimpleEncodingTrait"> { let cppNamespace = "a::traits"; // ATTR-INTERFACE: namespace a::traits { +// ATTR-INTERFACE-EMPTY: // ATTR-INTERFACE-NEXT: class SimpleEncodingTrait; } @@ -114,6 +117,7 @@ def SimpleOpInterface : OpInterface<"SimpleOpInterface"> { Simple Op Interface description }]; // OP-INTERFACE: namespace a::traits { +// OP-INTERFACE-EMPTY: // OP-INTERFACE-NEXT: /// Simple Op Interface description // OP-INTERFACE-NEXT: class SimpleOpInterface; } diff --git a/mlir/test/mlir-tblgen/dialect-interface.td b/mlir/test/mlir-tblgen/dialect-interface.td new file mode 100644 index 0000000..ff39fd9 --- /dev/null +++ b/mlir/test/mlir-tblgen/dialect-interface.td @@ -0,0 +1,65 @@ +// RUN: mlir-tblgen -gen-dialect-interface-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL + +include "mlir/IR/Interfaces.td" + +def NoDefaultMethod : DialectInterface<"NoDefaultMethod"> { + let description = [{ + This is an example dialect interface without default method body. + }]; + + let cppNamespace = "::mlir::example"; + + let methods = [ + InterfaceMethod< + /*desc=*/ "Check if it's an example dialect", + /*returnType=*/ "bool", + /*methodName=*/ "isExampleDialect", + /*args=*/ (ins) + >, + InterfaceMethod< + /*desc=*/ "second method to check if multiple methods supported", + /*returnType=*/ "unsigned", + /*methodName=*/ "supportSecondMethod", + /*args=*/ (ins "::mlir::Type":$type) + > + + ]; +} + +// DECL: class NoDefaultMethod : public {{.*}}DialectInterface::Base<NoDefaultMethod> +// DECL: public: +// DECL-NEXT: NoDefaultMethod(::mlir::Dialect *dialect) : Base(dialect) {} +// DECL: virtual bool isExampleDialect() const {} +// DECL: virtual unsigned supportSecondMethod(::mlir::Type type) const {} + +def WithDefaultMethodInterface : DialectInterface<"WithDefaultMethodInterface"> { + let description = [{ + This is an example dialect interface with default method bodies. + }]; + + let cppNamespace = "::mlir::example"; + + let methods = [ + InterfaceMethod< + /*desc=*/ "Check if it's an example dialect", + /*returnType=*/ "bool", + /*methodName=*/ "isExampleDialect", + /*args=*/ (ins), + /*methodBody=*/ [{ + return true; + }] + >, + InterfaceMethod< + /*desc=*/ "second method to check if multiple methods supported", + /*returnType=*/ "unsigned", + /*methodName=*/ "supportSecondMethod", + /*args=*/ (ins "::mlir::Type":$type) + > + + ]; +} + +// DECL: virtual bool isExampleDialect() const { +// DECL-NEXT: return true; +// DECL-NEXT: } + diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td index 549830e..a3cb9a4 100644 --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -69,19 +69,19 @@ def AOp : NS_Op<"a_op", []> { // DEF: ::llvm::LogicalResult AOpAdaptor::verify // DEF-NEXT: auto tblgen_aAttr = getProperties().aAttr; (void)tblgen_aAttr; -// DEF-NEXT: if (!tblgen_aAttr) return emitError(loc, "'test.a_op' op ""requires attribute 'aAttr'"); +// DEF-NEXT: if (!tblgen_aAttr) return emitError(loc, "'test.a_op' op requires attribute 'aAttr'"); // DEF-NEXT: auto tblgen_bAttr = getProperties().bAttr; (void)tblgen_bAttr; // DEF-NEXT: auto tblgen_cAttr = getProperties().cAttr; (void)tblgen_cAttr; // DEF-NEXT: auto tblgen_dAttr = getProperties().dAttr; (void)tblgen_dAttr; // DEF: if (tblgen_aAttr && !((some-condition))) -// DEF-NEXT: return emitError(loc, "'test.a_op' op ""attribute 'aAttr' failed to satisfy constraint: some attribute kind"); +// DEF-NEXT: return emitError(loc, "'test.a_op' op attribute 'aAttr' failed to satisfy constraint: some attribute kind"); // DEF: if (tblgen_bAttr && !((some-condition))) -// DEF-NEXT: return emitError(loc, "'test.a_op' op ""attribute 'bAttr' failed to satisfy constraint: some attribute kind"); +// DEF-NEXT: return emitError(loc, "'test.a_op' op attribute 'bAttr' failed to satisfy constraint: some attribute kind"); // DEF: if (tblgen_cAttr && !((some-condition))) -// DEF-NEXT: return emitError(loc, "'test.a_op' op ""attribute 'cAttr' failed to satisfy constraint: some attribute kind"); +// DEF-NEXT: return emitError(loc, "'test.a_op' op attribute 'cAttr' failed to satisfy constraint: some attribute kind"); // DEF: if (tblgen_dAttr && !((some-condition))) -// DEF-NEXT: return emitError(loc, "'test.a_op' op ""attribute 'dAttr' failed to satisfy constraint: some attribute kind"); +// DEF-NEXT: return emitError(loc, "'test.a_op' op attribute 'dAttr' failed to satisfy constraint: some attribute kind"); // Test getter methods // --- @@ -219,13 +219,13 @@ def AgetOp : Op<Test2_Dialect, "a_get_op", []> { // DEF: ::llvm::LogicalResult AgetOpAdaptor::verify // DEF: auto tblgen_aAttr = getProperties().aAttr; (void)tblgen_aAttr; -// DEF: if (!tblgen_aAttr) return emitError(loc, "'test2.a_get_op' op ""requires attribute 'aAttr'"); +// DEF: if (!tblgen_aAttr) return emitError(loc, "'test2.a_get_op' op requires attribute 'aAttr'"); // DEF: auto tblgen_bAttr = getProperties().bAttr; (void)tblgen_bAttr; // DEF: auto tblgen_cAttr = getProperties().cAttr; (void)tblgen_cAttr; // DEF: if (tblgen_bAttr && !((some-condition))) -// DEF-NEXT: return emitError(loc, "'test2.a_get_op' op ""attribute 'bAttr' failed to satisfy constraint: some attribute kind"); +// DEF-NEXT: return emitError(loc, "'test2.a_get_op' op attribute 'bAttr' failed to satisfy constraint: some attribute kind"); // DEF: if (tblgen_cAttr && !((some-condition))) -// DEF-NEXT: return emitError(loc, "'test2.a_get_op' op ""attribute 'cAttr' failed to satisfy constraint: some attribute kind"); +// DEF-NEXT: return emitError(loc, "'test2.a_get_op' op attribute 'cAttr' failed to satisfy constraint: some attribute kind"); // Test getter methods // --- diff --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td index 87b41f9..80dedb84 100644 --- a/mlir/test/mlir-tblgen/op-decl-and-defs.td +++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td @@ -65,6 +65,9 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> { // CHECK: ::std::optional< ::llvm::APFloat > getSomeAttr2(); // CHECK: ::mlir::Region &getSomeRegion() { // CHECK: ::mlir::RegionRange getSomeRegions() { +// CHECK-NEXT: return odsRegions.drop_front(1); +// CHECK: ::mlir::RegionRange getRegions() { +// CHECK-NEXT: return odsRegions; // CHECK: }; // CHECK: } @@ -152,10 +155,6 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> { // Check that `getAttrDictionary()` is used when not using properties. -// DECLS: ::mlir::RegionRange AOpGenericAdaptorBase::getSomeRegions() -// DECLS-NEXT: return odsRegions.drop_front(1); -// DECLS: ::mlir::RegionRange AOpGenericAdaptorBase::getRegions() - // Check AttrSizedOperandSegments // --- @@ -236,14 +235,14 @@ def NS_FOp : NS_Op<"op_with_all_types_constraint", // DEFS: FOp FOp::create(::mlir::OpBuilder &builder, ::mlir::Location location, ::mlir::Value a) { // DEFS: ::mlir::OperationState __state__(location, getOperationName()); -// DEFS: build(builder, __state__, a); +// DEFS: build(builder, __state__, std::forward<decltype(a)>(a)); // DEFS: auto __res__ = ::llvm::dyn_cast<FOp>(builder.create(__state__)); // DEFS: assert(__res__ && "builder didn't return the right type"); // DEFS: return __res__; // DEFS: } // DEFS: FOp FOp::create(::mlir::ImplicitLocOpBuilder &builder, ::mlir::Value a) { -// DEFS: return create(builder, builder.getLoc(), a); +// DEFS: return create(builder, builder.getLoc(), std::forward<decltype(a)>(a)); // DEFS: } def NS_GOp : NS_Op<"op_with_fixed_return_type", []> { diff --git a/mlir/test/mlir-tblgen/op-properties-predicates.td b/mlir/test/mlir-tblgen/op-properties-predicates.td index af09ee7..7cc9633 100644 --- a/mlir/test/mlir-tblgen/op-properties-predicates.td +++ b/mlir/test/mlir-tblgen/op-properties-predicates.td @@ -74,7 +74,7 @@ def OpWithPredicates : NS_Op<"op_with_predicates"> { // Note: comprehensive emission of verifiers is tested in verifyINvariantsImpl() below // CHECK: int64_t tblgen_scalar = this->getScalar(); // CHECK: if (!((tblgen_scalar >= 0))) -// CHECK: return emitError(loc, "'test.op_with_predicates' op ""property 'scalar' failed to satisfy constraint: non-negative int64_t"); +// CHECK: return emitError(loc, "'test.op_with_predicates' op property 'scalar' failed to satisfy constraint: non-negative int64_t"); // CHECK-LABEL: OpWithPredicates::verifyInvariantsImpl() // Note: for test readability, we capture [[maybe_unused]] into the variable maybe_unused diff --git a/mlir/test/mlir-tblgen/op-properties.td b/mlir/test/mlir-tblgen/op-properties.td index a9c784c..cb9bd3d 100644 --- a/mlir/test/mlir-tblgen/op-properties.td +++ b/mlir/test/mlir-tblgen/op-properties.td @@ -32,7 +32,7 @@ def OpWithProps : NS_Op<"op_with_props"> { ArrayProp<StringProp>:$strings, DefaultValuedProp<I32Prop, "0">:$default_int, OptionalProp<I64Prop>:$optional, - DefaultI64Array:$intArray + DefaultI64Array:$value ); } @@ -94,10 +94,10 @@ def OpWithOptionalPropsAndAttrs : // DECL: ::llvm::ArrayRef<std::string> getStrings() // DECL: using default_intTy = int32_t; // DECL: default_intTy default_int = 0; -// DECL: intArrayTy intArray = ::llvm::SmallVector<int64_t>{}; -// DECL: ::llvm::ArrayRef<int64_t> getIntArray() +// DECL: valueTy value = ::llvm::SmallVector<int64_t>{}; +// DECL: ::llvm::ArrayRef<int64_t> getValue() // DECL: return ::llvm::ArrayRef<int64_t>{propStorage} -// DECL: void setIntArray(::llvm::ArrayRef<int64_t> propValue) +// DECL: void setValue(::llvm::ArrayRef<int64_t> propValue) // DECL: propStorage.assign // DECL-LABEL: class OpWithProps : // DECL: setString(::llvm::StringRef newString) @@ -111,14 +111,14 @@ def OpWithOptionalPropsAndAttrs : // DECL-SAME: ::llvm::ArrayRef<std::string> strings, // DECL-SAME: /*optional*/int32_t default_int = 0, // DECL-SAME: /*optional*/std::optional<int64_t> optional = std::nullopt, -// DECL-SAME: /*optional*/::llvm::ArrayRef<int64_t> intArray = ::llvm::ArrayRef<int64_t>{}); +// DECL-SAME: /*optional*/::llvm::ArrayRef<int64_t> value = ::llvm::ArrayRef<int64_t>{}); // DEFS-LABEL: OpWithProps::computePropertiesHash -// DEFS: hash_intArray +// DEFS: hash_value_ // DEFS: using ::llvm::hash_value; // DEFS-NEXT: return hash_value(::llvm::ArrayRef<int64_t>{propStorage}) // DEFS: hash_value(prop.optional) -// DEFS: hash_intArray(prop.intArray) +// DEFS: hash_value_(prop.value) // ----- diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td index 42de7e4..ff16ad8 100644 --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -350,16 +350,16 @@ def MissingNamesOp : TestOp<"missing_names"> { // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip) // CHECK: @builtins.property - // CHECK: def f32(self) -> _ods_ir.Value: + // CHECK: def f32(self) -> _ods_ir.Value[_ods_ir.FloatType]: // CHECK: return self.operation.operands[1] let arguments = (ins I32, F32:$f32, I64); // CHECK: @builtins.property - // CHECK: def i32(self) -> _ods_ir.OpResult: + // CHECK: def i32(self) -> _ods_ir.OpResult[_ods_ir.IntegerType]: // CHECK: return self.operation.results[0] // // CHECK: @builtins.property - // CHECK: def i64(self) -> _ods_ir.OpResult: + // CHECK: def i64(self) -> _ods_ir.OpResult[_ods_ir.IntegerType]: // CHECK: return self.operation.results[2] let results = (outs I32:$i32, AnyFloat, I64:$i64); } @@ -590,20 +590,20 @@ def SimpleOp : TestOp<"simple"> { // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip) // CHECK: @builtins.property - // CHECK: def i32(self) -> _ods_ir.Value: + // CHECK: def i32(self) -> _ods_ir.Value[_ods_ir.IntegerType]: // CHECK: return self.operation.operands[0] // // CHECK: @builtins.property - // CHECK: def f32(self) -> _ods_ir.Value: + // CHECK: def f32(self) -> _ods_ir.Value[_ods_ir.FloatType]: // CHECK: return self.operation.operands[1] let arguments = (ins I32:$i32, F32:$f32); // CHECK: @builtins.property - // CHECK: def i64(self) -> _ods_ir.OpResult: + // CHECK: def i64(self) -> _ods_ir.OpResult[_ods_ir.IntegerType]: // CHECK: return self.operation.results[0] // // CHECK: @builtins.property - // CHECK: def f64(self) -> _ods_ir.OpResult: + // CHECK: def f64(self) -> _ods_ir.OpResult[_ods_ir.FloatType]: // CHECK: return self.operation.results[1] let results = (outs I64:$i64, AnyFloat:$f64); } diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td index c1fcd3f..41e041f 100644 --- a/mlir/test/mlir-tblgen/predicate.td +++ b/mlir/test/mlir-tblgen/predicate.td @@ -55,7 +55,7 @@ def OpF : NS_Op<"op_for_int_min_val", []> { // CHECK-LABEL: OpFAdaptor::verify // CHECK: (::llvm::cast<::mlir::IntegerAttr>(tblgen_attr).getInt() >= 10) -// CHECK-NEXT: "attribute 'attr' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 10" +// CHECK-NEXT: attribute 'attr' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 10" def OpFX : NS_Op<"op_for_int_max_val", []> { let arguments = (ins ConfinedAttr<I32Attr, [IntMaxValue<10>]>:$attr); @@ -63,7 +63,7 @@ def OpFX : NS_Op<"op_for_int_max_val", []> { // CHECK-LABEL: OpFXAdaptor::verify // CHECK: (::llvm::cast<::mlir::IntegerAttr>(tblgen_attr).getInt() <= 10) -// CHECK-NEXT: "attribute 'attr' failed to satisfy constraint: 32-bit signless integer attribute whose maximum value is 10" +// CHECK-NEXT: attribute 'attr' failed to satisfy constraint: 32-bit signless integer attribute whose maximum value is 10" def OpG : NS_Op<"op_for_arr_min_count", []> { let arguments = (ins ConfinedAttr<ArrayAttr, [ArrayMinCount<8>]>:$attr); @@ -71,7 +71,7 @@ def OpG : NS_Op<"op_for_arr_min_count", []> { // CHECK-LABEL: OpGAdaptor::verify // CHECK: (::llvm::cast<::mlir::ArrayAttr>(tblgen_attr).size() >= 8) -// CHECK-NEXT: "attribute 'attr' failed to satisfy constraint: array attribute with at least 8 elements" +// CHECK-NEXT: attribute 'attr' failed to satisfy constraint: array attribute with at least 8 elements" def OpH : NS_Op<"op_for_arr_value_at_index", []> { let arguments = (ins ConfinedAttr<ArrayAttr, [IntArrayNthElemEq<0, 8>]>:$attr); @@ -79,7 +79,7 @@ def OpH : NS_Op<"op_for_arr_value_at_index", []> { // CHECK-LABEL: OpHAdaptor::verify // CHECK: (((::llvm::cast<::mlir::ArrayAttr>(tblgen_attr).size() > 0)) && ((::llvm::cast<::mlir::IntegerAttr>(::llvm::cast<::mlir::ArrayAttr>(tblgen_attr)[0]).getInt() == 8))))) -// CHECK-NEXT: "attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be 8" +// CHECK-NEXT: attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be 8" def OpI: NS_Op<"op_for_arr_min_value_at_index", []> { let arguments = (ins ConfinedAttr<ArrayAttr, [IntArrayNthElemMinValue<0, 8>]>:$attr); @@ -87,7 +87,7 @@ def OpI: NS_Op<"op_for_arr_min_value_at_index", []> { // CHECK-LABEL: OpIAdaptor::verify // CHECK: (((::llvm::cast<::mlir::ArrayAttr>(tblgen_attr).size() > 0)) && ((::llvm::cast<::mlir::IntegerAttr>(::llvm::cast<::mlir::ArrayAttr>(tblgen_attr)[0]).getInt() >= 8))))) -// CHECK-NEXT: "attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be at least 8" +// CHECK-NEXT: attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be at least 8" def OpJ: NS_Op<"op_for_arr_max_value_at_index", []> { let arguments = (ins ConfinedAttr<ArrayAttr, [IntArrayNthElemMaxValue<0, 8>]>:$attr); @@ -95,7 +95,7 @@ def OpJ: NS_Op<"op_for_arr_max_value_at_index", []> { // CHECK-LABEL: OpJAdaptor::verify // CHECK: (((::llvm::cast<::mlir::ArrayAttr>(tblgen_attr).size() > 0)) && ((::llvm::cast<::mlir::IntegerAttr>(::llvm::cast<::mlir::ArrayAttr>(tblgen_attr)[0]).getInt() <= 8))))) -// CHECK-NEXT: "attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be at most 8" +// CHECK-NEXT: attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be at most 8" def OpK: NS_Op<"op_for_arr_in_range_at_index", []> { let arguments = (ins ConfinedAttr<ArrayAttr, [IntArrayNthElemInRange<0, 4, 8>]>:$attr); @@ -103,7 +103,7 @@ def OpK: NS_Op<"op_for_arr_in_range_at_index", []> { // CHECK-LABEL: OpKAdaptor::verify // CHECK: (((::llvm::cast<::mlir::ArrayAttr>(tblgen_attr).size() > 0)) && ((::llvm::cast<::mlir::IntegerAttr>(::llvm::cast<::mlir::ArrayAttr>(tblgen_attr)[0]).getInt() >= 4)) && ((::llvm::cast<::mlir::IntegerAttr>(::llvm::cast<::mlir::ArrayAttr>(tblgen_attr)[0]).getInt() <= 8))))) -// CHECK-NEXT: "attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be at least 4 and at most 8" +// CHECK-NEXT: attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be at least 4 and at most 8" def OpL: NS_Op<"op_for_TCopVTEtAreSameAt", [ PredOpTrait<"operands indexed at 0, 2, 3 should all have " @@ -121,7 +121,7 @@ def OpL: NS_Op<"op_for_TCopVTEtAreSameAt", [ // CHECK: ::llvm::all_equal(::llvm::map_range( // CHECK-SAME: ::mlir::ArrayRef<unsigned>({0, 2, 3}), // CHECK-SAME: [this](unsigned i) { return getElementTypeOrSelf(this->getOperand(i)); })) -// CHECK: "failed to verify that operands indexed at 0, 2, 3 should all have the same type" +// CHECK: failed to verify that operands indexed at 0, 2, 3 should all have the same type" def OpM : NS_Op<"op_for_AnyTensorOf", []> { let arguments = (ins TensorOf<[F32, I32]>:$x); diff --git a/mlir/test/python/CMakeLists.txt b/mlir/test/python/CMakeLists.txt index 2c12381..c81f75f 100644 --- a/mlir/test/python/CMakeLists.txt +++ b/mlir/test/python/CMakeLists.txt @@ -11,7 +11,7 @@ add_public_tablegen_target(MLIRPythonTestIncGen) add_subdirectory(lib) -set(MLIR_PYTHON_TEST_DEPENDS MLIRPythonModules mlir-runner) +set(MLIR_PYTHON_TEST_DEPENDS MLIRPythonModules mlir-runner mlir_c_runner_utils mlir_runner_utils) if(NOT MLIR_STANDALONE_BUILD) list(APPEND MLIR_PYTHON_TEST_DEPENDS FileCheck count not) endif() diff --git a/mlir/test/python/dialects/gpu/dialect.py b/mlir/test/python/dialects/gpu/dialect.py index 3945c99..1a009b7 100644 --- a/mlir/test/python/dialects/gpu/dialect.py +++ b/mlir/test/python/dialects/gpu/dialect.py @@ -133,9 +133,10 @@ def testGPUFuncOp(): ), func.known_grid_size func = gpu.GPUFuncOp( - func_type, + ir.FunctionType.get(inputs=[T.index()], results=[]), sym_name="non_kernel_func", body_builder=builder, + arg_attrs=[{"gpu.some_attribute": ir.StringAttr.get("foo")}], ) assert not func.is_kernel assert func.known_block_size is None @@ -154,10 +155,11 @@ def testGPUFuncOp(): # CHECK: %[[VAL_0:.*]] = gpu.global_id x # CHECK: gpu.return # CHECK: } - # CHECK: gpu.func @non_kernel_func() { - # CHECK: %[[VAL_0:.*]] = gpu.global_id x - # CHECK: gpu.return - # CHECK: } + # CHECK: gpu.func @non_kernel_func( + # CHECK-SAME: %[[ARG0:.*]]: index {gpu.some_attribute = "foo"}) { + # CHECK: %[[GLOBAL_ID_0:.*]] = gpu.global_id x + # CHECK: gpu.return + # CHECK: } # CHECK-LABEL: testGPULaunchFuncOp diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py index 709a1d2..92591cd 100644 --- a/mlir/test/python/dialects/linalg/ops.py +++ b/mlir/test/python/dialects/linalg/ops.py @@ -1,7 +1,8 @@ # RUN: %PYTHON %s | FileCheck %s -from mlir.dialects import arith, func, linalg, tensor, memref +from mlir.dialects import arith, func, linalg, tensor, memref, builtin from mlir.dialects.linalg.opdsl.lang import * +from mlir.extras import types as T from mlir.ir import * @@ -857,3 +858,76 @@ def testElementwiseOp(): ) print(module) + + +@run +def testReduceOp(): + with Context(), Location.unknown(): + f32 = T.f32() + tensor_type = T.tensor(10, f32) + + @builtin.module + def module(): + @func.func(tensor_type) + def reduce_op(input): + c1 = arith.constant(f32, 1.0) + single_result = ir.RankedTensorType.get((), f32) + dims = ir.DenseI64ArrayAttr.get([0]) + init = tensor.splat(single_result, c1, []) + + @linalg.reduce( + result=[single_result], + inputs=[input], + inits=[init], + dimensions=dims, + ) + def reduced(element: f32, acc: f32): + return arith.mulf(acc, element) + + return tensor.extract(reduced, []) + + print(module) + + +# CHECK-LABEL: func.func @reduce_op( +# CHECK-SAME: %[[ARG0:.*]]: tensor<10xf32>) -> f32 { +# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1.000000e+00 : f32 +# CHECK: %[[SPLAT_0:.*]] = tensor.splat %[[CONSTANT_0]] : tensor<f32> +# CHECK: %[[REDUCE_0:.*]] = linalg.reduce { arith.mulf } ins(%[[ARG0]] : tensor<10xf32>) outs(%[[SPLAT_0]] : tensor<f32>) dimensions = [0] +# CHECK: %[[EXTRACT_0:.*]] = tensor.extract %[[REDUCE_0]][] : tensor<f32> +# CHECK: return %[[EXTRACT_0]] : f32 +# CHECK: } + + +@run +def testMapOp(): + with Context(), Location.unknown(): + f32 = T.f32() + tensor_type = T.tensor(10, f32) + + @builtin.module + def module(): + @func.func(tensor_type) + def map_op(input): + empty = tensor.empty(tensor_type.shape, f32) + + @linalg.map( + result=[tensor_type], + inputs=[input, input], + init=empty, + ) + def add(element: f32, acc: f32, init: f32): + return arith.addf(element, acc) + + return add + + module.verify() + print(module) + + +# CHECK-LABEL: func.func @map_op( +# CHECK-SAME: %[[ARG0:.*]]: tensor<10xf32>) -> tensor<10xf32> { +# CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<10xf32> +# CHECK: %[[MAP_0:.*]] = linalg.map { arith.addf } ins(%[[ARG0]], %[[ARG0]] : tensor<10xf32>, tensor<10xf32>) outs(%[[EMPTY_0]] : tensor<10xf32>) +# CHECK: return %[[MAP_0]] : tensor<10xf32> +# CHECK: } diff --git a/mlir/test/python/dialects/linalg/utils.py b/mlir/test/python/dialects/linalg/utils.py index 5f7cb6a..8ab53b4 100644 --- a/mlir/test/python/dialects/linalg/utils.py +++ b/mlir/test/python/dialects/linalg/utils.py @@ -208,3 +208,43 @@ def test_get_indexing_maps_attr(): assert maps[0] == a_map assert maps[1] == b_map assert maps[2] == c_map + + +@run +def test_infer_contraction_dimensions_from_maps(): + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + # === Test valid contraction (matmul) === + dim_m = AffineDimExpr.get(0) + dim_n = AffineDimExpr.get(1) + dim_k = AffineDimExpr.get(2) + a_map = AffineMap.get(3, 0, [dim_m, dim_k]) + b_map = AffineMap.get(3, 0, [dim_k, dim_n]) + c_map = AffineMap.get(3, 0, [dim_m, dim_n]) + + dims = linalg.infer_contraction_dimensions_from_maps([a_map, b_map, c_map]) + assert dims is not None + + # Expect m=[0], n=[1], k=[2] as per standard matmul. + assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}" + assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}" + assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}" + assert list(dims.batch) == [], f"Expected batch=[], got {list(dims.batch)}" + + # === Test invalid input (wrong number of maps) === + invalid_dims = linalg.infer_contraction_dimensions_from_maps([a_map, b_map]) + assert invalid_dims is None + + # === Test element-wise operation === + dim_i = AffineDimExpr.get(0) + dim_j = AffineDimExpr.get(1) + elementwise_map = AffineMap.get(2, 0, [dim_i, dim_j]) + elementwise_dims = linalg.infer_contraction_dimensions_from_maps( + [elementwise_map, elementwise_map, elementwise_map] + ) + assert elementwise_dims is not None + assert len(elementwise_dims.m) == 0 + assert len(elementwise_dims.n) == 0 + assert len(elementwise_dims.k) == 0 + assert list(elementwise_dims.batch) == [0, 1] diff --git a/mlir/test/python/dialects/llvm.py b/mlir/test/python/dialects/llvm.py index 8ea0fdd..305ed9a 100644 --- a/mlir/test/python/dialects/llvm.py +++ b/mlir/test/python/dialects/llvm.py @@ -98,6 +98,9 @@ def testStructType(): assert opaque.opaque # CHECK: !llvm.struct<"opaque", opaque> + typ = Type.parse('!llvm.struct<"zoo", (i32, i64)>') + assert isinstance(typ, llvm.StructType) + # CHECK-LABEL: testSmoke @constructAndPrintInModule @@ -120,6 +123,9 @@ def testPointerType(): # CHECK: !llvm.ptr<1> print(ptr_with_addr) + typ = Type.parse("!llvm.ptr<1>") + assert isinstance(typ, llvm.PointerType) + # CHECK-LABEL: testConstant @constructAndPrintInModule diff --git a/mlir/test/python/dialects/nvvm.py b/mlir/test/python/dialects/nvvm.py index 3eb62be..d795524 100644 --- a/mlir/test/python/dialects/nvvm.py +++ b/mlir/test/python/dialects/nvvm.py @@ -15,7 +15,9 @@ def constructAndPrintInModule(f): module = Module.create() with InsertionPoint(module.body): f() + print(module) + module.operation.verify() return f @@ -89,3 +91,133 @@ def test_inline_ptx(): arith.addf(a, b) arith.addi(c, d) arith.addf(wo0, wo1) + + +@constructAndPrintInModule +def test_barriers(): + i32 = T.i32() + f32 = T.f32() + + @func.FuncOp.from_py_func(i32, i32, f32) + def barriers(mask, vi32, vf32): + c0 = arith.constant(T.i32(), 0) + cffff = arith.constant(T.i32(), 0xFFFF) + res = nvvm.barrier( + res=i32, + barrier_id=c0, + number_of_threads=cffff, + ) + + for reduction in ( + nvvm.BarrierReduction.AND, + nvvm.BarrierReduction.OR, + nvvm.BarrierReduction.POPC, + ): + res = nvvm.barrier( + res=i32, + reduction_op=reduction, + reduction_predicate=res, + ) + + nvvm.barrier0() + nvvm.bar_warp_sync(mask) + nvvm.cluster_arrive() + nvvm.cluster_arrive(aligned=True) + nvvm.cluster_arrive_relaxed() + nvvm.cluster_arrive_relaxed(aligned=True) + nvvm.cluster_wait() + nvvm.cluster_wait(aligned=True) + nvvm.fence_mbarrier_init() + nvvm.bar_warp_sync(mask) + return res + + +# CHECK-LABEL: func.func @barriers( +# CHECK: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: f32) -> i32 { +# CHECK: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 +# CHECK: %[[CONSTANT_1:.*]] = arith.constant 65535 : i32 +# CHECK: %[[BARRIER_0:.*]] = nvvm.barrier id = %[[CONSTANT_0]] number_of_threads = %[[CONSTANT_1]] -> i32 +# CHECK: %[[BARRIER_1:.*]] = nvvm.barrier #nvvm.reduction<and> %[[BARRIER_0]] -> i32 +# CHECK: %[[BARRIER_2:.*]] = nvvm.barrier #nvvm.reduction<or> %[[BARRIER_1]] -> i32 +# CHECK: %[[BARRIER_3:.*]] = nvvm.barrier #nvvm.reduction<popc> %[[BARRIER_2]] -> i32 +# CHECK: nvvm.barrier0 +# CHECK: nvvm.bar.warp.sync %[[ARG0]] : i32 +# CHECK: nvvm.cluster.arrive +# CHECK: nvvm.cluster.arrive {aligned} +# CHECK: nvvm.cluster.arrive.relaxed +# CHECK: nvvm.cluster.arrive.relaxed {aligned} +# CHECK: nvvm.cluster.wait +# CHECK: nvvm.cluster.wait {aligned} +# CHECK: nvvm.fence.mbarrier.init +# CHECK: nvvm.bar.warp.sync %[[ARG0]] : i32 +# CHECK: return %[[BARRIER_3]] : i32 +# CHECK: } + + +@constructAndPrintInModule +def test_reductions(): + i32 = T.i32() + f32 = T.f32() + + @func.FuncOp.from_py_func(i32, i32, f32) + def reductions(mask, vi32, vf32): + for abs in (True, False): + for nan in (True, False): + for kind in ( + nvvm.ReduxKind.AND, + nvvm.ReduxKind.MAX, + nvvm.ReduxKind.MIN, + nvvm.ReduxKind.OR, + nvvm.ReduxKind.UMAX, + nvvm.ReduxKind.UMIN, + nvvm.ReduxKind.XOR, + ): + nvvm.redux_sync(i32, vi32, kind, vi32) + + for kind in ( + nvvm.ReduxKind.FMIN, + nvvm.ReduxKind.FMAX, + ): + nvvm.redux_sync(f32, vf32, kind, vi32, abs=abs, nan=nan) + + +# CHECK-LABEL: func.func @reductions( +# CHECK: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: f32) { +# CHECK: %[[REDUX_0:.*]] = nvvm.redux.sync and %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_1:.*]] = nvvm.redux.sync max %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_2:.*]] = nvvm.redux.sync min %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_3:.*]] = nvvm.redux.sync or %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_4:.*]] = nvvm.redux.sync umax %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_5:.*]] = nvvm.redux.sync umin %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_6:.*]] = nvvm.redux.sync xor %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_7:.*]] = nvvm.redux.sync fmin %[[ARG2]], %[[ARG1]] {abs = true, nan = true} : f32 -> f32 +# CHECK: %[[REDUX_8:.*]] = nvvm.redux.sync fmax %[[ARG2]], %[[ARG1]] {abs = true, nan = true} : f32 -> f32 +# CHECK: %[[REDUX_9:.*]] = nvvm.redux.sync and %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_10:.*]] = nvvm.redux.sync max %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_11:.*]] = nvvm.redux.sync min %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_12:.*]] = nvvm.redux.sync or %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_13:.*]] = nvvm.redux.sync umax %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_14:.*]] = nvvm.redux.sync umin %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_15:.*]] = nvvm.redux.sync xor %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_16:.*]] = nvvm.redux.sync fmin %[[ARG2]], %[[ARG1]] {abs = true} : f32 -> f32 +# CHECK: %[[REDUX_17:.*]] = nvvm.redux.sync fmax %[[ARG2]], %[[ARG1]] {abs = true} : f32 -> f32 +# CHECK: %[[REDUX_18:.*]] = nvvm.redux.sync and %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_19:.*]] = nvvm.redux.sync max %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_20:.*]] = nvvm.redux.sync min %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_21:.*]] = nvvm.redux.sync or %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_22:.*]] = nvvm.redux.sync umax %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_23:.*]] = nvvm.redux.sync umin %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_24:.*]] = nvvm.redux.sync xor %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_25:.*]] = nvvm.redux.sync fmin %[[ARG2]], %[[ARG1]] {nan = true} : f32 -> f32 +# CHECK: %[[REDUX_26:.*]] = nvvm.redux.sync fmax %[[ARG2]], %[[ARG1]] {nan = true} : f32 -> f32 +# CHECK: %[[REDUX_27:.*]] = nvvm.redux.sync and %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_28:.*]] = nvvm.redux.sync max %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_29:.*]] = nvvm.redux.sync min %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_30:.*]] = nvvm.redux.sync or %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_31:.*]] = nvvm.redux.sync umax %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_32:.*]] = nvvm.redux.sync umin %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_33:.*]] = nvvm.redux.sync xor %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_34:.*]] = nvvm.redux.sync fmin %[[ARG2]], %[[ARG1]] : f32 -> f32 +# CHECK: %[[REDUX_35:.*]] = nvvm.redux.sync fmax %[[ARG2]], %[[ARG1]] : f32 -> f32 +# CHECK: return +# CHECK: } diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py index 1194e32..f0f74eb 100644 --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -554,7 +554,7 @@ def testOptionalOperandOp(): ) assert ( typing.get_type_hints(test.OptionalOperandOp.result.fget)["return"] - is OpResult + == OpResult[IntegerType] ) assert type(op1.result) is OpResult @@ -663,6 +663,13 @@ def testCustomType(): @run +# CHECK-LABEL: TEST: testValue +def testValue(): + # Check that Value is a generic class at runtime. + assert hasattr(Value, "__class_getitem__") + + +@run # CHECK-LABEL: TEST: testTensorValue def testTensorValue(): with Context() as ctx, Location.unknown(): diff --git a/mlir/test/python/dialects/rocdl.py b/mlir/test/python/dialects/rocdl.py index a4a50af..c73a536 100644 --- a/mlir/test/python/dialects/rocdl.py +++ b/mlir/test/python/dialects/rocdl.py @@ -29,13 +29,12 @@ def testSmoke(): a_frag = arith.constant(v16f32, f32_array) b_frag = arith.constant(v16f32, f32_array) c_frag = arith.constant(v16f32, f32_array) - false = arith.constant(T.bool(), False) - c_frag = rocdl.wmma_f16_16x16x16_f16(v16f32, [a_frag, b_frag, c_frag, false]) - # CHECK: %{{.*}} = rocdl.wmma.f16.16x16x16.f16 + c_frag = rocdl.wmma_f16_16x16x16_f16(v16f32, a_frag, b_frag, c_frag, opsel=False) + # CHECK: %{{.*}} = "rocdl.wmma.f16.16x16x16.f16" print(c_frag) assert isinstance(c_frag, OpView) - # CHECK: Value(%{{.*}} = rocdl.wmma.f16.16x16x16.f16 - c_frag = rocdl.wmma_f16_16x16x16_f16_(v16f32, [a_frag, b_frag, c_frag, false]) + # CHECK: Value(%{{.*}} = "rocdl.wmma.f16.16x16x16.f16" + c_frag = rocdl.wmma_f16_16x16x16_f16_(v16f32, a_frag, b_frag, c_frag, opsel=False) print(c_frag) assert isinstance(c_frag, Value) diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py index 62d11d5..0c0c9b9 100644 --- a/mlir/test/python/dialects/scf.py +++ b/mlir/test/python/dialects/scf.py @@ -1,10 +1,14 @@ # RUN: %PYTHON %s | FileCheck %s from mlir.ir import * -from mlir.dialects import arith -from mlir.dialects import func -from mlir.dialects import memref -from mlir.dialects import scf +from mlir.extras import types as T +from mlir.dialects import ( + arith, + func, + memref, + scf, + cf, +) from mlir.passmanager import PassManager @@ -355,3 +359,117 @@ def testIfWithElse(): # CHECK: scf.yield %[[TWO]], %[[THREE]] # CHECK: arith.addi %[[RET]]#0, %[[RET]]#1 # CHECK: return + + +@constructAndPrintInModule +def testIndexSwitch(): + i32 = T.i32() + + @func.FuncOp.from_py_func(T.index(), results=[i32]) + def index_switch(index): + c1 = arith.constant(i32, 1) + c0 = arith.constant(i32, 0) + value = arith.constant(i32, 5) + switch_op = scf.IndexSwitchOp([i32], index, range(3)) + + assert switch_op.regions[0] == switch_op.default_region + assert switch_op.regions[1] == switch_op.case_regions[0] + assert switch_op.regions[1] == switch_op.case_region(0) + assert len(switch_op.case_regions) == 3 + assert len(switch_op.regions) == 4 + + with InsertionPoint(switch_op.default_block): + cf.assert_(arith.constant(T.bool(), 0), "Whoops!") + scf.yield_([c1]) + + for i, block in enumerate(switch_op.case_blocks): + with InsertionPoint(block): + scf.yield_([arith.constant(i32, i)]) + + func.return_([switch_op.results[0]]) + + return index_switch + + +# CHECK-LABEL: func.func @index_switch( +# CHECK-SAME: %[[ARG0:.*]]: index) -> i32 { +# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : i32 +# CHECK: %[[CONSTANT_1:.*]] = arith.constant 0 : i32 +# CHECK: %[[CONSTANT_2:.*]] = arith.constant 5 : i32 +# CHECK: %[[INDEX_SWITCH_0:.*]] = scf.index_switch %[[ARG0]] -> i32 +# CHECK: case 0 { +# CHECK: %[[CONSTANT_3:.*]] = arith.constant 0 : i32 +# CHECK: scf.yield %[[CONSTANT_3]] : i32 +# CHECK: } +# CHECK: case 1 { +# CHECK: %[[CONSTANT_4:.*]] = arith.constant 1 : i32 +# CHECK: scf.yield %[[CONSTANT_4]] : i32 +# CHECK: } +# CHECK: case 2 { +# CHECK: %[[CONSTANT_5:.*]] = arith.constant 2 : i32 +# CHECK: scf.yield %[[CONSTANT_5]] : i32 +# CHECK: } +# CHECK: default { +# CHECK: %[[CONSTANT_6:.*]] = arith.constant false +# CHECK: cf.assert %[[CONSTANT_6]], "Whoops!" +# CHECK: scf.yield %[[CONSTANT_0]] : i32 +# CHECK: } +# CHECK: return %[[INDEX_SWITCH_0]] : i32 +# CHECK: } + + +@constructAndPrintInModule +def testIndexSwitchWithBodyBuilders(): + i32 = T.i32() + + @func.FuncOp.from_py_func(T.index(), results=[i32]) + def index_switch(index): + c1 = arith.constant(i32, 1) + c0 = arith.constant(i32, 0) + value = arith.constant(i32, 5) + + def default_body_builder(switch_op): + cf.assert_(arith.constant(T.bool(), 0), "Whoops!") + scf.yield_([c1]) + + def case_body_builder(switch_op, case_index: int, case_value: int): + scf.yield_([arith.constant(i32, case_value)]) + + result = scf.index_switch( + results=[i32], + arg=index, + cases=range(3), + case_body_builder=case_body_builder, + default_body_builder=default_body_builder, + ) + + func.return_([result]) + + return index_switch + + +# CHECK-LABEL: func.func @index_switch( +# CHECK-SAME: %[[ARG0:.*]]: index) -> i32 { +# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : i32 +# CHECK: %[[CONSTANT_1:.*]] = arith.constant 0 : i32 +# CHECK: %[[CONSTANT_2:.*]] = arith.constant 5 : i32 +# CHECK: %[[INDEX_SWITCH_0:.*]] = scf.index_switch %[[ARG0]] -> i32 +# CHECK: case 0 { +# CHECK: %[[CONSTANT_3:.*]] = arith.constant 0 : i32 +# CHECK: scf.yield %[[CONSTANT_3]] : i32 +# CHECK: } +# CHECK: case 1 { +# CHECK: %[[CONSTANT_4:.*]] = arith.constant 1 : i32 +# CHECK: scf.yield %[[CONSTANT_4]] : i32 +# CHECK: } +# CHECK: case 2 { +# CHECK: %[[CONSTANT_5:.*]] = arith.constant 2 : i32 +# CHECK: scf.yield %[[CONSTANT_5]] : i32 +# CHECK: } +# CHECK: default { +# CHECK: %[[CONSTANT_6:.*]] = arith.constant false +# CHECK: cf.assert %[[CONSTANT_6]], "Whoops!" +# CHECK: scf.yield %[[CONSTANT_0]] : i32 +# CHECK: } +# CHECK: return %[[INDEX_SWITCH_0]] : i32 +# CHECK: } diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py index 6c5e4e5..f58442d 100644 --- a/mlir/test/python/dialects/transform.py +++ b/mlir/test/python/dialects/transform.py @@ -51,6 +51,26 @@ def testSequenceOp(module: Module): transform.AnyOpType.get(), ) with InsertionPoint(sequence.body): + res = transform.CastOp(transform.AnyOpType.get(), sequence.bodyTarget) + res2 = transform.cast(transform.any_op_t(), res.result) + transform.YieldOp([res2]) + # CHECK-LABEL: TEST: testSequenceOp + # CHECK: transform.sequence + # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op): + # CHECK: %[[RES:.+]] = cast %[[ARG0]] : !transform.any_op to !transform.any_op + # CHECK: %[[RES2:.+]] = cast %[[RES]] : !transform.any_op to !transform.any_op + # CHECK: yield %[[RES2]] : !transform.any_op + # CHECK: } + + +@run +def testSequenceOp(module: Module): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [transform.AnyOpType.get()], + transform.AnyOpType.get(), + ) + with InsertionPoint(sequence.body): transform.YieldOp([sequence.bodyTarget]) # CHECK-LABEL: TEST: testSequenceOp # CHECK: = transform.sequence -> !transform.any_op failures(propagate) { @@ -58,6 +78,7 @@ def testSequenceOp(module: Module): # CHECK: yield %[[ARG0]] : !transform.any_op # CHECK: } + @run def testNestedSequenceOp(module: Module): sequence = transform.SequenceOp( @@ -103,55 +124,65 @@ def testSequenceOpWithExtras(module: Module): # CHECK-LABEL: TEST: testSequenceOpWithExtras # CHECK: transform.sequence failures(propagate) # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">): + sequence = transform.sequence( + transform.FailurePropagationMode.Propagate, + [], + transform.AnyOpType.get(), + [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")], + ) + with InsertionPoint(sequence.body): + transform.yield_() + # CHECK: transform.sequence failures(propagate) + # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">): @run def testNestedSequenceOpWithExtras(module: Module): - sequence = transform.SequenceOp( + sequence = transform.SequenceOp( transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get(), [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")], ) - with InsertionPoint(sequence.body): - nested = transform.SequenceOp( + with InsertionPoint(sequence.body): + nested = transform.SequenceOp( transform.FailurePropagationMode.Propagate, [], sequence.bodyTarget, sequence.bodyExtraArgs, ) - with InsertionPoint(nested.body): - transform.YieldOp() - transform.YieldOp() - # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras - # CHECK: transform.sequence failures(propagate) - # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">): - # CHECK: sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">) + with InsertionPoint(nested.body): + transform.YieldOp() + transform.YieldOp() + # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras + # CHECK: transform.sequence failures(propagate) + # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">): + # CHECK: sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">) @run def testTransformPDLOps(module: Module): - withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get()) - with InsertionPoint(withPdl.body): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, - [transform.AnyOpType.get()], - withPdl.bodyTarget, - ) - with InsertionPoint(sequence.body): - match = transform_pdl.PDLMatchOp( - transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher" - ) - transform.YieldOp(match) - # CHECK-LABEL: TEST: testTransformPDLOps - # CHECK: transform.with_pdl_patterns { - # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op): - # CHECK: = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) { - # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): - # CHECK: %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]] - # CHECK: yield %[[RES]] : !transform.any_op - # CHECK: } - # CHECK: } + withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get()) + with InsertionPoint(withPdl.body): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [transform.AnyOpType.get()], + withPdl.bodyTarget, + ) + with InsertionPoint(sequence.body): + match = transform_pdl.PDLMatchOp( + transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher" + ) + transform.YieldOp(match) + # CHECK-LABEL: TEST: testTransformPDLOps + # CHECK: transform.with_pdl_patterns { + # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op): + # CHECK: = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) { + # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): + # CHECK: %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]] + # CHECK: yield %[[RES]] : !transform.any_op + # CHECK: } + # CHECK: } @run @@ -161,32 +192,53 @@ def testNamedSequenceOp(module: Module): "__transform_main", [transform.AnyOpType.get()], [transform.AnyOpType.get()], - arg_attrs = [{"transform.consumed": UnitAttr.get()}]) + arg_attrs=[{"transform.consumed": UnitAttr.get()}], + ) with InsertionPoint(named_sequence.body): transform.YieldOp([named_sequence.bodyTarget]) # CHECK-LABEL: TEST: testNamedSequenceOp # CHECK: module attributes {transform.with_named_sequence} { - # CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op {transform.consumed}) -> !transform.any_op { - # CHECK: yield %[[ARG0]] : !transform.any_op + # CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op {transform.consumed}) -> !transform.any_op { + # CHECK: yield %[[ARG0]] : !transform.any_op + named_sequence = transform.named_sequence( + "other_seq", + [transform.AnyOpType.get()], + [transform.AnyOpType.get()], + arg_attrs=[{"transform.consumed": UnitAttr.get()}], + ) + with InsertionPoint(named_sequence.body): + transform.yield_([named_sequence.bodyTarget]) + # CHECK: transform.named_sequence @other_seq(%[[ARG1:.+]]: !transform.any_op {transform.consumed}) -> !transform.any_op { + # CHECK: yield %[[ARG1]] : !transform.any_op @run def testGetParentOp(module: Module): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() - ) - with InsertionPoint(sequence.body): - transform.GetParentOp( - transform.AnyOpType.get(), - sequence.bodyTarget, - isolated_from_above=True, - nth_parent=2, + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() ) - transform.YieldOp() - # CHECK-LABEL: TEST: testGetParentOp - # CHECK: transform.sequence - # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): - # CHECK: = get_parent_op %[[ARG1]] {isolated_from_above, nth_parent = 2 : i64} + with InsertionPoint(sequence.body): + transform.GetParentOp( + transform.AnyOpType.get(), + sequence.bodyTarget, + isolated_from_above=True, + nth_parent=2, + ) + transform.get_parent_op( + transform.AnyOpType.get(), + sequence.bodyTarget, + isolated_from_above=True, + nth_parent=2, + allow_empty_results=True, + op_name="func.func", + deduplicate=True, + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testGetParentOp + # CHECK: transform.sequence + # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): + # CHECK: = get_parent_op %[[ARG1]] {isolated_from_above, nth_parent = 2 : i64} + # CHECK: = get_parent_op %[[ARG1]] {allow_empty_results, deduplicate, isolated_from_above, nth_parent = 2 : i64, op_name = "func.func"} @run @@ -195,43 +247,58 @@ def testMergeHandlesOp(module: Module): transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() ) with InsertionPoint(sequence.body): - transform.MergeHandlesOp([sequence.bodyTarget]) + res = transform.MergeHandlesOp([sequence.bodyTarget]) + transform.merge_handles([res.result], deduplicate=True) transform.YieldOp() # CHECK-LABEL: TEST: testMergeHandlesOp # CHECK: transform.sequence # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): - # CHECK: = merge_handles %[[ARG1]] + # CHECK: %[[RES1:.+]] = merge_handles %[[ARG1]] : !transform.any_op + # CHECK: = merge_handles deduplicate %[[RES1]] : !transform.any_op @run def testApplyPatternsOpCompact(module: Module): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() - ) - with InsertionPoint(sequence.body): - with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns): - transform.ApplyCanonicalizationPatternsOp() - transform.YieldOp() - # CHECK-LABEL: TEST: testApplyPatternsOpCompact - # CHECK: apply_patterns to - # CHECK: transform.apply_patterns.canonicalization - # CHECK: !transform.any_op + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns): + transform.ApplyCanonicalizationPatternsOp() + with InsertionPoint( + transform.apply_patterns( + sequence.bodyTarget, + apply_cse=True, + max_iterations=3, + max_num_rewrites=5, + ).patterns + ): + transform.ApplyCanonicalizationPatternsOp() + transform.YieldOp() + # CHECK-LABEL: TEST: testApplyPatternsOpCompact + # CHECK: apply_patterns to + # CHECK: transform.apply_patterns.canonicalization + # CHECK: } : !transform.any_op + # CHECK: apply_patterns to + # CHECK: transform.apply_patterns.canonicalization + # CHECK: } {apply_cse, max_iterations = 3 : i64, max_num_rewrites = 5 : i64} : !transform.any_op @run def testApplyPatternsOpWithType(module: Module): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], - transform.OperationType.get('test.dummy') - ) - with InsertionPoint(sequence.body): - with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns): - transform.ApplyCanonicalizationPatternsOp() - transform.YieldOp() - # CHECK-LABEL: TEST: testApplyPatternsOp - # CHECK: apply_patterns to - # CHECK: transform.apply_patterns.canonicalization - # CHECK: !transform.op<"test.dummy"> + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("test.dummy"), + ) + with InsertionPoint(sequence.body): + with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns): + transform.ApplyCanonicalizationPatternsOp() + transform.YieldOp() + # CHECK-LABEL: TEST: testApplyPatternsOp + # CHECK: apply_patterns to + # CHECK: transform.apply_patterns.canonicalization + # CHECK: !transform.op<"test.dummy"> @run @@ -249,11 +316,13 @@ def testReplicateOp(module: Module): transform.AnyOpType.get(), sequence.bodyTarget, "second" ) transform.ReplicateOp(m1, [m2]) + transform.replicate(m1, [m2]) transform.YieldOp() # CHECK-LABEL: TEST: testReplicateOp # CHECK: %[[FIRST:.+]] = pdl_match # CHECK: %[[SECOND:.+]] = pdl_match # CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]] + # CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]] # CHECK-LABEL: TEST: testApplyRegisteredPassOp diff --git a/mlir/test/python/dialects/transform_interpreter.py b/mlir/test/python/dialects/transform_interpreter.py index 819a3be..ca9ce5d 100644 --- a/mlir/test/python/dialects/transform_interpreter.py +++ b/mlir/test/python/dialects/transform_interpreter.py @@ -32,6 +32,20 @@ def print_self(): @test_in_context +def print_self_via_apply_method(): + m = ir.Module.parse( + print_root_module.replace("from interpreter", "print_self_via_apply_method") + ) + m.body.operations[0].apply(m) + + +# CHECK-LABEL: print_self_via_apply_method +# CHECK: transform.named_sequence @__transform_main +# CHECK: transform.print +# CHECK: transform.yield + + +@test_in_context def print_other(): transform = ir.Module.parse( print_root_module.replace("from interpreter", "print_other") diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py index d6b70dc..e58b764 100644 --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -627,12 +627,16 @@ def testVectorizeChildrenAndApplyPatternsAllAttrs(target): disable_transfer_permutation_map_lowering_patterns=True, vectorize_nd_extract=True, vectorize_padding=True, + flatten_1d_depthwise_conv=True, + fold_type_extensions_into_contract=True, ) # CHECK-LABEL: TEST: testVectorizeChildrenAndApplyPatternsAllAttrs # CHECK: transform.sequence # CHECK: = transform.structured.vectorize # CHECK-SAME: disable_multi_reduction_to_contract_patterns # CHECK-SAME: disable_transfer_permutation_map_lowering_patterns + # CHECK-SAME: flatten_1d_depthwise_conv + # CHECK-SAME: fold_type_extensions_into_contract # CHECK-SAME: vectorize_nd_extract # CHECK-SAME: vectorize_padding @@ -646,12 +650,16 @@ def testVectorizeChildrenAndApplyPatternsNoAttrs(target): disable_transfer_permutation_map_lowering_patterns=False, vectorize_nd_extract=False, vectorize_padding=False, + flatten_1d_depthwise_conv=False, + fold_type_extensions_into_contract=False, ) # CHECK-LABEL: TEST: testVectorizeChildrenAndApplyPatternsNoAttrs # CHECK: transform.sequence # CHECK: = transform.structured.vectorize # CHECK-NOT: disable_multi_reduction_to_contract_patterns # CHECK-NOT: disable_transfer_permutation_map_lowering_patterns + # CHECK-NOT: flatten_1d_depthwise_conv + # CHECK-NOT: fold_type_extensions_into_contract # CHECK-NOT: vectorize_nd_extract # CHECK-NOT: vectorize_padding diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py new file mode 100644 index 0000000..2b11acb0 --- /dev/null +++ b/mlir/test/python/dialects/transform_xegpu_ext.py @@ -0,0 +1,296 @@ +# RUN: %PYTHON %s | FileCheck %s + +from mlir.ir import * +from mlir.dialects import transform +from mlir.dialects.transform import xegpu +from mlir.dialects.transform import structured, AnyValueType + + +def run(f): + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + print("\nTEST:", f.__name__) + f() + print(module) + return f + + +@run +def getDescOpDefaultIndex(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.dpas"), + ) + with InsertionPoint(sequence.body): + operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0]) + desc_handle = xegpu.get_desc_op(operand) + transform.YieldOp() + # CHECK-LABEL: TEST: getDescOpDefaultIndex + # CHECK: transform.xegpu.get_desc_op % + + +@run +def setDescLayoutMinimal(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.create_nd_tdesc"), + ) + with InsertionPoint(sequence.body): + xegpu.set_desc_layout(sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16]) + transform.YieldOp() + # CHECK-LABEL: TEST: setDescLayoutMinimal + # CHECK: %0 = transform.xegpu.set_desc_layout % + # CHECK: sg_layout = [6, 4] + # CHECK: sg_data = [32, 16] + + +@run +def setDescLayoutInstData(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.create_nd_tdesc"), + ) + with InsertionPoint(sequence.body): + xegpu.set_desc_layout( + sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], inst_data=[8, 16] + ) + transform.YieldOp() + # CHECK-LABEL: TEST: setDescLayoutInstData + # CHECK: %0 = transform.xegpu.set_desc_layout % + # CHECK: sg_layout = [6, 4] + # CHECK: sg_data = [32, 16] + # CHECK: inst_data = [8, 16] + + +@run +def setDescLayoutSlice(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.create_nd_tdesc"), + ) + with InsertionPoint(sequence.body): + xegpu.set_desc_layout( + sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], slice_dims=[0] + ) + transform.YieldOp() + # CHECK-LABEL: TEST: setDescLayoutSlice + # CHECK: %0 = transform.xegpu.set_desc_layout % + # CHECK: sg_layout = [6, 4] + # CHECK: sg_data = [32, 16] + # CHECK: slice_dims = [0] + + +@run +def setOpLayoutAttrOperandMinimal(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.dpas"), + ) + with InsertionPoint(sequence.body): + xegpu.set_op_layout_attr( + sequence.bodyTarget, + sg_layout=[6, 4], + sg_data=[32, 16], + ) + transform.YieldOp() + # CHECK-LABEL: TEST: setOpLayoutAttr + # CHECK: transform.xegpu.set_op_layout_attr % + # NO-CHECK: index = 0 + # NO-CHECK: result + # CHECK: sg_layout = [6, 4] + # CHECK: sg_data = [32, 16] + # NO-CHECK: inst_data + + +@run +def setOpLayoutAttrResult(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.dpas"), + ) + with InsertionPoint(sequence.body): + xegpu.set_op_layout_attr( + sequence.bodyTarget, + index=0, + sg_layout=[6, 4], + sg_data=[32, 16], + inst_data=[8, 16], + result=True, + ) + transform.YieldOp() + # CHECK-LABEL: TEST: setOpLayoutAttrResult + # CHECK: transform.xegpu.set_op_layout_attr % + # NO-CHECK: index = 0 + # CHECK: result + # CHECK: sg_layout = [6, 4] + # CHECK: sg_data = [32, 16] + # CHECK: inst_data = [8, 16] + + +@run +def setOpLayoutAttrResultSlice(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.dpas"), + ) + with InsertionPoint(sequence.body): + xegpu.set_op_layout_attr( + sequence.bodyTarget, + index=0, + sg_layout=[6, 4], + sg_data=[32, 16], + inst_data=[8, 16], + slice_dims=[0], + result=True, + ) + transform.YieldOp() + # CHECK-LABEL: TEST: setOpLayoutAttrResultSlice + # CHECK: transform.xegpu.set_op_layout_attr % + # NO-CHECK: index = 0 + # CHECK: result + # CHECK: sg_layout = [6, 4] + # CHECK: sg_data = [32, 16] + # CHECK: inst_data = [8, 16] + # CHECK: slice_dims = [0] + + +@run +def setGPULaunchThreadsOp(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("gpu.launch"), + ) + with InsertionPoint(sequence.body): + xegpu.set_gpu_launch_threads(sequence.bodyTarget, threads=[8, 4, 1]) + transform.YieldOp() + # CHECK-LABEL: TEST: setGPULaunchThreadsOp + # CHECK: transform.xegpu.set_gpu_launch_threads + # CHECK: threads = [8, 4, 1] + + +@run +def insertPrefetch0(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.dpas"), + ) + with InsertionPoint(sequence.body): + operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0]) + xegpu.insert_prefetch( + operand, + ) + transform.YieldOp() + # CHECK-LABEL: TEST: insertPrefetch0 + # CHECK: %[[OPR:.*]] = get_operand + # CHECK: transform.xegpu.insert_prefetch %[[OPR]] + + +@run +def insertPrefetchNbPrefetch(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.dpas"), + ) + with InsertionPoint(sequence.body): + operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0]) + xegpu.insert_prefetch( + operand, + nb_prefetch=2, + ) + transform.YieldOp() + # CHECK-LABEL: TEST: insertPrefetchNbPrefetch + # CHECK: %[[OPR:.*]] = get_operand + # CHECK: transform.xegpu.insert_prefetch %[[OPR]] + # CHECK-SAME: nb_prefetch = 2 + + +@run +def insertPrefetchNbPrefetchParam(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.dpas"), + ) + with InsertionPoint(sequence.body): + operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0]) + int32_t = IntegerType.get_signless(32) + param_int32_t = transform.ParamType.get(int32_t) + nb_param = transform.ParamConstantOp( + param_int32_t, + IntegerAttr.get(int32_t, 2), + ) + xegpu.insert_prefetch( + operand, + nb_prefetch=nb_param, + ) + transform.YieldOp() + # CHECK-LABEL: TEST: insertPrefetchNbPrefetchParam + # CHECK: %[[OPR:.*]] = get_operand + # CHECK: %[[PARAM_OP:.*]] = transform.param.constant 2 + # CHECK: transform.xegpu.insert_prefetch %[[OPR]] + # CHECK-SAME: nb_prefetch = %[[PARAM_OP]] + + +@run +def ConvertLayoutMinimal(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.dpas"), + ) + with InsertionPoint(sequence.body): + operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0]) + xegpu.convert_layout( + operand, + input_sg_layout=[6, 4], + input_sg_data=[32, 16], + target_sg_layout=[6, 4], + target_sg_data=[8, 16], + ) + transform.YieldOp() + # CHECK-LABEL: TEST: ConvertLayoutMinimal + # CHECK: transform.xegpu.convert_layout % + # CHECK: input_sg_layout = [6, 4] + # CHECK: input_sg_data = [32, 16] + # CHECK: target_sg_layout = [6, 4] + # CHECK: target_sg_data = [8, 16] + + +@run +def ConvertLayout(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.dpas"), + ) + with InsertionPoint(sequence.body): + operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [1]) + xegpu.convert_layout( + operand, + input_sg_layout=[6, 4], + input_sg_data=[32, 32], + input_inst_data=[32, 16], + target_sg_layout=[6, 4], + target_sg_data=[32, 32], + target_inst_data=[8, 16], + ) + transform.YieldOp() + # CHECK-LABEL: TEST: ConvertLayout + # CHECK: transform.xegpu.convert_layout % + # CHECK: input_sg_layout = [6, 4] + # CHECK: input_sg_data = [32, 32] + # CHECK: input_inst_data = [32, 16] + # CHECK: target_sg_layout = [6, 4] + # CHECK: target_sg_data = [32, 32] + # CHECK: target_inst_data = [8, 16] diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py index 146e213a..b11340f 100644 --- a/mlir/test/python/execution_engine.py +++ b/mlir/test/python/execution_engine.py @@ -71,6 +71,7 @@ def testInvalidModule(): func.func @foo() { return } """ ) + # CHECK: error: cannot be converted to LLVM IR: missing `LLVMTranslationDialectInterface` registration for dialect for op: func.func # CHECK: Got RuntimeError: Failure while creating the ExecutionEngine. try: execution_engine = ExecutionEngine(module) @@ -806,6 +807,7 @@ def testDumpToObjectFile(): # because RTDyldObjectLinkingLayer::emit will try to resolve symbols before dumping # (see the jitLinkForORC call at the bottom there). shared_libs=[MLIR_C_RUNNER_UTILS], + enable_pic=True, ) # CHECK: Object file exists: True diff --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py index 8f20231..8eff573 100644 --- a/mlir/test/python/integration/dialects/linalg/opsrun.py +++ b/mlir/test/python/integration/dialects/linalg/opsrun.py @@ -25,13 +25,13 @@ func.func @main() -> i32 attributes {llvm.emit_c_interface} { %O1 = memref.alloc() : memref<16xi32> %O2 = memref.alloc() : memref<4x16xi32> - %val0 = arith.constant 1.0 : f32 - %val1 = arith.constant 2.0 : f32 - %val2 = arith.constant 3.0 : f32 + %val0 = arith.constant 1 : i32 + %val1 = arith.constant 2 : i32 + %val2 = arith.constant 3 : i32 - call @fill_0d_on_buffers(%val0, %O0) : (f32, memref<i32>) -> () - call @fill_1d_on_buffers(%val1, %O1) : (f32, memref<16xi32>) -> () - call @fill_2d_on_buffers(%val2, %O2) : (f32, memref<4x16xi32>) -> () + call @fill_0d_on_buffers(%val0, %O0) : (i32, memref<i32>) -> () + call @fill_1d_on_buffers(%val1, %O1) : (i32, memref<16xi32>) -> () + call @fill_2d_on_buffers(%val2, %O2) : (i32, memref<4x16xi32>) -> () %c0 = arith.constant 0 : index %res0 = memref.load %O0[] : memref<i32> @@ -149,19 +149,18 @@ def transform(module, boilerplate): def test_fill_builtin(): with Context() as ctx, Location.unknown(): module = Module.create() - f32 = F32Type.get() i32 = IntegerType.get_signless(32) with InsertionPoint(module.body): - @func.FuncOp.from_py_func(f32, MemRefType.get([], i32)) + @func.FuncOp.from_py_func(i32, MemRefType.get([], i32)) def fill_0d_on_buffers(value, out): linalg.fill(value, outs=[out]) - @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32)) + @func.FuncOp.from_py_func(i32, MemRefType.get([16], i32)) def fill_1d_on_buffers(value, out): linalg.fill(value, outs=[out]) - @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32)) + @func.FuncOp.from_py_func(i32, MemRefType.get([4, 16], i32)) def fill_2d_on_buffers(value, out): linalg.fill(value, outs=[out]) @@ -184,19 +183,18 @@ test_fill_builtin() def test_fill_generic(): with Context() as ctx, Location.unknown(): module = Module.create() - f32 = F32Type.get() i32 = IntegerType.get_signless(32) with InsertionPoint(module.body): - @func.FuncOp.from_py_func(f32, MemRefType.get([], i32)) + @func.FuncOp.from_py_func(i32, MemRefType.get([], i32)) def fill_0d_on_buffers(value, out): linalg.fill(value, outs=[out], emit_generic=True) - @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32)) + @func.FuncOp.from_py_func(i32, MemRefType.get([16], i32)) def fill_1d_on_buffers(value, out): linalg.fill(value, outs=[out], emit_generic=True) - @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32)) + @func.FuncOp.from_py_func(i32, MemRefType.get([4, 16], i32)) def fill_2d_on_buffers(value, out): linalg.fill(value, outs=[out], emit_generic=True) diff --git a/mlir/test/python/ir/auto_location.py b/mlir/test/python/ir/auto_location.py index 8316890..1747c66a 100644 --- a/mlir/test/python/ir/auto_location.py +++ b/mlir/test/python/ir/auto_location.py @@ -15,17 +15,10 @@ def run(f): assert Context._get_live_count() == 0 -@contextmanager -def with_infer_location(): - _cext.globals.set_loc_tracebacks_enabled(True) - yield - _cext.globals.set_loc_tracebacks_enabled(False) - - # CHECK-LABEL: TEST: testInferLocations @run def testInferLocations(): - with Context() as ctx, with_infer_location(): + with Context() as ctx, loc_tracebacks(): ctx.allow_unregistered_dialects = True op = Operation.create("custom.op1") @@ -34,24 +27,26 @@ def testInferLocations(): two = arith.constant(IndexType.get(), 2) # fmt: off - # CHECK: loc(callsite("testInferLocations"("{{.*}}[[SEP:[/\\]+]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":31:13 to :43) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4)))) + # CHECK: loc(callsite("testInferLocations"("{{.*}}[[SEP:[/\\]+]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:13 to :43) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:1 to :4)))) # fmt: on print(op.location) - # fmt: off - # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":65:12 to :76) at callsite("constant"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":110:40 to :81) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":32:14 to :48) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4)))))) - # fmt: on - print(one.location) + # Test nesting of loc_tracebacks(). + with loc_tracebacks(): + # fmt: off + # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":65:12 to :76) at callsite("constant"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":110:40 to :81) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:14 to :48) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:1 to :4)))))) + # fmt: on + print(one.location) # fmt: off - # CHECK: loc(callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":34:14 to :48) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4)))) + # CHECK: loc(callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:14 to :48) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:1 to :4)))) # fmt: on print(two.location) _cext.globals.register_traceback_file_inclusion(_arith_ops_gen.__file__) three = arith.constant(IndexType.get(), 3) # fmt: off - # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":{{[0-9]+}}:4 to :235) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":52:16 to :50) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4))))) + # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":{{[0-9]+}}:4 to :235) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:16 to :50) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:1 to :4))))) # fmt: on print(three.location) @@ -60,7 +55,7 @@ def testInferLocations(): print(four.location) # fmt: off - # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":{{[0-9]+}}:4 to :235) at callsite("testInferLocations.<locals>.foo"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":59:19 to :53) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":65:8 to :13) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4)))))) + # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":{{[0-9]+}}:4 to :235) at callsite("testInferLocations.<locals>.foo"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:19 to :53) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:8 to :13) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:1 to :4)))))) # fmt: on foo() @@ -86,13 +81,13 @@ def testInferLocations(): _cext.globals.set_loc_tracebacks_frame_limit(2) # fmt: off - # CHECK: loc(callsite("testInferLocations.<locals>.bar1.<locals>.bar2.<locals>.bar3"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":77:27 to :61) at "testInferLocations.<locals>.bar1.<locals>.bar2"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":80:16 to :22))) + # CHECK: loc(callsite("testInferLocations.<locals>.bar1.<locals>.bar2.<locals>.bar3"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:27 to :61) at "testInferLocations.<locals>.bar1.<locals>.bar2"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:16 to :22))) # fmt: on bar1() _cext.globals.set_loc_tracebacks_frame_limit(1) # fmt: off - # CHECK: loc("testInferLocations.<locals>.bar1.<locals>.bar2.<locals>.bar3"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":77:27 to :61)) + # CHECK: loc("testInferLocations.<locals>.bar1.<locals>.bar2.<locals>.bar3"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:27 to :61)) # fmt: on bar1() diff --git a/mlir/test/python/ir/blocks.py b/mlir/test/python/ir/blocks.py index ced5fce..e876c00 100644 --- a/mlir/test/python/ir/blocks.py +++ b/mlir/test/python/ir/blocks.py @@ -191,3 +191,18 @@ def testBlockEraseArgs(): blocks[0].erase_argument(0) # CHECK: ^bb0: op.print(enable_debug_info=True) + + +# CHECK-LABEL: TEST: testBlockArgSetLocation +# CHECK: ^bb0(%{{.+}}: f32 loc("new_loc")): +@run +def testBlockArgSetLocation(): + with Context() as ctx, Location.unknown(ctx) as loc: + ctx.allow_unregistered_dialects = True + f32 = F32Type.get() + op = Operation.create("test", regions=1, loc=Location.unknown()) + blocks = op.regions[0].blocks + blocks.append(f32) + arg = blocks[0].arguments[0] + arg.set_location(Location.name("new_loc")) + op.print(enable_debug_info=True) diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py index f5fa4da..d124c28 100644 --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -2,12 +2,12 @@ import gc import io -import itertools from tempfile import NamedTemporaryFile from mlir.ir import * from mlir.dialects.builtin import ModuleOp -from mlir.dialects import arith +from mlir.dialects import arith, func, scf, shape from mlir.dialects._ods_common import _cext +from mlir.extras import types as T def run(f): @@ -43,6 +43,10 @@ def testTraverseOpRegionBlockIterators(): ) op = module.operation assert op.context is ctx + # Note, __nb_signature__ stores the fully-qualified signature - the actual type stub emitted is + # class RegionSequence(Sequence[Region]) + # CHECK: class RegionSequence(collections.abc.Sequence[mlir._mlir_libs._mlir.ir.Region]) + print(RegionSequence.__nb_signature__) # Get the block using iterators off of the named collections. regions = list(op.regions[:]) blocks = list(regions[0].blocks) @@ -774,6 +778,21 @@ def testKnownOpView(): print(repr(constant)) +# CHECK-LABEL: TEST: testFailedGenericOperationCreationReportsError +@run +def testFailedGenericOperationCreationReportsError(): + with Context(), Location.unknown(): + c0 = shape.const_shape([]) + c1 = shape.const_shape([1, 2, 3]) + try: + shape.MeetOp.build_generic(operands=[c0, c1]) + except MLIRError as e: + # CHECK: unequal shape cardinality + print(e) + else: + assert False, "Expected exception" + + # CHECK-LABEL: TEST: testSingleResultProperty @run def testSingleResultProperty(): @@ -1199,3 +1218,25 @@ def testGetOwnerConcreteOpview(): r = arith.AddIOp(a, a, overflowFlags=arith.IntegerOverflowFlags.nsw) for u in a.result.uses: assert isinstance(u.owner, arith.AddIOp) + + +# CHECK-LABEL: TEST: testIndexSwitch +@run +def testIndexSwitch(): + with Context() as ctx, Location.unknown(): + i32 = T.i32() + module = Module.create() + with InsertionPoint(module.body): + + @func.FuncOp.from_py_func(T.index()) + def index_switch(index): + c1 = arith.constant(i32, 1) + switch_op = scf.IndexSwitchOp(results=[i32], arg=index, cases=range(3)) + + assert len(switch_op.regions) == 4 + assert len(switch_op.regions[2:]) == 2 + assert len([i for i in switch_op.regions[2:]]) == 2 + assert len(switch_op.caseRegions) == 3 + assert len([i for i in switch_op.caseRegions]) == 3 + assert len(switch_op.caseRegions[1:]) == 2 + assert len([i for i in switch_op.caseRegions[1:]]) == 2 |
