aboutsummaryrefslogtreecommitdiff
path: root/offload/unittests/Conformance/include/mathtest/HostRefChecker.hpp
blob: 488aefda67ef4ef849da28f1090d84b0a0eca485 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file
/// This file contains the definition of the HostRefChecker class, which
/// verifies the results of a device computation against a reference
/// implementation on the host.
///
//===----------------------------------------------------------------------===//

#ifndef MATHTEST_HOSTREFCHECKER_HPP
#define MATHTEST_HOSTREFCHECKER_HPP

#include "mathtest/Numerics.hpp"
#include "mathtest/Support.hpp"
#include "mathtest/TestResult.hpp"

#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/Support/Parallel.h"

#include <cstddef>
#include <tuple>
#include <utility>

namespace mathtest {

template <auto Func> class HostRefChecker {
  using FunctionTraits = FunctionTraits<Func>;
  using InTypesTuple = typename FunctionTraits::ArgTypesTuple;

  using FunctionConfig = FunctionConfig<Func>;

  template <typename... Ts>
  using BuffersTupleType = std::tuple<llvm::ArrayRef<Ts>...>;

public:
  using OutType = typename FunctionTraits::ReturnType;

private:
  template <typename... Ts>
  using PartialResultType = TestResult<OutType, Ts...>;

public:
  using ResultType = ApplyTupleTypes_t<InTypesTuple, PartialResultType>;
  using InBuffersTupleType = ApplyTupleTypes_t<InTypesTuple, BuffersTupleType>;

  HostRefChecker() = delete;

  static ResultType check(InBuffersTupleType InBuffersTuple,
                          llvm::ArrayRef<OutType> OutBuffer) noexcept {
    const std::size_t BufferSize = OutBuffer.size();
    std::apply(
        [&](const auto &...InBuffers) {
          assert(
              ((InBuffers.size() == BufferSize) && ...) &&
              "All input buffers must have the same size as the output buffer");
        },
        InBuffersTuple);

    assert((BufferSize != 0) && "Buffer size cannot be zero");

    ResultType Init;

    auto Transform = [&](std::size_t Index) {
      auto CurrentInputsTuple = std::apply(
          [&](const auto &...InBuffers) {
            return std::make_tuple(InBuffers[Index]...);
          },
          InBuffersTuple);

      const OutType Actual = OutBuffer[Index];
      const OutType Expected = std::apply(Func, CurrentInputsTuple);

      const auto UlpDistance = computeUlpDistance(Actual, Expected);
      const bool IsFailure = UlpDistance > FunctionConfig::UlpTolerance;

      return ResultType(UlpDistance, IsFailure,
                        typename ResultType::TestCase(
                            std::move(CurrentInputsTuple), Actual, Expected));
    };

    auto Reduce = [](ResultType A, const ResultType &B) {
      A.accumulate(B);
      return A;
    };

    const auto Indexes = llvm::seq(BufferSize);
    return llvm::parallelTransformReduce(Indexes.begin(), Indexes.end(), Init,
                                         Reduce, Transform);
  }
};
} // namespace mathtest

#endif // MATHTEST_HOSTREFCHECKER_HPP