aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp
blob: 23b41304823143e0b8ca3b1089a2c0a63e1dc3eb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
//===- DIExpressionRewriter.cpp - Rewriter for DIExpression operators -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h"
#include "llvm/Support/DebugLog.h"

using namespace mlir;
using namespace LLVM;

#define DEBUG_TYPE "llvm-di-expression-simplifier"

//===----------------------------------------------------------------------===//
// DIExpressionRewriter
//===----------------------------------------------------------------------===//

void DIExpressionRewriter::addPattern(
    std::unique_ptr<ExprRewritePattern> pattern) {
  patterns.emplace_back(std::move(pattern));
}

DIExpressionAttr
DIExpressionRewriter::simplify(DIExpressionAttr expr,
                               std::optional<uint64_t> maxNumRewrites) const {
  ArrayRef<OperatorT> operators = expr.getOperations();

  // `inputs` contains the unprocessed postfix of operators.
  // `result` contains the already finalized prefix of operators.
  // Invariant: concat(result, inputs) is equivalent to `operators` after some
  // application of the rewrite patterns.
  // Using a deque for inputs so that we have efficient front insertion and
  // removal. Random access is not necessary for patterns.
  std::deque<OperatorT> inputs(operators.begin(), operators.end());
  SmallVector<OperatorT> result;

  uint64_t numRewrites = 0;
  while (!inputs.empty() &&
         (!maxNumRewrites || numRewrites < *maxNumRewrites)) {
    bool foundMatch = false;
    for (const std::unique_ptr<ExprRewritePattern> &pattern : patterns) {
      ExprRewritePattern::OpIterT matchEnd = pattern->match(inputs);
      if (matchEnd == inputs.begin())
        continue;

      foundMatch = true;
      SmallVector<OperatorT> replacement =
          pattern->replace(llvm::make_range(inputs.cbegin(), matchEnd));
      inputs.erase(inputs.begin(), matchEnd);
      inputs.insert(inputs.begin(), replacement.begin(), replacement.end());
      ++numRewrites;
      break;
    }

    if (!foundMatch) {
      // If no match, pass along the current operator.
      result.push_back(inputs.front());
      inputs.pop_front();
    }
  }

  if (maxNumRewrites && numRewrites >= *maxNumRewrites) {
    LDBG() << "LLVMDIExpressionSimplifier exceeded max num rewrites ("
           << maxNumRewrites << ")";
    // Skip rewriting the rest.
    result.append(inputs.begin(), inputs.end());
  }

  return LLVM::DIExpressionAttr::get(expr.getContext(), result);
}