aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp
blob: 145455e1c1b3d2b287a140180ae09d59dce75791 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
//===- TransformTransforms.cpp - C Interface for Transform dialect --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// C interface to transforms for the transform dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir-c/Dialect/Transform/Interpreter.h"
#include "mlir-c/Support.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
#include "mlir/CAPI/Wrap.h"
#include "mlir/Dialect/Transform/IR/Utils.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"

using namespace mlir;

DEFINE_C_API_PTR_METHODS(MlirTransformOptions, transform::TransformOptions)

extern "C" {

MlirTransformOptions mlirTransformOptionsCreate() {
  return wrap(new transform::TransformOptions);
}

void mlirTransformOptionsEnableExpensiveChecks(
    MlirTransformOptions transformOptions, bool enable) {
  unwrap(transformOptions)->enableExpensiveChecks(enable);
}

bool mlirTransformOptionsGetExpensiveChecksEnabled(
    MlirTransformOptions transformOptions) {
  return unwrap(transformOptions)->getExpensiveChecksEnabled();
}

void mlirTransformOptionsEnforceSingleTopLevelTransformOp(
    MlirTransformOptions transformOptions, bool enable) {
  unwrap(transformOptions)->enableEnforceSingleToplevelTransformOp(enable);
}

bool mlirTransformOptionsGetEnforceSingleTopLevelTransformOp(
    MlirTransformOptions transformOptions) {
  return unwrap(transformOptions)->getEnforceSingleToplevelTransformOp();
}

void mlirTransformOptionsDestroy(MlirTransformOptions transformOptions) {
  delete unwrap(transformOptions);
}

MlirLogicalResult mlirTransformApplyNamedSequence(
    MlirOperation payload, MlirOperation transformRoot,
    MlirOperation transformModule, MlirTransformOptions transformOptions) {
  Operation *transformRootOp = unwrap(transformRoot);
  Operation *transformModuleOp = unwrap(transformModule);
  if (!isa<transform::TransformOpInterface>(transformRootOp)) {
    transformRootOp->emitError()
        << "must implement TransformOpInterface to be used as transform root";
    return mlirLogicalResultFailure();
  }
  if (!isa<ModuleOp>(transformModuleOp)) {
    transformModuleOp->emitError()
        << "must be a " << ModuleOp::getOperationName();
    return mlirLogicalResultFailure();
  }
  return wrap(transform::applyTransformNamedSequence(
      unwrap(payload), unwrap(transformRoot),
      cast<ModuleOp>(unwrap(transformModule)), *unwrap(transformOptions)));
}

MlirLogicalResult mlirMergeSymbolsIntoFromClone(MlirOperation target,
                                                MlirOperation other) {
  OwningOpRef<Operation *> otherOwning(unwrap(other)->clone());
  LogicalResult result = transform::detail::mergeSymbolsInto(
      unwrap(target), std::move(otherOwning));
  return wrap(result);
}
}