diff options
| -rw-r--r-- | openmp/runtime/src/CMakeLists.txt | 1 | ||||
| -rw-r--r-- | openmp/runtime/src/i18n/en_US.txt | 5 | ||||
| -rw-r--r-- | openmp/runtime/src/kmp_traits.cpp | 306 | ||||
| -rw-r--r-- | openmp/runtime/src/kmp_traits.h | 436 | ||||
| -rw-r--r-- | openmp/runtime/unittests/CMakeLists.txt | 1 | ||||
| -rw-r--r-- | openmp/runtime/unittests/Traits/CMakeLists.txt | 6 | ||||
| -rw-r--r-- | openmp/runtime/unittests/Traits/MockOMP.cpp | 24 | ||||
| -rw-r--r-- | openmp/runtime/unittests/Traits/TestOMPTraitParser.cpp | 1038 | ||||
| -rw-r--r-- | openmp/runtime/unittests/Traits/TestOMPTraits.cpp | 1132 |
9 files changed, 2949 insertions, 0 deletions
diff --git a/openmp/runtime/src/CMakeLists.txt b/openmp/runtime/src/CMakeLists.txt index 53f83c006b04..424739d6a94b 100644 --- a/openmp/runtime/src/CMakeLists.txt +++ b/openmp/runtime/src/CMakeLists.txt @@ -96,6 +96,7 @@ else() kmp_str.cpp kmp_tasking.cpp kmp_threadprivate.cpp + kmp_traits.cpp kmp_utility.cpp kmp_barrier.cpp kmp_wait_release.cpp diff --git a/openmp/runtime/src/i18n/en_US.txt b/openmp/runtime/src/i18n/en_US.txt index 08e837d3dea1..906cf202a866 100644 --- a/openmp/runtime/src/i18n/en_US.txt +++ b/openmp/runtime/src/i18n/en_US.txt @@ -362,6 +362,11 @@ TopologyGeneric "%1$s: %2$s (%3$d total cores)" AffGranularityBad "%1$s: granularity setting: %2$s does not exist in topology. Using granularity=%3$s instead." TopologyHybrid "%1$s: hybrid core type detected: %2$d %3$s cores." TopologyHybridCoreEff "%1$s: %2$d with core efficiency %3$d." +TraitParserInvalidDeviceKind "trait parser while parsing %1$s: invalid device kind (allowed: host/nohost/cpu/gpu/fpga)" +TraitParserInvalidUID "trait parser while parsing %1$s: invalid uid (%2$s)" +TraitParserMaxRecursion "trait parser while parsing %1$s: max recursion depth (%2$d) exceeded" +TraitParserFailed "trait parser while parsing %1$s: failed to parse trait specification (%2$s)" +TraitParserValueTooLarge "trait parser while parsing %1$s: value %2$d above limit (%3$d)" # --- OpenMP errors detected at runtime --- # diff --git a/openmp/runtime/src/kmp_traits.cpp b/openmp/runtime/src/kmp_traits.cpp new file mode 100644 index 000000000000..9d44ca2affe2 --- /dev/null +++ b/openmp/runtime/src/kmp_traits.cpp @@ -0,0 +1,306 @@ +/* + * kmp_traits.cpp -- Handle OpenMP context traits + * + * OpenMP 6.0 specifies the following trait sets: + * - construct + * - device + * - target device + * - implementation + * - extension + * - dynamic + * Currently, the implementation in this file supports traits from the (target) + * device and implementation trait sets that are relevant for implementing the + * OMP_DEFAULT_DEVICE and OMP_AVAILABLE_DEVICES environment variables. + */ + +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "kmp_traits.h" +#include "kmp_i18n.h" + +using namespace kmp_traits; + +// OpenMP trait grammar (in EBNF), currently used for parsing the +// OMP_DEFAULT_DEVICE/OMP_AVAILABLE_DEVICES environment variables +// +// Notes about the grammar: +// - Device traits are going to be translated into device numbers (aka integers) +// later in the runtime. The parser handles device numbers as device traits that +// have already been translated. +// - "*" is also not a trait, strictly speaking. But it's also supported by the +// parser and converted into a "match any" wildcard trait. +// - OpenMP 6.0 explicitly excludes "&&" and "||" from appearing in the same +// grouping level. +// - This grammar currently only supports plain integers for array subsripts / +// sections, no expressions. +// - TODO: +// - Add support for more traits +// +// TODOs regarding the implementation (not the grammar): +// - Implement array subscript/section parsing +// - Implement grammar TODOs after they have been incorporated into the grammar +// +// list = [clause {',' clause}] +// clause = +// device_number +// | "*" [index_expr] +// | trait_expr_group +// | trait_expr index_expr +// device_number = ["-"] integer0 +// trait_expr_group = +// trait_expr +// | trait_expr {"&&" trait_expr} +// | trait_expr {"||" trait_expr} +// trait_expr = +// trait_expr_single +// | trait_expr_group_paren +// trait_expr_single = ["!"] trait +// trait_expr_group_paren = ["!"] "(" trait_expr_group ")" +// trait = +// "uid" "(" uid_value ")" +// uid_value = (letter | digit0 | symbol) {letter | digit0 | symbol} +// +// index_expr = "[" integer0 "]" | "[" array_section "]" +// array_section = +// lower_bound ":" length ":" stride +// | lower_bound ":" length ":" +// | lower_bound ":" length +// | lower_bound "::" stride +// | lower_bound "::" +// | lower_bound ":" +// | ":" length ":" stride +// | ":" length ":" +// | ":" length +// | "::" stride +// | "::" +// | ":" +// lower_bound = integer0 +// length = integer0 +// stride = integer +// +// integer0 = 0 | integer +// integer = digit {digit0} +// +// letter = +// "A" | "B" | "C" | "D" | "E" | "F" | "G" | "H" | "I" | "J" | "K" | "L" +// | "M" | "N" | "O" | "P" | "Q" | "R" | "S" | "T" | "U" | "V" | "W" | "X" +// | "Y" | "Z" | "a" | "b" | "c" | "d" | "e" | "f" | "g" | "h" | "i" | "j" +// | "k" | "l" | "m" | "n" | "o" | "p" | "q" | "r" | "s" | "t" | "u" | "v" +// | "w" | "x" | "y" | "z" +// digit0 = "0" | digit +// digit = "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" +// symbol = "-" | "_" + +namespace parser { + +#define MAX_RECURSION_DEPTH 64 + +using namespace kmp_traits; + +static kmp_str_ref consume_uid_value(kmp_str_ref &scan, const char *dbg_name) { + scan.skip_space(); + kmp_str_ref uid = scan.take_while([](char c) { + return isalnum(static_cast<unsigned char>(c)) || c == '-' || c == '_'; + }); + scan.drop_front(uid.length()); + if (uid.empty() || !scan.consume_front(")")) + KMP_FATAL(TraitParserInvalidUID, dbg_name, uid.copy()); + return uid; +} + +static bool consume_trait(kmp_trait_expr_single &expr, kmp_str_ref &scan, + const char *dbg_name) { + scan.skip_space(); + if (!scan.consume_front("uid(")) + return false; + kmp_str_ref uid = consume_uid_value(scan, dbg_name); + expr.set_trait(new kmp_uid_trait(uid)); + return true; +} + +static bool consume_trait_expr_single(kmp_trait_expr_single &expr, + kmp_str_ref &scan, const char *dbg_name) { + kmp_str_ref orig_scan = scan; + + scan.skip_space(); + if (scan.consume_front("!")) + expr.set_negated(); + if (consume_trait(expr, scan, dbg_name)) + return true; + scan = orig_scan; + return false; +} + +// forward declaration +static bool consume_trait_expr_group(kmp_trait_expr_group &group, + kmp_str_ref &scan, int max_recursion, + const char *dbg_name); + +static bool consume_trait_expr_group_paren(kmp_trait_expr_group &group, + kmp_str_ref &scan, int max_recursion, + const char *dbg_name) { + if (max_recursion-- <= 0) + KMP_FATAL(TraitParserMaxRecursion, dbg_name, MAX_RECURSION_DEPTH); + kmp_str_ref orig_scan = scan; + + scan.skip_space(); + if (scan.consume_front("!")) + group.set_negated(); + + scan.skip_space(); + if (!scan.consume_front("(") || + !consume_trait_expr_group(group, scan, max_recursion, dbg_name)) { + scan = orig_scan; + return false; + } + + scan.skip_space(); + if (!scan.consume_front(")")) { + scan = orig_scan; + return false; + } + return true; +} + +static bool consume_trait_expr(kmp_trait_expr *&expr, kmp_str_ref &scan, + int max_recursion, const char *dbg_name) { + if (max_recursion-- <= 0) + KMP_FATAL(TraitParserMaxRecursion, dbg_name, MAX_RECURSION_DEPTH); + + // Parse a single trait expression + kmp_trait_expr_single *single_expr = new kmp_trait_expr_single(); + if (consume_trait_expr_single(*single_expr, scan, dbg_name)) { + expr = single_expr; + return true; + } + delete single_expr; + + // Parse a parenthesized group trait expression + kmp_trait_expr_group *group_expr = new kmp_trait_expr_group(); + if (consume_trait_expr_group_paren(*group_expr, scan, max_recursion, + dbg_name)) { + expr = group_expr; + return true; + } + delete group_expr; + + return false; +} + +static bool consume_trait_expr_group(kmp_trait_expr_group &group, + kmp_str_ref &scan, int max_recursion, + const char *dbg_name) { + if (max_recursion-- <= 0) + KMP_FATAL(TraitParserMaxRecursion, dbg_name, MAX_RECURSION_DEPTH); + + kmp_trait_expr *expr = nullptr; + if (!consume_trait_expr(expr, scan, max_recursion, dbg_name)) + return false; + + group.add_expr(expr); + const char *op = nullptr; + + scan.skip_space(); + if (scan.consume_front("||")) { + group.set_group_type(kmp_trait_expr_group::OR); + op = "||"; + } else if (scan.consume_front("&&")) { + group.set_group_type(kmp_trait_expr_group::AND); + op = "&&"; + } else { + return true; // single trait expression, no group + } + + // Now that we got an operator, we need at least one more trait expr. + do { + if (!consume_trait_expr(expr, scan, max_recursion, dbg_name)) + return false; + group.add_expr(expr); + scan.skip_space(); + } while (scan.consume_front(op)); + + return true; +} + +static bool consume_clause(kmp_trait_clause &clause, kmp_str_ref &scan, + const char *dbg_name) { + kmp_str_ref orig_scan = scan; + scan.skip_space(); + + // Parse wildcard "trait" + if (scan.consume_front("*")) { + clause.set_expr(new kmp_wildcard_trait()); + return true; + } + + // Parse a literal device number + int value; + if (scan.consume_integer(value)) { + clause.set_expr(new kmp_literal_trait(value)); + return true; + } + + // Parse a trait expression group + kmp_trait_expr_group *group = new kmp_trait_expr_group(); + if (consume_trait_expr_group(*group, scan, MAX_RECURSION_DEPTH, dbg_name)) { + clause.set_expr(group); + return true; + } + delete group; + + scan = orig_scan; + return false; +} + +static bool consume_list(kmp_trait_context &context, kmp_str_ref &scan, + const char *dbg_name) { + kmp_str_ref orig_scan = scan; + scan.skip_space(); + + while (!scan.empty()) { + kmp_trait_clause *clause = new kmp_trait_clause(); + if (!consume_clause(*clause, scan, dbg_name)) { + delete clause; + scan = orig_scan; + return false; + } + context.add_clause(clause); + orig_scan = scan; + + scan.skip_space(); + if (!scan.consume_front(",") && !scan.empty()) { + scan = orig_scan; + return false; + } + } + + return true; +} + +} // namespace parser + +kmp_trait_context *kmp_trait_context::parse_from_spec(kmp_str_ref spec, + const char *dbg_name) { + kmp_trait_context *context = new kmp_trait_context(); + if (!parser::consume_list(*context, spec, dbg_name)) + KMP_FATAL(TraitParserFailed, dbg_name, spec.copy()); + return context; +} + +int kmp_trait_context::parse_single_device(kmp_str_ref spec, + int device_num_limit, + const char *dbg_name) { + int device_num; + spec.skip_space(); + if (!spec.consume_integer(device_num)) + KMP_FATAL(TraitParserFailed, dbg_name, spec.copy()); + if (device_num > device_num_limit) + KMP_FATAL(TraitParserValueTooLarge, dbg_name, device_num, device_num_limit); + return device_num; +} diff --git a/openmp/runtime/src/kmp_traits.h b/openmp/runtime/src/kmp_traits.h new file mode 100644 index 000000000000..3f93fa74f0b7 --- /dev/null +++ b/openmp/runtime/src/kmp_traits.h @@ -0,0 +1,436 @@ +//===----------- Traits.h - OpenMP context traits -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implementation of OpenMP context traits. +// +//===----------------------------------------------------------------------===// + +#ifndef OPENMP_TRAITS_H +#define OPENMP_TRAITS_H + +#include "kmp.h" +#include "kmp_adt.h" + +namespace kmp_traits { + +extern "C" int omp_get_num_devices(); +extern "C" const char *omp_get_uid_from_device(int device_num); + +class kmp_trait { +protected: + enum trait_type { WILDCARD_T, LITERAL_T, UID_T }; + trait_type _type; + + kmp_trait(trait_type type) : _type(type) {} + +public: + virtual ~kmp_trait() = default; + + kmp_trait(const kmp_trait &) = delete; + kmp_trait(kmp_trait &&) = delete; + kmp_trait &operator=(const kmp_trait &) = delete; + kmp_trait &operator=(kmp_trait &&) = delete; + + virtual bool match(int device) const = 0; + + // Use KMP_INTERNAL_MALLOC/KMP_INTERNAL_FREE for memory management. + void *operator new(size_t size) { return KMP_INTERNAL_MALLOC(size); } + void operator delete(void *ptr) { KMP_INTERNAL_FREE(ptr); } + + virtual bool operator==(const kmp_trait &other) const { + return _type == other._type; + } +}; + +/// Represents a wildcard trait that matches any device. +class kmp_wildcard_trait final : public kmp_trait { +public: + kmp_wildcard_trait() : kmp_trait(WILDCARD_T) {} + + bool match([[maybe_unused]] int device) const override { return true; } + + bool operator==(const kmp_trait &other) const override { + return kmp_trait::operator==(other); + } +}; + +/// Represents a specific device number. +class kmp_literal_trait final : public kmp_trait { + int device_num; + +public: + kmp_literal_trait(int device_num) + : kmp_trait(LITERAL_T), device_num(device_num) { + assert(device_num >= 0 && "Device number must be non-negative"); + } + + bool match(int device) const override { return device_num == device; } + + bool operator==(const kmp_trait &other) const override { + return kmp_trait::operator==(other) && + device_num == + static_cast<const kmp_literal_trait &>(other).device_num; + } +}; + +/// Represents a specific UID. +/// UID is deliberately not resolved at construction time since libomptarget +/// might not be initialized yet. This is why we delay calls to +/// omp_get_uid_from_device / omp_get_device_from_uid until the trait is +/// evaluated. +class kmp_uid_trait final : public kmp_trait { + char *uid; + // Can be used by unit tests to mock omp_get_uid_from_device. + const char *(*get_uid_from_device)(int device) = omp_get_uid_from_device; + +public: + kmp_uid_trait(kmp_str_ref uid) : kmp_trait(UID_T), uid(uid.copy()) {} + + ~kmp_uid_trait() override { + if (uid) + KMP_INTERNAL_FREE(uid); + } + + bool match(int device) const override { + const char *device_uid = get_uid_from_device(device); + if (!device_uid || !uid) + return false; + return strcmp(device_uid, uid) == 0; + } + + // For testing purposes only: set the function that returns the UID from a + // device. + void set_uid_from_device(const char *(*uid_from_device)(int)) { + get_uid_from_device = uid_from_device; + } + + bool operator==(const kmp_trait &other) const override { + if (!kmp_trait::operator==(other)) + return false; + const char *other_uid = static_cast<const kmp_uid_trait &>(other).uid; + return uid && other_uid ? strcmp(uid, other_uid) == 0 : uid == other_uid; + } +}; + +/// Abstract class representing either a single trait expression or a collection +/// of trait expressions that are ANDed or ORed together. +class kmp_trait_expr { +protected: + enum expr_type { SINGLE_T, GROUP_T }; + expr_type _type; + // Determines if the expression is negated (true) or not (false). + bool negated = false; + // Can be used by unit tests to mock omp_get_num_devices. + int (*get_num_devices)() = omp_get_num_devices; + + kmp_trait_expr(expr_type type) : _type(type) {} + kmp_trait_expr(expr_type type, bool negated) + : _type(type), negated(negated) {} + + virtual bool match_impl(int device, int num_devices) const = 0; + +public: + virtual ~kmp_trait_expr() = default; + + kmp_trait_expr(const kmp_trait_expr &) = delete; + kmp_trait_expr(kmp_trait_expr &&) = delete; + kmp_trait_expr &operator=(const kmp_trait_expr &) = delete; + kmp_trait_expr &operator=(kmp_trait_expr &&) = delete; + + bool is_negated() const { return negated; } + + // Check if the device matches the expression. + bool match(int device, int num_devices = -1) const { + if (num_devices == -1) + num_devices = get_num_devices(); + if (device < 0 || device >= num_devices) + return false; + return match_impl(device, num_devices); + } + + void set_negated(bool neg = true) { negated = neg; } + + // For testing purposes only: set the function that returns the number of + // devices. + void set_num_devices(int (*num_devices)()) { get_num_devices = num_devices; } + + // Use KMP_INTERNAL_MALLOC/KMP_INTERNAL_FREE for memory management. + void *operator new(size_t size) { return KMP_INTERNAL_MALLOC(size); } + void operator delete(void *ptr) { KMP_INTERNAL_FREE(ptr); } + + virtual bool operator==(const kmp_trait_expr &other) const { + return _type == other._type && negated == other.negated; + } +}; + +/// Represents a single (possibly negated) trait. +class kmp_trait_expr_single final : public kmp_trait_expr { + kmp_trait *trait = nullptr; + +protected: + bool match_impl(int device, [[maybe_unused]] int num_devices) const override { + assert(trait); + bool result = trait->match(device); + return negated ? !result : result; + } + +public: + kmp_trait_expr_single() : kmp_trait_expr(SINGLE_T) {} + kmp_trait_expr_single(bool negated) : kmp_trait_expr(SINGLE_T, negated) {} + kmp_trait_expr_single(kmp_trait *trait) + : kmp_trait_expr(SINGLE_T), trait(trait) { + assert(trait && "kmp_trait_expr_single requires a non-null trait"); + } + ~kmp_trait_expr_single() override { delete trait; } + + void set_trait(kmp_trait *new_trait) { + assert(new_trait); + if (trait) + delete trait; + trait = new_trait; + } + + bool operator==(const kmp_trait_expr &other) const override { + if (!kmp_trait_expr::operator==(other)) + return false; + const kmp_trait_expr_single &other_single = + static_cast<const kmp_trait_expr_single &>(other); + return trait && other_single.trait ? *trait == *other_single.trait + : trait == other_single.trait; + } +}; + +/// Represents a (possibly negated) collection of traits that are either ANDed +/// or ORed together. +class kmp_trait_expr_group final : public kmp_trait_expr { +public: + enum group_type { AND, OR }; + +private: + kmp_vector<kmp_trait_expr *> exprs; + // Determines if all traits have to match (true) or any of them (false). + group_type type = OR; + +protected: + bool match_impl(int device, int num_devices) const override { + size_t matched = 0; + for (const kmp_trait_expr *expr : exprs) { + if (expr->match(device, num_devices)) + matched++; + } + // Note: AND evaluates to true for an empty group. + bool result = type == AND ? matched == exprs.size() : matched > 0; + return negated ? !result : result; + } + +public: + kmp_trait_expr_group() : kmp_trait_expr(GROUP_T) {} + kmp_trait_expr_group(bool negated) : kmp_trait_expr(GROUP_T, negated) {} + ~kmp_trait_expr_group() override { + for (kmp_trait_expr *expr : exprs) + delete expr; + } + + void add_expr(kmp_trait *trait) { + assert(trait); + add_expr(new kmp_trait_expr_single(trait)); + } + void add_expr(kmp_trait_expr *expr) { + assert(expr); + exprs.push_back(expr); + // Propagate get_num_devices to the expression. + expr->set_num_devices(get_num_devices); + } + + group_type get_group_type() const { return type; } + + void set_group_type(group_type new_type) { type = new_type; } + + void set_num_devices(int (*num_devices)()) { + kmp_trait_expr::set_num_devices(num_devices); + for (kmp_trait_expr *expr : exprs) + expr->set_num_devices(num_devices); + } + + bool operator==(const kmp_trait_expr &other) const override { + if (!kmp_trait_expr::operator==(other)) + return false; + const kmp_trait_expr_group &other_group = + static_cast<const kmp_trait_expr_group &>(other); + return exprs.is_set_equal( + other_group.exprs, [](kmp_trait_expr *const &a, + kmp_trait_expr *const &b) { return *a == *b; }); + } +}; + +class kmp_trait_clause final { + kmp_trait_expr *expr = nullptr; + +public: + kmp_trait_clause() = default; + ~kmp_trait_clause() { delete expr; } + + kmp_trait_clause(const kmp_trait_clause &) = delete; + kmp_trait_clause(kmp_trait_clause &&) = delete; + kmp_trait_clause &operator=(const kmp_trait_clause &) = delete; + kmp_trait_clause &operator=(kmp_trait_clause &&) = delete; + + kmp_trait_expr *get_expr() { return expr; } + + bool match(int device, int num_devices = -1) const { + assert(expr); + return expr->match(device, num_devices); + } + + void set_expr(kmp_trait *trait) { + assert(trait); + if (expr) + delete expr; + expr = new kmp_trait_expr_single(trait); + } + void set_expr(kmp_trait_expr *new_expr) { + assert(new_expr); + if (expr) + delete expr; + expr = new_expr; + } + + // Use KMP_INTERNAL_MALLOC/KMP_INTERNAL_FREE for memory management. + void *operator new(size_t size) { return KMP_INTERNAL_MALLOC(size); } + void operator delete(void *ptr) { KMP_INTERNAL_FREE(ptr); } + + bool operator==(const kmp_trait_clause &other) const { + return expr && other.expr ? *expr == *other.expr : expr == other.expr; + } +}; + +} // namespace kmp_traits + +class kmp_trait_context final { + using kmp_trait_clause = kmp_traits::kmp_trait_clause; + using kmp_trait_expr = kmp_traits::kmp_trait_expr; + + kmp_vector<kmp_trait_clause *> clauses; + // List of devices that have been evaluated. + kmp_vector<int> devices; + bool evaluated = false; + // Can be used by unit tests to mock omp_get_num_devices. + int (*get_num_devices)() = kmp_traits::omp_get_num_devices; + + void _evaluate() { + devices.clear(); + for (int d = 0; d < get_num_devices(); ++d) { + if (_match(d)) + devices.push_back(d); + } + evaluated = true; + } + + bool _match(int device) const { + if (device < 0 || device >= get_num_devices()) + return false; + for (kmp_trait_clause *clause : clauses) { + if (clause->match(device)) + return true; + } + return false; + } + +public: + kmp_trait_context() = default; + ~kmp_trait_context() { + for (kmp_trait_clause *clause : clauses) + delete clause; + } + + kmp_trait_context(const kmp_trait_context &) = delete; + kmp_trait_context(kmp_trait_context &&) = delete; + kmp_trait_context &operator=(const kmp_trait_context &) = delete; + kmp_trait_context &operator=(kmp_trait_context &&) = delete; + + // Parse a trait specification from a string. + // If dbg_name is provided, it will be used in error messages to identify the + // source of the trait specification. + static kmp_trait_context *parse_from_spec(kmp_str_ref spec, + const char *dbg_name = nullptr); + + // Parse only a single device number from the spec. + // This is useful for backward compatibility with legacy code. + // If dbg_name is provided, it will be used in error messages to identify the + // source of the device number. + static int parse_single_device(kmp_str_ref spec, int device_num_limit, + const char *dbg_name = nullptr); + + void add_clause(kmp_trait_clause *clause) { + assert(clause); + clauses.push_back(clause); + // Propagate get_num_devices to the clause. + if (kmp_trait_expr *expr = clause->get_expr()) + expr->set_num_devices(get_num_devices); + } + + // Returns the list of devices that match the trait specification represented + // by the context. The list contains devices numbers forming a set and sorted + // in ascending order. + // Note to future developers: if we want to add an option to force + // re-evaluation, we need to consider that the devices vector and thus the + // context iterators are invalidated. + const kmp_vector<int> &evaluate() { + trigger_evaluation(); + return devices; + } + + const kmp_vector<int> &evaluate() const { + assert(evaluated && "kmp_trait_context not evaluated"); + return devices; + } + + // Check if the device matches the trait specification represented by the + // context. + bool match(int device) { return evaluate().contains(device); } + + bool match(int device) const { + assert(evaluated && "kmp_trait_context not evaluated"); + return devices.contains(device); + } + + // For testing purposes only: set the function that returns the number of + // devices. + void set_num_devices(int (*num_devices)()) { + get_num_devices = num_devices; + for (kmp_trait_clause *clause : clauses) { + if (kmp_trait_expr *expr = clause->get_expr()) + expr->set_num_devices(num_devices); + } + } + + // Triggers lazy evaluation if not already evaluated. + void trigger_evaluation() { + if (!evaluated) + _evaluate(); + } + + // Use KMP_INTERNAL_MALLOC/KMP_INTERNAL_FREE for memory management. + void *operator new(size_t size) { return KMP_INTERNAL_MALLOC(size); } + void operator delete(void *ptr) { KMP_INTERNAL_FREE(ptr); } + + bool operator==(const kmp_trait_context &other) const { + auto clause_comp = [](kmp_trait_clause *const &a, + kmp_trait_clause *const &b) { return *a == *b; }; + return clauses.is_set_equal(other.clauses, clause_comp); + } + + // Iterator support (returns the iterators of the devices vector; triggers + // lazy evaluation if not already evaluated and if the context is not const). + const int *begin() { return evaluate().begin(); } + const int *end() { return evaluate().end(); } + const int *begin() const { return evaluate().begin(); } + const int *end() const { return evaluate().end(); } +}; + +#endif // OPENMP_TRAITS_H diff --git a/openmp/runtime/unittests/CMakeLists.txt b/openmp/runtime/unittests/CMakeLists.txt index dada4f9fc65f..64cd31940dca 100644 --- a/openmp/runtime/unittests/CMakeLists.txt +++ b/openmp/runtime/unittests/CMakeLists.txt @@ -88,3 +88,4 @@ add_openmp_testsuite( add_subdirectory(ADT) add_subdirectory(String) +add_subdirectory(Traits) diff --git a/openmp/runtime/unittests/Traits/CMakeLists.txt b/openmp/runtime/unittests/Traits/CMakeLists.txt new file mode 100644 index 000000000000..d53db903c9f3 --- /dev/null +++ b/openmp/runtime/unittests/Traits/CMakeLists.txt @@ -0,0 +1,6 @@ +add_openmp_unittest(TraitsTests + MockOMP.cpp + TestOMPTraits.cpp + TestOMPTraitParser.cpp +) + diff --git a/openmp/runtime/unittests/Traits/MockOMP.cpp b/openmp/runtime/unittests/Traits/MockOMP.cpp new file mode 100644 index 000000000000..2c8623798b7d --- /dev/null +++ b/openmp/runtime/unittests/Traits/MockOMP.cpp @@ -0,0 +1,24 @@ +//===- MockOMP.cpp - Mock OMP functions for testing ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Mock implementations for OMP functions (libomptarget not available in tests). +// +//===----------------------------------------------------------------------===// + +// Mock device count for testing +static int MockNumDevices = 4; +static const char *MockDeviceUIDs[] = {"device-0", "device-1", "device-2", + "device-3"}; + +extern "C" int omp_get_num_devices() { return MockNumDevices; } + +extern "C" const char *omp_get_uid_from_device(int DeviceNum) { + if (DeviceNum >= 0 && DeviceNum < MockNumDevices) + return MockDeviceUIDs[DeviceNum]; + return ""; +} diff --git a/openmp/runtime/unittests/Traits/TestOMPTraitParser.cpp b/openmp/runtime/unittests/Traits/TestOMPTraitParser.cpp new file mode 100644 index 000000000000..7d78ca0d202c --- /dev/null +++ b/openmp/runtime/unittests/Traits/TestOMPTraitParser.cpp @@ -0,0 +1,1038 @@ +//===- TestOMPTraitParser.cpp - Tests for OMP Trait Parser ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "kmp_traits.h" +#include "gtest/gtest.h" + +namespace { + +//===----------------------------------------------------------------------===// +// Helper to parse and auto-cleanup +//===----------------------------------------------------------------------===// + +class ParserTest : public ::testing::Test { +protected: + kmp_trait_context *context = nullptr; + + void parse(const char *spec, const char *dbg_name = nullptr) { + context = kmp_trait_context::parse_from_spec(kmp_str_ref(spec), dbg_name); + } + + void TearDown() override { + if (context) { + delete context; + context = nullptr; + } + } +}; + +template <bool expected_result> +static void check_result_single(kmp_trait_context *context, + const kmp_vector<int> &result, + int expected_device_num) { + EXPECT_EQ(context->match(expected_device_num), expected_result); + EXPECT_EQ(result.contains(expected_device_num), expected_result); +} + +template <bool expected_result, int... device_nums> +static void check_result(kmp_trait_context *context, + const kmp_vector<int> &result) { + (check_result_single<expected_result>(context, result, device_nums), ...); +} + +template <bool expected_result, int... device_nums> +static void check_result(kmp_trait_context *context) { + const kmp_vector<int> &result = context->evaluate(); + check_result<expected_result, device_nums...>(context, result); +} + +//===----------------------------------------------------------------------===// +// Literal Device Numbers +//===----------------------------------------------------------------------===// + +TEST_F(ParserTest, SingleLiteral) { + parse("5"); + + ASSERT_NE(context, nullptr); + // Device 5 is out of range (mock has 4 devices: 0-3), so match returns false + + EXPECT_EQ(context->evaluate().size(), 0u); + check_result<false, 5, 0, 4, 6>(context); +} + +TEST_F(ParserTest, ZeroLiteral) { + parse("0"); + + ASSERT_NE(context, nullptr); + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 1u); + check_result<true, 0>(context, result); + check_result<false, 1>(context, result); +} + +TEST_F(ParserTest, MultipleLiterals) { + parse("1,2,3"); + + ASSERT_NE(context, nullptr); + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 3u); + check_result<true, 1, 2, 3>(context, result); + check_result<false, 0, 4>(context, result); +} + +TEST_F(ParserTest, LiteralsWithSpaces) { + parse("1, 2, 3"); + + ASSERT_NE(context, nullptr); + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 3u); + check_result<true, 1, 2, 3>(context, result); + check_result<false, 0, 4>(context, result); +} + +TEST_F(ParserTest, LiteralsWithLeadingSpaces) { + parse(" 1, 2, 3"); + + ASSERT_NE(context, nullptr); + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 3u); + check_result<true, 1, 2, 3>(context, result); + check_result<false, 0, 4>(context, result); +} + +TEST_F(ParserTest, LargeLiteral) { + parse("12345"); + + ASSERT_NE(context, nullptr); + // Device 12345 is out of range, so match returns false + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 0u); + check_result<false, 12345, 0>(context, result); +} + +//===----------------------------------------------------------------------===// +// Wildcard +//===----------------------------------------------------------------------===// + +TEST_F(ParserTest, Wildcard) { + parse("*"); + + ASSERT_NE(context, nullptr); + // Wildcard matches all 4 mock devices + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 4u); + check_result<true, 0, 1, 2, 3>(context, result); + check_result<false, 100>(context, result); +} + +TEST_F(ParserTest, WildcardWithLiterals) { + parse("1, *, 3"); + + ASSERT_NE(context, nullptr); + // Wildcard makes all in-range devices match + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 4u); + check_result<true, 0, 1, 2, 3>(context, result); + check_result<false, 100>(context, result); +} + +//===----------------------------------------------------------------------===// +// UID Traits +//===----------------------------------------------------------------------===// + +TEST_F(ParserTest, UIDTrait) { + parse("uid(device-0)"); + + ASSERT_NE(context, nullptr); + // Uses mock: device-0 is at index 0 + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 1u); + check_result<true, 0>(context, result); + check_result<false, 1>(context, result); +} + +TEST_F(ParserTest, UIDTraitWithUnderscore) { + parse("uid(my_device_123)"); + + ASSERT_NE(context, nullptr); + // This UID doesn't match any mock device + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 0u); + check_result<false, 0, 1>(context, result); +} + +TEST_F(ParserTest, UIDTraitWithDash) { + parse("uid(device-2)"); + + ASSERT_NE(context, nullptr); + // Uses mock: device-2 is at index 2 + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 1u); + check_result<true, 2>(context, result); + check_result<false, 0, 1, 3>(context, result); +} + +TEST_F(ParserTest, MultipleUIDTraits) { + parse("uid(device-1), uid(device-3)"); + + ASSERT_NE(context, nullptr); + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 2u); + check_result<true, 1, 3>(context, result); + check_result<false, 0, 2>(context, result); +} + +TEST_F(ParserTest, MixedLiteralsAndUIDs) { + parse("0, uid(device-2), 1, uid(device-3)"); + + ASSERT_NE(context, nullptr); + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 4u); + check_result<true, 0, 1, 2, 3>(context, result); +} + +//===----------------------------------------------------------------------===// +// Negation +//===----------------------------------------------------------------------===// + +TEST_F(ParserTest, NegatedUID) { + parse("!uid(device-0)"); + + ASSERT_NE(context, nullptr); + // Negated: everything except device-0 matches + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 3u); + check_result<false, 0>(context, result); + check_result<true, 1, 2, 3>(context, result); +} + +//===----------------------------------------------------------------------===// +// Grouping with Parentheses +//===----------------------------------------------------------------------===// + +TEST_F(ParserTest, SimpleGroup) { + parse("(uid(device-1))"); + + ASSERT_NE(context, nullptr); + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 1u); + check_result<false, 0>(context, result); + check_result<true, 1>(context, result); +} + +TEST_F(ParserTest, GroupWithOR) { + parse("(uid(device-0) || uid(device-2))"); + + ASSERT_NE(context, nullptr); + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 2u); + check_result<true, 0, 2>(context, result); + check_result<false, 1, 3>(context, result); +} + +TEST_F(ParserTest, GroupWithAND) { + parse("(uid(device-0) && uid(device-0))"); + + ASSERT_NE(context, nullptr); + // Both refer to same device, so AND passes for device 0 + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 1u); + check_result<true, 0>(context, result); + check_result<false, 1>(context, result); +} + +TEST_F(ParserTest, NegatedGroup) { + parse("!(uid(device-0) || uid(device-1))"); + + ASSERT_NE(context, nullptr); + // Negated: matches devices NOT in {0, 1} + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 2u); + check_result<false, 0, 1>(context, result); + check_result<true, 2, 3>(context, result); +} + +//===----------------------------------------------------------------------===// +// Complex Expressions +//===----------------------------------------------------------------------===// + +TEST_F(ParserTest, ComplexMixed) { + parse("0, 1, uid(device-2), *"); + + ASSERT_NE(context, nullptr); + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 4u); + check_result<true, 0, 1, 2, 3>(context, result); + check_result<false, 100>(context, result); +} + +TEST_F(ParserTest, MultipleORGroups) { + parse("(uid(device-0) || uid(device-1)), (uid(device-2) || uid(device-3))"); + + ASSERT_NE(context, nullptr); + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 4u); + check_result<true, 0, 1, 2, 3>(context, result); +} + +//===----------------------------------------------------------------------===// +// Complex Boolean Operators +//===----------------------------------------------------------------------===// + +TEST_F(ParserTest, ThreeWayOR) { + // Three UIDs combined with OR + parse("(uid(device-0) || uid(device-1) || uid(device-2))"); + + ASSERT_NE(context, nullptr); + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 3u); + check_result<true, 0, 1, 2>(context, result); + check_result<false, 3>(context, result); +} + +TEST_F(ParserTest, FourWayOR) { + // All four mock devices via OR + parse("(uid(device-0) || uid(device-1) || uid(device-2) || uid(device-3))"); + + ASSERT_NE(context, nullptr); + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 4u); + check_result<true, 0, 1, 2, 3>(context, result); +} + +TEST_F(ParserTest, ThreeWayAND) { + // Three identical UIDs with AND - all must match same device + parse("(uid(device-1) && uid(device-1) && uid(device-1))"); + + ASSERT_NE(context, nullptr); + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 1u); + check_result<true, 1>(context, result); + check_result<false, 0, 2, 3>(context, result); +} + +TEST_F(ParserTest, ANDWithDifferentUIDs) { + // AND with different UIDs - can never match (device can't have two UIDs) + parse("(uid(device-0) && uid(device-1))"); + + ASSERT_NE(context, nullptr); + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 0u); + check_result<false, 0, 1, 2, 3>(context, result); +} + +TEST_F(ParserTest, NegatedThreeWayOR) { + // Negate a group of three UIDs - matches devices NOT in {0, 1, 2} + parse("!(uid(device-0) || uid(device-1) || uid(device-2))"); + + ASSERT_NE(context, nullptr); + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 1u); + check_result<false, 0, 1, 2>(context, result); + check_result<true, 3>(context, result); +} + +TEST_F(ParserTest, NegatedAND) { + // Negate an AND group - since AND never matches, negation matches all + parse("!(uid(device-0) && uid(device-1))"); + + ASSERT_NE(context, nullptr); + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 4u); + check_result<true, 0, 1, 2, 3>(context, result); +} + +TEST_F(ParserTest, NegatedANDWithSameUID) { + // Negate an AND that matches device-0 - matches everything except 0 + parse("!(uid(device-0) && uid(device-0))"); + + ASSERT_NE(context, nullptr); + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 3u); + check_result<false, 0>(context, result); + check_result<true, 1, 2, 3>(context, result); +} + +TEST_F(ParserTest, NestedParensWithOR) { + // Nested parentheses around OR + parse("((uid(device-0) || uid(device-1)))"); + + ASSERT_NE(context, nullptr); + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 2u); + check_result<true, 0, 1>(context, result); + check_result<false, 2, 3>(context, result); +} + +TEST_F(ParserTest, NestedParensWithAND) { + // Nested parentheses around AND + parse("((uid(device-2) && uid(device-2)))"); + + ASSERT_NE(context, nullptr); + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 1u); + check_result<true, 2>(context, result); + check_result<false, 0, 1, 3>(context, result); +} + +TEST_F(ParserTest, DoubleNegation) { + // Double negation: !!uid(device-0) should match device-0 + parse("!(!uid(device-0))"); + + ASSERT_NE(context, nullptr); + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 1u); + check_result<true, 0>(context, result); + check_result<false, 1, 2, 3>(context, result); +} + +TEST_F(ParserTest, NegatedNestedOR) { + // Negate nested OR group + parse("!((uid(device-0) || uid(device-1)))"); + + ASSERT_NE(context, nullptr); + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 2u); + check_result<false, 0, 1>(context, result); + check_result<true, 2, 3>(context, result); +} + +TEST_F(ParserTest, MultipleNegatedExprs) { + // Multiple negated clauses - OR semantics between clauses + parse("!uid(device-0), !uid(device-1)"); + + ASSERT_NE(context, nullptr); + // First clause matches 1,2,3; Second clause matches 0,2,3 + // OR between clauses: union = all devices + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 4u); + check_result<true, 0, 1, 2, 3>(context, result); +} + +TEST_F(ParserTest, MixedNegatedAndNonNegated) { + // Mix negated and non-negated clauses + parse("uid(device-0), !uid(device-0)"); + + ASSERT_NE(context, nullptr); + // First matches 0, second matches 1,2,3 -> union = all + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 4u); + check_result<true, 0, 1, 2, 3>(context, result); +} + +TEST_F(ParserTest, ComplexORGroupsInSeparateExprs) { + // Complex OR groups as separate clauses + parse("(uid(device-0) || uid(device-1)), (uid(device-2) || uid(device-3))"); + + ASSERT_NE(context, nullptr); + // First matches 0,1; Second matches 2,3 -> union = all + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 4u); + check_result<true, 0, 1, 2, 3>(context, result); +} + +TEST_F(ParserTest, NegatedORGroupWithLiteral) { + // Negated OR group combined with literal in separate clauses + parse("!(uid(device-0) || uid(device-1)), 0"); + + ASSERT_NE(context, nullptr); + // First matches 2,3; Second matches 0 -> union = 0,2,3 + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 3u); + check_result<true, 0, 2, 3>(context, result); + check_result<false, 1>(context, result); +} + +TEST_F(ParserTest, DeeplyNestedWithOperators) { + // Deeply nested with operators + parse("(((uid(device-0) || uid(device-1))))"); + + ASSERT_NE(context, nullptr); + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 2u); + check_result<true, 0, 1>(context, result); + check_result<false, 2, 3>(context, result); +} + +TEST_F(ParserTest, ORWithSpacesAroundOperators) { + // OR with lots of whitespace + parse("( uid(device-0) || uid(device-2) || uid(device-3) )"); + + ASSERT_NE(context, nullptr); + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 3u); + check_result<true, 0, 2, 3>(context, result); + check_result<false, 1>(context, result); +} + +TEST_F(ParserTest, ANDWithSpacesAroundOperators) { + // AND with lots of whitespace + parse("( uid(device-1) && uid(device-1) )"); + + ASSERT_NE(context, nullptr); + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 1u); + check_result<true, 1>(context, result); + check_result<false, 0, 2, 3>(context, result); +} + +//===----------------------------------------------------------------------===// +// Mixed && and || (in separate clauses/groups) +//===----------------------------------------------------------------------===// + +TEST_F(ParserTest, ORExprAndANDExpr) { + // OR group in first clause, AND group in second clause + parse("(uid(device-0) || uid(device-1)), (uid(device-2) && uid(device-2))"); + + ASSERT_NE(context, nullptr); + // First clause matches 0,1; Second clause matches 2 -> union = 0,1,2 + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 3u); + check_result<true, 0, 1, 2>(context, result); + check_result<false, 3>(context, result); +} + +TEST_F(ParserTest, ANDExprAndORExpr) { + // AND group first, OR group second + parse("(uid(device-0) && uid(device-0)), (uid(device-2) || uid(device-3))"); + + ASSERT_NE(context, nullptr); + // First clause matches 0; Second clause matches 2,3 -> union = 0,2,3 + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 3u); + check_result<true, 0, 2, 3>(context, result); + check_result<false, 1>(context, result); +} + +TEST_F(ParserTest, MultipleANDAndORExprs) { + // Multiple clauses alternating between AND and OR + parse("(uid(device-0) && uid(device-0)), (uid(device-1) || uid(device-2)), " + "(uid(device-3) && uid(device-3))"); + + ASSERT_NE(context, nullptr); + // Expr 1 matches 0; Expr 2 matches 1,2; Expr 3 matches 3 -> all + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 4u); + check_result<true, 0, 1, 2, 3>(context, result); +} + +TEST_F(ParserTest, NegatedORWithAND) { + // Negated OR clause combined with AND clause + parse("!(uid(device-0) || uid(device-1)), (uid(device-0) && uid(device-0))"); + + ASSERT_NE(context, nullptr); + // First clause matches 2,3; Second clause matches 0 -> union = 0,2,3 + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 3u); + check_result<true, 0, 2, 3>(context, result); + check_result<false, 1>(context, result); +} + +TEST_F(ParserTest, NegatedANDWithOR) { + // Negated AND clause combined with OR clause + parse("!(uid(device-0) && uid(device-0)), (uid(device-0) || uid(device-1))"); + + ASSERT_NE(context, nullptr); + // First clause matches 1,2,3; Second clause matches 0,1 -> all + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 4u); + check_result<true, 0, 1, 2, 3>(context, result); +} + +TEST_F(ParserTest, ComplexMixedOperators) { + // Complex mix: OR, AND, negated OR, literal + parse("(uid(device-0) || uid(device-1)), (uid(device-2) && uid(device-2)), " + "!(uid(device-0) || uid(device-1) || uid(device-2)), 0"); + + ASSERT_NE(context, nullptr); + // Expr 1: 0,1; Expr 2: 2; Expr 3: NOT(0,1,2) = 3; Expr 4: 0 + // Union = all + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 4u); + check_result<true, 0, 1, 2, 3>(context, result); +} + +TEST_F(ParserTest, ANDNeverMatchesWithOR) { + // AND that never matches combined with OR that does + parse("(uid(device-0) && uid(device-1)), (uid(device-2) || uid(device-3))"); + + ASSERT_NE(context, nullptr); + // First clause: never matches (different UIDs); Second: 2,3 + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 2u); + check_result<false, 0, 1>(context, result); + check_result<true, 2, 3>(context, result); +} + +TEST_F(ParserTest, ORNeverMatchesWithAND) { + // OR with non-existent UIDs combined with AND that matches + parse("(uid(nonexistent-a) || uid(nonexistent-b)), (uid(device-0) && " + "uid(device-0))"); + + ASSERT_NE(context, nullptr); + // First clause: no match; Second: 0 + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 1u); + check_result<true, 0>(context, result); + check_result<false, 1, 2, 3>(context, result); +} + +TEST_F(ParserTest, ThreeWayORAndThreeWayAND) { + // Three-way OR and three-way AND in separate clauses + parse("(uid(device-0) || uid(device-1) || uid(device-2)), (uid(device-3) && " + "uid(device-3) && uid(device-3))"); + + ASSERT_NE(context, nullptr); + // First: 0,1,2; Second: 3 -> all + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 4u); + check_result<true, 0, 1, 2, 3>(context, result); +} + +TEST_F(ParserTest, NegatedMixedExprs) { + // Both clauses negated with different operators + parse("!(uid(device-0) || uid(device-1)), !(uid(device-2) && uid(device-2))"); + + ASSERT_NE(context, nullptr); + // First: NOT(0,1) = 2,3; Second: NOT(2) = 0,1,3 + // Union = all + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 4u); + check_result<true, 0, 1, 2, 3>(context, result); +} + +TEST_F(ParserTest, LiteralsWithMixedOperatorExprs) { + // Literals combined with both OR and AND clauses + parse("0, (uid(device-1) || uid(device-2)), 3, (uid(device-0) && " + "uid(device-0))"); + + ASSERT_NE(context, nullptr); + // Literals: 0,3; OR clause: 1,2; AND clause: 0 + // Union = all + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 4u); + check_result<true, 0, 1, 2, 3>(context, result); +} + +//===----------------------------------------------------------------------===// +// Nested Mixed Operators (|| and && at different nesting levels) +//===----------------------------------------------------------------------===// + +TEST_F(ParserTest, ORContainingANDGroup) { + // Outer OR with inner AND group: (A || (B && C)) + // For (B && C) to match, both B and C must match same device + parse("(uid(device-0) || (uid(device-1) && uid(device-1)))"); + + ASSERT_NE(context, nullptr); + // device-0 matches via first operand + // device-1 matches via (device-1 && device-1) + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 2u); + check_result<true, 0, 1>(context, result); + check_result<false, 2, 3>(context, result); +} + +TEST_F(ParserTest, ANDContainingORGroup) { + // Outer AND with inner OR group: (A && (B || C)) + // Both the trait A and the group (B || C) must match + // Since A is uid(device-0), only device-0 can satisfy A + // (B || C) must also match device-0 for AND to succeed + parse("(uid(device-0) && (uid(device-0) || uid(device-1)))"); + + ASSERT_NE(context, nullptr); + // device-0: uid(device-0) matches AND (uid(device-0) || uid(device-1)) + // matches -> true device-1: uid(device-0) doesn't match -> false + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 1u); + check_result<true, 0>(context, result); + check_result<false, 1, 2, 3>(context, result); +} + +TEST_F(ParserTest, ORWithTwoANDGroups) { + // ((A && B) || (C && D)) - OR of two AND groups + parse( + "((uid(device-0) && uid(device-0)) || (uid(device-2) && uid(device-2)))"); + + ASSERT_NE(context, nullptr); + // First AND matches device-0; Second AND matches device-2 + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 2u); + check_result<true, 0, 2>(context, result); + check_result<false, 1, 3>(context, result); +} + +TEST_F(ParserTest, ANDWithTwoORGroups) { + // ((A || B) && (C || D)) - AND of two OR groups + // For a device to match: must match (A || B) AND must match (C || D) + parse( + "((uid(device-0) || uid(device-1)) && (uid(device-0) || uid(device-2)))"); + + ASSERT_NE(context, nullptr); + // device-0: matches (0||1) AND matches (0||2) -> true + // device-1: matches (0||1) but NOT (0||2) -> false + // device-2: NOT (0||1) -> false + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 1u); + check_result<true, 0>(context, result); + check_result<false, 1, 2, 3>(context, result); +} + +TEST_F(ParserTest, ORWithNestedANDContainingOR) { + // (A || (B && (C || D))) - three levels of nesting + parse( + "(uid(device-3) || (uid(device-0) && (uid(device-0) || uid(device-1))))"); + + ASSERT_NE(context, nullptr); + // device-0: inner (0||1) matches, uid(device-0) matches -> AND matches; OR + // satisfied device-1: inner (0||1) matches, but uid(device-0) doesn't -> AND + // fails; outer uid(device-3) fails device-3: outer uid(device-3) matches + // directly + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 2u); + check_result<true, 0, 3>(context, result); + check_result<false, 1, 2>(context, result); +} + +TEST_F(ParserTest, ANDWithNestedORContainingAND) { + // (A && (B || (C && D))) - three levels of nesting + parse( + "(uid(device-0) && (uid(device-0) || (uid(device-1) && uid(device-1))))"); + + ASSERT_NE(context, nullptr); + // device-0: uid(device-0) matches; inner (uid(device-0) || ...) matches -> + // AND satisfied device-1: uid(device-0) doesn't match -> AND fails + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 1u); + check_result<true, 0>(context, result); + check_result<false, 1, 2, 3>(context, result); +} + +TEST_F(ParserTest, NegatedORContainingAND) { + // !(A || (B && C)) - negated complex expression + parse("!(uid(device-0) || (uid(device-1) && uid(device-1)))"); + + ASSERT_NE(context, nullptr); + // Without negation: matches 0, 1 + // With negation: matches 2, 3 + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 2u); + check_result<false, 0, 1>(context, result); + check_result<true, 2, 3>(context, result); +} + +TEST_F(ParserTest, NegatedANDContainingOR) { + // !(A && (B || C)) - negated complex expression + parse("!(uid(device-0) && (uid(device-0) || uid(device-1)))"); + + ASSERT_NE(context, nullptr); + // Without negation: matches only 0 + // With negation: matches 1, 2, 3 + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 3u); + check_result<false, 0>(context, result); + check_result<true, 1, 2, 3>(context, result); +} + +TEST_F(ParserTest, ComplexNestedWithAllDevices) { + // ((A || B) && (C || D)) where union covers all but AND restricts + parse( + "((uid(device-0) || uid(device-1)) && (uid(device-1) || uid(device-2)))"); + + ASSERT_NE(context, nullptr); + // device-0: (0||1)=true, (1||2)=false -> AND=false + // device-1: (0||1)=true, (1||2)=true -> AND=true + // device-2: (0||1)=false -> AND=false + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 1u); + check_result<true, 1>(context, result); + check_result<false, 0, 2, 3>(context, result); +} + +TEST_F(ParserTest, TripleNestedMixedOperators) { + // (((A || B) && C) || D) - deeply nested with alternating operators + parse( + "(((uid(device-0) || uid(device-1)) && uid(device-0)) || uid(device-3))"); + + ASSERT_NE(context, nullptr); + // Inner (0||1): matches 0, 1 + // Middle ((0||1) && 0): matches only 0 + // Outer (... || 3): matches 0, 3 + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 2u); + check_result<true, 0, 3>(context, result); + check_result<false, 1, 2>(context, result); +} + +TEST_F(ParserTest, ANDChainWithNestedOR) { + // (A && (B || C) && D) - wait, this mixes operators at same level + // Actually: ((A && (B || C)) is valid - let's do that + // Let's do: (uid(device-0) && (uid(device-0) || uid(device-1)) && + // uid(device-0)) This is three-way AND where middle operand is an OR group + parse("(uid(device-0) && (uid(device-0) || uid(device-1)) && uid(device-0))"); + + ASSERT_NE(context, nullptr); + // All three must match: uid(device-0), (0||1), uid(device-0) + // Only device-0 satisfies all + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 1u); + check_result<true, 0>(context, result); + check_result<false, 1, 2, 3>(context, result); +} + +TEST_F(ParserTest, ORChainWithNestedAND) { + // (A || (B && C) || D) - three-way OR where middle is AND group + parse("(uid(device-0) || (uid(device-1) && uid(device-1)) || uid(device-3))"); + + ASSERT_NE(context, nullptr); + // Any of: device-0, (device-1 && device-1), device-3 + // Matches: 0, 1, 3 + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 3u); + check_result<true, 0, 1, 3>(context, result); + check_result<false, 2>(context, result); +} + +TEST_F(ParserTest, NestedMixedWithSpaces) { + // Nested mixed operators with lots of whitespace + parse("( uid(device-0) || ( uid(device-1) && uid(device-1) ) || " + "uid(device-2) )"); + + ASSERT_NE(context, nullptr); + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 3u); + check_result<true, 0, 1, 2>(context, result); + check_result<false, 3>(context, result); +} + +//===----------------------------------------------------------------------===// +// Empty Input +//===----------------------------------------------------------------------===// + +TEST_F(ParserTest, EmptyString) { + parse(""); + + ASSERT_NE(context, nullptr); + // Empty context matches nothing + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 0u); + check_result<false, 0, 1>(context, result); +} + +TEST_F(ParserTest, OnlyWhitespace) { + parse(" "); + + ASSERT_NE(context, nullptr); + kmp_vector<int> result = context->evaluate(); + + EXPECT_EQ(result.size(), 0u); + check_result<false, 0>(context, result); +} + +//===----------------------------------------------------------------------===// +// Error Cases +//===----------------------------------------------------------------------===// + +TEST_F(ParserTest, OnlyComma) { + ASSERT_DEATH( + parse(",", "test_only_comma"), + "OMP: Error #[0-9]+: trait parser while parsing test_only_comma: " + "failed to parse trait specification \\(,\\)"); +} + +TEST_F(ParserTest, OnlyCommaNullDbgName) { + ASSERT_DEATH(parse(","), + "OMP: Error #[0-9]+: trait parser while parsing \\(null\\): " + "failed to parse trait specification \\(,\\)"); +} + +TEST_F(ParserTest, MixedAndOrSameLevel) { + // OpenMP 6.0 explicitly excludes "&&" and "||" from appearing in the same + // grouping level. + ASSERT_DEATH( + parse("uid(a) && uid(b) || uid(c)", "mixed_and_or_same_level"), + "OMP: Error #[0-9]+: trait parser while parsing mixed_and_or_same_level: " + "failed to parse trait specification " + "\\(\\|\\| uid\\(c\\)\\)"); +} + +TEST_F(ParserTest, MixedOrAndSameLevel) { + ASSERT_DEATH( + parse("uid(a) || uid(b) && uid(c)", "mixed_or_and_same_level"), + "OMP: Error #[0-9]+: trait parser while parsing mixed_or_and_same_level: " + "failed to parse trait specification " + "\\(&& uid\\(c\\)\\)"); +} + +TEST_F(ParserTest, InvalidUID) { + // Empty UID is not allowed + ASSERT_DEATH(parse("uid()", "invalid_uid"), + "OMP: Error #[0-9]+: trait parser while parsing invalid_uid: " + "invalid uid \\(\\)"); +} + +TEST_F(ParserTest, UnclosedParenthesis) { + ASSERT_DEATH( + parse("(uid(a)", "unclosed_parenthesis"), + "OMP: Error #[0-9]+: trait parser while parsing unclosed_parenthesis: " + "failed to parse trait specification \\(\\(uid\\(a\\)\\)"); +} + +TEST_F(ParserTest, UnmatchedClosingParenthesis) { + ASSERT_DEATH(parse("uid(a))", "unmatched_closing_parenthesis"), + "OMP: Error #[0-9]+: trait parser while parsing " + "unmatched_closing_parenthesis: " + "failed to parse trait specification \\(\\)\\)"); +} + +TEST_F(ParserTest, EmptyParentheses) { + ASSERT_DEATH( + parse("()", "empty_parentheses"), + "OMP: Error #[0-9]+: trait parser while parsing empty_parentheses: " + "failed to parse trait specification \\(\\(\\)\\)"); +} + +TEST_F(ParserTest, TrailingOperator) { + ASSERT_DEATH( + parse("uid(a) &&", "trailing_operator"), + "OMP: Error #[0-9]+: trait parser while parsing trailing_operator: " + "failed to parse trait specification \\(uid\\(a\\) &&\\)"); +} + +TEST_F(ParserTest, LeadingOperator) { + ASSERT_DEATH( + parse("&& uid(a)", "leading_operator"), + "OMP: Error #[0-9]+: trait parser while parsing leading_operator: " + "failed to parse trait specification \\(&& uid\\(a\\)\\)"); +} + +TEST_F(ParserTest, DoubleComma) { + ASSERT_DEATH(parse("uid(a),,uid(b)", "double_comma"), + "OMP: Error #[0-9]+: trait parser while parsing double_comma: " + "failed to parse trait specification \\(,,uid\\(b\\)\\)"); +} + +//===----------------------------------------------------------------------===// +// parse_single_device Tests +//===----------------------------------------------------------------------===// + +TEST(ParseSingleDeviceTest, ValidSingleDigit) { + int result = kmp_trait_context::parse_single_device(kmp_str_ref("5"), 10); + EXPECT_EQ(result, 5); +} + +TEST(ParseSingleDeviceTest, ValidMultiDigit) { + int result = kmp_trait_context::parse_single_device(kmp_str_ref("123"), 200); + EXPECT_EQ(result, 123); +} + +TEST(ParseSingleDeviceTest, Zero) { + int result = kmp_trait_context::parse_single_device(kmp_str_ref("0"), 10); + EXPECT_EQ(result, 0); +} + +TEST(ParseSingleDeviceTest, AtLimit) { + int result = kmp_trait_context::parse_single_device(kmp_str_ref("10"), 10); + EXPECT_EQ(result, 10); +} + +TEST(ParseSingleDeviceTest, AboveLimit) { + ASSERT_DEATH(kmp_trait_context::parse_single_device(kmp_str_ref("11"), 10, + "above_limit"), + "OMP: Error #[0-9]+: trait parser while parsing above_limit: " + "value 11 above limit \\(10\\)"); +} + +TEST(ParseSingleDeviceTest, NonInteger) { + ASSERT_DEATH(kmp_trait_context::parse_single_device(kmp_str_ref("abc"), 10, + "non_integer"), + "OMP: Error #[0-9]+: trait parser while parsing non_integer: " + "failed to parse trait specification \\(abc\\)"); +} + +TEST(ParseSingleDeviceTest, EmptyString) { + ASSERT_DEATH( + kmp_trait_context::parse_single_device(kmp_str_ref(""), 10, "empty"), + "OMP: Error #[0-9]+: trait parser while parsing empty: " + "failed to parse trait specification \\(\\)"); +} + +TEST(ParseSingleDeviceTest, LeadingSpaces) { + // consume_integer skips leading spaces + int result = kmp_trait_context::parse_single_device(kmp_str_ref(" 7"), 10); + EXPECT_EQ(result, 7); +} + +TEST(ParseSingleDeviceTest, LargeNumber) { + int result = + kmp_trait_context::parse_single_device(kmp_str_ref("999999"), 1000000); + EXPECT_EQ(result, 999999); +} + +} // namespace diff --git a/openmp/runtime/unittests/Traits/TestOMPTraits.cpp b/openmp/runtime/unittests/Traits/TestOMPTraits.cpp new file mode 100644 index 000000000000..9af10db803cc --- /dev/null +++ b/openmp/runtime/unittests/Traits/TestOMPTraits.cpp @@ -0,0 +1,1132 @@ +//===- TestOMPTraits.cpp - Tests for OMP Trait classes -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "kmp_traits.h" +#include "gtest/gtest.h" + +using namespace kmp_traits; + +namespace { + +//===----------------------------------------------------------------------===// +// kmp_wildcard_trait Tests +//===----------------------------------------------------------------------===// + +TEST(kmp_wildcard_trait_test, MatchesAnyDevice) { + kmp_wildcard_trait *trait = new kmp_wildcard_trait(); + + EXPECT_TRUE(trait->match(0)); + EXPECT_TRUE(trait->match(1)); + EXPECT_TRUE(trait->match(100)); + EXPECT_TRUE(trait->match(-1)); + + delete trait; +} + +TEST(kmp_wildcard_trait_test, Equality) { + kmp_wildcard_trait *t1 = new kmp_wildcard_trait(); + kmp_wildcard_trait *t2 = new kmp_wildcard_trait(); + + EXPECT_TRUE(*t1 == *t2); + + delete t1; + delete t2; +} + +//===----------------------------------------------------------------------===// +// kmp_literal_trait Tests +//===----------------------------------------------------------------------===// + +TEST(kmp_literal_trait_test, MatchesExactDevice) { + kmp_literal_trait *trait = new kmp_literal_trait(5); + + EXPECT_TRUE(trait->match(5)); + EXPECT_FALSE(trait->match(0)); + EXPECT_FALSE(trait->match(4)); + EXPECT_FALSE(trait->match(6)); + + delete trait; +} + +TEST(kmp_literal_trait_test, MatchesZero) { + kmp_literal_trait *trait = new kmp_literal_trait(0); + + EXPECT_TRUE(trait->match(0)); + EXPECT_FALSE(trait->match(1)); + + delete trait; +} + +#ifndef NDEBUG +TEST(kmp_literal_trait_test, MatchesNegative) { + EXPECT_DEATH(new kmp_literal_trait(-1), "Device number must be non-negative"); +} +#endif + +TEST(kmp_literal_trait_test, EqualitySameValue) { + kmp_literal_trait *t1 = new kmp_literal_trait(42); + kmp_literal_trait *t2 = new kmp_literal_trait(42); + + EXPECT_TRUE(*t1 == *t2); + + delete t1; + delete t2; +} + +TEST(kmp_literal_trait_test, EqualityDifferentValue) { + kmp_literal_trait *t1 = new kmp_literal_trait(1); + kmp_literal_trait *t2 = new kmp_literal_trait(2); + + EXPECT_FALSE(*t1 == *t2); + + delete t1; + delete t2; +} + +//===----------------------------------------------------------------------===// +// kmp_uid_trait Tests +//===----------------------------------------------------------------------===// + +TEST(kmp_uid_trait_test, Construction) { + kmp_uid_trait *trait = new kmp_uid_trait(kmp_str_ref("test-uid")); + + // Just verify it can be constructed without crashing + delete trait; +} + +TEST(kmp_uid_trait_test, MatchWithMock) { + kmp_uid_trait *trait = new kmp_uid_trait(kmp_str_ref("device-0")); + + // Uses the mock omp_get_uid_from_device + EXPECT_TRUE(trait->match(0)); // device-0 matches + EXPECT_FALSE(trait->match(1)); // device-1 doesn't match + EXPECT_FALSE(trait->match(2)); // device-2 doesn't match + + delete trait; +} + +TEST(kmp_uid_trait_test, MatchWithCustomMock) { + kmp_uid_trait *trait = new kmp_uid_trait(kmp_str_ref("custom-uid")); + + // Set a custom mock function + trait->set_uid_from_device([](int device) -> const char * { + return device == 2 ? "custom-uid" : "other"; + }); + + EXPECT_FALSE(trait->match(0)); + EXPECT_FALSE(trait->match(1)); + EXPECT_TRUE(trait->match(2)); // custom-uid matches device 2 + EXPECT_FALSE(trait->match(3)); + + delete trait; +} + +TEST(kmp_uid_trait_test, EqualitySameUID) { + kmp_uid_trait *t1 = new kmp_uid_trait(kmp_str_ref("my-device")); + kmp_uid_trait *t2 = new kmp_uid_trait(kmp_str_ref("my-device")); + + EXPECT_TRUE(*t1 == *t2); + + delete t1; + delete t2; +} + +TEST(kmp_uid_trait_test, EqualityDifferentUID) { + kmp_uid_trait *t1 = new kmp_uid_trait(kmp_str_ref("device-a")); + kmp_uid_trait *t2 = new kmp_uid_trait(kmp_str_ref("device-b")); + + EXPECT_FALSE(*t1 == *t2); + + delete t1; + delete t2; +} + +//===----------------------------------------------------------------------===// +// kmp_trait_expr_single Tests +//===----------------------------------------------------------------------===// + +TEST(kmp_trait_expr_single_test, CreateAndDestroy) { + kmp_trait_expr_single *expr = new kmp_trait_expr_single(); + EXPECT_NE(expr, nullptr); + delete expr; +} + +TEST(kmp_trait_expr_single_test, CreateWithTrait) { + kmp_trait_expr_single *expr = + new kmp_trait_expr_single(new kmp_literal_trait(2)); + + // Mock: 4 devices + expr->set_num_devices([]() { return 4; }); + + EXPECT_TRUE(expr->match(2)); + EXPECT_FALSE(expr->match(0)); + EXPECT_FALSE(expr->match(1)); + EXPECT_FALSE(expr->match(5)); // Out of range + + delete expr; +} + +TEST(kmp_trait_expr_single_test, SetTrait) { + kmp_trait_expr_single *expr = new kmp_trait_expr_single(); + expr->set_trait(new kmp_literal_trait(3)); + + // Mock: 4 devices + expr->set_num_devices([]() { return 4; }); + + EXPECT_TRUE(expr->match(3)); + EXPECT_FALSE(expr->match(0)); + + delete expr; +} + +TEST(kmp_trait_expr_single_test, DefaultNotNegated) { + kmp_trait_expr_single *expr = new kmp_trait_expr_single(); + + EXPECT_FALSE(expr->is_negated()); + + delete expr; +} + +TEST(kmp_trait_expr_single_test, SetNegated) { + kmp_trait_expr_single *expr = new kmp_trait_expr_single(); + + expr->set_negated(true); + EXPECT_TRUE(expr->is_negated()); + + expr->set_negated(false); + EXPECT_FALSE(expr->is_negated()); + + delete expr; +} + +TEST(kmp_trait_expr_single_test, MatchNegated) { + kmp_trait_expr_single *expr = + new kmp_trait_expr_single(new kmp_literal_trait(2)); + expr->set_negated(true); + + // Mock: 4 devices + expr->set_num_devices([]() { return 4; }); + + // Without negation: matches 2 + // With negation: matches everything in-range except 2 + EXPECT_FALSE(expr->match(2)); + EXPECT_TRUE(expr->match(0)); + EXPECT_TRUE(expr->match(1)); + EXPECT_TRUE(expr->match(3)); + // Out of range devices return false regardless of negation + EXPECT_FALSE(expr->match(5)); + + delete expr; +} + +TEST(kmp_trait_expr_single_test, MatchWildcard) { + kmp_trait_expr_single *expr = + new kmp_trait_expr_single(new kmp_wildcard_trait()); + + // Mock: 4 devices + expr->set_num_devices([]() { return 4; }); + + // Wildcard matches any in-range device + EXPECT_TRUE(expr->match(0)); + EXPECT_TRUE(expr->match(3)); + // Out of range devices return false + EXPECT_FALSE(expr->match(100)); + + delete expr; +} + +TEST(kmp_trait_expr_single_test, Equality) { + kmp_trait_expr_single *e1 = + new kmp_trait_expr_single(new kmp_literal_trait(1)); + kmp_trait_expr_single *e2 = + new kmp_trait_expr_single(new kmp_literal_trait(1)); + + EXPECT_TRUE(*e1 == *e2); + + delete e1; + delete e2; +} + +TEST(kmp_trait_expr_single_test, EqualityDifferentTrait) { + kmp_trait_expr_single *e1 = + new kmp_trait_expr_single(new kmp_literal_trait(1)); + kmp_trait_expr_single *e2 = + new kmp_trait_expr_single(new kmp_literal_trait(2)); + + EXPECT_FALSE(*e1 == *e2); + + delete e1; + delete e2; +} + +TEST(kmp_trait_expr_single_test, EqualityDifferentNegation) { + kmp_trait_expr_single *e1 = + new kmp_trait_expr_single(new kmp_literal_trait(1)); + kmp_trait_expr_single *e2 = + new kmp_trait_expr_single(new kmp_literal_trait(1)); + e2->set_negated(true); + + EXPECT_FALSE(*e1 == *e2); + + delete e1; + delete e2; +} + +//===----------------------------------------------------------------------===// +// kmp_trait_expr_group Tests +//===----------------------------------------------------------------------===// + +TEST(kmp_trait_expr_group_test, CreateAndDestroy) { + kmp_trait_expr_group *group = new kmp_trait_expr_group(); + EXPECT_NE(group, nullptr); + delete group; +} + +TEST(kmp_trait_expr_group_test, DefaultTypeIsOR) { + kmp_trait_expr_group *group = new kmp_trait_expr_group(); + + EXPECT_EQ(group->get_group_type(), kmp_trait_expr_group::OR); + + delete group; +} + +TEST(kmp_trait_expr_group_test, SetTypeAND) { + kmp_trait_expr_group *group = new kmp_trait_expr_group(); + + group->set_group_type(kmp_trait_expr_group::AND); + EXPECT_EQ(group->get_group_type(), kmp_trait_expr_group::AND); + + delete group; +} + +TEST(kmp_trait_expr_group_test, DefaultNotNegated) { + kmp_trait_expr_group *group = new kmp_trait_expr_group(); + + EXPECT_FALSE(group->is_negated()); + + delete group; +} + +TEST(kmp_trait_expr_group_test, SetNegated) { + kmp_trait_expr_group *group = new kmp_trait_expr_group(); + + group->set_negated(true); + EXPECT_TRUE(group->is_negated()); + + group->set_negated(false); + EXPECT_FALSE(group->is_negated()); + + delete group; +} + +TEST(kmp_trait_expr_group_test, AddTraitDirectly) { + kmp_trait_expr_group *group = new kmp_trait_expr_group(); + + group->add_expr(new kmp_wildcard_trait()); + + // Mock: 4 devices + group->set_num_devices([]() { return 4; }); + + // Wildcard matches any in-range device + EXPECT_TRUE(group->match(0)); + EXPECT_TRUE(group->match(3)); + // Out of range devices return false + EXPECT_FALSE(group->match(100)); + + delete group; +} + +TEST(kmp_trait_expr_group_test, AddExpr) { + kmp_trait_expr_group *group = new kmp_trait_expr_group(); + + group->add_expr(new kmp_trait_expr_single(new kmp_literal_trait(2))); + + // Mock: 4 devices + group->set_num_devices([]() { return 4; }); + + EXPECT_TRUE(group->match(2)); + EXPECT_FALSE(group->match(0)); + EXPECT_FALSE(group->match(5)); // Out of range + + delete group; +} + +TEST(kmp_trait_expr_group_test, MatchORSemantics) { + kmp_trait_expr_group *group = new kmp_trait_expr_group(); + group->set_group_type(kmp_trait_expr_group::OR); + + group->add_expr(new kmp_literal_trait(1)); + group->add_expr(new kmp_literal_trait(2)); + group->add_expr(new kmp_literal_trait(3)); + + // Mock: 5 devices + group->set_num_devices([]() { return 5; }); + + // OR: matches if ANY trait matches + EXPECT_TRUE(group->match(1)); + EXPECT_TRUE(group->match(2)); + EXPECT_TRUE(group->match(3)); + EXPECT_FALSE(group->match(0)); + EXPECT_FALSE(group->match(4)); + + delete group; +} + +TEST(kmp_trait_expr_group_test, MatchANDSemantics) { + kmp_trait_expr_group *group = new kmp_trait_expr_group(); + group->set_group_type(kmp_trait_expr_group::AND); + + // For AND to pass, ALL traits must match the same device + // A single literal only matches one device + group->add_expr(new kmp_literal_trait(2)); + + // Mock: 4 devices + group->set_num_devices([]() { return 4; }); + + EXPECT_TRUE(group->match(2)); + EXPECT_FALSE(group->match(0)); + // Out of range + EXPECT_FALSE(group->match(5)); + + delete group; +} + +TEST(kmp_trait_expr_group_test, MatchANDWithWildcard) { + kmp_trait_expr_group *group = new kmp_trait_expr_group(); + group->set_group_type(kmp_trait_expr_group::AND); + + group->add_expr(new kmp_wildcard_trait()); + group->add_expr(new kmp_literal_trait(2)); + + // Mock: 4 devices + group->set_num_devices([]() { return 4; }); + + // Wildcard matches all, literal matches 2 + // AND: both must match + EXPECT_TRUE(group->match(2)); + EXPECT_FALSE(group->match(0)); + // Out of range + EXPECT_FALSE(group->match(5)); + + delete group; +} + +TEST(kmp_trait_expr_group_test, MatchNegated) { + kmp_trait_expr_group *group = new kmp_trait_expr_group(); + + group->add_expr(new kmp_literal_trait(2)); + group->set_negated(true); + + // Mock: 4 devices + group->set_num_devices([]() { return 4; }); + + // Without negation: matches 2 + // With negation: matches everything in-range except 2 + EXPECT_FALSE(group->match(2)); + EXPECT_TRUE(group->match(0)); + EXPECT_TRUE(group->match(1)); + EXPECT_TRUE(group->match(3)); + // Out of range devices return false regardless of negation + EXPECT_FALSE(group->match(5)); + + delete group; +} + +TEST(kmp_trait_expr_group_test, MatchEmptyGroupOR) { + kmp_trait_expr_group *group = new kmp_trait_expr_group(); + group->set_group_type(kmp_trait_expr_group::OR); + + // Mock: 4 devices + group->set_num_devices([]() { return 4; }); + + // Empty OR: no traits match, so result is false + EXPECT_FALSE(group->match(0)); + EXPECT_FALSE(group->match(1)); + + delete group; +} + +TEST(kmp_trait_expr_group_test, MatchEmptyGroupAND) { + kmp_trait_expr_group *group = new kmp_trait_expr_group(); + group->set_group_type(kmp_trait_expr_group::AND); + + // Mock: 4 devices + group->set_num_devices([]() { return 4; }); + + // Empty AND: vacuously true (0 out of 0 traits match) + EXPECT_TRUE(group->match(0)); + EXPECT_TRUE(group->match(1)); + + delete group; +} + +TEST(kmp_trait_expr_group_test, Equality) { + kmp_trait_expr_group *g1 = new kmp_trait_expr_group(); + kmp_trait_expr_group *g2 = new kmp_trait_expr_group(); + + g1->add_expr(new kmp_literal_trait(1)); + g2->add_expr(new kmp_literal_trait(1)); + + EXPECT_TRUE(*g1 == *g2); + + delete g1; + delete g2; +} + +TEST(kmp_trait_expr_group_test, EqualityDifferentNegation) { + kmp_trait_expr_group *g1 = new kmp_trait_expr_group(); + kmp_trait_expr_group *g2 = new kmp_trait_expr_group(); + + g1->add_expr(new kmp_literal_trait(1)); + g2->add_expr(new kmp_literal_trait(1)); + g2->set_negated(true); + + EXPECT_FALSE(*g1 == *g2); + + delete g1; + delete g2; +} + +TEST(kmp_trait_expr_group_test, NestedGroups) { + kmp_trait_expr_group *outer = new kmp_trait_expr_group(); + outer->set_group_type(kmp_trait_expr_group::OR); + + kmp_trait_expr_group *inner = new kmp_trait_expr_group(); + inner->set_group_type(kmp_trait_expr_group::AND); + inner->add_expr(new kmp_literal_trait(1)); + inner->add_expr(new kmp_wildcard_trait()); + + outer->add_expr(inner); + outer->add_expr(new kmp_literal_trait(2)); + + // Mock: 4 devices + outer->set_num_devices([]() { return 4; }); + + // Inner matches device 1 (literal 1 AND wildcard) + // Outer matches 1 OR 2 + EXPECT_TRUE(outer->match(1)); + EXPECT_TRUE(outer->match(2)); + EXPECT_FALSE(outer->match(0)); + EXPECT_FALSE(outer->match(3)); + + delete outer; +} + +//===----------------------------------------------------------------------===// +// kmp_trait_clause Tests +//===----------------------------------------------------------------------===// + +TEST(kmp_trait_clause_test, CreateAndDestroy) { + kmp_trait_clause *clause = new kmp_trait_clause(); + EXPECT_NE(clause, nullptr); + delete clause; +} + +TEST(kmp_trait_clause_test, SetExprWithTrait) { + kmp_trait_clause *clause = new kmp_trait_clause(); + clause->set_expr(new kmp_literal_trait(2)); + + // The trait is wrapped in kmp_trait_expr_single internally + kmp_trait_expr *expr = clause->get_expr(); + EXPECT_NE(expr, nullptr); + + delete clause; +} + +TEST(kmp_trait_clause_test, SetExprWithExpr) { + kmp_trait_clause *clause = new kmp_trait_clause(); + kmp_trait_expr_group *group = new kmp_trait_expr_group(); + group->add_expr(new kmp_literal_trait(1)); + clause->set_expr(group); + + EXPECT_EQ(clause->get_expr(), group); + + delete clause; +} + +TEST(kmp_trait_clause_test, Equality) { + kmp_trait_clause *c1 = new kmp_trait_clause(); + kmp_trait_clause *c2 = new kmp_trait_clause(); + + c1->set_expr(new kmp_literal_trait(1)); + c2->set_expr(new kmp_literal_trait(1)); + + EXPECT_TRUE(*c1 == *c2); + + delete c1; + delete c2; +} + +TEST(kmp_trait_clause_test, EqualityDifferentExprs) { + kmp_trait_clause *c1 = new kmp_trait_clause(); + kmp_trait_clause *c2 = new kmp_trait_clause(); + + c1->set_expr(new kmp_literal_trait(1)); + c2->set_expr(new kmp_literal_trait(2)); + + EXPECT_FALSE(*c1 == *c2); + + delete c1; + delete c2; +} + +//===----------------------------------------------------------------------===// +// kmp_trait_context Tests +//===----------------------------------------------------------------------===// + +TEST(kmp_trait_context_test, CreateAndDestroy) { + kmp_trait_context *context = new kmp_trait_context(); + EXPECT_NE(context, nullptr); + delete context; +} + +TEST(kmp_trait_context_test, AddClause) { + kmp_trait_context *context = new kmp_trait_context(); + kmp_trait_clause *clause = new kmp_trait_clause(); + clause->set_expr(new kmp_literal_trait(2)); + context->add_clause(clause); + + // Mock: 4 devices + context->set_num_devices([]() { return 4; }); + + EXPECT_TRUE(context->match(2)); + EXPECT_FALSE(context->match(0)); + // Out of range + EXPECT_FALSE(context->match(5)); + + delete context; +} + +TEST(kmp_trait_context_test, MultipleClauses) { + kmp_trait_context *context = new kmp_trait_context(); + + kmp_trait_clause *c1 = new kmp_trait_clause(); + c1->set_expr(new kmp_literal_trait(1)); + context->add_clause(c1); + + kmp_trait_clause *c2 = new kmp_trait_clause(); + c2->set_expr(new kmp_literal_trait(2)); + context->add_clause(c2); + + kmp_trait_clause *c3 = new kmp_trait_clause(); + c3->set_expr(new kmp_literal_trait(3)); + context->add_clause(c3); + + // Mock: 5 devices + context->set_num_devices([]() { return 5; }); + + // Context uses OR semantics between clauses + EXPECT_TRUE(context->match(1)); + EXPECT_TRUE(context->match(2)); + EXPECT_TRUE(context->match(3)); + EXPECT_FALSE(context->match(0)); + EXPECT_FALSE(context->match(4)); + + delete context; +} + +TEST(kmp_trait_context_test, EmptyContextMatchesNothing) { + kmp_trait_context *context = new kmp_trait_context(); + + // Mock: 4 devices + context->set_num_devices([]() { return 4; }); + + EXPECT_FALSE(context->match(0)); + EXPECT_FALSE(context->match(1)); + + delete context; +} + +TEST(kmp_trait_context_test, WildcardClause) { + kmp_trait_context *context = new kmp_trait_context(); + kmp_trait_clause *clause = new kmp_trait_clause(); + clause->set_expr(new kmp_wildcard_trait()); + context->add_clause(clause); + + // Mock: 4 devices + context->set_num_devices([]() { return 4; }); + + // In-range devices match + EXPECT_TRUE(context->match(0)); + EXPECT_TRUE(context->match(3)); + // Out of range devices return false + EXPECT_FALSE(context->match(100)); + EXPECT_FALSE(context->match(-1)); + + delete context; +} + +TEST(kmp_trait_context_test, EvaluateWithMock) { + kmp_trait_context *context = new kmp_trait_context(); + + // Mock: 5 devices + context->set_num_devices([]() { return 5; }); + + kmp_trait_clause *c1 = new kmp_trait_clause(); + c1->set_expr(new kmp_literal_trait(1)); + context->add_clause(c1); + + kmp_trait_clause *c2 = new kmp_trait_clause(); + c2->set_expr(new kmp_literal_trait(3)); + context->add_clause(c2); + + kmp_vector<int> result = context->evaluate(); + EXPECT_EQ(result.size(), 2u); + EXPECT_TRUE(result.contains(1)); + EXPECT_TRUE(result.contains(3)); + EXPECT_FALSE(result.contains(0)); + EXPECT_FALSE(result.contains(2)); + EXPECT_FALSE(result.contains(4)); + + delete context; +} + +TEST(kmp_trait_context_test, Equality) { + kmp_trait_context *ctx1 = new kmp_trait_context(); + kmp_trait_context *ctx2 = new kmp_trait_context(); + + kmp_trait_clause *c1 = new kmp_trait_clause(); + c1->set_expr(new kmp_literal_trait(1)); + ctx1->add_clause(c1); + + kmp_trait_clause *c2 = new kmp_trait_clause(); + c2->set_expr(new kmp_literal_trait(1)); + ctx2->add_clause(c2); + + EXPECT_TRUE(*ctx1 == *ctx2); + + delete ctx1; + delete ctx2; +} + +TEST(kmp_trait_context_test, EqualityDifferentClauses) { + kmp_trait_context *ctx1 = new kmp_trait_context(); + kmp_trait_context *ctx2 = new kmp_trait_context(); + + kmp_trait_clause *c1 = new kmp_trait_clause(); + c1->set_expr(new kmp_literal_trait(1)); + ctx1->add_clause(c1); + + kmp_trait_clause *c2 = new kmp_trait_clause(); + c2->set_expr(new kmp_literal_trait(2)); + ctx2->add_clause(c2); + + EXPECT_FALSE(*ctx1 == *ctx2); + + delete ctx1; + delete ctx2; +} + +//===----------------------------------------------------------------------===// +// kmp_trait_context Iterator Tests +//===----------------------------------------------------------------------===// + +TEST(kmp_trait_context_test, IteratorRangeBasedFor) { + kmp_trait_context *context = new kmp_trait_context(); + + // Mock: 5 devices + context->set_num_devices([]() { return 5; }); + + kmp_trait_clause *c1 = new kmp_trait_clause(); + c1->set_expr(new kmp_literal_trait(1)); + context->add_clause(c1); + + kmp_trait_clause *c2 = new kmp_trait_clause(); + c2->set_expr(new kmp_literal_trait(3)); + context->add_clause(c2); + + // Use range-based for loop (should auto-evaluate) + kmp_vector<int> collected; + for (int d : *context) { + collected.push_back(d); + } + + EXPECT_EQ(collected.size(), 2u); + EXPECT_TRUE(collected.contains(1)); + EXPECT_TRUE(collected.contains(3)); + + delete context; +} + +TEST(kmp_trait_context_test, IteratorAutoEvaluates) { + kmp_trait_context *context = new kmp_trait_context(); + + // Mock: 4 devices + context->set_num_devices([]() { return 4; }); + + kmp_trait_clause *clause = new kmp_trait_clause(); + clause->set_expr(new kmp_wildcard_trait()); + context->add_clause(clause); + + // Directly use begin()/end() without calling evaluate() first + int count = 0; + for (const int *it = context->begin(); it != context->end(); ++it) { + EXPECT_GE(*it, 0); + EXPECT_LT(*it, 4); + count++; + } + + EXPECT_EQ(count, 4); + + delete context; +} + +TEST(kmp_trait_context_test, IteratorEmptyContext) { + kmp_trait_context *context = new kmp_trait_context(); + + // Mock: 4 devices + context->set_num_devices([]() { return 4; }); + + // Empty context - no clauses added + int count = 0; + for (int d : *context) { + (void)d; + count++; + } + + EXPECT_EQ(count, 0); + EXPECT_EQ(context->begin(), context->end()); + + delete context; +} + +TEST(kmp_trait_context_test, IteratorBeginEnd) { + kmp_trait_context *context = new kmp_trait_context(); + + // Mock: 3 devices + context->set_num_devices([]() { return 3; }); + + kmp_trait_clause *clause = new kmp_trait_clause(); + clause->set_expr(new kmp_literal_trait(2)); + context->add_clause(clause); + + // Test begin/end directly + const int *b = context->begin(); + const int *e = context->end(); + + EXPECT_EQ(e - b, 1); // Should have exactly 1 element + EXPECT_EQ(*b, 2); + + delete context; +} + +TEST(kmp_trait_context_test, IteratorMultipleDevices) { + kmp_trait_context *context = new kmp_trait_context(); + + // Add clauses for devices 0, 2, 4 + for (int i = 0; i < 6; i += 2) { + kmp_trait_clause *clause = new kmp_trait_clause(); + clause->set_expr(new kmp_literal_trait(i)); + context->add_clause(clause); + } + + // Mock: 6 devices (must be set after adding clauses to propagate to them) + context->set_num_devices([]() { return 6; }); + + // Collect via iterator + kmp_vector<int> collected; + for (int d : *context) { + collected.push_back(d); + } + + EXPECT_EQ(collected.size(), 3u); + EXPECT_TRUE(collected.contains(0)); + EXPECT_TRUE(collected.contains(2)); + EXPECT_TRUE(collected.contains(4)); + + delete context; +} + +TEST(kmp_trait_context_test, IteratorConsistentWithEvaluate) { + kmp_trait_context *context = new kmp_trait_context(); + + // Mock: 5 devices + context->set_num_devices([]() { return 5; }); + + kmp_trait_clause *c1 = new kmp_trait_clause(); + c1->set_expr(new kmp_literal_trait(1)); + context->add_clause(c1); + + kmp_trait_clause *c2 = new kmp_trait_clause(); + c2->set_expr(new kmp_literal_trait(4)); + context->add_clause(c2); + + // Get result via evaluate() + const kmp_vector<int> &eval_result = context->evaluate(); + + // Collect via iterator + kmp_vector<int> iter_result; + for (int d : *context) { + iter_result.push_back(d); + } + + // Both should give the same results + EXPECT_EQ(eval_result.size(), iter_result.size()); + for (size_t i = 0; i < eval_result.size(); i++) { + EXPECT_EQ(eval_result[i], iter_result[i]); + } + + delete context; +} + +TEST(kmp_trait_context_test, EvaluateReturnsByReference) { + kmp_trait_context *context = new kmp_trait_context(); + + // Mock: 3 devices + context->set_num_devices([]() { return 3; }); + + kmp_trait_clause *clause = new kmp_trait_clause(); + clause->set_expr(new kmp_wildcard_trait()); + context->add_clause(clause); + + // Multiple calls to evaluate() should return reference to the same data + const kmp_vector<int> &result1 = context->evaluate(); + const kmp_vector<int> &result2 = context->evaluate(); + + EXPECT_EQ(&result1, &result2); + + delete context; +} + +//===----------------------------------------------------------------------===// +// get_num_devices Propagation Tests +//===----------------------------------------------------------------------===// + +TEST(kmp_trait_context_test, PropagationToClausesAddedAfterSetNumDevices) { + kmp_trait_context *context = new kmp_trait_context(); + + // Set mock BEFORE adding clauses - propagation should still work + context->set_num_devices([]() { return 6; }); + + // Add clauses for devices 0, 2, 4 (all require 6 devices to be in range) + for (int i = 0; i < 6; i += 2) { + kmp_trait_clause *clause = new kmp_trait_clause(); + clause->set_expr(new kmp_literal_trait(i)); + context->add_clause(clause); + } + + // All three devices should match because propagation worked + kmp_vector<int> collected; + for (int d : *context) { + collected.push_back(d); + } + + EXPECT_EQ(collected.size(), 3u); + EXPECT_TRUE(collected.contains(0)); + EXPECT_TRUE(collected.contains(2)); + EXPECT_TRUE(collected.contains(4)); + + delete context; +} + +TEST(kmp_trait_context_test, PropagationToClausesAddedBeforeSetNumDevices) { + kmp_trait_context *context = new kmp_trait_context(); + + // Add clauses BEFORE setting mock + for (int i = 0; i < 6; i += 2) { + kmp_trait_clause *clause = new kmp_trait_clause(); + clause->set_expr(new kmp_literal_trait(i)); + context->add_clause(clause); + } + + // Set mock AFTER adding clauses + context->set_num_devices([]() { return 6; }); + + // All three devices should match + kmp_vector<int> collected; + for (int d : *context) { + collected.push_back(d); + } + + EXPECT_EQ(collected.size(), 3u); + EXPECT_TRUE(collected.contains(0)); + EXPECT_TRUE(collected.contains(2)); + EXPECT_TRUE(collected.contains(4)); + + delete context; +} + +TEST(kmp_trait_expr_group_test, PropagationToExprsAddedAfterSetNumDevices) { + kmp_trait_expr_group *group = new kmp_trait_expr_group(); + + // Set mock BEFORE adding expressions + group->set_num_devices([]() { return 8; }); + + // Add expressions for devices 5, 6, 7 (require 8 devices) + group->add_expr(new kmp_literal_trait(5)); + group->add_expr(new kmp_literal_trait(6)); + group->add_expr(new kmp_literal_trait(7)); + + // All should match + EXPECT_TRUE(group->match(5)); + EXPECT_TRUE(group->match(6)); + EXPECT_TRUE(group->match(7)); + EXPECT_FALSE(group->match(4)); + + delete group; +} + +TEST(kmp_trait_expr_group_test, PropagationToExprsAddedBeforeSetNumDevices) { + kmp_trait_expr_group *group = new kmp_trait_expr_group(); + + // Add expressions BEFORE setting mock + group->add_expr(new kmp_literal_trait(5)); + group->add_expr(new kmp_literal_trait(6)); + group->add_expr(new kmp_literal_trait(7)); + + // Set mock AFTER adding expressions + group->set_num_devices([]() { return 8; }); + + // All should match + EXPECT_TRUE(group->match(5)); + EXPECT_TRUE(group->match(6)); + EXPECT_TRUE(group->match(7)); + EXPECT_FALSE(group->match(4)); + + delete group; +} + +TEST(kmp_trait_expr_group_test, PropagationToNestedGroups) { + kmp_trait_expr_group *outer = new kmp_trait_expr_group(); + outer->set_group_type(kmp_trait_expr_group::OR); + + // Set mock on outer group FIRST + outer->set_num_devices([]() { return 10; }); + + // Create inner group and add high-numbered devices + kmp_trait_expr_group *inner = new kmp_trait_expr_group(); + inner->set_group_type(kmp_trait_expr_group::OR); + inner->add_expr(new kmp_literal_trait(8)); + inner->add_expr(new kmp_literal_trait(9)); + + // Add inner to outer - should propagate mock to inner and its children + outer->add_expr(inner); + + // Add another expression directly to outer + outer->add_expr(new kmp_literal_trait(7)); + + // All should match with 10 devices + EXPECT_TRUE(outer->match(7)); + EXPECT_TRUE(outer->match(8)); + EXPECT_TRUE(outer->match(9)); + EXPECT_FALSE(outer->match(10)); // Out of range + + delete outer; +} + +TEST(kmp_trait_expr_group_test, PropagationToDeeplyNestedGroups) { + // Create a deeply nested structure: outer -> middle -> inner + kmp_trait_expr_group *outer = new kmp_trait_expr_group(); + outer->set_group_type(kmp_trait_expr_group::OR); + + // Set mock on outer + outer->set_num_devices([]() { return 12; }); + + kmp_trait_expr_group *middle = new kmp_trait_expr_group(); + middle->set_group_type(kmp_trait_expr_group::OR); + + kmp_trait_expr_group *inner = new kmp_trait_expr_group(); + inner->set_group_type(kmp_trait_expr_group::OR); + inner->add_expr(new kmp_literal_trait(10)); + inner->add_expr(new kmp_literal_trait(11)); + + middle->add_expr(inner); + middle->add_expr(new kmp_literal_trait(9)); + + outer->add_expr(middle); + outer->add_expr(new kmp_literal_trait(8)); + + // All devices 8-11 should match (requires 12 devices) + EXPECT_TRUE(outer->match(8)); + EXPECT_TRUE(outer->match(9)); + EXPECT_TRUE(outer->match(10)); + EXPECT_TRUE(outer->match(11)); + EXPECT_FALSE(outer->match(12)); // Out of range + + delete outer; +} + +TEST(kmp_trait_context_test, PropagationToNestedGroupsInClauses) { + kmp_trait_context *context = new kmp_trait_context(); + + // Set mock on context FIRST + context->set_num_devices([]() { return 10; }); + + // Create a group with nested structure + kmp_trait_expr_group *group = new kmp_trait_expr_group(); + group->set_group_type(kmp_trait_expr_group::OR); + + kmp_trait_expr_group *inner = new kmp_trait_expr_group(); + inner->set_group_type(kmp_trait_expr_group::OR); + inner->add_expr(new kmp_literal_trait(8)); + inner->add_expr(new kmp_literal_trait(9)); + + group->add_expr(inner); + group->add_expr(new kmp_literal_trait(7)); + + // Create clause with the group + kmp_trait_clause *clause = new kmp_trait_clause(); + clause->set_expr(group); + + // Add clause to context - should propagate to group and inner + context->add_clause(clause); + + // All should match + kmp_vector<int> collected; + for (int d : *context) { + collected.push_back(d); + } + + EXPECT_EQ(collected.size(), 3u); + EXPECT_TRUE(collected.contains(7)); + EXPECT_TRUE(collected.contains(8)); + EXPECT_TRUE(collected.contains(9)); + + delete context; +} + +TEST(kmp_trait_context_test, PropagationMixedOrder) { + // Test a complex scenario with mixed ordering + kmp_trait_context *context = new kmp_trait_context(); + + // Add first clause before set_num_devices + kmp_trait_clause *c1 = new kmp_trait_clause(); + c1->set_expr(new kmp_literal_trait(5)); + context->add_clause(c1); + + // Set mock + context->set_num_devices([]() { return 8; }); + + // Add second clause after set_num_devices + kmp_trait_clause *c2 = new kmp_trait_clause(); + c2->set_expr(new kmp_literal_trait(6)); + context->add_clause(c2); + + // Add third clause with a group + kmp_trait_expr_group *group = new kmp_trait_expr_group(); + group->add_expr(new kmp_literal_trait(7)); + + kmp_trait_clause *c3 = new kmp_trait_clause(); + c3->set_expr(group); + context->add_clause(c3); + + // All three should match + kmp_vector<int> collected; + for (int d : *context) { + collected.push_back(d); + } + + EXPECT_EQ(collected.size(), 3u); + EXPECT_TRUE(collected.contains(5)); + EXPECT_TRUE(collected.contains(6)); + EXPECT_TRUE(collected.contains(7)); + + delete context; +} + +} // namespace |
