diff options
author | Matthias Springer <springerm@google.com> | 2022-07-09 09:15:36 +0200 |
---|---|---|
committer | Matthias Springer <springerm@google.com> | 2022-07-09 09:16:52 +0200 |
commit | fc9b37dd532dc68018c0c5947030b34ebcf68d14 (patch) | |
tree | eba5666af3f91b45424df6933c511a4fe3c54c2a | |
parent | e1272ab6ec8ddb56b87cf0d67bb385c5c850e677 (diff) | |
download | llvm-fc9b37dd532dc68018c0c5947030b34ebcf68d14.zip llvm-fc9b37dd532dc68018c0c5947030b34ebcf68d14.tar.gz llvm-fc9b37dd532dc68018c0c5947030b34ebcf68d14.tar.bz2 |
[mlir][bufferization] Do not canonicalize to_tensor(to_memref(x))
This is a partial revert of D128615.
to_memref(to_tensor(x)) always be folded to x. But to_tensor(to_memref(x)) cannot be folded in the general case because writes to the intermediary memref may go unnoticed.
Differential Revision: https://reviews.llvm.org/D129354
-rw-r--r-- | mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp | 16 | ||||
-rw-r--r-- | mlir/test/Dialect/SCF/canonicalize.mlir | 3 | ||||
-rw-r--r-- | mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir | 3 |
3 files changed, 5 insertions, 17 deletions
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp index 35f6f1b..4ab904e 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -539,20 +539,6 @@ OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) { } namespace { -/// Canonicalize bufferization.to_tensor + bufferization.to_memref. -struct ToTensorToMemrefFolding : public OpRewritePattern<ToTensorOp> { - using OpRewritePattern<ToTensorOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(ToTensorOp toTensorOp, - PatternRewriter &rewriter) const final { - auto toMemrefOp = toTensorOp.getMemref().getDefiningOp<ToMemrefOp>(); - if (!toMemrefOp) - return failure(); - rewriter.replaceOp(toTensorOp, toMemrefOp.getTensor()); - return success(); - } -}; - struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> { using OpRewritePattern<tensor::DimOp>::OpRewritePattern; @@ -571,7 +557,7 @@ struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> { void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<DimOfToTensorFolder, ToTensorToMemrefFolding>(context); + results.add<DimOfToTensorFolder>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index 535a007..8e087fc 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -787,7 +787,8 @@ func.func @last_value(%t0: tensor<128x128xf32>, %t1: tensor<128x128xf32>, } // CHECK-NEXT: %[[R0:.*]] = bufferization.to_tensor %[[M0]] : memref<128x128xf32> - // CHECK-NEXT: return %[[R0]], %[[T1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32> + // CHECK-NEXT: %[[R1:.*]] = bufferization.to_tensor %[[M1]] : memref<128x128xf32> + // CHECK-NEXT: return %[[R0]], %[[R1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32> return %0#0, %0#1, %0#2 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32> } diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir index f24048e..df55b83 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir @@ -109,7 +109,8 @@ // CHECK: scf.yield %[[VAL_84]] : f64 // CHECK: } // CHECK: memref.store %[[VAL_86:.*]], %[[VAL_15]][] : memref<f64> -// CHECK: return %[[VAL_0]] : tensor<f64> +// CHECK: %[[VAL_87:.*]] = bufferization.to_tensor %[[VAL_15]] : memref<f64> +// CHECK: return %[[VAL_87]] : tensor<f64> // CHECK: } func.func @sparse_matrix_sum(%argx: tensor<f64> {linalg.inplaceable = true}, %arga: tensor<64x32xf64, #SparseMatrix>, |