aboutsummaryrefslogtreecommitdiff
path: root/gcc
diff options
context:
space:
mode:
authorPhilip Herron <philip.herron@embecosm.com>2021-06-23 17:22:12 +0100
committerPhilip Herron <philip.herron@embecosm.com>2021-06-24 12:36:08 +0100
commit92a434a33904608f5659cf7b5d4df3d2a99bd5bd (patch)
tree02daa3d23d04cb9c950b15d5a7a94f171e1241db /gcc
parent23e748d7a6855ce132299cfef9692ee9c681de59 (diff)
downloadgcc-92a434a33904608f5659cf7b5d4df3d2a99bd5bd.zip
gcc-92a434a33904608f5659cf7b5d4df3d2a99bd5bd.tar.gz
gcc-92a434a33904608f5659cf7b5d4df3d2a99bd5bd.tar.bz2
Add support for nested functions
We missed that stmts in rust can be items like functions. This adds support for resolution and compilation of nested functions. Rust allows nested functions which are distinct to closures. Nested functions are not allowed to encapsulate the enclosing scope so they can be extracted as normal functions.
Diffstat (limited to 'gcc')
-rw-r--r--gcc/rust/backend/rust-compile-base.h3
-rw-r--r--gcc/rust/backend/rust-compile-implitem.h44
-rw-r--r--gcc/rust/backend/rust-compile-item.h22
-rw-r--r--gcc/rust/backend/rust-compile.cc51
-rw-r--r--gcc/rust/hir/rust-ast-lower-stmt.h85
-rw-r--r--gcc/rust/resolve/rust-ast-resolve-stmt.h55
-rw-r--r--gcc/rust/resolve/rust-ast-resolve.cc5
-rw-r--r--gcc/rust/typecheck/rust-hir-type-check-stmt.h80
-rw-r--r--gcc/testsuite/rust/compile/lookup_err1.rs7
-rw-r--r--gcc/testsuite/rust/compile/torture/nested_fn1.rs10
-rw-r--r--gcc/testsuite/rust/compile/torture/nested_fn2.rs11
11 files changed, 299 insertions, 74 deletions
diff --git a/gcc/rust/backend/rust-compile-base.h b/gcc/rust/backend/rust-compile-base.h
index ed33515..c346af5 100644
--- a/gcc/rust/backend/rust-compile-base.h
+++ b/gcc/rust/backend/rust-compile-base.h
@@ -210,6 +210,9 @@ protected:
void compile_function_body (Bfunction *fndecl,
std::unique_ptr<HIR::BlockExpr> &function_body,
bool has_return_type);
+
+ bool compile_locals_for_block (Resolver::Rib &rib, Bfunction *fndecl,
+ std::vector<Bvariable *> &locals);
};
} // namespace Compile
diff --git a/gcc/rust/backend/rust-compile-implitem.h b/gcc/rust/backend/rust-compile-implitem.h
index d6698d1..70f76b7 100644
--- a/gcc/rust/backend/rust-compile-implitem.h
+++ b/gcc/rust/backend/rust-compile-implitem.h
@@ -183,26 +183,10 @@ public:
}
std::vector<Bvariable *> locals;
- rib->iterate_decls ([&] (NodeId n, Location) mutable -> bool {
- Resolver::Definition d;
- bool ok = ctx->get_resolver ()->lookup_definition (n, &d);
- rust_assert (ok);
-
- HIR::Stmt *decl = nullptr;
- ok = ctx->get_mappings ()->resolve_nodeid_to_stmt (d.parent, &decl);
- rust_assert (ok);
-
- Bvariable *compiled = CompileVarDecl::compile (fndecl, decl, ctx);
- locals.push_back (compiled);
-
- return true;
- });
-
- bool toplevel_item
- = function.get_mappings ().get_local_defid () != UNKNOWN_LOCAL_DEFID;
- Bblock *enclosing_scope
- = toplevel_item ? NULL : ctx->peek_enclosing_scope ();
+ bool ok = compile_locals_for_block (*rib, fndecl, locals);
+ rust_assert (ok);
+ Bblock *enclosing_scope = NULL;
HIR::BlockExpr *function_body = function.get_definition ().get ();
Location start_location = function_body->get_locus ();
Location end_location = function_body->get_closing_locus ();
@@ -409,26 +393,10 @@ public:
}
std::vector<Bvariable *> locals;
- rib->iterate_decls ([&] (NodeId n, Location) mutable -> bool {
- Resolver::Definition d;
- bool ok = ctx->get_resolver ()->lookup_definition (n, &d);
- rust_assert (ok);
-
- HIR::Stmt *decl = nullptr;
- ok = ctx->get_mappings ()->resolve_nodeid_to_stmt (d.parent, &decl);
- rust_assert (ok);
-
- Bvariable *compiled = CompileVarDecl::compile (fndecl, decl, ctx);
- locals.push_back (compiled);
-
- return true;
- });
-
- bool toplevel_item
- = method.get_mappings ().get_local_defid () != UNKNOWN_LOCAL_DEFID;
- Bblock *enclosing_scope
- = toplevel_item ? NULL : ctx->peek_enclosing_scope ();
+ bool ok = compile_locals_for_block (*rib, fndecl, locals);
+ rust_assert (ok);
+ Bblock *enclosing_scope = NULL;
HIR::BlockExpr *function_body = method.get_function_body ().get ();
Location start_location = function_body->get_locus ();
Location end_location = function_body->get_closing_locus ();
diff --git a/gcc/rust/backend/rust-compile-item.h b/gcc/rust/backend/rust-compile-item.h
index 8a521e7..eacfda9 100644
--- a/gcc/rust/backend/rust-compile-item.h
+++ b/gcc/rust/backend/rust-compile-item.h
@@ -213,26 +213,10 @@ public:
}
std::vector<Bvariable *> locals;
- rib->iterate_decls ([&] (NodeId n, Location) mutable -> bool {
- Resolver::Definition d;
- bool ok = ctx->get_resolver ()->lookup_definition (n, &d);
- rust_assert (ok);
-
- HIR::Stmt *decl = nullptr;
- ok = ctx->get_mappings ()->resolve_nodeid_to_stmt (d.parent, &decl);
- rust_assert (ok);
-
- Bvariable *compiled = CompileVarDecl::compile (fndecl, decl, ctx);
- locals.push_back (compiled);
-
- return true;
- });
-
- bool toplevel_item
- = function.get_mappings ().get_local_defid () != UNKNOWN_LOCAL_DEFID;
- Bblock *enclosing_scope
- = toplevel_item ? NULL : ctx->peek_enclosing_scope ();
+ bool ok = compile_locals_for_block (*rib, fndecl, locals);
+ rust_assert (ok);
+ Bblock *enclosing_scope = NULL;
HIR::BlockExpr *function_body = function.get_definition ().get ();
Location start_location = function_body->get_locus ();
Location end_location = function_body->get_closing_locus ();
diff --git a/gcc/rust/backend/rust-compile.cc b/gcc/rust/backend/rust-compile.cc
index 351271c..5ffd11a 100644
--- a/gcc/rust/backend/rust-compile.cc
+++ b/gcc/rust/backend/rust-compile.cc
@@ -212,20 +212,8 @@ CompileBlock::visit (HIR::BlockExpr &expr)
}
std::vector<Bvariable *> locals;
- rib->iterate_decls ([&] (NodeId n, Location) mutable -> bool {
- Resolver::Definition d;
- bool ok = ctx->get_resolver ()->lookup_definition (n, &d);
- rust_assert (ok);
-
- HIR::Stmt *decl = nullptr;
- ok = ctx->get_mappings ()->resolve_nodeid_to_stmt (d.parent, &decl);
- rust_assert (ok);
-
- Bvariable *compiled = CompileVarDecl::compile (fndecl, decl, ctx);
- locals.push_back (compiled);
-
- return true;
- });
+ bool ok = compile_locals_for_block (*rib, fndecl, locals);
+ rust_assert (ok);
Bblock *enclosing_scope = ctx->peek_enclosing_scope ();
Bblock *new_block
@@ -415,6 +403,41 @@ HIRCompileBase::compile_function_body (
}
}
+bool
+HIRCompileBase::compile_locals_for_block (Resolver::Rib &rib, Bfunction *fndecl,
+ std::vector<Bvariable *> &locals)
+{
+ rib.iterate_decls ([&] (NodeId n, Location) mutable -> bool {
+ Resolver::Definition d;
+ bool ok = ctx->get_resolver ()->lookup_definition (n, &d);
+ rust_assert (ok);
+
+ HIR::Stmt *decl = nullptr;
+ ok = ctx->get_mappings ()->resolve_nodeid_to_stmt (d.parent, &decl);
+ rust_assert (ok);
+
+ // if its a function we extract this out side of this fn context
+ // and it is not a local to this function
+ bool is_item = ctx->get_mappings ()->lookup_hir_item (
+ decl->get_mappings ().get_crate_num (),
+ decl->get_mappings ().get_hirid ())
+ != nullptr;
+ if (is_item)
+ {
+ HIR::Item *item = static_cast<HIR::Item *> (decl);
+ CompileItem::compile (item, ctx, true);
+ return true;
+ }
+
+ Bvariable *compiled = CompileVarDecl::compile (fndecl, decl, ctx);
+ locals.push_back (compiled);
+
+ return true;
+ });
+
+ return true;
+}
+
// Mr Mangle time
static const std::string kMangledSymbolPrefix = "_ZN";
diff --git a/gcc/rust/hir/rust-ast-lower-stmt.h b/gcc/rust/hir/rust-ast-lower-stmt.h
index c495932..1dd8a10 100644
--- a/gcc/rust/hir/rust-ast-lower-stmt.h
+++ b/gcc/rust/hir/rust-ast-lower-stmt.h
@@ -230,6 +230,91 @@ public:
empty.get_locus ());
}
+ void visit (AST::Function &function) override
+ {
+ // ignore for now and leave empty
+ std::vector<std::unique_ptr<HIR::WhereClauseItem> > where_clause_items;
+ HIR::WhereClause where_clause (std::move (where_clause_items));
+ HIR::FunctionQualifiers qualifiers (
+ HIR::FunctionQualifiers::AsyncConstStatus::NONE, false);
+ HIR::Visibility vis = HIR::Visibility::create_public ();
+
+ // need
+ std::vector<std::unique_ptr<HIR::GenericParam> > generic_params;
+ if (function.has_generics ())
+ {
+ generic_params = lower_generic_params (function.get_generic_params ());
+ }
+
+ Identifier function_name = function.get_function_name ();
+ Location locus = function.get_locus ();
+
+ std::unique_ptr<HIR::Type> return_type
+ = function.has_return_type () ? std::unique_ptr<HIR::Type> (
+ ASTLoweringType::translate (function.get_return_type ().get ()))
+ : nullptr;
+
+ std::vector<HIR::FunctionParam> function_params;
+ for (auto &param : function.get_function_params ())
+ {
+ auto translated_pattern = std::unique_ptr<HIR::Pattern> (
+ ASTLoweringPattern::translate (param.get_pattern ().get ()));
+ auto translated_type = std::unique_ptr<HIR::Type> (
+ ASTLoweringType::translate (param.get_type ().get ()));
+
+ auto crate_num = mappings->get_current_crate ();
+ Analysis::NodeMapping mapping (crate_num, param.get_node_id (),
+ mappings->get_next_hir_id (crate_num),
+ UNKNOWN_LOCAL_DEFID);
+
+ auto hir_param
+ = HIR::FunctionParam (mapping, std::move (translated_pattern),
+ std::move (translated_type),
+ param.get_locus ());
+ function_params.push_back (hir_param);
+ }
+
+ bool terminated = false;
+ std::unique_ptr<HIR::BlockExpr> function_body
+ = std::unique_ptr<HIR::BlockExpr> (
+ ASTLoweringBlock::translate (function.get_definition ().get (),
+ &terminated));
+
+ auto crate_num = mappings->get_current_crate ();
+ Analysis::NodeMapping mapping (crate_num, function.get_node_id (),
+ mappings->get_next_hir_id (crate_num),
+ UNKNOWN_LOCAL_DEFID);
+
+ mappings->insert_location (crate_num,
+ function_body->get_mappings ().get_hirid (),
+ function.get_locus ());
+
+ auto fn
+ = new HIR::Function (mapping, std::move (function_name),
+ std::move (qualifiers), std::move (generic_params),
+ std::move (function_params), std::move (return_type),
+ std::move (where_clause), std::move (function_body),
+ std::move (vis), function.get_outer_attrs (), locus);
+
+ mappings->insert_hir_item (mapping.get_crate_num (), mapping.get_hirid (),
+ fn);
+ mappings->insert_hir_stmt (mapping.get_crate_num (), mapping.get_hirid (),
+ fn);
+ mappings->insert_location (crate_num, mapping.get_hirid (),
+ function.get_locus ());
+
+ // add the mappings for the function params at the end
+ for (auto &param : fn->get_function_params ())
+ {
+ mappings->insert_hir_param (mapping.get_crate_num (),
+ param.get_mappings ().get_hirid (), &param);
+ mappings->insert_location (crate_num, mapping.get_hirid (),
+ param.get_locus ());
+ }
+
+ translated = fn;
+ }
+
private:
ASTLoweringStmt () : translated (nullptr), terminated (false) {}
diff --git a/gcc/rust/resolve/rust-ast-resolve-stmt.h b/gcc/rust/resolve/rust-ast-resolve-stmt.h
index 3fd1cfa..e68e7b9 100644
--- a/gcc/rust/resolve/rust-ast-resolve-stmt.h
+++ b/gcc/rust/resolve/rust-ast-resolve-stmt.h
@@ -129,6 +129,61 @@ public:
resolver->get_type_scope ().pop ();
}
+ void visit (AST::Function &function) override
+ {
+ auto path = ResolveFunctionItemToCanonicalPath::resolve (function);
+ resolver->get_name_scope ().insert (
+ path, function.get_node_id (), function.get_locus (), false,
+ [&] (const CanonicalPath &, NodeId, Location locus) -> void {
+ RichLocation r (function.get_locus ());
+ r.add_range (locus);
+ rust_error_at (r, "redefined multiple times");
+ });
+ resolver->insert_new_definition (function.get_node_id (),
+ Definition{function.get_node_id (),
+ function.get_node_id ()});
+
+ NodeId scope_node_id = function.get_node_id ();
+ resolver->get_name_scope ().push (scope_node_id);
+ resolver->get_type_scope ().push (scope_node_id);
+ resolver->get_label_scope ().push (scope_node_id);
+ resolver->push_new_name_rib (resolver->get_name_scope ().peek ());
+ resolver->push_new_type_rib (resolver->get_type_scope ().peek ());
+ resolver->push_new_label_rib (resolver->get_type_scope ().peek ());
+
+ if (function.has_generics ())
+ {
+ for (auto &generic : function.get_generic_params ())
+ ResolveGenericParam::go (generic.get (), function.get_node_id ());
+ }
+
+ if (function.has_return_type ())
+ ResolveType::go (function.get_return_type ().get (),
+ function.get_node_id ());
+
+ // we make a new scope so the names of parameters are resolved and shadowed
+ // correctly
+ for (auto &param : function.get_function_params ())
+ {
+ ResolveType::go (param.get_type ().get (), param.get_node_id ());
+ PatternDeclaration::go (param.get_pattern ().get (),
+ param.get_node_id ());
+
+ // the mutability checker needs to verify for immutable decls the number
+ // of assignments are <1. This marks an implicit assignment
+ resolver->mark_assignment_to_decl (param.get_pattern ()->get_node_id (),
+ param.get_node_id ());
+ }
+
+ // resolve the function body
+ ResolveExpr::go (function.get_definition ().get (),
+ function.get_node_id ());
+
+ resolver->get_name_scope ().pop ();
+ resolver->get_type_scope ().pop ();
+ resolver->get_label_scope ().pop ();
+ }
+
private:
ResolveStmt (NodeId parent) : ResolverBase (parent) {}
};
diff --git a/gcc/rust/resolve/rust-ast-resolve.cc b/gcc/rust/resolve/rust-ast-resolve.cc
index e03a745..fae3f77 100644
--- a/gcc/rust/resolve/rust-ast-resolve.cc
+++ b/gcc/rust/resolve/rust-ast-resolve.cc
@@ -499,9 +499,8 @@ ResolvePath::resolve_path (AST::PathInExpression *expr)
else
{
rust_error_at (expr->get_locus (),
- "unknown root segment in path %s lookup %s",
- expr->as_string ().c_str (),
- root_ident_seg.as_string ().c_str ());
+ "Cannot find path %<%s%> in this scope",
+ expr->as_string ().c_str ());
return;
}
diff --git a/gcc/rust/typecheck/rust-hir-type-check-stmt.h b/gcc/rust/typecheck/rust-hir-type-check-stmt.h
index 0e55df8..3655d96 100644
--- a/gcc/rust/typecheck/rust-hir-type-check-stmt.h
+++ b/gcc/rust/typecheck/rust-hir-type-check-stmt.h
@@ -216,6 +216,86 @@ public:
infered = type;
}
+ void visit (HIR::Function &function) override
+ {
+ std::vector<TyTy::SubstitutionParamMapping> substitutions;
+ if (function.has_generics ())
+ {
+ for (auto &generic_param : function.get_generic_params ())
+ {
+ switch (generic_param.get ()->get_kind ())
+ {
+ case HIR::GenericParam::GenericKind::LIFETIME:
+ // Skipping Lifetime completely until better handling.
+ break;
+
+ case HIR::GenericParam::GenericKind::TYPE: {
+ auto param_type
+ = TypeResolveGenericParam::Resolve (generic_param.get ());
+ context->insert_type (generic_param->get_mappings (),
+ param_type);
+
+ substitutions.push_back (TyTy::SubstitutionParamMapping (
+ static_cast<HIR::TypeParam &> (*generic_param),
+ param_type));
+ }
+ break;
+ }
+ }
+ }
+
+ TyTy::BaseType *ret_type = nullptr;
+ if (!function.has_function_return_type ())
+ ret_type = new TyTy::TupleType (function.get_mappings ().get_hirid ());
+ else
+ {
+ auto resolved
+ = TypeCheckType::Resolve (function.get_return_type ().get ());
+ if (resolved == nullptr)
+ {
+ rust_error_at (function.get_locus (),
+ "failed to resolve return type");
+ return;
+ }
+
+ ret_type = resolved->clone ();
+ ret_type->set_ref (
+ function.get_return_type ()->get_mappings ().get_hirid ());
+ }
+
+ std::vector<std::pair<HIR::Pattern *, TyTy::BaseType *> > params;
+ for (auto &param : function.get_function_params ())
+ {
+ // get the name as well required for later on
+ auto param_tyty = TypeCheckType::Resolve (param.get_type ());
+ params.push_back (
+ std::pair<HIR::Pattern *, TyTy::BaseType *> (param.get_param_name (),
+ param_tyty));
+
+ context->insert_type (param.get_mappings (), param_tyty);
+ }
+
+ auto fnType = new TyTy::FnType (function.get_mappings ().get_hirid (),
+ function.get_function_name (), false,
+ std::move (params), ret_type,
+ std::move (substitutions));
+ context->insert_type (function.get_mappings (), fnType);
+
+ TyTy::FnType *resolved_fn_type = fnType;
+ auto expected_ret_tyty = resolved_fn_type->get_return_type ();
+ context->push_return_type (expected_ret_tyty);
+
+ auto block_expr_ty
+ = TypeCheckExpr::Resolve (function.get_definition ().get (), false);
+
+ context->pop_return_type ();
+
+ if (block_expr_ty->get_kind () != TyTy::NEVER)
+ expected_ret_tyty->unify (block_expr_ty);
+
+ infered = fnType;
+ }
+
private:
TypeCheckStmt (bool inside_loop)
: TypeCheckBase (), infered (nullptr), inside_loop (inside_loop)
diff --git a/gcc/testsuite/rust/compile/lookup_err1.rs b/gcc/testsuite/rust/compile/lookup_err1.rs
new file mode 100644
index 0000000..4a96f9f
--- /dev/null
+++ b/gcc/testsuite/rust/compile/lookup_err1.rs
@@ -0,0 +1,7 @@
+fn test() {
+ fn nested() {}
+}
+
+fn main() {
+ nested(); // { dg-error "Cannot find path .nested. in this scope" }
+}
diff --git a/gcc/testsuite/rust/compile/torture/nested_fn1.rs b/gcc/testsuite/rust/compile/torture/nested_fn1.rs
new file mode 100644
index 0000000..075b5db
--- /dev/null
+++ b/gcc/testsuite/rust/compile/torture/nested_fn1.rs
@@ -0,0 +1,10 @@
+pub fn main() {
+ let a = 123;
+
+ fn test(x: i32) -> i32 {
+ x + 456
+ }
+
+ let b;
+ b = test(a);
+}
diff --git a/gcc/testsuite/rust/compile/torture/nested_fn2.rs b/gcc/testsuite/rust/compile/torture/nested_fn2.rs
new file mode 100644
index 0000000..7040c86
--- /dev/null
+++ b/gcc/testsuite/rust/compile/torture/nested_fn2.rs
@@ -0,0 +1,11 @@
+pub fn main() {
+ fn test<T>(x: T) -> T {
+ x
+ }
+
+ let mut a = 123;
+ a = test(a);
+
+ let mut b = 456f32;
+ b = test(b);
+}