aboutsummaryrefslogtreecommitdiff
path: root/mlir/unittests/Dialect/Transform/Preload.cpp
blob: 8504928d85cb29383fbd7733b1156ae47a3419e8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
//===- Preload.cpp - Test MlirOptMain parameterization ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Transform/DebugExtension/DebugExtension.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/Utils.h"
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Support/TypeID.h"
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/raw_ostream.h"
#include "gtest/gtest.h"

using namespace mlir;

namespace mlir {
namespace test {
std::unique_ptr<Pass> createTestTransformDialectInterpreterPass();
} // namespace test
} // namespace mlir

const static llvm::StringLiteral library = R"MLIR(
module attributes {transform.with_named_sequence} {
  transform.named_sequence private @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
    transform.debug.emit_remark_at %arg0, "from external symbol" : !transform.any_op
    transform.yield
  }
})MLIR";

const static llvm::StringLiteral input = R"MLIR(
module attributes {transform.with_named_sequence} {
  transform.named_sequence private @__transform_main(%arg0: !transform.any_op {transform.readonly})

  transform.sequence failures(propagate) {
  ^bb0(%arg0: !transform.any_op):
    include @__transform_main failures(propagate) (%arg0) : (!transform.any_op) -> ()
  }
})MLIR";

TEST(Preload, ContextPreloadConstructedLibrary) {
  registerPassManagerCLOptions();

  MLIRContext context;
  auto *dialect = context.getOrLoadDialect<transform::TransformDialect>();
  DialectRegistry registry;
  mlir::transform::registerDebugExtension(registry);
  registry.applyExtensions(&context);
  ParserConfig parserConfig(&context);

  OwningOpRef<ModuleOp> inputModule =
      parseSourceString<ModuleOp>(input, parserConfig, "<input>");
  EXPECT_TRUE(inputModule) << "failed to parse input module";

  OwningOpRef<ModuleOp> transformLibrary =
      parseSourceString<ModuleOp>(library, parserConfig, "<transform-library>");
  EXPECT_TRUE(transformLibrary) << "failed to parse transform module";
  LogicalResult diag =
      dialect->loadIntoLibraryModule(std::move(transformLibrary));
  EXPECT_TRUE(succeeded(diag));

  ModuleOp retrievedTransformLibrary =
      transform::detail::getPreloadedTransformModule(&context);
  EXPECT_TRUE(retrievedTransformLibrary)
      << "failed to retrieve transform module";

  OwningOpRef<Operation *> clonedTransformModule(
      retrievedTransformLibrary->clone());

  LogicalResult res = transform::detail::mergeSymbolsInto(
      inputModule->getOperation(), std::move(clonedTransformModule));
  EXPECT_TRUE(succeeded(res)) << "failed to define declared symbols";

  transform::TransformOpInterface entryPoint =
      transform::detail::findTransformEntryPoint(inputModule->getOperation(),
                                                 retrievedTransformLibrary);
  EXPECT_TRUE(entryPoint) << "failed to find entry point";

  transform::TransformOptions options;
  res = transform::applyTransformNamedSequence(
      inputModule->getOperation(), entryPoint, retrievedTransformLibrary,
      options);
  EXPECT_TRUE(succeeded(res)) << "failed to apply named sequence";
}