aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r--mlir/lib/Dialect/GPU/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp27
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp174
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/Utils.cpp44
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp91
-rw-r--r--mlir/lib/Dialect/Mesh/IR/MeshOps.cpp213
-rw-r--r--mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt3
-rw-r--r--mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp639
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp28
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp180
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp44
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformOps.cpp37
-rw-r--r--mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp2
13 files changed, 1277 insertions, 206 deletions
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index 8383e06..8f289ce 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -64,6 +64,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
Transforms/ShuffleRewriter.cpp
Transforms/SPIRVAttachTarget.cpp
Transforms/SubgroupReduceLowering.cpp
+ Transforms/Utils.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU
diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
index 608d801..a75598a 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
@@ -27,33 +27,6 @@ using namespace mlir;
namespace {
-static vector::CombiningKind
-convertReductionKind(gpu::AllReduceOperation mode) {
- switch (mode) {
-#define MAP_CASE(X) \
- case gpu::AllReduceOperation::X: \
- return vector::CombiningKind::X
-
- MAP_CASE(ADD);
- MAP_CASE(MUL);
- MAP_CASE(MINUI);
- MAP_CASE(MINSI);
- MAP_CASE(MINNUMF);
- MAP_CASE(MAXSI);
- MAP_CASE(MAXUI);
- MAP_CASE(MAXNUMF);
- MAP_CASE(AND);
- MAP_CASE(OR);
- MAP_CASE(XOR);
- MAP_CASE(MINIMUMF);
- MAP_CASE(MAXIMUMF);
-
-#undef MAP_CASE
- }
-
- llvm_unreachable("Vector and GPU reduction kinds should match 1:1");
-}
-
struct GpuAllReduceRewriter {
using AccumulatorFactory = std::function<Value(Value, Value)>;
diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
index 61edce5..b00c65c 100644
--- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
@@ -13,13 +13,17 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Dialect/GPU/Transforms/Utils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
#include <cassert>
+#include <cstdint>
using namespace mlir;
@@ -58,7 +62,7 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
if (elemBitwidth >= maxShuffleBitwidth)
return rewriter.notifyMatchFailure(
- op, llvm::formatv("element type too large {0}, cannot break down "
+ op, llvm::formatv("element type too large ({0}), cannot break down "
"into vectors of bitwidth {1} or less",
elemBitwidth, maxShuffleBitwidth));
@@ -139,6 +143,167 @@ struct ScalarizeSingleElementReduce final
}
};
+/// Emits a subgroup reduction using a sequence of shuffles. Uses the `packFn`
+/// and `unpackFn` to convert to the native shuffle type and to the reduction
+/// type, respectively. For example, with `input` of type `f16`, `packFn` could
+/// build ops to cast the value to `i32` to perform shuffles, while `unpackFn`
+/// would cast it back to `f16` to perform arithmetic reduction on. Assumes that
+/// the subgroup is `subgroupSize` lanes wide and reduces across all of them.
+static Value createSubgroupShuffleReduction(
+ OpBuilder &builder, Location loc, Value input, gpu::AllReduceOperation mode,
+ unsigned subgroupSize, function_ref<Value(Value)> packFn,
+ function_ref<Value(Value)> unpackFn) {
+ assert(llvm::isPowerOf2_32(subgroupSize));
+ // Lane value always stays in the original type. We use it to perform arith
+ // reductions.
+ Value laneVal = input;
+ // Parallel reduction using butterfly shuffles.
+ for (unsigned i = 1; i < subgroupSize; i <<= 1) {
+ Value shuffled = builder
+ .create<gpu::ShuffleOp>(loc, packFn(laneVal), i,
+ /*width=*/subgroupSize,
+ /*mode=*/gpu::ShuffleMode::XOR)
+ .getShuffleResult();
+ laneVal = vector::makeArithReduction(builder, loc,
+ gpu::convertReductionKind(mode),
+ laneVal, unpackFn(shuffled));
+ assert(laneVal.getType() == input.getType());
+ }
+
+ return laneVal;
+}
+
+/// Lowers scalar gpu subgroup reductions to a series of shuffles.
+struct ScalarSubgroupReduceToShuffles final
+ : OpRewritePattern<gpu::SubgroupReduceOp> {
+ ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
+ unsigned shuffleBitwidth,
+ PatternBenefit benefit)
+ : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
+ shuffleBitwidth(shuffleBitwidth) {}
+
+ LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
+ PatternRewriter &rewriter) const override {
+ Type valueTy = op.getType();
+ unsigned elemBitwidth =
+ getElementTypeOrSelf(valueTy).getIntOrFloatBitWidth();
+ if (!valueTy.isIntOrFloat() || elemBitwidth > shuffleBitwidth)
+ return rewriter.notifyMatchFailure(
+ op, "value type is not a compatible scalar");
+
+ Location loc = op.getLoc();
+ // Since this is already a native shuffle scalar, no packing is necessary.
+ if (elemBitwidth == shuffleBitwidth) {
+ auto identityFn = [](Value v) { return v; };
+ rewriter.replaceOp(op, createSubgroupShuffleReduction(
+ rewriter, loc, op.getValue(), op.getOp(),
+ subgroupSize, identityFn, identityFn));
+ return success();
+ }
+
+ auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth);
+ auto equivIntType = rewriter.getIntegerType(elemBitwidth);
+ auto packFn = [loc, &rewriter, equivIntType,
+ shuffleIntType](Value unpackedVal) -> Value {
+ auto asInt =
+ rewriter.create<arith::BitcastOp>(loc, equivIntType, unpackedVal);
+ return rewriter.create<arith::ExtUIOp>(loc, shuffleIntType, asInt);
+ };
+ auto unpackFn = [loc, &rewriter, equivIntType,
+ valueTy](Value packedVal) -> Value {
+ auto asInt =
+ rewriter.create<arith::TruncIOp>(loc, equivIntType, packedVal);
+ return rewriter.create<arith::BitcastOp>(loc, valueTy, asInt);
+ };
+
+ rewriter.replaceOp(op, createSubgroupShuffleReduction(
+ rewriter, loc, op.getValue(), op.getOp(),
+ subgroupSize, packFn, unpackFn));
+ return success();
+ }
+
+private:
+ unsigned subgroupSize = 0;
+ unsigned shuffleBitwidth = 0;
+};
+
+/// Lowers vector gpu subgroup reductions to a series of shuffles.
+struct VectorSubgroupReduceToShuffles final
+ : OpRewritePattern<gpu::SubgroupReduceOp> {
+ VectorSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
+ unsigned shuffleBitwidth,
+ PatternBenefit benefit)
+ : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
+ shuffleBitwidth(shuffleBitwidth) {}
+
+ LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
+ PatternRewriter &rewriter) const override {
+ auto vecTy = dyn_cast<VectorType>(op.getType());
+ if (!vecTy)
+ return rewriter.notifyMatchFailure(op, "value type is not a vector");
+
+ unsigned vecBitwidth =
+ vecTy.getNumElements() * vecTy.getElementTypeBitWidth();
+ if (vecBitwidth > shuffleBitwidth)
+ return rewriter.notifyMatchFailure(
+ op,
+ llvm::formatv("vector type bitwidth too large ({0}), cannot lower "
+ "to shuffles of size {1}",
+ vecBitwidth, shuffleBitwidth));
+
+ unsigned elementsPerShuffle =
+ shuffleBitwidth / vecTy.getElementTypeBitWidth();
+ if (elementsPerShuffle * vecTy.getElementTypeBitWidth() != shuffleBitwidth)
+ return rewriter.notifyMatchFailure(
+ op, "shuffle bitwidth is not a multiple of the element bitwidth");
+
+ Location loc = op.getLoc();
+
+ // If the reduced type is smaller than the native shuffle size, extend it,
+ // perform the shuffles, and extract at the end.
+ auto extendedVecTy = VectorType::get(
+ static_cast<int64_t>(elementsPerShuffle), vecTy.getElementType());
+ Value extendedInput = op.getValue();
+ if (vecBitwidth < shuffleBitwidth) {
+ auto zero = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(extendedVecTy));
+ extendedInput = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, extendedInput, zero, /*offsets=*/0, /*strides=*/1);
+ }
+
+ auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth);
+ auto shuffleVecType = VectorType::get(1, shuffleIntType);
+
+ auto packFn = [loc, &rewriter, shuffleVecType](Value unpackedVal) -> Value {
+ auto asIntVec =
+ rewriter.create<vector::BitCastOp>(loc, shuffleVecType, unpackedVal);
+ return rewriter.create<vector::ExtractOp>(loc, asIntVec, 0);
+ };
+ auto unpackFn = [loc, &rewriter, shuffleVecType,
+ extendedVecTy](Value packedVal) -> Value {
+ auto asIntVec =
+ rewriter.create<vector::BroadcastOp>(loc, shuffleVecType, packedVal);
+ return rewriter.create<vector::BitCastOp>(loc, extendedVecTy, asIntVec);
+ };
+
+ Value res =
+ createSubgroupShuffleReduction(rewriter, loc, extendedInput, op.getOp(),
+ subgroupSize, packFn, unpackFn);
+
+ if (vecBitwidth < shuffleBitwidth) {
+ res = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, res, /*offsets=*/0, /*sizes=*/vecTy.getNumElements(),
+ /*strides=*/1);
+ }
+
+ rewriter.replaceOp(op, res);
+ return success();
+ }
+
+private:
+ unsigned subgroupSize = 0;
+ unsigned shuffleBitwidth = 0;
+};
} // namespace
void mlir::populateGpuBreakDownSubgrupReducePatterns(
@@ -148,3 +313,10 @@ void mlir::populateGpuBreakDownSubgrupReducePatterns(
maxShuffleBitwidth, benefit);
patterns.add<ScalarizeSingleElementReduce>(patterns.getContext(), benefit);
}
+
+void mlir::populateGpuLowerSubgroupReduceToShufflePattenrs(
+ RewritePatternSet &patterns, unsigned subgroupSize,
+ unsigned shuffleBitwidth, PatternBenefit benefit) {
+ patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
+ patterns.getContext(), subgroupSize, shuffleBitwidth, benefit);
+}
diff --git a/mlir/lib/Dialect/GPU/Transforms/Utils.cpp b/mlir/lib/Dialect/GPU/Transforms/Utils.cpp
new file mode 100644
index 0000000..e91aa18
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Transforms/Utils.cpp
@@ -0,0 +1,44 @@
+//===- Utils.cpp - GPU transforms utils -----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implements GPU dialect transforms utils.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/Transforms/Utils.h"
+#include "llvm/Support/ErrorHandling.h"
+
+namespace mlir::gpu {
+
+vector::CombiningKind convertReductionKind(gpu::AllReduceOperation mode) {
+ switch (mode) {
+#define MAP_CASE(X) \
+ case gpu::AllReduceOperation::X: \
+ return vector::CombiningKind::X
+
+ MAP_CASE(ADD);
+ MAP_CASE(MUL);
+ MAP_CASE(MINUI);
+ MAP_CASE(MINSI);
+ MAP_CASE(MINNUMF);
+ MAP_CASE(MAXSI);
+ MAP_CASE(MAXUI);
+ MAP_CASE(MAXNUMF);
+ MAP_CASE(AND);
+ MAP_CASE(OR);
+ MAP_CASE(XOR);
+ MAP_CASE(MINIMUMF);
+ MAP_CASE(MAXIMUMF);
+
+#undef MAP_CASE
+ }
+
+ llvm_unreachable("Vector and GPU reduction kinds should match 1:1");
+}
+
+} // namespace mlir::gpu
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index ba419d3..2917840 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -176,22 +176,22 @@ static bool isContractionBody(Block &block) {
return linalg::detail::isContractionBody(block, &isPairTemplateImpl<Args...>);
}
-/// Given a `linalgOp` and one of its `opOperand`, returns the positions of the
-/// iterators of type `iter` that index the `opOperand` as a permutation.
-/// This is useful to infer various subcomputations on a given `linalgOp`.
-/// This is performed by looking up each result in the matching indexing map and
-/// determining whether:
+/// Given an `indexingMap` and its corresponding `iterators`, returns
+/// the positions of the iterators of type `iter` that are indexed by
+/// the `indexingMap` as a permutation. This is useful to infer various
+/// subcomputations on a `LinalgOp`. This is performed by looking up
+/// each result in the `indexingMap` and determining whether:
/// - It is a single AffineDimExpr.
/// - It is the only result involving this AffineDimExpr.
static llvm::SmallDenseSet<int64_t>
-findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand,
+findPermutationsIndexingOperand(AffineMap indexingMap,
+ ArrayRef<utils::IteratorType> iterators,
utils::IteratorType iter) {
+ assert(iterators.size() == indexingMap.getNumDims());
llvm::SmallDenseSet<int64_t> res;
- assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner");
- AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
for (AffineExpr e : indexingMap.getResults()) {
if (auto d = dyn_cast<AffineDimExpr>(e)) {
- if (linalgOp.getIteratorTypesArray()[d.getPosition()] == iter &&
+ if (iterators[d.getPosition()] == iter &&
llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
return e.isFunctionOfDim(d.getPosition());
}) == 1)
@@ -206,6 +206,21 @@ auto par = utils::IteratorType::parallel;
auto red = utils::IteratorType::reduction;
} // namespace
+/// Infer the iterator types from the init affine map. This looks at which dims
+/// are present in the map results, and returns an iterator types array with
+/// parallel types for dims that are present, and reduction types for dims that
+/// are not present.
+static FailureOr<SmallVector<utils::IteratorType>>
+inferIteratorsFromOutMap(AffineMap map) {
+ if (!map.isProjectedPermutation())
+ return failure();
+ SmallVector<utils::IteratorType> iterators(map.getNumDims(), red);
+ for (auto expr : map.getResults())
+ if (auto dim = dyn_cast<AffineDimExpr>(expr))
+ iterators[dim.getPosition()] = par;
+ return iterators;
+}
+
/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
/// a matmul subcomputation within `linalgOp`. These dimensions are such that:
/// 1. The m dimension is involved in an outer-product along LHS
@@ -217,17 +232,15 @@ auto red = utils::IteratorType::reduction;
/// 5. Optional batch dimensions that appear in all operands are captured.
/// This allows e.g. detecting that some contraction is embedded within
/// `linalgOp` with some orthogonal heuristic.
-FailureOr<ContractionDimensions>
-mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
- if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
- return failure();
-
- llvm::SmallDenseSet<int64_t> a = findPermutationsIndexingOperand(
- linalgOp, linalgOp.getDpsInputOperand(0), par);
- llvm::SmallDenseSet<int64_t> b = findPermutationsIndexingOperand(
- linalgOp, linalgOp.getDpsInputOperand(1), par);
- llvm::SmallDenseSet<int64_t> c = findPermutationsIndexingOperand(
- linalgOp, linalgOp.getDpsInitOperand(0), par);
+static FailureOr<ContractionDimensions>
+inferContractionDimsImpl(ArrayRef<AffineMap> indexingMaps,
+ ArrayRef<utils::IteratorType> iterators) {
+ llvm::SmallDenseSet<int64_t> a =
+ findPermutationsIndexingOperand(indexingMaps[0], iterators, par);
+ llvm::SmallDenseSet<int64_t> b =
+ findPermutationsIndexingOperand(indexingMaps[1], iterators, par);
+ llvm::SmallDenseSet<int64_t> c =
+ findPermutationsIndexingOperand(indexingMaps[2], iterators, par);
// A & C - B are the iterators involved in an outer-product along A (the LHS).
llvm::SmallDenseSet<int64_t> ac = a;
@@ -243,10 +256,10 @@ mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
llvm::set_intersect(batches, c);
// A & B red are the reduction dimensions.
- llvm::SmallDenseSet<int64_t> ra = findPermutationsIndexingOperand(
- linalgOp, linalgOp.getDpsInputOperand(0), red);
- llvm::SmallDenseSet<int64_t> rb = findPermutationsIndexingOperand(
- linalgOp, linalgOp.getDpsInputOperand(1), red);
+ llvm::SmallDenseSet<int64_t> ra =
+ findPermutationsIndexingOperand(indexingMaps[0], iterators, red);
+ llvm::SmallDenseSet<int64_t> rb =
+ findPermutationsIndexingOperand(indexingMaps[1], iterators, red);
llvm::set_intersect(ra, rb);
// Return each set in sorted order.
@@ -262,6 +275,24 @@ mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
return dimensions;
}
+FailureOr<ContractionDimensions>
+mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
+ if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
+ return failure();
+ return inferContractionDimsImpl(linalgOp.getIndexingMapsArray(),
+ linalgOp.getIteratorTypesArray());
+}
+
+FailureOr<ContractionDimensions>
+mlir::linalg::inferContractionDims(ArrayRef<AffineMap> indexingMaps) {
+ if (indexingMaps.size() != 3)
+ return failure();
+ auto iterators = inferIteratorsFromOutMap(indexingMaps[2]);
+ if (failed(iterators))
+ return failure();
+ return inferContractionDimsImpl(indexingMaps, iterators.value());
+}
+
namespace mlir::linalg::detail {
enum class MatchContractionResult {
Success = 0,
@@ -504,10 +535,14 @@ static FailureOr<ConvolutionDimensions>
inferConvolutionDimsImpl(LinalgOp linalgOp,
ConvAccessExprWalker &inputExprWalker,
bool allowEmptyConvolvedDims) {
+ auto filterMap =
+ linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
+ auto outputMap =
+ linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
llvm::SmallDenseSet<int64_t> filterDims = findPermutationsIndexingOperand(
- linalgOp, linalgOp.getDpsInputOperand(1), par);
+ filterMap, linalgOp.getIteratorTypesArray(), par);
llvm::SmallDenseSet<int64_t> outputDims = findPermutationsIndexingOperand(
- linalgOp, linalgOp.getDpsInitOperand(0), par);
+ outputMap, linalgOp.getIteratorTypesArray(), par);
// unConvolvedDims & outputDims - filterDims are the batch iterators.
llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
@@ -529,8 +564,8 @@ inferConvolutionDimsImpl(LinalgOp linalgOp,
llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
llvm::SmallDenseSet<int64_t> filterReducedDims =
- findPermutationsIndexingOperand(linalgOp, linalgOp.getDpsInputOperand(1),
- red);
+ findPermutationsIndexingOperand(filterMap,
+ linalgOp.getIteratorTypesArray(), red);
// convolvedDims & filterReducedDims are the filter loop iterators.
llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index d276755..de4f58d 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -9,8 +9,10 @@
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Location.h"
@@ -59,8 +61,6 @@ static SmallVector<T> &canonicalizeSetAsVector(SmallVector<T> &vec) {
return vec;
}
-using MeshAxis = int16_t;
-
namespace {
struct DimensionSize {
@@ -114,6 +114,56 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
// Mesh utilities
//===----------------------------------------------------------------------===//
+static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
+ SymbolTableCollection &symbolTable) {
+ mesh::ClusterOp mesh =
+ symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(op, meshSymbol);
+ if (!mesh) {
+ return op->emitError() << "Undefined required mesh symbol \""
+ << meshSymbol.getValue() << "\".";
+ }
+
+ return mesh;
+}
+
+template <typename It>
+bool isUnique(It begin, It end) {
+ if (begin == end) {
+ return true;
+ }
+ It next = std::next(begin);
+ if (next == end) {
+ return true;
+ }
+ for (; next != end; ++next, ++begin) {
+ if (*begin == *next) {
+ return false;
+ }
+ }
+ return true;
+}
+
+static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
+ ClusterOp mesh) {
+ SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
+ llvm::sort(sorted);
+ if (!isUnique(sorted.begin(), sorted.end())) {
+ return emitError(loc) << "Mesh axes contains duplicate elements.";
+ }
+
+ MeshAxis rank = mesh.getRank();
+ for (auto axis : axes) {
+ if (axis >= rank || axis < 0) {
+ return emitError(loc)
+ << "0-based mesh axis index " << axis
+ << " is out of bounds. The referenced mesh \"" << mesh.getSymName()
+ << "\" is of rank " << rank << ".";
+ }
+ }
+
+ return success();
+}
+
bool mesh::isReductionLoop(IteratorType iType) {
return iType != IteratorType::Parallel && iType != IteratorType::Invalid;
}
@@ -173,7 +223,45 @@ SmallVector<int64_t> ClusterOp::canonicalDimSizes() {
}
//===----------------------------------------------------------------------===//
-// mesh.shard op
+// mesh.cluster_shape op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+ClusterShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
+ return failure();
+ }
+
+ size_t expectedResultsCount =
+ getAxes().empty() ? mesh->getRank() : getAxes().size();
+ if (getResult().size() != expectedResultsCount) {
+ return emitError() << "Unexpected number of results " << getResult().size()
+ << ". Expected " << expectedResultsCount << ".";
+ }
+
+ return success();
+}
+
+void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ ClusterOp mesh) {
+ build(odsBuilder, odsState,
+ SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
+ mesh.getSymName(), MeshAxesAttr());
+}
+
+void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ StringRef mesh, ArrayRef<MeshAxis> axes) {
+ build(odsBuilder, odsState,
+ SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
+ MeshAxesAttr::get(odsBuilder.getContext(), axes));
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.shard attr
//===----------------------------------------------------------------------===//
LogicalResult
@@ -205,6 +293,75 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
+bool MeshShardingAttr::operator==(Attribute rhs) const {
+ MeshShardingAttr rhsAsMeshShardingAttr = rhs.dyn_cast<MeshShardingAttr>();
+ return rhsAsMeshShardingAttr && *this == rhsAsMeshShardingAttr;
+}
+
+bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
+ if (getCluster() != rhs.getCluster() ||
+ getPartialAxes() != rhs.getPartialAxes()) {
+ return false;
+ }
+
+ if (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) {
+ return false;
+ }
+
+ auto minSize = std::min(getSplitAxes().size(), rhs.getSplitAxes().size());
+ if (!llvm::equal(llvm::make_range(getSplitAxes().begin(),
+ getSplitAxes().begin() + minSize),
+ llvm::make_range(rhs.getSplitAxes().begin(),
+ rhs.getSplitAxes().begin() + minSize))) {
+ return false;
+ }
+
+ return llvm::all_of(llvm::make_range(getSplitAxes().begin() + minSize,
+ getSplitAxes().end()),
+ std::mem_fn(&DenseI32ArrayAttr::empty)) &&
+ llvm::all_of(llvm::make_range(rhs.getSplitAxes().begin() + minSize,
+ rhs.getSplitAxes().end()),
+ std::mem_fn(&DenseI32ArrayAttr::empty));
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.process_index op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+ProcessIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
+ return failure();
+ }
+
+ size_t expectedResultsCount =
+ getAxes().empty() ? mesh->getRank() : getAxes().size();
+ if (getResult().size() != expectedResultsCount) {
+ return emitError() << "Unexpected number of results " << getResult().size()
+ << ". Expected " << expectedResultsCount << ".";
+ }
+
+ return success();
+}
+
+void ProcessIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ ClusterOp mesh) {
+ build(odsBuilder, odsState,
+ SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
+ mesh.getSymName(), MeshAxesAttr());
+}
+
+void ProcessIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ StringRef mesh, ArrayRef<MeshAxis> axes) {
+ build(odsBuilder, odsState,
+ SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
+ MeshAxesAttr::get(odsBuilder.getContext(), axes));
+}
+
//===----------------------------------------------------------------------===//
// collective communication ops
//===----------------------------------------------------------------------===//
@@ -258,56 +415,6 @@ static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
return success();
}
-static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
- SymbolTableCollection &symbolTable) {
- mesh::ClusterOp mesh =
- symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(op, meshSymbol);
- if (!mesh) {
- return op->emitError() << "Undefined required mesh symbol \""
- << meshSymbol.getValue() << "\".";
- }
-
- return mesh;
-}
-
-template <typename It>
-bool isUnique(It begin, It end) {
- if (begin == end) {
- return true;
- }
- It next = std::next(begin);
- if (next == end) {
- return true;
- }
- for (; next != end; ++next, ++begin) {
- if (*begin == *next) {
- return false;
- }
- }
- return true;
-}
-
-static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
- ClusterOp mesh) {
- SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
- llvm::sort(sorted);
- if (!isUnique(sorted.begin(), sorted.end())) {
- return emitError(loc) << "Mesh axes contains duplicate elements.";
- }
-
- MeshAxis rank = mesh.getRank();
- for (auto axis : axes) {
- if (axis >= rank || axis < 0) {
- return emitError(loc)
- << "0-based mesh axis index " << axis
- << " is out of bounds. The referenced mesh \"" << mesh.getSymName()
- << "\" is of rank " << rank << ".";
- }
- }
-
- return success();
-}
-
template <typename Op>
static FailureOr<ClusterOp>
getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {
diff --git a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
index 044b867..7a70c04 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRMeshTransforms
Simplifications.cpp
ShardingPropagation.cpp
+ Spmdization.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
@@ -11,11 +12,13 @@ add_mlir_dialect_library(MLIRMeshTransforms
LINK_LIBS PUBLIC
MLIRArithDialect
+ MLIRControlFlowDialect
MLIRFuncDialect
MLIRIR
MLIRMeshDialect
MLIRPass
MLIRShardingInterface
MLIRSupport
+ MLIRTensorDialect
MLIRTosaShardingInterfaceImpl
)
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
new file mode 100644
index 0000000..8d7e896
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -0,0 +1,639 @@
+//===- Spmdization.cpp --------------------------------------------- C++ --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Support/MathExtras.h"
+#include "llvm/ADT/ADL.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include <algorithm>
+#include <iterator>
+#include <numeric>
+#include <optional>
+#include <tuple>
+#include <type_traits>
+
+namespace mlir {
+namespace mesh {
+
+int64_t shardDimension(int64_t dim, int64_t shardCount) {
+ if (ShapedType::isDynamic(dim) || ShapedType::isDynamic(shardCount))
+ return ShapedType::kDynamic;
+
+ assert(dim % shardCount == 0);
+ return ceilDiv(dim, shardCount);
+}
+
+int64_t unshardDimension(int64_t dim, int64_t shardCount) {
+ if (ShapedType::isDynamic(dim) || ShapedType::isDynamic(shardCount))
+ return ShapedType::kDynamic;
+
+ return dim * shardCount;
+}
+
+template <typename MeshShape, typename SplitAxes>
+int64_t shardCount(const MeshShape &meshShape, const SplitAxes &splitAxes) {
+ int64_t res = 1;
+ for (auto splitAxis : splitAxes) {
+ int64_t meshDimSize = meshShape[splitAxis];
+ if (ShapedType::isDynamic(meshDimSize)) {
+ return ShapedType::kDynamic;
+ }
+ res *= meshDimSize;
+ }
+ return res;
+}
+
+// Compute the shape for the tensor on each device in the mesh.
+// Example:
+// On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1
+// would result in a shape for each shard of ?x2x?.
+template <typename InShape, typename MeshShape, typename SplitAxes,
+ typename OutShape>
+static void shardShape(const InShape &inShape, const MeshShape &meshShape,
+ const SplitAxes &splitAxes, OutShape &outShape) {
+ std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
+ llvm::adl_begin(outShape));
+ for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
+ outShape[tensorAxis] =
+ shardDimension(inShape[tensorAxis],
+ shardCount(meshShape, innerSplitAxes.asArrayRef()));
+ }
+}
+
+ShapedType shardShapedType(ShapedType shape, ClusterOp mesh,
+ MeshShardingAttr sharding) {
+ using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
+ SmallVector<Dim> resShapeArr(shape.getShape().size());
+ shardShape(shape.getShape(), mesh.canonicalDimSizes(),
+ sharding.getSplitAxes(), resShapeArr);
+ return shape.clone(resShapeArr);
+}
+
+template <typename SourceAxes, typename TargetAxes>
+static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
+ const TargetAxes &targetAxes) {
+ return llvm::all_of(targetAxes, [&sourceAxes](auto &targetAxis) {
+ return sourceAxes.contains(targetAxis);
+ });
+}
+
+// Return the reduced value and its corresponding sharding.
+// Example:
+// sourceSharding = <@mesh_1d, [[0]], partial = sum[0]>
+// targetSharding = <@mesh_1d, [[]]>
+// Then will apply all-reduce on the source value
+// and return it with the sharding <@mesh_1d, [[0]]>.
+static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
+handlePartialAxesDuringResharding(OpBuilder &builder,
+ MeshShardingAttr sourceSharding,
+ MeshShardingAttr targetSharding,
+ TypedValue<ShapedType> sourceShard) {
+ if (sourceSharding.getPartialAxes().empty() &&
+ targetSharding.getPartialAxes().empty()) {
+ return {sourceShard, sourceSharding};
+ }
+ assert(targetSharding.getPartialAxes().empty() ||
+ (!sourceSharding.getPartialAxes().empty() &&
+ sourceSharding.getPartialType() == targetSharding.getPartialType()));
+ using Axis = std::decay_t<decltype(sourceSharding.getPartialAxes().front())>;
+ using AxisSet = llvm::SmallDenseSet<Axis>;
+ AxisSet sourceShardingPartialAxesSet(sourceSharding.getPartialAxes().begin(),
+ sourceSharding.getPartialAxes().end());
+ AxisSet targetShardingPartialAxesSet(targetSharding.getPartialAxes().begin(),
+ targetSharding.getPartialAxes().end());
+ assert(arePartialAxesCompatible(sourceShardingPartialAxesSet,
+ targetShardingPartialAxesSet));
+ llvm::SmallVector<MeshAxis> allReduceMeshAxes;
+ llvm::copy_if(sourceShardingPartialAxesSet,
+ std::back_inserter(allReduceMeshAxes),
+ [&targetShardingPartialAxesSet](Axis a) {
+ return !targetShardingPartialAxesSet.contains(a);
+ });
+ if (allReduceMeshAxes.empty()) {
+ return {sourceShard, sourceSharding};
+ }
+
+ builder.setInsertionPointAfterValue(sourceShard);
+ TypedValue<ShapedType> resultValue =
+ builder
+ .create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
+ sourceSharding.getCluster().getLeafReference(),
+ allReduceMeshAxes, sourceShard,
+ sourceSharding.getPartialType())
+ .getResult()
+ .cast<TypedValue<ShapedType>>();
+
+ llvm::SmallVector<int32_t> remainingPartialAxes;
+ llvm::copy_if(sourceShardingPartialAxesSet,
+ std::back_inserter(allReduceMeshAxes),
+ [&targetShardingPartialAxesSet](Axis a) {
+ return targetShardingPartialAxesSet.contains(a);
+ });
+ MeshShardingAttr resultSharding =
+ MeshShardingAttr::get(builder.getContext(), sourceSharding.getCluster(),
+ sourceSharding.getSplitAxes(), remainingPartialAxes,
+ sourceSharding.getPartialType());
+ return {resultValue, resultSharding};
+}
+
+static MeshShardingAttr
+targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
+ int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
+ SmallVector<DenseI32ArrayAttr> targetShardingSplitAxes =
+ llvm::to_vector(sourceSharding.getSplitAxes());
+ while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
+ splitTensorAxis) {
+ targetShardingSplitAxes.push_back(DenseI32ArrayAttr::get(ctx, {}));
+ }
+ auto targetSplitAxes =
+ llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
+ targetSplitAxes.push_back(splitMeshAxis);
+ targetShardingSplitAxes[splitTensorAxis] =
+ DenseI32ArrayAttr::get(ctx, targetSplitAxes);
+ return MeshShardingAttr::get(
+ ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
+ sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
+}
+
+static ShapedType targetShapeInSplitLastAxis(ShapedType sourceShape,
+ int64_t splitTensorAxis,
+ int64_t splitCount) {
+ SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
+ targetShape[splitTensorAxis] =
+ shardDimension(targetShape[splitTensorAxis], splitCount);
+ return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
+}
+
+// Split a replicated tensor along a mesh axis.
+// e.g. [[0, 1]] -> [[0, 1, 2]].
+// Returns the spmdized target value with its sharding.
+//
+// The implementation is the extract the tensor slice corresponding
+// to the current device.
+static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
+splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
+ MeshShardingAttr sourceSharding,
+ TypedValue<ShapedType> sourceShard, ClusterOp mesh,
+ int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
+ MLIRContext *ctx = builder.getContext();
+ builder.setInsertionPointAfterValue(sourceShard);
+
+ Value zero = builder.create<arith::ConstantOp>(builder.getIndexAttr(0));
+
+ Value processIndexAlongAxis =
+ builder
+ .create<ProcessIndexOp>(mesh.getSymName(),
+ SmallVector<MeshAxis>({splitMeshAxis}))
+ .getResult()[0];
+
+ MeshShardingAttr targetSharding = targetShardingInSplitLastAxis(
+ ctx, sourceSharding, splitTensorAxis, splitMeshAxis);
+ ShapedType targetShape =
+ targetShapeInSplitLastAxis(sourceShard.getType(), splitTensorAxis,
+ mesh.canonicalDimSizes()[splitMeshAxis]);
+
+ Value meshAxisSize =
+ builder
+ .create<ClusterShapeOp>(mesh.getSymName(),
+ SmallVector<MeshAxis>({splitMeshAxis}))
+ .getResult()[0];
+
+ Value sourceAxisSize =
+ builder.create<tensor::DimOp>(sourceShard, splitTensorAxis);
+ Value sourceAxisSizeModMeshAxisSize =
+ builder.create<arith::RemUIOp>(sourceAxisSize, meshAxisSize);
+ Value isTargetShapeExactlyDivisible = builder.create<arith::CmpIOp>(
+ arith::CmpIPredicate::eq, sourceAxisSizeModMeshAxisSize, zero);
+ builder.create<cf::AssertOp>(
+ isTargetShapeExactlyDivisible,
+ "Sharding a tensor with axis size that is not exactly divisible by the "
+ "mesh axis size is not supported.");
+ Value targetAxisSize =
+ builder.create<arith::DivUIOp>(sourceAxisSize, meshAxisSize);
+ Value axisOffset =
+ builder.create<arith::MulIOp>(targetAxisSize, processIndexAlongAxis);
+ SmallVector<int64_t> staticOffsets(targetShape.getRank(), 0);
+ staticOffsets[splitTensorAxis] = ShapedType::kDynamic;
+ DenseI64ArrayAttr staticOffsetsAttr =
+ DenseI64ArrayAttr::get(ctx, staticOffsets);
+ SmallVector<Value> dynamicOffsets(1, axisOffset);
+
+ DenseI64ArrayAttr staticSizesAttr =
+ DenseI64ArrayAttr::get(ctx, targetShape.getShape());
+ SmallVector<Value> dynamicSizes;
+ for (int64_t i = 0; i < targetShape.getRank(); ++i) {
+ if (ShapedType::isDynamic(staticSizesAttr.asArrayRef()[i])) {
+ if (i == splitTensorAxis) {
+ dynamicSizes.push_back(targetAxisSize);
+ } else {
+ Value dimSize = builder.create<tensor::DimOp>(sourceShard, i);
+ dynamicSizes.push_back(dimSize);
+ }
+ }
+ }
+
+ DenseI64ArrayAttr staticStridesAttr = DenseI64ArrayAttr::get(
+ ctx, SmallVector<int64_t>(targetShape.getRank(), 1));
+ TypedValue<RankedTensorType> targetShard =
+ builder
+ .create<tensor::ExtractSliceOp>(
+ targetShape, sourceShard, dynamicOffsets, dynamicSizes,
+ SmallVector<Value>({}), staticOffsetsAttr, staticSizesAttr,
+ staticStridesAttr)
+ .getResult();
+ return {targetShard.cast<TypedValue<ShapedType>>(), targetSharding};
+}
+
+// Detect if the resharding is of type e.g.
+// [[0, 1]] -> [[0, 1, 2]].
+// If detected, returns the corresponding tensor axis mesh axis pair.
+// Does not detect insertions like
+// [[0, 1]] -> [[0, 2, 1]].
+static std::optional<std::tuple<int64_t, MeshAxis>>
+detectSplitLastAxisInResharding(MeshShardingAttr sourceSharding,
+ MeshShardingAttr targetSharding) {
+ for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size();
+ ++tensorAxis) {
+ if (sourceSharding.getSplitAxes().size() > tensorAxis) {
+ if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 !=
+ targetSharding.getSplitAxes()[tensorAxis].size()) {
+ continue;
+ }
+ if (!llvm::equal(
+ sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(),
+ llvm::make_range(
+ targetSharding.getSplitAxes()[tensorAxis]
+ .asArrayRef()
+ .begin(),
+ targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
+ 1))) {
+ continue;
+ }
+ } else {
+ if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) {
+ continue;
+ }
+ }
+ return std::make_tuple(
+ tensorAxis,
+ targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
+ }
+ return std::nullopt;
+}
+
+static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
+trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+ MeshShardingAttr sourceSharding,
+ MeshShardingAttr targetSharding,
+ TypedValue<ShapedType> sourceShard) {
+ if (auto detectRes =
+ detectSplitLastAxisInResharding(sourceSharding, targetSharding)) {
+ auto [tensorAxis, meshAxis] = detectRes.value();
+ return splitLastAxisInResharding(builder, sourceSharding, sourceShard, mesh,
+ tensorAxis, meshAxis);
+ }
+
+ return std::nullopt;
+}
+
+// Detect if the resharding is of type e.g.
+// [[0, 1, 2]] -> [[0, 1]].
+// If detected, returns the corresponding tensor axis mesh axis pair.
+static std::optional<std::tuple<int64_t, MeshAxis>>
+detectUnsplitLastAxisInResharding(MeshShardingAttr sourceSharding,
+ MeshShardingAttr targetSharding) {
+ for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
+ ++tensorAxis) {
+ if (targetSharding.getSplitAxes().size() > tensorAxis) {
+ if (sourceSharding.getSplitAxes()[tensorAxis].size() !=
+ targetSharding.getSplitAxes()[tensorAxis].size() + 1)
+ continue;
+ if (!llvm::equal(
+ llvm::make_range(
+ sourceSharding.getSplitAxes()[tensorAxis]
+ .asArrayRef()
+ .begin(),
+ sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
+ 1),
+ targetSharding.getSplitAxes()[tensorAxis].asArrayRef()))
+ continue;
+ } else {
+ if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1)
+ continue;
+ }
+ return std::make_tuple(
+ tensorAxis,
+ sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
+ }
+ return std::nullopt;
+}
+
+static MeshShardingAttr
+targetShardingInUnsplitLastAxis(MLIRContext *ctx,
+ MeshShardingAttr sourceSharding,
+ int64_t splitTensorAxis) {
+ SmallVector<DenseI32ArrayAttr> targetShardingSplitAxes =
+ llvm::to_vector(sourceSharding.getSplitAxes());
+ assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
+ splitTensorAxis);
+ auto targetSplitAxes =
+ llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
+
+ targetSplitAxes.pop_back();
+ targetShardingSplitAxes[splitTensorAxis] =
+ DenseI32ArrayAttr::get(ctx, targetSplitAxes);
+ return MeshShardingAttr::get(
+ ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
+ sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
+}
+
+static ShapedType allGatherResultShapeInUnsplitLastAxis(
+ ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {
+ SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
+ targetShape[splitTensorAxis] =
+ unshardDimension(targetShape[splitTensorAxis], splitCount);
+ return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
+}
+
+static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
+unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
+ MeshShardingAttr sourceSharding,
+ ShapedType sourceUnshardedShape,
+ TypedValue<ShapedType> sourceShard, ClusterOp mesh,
+ int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
+ MLIRContext *ctx = builder.getContext();
+ builder.setInsertionPointAfterValue(sourceShard);
+
+ MeshShardingAttr targetSharding =
+ targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitMeshAxis);
+ ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
+ sourceShard.getType(), mesh.canonicalDimSizes()[splitMeshAxis],
+ splitTensorAxis);
+ Value allGatherResult = builder.create<AllGatherOp>(
+ RankedTensorType::get(allGatherResultShape.getShape(),
+ allGatherResultShape.getElementType()),
+ mesh.getSymName(), SmallVector<MeshAxis>({splitMeshAxis}), sourceShard,
+ APInt(64, splitTensorAxis));
+ ShapedType targetShape =
+ shardShapedType(sourceUnshardedShape, mesh, targetSharding);
+ TypedValue<ShapedType> targetShard =
+ builder.create<tensor::CastOp>(targetShape, allGatherResult)
+ .getResult()
+ .cast<TypedValue<ShapedType>>();
+ return {targetShard, targetSharding};
+}
+
+static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
+tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+ MeshShardingAttr sourceSharding,
+ MeshShardingAttr targetSharding,
+ ShapedType sourceUnshardedShape,
+ TypedValue<ShapedType> sourceShard) {
+ if (auto detectRes =
+ detectUnsplitLastAxisInResharding(sourceSharding, targetSharding)) {
+ auto [tensorAxis, meshAxis] = detectRes.value();
+ return unsplitLastAxisInResharding(builder, sourceSharding,
+ sourceUnshardedShape, sourceShard, mesh,
+ tensorAxis, meshAxis);
+ }
+
+ return std::nullopt;
+}
+
+// Detect if the resharding is of type e.g.
+// [[0, 1], [2]] -> [[0], [1, 2]].
+// Only moving the last axis counts.
+// If detected, returns the corresponding (source_tensor_axis,
+// target_tensor_axis, mesh_axis) tuple.
+static std::optional<std::tuple<int64_t, int64_t, MeshAxis>>
+detectMoveLastSplitAxisInResharding(MeshShardingAttr sourceSharding,
+ MeshShardingAttr targetSharding) {
+ for (size_t sourceTensorAxis = 0;
+ sourceTensorAxis < sourceSharding.getSplitAxes().size();
+ ++sourceTensorAxis) {
+ for (size_t targetTensorAxis = 0;
+ targetTensorAxis < targetSharding.getSplitAxes().size();
+ ++targetTensorAxis) {
+ if (sourceTensorAxis == targetTensorAxis)
+ continue;
+ if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() ||
+ targetSharding.getSplitAxes()[targetTensorAxis].empty() ||
+ sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
+ targetSharding.getSplitAxes()[targetTensorAxis]
+ .asArrayRef()
+ .back())
+ continue;
+ if (!llvm::equal(
+ llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis]
+ .asArrayRef()
+ .begin(),
+ sourceSharding.getSplitAxes()[sourceTensorAxis]
+ .asArrayRef()
+ .end() -
+ 1),
+ llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis]
+ .asArrayRef()
+ .begin(),
+ targetSharding.getSplitAxes()[targetTensorAxis]
+ .asArrayRef()
+ .end() -
+ 1)))
+ continue;
+ return std::make_tuple(
+ sourceTensorAxis, targetTensorAxis,
+ sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back());
+ }
+ }
+ return std::nullopt;
+}
+
+static MeshShardingAttr
+targetShardingInMoveLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
+ int64_t sourceTensorAxis,
+ int64_t targetTensorAxis) {
+ SmallVector<DenseI32ArrayAttr> targetShardingSplitAxes =
+ llvm::to_vector(sourceSharding.getSplitAxes());
+ while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
+ targetTensorAxis) {
+ targetShardingSplitAxes.push_back(DenseI32ArrayAttr::get(ctx, {}));
+ }
+
+ auto sourceSplitAxes =
+ llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
+ assert(!sourceSplitAxes.empty());
+ auto meshAxis = sourceSplitAxes.back();
+ sourceSplitAxes.pop_back();
+ targetShardingSplitAxes[sourceTensorAxis] =
+ DenseI32ArrayAttr::get(ctx, sourceSplitAxes);
+
+ auto targetSplitAxes =
+ llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
+ targetSplitAxes.push_back(meshAxis);
+ targetShardingSplitAxes[targetTensorAxis] =
+ DenseI32ArrayAttr::get(ctx, targetSplitAxes);
+
+ return MeshShardingAttr::get(
+ ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
+ sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
+}
+
+static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
+ int64_t splitCount,
+ int64_t sourceTensorAxis,
+ int64_t targetTensorAxis) {
+ SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
+ targetShape[sourceTensorAxis] =
+ unshardDimension(targetShape[sourceTensorAxis], splitCount);
+ targetShape[targetTensorAxis] =
+ shardDimension(targetShape[targetTensorAxis], splitCount);
+ return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
+}
+
+static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
+moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+ MeshShardingAttr sourceSharding,
+ ShapedType sourceUnshardedShape,
+ TypedValue<ShapedType> sourceShard,
+ int64_t sourceTensorAxis,
+ int64_t targetTensorAxis, MeshAxis meshAxis) {
+ MLIRContext *ctx = builder.getContext();
+ builder.setInsertionPointAfterValue(sourceShard);
+
+ MeshShardingAttr targetSharding = targetShardingInMoveLastAxis(
+ ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
+ ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
+ sourceShard.getType(), mesh.canonicalDimSizes()[meshAxis],
+ sourceTensorAxis, targetTensorAxis);
+ Value allToAllResult = builder.create<AllToAllOp>(
+ RankedTensorType::get(allToAllResultShape.getShape(),
+ allToAllResultShape.getElementType()),
+ mesh.getSymName(), SmallVector<MeshAxis>({meshAxis}), sourceShard,
+ APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
+ ShapedType targetShape =
+ shardShapedType(sourceUnshardedShape, mesh, targetSharding);
+ TypedValue<ShapedType> targetShard =
+ builder.create<tensor::CastOp>(targetShape, allToAllResult)
+ .getResult()
+ .cast<TypedValue<ShapedType>>();
+ return {targetShard, targetSharding};
+}
+
+static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
+tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+ MeshShardingAttr sourceSharding,
+ MeshShardingAttr targetSharding,
+ ShapedType sourceUnshardedShape,
+ TypedValue<ShapedType> sourceShard) {
+ if (auto detectRes =
+ detectMoveLastSplitAxisInResharding(sourceSharding, targetSharding)) {
+ auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value();
+ return moveLastSplitAxisInResharding(
+ builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard,
+ sourceTensorAxis, targetTensorAxis, meshAxis);
+ }
+
+ return std::nullopt;
+}
+
+// Handles only resharding on a 1D mesh.
+// Currently the sharded tensor axes must be exactly divisible by the single
+// mesh axis size.
+static TypedValue<ShapedType>
+reshardOn1DMesh(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+ MeshShardingAttr sourceSharding,
+ MeshShardingAttr targetSharding,
+ TypedValue<ShapedType> sourceUnshardedValue,
+ TypedValue<ShapedType> sourceShard) {
+ assert(sourceShard.getType() ==
+ shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding));
+ [[maybe_unused]] ShapedType targetShardType =
+ shardShapedType(sourceUnshardedValue.getType(), mesh, targetSharding);
+ assert(sourceShard.getType().getRank() == targetShardType.getRank());
+ assert(mesh.getRank() == 1 && "Only 1D meshes are currently supported.");
+
+ auto [reducedSourceShard, reducedSourceSharding] =
+ handlePartialAxesDuringResharding(builder, sourceSharding, targetSharding,
+ sourceShard);
+
+ if (reducedSourceSharding == targetSharding) {
+ return reducedSourceShard;
+ }
+
+ TypedValue<ShapedType> targetShard;
+ MeshShardingAttr actualTargetSharding;
+ if (auto tryRes = tryMoveLastSplitAxisInResharding(
+ builder, mesh, reducedSourceSharding, targetSharding,
+ sourceUnshardedValue.getType(), reducedSourceShard)) {
+ std::tie(targetShard, actualTargetSharding) = tryRes.value();
+ } else if (auto tryRes = trySplitLastAxisInResharding(
+ builder, mesh, reducedSourceSharding, targetSharding,
+ reducedSourceShard)) {
+ std::tie(targetShard, actualTargetSharding) = tryRes.value();
+ } else if (auto tryRes = tryUnsplitLastAxisInResharding(
+ builder, mesh, reducedSourceSharding, targetSharding,
+ sourceUnshardedValue.getType(), reducedSourceShard)) {
+ std::tie(targetShard, actualTargetSharding) = tryRes.value();
+ } else {
+ assert(false && "Did not find any pattern to apply.");
+ }
+
+ assert(actualTargetSharding == targetSharding);
+ assert(targetShard.getType() == targetShardType);
+ return targetShard;
+}
+
+TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+ MeshShardingAttr sourceSharding,
+ MeshShardingAttr targetSharding,
+ TypedValue<ShapedType> sourceUnshardedValue,
+ TypedValue<ShapedType> sourceShard) {
+ // Resort to handling only 1D meshes since the general case is complicated if
+ // it needs to be communication efficient in terms of minimizing the data
+ // transfered between devices.
+ return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding,
+ sourceUnshardedValue, sourceShard);
+}
+
+TypedValue<ShapedType> reshard(OpBuilder &builder, ClusterOp mesh,
+ ShardOp source, ShardOp target,
+ TypedValue<ShapedType> sourceShardValue) {
+ assert(!source.getAnnotateForUsers());
+ assert(target.getAnnotateForUsers());
+ assert(source.getResult() == target.getOperand());
+ ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
+ return reshard(
+ implicitLocOpBuilder, mesh, source.getShard(), target.getShard(),
+ source.getSrc().cast<TypedValue<ShapedType>>(), sourceShardValue);
+}
+
+void reshardingRegisterDependentDialects(DialectRegistry &registry) {
+ registry.insert<arith::ArithDialect, mesh::MeshDialect, tensor::TensorDialect,
+ cf::ControlFlowDialect>();
+}
+
+} // namespace mesh
+} // namespace mlir
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index 8af3b69..87a37a7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -448,6 +448,23 @@ static bool isAdmissibleBSR(SparseTensorType &aTp) {
return false;
}
+/// Test for 2:4 matrix with suitable metadata.
+static bool isAdmissible24(SparseTensorType &aTp) {
+ return aTp.getDimRank() == 2 && aTp.getLvlRank() == 3 && aTp.isDenseLvl(0) &&
+ aTp.isDenseLvl(1) && aTp.is2OutOf4Lvl(2) && isAdmissibleMetaData(aTp);
+}
+
+/// Test for conversion into 2:4 matrix.
+static bool isConversionInto24(Value v) {
+ if (auto cnv = v.getDefiningOp<ConvertOp>()) {
+ Value a = cnv.getResult();
+ Value d = cnv.getSource();
+ SparseTensorType aTp = getSparseTensorType(a);
+ return isDenseTensor(d) && isAdmissible24(aTp);
+ }
+ return false;
+}
+
/// Returns a suitable sparse format for the operation and given operand
/// types with cuSparse, or kNone if none is available.
static CuSparseFormat getCuSparseFormat(SparseTensorType aTp,
@@ -925,6 +942,15 @@ static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter,
Value C = op.getOperand(2); // we have C = AB
SmallVector<Value> tokens;
+ // The cuSparselt API currently only allows pruning and compression
+ // to occur on the device. So we recognize the pattern
+ // A' = convert A ; dense to 2:4
+ // C = A'B ; 2:4 matrix mult
+ // and then perform compression and matrix multiplication on device.
+ auto cnv = A.getDefiningOp<ConvertOp>();
+ assert(cnv);
+ A = cnv.getSource();
+
// All input should be dense tensors.
if (!isDenseTensor(A) || !isDenseTensor(B) || !isDenseTensor(C))
return failure();
@@ -1260,7 +1286,7 @@ struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) {
if (!isDenseTensor(op.getOperand(0)) && !isDenseTensor(op.getOperand(1)))
return rewriteSpGEMM(rewriter, op, enableRT);
- if (op->getAttr("DENSE24"))
+ if (isConversionInto24(op.getOperand(0)))
return rewrite2To4SpMM(rewriter, op);
return rewriteSpMM(rewriter, op, enableRT);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 934e1e5..35eb4b4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -949,94 +949,9 @@ static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
// Sparsifier synthesis methods (loop sequence).
//===----------------------------------------------------------------------===//
-/// Starts a loop sequence at given level. Returns true if
-/// the universal loop index must be maintained at this level.
-static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
- LoopId curr, LatSetId lts) {
- assert(!env.getLoopVar(curr));
- // Emit invariants at this loop sequence level.
- genInvariants(env, builder, exp, curr, /*isStart=*/true);
- // Emit access pattern expansion for sparse tensor output.
- genExpand(env, builder, curr, /*isStart=*/true);
- // Emit further intitialization at this loop sequence level.
- const LatPointId l0 = env.set(lts)[0];
- bool needsUniv = false;
-
- SmallVector<TensorLevel> tidLvls;
- env.merger().foreachTensorLoopId(l0, [&](TensorLoopId b, TensorId tid,
- std::optional<Level> lvl,
- LevelType lt, bool isIdxReduc) {
- assert(env.merger().loop(b) == curr);
- if (isDenseLT(lt) || isUndefLT(lt)) {
- if (tid == env.merger().getSynTensorID()) {
- // Needs loop emitter to set up loop bounds for synthetic tensor too if
- // there is a loop condition imposed on the synthetic tensor.
- tidLvls.push_back(env.makeTensorLevel(tid, env.getCurrentDepth()));
- }
- needsUniv = true;
- }
- if (isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
- is2OutOf4LT(lt) || isIdxReduc) {
- // Only when this is a index reduction loop, can the lt be undefined.
- assert(!isUndefLT(lt) || isIdxReduc);
- // sparse/singleton levels, or a dense/sparse index reduction loop.
- tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
- }
- });
-
- env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls);
-
- // Maintain the universal index only if it is actually
- // consumed by a subsequent lattice point.
- if (needsUniv) {
- for (const LatPointId li : env.set(lts).drop_front())
- if (!env.merger().hasAnySparse(env.lat(li).simple))
- return true;
- }
- return false;
-}
-
-// Generates dense affine address for encoding.
-static void genConstantDenseAddressFromLevel(CodegenEnv &env,
- OpBuilder &builder, TensorId tid,
- Level startLvl) {
- // TODO: Handle affine expression on output tensor.
- linalg::GenericOp op = env.op();
- assert(tid < op.getNumDpsInputs());
- OpOperand *input = op.getDpsInputOperands()[tid];
- const auto lvlExprs = op.getMatchingIndexingMap(input).getResults();
- const auto enc = getSparseTensorEncoding(input->get().getType());
- if (enc) {
- const Location loc = op.getLoc();
- const TensorId tid = env.makeTensorId(input->getOperandNumber());
- const Level lvlRank = enc.getLvlRank();
- assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
- for (Level l = startLvl; l < lvlRank; l++) {
- AffineExpr lvlExpr = lvlExprs[l];
- if (enc.isDenseLvl(l) && isa<AffineConstantExpr>(lvlExpr))
- env.emitter().genDenseAffineAddress(
- builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
- else
- return; // break on first non-dense non-constant level
- }
- }
-}
-
-// We can generate address for constant affine expression before any loops
-// starting from the first level as they do not depend on any thing.
-// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
-// levels can be determined before loops.
-static void genInitConstantDenseAddress(CodegenEnv &env,
- RewriterBase &rewriter) {
- for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
- genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
-}
-
-/// Return true if the lattices bit can be iterated by a for loop.
-static bool translateBitsToTidLvlPairs(
+static bool getAllTidLvlsInLatPoints(
CodegenEnv &env, LatPointId li, LoopId curr,
- SmallVectorImpl<TensorLevel> &tidLvls,
- SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
+ llvm::function_ref<void(TensorLevel, AffineExpr)> callback) {
const BitVector &simple = env.lat(li).simple;
const TensorId outTid = env.merger().getOutTensorID();
const std::optional<Level> outLvl = env.merger().getLvl(outTid, curr);
@@ -1048,7 +963,7 @@ static bool translateBitsToTidLvlPairs(
LevelType lt, bool isIdxReduc) {
if (simple[b]) {
if (isIdxReduc) {
- tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
+ callback(env.makeTensorLevel(tid, *lvl), nullptr);
numloopCond++;
return;
}
@@ -1072,10 +987,10 @@ static bool translateBitsToTidLvlPairs(
}
}
hasNonUnique = !isUniqueLT(lt) || hasNonUnique;
- tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
+ callback(env.makeTensorLevel(tid, *lvl), nullptr);
numloopCond++;
} else if (isDenseLT(lt) || isIdxReduc) {
- tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
+ callback(env.makeTensorLevel(tid, *lvl), nullptr);
} else {
assert(isUndefLT(lt));
linalg::GenericOp op = env.op();
@@ -1109,7 +1024,7 @@ static bool translateBitsToTidLvlPairs(
// level. We need to generate the address according to the
// affine expression. This is also the best place we can do it
// to avoid putting it inside inner loops.
- affineTidLvls.emplace_back(env.makeTensorLevel(tid, l), exp);
+ callback(env.makeTensorLevel(tid, l), exp);
}
}
}
@@ -1120,15 +1035,14 @@ static bool translateBitsToTidLvlPairs(
// Note that we generate dense indices of the output tensor
// unconditionally, since they may not appear in the lattice, but may be
// needed for linearized env.
- tidLvls.push_back(env.makeTensorLevel(outTid, *outLvl));
+ callback(env.makeTensorLevel(outTid, *outLvl), nullptr);
}
if (numloopCond == 0) {
// Corner cases where the loop bound is defined by a *unused* operand, in
// this case, we just generate a dense "fake" loop by iterating over the
// synthetic tensor.
- tidLvls.push_back(env.makeTensorLevel(env.merger().getSynTensorID(),
- env.getCurrentDepth()));
+ callback(env.makeTensorLevel(env.merger().getSynTensorID(), curr), nullptr);
numloopCond++;
}
// If we just need to one loop conditions and the conditions is not imposed on
@@ -1136,6 +1050,84 @@ static bool translateBitsToTidLvlPairs(
return numloopCond == 1 && !hasNonUnique;
}
+/// Starts a loop sequence at given level. Returns true if
+/// the universal loop index must be maintained at this level.
+static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
+ LoopId curr, LatSetId lts) {
+ assert(!env.getLoopVar(curr));
+ // Emit invariants at this loop sequence level.
+ genInvariants(env, builder, exp, curr, /*isStart=*/true);
+ // Emit access pattern expansion for sparse tensor output.
+ genExpand(env, builder, curr, /*isStart=*/true);
+ // Emit further initialization at this loop sequence level.
+ const LatPointId l0 = env.set(lts)[0];
+
+ SmallVector<TensorLevel> tidLvls;
+ getAllTidLvlsInLatPoints(env, l0, curr, [&](TensorLevel tl, AffineExpr) {
+ tidLvls.emplace_back(tl);
+ });
+
+ env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls);
+
+ // Maintain the universal index only if it is actually
+ // consumed by a subsequent lattice point.
+ for (const LatPointId li : env.set(lts).drop_front())
+ if (!env.merger().hasAnySparse(env.lat(li).simple))
+ return true;
+
+ return false;
+}
+
+// Generates dense affine address for encoding.
+static void genConstantDenseAddressFromLevel(CodegenEnv &env,
+ OpBuilder &builder, TensorId tid,
+ Level startLvl) {
+ // TODO: Handle affine expression on output tensor.
+ linalg::GenericOp op = env.op();
+ assert(tid < op.getNumDpsInputs());
+ OpOperand *input = op.getDpsInputOperands()[tid];
+ const auto lvlExprs = op.getMatchingIndexingMap(input).getResults();
+ const auto enc = getSparseTensorEncoding(input->get().getType());
+ if (enc) {
+ const Location loc = op.getLoc();
+ const TensorId tid = env.makeTensorId(input->getOperandNumber());
+ const Level lvlRank = enc.getLvlRank();
+ assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
+ for (Level l = startLvl; l < lvlRank; l++) {
+ AffineExpr lvlExpr = lvlExprs[l];
+ if (enc.isDenseLvl(l) && isa<AffineConstantExpr>(lvlExpr))
+ env.emitter().genDenseAffineAddress(
+ builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
+ else
+ return; // break on first non-dense non-constant level
+ }
+ }
+}
+
+// We can generate address for constant affine expression before any loops
+// starting from the first level as they do not depend on anything.
+// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
+// levels can be determined before loops.
+static void genInitConstantDenseAddress(CodegenEnv &env,
+ RewriterBase &rewriter) {
+ for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
+ genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
+}
+
+/// Returns true if the lattice bit can be iterated by a for loop.
+static bool translateBitsToTidLvlPairs(
+ CodegenEnv &env, LatPointId li, LoopId curr,
+ SmallVectorImpl<TensorLevel> &tidLvls,
+ SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
+ return getAllTidLvlsInLatPoints(env, li, curr,
+ [&](TensorLevel tl, AffineExpr exp) {
+ if (exp)
+ affineTidLvls.emplace_back(tl, exp);
+ else
+ tidLvls.emplace_back(tl);
+ });
+}
+
/// Starts a single loop in current sequence.
static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
OpBuilder &builder, LoopId curr,
diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index e20450c..cfd838e 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -61,6 +61,47 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
}
};
+struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
+ using OpRewritePattern<UnPackOp>::OpRewritePattern;
+
+ Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand,
+ Type newOperandType, ArrayAttr reassociation) const {
+ if (operand.getType() == newOperandType)
+ return operand;
+ return rewriter.create<tensor::CollapseShapeOp>(loc, newOperandType,
+ operand, reassociation);
+ }
+
+ LogicalResult matchAndRewrite(UnPackOp unpackOp,
+ PatternRewriter &rewriter) const override {
+ if (!unpackOp.getOuterDimsPerm().empty()) {
+ return rewriter.notifyMatchFailure(unpackOp,
+ "expects no outer_dims_perm");
+ }
+
+ RankedTensorType sourceType = unpackOp.getSourceType();
+ RankedTensorType destType = unpackOp.getDestType();
+ if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
+ return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
+
+ ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos();
+ if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) {
+ return rewriter.notifyMatchFailure(
+ unpackOp, "expects unpacking at the innermost dimension");
+ }
+
+ auto reassociation =
+ getReassociationIndicesForReshape(sourceType, destType);
+ if (!reassociation)
+ return failure();
+ Value collapsed = insertCollapse(
+ rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType,
+ getReassociationIndicesAttribute(rewriter, *reassociation));
+ rewriter.replaceOp(unpackOp, collapsed);
+ return success();
+ }
+};
+
/// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
/// the pad op has zero low paddings, or if `pack` has no padding values.
struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
@@ -191,7 +232,8 @@ void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
}
void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) {
- patterns.add<SimplifyPackToExpandShape>(patterns.getContext());
+ patterns.add<SimplifyPackToExpandShape, SimplifyUnPackToCollapseShape>(
+ patterns.getContext());
}
} // namespace tensor
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 7136e42..aa4694c 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -32,6 +32,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <optional>
@@ -1975,6 +1976,42 @@ void transform::NamedSequenceOp::build(OpBuilder &builder,
}
//===----------------------------------------------------------------------===//
+// NumAssociationsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ size_t numAssociations =
+ llvm::TypeSwitch<Type, size_t>(getHandle().getType())
+ .Case([&](TransformHandleTypeInterface opHandle) {
+ return llvm::range_size(state.getPayloadOps(getHandle()));
+ })
+ .Case([&](TransformValueHandleTypeInterface valueHandle) {
+ return llvm::range_size(state.getPayloadValues(getHandle()));
+ })
+ .Case([&](TransformParamTypeInterface param) {
+ return llvm::range_size(state.getParams(getHandle()));
+ })
+ .Default([](Type) {
+ llvm_unreachable("unknown kind of transform dialect type");
+ return 0;
+ });
+ results.setParams(getNum().cast<OpResult>(),
+ rewriter.getI64IntegerAttr(numAssociations));
+ return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::NumAssociationsOp::verify() {
+ // Verify that the result type accepts an i64 attribute as payload.
+ auto resultType = getNum().getType().cast<TransformParamTypeInterface>();
+ return resultType
+ .checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)})
+ .checkAndReport();
+}
+
+//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 2ad992a..c1c0f54 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -271,7 +271,7 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
return false;
// Cond 1: A contiguous memref will always have a unit trailing stride.
- if (strides.back() != 1)
+ if (strides.empty() || strides.back() != 1)
return false;
// Cond 2: Strides of a contiguous memref have to match the flattened dims.