diff options
Diffstat (limited to 'mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp')
-rw-r--r-- | mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp | 40 |
1 files changed, 23 insertions, 17 deletions
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp index 240491a..ba448e4 100644 --- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp +++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp @@ -312,6 +312,19 @@ struct ForallLowering : public OpRewritePattern<mlir::scf::ForallOp> { } // namespace +static void propagateLoopAttrs(Operation *scfOp, Operation *brOp) { + // Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the + // llvm.loop_annotation attribute. + // LLVM requires the loop metadata to be attached on the "latch" block. Which + // is the back-edge to the header block (conditionBlock) + SmallVector<NamedAttribute> llvmAttrs; + llvm::copy_if(scfOp->getAttrs(), std::back_inserter(llvmAttrs), + [](auto attr) { + return isa<LLVM::LLVMDialect>(attr.getValue().getDialect()); + }); + brOp->setDiscardableAttrs(llvmAttrs); +} + LogicalResult ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const { Location loc = forOp.getLoc(); @@ -350,17 +363,7 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp, auto branchOp = cf::BranchOp::create(rewriter, loc, conditionBlock, loopCarried); - // Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the - // llvm.loop_annotation attribute. - // LLVM requires the loop metadata to be attached on the "latch" block. Which - // is the back-edge to the header block (conditionBlock) - SmallVector<NamedAttribute> llvmAttrs; - llvm::copy_if(forOp->getAttrs(), std::back_inserter(llvmAttrs), - [](auto attr) { - return isa<LLVM::LLVMDialect>(attr.getValue().getDialect()); - }); - branchOp->setDiscardableAttrs(llvmAttrs); - + propagateLoopAttrs(forOp, branchOp); rewriter.eraseOp(terminator); // Compute loop bounds before branching to the condition. @@ -582,18 +585,20 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp, // block. This should be reconsidered if we allow break/continue in SCF. rewriter.setInsertionPointToEnd(before); auto condOp = cast<ConditionOp>(before->getTerminator()); + SmallVector<Value> args = llvm::to_vector(condOp.getArgs()); rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(), after, condOp.getArgs(), continuation, ValueRange()); rewriter.setInsertionPointToEnd(after); auto yieldOp = cast<scf::YieldOp>(after->getTerminator()); - rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before, - yieldOp.getResults()); + auto latch = rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before, + yieldOp.getResults()); + propagateLoopAttrs(whileOp, latch); // Replace the op with values "yielded" from the "before" region, which are // visible by dominance. - rewriter.replaceOp(whileOp, condOp.getArgs()); + rewriter.replaceOp(whileOp, args); return success(); } @@ -630,10 +635,11 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp, // Loop around the "before" region based on condition. rewriter.setInsertionPointToEnd(before); auto condOp = cast<ConditionOp>(before->getTerminator()); - cf::CondBranchOp::create(rewriter, condOp.getLoc(), condOp.getCondition(), - before, condOp.getArgs(), continuation, - ValueRange()); + auto latch = cf::CondBranchOp::create( + rewriter, condOp.getLoc(), condOp.getCondition(), before, + condOp.getArgs(), continuation, ValueRange()); + propagateLoopAttrs(whileOp, latch); // Replace the op with values "yielded" from the "before" region, which are // visible by dominance. rewriter.replaceOp(whileOp, condOp.getArgs()); |