diff options
Diffstat (limited to 'mlir/lib/Dialect/OpenACC/IR')
| -rw-r--r-- | mlir/lib/Dialect/OpenACC/IR/CMakeLists.txt | 1 | ||||
| -rw-r--r-- | mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp | 1355 |
2 files changed, 1124 insertions, 232 deletions
diff --git a/mlir/lib/Dialect/OpenACC/IR/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/IR/CMakeLists.txt index ed7425b..2bd41d9 100644 --- a/mlir/lib/Dialect/OpenACC/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/OpenACC/IR/CMakeLists.txt @@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIROpenACCDialect LINK_LIBS PUBLIC MLIRIR + MLIRGPUDialect MLIRLLVMDialect MLIRMemRefDialect MLIROpenACCMPCommon diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 6564a4e..460314f 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -4,10 +4,11 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -15,8 +16,10 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallSet.h" @@ -39,11 +42,21 @@ static bool isScalarLikeType(Type type) { return type.isIntOrIndexOrFloat() || isa<ComplexType>(type); } +/// Helper function to attach the `VarName` attribute to an operation +/// if a variable name is provided. +static void attachVarNameAttr(Operation *op, OpBuilder &builder, + StringRef varName) { + if (!varName.empty()) { + auto varNameAttr = acc::VarNameAttr::get(builder.getContext(), varName); + op->setAttr(acc::getVarNameAttrName(), varNameAttr); + } +} + +template <typename T> struct MemRefPointerLikeModel - : public PointerLikeType::ExternalModel<MemRefPointerLikeModel, - MemRefType> { + : public PointerLikeType::ExternalModel<MemRefPointerLikeModel<T>, T> { Type getElementType(Type pointer) const { - return cast<MemRefType>(pointer).getElementType(); + return cast<T>(pointer).getElementType(); } mlir::acc::VariableTypeCategory @@ -52,7 +65,7 @@ struct MemRefPointerLikeModel if (auto mappableTy = dyn_cast<MappableType>(varType)) { return mappableTy.getTypeCategory(varPtr); } - auto memrefTy = cast<MemRefType>(pointer); + auto memrefTy = cast<T>(pointer); if (!memrefTy.hasRank()) { // This memref is unranked - aka it could have any rank, including a // rank of 0 which could mean scalar. For now, return uncategorized. @@ -74,14 +87,18 @@ struct MemRefPointerLikeModel } mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc, - StringRef varName, Type varType, - Value originalVar) const { + StringRef varName, Type varType, Value originalVar, + bool &needsFree) const { auto memrefTy = cast<MemRefType>(pointer); // Check if this is a static memref (all dimensions are known) - if yes // then we can generate an alloca operation. - if (memrefTy.hasStaticShape()) - return memref::AllocaOp::create(builder, loc, memrefTy).getResult(); + if (memrefTy.hasStaticShape()) { + needsFree = false; // alloca doesn't need deallocation + auto allocaOp = memref::AllocaOp::create(builder, loc, memrefTy); + attachVarNameAttr(allocaOp, builder, varName); + return allocaOp.getResult(); + } // For dynamic memrefs, extract sizes from the original variable if // provided. Otherwise they cannot be handled. @@ -99,8 +116,11 @@ struct MemRefPointerLikeModel // Note: We only add dynamic sizes to the dynamicSizes array // Static dimensions are handled automatically by AllocOp } - return memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes) - .getResult(); + needsFree = true; // alloc needs deallocation + auto allocOp = + memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes); + attachVarNameAttr(allocOp, builder, varName); + return allocOp.getResult(); } // TODO: Unranked not yet supported. @@ -108,10 +128,14 @@ struct MemRefPointerLikeModel } bool genFree(Type pointer, OpBuilder &builder, Location loc, - TypedValue<PointerLikeType> varPtr, Type varType) const { - if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varPtr)) { + TypedValue<PointerLikeType> varToFree, Value allocRes, + Type varType) const { + if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varToFree)) { + // Use allocRes if provided to determine the allocation type + Value valueToInspect = allocRes ? allocRes : memrefValue; + // Walk through casts to find the original allocation - Value currentValue = memrefValue; + Value currentValue = valueToInspect; Operation *originalAlloc = nullptr; // Follow the chain of operations to find the original allocation @@ -150,7 +174,7 @@ struct MemRefPointerLikeModel return true; } if (isa<memref::AllocOp>(originalAlloc)) { - // This is an alloc - generate dealloc + // This is an alloc - generate dealloc on varToFree memref::DeallocOp::create(builder, loc, memrefValue); return true; } @@ -181,12 +205,111 @@ struct MemRefPointerLikeModel return false; } + + mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc, + TypedValue<PointerLikeType> srcPtr, + Type valueType) const { + // Load from a memref - only valid for scalar memrefs (rank 0). + // This is because the address computation for memrefs is part of the load + // (and not computed separately), but the API does not have arguments for + // indexing. + auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(srcPtr); + if (!memrefValue) + return {}; + + auto memrefTy = memrefValue.getType(); + + // Only load from scalar memrefs (rank 0) + if (memrefTy.getRank() != 0) + return {}; + + return memref::LoadOp::create(builder, loc, memrefValue); + } + + bool genStore(Type pointer, OpBuilder &builder, Location loc, + Value valueToStore, TypedValue<PointerLikeType> destPtr) const { + // Store to a memref - only valid for scalar memrefs (rank 0) + // This is because the address computation for memrefs is part of the store + // (and not computed separately), but the API does not have arguments for + // indexing. + auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(destPtr); + if (!memrefValue) + return false; + + auto memrefTy = memrefValue.getType(); + + // Only store to scalar memrefs (rank 0) + if (memrefTy.getRank() != 0) + return false; + + memref::StoreOp::create(builder, loc, valueToStore, memrefValue); + return true; + } + + bool isDeviceData(Type pointer, Value var) const { + auto memrefTy = cast<T>(pointer); + Attribute memSpace = memrefTy.getMemorySpace(); + return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace); + } }; struct LLVMPointerPointerLikeModel : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel, LLVM::LLVMPointerType> { Type getElementType(Type pointer) const { return Type(); } + + mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc, + TypedValue<PointerLikeType> srcPtr, + Type valueType) const { + // For LLVM pointers, we need the valueType to determine what to load + if (!valueType) + return {}; + + return LLVM::LoadOp::create(builder, loc, valueType, srcPtr); + } + + bool genStore(Type pointer, OpBuilder &builder, Location loc, + Value valueToStore, TypedValue<PointerLikeType> destPtr) const { + LLVM::StoreOp::create(builder, loc, valueToStore, destPtr); + return true; + } +}; + +struct MemrefAddressOfGlobalModel + : public AddressOfGlobalOpInterface::ExternalModel< + MemrefAddressOfGlobalModel, memref::GetGlobalOp> { + SymbolRefAttr getSymbol(Operation *op) const { + auto getGlobalOp = cast<memref::GetGlobalOp>(op); + return getGlobalOp.getNameAttr(); + } +}; + +struct MemrefGlobalVariableModel + : public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel, + memref::GlobalOp> { + bool isConstant(Operation *op) const { + auto globalOp = cast<memref::GlobalOp>(op); + return globalOp.getConstant(); + } + + Region *getInitRegion(Operation *op) const { + // GlobalOp uses attributes for initialization, not regions + return nullptr; + } + + bool isDeviceData(Operation *op) const { + auto globalOp = cast<memref::GlobalOp>(op); + Attribute memSpace = globalOp.getType().getMemorySpace(); + return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace); + } +}; + +struct GPULaunchOffloadRegionModel + : public acc::OffloadRegionOpInterface::ExternalModel< + GPULaunchOffloadRegionModel, gpu::LaunchOp> { + mlir::Region &getOffloadRegion(mlir::Operation *op) const { + return cast<gpu::LaunchOp>(op).getBody(); + } }; /// Helper function for any of the times we need to modify an ArrayAttr based on @@ -274,9 +397,135 @@ void OpenACCDialect::initialize() { // By attaching interfaces here, we make the OpenACC dialect dependent on // the other dialects. This is probably better than having dialects like LLVM // and memref be dependent on OpenACC. - MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext()); + MemRefType::attachInterface<MemRefPointerLikeModel<MemRefType>>( + *getContext()); + UnrankedMemRefType::attachInterface< + MemRefPointerLikeModel<UnrankedMemRefType>>(*getContext()); LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>( *getContext()); + + // Attach operation interfaces + memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>( + *getContext()); + memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*getContext()); + gpu::LaunchOp::attachInterface<GPULaunchOffloadRegionModel>(*getContext()); +} + +//===----------------------------------------------------------------------===// +// RegionBranchOpInterface for acc.kernels / acc.parallel / acc.serial / +// acc.kernel_environment / acc.data / acc.host_data / acc.loop +//===----------------------------------------------------------------------===// + +/// Generic helper for single-region OpenACC ops that execute their body once +/// and then return to the parent operation with their results (if any). +static void +getSingleRegionOpSuccessorRegions(Operation *op, Region ®ion, + RegionBranchPoint point, + SmallVectorImpl<RegionSuccessor> ®ions) { + if (point.isParent()) { + regions.push_back(RegionSuccessor(®ion)); + return; + } + + regions.push_back(RegionSuccessor::parent()); +} + +static ValueRange getSingleRegionSuccessorInputs(Operation *op, + RegionSuccessor successor) { + return successor.isParent() ? ValueRange(op->getResults()) : ValueRange(); +} + +void KernelsOp::getSuccessorRegions(RegionBranchPoint point, + SmallVectorImpl<RegionSuccessor> ®ions) { + getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point, + regions); +} + +ValueRange KernelsOp::getSuccessorInputs(RegionSuccessor successor) { + return getSingleRegionSuccessorInputs(getOperation(), successor); +} + +void ParallelOp::getSuccessorRegions( + RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { + getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point, + regions); +} + +ValueRange ParallelOp::getSuccessorInputs(RegionSuccessor successor) { + return getSingleRegionSuccessorInputs(getOperation(), successor); +} + +void SerialOp::getSuccessorRegions(RegionBranchPoint point, + SmallVectorImpl<RegionSuccessor> ®ions) { + getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point, + regions); +} + +ValueRange SerialOp::getSuccessorInputs(RegionSuccessor successor) { + return getSingleRegionSuccessorInputs(getOperation(), successor); +} + +void KernelEnvironmentOp::getSuccessorRegions( + RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { + getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point, + regions); +} + +ValueRange KernelEnvironmentOp::getSuccessorInputs(RegionSuccessor successor) { + return getSingleRegionSuccessorInputs(getOperation(), successor); +} + +void DataOp::getSuccessorRegions(RegionBranchPoint point, + SmallVectorImpl<RegionSuccessor> ®ions) { + getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point, + regions); +} + +ValueRange DataOp::getSuccessorInputs(RegionSuccessor successor) { + return getSingleRegionSuccessorInputs(getOperation(), successor); +} + +void HostDataOp::getSuccessorRegions( + RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { + getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point, + regions); +} + +ValueRange HostDataOp::getSuccessorInputs(RegionSuccessor successor) { + return getSingleRegionSuccessorInputs(getOperation(), successor); +} + +void LoopOp::getSuccessorRegions(RegionBranchPoint point, + SmallVectorImpl<RegionSuccessor> ®ions) { + // Unstructured loops: the body may contain arbitrary CFG and early exits. + // At the RegionBranch level, only model entry into the body and exit to the + // parent; any backedges are represented inside the region CFG. + if (getUnstructured()) { + if (point.isParent()) { + regions.push_back(RegionSuccessor(&getRegion())); + return; + } + regions.push_back(RegionSuccessor::parent()); + return; + } + + // Structured loops: model a loop-shaped region graph similar to scf.for. + regions.push_back(RegionSuccessor(&getRegion())); + regions.push_back(RegionSuccessor::parent()); +} + +ValueRange LoopOp::getSuccessorInputs(RegionSuccessor successor) { + return getSingleRegionSuccessorInputs(getOperation(), successor); +} + +//===----------------------------------------------------------------------===// +// RegionBranchTerminatorOpInterface +//===----------------------------------------------------------------------===// + +MutableOperandRange +TerminatorOp::getMutableSuccessorOperands(RegionSuccessor /*point*/) { + // `acc.terminator` does not forward operands. + return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0); } //===----------------------------------------------------------------------===// @@ -442,6 +691,28 @@ checkValidModifier(Op op, acc::DataClauseModifier validModifiers) { return success(); } +template <typename OpT, typename RecipeOpT> +static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName) { + // Mappable types do not need a recipe because it is possible to generate one + // from its API. Reject reductions though because no API is available for them + // at this time. + if (mlir::acc::isMappableType(op.getVar().getType()) && + !std::is_same_v<OpT, acc::ReductionOp>) + return success(); + + mlir::SymbolRefAttr operandRecipe = op.getRecipeAttr(); + if (!operandRecipe) + return op->emitOpError() << "recipe expected for " << operandName; + + auto decl = + SymbolTable::lookupNearestSymbolFrom<RecipeOpT>(op, operandRecipe); + if (!decl) + return op->emitOpError() + << "expected symbol reference " << operandRecipe << " to point to a " + << operandName << " declaration"; + return success(); +} + static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var) { // Either `var` or `varPtr` keyword is required. @@ -548,6 +819,18 @@ static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, } } +static ParseResult parseRecipeSym(mlir::OpAsmParser &parser, + mlir::SymbolRefAttr &recipeAttr) { + if (failed(parser.parseAttribute(recipeAttr))) + return failure(); + return success(); +} + +static void printRecipeSym(mlir::OpAsmPrinter &p, mlir::Operation *op, + mlir::SymbolRefAttr recipeAttr) { + p << recipeAttr; +} + //===----------------------------------------------------------------------===// // DataBoundsOp //===----------------------------------------------------------------------===// @@ -570,6 +853,9 @@ LogicalResult acc::PrivateOp::verify() { return failure(); if (failed(checkNoModifier(*this))) return failure(); + if (failed( + checkRecipe<acc::PrivateOp, acc::PrivateRecipeOp>(*this, "private"))) + return failure(); return success(); } @@ -584,6 +870,23 @@ LogicalResult acc::FirstprivateOp::verify() { return failure(); if (failed(checkNoModifier(*this))) return failure(); + if (failed(checkRecipe<acc::FirstprivateOp, acc::FirstprivateRecipeOp>( + *this, "firstprivate"))) + return failure(); + return success(); +} + +//===----------------------------------------------------------------------===// +// FirstprivateMapInitialOp +//===----------------------------------------------------------------------===// +LogicalResult acc::FirstprivateMapInitialOp::verify() { + if (getDataClause() != acc::DataClause::acc_firstprivate) + return emitError("data clause associated with firstprivate operation must " + "match its intent"); + if (failed(checkVarAndVarType(*this))) + return failure(); + if (failed(checkNoModifier(*this))) + return failure(); return success(); } @@ -598,6 +901,9 @@ LogicalResult acc::ReductionOp::verify() { return failure(); if (failed(checkNoModifier(*this))) return failure(); + if (failed(checkRecipe<acc::ReductionOp, acc::ReductionRecipeOp>( + *this, "reduction"))) + return failure(); return success(); } @@ -918,6 +1224,270 @@ bool acc::CacheOp::isCacheReadonly() { acc::DataClauseModifier::readonly); } +//===----------------------------------------------------------------------===// +// Data entry/exit operations - getEffects implementations +//===----------------------------------------------------------------------===// + +// This function returns true iff the given operation is enclosed +// in any ACC_COMPUTE_CONSTRUCT_OPS operation. +// It is quite alike acc::getEnclosingComputeOp() utility, +// but we cannot use it here. +static bool isEnclosedIntoComputeOp(mlir::Operation *op) { + mlir::Operation *parentOp = op->getParentOp(); + while (parentOp) { + if (mlir::isa<ACC_COMPUTE_CONSTRUCT_OPS>(parentOp)) + return true; + parentOp = parentOp->getParentOp(); + } + return false; +} + +/// Helper to add an effect on an operand, referenced by its mutable range. +template <typename EffectTy> +static void addOperandEffect( + SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> + &effects, + MutableOperandRange operand) { + for (unsigned i = 0, e = operand.size(); i < e; ++i) + effects.emplace_back(EffectTy::get(), &operand[i]); +} + +/// Helper to add an effect on a result value. +template <typename EffectTy> +static void addResultEffect( + SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> + &effects, + Value result) { + effects.emplace_back(EffectTy::get(), mlir::cast<mlir::OpResult>(result)); +} + +// PrivateOp: accVar result write. +void acc::PrivateOp::getEffects( + SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> + &effects) { + // If acc.private is enclosed into a compute operation, + // then it denotes the device side privatization, hence + // it does not access the CurrentDeviceIdResource. + if (!isEnclosedIntoComputeOp(getOperation())) + effects.emplace_back(MemoryEffects::Read::get(), + acc::CurrentDeviceIdResource::get()); + // TODO: should this be MemoryEffects::Allocate? + addResultEffect<MemoryEffects::Write>(effects, getAccVar()); +} + +// FirstprivateOp: var read, accVar result write. +void acc::FirstprivateOp::getEffects( + SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> + &effects) { + // If acc.firstprivate is enclosed into a compute operation, + // then it denotes the device side privatization, hence + // it does not access the CurrentDeviceIdResource. + if (!isEnclosedIntoComputeOp(getOperation())) + effects.emplace_back(MemoryEffects::Read::get(), + acc::CurrentDeviceIdResource::get()); + addOperandEffect<MemoryEffects::Read>(effects, getVarMutable()); + addResultEffect<MemoryEffects::Write>(effects, getAccVar()); +} + +// FirstprivateMapInitialOp: var read, accVar result write. +void acc::FirstprivateMapInitialOp::getEffects( + SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), + acc::CurrentDeviceIdResource::get()); + addOperandEffect<MemoryEffects::Read>(effects, getVarMutable()); + addResultEffect<MemoryEffects::Write>(effects, getAccVar()); +} + +// ReductionOp: var read, accVar result write. +void acc::ReductionOp::getEffects( + SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> + &effects) { + // If acc.reduction is enclosed into a compute operation, + // then it denotes the device side reduction, hence + // it does not access the CurrentDeviceIdResource. + if (!isEnclosedIntoComputeOp(getOperation())) + effects.emplace_back(MemoryEffects::Read::get(), + acc::CurrentDeviceIdResource::get()); + addOperandEffect<MemoryEffects::Read>(effects, getVarMutable()); + addResultEffect<MemoryEffects::Write>(effects, getAccVar()); +} + +// DevicePtrOp: RuntimeCounters read. +void acc::DevicePtrOp::getEffects( + SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get()); + effects.emplace_back(MemoryEffects::Read::get(), + acc::CurrentDeviceIdResource::get()); +} + +// PresentOp: RuntimeCounters read+write. +void acc::PresentOp::getEffects( + SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get()); + effects.emplace_back(MemoryEffects::Write::get(), + acc::RuntimeCounters::get()); + effects.emplace_back(MemoryEffects::Read::get(), + acc::CurrentDeviceIdResource::get()); +} + +// CopyinOp: RuntimeCounters read+write, var read, accVar result write. +void acc::CopyinOp::getEffects( + SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get()); + effects.emplace_back(MemoryEffects::Write::get(), + acc::RuntimeCounters::get()); + effects.emplace_back(MemoryEffects::Read::get(), + acc::CurrentDeviceIdResource::get()); + addOperandEffect<MemoryEffects::Read>(effects, getVarMutable()); + addResultEffect<MemoryEffects::Write>(effects, getAccVar()); +} + +// CreateOp: RuntimeCounters read+write, accVar result write. +void acc::CreateOp::getEffects( + SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get()); + effects.emplace_back(MemoryEffects::Write::get(), + acc::RuntimeCounters::get()); + effects.emplace_back(MemoryEffects::Read::get(), + acc::CurrentDeviceIdResource::get()); + // TODO: should this be MemoryEffects::Allocate? + addResultEffect<MemoryEffects::Write>(effects, getAccVar()); +} + +// NoCreateOp: RuntimeCounters read+write. +void acc::NoCreateOp::getEffects( + SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get()); + effects.emplace_back(MemoryEffects::Write::get(), + acc::RuntimeCounters::get()); + effects.emplace_back(MemoryEffects::Read::get(), + acc::CurrentDeviceIdResource::get()); +} + +// AttachOp: RuntimeCounters read+write, var read. +void acc::AttachOp::getEffects( + SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get()); + effects.emplace_back(MemoryEffects::Write::get(), + acc::RuntimeCounters::get()); + effects.emplace_back(MemoryEffects::Read::get(), + acc::CurrentDeviceIdResource::get()); + // TODO: should we also add MemoryEffects::Write? + addOperandEffect<MemoryEffects::Read>(effects, getVarMutable()); +} + +// GetDevicePtrOp: RuntimeCounters read. +void acc::GetDevicePtrOp::getEffects( + SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get()); + effects.emplace_back(MemoryEffects::Read::get(), + acc::CurrentDeviceIdResource::get()); +} + +// UpdateDeviceOp: var read, accVar result write. +void acc::UpdateDeviceOp::getEffects( + SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), + acc::CurrentDeviceIdResource::get()); + addOperandEffect<MemoryEffects::Read>(effects, getVarMutable()); + addResultEffect<MemoryEffects::Write>(effects, getAccVar()); +} + +// UseDeviceOp: RuntimeCounters read. +void acc::UseDeviceOp::getEffects( + SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get()); + effects.emplace_back(MemoryEffects::Read::get(), + acc::CurrentDeviceIdResource::get()); +} + +// DeclareDeviceResidentOp: RuntimeCounters write, var read. +void acc::DeclareDeviceResidentOp::getEffects( + SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), + acc::RuntimeCounters::get()); + effects.emplace_back(MemoryEffects::Read::get(), + acc::CurrentDeviceIdResource::get()); + addOperandEffect<MemoryEffects::Read>(effects, getVarMutable()); +} + +// DeclareLinkOp: RuntimeCounters write, var read. +void acc::DeclareLinkOp::getEffects( + SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), + acc::RuntimeCounters::get()); + effects.emplace_back(MemoryEffects::Read::get(), + acc::CurrentDeviceIdResource::get()); + addOperandEffect<MemoryEffects::Read>(effects, getVarMutable()); +} + +// CacheOp: NoMemoryEffect +void acc::CacheOp::getEffects( + SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> + &effects) {} + +// CopyoutOp: RuntimeCounters read+write, accVar read, var write. +void acc::CopyoutOp::getEffects( + SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get()); + effects.emplace_back(MemoryEffects::Write::get(), + acc::RuntimeCounters::get()); + effects.emplace_back(MemoryEffects::Read::get(), + acc::CurrentDeviceIdResource::get()); + addOperandEffect<MemoryEffects::Read>(effects, getAccVarMutable()); + addOperandEffect<MemoryEffects::Write>(effects, getVarMutable()); +} + +// DeleteOp: RuntimeCounters read+write, accVar read. +void acc::DeleteOp::getEffects( + SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get()); + effects.emplace_back(MemoryEffects::Write::get(), + acc::RuntimeCounters::get()); + effects.emplace_back(MemoryEffects::Read::get(), + acc::CurrentDeviceIdResource::get()); + addOperandEffect<MemoryEffects::Read>(effects, getAccVarMutable()); +} + +// DetachOp: RuntimeCounters read+write, accVar read. +void acc::DetachOp::getEffects( + SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get()); + effects.emplace_back(MemoryEffects::Write::get(), + acc::RuntimeCounters::get()); + effects.emplace_back(MemoryEffects::Read::get(), + acc::CurrentDeviceIdResource::get()); + addOperandEffect<MemoryEffects::Read>(effects, getAccVarMutable()); +} + +// UpdateHostOp: RuntimeCounters read+write, accVar read, var write. +void acc::UpdateHostOp::getEffects( + SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get()); + effects.emplace_back(MemoryEffects::Write::get(), + acc::RuntimeCounters::get()); + effects.emplace_back(MemoryEffects::Read::get(), + acc::CurrentDeviceIdResource::get()); + addOperandEffect<MemoryEffects::Read>(effects, getAccVarMutable()); + addOperandEffect<MemoryEffects::Write>(effects, getVarMutable()); +} + template <typename StructureOp> static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions = 1) { @@ -1003,6 +1573,197 @@ struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> { } }; +/// Remove empty acc.kernel_environment operations. If the operation has wait +/// operands, create a acc.wait operation to preserve synchronization. +struct RemoveEmptyKernelEnvironment + : public OpRewritePattern<acc::KernelEnvironmentOp> { + using OpRewritePattern<acc::KernelEnvironmentOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(acc::KernelEnvironmentOp op, + PatternRewriter &rewriter) const override { + assert(op->getNumRegions() == 1 && "expected op to have one region"); + + Block &block = op.getRegion().front(); + if (!block.empty()) + return failure(); + + // Conservatively disable canonicalization of empty acc.kernel_environment + // operations if the wait operands in the kernel_environment cannot be fully + // represented by acc.wait operation. + + // Disable canonicalization if device type is not the default + if (auto deviceTypeAttr = op.getWaitOperandsDeviceTypeAttr()) { + for (auto attr : deviceTypeAttr) { + if (auto dtAttr = mlir::dyn_cast<acc::DeviceTypeAttr>(attr)) { + if (dtAttr.getValue() != mlir::acc::DeviceType::None) + return failure(); + } + } + } + + // Disable canonicalization if any wait segment has a devnum + if (auto hasDevnumAttr = op.getHasWaitDevnumAttr()) { + for (auto attr : hasDevnumAttr) { + if (auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>(attr)) { + if (boolAttr.getValue()) + return failure(); + } + } + } + + // Disable canonicalization if there are multiple wait segments + if (auto segmentsAttr = op.getWaitOperandsSegmentsAttr()) { + if (segmentsAttr.size() > 1) + return failure(); + } + + // Remove empty kernel environment. + // Preserve synchronization by creating acc.wait operation if needed. + if (!op.getWaitOperands().empty() || op.getWaitOnlyAttr()) + rewriter.replaceOpWithNewOp<acc::WaitOp>(op, op.getWaitOperands(), + /*asyncOperand=*/Value(), + /*waitDevnum=*/Value(), + /*async=*/nullptr, + /*ifCond=*/Value()); + else + rewriter.eraseOp(op); + + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Recipe Region Helpers +//===----------------------------------------------------------------------===// + +/// Create and populate an init region for privatization recipes. +/// Returns success if the region is populated, failure otherwise. +/// Sets needsFree to indicate if the allocated memory requires deallocation. +static LogicalResult createInitRegion(OpBuilder &builder, Location loc, + Region &initRegion, Type varType, + StringRef varName, ValueRange bounds, + bool &needsFree) { + // Create init block with arguments: original value + bounds + SmallVector<Type> argTypes{varType}; + SmallVector<Location> argLocs{loc}; + for (Value bound : bounds) { + argTypes.push_back(bound.getType()); + argLocs.push_back(loc); + } + + Block *initBlock = builder.createBlock(&initRegion); + initBlock->addArguments(argTypes, argLocs); + builder.setInsertionPointToStart(initBlock); + + Value privatizedValue; + + // Get the block argument that represents the original variable + Value blockArgVar = initBlock->getArgument(0); + + // Generate init region body based on variable type + if (isa<MappableType>(varType)) { + auto mappableTy = cast<MappableType>(varType); + auto typedVar = cast<TypedValue<MappableType>>(blockArgVar); + privatizedValue = mappableTy.generatePrivateInit( + builder, loc, typedVar, varName, bounds, {}, needsFree); + if (!privatizedValue) + return failure(); + } else { + assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType"); + auto pointerLikeTy = cast<PointerLikeType>(varType); + // Use PointerLikeType's allocation API with the block argument + privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType, + blockArgVar, needsFree); + if (!privatizedValue) + return failure(); + } + + // Add yield operation to init block + acc::YieldOp::create(builder, loc, privatizedValue); + + return success(); +} + +/// Create and populate a copy region for firstprivate recipes. +/// Returns success if the region is populated, failure otherwise. +/// TODO: Handle MappableType - it does not yet have a copy API. +static LogicalResult createCopyRegion(OpBuilder &builder, Location loc, + Region ©Region, Type varType, + ValueRange bounds) { + // Create copy block with arguments: original value + privatized value + + // bounds + SmallVector<Type> copyArgTypes{varType, varType}; + SmallVector<Location> copyArgLocs{loc, loc}; + for (Value bound : bounds) { + copyArgTypes.push_back(bound.getType()); + copyArgLocs.push_back(loc); + } + + Block *copyBlock = builder.createBlock(©Region); + copyBlock->addArguments(copyArgTypes, copyArgLocs); + builder.setInsertionPointToStart(copyBlock); + + bool isMappable = isa<MappableType>(varType); + bool isPointerLike = isa<PointerLikeType>(varType); + // TODO: Handle MappableType - it does not yet have a copy API. + // Otherwise, for now just fallback to pointer-like behavior. + if (isMappable && !isPointerLike) + return failure(); + + // Generate copy region body based on variable type + if (isPointerLike) { + auto pointerLikeTy = cast<PointerLikeType>(varType); + Value originalArg = copyBlock->getArgument(0); + Value privatizedArg = copyBlock->getArgument(1); + + // Generate copy operation using PointerLikeType interface + if (!pointerLikeTy.genCopy( + builder, loc, cast<TypedValue<PointerLikeType>>(privatizedArg), + cast<TypedValue<PointerLikeType>>(originalArg), varType)) + return failure(); + } + + // Add terminator to copy block + acc::TerminatorOp::create(builder, loc); + + return success(); +} + +/// Create and populate a destroy region for privatization recipes. +/// Returns success if the region is populated, failure otherwise. +static LogicalResult createDestroyRegion(OpBuilder &builder, Location loc, + Region &destroyRegion, Type varType, + Value allocRes, ValueRange bounds) { + // Create destroy block with arguments: original value + privatized value + + // bounds + SmallVector<Type> destroyArgTypes{varType, varType}; + SmallVector<Location> destroyArgLocs{loc, loc}; + for (Value bound : bounds) { + destroyArgTypes.push_back(bound.getType()); + destroyArgLocs.push_back(loc); + } + + Block *destroyBlock = builder.createBlock(&destroyRegion); + destroyBlock->addArguments(destroyArgTypes, destroyArgLocs); + builder.setInsertionPointToStart(destroyBlock); + + auto varToFree = + cast<TypedValue<PointerLikeType>>(destroyBlock->getArgument(1)); + if (isa<MappableType>(varType)) { + auto mappableTy = cast<MappableType>(varType); + if (!mappableTy.generatePrivateDestroy(builder, loc, varToFree, bounds)) + return failure(); + } else { + assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType"); + auto pointerLikeTy = cast<PointerLikeType>(varType); + if (!pointerLikeTy.genFree(builder, loc, varToFree, allocRes, varType)) + return failure(); + } + + acc::TerminatorOp::create(builder, loc); + return success(); +} + } // namespace //===----------------------------------------------------------------------===// @@ -1050,6 +1811,70 @@ LogicalResult acc::PrivateRecipeOp::verifyRegions() { return success(); } +std::optional<PrivateRecipeOp> +PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc, + StringRef recipeName, Type varType, + StringRef varName, ValueRange bounds) { + // First, validate that we can handle this variable type + bool isMappable = isa<MappableType>(varType); + bool isPointerLike = isa<PointerLikeType>(varType); + + // Unsupported type + if (!isMappable && !isPointerLike) + return std::nullopt; + + OpBuilder::InsertionGuard guard(builder); + + // Create the recipe operation first so regions have proper parent context + auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType); + + // Populate the init region + bool needsFree = false; + if (failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType, + varName, bounds, needsFree))) { + recipe.erase(); + return std::nullopt; + } + + // Only create destroy region if the allocation needs deallocation + if (needsFree) { + // Extract the allocated value from the init block's yield operation + auto yieldOp = + cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator()); + Value allocRes = yieldOp.getOperand(0); + + if (failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(), + varType, allocRes, bounds))) { + recipe.erase(); + return std::nullopt; + } + } + + return recipe; +} + +std::optional<PrivateRecipeOp> +PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc, + StringRef recipeName, + FirstprivateRecipeOp firstprivRecipe) { + // Create the private.recipe op with the same type as the firstprivate.recipe. + OpBuilder::InsertionGuard guard(builder); + auto varType = firstprivRecipe.getType(); + auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType); + + // Clone the init region + IRMapping mapping; + firstprivRecipe.getInitRegion().cloneInto(&recipe.getInitRegion(), mapping); + + // Clone destroy region if the firstprivate.recipe has one. + if (!firstprivRecipe.getDestroyRegion().empty()) { + IRMapping mapping; + firstprivRecipe.getDestroyRegion().cloneInto(&recipe.getDestroyRegion(), + mapping); + } + return recipe; +} + //===----------------------------------------------------------------------===// // FirstprivateRecipeOp //===----------------------------------------------------------------------===// @@ -1080,6 +1905,55 @@ LogicalResult acc::FirstprivateRecipeOp::verifyRegions() { return success(); } +std::optional<FirstprivateRecipeOp> +FirstprivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc, + StringRef recipeName, Type varType, + StringRef varName, ValueRange bounds) { + // First, validate that we can handle this variable type + bool isMappable = isa<MappableType>(varType); + bool isPointerLike = isa<PointerLikeType>(varType); + + // Unsupported type + if (!isMappable && !isPointerLike) + return std::nullopt; + + OpBuilder::InsertionGuard guard(builder); + + // Create the recipe operation first so regions have proper parent context + auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType); + + // Populate the init region + bool needsFree = false; + if (failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType, + varName, bounds, needsFree))) { + recipe.erase(); + return std::nullopt; + } + + // Populate the copy region + if (failed(createCopyRegion(builder, loc, recipe.getCopyRegion(), varType, + bounds))) { + recipe.erase(); + return std::nullopt; + } + + // Only create destroy region if the allocation needs deallocation + if (needsFree) { + // Extract the allocated value from the init block's yield operation + auto yieldOp = + cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator()); + Value allocRes = yieldOp.getOperand(0); + + if (failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(), + varType, allocRes, bounds))) { + recipe.erase(); + return std::nullopt; + } + } + + return recipe; +} + //===----------------------------------------------------------------------===// // ReductionRecipeOp //===----------------------------------------------------------------------===// @@ -1111,40 +1985,6 @@ LogicalResult acc::ReductionRecipeOp::verifyRegions() { } //===----------------------------------------------------------------------===// -// Custom parser and printer verifier for private clause -//===----------------------------------------------------------------------===// - -static ParseResult parseSymOperandList( - mlir::OpAsmParser &parser, - llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands, - llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &symbols) { - llvm::SmallVector<SymbolRefAttr> attributes; - if (failed(parser.parseCommaSeparatedList([&]() { - if (parser.parseAttribute(attributes.emplace_back()) || - parser.parseArrow() || - parser.parseOperand(operands.emplace_back()) || - parser.parseColonType(types.emplace_back())) - return failure(); - return success(); - }))) - return failure(); - llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(), - attributes.end()); - symbols = ArrayAttr::get(parser.getContext(), arrayAttr); - return success(); -} - -static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, - mlir::OperandRange operands, - mlir::TypeRange types, - std::optional<mlir::ArrayAttr> attributes) { - llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](auto it) { - p << std::get<0>(it) << " -> " << std::get<1>(it) << " : " - << std::get<1>(it).getType(); - }); -} - -//===----------------------------------------------------------------------===// // ParallelOp //===----------------------------------------------------------------------===// @@ -1163,45 +2003,19 @@ static LogicalResult checkDataOperands(Op op, return success(); } -template <typename Op> -static LogicalResult -checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes, - mlir::OperandRange operands, llvm::StringRef operandName, - llvm::StringRef symbolName, bool checkOperandType = true) { - if (!operands.empty()) { - if (!attributes || attributes->size() != operands.size()) - return op->emitOpError() - << "expected as many " << symbolName << " symbol reference as " - << operandName << " operands"; - } else { - if (attributes) - return op->emitOpError() - << "unexpected " << symbolName << " symbol reference"; - return success(); - } - +template <typename OpT, typename RecipeOpT> +static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp, + const mlir::ValueRange &operands, + llvm::StringRef operandName) { llvm::DenseSet<Value> set; - for (auto args : llvm::zip(operands, *attributes)) { - mlir::Value operand = std::get<0>(args); - + for (mlir::Value operand : operands) { + if (!mlir::isa<OpT>(operand.getDefiningOp())) + return accConstructOp->emitOpError() + << "expected " << operandName << " as defining op"; if (!set.insert(operand).second) - return op->emitOpError() + return accConstructOp->emitOpError() << operandName << " operand appears more than once"; - - mlir::Type varType = operand.getType(); - auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args)); - auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef); - if (!decl) - return op->emitOpError() - << "expected symbol reference " << symbolRef << " to point to a " - << operandName << " declaration"; - - if (checkOperandType && decl.getType() && decl.getType() != varType) - return op->emitOpError() << "expected " << operandName << " (" << varType - << ") to be the same type as " << operandName - << " declaration (" << decl.getType() << ")"; } - return success(); } @@ -1258,17 +2072,17 @@ static LogicalResult verifyDeviceTypeAndSegmentCountMatch( } LogicalResult acc::ParallelOp::verify() { - if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>( - *this, getPrivatizationRecipes(), getPrivateOperands(), "private", - "privatizations", /*checkOperandType=*/false))) + if (failed(checkPrivateOperands<mlir::acc::PrivateOp, + mlir::acc::PrivateRecipeOp>( + *this, getPrivateOperands(), "private"))) return failure(); - if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>( - *this, getFirstprivatizationRecipes(), getFirstprivateOperands(), - "firstprivate", "firstprivatizations", /*checkOperandType=*/false))) + if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp, + mlir::acc::FirstprivateRecipeOp>( + *this, getFirstprivateOperands(), "firstprivate"))) return failure(); - if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>( - *this, getReductionRecipes(), getReductionOperands(), "reduction", - "reductions", false))) + if (failed(checkPrivateOperands<mlir::acc::ReductionOp, + mlir::acc::ReductionRecipeOp>( + *this, getReductionOperands(), "reduction"))) return failure(); if (failed(verifyDeviceTypeAndSegmentCountMatch( @@ -1399,7 +2213,6 @@ void ParallelOp::build(mlir::OpBuilder &odsBuilder, mlir::ValueRange gangPrivateOperands, mlir::ValueRange gangFirstPrivateOperands, mlir::ValueRange dataClauseOperands) { - ParallelOp::build( odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr, /*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr, @@ -1408,9 +2221,8 @@ void ParallelOp::build(mlir::OpBuilder &odsBuilder, /*numGangsDeviceType=*/nullptr, numWorkers, /*numWorkersDeviceType=*/nullptr, vectorLength, /*vectorLengthDeviceType=*/nullptr, ifCond, selfCond, - /*selfAttr=*/nullptr, reductionOperands, /*reductionRecipes=*/nullptr, - gangPrivateOperands, /*privatizations=*/nullptr, gangFirstPrivateOperands, - /*firstprivatizations=*/nullptr, dataClauseOperands, + /*selfAttr=*/nullptr, reductionOperands, gangPrivateOperands, + gangFirstPrivateOperands, dataClauseOperands, /*defaultAttr=*/nullptr, /*combined=*/nullptr); } @@ -1487,46 +2299,22 @@ void acc::ParallelOp::addWaitOperands( void acc::ParallelOp::addPrivatization(MLIRContext *context, mlir::acc::PrivateOp op, mlir::acc::PrivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getPrivateOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getPrivatizationRecipesAttr()) - llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } void acc::ParallelOp::addFirstPrivatization( MLIRContext *context, mlir::acc::FirstprivateOp op, mlir::acc::FirstprivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getFirstprivateOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getFirstprivatizationRecipesAttr()) - llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } void acc::ParallelOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op, mlir::acc::ReductionRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getReductionOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getReductionRecipesAttr()) - llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } static ParseResult parseNumGangs( @@ -2094,17 +2882,17 @@ mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) { } LogicalResult acc::SerialOp::verify() { - if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>( - *this, getPrivatizationRecipes(), getPrivateOperands(), "private", - "privatizations", /*checkOperandType=*/false))) + if (failed(checkPrivateOperands<mlir::acc::PrivateOp, + mlir::acc::PrivateRecipeOp>( + *this, getPrivateOperands(), "private"))) return failure(); - if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>( - *this, getFirstprivatizationRecipes(), getFirstprivateOperands(), - "firstprivate", "firstprivatizations", /*checkOperandType=*/false))) + if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp, + mlir::acc::FirstprivateRecipeOp>( + *this, getFirstprivateOperands(), "firstprivate"))) return failure(); - if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>( - *this, getReductionRecipes(), getReductionOperands(), "reduction", - "reductions", false))) + if (failed(checkPrivateOperands<mlir::acc::ReductionOp, + mlir::acc::ReductionRecipeOp>( + *this, getReductionOperands(), "reduction"))) return failure(); if (failed(verifyDeviceTypeAndSegmentCountMatch( @@ -2168,46 +2956,22 @@ void acc::SerialOp::addWaitOperands( void acc::SerialOp::addPrivatization(MLIRContext *context, mlir::acc::PrivateOp op, mlir::acc::PrivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getPrivateOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getPrivatizationRecipesAttr()) - llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } void acc::SerialOp::addFirstPrivatization( MLIRContext *context, mlir::acc::FirstprivateOp op, mlir::acc::FirstprivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getFirstprivateOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getFirstprivatizationRecipesAttr()) - llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } void acc::SerialOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op, mlir::acc::ReductionRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getReductionOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getReductionRecipesAttr()) - llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } //===----------------------------------------------------------------------===// @@ -2337,6 +3101,27 @@ LogicalResult acc::KernelsOp::verify() { return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands()); } +void acc::KernelsOp::addPrivatization(MLIRContext *context, + mlir::acc::PrivateOp op, + mlir::acc::PrivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); + getPrivateOperandsMutable().append(op.getResult()); +} + +void acc::KernelsOp::addFirstPrivatization( + MLIRContext *context, mlir::acc::FirstprivateOp op, + mlir::acc::FirstprivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); + getFirstprivateOperandsMutable().append(op.getResult()); +} + +void acc::KernelsOp::addReduction(MLIRContext *context, + mlir::acc::ReductionOp op, + mlir::acc::ReductionRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); + getReductionOperandsMutable().append(op.getResult()); +} + void acc::KernelsOp::addNumWorkersOperand( MLIRContext *context, mlir::Value newValue, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) { @@ -2417,9 +3202,17 @@ LogicalResult acc::HostDataOp::verify() { return emitError("at least one operand must appear on the host_data " "operation"); - for (mlir::Value operand : getDataClauseOperands()) - if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp())) + llvm::SmallPtrSet<mlir::Value, 4> seenVars; + for (mlir::Value operand : getDataClauseOperands()) { + auto useDeviceOp = + mlir::dyn_cast<acc::UseDeviceOp>(operand.getDefiningOp()); + if (!useDeviceOp) return emitError("expect data entry operation as defining op"); + + // Check for duplicate use_device clauses + if (!seenVars.insert(useDeviceOp.getVar()).second) + return emitError("duplicate use_device variable"); + } return success(); } @@ -2429,6 +3222,15 @@ void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results, } //===----------------------------------------------------------------------===// +// KernelEnvironmentOp +//===----------------------------------------------------------------------===// + +void acc::KernelEnvironmentOp::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + results.add<RemoveEmptyKernelEnvironment>(context); +} + +//===----------------------------------------------------------------------===// // LoopOp //===----------------------------------------------------------------------===// @@ -2637,19 +3439,21 @@ bool hasDuplicateDeviceTypes( } /// Check for duplicates in the DeviceType array attribute. -LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes) { +/// Returns std::nullopt if no duplicates, or the duplicate DeviceType if found. +static std::optional<mlir::acc::DeviceType> +checkDeviceTypes(mlir::ArrayAttr deviceTypes) { llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes; if (!deviceTypes) - return success(); + return std::nullopt; for (auto attr : deviceTypes) { auto deviceTypeAttr = mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr); if (!deviceTypeAttr) - return failure(); + return mlir::acc::DeviceType::None; if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second) - return failure(); + return deviceTypeAttr.getValue(); } - return success(); + return std::nullopt; } LogicalResult acc::LoopOp::verify() { @@ -2676,9 +3480,10 @@ LogicalResult acc::LoopOp::verify() { getCollapseDeviceTypeAttr().getValue().size()) return emitOpError() << "collapse attribute count must match collapse" << " device_type count"; - if (failed(checkDeviceTypes(getCollapseDeviceTypeAttr()))) - return emitOpError() - << "duplicate device_type found in collapseDeviceType attribute"; + if (auto duplicateDeviceType = checkDeviceTypes(getCollapseDeviceTypeAttr())) + return emitOpError() << "duplicate device_type `" + << acc::stringifyDeviceType(*duplicateDeviceType) + << "` found in collapseDeviceType attribute"; // Check gang if (!getGangOperands().empty()) { @@ -2691,8 +3496,12 @@ LogicalResult acc::LoopOp::verify() { return emitOpError() << "gangOperandsArgType attribute count must match" << " gangOperands count"; } - if (getGangAttr() && failed(checkDeviceTypes(getGangAttr()))) - return emitOpError() << "duplicate device_type found in gang attribute"; + if (getGangAttr()) { + if (auto duplicateDeviceType = checkDeviceTypes(getGangAttr())) + return emitOpError() << "duplicate device_type `" + << acc::stringifyDeviceType(*duplicateDeviceType) + << "` found in gang attribute"; + } if (failed(verifyDeviceTypeAndSegmentCountMatch( *this, getGangOperands(), getGangOperandsSegmentsAttr(), @@ -2700,22 +3509,30 @@ LogicalResult acc::LoopOp::verify() { return failure(); // Check worker - if (failed(checkDeviceTypes(getWorkerAttr()))) - return emitOpError() << "duplicate device_type found in worker attribute"; - if (failed(checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr()))) - return emitOpError() << "duplicate device_type found in " - "workerNumOperandsDeviceType attribute"; + if (auto duplicateDeviceType = checkDeviceTypes(getWorkerAttr())) + return emitOpError() << "duplicate device_type `" + << acc::stringifyDeviceType(*duplicateDeviceType) + << "` found in worker attribute"; + if (auto duplicateDeviceType = + checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr())) + return emitOpError() << "duplicate device_type `" + << acc::stringifyDeviceType(*duplicateDeviceType) + << "` found in workerNumOperandsDeviceType attribute"; if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(), getWorkerNumOperandsDeviceTypeAttr(), "worker"))) return failure(); // Check vector - if (failed(checkDeviceTypes(getVectorAttr()))) - return emitOpError() << "duplicate device_type found in vector attribute"; - if (failed(checkDeviceTypes(getVectorOperandsDeviceTypeAttr()))) - return emitOpError() << "duplicate device_type found in " - "vectorOperandsDeviceType attribute"; + if (auto duplicateDeviceType = checkDeviceTypes(getVectorAttr())) + return emitOpError() << "duplicate device_type `" + << acc::stringifyDeviceType(*duplicateDeviceType) + << "` found in vector attribute"; + if (auto duplicateDeviceType = + checkDeviceTypes(getVectorOperandsDeviceTypeAttr())) + return emitOpError() << "duplicate device_type `" + << acc::stringifyDeviceType(*duplicateDeviceType) + << "` found in vectorOperandsDeviceType attribute"; if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(), getVectorOperandsDeviceTypeAttr(), "vector"))) @@ -2780,19 +3597,19 @@ LogicalResult acc::LoopOp::verify() { } } - if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>( - *this, getPrivatizationRecipes(), getPrivateOperands(), "private", - "privatizations", false))) + if (failed(checkPrivateOperands<mlir::acc::PrivateOp, + mlir::acc::PrivateRecipeOp>( + *this, getPrivateOperands(), "private"))) return failure(); - if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>( - *this, getFirstprivatizationRecipes(), getFirstprivateOperands(), - "firstprivate", "firstprivatizations", /*checkOperandType=*/false))) + if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp, + mlir::acc::FirstprivateRecipeOp>( + *this, getFirstprivateOperands(), "firstprivate"))) return failure(); - if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>( - *this, getReductionRecipes(), getReductionOperands(), "reduction", - "reductions", false))) + if (failed(checkPrivateOperands<mlir::acc::ReductionOp, + mlir::acc::ReductionRecipeOp>( + *this, getReductionOperands(), "reduction"))) return failure(); if (getCombined().has_value() && @@ -2806,8 +3623,12 @@ LogicalResult acc::LoopOp::verify() { if (getRegion().empty()) return emitError("expected non-empty body."); - // When it is container-like - it is expected to hold a loop-like operation. - if (isContainerLike()) { + if (getUnstructured()) { + if (!isContainerLike()) + return emitError( + "unstructured acc.loop must not have induction variables"); + } else if (isContainerLike()) { + // When it is container-like - it is expected to hold a loop-like operation. // Obtain the maximum collapse count - we use this to check that there // are enough loops contained. uint64_t collapseCount = getCollapseValue().value_or(1); @@ -3222,45 +4043,21 @@ void acc::LoopOp::addGangOperands( void acc::LoopOp::addPrivatization(MLIRContext *context, mlir::acc::PrivateOp op, mlir::acc::PrivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getPrivateOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getPrivatizationRecipesAttr()) - llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } void acc::LoopOp::addFirstPrivatization( MLIRContext *context, mlir::acc::FirstprivateOp op, mlir::acc::FirstprivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getFirstprivateOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getFirstprivatizationRecipesAttr()) - llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } void acc::LoopOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op, mlir::acc::ReductionRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getReductionOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getReductionRecipesAttr()) - llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } //===----------------------------------------------------------------------===// @@ -3597,7 +4394,8 @@ LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op, } if (Value writeVal = op.getWriteOpVal()) { - rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal); + rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal, + op.getIfCond()); return success(); } @@ -3724,7 +4522,8 @@ LogicalResult acc::RoutineOp::verify() { if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1)) return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can " - "be present at the same time"; + "be present at the same time for device_type `" + << acc::stringifyDeviceType(dtype) << "`"; } return success(); @@ -4021,6 +4820,100 @@ RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) { return std::nullopt; } +void RoutineOp::addSeq(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes) { + setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(), + effectiveDeviceTypes)); +} + +void RoutineOp::addVector(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes) { + setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(), + effectiveDeviceTypes)); +} + +void RoutineOp::addWorker(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes) { + setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(), + effectiveDeviceTypes)); +} + +void RoutineOp::addGang(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes) { + setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(), + effectiveDeviceTypes)); +} + +void RoutineOp::addGang(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes, + uint64_t val) { + llvm::SmallVector<mlir::Attribute> dimValues; + llvm::SmallVector<mlir::Attribute> deviceTypes; + + if (getGangDimAttr()) + llvm::copy(getGangDimAttr(), std::back_inserter(dimValues)); + if (getGangDimDeviceTypeAttr()) + llvm::copy(getGangDimDeviceTypeAttr(), std::back_inserter(deviceTypes)); + + assert(dimValues.size() == deviceTypes.size()); + + if (effectiveDeviceTypes.empty()) { + dimValues.push_back( + mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val)); + deviceTypes.push_back( + acc::DeviceTypeAttr::get(context, acc::DeviceType::None)); + } else { + for (DeviceType dt : effectiveDeviceTypes) { + dimValues.push_back( + mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val)); + deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt)); + } + } + assert(dimValues.size() == deviceTypes.size()); + + setGangDimAttr(mlir::ArrayAttr::get(context, dimValues)); + setGangDimDeviceTypeAttr(mlir::ArrayAttr::get(context, deviceTypes)); +} + +void RoutineOp::addBindStrName(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes, + mlir::StringAttr val) { + unsigned before = getBindStrNameDeviceTypeAttr() + ? getBindStrNameDeviceTypeAttr().size() + : 0; + + setBindStrNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper( + context, getBindStrNameDeviceTypeAttr(), effectiveDeviceTypes)); + unsigned after = getBindStrNameDeviceTypeAttr().size(); + + llvm::SmallVector<mlir::Attribute> vals; + if (getBindStrNameAttr()) + llvm::copy(getBindStrNameAttr(), std::back_inserter(vals)); + for (unsigned i = 0; i < after - before; ++i) + vals.push_back(val); + + setBindStrNameAttr(mlir::ArrayAttr::get(context, vals)); +} + +void RoutineOp::addBindIDName(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes, + mlir::SymbolRefAttr val) { + unsigned before = + getBindIdNameDeviceTypeAttr() ? getBindIdNameDeviceTypeAttr().size() : 0; + + setBindIdNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper( + context, getBindIdNameDeviceTypeAttr(), effectiveDeviceTypes)); + unsigned after = getBindIdNameDeviceTypeAttr().size(); + + llvm::SmallVector<mlir::Attribute> vals; + if (getBindIdNameAttr()) + llvm::copy(getBindIdNameAttr(), std::back_inserter(vals)); + for (unsigned i = 0; i < after - before; ++i) + vals.push_back(val); + + setBindIdNameAttr(mlir::ArrayAttr::get(context, vals)); +} + //===----------------------------------------------------------------------===// // InitOp //===----------------------------------------------------------------------===// @@ -4405,13 +5298,11 @@ mlir::acc::getMutableDataOperands(mlir::Operation *accOp) { return dataOperands; } -mlir::Operation *mlir::acc::getEnclosingComputeOp(mlir::Region ®ion) { - mlir::Operation *parentOp = region.getParentOp(); - while (parentOp) { - if (mlir::isa<ACC_COMPUTE_CONSTRUCT_OPS>(parentOp)) { - return parentOp; - } - parentOp = parentOp->getParentOp(); - } - return nullptr; +mlir::SymbolRefAttr mlir::acc::getRecipe(mlir::Operation *accOp) { + auto recipe{ + llvm::TypeSwitch<mlir::Operation *, mlir::SymbolRefAttr>(accOp) + .Case<ACC_DATA_ENTRY_OPS>( + [&](auto entry) { return entry.getRecipeAttr(); }) + .Default([&](mlir::Operation *) { return mlir::SymbolRefAttr{}; })}; + return recipe; } |
