diff options
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r-- | mlir/lib/Dialect/GPU/CMakeLists.txt | 1 | ||||
-rw-r--r-- | mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp | 27 | ||||
-rw-r--r-- | mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp | 174 | ||||
-rw-r--r-- | mlir/lib/Dialect/GPU/Transforms/Utils.cpp | 44 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 91 | ||||
-rw-r--r-- | mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 213 | ||||
-rw-r--r-- | mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt | 3 | ||||
-rw-r--r-- | mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp | 639 | ||||
-rw-r--r-- | mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp | 28 | ||||
-rw-r--r-- | mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp | 180 | ||||
-rw-r--r-- | mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp | 44 | ||||
-rw-r--r-- | mlir/lib/Dialect/Transform/IR/TransformOps.cpp | 37 | ||||
-rw-r--r-- | mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 2 |
13 files changed, 1277 insertions, 206 deletions
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt index 8383e06..8f289ce 100644 --- a/mlir/lib/Dialect/GPU/CMakeLists.txt +++ b/mlir/lib/Dialect/GPU/CMakeLists.txt @@ -64,6 +64,7 @@ add_mlir_dialect_library(MLIRGPUTransforms Transforms/ShuffleRewriter.cpp Transforms/SPIRVAttachTarget.cpp Transforms/SubgroupReduceLowering.cpp + Transforms/Utils.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp index 608d801..a75598a 100644 --- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp @@ -27,33 +27,6 @@ using namespace mlir; namespace { -static vector::CombiningKind -convertReductionKind(gpu::AllReduceOperation mode) { - switch (mode) { -#define MAP_CASE(X) \ - case gpu::AllReduceOperation::X: \ - return vector::CombiningKind::X - - MAP_CASE(ADD); - MAP_CASE(MUL); - MAP_CASE(MINUI); - MAP_CASE(MINSI); - MAP_CASE(MINNUMF); - MAP_CASE(MAXSI); - MAP_CASE(MAXUI); - MAP_CASE(MAXNUMF); - MAP_CASE(AND); - MAP_CASE(OR); - MAP_CASE(XOR); - MAP_CASE(MINIMUMF); - MAP_CASE(MAXIMUMF); - -#undef MAP_CASE - } - - llvm_unreachable("Vector and GPU reduction kinds should match 1:1"); -} - struct GpuAllReduceRewriter { using AccumulatorFactory = std::function<Value(Value, Value)>; diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp index 61edce5..b00c65c 100644 --- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp @@ -13,13 +13,17 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/GPU/Transforms/Passes.h" +#include "mlir/Dialect/GPU/Transforms/Utils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LogicalResult.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MathExtras.h" #include <cassert> +#include <cstdint> using namespace mlir; @@ -58,7 +62,7 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> { unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth(); if (elemBitwidth >= maxShuffleBitwidth) return rewriter.notifyMatchFailure( - op, llvm::formatv("element type too large {0}, cannot break down " + op, llvm::formatv("element type too large ({0}), cannot break down " "into vectors of bitwidth {1} or less", elemBitwidth, maxShuffleBitwidth)); @@ -139,6 +143,167 @@ struct ScalarizeSingleElementReduce final } }; +/// Emits a subgroup reduction using a sequence of shuffles. Uses the `packFn` +/// and `unpackFn` to convert to the native shuffle type and to the reduction +/// type, respectively. For example, with `input` of type `f16`, `packFn` could +/// build ops to cast the value to `i32` to perform shuffles, while `unpackFn` +/// would cast it back to `f16` to perform arithmetic reduction on. Assumes that +/// the subgroup is `subgroupSize` lanes wide and reduces across all of them. +static Value createSubgroupShuffleReduction( + OpBuilder &builder, Location loc, Value input, gpu::AllReduceOperation mode, + unsigned subgroupSize, function_ref<Value(Value)> packFn, + function_ref<Value(Value)> unpackFn) { + assert(llvm::isPowerOf2_32(subgroupSize)); + // Lane value always stays in the original type. We use it to perform arith + // reductions. + Value laneVal = input; + // Parallel reduction using butterfly shuffles. + for (unsigned i = 1; i < subgroupSize; i <<= 1) { + Value shuffled = builder + .create<gpu::ShuffleOp>(loc, packFn(laneVal), i, + /*width=*/subgroupSize, + /*mode=*/gpu::ShuffleMode::XOR) + .getShuffleResult(); + laneVal = vector::makeArithReduction(builder, loc, + gpu::convertReductionKind(mode), + laneVal, unpackFn(shuffled)); + assert(laneVal.getType() == input.getType()); + } + + return laneVal; +} + +/// Lowers scalar gpu subgroup reductions to a series of shuffles. +struct ScalarSubgroupReduceToShuffles final + : OpRewritePattern<gpu::SubgroupReduceOp> { + ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize, + unsigned shuffleBitwidth, + PatternBenefit benefit) + : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize), + shuffleBitwidth(shuffleBitwidth) {} + + LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, + PatternRewriter &rewriter) const override { + Type valueTy = op.getType(); + unsigned elemBitwidth = + getElementTypeOrSelf(valueTy).getIntOrFloatBitWidth(); + if (!valueTy.isIntOrFloat() || elemBitwidth > shuffleBitwidth) + return rewriter.notifyMatchFailure( + op, "value type is not a compatible scalar"); + + Location loc = op.getLoc(); + // Since this is already a native shuffle scalar, no packing is necessary. + if (elemBitwidth == shuffleBitwidth) { + auto identityFn = [](Value v) { return v; }; + rewriter.replaceOp(op, createSubgroupShuffleReduction( + rewriter, loc, op.getValue(), op.getOp(), + subgroupSize, identityFn, identityFn)); + return success(); + } + + auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth); + auto equivIntType = rewriter.getIntegerType(elemBitwidth); + auto packFn = [loc, &rewriter, equivIntType, + shuffleIntType](Value unpackedVal) -> Value { + auto asInt = + rewriter.create<arith::BitcastOp>(loc, equivIntType, unpackedVal); + return rewriter.create<arith::ExtUIOp>(loc, shuffleIntType, asInt); + }; + auto unpackFn = [loc, &rewriter, equivIntType, + valueTy](Value packedVal) -> Value { + auto asInt = + rewriter.create<arith::TruncIOp>(loc, equivIntType, packedVal); + return rewriter.create<arith::BitcastOp>(loc, valueTy, asInt); + }; + + rewriter.replaceOp(op, createSubgroupShuffleReduction( + rewriter, loc, op.getValue(), op.getOp(), + subgroupSize, packFn, unpackFn)); + return success(); + } + +private: + unsigned subgroupSize = 0; + unsigned shuffleBitwidth = 0; +}; + +/// Lowers vector gpu subgroup reductions to a series of shuffles. +struct VectorSubgroupReduceToShuffles final + : OpRewritePattern<gpu::SubgroupReduceOp> { + VectorSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize, + unsigned shuffleBitwidth, + PatternBenefit benefit) + : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize), + shuffleBitwidth(shuffleBitwidth) {} + + LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, + PatternRewriter &rewriter) const override { + auto vecTy = dyn_cast<VectorType>(op.getType()); + if (!vecTy) + return rewriter.notifyMatchFailure(op, "value type is not a vector"); + + unsigned vecBitwidth = + vecTy.getNumElements() * vecTy.getElementTypeBitWidth(); + if (vecBitwidth > shuffleBitwidth) + return rewriter.notifyMatchFailure( + op, + llvm::formatv("vector type bitwidth too large ({0}), cannot lower " + "to shuffles of size {1}", + vecBitwidth, shuffleBitwidth)); + + unsigned elementsPerShuffle = + shuffleBitwidth / vecTy.getElementTypeBitWidth(); + if (elementsPerShuffle * vecTy.getElementTypeBitWidth() != shuffleBitwidth) + return rewriter.notifyMatchFailure( + op, "shuffle bitwidth is not a multiple of the element bitwidth"); + + Location loc = op.getLoc(); + + // If the reduced type is smaller than the native shuffle size, extend it, + // perform the shuffles, and extract at the end. + auto extendedVecTy = VectorType::get( + static_cast<int64_t>(elementsPerShuffle), vecTy.getElementType()); + Value extendedInput = op.getValue(); + if (vecBitwidth < shuffleBitwidth) { + auto zero = rewriter.create<arith::ConstantOp>( + loc, rewriter.getZeroAttr(extendedVecTy)); + extendedInput = rewriter.create<vector::InsertStridedSliceOp>( + loc, extendedInput, zero, /*offsets=*/0, /*strides=*/1); + } + + auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth); + auto shuffleVecType = VectorType::get(1, shuffleIntType); + + auto packFn = [loc, &rewriter, shuffleVecType](Value unpackedVal) -> Value { + auto asIntVec = + rewriter.create<vector::BitCastOp>(loc, shuffleVecType, unpackedVal); + return rewriter.create<vector::ExtractOp>(loc, asIntVec, 0); + }; + auto unpackFn = [loc, &rewriter, shuffleVecType, + extendedVecTy](Value packedVal) -> Value { + auto asIntVec = + rewriter.create<vector::BroadcastOp>(loc, shuffleVecType, packedVal); + return rewriter.create<vector::BitCastOp>(loc, extendedVecTy, asIntVec); + }; + + Value res = + createSubgroupShuffleReduction(rewriter, loc, extendedInput, op.getOp(), + subgroupSize, packFn, unpackFn); + + if (vecBitwidth < shuffleBitwidth) { + res = rewriter.create<vector::ExtractStridedSliceOp>( + loc, res, /*offsets=*/0, /*sizes=*/vecTy.getNumElements(), + /*strides=*/1); + } + + rewriter.replaceOp(op, res); + return success(); + } + +private: + unsigned subgroupSize = 0; + unsigned shuffleBitwidth = 0; +}; } // namespace void mlir::populateGpuBreakDownSubgrupReducePatterns( @@ -148,3 +313,10 @@ void mlir::populateGpuBreakDownSubgrupReducePatterns( maxShuffleBitwidth, benefit); patterns.add<ScalarizeSingleElementReduce>(patterns.getContext(), benefit); } + +void mlir::populateGpuLowerSubgroupReduceToShufflePattenrs( + RewritePatternSet &patterns, unsigned subgroupSize, + unsigned shuffleBitwidth, PatternBenefit benefit) { + patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>( + patterns.getContext(), subgroupSize, shuffleBitwidth, benefit); +} diff --git a/mlir/lib/Dialect/GPU/Transforms/Utils.cpp b/mlir/lib/Dialect/GPU/Transforms/Utils.cpp new file mode 100644 index 0000000..e91aa18 --- /dev/null +++ b/mlir/lib/Dialect/GPU/Transforms/Utils.cpp @@ -0,0 +1,44 @@ +//===- Utils.cpp - GPU transforms utils -----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implements GPU dialect transforms utils. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/GPU/Transforms/Utils.h" +#include "llvm/Support/ErrorHandling.h" + +namespace mlir::gpu { + +vector::CombiningKind convertReductionKind(gpu::AllReduceOperation mode) { + switch (mode) { +#define MAP_CASE(X) \ + case gpu::AllReduceOperation::X: \ + return vector::CombiningKind::X + + MAP_CASE(ADD); + MAP_CASE(MUL); + MAP_CASE(MINUI); + MAP_CASE(MINSI); + MAP_CASE(MINNUMF); + MAP_CASE(MAXSI); + MAP_CASE(MAXUI); + MAP_CASE(MAXNUMF); + MAP_CASE(AND); + MAP_CASE(OR); + MAP_CASE(XOR); + MAP_CASE(MINIMUMF); + MAP_CASE(MAXIMUMF); + +#undef MAP_CASE + } + + llvm_unreachable("Vector and GPU reduction kinds should match 1:1"); +} + +} // namespace mlir::gpu diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index ba419d3..2917840 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -176,22 +176,22 @@ static bool isContractionBody(Block &block) { return linalg::detail::isContractionBody(block, &isPairTemplateImpl<Args...>); } -/// Given a `linalgOp` and one of its `opOperand`, returns the positions of the -/// iterators of type `iter` that index the `opOperand` as a permutation. -/// This is useful to infer various subcomputations on a given `linalgOp`. -/// This is performed by looking up each result in the matching indexing map and -/// determining whether: +/// Given an `indexingMap` and its corresponding `iterators`, returns +/// the positions of the iterators of type `iter` that are indexed by +/// the `indexingMap` as a permutation. This is useful to infer various +/// subcomputations on a `LinalgOp`. This is performed by looking up +/// each result in the `indexingMap` and determining whether: /// - It is a single AffineDimExpr. /// - It is the only result involving this AffineDimExpr. static llvm::SmallDenseSet<int64_t> -findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand, +findPermutationsIndexingOperand(AffineMap indexingMap, + ArrayRef<utils::IteratorType> iterators, utils::IteratorType iter) { + assert(iterators.size() == indexingMap.getNumDims()); llvm::SmallDenseSet<int64_t> res; - assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner"); - AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand); for (AffineExpr e : indexingMap.getResults()) { if (auto d = dyn_cast<AffineDimExpr>(e)) { - if (linalgOp.getIteratorTypesArray()[d.getPosition()] == iter && + if (iterators[d.getPosition()] == iter && llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) { return e.isFunctionOfDim(d.getPosition()); }) == 1) @@ -206,6 +206,21 @@ auto par = utils::IteratorType::parallel; auto red = utils::IteratorType::reduction; } // namespace +/// Infer the iterator types from the init affine map. This looks at which dims +/// are present in the map results, and returns an iterator types array with +/// parallel types for dims that are present, and reduction types for dims that +/// are not present. +static FailureOr<SmallVector<utils::IteratorType>> +inferIteratorsFromOutMap(AffineMap map) { + if (!map.isProjectedPermutation()) + return failure(); + SmallVector<utils::IteratorType> iterators(map.getNumDims(), red); + for (auto expr : map.getResults()) + if (auto dim = dyn_cast<AffineDimExpr>(expr)) + iterators[dim.getPosition()] = par; + return iterators; +} + /// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form /// a matmul subcomputation within `linalgOp`. These dimensions are such that: /// 1. The m dimension is involved in an outer-product along LHS @@ -217,17 +232,15 @@ auto red = utils::IteratorType::reduction; /// 5. Optional batch dimensions that appear in all operands are captured. /// This allows e.g. detecting that some contraction is embedded within /// `linalgOp` with some orthogonal heuristic. -FailureOr<ContractionDimensions> -mlir::linalg::inferContractionDims(LinalgOp linalgOp) { - if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2) - return failure(); - - llvm::SmallDenseSet<int64_t> a = findPermutationsIndexingOperand( - linalgOp, linalgOp.getDpsInputOperand(0), par); - llvm::SmallDenseSet<int64_t> b = findPermutationsIndexingOperand( - linalgOp, linalgOp.getDpsInputOperand(1), par); - llvm::SmallDenseSet<int64_t> c = findPermutationsIndexingOperand( - linalgOp, linalgOp.getDpsInitOperand(0), par); +static FailureOr<ContractionDimensions> +inferContractionDimsImpl(ArrayRef<AffineMap> indexingMaps, + ArrayRef<utils::IteratorType> iterators) { + llvm::SmallDenseSet<int64_t> a = + findPermutationsIndexingOperand(indexingMaps[0], iterators, par); + llvm::SmallDenseSet<int64_t> b = + findPermutationsIndexingOperand(indexingMaps[1], iterators, par); + llvm::SmallDenseSet<int64_t> c = + findPermutationsIndexingOperand(indexingMaps[2], iterators, par); // A & C - B are the iterators involved in an outer-product along A (the LHS). llvm::SmallDenseSet<int64_t> ac = a; @@ -243,10 +256,10 @@ mlir::linalg::inferContractionDims(LinalgOp linalgOp) { llvm::set_intersect(batches, c); // A & B red are the reduction dimensions. - llvm::SmallDenseSet<int64_t> ra = findPermutationsIndexingOperand( - linalgOp, linalgOp.getDpsInputOperand(0), red); - llvm::SmallDenseSet<int64_t> rb = findPermutationsIndexingOperand( - linalgOp, linalgOp.getDpsInputOperand(1), red); + llvm::SmallDenseSet<int64_t> ra = + findPermutationsIndexingOperand(indexingMaps[0], iterators, red); + llvm::SmallDenseSet<int64_t> rb = + findPermutationsIndexingOperand(indexingMaps[1], iterators, red); llvm::set_intersect(ra, rb); // Return each set in sorted order. @@ -262,6 +275,24 @@ mlir::linalg::inferContractionDims(LinalgOp linalgOp) { return dimensions; } +FailureOr<ContractionDimensions> +mlir::linalg::inferContractionDims(LinalgOp linalgOp) { + if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2) + return failure(); + return inferContractionDimsImpl(linalgOp.getIndexingMapsArray(), + linalgOp.getIteratorTypesArray()); +} + +FailureOr<ContractionDimensions> +mlir::linalg::inferContractionDims(ArrayRef<AffineMap> indexingMaps) { + if (indexingMaps.size() != 3) + return failure(); + auto iterators = inferIteratorsFromOutMap(indexingMaps[2]); + if (failed(iterators)) + return failure(); + return inferContractionDimsImpl(indexingMaps, iterators.value()); +} + namespace mlir::linalg::detail { enum class MatchContractionResult { Success = 0, @@ -504,10 +535,14 @@ static FailureOr<ConvolutionDimensions> inferConvolutionDimsImpl(LinalgOp linalgOp, ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims) { + auto filterMap = + linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1)); + auto outputMap = + linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0)); llvm::SmallDenseSet<int64_t> filterDims = findPermutationsIndexingOperand( - linalgOp, linalgOp.getDpsInputOperand(1), par); + filterMap, linalgOp.getIteratorTypesArray(), par); llvm::SmallDenseSet<int64_t> outputDims = findPermutationsIndexingOperand( - linalgOp, linalgOp.getDpsInitOperand(0), par); + outputMap, linalgOp.getIteratorTypesArray(), par); // unConvolvedDims & outputDims - filterDims are the batch iterators. llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims; @@ -529,8 +564,8 @@ inferConvolutionDimsImpl(LinalgOp linalgOp, llvm::set_intersect(depth, inputExprWalker.unConvolvedDims); llvm::SmallDenseSet<int64_t> filterReducedDims = - findPermutationsIndexingOperand(linalgOp, linalgOp.getDpsInputOperand(1), - red); + findPermutationsIndexingOperand(filterMap, + linalgOp.getIteratorTypesArray(), red); // convolvedDims & filterReducedDims are the filter loop iterators. llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims; diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp index d276755..de4f58d 100644 --- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp +++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp @@ -9,8 +9,10 @@ #include "mlir/Dialect/Mesh/IR/MeshOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Location.h" @@ -59,8 +61,6 @@ static SmallVector<T> &canonicalizeSetAsVector(SmallVector<T> &vec) { return vec; } -using MeshAxis = int16_t; - namespace { struct DimensionSize { @@ -114,6 +114,56 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value, // Mesh utilities //===----------------------------------------------------------------------===// +static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, + SymbolTableCollection &symbolTable) { + mesh::ClusterOp mesh = + symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(op, meshSymbol); + if (!mesh) { + return op->emitError() << "Undefined required mesh symbol \"" + << meshSymbol.getValue() << "\"."; + } + + return mesh; +} + +template <typename It> +bool isUnique(It begin, It end) { + if (begin == end) { + return true; + } + It next = std::next(begin); + if (next == end) { + return true; + } + for (; next != end; ++next, ++begin) { + if (*begin == *next) { + return false; + } + } + return true; +} + +static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes, + ClusterOp mesh) { + SmallVector<MeshAxis> sorted = llvm::to_vector(axes); + llvm::sort(sorted); + if (!isUnique(sorted.begin(), sorted.end())) { + return emitError(loc) << "Mesh axes contains duplicate elements."; + } + + MeshAxis rank = mesh.getRank(); + for (auto axis : axes) { + if (axis >= rank || axis < 0) { + return emitError(loc) + << "0-based mesh axis index " << axis + << " is out of bounds. The referenced mesh \"" << mesh.getSymName() + << "\" is of rank " << rank << "."; + } + } + + return success(); +} + bool mesh::isReductionLoop(IteratorType iType) { return iType != IteratorType::Parallel && iType != IteratorType::Invalid; } @@ -173,7 +223,45 @@ SmallVector<int64_t> ClusterOp::canonicalDimSizes() { } //===----------------------------------------------------------------------===// -// mesh.shard op +// mesh.cluster_shape op +//===----------------------------------------------------------------------===// + +LogicalResult +ClusterShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable); + if (failed(mesh)) { + return failure(); + } + if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) { + return failure(); + } + + size_t expectedResultsCount = + getAxes().empty() ? mesh->getRank() : getAxes().size(); + if (getResult().size() != expectedResultsCount) { + return emitError() << "Unexpected number of results " << getResult().size() + << ". Expected " << expectedResultsCount << "."; + } + + return success(); +} + +void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState, + ClusterOp mesh) { + build(odsBuilder, odsState, + SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()), + mesh.getSymName(), MeshAxesAttr()); +} + +void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState, + StringRef mesh, ArrayRef<MeshAxis> axes) { + build(odsBuilder, odsState, + SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh, + MeshAxesAttr::get(odsBuilder.getContext(), axes)); +} + +//===----------------------------------------------------------------------===// +// mesh.shard attr //===----------------------------------------------------------------------===// LogicalResult @@ -205,6 +293,75 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError, return success(); } +bool MeshShardingAttr::operator==(Attribute rhs) const { + MeshShardingAttr rhsAsMeshShardingAttr = rhs.dyn_cast<MeshShardingAttr>(); + return rhsAsMeshShardingAttr && *this == rhsAsMeshShardingAttr; +} + +bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const { + if (getCluster() != rhs.getCluster() || + getPartialAxes() != rhs.getPartialAxes()) { + return false; + } + + if (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) { + return false; + } + + auto minSize = std::min(getSplitAxes().size(), rhs.getSplitAxes().size()); + if (!llvm::equal(llvm::make_range(getSplitAxes().begin(), + getSplitAxes().begin() + minSize), + llvm::make_range(rhs.getSplitAxes().begin(), + rhs.getSplitAxes().begin() + minSize))) { + return false; + } + + return llvm::all_of(llvm::make_range(getSplitAxes().begin() + minSize, + getSplitAxes().end()), + std::mem_fn(&DenseI32ArrayAttr::empty)) && + llvm::all_of(llvm::make_range(rhs.getSplitAxes().begin() + minSize, + rhs.getSplitAxes().end()), + std::mem_fn(&DenseI32ArrayAttr::empty)); +} + +//===----------------------------------------------------------------------===// +// mesh.process_index op +//===----------------------------------------------------------------------===// + +LogicalResult +ProcessIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable); + if (failed(mesh)) { + return failure(); + } + if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) { + return failure(); + } + + size_t expectedResultsCount = + getAxes().empty() ? mesh->getRank() : getAxes().size(); + if (getResult().size() != expectedResultsCount) { + return emitError() << "Unexpected number of results " << getResult().size() + << ". Expected " << expectedResultsCount << "."; + } + + return success(); +} + +void ProcessIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState, + ClusterOp mesh) { + build(odsBuilder, odsState, + SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()), + mesh.getSymName(), MeshAxesAttr()); +} + +void ProcessIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState, + StringRef mesh, ArrayRef<MeshAxis> axes) { + build(odsBuilder, odsState, + SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh, + MeshAxesAttr::get(odsBuilder.getContext(), axes)); +} + //===----------------------------------------------------------------------===// // collective communication ops //===----------------------------------------------------------------------===// @@ -258,56 +415,6 @@ static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, return success(); } -static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, - SymbolTableCollection &symbolTable) { - mesh::ClusterOp mesh = - symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(op, meshSymbol); - if (!mesh) { - return op->emitError() << "Undefined required mesh symbol \"" - << meshSymbol.getValue() << "\"."; - } - - return mesh; -} - -template <typename It> -bool isUnique(It begin, It end) { - if (begin == end) { - return true; - } - It next = std::next(begin); - if (next == end) { - return true; - } - for (; next != end; ++next, ++begin) { - if (*begin == *next) { - return false; - } - } - return true; -} - -static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes, - ClusterOp mesh) { - SmallVector<MeshAxis> sorted = llvm::to_vector(axes); - llvm::sort(sorted); - if (!isUnique(sorted.begin(), sorted.end())) { - return emitError(loc) << "Mesh axes contains duplicate elements."; - } - - MeshAxis rank = mesh.getRank(); - for (auto axis : axes) { - if (axis >= rank || axis < 0) { - return emitError(loc) - << "0-based mesh axis index " << axis - << " is out of bounds. The referenced mesh \"" << mesh.getSymName() - << "\" is of rank " << rank << "."; - } - } - - return success(); -} - template <typename Op> static FailureOr<ClusterOp> getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) { diff --git a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt index 044b867..7a70c04 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRMeshTransforms Simplifications.cpp ShardingPropagation.cpp + Spmdization.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh @@ -11,11 +12,13 @@ add_mlir_dialect_library(MLIRMeshTransforms LINK_LIBS PUBLIC MLIRArithDialect + MLIRControlFlowDialect MLIRFuncDialect MLIRIR MLIRMeshDialect MLIRPass MLIRShardingInterface MLIRSupport + MLIRTensorDialect MLIRTosaShardingInterfaceImpl ) diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp new file mode 100644 index 0000000..8d7e896 --- /dev/null +++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp @@ -0,0 +1,639 @@ +//===- Spmdization.cpp --------------------------------------------- C++ --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Mesh/Transforms/Spmdization.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Mesh/IR/MeshOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Support/MathExtras.h" +#include "llvm/ADT/ADL.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include <algorithm> +#include <iterator> +#include <numeric> +#include <optional> +#include <tuple> +#include <type_traits> + +namespace mlir { +namespace mesh { + +int64_t shardDimension(int64_t dim, int64_t shardCount) { + if (ShapedType::isDynamic(dim) || ShapedType::isDynamic(shardCount)) + return ShapedType::kDynamic; + + assert(dim % shardCount == 0); + return ceilDiv(dim, shardCount); +} + +int64_t unshardDimension(int64_t dim, int64_t shardCount) { + if (ShapedType::isDynamic(dim) || ShapedType::isDynamic(shardCount)) + return ShapedType::kDynamic; + + return dim * shardCount; +} + +template <typename MeshShape, typename SplitAxes> +int64_t shardCount(const MeshShape &meshShape, const SplitAxes &splitAxes) { + int64_t res = 1; + for (auto splitAxis : splitAxes) { + int64_t meshDimSize = meshShape[splitAxis]; + if (ShapedType::isDynamic(meshDimSize)) { + return ShapedType::kDynamic; + } + res *= meshDimSize; + } + return res; +} + +// Compute the shape for the tensor on each device in the mesh. +// Example: +// On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1 +// would result in a shape for each shard of ?x2x?. +template <typename InShape, typename MeshShape, typename SplitAxes, + typename OutShape> +static void shardShape(const InShape &inShape, const MeshShape &meshShape, + const SplitAxes &splitAxes, OutShape &outShape) { + std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape), + llvm::adl_begin(outShape)); + for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) { + outShape[tensorAxis] = + shardDimension(inShape[tensorAxis], + shardCount(meshShape, innerSplitAxes.asArrayRef())); + } +} + +ShapedType shardShapedType(ShapedType shape, ClusterOp mesh, + MeshShardingAttr sharding) { + using Dim = std::decay_t<decltype(shape.getDimSize(0))>; + SmallVector<Dim> resShapeArr(shape.getShape().size()); + shardShape(shape.getShape(), mesh.canonicalDimSizes(), + sharding.getSplitAxes(), resShapeArr); + return shape.clone(resShapeArr); +} + +template <typename SourceAxes, typename TargetAxes> +static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, + const TargetAxes &targetAxes) { + return llvm::all_of(targetAxes, [&sourceAxes](auto &targetAxis) { + return sourceAxes.contains(targetAxis); + }); +} + +// Return the reduced value and its corresponding sharding. +// Example: +// sourceSharding = <@mesh_1d, [[0]], partial = sum[0]> +// targetSharding = <@mesh_1d, [[]]> +// Then will apply all-reduce on the source value +// and return it with the sharding <@mesh_1d, [[0]]>. +static std::tuple<TypedValue<ShapedType>, MeshShardingAttr> +handlePartialAxesDuringResharding(OpBuilder &builder, + MeshShardingAttr sourceSharding, + MeshShardingAttr targetSharding, + TypedValue<ShapedType> sourceShard) { + if (sourceSharding.getPartialAxes().empty() && + targetSharding.getPartialAxes().empty()) { + return {sourceShard, sourceSharding}; + } + assert(targetSharding.getPartialAxes().empty() || + (!sourceSharding.getPartialAxes().empty() && + sourceSharding.getPartialType() == targetSharding.getPartialType())); + using Axis = std::decay_t<decltype(sourceSharding.getPartialAxes().front())>; + using AxisSet = llvm::SmallDenseSet<Axis>; + AxisSet sourceShardingPartialAxesSet(sourceSharding.getPartialAxes().begin(), + sourceSharding.getPartialAxes().end()); + AxisSet targetShardingPartialAxesSet(targetSharding.getPartialAxes().begin(), + targetSharding.getPartialAxes().end()); + assert(arePartialAxesCompatible(sourceShardingPartialAxesSet, + targetShardingPartialAxesSet)); + llvm::SmallVector<MeshAxis> allReduceMeshAxes; + llvm::copy_if(sourceShardingPartialAxesSet, + std::back_inserter(allReduceMeshAxes), + [&targetShardingPartialAxesSet](Axis a) { + return !targetShardingPartialAxesSet.contains(a); + }); + if (allReduceMeshAxes.empty()) { + return {sourceShard, sourceSharding}; + } + + builder.setInsertionPointAfterValue(sourceShard); + TypedValue<ShapedType> resultValue = + builder + .create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(), + sourceSharding.getCluster().getLeafReference(), + allReduceMeshAxes, sourceShard, + sourceSharding.getPartialType()) + .getResult() + .cast<TypedValue<ShapedType>>(); + + llvm::SmallVector<int32_t> remainingPartialAxes; + llvm::copy_if(sourceShardingPartialAxesSet, + std::back_inserter(allReduceMeshAxes), + [&targetShardingPartialAxesSet](Axis a) { + return targetShardingPartialAxesSet.contains(a); + }); + MeshShardingAttr resultSharding = + MeshShardingAttr::get(builder.getContext(), sourceSharding.getCluster(), + sourceSharding.getSplitAxes(), remainingPartialAxes, + sourceSharding.getPartialType()); + return {resultValue, resultSharding}; +} + +static MeshShardingAttr +targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding, + int64_t splitTensorAxis, MeshAxis splitMeshAxis) { + SmallVector<DenseI32ArrayAttr> targetShardingSplitAxes = + llvm::to_vector(sourceSharding.getSplitAxes()); + while (static_cast<int64_t>(targetShardingSplitAxes.size()) <= + splitTensorAxis) { + targetShardingSplitAxes.push_back(DenseI32ArrayAttr::get(ctx, {})); + } + auto targetSplitAxes = + llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef()); + targetSplitAxes.push_back(splitMeshAxis); + targetShardingSplitAxes[splitTensorAxis] = + DenseI32ArrayAttr::get(ctx, targetSplitAxes); + return MeshShardingAttr::get( + ctx, sourceSharding.getCluster(), targetShardingSplitAxes, + sourceSharding.getPartialAxes(), sourceSharding.getPartialType()); +} + +static ShapedType targetShapeInSplitLastAxis(ShapedType sourceShape, + int64_t splitTensorAxis, + int64_t splitCount) { + SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape()); + targetShape[splitTensorAxis] = + shardDimension(targetShape[splitTensorAxis], splitCount); + return sourceShape.cloneWith(targetShape, sourceShape.getElementType()); +} + +// Split a replicated tensor along a mesh axis. +// e.g. [[0, 1]] -> [[0, 1, 2]]. +// Returns the spmdized target value with its sharding. +// +// The implementation is the extract the tensor slice corresponding +// to the current device. +static std::tuple<TypedValue<ShapedType>, MeshShardingAttr> +splitLastAxisInResharding(ImplicitLocOpBuilder &builder, + MeshShardingAttr sourceSharding, + TypedValue<ShapedType> sourceShard, ClusterOp mesh, + int64_t splitTensorAxis, MeshAxis splitMeshAxis) { + MLIRContext *ctx = builder.getContext(); + builder.setInsertionPointAfterValue(sourceShard); + + Value zero = builder.create<arith::ConstantOp>(builder.getIndexAttr(0)); + + Value processIndexAlongAxis = + builder + .create<ProcessIndexOp>(mesh.getSymName(), + SmallVector<MeshAxis>({splitMeshAxis})) + .getResult()[0]; + + MeshShardingAttr targetSharding = targetShardingInSplitLastAxis( + ctx, sourceSharding, splitTensorAxis, splitMeshAxis); + ShapedType targetShape = + targetShapeInSplitLastAxis(sourceShard.getType(), splitTensorAxis, + mesh.canonicalDimSizes()[splitMeshAxis]); + + Value meshAxisSize = + builder + .create<ClusterShapeOp>(mesh.getSymName(), + SmallVector<MeshAxis>({splitMeshAxis})) + .getResult()[0]; + + Value sourceAxisSize = + builder.create<tensor::DimOp>(sourceShard, splitTensorAxis); + Value sourceAxisSizeModMeshAxisSize = + builder.create<arith::RemUIOp>(sourceAxisSize, meshAxisSize); + Value isTargetShapeExactlyDivisible = builder.create<arith::CmpIOp>( + arith::CmpIPredicate::eq, sourceAxisSizeModMeshAxisSize, zero); + builder.create<cf::AssertOp>( + isTargetShapeExactlyDivisible, + "Sharding a tensor with axis size that is not exactly divisible by the " + "mesh axis size is not supported."); + Value targetAxisSize = + builder.create<arith::DivUIOp>(sourceAxisSize, meshAxisSize); + Value axisOffset = + builder.create<arith::MulIOp>(targetAxisSize, processIndexAlongAxis); + SmallVector<int64_t> staticOffsets(targetShape.getRank(), 0); + staticOffsets[splitTensorAxis] = ShapedType::kDynamic; + DenseI64ArrayAttr staticOffsetsAttr = + DenseI64ArrayAttr::get(ctx, staticOffsets); + SmallVector<Value> dynamicOffsets(1, axisOffset); + + DenseI64ArrayAttr staticSizesAttr = + DenseI64ArrayAttr::get(ctx, targetShape.getShape()); + SmallVector<Value> dynamicSizes; + for (int64_t i = 0; i < targetShape.getRank(); ++i) { + if (ShapedType::isDynamic(staticSizesAttr.asArrayRef()[i])) { + if (i == splitTensorAxis) { + dynamicSizes.push_back(targetAxisSize); + } else { + Value dimSize = builder.create<tensor::DimOp>(sourceShard, i); + dynamicSizes.push_back(dimSize); + } + } + } + + DenseI64ArrayAttr staticStridesAttr = DenseI64ArrayAttr::get( + ctx, SmallVector<int64_t>(targetShape.getRank(), 1)); + TypedValue<RankedTensorType> targetShard = + builder + .create<tensor::ExtractSliceOp>( + targetShape, sourceShard, dynamicOffsets, dynamicSizes, + SmallVector<Value>({}), staticOffsetsAttr, staticSizesAttr, + staticStridesAttr) + .getResult(); + return {targetShard.cast<TypedValue<ShapedType>>(), targetSharding}; +} + +// Detect if the resharding is of type e.g. +// [[0, 1]] -> [[0, 1, 2]]. +// If detected, returns the corresponding tensor axis mesh axis pair. +// Does not detect insertions like +// [[0, 1]] -> [[0, 2, 1]]. +static std::optional<std::tuple<int64_t, MeshAxis>> +detectSplitLastAxisInResharding(MeshShardingAttr sourceSharding, + MeshShardingAttr targetSharding) { + for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size(); + ++tensorAxis) { + if (sourceSharding.getSplitAxes().size() > tensorAxis) { + if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 != + targetSharding.getSplitAxes()[tensorAxis].size()) { + continue; + } + if (!llvm::equal( + sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(), + llvm::make_range( + targetSharding.getSplitAxes()[tensorAxis] + .asArrayRef() + .begin(), + targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() - + 1))) { + continue; + } + } else { + if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) { + continue; + } + } + return std::make_tuple( + tensorAxis, + targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back()); + } + return std::nullopt; +} + +static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>> +trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh, + MeshShardingAttr sourceSharding, + MeshShardingAttr targetSharding, + TypedValue<ShapedType> sourceShard) { + if (auto detectRes = + detectSplitLastAxisInResharding(sourceSharding, targetSharding)) { + auto [tensorAxis, meshAxis] = detectRes.value(); + return splitLastAxisInResharding(builder, sourceSharding, sourceShard, mesh, + tensorAxis, meshAxis); + } + + return std::nullopt; +} + +// Detect if the resharding is of type e.g. +// [[0, 1, 2]] -> [[0, 1]]. +// If detected, returns the corresponding tensor axis mesh axis pair. +static std::optional<std::tuple<int64_t, MeshAxis>> +detectUnsplitLastAxisInResharding(MeshShardingAttr sourceSharding, + MeshShardingAttr targetSharding) { + for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size(); + ++tensorAxis) { + if (targetSharding.getSplitAxes().size() > tensorAxis) { + if (sourceSharding.getSplitAxes()[tensorAxis].size() != + targetSharding.getSplitAxes()[tensorAxis].size() + 1) + continue; + if (!llvm::equal( + llvm::make_range( + sourceSharding.getSplitAxes()[tensorAxis] + .asArrayRef() + .begin(), + sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() - + 1), + targetSharding.getSplitAxes()[tensorAxis].asArrayRef())) + continue; + } else { + if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1) + continue; + } + return std::make_tuple( + tensorAxis, + sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back()); + } + return std::nullopt; +} + +static MeshShardingAttr +targetShardingInUnsplitLastAxis(MLIRContext *ctx, + MeshShardingAttr sourceSharding, + int64_t splitTensorAxis) { + SmallVector<DenseI32ArrayAttr> targetShardingSplitAxes = + llvm::to_vector(sourceSharding.getSplitAxes()); + assert(static_cast<int64_t>(targetShardingSplitAxes.size()) > + splitTensorAxis); + auto targetSplitAxes = + llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef()); + + targetSplitAxes.pop_back(); + targetShardingSplitAxes[splitTensorAxis] = + DenseI32ArrayAttr::get(ctx, targetSplitAxes); + return MeshShardingAttr::get( + ctx, sourceSharding.getCluster(), targetShardingSplitAxes, + sourceSharding.getPartialAxes(), sourceSharding.getPartialType()); +} + +static ShapedType allGatherResultShapeInUnsplitLastAxis( + ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) { + SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape()); + targetShape[splitTensorAxis] = + unshardDimension(targetShape[splitTensorAxis], splitCount); + return sourceShape.cloneWith(targetShape, sourceShape.getElementType()); +} + +static std::tuple<TypedValue<ShapedType>, MeshShardingAttr> +unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, + MeshShardingAttr sourceSharding, + ShapedType sourceUnshardedShape, + TypedValue<ShapedType> sourceShard, ClusterOp mesh, + int64_t splitTensorAxis, MeshAxis splitMeshAxis) { + MLIRContext *ctx = builder.getContext(); + builder.setInsertionPointAfterValue(sourceShard); + + MeshShardingAttr targetSharding = + targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitMeshAxis); + ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis( + sourceShard.getType(), mesh.canonicalDimSizes()[splitMeshAxis], + splitTensorAxis); + Value allGatherResult = builder.create<AllGatherOp>( + RankedTensorType::get(allGatherResultShape.getShape(), + allGatherResultShape.getElementType()), + mesh.getSymName(), SmallVector<MeshAxis>({splitMeshAxis}), sourceShard, + APInt(64, splitTensorAxis)); + ShapedType targetShape = + shardShapedType(sourceUnshardedShape, mesh, targetSharding); + TypedValue<ShapedType> targetShard = + builder.create<tensor::CastOp>(targetShape, allGatherResult) + .getResult() + .cast<TypedValue<ShapedType>>(); + return {targetShard, targetSharding}; +} + +static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>> +tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh, + MeshShardingAttr sourceSharding, + MeshShardingAttr targetSharding, + ShapedType sourceUnshardedShape, + TypedValue<ShapedType> sourceShard) { + if (auto detectRes = + detectUnsplitLastAxisInResharding(sourceSharding, targetSharding)) { + auto [tensorAxis, meshAxis] = detectRes.value(); + return unsplitLastAxisInResharding(builder, sourceSharding, + sourceUnshardedShape, sourceShard, mesh, + tensorAxis, meshAxis); + } + + return std::nullopt; +} + +// Detect if the resharding is of type e.g. +// [[0, 1], [2]] -> [[0], [1, 2]]. +// Only moving the last axis counts. +// If detected, returns the corresponding (source_tensor_axis, +// target_tensor_axis, mesh_axis) tuple. +static std::optional<std::tuple<int64_t, int64_t, MeshAxis>> +detectMoveLastSplitAxisInResharding(MeshShardingAttr sourceSharding, + MeshShardingAttr targetSharding) { + for (size_t sourceTensorAxis = 0; + sourceTensorAxis < sourceSharding.getSplitAxes().size(); + ++sourceTensorAxis) { + for (size_t targetTensorAxis = 0; + targetTensorAxis < targetSharding.getSplitAxes().size(); + ++targetTensorAxis) { + if (sourceTensorAxis == targetTensorAxis) + continue; + if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() || + targetSharding.getSplitAxes()[targetTensorAxis].empty() || + sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() != + targetSharding.getSplitAxes()[targetTensorAxis] + .asArrayRef() + .back()) + continue; + if (!llvm::equal( + llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis] + .asArrayRef() + .begin(), + sourceSharding.getSplitAxes()[sourceTensorAxis] + .asArrayRef() + .end() - + 1), + llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis] + .asArrayRef() + .begin(), + targetSharding.getSplitAxes()[targetTensorAxis] + .asArrayRef() + .end() - + 1))) + continue; + return std::make_tuple( + sourceTensorAxis, targetTensorAxis, + sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back()); + } + } + return std::nullopt; +} + +static MeshShardingAttr +targetShardingInMoveLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding, + int64_t sourceTensorAxis, + int64_t targetTensorAxis) { + SmallVector<DenseI32ArrayAttr> targetShardingSplitAxes = + llvm::to_vector(sourceSharding.getSplitAxes()); + while (static_cast<int64_t>(targetShardingSplitAxes.size()) <= + targetTensorAxis) { + targetShardingSplitAxes.push_back(DenseI32ArrayAttr::get(ctx, {})); + } + + auto sourceSplitAxes = + llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef()); + assert(!sourceSplitAxes.empty()); + auto meshAxis = sourceSplitAxes.back(); + sourceSplitAxes.pop_back(); + targetShardingSplitAxes[sourceTensorAxis] = + DenseI32ArrayAttr::get(ctx, sourceSplitAxes); + + auto targetSplitAxes = + llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef()); + targetSplitAxes.push_back(meshAxis); + targetShardingSplitAxes[targetTensorAxis] = + DenseI32ArrayAttr::get(ctx, targetSplitAxes); + + return MeshShardingAttr::get( + ctx, sourceSharding.getCluster(), targetShardingSplitAxes, + sourceSharding.getPartialAxes(), sourceSharding.getPartialType()); +} + +static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape, + int64_t splitCount, + int64_t sourceTensorAxis, + int64_t targetTensorAxis) { + SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape()); + targetShape[sourceTensorAxis] = + unshardDimension(targetShape[sourceTensorAxis], splitCount); + targetShape[targetTensorAxis] = + shardDimension(targetShape[targetTensorAxis], splitCount); + return sourceShape.cloneWith(targetShape, sourceShape.getElementType()); +} + +static std::tuple<TypedValue<ShapedType>, MeshShardingAttr> +moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh, + MeshShardingAttr sourceSharding, + ShapedType sourceUnshardedShape, + TypedValue<ShapedType> sourceShard, + int64_t sourceTensorAxis, + int64_t targetTensorAxis, MeshAxis meshAxis) { + MLIRContext *ctx = builder.getContext(); + builder.setInsertionPointAfterValue(sourceShard); + + MeshShardingAttr targetSharding = targetShardingInMoveLastAxis( + ctx, sourceSharding, sourceTensorAxis, targetTensorAxis); + ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis( + sourceShard.getType(), mesh.canonicalDimSizes()[meshAxis], + sourceTensorAxis, targetTensorAxis); + Value allToAllResult = builder.create<AllToAllOp>( + RankedTensorType::get(allToAllResultShape.getShape(), + allToAllResultShape.getElementType()), + mesh.getSymName(), SmallVector<MeshAxis>({meshAxis}), sourceShard, + APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis)); + ShapedType targetShape = + shardShapedType(sourceUnshardedShape, mesh, targetSharding); + TypedValue<ShapedType> targetShard = + builder.create<tensor::CastOp>(targetShape, allToAllResult) + .getResult() + .cast<TypedValue<ShapedType>>(); + return {targetShard, targetSharding}; +} + +static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>> +tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh, + MeshShardingAttr sourceSharding, + MeshShardingAttr targetSharding, + ShapedType sourceUnshardedShape, + TypedValue<ShapedType> sourceShard) { + if (auto detectRes = + detectMoveLastSplitAxisInResharding(sourceSharding, targetSharding)) { + auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value(); + return moveLastSplitAxisInResharding( + builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard, + sourceTensorAxis, targetTensorAxis, meshAxis); + } + + return std::nullopt; +} + +// Handles only resharding on a 1D mesh. +// Currently the sharded tensor axes must be exactly divisible by the single +// mesh axis size. +static TypedValue<ShapedType> +reshardOn1DMesh(ImplicitLocOpBuilder &builder, ClusterOp mesh, + MeshShardingAttr sourceSharding, + MeshShardingAttr targetSharding, + TypedValue<ShapedType> sourceUnshardedValue, + TypedValue<ShapedType> sourceShard) { + assert(sourceShard.getType() == + shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding)); + [[maybe_unused]] ShapedType targetShardType = + shardShapedType(sourceUnshardedValue.getType(), mesh, targetSharding); + assert(sourceShard.getType().getRank() == targetShardType.getRank()); + assert(mesh.getRank() == 1 && "Only 1D meshes are currently supported."); + + auto [reducedSourceShard, reducedSourceSharding] = + handlePartialAxesDuringResharding(builder, sourceSharding, targetSharding, + sourceShard); + + if (reducedSourceSharding == targetSharding) { + return reducedSourceShard; + } + + TypedValue<ShapedType> targetShard; + MeshShardingAttr actualTargetSharding; + if (auto tryRes = tryMoveLastSplitAxisInResharding( + builder, mesh, reducedSourceSharding, targetSharding, + sourceUnshardedValue.getType(), reducedSourceShard)) { + std::tie(targetShard, actualTargetSharding) = tryRes.value(); + } else if (auto tryRes = trySplitLastAxisInResharding( + builder, mesh, reducedSourceSharding, targetSharding, + reducedSourceShard)) { + std::tie(targetShard, actualTargetSharding) = tryRes.value(); + } else if (auto tryRes = tryUnsplitLastAxisInResharding( + builder, mesh, reducedSourceSharding, targetSharding, + sourceUnshardedValue.getType(), reducedSourceShard)) { + std::tie(targetShard, actualTargetSharding) = tryRes.value(); + } else { + assert(false && "Did not find any pattern to apply."); + } + + assert(actualTargetSharding == targetSharding); + assert(targetShard.getType() == targetShardType); + return targetShard; +} + +TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, ClusterOp mesh, + MeshShardingAttr sourceSharding, + MeshShardingAttr targetSharding, + TypedValue<ShapedType> sourceUnshardedValue, + TypedValue<ShapedType> sourceShard) { + // Resort to handling only 1D meshes since the general case is complicated if + // it needs to be communication efficient in terms of minimizing the data + // transfered between devices. + return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding, + sourceUnshardedValue, sourceShard); +} + +TypedValue<ShapedType> reshard(OpBuilder &builder, ClusterOp mesh, + ShardOp source, ShardOp target, + TypedValue<ShapedType> sourceShardValue) { + assert(!source.getAnnotateForUsers()); + assert(target.getAnnotateForUsers()); + assert(source.getResult() == target.getOperand()); + ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder); + return reshard( + implicitLocOpBuilder, mesh, source.getShard(), target.getShard(), + source.getSrc().cast<TypedValue<ShapedType>>(), sourceShardValue); +} + +void reshardingRegisterDependentDialects(DialectRegistry ®istry) { + registry.insert<arith::ArithDialect, mesh::MeshDialect, tensor::TensorDialect, + cf::ControlFlowDialect>(); +} + +} // namespace mesh +} // namespace mlir diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp index 8af3b69..87a37a7 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp @@ -448,6 +448,23 @@ static bool isAdmissibleBSR(SparseTensorType &aTp) { return false; } +/// Test for 2:4 matrix with suitable metadata. +static bool isAdmissible24(SparseTensorType &aTp) { + return aTp.getDimRank() == 2 && aTp.getLvlRank() == 3 && aTp.isDenseLvl(0) && + aTp.isDenseLvl(1) && aTp.is2OutOf4Lvl(2) && isAdmissibleMetaData(aTp); +} + +/// Test for conversion into 2:4 matrix. +static bool isConversionInto24(Value v) { + if (auto cnv = v.getDefiningOp<ConvertOp>()) { + Value a = cnv.getResult(); + Value d = cnv.getSource(); + SparseTensorType aTp = getSparseTensorType(a); + return isDenseTensor(d) && isAdmissible24(aTp); + } + return false; +} + /// Returns a suitable sparse format for the operation and given operand /// types with cuSparse, or kNone if none is available. static CuSparseFormat getCuSparseFormat(SparseTensorType aTp, @@ -925,6 +942,15 @@ static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter, Value C = op.getOperand(2); // we have C = AB SmallVector<Value> tokens; + // The cuSparselt API currently only allows pruning and compression + // to occur on the device. So we recognize the pattern + // A' = convert A ; dense to 2:4 + // C = A'B ; 2:4 matrix mult + // and then perform compression and matrix multiplication on device. + auto cnv = A.getDefiningOp<ConvertOp>(); + assert(cnv); + A = cnv.getSource(); + // All input should be dense tensors. if (!isDenseTensor(A) || !isDenseTensor(B) || !isDenseTensor(C)) return failure(); @@ -1260,7 +1286,7 @@ struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> { maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) { if (!isDenseTensor(op.getOperand(0)) && !isDenseTensor(op.getOperand(1))) return rewriteSpGEMM(rewriter, op, enableRT); - if (op->getAttr("DENSE24")) + if (isConversionInto24(op.getOperand(0))) return rewrite2To4SpMM(rewriter, op); return rewriteSpMM(rewriter, op, enableRT); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index 934e1e5..35eb4b4 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -949,94 +949,9 @@ static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp, // Sparsifier synthesis methods (loop sequence). //===----------------------------------------------------------------------===// -/// Starts a loop sequence at given level. Returns true if -/// the universal loop index must be maintained at this level. -static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp, - LoopId curr, LatSetId lts) { - assert(!env.getLoopVar(curr)); - // Emit invariants at this loop sequence level. - genInvariants(env, builder, exp, curr, /*isStart=*/true); - // Emit access pattern expansion for sparse tensor output. - genExpand(env, builder, curr, /*isStart=*/true); - // Emit further intitialization at this loop sequence level. - const LatPointId l0 = env.set(lts)[0]; - bool needsUniv = false; - - SmallVector<TensorLevel> tidLvls; - env.merger().foreachTensorLoopId(l0, [&](TensorLoopId b, TensorId tid, - std::optional<Level> lvl, - LevelType lt, bool isIdxReduc) { - assert(env.merger().loop(b) == curr); - if (isDenseLT(lt) || isUndefLT(lt)) { - if (tid == env.merger().getSynTensorID()) { - // Needs loop emitter to set up loop bounds for synthetic tensor too if - // there is a loop condition imposed on the synthetic tensor. - tidLvls.push_back(env.makeTensorLevel(tid, env.getCurrentDepth())); - } - needsUniv = true; - } - if (isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) || - is2OutOf4LT(lt) || isIdxReduc) { - // Only when this is a index reduction loop, can the lt be undefined. - assert(!isUndefLT(lt) || isIdxReduc); - // sparse/singleton levels, or a dense/sparse index reduction loop. - tidLvls.push_back(env.makeTensorLevel(tid, *lvl)); - } - }); - - env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls); - - // Maintain the universal index only if it is actually - // consumed by a subsequent lattice point. - if (needsUniv) { - for (const LatPointId li : env.set(lts).drop_front()) - if (!env.merger().hasAnySparse(env.lat(li).simple)) - return true; - } - return false; -} - -// Generates dense affine address for encoding. -static void genConstantDenseAddressFromLevel(CodegenEnv &env, - OpBuilder &builder, TensorId tid, - Level startLvl) { - // TODO: Handle affine expression on output tensor. - linalg::GenericOp op = env.op(); - assert(tid < op.getNumDpsInputs()); - OpOperand *input = op.getDpsInputOperands()[tid]; - const auto lvlExprs = op.getMatchingIndexingMap(input).getResults(); - const auto enc = getSparseTensorEncoding(input->get().getType()); - if (enc) { - const Location loc = op.getLoc(); - const TensorId tid = env.makeTensorId(input->getOperandNumber()); - const Level lvlRank = enc.getLvlRank(); - assert(lvlExprs.size() == static_cast<size_t>(lvlRank)); - for (Level l = startLvl; l < lvlRank; l++) { - AffineExpr lvlExpr = lvlExprs[l]; - if (enc.isDenseLvl(l) && isa<AffineConstantExpr>(lvlExpr)) - env.emitter().genDenseAffineAddress( - builder, loc, env.makeTensorLevel(tid, l), lvlExpr); - else - return; // break on first non-dense non-constant level - } - } -} - -// We can generate address for constant affine expression before any loops -// starting from the first level as they do not depend on any thing. -// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two -// levels can be determined before loops. -static void genInitConstantDenseAddress(CodegenEnv &env, - RewriterBase &rewriter) { - for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++) - genConstantDenseAddressFromLevel(env, rewriter, tid, 0); -} - -/// Return true if the lattices bit can be iterated by a for loop. -static bool translateBitsToTidLvlPairs( +static bool getAllTidLvlsInLatPoints( CodegenEnv &env, LatPointId li, LoopId curr, - SmallVectorImpl<TensorLevel> &tidLvls, - SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) { + llvm::function_ref<void(TensorLevel, AffineExpr)> callback) { const BitVector &simple = env.lat(li).simple; const TensorId outTid = env.merger().getOutTensorID(); const std::optional<Level> outLvl = env.merger().getLvl(outTid, curr); @@ -1048,7 +963,7 @@ static bool translateBitsToTidLvlPairs( LevelType lt, bool isIdxReduc) { if (simple[b]) { if (isIdxReduc) { - tidLvls.push_back(env.makeTensorLevel(tid, *lvl)); + callback(env.makeTensorLevel(tid, *lvl), nullptr); numloopCond++; return; } @@ -1072,10 +987,10 @@ static bool translateBitsToTidLvlPairs( } } hasNonUnique = !isUniqueLT(lt) || hasNonUnique; - tidLvls.push_back(env.makeTensorLevel(tid, *lvl)); + callback(env.makeTensorLevel(tid, *lvl), nullptr); numloopCond++; } else if (isDenseLT(lt) || isIdxReduc) { - tidLvls.push_back(env.makeTensorLevel(tid, *lvl)); + callback(env.makeTensorLevel(tid, *lvl), nullptr); } else { assert(isUndefLT(lt)); linalg::GenericOp op = env.op(); @@ -1109,7 +1024,7 @@ static bool translateBitsToTidLvlPairs( // level. We need to generate the address according to the // affine expression. This is also the best place we can do it // to avoid putting it inside inner loops. - affineTidLvls.emplace_back(env.makeTensorLevel(tid, l), exp); + callback(env.makeTensorLevel(tid, l), exp); } } } @@ -1120,15 +1035,14 @@ static bool translateBitsToTidLvlPairs( // Note that we generate dense indices of the output tensor // unconditionally, since they may not appear in the lattice, but may be // needed for linearized env. - tidLvls.push_back(env.makeTensorLevel(outTid, *outLvl)); + callback(env.makeTensorLevel(outTid, *outLvl), nullptr); } if (numloopCond == 0) { // Corner cases where the loop bound is defined by a *unused* operand, in // this case, we just generate a dense "fake" loop by iterating over the // synthetic tensor. - tidLvls.push_back(env.makeTensorLevel(env.merger().getSynTensorID(), - env.getCurrentDepth())); + callback(env.makeTensorLevel(env.merger().getSynTensorID(), curr), nullptr); numloopCond++; } // If we just need to one loop conditions and the conditions is not imposed on @@ -1136,6 +1050,84 @@ static bool translateBitsToTidLvlPairs( return numloopCond == 1 && !hasNonUnique; } +/// Starts a loop sequence at given level. Returns true if +/// the universal loop index must be maintained at this level. +static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp, + LoopId curr, LatSetId lts) { + assert(!env.getLoopVar(curr)); + // Emit invariants at this loop sequence level. + genInvariants(env, builder, exp, curr, /*isStart=*/true); + // Emit access pattern expansion for sparse tensor output. + genExpand(env, builder, curr, /*isStart=*/true); + // Emit further initialization at this loop sequence level. + const LatPointId l0 = env.set(lts)[0]; + + SmallVector<TensorLevel> tidLvls; + getAllTidLvlsInLatPoints(env, l0, curr, [&](TensorLevel tl, AffineExpr) { + tidLvls.emplace_back(tl); + }); + + env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls); + + // Maintain the universal index only if it is actually + // consumed by a subsequent lattice point. + for (const LatPointId li : env.set(lts).drop_front()) + if (!env.merger().hasAnySparse(env.lat(li).simple)) + return true; + + return false; +} + +// Generates dense affine address for encoding. +static void genConstantDenseAddressFromLevel(CodegenEnv &env, + OpBuilder &builder, TensorId tid, + Level startLvl) { + // TODO: Handle affine expression on output tensor. + linalg::GenericOp op = env.op(); + assert(tid < op.getNumDpsInputs()); + OpOperand *input = op.getDpsInputOperands()[tid]; + const auto lvlExprs = op.getMatchingIndexingMap(input).getResults(); + const auto enc = getSparseTensorEncoding(input->get().getType()); + if (enc) { + const Location loc = op.getLoc(); + const TensorId tid = env.makeTensorId(input->getOperandNumber()); + const Level lvlRank = enc.getLvlRank(); + assert(lvlExprs.size() == static_cast<size_t>(lvlRank)); + for (Level l = startLvl; l < lvlRank; l++) { + AffineExpr lvlExpr = lvlExprs[l]; + if (enc.isDenseLvl(l) && isa<AffineConstantExpr>(lvlExpr)) + env.emitter().genDenseAffineAddress( + builder, loc, env.makeTensorLevel(tid, l), lvlExpr); + else + return; // break on first non-dense non-constant level + } + } +} + +// We can generate address for constant affine expression before any loops +// starting from the first level as they do not depend on anything. +// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two +// levels can be determined before loops. +static void genInitConstantDenseAddress(CodegenEnv &env, + RewriterBase &rewriter) { + for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++) + genConstantDenseAddressFromLevel(env, rewriter, tid, 0); +} + +/// Returns true if the lattice bit can be iterated by a for loop. +static bool translateBitsToTidLvlPairs( + CodegenEnv &env, LatPointId li, LoopId curr, + SmallVectorImpl<TensorLevel> &tidLvls, + SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) { + return getAllTidLvlsInLatPoints(env, li, curr, + [&](TensorLevel tl, AffineExpr exp) { + if (exp) + affineTidLvls.emplace_back(tl, exp); + else + tidLvls.emplace_back(tl); + }); +} + /// Starts a single loop in current sequence. static std::pair<Operation *, bool> startLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp index e20450c..cfd838e 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp @@ -61,6 +61,47 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> { } }; +struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> { + using OpRewritePattern<UnPackOp>::OpRewritePattern; + + Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand, + Type newOperandType, ArrayAttr reassociation) const { + if (operand.getType() == newOperandType) + return operand; + return rewriter.create<tensor::CollapseShapeOp>(loc, newOperandType, + operand, reassociation); + } + + LogicalResult matchAndRewrite(UnPackOp unpackOp, + PatternRewriter &rewriter) const override { + if (!unpackOp.getOuterDimsPerm().empty()) { + return rewriter.notifyMatchFailure(unpackOp, + "expects no outer_dims_perm"); + } + + RankedTensorType sourceType = unpackOp.getSourceType(); + RankedTensorType destType = unpackOp.getDestType(); + if (!sourceType.hasStaticShape() || !destType.hasStaticShape()) + return rewriter.notifyMatchFailure(unpackOp, "expects static shapes"); + + ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos(); + if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) { + return rewriter.notifyMatchFailure( + unpackOp, "expects unpacking at the innermost dimension"); + } + + auto reassociation = + getReassociationIndicesForReshape(sourceType, destType); + if (!reassociation) + return failure(); + Value collapsed = insertCollapse( + rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType, + getReassociationIndicesAttribute(rewriter, *reassociation)); + rewriter.replaceOp(unpackOp, collapsed); + return success(); + } +}; + /// Fold a `pad` -> `pack` into `pack` if they have the same padding values and /// the pad op has zero low paddings, or if `pack` has no padding values. struct FoldPadWithPackOp : public OpRewritePattern<PackOp> { @@ -191,7 +232,8 @@ void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) { } void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) { - patterns.add<SimplifyPackToExpandShape>(patterns.getContext()); + patterns.add<SimplifyPackToExpandShape, SimplifyUnPackToCollapseShape>( + patterns.getContext()); } } // namespace tensor diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 7136e42..aa4694c 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -32,6 +32,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include <optional> @@ -1975,6 +1976,42 @@ void transform::NamedSequenceOp::build(OpBuilder &builder, } //===----------------------------------------------------------------------===// +// NumAssociationsOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + size_t numAssociations = + llvm::TypeSwitch<Type, size_t>(getHandle().getType()) + .Case([&](TransformHandleTypeInterface opHandle) { + return llvm::range_size(state.getPayloadOps(getHandle())); + }) + .Case([&](TransformValueHandleTypeInterface valueHandle) { + return llvm::range_size(state.getPayloadValues(getHandle())); + }) + .Case([&](TransformParamTypeInterface param) { + return llvm::range_size(state.getParams(getHandle())); + }) + .Default([](Type) { + llvm_unreachable("unknown kind of transform dialect type"); + return 0; + }); + results.setParams(getNum().cast<OpResult>(), + rewriter.getI64IntegerAttr(numAssociations)); + return DiagnosedSilenceableFailure::success(); +} + +LogicalResult transform::NumAssociationsOp::verify() { + // Verify that the result type accepts an i64 attribute as payload. + auto resultType = getNum().getType().cast<TransformParamTypeInterface>(); + return resultType + .checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)}) + .checkAndReport(); +} + +//===----------------------------------------------------------------------===// // SelectOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index 2ad992a..c1c0f54 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -271,7 +271,7 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) { return false; // Cond 1: A contiguous memref will always have a unit trailing stride. - if (strides.back() != 1) + if (strides.empty() || strides.back() != 1) return false; // Cond 2: Strides of a contiguous memref have to match the flattened dims. |