//===-- CUFLaunchAttachAttr.cpp -------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "flang/Optimizer/Dialect/CUF/CUFDialect.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIROpsSupport.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" namespace fir { #define GEN_PASS_DEF_CUFLAUNCHATTACHATTR #include "flang/Optimizer/Transforms/Passes.h.inc" } // namespace fir using namespace mlir; namespace { static constexpr llvm::StringRef cudaKernelInfix = "_cufk_"; class CUFGPUAttachAttrPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(mlir::gpu::LaunchFuncOp op, PatternRewriter &rewriter) const override { op->setAttr(cuf::getProcAttrName(), cuf::ProcAttributeAttr::get(op.getContext(), cuf::ProcAttribute::Global)); return mlir::success(); } }; struct CUFLaunchAttachAttr : public fir::impl::CUFLaunchAttachAttrBase { void runOnOperation() override { auto *context = &this->getContext(); mlir::RewritePatternSet patterns(context); patterns.add(context); mlir::ConversionTarget target(*context); target.addIllegalOp(); target.addDynamicallyLegalOp( [&](mlir::gpu::LaunchFuncOp op) -> bool { if (op.getKernelName().getValue().contains(cudaKernelInfix)) { if (op.getOperation()->getAttrOfType( cuf::getProcAttrName())) return true; return false; } return true; }); if (mlir::failed(mlir::applyPartialConversion(this->getOperation(), target, std::move(patterns)))) { mlir::emitError(mlir::UnknownLoc::get(context), "Pattern conversion failed\n"); this->signalPassFailure(); } } }; } // end anonymous namespace