aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthias Springer <springerm@google.com>2022-07-09 09:15:36 +0200
committerMatthias Springer <springerm@google.com>2022-07-09 09:16:52 +0200
commitfc9b37dd532dc68018c0c5947030b34ebcf68d14 (patch)
treeeba5666af3f91b45424df6933c511a4fe3c54c2a
parente1272ab6ec8ddb56b87cf0d67bb385c5c850e677 (diff)
downloadllvm-fc9b37dd532dc68018c0c5947030b34ebcf68d14.zip
llvm-fc9b37dd532dc68018c0c5947030b34ebcf68d14.tar.gz
llvm-fc9b37dd532dc68018c0c5947030b34ebcf68d14.tar.bz2
[mlir][bufferization] Do not canonicalize to_tensor(to_memref(x))
This is a partial revert of D128615. to_memref(to_tensor(x)) always be folded to x. But to_tensor(to_memref(x)) cannot be folded in the general case because writes to the intermediary memref may go unnoticed. Differential Revision: https://reviews.llvm.org/D129354
-rw-r--r--mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp16
-rw-r--r--mlir/test/Dialect/SCF/canonicalize.mlir3
-rw-r--r--mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir3
3 files changed, 5 insertions, 17 deletions
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 35f6f1b..4ab904e 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -539,20 +539,6 @@ OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) {
}
namespace {
-/// Canonicalize bufferization.to_tensor + bufferization.to_memref.
-struct ToTensorToMemrefFolding : public OpRewritePattern<ToTensorOp> {
- using OpRewritePattern<ToTensorOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ToTensorOp toTensorOp,
- PatternRewriter &rewriter) const final {
- auto toMemrefOp = toTensorOp.getMemref().getDefiningOp<ToMemrefOp>();
- if (!toMemrefOp)
- return failure();
- rewriter.replaceOp(toTensorOp, toMemrefOp.getTensor());
- return success();
- }
-};
-
struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
@@ -571,7 +557,7 @@ struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<DimOfToTensorFolder, ToTensorToMemrefFolding>(context);
+ results.add<DimOfToTensorFolder>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 535a007..8e087fc 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -787,7 +787,8 @@ func.func @last_value(%t0: tensor<128x128xf32>, %t1: tensor<128x128xf32>,
}
// CHECK-NEXT: %[[R0:.*]] = bufferization.to_tensor %[[M0]] : memref<128x128xf32>
- // CHECK-NEXT: return %[[R0]], %[[T1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
+ // CHECK-NEXT: %[[R1:.*]] = bufferization.to_tensor %[[M1]] : memref<128x128xf32>
+ // CHECK-NEXT: return %[[R0]], %[[R1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
return %0#0, %0#1, %0#2 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
index f24048e..df55b83 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
@@ -109,7 +109,8 @@
// CHECK: scf.yield %[[VAL_84]] : f64
// CHECK: }
// CHECK: memref.store %[[VAL_86:.*]], %[[VAL_15]][] : memref<f64>
-// CHECK: return %[[VAL_0]] : tensor<f64>
+// CHECK: %[[VAL_87:.*]] = bufferization.to_tensor %[[VAL_15]] : memref<f64>
+// CHECK: return %[[VAL_87]] : tensor<f64>
// CHECK: }
func.func @sparse_matrix_sum(%argx: tensor<f64> {linalg.inplaceable = true},
%arga: tensor<64x32xf64, #SparseMatrix>,