aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/OpenACC/IR
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/OpenACC/IR')
-rw-r--r--mlir/lib/Dialect/OpenACC/IR/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp1355
2 files changed, 1124 insertions, 232 deletions
diff --git a/mlir/lib/Dialect/OpenACC/IR/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/IR/CMakeLists.txt
index ed7425b..2bd41d9 100644
--- a/mlir/lib/Dialect/OpenACC/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenACC/IR/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIROpenACCDialect
LINK_LIBS PUBLIC
MLIRIR
+ MLIRGPUDialect
MLIRLLVMDialect
MLIRMemRefDialect
MLIROpenACCMPCommon
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 6564a4e..460314f 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -4,10 +4,11 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -15,8 +16,10 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/SymbolTable.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallSet.h"
@@ -39,11 +42,21 @@ static bool isScalarLikeType(Type type) {
return type.isIntOrIndexOrFloat() || isa<ComplexType>(type);
}
+/// Helper function to attach the `VarName` attribute to an operation
+/// if a variable name is provided.
+static void attachVarNameAttr(Operation *op, OpBuilder &builder,
+ StringRef varName) {
+ if (!varName.empty()) {
+ auto varNameAttr = acc::VarNameAttr::get(builder.getContext(), varName);
+ op->setAttr(acc::getVarNameAttrName(), varNameAttr);
+ }
+}
+
+template <typename T>
struct MemRefPointerLikeModel
- : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
- MemRefType> {
+ : public PointerLikeType::ExternalModel<MemRefPointerLikeModel<T>, T> {
Type getElementType(Type pointer) const {
- return cast<MemRefType>(pointer).getElementType();
+ return cast<T>(pointer).getElementType();
}
mlir::acc::VariableTypeCategory
@@ -52,7 +65,7 @@ struct MemRefPointerLikeModel
if (auto mappableTy = dyn_cast<MappableType>(varType)) {
return mappableTy.getTypeCategory(varPtr);
}
- auto memrefTy = cast<MemRefType>(pointer);
+ auto memrefTy = cast<T>(pointer);
if (!memrefTy.hasRank()) {
// This memref is unranked - aka it could have any rank, including a
// rank of 0 which could mean scalar. For now, return uncategorized.
@@ -74,14 +87,18 @@ struct MemRefPointerLikeModel
}
mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc,
- StringRef varName, Type varType,
- Value originalVar) const {
+ StringRef varName, Type varType, Value originalVar,
+ bool &needsFree) const {
auto memrefTy = cast<MemRefType>(pointer);
// Check if this is a static memref (all dimensions are known) - if yes
// then we can generate an alloca operation.
- if (memrefTy.hasStaticShape())
- return memref::AllocaOp::create(builder, loc, memrefTy).getResult();
+ if (memrefTy.hasStaticShape()) {
+ needsFree = false; // alloca doesn't need deallocation
+ auto allocaOp = memref::AllocaOp::create(builder, loc, memrefTy);
+ attachVarNameAttr(allocaOp, builder, varName);
+ return allocaOp.getResult();
+ }
// For dynamic memrefs, extract sizes from the original variable if
// provided. Otherwise they cannot be handled.
@@ -99,8 +116,11 @@ struct MemRefPointerLikeModel
// Note: We only add dynamic sizes to the dynamicSizes array
// Static dimensions are handled automatically by AllocOp
}
- return memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes)
- .getResult();
+ needsFree = true; // alloc needs deallocation
+ auto allocOp =
+ memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes);
+ attachVarNameAttr(allocOp, builder, varName);
+ return allocOp.getResult();
}
// TODO: Unranked not yet supported.
@@ -108,10 +128,14 @@ struct MemRefPointerLikeModel
}
bool genFree(Type pointer, OpBuilder &builder, Location loc,
- TypedValue<PointerLikeType> varPtr, Type varType) const {
- if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varPtr)) {
+ TypedValue<PointerLikeType> varToFree, Value allocRes,
+ Type varType) const {
+ if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varToFree)) {
+ // Use allocRes if provided to determine the allocation type
+ Value valueToInspect = allocRes ? allocRes : memrefValue;
+
// Walk through casts to find the original allocation
- Value currentValue = memrefValue;
+ Value currentValue = valueToInspect;
Operation *originalAlloc = nullptr;
// Follow the chain of operations to find the original allocation
@@ -150,7 +174,7 @@ struct MemRefPointerLikeModel
return true;
}
if (isa<memref::AllocOp>(originalAlloc)) {
- // This is an alloc - generate dealloc
+ // This is an alloc - generate dealloc on varToFree
memref::DeallocOp::create(builder, loc, memrefValue);
return true;
}
@@ -181,12 +205,111 @@ struct MemRefPointerLikeModel
return false;
}
+
+ mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc,
+ TypedValue<PointerLikeType> srcPtr,
+ Type valueType) const {
+ // Load from a memref - only valid for scalar memrefs (rank 0).
+ // This is because the address computation for memrefs is part of the load
+ // (and not computed separately), but the API does not have arguments for
+ // indexing.
+ auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(srcPtr);
+ if (!memrefValue)
+ return {};
+
+ auto memrefTy = memrefValue.getType();
+
+ // Only load from scalar memrefs (rank 0)
+ if (memrefTy.getRank() != 0)
+ return {};
+
+ return memref::LoadOp::create(builder, loc, memrefValue);
+ }
+
+ bool genStore(Type pointer, OpBuilder &builder, Location loc,
+ Value valueToStore, TypedValue<PointerLikeType> destPtr) const {
+ // Store to a memref - only valid for scalar memrefs (rank 0)
+ // This is because the address computation for memrefs is part of the store
+ // (and not computed separately), but the API does not have arguments for
+ // indexing.
+ auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(destPtr);
+ if (!memrefValue)
+ return false;
+
+ auto memrefTy = memrefValue.getType();
+
+ // Only store to scalar memrefs (rank 0)
+ if (memrefTy.getRank() != 0)
+ return false;
+
+ memref::StoreOp::create(builder, loc, valueToStore, memrefValue);
+ return true;
+ }
+
+ bool isDeviceData(Type pointer, Value var) const {
+ auto memrefTy = cast<T>(pointer);
+ Attribute memSpace = memrefTy.getMemorySpace();
+ return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace);
+ }
};
struct LLVMPointerPointerLikeModel
: public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
LLVM::LLVMPointerType> {
Type getElementType(Type pointer) const { return Type(); }
+
+ mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc,
+ TypedValue<PointerLikeType> srcPtr,
+ Type valueType) const {
+ // For LLVM pointers, we need the valueType to determine what to load
+ if (!valueType)
+ return {};
+
+ return LLVM::LoadOp::create(builder, loc, valueType, srcPtr);
+ }
+
+ bool genStore(Type pointer, OpBuilder &builder, Location loc,
+ Value valueToStore, TypedValue<PointerLikeType> destPtr) const {
+ LLVM::StoreOp::create(builder, loc, valueToStore, destPtr);
+ return true;
+ }
+};
+
+struct MemrefAddressOfGlobalModel
+ : public AddressOfGlobalOpInterface::ExternalModel<
+ MemrefAddressOfGlobalModel, memref::GetGlobalOp> {
+ SymbolRefAttr getSymbol(Operation *op) const {
+ auto getGlobalOp = cast<memref::GetGlobalOp>(op);
+ return getGlobalOp.getNameAttr();
+ }
+};
+
+struct MemrefGlobalVariableModel
+ : public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel,
+ memref::GlobalOp> {
+ bool isConstant(Operation *op) const {
+ auto globalOp = cast<memref::GlobalOp>(op);
+ return globalOp.getConstant();
+ }
+
+ Region *getInitRegion(Operation *op) const {
+ // GlobalOp uses attributes for initialization, not regions
+ return nullptr;
+ }
+
+ bool isDeviceData(Operation *op) const {
+ auto globalOp = cast<memref::GlobalOp>(op);
+ Attribute memSpace = globalOp.getType().getMemorySpace();
+ return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace);
+ }
+};
+
+struct GPULaunchOffloadRegionModel
+ : public acc::OffloadRegionOpInterface::ExternalModel<
+ GPULaunchOffloadRegionModel, gpu::LaunchOp> {
+ mlir::Region &getOffloadRegion(mlir::Operation *op) const {
+ return cast<gpu::LaunchOp>(op).getBody();
+ }
};
/// Helper function for any of the times we need to modify an ArrayAttr based on
@@ -274,9 +397,135 @@ void OpenACCDialect::initialize() {
// By attaching interfaces here, we make the OpenACC dialect dependent on
// the other dialects. This is probably better than having dialects like LLVM
// and memref be dependent on OpenACC.
- MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
+ MemRefType::attachInterface<MemRefPointerLikeModel<MemRefType>>(
+ *getContext());
+ UnrankedMemRefType::attachInterface<
+ MemRefPointerLikeModel<UnrankedMemRefType>>(*getContext());
LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
*getContext());
+
+ // Attach operation interfaces
+ memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>(
+ *getContext());
+ memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*getContext());
+ gpu::LaunchOp::attachInterface<GPULaunchOffloadRegionModel>(*getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// RegionBranchOpInterface for acc.kernels / acc.parallel / acc.serial /
+// acc.kernel_environment / acc.data / acc.host_data / acc.loop
+//===----------------------------------------------------------------------===//
+
+/// Generic helper for single-region OpenACC ops that execute their body once
+/// and then return to the parent operation with their results (if any).
+static void
+getSingleRegionOpSuccessorRegions(Operation *op, Region &region,
+ RegionBranchPoint point,
+ SmallVectorImpl<RegionSuccessor> &regions) {
+ if (point.isParent()) {
+ regions.push_back(RegionSuccessor(&region));
+ return;
+ }
+
+ regions.push_back(RegionSuccessor::parent());
+}
+
+static ValueRange getSingleRegionSuccessorInputs(Operation *op,
+ RegionSuccessor successor) {
+ return successor.isParent() ? ValueRange(op->getResults()) : ValueRange();
+}
+
+void KernelsOp::getSuccessorRegions(RegionBranchPoint point,
+ SmallVectorImpl<RegionSuccessor> &regions) {
+ getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
+ regions);
+}
+
+ValueRange KernelsOp::getSuccessorInputs(RegionSuccessor successor) {
+ return getSingleRegionSuccessorInputs(getOperation(), successor);
+}
+
+void ParallelOp::getSuccessorRegions(
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
+ getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
+ regions);
+}
+
+ValueRange ParallelOp::getSuccessorInputs(RegionSuccessor successor) {
+ return getSingleRegionSuccessorInputs(getOperation(), successor);
+}
+
+void SerialOp::getSuccessorRegions(RegionBranchPoint point,
+ SmallVectorImpl<RegionSuccessor> &regions) {
+ getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
+ regions);
+}
+
+ValueRange SerialOp::getSuccessorInputs(RegionSuccessor successor) {
+ return getSingleRegionSuccessorInputs(getOperation(), successor);
+}
+
+void KernelEnvironmentOp::getSuccessorRegions(
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
+ getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
+ regions);
+}
+
+ValueRange KernelEnvironmentOp::getSuccessorInputs(RegionSuccessor successor) {
+ return getSingleRegionSuccessorInputs(getOperation(), successor);
+}
+
+void DataOp::getSuccessorRegions(RegionBranchPoint point,
+ SmallVectorImpl<RegionSuccessor> &regions) {
+ getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
+ regions);
+}
+
+ValueRange DataOp::getSuccessorInputs(RegionSuccessor successor) {
+ return getSingleRegionSuccessorInputs(getOperation(), successor);
+}
+
+void HostDataOp::getSuccessorRegions(
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
+ getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
+ regions);
+}
+
+ValueRange HostDataOp::getSuccessorInputs(RegionSuccessor successor) {
+ return getSingleRegionSuccessorInputs(getOperation(), successor);
+}
+
+void LoopOp::getSuccessorRegions(RegionBranchPoint point,
+ SmallVectorImpl<RegionSuccessor> &regions) {
+ // Unstructured loops: the body may contain arbitrary CFG and early exits.
+ // At the RegionBranch level, only model entry into the body and exit to the
+ // parent; any backedges are represented inside the region CFG.
+ if (getUnstructured()) {
+ if (point.isParent()) {
+ regions.push_back(RegionSuccessor(&getRegion()));
+ return;
+ }
+ regions.push_back(RegionSuccessor::parent());
+ return;
+ }
+
+ // Structured loops: model a loop-shaped region graph similar to scf.for.
+ regions.push_back(RegionSuccessor(&getRegion()));
+ regions.push_back(RegionSuccessor::parent());
+}
+
+ValueRange LoopOp::getSuccessorInputs(RegionSuccessor successor) {
+ return getSingleRegionSuccessorInputs(getOperation(), successor);
+}
+
+//===----------------------------------------------------------------------===//
+// RegionBranchTerminatorOpInterface
+//===----------------------------------------------------------------------===//
+
+MutableOperandRange
+TerminatorOp::getMutableSuccessorOperands(RegionSuccessor /*point*/) {
+ // `acc.terminator` does not forward operands.
+ return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
}
//===----------------------------------------------------------------------===//
@@ -442,6 +691,28 @@ checkValidModifier(Op op, acc::DataClauseModifier validModifiers) {
return success();
}
+template <typename OpT, typename RecipeOpT>
+static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName) {
+ // Mappable types do not need a recipe because it is possible to generate one
+ // from its API. Reject reductions though because no API is available for them
+ // at this time.
+ if (mlir::acc::isMappableType(op.getVar().getType()) &&
+ !std::is_same_v<OpT, acc::ReductionOp>)
+ return success();
+
+ mlir::SymbolRefAttr operandRecipe = op.getRecipeAttr();
+ if (!operandRecipe)
+ return op->emitOpError() << "recipe expected for " << operandName;
+
+ auto decl =
+ SymbolTable::lookupNearestSymbolFrom<RecipeOpT>(op, operandRecipe);
+ if (!decl)
+ return op->emitOpError()
+ << "expected symbol reference " << operandRecipe << " to point to a "
+ << operandName << " declaration";
+ return success();
+}
+
static ParseResult parseVar(mlir::OpAsmParser &parser,
OpAsmParser::UnresolvedOperand &var) {
// Either `var` or `varPtr` keyword is required.
@@ -548,6 +819,18 @@ static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op,
}
}
+static ParseResult parseRecipeSym(mlir::OpAsmParser &parser,
+ mlir::SymbolRefAttr &recipeAttr) {
+ if (failed(parser.parseAttribute(recipeAttr)))
+ return failure();
+ return success();
+}
+
+static void printRecipeSym(mlir::OpAsmPrinter &p, mlir::Operation *op,
+ mlir::SymbolRefAttr recipeAttr) {
+ p << recipeAttr;
+}
+
//===----------------------------------------------------------------------===//
// DataBoundsOp
//===----------------------------------------------------------------------===//
@@ -570,6 +853,9 @@ LogicalResult acc::PrivateOp::verify() {
return failure();
if (failed(checkNoModifier(*this)))
return failure();
+ if (failed(
+ checkRecipe<acc::PrivateOp, acc::PrivateRecipeOp>(*this, "private")))
+ return failure();
return success();
}
@@ -584,6 +870,23 @@ LogicalResult acc::FirstprivateOp::verify() {
return failure();
if (failed(checkNoModifier(*this)))
return failure();
+ if (failed(checkRecipe<acc::FirstprivateOp, acc::FirstprivateRecipeOp>(
+ *this, "firstprivate")))
+ return failure();
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// FirstprivateMapInitialOp
+//===----------------------------------------------------------------------===//
+LogicalResult acc::FirstprivateMapInitialOp::verify() {
+ if (getDataClause() != acc::DataClause::acc_firstprivate)
+ return emitError("data clause associated with firstprivate operation must "
+ "match its intent");
+ if (failed(checkVarAndVarType(*this)))
+ return failure();
+ if (failed(checkNoModifier(*this)))
+ return failure();
return success();
}
@@ -598,6 +901,9 @@ LogicalResult acc::ReductionOp::verify() {
return failure();
if (failed(checkNoModifier(*this)))
return failure();
+ if (failed(checkRecipe<acc::ReductionOp, acc::ReductionRecipeOp>(
+ *this, "reduction")))
+ return failure();
return success();
}
@@ -918,6 +1224,270 @@ bool acc::CacheOp::isCacheReadonly() {
acc::DataClauseModifier::readonly);
}
+//===----------------------------------------------------------------------===//
+// Data entry/exit operations - getEffects implementations
+//===----------------------------------------------------------------------===//
+
+// This function returns true iff the given operation is enclosed
+// in any ACC_COMPUTE_CONSTRUCT_OPS operation.
+// It is quite alike acc::getEnclosingComputeOp() utility,
+// but we cannot use it here.
+static bool isEnclosedIntoComputeOp(mlir::Operation *op) {
+ mlir::Operation *parentOp = op->getParentOp();
+ while (parentOp) {
+ if (mlir::isa<ACC_COMPUTE_CONSTRUCT_OPS>(parentOp))
+ return true;
+ parentOp = parentOp->getParentOp();
+ }
+ return false;
+}
+
+/// Helper to add an effect on an operand, referenced by its mutable range.
+template <typename EffectTy>
+static void addOperandEffect(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects,
+ MutableOperandRange operand) {
+ for (unsigned i = 0, e = operand.size(); i < e; ++i)
+ effects.emplace_back(EffectTy::get(), &operand[i]);
+}
+
+/// Helper to add an effect on a result value.
+template <typename EffectTy>
+static void addResultEffect(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects,
+ Value result) {
+ effects.emplace_back(EffectTy::get(), mlir::cast<mlir::OpResult>(result));
+}
+
+// PrivateOp: accVar result write.
+void acc::PrivateOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ // If acc.private is enclosed into a compute operation,
+ // then it denotes the device side privatization, hence
+ // it does not access the CurrentDeviceIdResource.
+ if (!isEnclosedIntoComputeOp(getOperation()))
+ effects.emplace_back(MemoryEffects::Read::get(),
+ acc::CurrentDeviceIdResource::get());
+ // TODO: should this be MemoryEffects::Allocate?
+ addResultEffect<MemoryEffects::Write>(effects, getAccVar());
+}
+
+// FirstprivateOp: var read, accVar result write.
+void acc::FirstprivateOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ // If acc.firstprivate is enclosed into a compute operation,
+ // then it denotes the device side privatization, hence
+ // it does not access the CurrentDeviceIdResource.
+ if (!isEnclosedIntoComputeOp(getOperation()))
+ effects.emplace_back(MemoryEffects::Read::get(),
+ acc::CurrentDeviceIdResource::get());
+ addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
+ addResultEffect<MemoryEffects::Write>(effects, getAccVar());
+}
+
+// FirstprivateMapInitialOp: var read, accVar result write.
+void acc::FirstprivateMapInitialOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ effects.emplace_back(MemoryEffects::Read::get(),
+ acc::CurrentDeviceIdResource::get());
+ addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
+ addResultEffect<MemoryEffects::Write>(effects, getAccVar());
+}
+
+// ReductionOp: var read, accVar result write.
+void acc::ReductionOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ // If acc.reduction is enclosed into a compute operation,
+ // then it denotes the device side reduction, hence
+ // it does not access the CurrentDeviceIdResource.
+ if (!isEnclosedIntoComputeOp(getOperation()))
+ effects.emplace_back(MemoryEffects::Read::get(),
+ acc::CurrentDeviceIdResource::get());
+ addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
+ addResultEffect<MemoryEffects::Write>(effects, getAccVar());
+}
+
+// DevicePtrOp: RuntimeCounters read.
+void acc::DevicePtrOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
+ effects.emplace_back(MemoryEffects::Read::get(),
+ acc::CurrentDeviceIdResource::get());
+}
+
+// PresentOp: RuntimeCounters read+write.
+void acc::PresentOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
+ effects.emplace_back(MemoryEffects::Write::get(),
+ acc::RuntimeCounters::get());
+ effects.emplace_back(MemoryEffects::Read::get(),
+ acc::CurrentDeviceIdResource::get());
+}
+
+// CopyinOp: RuntimeCounters read+write, var read, accVar result write.
+void acc::CopyinOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
+ effects.emplace_back(MemoryEffects::Write::get(),
+ acc::RuntimeCounters::get());
+ effects.emplace_back(MemoryEffects::Read::get(),
+ acc::CurrentDeviceIdResource::get());
+ addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
+ addResultEffect<MemoryEffects::Write>(effects, getAccVar());
+}
+
+// CreateOp: RuntimeCounters read+write, accVar result write.
+void acc::CreateOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
+ effects.emplace_back(MemoryEffects::Write::get(),
+ acc::RuntimeCounters::get());
+ effects.emplace_back(MemoryEffects::Read::get(),
+ acc::CurrentDeviceIdResource::get());
+ // TODO: should this be MemoryEffects::Allocate?
+ addResultEffect<MemoryEffects::Write>(effects, getAccVar());
+}
+
+// NoCreateOp: RuntimeCounters read+write.
+void acc::NoCreateOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
+ effects.emplace_back(MemoryEffects::Write::get(),
+ acc::RuntimeCounters::get());
+ effects.emplace_back(MemoryEffects::Read::get(),
+ acc::CurrentDeviceIdResource::get());
+}
+
+// AttachOp: RuntimeCounters read+write, var read.
+void acc::AttachOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
+ effects.emplace_back(MemoryEffects::Write::get(),
+ acc::RuntimeCounters::get());
+ effects.emplace_back(MemoryEffects::Read::get(),
+ acc::CurrentDeviceIdResource::get());
+ // TODO: should we also add MemoryEffects::Write?
+ addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
+}
+
+// GetDevicePtrOp: RuntimeCounters read.
+void acc::GetDevicePtrOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
+ effects.emplace_back(MemoryEffects::Read::get(),
+ acc::CurrentDeviceIdResource::get());
+}
+
+// UpdateDeviceOp: var read, accVar result write.
+void acc::UpdateDeviceOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ effects.emplace_back(MemoryEffects::Read::get(),
+ acc::CurrentDeviceIdResource::get());
+ addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
+ addResultEffect<MemoryEffects::Write>(effects, getAccVar());
+}
+
+// UseDeviceOp: RuntimeCounters read.
+void acc::UseDeviceOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
+ effects.emplace_back(MemoryEffects::Read::get(),
+ acc::CurrentDeviceIdResource::get());
+}
+
+// DeclareDeviceResidentOp: RuntimeCounters write, var read.
+void acc::DeclareDeviceResidentOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ effects.emplace_back(MemoryEffects::Write::get(),
+ acc::RuntimeCounters::get());
+ effects.emplace_back(MemoryEffects::Read::get(),
+ acc::CurrentDeviceIdResource::get());
+ addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
+}
+
+// DeclareLinkOp: RuntimeCounters write, var read.
+void acc::DeclareLinkOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ effects.emplace_back(MemoryEffects::Write::get(),
+ acc::RuntimeCounters::get());
+ effects.emplace_back(MemoryEffects::Read::get(),
+ acc::CurrentDeviceIdResource::get());
+ addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
+}
+
+// CacheOp: NoMemoryEffect
+void acc::CacheOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {}
+
+// CopyoutOp: RuntimeCounters read+write, accVar read, var write.
+void acc::CopyoutOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
+ effects.emplace_back(MemoryEffects::Write::get(),
+ acc::RuntimeCounters::get());
+ effects.emplace_back(MemoryEffects::Read::get(),
+ acc::CurrentDeviceIdResource::get());
+ addOperandEffect<MemoryEffects::Read>(effects, getAccVarMutable());
+ addOperandEffect<MemoryEffects::Write>(effects, getVarMutable());
+}
+
+// DeleteOp: RuntimeCounters read+write, accVar read.
+void acc::DeleteOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
+ effects.emplace_back(MemoryEffects::Write::get(),
+ acc::RuntimeCounters::get());
+ effects.emplace_back(MemoryEffects::Read::get(),
+ acc::CurrentDeviceIdResource::get());
+ addOperandEffect<MemoryEffects::Read>(effects, getAccVarMutable());
+}
+
+// DetachOp: RuntimeCounters read+write, accVar read.
+void acc::DetachOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
+ effects.emplace_back(MemoryEffects::Write::get(),
+ acc::RuntimeCounters::get());
+ effects.emplace_back(MemoryEffects::Read::get(),
+ acc::CurrentDeviceIdResource::get());
+ addOperandEffect<MemoryEffects::Read>(effects, getAccVarMutable());
+}
+
+// UpdateHostOp: RuntimeCounters read+write, accVar read, var write.
+void acc::UpdateHostOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
+ effects.emplace_back(MemoryEffects::Write::get(),
+ acc::RuntimeCounters::get());
+ effects.emplace_back(MemoryEffects::Read::get(),
+ acc::CurrentDeviceIdResource::get());
+ addOperandEffect<MemoryEffects::Read>(effects, getAccVarMutable());
+ addOperandEffect<MemoryEffects::Write>(effects, getVarMutable());
+}
+
template <typename StructureOp>
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
unsigned nRegions = 1) {
@@ -1003,6 +1573,197 @@ struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
}
};
+/// Remove empty acc.kernel_environment operations. If the operation has wait
+/// operands, create a acc.wait operation to preserve synchronization.
+struct RemoveEmptyKernelEnvironment
+ : public OpRewritePattern<acc::KernelEnvironmentOp> {
+ using OpRewritePattern<acc::KernelEnvironmentOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(acc::KernelEnvironmentOp op,
+ PatternRewriter &rewriter) const override {
+ assert(op->getNumRegions() == 1 && "expected op to have one region");
+
+ Block &block = op.getRegion().front();
+ if (!block.empty())
+ return failure();
+
+ // Conservatively disable canonicalization of empty acc.kernel_environment
+ // operations if the wait operands in the kernel_environment cannot be fully
+ // represented by acc.wait operation.
+
+ // Disable canonicalization if device type is not the default
+ if (auto deviceTypeAttr = op.getWaitOperandsDeviceTypeAttr()) {
+ for (auto attr : deviceTypeAttr) {
+ if (auto dtAttr = mlir::dyn_cast<acc::DeviceTypeAttr>(attr)) {
+ if (dtAttr.getValue() != mlir::acc::DeviceType::None)
+ return failure();
+ }
+ }
+ }
+
+ // Disable canonicalization if any wait segment has a devnum
+ if (auto hasDevnumAttr = op.getHasWaitDevnumAttr()) {
+ for (auto attr : hasDevnumAttr) {
+ if (auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>(attr)) {
+ if (boolAttr.getValue())
+ return failure();
+ }
+ }
+ }
+
+ // Disable canonicalization if there are multiple wait segments
+ if (auto segmentsAttr = op.getWaitOperandsSegmentsAttr()) {
+ if (segmentsAttr.size() > 1)
+ return failure();
+ }
+
+ // Remove empty kernel environment.
+ // Preserve synchronization by creating acc.wait operation if needed.
+ if (!op.getWaitOperands().empty() || op.getWaitOnlyAttr())
+ rewriter.replaceOpWithNewOp<acc::WaitOp>(op, op.getWaitOperands(),
+ /*asyncOperand=*/Value(),
+ /*waitDevnum=*/Value(),
+ /*async=*/nullptr,
+ /*ifCond=*/Value());
+ else
+ rewriter.eraseOp(op);
+
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// Recipe Region Helpers
+//===----------------------------------------------------------------------===//
+
+/// Create and populate an init region for privatization recipes.
+/// Returns success if the region is populated, failure otherwise.
+/// Sets needsFree to indicate if the allocated memory requires deallocation.
+static LogicalResult createInitRegion(OpBuilder &builder, Location loc,
+ Region &initRegion, Type varType,
+ StringRef varName, ValueRange bounds,
+ bool &needsFree) {
+ // Create init block with arguments: original value + bounds
+ SmallVector<Type> argTypes{varType};
+ SmallVector<Location> argLocs{loc};
+ for (Value bound : bounds) {
+ argTypes.push_back(bound.getType());
+ argLocs.push_back(loc);
+ }
+
+ Block *initBlock = builder.createBlock(&initRegion);
+ initBlock->addArguments(argTypes, argLocs);
+ builder.setInsertionPointToStart(initBlock);
+
+ Value privatizedValue;
+
+ // Get the block argument that represents the original variable
+ Value blockArgVar = initBlock->getArgument(0);
+
+ // Generate init region body based on variable type
+ if (isa<MappableType>(varType)) {
+ auto mappableTy = cast<MappableType>(varType);
+ auto typedVar = cast<TypedValue<MappableType>>(blockArgVar);
+ privatizedValue = mappableTy.generatePrivateInit(
+ builder, loc, typedVar, varName, bounds, {}, needsFree);
+ if (!privatizedValue)
+ return failure();
+ } else {
+ assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
+ auto pointerLikeTy = cast<PointerLikeType>(varType);
+ // Use PointerLikeType's allocation API with the block argument
+ privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType,
+ blockArgVar, needsFree);
+ if (!privatizedValue)
+ return failure();
+ }
+
+ // Add yield operation to init block
+ acc::YieldOp::create(builder, loc, privatizedValue);
+
+ return success();
+}
+
+/// Create and populate a copy region for firstprivate recipes.
+/// Returns success if the region is populated, failure otherwise.
+/// TODO: Handle MappableType - it does not yet have a copy API.
+static LogicalResult createCopyRegion(OpBuilder &builder, Location loc,
+ Region &copyRegion, Type varType,
+ ValueRange bounds) {
+ // Create copy block with arguments: original value + privatized value +
+ // bounds
+ SmallVector<Type> copyArgTypes{varType, varType};
+ SmallVector<Location> copyArgLocs{loc, loc};
+ for (Value bound : bounds) {
+ copyArgTypes.push_back(bound.getType());
+ copyArgLocs.push_back(loc);
+ }
+
+ Block *copyBlock = builder.createBlock(&copyRegion);
+ copyBlock->addArguments(copyArgTypes, copyArgLocs);
+ builder.setInsertionPointToStart(copyBlock);
+
+ bool isMappable = isa<MappableType>(varType);
+ bool isPointerLike = isa<PointerLikeType>(varType);
+ // TODO: Handle MappableType - it does not yet have a copy API.
+ // Otherwise, for now just fallback to pointer-like behavior.
+ if (isMappable && !isPointerLike)
+ return failure();
+
+ // Generate copy region body based on variable type
+ if (isPointerLike) {
+ auto pointerLikeTy = cast<PointerLikeType>(varType);
+ Value originalArg = copyBlock->getArgument(0);
+ Value privatizedArg = copyBlock->getArgument(1);
+
+ // Generate copy operation using PointerLikeType interface
+ if (!pointerLikeTy.genCopy(
+ builder, loc, cast<TypedValue<PointerLikeType>>(privatizedArg),
+ cast<TypedValue<PointerLikeType>>(originalArg), varType))
+ return failure();
+ }
+
+ // Add terminator to copy block
+ acc::TerminatorOp::create(builder, loc);
+
+ return success();
+}
+
+/// Create and populate a destroy region for privatization recipes.
+/// Returns success if the region is populated, failure otherwise.
+static LogicalResult createDestroyRegion(OpBuilder &builder, Location loc,
+ Region &destroyRegion, Type varType,
+ Value allocRes, ValueRange bounds) {
+ // Create destroy block with arguments: original value + privatized value +
+ // bounds
+ SmallVector<Type> destroyArgTypes{varType, varType};
+ SmallVector<Location> destroyArgLocs{loc, loc};
+ for (Value bound : bounds) {
+ destroyArgTypes.push_back(bound.getType());
+ destroyArgLocs.push_back(loc);
+ }
+
+ Block *destroyBlock = builder.createBlock(&destroyRegion);
+ destroyBlock->addArguments(destroyArgTypes, destroyArgLocs);
+ builder.setInsertionPointToStart(destroyBlock);
+
+ auto varToFree =
+ cast<TypedValue<PointerLikeType>>(destroyBlock->getArgument(1));
+ if (isa<MappableType>(varType)) {
+ auto mappableTy = cast<MappableType>(varType);
+ if (!mappableTy.generatePrivateDestroy(builder, loc, varToFree, bounds))
+ return failure();
+ } else {
+ assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
+ auto pointerLikeTy = cast<PointerLikeType>(varType);
+ if (!pointerLikeTy.genFree(builder, loc, varToFree, allocRes, varType))
+ return failure();
+ }
+
+ acc::TerminatorOp::create(builder, loc);
+ return success();
+}
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -1050,6 +1811,70 @@ LogicalResult acc::PrivateRecipeOp::verifyRegions() {
return success();
}
+std::optional<PrivateRecipeOp>
+PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
+ StringRef recipeName, Type varType,
+ StringRef varName, ValueRange bounds) {
+ // First, validate that we can handle this variable type
+ bool isMappable = isa<MappableType>(varType);
+ bool isPointerLike = isa<PointerLikeType>(varType);
+
+ // Unsupported type
+ if (!isMappable && !isPointerLike)
+ return std::nullopt;
+
+ OpBuilder::InsertionGuard guard(builder);
+
+ // Create the recipe operation first so regions have proper parent context
+ auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
+
+ // Populate the init region
+ bool needsFree = false;
+ if (failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
+ varName, bounds, needsFree))) {
+ recipe.erase();
+ return std::nullopt;
+ }
+
+ // Only create destroy region if the allocation needs deallocation
+ if (needsFree) {
+ // Extract the allocated value from the init block's yield operation
+ auto yieldOp =
+ cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
+ Value allocRes = yieldOp.getOperand(0);
+
+ if (failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
+ varType, allocRes, bounds))) {
+ recipe.erase();
+ return std::nullopt;
+ }
+ }
+
+ return recipe;
+}
+
+std::optional<PrivateRecipeOp>
+PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
+ StringRef recipeName,
+ FirstprivateRecipeOp firstprivRecipe) {
+ // Create the private.recipe op with the same type as the firstprivate.recipe.
+ OpBuilder::InsertionGuard guard(builder);
+ auto varType = firstprivRecipe.getType();
+ auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
+
+ // Clone the init region
+ IRMapping mapping;
+ firstprivRecipe.getInitRegion().cloneInto(&recipe.getInitRegion(), mapping);
+
+ // Clone destroy region if the firstprivate.recipe has one.
+ if (!firstprivRecipe.getDestroyRegion().empty()) {
+ IRMapping mapping;
+ firstprivRecipe.getDestroyRegion().cloneInto(&recipe.getDestroyRegion(),
+ mapping);
+ }
+ return recipe;
+}
+
//===----------------------------------------------------------------------===//
// FirstprivateRecipeOp
//===----------------------------------------------------------------------===//
@@ -1080,6 +1905,55 @@ LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
return success();
}
+std::optional<FirstprivateRecipeOp>
+FirstprivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
+ StringRef recipeName, Type varType,
+ StringRef varName, ValueRange bounds) {
+ // First, validate that we can handle this variable type
+ bool isMappable = isa<MappableType>(varType);
+ bool isPointerLike = isa<PointerLikeType>(varType);
+
+ // Unsupported type
+ if (!isMappable && !isPointerLike)
+ return std::nullopt;
+
+ OpBuilder::InsertionGuard guard(builder);
+
+ // Create the recipe operation first so regions have proper parent context
+ auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType);
+
+ // Populate the init region
+ bool needsFree = false;
+ if (failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
+ varName, bounds, needsFree))) {
+ recipe.erase();
+ return std::nullopt;
+ }
+
+ // Populate the copy region
+ if (failed(createCopyRegion(builder, loc, recipe.getCopyRegion(), varType,
+ bounds))) {
+ recipe.erase();
+ return std::nullopt;
+ }
+
+ // Only create destroy region if the allocation needs deallocation
+ if (needsFree) {
+ // Extract the allocated value from the init block's yield operation
+ auto yieldOp =
+ cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
+ Value allocRes = yieldOp.getOperand(0);
+
+ if (failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
+ varType, allocRes, bounds))) {
+ recipe.erase();
+ return std::nullopt;
+ }
+ }
+
+ return recipe;
+}
+
//===----------------------------------------------------------------------===//
// ReductionRecipeOp
//===----------------------------------------------------------------------===//
@@ -1111,40 +1985,6 @@ LogicalResult acc::ReductionRecipeOp::verifyRegions() {
}
//===----------------------------------------------------------------------===//
-// Custom parser and printer verifier for private clause
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseSymOperandList(
- mlir::OpAsmParser &parser,
- llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
- llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &symbols) {
- llvm::SmallVector<SymbolRefAttr> attributes;
- if (failed(parser.parseCommaSeparatedList([&]() {
- if (parser.parseAttribute(attributes.emplace_back()) ||
- parser.parseArrow() ||
- parser.parseOperand(operands.emplace_back()) ||
- parser.parseColonType(types.emplace_back()))
- return failure();
- return success();
- })))
- return failure();
- llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
- attributes.end());
- symbols = ArrayAttr::get(parser.getContext(), arrayAttr);
- return success();
-}
-
-static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op,
- mlir::OperandRange operands,
- mlir::TypeRange types,
- std::optional<mlir::ArrayAttr> attributes) {
- llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](auto it) {
- p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
- << std::get<1>(it).getType();
- });
-}
-
-//===----------------------------------------------------------------------===//
// ParallelOp
//===----------------------------------------------------------------------===//
@@ -1163,45 +2003,19 @@ static LogicalResult checkDataOperands(Op op,
return success();
}
-template <typename Op>
-static LogicalResult
-checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
- mlir::OperandRange operands, llvm::StringRef operandName,
- llvm::StringRef symbolName, bool checkOperandType = true) {
- if (!operands.empty()) {
- if (!attributes || attributes->size() != operands.size())
- return op->emitOpError()
- << "expected as many " << symbolName << " symbol reference as "
- << operandName << " operands";
- } else {
- if (attributes)
- return op->emitOpError()
- << "unexpected " << symbolName << " symbol reference";
- return success();
- }
-
+template <typename OpT, typename RecipeOpT>
+static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp,
+ const mlir::ValueRange &operands,
+ llvm::StringRef operandName) {
llvm::DenseSet<Value> set;
- for (auto args : llvm::zip(operands, *attributes)) {
- mlir::Value operand = std::get<0>(args);
-
+ for (mlir::Value operand : operands) {
+ if (!mlir::isa<OpT>(operand.getDefiningOp()))
+ return accConstructOp->emitOpError()
+ << "expected " << operandName << " as defining op";
if (!set.insert(operand).second)
- return op->emitOpError()
+ return accConstructOp->emitOpError()
<< operandName << " operand appears more than once";
-
- mlir::Type varType = operand.getType();
- auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
- auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
- if (!decl)
- return op->emitOpError()
- << "expected symbol reference " << symbolRef << " to point to a "
- << operandName << " declaration";
-
- if (checkOperandType && decl.getType() && decl.getType() != varType)
- return op->emitOpError() << "expected " << operandName << " (" << varType
- << ") to be the same type as " << operandName
- << " declaration (" << decl.getType() << ")";
}
-
return success();
}
@@ -1258,17 +2072,17 @@ static LogicalResult verifyDeviceTypeAndSegmentCountMatch(
}
LogicalResult acc::ParallelOp::verify() {
- if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
- *this, getPrivatizationRecipes(), getPrivateOperands(), "private",
- "privatizations", /*checkOperandType=*/false)))
+ if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
+ mlir::acc::PrivateRecipeOp>(
+ *this, getPrivateOperands(), "private")))
return failure();
- if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
- *this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
- "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
+ if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
+ mlir::acc::FirstprivateRecipeOp>(
+ *this, getFirstprivateOperands(), "firstprivate")))
return failure();
- if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
- *this, getReductionRecipes(), getReductionOperands(), "reduction",
- "reductions", false)))
+ if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
+ mlir::acc::ReductionRecipeOp>(
+ *this, getReductionOperands(), "reduction")))
return failure();
if (failed(verifyDeviceTypeAndSegmentCountMatch(
@@ -1399,7 +2213,6 @@ void ParallelOp::build(mlir::OpBuilder &odsBuilder,
mlir::ValueRange gangPrivateOperands,
mlir::ValueRange gangFirstPrivateOperands,
mlir::ValueRange dataClauseOperands) {
-
ParallelOp::build(
odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr,
/*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr,
@@ -1408,9 +2221,8 @@ void ParallelOp::build(mlir::OpBuilder &odsBuilder,
/*numGangsDeviceType=*/nullptr, numWorkers,
/*numWorkersDeviceType=*/nullptr, vectorLength,
/*vectorLengthDeviceType=*/nullptr, ifCond, selfCond,
- /*selfAttr=*/nullptr, reductionOperands, /*reductionRecipes=*/nullptr,
- gangPrivateOperands, /*privatizations=*/nullptr, gangFirstPrivateOperands,
- /*firstprivatizations=*/nullptr, dataClauseOperands,
+ /*selfAttr=*/nullptr, reductionOperands, gangPrivateOperands,
+ gangFirstPrivateOperands, dataClauseOperands,
/*defaultAttr=*/nullptr, /*combined=*/nullptr);
}
@@ -1487,46 +2299,22 @@ void acc::ParallelOp::addWaitOperands(
void acc::ParallelOp::addPrivatization(MLIRContext *context,
mlir::acc::PrivateOp op,
mlir::acc::PrivateRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getPrivateOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getPrivatizationRecipesAttr())
- llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
void acc::ParallelOp::addFirstPrivatization(
MLIRContext *context, mlir::acc::FirstprivateOp op,
mlir::acc::FirstprivateRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getFirstprivateOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getFirstprivatizationRecipesAttr())
- llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
void acc::ParallelOp::addReduction(MLIRContext *context,
mlir::acc::ReductionOp op,
mlir::acc::ReductionRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getReductionOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getReductionRecipesAttr())
- llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
static ParseResult parseNumGangs(
@@ -2094,17 +2882,17 @@ mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
}
LogicalResult acc::SerialOp::verify() {
- if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
- *this, getPrivatizationRecipes(), getPrivateOperands(), "private",
- "privatizations", /*checkOperandType=*/false)))
+ if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
+ mlir::acc::PrivateRecipeOp>(
+ *this, getPrivateOperands(), "private")))
return failure();
- if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
- *this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
- "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
+ if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
+ mlir::acc::FirstprivateRecipeOp>(
+ *this, getFirstprivateOperands(), "firstprivate")))
return failure();
- if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
- *this, getReductionRecipes(), getReductionOperands(), "reduction",
- "reductions", false)))
+ if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
+ mlir::acc::ReductionRecipeOp>(
+ *this, getReductionOperands(), "reduction")))
return failure();
if (failed(verifyDeviceTypeAndSegmentCountMatch(
@@ -2168,46 +2956,22 @@ void acc::SerialOp::addWaitOperands(
void acc::SerialOp::addPrivatization(MLIRContext *context,
mlir::acc::PrivateOp op,
mlir::acc::PrivateRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getPrivateOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getPrivatizationRecipesAttr())
- llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
void acc::SerialOp::addFirstPrivatization(
MLIRContext *context, mlir::acc::FirstprivateOp op,
mlir::acc::FirstprivateRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getFirstprivateOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getFirstprivatizationRecipesAttr())
- llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
void acc::SerialOp::addReduction(MLIRContext *context,
mlir::acc::ReductionOp op,
mlir::acc::ReductionRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getReductionOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getReductionRecipesAttr())
- llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
//===----------------------------------------------------------------------===//
@@ -2337,6 +3101,27 @@ LogicalResult acc::KernelsOp::verify() {
return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
}
+void acc::KernelsOp::addPrivatization(MLIRContext *context,
+ mlir::acc::PrivateOp op,
+ mlir::acc::PrivateRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
+ getPrivateOperandsMutable().append(op.getResult());
+}
+
+void acc::KernelsOp::addFirstPrivatization(
+ MLIRContext *context, mlir::acc::FirstprivateOp op,
+ mlir::acc::FirstprivateRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
+ getFirstprivateOperandsMutable().append(op.getResult());
+}
+
+void acc::KernelsOp::addReduction(MLIRContext *context,
+ mlir::acc::ReductionOp op,
+ mlir::acc::ReductionRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
+ getReductionOperandsMutable().append(op.getResult());
+}
+
void acc::KernelsOp::addNumWorkersOperand(
MLIRContext *context, mlir::Value newValue,
llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
@@ -2417,9 +3202,17 @@ LogicalResult acc::HostDataOp::verify() {
return emitError("at least one operand must appear on the host_data "
"operation");
- for (mlir::Value operand : getDataClauseOperands())
- if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
+ llvm::SmallPtrSet<mlir::Value, 4> seenVars;
+ for (mlir::Value operand : getDataClauseOperands()) {
+ auto useDeviceOp =
+ mlir::dyn_cast<acc::UseDeviceOp>(operand.getDefiningOp());
+ if (!useDeviceOp)
return emitError("expect data entry operation as defining op");
+
+ // Check for duplicate use_device clauses
+ if (!seenVars.insert(useDeviceOp.getVar()).second)
+ return emitError("duplicate use_device variable");
+ }
return success();
}
@@ -2429,6 +3222,15 @@ void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
//===----------------------------------------------------------------------===//
+// KernelEnvironmentOp
+//===----------------------------------------------------------------------===//
+
+void acc::KernelEnvironmentOp::getCanonicalizationPatterns(
+ RewritePatternSet &results, MLIRContext *context) {
+ results.add<RemoveEmptyKernelEnvironment>(context);
+}
+
+//===----------------------------------------------------------------------===//
// LoopOp
//===----------------------------------------------------------------------===//
@@ -2637,19 +3439,21 @@ bool hasDuplicateDeviceTypes(
}
/// Check for duplicates in the DeviceType array attribute.
-LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
+/// Returns std::nullopt if no duplicates, or the duplicate DeviceType if found.
+static std::optional<mlir::acc::DeviceType>
+checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
if (!deviceTypes)
- return success();
+ return std::nullopt;
for (auto attr : deviceTypes) {
auto deviceTypeAttr =
mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
if (!deviceTypeAttr)
- return failure();
+ return mlir::acc::DeviceType::None;
if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
- return failure();
+ return deviceTypeAttr.getValue();
}
- return success();
+ return std::nullopt;
}
LogicalResult acc::LoopOp::verify() {
@@ -2676,9 +3480,10 @@ LogicalResult acc::LoopOp::verify() {
getCollapseDeviceTypeAttr().getValue().size())
return emitOpError() << "collapse attribute count must match collapse"
<< " device_type count";
- if (failed(checkDeviceTypes(getCollapseDeviceTypeAttr())))
- return emitOpError()
- << "duplicate device_type found in collapseDeviceType attribute";
+ if (auto duplicateDeviceType = checkDeviceTypes(getCollapseDeviceTypeAttr()))
+ return emitOpError() << "duplicate device_type `"
+ << acc::stringifyDeviceType(*duplicateDeviceType)
+ << "` found in collapseDeviceType attribute";
// Check gang
if (!getGangOperands().empty()) {
@@ -2691,8 +3496,12 @@ LogicalResult acc::LoopOp::verify() {
return emitOpError() << "gangOperandsArgType attribute count must match"
<< " gangOperands count";
}
- if (getGangAttr() && failed(checkDeviceTypes(getGangAttr())))
- return emitOpError() << "duplicate device_type found in gang attribute";
+ if (getGangAttr()) {
+ if (auto duplicateDeviceType = checkDeviceTypes(getGangAttr()))
+ return emitOpError() << "duplicate device_type `"
+ << acc::stringifyDeviceType(*duplicateDeviceType)
+ << "` found in gang attribute";
+ }
if (failed(verifyDeviceTypeAndSegmentCountMatch(
*this, getGangOperands(), getGangOperandsSegmentsAttr(),
@@ -2700,22 +3509,30 @@ LogicalResult acc::LoopOp::verify() {
return failure();
// Check worker
- if (failed(checkDeviceTypes(getWorkerAttr())))
- return emitOpError() << "duplicate device_type found in worker attribute";
- if (failed(checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr())))
- return emitOpError() << "duplicate device_type found in "
- "workerNumOperandsDeviceType attribute";
+ if (auto duplicateDeviceType = checkDeviceTypes(getWorkerAttr()))
+ return emitOpError() << "duplicate device_type `"
+ << acc::stringifyDeviceType(*duplicateDeviceType)
+ << "` found in worker attribute";
+ if (auto duplicateDeviceType =
+ checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr()))
+ return emitOpError() << "duplicate device_type `"
+ << acc::stringifyDeviceType(*duplicateDeviceType)
+ << "` found in workerNumOperandsDeviceType attribute";
if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(),
getWorkerNumOperandsDeviceTypeAttr(),
"worker")))
return failure();
// Check vector
- if (failed(checkDeviceTypes(getVectorAttr())))
- return emitOpError() << "duplicate device_type found in vector attribute";
- if (failed(checkDeviceTypes(getVectorOperandsDeviceTypeAttr())))
- return emitOpError() << "duplicate device_type found in "
- "vectorOperandsDeviceType attribute";
+ if (auto duplicateDeviceType = checkDeviceTypes(getVectorAttr()))
+ return emitOpError() << "duplicate device_type `"
+ << acc::stringifyDeviceType(*duplicateDeviceType)
+ << "` found in vector attribute";
+ if (auto duplicateDeviceType =
+ checkDeviceTypes(getVectorOperandsDeviceTypeAttr()))
+ return emitOpError() << "duplicate device_type `"
+ << acc::stringifyDeviceType(*duplicateDeviceType)
+ << "` found in vectorOperandsDeviceType attribute";
if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(),
getVectorOperandsDeviceTypeAttr(),
"vector")))
@@ -2780,19 +3597,19 @@ LogicalResult acc::LoopOp::verify() {
}
}
- if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
- *this, getPrivatizationRecipes(), getPrivateOperands(), "private",
- "privatizations", false)))
+ if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
+ mlir::acc::PrivateRecipeOp>(
+ *this, getPrivateOperands(), "private")))
return failure();
- if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
- *this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
- "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
+ if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
+ mlir::acc::FirstprivateRecipeOp>(
+ *this, getFirstprivateOperands(), "firstprivate")))
return failure();
- if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
- *this, getReductionRecipes(), getReductionOperands(), "reduction",
- "reductions", false)))
+ if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
+ mlir::acc::ReductionRecipeOp>(
+ *this, getReductionOperands(), "reduction")))
return failure();
if (getCombined().has_value() &&
@@ -2806,8 +3623,12 @@ LogicalResult acc::LoopOp::verify() {
if (getRegion().empty())
return emitError("expected non-empty body.");
- // When it is container-like - it is expected to hold a loop-like operation.
- if (isContainerLike()) {
+ if (getUnstructured()) {
+ if (!isContainerLike())
+ return emitError(
+ "unstructured acc.loop must not have induction variables");
+ } else if (isContainerLike()) {
+ // When it is container-like - it is expected to hold a loop-like operation.
// Obtain the maximum collapse count - we use this to check that there
// are enough loops contained.
uint64_t collapseCount = getCollapseValue().value_or(1);
@@ -3222,45 +4043,21 @@ void acc::LoopOp::addGangOperands(
void acc::LoopOp::addPrivatization(MLIRContext *context,
mlir::acc::PrivateOp op,
mlir::acc::PrivateRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getPrivateOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getPrivatizationRecipesAttr())
- llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
void acc::LoopOp::addFirstPrivatization(
MLIRContext *context, mlir::acc::FirstprivateOp op,
mlir::acc::FirstprivateRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getFirstprivateOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getFirstprivatizationRecipesAttr())
- llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
void acc::LoopOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op,
mlir::acc::ReductionRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getReductionOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getReductionRecipesAttr())
- llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
//===----------------------------------------------------------------------===//
@@ -3597,7 +4394,8 @@ LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
}
if (Value writeVal = op.getWriteOpVal()) {
- rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal);
+ rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal,
+ op.getIfCond());
return success();
}
@@ -3724,7 +4522,8 @@ LogicalResult acc::RoutineOp::verify() {
if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
- "be present at the same time";
+ "be present at the same time for device_type `"
+ << acc::stringifyDeviceType(dtype) << "`";
}
return success();
@@ -4021,6 +4820,100 @@ RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
return std::nullopt;
}
+void RoutineOp::addSeq(MLIRContext *context,
+ llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+ setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
+ effectiveDeviceTypes));
+}
+
+void RoutineOp::addVector(MLIRContext *context,
+ llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+ setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
+ effectiveDeviceTypes));
+}
+
+void RoutineOp::addWorker(MLIRContext *context,
+ llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+ setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
+ effectiveDeviceTypes));
+}
+
+void RoutineOp::addGang(MLIRContext *context,
+ llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+ setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
+ effectiveDeviceTypes));
+}
+
+void RoutineOp::addGang(MLIRContext *context,
+ llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
+ uint64_t val) {
+ llvm::SmallVector<mlir::Attribute> dimValues;
+ llvm::SmallVector<mlir::Attribute> deviceTypes;
+
+ if (getGangDimAttr())
+ llvm::copy(getGangDimAttr(), std::back_inserter(dimValues));
+ if (getGangDimDeviceTypeAttr())
+ llvm::copy(getGangDimDeviceTypeAttr(), std::back_inserter(deviceTypes));
+
+ assert(dimValues.size() == deviceTypes.size());
+
+ if (effectiveDeviceTypes.empty()) {
+ dimValues.push_back(
+ mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
+ deviceTypes.push_back(
+ acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
+ } else {
+ for (DeviceType dt : effectiveDeviceTypes) {
+ dimValues.push_back(
+ mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
+ deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
+ }
+ }
+ assert(dimValues.size() == deviceTypes.size());
+
+ setGangDimAttr(mlir::ArrayAttr::get(context, dimValues));
+ setGangDimDeviceTypeAttr(mlir::ArrayAttr::get(context, deviceTypes));
+}
+
+void RoutineOp::addBindStrName(MLIRContext *context,
+ llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
+ mlir::StringAttr val) {
+ unsigned before = getBindStrNameDeviceTypeAttr()
+ ? getBindStrNameDeviceTypeAttr().size()
+ : 0;
+
+ setBindStrNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
+ context, getBindStrNameDeviceTypeAttr(), effectiveDeviceTypes));
+ unsigned after = getBindStrNameDeviceTypeAttr().size();
+
+ llvm::SmallVector<mlir::Attribute> vals;
+ if (getBindStrNameAttr())
+ llvm::copy(getBindStrNameAttr(), std::back_inserter(vals));
+ for (unsigned i = 0; i < after - before; ++i)
+ vals.push_back(val);
+
+ setBindStrNameAttr(mlir::ArrayAttr::get(context, vals));
+}
+
+void RoutineOp::addBindIDName(MLIRContext *context,
+ llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
+ mlir::SymbolRefAttr val) {
+ unsigned before =
+ getBindIdNameDeviceTypeAttr() ? getBindIdNameDeviceTypeAttr().size() : 0;
+
+ setBindIdNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
+ context, getBindIdNameDeviceTypeAttr(), effectiveDeviceTypes));
+ unsigned after = getBindIdNameDeviceTypeAttr().size();
+
+ llvm::SmallVector<mlir::Attribute> vals;
+ if (getBindIdNameAttr())
+ llvm::copy(getBindIdNameAttr(), std::back_inserter(vals));
+ for (unsigned i = 0; i < after - before; ++i)
+ vals.push_back(val);
+
+ setBindIdNameAttr(mlir::ArrayAttr::get(context, vals));
+}
+
//===----------------------------------------------------------------------===//
// InitOp
//===----------------------------------------------------------------------===//
@@ -4405,13 +5298,11 @@ mlir::acc::getMutableDataOperands(mlir::Operation *accOp) {
return dataOperands;
}
-mlir::Operation *mlir::acc::getEnclosingComputeOp(mlir::Region &region) {
- mlir::Operation *parentOp = region.getParentOp();
- while (parentOp) {
- if (mlir::isa<ACC_COMPUTE_CONSTRUCT_OPS>(parentOp)) {
- return parentOp;
- }
- parentOp = parentOp->getParentOp();
- }
- return nullptr;
+mlir::SymbolRefAttr mlir::acc::getRecipe(mlir::Operation *accOp) {
+ auto recipe{
+ llvm::TypeSwitch<mlir::Operation *, mlir::SymbolRefAttr>(accOp)
+ .Case<ACC_DATA_ENTRY_OPS>(
+ [&](auto entry) { return entry.getRecipeAttr(); })
+ .Default([&](mlir::Operation *) { return mlir::SymbolRefAttr{}; })};
+ return recipe;
}