//===- ConstantArgumentGlobalisation.cpp ----------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Dialect/FIRDialect.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIRType.h" #include "flang/Optimizer/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dominance.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace fir { #define GEN_PASS_DEF_CONSTANTARGUMENTGLOBALISATIONOPT #include "flang/Optimizer/Transforms/Passes.h.inc" } // namespace fir #define DEBUG_TYPE "flang-constant-argument-globalisation-opt" namespace { unsigned uniqueLitId = 1; class CallOpRewriter : public mlir::OpRewritePattern { protected: const mlir::DominanceInfo &di; public: using OpRewritePattern::OpRewritePattern; CallOpRewriter(mlir::MLIRContext *ctx, const mlir::DominanceInfo &_di) : OpRewritePattern(ctx), di(_di) {} llvm::LogicalResult matchAndRewrite(fir::CallOp callOp, mlir::PatternRewriter &rewriter) const override { LLVM_DEBUG(llvm::dbgs() << "Processing call op: " << callOp << "\n"); auto module = callOp->getParentOfType(); bool needUpdate = false; fir::FirOpBuilder builder(rewriter, module); llvm::SmallVector newOperands; llvm::SmallVector> allocas; for (const mlir::Value &a : callOp.getArgs()) { auto alloca = mlir::dyn_cast_or_null(a.getDefiningOp()); // We can convert arguments that are alloca, and that has // the value by reference attribute. All else is just added // to the argument list. if (!alloca || !alloca->hasAttr(fir::getAdaptToByRefAttrName())) { newOperands.push_back(a); continue; } mlir::Type varTy = alloca.getInType(); assert(!fir::hasDynamicSize(varTy) && "only expect statically sized scalars to be by value"); // Find immediate store with const argument mlir::Operation *store = nullptr; for (mlir::Operation *s : alloca->getUsers()) { if (mlir::isa(s) && di.dominates(s, callOp)) { // We can only deal with ONE store - if already found one, // set to nullptr and exit the loop. if (store) { store = nullptr; break; } store = s; } } // If we didn't find any store, or multiple stores, add argument as is // and move on. if (!store) { newOperands.push_back(a); continue; } LLVM_DEBUG(llvm::dbgs() << " found store " << *store << "\n"); mlir::Operation *definingOp = store->getOperand(0).getDefiningOp(); // If not a constant, add to operands and move on. if (!mlir::isa(definingOp)) { // Unable to remove alloca arg newOperands.push_back(a); continue; } LLVM_DEBUG(llvm::dbgs() << " found define " << *definingOp << "\n"); std::string globalName = "_global_const_." + std::to_string(uniqueLitId++); assert(!builder.getNamedGlobal(globalName) && "We should have a unique name here"); if (llvm::none_of(allocas, [alloca](auto x) { return x.first == alloca; })) { allocas.push_back(std::make_pair(alloca, store)); } auto loc = callOp.getLoc(); fir::GlobalOp global = builder.createGlobalConstant( loc, varTy, globalName, [&](fir::FirOpBuilder &builder) { mlir::Operation *cln = definingOp->clone(); builder.insert(cln); mlir::Value val = builder.createConvert(loc, varTy, cln->getResult(0)); fir::HasValueOp::create(builder, loc, val); }, builder.createInternalLinkage()); mlir::Value addr = fir::AddrOfOp::create( builder, loc, global.resultType(), global.getSymbol()); newOperands.push_back(addr); needUpdate = true; } if (needUpdate) { auto loc = callOp.getLoc(); llvm::SmallVector newResultTypes; newResultTypes.append(callOp.getResultTypes().begin(), callOp.getResultTypes().end()); fir::CallOp newOp = fir::CallOp::create(builder, loc, callOp.getCallee().has_value() ? callOp.getCallee().value() : mlir::SymbolRefAttr{}, newResultTypes, newOperands); // Copy all the attributes from the old to new op. newOp->setAttrs(callOp->getAttrs()); rewriter.replaceOp(callOp, newOp); for (auto a : allocas) { if (a.first->hasOneUse()) { // If the alloca is only used for a store and the call operand, the // store is no longer required. rewriter.eraseOp(a.second); rewriter.eraseOp(a.first); } } LLVM_DEBUG(llvm::dbgs() << "global constant for " << callOp << " as " << newOp << '\n'); return mlir::success(); } // Failure here just means "we couldn't do the conversion", which is // perfectly acceptable to the upper layers of this function. return mlir::failure(); } }; // this pass attempts to convert immediate scalar literals in function calls // to global constants to allow transformations such as Dead Argument // Elimination class ConstantArgumentGlobalisationOpt : public fir::impl::ConstantArgumentGlobalisationOptBase< ConstantArgumentGlobalisationOpt> { public: ConstantArgumentGlobalisationOpt() = default; void runOnOperation() override { mlir::ModuleOp mod = getOperation(); mlir::DominanceInfo *di = &getAnalysis(); auto *context = &getContext(); mlir::RewritePatternSet patterns(context); mlir::GreedyRewriteConfig config; config.setRegionSimplificationLevel( mlir::GreedySimplifyRegionLevel::Disabled); config.setStrictness(mlir::GreedyRewriteStrictness::ExistingOps); patterns.insert(context, *di); if (mlir::failed( mlir::applyPatternsGreedily(mod, std::move(patterns), config))) { mlir::emitError(mod.getLoc(), "error in constant globalisation optimization\n"); signalPassFailure(); } } }; } // namespace