aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp')
-rw-r--r--mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp111
1 files changed, 111 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 6598ac1..6564a4e 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -7,6 +7,7 @@
// =============================================================================
#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -44,6 +45,7 @@ struct MemRefPointerLikeModel
Type getElementType(Type pointer) const {
return cast<MemRefType>(pointer).getElementType();
}
+
mlir::acc::VariableTypeCategory
getPointeeTypeCategory(Type pointer, TypedValue<PointerLikeType> varPtr,
Type varType) const {
@@ -70,6 +72,115 @@ struct MemRefPointerLikeModel
assert(memrefTy.getRank() > 0 && "rank expected to be positive");
return mlir::acc::VariableTypeCategory::array;
}
+
+ mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc,
+ StringRef varName, Type varType,
+ Value originalVar) const {
+ auto memrefTy = cast<MemRefType>(pointer);
+
+ // Check if this is a static memref (all dimensions are known) - if yes
+ // then we can generate an alloca operation.
+ if (memrefTy.hasStaticShape())
+ return memref::AllocaOp::create(builder, loc, memrefTy).getResult();
+
+ // For dynamic memrefs, extract sizes from the original variable if
+ // provided. Otherwise they cannot be handled.
+ if (originalVar && originalVar.getType() == memrefTy &&
+ memrefTy.hasRank()) {
+ SmallVector<Value> dynamicSizes;
+ for (int64_t i = 0; i < memrefTy.getRank(); ++i) {
+ if (memrefTy.isDynamicDim(i)) {
+ // Extract the size of dimension i from the original variable
+ auto indexValue = arith::ConstantIndexOp::create(builder, loc, i);
+ auto dimSize =
+ memref::DimOp::create(builder, loc, originalVar, indexValue);
+ dynamicSizes.push_back(dimSize);
+ }
+ // Note: We only add dynamic sizes to the dynamicSizes array
+ // Static dimensions are handled automatically by AllocOp
+ }
+ return memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes)
+ .getResult();
+ }
+
+ // TODO: Unranked not yet supported.
+ return {};
+ }
+
+ bool genFree(Type pointer, OpBuilder &builder, Location loc,
+ TypedValue<PointerLikeType> varPtr, Type varType) const {
+ if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varPtr)) {
+ // Walk through casts to find the original allocation
+ Value currentValue = memrefValue;
+ Operation *originalAlloc = nullptr;
+
+ // Follow the chain of operations to find the original allocation
+ // even if a casted result is provided.
+ while (currentValue) {
+ if (auto *definingOp = currentValue.getDefiningOp()) {
+ // Check if this is an allocation operation
+ if (isa<memref::AllocOp, memref::AllocaOp>(definingOp)) {
+ originalAlloc = definingOp;
+ break;
+ }
+
+ // Check if this is a cast operation we can look through
+ if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
+ currentValue = castOp.getSource();
+ continue;
+ }
+
+ // Check for other cast-like operations
+ if (auto reinterpretCastOp =
+ dyn_cast<memref::ReinterpretCastOp>(definingOp)) {
+ currentValue = reinterpretCastOp.getSource();
+ continue;
+ }
+
+ // If we can't look through this operation, stop
+ break;
+ }
+ // This is a block argument or similar - can't trace further.
+ break;
+ }
+
+ if (originalAlloc) {
+ if (isa<memref::AllocaOp>(originalAlloc)) {
+ // This is an alloca - no dealloc needed, but return true (success)
+ return true;
+ }
+ if (isa<memref::AllocOp>(originalAlloc)) {
+ // This is an alloc - generate dealloc
+ memref::DeallocOp::create(builder, loc, memrefValue);
+ return true;
+ }
+ }
+ }
+
+ return false;
+ }
+
+ bool genCopy(Type pointer, OpBuilder &builder, Location loc,
+ TypedValue<PointerLikeType> destination,
+ TypedValue<PointerLikeType> source, Type varType) const {
+ // Generate a copy operation between two memrefs
+ auto destMemref = dyn_cast_if_present<TypedValue<MemRefType>>(destination);
+ auto srcMemref = dyn_cast_if_present<TypedValue<MemRefType>>(source);
+
+ // As per memref documentation, source and destination must have same
+ // element type and shape in order to be compatible. We do not want to fail
+ // with an IR verification error - thus check that before generating the
+ // copy operation.
+ if (destMemref && srcMemref &&
+ destMemref.getType().getElementType() ==
+ srcMemref.getType().getElementType() &&
+ destMemref.getType().getShape() == srcMemref.getType().getShape()) {
+ memref::CopyOp::create(builder, loc, srcMemref, destMemref);
+ return true;
+ }
+
+ return false;
+ }
};
struct LLVMPointerPointerLikeModel