aboutsummaryrefslogtreecommitdiff
path: root/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'flang/lib/Optimizer/Transforms/CUFOpConversion.cpp')
-rw-r--r--flang/lib/Optimizer/Transforms/CUFOpConversion.cpp37
1 files changed, 35 insertions, 2 deletions
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index de5c515..8c525fc 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -8,6 +8,7 @@
#include "flang/Optimizer/Transforms/CUFOpConversion.h"
#include "flang/Common/Fortran.h"
+#include "flang/Optimizer/Builder/CUFCommon.h"
#include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
#include "flang/Optimizer/CodeGen/TypeConverter.h"
#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
@@ -15,7 +16,6 @@
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/Support/DataLayout.h"
-#include "flang/Optimizer/Transforms/CUFCommon.h"
#include "flang/Runtime/CUDA/allocatable.h"
#include "flang/Runtime/CUDA/common.h"
#include "flang/Runtime/CUDA/descriptor.h"
@@ -788,6 +788,38 @@ private:
const mlir::SymbolTable &symTab;
};
+struct CUFSyncDescriptorOpConversion
+ : public mlir::OpRewritePattern<cuf::SyncDescriptorOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(cuf::SyncDescriptorOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ auto mod = op->getParentOfType<mlir::ModuleOp>();
+ fir::FirOpBuilder builder(rewriter, mod);
+ mlir::Location loc = op.getLoc();
+
+ auto globalOp = mod.lookupSymbol<fir::GlobalOp>(op.getGlobalName());
+ if (!globalOp)
+ return mlir::failure();
+
+ auto hostAddr = builder.create<fir::AddrOfOp>(
+ loc, fir::ReferenceType::get(globalOp.getType()), op.getGlobalName());
+ mlir::func::FuncOp callee =
+ fir::runtime::getRuntimeFunc<mkRTKey(CUFSyncGlobalDescriptor)>(loc,
+ builder);
+ auto fTy = callee.getFunctionType();
+ mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+ mlir::Value sourceLine =
+ fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+ builder, loc, fTy, hostAddr, sourceFile, sourceLine)};
+ builder.create<fir::CallOp>(loc, callee, args);
+ op.erase();
+ return mlir::success();
+ }
+};
+
class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> {
public:
void runOnOperation() override {
@@ -848,7 +880,8 @@ void cuf::populateCUFToFIRConversionPatterns(
const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns) {
patterns.insert<CUFAllocOpConversion>(patterns.getContext(), &dl, &converter);
patterns.insert<CUFAllocateOpConversion, CUFDeallocateOpConversion,
- CUFFreeOpConversion>(patterns.getContext());
+ CUFFreeOpConversion, CUFSyncDescriptorOpConversion>(
+ patterns.getContext());
patterns.insert<CUFDataTransferOpConversion>(patterns.getContext(), symtab,
&dl, &converter);
patterns.insert<CUFLaunchOpConversion>(patterns.getContext(), symtab);