diff options
author | Artem Kroviakov <71938912+akroviakov@users.noreply.github.com> | 2024-06-21 16:36:56 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-21 09:36:56 -0500 |
commit | 74a105ad80725b4a54ef76950c938ffe76796b3b (patch) | |
tree | ff61fcb697edbaf8178dc0650e575467e9909d57 /mlir | |
parent | db8c7e004a8acf74f40e0f7bc60066f26d43ccd9 (diff) | |
download | llvm-74a105ad80725b4a54ef76950c938ffe76796b3b.zip llvm-74a105ad80725b4a54ef76950c938ffe76796b3b.tar.gz llvm-74a105ad80725b4a54ef76950c938ffe76796b3b.tar.bz2 |
[mlir][vector] Use notifyMatchFailure instead of assert in VectorLinearize (#93590)
As it was [suggested](https://github.com/llvm/llvm-project/pull/92370#discussion_r1617592942), the `assert` is replaced by `notifyMatchFailure` for improved consistency.
Diffstat (limited to 'mlir')
-rw-r--r-- | mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 35 | ||||
-rw-r--r-- | mlir/test/Dialect/Vector/linearize.mlir | 32 |
2 files changed, 54 insertions, 13 deletions
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 156bf74..a1bb81e 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -151,10 +151,12 @@ struct LinearizeVectorExtractStridedSlice final LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Type dstType = getTypeConverter()->convertType(extractOp.getType()); - assert(!(extractOp.getVector().getType().isScalable() || - cast<VectorType>(dstType).isScalable()) && - "scalable vectors are not supported."); + VectorType dstType = + getTypeConverter()->convertType<VectorType>(extractOp.getType()); + assert(dstType && "vector type destination expected."); + if (extractOp.getVector().getType().isScalable() || dstType.isScalable()) + return rewriter.notifyMatchFailure(extractOp, + "scalable vectors are not supported."); if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth)) return rewriter.notifyMatchFailure( extractOp, "Can't flatten since targetBitWidth <= OpSize"); @@ -264,10 +266,14 @@ struct LinearizeVectorShuffle final LogicalResult matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Type dstType = getTypeConverter()->convertType(shuffleOp.getType()); + VectorType dstType = + getTypeConverter()->convertType<VectorType>(shuffleOp.getType()); + assert(dstType && "vector type destination expected."); + // The assert is used because vector.shuffle does not support scalable + // vectors. assert(!(shuffleOp.getV1VectorType().isScalable() || shuffleOp.getV2VectorType().isScalable() || - cast<VectorType>(dstType).isScalable()) && + dstType.isScalable()) && "scalable vectors are not supported."); if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth)) return rewriter.notifyMatchFailure( @@ -336,9 +342,10 @@ struct LinearizeVectorExtract final matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type dstTy = getTypeConverter()->convertType(extractOp.getType()); - assert(!(extractOp.getVector().getType().isScalable() || - cast<VectorType>(dstTy).isScalable()) && - "scalable vectors are not supported."); + if (extractOp.getVector().getType().isScalable() || + cast<VectorType>(dstTy).isScalable()) + return rewriter.notifyMatchFailure(extractOp, + "scalable vectors are not supported."); if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth)) return rewriter.notifyMatchFailure( extractOp, "Can't flatten since targetBitWidth <= OpSize"); @@ -394,10 +401,12 @@ struct LinearizeVectorInsert final LogicalResult matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Type dstTy = getTypeConverter()->convertType(insertOp.getDestVectorType()); - assert(!(insertOp.getDestVectorType().isScalable() || - cast<VectorType>(dstTy).isScalable()) && - "scalable vectors are not supported."); + VectorType dstTy = getTypeConverter()->convertType<VectorType>( + insertOp.getDestVectorType()); + assert(dstTy && "vector type destination expected."); + if (insertOp.getDestVectorType().isScalable() || dstTy.isScalable()) + return rewriter.notifyMatchFailure(insertOp, + "scalable vectors are not supported."); if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(), targetVectorBitWidth)) diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index ec3806c..916e3e5 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -177,6 +177,17 @@ func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf return %0 : vector<2x2xf32> } +// ALL-LABEL: func.func @test_extract_strided_slice_1_scalable( +// ALL-SAME: %[[VAL_0:.*]]: vector<4x[8]xf32>) -> vector<2x[8]xf32> { +func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> { + // ALL-NOT: vector.shuffle + // ALL-NOT: vector.shape_cast + // ALL: %[[RES:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [1, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]xf32> to vector<2x[8]xf32> + %0 = vector.extract_strided_slice %arg0 { sizes = [2, 8], strides = [1, 1], offsets = [1, 0] } : vector<4x[8]xf32> to vector<2x[8]xf32> + // ALL: return %[[RES]] : vector<2x[8]xf32> + return %0 : vector<2x[8]xf32> +} + // ----- // ALL-LABEL: test_extract_strided_slice_2 // ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<1x4x2xf32> { @@ -246,6 +257,16 @@ func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> { return %0 : vector<8x2xf32> } +// ALL-LABEL: func.func @test_vector_extract_scalable( +// ALL-SAME: %[[VAL_0:.*]]: vector<2x8x[2]xf32>) -> vector<8x[2]xf32> { +func.func @test_vector_extract_scalable(%arg0: vector<2x8x[2]xf32>) -> vector<8x[2]xf32> { + // ALL-NOT: vector.shuffle + // ALL-NOT: vector.shape_cast + // ALL: %[[RES:.*]] = vector.extract %[[VAL_0]][1] : vector<8x[2]xf32> from vector<2x8x[2]xf32> + %0 = vector.extract %arg0[1]: vector<8x[2]xf32> from vector<2x8x[2]xf32> + // ALL: return %[[RES]] : vector<8x[2]xf32> + return %0 : vector<8x[2]xf32> +} // ----- // ALL-LABEL: test_vector_insert // ALL-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> { @@ -274,3 +295,14 @@ func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) %0 = vector.insert %arg1, %arg0[0]: vector<8x4xf32> into vector<2x8x4xf32> return %0 : vector<2x8x4xf32> } + +// ALL-LABEL: func.func @test_vector_insert_scalable( +// ALL-SAME: %[[VAL_0:.*]]: vector<2x8x[4]xf32>, %[[VAL_1:.*]]: vector<8x[4]xf32>) -> vector<2x8x[4]xf32> { +func.func @test_vector_insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector<8x[4]xf32>) -> vector<2x8x[4]xf32> { + // ALL-NOT: vector.shuffle + // ALL-NOT: vector.shape_cast + // ALL: %[[RES:.*]] = vector.insert %[[VAL_1]], %[[VAL_0]] [0] : vector<8x[4]xf32> into vector<2x8x[4]xf32> + %0 = vector.insert %arg1, %arg0[0]: vector<8x[4]xf32> into vector<2x8x[4]xf32> + // ALL: return %[[RES]] : vector<2x8x[4]xf32> + return %0 : vector<2x8x[4]xf32> +} |