diff options
author | bors[bot] <26634292+bors[bot]@users.noreply.github.com> | 2021-08-02 09:13:20 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-08-02 09:13:20 +0000 |
commit | b56c6fdfaad9ca1681714d288d1282cf08554462 (patch) | |
tree | d815a7d24660b24da5fcf7d05b5f2b3e8de99ae1 /gcc | |
parent | 06a65591eb09fbec25e4ee38c1cf751b416af5bf (diff) | |
parent | 389fd74a3f3e9422a965263b6961b51295c55976 (diff) | |
download | gcc-b56c6fdfaad9ca1681714d288d1282cf08554462.zip gcc-b56c6fdfaad9ca1681714d288d1282cf08554462.tar.gz gcc-b56c6fdfaad9ca1681714d288d1282cf08554462.tar.bz2 |
Merge #601
601: union support for hir type checking and gcc backend r=dkm a=dkm
From Mark Wielaard : https://gcc.gnu.org/pipermail/gcc-rust/2021-August/000107.html
>
> Treat a union as a Struct variant like a tuple struct. Add an
> iterator and get_identifier functions to the AST Union class. Same
> for the HIR Union class, plus a get_generics_params method. Add a new
> ADTKind enum and adt_kind field to the ADTType to select the
> underlying abstract data type (struct struct, tuple struct or union,
> with enum as possible future variant).
>
> An union constructor can have only one field. Add an union_index field
> to StructExprStruct which is set during type checking in the
> TypeCheckStructExpr HIR StructExprStructFields visitor.
>
> For the Gcc_backend class rename fill_in_struct to fill_in_fields and
> use it from a new union_type method. Handle union_index in
> constructor_expression (so only one field is initialized).
Fixes #157
Co-authored-by: Mark Wielaard <mark@klomp.org>
Diffstat (limited to 'gcc')
21 files changed, 529 insertions, 55 deletions
diff --git a/gcc/rust/ast/rust-item.h b/gcc/rust/ast/rust-item.h index 6d29c5b..5605b0b 100644 --- a/gcc/rust/ast/rust-item.h +++ b/gcc/rust/ast/rust-item.h @@ -2489,6 +2489,15 @@ public: std::vector<StructField> &get_variants () { return variants; } const std::vector<StructField> &get_variants () const { return variants; } + void iterate (std::function<bool (StructField &)> cb) + { + for (auto &variant : variants) + { + if (!cb (variant)) + return; + } + } + std::vector<std::unique_ptr<GenericParam> > &get_generic_params () { return generic_params; @@ -2505,6 +2514,8 @@ public: return where_clause; } + Identifier get_identifier () const { return union_name; } + protected: /* Use covariance to implement clone function as returning this object * rather than base */ diff --git a/gcc/rust/backend/rust-compile-context.h b/gcc/rust/backend/rust-compile-context.h index 0aaf084..8007c2f 100644 --- a/gcc/rust/backend/rust-compile-context.h +++ b/gcc/rust/backend/rust-compile-context.h @@ -417,9 +417,13 @@ public: fields.push_back (std::move (f)); } - Btype *struct_type_record = ctx->get_backend ()->struct_type (fields); + Btype *type_record; + if (type.is_union ()) + type_record = ctx->get_backend ()->union_type (fields); + else + type_record = ctx->get_backend ()->struct_type (fields); Btype *named_struct - = ctx->get_backend ()->named_type (type.get_name (), struct_type_record, + = ctx->get_backend ()->named_type (type.get_name (), type_record, ctx->get_mappings ()->lookup_location ( type.get_ty_ref ())); diff --git a/gcc/rust/backend/rust-compile-expr.h b/gcc/rust/backend/rust-compile-expr.h index 2a147ab..4658295 100644 --- a/gcc/rust/backend/rust-compile-expr.h +++ b/gcc/rust/backend/rust-compile-expr.h @@ -80,7 +80,7 @@ public: } translated - = ctx->get_backend ()->constructor_expression (tuple_type, vals, + = ctx->get_backend ()->constructor_expression (tuple_type, vals, -1, expr.get_locus ()); } @@ -595,6 +595,7 @@ public: translated = ctx->get_backend ()->constructor_expression (type, vals, + struct_expr.union_index, struct_expr.get_locus ()); } diff --git a/gcc/rust/backend/rust-compile.cc b/gcc/rust/backend/rust-compile.cc index 5ffd11a..aa9aa2d 100644 --- a/gcc/rust/backend/rust-compile.cc +++ b/gcc/rust/backend/rust-compile.cc @@ -79,7 +79,7 @@ CompileExpr::visit (HIR::CallExpr &expr) }); translated - = ctx->get_backend ()->constructor_expression (type, vals, + = ctx->get_backend ()->constructor_expression (type, vals, -1, expr.get_locus ()); } else diff --git a/gcc/rust/hir/rust-ast-lower-item.h b/gcc/rust/hir/rust-ast-lower-item.h index 80ca298..f168c7d 100644 --- a/gcc/rust/hir/rust-ast-lower-item.h +++ b/gcc/rust/hir/rust-ast-lower-item.h @@ -193,6 +193,57 @@ public: struct_decl.get_locus ()); } + void visit (AST::Union &union_decl) override + { + std::vector<std::unique_ptr<HIR::GenericParam> > generic_params; + if (union_decl.has_generics ()) + { + generic_params + = lower_generic_params (union_decl.get_generic_params ()); + } + + std::vector<std::unique_ptr<HIR::WhereClauseItem> > where_clause_items; + HIR::WhereClause where_clause (std::move (where_clause_items)); + HIR::Visibility vis = HIR::Visibility::create_public (); + + std::vector<HIR::StructField> variants; + union_decl.iterate ([&] (AST::StructField &variant) mutable -> bool { + HIR::Visibility vis = HIR::Visibility::create_public (); + HIR::Type *type + = ASTLoweringType::translate (variant.get_field_type ().get ()); + + auto crate_num = mappings->get_current_crate (); + Analysis::NodeMapping mapping (crate_num, variant.get_node_id (), + mappings->get_next_hir_id (crate_num), + mappings->get_next_localdef_id ( + crate_num)); + + HIR::StructField translated_variant (mapping, variant.get_field_name (), + std::unique_ptr<HIR::Type> (type), + vis, variant.get_locus (), + variant.get_outer_attrs ()); + variants.push_back (std::move (translated_variant)); + return true; + }); + + auto crate_num = mappings->get_current_crate (); + Analysis::NodeMapping mapping (crate_num, union_decl.get_node_id (), + mappings->get_next_hir_id (crate_num), + mappings->get_next_localdef_id (crate_num)); + + translated + = new HIR::Union (mapping, union_decl.get_identifier (), vis, + std::move (generic_params), std::move (where_clause), + std::move (variants), union_decl.get_outer_attrs (), + union_decl.get_locus ()); + + mappings->insert_defid_mapping (mapping.get_defid (), translated); + mappings->insert_hir_item (mapping.get_crate_num (), mapping.get_hirid (), + translated); + mappings->insert_location (crate_num, mapping.get_hirid (), + union_decl.get_locus ()); + } + void visit (AST::StaticItem &var) override { HIR::Visibility vis = HIR::Visibility::create_public (); diff --git a/gcc/rust/hir/rust-ast-lower-stmt.h b/gcc/rust/hir/rust-ast-lower-stmt.h index 9df6b74..2e97ca6 100644 --- a/gcc/rust/hir/rust-ast-lower-stmt.h +++ b/gcc/rust/hir/rust-ast-lower-stmt.h @@ -215,6 +215,59 @@ public: struct_decl.get_locus ()); } + void visit (AST::Union &union_decl) override + { + std::vector<std::unique_ptr<HIR::GenericParam> > generic_params; + if (union_decl.has_generics ()) + { + generic_params + = lower_generic_params (union_decl.get_generic_params ()); + } + + std::vector<std::unique_ptr<HIR::WhereClauseItem> > where_clause_items; + HIR::WhereClause where_clause (std::move (where_clause_items)); + HIR::Visibility vis = HIR::Visibility::create_public (); + + std::vector<HIR::StructField> variants; + union_decl.iterate ([&] (AST::StructField &variant) mutable -> bool { + HIR::Visibility vis = HIR::Visibility::create_public (); + HIR::Type *type + = ASTLoweringType::translate (variant.get_field_type ().get ()); + + auto crate_num = mappings->get_current_crate (); + Analysis::NodeMapping mapping (crate_num, variant.get_node_id (), + mappings->get_next_hir_id (crate_num), + mappings->get_next_localdef_id ( + crate_num)); + + // FIXME + // AST::StructField is missing Location info + Location variant_locus; + HIR::StructField translated_variant (mapping, variant.get_field_name (), + std::unique_ptr<HIR::Type> (type), + vis, variant_locus, + variant.get_outer_attrs ()); + variants.push_back (std::move (translated_variant)); + return true; + }); + + auto crate_num = mappings->get_current_crate (); + Analysis::NodeMapping mapping (crate_num, union_decl.get_node_id (), + mappings->get_next_hir_id (crate_num), + mappings->get_next_localdef_id (crate_num)); + + translated + = new HIR::Union (mapping, union_decl.get_identifier (), vis, + std::move (generic_params), std::move (where_clause), + std::move (variants), union_decl.get_outer_attrs (), + union_decl.get_locus ()); + + mappings->insert_hir_stmt (mapping.get_crate_num (), mapping.get_hirid (), + translated); + mappings->insert_location (crate_num, mapping.get_hirid (), + union_decl.get_locus ()); + } + void visit (AST::EmptyStmt &empty) override { auto crate_num = mappings->get_current_crate (); diff --git a/gcc/rust/hir/tree/rust-hir-expr.h b/gcc/rust/hir/tree/rust-hir-expr.h index 65c40d6..8d815c5 100644 --- a/gcc/rust/hir/tree/rust-hir-expr.h +++ b/gcc/rust/hir/tree/rust-hir-expr.h @@ -1449,6 +1449,10 @@ public: // FIXME make unique_ptr StructBase *struct_base; + // For unions there is just one field, the index + // is set when type checking + int union_index = -1; + std::string as_string () const override; bool has_struct_base () const { return struct_base != nullptr; } @@ -1467,7 +1471,8 @@ public: // copy constructor with vector clone StructExprStructFields (StructExprStructFields const &other) - : StructExprStruct (other), struct_base (other.struct_base) + : StructExprStruct (other), struct_base (other.struct_base), + union_index (other.union_index) { fields.reserve (other.fields.size ()); for (const auto &e : other.fields) @@ -1479,6 +1484,7 @@ public: { StructExprStruct::operator= (other); struct_base = other.struct_base; + union_index = other.union_index; fields.reserve (other.fields.size ()); for (const auto &e : other.fields) diff --git a/gcc/rust/hir/tree/rust-hir-item.h b/gcc/rust/hir/tree/rust-hir-item.h index 7d976c5..182fe87 100644 --- a/gcc/rust/hir/tree/rust-hir-item.h +++ b/gcc/rust/hir/tree/rust-hir-item.h @@ -1989,10 +1989,26 @@ public: Union (Union &&other) = default; Union &operator= (Union &&other) = default; + std::vector<std::unique_ptr<GenericParam> > &get_generic_params () + { + return generic_params; + } + + Identifier get_identifier () const { return union_name; } + Location get_locus () const { return locus; } void accept_vis (HIRVisitor &vis) override; + void iterate (std::function<bool (StructField &)> cb) + { + for (auto &variant : variants) + { + if (!cb (variant)) + return; + } + } + protected: /* Use covariance to implement clone function as returning this object * rather than base */ diff --git a/gcc/rust/resolve/rust-ast-resolve-item.h b/gcc/rust/resolve/rust-ast-resolve-item.h index 539229d..c67121d 100644 --- a/gcc/rust/resolve/rust-ast-resolve-item.h +++ b/gcc/rust/resolve/rust-ast-resolve-item.h @@ -260,6 +260,28 @@ public: resolver->get_type_scope ().pop (); } + void visit (AST::Union &union_decl) override + { + NodeId scope_node_id = union_decl.get_node_id (); + resolver->get_type_scope ().push (scope_node_id); + + if (union_decl.has_generics ()) + { + for (auto &generic : union_decl.get_generic_params ()) + { + ResolveGenericParam::go (generic.get (), union_decl.get_node_id ()); + } + } + + union_decl.iterate ([&] (AST::StructField &field) mutable -> bool { + ResolveType::go (field.get_field_type ().get (), + union_decl.get_node_id ()); + return true; + }); + + resolver->get_type_scope ().pop (); + } + void visit (AST::StaticItem &var) override { ResolveType::go (var.get_type ().get (), var.get_node_id ()); diff --git a/gcc/rust/resolve/rust-ast-resolve-stmt.h b/gcc/rust/resolve/rust-ast-resolve-stmt.h index 210a9fc..b604432 100644 --- a/gcc/rust/resolve/rust-ast-resolve-stmt.h +++ b/gcc/rust/resolve/rust-ast-resolve-stmt.h @@ -131,6 +131,38 @@ public: resolver->get_type_scope ().pop (); } + void visit (AST::Union &union_decl) override + { + auto path = CanonicalPath::new_seg (union_decl.get_node_id (), + union_decl.get_identifier ()); + resolver->get_type_scope ().insert ( + path, union_decl.get_node_id (), union_decl.get_locus (), false, + [&] (const CanonicalPath &, NodeId, Location locus) -> void { + RichLocation r (union_decl.get_locus ()); + r.add_range (locus); + rust_error_at (r, "redefined multiple times"); + }); + + NodeId scope_node_id = union_decl.get_node_id (); + resolver->get_type_scope ().push (scope_node_id); + + if (union_decl.has_generics ()) + { + for (auto &generic : union_decl.get_generic_params ()) + { + ResolveGenericParam::go (generic.get (), union_decl.get_node_id ()); + } + } + + union_decl.iterate ([&] (AST::StructField &field) mutable -> bool { + ResolveType::go (field.get_field_type ().get (), + union_decl.get_node_id ()); + return true; + }); + + resolver->get_type_scope ().pop (); + } + void visit (AST::Function &function) override { auto path = ResolveFunctionItemToCanonicalPath::resolve (function); diff --git a/gcc/rust/resolve/rust-ast-resolve-toplevel.h b/gcc/rust/resolve/rust-ast-resolve-toplevel.h index a042f5c..57a0534 100644 --- a/gcc/rust/resolve/rust-ast-resolve-toplevel.h +++ b/gcc/rust/resolve/rust-ast-resolve-toplevel.h @@ -81,6 +81,20 @@ public: }); } + void visit (AST::Union &union_decl) override + { + auto path + = prefix.append (CanonicalPath::new_seg (union_decl.get_node_id (), + union_decl.get_identifier ())); + resolver->get_type_scope ().insert ( + path, union_decl.get_node_id (), union_decl.get_locus (), false, + [&] (const CanonicalPath &, NodeId, Location locus) -> void { + RichLocation r (union_decl.get_locus ()); + r.add_range (locus); + rust_error_at (r, "redefined multiple times"); + }); + } + void visit (AST::StaticItem &var) override { auto path = prefix.append ( diff --git a/gcc/rust/rust-backend.h b/gcc/rust/rust-backend.h index be23fd3..4635796 100644 --- a/gcc/rust/rust-backend.h +++ b/gcc/rust/rust-backend.h @@ -178,6 +178,9 @@ public: // Get a struct type. virtual Btype *struct_type (const std::vector<Btyped_identifier> &fields) = 0; + // Get a union type. + virtual Btype *union_type (const std::vector<Btyped_identifier> &fields) = 0; + // Get an array type. virtual Btype *array_type (Btype *element_type, Bexpression *length) = 0; @@ -424,7 +427,7 @@ public: // corresponding fields in BTYPE. virtual Bexpression * constructor_expression (Btype *btype, const std::vector<Bexpression *> &vals, - Location) + int, Location) = 0; // Return an expression that constructs an array of BTYPE with INDEXES and diff --git a/gcc/rust/rust-gcc.cc b/gcc/rust/rust-gcc.cc index 44617a6..3e47a7c 100644 --- a/gcc/rust/rust-gcc.cc +++ b/gcc/rust/rust-gcc.cc @@ -265,6 +265,8 @@ public: Btype *struct_type (const std::vector<Btyped_identifier> &); + Btype *union_type (const std::vector<Btyped_identifier> &); + Btype *array_type (Btype *, Bexpression *); Btype *placeholder_pointer_type (const std::string &, Location, bool); @@ -377,7 +379,7 @@ public: Location); Bexpression *constructor_expression (Btype *, - const std::vector<Bexpression *> &, + const std::vector<Bexpression *> &, int, Location); Bexpression *array_constructor_expression (Btype *, @@ -531,7 +533,7 @@ private: Bfunction *make_function (tree t) { return new Bfunction (t); } - Btype *fill_in_struct (Btype *, const std::vector<Btyped_identifier> &); + Btype *fill_in_fields (Btype *, const std::vector<Btyped_identifier> &); Btype *fill_in_array (Btype *, Btype *, Bexpression *); @@ -1145,14 +1147,23 @@ Gcc_backend::function_ptr_type (Btype *result_type, Btype * Gcc_backend::struct_type (const std::vector<Btyped_identifier> &fields) { - return this->fill_in_struct (this->make_type (make_node (RECORD_TYPE)), + return this->fill_in_fields (this->make_type (make_node (RECORD_TYPE)), + fields); +} + +// Make a union type. + +Btype * +Gcc_backend::union_type (const std::vector<Btyped_identifier> &fields) +{ + return this->fill_in_fields (this->make_type (make_node (UNION_TYPE)), fields); } -// Fill in the fields of a struct type. +// Fill in the fields of a struct or union type. Btype * -Gcc_backend::fill_in_struct (Btype *fill, +Gcc_backend::fill_in_fields (Btype *fill, const std::vector<Btyped_identifier> &fields) { tree fill_tree = fill->get_tree (); @@ -1311,7 +1322,7 @@ Gcc_backend::set_placeholder_struct_type ( { tree t = placeholder->get_tree (); gcc_assert (TREE_CODE (t) == RECORD_TYPE && TYPE_FIELDS (t) == NULL_TREE); - Btype *r = this->fill_in_struct (placeholder, fields); + Btype *r = this->fill_in_fields (placeholder, fields); if (TYPE_NAME (t) != NULL_TREE) { @@ -1321,7 +1332,7 @@ Gcc_backend::set_placeholder_struct_type ( DECL_ORIGINAL_TYPE (TYPE_NAME (t)) = copy; TYPE_SIZE (copy) = NULL_TREE; Btype *bc = this->make_type (copy); - this->fill_in_struct (bc, fields); + this->fill_in_fields (bc, fields); delete bc; } @@ -1758,7 +1769,8 @@ Gcc_backend::struct_field_expression (Bexpression *bstruct, size_t index, if (struct_tree == error_mark_node || TREE_TYPE (struct_tree) == error_mark_node) return this->error_expression (); - gcc_assert (TREE_CODE (TREE_TYPE (struct_tree)) == RECORD_TYPE); + gcc_assert (TREE_CODE (TREE_TYPE (struct_tree)) == RECORD_TYPE + || TREE_CODE (TREE_TYPE (struct_tree)) == UNION_TYPE); tree field = TYPE_FIELDS (TREE_TYPE (struct_tree)); if (field == NULL_TREE) { @@ -2041,7 +2053,7 @@ Gcc_backend::lazy_boolean_expression (LazyBooleanOperator op, Bexpression *left, Bexpression * Gcc_backend::constructor_expression (Btype *btype, const std::vector<Bexpression *> &vals, - Location location) + int union_index, Location location) { tree type_tree = btype->get_tree (); if (type_tree == error_mark_node) @@ -2053,11 +2065,15 @@ Gcc_backend::constructor_expression (Btype *btype, tree sink = NULL_TREE; bool is_constant = true; tree field = TYPE_FIELDS (type_tree); - for (std::vector<Bexpression *>::const_iterator p = vals.begin (); - p != vals.end (); ++p, field = DECL_CHAIN (field)) + if (union_index != -1) { - gcc_assert (field != NULL_TREE); - tree val = (*p)->get_tree (); + gcc_assert (TREE_CODE (type_tree) == UNION_TYPE); + tree val = vals.front ()->get_tree (); + for (int i = 0; i < union_index; i++) + { + gcc_assert (field != NULL_TREE); + field = DECL_CHAIN (field); + } if (TREE_TYPE (field) == error_mark_node || val == error_mark_node || TREE_TYPE (val) == error_mark_node) return this->error_expression (); @@ -2070,17 +2086,49 @@ Gcc_backend::constructor_expression (Btype *btype, // would have been added as a map element for its // side-effects and construct an empty map. append_to_statement_list (val, &sink); - continue; } + else + { + constructor_elt empty = {NULL, NULL}; + constructor_elt *elt = init->quick_push (empty); + elt->index = field; + elt->value = this->convert_tree (TREE_TYPE (field), val, location); + if (!TREE_CONSTANT (elt->value)) + is_constant = false; + } + } + else + { + gcc_assert (TREE_CODE (type_tree) == RECORD_TYPE); + for (std::vector<Bexpression *>::const_iterator p = vals.begin (); + p != vals.end (); ++p, field = DECL_CHAIN (field)) + { + gcc_assert (field != NULL_TREE); + tree val = (*p)->get_tree (); + if (TREE_TYPE (field) == error_mark_node || val == error_mark_node + || TREE_TYPE (val) == error_mark_node) + return this->error_expression (); - constructor_elt empty = {NULL, NULL}; - constructor_elt *elt = init->quick_push (empty); - elt->index = field; - elt->value = this->convert_tree (TREE_TYPE (field), val, location); - if (!TREE_CONSTANT (elt->value)) - is_constant = false; + if (int_size_in_bytes (TREE_TYPE (field)) == 0) + { + // GIMPLE cannot represent indices of zero-sized types so + // trying to construct a map with zero-sized keys might lead + // to errors. Instead, we evaluate each expression that + // would have been added as a map element for its + // side-effects and construct an empty map. + append_to_statement_list (val, &sink); + continue; + } + + constructor_elt empty = {NULL, NULL}; + constructor_elt *elt = init->quick_push (empty); + elt->index = field; + elt->value = this->convert_tree (TREE_TYPE (field), val, location); + if (!TREE_CONSTANT (elt->value)) + is_constant = false; + } + gcc_assert (field == NULL_TREE); } - gcc_assert (field == NULL_TREE); tree ret = build_constructor (type_tree, init); if (is_constant) TREE_CONSTANT (ret) = 1; @@ -2781,6 +2829,7 @@ Gcc_backend::convert_tree (tree type_tree, tree expr_tree, Location location) || SCALAR_FLOAT_TYPE_P (type_tree) || COMPLEX_FLOAT_TYPE_P (type_tree)) return fold_convert_loc (location.gcc_location (), type_tree, expr_tree); else if (TREE_CODE (type_tree) == RECORD_TYPE + || TREE_CODE (type_tree) == UNION_TYPE || TREE_CODE (type_tree) == ARRAY_TYPE) { gcc_assert (int_size_in_bytes (type_tree) diff --git a/gcc/rust/typecheck/rust-hir-type-check-stmt.h b/gcc/rust/typecheck/rust-hir-type-check-stmt.h index 1b6f47c..77cbc06 100644 --- a/gcc/rust/typecheck/rust-hir-type-check-stmt.h +++ b/gcc/rust/typecheck/rust-hir-type-check-stmt.h @@ -159,7 +159,8 @@ public: TyTy::BaseType *type = new TyTy::ADTType (struct_decl.get_mappings ().get_hirid (), mappings->get_next_hir_id (), - struct_decl.get_identifier (), true, + struct_decl.get_identifier (), + TyTy::ADTType::ADTKind::TUPLE_STRUCT, std::move (fields), std::move (substitutions)); context->insert_type (struct_decl.get_mappings (), type); @@ -209,13 +210,66 @@ public: TyTy::BaseType *type = new TyTy::ADTType (struct_decl.get_mappings ().get_hirid (), mappings->get_next_hir_id (), - struct_decl.get_identifier (), false, + struct_decl.get_identifier (), + TyTy::ADTType::ADTKind::STRUCT_STRUCT, std::move (fields), std::move (substitutions)); context->insert_type (struct_decl.get_mappings (), type); infered = type; } + void visit (HIR::Union &union_decl) override + { + std::vector<TyTy::SubstitutionParamMapping> substitutions; + if (union_decl.has_generics ()) + { + for (auto &generic_param : union_decl.get_generic_params ()) + { + switch (generic_param.get ()->get_kind ()) + { + case HIR::GenericParam::GenericKind::LIFETIME: + // Skipping Lifetime completely until better handling. + break; + + case HIR::GenericParam::GenericKind::TYPE: { + auto param_type + = TypeResolveGenericParam::Resolve (generic_param.get ()); + context->insert_type (generic_param->get_mappings (), + param_type); + + substitutions.push_back (TyTy::SubstitutionParamMapping ( + static_cast<HIR::TypeParam &> (*generic_param), + param_type)); + } + break; + } + } + } + + std::vector<TyTy::StructFieldType *> variants; + union_decl.iterate ([&] (HIR::StructField &variant) mutable -> bool { + TyTy::BaseType *variant_type + = TypeCheckType::Resolve (variant.get_field_type ().get ()); + TyTy::StructFieldType *ty_variant + = new TyTy::StructFieldType (variant.get_mappings ().get_hirid (), + variant.get_field_name (), variant_type); + variants.push_back (ty_variant); + context->insert_type (variant.get_mappings (), + ty_variant->get_field_type ()); + return true; + }); + + TyTy::BaseType *type + = new TyTy::ADTType (union_decl.get_mappings ().get_hirid (), + mappings->get_next_hir_id (), + union_decl.get_identifier (), + TyTy::ADTType::ADTKind::UNION, std::move (variants), + std::move (substitutions)); + + context->insert_type (union_decl.get_mappings (), type); + infered = type; + } + void visit (HIR::Function &function) override { std::vector<TyTy::SubstitutionParamMapping> substitutions; diff --git a/gcc/rust/typecheck/rust-hir-type-check-toplevel.h b/gcc/rust/typecheck/rust-hir-type-check-toplevel.h index 18f3e72..5b9757f 100644 --- a/gcc/rust/typecheck/rust-hir-type-check-toplevel.h +++ b/gcc/rust/typecheck/rust-hir-type-check-toplevel.h @@ -94,7 +94,8 @@ public: TyTy::BaseType *type = new TyTy::ADTType (struct_decl.get_mappings ().get_hirid (), mappings->get_next_hir_id (), - struct_decl.get_identifier (), true, + struct_decl.get_identifier (), + TyTy::ADTType::ADTKind::TUPLE_STRUCT, std::move (fields), std::move (substitutions)); context->insert_type (struct_decl.get_mappings (), type); @@ -143,12 +144,64 @@ public: TyTy::BaseType *type = new TyTy::ADTType (struct_decl.get_mappings ().get_hirid (), mappings->get_next_hir_id (), - struct_decl.get_identifier (), false, + struct_decl.get_identifier (), + TyTy::ADTType::ADTKind::STRUCT_STRUCT, std::move (fields), std::move (substitutions)); context->insert_type (struct_decl.get_mappings (), type); } + void visit (HIR::Union &union_decl) override + { + std::vector<TyTy::SubstitutionParamMapping> substitutions; + if (union_decl.has_generics ()) + { + for (auto &generic_param : union_decl.get_generic_params ()) + { + switch (generic_param.get ()->get_kind ()) + { + case HIR::GenericParam::GenericKind::LIFETIME: + // Skipping Lifetime completely until better handling. + break; + + case HIR::GenericParam::GenericKind::TYPE: { + auto param_type + = TypeResolveGenericParam::Resolve (generic_param.get ()); + context->insert_type (generic_param->get_mappings (), + param_type); + + substitutions.push_back (TyTy::SubstitutionParamMapping ( + static_cast<HIR::TypeParam &> (*generic_param), + param_type)); + } + break; + } + } + } + + std::vector<TyTy::StructFieldType *> variants; + union_decl.iterate ([&] (HIR::StructField &variant) mutable -> bool { + TyTy::BaseType *variant_type + = TypeCheckType::Resolve (variant.get_field_type ().get ()); + TyTy::StructFieldType *ty_variant + = new TyTy::StructFieldType (variant.get_mappings ().get_hirid (), + variant.get_field_name (), variant_type); + variants.push_back (ty_variant); + context->insert_type (variant.get_mappings (), + ty_variant->get_field_type ()); + return true; + }); + + TyTy::BaseType *type + = new TyTy::ADTType (union_decl.get_mappings ().get_hirid (), + mappings->get_next_hir_id (), + union_decl.get_identifier (), + TyTy::ADTType::ADTKind::UNION, std::move (variants), + std::move (substitutions)); + + context->insert_type (union_decl.get_mappings (), type); + } + void visit (HIR::StaticItem &var) override { TyTy::BaseType *type = TypeCheckType::Resolve (var.get_type ()); diff --git a/gcc/rust/typecheck/rust-hir-type-check.cc b/gcc/rust/typecheck/rust-hir-type-check.cc index cb2896c..66adfcb 100644 --- a/gcc/rust/typecheck/rust-hir-type-check.cc +++ b/gcc/rust/typecheck/rust-hir-type-check.cc @@ -180,7 +180,17 @@ TypeCheckStructExpr::visit (HIR::StructExprStructFields &struct_expr) // check the arguments are all assigned and fix up the ordering if (fields_assigned.size () != struct_path_resolved->num_fields ()) { - if (!struct_expr.has_struct_base ()) + if (struct_def->is_union ()) + { + if (fields_assigned.size () != 1 || struct_expr.has_struct_base ()) + { + rust_error_at ( + struct_expr.get_locus (), + "union must have exactly one field variant assigned"); + return; + } + } + else if (!struct_expr.has_struct_base ()) { rust_error_at (struct_expr.get_locus (), "constructor is missing fields"); @@ -236,23 +246,40 @@ TypeCheckStructExpr::visit (HIR::StructExprStructFields &struct_expr) } } - // everything is ok, now we need to ensure all field values are ordered - // correctly. The GIMPLE backend uses a simple algorithm that assumes each - // assigned field in the constructor is in the same order as the field in - // the type - - std::vector<std::unique_ptr<HIR::StructExprField> > expr_fields - = struct_expr.get_fields_as_owner (); - for (auto &f : expr_fields) - f.release (); - - std::vector<std::unique_ptr<HIR::StructExprField> > ordered_fields; - for (size_t i = 0; i < adtFieldIndexToField.size (); i++) + if (struct_def->is_union ()) + { + // There is exactly one field in this constructor, we need to + // figure out the field index to make sure we initialize the + // right union field. + for (size_t i = 0; i < adtFieldIndexToField.size (); i++) + { + if (adtFieldIndexToField[i]) + { + struct_expr.union_index = i; + break; + } + } + rust_assert (struct_expr.union_index != -1); + } + else { - ordered_fields.push_back ( - std::unique_ptr<HIR::StructExprField> (adtFieldIndexToField[i])); + // everything is ok, now we need to ensure all field values are ordered + // correctly. The GIMPLE backend uses a simple algorithm that assumes each + // assigned field in the constructor is in the same order as the field in + // the type + std::vector<std::unique_ptr<HIR::StructExprField> > expr_fields + = struct_expr.get_fields_as_owner (); + for (auto &f : expr_fields) + f.release (); + + std::vector<std::unique_ptr<HIR::StructExprField> > ordered_fields; + for (size_t i = 0; i < adtFieldIndexToField.size (); i++) + { + ordered_fields.push_back ( + std::unique_ptr<HIR::StructExprField> (adtFieldIndexToField[i])); + } + struct_expr.set_fields_as_owner (std::move (ordered_fields)); } - struct_expr.set_fields_as_owner (std::move (ordered_fields)); resolved = struct_def; } diff --git a/gcc/rust/typecheck/rust-tycheck-dump.h b/gcc/rust/typecheck/rust-tycheck-dump.h index b80372b..cc2e3c0 100644 --- a/gcc/rust/typecheck/rust-tycheck-dump.h +++ b/gcc/rust/typecheck/rust-tycheck-dump.h @@ -48,6 +48,12 @@ public: + "\n"; } + void visit (HIR::Union &union_decl) override + { + dump + += indent () + "union " + type_string (union_decl.get_mappings ()) + "\n"; + } + void visit (HIR::ImplBlock &impl_block) override { dump += indent () + "impl " diff --git a/gcc/rust/typecheck/rust-tyty.cc b/gcc/rust/typecheck/rust-tyty.cc index 1ca28fa..6bac764 100644 --- a/gcc/rust/typecheck/rust-tyty.cc +++ b/gcc/rust/typecheck/rust-tyty.cc @@ -537,7 +537,7 @@ ADTType::clone () for (auto &f : fields) cloned_fields.push_back ((StructFieldType *) f->clone ()); - return new ADTType (get_ref (), get_ty_ref (), identifier, get_is_tuple (), + return new ADTType (get_ref (), get_ty_ref (), identifier, get_adt_kind (), cloned_fields, clone_substs (), used_arguments, get_combined_refs ()); } @@ -1999,7 +1999,7 @@ PlaceholderType::clone () void TypeCheckCallExpr::visit (ADTType &type) { - if (!type.get_is_tuple ()) + if (!type.is_tuple_struct ()) { rust_error_at ( call.get_locus (), diff --git a/gcc/rust/typecheck/rust-tyty.h b/gcc/rust/typecheck/rust-tyty.h index 336d42b..46110e4 100644 --- a/gcc/rust/typecheck/rust-tyty.h +++ b/gcc/rust/typecheck/rust-tyty.h @@ -855,7 +855,15 @@ protected: class ADTType : public BaseType, public SubstitutionRef { public: - ADTType (HirId ref, std::string identifier, bool is_tuple, + enum ADTKind + { + STRUCT_STRUCT, + TUPLE_STRUCT, + UNION, + // ENUM ? + }; + + ADTType (HirId ref, std::string identifier, ADTKind adt_kind, std::vector<StructFieldType *> fields, std::vector<SubstitutionParamMapping> subst_refs, SubstitutionArgumentMappings generic_arguments @@ -863,10 +871,10 @@ public: std::set<HirId> refs = std::set<HirId> ()) : BaseType (ref, ref, TypeKind::ADT, refs), SubstitutionRef (std::move (subst_refs), std::move (generic_arguments)), - identifier (identifier), fields (fields), is_tuple (is_tuple) + identifier (identifier), fields (fields), adt_kind (adt_kind) {} - ADTType (HirId ref, HirId ty_ref, std::string identifier, bool is_tuple, + ADTType (HirId ref, HirId ty_ref, std::string identifier, ADTKind adt_kind, std::vector<StructFieldType *> fields, std::vector<SubstitutionParamMapping> subst_refs, SubstitutionArgumentMappings generic_arguments @@ -874,10 +882,12 @@ public: std::set<HirId> refs = std::set<HirId> ()) : BaseType (ref, ty_ref, TypeKind::ADT, refs), SubstitutionRef (std::move (subst_refs), std::move (generic_arguments)), - identifier (identifier), fields (fields), is_tuple (is_tuple) + identifier (identifier), fields (fields), adt_kind (adt_kind) {} - bool get_is_tuple () { return is_tuple; } + ADTKind get_adt_kind () { return adt_kind; } + bool is_tuple_struct () { return adt_kind == TUPLE_STRUCT; } + bool is_union () { return adt_kind == UNION; } bool is_unit () const override { return this->fields.empty (); } @@ -964,7 +974,7 @@ public: private: std::string identifier; std::vector<StructFieldType *> fields; - bool is_tuple; + ADTType::ADTKind adt_kind; }; class FnType : public BaseType, public SubstitutionRef diff --git a/gcc/testsuite/rust/compile/torture/union.rs b/gcc/testsuite/rust/compile/torture/union.rs new file mode 100644 index 0000000..393e591 --- /dev/null +++ b/gcc/testsuite/rust/compile/torture/union.rs @@ -0,0 +1,35 @@ +// { dg-do compile } +// { dg-options "-w" } + +union U +{ + f1: u8 +} + +union V +{ + f1: u8, + f2: u16, + f3: i32, +} + +struct S +{ + f1: U, + f2: V +} + +fn main () +{ + let u = U { f1: 16 }; + let v = V { f2: 365 }; + let s = S { f1: u, f2: v }; + let _v125 = unsafe + { let mut uv: u64; + uv = s.f1.f1 as u64; + uv += s.f2.f1 as u64; + uv += s.f2.f2 as u64; + uv -= s.f2.f3 as u64; + uv + }; +} diff --git a/gcc/testsuite/rust/compile/torture/union_union.rs b/gcc/testsuite/rust/compile/torture/union_union.rs new file mode 100644 index 0000000..9feb145 --- /dev/null +++ b/gcc/testsuite/rust/compile/torture/union_union.rs @@ -0,0 +1,27 @@ +union union +{ + union: u32, + inion: i32, + u8ion: u8, + i64on: i64, + u64on: u64 +} + +pub fn main () +{ + let union = union { union: 2 }; + let inion = union { inion: -2 }; + let mut mnion = union { inion: -16 }; + let m1 = unsafe { mnion.union }; + unsafe { mnion.union = union.union }; + let m2 = unsafe { mnion.inion }; + let u1 = unsafe { union.union }; + let i1 = unsafe { union.inion }; + let u2 = unsafe { inion.union }; + let i2 = unsafe { inion.inion }; + let _r1 = u2 - u1 - m1; + let _r2 = i1 + i2 + m2; + let _u8 = unsafe { union.u8ion }; + let _i64 = unsafe { union.i64on }; + let _u64 = unsafe { union.u64on }; +} |