aboutsummaryrefslogtreecommitdiff
path: root/mlir/test
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test')
-rw-r--r--mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir14
-rw-r--r--mlir/test/CAPI/execution_engine.c4
-rw-r--r--mlir/test/CAPI/global_constructors.c2
-rw-r--r--mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir1
-rw-r--r--mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir445
-rw-r--r--mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait.mlir8
-rw-r--r--mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait_tensor.mlir9
-rw-r--r--mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait_unsupported.mlir11
-rw-r--r--mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir24
-rw-r--r--mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir20
-rw-r--r--mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir22
-rw-r--r--mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir329
-rw-r--r--mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir48
-rw-r--r--mlir/test/Conversion/ConvertToSPIRV/vector.mlir36
-rw-r--r--mlir/test/Conversion/FuncToLLVM/convert-data-layout.mlir6
-rw-r--r--mlir/test/Conversion/FuncToLLVM/convert-funcs.mlir2
-rw-r--r--mlir/test/Conversion/FuncToLLVM/func-memref.mlir1
-rw-r--r--mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir1
-rw-r--r--mlir/test/Conversion/GPUToNVVM/memref.mlir1
-rw-r--r--mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir23
-rw-r--r--mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir74
-rw-r--r--mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir37
-rw-r--r--mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir2
-rw-r--r--mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir30
-rw-r--r--mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir20
-rw-r--r--mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir43
-rw-r--r--mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir19
-rw-r--r--mlir/test/Conversion/OpenMPToLLVM/map-info-type-conversion-fail.mlir14
-rw-r--r--mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir1
-rw-r--r--mlir/test/Conversion/SCFToGPU/parallel_loop.mlir48
-rw-r--r--mlir/test/Conversion/SPIRVToLLVM/gl-ops-to-llvm.mlir12
-rw-r--r--mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir9
-rw-r--r--mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir6
-rw-r--r--mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir19
-rw-r--r--mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir39
-rw-r--r--mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir39
-rw-r--r--mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir119
-rw-r--r--mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir51
-rw-r--r--mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir43
-rw-r--r--mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir36
-rw-r--r--mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir112
-rw-r--r--mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir56
-rw-r--r--mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_sub_byte.mlir80
-rw-r--r--mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir53
-rw-r--r--mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir36
-rw-r--r--mlir/test/Conversion/XeGPUToXeVM/prefetch_nd_sub_byte.mlir21
-rw-r--r--mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir26
-rw-r--r--mlir/test/Dialect/AMDGPU/amdgpu-make-dma-descriptor-fold.mlir19
-rw-r--r--mlir/test/Dialect/AMDGPU/canonicalize.mlir36
-rw-r--r--mlir/test/Dialect/AMDGPU/invalid.mlir84
-rw-r--r--mlir/test/Dialect/AMDGPU/ops.mlir176
-rw-r--r--mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir100
-rw-r--r--mlir/test/Dialect/Affine/loop-coalescing.mlir28
-rw-r--r--mlir/test/Dialect/Affine/value-bounds-reification.mlir4
-rw-r--r--mlir/test/Dialect/Arith/canonicalize.mlir13
-rw-r--r--mlir/test/Dialect/Arith/ops.mlir24
-rw-r--r--mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir99
-rw-r--r--mlir/test/Dialect/Bufferization/invalid.mlir60
-rw-r--r--mlir/test/Dialect/Bufferization/ops.mlir37
-rw-r--r--mlir/test/Dialect/ControlFlow/canonicalize.mlir22
-rw-r--r--mlir/test/Dialect/EmitC/invalid_ops.mlir54
-rw-r--r--mlir/test/Dialect/EmitC/ops.mlir10
-rw-r--r--mlir/test/Dialect/Func/duplicate-function-elimination.mlir2
-rw-r--r--mlir/test/Dialect/GPU/invalid.mlir4
-rw-r--r--mlir/test/Dialect/IRDL/variadics.mlir68
-rw-r--r--mlir/test/Dialect/Index/inliner-interface.mlir15
-rw-r--r--mlir/test/Dialect/LLVMIR/add-debuginfo-func-scope.mlir19
-rw-r--r--mlir/test/Dialect/LLVMIR/func.mlir4
-rw-r--r--mlir/test/Dialect/LLVMIR/inlining.mlir14
-rw-r--r--mlir/test/Dialect/LLVMIR/invalid-cg-profile.mlir14
-rw-r--r--mlir/test/Dialect/LLVMIR/invalid.mlir6
-rw-r--r--mlir/test/Dialect/LLVMIR/nvvm-mma-sp-kind.mlir221
-rw-r--r--mlir/test/Dialect/LLVMIR/nvvm-mma-sp-ordered.mlir411
-rw-r--r--mlir/test/Dialect/LLVMIR/nvvm-mma-sp.mlir390
-rw-r--r--mlir/test/Dialect/LLVMIR/nvvm-target-invalid.mlir11
-rw-r--r--mlir/test/Dialect/LLVMIR/nvvm.mlir28
-rw-r--r--mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir70
-rw-r--r--mlir/test/Dialect/LLVMIR/rocdl.mlir158
-rw-r--r--mlir/test/Dialect/LLVMIR/roundtrip.mlir4
-rw-r--r--mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir214
-rw-r--r--mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir75
-rw-r--r--mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir26
-rw-r--r--mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir8
-rw-r--r--mlir/test/Dialect/Linalg/invalid.mlir18
-rw-r--r--mlir/test/Dialect/Linalg/reshape_fusion.mlir75
-rw-r--r--mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir34
-rw-r--r--mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir4
-rw-r--r--mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir70
-rw-r--r--mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir93
-rw-r--r--mlir/test/Dialect/MemRef/canonicalize.mlir13
-rw-r--r--mlir/test/Dialect/MemRef/expand-strided-metadata.mlir17
-rw-r--r--mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir49
-rw-r--r--mlir/test/Dialect/MemRef/invalid.mlir16
-rw-r--r--mlir/test/Dialect/MemRef/mem2reg.mlir2
-rw-r--r--mlir/test/Dialect/MemRef/transform-ops.mlir84
-rw-r--r--mlir/test/Dialect/OpenACC/acc-implicit-data-reduction.mlir109
-rw-r--r--mlir/test/Dialect/OpenACC/acc-implicit-data.mlir224
-rw-r--r--mlir/test/Dialect/OpenACC/acc-implicit-declare.mlir175
-rw-r--r--mlir/test/Dialect/OpenACC/canonicalize.mlir27
-rw-r--r--mlir/test/Dialect/OpenACC/invalid.mlir126
-rw-r--r--mlir/test/Dialect/OpenACC/legalize-data.mlir24
-rw-r--r--mlir/test/Dialect/OpenACC/legalize-serial.mlir164
-rw-r--r--mlir/test/Dialect/OpenACC/ops.mlir270
-rw-r--r--mlir/test/Dialect/OpenACC/pointer-like-interface-load.mlir29
-rw-r--r--mlir/test/Dialect/OpenACC/pointer-like-interface-store.mlir39
-rw-r--r--mlir/test/Dialect/OpenACC/recipe-populate-private-from-firstprivate.mlir38
-rw-r--r--mlir/test/Dialect/SCF/canonicalize.mlir50
-rw-r--r--mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir215
-rw-r--r--mlir/test/Dialect/Tensor/bufferize.mlir40
-rw-r--r--mlir/test/Dialect/Tosa/availability.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/canonicalize.mlir72
-rw-r--r--mlir/test/Dialect/Tosa/invalid.mlir20
-rw-r--r--mlir/test/Dialect/Tosa/invalid_extension.mlir7
-rw-r--r--mlir/test/Dialect/Tosa/ops.mlir74
-rw-r--r--mlir/test/Dialect/Tosa/quant-test.mlir14
-rw-r--r--mlir/test/Dialect/Tosa/tosa-arith-const-to-tosa-const.mlir100
-rw-r--r--mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir23
-rw-r--r--mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir21
-rw-r--r--mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32-aggressive.mlir81
-rw-r--r--mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir162
-rw-r--r--mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir42
-rw-r--r--mlir/test/Dialect/Tosa/verifier.mlir8
-rw-r--r--mlir/test/Dialect/Transform/include-failure-propagation.mlir38
-rw-r--r--mlir/test/Dialect/Transform/test-pass-application.mlir2
-rw-r--r--mlir/test/Dialect/UB/ops.mlir6
-rw-r--r--mlir/test/Dialect/Vector/bufferize.mlir20
-rw-r--r--mlir/test/Dialect/Vector/invalid.mlir4
-rw-r--r--mlir/test/Dialect/Vector/ops.mlir14
-rw-r--r--mlir/test/Dialect/Vector/vector-scan-transforms.mlir94
-rw-r--r--mlir/test/Dialect/Vector/vector-sink.mlir17
-rw-r--r--mlir/test/Dialect/Vector/vector-unroll-options.mlir134
-rw-r--r--mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir344
-rw-r--r--mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir681
-rw-r--r--mlir/test/Dialect/XeGPU/invalid.mlir33
-rw-r--r--mlir/test/Dialect/XeGPU/ops.mlir30
-rw-r--r--mlir/test/Dialect/XeGPU/optimize-transpose.mlir280
-rw-r--r--mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir73
-rw-r--r--mlir/test/Dialect/XeGPU/propagate-layout.mlir167
-rw-r--r--mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir910
-rw-r--r--mlir/test/Dialect/XeGPU/subgroup-distribute.mlir126
-rw-r--r--mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir157
-rw-r--r--mlir/test/Dialect/XeGPU/transform-ops.mlir509
-rw-r--r--mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir37
-rw-r--r--mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir6
-rw-r--r--mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir88
-rw-r--r--mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir96
-rw-r--r--mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir272
-rw-r--r--mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir123
-rw-r--r--mlir/test/Examples/NVGPU/Ch0.py26
-rw-r--r--mlir/test/Examples/NVGPU/Ch1.py43
-rw-r--r--mlir/test/Examples/NVGPU/Ch2.py59
-rw-r--r--mlir/test/Examples/NVGPU/Ch3.py84
-rw-r--r--mlir/test/Examples/NVGPU/Ch4.py166
-rw-r--r--mlir/test/Examples/NVGPU/Ch5.py181
-rw-r--r--mlir/test/Examples/NVGPU/lit.local.cfg2
-rw-r--r--mlir/test/Examples/NVGPU/tools/nvdsl.py35
-rw-r--r--mlir/test/Examples/NVGPU/tools/nvgpucompiler.py4
-rw-r--r--mlir/test/IR/invalid-ops.mlir8
-rw-r--r--mlir/test/IR/locations.mlir7
-rw-r--r--mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation-vector.mlir26
-rw-r--r--mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir82
-rw-r--r--mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir4
-rw-r--r--mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir4
-rw-r--r--mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir4
-rw-r--r--mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir93
-rw-r--r--mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir4
-rw-r--r--mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir15
-rw-r--r--mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir9
-rw-r--r--mlir/test/Integration/Dialect/Transform/match_matmul.mlir4
-rw-r--r--mlir/test/Integration/Dialect/XeGPU/LANE/load_store_subview.mlir63
-rw-r--r--mlir/test/Integration/Dialect/XeVM/GPU/gpu_printf.mlir2
-rw-r--r--mlir/test/Integration/GPU/CUDA/TensorCore/sm80/wmma-matmul-f64.mlir73
-rw-r--r--mlir/test/Integration/GPU/CUDA/all-reduce-and.mlir2
-rw-r--r--mlir/test/Integration/GPU/CUDA/all-reduce-maxsi.mlir2
-rw-r--r--mlir/test/Integration/GPU/CUDA/all-reduce-minsi.mlir2
-rw-r--r--mlir/test/Integration/GPU/CUDA/all-reduce-op.mlir2
-rw-r--r--mlir/test/Integration/GPU/CUDA/all-reduce-or.mlir2
-rw-r--r--mlir/test/Integration/GPU/CUDA/all-reduce-region.mlir2
-rw-r--r--mlir/test/Integration/GPU/CUDA/all-reduce-xor.mlir2
-rw-r--r--mlir/test/Integration/GPU/CUDA/alloc-host-shared.mlir2
-rw-r--r--mlir/test/Integration/GPU/CUDA/assert.mlir2
-rw-r--r--mlir/test/Integration/GPU/CUDA/async.mlir7
-rw-r--r--mlir/test/Integration/GPU/CUDA/command-line-arg.mlir2
-rw-r--r--mlir/test/Integration/GPU/CUDA/concurrent-kernels.mlir2
-rw-r--r--mlir/test/Integration/GPU/CUDA/dump-ptx.mlir2
-rw-r--r--mlir/test/Integration/GPU/CUDA/dump-sass.mlir2
-rw-r--r--mlir/test/Integration/GPU/CUDA/gpu-to-cubin.mlir2
-rw-r--r--mlir/test/Integration/GPU/CUDA/multiple-all-reduce.mlir2
-rw-r--r--mlir/test/Integration/GPU/CUDA/printf.mlir2
-rw-r--r--mlir/test/Integration/GPU/CUDA/shuffle.mlir2
-rw-r--r--mlir/test/Integration/GPU/CUDA/two-modules.mlir2
-rw-r--r--mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir16
-rw-r--r--mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir135
-rw-r--r--mlir/test/Interfaces/TilingInterface/query-fusability.mlir70
-rw-r--r--mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir1156
-rw-r--r--mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir380
-rw-r--r--mlir/test/Pass/invalid-unsupported-operation.mlir10
-rw-r--r--mlir/test/Pass/pipeline-invalid.mlir2
-rw-r--r--mlir/test/Target/Cpp/common-cpp.mlir19
-rw-r--r--mlir/test/Target/Cpp/expressions.mlir30
-rw-r--r--mlir/test/Target/LLVMIR/Import/debug-info-records.ll87
-rw-r--r--mlir/test/Target/LLVMIR/Import/function-attributes.ll6
-rw-r--r--mlir/test/Target/LLVMIR/Import/import-failure.ll12
-rw-r--r--mlir/test/Target/LLVMIR/Import/instructions.ll8
-rw-r--r--mlir/test/Target/LLVMIR/Import/intrinsic.ll32
-rw-r--r--mlir/test/Target/LLVMIR/Import/metadata-profiling.ll36
-rw-r--r--mlir/test/Target/LLVMIR/allocatable_gpu_reduction.mlir99
-rw-r--r--mlir/test/Target/LLVMIR/allocatable_gpu_reduction_teams.mlir121
-rw-r--r--mlir/test/Target/LLVMIR/anonymous-tbaa.mlir21
-rw-r--r--mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir34
-rw-r--r--mlir/test/Target/LLVMIR/llvmir.mlir25
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/barrier.mlir27
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/convert_fp16x2.mlir87
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir118
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/fence-invalid.mlir89
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/fence.mlir85
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/invalid_convert_fp16x2.mlir47
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/mbar_arr_drop_expect_tx.mlir68
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/mbar_arr_expect_tx.mlir68
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir103
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/mbar_arrive_drop.mlir103
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/mbar_complete_tx.mlir29
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/mbar_expect_tx.mlir29
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/mbar_init.mlir56
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/mbar_invalid.mlir138
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/mbar_test_wait.mlir73
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/mbar_try_wait.mlir147
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/membar.mlir14
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/permute_invalid.mlir43
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/permute_valid.mlir64
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/redux-sync-invalid.mlir49
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir22
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/tcgen05-ld-invalid.mlir9
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-shared.mlir229
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-tensor.mlir229
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-invalid.mlir119
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-shared.mlir442
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-shared.mlir229
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-tensor.mlir229
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-shared.mlir442
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-tensor.mlir634
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-tensor.mlir633
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-shared.mlir133
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-sp-shared.mlir133
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-sp-tensor.mlir133
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-tensor.mlir133
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir11
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy_invalid.mlir8
-rw-r--r--mlir/test/Target/LLVMIR/nvvmir-invalid.mlir8
-rw-r--r--mlir/test/Target/LLVMIR/nvvmir.mlir74
-rw-r--r--mlir/test/Target/LLVMIR/omptarget-data-use-dev-ordering.mlir20
-rw-r--r--mlir/test/Target/LLVMIR/omptarget-declare-target-to-device.mlir35
-rw-r--r--mlir/test/Target/LLVMIR/omptarget-declare-target-to-host.mlir34
-rw-r--r--mlir/test/Target/LLVMIR/omptarget-llvm.mlir17
-rw-r--r--mlir/test/Target/LLVMIR/omptarget-multi-block-reduction.mlir15
-rw-r--r--mlir/test/Target/LLVMIR/omptarget-multi-reduction.mlir8
-rw-r--r--mlir/test/Target/LLVMIR/omptarget-nowait.mlir27
-rw-r--r--mlir/test/Target/LLVMIR/omptarget-overlapping-record-member-map.mlir65
-rw-r--r--mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir18
-rw-r--r--mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir4
-rw-r--r--mlir/test/Target/LLVMIR/omptarget-record-type-with-ptr-member-host.mlir105
-rw-r--r--mlir/test/Target/LLVMIR/omptarget-runtimecc.mlir12
-rw-r--r--mlir/test/Target/LLVMIR/omptarget-teams-distribute-reduction.mlir2
-rw-r--r--mlir/test/Target/LLVMIR/omptarget-teams-reduction.mlir2
-rw-r--r--mlir/test/Target/LLVMIR/openmp-barrier-cancel.mlir14
-rw-r--r--mlir/test/Target/LLVMIR/openmp-cancel.mlir81
-rw-r--r--mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir14
-rw-r--r--mlir/test/Target/LLVMIR/openmp-dist_schedule.mlir34
-rw-r--r--mlir/test/Target/LLVMIR/openmp-dist_schedule_with_wsloop.mlir205
-rw-r--r--mlir/test/Target/LLVMIR/openmp-llvm.mlir74
-rw-r--r--mlir/test/Target/LLVMIR/openmp-outline-infinite-loop.mlir6
-rw-r--r--mlir/test/Target/LLVMIR/openmp-parallel-reduction-multiblock.mlir4
-rw-r--r--mlir/test/Target/LLVMIR/openmp-reduction-array-sections.mlir14
-rw-r--r--mlir/test/Target/LLVMIR/openmp-reduction-init-arg.mlir4
-rw-r--r--mlir/test/Target/LLVMIR/openmp-reduction-sections.mlir16
-rw-r--r--mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir2
-rw-r--r--mlir/test/Target/LLVMIR/openmp-target-spmd.mlir2
-rw-r--r--mlir/test/Target/LLVMIR/openmp-todo.mlir50
-rw-r--r--mlir/test/Target/LLVMIR/rocdl.mlir436
-rw-r--r--mlir/test/Target/LLVMIR/target-ext-type.mlir6
-rw-r--r--mlir/test/Target/SPIRV/consecutive-selection.spvasm (renamed from mlir/test/Target/SPIRV/consecutive-selection.spv)0
-rw-r--r--mlir/test/Target/SPIRV/decorations.mlir7
-rw-r--r--mlir/test/Target/SPIRV/group-ops.mlir30
-rw-r--r--mlir/test/Target/SPIRV/loop.mlir7
-rw-r--r--mlir/test/Target/SPIRV/mlir-translate.mlir1
-rw-r--r--mlir/test/Target/SPIRV/module.mlir1
-rw-r--r--mlir/test/Target/SPIRV/phi.mlir62
-rw-r--r--mlir/test/Target/SPIRV/selection.mlir151
-rw-r--r--mlir/test/Target/SPIRV/selection.spvasm (renamed from mlir/test/Target/SPIRV/selection.spv)0
-rw-r--r--mlir/test/Target/SPIRV/selection_switch.spvasm69
-rw-r--r--mlir/test/Target/SPIRV/struct.mlir23
-rw-r--r--mlir/test/Target/SPIRV/subgroup-block-intel.mlir34
-rw-r--r--mlir/test/Transforms/buffer-results-to-out-params-modify-public-functions.mlir40
-rw-r--r--mlir/test/Transforms/loop-invariant-code-motion.mlir145
-rw-r--r--mlir/test/Transforms/move-operation-deps.mlir182
-rw-r--r--mlir/test/Transforms/remove-dead-values.mlir40
-rw-r--r--mlir/test/Transforms/test-legalizer-full.mlir18
-rw-r--r--mlir/test/Transforms/test-legalizer-no-materializations.mlir67
-rw-r--r--mlir/test/Transforms/test-legalizer-no-rollback.mlir23
-rw-r--r--mlir/test/Transforms/test-legalizer-rollback.mlir19
-rw-r--r--mlir/test/Transforms/test-legalizer.mlir71
-rw-r--r--mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp11
-rw-r--r--mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp121
-rw-r--r--mlir/test/lib/Dialect/OpenACC/TestRecipePopulate.cpp23
-rw-r--r--mlir/test/lib/Dialect/Test/TestDialect.cpp1
-rw-r--r--mlir/test/lib/Dialect/Test/TestOpDefs.cpp138
-rw-r--r--mlir/test/lib/Dialect/Test/TestOps.h1
-rw-r--r--mlir/test/lib/Dialect/Test/TestOps.td136
-rw-r--r--mlir/test/lib/Dialect/Test/TestPatterns.cpp36
-rw-r--r--mlir/test/lib/Dialect/Test/TestTypeDefs.td7
-rw-r--r--mlir/test/lib/Dialect/Test/TestTypes.cpp18
-rw-r--r--mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp28
-rw-r--r--mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp3
-rw-r--r--mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp182
-rw-r--r--mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td83
-rw-r--r--mlir/test/lib/Transforms/TestTransformsOps.td4
-rw-r--r--mlir/test/lit.cfg.py7
-rw-r--r--mlir/test/mlir-tblgen/constraint-unique.td10
-rw-r--r--mlir/test/mlir-tblgen/cpp-class-comments.td4
-rw-r--r--mlir/test/mlir-tblgen/dialect-interface.td65
-rw-r--r--mlir/test/mlir-tblgen/op-attribute.td16
-rw-r--r--mlir/test/mlir-tblgen/op-decl-and-defs.td11
-rw-r--r--mlir/test/mlir-tblgen/op-properties-predicates.td2
-rw-r--r--mlir/test/mlir-tblgen/op-properties.td14
-rw-r--r--mlir/test/mlir-tblgen/op-python-bindings.td14
-rw-r--r--mlir/test/mlir-tblgen/predicate.td16
-rw-r--r--mlir/test/python/CMakeLists.txt2
-rw-r--r--mlir/test/python/dialects/gpu/dialect.py12
-rw-r--r--mlir/test/python/dialects/linalg/ops.py76
-rw-r--r--mlir/test/python/dialects/linalg/utils.py40
-rw-r--r--mlir/test/python/dialects/llvm.py6
-rw-r--r--mlir/test/python/dialects/nvvm.py132
-rw-r--r--mlir/test/python/dialects/python_test.py9
-rw-r--r--mlir/test/python/dialects/rocdl.py9
-rw-r--r--mlir/test/python/dialects/scf.py126
-rw-r--r--mlir/test/python/dialects/transform.py215
-rw-r--r--mlir/test/python/dialects/transform_interpreter.py14
-rw-r--r--mlir/test/python/dialects/transform_structured_ext.py8
-rw-r--r--mlir/test/python/dialects/transform_xegpu_ext.py296
-rw-r--r--mlir/test/python/execution_engine.py2
-rw-r--r--mlir/test/python/integration/dialects/linalg/opsrun.py26
-rw-r--r--mlir/test/python/ir/auto_location.py31
-rw-r--r--mlir/test/python/ir/blocks.py15
-rw-r--r--mlir/test/python/ir/operation.py45
344 files changed, 21909 insertions, 2521 deletions
diff --git a/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir b/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir
index 3748be7..768f1cf 100644
--- a/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir
+++ b/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir
@@ -184,6 +184,18 @@ func.func private @private0(%0 : i32) -> i32 {
// CHECK-NEXT: result #0: live
// CHECK-LABEL: test_tag: y:
// CHECK-NEXT: result #0: not live
+// CHECK-LABEL: test_tag: for:
+// CHECK-NEXT: operand #0: live
+// CHECK-NEXT: operand #1: live
+// CHECK-NEXT: operand #2: live
+// CHECK-NEXT: operand #3: live
+// CHECK-NEXT: operand #4: not live
+// CHECK-NEXT: result #0: live
+// CHECK-NEXT: result #1: not live
+// CHECK-NEXT: region: #0:
+// CHECK-NEXT: argument: #0: live
+// CHECK-NEXT: argument: #1: not live
+// CHECK-NEXT: argument: #2: not live
func.func @test_7_type_3(%arg0: memref<i32>) {
%c0 = arith.constant {tag = "zero"} 0 : index
%c10 = arith.constant {tag = "ten"} 10 : index
@@ -194,7 +206,7 @@ func.func @test_7_type_3(%arg0: memref<i32>) {
%1 = arith.addi %x, %x : i32
%2 = func.call @private0(%1) : (i32) -> i32
scf.yield %2, %arg3 : i32, i32
- }
+ } {tag = "for"}
memref.store %0#0, %arg0[] : memref<i32>
return
}
diff --git a/mlir/test/CAPI/execution_engine.c b/mlir/test/CAPI/execution_engine.c
index 4751288..4df232f 100644
--- a/mlir/test/CAPI/execution_engine.c
+++ b/mlir/test/CAPI/execution_engine.c
@@ -69,7 +69,7 @@ void testSimpleExecution(void) {
mlirRegisterAllLLVMTranslations(ctx);
MlirExecutionEngine jit = mlirExecutionEngineCreate(
module, /*optLevel=*/2, /*numPaths=*/0, /*sharedLibPaths=*/NULL,
- /*enableObjectDump=*/false);
+ /*enableObjectDump=*/false, /*enablePIC=*/false);
if (mlirExecutionEngineIsNull(jit)) {
fprintf(stderr, "Execution engine creation failed");
exit(2);
@@ -125,7 +125,7 @@ void testOmpCreation(void) {
// against the OpenMP library.
MlirExecutionEngine jit = mlirExecutionEngineCreate(
module, /*optLevel=*/2, /*numPaths=*/0, /*sharedLibPaths=*/NULL,
- /*enableObjectDump=*/false);
+ /*enableObjectDump=*/false, /*enablePIC=*/false);
if (mlirExecutionEngineIsNull(jit)) {
fprintf(stderr, "Engine creation failed with OpenMP");
exit(2);
diff --git a/mlir/test/CAPI/global_constructors.c b/mlir/test/CAPI/global_constructors.c
index bd2fe14..9aacaf2 100644
--- a/mlir/test/CAPI/global_constructors.c
+++ b/mlir/test/CAPI/global_constructors.c
@@ -79,7 +79,7 @@ void testGlobalCtorJitCallback(void) {
// Create execution engine with initialization disabled
MlirExecutionEngine jit = mlirExecutionEngineCreate(
module, /*optLevel=*/2, /*numPaths=*/0, /*sharedLibPaths=*/NULL,
- /*enableObjectDump=*/false);
+ /*enableObjectDump=*/false, /*enablePIC=*/false);
if (mlirExecutionEngineIsNull(jit)) {
fprintf(stderr, "Execution engine creation failed");
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 2fd3df6d..432b887 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -456,3 +456,4 @@ func.func @sched_barrier() {
amdgpu.sched_barrier allow = <valu|all_vmem>
func.return
}
+
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
new file mode 100644
index 0000000..a94e17a
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
@@ -0,0 +1,445 @@
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 --split-input-file --verify-diagnostics \
+// RUN: | FileCheck %s
+
+// CHECK-LABEL: @scaled_ext_packed_matrix_fp4
+// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf4E2M1FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
+func.func @scaled_ext_packed_matrix_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
+ // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
+ // CHECK: %[[SOURCE_8xi4:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf4E2M1FN> to vector<8xi4>
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32
+ // CHECK: rocdl.cvt.scale.pk8.f16.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xf16>
+ %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32
+ // CHECK: rocdl.cvt.scale.pk8.bf16.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xbf16>
+ %ret1 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xbf16>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32
+ // CHECK: rocdl.cvt.scale.pk8.f32.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xf32>
+ %ret2 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf32>
+ func.return %ret0, %ret1, %ret2: vector<8xf16>, vector<8xbf16>, vector<8xf32>
+}
+
+// CHECK-LABEL: @scaled_ext_packed_matrix_fp8
+// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E4M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
+func.func @scaled_ext_packed_matrix_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
+ // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
+ // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E4M3FN> to vector<8xi8>
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
+ // CHECK: rocdl.cvt.scale.pk8.f16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16>
+ %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
+ // CHECK: rocdl.cvt.scale.pk8.bf16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16>
+ %ret1 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xbf16>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
+ // CHECK: rocdl.cvt.scale.pk8.f32.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf32>
+ %ret2 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf32>
+
+ func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
+}
+
+// CHECK-LABEL: @scaled_ext_packed_matrix_bf8
+// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E5M2>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
+func.func @scaled_ext_packed_matrix_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
+ // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
+ // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E5M2> to vector<8xi8>
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
+ // CHECK: %[[RES:.+]] = rocdl.cvt.scale.pk8.f16.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16>
+ %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
+ // CHECK: rocdl.cvt.scale.pk8.bf16.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16>
+ %ret1 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xbf16>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
+ // CHECK: rocdl.cvt.scale.pk8.f32.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf32>
+ %ret2 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf32>
+ func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
+}
+
+
+// CHECK-LABEL: @scaled_ext_packed_matrix_fp6
+// CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E2M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
+func.func @scaled_ext_packed_matrix_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) {
+ // CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
+ // CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E2M3FN> to vector<16xi6>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
+ // CHECK: rocdl.cvt.scale.pk16.f16.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf16>
+ %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
+ // CHECK: rocdl.cvt.scale.pk16.bf16.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xbf16>
+ %ret1 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xbf16>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
+ // CHECK: rocdl.cvt.scale.pk16.f32.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf32>
+ %ret2 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf32>
+ return %ret0, %ret1, %ret2: vector<16xf16>, vector<16xbf16>, vector<16xf32>
+}
+
+// CHECK-LABEL: @scaled_ext_packed_matrix_bf6
+// CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E3M2FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
+func.func @scaled_ext_packed_matrix_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) {
+ // CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
+ // CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E3M2FN> to vector<16xi6>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
+ // CHECK: rocdl.cvt.scale.pk16.f16.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf16>
+ %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf16>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
+ // CHECK: rocdl.cvt.scale.pk16.bf16.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xbf16>
+ %ret1 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xbf16>
+
+ // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+ // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
+ // CHECK: rocdl.cvt.scale.pk16.f32.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf32>
+ %ret2 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf32>
+ return %ret0, %ret1, %ret2: vector<16xf16>, vector<16xbf16>, vector<16xf32>
+}
+
+// -----
+
+func.func @amdgpu.scaled_ext_packed_matrix_invalid_block_size_and_first_scale_byte_16(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) {
+ // expected-error@+1 {{'amdgpu.scaled_ext_packed_matrix' op blockSize of 16 can only have firstScaleByte be 0 or 1 for f4 and f6}}
+ %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(2) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
+ func.return
+}
+
+// -----
+
+func.func @amdgpu.scaled_ext_packed_matrix_invalid_block_size_and_first_scale_byte_32(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) {
+ // expected-error@+1 {{'amdgpu.scaled_ext_packed_matrix' op blockSize of 32 can only have firstScaleByte be 0 or 2 for f4 and f6.}}
+ %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(1) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
+ func.return
+}
+
+// -----
+
+func.func @amdgpu.scaled_ext_packed_matrix_invalid_attributes_for_f8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) {
+ // expected-error@+1 {{'amdgpu.scaled_ext_packed_matrix' op blockSize of 16 can only have (firstScaleLane, firstScaleByte) be (0, 0) or (16, 2) for f8.}}
+ %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(1) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16>
+ func.return
+}
+
+// -----
+
+func.func @amdgpu.scaled_ext_packed_matrix_invalid_input_output_sizes(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) {
+ // expected-error@+1 {{'amdgpu.scaled_ext_packed_matrix' op failed to verify that all of {source, res} have same shape}}
+ %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<16xf16>
+ func.return
+}
+
+// -----
+
+func.func @amdgpu.scaled_ext_packed_matrix_invalid_src_elem_type(%v: vector<16xf16>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>) {
+ // expected-error@+1 {{'amdgpu.scaled_ext_packed_matrix' op operand #0 must be}}
+ %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf16>, vector<4xf8E8M0FNU> -> vector<16xf16>
+ return %ret0: vector<16xf16>
+}
+
+// -----
+
+func.func @amdgpu.scaled_ext_packed_matrix_invalid_dst_elem_type(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf64>) {
+ // expected-error@+1 {{'amdgpu.scaled_ext_packed_matrix' op result #0 must be vector}}
+ %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf64>
+ return %ret0: vector<16xf64>
+}
+
+// -----
+
+#gpu_global_addrspace = 1
+#gpu_lds_addrspace = 3
+#amdgpu_fat_buffer_addrspace = 7
+
+func.func @amdgpu.make_dma_base.invalid_element_types(%idx: index, %mem: memref<8xi32, #gpu_global_addrspace>, %smem: memref<8xf32,#gpu_lds_addrspace>) -> (!amdgpu.tdm_base<i32>) {
+ // expected-error@+1 {{'amdgpu.make_dma_base' op failed to verify that all of {global, lds} have same element type}}
+ %0 = amdgpu.make_dma_base %mem[%idx], %smem[%idx] : memref<8xi32, #gpu_global_addrspace>, memref<8xf32, #gpu_lds_addrspace> -> !amdgpu.tdm_base<i32>
+ return %0 : !amdgpu.tdm_base<i32>
+}
+
+// -----
+
+#gpu_global_addrspace = 1
+#gpu_lds_addrspace = 3
+#amdgpu_fat_buffer_addrspace = 7
+
+func.func @amdgpu.make_dma_base.invalid_element_types(%idx: index, %mem: memref<8xi7, #gpu_global_addrspace>, %smem: memref<8xi7,#gpu_lds_addrspace>) -> (!amdgpu.tdm_base<i7>) {
+ // expected-error@+1 {{'amdgpu.make_dma_base' op element type must be 1, 2, 4, or 8 bytes long but type was 7 bits long.}}
+ %0 = amdgpu.make_dma_base %mem[%idx], %smem[%idx] : memref<8xi7, #gpu_global_addrspace>, memref<8xi7, #gpu_lds_addrspace> -> !amdgpu.tdm_base<i7>
+ return %0 : !amdgpu.tdm_base<i7>
+}
+
+// -----
+
+#gpu_global_addrspace = 1
+#gpu_lds_addrspace = 3
+#amdgpu_fat_buffer_addrspace = 7
+
+// CHECK-LABEL: func @make_dma_base
+// CHECK-SAME: (%[[IDX:.+]]: index, %[[MEM:.+]]: memref<8xi32, 1>, %[[SMEM:.+]]: memref<8xi32, 3>)
+func.func @make_dma_base(%idx: index, %mem: memref<8xi32, #gpu_global_addrspace>, %smem: memref<8xi32,#gpu_lds_addrspace>) -> (!amdgpu.tdm_base<i32>) {
+ // CHECK-DAG: %[[INT:.+]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64
+ // CHECK-DAG: %[[MEMREF_DESC_MEM:.+]] = builtin.unrealized_conversion_cast %[[MEM]] : memref<8xi32, 1>
+ // CHECK-DAG: %[[MEMREF_DESC_SMEM:.+]] = builtin.unrealized_conversion_cast %[[SMEM]] : memref<8xi32, 3>
+
+ // CHECK-DAG: %[[MEM_BASE_PTR:.+]] = llvm.extractvalue %[[MEMREF_DESC_MEM]][1] : !llvm.struct<(ptr<1>
+ // CHECK-DAG: %[[SMEM_BASE_PTR:.+]] = llvm.extractvalue %[[MEMREF_DESC_SMEM]][1] : !llvm.struct<(ptr<3>
+
+ // CHECK-DAG: %[[MEM_BASE_OFFSET:.+]] = llvm.getelementptr %[[MEM_BASE_PTR]][%[[INT]]]
+ // CHECK-DAG: %[[SMEM_BASE_OFFSET:.+]] = llvm.getelementptr %[[SMEM_BASE_PTR]][%[[INT]]]
+
+ // CHECK-DAG: %[[MEM_INT:.+]] = llvm.ptrtoint %[[MEM_BASE_OFFSET]] : !llvm.ptr<1> to i64
+ // CHECK-DAG: %[[SMEM_INT:.+]] = llvm.ptrtoint %[[SMEM_BASE_OFFSET]] : !llvm.ptr<3> to i32
+
+ // CHECK: %[[MEM_INT_LOW:.+]] = llvm.trunc %[[MEM_INT]] : i64 to i32
+ // CHECK-DAG: %[[SHIFT:.+]] = llvm.mlir.constant(32 : i64)
+ // CHECK: %[[SHIFTED_MEM_INT:.+]] = llvm.lshr %[[MEM_INT]], %[[SHIFT]]
+ // CHECK: %[[MEM_INT_HIGH:.+]] = llvm.trunc %[[SHIFTED_MEM_INT]] : i64 to i32
+ // CHECK-DAG: %[[MASK:.+]] = llvm.mlir.constant(33554431 : i32)
+ // CHECK: %[[VALID_MEM_INT_HIGH:.+]] = llvm.and %[[MEM_INT_HIGH]], %[[MASK]]
+
+ // CHECK-DAG: %[[TYPE_FIELD:.+]] = llvm.mlir.constant(-2147483648 : i32)
+ // CHECK: %[[MEM_INT_HIGH_TYPE:.+]] = llvm.or %[[VALID_MEM_INT_HIGH]], %[[TYPE_FIELD]]
+
+ // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK-DAG: %[[C1:.+]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-DAG: %[[C2:.+]] = llvm.mlir.constant(2 : i32) : i32
+ // CHECK-DAG: %[[C3:.+]] = llvm.mlir.constant(3 : i32) : i32
+
+ // CHECK: %[[V4I32_0_0:.+]] = llvm.mlir.poison : vector<4xi32>
+ // CHECK: %[[V4I32_0_1:.+]] = llvm.insertelement %[[C1]], %[[V4I32_0_0]][%[[C0]] : i32]
+ // CHECK: %[[V4I32_0_2:.+]] = llvm.insertelement %[[SMEM_INT]], %[[V4I32_0_1]][%[[C1]] : i32]
+ // CHECK: %[[V4I32_0_3:.+]] = llvm.insertelement %[[MEM_INT_LOW]], %[[V4I32_0_2]][%[[C2]] : i32]
+ // CHECK: %[[V4I32_0_4:.+]] = llvm.insertelement %[[MEM_INT_HIGH_TYPE]], %[[V4I32_0_3]][%[[C3]] : i32]
+
+ %0 = amdgpu.make_dma_base %mem[%idx], %smem[%idx] : memref<8xi32, #gpu_global_addrspace>, memref<8xi32, #gpu_lds_addrspace> -> !amdgpu.tdm_base<i32>
+
+ func.return %0 : !amdgpu.tdm_base<i32>
+}
+
+// -----
+
+// CHECK-LABEL: func @make_dma_descriptor
+// CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_base<i32>)
+func.func @make_dma_descriptor(%base: !amdgpu.tdm_base<i32>) -> !amdgpu.tdm_descriptor {
+ // CHECK-DAG: %[[DGROUP0:.+]] = builtin.unrealized_conversion_cast %[[BASE]]
+
+ // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32)
+ // CHECK-DAG: %[[C1:.+]] = llvm.mlir.constant(1 : i32)
+ // CHECK-DAG: %[[C2:.+]] = llvm.mlir.constant(2 : i32)
+ // CHECK-DAG: %[[C3:.+]] = llvm.mlir.constant(3 : i32)
+ // CHECK-DAG: %[[C4:.+]] = llvm.mlir.constant(4 : i32)
+ // CHECK-DAG: %[[C5:.+]] = llvm.mlir.constant(5 : i32)
+ // CHECK-DAG: %[[C6:.+]] = llvm.mlir.constant(6 : i32)
+ // CHECK-DAG: %[[C7:.+]] = llvm.mlir.constant(7 : i32)
+
+ // CHECK-DAG: %[[DATA_SIZE:.+]] = llvm.mlir.constant(2 : i32)
+ // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32)
+ // CHECK: %[[SGPR0:.+]] = llvm.shl %[[DATA_SIZE]], %[[C16]]
+
+ // CHECK-DAG: %[[TENSOR_DIM_0:.+]] = llvm.mlir.constant(64 : i32)
+ // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32)
+ // CHECK: %[[SGPR2_0:.+]] = llvm.lshr %[[TENSOR_DIM_0]], %[[C16]]
+ // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32)
+ // CHECK: %[[SGPR1:.+]] = llvm.shl %[[TENSOR_DIM_0]], %[[C16]]
+
+ // CHECK-DAG: %[[TENSOR_DIM_1:.+]] = llvm.mlir.constant(128 : i32)
+ // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32)
+ // CHECK: %[[SGPR3_0:.+]] = llvm.lshr %[[TENSOR_DIM_1]], %[[C16]]
+ // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32)
+ // CHECK: %[[TENSOR_DIM_1_SHIFTED:.+]] = llvm.shl %[[TENSOR_DIM_1]], %[[C16]]
+ // CHECK: %[[SGPR2:.+]] = llvm.or %[[SGPR2_0]], %[[TENSOR_DIM_1_SHIFTED]]
+
+ // CHECK-DAG: %[[TILE_DIM_0:.+]] = llvm.mlir.constant(64 : i32)
+ // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32)
+ // CHECK: %[[TILE_DIM_0_SHIFTED:.+]] = llvm.shl %[[TILE_DIM_0:.+]], %[[C16]]
+ // CHECK: %[[SGPR3:.+]] = llvm.or %[[SGPR3_0]], %[[TILE_DIM_0_SHIFTED]]
+
+ // CHECK-DAG: %[[SGPR4:.+]] = llvm.mlir.constant(128 : i32)
+
+ // CHECK-DAG: %[[TENSOR_DIM_0_STRIDE:.+]] = llvm.mlir.constant(1 : i64) : i64
+ // CHECK-DAG: %[[MASK:.+]] = llvm.mlir.constant(281474976710655 : i64) : i64
+ // CHECK: %[[TENSOR_DIM_0_STRIDE_MASKED:.+]] = llvm.and %[[MASK]], %[[TENSOR_DIM_0_STRIDE]]
+ // CHECK-DAG: %[[SGPR5:.+]] = llvm.trunc %[[TENSOR_DIM_0_STRIDE_MASKED]] : i64 to i32
+ // CHECK-DAG: %[[SHIFT:.+]] = llvm.mlir.constant(32 : i64) : i64
+ // CHECK: %[[TENSOR_DIM_0_STRIDE_HIGH_64:.+]] = llvm.lshr %[[TENSOR_DIM_0_STRIDE_MASKED]], %[[SHIFT]]
+ // CHECK: %[[SGPR6_0:.+]] = llvm.trunc %[[TENSOR_DIM_0_STRIDE_HIGH_64]] : i64 to i32
+
+ // CHECK-DAG: %[[TENSOR_DIM_1_STRIDE:.+]] = llvm.mlir.constant(64 : i64)
+ // CHECK-DAG: %[[MASK:.+]] = llvm.mlir.constant(281474976710655 : i64) : i64
+ // CHECK: %[[TENSOR_DIM_1_STRIDE_MASKED:.+]] = llvm.and %[[MASK]], %[[TENSOR_DIM_1_STRIDE]]
+ // CHECK-DAG: %[[TENSOR_DIM_1_STRIDE_LOW:.+]] = llvm.trunc %[[TENSOR_DIM_1_STRIDE_MASKED]]
+ // CHECK-DAG: %[[SHIFT:.+]] = llvm.mlir.constant(16 : i64) : i64
+ // CHECK: %[[TENSOR_DIM_1_STRIDE_SHIFTED:.+]] = llvm.lshr %[[TENSOR_DIM_1_STRIDE_MASKED]], %[[SHIFT]]
+ // CHECK: %[[SGPR7:.+]] = llvm.trunc %[[TENSOR_DIM_1_STRIDE_SHIFTED]] : i64 to i32
+ // CHECK-DAG: %[[SHIFT:.+]] = llvm.mlir.constant(16 : i32) : i32
+ // CHECK: %[[TENSOR_DIM_1_STRIDE_LOW_SHIFTED:.+]] = llvm.shl %[[TENSOR_DIM_1_STRIDE_LOW]], %[[SHIFT]]
+ // CHECK-DAG: %[[SGPR6:.+]] = llvm.or %[[SGPR6_0]], %[[TENSOR_DIM_1_STRIDE_LOW_SHIFTED]]
+
+ // CHECK: %[[V8I32:.+]] = llvm.mlir.poison : vector<8xi32>
+ // CHECK: %[[DGROUP1_0:.+]] = llvm.insertelement %[[SGPR0]], %[[V8I32]][%[[C0]] : i32]
+ // CHECK: %[[DGROUP1_1:.+]] = llvm.insertelement %[[SGPR1]], %[[DGROUP1_0]][%[[C1]] : i32]
+ // CHECK: %[[DGROUP1_2:.+]] = llvm.insertelement %[[SGPR2]], %[[DGROUP1_1]][%[[C2]] : i32]
+ // CHECK: %[[DGROUP1_3:.+]] = llvm.insertelement %[[SGPR3]], %[[DGROUP1_2]][%[[C3]] : i32]
+ // CHECK: %[[DGROUP1_4:.+]] = llvm.insertelement %[[SGPR4]], %[[DGROUP1_3]][%[[C4]] : i32]
+ // CHECK: %[[DGROUP1_5:.+]] = llvm.insertelement %[[SGPR5]], %[[DGROUP1_4]][%[[C5]] : i32]
+ // CHECK: %[[DGROUP1_6:.+]] = llvm.insertelement %[[SGPR6]], %[[DGROUP1_5]][%[[C6]] : i32]
+ // CHECK: %[[DGROUP1:.+]] = llvm.insertelement %[[SGPR7]], %[[DGROUP1_6]][%[[C7]] : i32]
+
+ // CHECK: %[[DGROUPS:.+]] = builtin.unrealized_conversion_cast %[[DGROUP0]], %[[DGROUP1]] : vector<4xi32>, vector<8xi32> to !amdgpu.tdm_descriptor
+ %descriptor = amdgpu.make_dma_descriptor %base globalSize [128, 64] globalStride [64, 1] sharedSize [128, 64] : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor
+ func.return %descriptor : !amdgpu.tdm_descriptor
+}
+
+// -----
+
+#gpu_global_addrspace = 1
+#gpu_lds_addrspace = 3
+#amdgpu_fat_buffer_addrspace = 7
+
+// CHECK-LABEL: func @make_dma_descriptor_atomic_barrier
+// CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_base<i32>, %[[BARRIER:.+]]: {{.*}}, %[[IDX:.+]]: index)
+func.func @make_dma_descriptor_atomic_barrier(%base: !amdgpu.tdm_base<i32>, %barrier : memref<8xi32, #gpu_lds_addrspace>, %idx: index) -> !amdgpu.tdm_descriptor {
+ // CHECK-DAG: %[[INDEX:.+]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64
+ // CHECK-DAG: %[[BARRIER_MEMREF_DESC:.+]] = builtin.unrealized_conversion_cast %[[BARRIER]]
+ // CHECK-DAG: %[[DGROUP0:.+]] = builtin.unrealized_conversion_cast %[[BASE]]
+
+ // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32)
+ // CHECK-DAG: %[[C1:.+]] = llvm.mlir.constant(1 : i32)
+ // CHECK-DAG: %[[C2:.+]] = llvm.mlir.constant(2 : i32)
+ // CHECK-DAG: %[[C3:.+]] = llvm.mlir.constant(3 : i32)
+ // CHECK-DAG: %[[C4:.+]] = llvm.mlir.constant(4 : i32)
+ // CHECK-DAG: %[[C5:.+]] = llvm.mlir.constant(5 : i32)
+ // CHECK-DAG: %[[C6:.+]] = llvm.mlir.constant(6 : i32)
+ // CHECK-DAG: %[[C7:.+]] = llvm.mlir.constant(7 : i32)
+
+ // CHECK-DAG: %[[DATA_SIZE:.+]] = llvm.mlir.constant(2 : i32)
+ // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32)
+ // CHECK: %[[SGPR0_0:.+]] = llvm.shl %[[DATA_SIZE]], %[[C16]]
+
+ // CHECK-DAG: %[[ATOMIC_BARRIER_ENABLE_OFFSET:.+]] = llvm.mlir.constant(18 : i32)
+ // CHECK: %[[ATOMIC_BARRIER_ENABLE_FIELD:.+]] = llvm.shl %[[C1]], %[[ATOMIC_BARRIER_ENABLE_OFFSET]]
+ // CHECK: %[[SGPR0:.+]] = llvm.or %[[SGPR0_0]], %[[ATOMIC_BARRIER_ENABLE_FIELD]]
+
+ // CHECK: %[[ATOMIC_BARRIER_ALIGNED_PTR:.+]] = llvm.extractvalue %[[BARRIER_MEMREF_DESC]][1]
+ // CHECK: %[[ATOMIC_BARRIER_ADDR:.+]] = llvm.getelementptr %[[ATOMIC_BARRIER_ALIGNED_PTR]][%[[INDEX]]
+ // CHECK: %[[ATOMIC_BARRIER_I32:.+]] = llvm.ptrtoint %[[ATOMIC_BARRIER_ADDR]] : !llvm.ptr<3> to i32
+ // CHECK: %[[ATOMIC_BARRIER_NO_3_LSB:.+]] = llvm.lshr %[[ATOMIC_BARRIER_I32]], %[[C3]]
+ // CHECK: %[[MASK:.+]] = llvm.mlir.constant(65535 : i32)
+ // CHECK: %[[ATOMIC_BARRIER:.+]] = llvm.and %[[ATOMIC_BARRIER_NO_3_LSB]], %[[MASK]]
+
+ // CHECK-DAG: %[[TENSOR_DIM_0:.+]] = llvm.mlir.constant(64 : i32)
+ // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32)
+ // CHECK: %[[SGPR2_0:.+]] = llvm.lshr %[[TENSOR_DIM_0]], %[[C16]]
+ // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32)
+ // CHECK: %[[SGPR1_0:.+]] = llvm.shl %[[TENSOR_DIM_0]], %[[C16]]
+ // CHECK: %[[SGPR1:.+]] = llvm.or %[[ATOMIC_BARRIER]], %[[SGPR1_0]]
+
+ // CHECK: %[[V8I32:.+]] = llvm.mlir.poison : vector<8xi32>
+ // CHECK: %[[DGROUP1_0:.+]] = llvm.insertelement %[[SGPR0]], %[[V8I32]][%[[C0]] : i32]
+ // CHECK: %[[DGROUP1_1:.+]] = llvm.insertelement %[[SGPR1]], %[[DGROUP1_0]][%[[C1]] : i32]
+
+ %descriptor = amdgpu.make_dma_descriptor %base globalSize [128, 64]
+ globalStride [64, 1]
+ sharedSize [128, 64]
+ atomicBarrier(%barrier[%idx] : memref<8xi32, #gpu_lds_addrspace>)
+ : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor
+ func.return %descriptor : !amdgpu.tdm_descriptor
+}
+
+// -----
+
+// CHECK-LABEL: func @make_dma_descriptor_workgroup_mask
+// CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_base<i32>, %[[WG_MASK:.+]]: i16, %[[TIMEOUT:.+]]: i1)
+func.func @make_dma_descriptor_workgroup_mask(%base: !amdgpu.tdm_base<i32>, %wg_mask: i16, %timeout: i1) -> !amdgpu.tdm_descriptor {
+ // CHECK-DAG: %[[DGROUP0:.+]] = builtin.unrealized_conversion_cast %[[BASE]]
+
+ // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32)
+ // CHECK-DAG: %[[C1:.+]] = llvm.mlir.constant(1 : i32)
+ // CHECK-DAG: %[[C2:.+]] = llvm.mlir.constant(2 : i32)
+ // CHECK-DAG: %[[C3:.+]] = llvm.mlir.constant(3 : i32)
+ // CHECK-DAG: %[[C4:.+]] = llvm.mlir.constant(4 : i32)
+ // CHECK-DAG: %[[C5:.+]] = llvm.mlir.constant(5 : i32)
+ // CHECK-DAG: %[[C6:.+]] = llvm.mlir.constant(6 : i32)
+ // CHECK-DAG: %[[C7:.+]] = llvm.mlir.constant(7 : i32)
+
+ // CHECK-DAG: %[[WG_MASK_EXT:.+]] = llvm.zext %[[WG_MASK]]
+ // CHECK-DAG: %[[DATA_SIZE:.+]] = llvm.mlir.constant(2 : i32)
+ // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32)
+ // CHECK: %[[DATA_SIZE_SHIFTED:.+]] = llvm.shl %[[DATA_SIZE]], %[[C16]]
+ // CHECK: %[[SGPR0_BASE:.+]] = llvm.or %[[WG_MASK_EXT]], %[[DATA_SIZE_SHIFTED]]
+ // CHECK-DAG: %[[C21:.+]] = llvm.mlir.constant(21 : i32)
+ // CHECK: %[[TIMEOUT_SHIFTED:.+]] = llvm.shl %[[C1]], %[[C21]]
+ // CHECK: %[[SGPR0:.+]] = llvm.or %[[SGPR0_BASE]], %[[TIMEOUT_SHIFTED]]
+
+ // CHECK-DAG: %[[TENSOR_DIM_0:.+]] = llvm.mlir.constant(64 : i32)
+ // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32)
+ // CHECK: %[[SGPR2_0:.+]] = llvm.lshr %[[TENSOR_DIM_0]], %[[C16]]
+ // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32)
+ // CHECK: %[[SGPR1:.+]] = llvm.shl %[[TENSOR_DIM_0]], %[[C16]]
+
+ // CHECK-DAG: %[[TENSOR_DIM_1:.+]] = llvm.mlir.constant(128 : i32)
+ // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32)
+ // CHECK: %[[SGPR3_0:.+]] = llvm.lshr %[[TENSOR_DIM_1]], %[[C16]]
+ // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32)
+ // CHECK: %[[TENSOR_DIM_1_SHIFTED:.+]] = llvm.shl %[[TENSOR_DIM_1]], %[[C16]]
+ // CHECK: %[[SGPR2:.+]] = llvm.or %[[SGPR2_0]], %[[TENSOR_DIM_1_SHIFTED]]
+
+ // CHECK-DAG: %[[TILE_DIM_0:.+]] = llvm.mlir.constant(64 : i32)
+ // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32)
+ // CHECK: %[[TILE_DIM_0_SHIFTED:.+]] = llvm.shl %[[TILE_DIM_0:.+]], %[[C16]]
+ // CHECK: %[[SGPR3:.+]] = llvm.or %[[SGPR3_0]], %[[TILE_DIM_0_SHIFTED]]
+
+ // CHECK-DAG: %[[SGPR4:.+]] = llvm.mlir.constant(128 : i32)
+
+ // CHECK-DAG: %[[TENSOR_DIM_0_STRIDE:.+]] = llvm.mlir.constant(1 : i64) : i64
+ // CHECK-DAG: %[[MASK:.+]] = llvm.mlir.constant(281474976710655 : i64) : i64
+ // CHECK: %[[TENSOR_DIM_0_STRIDE_MASKED:.+]] = llvm.and %[[MASK]], %[[TENSOR_DIM_0_STRIDE]]
+ // CHECK-DAG: %[[SGPR5:.+]] = llvm.trunc %[[TENSOR_DIM_0_STRIDE_MASKED]] : i64 to i32
+ // CHECK-DAG: %[[SHIFT:.+]] = llvm.mlir.constant(32 : i64) : i64
+ // CHECK: %[[TENSOR_DIM_0_STRIDE_HIGH_64:.+]] = llvm.lshr %[[TENSOR_DIM_0_STRIDE_MASKED]], %[[SHIFT]]
+ // CHECK: %[[SGPR6_0:.+]] = llvm.trunc %[[TENSOR_DIM_0_STRIDE_HIGH_64]] : i64 to i32
+
+ // CHECK-DAG: %[[TENSOR_DIM_1_STRIDE:.+]] = llvm.mlir.constant(64 : i64)
+ // CHECK-DAG: %[[MASK:.+]] = llvm.mlir.constant(281474976710655 : i64) : i64
+ // CHECK: %[[TENSOR_DIM_1_STRIDE_MASKED:.+]] = llvm.and %[[MASK]], %[[TENSOR_DIM_1_STRIDE]]
+ // CHECK-DAG: %[[TENSOR_DIM_1_STRIDE_LOW:.+]] = llvm.trunc %[[TENSOR_DIM_1_STRIDE_MASKED]]
+ // CHECK-DAG: %[[SHIFT:.+]] = llvm.mlir.constant(16 : i64) : i64
+ // CHECK: %[[TENSOR_DIM_1_STRIDE_SHIFTED:.+]] = llvm.lshr %[[TENSOR_DIM_1_STRIDE_MASKED]], %[[SHIFT]]
+ // CHECK: %[[SGPR7:.+]] = llvm.trunc %[[TENSOR_DIM_1_STRIDE_SHIFTED]] : i64 to i32
+ // CHECK-DAG: %[[SHIFT:.+]] = llvm.mlir.constant(16 : i32) : i32
+ // CHECK: %[[TENSOR_DIM_1_STRIDE_LOW_SHIFTED:.+]] = llvm.shl %[[TENSOR_DIM_1_STRIDE_LOW]], %[[SHIFT]]
+ // CHECK-DAG: %[[SGPR6:.+]] = llvm.or %[[SGPR6_0]], %[[TENSOR_DIM_1_STRIDE_LOW_SHIFTED]]
+
+ // CHECK: %[[V8I32:.+]] = llvm.mlir.poison : vector<8xi32>
+ // CHECK: %[[DGROUP1_0:.+]] = llvm.insertelement %[[SGPR0]], %[[V8I32]][%[[C0]] : i32]
+ // CHECK: %[[DGROUP1_1:.+]] = llvm.insertelement %[[SGPR1]], %[[DGROUP1_0]][%[[C1]] : i32]
+ // CHECK: %[[DGROUP1_2:.+]] = llvm.insertelement %[[SGPR2]], %[[DGROUP1_1]][%[[C2]] : i32]
+ // CHECK: %[[DGROUP1_3:.+]] = llvm.insertelement %[[SGPR3]], %[[DGROUP1_2]][%[[C3]] : i32]
+ // CHECK: %[[DGROUP1_4:.+]] = llvm.insertelement %[[SGPR4]], %[[DGROUP1_3]][%[[C4]] : i32]
+ // CHECK: %[[DGROUP1_5:.+]] = llvm.insertelement %[[SGPR5]], %[[DGROUP1_4]][%[[C5]] : i32]
+ // CHECK: %[[DGROUP1_6:.+]] = llvm.insertelement %[[SGPR6]], %[[DGROUP1_5]][%[[C6]] : i32]
+ // CHECK: %[[DGROUP1:.+]] = llvm.insertelement %[[SGPR7]], %[[DGROUP1_6]][%[[C7]] : i32]
+
+ // CHECK: %[[DGROUPS:.+]] = builtin.unrealized_conversion_cast %[[DGROUP0]], %[[DGROUP1]] : vector<4xi32>, vector<8xi32> to !amdgpu.tdm_descriptor
+ %descriptor = amdgpu.make_dma_descriptor %base globalSize [128, 64] globalStride [64, 1] sharedSize [128, 64] workgroupMask %wg_mask earlyTimeout %timeout : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor
+ func.return %descriptor : !amdgpu.tdm_descriptor
+}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait.mlir b/mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait.mlir
index 1016ee8..537ef59 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait.mlir
@@ -1,7 +1,7 @@
-// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx942 | FileCheck %s --check-prefixes=CHECK,GFX9
-// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1030 | FileCheck %s --check-prefixes=CHECK,GFX10
-// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1100 | FileCheck %s --check-prefixes=CHECK,GFX11
-// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1201 | FileCheck %s --check-prefixes=CHECK,GFX12
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx942 | FileCheck %s --check-prefixes=CHECK,GFX9
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1030 | FileCheck %s --check-prefixes=CHECK,GFX10
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1100 | FileCheck %s --check-prefixes=CHECK,GFX11
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1201 | FileCheck %s --check-prefixes=CHECK,GFX12
// CHECK-LABEL: func @memory_counter_wait
func.func @memory_counter_wait() {
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait_tensor.mlir b/mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait_tensor.mlir
new file mode 100644
index 0000000..5b29e01
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait_tensor.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 | FileCheck %s
+
+// CHECK-LABEL: func @memory_counter_wait_tensor
+func.func @memory_counter_wait_tensor() {
+ // CHECK: rocdl.s.wait.tensorcnt 3
+ amdgpu.memory_counter_wait tensor(3)
+
+ return
+}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait_unsupported.mlir b/mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait_unsupported.mlir
new file mode 100644
index 0000000..1d2f692
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait_unsupported.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt %s --verify-diagnostics --convert-amdgpu-to-rocdl=chipset=gfx942
+// RUN: mlir-opt %s --verify-diagnostics --convert-amdgpu-to-rocdl=chipset=gfx1030
+// RUN: mlir-opt %s --verify-diagnostics --convert-amdgpu-to-rocdl=chipset=gfx1100
+
+func.func @memory_counter_wait_tensor() {
+ // expected-error @below{{failed to legalize operation 'amdgpu.memory_counter_wait'}}
+ // expected-error @below{{'amdgpu.memory_counter_wait' op unsupported chipset}}
+ amdgpu.memory_counter_wait tensor(0)
+
+ return
+}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
index 9fcc147..4e6aa17 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
@@ -6,30 +6,30 @@ func.func @wmma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 :
%arg6 : vector<16xi8>, %arg7 : vector<8xi32>, %arg8 : vector<4xi32>,
%arg9 : vector<16xui8>, %arg10 : vector<16xi4>, %arg11 : vector<8xi4>) {
// CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32>
- amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32>
+ amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 {subwordOffset = 0 : i32} : vector<16xf16>, vector<16xf16>, vector<8xf32>
// CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32>
- amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32>
+ amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg2 {subwordOffset = 0 : i32} : vector<16xf16>, vector<16xf16>, vector<4xf32>
// CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xf32>) -> vector<8xf32>
- amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
+ amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg1 {subwordOffset = 0 : i32} : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
// CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<4xf32>) -> vector<4xf32>
- amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32>
- // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
+ amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg2 {subwordOffset = 0 : i32} : vector<16xbf16>, vector<16xbf16>, vector<4xf32>
+ // CHECK: rocdl.wmma.f16.16x16x16.f16 {{.*}} {opsel = true} : (vector<16xf16>, vector<16xf16>, vector<16xf16>) -> vector<16xf16>
amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16>
- // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16>
+ // CHECK: rocdl.wmma.f16.16x16x16.f16 {{.*}} : (vector<16xf16>, vector<16xf16>, vector<8xf16>) -> vector<8xf16>
amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
- // CHECK: %[[raw_bf16x16:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16>
+ // CHECK: %[[raw_bf16x16:.+]] = rocdl.wmma.bf16.16x16x16.bf16 {{.*}} {opsel = true} : (vector<16xi16>, vector<16xi16>, vector<16xi16>) -> vector<16xi16>
// CHECK-NEXT: llvm.bitcast %[[raw_bf16x16]] : vector<16xi16> to vector<16xbf16>
amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
- // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16>
+ // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16 {{.*}} : (vector<16xi16>, vector<16xi16>, vector<8xi16>) -> vector<8xi16>
// CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16>
amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
- // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32>
+ // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}{clamp = true, signA = true, signB = true} : (vector<4xi32>, vector<4xi32>, vector<8xi32>) -> vector<8xi32>
amdgpu.wmma 16x16x16 %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32>
- // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32>
+ // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}{clamp = true} : (vector<4xi32>, vector<4xi32>, vector<4xi32>) -> vector<4xi32>
amdgpu.wmma 16x16x16 %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32>
- // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
+ // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}{clamp = true, signA = true, signB = true} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32>
amdgpu.wmma 16x16x16 %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32>
- // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
+ // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}{clamp = true, signA = true, signB = true} : (i32, i32, vector<4xi32>) -> vector<4xi32>
amdgpu.wmma 16x16x16 %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32>
return
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
index 5788347..978227b 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
@@ -20,15 +20,15 @@ func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<4xf16>,
// CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xf32>) -> vector<4xf32>
amdgpu.wmma 16x16x16 %arg5 * %arg5 + %arg3 : vector<4xbf16>, vector<4xbf16>, vector<4xf32>
- // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<8xf16>, i1) -> vector<8xf16>
+ // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<8xf16>) -> vector<8xf16>
amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg0 : vector<8xf16>, vector<8xf16>, vector<8xf16>
- // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf16>, i1) -> vector<4xf16>
+ // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf16>) -> vector<4xf16>
amdgpu.wmma 16x16x16 %arg1 * %arg1 + %arg1 : vector<4xf16>, vector<4xf16>, vector<4xf16>
- // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<8xi16>, i1) -> vector<8xi16>
+ // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<8xi16>) -> vector<8xi16>
// CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16>
amdgpu.wmma 16x16x16 %arg4 * %arg4 + %arg4 : vector<8xbf16>, vector<8xbf16>, vector<8xbf16>
- // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xi16>, i1) -> vector<4xi16>
+ // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xi16>) -> vector<4xi16>
amdgpu.wmma 16x16x16 %arg5 * %arg5 + %arg5 : vector<4xbf16>, vector<4xbf16>, vector<4xbf16>
// CHECK: rocdl.wmma.f32.16x16x16.fp8_fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
@@ -51,19 +51,19 @@ func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<4xf16>,
// CHECK: rocdl.wmma.f32.16x16x16.bf8_fp8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32>
amdgpu.wmma 16x16x16 %arg9 * %arg7 + %arg3 : vector<4xf8E5M2>, vector<4xf8E4M3FN>, vector<4xf32>
- // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
+ // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}{clamp = true, signA = true, signB = true} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32>
amdgpu.wmma 16x16x16 %arg10 * %arg10 + %arg12 {clamp} : vector<8xi8>, vector<8xi8>, vector<8xi32>
- // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
+ // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}{clamp = true} : (i32, i32, vector<4xi32>) -> vector<4xi32>
amdgpu.wmma 16x16x16 %arg11 * %arg11 + %arg13 {unsignedA, unsignedB, clamp}: vector<4xi8>, vector<4xi8>, vector<4xi32>
- // CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
+ // CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}{clamp = true, signA = true, signB = true} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32>
amdgpu.wmma 16x16x32 %arg14 * %arg14 + %arg12 {clamp} : vector<16xi4>, vector<16xi4>, vector<8xi32>
- // CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
+ // CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}{clamp = true, signA = true, signB = true} : (i32, i32, vector<4xi32>) -> vector<4xi32>
amdgpu.wmma 16x16x32 %arg15 * %arg15 + %arg13 {clamp} : vector<8xi4>, vector<8xi4>, vector<4xi32>
- // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<8xi32>, i1) -> vector<8xi32>
+ // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}{clamp = true, signA = true, signB = true} : (i32, i32, vector<8xi32>) -> vector<8xi32>
amdgpu.wmma 16x16x16 %arg15 * %arg15 + %arg12 {clamp} : vector<8xi4>, vector<8xi4>, vector<8xi32>
- // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
+ // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}{clamp = true, signA = true, signB = true} : (i32, i32, vector<4xi32>) -> vector<4xi32>
amdgpu.wmma 16x16x16 %arg16 * %arg16 + %arg13 {clamp} : vector<4xi4>, vector<4xi4>, vector<4xi32>
func.return
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
index 5e77a3ad..37259f6 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
@@ -14,13 +14,13 @@ func.func @wmma_k32(%arg0 : vector<16xf16>, %arg1 : vector<16xbf16>, %arg2 : vec
// CHECK: rocdl.wmma.f32.16x16x32.f16 %arg0, %arg0, %arg2
amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<8xf32>
- // CHECK: rocdl.wmma.f16.16x16x32.f16 %arg0, %arg0, {{.*}} : (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1)
+ // CHECK: rocdl.wmma.f16.16x16x32.f16 %arg0, %arg0, {{.*}} : (vector<16xf16>, vector<16xf16>, vector<8xf16>)
amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg3 : vector<16xf16>, vector<16xf16>, vector<8xf16>
// CHECK: rocdl.wmma.f32.16x16x32.bf16 {{.*}}, {{.*}}, %arg2
amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
- // CHECK: rocdl.wmma.bf16.16x16x32.bf16 {{.*}}, {{.*}}, {{.*}}, {{.*}} : (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1)
+ // CHECK: rocdl.wmma.bf16.16x16x32.bf16 {{.*}}, {{.*}}, {{.*}} : (vector<16xbf16>, vector<16xbf16>, vector<8xbf16>)
amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg4 : vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
return
@@ -29,31 +29,31 @@ func.func @wmma_k32(%arg0 : vector<16xf16>, %arg1 : vector<16xbf16>, %arg2 : vec
// CHECK-LABEL: @wmma_k64
func.func @wmma_k64(%arg0 : vector<32xi8>, %arg1 : vector<32xf8E4M3FN>, %arg2 : vector<32xf8E5M2>,
%arg3 : vector<8xi32>, %arg4 : vector<8xf32>, %arg5 : vector<8xf16>) {
- // CHECK: rocdl.wmma.i32.16x16x64.iu8 {{.*}}, {{.*}}, {{.*}}, {{.*}}, %arg3, {{.*}}
+ // CHECK: rocdl.wmma.i32.16x16x64.iu8 {{.*}}, {{.*}}, %arg3 {clamp = true, signA = true, signB = true}
amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg3 {clamp} : vector<32xi8>, vector<32xi8>, vector<8xi32>
// CHECK: rocdl.wmma.f32.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg4
amdgpu.wmma 16x16x64 %arg1 * %arg1 + %arg4 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf32>
- // CHECK: rocdl.wmma.f16.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
+ // CHECK: rocdl.wmma.f16.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg5 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>)
amdgpu.wmma 16x16x64 %arg1 * %arg1 + %arg5 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf16>
// CHECK: rocdl.wmma.f32.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg4
amdgpu.wmma 16x16x64 %arg1 * %arg2 + %arg4 : vector<32xf8E4M3FN>, vector<32xf8E5M2>, vector<8xf32>
- // CHECK: rocdl.wmma.f16.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
+ // CHECK: rocdl.wmma.f16.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg5 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>)
amdgpu.wmma 16x16x64 %arg1 * %arg2 + %arg5 : vector<32xf8E4M3FN>, vector<32xf8E5M2>, vector<8xf16>
// CHECK: rocdl.wmma.f32.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg4
amdgpu.wmma 16x16x64 %arg2 * %arg2 + %arg4 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf32>
- // CHECK: rocdl.wmma.f16.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
+ // CHECK: rocdl.wmma.f16.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg5 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>)
amdgpu.wmma 16x16x64 %arg2 * %arg2 + %arg5 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf16>
// CHECK: rocdl.wmma.f32.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg4
amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg4 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf32>
- // CHECK: rocdl.wmma.f16.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
+ // CHECK: rocdl.wmma.f16.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg5 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>)
amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg5 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf16>
return
@@ -65,25 +65,25 @@ func.func @wmma_k128(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
// CHECK: rocdl.wmma.f32.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg2
amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg2 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf32>
- // CHECK: rocdl.wmma.f16.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
+ // CHECK: rocdl.wmma.f16.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg3 {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>)
amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg3 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf16>
// CHECK: rocdl.wmma.f32.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg2
amdgpu.wmma 16x16x128 %arg0 * %arg1 + %arg2 : vector<64xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf32>
- // CHECK: rocdl.wmma.f16.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
+ // CHECK: rocdl.wmma.f16.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg3 {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>)
amdgpu.wmma 16x16x128 %arg0 * %arg1 + %arg3 : vector<64xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf16>
// CHECK: rocdl.wmma.f32.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg2
amdgpu.wmma 16x16x128 %arg1 * %arg1 + %arg2 : vector<64xf8E5M2>, vector<64xf8E5M2>, vector<8xf32>
- // CHECK: rocdl.wmma.f16.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
+ // CHECK: rocdl.wmma.f16.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg3 {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>)
amdgpu.wmma 16x16x128 %arg1 * %arg1 + %arg3 : vector<64xf8E5M2>, vector<64xf8E5M2>, vector<8xf16>
// CHECK: rocdl.wmma.f32.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg2
amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg2 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf32>
- // CHECK: rocdl.wmma.f16.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
+ // CHECK: rocdl.wmma.f16.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg3 {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>)
amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg3 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf16>
return
diff --git a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
new file mode 100644
index 0000000..bd4a9da
--- /dev/null
+++ b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
@@ -0,0 +1,329 @@
+// RUN: mlir-opt %s --convert-arith-to-apfloat -split-input-file -verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL: func.func private @_mlir_apfloat_add(i32, i64, i64) -> i64
+
+// CHECK-LABEL: func.func @foo() -> f8E4M3FN {
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant 2.250000e+00 : f8E4M3FN
+// CHECK: return %[[CONSTANT_0]] : f8E4M3FN
+// CHECK: }
+
+// CHECK-LABEL: func.func @bar() -> f6E3M2FN {
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant 3.000000e+00 : f6E3M2FN
+// CHECK: return %[[CONSTANT_0]] : f6E3M2FN
+// CHECK: }
+
+// Illustrate that both f8E4M3FN and f6E3M2FN calling the same _mlir_apfloat_add is fine
+// because each gets its own semantics enum and gets bitcast/extui/trunci to its own width.
+// CHECK-LABEL: func.func @full_example() {
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant 1.375000e+00 : f8E4M3FN
+// CHECK: %[[VAL_0:.*]] = call @foo() : () -> f8E4M3FN
+// CHECK: %[[BITCAST_0:.*]] = arith.bitcast %[[CONSTANT_0]] : f8E4M3FN to i8
+// CHECK: %[[EXTUI_0:.*]] = arith.extui %[[BITCAST_0]] : i8 to i64
+// CHECK: %[[BITCAST_1:.*]] = arith.bitcast %[[VAL_0]] : f8E4M3FN to i8
+// CHECK: %[[EXTUI_1:.*]] = arith.extui %[[BITCAST_1]] : i8 to i64
+// // fltSemantics semantics for f8E4M3FN
+// CHECK: %[[CONSTANT_1:.*]] = arith.constant 10 : i32
+// CHECK: %[[VAL_1:.*]] = call @_mlir_apfloat_add(%[[CONSTANT_1]], %[[EXTUI_0]], %[[EXTUI_1]]) : (i32, i64, i64) -> i64
+// CHECK: %[[TRUNCI_0:.*]] = arith.trunci %[[VAL_1]] : i64 to i8
+// CHECK: %[[BITCAST_2:.*]] = arith.bitcast %[[TRUNCI_0]] : i8 to f8E4M3FN
+// CHECK: vector.print %[[BITCAST_2]] : f8E4M3FN
+
+// CHECK: %[[CONSTANT_2:.*]] = arith.constant 2.500000e+00 : f6E3M2FN
+// CHECK: %[[VAL_2:.*]] = call @bar() : () -> f6E3M2FN
+// CHECK: %[[BITCAST_3:.*]] = arith.bitcast %[[CONSTANT_2]] : f6E3M2FN to i6
+// CHECK: %[[EXTUI_2:.*]] = arith.extui %[[BITCAST_3]] : i6 to i64
+// CHECK: %[[BITCAST_4:.*]] = arith.bitcast %[[VAL_2]] : f6E3M2FN to i6
+// CHECK: %[[EXTUI_3:.*]] = arith.extui %[[BITCAST_4]] : i6 to i64
+// // fltSemantics semantics for f6E3M2FN
+// CHECK: %[[CONSTANT_3:.*]] = arith.constant 16 : i32
+// CHECK: %[[VAL_3:.*]] = call @_mlir_apfloat_add(%[[CONSTANT_3]], %[[EXTUI_2]], %[[EXTUI_3]]) : (i32, i64, i64) -> i64
+// CHECK: %[[TRUNCI_1:.*]] = arith.trunci %[[VAL_3]] : i64 to i6
+// CHECK: %[[BITCAST_5:.*]] = arith.bitcast %[[TRUNCI_1]] : i6 to f6E3M2FN
+// CHECK: vector.print %[[BITCAST_5]] : f6E3M2FN
+// CHECK: return
+// CHECK: }
+
+// Put rhs into separate function so that it won't be constant-folded.
+func.func @foo() -> f8E4M3FN {
+ %cst = arith.constant 2.2 : f8E4M3FN
+ return %cst : f8E4M3FN
+}
+
+func.func @bar() -> f6E3M2FN {
+ %cst = arith.constant 3.2 : f6E3M2FN
+ return %cst : f6E3M2FN
+}
+
+func.func @full_example() {
+ %a = arith.constant 1.4 : f8E4M3FN
+ %b = func.call @foo() : () -> (f8E4M3FN)
+ %c = arith.addf %a, %b : f8E4M3FN
+ vector.print %c : f8E4M3FN
+
+ %d = arith.constant 2.4 : f6E3M2FN
+ %e = func.call @bar() : () -> (f6E3M2FN)
+ %f = arith.addf %d, %e : f6E3M2FN
+ vector.print %f : f6E3M2FN
+ return
+}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_add(i32, i64, i64) -> i64
+// CHECK: %[[sem:.*]] = arith.constant 18 : i32
+// CHECK: call @_mlir_apfloat_add(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
+func.func @addf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
+ %0 = arith.addf %arg0, %arg1 : f4E2M1FN
+ return
+}
+
+// -----
+
+// Test decl collision (different type)
+// expected-error@+1{{matched function '_mlir_apfloat_add' but with different type: '(i32, i32, f32) -> index' (expected '(i32, i64, i64) -> i64')}}
+func.func private @_mlir_apfloat_add(i32, i32, f32) -> index
+func.func @addf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
+ %0 = arith.addf %arg0, %arg1 : f4E2M1FN
+ return
+}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_subtract(i32, i64, i64) -> i64
+// CHECK: %[[sem:.*]] = arith.constant 18 : i32
+// CHECK: call @_mlir_apfloat_subtract(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
+func.func @subf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
+ %0 = arith.subf %arg0, %arg1 : f4E2M1FN
+ return
+}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_multiply(i32, i64, i64) -> i64
+// CHECK: %[[sem:.*]] = arith.constant 18 : i32
+// CHECK: call @_mlir_apfloat_multiply(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
+func.func @subf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
+ %0 = arith.mulf %arg0, %arg1 : f4E2M1FN
+ return
+}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_divide(i32, i64, i64) -> i64
+// CHECK: %[[sem:.*]] = arith.constant 18 : i32
+// CHECK: call @_mlir_apfloat_divide(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
+func.func @subf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
+ %0 = arith.divf %arg0, %arg1 : f4E2M1FN
+ return
+}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_remainder(i32, i64, i64) -> i64
+// CHECK: %[[sem:.*]] = arith.constant 18 : i32
+// CHECK: call @_mlir_apfloat_remainder(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
+func.func @remf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
+ %0 = arith.remf %arg0, %arg1 : f4E2M1FN
+ return
+}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_convert(i32, i32, i64) -> i64
+// CHECK: %[[sem_in:.*]] = arith.constant 18 : i32
+// CHECK: %[[sem_out:.*]] = arith.constant 2 : i32
+// CHECK: call @_mlir_apfloat_convert(%[[sem_in]], %[[sem_out]], %{{.*}}) : (i32, i32, i64) -> i64
+func.func @extf(%arg0: f4E2M1FN) {
+ %0 = arith.extf %arg0 : f4E2M1FN to f32
+ return
+}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_convert(i32, i32, i64) -> i64
+// CHECK: %[[sem_in:.*]] = arith.constant 1 : i32
+// CHECK: %[[sem_out:.*]] = arith.constant 18 : i32
+// CHECK: call @_mlir_apfloat_convert(%[[sem_in]], %[[sem_out]], %{{.*}}) : (i32, i32, i64) -> i64
+func.func @truncf(%arg0: bf16) {
+ %0 = arith.truncf %arg0 : bf16 to f4E2M1FN
+ return
+}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_convert_to_int(i32, i32, i1, i64) -> i64
+// CHECK: %[[sem_in:.*]] = arith.constant 0 : i32
+// CHECK: %[[out_width:.*]] = arith.constant 4 : i32
+// CHECK: %[[is_unsigned:.*]] = arith.constant false
+// CHECK: %[[res:.*]] = call @_mlir_apfloat_convert_to_int(%[[sem_in]], %[[out_width]], %[[is_unsigned]], %{{.*}}) : (i32, i32, i1, i64) -> i64
+// CHECK: arith.trunci %[[res]] : i64 to i4
+func.func @fptosi(%arg0: f16) {
+ %0 = arith.fptosi %arg0 : f16 to i4
+ return
+}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_convert_to_int(i32, i32, i1, i64) -> i64
+// CHECK: %[[sem_in:.*]] = arith.constant 0 : i32
+// CHECK: %[[out_width:.*]] = arith.constant 4 : i32
+// CHECK: %[[is_unsigned:.*]] = arith.constant true
+// CHECK: %[[res:.*]] = call @_mlir_apfloat_convert_to_int(%[[sem_in]], %[[out_width]], %[[is_unsigned]], %{{.*}}) : (i32, i32, i1, i64) -> i64
+// CHECK: arith.trunci %[[res]] : i64 to i4
+func.func @fptoui(%arg0: f16) {
+ %0 = arith.fptoui %arg0 : f16 to i4
+ return
+}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_convert_from_int(i32, i32, i1, i64) -> i64
+// CHECK: %[[sem_out:.*]] = arith.constant 18 : i32
+// CHECK: %[[in_width:.*]] = arith.constant 32 : i32
+// CHECK: %[[is_unsigned:.*]] = arith.constant false
+// CHECK: %[[res:.*]] = call @_mlir_apfloat_convert_from_int(%[[sem_out]], %[[in_width]], %[[is_unsigned]], %{{.*}}) : (i32, i32, i1, i64) -> i64
+func.func @sitofp(%arg0: i32) {
+ %0 = arith.sitofp %arg0 : i32 to f4E2M1FN
+ return
+}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_convert_from_int(i32, i32, i1, i64) -> i64
+// CHECK: %[[sem_out:.*]] = arith.constant 18 : i32
+// CHECK: %[[in_width:.*]] = arith.constant 32 : i32
+// CHECK: %[[is_unsigned:.*]] = arith.constant true
+// CHECK: %[[res:.*]] = call @_mlir_apfloat_convert_from_int(%[[sem_out]], %[[in_width]], %[[is_unsigned]], %{{.*}}) : (i32, i32, i1, i64) -> i64
+func.func @uitofp(%arg0: i32) {
+ %0 = arith.uitofp %arg0 : i32 to f4E2M1FN
+ return
+}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_compare(i32, i64, i64) -> i8
+// CHECK: %[[sem:.*]] = arith.constant 18 : i32
+// CHECK: %[[cmp:.*]] = call @_mlir_apfloat_compare(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i8
+// CHECK: %[[c3:.*]] = arith.constant 3 : i8
+// CHECK: %[[is_unordered:.*]] = arith.cmpi eq, %[[cmp]], %[[c3]] : i8
+// CHECK: %[[c0:.*]] = arith.constant 0 : i8
+// CHECK: %[[is_lt:.*]] = arith.cmpi eq, %[[cmp]], %[[c0]] : i8
+// CHECK: arith.ori %[[is_unordered]], %[[is_lt]] : i1
+func.func @cmpf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
+ %0 = arith.cmpf "ult", %arg0, %arg1 : f4E2M1FN
+ return
+}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_neg(i32, i64) -> i64
+// CHECK: %[[sem:.*]] = arith.constant 2 : i32
+// CHECK: %[[res:.*]] = call @_mlir_apfloat_neg(%[[sem]], %{{.*}}) : (i32, i64) -> i64
+func.func @negf(%arg0: f32) {
+ %0 = arith.negf %arg0 : f32
+ return
+}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_minimum(i32, i64, i64) -> i64
+// CHECK: %[[sem:.*]] = arith.constant 2 : i32
+// CHECK: %[[res:.*]] = call @_mlir_apfloat_minimum(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
+func.func @minimumf(%arg0: f32, %arg1: f32) {
+ %0 = arith.minimumf %arg0, %arg1 : f32
+ return
+}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_maximum(i32, i64, i64) -> i64
+// CHECK: %[[sem:.*]] = arith.constant 2 : i32
+// CHECK: %[[res:.*]] = call @_mlir_apfloat_maximum(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
+func.func @maximumf(%arg0: f32, %arg1: f32) {
+ %0 = arith.maximumf %arg0, %arg1 : f32
+ return
+}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_minnum(i32, i64, i64) -> i64
+// CHECK: %[[sem:.*]] = arith.constant 2 : i32
+// CHECK: %[[res:.*]] = call @_mlir_apfloat_minnum(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
+func.func @minnumf(%arg0: f32, %arg1: f32) {
+ %0 = arith.minnumf %arg0, %arg1 : f32
+ return
+}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_maxnum(i32, i64, i64) -> i64
+// CHECK: %[[sem:.*]] = arith.constant 2 : i32
+// CHECK: %[[res:.*]] = call @_mlir_apfloat_maxnum(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
+func.func @maxnumf(%arg0: f32, %arg1: f32) {
+ %0 = arith.maxnumf %arg0, %arg1 : f32
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unsupported_bitwidth
+// CHECK: arith.addf {{.*}} : f128
+// CHECK: arith.negf {{.*}} : f128
+// CHECK: arith.cmpf {{.*}} : f128
+// CHECK: arith.extf {{.*}} : f32 to f128
+// CHECK: arith.truncf {{.*}} : f128 to f32
+// CHECK: arith.fptosi {{.*}} : f128 to i32
+// CHECK: arith.fptosi {{.*}} : f32 to i92
+// CHECK: arith.sitofp {{.*}} : i1 to f128
+// CHECK: arith.sitofp {{.*}} : i92 to f32
+func.func @unsupported_bitwidth(%arg0: f128, %arg1: f128, %arg2: f32) {
+ %0 = arith.addf %arg0, %arg1 : f128
+ %1 = arith.negf %arg0 : f128
+ %2 = arith.cmpf "ult", %arg0, %arg1 : f128
+ %3 = arith.extf %arg2 : f32 to f128
+ %4 = arith.truncf %arg0 : f128 to f32
+ %5 = arith.fptosi %arg0 : f128 to i32
+ %6 = arith.fptosi %arg2 : f32 to i92
+ %7 = arith.sitofp %2 : i1 to f128
+ %8 = arith.sitofp %6 : i92 to f32
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @addf_vector
+// CHECK-2: vector.to_elements
+
+// CHECK: arith.bitcast
+// CHECK: arith.extui
+// CHECK: arith.bitcast
+// CHECK: arith.extui
+// CHECK: call
+// CHECK: arith.trunci
+
+// CHECK: arith.bitcast
+// CHECK: arith.extui
+// CHECK: arith.bitcast
+// CHECK: arith.extui
+// CHECK: call
+// CHECK: arith.trunci
+
+// CHECK: arith.bitcast
+// CHECK: arith.extui
+// CHECK: arith.bitcast
+// CHECK: arith.extui
+// CHECK: call
+// CHECK: arith.trunci
+
+// CHECK: arith.bitcast
+// CHECK: arith.extui
+// CHECK: arith.bitcast
+// CHECK: arith.extui
+// CHECK: call
+// CHECK: arith.trunci
+
+// CHECK: vector.from_elements
+func.func @addf_vector(%arg0: vector<4xf4E2M1FN>, %arg1: vector<4xf4E2M1FN>) {
+ %0 = arith.addf %arg0, %arg1 : vector<4xf4E2M1FN>
+ return
+}
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index ba12ff2..b53c52d 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -738,6 +738,22 @@ func.func @ops_supporting_overflow(%arg0: i64, %arg1: i64) {
// -----
+// CHECK-LABEL: @ops_supporting_exact
+func.func @ops_supporting_exact(i32, i32) {
+^bb0(%arg0: i32, %arg1: i32):
+// CHECK: = llvm.ashr exact %arg0, %arg1 : i32
+ %0 = arith.shrsi %arg0, %arg1 exact : i32
+// CHECK: = llvm.lshr exact %arg0, %arg1 : i32
+ %1 = arith.shrui %arg0, %arg1 exact : i32
+// CHECK: = llvm.sdiv exact %arg0, %arg1 : i32
+ %2 = arith.divsi %arg0, %arg1 exact : i32
+// CHECK: = llvm.udiv exact %arg0, %arg1 : i32
+ %3 = arith.divui %arg0, %arg1 exact : i32
+ return
+}
+
+// -----
+
// CHECK-LABEL: func @memref_bitcast
// CHECK-SAME: (%[[ARG:.*]]: memref<?xi16>)
// CHECK: %[[V1:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<?xi16> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -747,3 +763,35 @@ func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
%2 = arith.bitcast %1 : memref<?xi16> to memref<?xbf16>
func.return %2 : memref<?xbf16>
}
+
+// -----
+
+// CHECK-LABEL: func @unsupported_fp_type
+// CHECK: arith.addf {{.*}} : f4E2M1FN
+// CHECK: arith.addf {{.*}} : vector<4xf4E2M1FN>
+// CHECK: arith.addf {{.*}} : vector<8x4xf4E2M1FN>
+// CHECK: arith.cmpf {{.*}} : f4E2M1FN
+// CHECK: llvm.select {{.*}} : i1, i4
+func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>, %arg3: f4E2M1FN, %arg4: i1) {
+ %0 = arith.addf %arg0, %arg0 : f4E2M1FN
+ %1 = arith.addf %arg1, %arg1 : vector<4xf4E2M1FN>
+ %2 = arith.addf %arg2, %arg2 : vector<8x4xf4E2M1FN>
+ %3 = arith.cmpf oeq, %arg0, %arg3 : f4E2M1FN
+ %4 = arith.select %arg4, %arg0, %arg3 : f4E2M1FN
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @supported_fp_type
+// CHECK: llvm.fadd {{.*}} : f32
+// CHECK: llvm.fadd {{.*}} : vector<4xf32>
+// CHECK-COUNT-4: llvm.fadd {{.*}} : vector<8xf32>
+// CHECK: llvm.fcmp {{.*}} : f32
+func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8xf32>, %arg3: f32) {
+ %0 = arith.addf %arg0, %arg0 : f32
+ %1 = arith.addf %arg1, %arg1 : vector<4xf32>
+ %2 = arith.addf %arg2, %arg2 : vector<4x8xf32>
+ %3 = arith.cmpf oeq, %arg0, %arg3 : f32
+ return
+}
diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
index a75f30d..cd8cfc8 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
@@ -275,6 +275,42 @@ func.func @reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 {
// -----
+// CHECK-LABEL: spirv.func @reduction_minnumf(
+// CHECK-SAME: %[[V:.*]]: vector<3xf32>,
+// CHECK-SAME: %[[S:.*]]: f32) -> f32 "None" {
+// CHECK: %[[S0:.*]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+// CHECK: %[[S1:.*]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+// CHECK: %[[S2:.*]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+// CHECK: %[[MIN0:.*]] = spirv.GL.FMin %[[S0]], %[[S1]] : f32
+// CHECK: %[[MIN1:.*]] = spirv.GL.FMin %[[MIN0]], %[[S2]] : f32
+// CHECK: %[[MIN2:.*]] = spirv.GL.FMin %[[MIN1]], %[[S]] : f32
+// CHECK: spirv.ReturnValue %[[MIN2]] : f32
+// CHECK: }
+func.func @reduction_minnumf(%v : vector<3xf32>, %s: f32) -> f32 {
+ %reduce = vector.reduction <minnumf>, %v, %s : vector<3xf32> into f32
+ return %reduce : f32
+}
+
+// -----
+
+// CHECK-LABEL: spirv.func @reduction_maxnumf(
+// CHECK-SAME: %[[V:.*]]: vector<3xf32>,
+// CHECK-SAME: %[[S:.*]]: f32) -> f32 "None" {
+// CHECK: %[[S0:.*]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+// CHECK: %[[S1:.*]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+// CHECK: %[[S2:.*]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+// CHECK: %[[MAX0:.*]] = spirv.GL.FMax %[[S0]], %[[S1]] : f32
+// CHECK: %[[MAX1:.*]] = spirv.GL.FMax %[[MAX0]], %[[S2]] : f32
+// CHECK: %[[MAX2:.*]] = spirv.GL.FMax %[[MAX1]], %[[S]] : f32
+// CHECK: spirv.ReturnValue %[[MAX2]] : f32
+// CHECK: }
+func.func @reduction_maxnumf(%v : vector<3xf32>, %s: f32) -> f32 {
+ %reduce = vector.reduction <maxnumf>, %v, %s : vector<3xf32> into f32
+ return %reduce : f32
+}
+
+// -----
+
// CHECK-LABEL: func @reduction_maxsi
// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
diff --git a/mlir/test/Conversion/FuncToLLVM/convert-data-layout.mlir b/mlir/test/Conversion/FuncToLLVM/convert-data-layout.mlir
index cb6bc35..c16702f 100644
--- a/mlir/test/Conversion/FuncToLLVM/convert-data-layout.mlir
+++ b/mlir/test/Conversion/FuncToLLVM/convert-data-layout.mlir
@@ -1,8 +1,8 @@
// RUN: mlir-opt -set-llvm-module-datalayout -convert-func-to-llvm %s | FileCheck %s
-// RUN-32: mlir-opt -set-llvm-module-datalayout='data-layout=p:32:32:32' -convert-func-to-llvm %s \
-// RUN-32: | FileCheck %s
+// RUN: mlir-opt -set-llvm-module-datalayout='data-layout=p:32:32:32' -convert-func-to-llvm %s \
+// RUN: | FileCheck %s --check-prefix=CHECK-32
// CHECK: module attributes {llvm.data_layout = ""}
-// CHECK-32: module attributes {llvm.data_layout ="p:32:32:32"}
+// CHECK-32: module attributes {llvm.data_layout = "p:32:32:32"}
module {}
diff --git a/mlir/test/Conversion/FuncToLLVM/convert-funcs.mlir b/mlir/test/Conversion/FuncToLLVM/convert-funcs.mlir
index ae1dc70..bd28162 100644
--- a/mlir/test/Conversion/FuncToLLVM/convert-funcs.mlir
+++ b/mlir/test/Conversion/FuncToLLVM/convert-funcs.mlir
@@ -32,7 +32,7 @@ func.func @pass_through(%arg0: () -> ()) -> (() -> ()) {
func.func private @llvmlinkage(i32) attributes { "llvm.linkage" = #llvm.linkage<extern_weak> }
// CHECK-LABEL: llvm.func @llvmreadnone(i32)
-// CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>
+// CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>
func.func private @llvmreadnone(i32) attributes { llvm.readnone }
// CHECK-LABEL: llvm.func @body(i32)
diff --git a/mlir/test/Conversion/FuncToLLVM/func-memref.mlir b/mlir/test/Conversion/FuncToLLVM/func-memref.mlir
index 15a9654..16ed484 100644
--- a/mlir/test/Conversion/FuncToLLVM/func-memref.mlir
+++ b/mlir/test/Conversion/FuncToLLVM/func-memref.mlir
@@ -1,5 +1,6 @@
// RUN: mlir-opt -pass-pipeline="builtin.module(func.func(convert-arith-to-llvm),convert-func-to-llvm,convert-cf-to-llvm,reconcile-unrealized-casts)" -split-input-file %s | FileCheck %s
// RUN: mlir-opt -pass-pipeline="builtin.module(func.func(convert-arith-to-llvm),convert-func-to-llvm{use-bare-ptr-memref-call-conv=1},convert-cf-to-llvm,reconcile-unrealized-casts)" -split-input-file %s | FileCheck %s --check-prefix=BAREPTR
+// RUN: mlir-opt -pass-pipeline="builtin.module(func.func(convert-arith-to-llvm{index-bitwidth=32}),convert-func-to-llvm{index-bitwidth=32},convert-cf-to-llvm{index-bitwidth=32},reconcile-unrealized-casts)" -split-input-file %s | FileCheck %s --check-prefix=CHECK32
// BAREPTR-LABEL: func @check_noalias
// BAREPTR-SAME: %{{.*}}: !llvm.ptr {llvm.noalias}, %{{.*}}: !llvm.ptr {llvm.noalias}
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index a4b5dde..f1cc1eb 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1' -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1 allow-pattern-rollback=0' -split-input-file | FileCheck %s
// RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1 allowed-dialects=func,arith,cf' -split-input-file | FileCheck %s
// RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1 use-bare-ptr-memref-call-conv=1' -split-input-file | FileCheck %s --check-prefix=CHECK-BARE
// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
diff --git a/mlir/test/Conversion/GPUToNVVM/memref.mlir b/mlir/test/Conversion/GPUToNVVM/memref.mlir
index e164ca9..a4e8ead 100644
--- a/mlir/test/Conversion/GPUToNVVM/memref.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/memref.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s -convert-gpu-to-nvvm | FileCheck %s
+// RUN: mlir-opt %s -convert-gpu-to-nvvm="allow-pattern-rollback=0" | FileCheck %s
// RUN: mlir-opt %s -convert-gpu-to-nvvm='use-bare-ptr-memref-call-conv=1' \
// RUN: | FileCheck %s --check-prefix=BARE
diff --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
index b479467..a080144 100644
--- a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt --convert-gpu-to-nvvm --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --convert-gpu-to-nvvm="allow-pattern-rollback=0" --split-input-file %s | FileCheck %s
// RUN: mlir-opt --convert-gpu-to-nvvm="index-bitwidth=32" --split-input-file %s | FileCheck --check-prefix=CHECK32 %s
gpu.module @test_module {
@@ -81,6 +82,28 @@ gpu.module @test_module {
gpu.module @test_module {
+ // CHECK-LABEL: func @gpu_wmma_f64_load_op() ->
+ // CHECK-SAME: f64
+ // CHECK32-LABEL: func @gpu_wmma_f64_load_op() ->
+ func.func @gpu_wmma_f64_load_op() -> (!gpu.mma_matrix<8x4xf64, "AOp">) {
+ %wg = memref.alloca() {alignment = 32} : memref<32x32xf64, 3>
+ %i = arith.constant 16 : index
+ %j = arith.constant 16 : index
+ %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf64, 3> -> !gpu.mma_matrix<8x4xf64, "AOp">
+ return %0 : !gpu.mma_matrix<8x4xf64, "AOp">
+ // CHECK: %[[MUL:.*]] = llvm.mul %{{.*}}, %{{.*}} : i64
+ // CHECK: %[[ADD:.*]] = llvm.add %[[MUL]], %{{.*}} : i64
+ // CHECK: %[[GEP:.*]] = llvm.getelementptr %{{.*}}[%[[ADD]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f64
+ // CHECK: %[[C32_I32:.*]] = llvm.mlir.constant(32 : index) : i32
+ // CHECK: %[[LOAD:.*]] = nvvm.wmma.load %[[GEP]], %[[C32_I32]] {eltype = #nvvm.mma_type<f64>, frag = #nvvm.mma_frag<a>, k = 4 : i32, layout = #nvvm.mma_layout<row>, m = 8 : i32, n = 8 : i32} : (!llvm.ptr<3>) -> f64
+ // CHECK: llvm.return %[[LOAD]] : f64
+ }
+}
+
+// -----
+
+gpu.module @test_module {
+
// CHECK-LABEL: func @gpu_wmma_store_op
// CHECK-SAME: (%[[D:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
// CHECK32-LABEL: func @gpu_wmma_store_op
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir
index c1627a0..19e1c7a 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefix=CPP
-// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefix=NOCPP
+// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefixes=CPP,CHECK
+// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefixes=NOCPP,CHECK
func.func @alloc_copy(%arg0: memref<999xi32>) {
%alloc = memref.alloc() : memref<999xi32>
@@ -9,42 +9,46 @@ func.func @alloc_copy(%arg0: memref<999xi32>) {
return
}
-// CHECK: module {
// NOCPP: emitc.include <"stdlib.h">
// NOCPP-NEXT: emitc.include <"string.h">
// CPP: emitc.include <"cstdlib">
// CPP-NEXT: emitc.include <"cstring">
-// CHECK-LABEL: alloc_copy
-// CHECK-SAME: %[[arg0:.*]]: memref<999xi32>
-// CHECK-NEXT: builtin.unrealized_conversion_cast %arg0 : memref<999xi32> to !emitc.array<999xi32>
-// CHECK-NEXT: emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
-// CHECK-NEXT: "emitc.constant"() <{value = 999 : index}> : () -> index
-// CHECK-NEXT: emitc.mul %1, %2 : (!emitc.size_t, index) -> !emitc.size_t
-// CHECK-NEXT: emitc.call_opaque "malloc"(%3) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
-// CHECK-NEXT: emitc.cast %4 : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
-// CHECK-NEXT: builtin.unrealized_conversion_cast %5 : !emitc.ptr<i32> to !emitc.array<999xi32>
-// CHECK-NEXT: "emitc.constant"() <{value = 0 : index}> : () -> index
-// CHECK-NEXT: emitc.subscript %0[%7] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32>
-// CHECK-NEXT: emitc.apply "&"(%8) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
-// CHECK-NEXT: emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
-// CHECK-NEXT: "emitc.constant"() <{value = 999 : index}> : () -> index
-// CHECK-NEXT: emitc.mul %12, %13 : (!emitc.size_t, index) -> !emitc.size_t
-// CHECK-NEXT: emitc.call_opaque "memcpy"(%11, %9, %14) : (!emitc.ptr<i32>, !emitc.ptr<i32>, !emitc.size_t) -> ()
-// CHECK-NEXT: emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
-// CHECK-NEXT: "emitc.constant"() <{value = 999 : index}> : () -> index
-// CHECK-NEXT: emitc.mul %15, %16 : (!emitc.size_t, index) -> !emitc.size_t
-// CHECK-NEXT: emitc.call_opaque "malloc"(%17) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
-// CHECK-NEXT: emitc.cast %18 : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
-// CHECK-NEXT: builtin.unrealized_conversion_cast %19 : !emitc.ptr<i32> to !emitc.array<999xi32>
-// CHECK-NEXT: "emitc.constant"() <{value = 0 : index}> : () -> index
-// CHECK-NEXT: emitc.subscript %0[%21] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32>
-// CHECK-NEXT: emitc.apply "&"(%22) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
-// CHECK-NEXT: emitc.subscript %20[%21] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32>
-// CHECK-NEXT: emitc.apply "&"(%24) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
-// CHECK-NEXT: emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
-// CHECK-NEXT: "emitc.constant"() <{value = 999 : index}> : () -> index
-// CHECK-NEXT: emitc.mul %26, %27 : (!emitc.size_t, index) -> !emitc.size_t
-// CHECK-NEXT: emitc.call_opaque "memcpy"(%25, %23, %28) : (!emitc.ptr<i32>, !emitc.ptr<i32>, !emitc.size_t) -> ()
-// CHECK-NEXT: return
+// CHECK-LABEL: func.func @alloc_copy(
+// CHECK-SAME: %[[ARG0:.*]]: memref<999xi32>) {
+// CHECK: %[[UNREALIZED_CONVERSION_CAST_0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<999xi32> to !emitc.array<999xi32>
+// CHECK: %[[CALL_OPAQUE_0:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
+// CHECK: %[[VAL_0:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index
+// CHECK: %[[MUL_0:.*]] = emitc.mul %[[CALL_OPAQUE_0]], %[[VAL_0]] : (!emitc.size_t, index) -> !emitc.size_t
+// CHECK: %[[CALL_OPAQUE_1:.*]] = emitc.call_opaque "malloc"(%[[MUL_0]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
+// CHECK: %[[CAST_0:.*]] = emitc.cast %[[CALL_OPAQUE_1]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
+// CHECK: %[[UNREALIZED_CONVERSION_CAST_1:.*]] = builtin.unrealized_conversion_cast %[[CAST_0]] : !emitc.ptr<i32> to !emitc.array<999xi32>
+// CHECK: %[[VAL_1:.*]] = "emitc.constant"() <{value = 0 : index}> : () -> index
+// CHECK: %[[SUBSCRIPT_0:.*]] = emitc.subscript %[[UNREALIZED_CONVERSION_CAST_0]]{{\[}}%[[VAL_1]]] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32>
+// CHECK: %[[ADDRESS_OF_0:.*]] = emitc.address_of %[[SUBSCRIPT_0]] : !emitc.lvalue<i32>
+// CHECK: %[[VAL_2:.*]] = "emitc.constant"() <{value = 0 : index}> : () -> index
+// CHECK: %[[SUBSCRIPT_1:.*]] = emitc.subscript %[[UNREALIZED_CONVERSION_CAST_1]]{{\[}}%[[VAL_2]]] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32>
+// CHECK: %[[ADDRESS_OF_1:.*]] = emitc.address_of %[[SUBSCRIPT_1]] : !emitc.lvalue<i32>
+// CHECK: %[[CALL_OPAQUE_2:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
+// CHECK: %[[VAL_3:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index
+// CHECK: %[[MUL_1:.*]] = emitc.mul %[[CALL_OPAQUE_2]], %[[VAL_3]] : (!emitc.size_t, index) -> !emitc.size_t
+// CHECK: emitc.call_opaque "memcpy"(%[[ADDRESS_OF_1]], %[[ADDRESS_OF_0]], %[[MUL_1]]) : (!emitc.ptr<i32>, !emitc.ptr<i32>, !emitc.size_t) -> ()
+// CHECK: %[[CALL_OPAQUE_3:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
+// CHECK: %[[VAL_4:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index
+// CHECK: %[[MUL_2:.*]] = emitc.mul %[[CALL_OPAQUE_3]], %[[VAL_4]] : (!emitc.size_t, index) -> !emitc.size_t
+// CHECK: %[[CALL_OPAQUE_4:.*]] = emitc.call_opaque "malloc"(%[[MUL_2]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
+// CHECK: %[[CAST_1:.*]] = emitc.cast %[[CALL_OPAQUE_4]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
+// CHECK: %[[UNREALIZED_CONVERSION_CAST_2:.*]] = builtin.unrealized_conversion_cast %[[CAST_1]] : !emitc.ptr<i32> to !emitc.array<999xi32>
+// CHECK: %[[VAL_5:.*]] = "emitc.constant"() <{value = 0 : index}> : () -> index
+// CHECK: %[[SUBSCRIPT_2:.*]] = emitc.subscript %[[UNREALIZED_CONVERSION_CAST_0]]{{\[}}%[[VAL_5]]] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32>
+// CHECK: %[[ADDRESS_OF_2:.*]] = emitc.address_of %[[SUBSCRIPT_2]] : !emitc.lvalue<i32>
+// CHECK: %[[VAL_6:.*]] = "emitc.constant"() <{value = 0 : index}> : () -> index
+// CHECK: %[[SUBSCRIPT_3:.*]] = emitc.subscript %[[UNREALIZED_CONVERSION_CAST_2]]{{\[}}%[[VAL_6]]] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32>
+// CHECK: %[[ADDRESS_OF_3:.*]] = emitc.address_of %[[SUBSCRIPT_3]] : !emitc.lvalue<i32>
+// CHECK: %[[CALL_OPAQUE_5:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
+// CHECK: %[[VAL_7:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index
+// CHECK: %[[MUL_3:.*]] = emitc.mul %[[CALL_OPAQUE_5]], %[[VAL_7]] : (!emitc.size_t, index) -> !emitc.size_t
+// CHECK: emitc.call_opaque "memcpy"(%[[ADDRESS_OF_3]], %[[ADDRESS_OF_2]], %[[MUL_3]]) : (!emitc.ptr<i32>, !emitc.ptr<i32>, !emitc.size_t) -> ()
+// CHECK: return
+// CHECK: }
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
index d151d1b..3de2d25 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefix=CPP
-// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefix=NOCPP
+// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefixes=CPP,CHECK
+// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefixes=NOCPP,CHECK
func.func @copying(%arg0 : memref<9x4x5x7xf32>, %arg1 : memref<9x4x5x7xf32>) {
memref.copy %arg0, %arg1 : memref<9x4x5x7xf32> to memref<9x4x5x7xf32>
@@ -10,20 +10,21 @@ func.func @copying(%arg0 : memref<9x4x5x7xf32>, %arg1 : memref<9x4x5x7xf32>) {
// NOCPP: emitc.include <"string.h">
// CPP: emitc.include <"cstring">
-// CHECK-LABEL: copying
-// CHECK-SAME: %[[arg0:.*]]: memref<9x4x5x7xf32>, %[[arg1:.*]]: memref<9x4x5x7xf32>
-// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg1 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
-// CHECK-NEXT: %1 = builtin.unrealized_conversion_cast %arg0 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
-// CHECK-NEXT: %2 = "emitc.constant"() <{value = 0 : index}> : () -> index
-// CHECK-NEXT: %3 = emitc.subscript %1[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
-// CHECK-NEXT: %4 = emitc.apply "&"(%3) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
-// CHECK-NEXT: %5 = emitc.subscript %0[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
-// CHECK-NEXT: %6 = emitc.apply "&"(%5) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
-// CHECK-NEXT: %7 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
-// CHECK-NEXT: %8 = "emitc.constant"() <{value = 1260 : index}> : () -> index
-// CHECK-NEXT: %9 = emitc.mul %7, %8 : (!emitc.size_t, index) -> !emitc.size_t
-// CHECK-NEXT: emitc.call_opaque "memcpy"(%6, %4, %9) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> ()
-// CHECK-NEXT: return
-// CHECK-NEXT: }
-// CHECK-NEXT:}
+// CHECK-LABEL: func.func @copying(
+// CHECK-SAME: %[[ARG0:.*]]: memref<9x4x5x7xf32>,
+// CHECK-SAME: %[[ARG1:.*]]: memref<9x4x5x7xf32>) {
+// CHECK: %[[UNREALIZED_CONVERSION_CAST_0:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
+// CHECK: %[[UNREALIZED_CONVERSION_CAST_1:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
+// CHECK: %[[VAL_0:.*]] = "emitc.constant"() <{value = 0 : index}> : () -> index
+// CHECK: %[[SUBSCRIPT_0:.*]] = emitc.subscript %[[UNREALIZED_CONVERSION_CAST_1]]{{\[}}%[[VAL_0]], %[[VAL_0]], %[[VAL_0]], %[[VAL_0]]] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
+// CHECK: %[[ADDRESS_OF_0:.*]] = emitc.address_of %[[SUBSCRIPT_0]] : !emitc.lvalue<f32>
+// CHECK: %[[VAL_1:.*]] = "emitc.constant"() <{value = 0 : index}> : () -> index
+// CHECK: %[[SUBSCRIPT_1:.*]] = emitc.subscript %[[UNREALIZED_CONVERSION_CAST_0]]{{\[}}%[[VAL_1]], %[[VAL_1]], %[[VAL_1]], %[[VAL_1]]] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
+// CHECK: %[[ADDRESS_OF_1:.*]] = emitc.address_of %[[SUBSCRIPT_1]] : !emitc.lvalue<f32>
+// CHECK: %[[CALL_OPAQUE_0:.*]] = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
+// CHECK: %[[VAL_2:.*]] = "emitc.constant"() <{value = 1260 : index}> : () -> index
+// CHECK: %[[MUL_0:.*]] = emitc.mul %[[CALL_OPAQUE_0]], %[[VAL_2]] : (!emitc.size_t, index) -> !emitc.size_t
+// CHECK: emitc.call_opaque "memcpy"(%[[ADDRESS_OF_1]], %[[ADDRESS_OF_0]], %[[MUL_0]]) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> ()
+// CHECK: return
+// CHECK: }
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
index 2b4eda3..c7b043b 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -53,7 +53,7 @@ module @globals {
// CHECK-NEXT: emitc.get_global @public_global : !emitc.array<3x7xf32>
%0 = memref.get_global @public_global : memref<3x7xf32>
// CHECK-NEXT: emitc.get_global @__constant_xi32 : !emitc.lvalue<i32>
- // CHECK-NEXT: emitc.apply "&"(%1) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
+ // CHECK-NEXT: emitc.address_of %1 : !emitc.lvalue<i32>
%1 = memref.get_global @__constant_xi32 : memref<i32>
return
}
diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index 68c3e9f..e0e4a61 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt -expand-strided-metadata -finalize-memref-to-llvm -lower-affine -convert-arith-to-llvm -cse %s -split-input-file | FileCheck %s
+// RUN: mlir-opt -expand-strided-metadata -finalize-memref-to-llvm='index-bitwidth=32' -lower-affine -convert-arith-to-llvm='index-bitwidth=32' -cse %s -split-input-file | FileCheck %s --check-prefix=CHECK32
//
// This test demonstrates a full "memref to llvm" pipeline where
// we first expand some of the memref operations (using affine,
@@ -441,10 +442,31 @@ func.func @collapse_shape_dynamic_with_non_identity_layout(
// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC6]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<4x?xf32, strided<[?, ?], offset: ?>>
// CHECK: return %[[RES]] : memref<4x?xf32, strided<[?, ?], offset: ?>>
// CHECK: }
-// CHECK32-LABEL: func @collapse_shape_dynamic_with_non_identity_layout(
-// CHECK32: llvm.mlir.constant(1 : index) : i32
-// CHECK32: llvm.mlir.constant(4 : index) : i32
-// CHECK32: llvm.mlir.constant(1 : index) : i32
+// CHECK32-LABEL: func.func @collapse_shape_dynamic_with_non_identity_layout(
+// CHECK32-SAME: %[[ARG:.*]]: memref<4x?x?xf32, strided<[?, 4, 1], offset: ?>>) -> memref<4x?xf32, strided<[?, ?], offset: ?>> {
+// CHECK32: %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<4x?x?xf32, strided<[?, 4, 1], offset: ?>> to !llvm.struct<(ptr, ptr, i32, array<3 x i32>, array<3 x i32>)>
+// CHECK32: %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i32,
+// CHECK32: %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i32,
+// CHECK32: %[[OFFSET:.*]] = llvm.extractvalue %[[MEM]][2] : !llvm.struct<(ptr, ptr, i32, array<3 x i32>, array<3 x i32>)>
+// CHECK32: %[[SIZE1:.*]] = llvm.extractvalue %[[MEM]][3, 1] : !llvm.struct<(ptr, ptr, i32, array<3 x i32>, array<3 x i32>)>
+// CHECK32: %[[SIZE2:.*]] = llvm.extractvalue %[[MEM]][3, 2] : !llvm.struct<(ptr, ptr, i32, array<3 x i32>, array<3 x i32>)>
+// CHECK32: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i32, array<3 x i32>, array<3 x i32>)>
+// CHECK32: %[[FINAL_SIZE1_I32:.*]] = llvm.mul %[[SIZE1]], %[[SIZE2]] overflow<nsw> : i32
+// CHECK32: %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1_I32]] : i32 to index
+// CHECK32: %[[FINAL_SIZE1_CAST:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i32
+// CHECK32: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)>
+// CHECK32: %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)>
+// CHECK32: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)>
+// CHECK32: %[[DESC2:.*]] = llvm.insertvalue %[[OFFSET]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)>
+// CHECK32: %[[C4_I32:.*]] = llvm.mlir.constant(4 : index) : i32
+// CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[C4_I32]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)>
+// CHECK32: %[[DESC4:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESC3]][4, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)>
+// CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[FINAL_SIZE1_CAST]], %[[DESC4]][3, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)>
+// CHECK32: %[[C1_I32:.*]] = llvm.mlir.constant(1 : index) : i32
+// CHECK32: %[[DESC6:.*]] = llvm.insertvalue %[[C1_I32]], %[[DESC5]][4, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)>
+// CHECK32: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC6]] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> to memref<4x?xf32, strided<[?, ?], offset: ?>>
+// CHECK32: return %[[RES]] : memref<4x?xf32, strided<[?, ?], offset: ?>>
+// CHECK32: }
// -----
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 8cce630..0eb4478 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -491,12 +491,12 @@ func.func @mbarrier() {
// CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
- // CHECK: %[[token:.+]] = nvvm.mbarrier.arrive.shared %[[barPtr2]]
+ // CHECK: %[[token:.+]] = nvvm.mbarrier.arrive %[[barPtr2]]
%token = nvgpu.mbarrier.arrive %barrier[%c0] : !barrierType -> !tokenType
// CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
- // CHECK: nvvm.mbarrier.test.wait.shared %[[barPtr3]], %[[token]]
+ // CHECK: nvvm.mbarrier.test.wait %[[barPtr3]], %[[token]]
%isDone = nvgpu.mbarrier.test.wait %barrier[%c0], %token : !barrierType, !tokenType
func.return
@@ -521,12 +521,12 @@ func.func @mbarrier_nocomplete() {
// CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
- // CHECK: %[[token:.+]] = nvvm.mbarrier.arrive.nocomplete.shared %[[barPtr2]]
+ // CHECK: %[[token:.+]] = nvvm.mbarrier.arrive.nocomplete %[[barPtr2]]
%token = nvgpu.mbarrier.arrive.nocomplete %barrier[%c0], %count : !barrierType -> !tokenType
// CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
- // CHECK: nvvm.mbarrier.test.wait.shared %[[barPtr3]], %[[token]]
+ // CHECK: nvvm.mbarrier.test.wait %[[barPtr3]], %[[token]]
%isDone = nvgpu.mbarrier.test.wait %barrier[%c0], %token : !barrierType, !tokenType
func.return
@@ -572,7 +572,7 @@ func.func @mbarrier_wait(%barriers : !nvgpu.mbarrier.group<memorySpace = #gpu.ad
// CHECK: %[[S3:.+]] = builtin.unrealized_conversion_cast %[[S2]] : index to i64
// CHECK: %[[S4:.+]] = llvm.extractvalue %[[CARG0]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[S5:.+]] = llvm.getelementptr %[[S4]][%[[S3]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
-// CHECK: nvvm.mbarrier.test.wait.shared {{.*}}, %[[CARG1]]
+// CHECK: nvvm.mbarrier.test.wait {{.*}}, %[[CARG1]]
%mbarId = arith.remui %i, %numBarriers : index
%isDone = nvgpu.mbarrier.test.wait %barriers[%mbarId], %token : !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>, num_barriers = 5>, !tokenType
}
@@ -603,14 +603,14 @@ func.func @mbarrier_txcount() {
%txcount = arith.constant 256 : index
// CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
- // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]]
+ // CHECK: nvvm.mbarrier.arrive.expect_tx %[[barPtr2]]
nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount : !barrierType
scf.yield
} else {
%txcount = arith.constant 0 : index
// CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
- // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]]
+ // CHECK: nvvm.mbarrier.arrive.expect_tx %[[barPtr2]]
nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount : !barrierType
scf.yield
}
@@ -620,7 +620,7 @@ func.func @mbarrier_txcount() {
%ticks = arith.constant 10000000 : index
// CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
- // CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
+ // CHECK: nvvm.mbarrier.try_wait.parity %[[barPtr3]]
nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType
func.return
@@ -649,14 +649,14 @@ func.func @mbarrier_txcount_pred() {
%txcount = arith.constant 256 : index
// CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
- // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]], {{.*}}, predicate = %[[P]]
+ // CHECK: nvvm.mbarrier.arrive.expect_tx %[[barPtr2]], {{.*}}, predicate = %[[P]]
nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount, predicate = %pred : !barrierType
%phase_c0 = arith.constant 0 : i1
%ticks = arith.constant 10000000 : index
// CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
- // CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
+ // CHECK: nvvm.mbarrier.try_wait.parity %[[barPtr3]]
nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType
func.return
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index fbc4c0a..8fb36ac 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -16,17 +16,13 @@ llvm.func @init_mbarrier(%barrier_gen : !llvm.ptr, %barrier : !llvm.ptr<3>, %cou
// CHECK-LABEL: @init_mbarrier_arrive_expect_tx
llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32, %pred : i1) {
- //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r"
- nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount : !llvm.ptr<3>, i32
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r,b"
- nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1
+ nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1
llvm.return
}
// CHECK-LABEL: @init_mbarrier_arrive_expect_tx_generic
llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount : i32, %pred : i1) {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.b64 _, [$0], $1;", "l,r"
- nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr, i32
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.arrive.expect_tx.b64 _, [$0], $1;", "l,r,b"
nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred : !llvm.ptr, i32, i1
llvm.return
@@ -44,7 +40,7 @@ llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %ticks : i32,
// CHECK-SAME: DONE:
// CHECK-SAME: }",
// CHECK-SAME: "r,r,r"
- nvvm.mbarrier.try_wait.parity.shared %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32
+ nvvm.mbarrier.try_wait.parity %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32
llvm.return
}
@@ -88,10 +84,10 @@ func.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.p
nvvm.cp.async.mbarrier.arrive %bar_gen : !llvm.ptr
// CHECK: nvvm.cp.async.mbarrier.arrive %{{.*}} {noinc = true}
nvvm.cp.async.mbarrier.arrive %bar_gen {noinc = true} : !llvm.ptr
- // CHECK: nvvm.cp.async.mbarrier.arrive.shared %{{.*}}
- nvvm.cp.async.mbarrier.arrive.shared %bar_shared : !llvm.ptr<3>
- // CHECK: nvvm.cp.async.mbarrier.arrive.shared %{{.*}} {noinc = true}
- nvvm.cp.async.mbarrier.arrive.shared %bar_shared {noinc = true} : !llvm.ptr<3>
+ // CHECK: nvvm.cp.async.mbarrier.arrive %{{.*}}
+ nvvm.cp.async.mbarrier.arrive %bar_shared : !llvm.ptr<3>
+ // CHECK: nvvm.cp.async.mbarrier.arrive %{{.*}} {noinc = true}
+ nvvm.cp.async.mbarrier.arrive %bar_shared {noinc = true} : !llvm.ptr<3>
llvm.return
}
@@ -544,8 +540,8 @@ func.func @elect_one_leader_sync() {
// -----
-// CHECK-LABEL: @init_mbarrier_arrive_expect_tx
-llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) {
+// CHECK-LABEL: @test_nvvm_prefetch
+llvm.func @test_nvvm_prefetch(%desc : !llvm.ptr, %pred : i1) {
//CHECK: nvvm.prefetch tensormap, %{{.*}}
nvvm.prefetch tensormap, %desc : !llvm.ptr
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$1 prefetch.tensormap [$0];", "l,b"
@@ -588,29 +584,6 @@ func.func @cp_async_bulk_wait_group() {
// -----
-func.func @fence_mbarrier_init() {
- //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.mbarrier_init.release.cluster;"
- nvvm.fence.mbarrier.init
- func.return
-}
-// -----
-
-func.func @fence_proxy() {
- //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.alias;", "" : () -> ()
- nvvm.fence.proxy { kind = #nvvm.proxy_kind<alias>}
- //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.async;", "" : () -> ()
- nvvm.fence.proxy { kind = #nvvm.proxy_kind<async>}
- //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.async.global;", "" : () -> ()
- nvvm.fence.proxy { kind = #nvvm.proxy_kind<async.global>}
- //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.async.shared::cta;", "" : () -> ()
- nvvm.fence.proxy { kind = #nvvm.proxy_kind<async.shared>, space = #nvvm.shared_space<cta>}
- //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.async.shared::cluster;", "" : () -> ()
- nvvm.fence.proxy { kind = #nvvm.proxy_kind<async.shared>, space = #nvvm.shared_space<cluster>}
- func.return
-}
-
-// -----
-
// CHECK-LABEL: @llvm_nvvm_barrier_arrive
// CHECK-SAME: (%[[barId:.*]]: i32, %[[numberOfThreads:.*]]: i32)
llvm.func @llvm_nvvm_barrier_arrive(%barID : i32, %numberOfThreads : i32) {
diff --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
index f2fbe91..b122f42 100644
--- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
@@ -615,3 +615,22 @@ omp.declare_mapper @my_mapper : !llvm.struct<"_QFdeclare_mapperTmy_type", (i32)>
// CHECK: omp.declare_mapper.info map_entries(%{{.*}}, %{{.*}} : !llvm.ptr, !llvm.ptr)
omp.declare_mapper.info map_entries(%3, %2 : !llvm.ptr, !llvm.ptr)
}
+
+// CHECK-LABEL: llvm.func @omp_dist_schedule(%arg0: i32) {
+func.func @omp_dist_schedule(%arg0: i32) {
+ %c1_i32 = arith.constant 1 : i32
+ // CHECK: %1 = llvm.mlir.constant(1024 : i32) : i32
+ %c1024_i32 = arith.constant 1024 : i32
+ %c16_i32 = arith.constant 16 : i32
+ %c8_i32 = arith.constant 8 : i32
+ omp.teams num_teams( to %c8_i32 : i32) thread_limit(%c16_i32 : i32) {
+ // CHECK: omp.distribute dist_schedule_static dist_schedule_chunk_size(%1 : i32) {
+ omp.distribute dist_schedule_static dist_schedule_chunk_size(%c1024_i32 : i32) {
+ omp.loop_nest (%arg1) : i32 = (%c1_i32) to (%arg0) inclusive step (%c1_i32) {
+ omp.terminator
+ }
+ }
+ omp.terminator
+ }
+ return
+}
diff --git a/mlir/test/Conversion/OpenMPToLLVM/map-info-type-conversion-fail.mlir b/mlir/test/Conversion/OpenMPToLLVM/map-info-type-conversion-fail.mlir
new file mode 100644
index 0000000..3bd9bb4
--- /dev/null
+++ b/mlir/test/Conversion/OpenMPToLLVM/map-info-type-conversion-fail.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt -convert-openmp-to-llvm -split-input-file -verify-diagnostics %s
+
+// Indicates that the TypeConversion has failed for the MPMapInfoOp.
+// In this specific case, the `tensor` type (used in a TypeAttr) cannot be converted
+// to an LLVM type. This test ensures that the conversion fails gracefully with a
+// legalization error instead of crashing.
+func.func @fail_map_info_tensor_type(%arg0: memref<?xf32>) {
+ // expected-error@+1 {{failed to legalize operation 'omp.map.info' that was explicitly marked illegal}}
+ %map_info = omp.map.info var_ptr(%arg0: memref<?xf32>, tensor<?xf32>) map_clauses(to) capture(ByRef) -> memref<?xf32>
+ omp.target_update map_entries(%map_info: memref<?xf32>) {
+ omp.terminator
+ }
+ return
+}
diff --git a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
index 483c7b3..0c4f20e 100644
--- a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
+++ b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-cf -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-cf="allow-pattern-rollback=0" -split-input-file %s | FileCheck %s
// CHECK-LABEL: func @simple_std_for_loop(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
// CHECK-NEXT: cf.br ^bb1(%{{.*}} : index)
diff --git a/mlir/test/Conversion/SCFToGPU/parallel_loop.mlir b/mlir/test/Conversion/SCFToGPU/parallel_loop.mlir
index 26f5a3e..2f192df 100644
--- a/mlir/test/Conversion/SCFToGPU/parallel_loop.mlir
+++ b/mlir/test/Conversion/SCFToGPU/parallel_loop.mlir
@@ -673,3 +673,51 @@ func.func @nested_parallel_with_side_effect() {
// CHECK: gpu.launch
// CHECK-NOT: scf.parallel
+
+// -----
+
+func.func @scf2gpu_index_creation_2d() {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c32 = arith.constant 32 : index
+
+ // Single 2-D scf.parallel mapped to block_x and thread_x.
+ // Use both IVs so the conversion must compute indices.
+ scf.parallel (%bx, %tx) = (%c0, %c0) to (%c32, %c32) step (%c1, %c1) {
+ %u = arith.addi %bx, %c0 : index
+ %v = arith.addi %tx, %c0 : index
+ } {
+ mapping = [
+ #gpu.loop_dim_map<processor = block_x, map = (d0) -> (d0), bound = (d0) -> (d0)>,
+ #gpu.loop_dim_map<processor = thread_x, map = (d0) -> (d0), bound = (d0) -> (d0)>
+ ]
+ }
+ return
+}
+
+// CHECK-LABEL: func @scf2gpu_index_creation_2d
+// CHECK: gpu.launch
+// CHECK: %[[IDX:.*]] = affine.apply
+// CHECK: arith.addi %[[IDX]],
+
+// -----
+
+func.func @scf2gpu_index_creation_1d() {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c64 = arith.constant 64 : index
+
+ scf.parallel (%t) = (%c0) to (%c64) step (%c1) {
+ %w = arith.addi %t, %c0 : index
+ } {
+ mapping = [
+ #gpu.loop_dim_map<processor = thread_x, map = (d0) -> (d0), bound = (d0) -> (d0)>
+ ]
+ }
+ return
+}
+
+// CHECK-LABEL: func @scf2gpu_index_creation_1d
+// CHECK: gpu.launch
+// CHECK: %[[IDX:.*]] = affine.apply
+// CHECK: arith.addi %[[IDX]],
diff --git a/mlir/test/Conversion/SPIRVToLLVM/gl-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/gl-ops-to-llvm.mlir
index e1936e2..b17e1c4 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/gl-ops-to-llvm.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/gl-ops-to-llvm.mlir
@@ -162,9 +162,7 @@ spirv.func @sqrt(%arg0: f32, %arg1: vector<3xf16>) "None" {
// CHECK-LABEL: @tan
spirv.func @tan(%arg0: f32) "None" {
- // CHECK: %[[SIN:.*]] = llvm.intr.sin(%{{.*}}) : (f32) -> f32
- // CHECK: %[[COS:.*]] = llvm.intr.cos(%{{.*}}) : (f32) -> f32
- // CHECK: llvm.fdiv %[[SIN]], %[[COS]] : f32
+ // CHECK: llvm.intr.tan(%{{.*}}) : (f32) -> f32
%0 = spirv.GL.Tan %arg0 : f32
spirv.Return
}
@@ -175,13 +173,7 @@ spirv.func @tan(%arg0: f32) "None" {
// CHECK-LABEL: @tanh
spirv.func @tanh(%arg0: f32) "None" {
- // CHECK: %[[TWO:.*]] = llvm.mlir.constant(2.000000e+00 : f32) : f32
- // CHECK: %[[X2:.*]] = llvm.fmul %[[TWO]], %{{.*}} : f32
- // CHECK: %[[EXP:.*]] = llvm.intr.exp(%[[X2]]) : (f32) -> f32
- // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
- // CHECK: %[[T0:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : f32
- // CHECK: %[[T1:.*]] = llvm.fadd %[[EXP]], %[[ONE]] : f32
- // CHECK: llvm.fdiv %[[T0]], %[[T1]] : f32
+ // CHECK: llvm.intr.tanh(%{{.*}}) : (f32) -> f32
%0 = spirv.GL.Tanh %arg0 : f32
spirv.Return
}
diff --git a/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir b/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir
index b69c2d0..65c6e05 100644
--- a/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir
@@ -79,3 +79,12 @@ func.func @tensor_2d_empty() -> () {
%x = arith.constant dense<> : tensor<2x0xi32>
return
}
+
+// Tensors with more than UINT32_MAX elements cannnot fit in a spirv.array.
+// Test that they are not lowered.
+// CHECK-LABEL: func @very_large_tensor
+// CHECK-NEXT: arith.constant dense<1>
+func.func @very_large_tensor() -> () {
+ %x = arith.constant dense<1> : tensor<4294967296xi32>
+ return
+}
diff --git a/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir b/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir
index 6c0b111..0fe63f5 100644
--- a/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir
+++ b/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir
@@ -17,3 +17,9 @@ func.func @check_poison() {
%3 = ub.poison : !llvm.ptr
return
}
+
+// CHECK-LABEL: @check_unrechable
+func.func @check_unrechable() {
+// CHECK: llvm.unreachable
+ ub.unreachable
+}
diff --git a/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir b/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir
index f497eb3..9c277cf 100644
--- a/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir
+++ b/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -split-input-file -convert-ub-to-spirv -verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -convert-ub-to-spirv %s | FileCheck %s
module attributes {
spirv.target_env = #spirv.target_env<
@@ -19,3 +19,20 @@ func.func @check_poison() {
}
}
+
+// -----
+
+module attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Shader], []>, #spirv.resource_limits<>>
+} {
+// CHECK-LABEL: @check_unrechable
+func.func @check_unrechable(%c: i1) {
+ cf.cond_br %c, ^bb1, ^bb2
+^bb1:
+// CHECK: spirv.Unreachable
+ ub.unreachable
+^bb2:
+ return
+}
+}
diff --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
index 9908205..ae5141d 100644
--- a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
@@ -9,11 +9,12 @@ func.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vecto
// CHECK-LABEL: @load_1D_vector(
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
+// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
+// CHECK-SAME: %[[COLLAPSED]]
+// CHECK-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
// CHECK-SAME: boundary_check = false
-// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf32>
+// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]]]{{.*}}-> vector<8xf32>
// CHECK: return %[[VEC]]
// -----
@@ -28,35 +29,29 @@ func.func @load_2D_vector(%source: memref<8x16x32xf32>,
// CHECK-LABEL: @load_2D_vector(
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
+// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK-SAME: %[[COLLAPSED]]
+// CHECK-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]
// -----
func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
- %offset: index) -> vector<8x16xf32> {
- %0 = vector.load %source[%offset, %offset, %offset]
+ %i: index, %j: index, %k: index) -> vector<8x16xf32> {
+ %0 = vector.load %source[%i, %j, %k]
: memref<?x?x?xf32>, vector<8x16xf32>
return %0 : vector<8x16xf32>
}
// CHECK-LABEL: @load_dynamic_source(
// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
-// CHECK-SAME: %[[OFFSET:.+]]: index
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
-// CHECK-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
-// CHECK-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
-// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
-// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
+// CHECK: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF1]], %[[OFF2]]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]
// -----
@@ -72,9 +67,9 @@ func.func @load_out_of_bounds(%source: memref<7x15xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<7x15xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME: %[[SRC]]
// CHECK-SAME: memref<7x15xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]
// -----
diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
index 2c498dc..1a10d91 100644
--- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
@@ -11,11 +11,12 @@ func.func @store_1D_vector(%vec: vector<8xf32>,
// CHECK-SAME: %[[VEC:.+]]: vector<8xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
+// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
+// CHECK-SAME: %[[COLLAPSED]]
+// CHECK-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
// CHECK-SAME: boundary_check = false
-// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf32>
+// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]]] : vector<8xf32>
// -----
@@ -30,16 +31,17 @@ func.func @store_2D_vector(%vec: vector<8x16xf32>,
// CHECK-SAME: %[[VEC:.+]]: vector<8x16xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
+// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// CHECK-SAME: %[[COLLAPSED]]
+// CHECK-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
// -----
func.func @store_dynamic_source(%vec: vector<8x16xf32>,
- %source: memref<?x?x?xf32>, %offset: index) {
- vector.store %vec, %source[%offset, %offset, %offset]
+ %source: memref<?x?x?xf32>, %i: index, %j: index, %k: index) {
+ vector.store %vec, %source[%i, %j, %k]
: memref<?x?x?xf32>, vector<8x16xf32>
return
}
@@ -47,18 +49,11 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
// CHECK-LABEL: @store_dynamic_source(
// CHECK-SAME: %[[VEC:.+]]: vector<8x16xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
-// CHECK-SAME: %[[OFFSET:.+]]: index
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
-// CHECK-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
-// CHECK-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
-// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
-// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// CHECK-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
+// CHECK: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF1]], %[[OFF2]]] : vector<8x16xf32>
// -----
@@ -74,9 +69,9 @@ func.func @store_out_of_bounds(%vec: vector<8x16xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<7x64xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME: %[[SRC]]
// CHECK-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
// -----
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index c4ca79a..8bb272b 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -11,13 +11,15 @@ gpu.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector
// LOAD-ND-LABEL: @load_1D_vector(
// LOAD-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
-// LOAD-ND-SAME: %[[OFFSET:.+]]: index
-// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// LOAD-ND-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
-// LOAD-ND-SAME: boundary_check = false
-// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf32>
-// LOAD-ND: return %[[VEC]]
+// LOAD-ND: %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
+// LOAD-ND: %[[STEP:.+]] = vector.step : vector<8xindex>
+// LOAD-ND-COUNT2: arith.muli {{.*}} : index
+// LOAD-ND-COUNT2: arith.addi {{.*}} : index
+// LOAD-ND: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
+// LOAD-ND: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
+// LOAD-ND: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
+// LOAD-ND: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-ND: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32>
// LOAD-GATHER-LABEL: @load_1D_vector(
// LOAD-GATHER-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
@@ -46,11 +48,12 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
// LOAD-ND-LABEL: @load_2D_vector(
// LOAD-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// LOAD-ND-SAME: %[[OFFSET:.+]]: index
+// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// LOAD-ND-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
+// LOAD-ND-SAME: %[[COLLAPSED]]
+// LOAD-ND-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32,
// LOAD-ND-SAME: boundary_check = false
-// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// LOAD-ND: return %[[VEC]]
// LOAD-GATHER-LABEL: @load_2D_vector(
@@ -83,9 +86,9 @@ gpu.func @load_zero_pad_out_of_bounds(%source: memref<32x64xf32>,
// LOAD-ND-LABEL: @load_zero_pad_out_of_bounds(
// LOAD-ND-SAME: %[[SRC:.+]]: memref<32x64xf32>,
// LOAD-ND-SAME: %[[OFFSET:.+]]: index
-// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
// LOAD-ND-SAME: memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32>
-// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// LOAD-ND: return %[[VEC]]
// LOAD-GATHER-LABEL: @load_zero_pad_out_of_bounds(
@@ -109,9 +112,9 @@ gpu.func @load_transposed(%source: memref<32x64xf32>,
// LOAD-ND-SAME: %[[SRC:.+]]: memref<32x64xf32>,
// LOAD-ND-SAME: %[[OFFSET1:.+]]: index,
// LOAD-ND-SAME: %[[OFFSET2:.+]]: index
-// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET1]], %[[OFFSET2]]]
+// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
// LOAD-ND-SAME: memref<32x64xf32> -> !xegpu.tensor_desc<16x8xf32
-// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]] <{transpose = array<i64: 1, 0>}>
+// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET1]], %[[OFFSET2]]] <{transpose = array<i64: 1, 0>}>
// LOAD-ND-SAME: -> vector<8x16xf32>
// LOAD-ND: return %[[VEC]]
@@ -143,16 +146,11 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
}
// LOAD-ND-LABEL: @load_dynamic_source(
// LOAD-ND-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
-// LOAD-ND-SAME: %[[OFFSET:.+]]: index
-// LOAD-ND: %[[C2:.+]] = arith.constant 2 : index
-// LOAD-ND: %[[C1:.+]] = arith.constant 1 : index
-// LOAD-ND: %[[C0:.+]] = arith.constant 0 : index
-// LOAD-ND-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
-// LOAD-ND-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
-// LOAD-ND-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// LOAD-ND: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
-// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET:.+]], %[[OFFSET:.+]], %[[OFFSET:.+]]]
-// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// LOAD-ND-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
+// LOAD-ND: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF1]], %[[OFF2]]]{{.*}}-> vector<8x16xf32>
// LOAD-ND: return %[[VEC]]
@@ -184,10 +182,11 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
}
// LOAD-ND-LABEL: @load_dynamic_source2(
-// LOAD-ND-DAG: %[[C0:.+]] = arith.constant 0 : index
-// LOAD-ND-DAG: %[[DIM:.+]] = memref.dim %{{.*}}, %[[C0]] : memref<?x8x16xf32>
-// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}], shape : [%[[DIM]], 8, 16], strides : [128, 16, 1] : memref<?x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
-// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
+// LOAD-ND-SAME: %[[SRC:.+]]: memref<?x8x16xf32>,
+// LOAD-ND-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
+// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32, strided<[16, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
+// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%{{.*}}, %{{.*}}] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
// LOAD-ND: return %[[VEC]] : vector<8x16xf32>
// LOAD-GATHER-LABEL: @load_dynamic_source2(
@@ -406,7 +405,7 @@ gpu.func @no_load_unsupported_map(%source: memref<16x32x64xf32>,
// -----
gpu.module @xevm_module {
-gpu.func @load_from_subview(%source: memref<4096x4096xf16>, %off1: index, %off2: index) -> vector<8xf16> {
+gpu.func @load_from_subview_1D(%source: memref<4096x4096xf16>, %off1: index, %off2: index) -> vector<8xf16> {
%c0 = arith.constant 0.0 : f16
%subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
%0 = vector.transfer_read %subview[%off2, %off2], %c0
@@ -414,18 +413,23 @@ gpu.func @load_from_subview(%source: memref<4096x4096xf16>, %off1: index, %off2:
gpu.return %0 : vector<8xf16>
}
-// LOAD-ND-LABEL: @load_from_subview(
+// LOAD-ND-LABEL: @load_from_subview_1D(
// LOAD-ND-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
// LOAD-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// LOAD-ND: %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
// LOAD-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
-// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME: %[[SUBVIEW]][%[[OFF2]], %[[OFF2]]]
-// LOAD-ND-SAME: memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
-// LOAD-ND-SAME: boundary_check = false
-// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf16>
-// LOAD-ND: return %[[VEC]]
-
-// LOAD-GATHER-LABEL: @load_from_subview(
+// LOAD-ND: %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index
+// LOAD-ND: %[[STEP:.+]] = vector.step : vector<8xindex>
+// LOAD-ND: arith.muli {{.*}} : index
+// LOAD-ND: arith.addi %[[OFFSET]]{{.*}} : index
+// LOAD-ND: arith.addi {{.*}} : index
+// LOAD-ND: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
+// LOAD-ND: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
+// LOAD-ND: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index
+// LOAD-ND: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-ND: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf16>
+
+// LOAD-GATHER-LABEL: @load_from_subview_1D(
// LOAD-GATHER-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
// LOAD-GATHER-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// LOAD-GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
@@ -441,3 +445,42 @@ gpu.func @load_from_subview(%source: memref<4096x4096xf16>, %off1: index, %off2:
// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf16>
}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @load_from_subview_2D(%source: memref<4096x4096xf16>, %off1: index, %off2: index) -> vector<8x16xf16> {
+ %c0 = arith.constant 0.0 : f16
+ %subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+ %0 = vector.transfer_read %subview[%off2, %off2], %c0
+ {in_bounds = [true, true]} : memref<256x256xf16, strided<[4096, 1], offset: ?>>, vector<8x16xf16>
+ gpu.return %0 : vector<8x16xf16>
+}
+
+// LOAD-ND-LABEL: @load_from_subview_2D(
+// LOAD-ND-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
+// LOAD-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// LOAD-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
+// LOAD-ND-SAME: %[[SUBVIEW]]
+// LOAD-ND-SAME: memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf16,
+// LOAD-ND-SAME: boundary_check = false
+// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF2]], %[[OFF2]]]{{.*}}-> vector<8x16xf16>
+// LOAD-ND: return %[[VEC]]
+
+// LOAD-GATHER-LABEL: @load_from_subview_2D(
+// LOAD-GATHER-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
+// LOAD-GATHER-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// LOAD-GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
+// LOAD-GATHER: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+// LOAD-GATHER: %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index
+// LOAD-GATHER-COUNT2: vector.step
+// LOAD-GATHER-COUNT2: vector.shape_cast
+// LOAD-GATHER-COUNT2: vector.broadcast
+// LOAD-GATHER-COUNT2: arith.muli {{.*}} : index
+// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
+// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
+// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} : vector<8x16xindex>
+// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index
+// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16>
+}
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index fcfc941..43a1a72 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --xevm-attach-target='module=xevm.* O=3 chip=pvc' -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=STORE-ND
+// RUN: mlir-opt %s --xevm-attach-target='module=xevm_* O=3 chip=pvc' -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=STORE-ND
// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=STORE-SCATTER
@@ -15,11 +15,12 @@ gpu.func @store_1D_vector(%vec: vector<8xf32>,
// STORE-ND-SAME: %[[VEC:.+]]: vector<8xf32>,
// STORE-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// STORE-ND-SAME: %[[OFFSET:.+]]: index
+// STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// STORE-ND-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
+// STORE-ND-SAME: %[[COLLAPSED]]
+// STORE-ND-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
// STORE-ND-SAME: boundary_check = false
-// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf32>
+// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]]] : vector<8xf32>
// STORE-SCATTER-LABEL: @store_1D_vector(
// STORE-SCATTER-SAME: %[[VEC:.+]]: vector<8xf32>,
@@ -49,11 +50,12 @@ gpu.func @store_2D_vector(%vec: vector<8x16xf32>,
// STORE-ND-SAME: %[[VEC:.+]]: vector<8x16xf32>,
// STORE-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// STORE-ND-SAME: %[[OFFSET:.+]]: index
+// STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// STORE-ND-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
+// STORE-ND-SAME: %[[COLLAPSED]]
+// STORE-ND-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32,
// STORE-ND-SAME: boundary_check = false
-// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
// STORE-SCATTER-LABEL: @store_2D_vector(
// STORE-SCATTER-SAME: %[[VEC:.+]]: vector<8x16xf32>,
@@ -73,8 +75,8 @@ gpu.func @store_2D_vector(%vec: vector<8x16xf32>,
// -----
gpu.module @xevm_module {
gpu.func @store_dynamic_source(%vec: vector<8x16xf32>,
- %source: memref<?x?x?xf32>, %offset: index) {
- vector.transfer_write %vec, %source[%offset, %offset, %offset]
+ %source: memref<?x?x?xf32>, %i: index, %j: index, %k: index) {
+ vector.transfer_write %vec, %source[%i, %j, %k]
{in_bounds = [true, true]}
: vector<8x16xf32>, memref<?x?x?xf32>
gpu.return
@@ -83,18 +85,11 @@ gpu.func @store_dynamic_source(%vec: vector<8x16xf32>,
// STORE-ND-LABEL: @store_dynamic_source(
// STORE-ND-SAME: %[[VEC:.+]]: vector<8x16xf32>,
// STORE-ND-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
-// STORE-ND-SAME: %[[OFFSET:.+]]: index
-// STORE-ND-DAG: %[[C0:.+]] = arith.constant 0 : index
-// STORE-ND-DAG: %[[C1:.+]] = arith.constant 1 : index
-// STORE-ND-DAG: %[[C2:.+]] = arith.constant 2 : index
-// STORE-ND-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
-// STORE-ND-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
-// STORE-ND-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// STORE-ND: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
-// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// STORE-ND-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
-// STORE-ND-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32
-// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// STORE-ND-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
+// STORE-ND: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF1]], %[[OFF2]]] : vector<8x16xf32>
// STORE-SCATTER-LABEL: @store_dynamic_source(
// STORE-SCATTER-SAME: %[[VEC:.+]]: vector<8x16xf32>,
@@ -126,9 +121,9 @@ gpu.func @store_out_of_bounds(%vec: vector<8x16xf32>,
// STORE-ND-SAME: %[[SRC:.+]]: memref<7x64xf32>,
// STORE-ND-SAME: %[[OFFSET:.+]]: index
// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// STORE-ND-SAME: %[[SRC]]
// STORE-ND-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32>
-// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
// STORE-SCATTER-LABEL: @store_out_of_bounds(
// STORE-SCATTER: vector.transfer_write
@@ -298,13 +293,13 @@ gpu.func @store_to_subview(%vec: vector<8xf16>,
// STORE-ND-SAME: %[[VEC:.+]]: vector<8xf16>,
// STORE-ND-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
// STORE-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
-// STORE-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1]
-// STORE-ND-SAME: : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+// STORE-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+// STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SUBVIEW]][%[[OFF2]], 0]
// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME: %[[SUBVIEW]][%[[OFF2]], %[[OFF2]]]
-// STORE-ND-SAME: memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
+// STORE-ND-SAME: %[[COLLAPSED]]
+// STORE-ND-SAME: memref<256xf16, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
// STORE-ND-SAME: boundary_check = false
-// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf16>
+// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF2]]] : vector<8xf16>
// STORE-SCATTER-LABEL: @store_to_subview(
// STORE-SCATTER-SAME: %[[VEC:.+]]: vector<8xf16>,
diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
index 09ef76c..9a1e2cb 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -7,42 +7,41 @@ gpu.module @create_nd_tdesc {
// CHECK-SAME: %[[DYN:.*]]: memref<?x?xf16>) kernel {
gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
%stride1: index, %stride2: index, %offset1: index, %offset2: index, %dyn: memref<?x?xf16>) kernel {
+ // CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref<?x?xf16> -> index
+ // CHECK: %[[DYN_ADDR:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
// CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
// CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
- // CHECK: %[[OFFSET_W:.*]] = arith.constant 0 : i32
- // CHECK: %[[OFFSET_H:.*]] = arith.constant 0 : i32
// CHECK: %[[SHAPE_W:.*]] = arith.index_cast %[[ARG3]] : index to i32
// CHECK: %[[SHAPE_H:.*]] = arith.index_cast %[[ARG2]] : index to i32
+ // CHECK: %[[PITCH:.*]] = arith.index_cast %[[ARG4]] : index to i32
// CHECK: %[[VAR6:.*]] = vector.bitcast %[[CST]] : vector<8xi32> to vector<4xi64>
// CHECK: %[[VAR7:.*]] = vector.insert %[[BASE_ADDR]], %[[VAR6]] [0] : i64 into vector<4xi64>
// CHECK: %[[VAR8:.*]] = vector.bitcast %[[VAR7]] : vector<4xi64> to vector<8xi32>
// CHECK: %[[VAR9:.*]] = vector.insert %[[SHAPE_W]], %[[VAR8]] [2] : i32 into vector<8xi32>
// CHECK: %[[VAR10:.*]] = vector.insert %[[SHAPE_H]], %[[VAR9]] [3] : i32 into vector<8xi32>
- // CHECK: %[[VAR11:.*]] = vector.insert %[[OFFSET_W]], %[[VAR10]] [4] : i32 into vector<8xi32>
- // CHECK: %[[VAR12:.*]] = vector.insert %[[OFFSET_H]], %[[VAR11]] [5] : i32 into vector<8xi32>
+ // CHECK: %[[VAR11:.*]] = vector.insert %[[PITCH]], %[[VAR10]] [4] : i32 into vector<8xi32>
%ptr_tdesc = xegpu.create_nd_tdesc %ptr, shape:[%shape1, %shape2], strides:[%stride1, %stride2]
: ui64 -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<16x32xf32, 1> to memref<16x32xf32>
%srcce = memref.memory_space_cast %src : memref<16x32xf32, 1> to memref<16x32xf32>
- // CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32>
// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
- // CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32
- // CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32
+ // CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+ // CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32>
// CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64
// CHECK: %[[SHAPE_W2:.*]] = arith.trunci %[[C32_I64]] : i64 to i32
// CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
// CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32
- // CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+ // CHECK: %[[C32_I64_2:.*]] = arith.constant 32 : i64
+ // CHECK: %[[PITCH2:.*]] = arith.trunci %[[C32_I64_2]] : i64 to i32
// CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
// CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
// CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32>
// CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32>
// CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32>
- // CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32>
- // CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
+ // CHECK: %[[VAR19:.*]] = vector.insert %[[PITCH2]], %[[VAR18]] [4] : i32 into vector<8xi32>
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[C1:.*]] = arith.constant 1 : index
@@ -51,20 +50,16 @@ gpu.module @create_nd_tdesc {
%size_x = arith.constant 64 : index
// CHECK: %[[C16:.*]] = arith.constant 16 : index
%BLOCK_DMODEL = arith.constant 16 : index
- // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
- // CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref<?x?xf16> -> index
- // CHECK: %[[C0_I32_6:.*]] = arith.constant 0 : i32
- // CHECK: %[[C0_I32_7:.*]] = arith.constant 0 : i32
- // CHECK: %[[VAR21:.*]] = arith.index_cast %[[C16]] : index to i32
- // CHECK: %[[VAR22:.*]] = arith.index_cast %[[C64]] : index to i32
- // CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
- // CHECK: %[[VAR24:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
- // CHECK: %[[VAR25:.*]] = vector.insert %[[VAR23]], %[[VAR24]] [0] : i64 into vector<4xi64>
- // CHECK: %[[VAR26:.*]] = vector.bitcast %[[VAR25]] : vector<4xi64> to vector<8xi32>
- // CHECK: %[[VAR27:.*]] = vector.insert %[[VAR21]], %[[VAR26]] [2] : i32 into vector<8xi32>
- // CHECK: %[[VAR28:.*]] = vector.insert %[[VAR22]], %[[VAR27]] [3] : i32 into vector<8xi32>
- // CHECK: %[[VAR29:.*]] = vector.insert %[[C0_I32_6]], %[[VAR28]] [4] : i32 into vector<8xi32>
- // CHECK: %[[VAR30:.*]] = vector.insert %[[C0_I32_7]], %[[VAR29]] [5] : i32 into vector<8xi32>
+ // CHECK: %[[CST_3:.*]] = arith.constant dense<0> : vector<8xi32>
+ // CHECK: %[[SHAPE_W3:.*]] = arith.index_cast %[[C16]] : index to i32
+ // CHECK: %[[SHAPE_H3:.*]] = arith.index_cast %[[C64]] : index to i32
+ // CHECK: %[[PITCH3:.*]] = arith.index_cast %[[C16]] : index to i32
+ // CHECK: %[[VAR25:.*]] = vector.bitcast %[[CST_3]] : vector<8xi32> to vector<4xi64>
+ // CHECK: %[[VAR26:.*]] = vector.insert %[[DYN_ADDR]], %[[VAR25]] [0] : i64 into vector<4xi64>
+ // CHECK: %[[VAR27:.*]] = vector.bitcast %[[VAR26]] : vector<4xi64> to vector<8xi32>
+ // CHECK: %[[VAR28:.*]] = vector.insert %[[SHAPE_W3]], %[[VAR27]] [2] : i32 into vector<8xi32>
+ // CHECK: %[[VAR29:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR28]] [3] : i32 into vector<8xi32>
+ // CHECK: %[[VAR30:.*]] = vector.insert %[[PITCH3]], %[[VAR29]] [4] : i32 into vector<8xi32>
%dyn_tdesc = xegpu.create_nd_tdesc %dyn, shape: [%size_x, %BLOCK_DMODEL], strides: [%BLOCK_DMODEL, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<16x16xf16>
gpu.return
}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir
new file mode 100644
index 0000000..aebec7f
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir
@@ -0,0 +1,36 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
+
+gpu.module @load_store_check {
+ // CHECK-LABEL: @load_store(
+ // CHECK-SAME: %[[SRC:.*]]: memref<512xf32, 1>, %[[DST:.*]]: memref<256xf32, 1>
+ gpu.func @load_store(%src: memref<512xf32, 1>, %dst: memref<256xf32, 1>) kernel {
+ // CHECK: %[[C512:.*]] = arith.constant 512 : i64
+ // CHECK: %[[C384:.*]] = arith.constant 384 : i64
+
+ // CHECK: %[[SRCCE:.*]] = memref.memory_space_cast %[[SRC]] : memref<512xf32, 1> to memref<512xf32>
+ %srcce = memref.memory_space_cast %src : memref<512xf32, 1> to memref<512xf32>
+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[SRCCE]] : memref<512xf32> -> index
+ // CHECK: %[[INTPTR_I64:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+ // CHECK: %[[DSTTE:.*]] = memref.memory_space_cast %[[DST]] : memref<256xf32, 1> to memref<256xf32>
+ %dstte = memref.memory_space_cast %dst : memref<256xf32, 1> to memref<256xf32>
+ // CHECK: %[[INTPTR1:.*]] = memref.extract_aligned_pointer_as_index %[[DSTTE]] : memref<256xf32> -> index
+ // CHECK: %[[INTPTR1_I64:.*]] = arith.index_castui %[[INTPTR1]] : index to i64
+
+ %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<512xf32> -> !xegpu.tensor_desc<32xf32>
+ // CHECK: %[[ADDR:.*]] = arith.addi %[[INTPTR_I64]], %[[C384]] : i64
+ // CHECK: %[[PTR:.*]] = llvm.inttoptr %[[ADDR]] : i64 to !llvm.ptr<1>
+ // CHECK: %[[LOAD:.*]] = xevm.blockload %[[PTR]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}>
+ // CHECK-SAME: : (!llvm.ptr<1>) -> vector<2xi32>
+ %loaded = xegpu.load_nd %src_tdesc[96] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<32xf32> -> vector<2xf32>
+
+ %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<256xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.block_tdesc_attr<memory_space = global>>
+ // CHECK: %[[ADDR1:.*]] = arith.addi %[[INTPTR1_I64]], %[[C512]] : i64
+ // CHECK: %[[PTR1:.*]] = llvm.inttoptr %[[ADDR1]] : i64 to !llvm.ptr<1>
+ // CHECK: xevm.blockstore %[[PTR1]], %[[LOAD]] <{cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>}>
+ // CHECK-SAME: : (!llvm.ptr<1>, vector<2xi32>)
+ xegpu.store_nd %loaded, %dst_tdesc[128] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : vector<2xf32>, !xegpu.tensor_desc<32xf32, #xegpu.block_tdesc_attr<memory_space = global>>
+ gpu.return
+ }
+}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
index d4cb493..3a3769f 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
@@ -4,8 +4,8 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
// e.g. for mem_desc<32x32xf16, @strides=[1, 16]>
// its memory layout tuple is (blocked shape = [1,1,32,32],strides=[1024,1024,32,1])
- //CHECK-LABEL: load_store_matrix_1
- gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> f32 {
+ //CHECK-LABEL: load_store_matrix_plain
+ gpu.func @load_store_matrix_plain(%arg0: memref<4096xi8, 3>) -> f32 {
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32>
//CHECK: %[[TID:.*]] = gpu.thread_id x
@@ -26,12 +26,40 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
gpu.return %1: f32
}
+ //CHECK-LABEL: load_store_matrix_plain_2d_input
+ gpu.func @load_store_matrix_plain_2d_input(%arg0: memref<8192xi8, 3>) -> f32 {
+ %c0 = arith.constant 0 : index
+ %view = memref.view %arg0[%c0][]: memref<8192xi8, 3> to memref<64x32xf32, 3>
+
+ %subview = memref.subview %view[32, 0] [32, 32] [1, 1] : memref<64x32xf32, 3> to memref<32x32xf32, strided<[32, 1], offset: 1024>, 3>
+
+ %0 = xegpu.create_mem_desc %subview : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> !xegpu.mem_desc<32x32xf32>
+
+ //CHECK: %[[TID:.*]] = gpu.thread_id x
+ //CHECK: %[[C1:.*]] = arith.constant 1 : index
+ //CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index
+ //CHECK: %[[C4:.*]] = arith.constant 4 : i32
+ //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i32
+ //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32
+
+ %tid_x = gpu.thread_id x
+
+ %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32
+
+ //CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3>
+
+ xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index
+
+ gpu.return %1: f32
+ }
+
+
// e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 32]>
// its memory layout tuple is ([2,4,16,16],[256,512,1,16])
- //CHECK-LABEL: load_store_matrix_2
- gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> f16 {
+ //CHECK-LABEL: load_store_matrix_blocked_strided
+ gpu.func @load_store_matrix_blocked_strided(%arg0: memref<4096xi8, 3>) -> f16 {
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
- //CHECK: %[[c0:.*]] = arith.constant 0 : index
+
//CHECK: %[[tid_x:.*]] = gpu.thread_id x
//CHECK: %[[c13:.*]] = arith.constant 13 : index
//CHECK: %[[c16:.*]] = arith.constant 16 : index
@@ -39,7 +67,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
//CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index
//CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
//CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
-
+ //CHECK: %[[c0:.*]] = arith.constant 0 : index
//CHECK: %[[c256:.*]] = arith.constant 256 : index
//CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
@@ -53,39 +81,39 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
//CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
//CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> f16
-
+
%tid_x = gpu.thread_id x
%c13 = arith.constant 13 : index
%1 = xegpu.load_matrix %0[%c13, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> f16
//CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3>
-
- xegpu.store_matrix %1, %0[%c13, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index
+
+ xegpu.store_matrix %1, %0[%c13, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index
gpu.return %1: f16
}
// e.g. for mem_desc<32x64xf16, @block=[16, 16]>
// its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
- //CHECK-LABEL: load_store_matrix_3
- gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> f16 {
- //CHECK: %[[c0:.*]] = arith.constant 0 : index
- //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>
+ //CHECK-LABEL: load_store_matrix_blocked_nostride
+ gpu.func @load_store_matrix_blocked_nostride(%arg0: memref<4096xi8, 3>) -> f16 {
+
+ //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
+ //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
-
+
//CHECK: %[[tid_x:.*]] = gpu.thread_id x
//CHECK: %[[c19:.*]] = arith.constant 19 : index
%tid_x = gpu.thread_id x
%c19 = arith.constant 19: index
-
- //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
- //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
+
//CHECK: %[[c16:.*]] = arith.constant 16 : index
//CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index
//CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index
//CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
//CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
+ //CHECK: %[[c0:.*]] = arith.constant 0 : index
//CHECK: %[[c1024:.*]] = arith.constant 1024 : index
//CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
@@ -97,32 +125,29 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
//CHECK: %[[c1:.*]] = arith.constant 1 : index
//CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index
//CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
-
//CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16
%1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> f16
-
+
//CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3>
xegpu.store_matrix %1, %0[%c19, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index
-
+
//CHECK: gpu.return %[[loaded]] : f16
gpu.return %1: f16
}
// e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]>
// its memory layout tuple is ([2,4,16,16],[256,512,1,16])
- //CHECK-LABEL: load_store_matrix_4
- gpu.func @load_store_matrix_4(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
+ //CHECK-LABEL: load_store_matrix_blocked_strided_return_vector
+ gpu.func @load_store_matrix_blocked_strided_return_vector(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
- //CHECK: %[[c0:.*]] = arith.constant 0 : index
//CHECK: %[[tid_x:.*]] = gpu.thread_id x
-
//CHECK: %[[c16:.*]] = arith.constant 16 : index
//CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
//CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
//CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
//CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
-
+ //CHECK: %[[c0:.*]] = arith.constant 0 : index
//CHECK: %[[c256:.*]] = arith.constant 256 : index
//CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
@@ -136,7 +161,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
//CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
//CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> vector<8xf16>
-
+
%tid_x = gpu.thread_id x
%c16 = arith.constant 16 : index
%1 = xegpu.load_matrix %0[%c16, %tid_x] : !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<8xf16>
@@ -147,28 +172,26 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
gpu.return %1: vector<8xf16>
}
-
+
// e.g. for mem_desc<32x64xf16, @block=[16, 16]>
// its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
- //CHECK-LABEL: load_store_matrix_5
- gpu.func @load_store_matrix_5(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
- //CHECK: %[[c0:.*]] = arith.constant 0 : index
- //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>
-
- %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
-
+ //CHECK-LABEL: load_store_matrix_blocked_subgroupblockio
+ gpu.func @load_store_matrix_blocked_subgroupblockio(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
+
+ //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
+ //CHECK: %[[basePtrI32:.*]] = arith.index_castui %[[intptr]] : index to i32
+ %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
+
//CHECK: %[[c16:.*]] = arith.constant 16 : index
//CHECK: %[[c48:.*]] = arith.constant 48 : index
-
%c16 = arith.constant 16 : index
%c48 = arith.constant 48 : index
- //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
- //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
//CHECK: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
//CHECK: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
//CHECK: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index
//CHECK: %[[offset3:.*]] = arith.remsi %[[c48]], %[[c16]] : index
+ //CHECK: %[[c0:.*]] = arith.constant 0 : index
//CHECK: %[[c1024:.*]] = arith.constant 1024 : index
//CHECK: %[[mul0:.*]] = arith.muli %[[offset0]], %[[c1024]] : index
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
@@ -183,7 +206,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
//CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i32
//CHECK: %[[c2:.*]] = arith.constant 2 : i32
//CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i32
- //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI64]], %[[byteOffset]] : i32
+ //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI32]], %[[byteOffset]] : i32
//CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i32 to !llvm.ptr<3>
//CHECK: %[[loadedI16:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16>
//CHECK: %[[loaded:.*]] = vector.bitcast %[[loadedI16]] : vector<8xi16> to vector<8xf16>
@@ -191,11 +214,22 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
%1 = xegpu.load_matrix %0[%c16, %c48] {subgroup_block_io}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> vector<8xf16>
//CHECK: %[[storeDataI16:.*]] = vector.bitcast %[[loaded]] : vector<8xf16> to vector<8xi16>
- //CHECK: xevm.blockstore %[[ptr]], %[[storeDataI16]] : (!llvm.ptr<3>, vector<8xi16>)
+ //CHECK: xevm.blockstore %[[ptr]], %[[storeDataI16]] : (!llvm.ptr<3>, vector<8xi16>)
xegpu.store_matrix %1, %0[%c16, %c48] {subgroup_block_io}: vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index
gpu.return %1: vector<8xf16>
}
+ gpu.func @matrix_vector_materialization(%matrixdesc : !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>) {
+ // CHECK: %[[XEVM_VECTOR:.*]] = llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xf16>
+ // CHECK: %[[SOURCE_MATERIALIZE:.*]] = vector.shape_cast %[[XEVM_VECTOR]] : vector<16xf16> to vector<1x16xf16>
+ // CHECK: %[[XEGPU_VECTOR:.*]] = arith.addf %[[SOURCE_MATERIALIZE]], %[[SOURCE_MATERIALIZE]] : vector<1x16xf16>
+ // CHECK: %[[TARGET_MATERIALIZE:.*]] = vector.shape_cast %[[XEGPU_VECTOR]] : vector<1x16xf16> to vector<16xf16>
+ // CHECK: llvm.store %[[TARGET_MATERIALIZE]], %{{.*}} : vector<16xf16>, !llvm.ptr<3>
+ %loaded = xegpu.load_matrix %matrixdesc[16,0] : !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>> -> vector<1x16xf16>
+ %loaded_2 = arith.addf %loaded, %loaded : vector<1x16xf16>
+ xegpu.store_matrix %loaded_2, %matrixdesc[16,0] : vector<1x16xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
+ gpu.return
+ }
}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
index 4c6bbf2..4c73c9c 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
@@ -1,72 +1,32 @@
-// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
+// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
gpu.module @load_store_check {
+ // CHECK-LABEL: gpu.func @load_store(
gpu.func @load_store(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel {
+ // CHECK: %[[W_P_BYTES:.*]] = arith.constant 64 : i32
+ // CHECK: %[[ZERO:.*]] = arith.constant 0 : i32
+ // CHECK: %[[H:.*]] = arith.constant 8 : i32
%srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
%dstte = memref.memory_space_cast %dst : memref<8x16xf32, 1> to memref<8x16xf32>
- // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64
- // CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
- // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
- // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32>
- // CHECK: %[[LD_DESC_2:.*]] = vector.insert {{.*}}, %[[LD_DESC_1]] [2] : i32 into vector<8xi32>
- // CHECK: %[[LD_DESC_3:.*]] = vector.insert {{.*}}, %[[LD_DESC_2]] [3] : i32 into vector<8xi32>
- // CHECK: %[[LD_DESC_4:.*]] = vector.insert {{.*}}, %[[LD_DESC_3]] [4] : i32 into vector<8xi32>
- // CHECK: %[[LD_DESC:.*]] = vector.insert {{.*}}, %[[LD_DESC_4]] [5] : i32 into vector<8xi32>
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-
- //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64>
- //CHECK: %[[LD_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64>
- //CHECK: %[[LD_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32>
- //CHECK: %[[LD_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32>
- //CHECK: %[[LD_TILE_W64:.*]] = arith.constant 0 : i64
- //CHECK: %[[LD_TILE_W:.*]] = arith.trunci %[[LD_TILE_W64]] : i64 to i32
- //CHECK: %[[LD_TILE_H64:.*]] = arith.constant 0 : i64
- //CHECK: %[[LD_TILE_H:.*]] = arith.trunci %[[LD_TILE_H64]] : i64 to i32
- //CHECK: %[[LD_LLVMPTR:.*]] = llvm.inttoptr %[[LD_INTPTR]] : i64 to !llvm.ptr<1>
- //CHECK: %[[LD_SIZEOF_F32:.*]] = arith.constant 4 : i32
- //CHECK: %[[LD_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[LD_BASE_W]], %[[LD_SIZEOF_F32]] : i32
- //CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %[[LD_LLVMPTR]], %[[LD_BASE_ROW_IN_BYTES]],
- //CHECK-SAME: %[[LD_BASE_H]], %[[LD_BASE_ROW_IN_BYTES]], %[[LD_TILE_W]], %[[LD_TILE_H]]
+ //CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %{{.*}}, %[[W_P_BYTES]], %[[H]], %[[W_P_BYTES]], %[[ZERO]], %[[ZERO]]
//CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
//CHECK-SAME: pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false,
//CHECK-SAME: v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
%loaded = xegpu.load_nd %src_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
: !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
- //CHECK: %[[LD_LOADED_F32:.*]] = vector.bitcast %[[LD_LOADED_I32]] : vector<8xi32> to vector<8xf32>
%tid_x = gpu.thread_id x
%tid_x_i32 = arith.index_cast %tid_x : index to i32
%tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32
- //CHECK: %[[LOADED_F32_MODIFIED:.*]] = vector.insert %{{.*}}, %[[LD_LOADED_F32]] [0] : f32 into vector<8xf32>
%loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32>
- // CHECK: %[[PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64
- // CHECK: %[[CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
- // CHECK: %[[DESC_0:.*]] = vector.insert %[[PTR_AS_I64]], %[[CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
- // CHECK: %[[DESC_1:.*]] = vector.bitcast %[[DESC_0]] : vector<4xi64> to vector<8xi32>
- // CHECK: %[[DESC_2:.*]] = vector.insert {{.*}}, %[[DESC_1]] [2] : i32 into vector<8xi32>
- // CHECK: %[[DESC_3:.*]] = vector.insert {{.*}}, %[[DESC_2]] [3] : i32 into vector<8xi32>
- // CHECK: %[[DESC_4:.*]] = vector.insert {{.*}}, %[[DESC_3]] [4] : i32 into vector<8xi32>
- // CHECK: %[[DESC:.*]] = vector.insert {{.*}}, %[[DESC_4]] [5] : i32 into vector<8xi32>
%dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
- //CHECK: %[[DESC_I64:.*]] = vector.bitcast %[[DESC]] : vector<8xi32> to vector<4xi64>
- //CHECK: %[[INTPTR:.*]] = vector.extract %[[DESC_I64]][0] : i64 from vector<4xi64>
- //CHECK: %[[BASE_W:.*]] = vector.extract %[[DESC]][2] : i32 from vector<8xi32>
- //CHECK: %[[BASE_H:.*]] = vector.extract %[[DESC]][3] : i32 from vector<8xi32>
- //CHECK: %[[TILE_W64:.*]] = arith.constant 0 : i64
- //CHECK: %[[TILE_W:.*]] = arith.trunci %[[TILE_W64]] : i64 to i32
- //CHECK: %[[TILE_H64:.*]] = arith.constant 0 : i64
- //CHECK: %[[TILE_H:.*]] = arith.trunci %[[TILE_H64]] : i64 to i32
- //CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %[[INTPTR]] : i64 to !llvm.ptr<1>
- //CHECK: %[[SIZEOF_F32:.*]] = arith.constant 4 : i32
- //CHECK: %[[BASE_ROW_IN_BYTES:.*]] = arith.muli %[[BASE_W]], %[[SIZEOF_F32]] : i32
- //CHECK: %[[FLAT_VALUE_I32:.*]] = vector.bitcast %[[LOADED_F32_MODIFIED]] : vector<8xf32> to vector<8xi32>
- //CHECK: xevm.blockstore2d %[[LLVMPTR]], %[[BASE_ROW_IN_BYTES]], %[[BASE_H]], %[[BASE_ROW_IN_BYTES]],
- //CHECK-SAME: %[[TILE_W]], %[[TILE_H]], %[[FLAT_VALUE_I32]]
- //CHECK-SAME: <{cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
+ //CHECK: xevm.blockstore2d %{{.*}}, %[[W_P_BYTES]], %[[H]], %[[W_P_BYTES]], %[[ZERO]], %[[ZERO]], %{{.*}} <{
+ //CHECK-SAME: cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
//CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
xegpu.store_nd %loaded_modified, %dst_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
: vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_sub_byte.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_sub_byte.mlir
new file mode 100644
index 0000000..97e5ce1
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_sub_byte.mlir
@@ -0,0 +1,80 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
+
+gpu.module @load_store_check {
+ // CHECK-LABEL: gpu.func @load_store_matrix_a
+ // CHECK-SAME: %[[ARG0:.*]]: memref<16x128xi4, 1>, %[[ARG1:.*]]: memref<16x128xi4, 1>
+ gpu.func @load_store_matrix_a(%src: memref<16x128xi4, 1>, %dst: memref<16x128xi4, 1>) kernel {
+ // CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32
+ // CHECK: %[[C8_I32:.*]] = arith.constant 8 : i32
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4xi64>
+ // CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32
+ // CHECK: %[[C128_I32:.*]] = arith.constant 128 : i32
+ // CHECK: %[[SRCCE:.*]] = memref.memory_space_cast %[[ARG0]]
+ // CHECK: %[[SRCINDEX:.*]] = memref.extract_aligned_pointer_as_index %[[SRCCE]]
+ // CHECK: %[[SRCPTR64:.*]] = arith.index_castui %[[SRCINDEX]] : index to i64
+ %srcce = memref.memory_space_cast %src : memref<16x128xi4, 1> to memref<16x128xi4>
+ // CHECK: %[[DSTTE:.*]] = memref.memory_space_cast %[[ARG1]]
+ // CHECK: %[[DSTINDEX:.*]] = memref.extract_aligned_pointer_as_index %[[DSTTE]]
+ // CHECK: %[[DSTPTR64:.*]] = arith.index_castui %[[DSTINDEX]] : index to i64
+ %dstte = memref.memory_space_cast %dst : memref<16x128xi4, 1> to memref<16x128xi4>
+
+ // CHECK: %[[PAYLOAD_SRC:.*]] = vector.insert %[[SRCPTR64]], %[[CST]] [0] : i64 into vector<4xi64>
+ // CHECK: %[[BITCAST1_SRC:.*]] = vector.bitcast %[[PAYLOAD_SRC]] : vector<4xi64> to vector<8xi32>
+ // CHECK: %[[PAYLOAD1_SRC:.*]] = vector.insert %[[C128_I32]], %[[BITCAST1_SRC]] [2] : i32 into vector<8xi32>
+ // CHECK: %[[PAYLOAD2_SRC:.*]] = vector.insert %[[C16_I32]], %[[PAYLOAD1_SRC]] [3] : i32 into vector<8xi32>
+ // CHECK: %[[PAYLOAD3_SRC:.*]] = vector.insert %[[C128_I32]], %[[PAYLOAD2_SRC]] [4] : i32 into vector<8xi32>
+ %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x128xi4> -> !xegpu.tensor_desc<8x64xi4>
+
+ // CHECK: %[[BITCAST2:.*]] = vector.bitcast %[[PAYLOAD3_SRC]] : vector<8xi32> to vector<4xi64>
+ // CHECK: %[[SRCPTR64:.*]] = vector.extract %[[BITCAST2]][0] : i64 from vector<4xi64>
+ // CHECK: %[[SRCLLVMPTR:.*]] = llvm.inttoptr %[[SRCPTR64]] : i64 to !llvm.ptr<1>
+ // CHECK: %[[LOADED:.*]] = xevm.blockload2d %[[SRCLLVMPTR]], %[[C64_I32]],
+ // CHECK-SAME: %[[C16_I32]], %[[C64_I32]], %[[C16_I32]], %[[C8_I32]] <{
+ // CHECK-SAME: cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 16 : i32,
+ // CHECK-SAME: pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false,
+ // CHECK-SAME: v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
+ %loaded = xegpu.load_nd %src_tdesc[8, 64] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<8x64xi4> -> vector<32xi4>
+
+ // CHECK: %[[PAYLOAD_DST:.*]] = vector.insert %[[DSTPTR64]], %[[CST]] [0] : i64 into vector<4xi64>
+ // CHECK: %[[BITCAST1_DST:.*]] = vector.bitcast %[[PAYLOAD_DST]] : vector<4xi64> to vector<8xi32>
+ // CHECK: %[[PAYLOAD1_DST:.*]] = vector.insert %[[C128_I32]], %[[BITCAST1_DST]] [2] : i32 into vector<8xi32>
+ // CHECK: %[[PAYLOAD2_DST:.*]] = vector.insert %[[C16_I32]], %[[PAYLOAD1_DST]] [3] : i32 into vector<8xi32>
+ // CHECK: %[[PAYLOAD3_DST:.*]] = vector.insert %[[C128_I32]], %[[PAYLOAD2_DST]] [4] : i32 into vector<8xi32>
+ %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<16x128xi4> -> !xegpu.tensor_desc<8x64xi4, #xegpu.block_tdesc_attr<memory_space = global>>
+
+ // CHECK: %[[BITCAST2_DST:.*]] = vector.bitcast %[[PAYLOAD3_DST]] : vector<8xi32> to vector<4xi64>
+ // CHECK: %[[DSTPTR64:.*]] = vector.extract %[[BITCAST2_DST]][0] : i64 from vector<4xi64>
+ // CHECK: %[[DSTLLVMPTR:.*]] = llvm.inttoptr %[[DSTPTR64]] : i64 to !llvm.ptr<1>
+ // CHECK: xevm.blockstore2d %[[DSTLLVMPTR]], %[[C64_I32]], %[[C16_I32]],
+ // CHECK-SAME: %[[C64_I32]], %[[C16_I32]], %[[C8_I32]], %[[LOADED]] <{
+ // CHECK-SAME: cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 16 : i32,
+ // CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
+ xegpu.store_nd %loaded, %dst_tdesc[8, 64] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : vector<32xi4>, !xegpu.tensor_desc<8x64xi4, #xegpu.block_tdesc_attr<memory_space = global>>
+ gpu.return
+ }
+
+ // CHECK-LABEL: gpu.func @load_matrix_b_request_pack
+ gpu.func @load_matrix_b_request_pack(%src: memref<64x128xi4, 1>, %dst: memref<64x128xi4, 1>) kernel {
+ // CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32
+ // CHECK: %[[C32_I32:.*]] = arith.constant 32 : i32
+ // CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32
+ %srcce = memref.memory_space_cast %src : memref<64x128xi4, 1> to memref<64x128xi4>
+ %dstte = memref.memory_space_cast %dst : memref<64x128xi4, 1> to memref<64x128xi4>
+
+ %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<64x128xi4> -> !xegpu.tensor_desc<32x32xi4>
+
+ // CHECK: xevm.blockload2d %{{.*}}, %[[C64_I32]], %[[C64_I32]], %[[C64_I32]], %[[C16_I32]], %[[C32_I32]] <{
+ // CHECK-SAME: cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 8 : i32,
+ // CHECK-SAME: pack_register = true, tile_height = 32 : i32, tile_width = 16 : i32, transpose = false,
+ // CHECK-SAME: v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
+ %loaded = xegpu.load_nd %src_tdesc[32, 32] <{packed, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<32x32xi4> -> vector<64xi4>
+
+ %c32 = arith.constant 32 : index
+ %c0 = arith.constant 0 : index
+ vector.store %loaded, %dstte[%c32, %c0] : memref<64x128xi4>, vector<64xi4>
+ gpu.return
+ }
+}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
index 9c552d8..d606cf5 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
@@ -1,15 +1,16 @@
-// RUN: mlir-opt %s --split-input-file -convert-xegpu-to-xevm | FileCheck %s
+// RUN: mlir-opt %s --split-input-file -convert-xegpu-to-xevm -canonicalize | FileCheck %s
gpu.module @test {
// CHECK-LABEL: @load_gather_i64_src_value_offset
-// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex>
-gpu.func @load_gather_i64_src_value_offset(%src: i64, %offset: vector<1xindex>) {
+// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: memref<1xf16>
+// CHECK-SAME: %[[ARG3:.*]]: vector<1xi1>
+gpu.func @load_gather_i64_src_value_offset(%src: i64, %offset: vector<1xindex>, %dst: memref<1xf16>, %mask: vector<1xi1>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f16
+ // CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64
+ // CHECK: %[[VAR2:.*]] = vector.extract %[[ARG3]][0] : i1 from vector<1xi1>
// CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
// CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
- // CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1>
- // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1>
- %1 = arith.constant dense<1>: vector<1xi1>
- // CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64
// CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C2_I64]] : i64
// CHECK: %[[VAR4:.*]] = arith.addi %[[ARG0]], %[[VAR3]] : i64
// CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1>
@@ -17,11 +18,12 @@ gpu.func @load_gather_i64_src_value_offset(%src: i64, %offset: vector<1xindex>)
// CHECK: %[[VAR7:.*]] = llvm.load %[[VAR5]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>} : !llvm.ptr<1> -> f16
// CHECK: scf.yield %[[VAR7]] : f16
// CHECK: } else {
- // CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f16
// CHECK: scf.yield %[[CST_0]] : f16
// CHECK: }
- %3 = xegpu.load %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ %0 = xegpu.load %src[%offset], %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
: i64, vector<1xindex>, vector<1xi1> -> vector<1xf16>
+ %c0 = arith.constant 0 : index
+ vector.store %0, %dst[%c0] : memref<1xf16>, vector<1xf16>
gpu.return
}
}
@@ -30,16 +32,16 @@ gpu.func @load_gather_i64_src_value_offset(%src: i64, %offset: vector<1xindex>)
gpu.module @test {
// CHECK-LABEL: @source_materialize_single_elem_vec
// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: memref<1xf16>
-gpu.func @source_materialize_single_elem_vec(%src: i64, %offset: vector<1xindex>, %dst: memref<1xf16>) {
- %1 = arith.constant dense<1>: vector<1xi1>
- %3 = xegpu.load %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+// CHECK-SAME: %[[ARG3:.*]]: vector<1xi1>
+gpu.func @source_materialize_single_elem_vec(%src: i64, %offset: vector<1xindex>, %dst: memref<1xf16>, %mask: vector<1xi1>) {
+ %0 = xegpu.load %src[%offset], %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
: i64, vector<1xindex>, vector<1xi1> -> vector<1xf16>
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[VAR_IF:.*]] = scf.if
// CHECK: %[[VAR_RET:.*]] = vector.broadcast %[[VAR_IF]] : f16 to vector<1xf16>
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: vector.store %[[VAR_RET]], %[[ARG2]][%[[C0]]] : memref<1xf16>, vector<1xf16>
%c0 = arith.constant 0 : index
- vector.store %3, %dst[%c0] : memref<1xf16>, vector<1xf16>
+ vector.store %0, %dst[%c0] : memref<1xf16>, vector<1xf16>
gpu.return
}
}
@@ -48,24 +50,21 @@ gpu.func @source_materialize_single_elem_vec(%src: i64, %offset: vector<1xindex>
gpu.module @test {
// CHECK-LABEL: @store_scatter_i64_src_value_offset
-// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex>
-gpu.func @store_scatter_i64_src_value_offset(%src: i64, %offset: vector<1xindex>) {
+// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: vector<1xi1>
+gpu.func @store_scatter_i64_src_value_offset(%src: i64, %offset: vector<1xindex>, %mask: vector<1xi1>) {
+ // CHECK: %[[CST_0:.*]] = arith.constant 2.900000e+00 : f32
+ // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
+ // CHECK: %[[VAR2:.*]] = vector.extract %[[ARG2]][0] : i1 from vector<1xi1>
// CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
// CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
- // CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1>
- // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1>
- %1 = arith.constant dense<1>: vector<1xi1>
- // CHECK: %[[CST_0:.*]] = arith.constant dense<2.900000e+00> : vector<1xf32>
- // CHECK: %[[VAR3:.*]] = vector.extract %[[CST_0]][0] : f32 from vector<1xf32>
- %2 = arith.constant dense<2.9>: vector<1xf32>
- // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
+ %0 = arith.constant dense<2.9>: vector<1xf32>
// CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
// CHECK: %[[VAR5:.*]] = arith.addi %[[ARG0]], %[[VAR4]] : i64
// CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1>
// CHECK: scf.if %[[VAR2]] {
- // CHECK: llvm.store %[[VAR3]], %[[VAR6]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>} : f32, !llvm.ptr<1>
+ // CHECK: llvm.store %[[CST_0]], %[[VAR6]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>} : f32, !llvm.ptr<1>
// CHECK: }
- xegpu.store %2, %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
+ xegpu.store %0, %src[%offset], %mask <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
: vector<1xf32>, i64, vector<1xindex>, vector<1xi1>
gpu.return
}
@@ -76,9 +75,9 @@ gpu.module @test {
// CHECK-LABEL: @prefetch_i64_src_value_offset
// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex>
gpu.func @prefetch_i64_src_value_offset(%src: i64, %offset: vector<1xindex>) {
+ // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
// CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
// CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
- // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
// CHECK: %[[VAR2:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
// CHECK: %[[VAR3:.*]] = arith.addi %[[ARG0]], %[[VAR2]] : i64
// CHECK: %[[VAR4:.*]] = llvm.inttoptr %[[VAR3]] : i64 to !llvm.ptr<1>
@@ -94,11 +93,11 @@ gpu.module @test {
// CHECK-LABEL: @prefetch_memref_src_value_offset
// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex>
gpu.func @prefetch_memref_src_value_offset(%src: memref<256xf32>, %offset: vector<1xindex>) {
+ // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
// CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
// CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
// CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
- // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
// CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
// CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64
// CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1>
diff --git a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir
index 873478a..43df721 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir
@@ -1,34 +1,18 @@
-// RUN: mlir-opt -convert-xegpu-to-xevm -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
-gpu.module @fence_check {
- gpu.func @fence(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel {
+gpu.module @prefetch_nd_check {
+ // CHECK-LABEL: gpu.func @prefetch_nd
+ gpu.func @prefetch_nd(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel {
+ // CHECK: %[[BASE_WIDTH_PITCH_BYTES:.*]] = arith.constant 64 : i32
+ // CHECK: %[[OFFSET_ZERO:.*]] = arith.constant 0 : i32
+ // CHECK: %[[BASE_H:.*]] = arith.constant 8 : i32
%srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
- %dstte = memref.memory_space_cast %dst : memref<8x16xf32, 1> to memref<8x16xf32>
-
- // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64
- // CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
- // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
- // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32>
- // CHECK: %[[LD_DESC_2:.*]] = vector.insert {{.*}}, %[[LD_DESC_1]] [2] : i32 into vector<8xi32>
- // CHECK: %[[LD_DESC_3:.*]] = vector.insert {{.*}}, %[[LD_DESC_2]] [3] : i32 into vector<8xi32>
- // CHECK: %[[LD_DESC_4:.*]] = vector.insert {{.*}}, %[[LD_DESC_3]] [4] : i32 into vector<8xi32>
- // CHECK: %[[LD_DESC:.*]] = vector.insert {{.*}}, %[[LD_DESC_4]] [5] : i32 into vector<8xi32>
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32,
#xegpu.block_tdesc_attr<memory_space = global>, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64>
- //CHECK: %[[PREF_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64>
- //CHECK: %[[PREF_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32>
- //CHECK: %[[PREF_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32>
- //CHECK: %[[PREF_TILE_W64:.*]] = arith.constant 0 : i64
- //CHECK: %[[PREF_TILE_W:.*]] = arith.trunci %[[PREF_TILE_W64]] : i64 to i32
- //CHECK: %[[PREF_TILE_H64:.*]] = arith.constant 0 : i64
- //CHECK: %[[PREF_TILE_H:.*]] = arith.trunci %[[PREF_TILE_H64]] : i64 to i32
- //CHECK: %[[PREF_LLVMPTR:.*]] = llvm.inttoptr %[[PREF_INTPTR]] : i64 to !llvm.ptr<1>
- //CHECK: %[[PREF_SIZEOF_F32:.*]] = arith.constant 4 : i32
- //CHECK: %[[PREF_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[PREF_BASE_W]], %[[PREF_SIZEOF_F32]] : i32
- //CHECK: xevm.blockprefetch2d %[[PREF_LLVMPTR]], %[[PREF_BASE_ROW_IN_BYTES]], %[[PREF_BASE_H]],
- //CHECK-SAME: %[[PREF_BASE_ROW_IN_BYTES]], %[[PREF_TILE_W]], %[[PREF_TILE_H]]
+ //CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %{{.*}} : i64 to !llvm.ptr<1>
+ //CHECK: xevm.blockprefetch2d %[[LLVMPTR]], %[[BASE_WIDTH_PITCH_BYTES]], %[[BASE_H]],
+ //CHECK-SAME: %[[BASE_WIDTH_PITCH_BYTES]], %[[OFFSET_ZERO]], %[[OFFSET_ZERO]]
//CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
//CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32, v_blocks = 1 : i32}>
//CHECK-SAME: : (!llvm.ptr<1>, i32, i32, i32, i32, i32)
diff --git a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd_sub_byte.mlir b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd_sub_byte.mlir
new file mode 100644
index 0000000..f925472
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd_sub_byte.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
+
+gpu.module @prefetch_check {
+ // CHECK-LABEL: gpu.func @prefetch_matrix_a
+ gpu.func @prefetch_matrix_a(%src: memref<16x128xi4, 1>) kernel {
+ // CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32
+ // CHECK: %[[C8_I32:.*]] = arith.constant 8 : i32
+ // CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32
+ %srcce = memref.memory_space_cast %src : memref<16x128xi4, 1> to memref<16x128xi4>
+
+ %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x128xi4> -> !xegpu.tensor_desc<8x64xi4>
+
+ // CHECK: xevm.blockprefetch2d %{{.*}}, %[[C64_I32]], %[[C16_I32]], %[[C64_I32]], %[[C16_I32]], %[[C8_I32]]
+ // CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 16 : i32,
+ // CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32, v_blocks = 1 : i32}> : (!llvm.ptr<1>
+ xegpu.prefetch_nd %src_tdesc[8, 64] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<8x64xi4>
+
+ gpu.return
+ }
+}
diff --git a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
index 72e70ff..7f01526 100644
--- a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
+++ b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
@@ -175,7 +175,7 @@ llvm.func @blockstore2d_cache_control(%c: !llvm.ptr<1>, %base_width_c: i32, %bas
// -----
// CHECK-LABEL: llvm.func spir_funccc @_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i(
// CHECK-SAME: !llvm.ptr<1> {llvm.nonnull}, i32, i32, i32, vector<2xi32>) attributes
-// CHECK-SAME: {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind}
+// CHECK-SAME: {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>, no_unwind}
// CHECK: llvm.func @blockprefetch2d(%[[ARG0:.*]]: !llvm.ptr<1>,
// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) {
llvm.func @blockprefetch2d(%ptr: !llvm.ptr<1>, %base_width: i32, %base_height: i32, %base_pitch: i32, %x: i32, %y: i32) {
@@ -187,7 +187,7 @@ llvm.func @blockprefetch2d(%ptr: !llvm.ptr<1>, %base_width: i32, %base_height: i
// CHECK: llvm.call spir_funccc @_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i(
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]])
// CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>)>, linkage = #llvm.linkage<external>,
- // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind,
+ // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>, no_unwind,
// CHECK-SAME: sym_name = "_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i", visibility_ = 0 : i64
xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y
<{elem_size_in_bits=8 : i32, tile_width=32 : i32, tile_height=8 : i32, v_blocks=1 : i32,
@@ -200,13 +200,13 @@ llvm.func @blockprefetch2d(%ptr: !llvm.ptr<1>, %base_width: i32, %base_height: i
// CHECK-LABEL: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(
// CHECK-SAME: vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes
// CHECK-SAME: {convergent, memory_effects = #llvm.memory_effects<other = none, argMem = none,
-// CHECK-SAME: inaccessibleMem = none>, no_unwind, will_return}
+// CHECK-SAME: inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>, no_unwind, will_return}
// CHECK: llvm.func @mma(%[[ARG0:.*]]: vector<8xf32>, %[[ARG1:.*]]: vector<8xi16>, %[[ARG2:.*]]: vector<8xi32>) -> vector<8xf32> {
llvm.func @mma(%loaded_c_casted: vector<8xf32>, %loaded_a: vector<8xi16>, %loaded_b_casted: vector<8xi32>) -> vector<8xf32> {
// CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(
// CHECK-SAME: %[[ARG1]], %[[ARG2]], %[[ARG0]]) {convergent, function_type =
// CHECK-SAME: !llvm.func<vector<8xf32> (vector<8xi16>, vector<8xi32>, vector<8xf32>)>, linkage = #llvm.linkage<external>,
- // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind,
+ // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>, no_unwind,
// CHECK-SAME: sym_name = "_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f", visibility_ = 0 : i64, will_return}
// CHECK-SAME: : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
%c_result = xevm.mma %loaded_a, %loaded_b_casted, %loaded_c_casted
@@ -230,13 +230,13 @@ llvm.func @memfence() {
// -----
// CHECK-LABEL: llvm.func spir_funccc @_Z8prefetchPU3AS1Kcm(!llvm.ptr<1>, i64) attributes
-// CHECK-SAME: {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind}
+// CHECK-SAME: {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>, no_unwind}
// CHECK: llvm.func @prefetch(%[[ARG0:.*]]: !llvm.ptr<1>) {
llvm.func @prefetch(%ptr: !llvm.ptr<1>) {
// CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: llvm.call spir_funccc @_Z8prefetchPU3AS1Kcm(%[[ARG0]], %[[VAR0]])
// CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i64)>, linkage = #llvm.linkage<external>,
- // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>,
+ // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>,
// CHECK-SAME: no_unwind, sym_name = "_Z8prefetchPU3AS1Kcm", visibility_ = 0 : i64
xevm.prefetch %ptr <{cache_control = #xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>)
llvm.return
@@ -352,7 +352,7 @@ llvm.func @local_id.x() -> i32 {
// CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z12get_local_idj(%[[VAR0]])
// CHECK-SAME: {function_type = !llvm.func<i32 (i32)>, linkage = #llvm.linkage<external>,
- // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
+ // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>,
// CHECK-SAME: no_unwind, sym_name = "_Z12get_local_idj", visibility_ = 0 : i64, will_return} : (i32) -> i32
%1 = xevm.local_id.x : i32
llvm.return %1 : i32
@@ -380,7 +380,7 @@ llvm.func @local_size.x() -> i32 {
// CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z14get_local_sizej(%[[VAR0]])
// CHECK-SAME: {function_type = !llvm.func<i32 (i32)>, linkage = #llvm.linkage<external>,
- // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
+ // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>,
// CHECK-SAME: no_unwind, sym_name = "_Z14get_local_sizej", visibility_ = 0 : i64, will_return} : (i32) -> i32
%1 = xevm.local_size.x : i32
llvm.return %1 : i32
@@ -408,7 +408,7 @@ llvm.func @group_id.x() -> i32 {
// CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z12get_group_idj(%[[VAR0]])
// CHECK-SAME: {function_type = !llvm.func<i32 (i32)>, linkage = #llvm.linkage<external>,
- // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
+ // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>,
// CHECK-SAME: no_unwind, sym_name = "_Z12get_group_idj", visibility_ = 0 : i64, will_return} : (i32) -> i32
%1 = xevm.group_id.x : i32
llvm.return %1 : i32
@@ -436,7 +436,7 @@ llvm.func @group_count.x() -> i32 {
// CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z14get_num_groupsj(%[[VAR0]])
// CHECK-SAME: {function_type = !llvm.func<i32 (i32)>, linkage = #llvm.linkage<external>,
- // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
+ // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>,
// CHECK-SAME: no_unwind, sym_name = "_Z14get_num_groupsj", visibility_ = 0 : i64, will_return} : (i32) -> i32
%1 = xevm.group_count.x : i32
llvm.return %1 : i32
@@ -463,7 +463,7 @@ llvm.func @group_count.z() -> i32 {
llvm.func @lane_id() -> i32 {
// CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id()
// CHECK-SAME: {function_type = !llvm.func<i32 ()>, linkage = #llvm.linkage<external>,
- // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
+ // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>,
// CHECK-SAME: no_unwind, sym_name = "_Z22get_sub_group_local_id", visibility_ = 0 : i64, will_return} : () -> i32
%1 = xevm.lane_id : i32
llvm.return %1 : i32
@@ -474,7 +474,7 @@ llvm.func @lane_id() -> i32 {
llvm.func @subgroup_size() -> i32 {
// CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z18get_sub_group_size()
// CHECK-SAME: {function_type = !llvm.func<i32 ()>, linkage = #llvm.linkage<external>,
- // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
+ // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>,
// CHECK-SAME: no_unwind, sym_name = "_Z18get_sub_group_size", visibility_ = 0 : i64, will_return} : () -> i32
%1 = xevm.subgroup_size : i32
llvm.return %1 : i32
@@ -485,7 +485,7 @@ llvm.func @subgroup_size() -> i32 {
llvm.func @subgroup_id() -> i32 {
// CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id()
// CHECK-SAME: {function_type = !llvm.func<i32 ()>, linkage = #llvm.linkage<external>,
- // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
+ // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>,
// CHECK-SAME: no_unwind, sym_name = "_Z16get_sub_group_id", visibility_ = 0 : i64, will_return} : () -> i32
%1 = xevm.subgroup_id : i32
llvm.return %1 : i32
diff --git a/mlir/test/Dialect/AMDGPU/amdgpu-make-dma-descriptor-fold.mlir b/mlir/test/Dialect/AMDGPU/amdgpu-make-dma-descriptor-fold.mlir
new file mode 100644
index 0000000..9d43c99
--- /dev/null
+++ b/mlir/test/Dialect/AMDGPU/amdgpu-make-dma-descriptor-fold.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-opt --canonicalize %s | FileCheck %s
+
+// CHECK-LABEL: @make_dma_descriptor_fold
+// CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_base<i32>, %[[IDX:.+]]: index)
+func.func @make_dma_descriptor_fold(%base: !amdgpu.tdm_base<i32>, %idx: index) -> !amdgpu.tdm_descriptor {
+ %c64 = arith.constant 64 : index
+
+ // CHECK: amdgpu.make_dma_descriptor %[[BASE]]
+ %0 = amdgpu.make_dma_descriptor %base
+ // CHECK-SAME: globalSize [64, 64]
+ globalSize [%c64, %c64]
+ // CHECK-SAME: globalStride [64, 1]
+ globalStride [%c64, 1]
+ // CHECK-SAME: sharedSize [64, 64]
+ sharedSize [%c64, %c64]
+ iterate %idx, %idx, %idx
+ : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor
+ func.return %0 : !amdgpu.tdm_descriptor
+}
diff --git a/mlir/test/Dialect/AMDGPU/canonicalize.mlir b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
index fee0c00..cff1d3f 100644
--- a/mlir/test/Dialect/AMDGPU/canonicalize.mlir
+++ b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
@@ -244,3 +244,39 @@ func.func @scaled_mfma_ugly_shapes(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4
%res_7 = amdgpu.scaled_mfma 16x16x128 (%sA_0_7[0] * %opA) * (%sB_6_19[0] * %opB) + %cst_0 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
return %res_4, %res_5, %res_6, %res_7 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
}
+
+// -----
+
+// CHECK-LABEL fuse_memory_counter_wait
+func.func @fuse_memory_counter_wait() {
+ // CHECK: amdgpu.memory_counter_wait
+ // CHECK-SAME: load(1) store(2) ds(2) exp(1) tensor(0)
+ // CHECK-NEXT: return
+ amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4) tensor(5)
+ amdgpu.memory_counter_wait load(4) store(3) ds(2) exp(1) tensor(0)
+ return
+}
+
+// CHECK-LABEL fuse_memory_counter_wait_different_counters
+func.func @fuse_memory_counter_wait_different_counters() {
+ // CHECK: amdgpu.memory_counter_wait
+ // CHECK-SAME: load(1) store(2) ds(3) exp(4)
+ // CHECK-NEXT: return
+ amdgpu.memory_counter_wait load(1) store(2)
+ amdgpu.memory_counter_wait ds(3) exp(4)
+ return
+}
+
+func.func private @use()
+
+// CHECK-LABEL fuse_memory_counter_wait_not_adjacent
+func.func @fuse_memory_counter_wait_not_adjacent() {
+ // CHECK: amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4)
+ // CHECK-NEXT: call @use()
+ // CHECK-NEXT: amdgpu.memory_counter_wait load(4) store(3) ds(2) exp(1)
+ // CHECK-NEXT: return
+ amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4)
+ func.call @use() : () -> ()
+ amdgpu.memory_counter_wait load(4) store(3) ds(2) exp(1)
+ return
+}
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 4c6f62a..6308ea9 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -333,48 +333,86 @@ func.func @gather_to_lds_non_lds(%idx1 : index, %mem1 : memref<32xf16>, %mem2 :
// -----
-func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_16(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) {
- // expected-error@+1 {{'amdgpu.scaled_ext_packed816' op blockSize of 16 can only have firstScaleByte be 0 or 1.}}
- %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(2) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16>
- func.return
+func.func @scaled_mfma_invalid_m(%arg0 : vector<4xf8E8M0FNU>, %arg1 : vector<32xf4E2M1FN>, %arg2 : vector<16xf32>) -> vector<16xf32> {
+ // expected-error@+1 {{'amdgpu.scaled_mfma' op attribute 'm' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16, 32}}}
+ %0 = amdgpu.scaled_mfma 8x32x64 (%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<16xf32>
+ func.return %0 : vector<16xf32>
}
// -----
-func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_32(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) {
- // expected-error@+1 {{'amdgpu.scaled_ext_packed816' op blockSize of 32 can only have firstScaleByte be 0 or 2.}}
- %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(1) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16>
- func.return
+func.func @scaled_mfma_invalid_n(%arg0 : vector<4xf8E8M0FNU>, %arg1 : vector<32xf4E2M1FN>, %arg2 : vector<16xf32>) -> vector<16xf32> {
+ // expected-error@+1 {{'amdgpu.scaled_mfma' op attribute 'n' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16, 32}}}
+ %0 = amdgpu.scaled_mfma 32x8x64 (%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<16xf32>
+ func.return %0 : vector<16xf32>
+}
+
+// -----
+
+func.func @scaled_mfma_invalid_k(%arg0 : vector<4xf8E8M0FNU>, %arg1 : vector<32xf4E2M1FN>, %arg2 : vector<16xf32>) -> vector<16xf32> {
+ // expected-error@+1 {{'amdgpu.scaled_mfma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {64, 128}}}
+ %0 = amdgpu.scaled_mfma 32x32x32 (%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<16xf32>
+ func.return %0 : vector<16xf32>
+}
+
+// -----
+
+func.func @make_dma_base_invalid_addressspace(%idx: index, %mem: memref<8xi32>) {
+ // expected-error@+1 {{'amdgpu.make_dma_base' op lds memref must have workgroup address space attribute.}}
+ amdgpu.make_dma_base %mem[%idx], %mem[%idx] : memref<8xi32>, memref<8xi32> -> !amdgpu.tdm_base<i32>
+}
+
+// -----
+
+func.func @make_dma_base_invalid_addressspace(%idx: index, %smem : memref<8xi32, #gpu.address_space<workgroup>>) {
+ // expected-error@+1 {{'amdgpu.make_dma_base' op global memref must have global address space attribute.}}
+ amdgpu.make_dma_base %smem[%idx], %smem[%idx] : memref<8xi32, #gpu.address_space<workgroup>>, memref<8xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32>
+}
+
+// -----
+
+func.func @make_dma_base_invalid_barrier(%base: !amdgpu.tdm_base<i32>, %barrier: memref<8xi32>, %idx: index) {
+ // expected-error@+1 {{'amdgpu.make_dma_descriptor' op atomic barrier address must be in LDS.}}
+ amdgpu.make_dma_descriptor %base globalSize [64, 64] globalStride [64, 1] sharedSize [64, 64] atomicBarrier(%barrier[%idx] : memref<8xi32>) : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor
}
// -----
-func.func @amdgpu.scaled_ext_packed816_invalid_input_output_sizes(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) {
- // expected-error@+1 {{'amdgpu.scaled_ext_packed816' op failed to verify that all of {source, res} have same shape}}
- %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<16xf16>
+// CHECK-LABEL: func @make_dma_descriptor_invalid_empty_strides
+// CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_base<i32>)
+func.func @make_dma_descriptor_invalid_empty_strides(%base: !amdgpu.tdm_base<i32>) {
+ // expected-error@+1 {{'amdgpu.make_dma_descriptor' op strides must not be empty.}}
+ amdgpu.make_dma_descriptor %base globalSize [0, 1] globalStride [] sharedSize [1, 0] : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor
func.return
}
// -----
-func.func @scaled_mfma_invalid_m(%arg0 : vector<4xf8E8M0FNU>, %arg1 : vector<32xf4E2M1FN>, %arg2 : vector<16xf32>) -> vector<16xf32> {
- // expected-error@+1 {{'amdgpu.scaled_mfma' op attribute 'm' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16, 32}}}
- %0 = amdgpu.scaled_mfma 8x32x64 (%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<16xf32>
- func.return %0 : vector<16xf32>
+// CHECK-LABEL: func @make_dma_descriptor_invalid_innermost_stride
+// CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_base<i32>)
+func.func @make_dma_descriptor_invalid_innermost_stride(%base: !amdgpu.tdm_base<i32>) {
+ // expected-error@+1 {{'amdgpu.make_dma_descriptor' op strides for the innermost dimension must be 1.}}
+ amdgpu.make_dma_descriptor %base globalSize [2, 2] globalStride [1, 2] sharedSize [1, 0] : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor
+ func.return
}
// -----
-func.func @scaled_mfma_invalid_n(%arg0 : vector<4xf8E8M0FNU>, %arg1 : vector<32xf4E2M1FN>, %arg2 : vector<16xf32>) -> vector<16xf32> {
- // expected-error@+1 {{'amdgpu.scaled_mfma' op attribute 'n' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16, 32}}}
- %0 = amdgpu.scaled_mfma 32x8x64 (%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<16xf32>
- func.return %0 : vector<16xf32>
+// CHECK-LABEL: func @make_dma_descriptor_invalid_size_and_stride_sizes
+// CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_base<i32>)
+func.func @make_dma_descriptor_invalid_size_and_stride_sizes(%base: !amdgpu.tdm_base<i32>) {
+ // expected-error@+1 {{'amdgpu.make_dma_descriptor' op strides and sizes must have same rank.}}
+ amdgpu.make_dma_descriptor %base globalSize [1, 1, 1] globalStride [1, 1] sharedSize [1, 0] : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor
+ func.return
}
// -----
-func.func @scaled_mfma_invalid_k(%arg0 : vector<4xf8E8M0FNU>, %arg1 : vector<32xf4E2M1FN>, %arg2 : vector<16xf32>) -> vector<16xf32> {
- // expected-error@+1 {{'amdgpu.scaled_mfma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {64, 128}}}
- %0 = amdgpu.scaled_mfma 32x32x32 (%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<16xf32>
- func.return %0 : vector<16xf32>
+// CHECK-LABEL: func @make_dma_descriptor_invalid_shared_and_global_rank
+// CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_base<i32>)
+func.func @make_dma_descriptor_invalid_shared_and_global_rank(%base: !amdgpu.tdm_base<i32>) {
+ // expected-error@+1 {{'amdgpu.make_dma_descriptor' op tensor must have same rank as tile.}}
+ amdgpu.make_dma_descriptor %base globalSize [4, 4] globalStride [1, 1] sharedSize [1, 2, 3] : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor
+ func.return
}
+
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 09134cb..651aff4 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -221,58 +221,58 @@ func.func @scaled_ext_scalar_f4e2m1_bf16(%v: vector<2xf4E2M1FN>, %scale: f32) ->
func.return %ret : vector<2xbf16>
}
-// CHECK-LABEL: func.func @scaled_ext_packed816_fp4
-func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
- // CHECK: amdgpu.scaled_ext_packed816
- %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
- // CHECK: amdgpu.scaled_ext_packed816
- %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xbf16>
- // CHECK: amdgpu.scaled_ext_packed816
- %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf32>
+// CHECK-LABEL: func.func @scaled_ext_packed_matrix_fp4
+func.func @scaled_ext_packed_matrix_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
+ // CHECK: amdgpu.scaled_ext_packed_matrix
+ %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
+ // CHECK: amdgpu.scaled_ext_packed_matrix
+ %ret1 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xbf16>
+ // CHECK: amdgpu.scaled_ext_packed_matrix
+ %ret2 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf32>
func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
}
-// CHECK-LABEL: func.func @scaled_ext_packed816_fp8
-func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
- // CHECK: amdgpu.scaled_ext_packed816
- %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
- // CHECK: amdgpu.scaled_ext_packed816
- %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xbf16>
- // CHECK: amdgpu.scaled_ext_packed816
- %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf32>
+// CHECK-LABEL: func.func @scaled_ext_packed_matrix_fp8
+func.func @scaled_ext_packed_matrix_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
+ // CHECK: amdgpu.scaled_ext_packed_matrix
+ %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
+ // CHECK: amdgpu.scaled_ext_packed_matrix
+ %ret1 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xbf16>
+ // CHECK: amdgpu.scaled_ext_packed_matrix
+ %ret2 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf32>
func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
}
-// CHECK-LABEL: func.func @scaled_ext_packed816_bf8
-func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
- // CHECK: amdgpu.scaled_ext_packed816
- %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16>
- // CHECK: amdgpu.scaled_ext_packed816
- %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xbf16>
- // CHECK: amdgpu.scaled_ext_packed816
- %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf32>
+// CHECK-LABEL: func.func @scaled_ext_packed_matrix_bf8
+func.func @scaled_ext_packed_matrix_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
+ // CHECK: amdgpu.scaled_ext_packed_matrix
+ %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16>
+ // CHECK: amdgpu.scaled_ext_packed_matrix
+ %ret1 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xbf16>
+ // CHECK: amdgpu.scaled_ext_packed_matrix
+ %ret2 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf32>
func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
}
-// CHECK-LABEL: func.func @scaled_ext_packed816_fp6
-func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) {
- // CHECK: amdgpu.scaled_ext_packed816
- %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16>
- // CHECK: amdgpu.scaled_ext_packed816
- %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xbf16>
- // CHECK: amdgpu.scaled_ext_packed816
- %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf32>
+// CHECK-LABEL: func.func @scaled_ext_packed_matrix_fp6
+func.func @scaled_ext_packed_matrix_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) {
+ // CHECK: amdgpu.scaled_ext_packed_matrix
+ %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16>
+ // CHECK: amdgpu.scaled_ext_packed_matrix
+ %ret1 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xbf16>
+ // CHECK: amdgpu.scaled_ext_packed_matrix
+ %ret2 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf32>
func.return %ret0, %ret1, %ret2 : vector<16xf16>, vector<16xbf16>, vector<16xf32>
}
-// CHECK-LABEL: func.func @scaled_ext_packed816_bf16
-func.func @scaled_ext_packed816_bf16(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) {
- // CHECK: amdgpu.scaled_ext_packed816
- %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf16>
- // CHECK: amdgpu.scaled_ext_packed816
- %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xbf16>
- // CHECK: amdgpu.scaled_ext_packed816
- %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf32>
+// CHECK-LABEL: func.func @scaled_ext_packed_matrix_bf6
+func.func @scaled_ext_packed_matrix_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) {
+ // CHECK: amdgpu.scaled_ext_packed_matrix
+ %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf16>
+ // CHECK: amdgpu.scaled_ext_packed_matrix
+ %ret1 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xbf16>
+ // CHECK: amdgpu.scaled_ext_packed_matrix
+ %ret2 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf32>
func.return %ret0, %ret1, %ret2 : vector<16xf16>, vector<16xbf16>, vector<16xf32>
}
@@ -671,17 +671,105 @@ func.func @gather_to_lds(%idx1 : index, %idx2 : index, %mem1 : memref<32xf16>, %
// CHECK-LABEL: func @memory_counter_wait
func.func @memory_counter_wait() {
- // CHECK: amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4)
- // CHECK: amdgpu.memory_counter_wait load(4) store(2) ds(3) exp(1)
+ // CHECK: amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4) tensor(5)
+ // CHECK: amdgpu.memory_counter_wait load(4) store(2) ds(3) exp(1) tensor(0)
// CHECK: amdgpu.memory_counter_wait load(1)
// CHECK: amdgpu.memory_counter_wait store(2)
// CHECK: amdgpu.memory_counter_wait ds(3)
// CHECK: amdgpu.memory_counter_wait exp(4)
- amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4)
- amdgpu.memory_counter_wait exp(1) store(2) ds(3) load(4)
+ // CHECK: amdgpu.memory_counter_wait tensor(5)
+ amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4) tensor(5)
+ amdgpu.memory_counter_wait tensor(0) exp(1) store(2) ds(3) load(4)
amdgpu.memory_counter_wait load(1)
amdgpu.memory_counter_wait store(2)
amdgpu.memory_counter_wait ds(3)
amdgpu.memory_counter_wait exp(4)
+ amdgpu.memory_counter_wait tensor(5)
+ func.return
+}
+
+// CHECK-LABEL: func @make_dma_base
+// CHECK-SAME: (%[[IDX:.+]]: index, %[[MEM:.+]]: memref<8xi32>, %[[SMEM:.+]]: memref<8xi32, #gpu.address_space<workgroup>>)
+func.func @make_dma_base(%idx: index, %mem: memref<8xi32>, %smem: memref<8xi32, #gpu.address_space<workgroup>>) {
+ // CHECK: amdgpu.make_dma_base %[[MEM]][%[[IDX]]], %[[SMEM]][%[[IDX]]] : memref<8xi32>, memref<8xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32>
+ amdgpu.make_dma_base %mem[%idx], %smem[%idx] : memref<8xi32>, memref<8xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32>
+ func.return
+}
+
+// CHECK-LABEL: func @make_dma_descriptor
+// CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_base<i32>, %[[WG_MASK:.+]]: i16, %[[TIMEOUT:.+]]: i1, %[[BARRIER:.+]]: memref<8xi32, #gpu.address_space<workgroup>>, %[[IDX:.+]]: index)
+func.func @make_dma_descriptor(%base: !amdgpu.tdm_base<i32>, %wg_mask: i16, %timeout: i1, %barrier: memref<8xi32, #gpu.address_space<workgroup>>, %idx: index) {
+
+ // CHECK: amdgpu.make_dma_descriptor %[[BASE]]
+ amdgpu.make_dma_descriptor %base
+ // CHECK-SAME: globalSize [64, 64]
+ globalSize [64, 64]
+ // CHECK-SAME: globalStride [64, 1]
+ globalStride [64, 1]
+ // CHECK-SAME: sharedSize [64, 64] : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor
+ sharedSize [64, 64] : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor
+
+ // CHECK: amdgpu.make_dma_descriptor %[[BASE]]
+ amdgpu.make_dma_descriptor %base
+ // CHECK-SAME: globalSize [64, 64]
+ globalSize [64, 64]
+ // CHECK-SAME: globalStride [64, 1]
+ globalStride [64, 1]
+ // CHECK-SAME: sharedSize [64, 64]
+ sharedSize [64, 64]
+ // CHECK-SAME: padShared(%[[IDX]] every %[[IDX]])
+ padShared(%idx every %idx)
+ : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor
+
+ // CHECK: amdgpu.make_dma_descriptor %[[BASE]]
+ amdgpu.make_dma_descriptor %base
+ // CHECK-SAME: globalSize [64, 64]
+ globalSize [64, 64]
+ // CHECK-SAME: globalStride [64, 1]
+ globalStride [64, 1]
+ // CHECK-SAME: sharedSize [64, 64]
+ sharedSize [64, 64]
+ // CHECK-SAME: workgroupMask %[[WG_MASK]]
+ workgroupMask %wg_mask
+ : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor
+
+ // CHECK: amdgpu.make_dma_descriptor %[[BASE]]
+ amdgpu.make_dma_descriptor %base
+ // CHECK-SAME: globalSize [64, 64]
+ globalSize [64, 64]
+ // CHECK-SAME: globalStride [64, 1]
+ globalStride [64, 1]
+ // CHECK-SAME: sharedSize [64, 64]
+ sharedSize [64, 64]
+ // CHECK-SAME: workgroupMask %[[WG_MASK]]
+ workgroupMask %wg_mask
+ // CHECK-SAME: earlyTimeout %[[TIMEOUT]]
+ earlyTimeout %timeout
+ : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor
+
+ // CHECK: amdgpu.make_dma_descriptor %[[BASE]]
+ amdgpu.make_dma_descriptor %base
+ // CHECK-SAME: globalSize [64, 64]
+ globalSize [64, 64]
+ // CHECK-SAME: globalStride [64, 1]
+ globalStride [64, 1]
+ // CHECK-SAME: sharedSize [64, 64]
+ sharedSize [64, 64]
+ // CHECK-SAME: atomicBarrier(%[[BARRIER]][%[[IDX]]] : memref<8xi32, #gpu.address_space<workgroup>>)
+ atomicBarrier(%barrier[%idx] : memref<8xi32, #gpu.address_space<workgroup>>)
+ : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor
+
+ // CHECK: amdgpu.make_dma_descriptor %[[BASE]]
+ amdgpu.make_dma_descriptor %base
+ // CHECK-SAME: globalSize [64, 64]
+ globalSize [64, 64]
+ // CHECK-SAME: globalStride [64, 1]
+ globalStride [64, 1]
+ // CHECK-SAME: sharedSize [64, 64]
+ sharedSize [64, 64]
+ // CHECK-SAME: iterate %[[IDX]], %[[IDX]], %[[IDX]]
+ iterate %idx, %idx, %idx
+ : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor
+
func.return
}
diff --git a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir
index b616632..b062736 100644
--- a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir
+++ b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir
@@ -243,6 +243,106 @@ func.func @vecdim_reduction_ori(%in: memref<256x512xi32>, %out: memref<256xi32>)
// CHECK: affine.store %[[final_red]], %{{.*}} : memref<256xi32>
// CHECK: }
+// -----
+
+func.func @vecdim_reduction_xori(%in: memref<256x512xi32>, %out: memref<256xi32>) {
+ %cst = arith.constant 0 : i32
+ affine.for %i = 0 to 256 {
+ %final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (i32) {
+ %ld = affine.load %in[%i, %j] : memref<256x512xi32>
+ %xor = arith.xori %red_iter, %ld : i32
+ affine.yield %xor : i32
+ }
+ affine.store %final_red, %out[%i] : memref<256xi32>
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @vecdim_reduction_xori(
+// CHECK-SAME: %[[input:.*]]: memref<256x512xi32>,
+// CHECK-SAME: %[[output:.*]]: memref<256xi32>) {
+// CHECK: %[[cst:.*]] = arith.constant 0 : i32
+// CHECK: affine.for %{{.*}} = 0 to 256 {
+// CHECK: %[[vzero:.*]] = arith.constant dense<0> : vector<128xi32>
+// CHECK: %[[vred:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[red_iter:.*]] = %[[vzero]]) -> (vector<128xi32>) {
+// CHECK: %[[poison:.*]] = ub.poison : i32
+// CHECK: %[[ld:.*]] = vector.transfer_read %[[input]]{{\[}}%{{.*}}, %{{.*}}], %[[poison]] : memref<256x512xi32>, vector<128xi32>
+// CHECK: %[[xor:.*]] = arith.xori %[[red_iter]], %[[ld]] : vector<128xi32>
+// CHECK: affine.yield %[[xor]] : vector<128xi32>
+// CHECK: }
+// CHECK: %[[final_red:.*]] = vector.reduction <xor>, %[[vred]] : vector<128xi32> into i32
+// CHECK: affine.store %[[final_red]], %[[output]]{{\[}}%{{.*}}] : memref<256xi32>
+// CHECK: }
+// CHECK: return
+// CHECK: }
+
+// -----
+
+func.func @vecdim_reduction_minnumf(%in: memref<256x512xf32>, %out: memref<256xf32>) {
+ %cst = arith.constant 0xFF800000 : f32
+ affine.for %i = 0 to 256 {
+ %final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) {
+ %ld = affine.load %in[%i, %j] : memref<256x512xf32>
+ %min = arith.minnumf %red_iter, %ld : f32
+ affine.yield %min : f32
+ }
+ affine.store %final_red, %out[%i] : memref<256xf32>
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @vecdim_reduction_minnumf(
+// CHECK-SAME: %[[input:.*]]: memref<256x512xf32>,
+// CHECK-SAME: %[[output:.*]]: memref<256xf32>) {
+// CHECK: %[[cst:.*]] = arith.constant 0xFF800000 : f32
+// CHECK: affine.for %{{.*}} = 0 to 256 {
+// CHECK: %[[vzero:.*]] = arith.constant dense<0x7FC00000> : vector<128xf32>
+// CHECK: %[[vred:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[red_iter:.*]] = %[[vzero]]) -> (vector<128xf32>) {
+// CHECK: %[[poison:.*]] = ub.poison : f32
+// CHECK: %[[ld:.*]] = vector.transfer_read %[[input]]{{\[}}%{{.*}}, %{{.*}}], %[[poison]] : memref<256x512xf32>, vector<128xf32>
+// CHECK: %[[min:.*]] = arith.minnumf %[[red_iter]], %[[ld]] : vector<128xf32>
+// CHECK: affine.yield %[[min]] : vector<128xf32>
+// CHECK: }
+// CHECK: %[[red_scalar:.*]] = vector.reduction <minnumf>, %[[vred]] : vector<128xf32> into f32
+// CHECK: %[[final_red:.*]] = arith.minnumf %[[red_scalar]], %[[cst]] : f32
+// CHECK: affine.store %[[final_red]], %[[output]]{{\[}}%{{.*}}] : memref<256xf32>
+// CHECK: }
+// CHECK: return
+// CHECK: }
+
+// -----
+
+func.func @vecdim_reduction_maxnumf(%in: memref<256x512xf32>, %out: memref<256xf32>) {
+ %cst = arith.constant 0xFF800000 : f32
+ affine.for %i = 0 to 256 {
+ %final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) {
+ %ld = affine.load %in[%i, %j] : memref<256x512xf32>
+ %max = arith.maxnumf %red_iter, %ld : f32
+ affine.yield %max : f32
+ }
+ affine.store %final_red, %out[%i] : memref<256xf32>
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @vecdim_reduction_maxnumf(
+// CHECK-SAME: %[[input:.*]]: memref<256x512xf32>,
+// CHECK-SAME: %[[output:.*]]: memref<256xf32>) {
+// CHECK: %[[cst:.*]] = arith.constant 0xFF800000 : f32
+// CHECK: affine.for %{{.*}} = 0 to 256 {
+// CHECK: %[[vzero:.*]] = arith.constant dense<0xFFC00000> : vector<128xf32>
+// CHECK: %[[vred:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[red_iter:.*]] = %[[vzero]]) -> (vector<128xf32>) {
+// CHECK: %[[poison:.*]] = ub.poison : f32
+// CHECK: %[[ld:.*]] = vector.transfer_read %[[input]]{{\[}}%{{.*}}, %{{.*}}], %[[poison]] : memref<256x512xf32>, vector<128xf32>
+// CHECK: %[[max:.*]] = arith.maxnumf %[[red_iter]], %[[ld]] : vector<128xf32>
+// CHECK: affine.yield %[[max]] : vector<128xf32>
+// CHECK: }
+// CHECK: %[[red_scalar:.*]] = vector.reduction <maxnumf>, %[[vred]] : vector<128xf32> into f32
+// CHECK: %[[final_red:.*]] = arith.maxnumf %[[red_scalar]], %[[cst]] : f32
+// CHECK: affine.store %[[final_red]], %[[output]]{{\[}}%{{.*}}] : memref<256xf32>
+// CHECK: }
+// CHECK: return
+// CHECK: }
// -----
diff --git a/mlir/test/Dialect/Affine/loop-coalescing.mlir b/mlir/test/Dialect/Affine/loop-coalescing.mlir
index 3be14ea..6a82532 100644
--- a/mlir/test/Dialect/Affine/loop-coalescing.mlir
+++ b/mlir/test/Dialect/Affine/loop-coalescing.mlir
@@ -416,3 +416,31 @@ func.func @test_loops_do_not_get_coalesced() {
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return
+
+// -----
+
+// CHECK-LABEL: func @inner_loop_has_iter_args
+// CHECK-SAME: %[[ALLOC:.*]]: memref<?xi64>)
+func.func @inner_loop_has_iter_args(%alloc : memref<?xi64>) {
+ %c17 = arith.constant 17 : index
+ affine.for %arg0 = 0 to 79 {
+ %0 = affine.for %arg1 = 0 to 64 iter_args(%arg2 = %alloc) -> (memref<?xi64>) {
+ %1 = arith.remui %arg1, %c17 : index
+ %2 = arith.index_cast %arg1 : index to i64
+ memref.store %2, %arg2[%1] : memref<?xi64>
+ affine.yield %arg2 : memref<?xi64>
+ }
+ }
+ return
+}
+
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant 17 : index
+// CHECK: %[[APPLY_0:.*]] = affine.apply affine_map<() -> (79)>()
+// CHECK: %[[APPLY_1:.*]] = affine.apply affine_map<() -> (64)>()
+// CHECK: %[[APPLY_2:.*]] = affine.apply affine_map<(d0)[s0] -> (d0 * s0)>(%[[APPLY_0]]){{\[}}%[[APPLY_1]]]
+// CHECK: affine.for %[[IV:.*]] = 0 to %[[APPLY_2]] {
+// CHECK: %[[APPLY_3:.*]] = affine.apply affine_map<(d0)[s0] -> (d0 mod s0)>(%[[IV]]){{\[}}%[[APPLY_1]]]
+// CHECK: %[[REMUI_0:.*]] = arith.remui %[[APPLY_3]], %[[CONSTANT_0]] : index
+// CHECK: %[[INDEX_CAST_0:.*]] = arith.index_cast %[[APPLY_3]] : index to i64
+// CHECK: memref.store %[[INDEX_CAST_0]], %[[ALLOC]]{{\[}}%[[REMUI_0]]] : memref<?xi64>
+// CHECK: }
diff --git a/mlir/test/Dialect/Affine/value-bounds-reification.mlir b/mlir/test/Dialect/Affine/value-bounds-reification.mlir
index 817614b..2e80102 100644
--- a/mlir/test/Dialect/Affine/value-bounds-reification.mlir
+++ b/mlir/test/Dialect/Affine/value-bounds-reification.mlir
@@ -36,13 +36,13 @@ func.func @reify_through_chain(%sz0: index, %sz2: index) -> (index, index, index
// CHECK: "test.some_use"(%[[c5]])
// CHECK: %[[c5:.*]] = arith.constant 5 : index
// CHECK: "test.some_use"(%[[c5]])
-func.func @reify_slice_bound(%t: tensor<?x?xi32>, %idx: index, %ub: index, %f: f32) {
+func.func @reify_slice_bound(%t: tensor<?x?xi32>, %idx: index, %ub: index, %f: i32) {
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
scf.for %iv = %c0 to %ub step %c4 {
%sz = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 4)>(%iv)[%ub]
%slice = tensor.extract_slice %t[%idx, %iv] [1, %sz] [1, 1] : tensor<?x?xi32> to tensor<1x?xi32>
- %filled = linalg.fill ins(%f : f32) outs(%slice : tensor<1x?xi32>) -> tensor<1x?xi32>
+ %filled = linalg.fill ins(%f : i32) outs(%slice : tensor<1x?xi32>) -> tensor<1x?xi32>
%bound = "test.reify_bound"(%filled) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index)
"test.some_use"(%bound) : (index) -> ()
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 2fe0995..3ad1530 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2958,6 +2958,19 @@ func.func @truncIShrSIToTrunciShrUI(%a: i64) -> i32 {
return %hi : i32
}
+// CHECK-LABEL: @truncIShrSIExactToTrunciShrUIExact
+// CHECK-SAME: (%[[A:.+]]: i64)
+// CHECK-NEXT: %[[C32:.+]] = arith.constant 32 : i64
+// CHECK-NEXT: %[[SHR:.+]] = arith.shrui %[[A]], %[[C32]] exact : i64
+// CHECK-NEXT: %[[TRU:.+]] = arith.trunci %[[SHR]] : i64 to i32
+// CHECK-NEXT: return %[[TRU]] : i32
+func.func @truncIShrSIExactToTrunciShrUIExact(%a: i64) -> i32 {
+ %c32 = arith.constant 32: i64
+ %sh = arith.shrsi %a, %c32 exact : i64
+ %hi = arith.trunci %sh: i64 to i32
+ return %hi : i32
+}
+
// CHECK-LABEL: @truncIShrSIToTrunciShrUIBadShiftAmt1
// CHECK: arith.shrsi
func.func @truncIShrSIToTrunciShrUIBadShiftAmt1(%a: i64) -> i32 {
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 1e656e8..58eadfd 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -151,6 +151,12 @@ func.func @test_divui(%arg0 : i64, %arg1 : i64) -> i64 {
return %0 : i64
}
+// CHECK-LABEL: test_divui_exact
+func.func @test_divui_exact(%arg0 : i64, %arg1 : i64) -> i64 {
+ %0 = arith.divui %arg0, %arg1 exact : i64
+ return %0 : i64
+}
+
// CHECK-LABEL: test_divui_tensor
func.func @test_divui_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> {
%0 = arith.divui %arg0, %arg1 : tensor<8x8xi64>
@@ -175,6 +181,12 @@ func.func @test_divsi(%arg0 : i64, %arg1 : i64) -> i64 {
return %0 : i64
}
+// CHECK-LABEL: test_divsi_exact
+func.func @test_divsi_exact(%arg0 : i64, %arg1 : i64) -> i64 {
+ %0 = arith.divsi %arg0, %arg1 exact : i64
+ return %0 : i64
+}
+
// CHECK-LABEL: test_divsi_tensor
func.func @test_divsi_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> {
%0 = arith.divsi %arg0, %arg1 : tensor<8x8xi64>
@@ -391,6 +403,12 @@ func.func @test_shrui(%arg0 : i64, %arg1 : i64) -> i64 {
return %0 : i64
}
+// CHECK-LABEL: test_shrui_exact
+func.func @test_shrui_exact(%arg0 : i64, %arg1 : i64) -> i64 {
+ %0 = arith.shrui %arg0, %arg1 exact : i64
+ return %0 : i64
+}
+
// CHECK-LABEL: test_shrui_tensor
func.func @test_shrui_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> {
%0 = arith.shrui %arg0, %arg1 : tensor<8x8xi64>
@@ -415,6 +433,12 @@ func.func @test_shrsi(%arg0 : i64, %arg1 : i64) -> i64 {
return %0 : i64
}
+// CHECK-LABEL: test_shrsi_exact
+func.func @test_shrsi_exact(%arg0 : i64, %arg1 : i64) -> i64 {
+ %0 = arith.shrsi %arg0, %arg1 exact : i64
+ return %0 : i64
+}
+
// CHECK-LABEL: test_shrsi_tensor
func.func @test_shrsi_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> {
%0 = arith.shrsi %arg0, %arg1 : tensor<8x8xi64>
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
index 8249d59..3929f5b 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -368,21 +368,18 @@ func.func @multiple_materialize_in_destination_buffer(%m: memref<5xf32>, %f: f32
// -----
-// `EmptyTensorElimination` fails to find a valid insertion
-// point for the new injected `SubsetExtraction`.
-// CHECK-LABEL: func.func @fail_to_eliminate_any_empty_tensors
-func.func @fail_to_eliminate_any_empty_tensors() -> tensor<5x6x128xf32> {
+// CHECK-LABEL: func.func @eliminate_all_empty_tensors
+func.func @eliminate_all_empty_tensors() -> tensor<5x6x128xf32> {
%cst_1 = arith.constant 1.0 : f32
%cst_2 = arith.constant 2.0 : f32
- // CHECK: memref.alloc
- // CHECK: memref.alloc
- // CHECK: memref.alloc
+ // CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32>
+ // CHECK-NOT: memref.alloc
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
%empty_2 = tensor.empty() : tensor<5x6x64xf32>
%res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
%cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
- // CHECK: memref.copy
+ // CHECK-NOT: memref.copy
%inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
%inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
@@ -392,20 +389,19 @@ func.func @fail_to_eliminate_any_empty_tensors() -> tensor<5x6x128xf32> {
// -----
-// CHECK-LABEL: func.func @succeed_to_eliminate_one_empty_tensor
-func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> {
+// CHECK-LABEL: func.func @eliminate_concatenated_empty_tensors
+func.func @eliminate_concatenated_empty_tensors() -> tensor<5x6x128xf32> {
%cst_1 = arith.constant 1.0 : f32
%cst_2 = arith.constant 2.0 : f32
// CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32>
- // CHECK: memref.alloc
// CHECK-NOT: memref.alloc
- %cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
+ %concatenated_empty = tensor.empty() : tensor<5x6x128xf32>
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
%empty_2 = tensor.empty() : tensor<5x6x64xf32>
%res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
- // CHECK: memref.copy
- %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
+ // CHECK-NOT: memref.copy
+ %inserted_slice_1 = tensor.insert_slice %res_1 into %concatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
%inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
@@ -420,20 +416,22 @@ func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> {
// CHECK-ELIM-LABEL: func.func @multi_use_of_the_same_tensor_empty
// CHECK-LABEL: func.func @multi_use_of_the_same_tensor_empty
+// CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32>
+// CHECK-NOT: memref.alloc
+// CHECK-NOT: memref.copy
+// CHECK-ELIM: tensor.extract_slice {{.*}}[0, 0, 0]
+// CHECK-ELIM: linalg.fill
+// CHECK-ELIM: tensor.extract_slice {{.*}}[0, 0, 64]
+// CHECK-ELIM: linalg.fill
func.func @multi_use_of_the_same_tensor_empty() -> tensor<5x6x128xf32> {
%cst_1 = arith.constant 1.0 : f32
%cst_2 = arith.constant 2.0 : f32
%cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
- // CHECK-ELIM: %[[VAL_3:.*]] = tensor.extract_slice
- // CHECK-ELIM: linalg.fill ins(%[[VAL_0:.*]] : f32) outs(%[[VAL_3]]
- // CHECK-ELIM-NOT: linalg.fill ins(%[[VAL_1:.*]] : f32) outs(%[[VAL_3]]
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
%res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
- // CHECK: memref.copy
%inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
- // CHECK-NOT: memref.copy
%inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
return %inserted_slice_2 : tensor<5x6x128xf32>
@@ -476,3 +474,66 @@ func.func @direct_use_of_tensor_empty(%arg0: tensor<5x6x128xf32>) -> tensor<5x6x
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
return %inserted_slice_1 : tensor<5x6x128xf32>
}
+
+// -----
+
+// Test that dependent pure operations are moved before the
+// insertion point to enable empty tensor elimination.
+
+// CHECK-LABEL: func.func @move_dependent_arith_op(
+// CHECK-SAME: %[[ARG0:.*]]: memref<10xf32>
+// CHECK-SAME: %[[ARG1:.*]]: index
+// CHECK-NOT: memref.alloc
+// CHECK: %[[C5:.*]] = arith.constant 5 : index
+// CHECK: %[[OFFSET:.*]] = arith.addi %[[ARG1]], %[[C5]]
+// CHECK: %[[SV:.*]] = memref.subview %[[ARG0]][%[[OFFSET]]] [5] [1]
+// CHECK: linalg.fill {{.*}} outs(%[[SV]]
+// CHECK: return %[[ARG0]]
+// CHECK-ELIM-LABEL: func.func @move_dependent_arith_op(
+// CHECK-ELIM-SAME: %[[ARG0:.*]]: tensor<10xf32>
+// CHECK-ELIM-SAME: %[[ARG1:.*]]: index
+// CHECK-ELIM: %[[C5:.*]] = arith.constant 5 : index
+// CHECK-ELIM: %[[OFFSET:.*]] = arith.addi %[[ARG1]], %[[C5]]
+// CHECK-ELIM: %[[SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[OFFSET]]] [5] [1]
+// CHECK-ELIM: %[[FILL:.*]] = linalg.fill {{.*}} outs(%[[SLICE]]
+// CHECK-ELIM: tensor.insert_slice %[[FILL]] into %[[ARG0]][%[[OFFSET]]]
+func.func @move_dependent_arith_op(
+ %arg0: tensor<10xf32> {bufferization.buffer_layout = affine_map<(d0) -> (d0)>, bufferization.writable = true},
+ %arg1: index, %f: f32) -> tensor<10xf32>
+{
+ %0 = tensor.empty() : tensor<5xf32>
+ %1 = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
+ %c5 = arith.constant 5 : index
+ %offset = arith.addi %arg1, %c5 : index
+ %2 = tensor.insert_slice %1 into %arg0[%offset][5][1]
+ : tensor<5xf32> into tensor<10xf32>
+ return %2 : tensor<10xf32>
+}
+
+// -----
+
+// Test that side-effecting operations are not moved, preventing empty
+// tensor elimination.
+
+// CHECK-LABEL: func.func @side_effecting_op_blocks_movement(
+// CHECK: memref.alloc
+// CHECK: linalg.fill
+// CHECK: memref.load
+// CHECK: memref.subview
+// CHECK: memref.copy
+// CHECK-ELIM-LABEL: func.func @side_effecting_op_blocks_movement(
+// CHECK-ELIM: tensor.empty
+// CHECK-ELIM: linalg.fill
+// CHECK-ELIM: memref.load
+// CHECK-ELIM: tensor.insert_slice
+func.func @side_effecting_op_blocks_movement(
+ %arg0: tensor<10xf32> {bufferization.buffer_layout = affine_map<(d0) -> (d0)>, bufferization.writable = true},
+ %mem: memref<index>, %f: f32) -> tensor<10xf32>
+{
+ %0 = tensor.empty() : tensor<5xf32>
+ %1 = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
+ %offset = memref.load %mem[] : memref<index>
+ %2 = tensor.insert_slice %1 into %arg0[%offset][5][1]
+ : tensor<5xf32> into tensor<10xf32>
+ return %2 : tensor<10xf32>
+}
diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir
index 2c8807b..9884b04 100644
--- a/mlir/test/Dialect/Bufferization/invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/invalid.mlir
@@ -127,3 +127,63 @@ func.func @invalid_manual_deallocation() {
// expected-error @below{{op attribute 'bufferization.manual_deallocation' can be used only on ops that have an allocation and/or free side effect}}
arith.constant {bufferization.manual_deallocation} 0 : index
}
+
+// -----
+
+func.func @invalid_rank_to_buffer(%t: tensor<1x2x3x4xf32>) {
+ // expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}}
+ // expected-error @below{{shapes do not match}}
+ %b = bufferization.to_buffer %t
+ : tensor<1x2x3x4xf32> to memref<1x2x3xf32>
+ return
+}
+
+// -----
+
+func.func @invalid_rank_to_tensor(%b: memref<1x2x3xf32>) {
+ // expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}}
+ // expected-error @below{{shapes do not match}}
+ %t = bufferization.to_tensor %b
+ : memref<1x2x3xf32> to tensor<1x2x3x4xf32>
+ return
+}
+
+// -----
+
+func.func @invalid_shape_to_buffer(%t: tensor<1x2x3x4xf32>) {
+ // expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}}
+ // expected-error @below{{shapes do not match}}
+ %b = bufferization.to_buffer %t
+ : tensor<1x2x3x4xf32> to memref<1x2x4x3xf32>
+ return
+}
+
+// -----
+
+func.func @invalid_shape_to_tensor(%b: memref<1x2x4x3xf32>) {
+ // expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}}
+ // expected-error @below{{shapes do not match}}
+ %t = bufferization.to_tensor %b
+ : memref<1x2x4x3xf32> to tensor<1x2x3x4xf32>
+ return
+}
+
+// -----
+
+func.func @invalid_type_to_buffer(%t: tensor<1x2x3x4xf32>) {
+ // expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}}
+ // expected-error @below{{element types do not match}}
+ %b = bufferization.to_buffer %t
+ : tensor<1x2x3x4xf32> to memref<1x2x3x4xf16>
+ return
+}
+
+// -----
+
+func.func @invalid_type_to_tensor(%b: memref<1x2x3x4xf16>) {
+ // expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}}
+ // expected-error @below{{element types do not match}}
+ %t2 = bufferization.to_tensor %b
+ : memref<1x2x3x4xf16> to tensor<1x2x3x4xf32>
+ return
+}
diff --git a/mlir/test/Dialect/Bufferization/ops.mlir b/mlir/test/Dialect/Bufferization/ops.mlir
index fc6df4a..b0db1bb 100644
--- a/mlir/test/Dialect/Bufferization/ops.mlir
+++ b/mlir/test/Dialect/Bufferization/ops.mlir
@@ -83,3 +83,40 @@ func.func @test_dealloc_op(%arg0: memref<2xf32>, %arg1: memref<4xi32>,
bufferization.dealloc
return %0#0, %0#1 : i1, i1
}
+
+// CHECK: func.func @test_builtin_custom_builtin_type_conversion
+// CHECK-SAME: (%[[t:.*]]: tensor<42xf32>) -> tensor<42xf32>
+func.func @test_builtin_custom_builtin_type_conversion(%t: tensor<42xf32>)
+ -> tensor<42xf32> {
+ // CHECK: %[[buffer:.*]] = bufferization.to_buffer %[[t]]
+ // CHECK-SAME: to !test.test_memref<[42], f32>
+ %buffer = bufferization.to_buffer %t
+ : tensor<42xf32> to !test.test_memref<[42], f32>
+
+ // CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[buffer]]
+ // CHECK-SAME: to tensor<42xf32>
+ %tensor = bufferization.to_tensor %buffer
+ : !test.test_memref<[42], f32> to tensor<42xf32>
+
+ // CHECK: return %[[tensor]]
+ return %tensor : tensor<42xf32>
+}
+
+// CHECK: func.func @test_custom_builtin_custom_type_conversion
+// CHECK-SAME: (%[[t:.*]]: !test.test_tensor<[42], f32>)
+// CHECK-SAME: -> !test.test_tensor<[42], f32>
+func.func @test_custom_builtin_custom_type_conversion(%t: !test.test_tensor<[42], f32>)
+ -> !test.test_tensor<[42], f32> {
+ // CHECK: %[[buffer:.*]] = bufferization.to_buffer %[[t]]
+ // CHECK-SAME: to memref<42xf32>
+ %buffer = bufferization.to_buffer %t
+ : !test.test_tensor<[42], f32> to memref<42xf32>
+
+ // CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[buffer]]
+ // CHECK-SAME: to !test.test_tensor<[42], f32>
+ %tensor = bufferization.to_tensor %buffer
+ : memref<42xf32> to !test.test_tensor<[42], f32>
+
+ // CHECK: return %[[tensor]]
+ return %tensor : !test.test_tensor<[42], f32>
+}
diff --git a/mlir/test/Dialect/ControlFlow/canonicalize.mlir b/mlir/test/Dialect/ControlFlow/canonicalize.mlir
index 17f7d28..21a1678 100644
--- a/mlir/test/Dialect/ControlFlow/canonicalize.mlir
+++ b/mlir/test/Dialect/ControlFlow/canonicalize.mlir
@@ -634,3 +634,25 @@ func.func @unsimplified_cycle_2(%c : i1) {
^bb7:
cf.br ^bb6
}
+
+// CHECK-LABEL: @drop_unreachable_branch_1
+// CHECK-NEXT: "test.foo"() : () -> ()
+// CHECK-NEXT: return
+func.func @drop_unreachable_branch_1(%c: i1) {
+ cf.cond_br %c, ^bb1, ^bb2
+^bb1:
+ "test.foo"() : () -> ()
+ return
+^bb2:
+ ub.unreachable
+}
+
+// CHECK-LABEL: @drop_unreachable_branch_2
+// CHECK-NEXT: ub.unreachable
+func.func @drop_unreachable_branch_2(%c: i1) {
+ cf.cond_br %c, ^bb1, ^bb2
+^bb1:
+ ub.unreachable
+^bb2:
+ ub.unreachable
+}
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 5f594fb..d1601be 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -876,3 +876,57 @@ func.func @test_do(%arg0 : !emitc.ptr<i32>) {
return
}
+
+// -----
+
+func.func @test_for_none_block_argument(%arg0: index) {
+ // expected-error@+1 {{expected body to have a single block argument for the induction variable}}
+ "emitc.for"(%arg0, %arg0, %arg0) (
+ {
+ emitc.yield
+ }
+ ) : (index, index, index) -> ()
+ return
+}
+
+// -----
+
+func.func @test_for_more_than_one_block_argument(%arg0: index) {
+ // expected-error@+1 {{expected body to have a single block argument for the induction variable}}
+ "emitc.for"(%arg0, %arg0, %arg0) (
+ {
+ ^bb0(%i0 : index, %i1 : index):
+ emitc.yield
+ }
+ ) : (index, index, index) -> ()
+ return
+}
+
+// -----
+
+func.func @test_for_unmatch_type(%arg0: index) {
+ // expected-error@+1 {{expected induction variable to be same type as bounds}}
+ "emitc.for"(%arg0, %arg0, %arg0) (
+ {
+ ^bb0(%i0 : f32):
+ emitc.yield
+ }
+ ) : (index, index, index) -> ()
+ return
+}
+
+// -----
+
+func.func @address_of(%arg0: !emitc.lvalue<i32>) {
+ // expected-error @+1 {{failed to verify that input and result reference the same type}}
+ %1 = "emitc.address_of"(%arg0) : (!emitc.lvalue<i32>) -> !emitc.ptr<i8>
+ return
+}
+
+// -----
+
+func.func @dereference(%arg0: !emitc.ptr<i32>) {
+ // expected-error @+1 {{failed to verify that input and result reference the same type}}
+ %1 = "emitc.dereference"(%arg0) : (!emitc.ptr<i32>) -> !emitc.lvalue<i8>
+ return
+}
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index 1259748..b2c8b84 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -355,3 +355,13 @@ func.func @do(%arg0 : !emitc.ptr<i32>) {
return
}
+
+func.func @address_of(%arg0: !emitc.lvalue<i32>) {
+ %1 = emitc.address_of %arg0 : !emitc.lvalue<i32>
+ return
+}
+
+func.func @dereference(%arg0: !emitc.ptr<i32>) {
+ %1 = emitc.dereference %arg0 : !emitc.ptr<i32>
+ return
+}
diff --git a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
index 1f8da78..bc04e8f 100644
--- a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
+++ b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s --split-input-file --duplicate-function-elimination | \
-// RUN: FileCheck %s
+// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-2,CHECK-3
func.func @identity(%arg0: tensor<f32>) -> tensor<f32> {
return %arg0 : tensor<f32>
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 35381da..26bcf94 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -688,7 +688,7 @@ func.func @mmamatrix_operand_type(){
func.func @mmamatrix_invalid_element_type(){
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
%i = arith.constant 16 : index
- // expected-error @+1 {{MMAMatrixType elements must be SI8, UI8, I32, F16, or F32}}
+ // expected-error @+1 {{MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64}}
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xbf16, "AOp">
return
}
@@ -708,7 +708,7 @@ func.func @mmaLoadOp_identity_layout(){
// -----
func.func @mma_invalid_memref_type(%src: memref<32x4xvector<4x8xf32>>, %i: index) {
- // expected-error @+1 {{operand #0 must be memref of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float or vector of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float values of ranks 1 values}}
+ // expected-error @+1 {{operand #0 must be memref of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float or 64-bit float or vector of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float or 64-bit float values of ranks 1 values}}
%0 = gpu.subgroup_mma_load_matrix %src[%i, %i] {leadDimension = 4 : index} : memref<32x4xvector<4x8xf32>> -> !gpu.mma_matrix<16x16xf16, "AOp">
return
}
diff --git a/mlir/test/Dialect/IRDL/variadics.mlir b/mlir/test/Dialect/IRDL/variadics.mlir
index a8871fc..873f248 100644
--- a/mlir/test/Dialect/IRDL/variadics.mlir
+++ b/mlir/test/Dialect/IRDL/variadics.mlir
@@ -133,7 +133,7 @@ func.func @testOptOperandFail(%x: i16) {
// Check that an operation with multiple variadics expects the segment size
// attribute
func.func @testMultOperandsMissingSegment(%x: i16, %z: i64) {
- // expected-error@+1 {{'operand_segment_sizes' attribute is expected but not provided}}
+ // expected-error@+1 {{'operandSegmentSizes' attribute is expected but not provided}}
"testvar.var_and_opt_operand"(%x, %x, %z) : (i16, i16, i64) -> ()
return
}
@@ -143,8 +143,8 @@ func.func @testMultOperandsMissingSegment(%x: i16, %z: i64) {
// Check that an operation with multiple variadics expects the segment size
// attribute of the right type
func.func @testMultOperandsWrongSegmentType(%x: i16, %z: i64) {
- // expected-error@+1 {{'operand_segment_sizes' attribute is expected to be a dense i32 array}}
- "testvar.var_and_opt_operand"(%x, %x, %z) {operand_segment_sizes = i32} : (i16, i16, i64) -> ()
+ // expected-error@+1 {{'operandSegmentSizes' attribute is expected to be a dense i32 array}}
+ "testvar.var_and_opt_operand"(%x, %x, %z) {operandSegmentSizes = i32} : (i16, i16, i64) -> ()
return
}
@@ -153,12 +153,12 @@ func.func @testMultOperandsWrongSegmentType(%x: i16, %z: i64) {
// Check that an operation with multiple variadics with the right segment size
// verifies.
func.func @testMultOperands(%x: i16, %y: i32, %z: i64) {
- "testvar.var_and_opt_operand"(%x, %x, %z) {operand_segment_sizes = array<i32: 2, 0, 1>} : (i16, i16, i64) -> ()
- // CHECK: "testvar.var_and_opt_operand"(%{{.*}}, %{{.*}}, %{{.*}}) {operand_segment_sizes = array<i32: 2, 0, 1>} : (i16, i16, i64) -> ()
- "testvar.var_and_opt_operand"(%x, %x, %y, %z) {operand_segment_sizes = array<i32: 2, 1, 1>} : (i16, i16, i32, i64) -> ()
- // CHECK-NEXT: "testvar.var_and_opt_operand"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) {operand_segment_sizes = array<i32: 2, 1, 1>} : (i16, i16, i32, i64) -> ()
- "testvar.var_and_opt_operand"(%y, %z) {operand_segment_sizes = array<i32: 0, 1, 1>} : (i32, i64) -> ()
- // CHECK-NEXT: "testvar.var_and_opt_operand"(%{{.*}}, %{{.*}}) {operand_segment_sizes = array<i32: 0, 1, 1>} : (i32, i64) -> ()
+ "testvar.var_and_opt_operand"(%x, %x, %z) {operandSegmentSizes = array<i32: 2, 0, 1>} : (i16, i16, i64) -> ()
+ // CHECK: "testvar.var_and_opt_operand"(%{{.*}}, %{{.*}}, %{{.*}}) {operandSegmentSizes = array<i32: 2, 0, 1>} : (i16, i16, i64) -> ()
+ "testvar.var_and_opt_operand"(%x, %x, %y, %z) {operandSegmentSizes = array<i32: 2, 1, 1>} : (i16, i16, i32, i64) -> ()
+ // CHECK-NEXT: "testvar.var_and_opt_operand"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) {operandSegmentSizes = array<i32: 2, 1, 1>} : (i16, i16, i32, i64) -> ()
+ "testvar.var_and_opt_operand"(%y, %z) {operandSegmentSizes = array<i32: 0, 1, 1>} : (i32, i64) -> ()
+ // CHECK-NEXT: "testvar.var_and_opt_operand"(%{{.*}}, %{{.*}}) {operandSegmentSizes = array<i32: 0, 1, 1>} : (i32, i64) -> ()
return
}
@@ -166,8 +166,8 @@ func.func @testMultOperands(%x: i16, %y: i32, %z: i64) {
// Check that the segment sizes expects non-negative values
func.func @testMultOperandsSegmentNegative() {
- // expected-error@+1 {{'operand_segment_sizes' attribute for specifying operand segments must have non-negative values}}
- "testvar.var_and_opt_operand"() {operand_segment_sizes = array<i32: 2, -1, 1>} : () -> ()
+ // expected-error@+1 {{'operandSegmentSizes' attribute for specifying operand segments must have non-negative values}}
+ "testvar.var_and_opt_operand"() {operandSegmentSizes = array<i32: 2, -1, 1>} : () -> ()
return
}
@@ -175,8 +175,8 @@ func.func @testMultOperandsSegmentNegative() {
// Check that the segment sizes expects 1 for single values
func.func @testMultOperandsSegmentWrongSingle() {
- // expected-error@+1 {{element 2 in 'operand_segment_sizes' attribute must be equal to 1}}
- "testvar.var_and_opt_operand"() {operand_segment_sizes = array<i32: 0, 0, 0>} : () -> ()
+ // expected-error@+1 {{element 2 in 'operandSegmentSizes' attribute must be equal to 1}}
+ "testvar.var_and_opt_operand"() {operandSegmentSizes = array<i32: 0, 0, 0>} : () -> ()
return
}
@@ -184,8 +184,8 @@ func.func @testMultOperandsSegmentWrongSingle() {
// Check that the segment sizes expects not more than 1 for optional values
func.func @testMultOperandsSegmentWrongOptional() {
- // expected-error@+1 {{element 1 in 'operand_segment_sizes' attribute must be equal to 0 or 1}}
- "testvar.var_and_opt_operand"() {operand_segment_sizes = array<i32: 0, 2, 0>} : () -> ()
+ // expected-error@+1 {{element 1 in 'operandSegmentSizes' attribute must be equal to 0 or 1}}
+ "testvar.var_and_opt_operand"() {operandSegmentSizes = array<i32: 0, 2, 0>} : () -> ()
return
}
@@ -193,8 +193,8 @@ func.func @testMultOperandsSegmentWrongOptional() {
// Check that the sum of the segment sizes should be equal to the number of operands
func.func @testMultOperandsSegmentWrongOptional(%y: i32, %z: i64) {
- // expected-error@+1 {{sum of elements in 'operand_segment_sizes' attribute must be equal to the number of operands}}
- "testvar.var_and_opt_operand"(%y, %z) {operand_segment_sizes = array<i32: 0, 0, 1>} : (i32, i64) -> ()
+ // expected-error@+1 {{sum of elements in 'operandSegmentSizes' attribute must be equal to the number of operands}}
+ "testvar.var_and_opt_operand"(%y, %z) {operandSegmentSizes = array<i32: 0, 0, 1>} : (i32, i64) -> ()
return
}
@@ -334,7 +334,7 @@ func.func @testOptResultFail() {
// Check that an operation with multiple variadics expects the segment size
// attribute
func.func @testMultResultsMissingSegment() {
- // expected-error@+1 {{'result_segment_sizes' attribute is expected but not provided}}
+ // expected-error@+1 {{'resultSegmentSizes' attribute is expected but not provided}}
"testvar.var_and_opt_result"() : () -> (i16, i16, i64)
return
}
@@ -344,8 +344,8 @@ func.func @testMultResultsMissingSegment() {
// Check that an operation with multiple variadics expects the segment size
// attribute of the right type
func.func @testMultResultsWrongSegmentType() {
- // expected-error@+1 {{'result_segment_sizes' attribute is expected to be a dense i32 array}}
- "testvar.var_and_opt_result"() {result_segment_sizes = i32} : () -> (i16, i16, i64)
+ // expected-error@+1 {{'resultSegmentSizes' attribute is expected to be a dense i32 array}}
+ "testvar.var_and_opt_result"() {resultSegmentSizes = i32} : () -> (i16, i16, i64)
return
}
@@ -354,12 +354,12 @@ func.func @testMultResultsWrongSegmentType() {
// Check that an operation with multiple variadics with the right segment size
// verifies.
func.func @testMultResults() {
- "testvar.var_and_opt_result"() {result_segment_sizes = array<i32: 2, 0, 1>} : () -> (i16, i16, i64)
- // CHECK: "testvar.var_and_opt_result"() {result_segment_sizes = array<i32: 2, 0, 1>} : () -> (i16, i16, i64)
- "testvar.var_and_opt_result"() {result_segment_sizes = array<i32: 2, 1, 1>} : () -> (i16, i16, i32, i64)
- // CHECK-NEXT: "testvar.var_and_opt_result"() {result_segment_sizes = array<i32: 2, 1, 1>} : () -> (i16, i16, i32, i64)
- "testvar.var_and_opt_result"() {result_segment_sizes = array<i32: 0, 1, 1>} : () -> (i32, i64)
- // CHECK-NEXT: "testvar.var_and_opt_result"() {result_segment_sizes = array<i32: 0, 1, 1>} : () -> (i32, i64)
+ "testvar.var_and_opt_result"() {resultSegmentSizes = array<i32: 2, 0, 1>} : () -> (i16, i16, i64)
+ // CHECK: "testvar.var_and_opt_result"() {resultSegmentSizes = array<i32: 2, 0, 1>} : () -> (i16, i16, i64)
+ "testvar.var_and_opt_result"() {resultSegmentSizes = array<i32: 2, 1, 1>} : () -> (i16, i16, i32, i64)
+ // CHECK-NEXT: "testvar.var_and_opt_result"() {resultSegmentSizes = array<i32: 2, 1, 1>} : () -> (i16, i16, i32, i64)
+ "testvar.var_and_opt_result"() {resultSegmentSizes = array<i32: 0, 1, 1>} : () -> (i32, i64)
+ // CHECK-NEXT: "testvar.var_and_opt_result"() {resultSegmentSizes = array<i32: 0, 1, 1>} : () -> (i32, i64)
return
}
@@ -367,8 +367,8 @@ func.func @testMultResults() {
// Check that the segment sizes expects non-negative values
func.func @testMultResultsSegmentNegative() {
- // expected-error@+1 {{'result_segment_sizes' attribute for specifying result segments must have non-negative values}}
- "testvar.var_and_opt_result"() {result_segment_sizes = array<i32: 2, -1, 1>} : () -> ()
+ // expected-error@+1 {{'resultSegmentSizes' attribute for specifying result segments must have non-negative values}}
+ "testvar.var_and_opt_result"() {resultSegmentSizes = array<i32: 2, -1, 1>} : () -> ()
return
}
@@ -376,8 +376,8 @@ func.func @testMultResultsSegmentNegative() {
// Check that the segment sizes expects 1 for single values
func.func @testMultResultsSegmentWrongSingle() {
- // expected-error@+1 {{element 2 in 'result_segment_sizes' attribute must be equal to 1}}
- "testvar.var_and_opt_result"() {result_segment_sizes = array<i32: 0, 0, 0>} : () -> ()
+ // expected-error@+1 {{element 2 in 'resultSegmentSizes' attribute must be equal to 1}}
+ "testvar.var_and_opt_result"() {resultSegmentSizes = array<i32: 0, 0, 0>} : () -> ()
return
}
@@ -385,8 +385,8 @@ func.func @testMultResultsSegmentWrongSingle() {
// Check that the segment sizes expects not more than 1 for optional values
func.func @testMultResultsSegmentWrongOptional() {
- // expected-error@+1 {{element 1 in 'result_segment_sizes' attribute must be equal to 0 or 1}}
- "testvar.var_and_opt_result"() {result_segment_sizes = array<i32: 0, 2, 0>} : () -> ()
+ // expected-error@+1 {{element 1 in 'resultSegmentSizes' attribute must be equal to 0 or 1}}
+ "testvar.var_and_opt_result"() {resultSegmentSizes = array<i32: 0, 2, 0>} : () -> ()
return
}
@@ -394,7 +394,7 @@ func.func @testMultResultsSegmentWrongOptional() {
// Check that the sum of the segment sizes should be equal to the number of results
func.func @testMultResultsSegmentWrongOptional() {
- // expected-error@+1 {{sum of elements in 'result_segment_sizes' attribute must be equal to the number of results}}
- "testvar.var_and_opt_result"() {result_segment_sizes = array<i32: 0, 0, 1>} : () -> (i32, i64)
+ // expected-error@+1 {{sum of elements in 'resultSegmentSizes' attribute must be equal to the number of results}}
+ "testvar.var_and_opt_result"() {resultSegmentSizes = array<i32: 0, 0, 1>} : () -> (i32, i64)
return
}
diff --git a/mlir/test/Dialect/Index/inliner-interface.mlir b/mlir/test/Dialect/Index/inliner-interface.mlir
new file mode 100644
index 0000000..4c3d106
--- /dev/null
+++ b/mlir/test/Dialect/Index/inliner-interface.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt %s -inline | FileCheck %s
+
+// CHECK-LABEL: @main
+func.func @main(%arg0: i32) -> index {
+ // CHECK-NOT: call
+ // CHECK: index.castu
+ %0 = call @f(%arg0) : (i32) -> index
+ return %0 : index
+}
+
+// CHECK-LABEL: @f
+func.func @f(%arg0: i32) -> index {
+ %0 = index.castu %arg0 : i32 to index
+ return %0 : index
+}
diff --git a/mlir/test/Dialect/LLVMIR/add-debuginfo-func-scope.mlir b/mlir/test/Dialect/LLVMIR/add-debuginfo-func-scope.mlir
index dfbf992..ffeb871 100644
--- a/mlir/test/Dialect/LLVMIR/add-debuginfo-func-scope.mlir
+++ b/mlir/test/Dialect/LLVMIR/add-debuginfo-func-scope.mlir
@@ -141,3 +141,22 @@ module {
llvm.func @func_callsiteloc() loc(callsite("foo" at "mysource.cc":10:8))
} loc(unknown)
+// -----
+
+// CHECK-LABEL: llvm.func @func_cross_file_op()
+// CHECK: #di_file = #llvm.di_file<"<unknown>" in "">
+// CHECK: #di_file1 = #llvm.di_file<"caller.py" in "">
+// CHECK: #di_file2 = #llvm.di_file<"callee.py" in "">
+// CHECK: #di_subroutine_type = #llvm.di_subroutine_type<callingConvention = DW_CC_normal>
+// CHECK: #di_subprogram = #llvm.di_subprogram<id = distinct[1]<>, compileUnit = #di_compile_unit, scope = #di_file1, name = "func_cross_file_op", linkageName = "func_cross_file_op", file = #di_file1, line = 5, scopeLine = 5, subprogramFlags = "Definition|Optimized", type = #di_subroutine_type>
+// CHECK: #di_lexical_block_file = #llvm.di_lexical_block_file<scope = #di_subprogram, file = #di_file2, discriminator = 0>
+
+#loc = loc("caller.py":5:1)
+#loc1 = loc("callee.py":10:5)
+
+module {
+ llvm.func @func_cross_file_op() {
+ llvm.return loc(#loc1)
+ } loc(#loc)
+} loc(unknown)
+
diff --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir
index cec4586..094313c 100644
--- a/mlir/test/Dialect/LLVMIR/func.mlir
+++ b/mlir/test/Dialect/LLVMIR/func.mlir
@@ -210,8 +210,8 @@ module {
}
// CHECK-LABEL: llvm.func @memory_attr
- // CHECK-SAME: attributes {memory = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = readwrite>} {
- llvm.func @memory_attr() attributes {memory = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = readwrite>} {
+ // CHECK-SAME: attributes {memory = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = readwrite, errnoMem = none, targetMem0 = none, targetMem1 = none>} {
+ llvm.func @memory_attr() attributes {memory = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = readwrite, errnoMem = none, targetMem0 = none, targetMem1 = none>} {
llvm.return
}
diff --git a/mlir/test/Dialect/LLVMIR/inlining.mlir b/mlir/test/Dialect/LLVMIR/inlining.mlir
index 8e292f4..9a77c5e 100644
--- a/mlir/test/Dialect/LLVMIR/inlining.mlir
+++ b/mlir/test/Dialect/LLVMIR/inlining.mlir
@@ -422,7 +422,7 @@ llvm.func @test_byval(%ptr : !llvm.ptr) {
// -----
-llvm.func @with_byval_arg(%ptr : !llvm.ptr { llvm.byval = f64 }) attributes {memory_effects = #llvm.memory_effects<other = readwrite, argMem = read, inaccessibleMem = readwrite>} {
+llvm.func @with_byval_arg(%ptr : !llvm.ptr { llvm.byval = f64 }) attributes {memory_effects = #llvm.memory_effects<other = readwrite, argMem = read, inaccessibleMem = readwrite, errnoMem = none, targetMem0 = none, targetMem1 = none>} {
llvm.return
}
@@ -436,7 +436,7 @@ llvm.func @test_byval_read_only(%ptr : !llvm.ptr) {
// -----
-llvm.func @with_byval_arg(%ptr : !llvm.ptr { llvm.byval = f64 }) attributes {memory_effects = #llvm.memory_effects<other = readwrite, argMem = write, inaccessibleMem = readwrite>} {
+llvm.func @with_byval_arg(%ptr : !llvm.ptr { llvm.byval = f64 }) attributes {memory_effects = #llvm.memory_effects<other = readwrite, argMem = write, inaccessibleMem = readwrite, errnoMem = none, targetMem0 = none, targetMem1 = none>} {
llvm.return
}
@@ -451,7 +451,7 @@ llvm.func @test_byval_write_only(%ptr : !llvm.ptr) {
// -----
-llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read>} {
+llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read, errnoMem = none, targetMem0 = none, targetMem1 = none>} {
llvm.return
}
@@ -472,7 +472,7 @@ llvm.func @test_byval_input_aligned(%unaligned : !llvm.ptr, %aligned : !llvm.ptr
llvm.func @func_that_uses_ptr(%ptr : !llvm.ptr)
-llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read>} {
+llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read, errnoMem = none, targetMem0 = none, targetMem1 = none>} {
llvm.call @func_that_uses_ptr(%ptr) : (!llvm.ptr) -> ()
llvm.return
}
@@ -496,7 +496,7 @@ module attributes {
llvm.func @func_that_uses_ptr(%ptr : !llvm.ptr)
-llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read>} {
+llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read, errnoMem = none, targetMem0 = none, targetMem1 = none>} {
llvm.call @func_that_uses_ptr(%ptr) : (!llvm.ptr) -> ()
llvm.return
}
@@ -524,7 +524,7 @@ module attributes {
llvm.func @func_that_uses_ptr(%ptr : !llvm.ptr)
-llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read>} {
+llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read, errnoMem = none, targetMem0 = none, targetMem1 = none>} {
llvm.call @func_that_uses_ptr(%ptr) : (!llvm.ptr) -> ()
llvm.return
}
@@ -550,7 +550,7 @@ llvm.func @test_alignment_exceeded_anyway() {
llvm.mlir.global private @unaligned_global(42 : i64) : i64
llvm.mlir.global private @aligned_global(42 : i64) { alignment = 64 } : i64
-llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read>} {
+llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read, errnoMem = none, targetMem0 = none, targetMem1 = none>} {
llvm.return
}
diff --git a/mlir/test/Dialect/LLVMIR/invalid-cg-profile.mlir b/mlir/test/Dialect/LLVMIR/invalid-cg-profile.mlir
new file mode 100644
index 0000000..bdc98ed
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/invalid-cg-profile.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-translate %s -mlir-to-llvmir | FileCheck %s
+// CHECK: !llvm.module.flags = !{![[CG_FLAG:[0-9]+]], ![[DBG_FLAG:[0-9]+]]}
+// CHECK: ![[CG_FLAG]] = !{i32 5, !"CG Profile", ![[CG_LIST:[0-9]+]]}
+// CHECK: ![[CG_LIST]] = distinct !{![[CG_ENTRY:[0-9]+]], ![[CG_ENTRY]], ![[CG_ENTRY]]}
+// CHECK: ![[CG_ENTRY]] = !{null, null, i64 222}
+// CHECK: ![[DBG_FLAG]] = !{i32 2, !"Debug Info Version", i32 3}
+
+module {
+ llvm.module_flags [#llvm.mlir.module_flag<append, "CG Profile", [
+ #llvm.cgprofile_entry<from = @from, to = @to, count = 222>,
+ #llvm.cgprofile_entry<from = @from, count = 222>,
+ #llvm.cgprofile_entry<from = @to, to = @from, count = 222>
+ ]>]
+}
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index aaf9f80..49b6342 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -664,21 +664,21 @@ func.func @zero_non_llvm_type() {
// -----
func.func @nvvm_invalid_shfl_pred_1(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) {
- // expected-error@+1 {{expected return type to be a two-element struct with i1 as the second element}}
+ // expected-error@+1 {{expected return type to be a two-element struct}}
%0 = nvvm.shfl.sync bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> i32
}
// -----
func.func @nvvm_invalid_shfl_pred_2(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) {
- // expected-error@+1 {{expected return type to be a two-element struct with i1 as the second element}}
+ // expected-error@+1 {{expected return type to be a two-element struct}}
%0 = nvvm.shfl.sync bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32)>
}
// -----
func.func @nvvm_invalid_shfl_pred_3(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) {
- // expected-error@+1 {{expected return type to be a two-element struct with i1 as the second element}}
+ // expected-error@+1 {{expected second element in the returned struct to be of type 'i1' but got 'i32' instead}}
%0 = nvvm.shfl.sync bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i32)>
}
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-kind.mlir b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-kind.mlir
new file mode 100644
index 0000000..ff3e91b
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-kind.mlir
@@ -0,0 +1,221 @@
+// RUN: mlir-opt %s -split-input-file | FileCheck %s
+
+// This file contains tests for sparse MMA (mma.sp.sync) operations with KIND variants.
+// The kind::f8f6f4 variant was introduced in PTX ISA 8.7 for sm_90+ architectures.
+//
+// Based on PTX ISA documentation:
+// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-for-sparse-mma
+//
+// KIND::F8F6F4 enables:
+// - Additional FP8 types: e3m2, e2m3, e2m1
+// - F16 accumulator for m16n8k64 FP8 operations
+// - Mixed-precision FP8 computations
+//
+// Requirements:
+// - ONLY works with ordered metadata (sp::ordered_metadata)
+// - ONLY for shape m16n8k64
+// - ONLY for FP8 types (not integers or other floats)
+
+// =============================================================================
+// FP8 e4m3 Sparse MMA with KIND (m16n8k64)
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e4m3_f16
+func.func @nvvm_mma_sp_kind_m16n8k64_e4m3_f16(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e4m3>, multiplicandBPtxType = #nvvm.mma_type<e4m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {kind = #nvvm.mma_kind<f8f6f4>,
+ orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e4m3_f32
+func.func @nvvm_mma_sp_kind_m16n8k64_e4m3_f32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e4m3>, multiplicandBPtxType = #nvvm.mma_type<e4m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {kind = #nvvm.mma_kind<f8f6f4>,
+ orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// =============================================================================
+// FP8 e5m2 Sparse MMA with KIND (m16n8k64)
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e5m2_f16
+func.func @nvvm_mma_sp_kind_m16n8k64_e5m2_f16(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e5m2>, multiplicandBPtxType = #nvvm.mma_type<e5m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {kind = #nvvm.mma_kind<f8f6f4>,
+ orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e5m2_f32
+func.func @nvvm_mma_sp_kind_m16n8k64_e5m2_f32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e5m2>, multiplicandBPtxType = #nvvm.mma_type<e5m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {kind = #nvvm.mma_kind<f8f6f4>,
+ orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// =============================================================================
+// FP8 e3m2 Sparse MMA with KIND (m16n8k64)
+// NOTE: e3m2 is ONLY available with kind::f8f6f4
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e3m2_f16
+func.func @nvvm_mma_sp_kind_m16n8k64_e3m2_f16(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e3m2>, multiplicandBPtxType = #nvvm.mma_type<e3m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {kind = #nvvm.mma_kind<f8f6f4>,
+ orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<e3m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e3m2>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e3m2_f32
+func.func @nvvm_mma_sp_kind_m16n8k64_e3m2_f32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e3m2>, multiplicandBPtxType = #nvvm.mma_type<e3m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {kind = #nvvm.mma_kind<f8f6f4>,
+ orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<e3m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e3m2>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// =============================================================================
+// FP8 e2m3 Sparse MMA with KIND (m16n8k64)
+// NOTE: e2m3 is ONLY available with kind::f8f6f4
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e2m3_f16
+func.func @nvvm_mma_sp_kind_m16n8k64_e2m3_f16(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e2m3>, multiplicandBPtxType = #nvvm.mma_type<e2m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {kind = #nvvm.mma_kind<f8f6f4>,
+ orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<e2m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m3>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e2m3_f32
+func.func @nvvm_mma_sp_kind_m16n8k64_e2m3_f32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e2m3>, multiplicandBPtxType = #nvvm.mma_type<e2m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {kind = #nvvm.mma_kind<f8f6f4>,
+ orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<e2m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m3>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// =============================================================================
+// FP8 e2m1 Sparse MMA with KIND (m16n8k64)
+// NOTE: e2m1 is ONLY available with kind::f8f6f4
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e2m1_f16
+func.func @nvvm_mma_sp_kind_m16n8k64_e2m1_f16(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e2m1>, multiplicandBPtxType = #nvvm.mma_type<e2m1>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {kind = #nvvm.mma_kind<f8f6f4>,
+ orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e2m1_f32
+func.func @nvvm_mma_sp_kind_m16n8k64_e2m1_f32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e2m1>, multiplicandBPtxType = #nvvm.mma_type<e2m1>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {kind = #nvvm.mma_kind<f8f6f4>,
+ orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-ordered.mlir b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-ordered.mlir
new file mode 100644
index 0000000..a4e2812
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-ordered.mlir
@@ -0,0 +1,411 @@
+// RUN: mlir-opt %s -split-input-file | FileCheck %s
+
+// This file contains tests for sparse MMA (mma.sp.sync) operations with ORDERED metadata.
+// The ordered metadata variant was introduced in PTX ISA 8.5 for sm_90+ architectures.
+//
+// Based on PTX ISA documentation:
+// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-for-sparse-mma
+//
+// Ordered metadata provides an alternative metadata ordering for 2:4 structured sparsity
+// that can offer better performance on newer architectures.
+
+// =============================================================================
+// F16 Sparse MMA Operations with Ordered Metadata (m16n8k16)
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k16_f16_f16
+func.func @nvvm_mma_sp_ordered_m16n8k16_f16_f16(
+ %a0 : vector<2xf16>, %a1 : vector<2xf16>,
+ %b0 : vector<2xf16>, %b1 : vector<2xf16>,
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 16>}
+ : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k16_f16_f32
+func.func @nvvm_mma_sp_ordered_m16n8k16_f16_f32(
+ %a0 : vector<2xf16>, %a1 : vector<2xf16>,
+ %b0 : vector<2xf16>, %b1 : vector<2xf16>,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 16>}
+ : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// =============================================================================
+// F16 Sparse MMA Operations with Ordered Metadata (m16n8k32)
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_f16_f16
+func.func @nvvm_mma_sp_ordered_m16n8k32_f16_f16(
+ %a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>,
+ %b0 : vector<2xf16>, %b1 : vector<2xf16>, %b2 : vector<2xf16>, %b3 : vector<2xf16>,
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_f16_f32
+func.func @nvvm_mma_sp_ordered_m16n8k32_f16_f32(
+ %a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>,
+ %b0 : vector<2xf16>, %b1 : vector<2xf16>, %b2 : vector<2xf16>, %b3 : vector<2xf16>,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// =============================================================================
+// BF16 Sparse MMA Operations with Ordered Metadata
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k16_bf16_f32
+func.func @nvvm_mma_sp_ordered_m16n8k16_bf16_f32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<bf16>,
+ multiplicandBPtxType = #nvvm.mma_type<bf16>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 16>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_bf16_f32
+func.func @nvvm_mma_sp_ordered_m16n8k32_bf16_f32(
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<bf16>,
+ multiplicandBPtxType = #nvvm.mma_type<bf16>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// =============================================================================
+// TF32 Sparse MMA Operations with Ordered Metadata
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k8_tf32_f32
+func.func @nvvm_mma_sp_ordered_m16n8k8_tf32_f32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<tf32>,
+ multiplicandBPtxType = #nvvm.mma_type<tf32>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 8>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k16_tf32_f32
+func.func @nvvm_mma_sp_ordered_m16n8k16_tf32_f32(
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<tf32>,
+ multiplicandBPtxType = #nvvm.mma_type<tf32>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 16>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// =============================================================================
+// Integer (s8) Sparse MMA Operations with Ordered Metadata
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_s8_s32
+func.func @nvvm_mma_sp_ordered_m16n8k32_s8_s32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<s8>,
+ multiplicandBPtxType = #nvvm.mma_type<s8>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_s8_s32_satfinite
+func.func @nvvm_mma_sp_ordered_m16n8k32_s8_s32_satfinite(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<s8>,
+ multiplicandBPtxType = #nvvm.mma_type<s8>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_s8_s32
+func.func @nvvm_mma_sp_ordered_m16n8k64_s8_s32(
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<s8>,
+ multiplicandBPtxType = #nvvm.mma_type<s8>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// =============================================================================
+// Integer (u8) Sparse MMA Operations with Ordered Metadata
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_u8_s32
+func.func @nvvm_mma_sp_ordered_m16n8k32_u8_s32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u8>, multiplicandBPtxType = #nvvm.mma_type<u8>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<u8>,
+ multiplicandBPtxType = #nvvm.mma_type<u8>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_u8_s32
+func.func @nvvm_mma_sp_ordered_m16n8k64_u8_s32(
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u8>, multiplicandBPtxType = #nvvm.mma_type<u8>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<u8>,
+ multiplicandBPtxType = #nvvm.mma_type<u8>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// =============================================================================
+// Sub-byte Integer (s4) Sparse MMA Operations with Ordered Metadata
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_s4_s32
+func.func @nvvm_mma_sp_ordered_m16n8k64_s4_s32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<s4>,
+ multiplicandBPtxType = #nvvm.mma_type<s4>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k128_s4_s32
+func.func @nvvm_mma_sp_ordered_m16n8k128_s4_s32(
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 128>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<s4>,
+ multiplicandBPtxType = #nvvm.mma_type<s4>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 128>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// =============================================================================
+// Sub-byte Integer (u4) Sparse MMA Operations with Ordered Metadata
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_u4_s32
+func.func @nvvm_mma_sp_ordered_m16n8k64_u4_s32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u4>, multiplicandBPtxType = #nvvm.mma_type<u4>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<u4>,
+ multiplicandBPtxType = #nvvm.mma_type<u4>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k128_u4_s32
+func.func @nvvm_mma_sp_ordered_m16n8k128_u4_s32(
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u4>, multiplicandBPtxType = #nvvm.mma_type<u4>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 128>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<u4>,
+ multiplicandBPtxType = #nvvm.mma_type<u4>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 128>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// =============================================================================
+// FP8 (e4m3) Sparse MMA Operations with Ordered Metadata
+// NOTE: FP8 ordered metadata requires PTX ISA 8.7+ and sm_90+
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_e4m3_f16
+func.func @nvvm_mma_sp_ordered_m16n8k64_e4m3_f16(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e4m3>, multiplicandBPtxType = #nvvm.mma_type<e4m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_e4m3_f32
+func.func @nvvm_mma_sp_ordered_m16n8k64_e4m3_f32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e4m3>, multiplicandBPtxType = #nvvm.mma_type<e4m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// =============================================================================
+// FP8 (e5m2) Sparse MMA Operations with Ordered Metadata
+// NOTE: FP8 ordered metadata requires PTX ISA 8.7+ and sm_90+
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_e5m2_f16
+func.func @nvvm_mma_sp_ordered_m16n8k64_e5m2_f16(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e5m2>, multiplicandBPtxType = #nvvm.mma_type<e5m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_e5m2_f32
+func.func @nvvm_mma_sp_ordered_m16n8k64_e5m2_f32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e5m2>, multiplicandBPtxType = #nvvm.mma_type<e5m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-mma-sp.mlir b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp.mlir
new file mode 100644
index 0000000..e7122aa
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp.mlir
@@ -0,0 +1,390 @@
+// RUN: mlir-opt %s -split-input-file | FileCheck %s
+
+// This file contains tests for all sparse MMA (mma.sp.sync) operations in the NVVM dialect
+// Based on PTX ISA documentation:
+// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-for-sparse-mma
+//
+// Sparse MMA operations follow 2:4 structured sparsity where 2 out of every 4 elements
+// in the A operand are non-zero. The A operand is provided in compressed form,
+// and sparseMetadata provides the sparsity indices.
+//
+// NOTE: These tests use the default (standard) metadata ordering.
+// For ordered metadata tests (PTX ISA 8.5+, sm_90+), see nvvm-mma-sp-ordered.mlir.
+
+// =============================================================================
+// F16 Sparse MMA Operations (m16n8k16)
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k16_f16_f16
+func.func @nvvm_mma_sp_m16n8k16_f16_f16(
+ %a0 : vector<2xf16>, %a1 : vector<2xf16>,
+ %b0 : vector<2xf16>, %b1 : vector<2xf16>,
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 16>}
+ : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k16_f16_f32
+func.func @nvvm_mma_sp_m16n8k16_f16_f32(
+ %a0 : vector<2xf16>, %a1 : vector<2xf16>,
+ %b0 : vector<2xf16>, %b1 : vector<2xf16>,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 16>}
+ : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// =============================================================================
+// F16 Sparse MMA Operations (m16n8k32)
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_f16_f16
+func.func @nvvm_mma_sp_m16n8k32_f16_f16(
+ %a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>,
+ %b0 : vector<2xf16>, %b1 : vector<2xf16>, %b2 : vector<2xf16>, %b3 : vector<2xf16>,
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_f16_f32
+func.func @nvvm_mma_sp_m16n8k32_f16_f32(
+ %a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>,
+ %b0 : vector<2xf16>, %b1 : vector<2xf16>, %b2 : vector<2xf16>, %b3 : vector<2xf16>,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// =============================================================================
+// BF16 Sparse MMA Operations
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k16_bf16_f32
+func.func @nvvm_mma_sp_m16n8k16_bf16_f32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<bf16>,
+ multiplicandBPtxType = #nvvm.mma_type<bf16>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 16>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_bf16_f32
+func.func @nvvm_mma_sp_m16n8k32_bf16_f32(
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<bf16>,
+ multiplicandBPtxType = #nvvm.mma_type<bf16>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// =============================================================================
+// TF32 Sparse MMA Operations
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k8_tf32_f32
+func.func @nvvm_mma_sp_m16n8k8_tf32_f32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<tf32>,
+ multiplicandBPtxType = #nvvm.mma_type<tf32>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 8>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k16_tf32_f32
+func.func @nvvm_mma_sp_m16n8k16_tf32_f32(
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<tf32>,
+ multiplicandBPtxType = #nvvm.mma_type<tf32>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 16>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// =============================================================================
+// Integer (s8) Sparse MMA Operations
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_s8_s32
+func.func @nvvm_mma_sp_m16n8k32_s8_s32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<s8>,
+ multiplicandBPtxType = #nvvm.mma_type<s8>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_s8_s32_satfinite
+func.func @nvvm_mma_sp_m16n8k32_s8_s32_satfinite(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<s8>,
+ multiplicandBPtxType = #nvvm.mma_type<s8>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_s8_s32
+func.func @nvvm_mma_sp_m16n8k64_s8_s32(
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<s8>,
+ multiplicandBPtxType = #nvvm.mma_type<s8>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// =============================================================================
+// Integer (u8) Sparse MMA Operations
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_u8_s32
+func.func @nvvm_mma_sp_m16n8k32_u8_s32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u8>, multiplicandBPtxType = #nvvm.mma_type<u8>, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<u8>,
+ multiplicandBPtxType = #nvvm.mma_type<u8>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_u8_s32
+func.func @nvvm_mma_sp_m16n8k64_u8_s32(
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u8>, multiplicandBPtxType = #nvvm.mma_type<u8>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<u8>,
+ multiplicandBPtxType = #nvvm.mma_type<u8>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// =============================================================================
+// Sub-byte Integer (s4) Sparse MMA Operations
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_s4_s32
+func.func @nvvm_mma_sp_m16n8k64_s4_s32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<s4>,
+ multiplicandBPtxType = #nvvm.mma_type<s4>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k128_s4_s32
+func.func @nvvm_mma_sp_m16n8k128_s4_s32(
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>, shape = #nvvm.shape<m = 16, n = 8, k = 128>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<s4>,
+ multiplicandBPtxType = #nvvm.mma_type<s4>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 128>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// =============================================================================
+// Sub-byte Integer (u4) Sparse MMA Operations
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_u4_s32
+func.func @nvvm_mma_sp_m16n8k64_u4_s32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u4>, multiplicandBPtxType = #nvvm.mma_type<u4>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<u4>,
+ multiplicandBPtxType = #nvvm.mma_type<u4>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k128_u4_s32
+func.func @nvvm_mma_sp_m16n8k128_u4_s32(
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u4>, multiplicandBPtxType = #nvvm.mma_type<u4>, shape = #nvvm.shape<m = 16, n = 8, k = 128>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<u4>,
+ multiplicandBPtxType = #nvvm.mma_type<u4>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 128>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// =============================================================================
+// FP8 (e4m3) Sparse MMA Operations
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_e4m3_f16
+func.func @nvvm_mma_sp_m16n8k64_e4m3_f16(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e4m3>, multiplicandBPtxType = #nvvm.mma_type<e4m3>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_e4m3_f32
+func.func @nvvm_mma_sp_m16n8k64_e4m3_f32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e4m3>, multiplicandBPtxType = #nvvm.mma_type<e4m3>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// =============================================================================
+// FP8 (e5m2) Sparse MMA Operations
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_e5m2_f16
+func.func @nvvm_mma_sp_m16n8k64_e5m2_f16(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e5m2>, multiplicandBPtxType = #nvvm.mma_type<e5m2>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_e5m2_f32
+func.func @nvvm_mma_sp_m16n8k64_e5m2_f32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e5m2>, multiplicandBPtxType = #nvvm.mma_type<e5m2>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-target-invalid.mlir b/mlir/test/Dialect/LLVMIR/nvvm-target-invalid.mlir
new file mode 100644
index 0000000..c2cfa76
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/nvvm-target-invalid.mlir
@@ -0,0 +1,11 @@
+// RUN: not mlir-opt %s 2>&1 | FileCheck %s
+// CHECK: 'nvvm.tcgen05.alloc' op is not supported on sm_90
+
+module {
+ gpu.module @mod [#nvvm.target<chip = "sm_90">] {
+ func.func @tcgen05_alloc(%arg0: !llvm.ptr<7>, %arg1: i32) {
+ nvvm.tcgen05.alloc %arg0, %arg1 : !llvm.ptr<7>, i32
+ return
+ }
+ }
+}
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index 2505e56..579f0ac 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -92,13 +92,6 @@ func.func @llvm_nvvm_cluster_wait() {
llvm.return
}
-// CHECK-LABEL: @llvm_nvvm_fence_sc_cluster
-func.func @llvm_nvvm_fence_sc_cluster() {
- // CHECK: nvvm.fence.sc.cluster
- nvvm.fence.sc.cluster
- llvm.return
-}
-
// CHECK-LABEL: @nvvm_shfl
func.func @nvvm_shfl(
%arg0 : i32, %arg1 : i32, %arg2 : i32,
@@ -445,8 +438,8 @@ llvm.func private @mbarrier_arrive(%barrier: !llvm.ptr) {
}
llvm.func private @mbarrier_arrive_shared(%barrier: !llvm.ptr<3>) {
- // CHECK: nvvm.mbarrier.arrive.shared %{{.*}} : !llvm.ptr<3>
- %0 = nvvm.mbarrier.arrive.shared %barrier : !llvm.ptr<3> -> i64
+ // CHECK: nvvm.mbarrier.arrive %{{.*}} : !llvm.ptr<3>
+ %0 = nvvm.mbarrier.arrive %barrier : !llvm.ptr<3> -> i64
llvm.return
}
@@ -459,21 +452,8 @@ llvm.func private @mbarrier_arrive_nocomplete(%barrier: !llvm.ptr) {
llvm.func private @mbarrier_arrive_nocomplete_shared(%barrier: !llvm.ptr<3>) {
%count = nvvm.read.ptx.sreg.ntid.x : i32
- // CHECK: nvvm.mbarrier.arrive.nocomplete.shared %{{.*}} : !llvm.ptr<3>
- %0 = nvvm.mbarrier.arrive.nocomplete.shared %barrier, %count : !llvm.ptr<3>, i32 -> i64
- llvm.return
-}
-
-llvm.func private @mbarrier_test_wait(%barrier: !llvm.ptr, %token : i64) -> i1 {
- // CHECK: nvvm.mbarrier.test.wait %{{.*}}
- %isComplete = nvvm.mbarrier.test.wait %barrier, %token : !llvm.ptr, i64 -> i1
- llvm.return %isComplete : i1
-}
-
-llvm.func private @mbarrier_test_wait_shared(%barrier: !llvm.ptr<3>, %token : i64) {
- %count = nvvm.read.ptx.sreg.ntid.x : i32
- // CHECK: nvvm.mbarrier.test.wait.shared %{{.*}}
- %isComplete = nvvm.mbarrier.test.wait.shared %barrier, %token : !llvm.ptr<3>, i64 -> i1
+ // CHECK: nvvm.mbarrier.arrive.nocomplete %{{.*}} : !llvm.ptr<3>
+ %0 = nvvm.mbarrier.arrive.nocomplete %barrier, %count : !llvm.ptr<3>, i32 -> i64
llvm.return
}
diff --git a/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir b/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir
new file mode 100644
index 0000000..506b81e
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir
@@ -0,0 +1,70 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+// Test invalid target architecture (sm_100 instead of sm_100a)
+gpu.module @invalid_arch_sm_100 [#nvvm.target<chip = "sm_100">] {
+ func.func @convert_rs(%src : vector<4xf32>, %rbits : i32) {
+ // expected-error@+1 {{'nvvm.convert.f32x4.to.f8x4' op is not supported on sm_100}}
+ %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E4M3FN)
+ return
+ }
+}
+
+// -----
+
+// Test invalid destination types for f8x4 (should only accept f8E4M3FN, f8E5M2)
+llvm.func @invalid_dst_type_f8x4_e3m4(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
+ // expected-error@+1 {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f32x4 to f8x4.}}
+ %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E3M4)
+ llvm.return %res : vector<4xi8>
+}
+
+// -----
+
+llvm.func @invalid_dst_type_f8x4_e8m0(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
+ // expected-error@+1 {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f32x4 to f8x4.}}
+ %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E8M0FNU)
+ llvm.return %res : vector<4xi8>
+}
+
+// -----
+
+// Test invalid destination types for f6x4 (should only accept f6E2M3FN, f6E3M2FN)
+llvm.func @invalid_dst_type_f6x4_f8(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
+ // expected-error@+1 {{Only 'f6E2M3FN' and 'f6E3M2FN' types are supported for conversions from f32x4 to f6x4.}}
+ %res = nvvm.convert.f32x4.to.f6x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E4M3FN)
+ llvm.return %res : vector<4xi8>
+}
+
+// -----
+
+// Test invalid destination type for f4x4 (should only accept f4E2M1FN)
+llvm.func @invalid_dst_type_f4x4_f6(%src : vector<4xf32>, %rbits : i32) -> i16 {
+ // expected-error@+1 {{Only 'f4E2M1FN' type is supported for conversions from f32x4 to f4x4.}}
+ %res = nvvm.convert.f32x4.to.f4x4 %src, %rbits : vector<4xf32> -> i16 (f6E2M3FN)
+ llvm.return %res : i16
+}
+
+// -----
+
+// Test invalid rounding modes for non-stochastic ops
+llvm.func @convert_float_to_tf32_rs_not_supported(%src : f32) -> i32 {
+ // expected-error @below {{Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.}}
+ %res = nvvm.convert.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rs>}
+ llvm.return %res : i32
+}
+
+// -----
+
+llvm.func @convert_f32x2_to_f8x2_rs_not_supported(%a : f32, %b : f32) {
+ // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to 'f8E4M3FN' and 'f8E5M2' types}}
+ %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : i16 (f8E4M3FN)
+ llvm.return
+}
+
+// -----
+
+llvm.func @convert_bf16x2_to_f8x2_rs_not_supported(%src : vector<2xbf16>) {
+ // expected-error @below {{Only RZ and RP rounding modes are supported for conversions from bf16x2 to f8x2.}}
+ %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16> -> i16 (f8E8M0FNU)
+ llvm.return
+}
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index e703600..40084bc 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -14,18 +14,24 @@ func.func @rocdl_special_regs() -> i32 {
%4 = rocdl.workgroup.id.y : i32
// CHECK: rocdl.workgroup.id.z : i32
%5 = rocdl.workgroup.id.z : i32
+ // CHECK: rocdl.cluster.id.x : i32
+ %6 = rocdl.cluster.id.x : i32
+ // CHECK: rocdl.cluster.id.y : i32
+ %7 = rocdl.cluster.id.y : i32
+ // CHECK: rocdl.cluster.id.z : i32
+ %8 = rocdl.cluster.id.z : i32
// CHECK: rocdl.workgroup.dim.x : i32
- %6 = rocdl.workgroup.dim.x : i32
+ %9 = rocdl.workgroup.dim.x : i32
// CHECK: rocdl.workgroup.dim.y : i32
- %7 = rocdl.workgroup.dim.y : i32
+ %10 = rocdl.workgroup.dim.y : i32
// CHECK: rocdl.workgroup.dim.z : i32
- %8 = rocdl.workgroup.dim.z : i32
+ %11 = rocdl.workgroup.dim.z : i32
// CHECK: rocdl.grid.dim.x : i32
- %9 = rocdl.grid.dim.x : i32
+ %12 = rocdl.grid.dim.x : i32
// CHECK: rocdl.grid.dim.y : i32
- %10 = rocdl.grid.dim.y : i32
+ %13 = rocdl.grid.dim.y : i32
// CHECK: rocdl.grid.dim.z : i32
- %11 = rocdl.grid.dim.z : i32
+ %14 = rocdl.grid.dim.z : i32
llvm.return %0 : i32
}
@@ -43,6 +49,59 @@ func.func @rocdl.fmed3.vector(%a: vector<4xf16>, %b: vector<4xf16>, %c: vector<4
llvm.return %0 : vector<4xf16>
}
+func.func @rocdl.math.ops(%a: f32, %b: f16, %c: bf16) {
+ // CHECK-LABEL: rocdl.math.ops
+ // CHECK: %{{.*}} = rocdl.tanh %{{.*}} f32 -> f32
+ // CHECK: %{{.*}} = rocdl.tanh %{{.*}} f16 -> f16
+ // CHECK: %{{.*}} = rocdl.tanh %{{.*}} bf16 -> bf16
+ %tanh0 = rocdl.tanh %a f32 -> f32
+ %tanh1 = rocdl.tanh %b f16 -> f16
+ %tanh2 = rocdl.tanh %c bf16 -> bf16
+
+ // CHECK: %{{.*}} = rocdl.sin %{{.*}} f32 -> f32
+ // CHECK: %{{.*}} = rocdl.sin %{{.*}} f16 -> f16
+ // CHECK: %{{.*}} = rocdl.sin %{{.*}} bf16 -> bf16
+ %sin0 = rocdl.sin %a f32 -> f32
+ %sin1 = rocdl.sin %b f16 -> f16
+ %sin2 = rocdl.sin %c bf16 -> bf16
+
+ // CHECK: %{{.*}} = rocdl.cos %{{.*}} f32 -> f32
+ // CHECK: %{{.*}} = rocdl.cos %{{.*}} f16 -> f16
+ // CHECK: %{{.*}} = rocdl.cos %{{.*}} bf16 -> bf16
+ %cos0 = rocdl.cos %a f32 -> f32
+ %cos1 = rocdl.cos %b f16 -> f16
+ %cos2 = rocdl.cos %c bf16 -> bf16
+
+ // CHECK: %{{.*}} = rocdl.rcp %{{.*}} f32 -> f32
+ // CHECK: %{{.*}} = rocdl.rcp %{{.*}} f16 -> f16
+ // CHECK: %{{.*}} = rocdl.rcp %{{.*}} bf16 -> bf16
+ %rcp0 = rocdl.rcp %a f32 -> f32
+ %rcp1 = rocdl.rcp %b f16 -> f16
+ %rcp2 = rocdl.rcp %c bf16 -> bf16
+
+ // CHECK: %{{.*}} = rocdl.exp2 %{{.*}} f32 -> f32
+ // CHECK: %{{.*}} = rocdl.exp2 %{{.*}} f16 -> f16
+ // CHECK: %{{.*}} = rocdl.exp2 %{{.*}} bf16 -> bf16
+ %exp2_0 = rocdl.exp2 %a f32 -> f32
+ %exp2_1 = rocdl.exp2 %b f16 -> f16
+ %exp2_2 = rocdl.exp2 %c bf16 -> bf16
+
+ // CHECK: %{{.*}} = rocdl.log %{{.*}} f32 -> f32
+ // CHECK: %{{.*}} = rocdl.log %{{.*}} f16 -> f16
+ // CHECK: %{{.*}} = rocdl.log %{{.*}} bf16 -> bf16
+ %log0 = rocdl.log %a f32 -> f32
+ %log1 = rocdl.log %b f16 -> f16
+ %log2 = rocdl.log %c bf16 -> bf16
+
+ // CHECK: %{{.*}} = rocdl.sqrt %{{.*}} f32 -> f32
+ // CHECK: %{{.*}} = rocdl.sqrt %{{.*}} f16 -> f16
+ // CHECK: %{{.*}} = rocdl.sqrt %{{.*}} bf16 -> bf16
+ %sqrt0 = rocdl.sqrt %a f32 -> f32
+ %sqrt1 = rocdl.sqrt %b f16 -> f16
+ %sqrt2 = rocdl.sqrt %c bf16 -> bf16
+ llvm.return
+}
+
func.func @rocdl.barrier() {
// CHECK: rocdl.barrier
rocdl.barrier
@@ -650,6 +709,39 @@ llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> {
llvm.return %r3 : vector<4xf16>
}
+llvm.func @rocdl.load.tr.ops(%gl_ptr : !llvm.ptr<1>, %ds_ptr : !llvm.ptr<3>) {
+ // CHECK-LABEL: @rocdl.load.tr.ops
+ // CHECK-SAME: (%[[GL_PTR:.+]]: !llvm.ptr<1>, %[[DS_OTR:.+]]: !llvm.ptr<3>)
+ // CHECK: rocdl.global.load.tr4.b64 %[[GL_PTR]] : !llvm.ptr<1> -> vector<2xi32>
+ // CHECK: rocdl.global.load.tr.b64 %[[GL_PTR]] : !llvm.ptr<1> -> vector<2xi32>
+ // CHECK: rocdl.global.load.tr6.b96 %[[GL_PTR]] : !llvm.ptr<1> -> vector<3xi32>
+ // CHECK: rocdl.global.load.tr.b128 %[[GL_PTR]] : !llvm.ptr<1> -> vector<8xi16>
+ // CHECK: rocdl.global.load.tr.b128 %[[GL_PTR]] : !llvm.ptr<1> -> vector<8xf16>
+ // CHECK: rocdl.global.load.tr.b128 %[[GL_PTR]] : !llvm.ptr<1> -> vector<8xbf16>
+ // CHECK: rocdl.ds.load.tr4.b64 %[[DS_OTR]] : !llvm.ptr<3> -> vector<2xi32>
+ // CHECK: rocdl.ds.load.tr8.b64 %[[DS_OTR]] : !llvm.ptr<3> -> vector<2xi32>
+ // CHECK: rocdl.ds.load.tr6.b96 %[[DS_OTR]] : !llvm.ptr<3> -> vector<3xi32>
+ // CHECK: rocdl.ds.load.tr16.b128 %[[DS_OTR]] : !llvm.ptr<3> -> vector<8xi16>
+ // CHECK: rocdl.ds.load.tr16.b128 %[[DS_OTR]] : !llvm.ptr<3> -> vector<8xf16>
+ // CHECK: rocdl.ds.load.tr16.b128 %[[DS_OTR]] : !llvm.ptr<3> -> vector<8xbf16>
+ // CHECK: llvm.return
+
+ rocdl.global.load.tr4.b64 %gl_ptr : !llvm.ptr<1> -> vector<2xi32>
+ rocdl.global.load.tr.b64 %gl_ptr : !llvm.ptr<1> -> vector<2xi32>
+ rocdl.global.load.tr6.b96 %gl_ptr : !llvm.ptr<1> -> vector<3xi32>
+ rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xi16>
+ rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xf16>
+ rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xbf16>
+
+ rocdl.ds.load.tr4.b64 %ds_ptr : !llvm.ptr<3> -> vector<2xi32>
+ rocdl.ds.load.tr8.b64 %ds_ptr : !llvm.ptr<3> -> vector<2xi32>
+ rocdl.ds.load.tr6.b96 %ds_ptr : !llvm.ptr<3> -> vector<3xi32>
+ rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xi16>
+ rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xf16>
+ rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xbf16>
+ llvm.return
+}
+
llvm.func @rocdl.load.to.lds(%src : !llvm.ptr<7>, %dst: !llvm.ptr<3>) {
// CHECK-LABEL @rocdl.load.to.lds
//CHECK: rocdl.load.to.lds %{{.*}}, %{{.*}}, 4, 0, 0 : <7>
@@ -664,6 +756,33 @@ llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) {
llvm.return
}
+llvm.func @rocdl.global.load.async.to.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) {
+ // CHECK-LABEL @rocdl.global.load.async.to.lds
+ // CHECK: rocdl.global.load.async.to.lds.b8 %{{.*}}, %{{.*}}, 0, 0
+ // CHECK: rocdl.global.load.async.to.lds.b32 %{{.*}}, %{{.*}}, 0, 0
+ // CHECK: rocdl.global.load.async.to.lds.b64 %{{.*}}, %{{.*}}, 0, 0
+ // CHECK: rocdl.global.load.async.to.lds.b128 %{{.*}}, %{{.*}}, 0, 0
+ rocdl.global.load.async.to.lds.b8 %src, %dst, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3>
+ rocdl.global.load.async.to.lds.b32 %src, %dst, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3>
+ rocdl.global.load.async.to.lds.b64 %src, %dst, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3>
+ rocdl.global.load.async.to.lds.b128 %src, %dst, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3>
+ llvm.return
+}
+
+llvm.func @rocdl.cluster.load.async.to.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) {
+ // CHECK-LABEL @rocdl.cluster.load.async.to.lds
+ // CHECK: rocdl.cluster.load.async.to.lds.b8 %{{.*}}, %{{.*}}, 0, 0, 0
+ // CHECK: rocdl.cluster.load.async.to.lds.b32 %{{.*}}, %{{.*}}, 0, 0, 0
+ // CHECK: rocdl.cluster.load.async.to.lds.b64 %{{.*}}, %{{.*}}, 0, 0, 0
+ // CHECK: rocdl.cluster.load.async.to.lds.b128 %{{.*}}, %{{.*}}, 0, 0, 0
+ rocdl.cluster.load.async.to.lds.b8 %src, %dst, 0, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3>
+ rocdl.cluster.load.async.to.lds.b32 %src, %dst, 0, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3>
+ rocdl.cluster.load.async.to.lds.b64 %src, %dst, 0, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3>
+ rocdl.cluster.load.async.to.lds.b128 %src, %dst, 0, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3>
+ llvm.return
+}
+
+
// CHECK-LABEL @rocdl.tensor.load.to.lds
llvm.func @rocdl.tensor.load.to.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>,
%dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) {
@@ -1037,6 +1156,13 @@ llvm.func @rocdl.s.get.barrier.state() {
llvm.return
}
+llvm.func @rocdl.s.get.named.barrier.state(%ptr : !llvm.ptr<3>) {
+ // CHECK-LABEL: rocdl.s.get.named.barrier.state
+ // CHECK: rocdl.s.get.named.barrier.state %[[PTR:.+]]
+ %0 = rocdl.s.get.named.barrier.state %ptr : i32
+ llvm.return
+}
+
llvm.func @rocdl.s.wait.dscnt() {
// CHECK-LABEL: rocdl.s.wait.dscnt
// CHECK: rocdl.s.wait.dscnt 0
@@ -1292,6 +1418,26 @@ llvm.func @rocdl.cvt.scalef32.sr.pk16(%v16xf32: vector<16xf32>,
// -----
+// CHECK-LABEL: @rocdl_wmma_scale_ops
+llvm.func @rocdl_wmma_scale_ops(%a_f8: vector<8xi32>, %a_f4: vector<4xi32>, %c_f32: vector<4xf32>, %c16_f32: vector<16xf32>,
+ %scale_i32: i32, %scale_i64: i64) {
+ // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %r0 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %a_f8, %a_f8, %c_f32, %scale_i32, %scale_i32 : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+
+ // CHECK: rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i64, i64) -> vector<4xf32>
+ %r1 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %a_f8, %a_f8, %c_f32, %scale_i64, %scale_i64 : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i64, i64) -> vector<4xf32>
+
+ // CHECK: rocdl.wmma.scale.f32.32x16x128.f4 %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
+ %r2 = rocdl.wmma.scale.f32.32x16x128.f4 %a_f4, %a_f4, %c16_f32, %scale_i32, %scale_i32 : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
+
+ // CHECK: rocdl.wmma.scale16.f32.32x16x128.f4 %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i64, i64) -> vector<16xf32>
+ %r3 = rocdl.wmma.scale16.f32.32x16x128.f4 %a_f4, %a_f4, %c16_f32, %scale_i64, %scale_i64 : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i64, i64) -> vector<16xf32>
+
+ llvm.return
+}
+
+// -----
+
// expected-error@below {{attribute attached to unexpected op}}
func.func private @expected_llvm_func() attributes { rocdl.kernel }
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 00e763a..afbf47e 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -122,8 +122,8 @@ func.func @ops(%arg0: i32, %arg1: f32,
// CHECK: llvm.call @baz() {will_return} : () -> ()
llvm.call @baz() {will_return} : () -> ()
-// CHECK: llvm.call @baz() {memory = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = write>} : () -> ()
- llvm.call @baz() {memory = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = write>} : () -> ()
+// CHECK: llvm.call @baz() {memory = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = write, errnoMem = none, targetMem0 = none, targetMem1 = none>} : () -> ()
+ llvm.call @baz() {memory = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = write, errnoMem = none, targetMem0 = none, targetMem1 = none>} : () -> ()
// Terminator operations and their successors.
//
diff --git a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
new file mode 100644
index 0000000..4b2d42a
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
@@ -0,0 +1,214 @@
+// The following test examples of linalg convolution named ops lowered to linalg.generic and then
+// lifted back up to named op.
+// NOTE: Most tests in this file use dynamic shapes as the underlying transformations don't modify shapes. There's one exception that's added as a smoke test.
+
+// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s --implicit-check-not=linalg.generic
+
+// -----------------------------
+// Convolution ops.
+// -----------------------------
+func.func @conv_1d(%in : tensor<?xf32>, %filter : tensor<?xf32>, %out : tensor<?xf32>) -> tensor<?xf32> {
+ %0 = linalg.conv_1d
+ ins(%in, %filter : tensor<?xf32>, tensor<?xf32>)
+ outs(%out : tensor<?xf32>) -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+// CHECK: @conv_1d
+// CHECK: linalg.conv_1d
+
+// -----
+
+func.func @conv_1d_nwc_wcf(%input: tensor<?x?x?xf32>, %filter: tensor<?x?x?xf32>, %output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %0 = linalg.conv_1d_nwc_wcf
+ {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>}
+ ins (%input, %filter: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+ outs (%output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+// CHECK: @conv_1d_nwc_wcf
+// CHECK: linalg.conv_1d_nwc_wcf
+// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
+
+// -----
+
+func.func @conv_1d_ncw_fcw(%input: tensor<?x?x?xf32>, %filter: tensor<?x?x?xf32>, %output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %0 = linalg.conv_1d_ncw_fcw
+ {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>}
+ ins (%input, %filter: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+ outs (%output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+// CHECK: @conv_1d_ncw_fcw
+// CHECK: linalg.conv_1d_ncw_fcw
+// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
+
+// -----
+
+func.func @conv_2d(%in : tensor<?x?xf32>, %filter : tensor<?x?xf32>, %out : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.conv_2d
+ ins(%in, %filter : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%out: tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK: @conv_2d
+// CHECK: linalg.conv_2d
+
+// -----
+
+func.func @conv_3d(%in : tensor<?x?x?xf32>, %filter : tensor<?x?x?xf32>, %out : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %0 = linalg.conv_3d
+ ins(%in, %filter : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+ outs(%out : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+// CHECK: @conv_3d
+// CHECK: linalg.conv_3d
+
+// -----
+
+// -----------------------------
+// Depthwise Convolution ops.
+// -----------------------------
+func.func @depthwise_conv_1d_ncw_cw(%input: tensor<?x?x?xf32>, %filter: tensor<?x?xf32>, %output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %0 = linalg.depthwise_conv_1d_ncw_cw
+ {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>}
+ ins (%input, %filter: tensor<?x?x?xf32>, tensor<?x?xf32>)
+ outs (%output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+// CHECK: @depthwise_conv_1d_ncw_cw
+// CHECK: linalg.depthwise_conv_1d_ncw_cw
+// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
+
+// -----
+
+func.func @depthwise_conv_1d_nwc_wc_static(%input: tensor<1x25x8xi8>, %filter: tensor<3x8xi8>, %output: tensor<1x10x8xi32>) -> tensor<1x10x8xi32> {
+ %0 = linalg.depthwise_conv_1d_nwc_wc
+ {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>}
+ ins (%input, %filter: tensor<1x25x8xi8>, tensor<3x8xi8>)
+ outs (%output: tensor<1x10x8xi32>) -> tensor<1x10x8xi32>
+ return %0 : tensor<1x10x8xi32>
+}
+// CHECK: @depthwise_conv_1d_nwc_wc_static
+// CHECK: linalg.depthwise_conv_1d_nwc_wc
+// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
+
+// -----
+
+func.func @depthwise_conv_1d_nwc_wcm(%input: tensor<?x?x?xf32>, %filter: tensor<?x?x?xf32>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.depthwise_conv_1d_nwc_wcm
+ {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+ ins (%input, %filter: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+ outs (%output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+// CHECK: @depthwise_conv_1d_nwc_wcm
+// CHECK: linalg.depthwise_conv_1d_nwc_wcm
+// CHECK-SAME: dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>
+
+// -----
+
+func.func @depthwise_conv_2d_nchw_chw(%input: tensor<?x?x?x?xf16>, %filter: tensor<?x?x?xf16>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.depthwise_conv_2d_nchw_chw
+ {dilations = dense<[2,3]> : vector<2xi64>, strides = dense<[4,5]> : vector<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xf16>, tensor<?x?x?xf16>)
+ outs (%output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+// CHECK: @depthwise_conv_2d_nchw_chw
+// CHECK: linalg.depthwise_conv_2d_nchw_chw
+// CHECK-SAME: dilations = dense<[2, 3]> : tensor<2xi64>, strides = dense<[4, 5]> : tensor<2xi64>
+
+// -----
+
+func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %output: tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32> {
+ %0 = linalg.depthwise_conv_3d_ndhwc_dhwcm
+ {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
+ ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+ outs (%output: tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?x?x?xf32>
+}
+// CHECK: @depthwise_conv_3d_ndhwc_dhwcm
+// CHECK: linalg.depthwise_conv_3d_ndhwc_dhwcm
+// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>
+
+// -----
+
+// -----------------------------
+// Pooling ops.
+// -----------------------------
+func.func @pooling_nhwc_max(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.pooling_nhwc_max
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+ outs (%output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+// CHECK: @pooling_nhwc_max
+// CHECK: linalg.pooling_nhwc_max
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
+func.func @pooling_nhwc_min(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.pooling_nhwc_min
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+ outs (%output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+// CHECK: @pooling_nhwc_min
+// CHECK: linalg.pooling_nhwc_min
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
+func.func @pooling_nhwc_sum(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.pooling_nhwc_sum
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+ outs (%output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+// CHECK: @pooling_nhwc_sum
+// CHECK: linalg.pooling_nhwc_sum
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
+func.func @pooling_nhwc_max_unsigned(%input: tensor<?x?x?x?xi8>, %filter: tensor<?x?xi8>, %output: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32> {
+ %0 = linalg.pooling_nhwc_max_unsigned
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xi8>, tensor<?x?xi8>)
+ outs (%output: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
+ return %0 : tensor<?x?x?x?xi32>
+}
+// CHECK: @pooling_nhwc_max_unsigned
+// CHECK: linalg.pooling_nhwc_max_unsigned
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
+func.func @pooling_nhwc_min_unsigned_integer(%input: tensor<?x?x?x?xi32>, %filter: tensor<?x?xi32>, %output: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32> {
+ %0 = linalg.pooling_nhwc_min_unsigned
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xi32>, tensor<?x?xi32>)
+ outs (%output: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
+ return %0 : tensor<?x?x?x?xi32>
+}
+// CHECK: @pooling_nhwc_min_unsigned_integer
+// CHECK: linalg.pooling_nhwc_min_unsigned
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
+func.func @pooling_nhwc_min_unsigned_float(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.pooling_nhwc_min_unsigned
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+ outs (%output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+// CHECK: @pooling_nhwc_min_unsigned_float
+// CHECK: linalg.pooling_nhwc_min
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
index 2bf3d21..77c7d7d 100644
--- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
+++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
@@ -594,6 +594,24 @@ func.func @fuse_by_collapsing_pad(%arg0 : tensor<2x12x5x336x9xi32>) -> tensor<8x
// -----
+func.func @no_fuse_by_collapsing_pad_non_constant_padding(%arg0 : tensor<2x12xi32>) -> tensor<8x3x4xi32> {
+ %expand = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [2, 3, 4] : tensor<2x12xi32> into tensor<2x3x4xi32>
+ %cst = arith.constant 0 : i32
+ %padded_0 = tensor.pad %expand low[1, 0, 0] high[5, 0, 0] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index):
+ %pad_val = arith.index_cast %arg1 : index to i32
+ tensor.yield %pad_val : i32
+ } : tensor<2x3x4xi32> to tensor<8x3x4xi32>
+ return %padded_0 : tensor<8x3x4xi32>
+}
+// CHECK: func @no_fuse_by_collapsing_pad_non_constant_padding(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12xi32>)
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]]
+// CHECK: %[[PAD:.+]] = tensor.pad %[[EXPAND]]
+// CHECK: return %[[PAD]]
+
+// -----
+
func.func @no_fuse_by_collapsing_pad(%arg0 : tensor<2x12x5x336x9xi32>) -> tensor<8x5x4x17x6x7x8x14xi32> {
%expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
%cst = arith.constant 0 : i32
@@ -640,6 +658,63 @@ func.func @fuse_by_collapsing_dynamic_pad(%arg0 : tensor<?x?x?x?xf32>,
// CHECK: return %[[EXPAND]]
// -----
+
+func.func @collapse_shape_with_producer_pad(%arg0: tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x336x14xi32> {
+ %cst = arith.constant 0 : i32
+ %padded = tensor.pad %arg0 low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index,
+ %arg5: index, %arg6: index, %arg7: index, %arg8: index):
+ tensor.yield %cst : i32
+ } : tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x3x4x17x6x7x8x14xi32>
+ %collapsed = tensor.collapse_shape %padded [[0], [1, 2], [3], [4, 5, 6], [7]]
+ : tensor<8x3x4x17x6x7x8x14xi32> into tensor<8x12x17x336x14xi32>
+ return %collapsed : tensor<8x12x17x336x14xi32>
+}
+// CHECK: func @collapse_shape_with_producer_pad
+// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32>
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
+// CHECK: %[[PAD:.+]] = tensor.pad %[[COLLAPSE]] low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2]
+// CHECK: return %[[PAD]]
+
+// -----
+
+func.func @collapse_shape_with_producer_pad_dynamic(%arg0: tensor<?x?x?x?x?x?xf32>,
+ %l0 : index, %l1 : index, %h0 : index, %h1 : index) -> tensor<?x?x?x?xf32> {
+ %cst = arith.constant 0.0 : f32
+ %padded = tensor.pad %arg0 low[%l0, 0, 0, %l1, 0, 0] high[%h0, 0, 0, %h1, 0, 0] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index):
+ tensor.yield %cst : f32
+ } : tensor<?x?x?x?x?x?xf32> to tensor<?x?x?x?x?x?xf32>
+ %collapsed = tensor.collapse_shape %padded [[0], [1, 2], [3], [4, 5]]
+ : tensor<?x?x?x?x?x?xf32> into tensor<?x?x?x?xf32>
+ return %collapsed : tensor<?x?x?x?xf32>
+}
+// CHECK: func @collapse_shape_with_producer_pad_dynamic
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?x?x?xf32>
+// CHECK-SAME: %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5]]
+// CHECK: %[[PAD:.+]] = tensor.pad %[[COLLAPSE]] low[%[[L0]], 0, %[[L1]], 0] high[%[[H0]], 0, %[[H1]], 0]
+// CHECK: return %[[PAD]]
+
+// -----
+
+func.func @collapse_shape_with_producer_pad_non_constant_padding(%arg0 : tensor<2x3x4xi32>) -> tensor<8x12xi32> {
+ %cst = arith.constant 0 : i32
+ %padded_0 = tensor.pad %arg0 low[1, 0, 0] high[5, 0, 0] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index):
+ %pad_val = arith.index_cast %arg1 : index to i32
+ tensor.yield %pad_val : i32
+ } : tensor<2x3x4xi32> to tensor<8x3x4xi32>
+ %collapsed = tensor.collapse_shape %padded_0 [[0], [1, 2]] : tensor<8x3x4xi32> into tensor<8x12xi32>
+ return %collapsed : tensor<8x12xi32>
+}
+// CHECK: func @collapse_shape_with_producer_pad_non_constant_padding(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4xi32>)
+// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]]
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PAD]]
+// CHECK: return %[[COLLAPSED]]
+
+// -----
// Static problem sizes. Checks all aspects of fusion by collapsing with bubbling up collapse shapes.
#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2)>
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index bc55c12..6f1a422 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -921,30 +921,6 @@ func.func @fold_fill_generic_basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
// -----
-// CHECK-LABEL: func @fold_fill_generic_different_dtype
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf16>) -> tensor<?xf16> {
-// CHECK-NOT: linalg.fill
-// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG0]] : tensor<?xf16>)
-// CHECK-SAME: outs({{.*}} : tensor<?xf16>) {
-#map0 = affine_map<(d0) -> (d0)>
-func.func @fold_fill_generic_different_dtype(%arg0: tensor<?xf16>) -> (tensor<?xf16>) {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 7.0 : f32
- %0 = tensor.dim %arg0, %c0 : tensor<?xf16>
- %1 = tensor.empty(%0) : tensor<?xf16>
- %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<?xf16>) -> tensor<?xf16>
- %3 = tensor.empty(%0) : tensor<?xf16>
- %4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor<?xf16>, tensor<?xf16>) outs (%3:tensor<?xf16>) {
- ^bb0(%arg1: f16, %arg2: f16, %arg3: f16):
- %5 = arith.addf %arg1, %arg2 : f16
- linalg.yield %5 : f16
- } -> tensor<?xf16>
- return %4 : tensor<?xf16>
-}
-
-// -----
-
// CHECK-LABEL: func @fold_fill_generic_mixedaccess
// CHECK-NOT: linalg.fill
// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
@@ -1079,4 +1055,4 @@ module {
// CHECK-NOT: linalg.generic
// CHECK: tensor.expand_shape
// CHECK: linalg.generic {{.*}}, iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]}
-// CHECK-SAME: ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>) \ No newline at end of file
+// CHECK-SAME: ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>)
diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index 290c6c7..4526dc9 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -380,8 +380,8 @@ func.func @generalize_pooling_nwc_sum_i32(%input : tensor<1x16x1xi32>, %shape: t
// -----
-func.func @generalize_fill_0d(%value: f64, %O: tensor<f32>) -> tensor<f32> {
- %0 = linalg.fill ins(%value: f64) outs(%O : tensor<f32>) -> tensor<f32>
+func.func @generalize_fill_0d(%value: f32, %O: tensor<f32>) -> tensor<f32> {
+ %0 = linalg.fill ins(%value: f32) outs(%O : tensor<f32>) -> tensor<f32>
return %0: tensor<f32>
}
@@ -394,8 +394,8 @@ func.func @generalize_fill_0d(%value: f64, %O: tensor<f32>) -> tensor<f32> {
// -----
-func.func @generalize_fill_2d(%value: f64, %O: memref<16x32xf32>) {
- linalg.fill ins(%value: f64) outs(%O : memref<16x32xf32>)
+func.func @generalize_fill_2d(%value: f32, %O: memref<16x32xf32>) {
+ linalg.fill ins(%value: f32) outs(%O : memref<16x32xf32>)
return
}
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index fabc8e6..1f554e6 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -352,6 +352,24 @@ func.func @illegal_fill_tensor_with_memref_return
// -----
+func.func @illegal_fill_element_type_truncation(%arg0 : tensor<2xf32>, %arg1 : f64) -> tensor<2xf32>
+{
+ // expected-error @+1 {{'linalg.fill' op expected fill value type ('f64') to match output element type ('f32')}}
+ %0 = linalg.fill ins(%arg1 : f64) outs(%arg0 : tensor<2xf32>) -> tensor<2xf32>
+ return %0 : tensor<2xf32>
+}
+
+// -----
+
+func.func @illegal_fill_element_type_extension(%arg0 : tensor<2xi32>, %arg1 : i16) -> tensor<2xi32>
+{
+ // expected-error @+1 {{'linalg.fill' op expected fill value type ('i16') to match output element type ('i32')}}
+ %0 = linalg.fill ins(%arg1 : i16) outs(%arg0 : tensor<2xi32>) -> tensor<2xi32>
+ return %0 : tensor<2xi32>
+}
+
+// -----
+
func.func @illegal_fill_value_type(%arg0 : tensor<2x2xf32>, %arg1 : tensor<2xf32>) -> tensor<2x2xf32>
{
// expected-error @+1 {{expected op with scalar input}}
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 67b4f2b..3fb7225 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -822,6 +822,23 @@ func.func @fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<
// -----
+func.func @no_fuse_by_expanding_pad_non_constant_padding(%arg0 : tensor<2x3x4xi32>) -> tensor<8x12xi32> {
+ %collapse = tensor.collapse_shape %arg0 [[0], [1, 2]] : tensor<2x3x4xi32> into tensor<2x12xi32>
+ %padded_0 = tensor.pad %collapse low[1, 0] high[5, 0] {
+ ^bb0(%arg1: index, %arg2: index):
+ %pad_val = arith.index_cast %arg1 : index to i32
+ tensor.yield %pad_val : i32
+ } : tensor<2x12xi32> to tensor<8x12xi32>
+ return %padded_0 : tensor<8x12xi32>
+}
+// CHECK: func @no_fuse_by_expanding_pad_non_constant_padding(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4xi32>)
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]]
+// CHECK: %[[PAD:.+]] = tensor.pad %[[COLLAPSE]]
+// CHECK: return %[[PAD]]
+
+// -----
+
func.func @no_fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x339x14xi32> {
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
%cst = arith.constant 0 : i32
@@ -863,6 +880,64 @@ func.func @fuse_by_expanding_dynamic_pad(%arg0 : tensor<?x?x?x?x?x?xi32>, %l0: i
// -----
+func.func @expand_shape_with_producer_pad(%arg0: tensor<2x12x5x336x9xi32>) -> tensor<8x3x4x17x6x7x8x14xi32> {
+ %cst = arith.constant 0 : i32
+ %padded = tensor.pad %arg0 low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index):
+ tensor.yield %cst : i32
+ } : tensor<2x12x5x336x9xi32> to tensor<8x12x17x336x14xi32>
+ %expanded = tensor.expand_shape %padded [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [8, 3, 4, 17, 6, 7, 8, 14]
+ : tensor<8x12x17x336x14xi32> into tensor<8x3x4x17x6x7x8x14xi32>
+ return %expanded : tensor<8x3x4x17x6x7x8x14xi32>
+}
+// CHECK: func @expand_shape_with_producer_pad
+// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9]
+// CHECK: %[[PAD:.+]] = tensor.pad %[[EXPAND]] low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2]
+// CHECK: return %[[PAD]]
+
+// -----
+
+func.func @expand_shape_with_producer_pad_dynamic(%arg0: tensor<?x?x?x?xf32>,
+ %s0: index, %s1: index, %s2: index, %s3: index, %s4: index, %s5: index,
+ %l0: index, %l1: index, %h0: index, %h1: index) -> tensor<?x?x?x?x?x?xf32> {
+ %cst = arith.constant 0.0 : f32
+ %padded = tensor.pad %arg0 low[%l0, 0, %l1, 0] high[%h0, 0, %h1, 0] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
+ tensor.yield %cst : f32
+ } : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
+ %expanded = tensor.expand_shape %padded [[0], [1, 2], [3], [4, 5]] output_shape [%s0, %s1, %s2, %s3, %s4, %s5]
+ : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
+ return %expanded : tensor<?x?x?x?x?x?xf32>
+}
+// CHECK: func @expand_shape_with_producer_pad_dynamic
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index, %[[S4:.+]]: index, %[[S5:.+]]: index, %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index
+// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0:.+]] : tensor<?x?x?x?xf32>
+// CHECK: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2:.+]] : tensor<?x?x?x?xf32>
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5]] output_shape [%[[DIM0]], %[[S1]], %[[S2]], %[[DIM2]], %[[S4]], %[[S5]]]
+// CHECK: %[[PAD:.+]] = tensor.pad %[[EXPAND]] low[%[[L0]], 0, 0, %[[L1]], 0, 0] high[%[[H0]], 0, 0, %[[H1]], 0, 0]
+// CHECK: return %[[PAD]]
+
+// -----
+
+func.func @expand_shape_with_producer_pad_non_constant_padding(%arg0 : tensor<2x12xi32>) -> tensor<8x3x4xi32> {
+ %padded_0 = tensor.pad %arg0 low[1, 0] high[5, 0] {
+ ^bb0(%arg1: index, %arg2: index):
+ %pad_val = arith.index_cast %arg1 : index to i32
+ tensor.yield %pad_val : i32
+ } : tensor<2x12xi32> to tensor<8x12xi32>
+ %expand = tensor.expand_shape %padded_0 [[0], [1, 2]] output_shape [8, 3, 4] : tensor<8x12xi32> into tensor<8x3x4xi32>
+ return %expand : tensor<8x3x4xi32>
+}
+// CHECK: func @expand_shape_with_producer_pad_non_constant_padding(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12xi32>)
+// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]]
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]]
+// CHECK: return %[[EXPAND]]
+
+// -----
+
func.func @move_operand_deps(%arg0 : tensor<?x128xf16>,
%arg1 : tensor<4x?x32x128xf16>, %empty : tensor<4x?x32x128xf16>) -> tensor<4x?x32x8x16xf16> {
%c0 = arith.constant 0 : index
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
index e521608..ab38f9f 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
@@ -253,6 +253,40 @@ module {
// -----
+#map = affine_map<(d0) -> (d0 * 2)>
+#map1 = affine_map<(d0) -> (d0 * 4)>
+module {
+ // CHECK-LABEL: func.func @fuse_tileable_op_no_dps
+ func.func @fuse_tileable_op_no_dps(%arg0: tensor<4x4x4xf32>, %arg1: tensor<4x4x4xf32>) -> tensor<4x4x4xf32> {
+ %0 = "test.tiling_no_dps_op"(%arg0, %arg1) : (tensor<4x4x4xf32>, tensor<4x4x4xf32>) -> tensor<4x4x4xf32>
+ %1 = tensor.empty() : tensor<4x4x4xf32>
+ // CHECK: scf.forall
+ %2 = scf.forall (%arg2, %arg3, %arg4) in (4, 2, 1) shared_outs(%arg5 = %1) -> (tensor<4x4x4xf32>) {
+ %3 = affine.apply #map(%arg3)
+ %4 = affine.apply #map1(%arg4)
+ // CHECK: "test.tiling_no_dps_op"
+ // CHECK: "test.unregistered_op"
+ %extracted_slice = tensor.extract_slice %0[%arg2, %3, %4] [1, 2, 4] [1, 1, 1] : tensor<4x4x4xf32> to tensor<1x2x4xf32>
+ %5 = "test.unregistered_op"(%extracted_slice, %extracted_slice) : (tensor<1x2x4xf32>, tensor<1x2x4xf32>) -> tensor<1x2x4xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %5 into %arg5[%arg2, %3, %4] [1, 2, 4] [1, 1, 1] : tensor<1x2x4xf32> into tensor<4x4x4xf32>
+ }
+ }
+ return %2 : tensor<4x4x4xf32>
+ }
+
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %op = transform.structured.match ops{["test.tiling_no_dps_op"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %forall = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %fused, %new_containing = transform.structured.fuse_into_containing_op %op into %forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+ }
+}
+
+// -----
+
module {
// CHECK-LABEL: func.func @fuse_tileable_op_through_bbarg_inout_nested
// CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<?x?x?xf32>
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
index 185fb9b..d72ab08 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
@@ -170,7 +170,7 @@ module {
// Fuse the consumer operation into the tiled loop.
%slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op
: (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice">
- transform.test.fuse_consumer %slice_op in (%forall_op)
+ transform.test.fuse_consumer_using_slice %slice_op in (%forall_op)
: (!transform.op<"tensor.parallel_insert_slice">, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
@@ -231,7 +231,7 @@ module {
// Fuse the consumer operation into the tiled loop.
%slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op
: (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice">
- // Note that we cannot apply transform.test.fuse_consumer here because the extract_slice
+ // Note that we cannot apply transform.test.fuse_consumer_using_slice here because the extract_slice
// is not qualified consumer operation. Forcing this will yeild "could not fetch consumer
// to fuse" error.
transform.yield
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
index aa2c1da..95959fc 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
@@ -285,6 +285,8 @@ module attributes {transform.with_named_sequence} {
///----------------------------------------------------------------------------------------
/// Tests for linalg.pack
+///
+/// TODO: Add similar tests for linalg.unpack
///----------------------------------------------------------------------------------------
// Note, see a similar test in:
@@ -1479,23 +1481,23 @@ module attributes {transform.with_named_sequence} {
// -----
-// CHECK-LABEL: func @reduce_1d(
-// CHECK-SAME: %[[A:.*]]: tensor<32xf32>
-func.func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
+// CHECK-LABEL: func @reduce_to_rank_0(
+// CHECK-SAME: %[[SRC:.*]]: tensor<32xf32>
+func.func @reduce_to_rank_0(%arg0: tensor<32xf32>) -> tensor<f32> {
// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
%f0 = arith.constant 0.000000e+00 : f32
- // CHECK: %[[init:.*]] = tensor.empty() : tensor<f32>
+ // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<f32>
%0 = tensor.empty() : tensor<f32>
%1 = linalg.fill ins(%f0 : f32) outs(%0 : tensor<f32>) -> tensor<f32>
- // CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]]
+ // CHECK: %[[R:.*]] = vector.transfer_read %[[SRC]][%[[C0]]]
// CHECK-SAME: : tensor<32xf32>, vector<32xf32>
- // CHECK: %[[red:.*]] = vector.multi_reduction <add>, %[[r]], %[[F0]] [0]
+ // CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[R]], %[[F0]] [0]
// CHECK-SAME: : vector<32xf32> to f32
- // CHECK: %[[red_v1:.*]] = vector.broadcast %[[red]] : f32 to vector<f32>
- // CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[init]][]
+ // CHECK: %[[RED_V1:.*]] = vector.broadcast %[[RED]] : f32 to vector<f32>
+ // CHECK: %[[RES:.*]] = vector.transfer_write %[[RED_V1]], %[[INIT]][]
// CHECK-SAME: : vector<f32>, tensor<f32>
%2 = linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>,
@@ -1523,6 +1525,58 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: func @reduce_to_rank_1(
+// CHECK-SAME: %[[SRC:.*]]: tensor<32xf32>
+func.func @reduce_to_rank_1(%arg0: tensor<32xf32>) -> tensor<1xf32> {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[F0:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
+ %f0 = arith.constant 0.000000e+00 : f32
+
+ // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<1xf32>
+ %0 = tensor.empty() : tensor<1xf32>
+
+ // CHECK: %[[INIT_ZERO:.*]] = vector.transfer_write %[[F0]], %[[INIT]][%[[C0]]]
+ // CHECK-SAME: : vector<1xf32>, tensor<1xf32>
+ %1 = linalg.fill ins(%f0 : f32) outs(%0 : tensor<1xf32>) -> tensor<1xf32>
+
+ // CHECK: %[[R:.*]] = vector.transfer_read %[[SRC]][%[[C0]]]
+ // CHECK-SAME: : tensor<32xf32>, vector<32xf32>
+ // CHECK: %[[INIT_ZERO_VEC:.*]] = vector.transfer_read %[[INIT_ZERO]][%[[C0]]]
+ // CHECK-SAME: : tensor<1xf32>, vector<f32>
+ // CHECK: %[[INIT_ZERO_SCL:.*]] = vector.extract %[[INIT_ZERO_VEC]][]
+ // CHECK-SAME: : f32 from vector<f32>
+ // CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[R]], %[[INIT_ZERO_SCL]] [0]
+ // CHECK-SAME: : vector<32xf32> to f32
+ // CHECK: %[[RED_V1:.*]] = vector.broadcast %[[RED]] : f32 to vector<f32>
+ // CHECK: vector.transfer_write %[[RED_V1]], %[[INIT_ZERO]][%[[C0]]]
+ // CHECK-SAME: : vector<f32>, tensor<1xf32>
+
+ %2 = linalg.generic {
+ indexing_maps = [affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (0)>],
+ iterator_types = ["reduction"]}
+ ins(%arg0 : tensor<32xf32>)
+ outs(%1 : tensor<1xf32>) {
+ ^bb0(%a: f32, %b: f32):
+ %3 = arith.addf %a, %b : f32
+ linalg.yield %3 : f32
+ } -> tensor<1xf32>
+
+ return %2 : tensor<1xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+
+// -----
+
// This test checks that vectorization does not occur when an input indexing map
// is not a projected permutation. In the future, this can be converted to a
// positive test when support is added.
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
index 1304a90..170bae6 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
@@ -1335,7 +1335,7 @@ func.func @pack_no_padding(%src: tensor<32x8x16xf32>, %dest: tensor<4x1x32x16x2x
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%src: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.pack"]} in %src : (!transform.any_op) -> !transform.any_op
- transform.structured.vectorize %0 vector_sizes [4, 1, 32] : !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [4, 1, 32, 16, 2] : !transform.any_op
transform.yield
}
}
@@ -1378,7 +1378,7 @@ func.func @pack_with_padding(%src: tensor<32x7x15xf32>, %dest: tensor<32x4x1x16x
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- transform.structured.vectorize %0 vector_sizes [32, 4, 1] : !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [32, 4, 1, 16, 2] : !transform.any_op
transform.yield
}
}
@@ -1424,8 +1424,13 @@ module attributes {transform.with_named_sequence} {
// CHECK-LABEL: func @pack_with_dynamic_dims
// CHECK-SAME: %[[SRC:.*]]: tensor<?x?xf32>,
// CHECK-SAME: %[[DEST:.*]]: tensor<?x?x16x2xf32>
-func.func @pack_with_dynamic_dims(%src: tensor<?x?xf32>, %dest: tensor<?x?x16x2xf32>) -> tensor<?x?x16x2xf32> {
- %pack = linalg.pack %src inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %dest : tensor<?x?xf32> -> tensor<?x?x16x2xf32>
+func.func @pack_with_dynamic_dims(
+ %src: tensor<?x?xf32>,
+ %dest: tensor<?x?x16x2xf32>) -> tensor<?x?x16x2xf32> {
+ %pack = linalg.pack %src
+ inner_dims_pos = [1, 0]
+ inner_tiles = [16, 2]
+ into %dest : tensor<?x?xf32> -> tensor<?x?x16x2xf32>
return %pack : tensor<?x?x16x2xf32>
}
@@ -1433,30 +1438,108 @@ func.func @pack_with_dynamic_dims(%src: tensor<?x?xf32>, %dest: tensor<?x?x16x2x
// CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1_0:.*]] = arith.constant 1 : index
+
+/// Compute mask for xfer_read
// CHECK-DAG: %[[D0_0:.*]] = tensor.dim {{.*}} %[[C0_0]] : tensor<?x?xf32>
// CHECK-DAG: %[[D1_0:.*]] = tensor.dim {{.*}} %[[C1_0]] : tensor<?x?xf32>
// CHECK: %[[MASK:.*]] = vector.create_mask %[[D0_0]], %[[D1_0]] : vector<8x16xi1>
+
+/// --= read =---
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] {
// CHECK-SAME: vector.transfer_read %{{.*}}[%[[C0_1]], %[[C0_1]]], %[[CST]]
// CHECK-SAME: {in_bounds = [true, true]} : tensor<?x?xf32>, vector<8x16xf32>
// CHECK-SAME: } : vector<8x16xi1> -> vector<8x16xf32>
+
+/// --= shape_cast =---
// CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<8x16xf32> to vector<4x2x1x16xf32>
+
+/// --= transpose =---
// CHECK: %[[TR:.*]] = vector.transpose %[[SC]], [0, 2, 3, 1] : vector<4x2x1x16xf32> to vector<4x1x16x2xf32>
+
+/// Compute mask for xfer_write
// CHECK-DAG: %[[C0_2:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[D2:.*]] = tensor.dim %[[DEST]], {{.*}} : tensor<?x?x16x2xf32>
// CHECK-DAG: %[[D3:.*]] = tensor.dim %[[DEST]], {{.*}} : tensor<?x?x16x2xf32>
// CHECK: %[[MASK_0:.*]] = vector.create_mask %[[D2]], %[[D3]], %[[C16]], %[[C2]] : vector<4x1x16x2xi1>
+
+/// --= write =---
// CHECK: %[[WRITE:.*]] = vector.mask %[[MASK_0]] {
// CHECK-SAME: vector.transfer_write %[[TR]], %[[DEST]][%[[C0_2]], %[[C0_2]], %[[C0_2]], %[[C0_2]]]
// CHECK-SAME: {in_bounds = [true, true, true, true]} : vector<4x1x16x2xf32>, tensor<?x?x16x2xf32>
+
// CHECK: return %[[WRITE]] : tensor<?x?x16x2xf32>
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- transform.structured.vectorize %0 vector_sizes [4, 1] : !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [4, 1, 16, 2] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+/// Similar to the test above, but one of the inner tile sizes is dynamic. As a
+/// result, more output dims are dynamic (and, e.g., output mask calcuation is a bit different).
+
+// CHECK-LABEL: func @pack_with_dynamic_dims_and_dynamic_inner_tile
+// CHECK-SAME: %[[SRC:.*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[DEST:.*]]: tensor<?x?x?x2xf32>
+func.func @pack_with_dynamic_dims_and_dynamic_inner_tile(
+ %src: tensor<?x?xf32>,
+ %dest: tensor<?x?x?x2xf32>) -> tensor<?x?x?x2xf32> {
+ %c16 = arith.constant 16 : index
+ %pack = linalg.pack %src
+ inner_dims_pos = [1, 0]
+ inner_tiles = [%c16, 2]
+ into %dest : tensor<?x?xf32> -> tensor<?x?x?x2xf32>
+ return %pack : tensor<?x?x?x2xf32>
+}
+
+// CHECK-DAG: %[[CST:.*]] = ub.poison : f32
+// CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1_0:.*]] = arith.constant 1 : index
+
+/// Compute mask for xfer_read
+// CHECK-DAG: %[[D0_0:.*]] = tensor.dim {{.*}} %[[C0_0]] : tensor<?x?xf32>
+// CHECK-DAG: %[[D1_0:.*]] = tensor.dim {{.*}} %[[C1_0]] : tensor<?x?xf32>
+// CHECK: %[[MASK:.*]] = vector.create_mask %[[D0_0]], %[[D1_0]] : vector<8x16xi1>
+
+/// --= read =---
+// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] {
+// CHECK-SAME: vector.transfer_read %{{.*}}[%[[C0_1]], %[[C0_1]]], %[[CST]]
+// CHECK-SAME: {in_bounds = [true, true]} : tensor<?x?xf32>, vector<8x16xf32>
+// CHECK-SAME: } : vector<8x16xi1> -> vector<8x16xf32>
+
+/// --= shape_cast =---
+// CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<8x16xf32> to vector<4x2x1x16xf32>
+
+/// --= transpose =---
+// CHECK: %[[TR:.*]] = vector.transpose %[[SC]], [0, 2, 3, 1] : vector<4x2x1x16xf32> to vector<4x1x16x2xf32>
+
+/// Compute mask for xfer_write
+// CHECK-DAG: %[[C0_2:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C2_2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[D2:.*]] = tensor.dim %[[DEST]], {{.*}} : tensor<?x?x?x2xf32>
+// CHECK-DAG: %[[D3:.*]] = tensor.dim %[[DEST]], {{.*}} : tensor<?x?x?x2xf32>
+// CHECK-DAG: %[[D4:.*]] = tensor.dim %[[DEST]], {{.*}} : tensor<?x?x?x2xf32>
+// CHECK: %[[MASK_0:.*]] = vector.create_mask %[[D2]], %[[D3]], %[[D4]], %[[C2_2]] : vector<4x1x16x2xi1>
+
+/// --= write =---
+// CHECK: %[[WRITE:.*]] = vector.mask %[[MASK_0]] {
+// CHECK-SAME: vector.transfer_write %[[TR]], %[[DEST]][%[[C0_2]], %[[C0_2]], %[[C0_2]], %[[C0_2]]]
+// CHECK-SAME: {in_bounds = [true, true, true, true]} : vector<4x1x16x2xf32>, tensor<?x?x?x2xf32>
+
+// CHECK: return %[[WRITE]] : tensor<?x?x?x2xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [4, 1, 16, 2] : !transform.any_op
transform.yield
}
}
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 3130902..e02717a 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -208,19 +208,6 @@ func.func @subview_negative_stride2(%arg0 : memref<7xf32>) -> memref<?xf32, stri
// -----
-// CHECK-LABEL: func @dim_of_sized_view
-// CHECK-SAME: %{{[a-z0-9A-Z_]+}}: memref<?xi8>
-// CHECK-SAME: %[[SIZE:.[a-z0-9A-Z_]+]]: index
-// CHECK: return %[[SIZE]] : index
-func.func @dim_of_sized_view(%arg : memref<?xi8>, %size: index) -> index {
- %c0 = arith.constant 0 : index
- %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [%size], strides: [1] : memref<?xi8> to memref<?xi8>
- %1 = memref.dim %0, %c0 : memref<?xi8>
- return %1 : index
-}
-
-// -----
-
// CHECK-LABEL: func @no_fold_subview_negative_size
// CHECK: %[[SUBVIEW:.+]] = memref.subview
// CHECK: return %[[SUBVIEW]]
diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
index 18cdfb7..4ed8d4b 100644
--- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
@@ -1455,3 +1455,20 @@ func.func @extract_strided_metadata_of_memory_space_cast_no_base(%base: memref<2
// CHECK-LABEL: func @extract_strided_metadata_of_memory_space_cast_no_base
// CHECK-NOT: memref.memory_space_cast
+
+// -----
+
+func.func @negative_memref_view_extract_aligned_pointer(%arg0: memref<?xi8>) -> index {
+ // `extract_aligned_pointer_as_index` must not be folded as `memref.view` can change the base pointer
+ // CHECK-LABEL: func @negative_memref_view_extract_aligned_pointer
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?xi8>)
+ // CHECK: %[[C10:.*]] = arith.constant 10 : index
+ // CHECK: %[[VIEW:.*]] = memref.view %[[ARG0]][%[[C10]]][] : memref<?xi8> to memref<f32>
+ // CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[VIEW]] : memref<f32> -> index
+ // CHECK: return %[[PTR]] : index
+
+ %c10 = arith.constant 10 : index
+ %0 = memref.view %arg0[%c10][] : memref<?xi8> to memref<f32>
+ %1 = memref.extract_aligned_pointer_as_index %0: memref<f32> -> index
+ return %1 : index
+}
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 1066526..ca91b01 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -992,6 +992,55 @@ func.func @fold_vector_maskedstore_expand_shape(
// -----
+func.func @fold_vector_transfer_read_expand_shape(
+ %arg0 : memref<32xf32>, %arg1 : index) -> vector<8xf32> {
+ %c0 = arith.constant 0 : index
+ %pad = ub.poison : f32
+ %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+ %1 = vector.transfer_read %0[%arg1, %c0], %pad {in_bounds = [true]} : memref<4x8xf32>, vector<8xf32>
+ return %1 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @fold_vector_transfer_read_expand_shape
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK: %[[C0:.*]] = arith.constant 0
+// CHECK: %[[PAD:.*]] = ub.poison : f32
+// CHECK: %[[IDX:.*]] = affine.linearize_index [%[[ARG1]], %[[C0]]] by (4, 8)
+// CHECK: vector.transfer_read %[[ARG0]][%[[IDX]]], %[[PAD]] {in_bounds = [true]}
+
+// -----
+
+func.func @fold_vector_transfer_read_with_perm_map(
+ %arg0 : memref<32xf32>, %arg1 : index) -> vector<4x4xf32> {
+ %c0 = arith.constant 0 : index
+ %pad = ub.poison : f32
+ %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+ %1 = vector.transfer_read %0[%arg1, %c0], %pad { permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<4x8xf32>, vector<4x4xf32>
+ return %1 : vector<4x4xf32>
+}
+
+// CHECK-LABEL: func @fold_vector_transfer_read_with_perm_map
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+// CHECK: memref.expand_shape %[[ARG0]] {{\[}}[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+
+// -----
+
+func.func @fold_vector_transfer_read_rank_mismatch(
+ %arg0 : memref<32xf32>, %arg1 : index) -> vector<4x4xf32> {
+ %c0 = arith.constant 0 : index
+ %pad = ub.poison : f32
+ %0 = memref.expand_shape %arg0 [[0, 1, 2]] output_shape [2, 4, 4] : memref<32xf32> into memref<2x4x4xf32>
+ %1 = vector.transfer_read %0[%arg1, %c0, %c0], %pad {in_bounds = [true, true]} : memref<2x4x4xf32>, vector<4x4xf32>
+ return %1 : vector<4x4xf32>
+}
+
+// CHECK-LABEL: func @fold_vector_transfer_read_rank_mismatch
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+// CHECK: memref.expand_shape %[[ARG0]] {{\[}}[0, 1, 2]] output_shape [2, 4, 4] : memref<32xf32> into memref<2x4x4xf32>
+
+// -----
+
func.func @fold_vector_load_collapse_shape(
%arg0 : memref<4x8xf32>, %arg1 : index) -> vector<8xf32> {
%0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 5ff2920..d10651f 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -992,6 +992,22 @@ func.func @invalid_store_alignment(%memref: memref<4xi32>, %val: i32) {
// -----
+func.func @invalid_alloc_alignment() {
+ // expected-error @below {{'memref.alloc' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ %0 = memref.alloc() {alignment = 3} : memref<4xf32>
+ return
+}
+
+// -----
+
+func.func @invalid_realloc_alignment(%src: memref<4xf32>) {
+ // expected-error @below {{'memref.realloc' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ %0 = memref.realloc %src {alignment = 7} : memref<4xf32> to memref<8xf32>
+ return
+}
+
+// -----
+
func.func @test_alloc_memref_map_rank_mismatch() {
^bb0:
// expected-error@+1 {{memref layout mismatch between rank and affine map: 2 != 1}}
diff --git a/mlir/test/Dialect/MemRef/mem2reg.mlir b/mlir/test/Dialect/MemRef/mem2reg.mlir
index d300699..dd68675 100644
--- a/mlir/test/Dialect/MemRef/mem2reg.mlir
+++ b/mlir/test/Dialect/MemRef/mem2reg.mlir
@@ -18,7 +18,7 @@ func.func @basic() -> i32 {
// CHECK-LABEL: func.func @basic_default
func.func @basic_default() -> i32 {
// CHECK-NOT: = memref.alloca
- // CHECK: %[[RES:.*]] = arith.constant 0 : i32
+ // CHECK: %[[RES:.*]] = ub.poison : i32
// CHECK-NOT: = memref.alloca
%0 = arith.constant 5 : i32
%1 = memref.alloca() : memref<i32>
diff --git a/mlir/test/Dialect/MemRef/transform-ops.mlir b/mlir/test/Dialect/MemRef/transform-ops.mlir
index 3b37c62..7fc84d4 100644
--- a/mlir/test/Dialect/MemRef/transform-ops.mlir
+++ b/mlir/test/Dialect/MemRef/transform-ops.mlir
@@ -306,6 +306,23 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: func.func @dead_alloc_escaped
+func.func @dead_alloc_escaped() -> memref<8x64xf32, 3> {
+ // CHECK: %{{.+}} = memref.alloc
+ %0 = memref.alloc() : memref<8x64xf32, 3>
+ return %0 : memref<8x64xf32, 3>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.memref.erase_dead_alloc_and_stores %0 : (!transform.any_op) -> ()
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK-LABEL: func.func @dead_alloc
func.func @dead_alloc() {
// CHECK-NOT: %{{.+}} = memref.alloc
@@ -378,6 +395,73 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: @dead_store_through_subview
+// CHECK-SAME: (%[[ARG:.+]]: vector<4xf32>)
+// CHECK-NOT: memref.alloc()
+// CHECK-NOT: vector.transfer_write
+func.func @dead_store_through_subview(%arg: vector<4xf32>) {
+ %c0 = arith.constant 0 : index
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<64xf32>
+ %subview = memref.subview %alloc[%c0] [4] [1] : memref<64xf32> to memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
+ vector.transfer_write %arg, %subview[%c0] {in_bounds = [true]}
+ : vector<4xf32>, memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.memref.erase_dead_alloc_and_stores %0 : (!transform.any_op) -> ()
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @dead_store_through_expand
+// CHECK-SAME: (%[[ARG:.+]]: vector<4xf32>)
+// CHECK-NOT: memref.alloc()
+// CHECK-NOT: vector.transfer_write
+func.func @dead_store_through_expand(%arg: vector<4xf32>) {
+ %c0 = arith.constant 0 : index
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<64xf32>
+ %expand = memref.expand_shape %alloc [[0, 1]] output_shape [16, 4] : memref<64xf32> into memref<16x4xf32>
+ vector.transfer_write %arg, %expand[%c0, %c0] {in_bounds = [true]} : vector<4xf32>, memref<16x4xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.memref.erase_dead_alloc_and_stores %0 : (!transform.any_op) -> ()
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @dead_store_through_collapse
+// CHECK-SAME: (%[[ARG:.+]]: vector<4xf32>)
+// CHECK-NOT: memref.alloc()
+// CHECK-NOT: vector.transfer_write
+func.func @dead_store_through_collapse(%arg: vector<4xf32>) {
+ %c0 = arith.constant 0 : index
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<16x4xf32>
+ %collapse = memref.collapse_shape %alloc [[0, 1]] : memref<16x4xf32> into memref<64xf32>
+ vector.transfer_write %arg, %collapse[%c0] {in_bounds = [true]} : vector<4xf32>, memref<64xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.memref.erase_dead_alloc_and_stores %0 : (!transform.any_op) -> ()
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK-LABEL: func @lower_to_llvm
// CHECK-NOT: memref.alloc
// CHECK: llvm.call @malloc
diff --git a/mlir/test/Dialect/OpenACC/acc-implicit-data-reduction.mlir b/mlir/test/Dialect/OpenACC/acc-implicit-data-reduction.mlir
new file mode 100644
index 0000000..fed0a4b
--- /dev/null
+++ b/mlir/test/Dialect/OpenACC/acc-implicit-data-reduction.mlir
@@ -0,0 +1,109 @@
+// RUN: mlir-opt %s -acc-implicit-data=enable-implicit-reduction-copy=true -split-input-file | FileCheck %s --check-prefix=COPY
+// RUN: mlir-opt %s -acc-implicit-data=enable-implicit-reduction-copy=false -split-input-file | FileCheck %s --check-prefix=FIRSTPRIVATE
+
+// Test case: scalar reduction variable in parallel loop
+// When enable-implicit-reduction-copy=true: expect copyin/copyout for reduction variable
+// When enable-implicit-reduction-copy=false: expect firstprivate for reduction variable
+
+acc.reduction.recipe @reduction_add_memref_i32 : memref<i32> reduction_operator <add> init {
+^bb0(%arg0: memref<i32>):
+ %c0_i32 = arith.constant 0 : i32
+ %alloc = memref.alloca() : memref<i32>
+ memref.store %c0_i32, %alloc[] : memref<i32>
+ acc.yield %alloc : memref<i32>
+} combiner {
+^bb0(%arg0: memref<i32>, %arg1: memref<i32>):
+ %0 = memref.load %arg0[] : memref<i32>
+ %1 = memref.load %arg1[] : memref<i32>
+ %2 = arith.addi %0, %1 : i32
+ memref.store %2, %arg0[] : memref<i32>
+ acc.yield %arg0 : memref<i32>
+}
+
+func.func @test_reduction_implicit_copy() {
+ %c1_i32 = arith.constant 1 : i32
+ %c100_i32 = arith.constant 100 : i32
+ %c0_i32 = arith.constant 0 : i32
+ %r = memref.alloca() : memref<i32>
+ memref.store %c0_i32, %r[] : memref<i32>
+
+ acc.parallel {
+ %red_var = acc.reduction varPtr(%r : memref<i32>) recipe(@reduction_add_memref_i32) -> memref<i32> {name = "r"}
+ acc.loop reduction(%red_var : memref<i32>) control(%iv : i32) = (%c1_i32 : i32) to (%c100_i32 : i32) step (%c1_i32 : i32) {
+ %load = memref.load %red_var[] : memref<i32>
+ %add = arith.addi %load, %c1_i32 : i32
+ memref.store %add, %red_var[] : memref<i32>
+ acc.yield
+ } attributes {independent = [#acc.device_type<none>]}
+ acc.yield
+ }
+ return
+}
+
+// When enable-implicit-reduction-copy=true: expect copyin/copyout for reduction variable
+// COPY-LABEL: func.func @test_reduction_implicit_copy
+// COPY: %[[COPYIN:.*]] = acc.copyin varPtr({{.*}} : memref<i32>) -> memref<i32> {dataClause = #acc<data_clause acc_reduction>, implicit = true, name = ""}
+// COPY: acc.copyout accPtr(%[[COPYIN]] : memref<i32>) to varPtr({{.*}} : memref<i32>) {dataClause = #acc<data_clause acc_copy>, implicit = true, name = ""}
+
+// When enable-implicit-reduction-copy=false: expect firstprivate for reduction variable
+// FIRSTPRIVATE-LABEL: func.func @test_reduction_implicit_copy
+// FIRSTPRIVATE: acc.firstprivate varPtr({{.*}} : memref<i32>) recipe({{.*}}) -> memref<i32> {implicit = true, name = ""}
+// FIRSTPRIVATE-NOT: acc.copyin
+// FIRSTPRIVATE-NOT: acc.copyout
+
+// -----
+
+// Test case: reduction variable used both in loop and outside
+// Should be firstprivate regardless of the flag setting
+
+acc.reduction.recipe @reduction_add_memref_i32_2 : memref<i32> reduction_operator <add> init {
+^bb0(%arg0: memref<i32>):
+ %c0_i32 = arith.constant 0 : i32
+ %alloc = memref.alloca() : memref<i32>
+ memref.store %c0_i32, %alloc[] : memref<i32>
+ acc.yield %alloc : memref<i32>
+} combiner {
+^bb0(%arg0: memref<i32>, %arg1: memref<i32>):
+ %0 = memref.load %arg0[] : memref<i32>
+ %1 = memref.load %arg1[] : memref<i32>
+ %2 = arith.addi %0, %1 : i32
+ memref.store %2, %arg0[] : memref<i32>
+ acc.yield %arg0 : memref<i32>
+}
+
+func.func @test_reduction_with_usage_outside_loop() {
+ %c1_i32 = arith.constant 1 : i32
+ %c100_i32 = arith.constant 100 : i32
+ %c0_i32 = arith.constant 0 : i32
+ %r = memref.alloca() : memref<i32>
+ %out = memref.alloca() : memref<i32>
+ memref.store %c0_i32, %r[] : memref<i32>
+
+ %out_create = acc.create varPtr(%out : memref<i32>) -> memref<i32> {dataClause = #acc<data_clause acc_copyout>, name = "out"}
+ acc.parallel dataOperands(%out_create : memref<i32>) {
+ %red_var = acc.reduction varPtr(%r : memref<i32>) recipe(@reduction_add_memref_i32_2) -> memref<i32> {name = "r"}
+ acc.loop reduction(%red_var : memref<i32>) control(%iv : i32) = (%c1_i32 : i32) to (%c100_i32 : i32) step (%c1_i32 : i32) {
+ %load = memref.load %red_var[] : memref<i32>
+ %add = arith.addi %load, %c1_i32 : i32
+ memref.store %add, %red_var[] : memref<i32>
+ acc.yield
+ } attributes {independent = [#acc.device_type<none>]}
+ // out = r (usage of r outside the loop)
+ %final_r = memref.load %r[] : memref<i32>
+ memref.store %final_r, %out_create[] : memref<i32>
+ acc.yield
+ }
+ acc.copyout accPtr(%out_create : memref<i32>) to varPtr(%out : memref<i32>) {dataClause = #acc<data_clause acc_copyout>, name = "out"}
+ return
+}
+
+// In this case, r should be firstprivate regardless of the flag setting
+// because it's used outside the reduction context
+// COPY-LABEL: func.func @test_reduction_with_usage_outside_loop
+// COPY: acc.firstprivate varPtr({{.*}} : memref<i32>) recipe({{.*}}) -> memref<i32> {implicit = true, name = ""}
+// COPY-NOT: acc.copyin varPtr({{.*}} : memref<i32>) -> memref<i32> {{.*}} name = ""
+
+// FIRSTPRIVATE-LABEL: func.func @test_reduction_with_usage_outside_loop
+// FIRSTPRIVATE: acc.firstprivate varPtr({{.*}} : memref<i32>) recipe({{.*}}) -> memref<i32> {implicit = true, name = ""}
+// FIRSTPRIVATE-NOT: acc.copyin varPtr({{.*}} : memref<i32>) -> memref<i32> {{.*}} name = ""
+
diff --git a/mlir/test/Dialect/OpenACC/acc-implicit-data.mlir b/mlir/test/Dialect/OpenACC/acc-implicit-data.mlir
new file mode 100644
index 0000000..6909fe6
--- /dev/null
+++ b/mlir/test/Dialect/OpenACC/acc-implicit-data.mlir
@@ -0,0 +1,224 @@
+// RUN: mlir-opt %s -acc-implicit-data -split-input-file | FileCheck %s
+
+// -----
+
+// Test scalar in serial construct - should generate firstprivate
+func.func @test_scalar_in_serial() {
+ %alloc = memref.alloca() : memref<i64>
+ acc.serial {
+ %load = memref.load %alloc[] : memref<i64>
+ acc.yield
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @test_scalar_in_serial
+// CHECK: acc.firstprivate varPtr({{.*}} : memref<i64>) recipe({{.*}}) -> memref<i64> {implicit = true, name = ""}
+
+// -----
+
+// Test scalar in parallel construct - should generate firstprivate
+func.func @test_scalar_in_parallel() {
+ %alloc = memref.alloca() : memref<f32>
+ acc.parallel {
+ %load = memref.load %alloc[] : memref<f32>
+ acc.yield
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @test_scalar_in_parallel
+// CHECK: acc.firstprivate varPtr({{.*}} : memref<f32>) recipe({{.*}}) -> memref<f32> {implicit = true, name = ""}
+
+// -----
+
+// Test scalar in kernels construct - should generate copyin/copyout
+func.func @test_scalar_in_kernels() {
+ %alloc = memref.alloca() : memref<f64>
+ acc.kernels {
+ %load = memref.load %alloc[] : memref<f64>
+ acc.terminator
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @test_scalar_in_kernels
+// CHECK: %[[COPYIN:.*]] = acc.copyin varPtr({{.*}} : memref<f64>) -> memref<f64> {dataClause = #acc<data_clause acc_copy>, implicit = true, name = ""}
+// CHECK: acc.copyout accPtr(%[[COPYIN]] : memref<f64>) to varPtr({{.*}} : memref<f64>) {dataClause = #acc<data_clause acc_copy>, implicit = true, name = ""}
+
+// -----
+
+// Test scalar in parallel with default(none) - should NOT generate implicit data
+func.func @test_scalar_parallel_defaultnone() {
+ %alloc = memref.alloca() : memref<f32>
+ acc.parallel {
+ %load = memref.load %alloc[] : memref<f32>
+ acc.yield
+ } attributes {defaultAttr = #acc<defaultvalue none>}
+ return
+}
+
+// CHECK-LABEL: func.func @test_scalar_parallel_defaultnone
+// CHECK-NOT: acc.firstprivate
+// CHECK-NOT: acc.copyin
+
+// -----
+
+// Test array in parallel - should generate copyin/copyout
+func.func @test_array_in_parallel() {
+ %alloc = memref.alloca() : memref<10xf32>
+ acc.parallel {
+ %c0 = arith.constant 0 : index
+ %load = memref.load %alloc[%c0] : memref<10xf32>
+ acc.yield
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @test_array_in_parallel
+// CHECK: %[[COPYIN:.*]] = acc.copyin varPtr({{.*}} : memref<10xf32>) -> memref<10xf32> {dataClause = #acc<data_clause acc_copy>, implicit = true, name = ""}
+// CHECK: acc.copyout accPtr(%[[COPYIN]] : memref<10xf32>) to varPtr({{.*}} : memref<10xf32>) {dataClause = #acc<data_clause acc_copy>, implicit = true, name = ""}
+
+// -----
+
+// Test array in kernels - should generate copyin/copyout
+func.func @test_array_in_kernels() {
+ %alloc = memref.alloca() : memref<20xi32>
+ acc.kernels {
+ %c0 = arith.constant 0 : index
+ %load = memref.load %alloc[%c0] : memref<20xi32>
+ acc.terminator
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @test_array_in_kernels
+// CHECK: %[[COPYIN:.*]] = acc.copyin varPtr({{.*}} : memref<20xi32>) -> memref<20xi32> {dataClause = #acc<data_clause acc_copy>, implicit = true, name = ""}
+// CHECK: acc.copyout accPtr(%[[COPYIN]] : memref<20xi32>) to varPtr({{.*}} : memref<20xi32>) {dataClause = #acc<data_clause acc_copy>, implicit = true, name = ""}
+
+// -----
+
+// Test array with default(present) - should generate present
+func.func @test_array_parallel_defaultpresent() {
+ %alloc = memref.alloca() : memref<10xf32>
+ acc.parallel {
+ %c0 = arith.constant 0 : index
+ %load = memref.load %alloc[%c0] : memref<10xf32>
+ acc.yield
+ } attributes {defaultAttr = #acc<defaultvalue present>}
+ return
+}
+
+// CHECK-LABEL: func.func @test_array_parallel_defaultpresent
+// CHECK: %[[PRESENT:.*]] = acc.present varPtr({{.*}} : memref<10xf32>) -> memref<10xf32> {acc.from_default, implicit = true, name = ""}
+// CHECK: acc.delete accPtr(%[[PRESENT]] : memref<10xf32>) {dataClause = #acc<data_clause acc_present>, implicit = true, name = ""}
+
+// -----
+
+// Test scalar with default(present) - should still generate firstprivate (scalars ignore default(present))
+func.func @test_scalar_parallel_defaultpresent() {
+ %alloc = memref.alloca() : memref<f32>
+ acc.parallel {
+ %load = memref.load %alloc[] : memref<f32>
+ acc.yield
+ } attributes {defaultAttr = #acc<defaultvalue present>}
+ return
+}
+
+// CHECK-LABEL: func.func @test_scalar_parallel_defaultpresent
+// CHECK: acc.firstprivate varPtr({{.*}} : memref<f32>) recipe({{.*}}) -> memref<f32> {implicit = true, name = ""}
+
+// -----
+
+// Test multidimensional array
+func.func @test_multidim_array_in_parallel() {
+ %alloc = memref.alloca() : memref<8x16xf32>
+ acc.parallel {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %load = memref.load %alloc[%c0, %c1] : memref<8x16xf32>
+ acc.yield
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @test_multidim_array_in_parallel
+// CHECK: %[[COPYIN:.*]] = acc.copyin varPtr({{.*}} : memref<8x16xf32>) -> memref<8x16xf32> {dataClause = #acc<data_clause acc_copy>, implicit = true, name = ""}
+// CHECK: acc.copyout accPtr(%[[COPYIN]] : memref<8x16xf32>) to varPtr({{.*}} : memref<8x16xf32>) {dataClause = #acc<data_clause acc_copy>, implicit = true, name = ""}
+
+// -----
+
+// Test dynamic size array
+func.func @test_dynamic_array(%size: index) {
+ %alloc = memref.alloca(%size) : memref<?xf64>
+ acc.parallel {
+ %c0 = arith.constant 0 : index
+ %load = memref.load %alloc[%c0] : memref<?xf64>
+ acc.yield
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @test_dynamic_array
+// CHECK: %[[COPYIN:.*]] = acc.copyin varPtr({{.*}} : memref<?xf64>) -> memref<?xf64> {dataClause = #acc<data_clause acc_copy>, implicit = true, name = ""}
+// CHECK: acc.copyout accPtr(%[[COPYIN]] : memref<?xf64>) to varPtr({{.*}} : memref<?xf64>) {dataClause = #acc<data_clause acc_copy>, implicit = true, name = ""}
+
+// -----
+
+// Test variable with explicit data clause - implicit should recognize it
+func.func @test_with_explicit_copyin() {
+ %alloc = memref.alloca() : memref<100xf32>
+ %copyin = acc.copyin varPtr(%alloc : memref<100xf32>) -> memref<100xf32> {name = "explicit"}
+ acc.parallel dataOperands(%copyin : memref<100xf32>) {
+ %c0 = arith.constant 0 : index
+ %load = memref.load %alloc[%c0] : memref<100xf32>
+ acc.yield
+ }
+ acc.copyout accPtr(%copyin : memref<100xf32>) to varPtr(%alloc : memref<100xf32>) {name = "explicit"}
+ return
+}
+
+// CHECK-LABEL: func.func @test_with_explicit_copyin
+// CHECK: acc.present varPtr({{.*}} : memref<100xf32>) -> memref<100xf32> {implicit = true, name = ""}
+
+// -----
+
+// Test multiple variables
+func.func @test_multiple_variables() {
+ %alloc1 = memref.alloca() : memref<f32>
+ %alloc2 = memref.alloca() : memref<10xi32>
+ acc.parallel {
+ %load1 = memref.load %alloc1[] : memref<f32>
+ %c0 = arith.constant 0 : index
+ %load2 = memref.load %alloc2[%c0] : memref<10xi32>
+ acc.yield
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @test_multiple_variables
+// CHECK: acc.firstprivate varPtr({{.*}} : memref<f32>) recipe({{.*}}) -> memref<f32> {implicit = true, name = ""}
+// CHECK: %[[COPYIN:.*]] = acc.copyin varPtr({{.*}} : memref<10xi32>) -> memref<10xi32> {dataClause = #acc<data_clause acc_copy>, implicit = true, name = ""}
+// CHECK: acc.copyout accPtr(%[[COPYIN]] : memref<10xi32>) to varPtr({{.*}} : memref<10xi32>) {dataClause = #acc<data_clause acc_copy>, implicit = true, name = ""}
+
+// -----
+
+// Test memref.view aliasing - view of explicitly copied buffer should generate present
+func.func @test_memref_view(%size: index) {
+ %c0 = arith.constant 0 : index
+ %buffer = memref.alloca(%size) : memref<?xi8>
+ %copyin = acc.copyin varPtr(%buffer : memref<?xi8>) -> memref<?xi8> {name = "buffer"}
+ %view = memref.view %buffer[%c0][] : memref<?xi8> to memref<8x64xf32>
+ acc.kernels dataOperands(%copyin : memref<?xi8>) {
+ %c0_0 = arith.constant 0 : index
+ %c0_1 = arith.constant 0 : index
+ %load = memref.load %view[%c0_0, %c0_1] : memref<8x64xf32>
+ acc.terminator
+ }
+ acc.copyout accPtr(%copyin : memref<?xi8>) to varPtr(%buffer : memref<?xi8>) {name = "buffer"}
+ return
+}
+
+// CHECK-LABEL: func.func @test_memref_view
+// CHECK: acc.present varPtr({{.*}} : memref<8x64xf32>) -> memref<8x64xf32> {implicit = true, name = ""}
+
diff --git a/mlir/test/Dialect/OpenACC/acc-implicit-declare.mlir b/mlir/test/Dialect/OpenACC/acc-implicit-declare.mlir
new file mode 100644
index 0000000..74ff338
--- /dev/null
+++ b/mlir/test/Dialect/OpenACC/acc-implicit-declare.mlir
@@ -0,0 +1,175 @@
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(acc-implicit-declare)" -split-input-file 2>&1 | FileCheck %s
+
+// -----
+
+// Test that non-constant scalar globals in compute regions are hoisted
+// instead of being marked with acc declare
+
+memref.global @gscalar : memref<f32> = dense<0.0>
+
+func.func @test_scalar_in_serial() {
+ acc.serial {
+ %addr = memref.get_global @gscalar : memref<f32>
+ %load = memref.load %addr[] : memref<f32>
+ acc.yield
+ }
+ return
+}
+
+// Expected to hoist this global access out of acc region instead of marking
+// with `acc declare`.
+// CHECK-LABEL: func.func @test_scalar_in_serial
+// CHECK: memref.get_global @gscalar
+// CHECK: acc.serial
+// CHECK-NOT: acc.declare
+
+// -----
+
+// Test that constant globals are marked with acc declare
+
+memref.global constant @gscalarconst : memref<f32> = dense<1.0>
+
+func.func @test_constant_in_serial() {
+ acc.serial {
+ %addr = memref.get_global @gscalarconst : memref<f32>
+ %load = memref.load %addr[] : memref<f32>
+ acc.yield
+ }
+ return
+}
+
+// This is expected to be `acc declare`'d since it is a constant.
+// CHECK: memref.global constant @gscalarconst {{.*}} {acc.declare = #acc.declare<dataClause = acc_copyin>}
+
+// -----
+
+// Test globals referenced in acc routine functions
+
+memref.global @gscalar_routine : memref<f32> = dense<0.0>
+
+acc.routine @acc_routine_0 func(@test_scalar_in_accroutine)
+func.func @test_scalar_in_accroutine() attributes {acc.routine_info = #acc.routine_info<[@acc_routine_0]>} {
+ %addr = memref.get_global @gscalar_routine : memref<f32>
+ %load = memref.load %addr[] : memref<f32>
+ return
+}
+
+// Global should be acc declare'd because it's in an acc routine
+// CHECK: memref.global @gscalar_routine {{.*}} {acc.declare = #acc.declare<dataClause = acc_copyin>}
+
+// -----
+
+// Test constant globals in acc routine
+
+memref.global constant @gscalarconst_routine : memref<f32> = dense<1.0>
+
+acc.routine @acc_routine_0 func(@test_constant_in_accroutine)
+func.func @test_constant_in_accroutine() attributes {acc.routine_info = #acc.routine_info<[@acc_routine_0]>} {
+ %addr = memref.get_global @gscalarconst_routine : memref<f32>
+ %load = memref.load %addr[] : memref<f32>
+ return
+}
+
+// CHECK: memref.global constant @gscalarconst_routine {{.*}} {acc.declare = #acc.declare<dataClause = acc_copyin>}
+
+// -----
+
+// Test acc.private.recipe with global reference - referenced variant
+
+memref.global @global_for_private : memref<f32> = dense<0.0>
+
+acc.private.recipe @private_recipe_with_global : memref<f32> init {
+^bb0(%arg0: memref<f32>):
+ %0 = memref.alloc() : memref<f32>
+ %global_addr = memref.get_global @global_for_private : memref<f32>
+ %global_val = memref.load %global_addr[] : memref<f32>
+ memref.store %global_val, %0[] : memref<f32>
+ acc.yield %0 : memref<f32>
+} destroy {
+^bb0(%arg0: memref<f32>):
+ memref.dealloc %arg0 : memref<f32>
+ acc.terminator
+}
+
+func.func @test_private_recipe_referenced() {
+ %var = memref.alloc() : memref<f32>
+ %priv = acc.private varPtr(%var : memref<f32>) recipe(@private_recipe_with_global) -> memref<f32>
+ acc.parallel private(%priv : memref<f32>) {
+ %load = memref.load %var[] : memref<f32>
+ acc.yield
+ }
+ memref.dealloc %var : memref<f32>
+ return
+}
+
+// Global should be acc declare'd because the recipe is referenced
+// CHECK: memref.global @global_for_private {{.*}} {acc.declare = #acc.declare<dataClause = acc_copyin>}
+
+// -----
+
+// Test acc.private.recipe with global reference - unreferenced variant
+
+memref.global @global_for_private_unused : memref<f32> = dense<0.0>
+
+acc.private.recipe @private_recipe_unused : memref<f32> init {
+^bb0(%arg0: memref<f32>):
+ %0 = memref.alloc() : memref<f32>
+ %global_addr = memref.get_global @global_for_private_unused : memref<f32>
+ %global_val = memref.load %global_addr[] : memref<f32>
+ memref.store %global_val, %0[] : memref<f32>
+ acc.yield %0 : memref<f32>
+} destroy {
+^bb0(%arg0: memref<f32>):
+ memref.dealloc %arg0 : memref<f32>
+ acc.terminator
+}
+
+func.func @test_private_recipe_not_referenced() {
+ %var = memref.alloc() : memref<f32>
+ acc.parallel {
+ %load = memref.load %var[] : memref<f32>
+ acc.yield
+ }
+ memref.dealloc %var : memref<f32>
+ return
+}
+
+// Global should NOT be acc declare'd because the recipe is not referenced
+// CHECK-NOT: memref.global @global_for_private_unused {{.*}} {acc.declare
+
+// -----
+
+// Test globals in different compute constructs (parallel, kernels, serial)
+
+memref.global @global_parallel : memref<f32> = dense<0.0>
+memref.global @global_kernels : memref<f32> = dense<0.0>
+memref.global constant @global_serial_const : memref<f32> = dense<1.0>
+
+func.func @test_multiple_constructs() {
+ acc.parallel {
+ %addr = memref.get_global @global_parallel : memref<f32>
+ %load = memref.load %addr[] : memref<f32>
+ acc.yield
+ }
+ acc.kernels {
+ %addr = memref.get_global @global_kernels : memref<f32>
+ %load = memref.load %addr[] : memref<f32>
+ acc.terminator
+ }
+ acc.serial {
+ %addr = memref.get_global @global_serial_const : memref<f32>
+ %load = memref.load %addr[] : memref<f32>
+ acc.yield
+ }
+ return
+}
+
+// Non-constant globals ARE hoisted before their compute regions
+// Constant global should be marked with acc.declare
+// CHECK: memref.global constant @global_serial_const {{.*}} {acc.declare = #acc.declare<dataClause = acc_copyin>}
+// CHECK-LABEL: func.func @test_multiple_constructs
+// CHECK: memref.get_global @global_parallel
+// CHECK-NEXT: acc.parallel
+// CHECK: memref.get_global @global_kernels
+// CHECK-NEXT: acc.kernels
+
diff --git a/mlir/test/Dialect/OpenACC/canonicalize.mlir b/mlir/test/Dialect/OpenACC/canonicalize.mlir
index fdc8e6b..38d3df3 100644
--- a/mlir/test/Dialect/OpenACC/canonicalize.mlir
+++ b/mlir/test/Dialect/OpenACC/canonicalize.mlir
@@ -219,3 +219,30 @@ func.func @update_unnecessary_computations(%x: memref<i32>) {
// CHECK-LABEL: func.func @update_unnecessary_computations
// CHECK-NOT: acc.atomic.update
// CHECK: acc.atomic.write
+
+// -----
+
+func.func @kernel_environment_canonicalization(%q1: i32, %q2: i32, %q3: i32) {
+ // Empty kernel_environment (no wait) - should be removed
+ acc.kernel_environment {
+ }
+
+ acc.kernel_environment wait({%q1 : i32, %q2 : i32}) {
+ }
+
+ acc.kernel_environment wait {
+ }
+
+ acc.kernel_environment wait({%q3 : i32} [#acc.device_type<nvidia>]) {
+ }
+
+ return
+}
+
+// CHECK-LABEL: func.func @kernel_environment_canonicalization
+// CHECK-SAME: ([[Q1:%.*]]: i32, [[Q2:%.*]]: i32, [[Q3:%.*]]: i32)
+// CHECK-NOT: acc.kernel_environment wait({{.*}}[#acc.device_type<none>])
+// CHECK: acc.wait([[Q1]], [[Q2]] : i32, i32)
+// CHECK: acc.wait{{$}}
+// CHECK: acc.kernel_environment wait({{.*}}[#acc.device_type<nvidia>])
+// CHECK: return
diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir
index 26b63fb..d1a1c93 100644
--- a/mlir/test/Dialect/OpenACC/invalid.mlir
+++ b/mlir/test/Dialect/OpenACC/invalid.mlir
@@ -76,27 +76,65 @@ acc.loop {
// -----
-// expected-error@+1 {{'acc.loop' op duplicate device_type found in gang attribute}}
+// expected-error@+1 {{'acc.loop' op duplicate device_type `none` found in gang attribute}}
acc.loop {
acc.yield
} attributes {gang = [#acc.device_type<none>, #acc.device_type<none>]}
// -----
-// expected-error@+1 {{'acc.loop' op duplicate device_type found in worker attribute}}
+// expected-error@+1 {{'acc.loop' op duplicate device_type `none` found in worker attribute}}
acc.loop {
acc.yield
} attributes {worker = [#acc.device_type<none>, #acc.device_type<none>]}
// -----
-// expected-error@+1 {{'acc.loop' op duplicate device_type found in vector attribute}}
+// expected-error@+1 {{'acc.loop' op duplicate device_type `none` found in vector attribute}}
acc.loop {
acc.yield
} attributes {vector = [#acc.device_type<none>, #acc.device_type<none>]}
// -----
+// expected-error@+1 {{'acc.loop' op duplicate device_type `nvidia` found in gang attribute}}
+acc.loop {
+ acc.yield
+} attributes {gang = [#acc.device_type<nvidia>, #acc.device_type<nvidia>]}
+
+// -----
+
+// expected-error@+1 {{'acc.loop' op duplicate device_type `none` found in collapseDeviceType attribute}}
+acc.loop {
+ acc.yield
+} attributes {collapse = [1, 1], collapseDeviceType = [#acc.device_type<none>, #acc.device_type<none>], independent = [#acc.device_type<none>]}
+
+// -----
+
+%i64value = arith.constant 1 : i64
+// expected-error@+1 {{'acc.loop' op duplicate device_type `none` found in workerNumOperandsDeviceType attribute}}
+acc.loop worker(%i64value: i64, %i64value: i64) {
+ acc.yield
+} attributes {workerNumOperandsDeviceType = [#acc.device_type<none>, #acc.device_type<none>], independent = [#acc.device_type<none>]}
+
+// -----
+
+%i64value = arith.constant 1 : i64
+// expected-error@+1 {{'acc.loop' op duplicate device_type `none` found in vectorOperandsDeviceType attribute}}
+acc.loop vector(%i64value: i64, %i64value: i64) {
+ acc.yield
+} attributes {vectorOperandsDeviceType = [#acc.device_type<none>, #acc.device_type<none>], independent = [#acc.device_type<none>]}
+
+// -----
+
+func.func @acc_routine_parallelism() -> () {
+ return
+}
+// expected-error@+1 {{only one of `gang`, `worker`, `vector`, `seq` can be present at the same time for device_type `nvidia`}}
+"acc.routine"() <{func_name = @acc_routine_parallelism, sym_name = "acc_routine_parallelism_rout", gang = [#acc.device_type<nvidia>], worker = [#acc.device_type<nvidia>]}> : () -> ()
+
+// -----
+
%1 = arith.constant 1 : i32
%2 = arith.constant 10 : i32
// expected-error@+1 {{only one of auto, independent, seq can be present at the same time}}
@@ -483,12 +521,12 @@ acc.loop gang({static=%i64Value: i64, ) control(%iv : i32) = (%1 : i32) to (%2 :
// -----
-func.func @fct1(%0 : !llvm.ptr) -> () {
- // expected-error@+1 {{expected symbol reference @privatization_i32 to point to a private declaration}}
- acc.serial private(@privatization_i32 -> %0 : !llvm.ptr) {
- }
- return
-}
+%i1 = arith.constant 1 : i32
+%i2 = arith.constant 10 : i32
+// expected-error@+1 {{unstructured acc.loop must not have induction variables}}
+acc.loop control(%iv : i32) = (%i1 : i32) to (%i2 : i32) step (%i1 : i32) {
+ acc.yield
+} attributes {independent = [#acc.device_type<none>], unstructured}
// -----
@@ -834,6 +872,76 @@ func.func @acc_loop_container() {
// -----
+func.func @fct1(%0 : !llvm.ptr) -> () {
+ // expected-error @below {{expected symbol reference @privatization_i32 to point to a private declaration}}
+ %priv = acc.private varPtr(%0 : !llvm.ptr) varType(i32) recipe(@privatization_i32) -> !llvm.ptr
+ return
+}
+
+// -----
+
+acc.private.recipe @privatization_i32 : !llvm.ptr init {
+^bb0(%arg0: !llvm.ptr):
+ %c1 = arith.constant 1 : i32
+ %c0 = arith.constant 0 : i32
+ %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
+ llvm.store %c0, %0 : i32, !llvm.ptr
+ acc.yield %0 : !llvm.ptr
+}
+
+func.func @fct1(%0 : !llvm.ptr) -> () {
+ %priv = acc.private varPtr(%0 : !llvm.ptr) varType(i32) recipe(@privatization_i32) -> !llvm.ptr
+ // expected-error @below {{expected firstprivate as defining op}}
+ acc.serial firstprivate(%priv : !llvm.ptr) {
+ }
+ return
+}
+
+// -----
+
+acc.private.recipe @privatization_i32 : !llvm.ptr init {
+^bb0(%arg0: !llvm.ptr):
+ %c1 = arith.constant 1 : i32
+ %c0 = arith.constant 0 : i32
+ %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
+ llvm.store %c0, %0 : i32, !llvm.ptr
+ acc.yield %0 : !llvm.ptr
+}
+
+func.func @fct1(%0 : !llvm.ptr) -> () {
+ %priv = acc.private varPtr(%0 : !llvm.ptr) varType(i32) recipe(@privatization_i32) -> !llvm.ptr
+ // expected-error @below {{op private operand appears more than once}}
+ acc.serial private(%priv, %priv : !llvm.ptr, !llvm.ptr) {
+ }
+ return
+}
+
+// -----
+
+func.func @fct1(%0 : !llvm.ptr) -> () {
+ // expected-error @below {{op recipe expected for private}}
+ %priv = acc.private varPtr(%0 : !llvm.ptr) varType(i32) -> !llvm.ptr
+ return
+}
+
+// -----
+
+func.func @fct1(%0 : !llvm.ptr) -> () {
+ // expected-error @below {{op recipe expected for firstprivate}}
+ %priv = acc.firstprivate varPtr(%0 : !llvm.ptr) varType(i32) -> !llvm.ptr
+ return
+}
+
+// -----
+
+func.func @fct1(%0 : !llvm.ptr) -> () {
+ // expected-error @below {{op recipe expected for reduction}}
+ %priv = acc.reduction varPtr(%0 : !llvm.ptr) varType(i32) -> !llvm.ptr
+ return
+}
+
+// -----
+
func.func @verify_declare_enter(%arg0 : memref<i32>) {
// expected-error @below {{expect valid declare data entry operation or acc.getdeviceptr as defining op}}
%0 = acc.declare_enter dataOperands(%arg0 : memref<i32>)
diff --git a/mlir/test/Dialect/OpenACC/legalize-data.mlir b/mlir/test/Dialect/OpenACC/legalize-data.mlir
index 40604dc..c7ef47c 100644
--- a/mlir/test/Dialect/OpenACC/legalize-data.mlir
+++ b/mlir/test/Dialect/OpenACC/legalize-data.mlir
@@ -129,8 +129,8 @@ func.func @test(%a: memref<10xf32>) {
%lb = arith.constant 0 : index
%st = arith.constant 1 : index
%c10 = arith.constant 10 : index
- %p1 = acc.private varPtr(%a : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32>
- acc.parallel private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) {
+ %p1 = acc.private varPtr(%a : memref<10xf32>) varType(tensor<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32>
+ acc.parallel private(%p1 : memref<10xf32>) {
acc.loop control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) {
%ci = memref.load %a[%i] : memref<10xf32>
acc.yield
@@ -142,8 +142,8 @@ func.func @test(%a: memref<10xf32>) {
// CHECK-LABEL: func.func @test
// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
-// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32>
-// CHECK: acc.parallel private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) {
+// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32>
+// CHECK: acc.parallel private(%[[PRIVATE]] : memref<10xf32>) {
// CHECK: acc.loop control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index) step (%{{.*}} : index) {
// DEVICE: %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32>
// CHECK: acc.yield
@@ -167,9 +167,9 @@ func.func @test(%a: memref<10xf32>) {
%lb = arith.constant 0 : index
%st = arith.constant 1 : index
%c10 = arith.constant 10 : index
- %p1 = acc.private varPtr(%a : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32>
+ %p1 = acc.private varPtr(%a : memref<10xf32>) varType(tensor<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32>
acc.parallel {
- acc.loop private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) {
+ acc.loop private(%p1 : memref<10xf32>) control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) {
%ci = memref.load %a[%i] : memref<10xf32>
acc.yield
} attributes {independent = [#acc.device_type<none>]}
@@ -180,9 +180,9 @@ func.func @test(%a: memref<10xf32>) {
// CHECK-LABEL: func.func @test
// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
-// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32>
+// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32>
// CHECK: acc.parallel {
-// CHECK: acc.loop private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index) step (%{{.*}} : index) {
+// CHECK: acc.loop private(%[[PRIVATE]] : memref<10xf32>) control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index) step (%{{.*}} : index) {
// DEVICE: %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32>
// CHECK: acc.yield
// CHECK: } attributes {independent = [#acc.device_type<none>]}
@@ -205,8 +205,8 @@ func.func @test(%a: memref<10xf32>) {
%lb = arith.constant 0 : index
%st = arith.constant 1 : index
%c10 = arith.constant 10 : index
- %p1 = acc.private varPtr(%a : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32>
- acc.serial private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) {
+ %p1 = acc.private varPtr(%a : memref<10xf32>) varType(tensor<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32>
+ acc.serial private(%p1 : memref<10xf32>) {
acc.loop control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) {
%ci = memref.load %a[%i] : memref<10xf32>
acc.yield
@@ -218,8 +218,8 @@ func.func @test(%a: memref<10xf32>) {
// CHECK-LABEL: func.func @test
// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
-// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32>
-// CHECK: acc.serial private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) {
+// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32>
+// CHECK: acc.serial private(%[[PRIVATE]] : memref<10xf32>) {
// CHECK: acc.loop control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index) step (%{{.*}} : index) {
// DEVICE: %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32>
// CHECK: acc.yield
diff --git a/mlir/test/Dialect/OpenACC/legalize-serial.mlir b/mlir/test/Dialect/OpenACC/legalize-serial.mlir
new file mode 100644
index 0000000..774c6b6
--- /dev/null
+++ b/mlir/test/Dialect/OpenACC/legalize-serial.mlir
@@ -0,0 +1,164 @@
+// RUN: mlir-opt %s -acc-legalize-serial | FileCheck %s
+
+acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
+^bb0(%arg0: memref<10xf32>):
+ %0 = memref.alloc() : memref<10xf32>
+ acc.yield %0 : memref<10xf32>
+} destroy {
+^bb0(%arg0: memref<10xf32>):
+ memref.dealloc %arg0 : memref<10xf32>
+ acc.terminator
+}
+
+acc.private.recipe @privatization_memref_10_10_f32 : memref<10x10xf32> init {
+^bb0(%arg0: memref<10x10xf32>):
+ %0 = memref.alloc() : memref<10x10xf32>
+ acc.yield %0 : memref<10x10xf32>
+} destroy {
+^bb0(%arg0: memref<10x10xf32>):
+ memref.dealloc %arg0 : memref<10x10xf32>
+ acc.terminator
+}
+
+acc.firstprivate.recipe @firstprivatization_memref_10xf32 : memref<10xf32> init {
+^bb0(%arg0: memref<10xf32>):
+ %0 = memref.alloc() : memref<10xf32>
+ acc.yield %0 : memref<10xf32>
+} copy {
+^bb0(%arg0: memref<10xf32>, %arg1: memref<10xf32>):
+ acc.terminator
+} destroy {
+^bb0(%arg0: memref<10xf32>):
+ memref.dealloc %arg0 : memref<10xf32>
+ acc.terminator
+}
+
+acc.reduction.recipe @reduction_add_i64 : i64 reduction_operator<add> init {
+^bb0(%0: i64):
+ %1 = arith.constant 0 : i64
+ acc.yield %1 : i64
+} combiner {
+^bb0(%0: i64, %1: i64):
+ %2 = arith.addi %0, %1 : i64
+ acc.yield %2 : i64
+}
+
+acc.reduction.recipe @reduction_add_memref_i64 : memref<i64> reduction_operator<add> init {
+^bb0(%arg0: memref<i64>):
+ %0 = memref.alloca() : memref<i64>
+ %c0 = arith.constant 0 : i64
+ memref.store %c0, %0[] : memref<i64>
+ acc.yield %0 : memref<i64>
+} combiner {
+^bb0(%arg0: memref<i64>, %arg1: memref<i64>):
+ %0 = memref.load %arg0[] : memref<i64>
+ %1 = memref.load %arg1[] : memref<i64>
+ %2 = arith.addi %0, %1 : i64
+ memref.store %2, %arg0[] : memref<i64>
+ acc.terminator
+}
+
+// CHECK: func.func @testserialop(%[[VAL_0:.*]]: memref<10xf32>, %[[VAL_1:.*]]: memref<10xf32>, %[[VAL_2:.*]]: memref<10x10xf32>) {
+// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i64
+// CHECK: %[[VAL_4:.*]] = arith.constant 1 : i32
+// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK: acc.parallel async(%[[VAL_3]] : i64) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) {
+// CHECK: }
+// CHECK: acc.parallel async(%[[VAL_4]] : i32) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) {
+// CHECK: }
+// CHECK: acc.parallel async(%[[VAL_5]] : index) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) {
+// CHECK: }
+// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_3]] : i64}) {
+// CHECK: }
+// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_4]] : i32}) {
+// CHECK: }
+// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_5]] : index}) {
+// CHECK: }
+// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_3]] : i64, %[[VAL_4]] : i32, %[[VAL_5]] : index}) {
+// CHECK: }
+// CHECK: %[[VAL_6:.*]] = acc.firstprivate varPtr(%[[VAL_1]] : memref<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32>
+// CHECK: %[[VAL_9:.*]] = acc.private varPtr(%[[VAL_2]] : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32>
+// CHECK: acc.parallel firstprivate(%[[VAL_6]] : memref<10xf32>) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) private(%[[VAL_9]] : memref<10x10xf32>) vector_length(%[[VAL_4]] : i32) {
+// CHECK: }
+// CHECK: %[[VAL_7:.*]] = acc.copyin varPtr(%[[VAL_0]] : memref<10xf32>) -> memref<10xf32> {dataClause = #acc<data_clause acc_copy>}
+// CHECK: acc.parallel dataOperands(%[[VAL_7]] : memref<10xf32>) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) {
+// CHECK: }
+// CHECK: %[[I64MEM:.*]] = memref.alloca() : memref<i64>
+// CHECK: memref.store %[[VAL_3]], %[[I64MEM]][] : memref<i64>
+// CHECK: %[[VAL_10:.*]] = acc.reduction varPtr(%[[I64MEM]] : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64>
+// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) reduction(%[[VAL_10]] : memref<i64>) {
+// CHECK: }
+// CHECK: acc.parallel combined(loop) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) {
+// CHECK: acc.loop combined(serial) control(%{{.*}} : index) = (%[[VAL_5]] : index) to (%[[VAL_5]] : index) step (%[[VAL_5]] : index) {
+// CHECK: acc.yield
+// CHECK: } attributes {seq = [#acc.device_type<none>]}
+// CHECK: acc.terminator
+// CHECK: }
+// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) {
+// CHECK: } attributes {defaultAttr = #acc<defaultvalue none>}
+// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) {
+// CHECK: } attributes {defaultAttr = #acc<defaultvalue present>}
+// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) {
+// CHECK: }
+// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) {
+// CHECK: }
+// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) {
+// CHECK: } attributes {selfAttr}
+// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) {
+// CHECK: acc.yield
+// CHECK: } attributes {selfAttr}
+// CHECK: return
+// CHECK: }
+
+func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf32>) -> () {
+ %i64value = arith.constant 1 : i64
+ %i32value = arith.constant 1 : i32
+ %idxValue = arith.constant 1 : index
+ acc.serial async(%i64value: i64) {
+ }
+ acc.serial async(%i32value: i32) {
+ }
+ acc.serial async(%idxValue: index) {
+ }
+ acc.serial wait({%i64value: i64}) {
+ }
+ acc.serial wait({%i32value: i32}) {
+ }
+ acc.serial wait({%idxValue: index}) {
+ }
+ acc.serial wait({%i64value : i64, %i32value : i32, %idxValue : index}) {
+ }
+ %firstprivate = acc.firstprivate varPtr(%b : memref<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32>
+ %c_private = acc.private varPtr(%c : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32>
+ acc.serial private(%c_private : memref<10x10xf32>) firstprivate(%firstprivate : memref<10xf32>) {
+ }
+ %copyinfromcopy = acc.copyin varPtr(%a : memref<10xf32>) -> memref<10xf32> {dataClause = #acc<data_clause acc_copy>}
+ acc.serial dataOperands(%copyinfromcopy : memref<10xf32>) {
+ }
+ %i64mem = memref.alloca() : memref<i64>
+ memref.store %i64value, %i64mem[] : memref<i64>
+ %i64reduction = acc.reduction varPtr(%i64mem : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64>
+ acc.serial reduction(%i64reduction : memref<i64>) {
+ }
+ acc.serial combined(loop) {
+ acc.loop combined(serial) control(%arg3 : index) = (%idxValue : index) to (%idxValue : index) step (%idxValue : index) {
+ acc.yield
+ } attributes {seq = [#acc.device_type<none>]}
+ acc.terminator
+ }
+ acc.serial {
+ } attributes {defaultAttr = #acc<defaultvalue none>}
+ acc.serial {
+ } attributes {defaultAttr = #acc<defaultvalue present>}
+ acc.serial {
+ } attributes {asyncAttr}
+ acc.serial {
+ } attributes {waitAttr}
+ acc.serial {
+ } attributes {selfAttr}
+ acc.serial {
+ acc.yield
+ } attributes {selfAttr}
+ return
+}
+
diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index 042ee25..d31397c 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -1,8 +1,8 @@
-// RUN: mlir-opt -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -split-input-file %s | FileCheck %s --check-prefixes=CHECK,CHECK-3
// Verify the printed output can be parsed.
-// RUN: mlir-opt -split-input-file %s | mlir-opt -split-input-file | FileCheck %s
+// RUN: mlir-opt -split-input-file %s | mlir-opt -split-input-file | FileCheck %s --check-prefixes=CHECK,CHECK-3
// Verify the generic form can be parsed.
-// RUN: mlir-opt -split-input-file -mlir-print-op-generic %s | mlir-opt -split-input-file | FileCheck %s
+// RUN: mlir-opt -split-input-file -mlir-print-op-generic %s | mlir-opt -split-input-file | FileCheck %s --check-prefixes=CHECK,CHECK-3
func.func @compute1(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf32>) -> memref<10x10xf32> {
%c0 = arith.constant 0 : index
@@ -120,8 +120,8 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x
%pc = acc.present varPtr(%c : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32>
%pd = acc.present varPtr(%d : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32>
acc.data dataOperands(%pa, %pb, %pc, %pd: memref<10x10xf32>, memref<10x10xf32>, memref<10xf32>, memref<10xf32>) {
- %private = acc.private varPtr(%c : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32>
- acc.parallel num_gangs({%numGangs: i64}) num_workers(%numWorkers: i64 [#acc.device_type<nvidia>]) private(@privatization_memref_10_f32 -> %private : memref<10xf32>) {
+ %private = acc.private varPtr(%c : memref<10xf32>) varType(tensor<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32>
+ acc.parallel num_gangs({%numGangs: i64}) num_workers(%numWorkers: i64 [#acc.device_type<nvidia>]) private(%private : memref<10xf32>) {
acc.loop gang control(%x : index) = (%lb : index) to (%c10 : index) step (%st : index) {
acc.loop worker control(%y : index) = (%lb : index) to (%c10 : index) step (%st : index) {
%axy = memref.load %a[%x, %y] : memref<10x10xf32>
@@ -157,8 +157,8 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x
// CHECK-NEXT: [[NUMGANG:%.*]] = arith.constant 10 : i64
// CHECK-NEXT: [[NUMWORKERS:%.*]] = arith.constant 10 : i64
// CHECK: acc.data dataOperands(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<10x10xf32>, memref<10x10xf32>, memref<10xf32>, memref<10xf32>) {
-// CHECK-NEXT: %[[P_ARG2:.*]] = acc.private varPtr([[ARG2]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32>
-// CHECK-NEXT: acc.parallel num_gangs({[[NUMGANG]] : i64}) num_workers([[NUMWORKERS]] : i64 [#acc.device_type<nvidia>]) private(@privatization_memref_10_f32 -> %[[P_ARG2]] : memref<10xf32>) {
+// CHECK-NEXT: %[[P_ARG2:.*]] = acc.private varPtr([[ARG2]] : memref<10xf32>) varType(tensor<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32>
+// CHECK-NEXT: acc.parallel num_gangs({[[NUMGANG]] : i64}) num_workers([[NUMWORKERS]] : i64 [#acc.device_type<nvidia>]) private(%[[P_ARG2]] : memref<10xf32>) {
// CHECK-NEXT: acc.loop gang control(%{{.*}}) = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) {
// CHECK-NEXT: acc.loop worker control(%{{.*}}) = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) {
// CHECK-NEXT: %{{.*}} = memref.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32>
@@ -375,8 +375,8 @@ func.func @testloopfirstprivate(%a: memref<10xf32>, %b: memref<10xf32>) -> () {
%c0 = arith.constant 0 : index
%c10 = arith.constant 10 : index
%c1 = arith.constant 1 : index
- %firstprivate = acc.firstprivate varPtr(%a : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32>
- acc.loop firstprivate(@firstprivatization_memref_10xf32 -> %firstprivate : memref<10xf32>) control(%iv : index) = (%c0 : index) to (%c10 : index) step (%c1 : index) {
+ %firstprivate = acc.firstprivate varPtr(%a : memref<10xf32>) varType(tensor<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32>
+ acc.loop firstprivate(%firstprivate : memref<10xf32>) control(%iv : index) = (%c0 : index) to (%c10 : index) step (%c1 : index) {
"test.openacc_dummy_op"() : () -> ()
acc.yield
} attributes {inclusiveUpperbound = array<i1: true>, independent = [#acc.device_type<none>]}
@@ -385,8 +385,8 @@ func.func @testloopfirstprivate(%a: memref<10xf32>, %b: memref<10xf32>) -> () {
// CHECK-LABEL: func.func @testloopfirstprivate(
// CHECK-SAME: %[[ARG0:.*]]: memref<10xf32>, %[[ARG1:.*]]: memref<10xf32>)
-// CHECK: %[[FIRSTPRIVATE:.*]] = acc.firstprivate varPtr(%[[ARG0]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32>
-// CHECK: acc.loop firstprivate(@firstprivatization_memref_10xf32 -> %[[FIRSTPRIVATE]] : memref<10xf32>) control(%{{.*}}) = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) {
+// CHECK: %[[FIRSTPRIVATE:.*]] = acc.firstprivate varPtr(%[[ARG0]] : memref<10xf32>) varType(tensor<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32>
+// CHECK: acc.loop firstprivate(%[[FIRSTPRIVATE]] : memref<10xf32>) control(%{{.*}}) = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) {
// CHECK: "test.openacc_dummy_op"() : () -> ()
// CHECK: acc.yield
// CHECK: } attributes {inclusiveUpperbound = array<i1: true>, independent = [#acc.device_type<none>]}
@@ -464,7 +464,10 @@ func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x
}
acc.parallel vector_length(%idxValue: index) {
}
- acc.parallel private(@privatization_memref_10_f32 -> %a : memref<10xf32>, @privatization_memref_10_10_f32 -> %c : memref<10x10xf32>) firstprivate(@privatization_memref_10xf32 -> %b: memref<10xf32>) {
+ %private_a = acc.private varPtr(%a : memref<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32>
+ %private_c = acc.private varPtr(%c : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32>
+ %firstprivate_b = acc.firstprivate varPtr(%b : memref<10xf32>) recipe(@privatization_memref_10xf32) -> memref<10xf32>
+ acc.parallel private(%private_a, %private_c : memref<10xf32>, memref<10x10xf32>) firstprivate(%firstprivate_b : memref<10xf32>) {
}
acc.parallel {
} attributes {defaultAttr = #acc<defaultvalue none>}
@@ -517,7 +520,10 @@ func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x
// CHECK-NEXT: }
// CHECK: acc.parallel vector_length([[IDXVALUE]] : index) {
// CHECK-NEXT: }
-// CHECK: acc.parallel firstprivate(@privatization_memref_10xf32 -> [[ARGB]] : memref<10xf32>) private(@privatization_memref_10_f32 -> [[ARGA]] : memref<10xf32>, @privatization_memref_10_10_f32 -> [[ARGC]] : memref<10x10xf32>) {
+// CHECK: %[[PRIVATE_A:.*]] = acc.private varPtr([[ARGA]] : memref<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32>
+// CHECK-NEXT: %[[PRIVATE_C:.*]] = acc.private varPtr([[ARGC]] : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32>
+// CHECK-NEXT: %[[FIRSTPRIVATE_B:.*]] = acc.firstprivate varPtr([[ARGB]] : memref<10xf32>) recipe(@privatization_memref_10xf32) -> memref<10xf32>
+// CHECK-NEXT: acc.parallel firstprivate(%[[FIRSTPRIVATE_B]] : memref<10xf32>) private(%[[PRIVATE_A]], %[[PRIVATE_C]] : memref<10xf32>, memref<10x10xf32>) {
// CHECK-NEXT: }
// CHECK: acc.parallel {
// CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>}
@@ -596,8 +602,10 @@ func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10
}
acc.serial wait({%i64value : i64, %i32value : i32, %idxValue : index}) {
}
- %firstprivate = acc.firstprivate varPtr(%b : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32>
- acc.serial private(@privatization_memref_10_f32 -> %a : memref<10xf32>, @privatization_memref_10_10_f32 -> %c : memref<10x10xf32>) firstprivate(@firstprivatization_memref_10xf32 -> %firstprivate : memref<10xf32>) {
+ %private_a = acc.private varPtr(%a : memref<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32>
+ %private_c = acc.private varPtr(%c : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32>
+ %firstprivate = acc.firstprivate varPtr(%b : memref<10xf32>) varType(tensor<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32>
+ acc.serial private(%private_a, %private_c : memref<10xf32>, memref<10x10xf32>) firstprivate(%firstprivate : memref<10xf32>) {
}
acc.serial {
} attributes {defaultAttr = #acc<defaultvalue none>}
@@ -633,8 +641,10 @@ func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10
// CHECK-NEXT: }
// CHECK: acc.serial wait({[[I64VALUE]] : i64, [[I32VALUE]] : i32, [[IDXVALUE]] : index}) {
// CHECK-NEXT: }
-// CHECK: %[[FIRSTP:.*]] = acc.firstprivate varPtr([[ARGB]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32>
-// CHECK: acc.serial firstprivate(@firstprivatization_memref_10xf32 -> %[[FIRSTP]] : memref<10xf32>) private(@privatization_memref_10_f32 -> [[ARGA]] : memref<10xf32>, @privatization_memref_10_10_f32 -> [[ARGC]] : memref<10x10xf32>) {
+// CHECK: %[[PRIVATE_A:.*]] = acc.private varPtr([[ARGA]] : memref<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32>
+// CHECK-NEXT: %[[PRIVATE_C:.*]] = acc.private varPtr([[ARGC]] : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32>
+// CHECK-NEXT: %[[FIRSTP:.*]] = acc.firstprivate varPtr([[ARGB]] : memref<10xf32>) varType(tensor<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32>
+// CHECK-NEXT: acc.serial firstprivate(%[[FIRSTP]] : memref<10xf32>) private(%[[PRIVATE_A]], %[[PRIVATE_C]] : memref<10xf32>, memref<10x10xf32>) {
// CHECK-NEXT: }
// CHECK: acc.serial {
// CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>}
@@ -721,6 +731,59 @@ func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10
// -----
+// Test acc.kernels with private and firstprivate operands, similar to acc.serial.
+
+acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
+^bb0(%arg0: memref<10xf32>):
+ %0 = memref.alloc() : memref<10xf32>
+ acc.yield %0 : memref<10xf32>
+} destroy {
+^bb0(%arg0: memref<10xf32>):
+ memref.dealloc %arg0 : memref<10xf32>
+ acc.terminator
+}
+
+acc.private.recipe @privatization_memref_10_10_f32 : memref<10x10xf32> init {
+^bb0(%arg0: memref<10x10xf32>):
+ %1 = memref.alloc() : memref<10x10xf32>
+ acc.yield %1 : memref<10x10xf32>
+} destroy {
+^bb0(%arg0: memref<10x10xf32>):
+ memref.dealloc %arg0 : memref<10x10xf32>
+ acc.terminator
+}
+
+acc.firstprivate.recipe @firstprivatization_memref_10xf32 : memref<10xf32> init {
+^bb0(%arg0: memref<10xf32>):
+ %2 = memref.alloca() : memref<10xf32>
+ acc.yield %2 : memref<10xf32>
+} copy {
+^bb0(%arg0: memref<10xf32>, %arg1: memref<10xf32>):
+ memref.copy %arg0, %arg1 : memref<10xf32> to memref<10xf32>
+ acc.terminator
+} destroy {
+^bb0(%arg0: memref<10xf32>):
+ acc.terminator
+}
+
+func.func @testkernelspriv(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf32>) -> () {
+ %priv_a = acc.private varPtr(%a : memref<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32>
+ %priv_c = acc.private varPtr(%c : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32>
+ %firstp = acc.firstprivate varPtr(%b : memref<10xf32>) varType(tensor<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32>
+ acc.kernels firstprivate(%firstp : memref<10xf32>) private(%priv_a, %priv_c : memref<10xf32>, memref<10x10xf32>) {
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @testkernelspriv(
+// CHECK: %[[PRIV_A:.*]] = acc.private varPtr(%{{.*}} : memref<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32>
+// CHECK: %[[PRIV_C:.*]] = acc.private varPtr(%{{.*}} : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32>
+// CHECK: %[[FIRSTP:.*]] = acc.firstprivate varPtr(%{{.*}} : memref<10xf32>) varType(tensor<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32>
+// CHECK: acc.kernels firstprivate(%[[FIRSTP]] : memref<10xf32>) private(%[[PRIV_A]], %[[PRIV_C]] : memref<10xf32>, memref<10x10xf32>) {
+// CHECK-NEXT: }
+
+// -----
+
func.func @testdataop(%a: memref<f32>, %b: memref<f32>, %c: memref<f32>) -> () {
%ifCond = arith.constant true
@@ -1511,32 +1574,43 @@ acc.private.recipe @privatization_struct_i32_i64 : !llvm.struct<(i32, i32)> init
// -----
-acc.reduction.recipe @reduction_add_i64 : i64 reduction_operator<add> init {
-^bb0(%arg0: i64):
- %0 = arith.constant 0 : i64
- acc.yield %0 : i64
+acc.reduction.recipe @reduction_add_memref_i64 : memref<i64> reduction_operator <add> init {
+^bb0(%arg0: memref<i64>):
+ %c0_i64 = arith.constant 0 : i64
+ %alloca = memref.alloca() : memref<i64>
+ memref.store %c0_i64, %alloca[] : memref<i64>
+ acc.yield %alloca : memref<i64>
} combiner {
-^bb0(%arg0: i64, %arg1: i64):
- %0 = arith.addi %arg0, %arg1 : i64
- acc.yield %0 : i64
+^bb0(%arg0: memref<i64>, %arg1: memref<i64>):
+ %0 = memref.load %arg0[] : memref<i64>
+ %1 = memref.load %arg1[] : memref<i64>
+ %2 = arith.addi %0, %1 : i64
+ memref.store %2, %arg0[] : memref<i64>
+ acc.yield %arg0 : memref<i64>
}
-// CHECK-LABEL: acc.reduction.recipe @reduction_add_i64 : i64 reduction_operator <add> init {
-// CHECK: ^bb0(%{{.*}}: i64):
+// CHECK-LABEL: acc.reduction.recipe @reduction_add_memref_i64 : memref<i64> reduction_operator <add> init {
+// CHECK: ^bb0(%{{.*}}: memref<i64>):
// CHECK: %[[C0:.*]] = arith.constant 0 : i64
-// CHECK: acc.yield %[[C0]] : i64
+// CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<i64>
+// CHECK: memref.store %[[C0]], %[[ALLOCA]][] : memref<i64>
+// CHECK: acc.yield %[[ALLOCA]] : memref<i64>
// CHECK: } combiner {
-// CHECK: ^bb0(%[[ARG0:.*]]: i64, %[[ARG1:.*]]: i64):
-// CHECK: %[[RES:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i64
-// CHECK: acc.yield %[[RES]] : i64
+// CHECK: ^bb0(%[[ARG0:.*]]: memref<i64>, %[[ARG1:.*]]: memref<i64>):
+// CHECK: %[[LOAD0:.*]] = memref.load %[[ARG0]][] : memref<i64>
+// CHECK: %[[LOAD1:.*]] = memref.load %[[ARG1]][] : memref<i64>
+// CHECK: %[[RES:.*]] = arith.addi %[[LOAD0]], %[[LOAD1]] : i64
+// CHECK: memref.store %[[RES]], %[[ARG0]][] : memref<i64>
+// CHECK: acc.yield %[[ARG0]] : memref<i64>
// CHECK: }
-func.func @acc_reduc_test(%a : i64) -> () {
+func.func @acc_reduc_test(%a : memref<i64>) -> () {
%c0 = arith.constant 0 : index
%c10 = arith.constant 10 : index
%c1 = arith.constant 1 : index
- acc.parallel reduction(@reduction_add_i64 -> %a : i64) {
- acc.loop reduction(@reduction_add_i64 -> %a : i64) control(%iv : index) = (%c0 : index) to (%c10 : index) step (%c1 : index) {
+ %reduction_a = acc.reduction varPtr(%a : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64>
+ acc.parallel reduction(%reduction_a : memref<i64>) {
+ acc.loop reduction(%reduction_a : memref<i64>) control(%iv : index) = (%c0 : index) to (%c10 : index) step (%c1 : index) {
acc.yield
} attributes {inclusiveUpperbound = array<i1: true>, independent = [#acc.device_type<none>]}
acc.yield
@@ -1545,31 +1619,68 @@ func.func @acc_reduc_test(%a : i64) -> () {
}
// CHECK-LABEL: func.func @acc_reduc_test(
-// CHECK-SAME: %[[ARG0:.*]]: i64)
-// CHECK: acc.parallel reduction(@reduction_add_i64 -> %[[ARG0]] : i64)
-// CHECK: acc.loop reduction(@reduction_add_i64 -> %[[ARG0]] : i64)
+// CHECK-SAME: %[[ARG0:.*]]: memref<i64>)
+// CHECK: %[[REDUCTION_A:.*]] = acc.reduction varPtr(%[[ARG0]] : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64>
+// CHECK-NEXT: acc.parallel reduction(%[[REDUCTION_A]] : memref<i64>)
+// CHECK: acc.loop reduction(%[[REDUCTION_A]] : memref<i64>)
// -----
-acc.reduction.recipe @reduction_add_i64 : i64 reduction_operator<add> init {
-^bb0(%0: i64):
- %1 = arith.constant 0 : i64
- acc.yield %1 : i64
+acc.reduction.recipe @reduction_add_memref_i64 : memref<i64> reduction_operator <add> init {
+^bb0(%arg0: memref<i64>):
+ %c0_i64 = arith.constant 0 : i64
+ %alloca = memref.alloca() : memref<i64>
+ memref.store %c0_i64, %alloca[] : memref<i64>
+ acc.yield %alloca : memref<i64>
} combiner {
-^bb0(%0: i64, %1: i64):
+^bb0(%arg0: memref<i64>, %arg1: memref<i64>):
+ %0 = memref.load %arg0[] : memref<i64>
+ %1 = memref.load %arg1[] : memref<i64>
%2 = arith.addi %0, %1 : i64
- acc.yield %2 : i64
+ memref.store %2, %arg0[] : memref<i64>
+ acc.yield %arg0 : memref<i64>
}
-func.func @acc_reduc_test(%a : i64) -> () {
- acc.serial reduction(@reduction_add_i64 -> %a : i64) {
+func.func @acc_reduc_test(%a : memref<i64>) -> () {
+ %reduction_a = acc.reduction varPtr(%a : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64>
+ acc.serial reduction(%reduction_a : memref<i64>) {
}
return
}
// CHECK-LABEL: func.func @acc_reduc_test(
-// CHECK-SAME: %[[ARG0:.*]]: i64)
-// CHECK: acc.serial reduction(@reduction_add_i64 -> %[[ARG0]] : i64)
+// CHECK-SAME: %[[ARG0:.*]]: memref<i64>)
+// CHECK: %[[REDUCTION_A:.*]] = acc.reduction varPtr(%[[ARG0]] : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64>
+// CHECK-NEXT: acc.serial reduction(%[[REDUCTION_A]] : memref<i64>)
+
+// -----
+
+acc.reduction.recipe @reduction_add_memref_i64 : memref<i64> reduction_operator <add> init {
+^bb0(%arg0: memref<i64>):
+ %c0_i64 = arith.constant 0 : i64
+ %alloca = memref.alloca() : memref<i64>
+ memref.store %c0_i64, %alloca[] : memref<i64>
+ acc.yield %alloca : memref<i64>
+} combiner {
+^bb0(%arg0: memref<i64>, %arg1: memref<i64>):
+ %0 = memref.load %arg0[] : memref<i64>
+ %1 = memref.load %arg1[] : memref<i64>
+ %2 = arith.addi %0, %1 : i64
+ memref.store %2, %arg0[] : memref<i64>
+ acc.yield %arg0 : memref<i64>
+}
+
+func.func @acc_kernels_reduc_test(%a : memref<i64>) -> () {
+ %reduction_a = acc.reduction varPtr(%a : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64>
+ acc.kernels reduction(%reduction_a : memref<i64>) {
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @acc_kernels_reduc_test(
+// CHECK-SAME: %[[ARG0:.*]]: memref<i64>)
+// CHECK: %[[REDUCTION_A:.*]] = acc.reduction varPtr(%[[ARG0]] : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64>
+// CHECK-NEXT: acc.kernels reduction(%[[REDUCTION_A]] : memref<i64>)
// -----
@@ -1699,6 +1810,59 @@ acc.routine @acc_func_rout9 func(@acc_func) bind("acc_func_gpu_gang_dim1") gang(
// -----
+// Test acc.specialized_routine attribute for specialized device functions
+acc.routine @routine_seq func(@device_func_seq) seq
+acc.routine @routine_gang func(@device_func_gang) gang
+acc.routine @routine_gang_dim2 func(@device_func_gang_dim2) gang(dim: 2 : i64)
+acc.routine @routine_gang_dim3 func(@device_func_gang_dim3) gang(dim: 3 : i64)
+acc.routine @routine_worker func(@device_func_worker) worker
+acc.routine @routine_vector func(@device_func_vector) vector
+
+func.func @device_func_seq() attributes {acc.specialized_routine = #acc.specialized_routine<@routine_seq, <seq>, "host_func_seq">} {
+ return
+}
+
+func.func @device_func_gang() attributes {acc.specialized_routine = #acc.specialized_routine<@routine_gang, <gang_dim1>, "host_func_gang">} {
+ return
+}
+
+func.func @device_func_gang_dim2() attributes {acc.specialized_routine = #acc.specialized_routine<@routine_gang_dim2, <gang_dim2>, "host_func_gang_dim2">} {
+ return
+}
+
+func.func @device_func_gang_dim3() attributes {acc.specialized_routine = #acc.specialized_routine<@routine_gang_dim3, <gang_dim3>, "host_func_gang_dim3">} {
+ return
+}
+
+func.func @device_func_worker() attributes {acc.specialized_routine = #acc.specialized_routine<@routine_worker, <worker>, "host_func_worker">} {
+ return
+}
+
+func.func @device_func_vector() attributes {acc.specialized_routine = #acc.specialized_routine<@routine_vector, <vector>, "host_func_vector">} {
+ return
+}
+
+// CHECK: acc.routine @routine_seq func(@device_func_seq) seq
+// CHECK: acc.routine @routine_gang func(@device_func_gang) gang
+// CHECK: acc.routine @routine_gang_dim2 func(@device_func_gang_dim2) gang(dim: 2 : i64)
+// CHECK: acc.routine @routine_gang_dim3 func(@device_func_gang_dim3) gang(dim: 3 : i64)
+// CHECK: acc.routine @routine_worker func(@device_func_worker) worker
+// CHECK: acc.routine @routine_vector func(@device_func_vector) vector
+// CHECK-LABEL: func.func @device_func_seq()
+// CHECK: attributes {acc.specialized_routine = #acc.specialized_routine<@routine_seq, <seq>, "host_func_seq">}
+// CHECK-LABEL: func.func @device_func_gang()
+// CHECK: attributes {acc.specialized_routine = #acc.specialized_routine<@routine_gang, <gang_dim1>, "host_func_gang">}
+// CHECK-LABEL: func.func @device_func_gang_dim2()
+// CHECK: attributes {acc.specialized_routine = #acc.specialized_routine<@routine_gang_dim2, <gang_dim2>, "host_func_gang_dim2">}
+// CHECK-LABEL: func.func @device_func_gang_dim3()
+// CHECK: attributes {acc.specialized_routine = #acc.specialized_routine<@routine_gang_dim3, <gang_dim3>, "host_func_gang_dim3">}
+// CHECK-LABEL: func.func @device_func_worker()
+// CHECK: attributes {acc.specialized_routine = #acc.specialized_routine<@routine_worker, <worker>, "host_func_worker">}
+// CHECK-LABEL: func.func @device_func_vector()
+// CHECK: attributes {acc.specialized_routine = #acc.specialized_routine<@routine_vector, <vector>, "host_func_vector">}
+
+// -----
+
func.func @acc_func() -> () {
"test.openacc_dummy_op"() {acc.declare_action = #acc.declare_action<postAlloc = @_QMacc_declareFacc_declare_allocateEa_acc_declare_update_desc_post_alloc>} : () -> ()
return
@@ -2143,6 +2307,20 @@ func.func @acc_loop_container() {
// -----
+func.func @acc_unstructured_loop() {
+ acc.loop {
+ acc.yield
+ } attributes {independent = [#acc.device_type<none>], unstructured}
+ return
+}
+
+// CHECK-LABEL: func.func @acc_unstructured_loop
+// CHECK: acc.loop
+// CHECK: acc.yield
+// CHECK: } attributes {independent = [#acc.device_type<none>], unstructured}
+
+// -----
+
// Test private recipe with data bounds for array slicing
acc.private.recipe @privatization_memref_slice : memref<10x10xf32> init {
^bb0(%arg0: memref<10x10xf32>, %bounds0: !acc.data_bounds_ty, %bounds1: !acc.data_bounds_ty):
diff --git a/mlir/test/Dialect/OpenACC/pointer-like-interface-load.mlir b/mlir/test/Dialect/OpenACC/pointer-like-interface-load.mlir
new file mode 100644
index 0000000..36df6a1
--- /dev/null
+++ b/mlir/test/Dialect/OpenACC/pointer-like-interface-load.mlir
@@ -0,0 +1,29 @@
+// RUN: mlir-opt %s --split-input-file --pass-pipeline="builtin.module(func.func(test-acc-pointer-like-interface{test-mode=load}))" 2>&1 | FileCheck %s
+
+func.func @test_memref_load_scalar() {
+ %ptr = memref.alloca() {test.ptr} : memref<f32>
+ // CHECK: Successfully generated load for operation: %[[PTR:.*]] = memref.alloca() {test.ptr} : memref<f32>
+ // CHECK: Loaded value type: f32
+ // CHECK: Generated: %{{.*}} = memref.load %[[PTR]][] : memref<f32>
+ return
+}
+
+// -----
+
+func.func @test_memref_load_int() {
+ %ptr = memref.alloca() {test.ptr} : memref<i64>
+ // CHECK: Successfully generated load for operation: %[[PTR:.*]] = memref.alloca() {test.ptr} : memref<i64>
+ // CHECK: Loaded value type: i64
+ // CHECK: Generated: %{{.*}} = memref.load %[[PTR]][] : memref<i64>
+ return
+}
+
+// -----
+
+func.func @test_memref_load_dynamic() {
+ %c10 = arith.constant 10 : index
+ %ptr = memref.alloc(%c10) {test.ptr} : memref<?xf32>
+ // CHECK: Failed to generate load for operation: %[[PTR:.*]] = memref.alloc(%{{.*}}) {test.ptr} : memref<?xf32>
+ return
+}
+
diff --git a/mlir/test/Dialect/OpenACC/pointer-like-interface-store.mlir b/mlir/test/Dialect/OpenACC/pointer-like-interface-store.mlir
new file mode 100644
index 0000000..0fee431
--- /dev/null
+++ b/mlir/test/Dialect/OpenACC/pointer-like-interface-store.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-opt %s --split-input-file --pass-pipeline="builtin.module(func.func(test-acc-pointer-like-interface{test-mode=store}))" 2>&1 | FileCheck %s
+
+func.func @test_memref_store_scalar() {
+ %ptr = memref.alloca() {test.ptr} : memref<f32>
+ // CHECK: Successfully generated store for operation: %[[PTR:.*]] = memref.alloca() {test.ptr} : memref<f32>
+ // CHECK: Generated: %[[VAL:.*]] = arith.constant 4.200000e+01 : f32
+ // CHECK: Generated: memref.store %[[VAL]], %[[PTR]][] : memref<f32>
+ return
+}
+
+// -----
+
+func.func @test_memref_store_int() {
+ %ptr = memref.alloca() {test.ptr} : memref<i32>
+ // CHECK: Successfully generated store for operation: %[[PTR:.*]] = memref.alloca() {test.ptr} : memref<i32>
+ // CHECK: Generated: %[[VAL:.*]] = arith.constant 42 : i32
+ // CHECK: Generated: memref.store %[[VAL]], %[[PTR]][] : memref<i32>
+ return
+}
+
+// -----
+
+func.func @test_memref_store_i64() {
+ %ptr = memref.alloca() {test.ptr} : memref<i64>
+ // CHECK: Successfully generated store for operation: %[[PTR:.*]] = memref.alloca() {test.ptr} : memref<i64>
+ // CHECK: Generated: %[[VAL:.*]] = arith.constant 42 : i64
+ // CHECK: Generated: memref.store %[[VAL]], %[[PTR]][] : memref<i64>
+ return
+}
+
+// -----
+
+func.func @test_memref_store_dynamic() {
+ %c10 = arith.constant 10 : index
+ %ptr = memref.alloc(%c10) {test.ptr} : memref<?xf32>
+ // CHECK: Failed to generate store for operation: %[[PTR:.*]] = memref.alloc(%{{.*}}) {test.ptr} : memref<?xf32>
+ return
+}
+
diff --git a/mlir/test/Dialect/OpenACC/recipe-populate-private-from-firstprivate.mlir b/mlir/test/Dialect/OpenACC/recipe-populate-private-from-firstprivate.mlir
new file mode 100644
index 0000000..154d44e
--- /dev/null
+++ b/mlir/test/Dialect/OpenACC/recipe-populate-private-from-firstprivate.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt %s --split-input-file --pass-pipeline="builtin.module(test-acc-recipe-populate{recipe-type=private_from_firstprivate})" | FileCheck %s
+
+// Verify that we can create a private recipe using the convenience overload
+// that takes an existing firstprivate recipe as input. For a simple scalar
+// alloca-backed memref, only an init region is expected (no destroy).
+// CHECK: acc.private.recipe @private_from_firstprivate_scalar : memref<f32> init {
+// CHECK: ^bb0(%{{.*}}: memref<f32>):
+// CHECK: %[[ALLOC:.*]] = memref.alloca() {acc.var_name = #acc.var_name<"scalar">} : memref<f32>
+// CHECK: acc.yield %[[ALLOC]] : memref<f32>
+// CHECK: }
+
+func.func @test_scalar_from_firstprivate() {
+ %0 = memref.alloca() {test.var = "scalar"} : memref<f32>
+ return
+}
+
+// -----
+
+// Verify that destroy regions are also present when creating a private recipe
+// from a firstprivate recipe that requires dynamic deallocation.
+// CHECK: acc.private.recipe @private_from_firstprivate_dynamic_d2 : memref<?x?xf32> init {
+// CHECK: ^bb0(%[[ARG:.*]]: memref<?x?xf32>):
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM0:.*]] = memref.dim %[[ARG]], %[[C0]] : memref<?x?xf32>
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<?x?xf32>
+// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) {acc.var_name = #acc.var_name<"dynamic_d2">} : memref<?x?xf32>
+// CHECK: acc.yield %[[ALLOC]] : memref<?x?xf32>
+// CHECK: } destroy {
+// CHECK: ^bb0(%{{.*}}: memref<?x?xf32>, %[[VAL:.*]]: memref<?x?xf32>):
+// CHECK: memref.dealloc %[[VAL]] : memref<?x?xf32>
+// CHECK: acc.terminator
+// CHECK: }
+
+func.func @test_dynamic_from_firstprivate(%arg0: index, %arg1: index) {
+ %0 = memref.alloc(%arg0, %arg1) {test.var = "dynamic_d2"} : memref<?x?xf32>
+ return
+}
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 084c3fc..ac590fc 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -974,6 +974,56 @@ func.func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
// -----
+// CHECK-LABEL: @while_move_if_down
+func.func @while_move_if_down() -> i32 {
+ %defined_outside = "test.get_some_value0" () : () -> (i32)
+ %0 = scf.while () : () -> (i32) {
+ %used_value = "test.get_some_value1" () : () -> (i32)
+ %used_by_subregion = "test.get_some_value2" () : () -> (i32)
+ %else_value = "test.get_some_value3" () : () -> (i32)
+ %condition = "test.condition"() : () -> i1
+ %res = scf.if %condition -> (i32) {
+ "test.use0" (%defined_outside) : (i32) -> ()
+ "test.use1" (%used_value) : (i32) -> ()
+ test.alloca_scope_region {
+ "test.use2" (%used_by_subregion) : (i32) -> ()
+ }
+ %then_value = "test.get_some_value4" () : () -> (i32)
+ scf.yield %then_value : i32
+ } else {
+ scf.yield %else_value : i32
+ }
+ scf.condition(%condition) %res : i32
+ } do {
+ ^bb0(%res_arg: i32):
+ "test.use3" (%res_arg) : (i32) -> ()
+ scf.yield
+ }
+ return %0 : i32
+}
+// CHECK: %[[defined_outside:.*]] = "test.get_some_value0"() : () -> i32
+// CHECK: %[[WHILE_RES:.*]]:3 = scf.while : () -> (i32, i32, i32) {
+// CHECK: %[[used_value:.*]] = "test.get_some_value1"() : () -> i32
+// CHECK: %[[used_by_subregion:.*]] = "test.get_some_value2"() : () -> i32
+// CHECK: %[[else_value:.*]] = "test.get_some_value3"() : () -> i32
+// CHECK: %[[condition:.*]] = "test.condition"() : () -> i1
+// CHECK: scf.condition(%[[condition]]) %[[else_value]], %[[used_value]], %[[used_by_subregion]] : i32, i32, i32
+// CHECK: } do {
+// CHECK: ^bb0(%[[res_arg:.*]]: i32, %[[used_value_arg:.*]]: i32, %[[used_by_subregion_arg:.*]]: i32):
+// CHECK: "test.use0"(%[[defined_outside]]) : (i32) -> ()
+// CHECK: "test.use1"(%[[used_value_arg]]) : (i32) -> ()
+// CHECK: test.alloca_scope_region {
+// CHECK: "test.use2"(%[[used_by_subregion_arg]]) : (i32) -> ()
+// CHECK: }
+// CHECK: %[[then_value:.*]] = "test.get_some_value4"() : () -> i32
+// CHECK: "test.use3"(%[[then_value]]) : (i32) -> ()
+// CHECK: scf.yield
+// CHECK: }
+// CHECK: return %[[WHILE_RES]]#0 : i32
+// CHECK: }
+
+// -----
+
// CHECK-LABEL: @while_cond_true
func.func @while_cond_true() -> i1 {
%0 = scf.while () : () -> i1 {
diff --git a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
index 8e29ff6..b70bb40 100644
--- a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
@@ -795,6 +795,53 @@ func.func @selection(%cond: i1) -> () {
// -----
+func.func @selection_switch(%selector: i32) -> () {
+ %zero = spirv.Constant 0: i32
+ %one = spirv.Constant 1: i32
+ %two = spirv.Constant 2: i32
+ %three = spirv.Constant 3: i32
+ %var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+
+ // CHECK: spirv.mlir.selection {
+ spirv.mlir.selection {
+ // CHECK-NEXT: spirv.Switch {{%.*}} : i32, [
+ // CHECK-NEXT: default: ^bb1,
+ // CHECK-NEXT: 0: ^bb2,
+ // CHECK-NEXT: 1: ^bb3
+ spirv.Switch %selector : i32, [
+ default: ^default,
+ 0: ^case0,
+ 1: ^case1
+ ]
+ // CHECK: ^bb1
+ ^default:
+ spirv.Store "Function" %var, %one : i32
+ // CHECK: spirv.Branch ^bb4
+ spirv.Branch ^merge
+
+ // CHECK: ^bb2
+ ^case0:
+ spirv.Store "Function" %var, %two : i32
+ // CHECK: spirv.Branch ^bb4
+ spirv.Branch ^merge
+
+ // CHECK: ^bb3
+ ^case1:
+ spirv.Store "Function" %var, %three : i32
+ // CHECK: spirv.Branch ^bb4
+ spirv.Branch ^merge
+
+ // CHECK: ^bb4
+ ^merge:
+ // CHECK-NEXT: spirv.mlir.merge
+ spirv.mlir.merge
+ }
+
+ spirv.Return
+}
+
+// -----
+
// CHECK-LABEL: @empty_region
func.func @empty_region() -> () {
// CHECK: spirv.mlir.selection
@@ -918,3 +965,171 @@ func.func @kill() {
// CHECK: spirv.Kill
spirv.Kill
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.Switch
+//===----------------------------------------------------------------------===//
+
+func.func @switch(%selector: i32) -> () {
+ // CHECK: spirv.Switch {{%.*}} : i32, [
+ // CHECK-NEXT: default: ^bb1,
+ // CHECK-NEXT: 0: ^bb2,
+ // CHECK-NEXT: 1: ^bb3,
+ // CHECK-NEXT: 2: ^bb4
+ spirv.Switch %selector : i32, [
+ default: ^default,
+ 0: ^case0,
+ 1: ^case1,
+ 2: ^case2
+ ]
+^default:
+ spirv.Branch ^merge
+
+^case0:
+ spirv.Branch ^merge
+
+^case1:
+ spirv.Branch ^merge
+
+^case2:
+ spirv.Branch ^merge
+
+^merge:
+ spirv.Return
+}
+
+func.func @switch_only_default(%selector: i32) -> () {
+ // CHECK: spirv.Switch {{%.*}} : i32, [
+ // CHECK-NEXT: default: ^bb1
+ spirv.Switch %selector : i32, [
+ default: ^default
+ ]
+^default:
+ spirv.Branch ^merge
+
+^merge:
+ spirv.Return
+}
+
+func.func @switch_operands(%selector : i32, %operand : i32) {
+ // CHECK: spirv.Switch {{%.*}} : i32, [
+ // CHECK-NEXT: default: ^bb1({{%.*}} : i32),
+ // CHECK-NEXT: 0: ^bb2({{%.*}} : i32),
+ // CHECK-NEXT: 1: ^bb3({{%.*}} : i32)
+ spirv.Switch %selector : i32, [
+ default: ^default(%operand : i32),
+ 0: ^case0(%operand : i32),
+ 1: ^case1(%operand : i32)
+ ]
+^default(%argd : i32):
+ spirv.Branch ^merge
+
+^case0(%arg0 : i32):
+ spirv.Branch ^merge
+
+^case1(%arg1 : i32):
+ spirv.Branch ^merge
+
+^merge:
+ spirv.Return
+}
+
+// -----
+
+func.func @switch_float_selector(%selector: f32) -> () {
+ // expected-error@+1 {{expected builtin.integer, but found 'f32'}}
+ spirv.Switch %selector : f32, [
+ default: ^default
+ ]
+^default:
+ spirv.Branch ^merge
+
+^merge:
+ spirv.Return
+}
+
+// -----
+
+func.func @switch_float_selector(%selector: i32) -> () {
+ // expected-error@+3 {{expected integer value}}
+ spirv.Switch %selector : i32, [
+ default: ^default,
+ 0.0: ^case0
+ ]
+^default:
+ spirv.Branch ^merge
+
+^case 0:
+ spirv.Branch ^merge
+
+^merge:
+ spirv.Return
+}
+
+// -----
+
+func.func @switch_missing_default(%selector: i32) -> () {
+ // expected-error@+2 {{expected 'default'}}
+ spirv.Switch %selector : i32, [
+ 0: ^case0
+ ]
+^case 0:
+ spirv.Branch ^merge
+
+^merge:
+ spirv.Return
+}
+
+// -----
+
+func.func @switch_default_no_target(%selector: i32) -> () {
+ // expected-error@+2 {{expected block name}}
+ spirv.Switch %selector : i32, [
+ default:
+ ]
+^default:
+ spirv.Branch ^merge
+
+^merge:
+ spirv.Return
+}
+
+// -----
+
+func.func @switch_case_no_target(%selector: i32) -> () {
+ // expected-error@+3 {{expected block name}}
+ spirv.Switch %selector : i32, [
+ default: ^default,
+ 0:
+ ]
+^default:
+ spirv.Branch ^merge
+
+^case 0:
+ spirv.Branch ^merge
+
+^merge:
+ spirv.Return
+}
+
+// -----
+
+func.func @switch_missing_operand_type(%selector: i32) -> () {
+ %0 = spirv.Constant 0 : i32
+ // expected-error@+2 {{expected ':'}}
+ spirv.Switch %selector : i32, [
+ default: ^default (%0),
+ 0.0: ^case0
+ ]
+^default(%argd : i32):
+ spirv.Branch ^merge
+
+^case 0:
+ spirv.Branch ^merge
+
+^merge:
+ spirv.Return
+}
+
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 5eb2360..be8ce20 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -678,11 +678,9 @@ func.func @tensor.concat_different_shapes(%f: tensor<8x4xf32>, %g: tensor<8x5xf3
// CHECK-DAG: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
// CHECK: %[[ALLOC:.*]] = memref.alloc
// CHECK-SAME: memref<8x?xf32>
-// CHECK-DAG: %[[OFFSET:.*]] = arith.constant 0 : index
-// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, %[[F_DIM]]] [1, 1]
+// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, %[[F_DIM]]] [1, 1]
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
-// CHECK: %[[OFFSET_2:.*]] = arith.addi %[[OFFSET]], %[[F_DIM]] : index
-// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [8, %[[G_DIM]]] [1, 1]
+// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [8, %[[G_DIM]]] [1, 1]
// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
// CHECK: return %[[RET]]
@@ -706,10 +704,9 @@ func.func @tensor.concat_dynamic(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>) -> te
// CHECK: %[[ALLOC:.*]] = memref.alloc
// CHECK-SAME: memref<?x?xf32>
// CHECK-DAG: %[[NON_CONCAT_DIM:.*]] = memref.dim %[[ALLOC]], %[[c0]]
-// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[c0]]] [%[[NON_CONCAT_DIM]], %[[F_DIM]]] [1, 1]
+// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [%[[NON_CONCAT_DIM]], %[[F_DIM]]] [1, 1]
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
-// CHECK: %[[OFFSET_2:.*]] = arith.addi %[[c0]], %[[F_DIM]] : index
-// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [%[[NON_CONCAT_DIM]], %[[G_DIM]]] [1, 1]
+// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [%[[NON_CONCAT_DIM]], %[[G_DIM]]] [1, 1]
// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
// CHECK: return %[[RET]]
@@ -721,6 +718,35 @@ func.func @tensor.concat_dynamic_nonconcat_dim(%f: tensor<?x?xf32>, %g: tensor<?
// -----
+// CHECK: #[[$sum_map:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+
+// CHECK-LABEL: func @tensor.concat_mixed_dynamic_static(
+// CHECK-SAME: %[[F:.*]]: tensor<8x?xf32>, %[[G:.*]]: tensor<8x?xf32>,
+// CHECK-SAME: %[[H:.*]]: tensor<8x2xf32>)
+// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
+// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_buffer %[[G]]
+// CHECK-DAG: %[[H_MEMREF:.*]] = bufferization.to_buffer %[[H]]
+// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x10xf32>
+// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+// CHECK: %[[F_DIM:.*]] = memref.dim %[[F_MEMREF]], %[[c1]]
+// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, %[[F_DIM]]] [1, 1]
+// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
+// CHECK: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
+// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [8, %[[G_DIM]]] [1, 1]
+// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
+// CHECK: %[[OFFSET:.*]] = affine.apply #[[$sum_map]]()[%[[F_DIM]], %[[G_DIM]]]
+// CHECK: %[[SUBVIEW3:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, 2] [1, 1]
+// CHECK: memref.copy %[[H_MEMREF]], %[[SUBVIEW3]]
+// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK: return %[[RET]]
+// CHECK: }
+func.func @tensor.concat_mixed_dynamic_static(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>, %h: tensor<8x2xf32>) -> tensor<8x10xf32> {
+ %0 = tensor.concat dim(1) %f, %g, %h : (tensor<8x?xf32>, tensor<8x?xf32>, tensor<8x2xf32>) -> tensor<8x10xf32>
+ return %0 : tensor<8x10xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @tensor.splat_dynamic(
// CHECK-SAME: %[[F:[a-zA-Z0-9_]+]]: f32
// CHECK-SAME: %[[M:[a-zA-Z0-9_]+]]: index
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index a05f423..6ef8b3e 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -606,7 +606,7 @@ func.func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> {
// CHECK-LABEL: cast
func.func @test_cast1(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16, mxfp, int64] ]
+ // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16, int64] ]
%0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 5a40f3f..84776c4 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -362,6 +362,36 @@ func.func @clamp_twice_with_nan_propagate_ignore_is_not_single_clamp(%arg0: tens
// -----
+// CHECK-LABEL: @clamp_twice_with_unsigned_quantized_is_single_clamp
+// CHECK: tosa.clamp %arg0 {max_val = 230 : ui8, min_val = 10 : ui8}
+func.func @clamp_twice_with_unsigned_quantized_is_single_clamp(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) {
+ %0 = tosa.clamp %arg0 {max_val = 240 : ui8, min_val = 10 : ui8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+ %1 = tosa.clamp %0 {max_val = 230 : ui8, min_val = 5 : ui8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+ return %1 : tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+}
+
+// -----
+
+// CHECK-LABEL: @clamp_twice_with_signed_quantized_is_single_clamp
+// CHECK: tosa.clamp %arg0 {max_val = 110 : i8, min_val = -5 : i8}
+func.func @clamp_twice_with_signed_quantized_is_single_clamp(%arg0:tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) {
+ %0 = tosa.clamp %arg0 {max_val = 110 : i8, min_val = -10 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
+ %1 = tosa.clamp %0 {max_val = 120 : i8, min_val = -5 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
+ return %1 : tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
+}
+
+// CHECK-LABEL: @clamp_twice_with_signed_quantized_non_overlap_is_not_single_clamp
+// CHECK: %[[CLAMP_1:.*]] = tosa.clamp %arg0 {max_val = 50 : i8, min_val = -10 : i8}
+// CHECK-NEXT: tosa.clamp %[[CLAMP_1]] {max_val = 120 : i8, min_val = 60 : i8}
+func.func @clamp_twice_with_signed_quantized_non_overlap_is_not_single_clamp(%arg0:tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) {
+ %0 = tosa.clamp %arg0 {max_val = 50 : i8, min_val = -10 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
+ %1 = tosa.clamp %0 {max_val = 120 : i8, min_val = 60 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
+ return %1 : tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
+}
+
+
+// -----
+
// CHECK-LABEL: @concat_fold
func.func @concat_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// CHECK: return %arg0
@@ -643,6 +673,48 @@ func.func @select_not_pred(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2:
// -----
+// CHECK-LABEL: @select_broadcast_same_value_no_fold
+func.func @select_broadcast_same_value_no_fold(%arg0: tensor<2x2xi1>, %arg1: tensor<1x1xf32>) -> tensor<2x2xf32> {
+ // CHECK: tosa.select %arg0, %arg1, %arg1
+ %0 = tosa.select %arg0, %arg1, %arg1 : (tensor<2x2xi1>, tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<2x2xf32>
+ return %0 : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @select_broadcast_true_value_no_fold
+func.func @select_broadcast_true_value_no_fold(%arg0: tensor<1x1xf32>, %arg1: tensor<2x2xf32>) -> tensor<?x?xf32> {
+ // CHECK: %[[CONST:.*]] = "tosa.const"
+ %0 = "tosa.const"() {values = dense<1> : tensor<2x2xi1>} : () -> tensor<2x2xi1>
+ // CHECK: tosa.select %[[CONST]], %arg0, %arg1
+ %1 = tosa.select %0, %arg0, %arg1 : (tensor<2x2xi1>, tensor<1x1xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @select_broadcast_false_value_no_fold
+func.func @select_broadcast_false_value_no_fold(%arg0: tensor<2x2xf32>, %arg1: tensor<1x1xf32>) -> tensor<2x2xf32> {
+ // CHECK: %[[CONST:.*]] = "tosa.const"
+ %0 = "tosa.const"() {values = dense<0> : tensor<2x2xi1>} : () -> tensor<2x2xi1>
+ // CHECK: tosa.select %[[CONST]], %arg0, %arg1
+ %1 = tosa.select %0, %arg0, %arg1 : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<1x1xf32>) -> tensor<2x2xf32>
+ return %1 : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @select_broadcast_false_value_dynamic_operand_no_fold
+func.func @select_broadcast_false_value_dynamic_operand_no_fold(%arg0: tensor<2x?xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+ // CHECK: %[[CONST:.*]] = "tosa.const"
+ %0 = "tosa.const"() {values = dense<0> : tensor<2x2xi1>} : () -> tensor<2x2xi1>
+ // CHECK: tosa.select %[[CONST]], %arg0, %arg1
+ %1 = tosa.select %0, %arg0, %arg1 : (tensor<2x2xi1>, tensor<2x?xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+ return %1 : tensor<2x2xf32>
+}
+
+// -----
+
// CHECK-LABEL: @reduce_all_fold
func.func @reduce_all_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// CHECK: return %arg0
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 119991ca..3d24928 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -4,7 +4,7 @@
// validation flow.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,int64,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
func.func @test_cast(%arg0: tensor<i1>) -> tensor<5xi32> {
@@ -306,6 +306,14 @@ func.func @test_concat_input_rank_mismatch(%arg0: tensor<1x2x3xf32>, %arg1: tens
// -----
+func.func @test_concat_input_output_rank_mismatch(%arg0: tensor<2x2xf32>, %arg1: tensor<2x1xf32>) -> tensor<2xf32> {
+ // expected-error@+1 {{'tosa.concat' op expect output rank to match inputs rank, got 1 vs 2}}
+ %0 = tosa.concat %arg0, %arg1 {axis = 1 : i32} : (tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2xf32>
+ return %0 : tensor<2xf32>
+}
+
+// -----
+
func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>) {
%0 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
%1 = "tosa.const"() {values = dense<3.14> : tensor<2xf32>} : () -> tensor<2xf32>
@@ -2036,6 +2044,16 @@ func.func @test_scatter_duplicate_indices(%arg0: tensor<2x52x3xf32>, %arg2: tens
// -----
+// CHECK-LABEL: test_scatter_duplicate_indices_int64
+func.func @test_scatter_duplicate_indices_int64(%arg0: tensor<2x52x3xf32>, %arg2: tensor<2x12x3xf32>) -> tensor<2x52x3xf32> {
+ %indices = "tosa.const"() { values = dense<[[1, 2, 3, 4, 5, 6, 7, 8, 9, 3, 11, 12], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]> : tensor<2x12xi64> } : () -> tensor<2x12xi64>
+ // expected-error@+1 {{'tosa.scatter' op indices values contain duplicates}}
+ %0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi64>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32>
+ return %0 : tensor<2x52x3xf32>
+}
+
+// -----
+
func.func @test_reduce_all_unsupported_data_types(%arg0: tensor<2x12x11xf32>) -> tensor<1x12x11xf32> {
// expected-error@+1 {{'tosa.reduce_all' op illegal: operation operand/result data types did not align with any profile or extension, got (f32,f32), did you mean (i1,i1)?}}
%0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<2x12x11xf32>) -> tensor<1x12x11xf32>
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 68a9578..177192b 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -563,13 +563,6 @@ func.func @test_const_fp6e3m2(%arg0 : index) -> tensor<4xf6E3M2FN> {
}
// -----
-func.func @test_cast_f4e2m1(%arg0: tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16> {
- // expected-error@+1 {{'tosa.cast' op illegal: requires all of [bf16, mxfp] but not enabled in target}}
- %0 = tosa.cast %arg0 : (tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16>
- return %0 : tensor<13x21x3xbf16>
-}
-
-// -----
func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf8E5M2>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
// expected-error@+1 {{'tosa.cast_from_block_scaled' op illegal: requires [mxfp] but not enabled in target}}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 22fde3b..652447bd 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -280,6 +280,13 @@ func.func @test_clamp_quantized(%arg0: tensor<13x21x3x!quant.uniform<i8:f32, 1.0
}
// -----
+// CHECK-LABEL: clamp_quantized_unsigned
+func.func @clamp_quantized_unsigned(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) {
+ %0 = tosa.clamp %arg0 {max_val = 255 : ui8, min_val = 0 : ui8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+ return %0 : tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+}
+
+// -----
// CHECK-LABEL: sigmoid
func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
%0 = tosa.sigmoid %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
@@ -343,6 +350,13 @@ func.func @test_intdiv(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>) -
}
// -----
+// CHECK-LABEL: intdiv_i64
+func.func @test_intdiv_i64(%arg0: tensor<13x21x1xi64>, %arg1: tensor<13x21x3xi64>) -> tensor<13x21x3xi64> {
+ %0 = tosa.intdiv %arg0, %arg1 : (tensor<13x21x1xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi64>
+ return %0 : tensor<13x21x3xi64>
+}
+
+// -----
// CHECK-LABEL: logical_and
func.func @test_logical_and(%arg0: tensor<13x21x3xi1>, %arg1: tensor<13x21x1xi1>) -> tensor<13x21x3xi1> {
%0 = tosa.logical_and %arg0, %arg1 : (tensor<13x21x3xi1>, tensor<13x21x1xi1>) -> tensor<13x21x3xi1>
@@ -750,10 +764,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
}
// -----
-// CHECK-LABEL: scatter
-func.func @test_scatter(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> {
- %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32>
- return %0 : tensor<13x52x3xf32>
+// CHECK-LABEL: gather_int64
+func.func @test_gather_int64(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi64>) -> tensor<13x26x3xf32> {
+ %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi64>) -> tensor<13x26x3xf32>
+ return %0 : tensor<13x26x3xf32>
}
// -----
@@ -764,6 +778,20 @@ func.func @test_gather_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tenso
}
// -----
+// CHECK-LABEL: scatter
+func.func @test_scatter(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> {
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32>
+ return %0 : tensor<13x52x3xf32>
+}
+
+// -----
+// CHECK-LABEL: scatter_int64
+func.func @test_scatter_int64(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi64>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> {
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi64>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32>
+ return %0 : tensor<13x52x3xf32>
+}
+
+// -----
// CHECK-LABEL: scatter_unranked_indices
func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>, %arg2: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<*xi32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
@@ -1277,6 +1305,42 @@ func.func @test_matmul_t_block_scaled_mxint8(%arg0: tensor<4x8x32x!tosa.mxint8>,
}
// -----
+// CHECK-LABEL: test_matmul_t_block_scaled_fp6e3m2_e2e
+func.func @test_matmul_t_block_scaled_fp6e3m2_e2e(%arg0: tensor<6x2x32xf32>, %arg1: tensor<6x64x32xf32>) -> tensor<6x2x64xf32> {
+ %a, %sa = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x2x32xf32>) -> (tensor<6x2x32xf6E3M2FN>, tensor<6x2x1xf8E8M0FNU>)
+ %b, %sb = tosa.cast_to_block_scaled %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x64x32xf32>) -> (tensor<6x64x32xf6E3M2FN>, tensor<6x64x1xf8E8M0FNU>)
+ %res = tosa.matmul_t_block_scaled %a, %sa, %b, %sb {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x2x32xf6E3M2FN>, tensor<6x2x1xf8E8M0FNU>, tensor<6x64x32xf6E3M2FN>, tensor<6x64x1xf8E8M0FNU>) -> tensor<6x2x64xf32>
+ return %res : tensor<6x2x64xf32>
+}
+
+// -----
+// CHECK-LABEL: test_matmul_t_block_scaled_fp6e2m3_e2e
+func.func @test_matmul_t_block_scaled_fp6e2m3_e2e(%arg0: tensor<6x2x32xf32>, %arg1: tensor<6x64x32xf32>) -> tensor<6x2x64xf32> {
+ %a, %sa = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x2x32xf32>) -> (tensor<6x2x32xf6E2M3FN>, tensor<6x2x1xf8E8M0FNU>)
+ %b, %sb = tosa.cast_to_block_scaled %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x64x32xf32>) -> (tensor<6x64x32xf6E2M3FN>, tensor<6x64x1xf8E8M0FNU>)
+ %res = tosa.matmul_t_block_scaled %a, %sa, %b, %sb {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x2x32xf6E2M3FN>, tensor<6x2x1xf8E8M0FNU>, tensor<6x64x32xf6E2M3FN>, tensor<6x64x1xf8E8M0FNU>) -> tensor<6x2x64xf32>
+ return %res : tensor<6x2x64xf32>
+}
+
+// -----
+// CHECK-LABEL: test_matmul_t_block_scaled_fp4e2m1_e2e
+func.func @test_matmul_t_block_scaled_fp4e2m1_e2e(%arg0: tensor<6x2x32xf32>, %arg1: tensor<6x64x32xf32>) -> tensor<6x2x64xf32> {
+ %a, %sa = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x2x32xf32>) -> (tensor<6x2x32xf4E2M1FN>, tensor<6x2x1xf8E8M0FNU>)
+ %b, %sb = tosa.cast_to_block_scaled %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x64x32xf32>) -> (tensor<6x64x32xf4E2M1FN>, tensor<6x64x1xf8E8M0FNU>)
+ %res = tosa.matmul_t_block_scaled %a, %sa, %b, %sb {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x2x32xf4E2M1FN>, tensor<6x2x1xf8E8M0FNU>, tensor<6x64x32xf4E2M1FN>, tensor<6x64x1xf8E8M0FNU>) -> tensor<6x2x64xf32>
+ return %res : tensor<6x2x64xf32>
+}
+
+// -----
+// CHECK-LABEL: test_matmul_t_block_scaled_mxint8_e2e
+func.func @test_matmul_t_block_scaled_mxint8_e2e(%arg0: tensor<6x2x32xf32>, %arg1: tensor<6x64x32xf32>) -> tensor<6x2x64xf32> {
+ %a, %sa = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x2x32xf32>) -> (tensor<6x2x32x!tosa.mxint8>, tensor<6x2x1xf8E8M0FNU>)
+ %b, %sb = tosa.cast_to_block_scaled %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x64x32xf32>) -> (tensor<6x64x32x!tosa.mxint8>, tensor<6x64x1xf8E8M0FNU>)
+ %res = tosa.matmul_t_block_scaled %a, %sa, %b, %sb {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x2x32x!tosa.mxint8>, tensor<6x2x1xf8E8M0FNU>, tensor<6x64x32x!tosa.mxint8>, tensor<6x64x1xf8E8M0FNU>) -> tensor<6x2x64xf32>
+ return %res : tensor<6x2x64xf32>
+}
+
+// -----
// CHECK-LABEL: test_cast_from_block_scaled_static
func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
%0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32>
@@ -1307,7 +1371,7 @@ func.func @test_cast_to_block_scaled_unranked(%arg0: tensor<*xf32>) -> (tensor<*
// -----
// CHECK-LABEL: test_cast_to_block_scaled_mxint8
func.func @test_cast_to_block_scaled_mxint8(%arg0: tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) {
- %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>)
+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>)
return %0#0, %0#1 : tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>
}
diff --git a/mlir/test/Dialect/Tosa/quant-test.mlir b/mlir/test/Dialect/Tosa/quant-test.mlir
index f0ad4eb..88dffe7 100644
--- a/mlir/test/Dialect/Tosa/quant-test.mlir
+++ b/mlir/test/Dialect/Tosa/quant-test.mlir
@@ -1,14 +1,22 @@
// RUN: mlir-opt --tosa-test-quant-utils %s | FileCheck %s
// -----
-// CHECK-LABEL: test_build_qtype
-func.func @test_build_qtype(%arg0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>> {
+// CHECK-LABEL: test_build_qtype_unsigned
+func.func @test_build_qtype_unsigned(%arg0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, %arg1: tensor<1xui8>, %arg2: tensor<1xui8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>> {
// CHECK: tosa.negate
- %0 = "tosa.negate"(%arg0, %arg1, %arg2) : (tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, tensor<1xi8>, tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>
+ %0 = "tosa.negate"(%arg0, %arg1, %arg2) : (tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, tensor<1xui8>, tensor<1xui8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>
return %0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>
}
// -----
+// CHECK-LABEL: test_build_qtype_signed
+func.func @test_build_qtype_signed(%arg0 : tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>> {
+ // CHECK: tosa.negate
+ %0 = "tosa.negate"(%arg0, %arg1, %arg2) : (tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>>, tensor<1xi8>, tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>>
+ return %0 : tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>>
+}
+
+// -----
// CHECK-LABEL: test_build_mult_and_shift
func.func @test_build_mult_and_shift(%arg0: tensor<1x32x32x8x!quant.uniform<i8:f32, 0.015684768557548523>>, %arg1 : tensor<16x1x1x8x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, %arg2 : tensor<16xi32>) -> tensor<1x34x36x16x!quant.uniform<i32:f32, 0.078431375324726104>> {
// CHECK: tosa.conv2d
diff --git a/mlir/test/Dialect/Tosa/tosa-arith-const-to-tosa-const.mlir b/mlir/test/Dialect/Tosa/tosa-arith-const-to-tosa-const.mlir
new file mode 100644
index 0000000..fc2d77ef
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-arith-const-to-tosa-const.mlir
@@ -0,0 +1,100 @@
+// RUN: mlir-opt %s --tosa-arith-const-to-tosa-const --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @rewrite_f32_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32>}> : () -> tensor<2xf32>
+// CHECK: return %[[CST]]
+func.func @rewrite_f32_tensor() -> tensor<2xf32> {
+ %c = arith.constant dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32>
+ return %c : tensor<2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rewrite_i32_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32>
+// CHECK: return %[[CST]]
+func.func @rewrite_i32_tensor() -> tensor<3xi32> {
+ %c = arith.constant dense<[1, 0, -1]> : tensor<3xi32>
+ return %c : tensor<3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rewrite_i1_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[true, false]> : tensor<2xi1>}> : () -> tensor<2xi1>
+func.func @rewrite_i1_tensor() -> tensor<2xi1> {
+ %c = arith.constant dense<[true, false]> : tensor<2xi1>
+ return %c : tensor<2xi1>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rewrite_rank0_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<1.234500e+00> : tensor<f32>}> : () -> tensor<f32>
+func.func @rewrite_rank0_tensor() -> tensor<f32> {
+ %c = arith.constant dense<1.234500e+00> : tensor<f32>
+ return %c : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @preserve_scalar_i32
+// CHECK: %[[CST:.*]] = arith.constant 42 : i32
+func.func @preserve_scalar_i32() -> i32 {
+ %c = arith.constant 42 : i32
+ return %c : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @preserve_index_tensor
+// CHECK: %[[CST:.*]] = arith.constant dense<[0, 1]> : tensor<2xindex>
+func.func @preserve_index_tensor() -> tensor<2xindex> {
+ %c = arith.constant dense<[0, 1]> : tensor<2xindex>
+ return %c : tensor<2xindex>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rewrite_resource_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense_resource<blob1> : tensor<4xf32>}> : () -> tensor<4xf32>
+func.func @rewrite_resource_tensor() -> tensor<4xf32> {
+ %c = arith.constant dense_resource<"blob1"> : tensor<4xf32>
+ return %c : tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rewrite_quant_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[10, 20]> : tensor<2xui8>}> : () -> tensor<2xui8>
+func.func @rewrite_quant_tensor() -> tensor<2xui8> {
+ %c = arith.constant dense<[10, 20]> : tensor<2xui8>
+ return %c : tensor<2xui8>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rewrite_quant_uniform_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<["10", "20"]> : tensor<2x!quant.uniform<i8:f32, 5.000000e-01>>}> : () -> tensor<2x!quant.uniform<i8:f32, 5.000000e-01>>
+func.func @rewrite_quant_uniform_tensor() -> tensor<2x!quant.uniform<i8:f32, 0.5:0>> {
+ %c = arith.constant dense<["10", "20"]> : tensor<2x!quant.uniform<i8:f32, 0.5:0>>
+ return %c : tensor<2x!quant.uniform<i8:f32, 0.5:0>>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rewrite_fp8_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[1.000000e+00, -5.000000e-01]> : tensor<2xf8E4M3FN>}> : () -> tensor<2xf8E4M3FN>
+func.func @rewrite_fp8_tensor() -> tensor<2xf8E4M3FN> {
+ %c = arith.constant dense<[1.0, -0.5]> : tensor<2xf8E4M3FN>
+ return %c : tensor<2xf8E4M3FN>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rewrite_mxint8_tensor
+// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<["0x00", "0x7F"]> : tensor<2x!tosa.mxint8>}> : () -> tensor<2x!tosa.mxint8>
+func.func @rewrite_mxint8_tensor() -> tensor<2x!tosa.mxint8> {
+ %c = arith.constant dense<["0x00", "0x7F"]> : tensor<2x!tosa.mxint8>
+ return %c : tensor<2x!tosa.mxint8>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
index c7eeb52..d4c4595 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
@@ -98,3 +98,26 @@ func.func @depthwise_conv2d_no_const_zero_point(%arg0: tensor<4x10x10x2xi8>, %ar
%0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = i32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<4x10x10x6xi32>
return %0 : tensor<4x10x10x6xi32>
}
+
+// -----
+// CHECK-LABEL: func.func @depthwise_conv2d_as_mul_dynamic_batch_bias(
+// CHECK-SAME: %[[INP:.*]]: tensor<?x10x10x2xf32>,
+// CHECK-SAME: %[[WTS:.*]]: tensor<1x1x2x3xf32>,
+// CHECK-SAME: %[[BIAS:.*]]: tensor<?xf32>) -> tensor<?x10x10x6xf32> {
+// CHECK: %[[BIAS_EXPANDED_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 1, 1, -1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK: %[[RES_EXPANDED_SHAPE:.*]] = tosa.const_shape {values = dense<[-1, 10, 10, 6]> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK: %[[MUL_SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+// CHECK: %[[WTS_EXPANDED_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 1, 1, 2, 3]> : tensor<5xindex>} : () -> !tosa.shape<5>
+// CHECK: %[[INP_EXPANDED_SHAPE:.*]] = tosa.const_shape {values = dense<[-1, 10, 10, 2, 1]> : tensor<5xindex>} : () -> !tosa.shape<5>
+// CHECK: %[[INP_RESHAPED:.*]] = tosa.reshape %[[INP]], %[[INP_EXPANDED_SHAPE]] : (tensor<?x10x10x2xf32>, !tosa.shape<5>) -> tensor<?x10x10x2x1xf32>
+// CHECK: %[[WTS_RESHAPED:.*]] = tosa.reshape %[[WTS]], %[[WTS_EXPANDED_SHAPE]] : (tensor<1x1x2x3xf32>, !tosa.shape<5>) -> tensor<1x1x1x2x3xf32>
+// CHECK: %[[MUL:.*]] = tosa.mul %[[INP_RESHAPED]], %[[WTS_RESHAPED]], %[[MUL_SHIFT]] : (tensor<?x10x10x2x1xf32>, tensor<1x1x1x2x3xf32>, tensor<1xi8>) -> tensor<?x10x10x2x3xf32>
+// CHECK: %[[RES_RESHAPED:.*]] = tosa.reshape %[[MUL]], %[[RES_EXPANDED_SHAPE]] : (tensor<?x10x10x2x3xf32>, !tosa.shape<4>) -> tensor<?x10x10x6xf32>
+// CHECK: %[[BIAS_RESHAPED:.*]] = tosa.reshape %[[BIAS]], %[[BIAS_EXPANDED_SHAPE]] : (tensor<?xf32>, !tosa.shape<4>) -> tensor<1x1x1x?xf32>
+// CHECK: %[[RES:.*]] = tosa.add %[[RES_RESHAPED]], %[[BIAS_RESHAPED]] : (tensor<?x10x10x6xf32>, tensor<1x1x1x?xf32>) -> tensor<?x10x10x6xf32>
+// CHECK: return %[[RES]]
+func.func @depthwise_conv2d_as_mul_dynamic_batch_bias(%arg0: tensor<?x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<?xf32>) -> tensor<?x10x10x6xf32> {
+ %zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<?x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<?xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x10x10x6xf32>
+ return %0 : tensor<?x10x10x6xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
index 810135f..61ca0ae 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
@@ -181,3 +181,24 @@ func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 :
(tensor<1x16x1x1xi8>, tensor<1x2x1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x19x2x1xi32>
"func.return" (%2) : (tensor<1x19x2x1xi32>) -> ()
}
+
+
+// -----
+// CHECK-LABEL: @transpose_conv2d_non_strided_dynamic_batch
+// CHECK: tosa.conv2d
+// CHECK-NOT: tosa.transpose_conv2d
+func.func @transpose_conv2d_non_strided_dynamic_batch(%arg0: tensor<?x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) -> tensor<?x18x19x5xf32> {
+ %zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<?x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x18x19x5xf32>
+ return %0 : tensor<?x18x19x5xf32>
+}
+
+// -----
+// CHECK-LABEL: @transpose_conv2d_strided_dynamic_batch
+// CHECK: tosa.conv2d
+// CHECK-NOT: tosa.transpose_conv2d
+func.func @transpose_conv2d_strided_dynamic_batch(%arg0: tensor<?x17x15x3xf32>, %arg1: tensor<5x3x5x3xf32>, %arg2: tensor<5xf32>) -> tensor<?x35x47x5xf32> {
+ %zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 3>} : (tensor<?x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x35x47x5xf32>
+ return %0 : tensor<?x35x47x5xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32-aggressive.mlir b/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32-aggressive.mlir
new file mode 100644
index 0000000..1a36177
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32-aggressive.mlir
@@ -0,0 +1,81 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-i64-to-i32="aggressive-rewrite=1" %s | FileCheck %s --allow-unused-prefixes --check-prefixes=COMMON,DEFAULT
+// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-i64-to-i32="aggressive-rewrite=1 convert-function-boundaries=1" %s | FileCheck %s --allow-unused-prefixes --check-prefixes=COMMON,FUNCBOUND
+
+// CHECK-LABEL: test_i64_argmax_large_axis_dim
+func.func @test_i64_argmax_large_axis_dim(%arg0: tensor<1x513x513x2147483650xi8>) -> tensor<1x513x513xi64> {
+ // DEFAULT: tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x2147483650xi8>) -> tensor<1x513x513xi32>
+ %0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x2147483650xi8>) -> tensor<1x513x513xi64>
+ return %0 : tensor<1x513x513xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_convert_input_parameters
+// DEFAULT: %[[IN:.*]]: tensor<1x513x513x3xi64>
+// FUNCBOUND: %[[IN:.*]]: tensor<1x513x513x3xi32>
+func.func @test_convert_input_parameters(%arg0: tensor<1x513x513x3xi64>) -> tensor<1x513x513x3xf32> {
+ // DEFAULT: %[[FUNC_BOUND_CAST:.*]] = tosa.cast %[[IN]] : (tensor<1x513x513x3xi64>) -> tensor<1x513x513x3xi32>
+ // DEFAULT: %[[CAST1:.*]] = tosa.cast %[[FUNC_BOUND_CAST]] : (tensor<1x513x513x3xi32>) -> tensor<1x513x513x3xi32>
+ // FUNCBOUND: %[[CAST1:.*]] = tosa.cast %[[IN]] : (tensor<1x513x513x3xi32>) -> tensor<1x513x513x3xi32>
+ %0 = tosa.cast %arg0 : (tensor<1x513x513x3xi64>) -> tensor<1x513x513x3xi32>
+
+ // COMMON: %[[CAST2:.*]] = tosa.cast %[[CAST1]] : (tensor<1x513x513x3xi32>) -> tensor<1x513x513x3xf32>
+ %1 = tosa.cast %0 : (tensor<1x513x513x3xi32>) -> tensor<1x513x513x3xf32>
+ return %1 : tensor<1x513x513x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_add
+// DEFAULT: %[[IN0:.*]]: tensor<13x21x1xi64>, %[[IN1:.*]]: tensor<13x21x3xi64>
+// FUNCBOUND: %[[IN0:.*]]: tensor<13x21x1xi32>, %[[IN1:.*]]: tensor<13x21x3xi32>
+func.func @test_add(%arg0: tensor<13x21x1xi64>, %arg1: tensor<13x21x3xi64>) -> tensor<13x21x3xi64> {
+ // DEFAULT-DAG: %[[FUNC_BOUND_CAST0:.*]] = tosa.cast %[[IN0]] : (tensor<13x21x1xi64>) -> tensor<13x21x1xi32>
+ // DEFAULT-DAG: %[[FUNC_BOUND_CAST1:.*]] = tosa.cast %[[IN1]] : (tensor<13x21x3xi64>) -> tensor<13x21x3xi32>
+ // DEFAULT: %[[ADD:.*]] = tosa.add %[[FUNC_BOUND_CAST0]], %[[FUNC_BOUND_CAST1]] : (tensor<13x21x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32>
+ // DEFAULT: %[[CAST:.*]] = tosa.cast %[[ADD]] : (tensor<13x21x3xi32>) -> tensor<13x21x3xi64>
+ // DEFAULT: return %[[CAST]] : tensor<13x21x3xi64>
+ // FUNCBOUND: %[[ADD:.*]] = tosa.add %[[IN0]], %[[IN1]] : (tensor<13x21x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32>
+ // FUNCBOUND: return %[[ADD]] : tensor<13x21x3xi32>
+ %0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi64>
+ return %0 : tensor<13x21x3xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_regions
+// DEFAULT: %[[IN0:.*]]: tensor<i64>, %[[IN1:.*]]: tensor<i64>
+func.func @test_regions(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i1>) -> tensor<i64> {
+ // DEFAULT-DAG: %[[CAST0:.*]] = tosa.cast %[[IN0]] : (tensor<i64>) -> tensor<i32>
+ // DEFAULT-DAG: %[[CAST1:.*]] = tosa.cast %[[IN1]] : (tensor<i64>) -> tensor<i32>
+ // COMMON: %[[IF_RESULT:.*]] = tosa.cond_if
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<i64>) {
+ // DEFAULT: %[[ADD:.*]] = tosa.add %[[CAST0]], %[[CAST1]] : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ // FUNCBOUND: %[[ADD:.*]] = tosa.add %[[IN0]], %[[IN1]] : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ %1 = tosa.add %arg0, %arg1 : (tensor<i64>, tensor<i64>) -> tensor<i64>
+ // COMMON: tosa.yield %[[ADD]] : tensor<i32>
+ tosa.yield %1 : tensor<i64>
+ } else {
+ // DEFAULT: %[[SUB:.*]] = tosa.sub %[[CAST0]], %[[CAST1]] : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ // FUNCBOUND: %[[SUB:.*]] = tosa.sub %[[IN0]], %[[IN1]] : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ %1 = tosa.sub %arg0, %arg1 : (tensor<i64>, tensor<i64>) -> tensor<i64>
+ // COMMON: tosa.yield %[[SUB]] : tensor<i32>
+ tosa.yield %1 : tensor<i64>
+ }
+ // DEFAULT: %[[OUT:.*]] = tosa.cast %[[IF_RESULT]] : (tensor<i32>) -> tensor<i64>
+ // DEFAULT: return %[[OUT]] : tensor<i64>
+ // FUNCBOUND: return %[[IF_RESULT]] : tensor<i32>
+ return %0 : tensor<i64>
+}
+
+// -----
+
+// CHECK-LABEL: test_const
+func.func @test_const() -> tensor<2xi64> {
+ // COMMON: %[[CONST:.*]] = "tosa.const"() <{values = dense<[1, 2]> : tensor<2xi32>}> : () -> tensor<2xi32>
+ %0 = "tosa.const"() <{values = dense<[1, 2]> : tensor<2xi64>}> : () -> tensor<2xi64>
+ // DEFAULT: %[[OUT:.*]] = tosa.cast %[[CONST]] : (tensor<2xi32>) -> tensor<2xi64>
+ // DEFAULT: return %[[OUT]] : tensor<2xi64>
+ // FUNCBOUND: return %[[CONST]] : tensor<2xi32>
+ return %0 : tensor<2xi64>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir b/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir
new file mode 100644
index 0000000..a14483f
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir
@@ -0,0 +1,162 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-i64-to-i32="convert-function-boundaries=0" %s | FileCheck %s --allow-unused-prefixes --check-prefixes=COMMON,DEFAULT
+// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-i64-to-i32="convert-function-boundaries=1" %s | FileCheck %s --allow-unused-prefixes --check-prefixes=COMMON,FUNCBOUND
+
+// -----
+
+// CHECK-LABEL: test_i64_argmax
+func.func @test_i64_argmax(%arg0: tensor<1x513x513x19xi8>) -> tensor<1x513x513xi64> {
+ // COMMON: %[[ARGMAX:.*]] = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x19xi8>) -> tensor<1x513x513xi32>
+ %0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x19xi8>) -> tensor<1x513x513xi64>
+
+ // DEFAULT: %[[CAST:.*]] = tosa.cast %[[ARGMAX]] : (tensor<1x513x513xi32>) -> tensor<1x513x513xi64>
+ // FUNCBOUND: return %[[ARGMAX]] : tensor<1x513x513xi32>
+ return %0 : tensor<1x513x513xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_i64_argmax_cast
+func.func @test_i64_argmax_cast(%arg0: tensor<1x513x513x19xi8>) -> tensor<1x513x513xf32> {
+ // COMMON: %[[ARGMAX:.*]] = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x19xi8>) -> tensor<1x513x513xi32>
+ %0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x19xi8>) -> tensor<1x513x513xi64>
+ // COMMON: tosa.cast %[[ARGMAX]] : (tensor<1x513x513xi32>) -> tensor<1x513x513xf32>
+ %1 = tosa.cast %0 : (tensor<1x513x513xi64>) -> tensor<1x513x513xf32>
+ return %1 : tensor<1x513x513xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_i64_argmax_large_axis_dim
+func.func @test_i64_argmax_large_axis_dim(%arg0: tensor<1x513x513x2147483650xi8>) -> tensor<1x513x513xi64> {
+ // expected-error @+1 {{failed to legalize operation 'tosa.argmax'}}
+ %0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x2147483650xi8>) -> tensor<1x513x513xi64>
+ return %0 : tensor<1x513x513xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_add
+func.func @test_add(%arg0: tensor<13x21x1xi64>, %arg1: tensor<13x21x3xi64>) -> tensor<13x21x3xi64> {
+ // expected-error @+1 {{failed to legalize operation 'tosa.add'}}
+ %0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi64>
+ return %0 : tensor<13x21x3xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_regions
+func.func @test_regions(%arg0: tensor<1x2xi32>, %arg1: tensor<1xi32>, %arg2: tensor<i1>) -> tensor<1xi32> {
+ // COMMON: %[[IF_RESULT:.*]] = tosa.cond_if %arg2 : tensor<i1> -> tensor<1xi32>
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<1xi32> {
+ // COMMON: %[[ARGMAX:.*]] = tosa.argmax %arg0 {axis = 1 : i32} : (tensor<1x2xi32>) -> tensor<1xi32>
+ %1 = tosa.argmax %arg0 {axis = 1 : i32} : (tensor<1x2xi32>) -> tensor<1xi64>
+ // COMMON: %[[CAST:.*]] = tosa.cast %[[ARGMAX]] : (tensor<1xi32>) -> tensor<1xi32>
+ %2 = tosa.cast %1 : (tensor<1xi64>) -> tensor<1xi32>
+ // COMMON: tosa.yield %[[CAST]] : tensor<1xi32>
+ tosa.yield %2 : tensor<1xi32>
+ } else {
+ tosa.yield %arg1 : tensor<1xi32>
+ }
+ // COMMON: return %[[IF_RESULT]] : tensor<1xi32>
+ return %0 : tensor<1xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_concat
+func.func @test_concat(%arg0: tensor<13x21x3xi64>, %arg1: tensor<13x21x3xi64>) -> tensor<26x21x3xi64> {
+ // COMMON: tosa.concat %{{.*}}, %{{.*}} {axis = 0 : i32} : (tensor<13x21x3xi32>, tensor<13x21x3xi32>) -> tensor<26x21x3xi32>
+ %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xi64>, tensor<13x21x3xi64>) -> tensor<26x21x3xi64>
+ return %0 : tensor<26x21x3xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_pad
+func.func @test_pad(%arg0: tensor<13x21x3xi64>, %arg1: tensor<1xi64>) -> tensor<15x23x5xi64> {
+ %padding = tosa.const_shape {values = dense<1> : tensor<6xindex>} : () -> !tosa.shape<6>
+ // COMMON: tosa.pad %{{.*}}, %{{.*}}, %{{.*}} : (tensor<13x21x3xi32>, !tosa.shape<6>, tensor<1xi32>) -> tensor<15x23x5xi32>
+ %1 = tosa.pad %arg0, %padding, %arg1 : (tensor<13x21x3xi64>, !tosa.shape<6>, tensor<1xi64>) -> tensor<15x23x5xi64>
+ return %1 : tensor<15x23x5xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape
+func.func @test_reshape(%arg0: tensor<13x21x3xi64>) -> tensor<1x819xi64> {
+ %1 = tosa.const_shape {values = dense<[1, 819]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // COMMON: tosa.reshape %{{.*}}, %{{.*}} : (tensor<13x21x3xi32>, !tosa.shape<2>) -> tensor<1x819xi32>
+ %0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xi64>, !tosa.shape<2>) -> tensor<1x819xi64>
+ return %0 : tensor<1x819xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_reverse
+func.func @test_reverse(%arg0: tensor<13x21x3xi64>) -> tensor<13x21x3xi64> {
+ // COMMON: tosa.reverse %{{.*}} {axis = 0 : i32} : (tensor<13x21x3xi32>) -> tensor<13x21x3xi32>
+ %0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<13x21x3xi64>) -> tensor<13x21x3xi64>
+ return %0 : tensor<13x21x3xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_slice
+func.func @test_slice(%arg0: tensor<13x21x3xi64>) -> tensor<4x11x1xi64> {
+ %0 = tosa.const_shape {values = dense<[4, 11, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ %1 = tosa.const_shape {values = dense<[6, 8, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ // COMMON: tosa.slice %{{.*}}, %{{.*}}, %{{.*}} : (tensor<13x21x3xi32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xi32>
+ %2 = tosa.slice %arg0, %0, %1 : (tensor<13x21x3xi64>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xi64>
+ return %2 : tensor<4x11x1xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_tile
+func.func @test_tile(%arg0: tensor<13x21x3xi64>) -> tensor<39x21x6xi64> {
+ %cst = tosa.const_shape { values = dense<[3, 1, 2]> : tensor<3xindex> } : () -> !tosa.shape<3>
+ // COMMON: tosa.tile %{{.*}}, %{{.*}} : (tensor<13x21x3xi32>, !tosa.shape<3>) -> tensor<39x21x6xi32>
+ %0 = tosa.tile %arg0, %cst: (tensor<13x21x3xi64>, !tosa.shape<3>) -> tensor<39x21x6xi64>
+ return %0 : tensor<39x21x6xi64>
+}
+
+// -----
+
+// CHECK-LABEL: transpose
+func.func @test_transpose(%arg0: tensor<13x21x3xi64>) -> tensor<3x13x21xi64> {
+ // COMMON: tosa.transpose %{{.*}} {perms = array<i32: 2, 0, 1>} : (tensor<13x21x3xi32>) -> tensor<3x13x21xi32>
+ %1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<13x21x3xi64>) -> tensor<3x13x21xi64>
+ return %1 : tensor<3x13x21xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_transition_to_i64
+func.func @test_transition_to_i64(%arg0: tensor<1xi32>) -> tensor<1xi64> {
+ // COMMON: %[[CAST:.*]] = tosa.cast %arg0 : (tensor<1xi32>) -> tensor<1xi32>
+ %0 = tosa.cast %arg0 : (tensor<1xi32>) -> tensor<1xi64>
+ // COMMON: %[[IDENTITY1:.*]] = tosa.identity %[[CAST]] : (tensor<1xi32>) -> tensor<1xi32>
+ %1 = tosa.identity %0 : (tensor<1xi64>) -> tensor<1xi64>
+ // COMMON: %[[IDENTITY2:.*]] = tosa.identity %[[IDENTITY1]] : (tensor<1xi32>) -> tensor<1xi32>
+ %2 = tosa.identity %1 : (tensor<1xi64>) -> tensor<1xi64>
+ // DEFAULT: %[[OUT_CAST:.*]] = tosa.cast %[[IDENTITY2]] : (tensor<1xi32>) -> tensor<1xi64>
+ // DEFAULT: return %[[OUT_CAST]] : tensor<1xi64>
+ // FUNCBOUND: return %[[IDENTITY2]] : tensor<1xi32>
+ return %2 : tensor<1xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_transition_from_i64
+func.func @test_transition_from_i64(%arg0: tensor<1xi64>) -> tensor<1xi32> {
+ // DEFAULT: %[[CAST:.*]] = tosa.cast %arg0 : (tensor<1xi64>) -> tensor<1xi32>
+ // DEFAULT: %[[IDENTITY1:.*]] = tosa.identity %[[CAST]] : (tensor<1xi32>) -> tensor<1xi32>
+ // FUNCBOUND: %[[IDENTITY1:.*]] = tosa.identity %arg0 : (tensor<1xi32>) -> tensor<1xi32>
+ %0 = tosa.identity %arg0 : (tensor<1xi64>) -> tensor<1xi64>
+ // COMMON: %[[IDENTITY2:.*]] = tosa.identity %[[IDENTITY1]] : (tensor<1xi32>) -> tensor<1xi32>
+ %1 = tosa.identity %0 : (tensor<1xi64>) -> tensor<1xi64>
+ // COMMON: %[[OUT_CAST:.*]] = tosa.cast %[[IDENTITY2]] : (tensor<1xi32>) -> tensor<1xi32>
+ %2 = tosa.cast %1 : (tensor<1xi64>) -> tensor<1xi32>
+ // COMMON: return %[[OUT_CAST]] : tensor<1xi32>
+ return %2 : tensor<1xi32>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
index 9bd7aa8..f6b1edc 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -2,6 +2,7 @@
// -----
+// CHECK-LABEL: test_matmul_fp8_mixed_precision_operands
func.func @test_matmul_fp8_mixed_precision_operands(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16> {
%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
@@ -53,14 +54,6 @@ func.func @test_const_fp6e3m2() -> tensor<4xf6E3M2FN> {
// -----
-// CHECK-LABEL: test_cast_f4e2m1
-func.func @test_cast_f4e2m1(%arg0: tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16> {
- %0 = tosa.cast %arg0 : (tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16>
- return %0 : tensor<13x21x3xbf16>
-}
-
-// -----
-
// CHECK-LABEL: test_cast_from_block_scaled_fp8e5m2_fp32
func.func @test_cast_from_block_scaled_fp8e5m2_fp32(%arg0: tensor<4x32xf8E5M2>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
%0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf8E5M2>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32>
@@ -109,14 +102,6 @@ func.func @test_const_mxint8() -> tensor<2x!tosa.mxint8> {
// -----
-// CHECK-LABEL: test_cast_f4e2m1
-func.func @test_cast_f4e2m1(%arg0: tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16> {
- %0 = tosa.cast %arg0 : (tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16>
- return %0 : tensor<13x21x3xbf16>
-}
-
-// -----
-
// CHECK-LABEL: test_matmul_t_block_scaled_mxint8
func.func @test_matmul_t_block_scaled_mxint8(%arg0: tensor<4x8x32x!tosa.mxint8>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32x!tosa.mxint8>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
%0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32x!tosa.mxint8>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32x!tosa.mxint8>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
@@ -130,3 +115,28 @@ func.func @test_cast_to_block_scaled_mxint8(%arg0: tensor<4x32xf32>) -> (tensor<
%0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>)
return %0#0, %0#1 : tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>
}
+
+// -----
+
+// CHECK-LABEL: test_argmax_fp8_i64
+func.func @test_argmax_fp8_i64(%arg0: tensor<12x8x16xf8E5M2>) -> tensor<12x16xi64> {
+ %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<12x8x16xf8E5M2>) -> tensor<12x16xi64>
+ return %0 : tensor<12x16xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_argmax_bf16_i64
+func.func @test_argmax_bf16_i64(%arg0: tensor<12x8x16xbf16>) -> tensor<12x16xi64> {
+ %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<12x8x16xbf16>) -> tensor<12x16xi64>
+ return %0 : tensor<12x16xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_scatter_const_indices_int64
+func.func @test_scatter_const_indices_int64(%arg0: tensor<2x52x3xf32>, %arg2: tensor<2x12x3xf32>) -> tensor<2x52x3xf32> {
+ %indices = "tosa.const"() { values = dense<[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]> : tensor<2x12xi64> } : () -> tensor<2x12xi64>
+ %0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi64>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32>
+ return %0 : tensor<2x52x3xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 6cf76cd..ea64d46 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1222,3 +1222,11 @@ func.func @test_cast_to_block_scaled_data_scale_channel_mismatch(%arg0: tensor<4
%0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>)
return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>
}
+
+// -----
+
+func.func @test_clamp_quantized(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) {
+ // expected-error@+1 {{'tosa.clamp' op min/max attributes types are incompatible with input/output element types.}}
+ %0 = tosa.clamp %arg0 {max_val = 127 : i8, min_val = -128 : i8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+ return %0 : tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+}
diff --git a/mlir/test/Dialect/Transform/include-failure-propagation.mlir b/mlir/test/Dialect/Transform/include-failure-propagation.mlir
new file mode 100644
index 0000000..94e9d8f
--- /dev/null
+++ b/mlir/test/Dialect/Transform/include-failure-propagation.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --verify-diagnostics
+
+module attributes { transform.with_named_sequence } {
+ // Callee returns a silenceable failure when given a module instead of func.func.
+ transform.named_sequence @callee(%root: !transform.any_op {transform.consumed}) -> (!transform.any_op) {
+ transform.test_consume_operand_of_op_kind_or_fail %root, "func.func" : !transform.any_op
+ transform.yield %root : !transform.any_op
+ }
+
+ transform.named_sequence @__transform_main(%root: !transform.any_op) {
+ %res = transform.sequence %root : !transform.any_op -> !transform.any_op failures(suppress) {
+ ^bb0(%arg0: !transform.any_op):
+ // This include returns a silenceable failure; it must not remap results.
+ %included = transform.include @callee failures(propagate) (%arg0) : (!transform.any_op) -> (!transform.any_op)
+ transform.yield %included : !transform.any_op
+ }
+
+ %count = transform.num_associations %res : (!transform.any_op) -> !transform.param<i64>
+ // expected-remark @below {{0}}
+ transform.debug.emit_param_as_remark %count : !transform.param<i64>
+
+ // If the include incorrectly forwarded mappings on failure, this would run
+ // and produce an unexpected remark under --verify-diagnostics.
+ transform.foreach %res : !transform.any_op {
+ ^bb0(%it: !transform.any_op):
+ transform.debug.emit_remark_at %it, "include result unexpectedly populated" : !transform.any_op
+ }
+ transform.yield
+ }
+}
+
+// -----
+
+module {
+ func.func @payload() {
+ return
+ }
+}
diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir
index ce8f69c..4806daf7 100644
--- a/mlir/test/Dialect/Transform/test-pass-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pass-application.mlir
@@ -386,7 +386,7 @@ module attributes {transform.with_named_sequence} {
// -----
module attributes {transform.with_named_sequence} {
- // expected-error @below {{trying to schedule a pass on an unsupported operation}}
+ // expected-error @below {{trying to schedule pass 'DuplicateFunctionEliminationPass' on an unsupported operation}}
// expected-note @below {{target op}}
func.func @invalid_target_op_type() {
return
diff --git a/mlir/test/Dialect/UB/ops.mlir b/mlir/test/Dialect/UB/ops.mlir
index 724b6b4..730c1bd 100644
--- a/mlir/test/Dialect/UB/ops.mlir
+++ b/mlir/test/Dialect/UB/ops.mlir
@@ -38,3 +38,9 @@ func.func @poison_tensor() -> tensor<8x?xf64> {
%0 = ub.poison : tensor<8x?xf64>
return %0 : tensor<8x?xf64>
}
+
+// CHECK-LABEL: func @unreachable()
+// CHECK: ub.unreachable
+func.func @unreachable() {
+ ub.unreachable
+}
diff --git a/mlir/test/Dialect/Vector/bufferize.mlir b/mlir/test/Dialect/Vector/bufferize.mlir
index 887fb94..70adefd 100644
--- a/mlir/test/Dialect/Vector/bufferize.mlir
+++ b/mlir/test/Dialect/Vector/bufferize.mlir
@@ -32,6 +32,26 @@ func.func @transfer_write(%t: tensor<?x?xf32>, %o1: index,
// -----
+// CHECK-LABEL: func @scatter(
+// CHECK-SAME: %[[base:.*]]: tensor<16x16xf32>, %[[v:.*]]: vector<16xi32>,
+// CHECK-SAME: %[[mask:.*]]: vector<16xi1>, %[[value:.*]]: vector<16xf32>) -> tensor<16x16xf32>
+// CHECK: %[[buf:.*]] = bufferization.to_buffer %[[base]] : tensor<16x16xf32> to memref<16x16xf32>
+// CHECK: %[[c0:.*]] = arith.constant 0 : index
+// CHECK: %[[alloc:.*]] = memref.alloc() {alignment = 64 : i64} : memref<16x16xf32>
+// CHECK: memref.copy %[[buf]], %[[alloc]] : memref<16x16xf32> to memref<16x16xf32>
+// CHECK: vector.scatter %[[alloc]][%[[c0]], %[[c0]]] [%[[v]]], %[[mask]], %[[value]] : memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+// CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[alloc]] : memref<16x16xf32> to tensor<16x16xf32>
+// CHECK: return %[[tensor]] : tensor<16x16xf32>
+func.func @scatter(%base: tensor<16x16xf32>, %v: vector<16xi32>,
+ %mask: vector<16xi1>, %value: vector<16xf32>) -> tensor<16x16xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = vector.scatter %base[%c0, %c0][%v], %mask, %value
+ : tensor<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> -> tensor<16x16xf32>
+ return %0 : tensor<16x16xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @gather(
// CHECK-SAME: %[[base:.*]]: tensor<?x?xf32>, %[[v:.*]]: vector<16xi32>,
// CHECK-SAME: %[[mask:.*]]: vector<16xi1>, %[[pass_thru:.*]]: vector<16xf32>)
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 5f035e3..79b09e1 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1491,9 +1491,9 @@ func.func @gather_non_power_of_two_alignment(%base: memref<16xf32>, %indices: ve
func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
- // expected-error@+2 {{custom op 'vector.scatter' invalid kind of type specified}}
+ // expected-error@+1 {{'vector.scatter' op operand #0 must be Tensor or MemRef of any type values, but got 'vector<16xf32>'}}
vector.scatter %base[%c0][%indices], %mask, %pass_thru
- : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
}
// -----
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index da9a1a8..de62022 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -1160,3 +1160,17 @@ func.func @step() {
%1 = vector.step : vector<[4]xindex>
return
}
+
+// CHECK-LABEL: func @scatter_tensor(
+// CHECK-SAME: %[[BASE:.*]]: tensor<16x16xf32>, %[[V:.*]]: vector<16xi32>,
+// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>) -> tensor<16x16xf32>
+func.func @scatter_tensor(%base: tensor<16x16xf32>, %v: vector<16xi32>,
+ %mask: vector<16xi1>, %value: vector<16xf32>) -> tensor<16x16xf32> {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[RESULT:.*]] = vector.scatter %[[BASE]][%[[C0]], %[[C0]]] [%[[V]]], %[[MASK]], %[[VALUE]]
+ %0 = vector.scatter %base[%c0, %c0] [%v], %mask, %value
+ : tensor<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> -> tensor<16x16xf32>
+ // CHECK: return %[[RESULT]] : tensor<16x16xf32>
+ return %0 : tensor<16x16xf32>
+}
diff --git a/mlir/test/Dialect/Vector/vector-scan-transforms.mlir b/mlir/test/Dialect/Vector/vector-scan-transforms.mlir
index 1d8f440..27a3653 100644
--- a/mlir/test/Dialect/Vector/vector-scan-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-scan-transforms.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --test-vector-scan-lowering | FileCheck %s
+// RUN: mlir-opt %s -split-input-file --test-vector-scan-lowering | FileCheck %s
// CHECK-LABEL: func @scan1d_inc
// CHECK-SAME: %[[ARG0:.*]]: vector<2xi32>,
@@ -18,6 +18,20 @@ func.func @scan1d_inc(%arg0 : vector<2xi32>, %arg1 : vector<i32>) -> (vector<2xi
return %0#0, %0#1 : vector<2xi32>, vector<i32>
}
+// -----
+
+// Reducing scalable dims is not yet supported!
+
+// CHECK-LABEL: func @scan1d_inc_scalable
+// CHECK: vector.scan
+func.func @scan1d_inc_scalable(%arg0 : vector<[2]xi32>, %arg1 : vector<i32>) -> (vector<[2]xi32>, vector<i32>) {
+ %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim = 0} :
+ vector<[2]xi32>, vector<i32>
+ return %0#0, %0#1 : vector<[2]xi32>, vector<i32>
+}
+
+// -----
+
// CHECK-LABEL: func @scan1d_exc
// CHECK-SAME: %[[ARG0:.*]]: vector<2xi32>,
// CHECK-SAME: %[[ARG1:.*]]: vector<i32>
@@ -36,6 +50,20 @@ func.func @scan1d_exc(%arg0 : vector<2xi32>, %arg1 : vector<i32>) -> (vector<2xi
return %0#0, %0#1 : vector<2xi32>, vector<i32>
}
+// -----
+
+// Rducing scalable dims is not yet supported!
+
+// CHECK-LABEL: func @scan1d_exc_scalable
+// CHECK: vector.scan
+func.func @scan1d_exc_scalable(%arg0 : vector<[2]xi32>, %arg1 : vector<i32>) -> (vector<[2]xi32>, vector<i32>) {
+ %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = false, reduction_dim = 0} :
+ vector<[2]xi32>, vector<i32>
+ return %0#0, %0#1 : vector<[2]xi32>, vector<i32>
+}
+
+// -----
+
// CHECK-LABEL: func @scan2d_mul_dim0
// CHECK-SAME: %[[ARG0:.*]]: vector<2x3xi32>,
// CHECK-SAME: %[[ARG1:.*]]: vector<3xi32>
@@ -53,6 +81,27 @@ func.func @scan2d_mul_dim0(%arg0 : vector<2x3xi32>, %arg1 : vector<3xi32>) -> (v
return %0#0, %0#1 : vector<2x3xi32>, vector<3xi32>
}
+// -----
+
+// CHECK-LABEL: func @scan2d_mul_dim0_scalable
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x[3]xi32>,
+// CHECK-SAME: %[[ARG1:.*]]: vector<[3]xi32>
+// CHECK: %[[A:.*]] = arith.constant dense<0> : vector<2x[3]xi32>
+// CHECK: %[[B:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x[3]xi32> to vector<1x[3]xi32>
+// CHECK: %[[C:.*]] = vector.insert_strided_slice %[[B]], %[[A]] {offsets = [0, 0], strides = [1, 1]} : vector<1x[3]xi32> into vector<2x[3]xi32>
+// CHECK: %[[D:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [1, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x[3]xi32> to vector<1x[3]xi32>
+// CHECK: %[[E:.*]] = arith.muli %[[B]], %[[D]] : vector<1x[3]xi32>
+// CHECK: %[[F:.*]] = vector.insert_strided_slice %[[E]], %[[C]] {offsets = [1, 0], strides = [1, 1]} : vector<1x[3]xi32> into vector<2x[3]xi32>
+// CHECK: %[[G:.*]] = vector.shape_cast %[[E]] : vector<1x[3]xi32> to vector<[3]xi32>
+// CHECK: return %[[F]], %[[G]] : vector<2x[3]xi32>, vector<[3]xi32>
+func.func @scan2d_mul_dim0_scalable(%arg0 : vector<2x[3]xi32>, %arg1 : vector<[3]xi32>) -> (vector<2x[3]xi32>, vector<[3]xi32>) {
+ %0:2 = vector.scan <mul>, %arg0, %arg1 {inclusive = true, reduction_dim = 0} :
+ vector<2x[3]xi32>, vector<[3]xi32>
+ return %0#0, %0#1 : vector<2x[3]xi32>, vector<[3]xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @scan2d_mul_dim1
// CHECK-SAME: %[[ARG0:.*]]: vector<2x3xi32>,
// CHECK-SAME: %[[ARG1:.*]]: vector<2xi32>
@@ -73,6 +122,30 @@ func.func @scan2d_mul_dim1(%arg0 : vector<2x3xi32>, %arg1 : vector<2xi32>) -> (v
return %0#0, %0#1 : vector<2x3xi32>, vector<2xi32>
}
+// -----
+
+// CHECK-LABEL: func @scan2d_mul_dim1_scalable
+// CHECK-SAME: %[[ARG0:.*]]: vector<[2]x3xi32>,
+// CHECK-SAME: %[[ARG1:.*]]: vector<[2]xi32>
+// CHECK: %[[A:.*]] = arith.constant dense<0> : vector<[2]x3xi32>
+// CHECK: %[[B:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0], sizes = [2, 1], strides = [1, 1]} : vector<[2]x3xi32> to vector<[2]x1xi32>
+// CHECK: %[[C:.*]] = vector.insert_strided_slice %[[B]], %[[A]] {offsets = [0, 0], strides = [1, 1]} : vector<[2]x1xi32> into vector<[2]x3xi32>
+// CHECK: %[[D:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} : vector<[2]x3xi32> to vector<[2]x1xi32>
+// CHECK: %[[E:.*]] = arith.muli %[[B]], %[[D]] : vector<[2]x1xi32>
+// CHECK: %[[F:.*]] = vector.insert_strided_slice %[[E]], %[[C]] {offsets = [0, 1], strides = [1, 1]} : vector<[2]x1xi32> into vector<[2]x3xi32>
+// CHECK: %[[G:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 2], sizes = [2, 1], strides = [1, 1]} : vector<[2]x3xi32> to vector<[2]x1xi32>
+// CHECK: %[[H:.*]] = arith.muli %[[E]], %[[G]] : vector<[2]x1xi32>
+// CHECK: %[[I:.*]] = vector.insert_strided_slice %[[H]], %[[F]] {offsets = [0, 2], strides = [1, 1]} : vector<[2]x1xi32> into vector<[2]x3xi32>
+// CHECK: %[[J:.*]] = vector.shape_cast %[[H]] : vector<[2]x1xi32> to vector<[2]xi32>
+// CHECK: return %[[I]], %[[J]] : vector<[2]x3xi32>, vector<[2]xi32>
+func.func @scan2d_mul_dim1_scalable(%arg0 : vector<[2]x3xi32>, %arg1 : vector<[2]xi32>) -> (vector<[2]x3xi32>, vector<[2]xi32>) {
+ %0:2 = vector.scan <mul>, %arg0, %arg1 {inclusive = true, reduction_dim = 1} :
+ vector<[2]x3xi32>, vector<[2]xi32>
+ return %0#0, %0#1 : vector<[2]x3xi32>, vector<[2]xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @scan3d_mul_dim1
// CHECK-SAME: %[[ARG0:.*]]: vector<4x2x3xf32>,
// CHECK-SAME: %[[ARG1:.*]]: vector<4x3xf32>
@@ -89,3 +162,22 @@ func.func @scan3d_mul_dim1(%arg0 : vector<4x2x3xf32>, %arg1 : vector<4x3xf32>) -
vector<4x2x3xf32>, vector<4x3xf32>
return %0#0, %0#1 : vector<4x2x3xf32>, vector<4x3xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @scan3d_mul_dim1_scalable
+// CHECK-SAME: %[[ARG0:.*]]: vector<4x2x[3]xf32>,
+// CHECK-SAME: %[[ARG1:.*]]: vector<4x[3]xf32>
+// CHECK: %[[A:.*]] = arith.constant dense<0.000000e+00> : vector<4x2x[3]xf32>
+// CHECK: %[[B:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x[3]xf32> to vector<4x1x[3]xf32>
+// CHECK: %[[C:.*]] = vector.shape_cast %[[ARG1]] : vector<4x[3]xf32> to vector<4x1x[3]xf32>
+// CHECK: %[[D:.*]] = vector.insert_strided_slice %[[C]], %[[A]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x[3]xf32> into vector<4x2x[3]xf32>
+// CHECK: %[[E:.*]] = arith.mulf %[[C]], %[[B]] : vector<4x1x[3]xf32>
+// CHECK: %[[F:.*]] = vector.insert_strided_slice %[[E]], %[[D]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x[3]xf32> into vector<4x2x[3]xf32>
+// CHECK: %[[G:.*]] = vector.shape_cast %[[E]] : vector<4x1x[3]xf32> to vector<4x[3]xf32>
+// CHECK: return %[[F]], %[[G]] : vector<4x2x[3]xf32>, vector<4x[3]xf32>
+func.func @scan3d_mul_dim1_scalable(%arg0 : vector<4x2x[3]xf32>, %arg1 : vector<4x[3]xf32>) -> (vector<4x2x[3]xf32>, vector<4x[3]xf32>) {
+ %0:2 = vector.scan <mul>, %arg0, %arg1 {inclusive = false, reduction_dim = 1} :
+ vector<4x2x[3]xf32>, vector<4x[3]xf32>
+ return %0#0, %0#1 : vector<4x2x[3]xf32>, vector<4x[3]xf32>
+}
diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir
index 577b06d..69fba88 100644
--- a/mlir/test/Dialect/Vector/vector-sink.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink.mlir
@@ -382,6 +382,21 @@ func.func @broadcast_scalar_extsi_scalable(%a : i8) -> vector<2x[4]xi32> {
return %r : vector<2x[4]xi32>
}
+// -----
+
+// CHECK-LABEL: func.func @negative_broadcast_cast_non_vector_result
+// CHECK-SAME: (%[[ARG:.*]]: i64)
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG]] : i64 to vector<26x7xi64>
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[BCAST]] : vector<26x7xi64> to !llvm.array<26 x vector<7xi64>>
+// CHECK: return %[[CAST]] : !llvm.array<26 x vector<7xi64>>
+/// This test ensures that the `ReorderCastOpsOnBroadcast` pattern does not
+/// attempt to reorder a cast operation that produces a non-vector result type.
+func.func @negative_broadcast_cast_non_vector_result(%arg0: i64) -> !llvm.array<26 x vector<7xi64>> {
+ %0 = vector.broadcast %arg0 : i64 to vector<26x7xi64>
+ %1 = builtin.unrealized_conversion_cast %0 : vector<26x7xi64> to !llvm.array<26 x vector<7xi64>>
+ return %1 : !llvm.array<26 x vector<7xi64>>
+}
+
//===----------------------------------------------------------------------===//
// [Pattern: ReorderElementwiseOpsOnTranspose]
//===----------------------------------------------------------------------===//
@@ -780,7 +795,7 @@ func.func @negative_extract_load_scalable(%arg0: memref<?xf32>, %arg1: index) ->
}
//-----------------------------------------------------------------------------
-// [Pattern: StoreOpFromSplatOrBroadcast]
+// [Pattern: StoreOpFromBroadcast]
//-----------------------------------------------------------------------------
// CHECK-LABEL: @store_splat
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index e5a98b5..805e66f 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -496,3 +496,137 @@ func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf3
// CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32>
// CHECK-NOT: arith.addf
// CHECK: return
+
+func.func @vector_create_mask(%size1: index, %size2: index) -> vector<16x16xi1> {
+ %0 = vector.create_mask %size1, %size2 : vector<16x16xi1>
+ return %0 : vector<16x16xi1>
+}
+
+// CHECK-LABEL: func @vector_create_mask
+// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<16x16xi1>
+// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<16x16xi1>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[C8:.*]] = arith.constant 8 : index
+// CHECK: %[[MAX0:.*]] = arith.maxsi %[[ARG0]], %[[C0]] : index
+// CHECK: %[[MIN0:.*]] = arith.minsi %[[MAX0]], %[[C8]] : index
+// CHECK: %[[MAX1:.*]] = arith.maxsi %[[ARG1]], %[[C0]] : index
+// CHECK: %[[MIN1:.*]] = arith.minsi %[[MAX1]], %[[C8]] : index
+// CHECK: %[[MASK00:.*]] = vector.create_mask %[[MIN0]], %[[MIN1]] : vector<8x8xi1>
+// CHECK: %[[INS00:.*]] = vector.insert_strided_slice %[[MASK00]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+// CHECK: %[[MAX0_2:.*]] = arith.maxsi %[[ARG0]], %[[C0]] : index
+// CHECK: %[[MIN0_2:.*]] = arith.minsi %[[MAX0_2]], %[[C8]] : index
+// CHECK: %[[SUB1:.*]] = arith.subi %[[ARG1]], %[[C8]] : index
+// CHECK: %[[MAX1_2:.*]] = arith.maxsi %[[SUB1]], %[[C0]] : index
+// CHECK: %[[MIN1_2:.*]] = arith.minsi %[[MAX1_2]], %[[C8]] : index
+// CHECK: %[[MASK01:.*]] = vector.create_mask %[[MIN0_2]], %[[MIN1_2]] : vector<8x8xi1>
+// CHECK: %[[INS01:.*]] = vector.insert_strided_slice %[[MASK01]], %[[INS00]] {offsets = [0, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+// CHECK: %[[SUB0:.*]] = arith.subi %[[ARG0]], %[[C8]] : index
+// CHECK: %[[MAX0_3:.*]] = arith.maxsi %[[SUB0]], %[[C0]] : index
+// CHECK: %[[MIN0_3:.*]] = arith.minsi %[[MAX0_3]], %[[C8]] : index
+// CHECK: %[[MAX1_3:.*]] = arith.maxsi %[[ARG1]], %[[C0]] : index
+// CHECK: %[[MIN1_3:.*]] = arith.minsi %[[MAX1_3]], %[[C8]] : index
+// CHECK: %[[MASK10:.*]] = vector.create_mask %[[MIN0_3]], %[[MIN1_3]] : vector<8x8xi1>
+// CHECK: %[[INS10:.*]] = vector.insert_strided_slice %[[MASK10]], %[[INS01]] {offsets = [8, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+// CHECK: %[[SUB0_2:.*]] = arith.subi %[[ARG0]], %[[C8]] : index
+// CHECK: %[[MAX0_4:.*]] = arith.maxsi %[[SUB0_2]], %[[C0]] : index
+// CHECK: %[[MIN0_4:.*]] = arith.minsi %[[MAX0_4]], %[[C8]] : index
+// CHECK: %[[SUB1_2:.*]] = arith.subi %[[ARG1]], %[[C8]] : index
+// CHECK: %[[MAX1_4:.*]] = arith.maxsi %[[SUB1_2]], %[[C0]] : index
+// CHECK: %[[MIN1_4:.*]] = arith.minsi %[[MAX1_4]], %[[C8]] : index
+// CHECK: %[[MASK11:.*]] = vector.create_mask %[[MIN0_4]], %[[MIN1_4]] : vector<8x8xi1>
+// CHECK: %[[INS11:.*]] = vector.insert_strided_slice %[[MASK11]], %[[INS10]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+// CHECK: return %[[INS11]] : vector<16x16xi1>
+
+func.func @vector_create_mask_constant_dim_sizes() -> vector<16x16xi1> {
+ %cst16 = arith.constant 16 : index
+ %0 = vector.create_mask %cst16, %cst16 : vector<16x16xi1>
+ return %0 : vector<16x16xi1>
+}
+
+// CHECK-LABEL: func @vector_create_mask_constant_dim_sizes() -> vector<16x16xi1> {
+// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<16x16xi1>
+// CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<8x8xi1>
+// CHECK: %[[S0:.*]] = vector.insert_strided_slice %[[CST_0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+// CHECK: %[[S1:.*]] = vector.insert_strided_slice %[[CST_0]], %[[S0]] {offsets = [0, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+// CHECK: %[[S2:.*]] = vector.insert_strided_slice %[[CST_0]], %[[S1]] {offsets = [8, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+// CHECK: %[[S3:.*]] = vector.insert_strided_slice %[[CST_0]], %[[S2]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+// CHECK: return %[[S3]] : vector<16x16xi1>
+
+
+func.func @shape_cast_1D(%v: vector<16xf32>) -> vector<2x2x4xf32> {
+ %0 = vector.shape_cast %v : vector<16xf32> to vector<2x2x4xf32>
+ return %0 : vector<2x2x4xf32>
+}
+
+// CHECK-LABEL: func @shape_cast_1D
+// CHECK-SAME: (%[[V:.*]]: vector<16xf32>) -> vector<2x2x4xf32> {
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x4xf32>
+// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32>
+// CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<8xf32> to vector<2x4xf32>
+// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<2x2x4xf32>
+// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [8], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32>
+// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<8xf32> to vector<2x4xf32>
+// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<2x2x4xf32>
+// CHECK: return %[[I1]] : vector<2x2x4xf32>
+
+
+func.func @shape_cast_2D(%v: vector<8x2xf32>) -> vector<4x4xf32> {
+ %0 = vector.shape_cast %v : vector<8x2xf32> to vector<4x4xf32>
+ return %0 : vector<4x4xf32>
+}
+
+// CHECK-LABEL: func @shape_cast_2D
+// CHECK-SAME: (%[[V:.*]]: vector<8x2xf32>) -> vector<4x4xf32> {
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
+// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32>
+// CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<4x2xf32> to vector<2x4xf32>
+// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<4x4xf32>
+// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [4, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32>
+// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<4x2xf32> to vector<2x4xf32>
+// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x4xf32> into vector<4x4xf32>
+// CHECK: return %[[I1]] : vector<4x4xf32>
+
+
+// This is a negative test case to ensure that such shape casts are not unrolled
+// because the targetShape (2x4) is not contiguous in result vector
+func.func @negative_shape_cast_target_shape_not_contiguous(%v: vector<64xf32>) -> vector<8x8xf32> {
+ %0 = vector.shape_cast %v : vector<64xf32> to vector<8x8xf32>
+ return %0 : vector<8x8xf32>
+}
+
+// CHECK-LABEL: func @negative_shape_cast_target_shape_not_contiguous
+// CHECK-SAME: (%[[V:.*]]: vector<64xf32>) -> vector<8x8xf32> {
+// CHECK: %[[SC:.*]] = vector.shape_cast %[[V]] : vector<64xf32> to vector<8x8xf32>
+// CHECK: return %[[SC]] : vector<8x8xf32>
+
+
+// This is negative test case to ensure that such shape casts are not unrolled
+// because it cannot determine the extractShape from source vector (8x3)
+// to extract conitguous targetShape (2x4)
+func.func @negative_shape_cast_source_shape_not_determinable(%v: vector<8x3xf32>) -> vector<6x4xf32> {
+ %0 = vector.shape_cast %v : vector<8x3xf32> to vector<6x4xf32>
+ return %0 : vector<6x4xf32>
+}
+
+// CHECK-LABEL: func @negative_shape_cast_source_shape_not_determinable
+// CHECK-SAME: (%[[V:.*]]: vector<8x3xf32>) -> vector<6x4xf32> {
+// CHECK: %[[SC:.*]] = vector.shape_cast %[[V]] : vector<8x3xf32> to vector<6x4xf32>
+// CHECK: return %[[SC]] : vector<6x4xf32>
+
+
+// TargetShape is [1x16]
+func.func @shape_cast_leading_unit_dim(%v: vector<32xf32>) -> vector<1x32xf32> {
+ %0 = vector.shape_cast %v : vector<32xf32> to vector<1x32xf32>
+ return %0 : vector<1x32xf32>
+}
+
+// CHECK-LABEL: func @shape_cast_leading_unit_dim
+// CHECK-SAME: (%[[V:.*]]: vector<32xf32>) -> vector<1x32xf32> {
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1x32xf32>
+// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0], sizes = [16], strides = [1]} : vector<32xf32> to vector<16xf32>
+// CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<16xf32> to vector<1x16xf32>
+// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<1x16xf32> into vector<1x32xf32>
+// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [16], sizes = [16], strides = [1]} : vector<32xf32> to vector<16xf32>
+// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<16xf32> to vector<1x16xf32>
+// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [0, 16], strides = [1, 1]} : vector<1x16xf32> into vector<1x32xf32>
+// CHECK: return %[[I1]] : vector<1x32xf32>
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir
new file mode 100644
index 0000000..e506b16
--- /dev/null
+++ b/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir
@@ -0,0 +1,344 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+!vecA = vector<1x1xf32>
+!vecB = vector<1x64xf32>
+!vecC = vector<1x64xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @matmul_outer_product_to_fma(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @matmul_outer_product_to_fma
+// CHECK: vector.broadcast{{.*}}vector<1xf32> to vector<64xf32>
+// CHECK: vector.fma{{.*}}vector<64xf32>
+// CHECK: vector.shape_cast{{.*}}vector<64xf32> to vector<1x64xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<64x1xf32>
+!vecB = vector<1x1xf32>
+!vecC = vector<64x1xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @matmul_outer_product_to_fma_bcst_B(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @matmul_outer_product_to_fma_bcst_B
+// CHECK: vector.broadcast
+// CHECK: vector.fma{{.*}}vector<64xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x1xf32>
+!vecB = vector<1x1x64xf32>
+!vecC = vector<1x1x64xf32>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @batch_matmul_to_fma(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @batch_matmul_to_fma
+// CHECK: vector.broadcast
+// CHECK: vector.fma{{.*}}vector<64xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x64x1xf32>
+!vecB = vector<1x1x1xf32>
+!vecC = vector<1x64x1xf32>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @batch_matmul_to_fma_bcst_B(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @batch_matmul_to_fma_bcst_B
+// CHECK: vector.broadcast
+// CHECK: vector.fma{{.*}}vector<64xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x1xf32>
+!vecB = vector<1x1x64xf32>
+!vecC = vector<1x64xf32>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_to_fma(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @brgemm_to_fma
+// CHECK: vector.broadcast
+// CHECK: vector.fma{{.*}}vector<64xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x64x1xf32>
+!vecB = vector<1x1x1xf32>
+!vecC = vector<64x1xf32>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_to_fma_bcst_B(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @brgemm_to_fma_bcst_B
+// CHECK: vector.broadcast
+// CHECK: vector.fma{{.*}}vector<64xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<3x1x1xf32>
+!vecB = vector<3x1x64xf32>
+!vecC = vector<3x1x64xf32>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_non_unit_batch_dim(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// Batch dimension should've been simplified earlier.
+
+// CHECK-LABEL: @negative_non_unit_batch_dim
+// CHECK-NOT: vector.fma
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<3x1x1xf32>
+!vecB = vector<3x1x64xf32>
+!vecC = vector<1x64xf32>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+func.func @negative_non_unit_batch_reduce_dim(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// Batch-reduce dimension should've been simplified earlier.
+
+// CHECK-LABEL: @negative_non_unit_batch_reduce_dim
+// CHECK-NOT: vector.fma
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1xf32>
+!vecB = vector<1x64xf32>
+!vecC = vector<1x64xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @negative_invalid_kind(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<mul>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_invalid_kind
+// CHECK-NOT: vector.fma
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x1xf32>
+!vecB = vector<1x1x64xf32>
+!vecC = vector<1x1x64xi32>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_accumulator_type(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_accumulator_type
+// CHECK-NOT: vector.fma
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
new file mode 100644
index 0000000..65676cb
--- /dev/null
+++ b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
@@ -0,0 +1,681 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+!vecA = vector<1x1x1x2xbf16>
+!vecB = vector<1x1x16x2xbf16>
+!vecC = vector<1x16xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_to_bf16dp(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @brgemm_to_bf16dp
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx512.dot
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x16x1x2xbf16>
+!vecB = vector<1x1x1x2xbf16>
+!vecC = vector<16x1xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_to_bf16dp_bcst_B(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @brgemm_to_bf16dp_bcst_B
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx512.dot
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x1x4xi8>
+!vecB = vector<1x1x8x4xi8>
+!vecC = vector<1x8xi32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_to_int8dp(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @brgemm_to_int8dp
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx.dot.i8
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x1x2xbf16>
+!vecB = vector<1x1x16x2xbf16>
+!vecC = vector<1x1x16xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @batch_matmul_bf16dp(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @batch_matmul_bf16dp
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx512.dot
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x1x4xi8>
+!vecB = vector<1x1x8x4xi8>
+!vecC = vector<1x1x8xi32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @batch_matmul_int8dp(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+
+// CHECK-LABEL: @batch_matmul_int8dp
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx.dot.i8
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x8x1x4xi8>
+!vecB = vector<1x1x1x4xi8>
+!vecC = vector<1x8x1xi32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @batch_matmul_int8dp_bcst_B(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+
+// CHECK-LABEL: @batch_matmul_int8dp_bcst_B
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx.dot.i8
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x2xbf16>
+!vecB = vector<1x16x2xbf16>
+!vecC = vector<1x16xf32>
+#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
+#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
+#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
+func.func @matmul_outer_product_to_bf16dp(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @matmul_outer_product_to_bf16dp
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx512.dot
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<16x1x2xbf16>
+!vecB = vector<1x1x2xbf16>
+!vecC = vector<16x1xf32>
+#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
+#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
+#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
+func.func @matmul_outer_product_to_bf16dp_bcst_B(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @matmul_outer_product_to_bf16dp_bcst_B
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx512.dot
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x4xi8>
+!vecB = vector<1x8x4xi8>
+!vecC = vector<1x8xi32>
+#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
+#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
+#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
+func.func @matmul_outer_product_to_int8dp(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @matmul_outer_product_to_int8dp
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx.dot.i8
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x2xbf16>
+!vecB = vector<1x16x2xbf16>
+!vecC = vector<1x16xf32>
+#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
+#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
+#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
+func.func @negative_invalid_vc_kind(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<mul>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_invalid_vc_kind
+// CHECK-NOT: x86vector.avx512.dot
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x1x4xbf16>
+!vecB = vector<1x1x16x4xbf16>
+!vecC = vector<1x16xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @negative_false_vnni_bf16(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_false_vnni_bf16
+// CHECK-NOT: x86vector.avx512.dot
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x1x2xi8>
+!vecB = vector<1x1x8x2xi8>
+!vecC = vector<1x8xi32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @negative_false_vnni_int8(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_false_vnni_int8
+// CHECK-NOT: x86vector.avx.dot.i8
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<3x1x1x2xbf16>
+!vecB = vector<3x1x16x2xbf16>
+!vecC = vector<3x1x16xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_batch_dimension(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_batch_dimension
+// CHECK-NOT: x86vector.avx512.dot
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<2x1x1x4xi8>
+!vecB = vector<2x1x8x4xi8>
+!vecC = vector<1x8xi32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @negative_brgemm_dimension(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_brgemm_dimension
+// CHECK-NOT: x86vector.avx.dot.i8
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x1x2xbf16>
+!vecB = vector<1x1x16x2xbf16>
+!vecC = vector<1x1x16xbf16>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_float_acc_type(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_float_acc_type
+// CHECK-NOT: x86vector.avx512.dot
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x1x4xi8>
+!vecB = vector<1x1x8x4xi8>
+!vecC = vector<1x1x8xi8>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_int_acc_type(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_int_acc_type
+// CHECK-NOT: x86vector.avx.dot.i8
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x1x4xbf16>
+!vecB = vector<1x1x16x4xbf16>
+!vecC = vector<1x1x16xbf16>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_wrong_vnni_blocking_factor_bf16(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_wrong_vnni_blocking_factor_bf16
+// CHECK-NOT: x86vector.avx512.dot
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x1xbf16>
+!vecB = vector<1x1x32xbf16>
+!vecC = vector<1x32xf32>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+func.func @negative_brgemm_not_vnni(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_brgemm_not_vnni
+// CHECK-NOT: x86vector.avx512.dot
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x1x4xi8>
+!vecB = vector<1x1x16x4xi8>
+!vecC = vector<1x1x16xi32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_wrong_vector_shape_int8(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_wrong_vector_shape_int8
+// CHECK-NOT: x86vector.avx.dot.i8
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x1x2xbf16>
+!vecB = vector<1x1x32x2xbf16>
+!vecC = vector<1x1x32xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_wrong_vector_shape_bf16(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_wrong_vector_shape_bf16
+// CHECK-NOT: x86vector.avx512.dot
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index ebbe3ce..67faa60 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -451,7 +451,7 @@ func.func @store_scatter_offset_wi_1(%src: memref<?xf16>) {
%offsets = arith.constant dense<[0]> : vector<1xindex>
%mask = arith.constant dense<1>: vector<1xi1>
// expected-error@+1 {{Mask should match value except the chunk size dim}}
- xegpu.store %val, %src[%offsets], %mask
+ xegpu.store %val, %src[%offsets], %mask
: vector<4xf16>, memref<?xf16>, vector<1xindex>, vector<1xi1>
return
}
@@ -836,7 +836,7 @@ func.func @slice_attr_repeat_dim() {
// -----
func.func @create_mem_desc_non_slm() {
%m = memref.alloca() {alignment = 1024} : memref<2048xi8, 1>
- // expected-error@+1 {{operand #0 must be statically shaped memref of 8-bit signless integer values for shared memory}}
+ // expected-error@+1 {{operand #0 must be reside in share memory and statically 1d shaped memref }}
%mem_desc = xegpu.create_mem_desc %m : memref<2048xi8, 1> -> !xegpu.mem_desc<16x64xf16>
return
}
@@ -871,14 +871,6 @@ func.func @load_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>) {
}
// -----
-func.func @load_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>) {
- // expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
- %data2 = xegpu.load_matrix %arg0[8, 8] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16> -> vector<16x16xf16>
- return
-}
-
-
-// -----
func.func @store_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf32>) {
// expected-error@+1 {{failed to verify that all of {mem_desc, data} have same element type}}
xegpu.store_matrix %arg1, %arg0[8, 8] : vector<16x16xf32>, !xegpu.mem_desc<16x64xf16>
@@ -900,16 +892,25 @@ func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: ve
}
// -----
-func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) {
- // expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
- xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
+func.func @simt_store_matrix_vector_nonlinear(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [32, 1]>>, %arg1: vector<2x16xf32>) {
+ // expected-error@+1 {{With subgroup_block_io, accessed data must be contiguous and coalesced}}
+ xegpu.store_matrix %arg1, %arg0[0, 0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
+ vector<2x16xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [32, 1]>>
return
}
// -----
-func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) {
- // expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
- xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
+func.func @simt_store_matrix_vector_noncoalesced(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [1, 16]>>, %arg1: vector<16x2xf32>) {
+ // expected-error@+1 {{With subgroup_block_io, the distributed dimensions must be contiguous}}
+ xegpu.store_matrix %arg1, %arg0[0, 0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>} :
+ vector<16x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [1, 16]>>
return
}
+// -----
+func.func @simt_store_matrix_vector_noncoalesced(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [32, 1], block = [1, 17]>>, %arg1: vector<16x2xf32>) {
+ // expected-error@+1 {{With subgroup_block_io, the block shape must match the lane layout}}
+ xegpu.store_matrix %arg1, %arg0[0, 0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
+ vector<16x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [32, 1], block = [1, 17]>>
+ return
+}
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 0a10f68..1e9738f 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -278,6 +278,15 @@ gpu.func @subgroup_load_nd_offset_1(%src: memref<24x32xf32>, %x : index, %y : in
gpu.return
}
+// CHECK: func @subgroup_load_nd_offset_2(%[[arg0:.*]]: memref<24x32xf32>, %arg1: index) {
+gpu.func @subgroup_load_nd_offset_2(%src: memref<24x32xf32>, %x : index) {
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0 : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
+ %1 = xegpu.create_nd_tdesc %src : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
+ // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]][%arg1, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32>
+ %2 = xegpu.load_nd %1[%x, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32>
+ gpu.return
+}
+
// CHECK: func @simt_load_nd_8(%[[arg0:.*]]: memref<24x32xf32>) {
gpu.func @simt_load_nd_8(%src: memref<24x32xf32>) {
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
@@ -825,6 +834,27 @@ gpu.func @create_mem_desc_with_stride() {
gpu.return
}
+
+// CHECK-LABEL: gpu.func @create_mem_desc_from_2d_memref({{.*}}) {
+gpu.func @create_mem_desc_from_2d_memref() {
+ //CHECK: [[alloc:%.+]] = memref.alloca() {alignment = 1024 : i64} : memref<16x64xf16, 3>
+ //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[alloc]] : memref<16x64xf16, 3> -> !xegpu.mem_desc<16x64xf16>
+ %m = memref.alloca() {alignment = 1024} : memref<16x64xf16, 3>
+ %mem_desc = xegpu.create_mem_desc %m : memref<16x64xf16, 3> -> !xegpu.mem_desc<16x64xf16>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @create_mem_desc_with_stride_from_2d_memref({{.*}}) {
+gpu.func @create_mem_desc_with_stride_from_2d_memref() {
+ //CHECK: %[[ALLOC:.+]] = memref.alloca() {alignment = 1024 : i64} : memref<32x64xf16, 3>
+ //CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][16, 0] [16, 64] [1, 1] : memref<32x64xf16, 3> to memref<16x64xf16, strided<[64, 1], offset: 1024>, 3>
+ //CHECK: %{{.+}} = xegpu.create_mem_desc %[[SUBVIEW]] : memref<16x64xf16, strided<[64, 1], offset: 1024>, 3> -> !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
+ %m = memref.alloca() {alignment = 1024} : memref<32x64xf16, 3>
+ %m_sub = memref.subview %m[16, 0][16, 64][1,1] : memref<32x64xf16, 3> to memref<16x64xf16, strided<[64, 1], offset: 1024>, 3>
+ %mem_desc = xegpu.create_mem_desc %m_sub : memref<16x64xf16, strided<[64, 1], offset: 1024>, 3> -> !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
+ gpu.return
+}
+
// CHECK: gpu.func @load_matrix([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
gpu.func @load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
// CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16>
diff --git a/mlir/test/Dialect/XeGPU/optimize-transpose.mlir b/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
new file mode 100644
index 0000000..24a0de6
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
@@ -0,0 +1,280 @@
+// RUN: mlir-opt --xevm-attach-target='module=xevm_* chip=pvc' \
+// RUN: --xegpu-optimize-block-loads --canonicalize --cse --split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: gpu.func @no_scf(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<64x64xf16>, %{{.*}}: vector<8x16xf16>) -> vector<8x16xf32> {
+// CHECK: %[[C16:.*]] = arith.constant 16 : index
+// CHECK: %[[C32:.*]] = arith.constant 32 : index
+// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<64x64xf16> -> index
+// CHECK: %[[T0:.*]] = arith.index_cast %[[PTR]] : index to i64
+// CHECK: %[[BDESC:.*]] = xegpu.create_nd_tdesc %[[T0]], shape : [64, %[[C32]]], strides : [%[[C32]], 1] : i64
+// CHECK-SAME: -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK-NEXT: %[[B:.*]] = xegpu.load_nd %[[BDESC]][%{{.*}}, %[[C16]]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME: : !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[B]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : vector<16x8xi32> to vector<16x16xf16>
+#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
+gpu.module @xevm_module {
+gpu.func @no_scf(%arg0: memref<64x64xf16>, %arg1: vector<8x16xf16>) -> vector<8x16xf32> {
+ %c0 = arith.constant 0 : index
+ %c32 = arith.constant 32 : index
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<64x64xf16> -> !xegpu.tensor_desc<16x16xf16, #b>
+ %1 = xegpu.load_nd %0[%c0, %c32] { result_layout = #b } : !xegpu.tensor_desc<16x16xf16, #b> -> vector<16x16xf16>
+ %2 = vector.transpose %1, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+ %6 = xegpu.dpas %arg1, %2 { layout_result_0 = #a } : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+ gpu.return %6 : vector<8x16xf32>
+}
+}
+
+// -----
+// CHECK-LABEL: gpu.func @no_scf_i8(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<64x64xi8>, %{{.*}}: vector<8x32xi8>) -> vector<8x16xi32> {
+// CHECK: %[[C16:.*]] = arith.constant 16 : index
+// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<64x64xi8> -> index
+// CHECK: %[[T0:.*]] = arith.index_cast %[[PTR]] : index to i64
+// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[T0]], shape : [64, %[[C16]]], strides : [%[[C16]], 1] : i64
+// CHECK-SAME: -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T1]][%{{.*}}, %[[C16]]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME: : !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32>
+// CHECK: %[[T3:.*]] = vector.bitcast %[[T2]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 4]>} : vector<16x8xi32> to vector<16x32xi8>
+#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>
+#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 4]>
+#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>
+#c = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+gpu.module @xevm_module {
+gpu.func @no_scf_i8(%arg0: memref<64x64xi8>, %arg1: vector<8x32xi8>) -> vector<8x16xi32> {
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<64x64xi8> -> !xegpu.tensor_desc<16x32xi8, #b>
+ %1 = xegpu.load_nd %0[%c0, %c64] { result_layout = #b } : !xegpu.tensor_desc<16x32xi8, #b> -> vector<16x32xi8>
+ %2 = vector.transpose %1, [1, 0] { layout_result_0 = #bt } : vector<16x32xi8> to vector<32x16xi8>
+ %6 = xegpu.dpas %arg1, %2 { layout_result_0 = #c } : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
+ gpu.return %6 : vector<8x16xi32>
+}
+}
+
+
+// -----
+// CHECK-LABEL: gpu.func @gemm_b_transpose(
+// CHECK-SAME: %{{.*}} memref<256x256xf16>, %[[ARG1:[a-zA-Z0-9]+]]: memref<256x256xf16>, %{{.*}}: memref<256x256xf32>) {
+// CHECK: %[[C128:.*]] = arith.constant 128 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C16:.*]] = arith.constant 16 : index
+// CHECK: %[[C256:.*]] = arith.constant 256 : index
+// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<256x256xf16> -> index
+// CHECK: %[[T3:.*]] = arith.index_cast %[[PTR]] : index to i64
+// CHECK: %[[T4:.*]] = xegpu.create_nd_tdesc %[[T3]], shape : [256, %[[C128]]], strides : [%c128, 1]
+// CHECK-SAME: : i64 -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK: %{{.*}} = scf.for %[[K:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}}) -> (vector<8x16xf32>) {
+// CHECK: %[[T7:.*]] = arith.shrui %[[K]], %[[C1]] : index
+// CHECK-NEXT: %[[T8:.*]] = xegpu.load_nd %[[T4]][%{{.*}}, %[[T7]]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} :
+// CHECK-SAME: !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32>
+// CHECK-NEXT: %{{.*}} = vector.bitcast %[[T8]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK-SAME: : vector<16x8xi32> to vector<16x16xf16>
+#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
+gpu.module @xevm_module {
+gpu.func @gemm_b_transpose(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) {
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c256 = arith.constant 256 : index
+ %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #a>
+ %1 = xegpu.load_nd %0[%c0, %c0] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf32, #a> -> vector<8x16xf32>
+ %2 = xegpu.create_nd_tdesc %arg0 : memref<256x256xf16> -> !xegpu.tensor_desc<8x16xf16, #a>
+ %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #b>
+ %4 = scf.for %arg3 = %c0 to %c256 step %c16 iter_args(%arg4 = %1) -> (vector<8x16xf32>) {
+ %5 = xegpu.load_nd %2[%c0, %arg3] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf16, #a> -> vector<8x16xf16>
+ %6 = xegpu.load_nd %3[%c0, %arg3] { layout_result_0 = #b } : !xegpu.tensor_desc<16x16xf16, #b> -> vector<16x16xf16>
+ %7 = vector.transpose %6, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+ %8 = xegpu.dpas %5, %7, %arg4 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ scf.yield %8 : vector<8x16xf32>
+ } {layout_result_0 = #a}
+ xegpu.store_nd %4, %0[%c0, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+ gpu.return
+}
+}
+
+// -----
+// CHECK-LABEL: gpu.func @nested_scf(
+// CHECK-SAME: %{{.*}}: memref<256x256xf16>, %[[ARG1:[a-zA-Z0-9]+]]: memref<256x256xf16>, %{{.*}}: memref<256x256xf32>) {
+// CHECK: %[[C128:.*]] = arith.constant 128 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C16:.*]] = arith.constant 16 : index
+// CHECK: %[[C256:.*]] = arith.constant 256 : index
+// CHECK: scf.for %{{.*}} to %{{.*}} step %{{.*}} {
+// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<256x256xf16> -> index
+// CHECK: %[[T3:.*]] = arith.index_cast %[[PTR]] : index to i64
+// CHECK: %[[T4:.*]] = xegpu.create_nd_tdesc %[[T3]], shape : [256, %[[C128]]], strides : [%[[C128]], 1] : i64
+// CHECK-SAME: -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK: %{{.*}} = scf.for %[[K:.*]] = %{{.*}} iter_args(%{{.*}}) -> (vector<8x16xf32>) {
+// CHECK: %[[T7:.*]] = arith.shrui %[[K]], %[[C1]] : index
+// CHECK-NEXT: %[[T8:.*]] = xegpu.load_nd %[[T4]][%{{.*}}, %[[T7]]] {layout_result_0 = #xegpu.layout<
+// CHECK-SAME: lane_layout = [16, 1], lane_data = [1, 1]>} :
+// CHECK-SAME: !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32>
+// CHECK-NEXT: %{{.*}} = vector.bitcast %[[T8]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK-SAME: : vector<16x8xi32> to vector<16x16xf16>
+#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
+gpu.module @xevm_module {
+gpu.func @nested_scf(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) {
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c256 = arith.constant 256 : index
+ scf.for %arg8 = %c0 to %c256 step %c16 {
+ %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #a>
+ %1 = xegpu.load_nd %0[%arg8, %c0] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf32, #a> -> vector<8x16xf32>
+ %2 = xegpu.create_nd_tdesc %arg0 : memref<256x256xf16> -> !xegpu.tensor_desc<8x16xf16, #a>
+ %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #b>
+ %4 = scf.for %arg3 = %c0 to %c256 step %c16 iter_args(%arg4 = %1) -> (vector<8x16xf32>) {
+ %5 = xegpu.load_nd %2[%arg8, %arg3] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf16, #a> -> vector<8x16xf16>
+ %6 = xegpu.load_nd %3[%arg8, %arg3] { layout_result_0 = #b } : !xegpu.tensor_desc<16x16xf16, #b> -> vector<16x16xf16>
+ %7 = vector.transpose %6, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+ %8 = xegpu.dpas %5, %7, %arg4 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ scf.yield %8 : vector<8x16xf32>
+ } {layout_result_0 = #a}
+ xegpu.store_nd %4, %0[%c0, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+ }
+ gpu.return
+}
+}
+
+// -----
+// CHECK-LABEL: gpu.func @large_loads(
+// CHECK-SAME: %{{.*}}: vector<8x16xf16>, %[[ARG1:[a-zA-Z0-9]+]]: memref<256x256xf16>, %{{.*}}: memref<256x256xf32>) {
+// CHECK: %[[C128:.*]] = arith.constant 128 : index
+// CHECK: %[[C8:.*]] = arith.constant 8 : index
+// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<32x16xi32>
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<256x256xf16> -> index
+// CHECK: %[[T2:.*]] = arith.index_cast %[[PTR]] : index to i64
+// CHECK: %[[T3:.*]] = xegpu.create_nd_tdesc %[[T2]], shape : [256, %[[C128]]], strides : [%[[C128]], 1] : i64
+// CHECK-SAME: -> !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK: %{{.*}}:4 = scf.for %[[K:.*]] = %{{.*}} iter_args(%{{.*}}) -> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) {
+// CHECK: %[[T5:.*]] = arith.shrui %[[K]], %[[C1]] : index
+// CHECK: %[[T6:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T5]]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME: : !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<32x8xi32>
+// CHECK: %[[T7:.*]] = vector.insert_strided_slice %[[T6]], %[[CST]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, offsets = [0, 0], strides = [1, 1]}
+// CHECK-SAME: : vector<32x8xi32> into vector<32x16xi32>
+// CHECK: %[[T8:.*]] = arith.addi %[[T5]], %[[C8]] : index
+// CHECK: %[[T9:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T8]]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME: : !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<32x8xi32>
+// CHECK: %[[T10:.*]] = vector.insert_strided_slice %[[T9]], %[[T7]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, offsets = [0, 8], strides = [1, 1]}
+// CHECK-SAME: : vector<32x8xi32> into vector<32x16xi32>
+// CHECK: %{{.*}} = vector.bitcast %[[T10]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK-SAME: : vector<32x16xi32> to vector<32x32xf16>
+#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
+gpu.module @xevm_module {
+gpu.func @large_loads(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) {
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %c256 = arith.constant 256 : index
+ %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #a>
+ %1 = xegpu.load_nd %0[%c0, %c0] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf32, #a> -> vector<8x16xf32>
+ %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16> -> !xegpu.tensor_desc<32x32xf16, #b>
+ %4:4 = scf.for %arg3 = %c0 to %c256 step %c32 iter_args(%arg4 = %1, %arg5 = %1, %arg6 = %1, %arg7 = %1)
+ -> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) {
+ %6 = xegpu.load_nd %3[%c0, %arg3] { layout_result_0 = #b } : !xegpu.tensor_desc<32x32xf16, #b> -> vector<32x32xf16>
+ %7 = vector.extract_strided_slice %6 {offsets = [0, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+ : vector<32x32xf16> to vector<16x16xf16>
+ %8 = vector.extract_strided_slice %6 {offsets = [0, 16], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+ : vector<32x32xf16> to vector<16x16xf16>
+ %9 = vector.extract_strided_slice %6 {offsets = [16, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+ : vector<32x32xf16> to vector<16x16xf16>
+ %10 = vector.extract_strided_slice %6 {offsets = [16, 16], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+ : vector<32x32xf16> to vector<16x16xf16>
+ %11 = vector.transpose %7, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+ %12 = vector.transpose %8, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+ %13 = vector.transpose %9, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+ %14 = vector.transpose %10, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+ %15 = xegpu.dpas %arg0, %11, %arg4 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ %16 = xegpu.dpas %arg0, %12, %arg5 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ %17 = xegpu.dpas %arg0, %13, %arg6 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ %18 = xegpu.dpas %arg0, %14, %arg7 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ scf.yield %15, %16, %17, %18 : vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>
+ } {layout_result_0 = #a, layout_result_1 = #a, layout_result_2 = #a, layout_result_3 = #a}
+ xegpu.store_nd %4#0, %0[%c0, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+ xegpu.store_nd %4#1, %0[%c0, %c16] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+ xegpu.store_nd %4#2, %0[%c16, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+ xegpu.store_nd %4#3, %0[%c16, %c16] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+ gpu.return
+}
+}
+
+// -----
+// CHECK-LABEL: gpu.func @array_length(
+// CHECK-SAME: %{{.*}}: vector<8x16xf16>, %[[ARG1:[a-zA-Z0-9]+]]: memref<256x256xf16>, %arg2: memref<256x256xf32>) {
+// CHECK: %[[C128:.*]] = arith.constant 128 : index
+// CHECK: %[[C8:.*]] = arith.constant 8 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<256x256xf16> -> index
+// CHECK: %[[T2:.*]] = arith.index_cast %[[PTR]] : index to i64
+// CHECK: %[[T3:.*]] = xegpu.create_nd_tdesc %[[T2]], shape : [256, %[[C128]]], strides : [%[[C128]], 1] : i64 ->
+// CHECK-SAME: !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK: %{{.*}}:4 = scf.for %[[K:.*]] = %{{.*}} iter_args(%{{.*}}) -> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) {
+// CHECK: %[[T5:.*]] = arith.shrui %[[K]], %[[C1]] : index
+// CHECK: %[[T6:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T5]]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME: : !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<32x8xi32>
+// CHECK: %[[T7:.*]] = vector.bitcast %[[T6]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK-SAME: : vector<32x8xi32> to vector<32x16xf16>
+// CHECK: %[[T8:.*]] = arith.addi %[[T5]], %[[C8]] : index
+// CHECK: %[[T9:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T8]]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME: : !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<32x8xi32>
+// CHECK: %[[T10:.*]] = vector.bitcast %[[T9]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK-SAME: : vector<32x8xi32> to vector<32x16xf16>
+#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
+gpu.module @xevm_module {
+gpu.func @array_length(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) {
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %c256 = arith.constant 256 : index
+ %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #a>
+ %1 = xegpu.load_nd %0[%c0, %c0] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf32, #a> -> vector<8x16xf32>
+ %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16>
+ -> !xegpu.tensor_desc<32x16xf16, #b, #xegpu.block_tdesc_attr<array_length = 2 : i64>>
+ %4:4 = scf.for %arg3 = %c0 to %c256 step %c32 iter_args(%arg4 = %1, %arg5 = %1, %arg6 = %1, %arg7 = %1)
+ -> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) {
+ %6 = xegpu.load_nd %3[%c0, %arg3] { layout_result_0 = #b }
+ : !xegpu.tensor_desc<32x16xf16, #b, #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<2x32x16xf16>
+ %19 = vector.extract %6[0] { layout_result_0 = #b } : vector<32x16xf16> from vector<2x32x16xf16>
+ %20 = vector.extract %6[1] { layout_result_0 = #b } : vector<32x16xf16> from vector<2x32x16xf16>
+ %7 = vector.extract_strided_slice %19 {offsets = [0, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+ : vector<32x16xf16> to vector<16x16xf16>
+ %8 = vector.extract_strided_slice %19 {offsets = [16, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+ : vector<32x16xf16> to vector<16x16xf16>
+ %9 = vector.extract_strided_slice %20 {offsets = [0, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+ : vector<32x16xf16> to vector<16x16xf16>
+ %10 = vector.extract_strided_slice %20 {offsets = [16, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+ : vector<32x16xf16> to vector<16x16xf16>
+ %11 = vector.transpose %7, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+ %12 = vector.transpose %8, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+ %13 = vector.transpose %9, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+ %14 = vector.transpose %10, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+ %15 = xegpu.dpas %arg0, %11, %arg4 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ %16 = xegpu.dpas %arg0, %12, %arg5 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ %17 = xegpu.dpas %arg0, %13, %arg6 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ %18 = xegpu.dpas %arg0, %14, %arg7 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ scf.yield %15, %16, %17, %18 : vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>
+ } {layout_result_0 = #a, layout_result_1 = #a, layout_result_2 = #a, layout_result_3 = #a}
+ xegpu.store_nd %4#0, %0[%c0, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+ xegpu.store_nd %4#1, %0[%c0, %c16] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+ xegpu.store_nd %4#2, %0[%c16, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+ xegpu.store_nd %4#3, %0[%c16, %c16] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+ gpu.return
+}
+}
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index 58461b8..32fb317 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -1,18 +1,45 @@
// RUN: mlir-opt -xevm-attach-target='chip=pvc' -xegpu-propagate-layout="layout-kind=inst" -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: func.func @load_store_no_array_len(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x32xf32>, %[[ARG1:[0-9a-zA-Z]+]]: memref<8x32xf32>) {
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
+// CHECK: %[[TDESC_SRC:.*]] = xegpu.create_nd_tdesc %[[ARG0]] : memref<8x32xf32> -> !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+// CHECK: %[[TDESC_DST:.*]] = xegpu.create_nd_tdesc %[[ARG1]] : memref<8x32xf32> -> !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+// CHECK: xegpu.prefetch_nd %[[TDESC_SRC]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, layout = #xegpu.layout<inst_data = [8, 16]>}> :
+// CHECK-SAME: !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+// CHECK: %[[LOADED:.*]] = xegpu.load_nd %0 <{layout = #xegpu.layout<inst_data = [8, 16]>}> {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} :
+// CHECK-SAME: !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<8x32xf32>
+// CHECK: xegpu.store_nd %[[LOADED]], %[[TDESC_DST]] <{layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+gpu.module @test {
+// Although the uArch allows 8x32 inst data using block count (or array_len),
+// it is up to optimization passes to decide on the block count usage.
+func.func @load_store_no_array_len(%arg0: memref<8x32xf32>, %arg1: memref<8x32xf32>) {
+ %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<8x32xf32> -> !xegpu.tensor_desc<8x32xf32>
+ %1 = xegpu.create_nd_tdesc %arg1 : memref<8x32xf32> -> !xegpu.tensor_desc<8x32xf32>
+ xegpu.prefetch_nd %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<8x32xf32>
+ %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x32xf32> -> vector<8x32xf32>
+ xegpu.store_nd %2, %1 : vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32>
+ return
+}
+}
+
+// -----
+
// CHECK-LABEL: func.func @dpas_f16(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
-// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>} dense<0.000000e+00> : vector<8x16xf32>
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
-// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
-// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T0]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>} :
-// CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16>
-// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T1]] {layout_result_0 = #xegpu.layout<inst_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>} :
-// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
-// CHECK: %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %[[CST]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>} :
+// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} dense<0.000000e+00> : vector<8x16xf32>
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<inst_data = [8, 16]>
+// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [16, 16]>>
+// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T0]] <{layout = #xegpu.layout<inst_data = [8, 16]>}> {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} :
+// CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<inst_data = [8, 16]>> -> vector<8x16xf16>
+// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T1]] <{layout = #xegpu.layout<inst_data = [16, 16]>}> {layout_result_0 = #xegpu.layout<inst_data = [16, 16]>} :
+// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [16, 16]>> -> vector<16x16xf16>
+// CHECK: %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %[[CST]] {layout_a = #xegpu.layout<inst_data = [8, 16]>, layout_b = #xegpu.layout<inst_data = [16, 16]>, layout_cd = #xegpu.layout<inst_data = [8, 16]>, layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} :
// CHECK-SAME: vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-// CHECK: %[[T5:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
-// CHECK: xegpu.store_nd %[[T4]], %[[T5]] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK: %[[T5:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<inst_data = [8, 16]>
+// CHECK: xegpu.store_nd %[[T4]], %[[T5]] <{layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<inst_data = [8, 16]>>
gpu.module @test {
func.func @dpas_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
@@ -46,18 +73,18 @@ gpu.module @test_kernel {
%out:3 = scf.for %k = %c0 to %c1024 step %c32
iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_tdesc)
-> (!xegpu.tensor_desc<16x32xf16>, !xegpu.tensor_desc<16x32xf16>, !xegpu.tensor_desc<16x32xf16>) {
- //CHECK: xegpu.load_nd {{.*}} {layout_result_0 = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>} :
- //CHECK-SAME: !xegpu.tensor_desc<16x32xf16, #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x32xf16>
+ //CHECK: xegpu.load_nd {{.*}} <{layout = #xegpu.layout<inst_data = [8, 16]>}> {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} :
+ //CHECK-SAME: !xegpu.tensor_desc<16x32xf16, #xegpu.layout<inst_data = [8, 16]>> -> vector<16x32xf16>
%a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x32xf16> -> vector<16x32xf16>
%b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x32xf16> -> vector<16x32xf16>
- //CHECK-COUNT: arith.addf {{.*}} {layout_result_0 = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x32xf16>
+ //CHECK-COUNT: arith.addf {{.*}} {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : vector<16x32xf16>
%c = arith.addf %a, %b : vector<16x32xf16>
- //CHECK-COUNT: xegpu.store_nd {{.*}} : vector<16x32xf16>, !xegpu.tensor_desc<16x32xf16, #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>>>
+ //CHECK-COUNT: xegpu.store_nd {{.*}} : vector<16x32xf16>, !xegpu.tensor_desc<16x32xf16, #xegpu.layout<inst_data = [8, 16]>>
xegpu.store_nd %c, %arg2: vector<16x32xf16>, !xegpu.tensor_desc<16x32xf16>
- //CHECK-COUNT: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<16x32xf16, #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ //CHECK-COUNT: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<16x32xf16, #xegpu.layout<inst_data = [8, 16]>>
%a_next_tdesc = xegpu.update_nd_offset %arg0, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16>
%b_next_tdesc = xegpu.update_nd_offset %arg1, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16>
%c_next_tdesc = xegpu.update_nd_offset %arg2, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16>
@@ -85,18 +112,18 @@ gpu.module @test_kernel {
%out:3 = scf.for %k = %c0 to %c1024 step %c32
iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_tdesc)
-> (!xegpu.tensor_desc<12x32xf16>, !xegpu.tensor_desc<12x32xf16>, !xegpu.tensor_desc<12x32xf16>) {
- //CHECK: xegpu.load_nd {{.*}} {layout_result_0 = #xegpu.layout<inst_data = [4, 16], lane_layout = [1, 16], lane_data = [1, 1]>} :
- //CHECK-SAME: !xegpu.tensor_desc<12x32xf16, #xegpu.layout<inst_data = [4, 16], lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<12x32xf16>
+ //CHECK: xegpu.load_nd {{.*}} <{layout = #xegpu.layout<inst_data = [4, 16]>}> {layout_result_0 = #xegpu.layout<inst_data = [4, 16]>} :
+ //CHECK-SAME: !xegpu.tensor_desc<12x32xf16, #xegpu.layout<inst_data = [4, 16]>> -> vector<12x32xf16>
%a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<12x32xf16> -> vector<12x32xf16>
%b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<12x32xf16> -> vector<12x32xf16>
- //CHECK-COUNT: arith.addf {{.*}} {layout_result_0 = #xegpu.layout<inst_data = [4, 16], lane_layout = [1, 16], lane_data = [1, 1]>} : vector<12x32xf16>
+ //CHECK-COUNT: arith.addf {{.*}} {layout_result_0 = #xegpu.layout<inst_data = [4, 16]>} : vector<12x32xf16>
%c = arith.addf %a, %b : vector<12x32xf16>
- //CHECK-COUNT: xegpu.store_nd {{.*}} : vector<12x32xf16>, !xegpu.tensor_desc<12x32xf16, #xegpu.layout<inst_data = [4, 16], lane_layout = [1, 16], lane_data = [1, 1]>>>
+ //CHECK-COUNT: xegpu.store_nd {{.*}} : vector<12x32xf16>, !xegpu.tensor_desc<12x32xf16, #xegpu.layout<inst_data = [4, 16]>>
xegpu.store_nd %c, %arg2: vector<12x32xf16>, !xegpu.tensor_desc<12x32xf16>
- //CHECK-COUNT: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<12x32xf16, #xegpu.layout<inst_data = [4, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ //CHECK-COUNT: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<12x32xf16, #xegpu.layout<inst_data = [4, 16]>>
%a_next_tdesc = xegpu.update_nd_offset %arg0, [%c0, %c32] : !xegpu.tensor_desc<12x32xf16>
%b_next_tdesc = xegpu.update_nd_offset %arg1, [%c0, %c32] : !xegpu.tensor_desc<12x32xf16>
%c_next_tdesc = xegpu.update_nd_offset %arg2, [%c0, %c32] : !xegpu.tensor_desc<12x32xf16>
@@ -113,9 +140,9 @@ gpu.module @test {
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
-// CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 8 : i64}>
-// CHECK-SAME: {layout_result_0 = #xegpu.layout<inst_data = [16, 8], lane_layout = [16, 1], lane_data = [1, 2]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
-// CHECK: xegpu.store %0, %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 8 : i64}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+// CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 8 : i64, layout = #xegpu.layout<inst_data = [16, 8]>}>
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<inst_data = [16, 8]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+// CHECK: xegpu.store %0, %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 8 : i64, layout = #xegpu.layout<inst_data = [16, 8]>}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
func.func @scatter_ops_chunksize(%src: memref<256xf16>) {
%1 = arith.constant dense<1>: vector<16xi1>
%offset = arith.constant dense<12> : vector<16xindex>
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index 543e119..48e77d8 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -xevm-attach-target='chip=pvc' -xegpu-propagate-layout -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -xevm-attach-target='chip=pvc' -xegpu-propagate-layout="layout-kind=lane" -split-input-file %s | FileCheck %s
gpu.module @test {
// CHECK-LABEL: func.func @dpas_f16(
@@ -6,14 +6,14 @@ gpu.module @test {
// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<0.000000e+00> : vector<8x16xf32>
// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
-// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T0]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
+// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T0]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
// CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16>
-// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T1]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
+// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T1]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
-// CHECK: %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %[[CST]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
+// CHECK: %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %[[CST]] {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
// CHECK-SAME: vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
// CHECK: %[[T5:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-// CHECK: xegpu.store_nd %[[T4]], %[[T5]] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK: xegpu.store_nd %[[T4]], %[[T5]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
func.func @dpas_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
%c0 = arith.constant 0 : index
%cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
@@ -32,7 +32,8 @@ func.func @dpas_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: me
gpu.module @test {
// CHECK-LABEL: func.func @dpas_i8(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: vector<8x32xi8>, %[[ARG1:[0-9a-zA-Z]+]]: vector<32x16xi8>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xi32>) {
-// CHECK: %[[T0:.*]] = xegpu.dpas %[[ARG0]], %[[ARG1]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16],
+// CHECK: %[[T0:.*]] = xegpu.dpas %[[ARG0]], %[[ARG1]] {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+
func.func @dpas_i8(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2: memref<8x16xi32>) {
%c0 = arith.constant 0 : index
%0 = xegpu.dpas %arg0, %arg1 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
@@ -46,8 +47,8 @@ func.func @dpas_i8(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2: memre
gpu.module @test {
// CHECK-LABEL: func.func @load_with_transpose_effect(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG0:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
-// CHECK: %{{.*}} = xegpu.load_nd %{{.*}} <{transpose = array<i64: 1, 0>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
-// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>> -> vector<16x16xf16>
+// CHECK: %{{.*}} = xegpu.load_nd %{{.*}} <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
+// CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16>
func.func @load_with_transpose_effect(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
%c0 = arith.constant 0 : index
%cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
@@ -108,7 +109,7 @@ gpu.module @test {
// CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
// CHECK-NEXT: %[[T2:.*]] = xegpu.create_tdesc %[[ARG1]], %[[CST]] : memref<256xf16>, vector<16xindex> ->
// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
-// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]] <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}> {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>, vector<16xi1> -> vector<16x16xf16>
func.func @load_gather_with_chunksize(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
%c0 = arith.constant 0 : index
@@ -135,7 +136,7 @@ gpu.module @test {
// CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
// CHECK-NEXT: %[[T0:.*]] = xegpu.create_tdesc %[[ARG0]], %[[CST]] : memref<256xf32>, vector<16xindex> ->
// CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-// CHECK-NEXT: %{{.*}} = xegpu.load %[[T0]], %[[CST0]] {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} :
+// CHECK-NEXT: %{{.*}} = xegpu.load %[[T0]], %[[CST0]] <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} :
// CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [16], lane_data = [1]>>, vector<16xi1> -> vector<16xf32>
func.func @load_gather_1d(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
%cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
@@ -183,9 +184,9 @@ gpu.module @test {
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
-// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64}>
+// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}>
// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
-// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
func.func @scatter_ops_chunksize(%src: memref<256xf16>) {
%1 = arith.constant dense<1>: vector<16xi1>
%offset = arith.constant dense<12> : vector<16xindex>
@@ -204,7 +205,7 @@ gpu.module @test {
// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
-// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}> : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
func.func @scatter_ops(%src: memref<256xf16>) {
%1 = arith.constant dense<1>: vector<16xi1>
%offset = arith.constant dense<12> : vector<16xindex>
@@ -215,10 +216,50 @@ func.func @scatter_ops(%src: memref<256xf16>) {
}
// -----
gpu.module @test {
+// CHECK-LABEL: func.func @scatter_ops_custom_perm_layout(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
+// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} dense<true> : vector<16xi1>
+// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} dense<12> : vector<16xindex>
+// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+// CHECK: %[[ADD_RES:.*]] = arith.addf %[[LOAD_VEC]], %[[LOAD_VEC]] {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} : vector<16xf16>
+// CHECK: xegpu.store %[[ADD_RES]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
+// CHECK-SAME <{layout = #xegpu.layout<lane_layout = [8], lane_data = [1]>}> : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+func.func @scatter_ops_custom_perm_layout(%src: memref<256xf16>) {
+ %1 = arith.constant dense<1>: vector<16xi1>
+ %offset = arith.constant dense<12> : vector<16xindex>
+ %3 = xegpu.load %src[%offset], %1 : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+ %4 = arith.addf %3, %3 : vector<16xf16>
+ xegpu.store %4, %src[%offset], %1 <{layout = #xegpu.layout<lane_layout = [8], lane_data = [1]>}> : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+ return
+}
+}
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @scatter_ops_preserve_load_perm_layout(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
+// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} dense<true> : vector<16xi1>
+// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} dense<12> : vector<16xindex>
+// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+// CHECK: %[[ADD_RES:.*]] = arith.addf %[[LOAD_VEC]], %[[LOAD_VEC]] {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} : vector<16xf16>
+// CHECK: xegpu.store %[[ADD_RES]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
+// CHECK-SAME <{layout = #xegpu.layout<lane_layout = [8], lane_data = [1]>}> : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+func.func @scatter_ops_preserve_load_perm_layout(%src: memref<256xf16>) {
+ %1 = arith.constant dense<1>: vector<16xi1>
+ %offset = arith.constant dense<12> : vector<16xindex>
+ %3 = xegpu.load %src[%offset], %1 <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}> : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+ %4 = arith.addf %3, %3 : vector<16xf16>
+ xegpu.store %4, %src[%offset], %1 <{layout = #xegpu.layout<lane_layout = [8], lane_data = [1]>}> : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+ return
+}
+}
+// -----
+gpu.module @test {
// CHECK-LABEL: func.func @vector_bitcast_i16_to_f16(
-// CHECK: %[[LOAD0:.*]] = xegpu.load_nd %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK: %[[LOAD0:.*]] = xegpu.load_nd %{{.*}} <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
// CHECK-SAME: !xegpu.tensor_desc<8x16xi16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xi16>
-// CHECK: %[[LOAD1:.*]] = xegpu.load_nd %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}
+// CHECK: %[[LOAD1:.*]] = xegpu.load_nd %{{.*}} <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}
// CHECK-SAME: !xegpu.tensor_desc<16x16xi16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xi16>
// CHECK: %{{.*}} = vector.bitcast %[[LOAD0]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
// CHECK-SAME: vector<8x16xi16> to vector<8x16xf16>
@@ -241,7 +282,7 @@ func.func @vector_bitcast_i16_to_f16(%arg0: memref<8x16xi16>, %arg1: memref<16x1
// -----
gpu.module @test {
// CHECK-LABEL: func.func @vector_bitcast_i32_to_f16(
-// CHECK: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
// CHECK-SAME: !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32>
// CHECK-NEXT: %{{.*}} = vector.bitcast %[[LOAD]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
// CHECK-SAME: vector<16x8xi32> to vector<16x16xf16>
@@ -262,7 +303,7 @@ func.func @vector_bitcast_i32_to_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x8
// -----
gpu.module @test {
// CHECK-LABEL: func.func @vector_bitcast_i16_to_i32(
-// CHECK: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>}
+// CHECK: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>}
// CHECK-SAME: !xegpu.tensor_desc<8x32xi16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>> -> vector<8x32xi16>
// CHECK-NEXT: %{{.*}} = vector.bitcast %[[LOAD]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
// CHECK-SAME: vector<8x32xi16> to vector<8x16xi32>
@@ -299,9 +340,9 @@ gpu.module @test {
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>,
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>,
// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
-// CHECK: %[[T1:.*]] = xegpu.load_nd %[[ARG1]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
+// CHECK: %[[T1:.*]] = xegpu.load_nd %[[ARG1]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
-// CHECK-NEXT: %[[T2:.*]] = xegpu.load_nd %[[ARG1]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
+// CHECK-NEXT: %[[T2:.*]] = xegpu.load_nd %[[ARG1]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
// CHECK-NEXT: %{{.*}} = arith.addf %[[T1]], %[[T2]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : vector<16x16xf16>
func.func @binary_op_one_use(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>) {
@@ -322,9 +363,9 @@ gpu.module @test {
// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>,
// CHECK-SAME: %[[ARG3:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
// CHECK: %[[T2:.*]] = arith.addf %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x16xf16>
-// CHECK: %[[T3:.*]] = xegpu.dpas %{{.*}}, %[[T2]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-// CHECK-NEXT: xegpu.store_nd %[[T3]], %[[ARG2]] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-// CHECK-NEXT: xegpu.store_nd %[[T2]], %[[ARG3]] : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK: %[[T3:.*]] = xegpu.dpas %{{.*}}, %[[T2]] {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+// CHECK-NEXT: xegpu.store_nd %[[T3]], %[[ARG2]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK-NEXT: xegpu.store_nd %[[T2]], %[[ARG3]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
func.func @binary_op_multiple_uses(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>, %arg3: !xegpu.tensor_desc<16x16xf16>) {
%0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
%1 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -345,11 +386,11 @@ gpu.module @test {
// CHECK-NEXT: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<0.000000e+00> : vector<8x16xf32>
// CHECK-NEXT: %[[T2:.*]]:3 = scf.for %{{.*}} iter_args(%[[ARG4:.*]] = %[[T0]], %[[ARG5:.*]] = %[[T1]], %[[ARG6:.*]] = %[[CST]]) ->
// CHECK-SAME: (!xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>, vector<8x16xf32>) {
-// CHECK-NEXT: %[[T4:.*]] = xegpu.load_nd %[[ARG4]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
+// CHECK-NEXT: %[[T4:.*]] = xegpu.load_nd %[[ARG4]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
// CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16>
-// CHECK-NEXT: %[[T5:.*]] = xegpu.load_nd %[[ARG5]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
+// CHECK-NEXT: %[[T5:.*]] = xegpu.load_nd %[[ARG5]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
-// CHECK-NEXT: %[[T6:.*]] = xegpu.dpas %[[T4]], %[[T5]], %[[ARG6]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
+// CHECK-NEXT: %[[T6:.*]] = xegpu.dpas %[[T4]], %[[T5]], %[[ARG6]] {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
// CHECK-SAME: vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
// CHECK-NEXT: %[[T7:.*]] = xegpu.update_nd_offset %[[ARG4]], [{{.*}}] : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-NEXT: %[[T8:.*]] = xegpu.update_nd_offset %[[ARG5]], [{{.*}}] : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
@@ -357,7 +398,7 @@ gpu.module @test {
// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>, vector<8x16xf32>
// CHECK-NEXT: } {layout_result_2 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
// CHECK-NEXT: %[[T3:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-// CHECK-NEXT: xegpu.store_nd %[[T2]]#2, %[[T3]] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK-NEXT: xegpu.store_nd %[[T2]]#2, %[[T3]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
func.func @for_op(%arg0: memref<8x128xf16>, %arg1: memref<128x16xf16>, %arg2: memref<8x16xf32>) {
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
@@ -385,11 +426,11 @@ gpu.module @test {
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>,
// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]+]]: i1, %[[ARG3:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
// CHECK: %{{.*}} = scf.if %[[ARG2]] -> (vector<16x16xf16>) {
-// CHECK-NEXT: %[[T3:.*]] = xegpu.load_nd %[[ARG1]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
+// CHECK-NEXT: %[[T3:.*]] = xegpu.load_nd %[[ARG1]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
// CHECK-NEXT: scf.yield %[[T3]] : vector<16x16xf16>
// CHECK-NEXT: } else {
-// CHECK-NEXT: %[[T4:.*]] = xegpu.load_nd %[[ARG1]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
+// CHECK-NEXT: %[[T4:.*]] = xegpu.load_nd %[[ARG1]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
// CHECK-NEXT: scf.yield %[[T4]] : vector<16x16xf16>
// CHECK-NEXT: } {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}
@@ -415,11 +456,11 @@ gpu.module @test {
// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]+]]: i1, %[[ARG3:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>,
// CHECK-SAME: %[[ARG4:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
// CHECK: %[[T1:.*]] = scf.if %[[ARG2]] -> (vector<16x16xf16>) {
-// CHECK-NEXT: %[[T3:.*]] = xegpu.load_nd %[[ARG1]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
+// CHECK-NEXT: %[[T3:.*]] = xegpu.load_nd %[[ARG1]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
// CHECK-NEXT: scf.yield %[[T3]] : vector<16x16xf16>
// CHECK-NEXT: } else {
-// CHECK-NEXT: %[[T4:.*]] = xegpu.load_nd %[[ARG1]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
+// CHECK-NEXT: %[[T4:.*]] = xegpu.load_nd %[[ARG1]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
// CHECK-NEXT: scf.yield %[[T4]] : vector<16x16xf16>
// CHECK-NEXT: } {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
@@ -499,7 +540,7 @@ gpu.module @test {
// CHECK-LABEL: func.func @prefetch_2d(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256x256xf16>) {
// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}, %{{.*}}] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-// CHECK-NEXT: xegpu.prefetch_nd %[[T0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK-NEXT: xegpu.prefetch_nd %[[T0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
func.func @prefetch_2d(%arg0: memref<256x256xf16>){
%c0 = arith.constant 0 : index
%0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16>
@@ -512,7 +553,7 @@ gpu.module @test {
// CHECK-LABEL: func.func @prefetch_1d(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256xf16> -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-// CHECK-NEXT: xegpu.prefetch_nd %[[T0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+// CHECK-NEXT: xegpu.prefetch_nd %[[T0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}> : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
func.func @prefetch_1d(%arg0: memref<256xf16>){
%c0 = arith.constant 0 : index
%0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
@@ -559,7 +600,7 @@ gpu.module @test {
// CHECK-LABEL: func.func @vector_shape_cast_1d_to_2d_dim1_distributed(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>,
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
-// CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[ARG0]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[ARG0]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
// CHECK-NEXT: %[[REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %{{[0-9a-zA-Z]+}}
// CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} [0] : vector<16x16xf16> to vector<16xf16>
@@ -581,7 +622,7 @@ gpu.module @test {
// CHECK-LABEL: func.func @vector_shape_cast_1d_to_2d_dim0_broadcasted(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>,
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
-// CHECK: %[[LOAD:.*]] = xegpu.load_nd %arg0 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK: %[[LOAD:.*]] = xegpu.load_nd %arg0 <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
// CHECK-NEXT: %[[REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %{{[0-9a-zA-Z]+}}
// CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>} [1]
@@ -599,3 +640,61 @@ func.func @vector_shape_cast_1d_to_2d_dim0_broadcasted(%arg0: !xegpu.tensor_desc
return
}
}
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @vector_broadcast_1d_to_2d_broadcast_along_row(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>,
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
+// CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[ARG0]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
+// CHECK-NEXT: %[[REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %{{[0-9a-zA-Z]+}}
+// CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} [0] : vector<16x16xf16> to vector<16xf16>
+// CHECK-NEXT: %[[BROADCAST:.*]] = vector.broadcast %[[REDUCE]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf16> to vector<16x16xf16>
+func.func @vector_broadcast_1d_to_2d_broadcast_along_row(%arg0: !xegpu.tensor_desc<16x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant dense<0.0000> : vector<16xf16>
+ %3 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+ %4 = vector.multi_reduction <add>, %3, %cst [0] : vector<16x16xf16> to vector<16xf16>
+ %5 = vector.broadcast %4 : vector<16xf16> to vector<16x16xf16>
+ xegpu.store_nd %5, %arg1 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
+ return
+}
+}
+
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @vector_broadcast_2d_to_2d_along_column(
+// CHECK: %[[REDUCE:.*]] = vector.multi_reduction <add>
+// CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>} [1] : vector<16x16xf16> to vector<16xf16>
+// CHECK-NEXT: %[[SHAPECAST:.*]] = vector.shape_cast %[[REDUCE]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf16> to vector<16x1xf16>
+// CHECK-NEXT: vector.broadcast %[[SHAPECAST]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x1xf16> to vector<16x16xf16>
+
+func.func @vector_broadcast_2d_to_2d_along_column(%arg0: !xegpu.tensor_desc<16x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant dense<0.0000> : vector<16xf16>
+ %3 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+ %4 = vector.multi_reduction <add>, %3, %cst [1] : vector<16x16xf16> to vector<16xf16>
+ %5 = vector.shape_cast %4 : vector<16xf16> to vector<16x1xf16>
+ %6 = vector.broadcast %5 : vector<16x1xf16> to vector<16x16xf16>
+ xegpu.store_nd %6, %arg1 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
+ return
+}
+}
+
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @vector_broadcast_scalar_to_vector(
+// CHECK: %[[CST:.*]] = arith.constant 0.{{.*}} : f16
+// CHECK-NEXT: %[[BROADCAST:.*]] = vector.broadcast %[[CST]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : f16 to vector<16x16xf16>
+
+func.func @vector_broadcast_scalar_to_vector(%arg0: !xegpu.tensor_desc<16x16xf16>) {
+ %cst = arith.constant 0.0000 : f16
+ %6 = vector.broadcast %cst : f16 to vector<16x16xf16>
+ xegpu.store_nd %6, %arg0 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
+ return
+}
+} \ No newline at end of file
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
index f233dff..216f3d1 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
@@ -1,6 +1,6 @@
-// RUN: mlir-opt --xevm-attach-target='module=xevm_* chip=pvc' -test-xegpu-sg-distribute -allow-unregistered-dialect \
-// RUN: -canonicalize -cse -split-input-file %s | FileCheck %s
-
+// RUN: mlir-opt --xevm-attach-target='module=xevm_* chip=pvc' -test-xegpu-sg-distribute \
+// RUN: -allow-unregistered-dialect -canonicalize -cse %s | FileCheck %s
+gpu.module @xevm_module{
// CHECK-LABEL: gpu.func @store_nd_1d
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
// CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16]
@@ -11,20 +11,17 @@
// CHECK-NEXT: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[W]]#1 : !xegpu.tensor_desc<16xf32,
// CHECK-SAME: #xegpu.layout<lane_layout = [16], lane_data = [1]>> to !xegpu.tensor_desc<16xf32> {resolve_simt_type_mismatch}
// CHECK-NEXT: xegpu.store_nd %[[W]]#0, %[[T1]][%[[W]]#2] : vector<1xf32>, !xegpu.tensor_desc<16xf32>
-gpu.module @xevm_module{
- gpu.func @store_nd_1d(%laneid: index) {
- %c0 = arith.constant 0 : index
- gpu.warp_execute_on_lane_0(%laneid)[16] {
- %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
- %cst = "some_op"() : () -> vector<16xf32>
- xegpu.store_nd %cst, %0 [%c0] {layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
- : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
- }
- gpu.return
+gpu.func @store_nd_1d(%laneid: index) {
+ %c0 = arith.constant 0 : index
+ gpu.warp_execute_on_lane_0(%laneid)[16] {
+ %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+ %cst = "some_op"() : () -> vector<16xf32>
+ xegpu.store_nd %cst, %0 [%c0] {layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
}
+ gpu.return
}
-// -----
// CHECK-LABEL: gpu.func @store_nd_2d
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
// CHECK: %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16]
@@ -37,22 +34,18 @@ gpu.module @xevm_module{
// CHECK-NEXT: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[W]]#1 : !xegpu.tensor_desc<16x16xf16,
// CHECK-SAME: #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> to !xegpu.tensor_desc<16x16xf16> {resolve_simt_type_mismatch}
// CHECK-NEXT: xegpu.store_nd %[[CAST]], %[[T1]][%[[W]]#2, %[[W]]#3] : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
-gpu.module @xevm_module{
- gpu.func @store_nd_2d(%laneid : index) {
- %c0 = arith.constant 0 : index
- gpu.warp_execute_on_lane_0(%laneid)[16] {
- %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- %cst = "some_op"() : () -> vector<16x16xf16>
- xegpu.store_nd %cst, %0 [%c0, %c0] {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
- : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- }
- gpu.return
+gpu.func @store_nd_2d(%laneid : index) {
+ %c0 = arith.constant 0 : index
+ gpu.warp_execute_on_lane_0(%laneid)[16] {
+ %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %cst = "some_op"() : () -> vector<16x16xf16>
+ xegpu.store_nd %cst, %0 [%c0, %c0] {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
}
+ gpu.return
}
-
-// -----
// CHECK-LABEL: gpu.func @load_nd_1d
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
// CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<1xf32>,
@@ -63,21 +56,19 @@ gpu.module @xevm_module{
// CHECK-NEXT: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[W]]#1 : !xegpu.tensor_desc<16xf32,
// CHECK-SAME: #xegpu.layout<lane_layout = [16], lane_data = [1]>> to !xegpu.tensor_desc<16xf32> {resolve_simt_type_mismatch}
// CHECK-NEXT: xegpu.load_nd %[[T1]][%[[W]]#2] : !xegpu.tensor_desc<16xf32> -> vector<1xf32>
-gpu.module @xevm_module{
- gpu.func @load_nd_1d(%laneid: index) {
- %c0 = arith.constant 0 : index
- %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf32>) {
- %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
- %1 = xegpu.load_nd %0 [%c0] {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} :
- !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf32>
- gpu.yield %1 : vector<16xf32>
- }
- "some_user_op"(%r) : (vector<1xf32>) -> ()
- gpu.return
+gpu.func @load_nd_1d(%laneid: index) {
+ %c0 = arith.constant 0 : index
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf32>) {
+ %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+ %1 = xegpu.load_nd %0 [%c0] {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} :
+ !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf32>
+ gpu.yield %1 : vector<16xf32>
}
+ "some_user_op"(%r) : (vector<1xf32>) -> ()
+ gpu.return
}
-// -----
+
// CHECK-LABEL: gpu.func @load_nd_2d
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
// CHECK: %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<16x1xf16>, !xegpu.tensor_desc<16x16xf16,
@@ -89,21 +80,19 @@ gpu.module @xevm_module{
// CHECK-SAME: #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> to !xegpu.tensor_desc<16x16xf16> {resolve_simt_type_mismatch}
// CHECK-NEXT: %[[T2:.*]] = xegpu.load_nd %[[T1]][%[[W]]#2, %[[W]]#3] : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
// CHECK: vector.shape_cast %[[T2]] : vector<16xf16> to vector<16x1xf16>
-gpu.module @xevm_module{
- gpu.func @load_nd_2d(%laneid: index) {
- %c0 = arith.constant 0 : index
- %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<16x1xf16>) {
- %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- %1 = xegpu.load_nd %0[%c0, %c0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
- : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
- gpu.yield %1 : vector<16x16xf16>
- }
- "some_user_op"(%r) : (vector<16x1xf16>) -> ()
- gpu.return
+gpu.func @load_nd_2d(%laneid: index) {
+ %c0 = arith.constant 0 : index
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<16x1xf16>) {
+ %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %1 = xegpu.load_nd %0[%c0, %c0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
+ gpu.yield %1 : vector<16x16xf16>
}
+ "some_user_op"(%r) : (vector<16x1xf16>) -> ()
+ gpu.return
}
-// -----
+
// CHECK-LABEL: gpu.func @load_nd_array_length
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
// CHECK: %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<2x16x1xf16>,
@@ -118,23 +107,21 @@ gpu.module @xevm_module{
// CHECK-NEXT: %[[T2:.*]] = xegpu.load_nd %[[T1]][%[[W]]#2, %[[W]]#3] : !xegpu.tensor_desc<16x16xf16,
// CHECK-SAME: #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<32xf16>
// CHECK-NEXT: vector.shape_cast %[[T2]] : vector<32xf16> to vector<2x16x1xf16>
-gpu.module @xevm_module{
- gpu.func @load_nd_array_length(%laneid: index) {
- %c0 = arith.constant 0 : index
- %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2x16x1xf16>) {
- %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>,
- #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- %1 = xegpu.load_nd %0[%c0, %c0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
- : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>,
- #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<2x16x16xf16>
- gpu.yield %1 : vector<2x16x16xf16>
- }
- "some_user_op"(%r) : (vector<2x16x1xf16>) -> ()
- gpu.return
+gpu.func @load_nd_array_length(%laneid: index) {
+ %c0 = arith.constant 0 : index
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2x16x1xf16>) {
+ %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>,
+ #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %1 = xegpu.load_nd %0[%c0, %c0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>,
+ #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<2x16x16xf16>
+ gpu.yield %1 : vector<2x16x16xf16>
}
+ "some_user_op"(%r) : (vector<2x16x1xf16>) -> ()
+ gpu.return
}
-// -----
+
// CHECK-LABEL: gpu.func @dpas
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
// CHECK: %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] ->
@@ -146,29 +133,27 @@ gpu.module @xevm_module{
// CHECK-DAG: %[[T3:.*]] = vector.shape_cast %[[W]]#3 : vector<8x1xf32> to vector<8xf32>
// CHECK-NEXT: %[[T4:.*]] = xegpu.dpas %[[T1]], %[[T2]], %[[T3]] : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
// CHECK-NEXT: vector.shape_cast %[[T4]] : vector<8xf32> to vector<8x1xf32>
-gpu.module @xevm_module{
- gpu.func @dpas(%laneid: index) {
- %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) {
- %0 = "some_op"() : () -> vector<8x16xf16>
- %1 = "some_op"() : () -> vector<16x16xf16>
- %2 = "some_op"() : () -> vector<8x16xf32>
- %3 = xegpu.dpas %0, %1, %2
- {
- layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
- layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>,
- layout_operand_2 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
- layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
- }
- : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
- gpu.yield %3 : vector<8x16xf32>
- }
- "some_user_op"(%r) : (vector<8x1xf32>) -> ()
- gpu.return
+gpu.func @dpas(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) {
+ %0 = "some_op"() : () -> vector<8x16xf16>
+ %1 = "some_op"() : () -> vector<16x16xf16>
+ %2 = "some_op"() : () -> vector<8x16xf32>
+ %3 = xegpu.dpas %0, %1, %2
+ {
+ layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>,
+ layout_operand_2 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }
+ : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ gpu.yield %3 : vector<8x16xf32>
}
+ "some_user_op"(%r) : (vector<8x1xf32>) -> ()
+ gpu.return
}
-// -----
+
// CHECK-LABEL: gpu.func @create_nd_tdesc_non_memref
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: ui64, %[[ARG1:[0-9a-zA-Z]+]]: index) {
// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%[[ARG1]])[16] -> (!xegpu.tensor_desc<16x16xf16,
@@ -178,21 +163,19 @@ gpu.module @xevm_module{
// CHECK-NEXT: %[[T1:.*]] = xegpu.create_nd_tdesc %[[W]]#1, shape : [64, 128], strides : [128, 1] : ui64 -> !xegpu.tensor_desc<16x16xf16>
// CHECK-NEXT: builtin.unrealized_conversion_cast %[[T1]] : !xegpu.tensor_desc<16x16xf16> to !xegpu.tensor_desc<16x16xf16,
// CHECK-SAME: #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> {resolve_simt_type_mismatch}
-gpu.module @xevm_module{
- gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %laneid: index) {
- %c0 = arith.constant 0 : index
- %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (!xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
- %0 = xegpu.create_nd_tdesc %arg0, shape:[64, 128], strides:[128, 1] : ui64 ->
- !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- gpu.yield %0 : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- }
- "some_user_op"(%r)
- : (!xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) -> ()
- gpu.return
+gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %laneid: index) {
+ %c0 = arith.constant 0 : index
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (!xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
+ %0 = xegpu.create_nd_tdesc %arg0, shape:[64, 128], strides:[128, 1] : ui64 ->
+ !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ gpu.yield %0 : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
}
+ "some_user_op"(%r)
+ : (!xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) -> ()
+ gpu.return
}
-// -----
+
// CHECK-LABEL: gpu.func @prefetch_2d
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
// CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (!xegpu.tensor_desc<16x16xf16,
@@ -204,21 +187,19 @@ gpu.module @xevm_module{
// CHECK-SAME: #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> to !xegpu.tensor_desc<16x16xf16> {resolve_simt_type_mismatch}
// CHECK-NEXT: xegpu.prefetch_nd %[[T1]][%[[W]]#1, %[[W]]#2]
// CHECK-SAME: <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16>
-gpu.module @xevm_module{
- gpu.func @prefetch_2d(%laneid: index) {
- %c0 = arith.constant 0 : index
- gpu.warp_execute_on_lane_0(%laneid)[16] {
- %0 = "some_op"() : ()
- -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- xegpu.prefetch_nd %0[%c0, %c0]
- <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
- : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- }
- gpu.return
+gpu.func @prefetch_2d(%laneid: index) {
+ %c0 = arith.constant 0 : index
+ gpu.warp_execute_on_lane_0(%laneid)[16] {
+ %0 = "some_op"() : ()
+ -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ xegpu.prefetch_nd %0[%c0, %c0]
+ <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
}
+ gpu.return
}
-// -----
+
// CHECK-LABEL: gpu.func @prefetch_1d
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (!xegpu.tensor_desc<16xf16,
@@ -229,44 +210,40 @@ gpu.module @xevm_module{
// CHECK-SAME: #xegpu.layout<lane_layout = [16], lane_data = [1]>> to !xegpu.tensor_desc<16xf16> {resolve_simt_type_mismatch}
// CHECK-NEXT: xegpu.prefetch_nd %[[T1]][%[[W]]#1] <{l1_hint = #xegpu.cache_hint<cached>,
// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf16>
-gpu.module @xevm_module{
- gpu.func @prefetch_1d(%laneid: index) {
- %c0 = arith.constant 0 : index
- gpu.warp_execute_on_lane_0(%laneid)[16] {
- %0 = "some_op"() : ()
- -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
- xegpu.prefetch_nd %0[%c0]
- <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
- : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
- }
- gpu.return
+gpu.func @prefetch_1d(%laneid: index) {
+ %c0 = arith.constant 0 : index
+ gpu.warp_execute_on_lane_0(%laneid)[16] {
+ %0 = "some_op"() : ()
+ -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+ xegpu.prefetch_nd %0[%c0]
+ <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
}
+ gpu.return
}
-// -----
+
// CHECK-LABEL: gpu.func @gpu_barrier({{.*}}) {
// CHECK: gpu.warp_execute_on_lane_0(%{{.*}})[16] -> ({{.*}}) {
// CHECK: gpu.yield %{{.*}}
// CHECK: }
// CHECK: %{{.*}} = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<16xf16> -> vector<1xf16>
// CHECK: gpu.barrier
-gpu.module @xevm_module{
- gpu.func @gpu_barrier(%laneid: index) {
- %c0 = arith.constant 0 : index
- %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf16>) {
- %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
- %1 = xegpu.load_nd %0[%c0]
- {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
- : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf16>
- gpu.barrier
- gpu.yield %1 : vector<16xf16>
- }
- "some_user_op"(%r) : (vector<1xf16>) -> ()
- gpu.return
+gpu.func @gpu_barrier(%laneid: index) {
+ %c0 = arith.constant 0 : index
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf16>) {
+ %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+ %1 = xegpu.load_nd %0[%c0]
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf16>
+ gpu.barrier
+ gpu.yield %1 : vector<16xf16>
}
+ "some_user_op"(%r) : (vector<1xf16>) -> ()
+ gpu.return
}
-// -----
+
// CHECK-LABEL: gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction
// CHECK: %[[ACC:.*]] = arith.constant {{.*}} dense<0.000000e+00> : vector<32xf32>
// CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16]
@@ -285,7 +262,6 @@ gpu.module @xevm_module{
// CHECK: %[[T7:.*]] = vector.extract %[[W]]#2[1] : f32 from vector<2xf32>
// CHECK: %[[T8:.*]] = vector.reduction <add>, %[[T6]], %[[T7]] : vector<16xf32> into f32
// CHECK: %[[T9:.*]] = vector.from_elements %[[T4]], %[[T8]] : vector<2xf32>
-gpu.module @xevm_module{
gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction(%laneid: index) {
%c0 = arith.constant 0 : index
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
@@ -307,9 +283,8 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction(%laneid: index)
"some_user_op"(%r) : (vector<2xf32>) -> ()
gpu.return
}
-}
-// -----
+
// CHECK-LABEL: gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction
// CHECK: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) {
// CHECK-NEXT: %[[SRC:.*]] = "some_def"() {{.*}} : () -> vector<2x16xf32>
@@ -320,7 +295,6 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction(%laneid: index)
// CHECK-NEXT: %[[T6:.*]] = vector.from_elements %[[T3]], %[[T5]] : vector<2xf32>
// CHECK-NEXT: gpu.yield %[[T6]] : vector<2xf32>
// CHECK-NEXT: }
-gpu.module @xevm_module{
gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction(%laneid: index) {
%c0 = arith.constant 0 : index
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
@@ -342,9 +316,8 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction(%laneid: index)
"some_user_op"(%r) : (vector<2xf32>) -> ()
gpu.return
}
-}
-// -----
+
// CHECK-LABEL: gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction
// CHECK: %[[ACC:.*]] = arith.constant {{.*}} dense<0.000000e+00> : vector<32xf32>
// CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>, vector<2x16xf32>, vector<2xf32>) {
@@ -358,7 +331,6 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction(%laneid: index)
// CHECK: %[[T5:.*]] = vector.extract %[[W]]#2[1] : f32 from vector<2xf32>
// CHECK: %[[T6:.*]] = vector.reduction <add>, %[[T4]], %[[T5]] : vector<16xf32> into f32
// CHECK: %[[T7:.*]] = vector.from_elements %[[T3]], %[[T6]] : vector<2xf32>
-gpu.module @xevm_module{
gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index) {
%c0 = arith.constant 0 : index
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
@@ -380,9 +352,8 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index)
"some_user_op"(%r) : (vector<2xf32>) -> ()
gpu.return
}
-}
-// -----
+
// CHECK-LABEL: gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction
// CHECK: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) {
// CHECK: %[[SRC:.*]] = "some_def"() {{.*}} : () -> vector<16x2xf32>
@@ -397,7 +368,6 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index)
// CHECK: %[[T7:.*]] = vector.from_elements %[[T3]], %[[T6]] : vector<2xf32>
// CHECK: gpu.yield %[[T7]] : vector<2xf32>
// CHECK: }
-gpu.module @xevm_module{
gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction(%laneid: index) {
%c0 = arith.constant 0 : index
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
@@ -419,9 +389,8 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction(%laneid: index)
"some_user_op"(%r) : (vector<2xf32>) -> ()
gpu.return
}
-}
-// -----
+
// CHECK-LABEL: gpu.func @scatter_ops_chunksize({{.*}}) {
// CHECK: %[[OFFSETS:.*]] = arith.constant {{.*}} dense<12> : vector<16xindex>
// CHECK: %[[MASKS:.*]] = arith.constant {{.*}} dense<true> : vector<16xi1>
@@ -434,35 +403,33 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction(%laneid: index)
// CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
// CHECK-NEXT: xegpu.store %[[T1]], %[[W]]#1[%[[W]]#2], %[[W]]#3 <{chunk_size = 8 : i64}>
// CHECK-SAME: : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
-gpu.module @xevm_module{
- gpu.func @scatter_ops_chunksize(%laneid: index, %src: memref<256xf16>) {
- gpu.warp_execute_on_lane_0(%laneid)[16] {
- %1 = arith.constant
- {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
- dense<1>: vector<16xi1>
- %offset = arith.constant
- {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
- dense<12> : vector<16xindex>
- %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}>
- {
- layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
- layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
- layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
- }
- : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
- xegpu.store %3, %src[%offset], %1 <{chunk_size=8}>
- {
- layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>,
- layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
- layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
- }
- : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
- }
- gpu.return
+gpu.func @scatter_ops_chunksize(%laneid: index, %src: memref<256xf16>) {
+ gpu.warp_execute_on_lane_0(%laneid)[16] {
+ %1 = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<1>: vector<16xi1>
+ %offset = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<12> : vector<16xindex>
+ %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}>
+ {
+ layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+ }
+ : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+ xegpu.store %3, %src[%offset], %1 <{chunk_size=8}>
+ {
+ layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>,
+ layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+ }
+ : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
}
+ gpu.return
}
-// -----
+
// CHECK-LABEL: gpu.func @scatter_ops({{.*}}) {
// CHECK: %[[OFFSETS:.*]] = arith.constant {{.*}} dense<12> : vector<16xindex>
// CHECK: %[[MASKS:.*]] = arith.constant {{.*}} dense<true> : vector<16xi1>
@@ -475,156 +442,144 @@ gpu.module @xevm_module{
// CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
// CHECK-NEXT: xegpu.store %[[T1]], %[[W]]#1[%[[W]]#2], %[[W]]#3
// CHECK-SAME: : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
-gpu.module @xevm_module{
- gpu.func @scatter_ops(%src: memref<256xf16>, %laneid: index) {
- gpu.warp_execute_on_lane_0(%laneid)[16] {
- %1 = arith.constant
- {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
- dense<1> : vector<16xi1>
- %offset = arith.constant
- {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
- dense<12> : vector<16xindex>
- %3 = xegpu.load %src[%offset], %1
- {
- layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
- layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
- layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
- } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
- xegpu.store %3, %src[%offset], %1
- {
- layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
- layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
- layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
- }
- : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+gpu.func @scatter_ops(%src: memref<256xf16>, %laneid: index) {
+ gpu.warp_execute_on_lane_0(%laneid)[16] {
+ %1 = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<1> : vector<16xi1>
+ %offset = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<12> : vector<16xindex>
+ %3 = xegpu.load %src[%offset], %1
+ {
+ layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+ } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+ xegpu.store %3, %src[%offset], %1
+ {
+ layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
}
- gpu.return
+ : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
}
+ gpu.return
}
-// -----
+
// CHECK-LABEL: gpu.func @memref_extract_aligned_pointer_as_index(
// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (index, memref<256x256xf16>) {
// CHECK: gpu.yield %{{.*}}, %{{.*}} : index, memref<256x256xf16>
// CHECK-NEXT: }
// CHECK-NEXT: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[W]]#1 : memref<256x256xf16> -> index
// CHECK-NEXT: arith.index_cast %[[INTPTR]] : index to i64
-gpu.module @xevm_module{
- gpu.func @memref_extract_aligned_pointer_as_index(%arg0 : memref<256x256xf16>, %laneid: index) {
- %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (index) {
- %ptr = memref.extract_aligned_pointer_as_index %arg0 : memref<256x256xf16> -> index
- gpu.yield %ptr : index
- }
- %ptr_i64 = arith.index_cast %r : index to i64
- "some_user_op"(%ptr_i64) : (i64) -> ()
- gpu.return
+gpu.func @memref_extract_aligned_pointer_as_index(%arg0 : memref<256x256xf16>, %laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (index) {
+ %ptr = memref.extract_aligned_pointer_as_index %arg0 : memref<256x256xf16> -> index
+ gpu.yield %ptr : index
}
+ %ptr_i64 = arith.index_cast %r : index to i64
+ "some_user_op"(%ptr_i64) : (i64) -> ()
+ gpu.return
}
-// -----
+
// CHECK-LABEL: gpu.func @vector_transpose(
// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2x1xf32>, vector<1x2xf32>) {
// CHECK: %[[SRC:.*]] = "some_op"() {{.*}} : () -> vector<16x2xf32>
// CHECK: gpu.yield %{{.*}}, %[[SRC]] : vector<2x16xf32>, vector<16x2xf32>
// CHECK-NEXT: }
// CHECK-NEXT: %[[T1:.*]] = vector.transpose %[[W]]#1, [1, 0] : vector<1x2xf32> to vector<2x1xf32>
-gpu.module @xevm_module{
- gpu.func @vector_transpose(%laneid: index) {
- %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2x1xf32>) {
- %cst = "some_op"()
- {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
- : () -> (vector<16x2xf32>)
- %transpose = vector.transpose %cst, [1, 0]
- {
- layout_operand_0 = #xegpu.layout<lane_layout = [16 , 1], lane_data = [1, 1]>,
- layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
- }
- : vector<16x2xf32> to vector<2x16xf32>
- gpu.yield %transpose : vector<2x16xf32>
- }
- "some_user_op"(%r) : (vector<2x1xf32>) -> ()
- gpu.return
+gpu.func @vector_transpose(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2x1xf32>) {
+ %cst = "some_op"()
+ {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+ : () -> (vector<16x2xf32>)
+ %transpose = vector.transpose %cst, [1, 0]
+ {
+ layout_operand_0 = #xegpu.layout<lane_layout = [16 , 1], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }
+ : vector<16x2xf32> to vector<2x16xf32>
+ gpu.yield %transpose : vector<2x16xf32>
}
+ "some_user_op"(%r) : (vector<2x1xf32>) -> ()
+ gpu.return
}
-// -----
+
// CHECK-LABEL: gpu.func @vector_bitcast(
// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<4x1xi16>, vector<4x2xi8>) {
// CHECK: %[[SRC:.*]] = "some_op"() {{.*}} : () -> vector<4x32xi8>
// CHECK: gpu.yield %{{.*}}, %[[SRC]] : vector<4x16xi16>, vector<4x32xi8>
// CHECK: }
// CHECK: vector.bitcast %[[W]]#1 : vector<4x2xi8> to vector<4x1xi16>
-gpu.module @xevm_module{
- gpu.func @vector_bitcast(%laneid: index) {
- %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<4x1xi16>) {
- %cst = "some_op"()
- {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>}
- : () -> (vector<4x32xi8>)
- %bitcast = vector.bitcast %cst
- {
- layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>,
- layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
- }
- : vector<4x32xi8> to vector<4x16xi16>
- gpu.yield %bitcast : vector<4x16xi16>
- }
- "some_user_op"(%r) : (vector<4x1xi16>) -> ()
- gpu.return
+gpu.func @vector_bitcast(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<4x1xi16>) {
+ %cst = "some_op"()
+ {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>}
+ : () -> (vector<4x32xi8>)
+ %bitcast = vector.bitcast %cst
+ {
+ layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }
+ : vector<4x32xi8> to vector<4x16xi16>
+ gpu.yield %bitcast : vector<4x16xi16>
}
+ "some_user_op"(%r) : (vector<4x1xi16>) -> ()
+ gpu.return
}
-// -----
+
// CHECK-LABEL: gpu.func @vector_shapecast_rank_increasing
// CHECK: %{{.*}}:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1x1xf32>, vector<1xf32>) {
// CHECK: gpu.yield %{{.*}} : vector<1x16xf32>, vector<16xf32>
// CHECK: }
// CHECK: %{{.*}} = vector.shape_cast %{{.*}}#1 : vector<1xf32> to vector<1x1xf32>
-gpu.module @xevm_module {
- gpu.func @vector_shapecast_rank_increasing(%laneid: index) {
- %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) {
- %cst = "some_op"()
- {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>}
- : () -> (vector<16xf32>)
- %cast = vector.shape_cast %cst
- {
- layout_operand_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>,
- layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
- }
- : vector<16xf32> to vector<1x16xf32>
- gpu.yield %cast : vector<1x16xf32>
- }
- "some_user_op"(%r) : (vector<1x1xf32>) -> ()
- gpu.return
+gpu.func @vector_shapecast_rank_increasing(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) {
+ %cst = "some_op"()
+ {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>}
+ : () -> (vector<16xf32>)
+ %cast = vector.shape_cast %cst
+ {
+ layout_operand_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }
+ : vector<16xf32> to vector<1x16xf32>
+ gpu.yield %cast : vector<1x16xf32>
}
+ "some_user_op"(%r) : (vector<1x1xf32>) -> ()
+ gpu.return
}
-// -----
+
// CHECK-LABEL: gpu.func @vector_shapecast_rank_reducing(
// CHECK: %{{.*}}:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1xf32>, vector<1x1xf32>) {
// CHECK: gpu.yield %{{.*}} : vector<16xf32>, vector<1x16xf32>
// CHECK: }
// CHECK: %{{.*}} = vector.shape_cast %{{.*}}#1 : vector<1x1xf32> to vector<1xf32>
-gpu.module @xevm_module {
- gpu.func @vector_shapecast_rank_reducing(%laneid: index) {
- %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf32>) {
- %cst = "some_op"()
- {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
- : () -> (vector<1x16xf32>)
- %cast = vector.shape_cast %cst
- {
- layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
- layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>
- }
- : vector<1x16xf32> to vector<16xf32>
- gpu.yield %cast : vector<16xf32>
- }
- "some_user_op"(%r) : (vector<1xf32>) -> ()
- gpu.return
+gpu.func @vector_shapecast_rank_reducing(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf32>) {
+ %cst = "some_op"()
+ {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : () -> (vector<1x16xf32>)
+ %cast = vector.shape_cast %cst
+ {
+ layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>
+ }
+ : vector<1x16xf32> to vector<16xf32>
+ gpu.yield %cast : vector<16xf32>
}
+ "some_user_op"(%r) : (vector<1xf32>) -> ()
+ gpu.return
}
-// -----
+
// NOTE: Layouts are still valid, but distribution still requires a slice layout for the operand.
//
// CHECK-LABEL: gpu.func @vector_shapecast_unsupported
@@ -634,21 +589,400 @@ gpu.module @xevm_module {
// CHECK: }
// CHECK: "some_user_op"(%[[W]]) : (vector<1x1xf32>) -> ()
// CHECK: gpu.return
-gpu.module @xevm_module {
- gpu.func @vector_shapecast_unsupported(%laneid: index) {
- %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) {
- %cst = "some_op"()
- {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> }
- : () -> (vector<16xf32>)
- %cast = vector.shape_cast %cst
- {
- layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
- layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
- }
- : vector<16xf32> to vector<1x16xf32>
- gpu.yield %cast : vector<1x16xf32>
+gpu.func @vector_shapecast_unsupported(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) {
+ %cst = "some_op"()
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> }
+ : () -> (vector<16xf32>)
+ %cast = vector.shape_cast %cst
+ {
+ layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }
+ : vector<16xf32> to vector<1x16xf32>
+ gpu.yield %cast : vector<1x16xf32>
+ }
+ "some_user_op"(%r) : (vector<1x1xf32>) -> ()
+ gpu.return
+}
+
+
+// CHECK-LABEL: gpu.func @vector_extract_strided_slice_distributed_dim_fully_extracted
+// CHECK-NEXT: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<8x1xf32>, vector<24x1xf32>) {
+// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<24x16xf32>
+// CHECK: gpu.yield %{{.*}}, %[[S]] : vector<8x16xf32>, vector<24x16xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[T1:.*]] = vector.extract_strided_slice %[[W]]#1
+// CHECK-SAME: {offsets = [8, 0], sizes = [8, 1], strides = [1, 1]} : vector<24x1xf32> to vector<8x1xf32>
+// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<8x1xf32>) -> ()
+gpu.func @vector_extract_strided_slice_distributed_dim_fully_extracted(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) {
+ %0 = "some_def"() : () -> (vector<24x16xf32>)
+ %1 = vector.extract_strided_slice %0 { offsets = [8, 0], sizes = [8, 16], strides = [1, 1],
+ layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }
+ : vector<24x16xf32> to vector<8x16xf32>
+ gpu.yield %1 : vector<8x16xf32>
+ }
+ "some_use"(%r) : (vector<8x1xf32>) -> ()
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_extract_strided_slice_non_distributed
+// CHECK-NEXT: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<8x1xf32>, vector<24x1xf32>) {
+// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<24x1xf32>
+// CHECK: gpu.yield %{{.*}}, %[[S]] : vector<8x1xf32>, vector<24x1xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[T1:.*]] = vector.extract_strided_slice %[[W]]#1
+// CHECK-SAME: {offsets = [8, 0], sizes = [8, 1], strides = [1, 1]} : vector<24x1xf32> to vector<8x1xf32>
+// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<8x1xf32>) -> ()
+gpu.func @vector_extract_strided_slice_non_distributed(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) {
+ %0 = "some_def"() : () -> (vector<24x1xf32>)
+ %1 = vector.extract_strided_slice %0 { offsets = [8, 0], sizes = [8, 1], strides = [1, 1],
+ layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }
+ : vector<24x1xf32> to vector<8x1xf32>
+ gpu.yield %1 : vector<8x1xf32>
+ }
+ "some_use"(%r) : (vector<8x1xf32>) -> ()
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_extract_strided_slice_inner_distributed
+// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<8x1xf32>, vector<24x4xf32>) {
+// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<24x64xf32>
+// CHECK: gpu.yield %{{.*}}, %[[S]] : vector<8x16xf32>, vector<24x64xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[T1:.*]] = vector.extract_strided_slice %[[W]]#1
+// CHECK-SAME: {offsets = [8, 3], sizes = [8, 1], strides = [1, 1]} : vector<24x4xf32> to vector<8x1xf32>
+// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<8x1xf32>) -> ()
+gpu.func @vector_extract_strided_slice_inner_distributed(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) {
+ %0 = "some_def"() : () -> (vector<24x64xf32>)
+ %1 = vector.extract_strided_slice %0 { offsets = [8, 48], sizes = [8, 16], strides = [1, 1],
+ layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }
+ : vector<24x64xf32> to vector<8x16xf32>
+ gpu.yield %1 : vector<8x16xf32>
+ }
+ "some_use"(%r) : (vector<8x1xf32>) -> ()
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_extract_strided_slice_outer_distributed
+// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1x16xf32>, vector<2x16xf32>) {
+// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<32x16xf32>
+// CHECK: gpu.yield %{{.*}}, %[[S]] : vector<16x16xf32>, vector<32x16xf32>
+// CHECK: }
+// CHECK-NEXT: %[[T1:.*]] = vector.extract %[[W]]#1[1] : vector<16xf32> from vector<2x16xf32>
+// CHECK-NEXT: %[[T2:.*]] = vector.shape_cast %[[T1]] : vector<16xf32> to vector<1x16xf32>
+// CHECK-NEXT: "some_use"(%[[T2]]) : (vector<1x16xf32>) -> ()
+gpu.func @vector_extract_strided_slice_outer_distributed(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x16xf32>) {
+ %0 = "some_def"() : () -> (vector<32x16xf32>)
+ %1 = vector.extract_strided_slice %0 { offsets = [16], sizes = [16], strides = [1],
+ layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
+ }
+ : vector<32x16xf32> to vector<16x16xf32>
+ gpu.yield %1 : vector<16x16xf32>
+ }
+ "some_use"(%r) : (vector<1x16xf32>) -> ()
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_extract_strided_slice_1d
+// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>, vector<4xf32>) {
+// CHECK: %[[S:.*]] = "some_def"() : () -> vector<64xf32>
+// CHECK: gpu.yield %{{.*}}, %[[S]] : vector<32xf32>, vector<64xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[T1:.*]] = vector.extract_strided_slice %[[W]]#1
+// CHECK-SAME: {offsets = [1], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<2xf32>) -> ()
+gpu.func @vector_extract_strided_slice_1d(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
+ %0 = "some_def"() : () -> (vector<64xf32>)
+ %1 = vector.extract_strided_slice %0 { offsets = [16], sizes = [32], strides = [1],
+ layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+ }
+ : vector<64xf32> to vector<32xf32>
+ gpu.yield %1 : vector<32xf32>
+ }
+ "some_use"(%r) : (vector<2xf32>) -> ()
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_extract_strided_slice_unsopported_offset
+// CHECK: %{{.*}} = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) {
+// CHECK: }
+// CHECK-NOT: %{{.*}} = vector.extract_strided_slice
+gpu.func @vector_extract_strided_slice_unsopported_offset(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
+ %0 = "some_def"() : () -> (vector<64xf32>)
+ %1 = vector.extract_strided_slice %0 { offsets = [3], sizes = [32], strides = [1],
+ layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+ }
+ : vector<64xf32> to vector<32xf32>
+ gpu.yield %1 : vector<32xf32>
+ }
+ "some_use"(%r) : (vector<2xf32>) -> ()
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_extract_strided_slice_unsopported_source
+// CHECK: %{{.*}} = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) {
+// CHECK: }
+// CHECK-NOT: %{{.*}} = vector.extract_strided_slice
+gpu.func @vector_extract_strided_slice_unsopported_source(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
+ %0 = "some_def"() : () -> (vector<54xf32>)
+ %1 = vector.extract_strided_slice %0 { offsets = [0], sizes = [32], strides = [1],
+ layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+ }
+ : vector<54xf32> to vector<32xf32>
+ gpu.yield %1 : vector<32xf32>
+ }
+ "some_use"(%r) : (vector<2xf32>) -> ()
+ gpu.return
+}
+
+
+// CHECK-LABEL: gpu.func @vector_insert_strided_slice_distributed_dim_fully_inserted
+// CHECK-NEXT: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<64x1xf32>, vector<16x1xf32>, vector<64x1xf32>) {
+// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<16x16xf32>
+// CHECK-NEXT: %[[D:.*]] = "some_def"() : () -> vector<64x16xf32>
+// CHECK: gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<64x16xf32>, vector<16x16xf32>, vector<64x16xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[T1:.*]] = vector.insert_strided_slice %[[W]]#1, %[[W]]#2
+// CHECK-SAME: {offsets = [24, 0], strides = [1, 1]} : vector<16x1xf32> into vector<64x1xf32>
+// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<64x1xf32>) -> ()
+gpu.func @vector_insert_strided_slice_distributed_dim_fully_inserted(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<64x1xf32>) {
+ %0 = "some_def"() : () -> (vector<16x16xf32>)
+ %1 = "some_def"() : () -> (vector<64x16xf32>)
+ %2 = vector.insert_strided_slice %0, %1 { offsets = [24, 0], strides = [1, 1],
+ layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }
+ : vector<16x16xf32> into vector<64x16xf32>
+ gpu.yield %2 : vector<64x16xf32>
+ }
+ "some_use"(%r) : (vector<64x1xf32>) -> ()
+ gpu.return
+}
+
+
+// CHECK-LABEL: gpu.func @vector_insert_strided_slice_non_distributed
+// CHECK-NEXT: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<64x1xf32>, vector<16x1xf32>, vector<64x1xf32>) {
+// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<16x1xf32>
+// CHECK-NEXT: %[[D:.*]] = "some_def"() : () -> vector<64x1xf32>
+// CHECK: gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<64x1xf32>, vector<16x1xf32>, vector<64x1xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[T1:.*]] = vector.insert_strided_slice %[[W]]#1, %[[W]]#2
+// CHECK-SAME: {offsets = [24, 0], strides = [1, 1]} : vector<16x1xf32> into vector<64x1xf32>
+// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<64x1xf32>) -> ()
+gpu.func @vector_insert_strided_slice_non_distributed(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<64x1xf32>) {
+ %0 = "some_def"() : () -> (vector<16x1xf32>)
+ %1 = "some_def"() : () -> (vector<64x1xf32>)
+ %2 = vector.insert_strided_slice %0, %1 { offsets = [24, 0], strides = [1, 1],
+ layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
}
- "some_user_op"(%r) : (vector<1x1xf32>) -> ()
- gpu.return
+ : vector<16x1xf32> into vector<64x1xf32>
+ gpu.yield %2 : vector<64x1xf32>
}
+ "some_use"(%r) : (vector<64x1xf32>) -> ()
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_insert_strided_slice_inner_distributed
+// CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<64x2xf32>, vector<16x1xf32>, vector<64x2xf32>) {
+// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<16x16xf32>
+// CHECK-NEXT: %[[D:.*]] = "some_def"() : () -> vector<64x32xf32>
+// CHECK: gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<64x32xf32>, vector<16x16xf32>, vector<64x32xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[T1:.*]] = vector.insert_strided_slice %[[W]]#1, %[[W]]#2
+// CHECK-SAME: {offsets = [24, 1], strides = [1, 1]} : vector<16x1xf32> into vector<64x2xf32>
+// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<64x2xf32>) -> ()
+gpu.func @vector_insert_strided_slice_inner_distributed(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<64x2xf32>) {
+ %0 = "some_def"() : () -> (vector<16x16xf32>)
+ %1 = "some_def"() : () -> (vector<64x32xf32>)
+ %2 = vector.insert_strided_slice %0, %1 { offsets = [24, 16], strides = [1, 1],
+ layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }
+ : vector<16x16xf32> into vector<64x32xf32>
+ gpu.yield %2 : vector<64x32xf32>
+ }
+ "some_use"(%r) : (vector<64x2xf32>) -> ()
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_insert_strided_slice_outer_distributed
+// CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<3x32xf32>, vector<1x16xf32>, vector<3x32xf32>) {
+// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<16x16xf32>
+// CHECK-NEXT: %[[D:.*]] = "some_def"() : () -> vector<48x32xf32>
+// CHECK: gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<48x32xf32>, vector<16x16xf32>, vector<48x32xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[T1:.*]] = vector.insert_strided_slice %[[W]]#1, %[[W]]#2
+// CHECK-SAME: {offsets = [2, 4], strides = [1, 1]} : vector<1x16xf32> into vector<3x32xf32>
+// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<3x32xf32>) -> ()
+gpu.func @vector_insert_strided_slice_outer_distributed(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<3x32xf32>) {
+ %0 = "some_def"() : () -> (vector<16x16xf32>)
+ %1 = "some_def"() : () -> (vector<48x32xf32>)
+ %2 = vector.insert_strided_slice %0, %1 { offsets = [32, 4], strides = [1, 1],
+ layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+ layout_operand_1 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
+ }
+ : vector<16x16xf32> into vector<48x32xf32>
+ gpu.yield %2 : vector<48x32xf32>
+ }
+ "some_use"(%r) : (vector<3x32xf32>) -> ()
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_insert_strided_slice_1d
+// CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<3xf32>, vector<1xf32>, vector<3xf32>) {
+// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<16xf32>
+// CHECK-NEXT: %[[D:.*]] = "some_def"() : () -> vector<48xf32>
+// CHECK: gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<48xf32>, vector<16xf32>, vector<48xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[T1:.*]] = vector.insert_strided_slice %[[W]]#1, %[[W]]#2
+// CHECK-SAME: {offsets = [1], strides = [1]} : vector<1xf32> into vector<3xf32>
+// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<3xf32>) -> ()
+gpu.func @vector_insert_strided_slice_1d(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<3xf32>) {
+ %0 = "some_def"() : () -> (vector<16xf32>)
+ %1 = "some_def"() : () -> (vector<48xf32>)
+ %2 = vector.insert_strided_slice %0, %1 { offsets = [16], strides = [1],
+ layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+ }
+ : vector<16xf32> into vector<48xf32>
+ gpu.yield %2 : vector<48xf32>
+ }
+ "some_use"(%r) : (vector<3xf32>) -> ()
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_insert_strided_slice_unsupported_source
+// CHECK: %{{.*}} = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<3xf32>) {
+// CHECK: }
+// CHECK-NOT: %{{.*}} = vector.insert_strided_slice
+gpu.func @vector_insert_strided_slice_unsupported_source(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<3xf32>) {
+ %0 = "some_def"() : () -> (vector<8xf32>)
+ %1 = "some_def"() : () -> (vector<48xf32>)
+ %2 = vector.insert_strided_slice %0, %1 { offsets = [16], strides = [1],
+ layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+ }
+ : vector<8xf32> into vector<48xf32>
+ gpu.yield %2 : vector<48xf32>
+ }
+ "some_use"(%r) : (vector<3xf32>) -> ()
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_insert_strided_slice_unsupported_offset
+// CHECK: %{{.*}} = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<3xf32>) {
+// CHECK: }
+// CHECK-NOT: %{{.*}} = vector.insert_strided_slice
+gpu.func @vector_insert_strided_slice_unsupported_offset(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<3xf32>) {
+ %0 = "some_def"() : () -> (vector<16xf32>)
+ %1 = "some_def"() : () -> (vector<48xf32>)
+ %2 = vector.insert_strided_slice %0, %1 { offsets = [3], strides = [1],
+ layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+ }
+ : vector<16xf32> into vector<48xf32>
+ gpu.yield %2 : vector<48xf32>
+ }
+ "some_use"(%r) : (vector<3xf32>) -> ()
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_broadcast_1d_to_2d_broadcast_within_lane
+// CHECK-SAME: (%[[ARG0:.*]]: index) {
+// CHECK: %[[R:.*]]:2 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<16x1xf16>, vector<1xf16>)
+// CHECK: %[[DEF:.*]] = "some_def"()
+// CHECK: %[[BCAST_INNER:.*]] = vector.broadcast %[[DEF]]
+// CHECK: gpu.yield %[[BCAST_INNER]], %[[DEF]]
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[R]]#1 : vector<1xf16> to vector<16x1xf16>
+// CHECK: "some_use"(%[[BCAST]])
+gpu.func @vector_broadcast_1d_to_2d_broadcast_within_lane(%laneid: index) {
+
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<16x1xf16>) {
+
+ %1 = "some_def"() : () -> vector<16xf16>
+
+ %2 = vector.broadcast %1 {
+ layout_operand_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ } : vector<16xf16> to vector<16x16xf16>
+
+ gpu.yield %2 : vector<16x16xf16>
+ }
+ "some_use"(%r) : (vector<16x1xf16>) -> ()
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_broadcast_2d_to_2d_across_lane_lower_to_noop_case
+// CHECK-SAME: (%[[ARG0:.*]]: index)
+// CHECK: %[[R:.*]]:2 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<16x1xf16>, vector<16x1xf16>)
+// CHECK: %[[DEF:.*]] = "some_def"() : () -> vector<16x1xf16>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[DEF]]
+// CHECK-SAME: : vector<16x1xf16> to vector<16x16xf16>
+// CHECK: gpu.yield %[[BCAST]], %[[DEF]] : vector<16x16xf16>, vector<16x1xf16>
+// CHECK: "some_use"(%[[R]]#1) : (vector<16x1xf16>) -> ()
+gpu.func @vector_broadcast_2d_to_2d_across_lane_lower_to_noop_case(%arg0: index) {
+ %0 = gpu.warp_execute_on_lane_0(%arg0)[16] -> (vector<16x1xf16>) {
+ %1 = "some_def"() : () -> vector<16x1xf16>
+ %2 = vector.broadcast %1 {
+ layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ } : vector<16x1xf16> to vector<16x16xf16>
+ gpu.yield %2: vector<16x16xf16>
+ }
+ "some_use"(%0) : (vector<16x1xf16>) -> ()
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_shape_cast_scalar_to_vector
+// CHECK-SAME: (%[[ARG0:.*]]: index)
+// CHECK: %[[R:.*]]:2 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<16x1xf16>, f16)
+// CHECK: %[[DEF:.*]] = "some_def"()
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[DEF]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : f16 to vector<16x16xf16>
+// CHECK: gpu.yield %[[BCAST]], %[[DEF]] : vector<16x16xf16>, f16
+// CHECK: %[[RESULT:.*]] = vector.broadcast %[[R]]#1 : f16 to vector<16x1xf16>
+// CHECK: "some_use"(%[[RESULT]])
+gpu.func
+@vector_shape_cast_scalar_to_vector(%arg0: index) {
+ %0 = gpu.warp_execute_on_lane_0(%arg0)[16] -> (vector<16x1xf16>) {
+ %1 = "some_def"() : () -> f16
+ %2 = vector.broadcast %1 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : f16 to vector<16x16xf16>
+ gpu.yield %2 : vector<16x16xf16>
+ }
+ "some_use"(%0) : (vector<16x1xf16>) -> ()
+ gpu.return
+}
+
}
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 27a3dc3..e5e3d2a 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -265,3 +265,129 @@ gpu.module @xevm_module{
gpu.return
}
}
+
+// -----
+// CHECK-LABEL: gpu.func @load_store_matrix_1({{.*}}) {
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[C8:.*]] = arith.constant 8 : index
+// CHECK: %[[LANE_ID:.*]] = gpu.lane_id
+// CHECK: %[[REMU1:.*]] = arith.remui %[[LANE_ID]], %[[C8]]
+// CHECK: %[[DIVU:.*]] = arith.divui %[[LANE_ID]], %[[C8]]
+// CHECK: %[[REMU2:.*]] = arith.remui %[[DIVU]], %[[C2]]
+// CHECK: %[[REMU3:.*]] = arith.remui %[[REMU2]], %[[C2]]
+// CHECK: %[[REMU4:.*]] = arith.remui %[[REMU1]], %[[C8]]
+// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[REMU3]], %[[REMU4]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<1x1xf32>
+// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%[[REMU3]], %[[REMU4]]] : vector<1x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+gpu.module @xevm_module{
+ gpu.func @load_store_matrix_1(%arg0: !xegpu.mem_desc<32x32xf32>) {
+ %c0 = arith.constant 0 : index
+ %1 = xegpu.load_matrix %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}> : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x8xf32>
+ xegpu.store_matrix %1, %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}> : vector<2x8xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+ gpu.return
+ }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @load_store_matrix_2({{.*}}) {
+// CHECK: %[[C8:.*]] = arith.constant 8 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[C4:.*]] = arith.constant 4 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[LANE_ID:.*]] = gpu.lane_id
+// CHECK: %[[REMU1:.*]] = arith.remui %[[LANE_ID]], %[[C4]]
+// CHECK: %[[DIVU:.*]] = arith.divui %[[LANE_ID]], %[[C4]]
+// CHECK: %[[REMU2:.*]] = arith.remui %[[DIVU]], %[[C4]]
+// CHECK: %[[MUL:.*]] = arith.muli %[[REMU2]], %[[C2]]
+// CHECK: %[[REMU3:.*]] = arith.remui %[[MUL]], %[[C8]]
+// CHECK: %[[REMU4:.*]] = arith.remui %[[REMU1]], %[[C4]]
+// CHECK: %[[ADD:.*]] = arith.addi %[[REMU4]], %[[C1]]
+// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[REMU3]], %[[ADD]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x1xf32>
+// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%[[REMU3]], %[[ADD]]] : vector<2x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+gpu.module @xevm_module{
+ gpu.func @load_store_matrix_2(%arg0: !xegpu.mem_desc<32x32xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %1 = xegpu.load_matrix %arg0[%c0, %c1] <{layout = #xegpu.layout<lane_layout = [4, 4], lane_data = [2, 1]>}> : !xegpu.mem_desc<32x32xf32>, index, index -> vector<8x4xf32>
+ xegpu.store_matrix %1, %arg0[%c0, %c1] <{layout = #xegpu.layout<lane_layout = [4, 4], lane_data = [2, 1]>}> : vector<8x4xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+ gpu.return
+ }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @load_store_matrix_3({{.*}}) {
+// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%{{.*}}, %{{.*}}] <{subgroup_block_io}>:
+// CHECK-SAME: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<block = [16, 1], stride = [1, 32]>>, index, index -> vector<1x2xf32>
+// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%{{.*}}, %{{.*}}] <{subgroup_block_io}>:
+// CHECK-SAME: vector<1x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<block = [16, 1], stride = [1, 32]>>, index, index
+gpu.module @xevm_module{
+ gpu.func @load_store_matrix_3(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 1]>>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %1 = xegpu.load_matrix %arg0[%c0, %c1] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} :
+ !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 1]>>, index, index -> vector<16x2xf32>
+ xegpu.store_matrix %1, %arg0[%c0, %c1] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} :
+ vector<16x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 1]>>, index, index
+ gpu.return
+ }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @vector_broadcast_1d_to_2d_broadcast_within_lane({{.*}}) {
+gpu.module @xevm_module{
+ gpu.func @vector_broadcast_1d_to_2d_broadcast_within_lane(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} dense<0.000000e+00> : vector<16xf16>
+ %tdesc0 = xegpu.create_nd_tdesc %arg0 : memref<16x16xf16>
+ -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %tdesc1 = xegpu.create_nd_tdesc %arg1 : memref<16x16xf16>
+ -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %0 = xegpu.load_nd %tdesc0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
+ %1 = vector.multi_reduction <add>, %0, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} [0] : vector<16x16xf16> to vector<16xf16>
+ // CHECK: %[[BCAST:.*]] = vector.broadcast %{{.*}} : f16 to vector<16xf16>
+ %2 = vector.broadcast %1 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf16> to vector<16x16xf16>
+ xegpu.store_nd %2, %tdesc1[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ gpu.return
+ }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @vector_broadcast_2d_to_2d_across_lane_lower_to_noop_case({{.*}}) {
+gpu.module @xevm_module{
+ gpu.func @vector_broadcast_2d_to_2d_across_lane_lower_to_noop_case(%arg0: memref<16xf16>, %arg1: memref<16x16xf16>) {
+ %c0 = arith.constant 0 : index
+ %mask = vector.constant_mask [16] {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}: vector<16xi1>
+ %1 = xegpu.load %arg0[%c0], %mask {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}: memref<16xf16>, index, vector<16xi1> -> vector<16xf16>
+
+ %11 = vector.shape_cast %1 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf16> to vector<16x1xf16>
+ %2 = vector.broadcast %11 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x1xf16> to vector<16x16xf16>
+ // CHECK-NOT: vector.broadcast
+ // CHECK-NOT: vector.shape_cast
+
+ %tdesc1 = xegpu.create_nd_tdesc %arg1 : memref<16x16xf16>
+ -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK: xegpu.store_nd {{.*}}, {{.*}}[{{.*}}, {{.*}}]
+ // CHECK-SAME: : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
+
+ xegpu.store_nd %2, %tdesc1[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ gpu.return
+ }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @vector_shape_cast_scalar_to_vector({{.*}}) {
+gpu.module @xevm_module{
+ gpu.func @vector_shape_cast_scalar_to_vector(%arg0: memref<16xf16>, %arg1: memref<16x16xf16>) {
+ %c0 = arith.constant 0 : index
+ %9 = gpu.block_id x
+ %10 = arith.index_cast %9 : index to i16
+ %11 = arith.bitcast %10 : i16 to f16
+ // CHECK: vector.broadcast {{.*}} : f16 to vector<16xf16>
+ %2 = vector.broadcast %11 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : f16 to vector<16x16xf16>
+ %tdesc1 = xegpu.create_nd_tdesc %arg1 : memref<16x16xf16>
+ -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ xegpu.store_nd %2, %tdesc1[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ gpu.return
+ }
+}
+
+
diff --git a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
new file mode 100644
index 0000000..dce4a41
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
@@ -0,0 +1,157 @@
+// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics
+
+func.func @set_desc_layout(%arg0: memref<4096x4096xf16>) {
+ %c32 = arith.constant 32 : index // expected-note {{target op}}
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error@below {{Expected a xegpu.create_nd_desc op, but got: arith.constant}}
+ %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] sg_data = [32, 32] : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_bad_result_index
+func.func @set_op_layout_attr_bad_result_index(%arg0: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error@below {{Index exceeds the number of op results}}
+ transform.xegpu.set_op_layout_attr %0 result index = 1 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_bad_operand_index
+func.func @set_op_layout_attr_bad_operand_index(%arg0: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error@below {{Index exceeds the number of op operands}}
+ transform.xegpu.set_op_layout_attr %0 index = 1 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_multiple
+func.func @set_op_layout_attr_multiple(%arg0: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+ %3 = arith.extf %2 : vector<256x32xf32> to vector<256x32xf64>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error@below {{Requires exactly one targetOp handle (got 2)}}
+ transform.xegpu.set_op_layout_attr %0 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @set_gpu_launch_threads_bad_handle(%arg0: memref<4096x4096xf16>) {
+ %c32 = arith.constant 32 : index // expected-note {{target op}}
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error@below {{Expected a gpu.launch op, but got: arith.constant}}
+ transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @set_gpu_launch_threads_many_handles(%arg0: memref<4096x4096xf16>) {
+ %c32 = arith.constant 32 : index
+ %c64 = arith.constant 64 : index
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error@below {{Requires exactly one targetOp handle (got 2)}}
+ transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @set_gpu_launch_threads_bad_threads(%arg0: memref<4096x4096xf16>) {
+ %c1 = arith.constant 1 : index
+ %c16 = arith.constant 16 : index
+ gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) {
+ gpu.terminator
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error@below {{Expected threads argument to consist of three values (got 2)}}
+ transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @insert_prefetch_dpas_c
+func.func @insert_prefetch_dpas_c(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+ %c32 = arith.constant 32 : index
+ %c4096 = arith.constant 4096 : index
+ %c0 = arith.constant 0 : index
+ %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+ // expected-note@below {{load op}}
+ %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+ %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+ %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
+ %5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+ %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+ scf.yield %7 : vector<256x256xf16>
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_operand %0[2] : (!transform.any_op) -> !transform.any_value
+ // expected-error@below {{Load op is not contained in a scf.for loop.}}
+ %2 = transform.xegpu.insert_prefetch %1 nb_prefetch = 1 : (!transform.any_value) -> !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
new file mode 100644
index 0000000..561034f
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -0,0 +1,509 @@
+// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL: @get_desc_op_a
+func.func @get_desc_op_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+ %c32 = arith.constant 32 : index
+ %c4096 = arith.constant 4096 : index
+ %c0 = arith.constant 0 : index
+ %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+ %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+ // expected-remark @below {{found desc op}}
+ %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+ %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
+ %5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+ %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+ scf.yield %7 : vector<256x256xf16>
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
+ %2 = transform.xegpu.get_desc_op %1 : (!transform.any_value) -> !transform.any_op
+ transform.debug.emit_remark_at %2, "found desc op" : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @get_desc_op_c
+func.func @get_desc_op_c(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+ %c32 = arith.constant 32 : index
+ %c4096 = arith.constant 4096 : index
+ %c0 = arith.constant 0 : index
+ // expected-remark @below {{found desc op}}
+ %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+ %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+ %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+ %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
+ %5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+ %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+ scf.yield %7 : vector<256x256xf16>
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_operand %0[2] : (!transform.any_op) -> !transform.any_value
+ %2 = transform.xegpu.get_desc_op %1 : (!transform.any_value) -> !transform.any_op
+ transform.debug.emit_remark_at %2, "found desc op" : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_desc_layout
+func.func @set_desc_layout(%arg0: memref<4096x4096xf16>) {
+ // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+ // CHECK-SAME: #xegpu.block_tdesc_attr<boundary_check = false>
+ // CHECK-SAME: #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>>
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.block_tdesc_attr<boundary_check = false>>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_desc_layout %{{.*}}
+ %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_desc_layout_minimal
+func.func @set_desc_layout_minimal(%arg0: memref<4096x4096xf16>) {
+ // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+ // CHECK-SAME: #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>>
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_desc_layout %{{.*}}
+ %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] sg_data = [32, 32] : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_desc_layout_param
+func.func @set_desc_layout_param(%arg0: memref<4096x4096xf16>) {
+ // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+ // CHECK-SAME: #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>>
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_desc_layout %{{.*}}
+ %layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
+ %1 = transform.xegpu.set_desc_layout %0 sg_layout = [%layout0, 4] sg_data = [32, 32] inst_data = [8, 16] : (!transform.any_op, !transform.param<i64>) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_desc_layout_slice
+func.func @set_desc_layout_slice(%arg0: memref<4096xf16>) {
+ // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+ // CHECK-SAME: #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096xf16> -> !xegpu.tensor_desc<256xf16>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_desc_layout %{{.*}}
+ %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] sg_data = [32, 32] slice_dims = [0] : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_result_default_index
+func.func @set_op_layout_attr_result_default_index(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+ %3 = xegpu.load_nd %2[0, 0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+ %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+ %5 = xegpu.load_nd %4[0, 0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+ // CHECK: = xegpu.dpas
+ // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
+ %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+ transform.xegpu.set_op_layout_attr %0 result sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_result_sg_param
+func.func @set_op_layout_attr_result_sg_param(%arg0: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ // CHECK: = arith.extf %1
+ // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
+ %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+ %layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
+ transform.xegpu.set_op_layout_attr %0 result sg_layout = [%layout0, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op, !transform.param<i64>
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_result_sg_param2
+func.func @set_op_layout_attr_result_sg_param2(%arg0: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ // CHECK: = arith.extf %1
+ // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
+ %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+ %layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
+ %layout1 = transform.param.constant 4 : i64 -> !transform.param<i64>
+ transform.xegpu.set_op_layout_attr %0 result sg_layout = [%layout0, %layout1] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op, !transform.param<i64>, !transform.param<i64>
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_result0
+func.func @set_op_layout_attr_result0(%arg0: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ // CHECK: = arith.extf %1
+ // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
+ %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+ transform.xegpu.set_op_layout_attr %0 result index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_result_slice
+func.func @set_op_layout_attr_result_slice(%arg0: vector<256xf16>) {
+ // CHECK: = arith.extf
+ // CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>, dims = [0]>}
+ %2 = arith.extf %arg0 : vector<256xf16> to vector<256xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+ transform.xegpu.set_op_layout_attr %0 result index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] slice_dims = [0] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_operand_minimal
+func.func @set_op_layout_attr_operand_minimal(%arg0: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ // CHECK: = arith.extf %1
+ // CHECK-SAME: {layout_operand_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64]>}
+ %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+ transform.xegpu.set_op_layout_attr %0 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_operand1
+func.func @set_op_layout_attr_operand1(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %3 = xegpu.load_nd %2[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ // CHECK: = arith.addf %1, %3
+ // CHECK-SAME: {layout_operand_1 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
+ %6 = arith.addf %1, %3 : vector<256x32xf16>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.addf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+ transform.xegpu.set_op_layout_attr %0 index = 1 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_gpu_launch_threads
+func.func @set_gpu_launch_threads(%arg0: memref<4096x4096xf16>) {
+ // CHECK: %[[C1:.+]] = arith.constant 1 : index
+ %c1 = arith.constant 1 : index
+ // CHECK: %[[C16:.+]] = arith.constant 16 : index
+ %c16 = arith.constant 16 : index
+ // CHECK: %[[C8:.+]] = arith.constant 8 : index
+ // CHECK: %[[C4:.+]] = arith.constant 4 : index
+ // CHECK: %[[C1_0:.+]] = arith.constant 1 : index
+ // CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C16]], %{{.*}} = %[[C16]], %{{.*}} = %[[C1]])
+ // CHECK-SAME: threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C8]], %{{.*}} = %[[C4]], %{{.*}} = %[[C1_0]])
+ gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) {
+ gpu.terminator
+ }
+ return
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_gpu_launch_threads %{{.*}}
+ transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_gpu_launch_threads_param
+func.func @set_gpu_launch_threads_param(%arg0: memref<4096x4096xf16>) {
+ // CHECK: %[[C1:.+]] = arith.constant 1 : index
+ %c1 = arith.constant 1 : index
+ // CHECK: %[[C16:.+]] = arith.constant 16 : index
+ %c16 = arith.constant 16 : index
+ // CHECK: %[[C8:.+]] = arith.constant 8 : index
+ // CHECK: %[[C4:.+]] = arith.constant 4 : index
+ // CHECK: %[[C1_0:.+]] = arith.constant 1 : index
+ // CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C16]], %{{.*}} = %[[C16]], %{{.*}} = %[[C1]])
+ // CHECK-SAME: threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C8]], %{{.*}} = %[[C4]], %{{.*}} = %[[C1_0]])
+ gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) {
+ gpu.terminator
+ }
+ return
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_gpu_launch_threads %{{.*}}
+ %th1 = transform.param.constant 4 : i64 -> !transform.param<i64>
+ transform.xegpu.set_gpu_launch_threads %0 threads = [8, %th1, 1] : !transform.any_op, !transform.param<i64>
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @insert_prefetch_dpas_a
+func.func @insert_prefetch_dpas_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+ // CHECK: %[[C32:.+]] = arith.constant 32 : index
+ %c32 = arith.constant 32 : index
+ %c4096 = arith.constant 4096 : index
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+ %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+ // CHECK: xegpu.create_nd_tdesc %arg0
+ // CHECK: xegpu.create_nd_tdesc %arg1
+ // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+ // CHECK-SAME: !xegpu.tensor_desc<256x32xf16
+ // CHECK: xegpu.prefetch_nd %[[V0]][%[[C0]], %[[C0]]]
+ %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+ // CHECK: scf.for %[[ARG3:.+]] = %[[C0]]
+ %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
+ // CHECK: %[[ADD:.+]] = arith.addi %[[ARG3]], %[[C32]]
+ // CHECK: xegpu.prefetch_nd %[[V0]][%[[C0]], %[[ADD]]]
+ %5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+ %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+ scf.yield %7 : vector<256x256xf16>
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["xegpu.dpas"]} in %func : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
+ // CHECK: transform.xegpu.insert_prefetch %{{.*}}
+ %2 = transform.xegpu.insert_prefetch %1 nb_prefetch = 1 : (!transform.any_value) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.canonicalization
+ } : !transform.any_op
+
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @insert_prefetch_dpas_a_nb_param2
+func.func @insert_prefetch_dpas_a_nb_param2(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+ // CHECK: %[[C64:.+]] = arith.constant 64 : index
+ // CHECK: %[[C32:.+]] = arith.constant 32 : index
+ %c32 = arith.constant 32 : index
+ %c4096 = arith.constant 4096 : index
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+ // CHECK: xegpu.create_nd_tdesc %arg0
+ // CHECK: xegpu.create_nd_tdesc %arg1
+ // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+ // CHECK-SAME: !xegpu.tensor_desc<256x32xf16
+ // CHECK: xegpu.prefetch_nd %[[V0]][0, %[[C0]]]
+ // CHECK: xegpu.prefetch_nd %[[V0]][0, %[[C32]]]
+ %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+ // CHECK: scf.for %[[ARG3:.+]] = %[[C0]]
+ %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
+ // CHECK: %[[ADD:.+]] = arith.addi %[[ARG3]], %[[C64]]
+ // CHECK: xegpu.prefetch_nd %[[V0]][0, %[[ADD]]]
+ %5 = xegpu.load_nd %3[0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %6 = xegpu.load_nd %4[%arg3, 0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+ %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+ scf.yield %7 : vector<256x256xf16>
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["xegpu.dpas"]} in %func : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
+ %nb = transform.param.constant 2 : i64 -> !transform.param<i64>
+ // CHECK: transform.xegpu.insert_prefetch %{{.*}}
+ %2 = transform.xegpu.insert_prefetch %1 nb_prefetch = %nb : (!transform.any_value, !transform.param<i64>) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.canonicalization
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @convert_layout_a
+func.func @convert_layout_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>>
+ // CHECK: %[[V1:.+]] = xegpu.load_nd %[[V0]]
+ %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>> -> vector<256x32xf16>
+ // CHECK: %[[V2:.+]] = xegpu.convert_layout %[[V1]]
+ // CHECK: input_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>
+ // CHECK: target_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>
+ %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+ %3 = xegpu.load_nd %2[%c0, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+ %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+ %5 = xegpu.load_nd %4[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+ // CHECK: = xegpu.dpas %[[V2]]
+ %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
+ // CHECK: transform.xegpu.convert_layout %{{.*}}
+ transform.xegpu.convert_layout %1
+ input_sg_layout = [8, 4] input_sg_data = [32, 32] input_inst_data = [32, 16]
+ target_sg_layout = [8, 4] target_sg_data = [32, 32] target_inst_data = [8, 16]
+ : (!transform.any_value) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @convert_layout_a_sg_param
+func.func @convert_layout_a_sg_param(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>>
+ // CHECK: %[[V1:.+]] = xegpu.load_nd %[[V0]]
+ %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>> -> vector<256x32xf16>
+ // CHECK: %[[V2:.+]] = xegpu.convert_layout %[[V1]]
+ // CHECK: input_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>
+ // CHECK: target_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>
+ %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+ %3 = xegpu.load_nd %2[%c0, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+ %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+ %5 = xegpu.load_nd %4[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+ // CHECK: = xegpu.dpas %[[V2]]
+ %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
+ %layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
+ // CHECK: transform.xegpu.convert_layout %{{.*}}
+ transform.xegpu.convert_layout %1
+ input_sg_layout = [%layout0, 4] input_sg_data = [32, 32] input_inst_data = [32, 16]
+ target_sg_layout = [%layout0, 4] target_sg_data = [32, 32] target_inst_data = [8, 16]
+ : (!transform.any_value, !transform.param<i64>, !transform.param<i64>) -> !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir b/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir
index b73bc69..8ce6d4d 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir
@@ -1,33 +1,30 @@
// RUN: mlir-opt --test-xegpu-layout-interface --cse -split-input-file %s | FileCheck %s
-//CHECk: #map = affine_map<()[s0] -> (s0 floordiv 8)>
gpu.module @test {
gpu.func @slice_attr() -> vector<128xindex> {
- //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
- //CHECK: [[IDY:%.+]] = affine.apply #map()[[[sgId]]]
- //CHECK: [[c32:%.+]] = arith.constant 32 : index
- //CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
- //CHECK: [[c128:%.+]] = arith.constant 128 : index
- //CHECK: [[MODY:%.+]] = index.remu [[LOCALY]], [[c128]]
- //CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
- //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
- //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
+ // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[DIVU:.*]] = arith.divui %[[SGID]], %[[C8:.*]]
+ // CHECK-DAG: %[[REMU:.*]] = arith.remui %[[DIVU]], %[[C4:.*]]
+ // CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU]], %[[C32:.*]]
+ // CHECK-DAG: %[[MOD:.*]] = arith.remui %[[MUL]], %[[C128:.*]]
+ // CHECK-DAG: %[[BASE:.*]] = vector.step : vector<32xindex>
+ // CHECK-DAG: %[[CAST:.*]] = vector.broadcast %[[MOD]] : index to vector<32xindex>
+ // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[BASE]], %[[CAST]] : vector<32xindex>
%step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>, dims = [1]>}: vector<128xindex>
gpu.return %step : vector<128xindex>
}
gpu.func @nested_slice_attr() -> vector<128xindex> {
- //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
- //CHECK: [[IDY:%.+]] = affine.apply #map()[[[sgId]]]
- //CHECK: [[c32:%.+]] = arith.constant 32 : index
- //CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
- //CHECK: [[c128:%.+]] = arith.constant 128 : index
- //CHECK: [[MODY:%.+]] = index.remu [[LOCALY]], [[c128]]
- //CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
- //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
- //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
+ // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[DIVU2:.*]] = arith.divui %[[SGID]], %[[C8:.*]]
+ // CHECK-DAG: %[[REMU:.*]] = arith.remui %[[DIVU2]], %[[C4:.*]]
+ // CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU]], %[[C32:.*]]
+ // CHECK-DAG: %[[MOD:.*]] = arith.remui %[[MUL]], %[[C128:.*]]
+ // CHECK-DAG: %[[BASE:.*]] = vector.step : vector<32xindex>
+ // CHECK-DAG: %[[CAST:.*]] = vector.broadcast %[[MOD]] : index to vector<32xindex>
+ // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[BASE]], %[[CAST]] : vector<32xindex>
%0 = vector.step {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 1], sg_data = [32, 32, 1]>, dims = [2]>, dims = [1]>} : vector<128xindex>
gpu.return %0 : vector<128xindex>
}
-} \ No newline at end of file
+}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
index 09df1e4..9580769 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
@@ -166,14 +166,12 @@ gpu.module @test_elementwise_ops {
%load_b = xegpu.load_nd %tdesc_b
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-> vector<24x32xf32>
- // CHECK-COUNT-12: arith.negf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
- // CHECK-SAME-COUNT-12: : vector<2x2xf32>
+ // CHECK-COUNT-12: arith.negf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
// CHECK-NOT: arith.negf
%negf = arith.negf %load_a
{layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
: vector<24x32xf32>
- // CHECK-COUNT-12: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
- // CHECK-SAME-COUNT-12: : vector<2x2xf32>
+ // CHECK-COUNT-12: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
// CHECK-NOT: math.powf
%powf = math.powf %load_a, %load_b
{layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index d2d250c..4829af3 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -1,14 +1,10 @@
// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
-#map = affine_map<()[s0] -> (s0 floordiv 4)>
-#map1 = affine_map<()[s0] -> (s0 mod 4)>
-
gpu.module @test_round_robin_assignment {
// CHECK-LABEL: create_nd_tdesc
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) {
- // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf32> -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-NOT: xegpu.create_nd_tdesc
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -16,22 +12,23 @@ gpu.module @test_round_robin_assignment {
}
// CHECK-LABEL: create_nd_tdesc_with_shared_data
- // CHECK-SAME: [[ARG_0:%.*]]: memref<256x128xf32>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
gpu.func @create_nd_tdesc_with_shared_data(%src: memref<256x128xf32>) {
- //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
- //CHECK: [[IdY:%.+]] = affine.apply #map()[[[sgId]]]
- //CHECK: [[IdX:%.+]] = affine.apply #map1()[[[sgId]]]
- //CHECK: [[C16:%.+]] = arith.constant 16 : index
- //CHECK: [[LY:%.+]] = index.mul [[IdY]], [[C16]]
- //CHECK: [[C64:%.+]] = arith.constant 64 : index
- //CHECK: [[LX:%.+]] = index.mul [[IdX]], [[C64]]
- //CHECK: [[C0:%.+]] = arith.constant 0 : index
- //CHECK: [[C0_1:%.+]] = arith.constant 0 : index
- //CHECK: [[C128:%.+]] = arith.constant 128 : index
- //CHECK: [[offY:%.+]] = index.remu [[LY]], [[C128]]
- //CHECK: [[C64_2:%.+]] = arith.constant 64 : index
- //CHECK: [[offX:%.+]] = index.remu [[LX]], [[C64_2]]
- //CHECK: xegpu.create_nd_tdesc [[ARG_0]][[[offY]], [[offX]]] : memref<256x128xf32> -> !xegpu.tensor_desc<16x64xf32>
+ // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK: %[[C4:.*]] = arith.constant 4 : index
+ // CHECK: %[[IDX:.*]] = arith.remui %[[SGID]], %[[C4]]
+ // CHECK: %[[IDY_DIV:.*]] = arith.divui %[[SGID]], %[[C4]]
+ // CHECK: %[[C8:.*]] = arith.constant 8 : index
+ // CHECK: %[[IDY:.*]] = arith.remui %[[IDY_DIV]], %[[C8]]
+ // CHECK: %[[C16:.*]] = arith.constant 16 : index
+ // CHECK: %[[LY:.*]] = arith.muli %[[IDY]], %[[C16]]
+ // CHECK: %[[C64:.*]] = arith.constant 64 : index
+ // CHECK: %[[LX:.*]] = arith.muli %[[IDX]], %[[C64]]
+ // CHECK: %[[C128:.*]] = arith.constant 128 : index
+ // CHECK: %[[OFFY:.*]] = arith.remui %[[LY]], %[[C128]]
+ // CHECK: %[[C64_1:.*]] = arith.constant 64 : index
+ // CHECK: %[[OFFX:.*]] = arith.remui %[[LX]], %[[C64_1]]
+ // CHECK: xegpu.create_nd_tdesc %[[ARG_0]][%[[OFFY]], %[[OFFX]]] : memref<256x128xf32> -> !xegpu.tensor_desc<16x64xf32>
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
-> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64]>>
gpu.return
@@ -42,9 +39,7 @@ gpu.module @test_round_robin_assignment {
gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) {
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
- // CHECK-COUNT-4: xegpu.load_nd %{{.*}}
- // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- // CHECK-SAME-COUNT-4: -> vector<16x16xf32>
+ // CHECK-COUNT-4: xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf32>
// CHECK-NOT: xegpu.load_nd
%load = xegpu.load_nd %tdesc
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -57,9 +52,8 @@ gpu.module @test_round_robin_assignment {
gpu.func @store_nd(%src: memref<256x128xf32>) {
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
- // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, %{{.*}}
- // CHECK-SAME-COUNT-4: : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- // CHECK-NOT : xegpu.store_nd
+ // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, %{{.*}} : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-NOT: xegpu.store_nd
%load = xegpu.load_nd %tdesc
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
-> vector<256x128xf32>
@@ -73,8 +67,7 @@ gpu.module @test_round_robin_assignment {
gpu.func @update_nd(%src: memref<256x128xf32>){
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
- // CHECK-COUNT-4: xegpu.update_nd_offset %{{.*}}, [0, 16]
- // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>>
+ // CHECK-COUNT-4: xegpu.update_nd_offset %{{.*}}, [0, 16] : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-NOT: xegpu.update_nd_offset
%update = xegpu.update_nd_offset %tdesc, [0, 16]
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -84,15 +77,9 @@ gpu.module @test_round_robin_assignment {
// CHECK-LABEL: dpas
// CHECK-SAME: (%[[ARG_0:.*]]: memref<256x128xf16>, %[[ARG_1:.*]]: memref<128x256xf16>)
gpu.func @dpas(%a: memref<256x128xf16>, %b: memref<128x256xf16>) {
- // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf16>
- // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- // CHECK-NOT: xegpu.create_nd_tdesc
- // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<128x256xf16>
- // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [4, 8], lane_data = [1, 1]>>
- // CHECK-NOT: xegpu.create_nd_tdesc
- // CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}}
- // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
- // CHECK-SAME-COUNT-16: : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
+ // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<128x256xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
+ // CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
// CHECK-NOT: xegpu.dpas
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<256x128xf16>
-> !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -113,8 +100,7 @@ gpu.module @test_round_robin_assignment {
// CHECK-LABEL: prefetch_nd_tdesc
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
gpu.func @prefetch_nd_tdesc(%src: memref<256x128xf32>) {
- // CHECK-COUNT-4: xegpu.prefetch_nd %{{.*}}
- // CHECK-SAME-COUNT-4: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-COUNT-4: xegpu.prefetch_nd %{{.*}} : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-NOT: xegpu.prefetch_nd
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -131,9 +117,7 @@ gpu.module @test_round_robin_assignment {
%load = xegpu.load_nd %tdesc
: !xegpu.tensor_desc<128x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
-> vector<128x1xf32>
- // CHECK-COUNT-2: vector.broadcast {{.*}}
- // CHECK-SAME-COUNT-2: {layout_result_0 = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>}
- // CHECK-SAME-COUNT-2: : vector<16x1xf32> to vector<16x32xf32>
+ // CHECK-COUNT-2: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>} : vector<16x1xf32> to vector<16x32xf32>
// CHECK-NOT: vector.broadcast
%broadcast = vector.broadcast %load
{layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 32], lane_layout = [8, 1], lane_data = [1, 1]>}
@@ -171,10 +155,10 @@ gpu.module @test_round_robin_assignment {
%0 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
%1 = xegpu.load_nd %0 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>> -> vector<256xf32>
%2 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
- //CHECK: scf.while ({{.*}}) : (vector<16xf32>, vector<16xf32>, i32) -> (vector<16xf32>, vector<16xf32>, i32)
+ // CHECK: scf.while ({{.*}}) : (vector<16xf32>, vector<16xf32>, i32) -> (vector<16xf32>, vector<16xf32>, i32)
%3:2 = scf.while (%arg2 = %1, %arg3 = %c0_i32) : (vector<256xf32>, i32) -> (vector<256xf32>, i32) {
%4 = arith.cmpi slt, %arg3, %c10_i32 : i32
- //CHECK: scf.condition{{.*}} : vector<16xf32>, vector<16xf32>, i32
+ // CHECK: scf.condition{{.*}} : vector<16xf32>, vector<16xf32>, i32
scf.condition(%4) %arg2, %arg3 : vector<256xf32>, i32
} do {
// CHECK: ([[arg2:%.+]]: vector<16xf32>, [[arg3:%.+]]: vector<16xf32>, [[arg4:%.+]]: i32)
@@ -195,16 +179,16 @@ gpu.module @test_round_robin_assignment {
%2 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
%3 = arith.cmpi eq, %0, %c10 : index
// CHECK-LABEL: scf.if
- // CHECK-SAME: (vector<16xf32>, vector<16xf32>)
+ // CHECK-SAME: (vector<16xf32>, vector<16xf32>)
%4 = scf.if %3 -> (vector<256xf32>) {
%5 = xegpu.load_nd %1 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>> -> vector<256xf32>
// CHECK-LABEL: scf.yield
- // CHECK-SAME: vector<16xf32>, vector<16xf32>
+ // CHECK-SAME: vector<16xf32>, vector<16xf32>
scf.yield %5 : vector<256xf32>
} else {
%5 = xegpu.load_nd %2 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>> -> vector<256xf32>
// CHECK-LABEL: scf.yield
- // CHECK-SAME: vector<16xf32>, vector<16xf32>
+ // CHECK-SAME: vector<16xf32>, vector<16xf32>
scf.yield %5 : vector<256xf32>
} {layout_result_0 = #xegpu.layout<sg_layout = [8], sg_data = [16]>}
xegpu.store_nd %4, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
@@ -220,16 +204,16 @@ gpu.module @test_round_robin_assignment {
%0 = arith.cmpi eq, %id, %c10 : index
// CHECK-LABEL: scf.if
- // CHECK-SAME: (!xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>)
+ // CHECK-SAME: (!xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>)
%1 = scf.if %0 -> (!xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>) {
%2 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
// CHECK-LABEL: scf.yield
- // CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>
+ // CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>
scf.yield %2 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
} else {
%3 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
// CHECK-LABEL: scf.yield
- // CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>
+ // CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>
scf.yield %3 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
}
xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
@@ -238,8 +222,8 @@ gpu.module @test_round_robin_assignment {
gpu.func @convert_layout_optimal(%arg0: memref<32x64xf32>) {
%0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<32x64xf32> -> !xegpu.tensor_desc<32x64xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>>
- //CHECK-2: xegpu.load_nd {{.*}} : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<inst_data = [16, 16]>> -> vector<16x16xf32>
- //CHECK-2: xegpu.convert_layout {{.*}} <{input_layout = #xegpu.layout<inst_data = [16, 16]>, target_layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<16x16xf32>
+ // CHECK-COUNT-2: xegpu.load_nd {{.*}} : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<inst_data = [16, 16]>> -> vector<16x16xf32>
+ // CHECK-COUNT-2: xegpu.convert_layout {{.*}} <{input_layout = #xegpu.layout<inst_data = [16, 16]>, target_layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<16x16xf32>
%1 = xegpu.load_nd %0 : !xegpu.tensor_desc<32x64xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>> -> vector<32x64xf32>
%2 = xegpu.convert_layout %1 <{input_layout = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>,
target_layout = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [8, 16]>}> : vector<32x64xf32>
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index 86a021b..c95c640 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -14,13 +14,11 @@ gpu.module @test_distribution {
// CHECK-LABEL: load_nd_tdesc_with_offset
gpu.func @load_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
- // CHECK-COUNT-4: xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}]
- // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- // CHECK-SAME-COUNT-4: -> vector<16x16xf32>
+ // CHECK-COUNT-4: xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf32>
// CHECK-NOT: xegpu.load_nd
%tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
- %load = xegpu.load_nd %tdesc[0, 0]
+ %load = xegpu.load_nd %tdesc[0, 0]
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
-> vector<256x128xf32>
gpu.return
@@ -28,8 +26,7 @@ gpu.module @test_distribution {
// CHECK-LABEL: store_nd_with_offset
gpu.func @store_nd_with_offset(%src: memref<256x128xf32>) {
- // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, {{%.*}}[{{%.*}}, {{%.*}}]
- // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, {{%.*}}[{{%.*}}, {{%.*}}] : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-NOT: xegpu.store_nd
%tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -42,10 +39,8 @@ gpu.module @test_distribution {
}
// CHECK-LABEL: prefetch_nd_tdesc_with_offset
- // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
gpu.func @prefetch_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
- // CHECK-COUNT-4: xegpu.prefetch_nd {{%.*}}[{{%.*}}, {{%.*}}]
- // CHECK-SAME-COUNT-4: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-COUNT-4: xegpu.prefetch_nd {{%.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-NOT: xegpu.prefetch_nd
%tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -57,15 +52,11 @@ gpu.module @test_distribution {
// CHECK-LABEL: dpas
// CHECK-SAME: (%[[ARG_0:.*]]: memref<256x128xf16>, %[[ARG_1:.*]]: memref<128x256xf16>)
gpu.func @dpas(%a: memref<256x128xf16>, %b: memref<128x256xf16>) {
- // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf16>
- // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- // CHECK-NOT: xegpu.create_nd_tdesc
- // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]] : memref<128x256xf16>
- // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [4, 8], lane_data = [1, 1]>>
- // CHECK-NOT: xegpu.create_nd_tdesc
- // CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}}
- // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
- // CHECK-SAME-COUNT-16: : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
+ // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-COUNT-4: xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
+ // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]] : memref<128x256xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
+ // CHECK-COUNT-4: xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
+ // CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
// CHECK-NOT: xegpu.dpas
%tdesc_a = xegpu.create_nd_tdesc %a : memref<256x128xf16>
-> !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -99,30 +90,57 @@ gpu.module @test_distribution {
gpu.return
}
+ // CHECK-LABEL: non_splat_constant
gpu.func @non_splat_constant() {
- // CHECK-DAG: %[[BASECST:.*]] = arith.constant dense<{{.*}}> : vector<2x1xindex>
+ // CHECK-DAG: %[[CST:.*]] = arith.constant dense<{{.*}}0{{.*}}, {{.*}}16{{.*}}> : vector<2x1xindex>
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
- // CHECK-DAG: %[[MAP4:.*]] = affine.apply #map4()[%[[SGID]]]
- // CHECK-DAG: %[[MAP5:.*]] = affine.apply #map5()[%[[SGID]]]
- // CHECK-DAG: %[[MUL:.*]] = index.mul %[[MAP4]], %[[C2:.*]]
- // CHECK-DAG: %[[REMU1:.*]] = index.remu %[[MUL]], %[[C32:.*]]
- // CHECK-DAG: %[[REMU2:.*]] = index.remu %[[MAP5]], %[[C1:.*]]
- // CHECK-DAG: %[[ADD16:.*]] = arith.addi %[[MUL]], %[[C16:.*]] : index
- // CHECK-DAG: %[[REMU3:.*]] = index.remu %[[ADD16]], %[[C32:.*]]
- // CHECK-DAG: %[[REMU4:.*]] = index.remu %[[MAP5]], %[[C1:.*]]
- // CHECK-DAG: %[[STRIDE1:.*]] = arith.muli %[[REMU1]], %[[C16:.*]] : index
- // CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[C0:.*]], %[[STRIDE1]] : index
- // CHECK-DAG: %[[STRIDE2:.*]] = arith.muli %[[REMU2]], %[[C0:.*]] : index
- // CHECK-DAG: %[[ADDSTRIDES1:.*]] = arith.addi %[[ADDSTRIDES]], %[[STRIDE2]] : index
- // CHECK-DAG: %[[BCAST1:.*]] = vector.broadcast %[[ADDSTRIDES1]] : index to vector<2x1xindex>
- // CHECK-DAG: %[[RESULT1:.*]] = arith.addi %[[BASECST]], %[[BCAST1]] : vector<2x1xindex>
- // CHECK-DAG: %[[STRIDE3:.*]] = arith.muli %[[REMU3]], %[[C16:.*]] : index
- // CHECK-DAG: %[[ADDSTRIDES2:.*]] = arith.addi %[[C0:.*]], %[[STRIDE3]] : index
- // CHECK-DAG: %[[STRIDE4:.*]] = arith.muli %[[REMU4]], %[[C0:.*]] : index
- // CHECK-DAG: %[[ADDSTRIDES3:.*]] = arith.addi %[[ADDSTRIDES2]], %[[STRIDE4]] : index
- // CHECK-DAG: %[[BCAST2:.*]] = vector.broadcast %[[ADDSTRIDES3]] : index to vector<2x1xindex>
- // CHECK-DAG: %[[RESULT2:.*]] = arith.addi %[[BASECST]], %[[BCAST2]] : vector<2x1xindex>
+ // CHECK-DAG: %[[T1:.*]] = arith.remui %[[SGID]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[T2:.*]] = arith.muli %[[T1]], %[[C2:.*]] : index
+ // CHECK-DAG: %[[T3:.*]] = arith.remui %[[T2]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[T4:.*]] = arith.addi %[[T2]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[T5:.*]] = arith.remui %[[T4]], %[[C32_6:.*]] : index
+ // CHECK-DAG: %[[T6:.*]] = arith.muli %[[T3]], %[[C16_10:.*]] : index
+ // CHECK-DAG: %[[T7:.*]] = arith.addi %[[C0_11:.*]], %[[T6]] : index
+ // CHECK-DAG: %[[T8:.*]] = arith.muli %[[C0_4:.*]], %[[C0_9:.*]] : index
+ // CHECK-DAG: %[[T9:.*]] = arith.addi %[[T7]], %[[T8]] : index
+ // CHECK-DAG: %[[T10:.*]] = vector.broadcast %[[T9]] : index to vector<2x1xindex>
+ // CHECK-DAG: %[[T11:.*]] = arith.addi %[[CST]], %[[T10]] : vector<2x1xindex>
+ // CHECK-DAG: %[[T12:.*]] = arith.muli %[[T5]], %[[C16_10:.*]] : index
+ // CHECK-DAG: %[[T13:.*]] = arith.addi %[[C0_12:.*]], %[[T12]] : index
+ // CHECK-DAG: %[[T14:.*]] = arith.muli %[[C0_8:.*]], %[[C0_9:.*]] : index
+ // CHECK-DAG: %[[T15:.*]] = arith.addi %[[T13]], %[[T14]] : index
+ // CHECK-DAG: %[[T16:.*]] = vector.broadcast %[[T15]] : index to vector<2x1xindex>
+ // CHECK-DAG: %[[T17:.*]] = arith.addi %[[CST]], %[[T16]] : vector<2x1xindex>
%cst_2 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [2, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>
gpu.return
}
+
+ // CHECK-LABEL: vector_transpose
+ gpu.func @vector_transpose(%src: memref<256x128xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 16], lane_layout = [16, 1], lane_data = [1, 1], order =[0, 1]>>
+ %load = xegpu.load_nd %tdesc[0, 0]
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 16], lane_layout = [16, 1], lane_data = [1, 1], order =[0, 1]>>
+ -> vector<256x128xf32>
+ // CHECK-COUNT-2: vector.transpose {{.*}}, [1, 0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<32x16xf32> to vector<16x32xf32>
+ // CHECK-NOT: vector.transpose
+ %trans = vector.transpose %load, [1, 0] {layout_result_0 = #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 32], lane_layout = [1, 16], lane_data = [1, 1], order =[1, 0]>} : vector<256x128xf32> to vector<128x256xf32>
+ gpu.return
+ }
+
+ // CHECK-LABEL: vector_mask_2D
+ gpu.func @vector_mask_2D() {
+ // CHECK-COUNT-4: vector.create_mask {{.*}}, {{.*}} : vector<16x16xi1>
+ // CHECK-NOT: vector.create_mask
+ %constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xi1>
+ gpu.return
+ }
+
+ gpu.func @vector_create_mask_2D() {
+ // CHECK-COUNT-4: vector.create_mask {{.*}}, {{.*}} : vector<16x16xi1>
+ // CHECK-NOT: vector.create_mask
+ %cst16 = arith.constant 16 : index
+ %constant_mask = vector.create_mask %cst16, %cst16 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xi1>
+ gpu.return
+ }
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 52acde4..69eb8ce 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -1,8 +1,5 @@
// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
-//CHECK: #map = affine_map<()[s0] -> (s0 floordiv 4)>
-//CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)>
-//CHECK: #map2 = affine_map<()[s0] -> (s0 floordiv 8)>
gpu.module @test_distribution {
// CHECK-LABEL: create_nd_tdesc_no_offset
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
@@ -26,13 +23,23 @@ gpu.module @test_distribution {
}
// CHECK-LABEL: load_nd_tdesc_with_offset
- // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
gpu.func @load_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
- //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
- //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
- //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
- //CHECK: %[[LOAD:.*]] = xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<32x32xf32>
- %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
+ //CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ //CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ //CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+ //CHECK-DAG: %[[SGIDX:.*]] = arith.remui %[[SGID]], %[[C4]]
+ //CHECK-DAG: %[[SGIDY_TMP:.*]] = arith.divui %[[SGID]], %[[C4]]
+ //CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+ //CHECK-DAG: %[[SGIDY:.*]] = arith.remui %[[SGIDY_TMP]], %[[C8]]
+ //CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
+ //CHECK-DAG: %[[L_OFF_Y:.*]] = arith.muli %[[SGIDY]], %[[C32]] : index
+ //CHECK-DAG: %[[L_OFF_X:.*]] = arith.muli %[[SGIDX]], %[[C32_1:.*]] : index
+ //CHECK-DAG: %[[C256:.*]] = arith.constant 256 : index
+ //CHECK-DAG: %[[OFF_Y:.*]] = arith.remui %[[L_OFF_Y]], %[[C256]] : index
+ //CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index
+ //CHECK-DAG: %[[OFF_X:.*]] = arith.remui %[[L_OFF_X]], %[[C128]] : index
+ //CHECK-DAG: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]][{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<32x32xf32>
+ %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc[0, 0]
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -43,9 +50,6 @@ gpu.module @test_distribution {
// CHECK-LABEL: store_nd_with_offsets
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
gpu.func @store_nd_with_offsets(%src: memref<256x128xf32>) {
- //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
- //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
- //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
//CHECK: xegpu.store_nd %{{.*}}, {{%.*}}[{{%.*}}, {{%.*}}] : vector<32x32xf32>, !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
%tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -60,9 +64,6 @@ gpu.module @test_distribution {
// CHECK-LABEL: prefetch_nd_tdesc_with_offset
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
gpu.func @prefetch_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
- //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
- //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
- //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
//CHECK: xegpu.prefetch_nd %{{.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
%cst0 = arith.constant 0 : index
%tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
@@ -285,14 +286,14 @@ gpu.module @test_distribution {
// CHECK: %[[VAL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8]>} dense<2.550000e+01> : vector<8xf16>
// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8]>} dense<0> : vector<8xindex>
// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8]>} dense<true> : vector<8xi1>
- // CHECK: xegpu.store %[[VAL]], %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}>
+ // CHECK: xegpu.store %[[VAL]], %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>, layout = #xegpu.layout<inst_data = [8]>}>
// CHECK-SAME: {layout_operand_0 = #xegpu.layout<inst_data = [8]>, layout_operand_2 = #xegpu.layout<inst_data = [8]>,
// CHECK-SAME: layout_operand_3 = #xegpu.layout<inst_data = [8]>}
// CHECK-SAME: : vector<8xf16>, memref<256xf16>, vector<8xindex>, vector<8xi1>
%val = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>} dense<25.5> : vector<256xf16>
%offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>} dense<0> : vector<256xindex>
%mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>} dense<1> : vector<256xi1>
- xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout_operand_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>,
+ xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout_operand_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>,
layout_operand_2 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>,
layout_operand_3 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>,
l1_hint = #xegpu.cache_hint<cached>}
@@ -319,21 +320,19 @@ gpu.module @test_distribution {
gpu.func @distribute_load_matrix(%arg0: memref<32768xi8, 3>) {
//CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[arg0]] : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
//CHECK: [[sgid:%.+]] = gpu.subgroup_id : index
- //CHECK: [[c2:%.+]] = arith.constant 2 : index
//CHECK: [[c4:%.+]] = arith.constant 4 : index
- //CHECK: [[c4_0:%.+]] = arith.constant 4 : index
- //CHECK: [[id_y:%.+]] = affine.apply #map()[[[sgid]]]
- //CHECK: [[id_x:%.+]] = affine.apply #map1()[[[sgid]]]
+ //CHECK: [[sgidx:%.+]] = arith.remui [[sgid]], [[c4]] : index
+ //CHECK: [[sgidy_tmp:%.+]] = arith.divui [[sgid]], [[c4]] : index
+ //CHECK: [[c2:%.+]] = arith.constant 2 : index
+ //CHECK: [[sgidy:%.+]] = arith.remui [[sgidy_tmp]], [[c2]] : index
//CHECK: [[c32:%.+]] = arith.constant 32 : index
- //CHECK: [[l_off_y:%.+]] = index.mul [[id_y]], [[c32]]
- //CHECK: [[c32_1:%.+]] = arith.constant 32 : index
- //CHECK: [[l_off_x:%.+]] = index.mul [[id_x]], [[c32_1]]
- //CHECK: [[c0:%.+]] = arith.constant 0 : index
- //CHECK: [[c0_1:%.+]] = arith.constant 0 : index
+ //CHECK: [[l_off_y:%.+]] = arith.muli [[sgidy]], [[c32]] : index
+ //CHECK: [[c32_0:%.+]] = arith.constant 32 : index
+ //CHECK: [[l_off_x:%.+]] = arith.muli [[sgidx]], [[c32_0]] : index
//CHECK: [[c64:%.+]] = arith.constant 64 : index
- //CHECK: [[off_y:%.+]] = index.remu [[l_off_y]], [[c64]]
+ //CHECK: [[off_y:%.+]] = arith.remui [[l_off_y]], [[c64]] : index
//CHECK: [[c128:%.+]] = arith.constant 128 : index
- //CHECK: [[off_x:%.+]] = index.remu [[l_off_x]], [[c128]]
+ //CHECK: [[off_x:%.+]] = arith.remui [[l_off_x]], [[c128]] : index
//CHECK: xegpu.load_matrix [[mdesc]][[[off_y]], [[off_x]]] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}>: !xegpu.mem_desc<64x128xf32>, index, index -> vector<32x32xf32>
%0 = xegpu.create_mem_desc %arg0 : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
%1 = xegpu.load_matrix %0[0, 0] <{layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32], lane_layout = [2, 8], lane_data = [1, 1]>}>: !xegpu.mem_desc<64x128xf32> -> vector<64x128xf32>
@@ -346,21 +345,19 @@ gpu.module @test_distribution {
//CHECK: [[cst:%.+]] = arith.constant dense<1.000000e+00> : vector<32x32xf32>
//CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[arg0]] : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
//CHECK: [[sgid:%.+]] = gpu.subgroup_id : index
- //CHECK: [[c2:%.+]] = arith.constant 2 : index
//CHECK: [[c4:%.+]] = arith.constant 4 : index
- //CHECK: [[c4_0:%.+]] = arith.constant 4 : index
- //CHECK: [[id_y:%.+]] = affine.apply #map()[[[sgid]]]
- //CHECK: [[id_x:%.+]] = affine.apply #map1()[[[sgid]]]
+ //CHECK: [[sgidx:%.+]] = arith.remui [[sgid]], [[c4]] : index
+ //CHECK: [[sgidy_tmp:%.+]] = arith.divui [[sgid]], [[c4]] : index
+ //CHECK: [[c2:%.+]] = arith.constant 2 : index
+ //CHECK: [[sgidy:%.+]] = arith.remui [[sgidy_tmp]], [[c2]] : index
//CHECK: [[c32:%.+]] = arith.constant 32 : index
- //CHECK: [[l_off_y:%.+]] = index.mul [[id_y]], [[c32]]
- //CHECK: [[c32_1:%.+]] = arith.constant 32 : index
- //CHECK: [[l_off_x:%.+]] = index.mul [[id_x]], [[c32_1]]
- //CHECK: [[c0:%.+]] = arith.constant 0 : index
- //CHECK: [[c0_2:%.+]] = arith.constant 0 : index
+ //CHECK: [[l_off_y:%.+]] = arith.muli [[sgidy]], [[c32]] : index
+ //CHECK: [[c32_0:%.+]] = arith.constant 32 : index
+ //CHECK: [[l_off_x:%.+]] = arith.muli [[sgidx]], [[c32_0]] : index
//CHECK: [[c64:%.+]] = arith.constant 64 : index
- //CHECK: [[off_y:%.+]] = index.remu [[l_off_y]], [[c64]]
+ //CHECK: [[off_y:%.+]] = arith.remui [[l_off_y]], [[c64]] : index
//CHECK: [[c128:%.+]] = arith.constant 128 : index
- //CHECK: [[off_x:%.+]] = index.remu [[l_off_x]], [[c128]]
+ //CHECK: [[off_x:%.+]] = arith.remui [[l_off_x]], [[c128]] : index
//CHECK: xegpu.store_matrix [[cst]], [[mdesc]][[[off_y]], [[off_x]]] : vector<32x32xf32>, !xegpu.mem_desc<64x128xf32>, index, index
%cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32]>} dense<1.0> : vector<64x128xf32>
%mdesc = xegpu.create_mem_desc %arg0 : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
@@ -411,14 +408,17 @@ gpu.module @test_distribution {
// CHECK-LABEL: vector_step_op
gpu.func @vector_step_op_slice_attr() {
//CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
- //CHECK-DAG: [[IDY:%.+]] = affine.apply #map2()[[[sgId]]]
- //CHECK-DAG: [[c32:%.+]] = arith.constant 32 : index
- //CHECK-DAG: [[LY:%.+]] = index.mul [[IDY]], [[c32]]
- //CHECK-DAG: [[c0:%.+]] = arith.constant 0 : index
- //CHECK-DAG: [[c128:%.+]] = arith.constant 128 : index
- //CHECK-DAG: [[MODY:%.+]] = index.remu [[LY]], [[c128]]
- //CHECK-DAG: [[BASE:%.+]] = vector.step : vector<32xindex>
- //CHECK-DAG: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
+ //CHECK: [[c8:%.+]] = arith.constant 8 : index
+ //CHECK: [[sgidx:%.+]] = arith.remui [[sgId]], [[c8]] : index
+ //CHECK: [[sgidy_tmp:%.+]] = arith.divui [[sgId]], [[c8]] : index
+ //CHECK: [[c4:%.+]] = arith.constant 4 : index
+ //CHECK: [[sgidy:%.+]] = arith.remui [[sgidy_tmp]], [[c4]] : index
+ //CHECK: [[c32:%.+]] = arith.constant 32 : index
+ //CHECK: [[LY:%.+]] = arith.muli [[sgidy]], [[c32]] : index
+ //CHECK: [[c128:%.+]] = arith.constant 128 : index
+ //CHECK: [[MODY:%.+]] = arith.remui [[LY]], [[c128]] : index
+ //CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
+ //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
//CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
%step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>, dims = [1]>}: vector<128xindex>
gpu.return
@@ -426,14 +426,14 @@ gpu.module @test_distribution {
gpu.func @vector_step_op_layout_attr() {
//CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
- //CHECK-DAG: [[c16:%.+]] = arith.constant 16 : index
- //CHECK-DAG: [[c8:%.+]] = arith.constant 8 : index
- //CHECK-DAG: [[LOCALY:%.+]] = index.mul [[sgId]], [[c8]]
- //CHECK-DAG: [[c0:%.+]] = arith.constant 0 : index
- //CHECK-DAG: [[c128:%.+]] = arith.constant 128 : index
- //CHECK-DAG: [[MODY:%.+]] = index.remu [[LOCALY]], [[c128]]
- //CHECK-DAG: [[BASE:%.+]] = vector.step : vector<8xindex>
- //CHECK-DAG: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<8xindex>
+ //CHECK: [[c16:%.+]] = arith.constant 16 : index
+ //CHECK: [[sgidx:%.+]] = arith.remui [[sgId]], [[c16]] : index
+ //CHECK: [[c8:%.+]] = arith.constant 8 : index
+ //CHECK: [[LOCALY:%.+]] = arith.muli [[sgidx]], [[c8]] : index
+ //CHECK: [[c128:%.+]] = arith.constant 128 : index
+ //CHECK: [[MODY:%.+]] = arith.remui [[LOCALY]], [[c128]] : index
+ //CHECK: [[BASE:%.+]] = vector.step : vector<8xindex>
+ //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<8xindex>
//CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<8xindex>
%step = vector.step {layout_result_0 = #xegpu.layout<sg_layout = [16], sg_data = [8]>}: vector<128xindex>
gpu.return
@@ -464,40 +464,49 @@ gpu.module @test_distribution {
gpu.return
}
+ // CHECK-LABEL: vector_transpose
+ gpu.func @vector_transpose(%src: memref<256x32xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src : memref<256x32xf32>
+ -> !xegpu.tensor_desc<256x32xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 32], lane_layout = [16, 1], lane_data = [1, 1], order =[0, 1]>>
+ %load = xegpu.load_nd %tdesc[0, 0]
+ : !xegpu.tensor_desc<256x32xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 32], lane_layout = [16, 1], lane_data = [1, 1], order =[0, 1]>>
+ -> vector<256x32xf32>
+ //CHECK: vector.transpose {{.*}}, [1, 0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<64x32xf32> to vector<32x64xf32>
+ %trans = vector.transpose %load, [1, 0] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], lane_layout = [1, 16], lane_data = [1, 1], order =[1, 0]>} : vector<256x32xf32> to vector<32x256xf32>
+ gpu.return
+ }
+
// CHECK-LABEL: non_splat_constant_2D
gpu.func @non_splat_constant_2D() {
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1x1xindex>
- // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
- // CHECK-DAG: affine.apply #map4()[%[[SGID]]]
- // CHECK-DAG: affine.apply #map5()[%[[SGID]]]
- // CHECK-DAG: %[[IDY:.*]] = index.remu %{{.*}}, %[[C32:.*]]
- // CHECK-DAG: %[[IDX:.*]] = index.remu %{{.*}}, %[[C1:.*]]
- // CHECK-DAG: %[[STRIDECOL:.*]] = arith.muli %[[IDY]], %[[C16:.*]] : index
- // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[C0:.*]], %[[STRIDECOL]] : index
- // CHECK-DAG: %[[STRIDEROW:.*]] = arith.muli %[[IDX]], %[[C0:.*]] : index
- // CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[ADD]], %[[STRIDEROW]] : index
- // CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ADDSTRIDES]] : index to vector<1x1xindex>
- // CHECK-DAG: arith.addi %[[CST]], %[[BCAST]] : vector<1x1xindex>
+ // CHECK-DAG: %[[T0:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[T1:.*]] = arith.remui %[[T0]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[T2:.*]] = arith.remui %[[T1]], %[[C32_4:.*]] : index
+ // CHECK-DAG: %[[T3:.*]] = arith.muli %[[T2]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[T4:.*]] = arith.addi %[[C0_8:.*]], %[[T3]] : index
+ // CHECK-DAG: %[[T5:.*]] = arith.muli %[[C0_6:.*]], %[[C0_7:.*]] : index
+ // CHECK-DAG: %[[T6:.*]] = arith.addi %[[T4]], %[[T5]] : index
+ // CHECK-DAG: %[[T7:.*]] = vector.broadcast %[[T6]] : index to vector<1x1xindex>
+ // CHECK-DAG: %[[T8:.*]] = arith.addi %[[CST]], %[[T7]] : vector<1x1xindex>
%cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [1, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>
gpu.return
}
// CHECK-LABEL: non_splat_constant_2D_non_unit_dim
gpu.func @non_splat_constant_2D_non_unit_dim() {
- // CHECK-DAG: %[[BASECST:.*]] = arith.constant dense<{{.*}} : vector<2x2xindex>
+ // CHECK-DAG: %[[BASECST:.*]] = arith.constant dense<{{\[}}{{\[}}0, 16{{\]}}, {{\[}}8, 24{{\]}}{{\]}}> : vector<2x2xindex>
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
- // CHECK-DAG: %[[IDY:.*]] = affine.apply #map()[%[[SGID]]]
- // CHECK-DAG: %[[IDX:.*]] = affine.apply #map1()[%[[SGID]]]
- // CHECK-DAG: %[[MULY:.*]] = index.mul %[[IDY]], %[[C2:.*]]
- // CHECK-DAG: %[[C2_2:.*]] = arith.constant 2 : index
- // CHECK-DAG: %[[MULX:.*]] = index.mul %[[IDX]], %[[C2:.*]]
- // CHECK-DAG: %[[REMU_Y:.*]] = index.remu %[[MULY]], %[[C8:.*]]
- // CHECK-DAG: %[[C8_2:.*]] = arith.constant 8 : index
- // CHECK-DAG: %[[REMU_X:.*]] = index.remu %[[MULX]], %[[C8:.*]]
- // CHECK-DAG: %[[MUL5:.*]] = arith.muli %[[REMU_Y]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[SGIDX:.*]] = arith.remui %[[SGID]], %{{.*}}
+ // CHECK-DAG: %[[SGIDY_TMP:.*]] = arith.divui %[[SGID]], %{{.*}}
+ // CHECK-DAG: %[[SGIDY:.*]] = arith.remui %[[SGIDY_TMP]], %{{.*}}
+ // CHECK-DAG: %[[MULY:.*]] = arith.muli %[[SGIDY]], %[[C2:.*]] : index
+ // CHECK-DAG: %[[MULX:.*]] = arith.muli %[[SGIDX]], %{{.*}} : index
+ // CHECK-DAG: %[[REMU_Y:.*]] = arith.remui %[[MULY]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[REMU_X:.*]] = arith.remui %[[MULX]], %{{.*}} : index
+ // CHECK-DAG: %[[MUL5:.*]] = arith.muli %[[REMU_Y]], %{{.*}} : index
// CHECK-DAG: %[[ADD:.*]] = arith.addi %[[C0:.*]], %[[MUL5]] : index
// CHECK-DAG: %[[MUL6:.*]] = arith.muli %[[REMU_X]], %[[C16:.*]] : index
- // CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[ADD]], %[[MUL6]] : index
+ // CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[ADD]], %[[MUL6]] : index
// CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ADDSTRIDES]] : index to vector<2x2xindex>
// CHECK-DAG: %[[ADDCST:.*]] = arith.addi %[[BASECST]], %[[BCAST]] : vector<2x2xindex>
%cst_8x8 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2]>} dense<[
@@ -517,13 +526,14 @@ gpu.module @test_distribution {
gpu.func @non_splat_constant() {
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
- // CHECK-DAG: %[[REMU:.*]] = index.remu %[[SGID]], %[[C32:.*]]
- // CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[REMU:.*]] = arith.remui %[[SGID]], %{{.*}}
+ // CHECK-DAG: %[[REMU2:.*]] = arith.remui %[[REMU]], %{{.*}}
+ // CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU2]], %[[C16:.*]] : index
// CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[C0:.*]], %[[MUL]] : index
// CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ADDSTRIDES]] : index to vector<1xindex>
// CHECK-DAG: %[[ADD:.*]] = arith.addi %[[CST]], %[[BCAST]] : vector<1xindex>
%cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [1]>} dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256, 272, 288, 304, 320, 336, 352, 368, 384, 400, 416, 432, 448, 464, 480, 496]> : vector<32xindex>
- // CHECK: arith.constant dense<{{\[}}[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]{{\]}}> : vector<1x16xindex>
+ // CHECK: arith.constant dense<{{\[}}{{\[}}0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15{{\]}}{{\]}}> : vector<1x16xindex>
%cst_1 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [1, 16]>} dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]> : vector<1x16xindex>
gpu.return
}
@@ -534,4 +544,106 @@ gpu.module @test_distribution {
%broadcast = vector.broadcast %arg0 {layout_result_0 = #xegpu.layout<sg_layout = [4, 8, 1], sg_data = [1, 1, 1]>} : index to vector<4x1x1xindex>
gpu.return
}
+
+ // CHECK-LABEL: vector_mask_1D
+ gpu.func @vector_mask_1D() {
+ // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[REMU:.*]] = arith.remui %[[SGID]], %[[C2:.*]]
+ // CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[REMU2:.*]] = arith.remui %[[MUL]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[SUB:.*]] = arith.subi %[[C8:.*]], %[[REMU2]] : index
+ // CHECK-DAG: %[[MAX:.*]] = arith.maxsi %[[SUB]], %[[C0:.*]] : index
+ // CHECK-DAG: %[[MIN:.*]] = arith.minsi %[[MAX]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MIN]] : vector<16xi1>
+ %constant_mask = vector.constant_mask [8] {layout_result_0 = #xegpu.layout<sg_layout = [2], sg_data = [16]>} : vector<32xi1>
+ gpu.return
+ }
+
+ // CHECK-LABEL: vector_mask_2D
+ gpu.func @vector_mask_2D() {
+ // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[SGIDX:.*]] = arith.remui %[[SGID]], %[[C4:.*]]
+ // CHECK-DAG: %[[SGIDY_TMP:.*]] = arith.divui %[[SGID]], %[[C4:.*]]
+ // CHECK-DAG: %[[SGIDY:.*]] = arith.remui %[[SGIDY_TMP]], %[[C8:.*]]
+ // CHECK-DAG: %[[ROW:.*]] = arith.muli %[[SGIDY]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[COL:.*]] = arith.muli %[[SGIDX]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[MODROW:.*]] = arith.remui %[[ROW]], %[[C256:.*]] : index
+ // CHECK-DAG: %[[MODCOL:.*]] = arith.remui %[[COL]], %[[C128:.*]] : index
+ // CHECK-DAG: %[[SUBROW:.*]] = arith.subi %[[C16:.*]], %[[MODROW]] : index
+ // CHECK-DAG: %[[MAXROW:.*]] = arith.maxsi %[[SUBROW]], %[[C4:.*]] : index
+ // CHECK-DAG: %[[MINROW:.*]] = arith.minsi %[[MAXROW]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[SUBCOL:.*]] = arith.subi %[[C16:.*]], %[[MODCOL]] : index
+ // CHECK-DAG: %[[MAXCOL:.*]] = arith.maxsi %[[SUBCOL]], %[[C7:.*]] : index
+ // CHECK-DAG: %[[MINCOL:.*]] = arith.minsi %[[MAXCOL]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MINROW]], %[[MINCOL]] : vector<32x32xi1>
+ %constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<256x128xi1>
+ gpu.return
+ }
+
+ // CHECK-LABEL: vector_create_mask_1D
+ gpu.func @vector_create_mask_1D() {
+ // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[REMU:.*]] = arith.remui %[[SGID]], %[[C2:.*]]
+ // CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU]], %[[C16:.*]]
+ // CHECK-DAG: %[[REMU2:.*]] = arith.remui %[[MUL]], %[[C32:.*]]
+ // CHECK-DAG: %[[SUB:.*]] = arith.subi %[[C8:.*]], %[[REMU2]] : index
+ // CHECK-DAG: %[[MAX:.*]] = arith.maxsi %[[SUB]], %[[C0:.*]] : index
+ // CHECK-DAG: %[[MIN:.*]] = arith.minsi %[[MAX]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MIN]] : vector<16xi1>
+ %cst8 = arith.constant 8 : index
+ %constant_mask = vector.create_mask %cst8 {layout_result_0 = #xegpu.layout<sg_layout = [2], sg_data = [16]>} : vector<32xi1>
+ gpu.return
+ }
+
+ // CHECK-LABEL: vector_create_mask_2D
+ gpu.func @vector_create_mask_2D() {
+ // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[SGIDX:.*]] = arith.remui %[[SGID]], %[[C4:.*]]
+ // CHECK-DAG: %[[SGIDY_TMP:.*]] = arith.divui %[[SGID]], %[[C4:.*]]
+ // CHECK-DAG: %[[SGIDY:.*]] = arith.remui %[[SGIDY_TMP]], %[[C8:.*]]
+ // CHECK-DAG: %[[ROW:.*]] = arith.muli %[[SGIDY]], %[[C32:.*]]
+ // CHECK-DAG: %[[COL:.*]] = arith.muli %[[SGIDX]], %[[C32:.*]]
+ // CHECK-DAG: %[[MODROW:.*]] = arith.remui %[[ROW]], %[[C256:.*]]
+ // CHECK-DAG: %[[MODCOL:.*]] = arith.remui %[[COL]], %[[C128:.*]]
+ // CHECK-DAG: %[[SUBROW:.*]] = arith.subi %[[C16:.*]], %[[MODROW]] : index
+ // CHECK-DAG: %[[MAXROW:.*]] = arith.maxsi %[[SUBROW]], %[[C0:.*]] : index
+ // CHECK-DAG: %[[MINROW:.*]] = arith.minsi %[[MAXROW]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[SUBCOL:.*]] = arith.subi %[[C16:.*]], %[[MODCOL]] : index
+ // CHECK-DAG: %[[MAXCOL:.*]] = arith.maxsi %[[SUBCOL]], %[[C0:.*]] : index
+ // CHECK-DAG: %[[MINCOL:.*]] = arith.minsi %[[MAXCOL]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MINROW]], %[[MINCOL]] : vector<32x32xi1>
+ %cst16 = arith.constant 16 : index
+ %constant_mask = vector.create_mask %cst16, %cst16 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<256x128xi1>
+ gpu.return
+ }
+
+ // CHECK-LABEL: distribute_load_slice_attr
+ gpu.func @distribute_load_slice_attr() {
+ %2 = memref.alloca() {alignment = 1024} : memref<4096xf32>
+ %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8], sg_data = [32], inst_data = [16]> } dense<0> : vector<256xindex>
+ %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8], sg_data = [32], inst_data = [16]> } dense<1> : vector<256xi1>
+
+ // CHECK: %[[LOAD:.*]] = xegpu.load {{.*}} <{chunk_size = 1 : i64, layout = #xegpu.slice<#xegpu.layout<inst_data = [8, 16]>, dims = [0]>}>
+ // CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<inst_data = [8, 16]>, dims = [0]>} :
+ // CHECK-SAME: memref<4096xf32>, vector<32xindex>, vector<32xi1> -> vector<32xf32>
+ %3 = xegpu.load %2[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>, dims = [0]> } : memref<4096xf32>, vector<256xindex>, vector<256xi1> -> vector<256xf32>
+
+ // CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[LOAD]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : vector<32xf32> to vector<32x32xf32>
+ %4 = vector.broadcast %3 {layout_result_0 =
+ #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>} : vector<256xf32> to vector<256x256xf32>
+ gpu.return
+ }
+
+ // CHECK-LABEL: load_nd_tdesc_with_anchor_layout
+ gpu.func @load_nd_tdesc_with_anchor_layout(%src: memref<256x128xf32>) {
+ //CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK: xegpu.load_nd %[[TDESC]][{{%.*}}, {{%.*}}] <{layout = #xegpu.layout<inst_data = [32, 16], lane_layout = [1, 16], lane_data = [1, 1]>}>
+ // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<32x32xf32>
+ %load = xegpu.load_nd %tdesc[0, 0] <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16],lane_layout = [1, 16], lane_data = [1, 1]>}>
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<256x128xf32>
+ gpu.return
+ }
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index e83229e..a8015cc 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -1,47 +1,35 @@
// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
-//CHECK: #map = affine_map<()[s0] -> (s0 floordiv 4)>
-//CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)>
gpu.module @test_1_1_assignment {
// CHECK-LABEL: create_nd_tdesc
- // CHECK-SAME: [[ARG_0:%.*]]: memref<256x128xf32>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) {
- //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
- //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
- //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
- //CHECK: [[C32:%.+]] = arith.constant 32 : index
- //CHECK: [[LY:%.+]] = index.mul [[SGIDY]], [[C32]]
- //CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C32]]
- //CHECK: [[C0:%.+]] = arith.constant 0 : index
- //CHECK: [[C0_1:%.+]] = arith.constant 0 : index
- //CHECK: [[C256:%.+]] = arith.constant 256 : index
- //CHECK: [[Y:%.+]] = index.remu [[LY]], [[C256]]
- //CHECK: [[C128:%.+]] = arith.constant 128 : index
- //CHECK: [[X:%.+]] = index.remu [[LX]], [[C128]]
- //CHECK: [[TDESC:%.+]] = xegpu.create_nd_tdesc [[ARG_0]][[[Y]], [[X]]] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[REMUX:.*]] = arith.remui %[[SGID]], %[[C4:.*]]
+ // CHECK-DAG: %[[DIVU:.*]] = arith.divui %[[SGID]], %[[C4:.*]]
+ // CHECK-DAG: %[[REMUY:.*]] = arith.remui %[[DIVU]], %[[C8:.*]]
+ // CHECK-DAG: %[[MULY:.*]] = arith.muli %[[REMUY]], %[[C32:.*]]
+ // CHECK-DAG: %[[MULX:.*]] = arith.muli %[[REMUX]], %[[C32:.*]]
+ // CHECK-DAG: %[[MODY:.*]] = arith.remui %[[MULY]], %[[C256:.*]]
+ // CHECK-DAG: %[[MODX:.*]] = arith.remui %[[MULX]], %[[C128:.*]]
+ // CHECK-DAG: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[MODY]], %[[MODX]]] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
// CHECK-LABEL: create_nd_tdesc_from_higher_rank_memref
- // CHECK-SAME: [[ARG_0:%.*]]: memref<3x256x128xf32>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<3x256x128xf32>
gpu.func @create_nd_tdesc_from_higher_rank_memref(%src: memref<3x256x128xf32>) {
- //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
- //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
- //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
- //CHECK: [[C32:%.+]] = arith.constant 32 : index
- //CHECK: [[LY:%.+]] = index.mul [[SGIDY]], [[C32]]
- //CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C32]]
- //CHECK: [[C0:%.+]] = arith.constant 0 : index
- //CHECK: [[C0_2:%.+]] = arith.constant 0 : index
- //CHECK: [[C256:%.+]] = arith.constant 256 : index
- //CHECK: [[MODY:%.+]] = index.remu [[LY]], [[C256]]
- //CHECK: [[C128:%.+]] = arith.constant 128 : index
- //CHECK: [[MODX:%.+]] = index.remu [[LX]], [[C128]]
- //CHECK: [[C0_3:%.+]] = arith.constant 0 : index
- //CHECK: [[C0_4:%.+]] = arith.constant 0 : index
- //CHECK: [[TDESC:%.+]] = xegpu.create_nd_tdesc [[ARG_0]][1, [[MODY]], [[MODX]]] : memref<3x256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[REMUX:.*]] = arith.remui %[[SGID]], %[[C4:.*]]
+ // CHECK-DAG: %[[DIVU:.*]] = arith.divui %[[SGID]], %[[C4:.*]]
+ // CHECK-DAG: %[[REMUY:.*]] = arith.remui %[[DIVU]], %[[C8:.*]]
+ // CHECK-DAG: %[[MULY:.*]] = arith.muli %[[REMUY]], %[[C32:.*]]
+ // CHECK-DAG: %[[MULX:.*]] = arith.muli %[[REMUX]], %[[C32:.*]]
+ // CHECK-DAG: %[[MODY:.*]] = arith.remui %[[MULY]], %[[C256:.*]]
+ // CHECK-DAG: %[[MODX:.*]] = arith.remui %[[MULX]], %[[C128:.*]]
+ // CHECK-DAG: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][1, %[[MODY]], %[[MODX]]] : memref<3x256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
%tdesc = xegpu.create_nd_tdesc %src[1, 0, 0] : memref<3x256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
@@ -81,25 +69,24 @@ gpu.module @test_1_1_assignment {
xegpu.store_nd %load, %tdesc
: vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
-}
+ }
-// CHECK-LABEL: update_nd
-// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
-gpu.func @update_nd(%src: memref<256x128xf32>){
- // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- // CHECK: %[[UPDATE:.*]] = xegpu.update_nd_offset %[[TDESC]], [0, 16]
- // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
- -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
- %update = xegpu.update_nd_offset %tdesc, [0, 16]
- : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
- gpu.return
-}
+ // CHECK-LABEL: update_nd
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @update_nd(%src: memref<256x128xf32>){
+ // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
+ // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK: %[[UPDATE:.*]] = xegpu.update_nd_offset %[[TDESC]], [0, 16]
+ // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ %update = xegpu.update_nd_offset %tdesc, [0, 16]
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ gpu.return
+ }
-// CHECK-LABEL: dpas
-gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
- // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x128xf16>, vector<128x16xf16> -> vector<16x16xf32>
+ // CHECK-LABEL: dpas
+ gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<128x128xf16>
-> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>>
%load_a = xegpu.load_nd %tdesc_a
@@ -110,16 +97,15 @@ gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
%load_b = xegpu.load_nd %tdesc_b
: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
-> vector<128x128xf16>
+ // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x128xf16>, vector<128x16xf16> -> vector<16x16xf32>
%dpas = xegpu.dpas %load_a, %load_b
{layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>}
: vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
gpu.return
}
-
-// CHECK-LABEL: dpas_no_sg_data
-gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
- // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
+ // CHECK-LABEL: dpas_no_sg_data
+ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<128x128xf16>
-> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1],
order = [1, 0]>>
@@ -134,6 +120,7 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1],
order = [1, 0]>>
-> vector<128x128xf16>
+ // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
%dpas = xegpu.dpas %load_a, %load_b
{layout_result_0 = #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>}
: vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
@@ -196,9 +183,9 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
}
gpu.func @scf_for(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) {
- //CHECK: [[c0:%.+]] = arith.constant 0 : index
- //CHECK: [[c128:%.+]] = arith.constant 128 : index
- //CHECK: [[c1024:%.+]] = arith.constant 1024 : index
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index
+ // CHECK-DAG: %[[C1024:.*]] = arith.constant 1024 : index
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
%c1024 = arith.constant 1024 : index
@@ -211,15 +198,15 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
%4 = xegpu.create_nd_tdesc %arg0[%0, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>
%5 = xegpu.create_nd_tdesc %arg1[%c0, %1] : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>
- // CHECK: [[scf:%.+]]:3 = scf.for [[arg3:%.+]] = [[c0]] to [[c1024]] step [[c128]]
- // CHECK-SAME: iter_args([[arg4:%.+]] = {{.*}}, [[arg5:%.+]] = {{.*}}, [[arg6:%.+]] = {{.*}}) ->
+ // CHECK: %[[SCF:.*]]:3 = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C1024]] step %[[C128]]
+ // CHECK-SAME: iter_args(%[[ARG4:.*]] = {{.*}}, %[[ARG5:.*]] = {{.*}}, %[[ARG6:.*]] = {{.*}}) ->
// CHECK-SAME: (!xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32>)
- // CHECK: [[a:%.+]] = xegpu.load_nd [[arg4]] : !xegpu.tensor_desc<16x128xf16> -> vector<16x128xf16>
- // CHECK: [[b:%.+]] = xegpu.load_nd [[arg5]] : !xegpu.tensor_desc<128x16xf16> -> vector<128x16xf16>
- // CHECK: [[c:%.+]] = xegpu.dpas [[a]], [[b]], [[arg6]] : vector<16x128xf16>, vector<128x16xf16>, vector<16x16xf32> -> vector<16x16xf32>
- // CHECK: [[at:%.+]] = xegpu.update_nd_offset [[arg4]], [[[c0]], [[c128]]] : !xegpu.tensor_desc<16x128xf16>
- // CHECK: [[bt:%.+]] = xegpu.update_nd_offset [[arg5]], [[[c128]], [[c0]]] : !xegpu.tensor_desc<128x16xf16>
- // CHECK: scf.yield [[at]], [[bt]], [[c]] : !xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32>
+ // CHECK: %[[A:.*]] = xegpu.load_nd %[[ARG4]] : !xegpu.tensor_desc<16x128xf16> -> vector<16x128xf16>
+ // CHECK: %[[B:.*]] = xegpu.load_nd %[[ARG5]] : !xegpu.tensor_desc<128x16xf16> -> vector<128x16xf16>
+ // CHECK: %[[C:.*]] = xegpu.dpas %[[A]], %[[B]], %[[ARG6]] : vector<16x128xf16>, vector<128x16xf16>, vector<16x16xf32> -> vector<16x16xf32>
+ // CHECK: %[[AT:.*]] = xegpu.update_nd_offset %[[ARG4]], [%[[C0]], %[[C128]]] : !xegpu.tensor_desc<16x128xf16>
+ // CHECK: %[[BT:.*]] = xegpu.update_nd_offset %[[ARG5]], [%[[C128]], %[[C0]]] : !xegpu.tensor_desc<128x16xf16>
+ // CHECK: scf.yield %[[AT]], %[[BT]], %[[C]] : !xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32>
%6:3 = scf.for %arg3 = %c0 to %c1024 step %c128 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3)
-> (!xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>,
!xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>, vector<128x128xf32>) {
@@ -252,7 +239,7 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
// CHECK: scf.condition{{.*}} : vector<16xf32>, i32
scf.condition(%4) %arg2, %arg3 : vector<256xf32>, i32
} do {
- // CHECK: ([[arg2:%.+]]: vector<16xf32>, [[arg3:%.+]]: i32)
+ // CHECK: (%[[ARG2:.*]]: vector<16xf32>, %[[ARG3:.*]]: i32)
^bb0(%arg2: vector<256xf32>, %arg3: i32):
xegpu.store_nd %arg2, %2 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
%4 = arith.addi %arg3, %c1_i32 : i32
@@ -344,9 +331,9 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
%cond4 = arith.cmpi slt, %sg_id, %c31 : index
%cond5 = arith.andi %cond3, %cond4 : i1
scf.if %cond5 {
- // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
- // CHECK: %[[C2:.*]] = arith.constant 2 : index
- // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C2]]
+ // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C2]]
%tdesc = xegpu.create_nd_tdesc %src2[0, 0] : memref<128x64xf32>
-> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc
diff --git a/mlir/test/Examples/NVGPU/Ch0.py b/mlir/test/Examples/NVGPU/Ch0.py
index 8f60088..e09720a 100644
--- a/mlir/test/Examples/NVGPU/Ch0.py
+++ b/mlir/test/Examples/NVGPU/Ch0.py
@@ -1,5 +1,9 @@
# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
-# RUN: %PYTHON %s | FileCheck %s
+# RUN: sh -c 'if [[ "%mlir_run_cuda_sm90_tests" == "1" ]]; \
+# RUN: then %PYTHON %s | FileCheck %s; \
+# RUN: else export MLIR_NVDSL_PRINT_IR=1; \
+# RUN: %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
+
# ===----------------------------------------------------------------------===//
# Chapter 0 : Hello World
@@ -33,7 +37,7 @@ def main(alpha):
# + operator generates arith.addi
myValue = alpha + tidx
# Print from a GPU thread
- gpu.printf("GPU thread %llu has %llu\n", [tidx, myValue])
+ gpu.printf("GPU thread %llu has %llu\n", tidx, myValue)
# 3. Call the GPU kernel
kernel()
@@ -43,8 +47,24 @@ alpha = 100
# 4. The `mlir_func` decorator JIT compiles the IR and executes the MLIR function.
main(alpha)
-
# CHECK: GPU thread 0 has 100
# CHECK: GPU thread 1 has 101
# CHECK: GPU thread 2 has 102
# CHECK: GPU thread 3 has 103
+
+# DUMPIR: func.func @main(%arg0: index) attributes {llvm.emit_c_interface} {
+# DUMPIR: %[[C0_I32:.*]] = arith.constant 0 : i32
+# DUMPIR: %[[C1:.*]] = arith.constant 1 : index
+# DUMPIR: %[[C1_0:.*]] = arith.constant 1 : index
+# DUMPIR: %[[C1_1:.*]] = arith.constant 1 : index
+# DUMPIR: %[[C4:.*]] = arith.constant 4 : index
+# DUMPIR: %[[C1_2:.*]] = arith.constant 1 : index
+# DUMPIR: %[[C1_3:.*]] = arith.constant 1 : index
+# DUMPIR: gpu.launch blocks(%arg1, %arg2, %arg3) in (%arg7 = %[[C1]], %arg8 = %[[C1_0]], %arg9 = %[[C1_1]]) threads(%arg4, %arg5, %arg6) in (%arg10 = %[[C4]], %arg11 = %[[C1_2]], %arg12 = %[[C1_3]]) dynamic_shared_memory_size %[[C0_I32]] {
+# DUMPIR: %[[TIDX:.*]] = gpu.thread_id x
+# DUMPIR: %[[MYVAL:.*]] = arith.addi %arg0, %[[TIDX]] : index
+# DUMPIR: gpu.printf "GPU thread %llu has %llu\0A", %[[TIDX]], %[[MYVAL]] : index, index
+# DUMPIR: gpu.terminator
+# DUMPIR: }
+# DUMPIR: return
+# DUMPIR: }
diff --git a/mlir/test/Examples/NVGPU/Ch1.py b/mlir/test/Examples/NVGPU/Ch1.py
index cfb48d5..6e44e4d 100644
--- a/mlir/test/Examples/NVGPU/Ch1.py
+++ b/mlir/test/Examples/NVGPU/Ch1.py
@@ -1,5 +1,9 @@
# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
-# RUN: %PYTHON %s | FileCheck %s
+# RUN: sh -c 'if [[ "%mlir_run_cuda_sm90_tests" == "1" ]]; \
+# RUN: then %PYTHON %s | FileCheck %s; \
+# RUN: else export MLIR_NVDSL_PRINT_IR=1; \
+# RUN: %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
+
# ===----------------------------------------------------------------------===//
# Chapter 1 : 2D Saxpy
@@ -24,12 +28,12 @@ import numpy as np
def saxpy(x, y, alpha):
# 1. Use MLIR GPU dialect to allocate and copy memory
token_ty = gpu.AsyncTokenType.get()
- t1 = gpu.wait(token_ty, [])
+ t1 = gpu.wait([])
x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
t4 = gpu.memcpy(token_ty, [t3], x_dev, x)
t5 = gpu.memcpy(token_ty, [t4], y_dev, y)
- t6 = gpu.wait(token_ty, [t5])
+ t6 = gpu.wait([t5])
# 2. Compute 2D SAXPY kernel
@NVDSL.mlir_gpu_launch(grid=(M, 1, 1), block=(N, 1, 1))
@@ -47,7 +51,7 @@ def saxpy(x, y, alpha):
saxpy_kernel()
t7 = gpu.memcpy(token_ty, [t6], y, y_dev)
- gpu.wait(token_ty, [t7])
+ gpu.wait([t7])
# 3. Pass numpy arrays to MLIR
@@ -56,11 +60,32 @@ N = 32
alpha = 2.0
x = np.random.randn(M, N).astype(np.float32)
y = np.ones((M, N), np.float32)
+
saxpy(x, y, alpha)
-# 4. Verify MLIR with reference computation
-ref = np.ones((M, N), np.float32)
-ref += x * alpha
-np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
-print("PASS")
+if os.getenv("MLIR_NVDSL_PRINT_IR") != "1":
+ # 4. Verify MLIR with reference computation
+ ref = np.ones((M, N), np.float32)
+ ref += x * alpha
+ np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
+ print("PASS")
# CHECK-NOT: Mismatched elements
+# CHECK: PASS
+
+# DUMPIR: func.func @saxpy(%[[ARG0:.*]]: memref<256x32xf32>, %[[ARG1:.*]]: memref<256x32xf32>, %[[ARG2:.*]]: f32) attributes {llvm.emit_c_interface} {
+# DUMPIR: %[[WAIT0:.*]] = gpu.wait async
+# DUMPIR: %[[MEMREF:.*]], %[[ASYNC0:.*]] = gpu.alloc async [%[[WAIT0]]] () : memref<256x32xf32>
+# DUMPIR: %[[MEMREF0:.*]], %[[ASYNC1:.*]] = gpu.alloc async [%[[ASYNC0]]] () : memref<256x32xf32>
+# DUMPIR: %[[MEMCPY1:.*]] = gpu.memcpy async [%[[ASYNC1]]] %[[MEMREF]], %[[ARG0]] : memref<256x32xf32>, memref<256x32xf32>
+# DUMPIR: %[[MEMCPY2:.*]] = gpu.memcpy async [%[[MEMCPY1]]] %[[MEMREF0]], %[[ARG1]] : memref<256x32xf32>, memref<256x32xf32>
+# DUMPIR: %[[WAIT1:.*]] = gpu.wait async [%[[MEMCPY2]]]
+# DUMPIR: %[[LD0:.*]] = memref.load %[[MEMREF]][%{{.*}}, %{{.*}}] : memref<256x32xf32>
+# DUMPIR: %[[LD1:.*]] = memref.load %[[MEMREF0]][%{{.*}}, %{{.*}}] : memref<256x32xf32>
+# DUMPIR: %[[MUL:.*]] = arith.mulf %[[LD0]], %[[ARG2]] : f32
+# DUMPIR: %[[ADD:.*]] = arith.addf %[[LD1]], %[[MUL]] : f32
+# DUMPIR: memref.store %[[ADD]], %[[MEMREF0]][%{{.*}}, %{{.*}}] : memref<256x32xf32>
+# DUMPIR: gpu.terminator
+# DUMPIR: %[[MEMCPY3:.*]] = gpu.memcpy async [%[[WAIT1]]] %[[ARG1]], %[[MEMREF0]] : memref<256x32xf32>, memref<256x32xf32>
+# DUMPIR: %[[WAIT2:.*]] = gpu.wait async [%[[MEMCPY3]]]
+# DUMPIR: return
+# DUMPIR: }
diff --git a/mlir/test/Examples/NVGPU/Ch2.py b/mlir/test/Examples/NVGPU/Ch2.py
index 729913c..aba610c 100644
--- a/mlir/test/Examples/NVGPU/Ch2.py
+++ b/mlir/test/Examples/NVGPU/Ch2.py
@@ -1,5 +1,9 @@
# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
-# RUN: %PYTHON %s | FileCheck %s
+# RUN: sh -c 'if [[ "%mlir_run_cuda_sm90_tests" == "1" ]]; \
+# RUN: then %PYTHON %s | FileCheck %s; \
+# RUN: else export MLIR_NVDSL_PRINT_IR=1; \
+# RUN: %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
+
# ===----------------------------------------------------------------------===//
# Chapter 2 : 2D Saxpy with TMA
@@ -28,12 +32,12 @@ import numpy as np
@NVDSL.mlir_func
def saxpy(x, y, alpha):
token_ty = gpu.AsyncTokenType.get()
- t1 = gpu.wait(token_ty, [])
+ t1 = gpu.wait([])
x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
t4 = gpu.memcpy(token_ty, [t3], x_dev, x)
t5 = gpu.memcpy(token_ty, [t4], y_dev, y)
- t6 = gpu.wait(token_ty, [t5])
+ t6 = gpu.wait([t5])
x_tma = TMA([1, N], x.type)
y_tma = TMA([1, N], y.type)
@@ -74,7 +78,7 @@ def saxpy(x, y, alpha):
saxpy_tma_kernel()
t7 = gpu.memcpy(token_ty, [t6], y, y_dev)
- gpu.wait(token_ty, [t7])
+ gpu.wait([t7])
# 3. Pass numpy arrays to MLIR
@@ -85,9 +89,46 @@ x = np.random.randn(M, N).astype(np.float32)
y = np.ones((M, N), np.float32)
saxpy(x, y, alpha)
-# 4. Verify MLIR with reference computation
-ref = np.ones((M, N), np.float32)
-ref += x * alpha
-np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
-print("PASS")
+if os.getenv("MLIR_NVDSL_PRINT_IR") != "1":
+ # 4. Verify MLIR with reference computation
+ ref = np.ones((M, N), np.float32)
+ ref += x * alpha
+ np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
+ print("PASS")
# CHECK-NOT: Mismatched elements
+# CHECK: PASS
+
+# DUMPIR: func.func @saxpy(%{{.*}}: memref<256x32xf32>, %[[ARG1:.*]]: memref<256x32xf32>, %[[ARG2:.*]]: f32) attributes {llvm.emit_c_interface} {
+# DUMPIR: %[[WAIT0:.*]] = gpu.wait async
+# DUMPIR: %[[MEMREF:.*]], %[[ASYNC0:.*]] = gpu.alloc async [%[[WAIT0]]] () : memref<256x32xf32>
+# DUMPIR: %[[CAST:.*]] = memref.cast %[[MEMREF]] : memref<256x32xf32> to memref<*xf32>
+# DUMPIR: %[[C1:.*]] = arith.constant 1 : index
+# DUMPIR: %[[C32:.*]] = arith.constant 32 : index
+# DUMPIR: %[[TMA0:.*]] = nvgpu.tma.create.descriptor %[[CAST]] box[%[[C1]], %[[C32]]] : memref<*xf32> -> <tensor = memref<1x32xf32, 3>, swizzle = none, l2promo = none, oob = zero, interleave = none>
+# DUMPIR: %[[C0:.*]] = arith.constant 0 : index
+# DUMPIR: %[[EQ:.*]] = arith.cmpi eq, %{{.*}}, %[[C0]] : index
+# DUMPIR: %[[MB:.*]] = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>>
+# DUMPIR: %[[C0_10:.*]] = arith.constant 0 : index
+# DUMPIR: %[[C1_11:.*]] = arith.constant 1 : index
+# DUMPIR: nvgpu.mbarrier.init %[[MB]][%[[C0_10]]], %[[C1_11]], predicate = %[[EQ]] : <memorySpace = #gpu.address_space<workgroup>>
+# DUMPIR: %[[DSM0:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR: %[[C0_12:.*]] = arith.constant 0 : index
+# DUMPIR: %[[VIEW:.*]] = memref.view %[[DSM0]][%[[C0_12]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<1x32xf32, #gpu.address_space<workgroup>>
+# DUMPIR: %[[DSM1:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR: %[[C128:.*]] = arith.constant 128 : index
+# DUMPIR: %[[VIEW_13:.*]] = memref.view %[[DSM1]][%[[C128]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<1x32xf32, #gpu.address_space<workgroup>>
+# DUMPIR: nvgpu.tma.async.load %[[TMA0]][%{{.*}}, %{{.*}}], %[[MB]][%{{.*}}] to %[[VIEW]], predicate = %[[EQ]] : <tensor = memref<1x32xf32, 3>, swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>> -> memref<1x32xf32, #gpu.address_space<workgroup>>
+# DUMPIR: nvgpu.mbarrier.arrive.expect_tx %[[MB]][%{{.*}}], %{{.*}}, predicate = %[[EQ]] : <memorySpace = #gpu.address_space<workgroup>>
+# DUMPIR: %[[C0_20:.*]] = arith.constant 0 : index
+# DUMPIR: %[[C10000000:.*]] = arith.constant 10000000 : index
+# DUMPIR: %[[FALSE:.*]] = arith.constant false
+# DUMPIR: nvgpu.mbarrier.try_wait.parity %[[MB]][%[[C0_20]]], %[[FALSE]], %[[C10000000]] : <memorySpace = #gpu.address_space<workgroup>>
+# DUMPIR: %[[C0_21:.*]] = arith.constant 0 : index
+# DUMPIR: %[[LD0:.*]] = memref.load %[[VIEW]][%[[C0_21]], %{{.*}}] : memref<1x32xf32, #gpu.address_space<workgroup>>
+# DUMPIR: %[[C0_22:.*]] = arith.constant 0 : index
+# DUMPIR: %[[LD1:.*]] = memref.load %[[VIEW_13]][%[[C0_22]], %{{.*}}] : memref<1x32xf32, #gpu.address_space<workgroup>>
+# DUMPIR: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<256x32xf32>
+# DUMPIR: %[[MEMCPY3:.*]] = gpu.memcpy async [%{{.*}}] %[[ARG1]], %{{.*}} : memref<256x32xf32>, memref<256x32xf32>
+# DUMPIR: %{{.*}} = gpu.wait async [%[[MEMCPY3]]]
+# DUMPIR: return
+# DUMPIR: }
diff --git a/mlir/test/Examples/NVGPU/Ch3.py b/mlir/test/Examples/NVGPU/Ch3.py
index eb96b11..fe11575 100644
--- a/mlir/test/Examples/NVGPU/Ch3.py
+++ b/mlir/test/Examples/NVGPU/Ch3.py
@@ -1,5 +1,9 @@
# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
-# RUN: %PYTHON %s | FileCheck %s
+# RUN: sh -c 'if [[ "%mlir_run_cuda_sm90_tests" == "1" ]]; \
+# RUN: then %PYTHON %s | FileCheck %s; \
+# RUN: else export MLIR_NVDSL_PRINT_IR=1; \
+# RUN: %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
+
# ===----------------------------------------------------------------------===//
# Chapter 3 : GEMM 128x128x64 with Tensor Core
@@ -60,13 +64,13 @@ def tma_load(
@NVDSL.mlir_func
def gemm_128_128_64(a, b, d):
token_ty = gpu.AsyncTokenType.get()
- t1 = gpu.wait(token_ty, [])
+ t1 = gpu.wait([])
a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])
d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], [])
t5 = gpu.memcpy(token_ty, [t4], a_dev, a)
t6 = gpu.memcpy(token_ty, [t5], b_dev, b)
- t7 = gpu.wait(token_ty, [t6])
+ t7 = gpu.wait([t6])
sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
a_tma = TMA([128, 64], a.type, swizzle=sw)
@@ -111,7 +115,7 @@ def gemm_128_128_64(a, b, d):
gemm_tma_kernel()
t8 = gpu.memcpy(token_ty, [t7], d, d_dev)
- gpu.wait(None, [t8])
+ gpu.wait([t8])
# Python pass arguments to MLIR
@@ -123,7 +127,73 @@ b = np.random.randn(K, N).astype(np.float16)
d = np.zeros((M, N), np.float32)
gemm_128_128_64(a, b, d)
-ref_d = a.astype(np.float16) @ b.astype(np.float16)
-np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
-print("PASS")
+if os.getenv("MLIR_NVDSL_PRINT_IR") != "1":
+ # Verify MLIR program with reference computation in python
+ ref_d = a.astype(np.float16) @ b.astype(np.float16)
+ np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
+ print("PASS")
# CHECK-NOT: Mismatched elements
+# CHECK: PASS
+
+# DUMPIR: func.func @gemm_128_128_64(%{{.*}}: memref<128x64xf16>, %{{.*}}: memref<64x128xf16>, %[[ARG2:.*]]: memref<128x128xf32>) attributes {llvm.emit_c_interface} {
+# DUMPIR: %[[C128:.*]] = arith.constant 128 : index
+# DUMPIR: %[[C64:.*]] = arith.constant 64 : index
+# DUMPIR: %[[TMA0:.*]] = nvgpu.tma.create.descriptor %{{.*}} box[%[[C128]], %[[C64]]] : memref<*xf16> -> <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
+# DUMPIR: %[[CAST1:.*]] = memref.cast %{{.*}} : memref<64x128xf16> to memref<*xf16>
+# DUMPIR: %[[C64_5:.*]] = arith.constant 64 : index
+# DUMPIR: %[[C64_6:.*]] = arith.constant 64 : index
+# DUMPIR: %[[TMA1:.*]] = nvgpu.tma.create.descriptor %[[CAST1]] box[%[[C64_5]], %[[C64_6]]] : memref<*xf16> -> <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
+# DUMPIR: %[[THREADID:.*]] = gpu.thread_id x
+# DUMPIR: %[[MB:.*]] = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>>
+# DUMPIR: %[[C0:.*]] = arith.constant 0 : index
+# DUMPIR: %[[EQ:.*]] = arith.cmpi eq, %[[THREADID]], %[[C0]] : index
+# DUMPIR: %[[C0_12:.*]] = arith.constant 0 : index
+# DUMPIR: %[[C1_13:.*]] = arith.constant 1 : index
+# DUMPIR: nvgpu.mbarrier.init %[[MB]][%[[C0_12]]], %[[C1_13]], predicate = %[[EQ]] : <memorySpace = #gpu.address_space<workgroup>>
+# DUMPIR: nvgpu.tma.prefetch.descriptor %[[TMA0]], predicate = %[[EQ]] : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
+# DUMPIR: nvgpu.tma.prefetch.descriptor %[[TMA1]], predicate = %[[EQ]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
+# DUMPIR: %[[DSM0:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR: %[[C0_14:.*]] = arith.constant 0 : index
+# DUMPIR: %[[VIEW:.*]] = memref.view %[[DSM0]][%[[C0_14]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR: %[[DSM1:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR: %[[C16384:.*]] = arith.constant 16384 : index
+# DUMPIR: %[[VIEW_15:.*]] = memref.view %[[DSM1]][%[[C16384]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x128xf16, #gpu.address_space<workgroup>>
+# DUMPIR: %[[DSM2:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR: %[[C0_16:.*]] = arith.constant 0 : index
+# DUMPIR: %[[VIEW_17:.*]] = memref.view %[[DSM2]][%[[C0_16]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR: %[[DSM3:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR: %[[C16384_18:.*]] = arith.constant 16384 : index
+# DUMPIR: %[[VIEW_19:.*]] = memref.view %[[DSM3]][%[[C16384_18]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR: %[[DSM4:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR: %[[C24576:.*]] = arith.constant 24576 : index
+# DUMPIR: %[[VIEW_20:.*]] = memref.view %[[DSM4]][%[[C24576]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR: %[[C0_21:.*]] = arith.constant 0 : index
+# DUMPIR: %[[C32768:.*]] = arith.constant 32768 : index
+# DUMPIR: nvgpu.mbarrier.arrive.expect_tx %[[MB]][%[[C0_21]]], %[[C32768]], predicate = %[[EQ]] : <memorySpace = #gpu.address_space<workgroup>>
+# DUMPIR: %[[C0_22:.*]] = arith.constant 0 : index
+# DUMPIR: %[[C0_23:.*]] = arith.constant 0 : index
+# DUMPIR: %[[C0_24:.*]] = arith.constant 0 : index
+# DUMPIR: nvgpu.tma.async.load %[[TMA0]][%[[C0_23]], %[[C0_24]]], %[[MB]][%[[C0_22]]] to %[[VIEW_17]], predicate = %[[EQ]] : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>> -> memref<128x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR: %[[C0_25:.*]] = arith.constant 0 : index
+# DUMPIR: %[[C0_26:.*]] = arith.constant 0 : index
+# DUMPIR: %[[C0_27:.*]] = arith.constant 0 : index
+# DUMPIR: nvgpu.tma.async.load %[[TMA1]][%[[C0_26]], %[[C0_27]]], %[[MB]][%[[C0_25]]] to %[[VIEW_19]], predicate = %[[EQ]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>> -> memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR: %[[C0_28:.*]] = arith.constant 0 : index
+# DUMPIR: %[[C64_29:.*]] = arith.constant 64 : index
+# DUMPIR: %[[C0_30:.*]] = arith.constant 0 : index
+# DUMPIR: nvgpu.tma.async.load %[[TMA1]][%[[C64_29]], %[[C0_30]]], %[[MB]][%[[C0_28]]] to %[[VIEW_20]], predicate = %[[EQ]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>> -> memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR: %[[C0_31:.*]] = arith.constant 0 : index
+# DUMPIR: %[[C10000000:.*]] = arith.constant 10000000 : index
+# DUMPIR: %[[FALSE:.*]] = arith.constant false
+# DUMPIR: nvgpu.mbarrier.try_wait.parity %[[MB]][%[[C0_31]]], %[[FALSE]], %[[C10000000]] : <memorySpace = #gpu.address_space<workgroup>>
+# DUMPIR: %[[WG_ACC:.*]] = nvgpu.warpgroup.mma.init.accumulator -> <fragmented = vector<128x128xf32>>
+# DUMPIR: %[[GEN0:.*]] = nvgpu.warpgroup.generate.descriptor %[[VIEW]], %[[TMA0]] : memref<128x64xf16, #gpu.address_space<workgroup>>, <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>>
+# DUMPIR: %[[GEN1:.*]] = nvgpu.warpgroup.generate.descriptor %[[VIEW_15]], %[[TMA1]] : memref<64x128xf16, #gpu.address_space<workgroup>>, <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>>
+# DUMPIR: %[[MMA:.*]] = nvgpu.warpgroup.mma %[[GEN0]], %[[GEN1]], %[[WG_ACC]] {transposeB} : <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>>, <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>>, <fragmented = vector<128x128xf32>> -> <fragmented = vector<128x128xf32>>
+# DUMPIR: nvgpu.warpgroup.mma.store %[[MMA]], %{{.*}} : <fragmented = vector<128x128xf32>> to memref<128x128xf32>
+# DUMPIR: gpu.terminator
+# DUMPIR: }
+# DUMPIR: %[[CPY3:.*]] = gpu.memcpy async [%{{.*}}] %[[ARG2]], %{{.*}} : memref<128x128xf32>, memref<128x128xf32>
+# DUMPIR: gpu.wait async [%[[CPY3]]]
+# DUMPIR: return
+# DUMPIR: }
diff --git a/mlir/test/Examples/NVGPU/Ch4.py b/mlir/test/Examples/NVGPU/Ch4.py
index 0e3460f..dffafda 100644
--- a/mlir/test/Examples/NVGPU/Ch4.py
+++ b/mlir/test/Examples/NVGPU/Ch4.py
@@ -1,5 +1,9 @@
# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
-# RUN: %PYTHON %s | FileCheck %s
+# RUN: sh -c 'if [[ "%mlir_run_cuda_sm90_tests" == "1" ]]; \
+# RUN: then %PYTHON %s | FileCheck %s; \
+# RUN: else export MLIR_NVDSL_PRINT_IR=1; \
+# RUN: %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
+
# ===----------------------------------------------------------------------===//
# Chapter 4 : Multistage GEMM with Tensor Core
@@ -259,13 +263,13 @@ def epilogue(D: WGMMAMatrix, d_dev):
@NVDSL.mlir_func
def gemm_multistage(a, b, d, num_stages):
token_ty = gpu.AsyncTokenType.get()
- t1 = gpu.wait(token_ty, [])
+ t1 = gpu.wait([])
a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])
d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], [])
t5 = gpu.memcpy(token_ty, [t4], a_dev, a)
t6 = gpu.memcpy(token_ty, [t5], b_dev, b)
- t7 = gpu.wait(token_ty, [t6])
+ t7 = gpu.wait([t6])
sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
a_tma = TMA([128, 64], a.type, swizzle=sw)
@@ -297,7 +301,7 @@ def gemm_multistage(a, b, d, num_stages):
gemm_multistage_kernel()
t8 = gpu.memcpy(token_ty, [t7], d, d_dev)
- gpu.wait(None, [t8])
+ gpu.wait([t8])
# Python pass arguments to MLIR
@@ -313,11 +317,153 @@ d = np.zeros((M, N), np.float32)
gemm_multistage(a, b, d, num_stages=7)
+if os.getenv("MLIR_NVDSL_PRINT_IR") != "1":
+ # Verify MLIR with reference computation
+ ref_d = a.astype(np.float16) @ b.astype(np.float16)
+ np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
-# Verify MLIR with reference computation
-ref_d = a.astype(np.float16) @ b.astype(np.float16)
-np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
-
-
-print("PASS")
+ print("PASS")
# CHECK-NOT: Mismatched elements
+# CHECK: PASS
+
+# DUMPIR: func.func @gemm_multistage(%{{.*}}: memref<512x1024xf16>, %{{.*}}: memref<1024x256xf16>, %{{.*}}: memref<512x256xf32>) attributes {llvm.emit_c_interface} {
+# DUMPIR: scf.if %{{.*}} {
+# DUMPIR: %[[C0_INIT:.*]] = arith.constant 0 : index
+# DUMPIR: %[[C7:.*]] = arith.constant 7 : index
+# DUMPIR: %[[C1_INIT:.*]] = arith.constant 1 : index
+# DUMPIR: scf.for %arg15 = %[[C0_INIT]] to %[[C7]] step %[[C1_INIT]] {
+# DUMPIR: %[[C1_MBAR:.*]] = arith.constant 1 : index
+# DUMPIR: nvgpu.mbarrier.init %{{.*}}[%arg15], %[[C1_MBAR]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+# DUMPIR: }
+# DUMPIR: nvgpu.tma.prefetch.descriptor %{{.*}} : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
+# DUMPIR: nvgpu.tma.prefetch.descriptor %{{.*}} : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
+# DUMPIR: }
+# DUMPIR: %[[C0_PROLOGUE:.*]] = arith.constant 0 : index
+# DUMPIR: %[[C6:.*]] = arith.constant 6 : index
+# DUMPIR: %[[C1_PROLOGUE:.*]] = arith.constant 1 : index
+# DUMPIR: scf.for %arg15 = %[[C0_PROLOGUE]] to %[[C6]] step %[[C1_PROLOGUE]] {
+# DUMPIR: %[[BID_X_P:.*]] = gpu.block_id x
+# DUMPIR: %[[BID_Y_P:.*]] = gpu.block_id y
+# DUMPIR: %[[C128_P1:.*]] = arith.constant 128 : index
+# DUMPIR: %[[DIMX_P:.*]] = arith.muli %[[BID_X_P]], %[[C128_P1]] : index
+# DUMPIR: %[[C128_P2:.*]] = arith.constant 128 : index
+# DUMPIR: %[[DIMY_P:.*]] = arith.muli %[[BID_Y_P]], %[[C128_P2]] : index
+# DUMPIR: %{{.*}} = gpu.thread_id x
+# DUMPIR: %[[TID_X_P:.*]] = gpu.thread_id x
+# DUMPIR: %[[C0_P:.*]] = arith.constant 0 : index
+# DUMPIR: %[[PRED_P:.*]] = arith.cmpi eq, %[[TID_X_P]], %[[C0_P]] : index
+# DUMPIR: %[[C16384_P1:.*]] = arith.constant 16384 : index
+# DUMPIR: %[[OFF_A_P:.*]] = arith.muli %arg15, %[[C16384_P1]] : index
+# DUMPIR: %[[C16384_P2:.*]] = arith.constant 16384 : index
+# DUMPIR: %[[OFF_B_BASE_P:.*]] = arith.muli %arg15, %[[C16384_P2]] : index
+# DUMPIR: %[[C114688:.*]] = arith.constant 114688 : index
+# DUMPIR: %[[OFF_B1_P:.*]] = arith.addi %[[OFF_B_BASE_P]], %[[C114688]] : index
+# DUMPIR: %[[C8192:.*]] = arith.constant 8192 : index
+# DUMPIR: %[[OFF_B2_P:.*]] = arith.addi %[[OFF_B1_P]], %[[C8192]] : index
+# DUMPIR: %[[SMEM_A_P:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR: %[[VIEW_A_P:.*]] = memref.view %[[SMEM_A_P]][%[[OFF_A_P]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR: %[[SMEM_B1_P:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR: %[[VIEW_B1_P:.*]] = memref.view %[[SMEM_B1_P]][%[[OFF_B1_P]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR: %[[SMEM_B2_P:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR: %[[VIEW_B2_P:.*]] = memref.view %[[SMEM_B2_P]][%[[OFF_B2_P]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR: %[[C32768:.*]] = arith.constant 32768 : index
+# DUMPIR: nvgpu.mbarrier.arrive.expect_tx %{{.*}}[%arg15], %[[C32768]], predicate = %[[PRED_P]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+# DUMPIR: %[[C64_K_P:.*]] = arith.constant 64 : index
+# DUMPIR: %[[K_COORD_P:.*]] = arith.muli %arg15, %[[C64_K_P]] : index
+# DUMPIR: nvgpu.tma.async.load %{{.*}}[%[[K_COORD_P]], %[[DIMX_P]]], %{{.*}}[%arg15] to %[[VIEW_A_P]], predicate = %[[PRED_P]] : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<128x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR: nvgpu.tma.async.load %{{.*}}[%[[DIMY_P]], %[[K_COORD_P]]], %{{.*}}[%arg15] to %[[VIEW_B1_P]], predicate = %[[PRED_P]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR: %[[C64_OFF:.*]] = arith.constant 64 : index
+# DUMPIR: %[[DIMY_P_OFF:.*]] = arith.addi %[[DIMY_P]], %[[C64_OFF]] : index
+# DUMPIR: nvgpu.tma.async.load %{{.*}}[%[[DIMY_P_OFF]], %[[K_COORD_P]]], %{{.*}}[%arg15] to %[[VIEW_B2_P]], predicate = %[[PRED_P]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR: }
+# DUMPIR: %[[TID_X_LOOP:.*]] = gpu.thread_id x
+# DUMPIR: %[[ACC_INIT:.*]] = nvgpu.warpgroup.mma.init.accumulator -> <fragmented = vector<128x128xf32>>
+# DUMPIR: %[[FALSE_LOOP:.*]] = arith.constant false
+# DUMPIR: %[[C0_LOOP:.*]] = arith.constant 0 : index
+# DUMPIR: %[[C16_LOOP:.*]] = arith.constant 16 : index
+# DUMPIR: %[[C1_LOOP:.*]] = arith.constant 1 : index
+# DUMPIR: %[[LOOP_RES:.*]]:2 = scf.for %arg15 = %[[C0_LOOP]] to %[[C16_LOOP]] step %[[C1_LOOP]] iter_args(%arg16 = %[[ACC_INIT]], %arg17 = %[[FALSE_LOOP]]) -> (!nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>, i1) {
+# DUMPIR: %[[C7_L:.*]] = arith.constant 7 : index
+# DUMPIR: %[[STAGE_L:.*]] = arith.remui %arg15, %[[C7_L]] : index
+# DUMPIR: %[[C10M:.*]] = arith.constant 10000000 : index
+# DUMPIR: nvgpu.mbarrier.try_wait.parity %{{.*}}[%[[STAGE_L]]], %arg17, %[[C10M]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+# DUMPIR: %[[C16384_L:.*]] = arith.constant 16384 : index
+# DUMPIR: %[[OFF_A_L:.*]] = arith.muli %[[STAGE_L]], %[[C16384_L]] : index
+# DUMPIR: %[[C114688_L:.*]] = arith.constant 114688 : index
+# DUMPIR: %[[OFF_B_L:.*]] = arith.addi %[[OFF_A_L]], %[[C114688_L]] : index
+# DUMPIR: %[[SMEM_A_L:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR: %[[VIEW_A_L:.*]] = memref.view %[[SMEM_A_L]][%[[OFF_A_L]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR: %[[SMEM_B_L:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR: %[[VIEW_B_L:.*]] = memref.view %[[SMEM_B_L]][%[[OFF_B_L]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x128xf16, #gpu.address_space<workgroup>>
+# DUMPIR: %[[DESC_A_L:.*]] = nvgpu.warpgroup.generate.descriptor %[[VIEW_A_L]], %{{.*}} : memref<128x64xf16, #gpu.address_space<workgroup>>, <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>>
+# DUMPIR: %[[DESC_B_L:.*]] = nvgpu.warpgroup.generate.descriptor %[[VIEW_B_L]], %{{.*}} : memref<64x128xf16, #gpu.address_space<workgroup>>, <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>>
+# DUMPIR: %[[ACC_L:.*]] = nvgpu.warpgroup.mma %[[DESC_A_L]], %[[DESC_B_L]], %arg16 {transposeB} : <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>>, <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>>, <fragmented = vector<128x128xf32>> -> <fragmented = vector<128x128xf32>>
+# DUMPIR: %[[C6_NEXT:.*]] = arith.constant 6 : index
+# DUMPIR: %[[ITER_NEXT:.*]] = arith.addi %arg15, %[[C6_NEXT]] : index
+# DUMPIR: %[[C16_CMP:.*]] = arith.constant 16 : index
+# DUMPIR: %[[IN_RANGE:.*]] = arith.cmpi ult, %[[ITER_NEXT]], %[[C16_CMP]] : index
+# DUMPIR: %[[C0_CMP:.*]] = arith.constant 0 : index
+# DUMPIR: %[[IS_THREAD0_L:.*]] = arith.cmpi eq, %[[TID_X_LOOP]], %[[C0_CMP]] : index
+# DUMPIR: %[[DO_LOAD:.*]] = arith.andi %[[IN_RANGE]], %[[IS_THREAD0_L]] : i1
+# DUMPIR: %[[C6_STAGE:.*]] = arith.constant 6 : index
+# DUMPIR: %[[STAGE_NEXT_L:.*]] = arith.addi %arg15, %[[C6_STAGE]] : index
+# DUMPIR: %[[C7_MOD:.*]] = arith.constant 7 : index
+# DUMPIR: %[[STAGE_LOAD:.*]] = arith.remui %[[STAGE_NEXT_L]], %[[C7_MOD]] : index
+# DUMPIR: %[[BID_X_L:.*]] = gpu.block_id x
+# DUMPIR: %[[BID_Y_L:.*]] = gpu.block_id y
+# DUMPIR: %[[C128_L1:.*]] = arith.constant 128 : index
+# DUMPIR: %[[DIMX_L:.*]] = arith.muli %[[BID_X_L]], %[[C128_L1]] : index
+# DUMPIR: %[[C128_L2:.*]] = arith.constant 128 : index
+# DUMPIR: %[[DIMY_L:.*]] = arith.muli %[[BID_Y_L]], %[[C128_L2]] : index
+# DUMPIR: %[[TID_X_L1:.*]] = gpu.thread_id x
+# DUMPIR: %[[TID_X_L2:.*]] = gpu.thread_id x
+# DUMPIR: %[[C16384_LA1:.*]] = arith.constant 16384 : index
+# DUMPIR: %[[OFF_A_LOAD:.*]] = arith.muli %[[STAGE_LOAD]], %[[C16384_LA1]] : index
+# DUMPIR: %[[C16384_LA2:.*]] = arith.constant 16384 : index
+# DUMPIR: %[[OFF_B_BASE_LOAD:.*]] = arith.muli %[[STAGE_LOAD]], %[[C16384_LA2]] : index
+# DUMPIR: %[[C114688_LOAD:.*]] = arith.constant 114688 : index
+# DUMPIR: %[[OFF_B1_LOAD:.*]] = arith.addi %[[OFF_B_BASE_LOAD]], %[[C114688_LOAD]] : index
+# DUMPIR: %[[C8192_LOAD:.*]] = arith.constant 8192 : index
+# DUMPIR: %[[OFF_B2_LOAD:.*]] = arith.addi %[[OFF_B1_LOAD]], %[[C8192_LOAD]] : index
+# DUMPIR: %[[SMEM_A_LOAD:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR: %[[VIEW_A_LOAD:.*]] = memref.view %[[SMEM_A_LOAD]][%[[OFF_A_LOAD]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR: %[[SMEM_B1_LOAD:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR: %[[VIEW_B1_LOAD:.*]] = memref.view %[[SMEM_B1_LOAD]][%[[OFF_B1_LOAD]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR: %[[SMEM_B2_LOAD:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR: %[[VIEW_B2_LOAD:.*]] = memref.view %[[SMEM_B2_LOAD]][%[[OFF_B2_LOAD]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR: %[[C32768_LOAD:.*]] = arith.constant 32768 : index
+# DUMPIR: nvgpu.mbarrier.arrive.expect_tx %{{.*}}[%[[STAGE_LOAD]]], %[[C32768_LOAD]], predicate = %[[DO_LOAD]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+# DUMPIR: %[[C64_K_LOAD:.*]] = arith.constant 64 : index
+# DUMPIR: %[[K_COORD_LOAD:.*]] = arith.muli %[[STAGE_NEXT_L]], %[[C64_K_LOAD]] : index
+# DUMPIR: nvgpu.tma.async.load %{{.*}}[%[[K_COORD_LOAD]], %[[DIMX_L]]], %{{.*}}[%[[STAGE_LOAD]]] to %[[VIEW_A_LOAD]], predicate = %[[DO_LOAD]] : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<128x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR: nvgpu.tma.async.load %{{.*}}[%[[DIMY_L]], %[[K_COORD_LOAD]]], %{{.*}}[%[[STAGE_LOAD]]] to %[[VIEW_B1_LOAD]], predicate = %[[DO_LOAD]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR: %[[C64_OFF_LOAD:.*]] = arith.constant 64 : index
+# DUMPIR: %[[DIMY_L_OFF:.*]] = arith.addi %[[DIMY_L]], %[[C64_OFF_LOAD]] : index
+# DUMPIR: nvgpu.tma.async.load %{{.*}}[%[[DIMY_L_OFF]], %[[K_COORD_LOAD]]], %{{.*}}[%[[STAGE_LOAD]]] to %[[VIEW_B2_LOAD]], predicate = %[[DO_LOAD]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR: %[[C6_FLIP:.*]] = arith.constant 6 : index
+# DUMPIR: %[[IS_STAGE6:.*]] = arith.cmpi eq, %[[STAGE_L]], %[[C6_FLIP]] : index
+# DUMPIR: %[[TRUE:.*]] = arith.constant true
+# DUMPIR: %[[PARITY_FLIP:.*]] = arith.xori %arg17, %[[TRUE]] : i1
+# DUMPIR: %[[NEW_PARITY:.*]] = arith.select %[[IS_STAGE6]], %[[PARITY_FLIP]], %arg17 : i1
+# DUMPIR: scf.yield %[[ACC_L]], %[[NEW_PARITY]] : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>, i1
+# DUMPIR: }
+# DUMPIR: nvvm.wgmma.wait.group.sync.aligned 0
+# DUMPIR: %[[TID_X_EPI:.*]] = gpu.thread_id x
+# DUMPIR: %[[BID_X_EPI:.*]] = gpu.block_id x
+# DUMPIR: %[[BID_Y_EPI:.*]] = gpu.block_id y
+# DUMPIR: %[[C128_EPI1:.*]] = arith.constant 128 : index
+# DUMPIR: %[[DIMX_EPI:.*]] = arith.muli %[[BID_X_EPI]], %[[C128_EPI1]] : index
+# DUMPIR: %[[C128_EPI2:.*]] = arith.constant 128 : index
+# DUMPIR: %[[DIMY_EPI:.*]] = arith.muli %[[BID_Y_EPI]], %[[C128_EPI2]] : index
+# DUMPIR: %[[SMEM_EPI:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR: %[[C0_VIEW:.*]] = arith.constant 0 : index
+# DUMPIR: %[[VIEW_EPI:.*]] = memref.view %[[SMEM_EPI]][%[[C0_VIEW]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x128xf32, #gpu.address_space<workgroup>>
+# DUMPIR: %[[SUBVIEW_EPI:.*]] = memref.subview %{{.*}}[%[[DIMX_EPI]], %[[DIMY_EPI]]] [128, 128] [1, 1] : memref<512x256xf32> to memref<128x128xf32, strided<[256, 1], offset: ?>>
+# DUMPIR: nvgpu.warpgroup.mma.store %[[LOOP_RES]]#0, %[[VIEW_EPI]] : <fragmented = vector<128x128xf32>> to memref<128x128xf32, #gpu.address_space<workgroup>>
+# DUMPIR: gpu.barrier
+# DUMPIR: %[[C0_STORE:.*]] = arith.constant 0 : index
+# DUMPIR: %[[C128_STORE:.*]] = arith.constant 128 : index
+# DUMPIR: %[[C1_STORE:.*]] = arith.constant 1 : index
+# DUMPIR: scf.for %arg15 = %[[C0_STORE]] to %[[C128_STORE]] step %[[C1_STORE]] {
+# DUMPIR: %[[VAL_LOAD:.*]] = memref.load %[[VIEW_EPI]][%arg15, %[[TID_X_EPI]]] : memref<128x128xf32, #gpu.address_space<workgroup>>
+# DUMPIR: memref.store %[[VAL_LOAD]], %[[SUBVIEW_EPI]][%arg15, %[[TID_X_EPI]]] : memref<128x128xf32, strided<[256, 1], offset: ?>>
diff --git a/mlir/test/Examples/NVGPU/Ch5.py b/mlir/test/Examples/NVGPU/Ch5.py
index f98cfd7..b725e50 100644
--- a/mlir/test/Examples/NVGPU/Ch5.py
+++ b/mlir/test/Examples/NVGPU/Ch5.py
@@ -1,5 +1,9 @@
# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
-# RUN: %PYTHON %s | FileCheck %s
+# RUN: sh -c 'if [[ "%mlir_run_cuda_sm90_tests" == "1" ]]; \
+# RUN: then %PYTHON %s | FileCheck %s; \
+# RUN: else export MLIR_NVDSL_PRINT_IR=1; \
+# RUN: %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
+
# ===----------------------------------------------------------------------===//
# Chapter 5 : Warp Specialized GEMM with Tensor Core
@@ -156,7 +160,7 @@ def producer_loop(
):
phase = const(True, ty=T.bool())
- for iv, phase in scf.for_(0, (K // TILE_K), 1, [phase]):
+ for iv, phase, _ in scf.for_(0, (K // TILE_K), 1, [phase]):
stage = iv % num_stages
# Wait MMA to be done
mbar_mma[stage].try_wait(phase)
@@ -253,13 +257,13 @@ def epilogue(D: WGMMAMatrix, d_dev):
@NVDSL.mlir_func
def gemm_warp_specialized(a, b, d, num_stages):
token_ty = gpu.AsyncTokenType.get()
- t1 = gpu.wait(token_ty, [])
+ t1 = gpu.wait([])
a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])
d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], [])
t5 = gpu.memcpy(token_ty, [t4], a_dev, a)
t6 = gpu.memcpy(token_ty, [t5], b_dev, b)
- t7 = gpu.wait(token_ty, [t6])
+ t7 = gpu.wait([t6])
sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
a_tma = TMA([128, 64], a.type, swizzle=sw)
@@ -295,7 +299,7 @@ def gemm_warp_specialized(a, b, d, num_stages):
gemm_warp_specialized_kernel()
t8 = gpu.memcpy(token_ty, [t7], d, d_dev)
- gpu.wait(None, [t8])
+ gpu.wait([t8])
# Python pass arguments to MLIR
@@ -311,11 +315,166 @@ d = np.zeros((M, N), np.float32)
gemm_warp_specialized(a, b, d, num_stages=7)
+if os.getenv("MLIR_NVDSL_PRINT_IR") != "1":
+ # Verify MLIR with reference computation
+ ref_d = a.astype(np.float16) @ b.astype(np.float16)
+ np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
-# Verify MLIR with reference computation
-ref_d = a.astype(np.float16) @ b.astype(np.float16)
-np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
-
-
-print("PASS")
+ print("PASS")
# CHECK-NOT: Mismatched elements
+# CHECK: PASS
+
+# DUMPIR: %[[TID_X:.*]] = gpu.thread_id x
+# DUMPIR: %[[C128:.*]] = arith.constant 128 : index
+# DUMPIR: %[[REM1:.*]] = arith.remui %[[TID_X]], %[[C128]] : index
+# DUMPIR: %[[C0:.*]] = arith.constant 0 : index
+# DUMPIR: %[[IS_PRIMARY:.*]] = arith.cmpi eq, %[[REM1]], %[[C0]] : index
+# DUMPIR: %[[C128_1:.*]] = arith.constant 128 : index
+# DUMPIR: %[[DIV1:.*]] = arith.divui %[[TID_X]], %[[C128_1]] : index
+# DUMPIR: %[[C1:.*]] = arith.constant 1 : index
+# DUMPIR: %[[IS_PRODUCER:.*]] = arith.cmpi eq, %[[DIV1]], %[[C1]] : index
+# DUMPIR: %[[TID_X_2:.*]] = gpu.thread_id x
+# DUMPIR: %[[C128_2:.*]] = arith.constant 128 : index
+# DUMPIR: %[[REM2:.*]] = arith.remui %[[TID_X_2]], %[[C128_2]] : index
+# DUMPIR: %[[C0_2:.*]] = arith.constant 0 : index
+# DUMPIR: %[[IS_PRIMARY_2:.*]] = arith.cmpi eq, %[[REM2]], %[[C0_2]] : index
+# DUMPIR: %[[C128_3:.*]] = arith.constant 128 : index
+# DUMPIR: %[[DIV2:.*]] = arith.divui %[[TID_X_2]], %[[C128_3]] : index
+# DUMPIR: %[[C0_3:.*]] = arith.constant 0 : index
+# DUMPIR: %[[IS_CONSUMER:.*]] = arith.cmpi eq, %[[DIV2]], %[[C0_3]] : index
+# DUMPIR: %[[TID_X_3:.*]] = gpu.thread_id x
+# DUMPIR: %[[MBAR_MMA:.*]] = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+# DUMPIR: %[[MBAR_TMA:.*]] = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+# DUMPIR: %[[C0_4:.*]] = arith.constant 0 : index
+# DUMPIR: %[[IS_THREAD0:.*]] = arith.cmpi eq, %[[TID_X_3]], %[[C0_4]] : index
+# DUMPIR: scf.if %[[IS_THREAD0]] {
+# DUMPIR: %[[C0_INIT:.*]] = arith.constant 0 : index
+# DUMPIR: %[[C7:.*]] = arith.constant 7 : index
+# DUMPIR: %[[C1_INIT:.*]] = arith.constant 1 : index
+# DUMPIR: scf.for %arg15 = %[[C0_INIT]] to %[[C7]] step %[[C1_INIT]] {
+# DUMPIR: %[[C1_INIT_VAL:.*]] = arith.constant 1 : index
+# DUMPIR: nvgpu.mbarrier.init %[[MBAR_MMA]][%arg15], %[[C1_INIT_VAL]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+# DUMPIR: %[[C1_INIT_VAL_2:.*]] = arith.constant 1 : index
+# DUMPIR: nvgpu.mbarrier.init %[[MBAR_TMA]][%arg15], %[[C1_INIT_VAL_2]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+# DUMPIR: }
+# DUMPIR: nvgpu.tma.prefetch.descriptor %{{.*}} : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
+# DUMPIR: nvgpu.tma.prefetch.descriptor %{{.*}} : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
+# DUMPIR: }
+# DUMPIR: scf.if %[[IS_PRODUCER]] {
+# DUMPIR: nvvm.setmaxregister decrease 40
+# DUMPIR: %[[TRUE:.*]] = arith.constant true
+# DUMPIR: %[[C0_PROD:.*]] = arith.constant 0 : index
+# DUMPIR: %[[C16:.*]] = arith.constant 16 : index
+# DUMPIR: %[[C1_PROD:.*]] = arith.constant 1 : index
+# DUMPIR: %[[PROD_LOOP:.*]] = scf.for %arg15 = %[[C0_PROD]] to %[[C16]] step %[[C1_PROD]] iter_args(%arg16 = %[[TRUE]]) -> (i1) {
+# DUMPIR: %[[C7_PROD:.*]] = arith.constant 7 : index
+# DUMPIR: %[[SLOT:.*]] = arith.remui %arg15, %[[C7_PROD]] : index
+# DUMPIR: %[[TIMEOUT:.*]] = arith.constant 10000000 : index
+# DUMPIR: nvgpu.mbarrier.try_wait.parity %[[MBAR_MMA]][%[[SLOT]]], %arg16, %[[TIMEOUT]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+# DUMPIR: %[[C6:.*]] = arith.constant 6 : index
+# DUMPIR: %[[IS_LAST:.*]] = arith.cmpi eq, %[[SLOT]], %[[C6]] : index
+# DUMPIR: %[[TRUE_2:.*]] = arith.constant true
+# DUMPIR: %[[FLIP:.*]] = arith.xori %arg16, %[[TRUE_2]] : i1
+# DUMPIR: %[[PHASE:.*]] = arith.select %[[IS_LAST]], %[[FLIP]], %arg16 : i1
+# DUMPIR: %[[BID_X:.*]] = gpu.block_id x
+# DUMPIR: %[[BID_Y:.*]] = gpu.block_id y
+# DUMPIR: %[[C128_TILE:.*]] = arith.constant 128 : index
+# DUMPIR: %[[DIM_X:.*]] = arith.muli %[[BID_X]], %[[C128_TILE]] : index
+# DUMPIR: %[[C128_TILE_2:.*]] = arith.constant 128 : index
+# DUMPIR: %[[DIM_Y:.*]] = arith.muli %[[BID_Y]], %[[C128_TILE_2]] : index
+# DUMPIR: %[[TID_PROD:.*]] = gpu.thread_id x
+# DUMPIR: %[[C16384:.*]] = arith.constant 16384 : index
+# DUMPIR: %[[OFF_A:.*]] = arith.muli %[[SLOT]], %[[C16384]] : index
+# DUMPIR: %[[C16384_2:.*]] = arith.constant 16384 : index
+# DUMPIR: %[[OFF_B_BASE:.*]] = arith.muli %[[SLOT]], %[[C16384_2]] : index
+# DUMPIR: %[[C114688:.*]] = arith.constant 114688 : index
+# DUMPIR: %[[OFF_B1:.*]] = arith.addi %[[OFF_B_BASE]], %[[C114688]] : index
+# DUMPIR: %[[C8192:.*]] = arith.constant 8192 : index
+# DUMPIR: %[[OFF_B2:.*]] = arith.addi %[[OFF_B1]], %[[C8192]] : index
+# DUMPIR: %[[SMEM:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR: %[[VIEW_A:.*]] = memref.view %[[SMEM]][%[[OFF_A]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR: %[[SMEM_2:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR: %[[VIEW_B1:.*]] = memref.view %[[SMEM_2]][%[[OFF_B1]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR: %[[SMEM_3:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR: %[[VIEW_B2:.*]] = memref.view %[[SMEM_3]][%[[OFF_B2]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR: %[[TX_COUNT:.*]] = arith.constant 32768 : index
+# DUMPIR: nvgpu.mbarrier.arrive.expect_tx %[[MBAR_TMA]][%[[SLOT]]], %[[TX_COUNT]], predicate = %[[IS_PRIMARY]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+# DUMPIR: %[[C128_WG:.*]] = arith.constant 128 : index
+# DUMPIR: %[[TID_MOD:.*]] = arith.remui %[[TID_PROD]], %[[C128_WG]] : index
+# DUMPIR: %[[C0_TMA:.*]] = arith.constant 0 : index
+# DUMPIR: %[[IS_TMA_THREAD:.*]] = arith.cmpi eq, %[[TID_MOD]], %[[C0_TMA]] : index
+# DUMPIR: %[[C64:.*]] = arith.constant 64 : index
+# DUMPIR: %[[K_COORD:.*]] = arith.muli %arg15, %[[C64]] : index
+# DUMPIR: nvgpu.tma.async.load %{{.*}}[%[[K_COORD]], %[[DIM_X]]], %[[MBAR_TMA]][%[[SLOT]]] to %[[VIEW_A]], predicate = %[[IS_TMA_THREAD]] : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<128x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR: nvgpu.tma.async.load %{{.*}}[%[[DIM_Y]], %[[K_COORD]]], %[[MBAR_TMA]][%[[SLOT]]] to %[[VIEW_B1]], predicate = %[[IS_TMA_THREAD]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR: %[[C64_OFF:.*]] = arith.constant 64 : index
+# DUMPIR: %[[DIM_Y_OFF:.*]] = arith.addi %[[DIM_Y]], %[[C64_OFF]] : index
+# DUMPIR: nvgpu.tma.async.load %{{.*}}[%[[DIM_Y_OFF]], %[[K_COORD]]], %[[MBAR_TMA]][%[[SLOT]]] to %[[VIEW_B2]], predicate = %[[IS_TMA_THREAD]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR: scf.yield %[[PHASE]] : i1
+# DUMPIR: }
+# DUMPIR: }
+# DUMPIR: scf.if %[[IS_CONSUMER]] {
+# DUMPIR: nvvm.setmaxregister increase 232
+# DUMPIR: %[[FALSE:.*]] = arith.constant false
+# DUMPIR: %[[ACC_INIT:.*]] = nvgpu.warpgroup.mma.init.accumulator -> <fragmented = vector<128x128xf32>>
+# DUMPIR: %[[C0_CONS:.*]] = arith.constant 0 : index
+# DUMPIR: %[[C16_CONS:.*]] = arith.constant 16 : index
+# DUMPIR: %[[C1_CONS:.*]] = arith.constant 1 : index
+# DUMPIR: %[[CONS_LOOP:.*]]:2 = scf.for %arg15 = %[[C0_CONS]] to %[[C16_CONS]] step %[[C1_CONS]] iter_args(%arg16 = %[[ACC_INIT]], %arg17 = %[[FALSE]]) -> (!nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>, i1) {
+# DUMPIR: %[[C7_CONS:.*]] = arith.constant 7 : index
+# DUMPIR: %[[SLOT_CONS:.*]] = arith.remui %arg15, %[[C7_CONS]] : index
+# DUMPIR: %[[TIMEOUT_CONS:.*]] = arith.constant 10000000 : index
+# DUMPIR: nvgpu.mbarrier.try_wait.parity %[[MBAR_TMA]][%[[SLOT_CONS]]], %arg17, %[[TIMEOUT_CONS]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+# DUMPIR: %[[C16384_CONS:.*]] = arith.constant 16384 : index
+# DUMPIR: %[[OFF_A_CONS:.*]] = arith.muli %[[SLOT_CONS]], %[[C16384_CONS]] : index
+# DUMPIR: %[[C114688_CONS:.*]] = arith.constant 114688 : index
+# DUMPIR: %[[OFF_B_CONS:.*]] = arith.addi %[[OFF_A_CONS]], %[[C114688_CONS]] : index
+# DUMPIR: %[[SMEM_CONS:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR: %[[VIEW_A_CONS:.*]] = memref.view %[[SMEM_CONS]][%[[OFF_A_CONS]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR: %[[SMEM_CONS_2:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR: %[[VIEW_B_CONS:.*]] = memref.view %[[SMEM_CONS_2]][%[[OFF_B_CONS]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x128xf16, #gpu.address_space<workgroup>>
+# DUMPIR: %[[DESC_A:.*]] = nvgpu.warpgroup.generate.descriptor %[[VIEW_A_CONS]], %{{.*}} : memref<128x64xf16, #gpu.address_space<workgroup>>, <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>>
+# DUMPIR: %[[DESC_B:.*]] = nvgpu.warpgroup.generate.descriptor %[[VIEW_B_CONS]], %{{.*}} : memref<64x128xf16, #gpu.address_space<workgroup>>, <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>>
+# DUMPIR: %[[ACC:.*]] = nvgpu.warpgroup.mma %[[DESC_A]], %[[DESC_B]], %arg16 {transposeB} : <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>>, <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>>, <fragmented = vector<128x128xf32>> -> <fragmented = vector<128x128xf32>>
+# DUMPIR: %[[C0_CMP:.*]] = arith.constant 0 : index
+# DUMPIR: %[[IS_NOT_FIRST:.*]] = arith.cmpi ugt, %arg15, %[[C0_CMP]] : index
+# DUMPIR: %[[ARRIVE_PRED:.*]] = arith.andi %[[IS_NOT_FIRST]], %[[IS_PRIMARY_2]] : i1
+# DUMPIR: scf.if %[[ARRIVE_PRED]] {
+# DUMPIR: %[[C0_ARR:.*]] = arith.constant 0 : index
+# DUMPIR: %[[IS_ZERO:.*]] = arith.cmpi eq, %[[SLOT_CONS]], %[[C0_ARR]] : index
+# DUMPIR: %[[C6_WRAP:.*]] = arith.constant 6 : index
+# DUMPIR: %[[C1_SUB:.*]] = arith.constant 1 : index
+# DUMPIR: %[[PREV_SLOT:.*]] = arith.subi %[[SLOT_CONS]], %[[C1_SUB]] : index
+# DUMPIR: %[[BARR_ID:.*]] = arith.select %[[IS_ZERO]], %[[C6_WRAP]], %[[PREV_SLOT]] : index
+# DUMPIR: %{{.*}} = nvgpu.mbarrier.arrive %[[MBAR_MMA]][%[[BARR_ID]]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> !nvgpu.mbarrier.token
+# DUMPIR: }
+# DUMPIR: %[[C6_LAST:.*]] = arith.constant 6 : index
+# DUMPIR: %[[IS_LAST_CONS:.*]] = arith.cmpi eq, %[[SLOT_CONS]], %[[C6_LAST]] : index
+# DUMPIR: %[[TRUE_CONS:.*]] = arith.constant true
+# DUMPIR: %[[FLIP_CONS:.*]] = arith.xori %arg17, %[[TRUE_CONS]] : i1
+# DUMPIR: %[[PHASE_CONS:.*]] = arith.select %[[IS_LAST_CONS]], %[[FLIP_CONS]], %arg17 : i1
+# DUMPIR: scf.yield %[[ACC]], %[[PHASE_CONS]] : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>, i1
+# DUMPIR: }
+# DUMPIR: nvvm.wgmma.wait.group.sync.aligned 0
+# DUMPIR: %[[TID_EPI:.*]] = gpu.thread_id x
+# DUMPIR: %[[BID_X_EPI:.*]] = gpu.block_id x
+# DUMPIR: %[[BID_Y_EPI:.*]] = gpu.block_id y
+# DUMPIR: %[[C128_EPI:.*]] = arith.constant 128 : index
+# DUMPIR: %[[DIM_X_EPI:.*]] = arith.muli %[[BID_X_EPI]], %[[C128_EPI]] : index
+# DUMPIR: %[[C128_EPI_2:.*]] = arith.constant 128 : index
+# DUMPIR: %[[DIM_Y_EPI:.*]] = arith.muli %[[BID_Y_EPI]], %[[C128_EPI_2]] : index
+# DUMPIR: %[[SMEM_EPI:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR: %[[C0_EPI:.*]] = arith.constant 0 : index
+# DUMPIR: %[[VIEW_EPI:.*]] = memref.view %[[SMEM_EPI]][%[[C0_EPI]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x128xf32, #gpu.address_space<workgroup>>
+# DUMPIR: %[[SUBVIEW:.*]] = memref.subview %{{.*}}[%[[DIM_X_EPI]], %[[DIM_Y_EPI]]] [128, 128] [1, 1] : memref<512x256xf32> to memref<128x128xf32, strided<[256, 1], offset: ?>>
+# DUMPIR: nvgpu.warpgroup.mma.store %[[CONS_LOOP]]#0, %[[VIEW_EPI]] : <fragmented = vector<128x128xf32>> to memref<128x128xf32, #gpu.address_space<workgroup>>
+# DUMPIR: gpu.barrier
+# DUMPIR: %[[C0_STORE:.*]] = arith.constant 0 : index
+# DUMPIR: %[[C128_STORE:.*]] = arith.constant 128 : index
+# DUMPIR: %[[C1_STORE:.*]] = arith.constant 1 : index
+# DUMPIR: scf.for %arg15 = %[[C0_STORE]] to %[[C128_STORE]] step %[[C1_STORE]] {
+# DUMPIR: %{{.*}} = memref.load %[[VIEW_EPI]][%arg15, %[[TID_EPI]]] : memref<128x128xf32, #gpu.address_space<workgroup>>
+# DUMPIR: memref.store %{{.*}}, %[[SUBVIEW]][%arg15, %[[TID_EPI]]] : memref<128x128xf32, strided<[256, 1], offset: ?>>
+# DUMPIR: }
+# DUMPIR: }
+# DUMPIR: gpu.terminator
diff --git a/mlir/test/Examples/NVGPU/lit.local.cfg b/mlir/test/Examples/NVGPU/lit.local.cfg
index 689cd25..af44b2e 100644
--- a/mlir/test/Examples/NVGPU/lit.local.cfg
+++ b/mlir/test/Examples/NVGPU/lit.local.cfg
@@ -1,4 +1,4 @@
config.unsupported = False
-if not config.enable_cuda_runner or not config.mlir_run_cuda_sm90_tests:
+if not config.enable_cuda_runner or not config.enable_bindings_python:
config.unsupported = True
\ No newline at end of file
diff --git a/mlir/test/Examples/NVGPU/tools/nvdsl.py b/mlir/test/Examples/NVGPU/tools/nvdsl.py
index 90dbb23..8561072 100644
--- a/mlir/test/Examples/NVGPU/tools/nvdsl.py
+++ b/mlir/test/Examples/NVGPU/tools/nvdsl.py
@@ -9,6 +9,7 @@ from mlir import runtime as rt
from tools import nvgpucompiler
MLIR_DYNAMIC = -9223372036854775808
+DUMP_ONLY = os.getenv("MLIR_NVDSL_PRINT_IR") == "1"
def const(value: int, ty=None):
@@ -84,9 +85,7 @@ class Mbarriers:
self.mbar_group_op, txcount_op, self.id_op, predicate=predicate
)
else:
- nvgpu.mbarrier_arrive(
- ir.Type.parse("!nvgpu.mbarrier.token"), self.mbar_group_op, self.id_op
- )
+ nvgpu.mbarrier_arrive(self.mbar_group_op, self.id_op)
def try_wait(self, phase: bool = False, ticks: int = 10000000):
ticks_op = const(ticks)
@@ -144,7 +143,9 @@ class TMA:
device_ptr,
)
self.tma_descriptor = nvgpu.TmaCreateDescriptorOp(
- tma_descriptor_ty, device_unranked_memref, map(const, self.tma_box_shape)
+ tma_descriptor_ty,
+ device_unranked_memref,
+ list(map(const, self.tma_box_shape)),
)
return self.tma_descriptor.result
@@ -156,7 +157,7 @@ class TMA:
dest,
mbarrier.mbar_group_op,
self.tma_descriptor,
- coordinates=map(const, coords),
+ coordinates=list(map(const, coords)),
mbarId=mbarrier.id_op,
predicate=predicate,
)
@@ -310,13 +311,10 @@ class NVDSL:
@functools.wraps(func)
def wrapper(*args, **kwargs):
launch_op = gpu.LaunchOp(
- None,
- [],
- *map(const, grid),
- *map(const, block),
- dynamicSharedMemorySize=arith.constant(T.i32(), smem),
+ grid_size=grid,
+ block_size=block,
+ dynamic_shared_memory_size=arith.constant(T.i32(), smem),
)
- launch_op.body.blocks.append(*([T.index()] * 12))
with ir.InsertionPoint(launch_op.body.blocks[0]):
result = func(*args, **kwargs)
gpu.terminator()
@@ -334,13 +332,11 @@ class NVDSL:
def saveIR(module):
"""Save generated IR"""
- if True: # self.saveIR:
- # print(mlir_nvgpu_module)
- original_stdout = sys.stdout
- with open("nvdsl.mlir", "w") as f:
- sys.stdout = f
- print(module)
- sys.stdout = original_stdout
+ original_stdout = sys.stdout
+ with open("nvdsl.mlir", "w") as f:
+ sys.stdout = f
+ print(module)
+ sys.stdout = original_stdout
def _binary_op(lhs, rhs, op: str, predAtt="") -> "ArithValue":
"""Generate MLIR's Arith dialects binary operations."""
@@ -429,6 +425,9 @@ class NVDSL:
# Save IR in a file
# saveIR(module)
+ if DUMP_ONLY:
+ print(module)
+ return 0
# Verify the module
module.operation.verify()
diff --git a/mlir/test/Examples/NVGPU/tools/nvgpucompiler.py b/mlir/test/Examples/NVGPU/tools/nvgpucompiler.py
index 1c9cc74..4b661f8 100644
--- a/mlir/test/Examples/NVGPU/tools/nvgpucompiler.py
+++ b/mlir/test/Examples/NVGPU/tools/nvgpucompiler.py
@@ -35,9 +35,11 @@ class NvgpuCompiler:
def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
"""Wraps the module in a JIT execution engine."""
- return execution_engine.ExecutionEngine(
+ ee = execution_engine.ExecutionEngine(
module, opt_level=self.opt_level, shared_libs=self.shared_libs
)
+ ee.initialize()
+ return ee
def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
"""Compiles and jits the module."""
diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 0c5fec8c..2f5dd28 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -145,3 +145,11 @@ func.func @verify_fail_3() {
%r = "arith.constant"() {value = -3 : si32} : () -> si32
return
}
+
+// -----
+
+// Verify that symbols with results are rejected
+module {
+ // expected-error@+1 {{'test.symbol_with_result' op symbols must not have results}}
+ %0 = "test.symbol_with_result"() <{sym_name = "test_symbol"}> : () -> i32
+}
diff --git a/mlir/test/IR/locations.mlir b/mlir/test/IR/locations.mlir
index b725307..2e23746 100644
--- a/mlir/test/IR/locations.mlir
+++ b/mlir/test/IR/locations.mlir
@@ -105,3 +105,10 @@ func.func @dialect_location() {
test.attr_with_loc("dialectLoc" loc(#test.custom_location<"foo.mlir"*32>))
return
}
+
+// CHECK-LABEL: @location_attr
+// CHECK: test.op_with_loc_attr loc("loc1":10:20) {foo.discardable_loc_attr = loc("loc2":20:30)} loc({{.*}}locations.mlir":[[# @LINE+2]]:3)
+func.func @location_attr() {
+ test.op_with_loc_attr loc("loc1":10:20) {foo.discardable_loc_attr = loc("loc2":20:30)}
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation-vector.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation-vector.mlir
new file mode 100644
index 0000000..5f8b2f4
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation-vector.mlir
@@ -0,0 +1,26 @@
+// REQUIRES: system-linux
+// TODO: Run only on Linux until we figure out how to build
+// mlir_apfloat_wrappers in a platform-independent way.
+
+// All floating-point arithmetics is lowered through APFloat.
+// RUN: mlir-opt %s --convert-arith-to-apfloat --convert-vector-to-scf \
+// RUN: --convert-scf-to-cf --convert-to-llvm | \
+// RUN: mlir-runner -e entry --entry-point-result=void \
+// RUN: --shared-libs=%mlir_c_runner_utils \
+// RUN: --shared-libs=%mlir_apfloat_wrappers | FileCheck %s
+
+// Put rhs into separate function so that it won't be constant-folded.
+func.func @foo_vec() -> (vector<4xf8E4M3FN>, vector<4xf32>) {
+ %cst1 = arith.constant dense<[2.2, 2.2, 2.2, 2.2]> : vector<4xf8E4M3FN>
+ %cst2 = arith.constant dense<[2.2, 2.2, 2.2, 2.2]> : vector<4xf32>
+ return %cst1, %cst2 : vector<4xf8E4M3FN>, vector<4xf32>
+}
+
+func.func @entry() {
+ // CHECK: ( 3.5, 3.5, 3.5, 3.5 )
+ %a1_vec = arith.constant dense<[1.4, 1.4, 1.4, 1.4]> : vector<4xf8E4M3FN>
+ %b1_vec, %b2_vec = func.call @foo_vec() : () -> (vector<4xf8E4M3FN>, vector<4xf32>)
+ %c1_vec = arith.addf %a1_vec, %b1_vec : vector<4xf8E4M3FN> // not supported by LLVM
+ vector.print %c1_vec : vector<4xf8E4M3FN>
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
new file mode 100644
index 0000000..7f72dd5
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
@@ -0,0 +1,82 @@
+// REQUIRES: system-linux
+// TODO: Run only on Linux until we figure out how to build
+// mlir_apfloat_wrappers in a platform-independent way.
+
+// Case 1: All floating-point arithmetics is lowered through APFloat.
+// RUN: mlir-opt %s --convert-arith-to-apfloat --convert-to-llvm | \
+// RUN: mlir-runner -e entry --entry-point-result=void \
+// RUN: --shared-libs=%mlir_c_runner_utils \
+// RUN: --shared-libs=%mlir_apfloat_wrappers | FileCheck %s
+
+// Case 2: Only unsupported arithmetics (f8E4M3FN) is lowered through APFloat.
+// Arithmetics on f32 is lowered directly to LLVM.
+// RUN: mlir-opt %s --convert-to-llvm --convert-arith-to-apfloat \
+// RUN: --convert-to-llvm --reconcile-unrealized-casts | \
+// RUN: mlir-runner -e entry --entry-point-result=void \
+// RUN: --shared-libs=%mlir_c_runner_utils \
+// RUN: --shared-libs=%mlir_apfloat_wrappers | FileCheck %s
+
+// Put rhs into separate function so that it won't be constant-folded.
+func.func @foo() -> (f8E4M3FN, f32) {
+ %cst1 = arith.constant 2.2 : f8E4M3FN
+ %cst2 = arith.constant 2.2 : f32
+ return %cst1, %cst2 : f8E4M3FN, f32
+}
+
+func.func @entry() {
+ %a1 = arith.constant 1.4 : f8E4M3FN
+ %a2 = arith.constant 1.4 : f32
+ %b1, %b2 = func.call @foo() : () -> (f8E4M3FN, f32)
+
+ // CHECK: 2.2
+ vector.print %b2 : f32
+
+ // CHECK-NEXT: 3.5
+ %c1 = arith.addf %a1, %b1 : f8E4M3FN // not supported by LLVM
+ vector.print %c1 : f8E4M3FN
+
+ // CHECK-NEXT: 3.6
+ %c2 = arith.addf %a2, %b2 : f32 // supported by LLVM
+ vector.print %c2 : f32
+
+ // CHECK-NEXT: 2.25
+ %cvt = arith.truncf %b2 : f32 to f8E4M3FN
+ vector.print %cvt : f8E4M3FN
+
+ // CHECK-NEXT: -2.25
+ %negated = arith.negf %cvt : f8E4M3FN
+ vector.print %negated : f8E4M3FN
+
+ // CHECK-NEXT: -2.25
+ %min = arith.minimumf %cvt, %negated : f8E4M3FN
+ vector.print %min : f8E4M3FN
+
+ // CHECK-NEXT: 1
+ %cmp1 = arith.cmpf "olt", %cvt, %c1 : f8E4M3FN
+ vector.print %cmp1 : i1
+
+ // CHECK-NEXT: 1
+ // Bit pattern: 01, interpreted as signed integer: 1
+ %cvt_int_signed = arith.fptosi %cvt : f8E4M3FN to i2
+ vector.print %cvt_int_signed : i2
+
+ // CHECK-NEXT: -2
+ // Bit pattern: 10, interpreted as signed integer: -2
+ %cvt_int_unsigned = arith.fptoui %cvt : f8E4M3FN to i2
+ vector.print %cvt_int_unsigned : i2
+
+ // CHECK-NEXT: -6
+ // Bit pattern: 1...11110111, interpreted as signed: -9
+ // Closest f4E2M1FN value: -6.0
+ %c9 = arith.constant -9 : i16
+ %cvt_from_signed_int = arith.sitofp %c9 : i16 to f4E2M1FN
+ vector.print %cvt_from_signed_int : f4E2M1FN
+
+ // CHECK-NEXT: 6
+ // Bit pattern: 1...11110111, interpreted as unsigned: 65527
+ // Closest f4E2M1FN value: 6.0
+ %cvt_from_unsigned_int = arith.uitofp %c9 : i16 to f4E2M1FN
+ vector.print %cvt_from_unsigned_int : f4E2M1FN
+
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
index 9d04357..d26853d 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
@@ -22,7 +22,7 @@ func.func @matmul_transpose_a(%A : tensor<?x?xf32>, %B : tensor<?x?xf32>, %C : t
}
func.func @main() {
- %c0 = arith.constant 0 : i32
+ %c0 = arith.constant 0.0 : f32
%c7 = arith.constant 7 : index
%A = arith.constant dense<[
@@ -44,7 +44,7 @@ func.func @main() {
%A_dyn = tensor.cast %A : tensor<13x7xf32> to tensor<?x?xf32>
%C_init = bufferization.alloc_tensor(%c7, %c7) : tensor<?x?xf32>
- %C = linalg.fill ins(%c0 : i32) outs(%C_init : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %C = linalg.fill ins(%c0 : f32) outs(%C_init : tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: Unranked Memref {{.*}} rank = 2 offset = 0 sizes = [7, 7] strides = [7, 1] data =
// CHECK: [32955, 33514, 34073, 34632, 35191, 35750, 36309]
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir
index ad7dbb9..e2c0f1d2 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir
@@ -16,7 +16,7 @@ func.func @matmul(%A : tensor<?x?xf32>, %B : tensor<?x?xf32>, %C : tensor<?x?xf3
}
func.func @main() {
- %c0 = arith.constant 0 : i32
+ %c0 = arith.constant 0.0 : f32
%c7 = arith.constant 7 : index
%A = arith.constant dense<[
@@ -37,7 +37,7 @@ func.func @main() {
%B_dyn = tensor.cast %B : tensor<13x7xf32> to tensor<?x?xf32>
%C_init = bufferization.alloc_tensor(%c7, %c7) : tensor<?x?xf32>
- %C = linalg.fill ins(%c0 : i32) outs(%C_init : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %C = linalg.fill ins(%c0 : f32) outs(%C_init : tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: Unranked Memref {{.*}} rank = 2 offset = 0 sizes = [7, 7] strides = [7, 1] data =
// CHECK: [32955, 33514, 34073, 34632, 35191, 35750, 36309]
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir
index 243f9e5..007189a 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir
@@ -29,7 +29,7 @@ func.func @main() {
%c128 = arith.constant 128 : i32
func.call @setArmSVLBits(%c128) : (i32) -> ()
- %c0 = arith.constant 0 : i32
+ %c0 = arith.constant 0.0 : f32
%c7 = arith.constant 7 : index
%A = arith.constant dense<[
@@ -50,7 +50,7 @@ func.func @main() {
%B_dyn = tensor.cast %B : tensor<13x7xf32> to tensor<?x?xf32>
%C_init = bufferization.alloc_tensor(%c7, %c7) : tensor<?x?xf32>
- %C = linalg.fill ins(%c0 : i32) outs(%C_init : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %C = linalg.fill ins(%c0 : f32) outs(%C_init : tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: Unranked Memref {{.*}} rank = 2 offset = 0 sizes = [7, 7] strides = [7, 1] data =
// CHECK: [32955, 33514, 34073, 34632, 35191, 35750, 36309]
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir
index 127ab70..c90476e 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir
@@ -24,17 +24,14 @@ func.func @main() {
%d5x = tensor.cast %c5x : tensor<5xf32> to tensor<?xf32>
%d4x = tensor.cast %c4x : tensor<4xf32> to tensor<?xf32>
- // CHECK-NOT: ERROR: Runtime op verification failed
- func.call @simple_add(%d5x, %d5x) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
-
// CHECK: ERROR: Runtime op verification failed
- // CHECK: linalg.generic
- // CHECK: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size
+ // CHECK-NEXT: linalg.generic
+ // CHECK-NEXT: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size
func.call @simple_add(%d5x, %d4x) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
// CHECK: ERROR: Runtime op verification failed
- // CHECK: linalg.generic
- // CHECK: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size
+ // CHECK-NEXT: linalg.generic
+ // CHECK-NEXT: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size
func.call @simple_add(%d4x, %d5x) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
%c1x1 = arith.constant dense<0.0> : tensor<1x1xf32>
@@ -48,72 +45,82 @@ func.func @main() {
%d4x5 = tensor.cast %c4x5 : tensor<4x5xf32> to tensor<?x?xf32>
%d5x4 = tensor.cast %c5x4 : tensor<5x4xf32> to tensor<?x?xf32>
- // CHECK-NOT: ERROR: Runtime op verification failed
- func.call @broadcast_add(%d1x1, %d1x1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
-
- // CHECK-NOT: ERROR: Runtime op verification failed
- func.call @broadcast_add(%d1x1, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
-
- // CHECK-NOT: ERROR: Runtime op verification failed
- func.call @broadcast_add(%d4x4, %d1x4) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
+ // CHECK: ERROR: Runtime op verification failed
+ // CHECK-NEXT: linalg.generic
+ // CHECK-NEXT: ^ dimension #1 of input/output operand #1 is incompatible with inferred dimension size
// CHECK: ERROR: Runtime op verification failed
- // CHECK: linalg.generic
- // CHECK: ^ dimension #1 of input/output operand #1 is incompatible with inferred dimension size
+ // CHECK-NEXT: linalg.generic
+ // CHECK-NEXT: ^ dimension #1 of input/output operand #2 is incompatible with inferred dimension size
func.call @broadcast_add(%d1x4, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
// CHECK: ERROR: Runtime op verification failed
- // CHECK: linalg.generic
- // CHECK: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size
+ // CHECK-NEXT: linalg.generic
+ // CHECK-NEXT: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size
+
// CHECK: ERROR: Runtime op verification failed
- // CHECK: linalg.generic
- // CHECK: ^ dimension #1 of input/output operand #1 is incompatible with inferred dimension size
+ // CHECK-NEXT: linalg.generic
+ // CHECK-NEXT: ^ dimension #1 of input/output operand #1 is incompatible with inferred dimension size
+
// CHECK: ERROR: Runtime op verification failed
- // CHECK: linalg.generic
- // CHECK: ^ dimension #1 of input/output operand #2 is incompatible with inferred dimension size
+ // CHECK-NEXT: linalg.generic
+ // CHECK-NEXT: ^ dimension #1 of input/output operand #2 is incompatible with inferred dimension size
func.call @broadcast_add(%d5x4, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
- // CHECK-NOT: ERROR: Runtime op verification failed
- func.call @matmul_generic(%d5x4, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
-
// CHECK: ERROR: Runtime op verification failed
- // CHECK: linalg.generic
- // CHECK: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size
+ // CHECK-NEXT: linalg.generic
+ // CHECK-NEXT: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size
func.call @matmul_generic(%d4x5, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
- // CHECK-NOT: ERROR: Runtime op verification failed
- func.call @matmul_named(%d5x4, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
-
// CHECK: ERROR: Runtime op verification failed
- // CHECK: linalg.matmul
- // CHECK: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size
+ // CHECK-NEXT: linalg.matmul
+ // CHECK-NEXT: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size
func.call @matmul_named(%d4x5, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
%c64x57 = arith.constant dense<0.0> : tensor<16x29xf32>
%c3x4 = arith.constant dense<0.0> : tensor<3x4xf32>
- // CHECK-NOT: ERROR: Runtime op verification failed
- func.call @conv(%c64x57, %c3x4) : (tensor<16x29xf32>, tensor<3x4xf32>) -> (tensor<5x7xf32>)
-
- // CHECK-NOT: ERROR: Runtime op verification failed
- func.call @reverse_from_3(%d4x) : (tensor<?xf32>) -> (tensor<?xf32>)
-
// CHECK: ERROR: Runtime op verification failed
- // CHECK: linalg.generic
- // CHECK: unexpected negative result on dimension #0 of input/output operand #0
+ // CHECK-NEXT: linalg.generic
+ // CHECK-NEXT: unexpected negative result on dimension #0 of input/output operand #0
func.call @reverse_from_3(%d5x) : (tensor<?xf32>) -> (tensor<?xf32>)
%c0x = arith.constant dense<1.0> : tensor<0xf32>
%d0x = tensor.cast %c0x : tensor<0xf32> to tensor<?xf32>
- // CHECK-NOT: ERROR: Runtime op verification failed
- func.call @fill_empty_1d(%d0x) : (tensor<?xf32>) -> (tensor<?xf32>)
%c0x5 = arith.constant dense<0.0> : tensor<0x5xf32>
%d0x5 = tensor.cast %c0x5 : tensor<0x5xf32> to tensor<?x?xf32>
// CHECK-NOT: ERROR: Runtime op verification failed
+ func.call @fill_empty_1d(%d0x) : (tensor<?xf32>) -> (tensor<?xf32>)
+
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ func.call @simple_add(%d5x, %d5x) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
+
+ // CHECK-NOT: ERROR: Runtime op verification failed
func.call @fill_empty_2d(%d0x5) : (tensor<?x?xf32>) -> (tensor<?x?xf32>)
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ func.call @conv(%c64x57, %c3x4) : (tensor<16x29xf32>, tensor<3x4xf32>) -> (tensor<5x7xf32>)
+
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ func.call @reverse_from_3(%d4x) : (tensor<?xf32>) -> (tensor<?xf32>)
+
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ func.call @matmul_named(%d5x4, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
+
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ func.call @matmul_generic(%d5x4, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
+
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ func.call @broadcast_add(%d1x1, %d1x1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
+
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ func.call @broadcast_add(%d1x1, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
+
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ func.call @broadcast_add(%d4x4, %d1x4) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
+
return
}
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
index 8fa32d7..bbda8d4e 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
@@ -27,8 +27,8 @@ func.func @main() {
%A_dyn = tensor.cast %A : tensor<8x2xf32> to tensor<?x?xf32>
%B_dyn = tensor.cast %B : tensor<2x4xf32> to tensor<?x?xf32>
- %c0_i32 = arith.constant 0 : i32
- %C_init = linalg.fill ins(%c0_i32 : i32) outs(%C_dyn : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %c0_f32 = arith.constant 0.0 : f32
+ %C_init = linalg.fill ins(%c0_f32 : f32) outs(%C_dyn : tensor<?x?xf32>) -> tensor<?x?xf32>
%res = linalg.matmul ins(%A_dyn, %B_dyn: tensor<?x?xf32>, tensor<?x?xf32>)
outs(%C_init: tensor<?x?xf32>) -> tensor<?x?xf32>
diff --git a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
index 8487567..09cfee1 100644
--- a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
@@ -50,6 +50,17 @@ func.func @subview_zero_size_dim(%memref: memref<10x4x1xf32, strided<[?, ?, ?],
return
}
+func.func @subview_with_empty_slice(%memref: memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>,
+ %dim_0: index,
+ %dim_1: index,
+ %dim_2: index,
+ %offset: index) {
+ %subview = memref.subview %memref[%offset, 0, 0] [%dim_0, %dim_1, %dim_2] [1, 1, 1] :
+ memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>> to
+ memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+ return
+}
+
func.func @main() {
%0 = arith.constant 0 : index
@@ -127,5 +138,9 @@ func.func @main() {
func.call @subview_zero_size_dim(%alloca_10x4x1_dyn_stride, %dim_0, %dim_1, %dim_2)
: (memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>, index, index, index) -> ()
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ %offset = arith.constant 10 : index
+ func.call @subview_with_empty_slice(%alloca_10x4x1_dyn_stride, %dim_0, %dim_1, %dim_2, %offset)
+ : (memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>, index, index, index, index) -> ()
return
}
diff --git a/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir b/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir
index a77fa31..745eea3 100644
--- a/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir
@@ -39,6 +39,11 @@ func.func @extract_slice_zero_size_dim(%arg0: tensor<10x4x1xf32>, %dim_0: index,
return
}
+func.func @extract_slice_empty_tensor(%arg0: tensor<10x4x1xf32>, %dim_0: index, %dim_1: index, %dim_2: index, %offset: index) {
+ tensor.extract_slice %arg0[%offset, 0, 0] [%dim_0, %dim_1, %dim_2] [1, 1, 1] : tensor<10x4x1xf32> to tensor<?x?x?xf32>
+ return
+}
+
func.func @main() {
%0 = arith.constant 0 : index
@@ -115,5 +120,9 @@ func.func @main() {
%dim_2 = arith.constant 1 : index
func.call @extract_slice_zero_size_dim(%cst10x4x1xf32, %dim_0, %dim_1, %dim_2) : (tensor<10x4x1xf32>, index, index, index) -> ()
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ %offset = arith.constant 10 : index
+ func.call @extract_slice_empty_tensor(%cst10x4x1xf32, %dim_0, %dim_1, %dim_2, %offset) : (tensor<10x4x1xf32>, index, index, index, index) -> ()
+
return
}
diff --git a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
index a374d9a..e3fee91 100644
--- a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
+++ b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
@@ -63,11 +63,11 @@ func.func @matmul_simple(%lhs: tensor<10x20xf16>, %rhs: tensor<20x15xf32>) -> te
}
func.func @matmul_with_extra_ops_in_func(%lhs: tensor<10x20xf32>, %rhs: tensor<20x15xf32>) -> tensor<10x15xf32> {
- %cst = arith.constant 0.0 : f64
+ %cst = arith.constant 0.0 : f32
%empty = tensor.empty() : tensor<10x15xf32>
// expected-remark @below {{fill}}
- %fill = linalg.fill ins(%cst : f64) outs(%empty : tensor<10x15xf32>) -> tensor<10x15xf32>
+ %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<10x15xf32>) -> tensor<10x15xf32>
%real_lhs = linalg.mul
ins(%lhs, %lhs : tensor<10x20xf32>, tensor<10x20xf32>) outs(%lhs : tensor<10x20xf32>) -> tensor<10x20xf32>
diff --git a/mlir/test/Integration/Dialect/XeGPU/LANE/load_store_subview.mlir b/mlir/test/Integration/Dialect/XeGPU/LANE/load_store_subview.mlir
new file mode 100644
index 0000000..c4608ac
--- /dev/null
+++ b/mlir/test/Integration/Dialect/XeGPU/LANE/load_store_subview.mlir
@@ -0,0 +1,63 @@
+// RUN: mlir-opt %s --gpu-lower-to-xevm-pipeline="xegpu-op-level=lane" \
+// RUN: | mlir-runner \
+// RUN: --shared-libs=%mlir_levelzero_runtime \
+// RUN: --shared-libs=%mlir_runner_utils \
+// RUN: --entry-point-result=void \
+// RUN: | FileCheck %s
+
+module @subview attributes {gpu.container_module} {
+ gpu.module @kernel {
+ gpu.func @subview(%src: memref<256xf32>, %dst: memref<256xf32>) kernel {
+ %src_subview = memref.subview %src[5] [251] [1] : memref<256xf32> to memref<251xf32, strided<[1], offset: 5>>
+ %dst_subview = memref.subview %dst[10] [246] [1] : memref<256xf32> to memref<246xf32, strided<[1], offset: 10>>
+ %lane_id = gpu.lane_id
+ %mask = arith.constant 1 : i1
+ %loaded = xegpu.load %src_subview[%lane_id], %mask : memref<251xf32, strided<[1], offset: 5>>, index, i1 -> f32
+ xegpu.store %loaded, %dst_subview[%lane_id], %mask : f32, memref<246xf32, strided<[1], offset: 10>>, index, i1
+ gpu.return
+ }
+ }
+ func.func @test(%src: memref<256xf32>, %dst: memref<256xf32>) -> memref<256xf32> {
+ %memref_src = gpu.alloc () : memref<256xf32>
+ gpu.memcpy %memref_src, %src : memref<256xf32>, memref<256xf32>
+ %memref_dst = gpu.alloc () : memref<256xf32>
+ gpu.memcpy %memref_dst, %dst : memref<256xf32>, memref<256xf32>
+ %c1 = arith.constant 1 : index
+ %c16 = arith.constant 16 : index
+ gpu.launch_func @kernel::@subview blocks in (%c1, %c1, %c1) threads in (%c16, %c1, %c1) args(%memref_src : memref<256xf32>, %memref_dst : memref<256xf32>)
+ gpu.wait // Wait for the kernel to finish.
+ gpu.memcpy %dst, %memref_dst : memref<256xf32>, memref<256xf32>
+ gpu.dealloc %memref_src : memref<256xf32>
+ gpu.dealloc %memref_dst : memref<256xf32>
+ return %dst : memref<256xf32>
+ }
+ func.func @main() {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c256 = arith.constant 256 : index
+ %memref_src = memref.alloc() : memref<256xf32>
+ %memref_dst = memref.alloc() : memref<256xf32>
+ // Initialize source memref
+ scf.for %i = %c0 to %c256 step %c1 {
+ %val = arith.index_cast %i : index to i32
+ %val_float = arith.sitofp %val : i32 to f32
+ memref.store %val_float, %memref_src[%i] : memref<256xf32>
+ }
+ // Initialize destination memref to zero
+ scf.for %i = %c0 to %c256 step %c1 {
+ %zero = arith.constant 0.0 : f32
+ memref.store %zero, %memref_dst[%i] : memref<256xf32>
+ }
+ // Call test function
+ %gpu_result = call @test(%memref_src, %memref_dst) : (memref<256xf32>, memref<256xf32>) -> memref<256xf32>
+ %gpu_result_casted = memref.cast %gpu_result : memref<256xf32> to memref<*xf32>
+ // CHECK: Unranked Memref base@ = 0x{{[0-9a-f]+}}
+ // CHECK: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ call @printMemrefF32(%gpu_result_casted) : (memref<*xf32>) -> ()
+ // Deallocate memrefs
+ memref.dealloc %memref_src : memref<256xf32>
+ memref.dealloc %memref_dst : memref<256xf32>
+ return
+ }
+ func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface}
+}
diff --git a/mlir/test/Integration/Dialect/XeVM/GPU/gpu_printf.mlir b/mlir/test/Integration/Dialect/XeVM/GPU/gpu_printf.mlir
index edf8775..5ed2148 100644
--- a/mlir/test/Integration/Dialect/XeVM/GPU/gpu_printf.mlir
+++ b/mlir/test/Integration/Dialect/XeVM/GPU/gpu_printf.mlir
@@ -3,7 +3,7 @@
// RUN: | mlir-opt -convert-scf-to-cf -convert-cf-to-llvm -convert-vector-to-llvm -convert-arith-to-llvm \
// RUN: | mlir-opt -gpu-to-llvm -reconcile-unrealized-casts -cse -gpu-module-to-binary \
// RUN: | mlir-runner \
-// RUN: --shared-libs=%mlir_sycl_runtime \
+// RUN: --shared-libs=%mlir_levelzero_runtime \
// RUN: --shared-libs=%mlir_runner_utils \
// RUN: --shared-libs=%mlir_c_runner_utils \
// RUN: --entry-point-result=void \
diff --git a/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/wmma-matmul-f64.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/wmma-matmul-f64.mlir
new file mode 100644
index 0000000..c3dd35b
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/wmma-matmul-f64.mlir
@@ -0,0 +1,73 @@
+// RUN: mlir-opt %s \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-chip=sm_80 cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-runner \
+// RUN: --shared-libs=%mlir_cuda_runtime \
+// RUN: --shared-libs=%mlir_runner_utils \
+// RUN: --entry-point-result=void \
+// RUN: | FileCheck %s
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+
+func.func @main() {
+ %a = memref.alloc() : memref<8x4xf64>
+ %b = memref.alloc() : memref<4x8xf64>
+ %c = memref.alloc() : memref<8x8xf64>
+ %d = memref.alloc() : memref<8x8xf64>
+
+ %f1 = arith.constant 1.0e+00 : f64
+ %fcst = arith.constant 3.14e+00 : f64
+ %c0 = arith.constant 0 : index
+ %c8 = arith.constant 8 : index
+ %c4 = arith.constant 4 : index
+ %c1 = arith.constant 1 : index
+ %c32 = arith.constant 32 : index
+
+ // Initialize the Input matrixes with ones.
+ scf.for %arg0 = %c0 to %c8 step %c1 {
+ scf.for %arg1 = %c0 to %c4 step %c1 {
+ memref.store %f1, %a[%arg0, %arg1] : memref<8x4xf64>
+ memref.store %f1, %b[%arg1, %arg0] : memref<4x8xf64>
+ }
+ }
+ // Initialize the accumulator matrix with a constant.
+ scf.for %arg0 = %c0 to %c8 step %c1 {
+ scf.for %arg1 = %c0 to %c8 step %c1 {
+ memref.store %fcst, %c[%arg0, %arg1] : memref<8x8xf64>
+ }
+ }
+
+ %2 = memref.cast %a : memref<8x4xf64> to memref<*xf64>
+ %20 = memref.cast %b : memref<4x8xf64> to memref<*xf64>
+ %33 = memref.cast %c : memref<8x8xf64> to memref<*xf64>
+ %34 = memref.cast %d : memref<8x8xf64> to memref<*xf64>
+
+ gpu.host_register %2 : memref<*xf64>
+ gpu.host_register %20 : memref<*xf64>
+ gpu.host_register %33 : memref<*xf64>
+ gpu.host_register %34 : memref<*xf64>
+
+ gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+ threads(%tx, %ty, %tz) in (%block_x = %c32, %block_y = %c1, %block_z = %c1) {
+ %A = gpu.subgroup_mma_load_matrix %a[%c0, %c0] {leadDimension = 4 : index} : memref<8x4xf64> -> !gpu.mma_matrix<8x4xf64, "AOp">
+ %B = gpu.subgroup_mma_load_matrix %b[%c0, %c0] {leadDimension = 8 : index} : memref<4x8xf64> -> !gpu.mma_matrix<4x8xf64, "BOp">
+ %C = gpu.subgroup_mma_load_matrix %c[%c0, %c0] {leadDimension = 8 : index} : memref<8x8xf64> -> !gpu.mma_matrix<8x8xf64, "COp">
+
+ %R = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<8x4xf64, "AOp">, !gpu.mma_matrix<4x8xf64, "BOp"> -> !gpu.mma_matrix<8x8xf64, "COp">
+
+ gpu.subgroup_mma_store_matrix %R, %d[%c0, %c0] {leadDimension = 8 : index}: !gpu.mma_matrix<8x8xf64, "COp">, memref<8x8xf64>
+ gpu.terminator
+ }
+ // Print the memref after computation.
+ call @printMemrefF64(%34) : (memref<*xf64>) -> ()
+ // CHECK: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14],
+ // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14],
+ // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14],
+ // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14],
+ // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14],
+ // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14],
+ // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14],
+ // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14]
+ return
+}
+
+func.func private @printMemrefF64(memref<*xf64>)
diff --git a/mlir/test/Integration/GPU/CUDA/all-reduce-and.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-and.mlir
index 5585d98..d0001f6 100644
--- a/mlir/test/Integration/GPU/CUDA/all-reduce-and.mlir
+++ b/mlir/test/Integration/GPU/CUDA/all-reduce-and.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s \
-// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="allow-pattern-rollback=0" \
// RUN: | mlir-runner \
// RUN: --shared-libs=%mlir_cuda_runtime \
// RUN: --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/all-reduce-maxsi.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-maxsi.mlir
index cd90ce3..fcff5f4 100644
--- a/mlir/test/Integration/GPU/CUDA/all-reduce-maxsi.mlir
+++ b/mlir/test/Integration/GPU/CUDA/all-reduce-maxsi.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s \
-// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \
// RUN: | mlir-runner \
// RUN: --shared-libs=%mlir_cuda_runtime \
// RUN: --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/all-reduce-minsi.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-minsi.mlir
index fec2567..4718ac9 100644
--- a/mlir/test/Integration/GPU/CUDA/all-reduce-minsi.mlir
+++ b/mlir/test/Integration/GPU/CUDA/all-reduce-minsi.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s \
-// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \
// RUN: | mlir-runner \
// RUN: --shared-libs=%mlir_cuda_runtime \
// RUN: --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/all-reduce-op.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-op.mlir
index d5633b0..5e3a7e7e 100644
--- a/mlir/test/Integration/GPU/CUDA/all-reduce-op.mlir
+++ b/mlir/test/Integration/GPU/CUDA/all-reduce-op.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s \
-// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \
// RUN: | mlir-runner \
// RUN: --shared-libs=%mlir_cuda_runtime \
// RUN: --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/all-reduce-or.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-or.mlir
index db297b0..f1a48ae 100644
--- a/mlir/test/Integration/GPU/CUDA/all-reduce-or.mlir
+++ b/mlir/test/Integration/GPU/CUDA/all-reduce-or.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s \
-// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \
// RUN: | mlir-runner \
// RUN: --shared-libs=%mlir_cuda_runtime \
// RUN: --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/all-reduce-region.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-region.mlir
index 65cbc79..f0a46ce 100644
--- a/mlir/test/Integration/GPU/CUDA/all-reduce-region.mlir
+++ b/mlir/test/Integration/GPU/CUDA/all-reduce-region.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s \
-// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \
// RUN: | mlir-runner \
// RUN: --shared-libs=%mlir_cuda_runtime \
// RUN: --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/all-reduce-xor.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-xor.mlir
index a0c955e..ddbabd4 100644
--- a/mlir/test/Integration/GPU/CUDA/all-reduce-xor.mlir
+++ b/mlir/test/Integration/GPU/CUDA/all-reduce-xor.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s \
-// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \
// RUN: | mlir-runner \
// RUN: --shared-libs=%mlir_cuda_runtime \
// RUN: --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/alloc-host-shared.mlir b/mlir/test/Integration/GPU/CUDA/alloc-host-shared.mlir
index f041df8..5c56e2d 100644
--- a/mlir/test/Integration/GPU/CUDA/alloc-host-shared.mlir
+++ b/mlir/test/Integration/GPU/CUDA/alloc-host-shared.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s \
-// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \
// RUN: | mlir-runner \
// RUN: --shared-libs=%mlir_cuda_runtime \
// RUN: --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/assert.mlir b/mlir/test/Integration/GPU/CUDA/assert.mlir
index 71a21cf..83cf70c 100644
--- a/mlir/test/Integration/GPU/CUDA/assert.mlir
+++ b/mlir/test/Integration/GPU/CUDA/assert.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
+// RUN: mlir-opt %s -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \
// RUN: | mlir-runner \
// RUN: --shared-libs=%mlir_cuda_runtime \
// RUN: --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/async.mlir b/mlir/test/Integration/GPU/CUDA/async.mlir
index 5acadd6..3e45b5a 100644
--- a/mlir/test/Integration/GPU/CUDA/async.mlir
+++ b/mlir/test/Integration/GPU/CUDA/async.mlir
@@ -8,8 +8,11 @@
// RUN: --shared-libs=%mlir_cuda_runtime \
// RUN: --shared-libs=%mlir_async_runtime \
// RUN: --shared-libs=%mlir_runner_utils \
-// RUN: --entry-point-result=void -O0 \
-// RUN: | FileCheck %s
+// RUN: --entry-point-result=void -O0
+// RUN:
+// This test is overly flaky right now and needs investigation, skipping FileCheck.
+// See: https://github.com/llvm/llvm-project/issues/170833
+// DISABLED: | FileCheck %s
func.func @main() {
%c0 = arith.constant 0 : index
diff --git a/mlir/test/Integration/GPU/CUDA/command-line-arg.mlir b/mlir/test/Integration/GPU/CUDA/command-line-arg.mlir
index 34dde6e..77a4fa0 100644
--- a/mlir/test/Integration/GPU/CUDA/command-line-arg.mlir
+++ b/mlir/test/Integration/GPU/CUDA/command-line-arg.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s \
-// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-chip=sm_80 ptxas-cmd-options='-v --register-usage-level=8'" -debug-only=serialize-to-binary \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-chip=sm_80 ptxas-cmd-options='-v --register-usage-level=8' allow-pattern-rollback=0" -debug-only=serialize-to-binary \
// RUN: 2>&1 | FileCheck %s
func.func @host_function(%arg0 : f32, %arg1 : memref<?xf32>) {
diff --git a/mlir/test/Integration/GPU/CUDA/concurrent-kernels.mlir b/mlir/test/Integration/GPU/CUDA/concurrent-kernels.mlir
index ed01416..51f6e36 100644
--- a/mlir/test/Integration/GPU/CUDA/concurrent-kernels.mlir
+++ b/mlir/test/Integration/GPU/CUDA/concurrent-kernels.mlir
@@ -2,7 +2,7 @@
// increment a global atomic counter and wait for the counter to reach 2.
//
// RUN: mlir-opt %s \
-// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \
// RUN: | env CUDA_MODULE_LOADING=EAGER mlir-runner \
// RUN: --shared-libs=%mlir_cuda_runtime \
// RUN: --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/dump-ptx.mlir b/mlir/test/Integration/GPU/CUDA/dump-ptx.mlir
index 27ec1ec..efffcaa 100644
--- a/mlir/test/Integration/GPU/CUDA/dump-ptx.mlir
+++ b/mlir/test/Integration/GPU/CUDA/dump-ptx.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s \
-// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline -debug-only=serialize-to-isa \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="allow-pattern-rollback=0" -debug-only=serialize-to-isa \
// RUN: 2>&1 | FileCheck %s
// CHECK-LABEL: Generated by LLVM NVPTX Back-End
diff --git a/mlir/test/Integration/GPU/CUDA/dump-sass.mlir b/mlir/test/Integration/GPU/CUDA/dump-sass.mlir
index d32f5ef..f810678 100644
--- a/mlir/test/Integration/GPU/CUDA/dump-sass.mlir
+++ b/mlir/test/Integration/GPU/CUDA/dump-sass.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s \
-// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline -debug-only=dump-sass \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="allow-pattern-rollback=0" -debug-only=dump-sass \
// RUN: 2>&1 | FileCheck %s
// CHECK: MOV
diff --git a/mlir/test/Integration/GPU/CUDA/gpu-to-cubin.mlir b/mlir/test/Integration/GPU/CUDA/gpu-to-cubin.mlir
index 07f3218..fe3c2b1 100644
--- a/mlir/test/Integration/GPU/CUDA/gpu-to-cubin.mlir
+++ b/mlir/test/Integration/GPU/CUDA/gpu-to-cubin.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s \
-// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \
// RUN: | mlir-runner \
// RUN: --shared-libs=%mlir_cuda_runtime \
// RUN: --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/multiple-all-reduce.mlir b/mlir/test/Integration/GPU/CUDA/multiple-all-reduce.mlir
index b2ac90a..f8f1aa8 100644
--- a/mlir/test/Integration/GPU/CUDA/multiple-all-reduce.mlir
+++ b/mlir/test/Integration/GPU/CUDA/multiple-all-reduce.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s \
-// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \
// RUN: | mlir-runner \
// RUN: --shared-libs=%mlir_cuda_runtime \
// RUN: --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/printf.mlir b/mlir/test/Integration/GPU/CUDA/printf.mlir
index fd664f2..ef11676 100644
--- a/mlir/test/Integration/GPU/CUDA/printf.mlir
+++ b/mlir/test/Integration/GPU/CUDA/printf.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s \
-// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \
// RUN: | mlir-runner \
// RUN: --shared-libs=%mlir_cuda_runtime \
// RUN: --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/shuffle.mlir b/mlir/test/Integration/GPU/CUDA/shuffle.mlir
index a6207d6..a4be5223 100644
--- a/mlir/test/Integration/GPU/CUDA/shuffle.mlir
+++ b/mlir/test/Integration/GPU/CUDA/shuffle.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s \
-// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \
// RUN: | mlir-runner \
// RUN: --shared-libs=%mlir_cuda_runtime \
// RUN: --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Integration/GPU/CUDA/two-modules.mlir b/mlir/test/Integration/GPU/CUDA/two-modules.mlir
index c3cee2f..3490003 100644
--- a/mlir/test/Integration/GPU/CUDA/two-modules.mlir
+++ b/mlir/test/Integration/GPU/CUDA/two-modules.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s \
-// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \
// RUN: | mlir-runner \
// RUN: --shared-libs=%mlir_cuda_runtime \
// RUN: --shared-libs=%mlir_runner_utils \
diff --git a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
index b98e8b0..c634444 100644
--- a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
+++ b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
@@ -184,3 +184,19 @@ func.func @propagate_from_block_to_iterarg(%arg0: index, %arg1: i1) {
}
return
}
+
+// CHECK-LABEL: func @multiple_loop_ivs
+func.func @multiple_loop_ivs(%arg0: memref<?x64xi32>) {
+ %ub1 = test.with_bounds { umin = 1 : index, umax = 32 : index,
+ smin = 1 : index, smax = 32 : index } : index
+ %c0_i32 = arith.constant 0 : i32
+ // CHECK: scf.forall
+ scf.forall (%arg1, %arg2) in (%ub1, 64) {
+ // CHECK: test.reflect_bounds {smax = 31 : index, smin = 0 : index, umax = 31 : index, umin = 0 : index}
+ %1 = test.reflect_bounds %arg1 : index
+ // CHECK-NEXT: test.reflect_bounds {smax = 63 : index, smin = 0 : index, umax = 63 : index, umin = 0 : index}
+ %2 = test.reflect_bounds %arg2 : index
+ memref.store %c0_i32, %arg0[%1, %2] : memref<?x64xi32>
+ }
+ return
+}
diff --git a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
index 4fa7406..624e099 100644
--- a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
+++ b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -resolve-shaped-type-result-dims -split-input-file | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(resolve-shaped-type-result-dims{error-on-pattern-iteration-limit=false}))" -split-input-file | FileCheck %s
+// See %test_unreifiable_result_shape below for why `error-on-partition-iteration-limit` is set to false.
func.func @result_shape(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
-> (index, index, index, index, index) {
@@ -27,12 +28,14 @@ func.func @result_shape(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
// -----
-func.func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
+// Test result shape reification for an operation that implements only
+// `reifyResultShapes` method of the `InferShapedTypeOpInterface`.
+func.func @reify_shaped_type_using_reify_result_shapes(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
-> (index, index, index, index, index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
- %0:2 = "test.op_with_result_shape_per_dim_interface"(%arg0, %arg1)
+ %0:2 = "test.reify_shaped_type_using_reify_result_shapes"(%arg0, %arg1)
: (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
%1 = tensor.dim %0#0, %c0 : tensor<?x5xf32>
%2 = tensor.dim %0#0, %c1 : tensor<?x5xf32>
@@ -41,7 +44,7 @@ func.func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf3
%5 = tensor.dim %0#1, %c2 : tensor<2x3x?xf32>
return %1, %2, %3, %4, %5 : index, index, index, index, index
}
-// CHECK-LABEL: func @result_shape_per_dim(
+// CHECK-LABEL: func @reify_shaped_type_using_reify_result_shapes(
// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>
// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
@@ -51,3 +54,127 @@ func.func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf3
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG_1]], %[[C0]]
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG_0]], %[[C2]]
// CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]]
+
+// -----
+
+// Test result shape reification for an operation that implements only
+// `reifyShapeOfResult` method of the `InferShapedTypeOpInterface`.
+func.func @reify_shaped_type_using_reify_shape_of_result(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
+ -> (index, index, index, index, index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %0:2 = "test.reify_shaped_type_using_reify_result_shapes"(%arg0, %arg1)
+ : (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
+ %1 = tensor.dim %0#0, %c0 : tensor<?x5xf32>
+ %2 = tensor.dim %0#0, %c1 : tensor<?x5xf32>
+ %3 = tensor.dim %0#1, %c0 : tensor<2x3x?xf32>
+ %4 = tensor.dim %0#1, %c1 : tensor<2x3x?xf32>
+ %5 = tensor.dim %0#1, %c2 : tensor<2x3x?xf32>
+ return %1, %2, %3, %4, %5 : index, index, index, index, index
+}
+// CHECK-LABEL: func @reify_shaped_type_using_reify_shape_of_result(
+// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>
+// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
+// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG_1]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG_0]], %[[C2]]
+// CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]]
+
+// -----
+
+// Test result shape reification for an operation that implements only
+// `reifyDimOfResult` method of the `InferShapedTypeOpInterface`.
+func.func @reify_shaped_type_using_reify_dim_of_result(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
+ -> (index, index, index, index, index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %0:2 = "test.reify_shaped_type_using_reify_result_shapes"(%arg0, %arg1)
+ : (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
+ %1 = tensor.dim %0#0, %c0 : tensor<?x5xf32>
+ %2 = tensor.dim %0#0, %c1 : tensor<?x5xf32>
+ %3 = tensor.dim %0#1, %c0 : tensor<2x3x?xf32>
+ %4 = tensor.dim %0#1, %c1 : tensor<2x3x?xf32>
+ %5 = tensor.dim %0#1, %c2 : tensor<2x3x?xf32>
+ return %1, %2, %3, %4, %5 : index, index, index, index, index
+}
+// CHECK-LABEL: func @reify_shaped_type_using_reify_dim_of_result(
+// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>
+// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
+// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG_1]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG_0]], %[[C2]]
+// CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]]
+
+// -----
+
+// This tests also indicates a problem with the approach of just using `reifyShapes`
+// without being specific about {result, dim} that needs to be resolved. The
+// `reifyShapes` implementations introduces `dim` operations that are effectively
+// dead, but it creates an infinite loop on pattern application (which eventually
+// bails on hitting the iteration limit). This is the pitfall of this legacy
+// mechanism.
+
+func.func @test_unreifiable_result_shapes(%arg0 : tensor<?x?xf32>)
+ -> (index, index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = "test.unreifiable_result_shapes"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
+ %d0 = tensor.dim %0, %c0 : tensor<?x?xf32>
+ %d1 = tensor.dim %0, %c1 : tensor<?x?xf32>
+ return %d0, %d1 : index, index
+}
+// CHECK-LABEL: func @test_unreifiable_result_shapes(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>)
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[OP:.+]] = "test.unreifiable_result_shapes"(%[[ARG0]])
+// CHECK: %[[D1:.+]] = tensor.dim %[[OP]], %[[C1]]
+// CHECK: return %[[D0]], %[[D1]]
+// -----
+
+func.func @test_unreifiable_result_shape(%arg0 : tensor<?x?xf32>)
+ -> (index, index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = "test.unreifiable_result_shape"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
+ %d0 = tensor.dim %0, %c0 : tensor<?x?xf32>
+ %d1 = tensor.dim %0, %c1 : tensor<?x?xf32>
+ return %d0, %d1 : index, index
+}
+// CHECK-LABEL: func @test_unreifiable_result_shape(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>)
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[OP:.+]] = "test.unreifiable_result_shape"(%[[ARG0]])
+// CHECK: %[[D1:.+]] = tensor.dim %[[OP]], %[[C1]]
+// CHECK: return %[[D0]], %[[D1]]
+
+// -----
+
+func.func @test_unreifiable_dim_of_result_shape(%arg0 : tensor<?x?xf32>)
+ -> (index, index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = "test.unreifiable_dim_of_result_shape"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
+ %d0 = tensor.dim %0, %c0 : tensor<?x?xf32>
+ %d1 = tensor.dim %0, %c1 : tensor<?x?xf32>
+ return %d0, %d1 : index, index
+}
+// CHECK-LABEL: func @test_unreifiable_dim_of_result_shape(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>)
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[OP:.+]] = "test.unreifiable_dim_of_result_shape"(%[[ARG0]])
+// CHECK: %[[D1:.+]] = tensor.dim %[[OP]], %[[C1]]
+// CHECK: return %[[D0]], %[[D1]]
diff --git a/mlir/test/Interfaces/TilingInterface/query-fusability.mlir b/mlir/test/Interfaces/TilingInterface/query-fusability.mlir
new file mode 100644
index 0000000..d7b0528
--- /dev/null
+++ b/mlir/test/Interfaces/TilingInterface/query-fusability.mlir
@@ -0,0 +1,70 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics
+
+func.func @fusable_with_matching_offsets(%arg0: tensor<10x20xf32>, %arg1: tensor<10x20xf32>, %dest: tensor<100x200xf32>) -> tensor<100x200xf32> {
+ %c0 = arith.constant 0 : index
+ %c10 = arith.constant 10 : index
+ %c20 = arith.constant 20 : index
+
+ %slice0 = tensor.insert_slice %arg0 into %dest[%c0, %c0] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32>
+ %slice1 = tensor.insert_slice %arg1 into %dest[%c0, %c0] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32>
+
+ // expected-remark @+1 {{can be fused with producer tensor.insert_slice ops}}
+ %result = linalg.add ins(%slice0, %slice1 : tensor<100x200xf32>, tensor<100x200xf32>)
+ outs(%dest : tensor<100x200xf32>) -> tensor<100x200xf32>
+
+ return %result : tensor<100x200xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg: !transform.any_op) {
+ %add = transform.structured.match ops{["linalg.add"]} in %arg : (!transform.any_op) -> !transform.any_op
+ transform.test.query_producer_fusability %add : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @not_fusable_with_different_offsets(%arg0: tensor<10x20xf32>, %arg1: tensor<10x20xf32>, %dest: tensor<100x200xf32>) -> tensor<100x200xf32> {
+ %c0 = arith.constant 0 : index
+ %c10 = arith.constant 10 : index
+ %c20 = arith.constant 20 : index
+
+ %slice0 = tensor.insert_slice %arg0 into %dest[%c0, %c0] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32>
+ %slice1 = tensor.insert_slice %arg1 into %dest[%c10, %c20] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32>
+
+ // expected-remark @+1 {{cannot be fused with producer tensor.insert_slice ops}}
+ %result = linalg.add ins(%slice0, %slice1 : tensor<100x200xf32>, tensor<100x200xf32>)
+ outs(%dest : tensor<100x200xf32>) -> tensor<100x200xf32>
+
+ return %result : tensor<100x200xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg: !transform.any_op) {
+ %add = transform.structured.match ops{["linalg.add"]} in %arg : (!transform.any_op) -> !transform.any_op
+ transform.test.query_producer_fusability %add : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @fusable_with_consumer_extract_slice(%arg0: tensor<100x200xf32>, %arg1: tensor<100x200xf32>, %dest: tensor<100x200xf32>) -> tensor<10x20xf32> {
+ // expected-remark @+1 {{can be fused with consumer tensor.extract_slice op}}
+ %add = linalg.add ins(%arg0, %arg1 : tensor<100x200xf32>, tensor<100x200xf32>)
+ outs(%dest : tensor<100x200xf32>) -> tensor<100x200xf32>
+
+ %c0 = arith.constant 0 : index
+ %slice = tensor.extract_slice %add[%c0, %c0] [10, 20] [1, 1] : tensor<100x200xf32> to tensor<10x20xf32>
+
+ return %slice : tensor<10x20xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg: !transform.any_op) {
+ %add = transform.structured.match ops{["linalg.add"]} in %arg : (!transform.any_op) -> !transform.any_op
+ transform.test.query_consumer_fusability %add : !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir
new file mode 100644
index 0000000..62dd7fa
--- /dev/null
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir
@@ -0,0 +1,1156 @@
+// RUN: mlir-opt --transform-interpreter --cse --split-input-file --verify-diagnostics %s | FileCheck %s
+
+#map = affine_map<(d0) -> (d0)>
+module {
+ func.func @fuse_tileable_consumer_scf_for(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
+ %c4 = arith.constant 4 : index
+ %c64 = arith.constant 64 : index
+ %c0 = arith.constant 0 : index
+ %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) {
+ %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32>
+ %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) {
+ ^bb0(%in: f32, %in_16: f32, %out: f32):
+ %13 = arith.mulf %in, %in_16 : f32
+ %14 = arith.addf %out, %13 : f32
+ linalg.yield %14 : f32
+ } -> tensor<32xf32>
+ %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32>
+ scf.yield %arg5, %4 : tensor<64xf32>, tensor<64xf32>
+ }
+ %in_operand_2 = tensor.empty() : tensor<64xf32>
+ %out_operand_3 = tensor.empty() : tensor<64xf32>
+ %2 = linalg.add ins(%1#1, %in_operand_2 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand_3 : tensor<64xf32>) -> tensor<64xf32>
+ return %2 : tensor<64xf32>
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %loop = transform.structured.match ops{["scf.for"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %a, %b = transform.test.fuse_consumer_using_slice %yield in (%loop)
+ : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: func.func @fuse_tileable_consumer_scf_for(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>)
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %0 = tensor.empty() : tensor<64xf32>
+// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]]
+// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG:.*]] = %0)
+// CHECK-SAME: {
+// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
+// CHECK: %[[MAT_OUT:.*]] = linalg.generic
+// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>)
+// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
+// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1]
+// CHECK: %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1]
+// CHECK: %[[ELEM_OUT:.*]] = linalg.add
+// CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] :
+// CHECK-SAME: outs(%[[SLICE_OUT]] :
+// CHECK: %[[INSERT_ELEM:.*]] = tensor.insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1]
+// CHECK: scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM]] :
+// CHECK: }
+// CHECK: return %[[FINAL_RESULT]]#2 :
+
+// -----
+
+module {
+ func.func @fuse_tileable_consumer_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> {
+ %c4 = arith.constant 4 : index
+ %c64 = arith.constant 64 : index
+ %c0 = arith.constant 0 : index
+ %1:2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2, %arg6 = %arg2) -> (tensor<64x64xf32>, tensor<64x64xf32>) {
+ %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+ %extracted_slice_1 = tensor.extract_slice %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+ %3 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %3 into %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+ tensor.parallel_insert_slice %extracted_slice_1 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+ }
+ }
+ %in_operand_2 = tensor.empty() : tensor<64x64xf32>
+ %out_operand_3 = tensor.empty() : tensor<64x64xf32>
+ %2 = linalg.add ins(%1#1, %in_operand_2 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%out_operand_3 : tensor<64x64xf32>) -> tensor<64x64xf32>
+ return %2 : tensor<64x64xf32>
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %first_slice_op, %second_slice_op = transform.split_handle %slice_ops
+ : (!transform.any_op)
+ -> (!transform.any_op, !transform.any_op)
+ %a, %b = transform.test.fuse_consumer_using_slice %first_slice_op in (%loop)
+ : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: func.func @fuse_tileable_consumer_scf_forall(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32>)
+// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<64x64xf32>
+// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
+// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG:.*]] = %[[OUT_INIT]])
+// CHECK-SAME: {
+// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[SECOND_ARG_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[MAT_OUT:.*]] = linalg.matmul
+// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] :
+// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %[[OUT_INIT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[ELEM_OUT:.*]] = linalg.add
+// CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] :
+// CHECK-SAME: outs(%[[SLICE_OUT]] :
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[MAT_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[SECOND_ARG_SLICE]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: }
+// CHECK: }
+// CHECK: return %[[FINAL_RESULT]]#2 :
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+module {
+ func.func @fuse_tileable_consumer_scf_for_multi_yielding_consumer(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
+ %c4 = arith.constant 4 : index
+ %c64 = arith.constant 64 : index
+ %c0 = arith.constant 0 : index
+ %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) {
+ %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32>
+ %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) {
+ ^bb0(%in: f32, %in_16: f32, %out: f32):
+ %13 = arith.mulf %in, %in_16 : f32
+ %14 = arith.addf %out, %13 : f32
+ linalg.yield %14 : f32
+ } -> tensor<32xf32>
+ %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32>
+ scf.yield %arg5, %4 : tensor<64xf32>, tensor<64xf32>
+ }
+ %in_operand_2 = tensor.empty() : tensor<64xf32>
+ %out_operand_3 = tensor.empty() : tensor<64xf32>
+ %out_operand_4 = tensor.empty() : tensor<64xf32>
+ %2:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%1#1, %in_operand_2 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand_3, %out_operand_4 : tensor<64xf32>, tensor<64xf32>) {
+ ^bb0(%in: f32, %in_16: f32, %out_0: f32, %out_1: f32):
+ %13 = arith.mulf %in, %in_16 : f32
+ %14 = arith.subf %out_0, %13 : f32
+ %15 = arith.addf %out_1, %in : f32
+ linalg.yield %14, %15 : f32, f32
+ } -> (tensor<64xf32>, tensor<64xf32>)
+ return %2#1 : tensor<64xf32>
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %loop = transform.structured.match ops{["scf.for"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %a, %b = transform.test.fuse_consumer_using_slice %yield in (%loop)
+ : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: func.func @fuse_tileable_consumer_scf_for_multi_yielding_consumer(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>)
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %0 = tensor.empty() : tensor<64xf32>
+// CHECK: %[[FINAL_RESULT:.*]]:4 = scf.for %[[IV:.*]] = %[[C0]]
+// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG_0:.*]] = %0, %[[ELEM_OUT_ARG_1:.*]] = %0)
+// CHECK-SAME: {
+// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
+// CHECK: %[[MAT_OUT:.*]] = linalg.generic
+// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>)
+// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
+// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1]
+// CHECK: %[[SLICE_OUT_0:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_0]][%[[IV]]] [32] [1]
+// CHECK: %[[SLICE_OUT_1:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_1]][%[[IV]]] [32] [1]
+// CHECK: %[[ELEM_OUT:.*]]:2 = linalg.generic
+// CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] :
+// CHECK-SAME: outs(%[[SLICE_OUT_0]], %[[SLICE_OUT_1]] :
+// CHECK: %[[INSERT_ELEM_0:.*]] = tensor.insert_slice %[[ELEM_OUT]]#0 into %[[ELEM_OUT_ARG_0]][%[[IV]]] [32] [1]
+// CHECK: %[[INSERT_ELEM_1:.*]] = tensor.insert_slice %[[ELEM_OUT]]#1 into %[[ELEM_OUT_ARG_1]][%[[IV]]] [32] [1]
+// CHECK: scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM_0]], %[[INSERT_ELEM_1]] :
+// CHECK: }
+// CHECK: return %[[FINAL_RESULT]]#3 :
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module {
+ func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64x32xf32>) -> (tensor<64x64xf32>, tensor<2048xf32>) {
+ %c4 = arith.constant 4 : index
+ %c64 = arith.constant 64 : index
+ %c0 = arith.constant 0 : index
+ %0:2 = scf.forall (%arg4, %arg5) in (2, 2) shared_outs(%arg6 = %arg3, %arg7 = %arg2) -> (tensor<64x32xf32>, tensor<64x64xf32>) {
+ %extracted_slice = tensor.extract_slice %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
+ %extracted_slice_0 = tensor.extract_slice %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+ %6 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %6 into %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+ tensor.parallel_insert_slice %extracted_slice_0 into %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
+ }
+ }
+ %1 = tensor.empty() : tensor<64x64xf32>
+ %2 = tensor.empty() : tensor<64x64xf32>
+ %3 = tensor.empty() : tensor<64x64xf32>
+ %4:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%0#1, %1 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%2, %3 : tensor<64x64xf32>, tensor<64x64xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32, %out_1: f32):
+ %6 = arith.mulf %in, %in_0 : f32
+ %7 = arith.subf %out, %6 : f32
+ %8 = arith.addf %out_1, %in : f32
+ linalg.yield %7, %8 : f32, f32
+ } -> (tensor<64x64xf32>, tensor<64x64xf32>)
+ %5 = tensor.empty() : tensor<2048xf32>
+ %unpack = linalg.unpack %0#0 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %5 : tensor<64x32xf32> -> tensor<2048xf32>
+ return %4#1, %unpack : tensor<64x64xf32>, tensor<2048xf32>
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %first_slice_op, %second_slice_op = transform.split_handle %slice_ops
+ : (!transform.any_op)
+ -> (!transform.any_op, !transform.any_op)
+ %a, %b = transform.test.fuse_consumer_using_slice %first_slice_op in (%loop)
+ : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32>
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
+// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<64x64xf32>
+// CHECK: %[[FINAL_RESULT:.*]]:4 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
+// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG3]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG_0:.*]] = %[[OUT_INIT]], %[[ELEM_OUT_ARG_1:.*]] = %[[OUT_INIT]])
+// CHECK-SAME: {
+// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[SECOND_ARG_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[MAT_OUT:.*]] = linalg.matmul
+// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] :
+// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %[[OUT_INIT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[SLICE_OUT_0:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_0]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[SLICE_OUT_1:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_1]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[ELEM_OUT:.*]]:2 = linalg.generic
+// CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] :
+// CHECK-SAME: outs(%[[SLICE_OUT_0]], %[[SLICE_OUT_1]] :
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[MAT_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[SECOND_ARG_SLICE]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[ELEM_OUT]]#0 into %[[ELEM_OUT_ARG_0]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[ELEM_OUT]]#1 into %[[ELEM_OUT_ARG_1]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: }
+// CHECK: }
+// CHECK: %[[UNPACK:.*]] = linalg.unpack %[[FINAL_RESULT]]#0 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %{{.*}} : tensor<64x32xf32> -> tensor<2048xf32>
+// CHECK: return %[[FINAL_RESULT]]#3, %[[UNPACK]] :
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module {
+ func.func @fuse_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2048xf32> {
+ %c4 = arith.constant 4 : index
+ %c64 = arith.constant 64 : index
+ %c0 = arith.constant 0 : index
+ %1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
+ %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
+ %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
+ ^bb0(%in: f32, %in_16: f32, %out: f32):
+ %13 = arith.mulf %in, %in_16 : f32
+ %14 = arith.addf %out, %13 : f32
+ linalg.yield %14 : f32
+ } -> tensor<32x32xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
+ }
+ }
+ %output = tensor.empty() : tensor<2048xf32>
+ %unpack = linalg.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2048xf32>
+ return %unpack : tensor<2048xf32>
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %a, %b = transform.test.fuse_consumer_using_slice %slice_op in (%loop)
+ : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
+// CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2048)>
+// CHECK: func.func @fuse_unpack_consumer_into_scf_forall(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
+// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<2048xf32>
+// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) = (0, 0) to (64, 32) step (32, 32)
+// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[UNPACK_OUT_ARG:.*]] = %[[OUT_INIT]])
+// CHECK-SAME: {
+// CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[GENERIC_OUT:.*]] = linalg.generic
+// CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] :
+// CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_OFFSET_MAP]](%[[IV1]])
+// CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min #[[UNPACK_RESULT_SIZE_MAP]](%[[IV1]])
+// CHECK: %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
+// CHECK: %[[TILED_UNPACK_OUT:.*]] = linalg.unpack %[[GENERIC_OUT]]
+// CHECK-SAME: outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32]
+// CHECK-SAME: into %[[TILED_UNPACK_DEST]]
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[TILED_UNPACK_OUT]] into %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
+// CHECK: }
+// CHECK: }
+// CHECK: return %[[FINAL_RESULT]]#1 :
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module {
+ func.func @fuse_unaligned_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2047xf32> {
+ %c4 = arith.constant 4 : index
+ %c64 = arith.constant 64 : index
+ %c0 = arith.constant 0 : index
+ %1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
+ %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
+ %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
+ ^bb0(%in: f32, %in_16: f32, %out: f32):
+ %13 = arith.mulf %in, %in_16 : f32
+ %14 = arith.addf %out, %13 : f32
+ linalg.yield %14 : f32
+ } -> tensor<32x32xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
+ }
+ }
+ %output = tensor.empty() : tensor<2047xf32>
+ %unpack = linalg.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2047xf32>
+ return %unpack : tensor<2047xf32>
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %a, %b = transform.test.fuse_consumer_using_slice %slice_op in (%loop)
+ : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
+// CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2047)>
+// CHECK: func.func @fuse_unaligned_unpack_consumer_into_scf_forall(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
+// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<2047xf32>
+// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) = (0, 0) to (64, 32) step (32, 32)
+// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[UNPACK_OUT_ARG:.*]] = %[[OUT_INIT]])
+// CHECK-SAME: {
+// CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[GENERIC_OUT:.*]] = linalg.generic
+// CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] :
+// CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_OFFSET_MAP]](%[[IV1]])
+// CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min #[[UNPACK_RESULT_SIZE_MAP]](%[[IV1]])
+// CHECK: %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
+// CHECK: %[[TILED_UNPACK_OUT:.*]] = linalg.unpack %[[GENERIC_OUT]]
+// CHECK-SAME: outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32]
+// CHECK-SAME: into %[[TILED_UNPACK_DEST]]
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[TILED_UNPACK_OUT]] into %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
+// CHECK: }
+// CHECK: }
+// CHECK: return %[[FINAL_RESULT]]#1 :
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module {
+ func.func @fuse_perfect_tiling_pack_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<4x32x16xf32> {
+ %c4 = arith.constant 4 : index
+ %c64 = arith.constant 64 : index
+ %c0 = arith.constant 0 : index
+ %1 = scf.forall (%arg3, %arg4) in (2, 1) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
+ %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
+ %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
+ ^bb0(%in: f32, %in_16: f32, %out: f32):
+ %13 = arith.mulf %in, %in_16 : f32
+ %14 = arith.addf %out, %13 : f32
+ linalg.yield %14 : f32
+ } -> tensor<32x32xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
+ }
+ }
+ %output = tensor.empty() : tensor<4x32x16xf32>
+ %pack = linalg.pack %1 inner_dims_pos = [0] inner_tiles = [16] into %output : tensor<64x32xf32> -> tensor<4x32x16xf32>
+ return %pack : tensor<4x32x16xf32>
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %a, %b = transform.test.fuse_consumer_using_slice %slice_op in (%loop)
+ : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
+// CHECK: func.func @fuse_perfect_tiling_pack_consumer(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
+// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<4x32x16xf32>
+// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 1)
+// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]])
+// CHECK-SAME: {
+// CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[GENERIC_OUT:.*]] = linalg.generic
+// CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] :
+// CHECK: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV1]])
+// CHECK: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], %[[IV2]], 0] [2, 32, 16] [1, 1, 1]
+// CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[GENERIC_OUT]]
+// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16]
+// CHECK-SAME: into %[[TILED_PACK_DEST]]
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], %[[IV2]], 0] [2, 32, 16] [1, 1, 1]
+
+// -----
+
+#map = affine_map<(d0) -> (-d0 + 4, 16)>
+func.func @fuse_pack_consumer_if_single_iteration(%arg0: tensor<4x4xf32>) -> tensor<1x4x16x1xf32> {
+ %0 = tensor.empty() : tensor<1x4x16x1xf32>
+ %1 = tensor.empty() : tensor<4x4xf32>
+ %2 = scf.forall (%arg1) = (0) to (4) step (16) shared_outs(%arg2 = %1) -> (tensor<4x4xf32>) {
+ %3 = affine.min #map(%arg1)
+ %extracted_slice = tensor.extract_slice %arg0[%arg1, 0] [%3, 4] [1, 1] : tensor<4x4xf32> to tensor<?x4xf32>
+ %extracted_slice_0 = tensor.extract_slice %arg2[%arg1, 0] [%3, 4] [1, 1] : tensor<4x4xf32> to tensor<?x4xf32>
+ %4 = linalg.exp ins(%extracted_slice : tensor<?x4xf32>) outs(%extracted_slice_0 : tensor<?x4xf32>) -> tensor<?x4xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %4 into %arg2[%arg1, 0] [%3, 4] [1, 1] : tensor<?x4xf32> into tensor<4x4xf32>
+ }
+ }
+ %cst = arith.constant 0.000000e+00 : f32
+ %pack = linalg.pack %2 padding_value(%cst : f32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 1] into %0 : tensor<4x4xf32> -> tensor<1x4x16x1xf32>
+ return %pack : tensor<1x4x16x1xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %consumer, %fused_consumer = transform.test.fuse_consumer_using_slice %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: #[[MAP:.*]] = affine_map<(d0) -> (-d0 + 4, 16)>
+// CHECK: func.func @fuse_pack_consumer_if_single_iteration(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[PACK_INIT:.*]] = tensor.empty() : tensor<1x4x16x1xf32>
+// CHECK-DAG: %[[ELEM_INIT:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (4) step (16)
+// CHECK-SAME: shared_outs(%[[ELEM_OUT_ARG:.*]] = %[[ELEM_INIT]], %[[PACK_OUT_ARG:.*]] = %[[PACK_INIT]])
+// CHECK-DAG: %[[SIZE:.+]] = affine.min #[[MAP]](%[[IV]])
+// CHECK-DAG: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
+// CHECK-DAG: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
+// CHECK: %[[ELEM:.*]] = linalg.exp
+// CHECK-SAME: ins(%[[ELEM_SRC]]
+// CHECK-SAME: outs(%[[ELEM_DEST]]
+// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[IV]], 0, 0, 0] [1, 4, 16, 1] [1, 1, 1, 1]
+// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]]
+// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32)
+// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 1]
+// CHECK-SAME: into %[[TILED_PACK_DEST]]
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][%[[IV]], 0, 0, 0] [1, 4, 16, 1] [1, 1, 1, 1]
+
+// -----
+
+func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>, %arg2: tensor<2x64x16x1xf32>) -> tensor<2x64x16x1xf32> {
+ %0 = scf.forall (%arg3) = (0) to (32) step (16) shared_outs(%arg4 = %arg1) -> (tensor<64x32xf32>) {
+ %src = tensor.extract_slice %arg0[0, %arg3] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
+ %dest = tensor.extract_slice %arg4[0, %arg3] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
+ %1 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %1 into %arg4[0, %arg3] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32>
+ }
+ }
+ %pack = linalg.pack %0 outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 1] into %arg2 : tensor<64x32xf32> -> tensor<2x64x16x1xf32>
+ return %pack : tensor<2x64x16x1xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %consumer, %fused_consumer = transform.test.fuse_consumer_using_slice %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
+// CHECK: func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16)
+// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[ARG2]])
+// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1]
+// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
+// CHECK: %[[ELEM:.*]] = linalg.exp
+// CHECK-SAME: ins(%[[ELEM_SRC]]
+// CHECK-SAME: outs(%[[ELEM_DEST]]
+// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
+// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1]
+// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]]
+// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 1]
+// CHECK-SAME: into %[[TILED_PACK_DEST]]
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1]
+
+// -----
+
+// It is valid to fuse the pack op in perfect tiling scenario when the dimension
+// is dynamic and padding is not needed.
+
+func.func @fuse_pack_consumer_with_no_pad_dynamic_dim(%arg0: tensor<64x?xf32>, %arg1: tensor<64x?xf32>, %1: tensor<64x?x16xf32>) -> tensor<64x?x16xf32> {
+ %c1 = arith.constant 1 : index
+ %d1 = tensor.dim %arg0, %c1 : tensor<64x?xf32>
+ %0 = scf.forall (%arg2) = (0) to (%d1) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x?xf32>) {
+ %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x?xf32> to tensor<64x16xf32>
+ %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x?xf32> to tensor<64x16xf32>
+ %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x?xf32>
+ }
+ }
+ %pack = linalg.pack %0 inner_dims_pos = [1] inner_tiles = [16] into %1 : tensor<64x?xf32> -> tensor<64x?x16xf32>
+ return %pack : tensor<64x?x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %consumer, %fused_consumer = transform.test.fuse_consumer_using_slice %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
+// CHECK: func.func @fuse_pack_consumer_with_no_pad_dynamic_dim(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (%{{.+}}) step (16)
+// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[ARG2]])
+// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1]
+// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
+// CHECK: %[[ELEM:.*]] = linalg.exp
+// CHECK-SAME: ins(%[[ELEM_SRC]]
+// CHECK-SAME: outs(%[[ELEM_DEST]]
+// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
+// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0] [64, 1, 16] [1, 1, 1]
+// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]]
+// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [16]
+// CHECK-SAME: into %[[TILED_PACK_DEST]]
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0] [64, 1, 16] [1, 1, 1]
+
+// -----
+
+// It is valid to fuse the pack op with padding semantics if it is a perfect
+// tiling case.
+
+func.func @fuse_pack_consumer_with_padding_semantics(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<22x2x3x16xf32> {
+ %0 = scf.forall (%arg2, %arg3) = (0, 0) to (64, 32) step (15, 16) shared_outs(%arg4 = %arg1) -> (tensor<64x32xf32>) {
+ %size = affine.min affine_map<(d0) -> (-d0 + 64, 15)>(%arg2)
+ %src = tensor.extract_slice %arg0[%arg2, %arg3] [%size, 16] [1, 1] : tensor<64x32xf32> to tensor<?x16xf32>
+ %dest = tensor.extract_slice %arg4[%arg2, %arg3] [%size, 16] [1, 1] : tensor<64x32xf32> to tensor<?x16xf32>
+ %2 = linalg.exp ins(%src : tensor<?x16xf32>) outs(%dest : tensor<?x16xf32>) -> tensor<?x16xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %2 into %arg4[%arg2, %arg3] [%size, 16] [1, 1] : tensor<?x16xf32> into tensor<64x32xf32>
+ }
+ }
+ %1 = tensor.empty() : tensor<22x2x3x16xf32>
+ %cst = arith.constant 0.000000e+00 : f32
+ %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<22x2x3x16xf32>
+ return %pack : tensor<22x2x3x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %consumer, %fused_consumer = transform.test.fuse_consumer_using_slice %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (-d0 + 64, 15)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (d0 floordiv 3)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0) -> (d0 ceildiv 3)>
+// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
+// CHECK: func.func @fuse_pack_consumer_with_padding_semantics(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<22x2x3x16xf32>
+// CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %{{.*}}:2 = scf.forall (%[[I:.*]], %[[J:.*]]) = (0, 0) to (64, 32) step (15, 16)
+// CHECK-SAME: shared_outs(%[[ELEM_OUT:.*]] = %[[ARG1]], %[[PACK_OUT:.*]] = %[[OUT_INIT]])
+// CHECK: %[[SIZE:.+]] = affine.min #[[MAP0]](%[[I]])
+// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1]
+// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT]]
+// CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1]
+// CHECK: %[[ELEM:.*]] = linalg.exp
+// CHECK-SAME: ins(%[[ELEM_SRC]]
+// CHECK-SAME: outs(%[[ELEM_DEST]]
+// CHECK-DAG: %[[D0_OFFSET:.*]] = affine.apply #[[MAP1]](%[[I]])
+// CHECK-DAG: %[[D0_SIZE:.*]] = affine.apply #[[MAP2]](%[[SIZE]])
+// CHECK-DAG: %[[D1_OFFSET:.*]] = affine.apply #[[MAP3]](%[[J]])
+// CHECK-DAG: %[[PACK_INIT:.*]] = tensor.extract_slice %[[PACK_OUT]]
+// CHECK-SAME: [%[[D0_OFFSET]], %[[D1_OFFSET]], 0, 0] [%[[D0_SIZE]], 1, 3, 16] [1, 1, 1, 1]
+// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]]
+// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32)
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 16]
+// CHECK-SAME: into %[[TILED_PACK_DEST]]
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[ELEM_OUT]]
+// CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT]]
+// CHECK-SAME: [%[[D0_OFFSET]], %[[D1_OFFSET]], 0, 0] [%[[D0_SIZE]], 1, 3, 16] [1, 1, 1, 1]
+
+// -----
+
+// Imperfect tiling is not supported in pack op consumer fusion.
+
+#map = affine_map<(d0) -> (d0 * 5)>
+#map1 = affine_map<(d0) -> (d0)>
+func.func @nofuse_pack_with_imperfect_tiling(%arg0: tensor<30xf32>) -> tensor<5x6xf32> {
+ %0 = tensor.empty() : tensor<30xf32>
+ %1 = scf.forall (%arg1) in (6) shared_outs(%arg2 = %0) -> (tensor<30xf32>) {
+ %3 = affine.apply #map(%arg1)
+ %extracted_slice = tensor.extract_slice %arg0[%3] [5] [1] : tensor<30xf32> to tensor<5xf32>
+ %extracted_slice_0 = tensor.extract_slice %arg2[%3] [5] [1] : tensor<30xf32> to tensor<5xf32>
+ %4 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel"]} ins(%extracted_slice : tensor<5xf32>) outs(%extracted_slice_0 : tensor<5xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %5 = arith.addf %in, %in : f32
+ linalg.yield %5 : f32
+ } -> tensor<5xf32>
+ scf.forall.in_parallel {
+ // expected-error @below {{failed to fuse consumer of slice}}
+ tensor.parallel_insert_slice %4 into %arg2[%3] [5] [1] : tensor<5xf32> into tensor<30xf32>
+ }
+ }
+ %2 = tensor.empty() : tensor<5x6xf32>
+ %pack = linalg.pack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [6] into %2 : tensor<30xf32> -> tensor<5x6xf32>
+ return %pack : tensor<5x6xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %consumer, %fused_consumer = transform.test.fuse_consumer_using_slice %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+module {
+ func.func @fuse_add_multiple_tilable_consumers(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c256 = arith.constant 256 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %dest0 = tensor.empty() : tensor<256x256xf32>
+ %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest0) -> (tensor<256x256xf32>) {
+ %extracted_slice_1 = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+ %extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+ %extracted_slice_3 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+ %3 = linalg.add ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice_1 : tensor<64x256xf32>) -> tensor<64x256xf32>
+ %insert_slice = tensor.insert_slice %3 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32>
+ scf.yield %insert_slice : tensor<256x256xf32>
+ }
+ %4 = linalg.mul ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+ %5 = linalg.exp ins(%1 : tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+ return %4, %5 : tensor<256x256xf32>, tensor<256x256xf32>
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %loop = transform.structured.match ops{["scf.for"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %a, %b = transform.test.fuse_consumer_using_slice %slice_op in (%loop) num_consumer_to_fuse = 2
+ : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: func.func @fuse_add_multiple_tilable_consumers(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x256xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<256x256xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
+// CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
+// CHECK: %[[LOOP_RESULT:.*]]:3 = scf.for %[[IV1:.*]] = %[[C0]]
+// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]])
+// CHECK-SAME: {
+// CHECK: %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+// CHECK: %[[ADD_INS0_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 256] [1, 1]
+// CHECK: %[[ADD_INS1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[IV1]], 0] [64, 256] [1, 1]
+// CHECK: %[[TILED_ADD_OUT:.*]] = linalg.add
+// CHECK-SAME: ins(%[[ADD_INS0_SLICE]], %[[ADD_INS1_SLICE]] :
+// CHECK-SAME: outs(%[[ADD_OUT_SLICE]] :
+// CHECK: %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+// CHECK: %[[EXP_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+// CHECK: %[[TILED_EXP_OUT:.*]] = linalg.exp
+// CHECK-SAME: ins(%[[TILED_ADD_OUT]] :
+// CHECK-SAME: outs(%[[EXP_OUT_SLICE]] :
+// CHECK: %[[MUL_INS2_SLICE:.*]] = tensor.extract_slice %[[ARG2]][%[[IV1]], 0] [64, 256] [1, 1]
+// CHECK: %[[MUL_OUT_SLICE:.*]] = tensor.extract_slice %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+// CHECK: %[[TILED_MUL_OUT:.*]] = linalg.mul
+// CHECK-SAME: ins(%[[TILED_ADD_OUT]], %[[MUL_INS2_SLICE]] :
+// CHECK-SAME: outs(%[[MUL_OUT_SLICE]] :
+// CHECK: %[[INSERT_EXP:.*]] = tensor.insert_slice %[[TILED_EXP_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+// CHECK: %[[INSERT_MUL:.*]] = tensor.insert_slice %[[TILED_MUL_OUT]] into %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+// CHECK: scf.yield %[[INSERT_ADD]], %[[INSERT_EXP]], %[[INSERT_MUL]] :
+// CHECK: }
+// CHECK: return %[[LOOP_RESULT]]#2, %[[LOOP_RESULT]]#1 :
+
+// -----
+
+module {
+ func.func @no_fuse_only_dps_consumer(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<258x258xf32>) {
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c256 = arith.constant 256 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %dest0 = tensor.empty() : tensor<256x256xf32>
+ %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest0) -> (tensor<256x256xf32>) {
+ %extracted_slice_1 = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+ %extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+ %extracted_slice_3 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+ %3 = linalg.add ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice_1 : tensor<64x256xf32>) -> tensor<64x256xf32>
+ %insert_slice = tensor.insert_slice %3 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32>
+ scf.yield %insert_slice : tensor<256x256xf32>
+ }
+ %dest1 = tensor.empty() : tensor<258x258xf32>
+ %4 = tensor.insert_slice %1 into %dest1[0, 0] [256, 256] [1, 1] : tensor<256x256xf32> into tensor<258x258xf32>
+ %5 = linalg.mul ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+ return %5, %4 : tensor<256x256xf32>, tensor<258x258xf32>
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %slice_ops = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %loop = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %slice_op, %other_slice = transform.split_handle %slice_ops : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %a, %b = transform.test.fuse_consumer_using_slice %slice_op in (%loop) num_consumer_to_fuse = 1
+ : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: func.func @no_fuse_only_dps_consumer(
+// CHECK: %[[LOOP_RESULT:.*]]:2 = scf.for {{.*}} {
+// CHECK: linalg.add
+// CHECK: linalg.mul
+// CHECK: scf.yield
+// CHECK: }
+// CHECK: %[[RES_SLICE:.+]] = tensor.insert_slice
+// CHECK: return %[[LOOP_RESULT]]#1, %[[RES_SLICE]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+module {
+ func.func @fuse_with_tilable_consumer_with_projected_permutations(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<24xf32>) -> tensor<256x256x24xf32> {
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c256 = arith.constant 256 : index
+ %0 = tensor.empty() : tensor<256x256xf32>
+ %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %0) -> (tensor<256x256xf32>) {
+ %extracted_slice = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+ %extracted_slice_0 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+ %extracted_slice_1 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+ %4 = linalg.add ins(%extracted_slice_0, %extracted_slice_1 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice : tensor<64x256xf32>) -> tensor<64x256xf32>
+ %inserted_slice = tensor.insert_slice %4 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32>
+ scf.yield %inserted_slice : tensor<256x256xf32>
+ }
+ %2 = tensor.empty() : tensor<256x256x24xf32>
+ %3 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1, %arg2 : tensor<256x256xf32>, tensor<24xf32>) outs(%2 : tensor<256x256x24xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %4 = arith.addf %in, %in_0 : f32
+ linalg.yield %4 : f32
+ } -> tensor<256x256x24xf32>
+ return %3 : tensor<256x256x24xf32>
+ }
+}
+
+// CHECK: func.func @fuse_with_tilable_consumer_with_projected_permutations(%[[VAL_0:.*]]: tensor<256x256xf32>, %[[VAL_1:.*]]: tensor<256x256xf32>, %[[VAL_2:.*]]: tensor<24xf32>) -> tensor<256x256x24xf32> {
+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_4:.*]] = arith.constant 64 : index
+// CHECK: %[[VAL_5:.*]] = arith.constant 256 : index
+// CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<256x256xf32>
+// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<256x256x24xf32>
+// CHECK: %[[VAL_8:.*]]:2 = scf.for %[[VAL_9:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_10:.*]] = %[[VAL_6]], %[[VAL_11:.*]] = %[[VAL_7]]) -> (tensor<256x256xf32>, tensor<256x256x24xf32>) {
+// CHECK: %[[VAL_12:.*]] = tensor.extract_slice %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
+// CHECK: %[[VAL_13:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
+// CHECK: %[[VAL_14:.*]] = tensor.extract_slice %[[VAL_1]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
+// CHECK: %[[VAL_15:.*]] = linalg.add ins(%[[VAL_13]], %[[VAL_14]] : tensor<64x256xf32>, tensor<64x256xf32>) outs(%[[VAL_12]] : tensor<64x256xf32>) -> tensor<64x256xf32>
+// CHECK: %[[VAL_16:.*]] = tensor.insert_slice %[[VAL_15]] into %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
+// CHECK: %[[VAL_17:.*]] = tensor.extract_slice %[[VAL_2]][0] [24] [1] : tensor<24xf32> to tensor<24xf32>
+// CHECK: %[[VAL_18:.*]] = tensor.extract_slice %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1]
+// CHECK: %[[VAL_19:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_15]], %[[VAL_17]] : tensor<64x256xf32>, tensor<24xf32>) outs(%[[VAL_18]] : tensor<64x256x24xf32>) {
+// CHECK: ^bb0(%[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32, %[[VAL_22:.*]]: f32):
+// CHECK: %[[VAL_23:.*]] = arith.addf %[[VAL_20]], %[[VAL_21]] : f32
+// CHECK: linalg.yield %[[VAL_23]] : f32
+// CHECK: } -> tensor<64x256x24xf32>
+// CHECK: %[[VAL_24:.*]] = tensor.insert_slice %[[VAL_25:.*]] into %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1]
+// CHECK: scf.yield %[[VAL_16]], %[[VAL_24]] : tensor<256x256xf32>, tensor<256x256x24xf32>
+// CHECK: }
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %loop = transform.structured.match ops{["scf.for"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %a, %b = transform.test.fuse_consumer_using_slice %slice_op in (%loop) num_consumer_to_fuse = 1
+ : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @multi_slice_fusion1(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>, %arg3 : index) -> tensor<?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+ %dim1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+ %loop:2 = scf.forall (%iv0) = (%c0) to (%dim0) step (%arg3) shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?xf32>, tensor<?xf32>) {
+ %tilesize = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
+ %arg0_slice = tensor.extract_slice %arg0[%iv0, 0] [%tilesize, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+ %init0_slice = tensor.extract_slice %init0[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
+ %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
+ %generic:2 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice, %init1_slice : tensor<?xf32>, tensor<?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+ %0 = arith.mulf %b0, %b1 : f32
+ %1 = arith.addf %b0, %b2 : f32
+ linalg.yield %0, %1 : f32, f32
+ } -> (tensor<?xf32>, tensor<?xf32>)
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %generic#0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
+ tensor.parallel_insert_slice %generic#1 into %init1[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
+ }
+ }
+ %empty = tensor.empty(%dim0) : tensor<?xf32>
+ %result = linalg.generic {
+ indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]}
+ ins(%loop#0, %loop#1 : tensor<?xf32>, tensor<?xf32>) outs(%empty : tensor<?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+ %0 = arith.addf %b0, %b1 : f32
+ linalg.yield %0 : f32
+ } -> tensor<?xf32>
+ return %result : tensor<?xf32>
+}
+// CHECK-LABEL: func @multi_slice_fusion1(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
+// CHECK: %[[C0:.+]] = arith.constant 0
+// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0]])
+// CHECK: %[[RESULT:.+]]:3 = scf.forall (%[[IV:.+]]) =
+// CHECK-SAME: , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]])
+// CHECK: %[[TILESIZE:.+]] = affine.min
+// CHECK-DAG: %[[GENERIC:.+]]:2 = linalg.generic
+// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
+// CHECK: %[[FUSED:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[GENERIC]]#0, %[[GENERIC]]#1 :
+// CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
+// CHECK: return %[[RESULT]]#2
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %a, %b = transform.test.fuse_consumer_using_slice %yield0, %yield1 in (%loop)
+ : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+// Check that when the given operand tiles are inconsistent, tiling fails.
+
+func.func @multi_slice_fusion2(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>, %arg3 : index) -> tensor<?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+ %dim1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+ %loop:2 = scf.forall (%iv0) = (%c0) to (%dim0) step (%arg3) shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?xf32>, tensor<?xf32>) {
+ %tilesize = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
+ %arg0_slice = tensor.extract_slice %arg0[%iv0, 0] [%tilesize, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+ %init0_slice = tensor.extract_slice %init0[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
+ %generic0 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice : tensor<?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32):
+ %0 = arith.mulf %b0, %b1 : f32
+ linalg.yield %0 : f32
+ } -> tensor<?xf32>
+ %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
+ %generic1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%arg0_slice : tensor<?x?xf32>) outs(%init1_slice: tensor<?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32):
+ %0 = arith.addf %b0, %b1 : f32
+ linalg.yield %0: f32
+ } -> tensor<?xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %generic0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
+ tensor.parallel_insert_slice %generic1 into %init1[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
+ }
+ }
+ %empty = tensor.empty(%dim0) : tensor<?xf32>
+ %result = linalg.generic {
+ indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]}
+ ins(%loop#0, %loop#1 : tensor<?xf32>, tensor<?xf32>) outs(%empty : tensor<?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+ %0 = arith.addf %b0, %b1 : f32
+ linalg.yield %0 : f32
+ } -> tensor<?xf32>
+ return %result : tensor<?xf32>
+}
+// CHECK-LABEL: func @multi_slice_fusion2(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
+// CHECK: %[[C0:.+]] = arith.constant 0
+// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0]])
+// CHECK: %[[RESULT:.+]]:3 = scf.forall (%[[IV:.+]]) =
+// CHECK-SAME: , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]])
+// CHECK: %[[TILESIZE:.+]] = affine.min
+// CHECK: %[[GENERIC0:.+]] = linalg.generic
+// CHECK: %[[GENERIC1:.+]] = linalg.generic
+// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
+// CHECK: %[[FUSED:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[GENERIC0]], %[[GENERIC1]] :
+// CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
+// CHECK: return %[[RESULT]]#2
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %a, %b = transform.test.fuse_consumer_using_slice %yield0, %yield1 in (%loop)
+ : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @multi_slice_fusion_with_broadcast(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?xf32>,
+ %arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %dim0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
+ %dim1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
+ %dim2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
+ %loop:2 = scf.forall (%iv0, %iv1) = (%c0, %c0) to (%dim0, %dim1) step (%arg3, %arg4)
+ shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?x?xf32>, tensor<?xf32>) {
+ %tilesize0 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
+ %tilesize1 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv1)[%dim1, %arg4]
+ %arg0_slice = tensor.extract_slice %arg0[%iv0, %iv1, 0] [%tilesize0, %tilesize1, %dim2] [1, 1, 1]
+ : tensor<?x?x?xf32> to tensor<?x?x?xf32>
+ %init0_slice = tensor.extract_slice %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+ : tensor<?x?xf32> to tensor<?x?xf32>
+ %generic0 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init0_slice : tensor<?x?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32):
+ %0 = arith.mulf %b0, %b1 : f32
+ linalg.yield %0 : f32
+ } -> tensor<?x?xf32>
+ %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize0] [1] : tensor<?xf32> to tensor<?xf32>
+ %generic1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%generic0 : tensor<?x?xf32>) outs(%init1_slice: tensor<?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32):
+ %0 = arith.addf %b0, %b1 : f32
+ linalg.yield %0: f32
+ } -> tensor<?xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+ : tensor<?x?xf32> into tensor<?x?xf32>
+ tensor.parallel_insert_slice %generic1 into %init1[%iv0] [%tilesize0] [1] : tensor<?xf32> into tensor<?xf32>
+ }
+ }
+ %empty = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
+ %result = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%loop#0, %loop#1 : tensor<?x?xf32>, tensor<?xf32>) outs(%empty : tensor<?x?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+ %0 = arith.addf %b0, %b1 : f32
+ linalg.yield %0 : f32
+ } -> tensor<?x?xf32>
+ return %result : tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %a, %b = transform.test.fuse_consumer_using_slice %yield0, %yield1 in (%loop)
+ : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-LABEL: func @multi_slice_fusion_with_broadcast(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1
+// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0]], %[[DIM1]])
+// CHECK: %[[RESULT:.+]]:3 = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) =
+// CHECK-SAME: , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]])
+// CHECK-DAG: %[[TILESIZE0:.+]] = affine.min {{.+}}(%[[IV0]])
+// CHECK-DAG: %[[TILESIZE1:.+]] = affine.min {{.+}}(%[[IV1]])
+// CHECK: %[[GENERIC0:.+]] = linalg.generic
+// CHECK: %[[GENERIC1:.+]] = linalg.generic
+// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]] [%[[TILESIZE0]], %[[TILESIZE1]]]
+// CHECK: %[[FUSED:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[GENERIC0]], %[[GENERIC1]] :
+// CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV0]], %[[IV1]]] [%[[TILESIZE0]], %[[TILESIZE1]]]
+// CHECK: return %[[RESULT]]#2
+
+// -----
+
+func.func @multi_slice_fusion_invalid(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>,
+ %arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %dim0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
+ %dim1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
+ %dim2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
+ %loop:2 = scf.forall (%iv0, %iv1) = (%c0, %c0) to (%dim0, %dim1) step (%arg3, %arg4)
+ shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+ %tilesize0 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
+ %tilesize1 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv1)[%dim1, %arg4]
+ %arg0_slice = tensor.extract_slice %arg0[%iv0, %iv1, 0] [%tilesize0, %tilesize1, %dim2] [1, 1, 1]
+ : tensor<?x?x?xf32> to tensor<?x?x?xf32>
+ %init0_slice = tensor.extract_slice %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+ : tensor<?x?xf32> to tensor<?x?xf32>
+ %generic0 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init0_slice : tensor<?x?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32):
+ %0 = arith.mulf %b0, %b1 : f32
+ linalg.yield %0 : f32
+ } -> tensor<?x?xf32>
+ %init1_slice = tensor.extract_slice %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+ : tensor<?x?xf32> to tensor<?x?xf32>
+ %generic1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init1_slice: tensor<?x?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32):
+ %0 = arith.addf %b0, %b1 : f32
+ linalg.yield %0: f32
+ } -> tensor<?x?xf32>
+ scf.forall.in_parallel {
+ // expected-error @below {{failed to fuse consumer of slice}}
+ tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+ : tensor<?x?xf32> into tensor<?x?xf32>
+ tensor.parallel_insert_slice %generic1 into %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
+ : tensor<?x?xf32> into tensor<?x?xf32>
+ }
+ }
+ %empty = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
+ %result = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%loop#0, %loop#1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%empty : tensor<?x?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+ %0 = arith.addf %b0, %b1 : f32
+ linalg.yield %0 : f32
+ } -> tensor<?x?xf32>
+ return %result : tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %a, %b = transform.test.fuse_consumer_using_slice %yield0, %yield1 in (%loop)
+ : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index 7888462..0137e2a 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -1,8 +1,8 @@
-// RUN: mlir-opt --transform-interpreter --cse --split-input-file --verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt --transform-interpreter --cse --split-input-file --verify-diagnostics --mlir-print-local-scope %s | FileCheck %s
#map = affine_map<(d0) -> (d0)>
module {
- func.func @fuse_tileable_consumer_scf_for(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
+ func.func @fuse_tilable_consumer_scf_for(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
%c4 = arith.constant 4 : index
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
@@ -28,14 +28,14 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%loop = transform.structured.match ops{["scf.for"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+ %add = transform.structured.match ops{["linalg.add"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- %a, %b = transform.test.fuse_consumer %yield in (%loop)
+ %a, %new_loop = transform.test.fuse_consumer %add into (%loop)
: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
-// CHECK: func.func @fuse_tileable_consumer_scf_for(
+// CHECK: func.func @fuse_tilable_consumer_scf_for(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>)
@@ -60,8 +60,61 @@ module attributes {transform.with_named_sequence} {
// -----
+#map = affine_map<(d0) -> (d0)>
module {
- func.func @fuse_tileable_consumer_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> {
+ func.func @fuse_tilable_consumer_nested_scf_for(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2 : tensor<?x?xf32>,
+ %lb0 : index, %ub0 : index, %step0 : index,
+ %lb1 : index, %ub1 : index, %step1 : index) -> tensor<?x?xf32> {
+ %0 = scf.for %arg3 = %lb0 to %ub0 step %step0 iter_args(%init0 = %arg0) -> tensor<?x?xf32> {
+ %1 = scf.for %arg4 = %lb1 to %ub1 step %step1 iter_args(%init1 = %init0) -> tensor<?x?xf32> {
+ %extracted_slice = tensor.extract_slice %init1[%arg3, %arg4] [%step0, %step1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+ %2 = tensor.insert_slice %extracted_slice into %init1[%arg3, %arg4] [%step0, %step1] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+ scf.yield %2 : tensor<?x?xf32>
+ }
+ scf.yield %1 : tensor<?x?xf32>
+ }
+ %2 = linalg.add ins(%0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %2 : tensor<?x?xf32>
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %loops = transform.structured.match ops{["scf.for"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %loop0, %loop1 = transform.split_handle %loops
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %add = transform.structured.match ops{["linalg.add"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %a, %new_loop0, %new_loop1 = transform.test.fuse_consumer %add into (%loop0, %loop1)
+ : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: func @fuse_tilable_consumer_nested_scf_for(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK: %[[OUTER_RESULT:.+]]:2 = scf.for
+// CHECK-SAME: iter_args(%[[INIT00:[a-zA-Z0-9_]+]] = %[[ARG0]], %[[INIT01:[a-zA-Z0-9_]+]] = %[[ARG2]])
+// CHECK: %[[INNER_RESULT:.+]]:2 = scf.for
+// CHECK-SAME: iter_args(%[[INIT10:[a-zA-Z0-9_]+]] = %[[INIT00]], %[[INIT11:[a-zA-Z0-9_]+]] = %[[INIT01]])
+// CHECK-DAG: %[[OPERAND1:.+]] = tensor.extract_slice %[[INIT10]]
+// CHECK-DAG: %[[OLD_INSERT_SLICE:.+]] = tensor.insert_slice %[[OPERAND1]] into %[[INIT10]]
+// CHECK-DAG: %[[OPERAND2:.+]] = tensor.extract_slice %[[ARG1]]
+// CHECK-DAG: %[[INIT:.+]] = tensor.extract_slice %[[INIT11]]
+// CHECK: %[[ADD:.+]] = linalg.add
+// CHECK-SAME: ins(%[[OPERAND1]], %[[OPERAND2]] :
+// CHECK-SAME: outs(%[[INIT]] :
+// CHECK: %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[ADD]] into %[[INIT11]]
+// CHECK: scf.yield %[[OLD_INSERT_SLICE]], %[[INSERT_SLICE]]
+// CHECK: scf.yield %[[INNER_RESULT]]#0, %[[INNER_RESULT]]#1
+// CHECK: return %[[OUTER_RESULT]]#1
+
+// -----
+
+module {
+ func.func @fuse_tilable_consumer_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> {
%c4 = arith.constant 4 : index
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
@@ -83,19 +136,16 @@ module {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
- %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ %add = transform.structured.match ops{["linalg.add"]} in %arg1
: (!transform.any_op) -> !transform.any_op
%loop = transform.structured.match ops{["scf.forall"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- %first_slice_op, %second_slice_op = transform.split_handle %slice_ops
- : (!transform.any_op)
- -> (!transform.any_op, !transform.any_op)
- %a, %b = transform.test.fuse_consumer %first_slice_op in (%loop)
+ %a, %new_loop = transform.test.fuse_consumer %add into (%loop)
: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
-// CHECK: func.func @fuse_tileable_consumer_scf_forall(
+// CHECK: func.func @fuse_tilable_consumer_scf_forall(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32>)
@@ -124,7 +174,7 @@ module attributes {transform.with_named_sequence} {
#map = affine_map<(d0) -> (d0)>
module {
- func.func @fuse_tileable_consumer_scf_for_multi_yielding_consumer(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
+ func.func @fuse_tilable_consumer_scf_for_multi_yielding_consumer(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
%c4 = arith.constant 4 : index
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
@@ -155,16 +205,18 @@ module {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
- %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+ %generics = transform.structured.match ops{["linalg.generic"]} in %arg1
: (!transform.any_op) -> !transform.any_op
+ %producer, %consumer = transform.split_handle %generics
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%loop = transform.structured.match ops{["scf.for"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- %a, %b = transform.test.fuse_consumer %yield in (%loop)
+ %a, %new_loop = transform.test.fuse_consumer %consumer into (%loop)
: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
-// CHECK: func.func @fuse_tileable_consumer_scf_for_multi_yielding_consumer(
+// CHECK: func.func @fuse_tilable_consumer_scf_for_multi_yielding_consumer(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>)
@@ -193,7 +245,7 @@ module attributes {transform.with_named_sequence} {
#map = affine_map<(d0, d1) -> (d0, d1)>
module {
- func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64x32xf32>) -> (tensor<64x64xf32>, tensor<2048xf32>) {
+ func.func @fuse_tilable_consumer_scf_forall_multi_yielding_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64x32xf32>) -> (tensor<64x64xf32>, tensor<2048xf32>) {
%c4 = arith.constant 4 : index
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
@@ -224,19 +276,16 @@ module {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
- %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
: (!transform.any_op) -> !transform.any_op
%loop = transform.structured.match ops{["scf.forall"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- %first_slice_op, %second_slice_op = transform.split_handle %slice_ops
- : (!transform.any_op)
- -> (!transform.any_op, !transform.any_op)
- %a, %b = transform.test.fuse_consumer %first_slice_op in (%loop)
+ %a, %new_loops = transform.test.fuse_consumer %generic into (%loop)
: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
-// CHECK: func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(
+// CHECK: func.func @fuse_tilable_consumer_scf_forall_multi_yielding_consumer(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32>
@@ -293,17 +342,15 @@ module {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
- %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ %consumer = transform.structured.match ops{["linalg.unpack"]} in %arg1
: (!transform.any_op) -> !transform.any_op
%loop = transform.structured.match ops{["scf.forall"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- %a, %b = transform.test.fuse_consumer %slice_op in (%loop)
+ %a, %new_loop = transform.test.fuse_consumer %consumer into (%loop)
: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
-// CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
-// CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2048)>
// CHECK: func.func @fuse_unpack_consumer_into_scf_forall(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
@@ -315,8 +362,8 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
// CHECK: %[[GENERIC_OUT:.*]] = linalg.generic
// CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] :
-// CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_OFFSET_MAP]](%[[IV1]])
-// CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min #[[UNPACK_RESULT_SIZE_MAP]](%[[IV1]])
+// CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 * 32)>(%[[IV1]])
+// CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min affine_map<(d0) -> (1024, d0 * -32 + 2048)>(%[[IV1]])
// CHECK: %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
// CHECK: %[[TILED_UNPACK_OUT:.*]] = linalg.unpack %[[GENERIC_OUT]]
// CHECK-SAME: outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32]
@@ -356,17 +403,15 @@ module {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
- %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ %consumer = transform.structured.match ops{["linalg.unpack"]} in %arg1
: (!transform.any_op) -> !transform.any_op
%loop = transform.structured.match ops{["scf.forall"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- %a, %b = transform.test.fuse_consumer %slice_op in (%loop)
+ %a, %new_loop = transform.test.fuse_consumer %consumer into (%loop)
: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
-// CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
-// CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2047)>
// CHECK: func.func @fuse_unaligned_unpack_consumer_into_scf_forall(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
@@ -378,8 +423,8 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
// CHECK: %[[GENERIC_OUT:.*]] = linalg.generic
// CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] :
-// CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_OFFSET_MAP]](%[[IV1]])
-// CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min #[[UNPACK_RESULT_SIZE_MAP]](%[[IV1]])
+// CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 * 32)>(%[[IV1]])
+// CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min affine_map<(d0) -> (1024, d0 * -32 + 2047)>(%[[IV1]])
// CHECK: %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
// CHECK: %[[TILED_UNPACK_OUT:.*]] = linalg.unpack %[[GENERIC_OUT]]
// CHECK-SAME: outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32]
@@ -419,16 +464,15 @@ module {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
- %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ %consumer = transform.structured.match ops{["linalg.pack"]} in %arg1
: (!transform.any_op) -> !transform.any_op
%loop = transform.structured.match ops{["scf.forall"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- %a, %b = transform.test.fuse_consumer %slice_op in (%loop)
+ %a, %new_loop = transform.test.fuse_consumer %consumer into (%loop)
: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
-// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
// CHECK: func.func @fuse_perfect_tiling_pack_consumer(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
@@ -440,7 +484,7 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
// CHECK: %[[GENERIC_OUT:.*]] = linalg.generic
// CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] :
-// CHECK: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV1]])
+// CHECK: %[[PACK_RESULT_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%[[IV1]])
// CHECK: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], %[[IV2]], 0] [2, 32, 16] [1, 1, 1]
// CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[GENERIC_OUT]]
// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16]
@@ -471,13 +515,12 @@ func.func @fuse_pack_consumer_if_single_iteration(%arg0: tensor<4x4xf32>) -> ten
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %consumer = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %fused_consumer, %new_loop = transform.test.fuse_consumer %consumer into(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
-// CHECK: #[[MAP:.*]] = affine_map<(d0) -> (-d0 + 4, 16)>
// CHECK: func.func @fuse_pack_consumer_if_single_iteration(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[PACK_INIT:.*]] = tensor.empty() : tensor<1x4x16x1xf32>
@@ -485,7 +528,7 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (4) step (16)
// CHECK-SAME: shared_outs(%[[ELEM_OUT_ARG:.*]] = %[[ELEM_INIT]], %[[PACK_OUT_ARG:.*]] = %[[PACK_INIT]])
-// CHECK-DAG: %[[SIZE:.+]] = affine.min #[[MAP]](%[[IV]])
+// CHECK-DAG: %[[SIZE:.+]] = affine.min affine_map<(d0) -> (-d0 + 4, 16)>(%[[IV]])
// CHECK-DAG: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
// CHECK-DAG: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
// CHECK: %[[ELEM:.*]] = linalg.exp
@@ -517,13 +560,12 @@ func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm(%arg0: tensor<
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %fused_consumer, %new_loop = transform.test.fuse_consumer %0 into(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
-// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
// CHECK: func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
@@ -535,7 +577,7 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[ELEM:.*]] = linalg.exp
// CHECK-SAME: ins(%[[ELEM_SRC]]
// CHECK-SAME: outs(%[[ELEM_DEST]]
-// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
+// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%[[IV]])
// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1]
// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]]
// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 1]
@@ -566,13 +608,12 @@ func.func @fuse_pack_consumer_with_no_pad_dynamic_dim(%arg0: tensor<64x?xf32>, %
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %fused_consumer, %new_loop = transform.test.fuse_consumer %0 into(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
-// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
// CHECK: func.func @fuse_pack_consumer_with_no_pad_dynamic_dim(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
@@ -584,7 +625,7 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[ELEM:.*]] = linalg.exp
// CHECK-SAME: ins(%[[ELEM_SRC]]
// CHECK-SAME: outs(%[[ELEM_DEST]]
-// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
+// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%[[IV]])
// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0] [64, 1, 16] [1, 1, 1]
// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]]
// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [16]
@@ -616,16 +657,12 @@ func.func @fuse_pack_consumer_with_padding_semantics(%arg0: tensor<64x32xf32>, %
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %fused_consumer, %new_loop = transform.test.fuse_consumer %0 into(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
-// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (-d0 + 64, 15)>
-// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (d0 floordiv 3)>
-// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0) -> (d0 ceildiv 3)>
-// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
// CHECK: func.func @fuse_pack_consumer_with_padding_semantics(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
@@ -633,7 +670,7 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %{{.*}}:2 = scf.forall (%[[I:.*]], %[[J:.*]]) = (0, 0) to (64, 32) step (15, 16)
// CHECK-SAME: shared_outs(%[[ELEM_OUT:.*]] = %[[ARG1]], %[[PACK_OUT:.*]] = %[[OUT_INIT]])
-// CHECK: %[[SIZE:.+]] = affine.min #[[MAP0]](%[[I]])
+// CHECK: %[[SIZE:.+]] = affine.min affine_map<(d0) -> (-d0 + 64, 15)>(%[[I]])
// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]]
// CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1]
// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT]]
@@ -641,9 +678,9 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[ELEM:.*]] = linalg.exp
// CHECK-SAME: ins(%[[ELEM_SRC]]
// CHECK-SAME: outs(%[[ELEM_DEST]]
-// CHECK-DAG: %[[D0_OFFSET:.*]] = affine.apply #[[MAP1]](%[[I]])
-// CHECK-DAG: %[[D0_SIZE:.*]] = affine.apply #[[MAP2]](%[[SIZE]])
-// CHECK-DAG: %[[D1_OFFSET:.*]] = affine.apply #[[MAP3]](%[[J]])
+// CHECK-DAG: %[[D0_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 floordiv 3)>(%[[I]])
+// CHECK-DAG: %[[D0_SIZE:.*]] = affine.apply affine_map<(d0) -> (d0 ceildiv 3)>(%[[SIZE]])
+// CHECK-DAG: %[[D1_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%[[J]])
// CHECK-DAG: %[[PACK_INIT:.*]] = tensor.extract_slice %[[PACK_OUT]]
// CHECK-SAME: [%[[D0_OFFSET]], %[[D1_OFFSET]], 0, 0] [%[[D0_SIZE]], 1, 3, 16] [1, 1, 1, 1]
// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]]
@@ -674,20 +711,21 @@ func.func @nofuse_pack_with_imperfect_tiling(%arg0: tensor<30xf32>) -> tensor<5x
linalg.yield %5 : f32
} -> tensor<5xf32>
scf.forall.in_parallel {
- // expected-error @below {{failed to fuse consumer of slice}}
+
tensor.parallel_insert_slice %4 into %arg2[%3] [5] [1] : tensor<5xf32> into tensor<30xf32>
}
}
%2 = tensor.empty() : tensor<5x6xf32>
+ // expected-error @below {{failed to fuse consumer of slice}}
%pack = linalg.pack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [6] into %2 : tensor<30xf32> -> tensor<5x6xf32>
return %pack : tensor<5x6xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %fused_consumer, %new_loop = transform.test.fuse_consumer %0 into(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
@@ -717,11 +755,15 @@ module {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
- %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+ %mulop = transform.structured.match ops{["linalg.mul"]} in %arg1
: (!transform.any_op) -> !transform.any_op
%loop = transform.structured.match ops{["scf.for"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- %a, %b = transform.test.fuse_consumer %slice_op in (%loop) num_consumer_to_fuse = 2
+ %fused_consumer, %new_loop = transform.test.fuse_consumer %mulop into (%loop)
+ : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %expop = transform.structured.match ops{["linalg.exp"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %fused_consumer_2, %new_loop_2 = transform.test.fuse_consumer %expop into (%new_loop)
: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
@@ -741,64 +783,20 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: ins(%[[ADD_INS0_SLICE]], %[[ADD_INS1_SLICE]] :
// CHECK-SAME: outs(%[[ADD_OUT_SLICE]] :
// CHECK: %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
-// CHECK: %[[EXP_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
-// CHECK: %[[TILED_EXP_OUT:.*]] = linalg.exp
-// CHECK-SAME: ins(%[[TILED_ADD_OUT]] :
-// CHECK-SAME: outs(%[[EXP_OUT_SLICE]] :
// CHECK: %[[MUL_INS2_SLICE:.*]] = tensor.extract_slice %[[ARG2]][%[[IV1]], 0] [64, 256] [1, 1]
-// CHECK: %[[MUL_OUT_SLICE:.*]] = tensor.extract_slice %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+// CHECK: %[[MUL_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
// CHECK: %[[TILED_MUL_OUT:.*]] = linalg.mul
// CHECK-SAME: ins(%[[TILED_ADD_OUT]], %[[MUL_INS2_SLICE]] :
// CHECK-SAME: outs(%[[MUL_OUT_SLICE]] :
-// CHECK: %[[INSERT_EXP:.*]] = tensor.insert_slice %[[TILED_EXP_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
-// CHECK: %[[INSERT_MUL:.*]] = tensor.insert_slice %[[TILED_MUL_OUT]] into %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
-// CHECK: scf.yield %[[INSERT_ADD]], %[[INSERT_EXP]], %[[INSERT_MUL]] :
-// CHECK: }
-// CHECK: return %[[LOOP_RESULT]]#2, %[[LOOP_RESULT]]#1 :
-
-// -----
-
-module {
- func.func @no_fuse_only_dps_consumer(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<258x258xf32>) {
- %c0 = arith.constant 0 : index
- %c64 = arith.constant 64 : index
- %c256 = arith.constant 256 : index
- %cst = arith.constant 0.000000e+00 : f32
- %dest0 = tensor.empty() : tensor<256x256xf32>
- %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest0) -> (tensor<256x256xf32>) {
- %extracted_slice_1 = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
- %extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
- %extracted_slice_3 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
- %3 = linalg.add ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice_1 : tensor<64x256xf32>) -> tensor<64x256xf32>
- %insert_slice = tensor.insert_slice %3 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32>
- scf.yield %insert_slice : tensor<256x256xf32>
- }
- %dest1 = tensor.empty() : tensor<258x258xf32>
- %4 = tensor.insert_slice %1 into %dest1[0, 0] [256, 256] [1, 1] : tensor<256x256xf32> into tensor<258x258xf32>
- %5 = linalg.mul ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
- return %5, %4 : tensor<256x256xf32>, tensor<258x258xf32>
- }
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
- %slice_ops = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
- : (!transform.any_op) -> !transform.any_op
- %loop = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %slice_op, %other_slice = transform.split_handle %slice_ops : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- %a, %b = transform.test.fuse_consumer %slice_op in (%loop) num_consumer_to_fuse = 1
- : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.yield
- }
-}
-// CHECK: func.func @no_fuse_only_dps_consumer(
-// CHECK: %[[LOOP_RESULT:.*]]:2 = scf.for {{.*}} {
-// CHECK: linalg.add
-// CHECK: linalg.mul
-// CHECK: scf.yield
+// CHECK: %[[EXP_OUT_SLICE:.*]] = tensor.extract_slice %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+// CHECK: %[[TILED_EXP_OUT:.*]] = linalg.exp
+// CHECK-SAME: ins(%[[TILED_ADD_OUT]] :
+// CHECK-SAME: outs(%[[EXP_OUT_SLICE]] :
+// CHECK: %[[INSERT_MUL:.*]] = tensor.insert_slice %[[TILED_MUL_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+// CHECK: %[[INSERT_EXP:.*]] = tensor.insert_slice %[[TILED_EXP_OUT]] into %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+// CHECK: scf.yield %[[INSERT_ADD]], %[[INSERT_MUL]], %[[INSERT_EXP]] :
// CHECK: }
-// CHECK: %[[RES_SLICE:.+]] = tensor.insert_slice
-// CHECK: return %[[LOOP_RESULT]]#1, %[[RES_SLICE]]
+// CHECK: return %[[LOOP_RESULT]]#1, %[[LOOP_RESULT]]#2 :
// -----
@@ -829,40 +827,41 @@ module {
}
}
-// CHECK: func.func @fuse_with_tilable_consumer_with_projected_permutations(%[[VAL_0:.*]]: tensor<256x256xf32>, %[[VAL_1:.*]]: tensor<256x256xf32>, %[[VAL_2:.*]]: tensor<24xf32>) -> tensor<256x256x24xf32> {
-// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_4:.*]] = arith.constant 64 : index
-// CHECK: %[[VAL_5:.*]] = arith.constant 256 : index
-// CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<256x256xf32>
-// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<256x256x24xf32>
-// CHECK: %[[VAL_8:.*]]:2 = scf.for %[[VAL_9:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_10:.*]] = %[[VAL_6]], %[[VAL_11:.*]] = %[[VAL_7]]) -> (tensor<256x256xf32>, tensor<256x256x24xf32>) {
-// CHECK: %[[VAL_12:.*]] = tensor.extract_slice %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
-// CHECK: %[[VAL_13:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
-// CHECK: %[[VAL_14:.*]] = tensor.extract_slice %[[VAL_1]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
-// CHECK: %[[VAL_15:.*]] = linalg.add ins(%[[VAL_13]], %[[VAL_14]] : tensor<64x256xf32>, tensor<64x256xf32>) outs(%[[VAL_12]] : tensor<64x256xf32>) -> tensor<64x256xf32>
-// CHECK: %[[VAL_16:.*]] = tensor.insert_slice %[[VAL_15]] into %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
-// CHECK: %[[VAL_17:.*]] = tensor.extract_slice %[[VAL_2]][0] [24] [1] : tensor<24xf32> to tensor<24xf32>
-// CHECK: %[[VAL_18:.*]] = tensor.extract_slice %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1]
-// CHECK: %[[VAL_19:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_15]], %[[VAL_17]] : tensor<64x256xf32>, tensor<24xf32>) outs(%[[VAL_18]] : tensor<64x256x24xf32>) {
-// CHECK: ^bb0(%[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32, %[[VAL_22:.*]]: f32):
-// CHECK: %[[VAL_23:.*]] = arith.addf %[[VAL_20]], %[[VAL_21]] : f32
-// CHECK: linalg.yield %[[VAL_23]] : f32
-// CHECK: } -> tensor<64x256x24xf32>
-// CHECK: %[[VAL_24:.*]] = tensor.insert_slice %[[VAL_25:.*]] into %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1]
-// CHECK: scf.yield %[[VAL_16]], %[[VAL_24]] : tensor<256x256xf32>, tensor<256x256x24xf32>
-// CHECK: }
-
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
- %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+ %consumer = transform.structured.match ops{["linalg.generic"]} in %arg1
: (!transform.any_op) -> !transform.any_op
%loop = transform.structured.match ops{["scf.for"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- %a, %b = transform.test.fuse_consumer %slice_op in (%loop) num_consumer_to_fuse = 1
+ %a, %b = transform.test.fuse_consumer %consumer into (%loop)
: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
+// CHECK: func.func @fuse_with_tilable_consumer_with_projected_permutations(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<256x256xf32>, %[[VAL_1:.*]]: tensor<256x256xf32>, %[[VAL_2:.*]]: tensor<24xf32>) -> tensor<256x256x24xf32> {
+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_4:.*]] = arith.constant 64 : index
+// CHECK: %[[VAL_5:.*]] = arith.constant 256 : index
+// CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<256x256xf32>
+// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<256x256x24xf32>
+// CHECK: %[[VAL_8:.*]]:2 = scf.for %[[VAL_9:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_10:.*]] = %[[VAL_6]], %[[VAL_11:.*]] = %[[VAL_7]]) -> (tensor<256x256xf32>, tensor<256x256x24xf32>) {
+// CHECK: %[[VAL_12:.*]] = tensor.extract_slice %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
+// CHECK: %[[VAL_13:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
+// CHECK: %[[VAL_14:.*]] = tensor.extract_slice %[[VAL_1]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
+// CHECK: %[[VAL_15:.*]] = linalg.add ins(%[[VAL_13]], %[[VAL_14]] : tensor<64x256xf32>, tensor<64x256xf32>) outs(%[[VAL_12]] : tensor<64x256xf32>) -> tensor<64x256xf32>
+// CHECK: %[[VAL_16:.*]] = tensor.insert_slice %[[VAL_15]] into %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
+// CHECK: %[[VAL_17:.*]] = tensor.extract_slice %[[VAL_2]][0] [24] [1] : tensor<24xf32> to tensor<24xf32>
+// CHECK: %[[VAL_18:.*]] = tensor.extract_slice %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1]
+// CHECK: %[[VAL_19:.*]] = linalg.generic
+// CHECK-SAME: ins(%[[VAL_15]], %[[VAL_17]] : tensor<64x256xf32>, tensor<24xf32>) outs(%[[VAL_18]] : tensor<64x256x24xf32>) {
+// CHECK: ^bb0(%[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32, %[[VAL_22:.*]]: f32):
+// CHECK: %[[VAL_23:.*]] = arith.addf %[[VAL_20]], %[[VAL_21]] : f32
+// CHECK: linalg.yield %[[VAL_23]] : f32
+// CHECK: } -> tensor<64x256x24xf32>
+// CHECK: %[[VAL_24:.*]] = tensor.insert_slice %[[VAL_25:.*]] into %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1]
+// CHECK: scf.yield %[[VAL_16]], %[[VAL_24]] : tensor<256x256xf32>, tensor<256x256x24xf32>
+// CHECK: }
// -----
@@ -878,12 +877,12 @@ func.func @multi_slice_fusion1(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %
%init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
%generic:2 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>],
- iterator_types = ["parallel", "reduction"]}
- ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice, %init1_slice : tensor<?xf32>, tensor<?xf32>) {
+ iterator_types = ["parallel", "reduction"]}
+ ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice, %init1_slice : tensor<?xf32>, tensor<?xf32>) {
^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
%0 = arith.mulf %b0, %b1 : f32
- %1 = arith.addf %b0, %b2 : f32
- linalg.yield %0, %1 : f32, f32
+ %1 = arith.addf %b0, %b2 : f32
+ linalg.yield %0, %1 : f32, f32
} -> (tensor<?xf32>, tensor<?xf32>)
scf.forall.in_parallel {
tensor.parallel_insert_slice %generic#0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
@@ -901,6 +900,19 @@ func.func @multi_slice_fusion1(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %
} -> tensor<?xf32>
return %result : tensor<?xf32>
}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %generics = transform.structured.match ops{["linalg.generic"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %producer, %consumer = transform.split_handle %generics : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %a, %b = transform.test.fuse_consumer %consumer into (%loop)
+ : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
// CHECK-LABEL: func @multi_slice_fusion1(
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
// CHECK: %[[C0:.+]] = arith.constant 0
@@ -916,23 +928,9 @@ func.func @multi_slice_fusion1(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %
// CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
// CHECK: return %[[RESULT]]#2
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
- %loop = transform.structured.match ops{["scf.forall"]} in %arg1
- : (!transform.any_op) -> !transform.any_op
- %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
- : (!transform.any_op) -> !transform.any_op
- %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
- : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.yield
- }
-}
// -----
-// Check that when the given operand tiles are inconsistent, tiling fails.
-
func.func @multi_slice_fusion2(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>, %arg3 : index) -> tensor<?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
@@ -944,20 +942,20 @@ func.func @multi_slice_fusion2(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %
%init0_slice = tensor.extract_slice %init0[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
%generic0 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
- iterator_types = ["parallel", "reduction"]}
- ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice : tensor<?xf32>) {
+ iterator_types = ["parallel", "reduction"]}
+ ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice : tensor<?xf32>) {
^bb0(%b0 : f32, %b1 : f32):
%0 = arith.mulf %b0, %b1 : f32
- linalg.yield %0 : f32
+ linalg.yield %0 : f32
} -> tensor<?xf32>
%init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
%generic1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
- iterator_types = ["parallel", "reduction"]}
- ins(%arg0_slice : tensor<?x?xf32>) outs(%init1_slice: tensor<?xf32>) {
+ iterator_types = ["parallel", "reduction"]}
+ ins(%arg0_slice : tensor<?x?xf32>) outs(%init1_slice: tensor<?xf32>) {
^bb0(%b0 : f32, %b1 : f32):
- %0 = arith.addf %b0, %b1 : f32
- linalg.yield %0: f32
+ %0 = arith.addf %b0, %b1 : f32
+ linalg.yield %0: f32
} -> tensor<?xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %generic0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
@@ -975,6 +973,19 @@ func.func @multi_slice_fusion2(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %
} -> tensor<?xf32>
return %result : tensor<?xf32>
}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %generics = transform.structured.match ops{["linalg.generic"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %producer1, %producer2, %consumer = transform.split_handle %generics : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %a, %b = transform.test.fuse_consumer %consumer into (%loop)
+ : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
// CHECK-LABEL: func @multi_slice_fusion2(
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
// CHECK: %[[C0:.+]] = arith.constant 0
@@ -991,19 +1002,6 @@ func.func @multi_slice_fusion2(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %
// CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
// CHECK: return %[[RESULT]]#2
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
- %loop = transform.structured.match ops{["scf.forall"]} in %arg1
- : (!transform.any_op) -> !transform.any_op
- %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
- : (!transform.any_op) -> !transform.any_op
- %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
- : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.yield
- }
-}
-
// -----
func.func @multi_slice_fusion_with_broadcast(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?xf32>,
@@ -1060,11 +1058,11 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%loop = transform.structured.match ops{["scf.forall"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ %generics = transform.structured.match ops{["linalg.generic"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
- : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %producer_1, %producer_2, %consumer = transform.split_handle %generics : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %a, %b = transform.test.fuse_consumer %consumer into (%loop)
+ : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
@@ -1124,7 +1122,6 @@ func.func @multi_slice_fusion_invalid(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<
linalg.yield %0: f32
} -> tensor<?x?xf32>
scf.forall.in_parallel {
- // expected-error @below {{failed to fuse consumer of slice}}
tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
: tensor<?x?xf32> into tensor<?x?xf32>
tensor.parallel_insert_slice %generic1 into %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
@@ -1132,6 +1129,7 @@ func.func @multi_slice_fusion_invalid(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<
}
}
%empty = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
+ // expected-error @below {{failed to fuse consumer of slice}}
%result = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
@@ -1146,11 +1144,11 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%loop = transform.structured.match ops{["scf.forall"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ %generics = transform.structured.match ops{["linalg.generic"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
- : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %producer_1, %producer_2, %consumer = transform.split_handle %generics : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %a, %b = transform.test.fuse_consumer %consumer into (%loop)
+ : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
diff --git a/mlir/test/Pass/invalid-unsupported-operation.mlir b/mlir/test/Pass/invalid-unsupported-operation.mlir
new file mode 100644
index 0000000..1ee4584
--- /dev/null
+++ b/mlir/test/Pass/invalid-unsupported-operation.mlir
@@ -0,0 +1,10 @@
+// RUN: mlir-opt %s -test-print-liveness -split-input-file -verify-diagnostics
+
+// Unnamed modules do not implement SymbolOpInterface.
+// expected-error-re @+1 {{trying to schedule pass '{{.*}}TestLivenessPass' on an unsupported operation}}
+module {}
+
+// -----
+
+// Named modules implement SymbolOpInterface.
+module @named_module {}
diff --git a/mlir/test/Pass/pipeline-invalid.mlir b/mlir/test/Pass/pipeline-invalid.mlir
index 948a133..bff2b1c 100644
--- a/mlir/test/Pass/pipeline-invalid.mlir
+++ b/mlir/test/Pass/pipeline-invalid.mlir
@@ -15,5 +15,5 @@ arith.constant 0
// -----
-// expected-error@below {{trying to schedule a pass on an unsupported operation}}
+// expected-error-re@below {{trying to schedule pass '{{.*}}TestFunctionPass' on an unsupported operation}}
module {}
diff --git a/mlir/test/Target/Cpp/common-cpp.mlir b/mlir/test/Target/Cpp/common-cpp.mlir
index 294e6af6..f397a4a 100644
--- a/mlir/test/Target/Cpp/common-cpp.mlir
+++ b/mlir/test/Target/Cpp/common-cpp.mlir
@@ -105,6 +105,25 @@ func.func @apply() -> !emitc.ptr<i32> {
return %1 : !emitc.ptr<i32>
}
+
+// CHECK-LABEL: void address_of() {
+func.func @address_of() {
+ // CHECK-NEXT: int32_t [[V1:[^ ]*]];
+ %0 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue<i32>
+ // CHECK-NEXT: int32_t* [[V2:[^ ]*]] = &[[V1]];
+ %1 = emitc.address_of %0 : !emitc.lvalue<i32>
+ return
+}
+
+// CHECK-LABEL: void dereference
+// CHECK-SAME: (int32_t* [[ARG0:[^ ]*]]) {
+func.func @dereference(%arg0: !emitc.ptr<i32>) {
+ // CHECK-NEXT: int32_t [[V1:[^ ]*]] = *[[ARG0]];
+ %2 = emitc.dereference %arg0 : !emitc.ptr<i32>
+ emitc.load %2 : !emitc.lvalue<i32>
+ return
+}
+
// CHECK: void array_type(int32_t v1[3], float v2[10][20])
func.func @array_type(%arg0: !emitc.array<3xi32>, %arg1: !emitc.array<10x20xf32>) {
return
diff --git a/mlir/test/Target/Cpp/expressions.mlir b/mlir/test/Target/Cpp/expressions.mlir
index 9f1c816..2de94d0 100644
--- a/mlir/test/Target/Cpp/expressions.mlir
+++ b/mlir/test/Target/Cpp/expressions.mlir
@@ -314,14 +314,14 @@ func.func @different_expressions(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32)
return %v_load : i32
}
-// CPP-DEFAULT: int32_t expression_with_dereference(int32_t [[VAL_1:v[0-9]+]], int32_t* [[VAL_2]]) {
+// CPP-DEFAULT: int32_t expression_with_dereference_apply(int32_t [[VAL_1:v[0-9]+]], int32_t* [[VAL_2]]) {
// CPP-DEFAULT-NEXT: return *([[VAL_2]] - [[VAL_1]]);
// CPP-DEFAULT-NEXT: }
-// CPP-DECLTOP: int32_t expression_with_dereference(int32_t [[VAL_1:v[0-9]+]], int32_t* [[VAL_2]]) {
+// CPP-DECLTOP: int32_t expression_with_dereference_apply(int32_t [[VAL_1:v[0-9]+]], int32_t* [[VAL_2]]) {
// CPP-DECLTOP-NEXT: return *([[VAL_2]] - [[VAL_1]]);
// CPP-DECLTOP-NEXT: }
-emitc.func @expression_with_dereference(%arg1: i32, %arg2: !emitc.ptr<i32>) -> i32 {
+emitc.func @expression_with_dereference_apply(%arg1: i32, %arg2: !emitc.ptr<i32>) -> i32 {
%c = emitc.expression %arg1, %arg2 : (i32, !emitc.ptr<i32>) -> i32 {
%e = emitc.sub %arg2, %arg1 : (!emitc.ptr<i32>, i32) -> !emitc.ptr<i32>
%d = emitc.apply "*"(%e) : (!emitc.ptr<i32>) -> i32
@@ -330,6 +330,28 @@ emitc.func @expression_with_dereference(%arg1: i32, %arg2: !emitc.ptr<i32>) -> i
return %c : i32
}
+// CPP-DEFAULT: bool expression_with_address_taken_apply(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t* [[VAL_3]]) {
+// CPP-DEFAULT-NEXT: int32_t [[VAL_4:v[0-9]+]] = 42;
+// CPP-DEFAULT-NEXT: return &[[VAL_4]] - [[VAL_2]] < [[VAL_3]];
+// CPP-DEFAULT-NEXT: }
+
+// CPP-DECLTOP: bool expression_with_address_taken_apply(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t* [[VAL_3]]) {
+// CPP-DECLTOP-NEXT: int32_t [[VAL_4:v[0-9]+]];
+// CPP-DECLTOP-NEXT: [[VAL_4]] = 42;
+// CPP-DECLTOP-NEXT: return &[[VAL_4]] - [[VAL_2]] < [[VAL_3]];
+// CPP-DECLTOP-NEXT: }
+
+func.func @expression_with_address_taken_apply(%arg0: i32, %arg1: i32, %arg2: !emitc.ptr<i32>) -> i1 {
+ %a = "emitc.variable"(){value = 42 : i32} : () -> !emitc.lvalue<i32>
+ %c = emitc.expression %arg1, %arg2, %a : (i32, !emitc.ptr<i32>, !emitc.lvalue<i32>) -> i1 {
+ %d = emitc.apply "&"(%a) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
+ %e = emitc.sub %d, %arg1 : (!emitc.ptr<i32>, i32) -> !emitc.ptr<i32>
+ %f = emitc.cmp lt, %e, %arg2 : (!emitc.ptr<i32>, !emitc.ptr<i32>) -> i1
+ emitc.yield %f : i1
+ }
+ return %c : i1
+}
+
// CPP-DEFAULT: bool expression_with_address_taken(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t* [[VAL_3]]) {
// CPP-DEFAULT-NEXT: int32_t [[VAL_4:v[0-9]+]] = 42;
// CPP-DEFAULT-NEXT: return &[[VAL_4]] - [[VAL_2]] < [[VAL_3]];
@@ -344,7 +366,7 @@ emitc.func @expression_with_dereference(%arg1: i32, %arg2: !emitc.ptr<i32>) -> i
func.func @expression_with_address_taken(%arg0: i32, %arg1: i32, %arg2: !emitc.ptr<i32>) -> i1 {
%a = "emitc.variable"(){value = 42 : i32} : () -> !emitc.lvalue<i32>
%c = emitc.expression %arg1, %arg2, %a : (i32, !emitc.ptr<i32>, !emitc.lvalue<i32>) -> i1 {
- %d = emitc.apply "&"(%a) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
+ %d = emitc.address_of %a : !emitc.lvalue<i32>
%e = emitc.sub %d, %arg1 : (!emitc.ptr<i32>, i32) -> !emitc.ptr<i32>
%f = emitc.cmp lt, %e, %arg2 : (!emitc.ptr<i32>, !emitc.ptr<i32>) -> i1
emitc.yield %f : i1
diff --git a/mlir/test/Target/LLVMIR/Import/debug-info-records.ll b/mlir/test/Target/LLVMIR/Import/debug-info-records.ll
new file mode 100644
index 0000000..077871e
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/debug-info-records.ll
@@ -0,0 +1,87 @@
+; RUN: mlir-translate -import-llvm -mlir-print-debuginfo -convert-debug-rec-to-intrinsics -emit-expensive-warnings -split-input-file %s 2>&1 | FileCheck %s
+; RUN: mlir-translate -import-llvm -mlir-print-debuginfo -emit-expensive-warnings -split-input-file %s 2>&1 | FileCheck %s
+
+; CHECK: #[[LOCAL_VAR0:.*]] = #llvm.di_local_variable<scope = #di_lexical_block>
+; CHECK: #[[LOCAL_VAR1:.*]] = #llvm.di_local_variable<scope = #di_lexical_block_file, name = "arg"
+; CHECK: #[[LOCAL_VAR2:.*]] = #llvm.di_local_variable<scope = #di_lexical_block, name = "alloc"
+
+; CHECK: @callee()
+define void @callee() {
+ ret void
+}
+
+define void @func_with_empty_named_info() {
+ call void @callee()
+ ret void
+}
+
+define void @func_no_debug() {
+ ret void
+}
+
+; CHECK: llvm.func @func_with_debug(%[[ARG0:.*]]: i64
+define void @func_with_debug(i64 %0) !dbg !3 {
+
+ ; CHECK: llvm.intr.dbg.value #[[LOCAL_VAR0]] = %[[ARG0]] : i64
+ ; CHECK: llvm.intr.dbg.value #[[LOCAL_VAR1]] #llvm.di_expression<[DW_OP_LLVM_fragment(0, 1)]> = %[[ARG0]] : i64
+ ; CHECK: %[[CST:.*]] = llvm.mlir.constant(1 : i32) : i32
+ ; CHECK: %[[ADDR:.*]] = llvm.alloca %[[CST]] x i64
+ ; CHECK: llvm.intr.dbg.declare #[[LOCAL_VAR2]] #llvm.di_expression<[DW_OP_deref, DW_OP_LLVM_convert(4, DW_ATE_signed)]> = %[[ADDR]] : !llvm.ptr
+ %2 = alloca i64, align 8, !dbg !19
+ #dbg_value(i64 %0, !20, !DIExpression(DW_OP_LLVM_fragment, 0, 1), !22)
+ #dbg_declare(ptr %2, !23, !DIExpression(DW_OP_deref, DW_OP_LLVM_convert, 4, DW_ATE_signed), !25)
+ #dbg_value(i64 %0, !26, !DIExpression(), !27)
+ call void @func_no_debug(), !dbg !28
+ %3 = add i64 %0, %0, !dbg !32
+ ret void, !dbg !37
+}
+
+define void @empty_types() !dbg !38 {
+ ret void, !dbg !44
+}
+
+!llvm.dbg.cu = !{!0}
+!llvm.module.flags = !{!2}
+
+!0 = distinct !DICompileUnit(language: DW_LANG_C, file: !1, producer: "MLIR", isOptimized: true, runtimeVersion: 0, splitDebugFilename: "test.dwo", emissionKind: FullDebug, nameTableKind: None)
+!1 = !DIFile(filename: "foo.mlir", directory: "/test/")
+!2 = !{i32 2, !"Debug Info Version", i32 3}
+!3 = distinct !DISubprogram(name: "func_with_debug", linkageName: "func_with_debug", scope: !4, file: !1, line: 3, type: !6, scopeLine: 3, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !0)
+!4 = !DINamespace(name: "nested", scope: !5)
+!5 = !DINamespace(name: "toplevel", scope: null, exportSymbols: true)
+!6 = !DISubroutineType(cc: DW_CC_normal, types: !7)
+!7 = !{null, !8, !9, !11, !12, !13, !16}
+!8 = !DIBasicType(name: "si64")
+!9 = !DIDerivedType(tag: DW_TAG_pointer_type, baseType: !10, size: 64, align: 32, offset: 8, extraData: !10)
+!10 = !DIBasicType(name: "si32", size: 32, encoding: DW_ATE_signed)
+!11 = !DIDerivedType(tag: DW_TAG_pointer_type, name: "named", baseType: !10)
+!12 = !DIDerivedType(tag: DW_TAG_pointer_type, baseType: !10, size: 64, align: 32, offset: 8, dwarfAddressSpace: 3)
+!13 = distinct !DICompositeType(tag: DW_TAG_structure_type, name: "composite", file: !1, line: 42, size: 64, align: 32, elements: !14)
+!14 = !{!15}
+!15 = !DISubrange(count: 4)
+!16 = !DICompositeType(tag: DW_TAG_array_type, name: "array", file: !1, baseType: !8, flags: DIFlagVector, elements: !17)
+!17 = !{!18}
+!18 = !DISubrange(lowerBound: 0, upperBound: 4, stride: 1)
+!19 = !DILocation(line: 100, column: 12, scope: !3)
+!20 = !DILocalVariable(name: "arg", arg: 1, scope: !21, file: !1, line: 6, type: !8, align: 32)
+!21 = distinct !DILexicalBlockFile(scope: !3, file: !1, discriminator: 0)
+!22 = !DILocation(line: 103, column: 3, scope: !3)
+!23 = !DILocalVariable(name: "alloc", scope: !24)
+!24 = distinct !DILexicalBlock(scope: !3)
+!25 = !DILocation(line: 106, column: 3, scope: !3)
+!26 = !DILocalVariable(scope: !24)
+!27 = !DILocation(line: 109, column: 3, scope: !3)
+!28 = !DILocation(line: 1, column: 2, scope: !3)
+!32 = !DILocation(line: 2, column: 4, scope: !33, inlinedAt: !36)
+!33 = distinct !DISubprogram(name: "callee", scope: !13, file: !1, type: !34, spFlags: DISPFlagDefinition, unit: !0)
+!34 = !DISubroutineType(types: !35)
+!35 = !{!8, !8}
+!36 = !DILocation(line: 28, column: 5, scope: !3)
+!37 = !DILocation(line: 135, column: 3, scope: !3)
+!38 = distinct !DISubprogram(name: "empty_types", scope: !39, file: !1, type: !40, spFlags: DISPFlagDefinition, unit: !0, annotations: !42)
+!39 = !DIModule(scope: !1, name: "module", configMacros: "bar", includePath: "/", apinotes: "/", file: !1, line: 42, isDecl: true)
+!40 = !DISubroutineType(cc: DW_CC_normal, types: !41)
+!41 = !{}
+!42 = !{!43}
+!43 = !{!"foo", !"bar"}
+!44 = !DILocation(line: 140, column: 3, scope: !38)
diff --git a/mlir/test/Target/LLVMIR/Import/function-attributes.ll b/mlir/test/Target/LLVMIR/Import/function-attributes.ll
index 83c0438..023b012 100644
--- a/mlir/test/Target/LLVMIR/Import/function-attributes.ll
+++ b/mlir/test/Target/LLVMIR/Import/function-attributes.ll
@@ -22,14 +22,14 @@ define dso_local void @dsolocal_func() {
; // -----
; CHECK-LABEL: @func_readnone
-; CHECK-SAME: attributes {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>}
+; CHECK-SAME: attributes {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>}
; CHECK: llvm.return
define void @func_readnone() readnone {
ret void
}
; CHECK-LABEL: @func_readnone_indirect
-; CHECK-SAME: attributes {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>}
+; CHECK-SAME: attributes {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>}
declare void @func_readnone_indirect() #0
attributes #0 = { readnone }
@@ -169,7 +169,7 @@ define void @entry_count() !prof !1 {
; // -----
; CHECK-LABEL: @func_memory
-; CHECK-SAME: attributes {memory_effects = #llvm.memory_effects<other = readwrite, argMem = none, inaccessibleMem = readwrite>}
+; CHECK-SAME: attributes {memory_effects = #llvm.memory_effects<other = readwrite, argMem = none, inaccessibleMem = readwrite, errnoMem = readwrite, targetMem0 = readwrite, targetMem1 = readwrite>}
; CHECK: llvm.return
define void @func_memory() memory(readwrite, argmem: none) {
ret void
diff --git a/mlir/test/Target/LLVMIR/Import/import-failure.ll b/mlir/test/Target/LLVMIR/Import/import-failure.ll
index d48be66..32f730b 100644
--- a/mlir/test/Target/LLVMIR/Import/import-failure.ll
+++ b/mlir/test/Target/LLVMIR/Import/import-failure.ll
@@ -1,16 +1,14 @@
; RUN: not mlir-translate -import-llvm -emit-expensive-warnings -split-input-file %s 2>&1 -o /dev/null | FileCheck %s
-; Check that debug intrinsics with an unsupported argument are dropped.
-
-declare void @llvm.dbg.value(metadata, metadata, metadata)
+; Check that debug records with an unsupported argument are dropped.
; CHECK: import-failure.ll
-; CHECK-SAME: warning: dropped intrinsic: tail call void @llvm.dbg.value(metadata !DIArgList(i64 %{{.*}}, i64 undef), metadata !3, metadata !DIExpression(DW_OP_LLVM_arg, 0, DW_OP_LLVM_arg, 1, DW_OP_constu, 1, DW_OP_mul, DW_OP_plus, DW_OP_stack_value))
+; CHECK-SAME: warning: unhandled debug variable record #dbg_value(!DIArgList(i64 %{{.*}}, i64 undef), !{{.*}}, !DIExpression(DW_OP_LLVM_arg, 0, DW_OP_LLVM_arg, 1, DW_OP_constu, 1, DW_OP_mul, DW_OP_plus, DW_OP_stack_value), !{{.*}})
; CHECK: import-failure.ll
-; CHECK-SAME: warning: dropped intrinsic: tail call void @llvm.dbg.value(metadata !6, metadata !3, metadata !DIExpression())
+; CHECK-SAME: warning: unhandled debug variable record #dbg_value(!{{.*}}, !{{.*}}, !DIExpression(), !{{.*}})
define void @unsupported_argument(i64 %arg1) {
- tail call void @llvm.dbg.value(metadata !DIArgList(i64 %arg1, i64 undef), metadata !3, metadata !DIExpression(DW_OP_LLVM_arg, 0, DW_OP_LLVM_arg, 1, DW_OP_constu, 1, DW_OP_mul, DW_OP_plus, DW_OP_stack_value)), !dbg !5
- tail call void @llvm.dbg.value(metadata !6, metadata !3, metadata !DIExpression()), !dbg !5
+ #dbg_value(!DIArgList(i64 %arg1, i64 undef), !3, !DIExpression(DW_OP_LLVM_arg, 0, DW_OP_LLVM_arg, 1, DW_OP_constu, 1, DW_OP_mul, DW_OP_plus, DW_OP_stack_value), !5)
+ #dbg_value(!6, !3, !DIExpression(), !5)
ret void
}
diff --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll
index be245e3..7f9c511 100644
--- a/mlir/test/Target/LLVMIR/Import/instructions.ll
+++ b/mlir/test/Target/LLVMIR/Import/instructions.ll
@@ -703,13 +703,13 @@ declare void @f()
; CHECK-LABEL: @call_memory_effects
define void @call_memory_effects() {
-; CHECK: llvm.call @f() {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>}
+; CHECK: llvm.call @f() {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>}
call void @f() memory(none)
-; CHECK: llvm.call @f() {memory_effects = #llvm.memory_effects<other = none, argMem = write, inaccessibleMem = read>}
+; CHECK: llvm.call @f() {memory_effects = #llvm.memory_effects<other = none, argMem = write, inaccessibleMem = read, errnoMem = none, targetMem0 = none, targetMem1 = none>}
call void @f() memory(none, argmem: write, inaccessiblemem: read)
-; CHECK: llvm.call @f() {memory_effects = #llvm.memory_effects<other = write, argMem = none, inaccessibleMem = write>}
+; CHECK: llvm.call @f() {memory_effects = #llvm.memory_effects<other = write, argMem = none, inaccessibleMem = write, errnoMem = write, targetMem0 = write, targetMem1 = write>}
call void @f() memory(write, argmem: none)
-; CHECK: llvm.call @f() {memory_effects = #llvm.memory_effects<other = readwrite, argMem = readwrite, inaccessibleMem = read>}
+; CHECK: llvm.call @f() {memory_effects = #llvm.memory_effects<other = readwrite, argMem = readwrite, inaccessibleMem = read, errnoMem = readwrite, targetMem0 = readwrite, targetMem1 = readwrite>}
call void @f() memory(readwrite, inaccessiblemem: read)
; CHECK: llvm.call @f()
; CHECK-NOT: #llvm.memory_effects
diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
index d2bb809..2381d7a 100644
--- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll
+++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
@@ -1128,6 +1128,34 @@ define void @experimental_constrained_fpext(float %s, <4 x float> %v) {
ret void
}
+; CHECK-LABEL: llvm.func @ucmp
+define i2 @ucmp(i32 %a, i32 %b) {
+ ; CHECK: %{{.*}} = llvm.intr.ucmp(%{{.*}}, %{{.*}}) : (i32, i32) -> i2
+ %r = call i2 @llvm.ucmp.i2.i32(i32 %a, i32 %b)
+ ret i2 %r
+}
+
+; CHECK-LABEL: llvm.func @vector_ucmp
+define <4 x i32> @vector_ucmp(<4 x i32> %a, <4 x i32> %b) {
+ ; CHECK: %{{.*}} = llvm.intr.ucmp(%{{.*}}, %{{.*}}) : (vector<4xi32>, vector<4xi32>) -> vector<4xi32>
+ %r = call <4 x i32> @llvm.ucmp.v4i32.v4i32(<4 x i32> %a, <4 x i32> %b)
+ ret <4 x i32> %r
+}
+
+; CHECK-LABEL: llvm.func @scmp
+define i2 @scmp(i32 %a, i32 %b) {
+ ; CHECK: %{{.*}} = llvm.intr.scmp(%{{.*}}, %{{.*}}) : (i32, i32) -> i2
+ %r = call i2 @llvm.scmp.i2.i32(i32 %a, i32 %b)
+ ret i2 %r
+}
+
+; CHECK-LABEL: llvm.func @vector_scmp
+define <4 x i32> @vector_scmp(<4 x i32> %a, <4 x i32> %b) {
+ ; CHECK: %{{.*}} = llvm.intr.scmp(%{{.*}}, %{{.*}}) : (vector<4xi32>, vector<4xi32>) -> vector<4xi32>
+ %r = call <4 x i32> @llvm.scmp.v4i32.v4i32(<4 x i32> %a, <4 x i32> %b)
+ ret <4 x i32> %r
+}
+
declare float @llvm.fmuladd.f32(float, float, float)
declare <8 x float> @llvm.fmuladd.v8f32(<8 x float>, <8 x float>, <8 x float>)
declare float @llvm.fma.f32(float, float, float)
@@ -1382,3 +1410,7 @@ declare <4 x half> @llvm.experimental.constrained.fptrunc.v4f16.v4f64(<4 x doubl
declare float @llvm.experimental.constrained.fptrunc.f32.f64(double, metadata, metadata)
declare <4 x double> @llvm.experimental.constrained.fpext.v4f64.v4f32(<4 x float>, metadata)
declare double @llvm.experimental.constrained.fpext.f64.f32(float, metadata)
+declare i2 @llvm.ucmp.i2.i32(i32, i32)
+declare <4 x i32> @llvm.ucmp.v4i32.v4i32(<4 x i32>, <4 x i32>)
+declare i2 @llvm.scmp.i2.i32(i32, i32)
+declare <4 x i32> @llvm.scmp.v4i32.v4i32(<4 x i32>, <4 x i32>)
diff --git a/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll b/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll
index c623df0..3280625 100644
--- a/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll
+++ b/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll
@@ -16,6 +16,22 @@ bb2:
; // -----
+; CHECK-LABEL: @cond_br_expected
+define i64 @cond_br_expected(i1 %arg1, i64 %arg2) {
+entry:
+ ; CHECK: llvm.cond_br
+ ; CHECK-SAME: weights([1, 2000])
+ br i1 %arg1, label %bb1, label %bb2, !prof !0
+bb1:
+ ret i64 %arg2
+bb2:
+ ret i64 %arg2
+}
+
+!0 = !{!"branch_weights", !"expected", i32 1, i32 2000}
+
+; // -----
+
; CHECK-LABEL: @simple_switch(
define i32 @simple_switch(i32 %arg1) {
; CHECK: llvm.switch
@@ -36,6 +52,26 @@ bbd:
; // -----
+; CHECK-LABEL: @simple_switch_expected(
+define i32 @simple_switch_expected(i32 %arg1) {
+ ; CHECK: llvm.switch
+ ; CHECK: {branch_weights = array<i32: 1, 1, 2000>}
+ switch i32 %arg1, label %bbd [
+ i32 0, label %bb1
+ i32 9, label %bb2
+ ], !prof !0
+bb1:
+ ret i32 %arg1
+bb2:
+ ret i32 %arg1
+bbd:
+ ret i32 %arg1
+}
+
+!0 = !{!"branch_weights", !"expected", i32 1, i32 1, i32 2000}
+
+; // -----
+
; Verify that a single weight attached to a call is not translated.
; The MLIR WeightedBranchOpInterface does not support this case.
diff --git a/mlir/test/Target/LLVMIR/allocatable_gpu_reduction.mlir b/mlir/test/Target/LLVMIR/allocatable_gpu_reduction.mlir
new file mode 100644
index 0000000..95d12f3
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/allocatable_gpu_reduction.mlir
@@ -0,0 +1,99 @@
+// Tests single-team by-ref GPU reductions.
+
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+module attributes {dlti.dl_spec = #dlti.dl_spec<"dlti.alloca_memory_space" = 5 : ui64, "dlti.global_memory_space" = 1 : ui64>, llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} {
+ omp.private {type = private} @_QFfooEi_private_i32 : i32
+ omp.declare_reduction @add_reduction_byref_box_heap_f32 : !llvm.ptr attributes {byref_element_type = f32} alloc {
+ %0 = llvm.mlir.constant(1 : i64) : i64
+ %1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> : (i64) -> !llvm.ptr<5>
+ %2 = llvm.addrspacecast %1 : !llvm.ptr<5> to !llvm.ptr
+ omp.yield(%2 : !llvm.ptr)
+ } init {
+ ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+ omp.yield(%arg1 : !llvm.ptr)
+ } combiner {
+ ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> {alignment = 8 : i64} : (i32) -> !llvm.ptr<5>
+ %2 = llvm.addrspacecast %1 : !llvm.ptr<5> to !llvm.ptr
+ %3 = llvm.mlir.constant(1 : i32) : i32
+ %4 = llvm.alloca %3 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> {alignment = 8 : i64} : (i32) -> !llvm.ptr<5>
+ %5 = llvm.addrspacecast %4 : !llvm.ptr<5> to !llvm.ptr
+ %6 = llvm.mlir.constant(24 : i32) : i32
+ "llvm.intr.memcpy"(%5, %arg0, %6) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+ %7 = llvm.mlir.constant(24 : i32) : i32
+ "llvm.intr.memcpy"(%2, %arg1, %7) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+ %8 = llvm.getelementptr %5[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
+ %9 = llvm.load %8 : !llvm.ptr -> !llvm.ptr
+ %10 = llvm.getelementptr %2[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
+ %11 = llvm.load %10 : !llvm.ptr -> !llvm.ptr
+ %12 = llvm.load %9 : !llvm.ptr -> f32
+ %13 = llvm.load %11 : !llvm.ptr -> f32
+ %14 = llvm.fadd %12, %13 {fastmathFlags = #llvm.fastmath<contract>} : f32
+ llvm.store %14, %9 : f32, !llvm.ptr
+ omp.yield(%arg0 : !llvm.ptr)
+ } data_ptr_ptr {
+ ^bb0(%arg0: !llvm.ptr):
+ %0 = llvm.getelementptr %arg0[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
+ omp.yield(%0 : !llvm.ptr)
+ }
+
+ llvm.func @foo_() {
+ %0 = llvm.mlir.constant(1 : i64) : i64
+ %4 = llvm.alloca %0 x i1 : (i64) -> !llvm.ptr<5>
+ %5 = llvm.addrspacecast %4 : !llvm.ptr<5> to !llvm.ptr
+ %8 = llvm.getelementptr %5[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
+ %9 = omp.map.info var_ptr(%5 : !llvm.ptr, f32) map_clauses(implicit, tofrom) capture(ByRef) var_ptr_ptr(%8 : !llvm.ptr) -> !llvm.ptr {name = ""}
+ %10 = omp.map.info var_ptr(%5 : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>) map_clauses(always, implicit, descriptor, to) capture(ByRef) members(%9 : [0] : !llvm.ptr) -> !llvm.ptr {name = "scalar_alloc"}
+ omp.target map_entries(%10 -> %arg0 : !llvm.ptr) {
+ %13 = llvm.mlir.constant(1000 : i32) : i32
+ %14 = llvm.mlir.constant(1 : i32) : i32
+ omp.parallel {
+ omp.wsloop reduction(byref @add_reduction_byref_box_heap_f32 %arg0 -> %arg4 : !llvm.ptr) {
+ omp.loop_nest (%arg5) : i32 = (%14) to (%13) inclusive step (%14) {
+ omp.yield
+ }
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+ llvm.return
+ }
+}
+
+// CHECK: define {{.*}} @_omp_reduction_shuffle_and_reduce_func({{.*}}) {{.*}} {
+// CHECK: %[[REMOTE_RED_LIST:.omp.reduction.remote_reduce_list]] = alloca [1 x ptr], align 8, addrspace(5)
+// CHECK: %[[RED_ELEM:.omp.reduction.element]] = alloca { ptr, i64, i32, i8, i8, i8, i8 }, align 8, addrspace(5)
+// CHECK: %[[RED_ELEM_1:.*]] = addrspacecast ptr addrspace(5) %[[RED_ELEM]] to ptr
+
+// CHECK: %[[SHUFFLE_ELEM:.*]] = alloca float, align 4, addrspace(5)
+// CHECK: %[[REMOTE_RED_LIST_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[REMOTE_RED_LIST]] to ptr
+
+// CHECK: %[[REMOTE_RED_LIST_ELEM0:.*]] = getelementptr inbounds [1 x ptr], ptr %[[REMOTE_RED_LIST_ASCAST]], i64 0, i64 0
+
+// CHECK: %[[SHUFFLE_ELEM_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[SHUFFLE_ELEM]] to ptr
+// CHECK: %[[SHUFFLE_RES:.*]] = call i32 @__kmpc_shuffle_int32({{.*}})
+// CHECK: store i32 %[[SHUFFLE_RES]], ptr %[[SHUFFLE_ELEM_ASCAST]], align 4
+
+// CHECK: %[[RED_ELEM_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[RED_ELEM]] to ptr
+// CHECK: %[[RED_ALLOC_PTR:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8 }, ptr %[[RED_ELEM_ASCAST]], i32 0, i32 0
+// CHECK: %[[SHUFFLE_ELEM_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[SHUFFLE_ELEM]] to ptr
+// CHECK: store ptr %[[SHUFFLE_ELEM_ASCAST]], ptr %[[RED_ALLOC_PTR]], align 8
+// CHECK: store ptr %[[RED_ELEM_1]], ptr %[[REMOTE_RED_LIST_ELEM0]], align 8
+// CHECK: }
+
+// CHECK: define {{.*}} @_omp_reduction_inter_warp_copy_func({{.*}}) {{.*}} {
+// CHECK: %[[WARP_MASTER_CMP:.*]] = icmp eq i32 %nvptx_lane_id, 0
+// CHECK: br i1 %[[WARP_MASTER_CMP]], label %[[WARP_MASTER_BB:.*]], label %{{.*}}
+
+// CHECK: [[WARP_MASTER_BB]]:
+// CHECK: %[[WARP_RESULT_PTR:.*]] = getelementptr inbounds [1 x ptr], ptr %{{.*}}, i64 0, i64 0
+// CHECK: %[[WARP_RESULT:.*]] = load ptr, ptr %[[WARP_RESULT_PTR]], align 8
+// CHECK: %[[ALLOC_MEM_PTR:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8 }, ptr %[[WARP_RESULT]], i32 0, i32 0
+// CHECK: %[[ALLOC_MEM:.*]] = load ptr, ptr %[[ALLOC_MEM_PTR]], align 8
+// CHECK: %[[WARP_TRANSFER_SLOT:.*]] = getelementptr inbounds [32 x i32], ptr addrspace(3) @__openmp_nvptx_data_transfer_temporary_storage, i64 0, i32 %nvptx_warp_id
+// CHECK: %[[WARP_RED_RES:.*]] = load i32, ptr %[[ALLOC_MEM]], align 4
+// CHECK: store volatile i32 %[[WARP_RED_RES]], ptr addrspace(3) %[[WARP_TRANSFER_SLOT]], align 4
+// CHECK: }
diff --git a/mlir/test/Target/LLVMIR/allocatable_gpu_reduction_teams.mlir b/mlir/test/Target/LLVMIR/allocatable_gpu_reduction_teams.mlir
new file mode 100644
index 0000000..1c73a49
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/allocatable_gpu_reduction_teams.mlir
@@ -0,0 +1,121 @@
+// Tests cross-teams by-ref GPU reductions.
+
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+module attributes {dlti.dl_spec = #dlti.dl_spec<"dlti.alloca_memory_space" = 5 : ui64, "dlti.global_memory_space" = 1 : ui64>, llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} {
+ omp.private {type = private} @_QFfooEi_private_i32 : i32
+ omp.declare_reduction @add_reduction_byref_box_heap_f32 : !llvm.ptr attributes {byref_element_type = f32} alloc {
+ %0 = llvm.mlir.constant(1 : i64) : i64
+ %1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> : (i64) -> !llvm.ptr<5>
+ %2 = llvm.addrspacecast %1 : !llvm.ptr<5> to !llvm.ptr
+ omp.yield(%2 : !llvm.ptr)
+ } init {
+ ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+ omp.yield(%arg1 : !llvm.ptr)
+ } combiner {
+ ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> {alignment = 8 : i64} : (i32) -> !llvm.ptr<5>
+ %2 = llvm.addrspacecast %1 : !llvm.ptr<5> to !llvm.ptr
+ %3 = llvm.mlir.constant(1 : i32) : i32
+ %4 = llvm.alloca %3 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> {alignment = 8 : i64} : (i32) -> !llvm.ptr<5>
+ %5 = llvm.addrspacecast %4 : !llvm.ptr<5> to !llvm.ptr
+ %6 = llvm.mlir.constant(24 : i32) : i32
+ "llvm.intr.memcpy"(%5, %arg0, %6) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+ %7 = llvm.mlir.constant(24 : i32) : i32
+ "llvm.intr.memcpy"(%2, %arg1, %7) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+ %8 = llvm.getelementptr %5[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
+ %9 = llvm.load %8 : !llvm.ptr -> !llvm.ptr
+ %10 = llvm.getelementptr %2[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
+ %11 = llvm.load %10 : !llvm.ptr -> !llvm.ptr
+ %12 = llvm.load %9 : !llvm.ptr -> f32
+ %13 = llvm.load %11 : !llvm.ptr -> f32
+ %14 = llvm.fadd %12, %13 {fastmathFlags = #llvm.fastmath<contract>} : f32
+ llvm.store %14, %9 : f32, !llvm.ptr
+ omp.yield(%arg0 : !llvm.ptr)
+ } data_ptr_ptr {
+ ^bb0(%arg0: !llvm.ptr):
+ %0 = llvm.getelementptr %arg0[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
+ omp.yield(%0 : !llvm.ptr)
+ }
+
+ llvm.func @foo_() {
+ %0 = llvm.mlir.constant(1 : i64) : i64
+ %4 = llvm.alloca %0 x i1 : (i64) -> !llvm.ptr<5>
+ %5 = llvm.addrspacecast %4 : !llvm.ptr<5> to !llvm.ptr
+ %8 = llvm.getelementptr %5[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
+ %9 = omp.map.info var_ptr(%5 : !llvm.ptr, f32) map_clauses(tofrom) capture(ByRef) var_ptr_ptr(%8 : !llvm.ptr) -> !llvm.ptr {name = ""}
+ %10 = omp.map.info var_ptr(%5 : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>) map_clauses(always, descriptor, to, attach) capture(ByRef) members(%9 : [0] : !llvm.ptr) -> !llvm.ptr {name = "scalar_alloc"}
+ omp.target map_entries(%10 -> %arg0 : !llvm.ptr) {
+ %14 = llvm.mlir.constant(1000000 : i32) : i32
+ %15 = llvm.mlir.constant(1 : i32) : i32
+ omp.teams reduction(byref @add_reduction_byref_box_heap_f32 %arg0 -> %arg3 : !llvm.ptr) {
+ omp.parallel {
+ omp.distribute {
+ omp.wsloop reduction(byref @add_reduction_byref_box_heap_f32 %arg3 -> %arg5 : !llvm.ptr) {
+ omp.loop_nest (%arg6) : i32 = (%15) to (%14) inclusive step (%15) {
+ omp.yield
+ }
+ } {omp.composite}
+ } {omp.composite}
+ omp.terminator
+ } {omp.composite}
+ omp.terminator
+ }
+ omp.terminator
+ }
+ llvm.return
+ }
+}
+
+// CHECK: %[[GLOBALIZED_LOCALS:.*]] = type { float }
+
+// CHECK: define internal void @_omp_reduction_list_to_global_copy_func({{.*}}) {{.*}} {
+// CHECK: %[[RED_ARR_LIST:.*]] = getelementptr inbounds [1 x ptr], ptr %{{.*}}, i64 0, i64 0
+// CHECK: %[[RED_ELEM_PTR:.*]] = load ptr, ptr %[[RED_ARR_LIST]], align 8
+// CHECK: %[[GLOB_ELEM_PTR:.*]] = getelementptr inbounds %[[GLOBALIZED_LOCALS]], ptr %{{.*}}, i32 0, i32 0
+// CHECK: %[[ALLOC_PTR_PTR:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8 }, ptr %[[RED_ELEM_PTR]], i32 0, i32 0
+// CHECK: %[[ALLOC_PTR:.*]] = load ptr, ptr %[[ALLOC_PTR_PTR]], align 8
+// CHECK: %[[ALLOC_VAL:.*]] = load float, ptr %[[ALLOC_PTR]], align 4
+// Verify that the actual value managed by the descriptor is stored in the globalized
+// locals arrays; rather than a pointer to the descriptor or a pointer to the value.
+// CHECK: store float %[[ALLOC_VAL]], ptr %[[GLOB_ELEM_PTR]], align 4
+// CHECK: }
+
+// CHECK: define internal void @_omp_reduction_list_to_global_reduce_func({{.*}}) {{.*}} {
+// Allocate a descriptor to manage the element retrieved from the globalized local array.
+// CHECK: %[[ALLOC_DESC:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8 }, align 8, addrspace(5)
+// CHECK: %[[ALLOC_DESC_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ALLOC_DESC]] to ptr
+
+// CHECK: %[[RED_ARR_LIST:.*]] = getelementptr inbounds [1 x ptr], ptr %{{.*}}, i64 0, i64 0
+// CHECK: %[[GLOB_ELEM_PTR:.*]] = getelementptr inbounds %[[GLOBALIZED_LOCALS]], ptr %{{.*}}, i32 0, i32 0
+// CHECK: %[[ALLOC_PTR_PTR:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8 }, ptr %[[ALLOC_DESC_ASCAST]], i32 0, i32 0
+// Store the pointer to the gloalized local element into the locally allocated descriptor.
+// CHECK: store ptr %[[GLOB_ELEM_PTR]], ptr %[[ALLOC_PTR_PTR]], align 8
+// CHECK: store ptr %[[ALLOC_DESC_ASCAST]], ptr %[[RED_ARR_LIST]], align 8
+// CHECK: }
+
+// CHECK: define internal void @_omp_reduction_global_to_list_copy_func({{.*}}) {{.*}} {
+// CHECK: %[[RED_ARR_LIST:.*]] = getelementptr inbounds [1 x ptr], ptr %{{.*}}, i64 0, i64 0
+// CHECK: %[[RED_ELEM_PTR:.*]] = load ptr, ptr %[[RED_ARR_LIST]], align 8
+// CHECK: %[[GLOB_ELEM_PTR:.*]] = getelementptr inbounds %[[GLOBALIZED_LOCALS]], ptr %{{.*}}, i32 0, i32 0
+// CHECK: %[[ALLOC_PTR_PTR:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8 }, ptr %[[RED_ELEM_PTR]], i32 0, i32 0
+// Similar to _omp_reduction_list_to_global_copy_func(...) but in the reverse direction; i.e.
+// the globalized local array is copied from rather than copied to.
+// CHECK: %[[ALLOC_PTR:.*]] = load ptr, ptr %[[ALLOC_PTR_PTR]], align 8
+// CHECK: %[[ALLOC_VAL:.*]] = load float, ptr %[[GLOB_ELEM_PTR]], align 4
+// CHECK: store float %[[ALLOC_VAL]], ptr %[[ALLOC_PTR]], align 4
+// CHECK: }
+
+// CHECK: define internal void @_omp_reduction_global_to_list_reduce_func({{.*}}) {{.*}} {
+// Allocate a descriptor to manage the element retrieved from the globalized local array.
+// CHECK: %[[ALLOC_DESC:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8 }, align 8, addrspace(5)
+// CHECK: %[[ALLOC_DESC_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ALLOC_DESC]] to ptr
+
+// CHECK: %[[RED_ARR_LIST:.*]] = getelementptr inbounds [1 x ptr], ptr %{{.*}}, i64 0, i64 0
+// CHECK: %[[GLOB_ELEM_PTR:.*]] = getelementptr inbounds %[[GLOBALIZED_LOCALS]], ptr %{{.*}}, i32 0, i32 0
+// CHECK: %[[ALLOC_PTR_PTR:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8 }, ptr %[[ALLOC_DESC_ASCAST]], i32 0, i32 0
+// Store the pointer to the gloalized local element into the locally allocated descriptor.
+// CHECK: store ptr %[[GLOB_ELEM_PTR]], ptr %[[ALLOC_PTR_PTR]], align 8
+// CHECK: store ptr %[[ALLOC_DESC_ASCAST]], ptr %[[RED_ARR_LIST]], align 8
+// CHECK: }
diff --git a/mlir/test/Target/LLVMIR/anonymous-tbaa.mlir b/mlir/test/Target/LLVMIR/anonymous-tbaa.mlir
new file mode 100644
index 0000000..b54bfe4
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/anonymous-tbaa.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
+
+#tbaa_root_0 = #llvm.tbaa_root<>
+#tbaa_type_desc_1 = #llvm.tbaa_type_desc<id = "omnipotent char", members = {<#tbaa_root_0, 0>}>
+#tbaa_type_desc_2 = #llvm.tbaa_type_desc<id = "long long", members = {<#tbaa_type_desc_1, 0>}>
+#tbaa_tag_3 = #llvm.tbaa_tag<access_type = #tbaa_type_desc_2, base_type = #tbaa_type_desc_2, offset = 0>
+
+// CHECK: define void @tbaa_anonymous_root(ptr %{{.*}}) {
+// CHECK: %{{.*}} = load i64, ptr %{{.*}}, align 4, !tbaa ![[TAG:[0-9]+]]
+// CHECK: ret void
+// CHECK: }
+// CHECK: !llvm.module.flags = !{![[FLAGS:[0-9]+]]}
+// CHECK: ![[FLAGS]] = !{i32 2, !"Debug Info Version", i32 3}
+// CHECK: ![[TAG]] = !{![[TYPE:[0-9]+]], ![[TYPE]], i64 0}
+// CHECK: ![[TYPE]] = !{!"long long", ![[BASE:[0-9]+]], i64 0}
+// CHECK: ![[BASE]] = !{!"omnipotent char", ![[ROOT:[0-9]+]], i64 0}
+// CHECK: ![[ROOT]] = distinct !{![[ROOT]]}
+llvm.func @tbaa_anonymous_root(%arg0: !llvm.ptr) {
+ %0 = llvm.load %arg0 {tbaa = [#tbaa_tag_3]} : !llvm.ptr -> i64
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
index 60bd24a..403c73f 100644
--- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
@@ -1276,6 +1276,34 @@ llvm.func @experimental_constrained_fpext(%s: f32, %v: vector<4xf32>) {
llvm.return
}
+// CHECK-LABEL: @ucmp
+llvm.func @ucmp(%a: i32, %b: i32) -> i2 {
+ // CHECK: call i2 @llvm.ucmp.i2.i32
+ %r = llvm.intr.ucmp(%a, %b) : (i32, i32) -> i2
+ llvm.return %r : i2
+}
+
+// CHECK-LABEL: @vector_ucmp
+llvm.func @vector_ucmp(%a: vector<4 x i32>, %b: vector<4 x i32>) -> vector<4 x i32> {
+ // CHECK: call <4 x i32> @llvm.ucmp.v4i32.v4i32
+ %0 = llvm.intr.ucmp(%a, %b) : (vector<4 x i32>, vector<4 x i32>) -> vector<4 x i32>
+ llvm.return %0 : vector<4 x i32>
+}
+
+// CHECK-LABEL: @scmp
+llvm.func @scmp(%a: i32, %b: i32) -> i2 {
+ // CHECK: call i2 @llvm.scmp.i2.i32
+ %r = llvm.intr.scmp(%a, %b) : (i32, i32) -> i2
+ llvm.return %r : i2
+}
+
+// CHECK-LABEL: @vector_scmp
+llvm.func @vector_scmp(%a: vector<4 x i32>, %b: vector<4 x i32>) -> vector<4 x i32> {
+ // CHECK: call <4 x i32> @llvm.scmp.v4i32.v4i32
+ %0 = llvm.intr.scmp(%a, %b) : (vector<4 x i32>, vector<4 x i32>) -> vector<4 x i32>
+ llvm.return %0 : vector<4 x i32>
+}
+
// Check that intrinsics are declared with appropriate types.
// CHECK-DAG: declare float @llvm.fma.f32(float, float, float)
// CHECK-DAG: declare <8 x float> @llvm.fma.v8f32(<8 x float>, <8 x float>, <8 x float>) #0
@@ -1308,7 +1336,7 @@ llvm.func @experimental_constrained_fpext(%s: f32, %v: vector<4xf32>) {
// CHECK-DAG: declare float @llvm.cos.f32(float)
// CHECK-DAG: declare <8 x float> @llvm.cos.v8f32(<8 x float>) #0
// CHECK-DAG: declare { float, float } @llvm.sincos.f32(float)
-// CHECK-DAG: declare { <8 x float>, <8 x float> } @llvm.sincos.v8f32(<8 x float>) #0
+// CHECK-DAG: declare { <8 x float>, <8 x float> } @llvm.sincos.v8f32(<8 x float>)
// CHECK-DAG: declare float @llvm.copysign.f32(float, float)
// CHECK-DAG: declare float @llvm.rint.f32(float)
// CHECK-DAG: declare double @llvm.rint.f64(double)
@@ -1464,3 +1492,7 @@ llvm.func @experimental_constrained_fpext(%s: f32, %v: vector<4xf32>) {
// CHECK-DAG: declare <4 x half> @llvm.experimental.constrained.fptrunc.v4f16.v4f32(<4 x float>, metadata, metadata)
// CHECK-DAG: declare double @llvm.experimental.constrained.fpext.f64.f32(float, metadata)
// CHECK-DAG: declare <4 x double> @llvm.experimental.constrained.fpext.v4f64.v4f32(<4 x float>, metadata)
+// CHECK-DAG: declare range(i2 -1, -2) i2 @llvm.ucmp.i2.i32(i32, i32)
+// CHECK-DAG: declare range(i32 -1, 2) <4 x i32> @llvm.ucmp.v4i32.v4i32(<4 x i32>, <4 x i32>)
+// CHECK-DAG: declare range(i2 -1, -2) i2 @llvm.scmp.i2.i32(i32, i32)
+// CHECK-DAG: declare range(i32 -1, 2) <4 x i32> @llvm.scmp.v4i32.v4i32(<4 x i32>, <4 x i32>)
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index cc243c8..819a514 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -78,6 +78,9 @@ llvm.mlir.global internal @f8E8M0FNU_global_as_i8(1.0 : f8E8M0FNU) : i8
// CHECK: @bf16_global_as_i16 = internal global i16 16320
llvm.mlir.global internal @bf16_global_as_i16(1.5 : bf16) : i16
+// CHECK: @bool_global_as_i8 = internal global i8 1
+llvm.mlir.global internal @bool_global_as_i8(true) : i8
+
// CHECK: @explicit_undef = global i32 undef
llvm.mlir.global external @explicit_undef() : i32 {
%0 = llvm.mlir.undef : i32
@@ -2371,17 +2374,17 @@ llvm.func @readonly_function(%arg0: !llvm.ptr {llvm.readonly})
// CHECK: declare void @arg_mem_none_func() #[[ATTR:[0-9]+]]
llvm.func @arg_mem_none_func() attributes {
- memory_effects = #llvm.memory_effects<other = readwrite, argMem = none, inaccessibleMem = readwrite>}
+ memory_effects = #llvm.memory_effects<other = readwrite, argMem = none, inaccessibleMem = readwrite, errnoMem = none, targetMem0 = none, targetMem1 = none>}
-// CHECK: attributes #[[ATTR]] = { memory(readwrite, argmem: none, errnomem: none) }
+// CHECK: attributes #[[ATTR]] = { memory(readwrite, argmem: none, errnomem: none, target_mem0: none, target_mem1: none) }
// -----
// CHECK: declare void @readwrite_func() #[[ATTR:[0-9]+]]
llvm.func @readwrite_func() attributes {
- memory_effects = #llvm.memory_effects<other = readwrite, argMem = readwrite, inaccessibleMem = readwrite>}
+ memory_effects = #llvm.memory_effects<other = readwrite, argMem = readwrite, inaccessibleMem = readwrite, errnoMem = none, targetMem0 = none, targetMem1 = none>}
-// CHECK: attributes #[[ATTR]] = { memory(readwrite, errnomem: none) }
+// CHECK: attributes #[[ATTR]] = { memory(readwrite, errnomem: none, target_mem0: none, target_mem1: none) }
// -----
@@ -2723,10 +2726,10 @@ llvm.func @fd()
// CHECK: call void @fc() #[[ATTRS_2:[0-9]+]]
// CHECK: call void @fd() #[[ATTRS_3:[0-9]+]]
llvm.func @mem_effects_call() {
- llvm.call @fa() {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>} : () -> ()
- llvm.call @fb() {memory_effects = #llvm.memory_effects<other = read, argMem = none, inaccessibleMem = write>} : () -> ()
- llvm.call @fc() {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = write>} : () -> ()
- llvm.call @fd() {memory_effects = #llvm.memory_effects<other = readwrite, argMem = read, inaccessibleMem = readwrite>} : () -> ()
+ llvm.call @fa() {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>} : () -> ()
+ llvm.call @fb() {memory_effects = #llvm.memory_effects<other = read, argMem = none, inaccessibleMem = write, errnoMem = none, targetMem0 = none, targetMem1 = none>} : () -> ()
+ llvm.call @fc() {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = write, errnoMem = none, targetMem0 = none, targetMem1 = none>} : () -> ()
+ llvm.call @fd() {memory_effects = #llvm.memory_effects<other = readwrite, argMem = read, inaccessibleMem = readwrite, errnoMem = none, targetMem0 = none, targetMem1 = none>} : () -> ()
llvm.return
}
@@ -2734,11 +2737,11 @@ llvm.func @mem_effects_call() {
// CHECK: #[[ATTRS_0]]
// CHECK-SAME: memory(none)
// CHECK: #[[ATTRS_1]]
-// CHECK-SAME: memory(read, argmem: none, inaccessiblemem: write, errnomem: none)
+// CHECK-SAME: memory(read, argmem: none, inaccessiblemem: write, errnomem: none, target_mem0: none, target_mem1: none)
// CHECK: #[[ATTRS_2]]
-// CHECK-SAME: memory(read, inaccessiblemem: write, errnomem: none)
+// CHECK-SAME: memory(read, inaccessiblemem: write, errnomem: none, target_mem0: none, target_mem1: none)
// CHECK: #[[ATTRS_3]]
-// CHECK-SAME: memory(readwrite, argmem: read, errnomem: none)
+// CHECK-SAME: memory(readwrite, argmem: read, errnomem: none, target_mem0: none, target_mem1: none)
// -----
diff --git a/mlir/test/Target/LLVMIR/nvvm/barrier.mlir b/mlir/test/Target/LLVMIR/nvvm/barrier.mlir
new file mode 100644
index 0000000..a18633e
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/barrier.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s --check-prefix=LLVM
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+
+// LLVM-LABEL: @llvm_nvvm_barrier(
+// LLVM-SAME: i32 %[[barId:.*]], i32 %[[numThreads:.*]], i32 %[[redOperand:.*]])
+llvm.func @llvm_nvvm_barrier(%barID : i32, %numberOfThreads : i32, %redOperand : i32) {
+ // LLVM: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
+ // CHECK: nvvm.barrier
+ nvvm.barrier
+ // LLVM: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 %[[barId]])
+ // CHECK: nvvm.barrier id = %{{.*}}
+ nvvm.barrier id = %barID
+ // LLVM: call void @llvm.nvvm.barrier.cta.sync.aligned.count(i32 %[[barId]], i32 %[[numThreads]])
+ // CHECK: nvvm.barrier id = %{{.*}} number_of_threads = %{{.*}}
+ nvvm.barrier id = %barID number_of_threads = %numberOfThreads
+ // LLVM: %{{.*}} = call i32 @llvm.nvvm.barrier0.and(i32 %[[redOperand]])
+ // CHECK: %{{.*}} = nvvm.barrier #nvvm.reduction<and> %{{.*}} -> i32
+ %0 = nvvm.barrier #nvvm.reduction<and> %redOperand -> i32
+ // LLVM: %{{.*}} = call i32 @llvm.nvvm.barrier0.or(i32 %[[redOperand]])
+ // CHECK: %{{.*}} = nvvm.barrier #nvvm.reduction<or> %{{.*}} -> i32
+ %1 = nvvm.barrier #nvvm.reduction<or> %redOperand -> i32
+ // LLVM: %{{.*}} = call i32 @llvm.nvvm.barrier0.popc(i32 %[[redOperand]])
+ // CHECK: %{{.*}} = nvvm.barrier #nvvm.reduction<popc> %{{.*}} -> i32
+ %2 = nvvm.barrier #nvvm.reduction<popc> %redOperand -> i32
+
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp16x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp16x2.mlir
new file mode 100644
index 0000000..a4bece8
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp16x2.mlir
@@ -0,0 +1,87 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @convert_f32x2_to_f16x2_rn
+llvm.func @convert_f32x2_to_f16x2_rn(%srcA : f32, %srcB : f32) {
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn(float %{{.*}}, float %{{.*}})
+ %res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+ %res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu(float %{{.*}}, float %{{.*}})
+ %res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
+ %res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_f32x2_to_f16x2_rz
+llvm.func @convert_f32x2_to_f16x2_rz(%srcA : f32, %srcB : f32) {
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz(float %{{.*}}, float %{{.*}})
+ %res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.satfinite(float %{{.*}}, float %{{.*}})
+ %res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float %{{.*}}, float %{{.*}})
+ %res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu.satfinite(float %{{.*}}, float %{{.*}})
+ %res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_f32x2_to_f16x2_rs_stochastic
+llvm.func @convert_f32x2_to_f16x2_rs_stochastic(%srcA : f32, %srcB : f32, %rbits : i32) {
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+ %res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+ %res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+ %res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+ %res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @convert_f32x2_to_bf16x2_rn
+llvm.func @convert_f32x2_to_bf16x2_rn(%srcA : f32, %srcB : f32) {
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn(float %{{.*}}, float %{{.*}})
+ %res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+ %res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu(float %{{.*}}, float %{{.*}})
+ %res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
+ %res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_f32x2_to_bf16x2_rz
+llvm.func @convert_f32x2_to_bf16x2_rz(%srcA : f32, %srcB : f32) {
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz(float %{{.*}}, float %{{.*}})
+ %res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.satfinite(float %{{.*}}, float %{{.*}})
+ %res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu(float %{{.*}}, float %{{.*}})
+ %res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu.satfinite(float %{{.*}}, float %{{.*}})
+ %res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_f32x2_to_bf16x2_rs_stochastic
+llvm.func @convert_f32x2_to_bf16x2_rs_stochastic(%srcA : f32, %srcB : f32, %rbits : i32) {
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+ %res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+ %res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+ %res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+ %res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir
new file mode 100644
index 0000000..03abcdd
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir
@@ -0,0 +1,118 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// -----
+
+// Test valid architectures work
+
+// Valid case on sm_100a
+gpu.module @valid_f16x2_rs_sm_100a [#nvvm.target<chip = "sm_100a">] {
+ func.func @convert_rs() {
+ %f1 = llvm.mlir.constant(1.0 : f32) : f32
+ %f2 = llvm.mlir.constant(2.0 : f32) : f32
+ %rbits = llvm.mlir.constant(0x12345678 : i32) : i32
+ %res = nvvm.convert.f32x2.to.f16x2 %f1, %f2, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16>
+ return
+ }
+}
+
+// Valid case on sm_103a
+gpu.module @valid_bf16x2_rs_sm_103a [#nvvm.target<chip = "sm_103a">] {
+ func.func @convert_rs() {
+ %f1 = llvm.mlir.constant(1.0 : f32) : f32
+ %f2 = llvm.mlir.constant(2.0 : f32) : f32
+ %rbits = llvm.mlir.constant(0 : i32) : i32
+ %res = nvvm.convert.f32x2.to.bf16x2 %f1, %f2, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16>
+ return
+ }
+}
+
+// -----
+
+// Test F32x4 -> F8x4 (E4M3) with stochastic rounding (.rs)
+
+// CHECK-LABEL: @convert_f32x4_to_f8x4_e4m3_rs
+llvm.func @convert_f32x4_to_f8x4_e4m3_rs(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
+ // CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e4m3x4.rs.satfinite(<4 x float> %{{.*}}, i32 %{{.*}})
+ %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E4M3FN)
+ llvm.return %res : vector<4xi8>
+}
+
+// CHECK-LABEL: @convert_f32x4_to_f8x4_e4m3_rs_relu
+llvm.func @convert_f32x4_to_f8x4_e4m3_rs_relu(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
+ // CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e4m3x4.rs.relu.satfinite(<4 x float> %{{.*}}, i32 %{{.*}})
+ %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits {relu = true} : vector<4xf32> -> vector<4xi8> (f8E4M3FN)
+ llvm.return %res : vector<4xi8>
+}
+
+// -----
+
+// Test F32x4 -> F8x4 (E5M2) with stochastic rounding (.rs)
+
+// CHECK-LABEL: @convert_f32x4_to_f8x4_e5m2_rs
+llvm.func @convert_f32x4_to_f8x4_e5m2_rs(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
+ // CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e5m2x4.rs.satfinite(<4 x float> %{{.*}}, i32 %{{.*}})
+ %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E5M2)
+ llvm.return %res : vector<4xi8>
+}
+
+// CHECK-LABEL: @convert_f32x4_to_f8x4_e5m2_rs_relu
+llvm.func @convert_f32x4_to_f8x4_e5m2_rs_relu(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
+ // CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e5m2x4.rs.relu.satfinite(<4 x float> %{{.*}}, i32 %{{.*}})
+ %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits {relu = true} : vector<4xf32> -> vector<4xi8> (f8E5M2)
+ llvm.return %res : vector<4xi8>
+}
+
+// -----
+
+// Test F32x4 -> F6x4 (E2M3) with stochastic rounding (.rs)
+
+// CHECK-LABEL: @convert_f32x4_to_f6x4_e2m3_rs
+llvm.func @convert_f32x4_to_f6x4_e2m3_rs(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
+ // CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e2m3x4.rs.satfinite(<4 x float> %{{.*}}, i32 %{{.*}})
+ %res = nvvm.convert.f32x4.to.f6x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f6E2M3FN)
+ llvm.return %res : vector<4xi8>
+}
+
+// CHECK-LABEL: @convert_f32x4_to_f6x4_e2m3_rs_relu
+llvm.func @convert_f32x4_to_f6x4_e2m3_rs_relu(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
+ // CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e2m3x4.rs.relu.satfinite(<4 x float> %{{.*}}, i32 %{{.*}})
+ %res = nvvm.convert.f32x4.to.f6x4 %src, %rbits {relu = true} : vector<4xf32> -> vector<4xi8> (f6E2M3FN)
+ llvm.return %res : vector<4xi8>
+}
+
+// -----
+
+// Test F32x4 -> F6x4 (E3M2) with stochastic rounding (.rs)
+
+// CHECK-LABEL: @convert_f32x4_to_f6x4_e3m2_rs
+llvm.func @convert_f32x4_to_f6x4_e3m2_rs(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
+ // CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e3m2x4.rs.satfinite(<4 x float> %{{.*}}, i32 %{{.*}})
+ %res = nvvm.convert.f32x4.to.f6x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f6E3M2FN)
+ llvm.return %res : vector<4xi8>
+}
+
+// CHECK-LABEL: @convert_f32x4_to_f6x4_e3m2_rs_relu
+llvm.func @convert_f32x4_to_f6x4_e3m2_rs_relu(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
+ // CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e3m2x4.rs.relu.satfinite(<4 x float> %{{.*}}, i32 %{{.*}})
+ %res = nvvm.convert.f32x4.to.f6x4 %src, %rbits {relu = true} : vector<4xf32> -> vector<4xi8> (f6E3M2FN)
+ llvm.return %res : vector<4xi8>
+}
+
+// -----
+
+// Test F32x4 -> F4x4 (E2M1) with stochastic rounding (.rs)
+
+// CHECK-LABEL: @convert_f32x4_to_f4x4_e2m1_rs
+llvm.func @convert_f32x4_to_f4x4_e2m1_rs(%src : vector<4xf32>, %rbits : i32) -> i16 {
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.f32x4.to.e2m1x4.rs.satfinite(<4 x float> %{{.*}}, i32 %{{.*}})
+ %res = nvvm.convert.f32x4.to.f4x4 %src, %rbits : vector<4xf32> -> i16 (f4E2M1FN)
+ llvm.return %res : i16
+}
+
+// CHECK-LABEL: @convert_f32x4_to_f4x4_e2m1_rs_relu
+llvm.func @convert_f32x4_to_f4x4_e2m1_rs_relu(%src : vector<4xf32>, %rbits : i32) -> i16 {
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.f32x4.to.e2m1x4.rs.relu.satfinite(<4 x float> %{{.*}}, i32 %{{.*}})
+ %res = nvvm.convert.f32x4.to.f4x4 %src, %rbits {relu = true} : vector<4xf32> -> i16 (f4E2M1FN)
+ llvm.return %res : i16
+}
+
diff --git a/mlir/test/Target/LLVMIR/nvvm/fence-invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/fence-invalid.mlir
new file mode 100644
index 0000000..22578b5
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/fence-invalid.mlir
@@ -0,0 +1,89 @@
+// RUN: mlir-translate --mlir-to-llvmir -verify-diagnostics -split-input-file %s
+
+llvm.func @fence_sync_restrict() {
+ // expected-error @below {{only acquire and release semantics are supported}}
+ nvvm.fence.sync_restrict {order = #nvvm.mem_order<weak>}
+ llvm.return
+}
+
+// -----
+
+llvm.func @fence_sync_restrict() {
+ // expected-error @below {{only acquire and release semantics are supported}}
+ nvvm.fence.sync_restrict {order = #nvvm.mem_order<mmio>}
+ llvm.return
+}
+
+// -----
+
+llvm.func @fence_proxy() {
+ // expected-error @below {{tensormap proxy is not a supported proxy kind}}
+ nvvm.fence.proxy {kind = #nvvm.proxy_kind<tensormap>}
+ llvm.return
+}
+
+// -----
+
+llvm.func @fence_proxy() {
+ // expected-error @below {{generic proxy not a supported proxy kind}}
+ nvvm.fence.proxy {kind = #nvvm.proxy_kind<generic>}
+ llvm.return
+}
+
+// -----
+
+llvm.func @fence_proxy() {
+ // expected-error @below {{async_shared fence requires space attribute}}
+ nvvm.fence.proxy {kind = #nvvm.proxy_kind<async.shared>}
+ llvm.return
+}
+
+// -----
+
+llvm.func @fence_proxy() {
+ // expected-error @below {{only async_shared fence can have space attribute}}
+ nvvm.fence.proxy {kind = #nvvm.proxy_kind<alias>, space = #nvvm.shared_space<cta>}
+ llvm.return
+}
+
+// -----
+
+llvm.func @fence_proxy_release() {
+ // expected-error @below {{uni-directional proxies only support generic for from_proxy attribute}}
+ nvvm.fence.proxy.release #nvvm.mem_scope<cta> from_proxy = #nvvm.proxy_kind<alias> to_proxy = #nvvm.proxy_kind<tensormap>
+ llvm.return
+}
+
+// -----
+
+llvm.func @fence_proxy_release() {
+ // expected-error @below {{uni-directional proxies only support tensormap for to_proxy attribute}}
+ nvvm.fence.proxy.release #nvvm.mem_scope<cta> from_proxy = #nvvm.proxy_kind<generic> to_proxy = #nvvm.proxy_kind<async>
+ llvm.return
+}
+
+// -----
+
+llvm.func @fence_proxy_sync_restrict() {
+ // expected-error @below {{only acquire and release semantics are supported}}
+ nvvm.fence.proxy.sync_restrict {order = #nvvm.mem_order<mmio>}
+ llvm.return
+}
+
+// -----
+
+llvm.func @fence_proxy_sync_restrict() {
+ // expected-error @below {{only async is supported for to_proxy attribute}}
+ nvvm.fence.proxy.sync_restrict {order = #nvvm.mem_order<acquire>, toProxy = #nvvm.proxy_kind<alias>,
+ fromProxy = #nvvm.proxy_kind<generic>}
+ llvm.return
+}
+
+// -----
+
+llvm.func @fence_proxy_sync_restrict() {
+ // expected-error @below {{only generic is support for from_proxy attribute}}
+ nvvm.fence.proxy.sync_restrict {order = #nvvm.mem_order<acquire>, toProxy = #nvvm.proxy_kind<async>,
+ fromProxy = #nvvm.proxy_kind<tensormap>}
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/fence.mlir b/mlir/test/Target/LLVMIR/nvvm/fence.mlir
new file mode 100644
index 0000000..0ab4cb7
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/fence.mlir
@@ -0,0 +1,85 @@
+// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @llvm_nvvm_fence_sc_cluster
+llvm.func @llvm_nvvm_fence_sc_cluster() {
+ // CHECK: nvvm.fence.sc.cluster
+ nvvm.fence.sc.cluster
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_fence_sync_restrict
+llvm.func @nvvm_fence_sync_restrict() {
+ // CHECK: call void @llvm.nvvm.fence.acquire.sync_restrict.space.cluster.scope.cluster()
+ nvvm.fence.sync_restrict {order = #nvvm.mem_order<acquire>}
+ // CHECK: call void @llvm.nvvm.fence.release.sync_restrict.space.cta.scope.cluster()
+ nvvm.fence.sync_restrict {order = #nvvm.mem_order<release>}
+ llvm.return
+}
+
+// CHECK-LABEL: @fence_mbarrier_init
+llvm.func @fence_mbarrier_init() {
+ // CHECK: call void @llvm.nvvm.fence.mbarrier_init.release.cluster()
+ nvvm.fence.mbarrier.init
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_fence_proxy
+llvm.func @nvvm_fence_proxy() {
+ // CHECK: call void @llvm.nvvm.fence.proxy.alias()
+ nvvm.fence.proxy {kind = #nvvm.proxy_kind<alias>}
+
+ // CHECK: call void @llvm.nvvm.fence.proxy.async()
+ nvvm.fence.proxy {kind = #nvvm.proxy_kind<async>}
+
+ // CHECK: call void @llvm.nvvm.fence.proxy.async.global()
+ nvvm.fence.proxy {kind = #nvvm.proxy_kind<async.global>}
+
+ // CHECK: call void @llvm.nvvm.fence.proxy.async.shared_cta()
+ nvvm.fence.proxy {kind = #nvvm.proxy_kind<async.shared>, space = #nvvm.shared_space<cta>}
+
+ // CHECK: call void @llvm.nvvm.fence.proxy.async.shared_cluster()
+ nvvm.fence.proxy {kind = #nvvm.proxy_kind<async.shared>, space = #nvvm.shared_space<cluster>}
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_fence_proxy_sync_restrict
+llvm.func @nvvm_fence_proxy_sync_restrict() {
+ // CHECK: call void @llvm.nvvm.fence.proxy.async_generic.acquire.sync_restrict.space.cluster.scope.cluster()
+ nvvm.fence.proxy.sync_restrict {order = #nvvm.mem_order<acquire>}
+ // CHECK: call void @llvm.nvvm.fence.proxy.async_generic.release.sync_restrict.space.cta.scope.cluster()
+ nvvm.fence.proxy.sync_restrict {order = #nvvm.mem_order<release>}
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_fence_proxy_tensormap_generic_release
+llvm.func @nvvm_fence_proxy_tensormap_generic_release() {
+ // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.release.cta()
+ nvvm.fence.proxy.release #nvvm.mem_scope<cta>
+
+ // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.release.cluster()
+ nvvm.fence.proxy.release #nvvm.mem_scope<cluster>
+
+ // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.release.gpu()
+ nvvm.fence.proxy.release #nvvm.mem_scope<gpu>
+
+ // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.release.sys()
+ nvvm.fence.proxy.release #nvvm.mem_scope<sys>
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_fence_proxy_tensormap_generic_acquire
+llvm.func @nvvm_fence_proxy_tensormap_generic_acquire(%addr : !llvm.ptr) {
+ %c128 = llvm.mlir.constant(128) : i32
+ // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.acquire.cta(ptr {{%[0-9]+}}, i32 128)
+ nvvm.fence.proxy.acquire #nvvm.mem_scope<cta> %addr, %c128
+
+ // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.acquire.cluster(ptr {{%[0-9]+}}, i32 128)
+ nvvm.fence.proxy.acquire #nvvm.mem_scope<cluster> %addr, %c128
+
+ // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.acquire.gpu(ptr {{%[0-9]+}}, i32 128)
+ nvvm.fence.proxy.acquire #nvvm.mem_scope<gpu> %addr, %c128
+
+ // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.acquire.sys(ptr {{%[0-9]+}}, i32 128)
+ nvvm.fence.proxy.acquire #nvvm.mem_scope<sys> %addr, %c128
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/invalid_convert_fp16x2.mlir b/mlir/test/Target/LLVMIR/nvvm/invalid_convert_fp16x2.mlir
new file mode 100644
index 0000000..37756c8
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/invalid_convert_fp16x2.mlir
@@ -0,0 +1,47 @@
+// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s
+
+// -----
+
+llvm.func @nvvm_cvt_f32x2_to_f16x2_invalid_rounding(%a : f32, %b : f32) {
+ // expected-error @below {{Only RN, RZ, and RS rounding modes are supported for conversions from f32x2 to f16x2.}}
+ %res = nvvm.convert.f32x2.to.f16x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rm>} : vector<2xf16>
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f32x2_to_f16x2_invalid_rbits_1(%a : f32, %b : f32) {
+ // expected-error @below {{random_bits is required for RS rounding mode.}}
+ %res = nvvm.convert.f32x2.to.f16x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16>
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f32x2_to_f16x2_invalid_rbits_2(%a : f32, %b : f32, %rbits : i32) {
+ // expected-error @below {{random_bits not supported for RN and RZ rounding modes.}}
+ %res = nvvm.convert.f32x2.to.f16x2 %a, %b, %rbits {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16>
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f32x2_to_bf16x2_invalid_rounding(%a : f32, %b : f32) {
+ // expected-error @below {{Only RN, RZ, and RS rounding modes are supported for conversions from f32x2 to bf16x2.}}
+ %res = nvvm.convert.f32x2.to.bf16x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rm>} : vector<2xbf16>
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f32x2_to_bf16x2_invalid_rbits_1(%a : f32, %b : f32) {
+ // expected-error @below {{random_bits is required for RS rounding mode.}}
+ %res = nvvm.convert.f32x2.to.bf16x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16>
+ llvm.return
+}
+
+llvm.func @nvvm_cvt_f32x2_to_bf16x2_invalid_rbits_2(%a : f32, %b : f32, %rbits : i32) {
+ // expected-error @below {{random_bits not supported for RN and RZ rounding modes.}}
+ %res = nvvm.convert.f32x2.to.bf16x2 %a, %b, %rbits {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16>
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_arr_drop_expect_tx.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_arr_drop_expect_tx.mlir
new file mode 100644
index 0000000..4b3cafe
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/mbar_arr_drop_expect_tx.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+llvm.func @mbarrier_arrive_drop_expect_tx_generic(%barrier: !llvm.ptr, %txcount : i32) {
+ // CHECK-LABEL: define void @mbarrier_arrive_drop_expect_tx_generic(ptr %0, i32 %1) {
+ // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %4 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.scope.cta.space.cta(ptr addrspace(3) %3, i32 %1)
+ // CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %6 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.scope.cta.space.cta(ptr addrspace(3) %5, i32 %1)
+ // CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %8 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.scope.cluster.space.cta(ptr addrspace(3) %7, i32 %1)
+ // CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %10 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.relaxed.scope.cta.space.cta(ptr addrspace(3) %9, i32 %1)
+ // CHECK-NEXT: %11 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %12 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.relaxed.scope.cta.space.cta(ptr addrspace(3) %11, i32 %1)
+ // CHECK-NEXT: %13 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %14 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.relaxed.scope.cluster.space.cta(ptr addrspace(3) %13, i32 %1)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %0 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount : !llvm.ptr, i32 -> i64
+ %1 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>} : !llvm.ptr, i32 -> i64
+ %2 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i32 -> i64
+
+ %3 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {relaxed = true} : !llvm.ptr, i32 -> i64
+ %4 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr, i32 -> i64
+ %5 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr, i32 -> i64
+ llvm.return
+}
+
+llvm.func @mbarrier_arrive_drop_expect_tx_shared(%barrier: !llvm.ptr<3>, %txcount : i32) {
+ // CHECK-LABEL: define void @mbarrier_arrive_drop_expect_tx_shared(ptr addrspace(3) %0, i32 %1) {
+ // CHECK-NEXT: %3 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: %4 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: %5 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: %6 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: %7 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: %8 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %0 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount : !llvm.ptr<3>, i32 -> i64
+ %1 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<3>, i32 -> i64
+ %2 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32 -> i64
+
+ %3 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {relaxed = true} : !llvm.ptr<3>, i32 -> i64
+ %4 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr<3>, i32 -> i64
+ %5 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr<3>, i32 -> i64
+ llvm.return
+}
+
+llvm.func @mbarrier_arrive_drop_expect_tx_shared_cluster(%barrier: !llvm.ptr<7>, %txcount : i32) {
+ // CHECK-LABEL: define void @mbarrier_arrive_drop_expect_tx_shared_cluster(ptr addrspace(7) %0, i32 %1) {
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.expect.tx.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1)
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.expect.tx.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1)
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.expect.tx.scope.cluster.space.cluster(ptr addrspace(7) %0, i32 %1)
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.expect.tx.relaxed.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1)
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.expect.tx.relaxed.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1)
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.expect.tx.relaxed.scope.cluster.space.cluster(ptr addrspace(7) %0, i32 %1)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount : !llvm.ptr<7>, i32
+ nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<7>, i32
+ nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<7>, i32
+
+ nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {relaxed = true} : !llvm.ptr<7>, i32
+ nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr<7>, i32
+ nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr<7>, i32
+ llvm.return
+}
+
diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_arr_expect_tx.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_arr_expect_tx.mlir
new file mode 100644
index 0000000..b5389bd
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/mbar_arr_expect_tx.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+llvm.func @mbarrier_arrive_expect_tx_generic(%barrier: !llvm.ptr, %txcount : i32) {
+ // CHECK-LABEL: define void @mbarrier_arrive_expect_tx_generic(ptr %0, i32 %1) {
+ // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %4 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.scope.cta.space.cta(ptr addrspace(3) %3, i32 %1)
+ // CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %6 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.scope.cta.space.cta(ptr addrspace(3) %5, i32 %1)
+ // CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %8 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.scope.cluster.space.cta(ptr addrspace(3) %7, i32 %1)
+ // CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %10 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.relaxed.scope.cta.space.cta(ptr addrspace(3) %9, i32 %1)
+ // CHECK-NEXT: %11 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %12 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.relaxed.scope.cta.space.cta(ptr addrspace(3) %11, i32 %1)
+ // CHECK-NEXT: %13 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %14 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.relaxed.scope.cluster.space.cta(ptr addrspace(3) %13, i32 %1)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %0 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr, i32 -> i64
+ %1 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>} : !llvm.ptr, i32 -> i64
+ %2 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i32 -> i64
+
+ %3 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {relaxed = true} : !llvm.ptr, i32 -> i64
+ %4 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr, i32 -> i64
+ %5 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr, i32 -> i64
+ llvm.return
+}
+
+llvm.func @mbarrier_arrive_expect_tx_shared(%barrier: !llvm.ptr<3>, %txcount : i32) {
+ // CHECK-LABEL: define void @mbarrier_arrive_expect_tx_shared(ptr addrspace(3) %0, i32 %1) {
+ // CHECK-NEXT: %3 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: %4 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: %5 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: %6 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: %7 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: %8 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %0 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr<3>, i32 -> i64
+ %1 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<3>, i32 -> i64
+ %2 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32 -> i64
+
+ %3 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {relaxed = true} : !llvm.ptr<3>, i32 -> i64
+ %4 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr<3>, i32 -> i64
+ %5 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr<3>, i32 -> i64
+ llvm.return
+}
+
+llvm.func @mbarrier_arrive_expect_tx_shared_cluster(%barrier: !llvm.ptr<7>, %txcount : i32) {
+ // CHECK-LABEL: define void @mbarrier_arrive_expect_tx_shared_cluster(ptr addrspace(7) %0, i32 %1) {
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.expect.tx.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1)
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.expect.tx.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1)
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.expect.tx.scope.cluster.space.cluster(ptr addrspace(7) %0, i32 %1)
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.expect.tx.relaxed.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1)
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.expect.tx.relaxed.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1)
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.expect.tx.relaxed.scope.cluster.space.cluster(ptr addrspace(7) %0, i32 %1)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr<7>, i32
+ nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<7>, i32
+ nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<7>, i32
+
+ nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {relaxed = true} : !llvm.ptr<7>, i32
+ nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr<7>, i32
+ nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr<7>, i32
+ llvm.return
+}
+
diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir
new file mode 100644
index 0000000..6e7e163
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir
@@ -0,0 +1,103 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+llvm.func @mbarrier_arrive_generic(%barrier: !llvm.ptr, %count : i32) {
+ // CHECK-LABEL: define void @mbarrier_arrive_generic(ptr %0, i32 %1) {
+ // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %4 = call i64 @llvm.nvvm.mbarrier.arrive.scope.cta.space.cta(ptr addrspace(3) %3, i32 1)
+ // CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %6 = call i64 @llvm.nvvm.mbarrier.arrive.scope.cta.space.cta(ptr addrspace(3) %5, i32 %1)
+ // CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %8 = call i64 @llvm.nvvm.mbarrier.arrive.scope.cta.space.cta(ptr addrspace(3) %7, i32 %1)
+ // CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %10 = call i64 @llvm.nvvm.mbarrier.arrive.scope.cluster.space.cta(ptr addrspace(3) %9, i32 %1)
+ // CHECK-NEXT: %11 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %12 = call i64 @llvm.nvvm.mbarrier.arrive.relaxed.scope.cta.space.cta(ptr addrspace(3) %11, i32 1)
+ // CHECK-NEXT: %13 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %14 = call i64 @llvm.nvvm.mbarrier.arrive.relaxed.scope.cta.space.cta(ptr addrspace(3) %13, i32 %1)
+ // CHECK-NEXT: %15 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %16 = call i64 @llvm.nvvm.mbarrier.arrive.relaxed.scope.cta.space.cta(ptr addrspace(3) %15, i32 %1)
+ // CHECK-NEXT: %17 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %18 = call i64 @llvm.nvvm.mbarrier.arrive.relaxed.scope.cluster.space.cta(ptr addrspace(3) %17, i32 %1)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %0 = nvvm.mbarrier.arrive %barrier : !llvm.ptr -> i64
+ %1 = nvvm.mbarrier.arrive %barrier, %count : !llvm.ptr -> i64
+ %2 = nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr -> i64
+ %3 = nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr -> i64
+
+ %4 = nvvm.mbarrier.arrive %barrier {relaxed = true} : !llvm.ptr -> i64
+ %5 = nvvm.mbarrier.arrive %barrier, %count {relaxed = true} : !llvm.ptr -> i64
+ %6 = nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr -> i64
+ %7 = nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr -> i64
+ llvm.return
+}
+
+llvm.func @mbarrier_arrive_shared(%barrier: !llvm.ptr<3>, %count : i32) {
+ // CHECK-LABEL: define void @mbarrier_arrive_shared(ptr addrspace(3) %0, i32 %1) {
+ // CHECK-NEXT: %3 = call i64 @llvm.nvvm.mbarrier.arrive.scope.cta.space.cta(ptr addrspace(3) %0, i32 1)
+ // CHECK-NEXT: %4 = call i64 @llvm.nvvm.mbarrier.arrive.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: %5 = call i64 @llvm.nvvm.mbarrier.arrive.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: %6 = call i64 @llvm.nvvm.mbarrier.arrive.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: %7 = call i64 @llvm.nvvm.mbarrier.arrive.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 1)
+ // CHECK-NEXT: %8 = call i64 @llvm.nvvm.mbarrier.arrive.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: %9 = call i64 @llvm.nvvm.mbarrier.arrive.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: %10 = call i64 @llvm.nvvm.mbarrier.arrive.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %0 = nvvm.mbarrier.arrive %barrier : !llvm.ptr<3> -> i64
+ %1 = nvvm.mbarrier.arrive %barrier, %count : !llvm.ptr<3> -> i64
+ %2 = nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<3> -> i64
+ %3 = nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3> -> i64
+
+ %4 = nvvm.mbarrier.arrive %barrier {relaxed = true} : !llvm.ptr<3> -> i64
+ %5 = nvvm.mbarrier.arrive %barrier, %count {relaxed = true} : !llvm.ptr<3> -> i64
+ %6 = nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr<3> -> i64
+ %7 = nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr<3> -> i64
+ llvm.return
+}
+
+llvm.func @mbarrier_arrive_shared_cluster(%barrier: !llvm.ptr<7>, %count : i32) {
+ // CHECK-LABEL: define void @mbarrier_arrive_shared_cluster(ptr addrspace(7) %0, i32 %1) {
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.scope.cta.space.cluster(ptr addrspace(7) %0, i32 1)
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1)
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1)
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.scope.cluster.space.cluster(ptr addrspace(7) %0, i32 %1)
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.relaxed.scope.cta.space.cluster(ptr addrspace(7) %0, i32 1)
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.relaxed.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1)
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.relaxed.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1)
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.relaxed.scope.cluster.space.cluster(ptr addrspace(7) %0, i32 %1)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ nvvm.mbarrier.arrive %barrier : !llvm.ptr<7>
+ nvvm.mbarrier.arrive %barrier, %count : !llvm.ptr<7>
+ nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<7>
+ nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<7>
+
+ nvvm.mbarrier.arrive %barrier {relaxed = true} : !llvm.ptr<7>
+ nvvm.mbarrier.arrive %barrier, %count {relaxed = true} : !llvm.ptr<7>
+ nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr<7>
+ nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr<7>
+ llvm.return
+}
+
+llvm.func @mbarrier_arrive_nocomplete(%barrier: !llvm.ptr) {
+ // CHECK-LABEL: define void @mbarrier_arrive_nocomplete(ptr %0) {
+ // CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
+ // CHECK-NEXT: %3 = call i64 @llvm.nvvm.mbarrier.arrive.noComplete(ptr %0, i32 %2)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %count = nvvm.read.ptx.sreg.ntid.x : i32
+ %0 = nvvm.mbarrier.arrive.nocomplete %barrier, %count : !llvm.ptr, i32 -> i64
+ llvm.return
+}
+
+llvm.func @mbarrier_arrive_nocomplete_shared(%barrier: !llvm.ptr<3>) {
+ // CHECK-LABEL: define void @mbarrier_arrive_nocomplete_shared(ptr addrspace(3) %0) {
+ // CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
+ // CHECK-NEXT: %3 = call i64 @llvm.nvvm.mbarrier.arrive.noComplete.shared(ptr addrspace(3) %0, i32 %2)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %count = nvvm.read.ptx.sreg.ntid.x : i32
+ %0 = nvvm.mbarrier.arrive.nocomplete %barrier, %count : !llvm.ptr<3>, i32 -> i64
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_arrive_drop.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_arrive_drop.mlir
new file mode 100644
index 0000000..c345c5d
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/mbar_arrive_drop.mlir
@@ -0,0 +1,103 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+llvm.func @mbarrier_arrive_drop_generic(%barrier: !llvm.ptr, %count : i32) {
+ // CHECK-LABEL: define void @mbarrier_arrive_drop_generic(ptr %0, i32 %1) {
+ // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %4 = call i64 @llvm.nvvm.mbarrier.arrive.drop.scope.cta.space.cta(ptr addrspace(3) %3, i32 1)
+ // CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %6 = call i64 @llvm.nvvm.mbarrier.arrive.drop.scope.cta.space.cta(ptr addrspace(3) %5, i32 %1)
+ // CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %8 = call i64 @llvm.nvvm.mbarrier.arrive.drop.scope.cta.space.cta(ptr addrspace(3) %7, i32 %1)
+ // CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %10 = call i64 @llvm.nvvm.mbarrier.arrive.drop.scope.cluster.space.cta(ptr addrspace(3) %9, i32 %1)
+ // CHECK-NEXT: %11 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %12 = call i64 @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cta.space.cta(ptr addrspace(3) %11, i32 1)
+ // CHECK-NEXT: %13 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %14 = call i64 @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cta.space.cta(ptr addrspace(3) %13, i32 %1)
+ // CHECK-NEXT: %15 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %16 = call i64 @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cta.space.cta(ptr addrspace(3) %15, i32 %1)
+ // CHECK-NEXT: %17 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %18 = call i64 @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cluster.space.cta(ptr addrspace(3) %17, i32 %1)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %0 = nvvm.mbarrier.arrive_drop %barrier : !llvm.ptr -> i64
+ %1 = nvvm.mbarrier.arrive_drop %barrier, %count : !llvm.ptr -> i64
+ %2 = nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr -> i64
+ %3 = nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr -> i64
+
+ %4 = nvvm.mbarrier.arrive_drop %barrier {relaxed = true} : !llvm.ptr -> i64
+ %5 = nvvm.mbarrier.arrive_drop %barrier, %count {relaxed = true} : !llvm.ptr -> i64
+ %6 = nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr -> i64
+ %7 = nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr -> i64
+ llvm.return
+}
+
+llvm.func @mbarrier_arrive_drop_shared(%barrier: !llvm.ptr<3>, %count : i32) {
+ // CHECK-LABEL: define void @mbarrier_arrive_drop_shared(ptr addrspace(3) %0, i32 %1) {
+ // CHECK-NEXT: %3 = call i64 @llvm.nvvm.mbarrier.arrive.drop.scope.cta.space.cta(ptr addrspace(3) %0, i32 1)
+ // CHECK-NEXT: %4 = call i64 @llvm.nvvm.mbarrier.arrive.drop.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: %5 = call i64 @llvm.nvvm.mbarrier.arrive.drop.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: %6 = call i64 @llvm.nvvm.mbarrier.arrive.drop.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: %7 = call i64 @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 1)
+ // CHECK-NEXT: %8 = call i64 @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: %9 = call i64 @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: %10 = call i64 @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %0 = nvvm.mbarrier.arrive_drop %barrier : !llvm.ptr<3> -> i64
+ %1 = nvvm.mbarrier.arrive_drop %barrier, %count : !llvm.ptr<3> -> i64
+ %2 = nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<3> -> i64
+ %3 = nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3> -> i64
+
+ %4 = nvvm.mbarrier.arrive_drop %barrier {relaxed = true} : !llvm.ptr<3> -> i64
+ %5 = nvvm.mbarrier.arrive_drop %barrier, %count {relaxed = true} : !llvm.ptr<3> -> i64
+ %6 = nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr<3> -> i64
+ %7 = nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr<3> -> i64
+ llvm.return
+}
+
+llvm.func @mbarrier_arrive_drop_shared_cluster(%barrier: !llvm.ptr<7>, %count : i32) {
+ // CHECK-LABEL: define void @mbarrier_arrive_drop_shared_cluster(ptr addrspace(7) %0, i32 %1) {
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.scope.cta.space.cluster(ptr addrspace(7) %0, i32 1)
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1)
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1)
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.scope.cluster.space.cluster(ptr addrspace(7) %0, i32 %1)
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cta.space.cluster(ptr addrspace(7) %0, i32 1)
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1)
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1)
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cluster.space.cluster(ptr addrspace(7) %0, i32 %1)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ nvvm.mbarrier.arrive_drop %barrier : !llvm.ptr<7>
+ nvvm.mbarrier.arrive_drop %barrier, %count : !llvm.ptr<7>
+ nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<7>
+ nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<7>
+
+ nvvm.mbarrier.arrive_drop %barrier {relaxed = true} : !llvm.ptr<7>
+ nvvm.mbarrier.arrive_drop %barrier, %count {relaxed = true} : !llvm.ptr<7>
+ nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr<7>
+ nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr<7>
+ llvm.return
+}
+
+llvm.func @mbarrier_arrive_drop_nocomplete(%barrier: !llvm.ptr) {
+ // CHECK-LABEL: define void @mbarrier_arrive_drop_nocomplete(ptr %0) {
+ // CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
+ // CHECK-NEXT: %3 = call i64 @llvm.nvvm.mbarrier.arrive.drop.noComplete(ptr %0, i32 %2)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %count = nvvm.read.ptx.sreg.ntid.x : i32
+ %0 = nvvm.mbarrier.arrive_drop.nocomplete %barrier, %count : !llvm.ptr, i32 -> i64
+ llvm.return
+}
+
+llvm.func @mbarrier_arrive_drop_nocomplete_shared(%barrier: !llvm.ptr<3>) {
+ // CHECK-LABEL: define void @mbarrier_arrive_drop_nocomplete_shared(ptr addrspace(3) %0) {
+ // CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
+ // CHECK-NEXT: %3 = call i64 @llvm.nvvm.mbarrier.arrive.drop.noComplete.shared(ptr addrspace(3) %0, i32 %2)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %count = nvvm.read.ptx.sreg.ntid.x : i32
+ %0 = nvvm.mbarrier.arrive_drop.nocomplete %barrier, %count : !llvm.ptr<3>, i32 -> i64
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_complete_tx.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_complete_tx.mlir
new file mode 100644
index 0000000..99289fa
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/mbar_complete_tx.mlir
@@ -0,0 +1,29 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+llvm.func @mbarrier_complete_tx_shared(%barrier: !llvm.ptr<3>, %tx_count : i32) {
+ // CHECK-LABEL: define void @mbarrier_complete_tx_shared(ptr addrspace(3) %0, i32 %1) {
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.complete.tx.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.complete.tx.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.complete.tx.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ nvvm.mbarrier.complete_tx %barrier, %tx_count : !llvm.ptr<3>, i32
+ nvvm.mbarrier.complete_tx %barrier, %tx_count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<3>, i32
+ nvvm.mbarrier.complete_tx %barrier, %tx_count {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32
+
+ llvm.return
+}
+
+llvm.func @mbarrier_complete_tx_shared_cluster(%barrier: !llvm.ptr<7>, %tx_count : i32) {
+ // CHECK-LABEL: define void @mbarrier_complete_tx_shared_cluster(ptr addrspace(7) %0, i32 %1) {
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.complete.tx.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1)
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.complete.tx.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1)
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.complete.tx.scope.cluster.space.cluster(ptr addrspace(7) %0, i32 %1)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ nvvm.mbarrier.complete_tx %barrier, %tx_count : !llvm.ptr<7>, i32
+ nvvm.mbarrier.complete_tx %barrier, %tx_count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<7>, i32
+ nvvm.mbarrier.complete_tx %barrier, %tx_count {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<7>, i32
+
+ llvm.return
+} \ No newline at end of file
diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_expect_tx.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_expect_tx.mlir
new file mode 100644
index 0000000..dad7237
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/mbar_expect_tx.mlir
@@ -0,0 +1,29 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+llvm.func @mbarrier_expect_tx_shared(%barrier: !llvm.ptr<3>, %tx_count : i32) {
+ // CHECK-LABEL: define void @mbarrier_expect_tx_shared(ptr addrspace(3) %0, i32 %1) {
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.expect.tx.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.expect.tx.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.expect.tx.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ nvvm.mbarrier.expect_tx %barrier, %tx_count : !llvm.ptr<3>, i32
+ nvvm.mbarrier.expect_tx %barrier, %tx_count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<3>, i32
+ nvvm.mbarrier.expect_tx %barrier, %tx_count {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32
+
+ llvm.return
+}
+
+llvm.func @mbarrier_expect_tx_shared_cluster(%barrier: !llvm.ptr<7>, %tx_count : i32) {
+ // CHECK-LABEL: define void @mbarrier_expect_tx_shared_cluster(ptr addrspace(7) %0, i32 %1) {
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.expect.tx.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1)
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.expect.tx.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1)
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.expect.tx.scope.cluster.space.cluster(ptr addrspace(7) %0, i32 %1)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ nvvm.mbarrier.expect_tx %barrier, %tx_count : !llvm.ptr<7>, i32
+ nvvm.mbarrier.expect_tx %barrier, %tx_count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<7>, i32
+ nvvm.mbarrier.expect_tx %barrier, %tx_count {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<7>, i32
+
+ llvm.return
+} \ No newline at end of file
diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_init.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_init.mlir
new file mode 100644
index 0000000..9c1d1cc
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/mbar_init.mlir
@@ -0,0 +1,56 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+llvm.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.ptr) {
+ // CHECK-LABEL: define void @cp_async_mbarrier_arrive(ptr addrspace(3) %0, ptr %1) {
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.mbarrier.arrive(ptr %1)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.mbarrier.arrive.noinc(ptr %1)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.mbarrier.arrive.shared(ptr addrspace(3) %0)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.mbarrier.arrive.noinc.shared(ptr addrspace(3) %0)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ nvvm.cp.async.mbarrier.arrive %bar_gen : !llvm.ptr
+ nvvm.cp.async.mbarrier.arrive %bar_gen {noinc = true} : !llvm.ptr
+ nvvm.cp.async.mbarrier.arrive %bar_shared : !llvm.ptr<3>
+ nvvm.cp.async.mbarrier.arrive %bar_shared {noinc = true} : !llvm.ptr<3>
+ llvm.return
+}
+
+llvm.func @mbarrier_init_generic(%barrier: !llvm.ptr) {
+ // CHECK-LABEL: define void @mbarrier_init_generic(ptr %0) {
+ // CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.init(ptr %0, i32 %2)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %count = nvvm.read.ptx.sreg.ntid.x : i32
+ nvvm.mbarrier.init %barrier, %count : !llvm.ptr, i32
+ llvm.return
+}
+
+llvm.func @mbarrier_init_shared(%barrier: !llvm.ptr<3>) {
+ // CHECK-LABEL: define void @mbarrier_init_shared(ptr addrspace(3) %0) {
+ // CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.init.shared(ptr addrspace(3) %0, i32 %2)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %count = nvvm.read.ptx.sreg.ntid.x : i32
+ nvvm.mbarrier.init %barrier, %count : !llvm.ptr<3>, i32
+ llvm.return
+}
+
+llvm.func @mbarrier_inval_generic(%barrier: !llvm.ptr) {
+ // CHECK-LABEL: define void @mbarrier_inval_generic(ptr %0) {
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.inval(ptr %0)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ nvvm.mbarrier.inval %barrier : !llvm.ptr
+ llvm.return
+}
+
+llvm.func @mbarrier_inval_shared(%barrier: !llvm.ptr<3>) {
+ // CHECK-LABEL: define void @mbarrier_inval_shared(ptr addrspace(3) %0) {
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.inval.shared(ptr addrspace(3) %0)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ nvvm.mbarrier.inval %barrier : !llvm.ptr<3>
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_invalid.mlir
new file mode 100644
index 0000000..4a7776d
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/mbar_invalid.mlir
@@ -0,0 +1,138 @@
+// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s
+
+// -----
+
+llvm.func @mbarrier_arrive_ret_check(%barrier: !llvm.ptr<7>) {
+ // expected-error @below {{mbarrier in shared_cluster space cannot return any value}}
+ %0 = nvvm.mbarrier.arrive %barrier : !llvm.ptr<7> -> i64
+ llvm.return
+}
+
+// -----
+
+llvm.func @mbarrier_arrive_invalid_scope(%barrier: !llvm.ptr<7>) {
+ // expected-error @below {{mbarrier scope must be either CTA or Cluster}}
+ %0 = nvvm.mbarrier.arrive %barrier {scope = #nvvm.mem_scope<gpu>} : !llvm.ptr<7> -> i64
+ llvm.return
+}
+
+// -----
+
+llvm.func @mbarrier_arrive_drop_ret_check(%barrier: !llvm.ptr<7>) {
+ // expected-error @below {{mbarrier in shared_cluster space cannot return any value}}
+ %0 = nvvm.mbarrier.arrive_drop %barrier : !llvm.ptr<7> -> i64
+ llvm.return
+}
+
+// -----
+
+llvm.func @mbarrier_arrive_drop_invalid_scope(%barrier: !llvm.ptr<7>) {
+ // expected-error @below {{mbarrier scope must be either CTA or Cluster}}
+ %0 = nvvm.mbarrier.arrive_drop %barrier {scope = #nvvm.mem_scope<gpu>} : !llvm.ptr<7> -> i64
+ llvm.return
+}
+
+// -----
+
+llvm.func @mbarrier_expect_tx_scope(%barrier: !llvm.ptr<7>, %tx_count: i32) {
+ // expected-error @below {{mbarrier scope must be either CTA or Cluster}}
+ nvvm.mbarrier.expect_tx %barrier, %tx_count {scope = #nvvm.mem_scope<gpu>} : !llvm.ptr<7>, i32
+ llvm.return
+}
+
+// -----
+
+llvm.func @mbarrier_complete_tx_scope(%barrier: !llvm.ptr<3>, %tx_count: i32) {
+ // expected-error @below {{mbarrier scope must be either CTA or Cluster}}
+ nvvm.mbarrier.complete_tx %barrier, %tx_count {scope = #nvvm.mem_scope<sys>} : !llvm.ptr<3>, i32
+ llvm.return
+}
+
+// -----
+
+llvm.func @mbarrier_arr_expect_tx(%barrier: !llvm.ptr<3>, %tx_count: i32) {
+ // expected-error @below {{mbarrier scope must be either CTA or Cluster}}
+ %1 = nvvm.mbarrier.arrive.expect_tx %barrier, %tx_count {scope = #nvvm.mem_scope<gpu>} : !llvm.ptr<3>, i32 -> i64
+ llvm.return
+}
+
+// -----
+
+llvm.func @mbarrier_arr_expect_tx_cluster(%barrier: !llvm.ptr<7>, %tx_count: i32) {
+ // expected-error @below {{mbarrier in shared_cluster space cannot return any value}}
+ %1 = nvvm.mbarrier.arrive.expect_tx %barrier, %tx_count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<7>, i32 -> i64
+ llvm.return
+}
+
+// -----
+
+llvm.func @init_mbarrier_arrive_expect_tx_asm_ret(%barrier : !llvm.ptr<3>, %txcount : i32, %pred : i1) {
+ // expected-error @below {{return-value is not supported when using predicate}}
+ %1 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1 -> i64
+ llvm.return
+}
+
+// -----
+
+llvm.func @init_mbarrier_arrive_expect_tx_asm_relaxed(%barrier : !llvm.ptr<3>, %txcount : i32, %pred : i1) {
+ // expected-error @below {{mbarrier with relaxed semantics is not supported when using predicate}}
+ nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred {relaxed = true} : !llvm.ptr<3>, i32, i1
+ llvm.return
+}
+
+// -----
+
+llvm.func @init_mbarrier_arrive_expect_tx_asm_cta(%barrier : !llvm.ptr<3>, %txcount : i32, %pred : i1) {
+ // expected-error @below {{mbarrier scope must be CTA when using predicate}}
+ nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32, i1
+ llvm.return
+}
+
+// -----
+
+llvm.func @init_mbarrier_arrive_expect_tx_asm_cluster(%barrier : !llvm.ptr<7>, %txcount : i32, %pred : i1) {
+ // expected-error @below {{mbarrier in shared_cluster space is not supported when using predicate}}
+ nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred : !llvm.ptr<7>, i32, i1
+ llvm.return
+}
+
+// -----
+
+llvm.func @mbarrier_arr_drop_expect_tx(%barrier: !llvm.ptr<3>, %tx_count: i32) {
+ // expected-error @below {{mbarrier scope must be either CTA or Cluster}}
+ %1 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %tx_count {scope = #nvvm.mem_scope<gpu>} : !llvm.ptr<3>, i32 -> i64
+ llvm.return
+}
+
+// -----
+
+llvm.func @mbarrier_arr_drop_expect_tx_cluster(%barrier: !llvm.ptr<7>, %tx_count: i32) {
+ // expected-error @below {{mbarrier in shared_cluster space cannot return any value}}
+ %1 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %tx_count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<7>, i32 -> i64
+ llvm.return
+}
+
+// -----
+
+llvm.func @mbarrier_test_wait(%barrier: !llvm.ptr<3>, %phase: i32) {
+ // expected-error @below {{mbarrier scope must be either CTA or Cluster}}
+ %1 = nvvm.mbarrier.test.wait %barrier, %phase {scope = #nvvm.mem_scope<gpu>} : !llvm.ptr<3>, i32 -> i1
+ llvm.return
+}
+
+// -----
+
+llvm.func @mbarrier_try_wait(%barrier: !llvm.ptr<3>, %phase: i32) {
+ // expected-error @below {{mbarrier scope must be either CTA or Cluster}}
+ %1 = nvvm.mbarrier.try_wait %barrier, %phase {scope = #nvvm.mem_scope<sys>} : !llvm.ptr<3>, i32 -> i1
+ llvm.return
+}
+
+// -----
+
+llvm.func @mbarrier_try_wait_with_timelimit(%barrier: !llvm.ptr<3>, %phase: i32, %ticks: i32) {
+ // expected-error @below {{mbarrier scope must be either CTA or Cluster}}
+ %1 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks {scope = #nvvm.mem_scope<gpu>} : !llvm.ptr<3>, i32, i32 -> i1
+ llvm.return
+}
+
diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_test_wait.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_test_wait.mlir
new file mode 100644
index 0000000..21ab72e
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/mbar_test_wait.mlir
@@ -0,0 +1,73 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+llvm.func @mbarrier_test_wait_state(%barrier: !llvm.ptr, %state : i64) {
+ // CHECK-LABEL: define void @mbarrier_test_wait_state(ptr %0, i64 %1) {
+ // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.scope.cta.space.cta(ptr addrspace(3) %3, i64 %1)
+ // CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.test.wait.scope.cluster.space.cta(ptr addrspace(3) %5, i64 %1)
+ // CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %8 = call i1 @llvm.nvvm.mbarrier.test.wait.relaxed.scope.cta.space.cta(ptr addrspace(3) %7, i64 %1)
+ // CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %10 = call i1 @llvm.nvvm.mbarrier.test.wait.relaxed.scope.cluster.space.cta(ptr addrspace(3) %9, i64 %1)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %0 = nvvm.mbarrier.test.wait %barrier, %state : !llvm.ptr, i64 -> i1
+ %1 = nvvm.mbarrier.test.wait %barrier, %state {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i64 -> i1
+
+ %2 = nvvm.mbarrier.test.wait %barrier, %state {relaxed = true} : !llvm.ptr, i64 -> i1
+ %3 = nvvm.mbarrier.test.wait %barrier, %state {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i64 -> i1
+ llvm.return
+}
+
+llvm.func @mbarrier_test_wait_shared_state(%barrier: !llvm.ptr<3>, %state : i64) {
+ // CHECK-LABEL: define void @mbarrier_test_wait_shared_state(ptr addrspace(3) %0, i64 %1) {
+ // CHECK-NEXT: %3 = call i1 @llvm.nvvm.mbarrier.test.wait.scope.cta.space.cta(ptr addrspace(3) %0, i64 %1)
+ // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.scope.cluster.space.cta(ptr addrspace(3) %0, i64 %1)
+ // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.test.wait.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i64 %1)
+ // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.test.wait.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i64 %1)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %0 = nvvm.mbarrier.test.wait %barrier, %state : !llvm.ptr<3>, i64 -> i1
+ %1 = nvvm.mbarrier.test.wait %barrier, %state {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i64 -> i1
+
+ %2 = nvvm.mbarrier.test.wait %barrier, %state {relaxed = true} : !llvm.ptr<3>, i64 -> i1
+ %3 = nvvm.mbarrier.test.wait %barrier, %state {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i64 -> i1
+ llvm.return
+}
+
+llvm.func @mbarrier_test_wait_phase(%barrier: !llvm.ptr, %phase : i32) {
+ // CHECK-LABEL: define void @mbarrier_test_wait_phase(ptr %0, i32 %1) {
+ // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.scope.cta.space.cta(ptr addrspace(3) %3, i32 %1)
+ // CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.scope.cluster.space.cta(ptr addrspace(3) %5, i32 %1)
+ // CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %8 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.relaxed.scope.cta.space.cta(ptr addrspace(3) %7, i32 %1)
+ // CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %10 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.relaxed.scope.cluster.space.cta(ptr addrspace(3) %9, i32 %1)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %0 = nvvm.mbarrier.test.wait %barrier, %phase : !llvm.ptr, i32 -> i1
+ %1 = nvvm.mbarrier.test.wait %barrier, %phase {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i32 -> i1
+
+ %2 = nvvm.mbarrier.test.wait %barrier, %phase {relaxed = true} : !llvm.ptr, i32 -> i1
+ %3 = nvvm.mbarrier.test.wait %barrier, %phase {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i32 -> i1
+ llvm.return
+}
+
+llvm.func @mbarrier_test_wait_shared_phase(%barrier: !llvm.ptr<3>, %phase : i32) {
+ // CHECK-LABEL: define void @mbarrier_test_wait_shared_phase(ptr addrspace(3) %0, i32 %1) {
+ // CHECK-NEXT: %3 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %0 = nvvm.mbarrier.test.wait %barrier, %phase : !llvm.ptr<3>, i32 -> i1
+ %1 = nvvm.mbarrier.test.wait %barrier, %phase {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32 -> i1
+
+ %2 = nvvm.mbarrier.test.wait %barrier, %phase {relaxed = true} : !llvm.ptr<3>, i32 -> i1
+ %3 = nvvm.mbarrier.test.wait %barrier, %phase {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32 -> i1
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_try_wait.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_try_wait.mlir
new file mode 100644
index 0000000..18aaf0e
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/mbar_try_wait.mlir
@@ -0,0 +1,147 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+llvm.func @mbarrier_try_wait_state(%barrier: !llvm.ptr, %state : i64) {
+ // CHECK-LABEL: define void @mbarrier_try_wait_state(ptr %0, i64 %1) {
+ // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.try.wait.scope.cta.space.cta(ptr addrspace(3) %3, i64 %1)
+ // CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.try.wait.scope.cluster.space.cta(ptr addrspace(3) %5, i64 %1)
+ // CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %8 = call i1 @llvm.nvvm.mbarrier.try.wait.relaxed.scope.cta.space.cta(ptr addrspace(3) %7, i64 %1)
+ // CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %10 = call i1 @llvm.nvvm.mbarrier.try.wait.relaxed.scope.cluster.space.cta(ptr addrspace(3) %9, i64 %1)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %0 = nvvm.mbarrier.try_wait %barrier, %state : !llvm.ptr, i64 -> i1
+ %1 = nvvm.mbarrier.try_wait %barrier, %state {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i64 -> i1
+
+ %2 = nvvm.mbarrier.try_wait %barrier, %state {relaxed = true} : !llvm.ptr, i64 -> i1
+ %3 = nvvm.mbarrier.try_wait %barrier, %state {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i64 -> i1
+
+ llvm.return
+}
+
+llvm.func @mbarrier_try_wait_state_with_timelimit(%barrier: !llvm.ptr, %state : i64, %ticks : i32) {
+ // CHECK-LABEL: define void @mbarrier_try_wait_state_with_timelimit(ptr %0, i64 %1, i32 %2) {
+ // CHECK-NEXT: %4 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.try.wait.tl.scope.cta.space.cta(ptr addrspace(3) %4, i64 %1, i32 %2)
+ // CHECK-NEXT: %6 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %7 = call i1 @llvm.nvvm.mbarrier.try.wait.tl.scope.cluster.space.cta(ptr addrspace(3) %6, i64 %1, i32 %2)
+ // CHECK-NEXT: %8 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %9 = call i1 @llvm.nvvm.mbarrier.try.wait.tl.relaxed.scope.cta.space.cta(ptr addrspace(3) %8, i64 %1, i32 %2)
+ // CHECK-NEXT: %10 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %11 = call i1 @llvm.nvvm.mbarrier.try.wait.tl.relaxed.scope.cluster.space.cta(ptr addrspace(3) %10, i64 %1, i32 %2)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %0 = nvvm.mbarrier.try_wait %barrier, %state, %ticks : !llvm.ptr, i64, i32 -> i1
+ %1 = nvvm.mbarrier.try_wait %barrier, %state, %ticks {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i64, i32 -> i1
+
+ %2 = nvvm.mbarrier.try_wait %barrier, %state, %ticks {relaxed = true} : !llvm.ptr, i64, i32 -> i1
+ %3 = nvvm.mbarrier.try_wait %barrier, %state, %ticks {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i64, i32 -> i1
+
+ llvm.return
+}
+
+llvm.func @mbarrier_try_wait_shared_state(%barrier: !llvm.ptr<3>, %state : i64) {
+ // CHECK-LABEL: define void @mbarrier_try_wait_shared_state(ptr addrspace(3) %0, i64 %1) {
+ // CHECK-NEXT: %3 = call i1 @llvm.nvvm.mbarrier.try.wait.scope.cta.space.cta(ptr addrspace(3) %0, i64 %1)
+ // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.try.wait.scope.cluster.space.cta(ptr addrspace(3) %0, i64 %1)
+ // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.try.wait.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i64 %1)
+ // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.try.wait.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i64 %1)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %0 = nvvm.mbarrier.try_wait %barrier, %state : !llvm.ptr<3>, i64 -> i1
+ %1 = nvvm.mbarrier.try_wait %barrier, %state {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i64 -> i1
+
+ %2 = nvvm.mbarrier.try_wait %barrier, %state {relaxed = true} : !llvm.ptr<3>, i64 -> i1
+ %3 = nvvm.mbarrier.try_wait %barrier, %state {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i64 -> i1
+ llvm.return
+}
+
+llvm.func @mbarrier_try_wait_shared_state_with_timelimit(%barrier: !llvm.ptr<3>, %state : i64, %ticks : i32) {
+ // CHECK-LABEL: define void @mbarrier_try_wait_shared_state_with_timelimit(ptr addrspace(3) %0, i64 %1, i32 %2) {
+ // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.try.wait.tl.scope.cta.space.cta(ptr addrspace(3) %0, i64 %1, i32 %2)
+ // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.try.wait.tl.scope.cluster.space.cta(ptr addrspace(3) %0, i64 %1, i32 %2)
+ // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.try.wait.tl.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i64 %1, i32 %2)
+ // CHECK-NEXT: %7 = call i1 @llvm.nvvm.mbarrier.try.wait.tl.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i64 %1, i32 %2)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %0 = nvvm.mbarrier.try_wait %barrier, %state, %ticks : !llvm.ptr<3>, i64, i32 -> i1
+ %1 = nvvm.mbarrier.try_wait %barrier, %state, %ticks {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i64, i32 -> i1
+
+ %2 = nvvm.mbarrier.try_wait %barrier, %state, %ticks {relaxed = true} : !llvm.ptr<3>, i64, i32 -> i1
+ %3 = nvvm.mbarrier.try_wait %barrier, %state, %ticks {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i64, i32 -> i1
+ llvm.return
+}
+
+llvm.func @mbarrier_try_wait_phase(%barrier: !llvm.ptr, %phase : i32) {
+ // CHECK-LABEL: define void @mbarrier_try_wait_phase(ptr %0, i32 %1) {
+ // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.scope.cta.space.cta(ptr addrspace(3) %3, i32 %1)
+ // CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.scope.cluster.space.cta(ptr addrspace(3) %5, i32 %1)
+ // CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %8 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.relaxed.scope.cta.space.cta(ptr addrspace(3) %7, i32 %1)
+ // CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %10 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.relaxed.scope.cluster.space.cta(ptr addrspace(3) %9, i32 %1)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %0 = nvvm.mbarrier.try_wait %barrier, %phase : !llvm.ptr, i32 -> i1
+ %1 = nvvm.mbarrier.try_wait %barrier, %phase {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i32 -> i1
+
+ %2 = nvvm.mbarrier.try_wait %barrier, %phase {relaxed = true} : !llvm.ptr, i32 -> i1
+ %3 = nvvm.mbarrier.try_wait %barrier, %phase {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i32 -> i1
+ llvm.return
+}
+
+llvm.func @mbarrier_try_wait_phase_with_timelimit(%barrier: !llvm.ptr, %phase : i32, %ticks : i32) {
+ // CHECK-LABEL: define void @mbarrier_try_wait_phase_with_timelimit(ptr %0, i32 %1, i32 %2) {
+ // CHECK-NEXT: %4 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.tl.scope.cta.space.cta(ptr addrspace(3) %4, i32 %1, i32 %2)
+ // CHECK-NEXT: %6 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %7 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.tl.scope.cluster.space.cta(ptr addrspace(3) %6, i32 %1, i32 %2)
+ // CHECK-NEXT: %8 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %9 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.tl.relaxed.scope.cta.space.cta(ptr addrspace(3) %8, i32 %1, i32 %2)
+ // CHECK-NEXT: %10 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %11 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.tl.relaxed.scope.cluster.space.cta(ptr addrspace(3) %10, i32 %1, i32 %2)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %0 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks : !llvm.ptr, i32, i32 -> i1
+ %1 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i32, i32 -> i1
+
+ %2 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks {relaxed = true} : !llvm.ptr, i32, i32 -> i1
+ %3 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i32, i32 -> i1
+ llvm.return
+}
+
+llvm.func @mbarrier_try_wait_shared_phase(%barrier: !llvm.ptr<3>, %phase : i32) {
+ // CHECK-LABEL: define void @mbarrier_try_wait_shared_phase(ptr addrspace(3) %0, i32 %1) {
+ // CHECK-NEXT: %3 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %0 = nvvm.mbarrier.try_wait %barrier, %phase : !llvm.ptr<3>, i32 -> i1
+ %1 = nvvm.mbarrier.try_wait %barrier, %phase {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32 -> i1
+
+ %2 = nvvm.mbarrier.try_wait %barrier, %phase {relaxed = true} : !llvm.ptr<3>, i32 -> i1
+ %3 = nvvm.mbarrier.try_wait %barrier, %phase {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32 -> i1
+ llvm.return
+}
+
+llvm.func @mbarrier_try_wait_shared_phase_with_timelimit(%barrier: !llvm.ptr<3>, %phase : i32, %ticks : i32) {
+ // CHECK-LABEL: define void @mbarrier_try_wait_shared_phase_with_timelimit(ptr addrspace(3) %0, i32 %1, i32 %2) {
+ // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.tl.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1, i32 %2)
+ // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.tl.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1, i32 %2)
+ // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.tl.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1, i32 %2)
+ // CHECK-NEXT: %7 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.tl.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1, i32 %2)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %0 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32 -> i1
+ %1 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32, i32 -> i1
+
+ %2 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks {relaxed = true} : !llvm.ptr<3>, i32, i32 -> i1
+ %3 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32, i32 -> i1
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/membar.mlir b/mlir/test/Target/LLVMIR/nvvm/membar.mlir
new file mode 100644
index 0000000..1b794f6
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/membar.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-translate -mlir-to-llvmir %s -split-input-file --verify-diagnostics | FileCheck %s
+
+// CHECK-lABEL: @memorybarrier()
+llvm.func @memorybarrier() {
+ // CHECK: call void @llvm.nvvm.membar.cta()
+ nvvm.memory.barrier #nvvm.mem_scope<cta>
+ // CHECK: call void @llvm.nvvm.fence.sc.cluster()
+ nvvm.memory.barrier #nvvm.mem_scope<cluster>
+ // CHECK: call void @llvm.nvvm.membar.gl()
+ nvvm.memory.barrier #nvvm.mem_scope<gpu>
+ // CHECK: call void @llvm.nvvm.membar.sys()
+ nvvm.memory.barrier #nvvm.mem_scope<sys>
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/permute_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/permute_invalid.mlir
new file mode 100644
index 0000000..1d6c23c
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/permute_invalid.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s
+
+llvm.func @invalid_default_missing_hi(%sel: i32, %lo: i32) -> i32 {
+ // expected-error @below {{mode 'default' requires 'hi' operand.}}
+ %r = nvvm.prmt #nvvm.permute_mode<default> %sel, %lo : i32
+ llvm.return %r : i32
+}
+
+llvm.func @invalid_f4e_missing_hi(%sel: i32, %lo: i32) -> i32 {
+ // expected-error @below {{mode 'f4e' requires 'hi' operand.}}
+ %r = nvvm.prmt #nvvm.permute_mode<f4e> %sel, %lo : i32
+ llvm.return %r : i32
+}
+
+llvm.func @invalid_b4e_missing_hi(%sel: i32, %lo: i32) -> i32 {
+ // expected-error @below {{mode 'b4e' requires 'hi' operand.}}
+ %r = nvvm.prmt #nvvm.permute_mode<b4e> %sel, %lo : i32
+ llvm.return %r : i32
+}
+
+llvm.func @invalid_rc8_with_hi(%sel: i32, %lo: i32, %hi: i32) -> i32 {
+ // expected-error @below {{mode 'rc8' does not accept 'hi' operand.}}
+ %r = nvvm.prmt #nvvm.permute_mode<rc8> %sel, %lo, %hi : i32
+ llvm.return %r : i32
+}
+
+llvm.func @invalid_ecl_with_hi(%sel: i32, %lo: i32, %hi: i32) -> i32 {
+ // expected-error @below {{mode 'ecl' does not accept 'hi' operand.}}
+ %r = nvvm.prmt #nvvm.permute_mode<ecl> %sel, %lo, %hi : i32
+ llvm.return %r : i32
+}
+
+llvm.func @invalid_ecr_with_hi(%sel: i32, %lo: i32, %hi: i32) -> i32 {
+ // expected-error @below {{mode 'ecr' does not accept 'hi' operand.}}
+ %r = nvvm.prmt #nvvm.permute_mode<ecr> %sel, %lo, %hi : i32
+ llvm.return %r : i32
+}
+
+llvm.func @invalid_rc16_with_hi(%sel: i32, %lo: i32, %hi: i32) -> i32 {
+ // expected-error @below {{mode 'rc16' does not accept 'hi' operand.}}
+ %r = nvvm.prmt #nvvm.permute_mode<rc16> %sel, %lo, %hi : i32
+ llvm.return %r : i32
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/permute_valid.mlir b/mlir/test/Target/LLVMIR/nvvm/permute_valid.mlir
new file mode 100644
index 0000000..d2baae7
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/permute_valid.mlir
@@ -0,0 +1,64 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @test_prmt_default
+llvm.func @test_prmt_default(%sel: i32, %lo: i32, %hi: i32) -> i32 {
+ // CHECK: call i32 @llvm.nvvm.prmt(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %result = nvvm.prmt #nvvm.permute_mode<default> %sel, %lo, %hi : i32
+ llvm.return %result : i32
+}
+
+// CHECK-LABEL: @test_prmt_f4e
+llvm.func @test_prmt_f4e(%pos: i32, %lo: i32, %hi: i32) -> i32 {
+ // CHECK: call i32 @llvm.nvvm.prmt.f4e(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %result = nvvm.prmt #nvvm.permute_mode<f4e> %pos, %lo, %hi : i32
+ llvm.return %result : i32
+}
+
+// CHECK-LABEL: @test_prmt_b4e
+llvm.func @test_prmt_b4e(%pos: i32, %lo: i32, %hi: i32) -> i32 {
+ // CHECK: call i32 @llvm.nvvm.prmt.b4e(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %result = nvvm.prmt #nvvm.permute_mode<b4e> %pos, %lo, %hi : i32
+ llvm.return %result : i32
+}
+
+// CHECK-LABEL: @test_prmt_rc8
+llvm.func @test_prmt_rc8(%sel: i32, %val: i32) -> i32 {
+ // CHECK: call i32 @llvm.nvvm.prmt.rc8(i32 %{{.*}}, i32 %{{.*}})
+ %result = nvvm.prmt #nvvm.permute_mode<rc8> %sel, %val : i32
+ llvm.return %result : i32
+}
+
+// CHECK-LABEL: @test_prmt_ecl
+llvm.func @test_prmt_ecl(%sel: i32, %val: i32) -> i32 {
+ // CHECK: call i32 @llvm.nvvm.prmt.ecl(i32 %{{.*}}, i32 %{{.*}})
+ %result = nvvm.prmt #nvvm.permute_mode<ecl> %sel, %val : i32
+ llvm.return %result : i32
+}
+
+// CHECK-LABEL: @test_prmt_ecr
+llvm.func @test_prmt_ecr(%sel: i32, %val: i32) -> i32 {
+ // CHECK: call i32 @llvm.nvvm.prmt.ecr(i32 %{{.*}}, i32 %{{.*}})
+ %result = nvvm.prmt #nvvm.permute_mode<ecr> %sel, %val : i32
+ llvm.return %result : i32
+}
+
+// CHECK-LABEL: @test_prmt_rc16
+llvm.func @test_prmt_rc16(%sel: i32, %val: i32) -> i32 {
+ // CHECK: call i32 @llvm.nvvm.prmt.rc16(i32 %{{.*}}, i32 %{{.*}})
+ %result = nvvm.prmt #nvvm.permute_mode<rc16> %sel, %val : i32
+ llvm.return %result : i32
+}
+
+// CHECK-LABEL: @test_prmt_mixed
+llvm.func @test_prmt_mixed(%sel: i32, %lo: i32, %hi: i32) -> i32 {
+ // CHECK: call i32 @llvm.nvvm.prmt(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %r1 = nvvm.prmt #nvvm.permute_mode<default> %sel, %lo, %hi : i32
+
+ // CHECK: call i32 @llvm.nvvm.prmt.rc8(i32 %{{.*}}, i32 %{{.*}})
+ %r2 = nvvm.prmt #nvvm.permute_mode<rc8> %sel, %r1 : i32
+
+ // CHECK: call i32 @llvm.nvvm.prmt.f4e(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %r3 = nvvm.prmt #nvvm.permute_mode<f4e> %lo, %r2, %sel : i32
+
+ llvm.return %r3 : i32
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/redux-sync-invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/redux-sync-invalid.mlir
new file mode 100644
index 0000000..a8a7430
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/redux-sync-invalid.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s
+
+// -----
+
+llvm.func @redux_sync_i32_with_abs(%value: i32, %offset: i32) {
+ // expected-error@+1 {{abs attribute is supported only for f32 type}}
+ %res = nvvm.redux.sync add %value, %offset {abs = true}: i32 -> i32
+ llvm.return
+}
+
+// -----
+
+llvm.func @redux_sync_i32_with_nan(%value: i32, %offset: i32) {
+ // expected-error@+1 {{nan attribute is supported only for f32 type}}
+ %res = nvvm.redux.sync add %value, %offset {nan = true}: i32 -> i32
+ llvm.return
+}
+
+// -----
+
+llvm.func @redux_sync_f32_with_invalid_kind_add(%value: f32, %offset: i32) {
+ // expected-error@+1 {{'add' redux kind unsupported with 'f32' type. Only supported type is 'i32'.}}
+ %res = nvvm.redux.sync add %value, %offset: f32 -> f32
+ llvm.return
+}
+
+// -----
+
+llvm.func @redux_sync_f32_with_invalid_kind_and(%value: f32, %offset: i32) {
+ // expected-error@+1 {{'and' redux kind unsupported with 'f32' type. Only supported type is 'i32'.}}
+ %res = nvvm.redux.sync and %value, %offset: f32 -> f32
+ llvm.return
+}
+
+// -----
+
+llvm.func @redux_sync_i32_with_invalid_kind_fmin(%value: i32, %offset: i32) {
+ // expected-error@+1 {{'fmin' redux kind unsupported with 'i32' type. Only supported type is 'f32'.}}
+ %res = nvvm.redux.sync fmin %value, %offset: i32 -> i32
+ llvm.return
+}
+
+// -----
+
+llvm.func @redux_sync_non_matching_types(%value: i32, %offset: i32) {
+ // expected-error@+1 {{failed to verify that all of {res, val} have same type}}
+ %res = nvvm.redux.sync add %value, %offset: i32 -> f32
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir
new file mode 100644
index 0000000..f2ccfe7
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s
+
+// -----
+
+func.func @nvvm_invalid_shfl_pred(%arg0 : i32, %arg1 : f32, %arg2 : i32, %arg3 : i32) {
+ // expected-error@+1 {{"return_value_and_is_valid" attribute must be specified when the return type is a struct type}}
+ %0 = nvvm.shfl.sync bfly %arg0, %arg1, %arg2, %arg3 : f32 -> !llvm.struct<(f32, i1)>
+}
+
+// -----
+
+func.func @nvvm_invalid_shfl_invalid_return_type_1(%arg0 : i32, %arg1 : f32, %arg2 : i32, %arg3 : i32) {
+ // expected-error@+1 {{expected return type to be of type 'f32' but got 'i32' instead}}
+ %0 = nvvm.shfl.sync bfly %arg0, %arg1, %arg2, %arg3 : f32 -> i32
+}
+
+// -----
+
+func.func @nvvm_invalid_shfl_invalid_return_type_2(%arg0 : i32, %arg1 : f32, %arg2 : i32, %arg3 : i32) {
+ // expected-error@+1 {{expected first element in the returned struct to be of type 'f32' but got 'i32' instead}}
+ %0 = nvvm.shfl.sync bfly %arg0, %arg1, %arg2, %arg3 {return_value_and_is_valid} : f32 -> !llvm.struct<(i32, i1)>
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-ld-invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-ld-invalid.mlir
new file mode 100644
index 0000000..1b93f20c
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-ld-invalid.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s
+
+// -----
+
+llvm.func @nvvm_tcgen05_ld_32x32b_offset(%tmemAddr : !llvm.ptr<6>, %offset : i64) -> () {
+ // expected-error@+1 {{offset argument is only supported for shape 16x32bx2}}
+ %ldv2 = nvvm.tcgen05.ld %tmemAddr, %offset { pack, shape = #nvvm.tcgen05_ldst_shape<shape_32x32b>} : vector<2 x i32>
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-shared.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-shared.mlir
new file mode 100644
index 0000000..db4574b
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-shared.mlir
@@ -0,0 +1,229 @@
+// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_mxf8f6f4_block_scale_cta_1
+llvm.func @nvvm_tcgen05_mma_mxf8f6f4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_mxf8f6f4_block_scale_cta_2
+llvm.func @nvvm_tcgen05_mma_mxf8f6f4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_mxf4_block_scale_cta_1
+llvm.func @nvvm_tcgen05_mma_mxf4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_mxf4_block_scale_cta_2
+llvm.func @nvvm_tcgen05_mma_mxf4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_mxf4nvf4_block_scale_cta_1
+llvm.func @nvvm_tcgen05_mma_mxf4nvf4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_mxf4nvf4_block_scale_cta_2
+llvm.func @nvvm_tcgen05_mma_mxf4nvf4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-tensor.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-tensor.mlir
new file mode 100644
index 0000000..a15c3fb
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-tensor.mlir
@@ -0,0 +1,229 @@
+// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_mxf8f6f4_block_scale_cta_1
+llvm.func @nvvm_tcgen05_mma_mxf8f6f4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_mxf8f6f4_block_scale_cta_2
+llvm.func @nvvm_tcgen05_mma_mxf8f6f4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_mxf4_block_scale_cta_1
+llvm.func @nvvm_tcgen05_mma_mxf4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_mxf4_block_scale_cta_2
+llvm.func @nvvm_tcgen05_mma_mxf4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_mxf4nvf4_block_scale_cta_1
+llvm.func @nvvm_tcgen05_mma_mxf4nvf4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_mxf4nvf4_block_scale_cta_2
+llvm.func @nvvm_tcgen05_mma_mxf4nvf4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-invalid.mlir
new file mode 100644
index 0000000..f46b35a
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-invalid.mlir
@@ -0,0 +1,119 @@
+// RUN: mlir-translate --mlir-to-llvmir -verify-diagnostics -split-input-file %s
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_disable_output_lane_cta_1
+llvm.func @nvvm_tcgen05_mma_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLanev4: vector<4 x i32>, %disableOutputLanev8: vector<8 x i32>) {
+ // expected-error @below {{Disable Output Lane of length 8 is incompatible with CtaGroupAttr}}
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLanev8
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>)
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_disable_output_lane_cta_2
+llvm.func @nvvm_tcgen05_mma_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLanev4: vector<4 x i32>, %disableOutputLanev8: vector<8 x i32>) {
+ // expected-error @below {{Disable Output Lane of length 8 is incompatible with CtaGroupAttr}}
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLanev8
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>)
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_shared_ashift
+llvm.func @nvvm_tcgen05_mma_shared_ashift(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1) {
+ // expected-error @below {{A-shift can be applied only when matrix A is in tensor memory}}
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, i64, i64, i32, i1)
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_ashift
+llvm.func @nvvm_tcgen05_mma_ashift(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1) {
+ // expected-error @below {{Cannot use collector buffer operation fill or use with ashift}}
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_mxf4nvf4_block_scale_default
+llvm.func @nvvm_tcgen05_mma_mxf4nvf4_block_scale_default(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scalea: !llvm.ptr<6>, %scaleb: !llvm.ptr<6>) {
+ // expected-error @below {{mxf4nvf4 requires block scale attribute}}
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scalea, %scaleb
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_mxf4_block_scale_default
+llvm.func @nvvm_tcgen05_mma_mxf4_block_scale_default(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scalea: !llvm.ptr<6>, %scaleb: !llvm.ptr<6>) {
+ // expected-error @below {{mxf4 kind does not support block16 attribute}}
+ nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scalea, %scaleb
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, ashift, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_sp_disable_output_lane_cta_1
+llvm.func @nvvm_tcgen05_mma_sp_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLanev4: vector<4 x i32>, %disableOutputLanev8: vector<8 x i32>, %spmetadata: !llvm.ptr<6>) {
+ // expected-error @below {{Disable Output Lane of length 8 is incompatible with CtaGroupAttr}}
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLanev8
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_sp_disable_output_lane_cta_2
+llvm.func @nvvm_tcgen05_mma_sp_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLanev4: vector<4 x i32>, %disableOutputLanev8: vector<8 x i32>, %spmetadata: !llvm.ptr<6>) {
+ // expected-error @below {{Disable Output Lane of length 8 is incompatible with CtaGroupAttr}}
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLanev8
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @nvvm_tcgen05_sp_mma_shared_ashift
+llvm.func @nvvm_tcgen05_sp_mma_shared_ashift(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) {
+ // expected-error @below {{A-shift can be applied only when matrix A is in tensor memory}}
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_sp_ashift
+llvm.func @nvvm_tcgen05_mma_sp_ashift(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) {
+ // expected-error @below {{Cannot use collector buffer operation fill or use with ashift}}
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_default
+llvm.func @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_default(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scalea: !llvm.ptr<6>, %scaleb: !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) {
+ // expected-error @below {{mxf4nvf4 requires block scale attribute}}
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scalea, %scaleb
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf4_block_scale_default
+llvm.func @nvvm_tcgen05_mma_sp_mxf4_block_scale_default(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scalea: !llvm.ptr<6>, %scaleb: !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) {
+ // expected-error @below {{mxf4 kind does not support block16 attribute}}
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scalea, %scaleb
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, ashift, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-shared.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-shared.mlir
new file mode 100644
index 0000000..286df36
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-shared.mlir
@@ -0,0 +1,442 @@
+// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_cta_1
+llvm.func @nvvm_tcgen05_mma_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 1, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 1, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 1, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 1, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 1, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 1, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 1, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 1, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 1, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 1, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 1, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 1, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 1, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 1, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 1, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 1, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_cta_2
+llvm.func @nvvm_tcgen05_mma_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 2, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 2, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 2, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 2, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 2, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 2, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 2, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 2, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 2, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 2, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 2, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 2, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 2, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 2, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 2, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 2, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ llvm.return
+}
+
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_scale_d_imm_cta_1
+llvm.func @nvvm_tcgen05_mma_scale_d_imm_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1) {
+
+ %scale_d_imm = llvm.mlir.constant(0:i64) : i64
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_scale_d_imm_cta_2
+llvm.func @nvvm_tcgen05_mma_scale_d_imm_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1) {
+
+ %scale_d_imm = llvm.mlir.constant(0:i64) : i64
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_disable_output_lane_cta_1
+llvm.func @nvvm_tcgen05_mma_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane : vector<4 x i32>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_disable_output_lane_cta_2
+llvm.func @nvvm_tcgen05_mma_disable_output_lane_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<8 x i32>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_scale_d_imm_disable_output_lane_cta_1
+llvm.func @nvvm_tcgen05_mma_scale_d_imm_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<4 x i32>) {
+
+ %scale_d_imm = llvm.mlir.constant(0:i64) : i64
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<4 x i32>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_scale_d_imm_disable_output_lane_cta_2
+llvm.func @nvvm_tcgen05_mma_scale_d_imm_disable_output_lane_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<8 x i32>) {
+
+ %scale_d_imm = llvm.mlir.constant(0:i64) : i64
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<8 x i32>)
+
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-shared.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-shared.mlir
new file mode 100644
index 0000000..5c7eabe
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-shared.mlir
@@ -0,0 +1,229 @@
+// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf8f6f4_block_scale_cta_1
+llvm.func @nvvm_tcgen05_mma_sp_mxf8f6f4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf8f6f4_block_scale_cta_2
+llvm.func @nvvm_tcgen05_mma_sp_mxf8f6f4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf4_block_scale_cta_1
+llvm.func @nvvm_tcgen05_mma_sp_mxf4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf4_block_scale_cta_2
+llvm.func @nvvm_tcgen05_mma_sp_mxf4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_cta_1
+llvm.func @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_cta_2
+llvm.func @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-tensor.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-tensor.mlir
new file mode 100644
index 0000000..3200411
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-tensor.mlir
@@ -0,0 +1,229 @@
+// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf8f6f4_block_scale_cta_1
+llvm.func @nvvm_tcgen05_mma_sp_mxf8f6f4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf8f6f4_block_scale_cta_2
+llvm.func @nvvm_tcgen05_mma_sp_mxf8f6f4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf4_block_scale_cta_1
+llvm.func @nvvm_tcgen05_mma_sp_mxf4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf4_block_scale_cta_2
+llvm.func @nvvm_tcgen05_mma_sp_mxf4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_cta_1
+llvm.func @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_cta_2
+llvm.func @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
+ nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
+ {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-shared.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-shared.mlir
new file mode 100644
index 0000000..96044cf
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-shared.mlir
@@ -0,0 +1,442 @@
+// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_sp_cta_1
+llvm.func @nvvm_tcgen05_mma_sp_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_sp_cta_2
+llvm.func @nvvm_tcgen05_mma_sp_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 2, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 2, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 2, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 2, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ llvm.return
+}
+
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_sp_scale_d_imm_cta_1
+llvm.func @nvvm_tcgen05_mma_sp_scale_d_imm_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) {
+
+ %scale_d_imm = llvm.mlir.constant(0:i64) : i64
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_sp_scale_d_imm_cta_2
+llvm.func @nvvm_tcgen05_mma_sp_scale_d_imm_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) {
+
+ %scale_d_imm = llvm.mlir.constant(0:i64) : i64
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_sp_disable_output_lane_cta_1
+llvm.func @nvvm_tcgen05_mma_sp_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane : vector<4 x i32>, %spmetadata: !llvm.ptr<6>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_sp_disable_output_lane_cta_2
+llvm.func @nvvm_tcgen05_mma_sp_disable_output_lane_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<8 x i32>, %spmetadata: !llvm.ptr<6>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_sp_scale_d_imm_disable_output_lane_cta_1
+llvm.func @nvvm_tcgen05_mma_sp_scale_d_imm_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<4 x i32>, %spmetadata: !llvm.ptr<6>) {
+
+ %scale_d_imm = llvm.mlir.constant(0:i64) : i64
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_sp_scale_d_imm_disable_output_lane_cta_2
+llvm.func @nvvm_tcgen05_mma_sp_scale_d_imm_disable_output_lane_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<8 x i32>, %spmetadata: !llvm.ptr<6>) {
+
+ %scale_d_imm = llvm.mlir.constant(0:i64) : i64
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>)
+
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-tensor.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-tensor.mlir
new file mode 100644
index 0000000..709beb0
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-tensor.mlir
@@ -0,0 +1,634 @@
+// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_sp_cta_1
+llvm.func @nvvm_tcgen05_mma_sp_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_sp_cta_2
+llvm.func @nvvm_tcgen05_mma_sp_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 2, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 2, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 2, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 2, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ llvm.return
+}
+
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_sp_scale_d_imm_cta_1
+llvm.func @nvvm_tcgen05_mma_sp_scale_d_imm_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) {
+
+ %scale_d_imm = llvm.mlir.constant(0:i64) : i64
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_sp_scale_d_imm_cta_2
+llvm.func @nvvm_tcgen05_mma_sp_scale_d_imm_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) {
+
+ %scale_d_imm = llvm.mlir.constant(0:i64) : i64
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_sp_disable_output_lane_cta_1
+llvm.func @nvvm_tcgen05_mma_sp_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane : vector<4 x i32>, %spmetadata: !llvm.ptr<6>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_sp_disable_output_lane_cta_2
+llvm.func @nvvm_tcgen05_mma_sp_disable_output_lane_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<8 x i32>, %spmetadata: !llvm.ptr<6>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_sp_scale_d_imm_disable_output_lane_cta_1
+llvm.func @nvvm_tcgen05_mma_sp_scale_d_imm_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<4 x i32>, %spmetadata: !llvm.ptr<6>) {
+
+ %scale_d_imm = llvm.mlir.constant(0:i64) : i64
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_sp_scale_d_imm_disable_output_lane_cta_2
+llvm.func @nvvm_tcgen05_mma_sp_scale_d_imm_disable_output_lane_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<8 x i32>, %spmetadata: !llvm.ptr<6>) {
+
+ %scale_d_imm = llvm.mlir.constant(0:i64) : i64
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>)
+
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-tensor.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-tensor.mlir
new file mode 100644
index 0000000..798e311
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-tensor.mlir
@@ -0,0 +1,633 @@
+// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_cta_1
+llvm.func @nvvm_tcgen05_mma_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f16 */ i32 0, /* cta_group= */ i32 1, /* collector=discard */ i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=tf32 */ i32 1, /* cta_group= */ i32 1, /* collector=discard */ i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f8f6f4 */ i32 2, /* cta_group= */ i32 1, /* collector=discard */ i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=i8 */ i32 3, /* cta_group= */ i32 1, /* collector=discard */ i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f16 */ i32 0, /* cta_group= */ i32 1, /* collector=lastuse */ i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=tf32 */ i32 1, /* cta_group= */ i32 1, /* collector=lastuse */ i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f8f6f4 */ i32 2, /* cta_group= */ i32 1, /* collector=lastuse */ i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=i8 */ i32 3, /* cta_group= */ i32 1, /* collector=lastuse */ i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 1, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 1, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 1, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 1, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 1, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 1, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 1, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 1, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f16 */ i32 0, /* cta_group= */ i32 1, /* collector=fill */ i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=tf32 */ i32 1, /* cta_group= */ i32 1, /* collector=fill */ i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f8f6f4 */ i32 2, /* cta_group= */ i32 1, /* collector=fill */ i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=i8 */ i32 3, /* cta_group= */ i32 1, /* collector=fill */ i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f16 */ i32 0, /* cta_group= */ i32 1, /* collector=use */ i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=tf32 */ i32 1, /* cta_group= */ i32 1, /* collector=use */ i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f8f6f4 */ i32 2, /* cta_group= */ i32 1, /* collector=use */ i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=i8 */ i32 3, /* cta_group= */ i32 1, /* collector=use */ i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_cta_2
+llvm.func @nvvm_tcgen05_mma_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f16 */ i32 0, /* cta_group= */ i32 2, /* collector=discard */ i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=tf32 */ i32 1, /* cta_group= */ i32 2, /* collector=discard */ i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f8f6f4 */ i32 2, /* cta_group= */ i32 2, /* collector=discard */ i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=i8 */ i32 3, /* cta_group= */ i32 2, /* collector=discard */ i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f16 */ i32 0, /* cta_group= */ i32 2, /* collector=lastuse */ i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=tf32 */ i32 1, /* cta_group= */ i32 2, /* collector=lastuse */ i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f8f6f4 */ i32 2, /* cta_group= */ i32 2, /* collector=lastuse */ i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=i8 */ i32 3, /* cta_group= */ i32 2, /* collector=lastuse */ i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 2, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 2, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 2, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 2, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 2, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 2, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 2, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 2, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f16 */ i32 0, /* cta_group= */ i32 2, /* collector=fill */ i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=tf32 */ i32 1, /* cta_group= */ i32 2, /* collector=fill */ i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f8f6f4 */ i32 2, /* cta_group= */ i32 2, /* collector=fill */ i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=i8 */ i32 3, /* cta_group= */ i32 2, /* collector=fill */ i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f16 */ i32 0, /* cta_group= */ i32 2, /* collector=use */ i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=tf32 */ i32 1, /* cta_group= */ i32 2, /* collector=use */ i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f8f6f4 */ i32 2, /* cta_group= */ i32 2, /* collector=use */ i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=i8 */ i32 3, /* cta_group= */ i32 2, /* collector=use */ i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_scale_d_imm_cta_1
+llvm.func @nvvm_tcgen05_mma_scale_d_imm_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1) {
+
+ %scale_d_imm = llvm.mlir.constant(0:i64) : i64
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_scale_d_imm_cta_2
+llvm.func @nvvm_tcgen05_mma_scale_d_imm_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1) {
+
+ %scale_d_imm = llvm.mlir.constant(0:i64) : i64
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_disable_output_lane_cta_1
+llvm.func @nvvm_tcgen05_mma_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane : vector<4 x i32>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_disable_output_lane_cta_2
+llvm.func @nvvm_tcgen05_mma_disable_output_lane_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<8 x i32>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_scale_d_imm_disable_output_lane_cta_1
+llvm.func @nvvm_tcgen05_mma_scale_d_imm_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<4 x i32>) {
+
+ %scale_d_imm = llvm.mlir.constant(0:i64) : i64
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_scale_d_imm_disable_output_lane_cta_2
+llvm.func @nvvm_tcgen05_mma_scale_d_imm_disable_output_lane_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<8 x i32>) {
+
+ %scale_d_imm = llvm.mlir.constant(0:i64) : i64
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 0)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 1)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 2)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 3)
+ nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane
+ {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>)
+
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-shared.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-shared.mlir
new file mode 100644
index 0000000..5f1aeb0
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-shared.mlir
@@ -0,0 +1,133 @@
+// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_ws
+llvm.func @nvvm_tcgen05_mma_ws(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 0, i32 0)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f16>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 0, i32 0)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<tf32>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 0, i32 0)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 0, i32 0)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<i8>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 1, i32 0)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f16>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 1, i32 0)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<tf32>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 1, i32 0)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 1, i32 0)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<i8>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 1, i32 1)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f16>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>,
+ collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 1, i32 1)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<tf32>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>,
+ collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 1, i32 1)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>,
+ collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 1, i32 1)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<i8>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>,
+ collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_ws_zero_col_mask
+llvm.func @nvvm_tcgen05_mma_ws_zero_col_mask(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %zero_col_mask: i64) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 0, i32 0)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<f16>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 0, i32 0)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<tf32>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 0, i32 0)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 0, i32 0)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<i8>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 1, i32 0)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<f16>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 1, i32 0)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<tf32>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 1, i32 0)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 1, i32 0)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<i8>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 1, i32 1)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<f16>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>,
+ collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 1, i32 1)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<tf32>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>,
+ collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 1, i32 1)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>,
+ collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 1, i32 1)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<i8>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>,
+ collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64)
+
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-sp-shared.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-sp-shared.mlir
new file mode 100644
index 0000000..e390e35
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-sp-shared.mlir
@@ -0,0 +1,133 @@
+// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_ws_sp
+llvm.func @nvvm_tcgen05_mma_ws_sp(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 0, i32 0)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0, i32 0)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<tf32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0, i32 0)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 0, i32 0)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<i8>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 0)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f16>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 0)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<tf32>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 0)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 0)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<i8>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 1)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f16>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>,
+ collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 1)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<tf32>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>,
+ collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 1)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>,
+ collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 1)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<i8>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>,
+ collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_ws_sp_zero_col_mask
+llvm.func @nvvm_tcgen05_mma_ws_sp_zero_col_mask(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>, %zero_col_mask: i64) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 0, i32 0)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<f16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 0, i32 0)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<tf32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 0, i32 0)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 0, i32 0)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<i8>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 1, i32 0)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<f16>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 1, i32 0)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<tf32>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 1, i32 0)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 1, i32 0)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<i8>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 1, i32 1)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<f16>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>,
+ collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 1, i32 1)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<tf32>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>,
+ collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 1, i32 1)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>,
+ collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 1, i32 1)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<i8>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>,
+ collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-sp-tensor.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-sp-tensor.mlir
new file mode 100644
index 0000000..f7ce548
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-sp-tensor.mlir
@@ -0,0 +1,133 @@
+// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_ws_sp
+llvm.func @nvvm_tcgen05_mma_ws_sp(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 0, i32 0)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0, i32 0)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<tf32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0, i32 0)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 0, i32 0)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<i8>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 0)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f16>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 0)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<tf32>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 0)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 0)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<i8>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 1)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f16>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>,
+ collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 1)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<tf32>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>,
+ collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 1)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>,
+ collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 1)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata
+ {kind = #nvvm.tcgen05_mma_kind<i8>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>,
+ collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_ws_sp_zero_col_mask
+llvm.func @nvvm_tcgen05_mma_ws_sp_zero_col_mask(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>, %zero_col_mask: i64) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 0, i32 0)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<f16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 0, i32 0)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<tf32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 0, i32 0)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 0, i32 0)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<i8>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 1, i32 0)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<f16>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 1, i32 0)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<tf32>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 1, i32 0)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 1, i32 0)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<i8>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 1, i32 1)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<f16>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>,
+ collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 1, i32 1)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<tf32>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>,
+ collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 1, i32 1)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>,
+ collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 1, i32 1)
+ nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<i8>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>,
+ collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64)
+
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-tensor.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-tensor.mlir
new file mode 100644
index 0000000..cecbb3f
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-tensor.mlir
@@ -0,0 +1,133 @@
+// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_ws
+llvm.func @nvvm_tcgen05_mma_ws(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %b_desc: i64, %idesc: i32, %enable_input_d: i1) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 0, i32 0)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 0, i32 0)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<tf32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 0, i32 0)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 0, i32 0)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<i8>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 1, i32 0)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f16>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 1, i32 0)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<tf32>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 1, i32 0)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 1, i32 0)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<i8>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 1, i32 1)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f16>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>,
+ collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 1, i32 1)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<tf32>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>,
+ collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 1, i32 1)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>,
+ collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 1, i32 1)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d
+ {kind = #nvvm.tcgen05_mma_kind<i8>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>,
+ collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1)
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_mma_ws_zero_col_mask
+llvm.func @nvvm_tcgen05_mma_ws_zero_col_mask(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %zero_col_mask: i64) {
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 0, i32 0)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<f16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 0, i32 0)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<tf32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 0, i32 0)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 0, i32 0)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<i8>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 1, i32 0)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<f16>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 1, i32 0)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<tf32>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 1, i32 0)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 1, i32 0)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<i8>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 1, i32 1)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<f16>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>,
+ collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 1, i32 1)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<tf32>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>,
+ collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 1, i32 1)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<f8f6f4>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>,
+ collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 1, i32 1)
+ nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask
+ {kind = #nvvm.tcgen05_mma_kind<i8>,
+ collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>,
+ collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64)
+
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir
index 0daf245..240fab5 100644
--- a/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir
@@ -16,6 +16,17 @@ llvm.func @llvm_nvvm_cp_async_bulk_global_to_shared_cluster(%dst : !llvm.ptr<7>,
llvm.return
}
+// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_global_to_shared_cta
+llvm.func @llvm_nvvm_cp_async_bulk_global_to_shared_cta(%dst : !llvm.ptr<3>, %src : !llvm.ptr<1>, %mbar : !llvm.ptr<3>, %size : i32, %ch : i64) {
+ // CHECK: call void @llvm.nvvm.cp.async.bulk.global.to.shared.cta(ptr addrspace(3) %[[DST:.*]], ptr addrspace(3) %[[MBAR:.*]], ptr addrspace(1) %[[SRC:.*]], i32 %[[SIZE:.*]], i64 0, i1 false)
+ // CHECK: call void @llvm.nvvm.cp.async.bulk.global.to.shared.cta(ptr addrspace(3) %[[DST]], ptr addrspace(3) %[[MBAR]], ptr addrspace(1) %[[SRC]], i32 %[[SIZE]], i64 %[[CH:.*]], i1 true)
+ nvvm.cp.async.bulk.shared.cluster.global %dst, %src, %mbar, %size : !llvm.ptr<3>, !llvm.ptr<1>
+
+ nvvm.cp.async.bulk.shared.cluster.global %dst, %src, %mbar, %size l2_cache_hint = %ch : !llvm.ptr<3>, !llvm.ptr<1>
+
+ llvm.return
+}
+
// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_shared_cta_to_shared_cluster
llvm.func @llvm_nvvm_cp_async_bulk_shared_cta_to_shared_cluster(%dst : !llvm.ptr<7>, %src : !llvm.ptr<3>, %mbar : !llvm.ptr<3>, %size : i32) {
// CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.cluster(ptr addrspace(7) %0, ptr addrspace(3) %2, ptr addrspace(3) %1, i32 %3)
diff --git a/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy_invalid.mlir
new file mode 100644
index 0000000..d762ff3
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy_invalid.mlir
@@ -0,0 +1,8 @@
+// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s
+
+llvm.func @tma_bulk_copy_g2s_mc(%src : !llvm.ptr<1>, %dest : !llvm.ptr<3>, %bar : !llvm.ptr<3>, %size : i32, %ctamask : i16) {
+ // expected-error @below {{Multicast is not supported with shared::cta mode.}}
+ nvvm.cp.async.bulk.shared.cluster.global %dest, %src, %bar, %size multicast_mask = %ctamask : !llvm.ptr<3>, !llvm.ptr<1>
+
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 42aa221..d5868ee 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -578,14 +578,6 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
// -----
-llvm.func @nanosleep() {
- // expected-error@+1 {{integer constant out of range for attribute}}
- nvvm.nanosleep 100000000000000
- llvm.return
-}
-
-// -----
-
llvm.func @clusterlaunchcontrol_query_cancel_is_canceled_invalid_return_type(%try_cancel_response: i128) {
// expected-error@+1 {{'nvvm.clusterlaunchcontrol.query.cancel' op is_canceled query type returns an i1}}
%res = nvvm.clusterlaunchcontrol.query.cancel query = is_canceled, %try_cancel_response : i32
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 9115de6..c4a6909 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -166,25 +166,6 @@ llvm.func @nvvm_rcp(%0: f32) -> f32 {
llvm.return %1 : f32
}
-// CHECK-LABEL: @llvm_nvvm_barrier0
-llvm.func @llvm_nvvm_barrier0() {
- // CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
- nvvm.barrier0
- llvm.return
-}
-
-// CHECK-LABEL: @llvm_nvvm_barrier(
-// CHECK-SAME: i32 %[[barId:.*]], i32 %[[numThreads:.*]])
-llvm.func @llvm_nvvm_barrier(%barID : i32, %numberOfThreads : i32) {
- // CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
- nvvm.barrier
- // CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 %[[barId]])
- nvvm.barrier id = %barID
- // CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.count(i32 %[[barId]], i32 %[[numThreads]])
- nvvm.barrier id = %barID number_of_threads = %numberOfThreads
- llvm.return
-}
-
// CHECK-LABEL: @llvm_nvvm_cluster_arrive
llvm.func @llvm_nvvm_cluster_arrive() {
// CHECK: call void @llvm.nvvm.barrier.cluster.arrive()
@@ -531,19 +512,6 @@ llvm.func @async_cp_zfill(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>, %cpSize: i32)
llvm.return
}
-// CHECK-LABEL: @cp_async_mbarrier_arrive
-llvm.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.ptr) {
- // CHECK: call void @llvm.nvvm.cp.async.mbarrier.arrive(ptr %{{.*}})
- nvvm.cp.async.mbarrier.arrive %bar_gen : !llvm.ptr
- // CHECK: call void @llvm.nvvm.cp.async.mbarrier.arrive.noinc(ptr %{{.*}})
- nvvm.cp.async.mbarrier.arrive %bar_gen {noinc = true} : !llvm.ptr
- // CHECK: call void @llvm.nvvm.cp.async.mbarrier.arrive.shared(ptr addrspace(3) %{{.*}})
- nvvm.cp.async.mbarrier.arrive.shared %bar_shared : !llvm.ptr<3>
- // CHECK: call void @llvm.nvvm.cp.async.mbarrier.arrive.noinc.shared(ptr addrspace(3) %{{.*}})
- nvvm.cp.async.mbarrier.arrive.shared %bar_shared {noinc = true} : !llvm.ptr<3>
- llvm.return
-}
-
// CHECK-LABEL: @llvm_nvvm_setmaxregister
llvm.func @llvm_nvvm_setmaxregister() {
// CHECK: call void @llvm.nvvm.setmaxnreg.inc.sync.aligned.u32(i32 256)
@@ -731,42 +699,6 @@ llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant},
llvm.return
}
-
-// -----
-// CHECK-LABEL: @nvvm_fence_proxy_tensormap_generic_release
-llvm.func @nvvm_fence_proxy_tensormap_generic_release() {
- %c128 = llvm.mlir.constant(128) : i32
- // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.release.cta()
- nvvm.fence.proxy.release #nvvm.mem_scope<cta>
-
- // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.release.cluster()
- nvvm.fence.proxy.release #nvvm.mem_scope<cluster>
-
- // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.release.gpu()
- nvvm.fence.proxy.release #nvvm.mem_scope<gpu>
-
- // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.release.sys()
- nvvm.fence.proxy.release #nvvm.mem_scope<sys>
- llvm.return
-}
-
-// -----
-// CHECK-LABEL: @nvvm_fence_proxy_tensormap_generic_acquire
-llvm.func @nvvm_fence_proxy_tensormap_generic_acquire(%addr : !llvm.ptr) {
- %c128 = llvm.mlir.constant(128) : i32
- // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.acquire.cta(ptr {{%[0-9]+}}, i32 128)
- nvvm.fence.proxy.acquire #nvvm.mem_scope<cta> %addr, %c128
-
- // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.acquire.cluster(ptr {{%[0-9]+}}, i32 128)
- nvvm.fence.proxy.acquire #nvvm.mem_scope<cluster> %addr, %c128
-
- // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.acquire.gpu(ptr {{%[0-9]+}}, i32 128)
- nvvm.fence.proxy.acquire #nvvm.mem_scope<gpu> %addr, %c128
-
- // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.acquire.sys(ptr {{%[0-9]+}}, i32 128)
- nvvm.fence.proxy.acquire #nvvm.mem_scope<sys> %addr, %c128
- llvm.return
-}
// -----
// CHECK-LABEL: @nvvm_exit
@@ -983,8 +915,8 @@ llvm.func @nvvm_pmevent() {
// -----
// CHECK-LABEL: @nanosleep
-llvm.func @nanosleep() {
- // CHECK: call void @llvm.nvvm.nanosleep(i32 4000)
- nvvm.nanosleep 4000
+llvm.func @nanosleep(%duration: i32) {
+ // CHECK: call void @llvm.nvvm.nanosleep(i32 %{{.*}})
+ nvvm.nanosleep %duration
llvm.return
}
diff --git a/mlir/test/Target/LLVMIR/omptarget-data-use-dev-ordering.mlir b/mlir/test/Target/LLVMIR/omptarget-data-use-dev-ordering.mlir
index f6860e5..d9be6d1 100644
--- a/mlir/test/Target/LLVMIR/omptarget-data-use-dev-ordering.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-data-use-dev-ordering.mlir
@@ -67,18 +67,18 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-a
// CHECK: define void @mix_use_device_ptr_and_addr_and_map_(ptr %[[ARG_0:.*]], ptr %[[ARG_1:.*]], ptr %[[ARG_2:.*]], ptr %[[ARG_3:.*]], ptr %[[ARG_4:.*]], ptr %[[ARG_5:.*]], ptr %[[ARG_6:.*]], ptr %[[ARG_7:.*]]) {
// CHECK: %[[ALLOCA:.*]] = alloca ptr, align 8
-// CHECK: %[[BASEPTR_0_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 0
+// CHECK: %[[BASEPTR_0_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 0
// CHECK: store ptr %[[ARG_0]], ptr %[[BASEPTR_0_GEP]], align 8
-// CHECK: %[[BASEPTR_2_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 2
+// CHECK: %[[BASEPTR_2_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 4
// CHECK: store ptr %[[ARG_2]], ptr %[[BASEPTR_2_GEP]], align 8
-// CHECK: %[[BASEPTR_6_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 6
-// CHECK: store ptr %[[ARG_4]], ptr %[[BASEPTR_6_GEP]], align 8
+// CHECK: %[[BASEPTR_3_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 9
+// CHECK: store ptr %[[ARG_4]], ptr %[[BASEPTR_3_GEP]], align 8
// CHECK: call void @__tgt_target_data_begin_mapper({{.*}})
// CHECK: %[[LOAD_BASEPTR_0:.*]] = load ptr, ptr %[[BASEPTR_0_GEP]], align 8
// store ptr %[[LOAD_BASEPTR_0]], ptr %[[ALLOCA]], align 8
// CHECK: %[[LOAD_BASEPTR_2:.*]] = load ptr, ptr %[[BASEPTR_2_GEP]], align 8
-// CHECK: %[[LOAD_BASEPTR_6:.*]] = load ptr, ptr %[[BASEPTR_6_GEP]], align 8
+// CHECK: %[[LOAD_BASEPTR_3:.*]] = load ptr, ptr %[[BASEPTR_3_GEP]], align 8
// CHECK: %[[GEP_A4:.*]] = getelementptr { i64 }, ptr %[[ARG_4]], i32 0, i32 0
// CHECK: %[[GEP_A7:.*]] = getelementptr { i64 }, ptr %[[ARG_7]], i32 0, i32 0
// CHECK: %[[LOAD_A4:.*]] = load i64, ptr %[[GEP_A4]], align 4
@@ -93,17 +93,17 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-a
// CHECK: define void @mix_use_device_ptr_and_addr_and_map_2(ptr %[[ARG_0:.*]], ptr %[[ARG_1:.*]], ptr %[[ARG_2:.*]], ptr %[[ARG_3:.*]], ptr %[[ARG_4:.*]], ptr %[[ARG_5:.*]], ptr %[[ARG_6:.*]], ptr %[[ARG_7:.*]]) {
// CHECK: %[[ALLOCA:.*]] = alloca ptr, align 8
-// CHECK: %[[BASEPTR_1_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 1
+// CHECK: %[[BASEPTR_1_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 1
// CHECK: store ptr %[[ARG_0]], ptr %[[BASEPTR_1_GEP]], align 8
-// CHECK: %[[BASEPTR_2_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 2
+// CHECK: %[[BASEPTR_2_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 4
// CHECK: store ptr %[[ARG_2]], ptr %[[BASEPTR_2_GEP]], align 8
-// CHECK: %[[BASEPTR_6_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 6
-// CHECK: store ptr %[[ARG_4]], ptr %[[BASEPTR_6_GEP]], align 8
+// CHECK: %[[BASEPTR_3_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 9
+// CHECK: store ptr %[[ARG_4]], ptr %[[BASEPTR_3_GEP]], align 8
// CHECK: call void @__tgt_target_data_begin_mapper({{.*}})
// CHECK: %[[LOAD_BASEPTR_1:.*]] = load ptr, ptr %[[BASEPTR_1_GEP]], align 8
// store ptr %[[LOAD_BASEPTR_1]], ptr %[[ALLOCA]], align 8
// CHECK: %[[LOAD_BASEPTR_2:.*]] = load ptr, ptr %[[BASEPTR_2_GEP]], align 8
-// CHECK: %[[LOAD_BASEPTR_6:.*]] = load ptr, ptr %[[BASEPTR_6_GEP]], align 8
+// CHECK: %[[LOAD_BASEPTR_3:.*]] = load ptr, ptr %[[BASEPTR_3_GEP]], align 8
// CHECK: %[[GEP_A4:.*]] = getelementptr { i64 }, ptr %[[ARG_4]], i32 0, i32 0
// CHECK: %[[GEP_A7:.*]] = getelementptr { i64 }, ptr %[[ARG_7]], i32 0, i32 0
// CHECK: %[[LOAD_A4:.*]] = load i64, ptr %[[GEP_A4]], align 4
diff --git a/mlir/test/Target/LLVMIR/omptarget-declare-target-to-device.mlir b/mlir/test/Target/LLVMIR/omptarget-declare-target-to-device.mlir
new file mode 100644
index 0000000..fa330b6
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omptarget-declare-target-to-device.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// This tests the replacement of operations for `declare target to` with the
+// generated `declare target to` global variable inside of target op regions when
+// lowering to IR for device. Unfortunately, as the host file is not passed as a
+// module attribute, we miss out on the metadata and entry info.
+
+module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} {
+ // CHECK-DAG: @_QMtest_0Ezii = global [11 x float] zeroinitializer
+ llvm.mlir.global external @_QMtest_0Ezii() {addr_space = 0 : i32, omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>} : !llvm.array<11 x f32> {
+ %0 = llvm.mlir.zero : !llvm.array<11 x f32>
+ llvm.return %0 : !llvm.array<11 x f32>
+ }
+
+ // CHECK-LABEL: define weak_odr protected amdgpu_kernel void @{{.*}}(ptr %{{.*}}) {{.*}} {
+ // CHECK-DAG: omp.target:
+ // CHECK-DAG: store float 1.000000e+00, ptr @_QMtest_0Ezii, align 4
+ // CHECK-DAG: br label %omp.region.cont
+ llvm.func @_QQmain() {
+ %0 = llvm.mlir.constant(1 : index) : i64
+ %1 = llvm.mlir.constant(0 : index) : i64
+ %2 = llvm.mlir.constant(11 : index) : i64
+ %3 = llvm.mlir.addressof @_QMtest_0Ezii : !llvm.ptr
+ %4 = omp.map.bounds lower_bound(%1 : i64) upper_bound(%2 : i64) extent(%2 : i64) stride(%0 : i64) start_idx(%1 : i64) {stride_in_bytes = true}
+ %5 = omp.map.info var_ptr(%3 : !llvm.ptr, !llvm.array<11 x f32>) map_clauses(tofrom) capture(ByRef) bounds(%4) -> !llvm.ptr
+ omp.target map_entries(%5 -> %arg0 : !llvm.ptr) {
+ %6 = llvm.mlir.constant(1.0 : f32) : f32
+ %7 = llvm.mlir.constant(0 : i64) : i64
+ %8 = llvm.getelementptr %arg0[%7] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ llvm.store %6, %8 : f32, !llvm.ptr
+ omp.terminator
+ }
+ llvm.return
+ }
+}
diff --git a/mlir/test/Target/LLVMIR/omptarget-declare-target-to-host.mlir b/mlir/test/Target/LLVMIR/omptarget-declare-target-to-host.mlir
new file mode 100644
index 0000000..4202421
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omptarget-declare-target-to-host.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu", omp.is_gpu = false, omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
+ // CHECK-DAG: @_QMtest_0Ezii = global [11 x float] zeroinitializer
+ // CHECK-DAG: @.offload_sizes = private unnamed_addr constant [1 x i64] [i64 48]
+ // CHECK-DAG: @.offload_maptypes = private unnamed_addr constant [1 x i64] [i64 3]
+ // CHECK-DAG: @.offloading.entry._QMtest_0Ezii = weak constant %struct.__tgt_offload_entry {{.*}} ptr @_QMtest_0Ezii, {{.*}}, i64 44,{{.*}}
+ llvm.mlir.global external @_QMtest_0Ezii() {addr_space = 0 : i32, omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>} : !llvm.array<11 x f32> {
+ %0 = llvm.mlir.zero : !llvm.array<11 x f32>
+ llvm.return %0 : !llvm.array<11 x f32>
+ }
+
+ // CHECK-DAG: %[[BASEPTR:.*]] = getelementptr inbounds [1 x ptr], ptr %.offload_baseptrs, i32 0, i32 0
+ // CHECK-DAG: store ptr @_QMtest_0Ezii, ptr %[[BASEPTR]], align 8
+ // CHECK-DAG: %[[OFFLOADPTR:.*]] = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0
+ // CHECK-DAG: store ptr @_QMtest_0Ezii, ptr %[[OFFLOADPTR]], align 8
+ llvm.func @_QQmain() {
+ %0 = llvm.mlir.constant(1 : index) : i64
+ %1 = llvm.mlir.constant(0 : index) : i64
+ %2 = llvm.mlir.constant(11 : index) : i64
+ %3 = llvm.mlir.addressof @_QMtest_0Ezii : !llvm.ptr
+ %4 = omp.map.bounds lower_bound(%1 : i64) upper_bound(%2 : i64) extent(%2 : i64) stride(%0 : i64) start_idx(%1 : i64) {stride_in_bytes = true}
+ %5 = omp.map.info var_ptr(%3 : !llvm.ptr, !llvm.array<11 x f32>) map_clauses(tofrom) capture(ByRef) bounds(%4) -> !llvm.ptr
+ omp.target map_entries(%5 -> %arg0 : !llvm.ptr) {
+ %6 = llvm.mlir.constant(1.0 : f32) : f32
+ %7 = llvm.mlir.constant(0 : i64) : i64
+ %8 = llvm.getelementptr %arg0[%7] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ llvm.store %6, %8 : f32, !llvm.ptr
+ omp.terminator
+ }
+ llvm.return
+ }
+ // CHEKC-DAG: !{{.*}} = !{i32 {{.*}}, !"_QMtest_0Ezii", i32 {{.*}}, i32 {{.*}}}
+}
diff --git a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
index e6ea3aa..e289d5d 100644
--- a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
@@ -622,3 +622,20 @@ module attributes {omp.target_triples = ["amdgcn-amd-amdhsa"]} {
// CHECK: br label %[[VAL_40]]
// CHECK: omp.done: ; preds = %[[VAL_68]], %[[VAL_63]], %[[VAL_32]]
// CHECK: ret void
+
+// -----
+
+module attributes {omp.target_triples = ["amdgcn-amd-amdhsa"]} {
+ llvm.func @_QPomp_target_is_device_ptr(%arg0 : !llvm.ptr) {
+ %map = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.ptr)
+ map_clauses(is_device_ptr) capture(ByRef) -> !llvm.ptr {name = ""}
+ omp.target map_entries(%map -> %ptr_arg : !llvm.ptr) {
+ omp.terminator
+ }
+ llvm.return
+ }
+}
+
+// CHECK: @.offload_sizes = private unnamed_addr constant [1 x i64] [i64 8]
+// CHECK: @.offload_maptypes = private unnamed_addr constant [1 x i64] [i64 288]
+// CHECK-LABEL: define void @_QPomp_target_is_device_ptr
diff --git a/mlir/test/Target/LLVMIR/omptarget-multi-block-reduction.mlir b/mlir/test/Target/LLVMIR/omptarget-multi-block-reduction.mlir
index 87ff0ba..fac61e05 100644
--- a/mlir/test/Target/LLVMIR/omptarget-multi-block-reduction.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-multi-block-reduction.mlir
@@ -7,7 +7,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<"dlti.alloca_memory_space" = 5 :
llvm.func @bar() {}
llvm.func @baz() {}
- omp.declare_reduction @add_reduction_byref_box_5xf32 : !llvm.ptr alloc {
+ omp.declare_reduction @add_reduction_byref_box_5xf32 : !llvm.ptr attributes {byref_element_type = !llvm.array<5 x f32>} alloc {
%0 = llvm.mlir.constant(1 : i64) : i64
%1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> : (i64) -> !llvm.ptr<5>
%2 = llvm.addrspacecast %1 : !llvm.ptr<5> to !llvm.ptr
@@ -23,7 +23,12 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<"dlti.alloca_memory_space" = 5 :
^bb3: // pred: ^bb1
llvm.call @baz() : () -> ()
omp.yield(%arg0 : !llvm.ptr)
+ } data_ptr_ptr {
+ ^bb0(%arg0: !llvm.ptr):
+ %0 = llvm.getelementptr %arg0[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
+ omp.yield(%0 : !llvm.ptr)
}
+
llvm.func @foo_() {
%c1 = llvm.mlir.constant(1 : i64) : i64
%10 = llvm.alloca %c1 x !llvm.array<5 x f32> {bindc_name = "x"} : (i64) -> !llvm.ptr<5>
@@ -51,8 +56,8 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<"dlti.alloca_memory_space" = 5 :
}
}
-// CHECK: call void @__kmpc_parallel_51({{.*}}, i32 1, i32 -1, i32 -1,
-// CHECK-SAME: ptr @[[PAR_OUTLINED:.*]], ptr null, ptr %2, i64 1)
+// CHECK: call void @__kmpc_parallel_60({{.*}}, i32 1, i32 -1, i32 -1,
+// CHECK-SAME: ptr @[[PAR_OUTLINED:.*]], ptr null, ptr %2, i64 1, i32 0)
// CHECK: define internal void @[[PAR_OUTLINED]]{{.*}} {
// CHECK: .omp.reduction.then:
@@ -67,9 +72,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<"dlti.alloca_memory_space" = 5 :
// CHECK: br label %[[CONT_BB:.*]]
// CHECK: [[CONT_BB]]:
-// CHECK-NEXT: %[[RED_RHS:.*]] = phi ptr [ %final.rhs, %{{.*}} ]
-// CHECK-NEXT: store ptr %[[RED_RHS]], ptr %{{.*}}, align 8
-// CHECK-NEXT: br label %.omp.reduction.done
+// CHECK-NEXT: %[[RED_RHS:.*]] = phi ptr [ %{{.*}}, %{{.*}} ]
// CHECK: }
// CHECK: define internal void @"{{.*}}$reduction$reduction_func"(ptr noundef %0, ptr noundef %1) #0 {
diff --git a/mlir/test/Target/LLVMIR/omptarget-multi-reduction.mlir b/mlir/test/Target/LLVMIR/omptarget-multi-reduction.mlir
index b8b7c78..8950db3 100644
--- a/mlir/test/Target/LLVMIR/omptarget-multi-reduction.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-multi-reduction.mlir
@@ -109,19 +109,19 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
// CHECK: icmp eq i32 %[[MASTER]], 1
// CHECK: i1 %{{.+}}, label %[[THEN:[A-Za-z0-9_.]*]], label %[[DONE:[A-Za-z0-9_.]*]]
// CHECK: [[THEN]]:
-// CHECK-NEXT: %[[FINAL_RHS0:[A-Za-z0-9_.]*]] = load double
// CHECK-NEXT: %[[FINAL_LHS0:[A-Za-z0-9_.]*]] = load double
+// CHECK-NEXT: %[[FINAL_RHS0:[A-Za-z0-9_.]*]] = load double
// CHECK-NEXT: %[[FINAL_RESULT0:[A-Za-z0-9_.]*]] = fadd contract double %[[FINAL_LHS0]], %[[FINAL_RHS0]]
// CHECK-NEXT: store double %[[FINAL_RESULT0]]
-// CHECK-NEXT: %[[FINAL_RHS1:[A-Za-z0-9_.]*]] = load double
// CHECK-NEXT: %[[FINAL_LHS1:[A-Za-z0-9_.]*]] = load double
+// CHECK-NEXT: %[[FINAL_RHS1:[A-Za-z0-9_.]*]] = load double
// CHECK-NEXT: %[[FINAL_RESULT1:[A-Za-z0-9_.]*]] = fadd contract double %[[FINAL_LHS1]], %[[FINAL_RHS1]]
// CHECK-NEXT: store double %[[FINAL_RESULT1]]
-// CHECK-NEXT: %[[FINAL_RHS2:[A-Za-z0-9_.]*]] = load float
// CHECK-NEXT: %[[FINAL_LHS2:[A-Za-z0-9_.]*]] = load float
+// CHECK-NEXT: %[[FINAL_RHS2:[A-Za-z0-9_.]*]] = load float
// CHECK-NEXT: %[[FINAL_RESULT2:[A-Za-z0-9_.]*]] = fadd contract float %[[FINAL_LHS2]], %[[FINAL_RHS2]]
// CHECK-NEXT: store float %[[FINAL_RESULT2]]
-// CHECK-NEXT: %[[FINAL_RHS3:[A-Za-z0-9_.]*]] = load float
// CHECK-NEXT: %[[FINAL_LHS3:[A-Za-z0-9_.]*]] = load float
+// CHECK-NEXT: %[[FINAL_RHS3:[A-Za-z0-9_.]*]] = load float
// CHECK-NEXT: %[[FINAL_RESULT3:[A-Za-z0-9_.]*]] = fadd contract float %[[FINAL_LHS3]], %[[FINAL_RHS3]]
// CHECK-NEXT: store float %[[FINAL_RESULT3]]
diff --git a/mlir/test/Target/LLVMIR/omptarget-nowait.mlir b/mlir/test/Target/LLVMIR/omptarget-nowait.mlir
index 19333c4..a96756f46 100644
--- a/mlir/test/Target/LLVMIR/omptarget-nowait.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-nowait.mlir
@@ -25,34 +25,33 @@ module attributes {omp.target_triples = ["amdgcn-amd-amdhsa"]} {
// CHECK: %struct.[[TSK_WTH_PRVTS:.*]] = type { %struct.kmp_task_ompbuilder_t, %struct.[[PRVTS:.*]] }
// CHECK: %struct.kmp_task_ompbuilder_t = type { ptr, ptr, i32, ptr, ptr }
-// CHECK: %struct.[[PRVTS]] = type { [5 x ptr], [5 x ptr], [5 x i64] }
+// CHECK: %struct.[[PRVTS]] = type { [6 x ptr], [6 x ptr], [6 x i64] }
// CHECK: define void @launch_(ptr captures(none) %0)
// CHECK: %[[STRUCTARG:.*]] = alloca { ptr, ptr }, align 8
-// CHECK: %[[BASEPTRS:.*]] = alloca [5 x ptr], align 8
-// CHECK: %[[PTRS:.*]] = alloca [5 x ptr], align 8
-// CHECK: %[[MAPPERS:.*]] = alloca [5 x ptr], align 8
-// CHECK: %[[SIZES:.*]] = alloca [5 x i64], align 4
+// CHECK: %[[BASEPTRS:.*]] = alloca [6 x ptr], align 8
+// CHECK: %[[PTRS:.*]] = alloca [6 x ptr], align 8
+// CHECK: %[[MAPPERS:.*]] = alloca [6 x ptr], align 8
+// CHECK: %[[SIZES:.*]] = alloca [6 x i64], align 4
-
-// CHECK: %[[VAL_20:.*]] = getelementptr inbounds [5 x ptr], ptr %[[BASEPTRS]], i32 0, i32 0
-// CHECK: %[[BASEPTRS_GEP:.*]] = getelementptr inbounds [5 x ptr], ptr %[[BASEPTRS]], i32 0, i32 0
-// CHECK: %[[PTRS_GEP:.*]] = getelementptr inbounds [5 x ptr], ptr %[[PTRS]], i32 0, i32 0
-// CHECK: %[[SIZES_GEP:.*]] = getelementptr inbounds [5 x i64], ptr %[[SIZES]], i32 0, i32 0
+// CHECK: %[[VAL_20:.*]] = getelementptr inbounds [6 x ptr], ptr %[[BASEPTRS]], i32 0, i32 0
+// CHECK: %[[BASEPTRS_GEP:.*]] = getelementptr inbounds [6 x ptr], ptr %[[BASEPTRS]], i32 0, i32 0
+// CHECK: %[[PTRS_GEP:.*]] = getelementptr inbounds [6 x ptr], ptr %[[PTRS]], i32 0, i32 0
+// CHECK: %[[SIZES_GEP:.*]] = getelementptr inbounds [6 x i64], ptr %[[SIZES]], i32 0, i32 0
// CHECK: %[[GL_THRD_NUM:.*]] = call i32 @__kmpc_global_thread_num
-// CHECK: %[[TASK_DESC:.*]] = call ptr @__kmpc_omp_target_task_alloc(ptr @4, i32 {{.*}}, i32 0, i64 160, i64 16, ptr [[TGT_TSK_PRXY_FNC:.*]], i64 -1)
+// CHECK: %[[TASK_DESC:.*]] = call ptr @__kmpc_omp_target_task_alloc(ptr @4, i32 {{.*}}, i32 0, i64 184, i64 16, ptr [[TGT_TSK_PRXY_FNC:.*]], i64 -1)
// CHECK: %[[TSK_PTR:.*]] = getelementptr inbounds nuw %struct.[[TSK_WTH_PRVTS]], ptr %[[TASK_DESC]], i32 0, i32 0
// CHECK: %[[SHAREDS:.*]] = getelementptr inbounds nuw %struct.kmp_task_ompbuilder_t, ptr %[[TSK_PTR]], i32 0, i32 0
// CHECK: %[[SHAREDS_PTR:.*]] = load ptr, ptr %[[SHAREDS]], align 8
// CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align 1 %[[SHAREDS_PTR]], ptr align 1 %[[STRUCTARG]], i64 16, i1 false)
// CHECK: %[[VAL_50:.*]] = getelementptr inbounds nuw %struct.[[TSK_WTH_PRVTS]], ptr %[[TASK_DESC]], i32 0, i32 1
// CHECK: %[[VAL_51:.*]] = getelementptr inbounds nuw %struct.[[PRVTS]], ptr %[[VAL_50]], i32 0, i32 0
-// CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align 1 %[[VAL_51]], ptr align 1 %[[BASEPTRS_GEP]], i64 40, i1 false)
+// CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align 1 %[[VAL_51]], ptr align 1 %[[BASEPTRS_GEP]], i64 48, i1 false)
// CHECK: %[[VAL_53:.*]] = getelementptr inbounds nuw %struct.[[PRVTS]], ptr %[[VAL_50]], i32 0, i32 1
-// CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align 1 %[[VAL_53]], ptr align 1 %[[PTRS_GEP]], i64 40, i1 false)
+// CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align 1 %[[VAL_53]], ptr align 1 %[[PTRS_GEP]], i64 48, i1 false)
// CHECK: %[[VAL_54:.*]] = getelementptr inbounds nuw %struct.[[PRVTS]], ptr %[[VAL_50]], i32 0, i32 2
-// CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align 1 %[[VAL_54]], ptr align 1 %[[SIZES_GEP]], i64 40, i1 false)
+// CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align 1 %[[VAL_54]], ptr align 1 %[[SIZES_GEP]], i64 48, i1 false)
// CHECK: %[[VAL_55:.*]] = call i32 @__kmpc_omp_task(ptr @4, i32 %[[GL_THRD_NUM]], ptr %[[TASK_DESC]])
// CHECK: define internal void @[[WORKER:.*]](i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}) {
diff --git a/mlir/test/Target/LLVMIR/omptarget-overlapping-record-member-map.mlir b/mlir/test/Target/LLVMIR/omptarget-overlapping-record-member-map.mlir
new file mode 100644
index 0000000..1e9369f
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omptarget-overlapping-record-member-map.mlir
@@ -0,0 +1,65 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu", omp.is_gpu = false, omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
+ llvm.func @_QQmain() attributes {fir.bindc_name = "main"} {
+ %0 = llvm.mlir.constant(1 : i64) : i64
+ %1 = llvm.alloca %0 x !llvm.struct<"_QFTdtype", (f32, i32)> {bindc_name = "dtypev"} : (i64) -> !llvm.ptr
+ %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"_QFTdtype", (f32, i32)>
+ %3 = omp.map.info var_ptr(%2 : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "dtypev%value2"}
+ %4 = omp.map.info var_ptr(%1 : !llvm.ptr, !llvm.struct<"_QFTdtype", (f32, i32)>) map_clauses(to) capture(ByRef) members(%3 : [1] : !llvm.ptr) -> !llvm.ptr {name = "dtypev"}
+ omp.target map_entries(%4 -> %arg0, %3 -> %arg1 : !llvm.ptr, !llvm.ptr) {
+ omp.terminator
+ }
+ llvm.return
+ }
+}
+
+// CHECK: @.offload_sizes = private unnamed_addr constant [4 x i64] [i64 0, i64 0, i64 0, i64 4]
+// CHECK: @.offload_maptypes = private unnamed_addr constant [4 x i64] [i64 32, i64 281474976710657, i64 281474976710657, i64 281474976710659]
+
+// CHECK: %[[ALLOCA:.*]] = alloca %_QFTdtype, i64 1, align 8
+// CHECK: %[[ELEMENT_ACC:.*]] = getelementptr %_QFTdtype, ptr %[[ALLOCA]], i32 0, i32 1
+
+// CHECK: %[[SIZE1_CALC_1:.*]] = getelementptr %_QFTdtype, ptr %[[ALLOCA]], i32 1
+// CHECK: %[[SIZE1_CALC_2:.*]] = ptrtoint ptr %[[SIZE1_CALC_1]] to i64
+// CHECK: %[[SIZE1_CALC_3:.*]] = ptrtoint ptr %[[ALLOCA]] to i64
+// CHECK: %[[SIZE1_CALC_4:.*]] = sub i64 %[[SIZE1_CALC_2]], %[[SIZE1_CALC_3]]
+// CHECK: %[[SIZE1_CALC_5:.*]] = sdiv exact i64 %[[SIZE1_CALC_4]], ptrtoint (ptr getelementptr (i8, ptr null, i32 1) to i64)
+
+// CHECK: %[[SIZE2_CALC_1:.*]] = getelementptr %_QFTdtype, ptr %[[ALLOCA]], i32 1
+// CHECK: %[[SIZE2_CALC_2:.*]] = ptrtoint ptr %[[ELEMENT_ACC]] to i64
+// CHECK: %[[SIZE2_CALC_3:.*]] = ptrtoint ptr %[[ALLOCA]] to i64
+// CHECK: %[[SIZE2_CALC_4:.*]] = sub i64 %[[SIZE2_CALC_2]], %[[SIZE2_CALC_3]]
+// CHECK: %[[SIZE2_CALC_5:.*]] = sdiv exact i64 %[[SIZE2_CALC_4]], ptrtoint (ptr getelementptr (i8, ptr null, i32 1) to i64)
+
+// CHECK: %[[SIZE3_CALC_1:.*]] = getelementptr i32, ptr %[[ELEMENT_ACC]], i32 1
+// CHECK: %[[SIZE3_CALC_2:.*]] = ptrtoint ptr %[[SIZE2_CALC_1]] to i64
+// CHECK: %[[SIZE3_CALC_3:.*]] = ptrtoint ptr %[[SIZE3_CALC_1]] to i64
+// CHECK: %[[SIZE3_CALC_4:.*]] = sub i64 %[[SIZE3_CALC_2]], %[[SIZE3_CALC_3]]
+// CHECK: %[[SIZE3_CALC_5:.*]] = sdiv exact i64 %[[SIZE3_CALC_4]], ptrtoint (ptr getelementptr (i8, ptr null, i32 1) to i64)
+
+// CHECK: %[[BASEPTR:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_baseptrs, i32 0, i32 0
+// CHECK: store ptr %[[ALLOCA]], ptr %[[BASEPTR]], align 8
+// CHECK: %[[PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_ptrs, i32 0, i32 0
+// CHECK: store ptr %[[ALLOCA]], ptr %[[PTRS]], align 8
+// CHECK: %[[SIZES:.*]] = getelementptr inbounds [4 x i64], ptr %.offload_sizes, i32 0, i32 0
+// CHECK: store i64 %[[SIZE1_CALC_5]], ptr %[[SIZES]], align 8
+
+// CHECK: %[[BASEPTR:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_baseptrs, i32 0, i32 1
+// CHECK: store ptr %[[ALLOCA]], ptr %[[BASEPTR]], align 8
+// CHECK: %[[PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_ptrs, i32 0, i32 1
+// CHECK: store ptr %[[ALLOCA]], ptr %[[PTRS]], align 8
+// CHECK: %[[SIZES:.*]] = getelementptr inbounds [4 x i64], ptr %.offload_sizes, i32 0, i32 1
+// CHECK: store i64 %[[SIZE2_CALC_5]], ptr %[[SIZES]], align 8
+
+// CHECK: %[[BASEPTR:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_baseptrs, i32 0, i32 2
+// CHECK: store ptr %[[ALLOCA]], ptr %[[BASEPTR]], align 8
+// CHECK: %[[PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_ptrs, i32 0, i32 2
+// CHECK: store ptr %13, ptr %[[PTRS]], align 8
+// CHECK: %[[SIZES:.*]] = getelementptr inbounds [4 x i64], ptr %.offload_sizes, i32 0, i32 2
+// CHECK: store i64 %[[SIZE3_CALC_5]], ptr %[[SIZES]], align 8
+
+// CHECK: %[[BASEPTR:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_baseptrs, i32 0, i32 3
+// CHECK: store ptr %[[ALLOCA]], ptr %[[BASEPTR]], align 8
+// CHECK: %[[PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_ptrs, i32 0, i32 3
+// CHECK: store ptr %[[ELEMENT_ACC]], ptr %[[PTRS]], align 8
diff --git a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
index 60c6fa4..cdb8dbb 100644
--- a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
@@ -70,31 +70,31 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
// CHECK: store ptr %[[TMP6]], ptr addrspace(5) %[[GEP_]], align 8
// CHECK: %[[TMP7:.*]] = getelementptr inbounds [1 x ptr], ptr %[[TMP2]], i64 0, i64 0
// CHECK: store ptr %[[STRUCTARG_ASCAST]], ptr %[[TMP7]], align 8
-// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast (ptr addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, i32 -1, ptr @[[FUNC1:.*]], ptr null, ptr %[[TMP2]], i64 1)
+// CHECK: call void @__kmpc_parallel_60(ptr addrspacecast (ptr addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, i32 -1, ptr @[[FUNC1:.*]], ptr null, ptr %[[TMP2]], i64 1, i32 0)
// CHECK: call void @__kmpc_target_deinit()
// CHECK: define internal void @[[FUNC1]](
// CHECK-SAME: ptr noalias noundef {{.*}}, ptr noalias noundef {{.*}}, ptr {{.*}}) #{{[0-9]+}} {
// Test if num_threads OpenMP clause for target region is correctly lowered
-// and passed as a param to kmpc_parallel_51 function
+// and passed as a param to kmpc_parallel_60 function
// CHECK: define weak_odr protected amdgpu_kernel void [[FUNC_NUM_THREADS0:@.*]](
// CHECK-NOT: call void @__kmpc_push_num_threads(
-// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast (
+// CHECK: call void @__kmpc_parallel_60(ptr addrspacecast (
// CHECK-SAME: ptr addrspace(1) @[[NUM_THREADS_GLOB:[0-9]+]] to ptr),
// CHECK-SAME: i32 [[NUM_THREADS_TMP0:%.*]], i32 1, i32 156,
-// CHECK-SAME: i32 -1, ptr [[FUNC_NUM_THREADS1:@.*]], ptr null, ptr [[NUM_THREADS_TMP1:%.*]], i64 1)
+// CHECK-SAME: i32 -1, ptr [[FUNC_NUM_THREADS1:@.*]], ptr null, ptr [[NUM_THREADS_TMP1:%.*]], i64 1, i32 0)
-// One of the arguments of kmpc_parallel_51 function is responsible for handling if clause
+// One of the arguments of kmpc_parallel_60 function is responsible for handling if clause
// of omp parallel construct for target region. If this argument is nonzero,
-// then kmpc_parallel_51 launches multiple threads for parallel region.
+// then kmpc_parallel_60 launches multiple threads for parallel region.
//
// This test checks if MLIR expression:
// %7 = llvm.icmp "ne" %5, %6 : i32
// omp.parallel if(%7)
// is correctly lowered to LLVM IR code and the if condition variable
-// is passed as a param to kmpc_parallel_51 function
+// is passed as a param to kmpc_parallel_60 function
// CHECK: define weak_odr protected amdgpu_kernel void @{{.*}}(
// CHECK-SAME: ptr {{.*}}, ptr {{.*}}, ptr %[[IFCOND_ARG2:.*]]) #{{[0-9]+}} {
@@ -102,7 +102,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
// CHECK: %[[IFCOND_TMP2:.*]] = load i32, ptr %[[IFCOND_TMP1]], align 4
// CHECK: %[[IFCOND_TMP3:.*]] = icmp ne i32 %[[IFCOND_TMP2]], 0
// CHECK: %[[IFCOND_TMP4:.*]] = sext i1 %[[IFCOND_TMP3]] to i32
-// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast (
+// CHECK: call void @__kmpc_parallel_60(ptr addrspacecast (
// CHECK-SAME: ptr addrspace(1) {{.*}} to ptr),
// CHECK-SAME: i32 {{.*}}, i32 %[[IFCOND_TMP4]], i32 -1,
-// CHECK-SAME: i32 -1, ptr {{.*}}, ptr null, ptr {{.*}}, i64 1)
+// CHECK-SAME: i32 -1, ptr {{.*}}, ptr null, ptr {{.*}}, i64 1, i32 0)
diff --git a/mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir b/mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir
index 5d2861a..917eaa0 100644
--- a/mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir
@@ -26,10 +26,10 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
}
}
-// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast
+// CHECK: call void @__kmpc_parallel_60(ptr addrspacecast
// CHECK-SAME: (ptr addrspace(1) @[[GLOB:[0-9]+]] to ptr),
// CHECK-SAME: i32 %[[THREAD_NUM:.*]], i32 1, i32 -1, i32 -1,
-// CHECK-SAME: ptr @[[PARALLEL_FUNC:.*]], ptr null, ptr %[[PARALLEL_ARGS:.*]], i64 1)
+// CHECK-SAME: ptr @[[PARALLEL_FUNC:.*]], ptr null, ptr %[[PARALLEL_ARGS:.*]], i64 1, i32 0)
// CHECK: define internal void @[[PARALLEL_FUNC]]
// CHECK-SAME: (ptr noalias noundef %[[TID_ADDR:.*]], ptr noalias noundef %[[ZERO_ADDR:.*]],
diff --git a/mlir/test/Target/LLVMIR/omptarget-record-type-with-ptr-member-host.mlir b/mlir/test/Target/LLVMIR/omptarget-record-type-with-ptr-member-host.mlir
index 9640f03..711b50a 100644
--- a/mlir/test/Target/LLVMIR/omptarget-record-type-with-ptr-member-host.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-record-type-with-ptr-member-host.mlir
@@ -59,9 +59,9 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-a
// CHECK: @[[FULL_ARR_GLOB:.*]] = internal global { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] } undef
// CHECK: @[[ARR_SECT_GLOB:.*]] = internal global { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] } undef
-// CHECK: @.offload_sizes = private unnamed_addr constant [12 x i64] [i64 0, i64 48, i64 8, i64 0, i64 0, i64 48, i64 8, i64 0, i64 0, i64 24, i64 8, i64 0]
-// CHECK: @.offload_maptypes = private unnamed_addr constant [12 x i64] [i64 32, i64 281474976710659, i64 281474976710659, i64 281474976710675, i64 32, i64 1407374883553283, i64 1407374883553283, i64 1407374883553299, i64 32, i64 2533274790395907, i64 2533274790395907, i64 2533274790395923]
-// CHECK: @.offload_mapnames = private constant [12 x ptr] [ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}]
+// CHECK: @.offload_sizes = private unnamed_addr constant [15 x i64] [i64 0, i64 0, i64 0, i64 8, i64 0, i64 0, i64 0, i64 0, i64 8, i64 0, i64 0, i64 0, i64 0, i64 8, i64 0]
+// CHECK: @.offload_maptypes = private unnamed_addr constant [15 x i64] [i64 32, i64 281474976710659, i64 281474976710659, i64 281474976710659, i64 281474976710675, i64 32, i64 1688849860263939, i64 1688849860263939, i64 1688849860263939, i64 1688849860263955, i64 32, i64 3096224743817219, i64 3096224743817219, i64 3096224743817219, i64 3096224743817235]
+// CHECK: @.offload_mapnames = private constant [15 x ptr] [ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}]
// CHECK: define void @main()
// CHECK: %[[SCALAR_ALLOCA:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8 }, i64 1, align 8
@@ -85,74 +85,97 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-a
// CHECK: %[[ARR_SECT_PTR:.*]] = getelementptr inbounds i32, ptr %[[LARR_SECT]], i64 %[[ARR_SECT_OFFSET2]]
// CHECK: %[[SCALAR_PTR_LOAD:.*]] = load ptr, ptr %[[SCALAR_BASE]], align 8
// CHECK: %[[FULL_ARR_DESC_SIZE:.*]] = sdiv exact i64 48, ptrtoint (ptr getelementptr (i8, ptr null, i32 1) to i64)
-// CHECK: %[[FULL_ARR_SIZE_CMP:.*]] = icmp eq ptr %[[FULL_ARR_PTR]], null
-// CHECK: %[[FULL_ARR_SIZE_SEL:.*]] = select i1 %[[FULL_ARR_SIZE_CMP]], i64 0, i64 %[[FULL_ARR_SIZE]]
+// CHECK: %[[FULL_ARR_SZ:.*]] = sdiv exact i64 40, ptrtoint (ptr getelementptr (i8, ptr null, i32 1) to i64)
+// CHECK: %[[NULL_CMP:.*]] = icmp eq ptr %[[FULL_ARR_PTR]], null
+// CHECK: %[[IS_NULL:.*]] = select i1 %[[NULL_CMP]], i64 0, i64 %[[FULL_ARR_SIZE]]
// CHECK: %[[ARR_SECT_DESC_SIZE:.*]] = sdiv exact i64 48, ptrtoint (ptr getelementptr (i8, ptr null, i32 1) to i64)
-// CHECK: %[[ARR_SECT_SIZE_CMP:.*]] = icmp eq ptr %[[ARR_SECT_PTR]], null
-// CHECK: %[[ARR_SECT_SIZE_SEL:.*]] = select i1 %[[ARR_SECT_SIZE_CMP]], i64 0, i64 %[[ARR_SECT_SIZE]]
+// CHECK: %[[ARR_SECT_SZ:.*]] = sdiv exact i64 40, ptrtoint (ptr getelementptr (i8, ptr null, i32 1) to i64)
+// CHECK: %[[NULL_CMP2:.*]] = icmp eq ptr %[[ARR_SECT_PTR]], null
+// CHECK: %[[IS_NULL2:.*]] = select i1 %[[NULL_CMP2]], i64 0, i64 %[[ARR_SECT_SIZE]]
// CHECK: %[[SCALAR_DESC_SZ4:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8 }, ptr %[[SCALAR_ALLOCA]], i32 1
// CHECK: %[[SCALAR_DESC_SZ3:.*]] = ptrtoint ptr %[[SCALAR_DESC_SZ4]] to i64
// CHECK: %[[SCALAR_DESC_SZ2:.*]] = ptrtoint ptr %[[SCALAR_ALLOCA]] to i64
// CHECK: %[[SCALAR_DESC_SZ1:.*]] = sub i64 %[[SCALAR_DESC_SZ3]], %[[SCALAR_DESC_SZ2]]
// CHECK: %[[SCALAR_DESC_SZ:.*]] = sdiv exact i64 %[[SCALAR_DESC_SZ1]], ptrtoint (ptr getelementptr (i8, ptr null, i32 1) to i64)
-
-// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 0
+// CHECK: %[[SCALAR_BASE_2:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8 }, ptr %[[SCALAR_ALLOCA]], i32 1
+// CHECK: %[[SCALAR_BASE_OFF:.*]] = getelementptr ptr, ptr %[[SCALAR_BASE]], i32 1
+// CHECK: %[[SCALAR_BASE_OFF_SZ1:.*]] = ptrtoint ptr %[[SCALAR_BASE_2]] to i64
+// CHECK: %[[SCALAR_BASE_OFF_SZ2:.*]] = ptrtoint ptr %[[SCALAR_BASE_OFF]] to i64
+// CHECK: %[[SCALAR_BASE_OFF_SZ3:.*]] = sub i64 %[[SCALAR_BASE_OFF_SZ1]], %[[SCALAR_BASE_OFF_SZ2]]
+// CHECK: %[[SCALAR_BASE_OFF_SZ4:.*]] = sdiv exact i64 %[[SCALAR_BASE_OFF_SZ3]], ptrtoint (ptr getelementptr (i8, ptr null, i32 1) to i64)
+// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 0
// CHECK: store ptr @full_arr, ptr %[[OFFLOADBASEPTRS]], align 8
-// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 0
+// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 0
// CHECK: store ptr @full_arr, ptr %[[OFFLOADPTRS]], align 8
-// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [12 x i64], ptr %.offload_sizes, i32 0, i32 0
+// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [15 x i64], ptr %.offload_sizes, i32 0, i32 0
// CHECK: store i64 %[[FULL_ARR_DESC_SIZE]], ptr %[[OFFLOADSIZES]], align 8
-// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 1
+// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 1
// CHECK: store ptr @full_arr, ptr %[[OFFLOADBASEPTRS]], align 8
-// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 1
+// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 1
// CHECK: store ptr @full_arr, ptr %[[OFFLOADPTRS]], align 8
-// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 2
+// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 2
+// CHECK: store ptr @full_arr, ptr %[[OFFLOADBASEPTRS]], align 8
+// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 2
+// CHECK: store ptr getelementptr inbounds nuw (i8, ptr @full_arr, i64 8), ptr %[[OFFLOADPTRS]], align 8
+// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [15 x i64], ptr %.offload_sizes, i32 0, i32 2
+// CHECK: store i64 %[[FULL_ARR_SZ]], ptr %[[OFFLOADSIZES]], align 8
+// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 3
// CHECK: store ptr @full_arr, ptr %[[OFFLOADBASEPTRS]], align 8
-// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 2
+// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 3
// CHECK: store ptr @full_arr, ptr %[[OFFLOADPTRS]], align 8
-// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 3
+// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 4
// CHECK: store ptr @full_arr, ptr %[[OFFLOADBASEPTRS]], align 8
-// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 3
+// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 4
// CHECK: store ptr %[[FULL_ARR_PTR]], ptr %[[OFFLOADPTRS]], align 8
-// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [12 x i64], ptr %.offload_sizes, i32 0, i32 3
-// CHECK: store i64 %[[FULL_ARR_SIZE_SEL]], ptr %[[OFFLOADSIZES]], align 8
-
-// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 4
+// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [15 x i64], ptr %.offload_sizes, i32 0, i32 4
+// CHECK: store i64 %[[IS_NULL]], ptr %[[OFFLOADSIZES]], align 8
+// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 5
// CHECK: store ptr @sect_arr, ptr %[[OFFLOADBASEPTRS]], align 8
-// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 4
+// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 5
// CHECK: store ptr @sect_arr, ptr %[[OFFLOADPTRS]], align 8
-// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [12 x i64], ptr %.offload_sizes, i32 0, i32 4
+// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [15 x i64], ptr %.offload_sizes, i32 0, i32 5
// CHECK: store i64 %[[ARR_SECT_DESC_SIZE]], ptr %[[OFFLOADSIZES]], align 8
-// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 5
+// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 6
// CHECK: store ptr @sect_arr, ptr %[[OFFLOADBASEPTRS]], align 8
-// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 5
+// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 6
// CHECK: store ptr @sect_arr, ptr %[[OFFLOADPTRS]], align 8
-// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 6
+// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 7
// CHECK: store ptr @sect_arr, ptr %[[OFFLOADBASEPTRS]], align 8
-// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 6
+// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 7
+// CHECK: store ptr getelementptr inbounds nuw (i8, ptr @sect_arr, i64 8), ptr %[[OFFLOADPTRS]], align 8
+// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [15 x i64], ptr %.offload_sizes, i32 0, i32 7
+// CHECK: store i64 %[[ARR_SECT_SZ]], ptr %[[OFFLOADSIZES]], align 8
+// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 8
+// CHECK: store ptr @sect_arr, ptr %[[OFFLOADBASEPTRS]], align 8
+// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 8
// CHECK: store ptr @sect_arr, ptr %[[OFFLOADPTRS]], align 8
-// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 7
+// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 9
// CHECK: store ptr @sect_arr, ptr %[[OFFLOADBASEPTRS]], align 8
-// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 7
+// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 9
// CHECK: store ptr %[[ARR_SECT_PTR]], ptr %[[OFFLOADPTRS]], align 8
-// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [12 x i64], ptr %.offload_sizes, i32 0, i32 7
-// CHECK: store i64 %[[ARR_SECT_SIZE_SEL]], ptr %[[OFFLOADSIZES]], align 8
-
-// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 8
+// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [15 x i64], ptr %.offload_sizes, i32 0, i32 9
+// CHECK: store i64 %[[IS_NULL2]], ptr %[[OFFLOADSIZES]], align 8
+// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 10
// CHECK: store ptr %[[SCALAR_ALLOCA]], ptr %[[OFFLOADBASEPTRS]], align 8
-// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 8
+// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 10
// CHECK: store ptr %[[SCALAR_ALLOCA]], ptr %[[OFFLOADPTRS]], align 8
-// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [12 x i64], ptr %.offload_sizes, i32 0, i32 8
+// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [15 x i64], ptr %.offload_sizes, i32 0, i32 10
// CHECK: store i64 %[[SCALAR_DESC_SZ]], ptr %[[OFFLOADSIZES]], align 8
-// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 9
+// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 11
// CHECK: store ptr %[[SCALAR_ALLOCA]], ptr %[[OFFLOADBASEPTRS]], align 8
-// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 9
+// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 11
// CHECK: store ptr %[[SCALAR_ALLOCA]], ptr %[[OFFLOADPTRS]], align 8
-// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 10
+// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 12
+// CHECK: store ptr %[[SCALAR_ALLOCA]], ptr %[[OFFLOADBASEPTRS]], align 8
+// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 12
+// CHECK: store ptr %[[SCALAR_BASE_OFF]], ptr %[[OFFLOADPTRS]], align 8
+// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [15 x i64], ptr %.offload_sizes, i32 0, i32 12
+// CHECK: store i64 %[[SCALAR_BASE_OFF_SZ4]], ptr %[[OFFLOADSIZES]], align 8
+// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 13
// CHECK: store ptr %[[SCALAR_ALLOCA]], ptr %[[OFFLOADBASEPTRS]], align 8
-// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 10
+// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 13
// CHECK: store ptr %[[SCALAR_BASE]], ptr %[[OFFLOADPTRS]], align 8
-// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 11
+// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 14
// CHECK: store ptr %[[SCALAR_BASE]], ptr %[[OFFLOADBASEPTRS]], align 8
-// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 11
+// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 14
// CHECK: store ptr %[[SCALAR_PTR_LOAD]], ptr %[[OFFLOADPTRS]], align 8
diff --git a/mlir/test/Target/LLVMIR/omptarget-runtimecc.mlir b/mlir/test/Target/LLVMIR/omptarget-runtimecc.mlir
new file mode 100644
index 0000000..a232bd7
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omptarget-runtimecc.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+module attributes {omp.is_target_device = true, omp.is_gpu = true, omp.target_triples = ["spirv64-intel"], llvm.target_triple = "spirv64-intel"} {
+// CHECK: call spir_func i32 @__kmpc_target_init
+// CHECK: call spir_func void @__kmpc_target_deinit
+ llvm.func @target_if_variable(%x : i1) {
+ omp.target if(%x) {
+ omp.terminator
+ }
+ llvm.return
+ }
+ }
diff --git a/mlir/test/Target/LLVMIR/omptarget-teams-distribute-reduction.mlir b/mlir/test/Target/LLVMIR/omptarget-teams-distribute-reduction.mlir
index 9aba72d..b7cb102 100644
--- a/mlir/test/Target/LLVMIR/omptarget-teams-distribute-reduction.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-teams-distribute-reduction.mlir
@@ -59,8 +59,8 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
// CHECK: call void @__kmpc_barrier
// CHECK: [[THEN]]:
-// CHECK-NEXT: %[[FINAL_RHS:[A-Za-z0-9_.]*]] = load i32
// CHECK-NEXT: %[[FINAL_LHS:[A-Za-z0-9_.]*]] = load i32
+// CHECK-NEXT: %[[FINAL_RHS:[A-Za-z0-9_.]*]] = load i32
// CHECK-NEXT: %[[FINAL_RESULT:[A-Za-z0-9_.]*]] = add i32 %[[FINAL_LHS]], %[[FINAL_RHS]]
// CHECK-NEXT: store i32 %[[FINAL_RESULT]]
diff --git a/mlir/test/Target/LLVMIR/omptarget-teams-reduction.mlir b/mlir/test/Target/LLVMIR/omptarget-teams-reduction.mlir
index dc22fe1..36eb280 100644
--- a/mlir/test/Target/LLVMIR/omptarget-teams-reduction.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-teams-reduction.mlir
@@ -62,8 +62,8 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
// CHECK: icmp eq i32 %[[MASTER]], 1
// CHECK: i1 %{{.+}}, label %[[THEN:[A-Za-z0-9_.]*]], label %[[DONE:[A-Za-z0-9_.]*]]
// CHECK: [[THEN]]:
-// CHECK-NEXT: %[[FINAL_RHS:[A-Za-z0-9_.]*]] = load i32
// CHECK-NEXT: %[[FINAL_LHS:[A-Za-z0-9_.]*]] = load i32
+// CHECK-NEXT: %[[FINAL_RHS:[A-Za-z0-9_.]*]] = load i32
// CHECK-NEXT: %[[FINAL_RESULT:[A-Za-z0-9_.]*]] = add i32 %[[FINAL_LHS]], %[[FINAL_RHS]]
// CHECK-NEXT: store i32 %[[FINAL_RESULT]]
diff --git a/mlir/test/Target/LLVMIR/openmp-barrier-cancel.mlir b/mlir/test/Target/LLVMIR/openmp-barrier-cancel.mlir
index c4b2456..6585549 100644
--- a/mlir/test/Target/LLVMIR/openmp-barrier-cancel.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-barrier-cancel.mlir
@@ -29,22 +29,24 @@ llvm.func @test() {
// CHECK: %[[VAL_14:.*]] = icmp eq i32 %[[VAL_13]], 0
// CHECK: br i1 %[[VAL_14]], label %[[VAL_15:.*]], label %[[VAL_16:.*]]
// CHECK: omp.par.region1.cncl: ; preds = %[[VAL_11]]
-// CHECK: %[[VAL_17:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
-// CHECK: %[[VAL_18:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[VAL_17]])
-// CHECK: br label %[[VAL_19:.*]]
+// CHECK: br label %[[FINI:.*]]
+// CHECK: .fini:
+// CHECK: %[[TID:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
+// CHECK: %[[CNCL_BARRIER:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[TID]])
+// CHECK: br label %[[EXIT_STUB:.*]]
// CHECK: omp.par.region1.split: ; preds = %[[VAL_11]]
// CHECK: %[[VAL_20:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
// CHECK: %[[VAL_21:.*]] = call i32 @__kmpc_cancel_barrier(ptr @3, i32 %[[VAL_20]])
// CHECK: %[[VAL_22:.*]] = icmp eq i32 %[[VAL_21]], 0
// CHECK: br i1 %[[VAL_22]], label %[[VAL_23:.*]], label %[[VAL_24:.*]]
// CHECK: omp.par.region1.split.cncl: ; preds = %[[VAL_15]]
-// CHECK: br label %[[VAL_19]]
+// CHECK: br label %[[FINI]]
// CHECK: omp.par.region1.split.cont: ; preds = %[[VAL_15]]
// CHECK: br label %[[VAL_25:.*]]
// CHECK: omp.region.cont: ; preds = %[[VAL_23]]
// CHECK: br label %[[VAL_26:.*]]
// CHECK: omp.par.pre_finalize: ; preds = %[[VAL_25]]
-// CHECK: br label %[[VAL_19]]
-// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_26]], %[[VAL_24]], %[[VAL_16]]
+// CHECK: br label %[[FINI]]
+// CHECK: omp.par.exit.exitStub:
// CHECK: ret void
diff --git a/mlir/test/Target/LLVMIR/openmp-cancel.mlir b/mlir/test/Target/LLVMIR/openmp-cancel.mlir
index 2124170..a6911f8 100644
--- a/mlir/test/Target/LLVMIR/openmp-cancel.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-cancel.mlir
@@ -24,16 +24,18 @@ llvm.func @cancel_parallel() {
// CHECK: %[[VAL_15:.*]] = icmp eq i32 %[[VAL_14]], 0
// CHECK: br i1 %[[VAL_15]], label %[[VAL_16:.*]], label %[[VAL_17:.*]]
// CHECK: omp.par.region1.cncl: ; preds = %[[VAL_12]]
+// CHECK: br label %[[VAL_20:.*]]
+// CHECK: .fini:
// CHECK: %[[VAL_18:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
// CHECK: %[[VAL_19:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[VAL_18]])
-// CHECK: br label %[[VAL_20:.*]]
+// CHECK: br label %[[EXIT_STUB:.*]]
// CHECK: omp.par.region1.split: ; preds = %[[VAL_12]]
// CHECK: br label %[[VAL_21:.*]]
// CHECK: omp.region.cont: ; preds = %[[VAL_16]]
// CHECK: br label %[[VAL_22:.*]]
// CHECK: omp.par.pre_finalize: ; preds = %[[VAL_21]]
// CHECK: br label %[[VAL_20]]
-// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_22]], %[[VAL_17]]
+// CHECK: omp.par.exit.exitStub:
// CHECK: ret void
llvm.func @cancel_parallel_if(%arg0 : i1) {
@@ -58,27 +60,36 @@ llvm.func @cancel_parallel_if(%arg0 : i1) {
// CHECK: omp.par.region: ; preds = %[[VAL_17]]
// CHECK: br label %[[VAL_20:.*]]
// CHECK: omp.par.region1: ; preds = %[[VAL_19]]
-// CHECK: br i1 %[[VAL_16]], label %[[VAL_21:.*]], label %[[VAL_22:.*]]
+// CHECK: br i1 %[[VAL_16]], label %[[SPLIT:.*]], label %[[VAL_22:.*]]
// CHECK: 3: ; preds = %[[VAL_20]]
-// CHECK: br label %[[VAL_23:.*]]
-// CHECK: 4: ; preds = %[[VAL_22]], %[[VAL_24:.*]]
+// CHECK: %[[GTN:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
+// CHECK: %[[NOT_CANCELLED:.*]] = call i32 @__kmpc_cancellationpoint(ptr @1, i32 %[[GTN]], i32 1)
+// CHECK: %[[COND:.*]] = icmp eq i32 %[[NOT_CANCELLED]], 0
+// CHECK: br i1 %[[COND]], label %[[VAL_23:.*]], label %[[CNCL:.*]]
+// CHECK: .cncl:
+// CHECK: br label %[[FINI:.*]]
+// CHECK: .fini:
+// CHECK: %[[VAL_32:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
+// CHECK: %[[VAL_33:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[VAL_32]])
+// CHECK: br label %[[EXIT_STUB:.*]]
+// CHECK: .split:
+// CHECK: br label %[[SEVEN:.*]]
+// CHECK: 7:
// CHECK: br label %[[VAL_25:.*]]
-// CHECK: omp.region.cont: ; preds = %[[VAL_23]]
+// CHECK: omp.region.cont:
// CHECK: br label %[[VAL_26:.*]]
// CHECK: omp.par.pre_finalize: ; preds = %[[VAL_25]]
// CHECK: br label %[[VAL_27:.*]]
-// CHECK: 5: ; preds = %[[VAL_20]]
+// CHECK: 8: ; preds = %[[VAL_20]]
// CHECK: %[[VAL_28:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
// CHECK: %[[VAL_29:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_28]], i32 1)
// CHECK: %[[VAL_30:.*]] = icmp eq i32 %[[VAL_29]], 0
-// CHECK: br i1 %[[VAL_30]], label %[[VAL_24]], label %[[VAL_31:.*]]
-// CHECK: .cncl: ; preds = %[[VAL_21]]
-// CHECK: %[[VAL_32:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
-// CHECK: %[[VAL_33:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[VAL_32]])
-// CHECK: br label %[[VAL_27]]
-// CHECK: .split: ; preds = %[[VAL_21]]
-// CHECK: br label %[[VAL_23]]
-// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_31]], %[[VAL_26]]
+// CHECK: br i1 %[[VAL_30]], label %[[SPLIT5:.*]], label %[[VAL_31:.*]]
+// CHECK: .cncl{{.*}}:
+// CHECK: br label %[[FINI]]
+// CHECK: .split{{.*}}:
+// CHECK: br label %[[SEVEN]]
+// CHECK: omp.par.exit.exitStub:
// CHECK: ret void
llvm.func @cancel_sections_if(%cond : i1) {
@@ -132,11 +143,16 @@ llvm.func @cancel_sections_if(%cond : i1) {
// CHECK: %[[VAL_30:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_29]], i32 3)
// CHECK: %[[VAL_31:.*]] = icmp eq i32 %[[VAL_30]], 0
// CHECK: br i1 %[[VAL_31]], label %[[VAL_32:.*]], label %[[VAL_33:.*]]
-// CHECK: .split: ; preds = %[[VAL_27]]
+// CHECK: .split{{.*}}: ; preds = %[[VAL_27]]
// CHECK: br label %[[VAL_34:.*]]
// CHECK: 12: ; preds = %[[VAL_25]]
+// CHECK: %[[GTN:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
+// CHECK: %[[CANCEL_POINT:.*]] = call i32 @__kmpc_cancellationpoint(ptr @1, i32 %[[GTN]], i32 3)
+// CHECK: %[[COND:.*]] = icmp eq i32 %13, 0
+// CHECK: br i1 %[[COND]], label %[[SPLIT:.*]], label %[[CNCL:.*]]
+// CHECK: .split{{.*}}:
// CHECK: br label %[[VAL_34]]
-// CHECK: 13: ; preds = %[[VAL_28]], %[[VAL_32]]
+// CHECK: 15:
// CHECK: br label %[[VAL_35:.*]]
// CHECK: omp.region.cont: ; preds = %[[VAL_34]]
// CHECK: br label %[[VAL_23]]
@@ -145,17 +161,17 @@ llvm.func @cancel_sections_if(%cond : i1) {
// CHECK: omp_section_loop.inc: ; preds = %[[VAL_23]]
// CHECK: %[[VAL_15]] = add nuw i32 %[[VAL_14]], 1
// CHECK: br label %[[VAL_12]]
-// CHECK: omp_section_loop.exit: ; preds = %[[VAL_33]], %[[VAL_16]]
+// CHECK: omp_section_loop.exit:
// CHECK: call void @__kmpc_for_static_fini(ptr @1, i32 %[[VAL_7]])
// CHECK: %[[VAL_36:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
// CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_36]])
// CHECK: br label %[[VAL_37:.*]]
// CHECK: omp_section_loop.after: ; preds = %[[VAL_19]]
-// CHECK: br label %[[VAL_38:.*]]
-// CHECK: omp_section_loop.aftersections.fini: ; preds = %[[VAL_37]]
// CHECK: ret void
-// CHECK: .cncl: ; preds = %[[VAL_27]]
-// CHECK: br label %[[VAL_19]]
+// CHECK: .cncl:
+// CHECK: br label %[[OMP_SECTION_LOOP_EXIT:.*]]
+// CHECK: .cncl{{.*}}:
+// CHECK: br label %[[OMP_SECTION_LOOP_EXIT:.*]]
llvm.func @cancel_wsloop_if(%lb : i32, %ub : i32, %step : i32, %cond : i1) {
omp.wsloop {
@@ -221,18 +237,23 @@ llvm.func @cancel_wsloop_if(%lb : i32, %ub : i32, %step : i32, %cond : i1) {
// CHECK: %[[VAL_47:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_46]], i32 2)
// CHECK: %[[VAL_48:.*]] = icmp eq i32 %[[VAL_47]], 0
// CHECK: br i1 %[[VAL_48]], label %[[VAL_49:.*]], label %[[VAL_50:.*]]
-// CHECK: .split: ; preds = %[[VAL_44]]
+// CHECK: .split{{.*}}:
// CHECK: br label %[[VAL_51:.*]]
-// CHECK: 28: ; preds = %[[VAL_42]]
+// CHECK: 28:
+// CHECK: %[[GTN:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
+// CHECK: %[[CANCEL_POINT:.*]] = call i32 @__kmpc_cancellationpoint(ptr @1, i32 %[[GTN]], i32 2)
+// CHECK: %[[COND:.*]] = icmp eq i32 %[[CANCEL_POINT]], 0
+// CHECK: br i1 %[[COND]], label %[[SPLIT3:.*]], label %[[CNCL4:.*]]
+// CHECK: .split{{.*}}:
// CHECK: br label %[[VAL_51]]
-// CHECK: 29: ; preds = %[[VAL_45]], %[[VAL_49]]
+// CHECK: 31:
// CHECK: br label %[[VAL_52:.*]]
// CHECK: omp.region.cont1: ; preds = %[[VAL_51]]
// CHECK: br label %[[VAL_32]]
// CHECK: omp_loop.inc: ; preds = %[[VAL_52]]
// CHECK: %[[VAL_34]] = add nuw i32 %[[VAL_33]], 1
// CHECK: br label %[[VAL_31]]
-// CHECK: omp_loop.exit: ; preds = %[[VAL_50]], %[[VAL_35]]
+// CHECK: omp_loop.exit:
// CHECK: call void @__kmpc_for_static_fini(ptr @1, i32 %[[VAL_26]])
// CHECK: %[[VAL_53:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
// CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_53]])
@@ -241,8 +262,12 @@ llvm.func @cancel_wsloop_if(%lb : i32, %ub : i32, %step : i32, %cond : i1) {
// CHECK: br label %[[VAL_55:.*]]
// CHECK: omp.region.cont: ; preds = %[[VAL_54]]
// CHECK: ret void
-// CHECK: .cncl: ; preds = %[[VAL_44]]
-// CHECK: br label %[[VAL_38]]
+// CHECK: .cncl{{.*}}:
+// CHECK: br label %[[FINI:.*]]
+// CHECK: .fini:
+// CHECK: br label %[[OMP_LOOP_EXIT:.*]]
+// CHECK: .cncl{{.*}}:
+// CHECK: br label %[[FINI:.*]]
omp.private {type = firstprivate} @i32_priv : i32 copy {
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
diff --git a/mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir b/mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir
index 5e0d3f9..93fa2064 100644
--- a/mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir
@@ -24,16 +24,18 @@ llvm.func @cancellation_point_parallel() {
// CHECK: %[[VAL_15:.*]] = icmp eq i32 %[[VAL_14]], 0
// CHECK: br i1 %[[VAL_15]], label %[[VAL_16:.*]], label %[[VAL_17:.*]]
// CHECK: omp.par.region1.cncl: ; preds = %[[VAL_12]]
+// CHECK: br label %[[FINI:.*]]
+// CHECK: .fini:
// CHECK: %[[VAL_18:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
// CHECK: %[[VAL_19:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[VAL_18]])
-// CHECK: br label %[[VAL_20:.*]]
+// CHECK: br label %[[EXIT_STUB:.*]]
// CHECK: omp.par.region1.split: ; preds = %[[VAL_12]]
// CHECK: br label %[[VAL_21:.*]]
// CHECK: omp.region.cont: ; preds = %[[VAL_16]]
// CHECK: br label %[[VAL_22:.*]]
// CHECK: omp.par.pre_finalize: ; preds = %[[VAL_21]]
-// CHECK: br label %[[VAL_20]]
-// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_22]], %[[VAL_17]]
+// CHECK: br label %[[FINI]]
+// CHECK: omp.par.exit.exitStub:
// CHECK: ret void
llvm.func @cancellation_point_sections() {
@@ -94,14 +96,12 @@ llvm.func @cancellation_point_sections() {
// CHECK: omp_section_loop.inc: ; preds = %[[VAL_46]]
// CHECK: %[[VAL_38]] = add nuw i32 %[[VAL_37]], 1
// CHECK: br label %[[VAL_35]]
-// CHECK: omp_section_loop.exit: ; preds = %[[VAL_53]], %[[VAL_39]]
+// CHECK: omp_section_loop.exit:
// CHECK: call void @__kmpc_for_static_fini(ptr @1, i32 %[[VAL_30]])
// CHECK: %[[VAL_55:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
// CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_55]])
// CHECK: br label %[[VAL_56:.*]]
// CHECK: omp_section_loop.after: ; preds = %[[VAL_42]]
-// CHECK: br label %[[VAL_57:.*]]
-// CHECK: omp_section_loop.aftersections.fini: ; preds = %[[VAL_56]]
// CHECK: ret void
// CHECK: omp.section.region.cncl: ; preds = %[[VAL_48]]
// CHECK: br label %[[VAL_42]]
@@ -175,7 +175,7 @@ llvm.func @cancellation_point_wsloop(%lb : i32, %ub : i32, %step : i32) {
// CHECK: omp_loop.inc: ; preds = %[[VAL_106]]
// CHECK: %[[VAL_92]] = add nuw i32 %[[VAL_91]], 1
// CHECK: br label %[[VAL_89]]
-// CHECK: omp_loop.exit: ; preds = %[[VAL_105]], %[[VAL_93]]
+// CHECK: omp_loop.exit:
// CHECK: call void @__kmpc_for_static_fini(ptr @1, i32 %[[VAL_84]])
// CHECK: %[[VAL_107:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
// CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_107]])
diff --git a/mlir/test/Target/LLVMIR/openmp-dist_schedule.mlir b/mlir/test/Target/LLVMIR/openmp-dist_schedule.mlir
new file mode 100644
index 0000000..a0dd556
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-dist_schedule.mlir
@@ -0,0 +1,34 @@
+// Test that dist_schedule gets correctly translated with the correct schedule type and chunk size where appropriate
+
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
+
+llvm.func @distribute_dist_schedule_chunk_size(%lb : i32, %ub : i32, %step : i32, %x : i32) {
+ // CHECK: call void @[[RUNTIME_FUNC:__kmpc_for_static_init_4u]](ptr @1, i32 %omp_global_thread_num, i32 91, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 1024)
+ // We want to make sure that the next call is not another init builder.
+ // CHECK-NOT: call void @[[RUNTIME_FUNC]]
+ %1 = llvm.mlir.constant(1024: i32) : i32
+ omp.distribute dist_schedule_static dist_schedule_chunk_size(%1 : i32) {
+ omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+ omp.yield
+ }
+ }
+ llvm.return
+}
+
+// When a chunk size is present, we need to make sure the correct parallel accesses metadata is added
+// CHECK: !2 = !{!"llvm.loop.parallel_accesses", !3}
+// CHECK-NEXT: !3 = distinct !{}
+
+// -----
+
+llvm.func @distribute_dist_schedule(%lb : i32, %ub : i32, %step : i32, %x : i32) {
+ // CHECK: call void @[[RUNTIME_FUNC:__kmpc_for_static_init_4u]](ptr @1, i32 %omp_global_thread_num, i32 92, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 0)
+ // We want to make sure that the next call is not another init builder.
+ // CHECK-NOT: call void @[[RUNTIME_FUNC]]
+ omp.distribute dist_schedule_static {
+ omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+ omp.yield
+ }
+ }
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/openmp-dist_schedule_with_wsloop.mlir b/mlir/test/Target/LLVMIR/openmp-dist_schedule_with_wsloop.mlir
new file mode 100644
index 0000000..dad32b4
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-dist_schedule_with_wsloop.mlir
@@ -0,0 +1,205 @@
+// Test that dist_schedule gets correctly translated with the correct schedule type and chunk size where appropriate while using workshare loops.
+
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
+
+llvm.func @distribute_wsloop_dist_schedule_chunked_schedule_chunked(%n: i32, %teams: i32, %threads: i32, %dcs: i32) {
+ %0 = llvm.mlir.constant(0 : i32) : i32
+ %1 = llvm.mlir.constant(1 : i32) : i32
+ %scs = llvm.mlir.constant(64 : i32) : i32
+
+ omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) {
+ omp.parallel {
+ omp.distribute dist_schedule_static dist_schedule_chunk_size(%dcs : i32) {
+ omp.wsloop schedule(static = %scs : i32) {
+ omp.loop_nest (%i) : i32 = (%0) to (%n) step (%1) {
+ omp.yield
+ }
+ } {omp.composite}
+ } {omp.composite}
+ omp.terminator
+ } {omp.composite}
+ omp.terminator
+ }
+ llvm.return
+}
+// CHECK: define internal void @distribute_wsloop_dist_schedule_chunked_schedule_chunked..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #0 {
+// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 33, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 64)
+// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 91, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 %3)
+
+llvm.func @distribute_wsloop_dist_schedule_chunked_schedule_chunked_i64(%n: i32, %teams: i32, %threads: i32) {
+ %0 = llvm.mlir.constant(0 : i64) : i64
+ %1 = llvm.mlir.constant(1 : i64) : i64
+ %dcs = llvm.mlir.constant(1024 : i64) : i64
+ %scs = llvm.mlir.constant(64 : i64) : i64
+ %n64 = llvm.zext %n : i32 to i64
+
+ omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) {
+ omp.parallel {
+ omp.distribute dist_schedule_static dist_schedule_chunk_size(%dcs : i64) {
+ omp.wsloop schedule(static = %scs : i64) {
+ omp.loop_nest (%i) : i64 = (%0) to (%n64) step (%1) {
+ omp.yield
+ }
+ } {omp.composite}
+ } {omp.composite}
+ omp.terminator
+ } {omp.composite}
+ omp.terminator
+ }
+ llvm.return
+}
+// CHECK: define internal void @distribute_wsloop_dist_schedule_chunked_schedule_chunked_i64..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #0 {
+// CHECK: call void @__kmpc_for_static_init_8u(ptr @1, i32 %omp_global_thread_num9, i32 33, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i64 1, i64 64)
+// call void @__kmpc_for_static_init_8u(ptr @1, i32 %omp_global_thread_num9, i32 91, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i64 1, i64 1024)
+
+// -----
+
+llvm.func @distribute_wsloop_dist_schedule_chunked(%n: i32, %teams: i32, %threads: i32) {
+ %0 = llvm.mlir.constant(0 : i32) : i32
+ %1 = llvm.mlir.constant(1 : i32) : i32
+ %dcs = llvm.mlir.constant(1024 : i32) : i32
+
+ omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) {
+ omp.parallel {
+ omp.distribute dist_schedule_static dist_schedule_chunk_size(%dcs : i32) {
+ omp.wsloop schedule(static) {
+ omp.loop_nest (%i) : i32 = (%0) to (%n) step (%1) {
+ omp.yield
+ }
+ } {omp.composite}
+ } {omp.composite}
+ omp.terminator
+ } {omp.composite}
+ omp.terminator
+ }
+ llvm.return
+}
+// CHECK: define internal void @distribute_wsloop_dist_schedule_chunked..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #0 {
+// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 34, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 0)
+// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 91, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 1024)
+
+llvm.func @distribute_wsloop_dist_schedule_chunked_i64(%n: i32, %teams: i32, %threads: i32) {
+ %0 = llvm.mlir.constant(0 : i64) : i64
+ %1 = llvm.mlir.constant(1 : i64) : i64
+ %dcs = llvm.mlir.constant(1024 : i64) : i64
+ %n64 = llvm.zext %n : i32 to i64
+
+ omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) {
+ omp.parallel {
+ omp.distribute dist_schedule_static dist_schedule_chunk_size(%dcs : i64) {
+ omp.wsloop schedule(static) {
+ omp.loop_nest (%i) : i64 = (%0) to (%n64) step (%1) {
+ omp.yield
+ }
+ } {omp.composite}
+ } {omp.composite}
+ omp.terminator
+ } {omp.composite}
+ omp.terminator
+ }
+ llvm.return
+}
+// CHECK: define internal void @distribute_wsloop_dist_schedule_chunked_i64..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #0 {
+// CHECK: call void @__kmpc_for_static_init_8u(ptr @1, i32 %omp_global_thread_num9, i32 34, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i64 1, i64 0)
+// CHECK: call void @__kmpc_for_static_init_8u(ptr @1, i32 %omp_global_thread_num9, i32 91, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i64 1, i64 1024)
+
+// -----
+
+llvm.func @distribute_wsloop_schedule_chunked(%n: i32, %teams: i32, %threads: i32) {
+ %0 = llvm.mlir.constant(0 : i32) : i32
+ %1 = llvm.mlir.constant(1 : i32) : i32
+ %scs = llvm.mlir.constant(64 : i32) : i32
+
+ omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) {
+ omp.parallel {
+ omp.distribute dist_schedule_static {
+ omp.wsloop schedule(static = %scs : i32) {
+ omp.loop_nest (%i) : i32 = (%0) to (%n) step (%1) {
+ omp.yield
+ }
+ } {omp.composite}
+ } {omp.composite}
+ omp.terminator
+ } {omp.composite}
+ omp.terminator
+ }
+ llvm.return
+}
+// CHECK: define internal void @distribute_wsloop_schedule_chunked..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #0 {
+// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 33, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 64)
+// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 92, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 0)
+
+llvm.func @distribute_wsloop_schedule_chunked_i64(%n: i32, %teams: i32, %threads: i32) {
+ %0 = llvm.mlir.constant(0 : i64) : i64
+ %1 = llvm.mlir.constant(1 : i64) : i64
+ %scs = llvm.mlir.constant(64 : i64) : i64
+ %n64 = llvm.zext %n : i32 to i64
+
+ omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) {
+ omp.parallel {
+ omp.distribute dist_schedule_static {
+ omp.wsloop schedule(static = %scs : i64) {
+ omp.loop_nest (%i) : i64 = (%0) to (%n64) step (%1) {
+ omp.yield
+ }
+ } {omp.composite}
+ } {omp.composite}
+ omp.terminator
+ } {omp.composite}
+ omp.terminator
+ }
+ llvm.return
+}
+
+// CHECK: define internal void @distribute_wsloop_schedule_chunked_i64..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #0 {
+// CHECK: call void @__kmpc_for_static_init_8u(ptr @1, i32 %omp_global_thread_num9, i32 33, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i64 1, i64 64)
+// CHECK: call void @__kmpc_for_static_init_8u(ptr @1, i32 %omp_global_thread_num9, i32 92, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i64 1, i64 0)
+
+// -----
+
+llvm.func @distribute_wsloop_no_chunks(%n: i32, %teams: i32, %threads: i32) {
+ %0 = llvm.mlir.constant(0 : i32) : i32
+ %1 = llvm.mlir.constant(1 : i32) : i32
+
+ omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) {
+ omp.parallel {
+ omp.distribute dist_schedule_static {
+ omp.wsloop schedule(static) {
+ omp.loop_nest (%i) : i32 = (%0) to (%n) step (%1) {
+ omp.yield
+ }
+ } {omp.composite}
+ } {omp.composite}
+ omp.terminator
+ } {omp.composite}
+ omp.terminator
+ }
+ llvm.return
+}
+// CHECK: define internal void @distribute_wsloop_no_chunks..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #0 {
+// CHECK: call void @__kmpc_dist_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 34, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.distupperbound, ptr %p.stride, i32 1, i32 0)
+// CHECK: call void @__kmpc_dist_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 92, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.distupperbound10, ptr %p.stride, i32 1, i32 0)
+
+llvm.func @distribute_wsloop_no_chunks_i64(%n: i32, %teams: i32, %threads: i32) {
+ %0 = llvm.mlir.constant(0 : i64) : i64
+ %1 = llvm.mlir.constant(1 : i64) : i64
+ %n64 = llvm.zext %n : i32 to i64
+
+ omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) {
+ omp.parallel {
+ omp.distribute dist_schedule_static {
+ omp.wsloop schedule(static) {
+ omp.loop_nest (%i) : i64 = (%0) to (%n64) step (%1) {
+ omp.yield
+ }
+ } {omp.composite}
+ } {omp.composite}
+ omp.terminator
+ } {omp.composite}
+ omp.terminator
+ }
+ llvm.return
+}
+// CHECK: define internal void @distribute_wsloop_no_chunks_i64..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #0 {
+// CHECK: call void @__kmpc_dist_for_static_init_8u(ptr @1, i32 %omp_global_thread_num9, i32 34, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.distupperbound, ptr %p.stride, i64 1, i64 0)
+// CHECK: call void @__kmpc_dist_for_static_init_8u(ptr @1, i32 %omp_global_thread_num9, i32 92, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.distupperbound10, ptr %p.stride, i64 1, i64 0) \ No newline at end of file
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 8bd33a3..1eb501c 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -328,6 +328,52 @@ llvm.func @test_omp_masked(%arg0: i32)-> () {
// -----
+llvm.func @wsloop_linear(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
+// CHECK-LABEL: @wsloop_linear
+
+// CHECK: %p.lastiter = alloca i32, align 4
+// CHECK: %p.lowerbound = alloca i32, align 4
+// CHECK: %p.upperbound = alloca i32, align 4
+// CHECK: %p.stride = alloca i32, align 4
+// CHECK: %[[LINEAR_VAR:.*]] = alloca i32, align 4
+// CHECK: %[[LINEAR_RESULT:.*]] = alloca i32, align 4
+
+// CHECK: omp_loop.preheader:
+// CHECK: %[[LOAD:.*]] = load i32, ptr %{{.*}}, align 4
+// CHECK: store i32 %[[LOAD]], ptr %[[LINEAR_VAR]], align 4
+
+// CHECK: omp_loop.body:
+// CHECK: %[[LOOP_IV_CALC:.*]] = add i32 %omp_loop.iv, {{.*}}
+// CHECK: %[[LINEAR_VAR_LOAD:.*]] = load i32, ptr %[[LINEAR_VAR]], align 4
+// CHECK: %[[MUL:.*]] = mul i32 %[[LOOP_IV_CALC]], {{.*}}
+// CHECK: %[[ADD:.*]] = add i32 %[[LINEAR_VAR_LOAD]], %[[MUL]]
+// CHECK: store i32 %[[ADD]], ptr %[[LINEAR_RESULT]], align 4
+
+// CHECK: omp_loop.linear_finalization:
+// CHECK: %[[ITER:.*]] = load i32, ptr %p.lastiter, align 4
+// CHECK: %[[CMP:.*]] = icmp ne i32 %[[ITER]], 0
+// CHECK: br i1 %[[CMP]], label %omp_loop.linear_lastiter_exit, label %omp_loop.linear_exit
+
+// CHECK: omp_loop.linear_lastiter_exit:
+// CHECK: %[[LOAD:.*]] = load i32, ptr %[[LINEAR_RESULT]], align 4
+// CHECK: store i32 %[[LOAD]], ptr {{.*}}, align 4
+// CHECK: br label %omp_loop.linear_exit
+
+// CHECK: omp_loop.linear_exit:
+// CHECK: %[[THREAD_ID:.*]] = call i32 @__kmpc_global_thread_num(ptr {{.*}})
+// CHECK: call void @__kmpc_barrier(ptr {{.*}}, i32 %[[THREAD_ID]])
+// CHECK: br label %omp_loop.after
+
+ omp.wsloop linear(%x = %step : !llvm.ptr) {
+ omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+ omp.yield
+ }
+ } {linear_var_types = [i32]}
+ llvm.return
+}
+
+// -----
+
// CHECK: %struct.ident_t = type
// CHECK: @[[$loc:.*]] = private unnamed_addr constant {{.*}} c";unknown;unknown;{{[0-9]+}};{{[0-9]+}};;\00"
// CHECK: @[[$loc_struct:.*]] = private unnamed_addr constant %struct.ident_t {{.*}} @[[$loc]] {{.*}}
@@ -695,6 +741,34 @@ llvm.func @simd_simple(%lb : i64, %ub : i64, %step : i64, %arg0: !llvm.ptr) {
// -----
+llvm.func @simd_linear(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
+
+// CHECK-LABEL: @simd_linear
+
+// CHECK: %[[LINEAR_VAR:.*]] = alloca i32, align 4
+// CHECK: %[[LINEAR_RESULT:.*]] = alloca i32, align 4
+
+// CHECK: omp_loop.preheader:
+// CHECK: %[[LOAD:.*]] = load i32, ptr {{.*}}, align 4
+// CHECK: store i32 %[[LOAD]], ptr %[[LINEAR_VAR]], align 4
+
+// CHECK: omp_loop.body:
+// CHECK: %[[LOOP_IV_CALC:.*]] = mul i32 %omp_loop.iv, {{.*}}
+// CHECK: %[[ADD:.*]] = add i32 %[[LOOP_IV_CALC]], {{.*}}
+// CHECK: %[[LOAD:.*]] = load i32, ptr %[[LINEAR_VAR]], align 4, !llvm.access.group !1
+// CHECK: %[[MUL:.*]] = mul i32 %omp_loop.iv, {{.*}}
+// CHECK: %[[ADD:.*]] = add i32 %[[LOAD]], %[[MUL]]
+// CHECK: store i32 %[[ADD]], ptr %[[LINEAR_RESULT]], align 4, !llvm.access.group !1
+ omp.simd linear(%x = %step : !llvm.ptr) {
+ omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+ omp.yield
+ }
+ } {linear_var_types = [i32]}
+ llvm.return
+}
+
+// -----
+
// CHECK-LABEL: @simd_simple_multiple
llvm.func @simd_simple_multiple(%lb1 : i64, %ub1 : i64, %step1 : i64, %lb2 : i64, %ub2 : i64, %step2 : i64, %arg0: !llvm.ptr, %arg1: !llvm.ptr) {
omp.simd {
diff --git a/mlir/test/Target/LLVMIR/openmp-outline-infinite-loop.mlir b/mlir/test/Target/LLVMIR/openmp-outline-infinite-loop.mlir
index faccfc6..99f37c7 100644
--- a/mlir/test/Target/LLVMIR/openmp-outline-infinite-loop.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-outline-infinite-loop.mlir
@@ -21,9 +21,11 @@ llvm.func @parallel_infinite_loop() -> () {
// CHECK: omp.region.cont: ; No predecessors!
// CHECK: br label %[[VAL_4:.*]]
// CHECK: omp.par.pre_finalize: ; preds = %[[VAL_5:.*]]
-// CHECK: br label %[[VAL_6:.*]]
-// CHECK: omp.par.exit: ; preds = %[[VAL_4]]
+// CHECK: br label %[[FINI:.*]]
+// CHECK: [[OMP_PAR_EXIT:omp.par.exit]]: ; preds = %[[FINI]]
// CHECK: ret void
+// CHECK: [[FINI]]:
+// CHECK: br label %[[OMP_PAR_EXIT]]
// CHECK: }
// CHECK-LABEL: define internal void @parallel_infinite_loop..omp_par(
diff --git a/mlir/test/Target/LLVMIR/openmp-parallel-reduction-multiblock.mlir b/mlir/test/Target/LLVMIR/openmp-parallel-reduction-multiblock.mlir
index 887d297..c79c369 100644
--- a/mlir/test/Target/LLVMIR/openmp-parallel-reduction-multiblock.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-parallel-reduction-multiblock.mlir
@@ -108,6 +108,8 @@ llvm.func @missordered_blocks_(%arg0: !llvm.ptr {fir.bindc_name = "x"}, %arg1: !
// CHECK: reduce.finalize: ; preds = %[[VAL_49]], %[[VAL_43]]
// CHECK: br label %[[VAL_53:.*]]
// CHECK: omp.par.pre_finalize: ; preds = %[[VAL_48]]
+// CHECK: br label %[[FINI:.*]]
+// CHECK: .fini:
// CHECK: %[[VAL_54:.*]] = load ptr, ptr %[[VAL_20]], align 8
// CHECK: %[[VAL_55:.*]] = load ptr, ptr %[[VAL_21]], align 8
// CHECK: br label %[[VAL_56:.*]]
@@ -115,5 +117,5 @@ llvm.func @missordered_blocks_(%arg0: !llvm.ptr {fir.bindc_name = "x"}, %arg1: !
// CHECK: br label %[[VAL_38]]
// CHECK: omp.reduction.neutral1: ; preds = %[[VAL_25]]
// CHECK: br label %[[VAL_30]]
-// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_53]]
+// CHECK: omp.par.exit.exitStub: ; preds = %[[FINI]]
// CHECK: ret void
diff --git a/mlir/test/Target/LLVMIR/openmp-reduction-array-sections.mlir b/mlir/test/Target/LLVMIR/openmp-reduction-array-sections.mlir
index b302b4b..13f52f0 100644
--- a/mlir/test/Target/LLVMIR/openmp-reduction-array-sections.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-reduction-array-sections.mlir
@@ -127,8 +127,6 @@ llvm.func @sectionsreduction_(%arg0: !llvm.ptr {fir.bindc_name = "x"}) attribute
// CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_36]])
// CHECK: br label %[[VAL_37:.*]]
// CHECK: omp_section_loop.after: ; preds = %[[VAL_35]]
-// CHECK: br label %[[VAL_38:.*]]
-// CHECK: omp_section_loop.aftersections.fini: ; preds = %[[VAL_37]]
// CHECK: %[[VAL_39:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_14]], i64 0, i64 0
// CHECK: store ptr %[[VAL_21]], ptr %[[VAL_39]], align 8
// CHECK: %[[VAL_40:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
@@ -137,9 +135,9 @@ llvm.func @sectionsreduction_(%arg0: !llvm.ptr {fir.bindc_name = "x"}) attribute
// CHECK: i32 1, label %[[VAL_43:.*]]
// CHECK: i32 2, label %[[VAL_44:.*]]
// CHECK: ]
-// CHECK: reduce.switch.atomic: ; preds = %[[VAL_38]]
+// CHECK: reduce.switch.atomic: ; preds = %[[VAL_37]]
// CHECK: unreachable
-// CHECK: reduce.switch.nonatomic: ; preds = %[[VAL_38]]
+// CHECK: reduce.switch.nonatomic: ; preds = %[[VAL_37]]
// CHECK: %[[VAL_45:.*]] = load ptr, ptr %[[VAL_21]], align 8
// CHECK: br label %[[VAL_46:.*]]
// CHECK: omp.reduction.nonatomic.body: ; preds = %[[VAL_43]]
@@ -157,7 +155,7 @@ llvm.func @sectionsreduction_(%arg0: !llvm.ptr {fir.bindc_name = "x"}) attribute
// CHECK: omp.reduction.nonatomic.body17: ; preds = %[[VAL_47]]
// CHECK: %[[VAL_50]] = sub i64 %[[VAL_49]], 1
// CHECK: br label %[[VAL_47]]
-// CHECK: reduce.finalize: ; preds = %[[VAL_53]], %[[VAL_38]]
+// CHECK: reduce.finalize: ; preds = %[[VAL_53]], %[[VAL_37]]
// CHECK: %[[VAL_55:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
// CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_55]])
// CHECK: %[[VAL_56:.*]] = load ptr, ptr %[[VAL_21]], align 8
@@ -173,7 +171,9 @@ llvm.func @sectionsreduction_(%arg0: !llvm.ptr {fir.bindc_name = "x"}) attribute
// CHECK: omp.region.cont: ; preds = %[[VAL_62]]
// CHECK: br label %[[VAL_64:.*]]
// CHECK: omp.par.pre_finalize: ; preds = %[[VAL_63]]
-// CHECK: br label %[[VAL_65:.*]]
+// CHECK: br label %[[FINI:.fini.*]]
+// CHECK: [[FINI]]:
+// CHECK: br label %[[EXIT:.*]]
// CHECK: omp.reduction.cleanup21: ; preds = %[[VAL_57]]
// CHECK: br label %[[VAL_61]]
// CHECK: omp_section_loop.body: ; preds = %[[VAL_32]]
@@ -219,5 +219,5 @@ llvm.func @sectionsreduction_(%arg0: !llvm.ptr {fir.bindc_name = "x"}) attribute
// CHECK: omp_section_loop.inc: ; preds = %[[VAL_69]]
// CHECK: %[[VAL_31]] = add nuw i32 %[[VAL_30]], 1
// CHECK: br label %[[VAL_28]]
-// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_64]]
+// CHECK: omp.par.exit.exitStub: ; preds = %[[FINI]]
// CHECK: ret void
diff --git a/mlir/test/Target/LLVMIR/openmp-reduction-init-arg.mlir b/mlir/test/Target/LLVMIR/openmp-reduction-init-arg.mlir
index a714ca6..cb30d3b 100644
--- a/mlir/test/Target/LLVMIR/openmp-reduction-init-arg.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-reduction-init-arg.mlir
@@ -96,8 +96,10 @@ module {
// CHECK: reduce.finalize: ; preds = %[[VAL_34]], %[[VAL_28]]
// CHECK: br label %[[VAL_38:.*]]
// CHECK: omp.par.pre_finalize: ; preds = %[[VAL_33]]
+// CHECK: br label %[[FINI:.*]]
+// CHECK: [[FINI]]:
// CHECK: br label %[[VAL_39:.*]]
-// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_38]]
+// CHECK: omp.par.exit.exitStub: ; preds = %[[FINI]]
// CHECK: ret void
// CHECK: %[[VAL_40:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_41:.*]], i64 0, i64 0
// CHECK: %[[VAL_42:.*]] = load ptr, ptr %[[VAL_40]], align 8
diff --git a/mlir/test/Target/LLVMIR/openmp-reduction-sections.mlir b/mlir/test/Target/LLVMIR/openmp-reduction-sections.mlir
index 19da6f8..00f6c1b 100644
--- a/mlir/test/Target/LLVMIR/openmp-reduction-sections.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-reduction-sections.mlir
@@ -86,8 +86,6 @@ llvm.func @sections_(%arg0: !llvm.ptr {fir.bindc_name = "x"}) attributes {fir.in
// CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_40]])
// CHECK: br label %[[VAL_41:.*]]
// CHECK: omp_section_loop.after: ; preds = %[[VAL_39]]
-// CHECK: br label %[[VAL_42:.*]]
-// CHECK: omp_section_loop.aftersections.fini: ; preds = %[[VAL_41]]
// CHECK: %[[VAL_43:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_21]], i64 0, i64 0
// CHECK: store ptr %[[VAL_20]], ptr %[[VAL_43]], align 8
// CHECK: %[[VAL_44:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
@@ -96,23 +94,25 @@ llvm.func @sections_(%arg0: !llvm.ptr {fir.bindc_name = "x"}) attributes {fir.in
// CHECK: i32 1, label %[[VAL_47:.*]]
// CHECK: i32 2, label %[[VAL_48:.*]]
// CHECK: ]
-// CHECK: reduce.switch.atomic: ; preds = %[[VAL_42]]
+// CHECK: reduce.switch.atomic: ; preds = %[[VAL_41]]
// CHECK: unreachable
-// CHECK: reduce.switch.nonatomic: ; preds = %[[VAL_42]]
+// CHECK: reduce.switch.nonatomic: ; preds = %[[VAL_41]]
// CHECK: %[[VAL_49:.*]] = load float, ptr %[[VAL_11]], align 4
// CHECK: %[[VAL_50:.*]] = load float, ptr %[[VAL_20]], align 4
// CHECK: %[[VAL_51:.*]] = fadd contract float %[[VAL_49]], %[[VAL_50]]
// CHECK: store float %[[VAL_51]], ptr %[[VAL_11]], align 4
// CHECK: call void @__kmpc_end_reduce(ptr @1, i32 %[[VAL_44]], ptr @.gomp_critical_user_.reduction.var)
// CHECK: br label %[[VAL_46]]
-// CHECK: reduce.finalize: ; preds = %[[VAL_47]], %[[VAL_42]]
+// CHECK: reduce.finalize: ; preds = %[[VAL_47]], %[[VAL_41]]
// CHECK: %[[VAL_52:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
// CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_52]])
// CHECK: br label %[[VAL_53:.*]]
// CHECK: omp.region.cont: ; preds = %[[VAL_46]]
// CHECK: br label %[[VAL_54:.*]]
// CHECK: omp.par.pre_finalize: ; preds = %[[VAL_53]]
-// CHECK: br label %[[VAL_55:.*]]
+// CHECK: br label %[[FINI:.fini.*]]
+// CHECK: [[FINI]]:
+// CHECK: br label %[[EXIT:.*]]
// CHECK: omp_section_loop.body: ; preds = %[[VAL_36]]
// CHECK: %[[VAL_56:.*]] = add i32 %[[VAL_34]], %[[VAL_28]]
// CHECK: %[[VAL_57:.*]] = mul i32 %[[VAL_56]], 1
@@ -144,8 +144,10 @@ llvm.func @sections_(%arg0: !llvm.ptr {fir.bindc_name = "x"}) attributes {fir.in
// CHECK: omp_section_loop.inc: ; preds = %[[VAL_59]]
// CHECK: %[[VAL_35]] = add nuw i32 %[[VAL_34]], 1
// CHECK: br label %[[VAL_32]]
-// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_54]]
+// CHECK: omp.par.exit.exitStub: ; preds = %[[FINI]]
// CHECK: ret void
+
+// CHECK-LABEL: define internal void @.omp.reduction.func
// CHECK: %[[VAL_70:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_71:.*]], i64 0, i64 0
// CHECK: %[[VAL_72:.*]] = load ptr, ptr %[[VAL_70]], align 8
// CHECK: %[[VAL_73:.*]] = load float, ptr %[[VAL_72]], align 4
diff --git a/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir
index 504d91b..5c37817 100644
--- a/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir
@@ -102,7 +102,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
// DEVICE: call void @__kmpc_distribute_static_loop{{.*}}({{.*}}, ptr @[[DISTRIBUTE_OUTLINE:[^,]*]], {{.*}})
// DEVICE: define internal void @[[DISTRIBUTE_OUTLINE]]({{.*}})
-// DEVICE: call void @__kmpc_parallel_51(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}})
+// DEVICE: call void @__kmpc_parallel_60(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}}, i32 {{.*}})
// DEVICE: define internal void @[[PARALLEL_OUTLINE]]({{.*}})
// DEVICE: call void @__kmpc_for_static_loop{{.*}}({{.*}})
diff --git a/mlir/test/Target/LLVMIR/openmp-target-spmd.mlir b/mlir/test/Target/LLVMIR/openmp-target-spmd.mlir
index 20202fc..dae80ba 100644
--- a/mlir/test/Target/LLVMIR/openmp-target-spmd.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-target-spmd.mlir
@@ -84,7 +84,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
// DEVICE: call void @__kmpc_target_deinit()
// DEVICE: define internal void @[[TARGET_OUTLINE]]({{.*}})
-// DEVICE: call void @__kmpc_parallel_51(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}})
+// DEVICE: call void @__kmpc_parallel_60(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}}, i32 {{.*}})
// DEVICE: define internal void @[[PARALLEL_OUTLINE]]({{.*}})
// DEVICE: call void @[[DISTRIBUTE_OUTLINE:.*]]({{.*}})
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index af6d254..396c57a 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -39,19 +39,6 @@ llvm.func @distribute_allocate(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr
// -----
-llvm.func @distribute_dist_schedule(%lb : i32, %ub : i32, %step : i32, %x : i32) {
- // expected-error@below {{not yet implemented: Unhandled clause dist_schedule with chunk_size in omp.distribute operation}}
- // expected-error@below {{LLVM Translation failed for operation: omp.distribute}}
- omp.distribute dist_schedule_static dist_schedule_chunk_size(%x : i32) {
- omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
- omp.yield
- }
- }
- llvm.return
-}
-
-// -----
-
llvm.func @distribute_order(%lb : i32, %ub : i32, %step : i32) {
// expected-error@below {{not yet implemented: Unhandled clause order in omp.distribute operation}}
// expected-error@below {{LLVM Translation failed for operation: omp.distribute}}
@@ -116,19 +103,6 @@ llvm.func @sections_private(%x : !llvm.ptr) {
// -----
-llvm.func @simd_linear(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
- // expected-error@below {{not yet implemented: Unhandled clause linear in omp.simd operation}}
- // expected-error@below {{LLVM Translation failed for operation: omp.simd}}
- omp.simd linear(%x = %step : !llvm.ptr) {
- omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
- omp.yield
- }
- }
- llvm.return
-}
-
-// -----
-
omp.declare_reduction @add_f32 : f32
init {
^bb0(%arg: f32):
@@ -238,17 +212,6 @@ llvm.func @target_in_reduction(%x : !llvm.ptr) {
// -----
-llvm.func @target_is_device_ptr(%x : !llvm.ptr) {
- // expected-error@below {{not yet implemented: Unhandled clause is_device_ptr in omp.target operation}}
- // expected-error@below {{LLVM Translation failed for operation: omp.target}}
- omp.target is_device_ptr(%x : !llvm.ptr) {
- omp.terminator
- }
- llvm.return
-}
-
-// -----
-
llvm.func @target_enter_data_depend(%x: !llvm.ptr) {
// expected-error@below {{not yet implemented: Unhandled clause depend in omp.target_enter_data operation}}
// expected-error@below {{LLVM Translation failed for operation: omp.target_enter_data}}
@@ -448,19 +411,6 @@ llvm.func @wsloop_allocate(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
}
// -----
-
-llvm.func @wsloop_linear(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
- // expected-error@below {{not yet implemented: Unhandled clause linear in omp.wsloop operation}}
- // expected-error@below {{LLVM Translation failed for operation: omp.wsloop}}
- omp.wsloop linear(%x = %step : !llvm.ptr) {
- omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
- omp.yield
- }
- }
- llvm.return
-}
-
-// -----
llvm.func @wsloop_order(%lb : i32, %ub : i32, %step : i32) {
// expected-error@below {{not yet implemented: Unhandled clause order in omp.wsloop operation}}
// expected-error@below {{LLVM Translation failed for operation: omp.wsloop}}
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 8a848221..2c748ad5 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -14,30 +14,36 @@ llvm.func @rocdl_special_regs() -> i32 {
%5 = rocdl.workgroup.id.y : i32
// CHECK: call i32 @llvm.amdgcn.workgroup.id.z()
%6 = rocdl.workgroup.id.z : i32
+ // CHECK: call i32 @llvm.amdgcn.cluster.id.x()
+ %7 = rocdl.cluster.id.x : i32
+ // CHECK: call i32 @llvm.amdgcn.cluster.id.y()
+ %8 = rocdl.cluster.id.y : i32
+ // CHECK: call i32 @llvm.amdgcn.cluster.id.z()
+ %9 = rocdl.cluster.id.z : i32
// CHECK: call i64 @__ockl_get_local_size(i32 0)
- %7 = rocdl.workgroup.dim.x : i64
+ %10 = rocdl.workgroup.dim.x : i64
// CHECK: call i64 @__ockl_get_local_size(i32 1)
- %8 = rocdl.workgroup.dim.y : i64
+ %11 = rocdl.workgroup.dim.y : i64
// CHECK: call i64 @__ockl_get_local_size(i32 2)
- %9 = rocdl.workgroup.dim.z : i64
+ %12 = rocdl.workgroup.dim.z : i64
// CHECK: call i64 @__ockl_get_num_groups(i32 0)
- %10 = rocdl.grid.dim.x : i64
+ %13 = rocdl.grid.dim.x : i64
// CHECK: call i64 @__ockl_get_num_groups(i32 1)
- %11 = rocdl.grid.dim.y : i64
+ %14 = rocdl.grid.dim.y : i64
// CHECK: call i64 @__ockl_get_num_groups(i32 2)
- %12 = rocdl.grid.dim.z : i64
+ %15 = rocdl.grid.dim.z : i64
// CHECK: call range(i32 0, 64) i32 @llvm.amdgcn.workitem.id.x()
- %13 = rocdl.workitem.id.x range <i32, 0, 64> : i32
+ %16 = rocdl.workitem.id.x range <i32, 0, 64> : i32
// CHECK: call range(i64 1, 65) i64 @__ockl_get_local_size(i32 0)
- %14 = rocdl.workgroup.dim.x range <i32, 1, 65> : i64
+ %17 = rocdl.workgroup.dim.x range <i32, 1, 65> : i64
// CHECK: call i32 @llvm.amdgcn.wavefrontsize()
- %15 = rocdl.wavefrontsize : i32
+ %18 = rocdl.wavefrontsize : i32
// CHECK: call range(i32 32, 65) i32 @llvm.amdgcn.wavefrontsize()
- %16 = rocdl.wavefrontsize range <i32, 32, 65> : i32
+ %19 = rocdl.wavefrontsize range <i32, 32, 65> : i32
llvm.return %1 : i32
}
@@ -55,6 +61,59 @@ llvm.func @kernel_func_workgroups()
llvm.return
}
+llvm.func @kernel_math_ops(%a: f32, %b: f16, %c: bf16) {
+ // CHECK-LABEL: kernel_math_ops
+ // CHECK: call float @llvm.amdgcn.tanh.f32(float %{{.*}})
+ // CHECK: call half @llvm.amdgcn.tanh.f16(half %{{.*}})
+ // CHECK: call bfloat @llvm.amdgcn.tanh.bf16(bfloat %{{.*}})
+ %tanh0 = rocdl.tanh %a f32 -> f32
+ %tanh1 = rocdl.tanh %b f16 -> f16
+ %tanh2 = rocdl.tanh %c bf16 -> bf16
+
+ // CHECK: call float @llvm.amdgcn.sin.f32(float %{{.*}})
+ // CHECK: call half @llvm.amdgcn.sin.f16(half %{{.*}})
+ // CHECK: call bfloat @llvm.amdgcn.sin.bf16(bfloat %{{.*}})
+ %sin0 = rocdl.sin %a f32 -> f32
+ %sin1 = rocdl.sin %b f16 -> f16
+ %sin2 = rocdl.sin %c bf16 -> bf16
+
+ // CHECK: call float @llvm.amdgcn.cos.f32(float %{{.*}})
+ // CHECK: call half @llvm.amdgcn.cos.f16(half %{{.*}})
+ // CHECK: call bfloat @llvm.amdgcn.cos.bf16(bfloat %{{.*}})
+ %cos0 = rocdl.cos %a f32 -> f32
+ %cos1 = rocdl.cos %b f16 -> f16
+ %cos2 = rocdl.cos %c bf16 -> bf16
+
+ // CHECK: call float @llvm.amdgcn.rcp.f32(float %{{.*}})
+ // CHECK: call half @llvm.amdgcn.rcp.f16(half %{{.*}})
+ // CHECK: call bfloat @llvm.amdgcn.rcp.bf16(bfloat %{{.*}})
+ %rcp0 = rocdl.rcp %a f32 -> f32
+ %rcp1 = rocdl.rcp %b f16 -> f16
+ %rcp2 = rocdl.rcp %c bf16 -> bf16
+
+ // CHECK: call float @llvm.amdgcn.exp2.f32(float %{{.*}})
+ // CHECK: call half @llvm.amdgcn.exp2.f16(half %{{.*}})
+ // CHECK: call bfloat @llvm.amdgcn.exp2.bf16(bfloat %{{.*}})
+ %exp2_0 = rocdl.exp2 %a f32 -> f32
+ %exp2_1 = rocdl.exp2 %b f16 -> f16
+ %exp2_2 = rocdl.exp2 %c bf16 -> bf16
+
+ // CHECK: call float @llvm.amdgcn.log.f32(float %{{.*}})
+ // CHECK: call half @llvm.amdgcn.log.f16(half %{{.*}})
+ // CHECK: call bfloat @llvm.amdgcn.log.bf16(bfloat %{{.*}})
+ %log0 = rocdl.log %a f32 -> f32
+ %log1 = rocdl.log %b f16 -> f16
+ %log2 = rocdl.log %c bf16 -> bf16
+
+ // CHECK: call float @llvm.amdgcn.sqrt.f32(float %{{.*}})
+ // CHECK: call half @llvm.amdgcn.sqrt.f16(half %{{.*}})
+ // CHECK: call bfloat @llvm.amdgcn.sqrt.bf16(bfloat %{{.*}})
+ %sqrt0 = rocdl.sqrt %a f32 -> f32
+ %sqrt1 = rocdl.sqrt %b f16 -> f16
+ %sqrt2 = rocdl.sqrt %c bf16 -> bf16
+ llvm.return
+}
+
llvm.func @known_block_sizes()
attributes {rocdl.kernel,
rocdl.flat_work_group_size = "128,128",
@@ -248,6 +307,13 @@ llvm.func @rocdl.s.get.barrier.state() {
llvm.return
}
+llvm.func @rocdl.s.get.named.barrier.state(%ptr : !llvm.ptr<3>) {
+ // CHECK-LABEL: rocdl.s.get.named.barrier.state
+ // CHECK: %[[STATE:.+]] = call i32 @llvm.amdgcn.s.get.named.barrier.state(ptr addrspace(3) %[[PTR:.+]])
+ %0 = rocdl.s.get.named.barrier.state %ptr : i32
+ llvm.return
+}
+
llvm.func @rocdl.s.wait.dscnt() {
// CHECK-LABEL: rocdl.s.wait.dscnt
// CHECK-NEXT: call void @llvm.amdgcn.s.wait.dscnt(i16 0)
@@ -875,140 +941,182 @@ llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : v
%arg4 : vector<2xi32>, %arg5 : vector<4xi32>, %arg6 : vector<4xf32>, %arg7 : vector<8xf16>, %arg8 : vector<8xi16>,
%arg9 : vector<32xf16>, %arg10 : vector<16xf32>, %arg11 : vector<4xf32>, %arg12 : vector<32xf32>, %arg13 : vector<64xf32>,
%arg14 : vector<64xi32>, %arg15 : vector<64xf16>, %arg16 : vector<16xbf16>, %arg17 : vector<32xbf16>) -> vector<8xf32> {
- %zero = llvm.mlir.constant(false) : i1
- %zero_i16 = llvm.mlir.constant(0 : i16) : i16
- // ---- Wave32 -----
+ // ---- Wave32 -----
// f16 -> f32
- // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16.v8f32.v16f16(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <8 x float> %{{.*}})
+ // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16.v8f32.v16f16(<16 x half> %{{.*}} <16 x half> %{{.*}} <8 x float> %{{.*}})
%r0 = rocdl.wmma.f32.16x16x16.f16 %arg1, %arg1, %arg0 : (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32>
// bf16 -> f32
- // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf16.v8f32.v16i16(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <8 x float> %{{.*}})
+ // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf16.v8f32.v16i16(<16 x i16> %{{.*}} <16 x i16> %{{.*}} <8 x float> %{{.*}})
%r1 = rocdl.wmma.f32.16x16x16.bf16 %arg2, %arg2, %arg0 : (vector<16xi16>, vector<16xi16>, vector<8xf32>) -> vector<8xf32>
// f16 -> f16 (OPSEL = {0,1})
- // CHECK: call <16 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16.v16f16.v16f16(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <16 x half> %{{.*}}, i1 {{.*}})
- %r2 = rocdl.wmma.f16.16x16x16.f16 %arg1, %arg1, %arg1, %zero : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
+ // CHECK: call <16 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16.v16f16.v16f16(<16 x half> %{{.*}} <16 x half> %{{.*}} <16 x half> %{{.*}} i1 false)
+ %r2 = rocdl.wmma.f16.16x16x16.f16 %arg1, %arg1, %arg1 {opsel = false} : (vector<16xf16>, vector<16xf16>, vector<16xf16>) -> vector<16xf16>
// bf16 -> bf16 (OPSEL = {0,1})
- // CHECK: call <16 x i16> @llvm.amdgcn.wmma.bf16.16x16x16.bf16.v16i16.v16i16(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <16 x i16> %{{.*}}, i1 {{.*}})
- %r4 = rocdl.wmma.bf16.16x16x16.bf16 %arg2, %arg2, %arg2, %zero : (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16>
+ // CHECK: call <16 x i16> @llvm.amdgcn.wmma.bf16.16x16x16.bf16.v16i16.v16i16(<16 x i16> %{{.*}} <16 x i16> %{{.*}} <16 x i16> %{{.*}} i1 false)
+ %r4 = rocdl.wmma.bf16.16x16x16.bf16 %arg2, %arg2, %arg2 {opsel = false} : (vector<16xi16>, vector<16xi16>, vector<16xi16>) -> vector<16xi16>
// int8 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1})
- // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v8i32.v4i32(i1 {{.*}}, <4 x i32> %{{.*}}, i1 {{.*}}, <4 x i32> %{{.*}}, <8 x i32> %{{.*}}, i1 {{.*}})
- %r5 = rocdl.wmma.i32.16x16x16.iu8 %zero, %arg5, %zero, %arg5, %arg3, %zero : (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32>
+ // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v8i32.v4i32(i1 false, <4 x i32> %{{.*}} i1 false, <4 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false)
+ %r5 = rocdl.wmma.i32.16x16x16.iu8 %arg5, %arg5, %arg3 {signA = false, signB = false, clamp = false} : (vector<4xi32>, vector<4xi32>, vector<8xi32>) -> vector<8xi32>
// int4 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1})
- // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v8i32.v2i32(i1 {{.*}}, <2 x i32> %{{.*}}, i1 {{.*}}, <2 x i32> %{{.*}}, <8 x i32> %{{.*}}, i1 {{.*}})
- %r6 = rocdl.wmma.i32.16x16x16.iu4 %zero, %arg4, %zero, %arg4, %arg3, %zero : (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
+ // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v8i32.v2i32(i1 false, <2 x i32> %{{.*}} i1 false, <2 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false)
+ %r6 = rocdl.wmma.i32.16x16x16.iu4 %arg4, %arg4, %arg3 {signA = false, signB = false, clamp = false} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32>
// int4 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1})
- // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x32.iu4.v8i32.v2i32(i1 {{.*}}, <2 x i32> %{{.*}}, i1 {{.*}}, <2 x i32> %{{.*}}, <8 x i32> %{{.*}}, i1 {{.*}})
- %r6.gfx12 = rocdl.wmma.i32.16x16x32.iu4 %zero, %arg4, %zero, %arg4, %arg3, %zero : (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
+ // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x32.iu4.v8i32.v2i32(i1 false, <2 x i32> %{{.*}} i1 false, <2 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false)
+ %r6.gfx12 = rocdl.wmma.i32.16x16x32.iu4 %arg4, %arg4, %arg3 {signA = false, signB = false, clamp = false} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32>
+
+ // Test signA=true, signB=false for iu8
+ // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v8i32.v4i32(i1 true, <4 x i32> %{{.*}} i1 false, <4 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false)
+ %r5a = rocdl.wmma.i32.16x16x16.iu8 %arg5, %arg5, %arg3 {signA = true, signB = false, clamp = false} : (vector<4xi32>, vector<4xi32>, vector<8xi32>) -> vector<8xi32>
+
+ // Test signA=false, signB=true for iu8
+ // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v8i32.v4i32(i1 false, <4 x i32> %{{.*}} i1 true, <4 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false)
+ %r5b = rocdl.wmma.i32.16x16x16.iu8 %arg5, %arg5, %arg3 {signA = false, signB = true, clamp = false} : (vector<4xi32>, vector<4xi32>, vector<8xi32>) -> vector<8xi32>
+
+ // Test signA=true, signB=true, clamp=true for iu8
+ // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v8i32.v4i32(i1 true, <4 x i32> %{{.*}} i1 true, <4 x i32> %{{.*}} <8 x i32> %{{.*}} i1 true)
+ %r5c = rocdl.wmma.i32.16x16x16.iu8 %arg5, %arg5, %arg3 {signA = true, signB = true, clamp = true} : (vector<4xi32>, vector<4xi32>, vector<8xi32>) -> vector<8xi32>
+
+ // Test signA=true, signB=false for iu4
+ // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v8i32.v2i32(i1 true, <2 x i32> %{{.*}} i1 false, <2 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false)
+ %r6a = rocdl.wmma.i32.16x16x16.iu4 %arg4, %arg4, %arg3 {signA = true, signB = false, clamp = false} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32>
+
+ // Test signA=false, signB=true, clamp=true for iu4
+ // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v8i32.v2i32(i1 false, <2 x i32> %{{.*}} i1 true, <2 x i32> %{{.*}} <8 x i32> %{{.*}} i1 true)
+ %r6b = rocdl.wmma.i32.16x16x16.iu4 %arg4, %arg4, %arg3 {signA = false, signB = true, clamp = true} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32>
+
+ // Test signA=true, signB=true for iu4 gfx12
+ // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x32.iu4.v8i32.v2i32(i1 true, <2 x i32> %{{.*}} i1 true, <2 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false)
+ %r6c = rocdl.wmma.i32.16x16x32.iu4 %arg4, %arg4, %arg3 {signA = true, signB = true, clamp = false} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32>
// f32 -> f32
- // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x4.f32.v4f32.v16f32(i1 {{.*}}, <16 x float> %{{.*}}, i1 {{.*}}, <16 x float> %{{.*}}, i16 0, <4 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
- %r1.gfx1250 = rocdl.wmma.f32.16x16x4.f32 %zero, %arg10, %zero, %arg10, %zero_i16, %arg11, %zero, %zero : (i1, vector<16xf32>, i1, vector<16xf32>, i16, vector<4xf32>, i1, i1) -> vector<4xf32>
+ // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x4.f32.v4f32.v16f32(i1 false, <16 x float> %{{.*}} i1 false, <16 x float> %{{.*}} i16 0, <4 x float> %{{.*}} i1 false, i1 false)
+ %r1.gfx1250 = rocdl.wmma.f32.16x16x4.f32 %arg10, %arg10, %arg11 {signA = false, signB = false, modC = 0 : i16} : (vector<16xf32>, vector<16xf32>, vector<4xf32>) -> vector<4xf32>
// f16 -> f32
- // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v32f32.v16f16(i1 {{.*}}, <16 x half> %{{.*}}, i1 {{.*}}, <16 x half> %{{.*}}, i16 0, <32 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
- %r2.gfx1250 = rocdl.wmma.f32.16x16x32.f16 %zero, %arg1, %zero, %arg1, %zero_i16, %arg12, %zero, %zero : (i1, vector<16xf16>, i1, vector<16xf16>, i16, vector<32xf32>, i1, i1) -> vector<32xf32>
+ // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v32f32.v16f16(i1 false, <16 x half> %{{.*}} i1 false, <16 x half> %{{.*}} i16 0, <32 x float> %{{.*}} i1 false, i1 false)
+ %r2.gfx1250 = rocdl.wmma.f32.16x16x32.f16 %arg1, %arg1, %arg12 {signA = false, signB = false, modC = 0 : i16} : (vector<16xf16>, vector<16xf16>, vector<32xf32>) -> vector<32xf32>
// bf16 -> f32
- // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.bf16.v32f32.v16bf16(i1 {{.*}}, <16 x bfloat> %{{.*}}, i1 {{.*}}, <16 x bfloat> %{{.*}}, i16 0, <32 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
- %r3.gfx1250 = rocdl.wmma.f32.16x16x32.bf16 %zero, %arg16, %zero, %arg16, %zero_i16, %arg12, %zero, %zero : (i1, vector<16xbf16>, i1, vector<16xbf16>, i16, vector<32xf32>, i1, i1) -> vector<32xf32>
+ // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.bf16.v32f32.v16bf16(i1 false, <16 x bfloat> %{{.*}} i1 false, <16 x bfloat> %{{.*}} i16 0, <32 x float> %{{.*}} i1 false, i1 false)
+ %r3.gfx1250 = rocdl.wmma.f32.16x16x32.bf16 %arg16, %arg16, %arg12 {signA = false, signB = false, modC = 0 : i16} : (vector<16xbf16>, vector<16xbf16>, vector<32xf32>) -> vector<32xf32>
// f16 -> f16
- // CHECK: call <32 x half> @llvm.amdgcn.wmma.f16.16x16x32.f16.v32f16.v16f16(i1 {{.*}}, <16 x half> %{{.*}}, i1 {{.*}}, <16 x half> %{{.*}}, i16 0, <32 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
- %r4.gfx1250 = rocdl.wmma.f16.16x16x32.f16 %zero, %arg1, %zero, %arg1, %zero_i16, %arg9, %zero, %zero : (i1, vector<16xf16>, i1, vector<16xf16>, i16, vector<32xf16>, i1, i1) -> vector<32xf16>
+ // CHECK: call <32 x half> @llvm.amdgcn.wmma.f16.16x16x32.f16.v32f16.v16f16(i1 false, <16 x half> %{{.*}} i1 false, <16 x half> %{{.*}} i16 0, <32 x half> %{{.*}} i1 false, i1 false)
+ %r4.gfx1250 = rocdl.wmma.f16.16x16x32.f16 %arg1, %arg1, %arg9 {signA = false, signB = false, modC = 0 : i16} : (vector<16xf16>, vector<16xf16>, vector<32xf16>) -> vector<32xf16>
// bf16 -> bf16
- // CHECK: call <32 x bfloat> @llvm.amdgcn.wmma.bf16.16x16x32.bf16.v32bf16.v16bf16(i1 {{.*}}, <16 x bfloat> %{{.*}}, i1 {{.*}}, <16 x bfloat> %{{.*}}, i16 0, <32 x bfloat> %{{.*}}, i1 {{.*}}, i1 {{.*}})
- %r5.gfx1250 = rocdl.wmma.bf16.16x16x32.bf16 %zero, %arg16, %zero, %arg16, %zero_i16, %arg17, %zero, %zero : (i1, vector<16xbf16>, i1, vector<16xbf16>, i16, vector<32xbf16>, i1, i1) -> vector<32xbf16>
+ // CHECK: call <32 x bfloat> @llvm.amdgcn.wmma.bf16.16x16x32.bf16.v32bf16.v16bf16(i1 false, <16 x bfloat> %{{.*}} i1 false, <16 x bfloat> %{{.*}} i16 0, <32 x bfloat> %{{.*}} i1 false, i1 false)
+ %r5.gfx1250 = rocdl.wmma.bf16.16x16x32.bf16 %arg16, %arg16, %arg17 {signA = false, signB = false, modC = 0 : i16} : (vector<16xbf16>, vector<16xbf16>, vector<32xbf16>) -> vector<32xbf16>
// bf16 -> bf16 / f32
- // CHECK: call <32 x bfloat> @llvm.amdgcn.wmma.bf16f32.16x16x32.bf16.v32bf16.v16bf16.v32f32(i1 {{.*}}, <16 x bfloat> %{{.*}}, i1 {{.*}}, <16 x bfloat> %{{.*}}, i16 0, <32 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
- %r6.gfx1250 = rocdl.wmma.bf16f32.16x16x32.bf16 %zero, %arg16, %zero, %arg16, %zero_i16, %arg12, %zero, %zero : (i1, vector<16xbf16>, i1, vector<16xbf16>, i16, vector<32xf32>, i1, i1) -> vector<32xbf16>
+ // CHECK: call <32 x bfloat> @llvm.amdgcn.wmma.bf16f32.16x16x32.bf16.v32bf16.v16bf16.v32f32(i1 false, <16 x bfloat> %{{.*}} i1 false, <16 x bfloat> %{{.*}} i16 0, <32 x float> %{{.*}} i1 false, i1 false)
+ %r6.gfx1250 = rocdl.wmma.bf16f32.16x16x32.bf16 %arg16, %arg16, %arg12 {signA = false, signB = false, modC = 0 : i16} : (vector<16xbf16>, vector<16xbf16>, vector<32xf32>) -> vector<32xbf16>
// f8/bf8 -> f16/f32
- // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.fp8.fp8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
- %r7.gfx1250 = rocdl.wmma.f32.16x16x64.fp8_fp8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32>
+ // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.fp8.fp8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false)
+ %r7.gfx1250 = rocdl.wmma.f32.16x16x64.fp8_fp8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32>
- // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.fp8.bf8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
- %r8.gfx1250 = rocdl.wmma.f32.16x16x64.fp8_bf8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32>
+ // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.fp8.bf8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false)
+ %r8.gfx1250 = rocdl.wmma.f32.16x16x64.fp8_bf8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32>
- // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.bf8.fp8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
- %r9.gfx1250 = rocdl.wmma.f32.16x16x64.bf8_fp8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32>
+ // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.bf8.fp8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false)
+ %r9.gfx1250 = rocdl.wmma.f32.16x16x64.bf8_fp8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32>
- // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.bf8.bf8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
- %r10.gfx1250 = rocdl.wmma.f32.16x16x64.bf8_bf8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32>
+ // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.bf8.bf8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false)
+ %r10.gfx1250 = rocdl.wmma.f32.16x16x64.bf8_bf8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32>
- // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.fp8.fp8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
- %r11.gfx1250 = rocdl.wmma.f16.16x16x64.fp8_fp8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16>
+ // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.fp8.fp8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false)
+ %r11.gfx1250 = rocdl.wmma.f16.16x16x64.fp8_fp8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16>
- // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.fp8.bf8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
- %r12.gfx1250 = rocdl.wmma.f16.16x16x64.fp8_bf8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16>
+ // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.fp8.bf8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false)
+ %r12.gfx1250 = rocdl.wmma.f16.16x16x64.fp8_bf8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16>
- // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.bf8.fp8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
- %r13.gfx1250 = rocdl.wmma.f16.16x16x64.bf8_fp8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16>
+ // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.bf8.fp8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false)
+ %r13.gfx1250 = rocdl.wmma.f16.16x16x64.bf8_fp8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16>
- // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.bf8.bf8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
- %r14.gfx1250 = rocdl.wmma.f16.16x16x64.bf8_bf8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16>
+ // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.bf8.bf8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false)
+ %r14.gfx1250 = rocdl.wmma.f16.16x16x64.bf8_bf8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16>
- // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.fp8.fp8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
- %r15.gfx1250 = rocdl.wmma.f32.16x16x128.fp8_fp8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32>
+ // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.fp8.fp8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false)
+ %r15.gfx1250 = rocdl.wmma.f32.16x16x128.fp8_fp8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32>
- // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.fp8.bf8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
- %r16.gfx1250 = rocdl.wmma.f32.16x16x128.fp8_bf8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32>
+ // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.fp8.bf8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false)
+ %r16.gfx1250 = rocdl.wmma.f32.16x16x128.fp8_bf8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32>
- // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.bf8.fp8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
- %r17.gfx1250 = rocdl.wmma.f32.16x16x128.bf8_fp8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32>
+ // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.bf8.fp8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false)
+ %r17.gfx1250 = rocdl.wmma.f32.16x16x128.bf8_fp8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32>
- // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.bf8.bf8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
- %r18.gfx1250 = rocdl.wmma.f32.16x16x128.bf8_bf8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32>
+ // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.bf8.bf8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false)
+ %r18.gfx1250 = rocdl.wmma.f32.16x16x128.bf8_bf8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32>
- // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.fp8.fp8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
- %r19.gfx1250 = rocdl.wmma.f16.16x16x128.fp8_fp8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16>
+ // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.fp8.fp8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false)
+ %r19.gfx1250 = rocdl.wmma.f16.16x16x128.fp8_fp8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16>
- // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.fp8.bf8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
- %r20.gfx1250 = rocdl.wmma.f16.16x16x128.fp8_bf8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16>
+ // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.fp8.bf8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false)
+ %r20.gfx1250 = rocdl.wmma.f16.16x16x128.fp8_bf8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16>
- // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.bf8.fp8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
- %r21.gfx1250 = rocdl.wmma.f16.16x16x128.bf8_fp8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16>
+ // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.bf8.fp8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false)
+ %r21.gfx1250 = rocdl.wmma.f16.16x16x128.bf8_fp8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16>
- // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.bf8.bf8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
- %r22.gfx1250 = rocdl.wmma.f16.16x16x128.bf8_bf8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16>
+ // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.bf8.bf8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false)
+ %r22.gfx1250 = rocdl.wmma.f16.16x16x128.bf8_bf8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16>
// iu8 -> i32
- // CHECK: call <64 x i32> @llvm.amdgcn.wmma.i32.16x16x64.iu8.v64i32.v4i32(i1 {{.*}}, <4 x i32> %{{.*}}, i1 {{.*}}, <4 x i32> %{{.*}}, <64 x i32> %{{.*}}, i1 {{.*}}, i1 {{.*}})
- %r23.gfx1250 = rocdl.wmma.i32.16x16x64.iu8 %zero, %arg5, %zero, %arg5, %arg14, %zero, %zero : (i1, vector<4xi32>, i1, vector<4xi32>, vector<64xi32>, i1, i1) -> vector<64xi32>
+ // CHECK: call <64 x i32> @llvm.amdgcn.wmma.i32.16x16x64.iu8.v64i32.v4i32(i1 false, <4 x i32> %{{.*}} i1 false, <4 x i32> %{{.*}} <64 x i32> %{{.*}} i1 false, i1 false)
+ %r23.gfx1250 = rocdl.wmma.i32.16x16x64.iu8 %arg5, %arg5, %arg14 {signA = false, signB = false} : (vector<4xi32>, vector<4xi32>, vector<64xi32>) -> vector<64xi32>
+
+ // Test signA=true, signB=true for iu8 gfx1250
+ // CHECK: call <64 x i32> @llvm.amdgcn.wmma.i32.16x16x64.iu8.v64i32.v4i32(i1 true, <4 x i32> %{{.*}} i1 true, <4 x i32> %{{.*}} <64 x i32> %{{.*}} i1 false, i1 false)
+ %r23a.gfx1250 = rocdl.wmma.i32.16x16x64.iu8 %arg5, %arg5, %arg14 {signA = true, signB = true} : (vector<4xi32>, vector<4xi32>, vector<64xi32>) -> vector<64xi32>
+
+ // Test signA=true, signB=false, reuseA=true, reuseB=true for iu8 gfx1250
+ // CHECK: call <64 x i32> @llvm.amdgcn.wmma.i32.16x16x64.iu8.v64i32.v4i32(i1 true, <4 x i32> %{{.*}} i1 false, <4 x i32> %{{.*}} <64 x i32> %{{.*}} i1 true, i1 true)
+ %r23b.gfx1250 = rocdl.wmma.i32.16x16x64.iu8 %arg5, %arg5, %arg14 {signA = true, signB = false, reuseA = true, reuseB = true} : (vector<4xi32>, vector<4xi32>, vector<64xi32>) -> vector<64xi32>
+
+ // Test signA=true, signB=true with modC=1 for f32 gfx1250
+ // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x4.f32.v4f32.v16f32(i1 true, <16 x float> %{{.*}} i1 true, <16 x float> %{{.*}} i16 1, <4 x float> %{{.*}} i1 false, i1 false)
+ %r1a.gfx1250 = rocdl.wmma.f32.16x16x4.f32 %arg10, %arg10, %arg11 {signA = true, signB = true, modC = 1 : i16, reuseA = false, reuseB = false} : (vector<16xf32>, vector<16xf32>, vector<4xf32>) -> vector<4xf32>
+
+ // Test with modC=2 and signA=false, signB=true, reuseA=true for f16 gfx1250
+ // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v32f32.v16f16(i1 false, <16 x half> %{{.*}} i1 true, <16 x half> %{{.*}} i16 2, <32 x float> %{{.*}} i1 true, i1 false)
+ %r2a.gfx1250 = rocdl.wmma.f32.16x16x32.f16 %arg1, %arg1, %arg12 {signA = false, signB = true, modC = 2 : i16, reuseA = true, reuseB = false} : (vector<16xf16>, vector<16xf16>, vector<32xf32>) -> vector<32xf32>
+
+ // Test with modC=3 and signA=true, signB=true, reuseB=true for bf16 gfx1250
+ // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.bf16.v32f32.v16bf16(i1 true, <16 x bfloat> %{{.*}} i1 true, <16 x bfloat> %{{.*}} i16 3, <32 x float> %{{.*}} i1 false, i1 true)
+ %r3a.gfx1250 = rocdl.wmma.f32.16x16x32.bf16 %arg16, %arg16, %arg12 {signA = true, signB = true, modC = 3 : i16, reuseA = false, reuseB = true} : (vector<16xbf16>, vector<16xbf16>, vector<32xf32>) -> vector<32xf32>
// ---- Wave64 -----
// f16 -> f32
- // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16.v4f32.v16f16(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <4 x float> %{{.*}})
+ // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16.v4f32.v16f16(<16 x half> %{{.*}} <16 x half> %{{.*}} <4 x float> %{{.*}})
%r7 = rocdl.wmma.f32.16x16x16.f16 %arg1, %arg1, %arg6 : (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32>
// bf16 -> f32
- // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf16.v4f32.v16i16(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <4 x float> %{{.*}})
+ // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf16.v4f32.v16i16(<16 x i16> %{{.*}} <16 x i16> %{{.*}} <4 x float> %{{.*}})
%r8 = rocdl.wmma.f32.16x16x16.bf16 %arg2, %arg2, %arg6 : (vector<16xi16>, vector<16xi16>, vector<4xf32>) -> vector<4xf32>
// f16 -> f16 (OPSEL = {0,1})
- // CHECK: call <8 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16.v8f16.v16f16(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <8 x half> %{{.*}}, i1 {{.*}})
- %r9 = rocdl.wmma.f16.16x16x16.f16 %arg1, %arg1, %arg7, %zero : (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16>
+ // CHECK: call <8 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16.v8f16.v16f16(<16 x half> %{{.*}} <16 x half> %{{.*}} <8 x half> %{{.*}} i1 false)
+ %r9 = rocdl.wmma.f16.16x16x16.f16 %arg1, %arg1, %arg7 {opsel = false} : (vector<16xf16>, vector<16xf16>, vector<8xf16>) -> vector<8xf16>
// bf16 -> bf16 (OPSEL = {0,1})
- // CHECK: call <8 x i16> @llvm.amdgcn.wmma.bf16.16x16x16.bf16.v8i16.v16i16(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <8 x i16> %{{.*}}, i1 {{.*}})
- %r11 = rocdl.wmma.bf16.16x16x16.bf16 %arg2, %arg2, %arg8, %zero : (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16>
+ // CHECK: call <8 x i16> @llvm.amdgcn.wmma.bf16.16x16x16.bf16.v8i16.v16i16(<16 x i16> %{{.*}} <16 x i16> %{{.*}} <8 x i16> %{{.*}} i1 false)
+ %r11 = rocdl.wmma.bf16.16x16x16.bf16 %arg2, %arg2, %arg8 {opsel = false} : (vector<16xi16>, vector<16xi16>, vector<8xi16>) -> vector<8xi16>
// int8 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1})
- // CHECK: call <4 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v4i32.v4i32(i1 {{.*}}, <4 x i32> %{{.*}}, i1 {{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i1 {{.*}})
- %r12 = rocdl.wmma.i32.16x16x16.iu8 %zero, %arg5, %zero, %arg5, %arg5, %zero : (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32>
+ // CHECK: call <4 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v4i32.v4i32(i1 false, <4 x i32> %{{.*}} i1 false, <4 x i32> %{{.*}} <4 x i32> %{{.*}} i1 true)
+ %r12 = rocdl.wmma.i32.16x16x16.iu8 %arg5, %arg5, %arg5 {signA = false, signB = false, clamp = true} : (vector<4xi32>, vector<4xi32>, vector<4xi32>) -> vector<4xi32>
// int4 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1})
- // CHECK: call <4 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v4i32.v2i32(i1 {{.*}}, <2 x i32> %{{.*}}, i1 {{.*}}, <2 x i32> %{{.*}}, <4 x i32> %{{.*}}, i1 {{.*}})
- %r13 = rocdl.wmma.i32.16x16x16.iu4 %zero, %arg4, %zero, %arg4, %arg5, %zero : (i1, vector<2xi32>, i1, vector<2xi32>, vector<4xi32>, i1) -> vector<4xi32>
+ // CHECK: call <4 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v4i32.v2i32(i1 false, <2 x i32> %{{.*}} i1 false, <2 x i32> %{{.*}} <4 x i32> %{{.*}} i1 true)
+ %r13 = rocdl.wmma.i32.16x16x16.iu4 %arg4, %arg4, %arg5 {signA = false, signB = false, clamp = true} : (vector<2xi32>, vector<2xi32>, vector<4xi32>) -> vector<4xi32>
llvm.return %r0 : vector<8xf32>
}
@@ -1028,6 +1136,39 @@ llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> {
llvm.return %r3 : vector<4xf16>
}
+llvm.func @rocdl.load.tr.ops(%gl_ptr : !llvm.ptr<1>, %ds_ptr : !llvm.ptr<3>) {
+ // CHECK-LABEL: rocdl.load.tr.ops
+ // CHECK-SAME: (ptr addrspace(1) %[[GL_PTR:.+]], ptr addrspace(3) %[[DS_PTR:.+]])
+ // CHECK: call <2 x i32> @llvm.amdgcn.global.load.tr4.b64.v2i32(ptr addrspace(1) %[[GL_PTR]])
+ // CHECK: call <2 x i32> @llvm.amdgcn.global.load.tr.b64.v2i32(ptr addrspace(1) %[[GL_PTR]])
+ // CHECK: call <3 x i32> @llvm.amdgcn.global.load.tr6.b96.v3i32(ptr addrspace(1) %[[GL_PTR]])
+ // CHECK: call <8 x i16> @llvm.amdgcn.global.load.tr.b128.v8i16(ptr addrspace(1) %[[GL_PTR]])
+ // CHECK: call <8 x half> @llvm.amdgcn.global.load.tr.b128.v8f16(ptr addrspace(1) %[[GL_PTR]])
+ // CHECK: call <8 x bfloat> @llvm.amdgcn.global.load.tr.b128.v8bf16(ptr addrspace(1) %[[GL_PTR]])
+
+ // CHECK: call <2 x i32> @llvm.amdgcn.ds.load.tr4.b64.v2i32(ptr addrspace(3) %[[DS_PTR]])
+ // CHECK: call <2 x i32> @llvm.amdgcn.ds.load.tr8.b64.v2i32(ptr addrspace(3) %[[DS_PTR]])
+ // CHECK: call <3 x i32> @llvm.amdgcn.ds.load.tr6.b96.v3i32(ptr addrspace(3) %[[DS_PTR]])
+ // CHECK: call <8 x i16> @llvm.amdgcn.ds.load.tr16.b128.v8i16(ptr addrspace(3) %[[DS_PTR]])
+ // CHECK: call <8 x half> @llvm.amdgcn.ds.load.tr16.b128.v8f16(ptr addrspace(3) %[[DS_PTR]])
+ // CHECK: call <8 x bfloat> @llvm.amdgcn.ds.load.tr16.b128.v8bf16(ptr addrspace(3) %[[DS_PTR]])
+
+ rocdl.global.load.tr4.b64 %gl_ptr : !llvm.ptr<1> -> vector<2xi32>
+ rocdl.global.load.tr.b64 %gl_ptr : !llvm.ptr<1> -> vector<2xi32>
+ rocdl.global.load.tr6.b96 %gl_ptr : !llvm.ptr<1> -> vector<3xi32>
+ rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xi16>
+ rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xf16>
+ rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xbf16>
+
+ rocdl.ds.load.tr4.b64 %ds_ptr : !llvm.ptr<3> -> vector<2xi32>
+ rocdl.ds.load.tr8.b64 %ds_ptr : !llvm.ptr<3> -> vector<2xi32>
+ rocdl.ds.load.tr6.b96 %ds_ptr : !llvm.ptr<3> -> vector<3xi32>
+ rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xi16>
+ rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xf16>
+ rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xbf16>
+ llvm.return
+}
+
llvm.func @rocdl.load.to.lds(%src : !llvm.ptr<7>, %dst: !llvm.ptr<3>) {
//CHECK: call void @llvm.amdgcn.load.to.lds.p7
rocdl.load.to.lds %src, %dst, 4, 0, 0 : !llvm.ptr<7>
@@ -1040,6 +1181,32 @@ llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) {
llvm.return
}
+// CHECK-LABEL: rocdl.global.load.async.to.lds
+llvm.func @rocdl.global.load.async.to.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) {
+ // CHECK: call void @llvm.amdgcn.global.load.async.to.lds.b8
+ rocdl.global.load.async.to.lds.b8 %src, %dst, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3>
+ // CHECK: call void @llvm.amdgcn.global.load.async.to.lds.b32
+ rocdl.global.load.async.to.lds.b32 %src, %dst, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3>
+ // CHECK: call void @llvm.amdgcn.global.load.async.to.lds.b64
+ rocdl.global.load.async.to.lds.b64 %src, %dst, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3>
+ // CHECK: call void @llvm.amdgcn.global.load.async.to.lds.b128
+ rocdl.global.load.async.to.lds.b128 %src, %dst, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3>
+ llvm.return
+}
+
+// CHECK-LABEL: rocdl.cluster.load.async.to.lds
+llvm.func @rocdl.cluster.load.async.to.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) {
+ // CHECK: call void @llvm.amdgcn.cluster.load.async.to.lds.b8
+ rocdl.cluster.load.async.to.lds.b8 %src, %dst, 0, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3>
+ // CHECK: call void @llvm.amdgcn.cluster.load.async.to.lds.b32
+ rocdl.cluster.load.async.to.lds.b32 %src, %dst, 0, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3>
+ // CHECK: call void @llvm.amdgcn.cluster.load.async.to.lds.b64
+ rocdl.cluster.load.async.to.lds.b64 %src, %dst, 0, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3>
+ // CHECK: call void @llvm.amdgcn.cluster.load.async.to.lds.b128
+ rocdl.cluster.load.async.to.lds.b128 %src, %dst, 0, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3>
+ llvm.return
+}
+
// CHECK-LABEL: rocdl.tensor.load.to.lds
llvm.func @rocdl.tensor.load.to.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>,
%dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) {
@@ -1174,6 +1341,113 @@ llvm.func @rocdl.raw.ptr.buffer.load.lds(%rsrc : !llvm.ptr<8>, %dstLds : !llvm.p
llvm.return
}
+llvm.func @rocdl.wmma.scale(%arg0: i32, %arg1: vector<4xf32>, %arg2: vector<8xi32>,
+ %arg3: vector<12xi32>, %arg5: vector<16xi32>,
+ %arg8: i64, %arg9: vector<8xf32>) -> vector<4xf32> {
+ // CHECK-LABEL: rocdl.wmma.scale
+
+ // Test with default attributes (all zeros/false)
+ // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 0, <16 x i32> %{{.*}}, i32 0, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i1 false, i1 false)
+ %r00 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg0, %arg0
+ {fmtA = 0 : i32, fmtB = 0 : i32, modC = 0 : i16,
+ scaleAType = 0 : i32, fmtScaleA = 0 : i32,
+ scaleBType = 0 : i32, fmtScaleB = 0 : i32,
+ reuseA = false, reuseB = false} :
+ (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+
+ // Test with different matrix formats (FP8 x BF8)
+ // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 0, <16 x i32> %{{.*}}, i32 1, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i1 false, i1 false)
+ %r01 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg0, %arg0
+ {fmtA = 0 : i32, fmtB = 1 : i32, modC = 0 : i16,
+ scaleAType = 1 : i32, fmtScaleA = 1 : i32,
+ scaleBType = 1 : i32, fmtScaleB = 1 : i32,
+ reuseA = false, reuseB = false} :
+ (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+
+ // Test with FP8 x FP6 (different vector sizes) and modC = 1 (negate)
+ // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v12i32(i32 0, <16 x i32> %{{.*}}, i32 2, <12 x i32> %{{.*}}, i16 1, <4 x float> %{{.*}}, i32 2, i32 2, i32 %{{.*}}, i32 2, i32 2, i32 %{{.*}}, i1 false, i1 false)
+ %r02 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg3, %arg1, %arg0, %arg0
+ {fmtA = 0 : i32, fmtB = 2 : i32, modC = 1 : i16,
+ scaleAType = 2 : i32, fmtScaleA = 2 : i32,
+ scaleBType = 2 : i32, fmtScaleB = 2 : i32,
+ reuseA = false, reuseB = false} :
+ (vector<16xi32>, vector<12xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+
+ // Test with BF8 x BF6 and modC = 2 (abs)
+ // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v12i32(i32 1, <16 x i32> %{{.*}}, i32 3, <12 x i32> %{{.*}}, i16 2, <4 x float> %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i1 false, i1 false)
+ %r03 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg3, %arg1, %arg0, %arg0
+ {fmtA = 1 : i32, fmtB = 3 : i32, modC = 2 : i16,
+ scaleAType = 0 : i32, fmtScaleA = 0 : i32,
+ scaleBType = 0 : i32, fmtScaleB = 0 : i32,
+ reuseA = false, reuseB = false} :
+ (vector<16xi32>, vector<12xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+
+ // Test with FP8 x FP4 and modC = 3 (negate(abs))
+ // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v8i32(i32 0, <16 x i32> %{{.*}}, i32 4, <8 x i32> %{{.*}}, i16 3, <4 x float> %{{.*}}, i32 3, i32 3, i32 %{{.*}}, i32 3, i32 3, i32 %{{.*}}, i1 false, i1 false)
+ %r04 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg2, %arg1, %arg0, %arg0
+ {fmtA = 0 : i32, fmtB = 4 : i32, modC = 3 : i16,
+ scaleAType = 3 : i32, fmtScaleA = 3 : i32,
+ scaleBType = 3 : i32, fmtScaleB = 3 : i32,
+ reuseA = false, reuseB = false} :
+ (vector<16xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+
+ // Test with reuseA = true
+ // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 2, <16 x i32> %{{.*}}, i32 2, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i1 true, i1 false)
+ %r10 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg0, %arg0
+ {fmtA = 2 : i32, fmtB = 2 : i32, modC = 0 : i16,
+ scaleAType = 0 : i32, fmtScaleA = 0 : i32,
+ scaleBType = 0 : i32, fmtScaleB = 0 : i32,
+ reuseA = true, reuseB = false} :
+ (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+
+ // Test with reuseB = true
+ // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 3, <16 x i32> %{{.*}}, i32 3, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i1 false, i1 true)
+ %r11 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg0, %arg0
+ {fmtA = 3 : i32, fmtB = 3 : i32, modC = 0 : i16,
+ scaleAType = 0 : i32, fmtScaleA = 0 : i32,
+ scaleBType = 0 : i32, fmtScaleB = 0 : i32,
+ reuseA = false, reuseB = true} :
+ (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+
+ // Test with both reuseA and reuseB = true
+ // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 4, <16 x i32> %{{.*}}, i32 4, <16 x i32> %{{.*}}, i16 1, <4 x float> %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i1 true, i1 true)
+ %r12 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg0, %arg0
+ {fmtA = 4 : i32, fmtB = 4 : i32, modC = 1 : i16,
+ scaleAType = 1 : i32, fmtScaleA = 1 : i32,
+ scaleBType = 1 : i32, fmtScaleB = 1 : i32,
+ reuseA = true, reuseB = true} :
+ (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+
+ // Test scale16 variant with i64 scale exponents
+ // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale16.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 0, <16 x i32> %{{.*}}, i32 1, <16 x i32> %{{.*}}, i16 2, <4 x float> %{{.*}}, i32 2, i32 2, i64 %{{.*}}, i32 2, i32 2, i64 %{{.*}}, i1 false, i1 false)
+ %r_scale16 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg8, %arg8
+ {fmtA = 0 : i32, fmtB = 1 : i32, modC = 2 : i16,
+ scaleAType = 2 : i32, fmtScaleA = 2 : i32,
+ scaleBType = 2 : i32, fmtScaleB = 2 : i32,
+ reuseA = false, reuseB = false} :
+ (vector<16xi32>, vector<16xi32>, vector<4xf32>, i64, i64) -> vector<4xf32>
+
+ // Test f4 variant (no matrix format parameters)
+ // CHECK: call <8 x float> @llvm.amdgcn.wmma.scale.f32.32x16x128.f4.v8f32.v16i32.v8i32(<16 x i32> %{{.*}}, <8 x i32> %{{.*}}, i16 0, <8 x float> %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i1 false, i1 false)
+ %r_f4 = rocdl.wmma.scale.f32.32x16x128.f4 %arg5, %arg2, %arg9, %arg0, %arg0
+ {modC = 0 : i16,
+ scaleAType = 1 : i32, fmtScaleA = 1 : i32,
+ scaleBType = 1 : i32, fmtScaleB = 1 : i32,
+ reuseA = false, reuseB = false} :
+ (vector<16xi32>, vector<8xi32>, vector<8xf32>, i32, i32) -> vector<8xf32>
+
+ // Test f4 scale16 variant with varied attributes
+ // CHECK: call <8 x float> @llvm.amdgcn.wmma.scale16.f32.32x16x128.f4.v8f32.v16i32.v8i32(<16 x i32> %{{.*}}, <8 x i32> %{{.*}}, i16 3, <8 x float> %{{.*}}, i32 2, i32 3, i64 %{{.*}}, i32 3, i32 2, i64 %{{.*}}, i1 true, i1 true)
+ %r_f4_scale16 = rocdl.wmma.scale16.f32.32x16x128.f4 %arg5, %arg2, %arg9, %arg8, %arg8
+ {modC = 3 : i16,
+ scaleAType = 2 : i32, fmtScaleA = 3 : i32,
+ scaleBType = 3 : i32, fmtScaleB = 2 : i32,
+ reuseA = true, reuseB = true} :
+ (vector<16xi32>, vector<8xi32>, vector<8xf32>, i64, i64) -> vector<8xf32>
+
+ llvm.return %r00 : vector<4xf32>
+}
+
llvm.func @rocdl.raw.ptr.buffer.atomic.f32(%rsrc : !llvm.ptr<8>,
%offset : i32, %soffset : i32,
%vdata1 : f32) {
diff --git a/mlir/test/Target/LLVMIR/target-ext-type.mlir b/mlir/test/Target/LLVMIR/target-ext-type.mlir
index 6b2d2ea..cee6301 100644
--- a/mlir/test/Target/LLVMIR/target-ext-type.mlir
+++ b/mlir/test/Target/LLVMIR/target-ext-type.mlir
@@ -6,6 +6,12 @@ llvm.mlir.global external @global() {addr_space = 0 : i32} : !llvm.target<"spirv
llvm.return %0 : !llvm.target<"spirv.DeviceEvent">
}
+// CHECK: @amdgcn_named_barrier = internal addrspace(3) global target("amdgcn.named.barrier", 0) poison
+llvm.mlir.global internal @amdgcn_named_barrier() {addr_space = 3 : i32} : !llvm.target<"amdgcn.named.barrier", 0> {
+ %0 = llvm.mlir.poison : !llvm.target<"amdgcn.named.barrier", 0>
+ llvm.return %0 : !llvm.target<"amdgcn.named.barrier", 0>
+}
+
// CHECK-LABEL: define target("spirv.Event") @func2() {
// CHECK-NEXT: %1 = alloca target("spirv.Event"), align 8
// CHECK-NEXT: %2 = load target("spirv.Event"), ptr %1, align 8
diff --git a/mlir/test/Target/SPIRV/consecutive-selection.spv b/mlir/test/Target/SPIRV/consecutive-selection.spvasm
index 3752058..3752058 100644
--- a/mlir/test/Target/SPIRV/consecutive-selection.spv
+++ b/mlir/test/Target/SPIRV/consecutive-selection.spvasm
diff --git a/mlir/test/Target/SPIRV/decorations.mlir b/mlir/test/Target/SPIRV/decorations.mlir
index 712fd17..29b5d4f 100644
--- a/mlir/test/Target/SPIRV/decorations.mlir
+++ b/mlir/test/Target/SPIRV/decorations.mlir
@@ -78,6 +78,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
// -----
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
+ // CHECK: coherent
+ spirv.GlobalVariable @var {coherent} : !spirv.ptr<vector<2xf32>, Output>
+}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
// CHECK: linkage_attributes = #spirv.linkage_attributes<linkage_name = "outSideGlobalVar1", linkage_type = <Import>>
spirv.GlobalVariable @var1 {
linkage_attributes=#spirv.linkage_attributes<
diff --git a/mlir/test/Target/SPIRV/group-ops.mlir b/mlir/test/Target/SPIRV/group-ops.mlir
index cf519cb..6f19b35 100644
--- a/mlir/test/Target/SPIRV/group-ops.mlir
+++ b/mlir/test/Target/SPIRV/group-ops.mlir
@@ -1,11 +1,13 @@
-// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip -split-input-file %s | FileCheck %s
+// RUN: mlir-translate --no-implicit-module --test-spirv-roundtrip --split-input-file %s | FileCheck %s
// RUN: %if spirv-tools %{ rm -rf %t %}
// RUN: %if spirv-tools %{ mkdir %t %}
// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv --split-input-file --spirv-save-validation-files-with-prefix=%t/module %s %}
// RUN: %if spirv-tools %{ spirv-val %t %}
-spirv.module Logical GLSL450 requires #spirv.vce<v1.3, [Shader, Linkage, SubgroupBallotKHR, Groups, SubgroupBufferBlockIOINTEL, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], [SPV_KHR_storage_buffer_storage_class, SPV_KHR_shader_ballot, SPV_INTEL_subgroups, SPV_KHR_uniform_group_instructions]> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.3,
+ [Shader, Linkage, SubgroupBallotKHR, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR],
+ [SPV_KHR_storage_buffer_storage_class, SPV_KHR_shader_ballot, SPV_KHR_uniform_group_instructions]> {
// CHECK-LABEL: @subgroup_ballot
spirv.func @subgroup_ballot(%predicate: i1) -> vector<4xi32> "None" {
// CHECK: %{{.*}} = spirv.KHR.SubgroupBallot %{{.*}}: vector<4xi32>
@@ -24,30 +26,6 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.3, [Shader, Linkage, Subgrou
%0 = spirv.GroupBroadcast <Workgroup> %value, %localid : f32, vector<3xi32>
spirv.ReturnValue %0: f32
}
- // CHECK-LABEL: @subgroup_block_read_intel
- spirv.func @subgroup_block_read_intel(%ptr : !spirv.ptr<i32, StorageBuffer>) -> i32 "None" {
- // CHECK: spirv.INTEL.SubgroupBlockRead %{{.*}} : !spirv.ptr<i32, StorageBuffer> -> i32
- %0 = spirv.INTEL.SubgroupBlockRead %ptr : !spirv.ptr<i32, StorageBuffer> -> i32
- spirv.ReturnValue %0: i32
- }
- // CHECK-LABEL: @subgroup_block_read_intel_vector
- spirv.func @subgroup_block_read_intel_vector(%ptr : !spirv.ptr<i32, StorageBuffer>) -> vector<3xi32> "None" {
- // CHECK: spirv.INTEL.SubgroupBlockRead %{{.*}} : !spirv.ptr<i32, StorageBuffer> -> vector<3xi32>
- %0 = spirv.INTEL.SubgroupBlockRead %ptr : !spirv.ptr<i32, StorageBuffer> -> vector<3xi32>
- spirv.ReturnValue %0: vector<3xi32>
- }
- // CHECK-LABEL: @subgroup_block_write_intel
- spirv.func @subgroup_block_write_intel(%ptr : !spirv.ptr<i32, StorageBuffer>, %value: i32) -> () "None" {
- // CHECK: spirv.INTEL.SubgroupBlockWrite %{{.*}}, %{{.*}} : i32
- spirv.INTEL.SubgroupBlockWrite "StorageBuffer" %ptr, %value : i32
- spirv.Return
- }
- // CHECK-LABEL: @subgroup_block_write_intel_vector
- spirv.func @subgroup_block_write_intel_vector(%ptr : !spirv.ptr<i32, StorageBuffer>, %value: vector<3xi32>) -> () "None" {
- // CHECK: spirv.INTEL.SubgroupBlockWrite %{{.*}}, %{{.*}} : vector<3xi32>
- spirv.INTEL.SubgroupBlockWrite "StorageBuffer" %ptr, %value : vector<3xi32>
- spirv.Return
- }
// CHECK-LABEL: @group_iadd
spirv.func @group_iadd(%value: i32) -> i32 "None" {
// CHECK: spirv.GroupIAdd <Workgroup> <Reduce> %{{.*}} : i32
diff --git a/mlir/test/Target/SPIRV/loop.mlir b/mlir/test/Target/SPIRV/loop.mlir
index 95b87b3..b9a4295 100644
--- a/mlir/test/Target/SPIRV/loop.mlir
+++ b/mlir/test/Target/SPIRV/loop.mlir
@@ -1,5 +1,10 @@
// RUN: mlir-translate -no-implicit-module -split-input-file -test-spirv-roundtrip %s | FileCheck %s
+// RUN: %if spirv-tools %{ rm -rf %t %}
+// RUN: %if spirv-tools %{ mkdir %t %}
+// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv --split-input-file --spirv-save-validation-files-with-prefix=%t/module %s %}
+// RUN: %if spirv-tools %{ spirv-val %t %}
+
// Single loop
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
@@ -62,7 +67,7 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// Single loop with block arguments
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]> {
spirv.GlobalVariable @GV1 bind(0, 0) : !spirv.ptr<!spirv.struct<(!spirv.array<10 x f32, stride=4> [0])>, StorageBuffer>
spirv.GlobalVariable @GV2 bind(0, 1) : !spirv.ptr<!spirv.struct<(!spirv.array<10 x f32, stride=4> [0])>, StorageBuffer>
// CHECK-LABEL: @loop_kernel
diff --git a/mlir/test/Target/SPIRV/mlir-translate.mlir b/mlir/test/Target/SPIRV/mlir-translate.mlir
index cbce351..b1966fe 100644
--- a/mlir/test/Target/SPIRV/mlir-translate.mlir
+++ b/mlir/test/Target/SPIRV/mlir-translate.mlir
@@ -1,7 +1,6 @@
// Check that `--spirv-save-validation-files-with-prefix` generates
// a correct number of files.
-// REQUIRES: shell
// RUN: rm -rf %t
// RUN: mkdir %t && mlir-translate --serialize-spirv --no-implicit-module \
// RUN: --split-input-file --spirv-save-validation-files-with-prefix=%t/foo %s \
diff --git a/mlir/test/Target/SPIRV/module.mlir b/mlir/test/Target/SPIRV/module.mlir
index 7e52e54..fb4d9bc 100644
--- a/mlir/test/Target/SPIRV/module.mlir
+++ b/mlir/test/Target/SPIRV/module.mlir
@@ -1,6 +1,5 @@
// RUN: mlir-translate --no-implicit-module --test-spirv-roundtrip --split-input-file %s | FileCheck %s
-// REQUIRES: shell
// RUN: %if spirv-tools %{ rm -rf %t %}
// RUN: %if spirv-tools %{ mkdir %t %}
// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv --split-input-file --spirv-save-validation-files-with-prefix=%t/module %s %}
diff --git a/mlir/test/Target/SPIRV/phi.mlir b/mlir/test/Target/SPIRV/phi.mlir
index ca635a4..92a3387 100644
--- a/mlir/test/Target/SPIRV/phi.mlir
+++ b/mlir/test/Target/SPIRV/phi.mlir
@@ -1,5 +1,10 @@
// RUN: mlir-translate -no-implicit-module -split-input-file -test-spirv-roundtrip %s | FileCheck %s
+// RUN: %if spirv-tools %{ rm -rf %t %}
+// RUN: %if spirv-tools %{ mkdir %t %}
+// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv --split-input-file --spirv-save-validation-files-with-prefix=%t/module %s %}
+// RUN: %if spirv-tools %{ spirv-val %t %}
+
// Test branch with one block argument
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
@@ -295,15 +300,26 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
%true = spirv.Constant true
%zero = spirv.Constant 0 : i32
%one = spirv.Constant 1 : i32
+ spirv.mlir.selection {
// CHECK: spirv.BranchConditional %{{.*}}, ^[[true1:.*]](%{{.*}}, %{{.*}} : i32, i32), ^[[false1:.*]]
- spirv.BranchConditional %true, ^true1(%zero, %zero: i32, i32), ^false1
+ spirv.BranchConditional %true, ^true1(%zero, %zero: i32, i32), ^false1
// CHECK: [[true1]](%{{.*}}: i32, %{{.*}}: i32)
- ^true1(%arg0: i32, %arg1: i32):
- spirv.Return
+ ^true1(%arg0: i32, %arg1: i32):
+ spirv.Return
// CHECK: [[false1]]:
- ^false1:
+ ^false1:
+ spirv.Return
+ ^merge:
+ spirv.mlir.merge
+ }
+
+ spirv.Return
+ }
+
+ spirv.func @main() -> () "None" {
spirv.Return
}
+ spirv.EntryPoint "GLCompute" @main
}
// -----
@@ -314,15 +330,26 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
%true = spirv.Constant true
%zero = spirv.Constant 0 : i32
%one = spirv.Constant 1 : i32
+ spirv.mlir.selection {
// CHECK: spirv.BranchConditional %{{.*}}, ^[[true1:.*]], ^[[false1:.*]](%{{.*}}, %{{.*}} : i32, i32)
- spirv.BranchConditional %true, ^true1, ^false1(%zero, %zero: i32, i32)
+ spirv.BranchConditional %true, ^true1, ^false1(%zero, %zero: i32, i32)
// CHECK: [[true1]]:
- ^true1:
- spirv.Return
+ ^true1:
+ spirv.Return
// CHECK: [[false1]](%{{.*}}: i32, %{{.*}}: i32):
- ^false1(%arg0: i32, %arg1: i32):
+ ^false1(%arg0: i32, %arg1: i32):
+ spirv.Return
+ ^merge:
+ spirv.mlir.merge
+ }
+
+ spirv.Return
+ }
+
+ spirv.func @main() -> () "None" {
spirv.Return
}
+ spirv.EntryPoint "GLCompute" @main
}
// -----
@@ -333,13 +360,24 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
%true = spirv.Constant true
%zero = spirv.Constant 0 : i32
%one = spirv.Constant 1 : i32
+ spirv.mlir.selection {
// CHECK: spirv.BranchConditional %{{.*}}, ^[[true1:.*]](%{{.*}} : i32), ^[[false1:.*]](%{{.*}}, %{{.*}} : i32, i32)
- spirv.BranchConditional %true, ^true1(%one: i32), ^false1(%zero, %zero: i32, i32)
+ spirv.BranchConditional %true, ^true1(%one: i32), ^false1(%zero, %zero: i32, i32)
// CHECK: [[true1]](%{{.*}}: i32):
- ^true1(%arg0: i32):
- spirv.Return
+ ^true1(%arg0: i32):
+ spirv.Return
// CHECK: [[false1]](%{{.*}}: i32, %{{.*}}: i32):
- ^false1(%arg1: i32, %arg2: i32):
+ ^false1(%arg1: i32, %arg2: i32):
+ spirv.Return
+ ^merge:
+ spirv.mlir.merge
+ }
+
spirv.Return
}
+
+ spirv.func @main() -> () "None" {
+ spirv.Return
+ }
+ spirv.EntryPoint "GLCompute" @main
}
diff --git a/mlir/test/Target/SPIRV/selection.mlir b/mlir/test/Target/SPIRV/selection.mlir
index 44625cc..d0ad118 100644
--- a/mlir/test/Target/SPIRV/selection.mlir
+++ b/mlir/test/Target/SPIRV/selection.mlir
@@ -1,5 +1,10 @@
// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip -split-input-file %s | FileCheck %s
+// RUN: %if spirv-tools %{ rm -rf %t %}
+// RUN: %if spirv-tools %{ mkdir %t %}
+// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv --split-input-file --spirv-save-validation-files-with-prefix=%t/module %s %}
+// RUN: %if spirv-tools %{ spirv-val %t %}
+
// Selection with both then and else branches
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
@@ -136,19 +141,31 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// CHECK-NEXT: spirv.Load "Function" %[[VAR]]
%cond = spirv.Load "Function" %var : i1
+ spirv.mlir.selection {
// CHECK: spirv.BranchConditional %1, ^[[THEN1:.+]](%{{.+}} : i32), ^[[ELSE1:.+]](%{{.+}}, %{{.+}} : i32, i32)
- spirv.BranchConditional %cond, ^then1(%one: i32), ^else1(%zero, %zero: i32, i32)
+ spirv.BranchConditional %cond, ^then1(%one: i32), ^else1(%zero, %zero: i32, i32)
// CHECK-NEXT: ^[[THEN1]](%{{.+}}: i32):
// CHECK-NEXT: spirv.Return
- ^then1(%arg0: i32):
- spirv.Return
+ ^then1(%arg0: i32):
+ spirv.Return
// CHECK-NEXT: ^[[ELSE1]](%{{.+}}: i32, %{{.+}}: i32):
// CHECK-NEXT: spirv.Return
- ^else1(%arg1: i32, %arg2: i32):
+ ^else1(%arg1: i32, %arg2: i32):
+ spirv.Return
+ ^merge:
+ spirv.mlir.merge
+ }
+
spirv.Return
}
+
+ spirv.func @main() -> () "None" {
+ spirv.Return
+ }
+ spirv.EntryPoint "GLCompute" @main
+ spirv.ExecutionMode @main "LocalSize", 1, 1, 1
}
// -----
@@ -203,3 +220,129 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
spirv.EntryPoint "GLCompute" @main
spirv.ExecutionMode @main "LocalSize", 1, 1, 1
}
+
+// -----
+
+// Selection with switch
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+// CHECK-LABEL: @selection_switch
+ spirv.func @selection_switch(%selector: i32) -> () "None" {
+ %zero = spirv.Constant 0: i32
+ %one = spirv.Constant 1: i32
+ %two = spirv.Constant 2: i32
+ %three = spirv.Constant 3: i32
+ %four = spirv.Constant 4: i32
+// CHECK: {{%.*}} = spirv.Variable init({{%.*}}) : !spirv.ptr<i32, Function>
+ %var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+// CHECK: spirv.mlir.selection {
+ spirv.mlir.selection {
+// CHECK-NEXT: spirv.Switch {{%.*}} : i32, [
+// CHECK-NEXT: default: ^[[DEFAULT:.+]],
+// CHECK-NEXT: 0: ^[[CASE0:.+]],
+// CHECK-NEXT: 1: ^[[CASE1:.+]],
+// CHECK-NEXT: 2: ^[[CASE2:.+]]
+ spirv.Switch %selector : i32, [
+ default: ^default,
+ 0: ^case0,
+ 1: ^case1,
+ 2: ^case2
+ ]
+// CHECK: ^[[DEFAULT]]
+ ^default:
+// CHECK: spirv.Store "Function" {{%.*}}, {{%.*}} : i32
+ spirv.Store "Function" %var, %one : i32
+// CHECK-NEXT: spirv.Branch ^[[MERGE:.+]]
+ spirv.Branch ^merge
+// CHECK-NEXT: ^[[CASE0]]
+ ^case0:
+// CHECK: spirv.Store "Function" {{%.*}}, {{%.*}} : i32
+ spirv.Store "Function" %var, %two : i32
+// CHECK-NEXT: spirv.Branch ^[[MERGE:.+]]
+ spirv.Branch ^merge
+// CHECK-NEXT: ^[[CASE1]]
+ ^case1:
+// CHECK: spirv.Store "Function" {{%.*}}, {{%.*}} : i32
+ spirv.Store "Function" %var, %three : i32
+// CHECK-NEXT: spirv.Branch ^[[MERGE:.+]]
+ spirv.Branch ^merge
+// CHECK-NEXT: ^[[CASE2]]
+ ^case2:
+// CHECK: spirv.Store "Function" {{%.*}}, {{%.*}} : i32
+ spirv.Store "Function" %var, %four : i32
+// CHECK-NEXT: spirv.Branch ^[[MERGE:.+]]
+ spirv.Branch ^merge
+// CHECK-NEXT: ^[[MERGE]]
+ ^merge:
+// CHECK-NEXT: spirv.mlir.merge
+ spirv.mlir.merge
+// CHECK-NEXT: }
+ }
+// CHECK-NEXT: spirv.Return
+ spirv.Return
+ }
+
+ spirv.func @main() -> () "None" {
+ spirv.Return
+ }
+ spirv.EntryPoint "GLCompute" @main
+ spirv.ExecutionMode @main "LocalSize", 1, 1, 1
+}
+
+// -----
+
+// Selection with switch and block operands
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.5, [Shader], []> {
+// CHECK-LABEL: @selection_switch_operands
+ spirv.func @selection_switch_operands(%selector : si32) "None" {
+ %cst1 = spirv.Constant 1.000000e+00 : f32
+ %vec0 = spirv.Undef : vector<3xf32>
+// CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[0 : i32] : f32 into vector<3xf32>
+ %vec1 = spirv.CompositeInsert %cst1, %vec0[0 : i32] : f32 into vector<3xf32>
+ spirv.Branch ^bb1
+ ^bb1:
+// CHECK: {{%.*}} = spirv.mlir.selection -> vector<3xf32> {
+ %vec4 = spirv.mlir.selection -> vector<3xf32> {
+// CHECK-NEXT: spirv.Switch {{%.*}} : si32, [
+// CHECK-NEXT: default: ^[[DEFAULT:.+]]({{%.*}} : vector<3xf32>),
+// CHECK-NEXT: 0: ^[[CASE0:.+]]({{%.*}} : vector<3xf32>),
+// CHECK-NEXT: 1: ^[[CASE1:.+]]({{%.*}} : vector<3xf32>)
+ spirv.Switch %selector : si32, [
+ default: ^bb3(%vec1 : vector<3xf32>),
+ 0: ^bb1(%vec1 : vector<3xf32>),
+ 1: ^bb2(%vec1 : vector<3xf32>)
+ ]
+// CHECK: ^[[CASE0]]({{%.*}}: vector<3xf32>)
+ ^bb1(%vecbb1: vector<3xf32>):
+ %cst3 = spirv.Constant 3.000000e+00 : f32
+// CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[1 : i32] : f32 into vector<3xf32>
+ %vec2 = spirv.CompositeInsert %cst3, %vecbb1[1 : i32] : f32 into vector<3xf32>
+// CHECK-NEXT: spirv.Branch ^[[DEFAULT]]({{%.*}} : vector<3xf32>)
+ spirv.Branch ^bb3(%vec2 : vector<3xf32>)
+// CHECK-NEXT: ^[[CASE1]]({{%.*}}: vector<3xf32>)
+ ^bb2(%vecbb2: vector<3xf32>):
+ %cst4 = spirv.Constant 4.000000e+00 : f32
+// CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[1 : i32] : f32 into vector<3xf32>
+ %vec3 = spirv.CompositeInsert %cst4, %vecbb2[1 : i32] : f32 into vector<3xf32>
+// CHECK-NEXT: spirv.Branch ^[[DEFAULT]]({{%.*}} : vector<3xf32>)
+ spirv.Branch ^bb3(%vec3 : vector<3xf32>)
+// CHECK-NEXT: ^[[DEFAULT]]({{%.*}}: vector<3xf32>)
+ ^bb3(%vecbb3: vector<3xf32>):
+// CHECK-NEXT: spirv.mlir.merge {{%.*}} : vector<3xf32>
+ spirv.mlir.merge %vecbb3 : vector<3xf32>
+// CHECK-NEXT: }
+ }
+ %cst2 = spirv.Constant 2.000000e+00 : f32
+// CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[2 : i32] : f32 into vector<3xf32>
+ %vec5 = spirv.CompositeInsert %cst2, %vec4[2 : i32] : f32 into vector<3xf32>
+ spirv.Return
+ }
+
+ spirv.func @main() -> () "None" {
+ spirv.Return
+ }
+
+ spirv.EntryPoint "GLCompute" @main
+ spirv.ExecutionMode @main "LocalSize", 1, 1, 1
+}
diff --git a/mlir/test/Target/SPIRV/selection.spv b/mlir/test/Target/SPIRV/selection.spvasm
index 9642d0a..9642d0a 100644
--- a/mlir/test/Target/SPIRV/selection.spv
+++ b/mlir/test/Target/SPIRV/selection.spvasm
diff --git a/mlir/test/Target/SPIRV/selection_switch.spvasm b/mlir/test/Target/SPIRV/selection_switch.spvasm
new file mode 100644
index 0000000..81fecf3
--- /dev/null
+++ b/mlir/test/Target/SPIRV/selection_switch.spvasm
@@ -0,0 +1,69 @@
+; RUN: %if spirv-tools %{ spirv-as --target-env spv1.0 %s -o - | mlir-translate --deserialize-spirv - -o - | FileCheck %s %}
+
+; This test is analogous to selection.spv but tests switch op.
+
+; CHECK: spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+; CHECK-NEXT: spirv.func @switch({{%.*}}: si32) "None" {
+; CHECK: {{%.*}} = spirv.Constant 1.000000e+00 : f32
+; CHECK-NEXT: {{%.*}} = spirv.Undef : vector<3xf32>
+; CHECK-NEXT: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[0 : i32] : f32 into vector<3xf32>
+; CHECK-NEXT: spirv.Branch ^[[bb:.+]]
+; CHECK-NEXT: ^[[bb:.+]]:
+; CHECK-NEXT: {{%.*}} = spirv.mlir.selection -> vector<3xf32> {
+; CHECK-NEXT: spirv.Switch {{%.*}} : si32, [
+; CHECK-NEXT: default: ^[[bb:.+]]({{%.*}}: vector<3xf32>),
+; CHECK-NEXT: 0: ^[[bb:.+]]({{%.*}}: vector<3xf32>),
+; CHECK-NEXT: 1: ^[[bb:.+]]({{%.*}}: vector<3xf32>)
+; CHECK: ^[[bb:.+]]({{%.*}}: vector<3xf32>):
+; CHECK: spirv.Branch ^[[bb:.+]]({{%.*}}: vector<3xf32>)
+; CHECK-NEXT: ^[[bb:.+]]({{%.*}}: vector<3xf32>):
+; CHECK: spirv.Branch ^[[bb:.+]]({{%.*}}: vector<3xf32>)
+; CHECK-NEXT: ^[[bb:.+]]({{%.*}}: vector<3xf32>):
+; CHECK-NEXT: spirv.mlir.merge %8 : vector<3xf32>
+; CHECK-NEXT }
+; CHECK: spirv.Return
+; CHECK-NEXT: }
+; CHECK: }
+
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ OpName %switch "switch"
+ OpName %main "main"
+ %void = OpTypeVoid
+ %int = OpTypeInt 32 1
+ %1 = OpTypeFunction %void %int
+ %float = OpTypeFloat 32
+ %float_1 = OpConstant %float 1
+ %v3float = OpTypeVector %float 3
+ %9 = OpUndef %v3float
+ %float_3 = OpConstant %float 3
+ %float_4 = OpConstant %float 4
+ %float_2 = OpConstant %float 2
+ %25 = OpTypeFunction %void
+ %switch = OpFunction %void None %1
+ %5 = OpFunctionParameter %int
+ %6 = OpLabel
+ OpBranch %12
+ %12 = OpLabel
+ %11 = OpCompositeInsert %v3float %float_1 %9 0
+ OpSelectionMerge %15 None
+ OpSwitch %5 %15 0 %13 1 %14
+ %13 = OpLabel
+ %16 = OpPhi %v3float %11 %12
+ %18 = OpCompositeInsert %v3float %float_3 %16 1
+ OpBranch %15
+ %14 = OpLabel
+ %19 = OpPhi %v3float %11 %12
+ %21 = OpCompositeInsert %v3float %float_4 %19 1
+ OpBranch %15
+ %15 = OpLabel
+ %22 = OpPhi %v3float %21 %14 %18 %13 %11 %12
+ %24 = OpCompositeInsert %v3float %float_2 %22 2
+ OpReturn
+ OpFunctionEnd
+ %main = OpFunction %void None %25
+ %27 = OpLabel
+ OpReturn
+ OpFunctionEnd
diff --git a/mlir/test/Target/SPIRV/struct.mlir b/mlir/test/Target/SPIRV/struct.mlir
index 4984ee7..c423500 100644
--- a/mlir/test/Target/SPIRV/struct.mlir
+++ b/mlir/test/Target/SPIRV/struct.mlir
@@ -1,6 +1,11 @@
// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip %s | FileCheck %s
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+// RUN: %if spirv-tools %{ rm -rf %t %}
+// RUN: %if spirv-tools %{ mkdir %t %}
+// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv --split-input-file --spirv-save-validation-files-with-prefix=%t/module %s %}
+// RUN: %if spirv-tools %{ spirv-val %t %}
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Addresses, Float64, Int64, Linkage], [SPV_KHR_storage_buffer_storage_class]> {
// CHECK: !spirv.ptr<!spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, Input>
spirv.GlobalVariable @var0 bind(0, 1) : !spirv.ptr<!spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, Input>
@@ -16,8 +21,8 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4]), Block>, StorageBuffer>
spirv.GlobalVariable @var4 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4]), Block>, StorageBuffer>
- // CHECK: !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable]), Block>, StorageBuffer>
- spirv.GlobalVariable @var5 : !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable]), Block>, StorageBuffer>
+ // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]), Block>, StorageBuffer>
+ spirv.GlobalVariable @var5 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]), Block>, StorageBuffer>
// CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]), Block>, StorageBuffer>
spirv.GlobalVariable @var6 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]), Block>, StorageBuffer>
@@ -34,14 +39,14 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// CHECK: !spirv.ptr<!spirv.struct<test_id, (!spirv.array<128 x f32, stride=4> [0])>, Input>
spirv.GlobalVariable @id_var0 : !spirv.ptr<!spirv.struct<test_id, (!spirv.array<128 x f32, stride=4> [0])>, Input>
- // CHECK: !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>), Block>, StorageBuffer>
- spirv.GlobalVariable @recursive_simple : !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>), Block>, StorageBuffer>
+ // CHECK: !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer> [0]), Block>, StorageBuffer>
+ spirv.GlobalVariable @recursive_simple : !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer> [0]), Block>, StorageBuffer>
- // CHECK: !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>), Block>, Uniform>), Block>, Uniform>
- spirv.GlobalVariable @recursive_2 : !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>), Block>, Uniform>), Block>, Uniform>
+ // CHECK: !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform> [0]), Block>, Uniform> [0]), Block>, Uniform>
+ spirv.GlobalVariable @recursive_2 : !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform> [0]), Block>, Uniform> [0]), Block>, Uniform>
- // CHECK: !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>), Block>, Uniform>), Block>, Uniform>
- spirv.GlobalVariable @recursive_3 : !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>), Block>, Uniform>), Block>, Uniform>
+ // CHECK: !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform> [0], !spirv.ptr<!spirv.struct<bxx>, Uniform> [8]), Block>, Uniform> [0]), Block>, Uniform>
+ spirv.GlobalVariable @recursive_3 : !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform> [0], !spirv.ptr<!spirv.struct<bxx>, Uniform> [8]), Block>, Uniform> [0]), Block>, Uniform>
// CHECK: spirv.GlobalVariable @block : !spirv.ptr<!spirv.struct<vert, (vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>, Output>
spirv.GlobalVariable @block : !spirv.ptr<!spirv.struct<vert, (vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>, Output>
diff --git a/mlir/test/Target/SPIRV/subgroup-block-intel.mlir b/mlir/test/Target/SPIRV/subgroup-block-intel.mlir
new file mode 100644
index 0000000..14060e6
--- /dev/null
+++ b/mlir/test/Target/SPIRV/subgroup-block-intel.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-translate --no-implicit-module --test-spirv-roundtrip %s | FileCheck %s
+
+// RUN: %if spirv-tools %{ rm -rf %t %}
+// RUN: %if spirv-tools %{ mkdir %t %}
+// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv --spirv-save-validation-files-with-prefix=%t/module %s %}
+// RUN: %if spirv-tools %{ spirv-val %t %}
+
+spirv.module Physical64 GLSL450 requires #spirv.vce<v1.3, [Addresses, Shader, Linkage, SubgroupBufferBlockIOINTEL],
+ [SPV_KHR_storage_buffer_storage_class, SPV_INTEL_subgroups]> {
+ // CHECK-LABEL: @subgroup_block_read_intel
+ spirv.func @subgroup_block_read_intel(%ptr : !spirv.ptr<i32, StorageBuffer>) -> i32 "None" {
+ // CHECK: spirv.INTEL.SubgroupBlockRead %{{.*}} : !spirv.ptr<i32, StorageBuffer> -> i32
+ %0 = spirv.INTEL.SubgroupBlockRead %ptr : !spirv.ptr<i32, StorageBuffer> -> i32
+ spirv.ReturnValue %0: i32
+ }
+ // CHECK-LABEL: @subgroup_block_read_intel_vector
+ spirv.func @subgroup_block_read_intel_vector(%ptr : !spirv.ptr<i32, StorageBuffer>) -> vector<3xi32> "None" {
+ // CHECK: spirv.INTEL.SubgroupBlockRead %{{.*}} : !spirv.ptr<i32, StorageBuffer> -> vector<3xi32>
+ %0 = spirv.INTEL.SubgroupBlockRead %ptr : !spirv.ptr<i32, StorageBuffer> -> vector<3xi32>
+ spirv.ReturnValue %0: vector<3xi32>
+ }
+ // CHECK-LABEL: @subgroup_block_write_intel
+ spirv.func @subgroup_block_write_intel(%ptr : !spirv.ptr<i32, StorageBuffer>, %value: i32) -> () "None" {
+ // CHECK: spirv.INTEL.SubgroupBlockWrite %{{.*}}, %{{.*}} : i32
+ spirv.INTEL.SubgroupBlockWrite "StorageBuffer" %ptr, %value : i32
+ spirv.Return
+ }
+ // CHECK-LABEL: @subgroup_block_write_intel_vector
+ spirv.func @subgroup_block_write_intel_vector(%ptr : !spirv.ptr<i32, StorageBuffer>, %value: vector<3xi32>) -> () "None" {
+ // CHECK: spirv.INTEL.SubgroupBlockWrite %{{.*}}, %{{.*}} : vector<3xi32>
+ spirv.INTEL.SubgroupBlockWrite "StorageBuffer" %ptr, %value : vector<3xi32>
+ spirv.Return
+ }
+}
diff --git a/mlir/test/Transforms/buffer-results-to-out-params-modify-public-functions.mlir b/mlir/test/Transforms/buffer-results-to-out-params-modify-public-functions.mlir
new file mode 100644
index 0000000..c99bde3
--- /dev/null
+++ b/mlir/test/Transforms/buffer-results-to-out-params-modify-public-functions.mlir
@@ -0,0 +1,40 @@
+// RUN: mlir-opt -p 'builtin.module(buffer-results-to-out-params{modify-public-functions})' %s | FileCheck %s
+
+// Test if `public` functions' return values are transformed into out parameters
+// when `buffer-results-to-out-params` is invoked with `modifyPublicFunctions`.
+
+// CHECK-LABEL: func.func @basic(
+// CHECK-SAME: %[[ARG0:.*]]: memref<f32>) {
+// CHECK: %[[VAL_0:.*]] = "test.source"() : () -> memref<f32>
+// CHECK: memref.copy %[[VAL_0]], %[[ARG0]] : memref<f32> to memref<f32>
+// CHECK: return
+// CHECK: }
+func.func @basic() -> (memref<f32>) {
+ %0 = "test.source"() : () -> (memref<f32>)
+ return %0 : memref<f32>
+}
+
+// CHECK-LABEL: func.func @presence_of_existing_arguments(
+// CHECK-SAME: %[[ARG0:.*]]: memref<1xf32>,
+// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>) {
+// CHECK: %[[VAL_0:.*]] = "test.source"() : () -> memref<2xf32>
+// CHECK: memref.copy %[[VAL_0]], %[[ARG1]] : memref<2xf32> to memref<2xf32>
+// CHECK: return
+// CHECK: }
+func.func @presence_of_existing_arguments(%arg0: memref<1xf32>) -> (memref<2xf32>) {
+ %0 = "test.source"() : () -> (memref<2xf32>)
+ return %0 : memref<2xf32>
+}
+
+// CHECK-LABEL: func.func @multiple_results(
+// CHECK-SAME: %[[ARG0:.*]]: memref<1xf32>,
+// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>) {
+// CHECK: %[[VAL_0:.*]]:2 = "test.source"() : () -> (memref<1xf32>, memref<2xf32>)
+// CHECK: memref.copy %[[VAL_0]]#0, %[[ARG0]] : memref<1xf32> to memref<1xf32>
+// CHECK: memref.copy %[[VAL_0]]#1, %[[ARG1]] : memref<2xf32> to memref<2xf32>
+// CHECK: return
+// CHECK: }
+func.func @multiple_results() -> (memref<1xf32>, memref<2xf32>) {
+ %0, %1 = "test.source"() : () -> (memref<1xf32>, memref<2xf32>)
+ return %0, %1 : memref<1xf32>, memref<2xf32>
+}
diff --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir
index c1604e2..31a4f64d 100644
--- a/mlir/test/Transforms/loop-invariant-code-motion.mlir
+++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir
@@ -880,6 +880,18 @@ func.func @no_speculate_divui(
return
}
+func.func @no_speculate_udiv(
+// CHECK-LABEL: @no_speculate_udiv(
+ %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) {
+ scf.for %i = %lb to %ub step %step {
+// CHECK: scf.for
+// CHECK: llvm.udiv
+ %val = llvm.udiv %num, %denom : i32
+ }
+
+ return
+}
+
func.func @no_speculate_divsi(
// CHECK-LABEL: @no_speculate_divsi(
%num: i32, %denom: i32, %lb: index, %ub: index, %step: index) {
@@ -892,6 +904,18 @@ func.func @no_speculate_divsi(
return
}
+func.func @no_speculate_sdiv(
+// CHECK-LABEL: @no_speculate_sdiv(
+ %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) {
+ scf.for %i = %lb to %ub step %step {
+// CHECK: scf.for
+// CHECK: llvm.sdiv
+ %val = llvm.sdiv %num, %denom : i32
+ }
+
+ return
+}
+
func.func @no_speculate_ceildivui(
// CHECK-LABEL: @no_speculate_ceildivui(
%num: i32, %denom: i32, %lb: index, %ub: index, %step: index) {
@@ -928,6 +952,18 @@ func.func @no_speculate_divui_const(%num: i32, %lb: index, %ub: index, %step: in
return
}
+func.func @no_speculate_udiv_const(%num: i32, %lb: index, %ub: index, %step: index) {
+// CHECK-LABEL: @no_speculate_udiv_const(
+ %c0 = arith.constant 0 : i32
+ scf.for %i = %lb to %ub step %step {
+// CHECK: scf.for
+// CHECK: llvm.udiv
+ %val = llvm.udiv %num, %c0 : i32
+ }
+
+ return
+}
+
func.func @speculate_divui_const(
// CHECK-LABEL: @speculate_divui_const(
%num: i32, %lb: index, %ub: index, %step: index) {
@@ -941,6 +977,19 @@ func.func @speculate_divui_const(
return
}
+func.func @speculate_udiv_const(
+// CHECK-LABEL: @speculate_udiv_const(
+ %num: i32, %lb: index, %ub: index, %step: index) {
+ %c5 = llvm.mlir.constant(5 : i32) : i32
+// CHECK: llvm.udiv
+// CHECK: scf.for
+ scf.for %i = %lb to %ub step %step {
+ %val = llvm.udiv %num, %c5 : i32
+ }
+
+ return
+}
+
func.func @no_speculate_ceildivui_const(%num: i32, %lb: index, %ub: index, %step: index) {
// CHECK-LABEL: @no_speculate_ceildivui_const(
%c0 = arith.constant 0 : i32
@@ -979,6 +1028,19 @@ func.func @no_speculate_divsi_const0(
return
}
+func.func @no_speculate_sdiv_const0(
+// CHECK-LABEL: @no_speculate_sdiv_const0(
+ %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) {
+ %c0 = arith.constant 0 : i32
+ scf.for %i = %lb to %ub step %step {
+// CHECK: scf.for
+// CHECK: llvm.sdiv
+ %val = llvm.sdiv %num, %c0 : i32
+ }
+
+ return
+}
+
func.func @no_speculate_divsi_const_minus1(
// CHECK-LABEL: @no_speculate_divsi_const_minus1(
%num: i32, %denom: i32, %lb: index, %ub: index, %step: index) {
@@ -992,6 +1054,19 @@ func.func @no_speculate_divsi_const_minus1(
return
}
+func.func @no_speculate_sdiv_const_minus1(
+// CHECK-LABEL: @no_speculate_sdiv_const_minus1(
+ %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) {
+ %cm1 = arith.constant -1 : i32
+ scf.for %i = %lb to %ub step %step {
+// CHECK: scf.for
+// CHECK: llvm.sdiv
+ %val = llvm.sdiv %num, %cm1 : i32
+ }
+
+ return
+}
+
func.func @speculate_divsi_const(
// CHECK-LABEL: @speculate_divsi_const(
%num: i32, %denom: i32, %lb: index, %ub: index, %step: index) {
@@ -1005,6 +1080,19 @@ func.func @speculate_divsi_const(
return
}
+func.func @speculate_sdiv_const(
+// CHECK-LABEL: @speculate_sdiv_const(
+ %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) {
+ %c5 = arith.constant 5 : i32
+ scf.for %i = %lb to %ub step %step {
+// CHECK: llvm.sdiv
+// CHECK: scf.for
+ %val = llvm.sdiv %num, %c5 : i32
+ }
+
+ return
+}
+
func.func @no_speculate_ceildivsi_const0(
// CHECK-LABEL: @no_speculate_ceildivsi_const0(
%num: i32, %denom: i32, %lb: index, %ub: index, %step: index) {
@@ -1057,6 +1145,19 @@ func.func @no_speculate_divui_range(
return
}
+func.func @no_speculate_udiv_range(
+// CHECK-LABEL: @no_speculate_udiv_range(
+ %num: i8, %lb: index, %ub: index, %step: index) {
+ %denom = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+ scf.for %i = %lb to %ub step %step {
+// CHECK: scf.for
+// CHECK: llvm.udiv
+ %val = llvm.udiv %num, %denom : i8
+ }
+
+ return
+}
+
func.func @no_speculate_divsi_range(
// CHECK-LABEL: @no_speculate_divsi_range(
%num: i8, %lb: index, %ub: index, %step: index) {
@@ -1072,6 +1173,21 @@ func.func @no_speculate_divsi_range(
return
}
+func.func @no_speculate_sdiv_range(
+// CHECK-LABEL: @no_speculate_sdiv_range(
+ %num: i8, %lb: index, %ub: index, %step: index) {
+ %denom0 = test.with_bounds {smax = -1: i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+ %denom1 = test.with_bounds {smax = 127 : i8, smin = 0 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+ scf.for %i = %lb to %ub step %step {
+// CHECK: scf.for
+// CHECK-COUNT-2: llvm.sdiv
+ %val0 = llvm.sdiv %num, %denom0 : i8
+ %val1 = llvm.sdiv %num, %denom1 : i8
+ }
+
+ return
+}
+
func.func @no_speculate_ceildivui_range(
// CHECK-LABEL: @no_speculate_ceildivui_range(
%num: i8, %lb: index, %ub: index, %step: index) {
@@ -1113,6 +1229,19 @@ func.func @speculate_divui_range(
return
}
+func.func @speculate_udiv_range(
+// CHECK-LABEL: @speculate_udiv_range(
+ %num: i8, %lb: index, %ub: index, %step: index) {
+ %denom = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 1 : i8} : i8
+ scf.for %i = %lb to %ub step %step {
+// CHECK: llvm.udiv
+// CHECK: scf.for
+ %val = llvm.udiv %num, %denom : i8
+ }
+
+ return
+}
+
func.func @speculate_divsi_range(
// CHECK-LABEL: @speculate_divsi_range(
%num: i8, %lb: index, %ub: index, %step: index) {
@@ -1129,6 +1258,22 @@ func.func @speculate_divsi_range(
return
}
+func.func @speculate_sdiv_range(
+// CHECK-LABEL: @speculate_sdiv_range(
+ %num: i8, %lb: index, %ub: index, %step: index) {
+ %denom0 = test.with_bounds {smax = 127 : i8, smin = 1 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+ %denom1 = test.with_bounds {smax = -2 : i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+ scf.for %i = %lb to %ub step %step {
+// CHECK-COUNT-2: llvm.sdiv
+// CHECK: scf.for
+ %val0 = llvm.sdiv %num, %denom0 : i8
+ %val1 = llvm.sdiv %num, %denom1 : i8
+
+ }
+
+ return
+}
+
func.func @speculate_ceildivui_range(
// CHECK-LABEL: @speculate_ceildivui_range(
%num: i8, %lb: index, %ub: index, %step: index) {
diff --git a/mlir/test/Transforms/move-operation-deps.mlir b/mlir/test/Transforms/move-operation-deps.mlir
index 75d8386..3119fd3 100644
--- a/mlir/test/Transforms/move-operation-deps.mlir
+++ b/mlir/test/Transforms/move-operation-deps.mlir
@@ -238,25 +238,26 @@ module attributes {transform.with_named_sequence} {
// -----
// Check simple move value definitions before insertion operation.
-func.func @simple_move_values() -> f32 {
- %0 = "before"() : () -> (f32)
- %1 = "moved_op_1"() : () -> (f32)
- %2 = "moved_op_2"() : () -> (f32)
- %3 = "foo"(%1, %2) : (f32, f32) -> (f32)
- return %3 : f32
-}
-// CHECK-LABEL: func @simple_move_values()
-// CHECK: %[[MOVED1:.+]] = "moved_op_1"
-// CHECK: %[[MOVED2:.+]] = "moved_op_2"
+func.func @simple_move_values(%arg0 : index) -> index {
+ %c0 = arith.constant 0 : index
+ %0 = "before"() : () -> (index)
+ %1 = arith.addi %arg0, %c0 {"moved_op_1"} : index
+ %2 = arith.subi %arg0, %c0 {"moved_op_2"} : index
+ %3 = "foo"(%1, %2) : (index, index) -> (index)
+ return %3 : index
+}
+// CHECK-LABEL: func @simple_move_values(
+// CHECK: %[[MOVED1:.+]] = arith.addi {{.*}} {moved_op_1}
+// CHECK: %[[MOVED2:.+]] = arith.subi {{.*}} {moved_op_2}
// CHECK: %[[BEFORE:.+]] = "before"
// CHECK: %[[FOO:.+]] = "foo"(%[[MOVED1]], %[[MOVED2]])
// CHECK: return %[[FOO]]
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
- %op1 = transform.structured.match ops{["moved_op_1"]} in %arg0
+ %op1 = transform.structured.match ops{["arith.addi"]} in %arg0
: (!transform.any_op) -> !transform.any_op
- %op2 = transform.structured.match ops{["moved_op_2"]} in %arg0
+ %op2 = transform.structured.match ops{["arith.subi"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op3 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
@@ -271,23 +272,26 @@ module attributes {transform.with_named_sequence} {
// -----
// Compute slice including the implicitly captured values.
-func.func @move_region_dependencies_values() -> f32 {
- %0 = "before"() : () -> (f32)
- %1 = "moved_op_1"() : () -> (f32)
- %2 = "moved_op_2"() ({
- %3 = "inner_op"(%1) : (f32) -> (f32)
- "yield"(%3) : (f32) -> ()
- }) : () -> (f32)
- return %2 : f32
+func.func @move_region_dependencies_values(%arg0 : index, %cond : i1) -> index {
+ %0 = "before"() : () -> (index)
+ %1 = arith.addi %arg0, %arg0 {moved_op_1} : index
+ %2 = scf.if %cond -> index {
+ %3 = arith.muli %1, %1 {inner_op} : index
+ scf.yield %3 : index
+ } else {
+ scf.yield %1 : index
+ }
+ return %2 : index
}
-// CHECK-LABEL: func @move_region_dependencies_values()
-// CHECK: %[[MOVED1:.+]] = "moved_op_1"
-// CHECK: %[[MOVED2:.+]] = "moved_op_2"
+// CHECK-LABEL: func @move_region_dependencies_values(
+// CHECK: %[[MOVED1:.+]] = arith.addi {{.*}} {moved_op_1}
+// CHECK: scf.if
+// CHECK: arith.muli %[[MOVED1]], %[[MOVED1]] {inner_op}
// CHECK: %[[BEFORE:.+]] = "before"
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
- %op1 = transform.structured.match ops{["moved_op_2"]} in %arg0
+ %op1 = transform.structured.match ops{["scf.if"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
@@ -301,31 +305,31 @@ module attributes {transform.with_named_sequence} {
// -----
// Move operations in toplogical sort order
-func.func @move_values_in_topological_sort_order() -> f32 {
- %0 = "before"() : () -> (f32)
- %1 = "moved_op_1"() : () -> (f32)
- %2 = "moved_op_2"() : () -> (f32)
- %3 = "moved_op_3"(%1) : (f32) -> (f32)
- %4 = "moved_op_4"(%1, %3) : (f32, f32) -> (f32)
- %5 = "moved_op_5"(%2) : (f32) -> (f32)
- %6 = "foo"(%4, %5) : (f32, f32) -> (f32)
- return %6 : f32
-}
-// CHECK-LABEL: func @move_values_in_topological_sort_order()
-// CHECK: %[[MOVED_1:.+]] = "moved_op_1"
-// CHECK-DAG: %[[MOVED_2:.+]] = "moved_op_3"(%[[MOVED_1]])
-// CHECK-DAG: %[[MOVED_3:.+]] = "moved_op_4"(%[[MOVED_1]], %[[MOVED_2]])
-// CHECK-DAG: %[[MOVED_4:.+]] = "moved_op_2"
-// CHECK-DAG: %[[MOVED_5:.+]] = "moved_op_5"(%[[MOVED_4]])
+func.func @move_values_in_topological_sort_order(%arg0 : index, %arg1 : index) -> index {
+ %0 = "before"() : () -> (index)
+ %1 = arith.addi %arg0, %arg0 {moved_op_1} : index
+ %2 = arith.addi %arg1, %arg1 {moved_op_2} : index
+ %3 = arith.muli %1, %1 {moved_op_3} : index
+ %4 = arith.andi %1, %3 {moved_op_4} : index
+ %5 = arith.subi %2, %2 {moved_op_5} : index
+ %6 = "foo"(%4, %5) : (index, index) -> (index)
+ return %6 : index
+}
+// CHECK-LABEL: func @move_values_in_topological_sort_order(
+// CHECK: %[[MOVED_1:.+]] = arith.addi {{.*}} {moved_op_1}
+// CHECK-DAG: %[[MOVED_2:.+]] = arith.muli %[[MOVED_1]], %[[MOVED_1]] {moved_op_3}
+// CHECK-DAG: %[[MOVED_3:.+]] = arith.andi %[[MOVED_1]], %[[MOVED_2]] {moved_op_4}
+// CHECK-DAG: %[[MOVED_4:.+]] = arith.addi {{.*}} {moved_op_2}
+// CHECK-DAG: %[[MOVED_5:.+]] = arith.subi %[[MOVED_4]], %[[MOVED_4]] {moved_op_5}
// CHECK: %[[BEFORE:.+]] = "before"
// CHECK: %[[FOO:.+]] = "foo"(%[[MOVED_3]], %[[MOVED_5]])
// CHECK: return %[[FOO]]
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
- %op1 = transform.structured.match ops{["moved_op_4"]} in %arg0
+ %op1 = transform.structured.match ops{["arith.andi"]} in %arg0
: (!transform.any_op) -> !transform.any_op
- %op2 = transform.structured.match ops{["moved_op_5"]} in %arg0
+ %op2 = transform.structured.match ops{["arith.subi"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op3 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
@@ -341,17 +345,17 @@ module attributes {transform.with_named_sequence} {
// Move only those value definitions that are not dominated by insertion point
-func.func @move_only_required_defns() -> (f32, f32, f32, f32) {
- %0 = "unmoved_op"() : () -> (f32)
- %1 = "dummy_op"() : () -> (f32)
- %2 = "before"() : () -> (f32)
- %3 = "moved_op"() : () -> (f32)
- return %0, %1, %2, %3 : f32, f32, f32, f32
+func.func @move_only_required_defns(%arg0 : index) -> (index, index, index, index) {
+ %0 = "unmoved_op"() : () -> (index)
+ %1 = "dummy_op"() : () -> (index)
+ %2 = "before"() : () -> (index)
+ %3 = arith.addi %arg0, %arg0 {moved_op} : index
+ return %0, %1, %2, %3 : index, index, index, index
}
-// CHECK-LABEL: func @move_only_required_defns()
+// CHECK-LABEL: func @move_only_required_defns(
// CHECK: %[[UNMOVED:.+]] = "unmoved_op"
// CHECK: %[[DUMMY:.+]] = "dummy_op"
-// CHECK: %[[MOVED:.+]] = "moved_op"
+// CHECK: %[[MOVED:.+]] = arith.addi {{.*}} {moved_op}
// CHECK: %[[BEFORE:.+]] = "before"
module attributes {transform.with_named_sequence} {
@@ -362,7 +366,7 @@ module attributes {transform.with_named_sequence} {
: (!transform.any_op) -> !transform.any_op
%op3 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
- %op4 = transform.structured.match ops{["moved_op"]} in %arg0
+ %op4 = transform.structured.match ops{["arith.addi"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
%v2 = transform.get_result %op4[0] : (!transform.any_op) -> !transform.any_value
@@ -374,19 +378,19 @@ module attributes {transform.with_named_sequence} {
// -----
-// Move only those value definitions that are not dominated by insertion point
+// Move only those value definitions that are not dominated by insertion point (duplicate test)
-func.func @move_only_required_defns() -> (f32, f32, f32, f32) {
- %0 = "unmoved_op"() : () -> (f32)
- %1 = "dummy_op"() : () -> (f32)
- %2 = "before"() : () -> (f32)
- %3 = "moved_op"() : () -> (f32)
- return %0, %1, %2, %3 : f32, f32, f32, f32
+func.func @move_only_required_defns_2(%arg0 : index) -> (index, index, index, index) {
+ %0 = "unmoved_op"() : () -> (index)
+ %1 = "dummy_op"() : () -> (index)
+ %2 = "before"() : () -> (index)
+ %3 = arith.subi %arg0, %arg0 {moved_op} : index
+ return %0, %1, %2, %3 : index, index, index, index
}
-// CHECK-LABEL: func @move_only_required_defns()
+// CHECK-LABEL: func @move_only_required_defns_2(
// CHECK: %[[UNMOVED:.+]] = "unmoved_op"
// CHECK: %[[DUMMY:.+]] = "dummy_op"
-// CHECK: %[[MOVED:.+]] = "moved_op"
+// CHECK: %[[MOVED:.+]] = arith.subi {{.*}} {moved_op}
// CHECK: %[[BEFORE:.+]] = "before"
module attributes {transform.with_named_sequence} {
@@ -397,7 +401,7 @@ module attributes {transform.with_named_sequence} {
: (!transform.any_op) -> !transform.any_op
%op3 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
- %op4 = transform.structured.match ops{["moved_op"]} in %arg0
+ %op4 = transform.structured.match ops{["arith.subi"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
%v2 = transform.get_result %op4[0] : (!transform.any_op) -> !transform.any_value
@@ -410,23 +414,23 @@ module attributes {transform.with_named_sequence} {
// -----
// Check handling of block arguments
-func.func @move_only_required_defns() -> (f32, f32) {
- %0 = "unmoved_op"() : () -> (f32)
- cf.br ^bb0(%0 : f32)
- ^bb0(%arg0 : f32) :
- %1 = "before"() : () -> (f32)
- %2 = "moved_op"(%arg0) : (f32) -> (f32)
- return %1, %2 : f32, f32
-}
-// CHECK-LABEL: func @move_only_required_defns()
-// CHECK: %[[MOVED:.+]] = "moved_op"
+func.func @move_with_block_arguments() -> (index, index) {
+ %0 = "unmoved_op"() : () -> (index)
+ cf.br ^bb0(%0 : index)
+ ^bb0(%arg0 : index) :
+ %1 = "before"() : () -> (index)
+ %2 = arith.addi %arg0, %arg0 {moved_op} : index
+ return %1, %2 : index, index
+}
+// CHECK-LABEL: func @move_with_block_arguments()
+// CHECK: %[[MOVED:.+]] = arith.addi {{.*}} {moved_op}
// CHECK: %[[BEFORE:.+]] = "before"
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
- %op2 = transform.structured.match ops{["moved_op"]} in %arg0
+ %op2 = transform.structured.match ops{["arith.addi"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
transform.test.move_value_defns %v1 before %op1
@@ -438,20 +442,20 @@ module attributes {transform.with_named_sequence} {
// -----
// Do not move across basic blocks
-func.func @no_move_across_basic_blocks() -> (f32, f32) {
- %0 = "unmoved_op"() : () -> (f32)
- %1 = "before"() : () -> (f32)
- cf.br ^bb0(%0 : f32)
- ^bb0(%arg0 : f32) :
- %2 = "moved_op"(%arg0) : (f32) -> (f32)
- return %1, %2 : f32, f32
+func.func @no_move_across_basic_blocks() -> (index, index) {
+ %0 = "unmoved_op"() : () -> (index)
+ %1 = "before"() : () -> (index)
+ cf.br ^bb0(%0 : index)
+ ^bb0(%arg0 : index) :
+ %2 = arith.addi %arg0, %arg0 {moved_op} : index
+ return %1, %2 : index, index
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
- %op2 = transform.structured.match ops{["moved_op"]} in %arg0
+ %op2 = transform.structured.match ops{["arith.addi"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
// expected-remark@+1{{unsupported case of moving definition of value before an insertion point in a different basic block}}
@@ -463,24 +467,22 @@ module attributes {transform.with_named_sequence} {
// -----
-func.func @move_isolated_from_above() -> () {
- %1 = "before"() : () -> (f32)
- %2 = "moved0"() : () -> (f32)
- %3 = test.isolated_one_region_op %2 {} : f32 -> f32
- %4 = "moved1"(%3) : (f32) -> (f32)
+func.func @move_isolated_from_above(%arg0 : index) -> () {
+ %1 = "before"() : () -> (index)
+ %2 = arith.addi %arg0, %arg0 {moved0} : index
+ %3 = arith.muli %2, %2 {moved1} : index
return
}
-// CHECK-LABEL: func @move_isolated_from_above()
-// CHECK: %[[MOVED0:.+]] = "moved0"
-// CHECK: %[[ISOLATED:.+]] = test.isolated_one_region_op %[[MOVED0]]
-// CHECK: %[[MOVED1:.+]] = "moved1"(%[[ISOLATED]])
+// CHECK-LABEL: func @move_isolated_from_above(
+// CHECK: %[[MOVED0:.+]] = arith.addi {{.*}} {moved0}
+// CHECK: %[[MOVED1:.+]] = arith.muli %[[MOVED0]], %[[MOVED0]] {moved1}
// CHECK: %[[BEFORE:.+]] = "before"
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
- %op2 = transform.structured.match ops{["moved1"]} in %arg0
+ %op2 = transform.structured.match ops{["arith.muli"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
transform.test.move_value_defns %v1 before %op1
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index e730450..7130667 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -118,6 +118,17 @@ func.func @main(%arg0 : i32) {
// -----
+// CHECK-LABEL: func.func private @clean_func_op_remove_side_effecting_op() {
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+func.func private @clean_func_op_remove_side_effecting_op(%arg0: i32) -> (i32) {
+ // vector.print has a side effect but the op is dead.
+ vector.print %arg0 : i32
+ return %arg0 : i32
+}
+
+// -----
+
// %arg0 is not live because it is never used. %arg1 is not live because its
// user `arith.addi` doesn't have any uses and the value that it is forwarded to
// (%non_live_0) also doesn't have any uses.
@@ -674,3 +685,32 @@ func.func @dead_value_loop_ivs_no_result(%lb: index, %ub: index, %step: index, %
}
return
}
+
+// -----
+
+// CHECK-LABEL: func @op_block_have_dead_arg
+func.func @op_block_have_dead_arg(%arg0: index, %arg1: index, %arg2: i1) {
+ scf.execute_region {
+ cf.cond_br %arg2, ^bb1(%arg0 : index), ^bb1(%arg1 : index)
+ ^bb1(%0: index):
+ scf.yield
+ }
+ // CHECK-NEXT: return
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func private @remove_dead_branch_op()
+// CHECK-NEXT: ub.unreachable
+// CHECK-NEXT: ^{{.*}}:
+// CHECK-NEXT: return
+// CHECK-NEXT: ^{{.*}}:
+// CHECK-NEXT: return
+func.func private @remove_dead_branch_op(%c: i1, %arg0: i64, %arg1: i64) -> (i64) {
+ cf.cond_br %c, ^bb1, ^bb2
+^bb1:
+ return %arg0 : i64
+^bb2:
+ return %arg1 : i64
+}
diff --git a/mlir/test/Transforms/test-legalizer-full.mlir b/mlir/test/Transforms/test-legalizer-full.mlir
index 42cec68..8da9109 100644
--- a/mlir/test/Transforms/test-legalizer-full.mlir
+++ b/mlir/test/Transforms/test-legalizer-full.mlir
@@ -72,3 +72,21 @@ builtin.module {
}
}
+
+// -----
+
+// The region of "test.post_order_legalization" is converted before the op.
+
+// expected-remark@+1 {{applyFullConversion failed}}
+builtin.module {
+func.func @test_preorder_legalization() {
+ // expected-error@+1 {{failed to legalize operation 'test.post_order_legalization'}}
+ "test.post_order_legalization"() ({
+ ^bb0(%arg0: i64):
+ // Not-explicitly-legal ops are not allowed to survive.
+ "test.remaining_consumer"(%arg0) : (i64) -> ()
+ "test.invalid"(%arg0) : (i64) -> ()
+ }) : () -> ()
+ return
+}
+}
diff --git a/mlir/test/Transforms/test-legalizer-no-materializations.mlir b/mlir/test/Transforms/test-legalizer-no-materializations.mlir
new file mode 100644
index 0000000..82dd742
--- /dev/null
+++ b/mlir/test/Transforms/test-legalizer-no-materializations.mlir
@@ -0,0 +1,67 @@
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0 build-materializations=0 attach-debug-materialization-kind=1" -verify-diagnostics %s | FileCheck %s --check-prefix=CHECK-KIND
+
+// CHECK-LABEL: func @dropped_input_in_use
+// CHECK-KIND-LABEL: func @dropped_input_in_use
+func.func @dropped_input_in_use(%arg: i16, %arg2: i64) {
+ // CHECK-NEXT: %[[cast:.*]] = "test.cast"() : () -> i16
+ // CHECK-NEXT: "work"(%[[cast]]) : (i16)
+ // CHECK-KIND-NEXT: %[[cast:.*]] = builtin.unrealized_conversion_cast to i16 {__kind__ = "source"}
+ // CHECK-KIND-NEXT: "work"(%[[cast]]) : (i16)
+ // expected-remark@+1 {{op 'work' is not legalizable}}
+ "work"(%arg) : (i16) -> ()
+}
+
+// -----
+
+// CHECK-KIND-LABEL: func @test_lookup_without_converter
+// CHECK-KIND: %[[producer:.*]] = "test.valid_producer"() : () -> i16
+// CHECK-KIND: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[producer]] : i16 to f64 {__kind__ = "target"}
+// CHECK-KIND: "test.valid_consumer"(%[[cast]]) : (f64) -> ()
+// CHECK-KIND: "test.valid_consumer"(%[[producer]]) : (i16) -> ()
+func.func @test_lookup_without_converter() {
+ %0 = "test.replace_with_valid_producer"() {type = i16} : () -> (i64)
+ "test.replace_with_valid_consumer"(%0) {with_converter} : (i64) -> ()
+ // Make sure that the second "replace_with_valid_consumer" lowering does not
+ // lookup the materialization that was created for the above op.
+ "test.replace_with_valid_consumer"(%0) : (i64) -> ()
+ // expected-remark@+1 {{op 'func.return' is not legalizable}}
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @remap_moved_region_args
+func.func @remap_moved_region_args() {
+ // CHECK-NEXT: return
+ // CHECK-NEXT: ^bb1(%[[arg0:.*]]: i64, %[[arg1:.*]]: i16, %[[arg2:.*]]: i64, %[[arg3:.*]]: f32):
+ // CHECK-NEXT: %[[cast1:.*]]:2 = builtin.unrealized_conversion_cast %[[arg3]] : f32 to f16, f16
+ // CHECK-NEXT: %[[cast2:.*]] = builtin.unrealized_conversion_cast %[[arg2]] : i64 to f64
+ // CHECK-NEXT: %[[cast3:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to f64
+ // CHECK-NEXT: %[[cast4:.*]] = "test.cast"(%[[cast1]]#0, %[[cast1]]#1) : (f16, f16) -> f32
+ // CHECK-NEXT: "test.valid"(%[[cast3]], %[[cast2]], %[[cast4]]) : (f64, f64, f32)
+ "test.region"() ({
+ ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
+ "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
+ }) : () -> ()
+ // expected-remark@+1 {{op 'func.return' is not legalizable}}
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @remap_cloned_region_args
+func.func @remap_cloned_region_args() {
+ // CHECK-NEXT: return
+ // CHECK-NEXT: ^bb1(%[[arg0:.*]]: i64, %[[arg1:.*]]: i16, %[[arg2:.*]]: i64, %[[arg3:.*]]: f32):
+ // CHECK-NEXT: %[[cast1:.*]]:2 = builtin.unrealized_conversion_cast %[[arg3]] : f32 to f16, f16
+ // CHECK-NEXT: %[[cast2:.*]] = builtin.unrealized_conversion_cast %[[arg2]] : i64 to f64
+ // CHECK-NEXT: %[[cast3:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to f64
+ // CHECK-NEXT: %[[cast4:.*]] = "test.cast"(%[[cast1]]#0, %[[cast1]]#1) : (f16, f16) -> f32
+ // CHECK-NEXT: "test.valid"(%[[cast3]], %[[cast2]], %[[cast4]]) : (f64, f64, f32)
+ "test.region"() ({
+ ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
+ "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
+ }) {legalizer.should_clone} : () -> ()
+ // expected-remark@+1 {{op 'func.return' is not legalizable}}
+ return
+}
diff --git a/mlir/test/Transforms/test-legalizer-no-rollback.mlir b/mlir/test/Transforms/test-legalizer-no-rollback.mlir
new file mode 100644
index 0000000..5f421a3
--- /dev/null
+++ b/mlir/test/Transforms/test-legalizer-no-rollback.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0" -verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: @conditional_replacement(
+// CHECK-SAME: %[[arg0:.*]]: i43)
+// CHECK: %[[cast1:.*]] = "test.cast"(%[[arg0]]) : (i43) -> i42
+// CHECK: %[[legal:.*]] = "test.legal_op"() : () -> i42
+// CHECK: %[[cast2:.*]] = "test.cast"(%[[legal]], %[[legal]]) : (i42, i42) -> i42
+// Uses were replaced for dummy_user_1.
+// CHECK: "test.dummy_user_1"(%[[cast2]]) {replace_uses} : (i42) -> ()
+// Uses were also replaced for dummy_user_2, but not by value_replace. The uses
+// were replaced due to the block signature conversion.
+// CHECK: "test.dummy_user_2"(%[[cast1]]) : (i42) -> ()
+// CHECK: "test.value_replace"(%[[cast1]], %[[legal]]) {conditional, is_legal} : (i42, i42) -> ()
+func.func @conditional_replacement(%arg0: i42) {
+ %repl = "test.legal_op"() : () -> (i42)
+ // expected-remark @+1 {{is not legalizable}}
+ "test.dummy_user_1"(%arg0) {replace_uses} : (i42) -> ()
+ // expected-remark @+1 {{is not legalizable}}
+ "test.dummy_user_2"(%arg0) {} : (i42) -> ()
+ // Perform a conditional 1:N replacement.
+ "test.value_replace"(%arg0, %repl) {conditional} : (i42, i42) -> ()
+ "test.return"() : () -> ()
+}
diff --git a/mlir/test/Transforms/test-legalizer-rollback.mlir b/mlir/test/Transforms/test-legalizer-rollback.mlir
index 71e1178..4bcca6b 100644
--- a/mlir/test/Transforms/test-legalizer-rollback.mlir
+++ b/mlir/test/Transforms/test-legalizer-rollback.mlir
@@ -163,3 +163,22 @@ func.func @create_unregistered_op_in_pattern() -> i32 {
"test.return"(%0) : (i32) -> ()
}
}
+
+// -----
+
+// CHECK-LABEL: func @test_failed_preorder_legalization
+// CHECK: "test.post_order_legalization"() ({
+// CHECK: %[[r:.*]] = "test.illegal_op_g"() : () -> i32
+// CHECK: "test.return"(%[[r]]) : (i32) -> ()
+// CHECK: }) : () -> ()
+// expected-remark @+1 {{applyPartialConversion failed}}
+module {
+func.func @test_failed_preorder_legalization() {
+ // expected-error @+1 {{failed to legalize operation 'test.post_order_legalization' that was explicitly marked illegal}}
+ "test.post_order_legalization"() ({
+ %0 = "test.illegal_op_g"() : () -> (i32)
+ "test.return"(%0) : (i32) -> ()
+ }) : () -> ()
+ return
+}
+}
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 94c5bb4..88a71cc 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -1,7 +1,6 @@
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=1" -verify-diagnostics %s | FileCheck %s
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=1" -verify-diagnostics -profile-actions-to=- %s | FileCheck %s --check-prefix=CHECK-PROFILER
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0" -verify-diagnostics %s | FileCheck %s
-// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0 build-materializations=0 attach-debug-materialization-kind=1" -verify-diagnostics %s | FileCheck %s --check-prefix=CHECK-KIND
// CHECK-PROFILER: "name": "pass-execution", "cat": "PERF", "ph": "B"
// CHECK-PROFILER: "name": "apply-conversion", "cat": "PERF", "ph": "B"
@@ -146,36 +145,6 @@ func.func @no_remap_nested() {
// -----
-// CHECK-LABEL: func @remap_moved_region_args
-func.func @remap_moved_region_args() {
- // CHECK-NEXT: return
- // CHECK-NEXT: ^bb1(%{{.*}}: f64, %{{.*}}: f64, %{{.*}}: f16, %{{.*}}: f16):
- // CHECK-NEXT: "test.cast"{{.*}} : (f16, f16) -> f32
- // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64, f32)
- "test.region"() ({
- ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
- "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
- }) : () -> ()
- // expected-remark@+1 {{op 'func.return' is not legalizable}}
- return
-}
-
-// -----
-
-// CHECK-LABEL: func @remap_cloned_region_args
-func.func @remap_cloned_region_args() {
- // CHECK-NEXT: return
- // CHECK-NEXT: ^bb1(%{{.*}}: f64, %{{.*}}: f64, %{{.*}}: f16, %{{.*}}: f16):
- // CHECK-NEXT: "test.cast"{{.*}} : (f16, f16) -> f32
- // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64, f32)
- "test.region"() ({
- ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
- "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
- }) {legalizer.should_clone} : () -> ()
- // expected-remark@+1 {{op 'func.return' is not legalizable}}
- return
-}
-
// CHECK-LABEL: func @remap_drop_region
func.func @remap_drop_region() {
// CHECK-NEXT: return
@@ -191,12 +160,9 @@ func.func @remap_drop_region() {
// -----
// CHECK-LABEL: func @dropped_input_in_use
-// CHECK-KIND-LABEL: func @dropped_input_in_use
func.func @dropped_input_in_use(%arg: i16, %arg2: i64) {
// CHECK-NEXT: %[[cast:.*]] = "test.cast"() : () -> i16
// CHECK-NEXT: "work"(%[[cast]]) : (i16)
- // CHECK-KIND-NEXT: %[[cast:.*]] = builtin.unrealized_conversion_cast to i16 {__kind__ = "source"}
- // CHECK-KIND-NEXT: "work"(%[[cast]]) : (i16)
// expected-remark@+1 {{op 'work' is not legalizable}}
"work"(%arg) : (i16) -> ()
}
@@ -452,11 +418,6 @@ func.func @test_multiple_1_to_n_replacement() {
// CHECK: %[[cast:.*]] = "test.cast"(%[[producer]]) : (i16) -> f64
// CHECK: "test.valid_consumer"(%[[cast]]) : (f64) -> ()
// CHECK: "test.valid_consumer"(%[[producer]]) : (i16) -> ()
-// CHECK-KIND-LABEL: func @test_lookup_without_converter
-// CHECK-KIND: %[[producer:.*]] = "test.valid_producer"() : () -> i16
-// CHECK-KIND: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[producer]] : i16 to f64 {__kind__ = "target"}
-// CHECK-KIND: "test.valid_consumer"(%[[cast]]) : (f64) -> ()
-// CHECK-KIND: "test.valid_consumer"(%[[producer]]) : (i16) -> ()
func.func @test_lookup_without_converter() {
%0 = "test.replace_with_valid_producer"() {type = i16} : () -> (i64)
"test.replace_with_valid_consumer"(%0) {with_converter} : (i64) -> ()
@@ -487,3 +448,35 @@ func.func @test_working_1to1_pattern(%arg0: f16) {
"test.type_consumer"(%arg0) : (f16) -> ()
"test.return"() : () -> ()
}
+
+// -----
+
+// The region of "test.post_order_legalization" is converted before the op.
+
+// CHECK: notifyBlockInserted into test.post_order_legalization: was unlinked
+// CHECK: notifyOperationInserted: test.invalid
+// CHECK: notifyBlockErased
+// CHECK: notifyOperationInserted: test.valid, was unlinked
+// CHECK: notifyOperationReplaced: test.invalid
+// CHECK: notifyOperationErased: test.invalid
+// CHECK: notifyOperationModified: test.post_order_legalization
+
+// CHECK-LABEL: func @test_preorder_legalization
+// CHECK: "test.post_order_legalization"() ({
+// CHECK: ^{{.*}}(%[[arg0:.*]]: f64):
+// Note: The survival of a not-explicitly-invalid operation does *not* cause
+// a conversion failure in when applying a partial conversion.
+// CHECK: %[[cast:.*]] = "test.cast"(%[[arg0]]) : (f64) -> i64
+// CHECK: "test.remaining_consumer"(%[[cast]]) : (i64) -> ()
+// CHECK: "test.valid"(%[[arg0]]) : (f64) -> ()
+// CHECK: }) {is_legal} : () -> ()
+func.func @test_preorder_legalization() {
+ "test.post_order_legalization"() ({
+ ^bb0(%arg0: i64):
+ // expected-remark @+1 {{'test.remaining_consumer' is not legalizable}}
+ "test.remaining_consumer"(%arg0) : (i64) -> ()
+ "test.invalid"(%arg0) : (i64) -> ()
+ }) : () -> ()
+ // expected-remark @+1 {{'func.return' is not legalizable}}
+ return
+}
diff --git a/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp
index 8e2f03b..99f72c6 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp
@@ -56,6 +56,17 @@ struct TestLivenessAnalysisPass
liveness->print(os);
os << "\n";
}
+ for (auto [regionIndex, region] : llvm::enumerate(op->getRegions())) {
+ os << " region: #" << regionIndex << ":\n";
+ for (auto [argumntIndex, argument] :
+ llvm::enumerate(region.getArguments())) {
+ const Liveness *liveness = livenessAnalysis.getLiveness(argument);
+ assert(liveness && "expected a sparse lattice");
+ os << " argument: #" << argumntIndex << ": ";
+ liveness->print(os);
+ os << "\n";
+ }
+ }
});
}
};
diff --git a/mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp b/mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp
index 027b0a1..3ff0dc8 100644
--- a/mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp
+++ b/mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp
@@ -46,7 +46,7 @@ struct TestPointerLikeTypeInterfacePass
Pass::Option<std::string> testMode{
*this, "test-mode",
- llvm::cl::desc("Test mode: walk, alloc, copy, or free"),
+ llvm::cl::desc("Test mode: walk, alloc, copy, free, load, or store"),
llvm::cl::init("walk")};
StringRef getArgument() const override {
@@ -75,6 +75,10 @@ private:
void testGenCopy(Operation *srcOp, Operation *destOp, Value srcResult,
Value destResult, PointerLikeType pointerType,
OpBuilder &builder);
+ void testGenLoad(Operation *op, Value result, PointerLikeType pointerType,
+ OpBuilder &builder);
+ void testGenStore(Operation *op, Value result, PointerLikeType pointerType,
+ OpBuilder &builder, Value providedValue = {});
struct PointerCandidate {
Operation *op;
@@ -92,9 +96,12 @@ void TestPointerLikeTypeInterfacePass::runOnOperation() {
auto func = getOperation();
OpBuilder builder(&getContext());
- if (testMode == "alloc" || testMode == "free") {
+ if (testMode == "alloc" || testMode == "free" || testMode == "load" ||
+ testMode == "store") {
// Collect all candidates first
SmallVector<PointerCandidate> candidates;
+ // For store mode, also look for a test value to use
+ Value testValue;
func.walk([&](Operation *op) {
if (op->hasAttr("test.ptr")) {
for (auto result : op->getResults()) {
@@ -105,6 +112,11 @@ void TestPointerLikeTypeInterfacePass::runOnOperation() {
}
}
}
+ // Collect value marked with test.value for store tests
+ if (testMode == "store" && op->hasAttr("test.value")) {
+ if (op->getNumResults() > 0)
+ testValue = op->getResult(0);
+ }
});
// Now test all candidates
@@ -115,6 +127,12 @@ void TestPointerLikeTypeInterfacePass::runOnOperation() {
else if (testMode == "free")
testGenFree(candidate.op, candidate.result, candidate.pointerType,
builder);
+ else if (testMode == "load")
+ testGenLoad(candidate.op, candidate.result, candidate.pointerType,
+ builder);
+ else if (testMode == "store")
+ testGenStore(candidate.op, candidate.result, candidate.pointerType,
+ builder, testValue);
}
} else if (testMode == "copy") {
// Collect all source and destination candidates
@@ -292,6 +310,105 @@ void TestPointerLikeTypeInterfacePass::testGenCopy(
}
}
+void TestPointerLikeTypeInterfacePass::testGenLoad(Operation *op, Value result,
+ PointerLikeType pointerType,
+ OpBuilder &builder) {
+ Location loc = op->getLoc();
+
+ // Create a new builder with the listener and set insertion point
+ OperationTracker tracker;
+ OpBuilder newBuilder(op->getContext());
+ newBuilder.setListener(&tracker);
+ newBuilder.setInsertionPointAfter(op);
+
+ // Call the genLoad API
+ auto typedResult = cast<TypedValue<PointerLikeType>>(result);
+ Value loadRes = pointerType.genLoad(newBuilder, loc, typedResult, Type());
+
+ if (loadRes) {
+ llvm::errs() << "Successfully generated load for operation: ";
+ op->print(llvm::errs());
+ llvm::errs() << "\n";
+ llvm::errs() << "\tLoaded value type: ";
+ loadRes.getType().print(llvm::errs());
+ llvm::errs() << "\n";
+
+ // Print all operations that were inserted
+ for (Operation *insertedOp : tracker.insertedOps) {
+ llvm::errs() << "\tGenerated: ";
+ insertedOp->print(llvm::errs());
+ llvm::errs() << "\n";
+ }
+ } else {
+ llvm::errs() << "Failed to generate load for operation: ";
+ op->print(llvm::errs());
+ llvm::errs() << "\n";
+ }
+}
+
+void TestPointerLikeTypeInterfacePass::testGenStore(Operation *op, Value result,
+ PointerLikeType pointerType,
+ OpBuilder &builder,
+ Value providedValue) {
+ Location loc = op->getLoc();
+
+ // Create a new builder with the listener and set insertion point
+ OperationTracker tracker;
+ OpBuilder newBuilder(op->getContext());
+ newBuilder.setListener(&tracker);
+ newBuilder.setInsertionPointAfter(op);
+
+ // Use provided value if available, otherwise create a constant
+ Value valueToStore = providedValue;
+ if (!valueToStore) {
+ // Create a test value to store - use a constant matching the element type
+ Type elementType = pointerType.getElementType();
+ if (!elementType) {
+ llvm::errs() << "Failed to generate store for operation: ";
+ op->print(llvm::errs());
+ llvm::errs() << "\n";
+ return;
+ }
+
+ if (elementType.isIntOrIndex()) {
+ auto attr = newBuilder.getIntegerAttr(elementType, 42);
+ valueToStore =
+ arith::ConstantOp::create(newBuilder, loc, elementType, attr);
+ } else if (auto floatType = dyn_cast<FloatType>(elementType)) {
+ auto attr = newBuilder.getFloatAttr(floatType, 42.0);
+ valueToStore =
+ arith::ConstantOp::create(newBuilder, loc, floatType, attr);
+ } else {
+ llvm::errs() << "Failed to generate store for operation: ";
+ op->print(llvm::errs());
+ llvm::errs() << "\n";
+ return;
+ }
+ }
+
+ // Call the genStore API
+ auto typedResult = cast<TypedValue<PointerLikeType>>(result);
+ bool success =
+ pointerType.genStore(newBuilder, loc, valueToStore, typedResult);
+
+ if (success) {
+ llvm::errs() << "Successfully generated store for operation: ";
+ op->print(llvm::errs());
+ llvm::errs() << "\n";
+
+ // Print all operations that were inserted
+ for (Operation *insertedOp : tracker.insertedOps) {
+ llvm::errs() << "\tGenerated: ";
+ insertedOp->print(llvm::errs());
+ llvm::errs() << "\n";
+ }
+ } else {
+ llvm::errs() << "Failed to generate store for operation: ";
+ op->print(llvm::errs());
+ llvm::errs() << "\n";
+ }
+}
+
} // namespace
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/OpenACC/TestRecipePopulate.cpp b/mlir/test/lib/Dialect/OpenACC/TestRecipePopulate.cpp
index 35f092c..2506ca4 100644
--- a/mlir/test/lib/Dialect/OpenACC/TestRecipePopulate.cpp
+++ b/mlir/test/lib/Dialect/OpenACC/TestRecipePopulate.cpp
@@ -93,6 +93,29 @@ void TestRecipePopulatePass::runOnOperation() {
if (!recipe) {
op->emitError("Failed to create firstprivate recipe for ") << varName;
}
+ } else if (recipeType == "private_from_firstprivate") {
+ // First create a firstprivate recipe, then use it to drive creation of a
+ // matching private recipe via the convenience overload. Give each recipe
+ // a stable, predictable name so tests can check both.
+ std::string firstprivName = "first_firstprivate_" + varName;
+ std::string privName = "private_from_firstprivate_" + varName;
+
+ auto firstpriv = FirstprivateRecipeOp::createAndPopulate(
+ builder, loc, firstprivName, var.getType(), varName, bounds);
+
+ if (!firstpriv) {
+ op->emitError("Failed to create firstprivate recipe for ") << varName;
+ return;
+ }
+
+ auto priv = PrivateRecipeOp::createAndPopulate(builder, loc, privName,
+ *firstpriv);
+
+ if (!priv) {
+ op->emitError(
+ "Failed to create private recipe (from firstprivate) for ")
+ << varName;
+ }
}
}
}
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 21d75f5..43392d7 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -37,7 +37,6 @@
#include "llvm/Support/Base64.h"
#include "llvm/Support/Casting.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Interfaces/FoldInterfaces.h"
#include "mlir/Reducer/ReductionPatternInterface.h"
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 4d4ec02..8689265 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -320,10 +320,10 @@ LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
}
//===----------------------------------------------------------------------===//
-// OpWithResultShapePerDimInterfaceOp
+// ReifyShapedTypeUsingReifyResultShapesOp
//===----------------------------------------------------------------------===//
-LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
+LogicalResult ReifyShapedTypeUsingReifyResultShapesOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
Location loc = getLoc();
shapes.reserve(getNumOperands());
@@ -345,6 +345,103 @@ LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
}
//===----------------------------------------------------------------------===//
+// ReifyShapedTypeUsingReifyShapeOfResultOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ReifyShapedTypeUsingReifyShapeOfResultOp::reifyResultShapes(
+ OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
+ return failure();
+}
+
+FailureOr<SmallVector<OpFoldResult>>
+ReifyShapedTypeUsingReifyShapeOfResultOp::reifyShapeOfResult(OpBuilder &builder,
+ int resultIndex) {
+ Location loc = getLoc();
+ Value sourceOperand = getOperand(getNumOperands() - 1 - resultIndex);
+ SmallVector<OpFoldResult> shape =
+ tensor::getMixedSizes(builder, loc, sourceOperand);
+ return shape;
+}
+
+//===----------------------------------------------------------------------===//
+// ReifyShapedTypeUsingReifyDimOfResultOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ReifyShapedTypeUsingReifyDimOfResultOp::reifyResultShapes(
+ OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
+ return failure();
+}
+
+FailureOr<SmallVector<OpFoldResult>>
+ReifyShapedTypeUsingReifyDimOfResultOp::reifyShapeOfResult(OpBuilder &builder,
+ int resultIndex) {
+ return failure();
+}
+
+FailureOr<OpFoldResult>
+ReifyShapedTypeUsingReifyDimOfResultOp::reifyDimOfResult(OpBuilder &builder,
+ int resultIndex,
+ int dim) {
+ Location loc = getLoc();
+ Value sourceOperand = getOperand(getNumOperands() - 1 - resultIndex);
+ OpFoldResult shape = tensor::getMixedSize(builder, loc, sourceOperand, dim);
+ return shape;
+}
+
+//===----------------------------------------------------------------------===//
+// UnreifableResultShapesOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult UnreifiableResultShapesOp::reifyResultShapes(
+ OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
+ Location loc = getLoc();
+ shapes.resize(1);
+ shapes[0] = {tensor::getMixedSize(builder, loc, getOperand(), 0),
+ OpFoldResult()};
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// UnreifableResultShapeOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult UnreifiableResultShapeOp::reifyResultShapes(
+ OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
+ return failure();
+}
+
+FailureOr<SmallVector<OpFoldResult>>
+UnreifiableResultShapeOp::reifyShapeOfResult(OpBuilder &builder,
+ int resultIndex) {
+ SmallVector<OpFoldResult> shape = {
+ tensor::getMixedSize(builder, getLoc(), getOperand(), 0), OpFoldResult()};
+ return shape;
+}
+
+//===----------------------------------------------------------------------===//
+// UnreifableResultShapeOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult UnreifiableDimOfResultShapeOp::reifyResultShapes(
+ OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
+ return failure();
+}
+
+FailureOr<SmallVector<OpFoldResult>>
+UnreifiableDimOfResultShapeOp::reifyShapeOfResult(OpBuilder &builder,
+ int resultIndex) {
+ return failure();
+}
+
+FailureOr<OpFoldResult>
+UnreifiableDimOfResultShapeOp::reifyDimOfResult(OpBuilder &builder,
+ int resultIndex, int dim) {
+ if (dim == 0)
+ return tensor::getMixedSize(builder, getLoc(), getOperand(), 0);
+ return failure();
+}
+
+//===----------------------------------------------------------------------===//
// SideEffectOp
//===----------------------------------------------------------------------===//
@@ -1052,6 +1149,32 @@ LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes(
}
//===----------------------------------------------------------------------===//
+// TilingNoDpsOp
+//===----------------------------------------------------------------------===//
+
+SmallVector<Range> TilingNoDpsOp::getIterationDomain(OpBuilder &builder) {
+ return {};
+}
+
+SmallVector<utils::IteratorType> TilingNoDpsOp::getLoopIteratorTypes() {
+ return {};
+}
+
+FailureOr<TilingResult>
+TilingNoDpsOp::getTiledImplementation(OpBuilder &builder,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes) {
+ return failure();
+}
+
+LogicalResult TilingNoDpsOp::getResultTilePosition(
+ OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+ SmallVector<OpFoldResult> &resultSizes) {
+ return failure();
+}
+
+//===----------------------------------------------------------------------===//
// OpWithShapedTypeInferTypeAdaptorInterfaceOp
//===----------------------------------------------------------------------===//
@@ -1514,3 +1637,14 @@ test::TestCreateTensorOp::getBufferType(
return convertTensorToBuffer(getOperation(), options, type);
}
+
+// Define a custom builder for ManyRegionsOp declared in TestOps.td.
+// OpBuilder<(ins "::std::unique_ptr<::mlir::Region>":$firstRegion,
+// "::std::unique_ptr<::mlir::Region>":$secondRegion)>
+void test::ManyRegionsOp::build(
+ mlir::OpBuilder &builder, mlir::OperationState &state,
+ llvm::SmallVectorImpl<std::unique_ptr<mlir::Region>> &&regions) {
+ for (auto &&regionPtr : std::move(regions))
+ state.addRegion(std::move(regionPtr));
+ ManyRegionsOp::build(builder, state, {}, regions.size());
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.h b/mlir/test/lib/Dialect/Test/TestOps.h
index 4201ade..6792743 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.h
+++ b/mlir/test/lib/Dialect/Test/TestOps.h
@@ -42,6 +42,7 @@
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
namespace test {
class TestDialect;
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index a3430ba..5417ae9 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -30,6 +30,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/MemorySlotInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/TilingInterface.td"
include "mlir/Interfaces/ValueBoundsOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
@@ -119,6 +120,13 @@ def SymbolOp : TEST_Op<"symbol", [NoMemoryEffect, Symbol]> {
OptionalAttr<StrAttr>:$sym_visibility);
}
+def SymbolWithResultOp : TEST_Op<"symbol_with_result", [Symbol]> {
+ let summary = "invalid symbol operation that produces an SSA result";
+ let arguments = (ins StrAttr:$sym_name,
+ OptionalAttr<StrAttr>:$sym_visibility);
+ let results = (outs AnyType:$result);
+}
+
def OverriddenSymbolVisibilityOp : TEST_Op<"overridden_symbol_visibility", [
DeclareOpInterfaceMethods<Symbol, ["getVisibility", "setVisibility"]>,
]> {
@@ -914,13 +922,97 @@ def OpWithResultShapeInterfaceOp : TEST_Op<"op_with_result_shape_interface",
let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
}
-def OpWithResultShapePerDimInterfaceOp :
- TEST_Op<"op_with_result_shape_per_dim_interface",
- [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+def ReifyShapedTypeUsingReifyResultShapesOp :
+ TEST_Op<"reify_shaped_type_using_reify_result_shapes",
+ [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+ ["reifyResultShapes"]>]> {
+ let description = [{
+ Test that when resolving a single dimension of a result for an operation
+ that doesnt implement `reifyShapeOfResult` nor implements `reifyDimOfResult`
+ calls into the implementation of `reifyResultShapes` to get the required value.
+ The op semantics is that the first result has the same shape as the second operand
+ and the second result has the same shape as the first operand.
+ }];
+ let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2);
+ let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
+}
+
+def ReifyShapedTypeUsingReifyShapeOfResultOp :
+ TEST_Op<"reify_shaped_type_using_reify_shape_of_result",
+ [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+ ["reifyResultShapes", "reifyShapeOfResult"]>]> {
+ let description = [{
+ Test that when resolving a single dimension of a result for an operation
+ that doesnt implement `reifyDimOfResult` but implements `reifyShapeOfResult`, which
+ is used to get the required value. `reifyResultShapes` is implemented as a failure
+ (which is also the default implementation) to ensure it is not called.
+ The op semantics is that the first result has the same shape as the second operand
+ and the second result has the same shape as the first operand.
+ }];
+ let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2);
+ let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
+}
+
+def ReifyShapedTypeUsingReifyDimOfResultOp :
+ TEST_Op<"reify_shaped_type_using_reify_dim_of_result",
+ [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+ ["reifyResultShapes", "reifyShapeOfResult", "reifyDimOfResult"]>]> {
+ let description = [{
+ Test that when resolving a single dimension of a result for an operation
+ that implements `reifyDimOfResult`, which is used to get the required value.
+ `reifyResultShapes` and `reifyShapeOfResult` are implemented as failures
+ to ensure they are not called. The op semantics is that the first result has
+ the same shape as the second operand and the second result has the same shape
+ as the first operand.
+ }];
let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2);
let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
}
+def UnreifiableResultShapesOp : TEST_Op<"unreifiable_result_shapes",
+ [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+ ["reifyResultShapes"]>]> {
+ let description = [{
+ Test handling of case where some dimension of the result cannot be
+ reified. This tests the path when `reifyResultShapes` is implemented.
+
+ Expected that dim 0 of `result` is reifable as dim 0 of `operand`, but
+ dim 1 of `result` is not reifiable.
+ }];
+ let arguments = (ins 2DTensorOf<[AnyType]>:$operand);
+ let results = (outs 2DTensorOf<[AnyType]>:$result);
+}
+
+def UnreifiableResultShapeOp : TEST_Op<"unreifiable_result_shape",
+ [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+ ["reifyResultShapes", "reifyShapeOfResult"]>]> {
+ let description = [{
+ Test handling of case where some dimension of the result cannot be
+ reified. This tests the path when `reifyShapeOfResult` is implemented,
+ but not `reifyDimOfResult` with `reifyResultShapes` implemented as a failure.
+
+ Expected that dim 0 of `result` is reifable as dim 0 of `operand`, but
+ dim 1 of `result` is not reifiable.
+ }];
+ let arguments = (ins 2DTensorOf<[AnyType]>:$operand);
+ let results = (outs 2DTensorOf<[AnyType]>:$result);
+}
+
+def UnreifiableDimOfResultShapeOp : TEST_Op<"unreifiable_dim_of_result_shape",
+ [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+ ["reifyResultShapes", "reifyShapeOfResult", "reifyDimOfResult"]>]> {
+ let description = [{
+ Test handling of case where some dimension of the result cannot be
+ reified. This tests the path when `reifyDimOfResult` is implemented,
+ and `reifyDimOfResult` with `reifyResultShapes` are implemented as a failure.
+
+ Expected that dim 0 of `result` is reifable as dim 0 of `operand`, but
+ dim 1 of `result` is not reifiable.
+ }];
+ let arguments = (ins 2DTensorOf<[AnyType]>:$operand);
+ let results = (outs 2DTensorOf<[AnyType]>:$result);
+}
+
def IsNotScalar : Constraint<CPred<"$0.getType().getRank() != 0">>;
def UpdateAttr : Pat<(I32ElementsAttrOp $attr),
@@ -1107,6 +1199,12 @@ def TestLocationDstNoResOp : TEST_Op<"loc_dst_no_res"> {
let results = (outs);
}
+def TestLocationAttrOp : TEST_Op<"op_with_loc_attr"> {
+ let arguments = (ins LocationAttr:$loc_attr);
+ let results = (outs );
+ let assemblyFormat = "$loc_attr attr-dict";
+}
+
//===----------------------------------------------------------------------===//
// Test Patterns
//===----------------------------------------------------------------------===//
@@ -2254,6 +2352,24 @@ def IsolatedGraphRegionOp : TEST_Op<"isolated_graph_region", [
let assemblyFormat = "attr-dict-with-keyword $region";
}
+def ManyRegionsOp : TEST_Op<"many_regions", []> {
+ let summary = "operation created with move-only objects";
+ let description = [{
+ Test op with multiple regions with a `create` function that
+ takes parameters containing move-only objects.
+ }];
+
+ let regions = (region VariadicRegion<AnyRegion>:$regions);
+ let builders =
+ [OpBuilder<(ins "::std::unique_ptr<::mlir::Region>":$singleRegion), [{
+ $_state.addRegion(std::move(singleRegion));
+ build($_builder, $_state, {}, /*regionsCount=*/1);
+ }]>,
+ // Define in TestOps.cpp.
+ OpBuilder<(ins "::llvm::SmallVectorImpl<::std::unique_ptr<::mlir::"
+ "Region>>&&":$regions)>];
+}
+
def AffineScopeOp : TEST_Op<"affine_scope", [AffineScope]> {
let summary = "affine scope operation";
let description = [{
@@ -2888,6 +3004,20 @@ def TestLinalgFillOp :
}
//===----------------------------------------------------------------------===//
+// Test TilingInterface.
+//===----------------------------------------------------------------------===//
+
+def Test_TilingNoDpsOp : TEST_Op<"tiling_no_dps_op",
+ [Pure, DeclareOpInterfaceMethods<TilingInterface,
+ ["getIterationDomain",
+ "getLoopIteratorTypes",
+ "getResultTilePosition",
+ "getTiledImplementation"]>]> {
+ let arguments = (ins AnyRankedTensor:$lhs, AnyRankedTensor:$rhs);
+ let results = (outs AnyRankedTensor:$result);
+}
+
+//===----------------------------------------------------------------------===//
// Test NVVM RequiresSM trait.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index fd2b943..7eabaae 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -977,7 +977,13 @@ struct TestValueReplace : public ConversionPattern {
// Replace the first operand with 2x the second operand.
Value from = op->getOperand(0);
Value repl = op->getOperand(1);
- rewriter.replaceAllUsesWith(from, {repl, repl});
+ if (op->hasAttr("conditional")) {
+ rewriter.replaceUsesWithIf(from, {repl, repl}, [=](OpOperand &use) {
+ return use.getOwner()->hasAttr("replace_uses");
+ });
+ } else {
+ rewriter.replaceAllUsesWith(from, {repl, repl});
+ }
rewriter.modifyOpInPlace(op, [&] {
// If the "trigger_rollback" attribute is set, keep the op illegal, so
// that a rollback is triggered.
@@ -1418,6 +1424,22 @@ public:
}
};
+class TestPostOrderLegalization : public ConversionPattern {
+public:
+ TestPostOrderLegalization(MLIRContext *ctx, const TypeConverter &converter)
+ : ConversionPattern(converter, "test.post_order_legalization", 1, ctx) {}
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ for (Region &r : op->getRegions())
+ if (failed(rewriter.legalize(&r)))
+ return failure();
+ rewriter.modifyOpInPlace(
+ op, [&]() { op->setAttr("is_legal", rewriter.getUnitAttr()); });
+ return success();
+ }
+};
+
/// Test unambiguous overload resolution of replaceOpWithMultiple. This
/// function is just to trigger compiler errors. It is never executed.
[[maybe_unused]] void testReplaceOpWithMultipleOverloads(
@@ -1532,7 +1554,8 @@ struct TestLegalizePatternDriver
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
TestPassthroughInvalidOp, TestMultiple1ToNReplacement,
TestValueReplace, TestReplaceWithValidConsumer,
- TestTypeConsumerOpPattern>(&getContext(), converter);
+ TestTypeConsumerOpPattern, TestPostOrderLegalization>(
+ &getContext(), converter);
patterns.add<TestConvertBlockArgs>(converter, &getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
@@ -1553,14 +1576,16 @@ struct TestLegalizePatternDriver
[](Type type) { return type.isF32(); });
});
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
- return converter.isSignatureLegal(op.getFunctionType()) &&
- converter.isLegal(&op.getBody());
+ return converter.isSignatureLegal(op.getFunctionType());
});
target.addDynamicallyLegalOp<func::CallOp>(
[&](func::CallOp op) { return converter.isLegal(op); });
target.addDynamicallyLegalOp(
OperationName("test.value_replace", &getContext()),
[](Operation *op) { return op->hasAttr("is_legal"); });
+ target.addDynamicallyLegalOp(
+ OperationName("test.post_order_legalization", &getContext()),
+ [](Operation *op) { return op->hasAttr("is_legal"); });
// TestCreateUnregisteredOp creates `arith.constant` operation,
// which was not added to target intentionally to test
@@ -2156,8 +2181,7 @@ struct TestTypeConversionDriver
recursiveType.getName() == "outer_converted_type");
});
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
- return converter.isSignatureLegal(op.getFunctionType()) &&
- converter.isLegal(&op.getBody());
+ return converter.isSignatureLegal(op.getFunctionType());
});
target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) {
// Allow casts from F64 to F32.
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index ea20597..9859bd0 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -470,4 +470,11 @@ def TestMemrefType : Test_Type<"TestMemref",
}];
}
+// Test implementation of an interface with methods specifying a
+// method body
+def TestBaseBody : Test_Type<"TestBaseBody",
+ [DeclareTypeInterfaceMethods<TestBaseTypeInterfacePrintTypeA>]> {
+ let mnemonic = "test_base_body";
+}
+
#endif // TEST_TYPEDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 614121f..9cf64a8 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -569,11 +569,17 @@ TestTensorType::getBufferType(
::mlir::LogicalResult TestTensorType::verifyCompatibleBufferType(
::mlir::bufferization::BufferLikeType bufferType,
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {
- auto testMemref = dyn_cast<TestMemrefType>(bufferType);
- if (!testMemref)
- return emitError() << "expected TestMemrefType";
+ if (auto testMemref = dyn_cast<TestMemrefType>(bufferType)) {
+ const bool valid = getShape() == testMemref.getShape() &&
+ getElementType() == testMemref.getElementType();
+ return mlir::success(valid);
+ }
+
+ if (auto builtinMemref = dyn_cast<MemRefType>(bufferType)) {
+ const bool valid = getShape() == builtinMemref.getShape() &&
+ getElementType() == builtinMemref.getElementType();
+ return mlir::success(valid);
+ }
- const bool valid = getShape() == testMemref.getShape() &&
- getElementType() == testMemref.getElementType();
- return mlir::success(valid);
+ return emitError() << "expected MemRefType or TestMemrefType";
}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 79bfc9bb..f834d0c 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -180,6 +180,34 @@ struct TestVectorUnrollingPatterns
}));
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
+ .setNativeShape(ArrayRef<int64_t>{8, 8})
+ .setFilterConstraint([](Operation *op) {
+ return success(isa<vector::CreateMaskOp>(op));
+ }));
+ populateVectorUnrollPatterns(
+ patterns,
+ UnrollVectorOptions()
+ .setNativeShapeFn(
+ [](Operation *op) -> std::optional<SmallVector<int64_t>> {
+ auto shapeCast = dyn_cast<vector::ShapeCastOp>(op);
+ if (!shapeCast)
+ return std::nullopt;
+
+ auto resultShape = shapeCast.getResultVectorType().getShape();
+ // Special case with leading unit dims and different inner dim
+ // for result and target shape.
+ if (resultShape.size() == 2 && resultShape[0] == 1 &&
+ resultShape[1] == 32) {
+ return SmallVector<int64_t>{1, 16};
+ }
+ // Default case: [2,4] for all tests.
+ return SmallVector<int64_t>{2, 4};
+ })
+ .setFilterConstraint([](Operation *op) {
+ return success(isa<vector::ShapeCastOp>(op));
+ }));
+ populateVectorUnrollPatterns(
+ patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
.setFilterConstraint([](Operation *op) {
return success(isa<vector::TransposeOp>(op));
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 76d4611..93d5144 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -200,7 +200,8 @@ class TestStepOpPattern : public OpConversionPattern<vector::StepOp> {
Value sgId =
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
- auto maybeOffsets = sliceAttr.getOffsets(rewriter, loc, sgId, wgShape);
+ auto maybeOffsets =
+ sliceAttr.computeDistributedCoords(rewriter, loc, sgId, wgShape);
if (failed(maybeOffsets))
return failure();
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 326fec3..583d68b 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
@@ -170,9 +171,71 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
// TestFuseConsumerOp
//===----------------------------------------------------------------------===//
+/// Fuse the consumer and store both the original consumer operation as well as
+/// the fused consumer operation.
+static LogicalResult
+applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp,
+ Operation *consumer,
+ MutableArrayRef<LoopLikeOpInterface> loops,
+ TransformResults &transformResults) {
+ SmallVector<Operation *> fusedConsumerOps;
+ rewriter.setInsertionPoint(consumer);
+
+ FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
+ scf::tileAndFuseConsumer(rewriter, consumer, loops);
+ if (failed(fuseConsumerResults))
+ return consumer->emitOpError("failed to fuse consumer of slice");
+
+ // Report back the relevant handles to the transform op.
+ for (OpOperand *tiledAndFusedConsumerOperand :
+ fuseConsumerResults->tiledAndFusedConsumerOperands) {
+ fusedConsumerOps.push_back(tiledAndFusedConsumerOperand->getOwner());
+ }
+ transformResults.set(transformOp->getOpResult(0), fusedConsumerOps);
+ for (auto [index, loop] : llvm::enumerate(loops)) {
+ transformResults.set(transformOp->getOpResult(index + 1), {loop});
+ }
+ return success();
+}
+
+DiagnosedSilenceableFailure
+transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
+ TransformResults &transformResults,
+ TransformState &state) {
+ Operation *consumer = *state.getPayloadOps(getConsumer()).begin();
+
+ SmallVector<LoopLikeOpInterface> loops;
+ // Since the matcher works inside-out, we need to iterate the loops in
+ // reverse.
+ for (auto loop : llvm::reverse(getLoops())) {
+ auto loopLikeOp =
+ dyn_cast<LoopLikeOpInterface>(*state.getPayloadOps(loop).begin());
+ if (!loopLikeOp) {
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+ loops.push_back(loopLikeOp);
+ }
+ LogicalResult result = applyFuseConsumer(rewriter, getOperation(), consumer,
+ loops, transformResults);
+ return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
+ : DiagnosedSilenceableFailure::success();
+}
+
+void transform::TestFuseConsumerOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ consumesHandle(getConsumerMutable(), effects);
+ consumesHandle(getLoopsMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
+ modifiesPayload(effects);
+}
+
+//===----------------------------------------------------------------------===//
+// TestFuseConsumerUsingSliceOp
+//===----------------------------------------------------------------------===//
+
/// Apply fusing of consumer transformation to all payload ops and store both
/// the original consumer operation as well as the fused consumer operation.
-static LogicalResult applyFuseConsumer(
+static LogicalResult applyFuseConsumerUsingSlices(
RewriterBase &rewriter, Operation *transformOp,
ArrayRef<Operation *> slices, MutableArrayRef<LoopLikeOpInterface> loops,
uint32_t numConsumerToFuse, TransformResults &transformResults) {
@@ -204,10 +267,9 @@ static LogicalResult applyFuseConsumer(
return success();
}
-DiagnosedSilenceableFailure
-transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
- TransformResults &transformResults,
- TransformState &state) {
+DiagnosedSilenceableFailure transform::TestFuseConsumerUsingSliceOp::apply(
+ TransformRewriter &rewriter, TransformResults &transformResults,
+ TransformState &state) {
SmallVector<Operation *> slices;
for (auto op : getTargets()) {
auto sliceOp = *state.getPayloadOps(op).begin();
@@ -224,13 +286,13 @@ transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
loops.push_back(loopLikeOp);
}
LogicalResult result =
- applyFuseConsumer(rewriter, getOperation(), slices, loops,
- getNumConsumerToFuse(), transformResults);
+ applyFuseConsumerUsingSlices(rewriter, getOperation(), slices, loops,
+ getNumConsumerToFuse(), transformResults);
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
: DiagnosedSilenceableFailure::success();
}
-void transform::TestFuseConsumerOp::getEffects(
+void transform::TestFuseConsumerUsingSliceOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getTargetsMutable(), effects);
consumesHandle(getLoopsMutable(), effects);
@@ -622,6 +684,110 @@ DiagnosedSilenceableFailure transform::TestTileUsingCustomLoopOp::apply(
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// TestQueryProducerFusability
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::TestQueryProducerFusability::apply(
+ TransformRewriter &rewriter, TransformResults &transformResults,
+ TransformState &state) {
+ for (Operation *target : state.getPayloadOps(getTarget())) {
+ auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
+ if (!tilingInterfaceOp) {
+ return emitSilenceableError()
+ << "target operation does not implement TilingInterface";
+ }
+
+ // Collect operand numbers and their corresponding producer insert_slice
+ // offsets and sizes.
+ SmallVector<unsigned> operandNumbers;
+ SmallVector<SmallVector<OpFoldResult>> allOffsets;
+ SmallVector<SmallVector<OpFoldResult>> allSizes;
+
+ for (OpOperand &operand : target->getOpOperands()) {
+ Value operandValue = operand.get();
+ Operation *definingOp = operandValue.getDefiningOp();
+
+ // Look for a producer tensor.insert_slice. This is only for testing
+ // purposes and otherwise is not a useful transformation.
+ if (auto insertSliceOp =
+ dyn_cast_or_null<tensor::InsertSliceOp>(definingOp)) {
+ operandNumbers.push_back(operand.getOperandNumber());
+ allOffsets.push_back(insertSliceOp.getMixedOffsets());
+ allSizes.push_back(insertSliceOp.getMixedSizes());
+ }
+ }
+
+ if (!operandNumbers.empty()) {
+ bool isFusable = tilingInterfaceOp.isOpFusableWithProducerSlices(
+ operandNumbers, allOffsets, allSizes);
+
+ if (isFusable) {
+ target->emitRemark()
+ << "can be fused with producer tensor.insert_slice ops";
+ } else {
+ target->emitRemark()
+ << "cannot be fused with producer tensor.insert_slice ops";
+ }
+ }
+ }
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::TestQueryProducerFusability::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getTargetMutable(), effects);
+ onlyReadsPayload(effects);
+}
+
+//===----------------------------------------------------------------------===//
+// TestQueryConsumerFusability
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::TestQueryConsumerFusability::apply(
+ TransformRewriter &rewriter, TransformResults &transformResults,
+ TransformState &state) {
+ for (Operation *target : state.getPayloadOps(getTarget())) {
+ auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
+ if (!tilingInterfaceOp) {
+ return emitSilenceableError()
+ << "target operation does not implement TilingInterface";
+ }
+
+ // Look for tensor.extract_slice ops that consume results of the tilable op.
+ for (OpResult result : target->getResults()) {
+ for (OpOperand &use : result.getUses()) {
+ Operation *user = use.getOwner();
+
+ // Look for a consumer tensor.extract_slice. This is only for testing
+ // purposes and otherwise is not a useful transformation.
+ if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(user)) {
+ bool isFusable = tilingInterfaceOp.isOpFusableWithConsumerSlice(
+ result.getResultNumber(), extractSliceOp.getMixedOffsets(),
+ extractSliceOp.getMixedSizes());
+
+ if (isFusable) {
+ target->emitRemark()
+ << "can be fused with consumer tensor.extract_slice op";
+ } else {
+ target->emitRemark()
+ << "cannot be fused with consumer tensor.extract_slice op";
+ }
+ }
+ }
+ }
+ }
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::TestQueryConsumerFusability::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getTargetMutable(), effects);
+ onlyReadsPayload(effects);
+}
+
#define GET_OP_CLASSES
#include "TestTilingInterfaceTransformOps.cpp.inc"
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
index 694c422..8c4f64d 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
@@ -49,14 +49,19 @@ def TestFuseAndYieldOp : Op<Transform_Dialect, "test.fuse_and_yield",
}];
}
-def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
+def TestFuseConsumerUsingSliceOp : Op<Transform_Dialect, "test.fuse_consumer_using_slice",
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
- Fuses the consumer of the operation pointed to by the target handle
- using the options provided as attributes.
+ For the `insert_slice`-like operations (that are typically generated through tiling),
+ within the loop nests passed in as `loops` (that are typically generated through tiling),
+ find the consumer that these slices map to (have to be the same consumer) and fuse
+ the consumer into the loop.
+
+ Returns a handle to the original consumer operation and the consumer operation after
+ fusion.
}];
let arguments = (ins
@@ -73,6 +78,32 @@ def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
}];
}
+def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ For the `consumer` that uses the result of the outer-most loop of a loop nest passed in
+ as `loops` (that are typically generated through tiling), fuse the consumer into the
+ loop.
+
+ Returns a handle to the consumer operation after fusion and the loops that might be
+ modified.
+ }];
+
+ let arguments = (ins
+ TransformHandleTypeInterface:$consumer,
+ Variadic<TransformHandleTypeInterface>:$loops);
+ let results = (outs TransformHandleTypeInterface:$fused_consumer,
+ Variadic<TransformHandleTypeInterface>:$result_loops);
+
+ let assemblyFormat = [{
+ $consumer `into` `(` $loops `)`
+ attr-dict `:` functional-type(operands, results)
+ }];
+}
+
+
def TestTileUsingForallOp : Op<Transform_Dialect, "test.tile_using_forall",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
@@ -166,11 +197,55 @@ def TestTileUsingCustomLoopOp : Op<
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes);
let results = (outs TransformHandleTypeInterface:$tiled_ops,
Variadic<TransformHandleTypeInterface>:$loops);
-
+
let assemblyFormat = [{
$root_op `tile_sizes` `=` $tile_sizes
attr-dict `:` functional-type(operands, results)
}];
}
+def TestQueryProducerFusability : Op<
+ Transform_Dialect, "test.query_producer_fusability",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+ let description = [{
+ Test operation for the producer fusability query method in the
+ TilingInterface.
+
+ For each operation in the target handle, this looks for tensor.insert_slice
+ ops that produce operands to the tilable op. The offset/sizes from those
+ inserts is used as the arguments to `isOpFusableWithProducerSlices` and
+ emits a remark with the result of the query.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs);
+
+ let assemblyFormat = [{
+ $target attr-dict `:` type($target)
+ }];
+}
+
+def TestQueryConsumerFusability
+ : Op<Transform_Dialect, "test.query_consumer_fusability",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+ let description = [{
+ Test operation for the consumer fusability query method in the
+ TilingInterface.
+
+ For each operation in the target handle, this looks for tensor.extract_slice
+ ops that consume results of the tilable op. The offset/sizes from those
+ extracts is used as the arguments to `isOpFusableWithConsumerSlice` and
+ emits a remark with the result of the query.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs);
+
+ let assemblyFormat = [{
+ $target attr-dict `:` type($target)
+ }];
+}
+
#endif // TEST_TILINGINTERFACE_TRANSFORM_OPS
diff --git a/mlir/test/lib/Transforms/TestTransformsOps.td b/mlir/test/lib/Transforms/TestTransformsOps.td
index 9b0a260..bc53b23 100644
--- a/mlir/test/lib/Transforms/TestTransformsOps.td
+++ b/mlir/test/lib/Transforms/TestTransformsOps.td
@@ -44,7 +44,9 @@ def TestMoveValueDefns :
DeclareOpInterfaceMethods<TransformOpInterface>,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
- Moves all dependencies of on operation before another operation.
+ Moves all dependencies of a list of values before another operation.
+ Only pure operations are moved. If there is a side effecting op in the
+ dependency chain no operations are moved.
}];
let arguments =
diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py
index 6ff12d6..675ded3 100644
--- a/mlir/test/lit.cfg.py
+++ b/mlir/test/lit.cfg.py
@@ -44,7 +44,7 @@ config.suffixes = [
".test",
".pdll",
".c",
- ".spv",
+ ".spvasm",
]
# test_source_root: The root path where tests are located.
@@ -214,6 +214,11 @@ tools = [
"not",
]
+if "Linux" in config.host_os:
+ # TODO: Run only on Linux until we figure out how to build
+ # mlir_apfloat_wrappers in a platform-independent way.
+ tools.extend([add_runtime("mlir_apfloat_wrappers")])
+
if config.enable_vulkan_runner:
tools.extend([add_runtime("mlir_vulkan_runtime")])
diff --git a/mlir/test/mlir-tblgen/constraint-unique.td b/mlir/test/mlir-tblgen/constraint-unique.td
index d51e1a5f..3f2e5cd 100644
--- a/mlir/test/mlir-tblgen/constraint-unique.td
+++ b/mlir/test/mlir-tblgen/constraint-unique.td
@@ -16,7 +16,7 @@ def AType : Type<ATypePred, "a type">;
def OtherType : Type<ATypePred, "another type">;
def AnAttrPred : CPred<"attrPred($_self, $_op)">;
-def AnAttr : Attr<AnAttrPred, "an attribute">;
+def AnAttr : Attr<AnAttrPred, "an attribute (got {{reformat($_self)}})">;
def OtherAttr : Attr<AnAttrPred, "another attribute">;
def ASuccessorPred : CPred<"successorPred($_self, $_op)">;
@@ -24,7 +24,7 @@ def ASuccessor : Successor<ASuccessorPred, "a successor">;
def OtherSuccessor : Successor<ASuccessorPred, "another successor">;
def ARegionPred : CPred<"regionPred($_self, $_op)">;
-def ARegion : Region<ARegionPred, "a region">;
+def ARegion : Region<ARegionPred, "a region ({{find(foo)}})">;
def OtherRegion : Region<ARegionPred, "another region">;
// OpA and OpB have the same type, attribute, successor, and region constraints.
@@ -71,10 +71,10 @@ def OpC : NS_Op<"op_c"> {
// CHECK: static ::llvm::LogicalResult [[$A_ATTR_CONSTRAINT:__mlir_ods_local_attr_constraint.*]](
// CHECK: if (attr && !((attrPred(attr, *op))))
// CHECK-NEXT: return emitError() << "attribute '" << attrName
-// CHECK-NEXT: << "' failed to satisfy constraint: an attribute";
+// CHECK-NEXT: << "' failed to satisfy constraint: an attribute (got " << reformat(attr) << ")";
/// Test that duplicate attribute constraint was not generated.
-// CHECK-NOT: << "' failed to satisfy constraint: an attribute";
+// CHECK-NOT: << "' failed to satisfy constraint: an attribute
/// Test that a attribute constraint with a different description was generated.
// CHECK: static ::llvm::LogicalResult [[$O_ATTR_CONSTRAINT:__mlir_ods_local_attr_constraint.*]](
@@ -103,7 +103,7 @@ def OpC : NS_Op<"op_c"> {
// CHECK: if (!((regionPred(region, *op)))) {
// CHECK-NEXT: return op->emitOpError("region #") << regionIndex
// CHECK-NEXT: << (regionName.empty() ? " " : " ('" + regionName + "') ")
-// CHECK-NEXT: << "failed to verify constraint: a region";
+// CHECK-NEXT: << "failed to verify constraint: a region (" << find(foo) << ")";
/// Test that duplicate region constraint was not generated.
// CHECK-NOT: << "failed to verify constraint: a region";
diff --git a/mlir/test/mlir-tblgen/cpp-class-comments.td b/mlir/test/mlir-tblgen/cpp-class-comments.td
index 9dcf975..0d3445d 100644
--- a/mlir/test/mlir-tblgen/cpp-class-comments.td
+++ b/mlir/test/mlir-tblgen/cpp-class-comments.td
@@ -36,6 +36,7 @@ def A_SomeOp1 : Op<A_Dialect, "some_op1", []>{
let cppNamespace = "OP1";
// OP: namespace OP1
+// OP-EMPTY:
// OP-NEXT: /// Some Op1 summary line1
// OP-NEXT: /// summary line2
// OP-NEXT: /// Some Op1 description
@@ -97,6 +98,7 @@ def EncodingTrait : AttrInterface<"EncodingTrait"> {
let methods = [
];
// ATTR-INTERFACE: namespace mlir::a::traits {
+// ATTR-INTERFACE-EMPTY:
// ATTR-INTERFACE-NEXT: /// Common trait for all layouts.
// ATTR-INTERFACE-NEXT: class EncodingTrait;
}
@@ -104,6 +106,7 @@ def EncodingTrait : AttrInterface<"EncodingTrait"> {
def SimpleEncodingTrait : AttrInterface<"SimpleEncodingTrait"> {
let cppNamespace = "a::traits";
// ATTR-INTERFACE: namespace a::traits {
+// ATTR-INTERFACE-EMPTY:
// ATTR-INTERFACE-NEXT: class SimpleEncodingTrait;
}
@@ -114,6 +117,7 @@ def SimpleOpInterface : OpInterface<"SimpleOpInterface"> {
Simple Op Interface description
}];
// OP-INTERFACE: namespace a::traits {
+// OP-INTERFACE-EMPTY:
// OP-INTERFACE-NEXT: /// Simple Op Interface description
// OP-INTERFACE-NEXT: class SimpleOpInterface;
}
diff --git a/mlir/test/mlir-tblgen/dialect-interface.td b/mlir/test/mlir-tblgen/dialect-interface.td
new file mode 100644
index 0000000..ff39fd9
--- /dev/null
+++ b/mlir/test/mlir-tblgen/dialect-interface.td
@@ -0,0 +1,65 @@
+// RUN: mlir-tblgen -gen-dialect-interface-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL
+
+include "mlir/IR/Interfaces.td"
+
+def NoDefaultMethod : DialectInterface<"NoDefaultMethod"> {
+ let description = [{
+ This is an example dialect interface without default method body.
+ }];
+
+ let cppNamespace = "::mlir::example";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/ "Check if it's an example dialect",
+ /*returnType=*/ "bool",
+ /*methodName=*/ "isExampleDialect",
+ /*args=*/ (ins)
+ >,
+ InterfaceMethod<
+ /*desc=*/ "second method to check if multiple methods supported",
+ /*returnType=*/ "unsigned",
+ /*methodName=*/ "supportSecondMethod",
+ /*args=*/ (ins "::mlir::Type":$type)
+ >
+
+ ];
+}
+
+// DECL: class NoDefaultMethod : public {{.*}}DialectInterface::Base<NoDefaultMethod>
+// DECL: public:
+// DECL-NEXT: NoDefaultMethod(::mlir::Dialect *dialect) : Base(dialect) {}
+// DECL: virtual bool isExampleDialect() const {}
+// DECL: virtual unsigned supportSecondMethod(::mlir::Type type) const {}
+
+def WithDefaultMethodInterface : DialectInterface<"WithDefaultMethodInterface"> {
+ let description = [{
+ This is an example dialect interface with default method bodies.
+ }];
+
+ let cppNamespace = "::mlir::example";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/ "Check if it's an example dialect",
+ /*returnType=*/ "bool",
+ /*methodName=*/ "isExampleDialect",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{
+ return true;
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/ "second method to check if multiple methods supported",
+ /*returnType=*/ "unsigned",
+ /*methodName=*/ "supportSecondMethod",
+ /*args=*/ (ins "::mlir::Type":$type)
+ >
+
+ ];
+}
+
+// DECL: virtual bool isExampleDialect() const {
+// DECL-NEXT: return true;
+// DECL-NEXT: }
+
diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td
index 549830e..a3cb9a4 100644
--- a/mlir/test/mlir-tblgen/op-attribute.td
+++ b/mlir/test/mlir-tblgen/op-attribute.td
@@ -69,19 +69,19 @@ def AOp : NS_Op<"a_op", []> {
// DEF: ::llvm::LogicalResult AOpAdaptor::verify
// DEF-NEXT: auto tblgen_aAttr = getProperties().aAttr; (void)tblgen_aAttr;
-// DEF-NEXT: if (!tblgen_aAttr) return emitError(loc, "'test.a_op' op ""requires attribute 'aAttr'");
+// DEF-NEXT: if (!tblgen_aAttr) return emitError(loc, "'test.a_op' op requires attribute 'aAttr'");
// DEF-NEXT: auto tblgen_bAttr = getProperties().bAttr; (void)tblgen_bAttr;
// DEF-NEXT: auto tblgen_cAttr = getProperties().cAttr; (void)tblgen_cAttr;
// DEF-NEXT: auto tblgen_dAttr = getProperties().dAttr; (void)tblgen_dAttr;
// DEF: if (tblgen_aAttr && !((some-condition)))
-// DEF-NEXT: return emitError(loc, "'test.a_op' op ""attribute 'aAttr' failed to satisfy constraint: some attribute kind");
+// DEF-NEXT: return emitError(loc, "'test.a_op' op attribute 'aAttr' failed to satisfy constraint: some attribute kind");
// DEF: if (tblgen_bAttr && !((some-condition)))
-// DEF-NEXT: return emitError(loc, "'test.a_op' op ""attribute 'bAttr' failed to satisfy constraint: some attribute kind");
+// DEF-NEXT: return emitError(loc, "'test.a_op' op attribute 'bAttr' failed to satisfy constraint: some attribute kind");
// DEF: if (tblgen_cAttr && !((some-condition)))
-// DEF-NEXT: return emitError(loc, "'test.a_op' op ""attribute 'cAttr' failed to satisfy constraint: some attribute kind");
+// DEF-NEXT: return emitError(loc, "'test.a_op' op attribute 'cAttr' failed to satisfy constraint: some attribute kind");
// DEF: if (tblgen_dAttr && !((some-condition)))
-// DEF-NEXT: return emitError(loc, "'test.a_op' op ""attribute 'dAttr' failed to satisfy constraint: some attribute kind");
+// DEF-NEXT: return emitError(loc, "'test.a_op' op attribute 'dAttr' failed to satisfy constraint: some attribute kind");
// Test getter methods
// ---
@@ -219,13 +219,13 @@ def AgetOp : Op<Test2_Dialect, "a_get_op", []> {
// DEF: ::llvm::LogicalResult AgetOpAdaptor::verify
// DEF: auto tblgen_aAttr = getProperties().aAttr; (void)tblgen_aAttr;
-// DEF: if (!tblgen_aAttr) return emitError(loc, "'test2.a_get_op' op ""requires attribute 'aAttr'");
+// DEF: if (!tblgen_aAttr) return emitError(loc, "'test2.a_get_op' op requires attribute 'aAttr'");
// DEF: auto tblgen_bAttr = getProperties().bAttr; (void)tblgen_bAttr;
// DEF: auto tblgen_cAttr = getProperties().cAttr; (void)tblgen_cAttr;
// DEF: if (tblgen_bAttr && !((some-condition)))
-// DEF-NEXT: return emitError(loc, "'test2.a_get_op' op ""attribute 'bAttr' failed to satisfy constraint: some attribute kind");
+// DEF-NEXT: return emitError(loc, "'test2.a_get_op' op attribute 'bAttr' failed to satisfy constraint: some attribute kind");
// DEF: if (tblgen_cAttr && !((some-condition)))
-// DEF-NEXT: return emitError(loc, "'test2.a_get_op' op ""attribute 'cAttr' failed to satisfy constraint: some attribute kind");
+// DEF-NEXT: return emitError(loc, "'test2.a_get_op' op attribute 'cAttr' failed to satisfy constraint: some attribute kind");
// Test getter methods
// ---
diff --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td
index 87b41f9..80dedb84 100644
--- a/mlir/test/mlir-tblgen/op-decl-and-defs.td
+++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td
@@ -65,6 +65,9 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
// CHECK: ::std::optional< ::llvm::APFloat > getSomeAttr2();
// CHECK: ::mlir::Region &getSomeRegion() {
// CHECK: ::mlir::RegionRange getSomeRegions() {
+// CHECK-NEXT: return odsRegions.drop_front(1);
+// CHECK: ::mlir::RegionRange getRegions() {
+// CHECK-NEXT: return odsRegions;
// CHECK: };
// CHECK: }
@@ -152,10 +155,6 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
// Check that `getAttrDictionary()` is used when not using properties.
-// DECLS: ::mlir::RegionRange AOpGenericAdaptorBase::getSomeRegions()
-// DECLS-NEXT: return odsRegions.drop_front(1);
-// DECLS: ::mlir::RegionRange AOpGenericAdaptorBase::getRegions()
-
// Check AttrSizedOperandSegments
// ---
@@ -236,14 +235,14 @@ def NS_FOp : NS_Op<"op_with_all_types_constraint",
// DEFS: FOp FOp::create(::mlir::OpBuilder &builder, ::mlir::Location location, ::mlir::Value a) {
// DEFS: ::mlir::OperationState __state__(location, getOperationName());
-// DEFS: build(builder, __state__, a);
+// DEFS: build(builder, __state__, std::forward<decltype(a)>(a));
// DEFS: auto __res__ = ::llvm::dyn_cast<FOp>(builder.create(__state__));
// DEFS: assert(__res__ && "builder didn't return the right type");
// DEFS: return __res__;
// DEFS: }
// DEFS: FOp FOp::create(::mlir::ImplicitLocOpBuilder &builder, ::mlir::Value a) {
-// DEFS: return create(builder, builder.getLoc(), a);
+// DEFS: return create(builder, builder.getLoc(), std::forward<decltype(a)>(a));
// DEFS: }
def NS_GOp : NS_Op<"op_with_fixed_return_type", []> {
diff --git a/mlir/test/mlir-tblgen/op-properties-predicates.td b/mlir/test/mlir-tblgen/op-properties-predicates.td
index af09ee7..7cc9633 100644
--- a/mlir/test/mlir-tblgen/op-properties-predicates.td
+++ b/mlir/test/mlir-tblgen/op-properties-predicates.td
@@ -74,7 +74,7 @@ def OpWithPredicates : NS_Op<"op_with_predicates"> {
// Note: comprehensive emission of verifiers is tested in verifyINvariantsImpl() below
// CHECK: int64_t tblgen_scalar = this->getScalar();
// CHECK: if (!((tblgen_scalar >= 0)))
-// CHECK: return emitError(loc, "'test.op_with_predicates' op ""property 'scalar' failed to satisfy constraint: non-negative int64_t");
+// CHECK: return emitError(loc, "'test.op_with_predicates' op property 'scalar' failed to satisfy constraint: non-negative int64_t");
// CHECK-LABEL: OpWithPredicates::verifyInvariantsImpl()
// Note: for test readability, we capture [[maybe_unused]] into the variable maybe_unused
diff --git a/mlir/test/mlir-tblgen/op-properties.td b/mlir/test/mlir-tblgen/op-properties.td
index a9c784c..cb9bd3d 100644
--- a/mlir/test/mlir-tblgen/op-properties.td
+++ b/mlir/test/mlir-tblgen/op-properties.td
@@ -32,7 +32,7 @@ def OpWithProps : NS_Op<"op_with_props"> {
ArrayProp<StringProp>:$strings,
DefaultValuedProp<I32Prop, "0">:$default_int,
OptionalProp<I64Prop>:$optional,
- DefaultI64Array:$intArray
+ DefaultI64Array:$value
);
}
@@ -94,10 +94,10 @@ def OpWithOptionalPropsAndAttrs :
// DECL: ::llvm::ArrayRef<std::string> getStrings()
// DECL: using default_intTy = int32_t;
// DECL: default_intTy default_int = 0;
-// DECL: intArrayTy intArray = ::llvm::SmallVector<int64_t>{};
-// DECL: ::llvm::ArrayRef<int64_t> getIntArray()
+// DECL: valueTy value = ::llvm::SmallVector<int64_t>{};
+// DECL: ::llvm::ArrayRef<int64_t> getValue()
// DECL: return ::llvm::ArrayRef<int64_t>{propStorage}
-// DECL: void setIntArray(::llvm::ArrayRef<int64_t> propValue)
+// DECL: void setValue(::llvm::ArrayRef<int64_t> propValue)
// DECL: propStorage.assign
// DECL-LABEL: class OpWithProps :
// DECL: setString(::llvm::StringRef newString)
@@ -111,14 +111,14 @@ def OpWithOptionalPropsAndAttrs :
// DECL-SAME: ::llvm::ArrayRef<std::string> strings,
// DECL-SAME: /*optional*/int32_t default_int = 0,
// DECL-SAME: /*optional*/std::optional<int64_t> optional = std::nullopt,
-// DECL-SAME: /*optional*/::llvm::ArrayRef<int64_t> intArray = ::llvm::ArrayRef<int64_t>{});
+// DECL-SAME: /*optional*/::llvm::ArrayRef<int64_t> value = ::llvm::ArrayRef<int64_t>{});
// DEFS-LABEL: OpWithProps::computePropertiesHash
-// DEFS: hash_intArray
+// DEFS: hash_value_
// DEFS: using ::llvm::hash_value;
// DEFS-NEXT: return hash_value(::llvm::ArrayRef<int64_t>{propStorage})
// DEFS: hash_value(prop.optional)
-// DEFS: hash_intArray(prop.intArray)
+// DEFS: hash_value_(prop.value)
// -----
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 42de7e4..ff16ad8 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -350,16 +350,16 @@ def MissingNamesOp : TestOp<"missing_names"> {
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
// CHECK: @builtins.property
- // CHECK: def f32(self) -> _ods_ir.Value:
+ // CHECK: def f32(self) -> _ods_ir.Value[_ods_ir.FloatType]:
// CHECK: return self.operation.operands[1]
let arguments = (ins I32, F32:$f32, I64);
// CHECK: @builtins.property
- // CHECK: def i32(self) -> _ods_ir.OpResult:
+ // CHECK: def i32(self) -> _ods_ir.OpResult[_ods_ir.IntegerType]:
// CHECK: return self.operation.results[0]
//
// CHECK: @builtins.property
- // CHECK: def i64(self) -> _ods_ir.OpResult:
+ // CHECK: def i64(self) -> _ods_ir.OpResult[_ods_ir.IntegerType]:
// CHECK: return self.operation.results[2]
let results = (outs I32:$i32, AnyFloat, I64:$i64);
}
@@ -590,20 +590,20 @@ def SimpleOp : TestOp<"simple"> {
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
// CHECK: @builtins.property
- // CHECK: def i32(self) -> _ods_ir.Value:
+ // CHECK: def i32(self) -> _ods_ir.Value[_ods_ir.IntegerType]:
// CHECK: return self.operation.operands[0]
//
// CHECK: @builtins.property
- // CHECK: def f32(self) -> _ods_ir.Value:
+ // CHECK: def f32(self) -> _ods_ir.Value[_ods_ir.FloatType]:
// CHECK: return self.operation.operands[1]
let arguments = (ins I32:$i32, F32:$f32);
// CHECK: @builtins.property
- // CHECK: def i64(self) -> _ods_ir.OpResult:
+ // CHECK: def i64(self) -> _ods_ir.OpResult[_ods_ir.IntegerType]:
// CHECK: return self.operation.results[0]
//
// CHECK: @builtins.property
- // CHECK: def f64(self) -> _ods_ir.OpResult:
+ // CHECK: def f64(self) -> _ods_ir.OpResult[_ods_ir.FloatType]:
// CHECK: return self.operation.results[1]
let results = (outs I64:$i64, AnyFloat:$f64);
}
diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td
index c1fcd3f..41e041f 100644
--- a/mlir/test/mlir-tblgen/predicate.td
+++ b/mlir/test/mlir-tblgen/predicate.td
@@ -55,7 +55,7 @@ def OpF : NS_Op<"op_for_int_min_val", []> {
// CHECK-LABEL: OpFAdaptor::verify
// CHECK: (::llvm::cast<::mlir::IntegerAttr>(tblgen_attr).getInt() >= 10)
-// CHECK-NEXT: "attribute 'attr' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 10"
+// CHECK-NEXT: attribute 'attr' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 10"
def OpFX : NS_Op<"op_for_int_max_val", []> {
let arguments = (ins ConfinedAttr<I32Attr, [IntMaxValue<10>]>:$attr);
@@ -63,7 +63,7 @@ def OpFX : NS_Op<"op_for_int_max_val", []> {
// CHECK-LABEL: OpFXAdaptor::verify
// CHECK: (::llvm::cast<::mlir::IntegerAttr>(tblgen_attr).getInt() <= 10)
-// CHECK-NEXT: "attribute 'attr' failed to satisfy constraint: 32-bit signless integer attribute whose maximum value is 10"
+// CHECK-NEXT: attribute 'attr' failed to satisfy constraint: 32-bit signless integer attribute whose maximum value is 10"
def OpG : NS_Op<"op_for_arr_min_count", []> {
let arguments = (ins ConfinedAttr<ArrayAttr, [ArrayMinCount<8>]>:$attr);
@@ -71,7 +71,7 @@ def OpG : NS_Op<"op_for_arr_min_count", []> {
// CHECK-LABEL: OpGAdaptor::verify
// CHECK: (::llvm::cast<::mlir::ArrayAttr>(tblgen_attr).size() >= 8)
-// CHECK-NEXT: "attribute 'attr' failed to satisfy constraint: array attribute with at least 8 elements"
+// CHECK-NEXT: attribute 'attr' failed to satisfy constraint: array attribute with at least 8 elements"
def OpH : NS_Op<"op_for_arr_value_at_index", []> {
let arguments = (ins ConfinedAttr<ArrayAttr, [IntArrayNthElemEq<0, 8>]>:$attr);
@@ -79,7 +79,7 @@ def OpH : NS_Op<"op_for_arr_value_at_index", []> {
// CHECK-LABEL: OpHAdaptor::verify
// CHECK: (((::llvm::cast<::mlir::ArrayAttr>(tblgen_attr).size() > 0)) && ((::llvm::cast<::mlir::IntegerAttr>(::llvm::cast<::mlir::ArrayAttr>(tblgen_attr)[0]).getInt() == 8)))))
-// CHECK-NEXT: "attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be 8"
+// CHECK-NEXT: attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be 8"
def OpI: NS_Op<"op_for_arr_min_value_at_index", []> {
let arguments = (ins ConfinedAttr<ArrayAttr, [IntArrayNthElemMinValue<0, 8>]>:$attr);
@@ -87,7 +87,7 @@ def OpI: NS_Op<"op_for_arr_min_value_at_index", []> {
// CHECK-LABEL: OpIAdaptor::verify
// CHECK: (((::llvm::cast<::mlir::ArrayAttr>(tblgen_attr).size() > 0)) && ((::llvm::cast<::mlir::IntegerAttr>(::llvm::cast<::mlir::ArrayAttr>(tblgen_attr)[0]).getInt() >= 8)))))
-// CHECK-NEXT: "attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be at least 8"
+// CHECK-NEXT: attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be at least 8"
def OpJ: NS_Op<"op_for_arr_max_value_at_index", []> {
let arguments = (ins ConfinedAttr<ArrayAttr, [IntArrayNthElemMaxValue<0, 8>]>:$attr);
@@ -95,7 +95,7 @@ def OpJ: NS_Op<"op_for_arr_max_value_at_index", []> {
// CHECK-LABEL: OpJAdaptor::verify
// CHECK: (((::llvm::cast<::mlir::ArrayAttr>(tblgen_attr).size() > 0)) && ((::llvm::cast<::mlir::IntegerAttr>(::llvm::cast<::mlir::ArrayAttr>(tblgen_attr)[0]).getInt() <= 8)))))
-// CHECK-NEXT: "attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be at most 8"
+// CHECK-NEXT: attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be at most 8"
def OpK: NS_Op<"op_for_arr_in_range_at_index", []> {
let arguments = (ins ConfinedAttr<ArrayAttr, [IntArrayNthElemInRange<0, 4, 8>]>:$attr);
@@ -103,7 +103,7 @@ def OpK: NS_Op<"op_for_arr_in_range_at_index", []> {
// CHECK-LABEL: OpKAdaptor::verify
// CHECK: (((::llvm::cast<::mlir::ArrayAttr>(tblgen_attr).size() > 0)) && ((::llvm::cast<::mlir::IntegerAttr>(::llvm::cast<::mlir::ArrayAttr>(tblgen_attr)[0]).getInt() >= 4)) && ((::llvm::cast<::mlir::IntegerAttr>(::llvm::cast<::mlir::ArrayAttr>(tblgen_attr)[0]).getInt() <= 8)))))
-// CHECK-NEXT: "attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be at least 4 and at most 8"
+// CHECK-NEXT: attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be at least 4 and at most 8"
def OpL: NS_Op<"op_for_TCopVTEtAreSameAt", [
PredOpTrait<"operands indexed at 0, 2, 3 should all have "
@@ -121,7 +121,7 @@ def OpL: NS_Op<"op_for_TCopVTEtAreSameAt", [
// CHECK: ::llvm::all_equal(::llvm::map_range(
// CHECK-SAME: ::mlir::ArrayRef<unsigned>({0, 2, 3}),
// CHECK-SAME: [this](unsigned i) { return getElementTypeOrSelf(this->getOperand(i)); }))
-// CHECK: "failed to verify that operands indexed at 0, 2, 3 should all have the same type"
+// CHECK: failed to verify that operands indexed at 0, 2, 3 should all have the same type"
def OpM : NS_Op<"op_for_AnyTensorOf", []> {
let arguments = (ins TensorOf<[F32, I32]>:$x);
diff --git a/mlir/test/python/CMakeLists.txt b/mlir/test/python/CMakeLists.txt
index 2c12381..c81f75f 100644
--- a/mlir/test/python/CMakeLists.txt
+++ b/mlir/test/python/CMakeLists.txt
@@ -11,7 +11,7 @@ add_public_tablegen_target(MLIRPythonTestIncGen)
add_subdirectory(lib)
-set(MLIR_PYTHON_TEST_DEPENDS MLIRPythonModules mlir-runner)
+set(MLIR_PYTHON_TEST_DEPENDS MLIRPythonModules mlir-runner mlir_c_runner_utils mlir_runner_utils)
if(NOT MLIR_STANDALONE_BUILD)
list(APPEND MLIR_PYTHON_TEST_DEPENDS FileCheck count not)
endif()
diff --git a/mlir/test/python/dialects/gpu/dialect.py b/mlir/test/python/dialects/gpu/dialect.py
index 3945c99..1a009b7 100644
--- a/mlir/test/python/dialects/gpu/dialect.py
+++ b/mlir/test/python/dialects/gpu/dialect.py
@@ -133,9 +133,10 @@ def testGPUFuncOp():
), func.known_grid_size
func = gpu.GPUFuncOp(
- func_type,
+ ir.FunctionType.get(inputs=[T.index()], results=[]),
sym_name="non_kernel_func",
body_builder=builder,
+ arg_attrs=[{"gpu.some_attribute": ir.StringAttr.get("foo")}],
)
assert not func.is_kernel
assert func.known_block_size is None
@@ -154,10 +155,11 @@ def testGPUFuncOp():
# CHECK: %[[VAL_0:.*]] = gpu.global_id x
# CHECK: gpu.return
# CHECK: }
- # CHECK: gpu.func @non_kernel_func() {
- # CHECK: %[[VAL_0:.*]] = gpu.global_id x
- # CHECK: gpu.return
- # CHECK: }
+ # CHECK: gpu.func @non_kernel_func(
+ # CHECK-SAME: %[[ARG0:.*]]: index {gpu.some_attribute = "foo"}) {
+ # CHECK: %[[GLOBAL_ID_0:.*]] = gpu.global_id x
+ # CHECK: gpu.return
+ # CHECK: }
# CHECK-LABEL: testGPULaunchFuncOp
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 709a1d2..92591cd 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -1,7 +1,8 @@
# RUN: %PYTHON %s | FileCheck %s
-from mlir.dialects import arith, func, linalg, tensor, memref
+from mlir.dialects import arith, func, linalg, tensor, memref, builtin
from mlir.dialects.linalg.opdsl.lang import *
+from mlir.extras import types as T
from mlir.ir import *
@@ -857,3 +858,76 @@ def testElementwiseOp():
)
print(module)
+
+
+@run
+def testReduceOp():
+ with Context(), Location.unknown():
+ f32 = T.f32()
+ tensor_type = T.tensor(10, f32)
+
+ @builtin.module
+ def module():
+ @func.func(tensor_type)
+ def reduce_op(input):
+ c1 = arith.constant(f32, 1.0)
+ single_result = ir.RankedTensorType.get((), f32)
+ dims = ir.DenseI64ArrayAttr.get([0])
+ init = tensor.splat(single_result, c1, [])
+
+ @linalg.reduce(
+ result=[single_result],
+ inputs=[input],
+ inits=[init],
+ dimensions=dims,
+ )
+ def reduced(element: f32, acc: f32):
+ return arith.mulf(acc, element)
+
+ return tensor.extract(reduced, [])
+
+ print(module)
+
+
+# CHECK-LABEL: func.func @reduce_op(
+# CHECK-SAME: %[[ARG0:.*]]: tensor<10xf32>) -> f32 {
+# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1.000000e+00 : f32
+# CHECK: %[[SPLAT_0:.*]] = tensor.splat %[[CONSTANT_0]] : tensor<f32>
+# CHECK: %[[REDUCE_0:.*]] = linalg.reduce { arith.mulf } ins(%[[ARG0]] : tensor<10xf32>) outs(%[[SPLAT_0]] : tensor<f32>) dimensions = [0]
+# CHECK: %[[EXTRACT_0:.*]] = tensor.extract %[[REDUCE_0]][] : tensor<f32>
+# CHECK: return %[[EXTRACT_0]] : f32
+# CHECK: }
+
+
+@run
+def testMapOp():
+ with Context(), Location.unknown():
+ f32 = T.f32()
+ tensor_type = T.tensor(10, f32)
+
+ @builtin.module
+ def module():
+ @func.func(tensor_type)
+ def map_op(input):
+ empty = tensor.empty(tensor_type.shape, f32)
+
+ @linalg.map(
+ result=[tensor_type],
+ inputs=[input, input],
+ init=empty,
+ )
+ def add(element: f32, acc: f32, init: f32):
+ return arith.addf(element, acc)
+
+ return add
+
+ module.verify()
+ print(module)
+
+
+# CHECK-LABEL: func.func @map_op(
+# CHECK-SAME: %[[ARG0:.*]]: tensor<10xf32>) -> tensor<10xf32> {
+# CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<10xf32>
+# CHECK: %[[MAP_0:.*]] = linalg.map { arith.addf } ins(%[[ARG0]], %[[ARG0]] : tensor<10xf32>, tensor<10xf32>) outs(%[[EMPTY_0]] : tensor<10xf32>)
+# CHECK: return %[[MAP_0]] : tensor<10xf32>
+# CHECK: }
diff --git a/mlir/test/python/dialects/linalg/utils.py b/mlir/test/python/dialects/linalg/utils.py
index 5f7cb6a..8ab53b4 100644
--- a/mlir/test/python/dialects/linalg/utils.py
+++ b/mlir/test/python/dialects/linalg/utils.py
@@ -208,3 +208,43 @@ def test_get_indexing_maps_attr():
assert maps[0] == a_map
assert maps[1] == b_map
assert maps[2] == c_map
+
+
+@run
+def test_infer_contraction_dimensions_from_maps():
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ # === Test valid contraction (matmul) ===
+ dim_m = AffineDimExpr.get(0)
+ dim_n = AffineDimExpr.get(1)
+ dim_k = AffineDimExpr.get(2)
+ a_map = AffineMap.get(3, 0, [dim_m, dim_k])
+ b_map = AffineMap.get(3, 0, [dim_k, dim_n])
+ c_map = AffineMap.get(3, 0, [dim_m, dim_n])
+
+ dims = linalg.infer_contraction_dimensions_from_maps([a_map, b_map, c_map])
+ assert dims is not None
+
+ # Expect m=[0], n=[1], k=[2] as per standard matmul.
+ assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}"
+ assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}"
+ assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}"
+ assert list(dims.batch) == [], f"Expected batch=[], got {list(dims.batch)}"
+
+ # === Test invalid input (wrong number of maps) ===
+ invalid_dims = linalg.infer_contraction_dimensions_from_maps([a_map, b_map])
+ assert invalid_dims is None
+
+ # === Test element-wise operation ===
+ dim_i = AffineDimExpr.get(0)
+ dim_j = AffineDimExpr.get(1)
+ elementwise_map = AffineMap.get(2, 0, [dim_i, dim_j])
+ elementwise_dims = linalg.infer_contraction_dimensions_from_maps(
+ [elementwise_map, elementwise_map, elementwise_map]
+ )
+ assert elementwise_dims is not None
+ assert len(elementwise_dims.m) == 0
+ assert len(elementwise_dims.n) == 0
+ assert len(elementwise_dims.k) == 0
+ assert list(elementwise_dims.batch) == [0, 1]
diff --git a/mlir/test/python/dialects/llvm.py b/mlir/test/python/dialects/llvm.py
index 8ea0fdd..305ed9a 100644
--- a/mlir/test/python/dialects/llvm.py
+++ b/mlir/test/python/dialects/llvm.py
@@ -98,6 +98,9 @@ def testStructType():
assert opaque.opaque
# CHECK: !llvm.struct<"opaque", opaque>
+ typ = Type.parse('!llvm.struct<"zoo", (i32, i64)>')
+ assert isinstance(typ, llvm.StructType)
+
# CHECK-LABEL: testSmoke
@constructAndPrintInModule
@@ -120,6 +123,9 @@ def testPointerType():
# CHECK: !llvm.ptr<1>
print(ptr_with_addr)
+ typ = Type.parse("!llvm.ptr<1>")
+ assert isinstance(typ, llvm.PointerType)
+
# CHECK-LABEL: testConstant
@constructAndPrintInModule
diff --git a/mlir/test/python/dialects/nvvm.py b/mlir/test/python/dialects/nvvm.py
index 3eb62be..d795524 100644
--- a/mlir/test/python/dialects/nvvm.py
+++ b/mlir/test/python/dialects/nvvm.py
@@ -15,7 +15,9 @@ def constructAndPrintInModule(f):
module = Module.create()
with InsertionPoint(module.body):
f()
+
print(module)
+ module.operation.verify()
return f
@@ -89,3 +91,133 @@ def test_inline_ptx():
arith.addf(a, b)
arith.addi(c, d)
arith.addf(wo0, wo1)
+
+
+@constructAndPrintInModule
+def test_barriers():
+ i32 = T.i32()
+ f32 = T.f32()
+
+ @func.FuncOp.from_py_func(i32, i32, f32)
+ def barriers(mask, vi32, vf32):
+ c0 = arith.constant(T.i32(), 0)
+ cffff = arith.constant(T.i32(), 0xFFFF)
+ res = nvvm.barrier(
+ res=i32,
+ barrier_id=c0,
+ number_of_threads=cffff,
+ )
+
+ for reduction in (
+ nvvm.BarrierReduction.AND,
+ nvvm.BarrierReduction.OR,
+ nvvm.BarrierReduction.POPC,
+ ):
+ res = nvvm.barrier(
+ res=i32,
+ reduction_op=reduction,
+ reduction_predicate=res,
+ )
+
+ nvvm.barrier0()
+ nvvm.bar_warp_sync(mask)
+ nvvm.cluster_arrive()
+ nvvm.cluster_arrive(aligned=True)
+ nvvm.cluster_arrive_relaxed()
+ nvvm.cluster_arrive_relaxed(aligned=True)
+ nvvm.cluster_wait()
+ nvvm.cluster_wait(aligned=True)
+ nvvm.fence_mbarrier_init()
+ nvvm.bar_warp_sync(mask)
+ return res
+
+
+# CHECK-LABEL: func.func @barriers(
+# CHECK: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: f32) -> i32 {
+# CHECK: %[[CONSTANT_0:.*]] = arith.constant 0 : i32
+# CHECK: %[[CONSTANT_1:.*]] = arith.constant 65535 : i32
+# CHECK: %[[BARRIER_0:.*]] = nvvm.barrier id = %[[CONSTANT_0]] number_of_threads = %[[CONSTANT_1]] -> i32
+# CHECK: %[[BARRIER_1:.*]] = nvvm.barrier #nvvm.reduction<and> %[[BARRIER_0]] -> i32
+# CHECK: %[[BARRIER_2:.*]] = nvvm.barrier #nvvm.reduction<or> %[[BARRIER_1]] -> i32
+# CHECK: %[[BARRIER_3:.*]] = nvvm.barrier #nvvm.reduction<popc> %[[BARRIER_2]] -> i32
+# CHECK: nvvm.barrier0
+# CHECK: nvvm.bar.warp.sync %[[ARG0]] : i32
+# CHECK: nvvm.cluster.arrive
+# CHECK: nvvm.cluster.arrive {aligned}
+# CHECK: nvvm.cluster.arrive.relaxed
+# CHECK: nvvm.cluster.arrive.relaxed {aligned}
+# CHECK: nvvm.cluster.wait
+# CHECK: nvvm.cluster.wait {aligned}
+# CHECK: nvvm.fence.mbarrier.init
+# CHECK: nvvm.bar.warp.sync %[[ARG0]] : i32
+# CHECK: return %[[BARRIER_3]] : i32
+# CHECK: }
+
+
+@constructAndPrintInModule
+def test_reductions():
+ i32 = T.i32()
+ f32 = T.f32()
+
+ @func.FuncOp.from_py_func(i32, i32, f32)
+ def reductions(mask, vi32, vf32):
+ for abs in (True, False):
+ for nan in (True, False):
+ for kind in (
+ nvvm.ReduxKind.AND,
+ nvvm.ReduxKind.MAX,
+ nvvm.ReduxKind.MIN,
+ nvvm.ReduxKind.OR,
+ nvvm.ReduxKind.UMAX,
+ nvvm.ReduxKind.UMIN,
+ nvvm.ReduxKind.XOR,
+ ):
+ nvvm.redux_sync(i32, vi32, kind, vi32)
+
+ for kind in (
+ nvvm.ReduxKind.FMIN,
+ nvvm.ReduxKind.FMAX,
+ ):
+ nvvm.redux_sync(f32, vf32, kind, vi32, abs=abs, nan=nan)
+
+
+# CHECK-LABEL: func.func @reductions(
+# CHECK: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: f32) {
+# CHECK: %[[REDUX_0:.*]] = nvvm.redux.sync and %[[ARG1]], %[[ARG1]] : i32 -> i32
+# CHECK: %[[REDUX_1:.*]] = nvvm.redux.sync max %[[ARG1]], %[[ARG1]] : i32 -> i32
+# CHECK: %[[REDUX_2:.*]] = nvvm.redux.sync min %[[ARG1]], %[[ARG1]] : i32 -> i32
+# CHECK: %[[REDUX_3:.*]] = nvvm.redux.sync or %[[ARG1]], %[[ARG1]] : i32 -> i32
+# CHECK: %[[REDUX_4:.*]] = nvvm.redux.sync umax %[[ARG1]], %[[ARG1]] : i32 -> i32
+# CHECK: %[[REDUX_5:.*]] = nvvm.redux.sync umin %[[ARG1]], %[[ARG1]] : i32 -> i32
+# CHECK: %[[REDUX_6:.*]] = nvvm.redux.sync xor %[[ARG1]], %[[ARG1]] : i32 -> i32
+# CHECK: %[[REDUX_7:.*]] = nvvm.redux.sync fmin %[[ARG2]], %[[ARG1]] {abs = true, nan = true} : f32 -> f32
+# CHECK: %[[REDUX_8:.*]] = nvvm.redux.sync fmax %[[ARG2]], %[[ARG1]] {abs = true, nan = true} : f32 -> f32
+# CHECK: %[[REDUX_9:.*]] = nvvm.redux.sync and %[[ARG1]], %[[ARG1]] : i32 -> i32
+# CHECK: %[[REDUX_10:.*]] = nvvm.redux.sync max %[[ARG1]], %[[ARG1]] : i32 -> i32
+# CHECK: %[[REDUX_11:.*]] = nvvm.redux.sync min %[[ARG1]], %[[ARG1]] : i32 -> i32
+# CHECK: %[[REDUX_12:.*]] = nvvm.redux.sync or %[[ARG1]], %[[ARG1]] : i32 -> i32
+# CHECK: %[[REDUX_13:.*]] = nvvm.redux.sync umax %[[ARG1]], %[[ARG1]] : i32 -> i32
+# CHECK: %[[REDUX_14:.*]] = nvvm.redux.sync umin %[[ARG1]], %[[ARG1]] : i32 -> i32
+# CHECK: %[[REDUX_15:.*]] = nvvm.redux.sync xor %[[ARG1]], %[[ARG1]] : i32 -> i32
+# CHECK: %[[REDUX_16:.*]] = nvvm.redux.sync fmin %[[ARG2]], %[[ARG1]] {abs = true} : f32 -> f32
+# CHECK: %[[REDUX_17:.*]] = nvvm.redux.sync fmax %[[ARG2]], %[[ARG1]] {abs = true} : f32 -> f32
+# CHECK: %[[REDUX_18:.*]] = nvvm.redux.sync and %[[ARG1]], %[[ARG1]] : i32 -> i32
+# CHECK: %[[REDUX_19:.*]] = nvvm.redux.sync max %[[ARG1]], %[[ARG1]] : i32 -> i32
+# CHECK: %[[REDUX_20:.*]] = nvvm.redux.sync min %[[ARG1]], %[[ARG1]] : i32 -> i32
+# CHECK: %[[REDUX_21:.*]] = nvvm.redux.sync or %[[ARG1]], %[[ARG1]] : i32 -> i32
+# CHECK: %[[REDUX_22:.*]] = nvvm.redux.sync umax %[[ARG1]], %[[ARG1]] : i32 -> i32
+# CHECK: %[[REDUX_23:.*]] = nvvm.redux.sync umin %[[ARG1]], %[[ARG1]] : i32 -> i32
+# CHECK: %[[REDUX_24:.*]] = nvvm.redux.sync xor %[[ARG1]], %[[ARG1]] : i32 -> i32
+# CHECK: %[[REDUX_25:.*]] = nvvm.redux.sync fmin %[[ARG2]], %[[ARG1]] {nan = true} : f32 -> f32
+# CHECK: %[[REDUX_26:.*]] = nvvm.redux.sync fmax %[[ARG2]], %[[ARG1]] {nan = true} : f32 -> f32
+# CHECK: %[[REDUX_27:.*]] = nvvm.redux.sync and %[[ARG1]], %[[ARG1]] : i32 -> i32
+# CHECK: %[[REDUX_28:.*]] = nvvm.redux.sync max %[[ARG1]], %[[ARG1]] : i32 -> i32
+# CHECK: %[[REDUX_29:.*]] = nvvm.redux.sync min %[[ARG1]], %[[ARG1]] : i32 -> i32
+# CHECK: %[[REDUX_30:.*]] = nvvm.redux.sync or %[[ARG1]], %[[ARG1]] : i32 -> i32
+# CHECK: %[[REDUX_31:.*]] = nvvm.redux.sync umax %[[ARG1]], %[[ARG1]] : i32 -> i32
+# CHECK: %[[REDUX_32:.*]] = nvvm.redux.sync umin %[[ARG1]], %[[ARG1]] : i32 -> i32
+# CHECK: %[[REDUX_33:.*]] = nvvm.redux.sync xor %[[ARG1]], %[[ARG1]] : i32 -> i32
+# CHECK: %[[REDUX_34:.*]] = nvvm.redux.sync fmin %[[ARG2]], %[[ARG1]] : f32 -> f32
+# CHECK: %[[REDUX_35:.*]] = nvvm.redux.sync fmax %[[ARG2]], %[[ARG1]] : f32 -> f32
+# CHECK: return
+# CHECK: }
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 1194e32..f0f74eb 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -554,7 +554,7 @@ def testOptionalOperandOp():
)
assert (
typing.get_type_hints(test.OptionalOperandOp.result.fget)["return"]
- is OpResult
+ == OpResult[IntegerType]
)
assert type(op1.result) is OpResult
@@ -663,6 +663,13 @@ def testCustomType():
@run
+# CHECK-LABEL: TEST: testValue
+def testValue():
+ # Check that Value is a generic class at runtime.
+ assert hasattr(Value, "__class_getitem__")
+
+
+@run
# CHECK-LABEL: TEST: testTensorValue
def testTensorValue():
with Context() as ctx, Location.unknown():
diff --git a/mlir/test/python/dialects/rocdl.py b/mlir/test/python/dialects/rocdl.py
index a4a50af..c73a536 100644
--- a/mlir/test/python/dialects/rocdl.py
+++ b/mlir/test/python/dialects/rocdl.py
@@ -29,13 +29,12 @@ def testSmoke():
a_frag = arith.constant(v16f32, f32_array)
b_frag = arith.constant(v16f32, f32_array)
c_frag = arith.constant(v16f32, f32_array)
- false = arith.constant(T.bool(), False)
- c_frag = rocdl.wmma_f16_16x16x16_f16(v16f32, [a_frag, b_frag, c_frag, false])
- # CHECK: %{{.*}} = rocdl.wmma.f16.16x16x16.f16
+ c_frag = rocdl.wmma_f16_16x16x16_f16(v16f32, a_frag, b_frag, c_frag, opsel=False)
+ # CHECK: %{{.*}} = "rocdl.wmma.f16.16x16x16.f16"
print(c_frag)
assert isinstance(c_frag, OpView)
- # CHECK: Value(%{{.*}} = rocdl.wmma.f16.16x16x16.f16
- c_frag = rocdl.wmma_f16_16x16x16_f16_(v16f32, [a_frag, b_frag, c_frag, false])
+ # CHECK: Value(%{{.*}} = "rocdl.wmma.f16.16x16x16.f16"
+ c_frag = rocdl.wmma_f16_16x16x16_f16_(v16f32, a_frag, b_frag, c_frag, opsel=False)
print(c_frag)
assert isinstance(c_frag, Value)
diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py
index 62d11d5..0c0c9b9 100644
--- a/mlir/test/python/dialects/scf.py
+++ b/mlir/test/python/dialects/scf.py
@@ -1,10 +1,14 @@
# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
-from mlir.dialects import arith
-from mlir.dialects import func
-from mlir.dialects import memref
-from mlir.dialects import scf
+from mlir.extras import types as T
+from mlir.dialects import (
+ arith,
+ func,
+ memref,
+ scf,
+ cf,
+)
from mlir.passmanager import PassManager
@@ -355,3 +359,117 @@ def testIfWithElse():
# CHECK: scf.yield %[[TWO]], %[[THREE]]
# CHECK: arith.addi %[[RET]]#0, %[[RET]]#1
# CHECK: return
+
+
+@constructAndPrintInModule
+def testIndexSwitch():
+ i32 = T.i32()
+
+ @func.FuncOp.from_py_func(T.index(), results=[i32])
+ def index_switch(index):
+ c1 = arith.constant(i32, 1)
+ c0 = arith.constant(i32, 0)
+ value = arith.constant(i32, 5)
+ switch_op = scf.IndexSwitchOp([i32], index, range(3))
+
+ assert switch_op.regions[0] == switch_op.default_region
+ assert switch_op.regions[1] == switch_op.case_regions[0]
+ assert switch_op.regions[1] == switch_op.case_region(0)
+ assert len(switch_op.case_regions) == 3
+ assert len(switch_op.regions) == 4
+
+ with InsertionPoint(switch_op.default_block):
+ cf.assert_(arith.constant(T.bool(), 0), "Whoops!")
+ scf.yield_([c1])
+
+ for i, block in enumerate(switch_op.case_blocks):
+ with InsertionPoint(block):
+ scf.yield_([arith.constant(i32, i)])
+
+ func.return_([switch_op.results[0]])
+
+ return index_switch
+
+
+# CHECK-LABEL: func.func @index_switch(
+# CHECK-SAME: %[[ARG0:.*]]: index) -> i32 {
+# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : i32
+# CHECK: %[[CONSTANT_1:.*]] = arith.constant 0 : i32
+# CHECK: %[[CONSTANT_2:.*]] = arith.constant 5 : i32
+# CHECK: %[[INDEX_SWITCH_0:.*]] = scf.index_switch %[[ARG0]] -> i32
+# CHECK: case 0 {
+# CHECK: %[[CONSTANT_3:.*]] = arith.constant 0 : i32
+# CHECK: scf.yield %[[CONSTANT_3]] : i32
+# CHECK: }
+# CHECK: case 1 {
+# CHECK: %[[CONSTANT_4:.*]] = arith.constant 1 : i32
+# CHECK: scf.yield %[[CONSTANT_4]] : i32
+# CHECK: }
+# CHECK: case 2 {
+# CHECK: %[[CONSTANT_5:.*]] = arith.constant 2 : i32
+# CHECK: scf.yield %[[CONSTANT_5]] : i32
+# CHECK: }
+# CHECK: default {
+# CHECK: %[[CONSTANT_6:.*]] = arith.constant false
+# CHECK: cf.assert %[[CONSTANT_6]], "Whoops!"
+# CHECK: scf.yield %[[CONSTANT_0]] : i32
+# CHECK: }
+# CHECK: return %[[INDEX_SWITCH_0]] : i32
+# CHECK: }
+
+
+@constructAndPrintInModule
+def testIndexSwitchWithBodyBuilders():
+ i32 = T.i32()
+
+ @func.FuncOp.from_py_func(T.index(), results=[i32])
+ def index_switch(index):
+ c1 = arith.constant(i32, 1)
+ c0 = arith.constant(i32, 0)
+ value = arith.constant(i32, 5)
+
+ def default_body_builder(switch_op):
+ cf.assert_(arith.constant(T.bool(), 0), "Whoops!")
+ scf.yield_([c1])
+
+ def case_body_builder(switch_op, case_index: int, case_value: int):
+ scf.yield_([arith.constant(i32, case_value)])
+
+ result = scf.index_switch(
+ results=[i32],
+ arg=index,
+ cases=range(3),
+ case_body_builder=case_body_builder,
+ default_body_builder=default_body_builder,
+ )
+
+ func.return_([result])
+
+ return index_switch
+
+
+# CHECK-LABEL: func.func @index_switch(
+# CHECK-SAME: %[[ARG0:.*]]: index) -> i32 {
+# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : i32
+# CHECK: %[[CONSTANT_1:.*]] = arith.constant 0 : i32
+# CHECK: %[[CONSTANT_2:.*]] = arith.constant 5 : i32
+# CHECK: %[[INDEX_SWITCH_0:.*]] = scf.index_switch %[[ARG0]] -> i32
+# CHECK: case 0 {
+# CHECK: %[[CONSTANT_3:.*]] = arith.constant 0 : i32
+# CHECK: scf.yield %[[CONSTANT_3]] : i32
+# CHECK: }
+# CHECK: case 1 {
+# CHECK: %[[CONSTANT_4:.*]] = arith.constant 1 : i32
+# CHECK: scf.yield %[[CONSTANT_4]] : i32
+# CHECK: }
+# CHECK: case 2 {
+# CHECK: %[[CONSTANT_5:.*]] = arith.constant 2 : i32
+# CHECK: scf.yield %[[CONSTANT_5]] : i32
+# CHECK: }
+# CHECK: default {
+# CHECK: %[[CONSTANT_6:.*]] = arith.constant false
+# CHECK: cf.assert %[[CONSTANT_6]], "Whoops!"
+# CHECK: scf.yield %[[CONSTANT_0]] : i32
+# CHECK: }
+# CHECK: return %[[INDEX_SWITCH_0]] : i32
+# CHECK: }
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 6c5e4e5..f58442d 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -51,6 +51,26 @@ def testSequenceOp(module: Module):
transform.AnyOpType.get(),
)
with InsertionPoint(sequence.body):
+ res = transform.CastOp(transform.AnyOpType.get(), sequence.bodyTarget)
+ res2 = transform.cast(transform.any_op_t(), res.result)
+ transform.YieldOp([res2])
+ # CHECK-LABEL: TEST: testSequenceOp
+ # CHECK: transform.sequence
+ # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
+ # CHECK: %[[RES:.+]] = cast %[[ARG0]] : !transform.any_op to !transform.any_op
+ # CHECK: %[[RES2:.+]] = cast %[[RES]] : !transform.any_op to !transform.any_op
+ # CHECK: yield %[[RES2]] : !transform.any_op
+ # CHECK: }
+
+
+@run
+def testSequenceOp(module: Module):
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [transform.AnyOpType.get()],
+ transform.AnyOpType.get(),
+ )
+ with InsertionPoint(sequence.body):
transform.YieldOp([sequence.bodyTarget])
# CHECK-LABEL: TEST: testSequenceOp
# CHECK: = transform.sequence -> !transform.any_op failures(propagate) {
@@ -58,6 +78,7 @@ def testSequenceOp(module: Module):
# CHECK: yield %[[ARG0]] : !transform.any_op
# CHECK: }
+
@run
def testNestedSequenceOp(module: Module):
sequence = transform.SequenceOp(
@@ -103,55 +124,65 @@ def testSequenceOpWithExtras(module: Module):
# CHECK-LABEL: TEST: testSequenceOpWithExtras
# CHECK: transform.sequence failures(propagate)
# CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):
+ sequence = transform.sequence(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.AnyOpType.get(),
+ [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
+ )
+ with InsertionPoint(sequence.body):
+ transform.yield_()
+ # CHECK: transform.sequence failures(propagate)
+ # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):
@run
def testNestedSequenceOpWithExtras(module: Module):
- sequence = transform.SequenceOp(
+ sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.AnyOpType.get(),
[transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
)
- with InsertionPoint(sequence.body):
- nested = transform.SequenceOp(
+ with InsertionPoint(sequence.body):
+ nested = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
sequence.bodyTarget,
sequence.bodyExtraArgs,
)
- with InsertionPoint(nested.body):
- transform.YieldOp()
- transform.YieldOp()
- # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
- # CHECK: transform.sequence failures(propagate)
- # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
- # CHECK: sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)
+ with InsertionPoint(nested.body):
+ transform.YieldOp()
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
+ # CHECK: transform.sequence failures(propagate)
+ # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
+ # CHECK: sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)
@run
def testTransformPDLOps(module: Module):
- withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
- with InsertionPoint(withPdl.body):
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate,
- [transform.AnyOpType.get()],
- withPdl.bodyTarget,
- )
- with InsertionPoint(sequence.body):
- match = transform_pdl.PDLMatchOp(
- transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher"
- )
- transform.YieldOp(match)
- # CHECK-LABEL: TEST: testTransformPDLOps
- # CHECK: transform.with_pdl_patterns {
- # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
- # CHECK: = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) {
- # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
- # CHECK: %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]]
- # CHECK: yield %[[RES]] : !transform.any_op
- # CHECK: }
- # CHECK: }
+ withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
+ with InsertionPoint(withPdl.body):
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [transform.AnyOpType.get()],
+ withPdl.bodyTarget,
+ )
+ with InsertionPoint(sequence.body):
+ match = transform_pdl.PDLMatchOp(
+ transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher"
+ )
+ transform.YieldOp(match)
+ # CHECK-LABEL: TEST: testTransformPDLOps
+ # CHECK: transform.with_pdl_patterns {
+ # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
+ # CHECK: = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) {
+ # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
+ # CHECK: %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]]
+ # CHECK: yield %[[RES]] : !transform.any_op
+ # CHECK: }
+ # CHECK: }
@run
@@ -161,32 +192,53 @@ def testNamedSequenceOp(module: Module):
"__transform_main",
[transform.AnyOpType.get()],
[transform.AnyOpType.get()],
- arg_attrs = [{"transform.consumed": UnitAttr.get()}])
+ arg_attrs=[{"transform.consumed": UnitAttr.get()}],
+ )
with InsertionPoint(named_sequence.body):
transform.YieldOp([named_sequence.bodyTarget])
# CHECK-LABEL: TEST: testNamedSequenceOp
# CHECK: module attributes {transform.with_named_sequence} {
- # CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op {transform.consumed}) -> !transform.any_op {
- # CHECK: yield %[[ARG0]] : !transform.any_op
+ # CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op {transform.consumed}) -> !transform.any_op {
+ # CHECK: yield %[[ARG0]] : !transform.any_op
+ named_sequence = transform.named_sequence(
+ "other_seq",
+ [transform.AnyOpType.get()],
+ [transform.AnyOpType.get()],
+ arg_attrs=[{"transform.consumed": UnitAttr.get()}],
+ )
+ with InsertionPoint(named_sequence.body):
+ transform.yield_([named_sequence.bodyTarget])
+ # CHECK: transform.named_sequence @other_seq(%[[ARG1:.+]]: !transform.any_op {transform.consumed}) -> !transform.any_op {
+ # CHECK: yield %[[ARG1]] : !transform.any_op
@run
def testGetParentOp(module: Module):
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
- )
- with InsertionPoint(sequence.body):
- transform.GetParentOp(
- transform.AnyOpType.get(),
- sequence.bodyTarget,
- isolated_from_above=True,
- nth_parent=2,
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
- transform.YieldOp()
- # CHECK-LABEL: TEST: testGetParentOp
- # CHECK: transform.sequence
- # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
- # CHECK: = get_parent_op %[[ARG1]] {isolated_from_above, nth_parent = 2 : i64}
+ with InsertionPoint(sequence.body):
+ transform.GetParentOp(
+ transform.AnyOpType.get(),
+ sequence.bodyTarget,
+ isolated_from_above=True,
+ nth_parent=2,
+ )
+ transform.get_parent_op(
+ transform.AnyOpType.get(),
+ sequence.bodyTarget,
+ isolated_from_above=True,
+ nth_parent=2,
+ allow_empty_results=True,
+ op_name="func.func",
+ deduplicate=True,
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testGetParentOp
+ # CHECK: transform.sequence
+ # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
+ # CHECK: = get_parent_op %[[ARG1]] {isolated_from_above, nth_parent = 2 : i64}
+ # CHECK: = get_parent_op %[[ARG1]] {allow_empty_results, deduplicate, isolated_from_above, nth_parent = 2 : i64, op_name = "func.func"}
@run
@@ -195,43 +247,58 @@ def testMergeHandlesOp(module: Module):
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
with InsertionPoint(sequence.body):
- transform.MergeHandlesOp([sequence.bodyTarget])
+ res = transform.MergeHandlesOp([sequence.bodyTarget])
+ transform.merge_handles([res.result], deduplicate=True)
transform.YieldOp()
# CHECK-LABEL: TEST: testMergeHandlesOp
# CHECK: transform.sequence
# CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
- # CHECK: = merge_handles %[[ARG1]]
+ # CHECK: %[[RES1:.+]] = merge_handles %[[ARG1]] : !transform.any_op
+ # CHECK: = merge_handles deduplicate %[[RES1]] : !transform.any_op
@run
def testApplyPatternsOpCompact(module: Module):
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
- )
- with InsertionPoint(sequence.body):
- with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns):
- transform.ApplyCanonicalizationPatternsOp()
- transform.YieldOp()
- # CHECK-LABEL: TEST: testApplyPatternsOpCompact
- # CHECK: apply_patterns to
- # CHECK: transform.apply_patterns.canonicalization
- # CHECK: !transform.any_op
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+ )
+ with InsertionPoint(sequence.body):
+ with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns):
+ transform.ApplyCanonicalizationPatternsOp()
+ with InsertionPoint(
+ transform.apply_patterns(
+ sequence.bodyTarget,
+ apply_cse=True,
+ max_iterations=3,
+ max_num_rewrites=5,
+ ).patterns
+ ):
+ transform.ApplyCanonicalizationPatternsOp()
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testApplyPatternsOpCompact
+ # CHECK: apply_patterns to
+ # CHECK: transform.apply_patterns.canonicalization
+ # CHECK: } : !transform.any_op
+ # CHECK: apply_patterns to
+ # CHECK: transform.apply_patterns.canonicalization
+ # CHECK: } {apply_cse, max_iterations = 3 : i64, max_num_rewrites = 5 : i64} : !transform.any_op
@run
def testApplyPatternsOpWithType(module: Module):
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [],
- transform.OperationType.get('test.dummy')
- )
- with InsertionPoint(sequence.body):
- with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns):
- transform.ApplyCanonicalizationPatternsOp()
- transform.YieldOp()
- # CHECK-LABEL: TEST: testApplyPatternsOp
- # CHECK: apply_patterns to
- # CHECK: transform.apply_patterns.canonicalization
- # CHECK: !transform.op<"test.dummy">
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.OperationType.get("test.dummy"),
+ )
+ with InsertionPoint(sequence.body):
+ with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns):
+ transform.ApplyCanonicalizationPatternsOp()
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testApplyPatternsOp
+ # CHECK: apply_patterns to
+ # CHECK: transform.apply_patterns.canonicalization
+ # CHECK: !transform.op<"test.dummy">
@run
@@ -249,11 +316,13 @@ def testReplicateOp(module: Module):
transform.AnyOpType.get(), sequence.bodyTarget, "second"
)
transform.ReplicateOp(m1, [m2])
+ transform.replicate(m1, [m2])
transform.YieldOp()
# CHECK-LABEL: TEST: testReplicateOp
# CHECK: %[[FIRST:.+]] = pdl_match
# CHECK: %[[SECOND:.+]] = pdl_match
# CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]
+ # CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]
# CHECK-LABEL: TEST: testApplyRegisteredPassOp
diff --git a/mlir/test/python/dialects/transform_interpreter.py b/mlir/test/python/dialects/transform_interpreter.py
index 819a3be..ca9ce5d 100644
--- a/mlir/test/python/dialects/transform_interpreter.py
+++ b/mlir/test/python/dialects/transform_interpreter.py
@@ -32,6 +32,20 @@ def print_self():
@test_in_context
+def print_self_via_apply_method():
+ m = ir.Module.parse(
+ print_root_module.replace("from interpreter", "print_self_via_apply_method")
+ )
+ m.body.operations[0].apply(m)
+
+
+# CHECK-LABEL: print_self_via_apply_method
+# CHECK: transform.named_sequence @__transform_main
+# CHECK: transform.print
+# CHECK: transform.yield
+
+
+@test_in_context
def print_other():
transform = ir.Module.parse(
print_root_module.replace("from interpreter", "print_other")
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index d6b70dc..e58b764 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -627,12 +627,16 @@ def testVectorizeChildrenAndApplyPatternsAllAttrs(target):
disable_transfer_permutation_map_lowering_patterns=True,
vectorize_nd_extract=True,
vectorize_padding=True,
+ flatten_1d_depthwise_conv=True,
+ fold_type_extensions_into_contract=True,
)
# CHECK-LABEL: TEST: testVectorizeChildrenAndApplyPatternsAllAttrs
# CHECK: transform.sequence
# CHECK: = transform.structured.vectorize
# CHECK-SAME: disable_multi_reduction_to_contract_patterns
# CHECK-SAME: disable_transfer_permutation_map_lowering_patterns
+ # CHECK-SAME: flatten_1d_depthwise_conv
+ # CHECK-SAME: fold_type_extensions_into_contract
# CHECK-SAME: vectorize_nd_extract
# CHECK-SAME: vectorize_padding
@@ -646,12 +650,16 @@ def testVectorizeChildrenAndApplyPatternsNoAttrs(target):
disable_transfer_permutation_map_lowering_patterns=False,
vectorize_nd_extract=False,
vectorize_padding=False,
+ flatten_1d_depthwise_conv=False,
+ fold_type_extensions_into_contract=False,
)
# CHECK-LABEL: TEST: testVectorizeChildrenAndApplyPatternsNoAttrs
# CHECK: transform.sequence
# CHECK: = transform.structured.vectorize
# CHECK-NOT: disable_multi_reduction_to_contract_patterns
# CHECK-NOT: disable_transfer_permutation_map_lowering_patterns
+ # CHECK-NOT: flatten_1d_depthwise_conv
+ # CHECK-NOT: fold_type_extensions_into_contract
# CHECK-NOT: vectorize_nd_extract
# CHECK-NOT: vectorize_padding
diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py
new file mode 100644
index 0000000..2b11acb0
--- /dev/null
+++ b/mlir/test/python/dialects/transform_xegpu_ext.py
@@ -0,0 +1,296 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects import transform
+from mlir.dialects.transform import xegpu
+from mlir.dialects.transform import structured, AnyValueType
+
+
+def run(f):
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ print("\nTEST:", f.__name__)
+ f()
+ print(module)
+ return f
+
+
+@run
+def getDescOpDefaultIndex():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.OperationType.get("xegpu.dpas"),
+ )
+ with InsertionPoint(sequence.body):
+ operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
+ desc_handle = xegpu.get_desc_op(operand)
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: getDescOpDefaultIndex
+ # CHECK: transform.xegpu.get_desc_op %
+
+
+@run
+def setDescLayoutMinimal():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.OperationType.get("xegpu.create_nd_tdesc"),
+ )
+ with InsertionPoint(sequence.body):
+ xegpu.set_desc_layout(sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16])
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: setDescLayoutMinimal
+ # CHECK: %0 = transform.xegpu.set_desc_layout %
+ # CHECK: sg_layout = [6, 4]
+ # CHECK: sg_data = [32, 16]
+
+
+@run
+def setDescLayoutInstData():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.OperationType.get("xegpu.create_nd_tdesc"),
+ )
+ with InsertionPoint(sequence.body):
+ xegpu.set_desc_layout(
+ sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], inst_data=[8, 16]
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: setDescLayoutInstData
+ # CHECK: %0 = transform.xegpu.set_desc_layout %
+ # CHECK: sg_layout = [6, 4]
+ # CHECK: sg_data = [32, 16]
+ # CHECK: inst_data = [8, 16]
+
+
+@run
+def setDescLayoutSlice():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.OperationType.get("xegpu.create_nd_tdesc"),
+ )
+ with InsertionPoint(sequence.body):
+ xegpu.set_desc_layout(
+ sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], slice_dims=[0]
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: setDescLayoutSlice
+ # CHECK: %0 = transform.xegpu.set_desc_layout %
+ # CHECK: sg_layout = [6, 4]
+ # CHECK: sg_data = [32, 16]
+ # CHECK: slice_dims = [0]
+
+
+@run
+def setOpLayoutAttrOperandMinimal():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.OperationType.get("xegpu.dpas"),
+ )
+ with InsertionPoint(sequence.body):
+ xegpu.set_op_layout_attr(
+ sequence.bodyTarget,
+ sg_layout=[6, 4],
+ sg_data=[32, 16],
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: setOpLayoutAttr
+ # CHECK: transform.xegpu.set_op_layout_attr %
+ # NO-CHECK: index = 0
+ # NO-CHECK: result
+ # CHECK: sg_layout = [6, 4]
+ # CHECK: sg_data = [32, 16]
+ # NO-CHECK: inst_data
+
+
+@run
+def setOpLayoutAttrResult():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.OperationType.get("xegpu.dpas"),
+ )
+ with InsertionPoint(sequence.body):
+ xegpu.set_op_layout_attr(
+ sequence.bodyTarget,
+ index=0,
+ sg_layout=[6, 4],
+ sg_data=[32, 16],
+ inst_data=[8, 16],
+ result=True,
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: setOpLayoutAttrResult
+ # CHECK: transform.xegpu.set_op_layout_attr %
+ # NO-CHECK: index = 0
+ # CHECK: result
+ # CHECK: sg_layout = [6, 4]
+ # CHECK: sg_data = [32, 16]
+ # CHECK: inst_data = [8, 16]
+
+
+@run
+def setOpLayoutAttrResultSlice():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.OperationType.get("xegpu.dpas"),
+ )
+ with InsertionPoint(sequence.body):
+ xegpu.set_op_layout_attr(
+ sequence.bodyTarget,
+ index=0,
+ sg_layout=[6, 4],
+ sg_data=[32, 16],
+ inst_data=[8, 16],
+ slice_dims=[0],
+ result=True,
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: setOpLayoutAttrResultSlice
+ # CHECK: transform.xegpu.set_op_layout_attr %
+ # NO-CHECK: index = 0
+ # CHECK: result
+ # CHECK: sg_layout = [6, 4]
+ # CHECK: sg_data = [32, 16]
+ # CHECK: inst_data = [8, 16]
+ # CHECK: slice_dims = [0]
+
+
+@run
+def setGPULaunchThreadsOp():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.OperationType.get("gpu.launch"),
+ )
+ with InsertionPoint(sequence.body):
+ xegpu.set_gpu_launch_threads(sequence.bodyTarget, threads=[8, 4, 1])
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: setGPULaunchThreadsOp
+ # CHECK: transform.xegpu.set_gpu_launch_threads
+ # CHECK: threads = [8, 4, 1]
+
+
+@run
+def insertPrefetch0():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.OperationType.get("xegpu.dpas"),
+ )
+ with InsertionPoint(sequence.body):
+ operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
+ xegpu.insert_prefetch(
+ operand,
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: insertPrefetch0
+ # CHECK: %[[OPR:.*]] = get_operand
+ # CHECK: transform.xegpu.insert_prefetch %[[OPR]]
+
+
+@run
+def insertPrefetchNbPrefetch():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.OperationType.get("xegpu.dpas"),
+ )
+ with InsertionPoint(sequence.body):
+ operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
+ xegpu.insert_prefetch(
+ operand,
+ nb_prefetch=2,
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: insertPrefetchNbPrefetch
+ # CHECK: %[[OPR:.*]] = get_operand
+ # CHECK: transform.xegpu.insert_prefetch %[[OPR]]
+ # CHECK-SAME: nb_prefetch = 2
+
+
+@run
+def insertPrefetchNbPrefetchParam():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.OperationType.get("xegpu.dpas"),
+ )
+ with InsertionPoint(sequence.body):
+ operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
+ int32_t = IntegerType.get_signless(32)
+ param_int32_t = transform.ParamType.get(int32_t)
+ nb_param = transform.ParamConstantOp(
+ param_int32_t,
+ IntegerAttr.get(int32_t, 2),
+ )
+ xegpu.insert_prefetch(
+ operand,
+ nb_prefetch=nb_param,
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: insertPrefetchNbPrefetchParam
+ # CHECK: %[[OPR:.*]] = get_operand
+ # CHECK: %[[PARAM_OP:.*]] = transform.param.constant 2
+ # CHECK: transform.xegpu.insert_prefetch %[[OPR]]
+ # CHECK-SAME: nb_prefetch = %[[PARAM_OP]]
+
+
+@run
+def ConvertLayoutMinimal():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.OperationType.get("xegpu.dpas"),
+ )
+ with InsertionPoint(sequence.body):
+ operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
+ xegpu.convert_layout(
+ operand,
+ input_sg_layout=[6, 4],
+ input_sg_data=[32, 16],
+ target_sg_layout=[6, 4],
+ target_sg_data=[8, 16],
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: ConvertLayoutMinimal
+ # CHECK: transform.xegpu.convert_layout %
+ # CHECK: input_sg_layout = [6, 4]
+ # CHECK: input_sg_data = [32, 16]
+ # CHECK: target_sg_layout = [6, 4]
+ # CHECK: target_sg_data = [8, 16]
+
+
+@run
+def ConvertLayout():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.OperationType.get("xegpu.dpas"),
+ )
+ with InsertionPoint(sequence.body):
+ operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [1])
+ xegpu.convert_layout(
+ operand,
+ input_sg_layout=[6, 4],
+ input_sg_data=[32, 32],
+ input_inst_data=[32, 16],
+ target_sg_layout=[6, 4],
+ target_sg_data=[32, 32],
+ target_inst_data=[8, 16],
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: ConvertLayout
+ # CHECK: transform.xegpu.convert_layout %
+ # CHECK: input_sg_layout = [6, 4]
+ # CHECK: input_sg_data = [32, 32]
+ # CHECK: input_inst_data = [32, 16]
+ # CHECK: target_sg_layout = [6, 4]
+ # CHECK: target_sg_data = [32, 32]
+ # CHECK: target_inst_data = [8, 16]
diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index 146e213a..b11340f 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -71,6 +71,7 @@ def testInvalidModule():
func.func @foo() { return }
"""
)
+ # CHECK: error: cannot be converted to LLVM IR: missing `LLVMTranslationDialectInterface` registration for dialect for op: func.func
# CHECK: Got RuntimeError: Failure while creating the ExecutionEngine.
try:
execution_engine = ExecutionEngine(module)
@@ -806,6 +807,7 @@ def testDumpToObjectFile():
# because RTDyldObjectLinkingLayer::emit will try to resolve symbols before dumping
# (see the jitLinkForORC call at the bottom there).
shared_libs=[MLIR_C_RUNNER_UTILS],
+ enable_pic=True,
)
# CHECK: Object file exists: True
diff --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py
index 8f20231..8eff573 100644
--- a/mlir/test/python/integration/dialects/linalg/opsrun.py
+++ b/mlir/test/python/integration/dialects/linalg/opsrun.py
@@ -25,13 +25,13 @@ func.func @main() -> i32 attributes {llvm.emit_c_interface} {
%O1 = memref.alloc() : memref<16xi32>
%O2 = memref.alloc() : memref<4x16xi32>
- %val0 = arith.constant 1.0 : f32
- %val1 = arith.constant 2.0 : f32
- %val2 = arith.constant 3.0 : f32
+ %val0 = arith.constant 1 : i32
+ %val1 = arith.constant 2 : i32
+ %val2 = arith.constant 3 : i32
- call @fill_0d_on_buffers(%val0, %O0) : (f32, memref<i32>) -> ()
- call @fill_1d_on_buffers(%val1, %O1) : (f32, memref<16xi32>) -> ()
- call @fill_2d_on_buffers(%val2, %O2) : (f32, memref<4x16xi32>) -> ()
+ call @fill_0d_on_buffers(%val0, %O0) : (i32, memref<i32>) -> ()
+ call @fill_1d_on_buffers(%val1, %O1) : (i32, memref<16xi32>) -> ()
+ call @fill_2d_on_buffers(%val2, %O2) : (i32, memref<4x16xi32>) -> ()
%c0 = arith.constant 0 : index
%res0 = memref.load %O0[] : memref<i32>
@@ -149,19 +149,18 @@ def transform(module, boilerplate):
def test_fill_builtin():
with Context() as ctx, Location.unknown():
module = Module.create()
- f32 = F32Type.get()
i32 = IntegerType.get_signless(32)
with InsertionPoint(module.body):
- @func.FuncOp.from_py_func(f32, MemRefType.get([], i32))
+ @func.FuncOp.from_py_func(i32, MemRefType.get([], i32))
def fill_0d_on_buffers(value, out):
linalg.fill(value, outs=[out])
- @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32))
+ @func.FuncOp.from_py_func(i32, MemRefType.get([16], i32))
def fill_1d_on_buffers(value, out):
linalg.fill(value, outs=[out])
- @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32))
+ @func.FuncOp.from_py_func(i32, MemRefType.get([4, 16], i32))
def fill_2d_on_buffers(value, out):
linalg.fill(value, outs=[out])
@@ -184,19 +183,18 @@ test_fill_builtin()
def test_fill_generic():
with Context() as ctx, Location.unknown():
module = Module.create()
- f32 = F32Type.get()
i32 = IntegerType.get_signless(32)
with InsertionPoint(module.body):
- @func.FuncOp.from_py_func(f32, MemRefType.get([], i32))
+ @func.FuncOp.from_py_func(i32, MemRefType.get([], i32))
def fill_0d_on_buffers(value, out):
linalg.fill(value, outs=[out], emit_generic=True)
- @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32))
+ @func.FuncOp.from_py_func(i32, MemRefType.get([16], i32))
def fill_1d_on_buffers(value, out):
linalg.fill(value, outs=[out], emit_generic=True)
- @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32))
+ @func.FuncOp.from_py_func(i32, MemRefType.get([4, 16], i32))
def fill_2d_on_buffers(value, out):
linalg.fill(value, outs=[out], emit_generic=True)
diff --git a/mlir/test/python/ir/auto_location.py b/mlir/test/python/ir/auto_location.py
index 8316890..1747c66a 100644
--- a/mlir/test/python/ir/auto_location.py
+++ b/mlir/test/python/ir/auto_location.py
@@ -15,17 +15,10 @@ def run(f):
assert Context._get_live_count() == 0
-@contextmanager
-def with_infer_location():
- _cext.globals.set_loc_tracebacks_enabled(True)
- yield
- _cext.globals.set_loc_tracebacks_enabled(False)
-
-
# CHECK-LABEL: TEST: testInferLocations
@run
def testInferLocations():
- with Context() as ctx, with_infer_location():
+ with Context() as ctx, loc_tracebacks():
ctx.allow_unregistered_dialects = True
op = Operation.create("custom.op1")
@@ -34,24 +27,26 @@ def testInferLocations():
two = arith.constant(IndexType.get(), 2)
# fmt: off
- # CHECK: loc(callsite("testInferLocations"("{{.*}}[[SEP:[/\\]+]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":31:13 to :43) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4))))
+ # CHECK: loc(callsite("testInferLocations"("{{.*}}[[SEP:[/\\]+]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:13 to :43) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:1 to :4))))
# fmt: on
print(op.location)
- # fmt: off
- # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":65:12 to :76) at callsite("constant"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":110:40 to :81) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":32:14 to :48) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4))))))
- # fmt: on
- print(one.location)
+ # Test nesting of loc_tracebacks().
+ with loc_tracebacks():
+ # fmt: off
+ # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":65:12 to :76) at callsite("constant"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":110:40 to :81) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:14 to :48) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:1 to :4))))))
+ # fmt: on
+ print(one.location)
# fmt: off
- # CHECK: loc(callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":34:14 to :48) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4))))
+ # CHECK: loc(callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:14 to :48) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:1 to :4))))
# fmt: on
print(two.location)
_cext.globals.register_traceback_file_inclusion(_arith_ops_gen.__file__)
three = arith.constant(IndexType.get(), 3)
# fmt: off
- # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":{{[0-9]+}}:4 to :235) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":52:16 to :50) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4)))))
+ # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":{{[0-9]+}}:4 to :235) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:16 to :50) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:1 to :4)))))
# fmt: on
print(three.location)
@@ -60,7 +55,7 @@ def testInferLocations():
print(four.location)
# fmt: off
- # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":{{[0-9]+}}:4 to :235) at callsite("testInferLocations.<locals>.foo"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":59:19 to :53) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":65:8 to :13) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4))))))
+ # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":{{[0-9]+}}:4 to :235) at callsite("testInferLocations.<locals>.foo"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:19 to :53) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:8 to :13) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:1 to :4))))))
# fmt: on
foo()
@@ -86,13 +81,13 @@ def testInferLocations():
_cext.globals.set_loc_tracebacks_frame_limit(2)
# fmt: off
- # CHECK: loc(callsite("testInferLocations.<locals>.bar1.<locals>.bar2.<locals>.bar3"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":77:27 to :61) at "testInferLocations.<locals>.bar1.<locals>.bar2"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":80:16 to :22)))
+ # CHECK: loc(callsite("testInferLocations.<locals>.bar1.<locals>.bar2.<locals>.bar3"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:27 to :61) at "testInferLocations.<locals>.bar1.<locals>.bar2"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:16 to :22)))
# fmt: on
bar1()
_cext.globals.set_loc_tracebacks_frame_limit(1)
# fmt: off
- # CHECK: loc("testInferLocations.<locals>.bar1.<locals>.bar2.<locals>.bar3"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":77:27 to :61))
+ # CHECK: loc("testInferLocations.<locals>.bar1.<locals>.bar2.<locals>.bar3"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:27 to :61))
# fmt: on
bar1()
diff --git a/mlir/test/python/ir/blocks.py b/mlir/test/python/ir/blocks.py
index ced5fce..e876c00 100644
--- a/mlir/test/python/ir/blocks.py
+++ b/mlir/test/python/ir/blocks.py
@@ -191,3 +191,18 @@ def testBlockEraseArgs():
blocks[0].erase_argument(0)
# CHECK: ^bb0:
op.print(enable_debug_info=True)
+
+
+# CHECK-LABEL: TEST: testBlockArgSetLocation
+# CHECK: ^bb0(%{{.+}}: f32 loc("new_loc")):
+@run
+def testBlockArgSetLocation():
+ with Context() as ctx, Location.unknown(ctx) as loc:
+ ctx.allow_unregistered_dialects = True
+ f32 = F32Type.get()
+ op = Operation.create("test", regions=1, loc=Location.unknown())
+ blocks = op.regions[0].blocks
+ blocks.append(f32)
+ arg = blocks[0].arguments[0]
+ arg.set_location(Location.name("new_loc"))
+ op.print(enable_debug_info=True)
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index f5fa4da..d124c28 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -2,12 +2,12 @@
import gc
import io
-import itertools
from tempfile import NamedTemporaryFile
from mlir.ir import *
from mlir.dialects.builtin import ModuleOp
-from mlir.dialects import arith
+from mlir.dialects import arith, func, scf, shape
from mlir.dialects._ods_common import _cext
+from mlir.extras import types as T
def run(f):
@@ -43,6 +43,10 @@ def testTraverseOpRegionBlockIterators():
)
op = module.operation
assert op.context is ctx
+ # Note, __nb_signature__ stores the fully-qualified signature - the actual type stub emitted is
+ # class RegionSequence(Sequence[Region])
+ # CHECK: class RegionSequence(collections.abc.Sequence[mlir._mlir_libs._mlir.ir.Region])
+ print(RegionSequence.__nb_signature__)
# Get the block using iterators off of the named collections.
regions = list(op.regions[:])
blocks = list(regions[0].blocks)
@@ -774,6 +778,21 @@ def testKnownOpView():
print(repr(constant))
+# CHECK-LABEL: TEST: testFailedGenericOperationCreationReportsError
+@run
+def testFailedGenericOperationCreationReportsError():
+ with Context(), Location.unknown():
+ c0 = shape.const_shape([])
+ c1 = shape.const_shape([1, 2, 3])
+ try:
+ shape.MeetOp.build_generic(operands=[c0, c1])
+ except MLIRError as e:
+ # CHECK: unequal shape cardinality
+ print(e)
+ else:
+ assert False, "Expected exception"
+
+
# CHECK-LABEL: TEST: testSingleResultProperty
@run
def testSingleResultProperty():
@@ -1199,3 +1218,25 @@ def testGetOwnerConcreteOpview():
r = arith.AddIOp(a, a, overflowFlags=arith.IntegerOverflowFlags.nsw)
for u in a.result.uses:
assert isinstance(u.owner, arith.AddIOp)
+
+
+# CHECK-LABEL: TEST: testIndexSwitch
+@run
+def testIndexSwitch():
+ with Context() as ctx, Location.unknown():
+ i32 = T.i32()
+ module = Module.create()
+ with InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(T.index())
+ def index_switch(index):
+ c1 = arith.constant(i32, 1)
+ switch_op = scf.IndexSwitchOp(results=[i32], arg=index, cases=range(3))
+
+ assert len(switch_op.regions) == 4
+ assert len(switch_op.regions[2:]) == 2
+ assert len([i for i in switch_op.regions[2:]]) == 2
+ assert len(switch_op.caseRegions) == 3
+ assert len([i for i in switch_op.caseRegions]) == 3
+ assert len(switch_op.caseRegions[1:]) == 2
+ assert len([i for i in switch_op.caseRegions[1:]]) == 2