aboutsummaryrefslogtreecommitdiff
path: root/flang/lib/Optimizer/Transforms/CUDA/CUFLaunchAttachAttr.cpp
blob: 41a0e5c7dceece5313b540118ca05abd016f7b0a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
//===-- CUFLaunchAttachAttr.cpp -------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "flang/Optimizer/Dialect/CUF/CUFDialect.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

namespace fir {
#define GEN_PASS_DEF_CUFLAUNCHATTACHATTR
#include "flang/Optimizer/Transforms/Passes.h.inc"
} // namespace fir

using namespace mlir;

namespace {

static constexpr llvm::StringRef cudaKernelInfix = "_cufk_";

class CUFGPUAttachAttrPattern
    : public OpRewritePattern<mlir::gpu::LaunchFuncOp> {
  using OpRewritePattern<mlir::gpu::LaunchFuncOp>::OpRewritePattern;
  LogicalResult matchAndRewrite(mlir::gpu::LaunchFuncOp op,
                                PatternRewriter &rewriter) const override {
    op->setAttr(cuf::getProcAttrName(),
                cuf::ProcAttributeAttr::get(op.getContext(),
                                            cuf::ProcAttribute::Global));
    return mlir::success();
  }
};

struct CUFLaunchAttachAttr
    : public fir::impl::CUFLaunchAttachAttrBase<CUFLaunchAttachAttr> {

  void runOnOperation() override {
    auto *context = &this->getContext();

    mlir::RewritePatternSet patterns(context);
    patterns.add<CUFGPUAttachAttrPattern>(context);

    mlir::ConversionTarget target(*context);
    target.addIllegalOp<mlir::gpu::LaunchFuncOp>();
    target.addDynamicallyLegalOp<mlir::gpu::LaunchFuncOp>(
        [&](mlir::gpu::LaunchFuncOp op) -> bool {
          if (op.getKernelName().getValue().contains(cudaKernelInfix)) {
            if (op.getOperation()->getAttrOfType<cuf::ProcAttributeAttr>(
                    cuf::getProcAttrName()))
              return true;
            return false;
          }
          return true;
        });

    if (mlir::failed(mlir::applyPartialConversion(this->getOperation(), target,
                                                  std::move(patterns)))) {
      mlir::emitError(mlir::UnknownLoc::get(context),
                      "Pattern conversion failed\n");
      this->signalPassFailure();
    }
  }
};

} // end anonymous namespace