diff options
author | Boian Petkantchin <boian.petkantchin@amd.com> | 2024-01-26 07:03:29 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-26 07:03:29 -0800 |
commit | 9a8437f50470e2658ca0b26bbc9f3da654c20dba (patch) | |
tree | 5f94883cc2c4ad96e00325351f7852fb244f25a2 | |
parent | d9245e8b471c6b3f61e3810faa9788b4994e295a (diff) | |
download | llvm-9a8437f50470e2658ca0b26bbc9f3da654c20dba.zip llvm-9a8437f50470e2658ca0b26bbc9f3da654c20dba.tar.gz llvm-9a8437f50470e2658ca0b26bbc9f3da654c20dba.tar.bz2 |
[mlir][mesh] Rename cluster to mesh (#79484)
Rename
* Op mesh.cluster -> mesh.mesh
* Op mesh.cluster_shape -> mesh.mesh_shape
* variables and attributes.
The name `mesh` is more specific to what it really represents. It is a
mesh of devices.
The name `cluster` implies a broader posibility of device
configurations. When just the word `mesh` is used the meaning can often
be inferred from the context whether it refers to the mesh dialect or a
device mesh. The full name can be used when needed.
18 files changed, 238 insertions, 243 deletions
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td index 07f9544..e835361 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td @@ -79,7 +79,8 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> { let mnemonic = "shard"; let parameters = (ins - AttrParameter<"::mlir::FlatSymbolRefAttr", "cluster placed">:$cluster, + AttrParameter<"::mlir::FlatSymbolRefAttr", + "The mesh on which tensors are sharded.">:$mesh, ArrayRefParameter<"MeshAxesAttr">:$split_axes, OptionalArrayRefParameter<"MeshAxis">:$partial_axes, OptionalParameter<"::mlir::mesh::Partial">:$partial_type @@ -91,9 +92,9 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> { The MeshSharding attribute could be used in the encoding of a `RankedTensorType` or the mesh.shard op. it contains three sub-attributes: - 1. `cluster`: this attribute is a FlatSymbolRefAttr that refers to the mesh - cluster where the distributed tensor is placed. The symbol must resolve to a - `mesh.cluster` operation. + 1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device + mesh where the distributed tensor is placed. The symbol must resolve to a + `mesh.mesh` operation. 2. `split_axes`: is an array composed of int64_t sub-arrays. The outer array's maximum size is the `rank` of the related tensor. For the i-th sub-array, if @@ -117,7 +118,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> { Example: ``` - mesh.cluster @mesh0(shape = 2x2x4) + mesh.mesh @mesh0(shape = 2x2x4) // The tensor is fully replicated on @mesh0. // Currently, there must be at least one sub-array present in axes, even @@ -140,12 +141,12 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> { ``` }]; let assemblyFormat = [{ - `<` $cluster `,` `[` $split_axes `]` (`,` `partial` `=` $partial_type `[` + `<` $mesh `,` `[` $split_axes `]` (`,` `partial` `=` $partial_type `[` $partial_axes^ `]`)? `>` }]; let builders = [ - AttrBuilder<(ins "FlatSymbolRefAttr":$cluster, + AttrBuilder<(ins "FlatSymbolRefAttr":$mesh, "ArrayRef<SmallVector<MeshAxis>>":$split_axes, "ArrayRef<MeshAxis>": $partial_axes, "mesh::Partial": $partial_type), [{ @@ -153,12 +154,12 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> { split_axes, [&](ArrayRef<MeshAxis> array) { return MeshAxesAttr::get($_ctxt, array); }); - return $_get($_ctxt, cluster, splitAxesAttr, partial_axes, + return $_get($_ctxt, mesh, splitAxesAttr, partial_axes, partial_type); }]>, - AttrBuilder<(ins "FlatSymbolRefAttr":$cluster, + AttrBuilder<(ins "FlatSymbolRefAttr":$mesh, "ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{ - return MeshShardingAttr::get($_ctxt, cluster, split_axes, {}, Partial::Sum); + return MeshShardingAttr::get($_ctxt, mesh, split_axes, {}, Partial::Sum); }]> ]; diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td index 78ff8bd..7b30102 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td @@ -26,17 +26,17 @@ class Mesh_Op<string mnemonic, list<Trait> traits = []> : Op<Mesh_Dialect, mnemonic, traits> { } -def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> { - let summary = "representing a mesh cluster"; +def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> { + let summary = "Description of a device/process mesh."; let description = [{ - The mesh.cluster operation is a symbol operation that identifies a specific - mesh cluster. The operation has three attributes: + The mesh.mesh operation is a symbol operation that identifies a specific + mesh. The operation has three attributes: - 1. `sym_name`: This attribute uniquely identifies the name of the mesh - cluster. This name serves as a symbolic reference to the cluster throughout + 1. `sym_name`: This attribute uniquely identifies the name of the mesh. + This name serves as a symbolic reference to the mesh throughout the MLIR module, allowing for consistent referencing and easier debugging. - 2. `shape`: This attribute represents the shape of the device cluster. + 2. `shape`: This attribute represents the shape of the device mesh. It uses the same notation as a tensor shape. Also allowing for dynamic dimensions. This flexibility allows for dynamic device assignment or configurations @@ -46,21 +46,21 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> { Example: ``` - // A device mesh cluster with 3 axes, the total device number is 4 * 8 * 12 + // A device mesh with 3 axes, the total device number is 4 * 8 * 12 // The dimension sizes are 4, 8, 12 - mesh.cluster @mesh0(shape = 4x8x12) + mesh.mesh @mesh0(shape = 4x8x12) - // A device mesh cluster with 2 axes, the total device number is unknown + // A device mesh with 2 axes, the total device number is unknown // The first dimension size is 4 and the second is unknown - mesh.cluster @mesh1(shape = 4x?) + mesh.mesh @mesh1(shape = 4x?) - // A device mesh cluster with 2 axes, the total device number is unknown + // A device mesh with 2 axes, the total device number is unknown // The first dimension size is unknown and the second is 4 - mesh.cluster @mesh2(shape = ?x4) + mesh.mesh @mesh2(shape = ?x4) - // A device mesh cluster with 2 axes, the number of devices along both axes + // A device mesh with 2 axes, the number of devices along both axes // is unknown - mesh.cluster @mesh3(shape = ?x?) + mesh.mesh @mesh3(shape = ?x?) // Used in the mesh sharding attribute to extend the standard tensor to // distributed @@ -81,9 +81,9 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> { let hasVerifier = 1; } -def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [ +def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [ Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> { - let summary = "Get the shape of the cluster."; + let summary = "Get the shape of the mesh."; let arguments = (ins FlatSymbolRefAttr:$mesh, DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes @@ -99,13 +99,13 @@ def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [ }]; let builders = [ - OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>, + OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>, OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)> ]; } def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> { - let summary = "Annotate on how a tensor is sharded across a mesh cluster."; + let summary = "Annotate on how a tensor is sharded across a mesh."; let description = [{ The mesh.shard operation is designed to specify and guide the sharding behavior of a tensor value across a mesh topology. This operation has one @@ -115,7 +115,7 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> { annotated for sharding. 2. `shard`: This attribute is type of `MeshSharding`, which is the core data - structure to represent distributed tensor in mesh cluster. + structure to represent distribution of a tensor on a mesh. 3. `annotate_for_users`: A unit attribute addressing the scenario when a tensor's sharding annotation differs based on its context of use (either as @@ -217,7 +217,7 @@ def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [ attr-dict `:` type($result) }]; let builders = [ - OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>, + OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>, OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)> ]; } @@ -239,7 +239,7 @@ def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [ let results = (outs Index:$result); let assemblyFormat = "`on` $mesh attr-dict `:` type($result)"; let builders = [ - OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)> + OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)> ]; } @@ -268,7 +268,7 @@ def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [ Example: ```mlir - mesh.cluster @mesh0(shape = 2x2) + mesh.mesh @mesh0(shape = 2x2) ... %1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1 : tensor<2x2xi8> -> tensor<2x4xi8> @@ -353,7 +353,7 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [ Example: ``` - mesh.cluster @mesh0(shape = 3) + mesh.mesh @mesh0(shape = 3) ... %1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0] split_axis = 0 concat_axis = 0 @@ -410,7 +410,7 @@ def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [ Example: ``` - mesh.cluster @mesh0(shape = 2x2) + mesh.mesh @mesh0(shape = 2x2) %1 = mesh.broadcast %0 on @mesh0 mesh_axes = [0] @@ -466,7 +466,7 @@ def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [ Example: ```mlir - mesh.cluster @mesh0(shape = 2x2) + mesh.mesh @mesh0(shape = 2x2) ... %1 = mesh.gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1 root = [1] @@ -589,7 +589,7 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", across the device group. Example: ``` - mesh.cluster @mesh0(shape = 2x2) + mesh.mesh @mesh0(shape = 2x2) ... %1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1] reduction = <max> scatter_axis = 0 @@ -652,7 +652,7 @@ def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [ Example: ``` - mesh.cluster @mesh0(shape = 2x2) + mesh.mesh @mesh0(shape = 2x2) %1 = mesh.scatter %0 on @mesh0 mesh_axes = [0] scatter_axis = 0 root = [1] @@ -748,7 +748,7 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [ Example: ``` - mesh.cluster @mesh0(shape = 2x4) + mesh.mesh @mesh0(shape = 2x4) %1 = mesh.shift on @mesh0 mesh_axes = [1] shift_axis = 1 offset = 2 rotate : tensor<2xi8> -> tensor<2xi8> diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h index a32274d..3bef7e6 100644 --- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h +++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h @@ -25,14 +25,14 @@ struct ShardingOption { // An array of int array. The sub-array at the i-th position signifies the // mesh axes the i-th loop will be sharded on. ShardingArray shardingArray = {}; - FlatSymbolRefAttr cluster = nullptr; + FlatSymbolRefAttr mesh = nullptr; // `empty` being true indicates that no sharding information can be inferred // at present. Note that it is different from the case where an operation is // not sharded. bool empty = false; ShardingOption() = default; - ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr cluster) - : shardingArray(std::move(shardingArray)), cluster(cluster) {} + ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr mesh) + : shardingArray(std::move(shardingArray)), mesh(mesh) {} }; // This method retrieves the 'MeshShardingAttr' attribute from a given operation diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h index f71bb9b..7cb992aa 100644 --- a/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h +++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h @@ -17,14 +17,14 @@ namespace mlir { namespace mesh { // Return the sharded shape `shape` acording ot sharding `sharding`. -ShapedType shardShapedType(ShapedType shape, ClusterOp mesh, +ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshShardingAttr sharding); // Insert resharding spmdization of the value `sourceShardValue` // from sharding `source` to sharding `target`. // `sourceShardValue` is the already sharded value according to `source`. -TypedValue<ShapedType> reshard(OpBuilder &builder, ClusterOp mesh, - ShardOp source, ShardOp target, +TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source, + ShardOp target, TypedValue<ShapedType> sourceShardValue); void reshardingRegisterDependentDialects(DialectRegistry ®istry); diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp index f6b6b7c..994a017 100644 --- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp +++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp @@ -114,10 +114,10 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value, // Mesh utilities //===----------------------------------------------------------------------===// -static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, - SymbolTableCollection &symbolTable) { - mesh::ClusterOp mesh = - symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(op, meshSymbol); +static FailureOr<MeshOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, + SymbolTableCollection &symbolTable) { + mesh::MeshOp mesh = + symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(op, meshSymbol); if (!mesh) { return op->emitError() << "Undefined required mesh symbol \"" << meshSymbol.getValue() << "\"."; @@ -144,7 +144,7 @@ bool isUnique(It begin, It end) { } static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes, - ClusterOp mesh) { + MeshOp mesh) { SmallVector<MeshAxis> sorted = llvm::to_vector(axes); llvm::sort(sorted); if (!isUnique(sorted.begin(), sorted.end())) { @@ -192,22 +192,22 @@ Partial mesh::getPartialTypeFromReduction(IteratorType iType) { } //===----------------------------------------------------------------------===// -// mesh.cluster op +// mesh.mesh op //===----------------------------------------------------------------------===// -LogicalResult ClusterOp::verify() { +LogicalResult MeshOp::verify() { int64_t rank = getRank(); if (rank <= 0) - return emitOpError("rank of cluster is expected to be a positive integer"); + return emitOpError("rank of mesh is expected to be a positive integer"); if (getShape().size() > size_t(rank)) return emitOpError( - "rank of shape is not expected to be larger than rank of cluster"); + "rank of shape is not expected to be larger than rank of mesh"); for (int64_t dimSize : getShape()) { if (dimSize < 0 && !ShapedType::isDynamic(dimSize)) - return emitOpError("dimension size of a mesh cluster is expected to be " + return emitOpError("dimension size of a mesh is expected to be " "non-negative or dynamic"); } @@ -215,11 +215,11 @@ LogicalResult ClusterOp::verify() { } //===----------------------------------------------------------------------===// -// mesh.cluster_shape op +// mesh.mesh_shape op //===----------------------------------------------------------------------===// LogicalResult -ClusterShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) { +MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) { auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable); if (failed(mesh)) { return failure(); @@ -238,16 +238,16 @@ ClusterShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return success(); } -void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState, - ClusterOp mesh) { +void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState, + MeshOp mesh) { build(odsBuilder, odsState, SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()), mesh.getSymName(), MeshAxesAttr::get(odsBuilder.getContext(), SmallVector<MeshAxis>())); } -void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState, - StringRef mesh, ArrayRef<MeshAxis> axes) { +void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState, + StringRef mesh, ArrayRef<MeshAxis> axes) { build(odsBuilder, odsState, SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh, MeshAxesAttr::get(odsBuilder.getContext(), axes)); @@ -261,7 +261,7 @@ LogicalResult MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError, FlatSymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes, ArrayRef<MeshAxis> partialAxes, Partial) { - // TODO: At present cluster symbol ref is not verified. This is due to the + // TODO: At present mesh symbol ref is not verified. This is due to the // difficulty in fetching the corresponding symbol op based on an attribute. llvm::SmallSet<MeshAxis, 4> visitedAxes; @@ -292,8 +292,7 @@ bool MeshShardingAttr::operator==(Attribute rhs) const { } bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const { - if (getCluster() != rhs.getCluster() || - getPartialAxes() != rhs.getPartialAxes()) { + if (getMesh() != rhs.getMesh() || getPartialAxes() != rhs.getPartialAxes()) { return false; } @@ -342,7 +341,7 @@ ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) { } void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState, - ClusterOp mesh) { + MeshOp mesh) { build(odsBuilder, odsState, SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()), mesh.getSymName(), ArrayRef<MeshAxis>()); @@ -369,7 +368,7 @@ ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) { } void ProcessLinearIndexOp::build(OpBuilder &odsBuilder, - OperationState &odsState, ClusterOp mesh) { + OperationState &odsState, MeshOp mesh) { build(odsBuilder, odsState, mesh.getSymName()); } @@ -427,7 +426,7 @@ static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, } template <typename Op> -static FailureOr<ClusterOp> +static FailureOr<MeshOp> getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) { auto mesh = ::getMesh(op.getOperation(), op.getMeshAttr(), symbolTable); if (failed(mesh)) { diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp index dca7e86..5dc91ff 100644 --- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp +++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp @@ -215,11 +215,10 @@ namespace { // Update the given `shardingOption` according to `meshAxes` and `loopIdx` static LogicalResult fillShardingOption(Operation *op, ShardingOption &shardingOption, - FlatSymbolRefAttr cluster, + FlatSymbolRefAttr mesh, ArrayRef<MeshAxis> meshAxes, unsigned loopIdx) { - if ((shardingOption.cluster && cluster && - shardingOption.cluster != cluster) || + if ((shardingOption.mesh && mesh && shardingOption.mesh != mesh) || (!shardingOption.shardingArray[loopIdx].empty() && shardingOption.shardingArray[loopIdx] != meshAxes)) { LLVM_DEBUG(DBGS() << "sharding option conflicts on loop iterator " @@ -238,8 +237,8 @@ static LogicalResult fillShardingOption(Operation *op, } } } - if (cluster) - shardingOption.cluster = cluster; + if (mesh) + shardingOption.mesh = mesh; if (shardingOption.shardingArray[loopIdx].empty()) shardingOption.shardingArray[loopIdx].append(meshAxes.begin(), meshAxes.end()); @@ -281,7 +280,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption( auto dim = cast<AffineDimExpr>(expr); unsigned index = dim.getPosition(); visitedLoopIndices.insert(index); - if (failed(fillShardingOption(op, shardingOption, shardAttr.getCluster(), + if (failed(fillShardingOption(op, shardingOption, shardAttr.getMesh(), axes, index))) return failure(); } @@ -333,8 +332,8 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption( if (loopIndices->size() == 1) { unsigned loopIdx = *loopIndices->begin(); visitedLoopIndices.insert(loopIdx); - if (failed(fillShardingOption(op, shardingOption, - shardAttr.getCluster(), axes, loopIdx))) + if (failed(fillShardingOption(op, shardingOption, shardAttr.getMesh(), + axes, loopIdx))) return failure(); } // If multiple loop indices correspond to a dimension of an operand, it is @@ -437,9 +436,8 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result, } removeTrailingEmptySubArray(splitAxes); - MeshShardingAttr shardAttr = - MeshShardingAttr::get(b.getContext(), shardingOption.cluster, splitAxes, - partialAxes, partialType); + MeshShardingAttr shardAttr = MeshShardingAttr::get( + b.getContext(), shardingOption.mesh, splitAxes, partialAxes, partialType); OpBuilder::InsertionGuard guard(b); b.setInsertionPointAfterValue(result); auto shardOp = b.create<ShardOp>(result.getLoc(), resultType, result, @@ -485,7 +483,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand, removeTrailingEmptySubArray(splitAxes); MeshShardingAttr shardAttr = - MeshShardingAttr::get(b.getContext(), shardingOption.cluster, splitAxes); + MeshShardingAttr::get(b.getContext(), shardingOption.mesh, splitAxes); OpBuilder::InsertionGuard guard(b); b.setInsertionPoint(opOperand.getOwner()); auto shardOp = b.create<ShardOp>(operand.getLoc(), operandType, operand, diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp index 429e684..c0273cd 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp @@ -55,20 +55,19 @@ namespace { // DialectFoldInterface, because it needs a SymbolTableCollection to cache the // symbol tables. // We can't use DialectFoldInterface since the cache may be invalidated by some -// pass changing the referenced ClusterOp ops. -struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> { +// pass changing the referenced MeshOp ops. +struct MeshShapeFolder : OpRewritePattern<MeshShapeOp> { template <typename... OpRewritePatternArgs> - ClusterShapeFolder(SymbolTableCollection &symbolTableCollection, - OpRewritePatternArgs &&...opRewritePatternArgs) + MeshShapeFolder(SymbolTableCollection &symbolTableCollection, + OpRewritePatternArgs &&...opRewritePatternArgs) : OpRewritePattern( std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...), symbolTableCollection(symbolTableCollection) {} - LogicalResult matchAndRewrite(ClusterShapeOp op, + LogicalResult matchAndRewrite(MeshShapeOp op, PatternRewriter &rewriter) const override { ImplicitLocOpBuilder builder(op->getLoc(), rewriter); - ClusterOp mesh = - symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>( - op.getOperation(), op.getMeshAttr()); + MeshOp mesh = symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>( + op.getOperation(), op.getMeshAttr()); if (!mesh) { return failure(); } @@ -104,8 +103,8 @@ struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> { // Leave only the dynamic mesh axes to be queried. if (!newShapeOpMeshAxes.empty()) { - ClusterShapeOp newShapeOp = - builder.create<ClusterShapeOp>(mesh.getSymName(), newShapeOpMeshAxes); + MeshShapeOp newShapeOp = + builder.create<MeshShapeOp>(mesh.getSymName(), newShapeOpMeshAxes); for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) { newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i]; } @@ -123,8 +122,7 @@ private: void populateFoldingPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) { - patterns.add<ClusterShapeFolder>(symbolTableCollection, - patterns.getContext()); + patterns.add<MeshShapeFolder>(symbolTableCollection, patterns.getContext()); } } // namespace mesh diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp index 9478b2e..593158d 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp @@ -84,7 +84,7 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape, } } -ShapedType shardShapedType(ShapedType shape, ClusterOp mesh, +ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshShardingAttr sharding) { using Dim = std::decay_t<decltype(shape.getDimSize(0))>; SmallVector<Dim> resShapeArr(shape.getShape().size()); @@ -141,7 +141,7 @@ handlePartialAxesDuringResharding(OpBuilder &builder, TypedValue<ShapedType> resultValue = builder .create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(), - sourceSharding.getCluster().getLeafReference(), + sourceSharding.getMesh().getLeafReference(), allReduceMeshAxes, sourceShard, sourceSharding.getPartialType()) .getResult() @@ -154,7 +154,7 @@ handlePartialAxesDuringResharding(OpBuilder &builder, return targetShardingPartialAxesSet.contains(a); }); MeshShardingAttr resultSharding = - MeshShardingAttr::get(builder.getContext(), sourceSharding.getCluster(), + MeshShardingAttr::get(builder.getContext(), sourceSharding.getMesh(), sourceSharding.getSplitAxes(), remainingPartialAxes, sourceSharding.getPartialType()); return {resultValue, resultSharding}; @@ -175,7 +175,7 @@ targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding, targetShardingSplitAxes[splitTensorAxis] = MeshAxesAttr::get(ctx, targetSplitAxes); return MeshShardingAttr::get( - ctx, sourceSharding.getCluster(), targetShardingSplitAxes, + ctx, sourceSharding.getMesh(), targetShardingSplitAxes, sourceSharding.getPartialAxes(), sourceSharding.getPartialType()); } @@ -197,7 +197,7 @@ static ShapedType targetShapeInSplitLastAxis(ShapedType sourceShape, static std::tuple<TypedValue<ShapedType>, MeshShardingAttr> splitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshShardingAttr sourceSharding, - TypedValue<ShapedType> sourceShard, ClusterOp mesh, + TypedValue<ShapedType> sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis) { MLIRContext *ctx = builder.getContext(); builder.setInsertionPointAfterValue(sourceShard); @@ -217,8 +217,8 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder, Value meshAxisSize = builder - .create<ClusterShapeOp>(mesh.getSymName(), - SmallVector<MeshAxis>({splitMeshAxis})) + .create<MeshShapeOp>(mesh.getSymName(), + SmallVector<MeshAxis>({splitMeshAxis})) .getResult()[0]; Value sourceAxisSize = @@ -305,7 +305,7 @@ detectSplitLastAxisInResharding(MeshShardingAttr sourceSharding, } static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>> -trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh, +trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue<ShapedType> sourceShard) { @@ -366,7 +366,7 @@ targetShardingInUnsplitLastAxis(MLIRContext *ctx, targetShardingSplitAxes[splitTensorAxis] = MeshAxesAttr::get(ctx, targetSplitAxes); return MeshShardingAttr::get( - ctx, sourceSharding.getCluster(), targetShardingSplitAxes, + ctx, sourceSharding.getMesh(), targetShardingSplitAxes, sourceSharding.getPartialAxes(), sourceSharding.getPartialType()); } @@ -382,7 +382,7 @@ static std::tuple<TypedValue<ShapedType>, MeshShardingAttr> unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshShardingAttr sourceSharding, ShapedType sourceUnshardedShape, - TypedValue<ShapedType> sourceShard, ClusterOp mesh, + TypedValue<ShapedType> sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis) { MLIRContext *ctx = builder.getContext(); builder.setInsertionPointAfterValue(sourceShard); @@ -406,7 +406,7 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, } static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>> -tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh, +tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, ShapedType sourceUnshardedShape, @@ -495,7 +495,7 @@ targetShardingInMoveLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding, MeshAxesAttr::get(ctx, targetSplitAxes); return MeshShardingAttr::get( - ctx, sourceSharding.getCluster(), targetShardingSplitAxes, + ctx, sourceSharding.getMesh(), targetShardingSplitAxes, sourceSharding.getPartialAxes(), sourceSharding.getPartialType()); } @@ -512,7 +512,7 @@ static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape, } static std::tuple<TypedValue<ShapedType>, MeshShardingAttr> -moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh, +moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, ShapedType sourceUnshardedShape, TypedValue<ShapedType> sourceShard, @@ -541,7 +541,7 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh, } static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>> -tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh, +tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, ShapedType sourceUnshardedShape, @@ -561,7 +561,7 @@ tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh, // Currently the sharded tensor axes must be exactly divisible by the single // mesh axis size. static TypedValue<ShapedType> -reshardOn1DMesh(ImplicitLocOpBuilder &builder, ClusterOp mesh, +reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue<ShapedType> sourceUnshardedValue, @@ -604,7 +604,7 @@ reshardOn1DMesh(ImplicitLocOpBuilder &builder, ClusterOp mesh, return targetShard; } -TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, ClusterOp mesh, +TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue<ShapedType> sourceUnshardedValue, @@ -616,8 +616,8 @@ TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, ClusterOp mesh, sourceUnshardedValue, sourceShard); } -TypedValue<ShapedType> reshard(OpBuilder &builder, ClusterOp mesh, - ShardOp source, ShardOp target, +TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source, + ShardOp target, TypedValue<ShapedType> sourceShardValue) { assert(!source.getAnnotateForUsers()); assert(target.getAnnotateForUsers()); diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp index c27e173..5c23446 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp @@ -24,7 +24,7 @@ namespace mlir::mesh { namespace { /// Lower `mesh.process_multi_index` into expression using -/// `mesh.process_linear_index` and `mesh.cluster_shape`. +/// `mesh.process_linear_index` and `mesh.mesh_shape`. struct ProcessMultiIndexOpLowering : OpRewritePattern<ProcessMultiIndexOp> { template <typename... OpRewritePatternArgs> ProcessMultiIndexOpLowering(SymbolTableCollection &symbolTableCollection, @@ -35,9 +35,8 @@ struct ProcessMultiIndexOpLowering : OpRewritePattern<ProcessMultiIndexOp> { LogicalResult matchAndRewrite(ProcessMultiIndexOp op, PatternRewriter &rewriter) const override { - ClusterOp mesh = - symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>( - op.getOperation(), op.getMeshAttr()); + MeshOp mesh = symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>( + op.getOperation(), op.getMeshAttr()); if (!mesh) { return failure(); } @@ -45,7 +44,7 @@ struct ProcessMultiIndexOpLowering : OpRewritePattern<ProcessMultiIndexOp> { ImplicitLocOpBuilder builder(op->getLoc(), rewriter); builder.setInsertionPointAfter(op.getOperation()); Value linearIndex = builder.create<ProcessLinearIndexOp>(mesh); - ValueRange meshShape = builder.create<ClusterShapeOp>(mesh).getResults(); + ValueRange meshShape = builder.create<MeshShapeOp>(mesh).getResults(); SmallVector<Value> completeMultiIndex = builder.create<affine::AffineDelinearizeIndexOp>(linearIndex, meshShape) .getMultiIndex(); diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir index 4cc009e..23c5b25 100644 --- a/mlir/test/Dialect/Mesh/canonicalization.mlir +++ b/mlir/test/Dialect/Mesh/canonicalization.mlir @@ -1,6 +1,6 @@ // RUN: mlir-opt --canonicalize %s | FileCheck %s -mesh.cluster @mesh0(shape = 2x4) +mesh.mesh @mesh0(shape = 2x4) // CHECK-LABEL: func @all_reduce_empty_mesh_axes func.func @all_reduce_empty_mesh_axes( diff --git a/mlir/test/Dialect/Mesh/folding.mlir b/mlir/test/Dialect/Mesh/folding.mlir index 9162dc5..369f316d 100644 --- a/mlir/test/Dialect/Mesh/folding.mlir +++ b/mlir/test/Dialect/Mesh/folding.mlir @@ -1,22 +1,22 @@ // RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s -mesh.cluster @mesh0(shape = 4x?x2) -mesh.cluster @mesh1(shape = 2x3) +mesh.mesh @mesh0(shape = 4x?x2) +mesh.mesh @mesh1(shape = 2x3) -// CHECK-LABEL: func.func @cluster_shape_op_folding -func.func @cluster_shape_op_folding() -> (index, index) { +// CHECK-LABEL: func.func @mesh_shape_op_folding +func.func @mesh_shape_op_folding() -> (index, index) { // CHECK: %[[AXIS_2_SIZE:.*]] = arith.constant 2 : index - // CHECK: %[[AXIS_1_SIZE:.*]] = mesh.cluster_shape @mesh0 axes = [1] : index - %0:2 = mesh.cluster_shape @mesh0 axes = [2, 1] : index, index + // CHECK: %[[AXIS_1_SIZE:.*]] = mesh.mesh_shape @mesh0 axes = [1] : index + %0:2 = mesh.mesh_shape @mesh0 axes = [2, 1] : index, index // CHECK: return %[[AXIS_2_SIZE]], %[[AXIS_1_SIZE]] return %0#0, %0#1 : index, index } -// CHECK-LABEL: func.func @cluster_shape_op_folding_all_axes_static_mesh -func.func @cluster_shape_op_folding_all_axes_static_mesh() -> (index, index) { +// CHECK-LABEL: func.func @mesh_shape_op_folding_all_axes_static_mesh +func.func @mesh_shape_op_folding_all_axes_static_mesh() -> (index, index) { // CHECK: %[[AXIS_0_SIZE:.*]] = arith.constant 2 : index // CHECK: %[[AXIS_1_SIZE:.*]] = arith.constant 3 : index - %0:2 = mesh.cluster_shape @mesh1 : index, index + %0:2 = mesh.mesh_shape @mesh1 : index, index // CHECK: return %[[AXIS_0_SIZE]], %[[AXIS_1_SIZE]] return %0#0, %0#1 : index, index } diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir index 8a1fb80..259e4eb 100644 --- a/mlir/test/Dialect/Mesh/invalid.mlir +++ b/mlir/test/Dialect/Mesh/invalid.mlir @@ -1,16 +1,16 @@ // RUN: mlir-opt -split-input-file -verify-diagnostics %s -// expected-error@+1 {{rank of cluster is expected to be a positive integer}} -mesh.cluster @mesh0(shape = []) +// expected-error@+1 {{rank of mesh is expected to be a positive integer}} +mesh.mesh @mesh0(shape = []) // ----- -// expected-error@+1 {{custom op 'mesh.cluster' Failed parsing dimension list. Did you mean an empty list? It must be denoted by "[]".}} -mesh.cluster @mesh0(shape = -1) +// expected-error@+1 {{custom op 'mesh.mesh' Failed parsing dimension list. Did you mean an empty list? It must be denoted by "[]".}} +mesh.mesh @mesh0(shape = -1) // ----- -mesh.cluster @mesh0(shape = 2x4) +mesh.mesh @mesh0(shape = 2x4) func.func @mesh_axis_duplicated_different_subarray( // expected-error@+1 {{mesh axis duplicated}} @@ -21,7 +21,7 @@ func.func @mesh_axis_duplicated_different_subarray( // ----- -mesh.cluster @mesh0(shape = 2x4) +mesh.mesh @mesh0(shape = 2x4) func.func @mesh_axis_duplicated_same_subarray( // expected-error@+1 {{mesh axis duplicated}} @@ -32,7 +32,7 @@ func.func @mesh_axis_duplicated_same_subarray( // ----- -mesh.cluster @mesh0(shape = 2x4) +mesh.mesh @mesh0(shape = 2x4) func.func @mesh_axis_duplicated_bewteen_split_and_partial( // expected-error@+1 {{mesh axis duplicated}} @@ -43,7 +43,7 @@ func.func @mesh_axis_duplicated_bewteen_split_and_partial( // ----- -mesh.cluster @mesh0(shape = 2x4) +mesh.mesh @mesh0(shape = 2x4) func.func @mesh_axis_negtive_in_split_part( // expected-error@+1 {{mesh axis is expected to be non-negative}} @@ -54,7 +54,7 @@ func.func @mesh_axis_negtive_in_split_part( // ----- -mesh.cluster @mesh0(shape = 2x4) +mesh.mesh @mesh0(shape = 2x4) func.func @mesh_axis_negtive_in_partial( // expected-error@+1 {{mesh axis is expected to be non-negative}} @@ -67,61 +67,61 @@ func.func @mesh_axis_negtive_in_partial( func.func @sharding_attribute_invalid_nested_symbol(%arg0 : tensor<4x8xf32>) { // expected-error@+2 {{custom op 'mesh.shard' invalid kind of attribute specified}} - // expected-error@+1 {{custom op 'mesh.shard' failed to parse MeshSharding parameter 'cluster' which is to be a `::mlir::FlatSymbolRefAttr`}} + // expected-error@+1 {{custom op 'mesh.shard' failed to parse MeshSharding parameter 'mesh' which is to be a `::mlir::FlatSymbolRefAttr`}} %0 = mesh.shard %arg0 to <@a::@b, [[0]]> : tensor<4x8xf32> } // ----- -mesh.cluster @mesh0(shape = 2x4) +mesh.mesh @mesh0(shape = 2x4) -func.func @cluster_shape_mesh_axis_out_of_bounds() -> (index, index) { +func.func @mesh_shape_mesh_axis_out_of_bounds() -> (index, index) { // expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}} - %0:2 = mesh.cluster_shape @mesh0 axes = [0, 2] : index, index + %0:2 = mesh.mesh_shape @mesh0 axes = [0, 2] : index, index return %0#0, %0#1 : index, index } // ----- -mesh.cluster @mesh0(shape = 1x2x3) +mesh.mesh @mesh0(shape = 1x2x3) -func.func @cluster_shape_duplicate_mesh_axis() -> (index, index, index) { +func.func @mesh_shape_duplicate_mesh_axis() -> (index, index, index) { // expected-error@+1 {{Mesh axes contains duplicate elements.}} - %0:3 = mesh.cluster_shape @mesh0 axes = [0, 2, 0] : index, index, index + %0:3 = mesh.mesh_shape @mesh0 axes = [0, 2, 0] : index, index, index return %0#0, %0#1, %0#2 : index, index, index } // ----- -mesh.cluster @mesh0(shape = 2x4) +mesh.mesh @mesh0(shape = 2x4) -func.func @cluster_shape_wrong_number_of_results() -> (index, index) { +func.func @mesh_shape_wrong_number_of_results() -> (index, index) { // expected-error@+1 {{Unexpected number of results 2. Expected 1.}} - %0:2 = mesh.cluster_shape @mesh0 axes = [0] : index, index + %0:2 = mesh.mesh_shape @mesh0 axes = [0] : index, index return %0#0, %0#1 : index, index } // ----- -mesh.cluster @mesh0(shape = 1x2x3) +mesh.mesh @mesh0(shape = 1x2x3) -func.func @cluster_shape_wrong_number_of_results_empty_mesh_axes() -> (index, index) { +func.func @mesh_shape_wrong_number_of_results_empty_mesh_axes() -> (index, index) { // expected-error@+1 {{Unexpected number of results 2. Expected 3.}} - %0:2 = mesh.cluster_shape @mesh0 : index, index + %0:2 = mesh.mesh_shape @mesh0 : index, index return %0#0, %0#1 : index, index } // ----- -func.func @cluster_shape_invalid_mesh_name() -> (index) { +func.func @mesh_shape_invalid_mesh_name() -> (index) { // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}} - %0 = mesh.cluster_shape @this_mesh_symbol_does_not_exist : index + %0 = mesh.mesh_shape @this_mesh_symbol_does_not_exist : index return %0#0 : index } // ----- -mesh.cluster @mesh0(shape = 2x4) +mesh.mesh @mesh0(shape = 2x4) func.func @process_multi_index_mesh_axis_out_of_bounds() -> (index, index) { // expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}} @@ -131,7 +131,7 @@ func.func @process_multi_index_mesh_axis_out_of_bounds() -> (index, index) { // ----- -mesh.cluster @mesh0(shape = 1x2x3) +mesh.mesh @mesh0(shape = 1x2x3) func.func @process_multi_index_duplicate_mesh_axis() -> (index, index, index) { // expected-error@+1 {{Mesh axes contains duplicate elements.}} @@ -141,7 +141,7 @@ func.func @process_multi_index_duplicate_mesh_axis() -> (index, index, index) { // ----- -mesh.cluster @mesh0(shape = 2x4) +mesh.mesh @mesh0(shape = 2x4) func.func @process_multi_index_wrong_number_of_results() -> (index, index) { // expected-error@+1 {{Unexpected number of results 2. Expected 1.}} @@ -151,7 +151,7 @@ func.func @process_multi_index_wrong_number_of_results() -> (index, index) { // ----- -mesh.cluster @mesh0(shape = 1x2x3) +mesh.mesh @mesh0(shape = 1x2x3) func.func @process_multi_index_wrong_number_of_results_empty_mesh_axes() -> (index, index) { // expected-error@+1 {{Unexpected number of results 2. Expected 3.}} @@ -187,7 +187,7 @@ func.func @all_reduce_invalid_mesh_symbol( // ----- -mesh.cluster @mesh0(shape = 2x4) +mesh.mesh @mesh0(shape = 2x4) func.func @all_reduce_invalid_mesh_axis( %arg0 : tensor<4xf32>) -> tensor<4xf64> { @@ -199,7 +199,7 @@ func.func @all_reduce_invalid_mesh_axis( // ----- -mesh.cluster @mesh0(shape = 2x4) +mesh.mesh @mesh0(shape = 2x4) func.func @all_reduce_duplicate_mesh_axis( %arg0 : tensor<4xf32>) -> tensor<4xf64> { @@ -211,7 +211,7 @@ func.func @all_reduce_duplicate_mesh_axis( // ----- -mesh.cluster @mesh0(shape = 2x4) +mesh.mesh @mesh0(shape = 2x4) func.func @all_reduce_invalid_tensor_dimension_size( %arg0 : tensor<4xf32>) -> tensor<5xf64> { @@ -232,7 +232,7 @@ func.func @all_gather_invalid_mesh_symbol( // ----- -mesh.cluster @mesh0(shape = 2x4) +mesh.mesh @mesh0(shape = 2x4) func.func @all_gather_invalid_mesh_axis( %arg0 : tensor<4xf32>) -> tensor<4xf32> { @@ -244,7 +244,7 @@ func.func @all_gather_invalid_mesh_axis( // ----- -mesh.cluster @mesh0(shape = 2x4) +mesh.mesh @mesh0(shape = 2x4) func.func @all_reduce_duplicate_mesh_axis( %arg0 : tensor<4xf32>) -> tensor<4xf32> { @@ -256,7 +256,7 @@ func.func @all_reduce_duplicate_mesh_axis( // ----- -mesh.cluster @mesh0(shape = 1) +mesh.mesh @mesh0(shape = 1) func.func @all_gather_invalid_non_gather_axis_dimension_size( %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> { @@ -268,7 +268,7 @@ func.func @all_gather_invalid_non_gather_axis_dimension_size( // ----- -mesh.cluster @mesh0(shape = 1x2) +mesh.mesh @mesh0(shape = 1x2) func.func @all_gather_invalid_gather_axis_dimension_size( %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> { @@ -280,7 +280,7 @@ func.func @all_gather_invalid_gather_axis_dimension_size( // ----- -mesh.cluster @mesh0(shape = 1) +mesh.mesh @mesh0(shape = 1) func.func @all_gather_invalid_gather_axis_dynamic_dimension( %arg0 : tensor<?xf32>) -> tensor<3xf32> { @@ -292,7 +292,7 @@ func.func @all_gather_invalid_gather_axis_dynamic_dimension( // ----- -mesh.cluster @mesh0(shape = 1) +mesh.mesh @mesh0(shape = 1) func.func @all_gather_invalid_gather_axis( %arg0 : tensor<3xf32>) -> tensor<3xf32> { @@ -304,7 +304,7 @@ func.func @all_gather_invalid_gather_axis( // ----- -mesh.cluster @mesh0(shape = 1) +mesh.mesh @mesh0(shape = 1) func.func @all_gather_invalid_negative_gather_axis( %arg0 : tensor<3xf32>) -> tensor<3xf32> { @@ -327,7 +327,7 @@ func.func @all_to_all_invalid_mesh_symbol( // ----- -mesh.cluster @mesh0(shape = 1) +mesh.mesh @mesh0(shape = 1) func.func @all_to_all_duplicate_mesh_axis( %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> { @@ -340,7 +340,7 @@ func.func @all_to_all_duplicate_mesh_axis( // ----- -mesh.cluster @mesh0(shape = ?x1) +mesh.mesh @mesh0(shape = ?x1) func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_device_group( %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> { @@ -353,7 +353,7 @@ func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_de // ----- -mesh.cluster @mesh0(shape = 1x1) +mesh.mesh @mesh0(shape = 1x1) func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dynamic_operand_dimension( %arg0 : tensor<?x6xi8>) -> tensor<3x?xi8> { @@ -366,7 +366,7 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dyna // ----- -mesh.cluster @mesh0(shape = 1x1) +mesh.mesh @mesh0(shape = 1x1) func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dynamic_operand_dimension( %arg0 : tensor<3x?xi8>) -> tensor<?x3xi8> { @@ -379,7 +379,7 @@ func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dyn // ----- -mesh.cluster @mesh0(shape = 3) +mesh.mesh @mesh0(shape = 3) func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size( %arg0 : tensor<3x2xi8>) -> tensor<1x7xi8> { @@ -392,7 +392,7 @@ func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size( // ----- -mesh.cluster @mesh0(shape = 3) +mesh.mesh @mesh0(shape = 3) func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size( %arg0 : tensor<3x2xi8>) -> tensor<2x6xi8> { @@ -405,7 +405,7 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size( // ----- -mesh.cluster @mesh0(shape = 3x?) +mesh.mesh @mesh0(shape = 3x?) func.func @broadcast_root_dimension_out_of_bounds( %arg0 : tensor<2xi8>) -> tensor<2xi8> { @@ -418,7 +418,7 @@ func.func @broadcast_root_dimension_out_of_bounds( // ----- -mesh.cluster @mesh0(shape = 3x?) +mesh.mesh @mesh0(shape = 3x?) func.func @broadcast_root_wrong_number_dimensions( %arg0 : tensor<2xi8>) -> tensor<2xi8> { @@ -431,7 +431,7 @@ func.func @broadcast_root_wrong_number_dimensions( // ----- -mesh.cluster @mesh0(shape = 3x?) +mesh.mesh @mesh0(shape = 3x?) func.func @broadcast_different_input_and_result_type( %arg0 : tensor<2xi8>) -> tensor<2xi16> { @@ -444,7 +444,7 @@ func.func @broadcast_different_input_and_result_type( // ----- -mesh.cluster @mesh0(shape = 1) +mesh.mesh @mesh0(shape = 1) func.func @gather_wrong_return_element_type( %arg0 : tensor<1xf32>) -> tensor<1xi8> { @@ -456,7 +456,7 @@ func.func @gather_wrong_return_element_type( // ----- -mesh.cluster @mesh0(shape = 1) +mesh.mesh @mesh0(shape = 1) func.func @gather_invalid_non_gather_axis_dimension_size( %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> { @@ -468,7 +468,7 @@ func.func @gather_invalid_non_gather_axis_dimension_size( // ----- -mesh.cluster @mesh0(shape = 1x2) +mesh.mesh @mesh0(shape = 1x2) func.func @gather_invalid_gather_axis_dimension_size( %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> { @@ -480,7 +480,7 @@ func.func @gather_invalid_gather_axis_dimension_size( // ----- -mesh.cluster @mesh0(shape = 1) +mesh.mesh @mesh0(shape = 1) func.func @gather_invalid_gather_axis_dynamic_dimension( %arg0 : tensor<?xf32>) -> tensor<3xf32> { @@ -492,7 +492,7 @@ func.func @gather_invalid_gather_axis_dynamic_dimension( // ----- -mesh.cluster @mesh0(shape = 1) +mesh.mesh @mesh0(shape = 1) func.func @gather_invalid_gather_axis( %arg0 : tensor<3xf32>) -> tensor<3xf32> { @@ -504,7 +504,7 @@ func.func @gather_invalid_gather_axis( // ----- -mesh.cluster @mesh0(shape = 1) +mesh.mesh @mesh0(shape = 1) func.func @gather_invalid_negative_gather_axis( %arg0 : tensor<3xf32>) -> tensor<3xf32> { @@ -516,7 +516,7 @@ func.func @gather_invalid_negative_gather_axis( // ----- -mesh.cluster @mesh0(shape = 3x?) +mesh.mesh @mesh0(shape = 3x?) func.func @gather_root_dimension_out_of_bounds( %arg0 : tensor<2xi8>) -> tensor<6xi8> { @@ -529,7 +529,7 @@ func.func @gather_root_dimension_out_of_bounds( // ----- -mesh.cluster @mesh0(shape = 3x?) +mesh.mesh @mesh0(shape = 3x?) func.func @gather_root_wrong_number_dimensions( %arg0 : tensor<2xi8>) -> tensor<2xi8> { @@ -542,7 +542,7 @@ func.func @gather_root_wrong_number_dimensions( // ----- -mesh.cluster @mesh0(shape = 3x?) +mesh.mesh @mesh0(shape = 3x?) func.func @receive_source_dimension_out_of_bounds( %arg0 : tensor<2xi8>) -> tensor<2xi8> { @@ -555,7 +555,7 @@ func.func @receive_source_dimension_out_of_bounds( // ----- -mesh.cluster @mesh0(shape = 3x?) +mesh.mesh @mesh0(shape = 3x?) func.func @receive_source_wrong_number_dimensions( %arg0 : tensor<2xi8>) -> tensor<2xi8> { @@ -568,7 +568,7 @@ func.func @receive_source_wrong_number_dimensions( // ----- -mesh.cluster @mesh0(shape = 3x?) +mesh.mesh @mesh0(shape = 3x?) func.func @receive_different_input_and_result_type( %arg0 : tensor<2xi8>) -> tensor<2xi16> { @@ -581,7 +581,7 @@ func.func @receive_different_input_and_result_type( // ----- -mesh.cluster @mesh0(shape = 3x?) +mesh.mesh @mesh0(shape = 3x?) func.func @reduce_root_dimension_out_of_bounds( %arg0 : tensor<2xi8>) -> tensor<2xi8> { @@ -594,7 +594,7 @@ func.func @reduce_root_dimension_out_of_bounds( // ----- -mesh.cluster @mesh0(shape = 3x?) +mesh.mesh @mesh0(shape = 3x?) func.func @reduce_root_wrong_number_dimensions( %arg0 : tensor<2xi8>) -> tensor<2xi8> { @@ -607,7 +607,7 @@ func.func @reduce_root_wrong_number_dimensions( // ----- -mesh.cluster @mesh0(shape = 3x?) +mesh.mesh @mesh0(shape = 3x?) func.func @reduce_different_input_and_result_shape( %arg0 : tensor<2xi8>) -> tensor<3xi16> { @@ -620,7 +620,7 @@ func.func @reduce_different_input_and_result_shape( // ----- -mesh.cluster @mesh0(shape = 3) +mesh.mesh @mesh0(shape = 3) func.func @reduce_scatter_duplicate_mesh_axis( %arg0 : tensor<?xf32>) -> tensor<?xf64> { @@ -632,7 +632,7 @@ func.func @reduce_scatter_duplicate_mesh_axis( // ----- -mesh.cluster @mesh0(shape = 3) +mesh.mesh @mesh0(shape = 3) func.func @reduce_scatter_invalid_dynamic_dimension( %arg0 : tensor<?xf32>) -> tensor<2xf64> { @@ -644,7 +644,7 @@ func.func @reduce_scatter_invalid_dynamic_dimension( // ----- -mesh.cluster @mesh0(shape = 3) +mesh.mesh @mesh0(shape = 3) func.func @reduce_scatter_invalid_static_dimension_size( %arg0 : tensor<3xf32>) -> tensor<2xf64> { @@ -656,7 +656,7 @@ func.func @reduce_scatter_invalid_static_dimension_size( // ----- -mesh.cluster @mesh0(shape = 3) +mesh.mesh @mesh0(shape = 3) func.func @reduce_scatter_invalid_operand_static_dimension_size( %arg0 : tensor<4xf32>) -> tensor<?xf64> { @@ -668,7 +668,7 @@ func.func @reduce_scatter_invalid_operand_static_dimension_size( // ----- -mesh.cluster @mesh0(shape = 3) +mesh.mesh @mesh0(shape = 3) func.func @scatter_duplicate_mesh_axis( %arg0 : tensor<?xf32>) -> tensor<?xf32> { @@ -681,7 +681,7 @@ func.func @scatter_duplicate_mesh_axis( // ----- -mesh.cluster @mesh0(shape = 3) +mesh.mesh @mesh0(shape = 3) func.func @scatter_invalid_dynamic_dimension( %arg0 : tensor<?xf32>) -> tensor<2xf32> { @@ -694,7 +694,7 @@ func.func @scatter_invalid_dynamic_dimension( // ----- -mesh.cluster @mesh0(shape = 3) +mesh.mesh @mesh0(shape = 3) func.func @scatter_invalid_static_dimension_size( %arg0 : tensor<3xf32>) -> tensor<2xf32> { @@ -707,7 +707,7 @@ func.func @scatter_invalid_static_dimension_size( // ----- -mesh.cluster @mesh0(shape = 3) +mesh.mesh @mesh0(shape = 3) func.func @scatter_invalid_operand_static_dimension_size( %arg0 : tensor<4xf32>) -> tensor<?xf32> { @@ -720,7 +720,7 @@ func.func @scatter_invalid_operand_static_dimension_size( // ----- -mesh.cluster @mesh0(shape = 3x?) +mesh.mesh @mesh0(shape = 3x?) func.func @scatter_root_dimension_out_of_bounds( %arg0 : tensor<3xi8>) -> tensor<1xi8> { @@ -733,7 +733,7 @@ func.func @scatter_root_dimension_out_of_bounds( // ----- -mesh.cluster @mesh0(shape = 3x?) +mesh.mesh @mesh0(shape = 3x?) func.func @scatter_root_wrong_number_dimensions( %arg0 : tensor<3xi8>) -> tensor<1xi8> { @@ -746,7 +746,7 @@ func.func @scatter_root_wrong_number_dimensions( // ----- -mesh.cluster @mesh0(shape = 3x?) +mesh.mesh @mesh0(shape = 3x?) func.func @send_destination_dimension_out_of_bounds( %arg0 : tensor<2xi8>) -> tensor<2xi8> { @@ -759,7 +759,7 @@ func.func @send_destination_dimension_out_of_bounds( // ----- -mesh.cluster @mesh0(shape = 3x?) +mesh.mesh @mesh0(shape = 3x?) func.func @send_destination_wrong_number_dimensions( %arg0 : tensor<2xi8>) -> tensor<2xi8> { @@ -772,7 +772,7 @@ func.func @send_destination_wrong_number_dimensions( // ----- -mesh.cluster @mesh0(shape = 3x?) +mesh.mesh @mesh0(shape = 3x?) func.func @send_different_input_and_result_type( %arg0 : tensor<2xi8>) -> tensor<2xi16> { @@ -796,7 +796,7 @@ func.func @shift_invalid_mesh_symbol( // ----- -mesh.cluster @mesh0(shape = 2x4) +mesh.mesh @mesh0(shape = 2x4) func.func @shift_invalid_mesh_axis( %arg0 : tensor<4xi8>) -> tensor<4xi8> { @@ -809,7 +809,7 @@ func.func @shift_invalid_mesh_axis( // ----- -mesh.cluster @mesh0(shape = 2x4) +mesh.mesh @mesh0(shape = 2x4) func.func @shift_duplicate_mesh_axis( %arg0 : tensor<4xi8>) -> tensor<4xi8> { @@ -822,7 +822,7 @@ func.func @shift_duplicate_mesh_axis( // ----- -mesh.cluster @mesh0(shape = 2x4) +mesh.mesh @mesh0(shape = 2x4) func.func @shift_invalid_tensor_dimension_size( %arg0 : tensor<4xi8>) -> tensor<5xi8> { @@ -835,7 +835,7 @@ func.func @shift_invalid_tensor_dimension_size( // ----- -mesh.cluster @mesh0(shape = 2x4) +mesh.mesh @mesh0(shape = 2x4) func.func @shift_invalid_shift_axis( %arg0 : tensor<4xi8>) -> tensor<4xi8> { diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir index 0aaa4bd..dbaaff9 100644 --- a/mlir/test/Dialect/Mesh/ops.mlir +++ b/mlir/test/Dialect/Mesh/ops.mlir @@ -1,21 +1,21 @@ // RUN: mlir-opt %s | mlir-opt | FileCheck %s -// CHECK: mesh.cluster @mesh0 -mesh.cluster @mesh0(shape = 2x2x4) +// CHECK: mesh.mesh @mesh0 +mesh.mesh @mesh0(shape = 2x2x4) -// CHECK: mesh.cluster @mesh1(shape = 4x?) -mesh.cluster @mesh1(shape = 4x?) +// CHECK: mesh.mesh @mesh1(shape = 4x?) +mesh.mesh @mesh1(shape = 4x?) -// CHECK: mesh.cluster @mesh2(shape = ?x4) -mesh.cluster @mesh2(shape = ?x4) +// CHECK: mesh.mesh @mesh2(shape = ?x4) +mesh.mesh @mesh2(shape = ?x4) -// CHECK: mesh.cluster @mesh3(shape = ?x?) -mesh.cluster @mesh3(shape = ?x?) +// CHECK: mesh.mesh @mesh3(shape = ?x?) +mesh.mesh @mesh3(shape = ?x?) -mesh.cluster @mesh4(shape = 3) +mesh.mesh @mesh4(shape = 3) -// CHECK: mesh.cluster @mesh5(shape = ?) -mesh.cluster @mesh5(shape = ?) +// CHECK: mesh.mesh @mesh5(shape = ?) +mesh.mesh @mesh5(shape = ?) // CHECK-LABEL: func @mesh_shard_encoding_fully_replicated func.func @mesh_shard_encoding_fully_replicated( @@ -132,26 +132,26 @@ func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) -> return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32> } -// CHECK-LABEL: func @cluster_shape -func.func @cluster_shape() -> (index, index) { - // CHECK: %[[RES:.*]]:2 = mesh.cluster_shape @mesh0 axes = [0, 1] : index, index - %0:2 = mesh.cluster_shape @mesh0 axes = [0, 1] : index, index +// CHECK-LABEL: func @mesh_shape +func.func @mesh_shape() -> (index, index) { + // CHECK: %[[RES:.*]]:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index + %0:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index // CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index return %0#0, %0#1 : index, index } -// CHECK-LABEL: func @cluster_shape_default_axes -func.func @cluster_shape_default_axes() -> (index, index, index) { - // CHECK: %[[RES:.*]]:3 = mesh.cluster_shape @mesh0 : index, index, index - %0:3 = mesh.cluster_shape @mesh0 : index, index, index +// CHECK-LABEL: func @mesh_shape_default_axes +func.func @mesh_shape_default_axes() -> (index, index, index) { + // CHECK: %[[RES:.*]]:3 = mesh.mesh_shape @mesh0 : index, index, index + %0:3 = mesh.mesh_shape @mesh0 : index, index, index // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index return %0#0, %0#1, %0#2 : index, index, index } -// CHECK-LABEL: func @cluster_shape_empty_axes -func.func @cluster_shape_empty_axes() -> (index, index, index) { - // CHECK: %[[RES:.*]]:3 = mesh.cluster_shape @mesh0 : index, index, index - %0:3 = mesh.cluster_shape @mesh0 axes = [] : index, index, index +// CHECK-LABEL: func @mesh_shape_empty_axes +func.func @mesh_shape_empty_axes() -> (index, index, index) { + // CHECK: %[[RES:.*]]:3 = mesh.mesh_shape @mesh0 : index, index, index + %0:3 = mesh.mesh_shape @mesh0 axes = [] : index, index, index // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index return %0#0, %0#1, %0#2 : index, index, index } diff --git a/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir b/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir index aeeba4e..677a598 100644 --- a/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir +++ b/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir @@ -1,11 +1,11 @@ // RUN: mlir-opt -test-mesh-process-multi-index-op-lowering %s | FileCheck %s -mesh.cluster @mesh2d(shape = ?x?) +mesh.mesh @mesh2d(shape = ?x?) // CHECK-LABEL: func.func @multi_index_2d_mesh func.func @multi_index_2d_mesh() -> (index, index) { // CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index - // CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.cluster_shape @mesh2d : index, index + // CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.mesh_shape @mesh2d : index, index // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %0 into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index %0:2 = mesh.process_multi_index on @mesh2d : index, index // CHECK: return %[[MULTI_IDX]]#0, %[[MULTI_IDX]]#1 : index, index @@ -15,7 +15,7 @@ func.func @multi_index_2d_mesh() -> (index, index) { // CHECK-LABEL: func.func @multi_index_2d_mesh_single_inner_axis func.func @multi_index_2d_mesh_single_inner_axis() -> index { // CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index - // CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.cluster_shape @mesh2d : index, index + // CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.mesh_shape @mesh2d : index, index // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %0 into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index %0 = mesh.process_multi_index on @mesh2d axes = [0] : index // CHECK: return %[[MULTI_IDX]]#0 : index diff --git a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir index 3f5c7d8..cb98d31 100644 --- a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir +++ b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt -test-mesh-resharding-spmdization %s | FileCheck %s -mesh.cluster @mesh_1d(shape = 2) -mesh.cluster @mesh_1d_dynamic(shape = ?) +mesh.mesh @mesh_1d(shape = 2) +mesh.mesh @mesh_1d_dynamic(shape = ?) // CHECK-LABEL: func @same_source_and_target_sharding func.func @same_source_and_target_sharding( @@ -22,7 +22,7 @@ func.func @split_replicated_tensor_axis( // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index // CHECK-DAG: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = arith.constant 14 : index // CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index - // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.cluster_shape @mesh_1d axes = [0] : index + // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index // CHECK: %[[RESULT_TENSOR_AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE]], %[[ZERO]] : index // CHECK: cf.assert %[[RESULT_TENSOR_AXIS_SIZE_CHECK]] @@ -44,7 +44,7 @@ func.func @split_replicated_tensor_axis_dynamic( // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index // CHECK-DAG: %[[TWO:.*]] = arith.constant 2 : index // CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_multi_index on @mesh_1d_dynamic axes = [0] : index - // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.cluster_shape @mesh_1d_dynamic axes = [0] : index + // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.mesh_shape @mesh_1d_dynamic axes = [0] : index // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[ZERO]] : tensor<?x3x?xf32> // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index // CHECK: %[[RESULT_TENSOR_AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE]], %[[ZERO]] : index diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir index 065ae9c..94f8d94 100644 --- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir +++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir @@ -1,8 +1,8 @@ // RUN: mlir-opt -sharding-propagation %s | FileCheck %s -mesh.cluster @mesh_1d(shape = ?) -mesh.cluster @mesh_2d(shape = 2x4) -mesh.cluster @mesh_3d(shape = ?x?x?) +mesh.mesh @mesh_1d(shape = ?) +mesh.mesh @mesh_2d(shape = 2x4) +mesh.mesh @mesh_3d(shape = ?x?x?) // CHECK-LABEL: func.func @element_wise_empty_sharding_info func.func @element_wise_empty_sharding_info(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { diff --git a/mlir/test/Dialect/Mesh/simplifications.mlir b/mlir/test/Dialect/Mesh/simplifications.mlir index 63ae9d5..d748be8 100644 --- a/mlir/test/Dialect/Mesh/simplifications.mlir +++ b/mlir/test/Dialect/Mesh/simplifications.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s -mesh.cluster @mesh0(shape = 4x2) -mesh.cluster @mesh1(shape = 4) +mesh.mesh @mesh0(shape = 4x2) +mesh.mesh @mesh1(shape = 4) // Checks that `all_reduce(x) + all_reduce(y)` gets transformed to // `all_reduce(x + y)`. diff --git a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp index 6fecbd4..9b3082a 100644 --- a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp +++ b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp @@ -37,15 +37,15 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> { } SymbolTableCollection symbolTable; - mesh::ClusterOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>( - op, op.getShard().getCluster()); + mesh::MeshOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>( + op, op.getShard().getMesh()); bool foundUser = false; for (auto user : op->getUsers()) { if (auto targetShardOp = llvm::dyn_cast<ShardOp>(user)) { if (targetShardOp.getAnnotateForUsers() && - mesh == symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>( - targetShardOp, targetShardOp.getShard().getCluster())) { + mesh == symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>( + targetShardOp, targetShardOp.getShard().getMesh())) { foundUser = true; break; } @@ -59,8 +59,8 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> { for (auto user : op->getUsers()) { auto targetShardOp = llvm::dyn_cast<ShardOp>(user); if (!targetShardOp || !targetShardOp.getAnnotateForUsers() || - symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>( - targetShardOp, targetShardOp.getShard().getCluster()) != mesh) { + symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>( + targetShardOp, targetShardOp.getShard().getMesh()) != mesh) { continue; } |