diff options
author | Philip Herron <philip.herron@embecosm.com> | 2021-10-29 17:38:29 +0100 |
---|---|---|
committer | Philip Herron <philip.herron@embecosm.com> | 2021-11-01 13:12:47 +0000 |
commit | faa1a005e92237a0188311b48455be88126e3e68 (patch) | |
tree | 40a10c0a05c8b3a10b8d52844cddca31bcf5069d /gcc | |
parent | 82f3dd40d2dcc9eb2ea261e42bf7bb365faaf0a0 (diff) | |
download | gcc-faa1a005e92237a0188311b48455be88126e3e68.zip gcc-faa1a005e92237a0188311b48455be88126e3e68.tar.gz gcc-faa1a005e92237a0188311b48455be88126e3e68.tar.bz2 |
Refactor ADTType to consist of multiple variants
Algebraic data types represent Structs, Tuple Structs, unit
structs and enums in rust. The key difference here is that
each of these are an ADT with a single variant and enums
are an ADT with multiple variants.
It adds indirection to where the fields of an ADT are
managed.
Co-authored-by: Mark Wielaard <mark@klomp.org>
Addresses #79
Diffstat (limited to 'gcc')
-rw-r--r-- | gcc/rust/backend/rust-compile-context.h | 9 | ||||
-rw-r--r-- | gcc/rust/backend/rust-compile-expr.h | 18 | ||||
-rw-r--r-- | gcc/rust/backend/rust-compile.cc | 6 | ||||
-rw-r--r-- | gcc/rust/hir/tree/rust-hir-item.h | 9 | ||||
-rw-r--r-- | gcc/rust/lint/rust-lint-marklive.cc | 14 | ||||
-rw-r--r-- | gcc/rust/typecheck/rust-hir-type-check-expr.h | 21 | ||||
-rw-r--r-- | gcc/rust/typecheck/rust-hir-type-check-stmt.h | 50 | ||||
-rw-r--r-- | gcc/rust/typecheck/rust-hir-type-check-toplevel.h | 50 | ||||
-rw-r--r-- | gcc/rust/typecheck/rust-hir-type-check.cc | 42 | ||||
-rw-r--r-- | gcc/rust/typecheck/rust-tyty-cast.h | 36 | ||||
-rw-r--r-- | gcc/rust/typecheck/rust-tyty-cmp.h | 36 | ||||
-rw-r--r-- | gcc/rust/typecheck/rust-tyty-coercion.h | 36 | ||||
-rw-r--r-- | gcc/rust/typecheck/rust-tyty-rules.h | 36 | ||||
-rw-r--r-- | gcc/rust/typecheck/rust-tyty.cc | 74 | ||||
-rw-r--r-- | gcc/rust/typecheck/rust-tyty.h | 199 |
15 files changed, 452 insertions, 184 deletions
diff --git a/gcc/rust/backend/rust-compile-context.h b/gcc/rust/backend/rust-compile-context.h index 2b2018f..551e041 100644 --- a/gcc/rust/backend/rust-compile-context.h +++ b/gcc/rust/backend/rust-compile-context.h @@ -431,10 +431,15 @@ public: if (ctx->lookup_compiled_types (type.get_ty_ref (), &translated, &type)) return; + // we dont support enums yet + rust_assert (!type.is_enum ()); + rust_assert (type.number_of_variants () == 1); + + TyTy::VariantDef &variant = *type.get_variants ().at (0); std::vector<Backend::Btyped_identifier> fields; - for (size_t i = 0; i < type.num_fields (); i++) + for (size_t i = 0; i < variant.num_fields (); i++) { - const TyTy::StructFieldType *field = type.get_field (i); + const TyTy::StructFieldType *field = variant.get_field_at_index (i); Btype *compiled_field_ty = TyTyResolveCompile::compile (ctx, field->get_field_type ()); diff --git a/gcc/rust/backend/rust-compile-expr.h b/gcc/rust/backend/rust-compile-expr.h index 0512373..f43db50 100644 --- a/gcc/rust/backend/rust-compile-expr.h +++ b/gcc/rust/backend/rust-compile-expr.h @@ -698,7 +698,13 @@ public: if (receiver->get_kind () == TyTy::TypeKind::ADT) { TyTy::ADTType *adt = static_cast<TyTy::ADTType *> (receiver); - adt->get_field (expr.get_field_name (), &field_index); + rust_assert (!adt->is_enum ()); + rust_assert (adt->number_of_variants () == 1); + + TyTy::VariantDef *variant = adt->get_variants ().at (0); + bool ok = variant->lookup_field (expr.get_field_name (), nullptr, + &field_index); + rust_assert (ok); } else if (receiver->get_kind () == TyTy::TypeKind::REF) { @@ -707,9 +713,15 @@ public: rust_assert (b->get_kind () == TyTy::TypeKind::ADT); TyTy::ADTType *adt = static_cast<TyTy::ADTType *> (b); - adt->get_field (expr.get_field_name (), &field_index); - Btype *adt_tyty = TyTyResolveCompile::compile (ctx, adt); + rust_assert (!adt->is_enum ()); + rust_assert (adt->number_of_variants () == 1); + TyTy::VariantDef *variant = adt->get_variants ().at (0); + bool ok = variant->lookup_field (expr.get_field_name (), nullptr, + &field_index); + rust_assert (ok); + + Btype *adt_tyty = TyTyResolveCompile::compile (ctx, adt); Bexpression *indirect = ctx->get_backend ()->indirect_expression (adt_tyty, receiver_ref, true, expr.get_locus ()); diff --git a/gcc/rust/backend/rust-compile.cc b/gcc/rust/backend/rust-compile.cc index c8254c1..a5d32b1 100644 --- a/gcc/rust/backend/rust-compile.cc +++ b/gcc/rust/backend/rust-compile.cc @@ -73,6 +73,10 @@ CompileExpr::visit (HIR::CallExpr &expr) TyTy::ADTType *adt = static_cast<TyTy::ADTType *> (tyty); Btype *compiled_adt_type = TyTyResolveCompile::compile (ctx, tyty); + rust_assert (!adt->is_enum ()); + rust_assert (adt->number_of_variants () == 1); + auto variant = adt->get_variants ().at (0); + // this assumes all fields are in order from type resolution and if a // base struct was specified those fields are filed via accesors std::vector<Bexpression *> vals; @@ -83,7 +87,7 @@ CompileExpr::visit (HIR::CallExpr &expr) // assignments are coercion sites so lets convert the rvalue if // necessary - auto respective_field = adt->get_field (i); + auto respective_field = variant->get_field_at_index (i); auto expected = respective_field->get_field_type (); TyTy::BaseType *actual = nullptr; diff --git a/gcc/rust/hir/tree/rust-hir-item.h b/gcc/rust/hir/tree/rust-hir-item.h index 12d0c20..c5a8d06 100644 --- a/gcc/rust/hir/tree/rust-hir-item.h +++ b/gcc/rust/hir/tree/rust-hir-item.h @@ -2005,14 +2005,7 @@ public: void accept_vis (HIRVisitor &vis) override; - void iterate (std::function<bool (StructField &)> cb) - { - for (auto &variant : variants) - { - if (!cb (variant)) - return; - } - } + std::vector<StructField> &get_variants () { return variants; } WhereClause &get_where_clause () { return where_clause; } diff --git a/gcc/rust/lint/rust-lint-marklive.cc b/gcc/rust/lint/rust-lint-marklive.cc index 4b095ab4..ef207bc 100644 --- a/gcc/rust/lint/rust-lint-marklive.cc +++ b/gcc/rust/lint/rust-lint-marklive.cc @@ -228,11 +228,17 @@ MarkLive::visit (HIR::FieldAccessExpr &expr) } rust_assert (adt != nullptr); + rust_assert (!adt->is_enum ()); + rust_assert (adt->number_of_variants () == 1); + + TyTy::VariantDef *variant = adt->get_variants ().at (0); // get the field index - size_t index = 0; - adt->get_field (expr.get_field_name (), &index); - if (index >= adt->num_fields ()) + size_t index; + TyTy::StructFieldType *field; + bool ok = variant->lookup_field (expr.get_field_name (), &field, &index); + rust_assert (ok); + if (index >= variant->num_fields ()) { rust_error_at (expr.get_receiver_expr ()->get_locus (), "cannot access struct %s by index: %ld", @@ -241,7 +247,7 @@ MarkLive::visit (HIR::FieldAccessExpr &expr) } // get the field hir id - HirId field_id = adt->get_field (index)->get_ref (); + HirId field_id = field->get_ref (); mark_hir_id (field_id); } diff --git a/gcc/rust/typecheck/rust-hir-type-check-expr.h b/gcc/rust/typecheck/rust-hir-type-check-expr.h index e71d1e9..851407e 100644 --- a/gcc/rust/typecheck/rust-hir-type-check-expr.h +++ b/gcc/rust/typecheck/rust-hir-type-check-expr.h @@ -113,14 +113,18 @@ public: } TyTy::ADTType *adt = static_cast<TyTy::ADTType *> (resolved); + rust_assert (!adt->is_enum ()); + rust_assert (adt->number_of_variants () == 1); + + TyTy::VariantDef *variant = adt->get_variants ().at (0); TupleIndex index = expr.get_tuple_index (); - if ((size_t) index >= adt->num_fields ()) + if ((size_t) index >= variant->num_fields ()) { rust_error_at (expr.get_locus (), "unknown field at index %i", index); return; } - auto field_tyty = adt->get_field ((size_t) index); + auto field_tyty = variant->get_field_at_index ((size_t) index); if (field_tyty == nullptr) { rust_error_at (expr.get_locus (), @@ -984,8 +988,15 @@ public: } TyTy::ADTType *adt = static_cast<TyTy::ADTType *> (struct_base); - auto resolved = adt->get_field (expr.get_field_name ()); - if (resolved == nullptr) + rust_assert (!adt->is_enum ()); + rust_assert (adt->number_of_variants () == 1); + + TyTy::VariantDef *vaiant = adt->get_variants ().at (0); + + TyTy::StructFieldType *lookup = nullptr; + bool found + = vaiant->lookup_field (expr.get_field_name (), &lookup, nullptr); + if (!found) { rust_error_at (expr.get_locus (), "unknown field [%s] for type [%s]", expr.get_field_name ().c_str (), @@ -993,7 +1004,7 @@ public: return; } - infered = resolved->get_field_type (); + infered = lookup->get_field_type (); } void visit (HIR::QualifiedPathInExpression &expr) override; diff --git a/gcc/rust/typecheck/rust-hir-type-check-stmt.h b/gcc/rust/typecheck/rust-hir-type-check-stmt.h index 74bc037..2195968 100644 --- a/gcc/rust/typecheck/rust-hir-type-check-stmt.h +++ b/gcc/rust/typecheck/rust-hir-type-check-stmt.h @@ -155,7 +155,6 @@ public: } std::vector<TyTy::StructFieldType *> fields; - size_t idx = 0; for (auto &field : struct_decl.get_fields ()) { @@ -170,12 +169,19 @@ public: idx++; } + // there is only a single variant + std::vector<TyTy::VariantDef *> variants; + variants.push_back ( + new TyTy::VariantDef (struct_decl.get_identifier (), + TyTy::VariantDef::VariantType::TUPLE, + std::move (fields))); + TyTy::BaseType *type = new TyTy::ADTType (struct_decl.get_mappings ().get_hirid (), mappings->get_next_hir_id (), struct_decl.get_identifier (), TyTy::ADTType::ADTKind::TUPLE_STRUCT, - std::move (fields), std::move (substitutions)); + std::move (variants), std::move (substitutions)); context->insert_type (struct_decl.get_mappings (), type); infered = type; @@ -222,12 +228,19 @@ public: ty_field->get_field_type ()); } + // there is only a single variant + std::vector<TyTy::VariantDef *> variants; + variants.push_back ( + new TyTy::VariantDef (struct_decl.get_identifier (), + TyTy::VariantDef::VariantType::STRUCT, + std::move (fields))); + TyTy::BaseType *type = new TyTy::ADTType (struct_decl.get_mappings ().get_hirid (), mappings->get_next_hir_id (), struct_decl.get_identifier (), TyTy::ADTType::ADTKind::STRUCT_STRUCT, - std::move (fields), std::move (substitutions)); + std::move (variants), std::move (substitutions)); context->insert_type (struct_decl.get_mappings (), type); infered = type; @@ -261,18 +274,25 @@ public: } } - std::vector<TyTy::StructFieldType *> variants; - union_decl.iterate ([&] (HIR::StructField &variant) mutable -> bool { - TyTy::BaseType *variant_type - = TypeCheckType::Resolve (variant.get_field_type ().get ()); - TyTy::StructFieldType *ty_variant - = new TyTy::StructFieldType (variant.get_mappings ().get_hirid (), - variant.get_field_name (), variant_type); - variants.push_back (ty_variant); - context->insert_type (variant.get_mappings (), - ty_variant->get_field_type ()); - return true; - }); + std::vector<TyTy::StructFieldType *> fields; + for (auto &variant : union_decl.get_variants ()) + { + TyTy::BaseType *variant_type + = TypeCheckType::Resolve (variant.get_field_type ().get ()); + TyTy::StructFieldType *ty_variant + = new TyTy::StructFieldType (variant.get_mappings ().get_hirid (), + variant.get_field_name (), variant_type); + fields.push_back (ty_variant); + context->insert_type (variant.get_mappings (), + ty_variant->get_field_type ()); + } + + // there is only a single variant + std::vector<TyTy::VariantDef *> variants; + variants.push_back ( + new TyTy::VariantDef (union_decl.get_identifier (), + TyTy::VariantDef::VariantType::STRUCT, + std::move (fields))); TyTy::BaseType *type = new TyTy::ADTType (union_decl.get_mappings ().get_hirid (), diff --git a/gcc/rust/typecheck/rust-hir-type-check-toplevel.h b/gcc/rust/typecheck/rust-hir-type-check-toplevel.h index a32d4a4..4dae953 100644 --- a/gcc/rust/typecheck/rust-hir-type-check-toplevel.h +++ b/gcc/rust/typecheck/rust-hir-type-check-toplevel.h @@ -87,7 +87,6 @@ public: } std::vector<TyTy::StructFieldType *> fields; - size_t idx = 0; for (auto &field : struct_decl.get_fields ()) { @@ -102,12 +101,19 @@ public: idx++; } + // there is only a single variant + std::vector<TyTy::VariantDef *> variants; + variants.push_back ( + new TyTy::VariantDef (struct_decl.get_identifier (), + TyTy::VariantDef::VariantType::TUPLE, + std::move (fields))); + TyTy::BaseType *type = new TyTy::ADTType (struct_decl.get_mappings ().get_hirid (), mappings->get_next_hir_id (), struct_decl.get_identifier (), TyTy::ADTType::ADTKind::TUPLE_STRUCT, - std::move (fields), std::move (substitutions)); + std::move (variants), std::move (substitutions)); context->insert_type (struct_decl.get_mappings (), type); } @@ -165,12 +171,19 @@ public: ty_field->get_field_type ()); } + // there is only a single variant + std::vector<TyTy::VariantDef *> variants; + variants.push_back ( + new TyTy::VariantDef (struct_decl.get_identifier (), + TyTy::VariantDef::VariantType::STRUCT, + std::move (fields))); + TyTy::BaseType *type = new TyTy::ADTType (struct_decl.get_mappings ().get_hirid (), mappings->get_next_hir_id (), struct_decl.get_identifier (), TyTy::ADTType::ADTKind::STRUCT_STRUCT, - std::move (fields), std::move (substitutions)); + std::move (variants), std::move (substitutions)); context->insert_type (struct_decl.get_mappings (), type); } @@ -208,18 +221,25 @@ public: ResolveWhereClauseItem::Resolve (*where_clause_item.get ()); } - std::vector<TyTy::StructFieldType *> variants; - union_decl.iterate ([&] (HIR::StructField &variant) mutable -> bool { - TyTy::BaseType *variant_type - = TypeCheckType::Resolve (variant.get_field_type ().get ()); - TyTy::StructFieldType *ty_variant - = new TyTy::StructFieldType (variant.get_mappings ().get_hirid (), - variant.get_field_name (), variant_type); - variants.push_back (ty_variant); - context->insert_type (variant.get_mappings (), - ty_variant->get_field_type ()); - return true; - }); + std::vector<TyTy::StructFieldType *> fields; + for (auto &variant : union_decl.get_variants ()) + { + TyTy::BaseType *variant_type + = TypeCheckType::Resolve (variant.get_field_type ().get ()); + TyTy::StructFieldType *ty_variant + = new TyTy::StructFieldType (variant.get_mappings ().get_hirid (), + variant.get_field_name (), variant_type); + fields.push_back (ty_variant); + context->insert_type (variant.get_mappings (), + ty_variant->get_field_type ()); + } + + // there is only a single variant + std::vector<TyTy::VariantDef *> variants; + variants.push_back ( + new TyTy::VariantDef (union_decl.get_identifier (), + TyTy::VariantDef::VariantType::STRUCT, + std::move (fields))); TyTy::BaseType *type = new TyTy::ADTType (union_decl.get_mappings ().get_hirid (), diff --git a/gcc/rust/typecheck/rust-hir-type-check.cc b/gcc/rust/typecheck/rust-hir-type-check.cc index bd06473..339429f 100644 --- a/gcc/rust/typecheck/rust-hir-type-check.cc +++ b/gcc/rust/typecheck/rust-hir-type-check.cc @@ -178,7 +178,11 @@ TypeCheckStructExpr::visit (HIR::StructExprStructFields &struct_expr) } // check the arguments are all assigned and fix up the ordering - if (fields_assigned.size () != struct_path_resolved->num_fields ()) + rust_assert (!struct_path_resolved->is_enum ()); + rust_assert (struct_path_resolved->number_of_variants () == 1); + TyTy::VariantDef *variant = struct_path_resolved->get_variants ().at (0); + + if (fields_assigned.size () != variant->num_fields ()) { if (struct_def->is_union ()) { @@ -201,7 +205,7 @@ TypeCheckStructExpr::visit (HIR::StructExprStructFields &struct_expr) // we have a struct base to assign the missing fields from. // the missing fields can be implicit FieldAccessExprs for the value std::set<std::string> missing_fields; - for (auto &field : struct_path_resolved->get_fields ()) + for (auto &field : variant->get_fields ()) { auto it = fields_assigned.find (field->get_name ()); if (it == fields_assigned.end ()) @@ -235,7 +239,7 @@ TypeCheckStructExpr::visit (HIR::StructExprStructFields &struct_expr) struct_expr.struct_base->base_struct->get_locus ()); size_t field_index; - bool ok = struct_path_resolved->get_field (missing, &field_index); + bool ok = variant->lookup_field (missing, nullptr, &field_index); rust_assert (ok); adtFieldIndexToField[field_index] = implicit_field; @@ -291,10 +295,14 @@ TypeCheckStructExpr::visit (HIR::StructExprFieldIdentifierValue &field) return; } + rust_assert (!struct_path_resolved->is_enum ()); + rust_assert (struct_path_resolved->number_of_variants () == 1); + TyTy::VariantDef *variant = struct_path_resolved->get_variants ().at (0); + size_t field_index; - TyTy::StructFieldType *field_type - = struct_path_resolved->get_field (field.field_name, &field_index); - if (field_type == nullptr) + TyTy::StructFieldType *field_type; + bool ok = variant->lookup_field (field.field_name, &field_type, &field_index); + if (!ok) { rust_error_at (field.get_locus (), "unknown field"); return; @@ -320,11 +328,14 @@ TypeCheckStructExpr::visit (HIR::StructExprFieldIndexValue &field) return; } - size_t field_index; + rust_assert (!struct_path_resolved->is_enum ()); + rust_assert (struct_path_resolved->number_of_variants () == 1); + TyTy::VariantDef *variant = struct_path_resolved->get_variants ().at (0); - TyTy::StructFieldType *field_type - = struct_path_resolved->get_field (field_name, &field_index); - if (field_type == nullptr) + size_t field_index; + TyTy::StructFieldType *field_type; + bool ok = variant->lookup_field (field_name, &field_type, &field_index); + if (!ok) { rust_error_at (field.get_locus (), "unknown field"); return; @@ -349,10 +360,15 @@ TypeCheckStructExpr::visit (HIR::StructExprFieldIdentifier &field) return; } + rust_assert (!struct_path_resolved->is_enum ()); + rust_assert (struct_path_resolved->number_of_variants () == 1); + TyTy::VariantDef *variant = struct_path_resolved->get_variants ().at (0); + size_t field_index; - TyTy::StructFieldType *field_type - = struct_path_resolved->get_field (field.get_field_name (), &field_index); - if (field_type == nullptr) + TyTy::StructFieldType *field_type; + bool ok = variant->lookup_field (field.get_field_name (), &field_type, + &field_index); + if (!ok) { rust_error_at (field.get_locus (), "unknown field"); return; diff --git a/gcc/rust/typecheck/rust-tyty-cast.h b/gcc/rust/typecheck/rust-tyty-cast.h index 874fe2a..aaa589b 100644 --- a/gcc/rust/typecheck/rust-tyty-cast.h +++ b/gcc/rust/typecheck/rust-tyty-cast.h @@ -993,29 +993,47 @@ public: void visit (ADTType &type) override { + if (base->get_adt_kind () != type.get_adt_kind ()) + { + BaseCastRules::visit (type); + return; + } + if (base->get_identifier ().compare (type.get_identifier ()) != 0) { BaseCastRules::visit (type); return; } - if (base->num_fields () != type.num_fields ()) + if (base->number_of_variants () != type.number_of_variants ()) { BaseCastRules::visit (type); return; } - for (size_t i = 0; i < type.num_fields (); ++i) + for (size_t i = 0; i < type.number_of_variants (); ++i) { - TyTy::StructFieldType *base_field = base->get_field (i); - TyTy::StructFieldType *other_field = type.get_field (i); + TyTy::VariantDef *a = base->get_variants ().at (i); + TyTy::VariantDef *b = type.get_variants ().at (i); - TyTy::BaseType *this_field_ty = base_field->get_field_type (); - TyTy::BaseType *other_field_ty = other_field->get_field_type (); + if (a->num_fields () != b->num_fields ()) + { + BaseCastRules::visit (type); + return; + } - BaseType *unified_ty = this_field_ty->unify (other_field_ty); - if (unified_ty->get_kind () == TyTy::TypeKind::ERROR) - return; + for (size_t j = 0; j < a->num_fields (); j++) + { + TyTy::StructFieldType *base_field = a->get_field_at_index (i); + TyTy::StructFieldType *other_field = b->get_field_at_index (i); + + TyTy::BaseType *this_field_ty = base_field->get_field_type (); + TyTy::BaseType *other_field_ty = other_field->get_field_type (); + + BaseType *unified_ty = this_field_ty->unify (other_field_ty); + if (unified_ty->get_kind () == TyTy::TypeKind::ERROR) + return; + } } resolved = type.clone (); diff --git a/gcc/rust/typecheck/rust-tyty-cmp.h b/gcc/rust/typecheck/rust-tyty-cmp.h index 68b04fe..4ab3df2 100644 --- a/gcc/rust/typecheck/rust-tyty-cmp.h +++ b/gcc/rust/typecheck/rust-tyty-cmp.h @@ -990,32 +990,50 @@ public: void visit (const ADTType &type) override { - if (base->get_identifier ().compare (type.get_identifier ()) != 0) + if (base->get_adt_kind () != type.get_adt_kind ()) { BaseCmp::visit (type); return; } - if (base->num_fields () != type.num_fields ()) + if (base->get_identifier ().compare (type.get_identifier ()) != 0) { BaseCmp::visit (type); return; } - for (size_t i = 0; i < type.num_fields (); ++i) + if (base->number_of_variants () != type.number_of_variants ()) { - const TyTy::StructFieldType *base_field = base->get_imm_field (i); - const TyTy::StructFieldType *other_field = type.get_imm_field (i); + BaseCmp::visit (type); + return; + } - TyTy::BaseType *this_field_ty = base_field->get_field_type (); - TyTy::BaseType *other_field_ty = other_field->get_field_type (); + for (size_t i = 0; i < type.number_of_variants (); ++i) + { + TyTy::VariantDef *a = base->get_variants ().at (i); + TyTy::VariantDef *b = type.get_variants ().at (i); - if (!this_field_ty->can_eq (other_field_ty, emit_error_flag, - autoderef_mode_flag)) + if (a->num_fields () != b->num_fields ()) { BaseCmp::visit (type); return; } + + for (size_t j = 0; j < a->num_fields (); j++) + { + TyTy::StructFieldType *base_field = a->get_field_at_index (i); + TyTy::StructFieldType *other_field = b->get_field_at_index (i); + + TyTy::BaseType *this_field_ty = base_field->get_field_type (); + TyTy::BaseType *other_field_ty = other_field->get_field_type (); + + if (!this_field_ty->can_eq (other_field_ty, emit_error_flag, + autoderef_mode_flag)) + { + BaseCmp::visit (type); + return; + } + } } ok = true; diff --git a/gcc/rust/typecheck/rust-tyty-coercion.h b/gcc/rust/typecheck/rust-tyty-coercion.h index fc34ac8..6525da3 100644 --- a/gcc/rust/typecheck/rust-tyty-coercion.h +++ b/gcc/rust/typecheck/rust-tyty-coercion.h @@ -1000,29 +1000,47 @@ public: void visit (ADTType &type) override { + if (base->get_adt_kind () != type.get_adt_kind ()) + { + BaseCoercionRules::visit (type); + return; + } + if (base->get_identifier ().compare (type.get_identifier ()) != 0) { BaseCoercionRules::visit (type); return; } - if (base->num_fields () != type.num_fields ()) + if (base->number_of_variants () != type.number_of_variants ()) { BaseCoercionRules::visit (type); return; } - for (size_t i = 0; i < type.num_fields (); ++i) + for (size_t i = 0; i < type.number_of_variants (); ++i) { - TyTy::StructFieldType *base_field = base->get_field (i); - TyTy::StructFieldType *other_field = type.get_field (i); + TyTy::VariantDef *a = base->get_variants ().at (i); + TyTy::VariantDef *b = type.get_variants ().at (i); - TyTy::BaseType *this_field_ty = base_field->get_field_type (); - TyTy::BaseType *other_field_ty = other_field->get_field_type (); + if (a->num_fields () != b->num_fields ()) + { + BaseCoercionRules::visit (type); + return; + } - BaseType *unified_ty = this_field_ty->unify (other_field_ty); - if (unified_ty->get_kind () == TyTy::TypeKind::ERROR) - return; + for (size_t j = 0; j < a->num_fields (); j++) + { + TyTy::StructFieldType *base_field = a->get_field_at_index (i); + TyTy::StructFieldType *other_field = b->get_field_at_index (i); + + TyTy::BaseType *this_field_ty = base_field->get_field_type (); + TyTy::BaseType *other_field_ty = other_field->get_field_type (); + + BaseType *unified_ty = this_field_ty->unify (other_field_ty); + if (unified_ty->get_kind () == TyTy::TypeKind::ERROR) + return; + } } resolved = type.clone (); diff --git a/gcc/rust/typecheck/rust-tyty-rules.h b/gcc/rust/typecheck/rust-tyty-rules.h index b7b845a..db86de9 100644 --- a/gcc/rust/typecheck/rust-tyty-rules.h +++ b/gcc/rust/typecheck/rust-tyty-rules.h @@ -1016,29 +1016,47 @@ public: void visit (ADTType &type) override { + if (base->get_adt_kind () != type.get_adt_kind ()) + { + BaseRules::visit (type); + return; + } + if (base->get_identifier ().compare (type.get_identifier ()) != 0) { BaseRules::visit (type); return; } - if (base->num_fields () != type.num_fields ()) + if (base->number_of_variants () != type.number_of_variants ()) { BaseRules::visit (type); return; } - for (size_t i = 0; i < type.num_fields (); ++i) + for (size_t i = 0; i < type.number_of_variants (); ++i) { - TyTy::StructFieldType *base_field = base->get_field (i); - TyTy::StructFieldType *other_field = type.get_field (i); + TyTy::VariantDef *a = base->get_variants ().at (i); + TyTy::VariantDef *b = type.get_variants ().at (i); - TyTy::BaseType *this_field_ty = base_field->get_field_type (); - TyTy::BaseType *other_field_ty = other_field->get_field_type (); + if (a->num_fields () != b->num_fields ()) + { + BaseRules::visit (type); + return; + } - BaseType *unified_ty = this_field_ty->unify (other_field_ty); - if (unified_ty->get_kind () == TyTy::TypeKind::ERROR) - return; + for (size_t j = 0; j < a->num_fields (); j++) + { + TyTy::StructFieldType *base_field = a->get_field_at_index (i); + TyTy::StructFieldType *other_field = b->get_field_at_index (i); + + TyTy::BaseType *this_field_ty = base_field->get_field_type (); + TyTy::BaseType *other_field_ty = other_field->get_field_type (); + + BaseType *unified_ty = this_field_ty->unify (other_field_ty); + if (unified_ty->get_kind () == TyTy::TypeKind::ERROR) + return; + } } resolved = type.clone (); diff --git a/gcc/rust/typecheck/rust-tyty.cc b/gcc/rust/typecheck/rust-tyty.cc index d1db835..9f345c3 100644 --- a/gcc/rust/typecheck/rust-tyty.cc +++ b/gcc/rust/typecheck/rust-tyty.cc @@ -590,35 +590,16 @@ ADTType::accept_vis (TyConstVisitor &vis) const std::string ADTType::as_string () const { - if (num_fields () == 0) - return identifier; - - std::string fields_buffer; - for (size_t i = 0; i < num_fields (); ++i) + std::string variants_buffer; + for (size_t i = 0; i < number_of_variants (); ++i) { - fields_buffer += get_field (i)->as_string (); - if ((i + 1) < num_fields ()) - fields_buffer += ", "; + TyTy::VariantDef *variant = variants.at (i); + variants_buffer += variant->as_string (); + if ((i + 1) < number_of_variants ()) + variants_buffer += ", "; } - return identifier + subst_as_string () + "{" + fields_buffer + "}"; -} - -const StructFieldType * -ADTType::get_field (size_t index) const -{ - return fields.at (index); -} - -const BaseType * -ADTType::get_field_type (size_t index) const -{ - const StructFieldType *ref = get_field (index); - auto context = Resolver::TypeCheckContext::get (); - BaseType *lookup = nullptr; - bool ok = context->lookup_type (ref->get_field_type ()->get_ref (), &lookup); - rust_assert (ok); - return lookup; + return identifier + subst_as_string () + "{" + variants_buffer + "}"; } BaseType * @@ -657,7 +638,10 @@ ADTType::is_equal (const BaseType &other) const return false; auto other2 = static_cast<const ADTType &> (other); - if (num_fields () != other2.num_fields ()) + if (get_adt_kind () != other2.get_adt_kind ()) + return false; + + if (number_of_variants () != other2.number_of_variants ()) return false; if (has_subsititions_defined () != other2.has_subsititions_defined ()) @@ -683,9 +667,12 @@ ADTType::is_equal (const BaseType &other) const } else { - for (size_t i = 0; i < num_fields (); i++) + for (size_t i = 0; i < number_of_variants (); i++) { - if (!get_field (i)->is_equal (*other2.get_field (i))) + const TyTy::VariantDef *a = get_variants ().at (i); + const TyTy::VariantDef *b = other2.get_variants ().at (i); + + if (!a->is_equal (*b)) return false; } } @@ -696,12 +683,12 @@ ADTType::is_equal (const BaseType &other) const BaseType * ADTType::clone () const { - std::vector<StructFieldType *> cloned_fields; - for (auto &f : fields) - cloned_fields.push_back ((StructFieldType *) f->clone ()); + std::vector<VariantDef *> cloned_variants; + for (auto &variant : variants) + cloned_variants.push_back (variant->clone ()); return new ADTType (get_ref (), get_ty_ref (), identifier, get_adt_kind (), - cloned_fields, clone_substs (), used_arguments, + cloned_variants, clone_substs (), used_arguments, get_combined_refs ()); } @@ -772,11 +759,14 @@ ADTType::handle_substitions (SubstitutionArgumentMappings subst_mappings) sub.fill_param_ty (*arg.get_tyty (), subst_mappings.get_locus ()); } - for (auto &field : adt->fields) + for (auto &variant : adt->get_variants ()) { - bool ok = ::Rust::TyTy::handle_substitions (subst_mappings, field); - if (!ok) - return adt; + for (auto &field : variant->get_fields ()) + { + bool ok = ::Rust::TyTy::handle_substitions (subst_mappings, field); + if (!ok) + return adt; + } } return adt; @@ -2556,18 +2546,22 @@ TypeCheckCallExpr::visit (ADTType &type) return; } - if (call.num_params () != type.num_fields ()) + rust_assert (!type.is_enum ()); + rust_assert (type.number_of_variants () == 1); + TyTy::VariantDef *variant = type.get_variants ().at (0); + + if (call.num_params () != variant->num_fields ()) { rust_error_at (call.get_locus (), "unexpected number of arguments %lu expected %lu", - call.num_params (), type.num_fields ()); + call.num_params (), variant->num_fields ()); return; } size_t i = 0; for (auto &argument : call.get_arguments ()) { - StructFieldType *field = type.get_field (i); + StructFieldType *field = variant->get_field_at_index (i); BaseType *field_tyty = field->get_field_type (); BaseType *arg = Resolver::TypeCheckExpr::Resolve (argument.get (), false); diff --git a/gcc/rust/typecheck/rust-tyty.h b/gcc/rust/typecheck/rust-tyty.h index 3077281..909f210 100644 --- a/gcc/rust/typecheck/rust-tyty.h +++ b/gcc/rust/typecheck/rust-tyty.h @@ -1009,6 +1009,139 @@ protected: SubstitutionArgumentMappings used_arguments; }; +// https://doc.rust-lang.org/nightly/nightly-rustc/rustc_middle/ty/struct.VariantDef.html +class VariantDef +{ +public: + enum VariantType + { + NUM, + TUPLE, + STRUCT + }; + + VariantDef (std::string identifier, int discriminant) + : identifier (identifier), discriminant (discriminant) + { + type = VariantType::NUM; + fields = {}; + } + + VariantDef (std::string identifier, VariantType type, + std::vector<StructFieldType *> fields) + : identifier (identifier), type (type), fields (fields) + { + discriminant = 0; + rust_assert (type == VariantType::TUPLE || type == VariantType::STRUCT); + } + + VariantDef (std::string identifier, VariantType type, int discriminant, + std::vector<StructFieldType *> fields) + : identifier (identifier), type (type), discriminant (discriminant), + fields (fields) + { + rust_assert ((type == VariantType::NUM && fields.empty ()) + || (type == VariantType::TUPLE && discriminant == 0) + || (type == VariantType::STRUCT && discriminant == 0)); + } + + VariantType get_variant_type () const { return type; } + + std::string get_identifier () const { return identifier; } + int get_discriminant () const { return discriminant; } + + size_t num_fields () const { return fields.size (); } + StructFieldType *get_field_at_index (size_t index) + { + // FIXME this is not safe + return fields.at (index); + } + + std::vector<StructFieldType *> &get_fields () + { + rust_assert (type != NUM); + return fields; + } + + bool lookup_field (const std::string &lookup, StructFieldType **field_lookup, + size_t *index) const + { + size_t i = 0; + for (auto &field : fields) + { + if (field->get_name ().compare (lookup) == 0) + { + if (index != nullptr) + *index = i; + + if (field_lookup != nullptr) + *field_lookup = field; + + return true; + } + i++; + } + return false; + } + + std::string as_string () const + { + if (type == VariantType::NUM) + return identifier + " = " + std::to_string (discriminant); + + std::string buffer; + for (size_t i = 0; i < fields.size (); ++i) + { + buffer += fields.at (i)->as_string (); + if ((i + 1) < fields.size ()) + buffer += ", "; + } + + if (type == VariantType::TUPLE) + return identifier + " (" + buffer + ")"; + else + return identifier + " {" + buffer + "}"; + } + + bool is_equal (const VariantDef &other) const + { + if (type != other.type) + return false; + + if (identifier.compare (other.identifier) != 0) + return false; + + if (discriminant != other.discriminant) + return false; + + if (fields.size () != other.fields.size ()) + return false; + + for (size_t i = 0; i < fields.size (); i++) + { + if (!fields.at (i)->is_equal (*other.fields.at (i))) + return false; + } + + return true; + } + + VariantDef *clone () const + { + std::vector<StructFieldType *> cloned_fields; + for (auto &f : fields) + cloned_fields.push_back ((StructFieldType *) f->clone ()); + + return new VariantDef (identifier, type, discriminant, cloned_fields); + } + +private: + std::string identifier; + VariantType type; + int discriminant; /* Either discriminant or fields are valid. */ + std::vector<StructFieldType *> fields; +}; + class ADTType : public BaseType, public SubstitutionRef { public: @@ -1017,36 +1150,48 @@ public: STRUCT_STRUCT, TUPLE_STRUCT, UNION, - // ENUM ? + ENUM }; ADTType (HirId ref, std::string identifier, ADTKind adt_kind, - std::vector<StructFieldType *> fields, + std::vector<VariantDef *> variants, std::vector<SubstitutionParamMapping> subst_refs, SubstitutionArgumentMappings generic_arguments = SubstitutionArgumentMappings::error (), std::set<HirId> refs = std::set<HirId> ()) : BaseType (ref, ref, TypeKind::ADT, refs), SubstitutionRef (std::move (subst_refs), std::move (generic_arguments)), - identifier (identifier), fields (fields), adt_kind (adt_kind) + identifier (identifier), variants (variants), adt_kind (adt_kind) {} ADTType (HirId ref, HirId ty_ref, std::string identifier, ADTKind adt_kind, - std::vector<StructFieldType *> fields, + std::vector<VariantDef *> variants, std::vector<SubstitutionParamMapping> subst_refs, SubstitutionArgumentMappings generic_arguments = SubstitutionArgumentMappings::error (), std::set<HirId> refs = std::set<HirId> ()) : BaseType (ref, ty_ref, TypeKind::ADT, refs), SubstitutionRef (std::move (subst_refs), std::move (generic_arguments)), - identifier (identifier), fields (fields), adt_kind (adt_kind) + identifier (identifier), variants (variants), adt_kind (adt_kind) {} ADTKind get_adt_kind () const { return adt_kind; } + + bool is_struct_struct () const { return adt_kind == STRUCT_STRUCT; } bool is_tuple_struct () const { return adt_kind == TUPLE_STRUCT; } bool is_union () const { return adt_kind == UNION; } + bool is_enum () const { return adt_kind == ENUM; } - bool is_unit () const override { return this->fields.empty (); } + bool is_unit () const override + { + if (number_of_variants () == 0) + return true; + + if (number_of_variants () == 1) + return variants.at (0)->num_fields () == 0; + + return false; + } void accept_vis (TyVisitor &vis) override; void accept_vis (TyConstVisitor &vis) const override; @@ -1061,8 +1206,6 @@ public: bool is_equal (const BaseType &other) const override; - size_t num_fields () const { return fields.size (); } - std::string get_identifier () const { return identifier; } std::string get_name () const override final @@ -1070,41 +1213,8 @@ public: return identifier + subst_as_string (); } - BaseType *get_field_type (size_t index); - - const BaseType *get_field_type (size_t index) const; - - const StructFieldType *get_field (size_t index) const; - - StructFieldType *get_field (size_t index) { return fields.at (index); } - - const StructFieldType *get_imm_field (size_t index) const - { - return fields.at (index); - } - - StructFieldType *get_field (const std::string &lookup, - size_t *index = nullptr) const - { - size_t i = 0; - for (auto &field : fields) - { - if (field->get_name ().compare (lookup) == 0) - { - if (index != nullptr) - *index = i; - return field; - } - i++; - } - return nullptr; - } - BaseType *clone () const final override; - std::vector<StructFieldType *> &get_fields () { return fields; } - const std::vector<StructFieldType *> &get_fields () const { return fields; } - bool needs_generic_substitutions () const override final { return needs_substitution (); @@ -1117,12 +1227,17 @@ public: return has_substitutions (); } + size_t number_of_variants () const { return variants.size (); } + + std::vector<VariantDef *> &get_variants () { return variants; } + const std::vector<VariantDef *> &get_variants () const { return variants; } + ADTType * handle_substitions (SubstitutionArgumentMappings mappings) override final; private: std::string identifier; - std::vector<StructFieldType *> fields; + std::vector<VariantDef *> variants; ADTType::ADTKind adt_kind; }; |