diff options
author | Boian Petkantchin <boian.petkantchin@amd.com> | 2023-12-08 11:34:44 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-08 11:34:44 -0800 |
commit | 944e031e36d9515b68b320f611edfc97d5460259 (patch) | |
tree | 6325b8e615744958739d7c2f3cf8da34940dc1a2 | |
parent | 02379d19147afda413a2bc757e8d2f5249d772d1 (diff) | |
download | llvm-944e031e36d9515b68b320f611edfc97d5460259.zip llvm-944e031e36d9515b68b320f611edfc97d5460259.tar.gz llvm-944e031e36d9515b68b320f611edfc97d5460259.tar.bz2 |
[mlir][mesh] Use tensor shape notation for the shape of a cluster (#73826)
Examle:
substitute
mesh.cluster @mesh0(rank = 2, dim_sizes = [0, 4])
with
mesh.cluster @mesh0(rank = 2, dim_sizes = ?x4)
Same as tensor/memref shapes. The only difference is for 0-rank shapes.
With tensors you would have something like `tensor<f32>`. Here to avoid
matching an empty string a 0-rank shape is denoted by `[]`.
-rw-r--r-- | mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 29 | ||||
-rw-r--r-- | mlir/include/mlir/IR/OpImplementation.h | 13 | ||||
-rw-r--r-- | mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 13 | ||||
-rw-r--r-- | mlir/lib/IR/AsmPrinter.cpp | 88 | ||||
-rw-r--r-- | mlir/test/Dialect/Mesh/canonicalization.mlir | 2 | ||||
-rw-r--r-- | mlir/test/Dialect/Mesh/invalid.mlir | 56 | ||||
-rw-r--r-- | mlir/test/Dialect/Mesh/ops.mlir | 15 | ||||
-rw-r--r-- | mlir/test/Dialect/Mesh/sharding-propagation.mlir | 2 | ||||
-rw-r--r-- | mlir/test/IR/custom-print-parse.mlir | 17 | ||||
-rw-r--r-- | mlir/test/IR/invalid-custom-print-parse.mlir | 16 | ||||
-rw-r--r-- | mlir/test/lib/Dialect/Test/TestOps.td | 14 |
11 files changed, 194 insertions, 71 deletions
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td index 361e67f..e6cdba9 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td @@ -40,26 +40,27 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> { determine the layout and the addressing space of the computation distributed across the mesh. - 3. `dim_sizes`: This attribute represents the device assignment along the - axes of the cluster. Each integer in the array corresponds to the number of - devices along a specific axis. If an integer value is 0, it implies that the - number of devices along that axis is unknown. This flexibility allows for - dynamic device assignment or configurations where the exact number of - devices might not be determined during compile time. + 3. `dim_sizes`: This attribute represents the shape of the device cluster. + It uses the same notation as a tensor shape. Also allowing for dynamic + dimensions. + This flexibility allows for dynamic device assignment or configurations + where the exact number of devices might not be determined during compile + time. + For example `2x?x4`. Example: ``` // A device mesh cluster with 3 axes, the total device number is 4 * 8 * 12 // The dimension sizes are 4, 8, 12 - mesh.cluster @mesh0(rank = 3, dim_sizes = [4, 8, 12]) + mesh.cluster @mesh0(rank = 3, dim_sizes = 4x8x12) // A device mesh cluster with 2 axes, the total device number is unknown // The first dimension size is 4 and the second is unknown - mesh.cluster @mesh1(rank = 2, dim_sizes = [4]) + mesh.cluster @mesh1(rank = 2, dim_sizes = 4) // A device mesh cluster with 2 axes, the total device number is unknown // The first dimension size is unknown and the second is 4 - mesh.cluster @mesh2(rank = 2, dim_sizes = [0, 4]) + mesh.cluster @mesh2(rank = 2, dim_sizes = ?x4) // A device mesh cluster with 2 axes, the number of devices along both axes // is unknown @@ -76,7 +77,7 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> { DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$dim_sizes ); let assemblyFormat = [{ - $sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` $dim_sizes^)? `)` + $sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` custom<DimensionList>($dim_sizes)^)? `)` attr-dict }]; let extraClassDeclaration = [{ @@ -88,7 +89,7 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> { template <typename OutIt> void canonicalDimSizes(OutIt outIt) { std::copy(getDimSizes().begin(), getDimSizes().end(), outIt); - std::fill_n(outIt, getRank() - getDimSizes().size(), 0); + std::fill_n(outIt, getRank() - getDimSizes().size(), ::mlir::ShapedType::kDynamic); } }]; let hasVerifier = 1; @@ -210,7 +211,7 @@ def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [ Example: ```mlir - mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2]) + mesh.cluster @mesh0(rank = 2, dim_sizes = 2x2) ... %1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1 : tensor<2x2xi8> -> tensor<2x4xi8> @@ -295,7 +296,7 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [ Example: ``` - mesh.cluster @mesh0(rank = 1, dim_sizes = [3]) + mesh.cluster @mesh0(rank = 1, dim_sizes = 3) ... %1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0] split_axis = 0 concat_axis = 0 @@ -527,7 +528,7 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", across the device group. Example: ``` - mesh.cluster @mesh0(rank = 1, dim_sizes = [2, 2]) + mesh.cluster @mesh0(rank = 1, dim_sizes = 2x2) ... %1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1] reduction = <max> scatter_axis = 0 diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index f1fabf9..86ed14e 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -226,6 +226,8 @@ public: printArrowTypeList(results); } + void printDimensionList(ArrayRef<int64_t> shape); + /// Class used to automatically end a cyclic region on destruction. class CyclicPrintReset { public: @@ -1762,6 +1764,17 @@ public: const SetVector<AsmDialectResourceHandle> &referencedResources, AsmResourceBuilder &builder) const {} }; + +//===--------------------------------------------------------------------===// +// Custom printers and parsers. +//===--------------------------------------------------------------------===// + +// Handles custom<DimensionList>(...) in TableGen. +void printDimensionList(OpAsmPrinter &printer, Operation *op, + ArrayRef<int64_t> dimensions); +ParseResult parseDimensionList(OpAsmParser &parser, + DenseI64ArrayAttr &dimensions); + } // namespace mlir //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp index 3b89860..1ba95f2 100644 --- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp +++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp @@ -58,11 +58,6 @@ static SmallVector<T> &canonicalizeSetAsVector(SmallVector<T> &vec) { return vec; } -template <typename DimSize> -static bool isMeshDimensionDynamic(DimSize size) { - return size <= DimSize(0); -} - using MeshAxis = int16_t; namespace { @@ -161,9 +156,9 @@ LogicalResult ClusterOp::verify() { "rank of dim_sizes is not expected to be larger than rank of cluster"); for (int64_t dimSize : dimSizes) { - if (dimSize < 0) - return emitOpError( - "dimension size of a mesh cluster is expected to be non-negative"); + if (dimSize < 0 && !ShapedType::isDynamic(dimSize)) + return emitOpError("dimension size of a mesh cluster is expected to be " + "non-negative or dynamic"); } return success(); @@ -316,7 +311,7 @@ static int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes, int64_t res = 1; for (MeshAxis axis : meshAxes) { - if (isMeshDimensionDynamic(meshShape[axis])) { + if (ShapedType::isDynamic(meshShape[axis])) { return ShapedType::kDynamic; } assert(size_t(axis) < meshShape.size()); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 4b76dcf..1f7cbf3 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -16,7 +16,9 @@ #include "mlir/IR/AsmState.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" @@ -27,6 +29,7 @@ #include "mlir/IR/Operation.h" #include "mlir/IR/Verifier.h" #include "llvm/ADT/APFloat.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" @@ -44,6 +47,7 @@ #include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/Threading.h" #include "llvm/Support/raw_ostream.h" +#include <type_traits> #include <optional> #include <tuple> @@ -425,6 +429,8 @@ public: void popCyclicPrinting(); + void printDimensionList(ArrayRef<int64_t> shape); + protected: void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, ArrayRef<StringRef> elidedAttrs = {}, @@ -1860,6 +1866,20 @@ private: // Allow direct access to the impl fields. friend AsmState; }; + +template <typename Range> +void printDimensionList(raw_ostream &stream, Range &&shape) { + llvm::interleave( + shape, stream, + [&stream](const auto &dimSize) { + if (ShapedType::isDynamic(dimSize)) + stream << "?"; + else + stream << dimSize; + }, + "x"); +} + } // namespace detail } // namespace mlir @@ -2576,13 +2596,9 @@ void AsmPrinter::Impl::printTypeImpl(Type type) { }) .Case<RankedTensorType>([&](RankedTensorType tensorTy) { os << "tensor<"; - for (int64_t dim : tensorTy.getShape()) { - if (ShapedType::isDynamic(dim)) - os << '?'; - else - os << dim; + printDimensionList(tensorTy.getShape()); + if (!tensorTy.getShape().empty()) os << 'x'; - } printType(tensorTy.getElementType()); // Only print the encoding attribute value if set. if (tensorTy.getEncoding()) { @@ -2598,13 +2614,9 @@ void AsmPrinter::Impl::printTypeImpl(Type type) { }) .Case<MemRefType>([&](MemRefType memrefTy) { os << "memref<"; - for (int64_t dim : memrefTy.getShape()) { - if (ShapedType::isDynamic(dim)) - os << '?'; - else - os << dim; + printDimensionList(memrefTy.getShape()); + if (!memrefTy.getShape().empty()) os << 'x'; - } printType(memrefTy.getElementType()); MemRefLayoutAttrInterface layout = memrefTy.getLayout(); if (!llvm::isa<AffineMapAttr>(layout) || !layout.isIdentity()) { @@ -2735,6 +2747,10 @@ LogicalResult AsmPrinter::Impl::pushCyclicPrinting(const void *opaquePointer) { void AsmPrinter::Impl::popCyclicPrinting() { state.popCyclicPrinting(); } +void AsmPrinter::Impl::printDimensionList(ArrayRef<int64_t> shape) { + detail::printDimensionList(os, shape); +} + //===--------------------------------------------------------------------===// // AsmPrinter //===--------------------------------------------------------------------===// @@ -2800,6 +2816,10 @@ void AsmPrinter::printResourceHandle(const AsmDialectResourceHandle &resource) { impl->printResourceHandle(resource); } +void AsmPrinter::printDimensionList(ArrayRef<int64_t> shape) { + detail::printDimensionList(getStream(), shape); +} + LogicalResult AsmPrinter::pushCyclicPrinting(const void *opaquePointer) { return impl->pushCyclicPrinting(opaquePointer); } @@ -3911,3 +3931,47 @@ void Block::printAsOperand(raw_ostream &os, AsmState &state) { OperationPrinter printer(os, state.getImpl()); printer.printBlockName(this); } + +//===--------------------------------------------------------------------===// +// Custom printers +//===--------------------------------------------------------------------===// +namespace mlir { + +void printDimensionList(OpAsmPrinter &printer, Operation *op, + ArrayRef<int64_t> dimensions) { + if (dimensions.empty()) + printer << "["; + printer.printDimensionList(dimensions); + if (dimensions.empty()) + printer << "]"; +} + +ParseResult parseDimensionList(OpAsmParser &parser, + DenseI64ArrayAttr &dimensions) { + // Empty list case denoted by "[]". + if (succeeded(parser.parseOptionalLSquare())) { + if (failed(parser.parseRSquare())) { + return parser.emitError(parser.getCurrentLocation()) + << "Failed parsing dimension list."; + } + dimensions = + DenseI64ArrayAttr::get(parser.getContext(), ArrayRef<int64_t>()); + return success(); + } + + // Non-empty list case. + SmallVector<int64_t> shapeArr; + if (failed(parser.parseDimensionList(shapeArr, true, false))) { + return parser.emitError(parser.getCurrentLocation()) + << "Failed parsing dimension list."; + } + if (shapeArr.empty()) { + return parser.emitError(parser.getCurrentLocation()) + << "Failed parsing dimension list. Did you mean an empty list? It " + "must be denoted by \"[]\"."; + } + dimensions = DenseI64ArrayAttr::get(parser.getContext(), shapeArr); + return success(); +} + +} // namespace mlir diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir index 5802d19..baee9fa 100644 --- a/mlir/test/Dialect/Mesh/canonicalization.mlir +++ b/mlir/test/Dialect/Mesh/canonicalization.mlir @@ -1,6 +1,6 @@ // RUN: mlir-opt --canonicalize %s | FileCheck %s -mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4]) +mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4) // CHECK-LABEL: func @all_reduce_empty_mesh_axes func.func @all_reduce_empty_mesh_axes( diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir index 2999668..a26e395 100644 --- a/mlir/test/Dialect/Mesh/invalid.mlir +++ b/mlir/test/Dialect/Mesh/invalid.mlir @@ -6,16 +6,16 @@ mesh.cluster @mesh0(rank = 0) // ----- // expected-error@+1 {{rank of dim_sizes is not expected to be larger than rank of cluster}} -mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 3, 4]) +mesh.cluster @mesh0(rank = 2, dim_sizes = 2x3x4) // ----- -// expected-error@+1 {{dimension size of a mesh cluster is expected to be non-negative}} -mesh.cluster @mesh0(rank = 2, dim_sizes = [-1]) +// expected-error@+1 {{custom op 'mesh.cluster' Failed parsing dimension list. Did you mean an empty list? It must be denoted by "[]".}} +mesh.cluster @mesh0(rank = 2, dim_sizes = -1) // ----- -mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4]) +mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4) func.func @mesh_axis_duplicated_different_subarray( // expected-error@+1 {{mesh axis duplicated}} @@ -26,7 +26,7 @@ func.func @mesh_axis_duplicated_different_subarray( // ----- -mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4]) +mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4) func.func @mesh_axis_duplicated_same_subarray( // expected-error@+1 {{mesh axis duplicated}} @@ -37,7 +37,7 @@ func.func @mesh_axis_duplicated_same_subarray( // ----- -mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4]) +mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4) func.func @mesh_axis_duplicated_bewteen_split_and_partial( // expected-error@+1 {{mesh axis duplicated}} @@ -48,7 +48,7 @@ func.func @mesh_axis_duplicated_bewteen_split_and_partial( // ----- -mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4]) +mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4) func.func @mesh_axis_negtive_in_split_part( // expected-error@+1 {{mesh axis is expected to be non-negative}} @@ -59,7 +59,7 @@ func.func @mesh_axis_negtive_in_split_part( // ----- -mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4]) +mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4) func.func @mesh_axis_negtive_in_partial( // expected-error@+1 {{mesh axis is expected to be non-negative}} @@ -80,7 +80,7 @@ func.func @all_reduce_invalid_mesh_symbol( // ----- -mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4]) +mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4) func.func @all_reduce_invalid_mesh_axis( %arg0 : tensor<4xf32>) -> tensor<4xf64> { @@ -92,7 +92,7 @@ func.func @all_reduce_invalid_mesh_axis( // ----- -mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4]) +mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4) func.func @all_reduce_duplicate_mesh_axis( %arg0 : tensor<4xf32>) -> tensor<4xf64> { @@ -104,7 +104,7 @@ func.func @all_reduce_duplicate_mesh_axis( // ----- -mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4]) +mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4) func.func @all_reduce_invalid_tensor_dimension_size( %arg0 : tensor<4xf32>) -> tensor<5xf64> { @@ -125,7 +125,7 @@ func.func @all_gather_invalid_mesh_symbol( // ----- -mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4]) +mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4) func.func @all_gather_invalid_mesh_axis( %arg0 : tensor<4xf32>) -> tensor<4xf32> { @@ -137,7 +137,7 @@ func.func @all_gather_invalid_mesh_axis( // ----- -mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4]) +mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4) func.func @all_reduce_duplicate_mesh_axis( %arg0 : tensor<4xf32>) -> tensor<4xf32> { @@ -149,7 +149,7 @@ func.func @all_reduce_duplicate_mesh_axis( // ----- -mesh.cluster @mesh0(rank = 1, dim_sizes = [1]) +mesh.cluster @mesh0(rank = 1, dim_sizes = 1) func.func @all_gather_invalid_non_gather_axis_dimension_size( %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> { @@ -161,7 +161,7 @@ func.func @all_gather_invalid_non_gather_axis_dimension_size( // ----- -mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 2]) +mesh.cluster @mesh0(rank = 2, dim_sizes = 1x2) func.func @all_gather_invalid_gather_axis_dimension_size( %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> { @@ -173,7 +173,7 @@ func.func @all_gather_invalid_gather_axis_dimension_size( // ----- -mesh.cluster @mesh0(rank = 1, dim_sizes = [1]) +mesh.cluster @mesh0(rank = 1, dim_sizes = 1) func.func @all_gather_invalid_gather_axis_dynamic_dimension( %arg0 : tensor<?xf32>) -> tensor<3xf32> { @@ -185,7 +185,7 @@ func.func @all_gather_invalid_gather_axis_dynamic_dimension( // ----- -mesh.cluster @mesh0(rank = 1, dim_sizes = [1]) +mesh.cluster @mesh0(rank = 1, dim_sizes = 1) func.func @all_gather_invalid_gather_axis( %arg0 : tensor<3xf32>) -> tensor<3xf32> { @@ -197,7 +197,7 @@ func.func @all_gather_invalid_gather_axis( // ----- -mesh.cluster @mesh0(rank = 1, dim_sizes = [1]) +mesh.cluster @mesh0(rank = 1, dim_sizes = 1) func.func @all_gather_invalid_negative_gather_axis( %arg0 : tensor<3xf32>) -> tensor<3xf32> { @@ -220,7 +220,7 @@ func.func @all_to_all_invalid_mesh_symbol( // ----- -mesh.cluster @mesh0(rank = 1, dim_sizes = [1]) +mesh.cluster @mesh0(rank = 1, dim_sizes = 1) func.func @all_to_all_duplicate_mesh_axis( %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> { @@ -233,7 +233,7 @@ func.func @all_to_all_duplicate_mesh_axis( // ----- -mesh.cluster @mesh0(rank = 2, dim_sizes = [0, 1]) +mesh.cluster @mesh0(rank = 2, dim_sizes = ?x1) func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_device_group( %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> { @@ -246,7 +246,7 @@ func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_de // ----- -mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 1]) +mesh.cluster @mesh0(rank = 2, dim_sizes = 1x1) func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dynamic_operand_dimension( %arg0 : tensor<?x6xi8>) -> tensor<3x?xi8> { @@ -259,7 +259,7 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dyna // ----- -mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 1]) +mesh.cluster @mesh0(rank = 2, dim_sizes = 1x1) func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dynamic_operand_dimension( %arg0 : tensor<3x?xi8>) -> tensor<?x3xi8> { @@ -272,7 +272,7 @@ func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dyn // ----- -mesh.cluster @mesh0(rank = 1, dim_sizes = [3]) +mesh.cluster @mesh0(rank = 1, dim_sizes = 3) func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size( %arg0 : tensor<3x2xi8>) -> tensor<1x7xi8> { @@ -285,7 +285,7 @@ func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size( // ----- -mesh.cluster @mesh0(rank = 1, dim_sizes = [3]) +mesh.cluster @mesh0(rank = 1, dim_sizes = 3) func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size( %arg0 : tensor<3x2xi8>) -> tensor<2x6xi8> { @@ -298,7 +298,7 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size( // ----- -mesh.cluster @mesh0(rank = 1, dim_sizes = [3]) +mesh.cluster @mesh0(rank = 1, dim_sizes = 3) func.func @reduce_scatter_duplicate_mesh_axis( %arg0 : tensor<?xf32>) -> tensor<?xf64> { @@ -310,7 +310,7 @@ func.func @reduce_scatter_duplicate_mesh_axis( // ----- -mesh.cluster @mesh0(rank = 1, dim_sizes = [3]) +mesh.cluster @mesh0(rank = 1, dim_sizes = 3) func.func @reduce_scatter_invalid_dynamic_dimension( %arg0 : tensor<?xf32>) -> tensor<2xf64> { @@ -322,7 +322,7 @@ func.func @reduce_scatter_invalid_dynamic_dimension( // ----- -mesh.cluster @mesh0(rank = 1, dim_sizes = [3]) +mesh.cluster @mesh0(rank = 1, dim_sizes = 3) func.func @reduce_scatter_invalid_static_dimension_size( %arg0 : tensor<3xf32>) -> tensor<2xf64> { @@ -334,7 +334,7 @@ func.func @reduce_scatter_invalid_static_dimension_size( // ----- -mesh.cluster @mesh0(rank = 1, dim_sizes = [3]) +mesh.cluster @mesh0(rank = 1, dim_sizes = 3) func.func @reduce_scatter_invalid_operand_static_dimension_size( %arg0 : tensor<4xf32>) -> tensor<?xf64> { diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir index 5b264bc..78ce276 100644 --- a/mlir/test/Dialect/Mesh/ops.mlir +++ b/mlir/test/Dialect/Mesh/ops.mlir @@ -1,18 +1,21 @@ // RUN: mlir-opt %s | mlir-opt | FileCheck %s // CHECK: mesh.cluster @mesh0 -mesh.cluster @mesh0(rank = 3, dim_sizes = [2, 2, 4]) +mesh.cluster @mesh0(rank = 3, dim_sizes = 2x2x4) -// CHECK: mesh.cluster @mesh1 -mesh.cluster @mesh1(rank = 2, dim_sizes = [4]) +// CHECK: mesh.cluster @mesh1(rank = 2, dim_sizes = 4) +mesh.cluster @mesh1(rank = 2, dim_sizes = 4) -// CHECK: mesh.cluster @mesh2 -mesh.cluster @mesh2(rank = 2, dim_sizes = [0, 4]) +// CHECK: mesh.cluster @mesh2(rank = 2, dim_sizes = ?x4) +mesh.cluster @mesh2(rank = 2, dim_sizes = ?x4) // CHECK: mesh.cluster @mesh3 mesh.cluster @mesh3(rank = 2) -mesh.cluster @mesh4(rank = 1, dim_sizes = [3]) +mesh.cluster @mesh4(rank = 1, dim_sizes = 3) + +// CHECK: mesh.cluster @mesh5(rank = 1) +mesh.cluster @mesh5(rank = 1, dim_sizes = []) // CHECK-LABEL: func @mesh_shard_encoding_fully_replicated func.func @mesh_shard_encoding_fully_replicated( diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir index bda407b..30bbd5c 100644 --- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir +++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt -sharding-propagation %s | FileCheck %s mesh.cluster @mesh_1d(rank = 1) -mesh.cluster @mesh_2d(rank = 2, dim_sizes = [2, 4]) +mesh.cluster @mesh_2d(rank = 2, dim_sizes = 2x4) mesh.cluster @mesh_3d(rank = 3) // CHECK-LABEL: func.func @element_wise_empty_sharding_info diff --git a/mlir/test/IR/custom-print-parse.mlir b/mlir/test/IR/custom-print-parse.mlir new file mode 100644 index 0000000..b157fd1 --- /dev/null +++ b/mlir/test/IR/custom-print-parse.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-opt %s | FileCheck %s + +// CHECK-LABEL: module @dimension_list +module @dimension_list { + // CHECK: test.custom_dimension_list_attr dimension_list = [] + test.custom_dimension_list_attr dimension_list = [] + // CHECK: test.custom_dimension_list_attr dimension_list = 3 + test.custom_dimension_list_attr dimension_list = 3 + // CHECK: test.custom_dimension_list_attr dimension_list = 0 + test.custom_dimension_list_attr dimension_list = 0 + // CHECK: test.custom_dimension_list_attr dimension_list = 1x2 + test.custom_dimension_list_attr dimension_list = 1x2 + // CHECK: test.custom_dimension_list_attr dimension_list = ? + test.custom_dimension_list_attr dimension_list = ? + // CHECK: test.custom_dimension_list_attr dimension_list = ?x? + test.custom_dimension_list_attr dimension_list = ?x? +} diff --git a/mlir/test/IR/invalid-custom-print-parse.mlir b/mlir/test/IR/invalid-custom-print-parse.mlir new file mode 100644 index 0000000..456b16c --- /dev/null +++ b/mlir/test/IR/invalid-custom-print-parse.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +// expected-error@+2 {{invalid dimension}} +// expected-error@+1 {{custom op 'test.custom_dimension_list_attr' Failed parsing dimension list.}} +test.custom_dimension_list_attr dimension_list = 1x-1 + +// ----- + +// expected-error@+1 {{custom op 'test.custom_dimension_list_attr' Failed parsing dimension list. Did you mean an empty list? It must be denoted by "[]".}} +test.custom_dimension_list_attr dimension_list = -1 + +// ----- + +// expected-error@+2 {{expected ']'}} +// expected-error@+1 {{custom op 'test.custom_dimension_list_attr' Failed parsing dimension list.}} +test.custom_dimension_list_attr dimension_list = [2x3] diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 1add9bd..96f66c2 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2021,6 +2021,20 @@ def AffineScopeOp : TEST_Op<"affine_scope", [AffineScope]> { } //===----------------------------------------------------------------------===// +// Custom printer/parser + +def CustomDimensionListAttrOp : TEST_Op<"custom_dimension_list_attr"> { + let description = [{ + Test printing/parsing of dimension list attribute. + }]; + let arguments = (ins DenseI64ArrayAttr:$dimension_list); + let assemblyFormat = [{ + `dimension_list` `=` custom<DimensionList>($dimension_list) + attr-dict + }]; +} + +//===----------------------------------------------------------------------===// // Test OpAsmInterface. def AsmInterfaceOp : TEST_Op<"asm_interface_op"> { |