aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorArtem Kroviakov <71938912+akroviakov@users.noreply.github.com>2024-06-21 16:36:56 +0200
committerGitHub <noreply@github.com>2024-06-21 09:36:56 -0500
commit74a105ad80725b4a54ef76950c938ffe76796b3b (patch)
treeff61fcb697edbaf8178dc0650e575467e9909d57 /mlir
parentdb8c7e004a8acf74f40e0f7bc60066f26d43ccd9 (diff)
downloadllvm-74a105ad80725b4a54ef76950c938ffe76796b3b.zip
llvm-74a105ad80725b4a54ef76950c938ffe76796b3b.tar.gz
llvm-74a105ad80725b4a54ef76950c938ffe76796b3b.tar.bz2
[mlir][vector] Use notifyMatchFailure instead of assert in VectorLinearize (#93590)
As it was [suggested](https://github.com/llvm/llvm-project/pull/92370#discussion_r1617592942), the `assert` is replaced by `notifyMatchFailure` for improved consistency.
Diffstat (limited to 'mlir')
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp35
-rw-r--r--mlir/test/Dialect/Vector/linearize.mlir32
2 files changed, 54 insertions, 13 deletions
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 156bf74..a1bb81e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -151,10 +151,12 @@ struct LinearizeVectorExtractStridedSlice final
LogicalResult
matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Type dstType = getTypeConverter()->convertType(extractOp.getType());
- assert(!(extractOp.getVector().getType().isScalable() ||
- cast<VectorType>(dstType).isScalable()) &&
- "scalable vectors are not supported.");
+ VectorType dstType =
+ getTypeConverter()->convertType<VectorType>(extractOp.getType());
+ assert(dstType && "vector type destination expected.");
+ if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
+ return rewriter.notifyMatchFailure(extractOp,
+ "scalable vectors are not supported.");
if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
return rewriter.notifyMatchFailure(
extractOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -264,10 +266,14 @@ struct LinearizeVectorShuffle final
LogicalResult
matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Type dstType = getTypeConverter()->convertType(shuffleOp.getType());
+ VectorType dstType =
+ getTypeConverter()->convertType<VectorType>(shuffleOp.getType());
+ assert(dstType && "vector type destination expected.");
+ // The assert is used because vector.shuffle does not support scalable
+ // vectors.
assert(!(shuffleOp.getV1VectorType().isScalable() ||
shuffleOp.getV2VectorType().isScalable() ||
- cast<VectorType>(dstType).isScalable()) &&
+ dstType.isScalable()) &&
"scalable vectors are not supported.");
if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
return rewriter.notifyMatchFailure(
@@ -336,9 +342,10 @@ struct LinearizeVectorExtract final
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type dstTy = getTypeConverter()->convertType(extractOp.getType());
- assert(!(extractOp.getVector().getType().isScalable() ||
- cast<VectorType>(dstTy).isScalable()) &&
- "scalable vectors are not supported.");
+ if (extractOp.getVector().getType().isScalable() ||
+ cast<VectorType>(dstTy).isScalable())
+ return rewriter.notifyMatchFailure(extractOp,
+ "scalable vectors are not supported.");
if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
return rewriter.notifyMatchFailure(
extractOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -394,10 +401,12 @@ struct LinearizeVectorInsert final
LogicalResult
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Type dstTy = getTypeConverter()->convertType(insertOp.getDestVectorType());
- assert(!(insertOp.getDestVectorType().isScalable() ||
- cast<VectorType>(dstTy).isScalable()) &&
- "scalable vectors are not supported.");
+ VectorType dstTy = getTypeConverter()->convertType<VectorType>(
+ insertOp.getDestVectorType());
+ assert(dstTy && "vector type destination expected.");
+ if (insertOp.getDestVectorType().isScalable() || dstTy.isScalable())
+ return rewriter.notifyMatchFailure(insertOp,
+ "scalable vectors are not supported.");
if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
targetVectorBitWidth))
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index ec3806c..916e3e5 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -177,6 +177,17 @@ func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf
return %0 : vector<2x2xf32>
}
+// ALL-LABEL: func.func @test_extract_strided_slice_1_scalable(
+// ALL-SAME: %[[VAL_0:.*]]: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
+func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
+ // ALL-NOT: vector.shuffle
+ // ALL-NOT: vector.shape_cast
+ // ALL: %[[RES:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [1, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]xf32> to vector<2x[8]xf32>
+ %0 = vector.extract_strided_slice %arg0 { sizes = [2, 8], strides = [1, 1], offsets = [1, 0] } : vector<4x[8]xf32> to vector<2x[8]xf32>
+ // ALL: return %[[RES]] : vector<2x[8]xf32>
+ return %0 : vector<2x[8]xf32>
+}
+
// -----
// ALL-LABEL: test_extract_strided_slice_2
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<1x4x2xf32> {
@@ -246,6 +257,16 @@ func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
return %0 : vector<8x2xf32>
}
+// ALL-LABEL: func.func @test_vector_extract_scalable(
+// ALL-SAME: %[[VAL_0:.*]]: vector<2x8x[2]xf32>) -> vector<8x[2]xf32> {
+func.func @test_vector_extract_scalable(%arg0: vector<2x8x[2]xf32>) -> vector<8x[2]xf32> {
+ // ALL-NOT: vector.shuffle
+ // ALL-NOT: vector.shape_cast
+ // ALL: %[[RES:.*]] = vector.extract %[[VAL_0]][1] : vector<8x[2]xf32> from vector<2x8x[2]xf32>
+ %0 = vector.extract %arg0[1]: vector<8x[2]xf32> from vector<2x8x[2]xf32>
+ // ALL: return %[[RES]] : vector<8x[2]xf32>
+ return %0 : vector<8x[2]xf32>
+}
// -----
// ALL-LABEL: test_vector_insert
// ALL-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> {
@@ -274,3 +295,14 @@ func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>)
%0 = vector.insert %arg1, %arg0[0]: vector<8x4xf32> into vector<2x8x4xf32>
return %0 : vector<2x8x4xf32>
}
+
+// ALL-LABEL: func.func @test_vector_insert_scalable(
+// ALL-SAME: %[[VAL_0:.*]]: vector<2x8x[4]xf32>, %[[VAL_1:.*]]: vector<8x[4]xf32>) -> vector<2x8x[4]xf32> {
+func.func @test_vector_insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector<8x[4]xf32>) -> vector<2x8x[4]xf32> {
+ // ALL-NOT: vector.shuffle
+ // ALL-NOT: vector.shape_cast
+ // ALL: %[[RES:.*]] = vector.insert %[[VAL_1]], %[[VAL_0]] [0] : vector<8x[4]xf32> into vector<2x8x[4]xf32>
+ %0 = vector.insert %arg1, %arg0[0]: vector<8x[4]xf32> into vector<2x8x[4]xf32>
+ // ALL: return %[[RES]] : vector<2x8x[4]xf32>
+ return %0 : vector<2x8x[4]xf32>
+}