aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp')
-rw-r--r--mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp953
1 files changed, 953 insertions, 0 deletions
diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
new file mode 100644
index 0000000..fa9e544
--- /dev/null
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -0,0 +1,953 @@
+//===- ShardToMPI.cpp - Shard to MPI dialect conversion -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation of Shard communication ops to MPI ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ShardToMPI/ShardToMPI.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Shard/IR/ShardDialect.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
+#include "mlir/Dialect/Shard/Transforms/Simplifications.h"
+#include "mlir/Dialect/Shard/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "shard-to-mpi"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTSHARDTOMPIPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace shard;
+
+namespace {
+/// Converts a vector of OpFoldResults (ints) into vector of Values of the
+/// provided type.
+static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
+ llvm::ArrayRef<int64_t> statics,
+ ValueRange dynamics,
+ Type type = Type()) {
+ SmallVector<Value> values;
+ auto dyn = dynamics.begin();
+ Type i64 = b.getI64Type();
+ if (!type)
+ type = i64;
+ assert((i64 == type || b.getIndexType() == type) &&
+ "expected an i64 or an intex type");
+ for (auto s : statics) {
+ if (s == ShapedType::kDynamic) {
+ values.emplace_back(*(dyn++));
+ } else {
+ TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s);
+ values.emplace_back(arith::ConstantOp::create(b, loc, type, val));
+ }
+ }
+ return values;
+}
+
+/// Create operations converting a linear index to a multi-dimensional index.
+static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
+ Value linearIndex,
+ ValueRange dimensions) {
+ int n = dimensions.size();
+ SmallVector<Value> multiIndex(n);
+
+ for (int i = n - 1; i >= 0; --i) {
+ multiIndex[i] = arith::RemSIOp::create(b, loc, linearIndex, dimensions[i]);
+ if (i > 0)
+ linearIndex = arith::DivSIOp::create(b, loc, linearIndex, dimensions[i]);
+ }
+
+ return multiIndex;
+}
+
+/// Create operations converting a multi-dimensional index to a linear index.
+Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
+ ValueRange dimensions) {
+
+ Value linearIndex = arith::ConstantIndexOp::create(b, loc, 0);
+ Value stride = arith::ConstantIndexOp::create(b, loc, 1);
+
+ for (int i = multiIndex.size() - 1; i >= 0; --i) {
+ Value off = arith::MulIOp::create(b, loc, multiIndex[i], stride);
+ linearIndex = arith::AddIOp::create(b, loc, linearIndex, off);
+ stride = arith::MulIOp::create(b, loc, stride, dimensions[i]);
+ }
+
+ return linearIndex;
+}
+
+/// Replace GetShardingOp with related/dependent ShardingOp.
+struct ConvertGetShardingOp : public OpConversionPattern<GetShardingOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(GetShardingOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto shardOp = adaptor.getSource().getDefiningOp<ShardOp>();
+ if (!shardOp)
+ return failure();
+ auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>();
+ if (!shardingOp)
+ return failure();
+
+ rewriter.replaceOp(op, shardingOp.getResult());
+ return success();
+ }
+};
+
+/// Convert a sharding op to a tuple of tensors of its components
+/// (SplitAxes, HaloSizes, ShardedDimsOffsets)
+/// as defined by type converter.
+struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ShardingOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto splitAxes = op.getSplitAxes().getAxes();
+ int64_t maxNAxes = 0;
+ for (auto axes : splitAxes)
+ maxNAxes = std::max<int64_t>(maxNAxes, axes.size());
+
+ // To hold the split axes, create empty 2d tensor with shape
+ // {splitAxes.size(), max-size-of-split-groups}.
+ // Set trailing elements for smaller split-groups to -1.
+ Location loc = op.getLoc();
+ auto i16 = rewriter.getI16Type();
+ auto i64 = rewriter.getI64Type();
+ std::array<int64_t, 2> shape = {static_cast<int64_t>(splitAxes.size()),
+ maxNAxes};
+ Value resSplitAxes = tensor::EmptyOp::create(rewriter, loc, shape, i16);
+ auto attr = IntegerAttr::get(i16, -1);
+ Value fillValue = arith::ConstantOp::create(rewriter, loc, i16, attr);
+ resSplitAxes =
+ linalg::FillOp::create(rewriter, loc, fillValue, resSplitAxes)
+ .getResult(0);
+
+ // explicitly write values into tensor row by row
+ std::array<int64_t, 2> strides = {1, 1};
+ int64_t nSplits = 0;
+ ValueRange empty = {};
+ for (auto [i, axes] : llvm::enumerate(splitAxes)) {
+ int64_t size = axes.size();
+ if (size > 0)
+ ++nSplits;
+ std::array<int64_t, 2> offs = {(int64_t)i, 0};
+ std::array<int64_t, 2> sizes = {1, size};
+ auto tensorType = RankedTensorType::get({size}, i16);
+ auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef());
+ auto vals = arith::ConstantOp::create(rewriter, loc, tensorType, attrs);
+ resSplitAxes = tensor::InsertSliceOp::create(rewriter, loc, vals,
+ resSplitAxes, empty, empty,
+ empty, offs, sizes, strides);
+ }
+
+ // To hold halos sizes, create 2d Tensor with shape {nSplits, 2}.
+ // Store the halo sizes in the tensor.
+ SmallVector<Value> haloSizes =
+ getMixedAsValues(rewriter, loc, adaptor.getStaticHaloSizes(),
+ adaptor.getDynamicHaloSizes());
+ auto type = RankedTensorType::get({nSplits, 2}, i64);
+ Value resHaloSizes =
+ haloSizes.empty()
+ ? tensor::EmptyOp::create(rewriter, loc,
+ std::array<int64_t, 2>{0, 0}, i64)
+ .getResult()
+ : tensor::FromElementsOp::create(rewriter, loc, type, haloSizes)
+ .getResult();
+
+ // To hold sharded dims offsets, create Tensor with shape {nSplits,
+ // maxSplitSize+1}. Store the offsets in the tensor but set trailing
+ // elements for smaller split-groups to -1. Computing the max size of the
+ // split groups needs using collectiveProcessGroupSize (which needs the
+ // GridOp)
+ Value resOffsets;
+ if (adaptor.getStaticShardedDimsOffsets().empty()) {
+ resOffsets = tensor::EmptyOp::create(rewriter, loc,
+ std::array<int64_t, 2>{0, 0}, i64);
+ } else {
+ SymbolTableCollection symbolTableCollection;
+ auto gridOp = getGrid(op, symbolTableCollection);
+ int64_t maxSplitSize = 0;
+ for (auto axes : splitAxes) {
+ int64_t splitSize =
+ collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape());
+ assert(splitSize != ShapedType::kDynamic);
+ maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize);
+ }
+ assert(maxSplitSize);
+ ++maxSplitSize; // add one for the total size
+
+ resOffsets = tensor::EmptyOp::create(
+ rewriter, loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64);
+ Value zero = arith::ConstantOp::create(
+ rewriter, loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic));
+ resOffsets =
+ linalg::FillOp::create(rewriter, loc, zero, resOffsets).getResult(0);
+ SmallVector<Value> offsets =
+ getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(),
+ adaptor.getDynamicShardedDimsOffsets());
+ int64_t curr = 0;
+ for (auto [i, axes] : llvm::enumerate(splitAxes)) {
+ int64_t splitSize =
+ collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape());
+ assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
+ ++splitSize; // add one for the total size
+ ArrayRef<Value> values(&offsets[curr], splitSize);
+ Value vals = tensor::FromElementsOp::create(rewriter, loc, values);
+ std::array<int64_t, 2> offs = {static_cast<int64_t>(i), 0};
+ std::array<int64_t, 2> sizes = {1, splitSize};
+ resOffsets = tensor::InsertSliceOp::create(rewriter, loc, vals,
+ resOffsets, empty, empty,
+ empty, offs, sizes, strides);
+ curr += splitSize;
+ }
+ }
+
+ // return a tuple of tensors as defined by type converter
+ SmallVector<Type> resTypes;
+ if (failed(getTypeConverter()->convertType(op.getResult().getType(),
+ resTypes)))
+ return failure();
+
+ resSplitAxes =
+ tensor::CastOp::create(rewriter, loc, resTypes[0], resSplitAxes);
+ resHaloSizes =
+ tensor::CastOp::create(rewriter, loc, resTypes[1], resHaloSizes);
+ resOffsets = tensor::CastOp::create(rewriter, loc, resTypes[2], resOffsets);
+
+ rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
+ op, TupleType::get(op.getContext(), resTypes),
+ ValueRange{resSplitAxes, resHaloSizes, resOffsets});
+
+ return success();
+ }
+};
+
+struct ConvertProcessMultiIndexOp
+ : public OpConversionPattern<ProcessMultiIndexOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ProcessMultiIndexOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ // Currently converts its linear index to a multi-dimensional index.
+
+ SymbolTableCollection symbolTableCollection;
+ Location loc = op.getLoc();
+ auto gridOp = getGrid(op, symbolTableCollection);
+ // For now we only support static grid shapes
+ if (ShapedType::isDynamicShape(gridOp.getShape()))
+ return failure();
+
+ SmallVector<Value> dims;
+ llvm::transform(
+ gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
+ return arith::ConstantIndexOp::create(rewriter, loc, i).getResult();
+ });
+ Value rank = ProcessLinearIndexOp::create(rewriter, op.getLoc(), gridOp);
+ auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims);
+
+ // optionally extract subset of grid axes
+ auto axes = adaptor.getAxes();
+ if (!axes.empty()) {
+ SmallVector<Value> subIndex;
+ for (auto axis : axes) {
+ subIndex.emplace_back(mIdx[axis]);
+ }
+ mIdx = std::move(subIndex);
+ }
+
+ rewriter.replaceOp(op, mIdx);
+ return success();
+ }
+};
+
+class ConvertProcessLinearIndexOp
+ : public OpConversionPattern<ProcessLinearIndexOp> {
+
+public:
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Create mpi::CommRankOp
+ Location loc = op.getLoc();
+ auto ctx = op.getContext();
+ Value commWorld =
+ mpi::CommWorldOp::create(rewriter, loc, mpi::CommType::get(ctx));
+ auto rank = mpi::CommRankOp::create(
+ rewriter, loc,
+ TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()},
+ commWorld)
+ .getRank();
+ rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
+ rank);
+ return success();
+ }
+};
+
+struct ConvertNeighborsLinearIndicesOp
+ : public OpConversionPattern<NeighborsLinearIndicesOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(NeighborsLinearIndicesOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ // Computes the neighbors indices along a split axis by simply
+ // adding/subtracting 1 to the current index in that dimension.
+ // Assigns -1 if neighbor is out of bounds.
+
+ auto axes = adaptor.getSplitAxes();
+ // For now only single axis sharding is supported
+ if (axes.size() != 1)
+ return failure();
+
+ Location loc = op.getLoc();
+ SymbolTableCollection symbolTableCollection;
+ auto gridOp = getGrid(op, symbolTableCollection);
+ auto mIdx = adaptor.getDevice();
+ auto orgIdx = mIdx[axes[0]];
+ SmallVector<Value> dims;
+ llvm::transform(
+ gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
+ return arith::ConstantIndexOp::create(rewriter, loc, i).getResult();
+ });
+ Value dimSz = dims[axes[0]];
+ Value one = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ Value minus1 = arith::ConstantIndexOp::create(rewriter, loc, -1);
+ Value atBorder =
+ arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sle, orgIdx,
+ arith::ConstantIndexOp::create(rewriter, loc, 0));
+ auto down = scf::IfOp::create(
+ rewriter, loc, atBorder,
+ [&](OpBuilder &builder, Location loc) {
+ scf::YieldOp::create(builder, loc, minus1);
+ },
+ [&](OpBuilder &builder, Location loc) {
+ SmallVector<Value> tmp = mIdx;
+ tmp[axes[0]] =
+ arith::SubIOp::create(rewriter, op.getLoc(), orgIdx, one)
+ .getResult();
+ scf::YieldOp::create(builder, loc,
+ multiToLinearIndex(loc, rewriter, tmp, dims));
+ });
+ atBorder = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::sge, orgIdx,
+ arith::SubIOp::create(rewriter, loc, dimSz, one).getResult());
+ auto up = scf::IfOp::create(
+ rewriter, loc, atBorder,
+ [&](OpBuilder &builder, Location loc) {
+ scf::YieldOp::create(builder, loc, minus1);
+ },
+ [&](OpBuilder &builder, Location loc) {
+ SmallVector<Value> tmp = mIdx;
+ tmp[axes[0]] =
+ arith::AddIOp::create(rewriter, op.getLoc(), orgIdx, one);
+ scf::YieldOp::create(builder, loc,
+ multiToLinearIndex(loc, rewriter, tmp, dims));
+ });
+ rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)});
+ return success();
+ }
+};
+
+struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ShardShapeOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto sharding = op.getSharding().getDefiningOp<ShardingOp>();
+ if (!sharding) {
+ return op->emitError()
+ << "Expected ShardingOp as defining op for sharding"
+ << " but found " << adaptor.getSharding()[0].getDefiningOp();
+ }
+
+ // Compute the sharded shape by applying the sharding to the input shape.
+ // If shardedDimsOffsets is not defined in the sharding, the shard shape is
+ // computed by dividing the dimension size by the number of shards in that
+ // dimension (which is given by the size of the grid axes provided in
+ // split-axes). Odd elements get distributed to trailing shards. If a
+ // shardedDimsOffsets is provided, the shard shape is computed by
+ // subtracting the offset of the current shard from the offset of the next
+ // shard.
+
+ Location loc = op.getLoc();
+ Type index = rewriter.getIndexType();
+
+ // This is a 1:N conversion because the sharding op is a 1:3 conversion.
+ // The operands in the adaptor are a vector<ValeRange>. For dims and device
+ // we have a 1:1 conversion.
+ // For simpler access fill a vector with the dynamic dims.
+ SmallVector<Value> dynDims, dynDevice;
+ for (auto dim : adaptor.getDimsDynamic()) {
+ // type conversion should be 1:1 for ints
+ dynDims.emplace_back(llvm::getSingleElement(dim));
+ }
+ // same for device
+ for (auto device : adaptor.getDeviceDynamic()) {
+ dynDevice.emplace_back(llvm::getSingleElement(device));
+ }
+
+ // To keep the code simple, convert dims/device to values when they are
+ // attributes. Count on canonicalization to fold static values.
+ SmallVector<Value> shape =
+ getMixedAsValues(rewriter, loc, op.getDims(), dynDims, index);
+ SmallVector<Value> multiIdx =
+ getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index);
+
+ // Get the GridOp, the grid shape is needed to compute the sharded shape.
+ SymbolTableCollection symbolTableCollection;
+ auto gridOp = getGrid(sharding, symbolTableCollection);
+ // For now we only support static grid shapes
+ if (ShapedType::isDynamicShape(gridOp.getShape()))
+ return failure();
+
+ auto splitAxes = sharding.getSplitAxes().getAxes();
+ // shardedDimsOffsets are optional and might be Values (not attributes).
+ // Also, the shardId might be dynamic which means the position in the
+ // shardedDimsOffsets is not statically known. Create a tensor of the
+ // shardedDimsOffsets and later extract the offsets for computing the
+ // local shard-size.
+ Value shardedDimsOffs;
+ {
+ SmallVector<Value> tmp = getMixedAsValues(
+ rewriter, loc, sharding.getStaticShardedDimsOffsets(),
+ sharding.getDynamicShardedDimsOffsets(), index);
+ if (!tmp.empty())
+ shardedDimsOffs = tensor::FromElementsOp::create(
+ rewriter, loc, RankedTensorType::get({(int64_t)tmp.size()}, index),
+ tmp);
+ }
+
+ // With static grid shape the sizes of the split axes are known.
+ // Hence the start/pos for each split axes in shardDimsOffsets can be
+ // computed statically.
+ int64_t pos = 0;
+ SmallVector<Value> shardShape;
+ Value zero =
+ arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(index));
+ Value one =
+ arith::ConstantOp::create(rewriter, loc, rewriter.getOneAttr(index));
+
+ // Iterate over the dimensions of the tensor shape, get their split Axes,
+ // and compute the sharded shape.
+ for (auto [i, dim] : llvm::enumerate(shape)) {
+ // Trailing dimensions might not be annotated.
+ if (i < splitAxes.size() && !splitAxes[i].empty()) {
+ auto axes = splitAxes[i];
+ // The current dimension might not be sharded.
+ // Create a value from the static position in shardDimsOffsets.
+ Value posVal = arith::ConstantOp::create(rewriter, loc,
+ rewriter.getIndexAttr(pos));
+ // Get the index of the local shard in the grid axis.
+ Value idx = multiIdx[axes[0]];
+ auto numShards =
+ collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape());
+ if (shardedDimsOffs) {
+ // If sharded dims offsets are provided, use them to compute the
+ // sharded shape.
+ if (axes.size() > 1) {
+ return op->emitError() << "Only single axis sharding is "
+ << "supported for each dimension.";
+ }
+ idx = arith::AddIOp::create(rewriter, loc, posVal, idx);
+ // Compute size = shardedDimsOffs[idx+1] - shardedDimsOffs[idx].
+ Value off =
+ tensor::ExtractOp::create(rewriter, loc, shardedDimsOffs, idx);
+ idx = arith::AddIOp::create(rewriter, loc, idx, one);
+ Value nextOff =
+ tensor::ExtractOp::create(rewriter, loc, shardedDimsOffs, idx);
+ Value sz = arith::SubIOp::create(rewriter, loc, nextOff, off);
+ shardShape.emplace_back(sz);
+ } else {
+ Value numShardsVal = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getIndexAttr(numShards));
+ // Compute shard dim size by distributing odd elements to trailing
+ // shards:
+ // sz = dim / numShards
+ // + (idx >= (numShards - (dim % numShards)) ? 1 : 0)
+ Value sz = arith::DivSIOp::create(rewriter, loc, dim, numShardsVal);
+ Value sz1 = arith::RemSIOp::create(rewriter, loc, dim, numShardsVal);
+ sz1 = arith::SubIOp::create(rewriter, loc, numShardsVal, sz1);
+ auto cond = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::sge, idx, sz1);
+ Value odd = arith::SelectOp::create(rewriter, loc, cond, one, zero);
+ sz = arith::AddIOp::create(rewriter, loc, sz, odd);
+ shardShape.emplace_back(sz);
+ }
+ pos += numShards + 1; // add one for the total size.
+ } // else no sharding if split axis is empty or no split axis
+ // If no size was added -> no sharding in this dimension.
+ if (shardShape.size() <= i)
+ shardShape.emplace_back(dim);
+ }
+ assert(shardShape.size() == shape.size());
+ rewriter.replaceOp(op, shardShape);
+ return success();
+ }
+};
+
+static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp(ReductionKindAttr kind) {
+ auto ctx = kind.getContext();
+ auto getReductionOp = [ctx](mpi::MPI_ReductionOpEnum redOp) {
+ return mpi::MPI_ReductionOpEnumAttr::get(ctx, redOp);
+ };
+
+ switch (kind.getValue()) {
+ case ReductionKind::Sum:
+ return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_SUM);
+ case ReductionKind::Product:
+ return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_PROD);
+ case ReductionKind::Min:
+ return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_MIN);
+ case ReductionKind::Max:
+ return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_MAX);
+ case ReductionKind::BitwiseAnd:
+ return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BAND);
+ case ReductionKind::BitwiseOr:
+ return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BOR);
+ case ReductionKind::BitwiseXor:
+ return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BXOR);
+ default:
+ llvm_unreachable("Unknown/unsupported reduction kind");
+ }
+}
+
+struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ SymbolTableCollection symbolTableCollection;
+ auto grid = adaptor.getGrid();
+ mlir::shard::GridOp gridOp = getGrid(op, symbolTableCollection);
+ if (!gridOp)
+ return op->emitError() << "No grid found for AllReduceOp";
+ if (ShapedType::isDynamicShape(gridOp.getShape()))
+ return op->emitError()
+ << "Dynamic grid shape not supported in AllReduceOp";
+
+ ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
+ Value input = adaptor.getInput();
+ auto inputShape = cast<ShapedType>(input.getType()).getShape();
+
+ // If the source is a memref, cast it to a tensor.
+ if (isa<RankedTensorType>(input.getType())) {
+ auto memrefType = MemRefType::get(
+ inputShape, cast<ShapedType>(input.getType()).getElementType());
+ input = bufferization::ToBufferOp::create(iBuilder, memrefType, input);
+ }
+ MemRefType inType = cast<MemRefType>(input.getType());
+
+ // Get the actual shape to allocate the buffer.
+ SmallVector<OpFoldResult> shape(inType.getRank());
+ for (auto i = 0; i < inType.getRank(); ++i) {
+ auto s = inputShape[i];
+ if (ShapedType::isDynamic(s))
+ shape[i] = memref::DimOp::create(iBuilder, input, s).getResult();
+ else
+ shape[i] = iBuilder.getIndexAttr(s);
+ }
+
+ // Allocate buffer and copy input to buffer.
+ Value buffer = memref::AllocOp::create(
+ iBuilder, shape, cast<ShapedType>(op.getType()).getElementType());
+ linalg::CopyOp::create(iBuilder, input, buffer);
+
+ // Get an MPI_Comm_split for the AllReduce operation.
+ // The color is the linear index of the process in the grid along the
+ // non-reduced axes. The key is the linear index of the process in the grid
+ // along the reduced axes.
+ SmallVector<Type> indexResultTypes(gridOp.getShape().size(),
+ iBuilder.getIndexType());
+ SmallVector<Value> myMultiIndex =
+ ProcessMultiIndexOp::create(iBuilder, indexResultTypes, grid)
+ .getResult();
+ Value zero = arith::ConstantIndexOp::create(iBuilder, 0);
+ SmallVector<Value> multiKey(myMultiIndex.size(), zero);
+
+ auto redAxes = adaptor.getGridAxes();
+ for (auto axis : redAxes) {
+ multiKey[axis] = myMultiIndex[axis];
+ myMultiIndex[axis] = zero;
+ }
+
+ Value color =
+ createProcessLinearIndex(grid, myMultiIndex, redAxes, iBuilder);
+ color = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), color);
+ Value key = createProcessLinearIndex(grid, multiKey, redAxes, iBuilder);
+ key = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), key);
+
+ // Finally split the communicator
+ auto commType = mpi::CommType::get(op->getContext());
+ Value commWorld = mpi::CommWorldOp::create(iBuilder, commType);
+ auto comm =
+ mpi::CommSplitOp::create(iBuilder, commType, commWorld, color, key)
+ .getNewcomm();
+
+ Value buffer1d = buffer;
+ // Collapse shape to 1d if needed
+ if (inType.getRank() > 1) {
+ ReassociationIndices reassociation(inType.getRank());
+ std::iota(reassociation.begin(), reassociation.end(), 0);
+ buffer1d = memref::CollapseShapeOp::create(
+ iBuilder, buffer, ArrayRef<ReassociationIndices>(reassociation));
+ }
+
+ // Create the MPI AllReduce operation.
+ mpi::AllReduceOp::create(iBuilder, TypeRange(), buffer1d, buffer1d,
+ getMPIReductionOp(adaptor.getReductionAttr()),
+ comm);
+
+ // If the destination is a memref, cast it to a tensor
+ if (isa<RankedTensorType>(op.getType()))
+ buffer = bufferization::ToTensorOp::create(iBuilder, op.getType(), buffer,
+ true);
+
+ rewriter.replaceOp(op, buffer);
+ return success();
+ }
+};
+
+struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(UpdateHaloOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ // The input/output memref is assumed to be in C memory order.
+ // Halos are exchanged as 2 blocks per dimension (one for each side: down
+ // and up). For each haloed dimension `d`, the exchanged blocks are
+ // expressed as multi-dimensional subviews. The subviews include potential
+ // halos of higher dimensions `dh > d`, no halos for the lower dimensions
+ // `dl < d` and for dimension `d` the currently exchanged halo only.
+ // By iterating form higher to lower dimensions this also updates the halos
+ // in the 'corners'.
+ // memref.subview is used to read and write the halo data from and to the
+ // local data. Because subviews and halos can have mixed dynamic and static
+ // shapes, OpFoldResults are used whenever possible.
+
+ auto haloSizes = getMixedValues(adaptor.getStaticHaloSizes(),
+ adaptor.getHaloSizes(), rewriter);
+ if (haloSizes.empty()) {
+ // no halos -> nothing to do
+ rewriter.replaceOp(op, adaptor.getDestination());
+ return success();
+ }
+
+ SymbolTableCollection symbolTableCollection;
+ Location loc = op.getLoc();
+
+ // convert a OpFoldResult into a Value
+ auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value {
+ if (auto value = dyn_cast<Value>(v))
+ return value;
+ return arith::ConstantOp::create(
+ rewriter, loc,
+ rewriter.getIndexAttr(
+ cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
+ };
+
+ auto dest = adaptor.getDestination();
+ auto dstShape = cast<ShapedType>(dest.getType()).getShape();
+ Value array = dest;
+ if (isa<RankedTensorType>(array.getType())) {
+ // If the destination is a memref, we need to cast it to a tensor
+ auto mmemrefType = MemRefType::get(
+ dstShape, cast<ShapedType>(array.getType()).getElementType());
+ array =
+ bufferization::ToBufferOp::create(rewriter, loc, mmemrefType, array);
+ }
+ auto rank = cast<ShapedType>(array.getType()).getRank();
+ auto opSplitAxes = adaptor.getSplitAxes().getAxes();
+ auto grid = adaptor.getGrid();
+ auto gridOp = getGrid(op, symbolTableCollection);
+ // subviews need Index values
+ for (auto &sz : haloSizes) {
+ if (auto value = dyn_cast<Value>(sz))
+ sz = arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(),
+ value)
+ .getResult();
+ }
+
+ // most of the offset/size/stride data is the same for all dims
+ SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
+ SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
+ SmallVector<OpFoldResult> shape(rank), dimSizes(rank);
+ auto currHaloDim = -1; // halo sizes are provided for split dimensions only
+ // we need the actual shape to compute offsets and sizes
+ for (auto i = 0; i < rank; ++i) {
+ auto s = dstShape[i];
+ if (ShapedType::isDynamic(s))
+ shape[i] = memref::DimOp::create(rewriter, loc, array, s).getResult();
+ else
+ shape[i] = rewriter.getIndexAttr(s);
+
+ if ((size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) {
+ ++currHaloDim;
+ // the offsets for lower dim sstarts after their down halo
+ offsets[i] = haloSizes[currHaloDim * 2];
+
+ // prepare shape and offsets of highest dim's halo exchange
+ Value _haloSz = arith::AddIOp::create(
+ rewriter, loc, toValue(haloSizes[currHaloDim * 2]),
+ toValue(haloSizes[currHaloDim * 2 + 1]));
+ // the halo shape of lower dims exlude the halos
+ dimSizes[i] =
+ arith::SubIOp::create(rewriter, loc, toValue(shape[i]), _haloSz)
+ .getResult();
+ } else {
+ dimSizes[i] = shape[i];
+ }
+ }
+
+ auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something
+ auto tag = arith::ConstantOp::create(rewriter, loc, tagAttr);
+ auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
+ auto zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
+
+ SmallVector<Type> indexResultTypes(gridOp.getShape().size(),
+ rewriter.getIndexType());
+ auto myMultiIndex =
+ ProcessMultiIndexOp::create(rewriter, loc, indexResultTypes, grid)
+ .getResult();
+ // traverse all split axes from high to low dim
+ for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
+ auto splitAxes = opSplitAxes[dim];
+ if (splitAxes.empty())
+ continue;
+ assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2);
+ // Get the linearized ids of the neighbors (down and up) for the
+ // given split
+ auto tmp = NeighborsLinearIndicesOp::create(rewriter, loc, grid,
+ myMultiIndex, splitAxes)
+ .getResults();
+ // MPI operates on i32...
+ Value neighbourIDs[2] = {
+ arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(),
+ tmp[0]),
+ arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(),
+ tmp[1])};
+
+ auto lowerRecvOffset = rewriter.getIndexAttr(0);
+ auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]);
+ auto upperRecvOffset =
+ arith::SubIOp::create(rewriter, loc, toValue(shape[dim]),
+ toValue(haloSizes[currHaloDim * 2 + 1]));
+ auto upperSendOffset = arith::SubIOp::create(
+ rewriter, loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2]));
+
+ Value commWorld = mpi::CommWorldOp::create(
+ rewriter, loc, mpi::CommType::get(op->getContext()));
+
+ // Make sure we send/recv in a way that does not lead to a dead-lock.
+ // The current approach is by far not optimal, this should be at least
+ // be a red-black pattern or using MPI_sendrecv.
+ // Also, buffers should be re-used.
+ // Still using temporary contiguous buffers for MPI communication...
+ // Still yielding a "serialized" communication pattern...
+ auto genSendRecv = [&](bool upperHalo) {
+ auto orgOffset = offsets[dim];
+ dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
+ : haloSizes[currHaloDim * 2];
+ // Check if we need to send and/or receive
+ // Processes on the grid borders have only one neighbor
+ auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
+ auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
+ auto hasFrom = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::sge, from, zero);
+ auto hasTo = arith::CmpIOp::create(rewriter, loc,
+ arith::CmpIPredicate::sge, to, zero);
+ auto buffer = memref::AllocOp::create(
+ rewriter, loc, dimSizes,
+ cast<ShapedType>(array.getType()).getElementType());
+ // if has neighbor: copy halo data from array to buffer and send
+ scf::IfOp::create(
+ rewriter, loc, hasTo, [&](OpBuilder &builder, Location loc) {
+ offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset)
+ : OpFoldResult(upperSendOffset);
+ auto subview = memref::SubViewOp::create(
+ builder, loc, array, offsets, dimSizes, strides);
+ memref::CopyOp::create(builder, loc, subview, buffer);
+ mpi::SendOp::create(builder, loc, TypeRange{}, buffer, tag, to,
+ commWorld);
+ scf::YieldOp::create(builder, loc);
+ });
+ // if has neighbor: receive halo data into buffer and copy to array
+ scf::IfOp::create(
+ rewriter, loc, hasFrom, [&](OpBuilder &builder, Location loc) {
+ offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset)
+ : OpFoldResult(lowerRecvOffset);
+ mpi::RecvOp::create(builder, loc, TypeRange{}, buffer, tag, from,
+ commWorld);
+ auto subview = memref::SubViewOp::create(
+ builder, loc, array, offsets, dimSizes, strides);
+ memref::CopyOp::create(builder, loc, buffer, subview);
+ scf::YieldOp::create(builder, loc);
+ });
+ memref::DeallocOp::create(rewriter, loc, buffer);
+ offsets[dim] = orgOffset;
+ };
+
+ auto doSendRecv = [&](int upOrDown) {
+ OpFoldResult &v = haloSizes[currHaloDim * 2 + upOrDown];
+ Value haloSz = dyn_cast<Value>(v);
+ if (!haloSz)
+ haloSz = arith::ConstantOp::create(
+ rewriter, loc,
+ rewriter.getI32IntegerAttr(
+ cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
+ auto hasSize = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::sgt, haloSz, zero);
+ scf::IfOp::create(rewriter, loc, hasSize,
+ [&](OpBuilder &builder, Location loc) {
+ genSendRecv(upOrDown > 0);
+ scf::YieldOp::create(builder, loc);
+ });
+ };
+
+ doSendRecv(0);
+ doSendRecv(1);
+
+ // the shape for lower dims include higher dims' halos
+ dimSizes[dim] = shape[dim];
+ // -> the offset for higher dims is always 0
+ offsets[dim] = rewriter.getIndexAttr(0);
+ // on to next halo
+ --currHaloDim;
+ }
+
+ if (isa<MemRefType>(op.getResult().getType())) {
+ rewriter.replaceOp(op, array);
+ } else {
+ assert(isa<RankedTensorType>(op.getResult().getType()));
+ rewriter.replaceOp(op, bufferization::ToTensorOp::create(
+ rewriter, loc, op.getResult().getType(), array,
+ /*restrict=*/true, /*writable=*/true));
+ }
+ return success();
+ }
+};
+
+struct ConvertShardToMPIPass
+ : public impl::ConvertShardToMPIPassBase<ConvertShardToMPIPass> {
+ using Base::Base;
+
+ /// Run the dialect converter on the module.
+ void runOnOperation() override {
+ auto *ctxt = &getContext();
+ RewritePatternSet patterns(ctxt);
+ ConversionTarget target(getContext());
+
+ // Define a type converter to convert shard::ShardingType,
+ // mostly for use in return operations.
+ TypeConverter typeConverter;
+ typeConverter.addConversion([](Type type) { return type; });
+
+ // convert shard::ShardingType to a tuple of RankedTensorTypes
+ typeConverter.addConversion(
+ [](ShardingType type,
+ SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
+ auto i16 = IntegerType::get(type.getContext(), 16);
+ auto i64 = IntegerType::get(type.getContext(), 64);
+ std::array<int64_t, 2> shp = {ShapedType::kDynamic,
+ ShapedType::kDynamic};
+ results.emplace_back(RankedTensorType::get(shp, i16));
+ results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
+ results.emplace_back(RankedTensorType::get(shp, i64));
+ return success();
+ });
+
+ // To 'extract' components, a UnrealizedConversionCastOp is expected
+ // to define the input
+ typeConverter.addTargetMaterialization(
+ [&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
+ Location loc) {
+ // Expecting a single input.
+ if (inputs.size() != 1 || !isa<TupleType>(inputs[0].getType()))
+ return SmallVector<Value>();
+ auto castOp = inputs[0].getDefiningOp<UnrealizedConversionCastOp>();
+ // Expecting an UnrealizedConversionCastOp.
+ if (!castOp)
+ return SmallVector<Value>();
+ // Fill a vector with elements of the tuple/castOp.
+ SmallVector<Value> results;
+ for (auto oprnd : castOp.getInputs()) {
+ if (!isa<RankedTensorType>(oprnd.getType()))
+ return SmallVector<Value>();
+ results.emplace_back(oprnd);
+ }
+ return results;
+ });
+
+ // No shard dialect should left after conversion...
+ target.addIllegalDialect<shard::ShardDialect>();
+ // ...except the global GridOp. GridShapeOp which will get folded later.
+ target.addLegalOp<shard::GridOp, shard::GridShapeOp>();
+ // Allow all the stuff that our patterns will convert to
+ target.addLegalDialect<
+ BuiltinDialect, mpi::MPIDialect, scf::SCFDialect, arith::ArithDialect,
+ tensor::TensorDialect, bufferization::BufferizationDialect,
+ linalg::LinalgDialect, memref::MemRefDialect, affine::AffineDialect>();
+ // Make sure the function signature, calls etc. are legal
+ target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
+ return typeConverter.isSignatureLegal(op.getFunctionType());
+ });
+ target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(
+ [&](Operation *op) { return typeConverter.isLegal(op); });
+
+ patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
+ ConvertProcessMultiIndexOp, ConvertGetShardingOp,
+ ConvertShardingOp, ConvertShardShapeOp, ConvertAllReduceOp,
+ ConvertProcessLinearIndexOp>(typeConverter, ctxt);
+
+ populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
+ patterns, typeConverter);
+ populateCallOpTypeConversionPattern(patterns, typeConverter);
+ populateReturnOpTypeConversionPattern(patterns, typeConverter);
+
+ (void)applyPartialConversion(getOperation(), target, std::move(patterns));
+
+ // Folding patterns cannot be mixed with conversion patterns -> extra pass.
+ patterns.clear();
+ SymbolTableCollection symbolTableCollection;
+ mlir::shard::populateFoldingPatterns(patterns, symbolTableCollection);
+ (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+ }
+};
+
+} // namespace