diff options
Diffstat (limited to 'gcc')
-rw-r--r-- | gcc/rust/backend/rust-compile-expr.cc | 302 | ||||
-rw-r--r-- | gcc/rust/backend/rust-compile-expr.h | 6 | ||||
-rw-r--r-- | gcc/rust/typecheck/rust-hir-type-check-expr.h | 593 |
3 files changed, 336 insertions, 565 deletions
diff --git a/gcc/rust/backend/rust-compile-expr.cc b/gcc/rust/backend/rust-compile-expr.cc index 4fc69d3..8902574 100644 --- a/gcc/rust/backend/rust-compile-expr.cc +++ b/gcc/rust/backend/rust-compile-expr.cc @@ -39,100 +39,18 @@ CompileExpr::visit (HIR::ArithmeticOrLogicalExpr &expr) TyTy::FnType *fntype; bool is_op_overload = ctx->get_tyctx ()->lookup_operator_overload ( expr.get_mappings ().get_hirid (), &fntype); - if (!is_op_overload) + if (is_op_overload) { - translated = ctx->get_backend ()->arithmetic_or_logical_expression ( - op, lhs, rhs, expr.get_locus ()); + auto lang_item_type + = Analysis::RustLangItem::OperatorToLangItem (expr.get_expr_type ()); + translated = resolve_operator_overload (lang_item_type, expr, lhs, rhs, + expr.get_lhs (), expr.get_rhs ()); return; } - // lookup the resolved name - NodeId resolved_node_id = UNKNOWN_NODEID; - if (!ctx->get_resolver ()->lookup_resolved_name ( - expr.get_mappings ().get_nodeid (), &resolved_node_id)) - { - rust_error_at (expr.get_locus (), "failed to lookup resolved MethodCall"); - return; - } - - // reverse lookup - HirId ref; - if (!ctx->get_mappings ()->lookup_node_to_hir ( - expr.get_mappings ().get_crate_num (), resolved_node_id, &ref)) - { - rust_fatal_error (expr.get_locus (), "reverse lookup failure"); - return; - } - - TyTy::BaseType *receiver = nullptr; - bool ok - = ctx->get_tyctx ()->lookup_receiver (expr.get_mappings ().get_hirid (), - &receiver); - rust_assert (ok); - - bool is_dyn_dispatch - = receiver->get_root ()->get_kind () == TyTy::TypeKind::DYNAMIC; - bool is_generic_receiver = receiver->get_kind () == TyTy::TypeKind::PARAM; - if (is_generic_receiver) - { - TyTy::ParamType *p = static_cast<TyTy::ParamType *> (receiver); - receiver = p->resolve (); - } - - if (is_dyn_dispatch) - { - const TyTy::DynamicObjectType *dyn - = static_cast<const TyTy::DynamicObjectType *> (receiver->get_root ()); - - std::vector<HIR::Expr *> arguments; - arguments.push_back (expr.get_rhs ()); - - translated = compile_dyn_dispatch_call (dyn, receiver, fntype, lhs, - arguments, expr.get_locus ()); - return; - } - - // lookup compiled functions since it may have already been compiled - HIR::PathIdentSegment segment_name ("add"); - Bexpression *fn_expr - = resolve_method_address (fntype, ref, receiver, segment_name, - expr.get_mappings (), expr.get_locus ()); - - // lookup the autoderef mappings - std::vector<Resolver::Adjustment> *adjustments = nullptr; - ok = ctx->get_tyctx ()->lookup_autoderef_mappings ( - expr.get_mappings ().get_hirid (), &adjustments); - rust_assert (ok); - - Bexpression *self = lhs; - for (auto &adjustment : *adjustments) - { - switch (adjustment.get_type ()) - { - case Resolver::Adjustment::AdjustmentType::IMM_REF: - case Resolver::Adjustment::AdjustmentType::MUT_REF: - self = ctx->get_backend ()->address_expression ( - self, expr.get_lhs ()->get_locus ()); - break; - - case Resolver::Adjustment::AdjustmentType::DEREF_REF: - Btype *expected_type - = TyTyResolveCompile::compile (ctx, adjustment.get_expected ()); - self = ctx->get_backend ()->indirect_expression ( - expected_type, self, true, /* known_valid*/ - expr.get_lhs ()->get_locus ()); - break; - } - } - - std::vector<Bexpression *> args; - args.push_back (self); // adjusted self - args.push_back (rhs); - - auto fncontext = ctx->peek_fn (); translated - = ctx->get_backend ()->call_expression (fncontext.fndecl, fn_expr, args, - nullptr, expr.get_locus ()); + = ctx->get_backend ()->arithmetic_or_logical_expression (op, lhs, rhs, + expr.get_locus ()); } void @@ -148,106 +66,30 @@ CompileExpr::visit (HIR::CompoundAssignmentExpr &expr) TyTy::FnType *fntype; bool is_op_overload = ctx->get_tyctx ()->lookup_operator_overload ( expr.get_mappings ().get_hirid (), &fntype); - if (!is_op_overload) + if (is_op_overload) { - auto operator_expr - = ctx->get_backend ()->arithmetic_or_logical_expression ( - op, lhs, rhs, expr.get_locus ()); - Bstatement *assignment - = ctx->get_backend ()->assignment_statement (fn.fndecl, lhs, - operator_expr, - expr.get_locus ()); + auto lang_item_type + = Analysis::RustLangItem::CompoundAssignmentOperatorToLangItem ( + expr.get_expr_type ()); + auto compound_assignment + = resolve_operator_overload (lang_item_type, expr, lhs, rhs, + expr.get_left_expr ().get (), + expr.get_right_expr ().get ()); + auto assignment + = ctx->get_backend ()->expression_statement (fn.fndecl, + compound_assignment); ctx->add_statement (assignment); - return; - } - - // lookup the resolved name - NodeId resolved_node_id = UNKNOWN_NODEID; - if (!ctx->get_resolver ()->lookup_resolved_name ( - expr.get_mappings ().get_nodeid (), &resolved_node_id)) - { - rust_error_at (expr.get_locus (), "failed to lookup resolved MethodCall"); - return; - } - - // reverse lookup - HirId ref; - if (!ctx->get_mappings ()->lookup_node_to_hir ( - expr.get_mappings ().get_crate_num (), resolved_node_id, &ref)) - { - rust_fatal_error (expr.get_locus (), "reverse lookup failure"); - return; - } - - TyTy::BaseType *receiver = nullptr; - bool ok - = ctx->get_tyctx ()->lookup_receiver (expr.get_mappings ().get_hirid (), - &receiver); - rust_assert (ok); - bool is_dyn_dispatch - = receiver->get_root ()->get_kind () == TyTy::TypeKind::DYNAMIC; - bool is_generic_receiver = receiver->get_kind () == TyTy::TypeKind::PARAM; - if (is_generic_receiver) - { - TyTy::ParamType *p = static_cast<TyTy::ParamType *> (receiver); - receiver = p->resolve (); - } - - if (is_dyn_dispatch) - { - const TyTy::DynamicObjectType *dyn - = static_cast<const TyTy::DynamicObjectType *> (receiver->get_root ()); - - std::vector<HIR::Expr *> arguments; - arguments.push_back (expr.get_right_expr ().get ()); - - translated = compile_dyn_dispatch_call (dyn, receiver, fntype, lhs, - arguments, expr.get_locus ()); return; } - // lookup compiled functions since it may have already been compiled - HIR::PathIdentSegment segment_name ("add_assign"); - Bexpression *fn_expr - = resolve_method_address (fntype, ref, receiver, segment_name, - expr.get_mappings (), expr.get_locus ()); - - // lookup the autoderef mappings - std::vector<Resolver::Adjustment> *adjustments = nullptr; - ok = ctx->get_tyctx ()->lookup_autoderef_mappings ( - expr.get_mappings ().get_hirid (), &adjustments); - rust_assert (ok); - - Bexpression *self = lhs; - for (auto &adjustment : *adjustments) - { - switch (adjustment.get_type ()) - { - case Resolver::Adjustment::AdjustmentType::IMM_REF: - case Resolver::Adjustment::AdjustmentType::MUT_REF: - self = ctx->get_backend ()->address_expression ( - self, expr.get_left_expr ()->get_locus ()); - break; - - case Resolver::Adjustment::AdjustmentType::DEREF_REF: - Btype *expected_type - = TyTyResolveCompile::compile (ctx, adjustment.get_expected ()); - self = ctx->get_backend ()->indirect_expression ( - expected_type, self, true, /* known_valid*/ - expr.get_left_expr ()->get_locus ()); - break; - } - } - - std::vector<Bexpression *> args; - args.push_back (self); // adjusted self - args.push_back (rhs); - - auto fncontext = ctx->peek_fn (); - translated - = ctx->get_backend ()->call_expression (fncontext.fndecl, fn_expr, args, - nullptr, expr.get_locus ()); + auto operator_expr + = ctx->get_backend ()->arithmetic_or_logical_expression (op, lhs, rhs, + expr.get_locus ()); + Bstatement *assignment + = ctx->get_backend ()->assignment_statement (fn.fndecl, lhs, operator_expr, + expr.get_locus ()); + ctx->add_statement (assignment); } Bexpression * @@ -427,5 +269,99 @@ CompileExpr::resolve_method_address (TyTy::FnType *fntype, HirId ref, } } +Bexpression * +CompileExpr::resolve_operator_overload ( + Analysis::RustLangItem::ItemType lang_item_type, HIR::OperatorExpr &expr, + Bexpression *lhs, Bexpression *rhs, HIR::Expr *lhs_expr, HIR::Expr *rhs_expr) +{ + TyTy::FnType *fntype; + bool is_op_overload = ctx->get_tyctx ()->lookup_operator_overload ( + expr.get_mappings ().get_hirid (), &fntype); + rust_assert (is_op_overload); + + // lookup the resolved name + NodeId resolved_node_id = UNKNOWN_NODEID; + bool ok = ctx->get_resolver ()->lookup_resolved_name ( + expr.get_mappings ().get_nodeid (), &resolved_node_id); + rust_assert (ok); + + // reverse lookup + HirId ref; + ok = ctx->get_mappings ()->lookup_node_to_hir ( + expr.get_mappings ().get_crate_num (), resolved_node_id, &ref); + rust_assert (ok); + + TyTy::BaseType *receiver = nullptr; + ok = ctx->get_tyctx ()->lookup_receiver (expr.get_mappings ().get_hirid (), + &receiver); + rust_assert (ok); + + bool is_dyn_dispatch + = receiver->get_root ()->get_kind () == TyTy::TypeKind::DYNAMIC; + bool is_generic_receiver = receiver->get_kind () == TyTy::TypeKind::PARAM; + if (is_generic_receiver) + { + TyTy::ParamType *p = static_cast<TyTy::ParamType *> (receiver); + receiver = p->resolve (); + } + + if (is_dyn_dispatch) + { + const TyTy::DynamicObjectType *dyn + = static_cast<const TyTy::DynamicObjectType *> (receiver->get_root ()); + + std::vector<HIR::Expr *> arguments; + arguments.push_back (rhs_expr); + + return compile_dyn_dispatch_call (dyn, receiver, fntype, lhs, arguments, + expr.get_locus ()); + } + + // lookup compiled functions since it may have already been compiled + HIR::PathIdentSegment segment_name ( + Analysis::RustLangItem::ToString (lang_item_type)); + Bexpression *fn_expr + = resolve_method_address (fntype, ref, receiver, segment_name, + expr.get_mappings (), expr.get_locus ()); + + // lookup the autoderef mappings + std::vector<Resolver::Adjustment> *adjustments = nullptr; + ok = ctx->get_tyctx ()->lookup_autoderef_mappings ( + expr.get_mappings ().get_hirid (), &adjustments); + rust_assert (ok); + + // FIXME refactor this out + Bexpression *self = lhs; + for (auto &adjustment : *adjustments) + { + switch (adjustment.get_type ()) + { + case Resolver::Adjustment::AdjustmentType::IMM_REF: + case Resolver::Adjustment::AdjustmentType::MUT_REF: + self + = ctx->get_backend ()->address_expression (self, + lhs_expr->get_locus ()); + break; + + case Resolver::Adjustment::AdjustmentType::DEREF_REF: + Btype *expected_type + = TyTyResolveCompile::compile (ctx, adjustment.get_expected ()); + self + = ctx->get_backend ()->indirect_expression (expected_type, self, + true, /* known_valid*/ + lhs_expr->get_locus ()); + break; + } + } + + std::vector<Bexpression *> args; + args.push_back (self); // adjusted self + args.push_back (rhs); + + auto fncontext = ctx->peek_fn (); + return ctx->get_backend ()->call_expression (fncontext.fndecl, fn_expr, args, + nullptr, expr.get_locus ()); +} + } // namespace Compile } // namespace Rust diff --git a/gcc/rust/backend/rust-compile-expr.h b/gcc/rust/backend/rust-compile-expr.h index bdcc5b6..7238d45 100644 --- a/gcc/rust/backend/rust-compile-expr.h +++ b/gcc/rust/backend/rust-compile-expr.h @@ -1005,6 +1005,12 @@ protected: Analysis::NodeMapping expr_mappings, Location expr_locus); + Bexpression * + resolve_operator_overload (Analysis::RustLangItem::ItemType lang_item_type, + HIR::OperatorExpr &expr, Bexpression *lhs, + Bexpression *rhs, HIR::Expr *lhs_expr, + HIR::Expr *rhs_expr); + private: CompileExpr (Context *ctx) : HIRCompileBase (ctx), translated (nullptr), capacity_expr (nullptr) diff --git a/gcc/rust/typecheck/rust-hir-type-check-expr.h b/gcc/rust/typecheck/rust-hir-type-check-expr.h index 977b817..37cd6b3 100644 --- a/gcc/rust/typecheck/rust-hir-type-check-expr.h +++ b/gcc/rust/typecheck/rust-hir-type-check-expr.h @@ -483,207 +483,24 @@ public: if (result->get_kind () == TyTy::TypeKind::ERROR) return; - // in order to probe of the correct type paths we need the root type, which - // strips any references - const TyTy::BaseType *root = lhs->get_root (); - - // look up lang item for arithmetic type - std::vector<PathProbeCandidate> candidates; auto lang_item_type = Analysis::RustLangItem::CompoundAssignmentOperatorToLangItem ( expr.get_expr_type ()); - std::string associated_item_name - = Analysis::RustLangItem::ToString (lang_item_type); - DefId respective_lang_item_id = UNKNOWN_DEFID; - bool lang_item_defined - = mappings->lookup_lang_item (lang_item_type, &respective_lang_item_id); - - // probe for the lang-item - if (lang_item_defined) - { - bool receiver_is_type_param - = root->get_kind () == TyTy::TypeKind::PARAM; - bool receiver_is_dyn = root->get_kind () == TyTy::TypeKind::DYNAMIC; - - bool receiver_is_generic = receiver_is_type_param || receiver_is_dyn; - bool probe_bounds = true; - bool probe_impls = !receiver_is_generic; - bool ignore_mandatory_trait_items = !receiver_is_generic; - - candidates = PathProbeType::Probe ( - root, HIR::PathIdentSegment (associated_item_name), probe_impls, - probe_bounds, ignore_mandatory_trait_items, respective_lang_item_id); - } - - // autoderef - std::vector<Adjustment> adjustments; - PathProbeCandidate *resolved_candidate - = MethodResolution::Select (candidates, lhs, adjustments); - - // is this the case we are recursive - // handle the case where we are within the impl block for this lang_item - // otherwise we end up with a recursive operator overload such as the i32 - // operator overload trait - if (lang_item_defined && resolved_candidate != nullptr) - { - TypeCheckContextItem &fn_context = context->peek_context (); - if (fn_context.get_type () == TypeCheckContextItem::ItemType::IMPL_ITEM) - { - auto &impl_item = fn_context.get_impl_item (); - HIR::ImplBlock *parent = impl_item.first; - HIR::Function *fn = impl_item.second; - - if (parent->has_trait_ref () - && fn->get_function_name ().compare (associated_item_name) == 0) - { - TraitReference *trait_reference - = TraitResolver::Lookup (*parent->get_trait_ref ().get ()); - if (!trait_reference->is_error ()) - { - TyTy::BaseType *lookup = nullptr; - bool ok - = context->lookup_type (fn->get_mappings ().get_hirid (), - &lookup); - rust_assert (ok); - rust_assert (lookup->get_kind () == TyTy::TypeKind::FNDEF); - - TyTy::FnType *fntype = static_cast<TyTy::FnType *> (lookup); - rust_assert (fntype->is_method ()); - - Adjuster adj (lhs); - TyTy::BaseType *adjusted = adj.adjust_type (adjustments); - - bool is_lang_item_impl - = trait_reference->get_mappings ().get_defid () - == respective_lang_item_id; - bool self_is_lang_item_self - = fntype->get_self_type ()->is_equal (*adjusted); - bool recursive_operator_overload - = is_lang_item_impl && self_is_lang_item_self; - - lang_item_defined = !recursive_operator_overload; - } - } - } - } - - bool have_implementation_for_lang_item = resolved_candidate != nullptr; - if (!lang_item_defined || !have_implementation_for_lang_item) - { - bool valid_lhs = validate_arithmetic_type (lhs, expr.get_expr_type ()); - bool valid_rhs = validate_arithmetic_type (rhs, expr.get_expr_type ()); - bool valid = valid_lhs && valid_rhs; - if (!valid) - { - rust_error_at (expr.get_locus (), - "cannot apply this operator to types %s and %s", - lhs->as_string ().c_str (), - rhs->as_string ().c_str ()); - return; - } - - // nothing left to do - return; - } - - // now its just like a method-call-expr - context->insert_receiver (expr.get_mappings ().get_hirid (), lhs); - - // store the adjustments for code-generation to know what to do - context->insert_autoderef_mappings (expr.get_mappings ().get_hirid (), - std::move (adjustments)); - - TyTy::BaseType *lookup_tyty = resolved_candidate->ty; - NodeId resolved_node_id - = resolved_candidate->is_impl_candidate () - ? resolved_candidate->item.impl.impl_item->get_impl_mappings () - .get_nodeid () - : resolved_candidate->item.trait.item_ref->get_mappings () - .get_nodeid (); - - rust_assert (lookup_tyty->get_kind () == TyTy::TypeKind::FNDEF); - TyTy::BaseType *lookup = lookup_tyty; - TyTy::FnType *fn = static_cast<TyTy::FnType *> (lookup); - rust_assert (fn->is_method ()); - - if (root->get_kind () == TyTy::TypeKind::ADT) - { - const TyTy::ADTType *adt = static_cast<const TyTy::ADTType *> (root); - if (adt->has_substitutions () && fn->needs_substitution ()) - { - // consider the case where we have: - // - // struct Foo<X,Y>(X,Y); - // - // impl<T> Foo<T, i32> { - // fn test<X>(self, a:X) -> (T,X) { (self.0, a) } - // } - // - // In this case we end up with an fn type of: - // - // fn <T,X> test(self:Foo<T,i32>, a:X) -> (T,X) - // - // This means the instance or self we are calling this method for - // will be substituted such that we can get the inherited type - // arguments but then need to use the turbo fish if available or - // infer the remaining arguments. Luckily rust does not allow for - // default types GenericParams on impl blocks since these must - // always be at the end of the list - - auto s = fn->get_self_type ()->get_root (); - rust_assert (s->can_eq (adt, false, false)); - rust_assert (s->get_kind () == TyTy::TypeKind::ADT); - const TyTy::ADTType *self_adt - = static_cast<const TyTy::ADTType *> (s); - - // we need to grab the Self substitutions as the inherit type - // parameters for this - if (self_adt->needs_substitution ()) - { - rust_assert (adt->was_substituted ()); - - TyTy::SubstitutionArgumentMappings used_args_in_prev_segment - = GetUsedSubstArgs::From (adt); - - TyTy::SubstitutionArgumentMappings inherit_type_args - = self_adt->solve_mappings_from_receiver_for_self ( - used_args_in_prev_segment); - - // there may or may not be inherited type arguments - if (!inherit_type_args.is_error ()) - { - // need to apply the inherited type arguments to the - // function - lookup = fn->handle_substitions (inherit_type_args); - } - } - } - } + bool operator_overloaded + = resolve_operator_overload (lang_item_type, expr, lhs, rhs); + if (operator_overloaded) + return; - // type check the arguments - TyTy::FnType *type = static_cast<TyTy::FnType *> (lookup); - rust_assert (type->num_params () == 2); - auto fnparam = type->param_at (1); - auto resolved_argument_type = fnparam.second->unify (rhs); - if (resolved_argument_type->get_kind () == TyTy::TypeKind::ERROR) + bool valid_lhs = validate_arithmetic_type (lhs, expr.get_expr_type ()); + bool valid_rhs = validate_arithmetic_type (rhs, expr.get_expr_type ()); + bool valid = valid_lhs && valid_rhs; + if (!valid) { rust_error_at (expr.get_locus (), - "Type Resolution failure on parameter"); + "cannot apply this operator to types %s and %s", + lhs->as_string ().c_str (), rhs->as_string ().c_str ()); return; } - - // get the return type - TyTy::BaseType *function_ret_tyty = fn->get_return_type ()->clone (); - - // store the expected fntype - context->insert_operator_overload (expr.get_mappings ().get_hirid (), type); - - // set up the resolved name on the path - resolver->insert_resolved_name (expr.get_mappings ().get_nodeid (), - resolved_node_id); - - // return the result of the function back - infered = function_ret_tyty; } void visit (HIR::IdentifierExpr &expr) override @@ -915,203 +732,25 @@ public: auto lhs = TypeCheckExpr::Resolve (expr.get_lhs (), false); auto rhs = TypeCheckExpr::Resolve (expr.get_rhs (), false); - // in order to probe of the correct type paths we need the root type, which - // strips any references - const TyTy::BaseType *root = lhs->get_root (); - - // look up lang item for arithmetic type - std::vector<PathProbeCandidate> candidates; auto lang_item_type = Analysis::RustLangItem::OperatorToLangItem (expr.get_expr_type ()); - std::string associated_item_name - = Analysis::RustLangItem::ToString (lang_item_type); - DefId respective_lang_item_id = UNKNOWN_DEFID; - bool lang_item_defined - = mappings->lookup_lang_item (lang_item_type, &respective_lang_item_id); - - // handle the case where we are within the impl block for this lang_item - // otherwise we end up with a recursive operator overload such as the i32 - // operator overload trait - if (lang_item_defined) - { - TypeCheckContextItem &fn_context = context->peek_context (); - if (fn_context.get_type () == TypeCheckContextItem::ItemType::IMPL_ITEM) - { - auto &impl_item = fn_context.get_impl_item (); - HIR::ImplBlock *parent = impl_item.first; - HIR::Function *fn = impl_item.second; - - if (parent->has_trait_ref () - && fn->get_function_name ().compare (associated_item_name) == 0) - { - TraitReference *trait_reference - = TraitResolver::Lookup (*parent->get_trait_ref ().get ()); - if (!trait_reference->is_error ()) - { - TyTy::BaseType *lookup = nullptr; - bool ok - = context->lookup_type (fn->get_mappings ().get_hirid (), - &lookup); - rust_assert (ok); - rust_assert (lookup->get_kind () == TyTy::TypeKind::FNDEF); - - TyTy::FnType *fntype = static_cast<TyTy::FnType *> (lookup); - rust_assert (fntype->is_method ()); - - bool is_lang_item_impl - = trait_reference->get_mappings ().get_defid () - == respective_lang_item_id; - bool self_is_lang_item_self - = fntype->get_self_type ()->is_equal (*lhs); - - bool recursive_operator_overload - = is_lang_item_impl && self_is_lang_item_self; - lang_item_defined = !recursive_operator_overload; - } - } - } - } - - // probe for the lang-item - if (lang_item_defined) - { - bool receiver_is_type_param - = root->get_kind () == TyTy::TypeKind::PARAM; - bool receiver_is_dyn = root->get_kind () == TyTy::TypeKind::DYNAMIC; - - bool receiver_is_generic = receiver_is_type_param || receiver_is_dyn; - bool probe_bounds = true; - bool probe_impls = !receiver_is_generic; - bool ignore_mandatory_trait_items = !receiver_is_generic; - - candidates = PathProbeType::Probe ( - root, HIR::PathIdentSegment (associated_item_name), probe_impls, - probe_bounds, ignore_mandatory_trait_items, respective_lang_item_id); - } - - bool have_implementation_for_lang_item = candidates.size () > 0; - if (!lang_item_defined || !have_implementation_for_lang_item) - { - bool valid_lhs = validate_arithmetic_type (lhs, expr.get_expr_type ()); - bool valid_rhs = validate_arithmetic_type (rhs, expr.get_expr_type ()); - bool valid = valid_lhs && valid_rhs; - if (!valid) - { - rust_error_at (expr.get_locus (), - "cannot apply this operator to types %s and %s", - lhs->as_string ().c_str (), - rhs->as_string ().c_str ()); - return; - } - - infered = lhs->unify (rhs); - return; - } - - // now its just like a method-call-expr - context->insert_receiver (expr.get_mappings ().get_hirid (), lhs); - - // autoderef - std::vector<Adjustment> adjustments; - PathProbeCandidate *resolved_candidate - = MethodResolution::Select (candidates, lhs, adjustments); - rust_assert (resolved_candidate != nullptr); - - // store the adjustments for code-generation to know what to do - context->insert_autoderef_mappings (expr.get_mappings ().get_hirid (), - std::move (adjustments)); - - TyTy::BaseType *lookup_tyty = resolved_candidate->ty; - NodeId resolved_node_id - = resolved_candidate->is_impl_candidate () - ? resolved_candidate->item.impl.impl_item->get_impl_mappings () - .get_nodeid () - : resolved_candidate->item.trait.item_ref->get_mappings () - .get_nodeid (); - - rust_assert (lookup_tyty->get_kind () == TyTy::TypeKind::FNDEF); - TyTy::BaseType *lookup = lookup_tyty; - TyTy::FnType *fn = static_cast<TyTy::FnType *> (lookup); - rust_assert (fn->is_method ()); - - if (root->get_kind () == TyTy::TypeKind::ADT) - { - const TyTy::ADTType *adt = static_cast<const TyTy::ADTType *> (root); - if (adt->has_substitutions () && fn->needs_substitution ()) - { - // consider the case where we have: - // - // struct Foo<X,Y>(X,Y); - // - // impl<T> Foo<T, i32> { - // fn test<X>(self, a:X) -> (T,X) { (self.0, a) } - // } - // - // In this case we end up with an fn type of: - // - // fn <T,X> test(self:Foo<T,i32>, a:X) -> (T,X) - // - // This means the instance or self we are calling this method for - // will be substituted such that we can get the inherited type - // arguments but then need to use the turbo fish if available or - // infer the remaining arguments. Luckily rust does not allow for - // default types GenericParams on impl blocks since these must - // always be at the end of the list - - auto s = fn->get_self_type ()->get_root (); - rust_assert (s->can_eq (adt, false, false)); - rust_assert (s->get_kind () == TyTy::TypeKind::ADT); - const TyTy::ADTType *self_adt - = static_cast<const TyTy::ADTType *> (s); - - // we need to grab the Self substitutions as the inherit type - // parameters for this - if (self_adt->needs_substitution ()) - { - rust_assert (adt->was_substituted ()); - - TyTy::SubstitutionArgumentMappings used_args_in_prev_segment - = GetUsedSubstArgs::From (adt); - - TyTy::SubstitutionArgumentMappings inherit_type_args - = self_adt->solve_mappings_from_receiver_for_self ( - used_args_in_prev_segment); - - // there may or may not be inherited type arguments - if (!inherit_type_args.is_error ()) - { - // need to apply the inherited type arguments to the - // function - lookup = fn->handle_substitions (inherit_type_args); - } - } - } - } + bool operator_overloaded + = resolve_operator_overload (lang_item_type, expr, lhs, rhs); + if (operator_overloaded) + return; - // type check the arguments - TyTy::FnType *type = static_cast<TyTy::FnType *> (lookup); - rust_assert (type->num_params () == 2); - auto fnparam = type->param_at (1); - auto resolved_argument_type = fnparam.second->unify (rhs); - if (resolved_argument_type->get_kind () == TyTy::TypeKind::ERROR) + bool valid_lhs = validate_arithmetic_type (lhs, expr.get_expr_type ()); + bool valid_rhs = validate_arithmetic_type (rhs, expr.get_expr_type ()); + bool valid = valid_lhs && valid_rhs; + if (!valid) { rust_error_at (expr.get_locus (), - "Type Resolution failure on parameter"); + "cannot apply this operator to types %s and %s", + lhs->as_string ().c_str (), rhs->as_string ().c_str ()); return; } - // get the return type - TyTy::BaseType *function_ret_tyty = fn->get_return_type ()->clone (); - - // store the expected fntype - context->insert_operator_overload (expr.get_mappings ().get_hirid (), type); - - // set up the resolved name on the path - resolver->insert_resolved_name (expr.get_mappings ().get_nodeid (), - resolved_node_id); - - // return the result of the function back - infered = function_ret_tyty; + infered = lhs->unify (rhs); } void visit (HIR::ComparisonExpr &expr) override @@ -1573,6 +1212,196 @@ public: infered = expr_to_convert->cast (tyty_to_convert_to); } +protected: + bool + resolve_operator_overload (Analysis::RustLangItem::ItemType lang_item_type, + HIR::OperatorExpr &expr, TyTy::BaseType *lhs, + TyTy::BaseType *rhs) + { + // in order to probe of the correct type paths we need the root type, which + // strips any references + const TyTy::BaseType *root = lhs->get_root (); + + // look up lang item for arithmetic type + std::vector<PathProbeCandidate> candidates; + std::string associated_item_name + = Analysis::RustLangItem::ToString (lang_item_type); + DefId respective_lang_item_id = UNKNOWN_DEFID; + bool lang_item_defined + = mappings->lookup_lang_item (lang_item_type, &respective_lang_item_id); + + // probe for the lang-item + if (lang_item_defined) + { + bool receiver_is_type_param + = root->get_kind () == TyTy::TypeKind::PARAM; + bool receiver_is_dyn = root->get_kind () == TyTy::TypeKind::DYNAMIC; + + bool receiver_is_generic = receiver_is_type_param || receiver_is_dyn; + bool probe_bounds = true; + bool probe_impls = !receiver_is_generic; + bool ignore_mandatory_trait_items = !receiver_is_generic; + + candidates = PathProbeType::Probe ( + root, HIR::PathIdentSegment (associated_item_name), probe_impls, + probe_bounds, ignore_mandatory_trait_items, respective_lang_item_id); + } + + // autoderef + std::vector<Adjustment> adjustments; + PathProbeCandidate *resolved_candidate + = MethodResolution::Select (candidates, lhs, adjustments); + + // is this the case we are recursive + // handle the case where we are within the impl block for this lang_item + // otherwise we end up with a recursive operator overload such as the i32 + // operator overload trait + if (lang_item_defined && resolved_candidate != nullptr) + { + TypeCheckContextItem &fn_context = context->peek_context (); + if (fn_context.get_type () == TypeCheckContextItem::ItemType::IMPL_ITEM) + { + auto &impl_item = fn_context.get_impl_item (); + HIR::ImplBlock *parent = impl_item.first; + HIR::Function *fn = impl_item.second; + + if (parent->has_trait_ref () + && fn->get_function_name ().compare (associated_item_name) == 0) + { + TraitReference *trait_reference + = TraitResolver::Lookup (*parent->get_trait_ref ().get ()); + if (!trait_reference->is_error ()) + { + TyTy::BaseType *lookup = nullptr; + bool ok + = context->lookup_type (fn->get_mappings ().get_hirid (), + &lookup); + rust_assert (ok); + rust_assert (lookup->get_kind () == TyTy::TypeKind::FNDEF); + + TyTy::FnType *fntype = static_cast<TyTy::FnType *> (lookup); + rust_assert (fntype->is_method ()); + + Adjuster adj (lhs); + TyTy::BaseType *adjusted = adj.adjust_type (adjustments); + + bool is_lang_item_impl + = trait_reference->get_mappings ().get_defid () + == respective_lang_item_id; + bool self_is_lang_item_self + = fntype->get_self_type ()->is_equal (*adjusted); + bool recursive_operator_overload + = is_lang_item_impl && self_is_lang_item_self; + + lang_item_defined = !recursive_operator_overload; + } + } + } + } + + bool have_implementation_for_lang_item = resolved_candidate != nullptr; + if (!lang_item_defined || !have_implementation_for_lang_item) + { + // no operator overload exists for this + return false; + } + + // now its just like a method-call-expr + context->insert_receiver (expr.get_mappings ().get_hirid (), lhs); + + // store the adjustments for code-generation to know what to do + context->insert_autoderef_mappings (expr.get_mappings ().get_hirid (), + std::move (adjustments)); + + TyTy::BaseType *lookup_tyty = resolved_candidate->ty; + NodeId resolved_node_id + = resolved_candidate->is_impl_candidate () + ? resolved_candidate->item.impl.impl_item->get_impl_mappings () + .get_nodeid () + : resolved_candidate->item.trait.item_ref->get_mappings () + .get_nodeid (); + + rust_assert (lookup_tyty->get_kind () == TyTy::TypeKind::FNDEF); + TyTy::BaseType *lookup = lookup_tyty; + TyTy::FnType *fn = static_cast<TyTy::FnType *> (lookup); + rust_assert (fn->is_method ()); + + if (root->get_kind () == TyTy::TypeKind::ADT) + { + const TyTy::ADTType *adt = static_cast<const TyTy::ADTType *> (root); + if (adt->has_substitutions () && fn->needs_substitution ()) + { + // consider the case where we have: + // + // struct Foo<X,Y>(X,Y); + // + // impl<T> Foo<T, i32> { + // fn test<X>(self, a:X) -> (T,X) { (self.0, a) } + // } + // + // In this case we end up with an fn type of: + // + // fn <T,X> test(self:Foo<T,i32>, a:X) -> (T,X) + // + // This means the instance or self we are calling this method for + // will be substituted such that we can get the inherited type + // arguments but then need to use the turbo fish if available or + // infer the remaining arguments. Luckily rust does not allow for + // default types GenericParams on impl blocks since these must + // always be at the end of the list + + auto s = fn->get_self_type ()->get_root (); + rust_assert (s->can_eq (adt, false, false)); + rust_assert (s->get_kind () == TyTy::TypeKind::ADT); + const TyTy::ADTType *self_adt + = static_cast<const TyTy::ADTType *> (s); + + // we need to grab the Self substitutions as the inherit type + // parameters for this + if (self_adt->needs_substitution ()) + { + rust_assert (adt->was_substituted ()); + + TyTy::SubstitutionArgumentMappings used_args_in_prev_segment + = GetUsedSubstArgs::From (adt); + + TyTy::SubstitutionArgumentMappings inherit_type_args + = self_adt->solve_mappings_from_receiver_for_self ( + used_args_in_prev_segment); + + // there may or may not be inherited type arguments + if (!inherit_type_args.is_error ()) + { + // need to apply the inherited type arguments to the + // function + lookup = fn->handle_substitions (inherit_type_args); + } + } + } + } + + // type check the arguments + TyTy::FnType *type = static_cast<TyTy::FnType *> (lookup); + rust_assert (type->num_params () == 2); + auto fnparam = type->param_at (1); + fnparam.second->unify (rhs); // typecheck the rhs + + // get the return type + TyTy::BaseType *function_ret_tyty = fn->get_return_type ()->clone (); + + // store the expected fntype + context->insert_operator_overload (expr.get_mappings ().get_hirid (), type); + + // set up the resolved name on the path + resolver->insert_resolved_name (expr.get_mappings ().get_nodeid (), + resolved_node_id); + + // return the result of the function back + infered = function_ret_tyty; + + return true; + } + private: TypeCheckExpr (bool inside_loop) : TypeCheckBase (), infered (nullptr), infered_array_elems (nullptr), |