diff options
author | Philip Herron <philip.herron@embecosm.com> | 2021-03-10 18:10:06 +0000 |
---|---|---|
committer | Philip Herron <herron.philip@googlemail.com> | 2021-03-22 20:01:30 +0000 |
commit | 280ac5bd99b4d66ea10d0aa8d4edba53bf460b10 (patch) | |
tree | e67968afbe4668e6a2679ed961d137cdb409166d /gcc | |
parent | eb33139efa7bbbd09ad26403c36a5dcf31e1b14e (diff) | |
download | gcc-280ac5bd99b4d66ea10d0aa8d4edba53bf460b10.zip gcc-280ac5bd99b4d66ea10d0aa8d4edba53bf460b10.tar.gz gcc-280ac5bd99b4d66ea10d0aa8d4edba53bf460b10.tar.bz2 |
Generics continued this adds more type resolution to ADT and Functions
Adds recursive generic argument handling for structs and functions. With a
new substitution mapper class to coerce the HIR::GenericArgs appropriately.
This is the building block to work on impl blocks with generics and
better Monomorphization support to handle duplicate functions etc.
Fixes: #236 #234 #235
Addresses: #237
Diffstat (limited to 'gcc')
23 files changed, 1082 insertions, 302 deletions
diff --git a/gcc/rust/backend/rust-compile-context.h b/gcc/rust/backend/rust-compile-context.h index 3f4a9ac..6f45e57 100644 --- a/gcc/rust/backend/rust-compile-context.h +++ b/gcc/rust/backend/rust-compile-context.h @@ -61,10 +61,26 @@ public: } } - ~Context () {} - - bool lookup_compiled_types (HirId id, ::Btype **type) + bool lookup_compiled_types (HirId id, ::Btype **type, + const TyTy::BaseType *ref = nullptr) { + if (ref != nullptr && ref->has_subsititions_defined ()) + { + for (auto it = mono.begin (); it != mono.end (); it++) + { + std::pair<HirId, ::Btype *> &val = it->second; + const TyTy::BaseType *r = it->first; + + if (ref->is_equal (*r)) + { + *type = val.second; + + return true; + } + } + return false; + } + auto it = compiled_type_map.find (id); if (it == compiled_type_map.end ()) return false; @@ -73,9 +89,15 @@ public: return true; } - void insert_compiled_type (HirId id, ::Btype *type) + void insert_compiled_type (HirId id, ::Btype *type, + const TyTy::BaseType *ref = nullptr) { compiled_type_map[id] = type; + if (ref != nullptr) + { + std::pair<HirId, ::Btype *> elem (id, type); + mono[ref] = std::move (elem); + } } ::Backend *get_backend () { return backend; } @@ -250,6 +272,7 @@ private: std::vector< ::Bblock *> scope_stack; std::vector< ::Bvariable *> loop_value_stack; std::vector< ::Blabel *> loop_begin_labels; + std::map<const TyTy::BaseType *, std::pair<HirId, ::Btype *> > mono; // To GCC middle-end std::vector< ::Btype *> type_decls; @@ -274,12 +297,8 @@ public: void visit (TyTy::ParamType ¶m) override { - rust_assert (param.get_ref () != param.get_ty_ref ()); - - TyTy::BaseType *lookup = nullptr; - bool ok = ctx->get_tyctx ()->lookup_type (param.get_ty_ref (), &lookup); - rust_assert (ok); - lookup->accept_vis (*this); + TyTy::TyVar var (param.get_ty_ref ()); + var.get_tyty ()->accept_vis (*this); } void visit (TyTy::FnType &type) override @@ -339,8 +358,7 @@ public: void visit (TyTy::ADTType &type) override { - bool ok = ctx->lookup_compiled_types (type.get_ty_ref (), &translated); - if (ok) + if (ctx->lookup_compiled_types (type.get_ty_ref (), &translated, &type)) return; // create implicit struct @@ -361,11 +379,12 @@ public: Btype *named_struct = ctx->get_backend ()->named_type (type.get_name (), struct_type_record, ctx->get_mappings ()->lookup_location ( - type.get_ref ())); + type.get_ty_ref ())); ctx->push_type (named_struct); - ctx->insert_compiled_type (type.get_ty_ref (), named_struct); translated = named_struct; + + ctx->insert_compiled_type (type.get_ty_ref (), named_struct, &type); } void visit (TyTy::TupleType &type) override @@ -485,7 +504,7 @@ public: } private: - TyTyResolveCompile (Context *ctx) : ctx (ctx) {} + TyTyResolveCompile (Context *ctx) : ctx (ctx), translated (nullptr) {} Context *ctx; ::Btype *translated; diff --git a/gcc/rust/backend/rust-compile-implitem.h b/gcc/rust/backend/rust-compile-implitem.h index 202b868..0817424 100644 --- a/gcc/rust/backend/rust-compile-implitem.h +++ b/gcc/rust/backend/rust-compile-implitem.h @@ -91,7 +91,7 @@ public: return; } - TyTy::FnType *fntype = (TyTy::FnType *) fntype_tyty; + TyTy::FnType *fntype = static_cast<TyTy::FnType *> (fntype_tyty); // convert to the actual function type ::Btype *compiled_fn_type = TyTyResolveCompile::compile (ctx, fntype); @@ -108,7 +108,7 @@ public: Bfunction *fndecl = ctx->get_backend ()->function (compiled_fn_type, fn_identifier, asm_name, flags, function.get_locus ()); - ctx->insert_function_decl (function.get_mappings ().get_hirid (), fndecl); + ctx->insert_function_decl (fntype->get_ty_ref (), fndecl); // setup the params @@ -256,7 +256,7 @@ public: return; } - TyTy::FnType *fntype = (TyTy::FnType *) fntype_tyty; + TyTy::FnType *fntype = static_cast<TyTy::FnType *> (fntype_tyty); // convert to the actual function type ::Btype *compiled_fn_type = TyTyResolveCompile::compile (ctx, fntype); @@ -273,7 +273,7 @@ public: Bfunction *fndecl = ctx->get_backend ()->function (compiled_fn_type, fn_identifier, asm_name, flags, method.get_locus ()); - ctx->insert_function_decl (method.get_mappings ().get_hirid (), fndecl); + ctx->insert_function_decl (fntype->get_ty_ref (), fndecl); // setup the params TyTy::BaseType *tyret = fntype->get_return_type (); diff --git a/gcc/rust/backend/rust-compile-item.h b/gcc/rust/backend/rust-compile-item.h index 1d4fcda..2bbfe4c 100644 --- a/gcc/rust/backend/rust-compile-item.h +++ b/gcc/rust/backend/rust-compile-item.h @@ -35,9 +35,10 @@ class CompileItem : public HIRCompileBase using Rust::Compile::HIRCompileBase::visit; public: - static void compile (HIR::Item *item, Context *ctx, bool compile_fns = true) + static void compile (HIR::Item *item, Context *ctx, bool compile_fns = true, + TyTy::BaseType *concrete = nullptr) { - CompileItem compiler (ctx, compile_fns); + CompileItem compiler (ctx, compile_fns, concrete); item->accept_vis (compiler); } @@ -118,8 +119,22 @@ public: return; } - TyTy::FnType *fntype = (TyTy::FnType *) fntype_tyty; - // convert to the actual function type + TyTy::FnType *fntype = static_cast<TyTy::FnType *> (fntype_tyty); + if (fntype->has_subsititions_defined ()) + { + // we cant do anything for this only when it is used + if (concrete == nullptr) + return; + else + { + rust_assert (concrete->get_kind () == TyTy::TypeKind::FNDEF); + fntype = static_cast<TyTy::FnType *> (concrete); + + // override the Hir Lookups for the substituions in this context + fntype->override_context (); + } + } + ::Btype *compiled_fn_type = TyTyResolveCompile::compile (ctx, fntype); unsigned int flags = 0; @@ -130,21 +145,30 @@ public: if (is_main_fn || function.has_visibility ()) flags |= Backend::function_is_visible; + std::string ir_symbol_name = function.get_function_name (); std::string asm_name = function.get_function_name (); if (!is_main_fn) { // FIXME need name mangling - asm_name = "__" + function.get_function_name (); + if (concrete == nullptr) + asm_name = "__" + function.get_function_name (); + else + { + ir_symbol_name + = function.get_function_name () + fntype->subst_as_string (); + + asm_name = "__" + function.get_function_name (); + for (auto &sub : fntype->get_substs ()) + asm_name += "G" + sub.as_string (); + } } Bfunction *fndecl - = ctx->get_backend ()->function (compiled_fn_type, - function.get_function_name (), asm_name, - flags, function.get_locus ()); - ctx->insert_function_decl (function.get_mappings ().get_hirid (), fndecl); + = ctx->get_backend ()->function (compiled_fn_type, ir_symbol_name, + asm_name, flags, function.get_locus ()); + ctx->insert_function_decl (fntype->get_ty_ref (), fndecl); // setup the params - TyTy::BaseType *tyret = fntype->get_return_type (); std::vector<Bvariable *> param_vars; @@ -274,11 +298,12 @@ public: } private: - CompileItem (Context *ctx, bool compile_fns) - : HIRCompileBase (ctx), compile_fns (compile_fns) + CompileItem (Context *ctx, bool compile_fns, TyTy::BaseType *concrete) + : HIRCompileBase (ctx), compile_fns (compile_fns), concrete (concrete) {} bool compile_fns; + TyTy::BaseType *concrete; }; } // namespace Compile diff --git a/gcc/rust/backend/rust-compile-resolve-path.cc b/gcc/rust/backend/rust-compile-resolve-path.cc index 1a798ee..4fbaae3 100644 --- a/gcc/rust/backend/rust-compile-resolve-path.cc +++ b/gcc/rust/backend/rust-compile-resolve-path.cc @@ -67,24 +67,35 @@ ResolvePathRef::visit (HIR::PathInExpression &expr) return; } - // must be a function call + // must be a function call but it might be a generic function which needs to + // be compiled first + TyTy::BaseType *lookup = nullptr; + bool ok = ctx->get_tyctx ()->lookup_type (expr.get_mappings ().get_hirid (), + &lookup); + rust_assert (ok); + rust_assert (lookup->get_kind () == TyTy::TypeKind::FNDEF); + Bfunction *fn = nullptr; - if (!ctx->lookup_function_decl (ref, &fn)) + if (!ctx->lookup_function_decl (lookup->get_ty_ref (), &fn)) { - // this might fail because its a forward decl so we can attempt to - // resolve it now + // it must resolve to some kind of HIR::Item HIR::Item *resolved_item = ctx->get_mappings ()->lookup_hir_item ( expr.get_mappings ().get_crate_num (), ref); if (resolved_item == nullptr) { - rust_error_at (expr.get_locus (), "failed to lookup forward decl"); + rust_error_at (expr.get_locus (), "failed to lookup definition decl"); return; } - CompileItem::compile (resolved_item, ctx); - if (!ctx->lookup_function_decl (ref, &fn)) + if (!lookup->has_subsititions_defined ()) + CompileItem::compile (resolved_item, ctx); + else + CompileItem::compile (resolved_item, ctx, true, lookup); + + if (!ctx->lookup_function_decl (lookup->get_ty_ref (), &fn)) { - rust_error_at (expr.get_locus (), "forward decl was not compiled 1"); + rust_fatal_error (expr.get_locus (), + "forward decl was not compiled 1"); return; } } diff --git a/gcc/rust/backend/rust-compile-tyty.h b/gcc/rust/backend/rust-compile-tyty.h index 815ebd5..774fd2e 100644 --- a/gcc/rust/backend/rust-compile-tyty.h +++ b/gcc/rust/backend/rust-compile-tyty.h @@ -102,7 +102,7 @@ public: void visit (TyTy::IntType &type) override { - switch (type.get_kind ()) + switch (type.get_int_kind ()) { case TyTy::IntType::I8: translated @@ -139,7 +139,7 @@ public: void visit (TyTy::UintType &type) override { - switch (type.get_kind ()) + switch (type.get_uint_kind ()) { case TyTy::UintType::U8: translated = backend->named_type ("u8", backend->integer_type (true, 8), @@ -175,7 +175,7 @@ public: void visit (TyTy::FloatType &type) override { - switch (type.get_kind ()) + switch (type.get_float_kind ()) { case TyTy::FloatType::F32: translated = backend->named_type ("f32", backend->float_type (32), diff --git a/gcc/rust/rust-backend.h b/gcc/rust/rust-backend.h index e42081b..7c0ac6e 100644 --- a/gcc/rust/rust-backend.h +++ b/gcc/rust/rust-backend.h @@ -785,6 +785,9 @@ public: // is like a C99 function marked inline but not extern. static const unsigned int function_only_inline = 1 << 6; + // const function + static const unsigned int function_read_only = 1 << 7; + // Declare or define a function of FNTYPE. // NAME is the Go name of the function. ASM_NAME, if not the empty // string, is the name that should be used in the symbol table; this diff --git a/gcc/rust/rust-gcc.cc b/gcc/rust/rust-gcc.cc index cf800c7..d1ab3a6 100644 --- a/gcc/rust/rust-gcc.cc +++ b/gcc/rust/rust-gcc.cc @@ -1254,6 +1254,7 @@ Gcc_backend::named_type (const std::string &name, Btype *btype, tree decl = build_decl (location.gcc_location (), TYPE_DECL, get_identifier_from_string (name), type); + TYPE_NAME (type) = decl; return this->make_type (type); } @@ -3207,43 +3208,8 @@ Gcc_backend::function (Btype *fntype, const std::string &name, DECL_EXTERNAL (decl) = 1; DECL_DECLARED_INLINE_P (decl) = 1; } - - // Optimize thunk functions for size. A thunk created for a defer - // statement that may call recover looks like: - // if runtime.setdeferretaddr(L1) { - // goto L1 - // } - // realfn() - // L1: - // The idea is that L1 should be the address to which realfn - // returns. This only works if this little function is not over - // optimized. At some point GCC started duplicating the epilogue in - // the basic-block reordering pass, breaking this assumption. - // Optimizing the function for size avoids duplicating the epilogue. - // This optimization shouldn't matter for any thunk since all thunks - // are small. - size_t pos = name.find ("..thunk"); - if (pos != std::string::npos) - { - for (pos += 7; pos < name.length (); ++pos) - { - if (name[pos] < '0' || name[pos] > '9') - break; - } - if (pos == name.length ()) - { - struct cl_optimization cur_opts; - cl_optimization_save (&cur_opts, &global_options, - &global_options_set); - global_options.x_optimize_size = 1; - global_options.x_optimize_fast = 0; - global_options.x_optimize_debug = 0; - DECL_FUNCTION_SPECIFIC_OPTIMIZATION (decl) - = build_optimization_node (&global_options, &global_options_set); - cl_optimization_restore (&global_options, &global_options_set, - &cur_opts); - } - } + if ((flags & function_read_only) != 0) + TREE_READONLY (decl) = 1; rust_preserve_from_gc (decl); return new Bfunction (decl); diff --git a/gcc/rust/typecheck/rust-hir-type-check-expr.h b/gcc/rust/typecheck/rust-hir-type-check-expr.h index d7f2bdb..c2c0160 100644 --- a/gcc/rust/typecheck/rust-hir-type-check-expr.h +++ b/gcc/rust/typecheck/rust-hir-type-check-expr.h @@ -25,6 +25,7 @@ #include "rust-tyty-call.h" #include "rust-hir-type-check-struct-field.h" #include "rust-hir-method-resolve.h" +#include "rust-substitution-mapper.h" namespace Rust { namespace Resolver { @@ -716,7 +717,7 @@ public: return; } - TyTy::ADTType *adt = (TyTy::ADTType *) struct_base; + TyTy::ADTType *adt = static_cast<TyTy::ADTType *> (struct_base); auto resolved = adt->get_field (expr.get_field_name ()); if (resolved == nullptr) { @@ -789,17 +790,18 @@ public: if (infered->has_subsititions_defined ()) { - if (infered->get_kind () != TyTy::TypeKind::ADT) + if (!infered->can_substitute ()) { rust_error_at (expr.get_locus (), - "substitutions only support on ADT types so far"); + "substitutions not supported for %s", + infered->as_string ().c_str ()); return; } - TyTy::ADTType *adt = static_cast<TyTy::ADTType *> (infered); - infered = seg.has_generic_args () - ? adt->handle_substitutions (seg.get_generic_args ()) - : adt->infer_substitutions (); + infered = SubstMapper::Resolve (infered, expr.get_locus (), + seg.has_generic_args () + ? &seg.get_generic_args () + : nullptr); } } diff --git a/gcc/rust/typecheck/rust-hir-type-check-implitem.h b/gcc/rust/typecheck/rust-hir-type-check-implitem.h index 9473eda..b8df74f 100644 --- a/gcc/rust/typecheck/rust-hir-type-check-implitem.h +++ b/gcc/rust/typecheck/rust-hir-type-check-implitem.h @@ -50,6 +50,20 @@ public: void visit (HIR::Function &function) override { + std::vector<TyTy::SubstitutionParamMapping> substitions; + if (function.has_generics ()) + { + for (auto &generic_param : function.get_generic_params ()) + { + auto param_type + = TypeResolveGenericParam::Resolve (generic_param.get ()); + context->insert_type (generic_param->get_mappings (), param_type); + + substitions.push_back ( + TyTy::SubstitutionParamMapping (generic_param, param_type)); + } + } + TyTy::BaseType *ret_type = nullptr; if (!function.has_function_return_type ()) ret_type = new TyTy::UnitType (function.get_mappings ().get_hirid ()); @@ -82,12 +96,27 @@ public: } auto fnType = new TyTy::FnType (function.get_mappings ().get_hirid (), - params, ret_type); + std::move (params), ret_type, + std::move (substitions)); context->insert_type (function.get_mappings (), fnType); } void visit (HIR::Method &method) override { + std::vector<TyTy::SubstitutionParamMapping> substitions; + if (method.has_generics ()) + { + for (auto &generic_param : method.get_generic_params ()) + { + auto param_type + = TypeResolveGenericParam::Resolve (generic_param.get ()); + context->insert_type (generic_param->get_mappings (), param_type); + + substitions.push_back ( + TyTy::SubstitutionParamMapping (generic_param, param_type)); + } + } + TyTy::BaseType *ret_type = nullptr; if (!method.has_function_return_type ()) ret_type = new TyTy::UnitType (method.get_mappings ().get_hirid ()); @@ -133,8 +162,9 @@ public: context->insert_type (param.get_mappings (), param_tyty); } - auto fnType = new TyTy::FnType (method.get_mappings ().get_hirid (), params, - ret_type); + auto fnType = new TyTy::FnType (method.get_mappings ().get_hirid (), + std::move (params), ret_type, + std::move (substitions)); context->insert_type (method.get_mappings (), fnType); } diff --git a/gcc/rust/typecheck/rust-hir-type-check-item.h b/gcc/rust/typecheck/rust-hir-type-check-item.h index 441a1e3..44fe943 100644 --- a/gcc/rust/typecheck/rust-hir-type-check-item.h +++ b/gcc/rust/typecheck/rust-hir-type-check-item.h @@ -73,8 +73,8 @@ public: } // need to get the return type from this - TyTy::FnType *resolve_fn_type = (TyTy::FnType *) lookup; - auto expected_ret_tyty = resolve_fn_type->get_return_type (); + TyTy::FnType *resolved_fn_type = static_cast<TyTy::FnType *> (lookup); + auto expected_ret_tyty = resolved_fn_type->get_return_type (); context->push_return_type (expected_ret_tyty); auto block_expr_ty diff --git a/gcc/rust/typecheck/rust-hir-type-check-stmt.h b/gcc/rust/typecheck/rust-hir-type-check-stmt.h index 7f7e625..abe7a55 100644 --- a/gcc/rust/typecheck/rust-hir-type-check-stmt.h +++ b/gcc/rust/typecheck/rust-hir-type-check-stmt.h @@ -105,6 +105,10 @@ public: TyTy::InferType::InferTypeKind::GENERAL)); } } + + TyTy::BaseType *lookup = nullptr; + bool ok = context->lookup_type (stmt.get_mappings ().get_hirid (), &lookup); + rust_assert (ok); } private: diff --git a/gcc/rust/typecheck/rust-hir-type-check-toplevel.h b/gcc/rust/typecheck/rust-hir-type-check-toplevel.h index 63c0f42..4f800b3 100644 --- a/gcc/rust/typecheck/rust-hir-type-check-toplevel.h +++ b/gcc/rust/typecheck/rust-hir-type-check-toplevel.h @@ -42,7 +42,7 @@ public: void visit (HIR::TupleStruct &struct_decl) override { - std::vector<TyTy::SubstitutionMapping> substitutions; + std::vector<TyTy::SubstitutionParamMapping> substitutions; if (struct_decl.has_generics ()) { for (auto &generic_param : struct_decl.get_generic_params ()) @@ -52,7 +52,7 @@ public: context->insert_type (generic_param->get_mappings (), param_type); substitutions.push_back ( - TyTy::SubstitutionMapping (generic_param, param_type)); + TyTy::SubstitutionParamMapping (generic_param, param_type)); } } @@ -82,7 +82,7 @@ public: void visit (HIR::StructStruct &struct_decl) override { - std::vector<TyTy::SubstitutionMapping> substitutions; + std::vector<TyTy::SubstitutionParamMapping> substitutions; if (struct_decl.has_generics ()) { for (auto &generic_param : struct_decl.get_generic_params ()) @@ -92,7 +92,7 @@ public: context->insert_type (generic_param->get_mappings (), param_type); substitutions.push_back ( - TyTy::SubstitutionMapping (generic_param, param_type)); + TyTy::SubstitutionParamMapping (generic_param, param_type)); } } @@ -136,7 +136,7 @@ public: void visit (HIR::Function &function) override { - std::vector<TyTy::SubstitutionMapping> substitutions; + std::vector<TyTy::SubstitutionParamMapping> substitutions; if (function.has_generics ()) { for (auto &generic_param : function.get_generic_params ()) @@ -146,7 +146,7 @@ public: context->insert_type (generic_param->get_mappings (), param_type); substitutions.push_back ( - TyTy::SubstitutionMapping (generic_param, param_type)); + TyTy::SubstitutionParamMapping (generic_param, param_type)); } } @@ -182,7 +182,8 @@ public: } auto fnType = new TyTy::FnType (function.get_mappings ().get_hirid (), - params, ret_type); + std::move (params), ret_type, + std::move (substitutions)); context->insert_type (function.get_mappings (), fnType); } diff --git a/gcc/rust/typecheck/rust-hir-type-check-type.h b/gcc/rust/typecheck/rust-hir-type-check-type.h index ec9cbda..de3dfa7 100644 --- a/gcc/rust/typecheck/rust-hir-type-check-type.h +++ b/gcc/rust/typecheck/rust-hir-type-check-type.h @@ -21,6 +21,7 @@ #include "rust-hir-type-check-base.h" #include "rust-hir-full.h" +#include "rust-substitution-mapper.h" namespace Rust { namespace Resolver { @@ -182,20 +183,9 @@ public: { if (translated->has_subsititions_defined ()) { - // so far we only support ADT so lets just handle it here - // for now - if (translated->get_kind () != TyTy::TypeKind::ADT) - { - rust_error_at ( - path.get_locus (), - "unsupported type for generic substitution: %s", - translated->as_string ().c_str ()); - return; - } - - TyTy::ADTType *adt - = static_cast<TyTy::ADTType *> (translated); - translated = adt->handle_substitutions (args); + translated + = SubstMapper::Resolve (translated, path.get_locus (), + &args); } else { @@ -208,23 +198,11 @@ public: return; } } - else if (translated->supports_substitutions ()) + else if (translated->has_subsititions_defined ()) { - // so far we only support ADT so lets just handle it here - // for now - if (translated->get_kind () != TyTy::TypeKind::ADT) - { - rust_error_at ( - path.get_locus (), - "unsupported type for generic substitution: %s", - translated->as_string ().c_str ()); - return; - } - - TyTy::ADTType *adt = static_cast<TyTy::ADTType *> (translated); - translated = adt->infer_substitutions (); + translated + = SubstMapper::Resolve (translated, path.get_locus ()); } - return; } } diff --git a/gcc/rust/typecheck/rust-hir-type-check.cc b/gcc/rust/typecheck/rust-hir-type-check.cc index d5cb0d5..9c25bfb 100644 --- a/gcc/rust/typecheck/rust-hir-type-check.cc +++ b/gcc/rust/typecheck/rust-hir-type-check.cc @@ -305,10 +305,19 @@ TypeCheckStructExpr::visit (HIR::PathInExpression &expr) if (struct_path_resolved->has_substitutions ()) { HIR::PathExprSegment seg = expr.get_final_segment (); - struct_path_resolved = seg.has_generic_args () - ? struct_path_resolved->handle_substitutions ( - seg.get_generic_args ()) - : struct_path_resolved->infer_substitutions (); + + TyTy::BaseType *subst + = SubstMapper::Resolve (struct_path_resolved, expr.get_locus (), + seg.has_generic_args () + ? &seg.get_generic_args () + : nullptr); + if (subst == nullptr || subst->get_kind () != TyTy::TypeKind::ADT) + { + rust_fatal_error (mappings->lookup_location (ref), + "expected a substituted ADT type"); + return; + } + struct_path_resolved = static_cast<TyTy::ADTType *> (subst); } } diff --git a/gcc/rust/typecheck/rust-substitution-mapper.h b/gcc/rust/typecheck/rust-substitution-mapper.h new file mode 100644 index 0000000..516b0b7 --- /dev/null +++ b/gcc/rust/typecheck/rust-substitution-mapper.h @@ -0,0 +1,170 @@ +// Copyright (C) 2020 Free Software Foundation, Inc. + +// This file is part of GCC. + +// GCC is free software; you can redistribute it and/or modify it under +// the terms of the GNU General Public License as published by the Free +// Software Foundation; either version 3, or (at your option) any later +// version. + +// GCC is distributed in the hope that it will be useful, but WITHOUT ANY +// WARRANTY; without even the implied warranty of MERCHANTABILITY or +// FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +// for more details. + +// You should have received a copy of the GNU General Public License +// along with GCC; see the file COPYING3. If not see +// <http://www.gnu.org/licenses/>. + +#ifndef RUST_SUBSTITUTION_MAPPER_H +#define RUST_SUBSTITUTION_MAPPER_H + +#include "rust-tyty.h" +#include "rust-tyty-visitor.h" + +namespace Rust { +namespace Resolver { + +class SubstMapper : public TyTy::TyVisitor +{ +public: + static TyTy::BaseType *Resolve (TyTy::BaseType *base, Location locus, + HIR::GenericArgs *generics = nullptr) + { + SubstMapper mapper (base->get_ref (), generics, locus); + base->accept_vis (mapper); + rust_assert (mapper.resolved != nullptr); + return mapper.resolved; + } + + bool have_generic_args () const { return generics != nullptr; } + + void visit (TyTy::FnType &type) override + { + TyTy::FnType *concrete = nullptr; + if (!have_generic_args ()) + { + TyTy::BaseType *substs = type.infer_substitions (locus); + rust_assert (substs->get_kind () == TyTy::TypeKind::FNDEF); + concrete = static_cast<TyTy::FnType *> (substs); + } + else + { + TyTy::SubstitutionArgumentMappings mappings + = type.get_mappings_from_generic_args (*generics); + concrete = type.handle_substitions (mappings); + } + + if (concrete != nullptr) + resolved = concrete; + } + + void visit (TyTy::ADTType &type) override + { + TyTy::ADTType *concrete = nullptr; + if (!have_generic_args ()) + { + TyTy::BaseType *substs = type.infer_substitions (locus); + rust_assert (substs->get_kind () == TyTy::TypeKind::ADT); + concrete = static_cast<TyTy::ADTType *> (substs); + } + else + { + TyTy::SubstitutionArgumentMappings mappings + = type.get_mappings_from_generic_args (*generics); + concrete = type.handle_substitions (mappings); + } + + if (concrete != nullptr) + resolved = concrete; + } + + void visit (TyTy::UnitType &) override { gcc_unreachable (); } + void visit (TyTy::InferType &) override { gcc_unreachable (); } + void visit (TyTy::TupleType &) override { gcc_unreachable (); } + void visit (TyTy::FnPtr &) override { gcc_unreachable (); } + void visit (TyTy::ArrayType &) override { gcc_unreachable (); } + void visit (TyTy::BoolType &) override { gcc_unreachable (); } + void visit (TyTy::IntType &) override { gcc_unreachable (); } + void visit (TyTy::UintType &) override { gcc_unreachable (); } + void visit (TyTy::FloatType &) override { gcc_unreachable (); } + void visit (TyTy::USizeType &) override { gcc_unreachable (); } + void visit (TyTy::ISizeType &) override { gcc_unreachable (); } + void visit (TyTy::ErrorType &) override { gcc_unreachable (); } + void visit (TyTy::CharType &) override { gcc_unreachable (); } + void visit (TyTy::ReferenceType &) override { gcc_unreachable (); } + void visit (TyTy::ParamType &) override { gcc_unreachable (); } + void visit (TyTy::StrType &) override { gcc_unreachable (); } + +private: + SubstMapper (HirId ref, HIR::GenericArgs *generics, Location locus) + : resolved (new TyTy::ErrorType (ref)), generics (generics), locus (locus) + {} + + TyTy::BaseType *resolved; + HIR::GenericArgs *generics; + Location locus; +}; + +class SubstMapperInternal : public TyTy::TyVisitor +{ +public: + static TyTy::BaseType *Resolve (TyTy::BaseType *base, + TyTy::SubstitutionArgumentMappings &mappings) + { + SubstMapperInternal mapper (base->get_ref (), mappings); + base->accept_vis (mapper); + rust_assert (mapper.resolved != nullptr); + return mapper.resolved; + } + + void visit (TyTy::FnType &type) override + { + TyTy::SubstitutionArgumentMappings adjusted + = type.adjust_mappings_for_this (mappings); + + TyTy::BaseType *concrete = type.handle_substitions (adjusted); + if (concrete != nullptr) + resolved = concrete; + } + + void visit (TyTy::ADTType &type) override + { + TyTy::SubstitutionArgumentMappings adjusted + = type.adjust_mappings_for_this (mappings); + + TyTy::BaseType *concrete = type.handle_substitions (adjusted); + if (concrete != nullptr) + resolved = concrete; + } + + void visit (TyTy::UnitType &) override { gcc_unreachable (); } + void visit (TyTy::InferType &) override { gcc_unreachable (); } + void visit (TyTy::TupleType &) override { gcc_unreachable (); } + void visit (TyTy::FnPtr &) override { gcc_unreachable (); } + void visit (TyTy::ArrayType &) override { gcc_unreachable (); } + void visit (TyTy::BoolType &) override { gcc_unreachable (); } + void visit (TyTy::IntType &) override { gcc_unreachable (); } + void visit (TyTy::UintType &) override { gcc_unreachable (); } + void visit (TyTy::FloatType &) override { gcc_unreachable (); } + void visit (TyTy::USizeType &) override { gcc_unreachable (); } + void visit (TyTy::ISizeType &) override { gcc_unreachable (); } + void visit (TyTy::ErrorType &) override { gcc_unreachable (); } + void visit (TyTy::CharType &) override { gcc_unreachable (); } + void visit (TyTy::ReferenceType &) override { gcc_unreachable (); } + void visit (TyTy::ParamType &) override { gcc_unreachable (); } + void visit (TyTy::StrType &) override { gcc_unreachable (); } + +private: + SubstMapperInternal (HirId ref, TyTy::SubstitutionArgumentMappings &mappings) + : resolved (new TyTy::ErrorType (ref)), mappings (mappings) + {} + + TyTy::BaseType *resolved; + TyTy::SubstitutionArgumentMappings &mappings; +}; + +} // namespace Resolver +} // namespace Rust + +#endif // RUST_SUBSTITUTION_MAPPER_H diff --git a/gcc/rust/typecheck/rust-tyty-rules.h b/gcc/rust/typecheck/rust-tyty-rules.h index c6eb357..e17f89d 100644 --- a/gcc/rust/typecheck/rust-tyty-rules.h +++ b/gcc/rust/typecheck/rust-tyty-rules.h @@ -240,7 +240,6 @@ public: rust_error_at (ref_locus, "expected [%s] got [ParamTy <%s>]", get_base ()->as_string ().c_str (), type.as_string ().c_str ()); - gcc_unreachable (); } virtual void visit (StrType &type) override @@ -487,6 +486,19 @@ public: BaseRules::visit (type); } + void visit (ParamType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL); + if (is_valid) + { + resolved = type.clone (); + return; + } + + BaseRules::visit (type); + } + private: BaseType *get_base () override { return base; } @@ -714,6 +726,20 @@ public: resolved = new BoolType (type.get_ref (), type.get_ty_ref ()); } + void visit (InferType &type) override + { + switch (type.get_infer_kind ()) + { + case InferType::InferTypeKind::GENERAL: + resolved = base->clone (); + break; + + default: + BaseRules::visit (type); + break; + } + } + private: BaseType *get_base () override { return base; } @@ -742,14 +768,14 @@ public: void visit (IntType &type) override { - if (type.get_kind () != base->get_kind ()) + if (type.get_int_kind () != base->get_int_kind ()) { BaseRules::visit (type); return; } resolved - = new IntType (type.get_ref (), type.get_ty_ref (), type.get_kind ()); + = new IntType (type.get_ref (), type.get_ty_ref (), type.get_int_kind ()); } private: @@ -780,14 +806,14 @@ public: void visit (UintType &type) override { - if (type.get_kind () != base->get_kind ()) + if (type.get_uint_kind () != base->get_uint_kind ()) { BaseRules::visit (type); return; } - resolved - = new UintType (type.get_ref (), type.get_ty_ref (), type.get_kind ()); + resolved = new UintType (type.get_ref (), type.get_ty_ref (), + type.get_uint_kind ()); } private: @@ -817,14 +843,14 @@ public: void visit (FloatType &type) override { - if (type.get_kind () != base->get_kind ()) + if (type.get_float_kind () != base->get_float_kind ()) { BaseRules::visit (type); return; } - resolved - = new FloatType (type.get_ref (), type.get_ty_ref (), type.get_kind ()); + resolved = new FloatType (type.get_ref (), type.get_ty_ref (), + type.get_float_kind ()); } private: @@ -864,7 +890,7 @@ public: } } - resolved = base->clone (); + resolved = type.clone (); } private: @@ -1029,6 +1055,8 @@ private: class ParamRules : public BaseRules { + using Rust::TyTy::BaseRules::visit; + public: ParamRules (ParamType *base) : BaseRules (base), base (base) {} @@ -1044,15 +1072,7 @@ public: BaseType *unify (BaseType *other) override final { if (base->get_ref () == base->get_ty_ref ()) - { - Location locus = mappings->lookup_location (base->get_ref ()); - rust_fatal_error (locus, - "invalid use of unify with ParamTy [%s] and [%s]", - base->as_string ().c_str (), - other->as_string ().c_str ()); - return nullptr; - } - + return BaseRules::unify (other); auto context = Resolver::TypeCheckContext::get (); BaseType *lookup = nullptr; bool ok = context->lookup_type (base->get_ty_ref (), &lookup); @@ -1061,6 +1081,17 @@ public: return lookup->unify (other); } + void visit (ParamType &type) override + { + if (base->get_symbol ().compare (type.get_symbol ()) != 0) + { + BaseRules::visit (type); + return; + } + + resolved = type.clone (); + } + private: BaseType *get_base () override { return base; } diff --git a/gcc/rust/typecheck/rust-tyty.cc b/gcc/rust/typecheck/rust-tyty.cc index 034bc2d..e5bd96a 100644 --- a/gcc/rust/typecheck/rust-tyty.cc +++ b/gcc/rust/typecheck/rust-tyty.cc @@ -23,6 +23,7 @@ #include "rust-hir-type-check-type.h" #include "rust-tyty-rules.h" #include "rust-hir-map.h" +#include "rust-substitution-mapper.h" namespace Rust { namespace TyTy { @@ -46,6 +47,22 @@ TyVar::get_tyty () const return lookup; } +TyVar +TyVar::get_implict_infer_var () +{ + auto mappings = Analysis::Mappings::get (); + auto context = Resolver::TypeCheckContext::get (); + + InferType *infer = new InferType (mappings->get_next_hir_id (), + InferType::InferTypeKind::GENERAL); + context->insert_type (Analysis::NodeMapping (mappings->get_current_crate (), + UNKNOWN_NODEID, + infer->get_ref (), + UNKNOWN_LOCAL_DEFID), + infer); + return TyVar (infer->get_ref ()); +} + void UnitType::accept_vis (TyVisitor &vis) { @@ -156,14 +173,24 @@ ErrorType::clone () std::string StructFieldType::as_string () const { - return name + ":" + ty->as_string (); + return name + ":" + get_field_type ()->debug_str (); } bool StructFieldType::is_equal (const StructFieldType &other) const { - return get_name ().compare (other.get_name ()) == 0 - && get_field_type ()->is_equal (*other.get_field_type ()); + bool names_eq = get_name ().compare (other.get_name ()) == 0; + + TyTy::BaseType &o = *other.get_field_type (); + if (o.get_kind () == TypeKind::PARAM) + { + ParamType &op = static_cast<ParamType &> (o); + o = *op.resolve (); + } + + bool types_eq = get_field_type ()->is_equal (o); + + return names_eq && types_eq; } StructFieldType * @@ -174,6 +201,87 @@ StructFieldType::clone () const } void +SubstitutionParamMapping::override_context () +{ + rust_assert (param->can_resolve ()); + + auto mappings = Analysis::Mappings::get (); + auto context = Resolver::TypeCheckContext::get (); + context->insert_type (Analysis::NodeMapping (mappings->get_current_crate (), + UNKNOWN_NODEID, + param->get_ref (), + UNKNOWN_LOCAL_DEFID), + param->resolve ()); +} + +SubstitutionArgumentMappings +SubstitutionRef::get_mappings_from_generic_args (HIR::GenericArgs &args) +{ + if (args.get_type_args ().size () != substitutions.size ()) + { + rust_error_at (args.get_locus (), + "Invalid number of generic arguments to generic type"); + return SubstitutionArgumentMappings::error (); + } + + std::vector<SubstitutionArg> mappings; + + // FIXME does not support binding yet + for (auto &arg : args.get_type_args ()) + { + BaseType *resolved = Resolver::TypeCheckType::Resolve (arg.get ()); + if (resolved == nullptr || resolved->get_kind () == TyTy::TypeKind::ERROR) + { + rust_error_at (args.get_locus (), "failed to resolve type arguments"); + return SubstitutionArgumentMappings::error (); + } + + SubstitutionArg subst_arg (&substitutions.at (mappings.size ()), + resolved); + mappings.push_back (std::move (subst_arg)); + } + + return SubstitutionArgumentMappings (mappings, args.get_locus ()); +} + +SubstitutionArgumentMappings +SubstitutionRef::adjust_mappings_for_this ( + SubstitutionArgumentMappings &mappings) +{ + if (substitutions.size () > mappings.size ()) + { + rust_error_at (mappings.get_locus (), + "not enough type arguments: subs %s vs mappings %s", + subst_as_string ().c_str (), + mappings.as_string ().c_str ()); + return SubstitutionArgumentMappings::error (); + } + + Analysis::Mappings *mappings_table = Analysis::Mappings::get (); + + std::vector<SubstitutionArg> resolved_mappings; + for (auto &subst : substitutions) + { + SubstitutionArg arg = SubstitutionArg::error (); + bool ok = mappings.get_argument_for_symbol (subst.get_param_ty (), &arg); + if (!ok) + { + rust_error_at (mappings_table->lookup_location ( + subst.get_param_ty ()->get_ref ()), + "failed to find parameter type: %s", + subst.get_param_ty ()->as_string ().c_str ()); + return SubstitutionArgumentMappings::error (); + } + + SubstitutionArg adjusted (&subst, arg.get_tyty ()); + resolved_mappings.push_back (std::move (adjusted)); + } + + return SubstitutionArgumentMappings (resolved_mappings, + mappings.get_locus ()); +} + +void ADTType::accept_vis (TyVisitor &vis) { vis.visit (*this); @@ -230,10 +338,34 @@ ADTType::is_equal (const BaseType &other) const if (num_fields () != other2.num_fields ()) return false; - for (size_t i = 0; i < num_fields (); i++) + if (has_subsititions_defined () != other2.has_subsititions_defined ()) + return false; + + if (has_subsititions_defined ()) { - if (!get_field (i)->is_equal (*other2.get_field (i))) + if (get_num_substitutions () != other2.get_num_substitutions ()) return false; + + for (size_t i = 0; i < get_num_substitutions (); i++) + { + const SubstitutionParamMapping &a = substitutions.at (i); + const SubstitutionParamMapping &b = other2.substitutions.at (i); + + const ParamType *aa = a.get_param_ty (); + const ParamType *bb = b.get_param_ty (); + BaseType *aaa = aa->resolve (); + BaseType *bbb = bb->resolve (); + if (!aaa->is_equal (*bbb)) + return false; + } + } + else + { + for (size_t i = 0; i < num_fields (); i++) + { + if (!get_field (i)->is_equal (*other2.get_field (i))) + return false; + } } return true; @@ -251,88 +383,83 @@ ADTType::clone () } ADTType * -ADTType::infer_substitutions () +ADTType::handle_substitions (SubstitutionArgumentMappings subst_mappings) { - auto context = Resolver::TypeCheckContext::get (); - ADTType *adt = static_cast<ADTType *> (clone ()); + if (subst_mappings.size () != get_num_substitutions ()) - for (auto &sub : adt->get_substs ()) { - // generate an new inference variable - InferType *infer = new InferType (mappings->get_next_hir_id (), - InferType::InferTypeKind::GENERAL); - context->insert_type ( - Analysis::NodeMapping (mappings->get_current_crate (), UNKNOWN_NODEID, - infer->get_ref (), UNKNOWN_LOCAL_DEFID), - infer); - - sub.fill_param_ty (infer); - adt->fill_in_params_for (sub, infer); - } - - // generate new ty ref id since this is an instantiate of the generic - adt->set_ty_ref (mappings->get_next_hir_id ()); - - return adt; -} - -ADTType * -ADTType::handle_substitutions (HIR::GenericArgs &generic_args) -{ - if (generic_args.get_type_args ().size () != get_num_substitutions ()) - { - rust_error_at (generic_args.get_locus (), + rust_error_at (subst_mappings.get_locus (), "invalid number of generic arguments to generic ADT type"); return nullptr; } ADTType *adt = static_cast<ADTType *> (clone ()); - size_t index = 0; - for (auto &arg : generic_args.get_type_args ()) - { - BaseType *resolved = Resolver::TypeCheckType::Resolve (arg.get ()); - if (resolved == nullptr) - { - rust_error_at (generic_args.get_locus (), - "failed to resolve type arguments"); - return nullptr; - } - - adt->fill_in_at (index, resolved); - index++; - } - - // generate new ty ref id since this is an instantiate of the generic adt->set_ty_ref (mappings->get_next_hir_id ()); - return adt; -} - -void -ADTType::fill_in_at (size_t index, BaseType *type) -{ - SubstitutionMapping sub = get_substitution_mapping_at (index); - SubstitutionRef<ADTType>::fill_in_at (index, type); - fill_in_params_for (sub, type); -} - -void -ADTType::fill_in_params_for (SubstitutionMapping sub, BaseType *type) -{ - iterate_fields ([&] (StructFieldType *field) mutable -> bool { - bool is_param_ty = field->get_field_type ()->get_kind () == TypeKind::PARAM; - if (!is_param_ty) - return true; - - const ParamType *pp = sub.get_param_ty (); - ParamType *p = static_cast<ParamType *> (field->get_field_type ()); + for (auto &sub : adt->get_substs ()) + { + SubstitutionArg arg = SubstitutionArg::error (); + bool ok + = subst_mappings.get_argument_for_symbol (sub.get_param_ty (), &arg); + rust_assert (ok); + sub.fill_param_ty (arg.get_tyty ()); + } - // for now let just see what symbols match up for the substitution - if (p->get_symbol ().compare (pp->get_symbol ()) == 0) - p->set_ty_ref (type->get_ref ()); + adt->iterate_fields ([&] (StructFieldType *field) mutable -> bool { + auto fty = field->get_field_type (); + bool is_param_ty = fty->get_kind () == TypeKind::PARAM; + if (is_param_ty) + { + ParamType *p = static_cast<ParamType *> (fty); + + SubstitutionArg arg = SubstitutionArg::error (); + bool ok = subst_mappings.get_argument_for_symbol (p, &arg); + if (!ok) + { + rust_error_at (subst_mappings.get_locus (), + "Failed to resolve parameter type: %s", + p->as_string ().c_str ()); + return false; + } + + auto argt = arg.get_tyty (); + bool arg_is_param = argt->get_kind () == TyTy::TypeKind::PARAM; + bool arg_is_concrete = argt->get_kind () != TyTy::TypeKind::INFER; + + if (arg_is_param || arg_is_concrete) + { + auto new_field = argt->clone (); + new_field->set_ref (fty->get_ref ()); + field->set_field_type (new_field); + } + else + { + field->get_field_type ()->set_ty_ref (argt->get_ref ()); + } + } + else if (fty->has_subsititions_defined ()) + { + BaseType *concrete + = Resolver::SubstMapperInternal::Resolve (fty, subst_mappings); + + if (concrete == nullptr + || concrete->get_kind () == TyTy::TypeKind::ERROR) + { + rust_error_at (subst_mappings.get_locus (), + "Failed to resolve field substitution type: %s", + fty->as_string ().c_str ()); + return false; + } + + auto new_field = concrete->clone (); + new_field->set_ref (fty->get_ref ()); + field->set_field_type (new_field); + } return true; }); + + return adt; } void @@ -452,7 +579,135 @@ FnType::clone () std::pair<HIR::Pattern *, BaseType *> (p.first, p.second->clone ())); return new FnType (get_ref (), get_ty_ref (), std::move (cloned_params), - get_return_type ()->clone (), get_combined_refs ()); + get_return_type ()->clone (), clone_substs (), + get_combined_refs ()); +} + +FnType * +FnType::handle_substitions (SubstitutionArgumentMappings subst_mappings) +{ + if (subst_mappings.size () != get_num_substitutions ()) + { + rust_error_at (subst_mappings.get_locus (), + "invalid number of generic arguments to generic ADT type"); + return nullptr; + } + + FnType *fn = static_cast<FnType *> (clone ()); + fn->set_ty_ref (mappings->get_next_hir_id ()); + + for (auto &sub : fn->get_substs ()) + { + SubstitutionArg arg = SubstitutionArg::error (); + bool ok + = subst_mappings.get_argument_for_symbol (sub.get_param_ty (), &arg); + rust_assert (ok); + sub.fill_param_ty (arg.get_tyty ()); + } + + auto fty = fn->get_return_type (); + bool is_param_ty = fty->get_kind () == TypeKind::PARAM; + if (is_param_ty) + { + ParamType *p = static_cast<ParamType *> (fty); + + SubstitutionArg arg = SubstitutionArg::error (); + bool ok = subst_mappings.get_argument_for_symbol (p, &arg); + if (!ok) + { + rust_error_at (subst_mappings.get_locus (), + "Failed to resolve parameter type: %s", + p->as_string ().c_str ()); + return nullptr; + } + + auto argt = arg.get_tyty (); + bool arg_is_param = argt->get_kind () == TyTy::TypeKind::PARAM; + bool arg_is_concrete = argt->get_kind () != TyTy::TypeKind::INFER; + + if (arg_is_param || arg_is_concrete) + { + auto new_field = argt->clone (); + new_field->set_ref (fty->get_ref ()); + fn->type = new_field; + } + else + { + fty->set_ty_ref (argt->get_ref ()); + } + } + else if (fty->has_subsititions_defined ()) + { + BaseType *concrete + = Resolver::SubstMapperInternal::Resolve (fty, subst_mappings); + + if (concrete == nullptr || concrete->get_kind () == TyTy::TypeKind::ERROR) + { + rust_error_at (subst_mappings.get_locus (), + "Failed to resolve field substitution type: %s", + fty->as_string ().c_str ()); + return nullptr; + } + + auto new_field = concrete->clone (); + new_field->set_ref (fty->get_ref ()); + fn->type = new_field; + } + + for (auto ¶m : fn->get_params ()) + { + auto fty = param.second; + bool is_param_ty = fty->get_kind () == TypeKind::PARAM; + if (is_param_ty) + { + ParamType *p = static_cast<ParamType *> (fty); + + SubstitutionArg arg = SubstitutionArg::error (); + bool ok = subst_mappings.get_argument_for_symbol (p, &arg); + if (!ok) + { + rust_error_at (subst_mappings.get_locus (), + "Failed to resolve parameter type: %s", + p->as_string ().c_str ()); + return nullptr; + } + + auto argt = arg.get_tyty (); + bool arg_is_param = argt->get_kind () == TyTy::TypeKind::PARAM; + bool arg_is_concrete = argt->get_kind () != TyTy::TypeKind::INFER; + + if (arg_is_param || arg_is_concrete) + { + auto new_field = argt->clone (); + new_field->set_ref (fty->get_ref ()); + param.second = new_field; + } + else + { + fty->set_ty_ref (argt->get_ref ()); + } + } + else if (fty->has_subsititions_defined ()) + { + BaseType *concrete + = Resolver::SubstMapperInternal::Resolve (fty, subst_mappings); + + if (concrete == nullptr + || concrete->get_kind () == TyTy::TypeKind::ERROR) + { + rust_error_at (subst_mappings.get_locus (), + "Failed to resolve field substitution type: %s", + fty->as_string ().c_str ()); + return nullptr; + } + + auto new_field = concrete->clone (); + new_field->set_ref (fty->get_ref ()); + param.second = new_field; + } + } + + return fn; } void @@ -623,10 +878,20 @@ IntType::unify (BaseType *other) BaseType * IntType::clone () { - return new IntType (get_ref (), get_ty_ref (), get_kind (), + return new IntType (get_ref (), get_ty_ref (), get_int_kind (), get_combined_refs ()); } +bool +IntType::is_equal (const BaseType &other) const +{ + if (!BaseType::is_equal (other)) + return false; + + const IntType &o = static_cast<const IntType &> (other); + return get_int_kind () == o.get_int_kind (); +} + void UintType::accept_vis (TyVisitor &vis) { @@ -663,10 +928,20 @@ UintType::unify (BaseType *other) BaseType * UintType::clone () { - return new UintType (get_ref (), get_ty_ref (), get_kind (), + return new UintType (get_ref (), get_ty_ref (), get_uint_kind (), get_combined_refs ()); } +bool +UintType::is_equal (const BaseType &other) const +{ + if (!BaseType::is_equal (other)) + return false; + + const UintType &o = static_cast<const UintType &> (other); + return get_uint_kind () == o.get_uint_kind (); +} + void FloatType::accept_vis (TyVisitor &vis) { @@ -697,10 +972,20 @@ FloatType::unify (BaseType *other) BaseType * FloatType::clone () { - return new FloatType (get_ref (), get_ty_ref (), get_kind (), + return new FloatType (get_ref (), get_ty_ref (), get_float_kind (), get_combined_refs ()); } +bool +FloatType::is_equal (const BaseType &other) const +{ + if (!BaseType::is_equal (other)) + return false; + + const FloatType &o = static_cast<const FloatType &> (other); + return get_float_kind () == o.get_float_kind (); +} + void USizeType::accept_vis (TyVisitor &vis) { @@ -828,7 +1113,9 @@ std::string ParamType::as_string () const { if (get_ref () == get_ty_ref ()) - return get_symbol (); + { + return get_symbol () + " REF: " + std::to_string (get_ref ()); + } auto context = Resolver::TypeCheckContext::get (); BaseType *lookup = nullptr; @@ -859,14 +1146,33 @@ ParamType::get_symbol () const } BaseType * -ParamType::resolve () +ParamType::resolve () const { - auto context = Resolver::TypeCheckContext::get (); - BaseType *lookup = nullptr; - bool ok = context->lookup_type (get_ty_ref (), &lookup); - rust_assert (ok); + rust_assert (can_resolve ()); - return lookup; + TyVar var (get_ty_ref ()); + BaseType *r = var.get_tyty (); + + while (r->get_kind () == TypeKind::PARAM) + { + ParamType *rr = static_cast<ParamType *> (r); + if (!rr->can_resolve ()) + break; + + TyVar v (rr->get_ty_ref ()); + r = v.get_tyty (); + } + + return TyVar (r->get_ty_ref ()).get_tyty (); +} + +bool +ParamType::is_equal (const BaseType &other) const +{ + if (!can_resolve ()) + return BaseType::is_equal (other); + + return resolve ()->is_equal (other); } BaseType * @@ -936,7 +1242,9 @@ TypeCheckCallExpr::visit (ADTType &type) auto res = field_tyty->unify (arg); if (res == nullptr) - return false; + { + return false; + } delete res; i++; diff --git a/gcc/rust/typecheck/rust-tyty.h b/gcc/rust/typecheck/rust-tyty.h index f0a450e..27cfe46 100644 --- a/gcc/rust/typecheck/rust-tyty.h +++ b/gcc/rust/typecheck/rust-tyty.h @@ -20,6 +20,7 @@ #define RUST_TYTY #include "rust-hir-map.h" +#include "rust-hir-full.h" namespace Rust { namespace TyTy { @@ -106,6 +107,11 @@ public: virtual bool has_subsititions_defined () const { return false; } + virtual bool can_substitute () const + { + return supports_substitutions () && has_subsititions_defined (); + } + std::string mappings_str () const { std::string buffer = "Ref: " + std::to_string (get_ref ()) @@ -149,6 +155,8 @@ public: BaseType *get_tyty () const; + static TyVar get_implict_infer_var (); + private: HirId ref; }; @@ -241,6 +249,46 @@ public: std::string get_name () const override final { return as_string (); } }; +class ParamType : public BaseType +{ +public: + ParamType (std::string symbol, HirId ref, HIR::GenericParam ¶m, + std::set<HirId> refs = std::set<HirId> ()) + : BaseType (ref, ref, TypeKind::PARAM, refs), symbol (symbol), param (param) + {} + + ParamType (std::string symbol, HirId ref, HirId ty_ref, + HIR::GenericParam ¶m, + std::set<HirId> refs = std::set<HirId> ()) + : BaseType (ref, ty_ref, TypeKind::PARAM, refs), symbol (symbol), + param (param) + {} + + void accept_vis (TyVisitor &vis) override; + + std::string as_string () const override; + + BaseType *unify (BaseType *other) override; + + BaseType *clone () final override; + + std::string get_symbol () const; + + HIR::GenericParam &get_generic_param () { return param; } + + bool can_resolve () const { return get_ref () != get_ty_ref (); } + + BaseType *resolve () const; + + std::string get_name () const override final { return as_string (); } + + bool is_equal (const BaseType &other) const override; + +private: + std::string symbol; + HIR::GenericParam ¶m; +}; + class StructFieldType { public: @@ -258,8 +306,12 @@ public: BaseType *get_field_type () const { return ty; } + void set_field_type (BaseType *fty) { ty = fty; } + StructFieldType *clone () const; + void debug () const { printf ("%s\n", as_string ().c_str ()); } + private: HirId ref; std::string name; @@ -310,72 +362,142 @@ private: std::vector<TyVar> fields; }; -class ParamType : public BaseType +class SubstitutionParamMapping { public: - ParamType (std::string symbol, HirId ref, HIR::GenericParam ¶m, - std::set<HirId> refs = std::set<HirId> ()) - : BaseType (ref, ref, TypeKind::PARAM), symbol (symbol), param (param) + SubstitutionParamMapping (std::unique_ptr<HIR::GenericParam> &generic, + ParamType *param) + + : generic (generic), param (param) {} - ParamType (std::string symbol, HirId ref, HirId ty_ref, - HIR::GenericParam ¶m, - std::set<HirId> refs = std::set<HirId> ()) - : BaseType (ref, ty_ref, TypeKind::PARAM), symbol (symbol), param (param) + SubstitutionParamMapping (const SubstitutionParamMapping &other) + : generic (other.generic), param (other.param) {} - void accept_vis (TyVisitor &vis) override; + std::string as_string () const { return param->as_string (); } - std::string as_string () const override; + void fill_param_ty (BaseType *type) + { + if (type->get_kind () == TypeKind::PARAM) + { + delete param; + param = static_cast<ParamType *> (type->clone ()); + } + else + { + param->set_ty_ref (type->get_ref ()); + } + } - BaseType *unify (BaseType *other) override; + SubstitutionParamMapping clone () + { + return SubstitutionParamMapping (generic, static_cast<ParamType *> ( + param->clone ())); + } - BaseType *clone () final override; + const ParamType *get_param_ty () const { return param; } - std::string get_symbol () const; + std::unique_ptr<HIR::GenericParam> &get_generic_param () { return generic; }; - HIR::GenericParam &get_generic_param () { return param; } + void override_context (); - bool can_resolve () const { return get_ref () != get_ty_ref (); } +private: + std::unique_ptr<HIR::GenericParam> &generic; + ParamType *param; +}; - BaseType *resolve (); +class SubstitutionArg +{ +public: + SubstitutionArg (SubstitutionParamMapping *param, BaseType *argument) + : param (std::move (param)), argument (argument) + {} - std::string get_name () const override final { return as_string (); } + SubstitutionArg (const SubstitutionArg &other) + : param (other.param), argument (other.argument) + {} + + SubstitutionArg &operator= (const SubstitutionArg &other) + { + param = other.param; + argument = other.argument; + return *this; + } + + BaseType *get_tyty () { return argument; } + + SubstitutionParamMapping *get_param_mapping () { return param; } + + static SubstitutionArg error () { return SubstitutionArg (nullptr, nullptr); } + + std::string as_string () const + { + return param->as_string () + ":" + argument->as_string (); + } private: - std::string symbol; - HIR::GenericParam ¶m; + SubstitutionParamMapping *param; + BaseType *argument; }; -class SubstitutionMapping +class SubstitutionArgumentMappings { public: - SubstitutionMapping (std::unique_ptr<HIR::GenericParam> &generic, - ParamType *param) - : generic (generic), param (param) + SubstitutionArgumentMappings (std::vector<SubstitutionArg> mappings, + Location locus) + : mappings (mappings), locus (locus) {} - std::string as_string () const { return param->as_string (); } + static SubstitutionArgumentMappings error () + { + return SubstitutionArgumentMappings ({}, Location ()); + } - void fill_param_ty (BaseType *type) { param->set_ty_ref (type->get_ref ()); } + bool is_error () const { return mappings.size () == 0; } - SubstitutionMapping clone () + bool get_argument_for_symbol (const ParamType *param_to_find, + SubstitutionArg *argument) { - return SubstitutionMapping (generic, - static_cast<ParamType *> (param->clone ())); + for (auto &mapping : mappings) + { + SubstitutionParamMapping *param = mapping.get_param_mapping (); + const ParamType *p = param->get_param_ty (); + + if (p->get_symbol ().compare (param_to_find->get_symbol ()) == 0) + { + *argument = mapping; + return true; + } + } + return false; } - const ParamType *get_param_ty () const { return param; } + Location get_locus () { return locus; } + + size_t size () const { return mappings.size (); } + + std::vector<SubstitutionArg> &get_mappings () { return mappings; } + + std::string as_string () const + { + std::string buffer; + for (auto &mapping : mappings) + { + buffer += mapping.as_string () + ", "; + } + return "<" + buffer + ">"; + } private: - std::unique_ptr<HIR::GenericParam> &generic; - ParamType *param; + std::vector<SubstitutionArg> mappings; + Location locus; }; -template <class T> class SubstitutionRef +class SubstitutionRef { public: - SubstitutionRef (std::vector<SubstitutionMapping> substitutions) + SubstitutionRef (std::vector<SubstitutionParamMapping> substitutions) : substitutions (substitutions) {} @@ -386,7 +508,7 @@ public: std::string buffer; for (size_t i = 0; i < substitutions.size (); i++) { - const SubstitutionMapping &sub = substitutions.at (i); + const SubstitutionParamMapping &sub = substitutions.at (i); buffer += sub.as_string (); if ((i + 1) < substitutions.size ()) @@ -398,42 +520,69 @@ public: size_t get_num_substitutions () const { return substitutions.size (); } - std::vector<SubstitutionMapping> &get_substs () { return substitutions; } + std::vector<SubstitutionParamMapping> &get_substs () { return substitutions; } - std::vector<SubstitutionMapping> clone_substs () + std::vector<SubstitutionParamMapping> clone_substs () { - std::vector<SubstitutionMapping> clone; + std::vector<SubstitutionParamMapping> clone; + for (auto &sub : substitutions) clone.push_back (sub.clone ()); return clone; } - virtual T *infer_substitutions () = 0; - - virtual T *handle_substitutions (HIR::GenericArgs &generic_args) = 0; - -protected: - virtual void fill_in_at (size_t index, BaseType *type) + void override_context () { - substitutions.at (index).fill_param_ty (type); + for (auto &sub : substitutions) + { + sub.override_context (); + } } - SubstitutionMapping get_substitution_mapping_at (size_t index) + // We are trying to subst <i32, f32> into Struct Foo<X,Y> {} + // in the case of Foo<i32,f32>{...} + // + // the substitions we have here define X,Y but the arguments have no bindings + // so its a matter of ordering + SubstitutionArgumentMappings + get_mappings_from_generic_args (HIR::GenericArgs &args); + + // Recursive substitutions + // Foo <A,B> { a:A, b: B}; Bar <X,Y,Z>{a:X, b: Foo<Y,Z>} + // + // we have bindings for X Y Z and need to propagate the binding Y,Z into Foo + // Which binds to A,B + SubstitutionArgumentMappings + adjust_mappings_for_this (SubstitutionArgumentMappings &mappings); + + BaseType *infer_substitions (Location locus) { - return substitutions.at (index); + std::vector<SubstitutionArg> args; + for (auto &sub : get_substs ()) + { + TyVar infer_var = TyVar::get_implict_infer_var (); + args.push_back (SubstitutionArg (&sub, infer_var.get_tyty ())); + } + + SubstitutionArgumentMappings infer_arguments (std::move (args), locus); + return handle_substitions (std::move (infer_arguments)); } -private: - std::vector<SubstitutionMapping> substitutions; + virtual BaseType *handle_substitions (SubstitutionArgumentMappings mappings) + = 0; + +protected: + std::vector<SubstitutionParamMapping> substitutions; }; -class ADTType : public BaseType, public SubstitutionRef<ADTType> +class ADTType : public BaseType, public SubstitutionRef + { public: ADTType (HirId ref, std::string identifier, bool is_tuple, std::vector<StructFieldType *> fields, - std::vector<SubstitutionMapping> subst_refs, + std::vector<SubstitutionParamMapping> subst_refs, std::set<HirId> refs = std::set<HirId> ()) : BaseType (ref, ref, TypeKind::ADT, refs), SubstitutionRef (std::move (subst_refs)), identifier (identifier), @@ -442,7 +591,7 @@ public: ADTType (HirId ref, HirId ty_ref, std::string identifier, bool is_tuple, std::vector<StructFieldType *> fields, - std::vector<SubstitutionMapping> subst_refs, + std::vector<SubstitutionParamMapping> subst_refs, std::set<HirId> refs = std::set<HirId> ()) : BaseType (ref, ty_ref, TypeKind::ADT, refs), SubstitutionRef (std::move (subst_refs)), identifier (identifier), @@ -514,13 +663,8 @@ public: return has_substitutions (); } - ADTType *infer_substitutions () override final; - - ADTType *handle_substitutions (HIR::GenericArgs &generic_args) override final; - - void fill_in_at (size_t index, BaseType *type) override final; - - void fill_in_params_for (SubstitutionMapping sub, BaseType *type); + ADTType * + handle_substitions (SubstitutionArgumentMappings mappings) override final; private: std::string identifier; @@ -528,20 +672,23 @@ private: bool is_tuple; }; -class FnType : public BaseType +class FnType : public BaseType, public SubstitutionRef { public: FnType (HirId ref, std::vector<std::pair<HIR::Pattern *, BaseType *> > params, - BaseType *type, std::set<HirId> refs = std::set<HirId> ()) - : BaseType (ref, ref, TypeKind::FNDEF, refs), params (std::move (params)), + BaseType *type, std::vector<SubstitutionParamMapping> subst_refs, + std::set<HirId> refs = std::set<HirId> ()) + : BaseType (ref, ref, TypeKind::FNDEF, refs), + SubstitutionRef (std::move (subst_refs)), params (std::move (params)), type (type) {} FnType (HirId ref, HirId ty_ref, std::vector<std::pair<HIR::Pattern *, BaseType *> > params, - BaseType *type, std::set<HirId> refs = std::set<HirId> ()) - : BaseType (ref, ty_ref, TypeKind::FNDEF, refs), params (params), - type (type) + BaseType *type, std::vector<SubstitutionParamMapping> subst_refs, + std::set<HirId> refs = std::set<HirId> ()) + : BaseType (ref, ty_ref, TypeKind::FNDEF, refs), + SubstitutionRef (std::move (subst_refs)), params (params), type (type) {} void accept_vis (TyVisitor &vis) override; @@ -580,6 +727,16 @@ public: BaseType *clone () final override; + bool supports_substitutions () const override final { return true; } + + bool has_subsititions_defined () const override final + { + return has_substitutions (); + } + + FnType * + handle_substitions (SubstitutionArgumentMappings mappings) override final; + private: std::vector<std::pair<HIR::Pattern *, BaseType *> > params; BaseType *type; @@ -719,10 +876,12 @@ public: BaseType *unify (BaseType *other) override; - IntKind get_kind () const { return int_kind; } + IntKind get_int_kind () const { return int_kind; } BaseType *clone () final override; + bool is_equal (const BaseType &other) const override; + private: IntKind int_kind; }; @@ -756,10 +915,12 @@ public: BaseType *unify (BaseType *other) override; - UintKind get_kind () const { return uint_kind; } + UintKind get_uint_kind () const { return uint_kind; } BaseType *clone () final override; + bool is_equal (const BaseType &other) const override; + private: UintKind uint_kind; }; @@ -791,10 +952,12 @@ public: BaseType *unify (BaseType *other) override; - FloatKind get_kind () const { return float_kind; } + FloatKind get_float_kind () const { return float_kind; } BaseType *clone () final override; + bool is_equal (const BaseType &other) const override; + private: FloatKind float_kind; }; diff --git a/gcc/testsuite/rust.test/compilable/generics3.rs b/gcc/testsuite/rust.test/compilable/generics3.rs new file mode 100644 index 0000000..0dc41c3 --- /dev/null +++ b/gcc/testsuite/rust.test/compilable/generics3.rs @@ -0,0 +1,13 @@ +fn test<T>(a: T) -> T { + a +} + +fn main() { + let a; + a = test(123); + let aa: i32 = a; + + let b; + b = test::<u32>(456); + let bb: u32 = b; +} diff --git a/gcc/testsuite/rust.test/compilable/generics4.rs b/gcc/testsuite/rust.test/compilable/generics4.rs new file mode 100644 index 0000000..81ac4e6 --- /dev/null +++ b/gcc/testsuite/rust.test/compilable/generics4.rs @@ -0,0 +1,13 @@ +struct Foo<T> { + a: T, + b: bool, +} + +fn test<T>(a: T) -> Foo<T> { + Foo { a: a, b: true } +} + +fn main() { + let a: Foo<i32> = test(123); + let b: Foo<u32> = test(456); +} diff --git a/gcc/testsuite/rust.test/compilable/generics5.rs b/gcc/testsuite/rust.test/compilable/generics5.rs new file mode 100644 index 0000000..3d7f70d --- /dev/null +++ b/gcc/testsuite/rust.test/compilable/generics5.rs @@ -0,0 +1,8 @@ +fn test<T>(a: T) -> T { + a +} + +fn main() { + let a: i32 = test(123); + let b: i32 = test(456); +} diff --git a/gcc/testsuite/rust.test/compilable/generics6.rs b/gcc/testsuite/rust.test/compilable/generics6.rs new file mode 100644 index 0000000..da9f167 --- /dev/null +++ b/gcc/testsuite/rust.test/compilable/generics6.rs @@ -0,0 +1,14 @@ +struct Foo<T>(T); + +struct Bar<T> { + a: Foo<T>, + b: bool, +} + +fn main() { + let a: Bar<i32> = Bar::<i32> { + a: Foo::<i32>(123), + b: true, + }; + let b: i32 = a.a.0; +} diff --git a/gcc/testsuite/rust.test/compilable/generics7.rs b/gcc/testsuite/rust.test/compilable/generics7.rs new file mode 100644 index 0000000..b534708 --- /dev/null +++ b/gcc/testsuite/rust.test/compilable/generics7.rs @@ -0,0 +1,12 @@ +struct Foo<T>(T); + +struct Bar { + a: Foo<i32>, + b: bool, +} + +fn main() { + let a = Foo::<i32>(123); + let b: Bar = Bar { a: a, b: true }; + let c: i32 = b.a.0; +} |