aboutsummaryrefslogtreecommitdiff
path: root/offload/unittests/Conformance/include/mathtest/RangeBasedGenerator.hpp
blob: 5e1e1139aba96b3b244767d5b69b374de563cf2c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file
/// This file contains the definition of the RangeBasedGenerator class, a base
/// class for input generators that operate on a sequence of ranges.
///
//===----------------------------------------------------------------------===//

#ifndef MATHTEST_RANGEBASEDGENERATOR_HPP
#define MATHTEST_RANGEBASEDGENERATOR_HPP

#include "mathtest/IndexedRange.hpp"
#include "mathtest/InputGenerator.hpp"

#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/Parallel.h"

#include <algorithm>
#include <array>
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <tuple>

namespace mathtest {

template <typename Derived, typename... InTypes>
class [[nodiscard]] RangeBasedGenerator : public InputGenerator<InTypes...> {
public:
  void reset() noexcept override { NextFlatIndex = 0; }

  [[nodiscard]] std::size_t
  fill(llvm::MutableArrayRef<InTypes>... Buffers) noexcept override {
    const std::array<std::size_t, NumInputs> BufferSizes = {Buffers.size()...};
    const std::size_t BufferSize = BufferSizes[0];
    assert((BufferSize != 0) && "Buffer size cannot be zero");
    assert(std::all_of(BufferSizes.begin(), BufferSizes.end(),
                       [&](std::size_t Size) { return Size == BufferSize; }) &&
           "All input buffers must have the same size");

    if (NextFlatIndex >= Size)
      return 0;

    const auto BatchSize = std::min<uint64_t>(BufferSize, Size - NextFlatIndex);
    const auto CurrentFlatIndex = NextFlatIndex;
    NextFlatIndex += BatchSize;

    auto BufferPtrsTuple = std::make_tuple(Buffers.data()...);

    llvm::parallelFor(0, BatchSize, [&](std::size_t Offset) {
      static_cast<Derived *>(this)->writeInputs(CurrentFlatIndex, Offset,
                                                BufferPtrsTuple);
    });

    return static_cast<std::size_t>(BatchSize);
  }

protected:
  using RangesTupleType = std::tuple<IndexedRange<InTypes>...>;

  static constexpr std::size_t NumInputs = sizeof...(InTypes);
  static_assert(NumInputs > 0, "The number of inputs must be at least 1");

  explicit constexpr RangeBasedGenerator(
      const IndexedRange<InTypes> &...Ranges) noexcept
      : RangesTuple(Ranges...) {}

  explicit constexpr RangeBasedGenerator(
      uint64_t Size, const IndexedRange<InTypes> &...Ranges) noexcept
      : RangesTuple(Ranges...), Size(Size) {}

  RangesTupleType RangesTuple;
  uint64_t Size = 0;

private:
  uint64_t NextFlatIndex = 0;
};
} // namespace mathtest

#endif // MATHTEST_RANGEBASEDGENERATOR_HPP