diff options
Diffstat (limited to 'mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp')
-rw-r--r-- | mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 950 |
1 files changed, 0 insertions, 950 deletions
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp deleted file mode 100644 index b931284..0000000 --- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp +++ /dev/null @@ -1,950 +0,0 @@ -//===- MeshToMPI.cpp - Mesh to MPI dialect conversion -----------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements a translation of Mesh communication ops tp MPI ops. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Conversion/MeshToMPI/MeshToMPI.h" - -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Func/Transforms/FuncConversions.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/MPI/IR/MPI.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Mesh/IR/MeshDialect.h" -#include "mlir/Dialect/Mesh/IR/MeshOps.h" -#include "mlir/Dialect/Mesh/Transforms/Simplifications.h" -#include "mlir/Dialect/Mesh/Transforms/Transforms.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/SymbolTable.h" -#include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -#define DEBUG_TYPE "mesh-to-mpi" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") - -namespace mlir { -#define GEN_PASS_DEF_CONVERTMESHTOMPIPASS -#include "mlir/Conversion/Passes.h.inc" -} // namespace mlir - -using namespace mlir; -using namespace mesh; - -namespace { -/// Converts a vector of OpFoldResults (ints) into vector of Values of the -/// provided type. -static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc, - llvm::ArrayRef<int64_t> statics, - ValueRange dynamics, - Type type = Type()) { - SmallVector<Value> values; - auto dyn = dynamics.begin(); - Type i64 = b.getI64Type(); - if (!type) - type = i64; - assert((i64 == type || b.getIndexType() == type) && - "expected an i64 or an intex type"); - for (auto s : statics) { - if (s == ShapedType::kDynamic) { - values.emplace_back(*(dyn++)); - } else { - TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s); - values.emplace_back(b.create<arith::ConstantOp>(loc, type, val)); - } - } - return values; -} - -/// Create operations converting a linear index to a multi-dimensional index. -static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b, - Value linearIndex, - ValueRange dimensions) { - int n = dimensions.size(); - SmallVector<Value> multiIndex(n); - - for (int i = n - 1; i >= 0; --i) { - multiIndex[i] = b.create<arith::RemSIOp>(loc, linearIndex, dimensions[i]); - if (i > 0) - linearIndex = b.create<arith::DivSIOp>(loc, linearIndex, dimensions[i]); - } - - return multiIndex; -} - -/// Create operations converting a multi-dimensional index to a linear index. -Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex, - ValueRange dimensions) { - - Value linearIndex = b.create<arith::ConstantIndexOp>(loc, 0); - Value stride = b.create<arith::ConstantIndexOp>(loc, 1); - - for (int i = multiIndex.size() - 1; i >= 0; --i) { - Value off = b.create<arith::MulIOp>(loc, multiIndex[i], stride); - linearIndex = b.create<arith::AddIOp>(loc, linearIndex, off); - stride = b.create<arith::MulIOp>(loc, stride, dimensions[i]); - } - - return linearIndex; -} - -/// Replace GetShardingOp with related/dependent ShardingOp. -struct ConvertGetShardingOp : public OpConversionPattern<GetShardingOp> { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(GetShardingOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto shardOp = adaptor.getSource().getDefiningOp<ShardOp>(); - if (!shardOp) - return failure(); - auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>(); - if (!shardingOp) - return failure(); - - rewriter.replaceOp(op, shardingOp.getResult()); - return success(); - } -}; - -/// Convert a sharding op to a tuple of tensors of its components -/// (SplitAxes, HaloSizes, ShardedDimsOffsets) -/// as defined by type converter. -struct ConvertShardingOp : public OpConversionPattern<ShardingOp> { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(ShardingOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto splitAxes = op.getSplitAxes().getAxes(); - int64_t maxNAxes = 0; - for (auto axes : splitAxes) - maxNAxes = std::max<int64_t>(maxNAxes, axes.size()); - - // To hold the split axes, create empty 2d tensor with shape - // {splitAxes.size(), max-size-of-split-groups}. - // Set trailing elements for smaller split-groups to -1. - Location loc = op.getLoc(); - auto i16 = rewriter.getI16Type(); - auto i64 = rewriter.getI64Type(); - std::array<int64_t, 2> shape = {static_cast<int64_t>(splitAxes.size()), - maxNAxes}; - Value resSplitAxes = rewriter.create<tensor::EmptyOp>(loc, shape, i16); - auto attr = IntegerAttr::get(i16, -1); - Value fillValue = rewriter.create<arith::ConstantOp>(loc, i16, attr); - resSplitAxes = rewriter.create<linalg::FillOp>(loc, fillValue, resSplitAxes) - .getResult(0); - - // explicitly write values into tensor row by row - std::array<int64_t, 2> strides = {1, 1}; - int64_t nSplits = 0; - ValueRange empty = {}; - for (auto [i, axes] : llvm::enumerate(splitAxes)) { - int64_t size = axes.size(); - if (size > 0) - ++nSplits; - std::array<int64_t, 2> offs = {(int64_t)i, 0}; - std::array<int64_t, 2> sizes = {1, size}; - auto tensorType = RankedTensorType::get({size}, i16); - auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef()); - auto vals = rewriter.create<arith::ConstantOp>(loc, tensorType, attrs); - resSplitAxes = rewriter.create<tensor::InsertSliceOp>( - loc, vals, resSplitAxes, empty, empty, empty, offs, sizes, strides); - } - - // To hold halos sizes, create 2d Tensor with shape {nSplits, 2}. - // Store the halo sizes in the tensor. - SmallVector<Value> haloSizes = - getMixedAsValues(rewriter, loc, adaptor.getStaticHaloSizes(), - adaptor.getDynamicHaloSizes()); - auto type = RankedTensorType::get({nSplits, 2}, i64); - Value resHaloSizes = - haloSizes.empty() - ? rewriter - .create<tensor::EmptyOp>(loc, std::array<int64_t, 2>{0, 0}, - i64) - .getResult() - : rewriter.create<tensor::FromElementsOp>(loc, type, haloSizes) - .getResult(); - - // To hold sharded dims offsets, create Tensor with shape {nSplits, - // maxSplitSize+1}. Store the offsets in the tensor but set trailing - // elements for smaller split-groups to -1. Computing the max size of the - // split groups needs using collectiveProcessGroupSize (which needs the - // MeshOp) - Value resOffsets; - if (adaptor.getStaticShardedDimsOffsets().empty()) { - resOffsets = rewriter.create<tensor::EmptyOp>( - loc, std::array<int64_t, 2>{0, 0}, i64); - } else { - SymbolTableCollection symbolTableCollection; - auto meshOp = getMesh(op, symbolTableCollection); - int64_t maxSplitSize = 0; - for (auto axes : splitAxes) { - int64_t splitSize = - collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape()); - assert(splitSize != ShapedType::kDynamic); - maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize); - } - assert(maxSplitSize); - ++maxSplitSize; // add one for the total size - - resOffsets = rewriter.create<tensor::EmptyOp>( - loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64); - Value zero = rewriter.create<arith::ConstantOp>( - loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic)); - resOffsets = - rewriter.create<linalg::FillOp>(loc, zero, resOffsets).getResult(0); - SmallVector<Value> offsets = - getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(), - adaptor.getDynamicShardedDimsOffsets()); - int64_t curr = 0; - for (auto [i, axes] : llvm::enumerate(splitAxes)) { - int64_t splitSize = - collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape()); - assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize); - ++splitSize; // add one for the total size - ArrayRef<Value> values(&offsets[curr], splitSize); - Value vals = rewriter.create<tensor::FromElementsOp>(loc, values); - std::array<int64_t, 2> offs = {static_cast<int64_t>(i), 0}; - std::array<int64_t, 2> sizes = {1, splitSize}; - resOffsets = rewriter.create<tensor::InsertSliceOp>( - loc, vals, resOffsets, empty, empty, empty, offs, sizes, strides); - curr += splitSize; - } - } - - // return a tuple of tensors as defined by type converter - SmallVector<Type> resTypes; - if (failed(getTypeConverter()->convertType(op.getResult().getType(), - resTypes))) - return failure(); - - resSplitAxes = - rewriter.create<tensor::CastOp>(loc, resTypes[0], resSplitAxes); - resHaloSizes = - rewriter.create<tensor::CastOp>(loc, resTypes[1], resHaloSizes); - resOffsets = rewriter.create<tensor::CastOp>(loc, resTypes[2], resOffsets); - - rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>( - op, TupleType::get(op.getContext(), resTypes), - ValueRange{resSplitAxes, resHaloSizes, resOffsets}); - - return success(); - } -}; - -struct ConvertProcessMultiIndexOp - : public OpConversionPattern<ProcessMultiIndexOp> { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(ProcessMultiIndexOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - // Currently converts its linear index to a multi-dimensional index. - - SymbolTableCollection symbolTableCollection; - Location loc = op.getLoc(); - auto meshOp = getMesh(op, symbolTableCollection); - // For now we only support static mesh shapes - if (ShapedType::isDynamicShape(meshOp.getShape())) - return failure(); - - SmallVector<Value> dims; - llvm::transform( - meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) { - return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult(); - }); - Value rank = rewriter.create<ProcessLinearIndexOp>(op.getLoc(), meshOp); - auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims); - - // optionally extract subset of mesh axes - auto axes = adaptor.getAxes(); - if (!axes.empty()) { - SmallVector<Value> subIndex; - for (auto axis : axes) { - subIndex.emplace_back(mIdx[axis]); - } - mIdx = std::move(subIndex); - } - - rewriter.replaceOp(op, mIdx); - return success(); - } -}; - -class ConvertProcessLinearIndexOp - : public OpConversionPattern<ProcessLinearIndexOp> { - -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Create mpi::CommRankOp - Location loc = op.getLoc(); - auto ctx = op.getContext(); - Value commWorld = - rewriter.create<mpi::CommWorldOp>(loc, mpi::CommType::get(ctx)); - auto rank = - rewriter - .create<mpi::CommRankOp>( - loc, - TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()}, - commWorld) - .getRank(); - rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(), - rank); - return success(); - } -}; - -struct ConvertNeighborsLinearIndicesOp - : public OpConversionPattern<NeighborsLinearIndicesOp> { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(NeighborsLinearIndicesOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - // Computes the neighbors indices along a split axis by simply - // adding/subtracting 1 to the current index in that dimension. - // Assigns -1 if neighbor is out of bounds. - - auto axes = adaptor.getSplitAxes(); - // For now only single axis sharding is supported - if (axes.size() != 1) - return failure(); - - Location loc = op.getLoc(); - SymbolTableCollection symbolTableCollection; - auto meshOp = getMesh(op, symbolTableCollection); - auto mIdx = adaptor.getDevice(); - auto orgIdx = mIdx[axes[0]]; - SmallVector<Value> dims; - llvm::transform( - meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) { - return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult(); - }); - Value dimSz = dims[axes[0]]; - Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); - Value minus1 = rewriter.create<arith::ConstantIndexOp>(loc, -1); - Value atBorder = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::sle, orgIdx, - rewriter.create<arith::ConstantIndexOp>(loc, 0)); - auto down = rewriter.create<scf::IfOp>( - loc, atBorder, - [&](OpBuilder &builder, Location loc) { - builder.create<scf::YieldOp>(loc, minus1); - }, - [&](OpBuilder &builder, Location loc) { - SmallVector<Value> tmp = mIdx; - tmp[axes[0]] = - rewriter.create<arith::SubIOp>(op.getLoc(), orgIdx, one) - .getResult(); - builder.create<scf::YieldOp>( - loc, multiToLinearIndex(loc, rewriter, tmp, dims)); - }); - atBorder = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::sge, orgIdx, - rewriter.create<arith::SubIOp>(loc, dimSz, one).getResult()); - auto up = rewriter.create<scf::IfOp>( - loc, atBorder, - [&](OpBuilder &builder, Location loc) { - builder.create<scf::YieldOp>(loc, minus1); - }, - [&](OpBuilder &builder, Location loc) { - SmallVector<Value> tmp = mIdx; - tmp[axes[0]] = - rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one); - builder.create<scf::YieldOp>( - loc, multiToLinearIndex(loc, rewriter, tmp, dims)); - }); - rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)}); - return success(); - } -}; - -struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(ShardShapeOp op, OneToNOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto sharding = op.getSharding().getDefiningOp<ShardingOp>(); - if (!sharding) { - return op->emitError() - << "Expected SharingOp as defining op for sharding" - << " but found " << adaptor.getSharding()[0].getDefiningOp(); - } - - // Compute the sharded shape by applying the sharding to the input shape. - // If shardedDimsOffsets is not defined in the sharding, the shard shape is - // computed by dividing the dimension size by the number of shards in that - // dimension (which is given by the size of the mesh axes provided in - // split-axes). Odd elements get distributed to trailing shards. If a - // shardedDimsOffsets is provided, the shard shape is computed by - // subtracting the offset of the current shard from the offset of the next - // shard. - - Location loc = op.getLoc(); - Type index = rewriter.getIndexType(); - - // This is a 1:N conversion because the sharding op is a 1:3 conversion. - // The operands in the adaptor are a vector<ValeRange>. For dims and device - // we have a 1:1 conversion. - // For simpler access fill a vector with the dynamic dims. - SmallVector<Value> dynDims, dynDevice; - for (auto dim : adaptor.getDimsDynamic()) { - // type conversion should be 1:1 for ints - dynDims.emplace_back(llvm::getSingleElement(dim)); - } - // same for device - for (auto device : adaptor.getDeviceDynamic()) { - dynDevice.emplace_back(llvm::getSingleElement(device)); - } - - // To keep the code simple, convert dims/device to values when they are - // attributes. Count on canonicalization to fold static values. - SmallVector<Value> shape = - getMixedAsValues(rewriter, loc, op.getDims(), dynDims, index); - SmallVector<Value> multiIdx = - getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index); - - // Get the MeshOp, the mesh shape is needed to compute the sharded shape. - SymbolTableCollection symbolTableCollection; - auto meshOp = getMesh(sharding, symbolTableCollection); - // For now we only support static mesh shapes - if (ShapedType::isDynamicShape(meshOp.getShape())) - return failure(); - - auto splitAxes = sharding.getSplitAxes().getAxes(); - // shardedDimsOffsets are optional and might be Values (not attributes). - // Also, the shardId might be dynamic which means the position in the - // shardedDimsOffsets is not statically known. Create a tensor of the - // shardedDimsOffsets and later extract the offsets for computing the - // local shard-size. - Value shardedDimsOffs; - { - SmallVector<Value> tmp = getMixedAsValues( - rewriter, loc, sharding.getStaticShardedDimsOffsets(), - sharding.getDynamicShardedDimsOffsets(), index); - if (!tmp.empty()) - shardedDimsOffs = rewriter.create<tensor::FromElementsOp>( - loc, RankedTensorType::get({(int64_t)tmp.size()}, index), tmp); - } - - // With static mesh shape the sizes of the split axes are known. - // Hence the start/pos for each split axes in shardDimsOffsets can be - // computed statically. - int64_t pos = 0; - SmallVector<Value> shardShape; - Value zero = - rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(index)); - Value one = - rewriter.create<arith::ConstantOp>(loc, rewriter.getOneAttr(index)); - - // Iterate over the dimensions of the tensor shape, get their split Axes, - // and compute the sharded shape. - for (auto [i, dim] : llvm::enumerate(shape)) { - // Trailing dimensions might not be annotated. - if (i < splitAxes.size() && !splitAxes[i].empty()) { - auto axes = splitAxes[i]; - // The current dimension might not be sharded. - // Create a value from the static position in shardDimsOffsets. - Value posVal = - rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(pos)); - // Get the index of the local shard in the mesh axis. - Value idx = multiIdx[axes[0]]; - auto numShards = - collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape()); - if (shardedDimsOffs) { - // If sharded dims offsets are provided, use them to compute the - // sharded shape. - if (axes.size() > 1) { - return op->emitError() << "Only single axis sharding is " - << "supported for each dimension."; - } - idx = rewriter.create<arith::AddIOp>(loc, posVal, idx); - // Compute size = shardedDimsOffs[idx+1] - shardedDimsOffs[idx]. - Value off = - rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx); - idx = rewriter.create<arith::AddIOp>(loc, idx, one); - Value nextOff = - rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx); - Value sz = rewriter.create<arith::SubIOp>(loc, nextOff, off); - shardShape.emplace_back(sz); - } else { - Value numShardsVal = rewriter.create<arith::ConstantOp>( - loc, rewriter.getIndexAttr(numShards)); - // Compute shard dim size by distributing odd elements to trailing - // shards: - // sz = dim / numShards - // + (idx >= (numShards - (dim % numShards)) ? 1 : 0) - Value sz = rewriter.create<arith::DivSIOp>(loc, dim, numShardsVal); - Value sz1 = rewriter.create<arith::RemSIOp>(loc, dim, numShardsVal); - sz1 = rewriter.create<arith::SubIOp>(loc, numShardsVal, sz1); - auto cond = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::sge, idx, sz1); - Value odd = rewriter.create<arith::SelectOp>(loc, cond, one, zero); - sz = rewriter.create<arith::AddIOp>(loc, sz, odd); - shardShape.emplace_back(sz); - } - pos += numShards + 1; // add one for the total size. - } // else no sharding if split axis is empty or no split axis - // If no size was added -> no sharding in this dimension. - if (shardShape.size() <= i) - shardShape.emplace_back(dim); - } - assert(shardShape.size() == shape.size()); - rewriter.replaceOp(op, shardShape); - return success(); - } -}; - -static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp(ReductionKindAttr kind) { - auto ctx = kind.getContext(); - auto getReductionOp = [ctx](mpi::MPI_ReductionOpEnum redOp) { - return mpi::MPI_ReductionOpEnumAttr::get(ctx, redOp); - }; - - switch (kind.getValue()) { - case ReductionKind::Sum: - return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_SUM); - case ReductionKind::Product: - return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_PROD); - case ReductionKind::Min: - return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_MIN); - case ReductionKind::Max: - return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_MAX); - case ReductionKind::BitwiseAnd: - return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BAND); - case ReductionKind::BitwiseOr: - return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BOR); - case ReductionKind::BitwiseXor: - return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BXOR); - default: - llvm_unreachable("Unknown/unsupported reduction kind"); - } -} - -struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(AllReduceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - SymbolTableCollection symbolTableCollection; - auto mesh = adaptor.getMesh(); - mlir::mesh::MeshOp meshOp = getMesh(op, symbolTableCollection); - if (!meshOp) - return op->emitError() << "No mesh found for AllReduceOp"; - if (ShapedType::isDynamicShape(meshOp.getShape())) - return op->emitError() - << "Dynamic mesh shape not supported in AllReduceOp"; - - ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter); - Value input = adaptor.getInput(); - auto inputShape = cast<ShapedType>(input.getType()).getShape(); - - // If the source is a memref, cast it to a tensor. - if (isa<RankedTensorType>(input.getType())) { - auto memrefType = MemRefType::get( - inputShape, cast<ShapedType>(input.getType()).getElementType()); - input = iBuilder.create<bufferization::ToBufferOp>(memrefType, input); - } - MemRefType inType = cast<MemRefType>(input.getType()); - - // Get the actual shape to allocate the buffer. - SmallVector<OpFoldResult> shape(inType.getRank()); - for (auto i = 0; i < inType.getRank(); ++i) { - auto s = inputShape[i]; - if (ShapedType::isDynamic(s)) - shape[i] = iBuilder.create<memref::DimOp>(input, s).getResult(); - else - shape[i] = iBuilder.getIndexAttr(s); - } - - // Allocate buffer and copy input to buffer. - Value buffer = iBuilder.create<memref::AllocOp>( - shape, cast<ShapedType>(op.getType()).getElementType()); - iBuilder.create<linalg::CopyOp>(input, buffer); - - // Get an MPI_Comm_split for the AllReduce operation. - // The color is the linear index of the process in the mesh along the - // non-reduced axes. The key is the linear index of the process in the mesh - // along the reduced axes. - SmallVector<Type> indexResultTypes(meshOp.getShape().size(), - iBuilder.getIndexType()); - SmallVector<Value> myMultiIndex = - iBuilder.create<ProcessMultiIndexOp>(indexResultTypes, mesh) - .getResult(); - Value zero = iBuilder.create<arith::ConstantIndexOp>(0); - SmallVector<Value> multiKey(myMultiIndex.size(), zero); - - auto redAxes = adaptor.getMeshAxes(); - for (auto axis : redAxes) { - multiKey[axis] = myMultiIndex[axis]; - myMultiIndex[axis] = zero; - } - - Value color = - createProcessLinearIndex(mesh, myMultiIndex, redAxes, iBuilder); - color = iBuilder.create<arith::IndexCastOp>(iBuilder.getI32Type(), color); - Value key = createProcessLinearIndex(mesh, multiKey, redAxes, iBuilder); - key = iBuilder.create<arith::IndexCastOp>(iBuilder.getI32Type(), key); - - // Finally split the communicator - auto commType = mpi::CommType::get(op->getContext()); - Value commWorld = iBuilder.create<mpi::CommWorldOp>(commType); - auto comm = - iBuilder.create<mpi::CommSplitOp>(commType, commWorld, color, key) - .getNewcomm(); - - Value buffer1d = buffer; - // Collapse shape to 1d if needed - if (inType.getRank() > 1) { - ReassociationIndices reassociation(inType.getRank()); - std::iota(reassociation.begin(), reassociation.end(), 0); - buffer1d = iBuilder.create<memref::CollapseShapeOp>( - buffer, ArrayRef<ReassociationIndices>(reassociation)); - } - - // Create the MPI AllReduce operation. - iBuilder.create<mpi::AllReduceOp>( - TypeRange(), buffer1d, buffer1d, - getMPIReductionOp(adaptor.getReductionAttr()), comm); - - // If the destination is a memref, cast it to a tensor - if (isa<RankedTensorType>(op.getType())) - buffer = iBuilder.create<bufferization::ToTensorOp>(op.getType(), buffer, - true); - - rewriter.replaceOp(op, buffer); - return success(); - } -}; - -struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(UpdateHaloOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - // The input/output memref is assumed to be in C memory order. - // Halos are exchanged as 2 blocks per dimension (one for each side: down - // and up). For each haloed dimension `d`, the exchanged blocks are - // expressed as multi-dimensional subviews. The subviews include potential - // halos of higher dimensions `dh > d`, no halos for the lower dimensions - // `dl < d` and for dimension `d` the currently exchanged halo only. - // By iterating form higher to lower dimensions this also updates the halos - // in the 'corners'. - // memref.subview is used to read and write the halo data from and to the - // local data. Because subviews and halos can have mixed dynamic and static - // shapes, OpFoldResults are used whenever possible. - - auto haloSizes = getMixedValues(adaptor.getStaticHaloSizes(), - adaptor.getHaloSizes(), rewriter); - if (haloSizes.empty()) { - // no halos -> nothing to do - rewriter.replaceOp(op, adaptor.getDestination()); - return success(); - } - - SymbolTableCollection symbolTableCollection; - Location loc = op.getLoc(); - - // convert a OpFoldResult into a Value - auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value { - if (auto value = dyn_cast<Value>(v)) - return value; - return rewriter.create<arith::ConstantOp>( - loc, rewriter.getIndexAttr( - cast<IntegerAttr>(cast<Attribute>(v)).getInt())); - }; - - auto dest = adaptor.getDestination(); - auto dstShape = cast<ShapedType>(dest.getType()).getShape(); - Value array = dest; - if (isa<RankedTensorType>(array.getType())) { - // If the destination is a memref, we need to cast it to a tensor - auto mmemrefType = MemRefType::get( - dstShape, cast<ShapedType>(array.getType()).getElementType()); - array = - rewriter.create<bufferization::ToBufferOp>(loc, mmemrefType, array); - } - auto rank = cast<ShapedType>(array.getType()).getRank(); - auto opSplitAxes = adaptor.getSplitAxes().getAxes(); - auto mesh = adaptor.getMesh(); - auto meshOp = getMesh(op, symbolTableCollection); - // subviews need Index values - for (auto &sz : haloSizes) { - if (auto value = dyn_cast<Value>(sz)) - sz = - rewriter - .create<arith::IndexCastOp>(loc, rewriter.getIndexType(), value) - .getResult(); - } - - // most of the offset/size/stride data is the same for all dims - SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0)); - SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); - SmallVector<OpFoldResult> shape(rank), dimSizes(rank); - auto currHaloDim = -1; // halo sizes are provided for split dimensions only - // we need the actual shape to compute offsets and sizes - for (auto i = 0; i < rank; ++i) { - auto s = dstShape[i]; - if (ShapedType::isDynamic(s)) - shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult(); - else - shape[i] = rewriter.getIndexAttr(s); - - if ((size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) { - ++currHaloDim; - // the offsets for lower dim sstarts after their down halo - offsets[i] = haloSizes[currHaloDim * 2]; - - // prepare shape and offsets of highest dim's halo exchange - Value _haloSz = rewriter.create<arith::AddIOp>( - loc, toValue(haloSizes[currHaloDim * 2]), - toValue(haloSizes[currHaloDim * 2 + 1])); - // the halo shape of lower dims exlude the halos - dimSizes[i] = - rewriter.create<arith::SubIOp>(loc, toValue(shape[i]), _haloSz) - .getResult(); - } else { - dimSizes[i] = shape[i]; - } - } - - auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something - auto tag = rewriter.create<arith::ConstantOp>(loc, tagAttr); - auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0 - auto zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr); - - SmallVector<Type> indexResultTypes(meshOp.getShape().size(), - rewriter.getIndexType()); - auto myMultiIndex = - rewriter.create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh) - .getResult(); - // traverse all split axes from high to low dim - for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) { - auto splitAxes = opSplitAxes[dim]; - if (splitAxes.empty()) - continue; - assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2); - // Get the linearized ids of the neighbors (down and up) for the - // given split - auto tmp = rewriter - .create<NeighborsLinearIndicesOp>(loc, mesh, myMultiIndex, - splitAxes) - .getResults(); - // MPI operates on i32... - Value neighbourIDs[2] = {rewriter.create<arith::IndexCastOp>( - loc, rewriter.getI32Type(), tmp[0]), - rewriter.create<arith::IndexCastOp>( - loc, rewriter.getI32Type(), tmp[1])}; - - auto lowerRecvOffset = rewriter.getIndexAttr(0); - auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]); - auto upperRecvOffset = rewriter.create<arith::SubIOp>( - loc, toValue(shape[dim]), toValue(haloSizes[currHaloDim * 2 + 1])); - auto upperSendOffset = rewriter.create<arith::SubIOp>( - loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2])); - - Value commWorld = rewriter.create<mpi::CommWorldOp>( - loc, mpi::CommType::get(op->getContext())); - - // Make sure we send/recv in a way that does not lead to a dead-lock. - // The current approach is by far not optimal, this should be at least - // be a red-black pattern or using MPI_sendrecv. - // Also, buffers should be re-used. - // Still using temporary contiguous buffers for MPI communication... - // Still yielding a "serialized" communication pattern... - auto genSendRecv = [&](bool upperHalo) { - auto orgOffset = offsets[dim]; - dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1] - : haloSizes[currHaloDim * 2]; - // Check if we need to send and/or receive - // Processes on the mesh borders have only one neighbor - auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1]; - auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0]; - auto hasFrom = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::sge, from, zero); - auto hasTo = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::sge, to, zero); - auto buffer = rewriter.create<memref::AllocOp>( - loc, dimSizes, cast<ShapedType>(array.getType()).getElementType()); - // if has neighbor: copy halo data from array to buffer and send - rewriter.create<scf::IfOp>( - loc, hasTo, [&](OpBuilder &builder, Location loc) { - offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset) - : OpFoldResult(upperSendOffset); - auto subview = builder.create<memref::SubViewOp>( - loc, array, offsets, dimSizes, strides); - builder.create<memref::CopyOp>(loc, subview, buffer); - builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag, to, - commWorld); - builder.create<scf::YieldOp>(loc); - }); - // if has neighbor: receive halo data into buffer and copy to array - rewriter.create<scf::IfOp>( - loc, hasFrom, [&](OpBuilder &builder, Location loc) { - offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset) - : OpFoldResult(lowerRecvOffset); - builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag, from, - commWorld); - auto subview = builder.create<memref::SubViewOp>( - loc, array, offsets, dimSizes, strides); - builder.create<memref::CopyOp>(loc, buffer, subview); - builder.create<scf::YieldOp>(loc); - }); - rewriter.create<memref::DeallocOp>(loc, buffer); - offsets[dim] = orgOffset; - }; - - auto doSendRecv = [&](int upOrDown) { - OpFoldResult &v = haloSizes[currHaloDim * 2 + upOrDown]; - Value haloSz = dyn_cast<Value>(v); - if (!haloSz) - haloSz = rewriter.create<arith::ConstantOp>( - loc, rewriter.getI32IntegerAttr( - cast<IntegerAttr>(cast<Attribute>(v)).getInt())); - auto hasSize = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::sgt, haloSz, zero); - rewriter.create<scf::IfOp>(loc, hasSize, - [&](OpBuilder &builder, Location loc) { - genSendRecv(upOrDown > 0); - builder.create<scf::YieldOp>(loc); - }); - }; - - doSendRecv(0); - doSendRecv(1); - - // the shape for lower dims include higher dims' halos - dimSizes[dim] = shape[dim]; - // -> the offset for higher dims is always 0 - offsets[dim] = rewriter.getIndexAttr(0); - // on to next halo - --currHaloDim; - } - - if (isa<MemRefType>(op.getResult().getType())) { - rewriter.replaceOp(op, array); - } else { - assert(isa<RankedTensorType>(op.getResult().getType())); - rewriter.replaceOp(op, rewriter.create<bufferization::ToTensorOp>( - loc, op.getResult().getType(), array, - /*restrict=*/true, /*writable=*/true)); - } - return success(); - } -}; - -struct ConvertMeshToMPIPass - : public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> { - using Base::Base; - - /// Run the dialect converter on the module. - void runOnOperation() override { - auto *ctxt = &getContext(); - RewritePatternSet patterns(ctxt); - ConversionTarget target(getContext()); - - // Define a type converter to convert mesh::ShardingType, - // mostly for use in return operations. - TypeConverter typeConverter; - typeConverter.addConversion([](Type type) { return type; }); - - // convert mesh::ShardingType to a tuple of RankedTensorTypes - typeConverter.addConversion( - [](ShardingType type, - SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> { - auto i16 = IntegerType::get(type.getContext(), 16); - auto i64 = IntegerType::get(type.getContext(), 64); - std::array<int64_t, 2> shp = {ShapedType::kDynamic, - ShapedType::kDynamic}; - results.emplace_back(RankedTensorType::get(shp, i16)); - results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2 - results.emplace_back(RankedTensorType::get(shp, i64)); - return success(); - }); - - // To 'extract' components, a UnrealizedConversionCastOp is expected - // to define the input - typeConverter.addTargetMaterialization( - [&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs, - Location loc) { - // Expecting a single input. - if (inputs.size() != 1 || !isa<TupleType>(inputs[0].getType())) - return SmallVector<Value>(); - auto castOp = inputs[0].getDefiningOp<UnrealizedConversionCastOp>(); - // Expecting an UnrealizedConversionCastOp. - if (!castOp) - return SmallVector<Value>(); - // Fill a vector with elements of the tuple/castOp. - SmallVector<Value> results; - for (auto oprnd : castOp.getInputs()) { - if (!isa<RankedTensorType>(oprnd.getType())) - return SmallVector<Value>(); - results.emplace_back(oprnd); - } - return results; - }); - - // No mesh dialect should left after conversion... - target.addIllegalDialect<mesh::MeshDialect>(); - // ...except the global MeshOp. MeshShapeOp which will get folded later. - target.addLegalOp<mesh::MeshOp, mesh::MeshShapeOp>(); - // Allow all the stuff that our patterns will convert to - target.addLegalDialect< - BuiltinDialect, mpi::MPIDialect, scf::SCFDialect, arith::ArithDialect, - tensor::TensorDialect, bufferization::BufferizationDialect, - linalg::LinalgDialect, memref::MemRefDialect, affine::AffineDialect>(); - // Make sure the function signature, calls etc. are legal - target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { - return typeConverter.isSignatureLegal(op.getFunctionType()); - }); - target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>( - [&](Operation *op) { return typeConverter.isLegal(op); }); - - patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp, - ConvertProcessMultiIndexOp, ConvertGetShardingOp, - ConvertShardingOp, ConvertShardShapeOp, ConvertAllReduceOp, - ConvertProcessLinearIndexOp>(typeConverter, ctxt); - - populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>( - patterns, typeConverter); - populateCallOpTypeConversionPattern(patterns, typeConverter); - populateReturnOpTypeConversionPattern(patterns, typeConverter); - - (void)applyPartialConversion(getOperation(), target, std::move(patterns)); - - // Folding patterns cannot be mixed with conversion patterns -> extra pass. - patterns.clear(); - SymbolTableCollection symbolTableCollection; - mlir::mesh::populateFoldingPatterns(patterns, symbolTableCollection); - (void)applyPatternsGreedily(getOperation(), std::move(patterns)); - } -}; - -} // namespace |