diff options
Diffstat (limited to 'mlir/lib/Dialect/OpenACC')
| -rw-r--r-- | mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp | 7 | ||||
| -rw-r--r-- | mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp | 636 | ||||
| -rw-r--r-- | mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp | 781 | ||||
| -rw-r--r-- | mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp | 431 | ||||
| -rw-r--r-- | mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitRoutine.cpp | 237 | ||||
| -rw-r--r-- | mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp | 117 | ||||
| -rw-r--r-- | mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt | 7 | ||||
| -rw-r--r-- | mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp | 121 |
8 files changed, 2133 insertions, 204 deletions
diff --git a/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp index 40e769e..1d775fb 100644 --- a/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp +++ b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp @@ -41,5 +41,12 @@ InFlightDiagnostic OpenACCSupport::emitNYI(Location loc, const Twine &message) { return mlir::emitError(loc, "not yet implemented: " + message); } +bool OpenACCSupport::isValidSymbolUse(Operation *user, SymbolRefAttr symbol, + Operation **definingOpPtr) { + if (impl) + return impl->isValidSymbolUse(user, symbol, definingOpPtr); + return acc::isValidSymbolUse(user, symbol, definingOpPtr); +} + } // namespace acc } // namespace mlir diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 35eba72..47f1222 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/SymbolTable.h" @@ -203,12 +204,91 @@ struct MemRefPointerLikeModel return false; } + + mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc, + TypedValue<PointerLikeType> srcPtr, + Type valueType) const { + // Load from a memref - only valid for scalar memrefs (rank 0). + // This is because the address computation for memrefs is part of the load + // (and not computed separately), but the API does not have arguments for + // indexing. + auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(srcPtr); + if (!memrefValue) + return {}; + + auto memrefTy = memrefValue.getType(); + + // Only load from scalar memrefs (rank 0) + if (memrefTy.getRank() != 0) + return {}; + + return memref::LoadOp::create(builder, loc, memrefValue); + } + + bool genStore(Type pointer, OpBuilder &builder, Location loc, + Value valueToStore, TypedValue<PointerLikeType> destPtr) const { + // Store to a memref - only valid for scalar memrefs (rank 0) + // This is because the address computation for memrefs is part of the store + // (and not computed separately), but the API does not have arguments for + // indexing. + auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(destPtr); + if (!memrefValue) + return false; + + auto memrefTy = memrefValue.getType(); + + // Only store to scalar memrefs (rank 0) + if (memrefTy.getRank() != 0) + return false; + + memref::StoreOp::create(builder, loc, valueToStore, memrefValue); + return true; + } }; struct LLVMPointerPointerLikeModel : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel, LLVM::LLVMPointerType> { Type getElementType(Type pointer) const { return Type(); } + + mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc, + TypedValue<PointerLikeType> srcPtr, + Type valueType) const { + // For LLVM pointers, we need the valueType to determine what to load + if (!valueType) + return {}; + + return LLVM::LoadOp::create(builder, loc, valueType, srcPtr); + } + + bool genStore(Type pointer, OpBuilder &builder, Location loc, + Value valueToStore, TypedValue<PointerLikeType> destPtr) const { + LLVM::StoreOp::create(builder, loc, valueToStore, destPtr); + return true; + } +}; + +struct MemrefAddressOfGlobalModel + : public AddressOfGlobalOpInterface::ExternalModel< + MemrefAddressOfGlobalModel, memref::GetGlobalOp> { + SymbolRefAttr getSymbol(Operation *op) const { + auto getGlobalOp = cast<memref::GetGlobalOp>(op); + return getGlobalOp.getNameAttr(); + } +}; + +struct MemrefGlobalVariableModel + : public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel, + memref::GlobalOp> { + bool isConstant(Operation *op) const { + auto globalOp = cast<memref::GlobalOp>(op); + return globalOp.getConstant(); + } + + Region *getInitRegion(Operation *op) const { + // GlobalOp uses attributes for initialization, not regions + return nullptr; + } }; /// Helper function for any of the times we need to modify an ArrayAttr based on @@ -302,6 +382,11 @@ void OpenACCDialect::initialize() { MemRefPointerLikeModel<UnrankedMemRefType>>(*getContext()); LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>( *getContext()); + + // Attach operation interfaces + memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>( + *getContext()); + memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*getContext()); } //===----------------------------------------------------------------------===// @@ -467,6 +552,28 @@ checkValidModifier(Op op, acc::DataClauseModifier validModifiers) { return success(); } +template <typename OpT, typename RecipeOpT> +static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName) { + // Mappable types do not need a recipe because it is possible to generate one + // from its API. Reject reductions though because no API is available for them + // at this time. + if (mlir::acc::isMappableType(op.getVar().getType()) && + !std::is_same_v<OpT, acc::ReductionOp>) + return success(); + + mlir::SymbolRefAttr operandRecipe = op.getRecipeAttr(); + if (!operandRecipe) + return op->emitOpError() << "recipe expected for " << operandName; + + auto decl = + SymbolTable::lookupNearestSymbolFrom<RecipeOpT>(op, operandRecipe); + if (!decl) + return op->emitOpError() + << "expected symbol reference " << operandRecipe << " to point to a " + << operandName << " declaration"; + return success(); +} + static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var) { // Either `var` or `varPtr` keyword is required. @@ -573,6 +680,18 @@ static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, } } +static ParseResult parseRecipeSym(mlir::OpAsmParser &parser, + mlir::SymbolRefAttr &recipeAttr) { + if (failed(parser.parseAttribute(recipeAttr))) + return failure(); + return success(); +} + +static void printRecipeSym(mlir::OpAsmPrinter &p, mlir::Operation *op, + mlir::SymbolRefAttr recipeAttr) { + p << recipeAttr; +} + //===----------------------------------------------------------------------===// // DataBoundsOp //===----------------------------------------------------------------------===// @@ -595,6 +714,9 @@ LogicalResult acc::PrivateOp::verify() { return failure(); if (failed(checkNoModifier(*this))) return failure(); + if (failed( + checkRecipe<acc::PrivateOp, acc::PrivateRecipeOp>(*this, "private"))) + return failure(); return success(); } @@ -609,6 +731,9 @@ LogicalResult acc::FirstprivateOp::verify() { return failure(); if (failed(checkNoModifier(*this))) return failure(); + if (failed(checkRecipe<acc::FirstprivateOp, acc::FirstprivateRecipeOp>( + *this, "firstprivate"))) + return failure(); return success(); } @@ -637,6 +762,9 @@ LogicalResult acc::ReductionOp::verify() { return failure(); if (failed(checkNoModifier(*this))) return failure(); + if (failed(checkRecipe<acc::ReductionOp, acc::ReductionRecipeOp>( + *this, "reduction"))) + return failure(); return success(); } @@ -1042,6 +1170,65 @@ struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> { } }; +/// Remove empty acc.kernel_environment operations. If the operation has wait +/// operands, create a acc.wait operation to preserve synchronization. +struct RemoveEmptyKernelEnvironment + : public OpRewritePattern<acc::KernelEnvironmentOp> { + using OpRewritePattern<acc::KernelEnvironmentOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(acc::KernelEnvironmentOp op, + PatternRewriter &rewriter) const override { + assert(op->getNumRegions() == 1 && "expected op to have one region"); + + Block &block = op.getRegion().front(); + if (!block.empty()) + return failure(); + + // Conservatively disable canonicalization of empty acc.kernel_environment + // operations if the wait operands in the kernel_environment cannot be fully + // represented by acc.wait operation. + + // Disable canonicalization if device type is not the default + if (auto deviceTypeAttr = op.getWaitOperandsDeviceTypeAttr()) { + for (auto attr : deviceTypeAttr) { + if (auto dtAttr = mlir::dyn_cast<acc::DeviceTypeAttr>(attr)) { + if (dtAttr.getValue() != mlir::acc::DeviceType::None) + return failure(); + } + } + } + + // Disable canonicalization if any wait segment has a devnum + if (auto hasDevnumAttr = op.getHasWaitDevnumAttr()) { + for (auto attr : hasDevnumAttr) { + if (auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>(attr)) { + if (boolAttr.getValue()) + return failure(); + } + } + } + + // Disable canonicalization if there are multiple wait segments + if (auto segmentsAttr = op.getWaitOperandsSegmentsAttr()) { + if (segmentsAttr.size() > 1) + return failure(); + } + + // Remove empty kernel environment. + // Preserve synchronization by creating acc.wait operation if needed. + if (!op.getWaitOperands().empty() || op.getWaitOnlyAttr()) + rewriter.replaceOpWithNewOp<acc::WaitOp>(op, op.getWaitOperands(), + /*asyncOperand=*/Value(), + /*waitDevnum=*/Value(), + /*async=*/nullptr, + /*ifCond=*/Value()); + else + rewriter.eraseOp(op); + + return success(); + } +}; + //===----------------------------------------------------------------------===// // Recipe Region Helpers //===----------------------------------------------------------------------===// @@ -1263,6 +1450,28 @@ PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc, return recipe; } +std::optional<PrivateRecipeOp> +PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc, + StringRef recipeName, + FirstprivateRecipeOp firstprivRecipe) { + // Create the private.recipe op with the same type as the firstprivate.recipe. + OpBuilder::InsertionGuard guard(builder); + auto varType = firstprivRecipe.getType(); + auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType); + + // Clone the init region + IRMapping mapping; + firstprivRecipe.getInitRegion().cloneInto(&recipe.getInitRegion(), mapping); + + // Clone destroy region if the firstprivate.recipe has one. + if (!firstprivRecipe.getDestroyRegion().empty()) { + IRMapping mapping; + firstprivRecipe.getDestroyRegion().cloneInto(&recipe.getDestroyRegion(), + mapping); + } + return recipe; +} + //===----------------------------------------------------------------------===// // FirstprivateRecipeOp //===----------------------------------------------------------------------===// @@ -1373,40 +1582,6 @@ LogicalResult acc::ReductionRecipeOp::verifyRegions() { } //===----------------------------------------------------------------------===// -// Custom parser and printer verifier for private clause -//===----------------------------------------------------------------------===// - -static ParseResult parseSymOperandList( - mlir::OpAsmParser &parser, - llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands, - llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &symbols) { - llvm::SmallVector<SymbolRefAttr> attributes; - if (failed(parser.parseCommaSeparatedList([&]() { - if (parser.parseAttribute(attributes.emplace_back()) || - parser.parseArrow() || - parser.parseOperand(operands.emplace_back()) || - parser.parseColonType(types.emplace_back())) - return failure(); - return success(); - }))) - return failure(); - llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(), - attributes.end()); - symbols = ArrayAttr::get(parser.getContext(), arrayAttr); - return success(); -} - -static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, - mlir::OperandRange operands, - mlir::TypeRange types, - std::optional<mlir::ArrayAttr> attributes) { - llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](auto it) { - p << std::get<0>(it) << " -> " << std::get<1>(it) << " : " - << std::get<1>(it).getType(); - }); -} - -//===----------------------------------------------------------------------===// // ParallelOp //===----------------------------------------------------------------------===// @@ -1425,45 +1600,19 @@ static LogicalResult checkDataOperands(Op op, return success(); } -template <typename Op> -static LogicalResult -checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes, - mlir::OperandRange operands, llvm::StringRef operandName, - llvm::StringRef symbolName, bool checkOperandType = true) { - if (!operands.empty()) { - if (!attributes || attributes->size() != operands.size()) - return op->emitOpError() - << "expected as many " << symbolName << " symbol reference as " - << operandName << " operands"; - } else { - if (attributes) - return op->emitOpError() - << "unexpected " << symbolName << " symbol reference"; - return success(); - } - +template <typename OpT, typename RecipeOpT> +static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp, + const mlir::ValueRange &operands, + llvm::StringRef operandName) { llvm::DenseSet<Value> set; - for (auto args : llvm::zip(operands, *attributes)) { - mlir::Value operand = std::get<0>(args); - + for (mlir::Value operand : operands) { + if (!mlir::isa<OpT>(operand.getDefiningOp())) + return accConstructOp->emitOpError() + << "expected " << operandName << " as defining op"; if (!set.insert(operand).second) - return op->emitOpError() + return accConstructOp->emitOpError() << operandName << " operand appears more than once"; - - mlir::Type varType = operand.getType(); - auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args)); - auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef); - if (!decl) - return op->emitOpError() - << "expected symbol reference " << symbolRef << " to point to a " - << operandName << " declaration"; - - if (checkOperandType && decl.getType() && decl.getType() != varType) - return op->emitOpError() << "expected " << operandName << " (" << varType - << ") to be the same type as " << operandName - << " declaration (" << decl.getType() << ")"; } - return success(); } @@ -1520,17 +1669,17 @@ static LogicalResult verifyDeviceTypeAndSegmentCountMatch( } LogicalResult acc::ParallelOp::verify() { - if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>( - *this, getPrivatizationRecipes(), getPrivateOperands(), "private", - "privatizations", /*checkOperandType=*/false))) + if (failed(checkPrivateOperands<mlir::acc::PrivateOp, + mlir::acc::PrivateRecipeOp>( + *this, getPrivateOperands(), "private"))) return failure(); - if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>( - *this, getFirstprivatizationRecipes(), getFirstprivateOperands(), - "firstprivate", "firstprivatizations", /*checkOperandType=*/false))) + if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp, + mlir::acc::FirstprivateRecipeOp>( + *this, getFirstprivateOperands(), "firstprivate"))) return failure(); - if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>( - *this, getReductionRecipes(), getReductionOperands(), "reduction", - "reductions", false))) + if (failed(checkPrivateOperands<mlir::acc::ReductionOp, + mlir::acc::ReductionRecipeOp>( + *this, getReductionOperands(), "reduction"))) return failure(); if (failed(verifyDeviceTypeAndSegmentCountMatch( @@ -1661,7 +1810,6 @@ void ParallelOp::build(mlir::OpBuilder &odsBuilder, mlir::ValueRange gangPrivateOperands, mlir::ValueRange gangFirstPrivateOperands, mlir::ValueRange dataClauseOperands) { - ParallelOp::build( odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr, /*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr, @@ -1670,9 +1818,8 @@ void ParallelOp::build(mlir::OpBuilder &odsBuilder, /*numGangsDeviceType=*/nullptr, numWorkers, /*numWorkersDeviceType=*/nullptr, vectorLength, /*vectorLengthDeviceType=*/nullptr, ifCond, selfCond, - /*selfAttr=*/nullptr, reductionOperands, /*reductionRecipes=*/nullptr, - gangPrivateOperands, /*privatizations=*/nullptr, gangFirstPrivateOperands, - /*firstprivatizations=*/nullptr, dataClauseOperands, + /*selfAttr=*/nullptr, reductionOperands, gangPrivateOperands, + gangFirstPrivateOperands, dataClauseOperands, /*defaultAttr=*/nullptr, /*combined=*/nullptr); } @@ -1749,46 +1896,22 @@ void acc::ParallelOp::addWaitOperands( void acc::ParallelOp::addPrivatization(MLIRContext *context, mlir::acc::PrivateOp op, mlir::acc::PrivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getPrivateOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getPrivatizationRecipesAttr()) - llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } void acc::ParallelOp::addFirstPrivatization( MLIRContext *context, mlir::acc::FirstprivateOp op, mlir::acc::FirstprivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getFirstprivateOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getFirstprivatizationRecipesAttr()) - llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } void acc::ParallelOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op, mlir::acc::ReductionRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getReductionOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getReductionRecipesAttr()) - llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } static ParseResult parseNumGangs( @@ -2356,17 +2479,17 @@ mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) { } LogicalResult acc::SerialOp::verify() { - if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>( - *this, getPrivatizationRecipes(), getPrivateOperands(), "private", - "privatizations", /*checkOperandType=*/false))) + if (failed(checkPrivateOperands<mlir::acc::PrivateOp, + mlir::acc::PrivateRecipeOp>( + *this, getPrivateOperands(), "private"))) return failure(); - if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>( - *this, getFirstprivatizationRecipes(), getFirstprivateOperands(), - "firstprivate", "firstprivatizations", /*checkOperandType=*/false))) + if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp, + mlir::acc::FirstprivateRecipeOp>( + *this, getFirstprivateOperands(), "firstprivate"))) return failure(); - if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>( - *this, getReductionRecipes(), getReductionOperands(), "reduction", - "reductions", false))) + if (failed(checkPrivateOperands<mlir::acc::ReductionOp, + mlir::acc::ReductionRecipeOp>( + *this, getReductionOperands(), "reduction"))) return failure(); if (failed(verifyDeviceTypeAndSegmentCountMatch( @@ -2430,46 +2553,22 @@ void acc::SerialOp::addWaitOperands( void acc::SerialOp::addPrivatization(MLIRContext *context, mlir::acc::PrivateOp op, mlir::acc::PrivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getPrivateOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getPrivatizationRecipesAttr()) - llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } void acc::SerialOp::addFirstPrivatization( MLIRContext *context, mlir::acc::FirstprivateOp op, mlir::acc::FirstprivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getFirstprivateOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getFirstprivatizationRecipesAttr()) - llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } void acc::SerialOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op, mlir::acc::ReductionRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getReductionOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getReductionRecipesAttr()) - llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } //===----------------------------------------------------------------------===// @@ -2599,6 +2698,27 @@ LogicalResult acc::KernelsOp::verify() { return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands()); } +void acc::KernelsOp::addPrivatization(MLIRContext *context, + mlir::acc::PrivateOp op, + mlir::acc::PrivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); + getPrivateOperandsMutable().append(op.getResult()); +} + +void acc::KernelsOp::addFirstPrivatization( + MLIRContext *context, mlir::acc::FirstprivateOp op, + mlir::acc::FirstprivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); + getFirstprivateOperandsMutable().append(op.getResult()); +} + +void acc::KernelsOp::addReduction(MLIRContext *context, + mlir::acc::ReductionOp op, + mlir::acc::ReductionRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); + getReductionOperandsMutable().append(op.getResult()); +} + void acc::KernelsOp::addNumWorkersOperand( MLIRContext *context, mlir::Value newValue, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) { @@ -2691,6 +2811,15 @@ void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results, } //===----------------------------------------------------------------------===// +// KernelEnvironmentOp +//===----------------------------------------------------------------------===// + +void acc::KernelEnvironmentOp::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + results.add<RemoveEmptyKernelEnvironment>(context); +} + +//===----------------------------------------------------------------------===// // LoopOp //===----------------------------------------------------------------------===// @@ -2899,19 +3028,21 @@ bool hasDuplicateDeviceTypes( } /// Check for duplicates in the DeviceType array attribute. -LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes) { +/// Returns std::nullopt if no duplicates, or the duplicate DeviceType if found. +static std::optional<mlir::acc::DeviceType> +checkDeviceTypes(mlir::ArrayAttr deviceTypes) { llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes; if (!deviceTypes) - return success(); + return std::nullopt; for (auto attr : deviceTypes) { auto deviceTypeAttr = mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr); if (!deviceTypeAttr) - return failure(); + return mlir::acc::DeviceType::None; if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second) - return failure(); + return deviceTypeAttr.getValue(); } - return success(); + return std::nullopt; } LogicalResult acc::LoopOp::verify() { @@ -2938,9 +3069,10 @@ LogicalResult acc::LoopOp::verify() { getCollapseDeviceTypeAttr().getValue().size()) return emitOpError() << "collapse attribute count must match collapse" << " device_type count"; - if (failed(checkDeviceTypes(getCollapseDeviceTypeAttr()))) - return emitOpError() - << "duplicate device_type found in collapseDeviceType attribute"; + if (auto duplicateDeviceType = checkDeviceTypes(getCollapseDeviceTypeAttr())) + return emitOpError() << "duplicate device_type `" + << acc::stringifyDeviceType(*duplicateDeviceType) + << "` found in collapseDeviceType attribute"; // Check gang if (!getGangOperands().empty()) { @@ -2953,8 +3085,12 @@ LogicalResult acc::LoopOp::verify() { return emitOpError() << "gangOperandsArgType attribute count must match" << " gangOperands count"; } - if (getGangAttr() && failed(checkDeviceTypes(getGangAttr()))) - return emitOpError() << "duplicate device_type found in gang attribute"; + if (getGangAttr()) { + if (auto duplicateDeviceType = checkDeviceTypes(getGangAttr())) + return emitOpError() << "duplicate device_type `" + << acc::stringifyDeviceType(*duplicateDeviceType) + << "` found in gang attribute"; + } if (failed(verifyDeviceTypeAndSegmentCountMatch( *this, getGangOperands(), getGangOperandsSegmentsAttr(), @@ -2962,22 +3098,30 @@ LogicalResult acc::LoopOp::verify() { return failure(); // Check worker - if (failed(checkDeviceTypes(getWorkerAttr()))) - return emitOpError() << "duplicate device_type found in worker attribute"; - if (failed(checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr()))) - return emitOpError() << "duplicate device_type found in " - "workerNumOperandsDeviceType attribute"; + if (auto duplicateDeviceType = checkDeviceTypes(getWorkerAttr())) + return emitOpError() << "duplicate device_type `" + << acc::stringifyDeviceType(*duplicateDeviceType) + << "` found in worker attribute"; + if (auto duplicateDeviceType = + checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr())) + return emitOpError() << "duplicate device_type `" + << acc::stringifyDeviceType(*duplicateDeviceType) + << "` found in workerNumOperandsDeviceType attribute"; if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(), getWorkerNumOperandsDeviceTypeAttr(), "worker"))) return failure(); // Check vector - if (failed(checkDeviceTypes(getVectorAttr()))) - return emitOpError() << "duplicate device_type found in vector attribute"; - if (failed(checkDeviceTypes(getVectorOperandsDeviceTypeAttr()))) - return emitOpError() << "duplicate device_type found in " - "vectorOperandsDeviceType attribute"; + if (auto duplicateDeviceType = checkDeviceTypes(getVectorAttr())) + return emitOpError() << "duplicate device_type `" + << acc::stringifyDeviceType(*duplicateDeviceType) + << "` found in vector attribute"; + if (auto duplicateDeviceType = + checkDeviceTypes(getVectorOperandsDeviceTypeAttr())) + return emitOpError() << "duplicate device_type `" + << acc::stringifyDeviceType(*duplicateDeviceType) + << "` found in vectorOperandsDeviceType attribute"; if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(), getVectorOperandsDeviceTypeAttr(), "vector"))) @@ -3042,19 +3186,19 @@ LogicalResult acc::LoopOp::verify() { } } - if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>( - *this, getPrivatizationRecipes(), getPrivateOperands(), "private", - "privatizations", false))) + if (failed(checkPrivateOperands<mlir::acc::PrivateOp, + mlir::acc::PrivateRecipeOp>( + *this, getPrivateOperands(), "private"))) return failure(); - if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>( - *this, getFirstprivatizationRecipes(), getFirstprivateOperands(), - "firstprivate", "firstprivatizations", /*checkOperandType=*/false))) + if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp, + mlir::acc::FirstprivateRecipeOp>( + *this, getFirstprivateOperands(), "firstprivate"))) return failure(); - if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>( - *this, getReductionRecipes(), getReductionOperands(), "reduction", - "reductions", false))) + if (failed(checkPrivateOperands<mlir::acc::ReductionOp, + mlir::acc::ReductionRecipeOp>( + *this, getReductionOperands(), "reduction"))) return failure(); if (getCombined().has_value() && @@ -3068,8 +3212,12 @@ LogicalResult acc::LoopOp::verify() { if (getRegion().empty()) return emitError("expected non-empty body."); - // When it is container-like - it is expected to hold a loop-like operation. - if (isContainerLike()) { + if (getUnstructured()) { + if (!isContainerLike()) + return emitError( + "unstructured acc.loop must not have induction variables"); + } else if (isContainerLike()) { + // When it is container-like - it is expected to hold a loop-like operation. // Obtain the maximum collapse count - we use this to check that there // are enough loops contained. uint64_t collapseCount = getCollapseValue().value_or(1); @@ -3484,45 +3632,21 @@ void acc::LoopOp::addGangOperands( void acc::LoopOp::addPrivatization(MLIRContext *context, mlir::acc::PrivateOp op, mlir::acc::PrivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getPrivateOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getPrivatizationRecipesAttr()) - llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } void acc::LoopOp::addFirstPrivatization( MLIRContext *context, mlir::acc::FirstprivateOp op, mlir::acc::FirstprivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getFirstprivateOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getFirstprivatizationRecipesAttr()) - llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } void acc::LoopOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op, mlir::acc::ReductionRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getReductionOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getReductionRecipesAttr()) - llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } //===----------------------------------------------------------------------===// @@ -3987,7 +4111,8 @@ LogicalResult acc::RoutineOp::verify() { if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1)) return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can " - "be present at the same time"; + "be present at the same time for device_type `" + << acc::stringifyDeviceType(dtype) << "`"; } return success(); @@ -4284,6 +4409,100 @@ RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) { return std::nullopt; } +void RoutineOp::addSeq(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes) { + setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(), + effectiveDeviceTypes)); +} + +void RoutineOp::addVector(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes) { + setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(), + effectiveDeviceTypes)); +} + +void RoutineOp::addWorker(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes) { + setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(), + effectiveDeviceTypes)); +} + +void RoutineOp::addGang(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes) { + setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(), + effectiveDeviceTypes)); +} + +void RoutineOp::addGang(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes, + uint64_t val) { + llvm::SmallVector<mlir::Attribute> dimValues; + llvm::SmallVector<mlir::Attribute> deviceTypes; + + if (getGangDimAttr()) + llvm::copy(getGangDimAttr(), std::back_inserter(dimValues)); + if (getGangDimDeviceTypeAttr()) + llvm::copy(getGangDimDeviceTypeAttr(), std::back_inserter(deviceTypes)); + + assert(dimValues.size() == deviceTypes.size()); + + if (effectiveDeviceTypes.empty()) { + dimValues.push_back( + mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val)); + deviceTypes.push_back( + acc::DeviceTypeAttr::get(context, acc::DeviceType::None)); + } else { + for (DeviceType dt : effectiveDeviceTypes) { + dimValues.push_back( + mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val)); + deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt)); + } + } + assert(dimValues.size() == deviceTypes.size()); + + setGangDimAttr(mlir::ArrayAttr::get(context, dimValues)); + setGangDimDeviceTypeAttr(mlir::ArrayAttr::get(context, deviceTypes)); +} + +void RoutineOp::addBindStrName(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes, + mlir::StringAttr val) { + unsigned before = getBindStrNameDeviceTypeAttr() + ? getBindStrNameDeviceTypeAttr().size() + : 0; + + setBindStrNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper( + context, getBindStrNameDeviceTypeAttr(), effectiveDeviceTypes)); + unsigned after = getBindStrNameDeviceTypeAttr().size(); + + llvm::SmallVector<mlir::Attribute> vals; + if (getBindStrNameAttr()) + llvm::copy(getBindStrNameAttr(), std::back_inserter(vals)); + for (unsigned i = 0; i < after - before; ++i) + vals.push_back(val); + + setBindStrNameAttr(mlir::ArrayAttr::get(context, vals)); +} + +void RoutineOp::addBindIDName(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes, + mlir::SymbolRefAttr val) { + unsigned before = + getBindIdNameDeviceTypeAttr() ? getBindIdNameDeviceTypeAttr().size() : 0; + + setBindIdNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper( + context, getBindIdNameDeviceTypeAttr(), effectiveDeviceTypes)); + unsigned after = getBindIdNameDeviceTypeAttr().size(); + + llvm::SmallVector<mlir::Attribute> vals; + if (getBindIdNameAttr()) + llvm::copy(getBindIdNameAttr(), std::back_inserter(vals)); + for (unsigned i = 0; i < after - before; ++i) + vals.push_back(val); + + setBindIdNameAttr(mlir::ArrayAttr::get(context, vals)); +} + //===----------------------------------------------------------------------===// // InitOp //===----------------------------------------------------------------------===// @@ -4667,3 +4886,12 @@ mlir::acc::getMutableDataOperands(mlir::Operation *accOp) { .Default([&](mlir::Operation *) { return nullptr; })}; return dataOperands; } + +mlir::SymbolRefAttr mlir::acc::getRecipe(mlir::Operation *accOp) { + auto recipe{ + llvm::TypeSwitch<mlir::Operation *, mlir::SymbolRefAttr>(accOp) + .Case<ACC_DATA_ENTRY_OPS>( + [&](auto entry) { return entry.getRecipeAttr(); }) + .Default([&](mlir::Operation *) { return mlir::SymbolRefAttr{}; })}; + return recipe; +} diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp new file mode 100644 index 0000000..67cdf10 --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp @@ -0,0 +1,781 @@ +//===- ACCImplicitData.cpp ------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass implements the OpenACC specification for "Variables with +// Implicitly Determined Data Attributes" (OpenACC 3.4 spec, section 2.6.2). +// +// Overview: +// --------- +// The pass automatically generates data clause operations for variables used +// within OpenACC compute constructs (parallel, kernels, serial) that do not +// already have explicit data clauses. The semantics follow these rules: +// +// 1. If there is a default(none) clause visible, no implicit data actions +// apply. +// +// 2. An aggregate variable (arrays, derived types, etc.) will be treated as: +// - In a present clause when default(present) is visible. +// - In a copy clause otherwise. +// +// 3. A scalar variable will be treated as if it appears in: +// - A copy clause if the compute construct is a kernels construct. +// - A firstprivate clause otherwise (parallel, serial). +// +// Requirements: +// ------------- +// To use this pass in a pipeline, the following requirements must be met: +// +// 1. Type Interface Implementation: Variables from the dialect being used +// must implement one or both of the following MLIR interfaces: +// `acc::MappableType` and/or `acc::PointerLikeType` +// +// These interfaces provide the necessary methods for the pass to: +// - Determine variable type categories (scalar vs. aggregate) +// - Generate appropriate bounds information +// - Generate privatization recipes +// +// 2. Operation Interface Implementation: Operations that access partial +// entities or create views should implement the following MLIR +// interfaces: `acc::PartialEntityAccess` and/or +// `mlir::ViewLikeOpInterface` +// +// These interfaces are used for proper data clause ordering, ensuring +// that base entities are mapped before derived entities (e.g., a +// struct is mapped before its fields, an array is mapped before +// subarray views). +// +// 3. Analysis Registration (Optional): If custom behavior is needed for +// variable name extraction or alias analysis, the dialect should +// pre-register the `acc::OpenACCSupport` and `mlir::AliasAnalysis` analyses. +// +// If not registered, default behavior will be used. +// +// Implementation Details: +// ----------------------- +// The pass performs the following operations: +// +// 1. Finds candidate variables which are live-in to the compute region and +// are not already in a data clause or private clause. +// +// 2. Generates both data "entry" and "exit" clause operations that match +// the intended action depending on variable type: +// - copy -> acc.copyin (entry) + acc.copyout (exit) +// - present -> acc.present (entry) + acc.delete (exit) +// - firstprivate -> acc.firstprivate (entry only, no exit) +// +// 3. Ensures that default clause is taken into consideration by looking +// through current construct and parent constructs to find the "visible +// default clause". +// +// 4. Fixes up SSA value links so that uses in the acc region reference the +// result of the newly created data clause operations. +// +// 5. When generating implicit data clause operations, it also adds variable +// name information and marks them with the implicit flag. +// +// 6. Recipes are generated by calling the appropriate entrypoints in the +// MappableType and PointerLikeType interfaces. +// +// 7. AliasAnalysis is used to determine if a variable is already covered by +// an existing data clause (e.g., an interior pointer covered by its parent). +// +// Examples: +// --------- +// +// Example 1: Scalar in parallel construct (implicit firstprivate) +// +// Before: +// func.func @test() { +// %scalar = memref.alloca() {acc.var_name = "x"} : memref<f32> +// acc.parallel { +// %val = memref.load %scalar[] : memref<f32> +// acc.yield +// } +// } +// +// After: +// func.func @test() { +// %scalar = memref.alloca() {acc.var_name = "x"} : memref<f32> +// %firstpriv = acc.firstprivate varPtr(%scalar : memref<f32>) +// -> memref<f32> {implicit = true, name = "x"} +// acc.parallel firstprivate(@recipe -> %firstpriv : memref<f32>) { +// %val = memref.load %firstpriv[] : memref<f32> +// acc.yield +// } +// } +// +// Example 2: Scalar in kernels construct (implicit copy) +// +// Before: +// func.func @test() { +// %scalar = memref.alloca() {acc.var_name = "n"} : memref<i32> +// acc.kernels { +// %val = memref.load %scalar[] : memref<i32> +// acc.terminator +// } +// } +// +// After: +// func.func @test() { +// %scalar = memref.alloca() {acc.var_name = "n"} : memref<i32> +// %copyin = acc.copyin varPtr(%scalar : memref<i32>) -> memref<i32> +// {dataClause = #acc<data_clause acc_copy>, +// implicit = true, name = "n"} +// acc.kernels dataOperands(%copyin : memref<i32>) { +// %val = memref.load %copyin[] : memref<i32> +// acc.terminator +// } +// acc.copyout accPtr(%copyin : memref<i32>) +// to varPtr(%scalar : memref<i32>) +// {dataClause = #acc<data_clause acc_copy>, +// implicit = true, name = "n"} +// } +// +// Example 3: Array (aggregate) in parallel (implicit copy) +// +// Before: +// func.func @test() { +// %array = memref.alloca() {acc.var_name = "arr"} : memref<100xf32> +// acc.parallel { +// %c0 = arith.constant 0 : index +// %val = memref.load %array[%c0] : memref<100xf32> +// acc.yield +// } +// } +// +// After: +// func.func @test() { +// %array = memref.alloca() {acc.var_name = "arr"} : memref<100xf32> +// %copyin = acc.copyin varPtr(%array : memref<100xf32>) +// -> memref<100xf32> +// {dataClause = #acc<data_clause acc_copy>, +// implicit = true, name = "arr"} +// acc.parallel dataOperands(%copyin : memref<100xf32>) { +// %c0 = arith.constant 0 : index +// %val = memref.load %copyin[%c0] : memref<100xf32> +// acc.yield +// } +// acc.copyout accPtr(%copyin : memref<100xf32>) +// to varPtr(%array : memref<100xf32>) +// {dataClause = #acc<data_clause acc_copy>, +// implicit = true, name = "arr"} +// } +// +// Example 4: Array with default(present) +// +// Before: +// func.func @test() { +// %array = memref.alloca() {acc.var_name = "arr"} : memref<100xf32> +// acc.parallel { +// %c0 = arith.constant 0 : index +// %val = memref.load %array[%c0] : memref<100xf32> +// acc.yield +// } attributes {defaultAttr = #acc<defaultvalue present>} +// } +// +// After: +// func.func @test() { +// %array = memref.alloca() {acc.var_name = "arr"} : memref<100xf32> +// %present = acc.present varPtr(%array : memref<100xf32>) +// -> memref<100xf32> +// {implicit = true, name = "arr"} +// acc.parallel dataOperands(%present : memref<100xf32>) +// attributes {defaultAttr = #acc<defaultvalue present>} { +// %c0 = arith.constant 0 : index +// %val = memref.load %present[%c0] : memref<100xf32> +// acc.yield +// } +// acc.delete accPtr(%present : memref<100xf32>) +// {dataClause = #acc<data_clause acc_present>, +// implicit = true, name = "arr"} +// } +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/OpenACC/Transforms/Passes.h" + +#include "mlir/Analysis/AliasAnalysis.h" +#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/Dialect/OpenACC/OpenACCUtils.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Interfaces/ViewLikeInterface.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/ErrorHandling.h" +#include <type_traits> + +namespace mlir { +namespace acc { +#define GEN_PASS_DEF_ACCIMPLICITDATA +#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc" +} // namespace acc +} // namespace mlir + +#define DEBUG_TYPE "acc-implicit-data" + +using namespace mlir; + +namespace { + +class ACCImplicitData : public acc::impl::ACCImplicitDataBase<ACCImplicitData> { +public: + using acc::impl::ACCImplicitDataBase<ACCImplicitData>::ACCImplicitDataBase; + + void runOnOperation() override; + +private: + /// Looks through the `dominatingDataClauses` to find the original data clause + /// op for an alias. Returns nullptr if no original data clause op is found. + template <typename OpT> + Operation *getOriginalDataClauseOpForAlias( + Value var, OpBuilder &builder, OpT computeConstructOp, + const SmallVector<Value> &dominatingDataClauses); + + /// Generates the appropriate `acc.copyin`, `acc.present`,`acc.firstprivate`, + /// etc. data clause op for a candidate variable. + template <typename OpT> + Operation *generateDataClauseOpForCandidate( + Value var, ModuleOp &module, OpBuilder &builder, OpT computeConstructOp, + const SmallVector<Value> &dominatingDataClauses, + const std::optional<acc::ClauseDefaultValue> &defaultClause); + + /// Generates the implicit data ops for a compute construct. + template <typename OpT> + void generateImplicitDataOps( + ModuleOp &module, OpT computeConstructOp, + std::optional<acc::ClauseDefaultValue> &defaultClause); + + /// Generates a private recipe for a variable. + acc::PrivateRecipeOp generatePrivateRecipe(ModuleOp &module, Value var, + Location loc, OpBuilder &builder, + acc::OpenACCSupport &accSupport); + + /// Generates a firstprivate recipe for a variable. + acc::FirstprivateRecipeOp + generateFirstprivateRecipe(ModuleOp &module, Value var, Location loc, + OpBuilder &builder, + acc::OpenACCSupport &accSupport); + + /// Generates recipes for a list of variables. + void generateRecipes(ModuleOp &module, OpBuilder &builder, + Operation *computeConstructOp, + const SmallVector<Value> &newOperands); +}; + +/// Determines if a variable is a candidate for implicit data mapping. +/// Returns true if the variable is a candidate, false otherwise. +static bool isCandidateForImplicitData(Value val, Region &accRegion) { + // Ensure the variable is an allowed type for data clause. + if (!acc::isPointerLikeType(val.getType()) && + !acc::isMappableType(val.getType())) + return false; + + // If this is already coming from a data clause, we do not need to generate + // another. + if (isa_and_nonnull<ACC_DATA_ENTRY_OPS>(val.getDefiningOp())) + return false; + + // If this is only used by private clauses, it is not a real live-in. + if (acc::isOnlyUsedByPrivateClauses(val, accRegion)) + return false; + + return true; +} + +template <typename OpT> +Operation *ACCImplicitData::getOriginalDataClauseOpForAlias( + Value var, OpBuilder &builder, OpT computeConstructOp, + const SmallVector<Value> &dominatingDataClauses) { + auto &aliasAnalysis = this->getAnalysis<AliasAnalysis>(); + for (auto dataClause : dominatingDataClauses) { + if (auto *dataClauseOp = dataClause.getDefiningOp()) { + // Only accept clauses that guarantee that the alias is present. + if (isa<acc::CopyinOp, acc::CreateOp, acc::PresentOp, acc::NoCreateOp, + acc::DevicePtrOp>(dataClauseOp)) + if (aliasAnalysis.alias(acc::getVar(dataClauseOp), var).isMust()) + return dataClauseOp; + } + } + return nullptr; +} + +// Generates bounds for variables that have unknown dimensions +static void fillInBoundsForUnknownDimensions(Operation *dataClauseOp, + OpBuilder &builder) { + + if (!acc::getBounds(dataClauseOp).empty()) + // If bounds are already present, do not overwrite them. + return; + + // For types that have unknown dimensions, attempt to generate bounds by + // relying on MappableType being able to extract it from the IR. + auto var = acc::getVar(dataClauseOp); + auto type = var.getType(); + if (auto mappableTy = dyn_cast<acc::MappableType>(type)) { + if (mappableTy.hasUnknownDimensions()) { + TypeSwitch<Operation *>(dataClauseOp) + .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClauseOp) { + if (std::is_same_v<decltype(dataClauseOp), acc::DevicePtrOp>) + return; + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(dataClauseOp); + auto bounds = mappableTy.generateAccBounds(var, builder); + if (!bounds.empty()) + dataClauseOp.getBoundsMutable().assign(bounds); + }); + } + } +} + +acc::PrivateRecipeOp +ACCImplicitData::generatePrivateRecipe(ModuleOp &module, Value var, + Location loc, OpBuilder &builder, + acc::OpenACCSupport &accSupport) { + auto type = var.getType(); + std::string recipeName = + accSupport.getRecipeName(acc::RecipeKind::private_recipe, type, var); + + // Check if recipe already exists + auto existingRecipe = module.lookupSymbol<acc::PrivateRecipeOp>(recipeName); + if (existingRecipe) + return existingRecipe; + + // Set insertion point to module body in a scoped way + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(module.getBody()); + + auto recipe = + acc::PrivateRecipeOp::createAndPopulate(builder, loc, recipeName, type); + if (!recipe.has_value()) + return accSupport.emitNYI(loc, "implicit private"), nullptr; + return recipe.value(); +} + +acc::FirstprivateRecipeOp +ACCImplicitData::generateFirstprivateRecipe(ModuleOp &module, Value var, + Location loc, OpBuilder &builder, + acc::OpenACCSupport &accSupport) { + auto type = var.getType(); + std::string recipeName = + accSupport.getRecipeName(acc::RecipeKind::firstprivate_recipe, type, var); + + // Check if recipe already exists + auto existingRecipe = + module.lookupSymbol<acc::FirstprivateRecipeOp>(recipeName); + if (existingRecipe) + return existingRecipe; + + // Set insertion point to module body in a scoped way + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(module.getBody()); + + auto recipe = acc::FirstprivateRecipeOp::createAndPopulate(builder, loc, + recipeName, type); + if (!recipe.has_value()) + return accSupport.emitNYI(loc, "implicit firstprivate"), nullptr; + return recipe.value(); +} + +void ACCImplicitData::generateRecipes(ModuleOp &module, OpBuilder &builder, + Operation *computeConstructOp, + const SmallVector<Value> &newOperands) { + auto &accSupport = this->getAnalysis<acc::OpenACCSupport>(); + for (auto var : newOperands) { + auto loc{var.getLoc()}; + if (auto privateOp = dyn_cast<acc::PrivateOp>(var.getDefiningOp())) { + auto recipe = generatePrivateRecipe( + module, acc::getVar(var.getDefiningOp()), loc, builder, accSupport); + if (recipe) + privateOp.setRecipeAttr( + SymbolRefAttr::get(module->getContext(), recipe.getSymName())); + } else if (auto firstprivateOp = + dyn_cast<acc::FirstprivateOp>(var.getDefiningOp())) { + auto recipe = generateFirstprivateRecipe( + module, acc::getVar(var.getDefiningOp()), loc, builder, accSupport); + if (recipe) + firstprivateOp.setRecipeAttr(SymbolRefAttr::get( + module->getContext(), recipe.getSymName().str())); + } else { + accSupport.emitNYI(var.getLoc(), "implicit reduction"); + } + } +} + +// Generates the data entry data op clause so that it adheres to OpenACC +// rules as follows (line numbers and specification from OpenACC 3.4): +// 1388 An aggregate variable will be treated as if it appears either: +// 1389 - In a present clause if there is a default(present) clause visible at +// the compute construct. +// 1391 - In a copy clause otherwise. +// 1392 A scalar variable will be treated as if it appears either: +// 1393 - In a copy clause if the compute construct is a kernels construct. +// 1394 - In a firstprivate clause otherwise. +template <typename OpT> +Operation *ACCImplicitData::generateDataClauseOpForCandidate( + Value var, ModuleOp &module, OpBuilder &builder, OpT computeConstructOp, + const SmallVector<Value> &dominatingDataClauses, + const std::optional<acc::ClauseDefaultValue> &defaultClause) { + auto &accSupport = this->getAnalysis<acc::OpenACCSupport>(); + acc::VariableTypeCategory typeCategory = + acc::VariableTypeCategory::uncategorized; + if (auto mappableTy = dyn_cast<acc::MappableType>(var.getType())) { + typeCategory = mappableTy.getTypeCategory(var); + } else if (auto pointerLikeTy = + dyn_cast<acc::PointerLikeType>(var.getType())) { + typeCategory = pointerLikeTy.getPointeeTypeCategory( + cast<TypedValue<acc::PointerLikeType>>(var), + pointerLikeTy.getElementType()); + } + + bool isScalar = + acc::bitEnumContainsAny(typeCategory, acc::VariableTypeCategory::scalar); + bool isAnyAggregate = acc::bitEnumContainsAny( + typeCategory, acc::VariableTypeCategory::aggregate); + Location loc = computeConstructOp->getLoc(); + + Operation *op = nullptr; + op = getOriginalDataClauseOpForAlias(var, builder, computeConstructOp, + dominatingDataClauses); + if (op) { + if (isa<acc::NoCreateOp>(op)) + return acc::NoCreateOp::create(builder, loc, var, + /*structured=*/true, /*implicit=*/true, + accSupport.getVariableName(var), + acc::getBounds(op)); + + if (isa<acc::DevicePtrOp>(op)) + return acc::DevicePtrOp::create(builder, loc, var, + /*structured=*/true, /*implicit=*/true, + accSupport.getVariableName(var), + acc::getBounds(op)); + + // The original data clause op is a PresentOp, CopyinOp, or CreateOp, + // hence guaranteed to be present. + return acc::PresentOp::create(builder, loc, var, + /*structured=*/true, /*implicit=*/true, + accSupport.getVariableName(var), + acc::getBounds(op)); + } else if (isScalar) { + if (enableImplicitReductionCopy && + acc::isOnlyUsedByReductionClauses(var, + computeConstructOp->getRegion(0))) { + auto copyinOp = + acc::CopyinOp::create(builder, loc, var, + /*structured=*/true, /*implicit=*/true, + accSupport.getVariableName(var)); + copyinOp.setDataClause(acc::DataClause::acc_reduction); + return copyinOp.getOperation(); + } + if constexpr (std::is_same_v<OpT, acc::KernelsOp> || + std::is_same_v<OpT, acc::KernelEnvironmentOp>) { + // Scalars are implicit copyin in kernels construct. + // We also do the same for acc.kernel_environment because semantics + // of user variable mappings should be applied while ACC construct exists + // and at this point we should only be dealing with unmapped variables + // that were made live-in by the compiler. + // TODO: This may be revisited. + auto copyinOp = + acc::CopyinOp::create(builder, loc, var, + /*structured=*/true, /*implicit=*/true, + accSupport.getVariableName(var)); + copyinOp.setDataClause(acc::DataClause::acc_copy); + return copyinOp.getOperation(); + } else { + // Scalars are implicit firstprivate in parallel and serial construct. + return acc::FirstprivateOp::create(builder, loc, var, + /*structured=*/true, /*implicit=*/true, + accSupport.getVariableName(var)); + } + } else if (isAnyAggregate) { + Operation *newDataOp = nullptr; + + // When default(present) is true, the implicit behavior is present. + if (defaultClause.has_value() && + defaultClause.value() == acc::ClauseDefaultValue::Present) { + newDataOp = acc::PresentOp::create(builder, loc, var, + /*structured=*/true, /*implicit=*/true, + accSupport.getVariableName(var)); + newDataOp->setAttr(acc::getFromDefaultClauseAttrName(), + builder.getUnitAttr()); + } else { + auto copyinOp = + acc::CopyinOp::create(builder, loc, var, + /*structured=*/true, /*implicit=*/true, + accSupport.getVariableName(var)); + copyinOp.setDataClause(acc::DataClause::acc_copy); + newDataOp = copyinOp.getOperation(); + } + + return newDataOp; + } else { + // This is not a fatal error - for example when the element type is + // pointer type (aka we have a pointer of pointer), it is potentially a + // deep copy scenario which is not being handled here. + // Other types need to be canonicalized. Thus just log unhandled cases. + LLVM_DEBUG(llvm::dbgs() + << "Unhandled case for implicit data mapping " << var << "\n"); + } + return nullptr; +} + +// Ensures that result values from the acc data clause ops are used inside the +// acc region. ie: +// acc.kernels { +// use %val +// } +// => +// %dev = acc.dataop %val +// acc.kernels { +// use %dev +// } +static void legalizeValuesInRegion(Region &accRegion, + SmallVector<Value> &newPrivateOperands, + SmallVector<Value> &newDataClauseOperands) { + for (Value dataClause : + llvm::concat<Value>(newDataClauseOperands, newPrivateOperands)) { + Value var = acc::getVar(dataClause.getDefiningOp()); + replaceAllUsesInRegionWith(var, dataClause, accRegion); + } +} + +// Adds the private operands to the compute construct operation. +template <typename OpT> +static void addNewPrivateOperands(OpT &accOp, + const SmallVector<Value> &privateOperands) { + if (privateOperands.empty()) + return; + + for (auto priv : privateOperands) { + if (isa<acc::PrivateOp>(priv.getDefiningOp())) { + accOp.getPrivateOperandsMutable().append(priv); + } else if (isa<acc::FirstprivateOp>(priv.getDefiningOp())) { + accOp.getFirstprivateOperandsMutable().append(priv); + } else { + llvm_unreachable("unhandled reduction operand"); + } + } +} + +static Operation *findDataExitOp(Operation *dataEntryOp) { + auto res = acc::getAccVar(dataEntryOp); + for (auto *user : res.getUsers()) + if (isa<ACC_DATA_EXIT_OPS>(user)) + return user; + return nullptr; +} + +// Generates matching data exit operation as described in the acc dialect +// for how data clauses are decomposed: +// https://mlir.llvm.org/docs/Dialects/OpenACCDialect/#operation-categories +// Key ones used here: +// * acc {construct} copy -> acc.copyin (before region) + acc.copyout (after +// region) +// * acc {construct} present -> acc.present (before region) + acc.delete +// (after region) +static void +generateDataExitOperations(OpBuilder &builder, Operation *accOp, + const SmallVector<Value> &newDataClauseOperands, + const SmallVector<Value> &sortedDataClauseOperands) { + builder.setInsertionPointAfter(accOp); + Value lastDataClause = nullptr; + for (auto dataEntry : llvm::reverse(sortedDataClauseOperands)) { + if (llvm::find(newDataClauseOperands, dataEntry) == + newDataClauseOperands.end()) { + // If this is not a new data clause operand, we should not generate an + // exit operation for it. + lastDataClause = dataEntry; + continue; + } + if (lastDataClause) + if (auto *dataExitOp = findDataExitOp(lastDataClause.getDefiningOp())) + builder.setInsertionPointAfter(dataExitOp); + Operation *dataEntryOp = dataEntry.getDefiningOp(); + if (isa<acc::CopyinOp>(dataEntryOp)) { + auto copyoutOp = acc::CopyoutOp::create( + builder, dataEntryOp->getLoc(), dataEntry, acc::getVar(dataEntryOp), + /*structured=*/true, /*implicit=*/true, + acc::getVarName(dataEntryOp).value(), acc::getBounds(dataEntryOp)); + copyoutOp.setDataClause(acc::DataClause::acc_copy); + } else if (isa<acc::PresentOp, acc::NoCreateOp>(dataEntryOp)) { + auto deleteOp = acc::DeleteOp::create( + builder, dataEntryOp->getLoc(), dataEntry, + /*structured=*/true, /*implicit=*/true, + acc::getVarName(dataEntryOp).value(), acc::getBounds(dataEntryOp)); + deleteOp.setDataClause(acc::getDataClause(dataEntryOp).value()); + } else if (isa<acc::DevicePtrOp>(dataEntryOp)) { + // Do nothing. + } else { + llvm_unreachable("unhandled data exit"); + } + lastDataClause = dataEntry; + } +} + +/// Returns all base references of a value in order. +/// So for example, if we have a reference to a struct field like +/// s.f1.f2.f3, this will return <s, s.f1, s.f1.f2, s.f1.f2.f3>. +/// Any intermediate casts/view-like operations are included in the +/// chain as well. +static SmallVector<Value> getBaseRefsChain(Value val) { + SmallVector<Value> baseRefs; + baseRefs.push_back(val); + while (true) { + Value prevVal = val; + + val = acc::getBaseEntity(val); + if (val != baseRefs.front()) + baseRefs.insert(baseRefs.begin(), val); + + // If this is a view-like operation, it is effectively another + // view of the same entity so we should add it to the chain also. + if (auto viewLikeOp = val.getDefiningOp<ViewLikeOpInterface>()) { + val = viewLikeOp.getViewSource(); + baseRefs.insert(baseRefs.begin(), val); + } + + // Continue loop if we made any progress + if (val == prevVal) + break; + } + + return baseRefs; +} + +static void insertInSortedOrder(SmallVector<Value> &sortedDataClauseOperands, + Operation *newClause) { + auto *insertPos = + std::find_if(sortedDataClauseOperands.begin(), + sortedDataClauseOperands.end(), [&](Value dataClauseVal) { + // Get the base refs for the current clause we are looking + // at. + auto var = acc::getVar(dataClauseVal.getDefiningOp()); + auto baseRefs = getBaseRefsChain(var); + + // If the newClause is of a base ref of an existing clause, + // we should insert it right before the current clause. + // Thus return true to stop iteration when this is the + // case. + return std::find(baseRefs.begin(), baseRefs.end(), + acc::getVar(newClause)) != baseRefs.end(); + }); + + if (insertPos != sortedDataClauseOperands.end()) { + newClause->moveBefore(insertPos->getDefiningOp()); + sortedDataClauseOperands.insert(insertPos, acc::getAccVar(newClause)); + } else { + sortedDataClauseOperands.push_back(acc::getAccVar(newClause)); + } +} + +template <typename OpT> +void ACCImplicitData::generateImplicitDataOps( + ModuleOp &module, OpT computeConstructOp, + std::optional<acc::ClauseDefaultValue> &defaultClause) { + // Implicit data attributes are only applied if "[t]here is no default(none) + // clause visible at the compute construct." + if (defaultClause.has_value() && + defaultClause.value() == acc::ClauseDefaultValue::None) + return; + assert(!defaultClause.has_value() || + defaultClause.value() == acc::ClauseDefaultValue::Present); + + // 1) Collect live-in values. + Region &accRegion = computeConstructOp->getRegion(0); + SetVector<Value> liveInValues; + getUsedValuesDefinedAbove(accRegion, liveInValues); + + // 2) Run the filtering to find relevant pointers that need copied. + auto isCandidate{[&](Value val) -> bool { + return isCandidateForImplicitData(val, accRegion); + }}; + auto candidateVars( + llvm::to_vector(llvm::make_filter_range(liveInValues, isCandidate))); + if (candidateVars.empty()) + return; + + // 3) Generate data clauses for the variables. + SmallVector<Value> newPrivateOperands; + SmallVector<Value> newDataClauseOperands; + OpBuilder builder(computeConstructOp); + if (!candidateVars.empty()) { + LLVM_DEBUG(llvm::dbgs() << "== Generating clauses for ==\n" + << computeConstructOp << "\n"); + } + auto &domInfo = this->getAnalysis<DominanceInfo>(); + auto &postDomInfo = this->getAnalysis<PostDominanceInfo>(); + auto dominatingDataClauses = + acc::getDominatingDataClauses(computeConstructOp, domInfo, postDomInfo); + for (auto var : candidateVars) { + auto newDataClauseOp = generateDataClauseOpForCandidate( + var, module, builder, computeConstructOp, dominatingDataClauses, + defaultClause); + fillInBoundsForUnknownDimensions(newDataClauseOp, builder); + LLVM_DEBUG(llvm::dbgs() << "Generated data clause for " << var << ":\n" + << "\t" << *newDataClauseOp << "\n"); + if (isa_and_nonnull<acc::PrivateOp, acc::FirstprivateOp, acc::ReductionOp>( + newDataClauseOp)) { + newPrivateOperands.push_back(acc::getAccVar(newDataClauseOp)); + } else if (isa_and_nonnull<ACC_DATA_CLAUSE_OPS>(newDataClauseOp)) { + newDataClauseOperands.push_back(acc::getAccVar(newDataClauseOp)); + dominatingDataClauses.push_back(acc::getAccVar(newDataClauseOp)); + } + } + + // 4) Legalize values in region (aka the uses in the region are the result + // of the data clause ops) + legalizeValuesInRegion(accRegion, newPrivateOperands, newDataClauseOperands); + + // 5) Generate private recipes which are required for properly attaching + // private operands. + if constexpr (!std::is_same_v<OpT, acc::KernelsOp> && + !std::is_same_v<OpT, acc::KernelEnvironmentOp>) + generateRecipes(module, builder, computeConstructOp, newPrivateOperands); + + // 6) Figure out insertion order for the new data clause operands. + SmallVector<Value> sortedDataClauseOperands( + computeConstructOp.getDataClauseOperands()); + for (auto newClause : newDataClauseOperands) + insertInSortedOrder(sortedDataClauseOperands, newClause.getDefiningOp()); + + // 7) Generate the data exit operations. + generateDataExitOperations(builder, computeConstructOp, newDataClauseOperands, + sortedDataClauseOperands); + // 8) Add all of the new operands to the compute construct op. + if constexpr (!std::is_same_v<OpT, acc::KernelsOp> && + !std::is_same_v<OpT, acc::KernelEnvironmentOp>) + addNewPrivateOperands(computeConstructOp, newPrivateOperands); + computeConstructOp.getDataClauseOperandsMutable().assign( + sortedDataClauseOperands); +} + +void ACCImplicitData::runOnOperation() { + ModuleOp module = this->getOperation(); + module.walk([&](Operation *op) { + if (isa<ACC_COMPUTE_CONSTRUCT_OPS, acc::KernelEnvironmentOp>(op)) { + assert(op->getNumRegions() == 1 && "must have 1 region"); + + auto defaultClause = acc::getDefaultAttr(op); + llvm::TypeSwitch<Operation *, void>(op) + .Case<ACC_COMPUTE_CONSTRUCT_OPS, acc::KernelEnvironmentOp>( + [&](auto op) { + generateImplicitDataOps(module, op, defaultClause); + }) + .Default([&](Operation *) {}); + } + }); +} + +} // namespace diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp new file mode 100644 index 0000000..8cab223 --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp @@ -0,0 +1,431 @@ +//===- ACCImplicitDeclare.cpp ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass applies implicit `acc declare` actions to global variables +// referenced in OpenACC compute regions and routine functions. +// +// Overview: +// --------- +// Global references in an acc regions (for globals not marked with `acc +// declare` by the user) can be handled in one of two ways: +// - Mapped through data clauses +// - Implicitly marked as `acc declare` (this pass) +// +// Thus, the OpenACC specification focuses solely on implicit data mapping rules +// whose implementation is captured in `ACCImplicitData` pass. +// +// However, it is both advantageous and required for certain cases to +// use implicit `acc declare` instead: +// - Any functions that are implicitly marked as `acc routine` through +// `ACCImplicitRoutine` may reference globals. Since data mapping +// is only possible for compute regions, such globals can only be +// made available on device through `acc declare`. +// - Compiler can generate and use globals for cases needed in IR +// representation such as type descriptors or various names needed for +// runtime calls and error reporting - such cases often are introduced +// after a frontend semantic checking is done since it is related to +// implementation detail. Thus, such compiler generated globals would +// not have been visible for a user to mark with `acc declare`. +// - Constant globals such as filename strings or data initialization values +// are values that do not get mutated but are still needed for appropriate +// runtime execution. If a kernel is launched 1000 times, it is not a +// good idea to map such a global 1000 times. Therefore, such globals +// benefit from being marked with `acc declare`. +// +// This pass automatically +// marks global variables with the `acc.declare` attribute when they are +// referenced in OpenACC compute constructs or routine functions and meet +// the criteria noted above, ensuring +// they are properly handled for device execution. +// +// The pass performs two main optimizations: +// +// 1. Hoisting: For non-constant globals referenced in compute regions, the +// pass hoists the address-of operation out of the region when possible, +// allowing them to be implicitly mapped through normal data clause +// mechanisms rather than requiring declare marking. +// +// 2. Declaration: For globals that must be available on the device (constants, +// globals in routines, globals in recipe operations), the pass adds the +// `acc.declare` attribute with the copyin data clause. +// +// Requirements: +// ------------- +// To use this pass in a pipeline, the following requirements must be met: +// +// 1. Operation Interface Implementation: Operations that compute addresses +// of global variables must implement the `acc::AddressOfGlobalOpInterface` +// and those that represent globals must implement the +// `acc::GlobalOpInterface`. Additionally, any operations that indirectly +// access globals must implement the `acc::IndirectGlobalAccessOpInterface`. +// +// 2. Analysis Registration (Optional): If custom behavior is needed for +// determining if a symbol use is valid within GPU regions, the dialect +// should pre-register the `acc::OpenACCSupport` analysis. +// +// Examples: +// --------- +// +// Example 1: Non-constant global in compute region (hoisted) +// +// Before: +// memref.global @g_scalar : memref<f32> = dense<0.0> +// func.func @test() { +// acc.serial { +// %addr = memref.get_global @g_scalar : memref<f32> +// %val = memref.load %addr[] : memref<f32> +// acc.yield +// } +// } +// +// After: +// memref.global @g_scalar : memref<f32> = dense<0.0> +// func.func @test() { +// %addr = memref.get_global @g_scalar : memref<f32> +// acc.serial { +// %val = memref.load %addr[] : memref<f32> +// acc.yield +// } +// } +// +// Example 2: Constant global in compute region (declared) +// +// Before: +// memref.global constant @g_const : memref<f32> = dense<1.0> +// func.func @test() { +// acc.serial { +// %addr = memref.get_global @g_const : memref<f32> +// %val = memref.load %addr[] : memref<f32> +// acc.yield +// } +// } +// +// After: +// memref.global constant @g_const : memref<f32> = dense<1.0> +// {acc.declare = #acc.declare<dataClause = acc_copyin>} +// func.func @test() { +// acc.serial { +// %addr = memref.get_global @g_const : memref<f32> +// %val = memref.load %addr[] : memref<f32> +// acc.yield +// } +// } +// +// Example 3: Global in acc routine (declared) +// +// Before: +// memref.global @g_data : memref<f32> = dense<0.0> +// acc.routine @routine_0 func(@device_func) +// func.func @device_func() attributes {acc.routine_info = ...} { +// %addr = memref.get_global @g_data : memref<f32> +// %val = memref.load %addr[] : memref<f32> +// } +// +// After: +// memref.global @g_data : memref<f32> = dense<0.0> +// {acc.declare = #acc.declare<dataClause = acc_copyin>} +// acc.routine @routine_0 func(@device_func) +// func.func @device_func() attributes {acc.routine_info = ...} { +// %addr = memref.get_global @g_data : memref<f32> +// %val = memref.load %addr[] : memref<f32> +// } +// +// Example 4: Global in private recipe (declared if recipe is used) +// +// Before: +// memref.global @g_init : memref<f32> = dense<0.0> +// acc.private.recipe @priv_recipe : memref<f32> init { +// ^bb0(%arg0: memref<f32>): +// %alloc = memref.alloc() : memref<f32> +// %global = memref.get_global @g_init : memref<f32> +// %val = memref.load %global[] : memref<f32> +// memref.store %val, %alloc[] : memref<f32> +// acc.yield %alloc : memref<f32> +// } destroy { ... } +// func.func @test() { +// %var = memref.alloc() : memref<f32> +// %priv = acc.private varPtr(%var : memref<f32>) +// recipe(@priv_recipe) -> memref<f32> +// acc.parallel private(%priv : memref<f32>) { ... } +// } +// +// After: +// memref.global @g_init : memref<f32> = dense<0.0> +// {acc.declare = #acc.declare<dataClause = acc_copyin>} +// acc.private.recipe @priv_recipe : memref<f32> init { +// ^bb0(%arg0: memref<f32>): +// %alloc = memref.alloc() : memref<f32> +// %global = memref.get_global @g_init : memref<f32> +// %val = memref.load %global[] : memref<f32> +// memref.store %val, %alloc[] : memref<f32> +// acc.yield %alloc : memref<f32> +// } destroy { ... } +// func.func @test() { +// %var = memref.alloc() : memref<f32> +// %priv = acc.private varPtr(%var : memref<f32>) +// recipe(@priv_recipe) -> memref<f32> +// acc.parallel private(%priv : memref<f32>) { ... } +// } +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/OpenACC/Transforms/Passes.h" + +#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" + +namespace mlir { +namespace acc { +#define GEN_PASS_DEF_ACCIMPLICITDECLARE +#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc" +} // namespace acc +} // namespace mlir + +#define DEBUG_TYPE "acc-implicit-declare" + +using namespace mlir; + +namespace { + +using GlobalOpSetT = llvm::SmallSetVector<Operation *, 16>; + +/// Checks whether a use of the requested `globalOp` should be considered +/// for hoisting out of acc region due to avoid `acc declare`ing something +/// that instead should be implicitly mapped. +static bool isGlobalUseCandidateForHoisting(Operation *globalOp, + Operation *user, + SymbolRefAttr symbol, + acc::OpenACCSupport &accSupport) { + // This symbol is valid in GPU region. This means semantics + // would change if moved to host - therefore it is not a candidate. + if (accSupport.isValidSymbolUse(user, symbol)) + return false; + + bool isConstant = false; + bool isFunction = false; + + if (auto globalVarOp = dyn_cast<acc::GlobalVariableOpInterface>(globalOp)) + isConstant = globalVarOp.isConstant(); + + if (isa<FunctionOpInterface>(globalOp)) + isFunction = true; + + // Constants should be kept in device code to ensure they are duplicated. + // Function references should be kept in device code to ensure their device + // addresses are computed. Everything else should be hoisted since we already + // proved they are not valid symbols in GPU region. + return !isConstant && !isFunction; +} + +/// Checks whether it is valid to use acc.declare marking on the global. +bool isValidForAccDeclare(Operation *globalOp) { + // For functions - we use acc.routine marking instead. + return !isa<FunctionOpInterface>(globalOp); +} + +/// Checks whether a recipe operation has meaningful use of its symbol that +/// justifies processing its regions for global references. Returns false if: +/// 1. The recipe has no symbol uses at all, or +/// 2. The only symbol use is the recipe's own symbol definition +template <typename RecipeOpT> +static bool hasRelevantRecipeUse(RecipeOpT &recipeOp, ModuleOp &mod) { + std::optional<SymbolTable::UseRange> symbolUses = recipeOp.getSymbolUses(mod); + + // No recipe symbol uses. + if (!symbolUses.has_value() || symbolUses->empty()) + return false; + + // If more than one use, assume it's used. + auto begin = symbolUses->begin(); + auto end = symbolUses->end(); + if (begin != end && std::next(begin) != end) + return true; + + // If single use, check if the use is the recipe itself. + const SymbolTable::SymbolUse &use = *symbolUses->begin(); + return use.getUser() != recipeOp.getOperation(); +} + +// Hoists addr_of operations for non-constant globals out of OpenACC regions. +// This way - they are implicitly mapped instead of being considered for +// implicit declare. +template <typename AccConstructT> +static void hoistNonConstantDirectUses(AccConstructT accOp, + acc::OpenACCSupport &accSupport) { + accOp.walk([&](acc::AddressOfGlobalOpInterface addrOfOp) { + SymbolRefAttr symRef = addrOfOp.getSymbol(); + if (symRef) { + Operation *globalOp = + SymbolTable::lookupNearestSymbolFrom(addrOfOp, symRef); + if (isGlobalUseCandidateForHoisting(globalOp, addrOfOp, symRef, + accSupport)) { + addrOfOp->moveBefore(accOp); + LLVM_DEBUG( + llvm::dbgs() << "Hoisted:\n\t" << addrOfOp << "\n\tfrom:\n\t"; + accOp->print(llvm::dbgs(), + OpPrintingFlags{}.skipRegions().enableDebugInfo()); + llvm::dbgs() << "\n"); + } + } + }); +} + +// Collects the globals referenced in a device region +static void collectGlobalsFromDeviceRegion(Region ®ion, + GlobalOpSetT &globals, + acc::OpenACCSupport &accSupport, + SymbolTable &symTab) { + region.walk([&](Operation *op) { + // 1) Only consider relevant operations which use symbols + auto addrOfOp = dyn_cast<acc::AddressOfGlobalOpInterface>(op); + if (addrOfOp) { + SymbolRefAttr symRef = addrOfOp.getSymbol(); + // 2) Found an operation which uses the symbol. Next determine if it + // is a candidate for `acc declare`. Some of the criteria considered + // is whether this symbol is not already a device one (either because + // acc declare is already used or this is a CUF global). + Operation *globalOp = nullptr; + bool isCandidate = !accSupport.isValidSymbolUse(op, symRef, &globalOp); + // 3) Add the candidate to the set of globals to be `acc declare`d. + if (isCandidate && globalOp && isValidForAccDeclare(globalOp)) + globals.insert(globalOp); + } else if (auto indirectAccessOp = + dyn_cast<acc::IndirectGlobalAccessOpInterface>(op)) { + // Process operations that indirectly access globals + llvm::SmallVector<SymbolRefAttr> symbols; + indirectAccessOp.getReferencedSymbols(symbols, &symTab); + for (SymbolRefAttr symRef : symbols) + if (Operation *globalOp = symTab.lookup(symRef.getLeafReference())) + if (isValidForAccDeclare(globalOp)) + globals.insert(globalOp); + } + }); +} + +// Adds the declare attribute to the operation `op`. +static void addDeclareAttr(MLIRContext *context, Operation *op, + acc::DataClause clause) { + op->setAttr(acc::getDeclareAttrName(), + acc::DeclareAttr::get(context, + acc::DataClauseAttr::get(context, clause))); +} + +// This pass applies implicit declare actions for globals referenced in +// OpenACC compute and routine regions. +class ACCImplicitDeclare + : public acc::impl::ACCImplicitDeclareBase<ACCImplicitDeclare> { +public: + using ACCImplicitDeclareBase<ACCImplicitDeclare>::ACCImplicitDeclareBase; + + void runOnOperation() override { + ModuleOp mod = getOperation(); + MLIRContext *context = &getContext(); + acc::OpenACCSupport &accSupport = getAnalysis<acc::OpenACCSupport>(); + + // 1) Start off by hoisting any AddressOf operations out of acc region + // for any cases we do not want to `acc declare`. This is because we can + // rely on implicit data mapping in majority of cases without uselessly + // polluting the device globals. + mod.walk([&](Operation *op) { + TypeSwitch<Operation *, void>(op) + .Case<ACC_COMPUTE_CONSTRUCT_OPS, acc::KernelEnvironmentOp>( + [&](auto accOp) { + hoistNonConstantDirectUses(accOp, accSupport); + }); + }); + + // 2) Collect global symbols which need to be `acc declare`d. Do it for + // compute regions, acc routine, and existing globals with the declare + // attribute. + SymbolTable symTab(mod); + GlobalOpSetT globalsToAccDeclare; + mod.walk([&](Operation *op) { + TypeSwitch<Operation *, void>(op) + .Case<ACC_COMPUTE_CONSTRUCT_OPS, acc::KernelEnvironmentOp>( + [&](auto accOp) { + collectGlobalsFromDeviceRegion( + accOp.getRegion(), globalsToAccDeclare, accSupport, symTab); + }) + .Case<FunctionOpInterface>([&](auto func) { + if ((acc::isAccRoutine(func) || + acc::isSpecializedAccRoutine(func)) && + !func.isExternal()) + collectGlobalsFromDeviceRegion(func.getFunctionBody(), + globalsToAccDeclare, accSupport, + symTab); + }) + .Case<acc::GlobalVariableOpInterface>([&](auto globalVarOp) { + if (globalVarOp->getAttr(acc::getDeclareAttrName())) + if (Region *initRegion = globalVarOp.getInitRegion()) + collectGlobalsFromDeviceRegion(*initRegion, globalsToAccDeclare, + accSupport, symTab); + }) + .Case<acc::PrivateRecipeOp>([&](auto privateRecipe) { + if (hasRelevantRecipeUse(privateRecipe, mod)) { + collectGlobalsFromDeviceRegion(privateRecipe.getInitRegion(), + globalsToAccDeclare, accSupport, + symTab); + collectGlobalsFromDeviceRegion(privateRecipe.getDestroyRegion(), + globalsToAccDeclare, accSupport, + symTab); + } + }) + .Case<acc::FirstprivateRecipeOp>([&](auto firstprivateRecipe) { + if (hasRelevantRecipeUse(firstprivateRecipe, mod)) { + collectGlobalsFromDeviceRegion(firstprivateRecipe.getInitRegion(), + globalsToAccDeclare, accSupport, + symTab); + collectGlobalsFromDeviceRegion( + firstprivateRecipe.getDestroyRegion(), globalsToAccDeclare, + accSupport, symTab); + collectGlobalsFromDeviceRegion(firstprivateRecipe.getCopyRegion(), + globalsToAccDeclare, accSupport, + symTab); + } + }) + .Case<acc::ReductionRecipeOp>([&](auto reductionRecipe) { + if (hasRelevantRecipeUse(reductionRecipe, mod)) { + collectGlobalsFromDeviceRegion(reductionRecipe.getInitRegion(), + globalsToAccDeclare, accSupport, + symTab); + collectGlobalsFromDeviceRegion( + reductionRecipe.getCombinerRegion(), globalsToAccDeclare, + accSupport, symTab); + } + }); + }); + + // 3) Finally, generate the appropriate declare actions needed to ensure + // this is considered for device global. + for (Operation *globalOp : globalsToAccDeclare) { + LLVM_DEBUG( + llvm::dbgs() << "Global is being `acc declare copyin`d: "; + globalOp->print(llvm::dbgs(), + OpPrintingFlags{}.skipRegions().enableDebugInfo()); + llvm::dbgs() << "\n"); + + // Mark it as declare copyin. + addDeclareAttr(context, globalOp, acc::DataClause::acc_copyin); + + // TODO: May need to create the global constructor which does the mapping + // action. It is not yet clear if this is needed yet (since the globals + // might just end up in the GPU image without requiring mapping via + // runtime). + } + } +}; + +} // namespace diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitRoutine.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitRoutine.cpp new file mode 100644 index 0000000..12efaf4 --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitRoutine.cpp @@ -0,0 +1,237 @@ +//===- ACCImplicitRoutine.cpp - OpenACC Implicit Routine Transform -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass implements the implicit rules described in OpenACC specification +// for `Routine Directive` (OpenACC 3.4 spec, section 2.15.1). +// +// "If no explicit routine directive applies to a procedure whose definition +// appears in the program unit being compiled, then the implementation applies +// an implicit routine directive to that procedure if any of the following +// conditions holds: +// - The procedure is called or its address is accessed in a compute region." +// +// The specification further states: +// "When the implementation applies an implicit routine directive to a +// procedure, it must recursively apply implicit routine directives to other +// procedures for which the above rules specify relevant dependencies. Such +// dependencies can form a cycle, so the implementation must take care to avoid +// infinite recursion." +// +// This pass implements these requirements by: +// 1. Walking through all OpenACC compute constructs and functions already +// marked with `acc routine` in the module and identifying function calls +// within these regions. +// 2. Creating implicit `acc.routine` operations for functions that don't +// already have routine declarations. +// 3. Recursively walking through all existing `acc routine` and creating +// implicit routine operations for function calls within these routines, +// while avoiding infinite recursion through proper tracking. +// +// Requirements: +// ------------- +// To use this pass in a pipeline, the following requirements must be met: +// +// 1. Operation Interface Implementation: Operations that define functions +// or call functions should implement `mlir::FunctionOpInterface` and +// `mlir::CallOpInterface` respectively. +// +// 2. Analysis Registration (Optional): If custom behavior is needed for +// determining if a symbol use is valid within GPU regions, the dialect +// should pre-register the `acc::OpenACCSupport` analysis. +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/OpenACC/Transforms/Passes.h" + +#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include <queue> + +#define DEBUG_TYPE "acc-implicit-routine" + +namespace mlir { +namespace acc { +#define GEN_PASS_DEF_ACCIMPLICITROUTINE +#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc" +} // namespace acc +} // namespace mlir + +namespace { + +using namespace mlir; + +class ACCImplicitRoutine + : public acc::impl::ACCImplicitRoutineBase<ACCImplicitRoutine> { +private: + unsigned routineCounter = 0; + static constexpr llvm::StringRef accRoutinePrefix = "acc_routine_"; + + // Count existing routine operations and update counter + void initRoutineCounter(ModuleOp module) { + module.walk([&](acc::RoutineOp routineOp) { routineCounter++; }); + } + + // Check if routine has a default bind clause or a device-type specific bind + // clause. Returns true if `acc routine` has a default bind clause or + // a device-type specific bind clause. + bool isACCRoutineBindDefaultOrDeviceType(acc::RoutineOp op, + acc::DeviceType deviceType) { + // Fast check to avoid device-type specific lookups. + if (!op.getBindIdName() && !op.getBindStrName()) + return false; + return op.getBindNameValue().has_value() || + op.getBindNameValue(deviceType).has_value(); + } + + // Generate a unique name for the routine and create the routine operation + acc::RoutineOp createRoutineOp(OpBuilder &builder, Location loc, + FunctionOpInterface &callee) { + std::string routineName = + (accRoutinePrefix + std::to_string(routineCounter++)).str(); + auto routineOp = acc::RoutineOp::create( + builder, loc, + /* sym_name=*/builder.getStringAttr(routineName), + /* func_name=*/ + mlir::SymbolRefAttr::get(builder.getContext(), + builder.getStringAttr(callee.getName())), + /* bindIdName=*/nullptr, + /* bindStrName=*/nullptr, + /* bindIdNameDeviceType=*/nullptr, + /* bindStrNameDeviceType=*/nullptr, + /* worker=*/nullptr, + /* vector=*/nullptr, + /* seq=*/nullptr, + /* nohost=*/nullptr, + /* implicit=*/builder.getUnitAttr(), + /* gang=*/nullptr, + /* gangDim=*/nullptr, + /* gangDimDeviceType=*/nullptr); + + // Assert that the callee does not already have routine info attribute + assert(!callee->hasAttr(acc::getRoutineInfoAttrName()) && + "function is already associated with a routine"); + + callee->setAttr( + acc::getRoutineInfoAttrName(), + mlir::acc::RoutineInfoAttr::get( + builder.getContext(), + {mlir::SymbolRefAttr::get(builder.getContext(), + builder.getStringAttr(routineName))})); + return routineOp; + } + + // Used to walk through a compute region looking for function calls. + void + implicitRoutineForCallsInComputeRegions(Operation *op, SymbolTable &symTab, + mlir::OpBuilder &builder, + acc::OpenACCSupport &accSupport) { + op->walk([&](CallOpInterface callOp) { + if (!callOp.getCallableForCallee()) + return; + + auto calleeSymbolRef = + dyn_cast<SymbolRefAttr>(callOp.getCallableForCallee()); + // When call is done through ssa value, the callee is not a symbol. + // Skip it because we don't know the call target. + if (!calleeSymbolRef) + return; + + auto callee = symTab.lookup<FunctionOpInterface>( + calleeSymbolRef.getLeafReference().str()); + // If the callee does not exist or is already a valid symbol for GPU + // regions, skip it + + assert(callee && "callee function must be found in symbol table"); + if (accSupport.isValidSymbolUse(callOp.getOperation(), calleeSymbolRef)) + return; + builder.setInsertionPoint(callee); + createRoutineOp(builder, callee.getLoc(), callee); + }); + } + + // Recursively handle calls within a routine operation + void implicitRoutineForCallsInRoutine(acc::RoutineOp routineOp, + mlir::OpBuilder &builder, + acc::OpenACCSupport &accSupport, + acc::DeviceType targetDeviceType) { + // When bind clause is used, it means that the target is different than the + // function to which the `acc routine` is used with. Skip this case to + // avoid implicitly recursively marking calls that would not end up on + // device. + if (isACCRoutineBindDefaultOrDeviceType(routineOp, targetDeviceType)) + return; + + SymbolTable symTab(routineOp->getParentOfType<ModuleOp>()); + std::queue<acc::RoutineOp> routineQueue; + routineQueue.push(routineOp); + while (!routineQueue.empty()) { + auto currentRoutine = routineQueue.front(); + routineQueue.pop(); + auto func = symTab.lookup<FunctionOpInterface>( + currentRoutine.getFuncName().getLeafReference()); + func.walk([&](CallOpInterface callOp) { + if (!callOp.getCallableForCallee()) + return; + + auto calleeSymbolRef = + dyn_cast<SymbolRefAttr>(callOp.getCallableForCallee()); + // When call is done through ssa value, the callee is not a symbol. + // Skip it because we don't know the call target. + if (!calleeSymbolRef) + return; + + auto callee = symTab.lookup<FunctionOpInterface>( + calleeSymbolRef.getLeafReference().str()); + // If the callee does not exist or is already a valid symbol for GPU + // regions, skip it + assert(callee && "callee function must be found in symbol table"); + if (accSupport.isValidSymbolUse(callOp.getOperation(), calleeSymbolRef)) + return; + builder.setInsertionPoint(callee); + auto newRoutineOp = createRoutineOp(builder, callee.getLoc(), callee); + routineQueue.push(newRoutineOp); + }); + } + } + +public: + using ACCImplicitRoutineBase<ACCImplicitRoutine>::ACCImplicitRoutineBase; + + void runOnOperation() override { + auto module = getOperation(); + mlir::OpBuilder builder(module.getContext()); + SymbolTable symTab(module); + initRoutineCounter(module); + + acc::OpenACCSupport &accSupport = getAnalysis<acc::OpenACCSupport>(); + + // Handle compute regions + module.walk([&](Operation *op) { + if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(op)) + implicitRoutineForCallsInComputeRegions(op, symTab, builder, + accSupport); + }); + + // Use the device type option from the pass options. + acc::DeviceType targetDeviceType = deviceType; + + // Handle existing routines + module.walk([&](acc::RoutineOp routineOp) { + implicitRoutineForCallsInRoutine(routineOp, builder, accSupport, + targetDeviceType); + }); + } +}; + +} // namespace diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp new file mode 100644 index 0000000..f41ce276 --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp @@ -0,0 +1,117 @@ +//===- ACCLegalizeSerial.cpp - Legalize ACC Serial region -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass converts acc.serial into acc.parallel with num_gangs(1) +// num_workers(1) vector_length(1). +// +// This transformation simplifies processing of acc regions by unifying the +// handling of serial and parallel constructs. Since an OpenACC serial region +// executes sequentially (like a parallel region with a single gang, worker, and +// vector), this conversion is semantically equivalent while enabling code reuse +// in later compilation stages. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/OpenACC/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" + +namespace mlir { +namespace acc { +#define GEN_PASS_DEF_ACCLEGALIZESERIAL +#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc" +} // namespace acc +} // namespace mlir + +#define DEBUG_TYPE "acc-legalize-serial" + +namespace { +using namespace mlir; + +struct ACCSerialOpConversion : public OpRewritePattern<acc::SerialOp> { + using OpRewritePattern<acc::SerialOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(acc::SerialOp serialOp, + PatternRewriter &rewriter) const override { + + const Location loc = serialOp.getLoc(); + + // Create a container holding the constant value of 1 for use as the + // num_gangs, num_workers, and vector_length attributes. + llvm::SmallVector<mlir::Value> numValues; + auto value = arith::ConstantIntOp::create(rewriter, loc, 1, 32); + numValues.push_back(value); + + // Since num_gangs is specified as both attributes and values, create a + // segment attribute. + llvm::SmallVector<int32_t> numGangsSegments; + numGangsSegments.push_back(numValues.size()); + auto gangSegmentsAttr = rewriter.getDenseI32ArrayAttr(numGangsSegments); + + // Create a device_type attribute set to `none` which ensures that + // the parallel dimensions specification applies to the default clauses. + llvm::SmallVector<mlir::Attribute> crtDeviceTypes; + auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get( + rewriter.getContext(), mlir::acc::DeviceType::None); + crtDeviceTypes.push_back(crtDeviceTypeAttr); + auto devTypeAttr = + mlir::ArrayAttr::get(rewriter.getContext(), crtDeviceTypes); + + LLVM_DEBUG(llvm::dbgs() << "acc.serial OP: " << serialOp << "\n"); + + // Create a new acc.parallel op with the same operands - except include the + // num_gangs, num_workers, and vector_length attributes. + acc::ParallelOp parOp = acc::ParallelOp::create( + rewriter, loc, serialOp.getAsyncOperands(), + serialOp.getAsyncOperandsDeviceTypeAttr(), serialOp.getAsyncOnlyAttr(), + serialOp.getWaitOperands(), serialOp.getWaitOperandsSegmentsAttr(), + serialOp.getWaitOperandsDeviceTypeAttr(), + serialOp.getHasWaitDevnumAttr(), serialOp.getWaitOnlyAttr(), numValues, + gangSegmentsAttr, devTypeAttr, numValues, devTypeAttr, numValues, + devTypeAttr, serialOp.getIfCond(), serialOp.getSelfCond(), + serialOp.getSelfAttrAttr(), serialOp.getReductionOperands(), + serialOp.getPrivateOperands(), serialOp.getFirstprivateOperands(), + serialOp.getDataClauseOperands(), serialOp.getDefaultAttrAttr(), + serialOp.getCombinedAttr()); + + parOp.getRegion().takeBody(serialOp.getRegion()); + + LLVM_DEBUG(llvm::dbgs() << "acc.parallel OP: " << parOp << "\n"); + rewriter.replaceOp(serialOp, parOp); + + return success(); + } +}; + +class ACCLegalizeSerial + : public mlir::acc::impl::ACCLegalizeSerialBase<ACCLegalizeSerial> { +public: + using ACCLegalizeSerialBase<ACCLegalizeSerial>::ACCLegalizeSerialBase; + void runOnOperation() override { + func::FuncOp funcOp = getOperation(); + MLIRContext *context = funcOp.getContext(); + RewritePatternSet patterns(context); + patterns.insert<ACCSerialOpConversion>(context); + (void)applyPatternsGreedily(funcOp, std::move(patterns)); + } +}; + +} // namespace diff --git a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt index 7d93495..10a1796 100644 --- a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt @@ -1,4 +1,8 @@ add_mlir_dialect_library(MLIROpenACCTransforms + ACCImplicitData.cpp + ACCImplicitDeclare.cpp + ACCImplicitRoutine.cpp + ACCLegalizeSerial.cpp LegalizeDataValues.cpp ADDITIONAL_HEADER_DIRS @@ -14,7 +18,10 @@ add_mlir_dialect_library(MLIROpenACCTransforms MLIROpenACCTypeInterfacesIncGen LINK_LIBS PUBLIC + MLIRAnalysis + MLIROpenACCAnalysis MLIROpenACCDialect + MLIROpenACCUtils MLIRFuncDialect MLIRIR MLIRPass diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp index 660c313..7f27b44 100644 --- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp +++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp @@ -9,8 +9,13 @@ #include "mlir/Dialect/OpenACC/OpenACCUtils.h" #include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/Support/Casting.h" mlir::Operation *mlir::acc::getEnclosingComputeOp(mlir::Region ®ion) { @@ -145,3 +150,119 @@ std::string mlir::acc::getRecipeName(mlir::acc::RecipeKind kind, return recipeName; } + +mlir::Value mlir::acc::getBaseEntity(mlir::Value val) { + if (auto partialEntityAccessOp = + dyn_cast<PartialEntityAccessOpInterface>(val.getDefiningOp())) { + if (!partialEntityAccessOp.isCompleteView()) + return partialEntityAccessOp.getBaseEntity(); + } + + return val; +} + +bool mlir::acc::isValidSymbolUse(mlir::Operation *user, + mlir::SymbolRefAttr symbol, + mlir::Operation **definingOpPtr) { + mlir::Operation *definingOp = + mlir::SymbolTable::lookupNearestSymbolFrom(user, symbol); + + // If there are no defining ops, we have no way to ensure validity because + // we cannot check for any attributes. + if (!definingOp) + return false; + + if (definingOpPtr) + *definingOpPtr = definingOp; + + // Check if the defining op is a recipe (private, reduction, firstprivate). + // Recipes are valid as they get materialized before being offloaded to + // device. They are only instructions for how to materialize. + if (mlir::isa<mlir::acc::PrivateRecipeOp, mlir::acc::ReductionRecipeOp, + mlir::acc::FirstprivateRecipeOp>(definingOp)) + return true; + + // Check if the defining op is a function + if (auto func = + mlir::dyn_cast_if_present<mlir::FunctionOpInterface>(definingOp)) { + // If this symbol is actually an acc routine - then it is expected for it + // to be offloaded - therefore it is valid. + if (func->hasAttr(mlir::acc::getRoutineInfoAttrName())) + return true; + + // If this symbol is a call to an LLVM intrinsic, then it is likely valid. + // Check the following: + // 1. The function is private + // 2. The function has no body + // 3. Name starts with "llvm." + // 4. The function's name is a valid LLVM intrinsic name + if (func.getVisibility() == mlir::SymbolTable::Visibility::Private && + func.getFunctionBody().empty() && func.getName().starts_with("llvm.") && + llvm::Intrinsic::lookupIntrinsicID(func.getName()) != + llvm::Intrinsic::not_intrinsic) + return true; + } + + // A declare attribute is needed for symbol references. + bool hasDeclare = definingOp->hasAttr(mlir::acc::getDeclareAttrName()); + return hasDeclare; +} + +llvm::SmallVector<mlir::Value> +mlir::acc::getDominatingDataClauses(mlir::Operation *computeConstructOp, + mlir::DominanceInfo &domInfo, + mlir::PostDominanceInfo &postDomInfo) { + llvm::SmallSetVector<mlir::Value, 8> dominatingDataClauses; + + llvm::TypeSwitch<mlir::Operation *>(computeConstructOp) + .Case<mlir::acc::ParallelOp, mlir::acc::KernelsOp, mlir::acc::SerialOp>( + [&](auto op) { + for (auto dataClause : op.getDataClauseOperands()) { + dominatingDataClauses.insert(dataClause); + } + }) + .Default([](mlir::Operation *) {}); + + // Collect the data clauses from enclosing data constructs. + mlir::Operation *currParentOp = computeConstructOp->getParentOp(); + while (currParentOp) { + if (mlir::isa<mlir::acc::DataOp>(currParentOp)) { + for (auto dataClause : mlir::dyn_cast<mlir::acc::DataOp>(currParentOp) + .getDataClauseOperands()) { + dominatingDataClauses.insert(dataClause); + } + } + currParentOp = currParentOp->getParentOp(); + } + + // Find the enclosing function/subroutine + auto funcOp = + computeConstructOp->getParentOfType<mlir::FunctionOpInterface>(); + if (!funcOp) + return dominatingDataClauses.takeVector(); + + // Walk the function to find `acc.declare_enter`/`acc.declare_exit` pairs that + // dominate and post-dominate the compute construct and add their data + // clauses to the list. + funcOp->walk([&](mlir::acc::DeclareEnterOp declareEnterOp) { + if (domInfo.dominates(declareEnterOp.getOperation(), computeConstructOp)) { + // Collect all `acc.declare_exit` ops for this token. + llvm::SmallVector<mlir::acc::DeclareExitOp> exits; + for (auto *user : declareEnterOp.getToken().getUsers()) + if (auto declareExit = mlir::dyn_cast<mlir::acc::DeclareExitOp>(user)) + exits.push_back(declareExit); + + // Only add clauses if every `acc.declare_exit` op post-dominates the + // compute construct. + if (!exits.empty() && + llvm::all_of(exits, [&](mlir::acc::DeclareExitOp exitOp) { + return postDomInfo.postDominates(exitOp, computeConstructOp); + })) { + for (auto dataClause : declareEnterOp.getDataClauseOperands()) + dominatingDataClauses.insert(dataClause); + } + } + }); + + return dominatingDataClauses.takeVector(); +} |
