aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Func/Utils/Utils.cpp
blob: b4cb0932ef631085c0decbef5894909f250e8f33 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
//===- Utils.cpp - Utilities to support the Func dialect ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements utilities for the Func dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Func/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/DebugLog.h"

#define DEBUG_TYPE "func-utils"

using namespace mlir;

/// This method creates an inverse mapping of the provided map `oldToNew`.
/// Given an array where `oldIdxToNewIdx[i] = j` means old index `i` maps
/// to new index `j`,
/// This method returns a vector where `result[j]` contains all old indices
/// that map to new index `j`.
///
/// Example:
/// ```
/// oldIdxToNewIdx = [0, 1, 2, 2, 3]
/// getInverseMapping(oldIdxToNewIdx) = [[0], [1], [2, 3], [4]]
/// ```
///
static llvm::SmallVector<llvm::SmallVector<int>>
getInverseMapping(ArrayRef<int> oldIdxToNewIdx) {
  int numOfNewIdxs = 0;
  if (!oldIdxToNewIdx.empty())
    numOfNewIdxs = 1 + *llvm::max_element(oldIdxToNewIdx);
  llvm::SmallVector<llvm::SmallVector<int>> newToOldIdxs(numOfNewIdxs);
  for (auto [oldIdx, newIdx] : llvm::enumerate(oldIdxToNewIdx))
    newToOldIdxs[newIdx].push_back(oldIdx);
  return newToOldIdxs;
}

/// This method returns a new vector of elements that are mapped from the
/// `origElements` based on the `newIdxToOldIdxs` mapping. This function assumes
/// that the `newIdxToOldIdxs` mapping is valid, i.e. for each new index, there
/// is at least one old index that maps to it. Also, It assumes that mapping to
/// the same old index has the same element in the `origElements` vector.
template <typename Element>
static SmallVector<Element> getMappedElements(
    ArrayRef<Element> origElements,
    const llvm::SmallVector<llvm::SmallVector<int>> &newIdxToOldIdxs) {
  SmallVector<Element> newElements;
  for (const auto &oldIdxs : newIdxToOldIdxs) {
    assert(llvm::all_of(oldIdxs,
                        [&origElements](int idx) -> bool {
                          return idx >= 0 &&
                                 static_cast<size_t>(idx) < origElements.size();
                        }) &&
           "idx must be less than the number of elements in the original "
           "elements");
    assert(!oldIdxs.empty() && "oldIdx must not be empty");
    Element origTypeToCheck = origElements[oldIdxs.front()];
    assert(llvm::all_of(oldIdxs,
                        [&](int idx) -> bool {
                          return origElements[idx] == origTypeToCheck;
                        }) &&
           "all oldIdxs must be equal");
    newElements.push_back(origTypeToCheck);
  }
  return newElements;
}

FailureOr<func::FuncOp>
func::replaceFuncWithNewMapping(RewriterBase &rewriter, func::FuncOp funcOp,
                                ArrayRef<int> oldArgIdxToNewArgIdx,
                                ArrayRef<int> oldResIdxToNewResIdx) {
  // Generate an empty new function operation with the same name as the
  // original.
  assert(funcOp.getNumArguments() == oldArgIdxToNewArgIdx.size() &&
         "oldArgIdxToNewArgIdx must match the number of arguments in the "
         "function");
  assert(
      funcOp.getNumResults() == oldResIdxToNewResIdx.size() &&
      "oldResIdxToNewResIdx must match the number of results in the function");

  if (!funcOp.getBody().hasOneBlock())
    return rewriter.notifyMatchFailure(
        funcOp, "expected function to have exactly one block");

  // We may have some duplicate arguments in the old function, i.e.
  // in the mapping `newArgIdxToOldArgIdxs` for some new argument index
  // there may be multiple old argument indices.
  llvm::SmallVector<llvm::SmallVector<int>> newArgIdxToOldArgIdxs =
      getInverseMapping(oldArgIdxToNewArgIdx);
  SmallVector<Type> newInputTypes = getMappedElements(
      funcOp.getFunctionType().getInputs(), newArgIdxToOldArgIdxs);

  SmallVector<Location> locs;
  for (const auto &oldArgIdxs : newArgIdxToOldArgIdxs)
    locs.push_back(funcOp.getArgument(oldArgIdxs.front()).getLoc());

  llvm::SmallVector<llvm::SmallVector<int>> newResToOldResIdxs =
      getInverseMapping(oldResIdxToNewResIdx);
  SmallVector<Type> newOutputTypes = getMappedElements(
      funcOp.getFunctionType().getResults(), newResToOldResIdxs);

  rewriter.setInsertionPoint(funcOp);
  auto newFuncOp = func::FuncOp::create(
      rewriter, funcOp.getLoc(), funcOp.getName(),
      rewriter.getFunctionType(newInputTypes, newOutputTypes));

  Region &newRegion = newFuncOp.getBody();
  rewriter.createBlock(&newRegion, newRegion.begin(), newInputTypes, locs);
  newFuncOp.setVisibility(funcOp.getVisibility());

  // Map the arguments of the original function to the new function in
  // the new order and adjust the attributes accordingly.
  IRMapping operandMapper;
  SmallVector<DictionaryAttr> argAttrs, resultAttrs;
  funcOp.getAllArgAttrs(argAttrs);
  for (auto [oldArgIdx, newArgIdx] : llvm::enumerate(oldArgIdxToNewArgIdx))
    operandMapper.map(funcOp.getArgument(oldArgIdx),
                      newFuncOp.getArgument(newArgIdx));
  for (auto [newArgIdx, oldArgIdx] : llvm::enumerate(newArgIdxToOldArgIdxs))
    newFuncOp.setArgAttrs(newArgIdx, argAttrs[oldArgIdx.front()]);

  funcOp.getAllResultAttrs(resultAttrs);
  for (auto [newResIdx, oldResIdx] : llvm::enumerate(newResToOldResIdxs))
    newFuncOp.setResultAttrs(newResIdx, resultAttrs[oldResIdx.front()]);

  // Clone the operations from the original function to the new function.
  rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
  for (Operation &op : funcOp.getOps())
    rewriter.clone(op, operandMapper);

  // Handle the return operation.
  auto returnOp = cast<func::ReturnOp>(
      newFuncOp.getFunctionBody().begin()->getTerminator());
  SmallVector<Value> newReturnValues;
  for (const auto &oldResIdxs : newResToOldResIdxs)
    newReturnValues.push_back(returnOp.getOperand(oldResIdxs.front()));

  rewriter.setInsertionPoint(returnOp);
  func::ReturnOp::create(rewriter, newFuncOp.getLoc(), newReturnValues);
  rewriter.eraseOp(returnOp);

  rewriter.eraseOp(funcOp);

  return newFuncOp;
}

func::CallOp
func::replaceCallOpWithNewMapping(RewriterBase &rewriter, func::CallOp callOp,
                                  ArrayRef<int> oldArgIdxToNewArgIdx,
                                  ArrayRef<int> oldResIdxToNewResIdx) {
  assert(callOp.getNumOperands() == oldArgIdxToNewArgIdx.size() &&
         "oldArgIdxToNewArgIdx must match the number of operands in the call "
         "operation");
  assert(callOp.getNumResults() == oldResIdxToNewResIdx.size() &&
         "oldResIdxToNewResIdx must match the number of results in the call "
         "operation");

  SmallVector<Value> origOperands = callOp.getOperands();
  SmallVector<llvm::SmallVector<int>> newArgIdxToOldArgIdxs =
      getInverseMapping(oldArgIdxToNewArgIdx);
  SmallVector<Value> newOperandsValues =
      getMappedElements<Value>(origOperands, newArgIdxToOldArgIdxs);
  SmallVector<llvm::SmallVector<int>> newResToOldResIdxs =
      getInverseMapping(oldResIdxToNewResIdx);
  SmallVector<Type> origResultTypes = llvm::to_vector(callOp.getResultTypes());
  SmallVector<Type> newResultTypes =
      getMappedElements<Type>(origResultTypes, newResToOldResIdxs);

  // Replace the kernel call operation with a new one that has the
  // mapped arguments.
  rewriter.setInsertionPoint(callOp);
  auto newCallOp =
      func::CallOp::create(rewriter, callOp.getLoc(), callOp.getCallee(),
                           newResultTypes, newOperandsValues);
  newCallOp.setNoInlineAttr(callOp.getNoInlineAttr());
  for (auto &&[oldResIdx, newResIdx] : llvm::enumerate(oldResIdxToNewResIdx))
    rewriter.replaceAllUsesWith(callOp.getResult(oldResIdx),
                                newCallOp.getResult(newResIdx));
  rewriter.eraseOp(callOp);

  return newCallOp;
}

FailureOr<std::pair<func::FuncOp, func::CallOp>>
func::deduplicateArgsOfFuncOp(RewriterBase &rewriter, func::FuncOp funcOp,
                              ModuleOp moduleOp) {
  SmallVector<func::CallOp> callOps;
  auto traversalResult = moduleOp.walk([&](func::CallOp callOp) {
    if (callOp.getCallee() == funcOp.getSymName()) {
      if (!callOps.empty())
        // Only support one callOp for now
        return WalkResult::interrupt();
      callOps.push_back(callOp);
    }
    return WalkResult::advance();
  });

  if (traversalResult.wasInterrupted()) {
    LDBG() << "function " << funcOp.getName() << " has more than one callOp";
    return failure();
  }

  if (callOps.empty()) {
    LDBG() << "function " << funcOp.getName() << " does not have any callOp";
    return failure();
  }

  func::CallOp callOp = callOps.front();

  // Create mapping for arguments (deduplicate operands)
  SmallVector<int> oldArgIdxToNewArgIdx(callOp.getNumOperands());
  llvm::DenseMap<Value, int> valueToNewArgIdx;
  for (auto [operandIdx, operand] : llvm::enumerate(callOp.getOperands())) {
    auto [iterator, inserted] = valueToNewArgIdx.insert(
        {operand, static_cast<int>(valueToNewArgIdx.size())});
    // Reduce the duplicate operands and maintain the original order.
    oldArgIdxToNewArgIdx[operandIdx] = iterator->second;
  }

  bool hasDuplicateOperands =
      valueToNewArgIdx.size() != callOp.getNumOperands();
  if (!hasDuplicateOperands) {
    LDBG() << "function " << funcOp.getName()
           << " does not have duplicate operands";
    return failure();
  }

  // Create identity mapping for results (no deduplication needed)
  SmallVector<int> oldResIdxToNewResIdx(callOp.getNumResults());
  for (int resultIdx : llvm::seq<int>(0, callOp.getNumResults()))
    oldResIdxToNewResIdx[resultIdx] = resultIdx;

  // Apply the transformation to create new function and call operations
  FailureOr<func::FuncOp> newFuncOpOrFailure = replaceFuncWithNewMapping(
      rewriter, funcOp, oldArgIdxToNewArgIdx, oldResIdxToNewResIdx);
  if (failed(newFuncOpOrFailure)) {
    LDBG() << "failed to replace function signature with name "
           << funcOp.getName() << " with new order";
    return failure();
  }

  func::CallOp newCallOp = replaceCallOpWithNewMapping(
      rewriter, callOp, oldArgIdxToNewArgIdx, oldResIdxToNewResIdx);

  return std::make_pair(*newFuncOpOrFailure, newCallOp);
}