aboutsummaryrefslogtreecommitdiff
path: root/flang/lib/Optimizer/Transforms/CUDA/CUFFunctionRewrite.cpp
blob: bcbfb5294c56f06389c24e3bf9ddaa5aa468f01a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
//===-- CUFFunctionRewrite.cpp --------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "flang/Optimizer/CodeGen/TypeConverter.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Support/DataLayout.h"
#include "flang/Optimizer/Transforms/Passes.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Debug.h"
#include <string_view>

#define DEBUG_TYPE "flang-cuf-function-rewrite"

namespace fir {
#define GEN_PASS_DEF_CUFFUNCTIONREWRITE
#include "flang/Optimizer/Transforms/Passes.h.inc"
} // namespace fir

using namespace mlir;

namespace {

using genFunctionType =
    std::function<mlir::Value(mlir::PatternRewriter &, fir::CallOp op)>;

class CallConversion : public OpRewritePattern<fir::CallOp> {
public:
  CallConversion(MLIRContext *context)
      : OpRewritePattern<fir::CallOp>(context) {}

  LogicalResult
  matchAndRewrite(fir::CallOp op,
                  mlir::PatternRewriter &rewriter) const override {
    auto callee = op.getCallee();
    if (!callee)
      return failure();
    auto name = callee->getRootReference().getValue();

    if (genMappings_.contains(name)) {
      auto fct = genMappings_.find(name);
      mlir::Value result = fct->second(rewriter, op);
      if (result)
        rewriter.replaceOp(op, result);
      else
        rewriter.eraseOp(op);
      return success();
    }
    return failure();
  }

private:
  static mlir::Value genOnDevice(mlir::PatternRewriter &rewriter,
                                 fir::CallOp op) {
    assert(op.getArgs().size() == 0 && "expect 0 arguments");
    mlir::Location loc = op.getLoc();
    unsigned inGPUMod = op->getParentOfType<gpu::GPUModuleOp>() ? 1 : 0;
    mlir::Type i1Ty = rewriter.getIntegerType(1);
    mlir::Value t = mlir::arith::ConstantOp::create(
        rewriter, loc, i1Ty, rewriter.getIntegerAttr(i1Ty, inGPUMod));
    return fir::ConvertOp::create(rewriter, loc, op.getResult(0).getType(), t);
  }

  const llvm::StringMap<genFunctionType> genMappings_ = {
      {"on_device", &genOnDevice}};
};

class CUFFunctionRewrite
    : public fir::impl::CUFFunctionRewriteBase<CUFFunctionRewrite> {
public:
  void runOnOperation() override {
    auto *ctx = &getContext();
    mlir::RewritePatternSet patterns(ctx);

    patterns.insert<CallConversion>(patterns.getContext());

    if (mlir::failed(
            mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) {
      mlir::emitError(mlir::UnknownLoc::get(ctx),
                      "error in CUFFunctionRewrite op conversion\n");
      signalPassFailure();
    }
  }
};

} // namespace