aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp
blob: a963b3f063a8abc58c4da4c1f82aef971a9b387c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
//===- DebugExtensionOps.cpp - Debug extension for the Transform dialect --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.h"

#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "llvm/Support/InterleavedRange.h"

using namespace mlir;

#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp.inc"

DiagnosedSilenceableFailure
transform::EmitRemarkAtOp::apply(transform::TransformRewriter &rewriter,
                                 transform::TransformResults &results,
                                 transform::TransformState &state) {
  if (isa<TransformHandleTypeInterface>(getAt().getType())) {
    auto payload = state.getPayloadOps(getAt());
    for (Operation *op : payload)
      op->emitRemark() << getMessage();
    return DiagnosedSilenceableFailure::success();
  }

  assert(isa<transform::TransformValueHandleTypeInterface>(getAt().getType()) &&
         "unhandled kind of transform type");

  auto describeValue = [](Diagnostic &os, Value value) {
    os << "value handle points to ";
    if (auto arg = llvm::dyn_cast<BlockArgument>(value)) {
      os << "a block argument #" << arg.getArgNumber() << " in block #"
         << std::distance(arg.getOwner()->getParent()->begin(),
                          arg.getOwner()->getIterator())
         << " in region #" << arg.getOwner()->getParent()->getRegionNumber();
    } else {
      os << "an op result #" << llvm::cast<OpResult>(value).getResultNumber();
    }
  };

  for (Value value : state.getPayloadValues(getAt())) {
    InFlightDiagnostic diag = ::emitRemark(value.getLoc()) << getMessage();
    describeValue(diag.attachNote(), value);
  }

  return DiagnosedSilenceableFailure::success();
}

DiagnosedSilenceableFailure
transform::EmitParamAsRemarkOp::apply(transform::TransformRewriter &rewriter,
                                      transform::TransformResults &results,
                                      transform::TransformState &state) {
  std::string str;
  llvm::raw_string_ostream os(str);
  if (getMessage())
    os << *getMessage() << " ";
  os << llvm::interleaved(state.getParams(getParam()));
  if (!getAnchor()) {
    emitRemark() << str;
    return DiagnosedSilenceableFailure::success();
  }
  for (Operation *payload : state.getPayloadOps(getAnchor()))
    ::mlir::emitRemark(payload->getLoc()) << str;
  return DiagnosedSilenceableFailure::success();
}