aboutsummaryrefslogtreecommitdiff
path: root/flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp
blob: 4c7b228eefeb5d7e489775d635897a7e6d8445c4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
//===-- AssumedRankOpConversion.cpp ---------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "flang/Lower/BuiltinModules.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/Runtime/Support.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Support/TypeCode.h"
#include "flang/Optimizer/Support/Utils.h"
#include "flang/Optimizer/Transforms/Passes.h"
#include "flang/Runtime/support.h"
#include "flang/Support/Fortran.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace fir {
#define GEN_PASS_DEF_ASSUMEDRANKOPCONVERSION
#include "flang/Optimizer/Transforms/Passes.h.inc"
} // namespace fir

using namespace fir;
using namespace mlir;

namespace {

static int getCFIAttribute(mlir::Type boxType) {
  if (fir::isAllocatableType(boxType))
    return CFI_attribute_allocatable;
  if (fir::isPointerType(boxType))
    return CFI_attribute_pointer;
  return CFI_attribute_other;
}

static Fortran::runtime::LowerBoundModifier
getLowerBoundModifier(fir::LowerBoundModifierAttribute modifier) {
  switch (modifier) {
  case fir::LowerBoundModifierAttribute::Preserve:
    return Fortran::runtime::LowerBoundModifier::Preserve;
  case fir::LowerBoundModifierAttribute::SetToOnes:
    return Fortran::runtime::LowerBoundModifier::SetToOnes;
  case fir::LowerBoundModifierAttribute::SetToZeroes:
    return Fortran::runtime::LowerBoundModifier::SetToZeroes;
  }
  llvm_unreachable("bad modifier code");
}

class ReboxAssumedRankConv
    : public mlir::OpRewritePattern<fir::ReboxAssumedRankOp> {
public:
  using OpRewritePattern::OpRewritePattern;

  ReboxAssumedRankConv(mlir::MLIRContext *context,
                       mlir::SymbolTable *symbolTable, fir::KindMapping kindMap)
      : mlir::OpRewritePattern<fir::ReboxAssumedRankOp>(context),
        symbolTable{symbolTable}, kindMap{kindMap} {};

  llvm::LogicalResult
  matchAndRewrite(fir::ReboxAssumedRankOp rebox,
                  mlir::PatternRewriter &rewriter) const override {
    fir::FirOpBuilder builder{rewriter, kindMap, symbolTable};
    mlir::Location loc = rebox.getLoc();
    auto newBoxType = mlir::cast<fir::BaseBoxType>(rebox.getType());
    mlir::Type newMaxRankBoxType =
        newBoxType.getBoxTypeWithNewShape(Fortran::common::maxRank);
    // CopyAndUpdateDescriptor FIR interface requires loading
    // !fir.ref<fir.box> input which is expensive with assumed-rank. It could
    // be best to add an entry point that takes a non "const" from to cover
    // this case, but it would be good to indicate to LLVM that from does not
    // get modified.
    if (fir::isBoxAddress(rebox.getBox().getType()))
      TODO(loc, "fir.rebox_assumed_rank codegen with fir.ref<fir.box<>> input");
    mlir::Value tempDesc = builder.createTemporary(loc, newMaxRankBoxType);
    mlir::Value newDtype;
    mlir::Type newEleType = newBoxType.unwrapInnerType();
    auto oldBoxType = mlir::cast<fir::BaseBoxType>(
        fir::unwrapRefType(rebox.getBox().getType()));
    auto newDerivedType = mlir::dyn_cast<fir::RecordType>(newEleType);
    if (newDerivedType && !fir::isPolymorphicType(newBoxType) &&
        (fir::isPolymorphicType(oldBoxType) ||
         (newEleType != oldBoxType.unwrapInnerType())) &&
        !fir::isPolymorphicType(newBoxType)) {
      newDtype = fir::TypeDescOp::create(builder, loc,
                                         mlir::TypeAttr::get(newDerivedType));
    } else {
      newDtype = builder.createNullConstant(loc);
    }
    mlir::Value newAttribute = builder.createIntegerConstant(
        loc, builder.getIntegerType(8), getCFIAttribute(newBoxType));
    int lbsModifierCode =
        static_cast<int>(getLowerBoundModifier(rebox.getLbsModifier()));
    mlir::Value lowerBoundModifier = builder.createIntegerConstant(
        loc, builder.getIntegerType(32), lbsModifierCode);
    fir::runtime::genCopyAndUpdateDescriptor(builder, loc, tempDesc,
                                             rebox.getBox(), newDtype,
                                             newAttribute, lowerBoundModifier);

    mlir::Value descValue = fir::LoadOp::create(builder, loc, tempDesc);
    mlir::Value castDesc = builder.createConvert(loc, newBoxType, descValue);
    rewriter.replaceOp(rebox, castDesc);
    return mlir::success();
  }

private:
  mlir::SymbolTable *symbolTable = nullptr;
  fir::KindMapping kindMap;
};

class IsAssumedSizeConv : public mlir::OpRewritePattern<fir::IsAssumedSizeOp> {
public:
  using OpRewritePattern::OpRewritePattern;

  IsAssumedSizeConv(mlir::MLIRContext *context, mlir::SymbolTable *symbolTable,
                    fir::KindMapping kindMap)
      : mlir::OpRewritePattern<fir::IsAssumedSizeOp>(context),
        symbolTable{symbolTable}, kindMap{kindMap} {};

  llvm::LogicalResult
  matchAndRewrite(fir::IsAssumedSizeOp isAssumedSizeOp,
                  mlir::PatternRewriter &rewriter) const override {
    fir::FirOpBuilder builder{rewriter, kindMap, symbolTable};
    mlir::Location loc = isAssumedSizeOp.getLoc();
    mlir::Value result =
        fir::runtime::genIsAssumedSize(builder, loc, isAssumedSizeOp.getVal());
    rewriter.replaceOp(isAssumedSizeOp, result);
    return mlir::success();
  }

private:
  mlir::SymbolTable *symbolTable = nullptr;
  fir::KindMapping kindMap;
};

/// Convert FIR structured control flow ops to CFG ops.
class AssumedRankOpConversion
    : public fir::impl::AssumedRankOpConversionBase<AssumedRankOpConversion> {
public:
  void runOnOperation() override {
    auto *context = &getContext();
    mlir::ModuleOp mod = getOperation();
    mlir::SymbolTable symbolTable(mod);
    fir::KindMapping kindMap = fir::getKindMapping(mod);
    mlir::RewritePatternSet patterns(context);
    patterns.insert<ReboxAssumedRankConv>(context, &symbolTable, kindMap);
    patterns.insert<IsAssumedSizeConv>(context, &symbolTable, kindMap);
    mlir::GreedyRewriteConfig config;
    config.setRegionSimplificationLevel(
        mlir::GreedySimplifyRegionLevel::Disabled);
    (void)applyPatternsGreedily(mod, std::move(patterns), config);
  }
};
} // namespace