aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorArtem Kroviakov <71938912+akroviakov@users.noreply.github.com>2024-05-28 14:54:37 +0200
committerGitHub <noreply@github.com>2024-05-28 14:54:37 +0200
commit01fbc5658cdfa152519e2d0842ccf7d91aaeaeaf (patch)
treeeff853ca054be796f115441a47734b6b1854a97a
parent1da52caf2946e56f69eae75a60088a54edda1db5 (diff)
downloadllvm-01fbc5658cdfa152519e2d0842ccf7d91aaeaeaf.zip
llvm-01fbc5658cdfa152519e2d0842ccf7d91aaeaeaf.tar.gz
llvm-01fbc5658cdfa152519e2d0842ccf7d91aaeaeaf.tar.bz2
[mlir][vector] Add support for linearizing Insert VectorOp in VectorLinearize (#92370)
Building on top of [#88204](https://github.com/llvm/llvm-project/pull/88204), this PR adds support for converting `vector.insert` into an equivalent `vector.shuffle` operation that operates on linearized (1-D) vectors.
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp97
-rw-r--r--mlir/test/Dialect/Vector/linearize.mlir29
2 files changed, 125 insertions, 1 deletions
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 802a64b..156bf74 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -44,6 +44,19 @@ static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
return true;
}
+static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
+ VectorType vecType = dyn_cast<VectorType>(t);
+ // Reject index since getElementTypeBitWidth will abort for Index types.
+ if (!vecType || vecType.getElementType().isIndex())
+ return false;
+ // There are no dimension to fold if it is a 0-D vector.
+ if (vecType.getRank() == 0)
+ return false;
+ unsigned trailingVecDimBitWidth =
+ vecType.getShape().back() * vecType.getElementTypeBitWidth();
+ return trailingVecDimBitWidth <= targetBitWidth;
+}
+
namespace {
struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
using OpConversionPattern::OpConversionPattern;
@@ -358,6 +371,88 @@ struct LinearizeVectorExtract final
private:
unsigned targetVectorBitWidth;
};
+
+/// This pattern converts the InsertOp to a ShuffleOp that works on a
+/// linearized vector.
+/// Following,
+/// vector.insert %source %destination [ position ]
+/// is converted to :
+/// %source_1d = vector.shape_cast %source
+/// %destination_1d = vector.shape_cast %destination
+/// %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d
+/// ] %out_nd = vector.shape_cast %out_1d
+/// `shuffle_indices_1d` is computed using the position of the original insert.
+struct LinearizeVectorInsert final
+ : public OpConversionPattern<vector::InsertOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorInsert(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+ LogicalResult
+ matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type dstTy = getTypeConverter()->convertType(insertOp.getDestVectorType());
+ assert(!(insertOp.getDestVectorType().isScalable() ||
+ cast<VectorType>(dstTy).isScalable()) &&
+ "scalable vectors are not supported.");
+
+ if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
+ targetVectorBitWidth))
+ return rewriter.notifyMatchFailure(
+ insertOp, "Can't flatten since targetBitWidth < OpSize");
+
+ // dynamic position is not supported
+ if (insertOp.hasDynamicPosition())
+ return rewriter.notifyMatchFailure(insertOp,
+ "dynamic position is not supported.");
+ auto srcTy = insertOp.getSourceType();
+ auto srcAsVec = dyn_cast<VectorType>(srcTy);
+ uint64_t srcSize = 0;
+ if (srcAsVec) {
+ srcSize = srcAsVec.getNumElements();
+ } else {
+ return rewriter.notifyMatchFailure(insertOp,
+ "scalars are not supported.");
+ }
+
+ auto dstShape = insertOp.getDestVectorType().getShape();
+ const auto dstSize = insertOp.getDestVectorType().getNumElements();
+ auto dstSizeForOffsets = dstSize;
+
+ // compute linearized offset
+ int64_t linearizedOffset = 0;
+ auto offsetsNd = insertOp.getStaticPosition();
+ for (auto [dim, offset] : llvm::enumerate(offsetsNd)) {
+ dstSizeForOffsets /= dstShape[dim];
+ linearizedOffset += offset * dstSizeForOffsets;
+ }
+
+ llvm::SmallVector<int64_t, 2> indices(dstSize);
+ auto origValsUntil = indices.begin();
+ std::advance(origValsUntil, linearizedOffset);
+ std::iota(indices.begin(), origValsUntil,
+ 0); // original values that remain [0, offset)
+ auto newValsUntil = origValsUntil;
+ std::advance(newValsUntil, srcSize);
+ std::iota(origValsUntil, newValsUntil,
+ dstSize); // new values [offset, offset+srcNumElements)
+ std::iota(newValsUntil, indices.end(),
+ linearizedOffset + srcSize); // the rest of original values
+ // [offset+srcNumElements, end)
+
+ rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
+ insertOp, dstTy, adaptor.getDest(), adaptor.getSource(),
+ rewriter.getI64ArrayAttr(indices));
+
+ return success();
+ }
+
+private:
+ unsigned targetVectorBitWidth;
+};
} // namespace
void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
@@ -410,6 +505,6 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
: true;
});
patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
- LinearizeVectorExtractStridedSlice>(
+ LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
typeConverter, patterns.getContext(), targetBitWidth);
}
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index b29ceab..31a59b8 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -245,3 +245,32 @@ func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
%0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32>
return %0 : vector<8x2xf32>
}
+
+// -----
+// ALL-LABEL: test_vector_insert
+// ALL-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> {
+func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) -> vector<2x8x4xf32> {
+ // DEFAULT: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32>
+ // DEFAULT: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32>
+ // DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[ARG_SRC]]
+ // DEFAULT-SAME: [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87,
+ // DEFAULT-SAME: 88, 89, 90, 91, 92, 93, 94, 95, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
+ // DEFAULT-SAME: 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<32xf32>
+ // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32>
+ // DEFAULT: return %[[RES]] : vector<2x8x4xf32>
+
+ // BW-128: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32>
+ // BW-128: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32>
+ // BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[ARG_SRC]]
+ // BW-128-SAME: [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87,
+ // BW-128-SAME: 88, 89, 90, 91, 92, 93, 94, 95, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
+ // BW-128-SAME: 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<32xf32>
+ // BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32>
+ // BW-128: return %[[RES]] : vector<2x8x4xf32>
+
+ // BW-0: %[[RES:.*]] = vector.insert %[[SRC]], %[[DEST]] [0] : vector<8x4xf32> into vector<2x8x4xf32>
+ // BW-0: return %[[RES]] : vector<2x8x4xf32>
+
+ %0 = vector.insert %arg1, %arg0[0]: vector<8x4xf32> into vector<2x8x4xf32>
+ return %0 : vector<2x8x4xf32>
+}