aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authordonald chen <chenxunyu1993@gmail.com>2024-07-15 12:52:46 +0800
committerGitHub <noreply@github.com>2024-07-15 12:52:46 +0800
commit662c6fc74c1ae7fdefd27524dfaeb7f3e9a1f553 (patch)
tree756239602d52c8c6c953a174fc771d742bbfac22
parent36984536be147ce4f53d7b05a77c8ccc79227271 (diff)
downloadllvm-662c6fc74c1ae7fdefd27524dfaeb7f3e9a1f553.zip
llvm-662c6fc74c1ae7fdefd27524dfaeb7f3e9a1f553.tar.gz
llvm-662c6fc74c1ae7fdefd27524dfaeb7f3e9a1f553.tar.bz2
[mlir] [bufferize] fix bufferize deallocation error in nest symbol table (#98476)
In nested symbols, the dealloc_helper function generated by lower deallocations pass was incorrectly positioned, causing calls fail. This patch fixes this issue.
-rw-r--r--mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h5
-rw-r--r--mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp21
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp41
-rw-r--r--mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir41
4 files changed, 82 insertions, 26 deletions
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index e053e6c..c12ed7f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -18,6 +18,9 @@ class FuncOp;
namespace bufferization {
struct OneShotBufferizationOptions;
+/// Maps from symbol table to its corresponding dealloc helper function.
+using DeallocHelperMap = llvm::DenseMap<Operation *, func::FuncOp>;
+
//===----------------------------------------------------------------------===//
// Passes
//===----------------------------------------------------------------------===//
@@ -46,7 +49,7 @@ std::unique_ptr<Pass> createLowerDeallocationsPass();
/// Adds the conversion pattern of the `bufferization.dealloc` operation to the
/// given pattern set for use in other transformation passes.
void populateBufferizationDeallocLoweringPattern(
- RewritePatternSet &patterns, func::FuncOp deallocLibraryFunc);
+ RewritePatternSet &patterns, const DeallocHelperMap &deallocHelperFuncMap);
/// Construct the library function needed for the fully generic
/// `bufferization.dealloc` lowering implemented in the LowerDeallocations pass.
diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
index 2aae39f..f990307 100644
--- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
+++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
@@ -132,27 +132,30 @@ struct BufferizationToMemRefPass
return;
}
- func::FuncOp helperFuncOp;
+ bufferization::DeallocHelperMap deallocHelperFuncMap;
if (auto module = dyn_cast<ModuleOp>(getOperation())) {
OpBuilder builder =
OpBuilder::atBlockBegin(&module.getBodyRegion().front());
- SymbolTable symbolTable(module);
// Build dealloc helper function if there are deallocs.
getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
- if (deallocOp.getMemrefs().size() > 1) {
- helperFuncOp = bufferization::buildDeallocationLibraryFunction(
- builder, getOperation()->getLoc(), symbolTable);
- return WalkResult::interrupt();
+ Operation *symtableOp =
+ deallocOp->getParentWithTrait<OpTrait::SymbolTable>();
+ if (deallocOp.getMemrefs().size() > 1 &&
+ !deallocHelperFuncMap.contains(symtableOp)) {
+ SymbolTable symbolTable(symtableOp);
+ func::FuncOp helperFuncOp =
+ bufferization::buildDeallocationLibraryFunction(
+ builder, getOperation()->getLoc(), symbolTable);
+ deallocHelperFuncMap[symtableOp] = helperFuncOp;
}
- return WalkResult::advance();
});
}
RewritePatternSet patterns(&getContext());
patterns.add<CloneOpConversion>(patterns.getContext());
- bufferization::populateBufferizationDeallocLoweringPattern(patterns,
- helperFuncOp);
+ bufferization::populateBufferizationDeallocLoweringPattern(
+ patterns, deallocHelperFuncMap);
ConversionTarget target(getContext());
target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp b/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp
index 7fb4691..9e2c91b 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp
@@ -300,8 +300,9 @@ class DeallocOpConversion
MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()),
retainCondsMemref);
+ Operation *symtableOp = op->getParentWithTrait<OpTrait::SymbolTable>();
rewriter.create<func::CallOp>(
- op.getLoc(), deallocHelperFunc,
+ op.getLoc(), deallocHelperFuncMap.lookup(symtableOp),
SmallVector<Value>{castedDeallocMemref, castedRetainMemref,
castedCondsMemref, castedDeallocCondsMemref,
castedRetainCondsMemref});
@@ -338,9 +339,11 @@ class DeallocOpConversion
}
public:
- DeallocOpConversion(MLIRContext *context, func::FuncOp deallocHelperFunc)
+ DeallocOpConversion(
+ MLIRContext *context,
+ const bufferization::DeallocHelperMap &deallocHelperFuncMap)
: OpConversionPattern<bufferization::DeallocOp>(context),
- deallocHelperFunc(deallocHelperFunc) {}
+ deallocHelperFuncMap(deallocHelperFuncMap) {}
LogicalResult
matchAndRewrite(bufferization::DeallocOp op, OpAdaptor adaptor,
@@ -360,7 +363,8 @@ public:
if (adaptor.getMemrefs().size() == 1)
return rewriteOneMemrefMultipleRetainCase(op, adaptor, rewriter);
- if (!deallocHelperFunc)
+ Operation *symtableOp = op->getParentWithTrait<OpTrait::SymbolTable>();
+ if (!deallocHelperFuncMap.contains(symtableOp))
return op->emitError(
"library function required for generic lowering, but cannot be "
"automatically inserted when operating on functions");
@@ -369,7 +373,7 @@ public:
}
private:
- func::FuncOp deallocHelperFunc;
+ const bufferization::DeallocHelperMap &deallocHelperFuncMap;
};
} // namespace
@@ -385,26 +389,29 @@ struct LowerDeallocationsPass
return;
}
- func::FuncOp helperFuncOp;
+ bufferization::DeallocHelperMap deallocHelperFuncMap;
if (auto module = dyn_cast<ModuleOp>(getOperation())) {
OpBuilder builder =
OpBuilder::atBlockBegin(&module.getBodyRegion().front());
- SymbolTable symbolTable(module);
// Build dealloc helper function if there are deallocs.
getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
- if (deallocOp.getMemrefs().size() > 1) {
- helperFuncOp = bufferization::buildDeallocationLibraryFunction(
- builder, getOperation()->getLoc(), symbolTable);
- return WalkResult::interrupt();
+ Operation *symtableOp =
+ deallocOp->getParentWithTrait<OpTrait::SymbolTable>();
+ if (deallocOp.getMemrefs().size() > 1 &&
+ !deallocHelperFuncMap.contains(symtableOp)) {
+ SymbolTable symbolTable(symtableOp);
+ func::FuncOp helperFuncOp =
+ bufferization::buildDeallocationLibraryFunction(
+ builder, getOperation()->getLoc(), symbolTable);
+ deallocHelperFuncMap[symtableOp] = helperFuncOp;
}
- return WalkResult::advance();
});
}
RewritePatternSet patterns(&getContext());
- bufferization::populateBufferizationDeallocLoweringPattern(patterns,
- helperFuncOp);
+ bufferization::populateBufferizationDeallocLoweringPattern(
+ patterns, deallocHelperFuncMap);
ConversionTarget target(getContext());
target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
@@ -535,8 +542,10 @@ func::FuncOp mlir::bufferization::buildDeallocationLibraryFunction(
}
void mlir::bufferization::populateBufferizationDeallocLoweringPattern(
- RewritePatternSet &patterns, func::FuncOp deallocLibraryFunc) {
- patterns.add<DeallocOpConversion>(patterns.getContext(), deallocLibraryFunc);
+ RewritePatternSet &patterns,
+ const bufferization::DeallocHelperMap &deallocHelperFuncMap) {
+ patterns.add<DeallocOpConversion>(patterns.getContext(),
+ deallocHelperFuncMap);
}
std::unique_ptr<Pass> mlir::bufferization::createLowerDeallocationsPass() {
diff --git a/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir b/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir
index 5fedd45..edffcbd 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir
@@ -154,3 +154,44 @@ func.func @conversion_dealloc_multiple_memrefs_and_retained(%arg0: memref<2xf32>
// CHECK-NEXT: memref.store [[DEALLOC_COND]], [[DEALLOC_CONDS_OUT]][[[OUTER_ITER]]]
// CHECK-NEXT: }
// CHECK-NEXT: return
+
+// -----
+
+// This test check dealloc_helper function is generated on each nested symbol
+// table operation when needed and only generated once.
+module @conversion_nest_module_dealloc_helper {
+ func.func @top_level_func(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) {
+ %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>)
+ func.return %0#0, %0#1 : i1, i1
+ }
+ module @nested_module_not_need_dealloc_helper {
+ func.func @nested_module_not_need_dealloc_helper_func(%arg0: memref<2xf32>, %arg1: memref<1xf32>, %arg2: i1, %arg3: memref<2xf32>) -> (i1, i1) {
+ %0:2 = bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg2) retain (%arg1, %arg3 : memref<1xf32>, memref<2xf32>)
+ return %0#0, %0#1 : i1, i1
+ }
+ }
+ module @nested_module_need_dealloc_helper {
+ func.func @nested_module_need_dealloc_helper_func0(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) {
+ %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>)
+ func.return %0#0, %0#1 : i1, i1
+ }
+ func.func @nested_module_need_dealloc_helper_func1(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) {
+ %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>)
+ func.return %0#0, %0#1 : i1, i1
+ }
+ }
+}
+
+// CHECK: module @conversion_nest_module_dealloc_helper {
+// CHECK: func.func @top_level_func
+// CHECK: call @dealloc_helper
+// CHECK: module @nested_module_not_need_dealloc_helper {
+// CHECK: func.func @nested_module_not_need_dealloc_helper_func
+// CHECK-NOT: @dealloc_helper
+// CHECK: module @nested_module_need_dealloc_helper {
+// CHECK: func.func @nested_module_need_dealloc_helper_func0
+// CHECK: call @dealloc_helper
+// CHECK: func.func @nested_module_need_dealloc_helper_func1
+// CHECK: call @dealloc_helper
+// CHECK: func.func private @dealloc_helper
+// CHECK: func.func private @dealloc_helper