diff options
author | Kunwar Grover <groverkss@gmail.com> | 2024-10-29 22:47:44 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-10-29 22:47:44 +0000 |
commit | 2c5eea0e88a6ef6bf932d90c67aaec2bcc59d340 (patch) | |
tree | a57a57d3f472a9f6455d0f59af814f88c4e629af | |
parent | ca998b071eba1c92bf8535964183c7c4c3b258c3 (diff) | |
download | llvm-2c5eea0e88a6ef6bf932d90c67aaec2bcc59d340.zip llvm-2c5eea0e88a6ef6bf932d90c67aaec2bcc59d340.tar.gz llvm-2c5eea0e88a6ef6bf932d90c67aaec2bcc59d340.tar.bz2 |
[mlir][Vector] Fix vector.insert folder for scalar to 0-d inserts (#113828)
The current vector.insert folder tries to replace a scalar with a 0-rank
vector. This patch fixes this crash by not folding unless they types of
the result and replacement are same.
-rw-r--r-- | mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 8 | ||||
-rw-r--r-- | mlir/test/Dialect/Vector/canonicalize.mlir | 48 |
2 files changed, 41 insertions, 15 deletions
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index d71a236..1853ae0 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2951,11 +2951,11 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results, InsertOpConstantFolder>(context); } -// Eliminates insert operations that produce values identical to their source -// value. This happens when the source and destination vectors have identical -// sizes. OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) { - if (getNumIndices() == 0) + // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to + // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>" + // (type mismatch). + if (getNumIndices() == 0 && getSourceType() == getType()) return getSource(); return {}; } diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 6d6bc19..c963460 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -800,6 +800,43 @@ func.func @fold_extract_shapecast_to_shapecast(%arg0 : vector<3x4xf32>) -> vecto // ----- +// CHECK-LABEL: func @extract_no_fold_scalar_to_0d( +// CHECK-SAME: %[[v:.*]]: vector<f32>) +// CHECK: %[[extract:.*]] = vector.extract %[[v]][] : f32 from vector<f32> +// CHECK: return %[[extract]] +func.func @extract_no_fold_scalar_to_0d(%v: vector<f32>) -> f32 { + %0 = vector.extract %v[] : f32 from vector<f32> + return %0 : f32 +} + +// ----- + +// CHECK-LABEL: func @insert_fold_same_rank( +// CHECK-SAME: %[[v:.*]]: vector<2x2xf32>) +// CHECK: %[[CST:.+]] = arith.constant +// CHECK-SAME: : vector<2x2xf32> +// CHECK-NOT: vector.insert +// CHECK: return %[[CST]] +func.func @insert_fold_same_rank(%v: vector<2x2xf32>) -> vector<2x2xf32> { + %cst = arith.constant dense<0.000000e+00> : vector<2x2xf32> + %0 = vector.insert %cst, %v [] : vector<2x2xf32> into vector<2x2xf32> + return %0 : vector<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @insert_no_fold_scalar_to_0d( +// CHECK-SAME: %[[v:.*]]: vector<f32>) +// CHECK: %[[extract:.*]] = vector.insert %{{.*}}, %[[v]] [] : f32 into vector<f32> +// CHECK: return %[[extract]] +func.func @insert_no_fold_scalar_to_0d(%v: vector<f32>) -> vector<f32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = vector.insert %cst, %v [] : f32 into vector<f32> + return %0 : vector<f32> +} + +// ----- + // CHECK-LABEL: dont_fold_expand_collapse // CHECK: %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<1x1x8x8xf32> // CHECK: %[[B:.*]] = vector.shape_cast %{{.*}} : vector<1x1x8x8xf32> to vector<8x8xf32> @@ -2606,17 +2643,6 @@ func.func @rank_1_shuffle_to_interleave(%arg0: vector<6xi32>, %arg1: vector<6xi3 // ----- -// CHECK-LABEL: func @extract_from_0d_regression( -// CHECK-SAME: %[[v:.*]]: vector<f32>) -// CHECK: %[[extract:.*]] = vector.extract %[[v]][] : f32 from vector<f32> -// CHECK: return %[[extract]] -func.func @extract_from_0d_regression(%v: vector<f32>) -> f32 { - %0 = vector.extract %v[] : f32 from vector<f32> - return %0 : f32 -} - -// ----- - // CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression( // CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: vector<f32>, %[[c:.*]]: vector<2xf32>) func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>, %c: vector<2xf32>) -> (f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) { |