// Copyright (C) 2025 Free Software Foundation, Inc. // This file is part of GCC. // GCC is free software; you can redistribute it and/or modify it under // the terms of the GNU General Public License as published by the Free // Software Foundation; either version 3, or (at your option) any later // version. // GCC is distributed in the hope that it will be useful, but WITHOUT ANY // WARRANTY; without even the implied warranty of MERCHANTABILITY or // FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License // for more details. // You should have received a copy of the GNU General Public License // along with GCC; see the file COPYING3. If not see // . #include "rust-derive-ord.h" #include "rust-ast.h" #include "rust-derive-cmp-common.h" #include "rust-derive.h" #include "rust-item.h" #include "rust-system.h" namespace Rust { namespace AST { DeriveOrd::DeriveOrd (Ordering ordering, location_t loc) : DeriveVisitor (loc), ordering (ordering) {} std::unique_ptr DeriveOrd::go (Item &item) { item.accept_vis (*this); return std::move (expanded); } std::unique_ptr DeriveOrd::cmp_call (std::unique_ptr &&self_expr, std::unique_ptr &&other_expr) { auto cmp_fn_path = builder.path_in_expression ( {"core", "cmp", trait (ordering), fn (ordering)}, true); return builder.call (ptrify (cmp_fn_path), vec (builder.ref (std::move (self_expr)), builder.ref (std::move (other_expr)))); } std::unique_ptr DeriveOrd::cmp_impl ( std::unique_ptr &&fn_block, Identifier type_name, const std::vector> &type_generics) { auto fn = cmp_fn (std::move (fn_block), type_name); auto trait = ordering == Ordering::Partial ? "PartialOrd" : "Ord"; auto trait_path = builder.type_path ({"core", "cmp", trait}, true); auto trait_bound = builder.trait_bound (builder.type_path ({"core", "cmp", trait}, true)); auto trait_items = vec (std::move (fn)); auto cmp_generics = setup_impl_generics (type_name.as_string (), type_generics, std::move (trait_bound)); return builder.trait_impl (trait_path, std::move (cmp_generics.self_type), std::move (trait_items), std::move (cmp_generics.impl)); } std::unique_ptr DeriveOrd::cmp_fn (std::unique_ptr &&block, Identifier type_name) { // Ordering auto return_type = builder.type_path ({"core", "cmp", "Ordering"}, true); // In the case of PartialOrd, we return an Option if (ordering == Ordering::Partial) { auto generic = GenericArg::create_type (ptrify (return_type)); auto generic_seg = builder.type_path_segment_generic ( "Option", GenericArgs ({}, {generic}, {}, loc)); auto core = builder.type_path_segment ("core"); auto option = builder.type_path_segment ("option"); return_type = builder.type_path (vec (std::move (core), std::move (option), std::move (generic_seg)), true); } // &self, other: &Self auto params = vec ( builder.self_ref_param (), builder.function_param (builder.identifier_pattern ("other"), builder.reference_type (ptrify ( builder.type_path (type_name.as_string ()))))); auto function_name = fn (ordering); return builder.function (function_name, std::move (params), ptrify (return_type), std::move (block)); } std::unique_ptr DeriveOrd::make_equal () { std::unique_ptr equal = ptrify ( builder.path_in_expression ({"core", "cmp", "Ordering", "Equal"}, true)); // We need to wrap the pattern in Option::Some if we are doing partial // ordering if (ordering == Ordering::Partial) { auto pattern_items = std::unique_ptr ( new TupleStructItemsNoRange (vec (std::move (equal)))); equal = std::make_unique (builder.path_in_expression ( LangItem::Kind::OPTION_SOME), std::move (pattern_items)); } return equal; } std::pair DeriveOrd::make_cmp_arms () { // All comparison results other than Ordering::Equal auto non_equal = builder.identifier_pattern (DeriveOrd::not_equal); auto equal = make_equal (); return {builder.match_arm (std::move (equal)), builder.match_arm (std::move (non_equal))}; } std::unique_ptr DeriveOrd::recursive_match (std::vector &&members) { if (members.empty ()) { std::unique_ptr value = ptrify ( builder.path_in_expression ({"core", "cmp", "Ordering", "Equal"}, true)); if (ordering == Ordering::Partial) value = builder.call (ptrify (builder.path_in_expression ( LangItem::Kind::OPTION_SOME)), std::move (value)); return value; } std::unique_ptr final_expr = nullptr; for (auto it = members.rbegin (); it != members.rend (); it++) { auto &member = *it; auto call = cmp_call (std::move (member.self_expr), std::move (member.other_expr)); // For the last member (so the first iterator), we just create a call // expression if (it == members.rbegin ()) { final_expr = std::move (call); continue; } // If we aren't dealing with the last member, then we need to wrap all of // that in a big match expression and keep going auto match_arms = make_cmp_arms (); auto match_cases = {builder.match_case (std::move (match_arms.first), std::move (final_expr)), builder.match_case (std::move (match_arms.second), builder.identifier (DeriveOrd::not_equal))}; final_expr = builder.match (std::move (call), std::move (match_cases)); } return final_expr; } // we need to do a recursive match expression for all of the fields used in a // struct so for something like struct Foo { a: i32, b: i32, c: i32 } we must // first compare each `a` field, then `b`, then `c`, like this: // // match cmp_fn(self., other.) { // Ordering::Equal => , // cmp => cmp, // } // // and the recurse will be the exact same expression, on the next field. so that // our result looks like this: // // match cmp_fn(self.a, other.a) { // Ordering::Equal => match cmp_fn(self.b, other.b) { // Ordering::Equal =>cmp_fn(self.c, other.c), // cmp => cmp, // } // cmp => cmp, // } // // the last field comparison needs not to be a match but just the function call. // this is going to be annoying lol void DeriveOrd::visit_struct (StructStruct &item) { auto fields = SelfOther::fields (builder, item.get_fields ()); auto match_expr = recursive_match (std::move (fields)); expanded = cmp_impl (builder.block (std::move (match_expr)), item.get_identifier (), item.get_generic_params ()); } // same as structs, but for each field index instead of each field name - // straightforward once we have `visit_struct` working void DeriveOrd::visit_tuple (TupleStruct &item) { auto fields = SelfOther::indexes (builder, item.get_fields ()); auto match_expr = recursive_match (std::move (fields)); expanded = cmp_impl (builder.block (std::move (match_expr)), item.get_identifier (), item.get_generic_params ()); } // for enums, we need to generate a match for each of the enum's variant that // contains data and then do the same thing as visit_struct or visit_enum. if // the two aren't the same variant, then compare the two discriminant values for // all the dataless enum variants and in the general case. // // so for enum Foo { A(i32, i32), B, C } we need to do the following // // match (self, other) { // (A(self_0, self_1), A(other_0, other_1)) => { // match cmp_fn(self_0, other_0) { // Ordering::Equal => cmp_fn(self_1, other_1), // cmp => cmp, // }, // _ => cmp_fn(discr_value(self), discr_value(other)) // } void DeriveOrd::visit_enum (Enum &item) { // NOTE: We can factor this even further with DerivePartialEq, but this is // getting out of scope for this PR surely auto cases = std::vector (); auto type_name = item.get_identifier ().as_string (); auto let_sd = builder.discriminant_value (DeriveOrd::self_discr, "self"); auto let_od = builder.discriminant_value (DeriveOrd::other_discr, "other"); auto discr_cmp = cmp_call (builder.identifier (DeriveOrd::self_discr), builder.identifier (DeriveOrd::other_discr)); auto recursive_match_fn = [this] (std::vector &&fields) { return recursive_match (std::move (fields)); }; for (auto &variant : item.get_variants ()) { auto enum_builder = EnumMatchBuilder (type_name, variant->get_identifier ().as_string (), recursive_match_fn, builder); switch (variant->get_enum_item_kind ()) { case EnumItem::Kind::Struct: cases.emplace_back (enum_builder.strukt (*variant)); break; case EnumItem::Kind::Tuple: cases.emplace_back (enum_builder.tuple (*variant)); break; case EnumItem::Kind::Identifier: case EnumItem::Kind::Discriminant: // We don't need to do anything for these, as they are handled by the // discriminant value comparison break; } } // Add the last case which compares the discriminant values in case `self` and // `other` are actually different variants of the enum cases.emplace_back ( builder.match_case (builder.wildcard (), std::move (discr_cmp))); auto match = builder.match (builder.tuple (vec (builder.identifier ("self"), builder.identifier ("other"))), std::move (cases)); expanded = cmp_impl (builder.block (vec (std::move (let_sd), std::move (let_od)), std::move (match)), type_name, item.get_generic_params ()); } void DeriveOrd::visit_union (Union &item) { auto trait_name = trait (ordering); rust_error_at (item.get_locus (), "derive(%s) cannot be used on unions", trait_name.c_str ()); } } // namespace AST } // namespace Rust