aboutsummaryrefslogtreecommitdiff
path: root/mlir/unittests/Conversion/PDLToPDLInterp/RootOrderingTest.cpp
blob: 020c0fe770bfc3f440084164e3aacab7867c2e93 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
//===- RootOrderingTest.cpp - unit tests for optimal branching ------------===//
//
// Part of the LLVM Project, under the Apache License v[1].0 with LLVM
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "../lib/Conversion/PDLToPDLInterp/RootOrdering.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "gtest/gtest.h"

using namespace mlir;
using namespace mlir::arith;
using namespace mlir::pdl_to_pdl_interp;

namespace {

//===----------------------------------------------------------------------===//
// Test Fixture
//===----------------------------------------------------------------------===//

/// The test fixture for constructing root ordering tests and verifying results.
/// This fixture constructs the test values v. The test populates the graph
/// with the desired costs and then calls check(), passing the expected optimal
/// cost and the list of edges in the preorder traversal of the optimal
/// branching.
class RootOrderingTest : public ::testing::Test {
protected:
  RootOrderingTest() {
    context.loadDialect<ArithDialect>();
    createValues();
  }

  /// Creates the test values. These values simply act as vertices / vertex IDs
  /// in the cost graph, rather than being a part of an IR.
  void createValues() {
    OpBuilder builder(&context);
    builder.setInsertionPointToStart(&block);
    for (int i = 0; i < 4; ++i)
      // Ops will be deleted when `block` is destroyed.
      v[i] = ConstantIntOp::create(builder, builder.getUnknownLoc(), i, 32);
  }

  /// Checks that optimal branching on graph has the given cost and
  /// its preorder traversal results in the specified edges.
  void check(unsigned cost, const OptimalBranching::EdgeList &edges) {
    OptimalBranching opt(graph, v[0]);
    EXPECT_EQ(opt.solve(), cost);
    EXPECT_EQ(opt.preOrderTraversal({v, v + edges.size()}), edges);
    for (std::pair<Value, Value> edge : edges)
      EXPECT_EQ(opt.getRootOrderingParents().lookup(edge.first), edge.second);
  }

protected:
  /// The context for creating the values.
  MLIRContext context;

  /// Block holding all the operations.
  Block block;

  /// Values used in the graph definition. We always use leading `n` values.
  Value v[4];

  /// The graph being tested on.
  RootOrderingGraph graph;
};

//===----------------------------------------------------------------------===//
// Simple 3-node graphs
//===----------------------------------------------------------------------===//

TEST_F(RootOrderingTest, simpleA) {
  graph[v[1]][v[0]].cost = {1, 10};
  graph[v[2]][v[0]].cost = {1, 11};
  graph[v[1]][v[2]].cost = {2, 12};
  graph[v[2]][v[1]].cost = {2, 13};
  check(2, {{v[0], {}}, {v[1], v[0]}, {v[2], v[0]}});
}

TEST_F(RootOrderingTest, simpleB) {
  graph[v[1]][v[0]].cost = {1, 10};
  graph[v[2]][v[0]].cost = {2, 11};
  graph[v[1]][v[2]].cost = {1, 12};
  graph[v[2]][v[1]].cost = {1, 13};
  check(2, {{v[0], {}}, {v[1], v[0]}, {v[2], v[1]}});
}

TEST_F(RootOrderingTest, simpleC) {
  graph[v[1]][v[0]].cost = {2, 10};
  graph[v[2]][v[0]].cost = {2, 11};
  graph[v[1]][v[2]].cost = {1, 12};
  graph[v[2]][v[1]].cost = {1, 13};
  check(3, {{v[0], {}}, {v[1], v[0]}, {v[2], v[1]}});
}

//===----------------------------------------------------------------------===//
// Graph for testing contraction
//===----------------------------------------------------------------------===//

TEST_F(RootOrderingTest, contraction) {
  graph[v[1]][v[0]].cost = {10, 0};
  graph[v[2]][v[0]].cost = {5, 0};
  graph[v[2]][v[1]].cost = {1, 0};
  graph[v[3]][v[2]].cost = {2, 0};
  graph[v[1]][v[3]].cost = {3, 0};
  check(10, {{v[0], {}}, {v[2], v[0]}, {v[3], v[2]}, {v[1], v[3]}});
}

} // namespace