aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp')
-rw-r--r--mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp950
1 files changed, 0 insertions, 950 deletions
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
deleted file mode 100644
index b931284..0000000
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ /dev/null
@@ -1,950 +0,0 @@
-//===- MeshToMPI.cpp - Mesh to MPI dialect conversion -----------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements a translation of Mesh communication ops tp MPI ops.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
-
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/MPI/IR/MPI.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
-#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
-#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Utils/StaticValueUtils.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/SymbolTable.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-#define DEBUG_TYPE "mesh-to-mpi"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-
-namespace mlir {
-#define GEN_PASS_DEF_CONVERTMESHTOMPIPASS
-#include "mlir/Conversion/Passes.h.inc"
-} // namespace mlir
-
-using namespace mlir;
-using namespace mesh;
-
-namespace {
-/// Converts a vector of OpFoldResults (ints) into vector of Values of the
-/// provided type.
-static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
- llvm::ArrayRef<int64_t> statics,
- ValueRange dynamics,
- Type type = Type()) {
- SmallVector<Value> values;
- auto dyn = dynamics.begin();
- Type i64 = b.getI64Type();
- if (!type)
- type = i64;
- assert((i64 == type || b.getIndexType() == type) &&
- "expected an i64 or an intex type");
- for (auto s : statics) {
- if (s == ShapedType::kDynamic) {
- values.emplace_back(*(dyn++));
- } else {
- TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s);
- values.emplace_back(b.create<arith::ConstantOp>(loc, type, val));
- }
- }
- return values;
-}
-
-/// Create operations converting a linear index to a multi-dimensional index.
-static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
- Value linearIndex,
- ValueRange dimensions) {
- int n = dimensions.size();
- SmallVector<Value> multiIndex(n);
-
- for (int i = n - 1; i >= 0; --i) {
- multiIndex[i] = b.create<arith::RemSIOp>(loc, linearIndex, dimensions[i]);
- if (i > 0)
- linearIndex = b.create<arith::DivSIOp>(loc, linearIndex, dimensions[i]);
- }
-
- return multiIndex;
-}
-
-/// Create operations converting a multi-dimensional index to a linear index.
-Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
- ValueRange dimensions) {
-
- Value linearIndex = b.create<arith::ConstantIndexOp>(loc, 0);
- Value stride = b.create<arith::ConstantIndexOp>(loc, 1);
-
- for (int i = multiIndex.size() - 1; i >= 0; --i) {
- Value off = b.create<arith::MulIOp>(loc, multiIndex[i], stride);
- linearIndex = b.create<arith::AddIOp>(loc, linearIndex, off);
- stride = b.create<arith::MulIOp>(loc, stride, dimensions[i]);
- }
-
- return linearIndex;
-}
-
-/// Replace GetShardingOp with related/dependent ShardingOp.
-struct ConvertGetShardingOp : public OpConversionPattern<GetShardingOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(GetShardingOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto shardOp = adaptor.getSource().getDefiningOp<ShardOp>();
- if (!shardOp)
- return failure();
- auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>();
- if (!shardingOp)
- return failure();
-
- rewriter.replaceOp(op, shardingOp.getResult());
- return success();
- }
-};
-
-/// Convert a sharding op to a tuple of tensors of its components
-/// (SplitAxes, HaloSizes, ShardedDimsOffsets)
-/// as defined by type converter.
-struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(ShardingOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto splitAxes = op.getSplitAxes().getAxes();
- int64_t maxNAxes = 0;
- for (auto axes : splitAxes)
- maxNAxes = std::max<int64_t>(maxNAxes, axes.size());
-
- // To hold the split axes, create empty 2d tensor with shape
- // {splitAxes.size(), max-size-of-split-groups}.
- // Set trailing elements for smaller split-groups to -1.
- Location loc = op.getLoc();
- auto i16 = rewriter.getI16Type();
- auto i64 = rewriter.getI64Type();
- std::array<int64_t, 2> shape = {static_cast<int64_t>(splitAxes.size()),
- maxNAxes};
- Value resSplitAxes = rewriter.create<tensor::EmptyOp>(loc, shape, i16);
- auto attr = IntegerAttr::get(i16, -1);
- Value fillValue = rewriter.create<arith::ConstantOp>(loc, i16, attr);
- resSplitAxes = rewriter.create<linalg::FillOp>(loc, fillValue, resSplitAxes)
- .getResult(0);
-
- // explicitly write values into tensor row by row
- std::array<int64_t, 2> strides = {1, 1};
- int64_t nSplits = 0;
- ValueRange empty = {};
- for (auto [i, axes] : llvm::enumerate(splitAxes)) {
- int64_t size = axes.size();
- if (size > 0)
- ++nSplits;
- std::array<int64_t, 2> offs = {(int64_t)i, 0};
- std::array<int64_t, 2> sizes = {1, size};
- auto tensorType = RankedTensorType::get({size}, i16);
- auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef());
- auto vals = rewriter.create<arith::ConstantOp>(loc, tensorType, attrs);
- resSplitAxes = rewriter.create<tensor::InsertSliceOp>(
- loc, vals, resSplitAxes, empty, empty, empty, offs, sizes, strides);
- }
-
- // To hold halos sizes, create 2d Tensor with shape {nSplits, 2}.
- // Store the halo sizes in the tensor.
- SmallVector<Value> haloSizes =
- getMixedAsValues(rewriter, loc, adaptor.getStaticHaloSizes(),
- adaptor.getDynamicHaloSizes());
- auto type = RankedTensorType::get({nSplits, 2}, i64);
- Value resHaloSizes =
- haloSizes.empty()
- ? rewriter
- .create<tensor::EmptyOp>(loc, std::array<int64_t, 2>{0, 0},
- i64)
- .getResult()
- : rewriter.create<tensor::FromElementsOp>(loc, type, haloSizes)
- .getResult();
-
- // To hold sharded dims offsets, create Tensor with shape {nSplits,
- // maxSplitSize+1}. Store the offsets in the tensor but set trailing
- // elements for smaller split-groups to -1. Computing the max size of the
- // split groups needs using collectiveProcessGroupSize (which needs the
- // MeshOp)
- Value resOffsets;
- if (adaptor.getStaticShardedDimsOffsets().empty()) {
- resOffsets = rewriter.create<tensor::EmptyOp>(
- loc, std::array<int64_t, 2>{0, 0}, i64);
- } else {
- SymbolTableCollection symbolTableCollection;
- auto meshOp = getMesh(op, symbolTableCollection);
- int64_t maxSplitSize = 0;
- for (auto axes : splitAxes) {
- int64_t splitSize =
- collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
- assert(splitSize != ShapedType::kDynamic);
- maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize);
- }
- assert(maxSplitSize);
- ++maxSplitSize; // add one for the total size
-
- resOffsets = rewriter.create<tensor::EmptyOp>(
- loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64);
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic));
- resOffsets =
- rewriter.create<linalg::FillOp>(loc, zero, resOffsets).getResult(0);
- SmallVector<Value> offsets =
- getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(),
- adaptor.getDynamicShardedDimsOffsets());
- int64_t curr = 0;
- for (auto [i, axes] : llvm::enumerate(splitAxes)) {
- int64_t splitSize =
- collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
- assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
- ++splitSize; // add one for the total size
- ArrayRef<Value> values(&offsets[curr], splitSize);
- Value vals = rewriter.create<tensor::FromElementsOp>(loc, values);
- std::array<int64_t, 2> offs = {static_cast<int64_t>(i), 0};
- std::array<int64_t, 2> sizes = {1, splitSize};
- resOffsets = rewriter.create<tensor::InsertSliceOp>(
- loc, vals, resOffsets, empty, empty, empty, offs, sizes, strides);
- curr += splitSize;
- }
- }
-
- // return a tuple of tensors as defined by type converter
- SmallVector<Type> resTypes;
- if (failed(getTypeConverter()->convertType(op.getResult().getType(),
- resTypes)))
- return failure();
-
- resSplitAxes =
- rewriter.create<tensor::CastOp>(loc, resTypes[0], resSplitAxes);
- resHaloSizes =
- rewriter.create<tensor::CastOp>(loc, resTypes[1], resHaloSizes);
- resOffsets = rewriter.create<tensor::CastOp>(loc, resTypes[2], resOffsets);
-
- rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
- op, TupleType::get(op.getContext(), resTypes),
- ValueRange{resSplitAxes, resHaloSizes, resOffsets});
-
- return success();
- }
-};
-
-struct ConvertProcessMultiIndexOp
- : public OpConversionPattern<ProcessMultiIndexOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(ProcessMultiIndexOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
-
- // Currently converts its linear index to a multi-dimensional index.
-
- SymbolTableCollection symbolTableCollection;
- Location loc = op.getLoc();
- auto meshOp = getMesh(op, symbolTableCollection);
- // For now we only support static mesh shapes
- if (ShapedType::isDynamicShape(meshOp.getShape()))
- return failure();
-
- SmallVector<Value> dims;
- llvm::transform(
- meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
- return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
- });
- Value rank = rewriter.create<ProcessLinearIndexOp>(op.getLoc(), meshOp);
- auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims);
-
- // optionally extract subset of mesh axes
- auto axes = adaptor.getAxes();
- if (!axes.empty()) {
- SmallVector<Value> subIndex;
- for (auto axis : axes) {
- subIndex.emplace_back(mIdx[axis]);
- }
- mIdx = std::move(subIndex);
- }
-
- rewriter.replaceOp(op, mIdx);
- return success();
- }
-};
-
-class ConvertProcessLinearIndexOp
- : public OpConversionPattern<ProcessLinearIndexOp> {
-
-public:
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- // Create mpi::CommRankOp
- Location loc = op.getLoc();
- auto ctx = op.getContext();
- Value commWorld =
- rewriter.create<mpi::CommWorldOp>(loc, mpi::CommType::get(ctx));
- auto rank =
- rewriter
- .create<mpi::CommRankOp>(
- loc,
- TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()},
- commWorld)
- .getRank();
- rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
- rank);
- return success();
- }
-};
-
-struct ConvertNeighborsLinearIndicesOp
- : public OpConversionPattern<NeighborsLinearIndicesOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(NeighborsLinearIndicesOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
-
- // Computes the neighbors indices along a split axis by simply
- // adding/subtracting 1 to the current index in that dimension.
- // Assigns -1 if neighbor is out of bounds.
-
- auto axes = adaptor.getSplitAxes();
- // For now only single axis sharding is supported
- if (axes.size() != 1)
- return failure();
-
- Location loc = op.getLoc();
- SymbolTableCollection symbolTableCollection;
- auto meshOp = getMesh(op, symbolTableCollection);
- auto mIdx = adaptor.getDevice();
- auto orgIdx = mIdx[axes[0]];
- SmallVector<Value> dims;
- llvm::transform(
- meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
- return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
- });
- Value dimSz = dims[axes[0]];
- Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- Value minus1 = rewriter.create<arith::ConstantIndexOp>(loc, -1);
- Value atBorder = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sle, orgIdx,
- rewriter.create<arith::ConstantIndexOp>(loc, 0));
- auto down = rewriter.create<scf::IfOp>(
- loc, atBorder,
- [&](OpBuilder &builder, Location loc) {
- builder.create<scf::YieldOp>(loc, minus1);
- },
- [&](OpBuilder &builder, Location loc) {
- SmallVector<Value> tmp = mIdx;
- tmp[axes[0]] =
- rewriter.create<arith::SubIOp>(op.getLoc(), orgIdx, one)
- .getResult();
- builder.create<scf::YieldOp>(
- loc, multiToLinearIndex(loc, rewriter, tmp, dims));
- });
- atBorder = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, orgIdx,
- rewriter.create<arith::SubIOp>(loc, dimSz, one).getResult());
- auto up = rewriter.create<scf::IfOp>(
- loc, atBorder,
- [&](OpBuilder &builder, Location loc) {
- builder.create<scf::YieldOp>(loc, minus1);
- },
- [&](OpBuilder &builder, Location loc) {
- SmallVector<Value> tmp = mIdx;
- tmp[axes[0]] =
- rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one);
- builder.create<scf::YieldOp>(
- loc, multiToLinearIndex(loc, rewriter, tmp, dims));
- });
- rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)});
- return success();
- }
-};
-
-struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(ShardShapeOp op, OneToNOpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto sharding = op.getSharding().getDefiningOp<ShardingOp>();
- if (!sharding) {
- return op->emitError()
- << "Expected SharingOp as defining op for sharding"
- << " but found " << adaptor.getSharding()[0].getDefiningOp();
- }
-
- // Compute the sharded shape by applying the sharding to the input shape.
- // If shardedDimsOffsets is not defined in the sharding, the shard shape is
- // computed by dividing the dimension size by the number of shards in that
- // dimension (which is given by the size of the mesh axes provided in
- // split-axes). Odd elements get distributed to trailing shards. If a
- // shardedDimsOffsets is provided, the shard shape is computed by
- // subtracting the offset of the current shard from the offset of the next
- // shard.
-
- Location loc = op.getLoc();
- Type index = rewriter.getIndexType();
-
- // This is a 1:N conversion because the sharding op is a 1:3 conversion.
- // The operands in the adaptor are a vector<ValeRange>. For dims and device
- // we have a 1:1 conversion.
- // For simpler access fill a vector with the dynamic dims.
- SmallVector<Value> dynDims, dynDevice;
- for (auto dim : adaptor.getDimsDynamic()) {
- // type conversion should be 1:1 for ints
- dynDims.emplace_back(llvm::getSingleElement(dim));
- }
- // same for device
- for (auto device : adaptor.getDeviceDynamic()) {
- dynDevice.emplace_back(llvm::getSingleElement(device));
- }
-
- // To keep the code simple, convert dims/device to values when they are
- // attributes. Count on canonicalization to fold static values.
- SmallVector<Value> shape =
- getMixedAsValues(rewriter, loc, op.getDims(), dynDims, index);
- SmallVector<Value> multiIdx =
- getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index);
-
- // Get the MeshOp, the mesh shape is needed to compute the sharded shape.
- SymbolTableCollection symbolTableCollection;
- auto meshOp = getMesh(sharding, symbolTableCollection);
- // For now we only support static mesh shapes
- if (ShapedType::isDynamicShape(meshOp.getShape()))
- return failure();
-
- auto splitAxes = sharding.getSplitAxes().getAxes();
- // shardedDimsOffsets are optional and might be Values (not attributes).
- // Also, the shardId might be dynamic which means the position in the
- // shardedDimsOffsets is not statically known. Create a tensor of the
- // shardedDimsOffsets and later extract the offsets for computing the
- // local shard-size.
- Value shardedDimsOffs;
- {
- SmallVector<Value> tmp = getMixedAsValues(
- rewriter, loc, sharding.getStaticShardedDimsOffsets(),
- sharding.getDynamicShardedDimsOffsets(), index);
- if (!tmp.empty())
- shardedDimsOffs = rewriter.create<tensor::FromElementsOp>(
- loc, RankedTensorType::get({(int64_t)tmp.size()}, index), tmp);
- }
-
- // With static mesh shape the sizes of the split axes are known.
- // Hence the start/pos for each split axes in shardDimsOffsets can be
- // computed statically.
- int64_t pos = 0;
- SmallVector<Value> shardShape;
- Value zero =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(index));
- Value one =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getOneAttr(index));
-
- // Iterate over the dimensions of the tensor shape, get their split Axes,
- // and compute the sharded shape.
- for (auto [i, dim] : llvm::enumerate(shape)) {
- // Trailing dimensions might not be annotated.
- if (i < splitAxes.size() && !splitAxes[i].empty()) {
- auto axes = splitAxes[i];
- // The current dimension might not be sharded.
- // Create a value from the static position in shardDimsOffsets.
- Value posVal =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(pos));
- // Get the index of the local shard in the mesh axis.
- Value idx = multiIdx[axes[0]];
- auto numShards =
- collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
- if (shardedDimsOffs) {
- // If sharded dims offsets are provided, use them to compute the
- // sharded shape.
- if (axes.size() > 1) {
- return op->emitError() << "Only single axis sharding is "
- << "supported for each dimension.";
- }
- idx = rewriter.create<arith::AddIOp>(loc, posVal, idx);
- // Compute size = shardedDimsOffs[idx+1] - shardedDimsOffs[idx].
- Value off =
- rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx);
- idx = rewriter.create<arith::AddIOp>(loc, idx, one);
- Value nextOff =
- rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx);
- Value sz = rewriter.create<arith::SubIOp>(loc, nextOff, off);
- shardShape.emplace_back(sz);
- } else {
- Value numShardsVal = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIndexAttr(numShards));
- // Compute shard dim size by distributing odd elements to trailing
- // shards:
- // sz = dim / numShards
- // + (idx >= (numShards - (dim % numShards)) ? 1 : 0)
- Value sz = rewriter.create<arith::DivSIOp>(loc, dim, numShardsVal);
- Value sz1 = rewriter.create<arith::RemSIOp>(loc, dim, numShardsVal);
- sz1 = rewriter.create<arith::SubIOp>(loc, numShardsVal, sz1);
- auto cond = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, idx, sz1);
- Value odd = rewriter.create<arith::SelectOp>(loc, cond, one, zero);
- sz = rewriter.create<arith::AddIOp>(loc, sz, odd);
- shardShape.emplace_back(sz);
- }
- pos += numShards + 1; // add one for the total size.
- } // else no sharding if split axis is empty or no split axis
- // If no size was added -> no sharding in this dimension.
- if (shardShape.size() <= i)
- shardShape.emplace_back(dim);
- }
- assert(shardShape.size() == shape.size());
- rewriter.replaceOp(op, shardShape);
- return success();
- }
-};
-
-static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp(ReductionKindAttr kind) {
- auto ctx = kind.getContext();
- auto getReductionOp = [ctx](mpi::MPI_ReductionOpEnum redOp) {
- return mpi::MPI_ReductionOpEnumAttr::get(ctx, redOp);
- };
-
- switch (kind.getValue()) {
- case ReductionKind::Sum:
- return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_SUM);
- case ReductionKind::Product:
- return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_PROD);
- case ReductionKind::Min:
- return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_MIN);
- case ReductionKind::Max:
- return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_MAX);
- case ReductionKind::BitwiseAnd:
- return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BAND);
- case ReductionKind::BitwiseOr:
- return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BOR);
- case ReductionKind::BitwiseXor:
- return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BXOR);
- default:
- llvm_unreachable("Unknown/unsupported reduction kind");
- }
-}
-
-struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- SymbolTableCollection symbolTableCollection;
- auto mesh = adaptor.getMesh();
- mlir::mesh::MeshOp meshOp = getMesh(op, symbolTableCollection);
- if (!meshOp)
- return op->emitError() << "No mesh found for AllReduceOp";
- if (ShapedType::isDynamicShape(meshOp.getShape()))
- return op->emitError()
- << "Dynamic mesh shape not supported in AllReduceOp";
-
- ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
- Value input = adaptor.getInput();
- auto inputShape = cast<ShapedType>(input.getType()).getShape();
-
- // If the source is a memref, cast it to a tensor.
- if (isa<RankedTensorType>(input.getType())) {
- auto memrefType = MemRefType::get(
- inputShape, cast<ShapedType>(input.getType()).getElementType());
- input = iBuilder.create<bufferization::ToBufferOp>(memrefType, input);
- }
- MemRefType inType = cast<MemRefType>(input.getType());
-
- // Get the actual shape to allocate the buffer.
- SmallVector<OpFoldResult> shape(inType.getRank());
- for (auto i = 0; i < inType.getRank(); ++i) {
- auto s = inputShape[i];
- if (ShapedType::isDynamic(s))
- shape[i] = iBuilder.create<memref::DimOp>(input, s).getResult();
- else
- shape[i] = iBuilder.getIndexAttr(s);
- }
-
- // Allocate buffer and copy input to buffer.
- Value buffer = iBuilder.create<memref::AllocOp>(
- shape, cast<ShapedType>(op.getType()).getElementType());
- iBuilder.create<linalg::CopyOp>(input, buffer);
-
- // Get an MPI_Comm_split for the AllReduce operation.
- // The color is the linear index of the process in the mesh along the
- // non-reduced axes. The key is the linear index of the process in the mesh
- // along the reduced axes.
- SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
- iBuilder.getIndexType());
- SmallVector<Value> myMultiIndex =
- iBuilder.create<ProcessMultiIndexOp>(indexResultTypes, mesh)
- .getResult();
- Value zero = iBuilder.create<arith::ConstantIndexOp>(0);
- SmallVector<Value> multiKey(myMultiIndex.size(), zero);
-
- auto redAxes = adaptor.getMeshAxes();
- for (auto axis : redAxes) {
- multiKey[axis] = myMultiIndex[axis];
- myMultiIndex[axis] = zero;
- }
-
- Value color =
- createProcessLinearIndex(mesh, myMultiIndex, redAxes, iBuilder);
- color = iBuilder.create<arith::IndexCastOp>(iBuilder.getI32Type(), color);
- Value key = createProcessLinearIndex(mesh, multiKey, redAxes, iBuilder);
- key = iBuilder.create<arith::IndexCastOp>(iBuilder.getI32Type(), key);
-
- // Finally split the communicator
- auto commType = mpi::CommType::get(op->getContext());
- Value commWorld = iBuilder.create<mpi::CommWorldOp>(commType);
- auto comm =
- iBuilder.create<mpi::CommSplitOp>(commType, commWorld, color, key)
- .getNewcomm();
-
- Value buffer1d = buffer;
- // Collapse shape to 1d if needed
- if (inType.getRank() > 1) {
- ReassociationIndices reassociation(inType.getRank());
- std::iota(reassociation.begin(), reassociation.end(), 0);
- buffer1d = iBuilder.create<memref::CollapseShapeOp>(
- buffer, ArrayRef<ReassociationIndices>(reassociation));
- }
-
- // Create the MPI AllReduce operation.
- iBuilder.create<mpi::AllReduceOp>(
- TypeRange(), buffer1d, buffer1d,
- getMPIReductionOp(adaptor.getReductionAttr()), comm);
-
- // If the destination is a memref, cast it to a tensor
- if (isa<RankedTensorType>(op.getType()))
- buffer = iBuilder.create<bufferization::ToTensorOp>(op.getType(), buffer,
- true);
-
- rewriter.replaceOp(op, buffer);
- return success();
- }
-};
-
-struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(UpdateHaloOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
-
- // The input/output memref is assumed to be in C memory order.
- // Halos are exchanged as 2 blocks per dimension (one for each side: down
- // and up). For each haloed dimension `d`, the exchanged blocks are
- // expressed as multi-dimensional subviews. The subviews include potential
- // halos of higher dimensions `dh > d`, no halos for the lower dimensions
- // `dl < d` and for dimension `d` the currently exchanged halo only.
- // By iterating form higher to lower dimensions this also updates the halos
- // in the 'corners'.
- // memref.subview is used to read and write the halo data from and to the
- // local data. Because subviews and halos can have mixed dynamic and static
- // shapes, OpFoldResults are used whenever possible.
-
- auto haloSizes = getMixedValues(adaptor.getStaticHaloSizes(),
- adaptor.getHaloSizes(), rewriter);
- if (haloSizes.empty()) {
- // no halos -> nothing to do
- rewriter.replaceOp(op, adaptor.getDestination());
- return success();
- }
-
- SymbolTableCollection symbolTableCollection;
- Location loc = op.getLoc();
-
- // convert a OpFoldResult into a Value
- auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value {
- if (auto value = dyn_cast<Value>(v))
- return value;
- return rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIndexAttr(
- cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
- };
-
- auto dest = adaptor.getDestination();
- auto dstShape = cast<ShapedType>(dest.getType()).getShape();
- Value array = dest;
- if (isa<RankedTensorType>(array.getType())) {
- // If the destination is a memref, we need to cast it to a tensor
- auto mmemrefType = MemRefType::get(
- dstShape, cast<ShapedType>(array.getType()).getElementType());
- array =
- rewriter.create<bufferization::ToBufferOp>(loc, mmemrefType, array);
- }
- auto rank = cast<ShapedType>(array.getType()).getRank();
- auto opSplitAxes = adaptor.getSplitAxes().getAxes();
- auto mesh = adaptor.getMesh();
- auto meshOp = getMesh(op, symbolTableCollection);
- // subviews need Index values
- for (auto &sz : haloSizes) {
- if (auto value = dyn_cast<Value>(sz))
- sz =
- rewriter
- .create<arith::IndexCastOp>(loc, rewriter.getIndexType(), value)
- .getResult();
- }
-
- // most of the offset/size/stride data is the same for all dims
- SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
- SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
- SmallVector<OpFoldResult> shape(rank), dimSizes(rank);
- auto currHaloDim = -1; // halo sizes are provided for split dimensions only
- // we need the actual shape to compute offsets and sizes
- for (auto i = 0; i < rank; ++i) {
- auto s = dstShape[i];
- if (ShapedType::isDynamic(s))
- shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult();
- else
- shape[i] = rewriter.getIndexAttr(s);
-
- if ((size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) {
- ++currHaloDim;
- // the offsets for lower dim sstarts after their down halo
- offsets[i] = haloSizes[currHaloDim * 2];
-
- // prepare shape and offsets of highest dim's halo exchange
- Value _haloSz = rewriter.create<arith::AddIOp>(
- loc, toValue(haloSizes[currHaloDim * 2]),
- toValue(haloSizes[currHaloDim * 2 + 1]));
- // the halo shape of lower dims exlude the halos
- dimSizes[i] =
- rewriter.create<arith::SubIOp>(loc, toValue(shape[i]), _haloSz)
- .getResult();
- } else {
- dimSizes[i] = shape[i];
- }
- }
-
- auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something
- auto tag = rewriter.create<arith::ConstantOp>(loc, tagAttr);
- auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
- auto zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
-
- SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
- rewriter.getIndexType());
- auto myMultiIndex =
- rewriter.create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
- .getResult();
- // traverse all split axes from high to low dim
- for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
- auto splitAxes = opSplitAxes[dim];
- if (splitAxes.empty())
- continue;
- assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2);
- // Get the linearized ids of the neighbors (down and up) for the
- // given split
- auto tmp = rewriter
- .create<NeighborsLinearIndicesOp>(loc, mesh, myMultiIndex,
- splitAxes)
- .getResults();
- // MPI operates on i32...
- Value neighbourIDs[2] = {rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getI32Type(), tmp[0]),
- rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getI32Type(), tmp[1])};
-
- auto lowerRecvOffset = rewriter.getIndexAttr(0);
- auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]);
- auto upperRecvOffset = rewriter.create<arith::SubIOp>(
- loc, toValue(shape[dim]), toValue(haloSizes[currHaloDim * 2 + 1]));
- auto upperSendOffset = rewriter.create<arith::SubIOp>(
- loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2]));
-
- Value commWorld = rewriter.create<mpi::CommWorldOp>(
- loc, mpi::CommType::get(op->getContext()));
-
- // Make sure we send/recv in a way that does not lead to a dead-lock.
- // The current approach is by far not optimal, this should be at least
- // be a red-black pattern or using MPI_sendrecv.
- // Also, buffers should be re-used.
- // Still using temporary contiguous buffers for MPI communication...
- // Still yielding a "serialized" communication pattern...
- auto genSendRecv = [&](bool upperHalo) {
- auto orgOffset = offsets[dim];
- dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
- : haloSizes[currHaloDim * 2];
- // Check if we need to send and/or receive
- // Processes on the mesh borders have only one neighbor
- auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
- auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
- auto hasFrom = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, from, zero);
- auto hasTo = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, to, zero);
- auto buffer = rewriter.create<memref::AllocOp>(
- loc, dimSizes, cast<ShapedType>(array.getType()).getElementType());
- // if has neighbor: copy halo data from array to buffer and send
- rewriter.create<scf::IfOp>(
- loc, hasTo, [&](OpBuilder &builder, Location loc) {
- offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset)
- : OpFoldResult(upperSendOffset);
- auto subview = builder.create<memref::SubViewOp>(
- loc, array, offsets, dimSizes, strides);
- builder.create<memref::CopyOp>(loc, subview, buffer);
- builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag, to,
- commWorld);
- builder.create<scf::YieldOp>(loc);
- });
- // if has neighbor: receive halo data into buffer and copy to array
- rewriter.create<scf::IfOp>(
- loc, hasFrom, [&](OpBuilder &builder, Location loc) {
- offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset)
- : OpFoldResult(lowerRecvOffset);
- builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag, from,
- commWorld);
- auto subview = builder.create<memref::SubViewOp>(
- loc, array, offsets, dimSizes, strides);
- builder.create<memref::CopyOp>(loc, buffer, subview);
- builder.create<scf::YieldOp>(loc);
- });
- rewriter.create<memref::DeallocOp>(loc, buffer);
- offsets[dim] = orgOffset;
- };
-
- auto doSendRecv = [&](int upOrDown) {
- OpFoldResult &v = haloSizes[currHaloDim * 2 + upOrDown];
- Value haloSz = dyn_cast<Value>(v);
- if (!haloSz)
- haloSz = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(
- cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
- auto hasSize = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sgt, haloSz, zero);
- rewriter.create<scf::IfOp>(loc, hasSize,
- [&](OpBuilder &builder, Location loc) {
- genSendRecv(upOrDown > 0);
- builder.create<scf::YieldOp>(loc);
- });
- };
-
- doSendRecv(0);
- doSendRecv(1);
-
- // the shape for lower dims include higher dims' halos
- dimSizes[dim] = shape[dim];
- // -> the offset for higher dims is always 0
- offsets[dim] = rewriter.getIndexAttr(0);
- // on to next halo
- --currHaloDim;
- }
-
- if (isa<MemRefType>(op.getResult().getType())) {
- rewriter.replaceOp(op, array);
- } else {
- assert(isa<RankedTensorType>(op.getResult().getType()));
- rewriter.replaceOp(op, rewriter.create<bufferization::ToTensorOp>(
- loc, op.getResult().getType(), array,
- /*restrict=*/true, /*writable=*/true));
- }
- return success();
- }
-};
-
-struct ConvertMeshToMPIPass
- : public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> {
- using Base::Base;
-
- /// Run the dialect converter on the module.
- void runOnOperation() override {
- auto *ctxt = &getContext();
- RewritePatternSet patterns(ctxt);
- ConversionTarget target(getContext());
-
- // Define a type converter to convert mesh::ShardingType,
- // mostly for use in return operations.
- TypeConverter typeConverter;
- typeConverter.addConversion([](Type type) { return type; });
-
- // convert mesh::ShardingType to a tuple of RankedTensorTypes
- typeConverter.addConversion(
- [](ShardingType type,
- SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
- auto i16 = IntegerType::get(type.getContext(), 16);
- auto i64 = IntegerType::get(type.getContext(), 64);
- std::array<int64_t, 2> shp = {ShapedType::kDynamic,
- ShapedType::kDynamic};
- results.emplace_back(RankedTensorType::get(shp, i16));
- results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
- results.emplace_back(RankedTensorType::get(shp, i64));
- return success();
- });
-
- // To 'extract' components, a UnrealizedConversionCastOp is expected
- // to define the input
- typeConverter.addTargetMaterialization(
- [&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
- Location loc) {
- // Expecting a single input.
- if (inputs.size() != 1 || !isa<TupleType>(inputs[0].getType()))
- return SmallVector<Value>();
- auto castOp = inputs[0].getDefiningOp<UnrealizedConversionCastOp>();
- // Expecting an UnrealizedConversionCastOp.
- if (!castOp)
- return SmallVector<Value>();
- // Fill a vector with elements of the tuple/castOp.
- SmallVector<Value> results;
- for (auto oprnd : castOp.getInputs()) {
- if (!isa<RankedTensorType>(oprnd.getType()))
- return SmallVector<Value>();
- results.emplace_back(oprnd);
- }
- return results;
- });
-
- // No mesh dialect should left after conversion...
- target.addIllegalDialect<mesh::MeshDialect>();
- // ...except the global MeshOp. MeshShapeOp which will get folded later.
- target.addLegalOp<mesh::MeshOp, mesh::MeshShapeOp>();
- // Allow all the stuff that our patterns will convert to
- target.addLegalDialect<
- BuiltinDialect, mpi::MPIDialect, scf::SCFDialect, arith::ArithDialect,
- tensor::TensorDialect, bufferization::BufferizationDialect,
- linalg::LinalgDialect, memref::MemRefDialect, affine::AffineDialect>();
- // Make sure the function signature, calls etc. are legal
- target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
- return typeConverter.isSignatureLegal(op.getFunctionType());
- });
- target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(
- [&](Operation *op) { return typeConverter.isLegal(op); });
-
- patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
- ConvertProcessMultiIndexOp, ConvertGetShardingOp,
- ConvertShardingOp, ConvertShardShapeOp, ConvertAllReduceOp,
- ConvertProcessLinearIndexOp>(typeConverter, ctxt);
-
- populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
- patterns, typeConverter);
- populateCallOpTypeConversionPattern(patterns, typeConverter);
- populateReturnOpTypeConversionPattern(patterns, typeConverter);
-
- (void)applyPartialConversion(getOperation(), target, std::move(patterns));
-
- // Folding patterns cannot be mixed with conversion patterns -> extra pass.
- patterns.clear();
- SymbolTableCollection symbolTableCollection;
- mlir::mesh::populateFoldingPatterns(patterns, symbolTableCollection);
- (void)applyPatternsGreedily(getOperation(), std::move(patterns));
- }
-};
-
-} // namespace