aboutsummaryrefslogtreecommitdiff
path: root/gcc/rust
diff options
context:
space:
mode:
authorPhilip Herron <philip.herron@embecosm.com>2021-07-19 18:55:25 +0100
committerPhilip Herron <philip.herron@embecosm.com>2021-07-19 19:38:29 +0100
commit0cbd3afc714a1d874fd829108f9b51a44205c050 (patch)
treea51428aa157c2d0c4d66989a6a0e1538ca462e9a /gcc/rust
parentf82bf003cebb8a296312d32882db03e52945dac3 (diff)
downloadgcc-0cbd3afc714a1d874fd829108f9b51a44205c050.zip
gcc-0cbd3afc714a1d874fd829108f9b51a44205c050.tar.gz
gcc-0cbd3afc714a1d874fd829108f9b51a44205c050.tar.bz2
Initial coercion rules
Lets keep the same unify pattern for coercion rules to keep the code as readable as possible. Addresses #434
Diffstat (limited to 'gcc/rust')
-rw-r--r--gcc/rust/typecheck/rust-tyty-coercion.h1198
-rw-r--r--gcc/rust/typecheck/rust-tyty.cc133
-rw-r--r--gcc/rust/typecheck/rust-tyty.h23
3 files changed, 1354 insertions, 0 deletions
diff --git a/gcc/rust/typecheck/rust-tyty-coercion.h b/gcc/rust/typecheck/rust-tyty-coercion.h
new file mode 100644
index 0000000..6695056
--- /dev/null
+++ b/gcc/rust/typecheck/rust-tyty-coercion.h
@@ -0,0 +1,1198 @@
+// Copyright (C) 2020 Free Software Foundation, Inc.
+
+// This file is part of GCC.
+
+// GCC is free software; you can redistribute it and/or modify it under
+// the terms of the GNU General Public License as published by the Free
+// Software Foundation; either version 3, or (at your option) any later
+// version.
+
+// GCC is distributed in the hope that it will be useful, but WITHOUT ANY
+// WARRANTY; without even the implied warranty of MERCHANTABILITY or
+// FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
+// for more details.
+
+// You should have received a copy of the GNU General Public License
+// along with GCC; see the file COPYING3. If not see
+// <http://www.gnu.org/licenses/>.
+
+#ifndef RUST_TYTY_COERCION_RULES
+#define RUST_TYTY_COERCION_RULES
+
+#include "rust-diagnostics.h"
+#include "rust-tyty.h"
+#include "rust-tyty-visitor.h"
+#include "rust-hir-map.h"
+#include "rust-hir-type-check.h"
+
+extern ::Backend *
+rust_get_backend ();
+
+namespace Rust {
+namespace TyTy {
+
+class BaseCoercionRules : public TyVisitor
+{
+public:
+ virtual ~BaseCoercionRules () {}
+
+ virtual BaseType *coerce (BaseType *other)
+ {
+ if (other->get_kind () == TypeKind::PARAM)
+ {
+ ParamType *p = static_cast<ParamType *> (other);
+ if (p->can_resolve ())
+ {
+ other = p->resolve ();
+ }
+ }
+
+ other->accept_vis (*this);
+ if (resolved->get_kind () == TyTy::TypeKind::ERROR)
+ return resolved;
+
+ resolved->append_reference (get_base ()->get_ref ());
+ resolved->append_reference (other->get_ref ());
+ for (auto ref : get_base ()->get_combined_refs ())
+ resolved->append_reference (ref);
+ for (auto ref : other->get_combined_refs ())
+ resolved->append_reference (ref);
+
+ bool result_resolved = resolved->get_kind () != TyTy::TypeKind::INFER;
+ bool result_is_infer_var = resolved->get_kind () == TyTy::TypeKind::INFER;
+ bool results_is_non_general_infer_var
+ = (result_is_infer_var
+ && (static_cast<InferType *> (resolved))->get_infer_kind ()
+ != TyTy::InferType::GENERAL);
+ if (result_resolved || results_is_non_general_infer_var)
+ {
+ for (auto &ref : resolved->get_combined_refs ())
+ {
+ TyTy::BaseType *ref_tyty = nullptr;
+ bool ok = context->lookup_type (ref, &ref_tyty);
+ if (!ok)
+ continue;
+
+ // if any of the types are inference variables lets fix them
+ if (ref_tyty->get_kind () == TyTy::TypeKind::INFER)
+ {
+ context->insert_type (
+ Analysis::NodeMapping (mappings->get_current_crate (),
+ UNKNOWN_NODEID, ref,
+ UNKNOWN_LOCAL_DEFID),
+ resolved->clone ());
+ }
+ }
+ }
+ return resolved;
+ }
+
+ virtual void visit (TupleType &type) override
+ {
+ Location ref_locus = mappings->lookup_location (type.get_ref ());
+ Location base_locus = mappings->lookup_location (get_base ()->get_ref ());
+ RichLocation r (ref_locus);
+ r.add_range (base_locus);
+ rust_error_at (r, "cannot coerce [%s] with [%s]",
+ get_base ()->as_string ().c_str (),
+ type.as_string ().c_str ());
+ }
+
+ virtual void visit (ADTType &type) override
+ {
+ Location ref_locus = mappings->lookup_location (type.get_ref ());
+ Location base_locus = mappings->lookup_location (get_base ()->get_ref ());
+ RichLocation r (ref_locus);
+ r.add_range (base_locus);
+ rust_error_at (r, "cannot coerce [%s] with [%s]",
+ get_base ()->as_string ().c_str (),
+ type.as_string ().c_str ());
+ }
+
+ virtual void visit (InferType &type) override
+ {
+ Location ref_locus = mappings->lookup_location (type.get_ref ());
+ Location base_locus = mappings->lookup_location (get_base ()->get_ref ());
+ RichLocation r (ref_locus);
+ r.add_range (base_locus);
+ rust_error_at (r, "cannot coerce [%s] with [%s]",
+ get_base ()->as_string ().c_str (),
+ type.as_string ().c_str ());
+ }
+
+ virtual void visit (FnType &type) override
+ {
+ Location ref_locus = mappings->lookup_location (type.get_ref ());
+ Location base_locus = mappings->lookup_location (get_base ()->get_ref ());
+ RichLocation r (ref_locus);
+ r.add_range (base_locus);
+ rust_error_at (r, "cannot coerce [%s] with [%s]",
+ get_base ()->as_string ().c_str (),
+ type.as_string ().c_str ());
+ }
+
+ virtual void visit (FnPtr &type) override
+ {
+ Location ref_locus = mappings->lookup_location (type.get_ref ());
+ Location base_locus = mappings->lookup_location (get_base ()->get_ref ());
+ RichLocation r (ref_locus);
+ r.add_range (base_locus);
+ rust_error_at (r, "cannot coerce [%s] with [%s]",
+ get_base ()->as_string ().c_str (),
+ type.as_string ().c_str ());
+ }
+
+ virtual void visit (ArrayType &type) override
+ {
+ Location ref_locus = mappings->lookup_location (type.get_ref ());
+ Location base_locus = mappings->lookup_location (get_base ()->get_ref ());
+ RichLocation r (ref_locus);
+ r.add_range (base_locus);
+ rust_error_at (r, "cannot coerce [%s] with [%s]",
+ get_base ()->as_string ().c_str (),
+ type.as_string ().c_str ());
+ }
+
+ virtual void visit (BoolType &type) override
+ {
+ Location ref_locus = mappings->lookup_location (type.get_ref ());
+ Location base_locus = mappings->lookup_location (get_base ()->get_ref ());
+ RichLocation r (ref_locus);
+ r.add_range (base_locus);
+ rust_error_at (r, "cannot coerce [%s] with [%s]",
+ get_base ()->as_string ().c_str (),
+ type.as_string ().c_str ());
+ }
+
+ virtual void visit (IntType &type) override
+ {
+ Location ref_locus = mappings->lookup_location (type.get_ref ());
+ Location base_locus = mappings->lookup_location (get_base ()->get_ref ());
+ RichLocation r (ref_locus);
+ r.add_range (base_locus);
+ rust_error_at (r, "cannot coerce [%s] with [%s]",
+ get_base ()->as_string ().c_str (),
+ type.as_string ().c_str ());
+ }
+
+ virtual void visit (UintType &type) override
+ {
+ Location ref_locus = mappings->lookup_location (type.get_ref ());
+ Location base_locus = mappings->lookup_location (get_base ()->get_ref ());
+ RichLocation r (ref_locus);
+ r.add_range (base_locus);
+ rust_error_at (r, "cannot coerce [%s] with [%s]",
+ get_base ()->as_string ().c_str (),
+ type.as_string ().c_str ());
+ }
+
+ virtual void visit (USizeType &type) override
+ {
+ Location ref_locus = mappings->lookup_location (type.get_ref ());
+ Location base_locus = mappings->lookup_location (get_base ()->get_ref ());
+ RichLocation r (ref_locus);
+ r.add_range (base_locus);
+ rust_error_at (r, "cannot coerce [%s] with [%s]",
+ get_base ()->as_string ().c_str (),
+ type.as_string ().c_str ());
+ }
+
+ virtual void visit (ISizeType &type) override
+ {
+ Location ref_locus = mappings->lookup_location (type.get_ref ());
+ Location base_locus = mappings->lookup_location (get_base ()->get_ref ());
+ RichLocation r (ref_locus);
+ r.add_range (base_locus);
+ rust_error_at (r, "cannot coerce [%s] with [%s]",
+ get_base ()->as_string ().c_str (),
+ type.as_string ().c_str ());
+ }
+
+ virtual void visit (FloatType &type) override
+ {
+ Location ref_locus = mappings->lookup_location (type.get_ref ());
+ Location base_locus = mappings->lookup_location (get_base ()->get_ref ());
+ RichLocation r (ref_locus);
+ r.add_range (base_locus);
+ rust_error_at (r, "cannot coerce [%s] with [%s]",
+ get_base ()->as_string ().c_str (),
+ type.as_string ().c_str ());
+ }
+
+ virtual void visit (ErrorType &type) override
+ {
+ Location ref_locus = mappings->lookup_location (type.get_ref ());
+ Location base_locus = mappings->lookup_location (get_base ()->get_ref ());
+ RichLocation r (ref_locus);
+ r.add_range (base_locus);
+ rust_error_at (r, "cannot coerce [%s] with [%s]",
+ get_base ()->as_string ().c_str (),
+ type.as_string ().c_str ());
+ }
+
+ virtual void visit (CharType &type) override
+ {
+ Location ref_locus = mappings->lookup_location (type.get_ref ());
+ Location base_locus = mappings->lookup_location (get_base ()->get_ref ());
+ RichLocation r (ref_locus);
+ r.add_range (base_locus);
+ rust_error_at (r, "cannot coerce [%s] with [%s]",
+ get_base ()->as_string ().c_str (),
+ type.as_string ().c_str ());
+ }
+
+ virtual void visit (ReferenceType &type) override
+ {
+ Location ref_locus = mappings->lookup_location (type.get_ref ());
+ Location base_locus = mappings->lookup_location (get_base ()->get_ref ());
+ RichLocation r (ref_locus);
+ r.add_range (base_locus);
+ rust_error_at (r, "cannot coerce [%s] with [%s]",
+ get_base ()->as_string ().c_str (),
+ type.as_string ().c_str ());
+ }
+
+ virtual void visit (ParamType &type) override
+ {
+ Location ref_locus = mappings->lookup_location (type.get_ref ());
+ Location base_locus = mappings->lookup_location (get_base ()->get_ref ());
+ RichLocation r (ref_locus);
+ r.add_range (base_locus);
+ rust_error_at (r, "cannot coerce [%s] with [%s]",
+ get_base ()->as_string ().c_str (),
+ type.as_string ().c_str ());
+ }
+
+ virtual void visit (StrType &type) override
+ {
+ Location ref_locus = mappings->lookup_location (type.get_ref ());
+ Location base_locus = mappings->lookup_location (get_base ()->get_ref ());
+ RichLocation r (ref_locus);
+ r.add_range (base_locus);
+ rust_error_at (r, "cannot coerce [%s] with [%s]",
+ get_base ()->as_string ().c_str (),
+ type.as_string ().c_str ());
+ }
+
+ virtual void visit (NeverType &type) override
+ {
+ Location ref_locus = mappings->lookup_location (type.get_ref ());
+ Location base_locus = mappings->lookup_location (get_base ()->get_ref ());
+ RichLocation r (ref_locus);
+ r.add_range (base_locus);
+ rust_error_at (r, "cannot coerce [%s] with [%s]",
+ get_base ()->as_string ().c_str (),
+ type.as_string ().c_str ());
+ }
+
+ virtual void visit (PlaceholderType &type) override
+ {
+ Location ref_locus = mappings->lookup_location (type.get_ref ());
+ Location base_locus = mappings->lookup_location (get_base ()->get_ref ());
+ RichLocation r (ref_locus);
+ r.add_range (base_locus);
+ rust_error_at (r, "cannot coerce [%s] with [%s]",
+ get_base ()->as_string ().c_str (),
+ type.as_string ().c_str ());
+ }
+
+protected:
+ BaseCoercionRules (BaseType *base)
+ : mappings (Analysis::Mappings::get ()),
+ context (Resolver::TypeCheckContext::get ()),
+ resolved (new ErrorType (base->get_ref (), base->get_ref ()))
+ {}
+
+ Analysis::Mappings *mappings;
+ Resolver::TypeCheckContext *context;
+
+ /* Temporary storage for the result of a unification.
+ We could return the result directly instead of storing it in the rule
+ object, but that involves modifying the visitor pattern to accommodate
+ the return value, which is too complex. */
+ BaseType *resolved;
+
+private:
+ /* Returns a pointer to the ty that created this rule. */
+ virtual BaseType *get_base () = 0;
+};
+
+class InferCoercionRules : public BaseCoercionRules
+{
+ using Rust::TyTy::BaseCoercionRules::visit;
+
+public:
+ InferCoercionRules (InferType *base) : BaseCoercionRules (base), base (base)
+ {}
+
+ void visit (BoolType &type) override
+ {
+ bool is_valid
+ = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL);
+ if (is_valid)
+ {
+ resolved = type.clone ();
+ return;
+ }
+
+ BaseCoercionRules::visit (type);
+ }
+
+ void visit (IntType &type) override
+ {
+ bool is_valid
+ = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL)
+ || (base->get_infer_kind ()
+ == TyTy::InferType::InferTypeKind::INTEGRAL);
+ if (is_valid)
+ {
+ resolved = type.clone ();
+ return;
+ }
+
+ BaseCoercionRules::visit (type);
+ }
+
+ void visit (UintType &type) override
+ {
+ bool is_valid
+ = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL)
+ || (base->get_infer_kind ()
+ == TyTy::InferType::InferTypeKind::INTEGRAL);
+ if (is_valid)
+ {
+ resolved = type.clone ();
+ return;
+ }
+
+ BaseCoercionRules::visit (type);
+ }
+
+ void visit (USizeType &type) override
+ {
+ bool is_valid
+ = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL)
+ || (base->get_infer_kind ()
+ == TyTy::InferType::InferTypeKind::INTEGRAL);
+ if (is_valid)
+ {
+ resolved = type.clone ();
+ return;
+ }
+
+ BaseCoercionRules::visit (type);
+ }
+
+ void visit (ISizeType &type) override
+ {
+ bool is_valid
+ = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL)
+ || (base->get_infer_kind ()
+ == TyTy::InferType::InferTypeKind::INTEGRAL);
+ if (is_valid)
+ {
+ resolved = type.clone ();
+ return;
+ }
+
+ BaseCoercionRules::visit (type);
+ }
+
+ void visit (FloatType &type) override
+ {
+ bool is_valid
+ = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL)
+ || (base->get_infer_kind () == TyTy::InferType::InferTypeKind::FLOAT);
+ if (is_valid)
+ {
+ resolved = type.clone ();
+ return;
+ }
+
+ BaseCoercionRules::visit (type);
+ }
+
+ void visit (ArrayType &type) override
+ {
+ bool is_valid
+ = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL);
+ if (is_valid)
+ {
+ resolved = type.clone ();
+ return;
+ }
+
+ BaseCoercionRules::visit (type);
+ }
+
+ void visit (ADTType &type) override
+ {
+ bool is_valid
+ = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL);
+ if (is_valid)
+ {
+ resolved = type.clone ();
+ return;
+ }
+
+ BaseCoercionRules::visit (type);
+ }
+
+ void visit (TupleType &type) override
+ {
+ bool is_valid
+ = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL);
+ if (is_valid)
+ {
+ resolved = type.clone ();
+ return;
+ }
+
+ BaseCoercionRules::visit (type);
+ }
+
+ void visit (InferType &type) override
+ {
+ switch (base->get_infer_kind ())
+ {
+ case InferType::InferTypeKind::GENERAL:
+ resolved = type.clone ();
+ return;
+
+ case InferType::InferTypeKind::INTEGRAL: {
+ if (type.get_infer_kind () == InferType::InferTypeKind::INTEGRAL)
+ {
+ resolved = type.clone ();
+ return;
+ }
+ else if (type.get_infer_kind () == InferType::InferTypeKind::GENERAL)
+ {
+ resolved = base->clone ();
+ return;
+ }
+ }
+ break;
+
+ case InferType::InferTypeKind::FLOAT: {
+ if (type.get_infer_kind () == InferType::InferTypeKind::FLOAT)
+ {
+ resolved = type.clone ();
+ return;
+ }
+ else if (type.get_infer_kind () == InferType::InferTypeKind::GENERAL)
+ {
+ resolved = base->clone ();
+ return;
+ }
+ }
+ break;
+ }
+
+ BaseCoercionRules::visit (type);
+ }
+
+ void visit (CharType &type) override
+ {
+ {
+ bool is_valid
+ = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL);
+ if (is_valid)
+ {
+ resolved = type.clone ();
+ return;
+ }
+
+ BaseCoercionRules::visit (type);
+ }
+ }
+
+ void visit (ReferenceType &type) override
+
+ {
+ bool is_valid
+ = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL);
+ if (is_valid)
+ {
+ resolved = type.clone ();
+ return;
+ }
+
+ BaseCoercionRules::visit (type);
+ }
+
+ void visit (ParamType &type) override
+ {
+ bool is_valid
+ = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL);
+ if (is_valid)
+ {
+ resolved = type.clone ();
+ return;
+ }
+
+ BaseCoercionRules::visit (type);
+ }
+
+private:
+ BaseType *get_base () override { return base; }
+
+ InferType *base;
+};
+
+class FnCoercionRules : public BaseCoercionRules
+{
+ using Rust::TyTy::BaseCoercionRules::visit;
+
+public:
+ FnCoercionRules (FnType *base) : BaseCoercionRules (base), base (base) {}
+
+ void visit (InferType &type) override
+ {
+ if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL)
+ {
+ BaseCoercionRules::visit (type);
+ return;
+ }
+
+ resolved = base->clone ();
+ resolved->set_ref (type.get_ref ());
+ }
+
+ void visit (FnType &type) override
+ {
+ if (base->num_params () != type.num_params ())
+ {
+ BaseCoercionRules::visit (type);
+ return;
+ }
+
+ for (size_t i = 0; i < base->num_params (); i++)
+ {
+ auto a = base->param_at (i).second;
+ auto b = type.param_at (i).second;
+
+ auto unified_param = a->unify (b);
+ if (unified_param == nullptr)
+ {
+ BaseCoercionRules::visit (type);
+ return;
+ }
+ }
+
+ auto unified_return
+ = base->get_return_type ()->unify (type.get_return_type ());
+ if (unified_return == nullptr)
+ {
+ BaseCoercionRules::visit (type);
+ return;
+ }
+
+ resolved = base->clone ();
+ resolved->set_ref (type.get_ref ());
+ }
+
+private:
+ BaseType *get_base () override { return base; }
+
+ FnType *base;
+};
+
+class FnptrCoercionRules : public BaseCoercionRules
+{
+ using Rust::TyTy::BaseCoercionRules::visit;
+
+public:
+ FnptrCoercionRules (FnPtr *base) : BaseCoercionRules (base), base (base) {}
+
+ void visit (InferType &type) override
+ {
+ if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL)
+ {
+ BaseCoercionRules::visit (type);
+ return;
+ }
+
+ resolved = base->clone ();
+ resolved->set_ref (type.get_ref ());
+ }
+
+ void visit (FnPtr &type) override
+ {
+ auto this_ret_type = base->get_return_type ();
+ auto other_ret_type = type.get_return_type ();
+ auto unified_result = this_ret_type->unify (other_ret_type);
+ if (unified_result == nullptr
+ || unified_result->get_kind () == TypeKind::ERROR)
+ {
+ BaseCoercionRules::visit (type);
+ return;
+ }
+
+ if (base->num_params () != type.num_params ())
+ {
+ BaseCoercionRules::visit (type);
+ return;
+ }
+
+ for (size_t i = 0; i < base->num_params (); i++)
+ {
+ auto this_param = base->param_at (i);
+ auto other_param = type.param_at (i);
+ auto unified_param = this_param->unify (other_param);
+ if (unified_param == nullptr
+ || unified_param->get_kind () == TypeKind::ERROR)
+ {
+ BaseCoercionRules::visit (type);
+ return;
+ }
+ }
+
+ resolved = base->clone ();
+ resolved->set_ref (type.get_ref ());
+ }
+
+ void visit (FnType &type) override
+ {
+ auto this_ret_type = base->get_return_type ();
+ auto other_ret_type = type.get_return_type ();
+ auto unified_result = this_ret_type->unify (other_ret_type);
+ if (unified_result == nullptr
+ || unified_result->get_kind () == TypeKind::ERROR)
+ {
+ BaseCoercionRules::visit (type);
+ return;
+ }
+
+ if (base->num_params () != type.num_params ())
+ {
+ BaseCoercionRules::visit (type);
+ return;
+ }
+
+ for (size_t i = 0; i < base->num_params (); i++)
+ {
+ auto this_param = base->param_at (i);
+ auto other_param = type.param_at (i).second;
+ auto unified_param = this_param->unify (other_param);
+ if (unified_param == nullptr
+ || unified_param->get_kind () == TypeKind::ERROR)
+ {
+ BaseCoercionRules::visit (type);
+ return;
+ }
+ }
+
+ resolved = base->clone ();
+ resolved->set_ref (type.get_ref ());
+ }
+
+private:
+ BaseType *get_base () override { return base; }
+
+ FnPtr *base;
+};
+
+class ArrayCoercionRules : public BaseCoercionRules
+{
+ using Rust::TyTy::BaseCoercionRules::visit;
+
+public:
+ ArrayCoercionRules (ArrayType *base) : BaseCoercionRules (base), base (base)
+ {}
+
+ void visit (ArrayType &type) override
+ {
+ // check base type
+ auto base_resolved
+ = base->get_element_type ()->unify (type.get_element_type ());
+ if (base_resolved == nullptr)
+ {
+ BaseCoercionRules::visit (type);
+ return;
+ }
+
+ auto backend = rust_get_backend ();
+
+ // need to check the base types and capacity
+ if (!backend->const_values_equal (type.get_capacity (),
+ base->get_capacity ()))
+ {
+ BaseCoercionRules::visit (type);
+ return;
+ }
+
+ resolved
+ = new ArrayType (type.get_ref (), type.get_ty_ref (),
+ type.get_capacity (), TyVar (base_resolved->get_ref ()));
+ }
+
+private:
+ BaseType *get_base () override { return base; }
+
+ ArrayType *base;
+};
+
+class BoolCoercionRules : public BaseCoercionRules
+{
+ using Rust::TyTy::BaseCoercionRules::visit;
+
+public:
+ BoolCoercionRules (BoolType *base) : BaseCoercionRules (base), base (base) {}
+
+ void visit (BoolType &type) override
+ {
+ resolved = new BoolType (type.get_ref (), type.get_ty_ref ());
+ }
+
+ void visit (InferType &type) override
+ {
+ switch (type.get_infer_kind ())
+ {
+ case InferType::InferTypeKind::GENERAL:
+ resolved = base->clone ();
+ break;
+
+ default:
+ BaseCoercionRules::visit (type);
+ break;
+ }
+ }
+
+private:
+ BaseType *get_base () override { return base; }
+
+ BoolType *base;
+};
+
+class IntCoercionRules : public BaseCoercionRules
+{
+ using Rust::TyTy::BaseCoercionRules::visit;
+
+public:
+ IntCoercionRules (IntType *base) : BaseCoercionRules (base), base (base) {}
+
+ void visit (InferType &type) override
+ {
+ // cant assign a float inference variable
+ if (type.get_infer_kind () == InferType::InferTypeKind::FLOAT)
+ {
+ BaseCoercionRules::visit (type);
+ return;
+ }
+
+ resolved = base->clone ();
+ resolved->set_ref (type.get_ref ());
+ }
+
+ void visit (IntType &type) override
+ {
+ if (type.get_int_kind () != base->get_int_kind ())
+ {
+ BaseCoercionRules::visit (type);
+ return;
+ }
+
+ resolved
+ = new IntType (type.get_ref (), type.get_ty_ref (), type.get_int_kind ());
+ }
+
+private:
+ BaseType *get_base () override { return base; }
+
+ IntType *base;
+};
+
+class UintCoercionRules : public BaseCoercionRules
+{
+ using Rust::TyTy::BaseCoercionRules::visit;
+
+public:
+ UintCoercionRules (UintType *base) : BaseCoercionRules (base), base (base) {}
+
+ void visit (InferType &type) override
+ {
+ // cant assign a float inference variable
+ if (type.get_infer_kind () == InferType::InferTypeKind::FLOAT)
+ {
+ BaseCoercionRules::visit (type);
+ return;
+ }
+
+ resolved = base->clone ();
+ resolved->set_ref (type.get_ref ());
+ }
+
+ void visit (UintType &type) override
+ {
+ if (type.get_uint_kind () != base->get_uint_kind ())
+ {
+ BaseCoercionRules::visit (type);
+ return;
+ }
+
+ resolved = new UintType (type.get_ref (), type.get_ty_ref (),
+ type.get_uint_kind ());
+ }
+
+private:
+ BaseType *get_base () override { return base; }
+
+ UintType *base;
+};
+
+class FloatCoercionRules : public BaseCoercionRules
+{
+ using Rust::TyTy::BaseCoercionRules::visit;
+
+public:
+ FloatCoercionRules (FloatType *base) : BaseCoercionRules (base), base (base)
+ {}
+
+ void visit (InferType &type) override
+ {
+ if (type.get_infer_kind () == InferType::InferTypeKind::INTEGRAL)
+ {
+ BaseCoercionRules::visit (type);
+ return;
+ }
+
+ resolved = base->clone ();
+ resolved->set_ref (type.get_ref ());
+ }
+
+ void visit (FloatType &type) override
+ {
+ if (type.get_float_kind () != base->get_float_kind ())
+ {
+ BaseCoercionRules::visit (type);
+ return;
+ }
+
+ resolved = new FloatType (type.get_ref (), type.get_ty_ref (),
+ type.get_float_kind ());
+ }
+
+private:
+ BaseType *get_base () override { return base; }
+
+ FloatType *base;
+};
+
+class ADTCoercionRules : public BaseCoercionRules
+{
+ using Rust::TyTy::BaseCoercionRules::visit;
+
+public:
+ ADTCoercionRules (ADTType *base) : BaseCoercionRules (base), base (base) {}
+
+ void visit (ADTType &type) override
+ {
+ if (base->get_identifier ().compare (type.get_identifier ()) != 0)
+ {
+ BaseCoercionRules::visit (type);
+ return;
+ }
+
+ if (base->num_fields () != type.num_fields ())
+ {
+ BaseCoercionRules::visit (type);
+ return;
+ }
+
+ for (size_t i = 0; i < type.num_fields (); ++i)
+ {
+ TyTy::StructFieldType *base_field = base->get_field (i);
+ TyTy::StructFieldType *other_field = type.get_field (i);
+
+ TyTy::BaseType *this_field_ty = base_field->get_field_type ();
+ TyTy::BaseType *other_field_ty = other_field->get_field_type ();
+
+ BaseType *unified_ty = this_field_ty->unify (other_field_ty);
+ if (unified_ty->get_kind () == TyTy::TypeKind::ERROR)
+ return;
+ }
+
+ resolved = type.clone ();
+ }
+
+private:
+ BaseType *get_base () override { return base; }
+
+ ADTType *base;
+};
+
+class TupleCoercionRules : public BaseCoercionRules
+{
+ using Rust::TyTy::BaseCoercionRules::visit;
+
+public:
+ TupleCoercionRules (TupleType *base) : BaseCoercionRules (base), base (base)
+ {}
+
+ void visit (TupleType &type) override
+ {
+ if (base->num_fields () != type.num_fields ())
+ {
+ BaseCoercionRules::visit (type);
+ return;
+ }
+
+ std::vector<TyVar> fields;
+ for (size_t i = 0; i < base->num_fields (); i++)
+ {
+ BaseType *bo = base->get_field (i);
+ BaseType *fo = type.get_field (i);
+
+ BaseType *unified_ty = bo->unify (fo);
+ if (unified_ty->get_kind () == TyTy::TypeKind::ERROR)
+ return;
+
+ fields.push_back (TyVar (unified_ty->get_ref ()));
+ }
+
+ resolved
+ = new TyTy::TupleType (type.get_ref (), type.get_ty_ref (), fields);
+ }
+
+private:
+ BaseType *get_base () override { return base; }
+
+ TupleType *base;
+};
+
+class USizeCoercionRules : public BaseCoercionRules
+{
+ using Rust::TyTy::BaseCoercionRules::visit;
+
+public:
+ USizeCoercionRules (USizeType *base) : BaseCoercionRules (base), base (base)
+ {}
+
+ void visit (InferType &type) override
+ {
+ // cant assign a float inference variable
+ if (type.get_infer_kind () == InferType::InferTypeKind::FLOAT)
+ {
+ BaseCoercionRules::visit (type);
+ return;
+ }
+
+ resolved = base->clone ();
+ resolved->set_ref (type.get_ref ());
+ }
+
+ void visit (USizeType &type) override { resolved = type.clone (); }
+
+private:
+ BaseType *get_base () override { return base; }
+
+ USizeType *base;
+};
+
+class ISizeCoercionRules : public BaseCoercionRules
+{
+ using Rust::TyTy::BaseCoercionRules::visit;
+
+public:
+ ISizeCoercionRules (ISizeType *base) : BaseCoercionRules (base), base (base)
+ {}
+
+ void visit (InferType &type) override
+ {
+ // cant assign a float inference variable
+ if (type.get_infer_kind () == InferType::InferTypeKind::FLOAT)
+ {
+ BaseCoercionRules::visit (type);
+ return;
+ }
+
+ resolved = base->clone ();
+ resolved->set_ref (type.get_ref ());
+ }
+
+ void visit (ISizeType &type) override { resolved = type.clone (); }
+
+private:
+ BaseType *get_base () override { return base; }
+
+ ISizeType *base;
+};
+
+class CharCoercionRules : public BaseCoercionRules
+{
+ using Rust::TyTy::BaseCoercionRules::visit;
+
+public:
+ CharCoercionRules (CharType *base) : BaseCoercionRules (base), base (base) {}
+
+ void visit (InferType &type) override
+ {
+ if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL)
+ {
+ BaseCoercionRules::visit (type);
+ return;
+ }
+
+ resolved = base->clone ();
+ resolved->set_ref (type.get_ref ());
+ }
+
+ void visit (CharType &type) override { resolved = type.clone (); }
+
+private:
+ BaseType *get_base () override { return base; }
+
+ CharType *base;
+};
+
+class ReferenceCoercionRules : public BaseCoercionRules
+{
+ using Rust::TyTy::BaseCoercionRules::visit;
+
+public:
+ ReferenceCoercionRules (ReferenceType *base)
+ : BaseCoercionRules (base), base (base)
+ {}
+
+ void visit (ReferenceType &type) override
+ {
+ auto base_type = base->get_base ();
+ auto other_base_type = type.get_base ();
+
+ TyTy::BaseType *base_resolved = base_type->unify (other_base_type);
+ if (base_resolved == nullptr
+ || base_resolved->get_kind () == TypeKind::ERROR)
+ {
+ BaseCoercionRules::visit (type);
+ return;
+ }
+
+ // we can allow for mutability changes here by casting down from mutability
+ // eg: mut vs const, we cant take a mutable reference from a const
+ // eg: const vs mut we can take a const reference from a mutable one
+ if (!base->is_mutable () || (base->is_mutable () == type.is_mutable ()))
+ {
+ resolved = new ReferenceType (base->get_ref (), base->get_ty_ref (),
+ TyVar (base_resolved->get_ref ()),
+ base->is_mutable ());
+ return;
+ }
+
+ BaseCoercionRules::visit (type);
+ }
+
+private:
+ BaseType *get_base () override { return base; }
+
+ ReferenceType *base;
+};
+
+class ParamCoercionRules : public BaseCoercionRules
+{
+ using Rust::TyTy::BaseCoercionRules::visit;
+
+public:
+ ParamCoercionRules (ParamType *base) : BaseCoercionRules (base), base (base)
+ {}
+
+ // param types are a placeholder we shouldn't have cases where we unify
+ // against it. eg: struct foo<T> { a: T }; When we invoke it we can do either:
+ //
+ // foo<i32>{ a: 123 }.
+ // Then this enforces the i32 type to be referenced on the
+ // field via an hirid.
+ //
+ // rust also allows for a = foo{a:123}; Where we can use an Inference Variable
+ // to handle the typing of the struct
+ BaseType *coerce (BaseType *other) override final
+ {
+ if (base->get_ref () == base->get_ty_ref ())
+ return BaseCoercionRules::coerce (other);
+
+ auto context = Resolver::TypeCheckContext::get ();
+ BaseType *lookup = nullptr;
+ bool ok = context->lookup_type (base->get_ty_ref (), &lookup);
+ rust_assert (ok);
+
+ return lookup->unify (other);
+ }
+
+ void visit (ParamType &type) override
+ {
+ if (base->get_symbol ().compare (type.get_symbol ()) != 0)
+ {
+ BaseCoercionRules::visit (type);
+ return;
+ }
+
+ resolved = type.clone ();
+ }
+
+ void visit (InferType &type) override
+ {
+ if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL)
+ {
+ BaseCoercionRules::visit (type);
+ return;
+ }
+
+ resolved = base->clone ();
+ }
+
+private:
+ BaseType *get_base () override { return base; }
+
+ ParamType *base;
+};
+
+class StrCoercionRules : public BaseCoercionRules
+{
+ // FIXME we will need a enum for the StrType like ByteBuf etc..
+ using Rust::TyTy::BaseCoercionRules::visit;
+
+public:
+ StrCoercionRules (StrType *base) : BaseCoercionRules (base), base (base) {}
+
+ void visit (StrType &type) override { resolved = type.clone (); }
+
+private:
+ BaseType *get_base () override { return base; }
+
+ StrType *base;
+};
+
+class NeverCoercionRules : public BaseCoercionRules
+{
+ using Rust::TyTy::BaseCoercionRules::visit;
+
+public:
+ NeverCoercionRules (NeverType *base) : BaseCoercionRules (base), base (base)
+ {}
+
+ virtual void visit (NeverType &type) override { resolved = type.clone (); }
+
+private:
+ BaseType *get_base () override { return base; }
+
+ NeverType *base;
+};
+
+class PlaceholderCoercionRules : public BaseCoercionRules
+{
+ using Rust::TyTy::BaseCoercionRules::visit;
+
+public:
+ PlaceholderCoercionRules (PlaceholderType *base)
+ : BaseCoercionRules (base), base (base)
+ {}
+
+private:
+ BaseType *get_base () override { return base; }
+
+ PlaceholderType *base;
+};
+
+} // namespace TyTy
+} // namespace Rust
+
+#endif // RUST_TYTY_COERCION_RULES
diff --git a/gcc/rust/typecheck/rust-tyty.cc b/gcc/rust/typecheck/rust-tyty.cc
index 0eceaef..16bb01b 100644
--- a/gcc/rust/typecheck/rust-tyty.cc
+++ b/gcc/rust/typecheck/rust-tyty.cc
@@ -23,6 +23,7 @@
#include "rust-hir-type-check-type.h"
#include "rust-tyty-rules.h"
#include "rust-tyty-cmp.h"
+#include "rust-tyty-coercion.h"
#include "rust-hir-map.h"
#include "rust-substitution-mapper.h"
@@ -111,6 +112,13 @@ InferType::can_eq (const BaseType *other, bool emit_errors) const
}
BaseType *
+InferType::coerce (BaseType *other)
+{
+ InferCoercionRules r (this);
+ return r.coerce (other);
+}
+
+BaseType *
InferType::clone ()
{
return new InferType (get_ref (), get_ty_ref (), get_infer_kind (),
@@ -173,6 +181,12 @@ ErrorType::can_eq (const BaseType *other, bool emit_errors) const
}
BaseType *
+ErrorType::coerce (BaseType *other)
+{
+ return this;
+}
+
+BaseType *
ErrorType::clone ()
{
return new ErrorType (get_ref (), get_ty_ref (), get_combined_refs ());
@@ -438,6 +452,13 @@ ADTType::unify (BaseType *other)
return r.unify (other);
}
+BaseType *
+ADTType::coerce (BaseType *other)
+{
+ ADTCoercionRules r (this);
+ return r.coerce (other);
+}
+
bool
ADTType::can_eq (const BaseType *other, bool emit_errors) const
{
@@ -605,6 +626,13 @@ TupleType::unify (BaseType *other)
return r.unify (other);
}
+BaseType *
+TupleType::coerce (BaseType *other)
+{
+ TupleCoercionRules r (this);
+ return r.coerce (other);
+}
+
bool
TupleType::can_eq (const BaseType *other, bool emit_errors) const
{
@@ -695,6 +723,13 @@ FnType::unify (BaseType *other)
return r.unify (other);
}
+BaseType *
+FnType::coerce (BaseType *other)
+{
+ FnCoercionRules r (this);
+ return r.coerce (other);
+}
+
bool
FnType::can_eq (const BaseType *other, bool emit_errors) const
{
@@ -896,6 +931,13 @@ FnPtr::unify (BaseType *other)
return r.unify (other);
}
+BaseType *
+FnPtr::coerce (BaseType *other)
+{
+ FnptrCoercionRules r (this);
+ return r.coerce (other);
+}
+
bool
FnPtr::can_eq (const BaseType *other, bool emit_errors) const
{
@@ -969,6 +1011,13 @@ ArrayType::unify (BaseType *other)
return r.unify (other);
}
+BaseType *
+ArrayType::coerce (BaseType *other)
+{
+ ArrayCoercionRules r (this);
+ return r.coerce (other);
+}
+
bool
ArrayType::can_eq (const BaseType *other, bool emit_errors) const
{
@@ -1030,6 +1079,13 @@ BoolType::unify (BaseType *other)
return r.unify (other);
}
+BaseType *
+BoolType::coerce (BaseType *other)
+{
+ BoolCoercionRules r (this);
+ return r.coerce (other);
+}
+
bool
BoolType::can_eq (const BaseType *other, bool emit_errors) const
{
@@ -1082,6 +1138,13 @@ IntType::unify (BaseType *other)
return r.unify (other);
}
+BaseType *
+IntType::coerce (BaseType *other)
+{
+ IntCoercionRules r (this);
+ return r.coerce (other);
+}
+
bool
IntType::can_eq (const BaseType *other, bool emit_errors) const
{
@@ -1145,6 +1208,13 @@ UintType::unify (BaseType *other)
return r.unify (other);
}
+BaseType *
+UintType::coerce (BaseType *other)
+{
+ UintCoercionRules r (this);
+ return r.coerce (other);
+}
+
bool
UintType::can_eq (const BaseType *other, bool emit_errors) const
{
@@ -1202,6 +1272,13 @@ FloatType::unify (BaseType *other)
return r.unify (other);
}
+BaseType *
+FloatType::coerce (BaseType *other)
+{
+ FloatCoercionRules r (this);
+ return r.coerce (other);
+}
+
bool
FloatType::can_eq (const BaseType *other, bool emit_errors) const
{
@@ -1251,6 +1328,13 @@ USizeType::unify (BaseType *other)
return r.unify (other);
}
+BaseType *
+USizeType::coerce (BaseType *other)
+{
+ USizeCoercionRules r (this);
+ return r.coerce (other);
+}
+
bool
USizeType::can_eq (const BaseType *other, bool emit_errors) const
{
@@ -1289,6 +1373,13 @@ ISizeType::unify (BaseType *other)
return r.unify (other);
}
+BaseType *
+ISizeType::coerce (BaseType *other)
+{
+ ISizeCoercionRules r (this);
+ return r.coerce (other);
+}
+
bool
ISizeType::can_eq (const BaseType *other, bool emit_errors) const
{
@@ -1327,6 +1418,13 @@ CharType::unify (BaseType *other)
return r.unify (other);
}
+BaseType *
+CharType::coerce (BaseType *other)
+{
+ CharCoercionRules r (this);
+ return r.coerce (other);
+}
+
bool
CharType::can_eq (const BaseType *other, bool emit_errors) const
{
@@ -1366,6 +1464,13 @@ ReferenceType::unify (BaseType *other)
return r.unify (other);
}
+BaseType *
+ReferenceType::coerce (BaseType *other)
+{
+ ReferenceCoercionRules r (this);
+ return r.coerce (other);
+}
+
bool
ReferenceType::can_eq (const BaseType *other, bool emit_errors) const
{
@@ -1447,6 +1552,13 @@ ParamType::unify (BaseType *other)
return r.unify (other);
}
+BaseType *
+ParamType::coerce (BaseType *other)
+{
+ ParamCoercionRules r (this);
+ return r.coerce (other);
+}
+
bool
ParamType::can_eq (const BaseType *other, bool emit_errors) const
{
@@ -1553,6 +1665,13 @@ StrType::unify (BaseType *other)
return r.unify (other);
}
+BaseType *
+StrType::coerce (BaseType *other)
+{
+ StrCoercionRules r (this);
+ return r.coerce (other);
+}
+
bool
StrType::can_eq (const BaseType *other, bool emit_errors) const
{
@@ -1591,6 +1710,13 @@ NeverType::unify (BaseType *other)
return r.unify (other);
}
+BaseType *
+NeverType::coerce (BaseType *other)
+{
+ NeverCoercionRules r (this);
+ return r.coerce (other);
+}
+
bool
NeverType::can_eq (const BaseType *other, bool emit_errors) const
{
@@ -1629,6 +1755,13 @@ PlaceholderType::unify (BaseType *other)
return r.unify (other);
}
+BaseType *
+PlaceholderType::coerce (BaseType *other)
+{
+ PlaceholderCoercionRules r (this);
+ return r.coerce (other);
+}
+
bool
PlaceholderType::can_eq (const BaseType *other, bool emit_errors) const
{
diff --git a/gcc/rust/typecheck/rust-tyty.h b/gcc/rust/typecheck/rust-tyty.h
index 680e43f..c0af9f6 100644
--- a/gcc/rust/typecheck/rust-tyty.h
+++ b/gcc/rust/typecheck/rust-tyty.h
@@ -167,6 +167,9 @@ public:
// checks
virtual bool can_eq (const BaseType *other, bool emit_errors) const = 0;
+ // this is the base coercion interface for types
+ virtual BaseType *coerce (BaseType *other) = 0;
+
// Check value equality between two ty. Type inference rules are ignored. Two
// ty are considered equal if they're of the same kind, and
// 1. (For ADTs, arrays, tuples, refs) have the same underlying ty
@@ -287,6 +290,8 @@ public:
bool can_eq (const BaseType *other, bool emit_errors) const override final;
+ BaseType *coerce (BaseType *other) override;
+
BaseType *clone () final override;
InferTypeKind get_infer_kind () const { return infer_kind; }
@@ -321,6 +326,7 @@ public:
BaseType *unify (BaseType *other) override;
bool can_eq (const BaseType *other, bool emit_errors) const override final;
+ BaseType *coerce (BaseType *other) override;
BaseType *clone () final override;
@@ -350,6 +356,7 @@ public:
BaseType *unify (BaseType *other) override;
bool can_eq (const BaseType *other, bool emit_errors) const override final;
+ BaseType *coerce (BaseType *other) override;
BaseType *clone () final override;
@@ -436,6 +443,7 @@ public:
BaseType *unify (BaseType *other) override;
bool can_eq (const BaseType *other, bool emit_errors) const override final;
+ BaseType *coerce (BaseType *other) override;
bool is_equal (const BaseType &other) const override;
@@ -869,6 +877,7 @@ public:
BaseType *unify (BaseType *other) override;
bool can_eq (const BaseType *other, bool emit_errors) const override final;
+ BaseType *coerce (BaseType *other) override;
bool is_equal (const BaseType &other) const override;
@@ -989,6 +998,7 @@ public:
BaseType *unify (BaseType *other) override;
bool can_eq (const BaseType *other, bool emit_errors) const override final;
+ BaseType *coerce (BaseType *other) override;
bool is_equal (const BaseType &other) const override;
@@ -1089,6 +1099,7 @@ public:
BaseType *unify (BaseType *other) override;
bool can_eq (const BaseType *other, bool emit_errors) const override final;
+ BaseType *coerce (BaseType *other) override;
bool is_equal (const BaseType &other) const override;
@@ -1132,6 +1143,7 @@ public:
BaseType *unify (BaseType *other) override;
bool can_eq (const BaseType *other, bool emit_errors) const override final;
+ BaseType *coerce (BaseType *other) override;
bool is_equal (const BaseType &other) const override;
@@ -1172,6 +1184,7 @@ public:
BaseType *unify (BaseType *other) override;
bool can_eq (const BaseType *other, bool emit_errors) const override final;
+ BaseType *coerce (BaseType *other) override;
BaseType *clone () final override;
};
@@ -1206,6 +1219,7 @@ public:
BaseType *unify (BaseType *other) override;
bool can_eq (const BaseType *other, bool emit_errors) const override final;
+ BaseType *coerce (BaseType *other) override;
IntKind get_int_kind () const { return int_kind; }
@@ -1247,6 +1261,7 @@ public:
BaseType *unify (BaseType *other) override;
bool can_eq (const BaseType *other, bool emit_errors) const override final;
+ BaseType *coerce (BaseType *other) override;
UintKind get_uint_kind () const { return uint_kind; }
@@ -1286,6 +1301,7 @@ public:
BaseType *unify (BaseType *other) override;
bool can_eq (const BaseType *other, bool emit_errors) const override final;
+ BaseType *coerce (BaseType *other) override;
FloatKind get_float_kind () const { return float_kind; }
@@ -1317,6 +1333,7 @@ public:
BaseType *unify (BaseType *other) override;
bool can_eq (const BaseType *other, bool emit_errors) const override final;
+ BaseType *coerce (BaseType *other) override;
BaseType *clone () final override;
};
@@ -1341,6 +1358,7 @@ public:
BaseType *unify (BaseType *other) override;
bool can_eq (const BaseType *other, bool emit_errors) const override final;
+ BaseType *coerce (BaseType *other) override;
BaseType *clone () final override;
};
@@ -1365,6 +1383,7 @@ public:
BaseType *unify (BaseType *other) override;
bool can_eq (const BaseType *other, bool emit_errors) const override final;
+ BaseType *coerce (BaseType *other) override;
BaseType *clone () final override;
};
@@ -1393,6 +1412,7 @@ public:
BaseType *unify (BaseType *other) override;
bool can_eq (const BaseType *other, bool emit_errors) const override final;
+ BaseType *coerce (BaseType *other) override;
bool is_equal (const BaseType &other) const override;
@@ -1432,6 +1452,7 @@ public:
BaseType *unify (BaseType *other) override;
bool can_eq (const BaseType *other, bool emit_errors) const override final;
+ BaseType *coerce (BaseType *other) override;
bool is_equal (const BaseType &other) const override;
@@ -1466,6 +1487,7 @@ public:
BaseType *unify (BaseType *other) override;
bool can_eq (const BaseType *other, bool emit_errors) const override final;
+ BaseType *coerce (BaseType *other) override;
BaseType *clone () final override;
@@ -1495,6 +1517,7 @@ public:
BaseType *unify (BaseType *other) override;
bool can_eq (const BaseType *other, bool emit_errors) const override final;
+ BaseType *coerce (BaseType *other) override;
BaseType *clone () final override;