aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp')
-rw-r--r--mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp40
1 files changed, 23 insertions, 17 deletions
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 240491a..ba448e4 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -312,6 +312,19 @@ struct ForallLowering : public OpRewritePattern<mlir::scf::ForallOp> {
} // namespace
+static void propagateLoopAttrs(Operation *scfOp, Operation *brOp) {
+ // Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the
+ // llvm.loop_annotation attribute.
+ // LLVM requires the loop metadata to be attached on the "latch" block. Which
+ // is the back-edge to the header block (conditionBlock)
+ SmallVector<NamedAttribute> llvmAttrs;
+ llvm::copy_if(scfOp->getAttrs(), std::back_inserter(llvmAttrs),
+ [](auto attr) {
+ return isa<LLVM::LLVMDialect>(attr.getValue().getDialect());
+ });
+ brOp->setDiscardableAttrs(llvmAttrs);
+}
+
LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
PatternRewriter &rewriter) const {
Location loc = forOp.getLoc();
@@ -350,17 +363,7 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
auto branchOp =
cf::BranchOp::create(rewriter, loc, conditionBlock, loopCarried);
- // Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the
- // llvm.loop_annotation attribute.
- // LLVM requires the loop metadata to be attached on the "latch" block. Which
- // is the back-edge to the header block (conditionBlock)
- SmallVector<NamedAttribute> llvmAttrs;
- llvm::copy_if(forOp->getAttrs(), std::back_inserter(llvmAttrs),
- [](auto attr) {
- return isa<LLVM::LLVMDialect>(attr.getValue().getDialect());
- });
- branchOp->setDiscardableAttrs(llvmAttrs);
-
+ propagateLoopAttrs(forOp, branchOp);
rewriter.eraseOp(terminator);
// Compute loop bounds before branching to the condition.
@@ -582,18 +585,20 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
// block. This should be reconsidered if we allow break/continue in SCF.
rewriter.setInsertionPointToEnd(before);
auto condOp = cast<ConditionOp>(before->getTerminator());
+ SmallVector<Value> args = llvm::to_vector(condOp.getArgs());
rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
after, condOp.getArgs(),
continuation, ValueRange());
rewriter.setInsertionPointToEnd(after);
auto yieldOp = cast<scf::YieldOp>(after->getTerminator());
- rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before,
- yieldOp.getResults());
+ auto latch = rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before,
+ yieldOp.getResults());
+ propagateLoopAttrs(whileOp, latch);
// Replace the op with values "yielded" from the "before" region, which are
// visible by dominance.
- rewriter.replaceOp(whileOp, condOp.getArgs());
+ rewriter.replaceOp(whileOp, args);
return success();
}
@@ -630,10 +635,11 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
// Loop around the "before" region based on condition.
rewriter.setInsertionPointToEnd(before);
auto condOp = cast<ConditionOp>(before->getTerminator());
- cf::CondBranchOp::create(rewriter, condOp.getLoc(), condOp.getCondition(),
- before, condOp.getArgs(), continuation,
- ValueRange());
+ auto latch = cf::CondBranchOp::create(
+ rewriter, condOp.getLoc(), condOp.getCondition(), before,
+ condOp.getArgs(), continuation, ValueRange());
+ propagateLoopAttrs(whileOp, latch);
// Replace the op with values "yielded" from the "before" region, which are
// visible by dominance.
rewriter.replaceOp(whileOp, condOp.getArgs());