//===-- FIRToSCF.cpp ------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "flang/Optimizer/Dialect/FIRDialect.h" #include "flang/Optimizer/Transforms/Passes.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Transforms/WalkPatternRewriteDriver.h" namespace fir { #define GEN_PASS_DEF_FIRTOSCFPASS #include "flang/Optimizer/Transforms/Passes.h.inc" } // namespace fir namespace { class FIRToSCFPass : public fir::impl::FIRToSCFPassBase { public: void runOnOperation() override; }; struct DoLoopConversion : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite(fir::DoLoopOp doLoopOp, mlir::PatternRewriter &rewriter) const override { mlir::Location loc = doLoopOp.getLoc(); bool hasFinalValue = doLoopOp.getFinalValue().has_value(); // Get loop values from the DoLoopOp mlir::Value low = doLoopOp.getLowerBound(); mlir::Value high = doLoopOp.getUpperBound(); assert(low && high && "must be a Value"); mlir::Value step = doLoopOp.getStep(); mlir::SmallVector iterArgs; if (hasFinalValue) iterArgs.push_back(low); iterArgs.append(doLoopOp.getIterOperands().begin(), doLoopOp.getIterOperands().end()); // fir.do_loop iterates over the interval [%l, %u], and the step may be // negative. But scf.for iterates over the interval [%l, %u), and the step // must be a positive value. // For easier conversion, we calculate the trip count and use a canonical // induction variable. auto diff = mlir::arith::SubIOp::create(rewriter, loc, high, low); auto distance = mlir::arith::AddIOp::create(rewriter, loc, diff, step); auto tripCount = mlir::arith::DivSIOp::create(rewriter, loc, distance, step); auto zero = mlir::arith::ConstantIndexOp::create(rewriter, loc, 0); auto one = mlir::arith::ConstantIndexOp::create(rewriter, loc, 1); auto scfForOp = mlir::scf::ForOp::create(rewriter, loc, zero, tripCount, one, iterArgs); auto &loopOps = doLoopOp.getBody()->getOperations(); auto resultOp = mlir::cast(doLoopOp.getBody()->getTerminator()); auto results = resultOp.getOperands(); mlir::Block *loweredBody = scfForOp.getBody(); loweredBody->getOperations().splice(loweredBody->begin(), loopOps, loopOps.begin(), std::prev(loopOps.end())); rewriter.setInsertionPointToStart(loweredBody); mlir::Value iv = mlir::arith::MulIOp::create( rewriter, loc, scfForOp.getInductionVar(), step); iv = mlir::arith::AddIOp::create(rewriter, loc, low, iv); if (!results.empty()) { rewriter.setInsertionPointToEnd(loweredBody); mlir::scf::YieldOp::create(rewriter, resultOp->getLoc(), results); } doLoopOp.getInductionVar().replaceAllUsesWith(iv); rewriter.replaceAllUsesWith(doLoopOp.getRegionIterArgs(), hasFinalValue ? scfForOp.getRegionIterArgs().drop_front() : scfForOp.getRegionIterArgs()); // Copy all the attributes from the old to new op. scfForOp->setAttrs(doLoopOp->getAttrs()); rewriter.replaceOp(doLoopOp, scfForOp); return mlir::success(); } }; struct IterWhileConversion : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite(fir::IterWhileOp iterWhileOp, mlir::PatternRewriter &rewriter) const override { mlir::Location loc = iterWhileOp.getLoc(); mlir::Value lowerBound = iterWhileOp.getLowerBound(); mlir::Value upperBound = iterWhileOp.getUpperBound(); mlir::Value step = iterWhileOp.getStep(); mlir::Value okInit = iterWhileOp.getIterateIn(); mlir::ValueRange iterArgs = iterWhileOp.getInitArgs(); mlir::SmallVector initVals; initVals.push_back(lowerBound); initVals.push_back(okInit); initVals.append(iterArgs.begin(), iterArgs.end()); mlir::SmallVector loopTypes; loopTypes.push_back(lowerBound.getType()); loopTypes.push_back(okInit.getType()); for (auto val : iterArgs) loopTypes.push_back(val.getType()); auto scfWhileOp = mlir::scf::WhileOp::create(rewriter, loc, loopTypes, initVals); auto &beforeBlock = *rewriter.createBlock( &scfWhileOp.getBefore(), scfWhileOp.getBefore().end(), loopTypes, mlir::SmallVector(loopTypes.size(), loc)); mlir::Region::BlockArgListType argsInBefore = scfWhileOp.getBefore().getArguments(); auto ivInBefore = argsInBefore[0]; auto earlyExitInBefore = argsInBefore[1]; rewriter.setInsertionPointToStart(&beforeBlock); mlir::Value inductionCmp = mlir::arith::CmpIOp::create( rewriter, loc, mlir::arith::CmpIPredicate::sle, ivInBefore, upperBound); mlir::Value cond = mlir::arith::AndIOp::create(rewriter, loc, inductionCmp, earlyExitInBefore); mlir::scf::ConditionOp::create(rewriter, loc, cond, argsInBefore); rewriter.moveBlockBefore(iterWhileOp.getBody(), &scfWhileOp.getAfter(), scfWhileOp.getAfter().begin()); auto *afterBody = scfWhileOp.getAfterBody(); auto resultOp = mlir::cast(afterBody->getTerminator()); mlir::SmallVector results(resultOp->getOperands()); mlir::Value ivInAfter = scfWhileOp.getAfterArguments()[0]; rewriter.setInsertionPointToStart(afterBody); results[0] = mlir::arith::AddIOp::create(rewriter, loc, ivInAfter, step); rewriter.setInsertionPointToEnd(afterBody); rewriter.replaceOpWithNewOp(resultOp, results); scfWhileOp->setAttrs(iterWhileOp->getAttrs()); rewriter.replaceOp(iterWhileOp, scfWhileOp); return mlir::success(); } }; void copyBlockAndTransformResult(mlir::PatternRewriter &rewriter, mlir::Block &srcBlock, mlir::Block &dstBlock) { mlir::Operation *srcTerminator = srcBlock.getTerminator(); auto resultOp = mlir::cast(srcTerminator); dstBlock.getOperations().splice(dstBlock.begin(), srcBlock.getOperations(), srcBlock.begin(), std::prev(srcBlock.end())); if (!resultOp->getOperands().empty()) { rewriter.setInsertionPointToEnd(&dstBlock); mlir::scf::YieldOp::create(rewriter, resultOp->getLoc(), resultOp->getOperands()); } rewriter.eraseOp(srcTerminator); } struct IfConversion : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite(fir::IfOp ifOp, mlir::PatternRewriter &rewriter) const override { bool hasElse = !ifOp.getElseRegion().empty(); auto scfIfOp = mlir::scf::IfOp::create(rewriter, ifOp.getLoc(), ifOp.getResultTypes(), ifOp.getCondition(), hasElse); copyBlockAndTransformResult(rewriter, ifOp.getThenRegion().front(), scfIfOp.getThenRegion().front()); if (hasElse) { copyBlockAndTransformResult(rewriter, ifOp.getElseRegion().front(), scfIfOp.getElseRegion().front()); } scfIfOp->setAttrs(ifOp->getAttrs()); rewriter.replaceOp(ifOp, scfIfOp); return mlir::success(); } }; } // namespace void FIRToSCFPass::runOnOperation() { mlir::RewritePatternSet patterns(&getContext()); patterns.add( patterns.getContext()); walkAndApplyPatterns(getOperation(), std::move(patterns)); } std::unique_ptr fir::createFIRToSCFPass() { return std::make_unique(); }