aboutsummaryrefslogtreecommitdiff
path: root/mlir/test
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test')
-rw-r--r--mlir/test/CMakeLists.txt6
-rw-r--r--mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir58
-rw-r--r--mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir72
-rw-r--r--mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir17
-rw-r--r--mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir108
-rw-r--r--mlir/test/Conversion/ConvertToSPIRV/convert-gpu-modules.mlir9
-rw-r--r--mlir/test/Conversion/ConvertToSPIRV/vector.mlir102
-rw-r--r--mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir54
-rw-r--r--mlir/test/Conversion/GPUToSPIRV/lookup-target-env.mlir40
-rw-r--r--mlir/test/Conversion/GPUToSPIRV/rotate.mlir38
-rw-r--r--mlir/test/Conversion/MathToSPIRV/math-to-fpclassify-spirv.mlir27
-rw-r--r--mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc.mlir72
-rw-r--r--mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir24
-rw-r--r--mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir (renamed from mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir)66
-rw-r--r--mlir/test/Conversion/ShardToMPI/convert-shardshape-to-mpi.mlir (renamed from mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir)40
-rw-r--r--mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir2
-rw-r--r--mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir142
-rw-r--r--mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir122
-rw-r--r--mlir/test/Dialect/AMDGPU/canonicalize.mlir29
-rw-r--r--mlir/test/Dialect/Arith/mesh-spmdize.mlir17
-rw-r--r--mlir/test/Dialect/Arith/shard-partition.mlir17
-rw-r--r--mlir/test/Dialect/Arith/sharding-propagation.mlir60
-rw-r--r--mlir/test/Dialect/Async/canonicalize.mlir10
-rw-r--r--mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir33
-rw-r--r--mlir/test/Dialect/Bufferization/canonicalize.mlir20
-rw-r--r--mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir2
-rw-r--r--mlir/test/Dialect/GPU/invalid.mlir44
-rw-r--r--mlir/test/Dialect/GPU/ops.mlir5
-rw-r--r--mlir/test/Dialect/Linalg/canonicalize.mlir146
-rw-r--r--mlir/test/Dialect/Linalg/data-layout-propagation.mlir18
-rw-r--r--mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir55
-rw-r--r--mlir/test/Dialect/Linalg/invalid.mlir37
-rw-r--r--mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir42
-rw-r--r--mlir/test/Dialect/Linalg/shard-partition.mlir (renamed from mlir/test/Dialect/Linalg/mesh-spmdization.mlir)118
-rw-r--r--mlir/test/Dialect/Linalg/sharding-propagation.mlir42
-rw-r--r--mlir/test/Dialect/Linalg/transform-lower-pack.mlir16
-rw-r--r--mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir149
-rw-r--r--mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir24
-rw-r--r--mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir51
-rw-r--r--mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir13
-rw-r--r--mlir/test/Dialect/Mesh/canonicalization.mlir248
-rw-r--r--mlir/test/Dialect/Mesh/folding.mlir22
-rw-r--r--mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir49
-rw-r--r--mlir/test/Dialect/Mesh/inlining.mlir15
-rw-r--r--mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir23
-rw-r--r--mlir/test/Dialect/Mesh/resharding-spmdization.mlir168
-rw-r--r--mlir/test/Dialect/Mesh/sharding-propagation.mlir301
-rw-r--r--mlir/test/Dialect/Mesh/spmdization.mlir317
-rw-r--r--mlir/test/Dialect/OpenMP/ops.mlir10
-rw-r--r--mlir/test/Dialect/SCF/canonicalize.mlir34
-rw-r--r--mlir/test/Dialect/SPIRV/IR/logical-ops.mlir18
-rw-r--r--mlir/test/Dialect/SPIRV/IR/types.mlir6
-rw-r--r--mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir24
-rw-r--r--mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir2
-rw-r--r--mlir/test/Dialect/Shard/all-scatter-op-lowering.mlir (renamed from mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir)40
-rw-r--r--mlir/test/Dialect/Shard/backward-sharding-propagation.mlir (renamed from mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir)12
-rw-r--r--mlir/test/Dialect/Shard/canonicalization.mlir248
-rw-r--r--mlir/test/Dialect/Shard/folding.mlir22
-rw-r--r--mlir/test/Dialect/Shard/forward-backward-sharding-propagation.mlir (renamed from mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir)14
-rw-r--r--mlir/test/Dialect/Shard/forward-sharding-propagation.mlir49
-rw-r--r--mlir/test/Dialect/Shard/inlining.mlir15
-rw-r--r--mlir/test/Dialect/Shard/invalid.mlir (renamed from mlir/test/Dialect/Mesh/invalid.mlir)442
-rw-r--r--mlir/test/Dialect/Shard/ops.mlir (renamed from mlir/test/Dialect/Mesh/ops.mlir)350
-rw-r--r--mlir/test/Dialect/Shard/partition.mlir317
-rw-r--r--mlir/test/Dialect/Shard/process-multi-index-op-lowering.mlir23
-rw-r--r--mlir/test/Dialect/Shard/resharding-partition.mlir168
-rw-r--r--mlir/test/Dialect/Shard/sharding-propagation-failed.mlir (renamed from mlir/test/Dialect/Mesh/sharding-propagation-failed.mlir)0
-rw-r--r--mlir/test/Dialect/Shard/sharding-propagation.mlir301
-rw-r--r--mlir/test/Dialect/Shard/simplifications.mlir (renamed from mlir/test/Dialect/Mesh/simplifications.mlir)78
-rw-r--r--mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir134
-rw-r--r--mlir/test/Dialect/Tensor/mesh-spmdization.mlir52
-rw-r--r--mlir/test/Dialect/Tensor/shard-partition.mlir52
-rw-r--r--mlir/test/Dialect/Tosa/availability.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/canonicalize.mlir31
-rw-r--r--mlir/test/Dialect/Tosa/controlflow.mlir35
-rw-r--r--mlir/test/Dialect/Tosa/error_if_check.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/invalid.mlir24
-rw-r--r--mlir/test/Dialect/Tosa/invalid_extension.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/level_check.mlir18
-rw-r--r--mlir/test/Dialect/Tosa/ops.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir40
-rw-r--r--mlir/test/Dialect/Tosa/verifier.mlir85
-rw-r--r--mlir/test/Dialect/Vector/canonicalize.mlir158
-rw-r--r--mlir/test/Dialect/Vector/int-range-interface.mlir38
-rw-r--r--mlir/test/Dialect/Vector/invalid.mlir56
-rw-r--r--mlir/test/Dialect/Vector/ops.mlir32
-rw-r--r--mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir16
-rw-r--r--mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir24
-rw-r--r--mlir/test/Dialect/Vector/vector-sink.mlir139
-rw-r--r--mlir/test/Dialect/Vector/vector-transfer-unroll.mlir18
-rw-r--r--mlir/test/Dialect/XeGPU/invalid.mlir68
-rw-r--r--mlir/test/Dialect/XeGPU/ops.mlir29
-rw-r--r--mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir139
-rw-r--r--mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir242
-rw-r--r--mlir/test/Examples/transform/Ch3/ops.mlir21
-rw-r--r--mlir/test/Examples/transform/Ch3/sequence.mlir11
-rw-r--r--mlir/test/IR/diagnostic-nosplit.mlir13
-rw-r--r--mlir/test/IR/test-pattern-logging-listener.mlir10
-rw-r--r--mlir/test/IR/top-level.mlir4
-rw-r--r--mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir12
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir8
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir6
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir4
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot.mlir4
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/X86Vector/sparse-dot-product.mlir16
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/compress.mlir5
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/maskedstore.mlir5
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/scatter.mlir3
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/transfer-read-1d.mlir2
-rw-r--r--mlir/test/Integration/GPU/Vulkan/vector-interleave.mlir16
-rw-r--r--mlir/test/Integration/GPU/Vulkan/vector-shuffle.mlir16
-rw-r--r--mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir127
-rw-r--r--mlir/test/Target/LLVMIR/Import/intrinsic.ll20
-rw-r--r--mlir/test/Target/LLVMIR/Import/module-asm.ll5
-rw-r--r--mlir/test/Target/LLVMIR/invalid-module.mlir12
-rw-r--r--mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir24
-rw-r--r--mlir/test/Target/LLVMIR/module-asm.mlir6
-rw-r--r--mlir/test/Target/LLVMIR/nvvmir-invalid.mlir39
-rw-r--r--mlir/test/Target/LLVMIR/nvvmir.mlir23
-rw-r--r--mlir/test/Target/LLVMIR/omptarget-debug-reduc-fn-loc.mlir121
-rw-r--r--mlir/test/Target/LLVMIR/xevm.mlir21
-rw-r--r--mlir/test/Target/SPIRV/constant.mlir33
-rw-r--r--mlir/test/Target/SPIRV/lit.local.cfg4
-rw-r--r--mlir/test/Target/SPIRV/logical-ops.mlir2
-rw-r--r--mlir/test/Target/SPIRV/memory-ops.mlir20
-rw-r--r--mlir/test/Target/SPIRV/struct.mlir38
-rw-r--r--mlir/test/Target/SPIRV/undef.mlir6
-rw-r--r--mlir/test/Transforms/compose-subview.mlir70
-rw-r--r--mlir/test/Transforms/remove-dead-values.mlir23
-rw-r--r--mlir/test/Transforms/test-legalize-type-conversion.mlir2
-rw-r--r--mlir/test/Transforms/test-legalizer-analysis.mlir2
-rw-r--r--mlir/test/Transforms/test-legalizer-full.mlir2
-rw-r--r--mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp108
-rw-r--r--mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp7
-rw-r--r--mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp2
-rw-r--r--mlir/test/lib/Dialect/Bufferization/CMakeLists.txt1
-rw-r--r--mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp57
-rw-r--r--mlir/test/lib/Dialect/CMakeLists.txt2
-rw-r--r--mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp6
-rw-r--r--mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp12
-rw-r--r--mlir/test/lib/Dialect/SCF/TestWhileOpBuilder.cpp16
-rw-r--r--mlir/test/lib/Dialect/Shard/CMakeLists.txt (renamed from mlir/test/lib/Dialect/Mesh/CMakeLists.txt)10
-rw-r--r--mlir/test/lib/Dialect/Shard/TestOpLowering.cpp (renamed from mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp)18
-rw-r--r--mlir/test/lib/Dialect/Shard/TestReshardingPartition.cpp (renamed from mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp)49
-rw-r--r--mlir/test/lib/Dialect/Shard/TestSimplifications.cpp (renamed from mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp)24
-rw-r--r--mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp27
-rw-r--r--mlir/test/lib/Dialect/Test/TestAttrDefs.td1
-rw-r--r--mlir/test/lib/Dialect/Test/TestAttributes.cpp10
-rw-r--r--mlir/test/lib/Dialect/Test/TestDialect.cpp2
-rw-r--r--mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp10
-rw-r--r--mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp8
-rw-r--r--mlir/test/lib/Dialect/Test/TestOpDefs.cpp46
-rw-r--r--mlir/test/lib/Dialect/Test/TestOps.td22
-rw-r--r--mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp4
-rw-r--r--mlir/test/lib/Dialect/Test/TestPatterns.cpp83
-rw-r--r--mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp4
-rw-r--r--mlir/test/lib/Dialect/TestIRDLToCpp/TestIRDLToCppDialect.cpp6
-rw-r--r--mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp14
-rw-r--r--mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp7
-rw-r--r--mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp28
-rw-r--r--mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp2
-rw-r--r--mlir/test/lib/IR/TestPrintInvalid.cpp7
-rw-r--r--mlir/test/lib/IR/TestSlicing.cpp6
-rw-r--r--mlir/test/lib/Pass/TestPassManager.cpp4
-rw-r--r--mlir/test/lib/Transforms/TestDialectConversion.cpp2
-rw-r--r--mlir/test/lib/Transforms/TestInliningCallback.cpp12
-rw-r--r--mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp2
-rw-r--r--mlir/test/lib/Transforms/TestTransformsOps.cpp4
-rw-r--r--mlir/test/lit.cfg.py1
-rw-r--r--mlir/test/lit.site.cfg.py.in3
-rw-r--r--mlir/test/mlir-runner/simple.mlir6
-rw-r--r--mlir/test/mlir-tblgen/attrdefs.td5
-rw-r--r--mlir/test/mlir-tblgen/op-properties-predicates.td6
-rw-r--r--mlir/test/mlir-tblgen/rewriter-attributes-properties.td2
-rw-r--r--mlir/test/mlir-tblgen/rewriter-indexing.td2
-rw-r--r--mlir/test/python/ir/array_attributes.py22
177 files changed, 4451 insertions, 3678 deletions
diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index ac8b44f5..89568e7 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -68,6 +68,7 @@ endif()
llvm_canonicalize_cmake_booleans(
LLVM_BUILD_EXAMPLES
LLVM_HAS_NVPTX_TARGET
+ LLVM_INCLUDE_SPIRV_TOOLS_TESTS
MLIR_ENABLE_BINDINGS_PYTHON
MLIR_ENABLE_CUDA_RUNNER
MLIR_ENABLE_ROCM_CONVERSIONS
@@ -217,6 +218,11 @@ if(MLIR_ENABLE_BINDINGS_PYTHON)
)
endif()
+if (LLVM_INCLUDE_SPIRV_TOOLS_TESTS)
+ list(APPEND MLIR_TEST_DEPENDS spirv-as)
+ list(APPEND MLIR_TEST_DEPENDS spirv-val)
+endif()
+
# This target can be used to just build the dependencies
# for the check-mlir target without executing the tests.
# This is useful for bots when splitting the build step
diff --git a/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir b/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
index b980451..1d36be1 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
@@ -163,27 +163,23 @@ func.func @conversion_f4_f16_fallback(%in: vector<2x2xf4E2M1FN>, %scale: vector<
// CHECK-DAG: %[[SCALE_CAST:.+]] = vector.shape_cast %[[BCAST]]
// CHECK-DAG: %[[SCALE_EXT:.+]] = arith.extf %[[SCALE_CAST]]
// CHECK-DAG: vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 0, 0], sizes = [1, 1, 4], strides = [1, 1, 1]}
-// CHECK-NEXT: vector.shape_cast
+// CHECK-NEXT: %[[IN_SLICE_CAST:.+]] = vector.shape_cast
// CHECK-NEXT: vector.extract %[[SCALE_EXT]][0, 0, 0]
-// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT: amdgpu.scaled_ext_packed
-// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
-// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
-// CHECK-NEXT: amdgpu.scaled_ext_packed
-// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]}
+// CHECK-NEXT: %[[LOWHALF:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_CAST]][0]
+// CHECK-NEXT: vector.insert_strided_slice %[[LOWHALF]], %{{.+}} {offsets = [0], strides = [1]}
+// CHECK-NEXT: %[[HIGHHALF:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_CAST]][1]
+// CHECK-NEXT: vector.insert_strided_slice %[[HIGHHALF]], %{{.+}} {offsets = [2], strides = [1]}
// CHECK-NEXT: vector.shape_cast
// CHECK-NEXT: vector.insert_strided_slice %{{.+}} {offsets = [0, 0, 0], strides = [1, 1, 1]}
// CHECK-NEXT: vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 1, 0], sizes = [1, 1, 4], strides = [1, 1, 1]}
// CHECK-NEXT: vector.shape_cast
// CHECK-NEXT: vector.extract %[[SCALE_EXT]][0, 1, 0]
-// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
// CHECK-NEXT: amdgpu.scaled_ext_packed
// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
-// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
// CHECK-NEXT: amdgpu.scaled_ext_packed
// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]}
// CHECK-NEXT: vector.shape_cast
-// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0, 1, 0], strides = [1, 1, 1]}
+// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0, 1, 0], strides = [1, 1, 1]}
func.func @conversion_broadcast(%in: vector<8x8xf8E5M2>, %scale: vector<8x2xf8E8M0FNU>) -> vector<8x8xf32> {
%bc = vector.broadcast %scale : vector<8x2xf8E8M0FNU> to vector<4x8x2xf8E8M0FNU>
%cast1 = vector.shape_cast %in : vector<8x8xf8E5M2> to vector<8x2x4xf8E5M2>
@@ -203,21 +199,17 @@ func.func @conversion_broadcast(%in: vector<8x8xf8E5M2>, %scale: vector<8x2xf8E8
// CHECK-NEXT: %[[SCALE_EXT:.+]] = arith.extf %[[SCALE_FLAT]] : vector<6xf8E8M0FNU> to vector<6xf32>
// CHECK-NEXT: %[[IN_SLICE_0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [3], strides = [1]} : vector<6xf8E5M2> to vector<3xf8E5M2>
// CHECK-NEXT: %[[SCALE_SCALAR_0:.+]] = vector.extract %[[SCALE_EXT]][0] : f32 from vector<6xf32>
-// CHECK-NEXT: %[[IN_CHUNK_0A:.+]] = vector.extract_strided_slice %[[IN_SLICE_0]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT: %[[PACKED_0A:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_0A]][0], %[[SCALE_SCALAR_0]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK-NEXT: %[[PACKED_0A:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_0]][0], %[[SCALE_SCALAR_0]] : vector<3xf8E5M2> to vector<2xf32>
// CHECK-NEXT: %[[PARTIAL_ACC_0:.+]] = vector.insert_strided_slice %[[PACKED_0A]], %[[CST_PARTIAL]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<3xf32>
-// CHECK-NEXT: %[[IN_CHUNK_0B:.+]] = vector.extract_strided_slice %[[IN_SLICE_0]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf8E5M2> to vector<1xf8E5M2>
-// CHECK-NEXT: %[[PACKED_0B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_0B]][0], %[[SCALE_SCALAR_0]] : vector<1xf8E5M2> to vector<2xf32>
+// CHECK-NEXT: %[[PACKED_0B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_0]][1], %[[SCALE_SCALAR_0]] : vector<3xf8E5M2> to vector<2xf32>
// CHECK-NEXT: %[[PACKED_0B:.+]] = vector.extract_strided_slice %[[PACKED_0B_RAW]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32>
// CHECK-NEXT: %[[OUT_SLICE_0:.+]] = vector.insert_strided_slice %[[PACKED_0B]], %[[PARTIAL_ACC_0]] {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32>
// CHECK-NEXT: %[[FINAL_ACC_A:.+]] = vector.insert_strided_slice %[[OUT_SLICE_0]], %[[CST_FINAL]] {offsets = [0], strides = [1]} : vector<3xf32> into vector<6xf32>
// CHECK-NEXT: %[[IN_SLICE_1:.+]] = vector.extract_strided_slice %arg0 {offsets = [3], sizes = [3], strides = [1]} : vector<6xf8E5M2> to vector<3xf8E5M2>
// CHECK-NEXT: %[[SCALE_SCALAR_1:.+]] = vector.extract %[[SCALE_EXT]][3] : f32 from vector<6xf32>
-// CHECK-NEXT: %[[IN_CHUNK_1A:.+]] = vector.extract_strided_slice %[[IN_SLICE_1]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT: %[[PACKED_1A:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_1A]][0], %[[SCALE_SCALAR_1]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK-NEXT: %[[PACKED_1A:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_1]][0], %[[SCALE_SCALAR_1]] : vector<3xf8E5M2> to vector<2xf32>
// CHECK-NEXT: %[[PARTIAL_ACC_1:.+]] = vector.insert_strided_slice %[[PACKED_1A]], %[[CST_PARTIAL]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<3xf32>
-// CHECK-NEXT: %[[IN_CHUNK_1B:.+]] = vector.extract_strided_slice %[[IN_SLICE_1]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf8E5M2> to vector<1xf8E5M2>
-// CHECK-NEXT: %[[PACKED_1B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_1B]][0], %[[SCALE_SCALAR_1]] : vector<1xf8E5M2> to vector<2xf32>
+// CHECK-NEXT: %[[PACKED_1B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_1]][1], %[[SCALE_SCALAR_1]] : vector<3xf8E5M2> to vector<2xf32>
// CHECK-NEXT: %[[PACKED_1B:.+]] = vector.extract_strided_slice %[[PACKED_1B_RAW]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32>
// CHECK-NEXT: %[[OUT_SLICE_1:.+]] = vector.insert_strided_slice %[[PACKED_1B]], %[[PARTIAL_ACC_1]] {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32>
// CHECK-NEXT: %[[RESULT:.+]] = vector.insert_strided_slice %[[OUT_SLICE_1]], %[[FINAL_ACC_A]] {offsets = [3], strides = [1]} : vector<3xf32> into vector<6xf32>
@@ -236,11 +228,9 @@ func.func @conversion_broadcast_odd(%in: vector<6xf8E5M2>, %scale: vector<2xf8E8
// CHECK-DAG: %[[SCALE_SPLAT:.+]] = vector.broadcast %arg1 : f8E8M0FNU to vector<4xf8E8M0FNU>
// CHECK-DAG: %[[SCALE_EXTF:.+]] = arith.extf %[[SCALE_SPLAT]] : vector<4xf8E8M0FNU> to vector<4xf32>
// CHECK-DAG: %[[SCALE_SCALAR:.+]] = vector.extract %[[SCALE_EXTF]][0] : f32 from vector<4xf32>
-// CHECK: %[[IN_CHUNK0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT: %[[OUT_CHUNK0:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK0]][0], %[[SCALE_SCALAR]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK: %[[OUT_CHUNK0:.+]] = amdgpu.scaled_ext_packed %arg0[0], %[[SCALE_SCALAR]] : vector<4xf8E5M2> to vector<2xf32>
// CHECK-NEXT: %[[ACCUM_A:.+]] = vector.insert_strided_slice %[[OUT_CHUNK0]], %[[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
-// CHECK-NEXT: %[[IN_CHUNK1:.+]] = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT: %[[OUT_CHUNK1:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK1]][0], %[[SCALE_SCALAR]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK-NEXT: %[[OUT_CHUNK1:.+]] = amdgpu.scaled_ext_packed %arg0[1], %[[SCALE_SCALAR]] : vector<4xf8E5M2> to vector<2xf32>
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = vector.insert_strided_slice %[[OUT_CHUNK1]], %[[ACCUM_A]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
// CHECK-NEXT: return %[[FINAL_RESULT]] : vector<4xf32>
func.func @conversion_broadcast(%in: vector<4xf8E5M2>, %scale: f8E8M0FNU) -> vector<4xf32> {
@@ -261,3 +251,27 @@ func.func @conversion_scalar(%in: f8E5M2, %scale: f8E8M0FNU) -> f32 {
%ext = arith.scaling_extf %in, %scale : f8E5M2, f8E8M0FNU to f32
return %ext : f32
}
+
+// -----
+
+// CHECK-LABEL: @long_fp4_broadcast
+// CHECK-COUNT-4: amdgpu.scaled_ext_packed %{{.+}}[3]
+// CHECK-NOT: amdgpu.scaled_ext_packed
+// CHECK: return
+func.func @long_fp4_broadcast(%in: vector<32xf4E2M1FN>, %scale: f32) -> vector<32xf32> {
+ %splat = vector.broadcast %scale : f32 to vector<32xf32>
+ %ext = arith.scaling_extf %in, %splat : vector<32xf4E2M1FN>, vector<32xf32> to vector<32xf32>
+ return %ext : vector<32xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @long_fp8_broadcast
+// CHECK-COUNT-8: amdgpu.scaled_ext_packed %{{.+}}[1]
+// CHECK-NOT: amdgpu.scaled_ext_packed
+// CHECK: return
+func.func @long_fp8_broadcast(%in: vector<32xf8E4M3FN>, %scale: f32) -> vector<32xf32> {
+ %splat = vector.broadcast %scale : f32 to vector<32xf32>
+ %ext = arith.scaling_extf %in, %splat : vector<32xf8E4M3FN>, vector<32xf32> to vector<32xf32>
+ return %ext : vector<32xf32>
+}
diff --git a/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir b/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir
index 488e75c..90a8608 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir
@@ -88,28 +88,20 @@ func.func @conversion_f4_fallback(%in: vector<2x2xf32>, %scale: vector<2x2xf8E8M
// CHECK-NEXT: vector.shape_cast
// CHECK-NEXT: vector.extract %[[SCALE_EXT]][0, 0, 0]
// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT: amdgpu.packed_scaled_trunc
-// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
+// CHECK-NEXT: %[[P1:.+]] = amdgpu.packed_scaled_trunc
// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
-// CHECK-NEXT: amdgpu.packed_scaled_trunc
-// CHECK-NEXT: vector.extract_strided_slice
-// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]}
-// CHECK-NEXT: vector.shape_cast
-// CHECK-NEXT: vector.insert_strided_slice %{{.+}} {offsets = [0, 0, 0], strides = [1, 1, 1]}
+// CHECK-NEXT: %[[P2:.+]] = amdgpu.packed_scaled_trunc {{.*}} into %[[P1]][1]
+// CHECK-NEXT: %[[P2_CAST:.+]] = vector.shape_cast %[[P2]] : vector<4xf8E5M2> to vector<1x1x4xf8E5M2>
+// CHECK-NEXT: vector.insert_strided_slice %[[P2_CAST]], %{{.+}} {offsets = [0, 0, 0], strides = [1, 1, 1]}
// CHECK-NEXT: vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 1, 0], sizes = [1, 1, 4], strides = [1, 1, 1]}
// CHECK-NEXT: vector.shape_cast
// CHECK-NEXT: vector.extract %[[SCALE_EXT]][0, 1, 0]
// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
// CHECK-NEXT: amdgpu.packed_scaled_trunc
-// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
// CHECK-NEXT: amdgpu.packed_scaled_trunc
-// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]}
// CHECK-NEXT: vector.shape_cast
-// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0, 1, 0], strides = [1, 1, 1]}
+// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0, 1, 0], strides = [1, 1, 1]}
func.func @conversion_broadcast(%in: vector<8x8xf32>, %scale: vector<8x2xf8E8M0FNU>) -> vector<8x8xf8E5M2> {
%bc = vector.broadcast %scale : vector<8x2xf8E8M0FNU> to vector<4x8x2xf8E8M0FNU>
%cast1 = vector.shape_cast %in : vector<8x8xf32> to vector<8x2x4xf32>
@@ -122,7 +114,7 @@ func.func @conversion_broadcast(%in: vector<8x8xf32>, %scale: vector<8x2xf8E8M0F
// -----
// CHECK-LABEL: @conversion_broadcast_odd
-// CHECK-NEXT: %[[CST3:.+]] = arith.constant dense<0.000000e+00> : vector<3xf8E5M2>
+// CHECK-NEXT: %[[CST4:.+]] = arith.constant dense<0.000000e+00> : vector<4xf8E5M2>
// CHECK-NEXT: %[[CST6:.+]] = arith.constant dense<0.000000e+00> : vector<6xf8E5M2>
// CHECK-NEXT: %[[SCALE_BCAST:.+]] = vector.broadcast %arg1 : vector<2xf8E8M0FNU> to vector<3x2xf8E8M0FNU>
// CHECK-NEXT: %[[SCALE_FLAT:.+]] = vector.shape_cast %[[SCALE_BCAST]] : vector<3x2xf8E8M0FNU> to vector<6xf8E8M0FNU>
@@ -130,24 +122,18 @@ func.func @conversion_broadcast(%in: vector<8x8xf32>, %scale: vector<8x2xf8E8M0F
// CHECK-NEXT: %[[IN_CHUNK0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
// CHECK-NEXT: %[[SCALE0:.+]] = vector.extract %[[SCALE_EXTF]][0] : f32 from vector<6xf32>
// CHECK-NEXT: %[[IN_CHUNK0_PART0:.+]] = vector.extract_strided_slice %[[IN_CHUNK0]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf32> to vector<2xf32>
-// CHECK-NEXT: %[[PACKED0_PART0:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK0_PART0]] into undef[0], %[[SCALE0]] : vector<2xf32> to vector<4xf8E5M2>
-// CHECK-NEXT: %[[OUT_CHUNK0_PART0:.+]] = vector.extract_strided_slice %[[PACKED0_PART0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT: %[[ACCUM0_PART0:.+]] = vector.insert_strided_slice %[[OUT_CHUNK0_PART0]], %[[CST3]] {offsets = [0], strides = [1]} : vector<2xf8E5M2> into vector<3xf8E5M2>
+// CHECK-NEXT: %[[PACKED0_PART0:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK0_PART0]] into %[[CST4]][0], %[[SCALE0]] : vector<2xf32> to vector<4xf8E5M2>
// CHECK-NEXT: %[[IN_CHUNK0_PART1:.+]] = vector.extract_strided_slice %[[IN_CHUNK0]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32>
-// CHECK-NEXT: %[[PACKED0_PART1:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK0_PART1]] into undef[0], %[[SCALE0]] : vector<1xf32> to vector<4xf8E5M2>
-// CHECK-NEXT: %[[OUT_CHUNK0_PART1:.+]] = vector.extract_strided_slice %[[PACKED0_PART1]] {offsets = [0], sizes = [1], strides = [1]} : vector<4xf8E5M2> to vector<1xf8E5M2>
-// CHECK-NEXT: %[[CHUNK0_RES:.+]] = vector.insert_strided_slice %[[OUT_CHUNK0_PART1]], %[[ACCUM0_PART0]] {offsets = [2], strides = [1]} : vector<1xf8E5M2> into vector<3xf8E5M2>
+// CHECK-NEXT: %[[PACKED0_PART1:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK0_PART1]] into %[[PACKED0_PART0]][1], %[[SCALE0]] : vector<1xf32> to vector<4xf8E5M2>
+// CHECK-NEXT: %[[CHUNK0_RES:.+]] = vector.extract_strided_slice %[[PACKED0_PART1]] {offsets = [0], sizes = [3], strides = [1]} : vector<4xf8E5M2> to vector<3xf8E5M2>
// CHECK-NEXT: %[[FINAL_ACCUM_A:.+]] = vector.insert_strided_slice %[[CHUNK0_RES]], %[[CST6]] {offsets = [0], strides = [1]} : vector<3xf8E5M2> into vector<6xf8E5M2>
// CHECK-NEXT: %[[IN_CHUNK1:.+]] = vector.extract_strided_slice %arg0 {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
// CHECK-NEXT: %[[SCALE1:.+]] = vector.extract %[[SCALE_EXTF]][3] : f32 from vector<6xf32>
// CHECK-NEXT: %[[IN_CHUNK1_PART0:.+]] = vector.extract_strided_slice %[[IN_CHUNK1]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf32> to vector<2xf32>
-// CHECK-NEXT: %[[PACKED1_PART0:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK1_PART0]] into undef[0], %[[SCALE1]] : vector<2xf32> to vector<4xf8E5M2>
-// CHECK-NEXT: %[[OUT_CHUNK1_PART0:.+]] = vector.extract_strided_slice %[[PACKED1_PART0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT: %[[ACCUM1_PART0:.+]] = vector.insert_strided_slice %[[OUT_CHUNK1_PART0]], %[[CST3]] {offsets = [0], strides = [1]} : vector<2xf8E5M2> into vector<3xf8E5M2>
+// CHECK-NEXT: %[[PACKED1_PART0:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK1_PART0]] into %[[CST4]][0], %[[SCALE1]] : vector<2xf32> to vector<4xf8E5M2>
// CHECK-NEXT: %[[IN_CHUNK1_PART1:.+]] = vector.extract_strided_slice %[[IN_CHUNK1]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32>
-// CHECK-NEXT: %[[PACKED1_PART1:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK1_PART1]] into undef[0], %[[SCALE1]] : vector<1xf32> to vector<4xf8E5M2>
-// CHECK-NEXT: %[[OUT_CHUNK1_PART1:.+]] = vector.extract_strided_slice %[[PACKED1_PART1]] {offsets = [0], sizes = [1], strides = [1]} : vector<4xf8E5M2> to vector<1xf8E5M2>
-// CHECK-NEXT: %[[CHUNK1_RES:.+]] = vector.insert_strided_slice %[[OUT_CHUNK1_PART1]], %[[ACCUM1_PART0]] {offsets = [2], strides = [1]} : vector<1xf8E5M2> into vector<3xf8E5M2>
+// CHECK-NEXT: %[[PACKED1_PART1:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK1_PART1]] into %[[PACKED1_PART0]][1], %[[SCALE1]] : vector<1xf32> to vector<4xf8E5M2>
+// CHECK-NEXT: %[[CHUNK1_RES:.+]] = vector.extract_strided_slice %[[PACKED1_PART1]] {offsets = [0], sizes = [3], strides = [1]} : vector<4xf8E5M2> to vector<3xf8E5M2>
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = vector.insert_strided_slice %[[CHUNK1_RES]], %[[FINAL_ACCUM_A]] {offsets = [3], strides = [1]} : vector<3xf8E5M2> into vector<6xf8E5M2>
// CHECK-NEXT: return %[[FINAL_RESULT]] : vector<6xf8E5M2>
func.func @conversion_broadcast_odd(%in: vector<6xf32>, %scale: vector<2xf8E8M0FNU>) -> vector<6xf8E5M2> {
@@ -165,14 +151,10 @@ func.func @conversion_broadcast_odd(%in: vector<6xf32>, %scale: vector<2xf8E8M0F
// CHECK-DAG: %[[SCALE_EXTF:.+]] = arith.extf %[[SCALE_SPLAT]] : vector<4xf8E8M0FNU> to vector<4xf32>
// CHECK-DAG: %[[SCALE_SCALAR:.+]] = vector.extract %[[SCALE_EXTF]][0] : f32 from vector<4xf32>
// CHECK: %[[IN_CHUNK0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
-// CHECK-NEXT: %[[PACKED0:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK0]] into undef[0], %[[SCALE_SCALAR]] : vector<2xf32> to vector<4xf8E5M2>
-// CHECK-NEXT: %[[OUT_CHUNK0:.+]] = vector.extract_strided_slice %[[PACKED0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT: %[[ACCUM_A:.+]] = vector.insert_strided_slice %[[OUT_CHUNK0]], %[[CST]] {offsets = [0], strides = [1]} : vector<2xf8E5M2> into vector<4xf8E5M2>
+// CHECK-NEXT: %[[PACKED0:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK0]] into %[[CST]][0], %[[SCALE_SCALAR]] : vector<2xf32> to vector<4xf8E5M2>
// CHECK-NEXT: %[[IN_CHUNK1:.+]] = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
-// CHECK-NEXT: %[[PACKED1:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK1]] into undef[0], %[[SCALE_SCALAR]] : vector<2xf32> to vector<4xf8E5M2>
-// CHECK-NEXT: %[[OUT_CHUNK1:.+]] = vector.extract_strided_slice %[[PACKED1]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT: %[[FINAL_RESULT:.+]] = vector.insert_strided_slice %[[OUT_CHUNK1]], %[[ACCUM_A]] {offsets = [2], strides = [1]} : vector<2xf8E5M2> into vector<4xf8E5M2>
-// CHECK-NEXT: return %[[FINAL_RESULT]] : vector<4xf8E5M2>
+// CHECK-NEXT: %[[PACKED1:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK1]] into %[[PACKED0]][1], %[[SCALE_SCALAR]] : vector<2xf32> to vector<4xf8E5M2>
+// CHECK-NEXT: return %[[PACKED1]] : vector<4xf8E5M2>
func.func @conversion_broadcast(%in: vector<4xf32>, %scale: f8E8M0FNU) -> vector<4xf8E5M2> {
%splat = vector.broadcast %scale : f8E8M0FNU to vector<4xf8E8M0FNU>
%ext = arith.scaling_truncf %in, %splat : vector<4xf32>, vector<4xf8E8M0FNU> to vector<4xf8E5M2>
@@ -191,3 +173,27 @@ func.func @conversion_scalar(%in: f32, %scale: f8E8M0FNU) -> f8E5M2 {
%ext = arith.scaling_truncf %in, %scale : f32, f8E8M0FNU to f8E5M2
return %ext : f8E5M2
}
+
+// -----
+
+// CHECK-LABEL: @long_fp4_broadcast
+// CHECK-COUNT-4: amdgpu.packed_scaled_trunc %{{.*}} into %{{.+}}[3]
+// CHECK-NOT: amdgpu.packed_scaled_trunc
+// CHECK: return
+func.func @long_fp4_broadcast(%in: vector<32xf32>, %scale: f32) -> vector<32xf4E2M1FN> {
+ %splat = vector.broadcast %scale : f32 to vector<32xf32>
+ %trunc = arith.scaling_truncf %in, %splat : vector<32xf32>, vector<32xf32> to vector<32xf4E2M1FN>
+ return %trunc : vector<32xf4E2M1FN>
+}
+
+// -----
+
+// CHECK-LABEL: @long_fp8_broadcast
+// CHECK-COUNT-8: amdgpu.packed_scaled_trunc %{{.*}} into %{{.+}}[1]
+// CHECK-NOT: amdgpu.packed_scaled_trunc
+// CHECK: return
+func.func @long_fp8_broadcast(%in: vector<32xf32>, %scale: f32) -> vector<32xf8E4M3FN> {
+ %splat = vector.broadcast %scale : f32 to vector<32xf32>
+ %trunc = arith.scaling_truncf %in, %splat : vector<32xf32>, vector<32xf32> to vector<32xf8E4M3FN>
+ return %trunc : vector<32xf8E4M3FN>
+}
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 1abe0fd..6e2352e 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -559,6 +559,23 @@ func.func @constant() {
return
}
+// CHECK-LABEL: @constant_8bit_float
+func.func @constant_8bit_float() {
+ // CHECK: spirv.Constant 56 : i8
+ %cst = arith.constant 1.0 : f8E4M3
+ // CHECK: spirv.Constant 56 : i8
+ %cst_i8 = arith.bitcast %cst : f8E4M3 to i8
+ // CHECK: spirv.Constant dense<56> : vector<4xi8>
+ %cst_vector = arith.constant dense<1.0> : vector<4xf8E4M3>
+ // CHECK: spirv.Constant dense<56> : vector<4xi8>
+ %cst_vector_i8 = arith.bitcast %cst_vector : vector<4xf8E4M3> to vector<4xi8>
+ // CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8>
+ %cst_tensor = arith.constant dense<1.0> : tensor<4xf8E5M2>
+ // CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8>
+ %cst_tensor_i8 = arith.bitcast %cst_tensor : tensor<4xf8E5M2> to tensor<4xi8>
+ return
+}
+
// CHECK-LABEL: @constant_16bit
func.func @constant_16bit() {
// CHECK: spirv.Constant 4 : i16
diff --git a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
index bae7c59..ae59f28 100644
--- a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
+++ b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
@@ -2,8 +2,26 @@
// CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32
// CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64
+// CHECK-DAG: @__ocml_carg_f32(complex<f32>) -> f32
+// CHECK-DAG: @__ocml_carg_f64(complex<f64>) -> f64
+// CHECK-DAG: @__ocml_ccos_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_ccos_f64(complex<f64>) -> complex<f64>
// CHECK-DAG: @__ocml_cexp_f32(complex<f32>) -> complex<f32>
// CHECK-DAG: @__ocml_cexp_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_clog_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_clog_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_conj_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_conj_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_cpow_f32(complex<f32>, complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_cpow_f64(complex<f64>, complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_csin_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_csin_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_csqrt_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_csqrt_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_ctan_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_ctan_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_ctanh_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_ctanh_f64(complex<f64>) -> complex<f64>
//CHECK-LABEL: @abs_caller
func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
@@ -15,6 +33,26 @@ func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
return %rf, %rd : f32, f64
}
+//CHECK-LABEL: @angle_caller
+func.func @angle_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
+ // CHECK: %[[AF:.*]] = call @__ocml_carg_f32(%{{.*}})
+ %af = complex.angle %f : complex<f32>
+ // CHECK: %[[AD:.*]] = call @__ocml_carg_f64(%{{.*}})
+ %ad = complex.angle %d : complex<f64>
+ // CHECK: return %[[AF]], %[[AD]]
+ return %af, %ad : f32, f64
+}
+
+//CHECK-LABEL: @cos_caller
+func.func @cos_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK: %[[CF:.*]] = call @__ocml_ccos_f32(%{{.*}})
+ %cf = complex.cos %f : complex<f32>
+ // CHECK: %[[CD:.*]] = call @__ocml_ccos_f64(%{{.*}})
+ %cd = complex.cos %d : complex<f64>
+ // CHECK: return %[[CF]], %[[CD]]
+ return %cf, %cd : complex<f32>, complex<f64>
+}
+
//CHECK-LABEL: @exp_caller
func.func @exp_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
// CHECK: %[[EF:.*]] = call @__ocml_cexp_f32(%{{.*}})
@@ -24,3 +62,73 @@ func.func @exp_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, comp
// CHECK: return %[[EF]], %[[ED]]
return %ef, %ed : complex<f32>, complex<f64>
}
+
+//CHECK-LABEL: @log_caller
+func.func @log_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK: %[[LF:.*]] = call @__ocml_clog_f32(%{{.*}})
+ %lf = complex.log %f : complex<f32>
+ // CHECK: %[[LD:.*]] = call @__ocml_clog_f64(%{{.*}})
+ %ld = complex.log %d : complex<f64>
+ // CHECK: return %[[LF]], %[[LD]]
+ return %lf, %ld : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @conj_caller
+func.func @conj_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK: %[[CF:.*]] = call @__ocml_conj_f32(%{{.*}})
+ %cf2 = complex.conj %f : complex<f32>
+ // CHECK: %[[CD:.*]] = call @__ocml_conj_f64(%{{.*}})
+ %cd2 = complex.conj %d : complex<f64>
+ // CHECK: return %[[CF]], %[[CD]]
+ return %cf2, %cd2 : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @pow_caller
+func.func @pow_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK: %[[PF:.*]] = call @__ocml_cpow_f32(%{{.*}}, %{{.*}})
+ %pf = complex.pow %f, %f : complex<f32>
+ // CHECK: %[[PD:.*]] = call @__ocml_cpow_f64(%{{.*}}, %{{.*}})
+ %pd = complex.pow %d, %d : complex<f64>
+ // CHECK: return %[[PF]], %[[PD]]
+ return %pf, %pd : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @sin_caller
+func.func @sin_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}})
+ %sf2 = complex.sin %f : complex<f32>
+ // CHECK: %[[SD:.*]] = call @__ocml_csin_f64(%{{.*}})
+ %sd2 = complex.sin %d : complex<f64>
+ // CHECK: return %[[SF]], %[[SD]]
+ return %sf2, %sd2 : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @sqrt_caller
+func.func @sqrt_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK: %[[SF:.*]] = call @__ocml_csqrt_f32(%{{.*}})
+ %sf = complex.sqrt %f : complex<f32>
+ // CHECK: %[[SD:.*]] = call @__ocml_csqrt_f64(%{{.*}})
+ %sd = complex.sqrt %d : complex<f64>
+ // CHECK: return %[[SF]], %[[SD]]
+ return %sf, %sd : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @tan_caller
+func.func @tan_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK: %[[TF:.*]] = call @__ocml_ctan_f32(%{{.*}})
+ %tf2 = complex.tan %f : complex<f32>
+ // CHECK: %[[TD:.*]] = call @__ocml_ctan_f64(%{{.*}})
+ %td2 = complex.tan %d : complex<f64>
+ // CHECK: return %[[TF]], %[[TD]]
+ return %tf2, %td2 : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @tanh_caller
+func.func @tanh_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK: %[[TF:.*]] = call @__ocml_ctanh_f32(%{{.*}})
+ %tf = complex.tanh %f : complex<f32>
+ // CHECK: %[[TD:.*]] = call @__ocml_ctanh_f64(%{{.*}})
+ %td = complex.tanh %d : complex<f64>
+ // CHECK: return %[[TF]], %[[TD]]
+ return %tf, %td : complex<f32>, complex<f64>
+}
diff --git a/mlir/test/Conversion/ConvertToSPIRV/convert-gpu-modules.mlir b/mlir/test/Conversion/ConvertToSPIRV/convert-gpu-modules.mlir
index 00bbd1c..96ad107 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/convert-gpu-modules.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/convert-gpu-modules.mlir
@@ -85,11 +85,10 @@ module attributes {
// CHECK: spirv.Load "StorageBuffer"
%val = memref.load %arg0[%idx0] : memref<2xi32>
// CHECK: spirv.CompositeInsert
- %vec = vector.insertelement %val, %vec0[%idx0 : index] : vector<2xi32>
+ %vec = vector.insert %val, %vec0[%idx0] : i32 into vector<2xi32>
// CHECK: spirv.VectorShuffle
%shuffle = vector.shuffle %vec, %vec[3, 2, 1, 0] : vector<2xi32>, vector<2xi32>
- // CHECK: spirv.CompositeExtract
- %res = vector.extractelement %shuffle[%idx0 : index] : vector<4xi32>
+ %res = vector.extract %shuffle[%idx0] : i32 from vector<4xi32>
// CHECK: spirv.AccessChain
// CHECK: spirv.Store "StorageBuffer"
memref.store %res, %arg1[%idx0]: memref<4xi32>
@@ -102,9 +101,9 @@ module attributes {
// CHECK-SAME: %{{.*}}: memref<2xi32>, %{{.*}}: memref<4xi32>
// CHECK: arith.constant
// CHECK: memref.load
- // CHECK: vector.insertelement
+ // CHECK: vector.insert
// CHECK: vector.shuffle
- // CHECK: vector.extractelement
+ // CHECK: vector.extract
// CHECK: memref.store
// CHECK: gpu.return
}
diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
index fb14feb..eb9feaa 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
@@ -51,108 +51,6 @@ func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf3
// -----
-// CHECK-LABEL: @extract_element
-// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32
-// CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
-func.func @extract_element(%arg0 : vector<4xf32>, %id : i32) -> f32 {
- %0 = vector.extractelement %arg0[%id : i32] : vector<4xf32>
- return %0: f32
-}
-
-// -----
-
-// CHECK-LABEL: @extract_element_cst
-// CHECK-SAME: %[[V:.*]]: vector<4xf32>
-// CHECK: spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32>
-func.func @extract_element_cst(%arg0 : vector<4xf32>) -> f32 {
- %idx = arith.constant 1 : i32
- %0 = vector.extractelement %arg0[%idx : i32] : vector<4xf32>
- return %0: f32
-}
-
-// -----
-
-// CHECK-LABEL: @extract_element_index
-func.func @extract_element_index(%arg0 : vector<4xf32>, %id : index) -> f32 {
- // CHECK: spirv.VectorExtractDynamic
- %0 = vector.extractelement %arg0[%id : index] : vector<4xf32>
- return %0: f32
-}
-
-// -----
-
-// CHECK-LABEL: @extract_element_size1_vector
-// CHECK-SAME:(%[[S:.+]]: f32,
-func.func @extract_element_size1_vector(%arg0 : f32, %i: index) -> f32 {
- %bcast = vector.broadcast %arg0 : f32 to vector<1xf32>
- %0 = vector.extractelement %bcast[%i : index] : vector<1xf32>
- // CHECK: spirv.ReturnValue %[[S]]
- return %0: f32
-}
-
-// -----
-
-// CHECK-LABEL: @extract_element_0d_vector
-// CHECK-SAME: (%[[S:.+]]: f32)
-func.func @extract_element_0d_vector(%arg0 : f32) -> f32 {
- %bcast = vector.broadcast %arg0 : f32 to vector<f32>
- %0 = vector.extractelement %bcast[] : vector<f32>
- // CHECK: spirv.ReturnValue %[[S]]
- return %0: f32
-}
-
-// -----
-
-// CHECK-LABEL: @insert_element
-// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32
-// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32
-func.func @insert_element(%val: f32, %arg0 : vector<4xf32>, %id : i32) -> vector<4xf32> {
- %0 = vector.insertelement %val, %arg0[%id : i32] : vector<4xf32>
- return %0: vector<4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @insert_element_cst
-// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>
-// CHECK: spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32>
-func.func @insert_element_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> {
- %idx = arith.constant 2 : i32
- %0 = vector.insertelement %val, %arg0[%idx : i32] : vector<4xf32>
- return %0: vector<4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @insert_element_index
-func.func @insert_element_index(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> {
- // CHECK: spirv.VectorInsertDynamic
- %0 = vector.insertelement %val, %arg0[%id : index] : vector<4xf32>
- return %0: vector<4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @insert_element_size1_vector
-// CHECK-SAME: (%[[S:[a-z0-9]+]]: f32
-func.func @insert_element_size1_vector(%scalar: f32, %vector : vector<1xf32>, %i: index) -> vector<1xf32> {
- %0 = vector.insertelement %scalar, %vector[%i : index] : vector<1xf32>
- // CHECK: spirv.ReturnValue %[[S]]
- return %0: vector<1xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @insert_element_0d_vector
-// CHECK-SAME: (%[[S:[a-z0-9]+]]: f32
-func.func @insert_element_0d_vector(%scalar: f32, %vector : vector<f32>) -> vector<f32> {
- %0 = vector.insertelement %scalar, %vector[] : vector<f32>
- // CHECK: spirv.ReturnValue %[[S]]
- return %0: vector<f32>
-}
-
-// -----
-
// CHECK-LABEL: @insert_size1_vector
// CHECK-SAME: %[[SUB:.*]]: f32, %[[FULL:.*]]: vector<3xf32>
// CHECK: %[[RET:.*]] = spirv.CompositeInsert %[[SUB]], %[[FULL]][2 : i32] : f32 into vector<3xf32>
diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
index 1737f4a..0c77c88 100644
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
@@ -1,6 +1,8 @@
// RUN: mlir-opt -split-input-file -convert-func-to-spirv %s -o - | FileCheck %s
// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-lt-32-bit-scalar-types=false" %s | \
// RUN: FileCheck %s --check-prefix=NOEMU
+// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-unsupported-float-types=false" %s | \
+// RUN: FileCheck %s --check-prefix=UNSUPPORTED_FLOAT
//===----------------------------------------------------------------------===//
// Integer types
@@ -944,3 +946,55 @@ func.func @unranked_tensor(%arg0: tensor<*xi32>) { return }
func.func @dynamic_dim_tensor(%arg0: tensor<8x?xi32>) { return }
} // end module
+
+
+// -----
+
+// Check that 8-bit float types are emulated as i8.
+module attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int8], []>, #spirv.resource_limits<>>
+} {
+
+ // CHECK: spirv.func @float8_to_integer8
+ // CHECK-SAME: (%arg0: i8
+ // CHECK-SAME: %arg1: i8
+ // CHECK-SAME: %arg2: i8
+ // CHECK-SAME: %arg3: i8
+ // CHECK-SAME: %arg4: i8
+ // CHECK-SAME: %arg5: i8
+ // CHECK-SAME: %arg6: i8
+ // CHECK-SAME: %arg7: i8
+ // CHECK-SAME: %arg8: vector<4xi8>
+ // CHECK-SAME: %arg9: !spirv.ptr<!spirv.struct<(!spirv.array<8 x i8, stride=1> [0])>, StorageBuffer>
+ // CHECK-SAME: %arg10: !spirv.array<4 x i8>
+ // UNSUPPORTED_FLOAT-LABEL: func.func @float8_to_integer8
+ // UNSUPPORTED_FLOAT-SAME: (%arg0: f8E5M2
+ // UNSUPPORTED_FLOAT-SAME: %arg1: f8E4M3
+ // UNSUPPORTED_FLOAT-SAME: %arg2: f8E4M3FN
+ // UNSUPPORTED_FLOAT-SAME: %arg3: f8E5M2FNUZ
+ // UNSUPPORTED_FLOAT-SAME: %arg4: f8E4M3FNUZ
+ // UNSUPPORTED_FLOAT-SAME: %arg5: f8E4M3B11FNUZ
+ // UNSUPPORTED_FLOAT-SAME: %arg6: f8E3M4
+ // UNSUPPORTED_FLOAT-SAME: %arg7: f8E8M0FNU
+ // UNSUPPORTED_FLOAT-SAME: %arg8: vector<4xf8E4M3B11FNUZ>
+ // UNSUPPORTED_FLOAT-SAME: %arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>
+ // UNSUPPORTED_FLOAT-SAME: %arg10: tensor<4xf8E5M2>
+ // UNSUPPORTED_FLOAT-SAME: ) {
+
+ func.func @float8_to_integer8(
+ %arg0: f8E5M2, // CHECK-NOT: f8E5M2
+ %arg1: f8E4M3, // CHECK-NOT: f8E4M3
+ %arg2: f8E4M3FN, // CHECK-NOT: f8E4M3FN
+ %arg3: f8E5M2FNUZ, // CHECK-NOT: f8E5M2FNUZ
+ %arg4: f8E4M3FNUZ, // CHECK-NOT: f8E4M3FNUZ
+ %arg5: f8E4M3B11FNUZ, // CHECK-NOT: f8E4M3B11FNUZ
+ %arg6: f8E3M4, // CHECK-NOT: f8E3M4
+ %arg7: f8E8M0FNU, // CHECK-NOT: f8E8M0FNU
+ %arg8: vector<4xf8E4M3B11FNUZ>, // CHECK-NOT: vector<4xf8E4M3B11FNUZ>
+ %arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>, // CHECK-NOT: memref
+ %arg10: tensor<4xf8E5M2> // CHECK-NOT: tensor
+ ) {
+ // CHECK: spirv.Return
+ return
+ }
+}
diff --git a/mlir/test/Conversion/GPUToSPIRV/lookup-target-env.mlir b/mlir/test/Conversion/GPUToSPIRV/lookup-target-env.mlir
new file mode 100644
index 0000000..983747b
--- /dev/null
+++ b/mlir/test/Conversion/GPUToSPIRV/lookup-target-env.mlir
@@ -0,0 +1,40 @@
+// RUN: mlir-opt --split-input-file --convert-gpu-to-spirv %s | FileCheck %s
+
+module attributes {gpu.container_module} {
+ // CHECK-LABEL: spirv.module @{{.*}} GLSL450
+ gpu.module @kernels [#spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>] {
+ // CHECK: spirv.func @load_kernel
+ // CHECK-SAME: %[[ARG:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<48 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>})
+ gpu.func @load_kernel(%arg0: memref<12x4xf32>) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[PTR:.*]] = spirv.AccessChain %[[ARG]]{{\[}}{{%.*}}, {{%.*}}{{\]}}
+ // CHECK-NEXT: {{%.*}} = spirv.Load "StorageBuffer" %[[PTR]] : f32
+ %0 = memref.load %arg0[%c0, %c0] : memref<12x4xf32>
+ // CHECK: spirv.Return
+ gpu.return
+ }
+ }
+}
+
+// -----
+// Checks that the `-convert-gpu-to-spirv` pass selects the first
+// `spirv.target_env` from the `targets` array attribute attached to `gpu.module`.
+module attributes {gpu.container_module} {
+ // CHECK-LABEL: spirv.module @{{.*}} GLSL450
+ // CHECK-SAME: #spirv.target_env<#spirv.vce<v1.4, [Shader], [SPV_KHR_storage_buffer_storage_class]>
+ gpu.module @kernels [
+ #spirv.target_env<#spirv.vce<v1.4, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>,
+ #spirv.target_env<#spirv.vce<v1.0, [Kernel], []>, #spirv.resource_limits<>>,
+ #spirv.target_env<#spirv.vce<v1.0, [Shader], []>, #spirv.resource_limits<>>] {
+ // CHECK: spirv.func @load_kernel
+ // CHECK-SAME: %[[ARG:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<48 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>})
+ gpu.func @load_kernel(%arg0: memref<12x4xf32>) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[PTR:.*]] = spirv.AccessChain %[[ARG]]{{\[}}{{%.*}}, {{%.*}}{{\]}}
+ // CHECK-NEXT: {{%.*}} = spirv.Load "StorageBuffer" %[[PTR]] : f32
+ %0 = memref.load %arg0[%c0, %c0] : memref<12x4xf32>
+ // CHECK: spirv.Return
+ gpu.return
+ }
+ }
+}
diff --git a/mlir/test/Conversion/GPUToSPIRV/rotate.mlir b/mlir/test/Conversion/GPUToSPIRV/rotate.mlir
index b96dd37..c71d220 100644
--- a/mlir/test/Conversion/GPUToSPIRV/rotate.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/rotate.mlir
@@ -10,16 +10,14 @@ gpu.module @kernels {
// CHECK-LABEL: spirv.func @rotate()
gpu.func @rotate() kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
- %offset = arith.constant 4 : i32
- %width = arith.constant 16 : i32
%val = arith.constant 42.0 : f32
+ // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
// CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
// CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32
- // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
// CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32
// CHECK: %{{.+}} = spirv.Constant true
- %result, %valid = gpu.rotate %val, %offset, %width : f32
+ %result, %valid = gpu.rotate %val, 4, 16 : f32
gpu.return
}
}
@@ -38,18 +36,16 @@ gpu.module @kernels {
// CHECK-LABEL: spirv.func @rotate_width_less_than_subgroup_size()
gpu.func @rotate_width_less_than_subgroup_size() kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
- %offset = arith.constant 4 : i32
- %width = arith.constant 8 : i32
%val = arith.constant 42.0 : f32
+ // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
// CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
// CHECK: %[[WIDTH:.+]] = spirv.Constant 8 : i32
- // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
// CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32
// CHECK: %[[INVOCATION_ID_ADDR:.+]] = spirv.mlir.addressof @__builtin__SubgroupLocalInvocationId__
// CHECK: %[[INVOCATION_ID:.+]] = spirv.Load "Input" %[[INVOCATION_ID_ADDR]]
// CHECK: %{{.+}} = spirv.ULessThan %[[INVOCATION_ID]], %[[WIDTH]]
- %result, %valid = gpu.rotate %val, %offset, %width : f32
+ %result, %valid = gpu.rotate %val, 4, 8 : f32
gpu.return
}
}
@@ -67,34 +63,10 @@ module attributes {
gpu.module @kernels {
gpu.func @rotate_with_bigger_than_subgroup_size() kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
- %offset = arith.constant 4 : i32
- %width = arith.constant 32 : i32
%val = arith.constant 42.0 : f32
// expected-error @+1 {{failed to legalize operation 'gpu.rotate'}}
- %result, %valid = gpu.rotate %val, %offset, %width : f32
- gpu.return
- }
-}
-
-}
-
-// -----
-
-module attributes {
- gpu.container_module,
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformRotateKHR], []>,
- #spirv.resource_limits<subgroup_size = 16>>
-} {
-
-gpu.module @kernels {
- gpu.func @rotate_non_const_width(%width: i32) kernel
- attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
- %offset = arith.constant 4 : i32
- %val = arith.constant 42.0 : f32
-
- // expected-error @+1 {{'gpu.rotate' op width is not a constant value}}
- %result, %valid = gpu.rotate %val, %offset, %width : f32
+ %result, %valid = gpu.rotate %val, 4, 32 : f32
gpu.return
}
}
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-fpclassify-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-fpclassify-spirv.mlir
new file mode 100644
index 0000000..3e5f592
--- /dev/null
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-fpclassify-spirv.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt --convert-math-to-spirv %s | FileCheck %s
+
+module attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], []>, #spirv.resource_limits<>>
+} {
+
+ // CHECK-LABEL: @fpclassify
+ func.func @fpclassify(%x: f32, %v: vector<4xf32>) {
+ // CHECK: spirv.IsFinite %{{.*}} : f32
+ %0 = math.isfinite %x : f32
+ // CHECK: spirv.IsFinite %{{.*}} : vector<4xf32>
+ %1 = math.isfinite %v : vector<4xf32>
+
+ // CHECK: spirv.IsNan %{{.*}} : f32
+ %2 = math.isnan %x : f32
+ // CHECK: spirv.IsNan %{{.*}} : vector<4xf32>
+ %3 = math.isnan %v : vector<4xf32>
+
+ // CHECK: spirv.IsInf %{{.*}} : f32
+ %4 = math.isinf %x : f32
+ // CHECK: spirv.IsInf %{{.*}} : vector<4xf32>
+ %5 = math.isinf %v : vector<4xf32>
+
+ return
+ }
+
+}
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc.mlir
new file mode 100644
index 0000000..e391a89
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc.mlir
@@ -0,0 +1,72 @@
+// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefix=CPP
+// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefix=NOCPP
+
+func.func @alloc() {
+ %alloc = memref.alloc() : memref<999xi32>
+ return
+}
+
+// CPP: module {
+// CPP-NEXT: emitc.include <"cstdlib">
+// CPP-LABEL: alloc()
+// CPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
+// CPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index
+// CPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t
+// CPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
+// CPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
+// CPP-NEXT: return
+
+// NOCPP: module {
+// NOCPP-NEXT: emitc.include <"stdlib.h">
+// NOCPP-LABEL: alloc()
+// NOCPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
+// NOCPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index
+// NOCPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t
+// NOCPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
+// NOCPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
+// NOCPP-NEXT: return
+
+func.func @alloc_aligned() {
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<999xf32>
+ return
+}
+
+// CPP-LABEL: alloc_aligned
+// CPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
+// CPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index
+// CPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t
+// CPP-NEXT: %[[ALIGNMENT:.*]] = "emitc.constant"() <{value = 64 : index}> : () -> !emitc.size_t
+// CPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "aligned_alloc"(%[[ALIGNMENT]], %[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t, !emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
+// CPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<f32>
+// CPP-NEXT: return
+
+// NOCPP-LABEL: alloc_aligned
+// NOCPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
+// NOCPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index
+// NOCPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t
+// NOCPP-NEXT: %[[ALIGNMENT:.*]] = "emitc.constant"() <{value = 64 : index}> : () -> !emitc.size_t
+// NOCPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "aligned_alloc"(%[[ALIGNMENT]], %[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t, !emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
+// NOCPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<f32>
+// NOCPP-NEXT: return
+
+func.func @allocating_multi() {
+ %alloc_5 = memref.alloc() : memref<7x999xi32>
+ return
+}
+
+// CPP-LABEL: allocating_multi
+// CPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
+// CPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 6993 : index}> : () -> index
+// CPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t
+// CPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">
+// CPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
+// CPP-NEXT: return
+
+// NOCPP-LABEL: allocating_multi
+// NOCPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
+// NOCPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 6993 : index}> : () -> index
+// NOCPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t
+// NOCPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
+// NOCPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
+// NOCPP-NEXT: return
+
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 8d720ce..580b09d 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -580,30 +580,6 @@ func.func @elect_one_leader_sync() {
// -----
-// CHECK-LABEL: @stmatrix(
-// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !llvm.ptr<3>,
-// CHECK-SAME: %[[arg1:[a-zA-Z0-9_]+]]: i32,
-// CHECK-SAME: %[[arg2:[a-zA-Z0-9_]+]]: i32,
-// CHECK-SAME: %[[arg3:[a-zA-Z0-9_]+]]: i32,
-// CHECK-SAME: %[[arg4:[a-zA-Z0-9_]+]]: i32)
-llvm.func @stmatrix(%arg0 : !llvm.ptr<3>, %m1 : i32, %m2 : i32, %m3 : i32, %m4 : i32) {
-// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.m8n8.shared.b16 [$0], {$1};", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> ()
-// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.m8n8.shared.b16 [$0], {$1, $2};", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> ()
-// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> ()
-// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [$0], {$1};", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> ()
-// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [$0], {$1, $2};", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> ()
-// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> ()
- nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32
- nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32
- nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32, i32, i32
- nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32
- nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32
- nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32, i32, i32
- llvm.return
-}
-
-// -----
-
// CHECK-LABEL: @init_mbarrier_arrive_expect_tx
llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) {
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "prefetch.tensormap [$0];", "l"
diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
index d54d003..5e20b5a 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
+++ b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
@@ -1,14 +1,14 @@
-// RUN: mlir-opt %s -convert-mesh-to-mpi -canonicalize -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -convert-shard-to-mpi -canonicalize -split-input-file | FileCheck %s
// -----
-// CHECK: mesh.mesh @mesh0
-mesh.mesh @mesh0(shape = 3x4x5)
+// CHECK: shard.grid @grid0
+shard.grid @grid0(shape = 3x4x5)
func.func @process_multi_index() -> (index, index, index) {
// CHECK: mpi.comm_rank
// CHECK-DAG: %[[v4:.*]] = arith.remsi
// CHECK-DAG: %[[v0:.*]] = arith.remsi
// CHECK-DAG: %[[v1:.*]] = arith.remsi
- %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index
+ %0:3 = shard.process_multi_index on @grid0 axes = [] : index, index, index
// CHECK: return %[[v1]], %[[v0]], %[[v4]] : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}
@@ -17,7 +17,7 @@ func.func @process_multi_index() -> (index, index, index) {
func.func @process_linear_index() -> index {
// CHECK: %[[RES:.*]], %[[rank:.*]] = mpi.comm_rank
// CHECK: %[[cast:.*]] = arith.index_cast %[[rank]] : i32 to index
- %0 = mesh.process_linear_index on @mesh0 : index
+ %0 = shard.process_linear_index on @grid0 : index
// CHECK: return %[[cast]] : index
return %0 : index
}
@@ -29,7 +29,7 @@ func.func @neighbors_dim0(%arg0 : tensor<120x120x120xi8>) -> (index, index) {
%c4 = arith.constant 4 : index
// CHECK-DAG: [[up:%.*]] = arith.constant 44 : index
// CHECK-DAG: [[down:%.*]] = arith.constant 4 : index
- %idx:2 = mesh.neighbors_linear_indices on @mesh0[%c1, %c0, %c4] split_axes = [0] : index, index
+ %idx:2 = shard.neighbors_linear_indices on @grid0[%c1, %c0, %c4] split_axes = [0] : index, index
// CHECK: return [[down]], [[up]] : index, index
return %idx#0, %idx#1 : index, index
}
@@ -41,7 +41,7 @@ func.func @neighbors_dim1(%arg0 : tensor<120x120x120xi8>) -> (index, index) {
%c4 = arith.constant 4 : index
// CHECK-DAG: [[up:%.*]] = arith.constant 29 : index
// CHECK-DAG: [[down:%.*]] = arith.constant -1 : index
- %idx:2 = mesh.neighbors_linear_indices on @mesh0[%c1, %c0, %c4] split_axes = [1] : index, index
+ %idx:2 = shard.neighbors_linear_indices on @grid0[%c1, %c0, %c4] split_axes = [1] : index, index
// CHECK: return [[down]], [[up]] : index, index
return %idx#0, %idx#1 : index, index
}
@@ -53,20 +53,20 @@ func.func @neighbors_dim2(%arg0 : tensor<120x120x120xi8>) -> (index, index) {
%c4 = arith.constant 4 : index
// CHECK-DAG: [[up:%.*]] = arith.constant -1 : index
// CHECK-DAG: [[down:%.*]] = arith.constant 23 : index
- %idx:2 = mesh.neighbors_linear_indices on @mesh0[%c1, %c0, %c4] split_axes = [2] : index, index
+ %idx:2 = shard.neighbors_linear_indices on @grid0[%c1, %c0, %c4] split_axes = [2] : index, index
// CHECK: return [[down]], [[up]] : index, index
return %idx#0, %idx#1 : index, index
}
// -----
-// CHECK: mesh.mesh @mesh0
+// CHECK: shard.grid @grid0
module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
- mesh.mesh @mesh0(shape = 3x4x5)
+ shard.grid @grid0(shape = 3x4x5)
func.func @process_multi_index() -> (index, index, index) {
// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
- %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index
+ %0:3 = shard.process_multi_index on @grid0 axes = [] : index, index, index
// CHECK: return %[[c1]], %[[c0]], %[[c4]] : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}
@@ -74,7 +74,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
// CHECK-LABEL: func @process_linear_index
func.func @process_linear_index() -> index {
// CHECK: %[[c24:.*]] = arith.constant 24 : index
- %0 = mesh.process_linear_index on @mesh0 : index
+ %0 = shard.process_linear_index on @grid0 : index
// CHECK: return %[[c24]] : index
return %0 : index
}
@@ -82,7 +82,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
// -----
module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
- mesh.mesh @mesh0(shape = 3x4x5)
+ shard.grid @grid0(shape = 3x4x5)
// CHECK-LABEL: func.func @allreduce_tensor(
func.func @allreduce_tensor(
// CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32>
@@ -97,7 +97,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
// CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf32> into memref<12xf32>
// CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf32>, memref<12xf32>
// CHECK: [[v2:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<3x4xf32> to tensor<3x4xf32>
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1] reduction = max : tensor<3x4xf32> -> tensor<3x4xf32>
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : tensor<3x4xf32> -> tensor<3x4xf32>
// CHECK: return [[v2]] : tensor<3x4xf32>
return %0 : tensor<3x4xf32>
}
@@ -114,7 +114,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
// CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm
// CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf32> into memref<12xf32>
// CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf32>, memref<12xf32>
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf32>
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf32>
// CHECK: return [[valloc]] : memref<3x4xf32>
return %0 : memref<3x4xf32>
}
@@ -131,14 +131,14 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
// CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm
// CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf64> into memref<12xf64>
// CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf64>, memref<12xf64>
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf64>
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf64>
// CHECK: return [[valloc]] : memref<3x4xf64>
return %0 : memref<3x4xf64>
}
}
// -----
-mesh.mesh @mesh0(shape = 3x4x5)
+shard.grid @grid0(shape = 3x4x5)
// CHECK-LABEL: func @update_halo_1d_first
func.func @update_halo_1d_first(
// CHECK-SAME: [[arg0:%.*]]: memref<120x120x120xi8>
@@ -155,14 +155,14 @@ func.func @update_halo_1d_first(
// CHECK: mpi.recv(
// CHECK-SAME: : memref<3x120x120xi8>, i32, i32
// CHECK: memref.subview [[arg0]][117, 0, 0] [3, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<3x120x120xi8
- %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 3] : memref<120x120x120xi8>
+ %res = shard.update_halo %arg0 on @grid0 split_axes = [[0]] halo_sizes = [2, 3] : memref<120x120x120xi8>
// CHECK: return [[res:%.*]] : memref<120x120x120xi8>
return %res : memref<120x120x120xi8>
}
// -----
module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 1> } {
- mesh.mesh @mesh0(shape = 4)
+ shard.grid @grid0(shape = 4)
// CHECK-LABEL: func @update_halo_1d_with_zero
func.func @update_halo_1d_with_zero (
// CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8>
@@ -179,7 +179,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 1> } {
// CHECK: [[vsubview_0:%.*]] = memref.subview [[varg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1]>>
// CHECK: memref.copy [[valloc]], [[vsubview_0]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1]>>
// CHECK: memref.dealloc [[valloc]] : memref<2x120x120xi8>
- %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 0] : memref<120x120x120xi8>
+ %res = shard.update_halo %arg0 on @grid0 split_axes = [[0]] halo_sizes = [2, 0] : memref<120x120x120xi8>
// CHECK: return [[varg0]] : memref<120x120x120xi8>
return %res : memref<120x120x120xi8>
}
@@ -187,7 +187,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 1> } {
// -----
module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
- mesh.mesh @mesh0(shape = 3x4x5)
+ shard.grid @grid0(shape = 3x4x5)
// CHECK-LABEL: func @update_halo_3d
func.func @update_halo_3d(
// CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8>
@@ -236,7 +236,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
// CHECK: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8>
// CHECK: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]], [[v2]]) : memref<2x120x120xi8>, i32, i32
// CHECK: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
- %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8>
+ %res = shard.update_halo %arg0 on @grid0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8>
// CHECK: return [[varg0]] : memref<120x120x120xi8>
return %res : memref<120x120x120xi8>
}
@@ -291,18 +291,18 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
// CHECK: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]], [[v3]]) : memref<2x120x120xi8>, i32, i32
// CHECK: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
// CHECK: [[v4:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8> to tensor<120x120x120xi8>
- %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8>
+ %res = shard.update_halo %arg0 on @grid0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8>
// CHECK: return [[v4]] : tensor<120x120x120xi8>
return %res : tensor<120x120x120xi8>
}
}
// -----
-mesh.mesh @mesh0(shape = 2x2x4)
+shard.grid @grid0(shape = 2x2x4)
// CHECK-LABEL: func.func @return_sharding(
// CHECK-SAME: [[varg0:%.*]]: tensor<2x4xf32>) -> (tensor<2x4xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
-func.func @return_sharding(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>, !mesh.sharding) {
- %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] : !mesh.sharding
+func.func @return_sharding(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>, !shard.sharding) {
+ %sharding = shard.sharding @grid0 split_axes = [[0, 1], [2]] : !shard.sharding
// CHECK: [[vcst:%.*]] = arith.constant dense<2> : tensor<1xi16>
// CHECK: [[vcst_0:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
// CHECK: [[vcm1_i16:%.*]] = arith.constant -1 : i16
@@ -316,13 +316,13 @@ func.func @return_sharding(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>, !mesh.sh
// CHECK: [[vcast_2:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64>
// CHECK: [[vcast_3:%.*]] = tensor.cast [[v3]] : tensor<0x0xi64> to tensor<?x?xi64>
// CHECK: return [[varg0]], [[vcast]], [[vcast_2]], [[vcast_3]] : tensor<2x4xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
- return %arg0, %sharding : tensor<2x4xf32>, !mesh.sharding
+ return %arg0, %sharding : tensor<2x4xf32>, !shard.sharding
}
// CHECK-LABEL: func.func @return_sharding_halos(
// CHECK-SAME: [[varg0:%.*]]: tensor<6x8xf32>) -> (tensor<6x8xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
-func.func @return_sharding_halos(%arg0: tensor<6x8xf32>) -> (tensor<6x8xf32>, !mesh.sharding) {
- %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] halo_sizes = [0, 4, 3, 1] : !mesh.sharding
+func.func @return_sharding_halos(%arg0: tensor<6x8xf32>) -> (tensor<6x8xf32>, !shard.sharding) {
+ %sharding = shard.sharding @grid0 split_axes = [[0, 1], [2]] halo_sizes = [0, 4, 3, 1] : !shard.sharding
// CHECK: [[vcst:%.*]] = arith.constant dense<{{\[\[}}0, 4], [3, 1]]> : tensor<2x2xi64>
// CHECK: [[vcst_0:%.*]] = arith.constant dense<2> : tensor<1xi16>
// CHECK: [[vcst_1:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
@@ -336,13 +336,13 @@ func.func @return_sharding_halos(%arg0: tensor<6x8xf32>) -> (tensor<6x8xf32>, !m
// CHECK: [[vcast_3:%.*]] = tensor.cast [[vcst]] : tensor<2x2xi64> to tensor<?x?xi64>
// CHECK: [[vcast_4:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64>
// CHECK: return [[varg0]], [[vcast]], [[vcast_3]], [[vcast_4]] : tensor<6x8xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
- return %arg0, %sharding : tensor<6x8xf32>, !mesh.sharding
+ return %arg0, %sharding : tensor<6x8xf32>, !shard.sharding
}
// CHECK-LABEL: func.func @return_sharding_offs(
// CHECK-SAME: [[varg0:%.*]]: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
-func.func @return_sharding_offs(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>, !mesh.sharding) {
- %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] sharded_dims_offsets = [0, 3, 5, 7, 8, 0, 0, 5, 10, 16] : !mesh.sharding
+func.func @return_sharding_offs(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>, !shard.sharding) {
+ %sharding = shard.sharding @grid0 split_axes = [[0, 1], [2]] sharded_dims_offsets = [0, 3, 5, 7, 8, 0, 0, 5, 10, 16] : !shard.sharding
// CHECK: [[vcst:%.*]] = arith.constant dense<[0, 0, 5, 10, 16]> : tensor<5xi64>
// CHECK: [[vcst_0:%.*]] = arith.constant dense<[0, 3, 5, 7, 8]> : tensor<5xi64>
// CHECK: [[vcm9223372036854775808_i64:%.*]] = arith.constant -9223372036854775808 : i64
@@ -362,5 +362,5 @@ func.func @return_sharding_offs(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>, !me
// CHECK: [[vcast_6:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64>
// CHECK: [[vcast_7:%.*]] = tensor.cast [[vinserted_slice_5]] : tensor<2x5xi64> to tensor<?x?xi64>
// CHECK: return [[varg0]], [[vcast]], [[vcast_6]], [[vcast_7]] : tensor<?x?xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
- return %arg0, %sharding : tensor<?x?xf32>, !mesh.sharding
+ return %arg0, %sharding : tensor<?x?xf32>, !shard.sharding
}
diff --git a/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shardshape-to-mpi.mlir
index 156bbfb..9729d2b 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir
+++ b/mlir/test/Conversion/ShardToMPI/convert-shardshape-to-mpi.mlir
@@ -1,21 +1,21 @@
-// RUN: mlir-opt %s --convert-mesh-to-mpi -canonicalize | FileCheck %s
+// RUN: mlir-opt %s --convert-shard-to-mpi -canonicalize | FileCheck %s
module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
- // CHECK: mesh.mesh @mesh0
- mesh.mesh @mesh0(shape = 3x4x5)
+ // CHECK: shard.grid @grid0
+ shard.grid @grid0(shape = 3x4x5)
- // Notice: comm_world_rank/linear index 24 is multiindex [1, 0, 4] in @mesh0
+ // Notice: comm_world_rank/linear index 24 is multiindex [1, 0, 4] in @grid0
// all shards are equal
// CHECK-LABEL: func.func @shard_shape_equal() -> (index, index, index) {
func.func @shard_shape_equal() -> (index, index, index) {
- %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
- %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] : !shard.sharding
+ %0:3 = shard.process_multi_index on @grid0 : index, index, index
%c9 = arith.constant 9 : index
%c12 = arith.constant 12 : index
// CHECK: [[vc3:%.*]] = arith.constant 3 : index
- %1:3 = mesh.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+ %1:3 = shard.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
// CHECK: return [[vc3]], [[vc3]], [[vc3]] : index, index, index
return %1#0, %1#1, %1#2 : index, index, index
}
@@ -23,13 +23,13 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
// last shard in last dim gets an extra element
// CHECK-LABEL: func.func @shard_shape_odd_1() -> (index, index, index) {
func.func @shard_shape_odd_1() -> (index, index, index) {
- %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
- %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] : !shard.sharding
+ %0:3 = shard.process_multi_index on @grid0 : index, index, index
%c9 = arith.constant 9 : index
%c12 = arith.constant 12 : index
// CHECK-DAG: [[vc3:%.*]] = arith.constant 3 : index
// CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
- %1:3 = mesh.shard_shape dims = [%c9, %c12, 16] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+ %1:3 = shard.shard_shape dims = [%c9, %c12, 16] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
// CHECK: return [[vc3]], [[vc3]], [[vc4]] : index, index, index
return %1#0, %1#1, %1#2 : index, index, index
}
@@ -37,11 +37,11 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
// In the second dimension the shard sizes are now [3 4 4 4]
// CHECK-LABEL: func.func @shard_shape_odd_2() -> (index, index, index) {
func.func @shard_shape_odd_2() -> (index, index, index) {
- %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
- %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] : !shard.sharding
+ %0:3 = shard.process_multi_index on @grid0 : index, index, index
%c9 = arith.constant 9 : index
// CHECK: [[vc3:%.*]] = arith.constant 3 : index
- %1:3 = mesh.shard_shape dims = [%c9, 15, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+ %1:3 = shard.shard_shape dims = [%c9, 15, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
// CHECK: return [[vc3]], [[vc3]], [[vc3]] : index, index, index
return %1#0, %1#1, %1#2 : index, index, index
}
@@ -49,11 +49,11 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
// In the first dimension the shard sizes are now [3 4 4]
// CHECK-LABEL: func.func @shard_shape_odd_3() -> (index, index, index) {
func.func @shard_shape_odd_3() -> (index, index, index) {
- %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
- %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] : !shard.sharding
+ %0:3 = shard.process_multi_index on @grid0 : index, index, index
// CHECK-DAG: [[vc3:%.*]] = arith.constant 3 : index
// CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
- %1:3 = mesh.shard_shape dims = [11, 12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+ %1:3 = shard.shard_shape dims = [11, 12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
// CHECK: return [[vc4]], [[vc3]], [[vc3]] : index, index, index
return %1#0, %1#1, %1#2 : index, index, index
}
@@ -61,14 +61,14 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
// extract from sharded_dims_offsets
// CHECK-LABEL: func.func @shard_shape_sharded_dims_offs() -> (index, index, index) {
func.func @shard_shape_sharded_dims_offs() -> (index, index, index) {
- %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]]
- sharded_dims_offsets = [0, 1, 4, 9, 0, 2, 6, 12, 12, 0, 3, 6, 9, 12, 15]: !mesh.sharding
- %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]]
+ sharded_dims_offsets = [0, 1, 4, 9, 0, 2, 6, 12, 12, 0, 3, 6, 9, 12, 15]: !shard.sharding
+ %0:3 = shard.process_multi_index on @grid0 : index, index, index
%c9 = arith.constant 9 : index
%c12 = arith.constant 12 : index
// CHECK: [[vc3:%.*]] = arith.constant 3 : index
// CHECK: [[vc2:%.*]] = arith.constant 2 : index
- %1:3 = mesh.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+ %1:3 = shard.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
// CHECK: return [[vc3]], [[vc2]], [[vc3]] : index, index, index
return %1#0, %1#1, %1#2 : index, index, index
}
diff --git a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
index fa7a91c..b6f2383 100644
--- a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
+++ b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
@@ -36,7 +36,7 @@ func.func @while_test(%arg0 : tensor<i32>) -> (tensor<i32>) {
func.func @if_test(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> (tensor<f32>) {
// CHECK: [[EX:%.+]] = tensor.extract [[ARG2]]
// CHECK: [[IF:%.+]] = scf.if [[EX]] -> (tensor<f32>) {
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
// CHECK: scf.yield [[ARG0]]
tosa.yield %arg0 : tensor<f32>
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index 8c135d5..31e17fb 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -274,73 +274,6 @@ func.func @shuffle_2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf3
// -----
//===----------------------------------------------------------------------===//
-// vector.extractelement
-//===----------------------------------------------------------------------===//
-
-func.func @extractelement_from_vec_0d_f32(%arg0: vector<f32>) -> f32 {
- %1 = vector.extractelement %arg0[] : vector<f32>
- return %1 : f32
-}
-// CHECK-LABEL: @extractelement_from_vec_0d_f32
-// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
-// CHECK: llvm.extractelement %{{.*}}[%[[C0]] : {{.*}}] : vector<1xf32>
-
-// -----
-
-func.func @extractelement_from_vec_1d_f32_idx_as_i32(%arg0: vector<16xf32>) -> f32 {
- %0 = arith.constant 15 : i32
- %1 = vector.extractelement %arg0[%0 : i32]: vector<16xf32>
- return %1 : f32
-}
-// CHECK-LABEL: @extractelement_from_vec_1d_f32_idx_as_i32(
-// CHECK-SAME: %[[A:.*]]: vector<16xf32>)
-// CHECK: %[[C:.*]] = arith.constant 15 : i32
-// CHECK: %[[X:.*]] = llvm.extractelement %[[A]][%[[C]] : i32] : vector<16xf32>
-// CHECK: return %[[X]] : f32
-
-// -----
-
-func.func @extractelement_from_vec_1d_f32_idx_as_i32_scalable(%arg0: vector<[16]xf32>) -> f32 {
- %0 = arith.constant 15 : i32
- %1 = vector.extractelement %arg0[%0 : i32]: vector<[16]xf32>
- return %1 : f32
-}
-// CHECK-LABEL: @extractelement_from_vec_1d_f32_idx_as_i32_scalable(
-// CHECK-SAME: %[[A:.*]]: vector<[16]xf32>)
-// CHECK: %[[C:.*]] = arith.constant 15 : i32
-// CHECK: %[[X:.*]] = llvm.extractelement %[[A]][%[[C]] : i32] : vector<[16]xf32>
-// CHECK: return %[[X]] : f32
-
-// -----
-func.func @extractelement_from_vec_1d_f32_idx_as_index(%arg0: vector<16xf32>) -> f32 {
- %0 = arith.constant 15 : index
- %1 = vector.extractelement %arg0[%0 : index]: vector<16xf32>
- return %1 : f32
-}
-// CHECK-LABEL: @extractelement_from_vec_1d_f32_idx_as_index(
-// CHECK-SAME: %[[A:.*]]: vector<16xf32>)
-// CHECK: %[[C:.*]] = arith.constant 15 : index
-// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %[[C]] : index to i64
-// CHECK: %[[X:.*]] = llvm.extractelement %[[A]][%[[I]] : i64] : vector<16xf32>
-// CHECK: return %[[X]] : f32
-
-// -----
-
-func.func @extractelement_from_vec_1d_f32_idx_as_index_scalable(%arg0: vector<[16]xf32>) -> f32 {
- %0 = arith.constant 15 : index
- %1 = vector.extractelement %arg0[%0 : index]: vector<[16]xf32>
- return %1 : f32
-}
-// CHECK-LABEL: @extractelement_from_vec_1d_f32_idx_as_index_scalable(
-// CHECK-SAME: %[[A:.*]]: vector<[16]xf32>)
-// CHECK: %[[C:.*]] = arith.constant 15 : index
-// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %[[C]] : index to i64
-// CHECK: %[[X:.*]] = llvm.extractelement %[[A]][%[[I]] : i64] : vector<[16]xf32>
-// CHECK: return %[[X]] : f32
-
-// -----
-
-//===----------------------------------------------------------------------===//
// vector.extract
//===----------------------------------------------------------------------===//
@@ -592,81 +525,6 @@ func.func @extract_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const(%arg :
// -----
//===----------------------------------------------------------------------===//
-// vector.insertelement
-//===----------------------------------------------------------------------===//
-
-func.func @insertelement_into_vec_0d_f32(%arg0: f32, %arg1: vector<f32>) -> vector<f32> {
- %1 = vector.insertelement %arg0, %arg1[] : vector<f32>
- return %1 : vector<f32>
-}
-// CHECK-LABEL: @insertelement_into_vec_0d_f32
-// CHECK-SAME: %[[A:.*]]: f32,
-// CHECK: %[[B:.*]] = builtin.unrealized_conversion_cast %{{.*}} :
-// CHECK: vector<f32> to vector<1xf32>
-// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
-// CHECK: %[[X:.*]] = llvm.insertelement %[[A]], %[[B]][%[[C0]] : {{.*}}] : vector<1xf32>
-
-// -----
-
-func.func @insertelement_into_vec_1d_f32_idx_as_i32(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> {
- %0 = arith.constant 3 : i32
- %1 = vector.insertelement %arg0, %arg1[%0 : i32] : vector<4xf32>
- return %1 : vector<4xf32>
-}
-// CHECK-LABEL: @insertelement_into_vec_1d_f32_idx_as_i32(
-// CHECK-SAME: %[[A:.*]]: f32,
-// CHECK-SAME: %[[B:.*]]: vector<4xf32>)
-// CHECK: %[[C:.*]] = arith.constant 3 : i32
-// CHECK: %[[X:.*]] = llvm.insertelement %[[A]], %[[B]][%[[C]] : i32] : vector<4xf32>
-// CHECK: return %[[X]] : vector<4xf32>
-
-// -----
-
-func.func @insertelement_into_vec_1d_f32_idx_as_i32_scalable(%arg0: f32, %arg1: vector<[4]xf32>) -> vector<[4]xf32> {
- %0 = arith.constant 3 : i32
- %1 = vector.insertelement %arg0, %arg1[%0 : i32] : vector<[4]xf32>
- return %1 : vector<[4]xf32>
-}
-// CHECK-LABEL: @insertelement_into_vec_1d_f32_idx_as_i32_scalable(
-// CHECK-SAME: %[[A:.*]]: f32,
-// CHECK-SAME: %[[B:.*]]: vector<[4]xf32>)
-// CHECK: %[[C:.*]] = arith.constant 3 : i32
-// CHECK: %[[X:.*]] = llvm.insertelement %[[A]], %[[B]][%[[C]] : i32] : vector<[4]xf32>
-// CHECK: return %[[X]] : vector<[4]xf32>
-
-// -----
-
-func.func @insertelement_into_vec_1d_f32_scalable_idx_as_index(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> {
- %0 = arith.constant 3 : index
- %1 = vector.insertelement %arg0, %arg1[%0 : index] : vector<4xf32>
- return %1 : vector<4xf32>
-}
-// CHECK-LABEL: @insertelement_into_vec_1d_f32_scalable_idx_as_index(
-// CHECK-SAME: %[[A:.*]]: f32,
-// CHECK-SAME: %[[B:.*]]: vector<4xf32>)
-// CHECK: %[[C:.*]] = arith.constant 3 : index
-// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %[[C]] : index to i64
-// CHECK: %[[X:.*]] = llvm.insertelement %[[A]], %[[B]][%[[I]] : i64] : vector<4xf32>
-// CHECK: return %[[X]] : vector<4xf32>
-
-// -----
-
-func.func @insertelement_into_vec_1d_f32_scalable_idx_as_index_scalable(%arg0: f32, %arg1: vector<[4]xf32>) -> vector<[4]xf32> {
- %0 = arith.constant 3 : index
- %1 = vector.insertelement %arg0, %arg1[%0 : index] : vector<[4]xf32>
- return %1 : vector<[4]xf32>
-}
-// CHECK-LABEL: @insertelement_into_vec_1d_f32_scalable_idx_as_index_scalable(
-// CHECK-SAME: %[[A:.*]]: f32,
-// CHECK-SAME: %[[B:.*]]: vector<[4]xf32>)
-// CHECK: %[[C:.*]] = arith.constant 3 : index
-// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %[[C]] : index to i64
-// CHECK: %[[X:.*]] = llvm.insertelement %[[A]], %[[B]][%[[I]] : i64] : vector<[4]xf32>
-// CHECK: return %[[X]] : vector<[4]xf32>
-
-// -----
-
-//===----------------------------------------------------------------------===//
// vector.insert
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index f43a41a..8918f91 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -400,67 +400,6 @@ func.func @insert_dynamic_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32>
// -----
-// CHECK-LABEL: @extract_element
-// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32
-// CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
-func.func @extract_element(%arg0 : vector<4xf32>, %id : i32) -> f32 {
- %0 = vector.extractelement %arg0[%id : i32] : vector<4xf32>
- return %0: f32
-}
-
-// -----
-
-// CHECK-LABEL: @extract_element_cst
-// CHECK-SAME: %[[V:.*]]: vector<4xf32>
-// CHECK: spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32>
-func.func @extract_element_cst(%arg0 : vector<4xf32>) -> f32 {
- %idx = arith.constant 1 : i32
- %0 = vector.extractelement %arg0[%idx : i32] : vector<4xf32>
- return %0: f32
-}
-
-// -----
-
-// CHECK-LABEL: @extract_element_index
-func.func @extract_element_index(%arg0 : vector<4xf32>, %id : index) -> f32 {
- // CHECK: spirv.VectorExtractDynamic
- %0 = vector.extractelement %arg0[%id : index] : vector<4xf32>
- return %0: f32
-}
-
-// -----
-
-// CHECK-LABEL: @extract_element_size5_vector
-func.func @extract_element_size5_vector(%arg0 : vector<5xf32>, %id : i32) -> f32 {
- // CHECK: vector.extractelement
- %0 = vector.extractelement %arg0[%id : i32] : vector<5xf32>
- return %0: f32
-}
-
-// -----
-
-// CHECK-LABEL: @extract_element_size1_vector
-// CHECK-SAME: (%[[S:.+]]: f32
-func.func @extract_element_size1_vector(%arg0 : f32, %i: index) -> f32 {
- %bcast = vector.broadcast %arg0 : f32 to vector<1xf32>
- %0 = vector.extractelement %bcast[%i : index] : vector<1xf32>
- // CHECK: return %[[S]]
- return %0: f32
-}
-
-// -----
-
-// CHECK-LABEL: @extract_element_0d_vector
-// CHECK-SAME: (%[[S:.+]]: f32)
-func.func @extract_element_0d_vector(%arg0 : f32) -> f32 {
- %bcast = vector.broadcast %arg0 : f32 to vector<f32>
- %0 = vector.extractelement %bcast[] : vector<f32>
- // CHECK: return %[[S]]
- return %0: f32
-}
-
-// -----
-
// CHECK-LABEL: @extract_strided_slice
// CHECK-SAME: %[[ARG:.+]]: vector<4xf32>
// CHECK: spirv.VectorShuffle [1 : i32, 2 : i32] %[[ARG]], %[[ARG]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
@@ -473,67 +412,6 @@ func.func @extract_strided_slice(%arg0: vector<4xf32>) -> (vector<2xf32>, vector
// -----
-// CHECK-LABEL: @insert_element
-// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32
-// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32
-func.func @insert_element(%val: f32, %arg0 : vector<4xf32>, %id : i32) -> vector<4xf32> {
- %0 = vector.insertelement %val, %arg0[%id : i32] : vector<4xf32>
- return %0: vector<4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @insert_element_cst
-// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>
-// CHECK: spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32>
-func.func @insert_element_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> {
- %idx = arith.constant 2 : i32
- %0 = vector.insertelement %val, %arg0[%idx : i32] : vector<4xf32>
- return %0: vector<4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @insert_element_index
-func.func @insert_element_index(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> {
- // CHECK: spirv.VectorInsertDynamic
- %0 = vector.insertelement %val, %arg0[%id : index] : vector<4xf32>
- return %0: vector<4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @insert_element_size5_vector
-func.func @insert_element_size5_vector(%val: f32, %arg0 : vector<5xf32>, %id : i32) -> vector<5xf32> {
- // CHECK: vector.insertelement
- %0 = vector.insertelement %val, %arg0[%id : i32] : vector<5xf32>
- return %0 : vector<5xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @insert_element_size1_vector
-// CHECK-SAME: (%[[S:[a-z0-9]+]]: f32
-func.func @insert_element_size1_vector(%scalar: f32, %vector : vector<1xf32>, %i: index) -> vector<1xf32> {
- %0 = vector.insertelement %scalar, %vector[%i : index] : vector<1xf32>
- // CHECK: %[[V:.+]] = builtin.unrealized_conversion_cast %arg0 : f32 to vector<1xf32>
- // CHECK: return %[[V]]
- return %0: vector<1xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @insert_element_0d_vector
-// CHECK-SAME: (%[[S:[a-z0-9]+]]: f32
-func.func @insert_element_0d_vector(%scalar: f32, %vector : vector<f32>) -> vector<f32> {
- %0 = vector.insertelement %scalar, %vector[] : vector<f32>
- // CHECK: %[[V:.+]] = builtin.unrealized_conversion_cast %arg0 : f32 to vector<f32>
- // CHECK: return %[[V]]
- return %0: vector<f32>
-}
-
-// -----
-
// CHECK-LABEL: @insert_strided_slice
// CHECK-SAME: %[[PART:.+]]: vector<2xf32>, %[[ALL:.+]]: vector<4xf32>
// CHECK: spirv.VectorShuffle [0 : i32, 4 : i32, 5 : i32, 3 : i32] %[[ALL]], %[[PART]] : vector<4xf32>, vector<2xf32> -> vector<4xf32>
diff --git a/mlir/test/Dialect/AMDGPU/canonicalize.mlir b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
index 4559e39..5501ad4 100644
--- a/mlir/test/Dialect/AMDGPU/canonicalize.mlir
+++ b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
@@ -130,3 +130,32 @@ func.func @dead_atomic_add(%arg0: memref<4xf32>, %arg1: f32) {
amdgpu.raw_buffer_atomic_fadd {boundsCheck = true} %arg1 -> %arg0[%c4_i32] : f32 -> memref<4xf32>, i32
func.return
}
+
+// -----
+
+// CHECK-LABEL: func @fold_gather_to_lds_of_cast
+func.func @fold_gather_to_lds_of_cast(%global: memref<128x72xf32, 1>, %lds: memref<64x64xf32, 3>) {
+// CHECK-SAME: %[[GLOBAL:[A-Za-z0-9]+]]: memref<128x72xf32, 1>
+ %c0 = arith.constant 0 : index
+ %0 = memref.cast %global : memref<128x72xf32, 1> to memref<?x?xf32, 1>
+ // CHECK: amdgpu.gather_to_lds %[[GLOBAL]]
+ // CHECK-SAME: : f32, memref<128x72xf32, 1>
+ amdgpu.gather_to_lds %0[%c0, %c0], %lds[%c0, %c0]
+ : f32, memref<?x?xf32, 1>, memref<64x64xf32, 3>
+ func.return
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_gather_to_lds_of_cast_dest
+func.func @fold_gather_to_lds_of_cast_dest(%global: memref<128x72xf32, 1>, %lds: memref<64x64xf32, 3>) {
+// CHECK-SAME: %[[GLOBAL:[A-Za-z0-9]+]]: memref<128x72xf32, 1>
+// CHECK-SAME: %[[LDS:[A-Za-z0-9]+]]: memref<64x64xf32, 3>
+ %c0 = arith.constant 0 : index
+ %0 = memref.cast %lds : memref<64x64xf32, 3> to memref<?x?xf32, 3>
+ // CHECK: amdgpu.gather_to_lds %[[GLOBAL]][{{.*}}], %[[LDS]]
+ // CHECK-SAME: : f32, memref<128x72xf32, 1>, memref<64x64xf32, 3>
+ amdgpu.gather_to_lds %global[%c0, %c0], %0[%c0, %c0]
+ : f32, memref<128x72xf32, 1>, memref<?x?xf32, 3>
+ func.return
+}
diff --git a/mlir/test/Dialect/Arith/mesh-spmdize.mlir b/mlir/test/Dialect/Arith/mesh-spmdize.mlir
deleted file mode 100644
index 6b55dd5..0000000
--- a/mlir/test/Dialect/Arith/mesh-spmdize.mlir
+++ /dev/null
@@ -1,17 +0,0 @@
-// RUN: mlir-opt \
-// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization))" \
-// RUN: %s | FileCheck %s
-
-mesh.mesh @mesh4x4(shape = 4x4)
-
-// CHECK-LABEL: func @test_spmdize_constant
-// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> :
-// tensor<256x1024xf32> CHECK-NEXT: [[vc434_i32:%.*]] = arith.constant 434 :
-// i32 CHECK-NEXT: return [[vcst]] : tensor<256x1024xf32>
-func.func @test_spmdize_constant() ->(tensor<1024x1024xf32>)attributes{llvm.emit_c_interface} {
- %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
- %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
- %sharding_annotated_1 = mesh.shard %cst to %sharding_1 : tensor<1024x1024xf32>
- %ci = arith.constant 434 : i32
- return %sharding_annotated_1 : tensor<1024x1024xf32>
-}
diff --git a/mlir/test/Dialect/Arith/shard-partition.mlir b/mlir/test/Dialect/Arith/shard-partition.mlir
new file mode 100644
index 0000000..be89427
--- /dev/null
+++ b/mlir/test/Dialect/Arith/shard-partition.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-opt \
+// RUN: --pass-pipeline="builtin.module(func.func(shard-partition))" \
+// RUN: %s | FileCheck %s
+
+shard.grid @grid4x4(shape = 4x4)
+
+// CHECK-LABEL: func @test_partition_constant
+// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> :
+// tensor<256x1024xf32> CHECK-NEXT: [[vc434_i32:%.*]] = arith.constant 434 :
+// i32 CHECK-NEXT: return [[vcst]] : tensor<256x1024xf32>
+func.func @test_partition_constant() ->(tensor<1024x1024xf32>)attributes{llvm.emit_c_interface} {
+ %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+ %sharding_1 = shard.sharding @grid4x4 split_axes = [[0]] : !shard.sharding
+ %sharded_1 = shard.shard %cst to %sharding_1 : tensor<1024x1024xf32>
+ %ci = arith.constant 434 : i32
+ return %sharded_1 : tensor<1024x1024xf32>
+}
diff --git a/mlir/test/Dialect/Arith/sharding-propagation.mlir b/mlir/test/Dialect/Arith/sharding-propagation.mlir
index 19eb340..762620d 100644
--- a/mlir/test/Dialect/Arith/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Arith/sharding-propagation.mlir
@@ -1,54 +1,54 @@
// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation))" %s | FileCheck %s
-mesh.mesh @mesh4x4(shape = 4x4)
+shard.grid @grid4x4(shape = 4x4)
// CHECK-LABEL: func.func @test_shard_constant() -> tensor<1024x1024xf32> attributes {llvm.emit_c_interface} {
// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
-// CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
-// CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding
+// CHECK-NEXT: [[vsharded:%.*]] = shard.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32>
// CHECK-NEXT: [[vcst_0:%.*]] = arith.constant 4.340000e+01 : f32
// CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<1024x1024xf32>
-// CHECK-NEXT: [[vsharding_1:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
-// CHECK-NEXT: [[vsharding_annotated_2:%.*]] = mesh.shard [[v0]] to [[vsharding_1]] : tensor<1024x1024xf32>
-// CHECK-NEXT: [[vsharding_3:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
-// CHECK-NEXT: [[vsharding_annotated_4:%.*]] = mesh.shard [[vsharding_annotated]] to [[vsharding_3]] annotate_for_users : tensor<1024x1024xf32>
-// CHECK-NEXT: [[vsharding_5:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
-// CHECK-NEXT: [[vsharding_annotated_6:%.*]] = mesh.shard [[vsharding_annotated_2]] to [[vsharding_5]] annotate_for_users : tensor<1024x1024xf32>
-// CHECK-NEXT: [[v1:%.*]] = linalg.add ins([[vsharding_annotated_4]], [[vcst_0]] : tensor<1024x1024xf32>, f32) outs([[vsharding_annotated_6]] : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
-// CHECK-NEXT: [[vsharding_7:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
-// CHECK-NEXT: [[vsharding_annotated_8:%.*]] = mesh.shard [[v1]] to [[vsharding_7]] : tensor<1024x1024xf32>
-// CHECK-NEXT: return [[vsharding_annotated_8]] : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_1:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding
+// CHECK-NEXT: [[vsharded_2:%.*]] = shard.shard [[v0]] to [[vsharding_1]] : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_3:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding
+// CHECK-NEXT: [[vsharded_4:%.*]] = shard.shard [[vsharded]] to [[vsharding_3]] annotate_for_users : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_5:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding
+// CHECK-NEXT: [[vsharded_6:%.*]] = shard.shard [[vsharded_2]] to [[vsharding_5]] annotate_for_users : tensor<1024x1024xf32>
+// CHECK-NEXT: [[v1:%.*]] = linalg.add ins([[vsharded_4]], [[vcst_0]] : tensor<1024x1024xf32>, f32) outs([[vsharded_6]] : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_7:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding
+// CHECK-NEXT: [[vsharded_8:%.*]] = shard.shard [[v1]] to [[vsharding_7]] : tensor<1024x1024xf32>
+// CHECK-NEXT: return [[vsharded_8]] : tensor<1024x1024xf32>
func.func @test_shard_constant() -> (tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
%cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
- %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
- %sharding_annotated_1 = mesh.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32>
+ %sharding_1 = shard.sharding @grid4x4 split_axes = [[0]] : !shard.sharding
+ %sharded_1 = shard.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32>
%ci = arith.constant 43.4e+00 : f32
%o1 = tensor.empty() : tensor<1024x1024xf32>
- %res = linalg.add ins(%sharding_annotated_1, %ci : tensor<1024x1024xf32>, f32) outs(%o1 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
+ %res = linalg.add ins(%sharded_1, %ci : tensor<1024x1024xf32>, f32) outs(%o1 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
return %res : tensor<1024x1024xf32>
}
// CHECK-LABEL: func.func @test_shard_constant_back() -> tensor<1024x1024xf32> attributes {llvm.emit_c_interface} {
// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
-// CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
-// CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding
+// CHECK-NEXT: [[vsharded:%.*]] = shard.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32>
// CHECK-NEXT: [[vcst_0:%.*]] = arith.constant 4.340000e+01 : f32
// CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<1024x1024xf32>
-// CHECK-NEXT: [[vsharding_1:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
-// CHECK-NEXT: [[vsharding_annotated_2:%.*]] = mesh.shard [[v0]] to [[vsharding_1]] : tensor<1024x1024xf32>
-// CHECK-NEXT: [[vsharding_3:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
-// CHECK-NEXT: [[vsharding_annotated_4:%.*]] = mesh.shard [[vsharding_annotated]] to [[vsharding_3]] annotate_for_users : tensor<1024x1024xf32>
-// CHECK-NEXT: [[vsharding_5:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
-// CHECK-NEXT: [[vsharding_annotated_6:%.*]] = mesh.shard [[vsharding_annotated_2]] to [[vsharding_5]] annotate_for_users : tensor<1024x1024xf32>
-// CHECK-NEXT: [[v1:%.*]] = linalg.add ins([[vsharding_annotated_4]], [[vcst_0]] : tensor<1024x1024xf32>, f32) outs([[vsharding_annotated_6]] : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
-// CHECK-NEXT: [[vsharding_7:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
-// CHECK-NEXT: [[vsharding_annotated_8:%.*]] = mesh.shard [[v1]] to [[vsharding_7]] : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_1:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding
+// CHECK-NEXT: [[vsharded_2:%.*]] = shard.shard [[v0]] to [[vsharding_1]] : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_3:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding
+// CHECK-NEXT: [[vsharded_4:%.*]] = shard.shard [[vsharded]] to [[vsharding_3]] annotate_for_users : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_5:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding
+// CHECK-NEXT: [[vsharded_6:%.*]] = shard.shard [[vsharded_2]] to [[vsharding_5]] annotate_for_users : tensor<1024x1024xf32>
+// CHECK-NEXT: [[v1:%.*]] = linalg.add ins([[vsharded_4]], [[vcst_0]] : tensor<1024x1024xf32>, f32) outs([[vsharded_6]] : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_7:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding
+// CHECK-NEXT: [[vsharded_8:%.*]] = shard.shard [[v1]] to [[vsharding_7]] : tensor<1024x1024xf32>
func.func @test_shard_constant_back() -> (tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
%cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
%ci = arith.constant 43.4e+00 : f32
%o1 = tensor.empty() : tensor<1024x1024xf32>
%res = linalg.add ins(%cst_1, %ci : tensor<1024x1024xf32>, f32) outs(%o1 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
- %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
- %sharding_annotated_1 = mesh.shard %res to %sharding_1 : tensor<1024x1024xf32>
- return %sharding_annotated_1 : tensor<1024x1024xf32>
+ %sharding_1 = shard.sharding @grid4x4 split_axes = [[0]] : !shard.sharding
+ %sharded_1 = shard.shard %res to %sharding_1 : tensor<1024x1024xf32>
+ return %sharded_1 : tensor<1024x1024xf32>
}
diff --git a/mlir/test/Dialect/Async/canonicalize.mlir b/mlir/test/Dialect/Async/canonicalize.mlir
new file mode 100644
index 0000000..1a74eaa
--- /dev/null
+++ b/mlir/test/Dialect/Async/canonicalize.mlir
@@ -0,0 +1,10 @@
+// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s
+
+// CHECK-NOT: async.execute
+
+func.func @empty_execute() {
+ %token = async.execute {
+ async.yield
+ }
+ return
+}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir
new file mode 100644
index 0000000..e2ab876
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='builtin.module(test.symbol_scope_isolated(test-one-shot-module-bufferize))' -split-input-file | FileCheck %s
+
+"test.symbol_scope_isolated"() ({
+ // CHECK-LABEL: func @inner_func(
+ // CHECK-SAME: %[[arg0:.*]]: memref<?xf32
+ func.func @inner_func(%t: tensor<?xf32>) -> (tensor<?xf32>, f32) {
+ // CHECK-NOT: copy
+ %f = arith.constant 1.0 : f32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK: memref.store %{{.*}}, %[[arg0]]
+ %0 = tensor.insert %f into %t[%c0] : tensor<?xf32>
+ // CHECK: %[[load:.*]] = memref.load %[[arg0]]
+ %1 = tensor.extract %0[%c1] : tensor<?xf32>
+ // CHECK: return %[[arg0]], %[[load]] : memref<?xf32{{.*}}>, f32
+ return %0, %1 : tensor<?xf32>, f32
+ }
+
+ // CHECK-LABEL: func @call_func_with_non_tensor_return(
+ // CHECK-SAME: %[[arg0:.*]]: memref<?xf32
+ func.func @call_func_with_non_tensor_return(
+ %t0: tensor<?xf32> {bufferization.writable = true}) -> (f32, tensor<?xf32>) {
+ // CHECK-NOT: alloc
+ // CHECK-NOT: copy
+ // CHECK: %[[call:.*]]:2 = call @inner_func(%[[arg0]])
+ %0, %1 = call @inner_func(%t0) : (tensor<?xf32>) -> (tensor<?xf32>, f32)
+ // CHECK: return %[[call]]#1, %[[call]]#0 : f32, memref<?xf32,{{.*}}>
+ return %1, %0 : f32, tensor<?xf32>
+ }
+ "test.finish" () : () -> ()
+}) : () -> ()
+
+
diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index f44e290..2acd194 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -255,16 +255,32 @@ func.func @clone_and_preceding_dealloc(%arg0: memref<?xf32>) -> memref<32xf32> {
func.func @tensor_cast_to_buffer(%arg0 : tensor<4x6x16x32xi8>) ->
memref<?x?x16x32xi8> {
%0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
- %1 = bufferization.to_buffer %0 : tensor<?x?x16x32xi8> to memref<?x?x16x32xi8>
+ %1 = bufferization.to_buffer %0 read_only : tensor<?x?x16x32xi8> to memref<?x?x16x32xi8>
return %1 : memref<?x?x16x32xi8>
}
-// CHECK: %[[M:.+]] = bufferization.to_buffer %[[ARG0]] : tensor<4x6x16x32xi8>
+// CHECK: %[[M:.+]] = bufferization.to_buffer %[[ARG0]] read_only : tensor<4x6x16x32xi8>
// CHECK: %[[M1:.+]] = memref.cast %[[M]]
// CHECK-SAME: memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
// CHECK: return %[[M1]] : memref<?x?x16x32xi8>
// -----
+// CHECK-LABEL: func @tensor_cast_to_buffer
+// CHECK-SAME: %[[ARG0:.+]]: tensor<4x6x16x32xi8>
+func.func @tensor_cast_to_buffer_layout_and_memspace(%arg0 : tensor<4x6x16x32xi8>) ->
+ memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1> {
+ %0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
+ %1 = bufferization.to_buffer %0 : tensor<?x?x16x32xi8> to memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
+ return %1 : memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
+}
+// CHECK: %[[M:.+]] = bufferization.to_buffer %[[ARG0]] : tensor<4x6x16x32xi8>
+// CHECK: %[[M1:.+]] = memref.cast %[[M]]
+// CHECK-SAME: memref<4x6x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
+// CHECK-SAME: to memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
+// CHECK: return %[[M1]] : memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
+
+// -----
+
// Folding of memref.load(to_buffer(%v, %idxs)) -> tensor.extract(%v, %idx)
// CHECK-LABEL: func @load_from_buffer_cast(
func.func @load_from_buffer_cast(%arg0: index, %arg1: index,
diff --git a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
index c67a0c1..029fa78 100644
--- a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
+++ b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=emitc.name_hint' %s | FileCheck %s
+// RUN: mlir-opt --wrap-emitc-func-in-class %s | FileCheck %s
module attributes { } {
emitc.func @model(%arg0: !emitc.array<1xf32> {emitc.name_hint = "another_feature"},
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 162ff06..35381da 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -479,20 +479,16 @@ func.func @shuffle_unsupported_type_vec(%arg0 : vector<[4]xf32>, %arg1 : i32, %a
// -----
func.func @rotate_mismatching_type(%arg0 : f32) {
- %offset = arith.constant 4 : i32
- %width = arith.constant 16 : i32
// expected-error@+1 {{op failed to verify that all of {value, rotateResult} have same type}}
- %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (i32, i1)
+ %rotate, %valid = "gpu.rotate"(%arg0) { offset = 4 : i32, width = 16 : i32 } : (f32) -> (i32, i1)
return
}
// -----
func.func @rotate_unsupported_type(%arg0 : index) {
- %offset = arith.constant 4 : i32
- %width = arith.constant 16 : i32
// expected-error@+1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'index'}}
- %rotate, %valid = gpu.rotate %arg0, %offset, %width : index
+ %rotate, %valid = gpu.rotate %arg0, 4, 16 : index
return
}
@@ -502,55 +498,31 @@ func.func @rotate_unsupported_type_vec(%arg0 : vector<[4]xf32>) {
%offset = arith.constant 4 : i32
%width = arith.constant 16 : i32
// expected-error@+1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'vector<[4]xf32>'}}
- %rotate, %valid = gpu.rotate %arg0, %offset, %width : vector<[4]xf32>
+ %rotate, %valid = gpu.rotate %arg0, 4, 16 : vector<[4]xf32>
return
}
// -----
func.func @rotate_unsupported_width(%arg0 : f32) {
- %offset = arith.constant 4 : i32
- %width = arith.constant 15 : i32
- // expected-error@+1 {{op width must be a power of two}}
- %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
+ // expected-error@+1 {{'gpu.rotate' op attribute 'width' failed to satisfy constraint: 32-bit signless integer attribute whose value is a power of two > 0}}
+ %rotate, %valid = "gpu.rotate"(%arg0) { offset = 4 : i32, width = 15 : i32 } : (f32) -> (f32, i1)
return
}
// -----
func.func @rotate_unsupported_offset(%arg0 : f32) {
- %offset = arith.constant 16 : i32
- %width = arith.constant 16 : i32
// expected-error@+1 {{op offset must be in the range [0, 16)}}
- %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
+ %rotate, %valid = "gpu.rotate"(%arg0) { offset = 16 : i32, width = 16 : i32 }: (f32) -> (f32, i1)
return
}
// -----
func.func @rotate_unsupported_offset_minus(%arg0 : f32) {
- %offset = arith.constant -1 : i32
- %width = arith.constant 16 : i32
- // expected-error@+1 {{op offset must be in the range [0, 16)}}
- %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
- return
-}
-
-// -----
-
-func.func @rotate_offset_non_constant(%arg0 : f32, %offset : i32) {
- %width = arith.constant 16 : i32
- // expected-error@+1 {{op offset is not a constant value}}
- %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
- return
-}
-
-// -----
-
-func.func @rotate_width_non_constant(%arg0 : f32, %width : i32) {
- %offset = arith.constant 0 : i32
- // expected-error@+1 {{op width is not a constant value}}
- %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
+ // expected-error@+1 {{'gpu.rotate' op attribute 'offset' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 0}}
+ %rotate, %valid = "gpu.rotate"(%arg0) { offset = -1 : i32, width = 16 : i32 } : (f32) -> (f32, i1)
return
}
diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index 2aef80f..ee1fdfa 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -140,9 +140,8 @@ module attributes {gpu.container_module} {
// CHECK: gpu.shuffle idx %{{.*}}, %{{.*}}, %{{.*}} : f32
%shfl3, %pred3 = gpu.shuffle idx %arg0, %offset, %width : f32
- // CHECK: gpu.rotate %{{.*}}, %{{.*}}, %{{.*}} : f32
- %rotate_width = arith.constant 16 : i32
- %rotate, %pred4 = gpu.rotate %arg0, %offset, %rotate_width : f32
+ // CHECK: gpu.rotate %{{.*}}, 3, 16 : f32
+ %rotate, %pred4 = gpu.rotate %arg0, 3, 16 : f32
"gpu.barrier"() : () -> ()
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 7284ae7..5c5f7e8 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1176,6 +1176,52 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>)
// -----
+// CHECK-LABEL: @broadcast_broadcast_fold
+// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32>
+// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x3xf32>
+// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32>
+// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) dimensions = [1, 2]
+// CHECK-NOT: linalg.broadcast
+// CHECK: return %[[BROADCAST]] : tensor<2x3x4xf32>
+func.func @broadcast_broadcast_fold(%input: tensor<2xf32>,
+ %init1: tensor<2x3xf32>,
+ %init2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+ %broadcast1 = linalg.broadcast
+ ins(%input: tensor<2xf32>)
+ outs(%init1: tensor<2x3xf32>)
+ dimensions = [1]
+ %broadcast2 = linalg.broadcast
+ ins(%broadcast1: tensor<2x3xf32>)
+ outs(%init2: tensor<2x3x4xf32>)
+ dimensions = [2]
+ func.return %broadcast2 : tensor<2x3x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast_broadcast_fold
+// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32>
+// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x4xf32>
+// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32>
+// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) dimensions = [1, 2]
+// CHECK-NOT: linalg.broadcast
+// CHECK: return %[[BROADCAST]] : tensor<2x3x4xf32>
+func.func @broadcast_broadcast_fold(%input: tensor<2xf32>,
+ %init1: tensor<2x4xf32>,
+ %init2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+ %broadcast1 = linalg.broadcast
+ ins(%input: tensor<2xf32>)
+ outs(%init1: tensor<2x4xf32>)
+ dimensions = [1]
+ %broadcast2 = linalg.broadcast
+ ins(%broadcast1: tensor<2x4xf32>)
+ outs(%init2: tensor<2x3x4xf32>)
+ dimensions = [1]
+ func.return %broadcast2 : tensor<2x3x4xf32>
+}
+
+// -----
+
func.func @transpose_1d(%input: tensor<16xf32>,
%init: tensor<16xf32>) -> tensor<16xf32> {
%transpose = linalg.transpose
@@ -1387,42 +1433,43 @@ func.func @recursive_effect(%arg : tensor<1xf32>) {
// CHECK-LABEL: @recursive_effect
// CHECK: linalg.map
+// -----
+
//===----------------------------------------------------------------------===//
// linalg.pack
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @fold_pack_constant_splat
// CHECK-NOT: linalg.pack
-// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
-func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
+// CHECK: arith.constant dense<1.000000e-01> : tensor<4x8x8x32xf32>
+func.func @fold_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> {
%cst = arith.constant dense<1.000000e-01> : tensor<64x128xf32>
%0 = linalg.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
- inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<8x16x8x32xf32>
- return %0 : tensor<8x16x8x32xf32>
+ inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<4x8x8x32xf32>
+ return %0 : tensor<4x8x8x32xf32>
}
// -----
// CHECK-LABEL: func @fold_padding_value_pack_constant_splat
// CHECK-NOT: linalg.pack
-// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
-func.func @fold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
+// CHECK: arith.constant dense<1.000000e-01> : tensor<4x8x8x32xf32>
+func.func @fold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> {
%pad = arith.constant 1.000000e-01 : f32
%cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
%0 = linalg.pack %cst
padding_value(%pad : f32)
outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
- inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
- return %0 : tensor<8x16x8x32xf32>
+ inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<4x8x8x32xf32>
+ return %0 : tensor<4x8x8x32xf32>
}
-
// -----
// CHECK-LABEL: func @nofold_padding_value_pack_constant_splat
// CHECK: arith.constant dense<1.000000e-01> : tensor<63x127xf32>
// CHECK: linalg.pack
-func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
+func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> {
%pad = arith.constant 0.0 : f32
%cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
%0 = linalg.pack %cst
@@ -1430,8 +1477,8 @@ func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32
outer_dims_perm = [1, 0]
inner_dims_pos = [0, 1]
inner_tiles = [8, 32]
- into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
- return %0 : tensor<8x16x8x32xf32>
+ into %dest : tensor<63x127xf32> -> tensor<4x8x8x32xf32>
+ return %0 : tensor<4x8x8x32xf32>
}
// -----
@@ -1889,31 +1936,84 @@ func.func @fold_cast_unpack_dynamic_tile_size(
// linalg.unpack + tensor.extract_slice
//===----------------------------------------------------------------------===//
-func.func @fold_extract_slice_into_unpack(
- %src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
-) -> tensor<28x28x?xf32> {
+func.func @fold_extract_slice_into_unpack_slicing_trailing_dim(%src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>) -> tensor<28x28x10xf32> {
%unpack = linalg.unpack %src
outer_dims_perm = [0, 1, 2]
inner_dims_pos = [1, 2]
inner_tiles = [16, 16]
- into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
+ into %dest : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
%extracted_slice = tensor.extract_slice %unpack
- [0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
- return %extracted_slice : tensor<28x28x?xf32>
+ [0, 0, 0] [28, 28, 10] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x28x10xf32>
+ return %extracted_slice : tensor<28x28x10xf32>
}
+// CHECK-LABEL: func @fold_extract_slice_into_unpack_slicing_trailing_dim
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK: %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]]
+// CHECK-SAME: [0, 0, 0] [28, 28, 10] [1, 1, 1]
+// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
+// CHECK-SAME: into %[[DEST_SLICE]]
+// CHECK: return %[[UNPACK]]
-// CHECK-LABEL: func @fold_extract_slice_into_unpack
-// CHECK-SAME: %[[SRC:.+]]: tensor<28x2x?x16x16xf32>
-// CHECK-SAME: %[[DEST:.+]]: tensor<28x32x?xf32>
-// CHECK-SAME: %[[SIZE:.+]]: index
+// -----
+
+// The available dimension size is [17, 32], because CeilDiv(%d1, 16) == 2.
+
+func.func @fold_extract_slice_into_unpack_slicing_dim_1(%src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>) -> tensor<28x17x15xf32> {
+ %unpack = linalg.unpack %src
+ inner_dims_pos = [1, 2]
+ inner_tiles = [16, 16]
+ into %dest : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
+ %extracted_slice = tensor.extract_slice %unpack
+ [0, 0, 0] [28, 17, 15] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x17x15xf32>
+ return %extracted_slice : tensor<28x17x15xf32>
+}
+// CHECK-LABEL: func @fold_extract_slice_into_unpack_slicing_dim_1(
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK: %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]]
-// CHECK-SAME: [0, 0, 0] [28, 28, %[[SIZE]]] [1, 1, 1]
+// CHECK-SAME: [0, 0, 0] [28, 17, 15] [1, 1, 1]
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
// CHECK-SAME: into %[[DEST_SLICE]]
// CHECK: return %[[UNPACK]]
// -----
+// The available dimension size is [17, 32], because CeilDiv(%d1, 16) == 2.
+
+func.func @no_fold_extract_slice_into_unpack_artificial_padding(%src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>) -> tensor<28x16x15xf32> {
+ %unpack = linalg.unpack %src
+ inner_dims_pos = [1, 2]
+ inner_tiles = [16, 16]
+ into %dest : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
+ %extracted_slice = tensor.extract_slice %unpack
+ [0, 0, 0] [28, 16, 15] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x16x15xf32>
+ return %extracted_slice : tensor<28x16x15xf32>
+}
+// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_artificial_padding
+// CHECK: linalg.unpack
+// CHECK: tensor.extract_slice
+
+// -----
+
+func.func @no_fold_extract_slice_into_unpack_dynamic(
+ %src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
+) -> tensor<28x28x?xf32> {
+ %unpack = linalg.unpack %src
+ outer_dims_perm = [0, 1, 2]
+ inner_dims_pos = [1, 2]
+ inner_tiles = [16, 16]
+ into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
+ %extracted_slice = tensor.extract_slice %unpack
+ [0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
+ return %extracted_slice : tensor<28x28x?xf32>
+}
+// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_dynamic
+// CHECK: linalg.unpack
+// CHECK: tensor.extract_slice
+
+// -----
+
func.func @no_fold_extract_slice_into_unpack_rank_reducing(
%src : tensor<28x2x16xf32>, %dest : tensor<28x32xf32>
) -> tensor<28xf32> {
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 6fc8d9f..cc26fa4 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1295,24 +1295,6 @@ func.func @no_bubble_up_pack_expanded_padding_through_expand_cannot_reassociate(
// -----
-func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(%arg0: tensor<32x64xf32>) -> tensor<8x4x16x8xf32> {
- %empty = tensor.empty() : tensor<8x4x16x8xf32>
- %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32>
- %pack = linalg.pack %expanded inner_dims_pos = [0] inner_tiles = [8] into %empty : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32>
- return %pack : tensor<8x4x16x8xf32>
-}
-// CHECK-LABEL: func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x4x16x8xf32>
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]]
-// CHECK-SAME: output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32>
-// CHECK: %[[PACK:.+]] = linalg.pack %[[EXPANDED]]
-// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [8] into %[[EMPTY]]
-// CHECK-SAME: : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32>
-// CHECK: return %[[PACK]] : tensor<8x4x16x8xf32>
-
-// -----
-
func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index, %sz0: index) -> tensor<?x256x256xf32> {
%6 = tensor.empty(%dim) : tensor<?x256xf32>
%unpack = linalg.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index a00c798..5f42938 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -1076,6 +1076,44 @@ func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> te
// -----
+func.func @drop_unit_dim_mixed_static_dynamic(%arg0: tensor<1x?xf32>) -> tensor<1x16xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %padded = tensor.pad %arg0 low[%c0, %c1] high[%c0, %c0] {
+ ^bb0(%arg1: index, %arg2: index):
+ tensor.yield %cst : f32
+ } : tensor<1x?xf32> to tensor<1x16xf32>
+ return %padded : tensor<1x16xf32>
+}
+// CHECK-LABEL: func @drop_unit_dim_mixed_static_dynamic
+// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARGS:.*]] : tensor<1x?xf32> into tensor<?xf32>
+// CHECK: %[[PADDED:.*]] = tensor.pad %[[COLLAPSE]] low[1] high[0] {
+// CHECK: ^bb0(%[[IDX:.*]]: index):
+// CHECK: tensor.yield %[[CST]] : f32
+// CHECK: } : tensor<?xf32> to tensor<16xf32>
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[PADDED]] {{\[\[}}0, 1]] output_shape [1, 16] : tensor<16xf32> into tensor<1x16xf32>
+// CHECK: return %[[EXPANDED]] : tensor<1x16xf32>
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+module {
+ func.func @drop_unit_dim_corresponding_to_dynamic_dim(%arg0: tensor<1x?x?x1xf32>, %arg1: index) -> tensor<?x1x61x1xf32> {
+ %cst = arith.constant dense<1.000000e+00> : tensor<1x1x1x1xf32>
+ %0 = tensor.empty(%arg1) : tensor<?x1x61x1xf32>
+ %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %cst : tensor<1x?x?x1xf32>, tensor<1x1x1x1xf32>) outs(%0 : tensor<?x1x61x1xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %2 = arith.mulf %in, %in_0 : f32
+ %3 = arith.addf %out, %2 : f32
+ linalg.yield %3 : f32
+ } -> tensor<?x1x61x1xf32>
+ return %1 : tensor<?x1x61x1xf32>
+ }
+}
// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (0, d0)>
// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> ()>
@@ -1097,23 +1135,6 @@ func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> te
// CHECK: return %[[VAL_14]] : tensor<?x1x61x1xf32>
// CHECK: }
-#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
-#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
-#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
-module {
- func.func @drop_unit_dim_corresponding_to_dynamic_dim(%arg0: tensor<1x?x?x1xf32>, %arg1: index) -> tensor<?x1x61x1xf32> {
- %cst = arith.constant dense<1.000000e+00> : tensor<1x1x1x1xf32>
- %0 = tensor.empty(%arg1) : tensor<?x1x61x1xf32>
- %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %cst : tensor<1x?x?x1xf32>, tensor<1x1x1x1xf32>) outs(%0 : tensor<?x1x61x1xf32>) {
- ^bb0(%in: f32, %in_0: f32, %out: f32):
- %2 = arith.mulf %in, %in_0 : f32
- %3 = arith.addf %out, %2 : f32
- linalg.yield %3 : f32
- } -> tensor<?x1x61x1xf32>
- return %1 : tensor<?x1x61x1xf32>
- }
-}
-
// -----
func.func @no_fold_empty_tensor_dim_out_of_bounds(%arg0: tensor<1x?x10xf32>) -> tensor<1x?xf32> {
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index da1dfc7..40bf4d1 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1760,6 +1760,7 @@ func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf
}
// -----
+
func.func @pack_mismatch_inner_tile_size_and_output_shape(
%input : tensor<?x?xf32>, %output : tensor<?x?x8x8xf32>) -> tensor<?x?x8x8xf32> {
// expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
@@ -1824,27 +1825,47 @@ func.func @unpack_invalid_outer_dims_perm(%source: tensor<128x256xf32>, %dest: t
// -----
+func.func @pack_with_artificial_padding(%input: tensor<9xf32>, %output: tensor<3x8xf32>) -> tensor<3x8xf32> {
+ %cst = arith.constant 0.0 : f32
+ // expected-error@+1 {{expected 'tensor<2x8xf32>' for the packed domain value, got 'tensor<3x8xf32>'}}
+ %0 = linalg.pack %input padding_value(%cst : f32) inner_dims_pos = [0]
+ inner_tiles = [8] into %output
+ : tensor<9xf32> -> tensor<3x8xf32>
+ return %0 : tensor<3x8xf32>
+}
+
+// -----
+
// The outer dims in the output tensor are incorrectly/unexpectedly transposed.
// This could be fixed by adding `outer_dims_perm = [1, 0]` (the default value assumes no transpose).
func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<4x16x32x16xf32>) -> tensor<4x16x32x16xf32> {
- // expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<16x4x32x16xf32>', got 'tensor<4x16x32x16xf32>'}}
+ // expected-error@+1 {{expected 'tensor<16x4x32x16xf32>' for the packed domain value, got 'tensor<4x16x32x16xf32>'}}
%0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [32, 16] into %output : tensor<256x128xf32> -> tensor<4x16x32x16xf32>
return %0 : tensor<4x16x32x16xf32>
}
// -----
-func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
- // expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x8x16x32xf32>', got 'tensor<8x8x32x16xf32>'}}
- %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32>
- return %0 : tensor<8x8x32x16xf32>
+func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<8x7x16x32xf32>) -> tensor<8x7x16x32xf32> {
+ // expected-error@+1 {{expected 'tensor<8x8x16x32xf32>' for the packed domain value, got 'tensor<8x7x16x32xf32>'}}
+ %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x7x16x32xf32>
+ return %0 : tensor<8x7x16x32xf32>
+}
+
+// -----
+
+func.func @unpack_with_artifical_tiles_that_are_dropped(%input: tensor<3x8xf32>, %output: tensor<9xf32>) -> tensor<9xf32> {
+ // expected-error@+1 {{expected 'tensor<2x8xf32>' for the packed domain value, got 'tensor<3x8xf32>'}}
+ %0 = linalg.unpack %input inner_dims_pos = [0] inner_tiles = [8] into %output
+ : tensor<3x8xf32> -> tensor<9xf32>
+ return %0 : tensor<9xf32>
}
// -----
-func.func @unpack_invalid(%output: tensor<256x128xf32>, %input: tensor<8x8x32x16xf32>) -> tensor<256x128xf32> {
- // expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x32x4x32xf32>', got 'tensor<8x8x32x16xf32>'}}
- %0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
+func.func @unpack_invalid_source_shape(%output: tensor<256x128xf32>, %input: tensor<8x8x4x32xf32>) -> tensor<256x128xf32> {
+ // expected-error@+1 {{expected 'tensor<8x32x4x32xf32>' for the packed domain value, got 'tensor<8x8x4x32xf32>'}}
+ %0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x4x32xf32> -> tensor<256x128xf32>
return %0 : tensor<256x128xf32>
}
diff --git a/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir b/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
deleted file mode 100644
index 5297eeb..0000000
--- a/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
+++ /dev/null
@@ -1,42 +0,0 @@
-// RUN: mlir-opt \
-// RUN: --verify-each \
-// RUN: --pass-pipeline="builtin.module(func.func(sharding-propagation))" \
-// RUN: %s | FileCheck %s
-
-mesh.mesh @mesh_2(shape = 2)
-
-// CHECK-LABEL: func @matmul_shard_prallel_axis
-func.func @matmul_shard_prallel_axis(
- // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<2x3xf32>,
- %arg0 : tensor<2x3xf32>,
- // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<3x2xf32>,
- %arg1 : tensor<3x2xf32>,
- // CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<2x2xf32>
- %out_dps: tensor<2x2xf32>
-) -> tensor<2x2xf32> {
- // CHECK: %[[SIN1_ANNOTATED_0:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding
- // CHECK-NEXT: %[[IN1_ANNOTATED_0:.*]] = mesh.shard %[[IN1]] to %[[SIN1_ANNOTATED_0]] : tensor<2x3xf32>
- // CHECK: %[[SIN1_ANNOTATED_1:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding
- // CHECK-NEXT: %[[IN1_ANNOTATED_1:.*]] = mesh.shard %[[IN1_ANNOTATED_0]] to %[[SIN1_ANNOTATED_1]] annotate_for_users : tensor<2x3xf32>
- // CHECK: %[[SIN2_ANNOTATED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[]] : !mesh.sharding
- // CHECK-NEXT: %[[IN2_ANNOTATED:.*]] = mesh.shard %[[IN2]] to %[[SIN2_ANNOTATED]] annotate_for_users : tensor<3x2xf32>
- // CHECK: %[[SDPS_OUT_ANNOTATED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding
- // CHECK-NEXT: %[[DPS_OUT_ANNOTATED:.*]] = mesh.shard %[[DPS_OUT]] to %[[SDPS_OUT_ANNOTATED]] annotate_for_users : tensor<2x2xf32>
- %sarg0_sharded = mesh.sharding @mesh_2 split_axes = [[0]] : !mesh.sharding
- %arg0_sharded = mesh.shard %arg0 to %sarg0_sharded : tensor<2x3xf32>
-
- // CHECK: %[[RES:.*]] = linalg.matmul ins(%[[IN1_ANNOTATED_1]], %[[IN2_ANNOTATED]] : tensor<2x3xf32>, tensor<3x2xf32>)
- // CHECK-SAME: outs(%[[DPS_OUT_ANNOTATED]] : tensor<2x2xf32>) -> tensor<2x2xf32>
- %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>)
- outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32>
-
- // CHECK: %[[SRES_ANNOTATED_0:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding
- // CHECK-NEXT: %[[RES_ANNOTATED_0:.*]] = mesh.shard %[[RES]] to %[[SRES_ANNOTATED_0]] : tensor<2x2xf32>
- // CHECK: %[[SRES_ANNOTATED_1:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[]] : !mesh.sharding
- // CHECK-NEXT: %[[RES_ANNOTATED_1:.*]] = mesh.shard %[[RES_ANNOTATED_0]] to %[[SRES_ANNOTATED_1]] annotate_for_users : tensor<2x2xf32>
- %sres_sharded = mesh.sharding @mesh_2 split_axes = [[]] : !mesh.sharding
- %res_sharded = mesh.shard %res to %sres_sharded annotate_for_users : tensor<2x2xf32>
-
- // CHECK: return %[[RES_ANNOTATED_1]] : tensor<2x2xf32>
- return %res_sharded : tensor<2x2xf32>
-}
diff --git a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir b/mlir/test/Dialect/Linalg/shard-partition.mlir
index ce12b29..aee9707 100644
--- a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
+++ b/mlir/test/Dialect/Linalg/shard-partition.mlir
@@ -1,15 +1,15 @@
// RUN: mlir-opt \
-// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-single-fold))" \
+// RUN: --pass-pipeline="builtin.module(func.func(shard-partition,test-single-fold))" \
// RUN: --split-input-file \
// RUN: %s | FileCheck %s
// CHECK: #[[$MAP_IDENTITY_1D:.*]] = affine_map<(d0) -> (d0)>
#map_identity_1d = affine_map<(d0) -> (d0)>
-mesh.mesh @mesh_1d(shape = 2)
+shard.grid @grid_1d(shape = 2)
-// CHECK-LABEL: func @elementwise_static_1d_mesh_static_1d_tensor
-func.func @elementwise_static_1d_mesh_static_1d_tensor(
+// CHECK-LABEL: func @elementwise_static_1d_grid_static_1d_tensor
+func.func @elementwise_static_1d_grid_static_1d_tensor(
// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<1xi8>,
%in1: tensor<2xi8>,
// CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<1xi8>,
@@ -18,13 +18,13 @@ func.func @elementwise_static_1d_mesh_static_1d_tensor(
%dps_out: tensor<2xi8>
// CHECK-SAME: -> tensor<1xi8> {
) -> tensor<2xi8> {
- %sharding = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %in1_sharded1 = mesh.shard %in1 to %sharding : tensor<2xi8>
- %in1_sharded2 = mesh.shard %in1_sharded1 to %sharding annotate_for_users : tensor<2xi8>
- %in2_sharded1 = mesh.shard %in2 to %sharding : tensor<2xi8>
- %in2_sharded2 = mesh.shard %in2_sharded1 to %sharding annotate_for_users : tensor<2xi8>
- %dps_out_sharded1 = mesh.shard %dps_out to %sharding : tensor<2xi8>
- %dps_out_shared2 = mesh.shard %dps_out_sharded1 to %sharding annotate_for_users : tensor<2xi8>
+ %sharding = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %in1_sharded1 = shard.shard %in1 to %sharding : tensor<2xi8>
+ %in1_sharded2 = shard.shard %in1_sharded1 to %sharding annotate_for_users : tensor<2xi8>
+ %in2_sharded1 = shard.shard %in2 to %sharding : tensor<2xi8>
+ %in2_sharded2 = shard.shard %in2_sharded1 to %sharding annotate_for_users : tensor<2xi8>
+ %dps_out_sharded1 = shard.shard %dps_out to %sharding : tensor<2xi8>
+ %dps_out_shared2 = shard.shard %dps_out_sharded1 to %sharding annotate_for_users : tensor<2xi8>
// CHECK: %[[RES:.*]] = linalg.generic {
// CHECK-SAME: indexing_maps = [#[[$MAP_IDENTITY_1D]], #[[$MAP_IDENTITY_1D]], #[[$MAP_IDENTITY_1D]]],
// CHECK-SAME: iterator_types = ["parallel"]}
@@ -39,18 +39,18 @@ func.func @elementwise_static_1d_mesh_static_1d_tensor(
%res_scalar = arith.muli %in1_scalar, %in2_scalar : i8
linalg.yield %res_scalar : i8
} -> tensor<2xi8>
- %res_sharded1 = mesh.shard %res to %sharding : tensor<2xi8>
- %res_shared2 = mesh.shard %res_sharded1 to %sharding annotate_for_users : tensor<2xi8>
+ %res_sharded1 = shard.shard %res to %sharding : tensor<2xi8>
+ %res_shared2 = shard.shard %res_sharded1 to %sharding annotate_for_users : tensor<2xi8>
// CHECK: return %[[RES]] : tensor<1xi8>
return %res_shared2 : tensor<2xi8>
}
// -----
-mesh.mesh @mesh_1d(shape = 4)
+shard.grid @grid_1d(shape = 4)
-// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_parallel_iterator_sharding
-func.func @matmul_1d_mesh_static_tensors_parallel_iterator_sharding(
+// CHECK-LABEL: func @matmul_1d_grid_static_tensors_parallel_iterator_sharding
+func.func @matmul_1d_grid_static_tensors_parallel_iterator_sharding(
// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<1x3xi8>,
%in1: tensor<4x3xi8>,
// CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<3x8xi8>,
@@ -59,32 +59,32 @@ func.func @matmul_1d_mesh_static_tensors_parallel_iterator_sharding(
%dps_out: tensor<4x8xi8>
// CHECK-SAME: -> tensor<1x8xi8> {
) -> tensor<4x8xi8> {
- %sharding = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %in1_shared1 = mesh.shard %in1 to %sharding : tensor<4x3xi8>
- %in1_shared2 = mesh.shard %in1_shared1 to %sharding annotate_for_users : tensor<4x3xi8>
- %sharding2 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %in2_shared1 = mesh.shard %in2 to %sharding2 : tensor<3x8xi8>
- %in2_shared2 = mesh.shard %in2_shared1 to %sharding2 annotate_for_users : tensor<3x8xi8>
- %dps_out_shared1 = mesh.shard %dps_out to %sharding : tensor<4x8xi8>
- %dps_out_shared2 = mesh.shard %dps_out_shared1 to %sharding annotate_for_users : tensor<4x8xi8>
+ %sharding = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %in1_shared1 = shard.shard %in1 to %sharding : tensor<4x3xi8>
+ %in1_shared2 = shard.shard %in1_shared1 to %sharding annotate_for_users : tensor<4x3xi8>
+ %sharding2 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+ %in2_shared1 = shard.shard %in2 to %sharding2 : tensor<3x8xi8>
+ %in2_shared2 = shard.shard %in2_shared1 to %sharding2 annotate_for_users : tensor<3x8xi8>
+ %dps_out_shared1 = shard.shard %dps_out to %sharding : tensor<4x8xi8>
+ %dps_out_shared2 = shard.shard %dps_out_shared1 to %sharding annotate_for_users : tensor<4x8xi8>
// CHECK: %[[RES:.*]] = linalg.matmul
// CHECK-SAME: ins(%[[IN1]], %[[IN2]] : tensor<1x3xi8>, tensor<3x8xi8>)
// CHECK-SAME: outs(%[[DPS_OUT]] : tensor<1x8xi8>)
// CHECK-SAME: -> tensor<1x8xi8>
%res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x3xi8>, tensor<3x8xi8>)
outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8>
- %res_shared1 = mesh.shard %res to %sharding : tensor<4x8xi8>
- %res_shared2 = mesh.shard %res_shared1 to %sharding annotate_for_users : tensor<4x8xi8>
+ %res_shared1 = shard.shard %res to %sharding : tensor<4x8xi8>
+ %res_shared2 = shard.shard %res_shared1 to %sharding annotate_for_users : tensor<4x8xi8>
// CHECK: return %[[RES]] : tensor<1x8xi8>
return %res_shared2 : tensor<4x8xi8>
}
// -----
-mesh.mesh @mesh_1d(shape = 3)
+shard.grid @grid_1d(shape = 3)
-// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding
-func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding(
+// CHECK-LABEL: func @matmul_1d_grid_static_tensors_reduction_iterator_sharding
+func.func @matmul_1d_grid_static_tensors_reduction_iterator_sharding(
// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<4x2xi8>,
%in1: tensor<4x6xi8>,
// CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<2x8xi8>,
@@ -93,19 +93,19 @@ func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding(
%dps_out: tensor<4x8xi8>
// CHECK-SAME: -> tensor<4x8xi8> {
) -> tensor<4x8xi8> {
- %sharding = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
- %in1_shared1 = mesh.shard %in1 to %sharding : tensor<4x6xi8>
- %in1_shared2 = mesh.shard %in1_shared1 to %sharding annotate_for_users : tensor<4x6xi8>
- %sharding2 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %in2_shared1 = mesh.shard %in2 to %sharding2 : tensor<6x8xi8>
- %in2_shared2 = mesh.shard %in2_shared1 to %sharding2 annotate_for_users : tensor<6x8xi8>
- %sharding3 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %dps_out_shared1 = mesh.shard %dps_out to %sharding3 : tensor<4x8xi8>
- %dps_out_shared2 = mesh.shard %dps_out_shared1 to %sharding3 annotate_for_users : tensor<4x8xi8>
+ %sharding = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding
+ %in1_shared1 = shard.shard %in1 to %sharding : tensor<4x6xi8>
+ %in1_shared2 = shard.shard %in1_shared1 to %sharding annotate_for_users : tensor<4x6xi8>
+ %sharding2 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %in2_shared1 = shard.shard %in2 to %sharding2 : tensor<6x8xi8>
+ %in2_shared2 = shard.shard %in2_shared1 to %sharding2 annotate_for_users : tensor<6x8xi8>
+ %sharding3 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+ %dps_out_shared1 = shard.shard %dps_out to %sharding3 : tensor<4x8xi8>
+ %dps_out_shared2 = shard.shard %dps_out_shared1 to %sharding3 annotate_for_users : tensor<4x8xi8>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C0_I8:.*]] = arith.constant 0 : i8
- // CHECK-DAG: %[[PROCESS_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
- // CHECK-DAG: %[[MESH_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index
+ // CHECK-DAG: %[[PROCESS_IDX:.*]] = shard.process_multi_index on @grid_1d axes = [0] : index
+ // CHECK-DAG: %[[SHARD_SIZE:.*]] = shard.grid_shape @grid_1d axes = [0] : index
// CHECK: %[[DPS_INIT_OPERAND_CONDITION:.*]] = arith.cmpi eq, %[[PROCESS_IDX]], %[[C0]] : index
// CHECK: %[[DPS_INIT_OPERAND:.*]] = scf.if %[[DPS_INIT_OPERAND_CONDITION]] -> (tensor<4x8xi8>) {
// CHECK: scf.yield %[[DPS_OUT]] : tensor<4x8xi8>
@@ -117,21 +117,21 @@ func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding(
// CHECK: }
// CHECK: %[[SHARDED_MATMUL:.*]] = linalg.matmul ins(%[[IN1]], %[[IN2]] : tensor<4x2xi8>, tensor<2x8xi8>)
// CHECK-SAME: outs(%[[DPS_INIT_OPERAND]] : tensor<4x8xi8>) -> tensor<4x8xi8>
- // CHECK: %[[ALL_REDUCED:.*]] = mesh.all_reduce %[[SHARDED_MATMUL]] on @mesh_1d mesh_axes = [0] : tensor<4x8xi8> -> tensor<4x8xi8>
+ // CHECK: %[[ALL_REDUCED:.*]] = shard.all_reduce %[[SHARDED_MATMUL]] on @grid_1d grid_axes = [0] : tensor<4x8xi8> -> tensor<4x8xi8>
%res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x6xi8>, tensor<6x8xi8>)
outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8>
- %res_shared1 = mesh.shard %res to %sharding3 : tensor<4x8xi8>
- %res_shared2 = mesh.shard %res_shared1 to %sharding3 annotate_for_users : tensor<4x8xi8>
+ %res_shared1 = shard.shard %res to %sharding3 : tensor<4x8xi8>
+ %res_shared2 = shard.shard %res_shared1 to %sharding3 annotate_for_users : tensor<4x8xi8>
// CHECK: return %[[ALL_REDUCED]] : tensor<4x8xi8>
return %res_shared2 : tensor<4x8xi8>
}
// -----
-mesh.mesh @mesh_1d(shape = 4)
+shard.grid @grid_1d(shape = 4)
-// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis
-func.func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis(
+// CHECK-LABEL: func @matmul_1d_grid_static_tensors_parallel_iterator_unsplit_last_axis
+func.func @matmul_1d_grid_static_tensors_parallel_iterator_unsplit_last_axis(
// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<4x6xi8>,
%in1: tensor<4x6xi8>,
// CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<6x8xi8>,
@@ -140,25 +140,25 @@ func.func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis(
%dps_out: tensor<4x8xi8>
// CHECK-SAME: -> tensor<4x8xi8> {
) -> tensor<4x8xi8> {
- %sharding1 = mesh.sharding @mesh_1d split_axes = [[], []] : !mesh.sharding
- %in1_replicated1 = mesh.shard %in1 to %sharding1 : tensor<4x6xi8>
- %in1_replicated2 = mesh.shard %in1_replicated1 to %sharding1 annotate_for_users : tensor<4x6xi8>
- // CHECK: %[[ALL_SLICE1:.*]] = mesh.all_slice %[[IN2]] on @mesh_1d mesh_axes = [0] slice_axis = 1
- %in2_replicated = mesh.shard %in2 to %sharding1 : tensor<6x8xi8>
- %sharding2 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
- %in2_sharded = mesh.shard %in2_replicated to %sharding2 annotate_for_users : tensor<6x8xi8>
- // CHECK: %[[ALL_SLICE2:.*]] = mesh.all_slice %[[DPS_OUT]] on @mesh_1d mesh_axes = [0] slice_axis = 1
- %dps_out_replicated = mesh.shard %dps_out to %sharding1 : tensor<4x8xi8>
- %dps_out_sharded = mesh.shard %dps_out_replicated to %sharding2 annotate_for_users : tensor<4x8xi8>
+ %sharding1 = shard.sharding @grid_1d split_axes = [[], []] : !shard.sharding
+ %in1_replicated1 = shard.shard %in1 to %sharding1 : tensor<4x6xi8>
+ %in1_replicated2 = shard.shard %in1_replicated1 to %sharding1 annotate_for_users : tensor<4x6xi8>
+ // CHECK: %[[ALL_SLICE1:.*]] = shard.all_slice %[[IN2]] on @grid_1d grid_axes = [0] slice_axis = 1
+ %in2_replicated = shard.shard %in2 to %sharding1 : tensor<6x8xi8>
+ %sharding2 = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding
+ %in2_sharded = shard.shard %in2_replicated to %sharding2 annotate_for_users : tensor<6x8xi8>
+ // CHECK: %[[ALL_SLICE2:.*]] = shard.all_slice %[[DPS_OUT]] on @grid_1d grid_axes = [0] slice_axis = 1
+ %dps_out_replicated = shard.shard %dps_out to %sharding1 : tensor<4x8xi8>
+ %dps_out_sharded = shard.shard %dps_out_replicated to %sharding2 annotate_for_users : tensor<4x8xi8>
// CHECK: %[[MATMUL_RES:.*]] = linalg.matmul
// CHECK-SAME: ins(%[[IN1]], %[[ALL_SLICE1]] : tensor<4x6xi8>, tensor<6x2xi8>)
// CHECK-SAME: outs(%[[ALL_SLICE2]] : tensor<4x2xi8>)
// CHECK-SAME: -> tensor<4x2xi8>
%res = linalg.matmul ins(%in1_replicated2, %in2_sharded : tensor<4x6xi8>, tensor<6x8xi8>)
outs(%dps_out_sharded : tensor<4x8xi8>) -> tensor<4x8xi8>
- // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[MATMUL_RES]] on @mesh_1d mesh_axes = [0] gather_axis = 1 : tensor<4x2xi8> -> tensor<4x8xi8>
- %res_sharded = mesh.shard %res to %sharding2 : tensor<4x8xi8>
- %res_replicated = mesh.shard %res_sharded to %sharding1 annotate_for_users : tensor<4x8xi8>
+ // CHECK: %[[ALL_GATHER:.*]] = shard.all_gather %[[MATMUL_RES]] on @grid_1d grid_axes = [0] gather_axis = 1 : tensor<4x2xi8> -> tensor<4x8xi8>
+ %res_sharded = shard.shard %res to %sharding2 : tensor<4x8xi8>
+ %res_replicated = shard.shard %res_sharded to %sharding1 annotate_for_users : tensor<4x8xi8>
// CHECK: return %[[ALL_GATHER]] : tensor<4x8xi8>
return %res_replicated : tensor<4x8xi8>
}
diff --git a/mlir/test/Dialect/Linalg/sharding-propagation.mlir b/mlir/test/Dialect/Linalg/sharding-propagation.mlir
new file mode 100644
index 0000000..e0ecefc
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/sharding-propagation.mlir
@@ -0,0 +1,42 @@
+// RUN: mlir-opt \
+// RUN: --verify-each \
+// RUN: --pass-pipeline="builtin.module(func.func(sharding-propagation))" \
+// RUN: %s | FileCheck %s
+
+shard.grid @grid_2(shape = 2)
+
+// CHECK-LABEL: func @matmul_shard_prallel_axis
+func.func @matmul_shard_prallel_axis(
+ // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<2x3xf32>,
+ %arg0 : tensor<2x3xf32>,
+ // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<3x2xf32>,
+ %arg1 : tensor<3x2xf32>,
+ // CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<2x2xf32>
+ %out_dps: tensor<2x2xf32>
+) -> tensor<2x2xf32> {
+ // CHECK: %[[SIN1_ANNOTATED_0:.*]] = shard.sharding @grid_2 split_axes = {{\[}}[0]] : !shard.sharding
+ // CHECK-NEXT: %[[IN1_ANNOTATED_0:.*]] = shard.shard %[[IN1]] to %[[SIN1_ANNOTATED_0]] : tensor<2x3xf32>
+ // CHECK: %[[SIN1_ANNOTATED_1:.*]] = shard.sharding @grid_2 split_axes = {{\[}}[0]] : !shard.sharding
+ // CHECK-NEXT: %[[IN1_ANNOTATED_1:.*]] = shard.shard %[[IN1_ANNOTATED_0]] to %[[SIN1_ANNOTATED_1]] annotate_for_users : tensor<2x3xf32>
+ // CHECK: %[[SIN2_ANNOTATED:.*]] = shard.sharding @grid_2 split_axes = {{\[}}[]] : !shard.sharding
+ // CHECK-NEXT: %[[IN2_ANNOTATED:.*]] = shard.shard %[[IN2]] to %[[SIN2_ANNOTATED]] annotate_for_users : tensor<3x2xf32>
+ // CHECK: %[[SDPS_OUT_ANNOTATED:.*]] = shard.sharding @grid_2 split_axes = {{\[}}[0]] : !shard.sharding
+ // CHECK-NEXT: %[[DPS_OUT_ANNOTATED:.*]] = shard.shard %[[DPS_OUT]] to %[[SDPS_OUT_ANNOTATED]] annotate_for_users : tensor<2x2xf32>
+ %sarg0_sharded = shard.sharding @grid_2 split_axes = [[0]] : !shard.sharding
+ %arg0_sharded = shard.shard %arg0 to %sarg0_sharded : tensor<2x3xf32>
+
+ // CHECK: %[[RES:.*]] = linalg.matmul ins(%[[IN1_ANNOTATED_1]], %[[IN2_ANNOTATED]] : tensor<2x3xf32>, tensor<3x2xf32>)
+ // CHECK-SAME: outs(%[[DPS_OUT_ANNOTATED]] : tensor<2x2xf32>) -> tensor<2x2xf32>
+ %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>)
+ outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32>
+
+ // CHECK: %[[SRES_ANNOTATED_0:.*]] = shard.sharding @grid_2 split_axes = {{\[}}[0]] : !shard.sharding
+ // CHECK-NEXT: %[[RES_ANNOTATED_0:.*]] = shard.shard %[[RES]] to %[[SRES_ANNOTATED_0]] : tensor<2x2xf32>
+ // CHECK: %[[SRES_ANNOTATED_1:.*]] = shard.sharding @grid_2 split_axes = {{\[}}[]] : !shard.sharding
+ // CHECK-NEXT: %[[RES_ANNOTATED_1:.*]] = shard.shard %[[RES_ANNOTATED_0]] to %[[SRES_ANNOTATED_1]] annotate_for_users : tensor<2x2xf32>
+ %sres_sharded = shard.sharding @grid_2 split_axes = [[]] : !shard.sharding
+ %res_sharded = shard.shard %res to %sres_sharded annotate_for_users : tensor<2x2xf32>
+
+ // CHECK: return %[[RES_ANNOTATED_1]] : tensor<2x2xf32>
+ return %res_sharded : tensor<2x2xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 81fd7a8..9e7681d 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -326,23 +326,23 @@ module attributes {transform.with_named_sequence} {
// -----
// CHECK-LABEL: func.func @pack_with_pad(
-func.func @pack_with_pad(%src: tensor<4225x12xf32>, %dest: tensor<265x16x16x1xf32>)
- -> tensor<265x16x16x1xf32> {
+func.func @pack_with_pad(%src: tensor<4225x12xf32>, %dest: tensor<265x12x16x1xf32>)
+ -> tensor<265x12x16x1xf32> {
// CHECK: tensor.pad {{.*}} low[0, 0]
- // CHECK: : tensor<4225x12xf32> to tensor<4240x16xf32>
+ // CHECK: : tensor<4225x12xf32> to tensor<4240x12xf32>
// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2, 3]]
- // CHECK-SAME: : tensor<4240x16xf32> into tensor<265x16x16x1xf32>
+ // CHECK-SAME: : tensor<4240x12xf32> into tensor<265x16x12x1xf32>
// CHECK: linalg.transpose
- // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>)
- // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>)
+ // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<265x16x12x1xf32>)
+ // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<265x12x16x1xf32>)
// CHECK-SAME: permutation = [0, 2, 1, 3]
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.pack %src
padding_value(%cst : f32)
inner_dims_pos = [0, 1]
inner_tiles = [16, 1] into %dest
- : tensor<4225x12xf32> -> tensor<265x16x16x1xf32>
- return %0 : tensor<265x16x16x1xf32>
+ : tensor<4225x12xf32> -> tensor<265x12x16x1xf32>
+ return %0 : tensor<265x12x16x1xf32>
}
module attributes {transform.with_named_sequence} {
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir
index 78619b6..981f5dc 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir
@@ -52,22 +52,22 @@ module {
// CHECK-LABEL: @generic
// CHECK-SAME: %[[T0:.*]]: tensor<7x5xf32>,
-// CHECK-SAME: %[[T1:.*]]: tensor<7x11x12xf32>)
- func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x12xf32>) -> tensor<7x11x12xf32> {
+// CHECK-SAME: %[[T1:.*]]: tensor<7x11x11xf32>)
+ func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x11xf32>) -> tensor<7x11x11xf32> {
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.
// CHECK: %[[PAD0:.*]] = tensor.pad %[[T0]] low[0, 0] high[2, 0]
// CHECK: : tensor<7x5xf32> to tensor<9x5xf32>
// CHECK: %[[PAD1:.*]] = tensor.pad %[[T1]] low[0, 0, 0] high[2, 4, 2] {
- // CHECK: : tensor<7x11x12xf32> to tensor<9x15x14xf32>
+ // CHECK: : tensor<7x11x11xf32> to tensor<9x15x13xf32>
// CHECK-NEXT: linalg.generic
- // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 12] [1, 1, 1] : tensor<9x15x14xf32> to tensor<7x11x12xf32>
- %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x12xf32>) {
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 11] [1, 1, 1] : tensor<9x15x13xf32> to tensor<7x11x11xf32>
+ %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x11xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
- } -> tensor<7x11x12xf32>
- return %0 : tensor<7x11x12xf32>
+ } -> tensor<7x11x11xf32>
+ return %0 : tensor<7x11x11xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
@@ -83,7 +83,7 @@ module {
// -----
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3 + 5)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3 + 4)>
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 5)>
#map = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -272,3 +272,136 @@ module attributes {transform.with_named_sequence} {
}
}
+// -----
+
+// CHECK-LABEL: pad_conv
+func.func @pad_conv(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 12]
+ // CHECK: : tensor<1x16x16x4xf32> to tensor<1x16x18x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
+ // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0]
+ // CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32>
+ // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32>
+
+ %0 = linalg.conv_2d_nhwc_fhwc
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>)
+ outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+ return %0 : tensor<1x14x14x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
+ padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
+ } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 16) * 16 + 2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 16) * 16)>
+
+// CHECK-LABEL: pad_conv_dynamic
+func.func @pad_conv_dynamic(%arg0: tensor<1x16x?x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x?x16xf32>) -> tensor<1x14x?x16xf32> {
+
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[D0_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32>
+ // CHECK: %[[D0_1:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x16x?x4xf32>
+ // CHECK: %[[H0:.*]] = affine.apply #[[$MAP0]]()[%[[D0_0]], %[[D0_1]]]
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, %[[H0]], 12]
+ // CHECK: : tensor<1x16x?x4xf32> to tensor<1x16x?x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
+ // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
+ // CHECK: %[[D1_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32>
+ // CHECK: %[[H1:.*]] = affine.apply #[[$MAP1]]()[%[[D0_0]], %[[D1_0]]]
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, %[[H1]], 0]
+ // CHECK: : tensor<1x14x?x16xf32> to tensor<1x14x?x16xf32>
+ // CHECK: %[[D2_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32>
+ // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, %[[D2_0]], 16] [1, 1, 1, 1] : tensor<1x14x?x16xf32> to tensor<1x14x?x16xf32>
+
+ %0 = linalg.conv_2d_nhwc_fhwc
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<1x16x?x4xf32>, tensor<16x3x3x4xf32>)
+ outs(%arg2: tensor<1x14x?x16xf32>) -> tensor<1x14x?x16xf32>
+ return %0 : tensor<1x14x?x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
+ padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
+ } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: pad_conv_strided
+func.func @pad_conv_strided(%arg0: tensor<1x42x42x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 6, 12]
+ // CHECK: : tensor<1x42x42x4xf32> to tensor<1x42x48x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
+ // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0]
+ // CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32>
+ // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32>
+
+ %0 = linalg.conv_2d_nhwc_fhwc
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<3> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<1x42x42x4xf32>, tensor<16x3x3x4xf32>)
+ outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+ return %0 : tensor<1x14x14x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
+ padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
+ } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: pad_conv_dilated
+func.func @pad_conv_dilated(%arg0: tensor<1x18x18x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 12]
+ // CHECK: : tensor<1x18x18x4xf32> to tensor<1x18x20x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
+ // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0]
+ // CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32>
+ // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32>
+
+ %0 = linalg.conv_2d_nhwc_fhwc
+ {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<1x18x18x4xf32>, tensor<16x3x3x4xf32>)
+ outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+ return %0 : tensor<1x14x14x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
+ padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
+ } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
index 26c03ed..f741876 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
@@ -69,22 +69,22 @@ module {
// CHECK-LABEL: @generic
// CHECK-SAME: %[[T0:.*]]: tensor<7x5xf32>,
-// CHECK-SAME: %[[T1:.*]]: tensor<7x11x12xf32>)
- func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x12xf32>) -> tensor<7x11x12xf32> {
+// CHECK-SAME: %[[T1:.*]]: tensor<7x11x11xf32>)
+ func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x11xf32>) -> tensor<7x11x11xf32> {
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.
// CHECK: %[[PAD0:.*]] = tensor.pad %[[T0]] low[0, 0] high[1, 0]
// CHECK: : tensor<7x5xf32> to tensor<8x5xf32>
// CHECK: %[[PAD1:.*]] = tensor.pad %[[T1]] low[0, 0, 0] high[1, 3, 1] {
- // CHECK: : tensor<7x11x12xf32> to tensor<8x14x13xf32>
+ // CHECK: : tensor<7x11x11xf32> to tensor<8x14x12xf32>
// CHECK-NEXT: linalg.generic
- // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 12] [1, 1, 1] : tensor<8x14x13xf32> to tensor<7x11x12xf32>
- %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x12xf32>) {
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 11] [1, 1, 1] : tensor<8x14x12xf32> to tensor<7x11x11xf32>
+ %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x11xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
- } -> tensor<7x11x12xf32>
- return %0 : tensor<7x11x12xf32>
+ } -> tensor<7x11x11xf32>
+ return %0 : tensor<7x11x11xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
@@ -102,7 +102,7 @@ module {
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (-s0 + 8)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (-s0 + 13)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (-s0 + 12)>
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 5)>
#map = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -127,13 +127,13 @@ module {
// CHECK: %[[D2_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<?x11x?xf32>
// CHECK: %[[H2:.*]] = affine.apply #[[$MAP1]]()[%[[D2_0]]]
// CHECK: tensor.pad %{{.*}} low[0, 0, 0] high[%[[H1]], 3, %[[H2]]] {
- // CHECK: : tensor<?x11x?xf32> to tensor<8x14x13xf32>
+ // CHECK: : tensor<?x11x?xf32> to tensor<8x14x12xf32>
//
// CHECK: %[[D0_2:.*]] = tensor.dim %{{.*}}, %[[C0]] : tensor<?x5xf32>
// CHECK: %[[D2_1:.*]] = affine.apply #[[$MAP2]]()[%[[D0_2]]]
- // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<8x5xf32>) outs(%{{.*}} : tensor<8x14x13xf32>) {
- // CHECK: } -> tensor<8x14x13xf32>
- // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [%[[D0_2]], 11, %[[D2_1]]] [1, 1, 1] : tensor<8x14x13xf32> to tensor<?x11x?xf32>
+ // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<8x5xf32>) outs(%{{.*}} : tensor<8x14x12xf32>) {
+ // CHECK: } -> tensor<8x14x12xf32>
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [%[[D0_2]], 11, %[[D2_1]]] [1, 1, 1] : tensor<8x14x12xf32> to tensor<?x11x?xf32>
//
%0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<?x5xf32>) outs(%arg1 : tensor<?x11x?xf32>) {
^bb0(%in: f32, %out: f32):
diff --git a/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir
index c3ee892..d7722ea 100644
--- a/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir
@@ -230,18 +230,17 @@ func.func @vectorize_nd_tensor_extract_index_from_tensor(%arg0: tensor<3x3xf32>,
// CHECK-SAME: %[[ARG4:.*]]: tensor<4x7x3x2xf32>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[PV:.*]] = ub.poison : i32
-// CHECK-DAG: %[[CST:.*]] = arith.constant dense<3> : vector<7x2x4x3xindex>
+// CHECK-DAG: %[[CST:.*]] = arith.constant dense<3> : vector<4x3xindex>
// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<true> : vector<4x7x3x2xi1>
// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<4x7x3x2xf32>
// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], %[[PV]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32>
// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], %[[PV]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32>
// CHECK: %[[CAST:.*]] = arith.index_cast %[[V0]] : vector<4x3xi32> to vector<4x3xindex>
-// CHECK: %[[B1:.*]] = vector.broadcast %[[CAST]] : vector<4x3xindex> to vector<7x2x4x3xindex>
// CHECK: %[[CAST_1:.*]] = arith.index_cast %[[V1]] : vector<4x3xi32> to vector<4x3xindex>
-// CHECK: %[[B2:.*]] = vector.broadcast %[[CAST_1]] : vector<4x3xindex> to vector<7x2x4x3xindex>
-// CHECK: %[[MULI:.*]] = arith.muli %[[B1]], %[[CST]] : vector<7x2x4x3xindex>
-// CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[MULI]] : vector<7x2x4x3xindex>
-// CHECK: %[[T:.*]] = vector.transpose %[[ADDI]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex>
+// CHECK: %[[MULI:.*]] = arith.muli %[[CAST]], %[[CST]] : vector<4x3xindex>
+// CHECK: %[[ADDI:.*]] = arith.addi %[[CAST_1]], %[[MULI]] : vector<4x3xindex>
+// CHECK: %[[B:.*]] = vector.broadcast %[[ADDI]] : vector<4x3xindex> to vector<7x2x4x3xindex>
+// CHECK: %[[T:.*]] = vector.transpose %[[B]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex>
// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]]] [%[[T]]], %[[CST_1]], %[[PASSTHRU]] : tensor<3x3xf32>, vector<4x7x3x2xindex>, vector<4x7x3x2xi1>, vector<4x7x3x2xf32> into vector<4x7x3x2xf32>
// CHECK: vector.transfer_write %[[GATHER]], %[[ARG4]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true]} : vector<4x7x3x2xf32>, tensor<4x7x3x2xf32>
@@ -270,20 +269,16 @@ func.func @vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load(%
// CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32>
// CHECK-SAME: %[[ARG1:.*]]: index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[CST:.*]] = arith.constant dense<768> : vector<1x8xindex>
-// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<128> : vector<1x8xindex>
// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
-// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<true> : vector<8x1xi1>
-// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
+// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<true> : vector<8x1xi1>
+// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<[0, 98304, 196608, 294912, 393216, 491520, 589824, 688128]> : vector<8xindex>
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<8x1xf32>
-// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_3]] : vector<8xindex> to vector<1x8xindex>
// CHECK: %[[ADDI_ARG1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index
-// CHECK: %[[MULI_1:.*]] = arith.muli %[[B1]], %[[CST_0]] : vector<1x8xindex>
-// CHECK: %[[MULI_2:.*]] = arith.muli %[[MULI_1]], %[[CST]] : vector<1x8xindex>
-// CHECK: %[[T:.*]] = vector.transpose %[[MULI_2]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
+// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_1]] : vector<8xindex> to vector<1x8xindex>
+// CHECK: %[[T:.*]] = vector.transpose %[[B1]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
// CHECK: %[[B2:.*]] = vector.broadcast %[[ADDI_ARG1]] : index to vector<8x1xindex>
// CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[T]] : vector<8x1xindex>
-// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_2]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
+// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_0]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
// CHECK: vector.transfer_write %[[GATHER]], %[[EMPTY]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
// -----
@@ -309,15 +304,13 @@ func.func @index_from_output_column_vector_gather_load(%src: tensor<8x128xf32>)
// CHECK-LABEL: func.func @index_from_output_column_vector_gather_load(
// CHECK-SAME: %[[SRC:.*]]: tensor<8x128xf32>) -> tensor<8x1xf32> {
-// CHECK: %[[C128:.*]] = arith.constant dense<128> : vector<1x8xindex>
+// CHECK: %[[IDX_VEC:.*]] = arith.constant dense<[0, 128, 256, 384, 512, 640, 768, 896]> : vector<8xindex>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[PASS_THRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8x1xi1>
-// CHECK: %[[IDX_VEC:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
// CHECK: %[[OUT:.*]] = tensor.empty() : tensor<8x1xf32>
// CHECK: %[[B:.*]] = vector.broadcast %[[IDX_VEC]] : vector<8xindex> to vector<1x8xindex>
-// CHECK: %[[MUL:.*]] = arith.muli %[[B]], %[[C128]] : vector<1x8xindex>
-// CHECK: %[[TR:.*]] = vector.transpose %[[MUL]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
+// CHECK: %[[TR:.*]] = vector.transpose %[[B]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
// CHECK: %[[GATHER:.*]] = vector.gather %[[SRC]]{{\[}}%[[C0]], %[[C0]]] {{\[}}%[[TR]]], %[[MASK]], %[[PASS_THRU]] : tensor<8x128xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
// CHECK: %[[RES:.*]] = vector.transfer_write %[[GATHER]], %[[OUT]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
// CHECK: return %[[RES]] : tensor<8x1xf32>
@@ -420,12 +413,12 @@ func.func @vectorize_nd_tensor_extract_with_affine_apply_gather(%6: tensor<80x16
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<true> : vector<1x4xi1>
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32>
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_7:.*]] = arith.constant dense<16> : vector<1x4xindex>
+// CHECK-DAG: %[[VAL_7:.*]] = arith.constant dense<16> : vector<4xindex>
// CHECK: %[[VAL_8:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex>
// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : vector<4xindex>
-// CHECK: %[[VAL_10:.*]] = vector.broadcast %[[VAL_9]] : vector<4xindex> to vector<1x4xindex>
-// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_10]], %[[VAL_7]] : vector<1x4xindex>
-// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_7]] : vector<1x4xindex>
+// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_9]], %[[VAL_7]] : vector<4xindex>
+// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_7]] : vector<4xindex>
+// CHECK: %[[VAL_12:.*]] = vector.broadcast %[[VAL_11]] : vector<4xindex> to vector<1x4xindex>
// CHECK: %[[VAL_13:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_12]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
// CHECK: %[[VAL_14:.*]] = vector.transfer_write %[[VAL_13]], %[[VAL_2]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
// CHECK: return %[[VAL_14]] : tensor<1x4xf32>
@@ -450,14 +443,12 @@ func.func @vectorize_nd_tensor_extract_with_maxsi_gather(%arg0: tensor<80x16xf32
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_maxsi_gather(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<80x16xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<1264> : vector<1x4xindex>
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[1264, 1265, 1266, 1267]> : vector<4xindex>
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<true> : vector<1x4xi1>
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32>
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_7:.*]] = vector.broadcast %[[VAL_2]] : vector<4xindex> to vector<1x4xindex>
-// CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_3]] : vector<1x4xindex>
-// CHECK: %[[VAL_9:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_8]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[VAL_9:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_7]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
// CHECK: %[[VAL_10:.*]] = vector.transfer_write %[[VAL_9]], %[[VAL_1]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
// CHECK: return %[[VAL_10]] : tensor<1x4xf32>
// CHECK: }
@@ -519,13 +510,13 @@ func.func @vectorize_reverse_like_tensor_extract(%arg0: tensor<1x2x3xf32>, %arg1
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]
// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]
-// CHECK-DAG: %[[CST:.+]] = arith.constant dense<3> : vector<1x1x3xindex>
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x3xi1>
// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32>
// CHECK-DAG: %[[INIT_IDX:.+]] = arith.constant dense<[2, 1, 0]> : vector<3xindex>
-// CHECK: %[[T0:.+]] = vector.broadcast %[[ARG2]] : index to vector<1x1x3xindex>
-// CHECK: %[[T1:.+]] = arith.muli %[[T0]], %[[CST]] : vector<1x1x3xindex>
+// CHECK: %[[T0:.+]] = arith.muli %[[ARG2]], %[[C3]] : index
+// CHECK: %[[T1:.+]] = vector.broadcast %[[T0]] : index to vector<1x1x3xindex>
// CHECK: %[[T2:.+]] = vector.broadcast %[[INIT_IDX]]
// CHECK: %[[T3:.+]] = arith.addi %[[T2]], %[[T1]]
// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[T3]]], %[[MASK]], %[[PASSTHRU]]
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
index 98e8f50..d41d861 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
@@ -941,20 +941,17 @@ module attributes {transform.with_named_sequence} {
// CHECK-LABEL: func @test_vectorize_dynamic_shapes_unpack
// CHECK-SAME: %[[ARG_0:.*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[ARG_1:.*]]: tensor<?x?x16x2xf32>
func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x16x2xf32>) -> tensor<?x?xf32> {
// CHECK: %[[C0:.*]] = arith.constant 0
-// CHECK: %[[DIM:.*]] = tensor.dim %arg0, %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[DIM0:.*]] = tensor.dim %arg0, %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00
// CHECK: %[[C01:.*]] = arith.constant 0
// CHECK: %[[C02:.*]] = arith.constant 0
-// CHECK: %[[DIM4:.*]] = tensor.dim %arg1, %[[C02]] : tensor<?x?x16x2xf32>
-// CHECK: %[[CNST14:.*]] = arith.constant 1
-// CHECK: %[[DIM6:.*]] = tensor.dim %arg1, %[[CNST14]] : tensor<?x?x16x2xf32>
+// CHECK: %[[DIM_0:.*]] = tensor.dim %[[ARG_1]], %[[C02]] : tensor<?x?x16x2xf32>
+// CHECK: %[[C1:.*]] = arith.constant 1
+// CHECK: %[[DIM6:.*]] = tensor.dim %[[ARG_1]], %[[C1]] : tensor<?x?x16x2xf32>
// CHECK: %[[CNST16:.*]] = arith.constant 16 : index
// CHECK: %[[CNST2:.*]] = arith.constant 2 : index
-// CHECK: %[[readMsk0:.*]] = vector.create_mask %[[DIM4]], %[[DIM6]], %[[CNST16]], %[[CNST2]] : vector<2x1x16x2xi1>
+// CHECK: %[[readMsk0:.*]] = vector.create_mask %[[DIM_0]], %[[DIM6]], %[[CNST16]], %[[CNST2]] : vector<2x1x16x2xi1>
// CHECK: %[[read0:.*]] = vector.mask %[[readMsk0]] {{.*}} vector.transfer_read %{{.*}} : tensor<?x?x16x2xf32>, vector<2x1x16x2xf32> } : vector<2x1x16x2xi1> -> vector<2x1x16x2xf32>
// CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 3, 1, 2] : vector<2x1x16x2xf32> to vector<2x2x1x16xf32>
// CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<2x2x1x16xf32> to vector<4x16xf32>
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
deleted file mode 100644
index aff07bb..0000000
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ /dev/null
@@ -1,248 +0,0 @@
-// RUN: mlir-opt --canonicalize %s | FileCheck %s
-
-mesh.mesh @mesh0(shape = 2x4)
-
-// CHECK-LABEL: func @all_reduce_empty_mesh_axes
-func.func @all_reduce_empty_mesh_axes(
-// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
- %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-// CHECK-NOT: mesh.all_reduce
- %0 = mesh.all_reduce %arg0 on @mesh0
- mesh_axes = []
- : tensor<4xf32> -> tensor<4xf32>
-// CHECK: return %[[ARG]]
- return %0 : tensor<4xf32>
-}
-
-// CHECK-LABEL: func @all_reduce_empty_mesh_axes_different_return_type
-func.func @all_reduce_empty_mesh_axes_different_return_type(
- %arg0 : tensor<4xf32>) -> tensor<4xf64> {
-// CHECK: mesh.all_reduce
- %0 = mesh.all_reduce %arg0 on @mesh0
-// CHECK-NOT: mesh_axes
- mesh_axes = []
- : tensor<4xf32> -> tensor<4xf64>
- return %0 : tensor<4xf64>
-}
-
-// CHECK-LABEL: func @all_reduce_default_reduction
-func.func @all_reduce_default_reduction(
- %arg0 : tensor<4xf32>) -> tensor<4xf64> {
- %0 = mesh.all_reduce %arg0 on @mesh0
- mesh_axes = [0]
-// CHECK-NOT: reduction
- reduction = sum
- : tensor<4xf32> -> tensor<4xf64>
- return %0 : tensor<4xf64>
-}
-
-// CHECK-LABEL: func @all_to_all_empty_mesh_axes
-func.func @all_to_all_empty_mesh_axes(
-// CHECK-SAME: %[[ARG:.*]]: tensor<8xf32>
- %arg0 : tensor<8xf32>) -> tensor<8xf32> {
-// CHECK-NOT: mesh.all_to_all
- %0 = mesh.all_to_all %arg0 on @mesh0
- mesh_axes = []
- split_axis = 0
- concat_axis = 0
- : tensor<8xf32> -> tensor<8xf32>
-// CHECK: return %[[ARG]]
- return %0 : tensor<8xf32>
-}
-
-// CHECK-LABEL: func @all_gather_empty_mesh_axes
-func.func @all_gather_empty_mesh_axes(
-// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
- %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-// CHECK-NOT: mesh.all_gather
- %0 = mesh.all_gather %arg0 on @mesh0
- mesh_axes = []
- gather_axis = 0
- : tensor<4xf32> -> tensor<4xf32>
-// CHECK: return %[[ARG]]
- return %0 : tensor<4xf32>
-}
-
-// CHECK-LABEL: func @all_slice_empty_mesh_axes
-func.func @all_slice_empty_mesh_axes(
-// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
- %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-// CHECK-NOT: mesh.scatter
- %0 = mesh.all_slice %arg0 on @mesh0
- mesh_axes = []
- slice_axis = 0
- : tensor<4xf32> -> tensor<4xf32>
-// CHECK: return %[[ARG]]
- return %0 : tensor<4xf32>
-}
-
-// CHECK-LABEL: func @broadcast_empty_mesh_axes
-func.func @broadcast_empty_mesh_axes(
-// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
- %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-// CHECK-NOT: mesh.broadcast
- %0 = mesh.broadcast %arg0 on @mesh0
- mesh_axes = []
- root = []
- : (tensor<4xf32>) -> tensor<4xf32>
-// CHECK: return %[[ARG]]
- return %0 : tensor<4xf32>
-}
-
-// CHECK-LABEL: func @gather_empty_mesh_axes
-func.func @gather_empty_mesh_axes(
-// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
- %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-// CHECK-NOT: mesh.gather
- %0 = mesh.gather %arg0 on @mesh0
- mesh_axes = []
- gather_axis = 0
- root = []
- : (tensor<4xf32>) -> tensor<4xf32>
-// CHECK: return %[[ARG]]
- return %0 : tensor<4xf32>
-}
-
-// CHECK-LABEL: func @receive_empty_mesh_axes
-func.func @receive_empty_mesh_axes(
-// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
- %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-// CHECK-NOT: mesh.recv
- %0 = mesh.recv %arg0 on @mesh0
- mesh_axes = []
- : (tensor<4xf32>) -> tensor<4xf32>
-// CHECK: return %[[ARG]]
- return %0 : tensor<4xf32>
-}
-
-// CHECK-LABEL: func @reduce_empty_mesh_axes
-func.func @reduce_empty_mesh_axes(
-// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
- %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-// CHECK-NOT: mesh.reduce
- %0 = mesh.reduce %arg0 on @mesh0
- mesh_axes = []
- root = []
- : (tensor<4xf32>) -> tensor<4xf32>
-// CHECK: return %[[ARG]]
- return %0 : tensor<4xf32>
-}
-
-// CHECK-LABEL: func @reduce_scatter_empty_mesh_axes
-func.func @reduce_scatter_empty_mesh_axes(
-// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
- %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-// CHECK-NOT: mesh.reduce_scatter
- %0 = mesh.reduce_scatter %arg0 on @mesh0
- mesh_axes = []
- scatter_axis = 0
- : tensor<4xf32> -> tensor<4xf32>
-// CHECK: return %[[ARG]]
- return %0 : tensor<4xf32>
-}
-
-// CHECK-LABEL: func @reduce_scatter_empty_mesh_axes_different_return_type
-func.func @reduce_scatter_empty_mesh_axes_different_return_type(
- %arg0 : tensor<4xf32>) -> tensor<4xf64> {
-// CHECK: mesh.reduce_scatter
- %0 = mesh.reduce_scatter %arg0 on @mesh0
-// CHECK-NOT: mesh_axes
- mesh_axes = []
- scatter_axis = 0
- : tensor<4xf32> -> tensor<4xf64>
- return %0 : tensor<4xf64>
-}
-
-// CHECK-LABEL: func @reduce_scatter_default_reduction
-func.func @reduce_scatter_default_reduction(
- %arg0 : tensor<4xf32>) -> tensor<2xf64> {
- %0 = mesh.reduce_scatter %arg0 on @mesh0
- mesh_axes = [0]
-// CHECK-NOT: reduction
- reduction = sum
- scatter_axis = 0
- : tensor<4xf32> -> tensor<2xf64>
- return %0 : tensor<2xf64>
-}
-
-// CHECK-LABEL: func @scatter_empty_mesh_axes
-func.func @scatter_empty_mesh_axes(
-// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
- %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-// CHECK-NOT: mesh.scatter
- %0 = mesh.scatter %arg0 on @mesh0
- mesh_axes = []
- scatter_axis = 0
- root = []
- : (tensor<4xf32>) -> tensor<4xf32>
-// CHECK: return %[[ARG]]
- return %0 : tensor<4xf32>
-}
-
-// CHECK-LABEL: func @send_empty_mesh_axes
-func.func @send_empty_mesh_axes(
-// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
- %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-// CHECK-NOT: mesh.send
- %0 = mesh.send %arg0 on @mesh0
- mesh_axes = []
- destination = []
- : (tensor<4xf32>) -> tensor<4xf32>
-// CHECK: return %[[ARG]]
- return %0 : tensor<4xf32>
-}
-
-mesh.mesh @mesh4x4(shape = 4x4)
-// CHECK-LABEL: func @test_halo_sizes
-func.func @test_halo_sizes() -> !mesh.sharding {
- %c2_i64 = arith.constant 2 : i64
- // CHECK mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 2, 22] : !mesh.sharding
- %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, %c2_i64, %c2_i64, 22] : !mesh.sharding
- return %sharding : !mesh.sharding
-}
-
-// CHECK-LABEL: func @test_shard_offs
-func.func @test_shard_offs() -> !mesh.sharding {
- %c2_i64 = arith.constant 2 : i64
- // CHECK mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, 2, 3, 4, 0, 2, 3, 4, 22] : !mesh.sharding
- %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, %c2_i64, 3, 4, 0, %c2_i64, 3, 4, 22] : !mesh.sharding
- return %sharding : !mesh.sharding
-}
-
-// CHECK-LABEL: func @test_duplicate_shardops
-func.func @test_duplicate_shardops() -> (tensor<1024x1024xf32>, tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
- // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
- %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
- // CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0, 1]] : !mesh.sharding
- %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding
- %cst_2 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
- %sharding_2 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding
- %sharding_annotated_2 = mesh.shard %cst_2 to %sharding_2 : tensor<1024x1024xf32>
- %cst_3 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
- %sharding_3 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding
- %sharding_annotated_3 = mesh.shard %cst_3 to %sharding_3 : tensor<1024x1024xf32>
- // CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32>
- %sharding_annotated_1 = mesh.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32>
- // CHECK-NEXT: return [[vsharding_annotated]], [[vsharding_annotated]] : tensor<1024x1024xf32>, tensor<1024x1024xf32>
- return %sharding_annotated_1, %sharding_annotated_2 : tensor<1024x1024xf32>, tensor<1024x1024xf32>
-}
-
-// CHECK-LABEL: func @test_duplicate_shardops_diff
-func.func @test_duplicate_shardops_diff() -> (tensor<1024x1024xf32>, tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
- // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
- %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
- // CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
- %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
- %cst_2 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
- // CHECK-NEXT: [[vsharding_0:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0, 1]] : !mesh.sharding
- %sharding_2 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding
- // CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding_0]] : tensor<1024x1024xf32>
- %sharding_annotated_2 = mesh.shard %cst_2 to %sharding_2 : tensor<1024x1024xf32>
- %cst_3 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
- %sharding_3 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
- %sharding_annotated_3 = mesh.shard %cst_3 to %sharding_3 : tensor<1024x1024xf32>
- // CHECK-NEXT: [[vsharding_annotated_1:%.*]] = mesh.shard [[vsharding_annotated]] to [[vsharding]] : tensor<1024x1024xf32>
- %sharding_annotated_1 = mesh.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32>
- // CHECK-NEXT: return [[vsharding_annotated_1]], [[vsharding_annotated]] : tensor<1024x1024xf32>, tensor<1024x1024xf32>
- return %sharding_annotated_1, %sharding_annotated_2 : tensor<1024x1024xf32>, tensor<1024x1024xf32>
-}
diff --git a/mlir/test/Dialect/Mesh/folding.mlir b/mlir/test/Dialect/Mesh/folding.mlir
deleted file mode 100644
index 369f316d..0000000
--- a/mlir/test/Dialect/Mesh/folding.mlir
+++ /dev/null
@@ -1,22 +0,0 @@
-// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
-
-mesh.mesh @mesh0(shape = 4x?x2)
-mesh.mesh @mesh1(shape = 2x3)
-
-// CHECK-LABEL: func.func @mesh_shape_op_folding
-func.func @mesh_shape_op_folding() -> (index, index) {
- // CHECK: %[[AXIS_2_SIZE:.*]] = arith.constant 2 : index
- // CHECK: %[[AXIS_1_SIZE:.*]] = mesh.mesh_shape @mesh0 axes = [1] : index
- %0:2 = mesh.mesh_shape @mesh0 axes = [2, 1] : index, index
- // CHECK: return %[[AXIS_2_SIZE]], %[[AXIS_1_SIZE]]
- return %0#0, %0#1 : index, index
-}
-
-// CHECK-LABEL: func.func @mesh_shape_op_folding_all_axes_static_mesh
-func.func @mesh_shape_op_folding_all_axes_static_mesh() -> (index, index) {
- // CHECK: %[[AXIS_0_SIZE:.*]] = arith.constant 2 : index
- // CHECK: %[[AXIS_1_SIZE:.*]] = arith.constant 3 : index
- %0:2 = mesh.mesh_shape @mesh1 : index, index
- // CHECK: return %[[AXIS_0_SIZE]], %[[AXIS_1_SIZE]]
- return %0#0, %0#1 : index, index
-}
diff --git a/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir b/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir
deleted file mode 100644
index 6ab711b..0000000
--- a/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir
+++ /dev/null
@@ -1,49 +0,0 @@
-// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward}))" %s | FileCheck %s
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
-module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "mpich", "MPI:comm_world_rank" = 0 : i32>} {
- mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
- func.func @test_forward() -> (tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>) attributes {llvm.emit_c_interface} {
- %c1_i32 = arith.constant 1 : i32
- // CHECK: [[v3:%.*]] = tensor.empty() : tensor<6x6xi32>
- %0 = tensor.empty() : tensor<6x6xi32>
- // CHECK: [[v1:%.*]] = linalg.fill ins
- // CHECK: [[vsharding_0:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
- // CHECK: [[vsharding_annotated_1:%.*]] = mesh.shard [[v1]] to [[vsharding_0]] : tensor<6x6xi32>
- %1 = linalg.fill ins(%c1_i32 : i32) outs(%0 : tensor<6x6xi32>) -> tensor<6x6xi32>
- %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
- %sharding_annotated = mesh.shard %1 to %sharding : tensor<6x6xi32>
- // CHECK: [[v2:%.*]] = tensor.empty() : tensor<6x6xi32>
- // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
- // CHECK: [[vsharding_annotated_3:%.*]] = mesh.shard [[vsharding_annotated_1]] to [[vsharding_2]] annotate_for_users : tensor<6x6xi32>
- %3 = tensor.empty() : tensor<6x6xi32>
- // CHECK: [[vsharding_4:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
- // CHECK: [[vsharding_annotated_5:%.*]] = mesh.shard [[v2]] to [[vsharding_4]] annotate_for_users : tensor<6x6xi32>
- // CHECK: [[v3:%.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]}
- // CHECK-SAME: ins([[vsharding_annotated_3]], [[vsharding_annotated_3]] : tensor<6x6xi32>, tensor<6x6xi32>) outs([[vsharding_annotated_5]] : tensor<6x6xi32>) {
- // CHECK: [[vsharding_6:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
- // CHECK: [[vsharding_annotated_7:%.*]] = mesh.shard [[v3]] to [[vsharding_6]] : tensor<6x6xi32>
- %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%sharding_annotated, %sharding_annotated
- : tensor<6x6xi32>, tensor<6x6xi32>) outs(%3 : tensor<6x6xi32>) {
- ^bb0(%in: i32, %in_2: i32, %out: i32):
- %9 = arith.addi %in, %in_2 : i32
- linalg.yield %9 : i32
- } -> tensor<6x6xi32>
- %c0_i32 = arith.constant 0 : i32
- %6 = tensor.empty() : tensor<i32>
- %7 = linalg.fill ins(%c0_i32 : i32) outs(%6 : tensor<i32>) -> tensor<i32>
- // CHECK: [[vreduced:%.*]] = linalg.reduce ins
- // CHECK: [[vsharding_12:%.*]] = mesh.sharding @mesh split_axes = [] : !mesh.sharding
- // CHECK: [[vsharding_annotated_13:%.*]] = mesh.shard [[vreduced]] to [[vsharding_12]] : tensor<i32>
- %reduced = linalg.reduce ins(%4 : tensor<6x6xi32>) outs(%7 : tensor<i32>) dimensions = [0, 1]
- (%in: i32, %init: i32) {
- %9 = arith.addi %in, %init : i32
- linalg.yield %9 : i32
- }
- // CHECK: [[vsharding_14:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}]] : !mesh.sharding
- %sharding_0 = mesh.sharding @mesh split_axes = [[]] : !mesh.sharding
- // CHECK: [[vsharding_annotated_15:%.*]] = mesh.shard [[vsharding_annotated_13]] to [[vsharding_14]] annotate_for_users : tensor<i32>
- %sharding_annotated_1 = mesh.shard %reduced to %sharding_0 annotate_for_users : tensor<i32>
- return %sharding_annotated, %4, %sharding_annotated_1 : tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>
- }
-}
diff --git a/mlir/test/Dialect/Mesh/inlining.mlir b/mlir/test/Dialect/Mesh/inlining.mlir
deleted file mode 100644
index c41a709..0000000
--- a/mlir/test/Dialect/Mesh/inlining.mlir
+++ /dev/null
@@ -1,15 +0,0 @@
-// RUN: mlir-opt -inline %s | FileCheck %s
-
-mesh.mesh @mesh0(shape = 4x?x2)
-
-func.func private @mesh_to_inline() -> (index, index) {
- %0:2 = mesh.mesh_shape @mesh0 axes = [2, 1] : index, index
- return %0#0, %0#1 : index, index
-}
-// CHECK-LABEL: func.func @main
-func.func @main() -> (index, index) {
- // CHECK-NEXT: %[[AXIS_SIZE:.*]]:2 = mesh.mesh_shape @mesh0 axes = [2, 1] : index
- %0:2 = func.call @mesh_to_inline() : () -> (index, index)
- // CHECK-NEXT: return %[[AXIS_SIZE]]#0, %[[AXIS_SIZE]]#1
- return %0#0, %0#1 : index, index
-}
diff --git a/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir b/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
deleted file mode 100644
index e23cfd7..0000000
--- a/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
+++ /dev/null
@@ -1,23 +0,0 @@
-// RUN: mlir-opt -test-mesh-process-multi-index-op-lowering %s | FileCheck %s
-
-mesh.mesh @mesh2d(shape = ?x?)
-
-// CHECK-LABEL: func.func @multi_index_2d_mesh
-func.func @multi_index_2d_mesh() -> (index, index) {
- // CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index
- // CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.mesh_shape @mesh2d : index, index
- // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
- %0:2 = mesh.process_multi_index on @mesh2d : index, index
- // CHECK: return %[[MULTI_IDX]]#0, %[[MULTI_IDX]]#1 : index, index
- return %0#0, %0#1 : index, index
-}
-
-// CHECK-LABEL: func.func @multi_index_2d_mesh_single_inner_axis
-func.func @multi_index_2d_mesh_single_inner_axis() -> index {
- // CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index
- // CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.mesh_shape @mesh2d : index, index
- // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
- %0 = mesh.process_multi_index on @mesh2d axes = [0] : index
- // CHECK: return %[[MULTI_IDX]]#0 : index
- return %0 : index
-}
diff --git a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
deleted file mode 100644
index 5e62c92..0000000
--- a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
+++ /dev/null
@@ -1,168 +0,0 @@
-// RUN: mlir-opt -test-mesh-resharding-spmdization %s | FileCheck %s
-
-mesh.mesh @mesh_1d(shape = 2)
-mesh.mesh @mesh_1d_dynamic(shape = ?)
-
-// CHECK-LABEL: func @same_source_and_target_sharding
-func.func @same_source_and_target_sharding(
- // CHECK-SAME: %[[ARG:.*]]: tensor<2xf32>
- %arg0: tensor<2xf32>
-) -> tensor<2xf32> {
- %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<2xf32>
- %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2xf32>
- // CHECK: return %[[ARG]]
- return %1 : tensor<2xf32>
-}
-
-// CHECK-LABEL: func @identical_source_and_target_sharding
-func.func @identical_source_and_target_sharding(
- // CHECK-SAME: %[[ARG:.*]]: tensor<2xf32>
- %arg0: tensor<2xf32>
-) -> tensor<2xf32> {
- %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<2xf32>
- %1 = mesh.shard %0 to %s0 annotate_for_users : tensor<2xf32>
- // CHECK: return %[[ARG]]
- return %1 : tensor<2xf32>
-}
-
-// CHECK-LABEL: func @split_replicated_tensor_axis
-func.func @split_replicated_tensor_axis(
- // CHECK-SAME: %[[ARG:.*]]: tensor<3x14xf32>
- %arg0: tensor<3x14xf32>
-) -> tensor<3x14xf32> {
- // CHECK: %[[ALL_SLICE:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 1
- // CHECK-SAME: tensor<3x14xf32> -> tensor<3x7xf32>
- // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[ALL_SLICE]] : tensor<3x7xf32> to tensor<3x14xf32>
- %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<3x14xf32>
- %s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<3x14xf32>
- // CHECK: return %[[RESULT]] : tensor<3x14xf32>
- return %1 : tensor<3x14xf32>
-}
-
-// CHECK-LABEL: func @split_replicated_tensor_axis_dynamic
-func.func @split_replicated_tensor_axis_dynamic(
- // CHECK-SAME: %[[ARG:.*]]: tensor<?x3x?xf32>
- %arg0: tensor<?x3x?xf32>
-) -> tensor<?x3x?xf32> {
- // CHECK: %[[RESULT:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d_dynamic mesh_axes = [0] slice_axis = 0
- // CHECK-SAME: tensor<?x3x?xf32> -> tensor<?x3x?xf32>
- %s0 = mesh.sharding @mesh_1d_dynamic split_axes = [[], [], []] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<?x3x?xf32>
- %s1 = mesh.sharding @mesh_1d_dynamic split_axes = [[0]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<?x3x?xf32>
- // CHECK: return %[[RESULT]] : tensor<?x3x?xf32>
- return %1 : tensor<?x3x?xf32>
-}
-
-// CHECK-LABEL: func @move_split_axis
-func.func @move_split_axis(
- // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
- %arg0: tensor<10x14xf32>
-) -> tensor<10x14xf32> {
- // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32>
- // CHECK: %[[TARGET_SHARD:.*]] = mesh.all_to_all %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<5x14xf32> -> tensor<10x7xf32>
- // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x7xf32> to tensor<10x14xf32>
- %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
- %s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
- // CHECK: return %[[RES]] : tensor<10x14xf32>
- return %1 : tensor<10x14xf32>
-}
-
-// CHECK-LABEL: func @move_split_axis_dynamic_mesh
-func.func @move_split_axis_dynamic_mesh(
- // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
- %arg0: tensor<10x14xf32>
-) -> tensor<10x14xf32> {
- // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32>
- // CHECK: %[[ALL_TO_ALL:.*]] = mesh.all_to_all %[[SOURCE_SHARD]] on @mesh_1d_dynamic mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x?xf32>
- // CHECK: %[[TARGET_SHARD:.*]] = tensor.cast %[[ALL_TO_ALL]] : tensor<?x?xf32> to tensor<10x?xf32>
- // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x?xf32> to tensor<10x14xf32>
- %s0 = mesh.sharding @mesh_1d_dynamic split_axes = [[0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
- %s1 = mesh.sharding @mesh_1d_dynamic split_axes = [[], [0]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
- // CHECK: return %[[RES]] : tensor<10x14xf32>
- return %1 : tensor<10x14xf32>
-}
-
-// CHECK-LABEL: func @move_split_dynamic_axis
-func.func @move_split_dynamic_axis(
- // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32>
- %arg0: tensor<?x14xf32>
-) -> tensor<?x14xf32> {
- // CHECK: %[[TARGET_SHARD:.*]] = mesh.all_to_all %[[ARG]] on @mesh_1d mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x7xf32>
- // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<?x7xf32> to tensor<?x14xf32>
- %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<?x14xf32>
- %s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<?x14xf32>
- // CHECK: return %[[RES]] : tensor<?x14xf32>
- return %1 : tensor<?x14xf32>
-}
-
-// CHECK-LABEL: func @unshard_static_axis
-func.func @unshard_static_axis(
- // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
- %arg0: tensor<10x14xf32>
-) -> tensor<10x14xf32> {
- // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32>
- // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<5x14xf32> -> tensor<10x14xf32>
- %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
- %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
- // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32>
- return %1 : tensor<10x14xf32>
-}
-
-// CHECK-LABEL: func @unshard_static_last_axis
-func.func @unshard_static_last_axis(
- // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
- %arg0: tensor<10x14xf32>
-) -> tensor<10x14xf32> {
- // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<10x7xf32>
- // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] gather_axis = 1 : tensor<10x7xf32> -> tensor<10x14xf32>
- %s0 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
- %s1 = mesh.sharding @mesh_1d split_axes = [[], []] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
- // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32>
- return %1 : tensor<10x14xf32>
-}
-
-// CHECK-LABEL: func @unshard_dynamic_axis
-func.func @unshard_dynamic_axis(
- // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32>
- %arg0: tensor<?x14xf32>
-) -> tensor<?x14xf32> {
- // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[ARG]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32>
- %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<?x14xf32>
- %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<?x14xf32>
- // CHECK: return %[[ALL_GATHER]] : tensor<?x14xf32>
- return %1 : tensor<?x14xf32>
-}
-
-// CHECK-LABEL: func @unshard_static_axis_on_dynamic_mesh_axis
-func.func @unshard_static_axis_on_dynamic_mesh_axis(
-// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
- %arg0: tensor<10x14xf32>
-) -> tensor<10x14xf32> {
- // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32>
- // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d_dynamic mesh_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32>
- // CHECK: %[[RES:.*]] = tensor.cast %[[ALL_GATHER]] : tensor<?x14xf32> to tensor<10x14xf32>
- %s0 = mesh.sharding @mesh_1d_dynamic split_axes = [[0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
- %s1 = mesh.sharding @mesh_1d_dynamic split_axes = [[]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
- // CHECK: return %[[RES]] : tensor<10x14xf32>
- return %1 : tensor<10x14xf32>
-}
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
deleted file mode 100644
index 0881d994..0000000
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ /dev/null
@@ -1,301 +0,0 @@
-// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation,cse))" %s | FileCheck %s
-
-mesh.mesh @mesh_2(shape = 2)
-mesh.mesh @mesh_1d(shape = ?)
-mesh.mesh @mesh_2d(shape = 2x4)
-mesh.mesh @mesh_3d(shape = ?x?x?)
-
-// CHECK-LABEL: func.func @element_wise_empty_sharding_info
-func.func @element_wise_empty_sharding_info(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
- // CHECK-NEXT: tosa.sigmoid
- %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: return
- return %0 : tensor<8x16xf32>
-}
-
-// CHECK-LABEL: func.func @element_wise_on_def
-// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
-func.func @element_wise_on_def(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
- // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
- // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
- // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]]
- %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S0]] : tensor<8x16xf32>
- %s1 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 : tensor<8x16xf32>
- // CHECK-NEXT: return %[[V2]]
- return %1 : tensor<8x16xf32>
-}
-
-// CHECK-LABEL: func.func @element_wise_on_use
-// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
-func.func @element_wise_on_use(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
- // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
- // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
- %s0 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 annotate_for_users : tensor<8x16xf32>
- // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]]
- %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S0]] : tensor<8x16xf32>
- // CHECK-NEXT: return %[[V2]]
- return %1 : tensor<8x16xf32>
-}
-
-// CHECK-LABEL: func.func @element_wise_on_graph_output
-// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
-func.func @element_wise_on_graph_output(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
- // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
- // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
- // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]]
- %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S0]] : tensor<8x16xf32>
- // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
- %s1 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<8x16xf32>
- // CHECK-NEXT: return %[[V3]]
- return %1 : tensor<8x16xf32>
-}
-
-// CHECK-LABEL: func.func @element_wise_on_graph_input
-// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
-func.func @element_wise_on_graph_input(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
- // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
- // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] : tensor<8x16xf32>
- // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[V0]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
- %s0 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<8x16xf32>
- // CHECK-NEXT: %[[V2:.*]] = tosa.sigmoid %[[V1]]
- %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S0]] : tensor<8x16xf32>
- // CHECK-NEXT: return %[[V3]]
- return %1 : tensor<8x16xf32>
-}
-
-// CHECK-LABEL: func.func @arrow_structure
-// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
-func.func @arrow_structure(%arg0: tensor<8x16xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) {
- // CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
- // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG]] to %[[S1]] annotate_for_users : tensor<8x16xf32>
- // CHECK-NEXT: %[[V2:.*]] = tosa.tanh %[[V1]]
- // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S1]] : tensor<8x16xf32>
- %0 = tosa.tanh %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: %[[V4:.*]] = mesh.shard %[[V3]] to %[[S1]] annotate_for_users : tensor<8x16xf32>
- // CHECK-NEXT: %[[V5:.*]] = tosa.abs %[[V4]]
- // CHECK-NEXT: %[[V6:.*]] = mesh.shard %[[V5]] to %[[S1]] : tensor<8x16xf32>
- %1 = tosa.abs %0: (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
- // CHECK-NEXT: %[[ZP1:.*]] = mesh.shard %arg1 to %[[S3]] annotate_for_users : tensor<1xf32>
- // CHECK-NEXT: %[[ZP2:.*]] = mesh.shard %arg2 to %[[S3]] annotate_for_users : tensor<1xf32>
- // CHECK-NEXT: %[[V7:.*]] = tosa.negate %[[V4]], %[[ZP1]], %[[ZP2]]
- // CHECK-NEXT: %[[V8:.*]] = mesh.shard %[[V7]] to %[[S1]] : tensor<8x16xf32>
- %2 = tosa.negate %0, %arg1, %arg2 : (tensor<8x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<8x16xf32>
- %s3 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding
- %3 = mesh.shard %2 to %s3 : tensor<8x16xf32>
- // CHECK-NEXT: return %[[V6]], %[[V8]]
- return %1, %3 : tensor<8x16xf32>, tensor<8x16xf32>
-}
-
-// CHECK-LABEL: func.func @matmul_on_def_shard_batch_and_m
-// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32>
-func.func @matmul_on_def_shard_batch_and_m(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> {
- // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
- // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32>
- // CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0]] : !mesh.sharding
- // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32>
- // CHECK-NEXT: %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
- // CHECK-NEXT: %[[ZP:.*]] = mesh.shard %[[ARG2]] to %[[S2]] annotate_for_users : tensor<1xf32>
- // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]]
- %0 = tosa.matmul %arg0, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32>
- // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S0]] : tensor<2x16x32xf32>
- %s1 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 : tensor<2x16x32xf32>
- // CHECK-NEXT: return %[[V3]]
- return %1 : tensor<2x16x32xf32>
-}
-
-// CHECK-LABEL: func.func @matmul_on_def_shard_m_and_n
-// CHECK-SAME: [[varg0:%.*]]: tensor<2x16x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<1xf32>
-func.func @matmul_on_def_shard_m_and_n(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> {
- // CHECK: [[vsharding:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0]] : !mesh.sharding
- // CHECK: [[vsharded:%.*]] = mesh.shard [[varg0]] to [[vsharding]] annotate_for_users : tensor<2x16x8xf32>
- // CHECK: [[vsharding_0:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [], [1]] : !mesh.sharding
- // CHECK: [[vsharded_1:%.*]] = mesh.shard [[varg1]] to [[vsharding_0]] annotate_for_users : tensor<2x8x32xf32>
- // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
- // CHECK: [[vsharded_3:%.*]] = mesh.shard [[varg2]] to [[vsharding_2]] annotate_for_users : tensor<1xf32>
- // CHECK: [[v0:%.*]] = tosa.matmul
- %0 = tosa.matmul %arg0, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32>
- // CHECK: [[vsharding_4:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0], [1]] : !mesh.sharding
- // CHECK: [[vsharded_5:%.*]] = mesh.shard [[v0]] to [[vsharding_4]] : tensor<2x16x32xf32>
- %s1 = mesh.sharding @mesh_2d split_axes = [[], [0], [1]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 : tensor<2x16x32xf32>
- // CHECK-NEXT: return [[vsharded_5]]
- return %1 : tensor<2x16x32xf32>
-}
-
-// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_k
-// CHECK-SAME: [[varg0:%.*]]: tensor<2x16x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<1xf32>
-func.func @matmul_on_use_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> {
- // CHECK: [[vsharding:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0], [1]] : !mesh.sharding
- %s0 = mesh.sharding @mesh_2d split_axes = [[], [0], [1]] : !mesh.sharding
- // CHECK: [[vsharded:%.*]] = mesh.shard [[varg0]] to [[vsharding]] : tensor<2x16x8xf32>
- %arg0_s = mesh.shard %arg0 to %s0 : tensor<2x16x8xf32>
- // CHECK: [[vsharded_0:%.*]] = mesh.shard [[vsharded]] to [[vsharding]] annotate_for_users : tensor<2x16x8xf32>
- // CHECK: [[vsharding_1:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1]] : !mesh.sharding
- // CHECK: [[vsharded_2:%.*]] = mesh.shard [[varg1]] to [[vsharding_1]] annotate_for_users : tensor<2x8x32xf32>
- // CHECK: [[vsharding_3:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
- // CHECK: [[vsharded_4:%.*]] = mesh.shard [[varg2]] to [[vsharding_3]] annotate_for_users : tensor<1xf32>
- // CHECK: [[v0:%.*]] = tosa.matmul
- // CHECK: [[vsharding_5:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0]] : !mesh.sharding
- // CHECK: [[vsharded_6:%.*]] = mesh.shard [[v0]] to [[vsharding_5]] : tensor<2x16x32xf32>
- %0 = tosa.matmul %arg0_s, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32>
- // CHECK: return [[vsharded_6]]
- return %0 : tensor<2x16x32xf32>
-}
-
-// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_duplicted_k
-// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32>
-func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> {
- // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1], [0]] : !mesh.sharding
- // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32>
- %s0 = mesh.sharding @mesh_2d split_axes = [[], [1], [0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 annotate_for_users : tensor<2x16x8xf32>
- // CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0]] : !mesh.sharding
- // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32>
- %s1 = mesh.sharding @mesh_2d split_axes = [[], [0]] : !mesh.sharding
- %1 = mesh.shard %arg1 to %s1 annotate_for_users : tensor<2x8x32xf32>
- // CHECK-NEXT: %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
- // CHECK-NEXT: %[[ZP:.*]] = mesh.shard %[[ARG2]] to %[[S2]] annotate_for_users : tensor<1xf32>
- // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]]
- %2 = tosa.matmul %0, %1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32>
- // CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1]] : !mesh.sharding
- // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]] : tensor<2x16x32xf32>
- // CHECK-NEXT: return %[[V3]]
- return %2 : tensor<2x16x32xf32>
-}
-
-// CHECK-LABEL: func.func @resolve_conflicting_annotations
-func.func @resolve_conflicting_annotations(
- // CHECK-SAME: %[[IN1:.*]]: tensor<2x3xf32>,
- %arg0: tensor<2x3xf32>,
- // CHECK-SAME: %[[IN2:.*]]: tensor<3x2xf32>,
- %arg1: tensor<3x2xf32>,
- // CHECK-SAME: %[[OUT_DPS:.*]]: tensor<2x2xf32>
- %out_dps: tensor<2x2xf32>
-// CHECK-SAME: ) -> tensor<2x2xf32> {
-) -> tensor<2x2xf32> {
- // CHECK: %[[SIN1_SHARDED1:.*]] = mesh.sharding @mesh_2 split_axes = {{\[\[}}0]] : !mesh.sharding
- // CHECK-NEXT: %[[IN1_SHARDED1:.*]] = mesh.shard %[[IN1]] to %[[SIN1_SHARDED1]] : tensor<2x3xf32>
- // CHECK: %[[SIN2_SHARDED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[\[}}]] : !mesh.sharding
- // CHECK-NEXT: %[[IN1_SHARDED2:.*]] = mesh.shard %[[IN1_SHARDED1]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<2x3xf32>
- // CHECK-NEXT: %[[IN2_SHARDED:.*]] = mesh.shard %[[IN2]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<3x2xf32>
- // CHECK-NEXT: %[[OUT_DPS_SHARDED:.*]] = mesh.shard %[[OUT_DPS]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<2x2xf32>
- %sarg0_sharded = mesh.sharding @mesh_2 split_axes = [[0]] : !mesh.sharding
- %arg0_sharded = mesh.shard %arg0 to %sarg0_sharded : tensor<2x3xf32>
- // CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[IN1_SHARDED2]], %[[IN2_SHARDED]] : tensor<2x3xf32>, tensor<3x2xf32>)
- // CHECK-SAME: outs(%[[OUT_DPS_SHARDED]] : tensor<2x2xf32>) -> tensor<2x2xf32>
- %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>)
- outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32>
- // CHECK-NEXT: %[[RES:.*]] = mesh.shard %[[MATMUL]] to %[[SIN2_SHARDED]] : tensor<2x2xf32>
- %sres_sharded = mesh.sharding @mesh_2 split_axes = [[]] : !mesh.sharding
- %res_sharded = mesh.shard %res to %sres_sharded : tensor<2x2xf32>
- // CHECK: return %[[RES]] : tensor<2x2xf32>
- return %res_sharded : tensor<2x2xf32>
-}
-
-// https://arxiv.org/abs/2211.05102 Figure 2(a)
-// The sharding propagation results in unnecessary reshards,
-// an optimization pass should be able to remove them.
-// CHECK-LABEL: func.func @mlp_1d_weight_stationary
-// CHECK-SAME: [[varg0:%.*]]: tensor<2x4x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<2x32x8xf32>, [[varg3:%.*]]: tensor<1xf32>
-func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>, %arg3: tensor<1xf32>) -> tensor<2x4x8xf32> {
- %s0 = mesh.sharding @mesh_1d split_axes = [[], [], [0, 1, 2]] : !mesh.sharding
- %sharded0 = mesh.shard %arg0 to %s0 : tensor<2x4x8xf32>
- %sharded1 = mesh.shard %arg1 to %s0 : tensor<2x8x32xf32>
- // CHECK: [[vsharding:%.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [], [0, 1, 2]] : !mesh.sharding
- // CHECK: [[vsharded:%.*]] = mesh.shard [[varg0]] to [[vsharding]] : tensor<2x4x8xf32>
- // CHECK: [[vsharded_0:%.*]] = mesh.shard [[varg1]] to [[vsharding]] : tensor<2x8x32xf32>
- // CHECK: [[vsharded_1:%.*]] = mesh.shard [[vsharded]] to [[vsharding]] annotate_for_users : tensor<2x4x8xf32>
- // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [0, 1, 2]] : !mesh.sharding
- // CHECK: [[vsharded_3:%.*]] = mesh.shard [[vsharded_0]] to [[vsharding_2]] annotate_for_users : tensor<2x8x32xf32>
- // CHECK: [[vsharding_4:%.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}]] : !mesh.sharding
- // CHECK: [[vsharded_5:%.*]] = mesh.shard [[varg3]] to [[vsharding_4]] annotate_for_users : tensor<1xf32>
- // CHECK: [[v0:%.*]] = tosa.matmul
- %1 = tosa.matmul %sharded0, %sharded1, %arg3, %arg3 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x32xf32>
- // CHECK: [[vsharded_6:%.*]] = mesh.shard [[v0]] to [[vsharding_4]] : tensor<2x4x32xf32>
- // CHECK: [[vsharded_7:%.*]] = mesh.shard [[vsharded_6]] to [[vsharding_4]] annotate_for_users : tensor<2x4x32xf32>
- // CHECK: [[v1:%.*]] = tosa.sigmoid [[vsharded_7]] : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
- // CHECK: [[vsharded_8:%.*]] = mesh.shard [[v1]] to [[vsharding_4]] : tensor<2x4x32xf32>
- %2 = tosa.sigmoid %1 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
- %sharding = mesh.sharding @mesh_1d split_axes = [[], [0, 1, 2]] : !mesh.sharding
- // CHECK: [[vsharded_9:%.*]] = mesh.shard [[varg2]] to [[vsharding_2]] : tensor<2x32x8xf32>
- %sharded2 = mesh.shard %arg2 to %sharding : tensor<2x32x8xf32>
- // CHECK: [[vsharded_10:%.*]] = mesh.shard [[vsharded_8]] to [[vsharding_4]] annotate_for_users : tensor<2x4x32xf32>
- // CHECK: [[vsharded_11:%.*]] = mesh.shard [[vsharded_9]] to [[vsharding]] annotate_for_users : tensor<2x32x8xf32>
- // CHECK: [[v2:%.*]] = tosa.matmul
- %3 = tosa.matmul %2, %sharded2, %arg3, %arg3 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x8xf32>
- // CHECK: [[vsharded_12:%.*]] = mesh.shard [[v2]] to [[vsharding]] : tensor<2x4x8xf32>
- %s4 = mesh.sharding @mesh_1d split_axes = [[], [], [0, 1, 2]] : !mesh.sharding
- %4 = mesh.shard %3 to %s4 : tensor<2x4x8xf32>
- // CHECK: return [[vsharded_12]]
- return %4 : tensor<2x4x8xf32>
-}
-
-// https://arxiv.org/abs/2211.05102 Figure 2(b)
-// The sharding propagation results in unnecessary reshards,
-// an optimization pass should be able to remove them.
-// CHECK-LABEL: func.func @mlp_2d_weight_stationary
-// CHECK-SAME: [[varg0:%.*]]: tensor<2x4x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<2x32x8xf32>, [[varg3:%.*]]: tensor<1xf32>
-func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>, %arg3: tensor<1xf32>) -> tensor<2x4x8xf32> {
- // CHECK: [[vsharding:%.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [], [0, 1, 2]] : !mesh.sharding
- %s0 = mesh.sharding @mesh_3d split_axes = [[], [], [0, 1, 2]] : !mesh.sharding
- // CHECK: [[vsharded:%.*]] = mesh.shard [[varg0]] to [[vsharding]] : tensor<2x4x8xf32>
- %arg0_s = mesh.shard %arg0 to %s0 : tensor<2x4x8xf32>
- // CHECK: [[vsharding_0:%.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [0], [1, 2]] : !mesh.sharding
- %s1 = mesh.sharding @mesh_3d split_axes = [[], [0], [1, 2]] : !mesh.sharding
- // CHECK: [[vsharded_1:%.*]] = mesh.shard [[varg1]] to [[vsharding_0]] : tensor<2x8x32xf32>
- %arg1_s = mesh.shard %arg1 to %s1 : tensor<2x8x32xf32>
- // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}]] : !mesh.sharding
- // CHECK: [[vsharded_3:%.*]] = mesh.shard [[vsharded]] to [[vsharding_2]] annotate_for_users : tensor<2x4x8xf32>
- // CHECK: [[vsharded_4:%.*]] = mesh.shard [[vsharded_1]] to [[vsharding]] annotate_for_users : tensor<2x8x32xf32>
- // CHECK: [[vsharded_5:%.*]] = mesh.shard [[varg3]] to [[vsharding_2]] annotate_for_users : tensor<1xf32>
- // CHECK: [[v0:%.*]] = tosa.matmul
- %1 = tosa.matmul %arg0_s, %arg1_s, %arg3, %arg3 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x32xf32>
- // CHECK: [[vsharded_6:%.*]] = mesh.shard [[v0]] to [[vsharding]] : tensor<2x4x32xf32>
- %2 = mesh.shard %1 to %s0 : tensor<2x4x32xf32>
- // CHECK: [[vsharded_7:%.*]] = mesh.shard [[vsharded_6]] to [[vsharding]] annotate_for_users : tensor<2x4x32xf32>
- // CHECK: [[v1:%.*]] = tosa.sigmoid
- // CHECK: [[vsharded_8:%.*]] = mesh.shard [[v1]] to [[vsharding]] : tensor<2x4x32xf32>
- %3 = tosa.sigmoid %2 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
- // CHECK: [[vsharding_9:%.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [1, 2], [0]] : !mesh.sharding
- %s2 = mesh.sharding @mesh_3d split_axes = [[], [1, 2], [0]] : !mesh.sharding
- // CHECK: [[vsharded_10:%.*]] = mesh.shard [[varg2]] to [[vsharding_9]] : tensor<2x32x8xf32>
- %arg2_s = mesh.shard %arg2 to %s2 : tensor<2x32x8xf32>
- // CHECK: [[vsharded_11:%.*]] = mesh.shard [[vsharded_8]] to [[vsharding_2]] annotate_for_users : tensor<2x4x32xf32>
- // CHECK: [[vsharded_12:%.*]] = mesh.shard [[vsharded_10]] to [[vsharding]] annotate_for_users : tensor<2x32x8xf32>
- // CHECK: [[v2:%.*]] = tosa.matmul
- %4 = tosa.matmul %3, %arg2_s, %arg3, %arg3 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x8xf32>
- // CHECK: [[vsharded_13:%.*]] = mesh.shard [[v2]] to [[vsharding]] : tensor<2x4x8xf32>
- %5 = mesh.shard %4 to %s0 : tensor<2x4x8xf32>
- // CHECK: [[vsharded_14:%.*]] = mesh.shard [[vsharded_13]] to [[vsharding]] annotate_for_users : tensor<2x4x8xf32>
- %6 = mesh.shard %5 to %s0 annotate_for_users : tensor<2x4x8xf32>
- // CHECK: return [[vsharded_14]]
- return %6 : tensor<2x4x8xf32>
-}
-
-// CHECK-LABEL: func.func @elementwise_duplicated_chain
-// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
-func.func @elementwise_duplicated_chain(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
- // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
- // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
- // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]]
- %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S0]] : tensor<8x16xf32>
- // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
- // CHECK-NEXT: %[[V4:.*]] = tosa.sigmoid %[[V3]]
- %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: %[[V5:.*]] = mesh.shard %[[V4]] to %[[S0]] : tensor<8x16xf32>
- %s0 = mesh.sharding @mesh_2d split_axes = [[]] : !mesh.sharding
- %2 = mesh.shard %1 to %s0 : tensor<8x16xf32>
- // CHECK-NEXT: return %[[V5]]
- return %2 : tensor<8x16xf32>
-}
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
deleted file mode 100644
index 701898c..0000000
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ /dev/null
@@ -1,317 +0,0 @@
-// RUN: mlir-opt \
-// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-single-fold))" \
-// RUN: %s | FileCheck %s
-
-mesh.mesh @mesh_1d(shape = 2)
-
-// CHECK-LABEL: func @return_sharding
-func.func @return_sharding(
- // CHECK-SAME: [[ARG:%.*]]: tensor<1xf32>
- %arg0: tensor<2xf32>
-// CHECK-SAME: ) -> (tensor<1xf32>, !mesh.sharding) {
-) -> (tensor<2xf32>, !mesh.sharding) {
- %ssharding_annotated = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %sharding_annotated = mesh.shard %arg0 to %ssharding_annotated : tensor<2xf32>
- // CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}0]] : !mesh.sharding
- %r = mesh.get_sharding %sharding_annotated : tensor<2xf32> -> !mesh.sharding
- // CHECK-NEXT: return [[ARG]], [[vsharding]] : tensor<1xf32>, !mesh.sharding
- return %sharding_annotated, %r : tensor<2xf32>, !mesh.sharding
-}
-
-// CHECK-LABEL: func @full_replication
-func.func @full_replication(
- // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
- %arg0: tensor<2xi8>
-// CHECK-SAME: -> tensor<2xi8> {
-) -> tensor<2xi8> {
- %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<2xi8>
- %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2xi8>
- // CHECK: return %[[ARG]] : tensor<2xi8>
- return %1 : tensor<2xi8>
-}
-
-// CHECK-LABEL: func @sharding_triplet
-func.func @sharding_triplet(
- // CHECK-SAME: %[[ARG:.*]]: tensor<1xf32>
- %arg0: tensor<2xf32>
-// CHECK-SAME: ) -> tensor<2xf32> {
-) -> tensor<2xf32> {
- // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[ARG]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<1xf32> -> tensor<2xf32>
- %ssharding_annotated = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %sharding_annotated = mesh.shard %arg0 to %ssharding_annotated : tensor<2xf32>
- %ssharding_annotated_0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %sharding_annotated_0 = mesh.shard %sharding_annotated to %ssharding_annotated_0 annotate_for_users : tensor<2xf32>
- %ssharding_annotated_1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %sharding_annotated_1 = mesh.shard %sharding_annotated_0 to %ssharding_annotated_1 : tensor<2xf32>
- // CHECK: return %[[ALL_GATHER]] : tensor<2xf32>
- return %sharding_annotated_1 : tensor<2xf32>
-}
-
-
-// CHECK-LABEL: func @move_split_axis
-func.func @move_split_axis(
- // CHECK-SAME: %[[ARG:.*]]: tensor<1x2xi8>
- %arg0: tensor<2x2xi8>
-// CHECK-SAME: -> tensor<2x1xi8> {
-) -> tensor<2x2xi8> {
- // CHECK: %[[ALL_TO_ALL:.*]] = mesh.all_to_all %[[ARG]] on @mesh_1d
- // CHECK-SAME: mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<1x2xi8> -> tensor<2x1xi8>
- %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<2x2xi8>
- %s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2x2xi8>
- // CHECK: return %[[ALL_TO_ALL]] : tensor<2x1xi8>
- return %1 : tensor<2x2xi8>
-}
-
-// CHECK-LABEL: func @non_tensor_value
-func.func @non_tensor_value(
- // CHECK-SAME: %[[ARG:.*]]: i8
- %arg0: i8
-// CHECK-SAME: -> i8 {
-) -> i8 {
- // CHECK: %[[RES:.*]] = arith.addi %[[ARG]], %[[ARG]] : i8
- %0 = arith.addi %arg0, %arg0 : i8
- // CHECK: return %[[RES]] : i8
- return %0 : i8
-}
-
-// CHECK-LABEL: func @unary_elementwise
-func.func @unary_elementwise(
- // CHECK-SAME: %[[ARG:.*]]: tensor<1xi8>
- %arg0: tensor<2xi8>
-// CHECK-SAME: -> tensor<1xi8> {
-) -> tensor<2xi8> {
- %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<2xi8>
- %s1 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2xi8>
- // CHECK: %[[RES:.*]] = tosa.abs %[[ARG]] : (tensor<1xi8>) -> tensor<1xi8>
- %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
- %s3 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %3 = mesh.shard %2 to %s3 : tensor<2xi8>
- %s4 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %4 = mesh.shard %3 to %s4 annotate_for_users : tensor<2xi8>
- // CHECK: return %[[RES]] : tensor<1xi8>
- return %4 : tensor<2xi8>
-}
-
-// full replication -> shard axis -> abs -> shard axis -> full replication
-// CHECK-LABEL: func @unary_elementwise_with_resharding
-func.func @unary_elementwise_with_resharding(
- // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
- %arg0: tensor<2xi8>
-// CHECK-SAME: -> tensor<2xi8> {
-) -> tensor<2xi8> {
- // CHECK: %[[SLICE:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 0
- // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
- %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<2xi8>
- %s1 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2xi8>
- // CHECK: %[[ABS:.*]] = tosa.abs %[[SLICE]] : (tensor<1xi8>) -> tensor<1xi8>
- %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
- // CHECK: %[[RES:.*]] = mesh.all_gather %[[ABS]] on @mesh_1d
- // CHECK-SAME: mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
- %s3 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %3 = mesh.shard %2 to %s3 : tensor<2xi8>
- %s4 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %4 = mesh.shard %3 to %s4 annotate_for_users : tensor<2xi8>
- // CHECK: return %[[RES]] : tensor<2xi8>
- return %4 : tensor<2xi8>
-}
-
-// CHECK-LABEL: func @binary_elementwise
-func.func @binary_elementwise(
- // CHECK-SAME: %[[ARG0:.*]]: tensor<1xi8>,
- %arg0: tensor<2xi8>,
- // CHECK-SAME: %[[ARG1:.*]]: tensor<1xi8>
- %arg1: tensor<2xi8>
-// CHECK-SAME: -> tensor<1xi8> {
-) -> tensor<2xi8> {
- %sarg0_sharded = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %arg0_sharded = mesh.shard %arg0 to %sarg0_sharded : tensor<2xi8>
- %sop_arg0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %op_arg0 = mesh.shard %arg0_sharded to %sop_arg0 annotate_for_users : tensor<2xi8>
- %sarg1_sharded = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %arg1_sharded = mesh.shard %arg1 to %sarg1_sharded : tensor<2xi8>
- %sop_arg1 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %op_arg1 = mesh.shard %arg1_sharded to %sop_arg1 annotate_for_users : tensor<2xi8>
- // CHECK: %[[RES:.*]] = tosa.add %[[ARG0]], %[[ARG1]] : (tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8>
- %op_res = tosa.add %op_arg0, %op_arg1 : (tensor<2xi8>, tensor<2xi8>) -> tensor<2xi8>
- %sop_res_sharded = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %op_res_sharded = mesh.shard %op_res to %sop_res_sharded : tensor<2xi8>
- %sres = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %res = mesh.shard %op_res_sharded to %sres annotate_for_users : tensor<2xi8>
- // CHECK: return %[[RES]] : tensor<1xi8>
- return %res : tensor<2xi8>
-}
-
-// reshard
-// abs
-// reshard
-// abs
-// reshard
-// CHECK-LABEL: func @multiple_chained_ops
-func.func @multiple_chained_ops(
- // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
- %arg0: tensor<2xi8>
-// CHECK-SAME: -> tensor<1xi8> {
-) -> tensor<2xi8> {
- // CHECK: %[[RESHARD1:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 0
- // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
- %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<2xi8>
- %s1 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2xi8>
- // CHECK: %[[ABS1:.*]] = tosa.abs %[[RESHARD1]] : (tensor<1xi8>) -> tensor<1xi8>
- %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
- // CHECK: %[[RESHARD2:.*]] = mesh.all_gather %[[ABS1]] on @mesh_1d
- // CHECK-SAME: mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
- %s3 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %3 = mesh.shard %2 to %s3 : tensor<2xi8>
- %s4 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %4 = mesh.shard %3 to %s4 annotate_for_users : tensor<2xi8>
- // CHECK: %[[ABS2:.*]] = tosa.abs %[[RESHARD2]] : (tensor<2xi8>) -> tensor<2xi8>
- %5 = tosa.abs %4 : (tensor<2xi8>) -> tensor<2xi8>
- // CHECK: %[[RESHARD3:.*]] = mesh.all_slice %[[ABS2]] on @mesh_1d mesh_axes = [0] slice_axis = 0 :
- // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
- %s6 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %6 = mesh.shard %5 to %s6 : tensor<2xi8>
- %s7 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %7 = mesh.shard %6 to %s7 annotate_for_users : tensor<2xi8>
- // CHECK: return %[[RESHARD3]] : tensor<1xi8>
- return %7 : tensor<2xi8>
-}
-
-// CHECK-LABEL: func @incomplete_sharding
-func.func @incomplete_sharding(
- // CHECK-SAME: %[[ARG:.*]]: tensor<4x16xf32>
- %arg0: tensor<8x16xf32>
-// CHECK-SAME: -> tensor<4x16xf32> {
-) -> tensor<8x16xf32> {
- %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 annotate_for_users : tensor<8x16xf32>
- // CHECK: %[[RES:.*]] = tosa.sigmoid %[[ARG]] : (tensor<4x16xf32>) -> tensor<4x16xf32>
- %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- %s2 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %2 = mesh.shard %1 to %s2 : tensor<8x16xf32>
- // CHECK: return %[[RES]] : tensor<4x16xf32>
- return %2 : tensor<8x16xf32>
-}
-
-mesh.mesh @mesh_1d_4(shape = 4)
-
-// CHECK-LABEL: func @ew_chain_with_halo
-func.func @ew_chain_with_halo(
- // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<5x16xf32>
- %arg0: tensor<8x16xf32>,
- // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<1xf32>
- %arg1: tensor<1xf32>,
- // CHECK-SAME: %[[IN3:[A-Za-z0-9_]+]]: tensor<1xf32>
- %arg2: tensor<1xf32>)
- // CHECK-SAME: -> tensor<5x16xf32>
- -> tensor<8x16xf32> {
- %ssharding_annotated = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
- %sharding_annotated = mesh.shard %arg0 to %ssharding_annotated annotate_for_users : tensor<8x16xf32>
- // CHECK: %[[TMP1:.*]] = tosa.tanh %[[IN1]] : (tensor<5x16xf32>) -> tensor<5x16xf32>
- %0 = tosa.tanh %sharding_annotated : (tensor<8x16xf32>) -> tensor<8x16xf32>
- %ssharding_annotated_0 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
- %sharding_annotated_0 = mesh.shard %0 to %ssharding_annotated_0 : tensor<8x16xf32>
- %ssharding_annotated_1 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
- %sharding_annotated_1 = mesh.shard %sharding_annotated_0 to %ssharding_annotated_1 annotate_for_users : tensor<8x16xf32>
- // CHECK-NEXT: %[[TMP2:.*]] = tosa.abs %[[TMP1]] : (tensor<5x16xf32>) -> tensor<5x16xf32>
- %1 = tosa.abs %sharding_annotated_1 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- %ssharding_annotated_2 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
- %sharding_annotated_2 = mesh.shard %1 to %ssharding_annotated_2 : tensor<8x16xf32>
- %ssharding_annotated_4 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
- %sharding_annotated_4 = mesh.shard %sharding_annotated_2 to %ssharding_annotated_4 annotate_for_users : tensor<8x16xf32>
- // CHECK-NEXT: %[[TMP3:.*]] = tosa.negate %[[TMP2]], %[[IN2]], %[[IN3]] : (tensor<5x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x16xf32>
- %sharding_1 = mesh.sharding @mesh_1d_4 split_axes = [[]] : !mesh.sharding
- %zero_point_1 = mesh.shard %arg1 to %sharding_1 annotate_for_users : tensor<1xf32>
- %zero_point_2 = mesh.shard %arg2 to %sharding_1 annotate_for_users : tensor<1xf32>
- %2 = tosa.negate %sharding_annotated_4, %zero_point_1, %zero_point_2 : (tensor<8x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<8x16xf32>
- %ssharding_annotated_5 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
- %sharding_annotated_5 = mesh.shard %2 to %ssharding_annotated_5 : tensor<8x16xf32>
- %ssharding_annotated_6 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
- %sharding_annotated_6 = mesh.shard %sharding_annotated_5 to %ssharding_annotated_6 annotate_for_users : tensor<8x16xf32>
- // CHECK-NEXT: return %[[TMP3]] : tensor<5x16xf32>
- return %sharding_annotated_6 : tensor<8x16xf32>
-}
-
-// CHECK-LABEL: func @test_shard_update_halo
-// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x1200xi64>
-func.func @test_shard_update_halo(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> {
- %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] : !mesh.sharding
- // CHECK: %[[T:.*]] = tensor.empty() : tensor<304x1200xi64>
- // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][2, 0] [300, 1200] [1, 1] : tensor<300x1200xi64> into tensor<304x1200xi64>
- // CHECK: %[[UH:.*]] = mesh.update_halo %[[inserted_slice]] on @mesh_1d_4 split_axes = {{\[\[0]]}} halo_sizes = [2, 2] : tensor<304x1200xi64>
- %sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64>
- %sharding_0 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 2] : !mesh.sharding
- %sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64>
- %sharding_annotated_3 = mesh.shard %sharding_annotated_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64>
- // CHECK: return %[[UH]] : tensor<304x1200xi64>
- return %sharding_annotated_3 : tensor<1200x1200xi64>
-}
-
-mesh.mesh @mesh4x4(shape = 4x4)
-// CHECK-LABEL: func @test_shard_update_halo2d
-// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x300xi64>
-func.func @test_shard_update_halo2d(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> {
- %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] : !mesh.sharding
- // CHECK: %[[T:.*]] = tensor.empty() : tensor<303x307xi64>
- // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][1, 3] [300, 300] [1, 1] : tensor<300x300xi64> into tensor<303x307xi64>
- // CHECK: %[[UH:.*]] = mesh.update_halo %[[inserted_slice]] on @mesh4x4 split_axes = {{\[\[}}0], [1]] halo_sizes = [1, 2, 3, 4] : tensor<303x307xi64>
- %sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64>
- %sharding_0 = mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 3, 4] : !mesh.sharding
- %sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64>
- %sharding_annotated_3 = mesh.shard %sharding_annotated_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64>
- // CHECK: return %[[UH]] : tensor<303x307xi64>
- return %sharding_annotated_3 : tensor<1200x1200xi64>
-}
-
-mesh.mesh @mesh(shape = 2)
-// CHECK-LABEL: func.func @test_reduce_0d(
-// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<3x6xi32>
-func.func @test_reduce_0d(%arg0: tensor<6x6xi32>) -> (tensor<i32>) {
- %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
- %sharded = mesh.shard %arg0 to %sharding annotate_for_users : tensor<6x6xi32>
- %4 = tensor.empty() : tensor<i32>
- %sharding_out = mesh.sharding @mesh split_axes = [[]] : !mesh.sharding
- %sharded_out = mesh.shard %4 to %sharding_out : tensor<i32>
- %sharded_in = mesh.shard %sharded to %sharding annotate_for_users : tensor<6x6xi32>
- // CHECK: %[[reduced:.*]] = linalg.reduce ins(%arg0 : tensor<3x6xi32>)
- %reduced = linalg.reduce ins(%sharded_in : tensor<6x6xi32>) outs(%sharded_out : tensor<i32>) dimensions = [0, 1]
- (%in: i32, %init: i32) {
- %6 = arith.addi %in, %init : i32
- linalg.yield %6 : i32
- }
- // CHECK: %[[all_reduce:.*]] = mesh.all_reduce %[[reduced]] on @mesh mesh_axes = [0] : tensor<i32> -> tensor<i32>
- %sharded_red = mesh.shard %reduced to %sharding_out : tensor<i32>
- %sharded_ret = mesh.shard %sharded_red to %sharding_out annotate_for_users : tensor<i32>
- // CHECK: return %[[all_reduce]] : tensor<i32>
- return %sharded_ret : tensor<i32>
-}
-
-// CHECK-LABEL: func.func @test_reduce_1d(
-// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<3x6xi32>
-func.func @test_reduce_1d(%arg0: tensor<6x6xi32>) -> (tensor<6xi32>) {
- %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
- %sharded = mesh.shard %arg0 to %sharding annotate_for_users : tensor<6x6xi32>
- %4 = tensor.empty() : tensor<6xi32>
- %sharded_out = mesh.shard %4 to %sharding : tensor<6xi32>
- %sharded_in = mesh.shard %sharded to %sharding annotate_for_users : tensor<6x6xi32>
- // CHECK: %[[reduced:.*]] = linalg.reduce ins(%arg0 : tensor<3x6xi32>)
- %reduced = linalg.reduce ins(%sharded_in : tensor<6x6xi32>) outs(%sharded_out : tensor<6xi32>) dimensions = [1]
- (%in: i32, %init: i32) {
- %6 = arith.addi %in, %init : i32
- linalg.yield %6 : i32
- }
- // CHECK-NOT: mesh.all_reduce
- %sharded_red = mesh.shard %reduced to %sharding : tensor<6xi32>
- %sharded_ret = mesh.shard %sharded_red to %sharding annotate_for_users : tensor<6xi32>
- // CHECK: return %[[reduced]] : tensor<3xi32>
- return %sharded_ret : tensor<6xi32>
-}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 4c50ed3..8c846cd 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -1406,7 +1406,7 @@ func.func @omp_atomic_update(%x : memref<i32>, %expr : i32, %xBool : memref<i1>,
// CHECK-NEXT: (%[[XVAL:.*]]: i1):
// CHECK-NEXT: %[[NEWVAL:.*]] = llvm.icmp "eq" %[[XVAL]], %[[EXPRBOOL]] : i1
// CHECK-NEXT: omp.yield(%[[NEWVAL]] : i1)
- // }
+ // CHECK-NEXT: }
omp.atomic.update %xBool : memref<i1> {
^bb0(%xval: i1):
%newval = llvm.icmp "eq" %xval, %exprBool : i1
@@ -1562,6 +1562,14 @@ func.func @omp_atomic_update(%x : memref<i32>, %expr : i32, %xBool : memref<i1>,
omp.yield(%newval : i32)
}
+ // CHECK: omp.atomic.update %[[X]] : memref<i32> {
+ // CHECK-NEXT: (%[[XVAL:.*]]: i32):
+ // CHECK-NEXT: omp.yield(%{{.+}} : i32)
+ // CHECK-NEXT: } {atomic_control = #omp.atomic_control<ignore_denormal_mode = true, fine_grained_memory = true, remote_memory = true>}
+ omp.atomic.update %x : memref<i32> {
+ ^bb0(%xval:i32):
+ omp.yield(%const:i32)
+ } {atomic_control = #omp.atomic_control<ignore_denormal_mode = true, fine_grained_memory = true, remote_memory = true>}
return
}
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 12d30e17..308cf150 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1440,8 +1440,8 @@ func.func @propagate_into_execute_region() {
// -----
-// CHECK-LABEL: func @execute_region_elim
-func.func @execute_region_elim() {
+// CHECK-LABEL: func @execute_region_inline
+func.func @execute_region_inline() {
affine.for %i = 0 to 100 {
"test.foo"() : () -> ()
%v = scf.execute_region -> i64 {
@@ -1461,8 +1461,30 @@ func.func @execute_region_elim() {
// -----
-// CHECK-LABEL: func @func_execute_region_elim
-func.func @func_execute_region_elim() {
+// CHECK-LABEL: func @execute_region_no_inline
+func.func @execute_region_no_inline() {
+ affine.for %i = 0 to 100 {
+ "test.foo"() : () -> ()
+ %v = scf.execute_region -> i64 no_inline {
+ %x = "test.val"() : () -> i64
+ scf.yield %x : i64
+ }
+ "test.bar"(%v) : (i64) -> ()
+ }
+ return
+}
+
+// CHECK-NEXT: affine.for %arg0 = 0 to 100 {
+// CHECK-NEXT: "test.foo"() : () -> ()
+// CHECK-NEXT: scf.execute_region
+// CHECK-NEXT: %[[VAL:.*]] = "test.val"() : () -> i64
+// CHECK-NEXT: scf.yield %[[VAL]] : i64
+// CHECK-NEXT: }
+
+// -----
+
+// CHECK-LABEL: func @func_execute_region_inline
+func.func @func_execute_region_inline() {
"test.foo"() : () -> ()
%v = scf.execute_region -> i64 {
%c = "test.cmp"() : () -> i1
@@ -1496,8 +1518,8 @@ func.func @func_execute_region_elim() {
// -----
-// CHECK-LABEL: func @func_execute_region_elim_multi_yield
-func.func @func_execute_region_elim_multi_yield() {
+// CHECK-LABEL: func @func_execute_region_inline_multi_yield
+func.func @func_execute_region_inline_multi_yield() {
"test.foo"() : () -> ()
%v = scf.execute_region -> i64 {
%c = "test.cmp"() : () -> i1
diff --git a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
index d6c3464..58b8288 100644
--- a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
@@ -33,6 +33,24 @@ func.func @inotequal_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vecto
// -----
//===----------------------------------------------------------------------===//
+// spirv.IsFinite
+//===----------------------------------------------------------------------===//
+
+func.func @isfinite_scalar(%arg0: f32) -> i1 {
+ // CHECK: spirv.IsFinite {{.*}} : f32
+ %0 = spirv.IsFinite %arg0 : f32
+ return %0 : i1
+}
+
+func.func @isfinite_vector(%arg0: vector<2xf32>) -> vector<2xi1> {
+ // CHECK: spirv.IsFinite {{.*}} : vector<2xf32>
+ %0 = spirv.IsFinite %arg0 : vector<2xf32>
+ return %0 : vector<2xi1>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
// spirv.IsInf
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index 5d05a654..6d321af 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -296,6 +296,12 @@ func.func private @struct_type_with_matrix_2(!spirv.struct<(!spirv.matrix<3 x ve
// CHECK: func private @struct_empty(!spirv.struct<()>)
func.func private @struct_empty(!spirv.struct<()>)
+// CHECK: func.func private @struct_block(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>)
+func.func private @struct_block(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>)
+
+// CHECK: func.func private @struct_two_dec(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block, CPacked>)
+func.func private @struct_two_dec(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block, CPacked>)
+
// -----
// expected-error @+1 {{offset specification must be given for all members}}
diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
index bd51a07..f3a3218 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
@@ -66,3 +66,27 @@ spirv.module Logical GLSL450 attributes {spirv.target_env = #spirv.target_env<#s
// CHECK: spirv.EntryPoint "GLCompute" [[FN]], [[VAR0]], [[VAR1]]
// CHECK: spirv.ExecutionMode [[FN]] "LocalSize", 32, 1, 1
} // end spirv.module
+
+// -----
+
+module {
+ spirv.module Logical GLSL450 attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Sampled1D], []>, #spirv.resource_limits<>>} {
+ // CHECK-DAG: spirv.GlobalVariable @[[IMAGE_GV:.*]] bind(0, 0) : !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
+ // CHECK: spirv.func @read_image
+ spirv.func @read_image(%arg0: !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) "None" attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1]>} {
+ // CHECK: %[[IMAGE_ADDR:.*]] = spirv.mlir.addressof @[[IMAGE_GV]] : !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
+ %cst0_i32 = spirv.Constant 0 : i32
+ // CHECK: spirv.Load "UniformConstant" %[[IMAGE_ADDR]]
+ %0 = spirv.Load "UniformConstant" %arg0 : !spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
+ %1 = spirv.Image %0 : !spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
+ %2 = spirv.ImageFetch %1, %cst0_i32 : !spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>, i32 -> vector<4xf32>
+ %3 = spirv.CompositeExtract %2[0 : i32] : vector<4xf32>
+ %cst0_i32_0 = spirv.Constant 0 : i32
+ %cst0_i32_1 = spirv.Constant 0 : i32
+ %cst1_i32 = spirv.Constant 1 : i32
+ %4 = spirv.AccessChain %arg1[%cst0_i32_0, %cst0_i32] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
+ spirv.Store "StorageBuffer" %4, %3 : f32
+ spirv.Return
+ }
+ }
+}
diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index 2b23766..8d7f3da 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -178,7 +178,7 @@ spirv.module Logical GLSL450 attributes {
// Vulkan memory model requires SPV_KHR_vulkan_memory_model, which is enabled
// implicitly by v1.5.
-// CHECK: requires #spirv.vce<v1.0, [VulkanMemoryModel], [SPV_KHR_vulkan_memory_model]>
+// CHECK: requires #spirv.vce<v1.5, [VulkanMemoryModel], [SPV_KHR_vulkan_memory_model]>
spirv.module Logical Vulkan attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.5, [Shader, VulkanMemoryModel], []>, #spirv.resource_limits<>>
diff --git a/mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir b/mlir/test/Dialect/Shard/all-scatter-op-lowering.mlir
index 4f54607a..bc91121 100644
--- a/mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir
+++ b/mlir/test/Dialect/Shard/all-scatter-op-lowering.mlir
@@ -1,43 +1,43 @@
-// RUN: mlir-opt --split-input-file --test-mesh-all-slice-op-lowering --test-mesh-simplifications --cse %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --test-grid-all-slice-op-lowering --test-grid-simplifications --cse %s | FileCheck %s
-mesh.mesh @mesh_1d(shape = ?)
+shard.grid @grid_1d(shape = ?)
-// CHECK-LABEL: func.func @all_slice_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh
-func.func @all_slice_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh(
+// CHECK-LABEL: func.func @all_slice_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_grid
+func.func @all_slice_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_grid(
// CHECK: %[[ARG:.*]]: tensor<?xf16>
%arg0: tensor<?xf16>
// CHECK-SAME: -> tensor<?xf16> {
) -> tensor<?xf16> {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
- // CHECK-DAG: %[[PROC_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
- // CHECK-DAG: %[[MESH_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index
+ // CHECK-DAG: %[[PROC_IDX:.*]] = shard.process_multi_index on @grid_1d axes = [0] : index
+ // CHECK-DAG: %[[SHARD_SIZE:.*]] = shard.grid_shape @grid_1d axes = [0] : index
// CHECK: %[[TENSOR_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %c0 : tensor<?xf16>
- // CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[TENSOR_AXIS_SIZE]], %[[MESH_SIZE]] : index
+ // CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[TENSOR_AXIS_SIZE]], %[[SHARD_SIZE]] : index
// CHECK: %[[AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[AXIS_SIZE_CHECK_REMINDER]], %[[C0]] : index
// CHECK: cf.assert %[[AXIS_SIZE_CHECK]]
- // CHECK: %[[RESULT_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_AXIS_SIZE]], %[[MESH_SIZE]] : index
+ // CHECK: %[[RESULT_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_AXIS_SIZE]], %[[SHARD_SIZE]] : index
// CHECK: %[[SLICE_OFFSET:.*]] = arith.muli %[[PROC_IDX]], %[[RESULT_AXIS_SIZE]] : index
// CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][%[[SLICE_OFFSET]]] [%[[RESULT_AXIS_SIZE]]] [1] : tensor<?xf16> to tensor<?xf16>
- %0 = mesh.all_slice %arg0 on @mesh_1d mesh_axes = [0] slice_axis = 0 : tensor<?xf16> -> tensor<?xf16>
+ %0 = shard.all_slice %arg0 on @grid_1d grid_axes = [0] slice_axis = 0 : tensor<?xf16> -> tensor<?xf16>
// CHECK: return %[[RESULT]] : tensor<?xf16>
return %0 : tensor<?xf16>
}
// -----
-mesh.mesh @mesh_1d(shape = 2)
+shard.grid @grid_1d(shape = 2)
-// CHECK-LABEL: func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_mesh
-func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_mesh(
+// CHECK-LABEL: func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_grid
+func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_grid(
// CHECK: %[[ARG:.*]]: tensor<2xf16>
%arg0: tensor<2xf16>
// CHECK-SAME: -> tensor<1xf16> {
) -> tensor<1xf16> {
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
- // CHECK: %[[PROC_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
+ // CHECK: %[[PROC_IDX:.*]] = shard.process_multi_index on @grid_1d axes = [0] : index
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[ARG]][%[[PROC_IDX]]] [%[[C1]]] [1] : tensor<2xf16> to tensor<?xf16>
// CHECK: %[[RESULT:.*]] = tensor.cast %[[SLICE]] : tensor<?xf16> to tensor<1xf16>
- %0 = mesh.all_slice %arg0 on @mesh_1d mesh_axes = [0] slice_axis = 0 : tensor<2xf16> -> tensor<1xf16>
+ %0 = shard.all_slice %arg0 on @grid_1d grid_axes = [0] slice_axis = 0 : tensor<2xf16> -> tensor<1xf16>
// CHECK: return %[[RESULT]] : tensor<1xf16>
return %0 : tensor<1xf16>
}
@@ -46,18 +46,18 @@ func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_mesh(
// CHECK: #map = affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>
-mesh.mesh @mesh_4d(shape = ?x?x?x?)
+shard.grid @grid_4d(shape = ?x?x?x?)
-// CHECK-LABEL: func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh
-func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh(
+// CHECK-LABEL: func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_grid
+func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_grid(
// CHECK: %[[ARG:.*]]: tensor<?x?xf16>
%arg0 : tensor<?x?xf16>
// CHECK-SAME: -> tensor<?x?xf16> {
) -> tensor<?x?xf16> {
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
- // CHECK-DAG: %[[IN_GROUP_PROC_MULTI_IDX:.*]]:2 = mesh.process_multi_index on @mesh_4d axes = [3, 1] : index, index
- // CHECK-DAG: %[[PROC_GROUP_SHAPE:.*]]:2 = mesh.mesh_shape @mesh_4d axes = [3, 1] : index, index
+ // CHECK-DAG: %[[IN_GROUP_PROC_MULTI_IDX:.*]]:2 = shard.process_multi_index on @grid_4d axes = [3, 1] : index, index
+ // CHECK-DAG: %[[PROC_GROUP_SHAPE:.*]]:2 = shard.grid_shape @grid_4d axes = [3, 1] : index, index
// CHECK: %[[PROC_GROUP_SIZE:.*]] = arith.muli %[[PROC_GROUP_SHAPE]]#0, %[[PROC_GROUP_SHAPE]]#1 : index
// CHECK: %[[SCATTER_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[C1]] : tensor<?x?xf16>
// CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[SCATTER_AXIS_SIZE]], %[[PROC_GROUP_SIZE]] : index
@@ -68,7 +68,7 @@ func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh(
// CHECK: %[[AXIS_0_SIZE:.*]] = tensor.dim %[[ARG]], %[[C0]] : tensor<?x?xf16>
// CHECK: %[[SCATTER_AXIS_OFFSET:.*]] = arith.muli %[[PROC_IN_GROUP_LINEAR_IDX]], %[[RESULT_SCATTER_AXIS_SIZE]] : index
// CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][0, %[[SCATTER_AXIS_OFFSET]]] [%[[AXIS_0_SIZE]], %[[RESULT_SCATTER_AXIS_SIZE]]] [1, 1] : tensor<?x?xf16> to tensor<?x?xf16>
- %0 = mesh.all_slice %arg0 on @mesh_4d mesh_axes = [3, 1] slice_axis = 1 : tensor<?x?xf16> -> tensor<?x?xf16>
+ %0 = shard.all_slice %arg0 on @grid_4d grid_axes = [3, 1] slice_axis = 1 : tensor<?x?xf16> -> tensor<?x?xf16>
// CHECK: return %[[RESULT]] : tensor<?x?xf16>
return %0 : tensor<?x?xf16>
}
diff --git a/mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir b/mlir/test/Dialect/Shard/backward-sharding-propagation.mlir
index 4223d01..8894c4a 100644
--- a/mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir
+++ b/mlir/test/Dialect/Shard/backward-sharding-propagation.mlir
@@ -2,17 +2,17 @@
#map = affine_map<(d0, d1) -> (d0, d1)>
module {
- mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
+ shard.grid @grid(shape = 1) {sym_visibility = "private"}
func.func @test_forward() -> tensor<6x6xi32> {
%c1_i32 = arith.constant 1 : i32
// CHECK: tensor.empty()
%0 = tensor.empty() : tensor<6x6xi32>
- %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
- // CHECK-COUNT-2: mesh.shard
- %sharding_annotated = mesh.shard %0 to %sharding : tensor<6x6xi32>
- %1 = linalg.fill ins(%c1_i32 : i32) outs(%sharding_annotated : tensor<6x6xi32>) -> tensor<6x6xi32>
+ %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding
+ // CHECK-COUNT-2: shard.shard
+ %sharded = shard.shard %0 to %sharding : tensor<6x6xi32>
+ %1 = linalg.fill ins(%c1_i32 : i32) outs(%sharded : tensor<6x6xi32>) -> tensor<6x6xi32>
// CHECK: tensor.empty()
- // CHECK-NOT: mesh.shard @
+ // CHECK-NOT: shard.shard @
%2 = tensor.empty() : tensor<6x6xi32>
%3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%1, %1
: tensor<6x6xi32>, tensor<6x6xi32>) outs(%2 : tensor<6x6xi32>) {
diff --git a/mlir/test/Dialect/Shard/canonicalization.mlir b/mlir/test/Dialect/Shard/canonicalization.mlir
new file mode 100644
index 0000000..ed40dfb
--- /dev/null
+++ b/mlir/test/Dialect/Shard/canonicalization.mlir
@@ -0,0 +1,248 @@
+// RUN: mlir-opt --canonicalize %s | FileCheck %s
+
+shard.grid @grid0(shape = 2x4)
+
+// CHECK-LABEL: func @all_reduce_empty_grid_axes
+func.func @all_reduce_empty_grid_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: shard.all_reduce
+ %0 = shard.all_reduce %arg0 on @grid0
+ grid_axes = []
+ : tensor<4xf32> -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @all_reduce_empty_grid_axes_different_return_type
+func.func @all_reduce_empty_grid_axes_different_return_type(
+ %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+// CHECK: shard.all_reduce
+ %0 = shard.all_reduce %arg0 on @grid0
+// CHECK-NOT: grid_axes
+ grid_axes = []
+ : tensor<4xf32> -> tensor<4xf64>
+ return %0 : tensor<4xf64>
+}
+
+// CHECK-LABEL: func @all_reduce_default_reduction
+func.func @all_reduce_default_reduction(
+ %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+ %0 = shard.all_reduce %arg0 on @grid0
+ grid_axes = [0]
+// CHECK-NOT: reduction
+ reduction = sum
+ : tensor<4xf32> -> tensor<4xf64>
+ return %0 : tensor<4xf64>
+}
+
+// CHECK-LABEL: func @all_to_all_empty_grid_axes
+func.func @all_to_all_empty_grid_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<8xf32>
+ %arg0 : tensor<8xf32>) -> tensor<8xf32> {
+// CHECK-NOT: shard.all_to_all
+ %0 = shard.all_to_all %arg0 on @grid0
+ grid_axes = []
+ split_axis = 0
+ concat_axis = 0
+ : tensor<8xf32> -> tensor<8xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<8xf32>
+}
+
+// CHECK-LABEL: func @all_gather_empty_grid_axes
+func.func @all_gather_empty_grid_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: shard.all_gather
+ %0 = shard.all_gather %arg0 on @grid0
+ grid_axes = []
+ gather_axis = 0
+ : tensor<4xf32> -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @all_slice_empty_grid_axes
+func.func @all_slice_empty_grid_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: shard.scatter
+ %0 = shard.all_slice %arg0 on @grid0
+ grid_axes = []
+ slice_axis = 0
+ : tensor<4xf32> -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @broadcast_empty_grid_axes
+func.func @broadcast_empty_grid_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: shard.broadcast
+ %0 = shard.broadcast %arg0 on @grid0
+ grid_axes = []
+ root = []
+ : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @gather_empty_grid_axes
+func.func @gather_empty_grid_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: shard.gather
+ %0 = shard.gather %arg0 on @grid0
+ grid_axes = []
+ gather_axis = 0
+ root = []
+ : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @receive_empty_grid_axes
+func.func @receive_empty_grid_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: shard.recv
+ %0 = shard.recv %arg0 on @grid0
+ grid_axes = []
+ : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @reduce_empty_grid_axes
+func.func @reduce_empty_grid_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: shard.reduce
+ %0 = shard.reduce %arg0 on @grid0
+ grid_axes = []
+ root = []
+ : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @reduce_scatter_empty_grid_axes
+func.func @reduce_scatter_empty_grid_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: shard.reduce_scatter
+ %0 = shard.reduce_scatter %arg0 on @grid0
+ grid_axes = []
+ scatter_axis = 0
+ : tensor<4xf32> -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @reduce_scatter_empty_grid_axes_different_return_type
+func.func @reduce_scatter_empty_grid_axes_different_return_type(
+ %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+// CHECK: shard.reduce_scatter
+ %0 = shard.reduce_scatter %arg0 on @grid0
+// CHECK-NOT: grid_axes
+ grid_axes = []
+ scatter_axis = 0
+ : tensor<4xf32> -> tensor<4xf64>
+ return %0 : tensor<4xf64>
+}
+
+// CHECK-LABEL: func @reduce_scatter_default_reduction
+func.func @reduce_scatter_default_reduction(
+ %arg0 : tensor<4xf32>) -> tensor<2xf64> {
+ %0 = shard.reduce_scatter %arg0 on @grid0
+ grid_axes = [0]
+// CHECK-NOT: reduction
+ reduction = sum
+ scatter_axis = 0
+ : tensor<4xf32> -> tensor<2xf64>
+ return %0 : tensor<2xf64>
+}
+
+// CHECK-LABEL: func @scatter_empty_grid_axes
+func.func @scatter_empty_grid_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: shard.scatter
+ %0 = shard.scatter %arg0 on @grid0
+ grid_axes = []
+ scatter_axis = 0
+ root = []
+ : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @send_empty_grid_axes
+func.func @send_empty_grid_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: shard.send
+ %0 = shard.send %arg0 on @grid0
+ grid_axes = []
+ destination = []
+ : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+shard.grid @grid4x4(shape = 4x4)
+// CHECK-LABEL: func @test_halo_sizes
+func.func @test_halo_sizes() -> !shard.sharding {
+ %c2_i64 = arith.constant 2 : i64
+ // CHECK shard.sharding @grid4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 2, 22] : !shard.sharding
+ %sharding = shard.sharding @grid4x4 split_axes = [[0], [1]] halo_sizes = [1, %c2_i64, %c2_i64, 22] : !shard.sharding
+ return %sharding : !shard.sharding
+}
+
+// CHECK-LABEL: func @test_shard_offs
+func.func @test_shard_offs() -> !shard.sharding {
+ %c2_i64 = arith.constant 2 : i64
+ // CHECK shard.sharding @grid4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, 2, 3, 4, 0, 2, 3, 4, 22] : !shard.sharding
+ %sharding = shard.sharding @grid4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, %c2_i64, 3, 4, 0, %c2_i64, 3, 4, 22] : !shard.sharding
+ return %sharding : !shard.sharding
+}
+
+// CHECK-LABEL: func @test_duplicate_shardops
+func.func @test_duplicate_shardops() -> (tensor<1024x1024xf32>, tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
+ // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+ %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+ // CHECK-NEXT: [[vsharding:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0, 1]] : !shard.sharding
+ %sharding_1 = shard.sharding @grid4x4 split_axes = [[0, 1]] : !shard.sharding
+ %cst_2 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+ %sharding_2 = shard.sharding @grid4x4 split_axes = [[0, 1]] : !shard.sharding
+ %sharded_2 = shard.shard %cst_2 to %sharding_2 : tensor<1024x1024xf32>
+ %cst_3 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+ %sharding_3 = shard.sharding @grid4x4 split_axes = [[0, 1]] : !shard.sharding
+ %sharded_3 = shard.shard %cst_3 to %sharding_3 : tensor<1024x1024xf32>
+ // CHECK-NEXT: [[vsharded:%.*]] = shard.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32>
+ %sharded_1 = shard.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32>
+ // CHECK-NEXT: return [[vsharded]], [[vsharded]] : tensor<1024x1024xf32>, tensor<1024x1024xf32>
+ return %sharded_1, %sharded_2 : tensor<1024x1024xf32>, tensor<1024x1024xf32>
+}
+
+// CHECK-LABEL: func @test_duplicate_shardops_diff
+func.func @test_duplicate_shardops_diff() -> (tensor<1024x1024xf32>, tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
+ // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+ %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+ // CHECK-NEXT: [[vsharding:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding
+ %sharding_1 = shard.sharding @grid4x4 split_axes = [[0]] : !shard.sharding
+ %cst_2 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+ // CHECK-NEXT: [[vsharding_0:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0, 1]] : !shard.sharding
+ %sharding_2 = shard.sharding @grid4x4 split_axes = [[0, 1]] : !shard.sharding
+ // CHECK-NEXT: [[vsharded:%.*]] = shard.shard [[vcst]] to [[vsharding_0]] : tensor<1024x1024xf32>
+ %sharded_2 = shard.shard %cst_2 to %sharding_2 : tensor<1024x1024xf32>
+ %cst_3 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+ %sharding_3 = shard.sharding @grid4x4 split_axes = [[0]] : !shard.sharding
+ %sharded_3 = shard.shard %cst_3 to %sharding_3 : tensor<1024x1024xf32>
+ // CHECK-NEXT: [[vsharded_1:%.*]] = shard.shard [[vsharded]] to [[vsharding]] : tensor<1024x1024xf32>
+ %sharded_1 = shard.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32>
+ // CHECK-NEXT: return [[vsharded_1]], [[vsharded]] : tensor<1024x1024xf32>, tensor<1024x1024xf32>
+ return %sharded_1, %sharded_2 : tensor<1024x1024xf32>, tensor<1024x1024xf32>
+}
diff --git a/mlir/test/Dialect/Shard/folding.mlir b/mlir/test/Dialect/Shard/folding.mlir
new file mode 100644
index 0000000..5a0f35b
--- /dev/null
+++ b/mlir/test/Dialect/Shard/folding.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt -test-grid-simplifications %s | FileCheck %s
+
+shard.grid @grid0(shape = 4x?x2)
+shard.grid @grid1(shape = 2x3)
+
+// CHECK-LABEL: func.func @grid_shape_op_folding
+func.func @grid_shape_op_folding() -> (index, index) {
+ // CHECK: %[[AXIS_2_SIZE:.*]] = arith.constant 2 : index
+ // CHECK: %[[AXIS_1_SIZE:.*]] = shard.grid_shape @grid0 axes = [1] : index
+ %0:2 = shard.grid_shape @grid0 axes = [2, 1] : index, index
+ // CHECK: return %[[AXIS_2_SIZE]], %[[AXIS_1_SIZE]]
+ return %0#0, %0#1 : index, index
+}
+
+// CHECK-LABEL: func.func @grid_shape_op_folding_all_axes_static_grid
+func.func @grid_shape_op_folding_all_axes_static_grid() -> (index, index) {
+ // CHECK: %[[AXIS_0_SIZE:.*]] = arith.constant 2 : index
+ // CHECK: %[[AXIS_1_SIZE:.*]] = arith.constant 3 : index
+ %0:2 = shard.grid_shape @grid1 : index, index
+ // CHECK: return %[[AXIS_0_SIZE]], %[[AXIS_1_SIZE]]
+ return %0#0, %0#1 : index, index
+}
diff --git a/mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir b/mlir/test/Dialect/Shard/forward-backward-sharding-propagation.mlir
index dd2eee2..0d8d997 100644
--- a/mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir
+++ b/mlir/test/Dialect/Shard/forward-backward-sharding-propagation.mlir
@@ -2,25 +2,25 @@
#map = affine_map<(d0, d1) -> (d0, d1)>
module {
- mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
+ shard.grid @grid(shape = 1) {sym_visibility = "private"}
func.func @test_forward() -> tensor<6x6xi32> {
%c1_i32 = arith.constant 1 : i32
// CHECK: tensor.empty()
%0 = tensor.empty() : tensor<6x6xi32>
- // CHECK-COUNT-3: mesh.sharding @mesh split_axes = {{\[\[0}}]]
- %sharding_row = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
- %annotated_row = mesh.shard %0 to %sharding_row : tensor<6x6xi32>
+ // CHECK-COUNT-3: shard.sharding @grid split_axes = {{\[\[0}}]]
+ %sharding_row = shard.sharding @grid split_axes = [[0]] : !shard.sharding
+ %annotated_row = shard.shard %0 to %sharding_row : tensor<6x6xi32>
%1 = linalg.fill ins(%c1_i32 : i32) outs(%annotated_row : tensor<6x6xi32>) -> tensor<6x6xi32>
%2 = tensor.empty() : tensor<6x6xi32>
- // CHECK-COUNT-4: mesh.sharding @mesh split_axes = {{\[\[1}}]]
+ // CHECK-COUNT-4: shard.sharding @grid split_axes = {{\[\[1}}]]
%3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%2, %1
: tensor<6x6xi32>, tensor<6x6xi32>) outs(%2 : tensor<6x6xi32>) {
^bb0(%in: i32, %in_2: i32, %out: i32):
%9 = arith.addi %in, %in_2 : i32
linalg.yield %9 : i32
} -> tensor<6x6xi32>
- %sharding_col = mesh.sharding @mesh split_axes = [[1]] : !mesh.sharding
- %annotated_col = mesh.shard %3 to %sharding_col : tensor<6x6xi32>
+ %sharding_col = shard.sharding @grid split_axes = [[1]] : !shard.sharding
+ %annotated_col = shard.shard %3 to %sharding_col : tensor<6x6xi32>
// CHECK: return
return %annotated_col : tensor<6x6xi32>
}
diff --git a/mlir/test/Dialect/Shard/forward-sharding-propagation.mlir b/mlir/test/Dialect/Shard/forward-sharding-propagation.mlir
new file mode 100644
index 0000000..3cda9ea
--- /dev/null
+++ b/mlir/test/Dialect/Shard/forward-sharding-propagation.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward}))" %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "mpich", "MPI:comm_world_rank" = 0 : i32>} {
+ shard.grid @grid(shape = 1) {sym_visibility = "private"}
+ func.func @test_forward() -> (tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>) attributes {llvm.emit_c_interface} {
+ %c1_i32 = arith.constant 1 : i32
+ // CHECK: [[v3:%.*]] = tensor.empty() : tensor<6x6xi32>
+ %0 = tensor.empty() : tensor<6x6xi32>
+ // CHECK: [[v1:%.*]] = linalg.fill ins
+ // CHECK: [[vsharding_0:%.*]] = shard.sharding @grid split_axes = {{\[\[}}0]] : !shard.sharding
+ // CHECK: [[vsharded_1:%.*]] = shard.shard [[v1]] to [[vsharding_0]] : tensor<6x6xi32>
+ %1 = linalg.fill ins(%c1_i32 : i32) outs(%0 : tensor<6x6xi32>) -> tensor<6x6xi32>
+ %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding
+ %sharded = shard.shard %1 to %sharding : tensor<6x6xi32>
+ // CHECK: [[v2:%.*]] = tensor.empty() : tensor<6x6xi32>
+ // CHECK: [[vsharding_2:%.*]] = shard.sharding @grid split_axes = {{\[\[}}0]] : !shard.sharding
+ // CHECK: [[vsharded_3:%.*]] = shard.shard [[vsharded_1]] to [[vsharding_2]] annotate_for_users : tensor<6x6xi32>
+ %3 = tensor.empty() : tensor<6x6xi32>
+ // CHECK: [[vsharding_4:%.*]] = shard.sharding @grid split_axes = {{\[\[}}0]] : !shard.sharding
+ // CHECK: [[vsharded_5:%.*]] = shard.shard [[v2]] to [[vsharding_4]] annotate_for_users : tensor<6x6xi32>
+ // CHECK: [[v3:%.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]}
+ // CHECK-SAME: ins([[vsharded_3]], [[vsharded_3]] : tensor<6x6xi32>, tensor<6x6xi32>) outs([[vsharded_5]] : tensor<6x6xi32>) {
+ // CHECK: [[vsharding_6:%.*]] = shard.sharding @grid split_axes = {{\[\[}}0]] : !shard.sharding
+ // CHECK: [[vsharded_7:%.*]] = shard.shard [[v3]] to [[vsharding_6]] : tensor<6x6xi32>
+ %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%sharded, %sharded
+ : tensor<6x6xi32>, tensor<6x6xi32>) outs(%3 : tensor<6x6xi32>) {
+ ^bb0(%in: i32, %in_2: i32, %out: i32):
+ %9 = arith.addi %in, %in_2 : i32
+ linalg.yield %9 : i32
+ } -> tensor<6x6xi32>
+ %c0_i32 = arith.constant 0 : i32
+ %6 = tensor.empty() : tensor<i32>
+ %7 = linalg.fill ins(%c0_i32 : i32) outs(%6 : tensor<i32>) -> tensor<i32>
+ // CHECK: [[vreduced:%.*]] = linalg.reduce ins
+ // CHECK: [[vsharding_12:%.*]] = shard.sharding @grid split_axes = [] : !shard.sharding
+ // CHECK: [[vsharded_13:%.*]] = shard.shard [[vreduced]] to [[vsharding_12]] : tensor<i32>
+ %reduced = linalg.reduce ins(%4 : tensor<6x6xi32>) outs(%7 : tensor<i32>) dimensions = [0, 1]
+ (%in: i32, %init: i32) {
+ %9 = arith.addi %in, %init : i32
+ linalg.yield %9 : i32
+ }
+ // CHECK: [[vsharding_14:%.*]] = shard.sharding @grid split_axes = {{\[\[}}]] : !shard.sharding
+ %sharding_0 = shard.sharding @grid split_axes = [[]] : !shard.sharding
+ // CHECK: [[vsharded_15:%.*]] = shard.shard [[vsharded_13]] to [[vsharding_14]] annotate_for_users : tensor<i32>
+ %sharded_1 = shard.shard %reduced to %sharding_0 annotate_for_users : tensor<i32>
+ return %sharded, %4, %sharded_1 : tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>
+ }
+}
diff --git a/mlir/test/Dialect/Shard/inlining.mlir b/mlir/test/Dialect/Shard/inlining.mlir
new file mode 100644
index 0000000..ce664b3
--- /dev/null
+++ b/mlir/test/Dialect/Shard/inlining.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt -inline %s | FileCheck %s
+
+shard.grid @grid0(shape = 4x?x2)
+
+func.func private @grid_to_inline() -> (index, index) {
+ %0:2 = shard.grid_shape @grid0 axes = [2, 1] : index, index
+ return %0#0, %0#1 : index, index
+}
+// CHECK-LABEL: func.func @main
+func.func @main() -> (index, index) {
+ // CHECK-NEXT: %[[AXIS_SIZE:.*]]:2 = shard.grid_shape @grid0 axes = [2, 1] : index
+ %0:2 = func.call @grid_to_inline() : () -> (index, index)
+ // CHECK-NEXT: return %[[AXIS_SIZE]]#0, %[[AXIS_SIZE]]#1
+ return %0#0, %0#1 : index, index
+}
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Shard/invalid.mlir
index 2656332..6acac97 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Shard/invalid.mlir
@@ -1,55 +1,55 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s
-// expected-error@+1 {{rank of mesh is expected to be a positive integer}}
-mesh.mesh @mesh0(shape = [])
+// expected-error@+1 {{rank of grid is expected to be a positive integer}}
+shard.grid @grid0(shape = [])
// -----
-// expected-error@+1 {{custom op 'mesh.mesh' Failed parsing dimension list. Did you mean an empty list? It must be denoted by "[]".}}
-mesh.mesh @mesh0(shape = -1)
+// expected-error@+1 {{custom op 'shard.grid' Failed parsing dimension list. Did you mean an empty list? It must be denoted by "[]".}}
+shard.grid @grid0(shape = -1)
// -----
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
-func.func @mesh_axis_duplicated_different_subarray(
+func.func @grid_axis_duplicated_different_subarray(
%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
- // expected-error@+1 {{mesh axis duplicated}}
- %s = mesh.sharding @mesh0 split_axes = [[0], [0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+ // expected-error@+1 {{grid axis duplicated}}
+ %s = shard.sharding @grid0 split_axes = [[0], [0]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
// -----
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
-func.func @mesh_axis_duplicated_same_subarray(
+func.func @grid_axis_duplicated_same_subarray(
%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
- // expected-error@+1 {{mesh axis duplicated}}
- %s = mesh.sharding @mesh0 split_axes = [[0, 0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+ // expected-error@+1 {{grid axis duplicated}}
+ %s = shard.sharding @grid0 split_axes = [[0, 0]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
// -----
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
-func.func @mesh_axis_negtive_in_split_part(
+func.func @grid_axis_negtive_in_split_part(
%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
- // expected-error@+1 {{mesh axis is expected to be non-negative}}
- %s = mesh.sharding @mesh0 split_axes = [[-1]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+ // expected-error@+1 {{grid axis is expected to be non-negative}}
+ %s = shard.sharding @grid0 split_axes = [[-1]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
// -----
func.func @sharding_attribute_invalid_nested_symbol(%arg0 : tensor<4x8xf32>) {
- // expected-error@+1 {{custom op 'mesh.sharding' invalid kind of attribute specified}}
- %s = mesh.sharding @a::@b split_axes = [[0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+ // expected-error@+1 {{custom op 'shard.sharding' invalid kind of attribute specified}}
+ %s = shard.sharding @a::@b split_axes = [[0]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
return
}
@@ -57,8 +57,8 @@ func.func @sharding_attribute_invalid_nested_symbol(%arg0 : tensor<4x8xf32>) {
func.func @sharding_attribute_invalid_halo(%arg0 : tensor<4x8xf32>) {
// expected-error@+1 {{halo sizes must be specified for all split axes}}
- %s = mesh.sharding @mesh0 split_axes = [[0], [1]] halo_sizes = [1, 2] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+ %s = shard.sharding @grid0 split_axes = [[0], [1]] halo_sizes = [1, 2] : !shard.sharding
+ %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
return
}
@@ -66,292 +66,292 @@ func.func @sharding_attribute_invalid_halo(%arg0 : tensor<4x8xf32>) {
func.func @sharding_attribute_invalid_sizes(%arg0 : tensor<4x8xf32>) {
// expected-error@+1 {{halo sizes and shard offsets are mutually exclusive}}
- %s = mesh.sharding @mesh0 split_axes = [[0]] halo_sizes = [1, 2] sharded_dims_offsets = [0, 2, 2] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+ %s = shard.sharding @grid0 split_axes = [[0]] halo_sizes = [1, 2] sharded_dims_offsets = [0, 2, 2] : !shard.sharding
+ %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
return
}
// -----
-mesh.mesh @mesh_dyn(shape = ?x?)
-func.func @sharding_dyn_mesh_and_sizes(%arg0 : tensor<4x8xf32>) {
- // expected-error@+1 {{sharded dims offsets are not allowed for devices meshes with dynamic shape}}
- %s = mesh.sharding @mesh_dyn split_axes = [[0]] sharded_dims_offsets = [0, 2, 2] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+shard.grid @grid_dyn(shape = ?x?)
+func.func @sharding_dyn_grid_and_sizes(%arg0 : tensor<4x8xf32>) {
+ // expected-error@+1 {{sharded dims offsets are not allowed for device grids with dynamic shape}}
+ %s = shard.sharding @grid_dyn split_axes = [[0]] sharded_dims_offsets = [0, 2, 2] : !shard.sharding
+ %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
return
}
// -----
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
func.func @sharding_sizes_count(%arg0 : tensor<4x8xf32>) {
// expected-error@+1 {{sharded dims offsets has wrong size}}
- %s = mesh.sharding @mesh0 split_axes = [[0], [1]] sharded_dims_offsets = [0, 2, 4, 0, 2, 4, 6] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+ %s = shard.sharding @grid0 split_axes = [[0], [1]] sharded_dims_offsets = [0, 2, 4, 0, 2, 4, 6] : !shard.sharding
+ %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
return
}
// -----
-mesh.mesh @mesh0(shape = 4)
+shard.grid @grid0(shape = 4)
func.func @sharding_sizes_decreasing(%arg0 : tensor<4x8xf32>) {
// expected-error@+1 {{sharded dims offsets must be non-decreasing}}
- %s = mesh.sharding @mesh0 split_axes = [[0]] sharded_dims_offsets = [0, 2, 3, 2] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+ %s = shard.sharding @grid0 split_axes = [[0]] sharded_dims_offsets = [0, 2, 3, 2] : !shard.sharding
+ %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
return
}
// -----
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
-func.func @mesh_shape_mesh_axis_out_of_bounds() -> (index, index) {
- // expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
- %0:2 = mesh.mesh_shape @mesh0 axes = [0, 2] : index, index
+func.func @grid_shape_grid_axis_out_of_bounds() -> (index, index) {
+ // expected-error@+1 {{0-based grid axis index 2 is out of bounds. The referenced grid "grid0" is of rank 2.}}
+ %0:2 = shard.grid_shape @grid0 axes = [0, 2] : index, index
return %0#0, %0#1 : index, index
}
// -----
-mesh.mesh @mesh0(shape = 1x2x3)
+shard.grid @grid0(shape = 1x2x3)
-func.func @mesh_shape_duplicate_mesh_axis() -> (index, index, index) {
- // expected-error@+1 {{Mesh axes contains duplicate elements.}}
- %0:3 = mesh.mesh_shape @mesh0 axes = [0, 2, 0] : index, index, index
+func.func @grid_shape_duplicate_grid_axis() -> (index, index, index) {
+ // expected-error@+1 {{Grid axes contains duplicate elements.}}
+ %0:3 = shard.grid_shape @grid0 axes = [0, 2, 0] : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}
// -----
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
-func.func @mesh_shape_wrong_number_of_results() -> (index, index) {
+func.func @grid_shape_wrong_number_of_results() -> (index, index) {
// expected-error@+1 {{Unexpected number of results 2. Expected 1.}}
- %0:2 = mesh.mesh_shape @mesh0 axes = [0] : index, index
+ %0:2 = shard.grid_shape @grid0 axes = [0] : index, index
return %0#0, %0#1 : index, index
}
// -----
-mesh.mesh @mesh0(shape = 1x2x3)
+shard.grid @grid0(shape = 1x2x3)
-func.func @mesh_shape_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
+func.func @grid_shape_wrong_number_of_results_empty_grid_axes() -> (index, index) {
// expected-error@+1 {{Unexpected number of results 2. Expected 3.}}
- %0:2 = mesh.mesh_shape @mesh0 : index, index
+ %0:2 = shard.grid_shape @grid0 : index, index
return %0#0, %0#1 : index, index
}
// -----
-func.func @mesh_shape_invalid_mesh_name() -> (index) {
- // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
- %0 = mesh.mesh_shape @this_mesh_symbol_does_not_exist : index
+func.func @grid_shape_invalid_grid_name() -> (index) {
+ // expected-error@+1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}}
+ %0 = shard.grid_shape @this_grid_symbol_does_not_exist : index
return %0#0 : index
}
// -----
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
-func.func @process_multi_index_mesh_axis_out_of_bounds() -> (index, index) {
- // expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
- %0:2 = mesh.process_multi_index on @mesh0 axes = [0, 2] : index, index
+func.func @process_multi_index_grid_axis_out_of_bounds() -> (index, index) {
+ // expected-error@+1 {{0-based grid axis index 2 is out of bounds. The referenced grid "grid0" is of rank 2.}}
+ %0:2 = shard.process_multi_index on @grid0 axes = [0, 2] : index, index
return %0#0, %0#1 : index, index
}
// -----
-mesh.mesh @mesh0(shape = 1x2x3)
+shard.grid @grid0(shape = 1x2x3)
-func.func @process_multi_index_duplicate_mesh_axis() -> (index, index, index) {
- // expected-error@+1 {{Mesh axes contains duplicate elements.}}
- %0:3 = mesh.process_multi_index on @mesh0 axes = [0, 2, 0] : index, index, index
+func.func @process_multi_index_duplicate_grid_axis() -> (index, index, index) {
+ // expected-error@+1 {{Grid axes contains duplicate elements.}}
+ %0:3 = shard.process_multi_index on @grid0 axes = [0, 2, 0] : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}
// -----
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
func.func @process_multi_index_wrong_number_of_results() -> (index, index) {
// expected-error@+1 {{Unexpected number of results 2. Expected 1.}}
- %0:2 = mesh.process_multi_index on @mesh0 axes = [0] : index, index
+ %0:2 = shard.process_multi_index on @grid0 axes = [0] : index, index
return %0#0, %0#1 : index, index
}
// -----
-mesh.mesh @mesh0(shape = 1x2x3)
+shard.grid @grid0(shape = 1x2x3)
-func.func @process_multi_index_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
+func.func @process_multi_index_wrong_number_of_results_empty_grid_axes() -> (index, index) {
// expected-error@+1 {{Unexpected number of results 2. Expected 3.}}
- %0:2 = mesh.process_multi_index on @mesh0 : index, index
+ %0:2 = shard.process_multi_index on @grid0 : index, index
return %0#0, %0#1 : index, index
}
// -----
-func.func @process_multi_index_invalid_mesh_name() -> (index) {
- // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
- %0 = mesh.process_multi_index on @this_mesh_symbol_does_not_exist : index
+func.func @process_multi_index_invalid_grid_name() -> (index) {
+ // expected-error@+1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}}
+ %0 = shard.process_multi_index on @this_grid_symbol_does_not_exist : index
return %0 : index
}
// -----
-func.func @process_linear_index_invalid_mesh_name() -> (index) {
- // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
- %0 = mesh.process_linear_index on @this_mesh_symbol_does_not_exist : index
+func.func @process_linear_index_invalid_grid_name() -> (index) {
+ // expected-error@+1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}}
+ %0 = shard.process_linear_index on @this_grid_symbol_does_not_exist : index
return %0 : index
}
// -----
-func.func @all_reduce_invalid_mesh_symbol(
+func.func @all_reduce_invalid_grid_symbol(
%arg0 : tensor<4xf32>) -> tensor<4xf64> {
- // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
- %0 = mesh.all_reduce %arg0 on @this_mesh_symbol_does_not_exist reduction = sum
+ // expected-error@+1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}}
+ %0 = shard.all_reduce %arg0 on @this_grid_symbol_does_not_exist reduction = sum
: tensor<4xf32> -> tensor<4xf64>
return %0 : tensor<4xf64>
}
// -----
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
-func.func @all_reduce_invalid_mesh_axis(
+func.func @all_reduce_invalid_grid_axis(
%arg0 : tensor<4xf32>) -> tensor<4xf64> {
- // expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [2] reduction = sum
+ // expected-error@+1 {{0-based grid axis index 2 is out of bounds. The referenced grid "grid0" is of rank 2.}}
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [2] reduction = sum
: tensor<4xf32> -> tensor<4xf64>
return %0 : tensor<4xf64>
}
// -----
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
-func.func @all_reduce_duplicate_mesh_axis(
+func.func @all_reduce_duplicate_grid_axis(
%arg0 : tensor<4xf32>) -> tensor<4xf64> {
- // expected-error@+1 {{Mesh axes contains duplicate elements.}}
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1, 0] reduction = sum
+ // expected-error@+1 {{Grid axes contains duplicate elements.}}
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1, 0] reduction = sum
: tensor<4xf32> -> tensor<4xf64>
return %0 : tensor<4xf64>
}
// -----
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
func.func @all_reduce_invalid_tensor_dimension_size(
%arg0 : tensor<4xf32>) -> tensor<5xf64> {
- // expected-error@+1 {{'mesh.all_reduce' op requires the same shape for all operands and results}}
- %0 = mesh.all_reduce %arg0 on @mesh0 : tensor<4xf32> -> tensor<5xf64>
+ // expected-error@+1 {{'shard.all_reduce' op requires the same shape for all operands and results}}
+ %0 = shard.all_reduce %arg0 on @grid0 : tensor<4xf32> -> tensor<5xf64>
return %0 : tensor<5xf64>
}
// -----
-func.func @all_gather_invalid_mesh_symbol(
+func.func @all_gather_invalid_grid_symbol(
%arg0 : tensor<4xf32>) -> tensor<4xf32> {
- // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
- %0 = mesh.all_gather %arg0 on @this_mesh_symbol_does_not_exist gather_axis = 0
+ // expected-error@+1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}}
+ %0 = shard.all_gather %arg0 on @this_grid_symbol_does_not_exist gather_axis = 0
: tensor<4xf32> -> tensor<4xf32>
return %0 : tensor<4xf32>
}
// -----
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
-func.func @all_gather_invalid_mesh_axis(
+func.func @all_gather_invalid_grid_axis(
%arg0 : tensor<4xf32>) -> tensor<4xf32> {
- // expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
- %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 0
+ // expected-error@+1 {{0-based grid axis index 2 is out of bounds. The referenced grid "grid0" is of rank 2.}}
+ %0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 0
: tensor<4xf32> -> tensor<4xf32>
return %0 : tensor<4xf32>
}
// -----
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
-func.func @all_reduce_duplicate_mesh_axis(
+func.func @all_reduce_duplicate_grid_axis(
%arg0 : tensor<4xf32>) -> tensor<4xf32> {
- // expected-error@+1 {{Mesh axes contains duplicate elements.}}
- %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2, 2] gather_axis = 0
+ // expected-error@+1 {{Grid axes contains duplicate elements.}}
+ %0 = shard.all_gather %arg0 on @grid0 grid_axes = [2, 2] gather_axis = 0
: tensor<4xf32> -> tensor<4xf32>
return %0 : tensor<4xf32>
}
// -----
-mesh.mesh @mesh0(shape = 1)
+shard.grid @grid0(shape = 1)
func.func @all_gather_invalid_non_gather_axis_dimension_size(
%arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
// expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 4, but got 5.}}
- %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0
+ %0 = shard.all_gather %arg0 on @grid0 grid_axes = [0] gather_axis = 0
: tensor<3x4xf32> -> tensor<3x5xf32>
return %0 : tensor<3x5xf32>
}
// -----
-mesh.mesh @mesh0(shape = 1x2)
+shard.grid @grid0(shape = 1x2)
func.func @all_gather_invalid_gather_axis_dimension_size(
%arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
// expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 8, but got 5.}}
- %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [1] gather_axis = 1
+ %0 = shard.all_gather %arg0 on @grid0 grid_axes = [1] gather_axis = 1
: tensor<3x4xf32> -> tensor<3x5xf32>
return %0 : tensor<3x5xf32>
}
// -----
-mesh.mesh @mesh0(shape = 1)
+shard.grid @grid0(shape = 1)
func.func @all_gather_invalid_gather_axis_dynamic_dimension(
%arg0 : tensor<?xf32>) -> tensor<3xf32> {
// expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
- %0 = mesh.all_gather %arg0 on @mesh0 gather_axis = 0
+ %0 = shard.all_gather %arg0 on @grid0 gather_axis = 0
: tensor<?xf32> -> tensor<3xf32>
return %0 : tensor<3xf32>
}
// -----
-mesh.mesh @mesh0(shape = 1)
+shard.grid @grid0(shape = 1)
func.func @all_gather_invalid_gather_axis(
%arg0 : tensor<3xf32>) -> tensor<3xf32> {
// expected-error@+1 {{Gather axis 1 is out of bounds [0, 1).}}
- %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 1
+ %0 = shard.all_gather %arg0 on @grid0 grid_axes = [0] gather_axis = 1
: tensor<3xf32> -> tensor<3xf32>
return %0 : tensor<3xf32>
}
// -----
-mesh.mesh @mesh0(shape = 1)
+shard.grid @grid0(shape = 1)
func.func @all_gather_invalid_negative_gather_axis(
%arg0 : tensor<3xf32>) -> tensor<3xf32> {
// expected-error@+1 {{Gather axis -1 is out of bounds [0, 1).}}
- %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = -1
+ %0 = shard.all_gather %arg0 on @grid0 grid_axes = [0] gather_axis = -1
: tensor<3xf32> -> tensor<3xf32>
return %0 : tensor<3xf32>
}
// -----
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
-func.func @all_slice_duplicate_mesh_axis(
+func.func @all_slice_duplicate_grid_axis(
%arg0 : tensor<?xf32>) -> tensor<?xf32> {
- // expected-error@+1 {{Mesh axes contains duplicate elements.}}
- %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [0, 0]
+ // expected-error@+1 {{Grid axes contains duplicate elements.}}
+ %0 = shard.all_slice %arg0 on @grid0 grid_axes = [0, 0]
slice_axis = 0
: tensor<?xf32> -> tensor<?xf32>
return %0 : tensor<?xf32>
@@ -359,12 +359,12 @@ func.func @all_slice_duplicate_mesh_axis(
// -----
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
func.func @all_slice_invalid_dynamic_dimension(
%arg0 : tensor<?xf32>) -> tensor<2xf32> {
// expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
- %0 = mesh.all_slice %arg0 on @mesh0
+ %0 = shard.all_slice %arg0 on @grid0
slice_axis = 0
: tensor<?xf32> -> tensor<2xf32>
return %0 : tensor<2xf32>
@@ -372,12 +372,12 @@ func.func @all_slice_invalid_dynamic_dimension(
// -----
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
func.func @all_slice_invalid_static_dimension_size(
%arg0 : tensor<3xf32>) -> tensor<2xf32> {
// expected-error@+1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
- %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.all_slice %arg0 on @grid0 grid_axes = [0]
slice_axis = 0
: tensor<3xf32> -> tensor<2xf32>
return %0 : tensor<2xf32>
@@ -385,12 +385,12 @@ func.func @all_slice_invalid_static_dimension_size(
// -----
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
func.func @all_slice_invalid_operand_static_dimension_size(
%arg0 : tensor<4xf32>) -> tensor<?xf32> {
// expected-error@+1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}}
- %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.all_slice %arg0 on @grid0 grid_axes = [0]
slice_axis = 0
: tensor<4xf32> -> tensor<?xf32>
return %0 : tensor<?xf32>
@@ -398,10 +398,10 @@ func.func @all_slice_invalid_operand_static_dimension_size(
// -----
-func.func @all_to_all_invalid_mesh_symbol(
+func.func @all_to_all_invalid_grid_symbol(
%arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
- // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
- %0 = mesh.all_to_all %arg0 on @this_mesh_symbol_does_not_exist
+ // expected-error@+1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}}
+ %0 = shard.all_to_all %arg0 on @this_grid_symbol_does_not_exist
split_axis = 1 concat_axis = 0
: tensor<3x6xi8> -> tensor<3x6xi8>
return %0 : tensor<3x6xi8>
@@ -409,12 +409,12 @@ func.func @all_to_all_invalid_mesh_symbol(
// -----
-mesh.mesh @mesh0(shape = 1)
+shard.grid @grid0(shape = 1)
-func.func @all_to_all_duplicate_mesh_axis(
+func.func @all_to_all_duplicate_grid_axis(
%arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
- // expected-error@+1 {{Mesh axes contains duplicate elements.}}
- %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0, 0]
+ // expected-error@+1 {{Grid axes contains duplicate elements.}}
+ %0 = shard.all_to_all %arg0 on @grid0 grid_axes = [0, 0]
split_axis = 0 concat_axis = 0
: tensor<3x6xi8> -> tensor<3x6xi8>
return %0 : tensor<3x6xi8>
@@ -422,12 +422,12 @@ func.func @all_to_all_duplicate_mesh_axis(
// -----
-mesh.mesh @mesh0(shape = ?x1)
+shard.grid @grid0(shape = ?x1)
func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_device_group(
%arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
// expected-error@+1 {{Dimension size mismatch for result axis 1. Expected dynamic, but got 6.}}
- %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.all_to_all %arg0 on @grid0 grid_axes = [0]
split_axis = 0 concat_axis = 1
: tensor<3x6xi8> -> tensor<3x6xi8>
return %0 : tensor<3x6xi8>
@@ -435,12 +435,12 @@ func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_de
// -----
-mesh.mesh @mesh0(shape = 1x1)
+shard.grid @grid0(shape = 1x1)
func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dynamic_operand_dimension(
%arg0 : tensor<?x6xi8>) -> tensor<3x?xi8> {
// expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
- %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [1]
+ %0 = shard.all_to_all %arg0 on @grid0 grid_axes = [1]
split_axis = 0 concat_axis = 1
: tensor<?x6xi8> -> tensor<3x?xi8>
return %0 : tensor<3x?xi8>
@@ -448,12 +448,12 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dyna
// -----
-mesh.mesh @mesh0(shape = 1x1)
+shard.grid @grid0(shape = 1x1)
func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dynamic_operand_dimension(
%arg0 : tensor<3x?xi8>) -> tensor<?x3xi8> {
// expected-error@+1 {{Dimension size mismatch for result axis 1. Expected dynamic, but got 3.}}
- %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [1]
+ %0 = shard.all_to_all %arg0 on @grid0 grid_axes = [1]
split_axis = 0 concat_axis = 1
: tensor<3x?xi8> -> tensor<?x3xi8>
return %0 : tensor<?x3xi8>
@@ -461,12 +461,12 @@ func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dyn
// -----
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size(
%arg0 : tensor<3x2xi8>) -> tensor<1x7xi8> {
// expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 6, but got 7.}}
- %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.all_to_all %arg0 on @grid0 grid_axes = [0]
split_axis = 0 concat_axis = 1
: tensor<3x2xi8> -> tensor<1x7xi8>
return %0 : tensor<1x7xi8>
@@ -474,12 +474,12 @@ func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size(
// -----
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
%arg0 : tensor<3x2xi8>) -> tensor<2x6xi8> {
// expected-error@+1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
- %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.all_to_all %arg0 on @grid0 grid_axes = [0]
split_axis = 0 concat_axis = 1
: tensor<3x2xi8> -> tensor<2x6xi8>
return %0 : tensor<2x6xi8>
@@ -487,12 +487,12 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
// -----
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
func.func @broadcast_root_dimension_out_of_bounds(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
// expected-error@+1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}}
- %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.broadcast %arg0 on @grid0 grid_axes = [0]
root = [3]
: (tensor<2xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
@@ -500,12 +500,12 @@ func.func @broadcast_root_dimension_out_of_bounds(
// -----
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
func.func @broadcast_root_wrong_number_dimensions(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
// expected-error@+1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}}
- %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.broadcast %arg0 on @grid0 grid_axes = [0]
root = [2, 2]
: (tensor<2xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
@@ -513,12 +513,12 @@ func.func @broadcast_root_wrong_number_dimensions(
// -----
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
func.func @broadcast_different_input_and_result_type(
%arg0 : tensor<2xi8>) -> tensor<2xi16> {
- // expected-error@+1 {{'mesh.broadcast' op failed to verify that all of {input, result} have same element type}}
- %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0]
+ // expected-error@+1 {{'shard.broadcast' op failed to verify that all of {input, result} have same element type}}
+ %0 = shard.broadcast %arg0 on @grid0 grid_axes = [0]
root = [2]
: (tensor<2xi8>) -> tensor<2xi16>
return %0 : tensor<2xi16>
@@ -526,84 +526,84 @@ func.func @broadcast_different_input_and_result_type(
// -----
-mesh.mesh @mesh0(shape = 1)
+shard.grid @grid0(shape = 1)
func.func @gather_wrong_return_element_type(
%arg0 : tensor<1xf32>) -> tensor<1xi8> {
- // expected-error@+1 {{'mesh.gather' op failed to verify that all of {input, result} have same element type}}
- %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 root = [0]
+ // expected-error@+1 {{'shard.gather' op failed to verify that all of {input, result} have same element type}}
+ %0 = shard.gather %arg0 on @grid0 grid_axes = [0] gather_axis = 0 root = [0]
: (tensor<1xf32>) -> tensor<1xi8>
return %0 : tensor<1xi8>
}
// -----
-mesh.mesh @mesh0(shape = 1)
+shard.grid @grid0(shape = 1)
func.func @gather_invalid_non_gather_axis_dimension_size(
%arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
// expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 4, but got 5.}}
- %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 root = [0]
+ %0 = shard.gather %arg0 on @grid0 grid_axes = [0] gather_axis = 0 root = [0]
: (tensor<3x4xf32>) -> tensor<3x5xf32>
return %0 : tensor<3x5xf32>
}
// -----
-mesh.mesh @mesh0(shape = 1x2)
+shard.grid @grid0(shape = 1x2)
func.func @gather_invalid_gather_axis_dimension_size(
%arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
// expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 8, but got 5.}}
- %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [1] gather_axis = 1 root = [0]
+ %0 = shard.gather %arg0 on @grid0 grid_axes = [1] gather_axis = 1 root = [0]
: (tensor<3x4xf32>) -> tensor<3x5xf32>
return %0 : tensor<3x5xf32>
}
// -----
-mesh.mesh @mesh0(shape = 1)
+shard.grid @grid0(shape = 1)
func.func @gather_invalid_gather_axis_dynamic_dimension(
%arg0 : tensor<?xf32>) -> tensor<3xf32> {
// expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
- %0 = mesh.gather %arg0 on @mesh0 gather_axis = 0 root = []
+ %0 = shard.gather %arg0 on @grid0 gather_axis = 0 root = []
: (tensor<?xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
}
// -----
-mesh.mesh @mesh0(shape = 1)
+shard.grid @grid0(shape = 1)
func.func @gather_invalid_gather_axis(
%arg0 : tensor<3xf32>) -> tensor<3xf32> {
// expected-error@+1 {{Gather axis 1 is out of bounds [0, 1).}}
- %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 1 root = [0]
+ %0 = shard.gather %arg0 on @grid0 grid_axes = [0] gather_axis = 1 root = [0]
: (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
}
// -----
-mesh.mesh @mesh0(shape = 1)
+shard.grid @grid0(shape = 1)
func.func @gather_invalid_negative_gather_axis(
%arg0 : tensor<3xf32>) -> tensor<3xf32> {
// expected-error@+1 {{Gather axis -1 is out of bounds [0, 1).}}
- %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = -1 root = [0]
+ %0 = shard.gather %arg0 on @grid0 grid_axes = [0] gather_axis = -1 root = [0]
: (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
}
// -----
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
func.func @gather_root_dimension_out_of_bounds(
%arg0 : tensor<2xi8>) -> tensor<6xi8> {
// expected-error@+1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}}
- %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0
+ %0 = shard.gather %arg0 on @grid0 grid_axes = [0] gather_axis = 0
root = [3]
: (tensor<2xi8>) -> tensor<6xi8>
return %0 : tensor<6xi8>
@@ -611,12 +611,12 @@ func.func @gather_root_dimension_out_of_bounds(
// -----
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
func.func @gather_root_wrong_number_dimensions(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
// expected-error@+1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}}
- %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0
+ %0 = shard.gather %arg0 on @grid0 grid_axes = [0] gather_axis = 0
root = [2, 2]
: (tensor<2xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
@@ -624,12 +624,12 @@ func.func @gather_root_wrong_number_dimensions(
// -----
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
func.func @receive_source_dimension_out_of_bounds(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
// expected-error@+1 {{Out of bounds coordinate 0 for in-group device "source". Got 3, but expected value in the range [0, 2].}}
- %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.recv %arg0 on @grid0 grid_axes = [0]
source = [3]
: (tensor<2xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
@@ -637,12 +637,12 @@ func.func @receive_source_dimension_out_of_bounds(
// -----
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
func.func @receive_source_wrong_number_dimensions(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
// expected-error@+1 {{In-group device "source" has unexpected multi-index size 2. Expected 1.}}
- %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.recv %arg0 on @grid0 grid_axes = [0]
source = [2, 2]
: (tensor<2xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
@@ -650,12 +650,12 @@ func.func @receive_source_wrong_number_dimensions(
// -----
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
func.func @receive_different_input_and_result_type(
%arg0 : tensor<2xi8>) -> tensor<2xi16> {
- // expected-error@+1 {{'mesh.recv' op failed to verify that all of {input, result} have same element type}}
- %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0]
+ // expected-error@+1 {{'shard.recv' op failed to verify that all of {input, result} have same element type}}
+ %0 = shard.recv %arg0 on @grid0 grid_axes = [0]
source = [2]
: (tensor<2xi8>) -> tensor<2xi16>
return %0 : tensor<2xi16>
@@ -663,12 +663,12 @@ func.func @receive_different_input_and_result_type(
// -----
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
func.func @reduce_root_dimension_out_of_bounds(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
// expected-error@+1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}}
- %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.reduce %arg0 on @grid0 grid_axes = [0]
root = [3]
: (tensor<2xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
@@ -676,12 +676,12 @@ func.func @reduce_root_dimension_out_of_bounds(
// -----
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
func.func @reduce_root_wrong_number_dimensions(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
// expected-error@+1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}}
- %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.reduce %arg0 on @grid0 grid_axes = [0]
root = [2, 2]
: (tensor<2xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
@@ -689,12 +689,12 @@ func.func @reduce_root_wrong_number_dimensions(
// -----
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
func.func @reduce_different_input_and_result_shape(
%arg0 : tensor<2xi8>) -> tensor<3xi16> {
- // expected-error@+1 {{'mesh.reduce' op failed to verify that all of {input, result} have same shape}}
- %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0]
+ // expected-error@+1 {{'shard.reduce' op failed to verify that all of {input, result} have same shape}}
+ %0 = shard.reduce %arg0 on @grid0 grid_axes = [0]
root = [2]
: (tensor<2xi8>) -> tensor<3xi16>
return %0 : tensor<3xi16>
@@ -702,60 +702,60 @@ func.func @reduce_different_input_and_result_shape(
// -----
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
-func.func @reduce_scatter_duplicate_mesh_axis(
+func.func @reduce_scatter_duplicate_grid_axis(
%arg0 : tensor<?xf32>) -> tensor<?xf64> {
- // expected-error@+1 {{Mesh axes contains duplicate elements.}}
- %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0, 0] scatter_axis = 0
+ // expected-error@+1 {{Grid axes contains duplicate elements.}}
+ %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0, 0] scatter_axis = 0
: tensor<?xf32> -> tensor<?xf64>
return %0 : tensor<?xf64>
}
// -----
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
func.func @reduce_scatter_invalid_dynamic_dimension(
%arg0 : tensor<?xf32>) -> tensor<2xf64> {
// expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
- %0 = mesh.reduce_scatter %arg0 on @mesh0 scatter_axis = 0
+ %0 = shard.reduce_scatter %arg0 on @grid0 scatter_axis = 0
: tensor<?xf32> -> tensor<2xf64>
return %0 : tensor<2xf64>
}
// -----
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
func.func @reduce_scatter_invalid_static_dimension_size(
%arg0 : tensor<3xf32>) -> tensor<2xf64> {
// expected-error@+1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
- %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0] scatter_axis = 0
+ %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_axis = 0
: tensor<3xf32> -> tensor<2xf64>
return %0 : tensor<2xf64>
}
// -----
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
func.func @reduce_scatter_invalid_operand_static_dimension_size(
%arg0 : tensor<4xf32>) -> tensor<?xf64> {
// expected-error@+1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}}
- %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0] scatter_axis = 0
+ %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_axis = 0
: tensor<4xf32> -> tensor<?xf64>
return %0 : tensor<?xf64>
}
// -----
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
-func.func @scatter_duplicate_mesh_axis(
+func.func @scatter_duplicate_grid_axis(
%arg0 : tensor<?xf32>) -> tensor<?xf32> {
- // expected-error@+1 {{Mesh axes contains duplicate elements.}}
- %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0, 0]
+ // expected-error@+1 {{Grid axes contains duplicate elements.}}
+ %0 = shard.scatter %arg0 on @grid0 grid_axes = [0, 0]
scatter_axis = 0 root = [0, 0]
: (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
@@ -763,12 +763,12 @@ func.func @scatter_duplicate_mesh_axis(
// -----
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
func.func @scatter_invalid_dynamic_dimension(
%arg0 : tensor<?xf32>) -> tensor<2xf32> {
// expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
- %0 = mesh.scatter %arg0 on @mesh0
+ %0 = shard.scatter %arg0 on @grid0
scatter_axis = 0 root = []
: (tensor<?xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
@@ -776,12 +776,12 @@ func.func @scatter_invalid_dynamic_dimension(
// -----
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
func.func @scatter_invalid_static_dimension_size(
%arg0 : tensor<3xf32>) -> tensor<2xf32> {
// expected-error@+1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
- %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.scatter %arg0 on @grid0 grid_axes = [0]
scatter_axis = 0 root = [1]
: (tensor<3xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
@@ -789,12 +789,12 @@ func.func @scatter_invalid_static_dimension_size(
// -----
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
func.func @scatter_invalid_operand_static_dimension_size(
%arg0 : tensor<4xf32>) -> tensor<?xf32> {
// expected-error@+1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}}
- %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.scatter %arg0 on @grid0 grid_axes = [0]
scatter_axis = 0 root = [1]
: (tensor<4xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
@@ -802,12 +802,12 @@ func.func @scatter_invalid_operand_static_dimension_size(
// -----
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
func.func @scatter_root_dimension_out_of_bounds(
%arg0 : tensor<3xi8>) -> tensor<1xi8> {
// expected-error@+1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}}
- %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.scatter %arg0 on @grid0 grid_axes = [0]
scatter_axis = 0 root = [3]
: (tensor<3xi8>) -> tensor<1xi8>
return %0 : tensor<1xi8>
@@ -815,12 +815,12 @@ func.func @scatter_root_dimension_out_of_bounds(
// -----
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
func.func @scatter_root_wrong_number_dimensions(
%arg0 : tensor<3xi8>) -> tensor<1xi8> {
// expected-error@+1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}}
- %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.scatter %arg0 on @grid0 grid_axes = [0]
scatter_axis = 0 root = [2, 2]
: (tensor<3xi8>) -> tensor<1xi8>
return %0 : tensor<1xi8>
@@ -828,12 +828,12 @@ func.func @scatter_root_wrong_number_dimensions(
// -----
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
func.func @send_destination_dimension_out_of_bounds(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
// expected-error@+1 {{Out of bounds coordinate 0 for in-group device "destination". Got 3, but expected value in the range [0, 2].}}
- %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.send %arg0 on @grid0 grid_axes = [0]
destination = [3]
: (tensor<2xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
@@ -841,12 +841,12 @@ func.func @send_destination_dimension_out_of_bounds(
// -----
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
func.func @send_destination_wrong_number_dimensions(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
// expected-error@+1 {{In-group device "destination" has unexpected multi-index size 2. Expected 1.}}
- %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.send %arg0 on @grid0 grid_axes = [0]
destination = [2, 2]
: (tensor<2xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
@@ -854,12 +854,12 @@ func.func @send_destination_wrong_number_dimensions(
// -----
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
func.func @send_different_input_and_result_type(
%arg0 : tensor<2xi8>) -> tensor<2xi16> {
- // expected-error@+1 {{'mesh.send' op failed to verify that all of {input, result} have same element type}}
- %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0]
+ // expected-error@+1 {{'shard.send' op failed to verify that all of {input, result} have same element type}}
+ %0 = shard.send %arg0 on @grid0 grid_axes = [0]
destination = [2]
: (tensor<2xi8>) -> tensor<2xi16>
return %0 : tensor<2xi16>
@@ -867,10 +867,10 @@ func.func @send_different_input_and_result_type(
// -----
-func.func @shift_invalid_mesh_symbol(
+func.func @shift_invalid_grid_symbol(
%arg0 : tensor<4xi8>) -> tensor<4xi8> {
- // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
- %0 = mesh.shift %arg0 on @this_mesh_symbol_does_not_exist
+ // expected-error@+1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}}
+ %0 = shard.shift %arg0 on @this_grid_symbol_does_not_exist
shift_axis = 0 offset = -2
: tensor<4xi8> -> tensor<4xi8>
return %0 : tensor<4xi8>
@@ -878,12 +878,12 @@ func.func @shift_invalid_mesh_symbol(
// -----
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
-func.func @shift_invalid_mesh_axis(
+func.func @shift_invalid_grid_axis(
%arg0 : tensor<4xi8>) -> tensor<4xi8> {
- // expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
- %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [2]
+ // expected-error@+1 {{0-based grid axis index 2 is out of bounds. The referenced grid "grid0" is of rank 2.}}
+ %0 = shard.shift %arg0 on @grid0 grid_axes = [2]
shift_axis = 2 offset = -2
: tensor<4xi8> -> tensor<4xi8>
return %0 : tensor<4xi8>
@@ -891,12 +891,12 @@ func.func @shift_invalid_mesh_axis(
// -----
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
-func.func @shift_duplicate_mesh_axis(
+func.func @shift_duplicate_grid_axis(
%arg0 : tensor<4xi8>) -> tensor<4xi8> {
- // expected-error@+1 {{Mesh axes contains duplicate elements.}}
- %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0, 1, 0]
+ // expected-error@+1 {{Grid axes contains duplicate elements.}}
+ %0 = shard.shift %arg0 on @grid0 grid_axes = [0, 1, 0]
shift_axis = 0 offset = -2
: tensor<4xi8> -> tensor<4xi8>
return %0 : tensor<4xi8>
@@ -904,12 +904,12 @@ func.func @shift_duplicate_mesh_axis(
// -----
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
func.func @shift_invalid_tensor_dimension_size(
%arg0 : tensor<4xi8>) -> tensor<5xi8> {
- // expected-error@+1 {{'mesh.shift' op requires the same shape for all operands and results}}
- %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0]
+ // expected-error@+1 {{'shard.shift' op requires the same shape for all operands and results}}
+ %0 = shard.shift %arg0 on @grid0 grid_axes = [0]
shift_axis = 0 offset = 2
: tensor<4xi8> -> tensor<5xi8>
return %0 : tensor<5xi8>
@@ -917,12 +917,12 @@ func.func @shift_invalid_tensor_dimension_size(
// -----
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
func.func @shift_invalid_shift_axis(
%arg0 : tensor<4xi8>) -> tensor<4xi8> {
- // expected-error@+1 {{Invalid shift axis 1. It must be one of the grouping mesh axes.}}
- %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0]
+ // expected-error@+1 {{Invalid shift axis 1. It must be one of the grouping grid axes.}}
+ %0 = shard.shift %arg0 on @grid0 grid_axes = [0]
shift_axis = 1 offset = 2
: tensor<4xi8> -> tensor<4xi8>
return %0 : tensor<4xi8>
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Shard/ops.mlir
index c354de5..5265dad 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Shard/ops.mlir
@@ -1,176 +1,176 @@
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
-// CHECK: mesh.mesh @mesh0
-mesh.mesh @mesh0(shape = 2x2x4)
+// CHECK: shard.grid @grid0
+shard.grid @grid0(shape = 2x2x4)
-// CHECK: mesh.mesh @mesh1(shape = 4x?)
-mesh.mesh @mesh1(shape = 4x?)
+// CHECK: shard.grid @grid1(shape = 4x?)
+shard.grid @grid1(shape = 4x?)
-// CHECK: mesh.mesh @mesh2(shape = ?x4)
-mesh.mesh @mesh2(shape = ?x4)
+// CHECK: shard.grid @grid2(shape = ?x4)
+shard.grid @grid2(shape = ?x4)
-// CHECK: mesh.mesh @mesh3(shape = ?x?)
-mesh.mesh @mesh3(shape = ?x?)
+// CHECK: shard.grid @grid3(shape = ?x?)
+shard.grid @grid3(shape = ?x?)
-mesh.mesh @mesh4(shape = 3)
+shard.grid @grid4(shape = 3)
-// CHECK: mesh.mesh @mesh5(shape = ?)
-mesh.mesh @mesh5(shape = ?)
+// CHECK: shard.grid @grid5(shape = ?)
+shard.grid @grid5(shape = ?)
-// CHECK-LABEL: func @mesh_shard_op_fully_replicated
+// CHECK-LABEL: func @grid_shard_op_fully_replicated
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
-func.func @mesh_shard_op_fully_replicated(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
- // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}]] : !mesh.sharding
- %s = mesh.sharding @mesh0 split_axes = [[]] : !mesh.sharding
- // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
- %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+func.func @grid_shard_op_fully_replicated(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+ // CHECK-NEXT: %[[S:.*]] = shard.sharding @grid0 split_axes = {{\[\[}}]] : !shard.sharding
+ %s = shard.sharding @grid0 split_axes = [[]] : !shard.sharding
+ // CHECK-NEXT: shard.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
+ %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
-// CHECK-LABEL: func @mesh_shard_op_1st_dim
+// CHECK-LABEL: func @grid_shard_op_1st_dim
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
-func.func @mesh_shard_op_1st_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
- // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}0]] : !mesh.sharding
- %s = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
+func.func @grid_shard_op_1st_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+ // CHECK-NEXT: %[[S:.*]] = shard.sharding @grid0 split_axes = {{\[\[}}0]] : !shard.sharding
+ %s = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
- %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+ %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
-// CHECK-LABEL: func @mesh_shard_op_2nd_dim
+// CHECK-LABEL: func @grid_shard_op_2nd_dim
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
-func.func @mesh_shard_op_2nd_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
- // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh1 split_axes = {{\[\[}}], [0]] : !mesh.sharding
- %s = mesh.sharding @mesh1 split_axes = [[], [0]] : !mesh.sharding
- // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
- %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+func.func @grid_shard_op_2nd_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+ // CHECK-NEXT: %[[S:.*]] = shard.sharding @grid1 split_axes = {{\[\[}}], [0]] : !shard.sharding
+ %s = shard.sharding @grid1 split_axes = [[], [0]] : !shard.sharding
+ // CHECK-NEXT: shard.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
+ %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
-// CHECK-LABEL: func @mesh_shard_op_1st_and_3rd_dim
-func.func @mesh_shard_op_1st_and_3rd_dim(
+// CHECK-LABEL: func @grid_shard_op_1st_and_3rd_dim
+func.func @grid_shard_op_1st_and_3rd_dim(
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8x16xf32>
%arg0 : tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
- // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3 split_axes = {{\[\[}}0], [], [1]] : !mesh.sharding
- %s = mesh.sharding @mesh3 split_axes = [[0], [], [1]] : !mesh.sharding
- // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8x16xf32>
- %0 = mesh.shard %arg0 to %s : tensor<4x8x16xf32>
+ // CHECK-NEXT: %[[S:.*]] = shard.sharding @grid3 split_axes = {{\[\[}}0], [], [1]] : !shard.sharding
+ %s = shard.sharding @grid3 split_axes = [[0], [], [1]] : !shard.sharding
+ // CHECK-NEXT: shard.shard %[[ARG]] to %[[S]] : tensor<4x8x16xf32>
+ %0 = shard.shard %arg0 to %s : tensor<4x8x16xf32>
return %0 : tensor<4x8x16xf32>
}
-// CHECK-LABEL: func @mesh_shard_op_two_users
+// CHECK-LABEL: func @grid_shard_op_two_users
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
-func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) ->
+func.func @grid_shard_op_two_users(%arg0 : tensor<4x8xf32>) ->
(tensor<4x8xf32>, tensor<4x8xf32>) {
- // CHECK-NEXT: %[[V0:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}0]] : !mesh.sharding
- %s0 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<4x8xf32>
- // CHECK-DAG: mesh.sharding @mesh0 split_axes = {{\[\[}}1]] : !mesh.sharding
- %s1 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<4x8xf32>
- // CHECK-DAG: mesh.sharding @mesh0 split_axes = {{\[\[}}2]] : !mesh.sharding
- %s2 = mesh.sharding @mesh0 split_axes = [[2]] : !mesh.sharding
- %2 = mesh.shard %0 to %s2 annotate_for_users : tensor<4x8xf32>
+ // CHECK-NEXT: %[[V0:.*]] = shard.sharding @grid0 split_axes = {{\[\[}}0]] : !shard.sharding
+ %s0 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<4x8xf32>
+ // CHECK-DAG: shard.sharding @grid0 split_axes = {{\[\[}}1]] : !shard.sharding
+ %s1 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<4x8xf32>
+ // CHECK-DAG: shard.sharding @grid0 split_axes = {{\[\[}}2]] : !shard.sharding
+ %s2 = shard.sharding @grid0 split_axes = [[2]] : !shard.sharding
+ %2 = shard.shard %0 to %s2 annotate_for_users : tensor<4x8xf32>
return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32>
}
-// CHECK-LABEL: func @mesh_shard_halo_sizes
-func.func @mesh_shard_halo_sizes() -> () {
+// CHECK-LABEL: func @grid_shard_halo_sizes
+func.func @grid_shard_halo_sizes() -> () {
// CHECK: %[[C3:.*]] = arith.constant 3 : i64
%c3 = arith.constant 3 : i64
- // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] halo_sizes = [1, 4] : !mesh.sharding
- %sharding1 = mesh.sharding @mesh4 split_axes = [[0]] halo_sizes = [1, 4] : !mesh.sharding
- // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] halo_sizes = [4, %[[C3]]] : !mesh.sharding
- %sharding2 = mesh.sharding @mesh4 split_axes = [[0]] halo_sizes = [4, %c3] : !mesh.sharding
+ // CHECK: shard.sharding @grid4 split_axes = {{\[\[}}0]] halo_sizes = [1, 4] : !shard.sharding
+ %sharding1 = shard.sharding @grid4 split_axes = [[0]] halo_sizes = [1, 4] : !shard.sharding
+ // CHECK: shard.sharding @grid4 split_axes = {{\[\[}}0]] halo_sizes = [4, %[[C3]]] : !shard.sharding
+ %sharding2 = shard.sharding @grid4 split_axes = [[0]] halo_sizes = [4, %c3] : !shard.sharding
return
}
-// CHECK-LABEL: func @mesh_shard_dims_sizes
-func.func @mesh_shard_dims_sizes() -> () {
+// CHECK-LABEL: func @grid_shard_dims_sizes
+func.func @grid_shard_dims_sizes() -> () {
// CHECK: %[[C3:.*]] = arith.constant 3 : i64
%c3 = arith.constant 3 : i64
- // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 6] : !mesh.sharding
- %sharding1 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 6] : !mesh.sharding
- // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 2, %[[C3]], 5] : !mesh.sharding
- %sharding2 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_offsets = [0, 2, %c3, 5] : !mesh.sharding
+ // CHECK: shard.sharding @grid4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 6] : !shard.sharding
+ %sharding1 = shard.sharding @grid4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 6] : !shard.sharding
+ // CHECK: shard.sharding @grid4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 2, %[[C3]], 5] : !shard.sharding
+ %sharding2 = shard.sharding @grid4 split_axes = [[0]] sharded_dims_offsets = [0, 2, %c3, 5] : !shard.sharding
return
}
-// CHECK-LABEL: func @mesh_shard_shape
-func.func @mesh_shard_shape() {
+// CHECK-LABEL: func @grid_shard_shape
+func.func @grid_shard_shape() {
// CHECK: %[[C3:.*]] = arith.constant 3 : index
%c3 = arith.constant 3 : index
- // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}]] : !mesh.sharding
- %s = mesh.sharding @mesh0 split_axes = [[]] : !mesh.sharding
- // CHECK-NEXT: mesh.shard_shape dims = [8, %[[C3]]
+ // CHECK-NEXT: %[[S:.*]] = shard.sharding @grid0 split_axes = {{\[\[}}]] : !shard.sharding
+ %s = shard.sharding @grid0 split_axes = [[]] : !shard.sharding
+ // CHECK-NEXT: shard.shard_shape dims = [8, %[[C3]]
// CHECK-SAME: ] sharding = %[[S]] device = [%[[C3]]
// CHECK-SAME: ] : index, index
- %shp:2 = mesh.shard_shape dims = [8, %c3] sharding = %s device = [%c3] : index, index
- // CHECK-NEXT: mesh.shard_shape dims = [8, 4] sharding = %[[S]] device = [3] : index, index
- %shp1:2 = mesh.shard_shape dims = [8, 4] sharding = %s device = [3] : index, index
+ %shp:2 = shard.shard_shape dims = [8, %c3] sharding = %s device = [%c3] : index, index
+ // CHECK-NEXT: shard.shard_shape dims = [8, 4] sharding = %[[S]] device = [3] : index, index
+ %shp1:2 = shard.shard_shape dims = [8, 4] sharding = %s device = [3] : index, index
return
}
-// CHECK-LABEL: func @mesh_get_sharding
+// CHECK-LABEL: func @grid_get_sharding
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
-func.func @mesh_get_sharding(%arg0 : tensor<4x8xf32>) -> !mesh.sharding {
- // CHECK-NEXT: mesh.get_sharding %[[ARG]] : tensor<4x8xf32> -> !mesh.sharding
- %0 = mesh.get_sharding %arg0 : tensor<4x8xf32> -> !mesh.sharding
- return %0 : !mesh.sharding
+func.func @grid_get_sharding(%arg0 : tensor<4x8xf32>) -> !shard.sharding {
+ // CHECK-NEXT: shard.get_sharding %[[ARG]] : tensor<4x8xf32> -> !shard.sharding
+ %0 = shard.get_sharding %arg0 : tensor<4x8xf32> -> !shard.sharding
+ return %0 : !shard.sharding
}
-// CHECK-LABEL: func @mesh_shape
-func.func @mesh_shape() -> (index, index) {
- // CHECK: %[[RES:.*]]:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index
- %0:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index
+// CHECK-LABEL: func @grid_shape
+func.func @grid_shape() -> (index, index) {
+ // CHECK: %[[RES:.*]]:2 = shard.grid_shape @grid0 axes = [0, 1] : index, index
+ %0:2 = shard.grid_shape @grid0 axes = [0, 1] : index, index
// CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index
return %0#0, %0#1 : index, index
}
-// CHECK-LABEL: func @mesh_shape_default_axes
-func.func @mesh_shape_default_axes() -> (index, index, index) {
- // CHECK: %[[RES:.*]]:3 = mesh.mesh_shape @mesh0 : index, index, index
- %0:3 = mesh.mesh_shape @mesh0 : index, index, index
+// CHECK-LABEL: func @grid_shape_default_axes
+func.func @grid_shape_default_axes() -> (index, index, index) {
+ // CHECK: %[[RES:.*]]:3 = shard.grid_shape @grid0 : index, index, index
+ %0:3 = shard.grid_shape @grid0 : index, index, index
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}
-// CHECK-LABEL: func @mesh_shape_empty_axes
-func.func @mesh_shape_empty_axes() -> (index, index, index) {
- // CHECK: %[[RES:.*]]:3 = mesh.mesh_shape @mesh0 : index, index, index
- %0:3 = mesh.mesh_shape @mesh0 axes = [] : index, index, index
+// CHECK-LABEL: func @grid_shape_empty_axes
+func.func @grid_shape_empty_axes() -> (index, index, index) {
+ // CHECK: %[[RES:.*]]:3 = shard.grid_shape @grid0 : index, index, index
+ %0:3 = shard.grid_shape @grid0 axes = [] : index, index, index
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}
// CHECK-LABEL: func @process_multi_index
func.func @process_multi_index() -> (index, index) {
- // CHECK: %[[RES:.*]]:2 = mesh.process_multi_index on @mesh0 axes = [0, 1] : index, index
- %0:2 = mesh.process_multi_index on @mesh0 axes = [0, 1] : index, index
+ // CHECK: %[[RES:.*]]:2 = shard.process_multi_index on @grid0 axes = [0, 1] : index, index
+ %0:2 = shard.process_multi_index on @grid0 axes = [0, 1] : index, index
// CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index
return %0#0, %0#1 : index, index
}
// CHECK-LABEL: func @process_multi_index_default_axes
func.func @process_multi_index_default_axes() -> (index, index, index) {
- // CHECK: %[[RES:.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
- %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ // CHECK: %[[RES:.*]]:3 = shard.process_multi_index on @grid0 : index, index, index
+ %0:3 = shard.process_multi_index on @grid0 : index, index, index
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}
// CHECK-LABEL: func @process_multi_index_empty_axes
func.func @process_multi_index_empty_axes() -> (index, index, index) {
- // CHECK: %[[RES:.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
- %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index
+ // CHECK: %[[RES:.*]]:3 = shard.process_multi_index on @grid0 : index, index, index
+ %0:3 = shard.process_multi_index on @grid0 axes = [] : index, index, index
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}
// CHECK-LABEL: func @process_linear_index
func.func @process_linear_index() -> index {
- // CHECK: %[[RES:.*]] = mesh.process_linear_index on @mesh0 : index
- %0 = mesh.process_linear_index on @mesh0 : index
+ // CHECK: %[[RES:.*]] = shard.process_linear_index on @grid0 : index
+ %0 = shard.process_linear_index on @grid0 : index
// CHECK: return %[[RES]] : index
return %0 : index
}
@@ -179,9 +179,9 @@ func.func @process_linear_index() -> index {
func.func @all_reduce(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
%arg0 : tensor<3x4xf32>) -> tensor<3x4xf64> {
- // CHECK-NEXT: mesh.all_reduce %[[ARG]] on @mesh0 mesh_axes = [1, 0] reduction = max
+ // CHECK-NEXT: shard.all_reduce %[[ARG]] on @grid0 grid_axes = [1, 0] reduction = max
// CHECK-SAME: : tensor<3x4xf32> -> tensor<3x4xf64>
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [1, 0] reduction = max
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [1, 0] reduction = max
: tensor<3x4xf32> -> tensor<3x4xf64>
return %0 : tensor<3x4xf64>
}
@@ -190,9 +190,9 @@ func.func @all_reduce(
func.func @all_gather(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
%arg0 : tensor<3x4xf32>) -> tensor<3x16xf32> {
- // CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh0 mesh_axes = [2] gather_axis = 1
+ // CHECK-NEXT: shard.all_gather %[[ARG]] on @grid0 grid_axes = [2] gather_axis = 1
// CHECK-SAME: : tensor<3x4xf32> -> tensor<3x16xf32>
- %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 1
+ %0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 1
: tensor<3x4xf32> -> tensor<3x16xf32>
return %0 : tensor<3x16xf32>
}
@@ -201,20 +201,20 @@ func.func @all_gather(
func.func @all_gather_dynamic_dims_in_tensor(
// CHECK-SAME: %[[ARG:.*]]: tensor<?x?xf32>
%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
- // CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh0 mesh_axes = [2] gather_axis = 1
+ // CHECK-NEXT: shard.all_gather %[[ARG]] on @grid0 grid_axes = [2] gather_axis = 1
// CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?xf32>
- %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 1
+ %0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 1
: tensor<?x?xf32> -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
-// CHECK-LABEL: func @all_gather_dynamic_dims_in_mesh
-func.func @all_gather_dynamic_dims_in_mesh(
+// CHECK-LABEL: func @all_gather_dynamic_dims_in_grid
+func.func @all_gather_dynamic_dims_in_grid(
// CHECK-SAME: %[[ARG:.*]]: tensor<5x6xf32>
%arg0 : tensor<5x6xf32>) -> tensor<5x?xf32> {
- // CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh3 mesh_axes = [1] gather_axis = 1
+ // CHECK-NEXT: shard.all_gather %[[ARG]] on @grid3 grid_axes = [1] gather_axis = 1
// CHECK-SAME: : tensor<5x6xf32> -> tensor<5x?xf32>
- %0 = mesh.all_gather %arg0 on @mesh3 mesh_axes = [1] gather_axis = 1
+ %0 = shard.all_gather %arg0 on @grid3 grid_axes = [1] gather_axis = 1
: tensor<5x6xf32> -> tensor<5x?xf32>
return %0 : tensor<5x?xf32>
}
@@ -223,10 +223,10 @@ func.func @all_gather_dynamic_dims_in_mesh(
func.func @all_slice_static_dimensions(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
%arg0 : tensor<3x4xf32>) -> tensor<3x1xf32> {
- // CHECK-NEXT: mesh.all_slice %[[ARG]]
- // CHECK-SAME: on @mesh0 mesh_axes = [2] slice_axis = 1
+ // CHECK-NEXT: shard.all_slice %[[ARG]]
+ // CHECK-SAME: on @grid0 grid_axes = [2] slice_axis = 1
// CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf32>
- %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [2] slice_axis = 1
+ %0 = shard.all_slice %arg0 on @grid0 grid_axes = [2] slice_axis = 1
: tensor<3x4xf32> -> tensor<3x1xf32>
return %0 : tensor<3x1xf32>
}
@@ -235,10 +235,10 @@ func.func @all_slice_static_dimensions(
func.func @all_slice_dynamic_dimensions(
// CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
%arg0 : tensor<?xf32>) -> tensor<?xf32> {
- // CHECK-NEXT: mesh.all_slice %[[ARG]]
- // CHECK-SAME: on @mesh3 mesh_axes = [0, 1] slice_axis = 0
+ // CHECK-NEXT: shard.all_slice %[[ARG]]
+ // CHECK-SAME: on @grid3 grid_axes = [0, 1] slice_axis = 0
// CHECK-SAME: : tensor<?xf32> -> tensor<?xf32>
- %0 = mesh.all_slice %arg0 on @mesh3 mesh_axes = [0, 1] slice_axis = 0
+ %0 = shard.all_slice %arg0 on @grid3 grid_axes = [0, 1] slice_axis = 0
: tensor<?xf32> -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@@ -247,10 +247,10 @@ func.func @all_slice_dynamic_dimensions(
func.func @all_to_all(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
%arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
- // CHECK-NEXT: mesh.all_to_all %[[ARG]]
- // CHECK-SAME: on @mesh4 split_axis = 1 concat_axis = 0
+ // CHECK-NEXT: shard.all_to_all %[[ARG]]
+ // CHECK-SAME: on @grid4 split_axis = 1 concat_axis = 0
// CHECK-SAME: : tensor<3x6xi8> -> tensor<3x6xi8>
- %0 = mesh.all_to_all %arg0 on @mesh4
+ %0 = shard.all_to_all %arg0 on @grid4
split_axis = 1 concat_axis = 0
: tensor<3x6xi8> -> tensor<3x6xi8>
return %0 : tensor<3x6xi8>
@@ -260,10 +260,10 @@ func.func @all_to_all(
func.func @all_to_all_dynamic_dims_in_result(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
%arg0 : tensor<3x6xi8>) -> tensor<3x?xi8> {
- // CHECK-NEXT: mesh.all_to_all %[[ARG]]
- // CHECK-SAME: on @mesh4 split_axis = 1 concat_axis = 0
+ // CHECK-NEXT: shard.all_to_all %[[ARG]]
+ // CHECK-SAME: on @grid4 split_axis = 1 concat_axis = 0
// CHECK-SAME: : tensor<3x6xi8> -> tensor<3x?xi8>
- %0 = mesh.all_to_all %arg0 on @mesh4
+ %0 = shard.all_to_all %arg0 on @grid4
split_axis = 1 concat_axis = 0
: tensor<3x6xi8> -> tensor<3x?xi8>
return %0 : tensor<3x?xi8>
@@ -273,10 +273,10 @@ func.func @all_to_all_dynamic_dims_in_result(
func.func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size(
// CHECK-SAME: %[[ARG:.*]]: tensor<3xi8>
%arg0 : tensor<3xi8>) -> tensor<3xi8> {
- // CHECK-NEXT: mesh.all_to_all %[[ARG]]
- // CHECK-SAME: @mesh4 split_axis = 0 concat_axis = 0
+ // CHECK-NEXT: shard.all_to_all %[[ARG]]
+ // CHECK-SAME: @grid4 split_axis = 0 concat_axis = 0
// CHECK-SAME: : tensor<3xi8> -> tensor<3xi8>
- %0 = mesh.all_to_all %arg0 on @mesh4
+ %0 = shard.all_to_all %arg0 on @grid4
split_axis = 0 concat_axis = 0
: tensor<3xi8> -> tensor<3xi8>
return %0 : tensor<3xi8>
@@ -286,10 +286,10 @@ func.func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size(
func.func @all_to_all_non_divisible_split_axis_size(
// CHECK-SAME: %[[ARG:.*]]: tensor<2x3xi8>
%arg0 : tensor<2x3xi8>) -> tensor<?x12xi8> {
- // CHECK-NEXT: mesh.all_to_all %[[ARG]]
- // CHECK-SAME: @mesh0 mesh_axes = [0, 1] split_axis = 0 concat_axis = 1
+ // CHECK-NEXT: shard.all_to_all %[[ARG]]
+ // CHECK-SAME: @grid0 grid_axes = [0, 1] split_axis = 0 concat_axis = 1
// CHECK-SAME: : tensor<2x3xi8> -> tensor<?x12xi8>
- %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0, 1]
+ %0 = shard.all_to_all %arg0 on @grid0 grid_axes = [0, 1]
split_axis = 0 concat_axis = 1
: tensor<2x3xi8> -> tensor<?x12xi8>
return %0 : tensor<?x12xi8>
@@ -299,11 +299,11 @@ func.func @all_to_all_non_divisible_split_axis_size(
func.func @broadcast_static_root(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
%arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
- // CHECK-NEXT: mesh.broadcast %[[ARG]]
- // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-NEXT: shard.broadcast %[[ARG]]
+ // CHECK-SAME: on @grid0 grid_axes = [0, 2]
// CHECK-SAME: root = [0, 1]
// CHECK-SAME: : (tensor<3x6xi8>) -> tensor<3x6xi8>
- %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0, 2]
+ %0 = shard.broadcast %arg0 on @grid0 grid_axes = [0, 2]
root = [0, 1]
: (tensor<3x6xi8>) -> tensor<3x6xi8>
return %0 : tensor<3x6xi8>
@@ -316,11 +316,11 @@ func.func @broadcast_dynamic_root(
// CHECK-SAME: %[[ARG1:.*]]: index
%arg1 : index
) -> tensor<3x6xi8> {
- // CHECK-NEXT: mesh.broadcast %[[ARG0]]
- // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-NEXT: shard.broadcast %[[ARG0]]
+ // CHECK-SAME: on @grid0 grid_axes = [0, 2]
// CHECK-SAME: root = [1, %[[ARG1]]]
// CHECK-SAME: : (tensor<3x6xi8>, index) -> tensor<3x6xi8>
- %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0, 2]
+ %0 = shard.broadcast %arg0 on @grid0 grid_axes = [0, 2]
root = [1, %arg1]
: (tensor<3x6xi8>, index) -> tensor<3x6xi8>
return %0 : tensor<3x6xi8>
@@ -330,12 +330,12 @@ func.func @broadcast_dynamic_root(
func.func @gather_static_root(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
%arg0 : tensor<3x6xi8>) -> tensor<24x6xi8> {
- // CHECK-NEXT: mesh.gather %[[ARG]]
- // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-NEXT: shard.gather %[[ARG]]
+ // CHECK-SAME: on @grid0 grid_axes = [0, 2]
// CHECK-SAME: gather_axis = 0
// CHECK-SAME: root = [0, 1]
// CHECK-SAME: : (tensor<3x6xi8>) -> tensor<24x6xi8>
- %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0, 2]
+ %0 = shard.gather %arg0 on @grid0 grid_axes = [0, 2]
gather_axis = 0
root = [0, 1]
: (tensor<3x6xi8>) -> tensor<24x6xi8>
@@ -349,12 +349,12 @@ func.func @gather_dynamic_root(
// CHECK-SAME: %[[ARG1:.*]]: index
%arg1 : index
) -> tensor<24x6xi8> {
- // CHECK-NEXT: mesh.gather %[[ARG0]]
- // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-NEXT: shard.gather %[[ARG0]]
+ // CHECK-SAME: on @grid0 grid_axes = [0, 2]
// CHECK-SAME: gather_axis = 0
// CHECK-SAME: root = [1, %[[ARG1]]]
// CHECK-SAME: : (tensor<3x6xi8>, index) -> tensor<24x6xi8>
- %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0, 2]
+ %0 = shard.gather %arg0 on @grid0 grid_axes = [0, 2]
gather_axis = 0
root = [1, %arg1]
: (tensor<3x6xi8>, index) -> tensor<24x6xi8>
@@ -365,11 +365,11 @@ func.func @gather_dynamic_root(
func.func @receive_static_source(
// CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
- // CHECK-NEXT: mesh.recv %[[ARG]]
- // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-NEXT: shard.recv %[[ARG]]
+ // CHECK-SAME: on @grid0 grid_axes = [0, 2]
// CHECK-SAME: source = [0, 1]
// CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8>
- %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2]
+ %0 = shard.recv %arg0 on @grid0 grid_axes = [0, 2]
source = [0, 1]
: (tensor<2xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
@@ -382,11 +382,11 @@ func.func @receive_dynamic_source(
// CHECK-SAME: %[[ARG1:.*]]: index
%arg1 : index
) -> tensor<2xi8> {
- // CHECK-NEXT: mesh.recv %[[ARG0]]
- // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-NEXT: shard.recv %[[ARG0]]
+ // CHECK-SAME: on @grid0 grid_axes = [0, 2]
// CHECK-SAME: source = [1, %[[ARG1]]]
// CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8>
- %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2]
+ %0 = shard.recv %arg0 on @grid0 grid_axes = [0, 2]
source = [1, %arg1]
: (tensor<2xi8>, index) -> tensor<2xi8>
return %0 : tensor<2xi8>
@@ -396,9 +396,9 @@ func.func @receive_dynamic_source(
func.func @receive_no_source(
// CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
- // CHECK-NEXT: mesh.recv %[[ARG]]
+ // CHECK-NEXT: shard.recv %[[ARG]]
// CHECK-NOT: source
- %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2]
+ %0 = shard.recv %arg0 on @grid0 grid_axes = [0, 2]
: (tensor<2xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
}
@@ -407,11 +407,11 @@ func.func @receive_no_source(
func.func @reduce_static_root(
// CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
- // CHECK-NEXT: mesh.reduce %[[ARG]]
- // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-NEXT: shard.reduce %[[ARG]]
+ // CHECK-SAME: on @grid0 grid_axes = [0, 2]
// CHECK-SAME: root = [0, 1]
// CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8>
- %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2]
+ %0 = shard.reduce %arg0 on @grid0 grid_axes = [0, 2]
root = [0, 1]
: (tensor<2xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
@@ -424,11 +424,11 @@ func.func @reduce_dynamic_root(
// CHECK-SAME: %[[ARG1:.*]]: index
%arg1 : index
) -> tensor<2xi8> {
- // CHECK-NEXT: mesh.reduce %[[ARG0]]
- // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-NEXT: shard.reduce %[[ARG0]]
+ // CHECK-SAME: on @grid0 grid_axes = [0, 2]
// CHECK-SAME: root = [1, %[[ARG1]]]
// CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8>
- %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2]
+ %0 = shard.reduce %arg0 on @grid0 grid_axes = [0, 2]
root = [1, %arg1]
: (tensor<2xi8>, index) -> tensor<2xi8>
return %0 : tensor<2xi8>
@@ -438,11 +438,11 @@ func.func @reduce_dynamic_root(
func.func @reduce_different_return_element_type(
// CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
%arg0 : tensor<2xi8>) -> tensor<2xi16> {
- // CHECK-NEXT: mesh.reduce %[[ARG]]
- // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-NEXT: shard.reduce %[[ARG]]
+ // CHECK-SAME: on @grid0 grid_axes = [0, 2]
// CHECK-SAME: root = [0, 1]
// CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi16>
- %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2]
+ %0 = shard.reduce %arg0 on @grid0 grid_axes = [0, 2]
root = [0, 1]
: (tensor<2xi8>) -> tensor<2xi16>
return %0 : tensor<2xi16>
@@ -452,10 +452,10 @@ func.func @reduce_different_return_element_type(
func.func @reduce_scatter_static_dimensions(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
%arg0 : tensor<3x4xf32>) -> tensor<3x1xf64> {
- // CHECK-NEXT: mesh.reduce_scatter %[[ARG]]
- // CHECK-SAME: on @mesh0 mesh_axes = [2] reduction = max scatter_axis = 1
+ // CHECK-NEXT: shard.reduce_scatter %[[ARG]]
+ // CHECK-SAME: on @grid0 grid_axes = [2] reduction = max scatter_axis = 1
// CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf64>
- %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [2]
+ %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [2]
reduction = max scatter_axis = 1
: tensor<3x4xf32> -> tensor<3x1xf64>
return %0 : tensor<3x1xf64>
@@ -465,10 +465,10 @@ func.func @reduce_scatter_static_dimensions(
func.func @reduce_scatter_dynamic_dimensions(
// CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
%arg0 : tensor<?xf32>) -> tensor<?xf64> {
- // CHECK-NEXT: mesh.reduce_scatter %[[ARG]]
- // CHECK-SAME: on @mesh3 mesh_axes = [0, 1] scatter_axis = 0
+ // CHECK-NEXT: shard.reduce_scatter %[[ARG]]
+ // CHECK-SAME: on @grid3 grid_axes = [0, 1] scatter_axis = 0
// CHECK-SAME: : tensor<?xf32> -> tensor<?xf64>
- %0 = mesh.reduce_scatter %arg0 on @mesh3 mesh_axes = [0, 1] scatter_axis = 0
+ %0 = shard.reduce_scatter %arg0 on @grid3 grid_axes = [0, 1] scatter_axis = 0
: tensor<?xf32> -> tensor<?xf64>
return %0 : tensor<?xf64>
}
@@ -477,11 +477,11 @@ func.func @reduce_scatter_dynamic_dimensions(
func.func @scatter_static_dimensions(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
%arg0 : tensor<3x4xf32>) -> tensor<3x1xf32> {
- // CHECK-NEXT: mesh.scatter %[[ARG]]
- // CHECK-SAME: on @mesh0 mesh_axes = [2]
+ // CHECK-NEXT: shard.scatter %[[ARG]]
+ // CHECK-SAME: on @grid0 grid_axes = [2]
// CHECK-SAME: scatter_axis = 1 root = [1]
// CHECK-SAME: : (tensor<3x4xf32>) -> tensor<3x1xf32>
- %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [2]
+ %0 = shard.scatter %arg0 on @grid0 grid_axes = [2]
scatter_axis = 1 root = [1]
: (tensor<3x4xf32>) -> tensor<3x1xf32>
return %0 : tensor<3x1xf32>
@@ -491,11 +491,11 @@ func.func @scatter_static_dimensions(
func.func @scatter_dynamic_dimensions(
// CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
%arg0 : tensor<?xf32>) -> tensor<?xf32> {
- // CHECK-NEXT: mesh.scatter %[[ARG]]
- // CHECK-SAME: on @mesh3 mesh_axes = [0, 1]
+ // CHECK-NEXT: shard.scatter %[[ARG]]
+ // CHECK-SAME: on @grid3 grid_axes = [0, 1]
// CHECK-SAME: scatter_axis = 0 root = [1, 2]
// CHECK-SAME: : (tensor<?xf32>) -> tensor<?xf32>
- %0 = mesh.scatter %arg0 on @mesh3 mesh_axes = [0, 1]
+ %0 = shard.scatter %arg0 on @grid3 grid_axes = [0, 1]
scatter_axis = 0 root = [1, 2]
: (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
@@ -508,12 +508,12 @@ func.func @scatter_dynamic_root(
// CHECK-SAME: %[[ARG1:.*]]: index
%arg1 : index
) -> tensor<1xi8> {
- // CHECK-NEXT: mesh.scatter %[[ARG0]]
- // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-NEXT: shard.scatter %[[ARG0]]
+ // CHECK-SAME: on @grid0 grid_axes = [0, 2]
// CHECK-SAME: scatter_axis = 0
// CHECK-SAME: root = [1, %[[ARG1]]]
// CHECK-SAME: : (tensor<8xi8>, index) -> tensor<1xi8>
- %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0, 2]
+ %0 = shard.scatter %arg0 on @grid0 grid_axes = [0, 2]
scatter_axis = 0
root = [1, %arg1]
: (tensor<8xi8>, index) -> tensor<1xi8>
@@ -524,11 +524,11 @@ func.func @scatter_dynamic_root(
func.func @send_static_destination(
// CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
- // CHECK-NEXT: mesh.send %[[ARG]]
- // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-NEXT: shard.send %[[ARG]]
+ // CHECK-SAME: on @grid0 grid_axes = [0, 2]
// CHECK-SAME: destination = [0, 1]
// CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8>
- %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0, 2]
+ %0 = shard.send %arg0 on @grid0 grid_axes = [0, 2]
destination = [0, 1]
: (tensor<2xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
@@ -541,11 +541,11 @@ func.func @send_dynamic_destination(
// CHECK-SAME: %[[ARG1:.*]]: index
%arg1 : index
) -> tensor<2xi8> {
- // CHECK-NEXT: mesh.send %[[ARG0]]
- // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-NEXT: shard.send %[[ARG0]]
+ // CHECK-SAME: on @grid0 grid_axes = [0, 2]
// CHECK-SAME: destination = [1, %[[ARG1]]]
// CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8>
- %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0, 2]
+ %0 = shard.send %arg0 on @grid0 grid_axes = [0, 2]
destination = [1, %arg1]
: (tensor<2xi8>, index) -> tensor<2xi8>
return %0 : tensor<2xi8>
@@ -555,11 +555,11 @@ func.func @send_dynamic_destination(
func.func @shift(
// CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
- // CHECK-NEXT: mesh.shift %[[ARG]]
- // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-NEXT: shard.shift %[[ARG]]
+ // CHECK-SAME: on @grid0 grid_axes = [0, 2]
// CHECK-SAME: shift_axis = 2 offset = -2 rotate
// CHECK-SAME: : tensor<2xi8> -> tensor<2xi8>
- %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0, 2]
+ %0 = shard.shift %arg0 on @grid0 grid_axes = [0, 2]
shift_axis = 2 offset = -2 rotate
: tensor<2xi8> -> tensor<2xi8>
return %0 : tensor<2xi8>
@@ -570,16 +570,16 @@ func.func @update_halo(
// CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8>
%arg0 : memref<12x12xi8>) {
// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i64
- // CHECK-NEXT: %[[UH1:.*]] = mesh.update_halo %[[ARG]] on @mesh0
+ // CHECK-NEXT: %[[UH1:.*]] = shard.update_halo %[[ARG]] on @grid0
// CHECK-SAME: split_axes = {{\[\[}}0]]
// CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8>
%c2 = arith.constant 2 : i64
- %uh1 = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]]
+ %uh1 = shard.update_halo %arg0 on @grid0 split_axes = [[0]]
halo_sizes = [2, %c2] : memref<12x12xi8>
- // CHECK-NEXT: %[[UH2:.*]] = mesh.update_halo %[[UH1]] on @mesh0
+ // CHECK-NEXT: %[[UH2:.*]] = shard.update_halo %[[UH1]] on @grid0
// CHECK-SAME: split_axes = {{\[\[}}0], [1]]
// CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2] : memref<12x12xi8>
- %uh2 = mesh.update_halo %uh1 on @mesh0 split_axes = [[0], [1]]
+ %uh2 = shard.update_halo %uh1 on @grid0 split_axes = [[0], [1]]
halo_sizes = [2, 2, %c2, 2] : memref<12x12xi8>
return
}
diff --git a/mlir/test/Dialect/Shard/partition.mlir b/mlir/test/Dialect/Shard/partition.mlir
new file mode 100644
index 0000000..c2572cc
--- /dev/null
+++ b/mlir/test/Dialect/Shard/partition.mlir
@@ -0,0 +1,317 @@
+// RUN: mlir-opt \
+// RUN: --pass-pipeline="builtin.module(func.func(shard-partition,test-single-fold))" \
+// RUN: %s | FileCheck %s
+
+shard.grid @grid_1d(shape = 2)
+
+// CHECK-LABEL: func @return_sharding
+func.func @return_sharding(
+ // CHECK-SAME: [[ARG:%.*]]: tensor<1xf32>
+ %arg0: tensor<2xf32>
+// CHECK-SAME: ) -> (tensor<1xf32>, !shard.sharding) {
+) -> (tensor<2xf32>, !shard.sharding) {
+ %ssharded = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %sharded = shard.shard %arg0 to %ssharded : tensor<2xf32>
+ // CHECK-NEXT: [[vsharding:%.*]] = shard.sharding @grid_1d split_axes = {{\[\[}}0]] : !shard.sharding
+ %r = shard.get_sharding %sharded : tensor<2xf32> -> !shard.sharding
+ // CHECK-NEXT: return [[ARG]], [[vsharding]] : tensor<1xf32>, !shard.sharding
+ return %sharded, %r : tensor<2xf32>, !shard.sharding
+}
+
+// CHECK-LABEL: func @full_replication
+func.func @full_replication(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
+ %arg0: tensor<2xi8>
+// CHECK-SAME: -> tensor<2xi8> {
+) -> tensor<2xi8> {
+ %s0 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<2xi8>
+ %s1 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<2xi8>
+ // CHECK: return %[[ARG]] : tensor<2xi8>
+ return %1 : tensor<2xi8>
+}
+
+// CHECK-LABEL: func @sharding_triplet
+func.func @sharding_triplet(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<1xf32>
+ %arg0: tensor<2xf32>
+// CHECK-SAME: ) -> tensor<2xf32> {
+) -> tensor<2xf32> {
+ // CHECK: %[[ALL_GATHER:.*]] = shard.all_gather %[[ARG]] on @grid_1d grid_axes = [0] gather_axis = 0 : tensor<1xf32> -> tensor<2xf32>
+ %ssharded = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %sharded = shard.shard %arg0 to %ssharded : tensor<2xf32>
+ %ssharded_0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %sharded_0 = shard.shard %sharded to %ssharded_0 annotate_for_users : tensor<2xf32>
+ %ssharded_1 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+ %sharded_1 = shard.shard %sharded_0 to %ssharded_1 : tensor<2xf32>
+ // CHECK: return %[[ALL_GATHER]] : tensor<2xf32>
+ return %sharded_1 : tensor<2xf32>
+}
+
+
+// CHECK-LABEL: func @move_split_axis
+func.func @move_split_axis(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<1x2xi8>
+ %arg0: tensor<2x2xi8>
+// CHECK-SAME: -> tensor<2x1xi8> {
+) -> tensor<2x2xi8> {
+ // CHECK: %[[ALL_TO_ALL:.*]] = shard.all_to_all %[[ARG]] on @grid_1d
+ // CHECK-SAME: grid_axes = [0] split_axis = 1 concat_axis = 0 : tensor<1x2xi8> -> tensor<2x1xi8>
+ %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<2x2xi8>
+ %s1 = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<2x2xi8>
+ // CHECK: return %[[ALL_TO_ALL]] : tensor<2x1xi8>
+ return %1 : tensor<2x2xi8>
+}
+
+// CHECK-LABEL: func @non_tensor_value
+func.func @non_tensor_value(
+ // CHECK-SAME: %[[ARG:.*]]: i8
+ %arg0: i8
+// CHECK-SAME: -> i8 {
+) -> i8 {
+ // CHECK: %[[RES:.*]] = arith.addi %[[ARG]], %[[ARG]] : i8
+ %0 = arith.addi %arg0, %arg0 : i8
+ // CHECK: return %[[RES]] : i8
+ return %0 : i8
+}
+
+// CHECK-LABEL: func @unary_elementwise
+func.func @unary_elementwise(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<1xi8>
+ %arg0: tensor<2xi8>
+// CHECK-SAME: -> tensor<1xi8> {
+) -> tensor<2xi8> {
+ %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<2xi8>
+ %s1 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<2xi8>
+ // CHECK: %[[RES:.*]] = tosa.abs %[[ARG]] : (tensor<1xi8>) -> tensor<1xi8>
+ %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
+ %s3 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %3 = shard.shard %2 to %s3 : tensor<2xi8>
+ %s4 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %4 = shard.shard %3 to %s4 annotate_for_users : tensor<2xi8>
+ // CHECK: return %[[RES]] : tensor<1xi8>
+ return %4 : tensor<2xi8>
+}
+
+// full replication -> shard axis -> abs -> shard axis -> full replication
+// CHECK-LABEL: func @unary_elementwise_with_resharding
+func.func @unary_elementwise_with_resharding(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
+ %arg0: tensor<2xi8>
+// CHECK-SAME: -> tensor<2xi8> {
+) -> tensor<2xi8> {
+ // CHECK: %[[SLICE:.*]] = shard.all_slice %[[ARG]] on @grid_1d grid_axes = [0] slice_axis = 0
+ // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
+ %s0 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<2xi8>
+ %s1 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<2xi8>
+ // CHECK: %[[ABS:.*]] = tosa.abs %[[SLICE]] : (tensor<1xi8>) -> tensor<1xi8>
+ %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
+ // CHECK: %[[RES:.*]] = shard.all_gather %[[ABS]] on @grid_1d
+ // CHECK-SAME: grid_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
+ %s3 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %3 = shard.shard %2 to %s3 : tensor<2xi8>
+ %s4 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+ %4 = shard.shard %3 to %s4 annotate_for_users : tensor<2xi8>
+ // CHECK: return %[[RES]] : tensor<2xi8>
+ return %4 : tensor<2xi8>
+}
+
+// CHECK-LABEL: func @binary_elementwise
+func.func @binary_elementwise(
+ // CHECK-SAME: %[[ARG0:.*]]: tensor<1xi8>,
+ %arg0: tensor<2xi8>,
+ // CHECK-SAME: %[[ARG1:.*]]: tensor<1xi8>
+ %arg1: tensor<2xi8>
+// CHECK-SAME: -> tensor<1xi8> {
+) -> tensor<2xi8> {
+ %sarg0_sharded = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %arg0_sharded = shard.shard %arg0 to %sarg0_sharded : tensor<2xi8>
+ %sop_arg0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %op_arg0 = shard.shard %arg0_sharded to %sop_arg0 annotate_for_users : tensor<2xi8>
+ %sarg1_sharded = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %arg1_sharded = shard.shard %arg1 to %sarg1_sharded : tensor<2xi8>
+ %sop_arg1 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %op_arg1 = shard.shard %arg1_sharded to %sop_arg1 annotate_for_users : tensor<2xi8>
+ // CHECK: %[[RES:.*]] = tosa.add %[[ARG0]], %[[ARG1]] : (tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8>
+ %op_res = tosa.add %op_arg0, %op_arg1 : (tensor<2xi8>, tensor<2xi8>) -> tensor<2xi8>
+ %sop_res_sharded = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %op_res_sharded = shard.shard %op_res to %sop_res_sharded : tensor<2xi8>
+ %sres = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %res = shard.shard %op_res_sharded to %sres annotate_for_users : tensor<2xi8>
+ // CHECK: return %[[RES]] : tensor<1xi8>
+ return %res : tensor<2xi8>
+}
+
+// reshard
+// abs
+// reshard
+// abs
+// reshard
+// CHECK-LABEL: func @multiple_chained_ops
+func.func @multiple_chained_ops(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
+ %arg0: tensor<2xi8>
+// CHECK-SAME: -> tensor<1xi8> {
+) -> tensor<2xi8> {
+ // CHECK: %[[RESHARD1:.*]] = shard.all_slice %[[ARG]] on @grid_1d grid_axes = [0] slice_axis = 0
+ // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
+ %s0 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<2xi8>
+ %s1 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<2xi8>
+ // CHECK: %[[ABS1:.*]] = tosa.abs %[[RESHARD1]] : (tensor<1xi8>) -> tensor<1xi8>
+ %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
+ // CHECK: %[[RESHARD2:.*]] = shard.all_gather %[[ABS1]] on @grid_1d
+ // CHECK-SAME: grid_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
+ %s3 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %3 = shard.shard %2 to %s3 : tensor<2xi8>
+ %s4 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+ %4 = shard.shard %3 to %s4 annotate_for_users : tensor<2xi8>
+ // CHECK: %[[ABS2:.*]] = tosa.abs %[[RESHARD2]] : (tensor<2xi8>) -> tensor<2xi8>
+ %5 = tosa.abs %4 : (tensor<2xi8>) -> tensor<2xi8>
+ // CHECK: %[[RESHARD3:.*]] = shard.all_slice %[[ABS2]] on @grid_1d grid_axes = [0] slice_axis = 0 :
+ // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
+ %s6 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+ %6 = shard.shard %5 to %s6 : tensor<2xi8>
+ %s7 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %7 = shard.shard %6 to %s7 annotate_for_users : tensor<2xi8>
+ // CHECK: return %[[RESHARD3]] : tensor<1xi8>
+ return %7 : tensor<2xi8>
+}
+
+// CHECK-LABEL: func @incomplete_sharding
+func.func @incomplete_sharding(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<4x16xf32>
+ %arg0: tensor<8x16xf32>
+// CHECK-SAME: -> tensor<4x16xf32> {
+) -> tensor<8x16xf32> {
+ %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 annotate_for_users : tensor<8x16xf32>
+ // CHECK: %[[RES:.*]] = tosa.sigmoid %[[ARG]] : (tensor<4x16xf32>) -> tensor<4x16xf32>
+ %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ %s2 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %2 = shard.shard %1 to %s2 : tensor<8x16xf32>
+ // CHECK: return %[[RES]] : tensor<4x16xf32>
+ return %2 : tensor<8x16xf32>
+}
+
+shard.grid @grid_1d_4(shape = 4)
+
+// CHECK-LABEL: func @ew_chain_with_halo
+func.func @ew_chain_with_halo(
+ // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<5x16xf32>
+ %arg0: tensor<8x16xf32>,
+ // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<1xf32>
+ %arg1: tensor<1xf32>,
+ // CHECK-SAME: %[[IN3:[A-Za-z0-9_]+]]: tensor<1xf32>
+ %arg2: tensor<1xf32>)
+ // CHECK-SAME: -> tensor<5x16xf32>
+ -> tensor<8x16xf32> {
+ %ssharded = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding
+ %sharded = shard.shard %arg0 to %ssharded annotate_for_users : tensor<8x16xf32>
+ // CHECK: %[[TMP1:.*]] = tosa.tanh %[[IN1]] : (tensor<5x16xf32>) -> tensor<5x16xf32>
+ %0 = tosa.tanh %sharded : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ %ssharded_0 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding
+ %sharded_0 = shard.shard %0 to %ssharded_0 : tensor<8x16xf32>
+ %ssharded_1 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding
+ %sharded_1 = shard.shard %sharded_0 to %ssharded_1 annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[TMP2:.*]] = tosa.abs %[[TMP1]] : (tensor<5x16xf32>) -> tensor<5x16xf32>
+ %1 = tosa.abs %sharded_1 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ %ssharded_2 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding
+ %sharded_2 = shard.shard %1 to %ssharded_2 : tensor<8x16xf32>
+ %ssharded_4 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding
+ %sharded_4 = shard.shard %sharded_2 to %ssharded_4 annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[TMP3:.*]] = tosa.negate %[[TMP2]], %[[IN2]], %[[IN3]] : (tensor<5x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x16xf32>
+ %sharding_1 = shard.sharding @grid_1d_4 split_axes = [[]] : !shard.sharding
+ %zero_point_1 = shard.shard %arg1 to %sharding_1 annotate_for_users : tensor<1xf32>
+ %zero_point_2 = shard.shard %arg2 to %sharding_1 annotate_for_users : tensor<1xf32>
+ %2 = tosa.negate %sharded_4, %zero_point_1, %zero_point_2 : (tensor<8x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<8x16xf32>
+ %ssharded_5 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding
+ %sharded_5 = shard.shard %2 to %ssharded_5 : tensor<8x16xf32>
+ %ssharded_6 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding
+ %sharded_6 = shard.shard %sharded_5 to %ssharded_6 annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: return %[[TMP3]] : tensor<5x16xf32>
+ return %sharded_6 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func @test_shard_update_halo
+// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x1200xi64>
+func.func @test_shard_update_halo(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> {
+ %sharding = shard.sharding @grid_1d_4 split_axes = [[0]] : !shard.sharding
+ // CHECK: %[[T:.*]] = tensor.empty() : tensor<304x1200xi64>
+ // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][2, 0] [300, 1200] [1, 1] : tensor<300x1200xi64> into tensor<304x1200xi64>
+ // CHECK: %[[UH:.*]] = shard.update_halo %[[inserted_slice]] on @grid_1d_4 split_axes = {{\[\[0]]}} halo_sizes = [2, 2] : tensor<304x1200xi64>
+ %sharded = shard.shard %arg0 to %sharding : tensor<1200x1200xi64>
+ %sharding_0 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 2] : !shard.sharding
+ %sharded_1 = shard.shard %sharded to %sharding_0 : tensor<1200x1200xi64>
+ %sharded_3 = shard.shard %sharded_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64>
+ // CHECK: return %[[UH]] : tensor<304x1200xi64>
+ return %sharded_3 : tensor<1200x1200xi64>
+}
+
+shard.grid @grid4x4(shape = 4x4)
+// CHECK-LABEL: func @test_shard_update_halo2d
+// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x300xi64>
+func.func @test_shard_update_halo2d(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> {
+ %sharding = shard.sharding @grid4x4 split_axes = [[0], [1]] : !shard.sharding
+ // CHECK: %[[T:.*]] = tensor.empty() : tensor<303x307xi64>
+ // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][1, 3] [300, 300] [1, 1] : tensor<300x300xi64> into tensor<303x307xi64>
+ // CHECK: %[[UH:.*]] = shard.update_halo %[[inserted_slice]] on @grid4x4 split_axes = {{\[\[}}0], [1]] halo_sizes = [1, 2, 3, 4] : tensor<303x307xi64>
+ %sharded = shard.shard %arg0 to %sharding : tensor<1200x1200xi64>
+ %sharding_0 = shard.sharding @grid4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 3, 4] : !shard.sharding
+ %sharded_1 = shard.shard %sharded to %sharding_0 : tensor<1200x1200xi64>
+ %sharded_3 = shard.shard %sharded_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64>
+ // CHECK: return %[[UH]] : tensor<303x307xi64>
+ return %sharded_3 : tensor<1200x1200xi64>
+}
+
+shard.grid @grid(shape = 2)
+// CHECK-LABEL: func.func @test_reduce_0d(
+// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<3x6xi32>
+func.func @test_reduce_0d(%arg0: tensor<6x6xi32>) -> (tensor<i32>) {
+ %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding
+ %sharded = shard.shard %arg0 to %sharding annotate_for_users : tensor<6x6xi32>
+ %4 = tensor.empty() : tensor<i32>
+ %sharding_out = shard.sharding @grid split_axes = [[]] : !shard.sharding
+ %sharded_out = shard.shard %4 to %sharding_out : tensor<i32>
+ %sharded_in = shard.shard %sharded to %sharding annotate_for_users : tensor<6x6xi32>
+ // CHECK: %[[reduced:.*]] = linalg.reduce ins(%arg0 : tensor<3x6xi32>)
+ %reduced = linalg.reduce ins(%sharded_in : tensor<6x6xi32>) outs(%sharded_out : tensor<i32>) dimensions = [0, 1]
+ (%in: i32, %init: i32) {
+ %6 = arith.addi %in, %init : i32
+ linalg.yield %6 : i32
+ }
+ // CHECK: %[[all_reduce:.*]] = shard.all_reduce %[[reduced]] on @grid grid_axes = [0] : tensor<i32> -> tensor<i32>
+ %sharded_red = shard.shard %reduced to %sharding_out : tensor<i32>
+ %sharded_ret = shard.shard %sharded_red to %sharding_out annotate_for_users : tensor<i32>
+ // CHECK: return %[[all_reduce]] : tensor<i32>
+ return %sharded_ret : tensor<i32>
+}
+
+// CHECK-LABEL: func.func @test_reduce_1d(
+// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<3x6xi32>
+func.func @test_reduce_1d(%arg0: tensor<6x6xi32>) -> (tensor<6xi32>) {
+ %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding
+ %sharded = shard.shard %arg0 to %sharding annotate_for_users : tensor<6x6xi32>
+ %4 = tensor.empty() : tensor<6xi32>
+ %sharded_out = shard.shard %4 to %sharding : tensor<6xi32>
+ %sharded_in = shard.shard %sharded to %sharding annotate_for_users : tensor<6x6xi32>
+ // CHECK: %[[reduced:.*]] = linalg.reduce ins(%arg0 : tensor<3x6xi32>)
+ %reduced = linalg.reduce ins(%sharded_in : tensor<6x6xi32>) outs(%sharded_out : tensor<6xi32>) dimensions = [1]
+ (%in: i32, %init: i32) {
+ %6 = arith.addi %in, %init : i32
+ linalg.yield %6 : i32
+ }
+ // CHECK-NOT: shard.all_reduce
+ %sharded_red = shard.shard %reduced to %sharding : tensor<6xi32>
+ %sharded_ret = shard.shard %sharded_red to %sharding annotate_for_users : tensor<6xi32>
+ // CHECK: return %[[reduced]] : tensor<3xi32>
+ return %sharded_ret : tensor<6xi32>
+}
diff --git a/mlir/test/Dialect/Shard/process-multi-index-op-lowering.mlir b/mlir/test/Dialect/Shard/process-multi-index-op-lowering.mlir
new file mode 100644
index 0000000..33c7a8f
--- /dev/null
+++ b/mlir/test/Dialect/Shard/process-multi-index-op-lowering.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt -test-grid-process-multi-index-op-lowering %s | FileCheck %s
+
+shard.grid @grid2d(shape = ?x?)
+
+// CHECK-LABEL: func.func @multi_index_2d_grid
+func.func @multi_index_2d_grid() -> (index, index) {
+ // CHECK: %[[LINEAR_IDX:.*]] = shard.process_linear_index on @grid2d : index
+ // CHECK: %[[SHARD_SHAPE:.*]]:2 = shard.grid_shape @grid2d : index, index
+ // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[SHARD_SHAPE]]#0, %[[SHARD_SHAPE]]#1) : index, index
+ %0:2 = shard.process_multi_index on @grid2d : index, index
+ // CHECK: return %[[MULTI_IDX]]#0, %[[MULTI_IDX]]#1 : index, index
+ return %0#0, %0#1 : index, index
+}
+
+// CHECK-LABEL: func.func @multi_index_2d_grid_single_inner_axis
+func.func @multi_index_2d_grid_single_inner_axis() -> index {
+ // CHECK: %[[LINEAR_IDX:.*]] = shard.process_linear_index on @grid2d : index
+ // CHECK: %[[SHARD_SHAPE:.*]]:2 = shard.grid_shape @grid2d : index, index
+ // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[SHARD_SHAPE]]#0, %[[SHARD_SHAPE]]#1) : index, index
+ %0 = shard.process_multi_index on @grid2d axes = [0] : index
+ // CHECK: return %[[MULTI_IDX]]#0 : index
+ return %0 : index
+}
diff --git a/mlir/test/Dialect/Shard/resharding-partition.mlir b/mlir/test/Dialect/Shard/resharding-partition.mlir
new file mode 100644
index 0000000..ff9e840
--- /dev/null
+++ b/mlir/test/Dialect/Shard/resharding-partition.mlir
@@ -0,0 +1,168 @@
+// RUN: mlir-opt -test-grid-resharding-partition %s | FileCheck %s
+
+shard.grid @grid_1d(shape = 2)
+shard.grid @grid_1d_dynamic(shape = ?)
+
+// CHECK-LABEL: func @same_source_and_target_sharding
+func.func @same_source_and_target_sharding(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<2xf32>
+ %arg0: tensor<2xf32>
+) -> tensor<2xf32> {
+ %s0 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<2xf32>
+ %s1 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<2xf32>
+ // CHECK: return %[[ARG]]
+ return %1 : tensor<2xf32>
+}
+
+// CHECK-LABEL: func @identical_source_and_target_sharding
+func.func @identical_source_and_target_sharding(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<2xf32>
+ %arg0: tensor<2xf32>
+) -> tensor<2xf32> {
+ %s0 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<2xf32>
+ %1 = shard.shard %0 to %s0 annotate_for_users : tensor<2xf32>
+ // CHECK: return %[[ARG]]
+ return %1 : tensor<2xf32>
+}
+
+// CHECK-LABEL: func @split_replicated_tensor_axis
+func.func @split_replicated_tensor_axis(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<3x14xf32>
+ %arg0: tensor<3x14xf32>
+) -> tensor<3x14xf32> {
+ // CHECK: %[[ALL_SLICE:.*]] = shard.all_slice %[[ARG]] on @grid_1d grid_axes = [0] slice_axis = 1
+ // CHECK-SAME: tensor<3x14xf32> -> tensor<3x7xf32>
+ // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[ALL_SLICE]] : tensor<3x7xf32> to tensor<3x14xf32>
+ %s0 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<3x14xf32>
+ %s1 = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<3x14xf32>
+ // CHECK: return %[[RESULT]] : tensor<3x14xf32>
+ return %1 : tensor<3x14xf32>
+}
+
+// CHECK-LABEL: func @split_replicated_tensor_axis_dynamic
+func.func @split_replicated_tensor_axis_dynamic(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<?x3x?xf32>
+ %arg0: tensor<?x3x?xf32>
+) -> tensor<?x3x?xf32> {
+ // CHECK: %[[RESULT:.*]] = shard.all_slice %[[ARG]] on @grid_1d_dynamic grid_axes = [0] slice_axis = 0
+ // CHECK-SAME: tensor<?x3x?xf32> -> tensor<?x3x?xf32>
+ %s0 = shard.sharding @grid_1d_dynamic split_axes = [[], [], []] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<?x3x?xf32>
+ %s1 = shard.sharding @grid_1d_dynamic split_axes = [[0]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<?x3x?xf32>
+ // CHECK: return %[[RESULT]] : tensor<?x3x?xf32>
+ return %1 : tensor<?x3x?xf32>
+}
+
+// CHECK-LABEL: func @move_split_axis
+func.func @move_split_axis(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+ %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+ // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32>
+ // CHECK: %[[TARGET_SHARD:.*]] = shard.all_to_all %[[SOURCE_SHARD]] on @grid_1d grid_axes = [0] split_axis = 1 concat_axis = 0 : tensor<5x14xf32> -> tensor<10x7xf32>
+ // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x7xf32> to tensor<10x14xf32>
+ %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<10x14xf32>
+ %s1 = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
+ // CHECK: return %[[RES]] : tensor<10x14xf32>
+ return %1 : tensor<10x14xf32>
+}
+
+// CHECK-LABEL: func @move_split_axis_dynamic_grid
+func.func @move_split_axis_dynamic_grid(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+ %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+ // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32>
+ // CHECK: %[[ALL_TO_ALL:.*]] = shard.all_to_all %[[SOURCE_SHARD]] on @grid_1d_dynamic grid_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x?xf32>
+ // CHECK: %[[TARGET_SHARD:.*]] = tensor.cast %[[ALL_TO_ALL]] : tensor<?x?xf32> to tensor<10x?xf32>
+ // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x?xf32> to tensor<10x14xf32>
+ %s0 = shard.sharding @grid_1d_dynamic split_axes = [[0]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<10x14xf32>
+ %s1 = shard.sharding @grid_1d_dynamic split_axes = [[], [0]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
+ // CHECK: return %[[RES]] : tensor<10x14xf32>
+ return %1 : tensor<10x14xf32>
+}
+
+// CHECK-LABEL: func @move_split_dynamic_axis
+func.func @move_split_dynamic_axis(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32>
+ %arg0: tensor<?x14xf32>
+) -> tensor<?x14xf32> {
+ // CHECK: %[[TARGET_SHARD:.*]] = shard.all_to_all %[[ARG]] on @grid_1d grid_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x7xf32>
+ // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<?x7xf32> to tensor<?x14xf32>
+ %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<?x14xf32>
+ %s1 = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<?x14xf32>
+ // CHECK: return %[[RES]] : tensor<?x14xf32>
+ return %1 : tensor<?x14xf32>
+}
+
+// CHECK-LABEL: func @unshard_static_axis
+func.func @unshard_static_axis(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+ %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+ // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32>
+ // CHECK: %[[ALL_GATHER:.*]] = shard.all_gather %[[SOURCE_SHARD]] on @grid_1d grid_axes = [0] gather_axis = 0 : tensor<5x14xf32> -> tensor<10x14xf32>
+ %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<10x14xf32>
+ %s1 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
+ // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32>
+ return %1 : tensor<10x14xf32>
+}
+
+// CHECK-LABEL: func @unshard_static_last_axis
+func.func @unshard_static_last_axis(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+ %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+ // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<10x7xf32>
+ // CHECK: %[[ALL_GATHER:.*]] = shard.all_gather %[[SOURCE_SHARD]] on @grid_1d grid_axes = [0] gather_axis = 1 : tensor<10x7xf32> -> tensor<10x14xf32>
+ %s0 = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<10x14xf32>
+ %s1 = shard.sharding @grid_1d split_axes = [[], []] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
+ // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32>
+ return %1 : tensor<10x14xf32>
+}
+
+// CHECK-LABEL: func @unshard_dynamic_axis
+func.func @unshard_dynamic_axis(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32>
+ %arg0: tensor<?x14xf32>
+) -> tensor<?x14xf32> {
+ // CHECK: %[[ALL_GATHER:.*]] = shard.all_gather %[[ARG]] on @grid_1d grid_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32>
+ %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<?x14xf32>
+ %s1 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<?x14xf32>
+ // CHECK: return %[[ALL_GATHER]] : tensor<?x14xf32>
+ return %1 : tensor<?x14xf32>
+}
+
+// CHECK-LABEL: func @unshard_static_axis_on_dynamic_grid_axis
+func.func @unshard_static_axis_on_dynamic_grid_axis(
+// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+ %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+ // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32>
+ // CHECK: %[[ALL_GATHER:.*]] = shard.all_gather %[[SOURCE_SHARD]] on @grid_1d_dynamic grid_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32>
+ // CHECK: %[[RES:.*]] = tensor.cast %[[ALL_GATHER]] : tensor<?x14xf32> to tensor<10x14xf32>
+ %s0 = shard.sharding @grid_1d_dynamic split_axes = [[0]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<10x14xf32>
+ %s1 = shard.sharding @grid_1d_dynamic split_axes = [[]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
+ // CHECK: return %[[RES]] : tensor<10x14xf32>
+ return %1 : tensor<10x14xf32>
+}
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation-failed.mlir b/mlir/test/Dialect/Shard/sharding-propagation-failed.mlir
index b5eb98d..b5eb98d 100644
--- a/mlir/test/Dialect/Mesh/sharding-propagation-failed.mlir
+++ b/mlir/test/Dialect/Shard/sharding-propagation-failed.mlir
diff --git a/mlir/test/Dialect/Shard/sharding-propagation.mlir b/mlir/test/Dialect/Shard/sharding-propagation.mlir
new file mode 100644
index 0000000..34aaf05
--- /dev/null
+++ b/mlir/test/Dialect/Shard/sharding-propagation.mlir
@@ -0,0 +1,301 @@
+// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation,cse))" %s | FileCheck %s
+
+shard.grid @grid_2(shape = 2)
+shard.grid @grid_1d(shape = ?)
+shard.grid @grid_2d(shape = 2x4)
+shard.grid @grid_3d(shape = ?x?x?)
+
+// CHECK-LABEL: func.func @element_wise_empty_sharding_info
+func.func @element_wise_empty_sharding_info(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+ // CHECK-NEXT: tosa.sigmoid
+ %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ // CHECK-NEXT: return
+ return %0 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @element_wise_on_def
+// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
+func.func @element_wise_on_def(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+ // CHECK-NEXT: %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0], [1]] : !shard.sharding
+ // CHECK-NEXT: %[[V0:.*]] = shard.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]]
+ %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ // CHECK-NEXT: %[[V2:.*]] = shard.shard %[[V1]] to %[[S0]] : tensor<8x16xf32>
+ %s1 = shard.sharding @grid_2d split_axes = [[0], [1]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 : tensor<8x16xf32>
+ // CHECK-NEXT: return %[[V2]]
+ return %1 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @element_wise_on_use
+// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
+func.func @element_wise_on_use(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+ // CHECK-NEXT: %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0], [1]] : !shard.sharding
+ // CHECK-NEXT: %[[V0:.*]] = shard.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
+ %s0 = shard.sharding @grid_2d split_axes = [[0], [1]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]]
+ %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ // CHECK-NEXT: %[[V2:.*]] = shard.shard %[[V1]] to %[[S0]] : tensor<8x16xf32>
+ // CHECK-NEXT: return %[[V2]]
+ return %1 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @element_wise_on_graph_output
+// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
+func.func @element_wise_on_graph_output(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+ // CHECK-NEXT: %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0], [1]] : !shard.sharding
+ // CHECK-NEXT: %[[V0:.*]] = shard.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]]
+ %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ // CHECK-NEXT: %[[V2:.*]] = shard.shard %[[V1]] to %[[S0]] : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V3:.*]] = shard.shard %[[V2]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
+ %s1 = shard.sharding @grid_2d split_axes = [[0], [1]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: return %[[V3]]
+ return %1 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @element_wise_on_graph_input
+// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
+func.func @element_wise_on_graph_input(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+ // CHECK-NEXT: %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0], [1]] : !shard.sharding
+ // CHECK-NEXT: %[[V0:.*]] = shard.shard %[[ARG]] to %[[S0]] : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V1:.*]] = shard.shard %[[V0]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
+ %s0 = shard.sharding @grid_2d split_axes = [[0], [1]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V2:.*]] = tosa.sigmoid %[[V1]]
+ %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ // CHECK-NEXT: %[[V3:.*]] = shard.shard %[[V2]] to %[[S0]] : tensor<8x16xf32>
+ // CHECK-NEXT: return %[[V3]]
+ return %1 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @arrow_structure
+// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
+func.func @arrow_structure(%arg0: tensor<8x16xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) {
+ // CHECK-NEXT: %[[S1:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0], [1]] : !shard.sharding
+ // CHECK-NEXT: %[[V1:.*]] = shard.shard %[[ARG]] to %[[S1]] annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V2:.*]] = tosa.tanh %[[V1]]
+ // CHECK-NEXT: %[[V3:.*]] = shard.shard %[[V2]] to %[[S1]] : tensor<8x16xf32>
+ %0 = tosa.tanh %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ // CHECK-NEXT: %[[V4:.*]] = shard.shard %[[V3]] to %[[S1]] annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V5:.*]] = tosa.abs %[[V4]]
+ // CHECK-NEXT: %[[V6:.*]] = shard.shard %[[V5]] to %[[S1]] : tensor<8x16xf32>
+ %1 = tosa.abs %0: (tensor<8x16xf32>) -> tensor<8x16xf32>
+ // CHECK-NEXT: %[[S3:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}]] : !shard.sharding
+ // CHECK-NEXT: %[[ZP1:.*]] = shard.shard %arg1 to %[[S3]] annotate_for_users : tensor<1xf32>
+ // CHECK-NEXT: %[[ZP2:.*]] = shard.shard %arg2 to %[[S3]] annotate_for_users : tensor<1xf32>
+ // CHECK-NEXT: %[[V7:.*]] = tosa.negate %[[V4]], %[[ZP1]], %[[ZP2]]
+ // CHECK-NEXT: %[[V8:.*]] = shard.shard %[[V7]] to %[[S1]] : tensor<8x16xf32>
+ %2 = tosa.negate %0, %arg1, %arg2 : (tensor<8x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<8x16xf32>
+ %s3 = shard.sharding @grid_2d split_axes = [[0], [1]] : !shard.sharding
+ %3 = shard.shard %2 to %s3 : tensor<8x16xf32>
+ // CHECK-NEXT: return %[[V6]], %[[V8]]
+ return %1, %3 : tensor<8x16xf32>, tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @matmul_on_def_shard_batch_and_m
+// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32>
+func.func @matmul_on_def_shard_batch_and_m(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> {
+ // CHECK-NEXT: %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0], [1]] : !shard.sharding
+ // CHECK-NEXT: %[[V0:.*]] = shard.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32>
+ // CHECK-NEXT: %[[S1:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0]] : !shard.sharding
+ // CHECK-NEXT: %[[V1:.*]] = shard.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32>
+ // CHECK-NEXT: %[[S2:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}]] : !shard.sharding
+ // CHECK-NEXT: %[[ZP:.*]] = shard.shard %[[ARG2]] to %[[S2]] annotate_for_users : tensor<1xf32>
+ // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]]
+ %0 = tosa.matmul %arg0, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32>
+ // CHECK-NEXT: %[[V3:.*]] = shard.shard %[[V2]] to %[[S0]] : tensor<2x16x32xf32>
+ %s1 = shard.sharding @grid_2d split_axes = [[0], [1]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 : tensor<2x16x32xf32>
+ // CHECK-NEXT: return %[[V3]]
+ return %1 : tensor<2x16x32xf32>
+}
+
+// CHECK-LABEL: func.func @matmul_on_def_shard_m_and_n
+// CHECK-SAME: [[varg0:%.*]]: tensor<2x16x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<1xf32>
+func.func @matmul_on_def_shard_m_and_n(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> {
+ // CHECK: [[vsharding:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [0]] : !shard.sharding
+ // CHECK: [[vsharded:%.*]] = shard.shard [[varg0]] to [[vsharding]] annotate_for_users : tensor<2x16x8xf32>
+ // CHECK: [[vsharding_0:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [], [1]] : !shard.sharding
+ // CHECK: [[vsharded_1:%.*]] = shard.shard [[varg1]] to [[vsharding_0]] annotate_for_users : tensor<2x8x32xf32>
+ // CHECK: [[vsharding_2:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}]] : !shard.sharding
+ // CHECK: [[vsharded_3:%.*]] = shard.shard [[varg2]] to [[vsharding_2]] annotate_for_users : tensor<1xf32>
+ // CHECK: [[v0:%.*]] = tosa.matmul
+ %0 = tosa.matmul %arg0, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32>
+ // CHECK: [[vsharding_4:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [0], [1]] : !shard.sharding
+ // CHECK: [[vsharded_5:%.*]] = shard.shard [[v0]] to [[vsharding_4]] : tensor<2x16x32xf32>
+ %s1 = shard.sharding @grid_2d split_axes = [[], [0], [1]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 : tensor<2x16x32xf32>
+ // CHECK-NEXT: return [[vsharded_5]]
+ return %1 : tensor<2x16x32xf32>
+}
+
+// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_k
+// CHECK-SAME: [[varg0:%.*]]: tensor<2x16x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<1xf32>
+func.func @matmul_on_use_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> {
+ // CHECK: [[vsharding:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [0], [1]] : !shard.sharding
+ %s0 = shard.sharding @grid_2d split_axes = [[], [0], [1]] : !shard.sharding
+ // CHECK: [[vsharded:%.*]] = shard.shard [[varg0]] to [[vsharding]] : tensor<2x16x8xf32>
+ %arg0_s = shard.shard %arg0 to %s0 : tensor<2x16x8xf32>
+ // CHECK: [[vsharded_0:%.*]] = shard.shard [[vsharded]] to [[vsharding]] annotate_for_users : tensor<2x16x8xf32>
+ // CHECK: [[vsharding_1:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [1]] : !shard.sharding
+ // CHECK: [[vsharded_2:%.*]] = shard.shard [[varg1]] to [[vsharding_1]] annotate_for_users : tensor<2x8x32xf32>
+ // CHECK: [[vsharding_3:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}]] : !shard.sharding
+ // CHECK: [[vsharded_4:%.*]] = shard.shard [[varg2]] to [[vsharding_3]] annotate_for_users : tensor<1xf32>
+ // CHECK: [[v0:%.*]] = tosa.matmul
+ // CHECK: [[vsharding_5:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [0]] : !shard.sharding
+ // CHECK: [[vsharded_6:%.*]] = shard.shard [[v0]] to [[vsharding_5]] : tensor<2x16x32xf32>
+ %0 = tosa.matmul %arg0_s, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32>
+ // CHECK: return [[vsharded_6]]
+ return %0 : tensor<2x16x32xf32>
+}
+
+// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_duplicted_k
+// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32>
+func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> {
+ // CHECK-NEXT: %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [1], [0]] : !shard.sharding
+ // CHECK-NEXT: %[[V0:.*]] = shard.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32>
+ %s0 = shard.sharding @grid_2d split_axes = [[], [1], [0]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 annotate_for_users : tensor<2x16x8xf32>
+ // CHECK-NEXT: %[[S1:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [0]] : !shard.sharding
+ // CHECK-NEXT: %[[V1:.*]] = shard.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32>
+ %s1 = shard.sharding @grid_2d split_axes = [[], [0]] : !shard.sharding
+ %1 = shard.shard %arg1 to %s1 annotate_for_users : tensor<2x8x32xf32>
+ // CHECK-NEXT: %[[S2:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}]] : !shard.sharding
+ // CHECK-NEXT: %[[ZP:.*]] = shard.shard %[[ARG2]] to %[[S2]] annotate_for_users : tensor<1xf32>
+ // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]]
+ %2 = tosa.matmul %0, %1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32>
+ // CHECK-NEXT: %[[S3:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [1]] : !shard.sharding
+ // CHECK-NEXT: %[[V3:.*]] = shard.shard %[[V2]] to %[[S3]] : tensor<2x16x32xf32>
+ // CHECK-NEXT: return %[[V3]]
+ return %2 : tensor<2x16x32xf32>
+}
+
+// CHECK-LABEL: func.func @resolve_conflicting_annotations
+func.func @resolve_conflicting_annotations(
+ // CHECK-SAME: %[[IN1:.*]]: tensor<2x3xf32>,
+ %arg0: tensor<2x3xf32>,
+ // CHECK-SAME: %[[IN2:.*]]: tensor<3x2xf32>,
+ %arg1: tensor<3x2xf32>,
+ // CHECK-SAME: %[[OUT_DPS:.*]]: tensor<2x2xf32>
+ %out_dps: tensor<2x2xf32>
+// CHECK-SAME: ) -> tensor<2x2xf32> {
+) -> tensor<2x2xf32> {
+ // CHECK: %[[SIN1_SHARDED1:.*]] = shard.sharding @grid_2 split_axes = {{\[\[}}0]] : !shard.sharding
+ // CHECK-NEXT: %[[IN1_SHARDED1:.*]] = shard.shard %[[IN1]] to %[[SIN1_SHARDED1]] : tensor<2x3xf32>
+ // CHECK: %[[SIN2_SHARDED:.*]] = shard.sharding @grid_2 split_axes = {{\[\[}}]] : !shard.sharding
+ // CHECK-NEXT: %[[IN1_SHARDED2:.*]] = shard.shard %[[IN1_SHARDED1]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<2x3xf32>
+ // CHECK-NEXT: %[[IN2_SHARDED:.*]] = shard.shard %[[IN2]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<3x2xf32>
+ // CHECK-NEXT: %[[OUT_DPS_SHARDED:.*]] = shard.shard %[[OUT_DPS]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<2x2xf32>
+ %sarg0_sharded = shard.sharding @grid_2 split_axes = [[0]] : !shard.sharding
+ %arg0_sharded = shard.shard %arg0 to %sarg0_sharded : tensor<2x3xf32>
+ // CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[IN1_SHARDED2]], %[[IN2_SHARDED]] : tensor<2x3xf32>, tensor<3x2xf32>)
+ // CHECK-SAME: outs(%[[OUT_DPS_SHARDED]] : tensor<2x2xf32>) -> tensor<2x2xf32>
+ %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>)
+ outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32>
+ // CHECK-NEXT: %[[RES:.*]] = shard.shard %[[MATMUL]] to %[[SIN2_SHARDED]] : tensor<2x2xf32>
+ %sres_sharded = shard.sharding @grid_2 split_axes = [[]] : !shard.sharding
+ %res_sharded = shard.shard %res to %sres_sharded : tensor<2x2xf32>
+ // CHECK: return %[[RES]] : tensor<2x2xf32>
+ return %res_sharded : tensor<2x2xf32>
+}
+
+// https://arxiv.org/abs/2211.05102 Figure 2(a)
+// The sharding propagation results in unnecessary reshards,
+// an optimization pass should be able to remove them.
+// CHECK-LABEL: func.func @mlp_1d_weight_stationary
+// CHECK-SAME: [[varg0:%.*]]: tensor<2x4x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<2x32x8xf32>, [[varg3:%.*]]: tensor<1xf32>
+func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>, %arg3: tensor<1xf32>) -> tensor<2x4x8xf32> {
+ %s0 = shard.sharding @grid_1d split_axes = [[], [], [0, 1, 2]] : !shard.sharding
+ %sharded0 = shard.shard %arg0 to %s0 : tensor<2x4x8xf32>
+ %sharded1 = shard.shard %arg1 to %s0 : tensor<2x8x32xf32>
+ // CHECK: [[vsharding:%.*]] = shard.sharding @grid_1d split_axes = {{\[\[}}], [], [0, 1, 2]] : !shard.sharding
+ // CHECK: [[vsharded:%.*]] = shard.shard [[varg0]] to [[vsharding]] : tensor<2x4x8xf32>
+ // CHECK: [[vsharded_0:%.*]] = shard.shard [[varg1]] to [[vsharding]] : tensor<2x8x32xf32>
+ // CHECK: [[vsharded_1:%.*]] = shard.shard [[vsharded]] to [[vsharding]] annotate_for_users : tensor<2x4x8xf32>
+ // CHECK: [[vsharding_2:%.*]] = shard.sharding @grid_1d split_axes = {{\[\[}}], [0, 1, 2]] : !shard.sharding
+ // CHECK: [[vsharded_3:%.*]] = shard.shard [[vsharded_0]] to [[vsharding_2]] annotate_for_users : tensor<2x8x32xf32>
+ // CHECK: [[vsharding_4:%.*]] = shard.sharding @grid_1d split_axes = {{\[\[}}]] : !shard.sharding
+ // CHECK: [[vsharded_5:%.*]] = shard.shard [[varg3]] to [[vsharding_4]] annotate_for_users : tensor<1xf32>
+ // CHECK: [[v0:%.*]] = tosa.matmul
+ %1 = tosa.matmul %sharded0, %sharded1, %arg3, %arg3 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x32xf32>
+ // CHECK: [[vsharded_6:%.*]] = shard.shard [[v0]] to [[vsharding_4]] : tensor<2x4x32xf32>
+ // CHECK: [[vsharded_7:%.*]] = shard.shard [[vsharded_6]] to [[vsharding_4]] annotate_for_users : tensor<2x4x32xf32>
+ // CHECK: [[v1:%.*]] = tosa.sigmoid [[vsharded_7]] : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
+ // CHECK: [[vsharded_8:%.*]] = shard.shard [[v1]] to [[vsharding_4]] : tensor<2x4x32xf32>
+ %2 = tosa.sigmoid %1 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
+ %sharding = shard.sharding @grid_1d split_axes = [[], [0, 1, 2]] : !shard.sharding
+ // CHECK: [[vsharded_9:%.*]] = shard.shard [[varg2]] to [[vsharding_2]] : tensor<2x32x8xf32>
+ %sharded2 = shard.shard %arg2 to %sharding : tensor<2x32x8xf32>
+ // CHECK: [[vsharded_10:%.*]] = shard.shard [[vsharded_8]] to [[vsharding_4]] annotate_for_users : tensor<2x4x32xf32>
+ // CHECK: [[vsharded_11:%.*]] = shard.shard [[vsharded_9]] to [[vsharding]] annotate_for_users : tensor<2x32x8xf32>
+ // CHECK: [[v2:%.*]] = tosa.matmul
+ %3 = tosa.matmul %2, %sharded2, %arg3, %arg3 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x8xf32>
+ // CHECK: [[vsharded_12:%.*]] = shard.shard [[v2]] to [[vsharding]] : tensor<2x4x8xf32>
+ %s4 = shard.sharding @grid_1d split_axes = [[], [], [0, 1, 2]] : !shard.sharding
+ %4 = shard.shard %3 to %s4 : tensor<2x4x8xf32>
+ // CHECK: return [[vsharded_12]]
+ return %4 : tensor<2x4x8xf32>
+}
+
+// https://arxiv.org/abs/2211.05102 Figure 2(b)
+// The sharding propagation results in unnecessary reshards,
+// an optimization pass should be able to remove them.
+// CHECK-LABEL: func.func @mlp_2d_weight_stationary
+// CHECK-SAME: [[varg0:%.*]]: tensor<2x4x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<2x32x8xf32>, [[varg3:%.*]]: tensor<1xf32>
+func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>, %arg3: tensor<1xf32>) -> tensor<2x4x8xf32> {
+ // CHECK: [[vsharding:%.*]] = shard.sharding @grid_3d split_axes = {{\[\[}}], [], [0, 1, 2]] : !shard.sharding
+ %s0 = shard.sharding @grid_3d split_axes = [[], [], [0, 1, 2]] : !shard.sharding
+ // CHECK: [[vsharded:%.*]] = shard.shard [[varg0]] to [[vsharding]] : tensor<2x4x8xf32>
+ %arg0_s = shard.shard %arg0 to %s0 : tensor<2x4x8xf32>
+ // CHECK: [[vsharding_0:%.*]] = shard.sharding @grid_3d split_axes = {{\[\[}}], [0], [1, 2]] : !shard.sharding
+ %s1 = shard.sharding @grid_3d split_axes = [[], [0], [1, 2]] : !shard.sharding
+ // CHECK: [[vsharded_1:%.*]] = shard.shard [[varg1]] to [[vsharding_0]] : tensor<2x8x32xf32>
+ %arg1_s = shard.shard %arg1 to %s1 : tensor<2x8x32xf32>
+ // CHECK: [[vsharding_2:%.*]] = shard.sharding @grid_3d split_axes = {{\[\[}}]] : !shard.sharding
+ // CHECK: [[vsharded_3:%.*]] = shard.shard [[vsharded]] to [[vsharding_2]] annotate_for_users : tensor<2x4x8xf32>
+ // CHECK: [[vsharded_4:%.*]] = shard.shard [[vsharded_1]] to [[vsharding]] annotate_for_users : tensor<2x8x32xf32>
+ // CHECK: [[vsharded_5:%.*]] = shard.shard [[varg3]] to [[vsharding_2]] annotate_for_users : tensor<1xf32>
+ // CHECK: [[v0:%.*]] = tosa.matmul
+ %1 = tosa.matmul %arg0_s, %arg1_s, %arg3, %arg3 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x32xf32>
+ // CHECK: [[vsharded_6:%.*]] = shard.shard [[v0]] to [[vsharding]] : tensor<2x4x32xf32>
+ %2 = shard.shard %1 to %s0 : tensor<2x4x32xf32>
+ // CHECK: [[vsharded_7:%.*]] = shard.shard [[vsharded_6]] to [[vsharding]] annotate_for_users : tensor<2x4x32xf32>
+ // CHECK: [[v1:%.*]] = tosa.sigmoid
+ // CHECK: [[vsharded_8:%.*]] = shard.shard [[v1]] to [[vsharding]] : tensor<2x4x32xf32>
+ %3 = tosa.sigmoid %2 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
+ // CHECK: [[vsharding_9:%.*]] = shard.sharding @grid_3d split_axes = {{\[\[}}], [1, 2], [0]] : !shard.sharding
+ %s2 = shard.sharding @grid_3d split_axes = [[], [1, 2], [0]] : !shard.sharding
+ // CHECK: [[vsharded_10:%.*]] = shard.shard [[varg2]] to [[vsharding_9]] : tensor<2x32x8xf32>
+ %arg2_s = shard.shard %arg2 to %s2 : tensor<2x32x8xf32>
+ // CHECK: [[vsharded_11:%.*]] = shard.shard [[vsharded_8]] to [[vsharding_2]] annotate_for_users : tensor<2x4x32xf32>
+ // CHECK: [[vsharded_12:%.*]] = shard.shard [[vsharded_10]] to [[vsharding]] annotate_for_users : tensor<2x32x8xf32>
+ // CHECK: [[v2:%.*]] = tosa.matmul
+ %4 = tosa.matmul %3, %arg2_s, %arg3, %arg3 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x8xf32>
+ // CHECK: [[vsharded_13:%.*]] = shard.shard [[v2]] to [[vsharding]] : tensor<2x4x8xf32>
+ %5 = shard.shard %4 to %s0 : tensor<2x4x8xf32>
+ // CHECK: [[vsharded_14:%.*]] = shard.shard [[vsharded_13]] to [[vsharding]] annotate_for_users : tensor<2x4x8xf32>
+ %6 = shard.shard %5 to %s0 annotate_for_users : tensor<2x4x8xf32>
+ // CHECK: return [[vsharded_14]]
+ return %6 : tensor<2x4x8xf32>
+}
+
+// CHECK-LABEL: func.func @elementwise_duplicated_chain
+// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
+func.func @elementwise_duplicated_chain(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+ // CHECK-NEXT: %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}]] : !shard.sharding
+ // CHECK-NEXT: %[[V0:.*]] = shard.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]]
+ %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ // CHECK-NEXT: %[[V2:.*]] = shard.shard %[[V1]] to %[[S0]] : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V3:.*]] = shard.shard %[[V2]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V4:.*]] = tosa.sigmoid %[[V3]]
+ %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ // CHECK-NEXT: %[[V5:.*]] = shard.shard %[[V4]] to %[[S0]] : tensor<8x16xf32>
+ %s0 = shard.sharding @grid_2d split_axes = [[]] : !shard.sharding
+ %2 = shard.shard %1 to %s0 : tensor<8x16xf32>
+ // CHECK-NEXT: return %[[V5]]
+ return %2 : tensor<8x16xf32>
+}
diff --git a/mlir/test/Dialect/Mesh/simplifications.mlir b/mlir/test/Dialect/Shard/simplifications.mlir
index e955f4c..33cd490 100644
--- a/mlir/test/Dialect/Mesh/simplifications.mlir
+++ b/mlir/test/Dialect/Shard/simplifications.mlir
@@ -1,7 +1,7 @@
-// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
+// RUN: mlir-opt -test-grid-simplifications %s | FileCheck %s
-mesh.mesh @mesh0(shape = 4x2)
-mesh.mesh @mesh1(shape = 4)
+shard.grid @grid0(shape = 4x2)
+shard.grid @grid1(shape = 4)
// Checks that `all_reduce(x) + all_reduce(y)` gets transformed to
// `all_reduce(x + y)`.
@@ -11,13 +11,13 @@ func.func @all_reduce_arith_addf_endomorphism(
%arg0: tensor<5xf32>,
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg1: tensor<5xf32>) -> tensor<5xf32> {
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0]
: tensor<5xf32> -> tensor<5xf32>
- %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+ %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0]
: tensor<5xf32> -> tensor<5xf32>
// CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ARG0]], %[[ARG1]]
%2 = arith.addf %0, %1 : tensor<5xf32>
- // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]]
+ // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ADD_RES]]
// CHECK: return %[[ALL_REDUCE_RES]]
return %2 : tensor<5xf32>
}
@@ -28,13 +28,13 @@ func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_result(
%arg0: tensor<5xf32>,
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0]
: tensor<5xf32> -> tensor<5xf32>
- %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+ %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0]
: tensor<5xf32> -> tensor<5xf32>
// CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ARG0]], %[[ARG1]]
%2 = arith.addf %0, %1 : tensor<5xf32>
- // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]]
+ // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ADD_RES]]
// CHECK: return %[[ALL_REDUCE_RES]], %[[ALL_REDUCE_RES]]
return %2, %2 : tensor<5xf32>, tensor<5xf32>
}
@@ -46,11 +46,11 @@ func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_all_reduce_result
%arg0: tensor<5xf32>,
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
- // CHECK: %[[ALL_REDUCE_0_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]]
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+ // CHECK: %[[ALL_REDUCE_0_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG0]]
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0]
: tensor<5xf32> -> tensor<5xf32>
- // CHECK: %[[ALL_REDUCE_1_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]]
- %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+ // CHECK: %[[ALL_REDUCE_1_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG1]]
+ %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0]
: tensor<5xf32> -> tensor<5xf32>
// CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE_0_RES]], %[[ALL_REDUCE_1_RES]]
%2 = arith.addf %0, %1 : tensor<5xf32>
@@ -58,17 +58,17 @@ func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_all_reduce_result
return %0, %2 : tensor<5xf32>, tensor<5xf32>
}
-// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_mesh
-func.func @all_reduce_arith_addf_no_endomorphism_different_mesh(
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_grid
+func.func @all_reduce_arith_addf_no_endomorphism_different_grid(
// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg0: tensor<5xf32>,
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg1: tensor<5xf32>) -> tensor<5xf32> {
- // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+ // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG0]] on @grid0
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0]
: tensor<5xf32> -> tensor<5xf32>
- // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh1
- %1 = mesh.all_reduce %arg1 on @mesh1 mesh_axes = [0]
+ // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG1]] on @grid1
+ %1 = shard.all_reduce %arg1 on @grid1 grid_axes = [0]
: tensor<5xf32> -> tensor<5xf32>
// CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
%2 = arith.addf %0, %1 : tensor<5xf32>
@@ -76,17 +76,17 @@ func.func @all_reduce_arith_addf_no_endomorphism_different_mesh(
return %2 : tensor<5xf32>
}
-// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_mesh_axes
-func.func @all_reduce_arith_addf_no_endomorphism_different_mesh_axes(
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_grid_axes
+func.func @all_reduce_arith_addf_no_endomorphism_different_grid_axes(
// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg0: tensor<5xf32>,
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg1: tensor<5xf32>) -> tensor<5xf32> {
- // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0]
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+ // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG0]] on @grid0 grid_axes = [0]
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0]
: tensor<5xf32> -> tensor<5xf32>
- // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [1]
- %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [1]
+ // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG1]] on @grid0 grid_axes = [1]
+ %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [1]
: tensor<5xf32> -> tensor<5xf32>
// CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
%2 = arith.addf %0, %1 : tensor<5xf32>
@@ -100,11 +100,11 @@ func.func @all_reduce_arith_addf_no_endomorphism_wrong_reduction_kind(
%arg0: tensor<5xf32>,
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg1: tensor<5xf32>) -> tensor<5xf32> {
- // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0] reduction = max
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = max
+ // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG0]] on @grid0 grid_axes = [0] reduction = max
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] reduction = max
: tensor<5xf32> -> tensor<5xf32>
- // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [0]
- %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+ // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG1]] on @grid0 grid_axes = [0]
+ %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0]
: tensor<5xf32> -> tensor<5xf32>
// CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
%2 = arith.addf %0, %1 : tensor<5xf32>
@@ -118,11 +118,11 @@ func.func @all_reduce_arith_addf_no_endomorphism_different_operand_result_elemen
%arg0: tensor<5xf32>,
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg1: tensor<5xf32>) -> tensor<5xf64> {
- // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0]
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+ // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG0]] on @grid0 grid_axes = [0]
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0]
: tensor<5xf32> -> tensor<5xf64>
- // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [0]
- %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+ // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG1]] on @grid0 grid_axes = [0]
+ %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0]
: tensor<5xf32> -> tensor<5xf64>
// CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
%2 = arith.addf %0, %1 : tensor<5xf64>
@@ -138,13 +138,13 @@ func.func @all_reduce_arith_minimumf_endomorphism(
%arg0: tensor<5xf32>,
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg1: tensor<5xf32>) -> tensor<5xf32> {
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = min
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] reduction = min
: tensor<5xf32> -> tensor<5xf32>
- %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = min
+ %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0] reduction = min
: tensor<5xf32> -> tensor<5xf32>
// CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minimumf %[[ARG0]], %[[ARG1]]
%2 = arith.minimumf %0, %1 : tensor<5xf32>
- // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = min
+ // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ADD_RES]] on @grid0 grid_axes = [0] reduction = min
// CHECK: return %[[ALL_REDUCE_RES]]
return %2 : tensor<5xf32>
}
@@ -155,13 +155,13 @@ func.func @all_reduce_arith_minsi_endomorphism(
%arg0: tensor<5xi32>,
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xi32>
%arg1: tensor<5xi32>) -> tensor<5xi32> {
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = min
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] reduction = min
: tensor<5xi32> -> tensor<5xi32>
- %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = min
+ %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0] reduction = min
: tensor<5xi32> -> tensor<5xi32>
// CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minsi %[[ARG0]], %[[ARG1]]
%2 = arith.minsi %0, %1 : tensor<5xi32>
- // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = min
+ // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ADD_RES]] on @grid0 grid_axes = [0] reduction = min
// CHECK: return %[[ALL_REDUCE_RES]]
return %2 : tensor<5xi32>
}
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index 16efa73..9da2dea 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -1,22 +1,92 @@
// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack %s | FileCheck %s
// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack-control %s | FileCheck %s --check-prefix=CONTROL
-func.func @fold_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
+func.func @fold_extract_slice_into_unpack_slicing_trailing_dim(%arg0 : tensor<28x2x1x16x16xf32>) -> tensor<28x28x10xf32> {
+ %empty = tensor.empty() : tensor<28x28x15xf32>
+ %unpack = linalg.unpack %arg0
+ inner_dims_pos = [1, 2]
+ inner_tiles = [16, 16]
+ into %empty : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
+ %extracted_slice = tensor.extract_slice %unpack
+ [0, 0, 0] [28, 28, 10] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x28x10xf32>
+ return %extracted_slice : tensor<28x28x10xf32>
+}
+// CHECK-LABEL: func @fold_extract_slice_into_unpack_slicing_trailing_dim
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK: %[[DEST_SLICE:.+]] = tensor.empty() : tensor<28x28x10xf32>
+// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
+// CHECK-SAME: into %[[DEST_SLICE]]
+// CHECK: return %[[UNPACK]]
+
+// -----
+
+// The available dimension size is [17, 32], because CeilDiv(%d1, 16) == 2.
+
+func.func @fold_extract_slice_into_unpack_slicing_dim_1(%arg0 : tensor<28x2x1x16x16xf32>) -> tensor<28x17x15xf32> {
+ %empty = tensor.empty() : tensor<28x28x15xf32>
+ %unpack = linalg.unpack %arg0
+ inner_dims_pos = [1, 2]
+ inner_tiles = [16, 16]
+ into %empty : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
+ %extracted_slice = tensor.extract_slice %unpack
+ [0, 0, 0] [28, 17, 15] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x17x15xf32>
+ return %extracted_slice : tensor<28x17x15xf32>
+}
+// CHECK-LABEL: func @fold_extract_slice_into_unpack_slicing_dim_1(
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK: %[[DEST_SLICE:.+]] = tensor.empty() : tensor<28x17x15xf32>
+// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
+// CHECK-SAME: into %[[DEST_SLICE]]
+// CHECK: return %[[UNPACK]]
+
+// -----
+
+// The available dimension size is [17, 32], because CeilDiv(%d1, 16) == 2.
+
+func.func @no_fold_extract_slice_into_unpack_artificial_padding(%arg0 : tensor<28x2x1x16x16xf32>) -> tensor<28x16x15xf32> {
+ %empty = tensor.empty() : tensor<28x28x15xf32>
+ %unpack = linalg.unpack %arg0
+ inner_dims_pos = [1, 2]
+ inner_tiles = [16, 16]
+ into %empty : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
+ %extracted_slice = tensor.extract_slice %unpack
+ [0, 0, 0] [28, 16, 15] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x16x15xf32>
+ return %extracted_slice : tensor<28x16x15xf32>
+}
+// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_artificial_padding
+// CHECK: linalg.unpack
+// CHECK: tensor.extract_slice
+
+// -----
+
+func.func @no_fold_extract_slice_into_unpack_dynamic(
+ %src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
+) -> tensor<28x28x?xf32> {
+ %unpack = linalg.unpack %src
+ outer_dims_perm = [0, 1, 2]
+ inner_dims_pos = [1, 2]
+ inner_tiles = [16, 16]
+ into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
+ %extracted_slice = tensor.extract_slice %unpack
+ [0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
+ return %extracted_slice : tensor<28x28x?xf32>
+}
+// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_dynamic
+// CHECK: linalg.unpack
+// CHECK: tensor.extract_slice
+
+// -----
+
+func.func @nofold_dynamic_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
%arg2 : index, %arg3 : index) -> tensor<?x?xf32> {
%0 = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
: tensor<?x?x8x4xf32> -> tensor<?x?xf32>
%1 = tensor.extract_slice %0[0, 0] [%arg2, %arg3] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
-// CHECK: func @fold_unpack_slice(
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x8x4xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
-// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
-// CHECK: %[[INIT:.+]] = tensor.empty(%[[ARG2]], %[[ARG3]]) : tensor<?x?xf32>
-// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [8, 4]
-// CHECK-SAME: into %[[INIT]]
-// CHECK: return %[[UNPACK]]
+// CHECK-LABEL: func @nofold_dynamic_unpack_slice(
+// CHECK: linalg.unpack
+// CHECK: tensor.extract_slice
// -----
@@ -59,48 +129,62 @@ func.func @nofold_unpack_slice_rank_reduced(%arg0 : tensor<?x?x8x4xf32>, %arg1 :
// -----
-func.func @pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
- %c0 = arith.constant 0 : index
+func.func @fold_pad_pack(%src: tensor<9x16xf32>) -> tensor<2x1x8x32xf32> {
%cst = arith.constant 0.000000e+00 : f32
- %padded = tensor.pad %src low[0, 0] high[15, 0] {
+ %padded = tensor.pad %src low[0, 0] high[7, 0] {
^bb0(%arg0: index, %arg1: index):
tensor.yield %cst : f32
- } : tensor<16641x16xf32> to tensor<16656x16xf32>
- %empty = tensor.empty() : tensor<2082x1x8x32xf32>
+ } : tensor<9x16xf32> to tensor<16x16xf32>
+ %empty = tensor.empty() : tensor<2x1x8x32xf32>
%pack = linalg.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
- : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
- return %pack : tensor<2082x1x8x32xf32>
+ : tensor<16x16xf32> -> tensor<2x1x8x32xf32>
+ return %pack : tensor<2x1x8x32xf32>
}
-// CHECK-LABEL: func.func @pad_pack
+// CHECK-LABEL: func.func @fold_pad_pack
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK: %[[PAD_VAL:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<2082x1x8x32xf32>
+// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<2x1x8x32xf32>
// CHECK: %[[PACK:.+]] = linalg.pack %[[SRC]]
// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32)
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %[[DEST]]
// -----
-func.func @nofold_pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
- %c0 = arith.constant 0 : index
+func.func @nofold_pad_pack_artificial_padding(%src: tensor<9x16xf32>) -> tensor<3x1x8x32xf32> {
%cst = arith.constant 0.000000e+00 : f32
- %padded = tensor.pad %src nofold low[0, 0] high[15, 0] {
+ %padded = tensor.pad %src low[0, 0] high[8, 0] {
^bb0(%arg0: index, %arg1: index):
tensor.yield %cst : f32
- } : tensor<16641x16xf32> to tensor<16656x16xf32>
+ } : tensor<9x16xf32> to tensor<17x16xf32>
+ %empty = tensor.empty() : tensor<3x1x8x32xf32>
+ %pack = linalg.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
+ : tensor<17x16xf32> -> tensor<3x1x8x32xf32>
+ return %pack : tensor<3x1x8x32xf32>
+}
+// CHECK-LABLE: func.func @nofold_pad_pack_artificial_padding(
+// CHECK: tensor.pad
+// CHECK: linalg.pack
+
+// -----
+
+func.func @nofold_pad_pack_with_nofold_attribute(%src: tensor<16649x16xf32>) -> tensor<2082x1x8x32xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %padded = tensor.pad %src nofold low[0, 0] high[7, 0] {
+ ^bb0(%arg0: index, %arg1: index):
+ tensor.yield %cst : f32
+ } : tensor<16649x16xf32> to tensor<16656x16xf32>
%empty = tensor.empty() : tensor<2082x1x8x32xf32>
%pack = linalg.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
: tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
return %pack : tensor<2082x1x8x32xf32>
}
-// CHECK-LABEL: func.func @nofold_pad_pack
+// CHECK-LABEL: func.func @nofold_pad_pack_with_nofold_attribute(
// CHECK: tensor.pad
// CHECK: linalg.pack
// -----
func.func @pad_pack_different_padding_value(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
- %c0 = arith.constant 0 : index
%cst0 = arith.constant 0.000000e+00 : f32
%cst1 = arith.constant 1.000000e+00 : f32
%padded = tensor.pad %src low[0, 0] high[15, 0] {
diff --git a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
deleted file mode 100644
index 8598d81..0000000
--- a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
+++ /dev/null
@@ -1,52 +0,0 @@
-// RUN: mlir-opt \
-// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-single-fold))" \
-// RUN: %s | FileCheck %s
-
-mesh.mesh @mesh_1d_4(shape = 4)
-
-// CHECK-LABEL: func @tensor_empty_static_sharded_dims_offsets
-func.func @tensor_empty_static_sharded_dims_offsets() -> () {
- %b = tensor.empty() : tensor<8x16xf32>
- %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
- %sharded= mesh.shard %b to %sharding : tensor<8x16xf32>
- // CHECK: %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
- // CHECK: %[[proc_multi_idx:.*]] = mesh.process_multi_index on @mesh_1d_4 : index
- // CHECK: %[[V0:.*]]:2 = mesh.shard_shape dims = [8, 16] sharding = %[[sharding]] device = [%[[proc_multi_idx]]
- // CHECK-SAME: ] : index, index
- // CHECK: tensor.empty(%[[V0]]#0) : tensor<?x16xf32>
-
- return
-}
-
-// CHECK-LABEL: func @tensor_empty_dynamic_sharded_dims_offsets
-// CHECK-SAME: %[[A0:.*]]: index
-func.func @tensor_empty_dynamic_sharded_dims_offsets(%arg0 : index) -> () {
- %b = tensor.empty(%arg0) : tensor<8x?xf32>
- %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
- %sharded= mesh.shard %b to %sharding : tensor<8x?xf32>
- // CHECK: %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
- // CHECK: %[[proc_multi_idx:.*]] = mesh.process_multi_index on @mesh_1d_4 : index
- // CHECK: %[[V0:.*]]:2 = mesh.shard_shape dims = [8, %[[A0]]
- // CHECK-SAME: ] sharding = %[[sharding]] device = [%[[proc_multi_idx]]
- // CHECK-SAME: ] : index, index
- // CHECK: tensor.empty(%[[V0]]#0, %[[A0]]) : tensor<?x?xf32>
-
- return
-}
-
-// CHECK-LABEL: func @tensor_empty_same_static_dims_sizes
-func.func @tensor_empty_same_static_dims_sizes() -> () {
- %b = tensor.empty() : tensor<16x16xf32>
- %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 4, 8, 12, 16] : !mesh.sharding
- %sharded= mesh.shard %b to %sharding : tensor<16x16xf32>
- // CHECK-NEXT: tensor.empty() : tensor<4x16xf32>
-
- return
-}
-
-// CHECK-LABEL: func @tensor_empty_0d
-func.func @tensor_empty_0d() -> () {
- tensor.empty() : tensor<f32>
- // CHECK-NEXT: tensor.empty() : tensor<f32>
- return
-}
diff --git a/mlir/test/Dialect/Tensor/shard-partition.mlir b/mlir/test/Dialect/Tensor/shard-partition.mlir
new file mode 100644
index 0000000..5918ee1
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/shard-partition.mlir
@@ -0,0 +1,52 @@
+// RUN: mlir-opt \
+// RUN: --pass-pipeline="builtin.module(func.func(shard-partition,test-single-fold))" \
+// RUN: %s | FileCheck %s
+
+shard.grid @grid_1d_4(shape = 4)
+
+// CHECK-LABEL: func @tensor_empty_static_sharded_dims_offsets
+func.func @tensor_empty_static_sharded_dims_offsets() -> () {
+ %b = tensor.empty() : tensor<8x16xf32>
+ %sharding = shard.sharding @grid_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !shard.sharding
+ %sharded= shard.shard %b to %sharding : tensor<8x16xf32>
+ // CHECK: %[[sharding:.*]] = shard.sharding @grid_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !shard.sharding
+ // CHECK: %[[proc_multi_idx:.*]] = shard.process_multi_index on @grid_1d_4 : index
+ // CHECK: %[[V0:.*]]:2 = shard.shard_shape dims = [8, 16] sharding = %[[sharding]] device = [%[[proc_multi_idx]]
+ // CHECK-SAME: ] : index, index
+ // CHECK: tensor.empty(%[[V0]]#0) : tensor<?x16xf32>
+
+ return
+}
+
+// CHECK-LABEL: func @tensor_empty_dynamic_sharded_dims_offsets
+// CHECK-SAME: %[[A0:.*]]: index
+func.func @tensor_empty_dynamic_sharded_dims_offsets(%arg0 : index) -> () {
+ %b = tensor.empty(%arg0) : tensor<8x?xf32>
+ %sharding = shard.sharding @grid_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !shard.sharding
+ %sharded= shard.shard %b to %sharding : tensor<8x?xf32>
+ // CHECK: %[[sharding:.*]] = shard.sharding @grid_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !shard.sharding
+ // CHECK: %[[proc_multi_idx:.*]] = shard.process_multi_index on @grid_1d_4 : index
+ // CHECK: %[[V0:.*]]:2 = shard.shard_shape dims = [8, %[[A0]]
+ // CHECK-SAME: ] sharding = %[[sharding]] device = [%[[proc_multi_idx]]
+ // CHECK-SAME: ] : index, index
+ // CHECK: tensor.empty(%[[V0]]#0, %[[A0]]) : tensor<?x?xf32>
+
+ return
+}
+
+// CHECK-LABEL: func @tensor_empty_same_static_dims_sizes
+func.func @tensor_empty_same_static_dims_sizes() -> () {
+ %b = tensor.empty() : tensor<16x16xf32>
+ %sharding = shard.sharding @grid_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 4, 8, 12, 16] : !shard.sharding
+ %sharded= shard.shard %b to %sharding : tensor<16x16xf32>
+ // CHECK-NEXT: tensor.empty() : tensor<4x16xf32>
+
+ return
+}
+
+// CHECK-LABEL: func @tensor_empty_0d
+func.func @tensor_empty_0d() -> () {
+ tensor.empty() : tensor<f32>
+ // CHECK-NEXT: tensor.empty() : tensor<f32>
+ return
+}
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index 0176fc2..6398161 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -645,7 +645,7 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
// CHECK: tosa.cond_if profiles: [ ]
// CHECK: tosa.cond_if extensions: [ [controlflow] ]
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
} else {
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 11c8d54..5150ee3 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -241,6 +241,26 @@ func.func @clamp_f32_is_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> {
// -----
+// CHECK-LABEL: @clamp_boolean_is_noop
+func.func @clamp_boolean_is_noop(%arg0: tensor<4xi1>) -> tensor<4xi1> {
+ // CHECK: return %arg0
+ // CHECK-NOT: tosa.clamp
+ %0 = tosa.clamp %arg0 {min_val = false, max_val = true} : (tensor<4xi1>) -> tensor<4xi1>
+ return %0 : tensor<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @clamp_boolean_dynamic_is_noop
+func.func @clamp_boolean_dynamic_is_noop(%arg0: tensor<?xi1>) -> tensor<?xi1> {
+ // CHECK: return %arg0
+ // CHECK-NOT: tosa.clamp
+ %0 = tosa.clamp %arg0 {min_val = false, max_val = true} : (tensor<?xi1>) -> tensor<?xi1>
+ return %0 : tensor<?xi1>
+}
+
+// -----
+
// CHECK-LABEL: @clamp_int8_is_noop
func.func @clamp_int8_is_noop(%arg0: tensor<4xi8>) -> tensor<4xi8> {
// CHECK: return %arg0
@@ -1349,3 +1369,14 @@ func.func @test_fold_i1_to_i32_cast() -> tensor<i32> {
%1 = "tosa.cast"(%0) : (tensor<i1>) -> tensor<i32>
return %1 : tensor<i32>
}
+
+// -----
+
+// CHECK-LABEL: @test_fold_i32_to_i1_cast
+// CHECK: %[[OUT:.*]] = "tosa.const"() <{values = dense<true> : tensor<i1>}> : () -> tensor<i1>
+// CHECK: return %[[OUT]] : tensor<i1>
+func.func @test_fold_i32_to_i1_cast() -> tensor<i1> {
+ %0 = "tosa.const"() <{values = dense<10> : tensor<i32>}> : () -> tensor<i32>
+ %1 = "tosa.cast"(%0) : (tensor<i32>) -> tensor<i1>
+ return %1 : tensor<i1>
+}
diff --git a/mlir/test/Dialect/Tosa/controlflow.mlir b/mlir/test/Dialect/Tosa/controlflow.mlir
new file mode 100644
index 0000000..06312c7
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/controlflow.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-opt -split-input-file %s | FileCheck %s
+
+// -----
+
+func.func @condif_cond_type_check(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // CHECK: tosa.cond_if %[[ARG2:.*]] : tensor<i1> -> tensor<f32> {
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
+ %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ // CHECK: } else {
+ } else {
+ %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @condif_block_args_check(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // CHECK: tosa.cond_if %[[ARG2:.*]] (%[[ARG3:.*]] = %[[ARG0:.*]], %[[ARG4:.*]] = %[[ARG1:.*]]) : tensor<i1> (tensor<f32>, tensor<f32>) -> tensor<f32> {
+ // CHECK-NEXT: ^bb0(%[[ARG3]]: tensor<f32>, %[[ARG4]]: tensor<f32>):
+ %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<f32>, tensor<f32>) -> tensor<f32> {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ // CHECK: } else {
+ // CHECK-NEXT: ^bb0(%[[ARG3]]: tensor<f32>, %[[ARG4]]: tensor<f32>):
+ } else {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index eb25011..fad1bec 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -259,7 +259,7 @@ func.func @test_cond_if_else_not_isolated_from_above(%arg0: tensor<f32>, %arg1:
func.func @test_cond_if_simplified_form_not_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
// expected-error@+1 {{'tosa.cond_if' op is not conformant to the TOSA specification. It requires the 'then' region is isolated from above.}}
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<f32>) {
tosa.yield %arg0 : tensor<f32>
} else {
tosa.yield %arg1 : tensor<f32>
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index ed74714..3bccb32 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -300,7 +300,7 @@ func.func @test_concat_input_rank_mismatch(%arg0: tensor<1x2x3xf32>, %arg1: tens
func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>) {
%0 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
%1 = "tosa.const"() {values = dense<3.14> : tensor<2xf32>} : () -> tensor<2xf32>
- // expected-error@+1 {{'tosa.pad' op operand #2 must be tosa-conformant scalar tensor of number values, but got 'tensor<2xf32>'}}
+ // expected-error@+1 {{'tosa.pad' op operand #2 must be tosa-conformant unranked tensor of unsigned integer or signless integer or floating-point values or tosa-conformant scalar tensor of number values, but got 'tensor<2xf32>'}}
%2 = tosa.pad %arg0, %0, %1 : (tensor<13x21xf32>, !tosa.shape<4>, tensor<2xf32>) -> tensor<13x21xf32>
return
}
@@ -1006,7 +1006,7 @@ func.func @test_non_tosa_ops() {
func.func @test_pad_rank0_pad_const(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf8E5M2> {
%padding = tosa.const_shape {values = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
%cst = "tosa.const"() { values = dense<-0.0> : tensor<f8E4M3FN> } : () -> tensor<f8E4M3FN>
- // expected-error@+1 {{'tosa.pad' op operand #2 must be tosa-conformant scalar tensor of number values, but got 'tensor<f8E4M3FN>'}}
+ // expected-error@+1 {{'tosa.pad' op operand #2 must be tosa-conformant unranked tensor of unsigned integer or signless integer or floating-point values or tosa-conformant scalar tensor of number values, but got 'tensor<f8E4M3FN>'}}
%0 = tosa.pad %arg0, %padding, %cst : (tensor<13x21x3xf8E4M3FN>, !tosa.shape<6>, tensor<f8E4M3FN>) -> tensor<13x21x3xf8E5M2>
return %0 : tensor<13x21x3xf8E5M2>
}
@@ -1125,7 +1125,7 @@ func.func @test_sub_with_unequal_result_ranks(%arg0: tensor<1x21x3xf32>, %arg1:
// CHECK-LABEL: test_mul_non_scalar_shift_2d
func.func @test_mul_non_scalar_shift_2d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
%shift = "tosa.const"() <{values = dense<0> : tensor<1x1xi8>}> : () -> tensor<1x1xi8>
- // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}}
+ // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant unranked tensor of 8-bit signless integer values or tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}}
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1x1xi8>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
@@ -1134,7 +1134,7 @@ func.func @test_mul_non_scalar_shift_2d(%arg0: tensor<13x21x3xf32>, %arg1: tenso
// CHECK-LABEL: test_mul_non_scalar_shift_1d
func.func @test_mul_non_scalar_shift_1d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
%shift = "tosa.const"() <{values = dense<0> : tensor<2xi8>}> : () -> tensor<2xi8>
- // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<2xi8>'}}
+ // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant unranked tensor of 8-bit signless integer values or tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<2xi8>'}}
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<2xi8>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
@@ -2036,3 +2036,19 @@ func.func @test_scatter_duplicate_indices(%arg0: tensor<2x52x3xf32>, %arg2: tens
%0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi32>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32>
return %0 : tensor<2x52x3xf32>
}
+
+// -----
+
+func.func @test_reduce_all_unsupported_data_types(%arg0: tensor<2x12x11xf32>) -> tensor<1x12x11xf32> {
+ // expected-error@+1 {{'tosa.reduce_all' op illegal: operation operand/result data types did not align with any profile or extension, got (f32,f32), did you mean (i1,i1)?}}
+ %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<2x12x11xf32>) -> tensor<1x12x11xf32>
+ return %0 : tensor<1x12x11xf32>
+}
+
+// -----
+
+func.func @test_rfft2d(%arg0: tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, tensor<13x8x9xbf16>) {
+ // expected-error@+1 {{'tosa.rfft2d' op illegal: operation operand/result data types did not align with any profile or extension, got (bf16,bf16,bf16), did you mean (f32,f32,f32)?}}
+ %0, %1 = tosa.rfft2d %arg0 : (tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, tensor<13x8x9xbf16>)
+ return %0, %1 : tensor<13x8x9xbf16>, tensor<13x8x9xbf16>
+}
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 5630c33..3154f54 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -337,7 +337,7 @@ func.func @test_cast_bf16_i32(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xi32
// -----
func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
// expected-error@+1 {{'tosa.cond_if' op illegal: requires [controlflow]}}
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
} else {
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 0dddf26..bf9ed8a 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -48,10 +48,10 @@ func.func @test_add_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tens
// -----
-func.func @test_arithmetic_right_shift_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+func.func @test_arithmetic_right_shift_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi32>, %arg1: tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32> {
// expected-error@+1 {{'tosa.arithmetic_right_shift' op failed level check: operand rank(shape) <= MAX_RANK}}
- %0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (tensor<1x1x1x1x13x21x3xf32>, tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
- return %0 : tensor<1x1x1x1x13x21x3xf32>
+ %0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (tensor<1x1x1x1x13x21x3xi32>, tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32>
+ return %0 : tensor<1x1x1x1x13x21x3xi32>
}
// -----
@@ -1506,13 +1506,13 @@ func.func @test_while_tensor_list_size(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1:
// -----
func.func @test_cond_if_max_nested_depth(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>, %arg3: tensor<i1>) -> tensor<f32> {
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
- %1 = tosa.cond_if %arg3 -> (tensor<f32>) {
- %2 = tosa.cond_if %arg2 -> (tensor<f32>) {
- %3 = tosa.cond_if %arg3 -> (tensor<f32>) {
- %4 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
+ %1 = tosa.cond_if %arg3 : tensor<i1>-> tensor<f32> {
+ %2 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
+ %3 = tosa.cond_if %arg3 : tensor<i1> -> tensor<f32> {
+ %4 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
// expected-error@+1 {{'tosa.cond_if' op failed level check: 6 >= MAX_NESTING}}
- %5 = tosa.cond_if %arg3 -> (tensor<f32>) {
+ %5 = tosa.cond_if %arg3 : tensor<i1> -> tensor<f32> {
%res = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %res : tensor<f32>
} else {
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index ef51197e..30361a8 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -839,7 +839,7 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
// -----
// CHECK-LABEL: cond_if
func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
} else {
diff --git a/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir b/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir
index 38ac8d8..e957bdd 100644
--- a/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir
@@ -54,7 +54,7 @@ func.func @test_no_change(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
// CHECK-LABEL: test_regions
// CHECK: %arg0: tensor<i8>, %arg1: tensor<i8>
func.func @test_regions(%arg0: tensor<ui8>, %arg1: tensor<ui8>, %arg2: tensor<i1>) -> tensor<ui8> {
- // CHECK: tosa.cond_if %arg2 -> (tensor<i8>)
+ // CHECK: tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<i8>, tensor<i8>) -> tensor<i8>
%0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
^bb0(%arg3: tensor<ui8>, %arg4: tensor<ui8>):
// CHECK: %1 = tosa.add %arg0, %arg1 : (tensor<i8>, tensor<i8>) -> tensor<i8>
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index d0f4027..7b8fc24 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -344,6 +344,30 @@ func.func @test_unranked_zero_points_matmul(%arg0: tensor<1x2x3xf32>, %arg1: ten
// -----
+// CHECK-LABEL: @test_accepts_unranked_scalar_tensor
+func.func @test_accepts_unranked_scalar_tensor(%arg0: tensor<1x2x2xf32>, %arg1: tensor<1xf32>) -> tensor<*xf32> {
+ // CHECK: %[[ZP:.*]] = tosa.cast %arg1 : (tensor<1xf32>) -> tensor<1xf32>
+ %0 = tosa.cast %arg1 : (tensor<1xf32>) -> tensor<*xf32>
+ // CHECK: %[[SHAPE:.*]] = tosa.const_shape
+ %1 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ // CHECK: tosa.pad %arg0, %[[SHAPE]], %[[ZP]] : (tensor<1x2x2xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<1x3x3xf32>
+ %2 = tosa.pad %arg0, %1, %0 : (tensor<1x2x2xf32>, !tosa.shape<6>, tensor<*xf32>) -> tensor<*xf32>
+ return %2 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_unranked_scalar_i8_tensor
+func.func @test_unranked_scalar_i8_tensor(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>, %arg2: tensor<1xi8>) -> tensor<4xi32> {
+ // CHECK: %[[SHIFT:.*]] = tosa.cast %arg2 : (tensor<1xi8>) -> tensor<1xi8>
+ %shift = tosa.cast %arg2 : (tensor<1xi8>) -> tensor<*xi8>
+ // CHECK: tosa.mul %arg0, %arg1, %[[SHIFT]] : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<*xi8>) -> tensor<4xi32>
+ return %0 : tensor<4xi32>
+}
+
+// -----
+
// CHECK-LABEL: @test_table_static
func.func @test_table_static(%arg0 : tensor<4x5xi16>, %arg1 : tensor<513xi16>) -> () {
// CHECK:tosa.table %arg0, %arg1 : (tensor<4x5xi16>, tensor<513xi16>) -> tensor<4x5xi16>
@@ -1153,8 +1177,8 @@ func.func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tens
%b = tosa.log %arg1 : (tensor<f32>) -> tensor<f32>
// CHECK: tosa.cond_if
- // CHECK: -> (tensor<f32>)
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ // CHECK: -> tensor<f32>
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
tosa.yield %a : tensor<f32>
} else {
tosa.yield %b : tensor<f32>
@@ -1167,8 +1191,8 @@ func.func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tens
// CHECK-LABEL: @if_test_dynamic
func.func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
// CHECK: tosa.cond_if
- // CHECK: -> (tensor<?xf32>)
- %0 = tosa.cond_if %arg2 -> (tensor<?xf32>) {
+ // CHECK: -> tensor<?xf32>
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<?xf32> {
tosa.yield %arg0 : tensor<2xf32>
} else {
tosa.yield %arg1 : tensor<3xf32>
@@ -1181,8 +1205,8 @@ func.func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 :
// CHECK-LABEL: @if_test_unranked
func.func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
// CHECK: tosa.cond_if
- // CHECK: -> (tensor<*xf32>)
- %0 = tosa.cond_if %arg2 -> (tensor<*xf32>) {
+ // CHECK: -> tensor<*xf32>
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<*xf32> {
tosa.yield %arg0 : tensor<f32>
} else {
tosa.yield %arg1 : tensor<3xf32>
@@ -1195,8 +1219,8 @@ func.func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 :
// CHECK-LABEL: @if_test_propagate
func.func @if_test_propagate(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> () {
// CHECK: tosa.cond_if
- // CHECK: -> (tensor<f32>)
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ // CHECK: -> tensor<f32>
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
} else {
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index b305236..2a937b0 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -500,9 +500,39 @@ func.func @test_cond_if_input_list_mismatch_else_block_2(%arg0: tensor<f32>, %ar
// -----
+func.func @test_cond_if_input_list_mismatch_else_block_simple(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' arguments (1) and 'input_list' (2)}}
+ %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<f32>, tensor<f32>) -> tensor<f32> {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ } else {
+ ^bb0(%arg3: tensor<f32>):
+ tosa.yield %arg3 : tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_cond_if_input_list_mismatch_else_block_simple_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' arguments (2) and 'input_list' (1)}}
+ %0 = tosa.cond_if %arg2 (%arg3 = %arg0) : tensor<i1> (tensor<f32>) -> tensor<f32> {
+ ^bb0(%arg3: tensor<f32>):
+ tosa.yield %arg3 : tensor<f32>
+ } else {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
+
+// -----
+
func.func @test_cond_if_output_list_mismatch_then_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
// expected-error@+1 {{'tosa.cond_if' op require same number of values in 'then_graph' results (2) and 'output_list' (1)}}
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
%2 = tosa.add %1, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1, %2 : tensor<f32>, tensor<f32>
@@ -517,7 +547,7 @@ func.func @test_cond_if_output_list_mismatch_then_block(%arg0: tensor<f32>, %arg
func.func @test_cond_if_output_list_mismatch_then_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
// expected-error@+1 {{'tosa.cond_if' op require same number of values in 'then_graph' results (1) and 'output_list' (2)}}
- %0, %2 = tosa.cond_if %arg2 -> (tensor<f32>, tensor<f32>) {
+ %0, %2 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<f32>, tensor<f32>) {
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
} else {
@@ -531,7 +561,7 @@ func.func @test_cond_if_output_list_mismatch_then_block_2(%arg0: tensor<f32>, %a
func.func @test_cond_if_output_list_mismatch_else_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
// expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' results (2) and 'output_list' (1)}}
- %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> {
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
} else {
@@ -546,7 +576,7 @@ func.func @test_cond_if_output_list_mismatch_else_block(%arg0: tensor<f32>, %arg
func.func @test_cond_if_output_list_mismatch_else_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
// expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' results (1) and 'output_list' (2)}}
- %0, %2 = tosa.cond_if %arg2 -> (tensor<f32>, tensor<f32>) {
+ %0, %2 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<f32>, tensor<f32>) {
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
%2 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1, %2 : tensor<f32>, tensor<f32>
@@ -574,6 +604,53 @@ func.func @test_cond_if_cond_input_not_size_one(%arg0: tensor<f32>, %arg1: tenso
// -----
+// CHECK-LABEL: cond_if_cond_type
+func.func @test_cond_if_cond_type(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // expected-error@+2 {{expected ':'}}
+ // expected-error@+1 {{custom op 'tosa.cond_if' expected type for condition operand}}
+ %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ tosa.yield %arg0 : tensor<f32>
+ } else {
+ tosa.yield %arg1 : tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_cond_if_input_list_type_mismatch_simple(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // expected-error@+1 {{custom op 'tosa.cond_if' expected as many input types as operands (expected 2 got 0)}}
+ %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> () -> tensor<f32> {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ } else {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_cond_if_incorrect_type_simple(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // expected-error@+2 {{expected non-function type}}
+ // expected-error@+1 {{custom op 'tosa.cond_if' expected list of types for block arguments followed by arrow type and list of return types}}
+ %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (%arg3) -> tensor<f32> {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ } else {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
+
+// -----
+
func.func @test_while_loop_input_list_mismatch_body_block_in(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
%0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
// expected-error@+1 {{'tosa.while_loop' op require same number of values in 'body_graph' arguments (3) and 'input_list' (2)}}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1461c30..56996b5 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1330,11 +1330,11 @@ func.func @fold_consecutive_broadcasts(%a : i32) -> vector<4x16xi32> {
// -----
-// CHECK-LABEL: shape_cast_constant
+// CHECK-LABEL: shape_cast_splat_constant
// CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : vector<3x4x2xi32>
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<20x2xf32>
// CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32>
-func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
+func.func @shape_cast_splat_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
%cst = arith.constant dense<2.000000e+00> : vector<5x4x2xf32>
%cst_1 = arith.constant dense<1> : vector<12x2xi32>
%0 = vector.shape_cast %cst : vector<5x4x2xf32> to vector<20x2xf32>
@@ -1344,6 +1344,36 @@ func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
// -----
+// Test of shape_cast's fold method:
+// shape_cast(constant) -> constant.
+//
+// CHECK-LABEL: @shape_cast_dense_int_constant
+// CHECK: %[[CST:.*]] = arith.constant
+// CHECK-SAME{LITERAL}: dense<[[2, 3, 5], [7, 11, 13]]>
+// CHECK: return %[[CST]] : vector<2x3xi8>
+func.func @shape_cast_dense_int_constant() -> vector<2x3xi8> {
+ %cst = arith.constant dense<[2, 3, 5, 7, 11, 13]> : vector<6xi8>
+ %0 = vector.shape_cast %cst : vector<6xi8> to vector<2x3xi8>
+ return %0 : vector<2x3xi8>
+}
+
+// -----
+
+// Test of shape_cast fold's method:
+// (shape_cast(const_x), const_x) -> (const_x_folded, const_x)
+//
+// CHECK-LABEL: @shape_cast_dense_float_constant
+// CHECK-DAG: %[[CST0:.*]] = {{.*}}1.000000e+00, 2.000000e+00{{.*}} vector<1x2xf32>
+// CHECK-DAG: %[[CST1:.*]] = {{.*}}1.000000e+00, 2.000000e+00{{.*}} vector<2xf32>
+// CHECK: return %[[CST1]], %[[CST0]] : vector<2xf32>, vector<1x2xf32>
+func.func @shape_cast_dense_float_constant() -> (vector<2xf32>, vector<1x2xf32>){
+ %cst = arith.constant dense<[[1.0, 2.0]]> : vector<1x2xf32>
+ %0 = vector.shape_cast %cst : vector<1x2xf32> to vector<2xf32>
+ return %0, %cst : vector<2xf32>, vector<1x2xf32>
+}
+
+// -----
+
// CHECK-LABEL: shape_cast_poison
// CHECK-DAG: %[[CST1:.*]] = ub.poison : vector<3x4x2xi32>
// CHECK-DAG: %[[CST0:.*]] = ub.poison : vector<20x2xf32>
@@ -2562,118 +2592,6 @@ func.func @insert_2d_splat_constant()
// -----
-// CHECK-LABEL: func @insert_element_fold
-// CHECK: %[[V:.+]] = arith.constant dense<[0, 1, 7, 3]> : vector<4xi32>
-// CHECK: return %[[V]]
-func.func @insert_element_fold() -> vector<4xi32> {
- %v = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32>
- %s = arith.constant 7 : i32
- %i = arith.constant 2 : i32
- %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32>
- return %1 : vector<4xi32>
-}
-
-// -----
-
-// CHECK-LABEL: func @insert_element_invalid_fold
-func.func @insert_element_invalid_fold() -> vector<1xf32> {
- // Out-of-bound index here.
- %c26 = arith.constant 26 : index
- %cst_2 = arith.constant 1.60215309E+9 : f32
- %cst_20 = arith.constant dense<1.60215309E+9> : vector<1xf32>
-// CHECK: vector.insertelement
- %46 = vector.insertelement %cst_2, %cst_20[%c26 : index] : vector<1xf32>
- return %46 : vector<1xf32>
-}
-
-
-// -----
-
-// Do not crash on poison
-// CHECK-LABEL: func @insert_poison_fold1
-// CHECK: vector.insertelement
-func.func @insert_poison_fold1() -> vector<4xi32> {
- %v = ub.poison : vector<4xi32>
- %s = arith.constant 7 : i32
- %i = arith.constant 2 : i32
- %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32>
- return %1 : vector<4xi32>
-}
-
-// -----
-
-// Do not crash on poison
-// CHECK-LABEL: func @insert_poison_fold2
-// CHECK: vector.insertelement
-func.func @insert_poison_fold2() -> vector<4xi32> {
- %v = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32>
- %s = ub.poison : i32
- %i = arith.constant 2 : i32
- %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32>
- return %1 : vector<4xi32>
-}
-
-// -----
-
-// Do not crash on poison
-// CHECK-LABEL: func @insert_poison_fold3
-// CHECK: vector.insertelement
-func.func @insert_poison_fold3() -> vector<4xi32> {
- %v = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32>
- %s = arith.constant 7 : i32
- %i = ub.poison : i32
- %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32>
- return %1 : vector<4xi32>
-}
-
-// -----
-
-// CHECK-LABEL: func @extract_element_fold
-// CHECK: %[[C:.+]] = arith.constant 5 : i32
-// CHECK: return %[[C]]
-func.func @extract_element_fold() -> i32 {
- %v = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32>
- %i = arith.constant 2 : i32
- %1 = vector.extractelement %v[%i : i32] : vector<4xi32>
- return %1 : i32
-}
-
-// CHECK-LABEL: func @extract_element_splat_fold
-// CHECK-SAME: (%[[ARG:.+]]: i32)
-// CHECK: return %[[ARG]]
-func.func @extract_element_splat_fold(%a : i32) -> i32 {
- %v = vector.splat %a : vector<4xi32>
- %i = arith.constant 2 : i32
- %1 = vector.extractelement %v[%i : i32] : vector<4xi32>
- return %1 : i32
-}
-
-// -----
-
-// Do not crash on poison
-// CHECK-LABEL: func @extract_element_poison_fold1
-// CHECK: vector.extractelement
-func.func @extract_element_poison_fold1() -> i32 {
- %v = ub.poison : vector<4xi32>
- %i = arith.constant 2 : i32
- %1 = vector.extractelement %v[%i : i32] : vector<4xi32>
- return %1 : i32
-}
-
-// -----
-
-// Do not crash on poison
-// CHECK-LABEL: func @extract_element_poison_fold2
-// CHECK: vector.extractelement
-func.func @extract_element_poison_fold2() -> i32 {
- %v = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32>
- %i = ub.poison : i32
- %1 = vector.extractelement %v[%i : i32] : vector<4xi32>
- return %1 : i32
-}
-
-// -----
-
// CHECK-LABEL: func @reduce_one_element_vector_extract
// CHECK-SAME: (%[[V:.+]]: vector<1xf32>)
// CHECK: %[[S:.+]] = vector.extract %[[V]][0] : f32 from vector<1xf32>
@@ -2933,18 +2851,6 @@ func.func @extract_strided_slice_of_constant_mask() -> vector<5x7xi1>{
// -----
-// CHECK-LABEL: func.func @fold_extractelement_of_broadcast(
-// CHECK-SAME: %[[f:.*]]: f32
-// CHECK: return %[[f]]
-func.func @fold_extractelement_of_broadcast(%f: f32) -> f32 {
- %0 = vector.broadcast %f : f32 to vector<15xf32>
- %c5 = arith.constant 5 : index
- %1 = vector.extractelement %0 [%c5 : index] : vector<15xf32>
- return %1 : f32
-}
-
-// -----
-
// CHECK-LABEL: func.func @fold_0d_vector_reduction
func.func @fold_0d_vector_reduction(%arg0: vector<f32>) -> f32 {
// CHECK-NEXT: %[[RES:.*]] = vector.extract %arg{{.*}}[] : f32 from vector<f32>
diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir
index 0263193..b2f16bb 100644
--- a/mlir/test/Dialect/Vector/int-range-interface.mlir
+++ b/mlir/test/Dialect/Vector/int-range-interface.mlir
@@ -51,6 +51,15 @@ func.func @vector_shape_cast() -> vector<4x4xindex> {
func.return %2 : vector<4x4xindex>
}
+// CHECK-LABEL: func @vector_transpose
+// CHECK: test.reflect_bounds {smax = 8 : index, smin = 7 : index, umax = 8 : index, umin = 7 : index}
+func.func @vector_transpose() -> vector<2x4xindex> {
+ %0 = test.with_bounds { smax = 8 : index, smin = 7 : index, umax = 8 : index, umin = 7 : index } : vector<4x2xindex>
+ %1 = vector.transpose %0, [1, 0] : vector<4x2xindex> to vector<2x4xindex>
+ %2 = test.reflect_bounds %1 : vector<2x4xindex>
+ func.return %2 : vector<2x4xindex>
+}
+
// CHECK-LABEL: func @vector_extract
// CHECK: test.reflect_bounds {smax = 6 : index, smin = 5 : index, umax = 6 : index, umin = 5 : index}
func.func @vector_extract() -> index {
@@ -60,16 +69,6 @@ func.func @vector_extract() -> index {
func.return %2 : index
}
-// CHECK-LABEL: func @vector_extractelement
-// CHECK: test.reflect_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index}
-func.func @vector_extractelement() -> index {
- %c0 = arith.constant 0 : index
- %0 = test.with_bounds { umin = 6 : index, umax = 7 : index, smin = 6 : index, smax = 7 : index } : vector<4xindex>
- %1 = vector.extractelement %0[%c0 : index] : vector<4xindex>
- %2 = test.reflect_bounds %1 : index
- func.return %2 : index
-}
-
// CHECK-LABEL: func @vector_add
// CHECK: test.reflect_bounds {smax = 12 : index, smin = 10 : index, umax = 12 : index, umin = 10 : index}
func.func @vector_add() -> vector<4xindex> {
@@ -90,17 +89,6 @@ func.func @vector_insert() -> vector<4xindex> {
func.return %3 : vector<4xindex>
}
-// CHECK-LABEL: func @vector_insertelement
-// CHECK: test.reflect_bounds {smax = 8 : index, smin = 5 : index, umax = 8 : index, umin = 5 : index}
-func.func @vector_insertelement() -> vector<4xindex> {
- %c0 = arith.constant 0 : index
- %0 = test.with_bounds { umin = 5 : index, umax = 7 : index, smin = 5 : index, smax = 7 : index } : vector<4xindex>
- %1 = test.with_bounds { umin = 6 : index, umax = 8 : index, smin = 6 : index, smax = 8 : index } : index
- %2 = vector.insertelement %1, %0[%c0 : index] : vector<4xindex>
- %3 = test.reflect_bounds %2 : vector<4xindex>
- func.return %3 : vector<4xindex>
-}
-
// CHECK-LABEL: func @test_loaded_vector_extract
// No bounds
// CHECK: test.reflect_bounds {smax = 2147483647 : si32, smin = -2147483648 : si32, umax = 4294967295 : ui32, umin = 0 : ui32} %{{.*}} : i32
@@ -120,3 +108,11 @@ func.func @test_vector_extsi() -> vector<2xi32> {
%2 = test.reflect_bounds %1 : vector<2xi32>
func.return %2 : vector<2xi32>
}
+
+// CHECK-LABEL: func @vector_step
+// CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index}
+func.func @vector_step() -> vector<8xindex> {
+ %0 = vector.step : vector<8xindex>
+ %1 = test.reflect_bounds %0 : vector<8xindex>
+ func.return %1 : vector<8xindex>
+}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index ca837d3..c21de56 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -119,30 +119,6 @@ func.func @shuffle_empty_mask(%arg0: vector<2xf32>, %arg1: vector<2xf32>) {
// -----
-func.func @extract_element(%arg0: vector<f32>) {
- %c = arith.constant 3 : i32
- // expected-error@+1 {{expected position to be empty with 0-D vector}}
- %1 = vector.extractelement %arg0[%c : i32] : vector<f32>
-}
-
-// -----
-
-func.func @extract_element(%arg0: vector<4xf32>) {
- %c = arith.constant 3 : i32
- // expected-error@+1 {{expected position for 1-D vector}}
- %1 = vector.extractelement %arg0[] : vector<4xf32>
-}
-
-// -----
-
-func.func @extract_element(%arg0: vector<4x4xf32>) {
- %c = arith.constant 3 : i32
- // expected-error@+1 {{unexpected >1 vector rank}}
- %1 = vector.extractelement %arg0[%c : i32] : vector<4x4xf32>
-}
-
-// -----
-
func.func @extract_vector_type(%arg0: index) {
// expected-error@+1 {{invalid kind of type specified: expected builtin.vector, but found 'index'}}
%1 = vector.extract %arg0[] : index from index
@@ -192,38 +168,6 @@ func.func @extract_position_overflow(%arg0: vector<4x8x16xf32>) {
// -----
-func.func @insert_element(%arg0: f32, %arg1: vector<f32>) {
- %c = arith.constant 3 : i32
- // expected-error@+1 {{expected position to be empty with 0-D vector}}
- %0 = vector.insertelement %arg0, %arg1[%c : i32] : vector<f32>
-}
-
-// -----
-
-func.func @insert_element(%arg0: f32, %arg1: vector<4xf32>) {
- %c = arith.constant 3 : i32
- // expected-error@+1 {{expected position for 1-D vector}}
- %0 = vector.insertelement %arg0, %arg1[] : vector<4xf32>
-}
-
-// -----
-
-func.func @insert_element(%arg0: f32, %arg1: vector<4x4xf32>) {
- %c = arith.constant 3 : i32
- // expected-error@+1 {{unexpected >1 vector rank}}
- %0 = vector.insertelement %arg0, %arg1[%c : i32] : vector<4x4xf32>
-}
-
-// -----
-
-func.func @insert_element_wrong_type(%arg0: i32, %arg1: vector<4xf32>) {
- %c = arith.constant 3 : i32
- // expected-error@+1 {{'vector.insertelement' op failed to verify that source operand type matches element type of result}}
- %0 = "vector.insertelement" (%arg0, %arg1, %c) : (i32, vector<4xf32>, i32) -> (vector<4xf32>)
-}
-
-// -----
-
func.func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
// expected-error@+1 {{expected position attribute of rank no greater than dest vector rank}}
%1 = vector.insert %a, %b[3, 3, 3, 3, 3, 3] : f32 into vector<4x8x16xf32>
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 6a56116..625ffc1 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -199,22 +199,6 @@ func.func @shuffle_poison_mask(%a: vector<4xf32>, %b: vector<4xf32>) -> vector<4
return %1 : vector<4xf32>
}
-// CHECK-LABEL: @extract_element_0d
-func.func @extract_element_0d(%a: vector<f32>) -> f32 {
- // CHECK-NEXT: vector.extractelement %{{.*}}[] : vector<f32>
- %1 = vector.extractelement %a[] : vector<f32>
- return %1 : f32
-}
-
-// CHECK-LABEL: @extract_element
-func.func @extract_element(%a: vector<16xf32>) -> f32 {
- // CHECK: %[[C15:.*]] = arith.constant 15 : i32
- %c = arith.constant 15 : i32
- // CHECK-NEXT: vector.extractelement %{{.*}}[%[[C15]] : i32] : vector<16xf32>
- %1 = vector.extractelement %a[%c : i32] : vector<16xf32>
- return %1 : f32
-}
-
// CHECK-LABEL: @extract_const_idx
func.func @extract_const_idx(%arg0: vector<4x8x16xf32>)
-> (vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32) {
@@ -256,22 +240,6 @@ func.func @extract_poison_idx(%a: vector<4x5xf32>) -> f32 {
return %0 : f32
}
-// CHECK-LABEL: @insert_element_0d
-func.func @insert_element_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
- // CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[] : vector<f32>
- %1 = vector.insertelement %a, %b[] : vector<f32>
- return %1 : vector<f32>
-}
-
-// CHECK-LABEL: @insert_element
-func.func @insert_element(%a: f32, %b: vector<16xf32>) -> vector<16xf32> {
- // CHECK: %[[C15:.*]] = arith.constant 15 : i32
- %c = arith.constant 15 : i32
- // CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[%[[C15]] : i32] : vector<16xf32>
- %1 = vector.insertelement %a, %b[%c : i32] : vector<16xf32>
- return %1 : vector<16xf32>
-}
-
// CHECK-LABEL: @insert_const_idx
func.func @insert_const_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>,
%res: vector<4x8x16xf32>) -> vector<4x8x16xf32> {
diff --git a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
index 8e167a5..d5e3443 100644
--- a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
@@ -2,7 +2,7 @@
// CHECK-LABEL: func @broadcast_vec1d_from_scalar
// CHECK-SAME: %[[A:.*0]]: f32
-// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2xf32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2xf32>
// CHECK: return %[[T0]] : vector<2xf32>
func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
@@ -12,7 +12,7 @@ func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
// CHECK-LABEL: func @broadcast_vec2d_from_scalar
// CHECK-SAME: %[[A:.*0]]: f32
-// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3xf32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3xf32>
// CHECK: return %[[T0]] : vector<2x3xf32>
func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
@@ -22,7 +22,7 @@ func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
// CHECK-LABEL: func @broadcast_vec3d_from_scalar
// CHECK-SAME: %[[A:.*0]]: f32
-// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3x4xf32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3x4xf32>
// CHECK: return %[[T0]] : vector<2x3x4xf32>
func.func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> {
@@ -87,7 +87,7 @@ func.func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf3
// CHECK-LABEL: func @broadcast_stretch
// CHECK-SAME: %[[A:.*0]]: vector<1xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<1xf32>
-// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<4xf32>
+// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<4xf32>
// CHECK: return %[[T1]] : vector<4xf32>
func.func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> {
@@ -113,16 +113,16 @@ func.func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32>
// CHECK-SAME: %[[A:.*0]]: vector<4x1xf32>
// CHECK: %[[U0:.*]] = ub.poison : vector<4x3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<4x1xf32>
-// CHECK: %[[T2:.*]] = vector.splat %[[T0]] : vector<3xf32>
+// CHECK: %[[T2:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[U0]] [0] : vector<3xf32> into vector<4x3xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : f32 from vector<4x1xf32>
-// CHECK: %[[T6:.*]] = vector.splat %[[T4]] : vector<3xf32>
+// CHECK: %[[T6:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<4x3xf32>
// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : f32 from vector<4x1xf32>
-// CHECK: %[[T10:.*]] = vector.splat %[[T8]] : vector<3xf32>
+// CHECK: %[[T10:.*]] = vector.broadcast %[[T8]] : f32 to vector<3xf32>
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T7]] [2] : vector<3xf32> into vector<4x3xf32>
// CHECK: %[[T12:.*]] = vector.extract %[[A]][3, 0] : f32 from vector<4x1xf32>
-// CHECK: %[[T14:.*]] = vector.splat %[[T12]] : vector<3xf32>
+// CHECK: %[[T14:.*]] = vector.broadcast %[[T12]] : f32 to vector<3xf32>
// CHECK: %[[T15:.*]] = vector.insert %[[T14]], %[[T11]] [3] : vector<3xf32> into vector<4x3xf32>
// CHECK: return %[[T15]] : vector<4x3xf32>
diff --git a/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir
index 059d955..5a8125e 100644
--- a/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir
@@ -5,11 +5,11 @@
// CHECK-SAME: %[[B:.*1]]: vector<3xf32>
// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32>
-// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32>
+// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
// CHECK: %[[T2:.*]] = arith.mulf %[[T1]], %[[B]] : vector<3xf32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xf32>
+// CHECK: %[[T5:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32>
// CHECK: %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<2x3xf32>
// CHECK: return %[[T7]] : vector<2x3xf32>
@@ -26,12 +26,12 @@ func.func @outerproduct_noacc(%arg0: vector<2xf32>,
// CHECK-SAME: %[[C:.*2]]: vector<2x3xf32>
// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32>
-// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32>
+// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xf32> from vector<2x3xf32>
// CHECK: %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32>
// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.splat %[[T5]] : vector<3xf32>
+// CHECK: %[[T6:.*]] = vector.broadcast %[[T5]] : f32 to vector<3xf32>
// CHECK: %[[T7:.*]] = vector.extract %[[C]][1] : vector<3xf32> from vector<2x3xf32>
// CHECK: %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32>
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32>
@@ -49,11 +49,11 @@ func.func @outerproduct_acc(%arg0: vector<2xf32>,
// CHECK-SAME: %[[B:.*1]]: vector<3xi32>
// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32>
-// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32>
+// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : i32 to vector<3xi32>
// CHECK: %[[T2:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32>
-// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xi32>
+// CHECK: %[[T5:.*]] = vector.broadcast %[[T4]] : i32 to vector<3xi32>
// CHECK: %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xi32> into vector<2x3xi32>
// CHECK: return %[[T7]] : vector<2x3xi32>
@@ -69,13 +69,13 @@ func.func @outerproduct_noacc_int(%arg0: vector<2xi32>,
// CHECK-SAME: %[[C:.*2]]: vector<2x3xi32>
// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32>
-// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32>
+// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : i32 to vector<3xi32>
// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xi32> from vector<2x3xi32>
// CHECK: %[[T3:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32>
// CHECK: %[[T4:.*]] = arith.addi %[[T3]], %[[T2]] : vector<3xi32>
// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32>
-// CHECK: %[[T7:.*]] = vector.splat %[[T6]] : vector<3xi32>
+// CHECK: %[[T7:.*]] = vector.broadcast %[[T6]] : i32 to vector<3xi32>
// CHECK: %[[T8:.*]] = vector.extract %[[C]][1] : vector<3xi32> from vector<2x3xi32>
// CHECK: %[[T9:.*]] = arith.muli %[[T7]], %[[B]] : vector<3xi32>
// CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[T8]] : vector<3xi32>
@@ -91,7 +91,7 @@ func.func @outerproduct_acc_int(%arg0: vector<2xi32>,
// CHECK-LABEL: func @axpy_fp(
// CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
// CHECK-SAME: %[[B:.*1]]: f32)
-// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : f32 to vector<16xf32>
// CHECK: %[[T1:.*]] = arith.mulf %[[A]], %[[T0]] : vector<16xf32>
// CHECK: return %[[T1]] : vector<16xf32>
func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> {
@@ -103,7 +103,7 @@ func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> {
// CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
// CHECK-SAME: %[[B:.*1]]: f32,
// CHECK-SAME: %[[C:.*2]]: vector<16xf32>)
-// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : f32 to vector<16xf32>
// CHECK: %[[T1:.*]] = vector.fma %[[A]], %[[T0]], %[[C]] : vector<16xf32>
// CHECK: return %[[T1]] : vector<16xf32>
func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) -> vector<16xf32> {
@@ -114,7 +114,7 @@ func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>
// CHECK-LABEL: func @axpy_int(
// CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
// CHECK-SAME: %[[B:.*1]]: i32)
-// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : i32 to vector<16xi32>
// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32>
// CHECK: return %[[T1]] : vector<16xi32>
func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> {
@@ -126,7 +126,7 @@ func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> {
// CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
// CHECK-SAME: %[[B:.*1]]: i32,
// CHECK-SAME: %[[C:.*2]]: vector<16xi32>)
-// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : i32 to vector<16xi32>
// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32>
// CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[C]] : vector<16xi32>
// CHECK: return %[[T2]] : vector<16xi32>
diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir
index b826cdc..ef881ba 100644
--- a/mlir/test/Dialect/Vector/vector-sink.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink.mlir
@@ -180,13 +180,14 @@ func.func @negative_not_elementwise() -> vector<2x2xf32> {
// -----
-// The source and the result for arith.cmp have different types - not supported
-
-// CHECK-LABEL: func.func @negative_source_and_result_mismatch
-// CHECK: %[[BROADCAST:.+]] = vector.broadcast
-// CHECK: %[[RETURN:.+]] = arith.cmpf uno, %[[BROADCAST]], %[[BROADCAST]]
-// CHECK: return %[[RETURN]]
-func.func @negative_source_and_result_mismatch(%arg0 : f32, %arg1 : vector<1xf32>) -> vector<1xi1> {
+// The source and the result for arith.cmp have different types
+
+// CHECK-LABEL: func.func @source_and_result_mismatch(
+// CHECK-SAME: %[[ARG0:.+]]: f32)
+// CHECK: %[[COMPARE:.+]] = arith.cmpf uno, %[[ARG0]], %[[ARG0]]
+// CHECK: %[[BROADCAST:.+]] = vector.broadcast %[[COMPARE]] : i1 to vector<1xi1>
+// CHECK: return %[[BROADCAST]]
+func.func @source_and_result_mismatch(%arg0 : f32) -> vector<1xi1> {
%0 = vector.broadcast %arg0 : f32 to vector<1xf32>
%1 = arith.cmpf uno, %0, %0 : vector<1xf32>
return %1 : vector<1xi1>
@@ -210,6 +211,130 @@ func.func @negative_op_only_supports_vectors(%arg0 : f32) -> vector<1xf32> {
return %1 : vector<1xf32>
}
+// -----
+
+// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const(
+// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> {
+// CHECK: %[[NEW_CST:.*]] = arith.constant 2 : index
+// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[NEW_CST]] : index
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
+// CHECK: return %[[BCAST]] : vector<1x4xindex>
+
+func.func @broadcast_scalar_and_splat_const(%arg0: index) -> vector<1x4xindex> {
+ %0 = vector.broadcast %arg0 : index to vector<1x4xindex>
+ %cst = arith.constant dense<2> : vector<1x4xindex>
+ %2 = arith.addi %0, %cst : vector<1x4xindex>
+ return %2 : vector<1x4xindex>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const_const_first(
+// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> {
+// CHECK: %[[NEW_CST:.*]] = arith.constant 2 : index
+// CHECK: %[[SUB:.*]] = arith.subi %[[NEW_CST]], %[[ARG_0]] : index
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[SUB]] : index to vector<1x4xindex>
+// CHECK: return %[[BCAST]] : vector<1x4xindex>
+
+func.func @broadcast_scalar_and_splat_const_const_first(%arg0: index) -> vector<1x4xindex> {
+ %0 = vector.broadcast %arg0 : index to vector<1x4xindex>
+ %cst = arith.constant dense<2> : vector<1x4xindex>
+ %2 = arith.subi %cst, %0 : vector<1x4xindex>
+ return %2 : vector<1x4xindex>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_vector_and_splat_const(
+// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>) -> vector<3x4xf32> {
+// CHECK: %[[NEW_CST:.*]] = arith.constant dense<2.000000e+00> : vector<4xf32>
+// CHECK: %[[ADD:.*]] = arith.mulf %[[ARG_0]], %[[NEW_CST]] : vector<4xf32>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : vector<4xf32> to vector<3x4xf32>
+// CHECK: return %[[BCAST]] : vector<3x4xf32>
+
+func.func @broadcast_vector_and_splat_const(%arg0: vector<4xf32>) -> vector<3x4xf32> {
+ %0 = vector.broadcast %arg0 : vector<4xf32> to vector<3x4xf32>
+ %cst = arith.constant dense<2.000000e+00> : vector<3x4xf32>
+ %2 = arith.mulf %0, %cst : vector<3x4xf32>
+ return %2 : vector<3x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @negative_broadcast_with_non_splat_const(
+// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> {
+// CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : index to vector<1x4xindex>
+// CHECK-DAG: %[[CST:.*]] = arith.constant dense<{{\[}}[0, 1, 2, 3]]> : vector<1x4xindex>
+// CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[CST]] : vector<1x4xindex>
+// CHECK: return %[[ADD]] : vector<1x4xindex>
+
+func.func @negative_broadcast_with_non_splat_const(%arg0: index) -> vector<1x4xindex> {
+ %0 = vector.broadcast %arg0 : index to vector<1x4xindex>
+ %cst = arith.constant dense<[[0, 1, 2, 3]]> : vector<1x4xindex>
+ %2 = arith.addi %0, %cst : vector<1x4xindex>
+ return %2 : vector<1x4xindex>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_scalar_mixed_type(
+// CHECK-SAME: %[[ARG_0:.*]]: f16) -> vector<1x4xf32> {
+// CHECK: %[[EXTF:.*]] = arith.extf %[[ARG_0]] : f16 to f32
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTF]] : f32 to vector<1x4xf32>
+// CHECK: return %[[BCAST]] : vector<1x4xf32>
+
+func.func @broadcast_scalar_mixed_type(%arg0: f16) -> vector<1x4xf32> {
+ %0 = vector.broadcast %arg0 : f16 to vector<1x4xf16>
+ %1 = arith.extf %0 : vector<1x4xf16> to vector<1x4xf32>
+ return %1 : vector<1x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_vector_mixed_type(
+// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf16>) -> vector<3x4xf32> {
+// CHECK: %[[EXTF:.*]] = arith.extf %[[ARG_0]] : vector<4xf16> to vector<4xf32>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTF]] : vector<4xf32> to vector<3x4xf32>
+// CHECK: return %[[BCAST]] : vector<3x4xf32>
+
+func.func @broadcast_vector_mixed_type(%arg0: vector<4xf16>) -> vector<3x4xf32> {
+ %0 = vector.broadcast %arg0 : vector<4xf16> to vector<3x4xf16>
+ %1 = arith.extf %0 : vector<3x4xf16> to vector<3x4xf32>
+ return %1 : vector<3x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const_mixed_type(
+// CHECK-SAME: %[[ARG_0:.*]]: f32) -> vector<1x4xf32> {
+// CHECK: %[[NEW_CST:.*]] = arith.constant 3 : i32
+// CHECK: %[[POW:.*]] = math.fpowi %[[ARG_0]], %[[NEW_CST]] : f32, i32
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[POW]] : f32 to vector<1x4xf32>
+// CHECK: return %[[BCAST]] : vector<1x4xf32>
+
+func.func @broadcast_scalar_and_splat_const_mixed_type(%arg0: f32) -> vector<1x4xf32> {
+ %0 = vector.broadcast %arg0 : f32 to vector<1x4xf32>
+ %cst = arith.constant dense<3> : vector<1x4xi32>
+ %2 = math.fpowi %0, %cst : vector<1x4xf32>, vector<1x4xi32>
+ return %2 : vector<1x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_vector_and_splat_const_mixed_type(
+// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>) -> vector<3x4xf32> {
+// CHECK: %[[NEW_CST:.*]] = arith.constant dense<3> : vector<4xi32>
+// CHECK: %[[POW:.*]] = math.fpowi %[[ARG_0]], %[[NEW_CST]] : vector<4xf32>, vector<4xi32>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[POW]] : vector<4xf32> to vector<3x4xf32>
+// CHECK: return %[[BCAST]] : vector<3x4xf32>
+
+func.func @broadcast_vector_and_splat_const_mixed_type(%arg0: vector<4xf32>) -> vector<3x4xf32> {
+ %0 = vector.broadcast %arg0 : vector<4xf32> to vector<3x4xf32>
+ %cst = arith.constant dense<3> : vector<3x4xi32>
+ %2 = math.fpowi %0, %cst : vector<3x4xf32>, vector<3x4xi32>
+ return %2 : vector<3x4xf32>
+}
+
//===----------------------------------------------------------------------===//
// [Pattern: ReorderCastOpsOnBroadcast]
//
diff --git a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
index 5dd65ea..44601a4 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
@@ -68,6 +68,24 @@ func.func @transfer_write_unroll(%mem : memref<4x4xf32>, %vec : vector<4x4xf32>)
// -----
+// Ensure that cases with mismatched target and source shape ranks
+// do not lead to a crash.
+// Note: The vector unrolling target shape in `test-vector-transfer-unrolling-patterns`
+// is currently hard-coded to [2, 2].
+
+// CHECK-LABEL: func @negative_transfer_write
+// CHECK-NOT: vector.extract_strided_slice
+// CHECK: vector.transfer_write
+// CHECK: return
+func.func @negative_transfer_write(%vec: vector<6x34x62xi8>) {
+ %c0 = arith.constant 0 : index
+ %alloc = memref.alloc() : memref<6x34x62xi8>
+ vector.transfer_write %vec, %alloc[%c0, %c0, %c0]: vector<6x34x62xi8>, memref<6x34x62xi8>
+ return
+}
+
+// -----
+
// CHECK-LABEL: func @transfer_readwrite_unroll
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 0160bfe..dff3ffa 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -385,6 +385,74 @@ func.func @load_gather_vc_3(%src: ui64) {
}
// -----
+func.func @prefetch_offset_wi_1(%src: memref<4x4xf32>) {
+ %offsets = arith.constant dense<[0]> : vector<1xindex>
+ // expected-error@+1 {{Expecting the source is a 1D memref or pointer}}
+ xegpu.prefetch %src[%offsets]: memref<4x4xf32>, vector<1xindex>
+ return
+}
+
+// -----
+func.func @load_gather_offset_sg(%src: memref<?xf16>) {
+ %offsets = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %mask = arith.constant dense<1>: vector<8xi1>
+ // expected-error@+1 {{Mask should match value except the chunk size dim}}
+ %2 = xegpu.load %src[%offsets], %mask
+ : memref<?xf16>, vector<4xindex>, vector<8xi1>
+ -> vector<4x2xf16>
+ return
+}
+
+// -----
+func.func @load_gather_offset_wi(%src: ui64) {
+ %mask = arith.constant dense<1>: vector<1xi1>
+ %offsets = arith.constant dense<[0]> : vector<1xindex>
+ // expected-error@+1 {{value elements must match chunk size}}
+ %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64, vector<1xindex>, vector<1xi1> -> vector<4xf32>
+ return
+}
+
+// -----
+func.func @store_scatter_offset_wi_1(%src: memref<?xf16>) {
+ %val = arith.constant dense<2.9>: vector<4xf16>
+ %offsets = arith.constant dense<[0]> : vector<1xindex>
+ %mask = arith.constant dense<1>: vector<1xi1>
+ // expected-error@+1 {{value elements must match chunk size}}
+ xegpu.store %val, %src[%offsets], %mask
+ : vector<4xf16>, memref<?xf16>, vector<1xindex>, vector<1xi1>
+ return
+}
+
+// -----
+func.func @store_scatter_offset_wi_2(%src: memref<4x4xf16>) {
+ %val = arith.constant dense<2.9>: vector<4xf16>
+ %offsets = arith.constant dense<[0]> : vector<1xindex>
+ %mask = arith.constant dense<1>: vector<1xi1>
+ // expected-error@+1 {{Expecting the dest is a 1D memref or pointer}}
+ xegpu.store %val, %src[%offsets], %mask
+ : vector<4xf16>, memref<4x4xf16>, vector<1xindex>, vector<1xi1>
+ return
+}
+
+// -----
+func.func @load_gather_offset_wi_2(%src: ui64) {
+ %mask = arith.constant dense<1>: vector<1xi1>
+ %offsets = arith.constant dense<[0]> : vector<1xindex>
+ // expected-error@+1 {{value elements must match chunk size}}
+ %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64, vector<1xindex>, vector<1xi1> -> vector<4xf16>
+ return
+}
+
+// -----
+func.func @load_gather_offset_wi_1(%src: memref<4x4xf32>) {
+ %mask = arith.constant dense<1>: vector<1xi1>
+ %offsets = arith.constant dense<[0]> : vector<1xindex>
+ // expected-error@+1 {{Expecting the source is a 1D memref or pointer}}
+ %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : memref<4x4xf32>, vector<1xindex>, vector<1xi1> -> vector<2xf32>
+ return
+}
+
+// -----
func.func @store_scatter_vc_1(%src: memref<24x32xf32>) {
%0 = arith.constant dense<1>: vector<4xi1>
%1 = arith.constant dense<2.9>: vector<4x2xf32>
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 3ebb1b969a..6be2371 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -521,6 +521,16 @@ gpu.func @subgroup_load_4(%src: ui64) {
gpu.return
}
+// CHECK: gpu.func @subgroup_load_offset_1(%arg0: memref<?xf16>) {
+gpu.func @subgroup_load_offset_1(%src: memref<?xf16>) {
+ %offset = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %mask = arith.constant dense<1>: vector<4xi1>
+ //CHECK: %[[R1:.*]] = xegpu.load %arg0[%cst], %cst_0 <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : memref<?xf16>, vector<4xindex>, vector<4xi1> -> vector<4x2xf16>
+ %val = xegpu.load %src[%offset], %mask <{chunk_size=2, l1_hint = #xegpu.cache_hint<cached>}>
+ : memref<?xf16>, vector<4xindex>, vector<4xi1> -> vector<4x2xf16>
+ gpu.return
+}
+
// CHECK: gpu.func @subgroup_store(%[[arg0:.*]]: ui64) {
gpu.func @subgroup_store(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
@@ -626,6 +636,17 @@ gpu.func @subgroup_store_4(%src: ui64) {
gpu.return
}
+// CHECK: gpu.func @subgroup_store_offset_1(%arg0: memref<?xf16>) {
+gpu.func @subgroup_store_offset_1(%dest: memref<?xf16>) {
+ %val = arith.constant dense<2.9>: vector<4x2xf16>
+ %offset = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %mask = arith.constant dense<1>: vector<4xi1>
+ //CHECK: xegpu.store %[[R0:.*]], %arg0[%cst_0], %cst_1 <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<4x2xf16>, memref<?xf16>, vector<4xindex>, vector<4xi1>
+ xegpu.store %val, %dest[%offset], %mask <{chunk_size=2, l1_hint = #xegpu.cache_hint<cached>}>
+ : vector<4x2xf16>, memref<?xf16>, vector<4xindex>, vector<4xi1>
+ gpu.return
+}
+
// CHECK: gpu.func @prefetch(%[[arg0:.*]]: ui64) {
gpu.func @prefetch(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
@@ -637,6 +658,14 @@ gpu.func @prefetch(%src: ui64) {
gpu.return
}
+// CHECK: gpu.func @prefetch_offset(%[[arg0:.*]]: ui64) {
+gpu.func @prefetch_offset(%src: ui64) {
+ //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ // CHECK: xegpu.prefetch %[[arg0]][%cst] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : ui64, vector<4xindex>
+ xegpu.prefetch %src[%0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: ui64, vector<4xindex>
+ gpu.return
+}
// CHECK: gpu.func @create_update_tdesc(%[[arg0:.*]]: ui64) {
gpu.func @create_update_tdesc(%src: ui64) {
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index d67bdb4..628a485 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -2,122 +2,117 @@
gpu.module @test_round_robin_assignment {
// CHECK-LABEL: create_nd_tdesc
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @create_nd_tdesc(%src: memref<24x32xf32>) {
- // CHECK-COUNT-12: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<24x32xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) {
+ // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf32>
+ // CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-NOT: xegpu.create_nd_tdesc
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
// CHECK-LABEL: load_nd_tdesc
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @load_nd_tdesc(%src: memref<24x32xf32>) {
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
- // CHECK-COUNT-12: xegpu.load_nd %{{.*}}
- // CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
- // CHECK-SAME-COUNT-12: -> vector<2x2xf32>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-COUNT-4: xegpu.load_nd %{{.*}}
+ // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-SAME-COUNT-4: -> vector<16x16xf32>
// CHECK-NOT: xegpu.load_nd
%load = xegpu.load_nd %tdesc
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
- -> vector<24x32xf32>
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<256x128xf32>
gpu.return
}
// CHECK-LABEL: store_nd
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @store_nd(%src: memref<24x32xf32>) {
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
- // CHECK-COUNT-12: xegpu.store_nd %{{.*}}, %{{.*}}
- // CHECK-SAME-COUNT-12: : vector<2x2xf32>, !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @store_nd(%src: memref<256x128xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, %{{.*}}
+ // CHECK-SAME-COUNT-4: : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-NOT : xegpu.store_nd
%load = xegpu.load_nd %tdesc
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
- -> vector<24x32xf32>
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<256x128xf32>
xegpu.store_nd %load, %tdesc
- : vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
// CHECK-LABEL: update_nd
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @update_nd(%src: memref<24x32xf32>){
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
- // CHECK-COUNT-12: xegpu.update_nd_offset %{{.*}}, [0, 16]
- // CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @update_nd(%src: memref<256x128xf32>){
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-COUNT-4: xegpu.update_nd_offset %{{.*}}, [0, 16]
+ // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>>
// CHECK-NOT: xegpu.update_nd_offset
%update = xegpu.update_nd_offset %tdesc, [0, 16]
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
// CHECK-LABEL: dpas
- // CHECK-SAME: (%[[ARG_0:.*]]: memref<8x8xf32>, %[[ARG_1:.*]]: memref<8x8xf32>, %[[ARG_2:.*]]: memref<8x8xf32>)
- gpu.func @dpas(%a: memref<8x8xf32>, %b: memref<8x8xf32>, %c: memref<8x8xf32>) {
- // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<8x8xf32>
- // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+ // CHECK-SAME: (%[[ARG_0:.*]]: memref<256x128xf16>, %[[ARG_1:.*]]: memref<128x256xf16>)
+ gpu.func @dpas(%a: memref<256x128xf16>, %b: memref<128x256xf16>) {
+ // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf16>
+ // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-NOT: xegpu.create_nd_tdesc
- // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<8x8xf32>
- // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
- // CHECK-NOT: xegpu.create_nd_tdesc
- // CHECK-COUNT-4: xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}] : memref<8x8xf32>
- // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+ // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<128x256xf16>
+ // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [4, 8], lane_data = [1, 1]>>
// CHECK-NOT: xegpu.create_nd_tdesc
// CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}}
- // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
- // CHECK-SAME-COUNT-16: : vector<2x2xf32>, vector<2x2xf32> -> vector<2x2xf32>
+ // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ // CHECK-SAME-COUNT-16: : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
// CHECK-NOT: xegpu.dpas
- %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<8x8xf32>
- -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<256x128xf16>
+ -> !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
%load_a = xegpu.load_nd %tdesc_a
- : !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
- -> vector<8x8xf32>
- %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<8x8xf32>
- -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ : !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<256x128xf16>
+ %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x256xf16>
+ -> !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
%load_b = xegpu.load_nd %tdesc_b
- : !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
- -> vector<8x8xf32>
- %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<8x8xf32>
- -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ : !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
+ -> vector<128x256xf16>
%dpas = xegpu.dpas %load_a, %load_b
- {layout_result_0 = #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
- : vector<8x8xf32>, vector<8x8xf32> -> vector<8x8xf32>
+ {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>}
+ : vector<256x128xf16>, vector<128x256xf16> -> vector<256x256xf32>
gpu.return
}
// CHECK-LABEL: prefetch_nd_tdesc
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @prefetch_nd_tdesc(%src: memref<24x32xf32>) {
- // CHECK-COUNT-12: xegpu.prefetch_nd %{{.*}}
- // CHECK-SAME-COUNT-12 : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @prefetch_nd_tdesc(%src: memref<256x128xf32>) {
+ // CHECK-COUNT-4: xegpu.prefetch_nd %{{.*}}
+ // CHECK-SAME-COUNT-4: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-NOT: xegpu.prefetch_nd
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
xegpu.prefetch_nd %tdesc
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
// CHECK-LABEL: broadcast
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
- gpu.func @broadcast(%src: memref<24x1xf32>) {
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
- -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<128x1xf32>
+ gpu.func @broadcast(%src: memref<128x1xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<128x1xf32>
+ -> !xegpu.tensor_desc<128x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc
- : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
- -> vector<24x1xf32>
- // CHECK-COUNT-3: vector.broadcast {{.*}}
- // CHECK-SAME-COUNT-3: {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
- // CHECK-SAME-COUNT-3: : vector<2x1xf32> to vector<2x4xf32>
+ : !xegpu.tensor_desc<128x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
+ -> vector<128x1xf32>
+ // CHECK-COUNT-2: vector.broadcast {{.*}}
+ // CHECK-SAME-COUNT-2: {layout_result_0 = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>}
+ // CHECK-SAME-COUNT-2: : vector<16x1xf32> to vector<16x32xf32>
// CHECK-NOT: vector.broadcast
%broadcast = vector.broadcast %load
- {layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 4], lane_layout = [2, 1], lane_data = [1, 1]>}
- : vector<24x1xf32> to vector<24x8xf32>
+ {layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 32], lane_layout = [8, 1], lane_data = [1, 1]>}
+ : vector<128x1xf32> to vector<128x64xf32>
gpu.return
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index d511224..d4b0037 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -4,201 +4,181 @@
//CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)>
gpu.module @test_1_1_assignment {
// CHECK-LABEL: create_nd_tdesc
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @create_nd_tdesc(%src: memref<24x32xf32>) {
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) {
// CHECK: %[[SGID:.*]] = gpu.subgroup_id
- // CHECK: %[[C12:.*]] = arith.constant 12 : index
- // CHECK: %[[C4:.*]] = arith.constant 4 : index
// CHECK: %[[C8:.*]] = arith.constant 8 : index
+ // CHECK: %[[C32:.*]] = arith.constant 32 : index
+ // CHECK: %[[C4:.*]] = arith.constant 4 : index
+ // CHECK: %[[C32_0:.*]] = arith.constant 32 : index
+ // CHECK: %[[C4_1:.*]] = arith.constant 4 : index
// CHECK: %[[DIV:.*]] = affine.apply #map()[%[[SGID]]]
// CHECK: %[[REM:.*]] = affine.apply #map1()[%[[SGID]]]
- // CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C12]]
- // CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C8]]
- // CHECK: %[[C24:.*]] = arith.constant 24 : index
- // CHECK: %[[MOD:.*]] = index.remu %[[MUL1]], %[[C24]]
+ // CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C32]]
+ // CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C32_0]]
// CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: %[[ADD1:.*]] = index.add %[[MOD]], %[[C0]]
- // CHECK: %[[C32:.*]] = arith.constant 32 : index
- // CHECK: %[[MOD1:.*]] = index.remu %[[MUL2]], %[[C32]]
- // CHECK: %[[C0_1:.*]] = arith.constant 0 : index
- // CHECK: %[[ADD2:.*]] = index.add %[[MOD1]], %[[C0_1]]
- // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<24x32xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK: %[[C256:.*]] = arith.constant 256 : index
+ // CHECK: %[[MOD:.*]] = index.remu %[[MUL1]], %[[C256]]
+ // CHECK: %[[C0_2:.*]] = arith.constant 0 : index
+ // CHECK: %[[ADD1:.*]] = index.add %[[MOD]], %[[C0_2]]
+ // CHECK: %[[C0_3:.*]] = arith.constant 0 : index
+ // CHECK: %[[C128:.*]] = arith.constant 128 : index
+ // CHECK: %[[MOD1:.*]] = index.remu %[[MUL2]], %[[C128]]
+ // CHECK: %[[C0_4:.*]] = arith.constant 0 : index
+ // CHECK: %[[ADD2:.*]] = index.add %[[MOD1]], %[[C0_4]]
+ // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<256x128xf32>
+ // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK: gpu.return
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
// CHECK-LABEL: load_nd_tdesc
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @load_nd_tdesc(%src: memref<24x32xf32>) {
- // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) {
+ // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
+ // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]]
- // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- // CHECK-SAME: -> vector<12x8xf32>
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-SAME: -> vector<32x32xf32>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
- -> vector<24x32xf32>
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<256x128xf32>
gpu.return
}
// CHECK-LABEL: store_nd
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @store_nd(%src: memref<24x32xf32>) {
- // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @store_nd(%src: memref<256x128xf32>) {
+ // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
+ // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]]
- // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- // CHECK-SAME: -> vector<12x8xf32>
+ // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-SAME: -> vector<32x32xf32>
// CHECK: xegpu.store_nd %[[LOAD]], %[[TDESC]]
- // CHECK-SAME: : vector<12x8xf32>, !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK-SAME: : vector<32x32xf32>, !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
- -> vector<24x32xf32>
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<256x128xf32>
xegpu.store_nd %load, %tdesc
- : vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
// CHECK-LABEL: update_nd
-// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
-gpu.func @update_nd(%src: memref<24x32xf32>){
- // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+gpu.func @update_nd(%src: memref<256x128xf32>){
+ // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
+ // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK: %[[UPDATE:.*]] = xegpu.update_nd_offset %[[TDESC]], [0, 16]
- // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
%update = xegpu.update_nd_offset %tdesc, [0, 16]
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
// CHECK-LABEL: dpas
-// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
-// CHECK-SAME: %[[ARG_1:.*]]: memref<32x24xf32>
-gpu.func @dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
- // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
- // CHECk-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- // CHECK: %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]]
- // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- // CHECK-SAME: -> vector<12x8xf32>
- // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>>
- // CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]]
- // CHECK-SAME: : !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>>
- // CHECK-SAME: -> vector<8x12xf32>
- // CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]]
- // CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
- // CHECK-SAME: : vector<12x8xf32>, vector<8x12xf32> -> vector<12x12xf32>
- %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
+ // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x128xf16>, vector<128x16xf16> -> vector<16x16xf32>
+ %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<128x128xf16>
+ -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>>
%load_a = xegpu.load_nd %tdesc_a
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
- -> vector<24x32xf32>
- %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32>
- -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>>
+ : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<128x128xf16>
+ %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x128xf16>
+ -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
%load_b = xegpu.load_nd %tdesc_b
- : !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>>
- -> vector<32x24xf32>
+ : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
+ -> vector<128x128xf16>
%dpas = xegpu.dpas %load_a, %load_b
- {layout_result_0 = #xegpu.layout<sg_layout = [2, 2], sg_data = [12, 12], lane_layout = [2, 2], lane_data = [1, 1]>}
- : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
+ {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>}
+ : vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
gpu.return
}
// CHECK-LABEL: dpas_no_sg_data
-// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
-// CHECK-SAME: %[[ARG_1:.*]]: memref<32x24xf32>
-gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
- // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
- // CHECk-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- // CHECK: %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]]
- // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- // CHECK-SAME: -> vector<12x8xf32>
- // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>>
- // CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]]
- // CHECK-SAME: : !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>>
- // CHECK-SAME: -> vector<8x12xf32>
- // CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]]
- // CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
- // CHECK-SAME: : vector<12x8xf32>, vector<8x12xf32> -> vector<12x12xf32>
- %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], lane_layout = [2, 8], lane_data = [1, 1]>>
+gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
+ // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
+ %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<128x128xf16>
+ -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1],
+ order = [1, 0]>>
%load_a = xegpu.load_nd %tdesc_a
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], lane_layout = [2, 8], lane_data = [1, 1]>>
- -> vector<24x32xf32>
- %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32>
- -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], lane_layout = [8, 2], lane_data = [1, 1]>>
+ : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1],
+ order = [1, 0]>>
+ -> vector<128x128xf16>
+ %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x128xf16>
+ -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1],
+ order = [1, 0]>>
%load_b = xegpu.load_nd %tdesc_b
- : !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], lane_layout = [8, 2], lane_data = [1, 1]>>
- -> vector<32x24xf32>
+ : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1],
+ order = [1, 0]>>
+ -> vector<128x128xf16>
%dpas = xegpu.dpas %load_a, %load_b
- {layout_result_0 = #xegpu.layout<sg_layout = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
- : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
+ {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>}
+ : vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
gpu.return
}
// CHECK-LABEL: prefetch_nd_tdesc
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @prefetch_nd_tdesc(%src: memref<24x32xf32>) {
- // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @prefetch_nd_tdesc(%src: memref<256x128xf32>) {
+ // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
+ // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK: xegpu.prefetch_nd %[[TDESC]]
- // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
xegpu.prefetch_nd %tdesc
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
// CHECK-LABEL: dpas_with_no_create_nd_desc
- gpu.func @dpas_with_no_create_nd_desc(%a: vector<24x32xf32>, %b: vector<32x24xf32>) {
- // CHECK-NOT: vector<12x12xf32>
+ gpu.func @dpas_with_no_create_nd_desc(%a: vector<256x128xf32>, %b: vector<128x256xf32>) {
+ // CHECK-NOT: vector<32x32xf32>
%dpas = xegpu.dpas %a, %b
{layout = #xegpu.layout<sg_layout = [2, 2], sg_data = [12, 12], lane_layout = [2, 2], lane_data = [1, 1]>}
- : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
+ : vector<256x128xf32>, vector<128x256xf32> -> vector<256x256xf32>
gpu.return
}
// CHECK-LABEL: broadcast_dim1
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
- gpu.func @broadcast_dim1(%src: memref<24x1xf32>) {
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
- -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x1xf32>
+ gpu.func @broadcast_dim1(%src: memref<256x1xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x1xf32>
+ -> !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc
- : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
- -> vector<24x1xf32>
- // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
- // CHECK-SAME: : vector<12x1xf32> to vector<12x8xf32>
- %broadcast = vector.broadcast %load
- {layout_result_0 = #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 8], lane_layout = [2, 1], lane_data = [1, 1]>}
- : vector<24x1xf32> to vector<24x8xf32>
+ : !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
+ -> vector<256x1xf32>
+ // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>}
+ // CHECK-SAME: : vector<32x1xf32> to vector<32x32xf32>
+ %broadcast = vector.broadcast %load
+ {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 32], lane_layout = [8, 1], lane_data = [1, 1]>}
+ : vector<256x1xf32> to vector<256x32xf32>
gpu.return
}
// CHECK-LABEL: broadcast_dim0
- // CHECK-SAME: %[[ARG_0:.*]]: memref<1x32xf32>
- gpu.func @broadcast_dim0(%src: memref<1x32xf32>) {
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<1x32xf32>
- -> !xegpu.tensor_desc<1x32xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 8], lane_layout = [1, 8], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<1x128xf32>
+ gpu.func @broadcast_dim0(%src: memref<1x128xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<1x128xf32>
+ -> !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc
- : !xegpu.tensor_desc<1x32xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 8], lane_layout = [1, 8], lane_data = [1, 1]>>
- -> vector<1x32xf32>
- // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 8], lane_data = [1, 1]>}
- // CHECK-SAME: : vector<1x8xf32> to vector<12x8xf32>
+ : !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<1x128xf32>
+ // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ // CHECK-SAME: : vector<1x32xf32> to vector<32x32xf32>
%broadcast = vector.broadcast %load
- {layout_result_0 = #xegpu.layout<sg_layout = [1, 4], sg_data = [12, 8], lane_layout = [1, 8], lane_data = [1, 1]>}
- : vector<1x32xf32> to vector<12x32xf32>
+ {layout_result_0 = #xegpu.layout<sg_layout = [1, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>}
+ : vector<1x128xf32> to vector<32x128xf32>
gpu.return
}
diff --git a/mlir/test/Examples/transform/Ch3/ops.mlir b/mlir/test/Examples/transform/Ch3/ops.mlir
index b2d47cc..707a09f 100644
--- a/mlir/test/Examples/transform/Ch3/ops.mlir
+++ b/mlir/test/Examples/transform/Ch3/ops.mlir
@@ -30,9 +30,30 @@ module attributes {transform.with_named_sequence} {
// -----
func.func private @orig()
+func.func private @updated()
// CHECK-LABEL: func @test2
func.func @test2() {
+ // CHECK: call @updated
+ call @orig() : () -> ()
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+ %call = transform.structured.match ops{["func.call"]} in %arg0 : (!transform.any_op) -> !transform.my.call_op_interface
+ // CHECK: transform.my.change_call_target %{{.*}}, "updated" : !transform.my.call_op_interface
+ transform.my.change_call_target %call, "updated" : !transform.my.call_op_interface
+ transform.yield
+ }
+}
+
+// -----
+
+func.func private @orig()
+
+// CHECK-LABEL: func @test3
+func.func @test3() {
// CHECK: "my.mm4"
call @orig() : () -> ()
return
diff --git a/mlir/test/Examples/transform/Ch3/sequence.mlir b/mlir/test/Examples/transform/Ch3/sequence.mlir
index 4d28518..877b006 100644
--- a/mlir/test/Examples/transform/Ch3/sequence.mlir
+++ b/mlir/test/Examples/transform/Ch3/sequence.mlir
@@ -101,11 +101,12 @@ module attributes {transform.with_named_sequence} {
%_1, %outline_target = transform.structured.fuse_into_containing_op %matmul_fused_2 into %loop_third
: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
%func, %call = transform.loop.outline %outline_target {func_name = "outlined"}
- : (!transform.any_op) -> (!transform.any_op, !transform.op<"func.call">)
-
- // Rewrite the call target.
- transform.my.change_call_target %call, "microkernel" : !transform.op<"func.call">
-
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ // Cast to our new type.
+ %casted = transform.cast %call : !transform.any_op to !transform.my.call_op_interface
+ // Using our new operation.
+ transform.my.change_call_target %casted, "microkernel" : !transform.my.call_op_interface
+
transform.yield
}
}
diff --git a/mlir/test/IR/diagnostic-nosplit.mlir b/mlir/test/IR/diagnostic-nosplit.mlir
new file mode 100644
index 0000000..ecfb9c6
--- /dev/null
+++ b/mlir/test/IR/diagnostic-nosplit.mlir
@@ -0,0 +1,13 @@
+// RUN: not mlir-opt %s -o - --split-input-file 2>&1 | FileCheck %s
+// This test verifies that diagnostic handler doesn't emit splits.
+
+
+// -----
+
+
+
+func.func @constant_out_of_range() {
+ // CHECK: mlir:11:8: error: 'arith.constant'
+ %x = "arith.constant"() {value = 100} : () -> i1
+ return
+}
diff --git a/mlir/test/IR/test-pattern-logging-listener.mlir b/mlir/test/IR/test-pattern-logging-listener.mlir
index c521110..d3d42e3 100644
--- a/mlir/test/IR/test-pattern-logging-listener.mlir
+++ b/mlir/test/IR/test-pattern-logging-listener.mlir
@@ -8,15 +8,15 @@
// {anonymous_namespace} vs `anonymous_namespace` (and maybe others?) on the
// various platforms.
-// CHECK: [pattern-logging-listener]
+// CHECK: [pattern-logging-listener:1]
// CHECK-SAME: ::ReplaceWithNewOp | notifyOperationInserted | test.new_op
-// CHECK: [pattern-logging-listener]
+// CHECK: [pattern-logging-listener:1]
// CHECK-SAME: ::ReplaceWithNewOp | notifyOperationReplaced (with values) | test.replace_with_new_op
-// CHECK: [pattern-logging-listener]
+// CHECK: [pattern-logging-listener:1]
// CHECK-SAME: ::ReplaceWithNewOp | notifyOperationModified | arith.addi
-// CHECK: [pattern-logging-listener]
+// CHECK: [pattern-logging-listener:1]
// CHECK-SAME: ::ReplaceWithNewOp | notifyOperationModified | arith.addi
-// CHECK: [pattern-logging-listener]
+// CHECK: [pattern-logging-listener:1]
// CHECK-SAME: ::ReplaceWithNewOp | notifyOperationErased | test.replace_with_new_op
func.func @replace_with_new_op() -> i32 {
%a = "test.replace_with_new_op"() : () -> (i32)
diff --git a/mlir/test/IR/top-level.mlir b/mlir/test/IR/top-level.mlir
index b571d94..5389691 100644
--- a/mlir/test/IR/top-level.mlir
+++ b/mlir/test/IR/top-level.mlir
@@ -6,10 +6,10 @@ func.func private @foo()
// -----
-// expected-error@-3 {{source must contain a single top-level operation, found: 2}}
+// expected-error@-2 {{source must contain a single top-level operation, found: 2}}
func.func private @bar()
func.func private @baz()
// -----
-// expected-error@-3 {{source must contain a single top-level operation, found: 0}}
+// expected-error@-2 {{source must contain a single top-level operation, found: 0}}
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir
index 05e6782..a7bb039 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir
@@ -81,21 +81,21 @@ func.func private @matmul(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tenso
func.func private @mmt4d(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tensor<7x13xi32>) -> tensor<7x13xi32> {
%zero = arith.constant 0 : i32
- %A_pack_empty = tensor.empty() : tensor<2x16x8x1xi32>
+ %A_pack_empty = tensor.empty() : tensor<1x16x8x1xi32>
%B_pack_empty = tensor.empty() : tensor<2x16x8x1xi32>
- %C_pack_empty = tensor.empty() : tensor<2x2x8x8xi32>
+ %C_pack_empty = tensor.empty() : tensor<1x2x8x8xi32>
// Pack matrices
- %A_pack = linalg.pack %A padding_value(%zero : i32) inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %A_pack_empty : tensor<7x16xi32> -> tensor<2x16x8x1xi32>
+ %A_pack = linalg.pack %A padding_value(%zero : i32) inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %A_pack_empty : tensor<7x16xi32> -> tensor<1x16x8x1xi32>
%B_pack = linalg.pack %B padding_value(%zero : i32) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [8, 1] into %B_pack_empty : tensor<16x13xi32> -> tensor<2x16x8x1xi32>
- %C_pack = linalg.pack %C padding_value(%zero : i32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_pack_empty : tensor<7x13xi32> -> tensor<2x2x8x8xi32>
+ %C_pack = linalg.pack %C padding_value(%zero : i32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_pack_empty : tensor<7x13xi32> -> tensor<1x2x8x8xi32>
// MMT4D
- %mmt4d = linalg.mmt4d ins(%A_pack, %B_pack : tensor<2x16x8x1xi32>, tensor<2x16x8x1xi32>) outs(%C_pack : tensor<2x2x8x8xi32>) -> tensor<2x2x8x8xi32>
+ %mmt4d = linalg.mmt4d ins(%A_pack, %B_pack : tensor<1x16x8x1xi32>, tensor<2x16x8x1xi32>) outs(%C_pack : tensor<1x2x8x8xi32>) -> tensor<1x2x8x8xi32>
// Unpack output
%C_out_empty = tensor.empty() : tensor<7x13xi32>
- %C_out_unpack = linalg.unpack %mmt4d outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_out_empty : tensor<2x2x8x8xi32> -> tensor<7x13xi32>
+ %C_out_unpack = linalg.unpack %mmt4d outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_out_empty : tensor<1x2x8x8xi32> -> tensor<7x13xi32>
return %C_out_unpack : tensor<7x13xi32>
}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir
index 6e2a82b..6ec1031 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir
@@ -4,14 +4,14 @@
// RUN: FileCheck %s
func.func @extract_element_0d(%a: vector<f32>) {
- %1 = vector.extractelement %a[] : vector<f32>
+ %1 = vector.extract %a[] : f32 from vector<f32>
// CHECK: 42
vector.print %1: f32
return
}
func.func @insert_element_0d(%a: f32, %b: vector<f32>) -> (vector<f32>) {
- %1 = vector.insertelement %a, %b[] : vector<f32>
+ %1 = vector.insert %a, %b[] : f32 into vector<f32>
return %1: vector<f32>
}
@@ -58,9 +58,9 @@ func.func @broadcast_0d(%a: f32) {
func.func @bitcast_0d() {
%0 = arith.constant 42 : i32
%1 = arith.constant dense<0> : vector<i32>
- %2 = vector.insertelement %0, %1[] : vector<i32>
+ %2 = vector.insert %0, %1[] : i32 into vector<i32>
%3 = vector.bitcast %2 : vector<i32> to vector<f32>
- %4 = vector.extractelement %3[] : vector<f32>
+ %4 = vector.extract %3[] : f32 from vector<f32>
%5 = arith.bitcast %4 : f32 to i32
// CHECK: 42
vector.print %5: i32
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
index b69a200..eb99886 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
@@ -72,7 +72,7 @@ func.func @za0_d_f64() -> i32 {
%row = vector.load %mem2[%vnum, %c0] : memref<?x?xf64>, vector<[2]xf64>
%inner_add_reduce = scf.for %offset = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_0_f64) -> (f64) {
- %t = vector.extractelement %row[%offset : index] : vector<[2]xf64>
+ %t = vector.extract %row[%offset] : f64 from vector<[2]xf64>
%inner_add_reduce_next = arith.addf %inner_iter, %t : f64
scf.yield %inner_add_reduce_next : f64
}
@@ -102,7 +102,7 @@ func.func @za0_d_f64() -> i32 {
%cmp = arith.cmpf one, %row_1, %row_2 : vector<[2]xf64>
%inner_mul_reduce = scf.for %i = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_1) -> (i64) {
- %t = vector.extractelement %cmp[%i : index] : vector<[2]xi1>
+ %t = vector.extract %cmp[%i] : i1 from vector<[2]xi1>
%t_i64 = arith.extui %t : i1 to i64
%inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64
scf.yield %inner_mul_reduce_next : i64
@@ -125,7 +125,7 @@ func.func @za0_d_f64() -> i32 {
%cmp = arith.cmpf oeq, %row_1, %row_2 : vector<[2]xf64>
%inner_mul_reduce = scf.for %i = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_1) -> (i64) {
- %t = vector.extractelement %cmp[%i : index] : vector<[2]xi1>
+ %t = vector.extract %cmp[%i] : i1 from vector<[2]xi1>
%t_i64 = arith.extui %t : i1 to i64
%inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64
scf.yield %inner_mul_reduce_next : i64
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
index 697fb90..ad8e321 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
@@ -36,7 +36,7 @@ func.func @entry() -> i32 {
%row = vector.load %za_b[%vnum, %c0] : memref<?x?xi8>, vector<[16]xi8>
%inner_mul_reduce = scf.for %offset = %c0 to %svl_b step %c1_index iter_args(%inner_iter = %init_1) -> (i64) {
- %t = vector.extractelement %row[%offset : index] : vector<[16]xi8>
+ %t = vector.extract %row[%offset] : i8 from vector<[16]xi8>
%t_i64 = arith.extui %t : i8 to i64
%inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64
scf.yield %inner_mul_reduce_next : i64
@@ -64,7 +64,7 @@ func.func @entry() -> i32 {
%row = vector.load %za_b[%vnum, %c0] : memref<?x?xi8>, vector<[16]xi8>
%inner_mul_reduce = scf.for %offset = %c0 to %svl_b step %c1_index iter_args(%inner_iter = %init_1) -> (i64) {
- %t = vector.extractelement %row[%offset : index] : vector<[16]xi8>
+ %t = vector.extract %row[%offset] : i8 from vector<[16]xi8>
%t_i64 = arith.extui %t : i8 to i64
%inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64
scf.yield %inner_mul_reduce_next : i64
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot.mlir
index 53a7282..aff272c2 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot.mlir
@@ -11,8 +11,8 @@ func.func @entry() -> i32 {
%b = arith.constant dense<[9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : vector<8xf32>
%r = x86vector.avx.intr.dot %a, %b : vector<8xf32>
- %1 = vector.extractelement %r[%i0 : i32]: vector<8xf32>
- %2 = vector.extractelement %r[%i4 : i32]: vector<8xf32>
+ %1 = vector.extract %r[%i0] : f32 from vector<8xf32>
+ %2 = vector.extract %r[%i4] : f32 from vector<8xf32>
%d = arith.addf %1, %2 : f32
// CHECK: ( 110, 110, 110, 110, 382, 382, 382, 382 )
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/sparse-dot-product.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/sparse-dot-product.mlir
index bf1caaa..1c56990 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/sparse-dot-product.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/sparse-dot-product.mlir
@@ -196,13 +196,13 @@ func.func @memref_dot_optimized(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
iter_args(%sum0 = %data_zero, %b_start0 = %c0) -> (f64, index) {
%v_A = vector.transfer_read %m_A[%a], %index_padding
: memref<?xi64>, vector<8xi64>
- %segA_min = vector.extractelement %v_A[%i0 : i32] : vector<8xi64>
+ %segA_min = vector.extract %v_A[%i0] : i64 from vector<8xi64>
%r1, %next_b_start0 = scf.for %b = %b_start0 to %N step %c8
iter_args(%sum1 = %sum0, %b_start1 = %b_start0) -> (f64, index) {
%v_C = vector.transfer_read %m_C[%b], %index_padding
: memref<?xi64>, vector<8xi64>
- %segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64>
+ %segB_max = vector.extract %v_C[%i7] : i64 from vector<8xi64>
%seg1_done = arith.cmpi "slt", %segB_max, %segA_min : i64
%r2, %next_b_start1 = scf.if %seg1_done -> (f64, index) {
@@ -273,10 +273,10 @@ func.func @memref_dot_while(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
%v_C = vector.transfer_read %m_C[%b1], %index_padding
: memref<?xi64>, vector<8xi64>
- %segA_min = vector.extractelement %v_A[%i0 : i32] : vector<8xi64>
- %segA_max = vector.extractelement %v_A[%i7 : i32] : vector<8xi64>
- %segB_min = vector.extractelement %v_C[%i0 : i32] : vector<8xi64>
- %segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64>
+ %segA_min = vector.extract %v_A[%i0] : i64 from vector<8xi64>
+ %segA_max = vector.extract %v_A[%i7] : i64 from vector<8xi64>
+ %segB_min = vector.extract %v_C[%i0] : i64 from vector<8xi64>
+ %segB_max = vector.extract %v_C[%i7] : i64 from vector<8xi64>
%seg1_done = arith.cmpi "slt", %segB_max, %segA_min : i64
%r2, %a2, %b2 = scf.if %seg1_done -> (f64, index, index) {
@@ -370,8 +370,8 @@ func.func @memref_dot_while_branchless(%m_A : memref<?xi64>, %m_B : memref<?xf64
-> f64
%r2 = arith.addf %r1, %subresult : f64
- %segA_max = vector.extractelement %v_A[%i7 : i32] : vector<8xi64>
- %segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64>
+ %segA_max = vector.extract %v_A[%i7] : i64 from vector<8xi64>
+ %segB_max = vector.extract %v_C[%i7] : i64 from vector<8xi64>
%cond_a = arith.cmpi "sle", %segA_max, %segB_max : i64
%cond_a_i64 = arith.extui %cond_a : i1 to i64
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/compress.mlir b/mlir/test/Integration/Dialect/Vector/CPU/compress.mlir
index e9a66cc..1683fa5 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/compress.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/compress.mlir
@@ -28,8 +28,7 @@ func.func @printmem16(%A: memref<?xf32>) {
%mem = scf.for %i = %c0 to %c16 step %c1
iter_args(%m_iter = %m) -> (vector<16xf32>) {
%c = memref.load %A[%i] : memref<?xf32>
- %i32 = arith.index_cast %i : index to i32
- %m_new = vector.insertelement %c, %m_iter[%i32 : i32] : vector<16xf32>
+ %m_new = vector.insert %c, %m_iter[%i] : f32 into vector<16xf32>
scf.yield %m_new : vector<16xf32>
}
vector.print %mem : vector<16xf32>
@@ -49,7 +48,7 @@ func.func @entry() {
memref.store %z, %A[%i] : memref<?xf32>
%i32 = arith.index_cast %i : index to i32
%fi = arith.sitofp %i32 : i32 to f32
- %v_new = vector.insertelement %fi, %v_iter[%i32 : i32] : vector<16xf32>
+ %v_new = vector.insert %fi, %v_iter[%i] : f32 into vector<16xf32>
scf.yield %v_new : vector<16xf32>
}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/maskedstore.mlir b/mlir/test/Integration/Dialect/Vector/CPU/maskedstore.mlir
index 2dc00df..826da53 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/maskedstore.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/maskedstore.mlir
@@ -28,8 +28,7 @@ func.func @printmem16(%A: memref<?xf32>) {
%mem = scf.for %i = %c0 to %c16 step %c1
iter_args(%m_iter = %m) -> (vector<16xf32>) {
%c = memref.load %A[%i] : memref<?xf32>
- %i32 = arith.index_cast %i : index to i32
- %m_new = vector.insertelement %c, %m_iter[%i32 : i32] : vector<16xf32>
+ %m_new = vector.insert %c, %m_iter[%i] : f32 into vector<16xf32>
scf.yield %m_new : vector<16xf32>
}
vector.print %mem : vector<16xf32>
@@ -53,7 +52,7 @@ func.func @entry() {
iter_args(%v_iter = %v) -> (vector<16xf32>) {
%i32 = arith.index_cast %i : index to i32
%fi = arith.sitofp %i32 : i32 to f32
- %v_new = vector.insertelement %fi, %v_iter[%i32 : i32] : vector<16xf32>
+ %v_new = vector.insert %fi, %v_iter[%i] : f32 into vector<16xf32>
scf.yield %v_new : vector<16xf32>
}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/scatter.mlir b/mlir/test/Integration/Dialect/Vector/CPU/scatter.mlir
index 54b6e69..22b5eef 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/scatter.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/scatter.mlir
@@ -21,8 +21,7 @@ func.func @printmem8(%A: memref<?xf32>) {
%mem = scf.for %i = %c0 to %c8 step %c1
iter_args(%m_iter = %m) -> (vector<8xf32>) {
%c = memref.load %A[%i] : memref<?xf32>
- %i32 = arith.index_cast %i : index to i32
- %m_new = vector.insertelement %c, %m_iter[%i32 : i32] : vector<8xf32>
+ %m_new = vector.insert %c, %m_iter[%i] : f32 into vector<8xf32>
scf.yield %m_new : vector<8xf32>
}
vector.print %mem : vector<8xf32>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-1d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-1d.mlir
index 2393bd1..639eed4 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-1d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-1d.mlir
@@ -200,7 +200,7 @@ func.func @entry() {
// CHECK: ( 2, 12, 22, -1, -1, -42, -42, -42, -42 )
// 6. Read a scalar from a 2D memref and broadcast the value to a 1D vector.
- // Generates a loop with vector.insertelement.
+ // Generates a loop with vector.insert.
call @transfer_read_1d_broadcast(%A, %c1, %c2)
: (memref<?x?xf32>, index, index) -> ()
// CHECK: ( 12, 12, 12, 12, 12, 12, 12, 12, 12 )
diff --git a/mlir/test/Integration/GPU/Vulkan/vector-interleave.mlir b/mlir/test/Integration/GPU/Vulkan/vector-interleave.mlir
index e665653..731bd5a 100644
--- a/mlir/test/Integration/GPU/Vulkan/vector-interleave.mlir
+++ b/mlir/test/Integration/GPU/Vulkan/vector-interleave.mlir
@@ -26,17 +26,17 @@ module attributes {
%val2 = memref.load %arg1[%idx0] : memref<2xi32>
%val3 = memref.load %arg1[%idx1] : memref<2xi32>
- %lhs0 = vector.insertelement %val0, %lhs[%idx0 : index] : vector<2xi32>
- %lhs1 = vector.insertelement %val1, %lhs0[%idx1 : index] : vector<2xi32>
- %rhs0 = vector.insertelement %val2, %rhs[%idx0 : index] : vector<2xi32>
- %rhs1 = vector.insertelement %val3, %rhs0[%idx1 : index] : vector<2xi32>
+ %lhs0 = vector.insert %val0, %lhs[%idx0] : i32 into vector<2xi32>
+ %lhs1 = vector.insert %val1, %lhs0[%idx1] : i32 into vector<2xi32>
+ %rhs0 = vector.insert %val2, %rhs[%idx0] : i32 into vector<2xi32>
+ %rhs1 = vector.insert %val3, %rhs0[%idx1] : i32 into vector<2xi32>
%interleave = vector.interleave %lhs1, %rhs1 : vector<2xi32> -> vector<4xi32>
- %res0 = vector.extractelement %interleave[%idx0 : index] : vector<4xi32>
- %res1 = vector.extractelement %interleave[%idx1 : index] : vector<4xi32>
- %res2 = vector.extractelement %interleave[%idx2 : index] : vector<4xi32>
- %res3 = vector.extractelement %interleave[%idx3 : index] : vector<4xi32>
+ %res0 = vector.extract %interleave[%idx0] : i32 from vector<4xi32>
+ %res1 = vector.extract %interleave[%idx1] : i32 from vector<4xi32>
+ %res2 = vector.extract %interleave[%idx2] : i32 from vector<4xi32>
+ %res3 = vector.extract %interleave[%idx3] : i32 from vector<4xi32>
memref.store %res0, %arg2[%idx0]: memref<4xi32>
memref.store %res1, %arg2[%idx1]: memref<4xi32>
diff --git a/mlir/test/Integration/GPU/Vulkan/vector-shuffle.mlir b/mlir/test/Integration/GPU/Vulkan/vector-shuffle.mlir
index dc53fe3..c1b7dba 100644
--- a/mlir/test/Integration/GPU/Vulkan/vector-shuffle.mlir
+++ b/mlir/test/Integration/GPU/Vulkan/vector-shuffle.mlir
@@ -26,17 +26,17 @@ module attributes {
%val2 = memref.load %arg1[%idx0] : memref<2xi32>
%val3 = memref.load %arg1[%idx1] : memref<2xi32>
- %lhs0 = vector.insertelement %val0, %lhs[%idx0 : index] : vector<2xi32>
- %lhs1 = vector.insertelement %val1, %lhs0[%idx1 : index] : vector<2xi32>
- %rhs0 = vector.insertelement %val2, %rhs[%idx0 : index] : vector<2xi32>
- %rhs1 = vector.insertelement %val3, %rhs0[%idx1 : index] : vector<2xi32>
+ %lhs0 = vector.insert %val0, %lhs[%idx0] : i32 into vector<2xi32>
+ %lhs1 = vector.insert %val1, %lhs0[%idx1] : i32 into vector<2xi32>
+ %rhs0 = vector.insert %val2, %rhs[%idx0] : i32 into vector<2xi32>
+ %rhs1 = vector.insert %val3, %rhs0[%idx1] : i32 into vector<2xi32>
%shuffle = vector.shuffle %lhs1, %rhs1[2, 1, 3, 3] : vector<2xi32>, vector<2xi32>
- %res0 = vector.extractelement %shuffle[%idx0 : index] : vector<4xi32>
- %res1 = vector.extractelement %shuffle[%idx1 : index] : vector<4xi32>
- %res2 = vector.extractelement %shuffle[%idx2 : index] : vector<4xi32>
- %res3 = vector.extractelement %shuffle[%idx3 : index] : vector<4xi32>
+ %res0 = vector.extract %shuffle[%idx0] : i32 from vector<4xi32>
+ %res1 = vector.extract %shuffle[%idx1] : i32 from vector<4xi32>
+ %res2 = vector.extract %shuffle[%idx2] : i32 from vector<4xi32>
+ %res3 = vector.extract %shuffle[%idx3] : i32 from vector<4xi32>
memref.store %res0, %arg2[%idx0]: memref<4xi32>
memref.store %res1, %arg2[%idx1]: memref<4xi32>
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index cdbca72..7888462 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -595,16 +595,17 @@ module attributes {transform.with_named_sequence} {
// -----
-// It is valid to fuse the pack op with padding semantics if the tiled
-// dimensions do not need padding.
+// It is valid to fuse the pack op with padding semantics if it is a perfect
+// tiling case.
func.func @fuse_pack_consumer_with_padding_semantics(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<22x2x3x16xf32> {
- %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
- %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
- %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
- %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
+ %0 = scf.forall (%arg2, %arg3) = (0, 0) to (64, 32) step (15, 16) shared_outs(%arg4 = %arg1) -> (tensor<64x32xf32>) {
+ %size = affine.min affine_map<(d0) -> (-d0 + 64, 15)>(%arg2)
+ %src = tensor.extract_slice %arg0[%arg2, %arg3] [%size, 16] [1, 1] : tensor<64x32xf32> to tensor<?x16xf32>
+ %dest = tensor.extract_slice %arg4[%arg2, %arg3] [%size, 16] [1, 1] : tensor<64x32xf32> to tensor<?x16xf32>
+ %2 = linalg.exp ins(%src : tensor<?x16xf32>) outs(%dest : tensor<?x16xf32>) -> tensor<?x16xf32>
scf.forall.in_parallel {
- tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32>
+ tensor.parallel_insert_slice %2 into %arg4[%arg2, %arg3] [%size, 16] [1, 1] : tensor<?x16xf32> into tensor<64x32xf32>
}
}
%1 = tensor.empty() : tensor<22x2x3x16xf32>
@@ -621,109 +622,39 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
-// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (-d0 + 64, 15)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (d0 floordiv 3)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0) -> (d0 ceildiv 3)>
+// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
// CHECK: func.func @fuse_pack_consumer_with_padding_semantics(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<22x2x3x16xf32>
// CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16)
-// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]])
-// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1]
-// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
+// CHECK: %{{.*}}:2 = scf.forall (%[[I:.*]], %[[J:.*]]) = (0, 0) to (64, 32) step (15, 16)
+// CHECK-SAME: shared_outs(%[[ELEM_OUT:.*]] = %[[ARG1]], %[[PACK_OUT:.*]] = %[[OUT_INIT]])
+// CHECK: %[[SIZE:.+]] = affine.min #[[MAP0]](%[[I]])
+// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1]
+// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT]]
+// CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1]
// CHECK: %[[ELEM:.*]] = linalg.exp
// CHECK-SAME: ins(%[[ELEM_SRC]]
// CHECK-SAME: outs(%[[ELEM_DEST]]
-// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
-// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 1, 3, 16] [1, 1, 1, 1]
-// CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[ELEM]]
-// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32)
-// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 16]
-// CHECK-SAME: into %[[TILED_PACK_DEST]]
-// CHECK: scf.forall.in_parallel {
-// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
-// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 1, 3, 16] [1, 1, 1, 1]
-
-// -----
-
-// It is valid to fuse the pack if the dimension is not tiled even when it needs
-// extra padding.
-
-func.func @fuse_pack_consumer_with_untiled_extra_padding(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<33x2x3x16xf32> {
- %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
- %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
- %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
- %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
- scf.forall.in_parallel {
- tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32>
- }
- }
- %1 = tensor.empty() : tensor<33x2x3x16xf32>
- %cst = arith.constant 0.000000e+00 : f32
- %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<33x2x3x16xf32>
- return %pack : tensor<33x2x3x16xf32>
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.yield
- }
-}
-// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
-// CHECK: func.func @fuse_pack_consumer_with_untiled_extra_padding(
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<33x2x3x16xf32>
-// CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16)
-// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]])
-// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1]
-// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
-// CHECK: %[[ELEM:.*]] = linalg.exp
-// CHECK-SAME: ins(%[[ELEM_SRC]]
-// CHECK-SAME: outs(%[[ELEM_DEST]]
-// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
-// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [33, 1, 3, 16] [1, 1, 1, 1]
-// CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[ELEM]]
+// CHECK-DAG: %[[D0_OFFSET:.*]] = affine.apply #[[MAP1]](%[[I]])
+// CHECK-DAG: %[[D0_SIZE:.*]] = affine.apply #[[MAP2]](%[[SIZE]])
+// CHECK-DAG: %[[D1_OFFSET:.*]] = affine.apply #[[MAP3]](%[[J]])
+// CHECK-DAG: %[[PACK_INIT:.*]] = tensor.extract_slice %[[PACK_OUT]]
+// CHECK-SAME: [%[[D0_OFFSET]], %[[D1_OFFSET]], 0, 0] [%[[D0_SIZE]], 1, 3, 16] [1, 1, 1, 1]
+// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]]
// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32)
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 16]
// CHECK-SAME: into %[[TILED_PACK_DEST]]
// CHECK: scf.forall.in_parallel {
-// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
-// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [33, 1, 3, 16] [1, 1, 1, 1]
-
-// -----
-
-// If the dimension is tiled and it needs extra padding, do not fuse the pack
-// op.
-
-func.func @nofuse_pack_consumer_with_extra_padding(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x32x3x16xf32> {
- %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
- %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
- %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
- %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
- scf.forall.in_parallel {
- // expected-error @below {{failed to fuse consumer of slice}}
- tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32>
- }
- }
- %1 = tensor.empty() : tensor<23x32x3x16xf32>
- %cst = arith.constant 0.000000e+00 : f32
- %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<23x32x3x16xf32>
- return %pack : tensor<23x32x3x16xf32>
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.yield
- }
-}
+// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[ELEM_OUT]]
+// CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT]]
+// CHECK-SAME: [%[[D0_OFFSET]], %[[D1_OFFSET]], 0, 0] [%[[D0_SIZE]], 1, 3, 16] [1, 1, 1, 1]
// -----
diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
index 24380b5..a419d75 100644
--- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll
+++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
@@ -570,10 +570,10 @@ define void @trap_intrinsics() {
; CHECK-LABEL: llvm.func @memcpy_test
define void @memcpy_test(i32 %0, ptr %1, ptr %2) {
- ; CHECK: "llvm.intr.memcpy"(%{{.*}}, %{{.*}}, %{{.*}}) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
- call void @llvm.memcpy.p0.p0.i32(ptr %1, ptr %2, i32 %0, i1 false)
- ; CHECK: "llvm.intr.memcpy.inline"(%{{.*}}, %{{.*}}) <{isVolatile = false, len = 10 : i64}> : (!llvm.ptr, !llvm.ptr) -> ()
- call void @llvm.memcpy.inline.p0.p0.i64(ptr %1, ptr %2, i64 10, i1 false)
+ ; CHECK: "llvm.intr.memcpy"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 4 : i64}, {llvm.align = 8 : i64}, {}], isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+ call void @llvm.memcpy.p0.p0.i32(ptr align 4 %1, ptr align 8 %2, i32 %0, i1 false)
+ ; CHECK: "llvm.intr.memcpy.inline"(%{{.*}}, %{{.*}}) <{arg_attrs = [{}, {llvm.align = 4 : i64}], isVolatile = false, len = 10 : i64}> : (!llvm.ptr, !llvm.ptr) -> ()
+ call void @llvm.memcpy.inline.p0.p0.i64(ptr %1, ptr align 4 %2, i64 10, i1 false)
; CHECK: "llvm.intr.memcpy.inline"(%{{.*}}, %{{.*}}) <{isVolatile = false, len = 10 : i32}> : (!llvm.ptr, !llvm.ptr) -> ()
call void @llvm.memcpy.inline.p0.p0.i32(ptr %1, ptr %2, i32 10, i1 false)
ret void
@@ -581,17 +581,17 @@ define void @memcpy_test(i32 %0, ptr %1, ptr %2) {
; CHECK-LABEL: llvm.func @memmove_test
define void @memmove_test(i32 %0, ptr %1, ptr %2) {
- ; CHECK: "llvm.intr.memmove"(%{{.*}}, %{{.*}}, %{{.*}}) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
- call void @llvm.memmove.p0.p0.i32(ptr %1, ptr %2, i32 %0, i1 false)
+ ; CHECK: "llvm.intr.memmove"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 16 : i64}, {}, {}], isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+ call void @llvm.memmove.p0.p0.i32(ptr align 16 %1, ptr %2, i32 %0, i1 false)
ret void
}
; CHECK-LABEL: llvm.func @memset_test
define void @memset_test(i32 %0, ptr %1, i8 %2) {
- ; CHECK: "llvm.intr.memset"(%{{.*}}, %{{.*}}, %{{.*}}) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
- call void @llvm.memset.p0.i32(ptr %1, i8 %2, i32 %0, i1 false)
- ; CHECK: "llvm.intr.memset.inline"(%{{.*}}, %{{.*}}) <{isVolatile = false, len = 10 : i64}> : (!llvm.ptr, i8) -> ()
- call void @llvm.memset.inline.p0.i64(ptr %1, i8 %2, i64 10, i1 false)
+ ; CHECK: "llvm.intr.memset"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 2 : i64}, {}, {}], isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+ call void @llvm.memset.p0.i32(ptr align 2 %1, i8 %2, i32 %0, i1 false)
+ ; CHECK: "llvm.intr.memset.inline"(%{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 4 : i64}, {}], isVolatile = false, len = 10 : i64}> : (!llvm.ptr, i8) -> ()
+ call void @llvm.memset.inline.p0.i64(ptr align 4 %1, i8 %2, i64 10, i1 false)
; CHECK: "llvm.intr.memset.inline"(%{{.*}}, %{{.*}}) <{isVolatile = false, len = 10 : i32}> : (!llvm.ptr, i8) -> ()
call void @llvm.memset.inline.p0.i32(ptr %1, i8 %2, i32 10, i1 false)
ret void
diff --git a/mlir/test/Target/LLVMIR/Import/module-asm.ll b/mlir/test/Target/LLVMIR/Import/module-asm.ll
new file mode 100644
index 0000000..38f6ea4
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/module-asm.ll
@@ -0,0 +1,5 @@
+; RUN: mlir-translate -import-llvm %s | FileCheck %s
+; CHECK: llvm.module_asm = ["foo", "bar"]
+
+module asm "foo"
+module asm "bar"
diff --git a/mlir/test/Target/LLVMIR/invalid-module.mlir b/mlir/test/Target/LLVMIR/invalid-module.mlir
index 7fd5f26..5ed6244 100644
--- a/mlir/test/Target/LLVMIR/invalid-module.mlir
+++ b/mlir/test/Target/LLVMIR/invalid-module.mlir
@@ -1,6 +1,16 @@
-// RUN: mlir-translate -verify-diagnostics -mlir-to-llvmir --no-implicit-module %s
+// RUN: mlir-translate -verify-diagnostics -mlir-to-llvmir --no-implicit-module -split-input-file %s
// expected-error@below {{'llvm.func' op can not be translated to an LLVMIR module}}
llvm.func @foo() {
llvm.return
}
+
+// -----
+
+// expected-error@below {{expected an array attribute for a module level asm}}
+module attributes {llvm.module_asm = "foo"} {}
+
+// -----
+
+// expected-error@below {{expected a string attribute for each entry of a module level asm}}
+module attributes {llvm.module_asm = [42]} {}
diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
index 44074ce..eb3510c 100644
--- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
@@ -601,29 +601,33 @@ llvm.func @trap_intrinsics() {
// CHECK-LABEL: @memcpy_test
llvm.func @memcpy_test(%arg0: i32, %arg2: !llvm.ptr, %arg3: !llvm.ptr) {
- // CHECK: call void @llvm.memcpy.p0.p0.i32(ptr %{{.*}}, ptr %{{.*}}, i32 %{{.*}}, i1 false
- "llvm.intr.memcpy"(%arg2, %arg3, %arg0) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
- // CHECK: call void @llvm.memcpy.inline.p0.p0.i32(ptr %{{.*}}, ptr %{{.*}}, i32 10, i1 true
- "llvm.intr.memcpy.inline"(%arg2, %arg3) <{isVolatile = true, len = 10 : i32}> : (!llvm.ptr, !llvm.ptr) -> ()
+ // CHECK: call void @llvm.memcpy.p0.p0.i32(ptr align 4 %{{.*}}, ptr align 8 %{{.*}}, i32 %{{.*}}, i1 false
+ "llvm.intr.memcpy"(%arg2, %arg3, %arg0) <{arg_attrs = [{llvm.align = 4 : i64}, {llvm.align = 8 : i64}, {}], isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+ // CHECK: call void @llvm.memcpy.inline.p0.p0.i32(ptr align 4 %{{.*}}, ptr %{{.*}}, i32 10, i1 true
+ "llvm.intr.memcpy.inline"(%arg2, %arg3) <{arg_attrs = [{llvm.align = 4 : i64}, {}], isVolatile = true, len = 10 : i32}> : (!llvm.ptr, !llvm.ptr) -> ()
// CHECK: call void @llvm.memcpy.inline.p0.p0.i64(ptr %{{.*}}, ptr %{{.*}}, i64 10, i1 true
"llvm.intr.memcpy.inline"(%arg2, %arg3) <{isVolatile = true, len = 10 : i64}> : (!llvm.ptr, !llvm.ptr) -> ()
+
+ // Verify that trailing empty argument attribute dictionaries can be omitted.
+ // CHECK: call void @llvm.memcpy.p0.p0.i32(ptr align 4 %{{.*}}, ptr align 8 %{{.*}}, i32 %{{.*}}, i1 false
+ "llvm.intr.memcpy"(%arg2, %arg3, %arg0) <{arg_attrs = [{llvm.align = 4 : i64}, {llvm.align = 8 : i64}], isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
llvm.return
}
// CHECK-LABEL: @memmove_test
llvm.func @memmove_test(%arg0: i32, %arg2: !llvm.ptr, %arg3: !llvm.ptr) {
- // CHECK: call void @llvm.memmove.p0.p0.i32(ptr %{{.*}}, ptr %{{.*}}, i32 %{{.*}}, i1 false
- "llvm.intr.memmove"(%arg2, %arg3, %arg0) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+ // CHECK: call void @llvm.memmove.p0.p0.i32(ptr align 4 %{{.*}}, ptr align 8 %{{.*}}, i32 %{{.*}}, i1 false
+ "llvm.intr.memmove"(%arg2, %arg3, %arg0) <{arg_attrs = [{llvm.align = 4 : i64}, {llvm.align = 8 : i64}, {}], isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
llvm.return
}
// CHECK-LABEL: @memset_test
llvm.func @memset_test(%arg0: i32, %arg2: !llvm.ptr, %arg3: i8) {
%i1 = llvm.mlir.constant(false) : i1
- // CHECK: call void @llvm.memset.p0.i32(ptr %{{.*}}, i8 %{{.*}}, i32 %{{.*}}, i1 false
- "llvm.intr.memset"(%arg2, %arg3, %arg0) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
- // CHECK: call void @llvm.memset.inline.p0.i32(ptr %{{.*}}, i8 %{{.*}}, i32 10, i1 true
- "llvm.intr.memset.inline"(%arg2, %arg3) <{isVolatile = true, len = 10 : i32}> : (!llvm.ptr, i8) -> ()
+ // CHECK: call void @llvm.memset.p0.i32(ptr align 8 %{{.*}}, i8 %{{.*}}, i32 %{{.*}}, i1 false
+ "llvm.intr.memset"(%arg2, %arg3, %arg0) <{arg_attrs = [{llvm.align = 8 : i64}, {}, {}], isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+ // CHECK: call void @llvm.memset.inline.p0.i32(ptr align 8 %{{.*}}, i8 %{{.*}}, i32 10, i1 true
+ "llvm.intr.memset.inline"(%arg2, %arg3) <{arg_attrs = [{llvm.align = 8 : i64}, {}], isVolatile = true, len = 10 : i32}> : (!llvm.ptr, i8) -> ()
// CHECK: call void @llvm.memset.inline.p0.i64(ptr %{{.*}}, i8 %{{.*}}, i64 10, i1 true
"llvm.intr.memset.inline"(%arg2, %arg3) <{isVolatile = true, len = 10 : i64}> : (!llvm.ptr, i8) -> ()
llvm.return
diff --git a/mlir/test/Target/LLVMIR/module-asm.mlir b/mlir/test/Target/LLVMIR/module-asm.mlir
new file mode 100644
index 0000000..2afb37c
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/module-asm.mlir
@@ -0,0 +1,6 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+module attributes {llvm.module_asm = ["foo", "bar"]} {}
+
+// CHECK: module asm "foo"
+// CHECK: module asm "bar"
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 8c4f0aa..85478cc 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -312,3 +312,42 @@ llvm.func @nvvm_prefetch_uniform_with_invalid_addr_space(%global_ptr: !llvm.ptr<
nvvm.prefetch level = L1 uniform, %global_ptr : !llvm.ptr<1>
llvm.return
}
+
+// -----
+
+llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
+ // expected-error@+1 {{'nvvm.stmatrix' op expected num attribute to be 1, 2 or 4}}
+ nvvm.stmatrix %arg0, %r1, %r2, %r3 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32, i32, i32
+ llvm.return
+}
+
+// -----
+
+llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
+ // expected-error@+1 {{'nvvm.stmatrix' op expected shape to be 8x8 or 16x8}}
+ nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32
+ llvm.return
+}
+
+// -----
+
+llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
+ // expected-error@+1 {{'nvvm.stmatrix' op expected element type to be B16 for 8x8 matrix}}
+ nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b8>} : !llvm.ptr<3>, i32
+ llvm.return
+}
+// -----
+
+llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
+ // expected-error@+1 {{'nvvm.stmatrix' op expected element type to be B8 for 16x8 matrix}}
+ nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32
+ llvm.return
+}
+
+// -----
+
+llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
+ // expected-error@+1 {{'nvvm.stmatrix' op expected layout to be col for 16x8 matrix}}
+ nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b8>} : !llvm.ptr<3>, i32
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index f86a041..5c2cfa4 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -573,6 +573,29 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
llvm.return
}
+// CHECK-LABEL: @st_matrix
+llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
+ // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
+ nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32
+ // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
+ nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32
+ // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x1.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
+ nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b8>} : !llvm.ptr<3>, i32
+ // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32, i32
+ // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32, i32
+ // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x2.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b8>} : !llvm.ptr<3>, i32, i32
+ // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32, i32, i32, i32
+ // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32, i32, i32, i32
+ // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x4.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b8>} : !llvm.ptr<3>, i32, i32, i32, i32
+ llvm.return
+}
+
// This function has the "kernel" attribute attached and should appear in the
// NVVM annotations after conversion.
llvm.func @kernel_func() attributes {nvvm.kernel} {
diff --git a/mlir/test/Target/LLVMIR/omptarget-debug-reduc-fn-loc.mlir b/mlir/test/Target/LLVMIR/omptarget-debug-reduc-fn-loc.mlir
deleted file mode 100644
index d889ef4..0000000
--- a/mlir/test/Target/LLVMIR/omptarget-debug-reduc-fn-loc.mlir
+++ /dev/null
@@ -1,121 +0,0 @@
-// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
-
-module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memory_space", 5 : ui32>>, llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} {
- omp.private {type = private} @_QFEi_private_i32 : i32 loc(#loc1)
- omp.declare_reduction @add_reduction_i32 : i32 init {
- ^bb0(%arg0: i32 loc("test.f90":8:7)):
- %0 = llvm.mlir.constant(0 : i32) : i32 loc(#loc2)
- omp.yield(%0 : i32) loc(#loc2)
- } combiner {
- ^bb0(%arg0: i32 loc("test.f90":8:7), %arg1: i32 loc("test.f90":8:7)):
- %0 = llvm.add %arg0, %arg1 : i32 loc(#loc2)
- omp.yield(%0 : i32) loc(#loc2)
- } loc(#loc2)
- llvm.func @_QQmain() {
- %0 = llvm.mlir.constant(1 : i64) : i64 loc(#loc4)
- %1 = llvm.alloca %0 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr<5> loc(#loc4)
- %2 = llvm.addrspacecast %1 : !llvm.ptr<5> to !llvm.ptr loc(#loc4)
- %3 = llvm.mlir.constant(1 : i64) : i64 loc(#loc1)
- %4 = llvm.alloca %3 x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr<5> loc(#loc1)
- %5 = llvm.addrspacecast %4 : !llvm.ptr<5> to !llvm.ptr loc(#loc1)
- %6 = llvm.mlir.constant(8191 : index) : i64 loc(#loc5)
- %7 = llvm.mlir.constant(0 : index) : i64 loc(#loc5)
- %8 = llvm.mlir.constant(1 : index) : i64 loc(#loc5)
- %9 = llvm.mlir.constant(0 : i32) : i32 loc(#loc5)
- %10 = llvm.mlir.constant(8192 : index) : i64 loc(#loc5)
- %11 = llvm.mlir.addressof @_QFEarr : !llvm.ptr<1> loc(#loc6)
- %12 = llvm.addrspacecast %11 : !llvm.ptr<1> to !llvm.ptr loc(#loc6)
- llvm.store %9, %2 : i32, !llvm.ptr loc(#loc7)
- %15 = omp.map.info var_ptr(%2 : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "x"} loc(#loc4)
- %16 = omp.map.info var_ptr(%5 : !llvm.ptr, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !llvm.ptr {name = "i"} loc(#loc7)
- %17 = omp.map.bounds lower_bound(%7 : i64) upper_bound(%6 : i64) extent(%10 : i64) stride(%8 : i64) start_idx(%8 : i64) loc(#loc7)
- %18 = omp.map.info var_ptr(%12 : !llvm.ptr, !llvm.array<8192 x i32>) map_clauses(implicit, tofrom) capture(ByRef) bounds(%17) -> !llvm.ptr {name = "arr"} loc(#loc7)
- omp.target map_entries(%15 -> %arg0, %16 -> %arg1, %18 -> %arg2 : !llvm.ptr, !llvm.ptr, !llvm.ptr) {
- %19 = llvm.mlir.constant(8192 : i32) : i32 loc(#loc5)
- %20 = llvm.mlir.constant(1 : i32) : i32 loc(#loc5)
- %21 = llvm.mlir.constant(8192 : index) : i64 loc(#loc6)
- omp.teams reduction(@add_reduction_i32 %arg0 -> %arg3 : !llvm.ptr) {
- omp.parallel private(@_QFEi_private_i32 %arg1 -> %arg4 : !llvm.ptr) {
- omp.distribute {
- omp.wsloop reduction(@add_reduction_i32 %arg3 -> %arg5 : !llvm.ptr) {
- omp.loop_nest (%arg6) : i32 = (%20) to (%19) inclusive step (%20) {
- llvm.store %arg6, %arg4 : i32, !llvm.ptr loc(#loc2)
- %22 = llvm.load %arg5 : !llvm.ptr -> i32 loc(#loc8)
- %23 = llvm.load %arg4 : !llvm.ptr -> i32 loc(#loc8)
- %34 = llvm.add %22, %23 : i32 loc(#loc8)
- llvm.store %34, %arg5 : i32, !llvm.ptr loc(#loc8)
- omp.yield loc(#loc2)
- } loc(#loc2)
- } {omp.composite} loc(#loc2)
- } {omp.composite} loc(#loc2)
- omp.terminator loc(#loc2)
- } {omp.composite} loc(#loc2)
- omp.terminator loc(#loc2)
- } loc(#loc2)
- omp.terminator loc(#loc2)
- } loc(#loc13)
- llvm.return loc(#loc9)
- } loc(#loc12)
- llvm.mlir.global internal @_QFEarr() {addr_space = 1 : i32} : !llvm.array<8192 x i32> {
- %0 = llvm.mlir.zero : !llvm.array<8192 x i32> loc(#loc6)
- llvm.return %0 : !llvm.array<8192 x i32> loc(#loc6)
- } loc(#loc6)
-} loc(#loc)
-
-#loc = loc("test.f90":4:18)
-#loc1 = loc("test.f90":4:18)
-#loc2 = loc("test.f90":8:7)
-#loc3 = loc("test.f90":1:7)
-#loc4 = loc("test.f90":3:18)
-#loc5 = loc(unknown)
-#loc6 = loc("test.f90":5:18)
-#loc7 = loc("test.f90":6:7)
-#loc8 = loc("test.f90":10:7)
-#loc9 = loc("test.f90":16:7)
-
-#di_file = #llvm.di_file<"target7.f90" in "">
-#di_null_type = #llvm.di_null_type
-#di_compile_unit = #llvm.di_compile_unit<id = distinct[0]<>,
- sourceLanguage = DW_LANG_Fortran95, file = #di_file, producer = "flang",
- isOptimized = false, emissionKind = LineTablesOnly>
-#di_subroutine_type = #llvm.di_subroutine_type<
- callingConvention = DW_CC_program, types = #di_null_type>
-#di_subprogram = #llvm.di_subprogram<id = distinct[1]<>,
- compileUnit = #di_compile_unit, scope = #di_file, name = "main",
- file = #di_file, subprogramFlags = "Definition|MainSubprogram",
- type = #di_subroutine_type>
-#di_subprogram1 = #llvm.di_subprogram<compileUnit = #di_compile_unit,
- name = "target", file = #di_file, subprogramFlags = "Definition",
- type = #di_subroutine_type>
-
-
-#loc12 = loc(fused<#di_subprogram>[#loc3])
-#loc13 = loc(fused<#di_subprogram1>[#loc2])
-
-// CHECK-DAG: define internal void @_omp_reduction_shuffle_and_reduce_func
-// CHECK-NOT: !dbg
-// CHECK: }
-// CHECK-DAG: define internal void @_omp_reduction_inter_warp_copy_func
-// CHECK-NOT: !dbg
-// CHECK: }
-// CHECK-DAG: define internal void @"__omp_offloading_{{.*}}__QQmain_l8_omp$reduction$reduction_func.1"
-// CHECK-NOT: !dbg
-// CHECK: }
-// CHECK-DAG: define internal void @_omp_reduction_shuffle_and_reduce_func.2
-// CHECK-NOT: !dbg
-// CHECK: }
-// CHECK-DAG: define internal void @_omp_reduction_inter_warp_copy_func.3
-// CHECK-NOT: !dbg
-// CHECK: }
-// CHECK-DAG: define internal void @_omp_reduction_list_to_global_copy_func
-// CHECK-NOT: !dbg
-// CHECK: }
-// CHECK-DAG: define internal void @_omp_reduction_list_to_global_reduce_func
-// CHECK-NOT: !dbg
-// CHECK: }
-// CHECK-DAG: define internal void @_omp_reduction_global_to_list_copy_func
-// CHECK-NOT: !dbg
-// CHECK: }
-// CHECK-DAG: define internal void @_omp_reduction_global_to_list_reduce_func
-// CHECK-NOT: !dbg
-// CHECK: }
diff --git a/mlir/test/Target/LLVMIR/xevm.mlir b/mlir/test/Target/LLVMIR/xevm.mlir
new file mode 100644
index 0000000..a3dd0b6
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/xevm.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-translate --split-input-file -mlir-to-llvmir %s | FileCheck %s
+
+module {
+ llvm.func spir_funccc @_Z8prefetchPU3AS1Kcm(!llvm.ptr<1>, i64)
+ llvm.func @prefetch(%arg0: !llvm.ptr<1>) {
+ %0 = llvm.mlir.constant(1 : i64) : i64
+ // CHECK-LABEL: call spir_func void @_Z8prefetchPU3AS1Kcm
+ // CHECK-SAME: !spirv.DecorationCacheControlINTEL ![[DECO1:.*]]
+ llvm.call spir_funccc @_Z8prefetchPU3AS1Kcm(%arg0, %0)
+ {function_type = !llvm.func<void (ptr<1>, i64)>, linkage = #llvm.linkage<external>,
+ no_unwind, sym_name = "_Z8prefetchPU3AS1Kcm", visibility_ = 0 : i64,
+ xevm.DecorationCacheControl = [[6442 : i32, 0 : i32, 1 : i32, 0 : i32], [6442 : i32, 1 : i32, 1 : i32, 0 : i32]]}
+ : (!llvm.ptr<1>, i64) -> ()
+ llvm.return
+ }
+}
+
+// CHECK: ![[DECO1]] = !{![[DECO2:.*]], ![[DECO3:.*]]}
+// CHECK: ![[DECO2]] = !{i32 6442, i32 0, i32 1, i32 0}
+// CHECK: ![[DECO3]] = !{i32 6442, i32 1, i32 1, i32 0}
+
diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index 76d34c2..1695d2a 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-translate --no-implicit-module --split-input-file --test-spirv-roundtrip %s | FileCheck %s
+// RUN: %if spirv-tools %{ mlir-translate -no-implicit-module --split-input-file -serialize-spirv %s | spirv-val %}
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int64, Int16, Int8, Float64, Float16, CooperativeMatrixKHR], [SPV_KHR_vulkan_memory_model, SPV_KHR_cooperative_matrix]> {
// CHECK-LABEL: @bool_const
spirv.func @bool_const() -> () "None" {
// CHECK: spirv.Constant true
@@ -305,6 +306,36 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
%coop = spirv.Constant dense<4> : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>
spirv.ReturnValue %coop : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>
}
+
+ // CHECK-LABEL: @arm_tensor_of_i32
+ spirv.func @arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+ // CHECK: {{%.*}} = spirv.Constant dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : !spirv.arm.tensor<2x3xi32>
+ %0 = spirv.Constant dense<[[1, 2, 3], [4, 5, 6]]> : !spirv.arm.tensor<2x3xi32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
+ }
+
+ // CHECK-LABEL: @splat_arm_tensor_of_i32
+ spirv.func @splat_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+ // CHECK: {{%.*}} = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32>
+ %0 = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
+ }
+
+ // CHECK-LABEL: @arm_tensor_of_f32
+ spirv.func @arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+ // CHECK: {{%.*}} = spirv.Constant dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : !spirv.arm.tensor<2x3xf32>
+ %0 = spirv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>: !spirv.arm.tensor<2x3xf32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
+ }
+
+ // CHECK-LABEL: @splat_arm_tensor_of_f32
+ spirv.func @splat_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+ // CHECK: {{%.*}} = spirv.Constant dense<2.000000e+00> : !spirv.arm.tensor<2x3xf32>
+ %0 = spirv.Constant dense<2.0> : !spirv.arm.tensor<2x3xf32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
+ }
+
+ spirv.EntryPoint "GLCompute" @bool_const
}
// -----
diff --git a/mlir/test/Target/SPIRV/lit.local.cfg b/mlir/test/Target/SPIRV/lit.local.cfg
new file mode 100644
index 0000000..6d44394
--- /dev/null
+++ b/mlir/test/Target/SPIRV/lit.local.cfg
@@ -0,0 +1,4 @@
+if config.spirv_tools_tests:
+ config.available_features.add("spirv-tools")
+ config.substitutions.append(("spirv-as", os.path.join(config.llvm_tools_dir, "spirv-as")))
+ config.substitutions.append(("spirv-val", os.path.join(config.llvm_tools_dir, "spirv-val")))
diff --git a/mlir/test/Target/SPIRV/logical-ops.mlir b/mlir/test/Target/SPIRV/logical-ops.mlir
index b200871..05cbddc 100644
--- a/mlir/test/Target/SPIRV/logical-ops.mlir
+++ b/mlir/test/Target/SPIRV/logical-ops.mlir
@@ -84,6 +84,8 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
%15 = spirv.IsNan %arg0 : f32
// CHECK: spirv.IsInf
%16 = spirv.IsInf %arg1 : f32
+ // CHECK: spirv.IsFinite
+ %17 = spirv.IsFinite %arg0 : f32
spirv.Return
}
}
diff --git a/mlir/test/Target/SPIRV/memory-ops.mlir b/mlir/test/Target/SPIRV/memory-ops.mlir
index 6b50c39..786d07a2 100644
--- a/mlir/test/Target/SPIRV/memory-ops.mlir
+++ b/mlir/test/Target/SPIRV/memory-ops.mlir
@@ -37,32 +37,32 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// -----
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
- spirv.func @load_store_zero_rank_float(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>) "None" {
- // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>
+ spirv.func @load_store_zero_rank_float(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>) "None" {
+ // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>
// CHECK-NEXT: [[VAL:%.*]] = spirv.Load "StorageBuffer" [[LOAD_PTR]] : f32
%0 = spirv.Constant 0 : i32
- %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
+ %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
%2 = spirv.Load "StorageBuffer" %1 : f32
- // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>
+ // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>
// CHECK-NEXT: spirv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : f32
%3 = spirv.Constant 0 : i32
- %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
+ %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
spirv.Store "StorageBuffer" %4, %2 : f32
spirv.Return
}
- spirv.func @load_store_zero_rank_int(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>) "None" {
- // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>
+ spirv.func @load_store_zero_rank_int(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>) "None" {
+ // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>
// CHECK-NEXT: [[VAL:%.*]] = spirv.Load "StorageBuffer" [[LOAD_PTR]] : i32
%0 = spirv.Constant 0 : i32
- %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer>
+ %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer>
%2 = spirv.Load "StorageBuffer" %1 : i32
- // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>
+ // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>
// CHECK-NEXT: spirv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : i32
%3 = spirv.Constant 0 : i32
- %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer>
+ %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer>
spirv.Store "StorageBuffer" %4, %2 : i32
spirv.Return
}
diff --git a/mlir/test/Target/SPIRV/struct.mlir b/mlir/test/Target/SPIRV/struct.mlir
index 0db0c0b..4984ee7 100644
--- a/mlir/test/Target/SPIRV/struct.mlir
+++ b/mlir/test/Target/SPIRV/struct.mlir
@@ -7,23 +7,23 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// CHECK: !spirv.ptr<!spirv.struct<(f32 [0], !spirv.struct<(f32 [0], !spirv.array<16 x f32, stride=4> [4])> [4])>, Input>
spirv.GlobalVariable @var1 bind(0, 2) : !spirv.ptr<!spirv.struct<(f32 [0], !spirv.struct<(f32 [0], !spirv.array<16 x f32, stride=4> [4])> [4])>, Input>
- // CHECK: !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38])>, StorageBuffer>
- spirv.GlobalVariable @var2 : !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38])>, StorageBuffer>
+ // CHECK: !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38]), Block>, StorageBuffer>
+ spirv.GlobalVariable @var2 : !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38]), Block>, StorageBuffer>
- // CHECK: !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0])>, StorageBuffer>
- spirv.GlobalVariable @var3 : !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0])>, StorageBuffer>
+ // CHECK: !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0]), Block>, StorageBuffer>
+ spirv.GlobalVariable @var3 : !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0]), Block>, StorageBuffer>
- // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4])>, StorageBuffer>
- spirv.GlobalVariable @var4 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4])>, StorageBuffer>
+ // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4]), Block>, StorageBuffer>
+ spirv.GlobalVariable @var4 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4]), Block>, StorageBuffer>
- // CHECK: !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable])>, StorageBuffer>
- spirv.GlobalVariable @var5 : !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable])>, StorageBuffer>
+ // CHECK: !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable]), Block>, StorageBuffer>
+ spirv.GlobalVariable @var5 : !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable]), Block>, StorageBuffer>
- // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable])>, StorageBuffer>
- spirv.GlobalVariable @var6 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable])>, StorageBuffer>
+ // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]), Block>, StorageBuffer>
+ spirv.GlobalVariable @var6 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]), Block>, StorageBuffer>
- // CHECK: !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16])>, StorageBuffer>
- spirv.GlobalVariable @var7 : !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16])>, StorageBuffer>
+ // CHECK: !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]), Block>, StorageBuffer>
+ spirv.GlobalVariable @var7 : !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]), Block>, StorageBuffer>
// CHECK: !spirv.ptr<!spirv.struct<()>, StorageBuffer>
spirv.GlobalVariable @empty : !spirv.ptr<!spirv.struct<()>, StorageBuffer>
@@ -34,15 +34,17 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// CHECK: !spirv.ptr<!spirv.struct<test_id, (!spirv.array<128 x f32, stride=4> [0])>, Input>
spirv.GlobalVariable @id_var0 : !spirv.ptr<!spirv.struct<test_id, (!spirv.array<128 x f32, stride=4> [0])>, Input>
+ // CHECK: !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>), Block>, StorageBuffer>
+ spirv.GlobalVariable @recursive_simple : !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>), Block>, StorageBuffer>
- // CHECK: !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>)>, StorageBuffer>
- spirv.GlobalVariable @recursive_simple : !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>)>, StorageBuffer>
+ // CHECK: !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>), Block>, Uniform>), Block>, Uniform>
+ spirv.GlobalVariable @recursive_2 : !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>), Block>, Uniform>), Block>, Uniform>
- // CHECK: !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>)>, Uniform>)>, Uniform>
- spirv.GlobalVariable @recursive_2 : !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>)>, Uniform>)>, Uniform>
+ // CHECK: !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>), Block>, Uniform>), Block>, Uniform>
+ spirv.GlobalVariable @recursive_3 : !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>), Block>, Uniform>), Block>, Uniform>
- // CHECK: !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>)>, Uniform>)>, Uniform>
- spirv.GlobalVariable @recursive_3 : !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>)>, Uniform>)>, Uniform>
+ // CHECK: spirv.GlobalVariable @block : !spirv.ptr<!spirv.struct<vert, (vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>, Output>
+ spirv.GlobalVariable @block : !spirv.ptr<!spirv.struct<vert, (vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>, Output>
// CHECK: !spirv.ptr<!spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, Input>,
// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, Output>
diff --git a/mlir/test/Target/SPIRV/undef.mlir b/mlir/test/Target/SPIRV/undef.mlir
index b9044fe..8889b80 100644
--- a/mlir/test/Target/SPIRV/undef.mlir
+++ b/mlir/test/Target/SPIRV/undef.mlir
@@ -13,10 +13,10 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// CHECK: {{%.*}} = spirv.Undef : !spirv.array<4 x !spirv.array<4 x i32>>
%5 = spirv.Undef : !spirv.array<4x!spirv.array<4xi32>>
%6 = spirv.CompositeExtract %5[1 : i32, 2 : i32] : !spirv.array<4x!spirv.array<4xi32>>
- // CHECK: {{%.*}} = spirv.Undef : !spirv.ptr<!spirv.struct<(f32)>, StorageBuffer>
- %7 = spirv.Undef : !spirv.ptr<!spirv.struct<(f32)>, StorageBuffer>
+ // CHECK: {{%.*}} = spirv.Undef : !spirv.ptr<!spirv.struct<(f32), Block>, StorageBuffer>
+ %7 = spirv.Undef : !spirv.ptr<!spirv.struct<(f32), Block>, StorageBuffer>
%8 = spirv.Constant 0 : i32
- %9 = spirv.AccessChain %7[%8] : !spirv.ptr<!spirv.struct<(f32)>, StorageBuffer>, i32 -> !spirv.ptr<f32, StorageBuffer>
+ %9 = spirv.AccessChain %7[%8] : !spirv.ptr<!spirv.struct<(f32), Block>, StorageBuffer>, i32 -> !spirv.ptr<f32, StorageBuffer>
spirv.Return
}
}
diff --git a/mlir/test/Transforms/compose-subview.mlir b/mlir/test/Transforms/compose-subview.mlir
index 53fbb8a..d6fa442 100644
--- a/mlir/test/Transforms/compose-subview.mlir
+++ b/mlir/test/Transforms/compose-subview.mlir
@@ -1,9 +1,9 @@
// RUN: mlir-opt %s -test-compose-subview -split-input-file | FileCheck %s
// CHECK-LABEL: func.func @subview_strided(
-// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: 3456>> {
+// CHECK-SAME: %[[input:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: 3456>> {
func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: 3456>> {
- // CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][3, 384] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: 3456>>
+ // CHECK: {{.*}} = memref.subview %[[input]][3, 384] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: 3456>>
%0 = memref.subview %input[2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, strided<[1024, 1], offset: 2304>>
%1 = memref.subview %0[1, 128] [1, 128] [1, 1] : memref<2x256xf32, strided<[1024, 1], offset: 2304>> to memref<1x128xf32, strided<[1024, 1], offset: 3456>>
return %1 : memref<1x128xf32, strided<[1024, 1], offset: 3456>>
@@ -12,9 +12,9 @@ func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, stri
// -----
// CHECK-LABEL: func.func @subview_strided(
-// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x10xf32, strided<[1024, 1], offset: 3745>> {
+// CHECK-SAME: %[[input:.*]]: memref<4x1024xf32>) -> memref<1x10xf32, strided<[1024, 1], offset: 3745>> {
func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x10xf32, strided<[1024, 1], offset: 3745>> {
- // CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][3, 673] [1, 10] [1, 1] : memref<4x1024xf32> to memref<1x10xf32, strided<[1024, 1], offset: 3745>>
+ // CHECK: {{.*}} = memref.subview %[[input]][3, 673] [1, 10] [1, 1] : memref<4x1024xf32> to memref<1x10xf32, strided<[1024, 1], offset: 3745>>
%0 = memref.subview %input[1, 512] [3, 256] [1, 1] : memref<4x1024xf32> to memref<3x256xf32, strided<[1024, 1], offset: 1536>>
%1 = memref.subview %0[1, 128] [2, 128] [1, 1] : memref<3x256xf32, strided<[1024, 1], offset: 1536>> to memref<2x128xf32, strided<[1024, 1], offset: 2688>>
%2 = memref.subview %1[1, 33] [1, 10] [1, 1] : memref<2x128xf32, strided<[1024, 1], offset: 2688>> to memref<1x10xf32, strided<[1024, 1], offset: 3745>>
@@ -24,12 +24,12 @@ func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x10xf32, strid
// -----
// CHECK-LABEL: func.func @subview_strided(
-// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
+// CHECK-SAME: %[[input:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
- // CHECK: %[[VAL_1:.*]] = arith.constant 3 : index
+ // CHECK: %[[C3:.*]] = arith.constant 3 : index
%cst_1 = arith.constant 1 : index
%cst_2 = arith.constant 2 : index
- // CHECK: %[[VAL_2:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], 384] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
+ // CHECK: {{.*}} = memref.subview %[[input]]{{\[}}%[[C3]], 384] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
%0 = memref.subview %input[%cst_2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, strided<[1024, 1], offset: ?>>
%1 = memref.subview %0[%cst_1, 128] [1, 128] [1, 1] : memref<2x256xf32, strided<[1024, 1], offset: ?>> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
return %1 : memref<1x128xf32, strided<[1024, 1], offset: ?>>
@@ -38,13 +38,13 @@ func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, stri
// -----
// CHECK-LABEL: func.func @subview_strided(
-// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
+// CHECK-SAME: %[[input:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
- // CHECK: %[[VAL_1:.*]] = arith.constant 3 : index
+ // CHECK: %[[C3:.*]] = arith.constant 3 : index
%cst_2 = arith.constant 2 : index
- // CHECK: %[[VAL_2:.*]] = arith.constant 384 : index
+ // CHECK: %[[C384:.*]] = arith.constant 384 : index
%cst_128 = arith.constant 128 : index
- // CHECK: %[[VAL_3:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], %[[VAL_2]]] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
+ // CHECK: {{.*}} = memref.subview %[[input]]{{\[}}%[[C3]], %[[C384]]] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
%0 = memref.subview %input[%cst_2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, strided<[1024, 1], offset: ?>>
%1 = memref.subview %0[1, %cst_128] [1, 128] [1, 1] : memref<2x256xf32, strided<[1024, 1], offset: ?>> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
return %1 : memref<1x128xf32, strided<[1024, 1], offset: ?>>
@@ -53,9 +53,9 @@ func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, stri
// -----
// CHECK-LABEL: func.func @subview_strided(
-// CHECK-SAME: %[[VAL_0:.*]]: memref<8x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> {
+// CHECK-SAME: %[[input:.*]]: memref<8x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> {
func.func @subview_strided(%input: memref<8x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> {
- // CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][4, 384] [1, 64] [4, 4] : memref<8x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: 4480>>
+ // CHECK: {{.*}} = memref.subview %[[input]][4, 384] [1, 64] [4, 4] : memref<8x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: 4480>>
%0 = memref.subview %input[2, 256] [2, 256] [2, 2] : memref<8x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: 2304>>
%1 = memref.subview %0[1, 64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: 2304>> to memref<1x64xf32, strided<[4096, 4], offset: 4480>>
return %1 : memref<1x64xf32, strided<[4096, 4], offset: 4480>>
@@ -64,9 +64,9 @@ func.func @subview_strided(%input: memref<8x1024xf32>) -> memref<1x64xf32, strid
// -----
// CHECK-LABEL: func.func @subview_strided(
-// CHECK-SAME: %[[VAL_0:.*]]: memref<30x30xf32>) -> memref<2x2xf32, strided<[240, 8], offset: 217>> {
+// CHECK-SAME: %[[input:.*]]: memref<30x30xf32>) -> memref<2x2xf32, strided<[240, 8], offset: 217>> {
func.func @subview_strided(%input: memref<30x30xf32>) -> memref<2x2xf32, strided<[240, 8], offset: 217>> {
- // CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][7, 7] [2, 2] [8, 8] : memref<30x30xf32> to memref<2x2xf32, strided<[240, 8], offset: 217>>
+ // CHECK: {{.*}} = memref.subview %[[input]][7, 7] [2, 2] [8, 8] : memref<30x30xf32> to memref<2x2xf32, strided<[240, 8], offset: 217>>
%0 = memref.subview %input[1, 1] [12, 12] [2, 2] : memref<30x30xf32> to memref<12x12xf32, strided<[60, 2], offset: 31>>
%1 = memref.subview %0[1, 1] [5, 5] [2, 2] : memref<12x12xf32, strided<[60, 2], offset: 31>> to memref<5x5xf32, strided<[120, 4], offset: 93>>
%2 = memref.subview %1[1, 1] [2, 2] [2, 2] : memref<5x5xf32, strided<[120, 4], offset: 93>> to memref<2x2xf32, strided<[240, 8], offset: 217>>
@@ -76,13 +76,13 @@ func.func @subview_strided(%input: memref<30x30xf32>) -> memref<2x2xf32, strided
// -----
// CHECK-LABEL: func.func @subview_strided(
-// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
+// CHECK-SAME: %[[input:.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
- // CHECK: %[[VAL_1:.*]] = arith.constant 4 : index
+ // CHECK: %[[C4:.*]] = arith.constant 4 : index
%cst_2 = arith.constant 2 : index
- // CHECK: %[[VAL_2:.*]] = arith.constant 384 : index
+ // CHECK: %[[C384:.*]] = arith.constant 384 : index
%cst_64 = arith.constant 64 : index
- // CHECK: %[[VAL_3:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], %[[VAL_2]]] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
+ // CHECK: {{.*}} = memref.subview %[[input]]{{\[}}%[[C4]], %[[C384]]] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
%0 = memref.subview %input[%cst_2, 256] [2, 256] [2, 2] : memref<4x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: ?>>
%1 = memref.subview %0[1, %cst_64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: ?>> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
return %1 : memref<1x64xf32, strided<[4096, 4], offset: ?>>
@@ -91,13 +91,39 @@ func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x64xf32, strid
// -----
// CHECK-LABEL: func.func @subview_strided(
-// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
+// CHECK-SAME: %[[input:.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
- // CHECK: %[[VAL_1:.*]] = arith.constant 4 : index
+ // CHECK: %[[C4:.*]] = arith.constant 4 : index
%cst_1 = arith.constant 1 : index
%cst_2 = arith.constant 2 : index
- // CHECK: %[[VAL_2:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], 384] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
+ // CHECK: {{.*}} = memref.subview %[[input]]{{\[}}%[[C4]], 384] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
%0 = memref.subview %input[%cst_2, 256] [2, 256] [2, 2] : memref<4x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: ?>>
%1 = memref.subview %0[%cst_1, 64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: ?>> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
return %1 : memref<1x64xf32, strided<[4096, 4], offset: ?>>
}
+
+// -----
+
+// CHECK-LABEL: func.func @single_dynamic_size_subview(
+// CHECK-SAME: %[[input:.*]]: memref<256x?xf32>,
+// CHECK-SAME: %{{.*}}: index,
+// CHECK-SAME: %[[SIZE_1:.*]]: index) -> memref<8x?xf32> {
+func.func @single_dynamic_size_subview(%input: memref<256x?xf32>, %size0 : index, %size1 : index) -> memref<8x?xf32>{
+ %subview = memref.subview %input[0, 0][8, %size0][1, 1] : memref<256x?xf32> to memref<8x?xf32>
+ %subview_1 = memref.subview %subview[0, 0][8, %size1][1, 1] : memref<8x?xf32> to memref<8x?xf32>
+ // CHECK: %{{.*}} = memref.subview %[[input]][0, 0] [8, %[[SIZE_1]]] [1, 1] : memref<256x?xf32> to memref<8x?xf32>
+ return %subview_1 : memref<8x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @all_dynamic_size_subview(
+// CHECK-SAME: %[[input:.*]]: memref<256x?xf32>,
+// CHECK-SAME: %{{.*}}: index,
+// CHECK-SAME: %[[SIZE1:.*]]: index) -> memref<?x?xf32> {
+func.func @all_dynamic_size_subview(%input: memref<256x?xf32>, %size0 : index, %size1 : index) -> memref<?x?xf32>{
+ %subview = memref.subview %input[0, 0][%size0, %size0][1, 1] : memref<256x?xf32> to memref<?x?xf32>
+ %subview_1 = memref.subview %subview[0, 0][%size1, %size1][1, 1] : memref<?x?xf32> to memref<?x?xf32>
+ // CHECK: {{.*}} = memref.subview %[[input]][0, 0] {{\[}}%[[SIZE1]], %[[SIZE1]]] [1, 1] : memref<256x?xf32> to memref<?x?xf32>
+ return %subview_1 : memref<?x?xf32>
+}
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 3af95db..9ded6a3 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -548,3 +548,26 @@ func.func @test_atomic_yield(%I: memref<10xf32>, %idx : index) {
func.return
}
+// -----
+
+// CHECK-LABEL: module @return_void_with_unused_argument
+module @return_void_with_unused_argument {
+ // CHECK-LABEL: func.func private @fn_return_void_with_unused_argument
+ // CHECK-SAME: (%[[ARG0_FN:.*]]: i32)
+ func.func private @fn_return_void_with_unused_argument(%arg0: i32, %arg1: memref<4xi32>) -> () {
+ %sum = arith.addi %arg0, %arg0 : i32
+ %c0 = arith.constant 0 : index
+ %buf = memref.alloc() : memref<1xi32>
+ memref.store %sum, %buf[%c0] : memref<1xi32>
+ return
+ }
+ // CHECK-LABEL: func.func @main
+ // CHECK-SAME: (%[[ARG0_MAIN:.*]]: i32)
+ // CHECK: call @fn_return_void_with_unused_argument(%[[ARG0_MAIN]]) : (i32) -> ()
+ func.func @main(%arg0: i32) -> memref<4xi32> {
+ %unused = memref.alloc() : memref<4xi32>
+ call @fn_return_void_with_unused_argument(%arg0, %unused) : (i32, memref<4xi32>) -> ()
+ return %unused : memref<4xi32>
+ }
+}
+
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index db8bd0f..9bffe92 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -104,8 +104,8 @@ func.func @test_signature_conversion_no_converter() {
"test.signature_conversion_no_converter"() ({
// expected-error@below {{failed to legalize unresolved materialization from ('f64') to ('f32') that remained live after conversion}}
^bb0(%arg0: f32):
- "test.type_consumer"(%arg0) : (f32) -> ()
// expected-note@below{{see existing live user here}}
+ "test.type_consumer"(%arg0) : (f32) -> ()
"test.return"(%arg0) : (f32) -> ()
}) : () -> ()
return
diff --git a/mlir/test/Transforms/test-legalizer-analysis.mlir b/mlir/test/Transforms/test-legalizer-analysis.mlir
index 19a1310..5b07055 100644
--- a/mlir/test/Transforms/test-legalizer-analysis.mlir
+++ b/mlir/test/Transforms/test-legalizer-analysis.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns -verify-diagnostics -test-legalize-mode=analysis %s | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns="test-legalize-mode=analysis" -verify-diagnostics %s | FileCheck %s
// expected-remark@-2 {{op 'builtin.module' is legalizable}}
// expected-remark@+1 {{op 'func.func' is legalizable}}
diff --git a/mlir/test/Transforms/test-legalizer-full.mlir b/mlir/test/Transforms/test-legalizer-full.mlir
index 5f1148c..dcd0172 100644
--- a/mlir/test/Transforms/test-legalizer-full.mlir
+++ b/mlir/test/Transforms/test-legalizer-full.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns -test-legalize-mode=full -split-input-file -verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns="test-legalize-mode=full" -split-input-file -verify-diagnostics %s | FileCheck %s
// CHECK-LABEL: func @multi_level_mapping
func.func @multi_level_mapping() {
diff --git a/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp b/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp
index 8a01a0a..016052c 100644
--- a/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp
+++ b/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp
@@ -69,25 +69,25 @@ struct MathCosToVCIX final : OpRewritePattern<math::CosOp> {
if (legalType.isScalable())
// Use arbitrary runtime vector length when vector type is scalable.
// Proper conversion pass should take it from the IR.
- rvl = rewriter.create<arith::ConstantOp>(loc,
- rewriter.getI64IntegerAttr(9));
+ rvl = arith::ConstantOp::create(rewriter, loc,
+ rewriter.getI64IntegerAttr(9));
Value res;
if (n == 1) {
- res = rewriter.create<vcix::BinaryImmOp>(loc, legalType, opcodeAttr, vec,
- immAttr, rvl);
+ res = vcix::BinaryImmOp::create(rewriter, loc, legalType, opcodeAttr, vec,
+ immAttr, rvl);
} else {
const unsigned eltCount = legalType.getShape()[0];
Type eltTy = legalType.getElementType();
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, eltTy, rewriter.getZeroAttr(eltTy));
- res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/);
+ Value zero = arith::ConstantOp::create(rewriter, loc, eltTy,
+ rewriter.getZeroAttr(eltTy));
+ res = vector::BroadcastOp::create(rewriter, loc, opType, zero /*dummy*/);
for (unsigned i = 0; i < n; ++i) {
- Value extracted = rewriter.create<vector::ScalableExtractOp>(
- loc, legalType, vec, i * eltCount);
- Value v = rewriter.create<vcix::BinaryImmOp>(loc, legalType, opcodeAttr,
- extracted, immAttr, rvl);
- res = rewriter.create<vector::ScalableInsertOp>(loc, v, res,
- i * eltCount);
+ Value extracted = vector::ScalableExtractOp::create(
+ rewriter, loc, legalType, vec, i * eltCount);
+ Value v = vcix::BinaryImmOp::create(
+ rewriter, loc, legalType, opcodeAttr, extracted, immAttr, rvl);
+ res = vector::ScalableInsertOp::create(rewriter, loc, v, res,
+ i * eltCount);
}
}
rewriter.replaceOp(op, res);
@@ -112,25 +112,25 @@ struct MathSinToVCIX final : OpRewritePattern<math::SinOp> {
if (legalType.isScalable())
// Use arbitrary runtime vector length when vector type is scalable.
// Proper conversion pass should take it from the IR.
- rvl = rewriter.create<arith::ConstantOp>(loc,
- rewriter.getI64IntegerAttr(9));
+ rvl = arith::ConstantOp::create(rewriter, loc,
+ rewriter.getI64IntegerAttr(9));
Value res;
if (n == 1) {
- res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec,
- vec, rvl);
+ res = vcix::BinaryOp::create(rewriter, loc, legalType, opcodeAttr, vec,
+ vec, rvl);
} else {
const unsigned eltCount = legalType.getShape()[0];
Type eltTy = legalType.getElementType();
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, eltTy, rewriter.getZeroAttr(eltTy));
- res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/);
+ Value zero = arith::ConstantOp::create(rewriter, loc, eltTy,
+ rewriter.getZeroAttr(eltTy));
+ res = vector::BroadcastOp::create(rewriter, loc, opType, zero /*dummy*/);
for (unsigned i = 0; i < n; ++i) {
- Value extracted = rewriter.create<vector::ScalableExtractOp>(
- loc, legalType, vec, i * eltCount);
- Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr,
- extracted, extracted, rvl);
- res = rewriter.create<vector::ScalableInsertOp>(loc, v, res,
- i * eltCount);
+ Value extracted = vector::ScalableExtractOp::create(
+ rewriter, loc, legalType, vec, i * eltCount);
+ Value v = vcix::BinaryOp::create(rewriter, loc, legalType, opcodeAttr,
+ extracted, extracted, rvl);
+ res = vector::ScalableInsertOp::create(rewriter, loc, v, res,
+ i * eltCount);
}
}
rewriter.replaceOp(op, res);
@@ -152,28 +152,28 @@ struct MathTanToVCIX final : OpRewritePattern<math::TanOp> {
Location loc = op.getLoc();
Value vec = op.getOperand();
Attribute opcodeAttr = rewriter.getI64IntegerAttr(0);
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, eltTy, rewriter.getZeroAttr(eltTy));
+ Value zero = arith::ConstantOp::create(rewriter, loc, eltTy,
+ rewriter.getZeroAttr(eltTy));
Value rvl = nullptr;
if (legalType.isScalable())
// Use arbitrary runtime vector length when vector type is scalable.
// Proper conversion pass should take it from the IR.
- rvl = rewriter.create<arith::ConstantOp>(loc,
- rewriter.getI64IntegerAttr(9));
+ rvl = arith::ConstantOp::create(rewriter, loc,
+ rewriter.getI64IntegerAttr(9));
Value res;
if (n == 1) {
- res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec,
- zero, rvl);
+ res = vcix::BinaryOp::create(rewriter, loc, legalType, opcodeAttr, vec,
+ zero, rvl);
} else {
const unsigned eltCount = legalType.getShape()[0];
- res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/);
+ res = vector::BroadcastOp::create(rewriter, loc, opType, zero /*dummy*/);
for (unsigned i = 0; i < n; ++i) {
- Value extracted = rewriter.create<vector::ScalableExtractOp>(
- loc, legalType, vec, i * eltCount);
- Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr,
- extracted, zero, rvl);
- res = rewriter.create<vector::ScalableInsertOp>(loc, v, res,
- i * eltCount);
+ Value extracted = vector::ScalableExtractOp::create(
+ rewriter, loc, legalType, vec, i * eltCount);
+ Value v = vcix::BinaryOp::create(rewriter, loc, legalType, opcodeAttr,
+ extracted, zero, rvl);
+ res = vector::ScalableInsertOp::create(rewriter, loc, v, res,
+ i * eltCount);
}
}
rewriter.replaceOp(op, res);
@@ -195,30 +195,30 @@ struct MathLogToVCIX final : OpRewritePattern<math::LogOp> {
Value vec = op.getOperand();
Attribute opcodeAttr = rewriter.getI64IntegerAttr(0);
Value rvl = nullptr;
- Value zeroInt = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
+ Value zeroInt = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
if (legalType.isScalable())
// Use arbitrary runtime vector length when vector type is scalable.
// Proper conversion pass should take it from the IR.
- rvl = rewriter.create<arith::ConstantOp>(loc,
- rewriter.getI64IntegerAttr(9));
+ rvl = arith::ConstantOp::create(rewriter, loc,
+ rewriter.getI64IntegerAttr(9));
Value res;
if (n == 1) {
- res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec,
- zeroInt, rvl);
+ res = vcix::BinaryOp::create(rewriter, loc, legalType, opcodeAttr, vec,
+ zeroInt, rvl);
} else {
const unsigned eltCount = legalType.getShape()[0];
Type eltTy = legalType.getElementType();
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, eltTy, rewriter.getZeroAttr(eltTy));
- res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/);
+ Value zero = arith::ConstantOp::create(rewriter, loc, eltTy,
+ rewriter.getZeroAttr(eltTy));
+ res = vector::BroadcastOp::create(rewriter, loc, opType, zero /*dummy*/);
for (unsigned i = 0; i < n; ++i) {
- Value extracted = rewriter.create<vector::ScalableExtractOp>(
- loc, legalType, vec, i * eltCount);
- Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr,
- extracted, zeroInt, rvl);
- res = rewriter.create<vector::ScalableInsertOp>(loc, v, res,
- i * eltCount);
+ Value extracted = vector::ScalableExtractOp::create(
+ rewriter, loc, legalType, vec, i * eltCount);
+ Value v = vcix::BinaryOp::create(rewriter, loc, legalType, opcodeAttr,
+ extracted, zeroInt, rvl);
+ res = vector::ScalableInsertOp::create(rewriter, loc, v, res,
+ i * eltCount);
}
}
rewriter.replaceOp(op, res);
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index ed5d06d..3569a73 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -145,7 +145,7 @@ static LogicalResult testReifyValueBounds(FunctionOpInterface funcOp,
if (reifiedScalable->map.getNumInputs() == 1) {
// The only possible input to the bound is vscale.
vscaleOperand.push_back(std::make_pair(
- rewriter.create<vector::VectorScaleOp>(loc), std::nullopt));
+ vector::VectorScaleOp::create(rewriter, loc), std::nullopt));
}
reified = affine::materializeComputedBound(
rewriter, loc, reifiedScalable->map, vscaleOperand);
@@ -169,8 +169,9 @@ static LogicalResult testReifyValueBounds(FunctionOpInterface funcOp,
rewriter.replaceOp(op, val);
return WalkResult::skip();
}
- Value constOp = rewriter.create<arith::ConstantIndexOp>(
- op->getLoc(), cast<IntegerAttr>(cast<Attribute>(*reified)).getInt());
+ Value constOp = arith::ConstantIndexOp::create(
+ rewriter, op->getLoc(),
+ cast<IntegerAttr>(cast<Attribute>(*reified)).getInt());
rewriter.replaceOp(op, constOp);
return WalkResult::skip();
});
diff --git a/mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp b/mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp
index 738d4ee59..a792d08 100644
--- a/mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp
+++ b/mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp
@@ -60,7 +60,7 @@ struct TestEmulateWideIntPass
// casts (and vice versa) and using it insted of `llvm.bitcast`.
auto addBitcast = [](OpBuilder &builder, Type type, ValueRange inputs,
Location loc) -> Value {
- auto cast = builder.create<LLVM::BitcastOp>(loc, type, inputs);
+ auto cast = LLVM::BitcastOp::create(builder, loc, type, inputs);
return cast->getResult(0);
};
typeConverter.addSourceMaterialization(addBitcast);
diff --git a/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt b/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt
index 226e0bb..2ee3222 100644
--- a/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt
@@ -1,5 +1,6 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRBufferizationTestPasses
+ TestOneShotModuleBufferize.cpp
TestTensorCopyInsertion.cpp
TestTensorLikeAndBufferLike.cpp
diff --git a/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp b/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp
new file mode 100644
index 0000000..1e2d4a7
--- /dev/null
+++ b/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp
@@ -0,0 +1,57 @@
+//===- TestOneShotModuleBufferzation.cpp - Bufferization Test -----*- c++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
+#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+struct TestOneShotModuleBufferizePass
+ : public PassWrapper<TestOneShotModuleBufferizePass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOneShotModuleBufferizePass)
+
+ TestOneShotModuleBufferizePass() = default;
+ TestOneShotModuleBufferizePass(const TestOneShotModuleBufferizePass &pass)
+ : PassWrapper(pass) {}
+
+ void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<bufferization::BufferizationDialect>();
+ }
+ StringRef getArgument() const final {
+ return "test-one-shot-module-bufferize";
+ }
+ StringRef getDescription() const final {
+ return "Pass to test One Shot Module Bufferization";
+ }
+
+ void runOnOperation() override {
+
+ llvm::errs() << "Running TestOneShotModuleBufferize on: "
+ << getOperation()->getName() << "\n";
+ bufferization::OneShotBufferizationOptions opt;
+
+ opt.bufferizeFunctionBoundaries = true;
+ bufferization::BufferizationState bufferizationState;
+
+ if (failed(bufferization::runOneShotModuleBufferize(getOperation(), opt,
+ bufferizationState)))
+ signalPassFailure();
+ }
+};
+} // namespace
+
+namespace mlir::test {
+void registerTestOneShotModuleBufferizePass() {
+ PassRegistration<TestOneShotModuleBufferizePass>();
+}
+} // namespace mlir::test
diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index eb2f74e..3b7bd9b 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -10,7 +10,7 @@ add_subdirectory(Linalg)
add_subdirectory(LLVM)
add_subdirectory(Math)
add_subdirectory(MemRef)
-add_subdirectory(Mesh)
+add_subdirectory(Shard)
add_subdirectory(NVGPU)
add_subdirectory(SCF)
add_subdirectory(Shape)
diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
index d0b62e7..c67bcd9 100644
--- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
+++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
@@ -48,8 +48,8 @@ static SmallVector<Value> buildDecomposeTuple(OpBuilder &builder,
}
for (unsigned i = 0, e = tupleType.size(); i < e; ++i) {
Type elementType = tupleType.getType(i);
- Value element = builder.create<test::GetTupleElementOp>(
- loc, elementType, tuple, builder.getI32IntegerAttr(i));
+ Value element = test::GetTupleElementOp::create(
+ builder, loc, elementType, tuple, builder.getI32IntegerAttr(i));
decompose(element);
}
};
@@ -94,7 +94,7 @@ static Value buildMakeTupleOp(OpBuilder &builder, TupleType resultType,
}
// Assemble the tuple from the elements.
- return builder.create<test::MakeTupleOp>(loc, resultType, elements);
+ return test::MakeTupleOp::create(builder, loc, resultType, elements);
}
/// A pass for testing call graph type decomposition.
diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
index 9eade75..9a394d2 100644
--- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
@@ -56,7 +56,7 @@ struct TestSCFForUtilsPass
SmallVector<Value> newYieldValues;
for (auto yieldVal : oldYieldValues) {
newYieldValues.push_back(
- b.create<arith::AddFOp>(loc, yieldVal, yieldVal));
+ arith::AddFOp::create(b, loc, yieldVal, yieldVal));
}
return newYieldValues;
};
@@ -160,13 +160,13 @@ struct TestSCFPipeliningPass
Value pred) {
Location loc = op->getLoc();
auto ifOp =
- rewriter.create<scf::IfOp>(loc, op->getResultTypes(), pred, true);
+ scf::IfOp::create(rewriter, loc, op->getResultTypes(), pred, true);
// True branch.
rewriter.moveOpBefore(op, &ifOp.getThenRegion().front(),
ifOp.getThenRegion().front().begin());
rewriter.setInsertionPointAfter(op);
if (op->getNumResults() > 0)
- rewriter.create<scf::YieldOp>(loc, op->getResults());
+ scf::YieldOp::create(rewriter, loc, op->getResults());
// False branch.
rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
SmallVector<Value> elseYieldOperands;
@@ -181,12 +181,12 @@ struct TestSCFPipeliningPass
} else {
// Default to assuming constant numeric values.
for (Type type : op->getResultTypes()) {
- elseYieldOperands.push_back(rewriter.create<arith::ConstantOp>(
- loc, rewriter.getZeroAttr(type)));
+ elseYieldOperands.push_back(arith::ConstantOp::create(
+ rewriter, loc, rewriter.getZeroAttr(type)));
}
}
if (op->getNumResults() > 0)
- rewriter.create<scf::YieldOp>(loc, elseYieldOperands);
+ scf::YieldOp::create(rewriter, loc, elseYieldOperands);
return ifOp.getOperation();
}
diff --git a/mlir/test/lib/Dialect/SCF/TestWhileOpBuilder.cpp b/mlir/test/lib/Dialect/SCF/TestWhileOpBuilder.cpp
index d3113c0..d3f7f0e6 100644
--- a/mlir/test/lib/Dialect/SCF/TestWhileOpBuilder.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestWhileOpBuilder.cpp
@@ -50,23 +50,23 @@ struct TestSCFWhileOpBuilderPass
// Create a WhileOp with the same operands and result types.
TypeRange resultTypes = whileOp->getResultTypes();
ValueRange operands = whileOp->getOperands();
- builder.create<WhileOp>(
- loc, resultTypes, operands, /*beforeBuilder=*/
+ WhileOp::create(
+ builder, loc, resultTypes, operands, /*beforeBuilder=*/
[&](OpBuilder &b, Location loc, ValueRange args) {
// Just cast the before args into the right types for condition.
ImplicitLocOpBuilder builder(loc, b);
auto castOp =
- builder.create<UnrealizedConversionCastOp>(resultTypes, args);
- auto cmp = builder.create<ConstantIntOp>(/*value=*/1, /*width=*/1);
- builder.create<ConditionOp>(cmp, castOp->getResults());
+ UnrealizedConversionCastOp::create(builder, resultTypes, args);
+ auto cmp = ConstantIntOp::create(builder, /*value=*/1, /*width=*/1);
+ ConditionOp::create(builder, cmp, castOp->getResults());
},
/*afterBuilder=*/
[&](OpBuilder &b, Location loc, ValueRange args) {
// Just cast the after args into the right types for yield.
ImplicitLocOpBuilder builder(loc, b);
- auto castOp = builder.create<UnrealizedConversionCastOp>(
- operands.getTypes(), args);
- builder.create<YieldOp>(castOp->getResults());
+ auto castOp = UnrealizedConversionCastOp::create(
+ builder, operands.getTypes(), args);
+ YieldOp::create(builder, castOp->getResults());
});
});
}
diff --git a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt b/mlir/test/lib/Dialect/Shard/CMakeLists.txt
index 7bd0493..f91c547 100644
--- a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Shard/CMakeLists.txt
@@ -1,14 +1,14 @@
# Exclude tests from libMLIR.so
-add_mlir_library(MLIRMeshTest
+add_mlir_library(MLIRShardTest
TestOpLowering.cpp
- TestReshardingSpmdization.cpp
+ TestReshardingPartition.cpp
TestSimplifications.cpp
EXCLUDE_FROM_LIBMLIR
)
-mlir_target_link_libraries(MLIRMeshTest PUBLIC
- MLIRMeshDialect
- MLIRMeshTransforms
+mlir_target_link_libraries(MLIRShardTest PUBLIC
+ MLIRShardDialect
+ MLIRShardTransforms
MLIRPass
MLIRRewrite
MLIRTransformUtils
diff --git a/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp b/mlir/test/lib/Dialect/Shard/TestOpLowering.cpp
index dbae93b..43f3b3f 100644
--- a/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp
+++ b/mlir/test/lib/Dialect/Shard/TestOpLowering.cpp
@@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
+#include "mlir/Dialect/Shard/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
@@ -24,17 +24,17 @@ struct TestAllSliceOpLoweringPass
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
SymbolTableCollection symbolTableCollection;
- mesh::populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection);
+ shard::populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection);
LogicalResult status =
applyPatternsGreedily(getOperation(), std::move(patterns));
(void)status;
assert(succeeded(status) && "applyPatternsGreedily failed.");
}
void getDependentDialects(DialectRegistry &registry) const override {
- mesh::registerAllSliceOpLoweringDialects(registry);
+ shard::registerAllSliceOpLoweringDialects(registry);
}
StringRef getArgument() const final {
- return "test-mesh-all-slice-op-lowering";
+ return "test-grid-all-slice-op-lowering";
}
StringRef getDescription() const final {
return "Test lowering of all-slice.";
@@ -48,21 +48,21 @@ struct TestMultiIndexOpLoweringPass
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
SymbolTableCollection symbolTableCollection;
- mesh::populateProcessMultiIndexOpLoweringPatterns(patterns,
- symbolTableCollection);
+ shard::populateProcessMultiIndexOpLoweringPatterns(patterns,
+ symbolTableCollection);
LogicalResult status =
applyPatternsGreedily(getOperation(), std::move(patterns));
(void)status;
assert(succeeded(status) && "applyPatternsGreedily failed.");
}
void getDependentDialects(DialectRegistry &registry) const override {
- mesh::registerProcessMultiIndexOpLoweringDialects(registry);
+ shard::registerProcessMultiIndexOpLoweringDialects(registry);
}
StringRef getArgument() const final {
- return "test-mesh-process-multi-index-op-lowering";
+ return "test-grid-process-multi-index-op-lowering";
}
StringRef getDescription() const final {
- return "Test lowering of mesh.process_multi_index op.";
+ return "Test lowering of shard.process_multi_index op.";
}
};
diff --git a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp b/mlir/test/lib/Dialect/Shard/TestReshardingPartition.cpp
index 102e64d..23fdad1 100644
--- a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
+++ b/mlir/test/lib/Dialect/Shard/TestReshardingPartition.cpp
@@ -7,9 +7,9 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
-#include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
+#include "mlir/Dialect/Shard/Transforms/Partition.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
@@ -22,11 +22,11 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
-using namespace mlir::mesh;
+using namespace mlir::shard;
namespace {
-struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
+struct TestReshardingRewritePattern : OpRewritePattern<ShardOp> {
using OpRewritePattern<ShardOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ShardOp op,
@@ -36,18 +36,18 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
}
SymbolTableCollection symbolTable;
- mesh::MeshOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
- op, cast<ShardingOp>(op.getSharding().getDefiningOp()).getMeshAttr());
+ shard::GridOp grid = symbolTable.lookupNearestSymbolFrom<shard::GridOp>(
+ op, cast<ShardingOp>(op.getSharding().getDefiningOp()).getGridAttr());
bool foundUser = false;
for (auto user : op->getUsers()) {
if (auto targetShardOp = llvm::dyn_cast<ShardOp>(user)) {
if (targetShardOp.getAnnotateForUsers() &&
- mesh == symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
+ grid == symbolTable.lookupNearestSymbolFrom<shard::GridOp>(
targetShardOp,
cast<ShardingOp>(
targetShardOp.getSharding().getDefiningOp())
- .getMeshAttr())) {
+ .getGridAttr())) {
foundUser = true;
break;
}
@@ -61,26 +61,25 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
for (auto user : op->getUsers()) {
auto targetShardOp = llvm::dyn_cast<ShardOp>(user);
if (!targetShardOp || !targetShardOp.getAnnotateForUsers() ||
- symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
+ symbolTable.lookupNearestSymbolFrom<shard::GridOp>(
targetShardOp,
cast<ShardingOp>(targetShardOp.getSharding().getDefiningOp())
- .getMeshAttr()) != mesh) {
+ .getGridAttr()) != grid) {
continue;
}
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
ShapedType sourceShardShape =
- shardShapedType(op.getResult().getType(), mesh, op.getSharding());
+ shardShapedType(op.getResult().getType(), grid, op.getSharding());
TypedValue<ShapedType> sourceShard = cast<TypedValue<ShapedType>>(
- builder
- .create<UnrealizedConversionCastOp>(sourceShardShape, op.getSrc())
+ UnrealizedConversionCastOp::create(builder, sourceShardShape,
+ op.getSrc())
->getResult(0));
TypedValue<ShapedType> targetShard =
- reshard(builder, mesh, op, targetShardOp, sourceShard);
+ reshard(builder, grid, op, targetShardOp, sourceShard);
Value newTargetUnsharded =
- builder
- .create<UnrealizedConversionCastOp>(
- targetShardOp.getResult().getType(), targetShard)
+ UnrealizedConversionCastOp::create(
+ builder, targetShardOp.getResult().getType(), targetShard)
->getResult(0);
rewriter.replaceAllUsesWith(targetShardOp.getResult(),
newTargetUnsharded);
@@ -90,13 +89,13 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
}
};
-struct TestMeshReshardingPass
- : public PassWrapper<TestMeshReshardingPass, OperationPass<ModuleOp>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshReshardingPass)
+struct TestReshardingPass
+ : public PassWrapper<TestReshardingPass, OperationPass<ModuleOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReshardingPass)
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
- patterns.insert<TestMeshReshardingRewritePattern>(&getContext());
+ patterns.insert<TestReshardingRewritePattern>(&getContext());
if (failed(applyPatternsGreedily(getOperation().getOperation(),
std::move(patterns)))) {
return signalPassFailure();
@@ -107,18 +106,18 @@ struct TestMeshReshardingPass
registry.insert<BuiltinDialect>();
}
StringRef getArgument() const final {
- return "test-mesh-resharding-spmdization";
+ return "test-grid-resharding-partition";
}
StringRef getDescription() const final {
- return "Test Mesh dialect resharding spmdization.";
+ return "Test Shard dialect resharding partition.";
}
};
} // namespace
namespace mlir {
namespace test {
-void registerTestMeshReshardingSpmdizationPass() {
- PassRegistration<TestMeshReshardingPass>();
+void registerTestReshardingPartitionPass() {
+ PassRegistration<TestReshardingPass>();
}
} // namespace test
} // namespace mlir
diff --git a/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp b/mlir/test/lib/Dialect/Shard/TestSimplifications.cpp
index 01e196d..2885215 100644
--- a/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp
+++ b/mlir/test/lib/Dialect/Shard/TestSimplifications.cpp
@@ -7,8 +7,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
-#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "mlir/Dialect/Shard/IR/ShardDialect.h"
+#include "mlir/Dialect/Shard/Transforms/Simplifications.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -16,23 +16,23 @@
using namespace mlir;
namespace {
-struct TestMeshSimplificationsPass
- : public PassWrapper<TestMeshSimplificationsPass, OperationPass<>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshSimplificationsPass)
+struct TestShardSimplificationsPass
+ : public PassWrapper<TestShardSimplificationsPass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestShardSimplificationsPass)
void runOnOperation() override;
void getDependentDialects(DialectRegistry &registry) const override {
- registry.insert<arith::ArithDialect, mesh::MeshDialect>();
+ registry.insert<arith::ArithDialect, shard::ShardDialect>();
}
- StringRef getArgument() const final { return "test-mesh-simplifications"; }
- StringRef getDescription() const final { return "Test mesh simplifications"; }
+ StringRef getArgument() const final { return "test-grid-simplifications"; }
+ StringRef getDescription() const final { return "Test grid simplifications"; }
};
} // namespace
-void TestMeshSimplificationsPass::runOnOperation() {
+void TestShardSimplificationsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
SymbolTableCollection symbolTableCollection;
- mesh::populateSimplificationPatterns(patterns, symbolTableCollection);
+ shard::populateSimplificationPatterns(patterns, symbolTableCollection);
[[maybe_unused]] LogicalResult status =
applyPatternsGreedily(getOperation(), std::move(patterns));
assert(succeeded(status) && "Rewrite patters application did not converge.");
@@ -40,8 +40,8 @@ void TestMeshSimplificationsPass::runOnOperation() {
namespace mlir {
namespace test {
-void registerTestMeshSimplificationsPass() {
- PassRegistration<TestMeshSimplificationsPass>();
+void registerTestShardSimplificationsPass() {
+ PassRegistration<TestShardSimplificationsPass>();
}
} // namespace test
} // namespace mlir
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index 0e191c3..687473e 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -192,8 +192,8 @@ struct RewriteExtractSliceFromCollapseShapeBase
// Create the destination tensor using the above values.
Type elementType = op.getSourceType().getElementType();
SmallVector<OpFoldResult> outputShape = reifiedShapes[0];
- Value dest = rewriter.create<tensor::EmptyOp>(op->getLoc(), outputShape,
- elementType);
+ Value dest = tensor::EmptyOp::create(rewriter, op->getLoc(), outputShape,
+ elementType);
// Calculate the parameters for the tile loop nest.
FailureOr<tensor::ExtractSliceFromCollapseHelper> params =
@@ -215,8 +215,8 @@ struct RewriteExtractSliceFromCollapseShapeUsingScfFor
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
const unsigned numTiledDims = helper.getIterationSpaceSizes().size();
- auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ auto one = arith::ConstantIndexOp::create(rewriter, loc, 1);
SmallVector<Value> lbs(numTiledDims, zero);
SmallVector<Value> steps(numTiledDims, one);
@@ -228,8 +228,8 @@ struct RewriteExtractSliceFromCollapseShapeUsingScfFor
helper.emitLoopNestBody(nestedBuilder, loc, outputIvs);
// Insert the slice into the destination.
- return {nestedBuilder.create<tensor::InsertSliceOp>(
- loc, tile, iterArgs[0], insertParams)};
+ return {tensor::InsertSliceOp::create(nestedBuilder, loc, tile,
+ iterArgs[0], insertParams)};
});
rewriter.replaceOp(op, nest.results);
@@ -245,8 +245,9 @@ struct RewriteExtractSliceFromCollapseShapeUsingScfForeach
tensor::ExtractSliceFromCollapseHelper &helper,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- auto forallOp = rewriter.create<scf::ForallOp>(
- loc, /*numThreads=*/getAsOpFoldResult(helper.getIterationSpaceSizes()),
+ auto forallOp = scf::ForallOp::create(
+ rewriter, loc,
+ /*numThreads=*/getAsOpFoldResult(helper.getIterationSpaceSizes()),
/*outputs=*/dest,
/*mapping=*/std::nullopt,
[&](OpBuilder &nestedBuilder, Location loc, ValueRange regionArgs) {
@@ -261,10 +262,10 @@ struct RewriteExtractSliceFromCollapseShapeUsingScfForeach
auto [tile, insertParams] =
helper.emitLoopNestBody(nestedBuilder, loc, outputIvs);
// Insert the slice into the destination.
- auto term = nestedBuilder.create<scf::InParallelOp>(loc);
+ auto term = scf::InParallelOp::create(nestedBuilder, loc);
nestedBuilder.setInsertionPointToStart(term.getBody());
- nestedBuilder.create<tensor::ParallelInsertSliceOp>(
- loc, tile, outputArgs[0], insertParams);
+ tensor::ParallelInsertSliceOp::create(nestedBuilder, loc, tile,
+ outputArgs[0], insertParams);
});
rewriter.replaceOp(op, forallOp->getResult(0));
return success();
@@ -355,8 +356,8 @@ static LogicalResult testTrackingListenerReplacements(Operation *rootOp) {
MLIRContext *context = rootOp->getContext();
OpBuilder builder(context);
OwningOpRef<transform::NamedSequenceOp> transformOp =
- builder.create<transform::NamedSequenceOp>(
- rootOp->getLoc(),
+ transform::NamedSequenceOp::create(
+ builder, rootOp->getLoc(),
/*sym_name=*/"test_sequence",
/*function_type=*/
TypeAttr::get(FunctionType::get(context, TypeRange{}, TypeRange{})),
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 382da59..5685004 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -347,6 +347,7 @@ def TestCopyCount : Test_Attr<"TestCopyCount"> {
let mnemonic = "copy_count";
let parameters = (ins TestParamCopyCount:$copy_count);
let assemblyFormat = "`<` $copy_count `>`";
+ let genVerifyDecl = 1;
}
def TestConditionalAliasAttr : Test_Attr<"TestConditionalAlias"> {
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index b31e90f..5890913 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -214,6 +214,16 @@ static void printTrueFalse(AsmPrinter &p, std::optional<int> result) {
}
//===----------------------------------------------------------------------===//
+// TestCopyCountAttr Implementation
+//===----------------------------------------------------------------------===//
+
+LogicalResult TestCopyCountAttr::verify(
+ llvm::function_ref<::mlir::InFlightDiagnostic()> /*emitError*/,
+ CopyCount /*copy_count*/) {
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// CopyCountAttr Implementation
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 1bbf2cc..a4c615b 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -346,7 +346,7 @@ TestDialect::~TestDialect() {
Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
Type type, Location loc) {
- return builder.create<TestOpConstant>(loc, type, value);
+ return TestOpConstant::create(builder, loc, type, value);
}
void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
index 01ae245..1235a5f 100644
--- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
@@ -354,7 +354,7 @@ struct TestInlinerInterface : public DialectInlinerInterface {
!(input.getType().isSignlessInteger(16) ||
input.getType().isSignlessInteger(32)))
return nullptr;
- return builder.create<TestCastOp>(conversionLoc, resultType, input);
+ return TestCastOp::create(builder, conversionLoc, resultType, input);
}
Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable,
@@ -362,16 +362,16 @@ struct TestInlinerInterface : public DialectInlinerInterface {
DictionaryAttr argumentAttrs) const final {
if (!argumentAttrs.contains("test.handle_argument"))
return argument;
- return builder.create<TestTypeChangerOp>(call->getLoc(), argument.getType(),
- argument);
+ return TestTypeChangerOp::create(builder, call->getLoc(),
+ argument.getType(), argument);
}
Value handleResult(OpBuilder &builder, Operation *call, Operation *callable,
Value result, DictionaryAttr resultAttrs) const final {
if (!resultAttrs.contains("test.handle_result"))
return result;
- return builder.create<TestTypeChangerOp>(call->getLoc(), result.getType(),
- result);
+ return TestTypeChangerOp::create(builder, call->getLoc(), result.getType(),
+ result);
}
void processInlinedCallBlocks(
diff --git a/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp b/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp
index dc6413b..b98f6ce 100644
--- a/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp
+++ b/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp
@@ -43,11 +43,11 @@ static LogicalResult convertLoad(OpBuilder &builder, llvm::Instruction *inst,
if (failed(addr))
return failure();
// Create the LoadOp
- Value loadOp = builder.create<LLVM::LoadOp>(
- moduleImport.translateLoc(inst->getDebugLoc()),
+ Value loadOp = LLVM::LoadOp::create(
+ builder, moduleImport.translateLoc(inst->getDebugLoc()),
moduleImport.convertType(inst->getType()), *addr);
- moduleImport.mapValue(inst) = builder.create<SameOperandElementTypeOp>(
- loadOp.getLoc(), loadOp.getType(), loadOp, loadOp);
+ moduleImport.mapValue(inst) = SameOperandElementTypeOp::create(
+ builder, loadOp.getLoc(), loadOp.getType(), loadOp, loadOp);
return success();
}
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 3ab4ef2..53055fe 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -18,6 +18,32 @@ using namespace mlir;
using namespace test;
//===----------------------------------------------------------------------===//
+// OverridenSymbolVisibilityOp
+//===----------------------------------------------------------------------===//
+
+SymbolTable::Visibility OverriddenSymbolVisibilityOp::getVisibility() {
+ return SymbolTable::Visibility::Private;
+}
+
+static StringLiteral getVisibilityString(SymbolTable::Visibility visibility) {
+ switch (visibility) {
+ case SymbolTable::Visibility::Private:
+ return "private";
+ case SymbolTable::Visibility::Nested:
+ return "nested";
+ case SymbolTable::Visibility::Public:
+ return "public";
+ }
+}
+
+void OverriddenSymbolVisibilityOp::setVisibility(
+ SymbolTable::Visibility visibility) {
+
+ emitOpError("cannot change visibility of symbol to ")
+ << getVisibilityString(visibility);
+}
+
+//===----------------------------------------------------------------------===//
// TestBranchOp
//===----------------------------------------------------------------------===//
@@ -286,9 +312,9 @@ LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value {
return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
}));
- shapes.push_back(builder.create<tensor::FromElementsOp>(
- getLoc(), RankedTensorType::get({rank}, builder.getIndexType()),
- currShape));
+ shapes.push_back(tensor::FromElementsOp::create(
+ builder, getLoc(),
+ RankedTensorType::get({rank}, builder.getIndexType()), currShape));
}
return success();
}
@@ -1302,8 +1328,8 @@ llvm::SmallVector<MemorySlot> TestMultiSlotAlloca::getPromotableSlots() {
Value TestMultiSlotAlloca::getDefaultValue(const MemorySlot &slot,
OpBuilder &builder) {
- return builder.create<TestOpConstant>(getLoc(), slot.elemType,
- builder.getI32IntegerAttr(42));
+ return TestOpConstant::create(builder, getLoc(), slot.elemType,
+ builder.getI32IntegerAttr(42));
}
void TestMultiSlotAlloca::handleBlockArgument(const MemorySlot &slot,
@@ -1335,7 +1361,7 @@ createNewMultiAllocaWithoutSlot(const MemorySlot &slot, OpBuilder &builder,
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPoint(oldOp);
auto replacement =
- builder.create<TestMultiSlotAlloca>(oldOp->getLoc(), newTypes);
+ TestMultiSlotAlloca::create(builder, oldOp->getLoc(), newTypes);
for (auto [oldResult, newResult] :
llvm::zip_equal(remainingValues, replacement.getResults()))
oldResult.replaceAllUsesWith(newResult);
@@ -1384,7 +1410,7 @@ DenseMap<Attribute, MemorySlot> TestMultiSlotAlloca::destructure(
for (Attribute usedIndex : usedIndices) {
Type elemType = slot.subelementTypes.lookup(usedIndex);
MemRefType elemPtr = MemRefType::get({}, elemType);
- auto subAlloca = builder.create<TestMultiSlotAlloca>(getLoc(), elemPtr);
+ auto subAlloca = TestMultiSlotAlloca::create(builder, getLoc(), elemPtr);
newAllocators.push_back(subAlloca);
slotMap.try_emplace<MemorySlot>(usedIndex,
{subAlloca.getResult(0), elemType});
@@ -1412,8 +1438,8 @@ TestMultiSlotAlloca::handleDestructuringComplete(
const auto bufferizedOutType = test::TestMemrefType::get(
getContext(), outType.getShape(), outType.getElementType(), nullptr);
// replace op with memref analogy
- auto dummyMemrefOp = rewriter.create<test::TestDummyMemrefOp>(
- getLoc(), bufferizedOutType, *buffer);
+ auto dummyMemrefOp = test::TestDummyMemrefOp::create(
+ rewriter, getLoc(), bufferizedOutType, *buffer);
mlir::bufferization::replaceOpWithBufferizedValues(rewriter, getOperation(),
dummyMemrefOp.getResult());
@@ -1434,7 +1460,7 @@ TestMultiSlotAlloca::handleDestructuringComplete(
// replace op with memref analogy
auto createMemrefOp =
- rewriter.create<test::TestCreateMemrefOp>(getLoc(), *bufferizedOutType);
+ test::TestCreateMemrefOp::create(rewriter, getLoc(), *bufferizedOutType);
mlir::bufferization::replaceOpWithBufferizedValues(
rewriter, getOperation(), createMemrefOp.getResult());
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index ab3f847..2eaad55 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -119,12 +119,28 @@ def SymbolOp : TEST_Op<"symbol", [NoMemoryEffect, Symbol]> {
OptionalAttr<StrAttr>:$sym_visibility);
}
+def OverriddenSymbolVisibilityOp : TEST_Op<"overridden_symbol_visibility", [
+ DeclareOpInterfaceMethods<Symbol, ["getVisibility", "setVisibility"]>,
+]> {
+ let summary = "operation overridden symbol visibility accessors";
+ let arguments = (ins StrAttr:$sym_name);
+}
+
def SymbolScopeOp : TEST_Op<"symbol_scope",
[SymbolTable, SingleBlockImplicitTerminator<"TerminatorOp">]> {
let summary = "operation which defines a new symbol table";
let regions = (region SizedRegion<1>:$region);
}
+def SymbolScopeIsolatedOp
+ : TEST_Op<"symbol_scope_isolated", [IsolatedFromAbove, SymbolTable,
+ SingleBlockImplicitTerminator<
+ "TerminatorOp">]> {
+ let summary =
+ "operation which defines a new symbol table that is IsolatedFromAbove";
+ let regions = (region SizedRegion<1>:$region);
+}
+
def SymbolTableRegionOp : TEST_Op<"symbol_table_region", [SymbolTable]> {
let summary = "operation which defines a new symbol table without a "
"restriction on a terminator";
@@ -2035,7 +2051,7 @@ def IllegalOpWithRegion : TEST_Op<"illegal_op_with_region"> {
OpBuilder::InsertionGuard g($_builder);
Block *body = $_builder.createBlock(bodyRegion);
$_builder.setInsertionPointToEnd(body);
- $_builder.create<IllegalOpTerminator>($_state.location);
+ IllegalOpTerminator::create($_builder,$_state.location);
}]>];
}
def IllegalOpWithRegionAnchor : TEST_Op<"illegal_op_with_region_anchor">;
@@ -2738,7 +2754,7 @@ def TestLinalgConvOp :
static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
mlir::ArrayRef<mlir::NamedAttribute> attrs,
llvm::function_ref<mlir::InFlightDiagnostic()> emitError) {
- b.create<mlir::linalg::YieldOp>(block.getArguments().back());
+ mlir::linalg::YieldOp::create(b,block.getArguments().back());
}
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
@@ -2801,7 +2817,7 @@ def TestLinalgFillOp :
static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
mlir::ArrayRef<mlir::NamedAttribute> attrs,
llvm::function_ref<mlir::InFlightDiagnostic()> emitError) {
- b.create<mlir::linalg::YieldOp>(block.getArguments().back());
+ mlir::linalg::YieldOp::create(b,block.getArguments().back());
}
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
diff --git a/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp b/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp
index 6d4e5e3..cc131ad 100644
--- a/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp
@@ -313,7 +313,7 @@ ParseResult WrappingRegionOp::parse(OpAsmParser &parser,
SmallVector<Value, 8> returnOperands(wrappedOp->getResults());
OpBuilder builder(parser.getContext());
builder.setInsertionPointToEnd(&block);
- builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands);
+ TestReturnOp::create(builder, wrappedOp->getLoc(), returnOperands);
// Get the results type for the wrapping op from the terminator operands.
Operation &returnOp = body.back().back();
@@ -397,7 +397,7 @@ ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser,
builder.create(opLoc, innerOpName, /*operands=*/{lhs, rhs}, innerOpType);
// Insert a return statement in the block returning the inner-op's result.
- builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults());
+ TestReturnOp::create(builder, innerOp->getLoc(), innerOp->getResults());
// Populate the op operation-state with result-type and location.
result.addTypes(opFntype.getResults());
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index b4aeccf..eda618f 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -33,14 +33,14 @@ static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
}
static void createOpI(PatternRewriter &rewriter, Location loc, Value input) {
- rewriter.create<OpI>(loc, input);
+ OpI::create(rewriter, loc, input);
}
static void handleNoResultOp(PatternRewriter &rewriter,
OpSymbolBindingNoResult op) {
// Turn the no result op to a one-result op.
- rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(),
- op.getOperand());
+ OpSymbolBindingB::create(rewriter, op.getLoc(), op.getOperand().getType(),
+ op.getOperand());
}
static bool getFirstI32Result(Operation *op, Value &value) {
@@ -120,8 +120,8 @@ public:
return failure();
rewriter.setInsertionPointToStart(op->getBlock());
- auto constOp = rewriter.create<arith::ConstantOp>(
- op.getLoc(), rewriter.getBoolAttr(true));
+ auto constOp = arith::ConstantOp::create(rewriter, op.getLoc(),
+ rewriter.getBoolAttr(true));
rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(),
Value(constOp));
return success();
@@ -139,8 +139,7 @@ public:
LogicalResult matchAndRewrite(TestCommutative2Op op,
PatternRewriter &rewriter) const override {
- auto operand =
- dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp());
+ auto operand = op->getOperand(0).getDefiningOp<TestCommutative2Op>();
if (!operand)
return failure();
Attribute constInput;
@@ -845,8 +844,8 @@ struct TestRegionRewriteUndo : public RewritePattern {
rewriter.getUnknownLoc());
// Add an explicitly illegal operation to ensure the conversion fails.
- rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
- rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>());
+ ILLegalOpF::create(rewriter, op->getLoc(), rewriter.getIntegerType(32));
+ TestValidOp::create(rewriter, op->getLoc(), ArrayRef<Value>());
// Drop this operation.
rewriter.eraseOp(op);
@@ -865,7 +864,7 @@ struct TestCreateBlock : public RewritePattern {
Type i32Type = rewriter.getIntegerType(32);
Location loc = op->getLoc();
rewriter.createBlock(&region, region.end(), {i32Type, i32Type}, {loc, loc});
- rewriter.create<TerminatorOp>(loc);
+ TerminatorOp::create(rewriter, loc);
rewriter.eraseOp(op);
return success();
}
@@ -884,8 +883,8 @@ struct TestCreateIllegalBlock : public RewritePattern {
Location loc = op->getLoc();
rewriter.createBlock(&region, region.end(), {i32Type, i32Type}, {loc, loc});
// Create an illegal op to ensure the conversion fails.
- rewriter.create<ILLegalOpF>(loc, i32Type);
- rewriter.create<TerminatorOp>(loc);
+ ILLegalOpF::create(rewriter, loc, i32Type);
+ TerminatorOp::create(rewriter, loc);
rewriter.eraseOp(op);
return success();
}
@@ -940,7 +939,7 @@ struct TestUndoBlockErase : public ConversionPattern {
ConversionPatternRewriter &rewriter) const final {
Block *secondBlock = &*std::next(op->getRegion(0).begin());
rewriter.setInsertionPointToStart(secondBlock);
- rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
+ ILLegalOpF::create(rewriter, op->getLoc(), rewriter.getF32Type());
rewriter.eraseBlock(secondBlock);
rewriter.modifyOpInPlace(op, [] {});
return success();
@@ -1008,9 +1007,8 @@ struct TestPassthroughInvalidOp : public ConversionPattern {
// This is a 1:N replacement. Insert a test.cast op. (That's what the
// argument materialization used to do.)
flattened.push_back(
- rewriter
- .create<TestCastOp>(op->getLoc(),
- op->getOperand(it.index()).getType(), range)
+ TestCastOp::create(rewriter, op->getLoc(),
+ op->getOperand(it.index()).getType(), range)
.getResult());
}
rewriter.replaceOpWithNewOp<TestValidOp>(op, TypeRange(), flattened,
@@ -1115,8 +1113,8 @@ struct TestNonRootReplacement : public RewritePattern {
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final {
auto resultType = *op->result_type_begin();
- auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
- auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
+ auto illegalOp = ILLegalOpF::create(rewriter, op->getLoc(), resultType);
+ auto legalOp = LegalOpB::create(rewriter, op->getLoc(), resultType);
rewriter.replaceOp(illegalOp, legalOp);
rewriter.replaceOp(op, illegalOp);
@@ -1182,7 +1180,7 @@ struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> {
LogicalResult matchAndRewrite(ILLegalOpG op,
PatternRewriter &rewriter) const final {
IntegerAttr attr = rewriter.getI32IntegerAttr(0);
- Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr);
+ Value val = arith::ConstantOp::create(rewriter, op->getLoc(), attr);
rewriter.replaceOpWithNewOp<LegalOpC>(op, val);
return success();
};
@@ -1355,7 +1353,7 @@ struct TestTypeConverter : public TypeConverter {
/// 1->N type mappings.
static Value materializeCast(OpBuilder &builder, Type resultType,
ValueRange inputs, Location loc) {
- return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
+ return TestCastOp::create(builder, loc, resultType, inputs).getResult();
}
};
@@ -1363,6 +1361,10 @@ struct TestLegalizePatternDriver
: public PassWrapper<TestLegalizePatternDriver, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver)
+ TestLegalizePatternDriver() = default;
+ TestLegalizePatternDriver(const TestLegalizePatternDriver &other)
+ : PassWrapper(other) {}
+
StringRef getArgument() const final { return "test-legalize-patterns"; }
StringRef getDescription() const final {
return "Run test dialect legalization patterns";
@@ -1370,8 +1372,6 @@ struct TestLegalizePatternDriver
/// The mode of conversion to use with the driver.
enum class ConversionMode { Analysis, Full, Partial };
- TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
-
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<func::FuncDialect, test::TestDialect>();
}
@@ -1500,24 +1500,19 @@ struct TestLegalizePatternDriver
op->emitRemark() << "op '" << op->getName() << "' is legalizable";
}
- /// The mode of conversion to use.
- ConversionMode mode;
+ Option<ConversionMode> mode{
+ *this, "test-legalize-mode",
+ llvm::cl::desc("The legalization mode to use with the test driver"),
+ llvm::cl::init(ConversionMode::Partial),
+ llvm::cl::values(
+ clEnumValN(ConversionMode::Analysis, "analysis",
+ "Perform an analysis conversion"),
+ clEnumValN(ConversionMode::Full, "full", "Perform a full conversion"),
+ clEnumValN(ConversionMode::Partial, "partial",
+ "Perform a partial conversion"))};
};
} // namespace
-static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
- legalizerConversionMode(
- "test-legalize-mode",
- llvm::cl::desc("The legalization mode to use with the test driver"),
- llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
- llvm::cl::values(
- clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
- "analysis", "Perform an analysis conversion"),
- clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
- "Perform a full conversion"),
- clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
- "partial", "Perform a partial conversion")));
-
//===----------------------------------------------------------------------===//
// ConversionPatternRewriter::getRemappedValue testing. This method is used
// to get the remapped value of an original value that was replaced using
@@ -1917,15 +1912,15 @@ struct TestTypeConversionDriver
// Allow casting from F64 back to F32.
if (!resultType.isF16() && inputs.size() == 1 &&
inputs[0].getType().isF64())
- return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
+ return TestCastOp::create(builder, loc, resultType, inputs).getResult();
// Allow producing an i32 or i64 from nothing.
if ((resultType.isInteger(32) || resultType.isInteger(64)) &&
inputs.empty())
- return builder.create<TestTypeProducerOp>(loc, resultType);
+ return TestTypeProducerOp::create(builder, loc, resultType);
// Allow producing an i64 from an integer.
if (isa<IntegerType>(resultType) && inputs.size() == 1 &&
isa<IntegerType>(inputs[0].getType()))
- return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
+ return TestCastOp::create(builder, loc, resultType, inputs).getResult();
// Otherwise, fail.
return nullptr;
});
@@ -2008,7 +2003,7 @@ struct TestTargetMaterializationWithNoUses
});
converter.addTargetMaterialization(
[](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
- return builder.create<TestCastOp>(loc, type, inputs).getResult();
+ return TestCastOp::create(builder, loc, type, inputs).getResult();
});
ConversionTarget target(getContext());
@@ -2059,7 +2054,7 @@ struct TestUndoBlocksMerge : public ConversionPattern {
Operation *branchOp = firstBlock.getTerminator();
Block *secondBlock = &*(std::next(op->getRegion(0).begin()));
rewriter.setInsertionPointToStart(secondBlock);
- rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
+ ILLegalOpF::create(rewriter, op->getLoc(), rewriter.getF32Type());
auto succOperands = branchOp->getOperands();
SmallVector<Value, 2> replacements(succOperands);
rewriter.eraseOp(branchOp);
@@ -2203,9 +2198,7 @@ void registerPatternsTestPass() {
PassRegistration<TestStrictPatternDriver>();
PassRegistration<TestWalkPatternDriver>();
- PassRegistration<TestLegalizePatternDriver>([] {
- return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
- });
+ PassRegistration<TestLegalizePatternDriver>();
PassRegistration<TestRemappedValue>();
diff --git a/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp b/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp
index 103817d..7831b27 100644
--- a/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp
+++ b/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp
@@ -68,8 +68,8 @@ LogicalResult TestDialectLLVMIRTranslationInterface::amendOperation(
if (createSymbol) {
OpBuilder builder(op->getRegion(0));
- builder.create<test::SymbolOp>(
- op->getLoc(),
+ test::SymbolOp::create(
+ builder, op->getLoc(),
StringAttr::get(op->getContext(), "sym_from_attr"),
/*sym_visibility=*/nullptr);
}
diff --git a/mlir/test/lib/Dialect/TestIRDLToCpp/TestIRDLToCppDialect.cpp b/mlir/test/lib/Dialect/TestIRDLToCpp/TestIRDLToCppDialect.cpp
index bda614a..9550e4c 100644
--- a/mlir/test/lib/Dialect/TestIRDLToCpp/TestIRDLToCppDialect.cpp
+++ b/mlir/test/lib/Dialect/TestIRDLToCpp/TestIRDLToCppDialect.cpp
@@ -47,9 +47,9 @@ struct TestOpConversion : public OpConversionPattern<test_irdl_to_cpp::BeefOp> {
op, op->getResultTypes().front());
rewriter.setInsertionPointAfter(bar);
- rewriter.create<test_irdl_to_cpp::HashOp>(
- bar.getLoc(), rewriter.getIntegerType(32), adaptor.getLhs(),
- adaptor.getRhs());
+ test_irdl_to_cpp::HashOp::create(rewriter, bar.getLoc(),
+ rewriter.getIntegerType(32),
+ adaptor.getLhs(), adaptor.getRhs());
return success();
}
};
diff --git a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
index 3389a1c..6457487 100644
--- a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
+++ b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
@@ -87,9 +87,9 @@ ConvertTosaNegateOp::matchAndRewrite(Operation *op,
return failure();
auto newConstOp =
- rewriter.create<tosa::ConstOp>(op->getLoc(), dstQConstType, inputElems);
- auto newNegateOp = rewriter.create<tosa::NegateOp>(
- op->getLoc(), dstQConstType, newConstOp.getResult());
+ tosa::ConstOp::create(rewriter, op->getLoc(), dstQConstType, inputElems);
+ auto newNegateOp = tosa::NegateOp::create(
+ rewriter, op->getLoc(), dstQConstType, newConstOp.getResult());
rewriter.replaceOp(op, {newNegateOp.getResult()});
return success();
@@ -145,8 +145,8 @@ ConvertTosaConv2DOp::matchAndRewrite(Operation *op,
auto newTosaConv2DOpType =
RankedTensorType::get(outputType.getShape(), rewriter.getIntegerType(32));
- auto newTosaConv2DOp = rewriter.create<tosa::Conv2DOp>(
- op->getLoc(), newTosaConv2DOpType, tosaConv2DOp.getInput(),
+ auto newTosaConv2DOp = tosa::Conv2DOp::create(
+ rewriter, op->getLoc(), newTosaConv2DOpType, tosaConv2DOp.getInput(),
tosaConv2DOp.getWeight(), tosaConv2DOp.getBias(),
tosaConv2DOp.getPadAttr(), tosaConv2DOp.getStrideAttr(),
tosaConv2DOp.getDilationAttr(), tosaConv2DOp.getAccTypeAttr());
@@ -178,8 +178,8 @@ ConvertTosaConv2DOp::matchAndRewrite(Operation *op,
newTosaConv2DOp.getResult().getType().isUnsignedInteger();
bool outputUnsigned = outputType.isUnsignedInteger();
- auto newTosaRescaleOp = rewriter.create<tosa::RescaleOp>(
- op->getLoc(), outputType, newTosaConv2DOp.getResult(),
+ auto newTosaRescaleOp = tosa::RescaleOp::create(
+ rewriter, op->getLoc(), outputType, newTosaConv2DOp.getResult(),
getConstTensorInt<int32_t>(rewriter, op->getLoc(), {multiplier}),
getConstTensorInt<int8_t>(rewriter, op->getLoc(),
{static_cast<int8_t>(shift)}),
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index cdf44c2..97fc699 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -796,8 +796,8 @@ DiagnosedSilenceableFailure mlir::test::TestProduceInvalidIR::applyToOne(
transform::TransformState &state) {
// Provide some IR that does not verify.
rewriter.setInsertionPointToStart(&target->getRegion(0).front());
- rewriter.create<TestDummyPayloadOp>(target->getLoc(), TypeRange(),
- ValueRange(), /*failToVerify=*/true);
+ TestDummyPayloadOp::create(rewriter, target->getLoc(), TypeRange(),
+ ValueRange(), /*failToVerify=*/true);
return DiagnosedSilenceableFailure::success();
}
@@ -877,7 +877,8 @@ public:
Location loc) -> Value {
if (inputs.size() != 1)
return Value();
- return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+ return UnrealizedConversionCastOp::create(builder, loc, resultType,
+ inputs)
.getResult(0);
};
addSourceMaterialization(unrealizedCastConverter);
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index a7285ab..f89c944 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -546,8 +546,8 @@ static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder,
auto ip = builder.saveInsertionPoint();
builder.setInsertionPoint(moduleOp);
- auto global = builder.create<memref::GlobalOp>(
- loc,
+ auto global = memref::GlobalOp::create(
+ builder, loc,
/*sym_name=*/symbolName,
/*sym_visibility=*/builder.getStringAttr("private"),
/*type=*/memrefType,
@@ -560,19 +560,18 @@ static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder,
global->moveBefore(&moduleOp.front());
builder.restoreInsertionPoint(ip);
- return builder.create<memref::GetGlobalOp>(loc, memrefType, symbolName);
+ return memref::GetGlobalOp::create(builder, loc, memrefType, symbolName);
}
static Value warpReduction(Location loc, OpBuilder &builder, Value input,
CombiningKind kind, uint32_t size) {
// First reduce on a single thread to get per lane reduction value.
- Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input);
+ Value laneVal = vector::ReductionOp::create(builder, loc, kind, input);
// Parallel reduction using butterfly shuffles.
for (uint64_t i = 1; i < size; i <<= 1) {
- Value shuffled = builder
- .create<gpu::ShuffleOp>(loc, laneVal, i,
- /*width=*/size,
- /*mode=*/gpu::ShuffleMode::XOR)
+ Value shuffled = gpu::ShuffleOp::create(builder, loc, laneVal, i,
+ /*width=*/size,
+ /*mode=*/gpu::ShuffleMode::XOR)
.getShuffleResult();
laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
}
@@ -647,12 +646,11 @@ struct TestVectorDistribution
"unsupported shuffle type");
Type i32Type = builder.getIntegerType(32);
Value srcIdxI32 =
- builder.create<arith::IndexCastOp>(loc, i32Type, srcIdx);
- Value warpSzI32 = builder.create<arith::ConstantOp>(
- loc, builder.getIntegerAttr(i32Type, warpSz));
- Value result = builder
- .create<gpu::ShuffleOp>(loc, val, srcIdxI32, warpSzI32,
- gpu::ShuffleMode::IDX)
+ arith::IndexCastOp::create(builder, loc, i32Type, srcIdx);
+ Value warpSzI32 = arith::ConstantOp::create(
+ builder, loc, builder.getIntegerAttr(i32Type, warpSz));
+ Value result = gpu::ShuffleOp::create(builder, loc, val, srcIdxI32,
+ warpSzI32, gpu::ShuffleMode::IDX)
.getResult(0);
return result;
};
@@ -680,7 +678,7 @@ struct TestVectorDistribution
options.warpAllocationFn = allocateGlobalSharedMemory;
options.warpSyncronizationFn = [](Location loc, OpBuilder &builder,
gpu::WarpExecuteOnLane0Op warpOp) {
- builder.create<gpu::BarrierOp>(loc);
+ gpu::BarrierOp::create(builder, loc);
};
// Test on one pattern in isolation.
if (warpOpToSCF) {
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index f71fcf7..c6245b6 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -20,8 +20,6 @@ using namespace mlir::xegpu;
namespace {
#define DEBUG_TYPE "test-xegpu-unroll"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
struct TestXeGPUUnrollingPatterns
: public PassWrapper<TestXeGPUUnrollingPatterns,
diff --git a/mlir/test/lib/IR/TestPrintInvalid.cpp b/mlir/test/lib/IR/TestPrintInvalid.cpp
index 8697918..25d1b19 100644
--- a/mlir/test/lib/IR/TestPrintInvalid.cpp
+++ b/mlir/test/lib/IR/TestPrintInvalid.cpp
@@ -34,13 +34,14 @@ struct TestPrintInvalidPass
void runOnOperation() override {
Location loc = getOperation().getLoc();
OpBuilder builder(getOperation().getBodyRegion());
- auto funcOp = builder.create<func::FuncOp>(
- loc, "test", FunctionType::get(getOperation().getContext(), {}, {}));
+ auto funcOp = func::FuncOp::create(
+ builder, loc, "test",
+ FunctionType::get(getOperation().getContext(), {}, {}));
funcOp.addEntryBlock();
// The created function is invalid because there is no return op.
llvm::outs() << "Invalid operation:\n" << funcOp << "\n";
builder.setInsertionPointToEnd(&funcOp.getBody().front());
- builder.create<func::ReturnOp>(loc);
+ func::ReturnOp::create(builder, loc);
// Now this function is valid.
llvm::outs() << "Valid operation:\n" << funcOp << "\n";
funcOp.erase();
diff --git a/mlir/test/lib/IR/TestSlicing.cpp b/mlir/test/lib/IR/TestSlicing.cpp
index 92fd6de..5a5ac45 100644
--- a/mlir/test/lib/IR/TestSlicing.cpp
+++ b/mlir/test/lib/IR/TestSlicing.cpp
@@ -30,8 +30,8 @@ static LogicalResult createBackwardSliceFunction(Operation *op,
OpBuilder builder(parentFuncOp);
Location loc = op->getLoc();
std::string clonedFuncOpName = parentFuncOp.getName().str() + suffix.str();
- func::FuncOp clonedFuncOp = builder.create<func::FuncOp>(
- loc, clonedFuncOpName, parentFuncOp.getFunctionType());
+ func::FuncOp clonedFuncOp = func::FuncOp::create(
+ builder, loc, clonedFuncOpName, parentFuncOp.getFunctionType());
IRMapping mapper;
builder.setInsertionPointToEnd(clonedFuncOp.addEntryBlock());
for (const auto &arg : enumerate(parentFuncOp.getArguments()))
@@ -46,7 +46,7 @@ static LogicalResult createBackwardSliceFunction(Operation *op,
(void)result;
for (Operation *slicedOp : slice)
builder.clone(*slicedOp, mapper);
- builder.create<func::ReturnOp>(loc);
+ func::ReturnOp::create(builder, loc);
return success();
}
diff --git a/mlir/test/lib/Pass/TestPassManager.cpp b/mlir/test/lib/Pass/TestPassManager.cpp
index 7afe210..25c8e53 100644
--- a/mlir/test/lib/Pass/TestPassManager.cpp
+++ b/mlir/test/lib/Pass/TestPassManager.cpp
@@ -217,8 +217,8 @@ struct TestInvalidParentPass
void runOnOperation() final {
FunctionOpInterface op = getOperation();
OpBuilder b(op.getFunctionBody());
- b.create<test::TestCallOp>(op.getLoc(), TypeRange(), "some_unknown_func",
- ValueRange());
+ test::TestCallOp::create(b, op.getLoc(), TypeRange(), "some_unknown_func",
+ ValueRange());
}
};
diff --git a/mlir/test/lib/Transforms/TestDialectConversion.cpp b/mlir/test/lib/Transforms/TestDialectConversion.cpp
index 8278937..dc0538e 100644
--- a/mlir/test/lib/Transforms/TestDialectConversion.cpp
+++ b/mlir/test/lib/Transforms/TestDialectConversion.cpp
@@ -45,7 +45,7 @@ struct PDLLTypeConverter : public TypeConverter {
/// Hook for materializing a conversion.
static Value materializeCast(OpBuilder &builder, Type resultType,
ValueRange inputs, Location loc) {
- return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+ return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs)
.getResult(0);
}
};
diff --git a/mlir/test/lib/Transforms/TestInliningCallback.cpp b/mlir/test/lib/Transforms/TestInliningCallback.cpp
index c518f3f..2888c3c 100644
--- a/mlir/test/lib/Transforms/TestInliningCallback.cpp
+++ b/mlir/test/lib/Transforms/TestInliningCallback.cpp
@@ -53,8 +53,8 @@ struct InlinerCallback
mlir::Operation &call = inlineBlock->back();
builder.setInsertionPointAfter(&call);
- auto executeRegionOp = builder.create<mlir::scf::ExecuteRegionOp>(
- call.getLoc(), call.getResultTypes());
+ auto executeRegionOp = mlir::scf::ExecuteRegionOp::create(
+ builder, call.getLoc(), call.getResultTypes());
mlir::Region &region = executeRegionOp.getRegion();
// Move the inlined blocks into the region
@@ -70,8 +70,8 @@ struct InlinerCallback
if (test::TestReturnOp returnOp =
llvm::dyn_cast<test::TestReturnOp>(&op)) {
mlir::OpBuilder returnBuilder(returnOp);
- returnBuilder.create<mlir::scf::YieldOp>(returnOp.getLoc(),
- returnOp.getOperands());
+ mlir::scf::YieldOp::create(returnBuilder, returnOp.getLoc(),
+ returnOp.getOperands());
returnOp.erase();
}
}
@@ -79,8 +79,8 @@ struct InlinerCallback
// Add test.return after scf.execute_region
builder.setInsertionPointAfter(executeRegionOp);
- builder.create<test::TestReturnOp>(executeRegionOp.getLoc(),
- executeRegionOp.getResults());
+ test::TestReturnOp::create(builder, executeRegionOp.getLoc(),
+ executeRegionOp.getResults());
}
void runOnOperation() override {
diff --git a/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp b/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp
index 4e0213c..c1fb706 100644
--- a/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp
+++ b/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp
@@ -28,7 +28,7 @@ makeIsolatedFromAboveImpl(RewriterBase &rewriter,
SmallVector<Value> operands = regionOp.getOperands();
operands.append(capturedValues);
auto isolatedRegionOp =
- rewriter.create<test::IsolatedOneRegionOp>(regionOp.getLoc(), operands);
+ test::IsolatedOneRegionOp::create(rewriter, regionOp.getLoc(), operands);
rewriter.inlineRegionBefore(region, isolatedRegionOp.getRegion(),
isolatedRegionOp.getRegion().begin());
rewriter.eraseOp(regionOp);
diff --git a/mlir/test/lib/Transforms/TestTransformsOps.cpp b/mlir/test/lib/Transforms/TestTransformsOps.cpp
index 9a5632b..ff5838d 100644
--- a/mlir/test/lib/Transforms/TestTransformsOps.cpp
+++ b/mlir/test/lib/Transforms/TestTransformsOps.cpp
@@ -74,8 +74,8 @@ transform::TestMakeComposedFoldedAffineApply::applyToOne(
if (auto v = dyn_cast<Value>(ofr)) {
result = v;
} else {
- result = rewriter.create<arith::ConstantIndexOp>(
- loc, getConstantIntValue(ofr).value());
+ result = arith::ConstantIndexOp::create(rewriter, loc,
+ getConstantIntValue(ofr).value());
}
results.push_back(result.getDefiningOp());
rewriter.replaceOp(affineApplyOp, result);
diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py
index 233fef8..feaf5fb 100644
--- a/mlir/test/lit.cfg.py
+++ b/mlir/test/lit.cfg.py
@@ -343,7 +343,6 @@ if config.enable_assertions:
else:
config.available_features.add("noasserts")
-
def have_host_jit_feature_support(feature_name):
mlir_runner_exe = lit.util.which("mlir-runner", config.mlir_tools_dir)
diff --git a/mlir/test/lit.site.cfg.py.in b/mlir/test/lit.site.cfg.py.in
index 132aabe..b1185e1 100644
--- a/mlir/test/lit.site.cfg.py.in
+++ b/mlir/test/lit.site.cfg.py.in
@@ -5,6 +5,7 @@ import sys
config.target_triple = "@LLVM_TARGET_TRIPLE@"
config.llvm_src_root = "@LLVM_SOURCE_DIR@"
config.llvm_tools_dir = lit_config.substitute("@LLVM_TOOLS_DIR@")
+config.spirv_tools_tests = @LLVM_INCLUDE_SPIRV_TOOLS_TESTS@
config.llvm_shlib_ext = "@SHLIBEXT@"
config.llvm_shlib_dir = lit_config.substitute(path(r"@SHLIBDIR@"))
config.python_executable = "@Python3_EXECUTABLE@"
@@ -41,7 +42,7 @@ config.mlir_run_amx_tests = @MLIR_RUN_AMX_TESTS@
config.mlir_run_arm_sve_tests = @MLIR_RUN_ARM_SVE_TESTS@
# This is a workaround for the fact that LIT's:
# %if <cond>
-# requires <cond> to be in the set of available features.
+# requires <cond> to be in the set of available features.
# TODO: Update LIT's TestRunner so that this is not required.
if config.mlir_run_arm_sve_tests:
config.available_features.add("mlir_arm_sve_tests")
diff --git a/mlir/test/mlir-runner/simple.mlir b/mlir/test/mlir-runner/simple.mlir
index 1a03b99..21dabdd 100644
--- a/mlir/test/mlir-runner/simple.mlir
+++ b/mlir/test/mlir-runner/simple.mlir
@@ -15,10 +15,10 @@
// RUN: ls %t.o
// RUN: rm %t.o
-// RUN: mlir-runner %s -dump-object-file -object-filename=%T/test.o \
+// RUN: mlir-runner %s -dump-object-file -object-filename=%t.o \
// RUN: %if target={{s390x-.*}} %{ -argext-abi-check=false %} | FileCheck %s
-// RUN: ls %T/test.o
-// RUN: rm %T/test.o
+// RUN: ls %t.o
+// RUN: rm %t.o
// Declarations of C library functions.
llvm.func @logbf(f32) -> f32
diff --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td
index d47411d..a809611 100644
--- a/mlir/test/mlir-tblgen/attrdefs.td
+++ b/mlir/test/mlir-tblgen/attrdefs.td
@@ -115,6 +115,11 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
// DEF: return new (allocator.allocate<CompoundAAttrStorage>())
// DEF-SAME: CompoundAAttrStorage(std::move(widthOfSomething), std::move(exampleTdType), std::move(apFloat), std::move(dims), std::move(inner));
+// DEF: CompoundAAttr CompoundAAttr::getChecked(
+// DEF-SAME: int widthOfSomething, ::test::SimpleTypeA exampleTdType, ::llvm::APFloat apFloat, ::llvm::ArrayRef<int> dims, ::mlir::Type inner
+// DEF-SAME: )
+// DEF-NEXT: return Base::getChecked(emitError, context, std::move(widthOfSomething), std::move(exampleTdType), std::move(apFloat), std::move(dims), std::move(inner));
+
// DEF: ::mlir::Type CompoundAAttr::getInner() const {
// DEF-NEXT: return getImpl()->inner;
}
diff --git a/mlir/test/mlir-tblgen/op-properties-predicates.td b/mlir/test/mlir-tblgen/op-properties-predicates.td
index 7cd24aa..af09ee7 100644
--- a/mlir/test/mlir-tblgen/op-properties-predicates.td
+++ b/mlir/test/mlir-tblgen/op-properties-predicates.td
@@ -70,6 +70,12 @@ def OpWithPredicates : NS_Op<"op_with_predicates"> {
// CHECK-NEXT: if (!(((!prop.has_value())) || ((::llvm::all_of((*(prop)), [](const int64_t& baseStore) -> bool { return [](int64_t baseIface) -> bool { return ((baseIface >= 0)); }(baseStore); })) && (!(((*(prop)).empty()))))))
// CHECK: failed to satisfy constraint: optional non-empty array of non-negative int64_
+// CHECK-LABEL: ::llvm::LogicalResult OpWithPredicatesAdaptor::verify
+// Note: comprehensive emission of verifiers is tested in verifyINvariantsImpl() below
+// CHECK: int64_t tblgen_scalar = this->getScalar();
+// CHECK: if (!((tblgen_scalar >= 0)))
+// CHECK: return emitError(loc, "'test.op_with_predicates' op ""property 'scalar' failed to satisfy constraint: non-negative int64_t");
+
// CHECK-LABEL: OpWithPredicates::verifyInvariantsImpl()
// Note: for test readability, we capture [[maybe_unused]] into the variable maybe_unused
// CHECK: [[maybe_unused:\[\[maybe_unused\]\]]] int64_t tblgen_scalar = this->getScalar();
diff --git a/mlir/test/mlir-tblgen/rewriter-attributes-properties.td b/mlir/test/mlir-tblgen/rewriter-attributes-properties.td
index 40af548..23ab24e 100644
--- a/mlir/test/mlir-tblgen/rewriter-attributes-properties.td
+++ b/mlir/test/mlir-tblgen/rewriter-attributes-properties.td
@@ -44,7 +44,7 @@ def test1 : Pat<(AOp (BOp:$x $y), $_), (AOp $x, $y)>;
// CHECK: test::AOp::Properties tblgen_props;
// CHECK: tblgen_values.push_back((*x.getODSResults(0).begin()));
// CHECK: tblgen_props.y = ::llvm::dyn_cast_if_present<decltype(tblgen_props.y)>(y);
-// CHECK: tblgen_AOp_0 = rewriter.create<test::AOp>(odsLoc, tblgen_types, tblgen_values, tblgen_props);
+// CHECK: tblgen_AOp_0 = test::AOp::create(rewriter, odsLoc, tblgen_types, tblgen_values, tblgen_props);
// Note: These use strings to pick up a non-trivial storage/interface type
// difference.
diff --git a/mlir/test/mlir-tblgen/rewriter-indexing.td b/mlir/test/mlir-tblgen/rewriter-indexing.td
index 0a94746..9bb6103 100644
--- a/mlir/test/mlir-tblgen/rewriter-indexing.td
+++ b/mlir/test/mlir-tblgen/rewriter-indexing.td
@@ -55,7 +55,7 @@ def test2 : Pat<(COp $attr1, $op1, $attr2, (AOp $op2)),
// We expect ODSOperand 0 here, the attribute before the operand in BOp
// definition shouldn't shift the counter.
// CHECK: op1 = (*castedOp0.getODSOperands(0).begin()).getDefiningOp();
-// CHECK: rewriter.create<test::BOp>((*a.getODSResults(0).begin()).getLoc()
+// CHECK: test::BOp::create(rewriter, (*a.getODSResults(0).begin()).getLoc()
def test3 : Pat<(BOp $attr, (AOp:$a $input)),
(BOp $attr, (AOp $input), (location $a))>;
diff --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py
index ef1d835..66f7ec8 100644
--- a/mlir/test/python/ir/array_attributes.py
+++ b/mlir/test/python/ir/array_attributes.py
@@ -31,6 +31,7 @@ def testGetDenseElementsUnsupported():
# CHECK: unimplemented array format conversion from format:
print(e)
+
# CHECK-LABEL: TEST: testGetDenseElementsUnSupportedTypeOkIfExplicitTypeProvided
@run
def testGetDenseElementsUnSupportedTypeOkIfExplicitTypeProvided():
@@ -41,8 +42,9 @@ def testGetDenseElementsUnSupportedTypeOkIfExplicitTypeProvided():
# realistic example would be a NumPy extension type like the bfloat16
# type from the ml_dtypes package, which isn't a dependency of this
# test.
- attr = DenseElementsAttr.get(array.view(np.datetime64),
- type=IntegerType.get_signless(64))
+ attr = DenseElementsAttr.get(
+ array.view(np.datetime64), type=IntegerType.get_signless(64)
+ )
# CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
print(attr)
# CHECK: {{\[}}[1 2 3]
@@ -135,6 +137,7 @@ def testGetDenseElementsFromListMixedTypes():
# Splats.
################################################################################
+
# CHECK-LABEL: TEST: testGetDenseElementsSplatInt
@run
def testGetDenseElementsSplatInt():
@@ -617,3 +620,18 @@ def testGetDenseResourceElementsAttr():
# CHECK: BACKING MEMORY DELETED
# CHECK: EXIT FUNCTION
print("EXIT FUNCTION")
+
+
+# CHECK-LABEL: TEST: testDanglingResource
+print("TEST: testDanglingResource")
+# see https://github.com/llvm/llvm-project/pull/149414, https://github.com/llvm/llvm-project/pull/150137, https://github.com/llvm/llvm-project/pull/150561
+# This error occurs only when there is an alive context with a DenseResourceElementsAttr
+# in the end of the program, so we put it here without an encapsulating function.
+ctx = Context()
+
+with ctx, Location.unknown():
+ DenseResourceElementsAttr.get_from_buffer(
+ memoryview(np.array([1, 2, 3])),
+ "some_resource",
+ RankedTensorType.get((3,), IntegerType.get_signed(32)),
+ )