// Copyright (C) 2020-2024 Free Software Foundation, Inc. // This file is part of GCC. // GCC is free software; you can redistribute it and/or modify it under // the terms of the GNU General Public License as published by the Free // Software Foundation; either version 3, or (at your option) any later // version. // GCC is distributed in the hope that it will be useful, but WITHOUT ANY // WARRANTY; without even the implied warranty of MERCHANTABILITY or // FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License // for more details. // You should have received a copy of the GNU General Public License // along with GCC; see the file COPYING3. If not see // . #include "rust-system.h" #include "rust-hir-pattern-analysis.h" #include "rust-diagnostics.h" #include "rust-hir-full-decls.h" #include "rust-hir-path.h" #include "rust-hir-pattern.h" #include "rust-hir.h" #include "rust-mapping-common.h" #include "rust-system.h" #include "rust-tyty.h" #include "rust-immutable-name-resolution-context.h" // for flag_name_resolution_2_0 #include "options.h" namespace Rust { namespace Analysis { PatternChecker::PatternChecker () : tyctx (*Resolver::TypeCheckContext::get ()), resolver (*Resolver::Resolver::get ()), mappings (Analysis::Mappings::get ()) {} void PatternChecker::go (HIR::Crate &crate) { rust_debug ("started pattern check"); for (auto &item : crate.get_items ()) item->accept_vis (*this); rust_debug ("finished pattern check"); } void PatternChecker::visit (Lifetime &) {} void PatternChecker::visit (LifetimeParam &) {} void PatternChecker::visit (PathInExpression &path) {} void PatternChecker::visit (TypePathSegment &) {} void PatternChecker::visit (TypePathSegmentGeneric &) {} void PatternChecker::visit (TypePathSegmentFunction &) {} void PatternChecker::visit (TypePath &) {} void PatternChecker::visit (QualifiedPathInExpression &) {} void PatternChecker::visit (QualifiedPathInType &) {} void PatternChecker::visit (LiteralExpr &) {} void PatternChecker::visit (BorrowExpr &expr) { expr.get_expr ().accept_vis (*this); } void PatternChecker::visit (DereferenceExpr &expr) { expr.get_expr ().accept_vis (*this); } void PatternChecker::visit (ErrorPropagationExpr &expr) { expr.get_expr ().accept_vis (*this); } void PatternChecker::visit (NegationExpr &expr) { expr.get_expr ().accept_vis (*this); } void PatternChecker::visit (ArithmeticOrLogicalExpr &expr) { expr.get_lhs ().accept_vis (*this); expr.get_rhs ().accept_vis (*this); } void PatternChecker::visit (ComparisonExpr &expr) { expr.get_lhs ().accept_vis (*this); expr.get_rhs ().accept_vis (*this); } void PatternChecker::visit (LazyBooleanExpr &expr) { expr.get_lhs ().accept_vis (*this); expr.get_rhs ().accept_vis (*this); } void PatternChecker::visit (TypeCastExpr &expr) { expr.get_expr ().accept_vis (*this); } void PatternChecker::visit (AssignmentExpr &expr) { expr.get_lhs ().accept_vis (*this); expr.get_rhs ().accept_vis (*this); } void PatternChecker::visit (CompoundAssignmentExpr &expr) { expr.get_lhs ().accept_vis (*this); expr.get_rhs ().accept_vis (*this); } void PatternChecker::visit (GroupedExpr &expr) { expr.get_expr_in_parens ().accept_vis (*this); } void PatternChecker::visit (ArrayElemsValues &elems) { for (auto &elem : elems.get_values ()) elem->accept_vis (*this); } void PatternChecker::visit (ArrayElemsCopied &elems) { elems.get_elem_to_copy ().accept_vis (*this); } void PatternChecker::visit (ArrayExpr &expr) { expr.get_internal_elements ().accept_vis (*this); } void PatternChecker::visit (ArrayIndexExpr &expr) { expr.get_array_expr ().accept_vis (*this); expr.get_index_expr ().accept_vis (*this); } void PatternChecker::visit (TupleExpr &expr) { for (auto &elem : expr.get_tuple_elems ()) elem->accept_vis (*this); } void PatternChecker::visit (TupleIndexExpr &expr) { expr.get_tuple_expr ().accept_vis (*this); } void PatternChecker::visit (StructExprStruct &) {} void PatternChecker::visit (StructExprFieldIdentifier &) {} void PatternChecker::visit (StructExprFieldIdentifierValue &field) { field.get_value ().accept_vis (*this); } void PatternChecker::visit (StructExprFieldIndexValue &field) { field.get_value ().accept_vis (*this); } void PatternChecker::visit (StructExprStructFields &expr) { for (auto &field : expr.get_fields ()) field->accept_vis (*this); } void PatternChecker::visit (StructExprStructBase &) {} void PatternChecker::visit (CallExpr &expr) { if (!expr.has_fnexpr ()) return; NodeId ast_node_id = expr.get_fnexpr ().get_mappings ().get_nodeid (); NodeId ref_node_id; if (flag_name_resolution_2_0) { auto &nr_ctx = Resolver2_0::ImmutableNameResolutionContext::get ().resolver (); if (auto id = nr_ctx.lookup (ast_node_id)) ref_node_id = *id; else return; } else if (!resolver.lookup_resolved_name (ast_node_id, &ref_node_id)) return; if (auto definition_id = mappings.lookup_node_to_hir (ref_node_id)) { if (expr.has_params ()) for (auto &arg : expr.get_arguments ()) arg->accept_vis (*this); } else { rust_unreachable (); } } void PatternChecker::visit (MethodCallExpr &expr) { expr.get_receiver ().accept_vis (*this); for (auto &arg : expr.get_arguments ()) arg->accept_vis (*this); } void PatternChecker::visit (FieldAccessExpr &expr) { expr.get_receiver_expr ().accept_vis (*this); } void PatternChecker::visit (ClosureExpr &expr) { expr.get_expr ().accept_vis (*this); } void PatternChecker::visit (BlockExpr &expr) { for (auto &stmt : expr.get_statements ()) stmt->accept_vis (*this); if (expr.has_expr ()) expr.get_final_expr ().accept_vis (*this); } void PatternChecker::visit (ContinueExpr &) {} void PatternChecker::visit (BreakExpr &expr) { if (expr.has_break_expr ()) expr.get_expr ().accept_vis (*this); } void PatternChecker::visit (RangeFromToExpr &expr) { expr.get_from_expr ().accept_vis (*this); expr.get_to_expr ().accept_vis (*this); } void PatternChecker::visit (RangeFromExpr &expr) { expr.get_from_expr ().accept_vis (*this); } void PatternChecker::visit (RangeToExpr &expr) { expr.get_to_expr ().accept_vis (*this); } void PatternChecker::visit (RangeFullExpr &) {} void PatternChecker::visit (RangeFromToInclExpr &expr) { expr.get_from_expr ().accept_vis (*this); expr.get_to_expr ().accept_vis (*this); } void PatternChecker::visit (RangeToInclExpr &expr) { expr.get_to_expr ().accept_vis (*this); } void PatternChecker::visit (ReturnExpr &expr) { if (expr.has_return_expr ()) expr.get_expr ().accept_vis (*this); } void PatternChecker::visit (UnsafeBlockExpr &expr) { expr.get_block_expr ().accept_vis (*this); } void PatternChecker::visit (LoopExpr &expr) { expr.get_loop_block ().accept_vis (*this); } void PatternChecker::visit (WhileLoopExpr &expr) { expr.get_predicate_expr ().accept_vis (*this); expr.get_loop_block ().accept_vis (*this); } void PatternChecker::visit (WhileLetLoopExpr &expr) { expr.get_cond ().accept_vis (*this); expr.get_loop_block ().accept_vis (*this); } void PatternChecker::visit (IfExpr &expr) { expr.get_if_condition ().accept_vis (*this); expr.get_if_block ().accept_vis (*this); } void PatternChecker::visit (IfExprConseqElse &expr) { expr.get_if_condition ().accept_vis (*this); expr.get_if_block ().accept_vis (*this); expr.get_else_block ().accept_vis (*this); } void PatternChecker::visit (MatchExpr &expr) { expr.get_scrutinee_expr ().accept_vis (*this); for (auto &match_arm : expr.get_match_cases ()) match_arm.get_expr ().accept_vis (*this); // match expressions are only an entrypoint TyTy::BaseType *scrutinee_ty; bool ok = tyctx.lookup_type ( expr.get_scrutinee_expr ().get_mappings ().get_hirid (), &scrutinee_ty); rust_assert (ok); check_match_usefulness (&tyctx, scrutinee_ty, expr); } void PatternChecker::visit (AwaitExpr &) { // TODO: Visit expression } void PatternChecker::visit (AsyncBlockExpr &) { // TODO: Visit block expression } void PatternChecker::visit (InlineAsm &expr) {} void PatternChecker::visit (TypeParam &) {} void PatternChecker::visit (ConstGenericParam &) {} void PatternChecker::visit (LifetimeWhereClauseItem &) {} void PatternChecker::visit (TypeBoundWhereClauseItem &) {} void PatternChecker::visit (Module &module) { for (auto &item : module.get_items ()) item->accept_vis (*this); } void PatternChecker::visit (ExternCrate &) {} void PatternChecker::visit (UseTreeGlob &) {} void PatternChecker::visit (UseTreeList &) {} void PatternChecker::visit (UseTreeRebind &) {} void PatternChecker::visit (UseDeclaration &) {} void PatternChecker::visit (Function &function) { function.get_definition ().accept_vis (*this); } void PatternChecker::visit (TypeAlias &) {} void PatternChecker::visit (StructStruct &) {} void PatternChecker::visit (TupleStruct &) {} void PatternChecker::visit (EnumItem &) {} void PatternChecker::visit (EnumItemTuple &) {} void PatternChecker::visit (EnumItemStruct &) {} void PatternChecker::visit (EnumItemDiscriminant &) {} void PatternChecker::visit (Enum &) {} void PatternChecker::visit (Union &) {} void PatternChecker::visit (ConstantItem &const_item) { const_item.get_expr ().accept_vis (*this); } void PatternChecker::visit (StaticItem &static_item) { static_item.get_expr ().accept_vis (*this); } void PatternChecker::visit (TraitItemFunc &item) { if (item.has_definition ()) item.get_block_expr ().accept_vis (*this); } void PatternChecker::visit (TraitItemConst &item) { if (item.has_expr ()) item.get_expr ().accept_vis (*this); } void PatternChecker::visit (TraitItemType &) {} void PatternChecker::visit (Trait &trait) { for (auto &item : trait.get_trait_items ()) item->accept_vis (*this); } void PatternChecker::visit (ImplBlock &impl) { for (auto &item : impl.get_impl_items ()) item->accept_vis (*this); } void PatternChecker::visit (ExternalStaticItem &) {} void PatternChecker::visit (ExternalFunctionItem &) {} void PatternChecker::visit (ExternalTypeItem &) {} void PatternChecker::visit (ExternBlock &block) { // FIXME: Do we need to do this? for (auto &item : block.get_extern_items ()) item->accept_vis (*this); } void PatternChecker::visit (LiteralPattern &) {} void PatternChecker::visit (IdentifierPattern &) {} void PatternChecker::visit (WildcardPattern &) {} void PatternChecker::visit (RangePatternBoundLiteral &) {} void PatternChecker::visit (RangePatternBoundPath &) {} void PatternChecker::visit (RangePatternBoundQualPath &) {} void PatternChecker::visit (RangePattern &) {} void PatternChecker::visit (ReferencePattern &) {} void PatternChecker::visit (StructPatternFieldTuplePat &) {} void PatternChecker::visit (StructPatternFieldIdentPat &) {} void PatternChecker::visit (StructPatternFieldIdent &) {} void PatternChecker::visit (StructPattern &) {} void PatternChecker::visit (TupleStructItemsNoRange &) {} void PatternChecker::visit (TupleStructItemsRange &) {} void PatternChecker::visit (TupleStructPattern &) {} void PatternChecker::visit (TuplePatternItemsMultiple &) {} void PatternChecker::visit (TuplePatternItemsRanged &) {} void PatternChecker::visit (TuplePattern &) {} void PatternChecker::visit (SlicePattern &) {} void PatternChecker::visit (AltPattern &) {} void PatternChecker::visit (EmptyStmt &) {} void PatternChecker::visit (LetStmt &stmt) { if (stmt.has_init_expr ()) stmt.get_init_expr ().accept_vis (*this); } void PatternChecker::visit (ExprStmt &stmt) { stmt.get_expr ().accept_vis (*this); } void PatternChecker::visit (TraitBound &) {} void PatternChecker::visit (ImplTraitType &) {} void PatternChecker::visit (TraitObjectType &) {} void PatternChecker::visit (ParenthesisedType &) {} void PatternChecker::visit (TupleType &) {} void PatternChecker::visit (NeverType &) {} void PatternChecker::visit (RawPointerType &) {} void PatternChecker::visit (ReferenceType &) {} void PatternChecker::visit (ArrayType &) {} void PatternChecker::visit (SliceType &) {} void PatternChecker::visit (InferredType &) {} void PatternChecker::visit (BareFunctionType &) {} bool Constructor::is_covered_by (const Constructor &o) const { if (o.kind == ConstructorKind::WILDCARD) return true; switch (kind) { case ConstructorKind::VARIANT: { rust_assert (kind == ConstructorKind::VARIANT); return variant_idx == o.variant_idx; } break; case ConstructorKind::INT_RANGE: { rust_assert (kind == ConstructorKind::INT_RANGE); return int_range.lo >= o.int_range.lo && int_range.hi <= o.int_range.hi; } break; case ConstructorKind::WILDCARD: { // TODO: wildcard is covered by a variant of enum with a single // variant return false; } break; case ConstructorKind::STRUCT: { // Struct pattern is always covered by a other struct constructor. return true; } break; // TODO: support references case ConstructorKind::REFERENCE: default: rust_unreachable (); } } bool Constructor::operator< (const Constructor &o) const { if (kind != o.kind) return kind < o.kind; switch (kind) { case ConstructorKind::VARIANT: return variant_idx < o.variant_idx; case ConstructorKind::INT_RANGE: return int_range.lo < o.int_range.lo || (int_range.lo == o.int_range.lo && int_range.hi < o.int_range.hi); case ConstructorKind::STRUCT: case ConstructorKind::WILDCARD: case ConstructorKind::REFERENCE: return false; default: rust_unreachable (); } } std::string Constructor::to_string () const { switch (kind) { case ConstructorKind::STRUCT: return "STRUCT"; case ConstructorKind::VARIANT: return "VARIANT(" + std::to_string (variant_idx) + ")"; case ConstructorKind::INT_RANGE: return "RANGE" + std::to_string (int_range.lo) + ".." + std::to_string (int_range.hi); case ConstructorKind::WILDCARD: return "_"; case ConstructorKind::REFERENCE: return "REF"; default: rust_unreachable (); } } std::vector DeconstructedPat::specialize (const Constructor &other_ctor, int other_ctor_arity) const { rust_assert (other_ctor.is_covered_by (ctor)); if (ctor.is_wildcard ()) return std::vector ( other_ctor_arity, DeconstructedPat (Constructor::make_wildcard (), locus)); return fields; } std::string DeconstructedPat::to_string () const { std::string s = ctor.to_string () + "["; for (auto &f : fields) s += f.to_string () + ", "; s += "](arity=" + std::to_string (arity) + ")"; return s; } bool PatOrWild::is_covered_by (const Constructor &c) const { if (pat.has_value ()) return pat.value ().get_ctor ().is_covered_by (c); else return true; } std::vector PatOrWild::specialize (const Constructor &other_ctor, int other_ctor_arity) const { if (pat.has_value ()) { auto v = pat.value ().specialize (other_ctor, other_ctor_arity); std::vector ret; for (auto &pat : v) ret.push_back (PatOrWild::make_pattern (pat)); return ret; } else { return std::vector (other_ctor_arity, PatOrWild::make_wildcard ()); } } std::string PatOrWild::to_string () const { if (pat.has_value ()) return pat.value ().to_string (); else return "Wild"; } void PatStack::pop_head_constructor (const Constructor &other_ctor, int other_ctor_arity) { rust_assert (!pats.empty ()); rust_assert (other_ctor.is_covered_by (head ().ctor ())); PatOrWild &hd = head (); auto v = hd.specialize (other_ctor, other_ctor_arity); { std::string s = "["; for (auto &pat : v) s += pat.to_string () + ", "; s += "]"; rust_debug ("specialize %s with %s to %s", hd.to_string ().c_str (), other_ctor.to_string ().c_str (), s.c_str ()); } pop_head (); for (auto &pat : v) pats.push_back (pat); } std::string MatrixRow::to_string () const { std::string s; for (const PatOrWild &pat : pats.get_subpatterns ()) s += pat.to_string () + ", "; return s; } std::vector PlaceInfo::specialize (const Constructor &c) const { switch (c.get_kind ()) { case Constructor::ConstructorKind::WILDCARD: case Constructor::ConstructorKind::INT_RANGE: { return {}; } break; case Constructor::ConstructorKind::STRUCT: case Constructor::ConstructorKind::VARIANT: { rust_assert (ty->get_kind () == TyTy::TypeKind::ADT); TyTy::ADTType *adt = static_cast (ty); switch (adt->get_adt_kind ()) { case TyTy::ADTType::ADTKind::ENUM: case TyTy::ADTType::ADTKind::STRUCT_STRUCT: case TyTy::ADTType::ADTKind::TUPLE_STRUCT: { TyTy::VariantDef *variant = adt->get_variants ().at (c.get_variant_index ()); if (variant->get_variant_type () == TyTy::VariantDef::VariantType::NUM) return {}; std::vector new_place_infos; for (auto &field : variant->get_fields ()) new_place_infos.push_back (field->get_field_type ()); return new_place_infos; } break; case TyTy::ADTType::ADTKind::UNION: { // TODO: support unions rust_unreachable (); } } } break; default: { rust_unreachable (); } break; } rust_unreachable (); } Matrix Matrix::specialize (const Constructor &ctor) const { auto subfields_place_info = place_infos.at (0).specialize (ctor); std::vector new_rows; for (const MatrixRow &row : rows) { PatStack pats = row.get_pats_clone (); const PatOrWild &hd = pats.head (); if (ctor.is_covered_by (hd.ctor ())) { pats.pop_head_constructor (ctor, subfields_place_info.size ()); new_rows.push_back (MatrixRow (pats, row.is_under_guard ())); } } if (place_infos.empty ()) return Matrix (new_rows, {}); // push subfields of the first fields after specialization std::vector new_place_infos = subfields_place_info; // add place infos for the rest of the fields for (size_t i = 1; i < place_infos.size (); i++) new_place_infos.push_back (place_infos.at (i)); return Matrix (new_rows, new_place_infos); } std::string Matrix::to_string () const { std::string s = "[\n"; for (const MatrixRow &row : rows) s += "row: " + row.to_string () + "\n"; s += "](place_infos=["; for (const PlaceInfo &place_info : place_infos) s += place_info.get_type ()->as_string () + ", "; s += "])"; return s; } std::string WitnessPat::to_string () const { switch (ctor.get_kind ()) { case Constructor::ConstructorKind::STRUCT: { TyTy::ADTType *adt = static_cast (ty); TyTy::VariantDef *variant = adt->get_variants ().at (ctor.get_variant_index ()); std::string buf; buf += adt->get_identifier (); buf += " {"; if (!fields.empty ()) buf += " "; for (size_t i = 0; i < fields.size (); i++) { buf += variant->get_fields ().at (i)->get_name () + ": "; buf += fields.at (i).to_string (); if (i < fields.size () - 1) buf += ", "; } if (!fields.empty ()) buf += " "; buf += "}"; return buf; } break; case Constructor::ConstructorKind::VARIANT: { std::string buf; TyTy::ADTType *adt = static_cast (ty); buf += adt->get_identifier (); TyTy::VariantDef *variant = adt->get_variants ().at (ctor.get_variant_index ()); buf += "::" + variant->get_identifier (); switch (variant->get_variant_type ()) { case TyTy::VariantDef::VariantType::NUM: { return buf; } break; case TyTy::VariantDef::VariantType::TUPLE: { buf += "("; for (size_t i = 0; i < fields.size (); i++) { buf += fields.at (i).to_string (); if (i < fields.size () - 1) buf += ", "; } buf += ")"; return buf; } break; case TyTy::VariantDef::VariantType::STRUCT: { buf += " {"; if (!fields.empty ()) buf += " "; for (size_t i = 0; i < fields.size (); i++) { buf += variant->get_fields ().at (i)->get_name () + ": "; buf += fields.at (i).to_string (); if (i < fields.size () - 1) buf += ", "; } if (!fields.empty ()) buf += " "; buf += "}"; } break; default: { rust_unreachable (); } break; } return buf; } break; case Constructor::ConstructorKind::INT_RANGE: { // TODO: implement rust_unreachable (); } break; case Constructor::ConstructorKind::WILDCARD: { return "_"; } break; case Constructor::ConstructorKind::REFERENCE: { // TODO: implement rust_unreachable (); } break; default: { rust_unreachable (); } break; } rust_unreachable (); } void WitnessMatrix::apply_constructor (const Constructor &ctor, const std::set &missings, TyTy::BaseType *ty) { int arity = 0; // TODO: only support struct and variant ctor for now. switch (ctor.get_kind ()) { case Constructor::ConstructorKind::WILDCARD: { arity = 0; } break; case Constructor::ConstructorKind::STRUCT: case Constructor::ConstructorKind::VARIANT: { if (ty->get_kind () == TyTy::TypeKind::ADT) { TyTy::ADTType *adt = static_cast (ty); TyTy::VariantDef *variant = adt->get_variants ().at (ctor.get_variant_index ()); if (variant->get_variant_type () == TyTy::VariantDef::NUM) arity = 0; else arity = variant->get_fields ().size (); } } break; default: { rust_unreachable (); } } std::string buf; for (auto &stack : patstacks) { buf += "["; for (auto &pat : stack) buf += pat.to_string () + ", "; buf += "]\n"; } rust_debug ("witness pats:\n%s", buf.c_str ()); for (auto &stack : patstacks) { std::vector subfield; for (int i = 0; i < arity; i++) { if (stack.empty ()) subfield.push_back (WitnessPat::make_wildcard (ty)); else { subfield.push_back (stack.back ()); stack.pop_back (); } } stack.push_back (WitnessPat (ctor, subfield, ty)); } } void WitnessMatrix::extend (const WitnessMatrix &other) { patstacks.insert (patstacks.end (), other.patstacks.begin (), other.patstacks.end ()); } // forward declarations static DeconstructedPat lower_pattern (Resolver::TypeCheckContext *ctx, HIR::Pattern &pattern, TyTy::BaseType *scrutinee_ty); static DeconstructedPat lower_tuple_pattern (Resolver::TypeCheckContext *ctx, HIR::TupleStructPattern &pattern, TyTy::VariantDef *variant, Constructor &ctor) { int arity = variant->get_fields ().size (); HIR::TupleStructItems &elems = pattern.get_items (); std::vector fields; switch (elems.get_item_type ()) { case HIR::TupleStructItems::ItemType::MULTIPLE: { HIR::TupleStructItemsNoRange &multiple = static_cast (elems); rust_assert (variant->get_fields ().size () == multiple.get_patterns ().size ()); for (size_t i = 0; i < multiple.get_patterns ().size (); i++) { fields.push_back ( lower_pattern (ctx, *multiple.get_patterns ().at (i), variant->get_fields ().at (i)->get_field_type ())); } return DeconstructedPat (ctor, arity, fields, pattern.get_locus ()); } break; case HIR::TupleStructItems::ItemType::RANGED: { // TODO: ranged tuple struct items rust_unreachable (); } break; default: { rust_unreachable (); } } } static DeconstructedPat lower_struct_pattern (Resolver::TypeCheckContext *ctx, HIR::StructPattern &pattern, TyTy::VariantDef *variant, Constructor ctor) { int arity = variant->get_fields ().size (); // Initialize all field patterns to wildcard. std::vector fields = std::vector (arity, DeconstructedPat::make_wildcard ( pattern.get_locus ())); std::map field_map; for (int i = 0; i < arity; i++) { auto &f = variant->get_fields ().at (i); field_map[f->get_name ()] = i; } // Fill in the fields with the present patterns. HIR::StructPatternElements elems = pattern.get_struct_pattern_elems (); for (auto &elem : elems.get_struct_pattern_fields ()) { switch (elem->get_item_type ()) { case HIR::StructPatternField::ItemType::IDENT: { HIR::StructPatternFieldIdent *ident = static_cast (elem.get ()); int field_idx = field_map.at (ident->get_identifier ().as_string ()); fields.at (field_idx) = DeconstructedPat::make_wildcard (pattern.get_locus ()); } break; case HIR::StructPatternField::ItemType::IDENT_PAT: { HIR::StructPatternFieldIdentPat *ident_pat = static_cast (elem.get ()); int field_idx = field_map.at (ident_pat->get_identifier ().as_string ()); fields.at (field_idx) = lower_pattern ( ctx, ident_pat->get_pattern (), variant->get_fields ().at (field_idx)->get_field_type ()); } break; case HIR::StructPatternField::ItemType::TUPLE_PAT: { // TODO: tuple: pat rust_unreachable (); } break; default: { rust_unreachable (); } } } return DeconstructedPat{ctor, arity, fields, pattern.get_locus ()}; }; static DeconstructedPat lower_pattern (Resolver::TypeCheckContext *ctx, HIR::Pattern &pattern, TyTy::BaseType *scrutinee_ty) { HIR::Pattern::PatternType pat_type = pattern.get_pattern_type (); switch (pat_type) { case HIR::Pattern::PatternType::WILDCARD: case HIR::Pattern::PatternType::IDENTIFIER: { return DeconstructedPat::make_wildcard (pattern.get_locus ()); } break; case HIR::Pattern::PatternType::PATH: { // TODO: support constants, associated constants, enum variants and // structs // https://doc.rust-lang.org/reference/patterns.html#path-patterns // unimplemented. Treat this pattern as wildcard for now. return DeconstructedPat::make_wildcard (pattern.get_locus ()); } break; case HIR::Pattern::PatternType::REFERENCE: { // TODO: unimplemented. Treat this pattern as wildcard for now. return DeconstructedPat::make_wildcard (pattern.get_locus ()); } break; case HIR::Pattern::PatternType::STRUCT: case HIR::Pattern::PatternType::TUPLE_STRUCT: { HirId path_id = UNKNOWN_HIRID; if (pat_type == HIR::Pattern::PatternType::STRUCT) { HIR::StructPattern &struct_pattern = static_cast (pattern); path_id = struct_pattern.get_path ().get_mappings ().get_hirid (); } else { HIR::TupleStructPattern &tuple_pattern = static_cast (pattern); path_id = tuple_pattern.get_path ().get_mappings ().get_hirid (); } rust_assert (scrutinee_ty->get_kind () == TyTy::TypeKind::ADT); TyTy::ADTType *adt = static_cast (scrutinee_ty); Constructor ctor = Constructor::make_struct (); TyTy::VariantDef *variant; if (adt->is_struct_struct () || adt->is_tuple_struct ()) variant = adt->get_variants ().at (0); else if (adt->is_enum ()) { HirId variant_id = UNKNOWN_HIRID; bool ok = ctx->lookup_variant_definition (path_id, &variant_id); rust_assert (ok); int variant_idx; ok = adt->lookup_variant_by_id (variant_id, &variant, &variant_idx); rust_assert (ok); ctor = Constructor::make_variant (variant_idx); } else { rust_unreachable (); } rust_assert (variant->get_variant_type () == TyTy::VariantDef::VariantType::TUPLE || variant->get_variant_type () == TyTy::VariantDef::VariantType::STRUCT); if (pat_type == HIR::Pattern::PatternType::STRUCT) { HIR::StructPattern &struct_pattern = static_cast (pattern); return lower_struct_pattern (ctx, struct_pattern, variant, ctor); } else { HIR::TupleStructPattern &tuple_pattern = static_cast (pattern); return lower_tuple_pattern (ctx, tuple_pattern, variant, ctor); } } break; case HIR::Pattern::PatternType::TUPLE: { // TODO: unimplemented. Treat this pattern as wildcard for now. return DeconstructedPat::make_wildcard (pattern.get_locus ()); } break; case HIR::Pattern::PatternType::SLICE: { // TODO: unimplemented. Treat this pattern as wildcard for now. return DeconstructedPat::make_wildcard (pattern.get_locus ()); } break; case HIR::Pattern::PatternType::ALT: { // TODO: unimplemented. Treat this pattern as wildcard for now. return DeconstructedPat::make_wildcard (pattern.get_locus ()); } break; case HIR::Pattern::PatternType::LITERAL: { // TODO: unimplemented. Treat this pattern as wildcard for now. return DeconstructedPat::make_wildcard (pattern.get_locus ()); } break; case HIR::Pattern::PatternType::RANGE: { // TODO: unimplemented. Treat this pattern as wildcard for now. return DeconstructedPat::make_wildcard (pattern.get_locus ()); } break; case HIR::Pattern::PatternType::GROUPED: { // TODO: unimplemented. Treat this pattern as wildcard for now. return DeconstructedPat::make_wildcard (pattern.get_locus ()); } break; default: { rust_unreachable (); } } } static MatchArm lower_arm (Resolver::TypeCheckContext *ctx, HIR::MatchCase &arm, TyTy::BaseType *scrutinee_ty) { rust_assert (arm.get_arm ().get_patterns ().size () > 0); DeconstructedPat pat = lower_pattern (ctx, *arm.get_arm ().get_patterns ().at (0), scrutinee_ty); return MatchArm (pat, arm.get_arm ().has_match_arm_guard ()); } std::pair, std::set> split_constructors (std::vector &ctors, PlaceInfo &place_info) { bool all_wildcard = true; for (auto &ctor : ctors) { if (!ctor.is_wildcard ()) all_wildcard = false; } // first pass for the case that all patterns are wildcard if (all_wildcard) return std::make_pair (std::set ( {Constructor::make_wildcard ()}), std::set ()); // TODO: only support enums and structs for now. TyTy::BaseType *ty = place_info.get_type (); rust_assert (ty->get_kind () == TyTy::TypeKind::ADT); TyTy::ADTType *adt = static_cast (ty); rust_assert (adt->is_enum () || adt->is_struct_struct () || adt->is_tuple_struct ()); std::set universe; if (adt->is_enum ()) { for (size_t i = 0; i < adt->get_variants ().size (); i++) universe.insert (Constructor::make_variant (i)); } else if (adt->is_struct_struct () || adt->is_tuple_struct ()) { universe.insert (Constructor::make_struct ()); } std::set present; for (auto &ctor : ctors) { if (ctor.is_wildcard ()) return std::make_pair (universe, std::set ()); else present.insert (ctor); } std::set missing; std::set_difference (universe.begin (), universe.end (), present.begin (), present.end (), std::inserter (missing, missing.end ())); return std::make_pair (universe, missing); } // The core of the algorithm. It computes the usefulness and exhaustiveness of a // given matrix recursively. // TODO: calculate usefulness static WitnessMatrix compute_exhaustiveness_and_usefulness (Resolver::TypeCheckContext *ctx, Matrix &matrix) { rust_debug ("call compute_exhaustiveness_and_usefulness"); rust_debug ("matrix: %s", matrix.to_string ().c_str ()); if (matrix.get_rows ().empty ()) { // no rows left. This means a non-exhaustive pattern. rust_debug ("non-exhaustive subpattern found"); return WitnessMatrix::make_unit (); } // Base case: there are no columns in matrix. if (matrix.get_place_infos ().empty ()) return WitnessMatrix::make_empty (); std::vector heads; for (auto head : matrix.heads ()) heads.push_back (head.ctor ()); // TODO: not sure missing ctors need to be calculated auto ctors_and_missings = split_constructors (heads, matrix.get_place_infos ().at (0)); std::set ctors = ctors_and_missings.first; std::set missings = ctors_and_missings.second; WitnessMatrix ret = WitnessMatrix::make_empty (); for (auto &ctor : ctors) { rust_debug ("specialize with %s", ctor.to_string ().c_str ()); // TODO: Instead of creating new matrix, we can change the original matrix // and use it for sub-pattern matching. It will significantly reduce // memory usage. Matrix spec_matrix = matrix.specialize (ctor); WitnessMatrix witness = compute_exhaustiveness_and_usefulness (ctx, spec_matrix); TyTy::BaseType *ty = matrix.get_place_infos ().at (0).get_type (); witness.apply_constructor (ctor, missings, ty); ret.extend (witness); } return ret; } static void emit_exhaustiveness_error (Resolver::TypeCheckContext *ctx, HIR::MatchExpr &expr, WitnessMatrix &witness) { TyTy::BaseType *scrutinee_ty; bool ok = ctx->lookup_type (expr.get_scrutinee_expr ().get_mappings ().get_hirid (), &scrutinee_ty); rust_assert (ok); if (!witness.empty ()) { std::stringstream buf; for (size_t i = 0; i < witness.get_stacks ().size (); i++) { auto &stack = witness.get_stacks ().at (i); WitnessPat w = WitnessPat::make_wildcard (scrutinee_ty); if (!stack.empty ()) w = stack.at (0); rust_debug ("Witness[%d]: %s", (int) i, w.to_string ().c_str ()); buf << "'" << w.to_string () << "'"; if (i != witness.get_stacks ().size () - 1) buf << " and "; } rust_error_at (expr.get_scrutinee_expr ().get_locus (), "non-exhaustive patterns: %s not covered", buf.str ().c_str ()); } else { rust_debug ("no witness found"); } } // Entry point for computing match usefulness and check exhaustiveness void check_match_usefulness (Resolver::TypeCheckContext *ctx, TyTy::BaseType *scrutinee_ty, HIR::MatchExpr &expr) { if (!expr.has_match_arms ()) return; // Lower the arms to a more convenient representation. std::vector rows; for (auto &arm : expr.get_match_cases ()) { PatStack pats; MatchArm lowered = lower_arm (ctx, arm, scrutinee_ty); PatOrWild pat = PatOrWild::make_pattern (lowered.get_pat ()); pats.push (pat); rows.push_back (MatrixRow (pats, lowered.has_guard ())); } std::vector place_infos = {{PlaceInfo (scrutinee_ty)}}; Matrix matrix{rows, place_infos}; WitnessMatrix witness = compute_exhaustiveness_and_usefulness (ctx, matrix); emit_exhaustiveness_error (ctx, expr, witness); } } // namespace Analysis } // namespace Rust