diff options
author | bors[bot] <26634292+bors[bot]@users.noreply.github.com> | 2022-07-07 20:00:19 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-07-07 20:00:19 +0000 |
commit | b1c9ac14a06a2c0dd4acb8758f3c88c2c610a31e (patch) | |
tree | 2592a09086138cd82b633cd1902f9c62cc52dcf5 | |
parent | a5c82a1d2b430585b915801a2ae1b73f1ea65f60 (diff) | |
parent | b7c4aa942fa828038106105524af2c568acca7ac (diff) | |
download | gcc-b1c9ac14a06a2c0dd4acb8758f3c88c2c610a31e.zip gcc-b1c9ac14a06a2c0dd4acb8758f3c88c2c610a31e.tar.gz gcc-b1c9ac14a06a2c0dd4acb8758f3c88c2c610a31e.tar.bz2 |
Merge #1367
1367: Initial support for matches on tuples r=dafaust a=dafaust
Initial work on supporting more complex match expressions, starting with tuples. It's proven to be quite a bit more complicated than I originally expected. Still this implementation is very naive, probably very poor performance, and skips all the nuanced stuff like bindings in patterns and pattern guards.
Overview
The general idea is that we lower `MatchExpr`s to `SWITCH_EXPR` trees, but a switch can only operate on one value at a time. When scrutinizing a tuple, we can't build a switch that checks all the elements at once. So instead, just before lowering to trees, we construct a new set of nested match expressions which tests the tuple one element at a time, and is (hopefully) semantically identical to the original match. These new nested matches are simplified, so that each one tests only one thing and therefore *can* lower directly to a `SWITCH_EXPR`.
To see an example, this
```text
match (tupA, tupB, tupC) {
(a1, b1, c1) => { blk1 }
(a2, b2, c2) => { blk2 }
(a1, b3, c3) => { blk3 }
...
_ => { blk4 }
}
```
after one step becomes
```text
match tupA {
a1 => {
match (tupB, tupC) {
(b1, c1) => { blk1 }
(b3, c3) => { blk3 }
...
_ => { blk4 }
}
a2 => {
match (tupB, tupC) {
(b2, c2) => { blk2 }
...
_ => { blk4 }
}
}
_ => { blk4 }
}
```
Repeat until all matches are fully simplified, then compile it all into trees.
Now, we might end up with a lot of duplicated basic blocks here as a result, but the optimization in GENERIC and GIMPLE (further on in gcc) is very good and knows how to take care of that for us.
This approach seems somewhat similar to how rustc lowers HIR to MIR, which is hopefully a good sign :)
Addresses: #1081
Co-authored-by: David Faust <david.faust@oracle.com>
-rw-r--r-- | gcc/rust/backend/rust-compile-expr.cc | 472 | ||||
-rw-r--r-- | gcc/rust/hir/tree/rust-hir.h | 6 | ||||
-rw-r--r-- | gcc/testsuite/rust/execute/torture/match_tuple1.rs | 41 |
3 files changed, 494 insertions, 25 deletions
diff --git a/gcc/rust/backend/rust-compile-expr.cc b/gcc/rust/backend/rust-compile-expr.cc index d12dc78..3740c80 100644 --- a/gcc/rust/backend/rust-compile-expr.cc +++ b/gcc/rust/backend/rust-compile-expr.cc @@ -181,6 +181,369 @@ CompileExpr::visit (HIR::DereferenceExpr &expr) known_valid, expr.get_locus ()); } +// Helper for sort_tuple_patterns. +// Determine whether Patterns a and b are really the same pattern. +// FIXME: This is a nasty hack to avoid properly implementing a comparison +// for Patterns, which we really probably do want at some point. +static bool +patterns_mergeable (HIR::Pattern *a, HIR::Pattern *b) +{ + if (!a || !b) + return false; + + HIR::Pattern::PatternType pat_type = a->get_pattern_type (); + if (b->get_pattern_type () != pat_type) + return false; + + switch (pat_type) + { + case HIR::Pattern::PatternType::PATH: { + // FIXME: this is far too naive + HIR::PathPattern &aref = *static_cast<HIR::PathPattern *> (a); + HIR::PathPattern &bref = *static_cast<HIR::PathPattern *> (b); + if (aref.get_num_segments () != bref.get_num_segments ()) + return false; + + const auto &asegs = aref.get_segments (); + const auto &bsegs = bref.get_segments (); + for (size_t i = 0; i < asegs.size (); i++) + { + if (asegs[i].as_string () != bsegs[i].as_string ()) + return false; + } + return true; + } + break; + case HIR::Pattern::PatternType::LITERAL: { + HIR::LiteralPattern &aref = *static_cast<HIR::LiteralPattern *> (a); + HIR::LiteralPattern &bref = *static_cast<HIR::LiteralPattern *> (b); + return aref.get_literal ().is_equal (bref.get_literal ()); + } + break; + case HIR::Pattern::PatternType::IDENTIFIER: { + // TODO + } + break; + case HIR::Pattern::PatternType::WILDCARD: + return true; + break; + + // TODO + + default:; + } + return false; +} + +// A little container for rearranging the patterns and cases in a match +// expression while simplifying. +struct PatternMerge +{ + std::unique_ptr<HIR::MatchCase> wildcard; + std::vector<std::unique_ptr<HIR::Pattern>> heads; + std::vector<std::vector<HIR::MatchCase>> cases; +}; + +// Helper for simplify_tuple_match. +// For each tuple pattern in a given match, pull out the first elt of the +// tuple and construct a new MatchCase with the remaining tuple elts as the +// pattern. Return a mapping from each _unique_ first tuple element to a +// vec of cases for a new match. +// +// FIXME: This used to be a std::map<Pattern, Vec<MatchCase>>, but it doesn't +// actually work like we want - the Pattern includes an HIR ID, which is unique +// per Pattern object. This means we don't have a good means for comparing +// Patterns. It would probably be best to actually implement a means of +// properly comparing patterns, and then use an actual map. +// +static struct PatternMerge +sort_tuple_patterns (HIR::MatchExpr &expr) +{ + rust_assert (expr.get_scrutinee_expr ()->get_expression_type () + == HIR::Expr::ExprType::Tuple); + + struct PatternMerge result; + result.wildcard = nullptr; + result.heads = std::vector<std::unique_ptr<HIR::Pattern>> (); + result.cases = std::vector<std::vector<HIR::MatchCase>> (); + + for (auto &match_case : expr.get_match_cases ()) + { + HIR::MatchArm &case_arm = match_case.get_arm (); + + // FIXME: Note we are only dealing with the first pattern in the arm. + // The patterns vector in the arm might hold many patterns, which are the + // patterns separated by the '|' token. Rustc abstracts these as "Or" + // patterns, and part of its simplification process is to get rid of them. + // We should get rid of the ORs too, maybe here or earlier than here? + auto pat = case_arm.get_patterns ()[0]->clone_pattern (); + + // Record wildcards so we can add them in inner matches. + if (pat->get_pattern_type () == HIR::Pattern::PatternType::WILDCARD) + { + // The *whole* pattern is a wild card (_). + result.wildcard + = std::unique_ptr<HIR::MatchCase> (new HIR::MatchCase (match_case)); + continue; + } + + rust_assert (pat->get_pattern_type () + == HIR::Pattern::PatternType::TUPLE); + + auto ref = *static_cast<HIR::TuplePattern *> (pat.get ()); + + rust_assert (ref.has_tuple_pattern_items ()); + + auto items + = HIR::TuplePattern (ref).get_items ()->clone_tuple_pattern_items (); + if (items->get_pattern_type () + == HIR::TuplePatternItems::TuplePatternItemType::MULTIPLE) + { + auto items_ref + = *static_cast<HIR::TuplePatternItemsMultiple *> (items.get ()); + + // Pop the first pattern out + auto patterns = std::vector<std::unique_ptr<HIR::Pattern>> (); + auto first = items_ref.get_patterns ()[0]->clone_pattern (); + for (auto p = items_ref.get_patterns ().begin () + 1; + p != items_ref.get_patterns ().end (); p++) + { + patterns.push_back ((*p)->clone_pattern ()); + } + + // if there is only one pattern left, don't make a tuple out of it + std::unique_ptr<HIR::Pattern> result_pattern; + if (patterns.size () == 1) + { + result_pattern = std::move (patterns[0]); + } + else + { + auto new_items = std::unique_ptr<HIR::TuplePatternItems> ( + new HIR::TuplePatternItemsMultiple (std::move (patterns))); + + // Construct a TuplePattern from the rest of the patterns + result_pattern = std::unique_ptr<HIR::Pattern> ( + new HIR::TuplePattern (ref.get_pattern_mappings (), + std::move (new_items), + ref.get_locus ())); + } + + // I don't know why we need to make foo separately here but + // using the { new_tuple } syntax in new_arm constructor does not + // compile. + auto foo = std::vector<std::unique_ptr<HIR::Pattern>> (); + foo.emplace_back (std::move (result_pattern)); + HIR::MatchArm new_arm (std::move (foo), Location (), nullptr, + AST::AttrVec ()); + + HIR::MatchCase new_case (match_case.get_mappings (), new_arm, + match_case.get_expr ()->clone_expr ()); + + bool pushed = false; + for (size_t i = 0; i < result.heads.size (); i++) + { + if (patterns_mergeable (result.heads[i].get (), first.get ())) + { + result.cases[i].push_back (new_case); + pushed = true; + } + } + + if (!pushed) + { + result.heads.push_back (std::move (first)); + result.cases.push_back ({new_case}); + } + } + else /* TuplePatternItemType::RANGED */ + { + // FIXME + gcc_unreachable (); + } + } + + return result; +} + +// Helper for CompileExpr::visit (HIR::MatchExpr). +// Given a MatchExpr where the scrutinee is some kind of tuple, build an +// equivalent match where only one element of the tuple is examined at a time. +// This resulting match can then be lowered to a SWITCH_EXPR tree directly. +// +// The approach is as follows: +// 1. Split the scrutinee and each pattern into the first (head) and the +// rest (tail). +// 2. Build a mapping of unique pattern heads to the cases (tail and expr) +// that shared that pattern head in the original match. +// (This is the job of sort_tuple_patterns ()). +// 3. For each unique pattern head, build a new MatchCase where the pattern +// is the unique head, and the expression is a new match where: +// - The scrutinee is the tail of the original scrutinee +// - The cases are are those built by the mapping in step 2, i.e. the +// tails of the patterns and the corresponing expressions from the +// original match expression. +// 4. Do this recursively for each inner match, until there is nothing more +// to simplify. +// 5. Build the resulting match which scrutinizes the head of the original +// scrutinee, using the cases built in step 3. +static HIR::MatchExpr +simplify_tuple_match (HIR::MatchExpr &expr) +{ + if (expr.get_scrutinee_expr ()->get_expression_type () + != HIR::Expr::ExprType::Tuple) + return expr; + + auto ref = *static_cast<HIR::TupleExpr *> (expr.get_scrutinee_expr ().get ()); + + auto &tail = ref.get_tuple_elems (); + rust_assert (tail.size () > 1); + + auto head = std::move (tail[0]); + tail.erase (tail.begin (), tail.begin () + 1); + + // e.g. + // match (tupA, tupB, tupC) { + // (a1, b1, c1) => { blk1 }, + // (a2, b2, c2) => { blk2 }, + // (a1, b3, c3) => { blk3 }, + // } + // tail = (tupB, tupC) + // head = tupA + + // Make sure the tail is only a tuple if it consists of at least 2 elements. + std::unique_ptr<HIR::Expr> remaining; + if (tail.size () == 1) + remaining = std::move (tail[0]); + else + remaining = std::unique_ptr<HIR::Expr> ( + new HIR::TupleExpr (ref.get_mappings (), std::move (tail), + AST::AttrVec (), ref.get_outer_attrs (), + ref.get_locus ())); + + // e.g. + // a1 -> [(b1, c1) => { blk1 }, + // (b3, c3) => { blk3 }] + // a2 -> [(b2, c2) => { blk2 }] + struct PatternMerge map = sort_tuple_patterns (expr); + + std::vector<HIR::MatchCase> cases; + // Construct the inner match for each unique first elt of the tuple + // patterns + for (size_t i = 0; i < map.heads.size (); i++) + { + auto inner_match_cases = map.cases[i]; + + // If there is a wildcard at the outer match level, then need to + // propegate the wildcard case into *every* inner match. + // FIXME: It is probably not correct to add this unconditionally, what if + // we have a pattern like (a, _, c)? Then there is already a wildcard in + // the inner matches, and having two will cause two 'default:' blocks + // which is an error. + if (map.wildcard != nullptr) + { + inner_match_cases.push_back (*(map.wildcard.get ())); + } + + // match (tupB, tupC) { + // (b1, c1) => { blk1 }, + // (b3, c3) => { blk3 } + // } + HIR::MatchExpr inner_match (expr.get_mappings (), + remaining->clone_expr (), inner_match_cases, + AST::AttrVec (), expr.get_outer_attrs (), + expr.get_locus ()); + + inner_match = simplify_tuple_match (inner_match); + + auto outer_arm_pat = std::vector<std::unique_ptr<HIR::Pattern>> (); + outer_arm_pat.emplace_back (map.heads[i]->clone_pattern ()); + + HIR::MatchArm outer_arm (std::move (outer_arm_pat), expr.get_locus ()); + + // Need to move the inner match to the heap and put it in a unique_ptr to + // build the actual match case of the outer expression + // auto inner_expr = std::unique_ptr<HIR::Expr> (new HIR::MatchExpr + // (inner_match)); + auto inner_expr = inner_match.clone_expr (); + + // a1 => match (tupB, tupC) { ... } + HIR::MatchCase outer_case (expr.get_mappings (), outer_arm, + std::move (inner_expr)); + + cases.push_back (outer_case); + } + + // If there was a wildcard, make sure to include it at the outer match level + // too. + if (map.wildcard != nullptr) + { + cases.push_back (*(map.wildcard.get ())); + } + + // match tupA { + // a1 => match (tupB, tupC) { + // (b1, c1) => { blk1 }, + // (b3, c3) => { blk3 } + // } + // a2 => match (tupB, tupC) { + // (b2, c2) => { blk2 } + // } + // } + HIR::MatchExpr outer_match (expr.get_mappings (), std::move (head), cases, + AST::AttrVec (), expr.get_outer_attrs (), + expr.get_locus ()); + + return outer_match; +} + +// Helper for CompileExpr::visit (HIR::MatchExpr). +// Check that the scrutinee of EXPR is a valid kind of expression to match on. +// Return the TypeKind of the scrutinee if it is valid, or TyTy::TypeKind::ERROR +// if not. +static TyTy::TypeKind +check_match_scrutinee (HIR::MatchExpr &expr, Context *ctx) +{ + TyTy::BaseType *scrutinee_expr_tyty = nullptr; + if (!ctx->get_tyctx ()->lookup_type ( + expr.get_scrutinee_expr ()->get_mappings ().get_hirid (), + &scrutinee_expr_tyty)) + { + return TyTy::TypeKind::ERROR; + } + + TyTy::TypeKind scrutinee_kind = scrutinee_expr_tyty->get_kind (); + rust_assert ((TyTy::is_primitive_type_kind (scrutinee_kind) + && scrutinee_kind != TyTy::TypeKind::NEVER) + || scrutinee_kind == TyTy::TypeKind::ADT + || scrutinee_kind == TyTy::TypeKind::TUPLE); + + if (scrutinee_kind == TyTy::TypeKind::ADT) + { + // this will need to change but for now the first pass implementation, + // lets assert this is the case + TyTy::ADTType *adt = static_cast<TyTy::ADTType *> (scrutinee_expr_tyty); + rust_assert (adt->is_enum ()); + rust_assert (adt->number_of_variants () > 0); + } + else if (scrutinee_kind == TyTy::TypeKind::FLOAT) + { + // FIXME: CASE_LABEL_EXPR does not support floating point types. + // Find another way to compile these. + rust_sorry_at (expr.get_locus (), + "match on floating-point types is not yet supported"); + } + + TyTy::BaseType *expr_tyty = nullptr; + if (!ctx->get_tyctx ()->lookup_type (expr.get_mappings ().get_hirid (), + &expr_tyty)) + { + return TyTy::TypeKind::ERROR; + } + + return scrutinee_kind; +} + void CompileExpr::visit (HIR::MatchExpr &expr) { @@ -213,36 +576,13 @@ CompileExpr::visit (HIR::MatchExpr &expr) the control flow graph. */ // DEFTREECODE (CASE_LABEL_EXPR, "case_label_expr", tcc_statement, 4) - TyTy::BaseType *scrutinee_expr_tyty = nullptr; - if (!ctx->get_tyctx ()->lookup_type ( - expr.get_scrutinee_expr ()->get_mappings ().get_hirid (), - &scrutinee_expr_tyty)) + TyTy::TypeKind scrutinee_kind = check_match_scrutinee (expr, ctx); + if (scrutinee_kind == TyTy::TypeKind::ERROR) { translated = error_mark_node; return; } - TyTy::TypeKind scrutinee_kind = scrutinee_expr_tyty->get_kind (); - rust_assert ((TyTy::is_primitive_type_kind (scrutinee_kind) - && scrutinee_kind != TyTy::TypeKind::NEVER) - || scrutinee_kind == TyTy::TypeKind::ADT); - - if (scrutinee_kind == TyTy::TypeKind::ADT) - { - // this will need to change but for now the first pass implementation, - // lets assert this is the case - TyTy::ADTType *adt = static_cast<TyTy::ADTType *> (scrutinee_expr_tyty); - rust_assert (adt->is_enum ()); - rust_assert (adt->number_of_variants () > 0); - } - else if (scrutinee_kind == TyTy::TypeKind::FLOAT) - { - // FIXME: CASE_LABEL_EXPR does not support floating point types. - // Find another way to compile these. - rust_sorry_at (expr.get_locus (), - "match on floating-point types is not yet supported"); - } - TyTy::BaseType *expr_tyty = nullptr; if (!ctx->get_tyctx ()->lookup_type (expr.get_mappings ().get_hirid (), &expr_tyty)) @@ -290,6 +630,88 @@ CompileExpr::visit (HIR::MatchExpr &expr) scrutinee_first_record_expr, 0, expr.get_scrutinee_expr ()->get_locus ()); } + else if (scrutinee_kind == TyTy::TypeKind::TUPLE) + { + // match on tuple becomes a series of nested switches, with one level + // for each element of the tuple from left to right. + auto exprtype = expr.get_scrutinee_expr ()->get_expression_type (); + switch (exprtype) + { + case HIR::Expr::ExprType::Tuple: { + // Build an equivalent expression which is nicer to lower. + HIR::MatchExpr outer_match = simplify_tuple_match (expr); + + // We've rearranged the match into something that lowers better + // to GENERIC trees. + // For actually doing the lowering we need to compile the match + // we've just made. But we're half-way through compiling the + // original one. + // ... + // For now, let's just replace the original with the rearranged one + // we just made, and compile that instead. What could go wrong? :) + // + // FIXME: What about when we decide a temporary is needed above? + // We might have already pushed a statement for it that + // we no longer need. Probably need to rearrange the order + // of these steps. + expr = outer_match; + + scrutinee_kind = check_match_scrutinee (expr, ctx); + if (scrutinee_kind == TyTy::TypeKind::ERROR) + { + translated = error_mark_node; + return; + } + + // Now compile the scrutinee of the simplified match. + // FIXME: this part is duplicated from above. + match_scrutinee_expr + = CompileExpr::Compile (expr.get_scrutinee_expr ().get (), ctx); + + if (TyTy::is_primitive_type_kind (scrutinee_kind)) + { + match_scrutinee_expr_qualifier_expr = match_scrutinee_expr; + } + else if (scrutinee_kind == TyTy::TypeKind::ADT) + { + // need to access qualifier the field, if we use QUAL_UNION_TYPE + // this would be DECL_QUALIFIER i think. For now this will just + // access the first record field and its respective qualifier + // because it will always be set because this is all a big + // special union + tree scrutinee_first_record_expr + = ctx->get_backend ()->struct_field_expression ( + match_scrutinee_expr, 0, + expr.get_scrutinee_expr ()->get_locus ()); + match_scrutinee_expr_qualifier_expr + = ctx->get_backend ()->struct_field_expression ( + scrutinee_first_record_expr, 0, + expr.get_scrutinee_expr ()->get_locus ()); + } + else + { + // FIXME: There are other cases, but it better not be a Tuple + gcc_unreachable (); + } + } + break; + + case HIR::Expr::ExprType::Ident: { + // FIXME + gcc_unreachable (); + } + break; + + case HIR::Expr::ExprType::Path: { + // FIXME + gcc_unreachable (); + } + break; + + default: + gcc_unreachable (); + } + } else { // FIXME: match on other types of expressions not yet implemented. diff --git a/gcc/rust/hir/tree/rust-hir.h b/gcc/rust/hir/tree/rust-hir.h index 8f94354..ff4c39e 100644 --- a/gcc/rust/hir/tree/rust-hir.h +++ b/gcc/rust/hir/tree/rust-hir.h @@ -122,6 +122,12 @@ public: // Returns whether literal is in an invalid state. bool is_error () const { return value_as_string == ""; } + + bool is_equal (Literal &other) + { + return value_as_string == other.value_as_string && type == other.type + && type_hint == other.type_hint; + } }; /* Base statement abstract class. Note that most "statements" are not allowed in diff --git a/gcc/testsuite/rust/execute/torture/match_tuple1.rs b/gcc/testsuite/rust/execute/torture/match_tuple1.rs new file mode 100644 index 0000000..1874062 --- /dev/null +++ b/gcc/testsuite/rust/execute/torture/match_tuple1.rs @@ -0,0 +1,41 @@ +// { dg-output "x:15\ny:20\n" } + +extern "C" { + fn printf(s: *const i8, ...); +} + +enum Foo { + A, + B, +} + +fn inspect(f: Foo, g: u8) -> i32 { + match (f, g) { + (Foo::A, 1) => { + return 5; + } + + (Foo::A, 2) => { + return 10; + } + + (Foo::B, 2) => { + return 15; + } + + _ => { + return 20; + } + } + return 25; +} + +fn main () -> i32 { + let x = inspect (Foo::B, 2); + let y = inspect (Foo::B, 1); + + printf ("x:%d\n" as *const str as *const i8, x); + printf ("y:%d\n" as *const str as *const i8, y); + + y - x - 5 +} |