aboutsummaryrefslogtreecommitdiff
path: root/llvm/unittests/Support/BalancedPartitioningTest.cpp
blob: ebe518a8e89cacff460cd49bc30c5d3bcb466fc7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
//===- BalancedPartitioningTest.cpp - BalancedPartitioning tests ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "llvm/Support/BalancedPartitioning.h"
#include "llvm/Testing/Support/SupportHelpers.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"

using testing::Each;
using testing::Field;
using testing::Not;
using testing::UnorderedElementsAre;
using testing::UnorderedElementsAreArray;

namespace llvm {

void PrintTo(const BPFunctionNode &Node, std::ostream *OS) {
  raw_os_ostream ROS(*OS);
  Node.dump(ROS);
}

class BalancedPartitioningTest : public ::testing::Test {
protected:
  BalancedPartitioningConfig Config;
  BalancedPartitioning Bp;
  BalancedPartitioningTest() : Bp(Config) {}

  static std::vector<BPFunctionNode::IDT>
  getIds(std::vector<BPFunctionNode> Nodes) {
    std::vector<BPFunctionNode::IDT> Ids;
    for (auto &N : Nodes)
      Ids.push_back(N.Id);
    return Ids;
  }
};

TEST_F(BalancedPartitioningTest, Basic) {
  std::vector<BPFunctionNode> Nodes = {
      BPFunctionNode(0, {1, 2}), BPFunctionNode(2, {3, 4}),
      BPFunctionNode(1, {1, 2}), BPFunctionNode(3, {3, 4}),
      BPFunctionNode(4, {4}),
  };

  Bp.run(Nodes);

  auto NodeIs = [](BPFunctionNode::IDT Id, std::optional<uint32_t> Bucket) {
    return AllOf(Field("Id", &BPFunctionNode::Id, Id),
                 Field("Bucket", &BPFunctionNode::Bucket, Bucket));
  };

  EXPECT_THAT(Nodes,
              UnorderedElementsAre(NodeIs(0, 0), NodeIs(1, 1), NodeIs(2, 2),
                                   NodeIs(3, 3), NodeIs(4, 4)));
}

TEST_F(BalancedPartitioningTest, Large) {
  const int ProblemSize = 1000;
  std::vector<BPFunctionNode::UtilityNodeT> AllUNs;
  for (int i = 0; i < ProblemSize; i++)
    AllUNs.emplace_back(i);

  std::mt19937 RNG;
  std::vector<BPFunctionNode> Nodes;
  for (int i = 0; i < ProblemSize; i++) {
    std::vector<BPFunctionNode::UtilityNodeT> UNs;
    int SampleSize =
        std::uniform_int_distribution<int>(0, AllUNs.size() - 1)(RNG);
    std::sample(AllUNs.begin(), AllUNs.end(), std::back_inserter(UNs),
                SampleSize, RNG);
    Nodes.emplace_back(i, UNs);
  }

  auto OrigIds = getIds(Nodes);

  Bp.run(Nodes);

  EXPECT_THAT(
      Nodes, Each(Not(Field("Bucket", &BPFunctionNode::Bucket, std::nullopt))));
  EXPECT_THAT(getIds(Nodes), UnorderedElementsAreArray(OrigIds));
}

TEST_F(BalancedPartitioningTest, MoveGain) {
  BalancedPartitioning::SignaturesT Signatures = {
      {10, 10, 10.f, 0.f, true}, // 0
      {10, 10, 0.f, 10.f, true}, // 1
      {10, 10, 0.f, 20.f, true}, // 2
  };
  EXPECT_FLOAT_EQ(Bp.moveGain(BPFunctionNode(0, {}), true, Signatures), 0.f);
  EXPECT_FLOAT_EQ(Bp.moveGain(BPFunctionNode(0, {0, 1}), true, Signatures),
                  10.f);
  EXPECT_FLOAT_EQ(Bp.moveGain(BPFunctionNode(0, {1, 2}), false, Signatures),
                  30.f);
}

} // end namespace llvm