aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/OpenACC
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/OpenACC')
-rw-r--r--mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp7
-rw-r--r--mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp636
-rw-r--r--mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp781
-rw-r--r--mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp431
-rw-r--r--mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitRoutine.cpp237
-rw-r--r--mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp117
-rw-r--r--mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt7
-rw-r--r--mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp121
8 files changed, 2133 insertions, 204 deletions
diff --git a/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp
index 40e769e..1d775fb 100644
--- a/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp
+++ b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp
@@ -41,5 +41,12 @@ InFlightDiagnostic OpenACCSupport::emitNYI(Location loc, const Twine &message) {
return mlir::emitError(loc, "not yet implemented: " + message);
}
+bool OpenACCSupport::isValidSymbolUse(Operation *user, SymbolRefAttr symbol,
+ Operation **definingOpPtr) {
+ if (impl)
+ return impl->isValidSymbolUse(user, symbol, definingOpPtr);
+ return acc::isValidSymbolUse(user, symbol, definingOpPtr);
+}
+
} // namespace acc
} // namespace mlir
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 35eba72..47f1222 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -15,6 +15,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/SymbolTable.h"
@@ -203,12 +204,91 @@ struct MemRefPointerLikeModel
return false;
}
+
+ mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc,
+ TypedValue<PointerLikeType> srcPtr,
+ Type valueType) const {
+ // Load from a memref - only valid for scalar memrefs (rank 0).
+ // This is because the address computation for memrefs is part of the load
+ // (and not computed separately), but the API does not have arguments for
+ // indexing.
+ auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(srcPtr);
+ if (!memrefValue)
+ return {};
+
+ auto memrefTy = memrefValue.getType();
+
+ // Only load from scalar memrefs (rank 0)
+ if (memrefTy.getRank() != 0)
+ return {};
+
+ return memref::LoadOp::create(builder, loc, memrefValue);
+ }
+
+ bool genStore(Type pointer, OpBuilder &builder, Location loc,
+ Value valueToStore, TypedValue<PointerLikeType> destPtr) const {
+ // Store to a memref - only valid for scalar memrefs (rank 0)
+ // This is because the address computation for memrefs is part of the store
+ // (and not computed separately), but the API does not have arguments for
+ // indexing.
+ auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(destPtr);
+ if (!memrefValue)
+ return false;
+
+ auto memrefTy = memrefValue.getType();
+
+ // Only store to scalar memrefs (rank 0)
+ if (memrefTy.getRank() != 0)
+ return false;
+
+ memref::StoreOp::create(builder, loc, valueToStore, memrefValue);
+ return true;
+ }
};
struct LLVMPointerPointerLikeModel
: public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
LLVM::LLVMPointerType> {
Type getElementType(Type pointer) const { return Type(); }
+
+ mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc,
+ TypedValue<PointerLikeType> srcPtr,
+ Type valueType) const {
+ // For LLVM pointers, we need the valueType to determine what to load
+ if (!valueType)
+ return {};
+
+ return LLVM::LoadOp::create(builder, loc, valueType, srcPtr);
+ }
+
+ bool genStore(Type pointer, OpBuilder &builder, Location loc,
+ Value valueToStore, TypedValue<PointerLikeType> destPtr) const {
+ LLVM::StoreOp::create(builder, loc, valueToStore, destPtr);
+ return true;
+ }
+};
+
+struct MemrefAddressOfGlobalModel
+ : public AddressOfGlobalOpInterface::ExternalModel<
+ MemrefAddressOfGlobalModel, memref::GetGlobalOp> {
+ SymbolRefAttr getSymbol(Operation *op) const {
+ auto getGlobalOp = cast<memref::GetGlobalOp>(op);
+ return getGlobalOp.getNameAttr();
+ }
+};
+
+struct MemrefGlobalVariableModel
+ : public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel,
+ memref::GlobalOp> {
+ bool isConstant(Operation *op) const {
+ auto globalOp = cast<memref::GlobalOp>(op);
+ return globalOp.getConstant();
+ }
+
+ Region *getInitRegion(Operation *op) const {
+ // GlobalOp uses attributes for initialization, not regions
+ return nullptr;
+ }
};
/// Helper function for any of the times we need to modify an ArrayAttr based on
@@ -302,6 +382,11 @@ void OpenACCDialect::initialize() {
MemRefPointerLikeModel<UnrankedMemRefType>>(*getContext());
LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
*getContext());
+
+ // Attach operation interfaces
+ memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>(
+ *getContext());
+ memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*getContext());
}
//===----------------------------------------------------------------------===//
@@ -467,6 +552,28 @@ checkValidModifier(Op op, acc::DataClauseModifier validModifiers) {
return success();
}
+template <typename OpT, typename RecipeOpT>
+static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName) {
+ // Mappable types do not need a recipe because it is possible to generate one
+ // from its API. Reject reductions though because no API is available for them
+ // at this time.
+ if (mlir::acc::isMappableType(op.getVar().getType()) &&
+ !std::is_same_v<OpT, acc::ReductionOp>)
+ return success();
+
+ mlir::SymbolRefAttr operandRecipe = op.getRecipeAttr();
+ if (!operandRecipe)
+ return op->emitOpError() << "recipe expected for " << operandName;
+
+ auto decl =
+ SymbolTable::lookupNearestSymbolFrom<RecipeOpT>(op, operandRecipe);
+ if (!decl)
+ return op->emitOpError()
+ << "expected symbol reference " << operandRecipe << " to point to a "
+ << operandName << " declaration";
+ return success();
+}
+
static ParseResult parseVar(mlir::OpAsmParser &parser,
OpAsmParser::UnresolvedOperand &var) {
// Either `var` or `varPtr` keyword is required.
@@ -573,6 +680,18 @@ static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op,
}
}
+static ParseResult parseRecipeSym(mlir::OpAsmParser &parser,
+ mlir::SymbolRefAttr &recipeAttr) {
+ if (failed(parser.parseAttribute(recipeAttr)))
+ return failure();
+ return success();
+}
+
+static void printRecipeSym(mlir::OpAsmPrinter &p, mlir::Operation *op,
+ mlir::SymbolRefAttr recipeAttr) {
+ p << recipeAttr;
+}
+
//===----------------------------------------------------------------------===//
// DataBoundsOp
//===----------------------------------------------------------------------===//
@@ -595,6 +714,9 @@ LogicalResult acc::PrivateOp::verify() {
return failure();
if (failed(checkNoModifier(*this)))
return failure();
+ if (failed(
+ checkRecipe<acc::PrivateOp, acc::PrivateRecipeOp>(*this, "private")))
+ return failure();
return success();
}
@@ -609,6 +731,9 @@ LogicalResult acc::FirstprivateOp::verify() {
return failure();
if (failed(checkNoModifier(*this)))
return failure();
+ if (failed(checkRecipe<acc::FirstprivateOp, acc::FirstprivateRecipeOp>(
+ *this, "firstprivate")))
+ return failure();
return success();
}
@@ -637,6 +762,9 @@ LogicalResult acc::ReductionOp::verify() {
return failure();
if (failed(checkNoModifier(*this)))
return failure();
+ if (failed(checkRecipe<acc::ReductionOp, acc::ReductionRecipeOp>(
+ *this, "reduction")))
+ return failure();
return success();
}
@@ -1042,6 +1170,65 @@ struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
}
};
+/// Remove empty acc.kernel_environment operations. If the operation has wait
+/// operands, create a acc.wait operation to preserve synchronization.
+struct RemoveEmptyKernelEnvironment
+ : public OpRewritePattern<acc::KernelEnvironmentOp> {
+ using OpRewritePattern<acc::KernelEnvironmentOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(acc::KernelEnvironmentOp op,
+ PatternRewriter &rewriter) const override {
+ assert(op->getNumRegions() == 1 && "expected op to have one region");
+
+ Block &block = op.getRegion().front();
+ if (!block.empty())
+ return failure();
+
+ // Conservatively disable canonicalization of empty acc.kernel_environment
+ // operations if the wait operands in the kernel_environment cannot be fully
+ // represented by acc.wait operation.
+
+ // Disable canonicalization if device type is not the default
+ if (auto deviceTypeAttr = op.getWaitOperandsDeviceTypeAttr()) {
+ for (auto attr : deviceTypeAttr) {
+ if (auto dtAttr = mlir::dyn_cast<acc::DeviceTypeAttr>(attr)) {
+ if (dtAttr.getValue() != mlir::acc::DeviceType::None)
+ return failure();
+ }
+ }
+ }
+
+ // Disable canonicalization if any wait segment has a devnum
+ if (auto hasDevnumAttr = op.getHasWaitDevnumAttr()) {
+ for (auto attr : hasDevnumAttr) {
+ if (auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>(attr)) {
+ if (boolAttr.getValue())
+ return failure();
+ }
+ }
+ }
+
+ // Disable canonicalization if there are multiple wait segments
+ if (auto segmentsAttr = op.getWaitOperandsSegmentsAttr()) {
+ if (segmentsAttr.size() > 1)
+ return failure();
+ }
+
+ // Remove empty kernel environment.
+ // Preserve synchronization by creating acc.wait operation if needed.
+ if (!op.getWaitOperands().empty() || op.getWaitOnlyAttr())
+ rewriter.replaceOpWithNewOp<acc::WaitOp>(op, op.getWaitOperands(),
+ /*asyncOperand=*/Value(),
+ /*waitDevnum=*/Value(),
+ /*async=*/nullptr,
+ /*ifCond=*/Value());
+ else
+ rewriter.eraseOp(op);
+
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// Recipe Region Helpers
//===----------------------------------------------------------------------===//
@@ -1263,6 +1450,28 @@ PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
return recipe;
}
+std::optional<PrivateRecipeOp>
+PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
+ StringRef recipeName,
+ FirstprivateRecipeOp firstprivRecipe) {
+ // Create the private.recipe op with the same type as the firstprivate.recipe.
+ OpBuilder::InsertionGuard guard(builder);
+ auto varType = firstprivRecipe.getType();
+ auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
+
+ // Clone the init region
+ IRMapping mapping;
+ firstprivRecipe.getInitRegion().cloneInto(&recipe.getInitRegion(), mapping);
+
+ // Clone destroy region if the firstprivate.recipe has one.
+ if (!firstprivRecipe.getDestroyRegion().empty()) {
+ IRMapping mapping;
+ firstprivRecipe.getDestroyRegion().cloneInto(&recipe.getDestroyRegion(),
+ mapping);
+ }
+ return recipe;
+}
+
//===----------------------------------------------------------------------===//
// FirstprivateRecipeOp
//===----------------------------------------------------------------------===//
@@ -1373,40 +1582,6 @@ LogicalResult acc::ReductionRecipeOp::verifyRegions() {
}
//===----------------------------------------------------------------------===//
-// Custom parser and printer verifier for private clause
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseSymOperandList(
- mlir::OpAsmParser &parser,
- llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
- llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &symbols) {
- llvm::SmallVector<SymbolRefAttr> attributes;
- if (failed(parser.parseCommaSeparatedList([&]() {
- if (parser.parseAttribute(attributes.emplace_back()) ||
- parser.parseArrow() ||
- parser.parseOperand(operands.emplace_back()) ||
- parser.parseColonType(types.emplace_back()))
- return failure();
- return success();
- })))
- return failure();
- llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
- attributes.end());
- symbols = ArrayAttr::get(parser.getContext(), arrayAttr);
- return success();
-}
-
-static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op,
- mlir::OperandRange operands,
- mlir::TypeRange types,
- std::optional<mlir::ArrayAttr> attributes) {
- llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](auto it) {
- p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
- << std::get<1>(it).getType();
- });
-}
-
-//===----------------------------------------------------------------------===//
// ParallelOp
//===----------------------------------------------------------------------===//
@@ -1425,45 +1600,19 @@ static LogicalResult checkDataOperands(Op op,
return success();
}
-template <typename Op>
-static LogicalResult
-checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
- mlir::OperandRange operands, llvm::StringRef operandName,
- llvm::StringRef symbolName, bool checkOperandType = true) {
- if (!operands.empty()) {
- if (!attributes || attributes->size() != operands.size())
- return op->emitOpError()
- << "expected as many " << symbolName << " symbol reference as "
- << operandName << " operands";
- } else {
- if (attributes)
- return op->emitOpError()
- << "unexpected " << symbolName << " symbol reference";
- return success();
- }
-
+template <typename OpT, typename RecipeOpT>
+static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp,
+ const mlir::ValueRange &operands,
+ llvm::StringRef operandName) {
llvm::DenseSet<Value> set;
- for (auto args : llvm::zip(operands, *attributes)) {
- mlir::Value operand = std::get<0>(args);
-
+ for (mlir::Value operand : operands) {
+ if (!mlir::isa<OpT>(operand.getDefiningOp()))
+ return accConstructOp->emitOpError()
+ << "expected " << operandName << " as defining op";
if (!set.insert(operand).second)
- return op->emitOpError()
+ return accConstructOp->emitOpError()
<< operandName << " operand appears more than once";
-
- mlir::Type varType = operand.getType();
- auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
- auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
- if (!decl)
- return op->emitOpError()
- << "expected symbol reference " << symbolRef << " to point to a "
- << operandName << " declaration";
-
- if (checkOperandType && decl.getType() && decl.getType() != varType)
- return op->emitOpError() << "expected " << operandName << " (" << varType
- << ") to be the same type as " << operandName
- << " declaration (" << decl.getType() << ")";
}
-
return success();
}
@@ -1520,17 +1669,17 @@ static LogicalResult verifyDeviceTypeAndSegmentCountMatch(
}
LogicalResult acc::ParallelOp::verify() {
- if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
- *this, getPrivatizationRecipes(), getPrivateOperands(), "private",
- "privatizations", /*checkOperandType=*/false)))
+ if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
+ mlir::acc::PrivateRecipeOp>(
+ *this, getPrivateOperands(), "private")))
return failure();
- if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
- *this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
- "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
+ if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
+ mlir::acc::FirstprivateRecipeOp>(
+ *this, getFirstprivateOperands(), "firstprivate")))
return failure();
- if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
- *this, getReductionRecipes(), getReductionOperands(), "reduction",
- "reductions", false)))
+ if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
+ mlir::acc::ReductionRecipeOp>(
+ *this, getReductionOperands(), "reduction")))
return failure();
if (failed(verifyDeviceTypeAndSegmentCountMatch(
@@ -1661,7 +1810,6 @@ void ParallelOp::build(mlir::OpBuilder &odsBuilder,
mlir::ValueRange gangPrivateOperands,
mlir::ValueRange gangFirstPrivateOperands,
mlir::ValueRange dataClauseOperands) {
-
ParallelOp::build(
odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr,
/*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr,
@@ -1670,9 +1818,8 @@ void ParallelOp::build(mlir::OpBuilder &odsBuilder,
/*numGangsDeviceType=*/nullptr, numWorkers,
/*numWorkersDeviceType=*/nullptr, vectorLength,
/*vectorLengthDeviceType=*/nullptr, ifCond, selfCond,
- /*selfAttr=*/nullptr, reductionOperands, /*reductionRecipes=*/nullptr,
- gangPrivateOperands, /*privatizations=*/nullptr, gangFirstPrivateOperands,
- /*firstprivatizations=*/nullptr, dataClauseOperands,
+ /*selfAttr=*/nullptr, reductionOperands, gangPrivateOperands,
+ gangFirstPrivateOperands, dataClauseOperands,
/*defaultAttr=*/nullptr, /*combined=*/nullptr);
}
@@ -1749,46 +1896,22 @@ void acc::ParallelOp::addWaitOperands(
void acc::ParallelOp::addPrivatization(MLIRContext *context,
mlir::acc::PrivateOp op,
mlir::acc::PrivateRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getPrivateOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getPrivatizationRecipesAttr())
- llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
void acc::ParallelOp::addFirstPrivatization(
MLIRContext *context, mlir::acc::FirstprivateOp op,
mlir::acc::FirstprivateRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getFirstprivateOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getFirstprivatizationRecipesAttr())
- llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
void acc::ParallelOp::addReduction(MLIRContext *context,
mlir::acc::ReductionOp op,
mlir::acc::ReductionRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getReductionOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getReductionRecipesAttr())
- llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
static ParseResult parseNumGangs(
@@ -2356,17 +2479,17 @@ mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
}
LogicalResult acc::SerialOp::verify() {
- if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
- *this, getPrivatizationRecipes(), getPrivateOperands(), "private",
- "privatizations", /*checkOperandType=*/false)))
+ if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
+ mlir::acc::PrivateRecipeOp>(
+ *this, getPrivateOperands(), "private")))
return failure();
- if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
- *this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
- "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
+ if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
+ mlir::acc::FirstprivateRecipeOp>(
+ *this, getFirstprivateOperands(), "firstprivate")))
return failure();
- if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
- *this, getReductionRecipes(), getReductionOperands(), "reduction",
- "reductions", false)))
+ if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
+ mlir::acc::ReductionRecipeOp>(
+ *this, getReductionOperands(), "reduction")))
return failure();
if (failed(verifyDeviceTypeAndSegmentCountMatch(
@@ -2430,46 +2553,22 @@ void acc::SerialOp::addWaitOperands(
void acc::SerialOp::addPrivatization(MLIRContext *context,
mlir::acc::PrivateOp op,
mlir::acc::PrivateRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getPrivateOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getPrivatizationRecipesAttr())
- llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
void acc::SerialOp::addFirstPrivatization(
MLIRContext *context, mlir::acc::FirstprivateOp op,
mlir::acc::FirstprivateRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getFirstprivateOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getFirstprivatizationRecipesAttr())
- llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
void acc::SerialOp::addReduction(MLIRContext *context,
mlir::acc::ReductionOp op,
mlir::acc::ReductionRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getReductionOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getReductionRecipesAttr())
- llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
//===----------------------------------------------------------------------===//
@@ -2599,6 +2698,27 @@ LogicalResult acc::KernelsOp::verify() {
return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
}
+void acc::KernelsOp::addPrivatization(MLIRContext *context,
+ mlir::acc::PrivateOp op,
+ mlir::acc::PrivateRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
+ getPrivateOperandsMutable().append(op.getResult());
+}
+
+void acc::KernelsOp::addFirstPrivatization(
+ MLIRContext *context, mlir::acc::FirstprivateOp op,
+ mlir::acc::FirstprivateRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
+ getFirstprivateOperandsMutable().append(op.getResult());
+}
+
+void acc::KernelsOp::addReduction(MLIRContext *context,
+ mlir::acc::ReductionOp op,
+ mlir::acc::ReductionRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
+ getReductionOperandsMutable().append(op.getResult());
+}
+
void acc::KernelsOp::addNumWorkersOperand(
MLIRContext *context, mlir::Value newValue,
llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
@@ -2691,6 +2811,15 @@ void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
//===----------------------------------------------------------------------===//
+// KernelEnvironmentOp
+//===----------------------------------------------------------------------===//
+
+void acc::KernelEnvironmentOp::getCanonicalizationPatterns(
+ RewritePatternSet &results, MLIRContext *context) {
+ results.add<RemoveEmptyKernelEnvironment>(context);
+}
+
+//===----------------------------------------------------------------------===//
// LoopOp
//===----------------------------------------------------------------------===//
@@ -2899,19 +3028,21 @@ bool hasDuplicateDeviceTypes(
}
/// Check for duplicates in the DeviceType array attribute.
-LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
+/// Returns std::nullopt if no duplicates, or the duplicate DeviceType if found.
+static std::optional<mlir::acc::DeviceType>
+checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
if (!deviceTypes)
- return success();
+ return std::nullopt;
for (auto attr : deviceTypes) {
auto deviceTypeAttr =
mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
if (!deviceTypeAttr)
- return failure();
+ return mlir::acc::DeviceType::None;
if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
- return failure();
+ return deviceTypeAttr.getValue();
}
- return success();
+ return std::nullopt;
}
LogicalResult acc::LoopOp::verify() {
@@ -2938,9 +3069,10 @@ LogicalResult acc::LoopOp::verify() {
getCollapseDeviceTypeAttr().getValue().size())
return emitOpError() << "collapse attribute count must match collapse"
<< " device_type count";
- if (failed(checkDeviceTypes(getCollapseDeviceTypeAttr())))
- return emitOpError()
- << "duplicate device_type found in collapseDeviceType attribute";
+ if (auto duplicateDeviceType = checkDeviceTypes(getCollapseDeviceTypeAttr()))
+ return emitOpError() << "duplicate device_type `"
+ << acc::stringifyDeviceType(*duplicateDeviceType)
+ << "` found in collapseDeviceType attribute";
// Check gang
if (!getGangOperands().empty()) {
@@ -2953,8 +3085,12 @@ LogicalResult acc::LoopOp::verify() {
return emitOpError() << "gangOperandsArgType attribute count must match"
<< " gangOperands count";
}
- if (getGangAttr() && failed(checkDeviceTypes(getGangAttr())))
- return emitOpError() << "duplicate device_type found in gang attribute";
+ if (getGangAttr()) {
+ if (auto duplicateDeviceType = checkDeviceTypes(getGangAttr()))
+ return emitOpError() << "duplicate device_type `"
+ << acc::stringifyDeviceType(*duplicateDeviceType)
+ << "` found in gang attribute";
+ }
if (failed(verifyDeviceTypeAndSegmentCountMatch(
*this, getGangOperands(), getGangOperandsSegmentsAttr(),
@@ -2962,22 +3098,30 @@ LogicalResult acc::LoopOp::verify() {
return failure();
// Check worker
- if (failed(checkDeviceTypes(getWorkerAttr())))
- return emitOpError() << "duplicate device_type found in worker attribute";
- if (failed(checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr())))
- return emitOpError() << "duplicate device_type found in "
- "workerNumOperandsDeviceType attribute";
+ if (auto duplicateDeviceType = checkDeviceTypes(getWorkerAttr()))
+ return emitOpError() << "duplicate device_type `"
+ << acc::stringifyDeviceType(*duplicateDeviceType)
+ << "` found in worker attribute";
+ if (auto duplicateDeviceType =
+ checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr()))
+ return emitOpError() << "duplicate device_type `"
+ << acc::stringifyDeviceType(*duplicateDeviceType)
+ << "` found in workerNumOperandsDeviceType attribute";
if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(),
getWorkerNumOperandsDeviceTypeAttr(),
"worker")))
return failure();
// Check vector
- if (failed(checkDeviceTypes(getVectorAttr())))
- return emitOpError() << "duplicate device_type found in vector attribute";
- if (failed(checkDeviceTypes(getVectorOperandsDeviceTypeAttr())))
- return emitOpError() << "duplicate device_type found in "
- "vectorOperandsDeviceType attribute";
+ if (auto duplicateDeviceType = checkDeviceTypes(getVectorAttr()))
+ return emitOpError() << "duplicate device_type `"
+ << acc::stringifyDeviceType(*duplicateDeviceType)
+ << "` found in vector attribute";
+ if (auto duplicateDeviceType =
+ checkDeviceTypes(getVectorOperandsDeviceTypeAttr()))
+ return emitOpError() << "duplicate device_type `"
+ << acc::stringifyDeviceType(*duplicateDeviceType)
+ << "` found in vectorOperandsDeviceType attribute";
if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(),
getVectorOperandsDeviceTypeAttr(),
"vector")))
@@ -3042,19 +3186,19 @@ LogicalResult acc::LoopOp::verify() {
}
}
- if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
- *this, getPrivatizationRecipes(), getPrivateOperands(), "private",
- "privatizations", false)))
+ if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
+ mlir::acc::PrivateRecipeOp>(
+ *this, getPrivateOperands(), "private")))
return failure();
- if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
- *this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
- "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
+ if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
+ mlir::acc::FirstprivateRecipeOp>(
+ *this, getFirstprivateOperands(), "firstprivate")))
return failure();
- if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
- *this, getReductionRecipes(), getReductionOperands(), "reduction",
- "reductions", false)))
+ if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
+ mlir::acc::ReductionRecipeOp>(
+ *this, getReductionOperands(), "reduction")))
return failure();
if (getCombined().has_value() &&
@@ -3068,8 +3212,12 @@ LogicalResult acc::LoopOp::verify() {
if (getRegion().empty())
return emitError("expected non-empty body.");
- // When it is container-like - it is expected to hold a loop-like operation.
- if (isContainerLike()) {
+ if (getUnstructured()) {
+ if (!isContainerLike())
+ return emitError(
+ "unstructured acc.loop must not have induction variables");
+ } else if (isContainerLike()) {
+ // When it is container-like - it is expected to hold a loop-like operation.
// Obtain the maximum collapse count - we use this to check that there
// are enough loops contained.
uint64_t collapseCount = getCollapseValue().value_or(1);
@@ -3484,45 +3632,21 @@ void acc::LoopOp::addGangOperands(
void acc::LoopOp::addPrivatization(MLIRContext *context,
mlir::acc::PrivateOp op,
mlir::acc::PrivateRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getPrivateOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getPrivatizationRecipesAttr())
- llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
void acc::LoopOp::addFirstPrivatization(
MLIRContext *context, mlir::acc::FirstprivateOp op,
mlir::acc::FirstprivateRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getFirstprivateOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getFirstprivatizationRecipesAttr())
- llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
void acc::LoopOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op,
mlir::acc::ReductionRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getReductionOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getReductionRecipesAttr())
- llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
//===----------------------------------------------------------------------===//
@@ -3987,7 +4111,8 @@ LogicalResult acc::RoutineOp::verify() {
if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
- "be present at the same time";
+ "be present at the same time for device_type `"
+ << acc::stringifyDeviceType(dtype) << "`";
}
return success();
@@ -4284,6 +4409,100 @@ RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
return std::nullopt;
}
+void RoutineOp::addSeq(MLIRContext *context,
+ llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+ setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
+ effectiveDeviceTypes));
+}
+
+void RoutineOp::addVector(MLIRContext *context,
+ llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+ setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
+ effectiveDeviceTypes));
+}
+
+void RoutineOp::addWorker(MLIRContext *context,
+ llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+ setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
+ effectiveDeviceTypes));
+}
+
+void RoutineOp::addGang(MLIRContext *context,
+ llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+ setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
+ effectiveDeviceTypes));
+}
+
+void RoutineOp::addGang(MLIRContext *context,
+ llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
+ uint64_t val) {
+ llvm::SmallVector<mlir::Attribute> dimValues;
+ llvm::SmallVector<mlir::Attribute> deviceTypes;
+
+ if (getGangDimAttr())
+ llvm::copy(getGangDimAttr(), std::back_inserter(dimValues));
+ if (getGangDimDeviceTypeAttr())
+ llvm::copy(getGangDimDeviceTypeAttr(), std::back_inserter(deviceTypes));
+
+ assert(dimValues.size() == deviceTypes.size());
+
+ if (effectiveDeviceTypes.empty()) {
+ dimValues.push_back(
+ mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
+ deviceTypes.push_back(
+ acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
+ } else {
+ for (DeviceType dt : effectiveDeviceTypes) {
+ dimValues.push_back(
+ mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
+ deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
+ }
+ }
+ assert(dimValues.size() == deviceTypes.size());
+
+ setGangDimAttr(mlir::ArrayAttr::get(context, dimValues));
+ setGangDimDeviceTypeAttr(mlir::ArrayAttr::get(context, deviceTypes));
+}
+
+void RoutineOp::addBindStrName(MLIRContext *context,
+ llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
+ mlir::StringAttr val) {
+ unsigned before = getBindStrNameDeviceTypeAttr()
+ ? getBindStrNameDeviceTypeAttr().size()
+ : 0;
+
+ setBindStrNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
+ context, getBindStrNameDeviceTypeAttr(), effectiveDeviceTypes));
+ unsigned after = getBindStrNameDeviceTypeAttr().size();
+
+ llvm::SmallVector<mlir::Attribute> vals;
+ if (getBindStrNameAttr())
+ llvm::copy(getBindStrNameAttr(), std::back_inserter(vals));
+ for (unsigned i = 0; i < after - before; ++i)
+ vals.push_back(val);
+
+ setBindStrNameAttr(mlir::ArrayAttr::get(context, vals));
+}
+
+void RoutineOp::addBindIDName(MLIRContext *context,
+ llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
+ mlir::SymbolRefAttr val) {
+ unsigned before =
+ getBindIdNameDeviceTypeAttr() ? getBindIdNameDeviceTypeAttr().size() : 0;
+
+ setBindIdNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
+ context, getBindIdNameDeviceTypeAttr(), effectiveDeviceTypes));
+ unsigned after = getBindIdNameDeviceTypeAttr().size();
+
+ llvm::SmallVector<mlir::Attribute> vals;
+ if (getBindIdNameAttr())
+ llvm::copy(getBindIdNameAttr(), std::back_inserter(vals));
+ for (unsigned i = 0; i < after - before; ++i)
+ vals.push_back(val);
+
+ setBindIdNameAttr(mlir::ArrayAttr::get(context, vals));
+}
+
//===----------------------------------------------------------------------===//
// InitOp
//===----------------------------------------------------------------------===//
@@ -4667,3 +4886,12 @@ mlir::acc::getMutableDataOperands(mlir::Operation *accOp) {
.Default([&](mlir::Operation *) { return nullptr; })};
return dataOperands;
}
+
+mlir::SymbolRefAttr mlir::acc::getRecipe(mlir::Operation *accOp) {
+ auto recipe{
+ llvm::TypeSwitch<mlir::Operation *, mlir::SymbolRefAttr>(accOp)
+ .Case<ACC_DATA_ENTRY_OPS>(
+ [&](auto entry) { return entry.getRecipeAttr(); })
+ .Default([&](mlir::Operation *) { return mlir::SymbolRefAttr{}; })};
+ return recipe;
+}
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp
new file mode 100644
index 0000000..67cdf10
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp
@@ -0,0 +1,781 @@
+//===- ACCImplicitData.cpp ------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass implements the OpenACC specification for "Variables with
+// Implicitly Determined Data Attributes" (OpenACC 3.4 spec, section 2.6.2).
+//
+// Overview:
+// ---------
+// The pass automatically generates data clause operations for variables used
+// within OpenACC compute constructs (parallel, kernels, serial) that do not
+// already have explicit data clauses. The semantics follow these rules:
+//
+// 1. If there is a default(none) clause visible, no implicit data actions
+// apply.
+//
+// 2. An aggregate variable (arrays, derived types, etc.) will be treated as:
+// - In a present clause when default(present) is visible.
+// - In a copy clause otherwise.
+//
+// 3. A scalar variable will be treated as if it appears in:
+// - A copy clause if the compute construct is a kernels construct.
+// - A firstprivate clause otherwise (parallel, serial).
+//
+// Requirements:
+// -------------
+// To use this pass in a pipeline, the following requirements must be met:
+//
+// 1. Type Interface Implementation: Variables from the dialect being used
+// must implement one or both of the following MLIR interfaces:
+// `acc::MappableType` and/or `acc::PointerLikeType`
+//
+// These interfaces provide the necessary methods for the pass to:
+// - Determine variable type categories (scalar vs. aggregate)
+// - Generate appropriate bounds information
+// - Generate privatization recipes
+//
+// 2. Operation Interface Implementation: Operations that access partial
+// entities or create views should implement the following MLIR
+// interfaces: `acc::PartialEntityAccess` and/or
+// `mlir::ViewLikeOpInterface`
+//
+// These interfaces are used for proper data clause ordering, ensuring
+// that base entities are mapped before derived entities (e.g., a
+// struct is mapped before its fields, an array is mapped before
+// subarray views).
+//
+// 3. Analysis Registration (Optional): If custom behavior is needed for
+// variable name extraction or alias analysis, the dialect should
+// pre-register the `acc::OpenACCSupport` and `mlir::AliasAnalysis` analyses.
+//
+// If not registered, default behavior will be used.
+//
+// Implementation Details:
+// -----------------------
+// The pass performs the following operations:
+//
+// 1. Finds candidate variables which are live-in to the compute region and
+// are not already in a data clause or private clause.
+//
+// 2. Generates both data "entry" and "exit" clause operations that match
+// the intended action depending on variable type:
+// - copy -> acc.copyin (entry) + acc.copyout (exit)
+// - present -> acc.present (entry) + acc.delete (exit)
+// - firstprivate -> acc.firstprivate (entry only, no exit)
+//
+// 3. Ensures that default clause is taken into consideration by looking
+// through current construct and parent constructs to find the "visible
+// default clause".
+//
+// 4. Fixes up SSA value links so that uses in the acc region reference the
+// result of the newly created data clause operations.
+//
+// 5. When generating implicit data clause operations, it also adds variable
+// name information and marks them with the implicit flag.
+//
+// 6. Recipes are generated by calling the appropriate entrypoints in the
+// MappableType and PointerLikeType interfaces.
+//
+// 7. AliasAnalysis is used to determine if a variable is already covered by
+// an existing data clause (e.g., an interior pointer covered by its parent).
+//
+// Examples:
+// ---------
+//
+// Example 1: Scalar in parallel construct (implicit firstprivate)
+//
+// Before:
+// func.func @test() {
+// %scalar = memref.alloca() {acc.var_name = "x"} : memref<f32>
+// acc.parallel {
+// %val = memref.load %scalar[] : memref<f32>
+// acc.yield
+// }
+// }
+//
+// After:
+// func.func @test() {
+// %scalar = memref.alloca() {acc.var_name = "x"} : memref<f32>
+// %firstpriv = acc.firstprivate varPtr(%scalar : memref<f32>)
+// -> memref<f32> {implicit = true, name = "x"}
+// acc.parallel firstprivate(@recipe -> %firstpriv : memref<f32>) {
+// %val = memref.load %firstpriv[] : memref<f32>
+// acc.yield
+// }
+// }
+//
+// Example 2: Scalar in kernels construct (implicit copy)
+//
+// Before:
+// func.func @test() {
+// %scalar = memref.alloca() {acc.var_name = "n"} : memref<i32>
+// acc.kernels {
+// %val = memref.load %scalar[] : memref<i32>
+// acc.terminator
+// }
+// }
+//
+// After:
+// func.func @test() {
+// %scalar = memref.alloca() {acc.var_name = "n"} : memref<i32>
+// %copyin = acc.copyin varPtr(%scalar : memref<i32>) -> memref<i32>
+// {dataClause = #acc<data_clause acc_copy>,
+// implicit = true, name = "n"}
+// acc.kernels dataOperands(%copyin : memref<i32>) {
+// %val = memref.load %copyin[] : memref<i32>
+// acc.terminator
+// }
+// acc.copyout accPtr(%copyin : memref<i32>)
+// to varPtr(%scalar : memref<i32>)
+// {dataClause = #acc<data_clause acc_copy>,
+// implicit = true, name = "n"}
+// }
+//
+// Example 3: Array (aggregate) in parallel (implicit copy)
+//
+// Before:
+// func.func @test() {
+// %array = memref.alloca() {acc.var_name = "arr"} : memref<100xf32>
+// acc.parallel {
+// %c0 = arith.constant 0 : index
+// %val = memref.load %array[%c0] : memref<100xf32>
+// acc.yield
+// }
+// }
+//
+// After:
+// func.func @test() {
+// %array = memref.alloca() {acc.var_name = "arr"} : memref<100xf32>
+// %copyin = acc.copyin varPtr(%array : memref<100xf32>)
+// -> memref<100xf32>
+// {dataClause = #acc<data_clause acc_copy>,
+// implicit = true, name = "arr"}
+// acc.parallel dataOperands(%copyin : memref<100xf32>) {
+// %c0 = arith.constant 0 : index
+// %val = memref.load %copyin[%c0] : memref<100xf32>
+// acc.yield
+// }
+// acc.copyout accPtr(%copyin : memref<100xf32>)
+// to varPtr(%array : memref<100xf32>)
+// {dataClause = #acc<data_clause acc_copy>,
+// implicit = true, name = "arr"}
+// }
+//
+// Example 4: Array with default(present)
+//
+// Before:
+// func.func @test() {
+// %array = memref.alloca() {acc.var_name = "arr"} : memref<100xf32>
+// acc.parallel {
+// %c0 = arith.constant 0 : index
+// %val = memref.load %array[%c0] : memref<100xf32>
+// acc.yield
+// } attributes {defaultAttr = #acc<defaultvalue present>}
+// }
+//
+// After:
+// func.func @test() {
+// %array = memref.alloca() {acc.var_name = "arr"} : memref<100xf32>
+// %present = acc.present varPtr(%array : memref<100xf32>)
+// -> memref<100xf32>
+// {implicit = true, name = "arr"}
+// acc.parallel dataOperands(%present : memref<100xf32>)
+// attributes {defaultAttr = #acc<defaultvalue present>} {
+// %c0 = arith.constant 0 : index
+// %val = memref.load %present[%c0] : memref<100xf32>
+// acc.yield
+// }
+// acc.delete accPtr(%present : memref<100xf32>)
+// {dataClause = #acc<data_clause acc_present>,
+// implicit = true, name = "arr"}
+// }
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+
+#include "mlir/Analysis/AliasAnalysis.h"
+#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/Dialect/OpenACC/OpenACCUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/ErrorHandling.h"
+#include <type_traits>
+
+namespace mlir {
+namespace acc {
+#define GEN_PASS_DEF_ACCIMPLICITDATA
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
+} // namespace acc
+} // namespace mlir
+
+#define DEBUG_TYPE "acc-implicit-data"
+
+using namespace mlir;
+
+namespace {
+
+class ACCImplicitData : public acc::impl::ACCImplicitDataBase<ACCImplicitData> {
+public:
+ using acc::impl::ACCImplicitDataBase<ACCImplicitData>::ACCImplicitDataBase;
+
+ void runOnOperation() override;
+
+private:
+ /// Looks through the `dominatingDataClauses` to find the original data clause
+ /// op for an alias. Returns nullptr if no original data clause op is found.
+ template <typename OpT>
+ Operation *getOriginalDataClauseOpForAlias(
+ Value var, OpBuilder &builder, OpT computeConstructOp,
+ const SmallVector<Value> &dominatingDataClauses);
+
+ /// Generates the appropriate `acc.copyin`, `acc.present`,`acc.firstprivate`,
+ /// etc. data clause op for a candidate variable.
+ template <typename OpT>
+ Operation *generateDataClauseOpForCandidate(
+ Value var, ModuleOp &module, OpBuilder &builder, OpT computeConstructOp,
+ const SmallVector<Value> &dominatingDataClauses,
+ const std::optional<acc::ClauseDefaultValue> &defaultClause);
+
+ /// Generates the implicit data ops for a compute construct.
+ template <typename OpT>
+ void generateImplicitDataOps(
+ ModuleOp &module, OpT computeConstructOp,
+ std::optional<acc::ClauseDefaultValue> &defaultClause);
+
+ /// Generates a private recipe for a variable.
+ acc::PrivateRecipeOp generatePrivateRecipe(ModuleOp &module, Value var,
+ Location loc, OpBuilder &builder,
+ acc::OpenACCSupport &accSupport);
+
+ /// Generates a firstprivate recipe for a variable.
+ acc::FirstprivateRecipeOp
+ generateFirstprivateRecipe(ModuleOp &module, Value var, Location loc,
+ OpBuilder &builder,
+ acc::OpenACCSupport &accSupport);
+
+ /// Generates recipes for a list of variables.
+ void generateRecipes(ModuleOp &module, OpBuilder &builder,
+ Operation *computeConstructOp,
+ const SmallVector<Value> &newOperands);
+};
+
+/// Determines if a variable is a candidate for implicit data mapping.
+/// Returns true if the variable is a candidate, false otherwise.
+static bool isCandidateForImplicitData(Value val, Region &accRegion) {
+ // Ensure the variable is an allowed type for data clause.
+ if (!acc::isPointerLikeType(val.getType()) &&
+ !acc::isMappableType(val.getType()))
+ return false;
+
+ // If this is already coming from a data clause, we do not need to generate
+ // another.
+ if (isa_and_nonnull<ACC_DATA_ENTRY_OPS>(val.getDefiningOp()))
+ return false;
+
+ // If this is only used by private clauses, it is not a real live-in.
+ if (acc::isOnlyUsedByPrivateClauses(val, accRegion))
+ return false;
+
+ return true;
+}
+
+template <typename OpT>
+Operation *ACCImplicitData::getOriginalDataClauseOpForAlias(
+ Value var, OpBuilder &builder, OpT computeConstructOp,
+ const SmallVector<Value> &dominatingDataClauses) {
+ auto &aliasAnalysis = this->getAnalysis<AliasAnalysis>();
+ for (auto dataClause : dominatingDataClauses) {
+ if (auto *dataClauseOp = dataClause.getDefiningOp()) {
+ // Only accept clauses that guarantee that the alias is present.
+ if (isa<acc::CopyinOp, acc::CreateOp, acc::PresentOp, acc::NoCreateOp,
+ acc::DevicePtrOp>(dataClauseOp))
+ if (aliasAnalysis.alias(acc::getVar(dataClauseOp), var).isMust())
+ return dataClauseOp;
+ }
+ }
+ return nullptr;
+}
+
+// Generates bounds for variables that have unknown dimensions
+static void fillInBoundsForUnknownDimensions(Operation *dataClauseOp,
+ OpBuilder &builder) {
+
+ if (!acc::getBounds(dataClauseOp).empty())
+ // If bounds are already present, do not overwrite them.
+ return;
+
+ // For types that have unknown dimensions, attempt to generate bounds by
+ // relying on MappableType being able to extract it from the IR.
+ auto var = acc::getVar(dataClauseOp);
+ auto type = var.getType();
+ if (auto mappableTy = dyn_cast<acc::MappableType>(type)) {
+ if (mappableTy.hasUnknownDimensions()) {
+ TypeSwitch<Operation *>(dataClauseOp)
+ .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClauseOp) {
+ if (std::is_same_v<decltype(dataClauseOp), acc::DevicePtrOp>)
+ return;
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPoint(dataClauseOp);
+ auto bounds = mappableTy.generateAccBounds(var, builder);
+ if (!bounds.empty())
+ dataClauseOp.getBoundsMutable().assign(bounds);
+ });
+ }
+ }
+}
+
+acc::PrivateRecipeOp
+ACCImplicitData::generatePrivateRecipe(ModuleOp &module, Value var,
+ Location loc, OpBuilder &builder,
+ acc::OpenACCSupport &accSupport) {
+ auto type = var.getType();
+ std::string recipeName =
+ accSupport.getRecipeName(acc::RecipeKind::private_recipe, type, var);
+
+ // Check if recipe already exists
+ auto existingRecipe = module.lookupSymbol<acc::PrivateRecipeOp>(recipeName);
+ if (existingRecipe)
+ return existingRecipe;
+
+ // Set insertion point to module body in a scoped way
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointToStart(module.getBody());
+
+ auto recipe =
+ acc::PrivateRecipeOp::createAndPopulate(builder, loc, recipeName, type);
+ if (!recipe.has_value())
+ return accSupport.emitNYI(loc, "implicit private"), nullptr;
+ return recipe.value();
+}
+
+acc::FirstprivateRecipeOp
+ACCImplicitData::generateFirstprivateRecipe(ModuleOp &module, Value var,
+ Location loc, OpBuilder &builder,
+ acc::OpenACCSupport &accSupport) {
+ auto type = var.getType();
+ std::string recipeName =
+ accSupport.getRecipeName(acc::RecipeKind::firstprivate_recipe, type, var);
+
+ // Check if recipe already exists
+ auto existingRecipe =
+ module.lookupSymbol<acc::FirstprivateRecipeOp>(recipeName);
+ if (existingRecipe)
+ return existingRecipe;
+
+ // Set insertion point to module body in a scoped way
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointToStart(module.getBody());
+
+ auto recipe = acc::FirstprivateRecipeOp::createAndPopulate(builder, loc,
+ recipeName, type);
+ if (!recipe.has_value())
+ return accSupport.emitNYI(loc, "implicit firstprivate"), nullptr;
+ return recipe.value();
+}
+
+void ACCImplicitData::generateRecipes(ModuleOp &module, OpBuilder &builder,
+ Operation *computeConstructOp,
+ const SmallVector<Value> &newOperands) {
+ auto &accSupport = this->getAnalysis<acc::OpenACCSupport>();
+ for (auto var : newOperands) {
+ auto loc{var.getLoc()};
+ if (auto privateOp = dyn_cast<acc::PrivateOp>(var.getDefiningOp())) {
+ auto recipe = generatePrivateRecipe(
+ module, acc::getVar(var.getDefiningOp()), loc, builder, accSupport);
+ if (recipe)
+ privateOp.setRecipeAttr(
+ SymbolRefAttr::get(module->getContext(), recipe.getSymName()));
+ } else if (auto firstprivateOp =
+ dyn_cast<acc::FirstprivateOp>(var.getDefiningOp())) {
+ auto recipe = generateFirstprivateRecipe(
+ module, acc::getVar(var.getDefiningOp()), loc, builder, accSupport);
+ if (recipe)
+ firstprivateOp.setRecipeAttr(SymbolRefAttr::get(
+ module->getContext(), recipe.getSymName().str()));
+ } else {
+ accSupport.emitNYI(var.getLoc(), "implicit reduction");
+ }
+ }
+}
+
+// Generates the data entry data op clause so that it adheres to OpenACC
+// rules as follows (line numbers and specification from OpenACC 3.4):
+// 1388 An aggregate variable will be treated as if it appears either:
+// 1389 - In a present clause if there is a default(present) clause visible at
+// the compute construct.
+// 1391 - In a copy clause otherwise.
+// 1392 A scalar variable will be treated as if it appears either:
+// 1393 - In a copy clause if the compute construct is a kernels construct.
+// 1394 - In a firstprivate clause otherwise.
+template <typename OpT>
+Operation *ACCImplicitData::generateDataClauseOpForCandidate(
+ Value var, ModuleOp &module, OpBuilder &builder, OpT computeConstructOp,
+ const SmallVector<Value> &dominatingDataClauses,
+ const std::optional<acc::ClauseDefaultValue> &defaultClause) {
+ auto &accSupport = this->getAnalysis<acc::OpenACCSupport>();
+ acc::VariableTypeCategory typeCategory =
+ acc::VariableTypeCategory::uncategorized;
+ if (auto mappableTy = dyn_cast<acc::MappableType>(var.getType())) {
+ typeCategory = mappableTy.getTypeCategory(var);
+ } else if (auto pointerLikeTy =
+ dyn_cast<acc::PointerLikeType>(var.getType())) {
+ typeCategory = pointerLikeTy.getPointeeTypeCategory(
+ cast<TypedValue<acc::PointerLikeType>>(var),
+ pointerLikeTy.getElementType());
+ }
+
+ bool isScalar =
+ acc::bitEnumContainsAny(typeCategory, acc::VariableTypeCategory::scalar);
+ bool isAnyAggregate = acc::bitEnumContainsAny(
+ typeCategory, acc::VariableTypeCategory::aggregate);
+ Location loc = computeConstructOp->getLoc();
+
+ Operation *op = nullptr;
+ op = getOriginalDataClauseOpForAlias(var, builder, computeConstructOp,
+ dominatingDataClauses);
+ if (op) {
+ if (isa<acc::NoCreateOp>(op))
+ return acc::NoCreateOp::create(builder, loc, var,
+ /*structured=*/true, /*implicit=*/true,
+ accSupport.getVariableName(var),
+ acc::getBounds(op));
+
+ if (isa<acc::DevicePtrOp>(op))
+ return acc::DevicePtrOp::create(builder, loc, var,
+ /*structured=*/true, /*implicit=*/true,
+ accSupport.getVariableName(var),
+ acc::getBounds(op));
+
+ // The original data clause op is a PresentOp, CopyinOp, or CreateOp,
+ // hence guaranteed to be present.
+ return acc::PresentOp::create(builder, loc, var,
+ /*structured=*/true, /*implicit=*/true,
+ accSupport.getVariableName(var),
+ acc::getBounds(op));
+ } else if (isScalar) {
+ if (enableImplicitReductionCopy &&
+ acc::isOnlyUsedByReductionClauses(var,
+ computeConstructOp->getRegion(0))) {
+ auto copyinOp =
+ acc::CopyinOp::create(builder, loc, var,
+ /*structured=*/true, /*implicit=*/true,
+ accSupport.getVariableName(var));
+ copyinOp.setDataClause(acc::DataClause::acc_reduction);
+ return copyinOp.getOperation();
+ }
+ if constexpr (std::is_same_v<OpT, acc::KernelsOp> ||
+ std::is_same_v<OpT, acc::KernelEnvironmentOp>) {
+ // Scalars are implicit copyin in kernels construct.
+ // We also do the same for acc.kernel_environment because semantics
+ // of user variable mappings should be applied while ACC construct exists
+ // and at this point we should only be dealing with unmapped variables
+ // that were made live-in by the compiler.
+ // TODO: This may be revisited.
+ auto copyinOp =
+ acc::CopyinOp::create(builder, loc, var,
+ /*structured=*/true, /*implicit=*/true,
+ accSupport.getVariableName(var));
+ copyinOp.setDataClause(acc::DataClause::acc_copy);
+ return copyinOp.getOperation();
+ } else {
+ // Scalars are implicit firstprivate in parallel and serial construct.
+ return acc::FirstprivateOp::create(builder, loc, var,
+ /*structured=*/true, /*implicit=*/true,
+ accSupport.getVariableName(var));
+ }
+ } else if (isAnyAggregate) {
+ Operation *newDataOp = nullptr;
+
+ // When default(present) is true, the implicit behavior is present.
+ if (defaultClause.has_value() &&
+ defaultClause.value() == acc::ClauseDefaultValue::Present) {
+ newDataOp = acc::PresentOp::create(builder, loc, var,
+ /*structured=*/true, /*implicit=*/true,
+ accSupport.getVariableName(var));
+ newDataOp->setAttr(acc::getFromDefaultClauseAttrName(),
+ builder.getUnitAttr());
+ } else {
+ auto copyinOp =
+ acc::CopyinOp::create(builder, loc, var,
+ /*structured=*/true, /*implicit=*/true,
+ accSupport.getVariableName(var));
+ copyinOp.setDataClause(acc::DataClause::acc_copy);
+ newDataOp = copyinOp.getOperation();
+ }
+
+ return newDataOp;
+ } else {
+ // This is not a fatal error - for example when the element type is
+ // pointer type (aka we have a pointer of pointer), it is potentially a
+ // deep copy scenario which is not being handled here.
+ // Other types need to be canonicalized. Thus just log unhandled cases.
+ LLVM_DEBUG(llvm::dbgs()
+ << "Unhandled case for implicit data mapping " << var << "\n");
+ }
+ return nullptr;
+}
+
+// Ensures that result values from the acc data clause ops are used inside the
+// acc region. ie:
+// acc.kernels {
+// use %val
+// }
+// =>
+// %dev = acc.dataop %val
+// acc.kernels {
+// use %dev
+// }
+static void legalizeValuesInRegion(Region &accRegion,
+ SmallVector<Value> &newPrivateOperands,
+ SmallVector<Value> &newDataClauseOperands) {
+ for (Value dataClause :
+ llvm::concat<Value>(newDataClauseOperands, newPrivateOperands)) {
+ Value var = acc::getVar(dataClause.getDefiningOp());
+ replaceAllUsesInRegionWith(var, dataClause, accRegion);
+ }
+}
+
+// Adds the private operands to the compute construct operation.
+template <typename OpT>
+static void addNewPrivateOperands(OpT &accOp,
+ const SmallVector<Value> &privateOperands) {
+ if (privateOperands.empty())
+ return;
+
+ for (auto priv : privateOperands) {
+ if (isa<acc::PrivateOp>(priv.getDefiningOp())) {
+ accOp.getPrivateOperandsMutable().append(priv);
+ } else if (isa<acc::FirstprivateOp>(priv.getDefiningOp())) {
+ accOp.getFirstprivateOperandsMutable().append(priv);
+ } else {
+ llvm_unreachable("unhandled reduction operand");
+ }
+ }
+}
+
+static Operation *findDataExitOp(Operation *dataEntryOp) {
+ auto res = acc::getAccVar(dataEntryOp);
+ for (auto *user : res.getUsers())
+ if (isa<ACC_DATA_EXIT_OPS>(user))
+ return user;
+ return nullptr;
+}
+
+// Generates matching data exit operation as described in the acc dialect
+// for how data clauses are decomposed:
+// https://mlir.llvm.org/docs/Dialects/OpenACCDialect/#operation-categories
+// Key ones used here:
+// * acc {construct} copy -> acc.copyin (before region) + acc.copyout (after
+// region)
+// * acc {construct} present -> acc.present (before region) + acc.delete
+// (after region)
+static void
+generateDataExitOperations(OpBuilder &builder, Operation *accOp,
+ const SmallVector<Value> &newDataClauseOperands,
+ const SmallVector<Value> &sortedDataClauseOperands) {
+ builder.setInsertionPointAfter(accOp);
+ Value lastDataClause = nullptr;
+ for (auto dataEntry : llvm::reverse(sortedDataClauseOperands)) {
+ if (llvm::find(newDataClauseOperands, dataEntry) ==
+ newDataClauseOperands.end()) {
+ // If this is not a new data clause operand, we should not generate an
+ // exit operation for it.
+ lastDataClause = dataEntry;
+ continue;
+ }
+ if (lastDataClause)
+ if (auto *dataExitOp = findDataExitOp(lastDataClause.getDefiningOp()))
+ builder.setInsertionPointAfter(dataExitOp);
+ Operation *dataEntryOp = dataEntry.getDefiningOp();
+ if (isa<acc::CopyinOp>(dataEntryOp)) {
+ auto copyoutOp = acc::CopyoutOp::create(
+ builder, dataEntryOp->getLoc(), dataEntry, acc::getVar(dataEntryOp),
+ /*structured=*/true, /*implicit=*/true,
+ acc::getVarName(dataEntryOp).value(), acc::getBounds(dataEntryOp));
+ copyoutOp.setDataClause(acc::DataClause::acc_copy);
+ } else if (isa<acc::PresentOp, acc::NoCreateOp>(dataEntryOp)) {
+ auto deleteOp = acc::DeleteOp::create(
+ builder, dataEntryOp->getLoc(), dataEntry,
+ /*structured=*/true, /*implicit=*/true,
+ acc::getVarName(dataEntryOp).value(), acc::getBounds(dataEntryOp));
+ deleteOp.setDataClause(acc::getDataClause(dataEntryOp).value());
+ } else if (isa<acc::DevicePtrOp>(dataEntryOp)) {
+ // Do nothing.
+ } else {
+ llvm_unreachable("unhandled data exit");
+ }
+ lastDataClause = dataEntry;
+ }
+}
+
+/// Returns all base references of a value in order.
+/// So for example, if we have a reference to a struct field like
+/// s.f1.f2.f3, this will return <s, s.f1, s.f1.f2, s.f1.f2.f3>.
+/// Any intermediate casts/view-like operations are included in the
+/// chain as well.
+static SmallVector<Value> getBaseRefsChain(Value val) {
+ SmallVector<Value> baseRefs;
+ baseRefs.push_back(val);
+ while (true) {
+ Value prevVal = val;
+
+ val = acc::getBaseEntity(val);
+ if (val != baseRefs.front())
+ baseRefs.insert(baseRefs.begin(), val);
+
+ // If this is a view-like operation, it is effectively another
+ // view of the same entity so we should add it to the chain also.
+ if (auto viewLikeOp = val.getDefiningOp<ViewLikeOpInterface>()) {
+ val = viewLikeOp.getViewSource();
+ baseRefs.insert(baseRefs.begin(), val);
+ }
+
+ // Continue loop if we made any progress
+ if (val == prevVal)
+ break;
+ }
+
+ return baseRefs;
+}
+
+static void insertInSortedOrder(SmallVector<Value> &sortedDataClauseOperands,
+ Operation *newClause) {
+ auto *insertPos =
+ std::find_if(sortedDataClauseOperands.begin(),
+ sortedDataClauseOperands.end(), [&](Value dataClauseVal) {
+ // Get the base refs for the current clause we are looking
+ // at.
+ auto var = acc::getVar(dataClauseVal.getDefiningOp());
+ auto baseRefs = getBaseRefsChain(var);
+
+ // If the newClause is of a base ref of an existing clause,
+ // we should insert it right before the current clause.
+ // Thus return true to stop iteration when this is the
+ // case.
+ return std::find(baseRefs.begin(), baseRefs.end(),
+ acc::getVar(newClause)) != baseRefs.end();
+ });
+
+ if (insertPos != sortedDataClauseOperands.end()) {
+ newClause->moveBefore(insertPos->getDefiningOp());
+ sortedDataClauseOperands.insert(insertPos, acc::getAccVar(newClause));
+ } else {
+ sortedDataClauseOperands.push_back(acc::getAccVar(newClause));
+ }
+}
+
+template <typename OpT>
+void ACCImplicitData::generateImplicitDataOps(
+ ModuleOp &module, OpT computeConstructOp,
+ std::optional<acc::ClauseDefaultValue> &defaultClause) {
+ // Implicit data attributes are only applied if "[t]here is no default(none)
+ // clause visible at the compute construct."
+ if (defaultClause.has_value() &&
+ defaultClause.value() == acc::ClauseDefaultValue::None)
+ return;
+ assert(!defaultClause.has_value() ||
+ defaultClause.value() == acc::ClauseDefaultValue::Present);
+
+ // 1) Collect live-in values.
+ Region &accRegion = computeConstructOp->getRegion(0);
+ SetVector<Value> liveInValues;
+ getUsedValuesDefinedAbove(accRegion, liveInValues);
+
+ // 2) Run the filtering to find relevant pointers that need copied.
+ auto isCandidate{[&](Value val) -> bool {
+ return isCandidateForImplicitData(val, accRegion);
+ }};
+ auto candidateVars(
+ llvm::to_vector(llvm::make_filter_range(liveInValues, isCandidate)));
+ if (candidateVars.empty())
+ return;
+
+ // 3) Generate data clauses for the variables.
+ SmallVector<Value> newPrivateOperands;
+ SmallVector<Value> newDataClauseOperands;
+ OpBuilder builder(computeConstructOp);
+ if (!candidateVars.empty()) {
+ LLVM_DEBUG(llvm::dbgs() << "== Generating clauses for ==\n"
+ << computeConstructOp << "\n");
+ }
+ auto &domInfo = this->getAnalysis<DominanceInfo>();
+ auto &postDomInfo = this->getAnalysis<PostDominanceInfo>();
+ auto dominatingDataClauses =
+ acc::getDominatingDataClauses(computeConstructOp, domInfo, postDomInfo);
+ for (auto var : candidateVars) {
+ auto newDataClauseOp = generateDataClauseOpForCandidate(
+ var, module, builder, computeConstructOp, dominatingDataClauses,
+ defaultClause);
+ fillInBoundsForUnknownDimensions(newDataClauseOp, builder);
+ LLVM_DEBUG(llvm::dbgs() << "Generated data clause for " << var << ":\n"
+ << "\t" << *newDataClauseOp << "\n");
+ if (isa_and_nonnull<acc::PrivateOp, acc::FirstprivateOp, acc::ReductionOp>(
+ newDataClauseOp)) {
+ newPrivateOperands.push_back(acc::getAccVar(newDataClauseOp));
+ } else if (isa_and_nonnull<ACC_DATA_CLAUSE_OPS>(newDataClauseOp)) {
+ newDataClauseOperands.push_back(acc::getAccVar(newDataClauseOp));
+ dominatingDataClauses.push_back(acc::getAccVar(newDataClauseOp));
+ }
+ }
+
+ // 4) Legalize values in region (aka the uses in the region are the result
+ // of the data clause ops)
+ legalizeValuesInRegion(accRegion, newPrivateOperands, newDataClauseOperands);
+
+ // 5) Generate private recipes which are required for properly attaching
+ // private operands.
+ if constexpr (!std::is_same_v<OpT, acc::KernelsOp> &&
+ !std::is_same_v<OpT, acc::KernelEnvironmentOp>)
+ generateRecipes(module, builder, computeConstructOp, newPrivateOperands);
+
+ // 6) Figure out insertion order for the new data clause operands.
+ SmallVector<Value> sortedDataClauseOperands(
+ computeConstructOp.getDataClauseOperands());
+ for (auto newClause : newDataClauseOperands)
+ insertInSortedOrder(sortedDataClauseOperands, newClause.getDefiningOp());
+
+ // 7) Generate the data exit operations.
+ generateDataExitOperations(builder, computeConstructOp, newDataClauseOperands,
+ sortedDataClauseOperands);
+ // 8) Add all of the new operands to the compute construct op.
+ if constexpr (!std::is_same_v<OpT, acc::KernelsOp> &&
+ !std::is_same_v<OpT, acc::KernelEnvironmentOp>)
+ addNewPrivateOperands(computeConstructOp, newPrivateOperands);
+ computeConstructOp.getDataClauseOperandsMutable().assign(
+ sortedDataClauseOperands);
+}
+
+void ACCImplicitData::runOnOperation() {
+ ModuleOp module = this->getOperation();
+ module.walk([&](Operation *op) {
+ if (isa<ACC_COMPUTE_CONSTRUCT_OPS, acc::KernelEnvironmentOp>(op)) {
+ assert(op->getNumRegions() == 1 && "must have 1 region");
+
+ auto defaultClause = acc::getDefaultAttr(op);
+ llvm::TypeSwitch<Operation *, void>(op)
+ .Case<ACC_COMPUTE_CONSTRUCT_OPS, acc::KernelEnvironmentOp>(
+ [&](auto op) {
+ generateImplicitDataOps(module, op, defaultClause);
+ })
+ .Default([&](Operation *) {});
+ }
+ });
+}
+
+} // namespace
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp
new file mode 100644
index 0000000..8cab223
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp
@@ -0,0 +1,431 @@
+//===- ACCImplicitDeclare.cpp ---------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass applies implicit `acc declare` actions to global variables
+// referenced in OpenACC compute regions and routine functions.
+//
+// Overview:
+// ---------
+// Global references in an acc regions (for globals not marked with `acc
+// declare` by the user) can be handled in one of two ways:
+// - Mapped through data clauses
+// - Implicitly marked as `acc declare` (this pass)
+//
+// Thus, the OpenACC specification focuses solely on implicit data mapping rules
+// whose implementation is captured in `ACCImplicitData` pass.
+//
+// However, it is both advantageous and required for certain cases to
+// use implicit `acc declare` instead:
+// - Any functions that are implicitly marked as `acc routine` through
+// `ACCImplicitRoutine` may reference globals. Since data mapping
+// is only possible for compute regions, such globals can only be
+// made available on device through `acc declare`.
+// - Compiler can generate and use globals for cases needed in IR
+// representation such as type descriptors or various names needed for
+// runtime calls and error reporting - such cases often are introduced
+// after a frontend semantic checking is done since it is related to
+// implementation detail. Thus, such compiler generated globals would
+// not have been visible for a user to mark with `acc declare`.
+// - Constant globals such as filename strings or data initialization values
+// are values that do not get mutated but are still needed for appropriate
+// runtime execution. If a kernel is launched 1000 times, it is not a
+// good idea to map such a global 1000 times. Therefore, such globals
+// benefit from being marked with `acc declare`.
+//
+// This pass automatically
+// marks global variables with the `acc.declare` attribute when they are
+// referenced in OpenACC compute constructs or routine functions and meet
+// the criteria noted above, ensuring
+// they are properly handled for device execution.
+//
+// The pass performs two main optimizations:
+//
+// 1. Hoisting: For non-constant globals referenced in compute regions, the
+// pass hoists the address-of operation out of the region when possible,
+// allowing them to be implicitly mapped through normal data clause
+// mechanisms rather than requiring declare marking.
+//
+// 2. Declaration: For globals that must be available on the device (constants,
+// globals in routines, globals in recipe operations), the pass adds the
+// `acc.declare` attribute with the copyin data clause.
+//
+// Requirements:
+// -------------
+// To use this pass in a pipeline, the following requirements must be met:
+//
+// 1. Operation Interface Implementation: Operations that compute addresses
+// of global variables must implement the `acc::AddressOfGlobalOpInterface`
+// and those that represent globals must implement the
+// `acc::GlobalOpInterface`. Additionally, any operations that indirectly
+// access globals must implement the `acc::IndirectGlobalAccessOpInterface`.
+//
+// 2. Analysis Registration (Optional): If custom behavior is needed for
+// determining if a symbol use is valid within GPU regions, the dialect
+// should pre-register the `acc::OpenACCSupport` analysis.
+//
+// Examples:
+// ---------
+//
+// Example 1: Non-constant global in compute region (hoisted)
+//
+// Before:
+// memref.global @g_scalar : memref<f32> = dense<0.0>
+// func.func @test() {
+// acc.serial {
+// %addr = memref.get_global @g_scalar : memref<f32>
+// %val = memref.load %addr[] : memref<f32>
+// acc.yield
+// }
+// }
+//
+// After:
+// memref.global @g_scalar : memref<f32> = dense<0.0>
+// func.func @test() {
+// %addr = memref.get_global @g_scalar : memref<f32>
+// acc.serial {
+// %val = memref.load %addr[] : memref<f32>
+// acc.yield
+// }
+// }
+//
+// Example 2: Constant global in compute region (declared)
+//
+// Before:
+// memref.global constant @g_const : memref<f32> = dense<1.0>
+// func.func @test() {
+// acc.serial {
+// %addr = memref.get_global @g_const : memref<f32>
+// %val = memref.load %addr[] : memref<f32>
+// acc.yield
+// }
+// }
+//
+// After:
+// memref.global constant @g_const : memref<f32> = dense<1.0>
+// {acc.declare = #acc.declare<dataClause = acc_copyin>}
+// func.func @test() {
+// acc.serial {
+// %addr = memref.get_global @g_const : memref<f32>
+// %val = memref.load %addr[] : memref<f32>
+// acc.yield
+// }
+// }
+//
+// Example 3: Global in acc routine (declared)
+//
+// Before:
+// memref.global @g_data : memref<f32> = dense<0.0>
+// acc.routine @routine_0 func(@device_func)
+// func.func @device_func() attributes {acc.routine_info = ...} {
+// %addr = memref.get_global @g_data : memref<f32>
+// %val = memref.load %addr[] : memref<f32>
+// }
+//
+// After:
+// memref.global @g_data : memref<f32> = dense<0.0>
+// {acc.declare = #acc.declare<dataClause = acc_copyin>}
+// acc.routine @routine_0 func(@device_func)
+// func.func @device_func() attributes {acc.routine_info = ...} {
+// %addr = memref.get_global @g_data : memref<f32>
+// %val = memref.load %addr[] : memref<f32>
+// }
+//
+// Example 4: Global in private recipe (declared if recipe is used)
+//
+// Before:
+// memref.global @g_init : memref<f32> = dense<0.0>
+// acc.private.recipe @priv_recipe : memref<f32> init {
+// ^bb0(%arg0: memref<f32>):
+// %alloc = memref.alloc() : memref<f32>
+// %global = memref.get_global @g_init : memref<f32>
+// %val = memref.load %global[] : memref<f32>
+// memref.store %val, %alloc[] : memref<f32>
+// acc.yield %alloc : memref<f32>
+// } destroy { ... }
+// func.func @test() {
+// %var = memref.alloc() : memref<f32>
+// %priv = acc.private varPtr(%var : memref<f32>)
+// recipe(@priv_recipe) -> memref<f32>
+// acc.parallel private(%priv : memref<f32>) { ... }
+// }
+//
+// After:
+// memref.global @g_init : memref<f32> = dense<0.0>
+// {acc.declare = #acc.declare<dataClause = acc_copyin>}
+// acc.private.recipe @priv_recipe : memref<f32> init {
+// ^bb0(%arg0: memref<f32>):
+// %alloc = memref.alloc() : memref<f32>
+// %global = memref.get_global @g_init : memref<f32>
+// %val = memref.load %global[] : memref<f32>
+// memref.store %val, %alloc[] : memref<f32>
+// acc.yield %alloc : memref<f32>
+// } destroy { ... }
+// func.func @test() {
+// %var = memref.alloc() : memref<f32>
+// %priv = acc.private varPtr(%var : memref<f32>)
+// recipe(@priv_recipe) -> memref<f32>
+// acc.parallel private(%priv : memref<f32>) { ... }
+// }
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+
+#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace mlir {
+namespace acc {
+#define GEN_PASS_DEF_ACCIMPLICITDECLARE
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
+} // namespace acc
+} // namespace mlir
+
+#define DEBUG_TYPE "acc-implicit-declare"
+
+using namespace mlir;
+
+namespace {
+
+using GlobalOpSetT = llvm::SmallSetVector<Operation *, 16>;
+
+/// Checks whether a use of the requested `globalOp` should be considered
+/// for hoisting out of acc region due to avoid `acc declare`ing something
+/// that instead should be implicitly mapped.
+static bool isGlobalUseCandidateForHoisting(Operation *globalOp,
+ Operation *user,
+ SymbolRefAttr symbol,
+ acc::OpenACCSupport &accSupport) {
+ // This symbol is valid in GPU region. This means semantics
+ // would change if moved to host - therefore it is not a candidate.
+ if (accSupport.isValidSymbolUse(user, symbol))
+ return false;
+
+ bool isConstant = false;
+ bool isFunction = false;
+
+ if (auto globalVarOp = dyn_cast<acc::GlobalVariableOpInterface>(globalOp))
+ isConstant = globalVarOp.isConstant();
+
+ if (isa<FunctionOpInterface>(globalOp))
+ isFunction = true;
+
+ // Constants should be kept in device code to ensure they are duplicated.
+ // Function references should be kept in device code to ensure their device
+ // addresses are computed. Everything else should be hoisted since we already
+ // proved they are not valid symbols in GPU region.
+ return !isConstant && !isFunction;
+}
+
+/// Checks whether it is valid to use acc.declare marking on the global.
+bool isValidForAccDeclare(Operation *globalOp) {
+ // For functions - we use acc.routine marking instead.
+ return !isa<FunctionOpInterface>(globalOp);
+}
+
+/// Checks whether a recipe operation has meaningful use of its symbol that
+/// justifies processing its regions for global references. Returns false if:
+/// 1. The recipe has no symbol uses at all, or
+/// 2. The only symbol use is the recipe's own symbol definition
+template <typename RecipeOpT>
+static bool hasRelevantRecipeUse(RecipeOpT &recipeOp, ModuleOp &mod) {
+ std::optional<SymbolTable::UseRange> symbolUses = recipeOp.getSymbolUses(mod);
+
+ // No recipe symbol uses.
+ if (!symbolUses.has_value() || symbolUses->empty())
+ return false;
+
+ // If more than one use, assume it's used.
+ auto begin = symbolUses->begin();
+ auto end = symbolUses->end();
+ if (begin != end && std::next(begin) != end)
+ return true;
+
+ // If single use, check if the use is the recipe itself.
+ const SymbolTable::SymbolUse &use = *symbolUses->begin();
+ return use.getUser() != recipeOp.getOperation();
+}
+
+// Hoists addr_of operations for non-constant globals out of OpenACC regions.
+// This way - they are implicitly mapped instead of being considered for
+// implicit declare.
+template <typename AccConstructT>
+static void hoistNonConstantDirectUses(AccConstructT accOp,
+ acc::OpenACCSupport &accSupport) {
+ accOp.walk([&](acc::AddressOfGlobalOpInterface addrOfOp) {
+ SymbolRefAttr symRef = addrOfOp.getSymbol();
+ if (symRef) {
+ Operation *globalOp =
+ SymbolTable::lookupNearestSymbolFrom(addrOfOp, symRef);
+ if (isGlobalUseCandidateForHoisting(globalOp, addrOfOp, symRef,
+ accSupport)) {
+ addrOfOp->moveBefore(accOp);
+ LLVM_DEBUG(
+ llvm::dbgs() << "Hoisted:\n\t" << addrOfOp << "\n\tfrom:\n\t";
+ accOp->print(llvm::dbgs(),
+ OpPrintingFlags{}.skipRegions().enableDebugInfo());
+ llvm::dbgs() << "\n");
+ }
+ }
+ });
+}
+
+// Collects the globals referenced in a device region
+static void collectGlobalsFromDeviceRegion(Region &region,
+ GlobalOpSetT &globals,
+ acc::OpenACCSupport &accSupport,
+ SymbolTable &symTab) {
+ region.walk([&](Operation *op) {
+ // 1) Only consider relevant operations which use symbols
+ auto addrOfOp = dyn_cast<acc::AddressOfGlobalOpInterface>(op);
+ if (addrOfOp) {
+ SymbolRefAttr symRef = addrOfOp.getSymbol();
+ // 2) Found an operation which uses the symbol. Next determine if it
+ // is a candidate for `acc declare`. Some of the criteria considered
+ // is whether this symbol is not already a device one (either because
+ // acc declare is already used or this is a CUF global).
+ Operation *globalOp = nullptr;
+ bool isCandidate = !accSupport.isValidSymbolUse(op, symRef, &globalOp);
+ // 3) Add the candidate to the set of globals to be `acc declare`d.
+ if (isCandidate && globalOp && isValidForAccDeclare(globalOp))
+ globals.insert(globalOp);
+ } else if (auto indirectAccessOp =
+ dyn_cast<acc::IndirectGlobalAccessOpInterface>(op)) {
+ // Process operations that indirectly access globals
+ llvm::SmallVector<SymbolRefAttr> symbols;
+ indirectAccessOp.getReferencedSymbols(symbols, &symTab);
+ for (SymbolRefAttr symRef : symbols)
+ if (Operation *globalOp = symTab.lookup(symRef.getLeafReference()))
+ if (isValidForAccDeclare(globalOp))
+ globals.insert(globalOp);
+ }
+ });
+}
+
+// Adds the declare attribute to the operation `op`.
+static void addDeclareAttr(MLIRContext *context, Operation *op,
+ acc::DataClause clause) {
+ op->setAttr(acc::getDeclareAttrName(),
+ acc::DeclareAttr::get(context,
+ acc::DataClauseAttr::get(context, clause)));
+}
+
+// This pass applies implicit declare actions for globals referenced in
+// OpenACC compute and routine regions.
+class ACCImplicitDeclare
+ : public acc::impl::ACCImplicitDeclareBase<ACCImplicitDeclare> {
+public:
+ using ACCImplicitDeclareBase<ACCImplicitDeclare>::ACCImplicitDeclareBase;
+
+ void runOnOperation() override {
+ ModuleOp mod = getOperation();
+ MLIRContext *context = &getContext();
+ acc::OpenACCSupport &accSupport = getAnalysis<acc::OpenACCSupport>();
+
+ // 1) Start off by hoisting any AddressOf operations out of acc region
+ // for any cases we do not want to `acc declare`. This is because we can
+ // rely on implicit data mapping in majority of cases without uselessly
+ // polluting the device globals.
+ mod.walk([&](Operation *op) {
+ TypeSwitch<Operation *, void>(op)
+ .Case<ACC_COMPUTE_CONSTRUCT_OPS, acc::KernelEnvironmentOp>(
+ [&](auto accOp) {
+ hoistNonConstantDirectUses(accOp, accSupport);
+ });
+ });
+
+ // 2) Collect global symbols which need to be `acc declare`d. Do it for
+ // compute regions, acc routine, and existing globals with the declare
+ // attribute.
+ SymbolTable symTab(mod);
+ GlobalOpSetT globalsToAccDeclare;
+ mod.walk([&](Operation *op) {
+ TypeSwitch<Operation *, void>(op)
+ .Case<ACC_COMPUTE_CONSTRUCT_OPS, acc::KernelEnvironmentOp>(
+ [&](auto accOp) {
+ collectGlobalsFromDeviceRegion(
+ accOp.getRegion(), globalsToAccDeclare, accSupport, symTab);
+ })
+ .Case<FunctionOpInterface>([&](auto func) {
+ if ((acc::isAccRoutine(func) ||
+ acc::isSpecializedAccRoutine(func)) &&
+ !func.isExternal())
+ collectGlobalsFromDeviceRegion(func.getFunctionBody(),
+ globalsToAccDeclare, accSupport,
+ symTab);
+ })
+ .Case<acc::GlobalVariableOpInterface>([&](auto globalVarOp) {
+ if (globalVarOp->getAttr(acc::getDeclareAttrName()))
+ if (Region *initRegion = globalVarOp.getInitRegion())
+ collectGlobalsFromDeviceRegion(*initRegion, globalsToAccDeclare,
+ accSupport, symTab);
+ })
+ .Case<acc::PrivateRecipeOp>([&](auto privateRecipe) {
+ if (hasRelevantRecipeUse(privateRecipe, mod)) {
+ collectGlobalsFromDeviceRegion(privateRecipe.getInitRegion(),
+ globalsToAccDeclare, accSupport,
+ symTab);
+ collectGlobalsFromDeviceRegion(privateRecipe.getDestroyRegion(),
+ globalsToAccDeclare, accSupport,
+ symTab);
+ }
+ })
+ .Case<acc::FirstprivateRecipeOp>([&](auto firstprivateRecipe) {
+ if (hasRelevantRecipeUse(firstprivateRecipe, mod)) {
+ collectGlobalsFromDeviceRegion(firstprivateRecipe.getInitRegion(),
+ globalsToAccDeclare, accSupport,
+ symTab);
+ collectGlobalsFromDeviceRegion(
+ firstprivateRecipe.getDestroyRegion(), globalsToAccDeclare,
+ accSupport, symTab);
+ collectGlobalsFromDeviceRegion(firstprivateRecipe.getCopyRegion(),
+ globalsToAccDeclare, accSupport,
+ symTab);
+ }
+ })
+ .Case<acc::ReductionRecipeOp>([&](auto reductionRecipe) {
+ if (hasRelevantRecipeUse(reductionRecipe, mod)) {
+ collectGlobalsFromDeviceRegion(reductionRecipe.getInitRegion(),
+ globalsToAccDeclare, accSupport,
+ symTab);
+ collectGlobalsFromDeviceRegion(
+ reductionRecipe.getCombinerRegion(), globalsToAccDeclare,
+ accSupport, symTab);
+ }
+ });
+ });
+
+ // 3) Finally, generate the appropriate declare actions needed to ensure
+ // this is considered for device global.
+ for (Operation *globalOp : globalsToAccDeclare) {
+ LLVM_DEBUG(
+ llvm::dbgs() << "Global is being `acc declare copyin`d: ";
+ globalOp->print(llvm::dbgs(),
+ OpPrintingFlags{}.skipRegions().enableDebugInfo());
+ llvm::dbgs() << "\n");
+
+ // Mark it as declare copyin.
+ addDeclareAttr(context, globalOp, acc::DataClause::acc_copyin);
+
+ // TODO: May need to create the global constructor which does the mapping
+ // action. It is not yet clear if this is needed yet (since the globals
+ // might just end up in the GPU image without requiring mapping via
+ // runtime).
+ }
+ }
+};
+
+} // namespace
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitRoutine.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitRoutine.cpp
new file mode 100644
index 0000000..12efaf4
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitRoutine.cpp
@@ -0,0 +1,237 @@
+//===- ACCImplicitRoutine.cpp - OpenACC Implicit Routine Transform -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass implements the implicit rules described in OpenACC specification
+// for `Routine Directive` (OpenACC 3.4 spec, section 2.15.1).
+//
+// "If no explicit routine directive applies to a procedure whose definition
+// appears in the program unit being compiled, then the implementation applies
+// an implicit routine directive to that procedure if any of the following
+// conditions holds:
+// - The procedure is called or its address is accessed in a compute region."
+//
+// The specification further states:
+// "When the implementation applies an implicit routine directive to a
+// procedure, it must recursively apply implicit routine directives to other
+// procedures for which the above rules specify relevant dependencies. Such
+// dependencies can form a cycle, so the implementation must take care to avoid
+// infinite recursion."
+//
+// This pass implements these requirements by:
+// 1. Walking through all OpenACC compute constructs and functions already
+// marked with `acc routine` in the module and identifying function calls
+// within these regions.
+// 2. Creating implicit `acc.routine` operations for functions that don't
+// already have routine declarations.
+// 3. Recursively walking through all existing `acc routine` and creating
+// implicit routine operations for function calls within these routines,
+// while avoiding infinite recursion through proper tracking.
+//
+// Requirements:
+// -------------
+// To use this pass in a pipeline, the following requirements must be met:
+//
+// 1. Operation Interface Implementation: Operations that define functions
+// or call functions should implement `mlir::FunctionOpInterface` and
+// `mlir::CallOpInterface` respectively.
+//
+// 2. Analysis Registration (Optional): If custom behavior is needed for
+// determining if a symbol use is valid within GPU regions, the dialect
+// should pre-register the `acc::OpenACCSupport` analysis.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+
+#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include <queue>
+
+#define DEBUG_TYPE "acc-implicit-routine"
+
+namespace mlir {
+namespace acc {
+#define GEN_PASS_DEF_ACCIMPLICITROUTINE
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
+} // namespace acc
+} // namespace mlir
+
+namespace {
+
+using namespace mlir;
+
+class ACCImplicitRoutine
+ : public acc::impl::ACCImplicitRoutineBase<ACCImplicitRoutine> {
+private:
+ unsigned routineCounter = 0;
+ static constexpr llvm::StringRef accRoutinePrefix = "acc_routine_";
+
+ // Count existing routine operations and update counter
+ void initRoutineCounter(ModuleOp module) {
+ module.walk([&](acc::RoutineOp routineOp) { routineCounter++; });
+ }
+
+ // Check if routine has a default bind clause or a device-type specific bind
+ // clause. Returns true if `acc routine` has a default bind clause or
+ // a device-type specific bind clause.
+ bool isACCRoutineBindDefaultOrDeviceType(acc::RoutineOp op,
+ acc::DeviceType deviceType) {
+ // Fast check to avoid device-type specific lookups.
+ if (!op.getBindIdName() && !op.getBindStrName())
+ return false;
+ return op.getBindNameValue().has_value() ||
+ op.getBindNameValue(deviceType).has_value();
+ }
+
+ // Generate a unique name for the routine and create the routine operation
+ acc::RoutineOp createRoutineOp(OpBuilder &builder, Location loc,
+ FunctionOpInterface &callee) {
+ std::string routineName =
+ (accRoutinePrefix + std::to_string(routineCounter++)).str();
+ auto routineOp = acc::RoutineOp::create(
+ builder, loc,
+ /* sym_name=*/builder.getStringAttr(routineName),
+ /* func_name=*/
+ mlir::SymbolRefAttr::get(builder.getContext(),
+ builder.getStringAttr(callee.getName())),
+ /* bindIdName=*/nullptr,
+ /* bindStrName=*/nullptr,
+ /* bindIdNameDeviceType=*/nullptr,
+ /* bindStrNameDeviceType=*/nullptr,
+ /* worker=*/nullptr,
+ /* vector=*/nullptr,
+ /* seq=*/nullptr,
+ /* nohost=*/nullptr,
+ /* implicit=*/builder.getUnitAttr(),
+ /* gang=*/nullptr,
+ /* gangDim=*/nullptr,
+ /* gangDimDeviceType=*/nullptr);
+
+ // Assert that the callee does not already have routine info attribute
+ assert(!callee->hasAttr(acc::getRoutineInfoAttrName()) &&
+ "function is already associated with a routine");
+
+ callee->setAttr(
+ acc::getRoutineInfoAttrName(),
+ mlir::acc::RoutineInfoAttr::get(
+ builder.getContext(),
+ {mlir::SymbolRefAttr::get(builder.getContext(),
+ builder.getStringAttr(routineName))}));
+ return routineOp;
+ }
+
+ // Used to walk through a compute region looking for function calls.
+ void
+ implicitRoutineForCallsInComputeRegions(Operation *op, SymbolTable &symTab,
+ mlir::OpBuilder &builder,
+ acc::OpenACCSupport &accSupport) {
+ op->walk([&](CallOpInterface callOp) {
+ if (!callOp.getCallableForCallee())
+ return;
+
+ auto calleeSymbolRef =
+ dyn_cast<SymbolRefAttr>(callOp.getCallableForCallee());
+ // When call is done through ssa value, the callee is not a symbol.
+ // Skip it because we don't know the call target.
+ if (!calleeSymbolRef)
+ return;
+
+ auto callee = symTab.lookup<FunctionOpInterface>(
+ calleeSymbolRef.getLeafReference().str());
+ // If the callee does not exist or is already a valid symbol for GPU
+ // regions, skip it
+
+ assert(callee && "callee function must be found in symbol table");
+ if (accSupport.isValidSymbolUse(callOp.getOperation(), calleeSymbolRef))
+ return;
+ builder.setInsertionPoint(callee);
+ createRoutineOp(builder, callee.getLoc(), callee);
+ });
+ }
+
+ // Recursively handle calls within a routine operation
+ void implicitRoutineForCallsInRoutine(acc::RoutineOp routineOp,
+ mlir::OpBuilder &builder,
+ acc::OpenACCSupport &accSupport,
+ acc::DeviceType targetDeviceType) {
+ // When bind clause is used, it means that the target is different than the
+ // function to which the `acc routine` is used with. Skip this case to
+ // avoid implicitly recursively marking calls that would not end up on
+ // device.
+ if (isACCRoutineBindDefaultOrDeviceType(routineOp, targetDeviceType))
+ return;
+
+ SymbolTable symTab(routineOp->getParentOfType<ModuleOp>());
+ std::queue<acc::RoutineOp> routineQueue;
+ routineQueue.push(routineOp);
+ while (!routineQueue.empty()) {
+ auto currentRoutine = routineQueue.front();
+ routineQueue.pop();
+ auto func = symTab.lookup<FunctionOpInterface>(
+ currentRoutine.getFuncName().getLeafReference());
+ func.walk([&](CallOpInterface callOp) {
+ if (!callOp.getCallableForCallee())
+ return;
+
+ auto calleeSymbolRef =
+ dyn_cast<SymbolRefAttr>(callOp.getCallableForCallee());
+ // When call is done through ssa value, the callee is not a symbol.
+ // Skip it because we don't know the call target.
+ if (!calleeSymbolRef)
+ return;
+
+ auto callee = symTab.lookup<FunctionOpInterface>(
+ calleeSymbolRef.getLeafReference().str());
+ // If the callee does not exist or is already a valid symbol for GPU
+ // regions, skip it
+ assert(callee && "callee function must be found in symbol table");
+ if (accSupport.isValidSymbolUse(callOp.getOperation(), calleeSymbolRef))
+ return;
+ builder.setInsertionPoint(callee);
+ auto newRoutineOp = createRoutineOp(builder, callee.getLoc(), callee);
+ routineQueue.push(newRoutineOp);
+ });
+ }
+ }
+
+public:
+ using ACCImplicitRoutineBase<ACCImplicitRoutine>::ACCImplicitRoutineBase;
+
+ void runOnOperation() override {
+ auto module = getOperation();
+ mlir::OpBuilder builder(module.getContext());
+ SymbolTable symTab(module);
+ initRoutineCounter(module);
+
+ acc::OpenACCSupport &accSupport = getAnalysis<acc::OpenACCSupport>();
+
+ // Handle compute regions
+ module.walk([&](Operation *op) {
+ if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(op))
+ implicitRoutineForCallsInComputeRegions(op, symTab, builder,
+ accSupport);
+ });
+
+ // Use the device type option from the pass options.
+ acc::DeviceType targetDeviceType = deviceType;
+
+ // Handle existing routines
+ module.walk([&](acc::RoutineOp routineOp) {
+ implicitRoutineForCallsInRoutine(routineOp, builder, accSupport,
+ targetDeviceType);
+ });
+ }
+};
+
+} // namespace
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp
new file mode 100644
index 0000000..f41ce276
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp
@@ -0,0 +1,117 @@
+//===- ACCLegalizeSerial.cpp - Legalize ACC Serial region -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass converts acc.serial into acc.parallel with num_gangs(1)
+// num_workers(1) vector_length(1).
+//
+// This transformation simplifies processing of acc regions by unifying the
+// handling of serial and parallel constructs. Since an OpenACC serial region
+// executes sequentially (like a parallel region with a single gang, worker, and
+// vector), this conversion is semantically equivalent while enabling code reuse
+// in later compilation stages.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Region.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir {
+namespace acc {
+#define GEN_PASS_DEF_ACCLEGALIZESERIAL
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
+} // namespace acc
+} // namespace mlir
+
+#define DEBUG_TYPE "acc-legalize-serial"
+
+namespace {
+using namespace mlir;
+
+struct ACCSerialOpConversion : public OpRewritePattern<acc::SerialOp> {
+ using OpRewritePattern<acc::SerialOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(acc::SerialOp serialOp,
+ PatternRewriter &rewriter) const override {
+
+ const Location loc = serialOp.getLoc();
+
+ // Create a container holding the constant value of 1 for use as the
+ // num_gangs, num_workers, and vector_length attributes.
+ llvm::SmallVector<mlir::Value> numValues;
+ auto value = arith::ConstantIntOp::create(rewriter, loc, 1, 32);
+ numValues.push_back(value);
+
+ // Since num_gangs is specified as both attributes and values, create a
+ // segment attribute.
+ llvm::SmallVector<int32_t> numGangsSegments;
+ numGangsSegments.push_back(numValues.size());
+ auto gangSegmentsAttr = rewriter.getDenseI32ArrayAttr(numGangsSegments);
+
+ // Create a device_type attribute set to `none` which ensures that
+ // the parallel dimensions specification applies to the default clauses.
+ llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
+ auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
+ rewriter.getContext(), mlir::acc::DeviceType::None);
+ crtDeviceTypes.push_back(crtDeviceTypeAttr);
+ auto devTypeAttr =
+ mlir::ArrayAttr::get(rewriter.getContext(), crtDeviceTypes);
+
+ LLVM_DEBUG(llvm::dbgs() << "acc.serial OP: " << serialOp << "\n");
+
+ // Create a new acc.parallel op with the same operands - except include the
+ // num_gangs, num_workers, and vector_length attributes.
+ acc::ParallelOp parOp = acc::ParallelOp::create(
+ rewriter, loc, serialOp.getAsyncOperands(),
+ serialOp.getAsyncOperandsDeviceTypeAttr(), serialOp.getAsyncOnlyAttr(),
+ serialOp.getWaitOperands(), serialOp.getWaitOperandsSegmentsAttr(),
+ serialOp.getWaitOperandsDeviceTypeAttr(),
+ serialOp.getHasWaitDevnumAttr(), serialOp.getWaitOnlyAttr(), numValues,
+ gangSegmentsAttr, devTypeAttr, numValues, devTypeAttr, numValues,
+ devTypeAttr, serialOp.getIfCond(), serialOp.getSelfCond(),
+ serialOp.getSelfAttrAttr(), serialOp.getReductionOperands(),
+ serialOp.getPrivateOperands(), serialOp.getFirstprivateOperands(),
+ serialOp.getDataClauseOperands(), serialOp.getDefaultAttrAttr(),
+ serialOp.getCombinedAttr());
+
+ parOp.getRegion().takeBody(serialOp.getRegion());
+
+ LLVM_DEBUG(llvm::dbgs() << "acc.parallel OP: " << parOp << "\n");
+ rewriter.replaceOp(serialOp, parOp);
+
+ return success();
+ }
+};
+
+class ACCLegalizeSerial
+ : public mlir::acc::impl::ACCLegalizeSerialBase<ACCLegalizeSerial> {
+public:
+ using ACCLegalizeSerialBase<ACCLegalizeSerial>::ACCLegalizeSerialBase;
+ void runOnOperation() override {
+ func::FuncOp funcOp = getOperation();
+ MLIRContext *context = funcOp.getContext();
+ RewritePatternSet patterns(context);
+ patterns.insert<ACCSerialOpConversion>(context);
+ (void)applyPatternsGreedily(funcOp, std::move(patterns));
+ }
+};
+
+} // namespace
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
index 7d93495..10a1796 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
@@ -1,4 +1,8 @@
add_mlir_dialect_library(MLIROpenACCTransforms
+ ACCImplicitData.cpp
+ ACCImplicitDeclare.cpp
+ ACCImplicitRoutine.cpp
+ ACCLegalizeSerial.cpp
LegalizeDataValues.cpp
ADDITIONAL_HEADER_DIRS
@@ -14,7 +18,10 @@ add_mlir_dialect_library(MLIROpenACCTransforms
MLIROpenACCTypeInterfacesIncGen
LINK_LIBS PUBLIC
+ MLIRAnalysis
+ MLIROpenACCAnalysis
MLIROpenACCDialect
+ MLIROpenACCUtils
MLIRFuncDialect
MLIRIR
MLIRPass
diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
index 660c313..7f27b44 100644
--- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
+++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
@@ -9,8 +9,13 @@
#include "mlir/Dialect/OpenACC/OpenACCUtils.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
+#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/IR/Intrinsics.h"
#include "llvm/Support/Casting.h"
mlir::Operation *mlir::acc::getEnclosingComputeOp(mlir::Region &region) {
@@ -145,3 +150,119 @@ std::string mlir::acc::getRecipeName(mlir::acc::RecipeKind kind,
return recipeName;
}
+
+mlir::Value mlir::acc::getBaseEntity(mlir::Value val) {
+ if (auto partialEntityAccessOp =
+ dyn_cast<PartialEntityAccessOpInterface>(val.getDefiningOp())) {
+ if (!partialEntityAccessOp.isCompleteView())
+ return partialEntityAccessOp.getBaseEntity();
+ }
+
+ return val;
+}
+
+bool mlir::acc::isValidSymbolUse(mlir::Operation *user,
+ mlir::SymbolRefAttr symbol,
+ mlir::Operation **definingOpPtr) {
+ mlir::Operation *definingOp =
+ mlir::SymbolTable::lookupNearestSymbolFrom(user, symbol);
+
+ // If there are no defining ops, we have no way to ensure validity because
+ // we cannot check for any attributes.
+ if (!definingOp)
+ return false;
+
+ if (definingOpPtr)
+ *definingOpPtr = definingOp;
+
+ // Check if the defining op is a recipe (private, reduction, firstprivate).
+ // Recipes are valid as they get materialized before being offloaded to
+ // device. They are only instructions for how to materialize.
+ if (mlir::isa<mlir::acc::PrivateRecipeOp, mlir::acc::ReductionRecipeOp,
+ mlir::acc::FirstprivateRecipeOp>(definingOp))
+ return true;
+
+ // Check if the defining op is a function
+ if (auto func =
+ mlir::dyn_cast_if_present<mlir::FunctionOpInterface>(definingOp)) {
+ // If this symbol is actually an acc routine - then it is expected for it
+ // to be offloaded - therefore it is valid.
+ if (func->hasAttr(mlir::acc::getRoutineInfoAttrName()))
+ return true;
+
+ // If this symbol is a call to an LLVM intrinsic, then it is likely valid.
+ // Check the following:
+ // 1. The function is private
+ // 2. The function has no body
+ // 3. Name starts with "llvm."
+ // 4. The function's name is a valid LLVM intrinsic name
+ if (func.getVisibility() == mlir::SymbolTable::Visibility::Private &&
+ func.getFunctionBody().empty() && func.getName().starts_with("llvm.") &&
+ llvm::Intrinsic::lookupIntrinsicID(func.getName()) !=
+ llvm::Intrinsic::not_intrinsic)
+ return true;
+ }
+
+ // A declare attribute is needed for symbol references.
+ bool hasDeclare = definingOp->hasAttr(mlir::acc::getDeclareAttrName());
+ return hasDeclare;
+}
+
+llvm::SmallVector<mlir::Value>
+mlir::acc::getDominatingDataClauses(mlir::Operation *computeConstructOp,
+ mlir::DominanceInfo &domInfo,
+ mlir::PostDominanceInfo &postDomInfo) {
+ llvm::SmallSetVector<mlir::Value, 8> dominatingDataClauses;
+
+ llvm::TypeSwitch<mlir::Operation *>(computeConstructOp)
+ .Case<mlir::acc::ParallelOp, mlir::acc::KernelsOp, mlir::acc::SerialOp>(
+ [&](auto op) {
+ for (auto dataClause : op.getDataClauseOperands()) {
+ dominatingDataClauses.insert(dataClause);
+ }
+ })
+ .Default([](mlir::Operation *) {});
+
+ // Collect the data clauses from enclosing data constructs.
+ mlir::Operation *currParentOp = computeConstructOp->getParentOp();
+ while (currParentOp) {
+ if (mlir::isa<mlir::acc::DataOp>(currParentOp)) {
+ for (auto dataClause : mlir::dyn_cast<mlir::acc::DataOp>(currParentOp)
+ .getDataClauseOperands()) {
+ dominatingDataClauses.insert(dataClause);
+ }
+ }
+ currParentOp = currParentOp->getParentOp();
+ }
+
+ // Find the enclosing function/subroutine
+ auto funcOp =
+ computeConstructOp->getParentOfType<mlir::FunctionOpInterface>();
+ if (!funcOp)
+ return dominatingDataClauses.takeVector();
+
+ // Walk the function to find `acc.declare_enter`/`acc.declare_exit` pairs that
+ // dominate and post-dominate the compute construct and add their data
+ // clauses to the list.
+ funcOp->walk([&](mlir::acc::DeclareEnterOp declareEnterOp) {
+ if (domInfo.dominates(declareEnterOp.getOperation(), computeConstructOp)) {
+ // Collect all `acc.declare_exit` ops for this token.
+ llvm::SmallVector<mlir::acc::DeclareExitOp> exits;
+ for (auto *user : declareEnterOp.getToken().getUsers())
+ if (auto declareExit = mlir::dyn_cast<mlir::acc::DeclareExitOp>(user))
+ exits.push_back(declareExit);
+
+ // Only add clauses if every `acc.declare_exit` op post-dominates the
+ // compute construct.
+ if (!exits.empty() &&
+ llvm::all_of(exits, [&](mlir::acc::DeclareExitOp exitOp) {
+ return postDomInfo.postDominates(exitOp, computeConstructOp);
+ })) {
+ for (auto dataClause : declareEnterOp.getDataClauseOperands())
+ dominatingDataClauses.insert(dataClause);
+ }
+ }
+ });
+
+ return dominatingDataClauses.takeVector();
+}