diff options
Diffstat (limited to 'gcc')
-rw-r--r-- | gcc/rust/backend/rust-compile-context.h | 17 | ||||
-rw-r--r-- | gcc/rust/backend/rust-compile-implitem.h | 4 | ||||
-rw-r--r-- | gcc/rust/backend/rust-compile-item.h | 2 | ||||
-rw-r--r-- | gcc/rust/backend/rust-compile-tyty.h | 2 | ||||
-rw-r--r-- | gcc/rust/backend/rust-compile.cc | 4 | ||||
-rw-r--r-- | gcc/rust/rust-backend.h | 5 | ||||
-rw-r--r-- | gcc/rust/rust-gcc.cc | 33 | ||||
-rw-r--r-- | gcc/rust/typecheck/rust-hir-type-check-expr.h | 3 | ||||
-rw-r--r-- | gcc/rust/typecheck/rust-hir-type-check-implitem.h | 4 | ||||
-rw-r--r-- | gcc/rust/typecheck/rust-hir-type-check-item.h | 2 | ||||
-rw-r--r-- | gcc/rust/typecheck/rust-hir-type-check-type.h | 19 | ||||
-rw-r--r-- | gcc/rust/typecheck/rust-tyty-call.h | 4 | ||||
-rw-r--r-- | gcc/rust/typecheck/rust-tyty-rules.h | 102 | ||||
-rw-r--r-- | gcc/rust/typecheck/rust-tyty-visitor.h | 1 | ||||
-rw-r--r-- | gcc/rust/typecheck/rust-tyty.cc | 137 | ||||
-rw-r--r-- | gcc/rust/typecheck/rust-tyty.h | 58 | ||||
-rw-r--r-- | gcc/testsuite/rust.test/compilable/function_reference4.rs | 8 |
17 files changed, 359 insertions, 46 deletions
diff --git a/gcc/rust/backend/rust-compile-context.h b/gcc/rust/backend/rust-compile-context.h index c2ed0bb..3f4a9ac 100644 --- a/gcc/rust/backend/rust-compile-context.h +++ b/gcc/rust/backend/rust-compile-context.h @@ -315,6 +315,23 @@ public: ctx->get_mappings ()->lookup_location (type.get_ref ())); } + void visit (TyTy::FnPtr &type) override + { + Btype *result_type + = TyTyResolveCompile::compile (ctx, type.get_return_type ()); + + std::vector<Btype *> parameters; + type.iterate_params ([&] (TyTy::BaseType *p) mutable -> bool { + Btype *pty = TyTyResolveCompile::compile (ctx, p); + parameters.push_back (pty); + return true; + }); + + translated = ctx->get_backend ()->function_ptr_type ( + result_type, parameters, + ctx->get_mappings ()->lookup_location (type.get_ref ())); + } + void visit (TyTy::UnitType &) override { translated = ctx->get_backend ()->void_type (); diff --git a/gcc/rust/backend/rust-compile-implitem.h b/gcc/rust/backend/rust-compile-implitem.h index 76cc608..1b6651a 100644 --- a/gcc/rust/backend/rust-compile-implitem.h +++ b/gcc/rust/backend/rust-compile-implitem.h @@ -110,7 +110,7 @@ public: // setup the params - TyTy::BaseType *tyret = fntype->return_type (); + TyTy::BaseType *tyret = fntype->get_return_type (); std::vector<Bvariable *> param_vars; size_t i = 0; @@ -273,7 +273,7 @@ public: ctx->insert_function_decl (method.get_mappings ().get_hirid (), fndecl); // setup the params - TyTy::BaseType *tyret = fntype->return_type (); + TyTy::BaseType *tyret = fntype->get_return_type (); std::vector<Bvariable *> param_vars; // insert self diff --git a/gcc/rust/backend/rust-compile-item.h b/gcc/rust/backend/rust-compile-item.h index c6b135b..5279218 100644 --- a/gcc/rust/backend/rust-compile-item.h +++ b/gcc/rust/backend/rust-compile-item.h @@ -141,7 +141,7 @@ public: // setup the params - TyTy::BaseType *tyret = fntype->return_type (); + TyTy::BaseType *tyret = fntype->get_return_type (); std::vector<Bvariable *> param_vars; size_t i = 0; diff --git a/gcc/rust/backend/rust-compile-tyty.h b/gcc/rust/backend/rust-compile-tyty.h index e043a50..815ebd5 100644 --- a/gcc/rust/backend/rust-compile-tyty.h +++ b/gcc/rust/backend/rust-compile-tyty.h @@ -58,6 +58,8 @@ public: void visit (TyTy::ParamType &) override { gcc_unreachable (); } + void visit (TyTy::FnPtr &type) override { gcc_unreachable (); } + void visit (TyTy::UnitType &) override { translated = backend->void_type (); } void visit (TyTy::FnType &type) override diff --git a/gcc/rust/backend/rust-compile.cc b/gcc/rust/backend/rust-compile.cc index 2c83527..204cce7 100644 --- a/gcc/rust/backend/rust-compile.cc +++ b/gcc/rust/backend/rust-compile.cc @@ -62,7 +62,9 @@ CompileExpr::visit (HIR::CallExpr &expr) } // must be a tuple constructor - if (tyty->get_kind () != TyTy::TypeKind::FNDEF) + bool is_fn = tyty->get_kind () == TyTy::TypeKind::FNDEF + || tyty->get_kind () == TyTy::TypeKind::FNPTR; + if (!is_fn) { Btype *type = TyTyResolveCompile::compile (ctx, tyty); diff --git a/gcc/rust/rust-backend.h b/gcc/rust/rust-backend.h index 700a376..3edb455 100644 --- a/gcc/rust/rust-backend.h +++ b/gcc/rust/rust-backend.h @@ -155,6 +155,11 @@ public: Btype *result_struct, Location location) = 0; + virtual Btype *function_ptr_type (Btype *result, + const std::vector<Btype *> &praameters, + Location location) + = 0; + // Get a struct type. virtual Btype *struct_type (const std::vector<Btyped_identifier> &fields) = 0; diff --git a/gcc/rust/rust-gcc.cc b/gcc/rust/rust-gcc.cc index 9a13332..cf800c7 100644 --- a/gcc/rust/rust-gcc.cc +++ b/gcc/rust/rust-gcc.cc @@ -206,6 +206,8 @@ public: const std::vector<Btyped_identifier> &, Btype *, const Location); + Btype *function_ptr_type (Btype *, const std::vector<Btype *> &, Location); + Btype *struct_type (const std::vector<Btyped_identifier> &); Btype *array_type (Btype *, Bexpression *); @@ -990,6 +992,37 @@ Gcc_backend::function_type (const Btyped_identifier &receiver, return this->make_type (build_pointer_type (fntype)); } +Btype * +Gcc_backend::function_ptr_type (Btype *result_type, + const std::vector<Btype *> ¶meters, + Location locus) +{ + tree args = NULL_TREE; + tree *pp = &args; + + for (auto ¶m : parameters) + { + tree t = param->get_tree (); + if (t == error_mark_node) + return this->error_type (); + + *pp = tree_cons (NULL_TREE, t, NULL_TREE); + pp = &TREE_CHAIN (*pp); + } + + *pp = void_list_node; + + tree result = result_type->get_tree (); + if (result != void_type_node && int_size_in_bytes (result) == 0) + result = void_type_node; + + tree fntype = build_function_type (result, args); + if (fntype == error_mark_node) + return this->error_type (); + + return this->make_type (build_pointer_type (fntype)); +} + // Make a struct type. Btype * diff --git a/gcc/rust/typecheck/rust-hir-type-check-expr.h b/gcc/rust/typecheck/rust-hir-type-check-expr.h index 7e94b981..d9e8d8b 100644 --- a/gcc/rust/typecheck/rust-hir-type-check-expr.h +++ b/gcc/rust/typecheck/rust-hir-type-check-expr.h @@ -169,7 +169,8 @@ public: return; bool valid_tyty = function_tyty->get_kind () == TyTy::TypeKind::ADT - || function_tyty->get_kind () == TyTy::TypeKind::FNDEF; + || function_tyty->get_kind () == TyTy::TypeKind::FNDEF + || function_tyty->get_kind () == TyTy::TypeKind::FNPTR; if (!valid_tyty) { rust_error_at (expr.get_locus (), diff --git a/gcc/rust/typecheck/rust-hir-type-check-implitem.h b/gcc/rust/typecheck/rust-hir-type-check-implitem.h index fa87bee..0354055 100644 --- a/gcc/rust/typecheck/rust-hir-type-check-implitem.h +++ b/gcc/rust/typecheck/rust-hir-type-check-implitem.h @@ -170,7 +170,7 @@ public: // need to get the return type from this TyTy::FnType *resolve_fn_type = (TyTy::FnType *) lookup; - auto expected_ret_tyty = resolve_fn_type->return_type (); + auto expected_ret_tyty = resolve_fn_type->get_return_type (); context->push_return_type (expected_ret_tyty); auto result = TypeCheckExpr::Resolve (function.function_body.get (), false); @@ -202,7 +202,7 @@ public: // need to get the return type from this TyTy::FnType *resolve_fn_type = (TyTy::FnType *) lookup; - auto expected_ret_tyty = resolve_fn_type->return_type (); + auto expected_ret_tyty = resolve_fn_type->get_return_type (); context->push_return_type (expected_ret_tyty); auto result diff --git a/gcc/rust/typecheck/rust-hir-type-check-item.h b/gcc/rust/typecheck/rust-hir-type-check-item.h index c66be57..54fc3df 100644 --- a/gcc/rust/typecheck/rust-hir-type-check-item.h +++ b/gcc/rust/typecheck/rust-hir-type-check-item.h @@ -72,7 +72,7 @@ public: // need to get the return type from this TyTy::FnType *resolve_fn_type = (TyTy::FnType *) lookup; - auto expected_ret_tyty = resolve_fn_type->return_type (); + auto expected_ret_tyty = resolve_fn_type->get_return_type (); context->push_return_type (expected_ret_tyty); auto result = TypeCheckExpr::Resolve (function.function_body.get (), false); diff --git a/gcc/rust/typecheck/rust-hir-type-check-type.h b/gcc/rust/typecheck/rust-hir-type-check-type.h index af58d60..0d7d07b 100644 --- a/gcc/rust/typecheck/rust-hir-type-check-type.h +++ b/gcc/rust/typecheck/rust-hir-type-check-type.h @@ -107,26 +107,17 @@ public: ? TypeCheckType::Resolve (fntype.get_return_type ().get ()) : new TyTy::UnitType (fntype.get_mappings ().get_hirid ()); - std::vector<std::pair<HIR::Pattern *, TyTy::BaseType *> > params; + std::vector<TyTy::TyCtx> params; for (auto ¶m : fntype.get_function_params ()) { - std::unique_ptr<HIR::Pattern> to_bind; - - bool is_ref = false; - bool is_mut = false; - - HIR::Pattern *pattern - = new HIR::IdentifierPattern (param.get_name (), param.get_locus (), - is_ref, is_mut, std::move (to_bind)); - TyTy::BaseType *ptype = TypeCheckType::Resolve (param.get_type ().get ()); - params.push_back ( - std::pair<HIR::Pattern *, TyTy::BaseType *> (pattern, ptype)); + params.push_back (TyTy::TyCtx (ptype->get_ref ())); } - translated = new TyTy::FnType (fntype.get_mappings ().get_hirid (), - std::move (params), return_type); + translated = new TyTy::FnPtr (fntype.get_mappings ().get_hirid (), + std::move (params), + TyTy::TyCtx (return_type->get_ref ())); } void visit (HIR::TupleType &tuple) diff --git a/gcc/rust/typecheck/rust-tyty-call.h b/gcc/rust/typecheck/rust-tyty-call.h index 2e65244..eac9868 100644 --- a/gcc/rust/typecheck/rust-tyty-call.h +++ b/gcc/rust/typecheck/rust-tyty-call.h @@ -60,6 +60,7 @@ public: // call fns void visit (FnType &type) override; + void visit (FnPtr &type) override; private: TypeCheckCallExpr (HIR::CallExpr &c, Resolver::TypeCheckContext *context) @@ -101,6 +102,9 @@ public: void visit (ParamType &) override { gcc_unreachable (); } void visit (StrType &) override { gcc_unreachable (); } + // FIXME + void visit (FnPtr &type) override { gcc_unreachable (); } + // call fns void visit (FnType &type) override; diff --git a/gcc/rust/typecheck/rust-tyty-rules.h b/gcc/rust/typecheck/rust-tyty-rules.h index 0ea7769..4e48114 100644 --- a/gcc/rust/typecheck/rust-tyty-rules.h +++ b/gcc/rust/typecheck/rust-tyty-rules.h @@ -148,6 +148,14 @@ public: type.as_string ().c_str ()); } + virtual void visit (FnPtr &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + rust_error_at (ref_locus, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + virtual void visit (ArrayType &type) override { Location ref_locus = mappings->lookup_location (type.get_ref ()); @@ -526,7 +534,6 @@ public: return; } - // FIXME add an abstract method for is_equal on BaseType for (size_t i = 0; i < base->num_params (); i++) { auto a = base->param_at (i).second; @@ -558,6 +565,99 @@ private: FnType *base; }; +class FnptrRules : public BaseRules +{ +public: + FnptrRules (FnPtr *base) : BaseRules (base), base (base) {} + + void visit (InferType &type) override + { + if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL) + { + BaseRules::visit (type); + return; + } + + resolved = base->clone (); + resolved->set_ref (type.get_ref ()); + } + + void visit (FnPtr &type) override + { + auto this_ret_type = base->get_return_type (); + auto other_ret_type = type.get_return_type (); + auto unified_result = this_ret_type->unify (other_ret_type); + if (unified_result == nullptr + || unified_result->get_kind () == TypeKind::ERROR) + { + BaseRules::visit (type); + return; + } + + if (base->num_params () != type.num_params ()) + { + BaseRules::visit (type); + return; + } + + for (size_t i = 0; i < base->num_params (); i++) + { + auto this_param = base->param_at (i); + auto other_param = type.param_at (i); + auto unified_param = this_param->unify (other_param); + if (unified_param == nullptr + || unified_param->get_kind () == TypeKind::ERROR) + { + BaseRules::visit (type); + return; + } + } + + resolved = base->clone (); + resolved->set_ref (type.get_ref ()); + } + + void visit (FnType &type) override + { + auto this_ret_type = base->get_return_type (); + auto other_ret_type = type.get_return_type (); + auto unified_result = this_ret_type->unify (other_ret_type); + if (unified_result == nullptr + || unified_result->get_kind () == TypeKind::ERROR) + { + BaseRules::visit (type); + return; + } + + if (base->num_params () != type.num_params ()) + { + BaseRules::visit (type); + return; + } + + for (size_t i = 0; i < base->num_params (); i++) + { + auto this_param = base->param_at (i); + auto other_param = type.param_at (i).second; + auto unified_param = this_param->unify (other_param); + if (unified_param == nullptr + || unified_param->get_kind () == TypeKind::ERROR) + { + BaseRules::visit (type); + return; + } + } + + resolved = base->clone (); + resolved->set_ref (type.get_ref ()); + } + +private: + BaseType *get_base () override { return base; } + + FnPtr *base; +}; + class ArrayRules : public BaseRules { public: diff --git a/gcc/rust/typecheck/rust-tyty-visitor.h b/gcc/rust/typecheck/rust-tyty-visitor.h index 61fd905..8ab7fff 100644 --- a/gcc/rust/typecheck/rust-tyty-visitor.h +++ b/gcc/rust/typecheck/rust-tyty-visitor.h @@ -32,6 +32,7 @@ public: virtual void visit (ADTType &type) = 0; virtual void visit (TupleType &type) = 0; virtual void visit (FnType &type) = 0; + virtual void visit (FnPtr &type) = 0; virtual void visit (ArrayType &type) = 0; virtual void visit (BoolType &type) = 0; virtual void visit (IntType &type) = 0; diff --git a/gcc/rust/typecheck/rust-tyty.cc b/gcc/rust/typecheck/rust-tyty.cc index c16a804..e7b2216 100644 --- a/gcc/rust/typecheck/rust-tyty.cc +++ b/gcc/rust/typecheck/rust-tyty.cc @@ -401,25 +401,23 @@ bool FnType::is_equal (const BaseType &other) const { if (get_kind () != other.get_kind ()) + return false; + + auto other2 = static_cast<const FnType &> (other); + if (!get_return_type ()->is_equal (*other2.get_return_type ())) + return false; + + if (num_params () != other2.num_params ()) + return false; + + for (size_t i = 0; i < num_params (); i++) { - return false; - } - else - { - auto other2 = static_cast<const FnType &> (other); - if (!get_return_type ()->is_equal (*other2.get_return_type ())) + auto lhs = param_at (i).second; + auto rhs = other2.param_at (i).second; + if (!lhs->is_equal (*rhs)) return false; - if (num_params () != other2.num_params ()) - return false; - for (size_t i = 0; i < num_params (); i++) - { - auto lhs = param_at (i).second; - auto rhs = other2.param_at (i).second; - if (!lhs->is_equal (*rhs)) - return false; - } - return true; } + return true; } BaseType * @@ -430,11 +428,69 @@ FnType::clone () cloned_params.push_back ( std::pair<HIR::Pattern *, BaseType *> (p.first, p.second->clone ())); - return new FnType (get_ref (), get_ty_ref (), cloned_params, + return new FnType (get_ref (), get_ty_ref (), std::move (cloned_params), get_return_type ()->clone (), get_combined_refs ()); } void +FnPtr::accept_vis (TyVisitor &vis) +{ + vis.visit (*this); +} + +std::string +FnPtr::as_string () const +{ + std::string params_str; + iterate_params ([&] (BaseType *p) mutable -> bool { + params_str += p->as_string () + " ,"; + return true; + }); + return "fnptr (" + params_str + ") -> " + get_return_type ()->as_string (); +} + +BaseType * +FnPtr::unify (BaseType *other) +{ + FnptrRules r (this); + return r.unify (other); +} + +bool +FnPtr::is_equal (const BaseType &other) const +{ + if (get_kind () != other.get_kind ()) + return false; + + auto other2 = static_cast<const FnPtr &> (other); + auto this_ret_type = get_return_type (); + auto other_ret_type = other2.get_return_type (); + if (this_ret_type->is_equal (*other_ret_type)) + return false; + + if (num_params () != other2.num_params ()) + return false; + + for (size_t i = 0; i < num_params (); i++) + { + if (!param_at (i)->is_equal (*other2.param_at (i))) + return false; + } + return true; +} + +BaseType * +FnPtr::clone () +{ + std::vector<TyCtx> cloned_params; + for (auto &p : params) + cloned_params.push_back (TyCtx (p.get_ref ())); + + return new FnPtr (get_ref (), get_ty_ref (), std::move (cloned_params), + result_type, get_combined_refs ()); +} + +void ArrayType::accept_vis (TyVisitor &vis) { vis.visit (*this); @@ -922,6 +978,53 @@ TypeCheckCallExpr::visit (FnType &type) resolved = type.get_return_type ()->clone (); } +void +TypeCheckCallExpr::visit (FnPtr &type) +{ + if (call.num_params () != type.num_params ()) + { + rust_error_at (call.get_locus (), + "unexpected number of arguments %lu expected %lu", + call.num_params (), type.num_params ()); + return; + } + + size_t i = 0; + call.iterate_params ([&] (HIR::Expr *param) mutable -> bool { + auto fnparam = type.param_at (i); + auto argument_expr_tyty = Resolver::TypeCheckExpr::Resolve (param, false); + if (argument_expr_tyty == nullptr) + { + rust_error_at (param->get_locus_slow (), + "failed to resolve type for argument expr in CallExpr"); + return false; + } + + auto resolved_argument_type = fnparam->unify (argument_expr_tyty); + if (resolved_argument_type == nullptr) + { + rust_error_at (param->get_locus_slow (), + "Type Resolution failure on parameter"); + return false; + } + + context->insert_type (param->get_mappings (), resolved_argument_type); + + i++; + return true; + }); + + if (i != call.num_params ()) + { + rust_error_at (call.get_locus (), + "unexpected number of arguments %lu expected %lu", i, + call.num_params ()); + return; + } + + resolved = type.get_return_type ()->clone (); +} + // method call checker void diff --git a/gcc/rust/typecheck/rust-tyty.h b/gcc/rust/typecheck/rust-tyty.h index 372cb7d..3f22955 100644 --- a/gcc/rust/typecheck/rust-tyty.h +++ b/gcc/rust/typecheck/rust-tyty.h @@ -35,6 +35,7 @@ enum TypeKind PARAM, ARRAY, FNDEF, + FNPTR, TUPLE, BOOL, CHAR, @@ -261,7 +262,7 @@ public: BaseType *unify (BaseType *other) override; - virtual bool is_equal (const BaseType &other) const override; + bool is_equal (const BaseType &other) const override; size_t num_fields () const { return fields.size (); } @@ -433,7 +434,7 @@ public: BaseType *unify (BaseType *other) override; - virtual bool is_equal (const BaseType &other) const override; + bool is_equal (const BaseType &other) const override; size_t num_fields () const { return fields.size (); } @@ -524,11 +525,9 @@ public: std::string get_name () const override final { return as_string (); } - BaseType *return_type () { return type; } - BaseType *unify (BaseType *other) override; - virtual bool is_equal (const BaseType &other) const override; + bool is_equal (const BaseType &other) const override; size_t num_params () const { return params.size (); } @@ -561,6 +560,53 @@ private: BaseType *type; }; +class FnPtr : public BaseType +{ +public: + FnPtr (HirId ref, std::vector<TyCtx> params, TyCtx result_type, + std::set<HirId> refs = std::set<HirId> ()) + : BaseType (ref, ref, TypeKind::FNPTR, refs), params (std::move (params)), + result_type (result_type) + {} + + FnPtr (HirId ref, HirId ty_ref, std::vector<TyCtx> params, TyCtx result_type, + std::set<HirId> refs = std::set<HirId> ()) + : BaseType (ref, ty_ref, TypeKind::FNPTR, refs), params (params), + result_type (result_type) + {} + + std::string get_name () const override final { return as_string (); } + + BaseType *get_return_type () const { return result_type.get_tyty (); } + + size_t num_params () const { return params.size (); } + + BaseType *param_at (size_t idx) const { return params.at (idx).get_tyty (); } + + void accept_vis (TyVisitor &vis) override; + + std::string as_string () const override; + + BaseType *unify (BaseType *other) override; + + bool is_equal (const BaseType &other) const override; + + BaseType *clone () final override; + + void iterate_params (std::function<bool (BaseType *)> cb) const + { + for (auto &p : params) + { + if (!cb (p.get_tyty ())) + return; + } + } + +private: + std::vector<TyCtx> params; + TyCtx result_type; +}; + class ArrayType : public BaseType { public: @@ -584,7 +630,7 @@ public: BaseType *unify (BaseType *other) override; - virtual bool is_equal (const BaseType &other) const override; + bool is_equal (const BaseType &other) const override; size_t get_capacity () const { return capacity; } diff --git a/gcc/testsuite/rust.test/compilable/function_reference4.rs b/gcc/testsuite/rust.test/compilable/function_reference4.rs new file mode 100644 index 0000000..a27f0e4 --- /dev/null +++ b/gcc/testsuite/rust.test/compilable/function_reference4.rs @@ -0,0 +1,8 @@ +fn test(a: i32) -> i32 { + a + 1 +} + +fn main() { + let a: fn(_) -> _ = test; + let b = a(1); +} |