diff options
author | Philip Herron <philip.herron@embecosm.com> | 2021-07-19 18:55:25 +0100 |
---|---|---|
committer | Philip Herron <philip.herron@embecosm.com> | 2021-07-19 19:38:29 +0100 |
commit | 0cbd3afc714a1d874fd829108f9b51a44205c050 (patch) | |
tree | a51428aa157c2d0c4d66989a6a0e1538ca462e9a /gcc/rust | |
parent | f82bf003cebb8a296312d32882db03e52945dac3 (diff) | |
download | gcc-0cbd3afc714a1d874fd829108f9b51a44205c050.zip gcc-0cbd3afc714a1d874fd829108f9b51a44205c050.tar.gz gcc-0cbd3afc714a1d874fd829108f9b51a44205c050.tar.bz2 |
Initial coercion rules
Lets keep the same unify pattern for coercion rules to keep the code as
readable as possible.
Addresses #434
Diffstat (limited to 'gcc/rust')
-rw-r--r-- | gcc/rust/typecheck/rust-tyty-coercion.h | 1198 | ||||
-rw-r--r-- | gcc/rust/typecheck/rust-tyty.cc | 133 | ||||
-rw-r--r-- | gcc/rust/typecheck/rust-tyty.h | 23 |
3 files changed, 1354 insertions, 0 deletions
diff --git a/gcc/rust/typecheck/rust-tyty-coercion.h b/gcc/rust/typecheck/rust-tyty-coercion.h new file mode 100644 index 0000000..6695056 --- /dev/null +++ b/gcc/rust/typecheck/rust-tyty-coercion.h @@ -0,0 +1,1198 @@ +// Copyright (C) 2020 Free Software Foundation, Inc. + +// This file is part of GCC. + +// GCC is free software; you can redistribute it and/or modify it under +// the terms of the GNU General Public License as published by the Free +// Software Foundation; either version 3, or (at your option) any later +// version. + +// GCC is distributed in the hope that it will be useful, but WITHOUT ANY +// WARRANTY; without even the implied warranty of MERCHANTABILITY or +// FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +// for more details. + +// You should have received a copy of the GNU General Public License +// along with GCC; see the file COPYING3. If not see +// <http://www.gnu.org/licenses/>. + +#ifndef RUST_TYTY_COERCION_RULES +#define RUST_TYTY_COERCION_RULES + +#include "rust-diagnostics.h" +#include "rust-tyty.h" +#include "rust-tyty-visitor.h" +#include "rust-hir-map.h" +#include "rust-hir-type-check.h" + +extern ::Backend * +rust_get_backend (); + +namespace Rust { +namespace TyTy { + +class BaseCoercionRules : public TyVisitor +{ +public: + virtual ~BaseCoercionRules () {} + + virtual BaseType *coerce (BaseType *other) + { + if (other->get_kind () == TypeKind::PARAM) + { + ParamType *p = static_cast<ParamType *> (other); + if (p->can_resolve ()) + { + other = p->resolve (); + } + } + + other->accept_vis (*this); + if (resolved->get_kind () == TyTy::TypeKind::ERROR) + return resolved; + + resolved->append_reference (get_base ()->get_ref ()); + resolved->append_reference (other->get_ref ()); + for (auto ref : get_base ()->get_combined_refs ()) + resolved->append_reference (ref); + for (auto ref : other->get_combined_refs ()) + resolved->append_reference (ref); + + bool result_resolved = resolved->get_kind () != TyTy::TypeKind::INFER; + bool result_is_infer_var = resolved->get_kind () == TyTy::TypeKind::INFER; + bool results_is_non_general_infer_var + = (result_is_infer_var + && (static_cast<InferType *> (resolved))->get_infer_kind () + != TyTy::InferType::GENERAL); + if (result_resolved || results_is_non_general_infer_var) + { + for (auto &ref : resolved->get_combined_refs ()) + { + TyTy::BaseType *ref_tyty = nullptr; + bool ok = context->lookup_type (ref, &ref_tyty); + if (!ok) + continue; + + // if any of the types are inference variables lets fix them + if (ref_tyty->get_kind () == TyTy::TypeKind::INFER) + { + context->insert_type ( + Analysis::NodeMapping (mappings->get_current_crate (), + UNKNOWN_NODEID, ref, + UNKNOWN_LOCAL_DEFID), + resolved->clone ()); + } + } + } + return resolved; + } + + virtual void visit (TupleType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "cannot coerce [%s] with [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (ADTType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "cannot coerce [%s] with [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (InferType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "cannot coerce [%s] with [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (FnType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "cannot coerce [%s] with [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (FnPtr &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "cannot coerce [%s] with [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (ArrayType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "cannot coerce [%s] with [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (BoolType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "cannot coerce [%s] with [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (IntType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "cannot coerce [%s] with [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (UintType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "cannot coerce [%s] with [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (USizeType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "cannot coerce [%s] with [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (ISizeType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "cannot coerce [%s] with [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (FloatType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "cannot coerce [%s] with [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (ErrorType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "cannot coerce [%s] with [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (CharType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "cannot coerce [%s] with [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (ReferenceType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "cannot coerce [%s] with [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (ParamType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "cannot coerce [%s] with [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (StrType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "cannot coerce [%s] with [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (NeverType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "cannot coerce [%s] with [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (PlaceholderType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "cannot coerce [%s] with [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + +protected: + BaseCoercionRules (BaseType *base) + : mappings (Analysis::Mappings::get ()), + context (Resolver::TypeCheckContext::get ()), + resolved (new ErrorType (base->get_ref (), base->get_ref ())) + {} + + Analysis::Mappings *mappings; + Resolver::TypeCheckContext *context; + + /* Temporary storage for the result of a unification. + We could return the result directly instead of storing it in the rule + object, but that involves modifying the visitor pattern to accommodate + the return value, which is too complex. */ + BaseType *resolved; + +private: + /* Returns a pointer to the ty that created this rule. */ + virtual BaseType *get_base () = 0; +}; + +class InferCoercionRules : public BaseCoercionRules +{ + using Rust::TyTy::BaseCoercionRules::visit; + +public: + InferCoercionRules (InferType *base) : BaseCoercionRules (base), base (base) + {} + + void visit (BoolType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL); + if (is_valid) + { + resolved = type.clone (); + return; + } + + BaseCoercionRules::visit (type); + } + + void visit (IntType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL) + || (base->get_infer_kind () + == TyTy::InferType::InferTypeKind::INTEGRAL); + if (is_valid) + { + resolved = type.clone (); + return; + } + + BaseCoercionRules::visit (type); + } + + void visit (UintType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL) + || (base->get_infer_kind () + == TyTy::InferType::InferTypeKind::INTEGRAL); + if (is_valid) + { + resolved = type.clone (); + return; + } + + BaseCoercionRules::visit (type); + } + + void visit (USizeType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL) + || (base->get_infer_kind () + == TyTy::InferType::InferTypeKind::INTEGRAL); + if (is_valid) + { + resolved = type.clone (); + return; + } + + BaseCoercionRules::visit (type); + } + + void visit (ISizeType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL) + || (base->get_infer_kind () + == TyTy::InferType::InferTypeKind::INTEGRAL); + if (is_valid) + { + resolved = type.clone (); + return; + } + + BaseCoercionRules::visit (type); + } + + void visit (FloatType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL) + || (base->get_infer_kind () == TyTy::InferType::InferTypeKind::FLOAT); + if (is_valid) + { + resolved = type.clone (); + return; + } + + BaseCoercionRules::visit (type); + } + + void visit (ArrayType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL); + if (is_valid) + { + resolved = type.clone (); + return; + } + + BaseCoercionRules::visit (type); + } + + void visit (ADTType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL); + if (is_valid) + { + resolved = type.clone (); + return; + } + + BaseCoercionRules::visit (type); + } + + void visit (TupleType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL); + if (is_valid) + { + resolved = type.clone (); + return; + } + + BaseCoercionRules::visit (type); + } + + void visit (InferType &type) override + { + switch (base->get_infer_kind ()) + { + case InferType::InferTypeKind::GENERAL: + resolved = type.clone (); + return; + + case InferType::InferTypeKind::INTEGRAL: { + if (type.get_infer_kind () == InferType::InferTypeKind::INTEGRAL) + { + resolved = type.clone (); + return; + } + else if (type.get_infer_kind () == InferType::InferTypeKind::GENERAL) + { + resolved = base->clone (); + return; + } + } + break; + + case InferType::InferTypeKind::FLOAT: { + if (type.get_infer_kind () == InferType::InferTypeKind::FLOAT) + { + resolved = type.clone (); + return; + } + else if (type.get_infer_kind () == InferType::InferTypeKind::GENERAL) + { + resolved = base->clone (); + return; + } + } + break; + } + + BaseCoercionRules::visit (type); + } + + void visit (CharType &type) override + { + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL); + if (is_valid) + { + resolved = type.clone (); + return; + } + + BaseCoercionRules::visit (type); + } + } + + void visit (ReferenceType &type) override + + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL); + if (is_valid) + { + resolved = type.clone (); + return; + } + + BaseCoercionRules::visit (type); + } + + void visit (ParamType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL); + if (is_valid) + { + resolved = type.clone (); + return; + } + + BaseCoercionRules::visit (type); + } + +private: + BaseType *get_base () override { return base; } + + InferType *base; +}; + +class FnCoercionRules : public BaseCoercionRules +{ + using Rust::TyTy::BaseCoercionRules::visit; + +public: + FnCoercionRules (FnType *base) : BaseCoercionRules (base), base (base) {} + + void visit (InferType &type) override + { + if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL) + { + BaseCoercionRules::visit (type); + return; + } + + resolved = base->clone (); + resolved->set_ref (type.get_ref ()); + } + + void visit (FnType &type) override + { + if (base->num_params () != type.num_params ()) + { + BaseCoercionRules::visit (type); + return; + } + + for (size_t i = 0; i < base->num_params (); i++) + { + auto a = base->param_at (i).second; + auto b = type.param_at (i).second; + + auto unified_param = a->unify (b); + if (unified_param == nullptr) + { + BaseCoercionRules::visit (type); + return; + } + } + + auto unified_return + = base->get_return_type ()->unify (type.get_return_type ()); + if (unified_return == nullptr) + { + BaseCoercionRules::visit (type); + return; + } + + resolved = base->clone (); + resolved->set_ref (type.get_ref ()); + } + +private: + BaseType *get_base () override { return base; } + + FnType *base; +}; + +class FnptrCoercionRules : public BaseCoercionRules +{ + using Rust::TyTy::BaseCoercionRules::visit; + +public: + FnptrCoercionRules (FnPtr *base) : BaseCoercionRules (base), base (base) {} + + void visit (InferType &type) override + { + if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL) + { + BaseCoercionRules::visit (type); + return; + } + + resolved = base->clone (); + resolved->set_ref (type.get_ref ()); + } + + void visit (FnPtr &type) override + { + auto this_ret_type = base->get_return_type (); + auto other_ret_type = type.get_return_type (); + auto unified_result = this_ret_type->unify (other_ret_type); + if (unified_result == nullptr + || unified_result->get_kind () == TypeKind::ERROR) + { + BaseCoercionRules::visit (type); + return; + } + + if (base->num_params () != type.num_params ()) + { + BaseCoercionRules::visit (type); + return; + } + + for (size_t i = 0; i < base->num_params (); i++) + { + auto this_param = base->param_at (i); + auto other_param = type.param_at (i); + auto unified_param = this_param->unify (other_param); + if (unified_param == nullptr + || unified_param->get_kind () == TypeKind::ERROR) + { + BaseCoercionRules::visit (type); + return; + } + } + + resolved = base->clone (); + resolved->set_ref (type.get_ref ()); + } + + void visit (FnType &type) override + { + auto this_ret_type = base->get_return_type (); + auto other_ret_type = type.get_return_type (); + auto unified_result = this_ret_type->unify (other_ret_type); + if (unified_result == nullptr + || unified_result->get_kind () == TypeKind::ERROR) + { + BaseCoercionRules::visit (type); + return; + } + + if (base->num_params () != type.num_params ()) + { + BaseCoercionRules::visit (type); + return; + } + + for (size_t i = 0; i < base->num_params (); i++) + { + auto this_param = base->param_at (i); + auto other_param = type.param_at (i).second; + auto unified_param = this_param->unify (other_param); + if (unified_param == nullptr + || unified_param->get_kind () == TypeKind::ERROR) + { + BaseCoercionRules::visit (type); + return; + } + } + + resolved = base->clone (); + resolved->set_ref (type.get_ref ()); + } + +private: + BaseType *get_base () override { return base; } + + FnPtr *base; +}; + +class ArrayCoercionRules : public BaseCoercionRules +{ + using Rust::TyTy::BaseCoercionRules::visit; + +public: + ArrayCoercionRules (ArrayType *base) : BaseCoercionRules (base), base (base) + {} + + void visit (ArrayType &type) override + { + // check base type + auto base_resolved + = base->get_element_type ()->unify (type.get_element_type ()); + if (base_resolved == nullptr) + { + BaseCoercionRules::visit (type); + return; + } + + auto backend = rust_get_backend (); + + // need to check the base types and capacity + if (!backend->const_values_equal (type.get_capacity (), + base->get_capacity ())) + { + BaseCoercionRules::visit (type); + return; + } + + resolved + = new ArrayType (type.get_ref (), type.get_ty_ref (), + type.get_capacity (), TyVar (base_resolved->get_ref ())); + } + +private: + BaseType *get_base () override { return base; } + + ArrayType *base; +}; + +class BoolCoercionRules : public BaseCoercionRules +{ + using Rust::TyTy::BaseCoercionRules::visit; + +public: + BoolCoercionRules (BoolType *base) : BaseCoercionRules (base), base (base) {} + + void visit (BoolType &type) override + { + resolved = new BoolType (type.get_ref (), type.get_ty_ref ()); + } + + void visit (InferType &type) override + { + switch (type.get_infer_kind ()) + { + case InferType::InferTypeKind::GENERAL: + resolved = base->clone (); + break; + + default: + BaseCoercionRules::visit (type); + break; + } + } + +private: + BaseType *get_base () override { return base; } + + BoolType *base; +}; + +class IntCoercionRules : public BaseCoercionRules +{ + using Rust::TyTy::BaseCoercionRules::visit; + +public: + IntCoercionRules (IntType *base) : BaseCoercionRules (base), base (base) {} + + void visit (InferType &type) override + { + // cant assign a float inference variable + if (type.get_infer_kind () == InferType::InferTypeKind::FLOAT) + { + BaseCoercionRules::visit (type); + return; + } + + resolved = base->clone (); + resolved->set_ref (type.get_ref ()); + } + + void visit (IntType &type) override + { + if (type.get_int_kind () != base->get_int_kind ()) + { + BaseCoercionRules::visit (type); + return; + } + + resolved + = new IntType (type.get_ref (), type.get_ty_ref (), type.get_int_kind ()); + } + +private: + BaseType *get_base () override { return base; } + + IntType *base; +}; + +class UintCoercionRules : public BaseCoercionRules +{ + using Rust::TyTy::BaseCoercionRules::visit; + +public: + UintCoercionRules (UintType *base) : BaseCoercionRules (base), base (base) {} + + void visit (InferType &type) override + { + // cant assign a float inference variable + if (type.get_infer_kind () == InferType::InferTypeKind::FLOAT) + { + BaseCoercionRules::visit (type); + return; + } + + resolved = base->clone (); + resolved->set_ref (type.get_ref ()); + } + + void visit (UintType &type) override + { + if (type.get_uint_kind () != base->get_uint_kind ()) + { + BaseCoercionRules::visit (type); + return; + } + + resolved = new UintType (type.get_ref (), type.get_ty_ref (), + type.get_uint_kind ()); + } + +private: + BaseType *get_base () override { return base; } + + UintType *base; +}; + +class FloatCoercionRules : public BaseCoercionRules +{ + using Rust::TyTy::BaseCoercionRules::visit; + +public: + FloatCoercionRules (FloatType *base) : BaseCoercionRules (base), base (base) + {} + + void visit (InferType &type) override + { + if (type.get_infer_kind () == InferType::InferTypeKind::INTEGRAL) + { + BaseCoercionRules::visit (type); + return; + } + + resolved = base->clone (); + resolved->set_ref (type.get_ref ()); + } + + void visit (FloatType &type) override + { + if (type.get_float_kind () != base->get_float_kind ()) + { + BaseCoercionRules::visit (type); + return; + } + + resolved = new FloatType (type.get_ref (), type.get_ty_ref (), + type.get_float_kind ()); + } + +private: + BaseType *get_base () override { return base; } + + FloatType *base; +}; + +class ADTCoercionRules : public BaseCoercionRules +{ + using Rust::TyTy::BaseCoercionRules::visit; + +public: + ADTCoercionRules (ADTType *base) : BaseCoercionRules (base), base (base) {} + + void visit (ADTType &type) override + { + if (base->get_identifier ().compare (type.get_identifier ()) != 0) + { + BaseCoercionRules::visit (type); + return; + } + + if (base->num_fields () != type.num_fields ()) + { + BaseCoercionRules::visit (type); + return; + } + + for (size_t i = 0; i < type.num_fields (); ++i) + { + TyTy::StructFieldType *base_field = base->get_field (i); + TyTy::StructFieldType *other_field = type.get_field (i); + + TyTy::BaseType *this_field_ty = base_field->get_field_type (); + TyTy::BaseType *other_field_ty = other_field->get_field_type (); + + BaseType *unified_ty = this_field_ty->unify (other_field_ty); + if (unified_ty->get_kind () == TyTy::TypeKind::ERROR) + return; + } + + resolved = type.clone (); + } + +private: + BaseType *get_base () override { return base; } + + ADTType *base; +}; + +class TupleCoercionRules : public BaseCoercionRules +{ + using Rust::TyTy::BaseCoercionRules::visit; + +public: + TupleCoercionRules (TupleType *base) : BaseCoercionRules (base), base (base) + {} + + void visit (TupleType &type) override + { + if (base->num_fields () != type.num_fields ()) + { + BaseCoercionRules::visit (type); + return; + } + + std::vector<TyVar> fields; + for (size_t i = 0; i < base->num_fields (); i++) + { + BaseType *bo = base->get_field (i); + BaseType *fo = type.get_field (i); + + BaseType *unified_ty = bo->unify (fo); + if (unified_ty->get_kind () == TyTy::TypeKind::ERROR) + return; + + fields.push_back (TyVar (unified_ty->get_ref ())); + } + + resolved + = new TyTy::TupleType (type.get_ref (), type.get_ty_ref (), fields); + } + +private: + BaseType *get_base () override { return base; } + + TupleType *base; +}; + +class USizeCoercionRules : public BaseCoercionRules +{ + using Rust::TyTy::BaseCoercionRules::visit; + +public: + USizeCoercionRules (USizeType *base) : BaseCoercionRules (base), base (base) + {} + + void visit (InferType &type) override + { + // cant assign a float inference variable + if (type.get_infer_kind () == InferType::InferTypeKind::FLOAT) + { + BaseCoercionRules::visit (type); + return; + } + + resolved = base->clone (); + resolved->set_ref (type.get_ref ()); + } + + void visit (USizeType &type) override { resolved = type.clone (); } + +private: + BaseType *get_base () override { return base; } + + USizeType *base; +}; + +class ISizeCoercionRules : public BaseCoercionRules +{ + using Rust::TyTy::BaseCoercionRules::visit; + +public: + ISizeCoercionRules (ISizeType *base) : BaseCoercionRules (base), base (base) + {} + + void visit (InferType &type) override + { + // cant assign a float inference variable + if (type.get_infer_kind () == InferType::InferTypeKind::FLOAT) + { + BaseCoercionRules::visit (type); + return; + } + + resolved = base->clone (); + resolved->set_ref (type.get_ref ()); + } + + void visit (ISizeType &type) override { resolved = type.clone (); } + +private: + BaseType *get_base () override { return base; } + + ISizeType *base; +}; + +class CharCoercionRules : public BaseCoercionRules +{ + using Rust::TyTy::BaseCoercionRules::visit; + +public: + CharCoercionRules (CharType *base) : BaseCoercionRules (base), base (base) {} + + void visit (InferType &type) override + { + if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL) + { + BaseCoercionRules::visit (type); + return; + } + + resolved = base->clone (); + resolved->set_ref (type.get_ref ()); + } + + void visit (CharType &type) override { resolved = type.clone (); } + +private: + BaseType *get_base () override { return base; } + + CharType *base; +}; + +class ReferenceCoercionRules : public BaseCoercionRules +{ + using Rust::TyTy::BaseCoercionRules::visit; + +public: + ReferenceCoercionRules (ReferenceType *base) + : BaseCoercionRules (base), base (base) + {} + + void visit (ReferenceType &type) override + { + auto base_type = base->get_base (); + auto other_base_type = type.get_base (); + + TyTy::BaseType *base_resolved = base_type->unify (other_base_type); + if (base_resolved == nullptr + || base_resolved->get_kind () == TypeKind::ERROR) + { + BaseCoercionRules::visit (type); + return; + } + + // we can allow for mutability changes here by casting down from mutability + // eg: mut vs const, we cant take a mutable reference from a const + // eg: const vs mut we can take a const reference from a mutable one + if (!base->is_mutable () || (base->is_mutable () == type.is_mutable ())) + { + resolved = new ReferenceType (base->get_ref (), base->get_ty_ref (), + TyVar (base_resolved->get_ref ()), + base->is_mutable ()); + return; + } + + BaseCoercionRules::visit (type); + } + +private: + BaseType *get_base () override { return base; } + + ReferenceType *base; +}; + +class ParamCoercionRules : public BaseCoercionRules +{ + using Rust::TyTy::BaseCoercionRules::visit; + +public: + ParamCoercionRules (ParamType *base) : BaseCoercionRules (base), base (base) + {} + + // param types are a placeholder we shouldn't have cases where we unify + // against it. eg: struct foo<T> { a: T }; When we invoke it we can do either: + // + // foo<i32>{ a: 123 }. + // Then this enforces the i32 type to be referenced on the + // field via an hirid. + // + // rust also allows for a = foo{a:123}; Where we can use an Inference Variable + // to handle the typing of the struct + BaseType *coerce (BaseType *other) override final + { + if (base->get_ref () == base->get_ty_ref ()) + return BaseCoercionRules::coerce (other); + + auto context = Resolver::TypeCheckContext::get (); + BaseType *lookup = nullptr; + bool ok = context->lookup_type (base->get_ty_ref (), &lookup); + rust_assert (ok); + + return lookup->unify (other); + } + + void visit (ParamType &type) override + { + if (base->get_symbol ().compare (type.get_symbol ()) != 0) + { + BaseCoercionRules::visit (type); + return; + } + + resolved = type.clone (); + } + + void visit (InferType &type) override + { + if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL) + { + BaseCoercionRules::visit (type); + return; + } + + resolved = base->clone (); + } + +private: + BaseType *get_base () override { return base; } + + ParamType *base; +}; + +class StrCoercionRules : public BaseCoercionRules +{ + // FIXME we will need a enum for the StrType like ByteBuf etc.. + using Rust::TyTy::BaseCoercionRules::visit; + +public: + StrCoercionRules (StrType *base) : BaseCoercionRules (base), base (base) {} + + void visit (StrType &type) override { resolved = type.clone (); } + +private: + BaseType *get_base () override { return base; } + + StrType *base; +}; + +class NeverCoercionRules : public BaseCoercionRules +{ + using Rust::TyTy::BaseCoercionRules::visit; + +public: + NeverCoercionRules (NeverType *base) : BaseCoercionRules (base), base (base) + {} + + virtual void visit (NeverType &type) override { resolved = type.clone (); } + +private: + BaseType *get_base () override { return base; } + + NeverType *base; +}; + +class PlaceholderCoercionRules : public BaseCoercionRules +{ + using Rust::TyTy::BaseCoercionRules::visit; + +public: + PlaceholderCoercionRules (PlaceholderType *base) + : BaseCoercionRules (base), base (base) + {} + +private: + BaseType *get_base () override { return base; } + + PlaceholderType *base; +}; + +} // namespace TyTy +} // namespace Rust + +#endif // RUST_TYTY_COERCION_RULES diff --git a/gcc/rust/typecheck/rust-tyty.cc b/gcc/rust/typecheck/rust-tyty.cc index 0eceaef..16bb01b 100644 --- a/gcc/rust/typecheck/rust-tyty.cc +++ b/gcc/rust/typecheck/rust-tyty.cc @@ -23,6 +23,7 @@ #include "rust-hir-type-check-type.h" #include "rust-tyty-rules.h" #include "rust-tyty-cmp.h" +#include "rust-tyty-coercion.h" #include "rust-hir-map.h" #include "rust-substitution-mapper.h" @@ -111,6 +112,13 @@ InferType::can_eq (const BaseType *other, bool emit_errors) const } BaseType * +InferType::coerce (BaseType *other) +{ + InferCoercionRules r (this); + return r.coerce (other); +} + +BaseType * InferType::clone () { return new InferType (get_ref (), get_ty_ref (), get_infer_kind (), @@ -173,6 +181,12 @@ ErrorType::can_eq (const BaseType *other, bool emit_errors) const } BaseType * +ErrorType::coerce (BaseType *other) +{ + return this; +} + +BaseType * ErrorType::clone () { return new ErrorType (get_ref (), get_ty_ref (), get_combined_refs ()); @@ -438,6 +452,13 @@ ADTType::unify (BaseType *other) return r.unify (other); } +BaseType * +ADTType::coerce (BaseType *other) +{ + ADTCoercionRules r (this); + return r.coerce (other); +} + bool ADTType::can_eq (const BaseType *other, bool emit_errors) const { @@ -605,6 +626,13 @@ TupleType::unify (BaseType *other) return r.unify (other); } +BaseType * +TupleType::coerce (BaseType *other) +{ + TupleCoercionRules r (this); + return r.coerce (other); +} + bool TupleType::can_eq (const BaseType *other, bool emit_errors) const { @@ -695,6 +723,13 @@ FnType::unify (BaseType *other) return r.unify (other); } +BaseType * +FnType::coerce (BaseType *other) +{ + FnCoercionRules r (this); + return r.coerce (other); +} + bool FnType::can_eq (const BaseType *other, bool emit_errors) const { @@ -896,6 +931,13 @@ FnPtr::unify (BaseType *other) return r.unify (other); } +BaseType * +FnPtr::coerce (BaseType *other) +{ + FnptrCoercionRules r (this); + return r.coerce (other); +} + bool FnPtr::can_eq (const BaseType *other, bool emit_errors) const { @@ -969,6 +1011,13 @@ ArrayType::unify (BaseType *other) return r.unify (other); } +BaseType * +ArrayType::coerce (BaseType *other) +{ + ArrayCoercionRules r (this); + return r.coerce (other); +} + bool ArrayType::can_eq (const BaseType *other, bool emit_errors) const { @@ -1030,6 +1079,13 @@ BoolType::unify (BaseType *other) return r.unify (other); } +BaseType * +BoolType::coerce (BaseType *other) +{ + BoolCoercionRules r (this); + return r.coerce (other); +} + bool BoolType::can_eq (const BaseType *other, bool emit_errors) const { @@ -1082,6 +1138,13 @@ IntType::unify (BaseType *other) return r.unify (other); } +BaseType * +IntType::coerce (BaseType *other) +{ + IntCoercionRules r (this); + return r.coerce (other); +} + bool IntType::can_eq (const BaseType *other, bool emit_errors) const { @@ -1145,6 +1208,13 @@ UintType::unify (BaseType *other) return r.unify (other); } +BaseType * +UintType::coerce (BaseType *other) +{ + UintCoercionRules r (this); + return r.coerce (other); +} + bool UintType::can_eq (const BaseType *other, bool emit_errors) const { @@ -1202,6 +1272,13 @@ FloatType::unify (BaseType *other) return r.unify (other); } +BaseType * +FloatType::coerce (BaseType *other) +{ + FloatCoercionRules r (this); + return r.coerce (other); +} + bool FloatType::can_eq (const BaseType *other, bool emit_errors) const { @@ -1251,6 +1328,13 @@ USizeType::unify (BaseType *other) return r.unify (other); } +BaseType * +USizeType::coerce (BaseType *other) +{ + USizeCoercionRules r (this); + return r.coerce (other); +} + bool USizeType::can_eq (const BaseType *other, bool emit_errors) const { @@ -1289,6 +1373,13 @@ ISizeType::unify (BaseType *other) return r.unify (other); } +BaseType * +ISizeType::coerce (BaseType *other) +{ + ISizeCoercionRules r (this); + return r.coerce (other); +} + bool ISizeType::can_eq (const BaseType *other, bool emit_errors) const { @@ -1327,6 +1418,13 @@ CharType::unify (BaseType *other) return r.unify (other); } +BaseType * +CharType::coerce (BaseType *other) +{ + CharCoercionRules r (this); + return r.coerce (other); +} + bool CharType::can_eq (const BaseType *other, bool emit_errors) const { @@ -1366,6 +1464,13 @@ ReferenceType::unify (BaseType *other) return r.unify (other); } +BaseType * +ReferenceType::coerce (BaseType *other) +{ + ReferenceCoercionRules r (this); + return r.coerce (other); +} + bool ReferenceType::can_eq (const BaseType *other, bool emit_errors) const { @@ -1447,6 +1552,13 @@ ParamType::unify (BaseType *other) return r.unify (other); } +BaseType * +ParamType::coerce (BaseType *other) +{ + ParamCoercionRules r (this); + return r.coerce (other); +} + bool ParamType::can_eq (const BaseType *other, bool emit_errors) const { @@ -1553,6 +1665,13 @@ StrType::unify (BaseType *other) return r.unify (other); } +BaseType * +StrType::coerce (BaseType *other) +{ + StrCoercionRules r (this); + return r.coerce (other); +} + bool StrType::can_eq (const BaseType *other, bool emit_errors) const { @@ -1591,6 +1710,13 @@ NeverType::unify (BaseType *other) return r.unify (other); } +BaseType * +NeverType::coerce (BaseType *other) +{ + NeverCoercionRules r (this); + return r.coerce (other); +} + bool NeverType::can_eq (const BaseType *other, bool emit_errors) const { @@ -1629,6 +1755,13 @@ PlaceholderType::unify (BaseType *other) return r.unify (other); } +BaseType * +PlaceholderType::coerce (BaseType *other) +{ + PlaceholderCoercionRules r (this); + return r.coerce (other); +} + bool PlaceholderType::can_eq (const BaseType *other, bool emit_errors) const { diff --git a/gcc/rust/typecheck/rust-tyty.h b/gcc/rust/typecheck/rust-tyty.h index 680e43f..c0af9f6 100644 --- a/gcc/rust/typecheck/rust-tyty.h +++ b/gcc/rust/typecheck/rust-tyty.h @@ -167,6 +167,9 @@ public: // checks virtual bool can_eq (const BaseType *other, bool emit_errors) const = 0; + // this is the base coercion interface for types + virtual BaseType *coerce (BaseType *other) = 0; + // Check value equality between two ty. Type inference rules are ignored. Two // ty are considered equal if they're of the same kind, and // 1. (For ADTs, arrays, tuples, refs) have the same underlying ty @@ -287,6 +290,8 @@ public: bool can_eq (const BaseType *other, bool emit_errors) const override final; + BaseType *coerce (BaseType *other) override; + BaseType *clone () final override; InferTypeKind get_infer_kind () const { return infer_kind; } @@ -321,6 +326,7 @@ public: BaseType *unify (BaseType *other) override; bool can_eq (const BaseType *other, bool emit_errors) const override final; + BaseType *coerce (BaseType *other) override; BaseType *clone () final override; @@ -350,6 +356,7 @@ public: BaseType *unify (BaseType *other) override; bool can_eq (const BaseType *other, bool emit_errors) const override final; + BaseType *coerce (BaseType *other) override; BaseType *clone () final override; @@ -436,6 +443,7 @@ public: BaseType *unify (BaseType *other) override; bool can_eq (const BaseType *other, bool emit_errors) const override final; + BaseType *coerce (BaseType *other) override; bool is_equal (const BaseType &other) const override; @@ -869,6 +877,7 @@ public: BaseType *unify (BaseType *other) override; bool can_eq (const BaseType *other, bool emit_errors) const override final; + BaseType *coerce (BaseType *other) override; bool is_equal (const BaseType &other) const override; @@ -989,6 +998,7 @@ public: BaseType *unify (BaseType *other) override; bool can_eq (const BaseType *other, bool emit_errors) const override final; + BaseType *coerce (BaseType *other) override; bool is_equal (const BaseType &other) const override; @@ -1089,6 +1099,7 @@ public: BaseType *unify (BaseType *other) override; bool can_eq (const BaseType *other, bool emit_errors) const override final; + BaseType *coerce (BaseType *other) override; bool is_equal (const BaseType &other) const override; @@ -1132,6 +1143,7 @@ public: BaseType *unify (BaseType *other) override; bool can_eq (const BaseType *other, bool emit_errors) const override final; + BaseType *coerce (BaseType *other) override; bool is_equal (const BaseType &other) const override; @@ -1172,6 +1184,7 @@ public: BaseType *unify (BaseType *other) override; bool can_eq (const BaseType *other, bool emit_errors) const override final; + BaseType *coerce (BaseType *other) override; BaseType *clone () final override; }; @@ -1206,6 +1219,7 @@ public: BaseType *unify (BaseType *other) override; bool can_eq (const BaseType *other, bool emit_errors) const override final; + BaseType *coerce (BaseType *other) override; IntKind get_int_kind () const { return int_kind; } @@ -1247,6 +1261,7 @@ public: BaseType *unify (BaseType *other) override; bool can_eq (const BaseType *other, bool emit_errors) const override final; + BaseType *coerce (BaseType *other) override; UintKind get_uint_kind () const { return uint_kind; } @@ -1286,6 +1301,7 @@ public: BaseType *unify (BaseType *other) override; bool can_eq (const BaseType *other, bool emit_errors) const override final; + BaseType *coerce (BaseType *other) override; FloatKind get_float_kind () const { return float_kind; } @@ -1317,6 +1333,7 @@ public: BaseType *unify (BaseType *other) override; bool can_eq (const BaseType *other, bool emit_errors) const override final; + BaseType *coerce (BaseType *other) override; BaseType *clone () final override; }; @@ -1341,6 +1358,7 @@ public: BaseType *unify (BaseType *other) override; bool can_eq (const BaseType *other, bool emit_errors) const override final; + BaseType *coerce (BaseType *other) override; BaseType *clone () final override; }; @@ -1365,6 +1383,7 @@ public: BaseType *unify (BaseType *other) override; bool can_eq (const BaseType *other, bool emit_errors) const override final; + BaseType *coerce (BaseType *other) override; BaseType *clone () final override; }; @@ -1393,6 +1412,7 @@ public: BaseType *unify (BaseType *other) override; bool can_eq (const BaseType *other, bool emit_errors) const override final; + BaseType *coerce (BaseType *other) override; bool is_equal (const BaseType &other) const override; @@ -1432,6 +1452,7 @@ public: BaseType *unify (BaseType *other) override; bool can_eq (const BaseType *other, bool emit_errors) const override final; + BaseType *coerce (BaseType *other) override; bool is_equal (const BaseType &other) const override; @@ -1466,6 +1487,7 @@ public: BaseType *unify (BaseType *other) override; bool can_eq (const BaseType *other, bool emit_errors) const override final; + BaseType *coerce (BaseType *other) override; BaseType *clone () final override; @@ -1495,6 +1517,7 @@ public: BaseType *unify (BaseType *other) override; bool can_eq (const BaseType *other, bool emit_errors) const override final; + BaseType *coerce (BaseType *other) override; BaseType *clone () final override; |