aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Query/Matcher/MatchFinder.cpp
blob: 1d4817e32417db8d3a45ac450a303b9a2ffc45da (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
//===- MatchFinder.cpp - --------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains the method definitions for the `MatchFinder` class
//
//===----------------------------------------------------------------------===//

#include "mlir/Query/Matcher/MatchFinder.h"
namespace mlir::query::matcher {

MatchFinder::MatchResult::MatchResult(Operation *rootOp,
                                      std::vector<Operation *> matchedOps)
    : rootOp(rootOp), matchedOps(std::move(matchedOps)) {}

std::vector<MatchFinder::MatchResult>
MatchFinder::collectMatches(Operation *root, DynMatcher matcher) const {
  std::vector<MatchResult> results;
  llvm::SetVector<Operation *> tempStorage;
  root->walk([&](Operation *subOp) {
    if (matcher.match(subOp)) {
      MatchResult match;
      match.rootOp = subOp;
      match.matchedOps.push_back(subOp);
      results.push_back(std::move(match));
    } else if (matcher.match(subOp, tempStorage)) {
      results.emplace_back(subOp, std::vector<Operation *>(tempStorage.begin(),
                                                           tempStorage.end()));
    }
    tempStorage.clear();
  });
  return results;
}

void MatchFinder::printMatch(llvm::raw_ostream &os, QuerySession &qs,
                             Operation *op) const {
  auto fileLoc = cast<FileLineColLoc>(op->getLoc());
  SMLoc smloc = qs.getSourceManager().FindLocForLineAndColumn(
      qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
  llvm::SMDiagnostic diag =
      qs.getSourceManager().GetMessage(smloc, llvm::SourceMgr::DK_Note, "");
  diag.print("", os, true, false, true);
}

void MatchFinder::printMatch(llvm::raw_ostream &os, QuerySession &qs,
                             Operation *op, const std::string &binding) const {
  auto fileLoc = cast<FileLineColLoc>(op->getLoc());
  auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
      qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
  qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
                                     "\"" + binding + "\" binds here");
}

std::vector<Operation *>
MatchFinder::flattenMatchedOps(std::vector<MatchResult> &matches) const {
  std::vector<Operation *> newVector;
  for (auto &result : matches) {
    newVector.insert(newVector.end(), result.matchedOps.begin(),
                     result.matchedOps.end());
  }
  return newVector;
}

} // namespace mlir::query::matcher