aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp')
-rw-r--r--mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp37
1 files changed, 21 insertions, 16 deletions
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 807be7e..ba448e4 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -312,6 +312,19 @@ struct ForallLowering : public OpRewritePattern<mlir::scf::ForallOp> {
} // namespace
+static void propagateLoopAttrs(Operation *scfOp, Operation *brOp) {
+ // Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the
+ // llvm.loop_annotation attribute.
+ // LLVM requires the loop metadata to be attached on the "latch" block. Which
+ // is the back-edge to the header block (conditionBlock)
+ SmallVector<NamedAttribute> llvmAttrs;
+ llvm::copy_if(scfOp->getAttrs(), std::back_inserter(llvmAttrs),
+ [](auto attr) {
+ return isa<LLVM::LLVMDialect>(attr.getValue().getDialect());
+ });
+ brOp->setDiscardableAttrs(llvmAttrs);
+}
+
LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
PatternRewriter &rewriter) const {
Location loc = forOp.getLoc();
@@ -350,17 +363,7 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
auto branchOp =
cf::BranchOp::create(rewriter, loc, conditionBlock, loopCarried);
- // Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the
- // llvm.loop_annotation attribute.
- // LLVM requires the loop metadata to be attached on the "latch" block. Which
- // is the back-edge to the header block (conditionBlock)
- SmallVector<NamedAttribute> llvmAttrs;
- llvm::copy_if(forOp->getAttrs(), std::back_inserter(llvmAttrs),
- [](auto attr) {
- return isa<LLVM::LLVMDialect>(attr.getValue().getDialect());
- });
- branchOp->setDiscardableAttrs(llvmAttrs);
-
+ propagateLoopAttrs(forOp, branchOp);
rewriter.eraseOp(terminator);
// Compute loop bounds before branching to the condition.
@@ -589,9 +592,10 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
rewriter.setInsertionPointToEnd(after);
auto yieldOp = cast<scf::YieldOp>(after->getTerminator());
- rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before,
- yieldOp.getResults());
+ auto latch = rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before,
+ yieldOp.getResults());
+ propagateLoopAttrs(whileOp, latch);
// Replace the op with values "yielded" from the "before" region, which are
// visible by dominance.
rewriter.replaceOp(whileOp, args);
@@ -631,10 +635,11 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
// Loop around the "before" region based on condition.
rewriter.setInsertionPointToEnd(before);
auto condOp = cast<ConditionOp>(before->getTerminator());
- cf::CondBranchOp::create(rewriter, condOp.getLoc(), condOp.getCondition(),
- before, condOp.getArgs(), continuation,
- ValueRange());
+ auto latch = cf::CondBranchOp::create(
+ rewriter, condOp.getLoc(), condOp.getCondition(), before,
+ condOp.getArgs(), continuation, ValueRange());
+ propagateLoopAttrs(whileOp, latch);
// Replace the op with values "yielded" from the "before" region, which are
// visible by dominance.
rewriter.replaceOp(whileOp, condOp.getArgs());