//===- LegalizeDataValues.cpp - -------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/OpenACC/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/IR/Dominance.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/Support/ErrorHandling.h" namespace mlir { namespace acc { #define GEN_PASS_DEF_LEGALIZEDATAVALUESINREGION #include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc" } // namespace acc } // namespace mlir using namespace mlir; namespace { static bool insideAccComputeRegion(mlir::Operation *op) { mlir::Operation *parent{op->getParentOp()}; while (parent) { if (isa(parent)) { return true; } parent = parent->getParentOp(); } return false; } static void collectVars(mlir::ValueRange operands, llvm::SmallVector> &values, bool hostToDevice) { for (auto operand : operands) { Value var = acc::getVar(operand.getDefiningOp()); Value accVar = acc::getAccVar(operand.getDefiningOp()); if (var && accVar) { if (hostToDevice) values.push_back({var, accVar}); else values.push_back({accVar, var}); } } } template static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement, Region &outerRegion) { for (auto &use : llvm::make_early_inc_range(orig.getUses())) { if (outerRegion.isAncestor(use.getOwner()->getParentRegion())) { if constexpr (std::is_same_v || std::is_same_v) { // For data construct regions, only replace uses in contained compute // regions. if (insideAccComputeRegion(use.getOwner())) { use.set(replacement); } } else { use.set(replacement); } } } } template static void replaceAllUsesInUnstructuredComputeRegionWith( Op &op, llvm::SmallVector> &values, DominanceInfo &domInfo, PostDominanceInfo &postDomInfo) { SmallVector exitOps; if constexpr (std::is_same_v) { // For declare enter/exit pairs, collect all exit ops for (auto *user : op.getToken().getUsers()) { if (auto declareExit = dyn_cast(user)) exitOps.push_back(declareExit); } if (exitOps.empty()) return; } for (auto p : values) { Value hostVal = std::get<0>(p); Value deviceVal = std::get<1>(p); for (auto &use : llvm::make_early_inc_range(hostVal.getUses())) { Operation *owner = use.getOwner(); // Check It's the case that the acc entry operation dominates the use. if (!domInfo.dominates(op.getOperation(), owner)) continue; // Check It's the case that at least one of the acc exit operations // post-dominates the use bool hasPostDominatingExit = false; for (auto *exit : exitOps) { if (postDomInfo.postDominates(exit, owner)) { hasPostDominatingExit = true; break; } } if (!hasPostDominatingExit) continue; if (insideAccComputeRegion(owner)) use.set(deviceVal); } } } template static void collectAndReplaceInRegion(Op &op, bool hostToDevice, DominanceInfo *domInfo = nullptr, PostDominanceInfo *postDomInfo = nullptr) { llvm::SmallVector> values; if constexpr (std::is_same_v) { collectVars(op.getReductionOperands(), values, hostToDevice); collectVars(op.getPrivateOperands(), values, hostToDevice); } else { collectVars(op.getDataClauseOperands(), values, hostToDevice); if constexpr (!std::is_same_v && !std::is_same_v && !std::is_same_v && !std::is_same_v && !std::is_same_v) { collectVars(op.getReductionOperands(), values, hostToDevice); collectVars(op.getPrivateOperands(), values, hostToDevice); collectVars(op.getFirstprivateOperands(), values, hostToDevice); } } if constexpr (std::is_same_v) { assert(domInfo && postDomInfo && "Dominance info required for DeclareEnterOp"); replaceAllUsesInUnstructuredComputeRegionWith(op, values, *domInfo, *postDomInfo); } else { for (auto p : values) { replaceAllUsesInAccComputeRegionsWith(std::get<0>(p), std::get<1>(p), op.getRegion()); } } } class LegalizeDataValuesInRegion : public acc::impl::LegalizeDataValuesInRegionBase< LegalizeDataValuesInRegion> { public: using LegalizeDataValuesInRegionBase< LegalizeDataValuesInRegion>::LegalizeDataValuesInRegionBase; void runOnOperation() override { func::FuncOp funcOp = getOperation(); bool replaceHostVsDevice = this->hostToDevice.getValue(); // Initialize dominance info DominanceInfo domInfo; PostDominanceInfo postDomInfo; bool computedDomInfo = false; funcOp.walk([&](Operation *op) { if (!isa(*op) && !(isa(*op) && applyToAccDataConstruct) && !isa(*op)) return; if (auto parallelOp = dyn_cast(*op)) { collectAndReplaceInRegion(parallelOp, replaceHostVsDevice); } else if (auto serialOp = dyn_cast(*op)) { collectAndReplaceInRegion(serialOp, replaceHostVsDevice); } else if (auto kernelsOp = dyn_cast(*op)) { collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice); } else if (auto loopOp = dyn_cast(*op)) { collectAndReplaceInRegion(loopOp, replaceHostVsDevice); } else if (auto dataOp = dyn_cast(*op)) { collectAndReplaceInRegion(dataOp, replaceHostVsDevice); } else if (auto declareOp = dyn_cast(*op)) { collectAndReplaceInRegion(declareOp, replaceHostVsDevice); } else if (auto hostDataOp = dyn_cast(*op)) { collectAndReplaceInRegion(hostDataOp, replaceHostVsDevice); } else if (auto declareEnterOp = dyn_cast(*op)) { if (!computedDomInfo) { domInfo = DominanceInfo(funcOp); postDomInfo = PostDominanceInfo(funcOp); computedDomInfo = true; } collectAndReplaceInRegion(declareEnterOp, replaceHostVsDevice, &domInfo, &postDomInfo); } else { llvm_unreachable("unsupported acc region op"); } }); } }; } // end anonymous namespace