diff options
author | Artem Kroviakov <71938912+akroviakov@users.noreply.github.com> | 2024-05-28 14:54:37 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-28 14:54:37 +0200 |
commit | 01fbc5658cdfa152519e2d0842ccf7d91aaeaeaf (patch) | |
tree | eff853ca054be796f115441a47734b6b1854a97a | |
parent | 1da52caf2946e56f69eae75a60088a54edda1db5 (diff) | |
download | llvm-01fbc5658cdfa152519e2d0842ccf7d91aaeaeaf.zip llvm-01fbc5658cdfa152519e2d0842ccf7d91aaeaeaf.tar.gz llvm-01fbc5658cdfa152519e2d0842ccf7d91aaeaeaf.tar.bz2 |
[mlir][vector] Add support for linearizing Insert VectorOp in VectorLinearize (#92370)
Building on top of
[#88204](https://github.com/llvm/llvm-project/pull/88204), this PR adds
support for converting `vector.insert` into an equivalent
`vector.shuffle` operation that operates on linearized (1-D) vectors.
-rw-r--r-- | mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 97 | ||||
-rw-r--r-- | mlir/test/Dialect/Vector/linearize.mlir | 29 |
2 files changed, 125 insertions, 1 deletions
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 802a64b..156bf74 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -44,6 +44,19 @@ static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) { return true; } +static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) { + VectorType vecType = dyn_cast<VectorType>(t); + // Reject index since getElementTypeBitWidth will abort for Index types. + if (!vecType || vecType.getElementType().isIndex()) + return false; + // There are no dimension to fold if it is a 0-D vector. + if (vecType.getRank() == 0) + return false; + unsigned trailingVecDimBitWidth = + vecType.getShape().back() * vecType.getElementTypeBitWidth(); + return trailingVecDimBitWidth <= targetBitWidth; +} + namespace { struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> { using OpConversionPattern::OpConversionPattern; @@ -358,6 +371,88 @@ struct LinearizeVectorExtract final private: unsigned targetVectorBitWidth; }; + +/// This pattern converts the InsertOp to a ShuffleOp that works on a +/// linearized vector. +/// Following, +/// vector.insert %source %destination [ position ] +/// is converted to : +/// %source_1d = vector.shape_cast %source +/// %destination_1d = vector.shape_cast %destination +/// %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d +/// ] %out_nd = vector.shape_cast %out_1d +/// `shuffle_indices_1d` is computed using the position of the original insert. +struct LinearizeVectorInsert final + : public OpConversionPattern<vector::InsertOp> { + using OpConversionPattern::OpConversionPattern; + LinearizeVectorInsert( + const TypeConverter &typeConverter, MLIRContext *context, + unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + targetVectorBitWidth(targetVectBitWidth) {} + LogicalResult + matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type dstTy = getTypeConverter()->convertType(insertOp.getDestVectorType()); + assert(!(insertOp.getDestVectorType().isScalable() || + cast<VectorType>(dstTy).isScalable()) && + "scalable vectors are not supported."); + + if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(), + targetVectorBitWidth)) + return rewriter.notifyMatchFailure( + insertOp, "Can't flatten since targetBitWidth < OpSize"); + + // dynamic position is not supported + if (insertOp.hasDynamicPosition()) + return rewriter.notifyMatchFailure(insertOp, + "dynamic position is not supported."); + auto srcTy = insertOp.getSourceType(); + auto srcAsVec = dyn_cast<VectorType>(srcTy); + uint64_t srcSize = 0; + if (srcAsVec) { + srcSize = srcAsVec.getNumElements(); + } else { + return rewriter.notifyMatchFailure(insertOp, + "scalars are not supported."); + } + + auto dstShape = insertOp.getDestVectorType().getShape(); + const auto dstSize = insertOp.getDestVectorType().getNumElements(); + auto dstSizeForOffsets = dstSize; + + // compute linearized offset + int64_t linearizedOffset = 0; + auto offsetsNd = insertOp.getStaticPosition(); + for (auto [dim, offset] : llvm::enumerate(offsetsNd)) { + dstSizeForOffsets /= dstShape[dim]; + linearizedOffset += offset * dstSizeForOffsets; + } + + llvm::SmallVector<int64_t, 2> indices(dstSize); + auto origValsUntil = indices.begin(); + std::advance(origValsUntil, linearizedOffset); + std::iota(indices.begin(), origValsUntil, + 0); // original values that remain [0, offset) + auto newValsUntil = origValsUntil; + std::advance(newValsUntil, srcSize); + std::iota(origValsUntil, newValsUntil, + dstSize); // new values [offset, offset+srcNumElements) + std::iota(newValsUntil, indices.end(), + linearizedOffset + srcSize); // the rest of original values + // [offset+srcNumElements, end) + + rewriter.replaceOpWithNewOp<vector::ShuffleOp>( + insertOp, dstTy, adaptor.getDest(), adaptor.getSource(), + rewriter.getI64ArrayAttr(indices)); + + return success(); + } + +private: + unsigned targetVectorBitWidth; +}; } // namespace void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( @@ -410,6 +505,6 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( : true; }); patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract, - LinearizeVectorExtractStridedSlice>( + LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>( typeConverter, patterns.getContext(), targetBitWidth); } diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index b29ceab..31a59b8 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -245,3 +245,32 @@ func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> { %0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32> return %0 : vector<8x2xf32> } + +// ----- +// ALL-LABEL: test_vector_insert +// ALL-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> { +func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) -> vector<2x8x4xf32> { + // DEFAULT: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32> + // DEFAULT: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32> + // DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[ARG_SRC]] + // DEFAULT-SAME: [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, + // DEFAULT-SAME: 88, 89, 90, 91, 92, 93, 94, 95, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, + // DEFAULT-SAME: 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<32xf32> + // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32> + // DEFAULT: return %[[RES]] : vector<2x8x4xf32> + + // BW-128: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32> + // BW-128: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32> + // BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[ARG_SRC]] + // BW-128-SAME: [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, + // BW-128-SAME: 88, 89, 90, 91, 92, 93, 94, 95, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, + // BW-128-SAME: 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<32xf32> + // BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32> + // BW-128: return %[[RES]] : vector<2x8x4xf32> + + // BW-0: %[[RES:.*]] = vector.insert %[[SRC]], %[[DEST]] [0] : vector<8x4xf32> into vector<2x8x4xf32> + // BW-0: return %[[RES]] : vector<2x8x4xf32> + + %0 = vector.insert %arg1, %arg0[0]: vector<8x4xf32> into vector<2x8x4xf32> + return %0 : vector<2x8x4xf32> +} |