//===-- FIROpenACCOpsInterfaces.cpp ---------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Implementation of external operation interfaces for FIR. // //===----------------------------------------------------------------------===// #include "flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h" #include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/HLFIR/HLFIROps.h" #include "flang/Optimizer/Support/InternalNames.h" #include "mlir/IR/SymbolTable.h" #include "llvm/ADT/SmallSet.h" namespace fir::acc { template <> mlir::Value PartialEntityAccessModel::getBaseEntity( mlir::Operation *op) const { return mlir::cast(op).getMemref(); } template <> mlir::Value PartialEntityAccessModel::getBaseEntity( mlir::Operation *op) const { return mlir::cast(op).getRef(); } template <> mlir::Value PartialEntityAccessModel::getBaseEntity( mlir::Operation *op) const { return mlir::cast(op).getMemref(); } mlir::Value PartialEntityAccessModel::getBaseEntity( mlir::Operation *op) const { auto declareOp = mlir::cast(op); // If storage is present, return it (partial view case) if (mlir::Value storage = declareOp.getStorage()) return storage; // Otherwise return the memref (complete view case) return declareOp.getMemref(); } bool PartialEntityAccessModel::isCompleteView( mlir::Operation *op) const { // Complete view if storage is absent return !mlir::cast(op).getStorage(); } mlir::Value PartialEntityAccessModel::getBaseEntity( mlir::Operation *op) const { auto declareOp = mlir::cast(op); // If storage is present, return it (partial view case) if (mlir::Value storage = declareOp.getStorage()) return storage; // Otherwise return the memref (complete view case) return declareOp.getMemref(); } bool PartialEntityAccessModel::isCompleteView( mlir::Operation *op) const { // Complete view if storage is absent return !mlir::cast(op).getStorage(); } mlir::SymbolRefAttr AddressOfGlobalModel::getSymbol(mlir::Operation *op) const { return mlir::cast(op).getSymbolAttr(); } bool GlobalVariableModel::isConstant(mlir::Operation *op) const { auto globalOp = mlir::cast(op); return globalOp.getConstant().has_value(); } mlir::Region *GlobalVariableModel::getInitRegion(mlir::Operation *op) const { auto globalOp = mlir::cast(op); return globalOp.hasInitializationBody() ? &globalOp.getRegion() : nullptr; } bool GlobalVariableModel::isDeviceData(mlir::Operation *op) const { if (auto dataAttr = cuf::getDataAttr(op)) return cuf::isDeviceDataAttribute(dataAttr.getValue()); return false; } // Helper to recursively process address-of operations in derived type // descriptors and collect all needed fir.globals. static void processAddrOfOpInDerivedTypeDescriptor( fir::AddrOfOp addrOfOp, mlir::SymbolTable &symTab, llvm::SmallSet &globalsSet, llvm::SmallVectorImpl &symbols) { if (auto globalOp = symTab.lookup( addrOfOp.getSymbol().getLeafReference().getValue())) { if (globalsSet.contains(globalOp)) return; globalsSet.insert(globalOp); symbols.push_back(addrOfOp.getSymbolAttr()); globalOp.walk([&](fir::AddrOfOp op) { processAddrOfOpInDerivedTypeDescriptor(op, symTab, globalsSet, symbols); }); } } // Utility to collect referenced symbols for type descriptors of derived types. // This is the common logic for operations that may require type descriptor // globals. static void collectReferencedSymbolsForType( mlir::Type ty, mlir::Operation *op, llvm::SmallVectorImpl &symbols, mlir::SymbolTable *symbolTable) { ty = fir::getDerivedType(fir::unwrapRefType(ty)); // Look for type descriptor globals only if it's a derived (record) type if (auto recTy = mlir::dyn_cast_if_present(ty)) { // If no symbol table provided, simply add the type descriptor name if (!symbolTable) { symbols.push_back(mlir::SymbolRefAttr::get( op->getContext(), fir::NameUniquer::getTypeDescriptorName(recTy.getName()))); return; } // Otherwise, do full lookup and recursive processing llvm::SmallSet globalsSet; fir::GlobalOp globalOp = symbolTable->lookup( fir::NameUniquer::getTypeDescriptorName(recTy.getName())); if (!globalOp) globalOp = symbolTable->lookup( fir::NameUniquer::getTypeDescriptorAssemblyName(recTy.getName())); if (globalOp) { globalsSet.insert(globalOp); symbols.push_back( mlir::SymbolRefAttr::get(op->getContext(), globalOp.getSymName())); globalOp.walk([&](fir::AddrOfOp addrOp) { processAddrOfOpInDerivedTypeDescriptor(addrOp, *symbolTable, globalsSet, symbols); }); } } } template <> void IndirectGlobalAccessModel::getReferencedSymbols( mlir::Operation *op, llvm::SmallVectorImpl &symbols, mlir::SymbolTable *symbolTable) const { auto allocaOp = mlir::cast(op); collectReferencedSymbolsForType(allocaOp.getType(), op, symbols, symbolTable); } template <> void IndirectGlobalAccessModel::getReferencedSymbols( mlir::Operation *op, llvm::SmallVectorImpl &symbols, mlir::SymbolTable *symbolTable) const { auto emboxOp = mlir::cast(op); collectReferencedSymbolsForType(emboxOp.getMemref().getType(), op, symbols, symbolTable); } template <> void IndirectGlobalAccessModel::getReferencedSymbols( mlir::Operation *op, llvm::SmallVectorImpl &symbols, mlir::SymbolTable *symbolTable) const { auto reboxOp = mlir::cast(op); collectReferencedSymbolsForType(reboxOp.getBox().getType(), op, symbols, symbolTable); } template <> void IndirectGlobalAccessModel::getReferencedSymbols( mlir::Operation *op, llvm::SmallVectorImpl &symbols, mlir::SymbolTable *symbolTable) const { auto typeDescOp = mlir::cast(op); collectReferencedSymbolsForType(typeDescOp.getInType(), op, symbols, symbolTable); } template <> bool OperationMoveModel::canMoveFromDescendant( mlir::Operation *op, mlir::Operation *descendant, mlir::Operation *candidate) const { // It should be always allowed to move operations from descendants // of acc.loop into the acc.loop. return true; } template <> bool OperationMoveModel::canMoveOutOf( mlir::Operation *op, mlir::Operation *candidate) const { // Disallow moving operations, which have operands that are referenced // in the data operands (e.g. in [first]private() etc.) of the acc.loop. // For example: // %17 = acc.private var(%16 : !fir.box>) // acc.loop private(%17 : !fir.box>) ... { // %19 = fir.box_addr %17 // } // We cannot hoist %19 without violating assumptions that OpenACC // transformations rely on. // In general, some movement out of acc.loop is allowed, // so return true if candidate is nullptr. if (!candidate) return true; auto loopOp = mlir::cast(op); unsigned numDataOperands = loopOp.getNumDataOperands(); for (unsigned i = 0; i < numDataOperands; ++i) { mlir::Value dataOperand = loopOp.getDataOperand(i); if (llvm::any_of(candidate->getOperands(), [&](mlir::Value candidateOperand) { return dataOperand == candidateOperand; })) return false; } return true; } } // namespace fir::acc