diff options
Diffstat (limited to 'clang/lib')
97 files changed, 2999 insertions, 949 deletions
diff --git a/clang/lib/AST/ByteCode/InterpBuiltin.cpp b/clang/lib/AST/ByteCode/InterpBuiltin.cpp index 6af7ef3..922d679 100644 --- a/clang/lib/AST/ByteCode/InterpBuiltin.cpp +++ b/clang/lib/AST/ByteCode/InterpBuiltin.cpp @@ -678,30 +678,6 @@ static bool interp__builtin_popcount(InterpState &S, CodePtr OpPC, return true; } -static bool interp__builtin_parity(InterpState &S, CodePtr OpPC, - const InterpFrame *Frame, - const CallExpr *Call) { - APSInt Val = popToAPSInt(S, Call->getArg(0)); - pushInteger(S, Val.popcount() % 2, Call->getType()); - return true; -} - -static bool interp__builtin_clrsb(InterpState &S, CodePtr OpPC, - const InterpFrame *Frame, - const CallExpr *Call) { - APSInt Val = popToAPSInt(S, Call->getArg(0)); - pushInteger(S, Val.getBitWidth() - Val.getSignificantBits(), Call->getType()); - return true; -} - -static bool interp__builtin_bitreverse(InterpState &S, CodePtr OpPC, - const InterpFrame *Frame, - const CallExpr *Call) { - APSInt Val = popToAPSInt(S, Call->getArg(0)); - pushInteger(S, Val.reverseBits(), Call->getType()); - return true; -} - static bool interp__builtin_classify_type(InterpState &S, CodePtr OpPC, const InterpFrame *Frame, const CallExpr *Call) { @@ -736,16 +712,6 @@ static bool interp__builtin_expect(InterpState &S, CodePtr OpPC, return true; } -static bool interp__builtin_ffs(InterpState &S, CodePtr OpPC, - const InterpFrame *Frame, - const CallExpr *Call) { - APSInt Value = popToAPSInt(S, Call->getArg(0)); - - uint64_t N = Value.countr_zero(); - pushInteger(S, N == Value.getBitWidth() ? 0 : N + 1, Call->getType()); - return true; -} - static bool interp__builtin_addressof(InterpState &S, CodePtr OpPC, const InterpFrame *Frame, const CallExpr *Call) { @@ -2314,10 +2280,14 @@ static bool interp__builtin_object_size(InterpState &S, CodePtr OpPC, if (Ptr.isBaseClass()) ByteOffset = computePointerOffset(ASTCtx, Ptr.getBase()) - computePointerOffset(ASTCtx, Ptr); - else - ByteOffset = - computePointerOffset(ASTCtx, Ptr) - - computePointerOffset(ASTCtx, Ptr.expand().atIndex(0).narrow()); + else { + if (Ptr.inArray()) + ByteOffset = + computePointerOffset(ASTCtx, Ptr) - + computePointerOffset(ASTCtx, Ptr.expand().atIndex(0).narrow()); + else + ByteOffset = 0; + } } else ByteOffset = computePointerOffset(ASTCtx, Ptr); @@ -2579,9 +2549,11 @@ static bool interp__builtin_elementwise_maxmin(InterpState &S, CodePtr OpPC, return true; } -static bool interp__builtin_ia32_pmul(InterpState &S, CodePtr OpPC, - const CallExpr *Call, - unsigned BuiltinID) { +static bool interp__builtin_ia32_pmul( + InterpState &S, CodePtr OpPC, const CallExpr *Call, + llvm::function_ref<APInt(const APSInt &, const APSInt &, const APSInt &, + const APSInt &)> + Fn) { assert(Call->getArg(0)->getType()->isVectorType() && Call->getArg(1)->getType()->isVectorType()); const Pointer &RHS = S.Stk.pop<Pointer>(); @@ -2590,35 +2562,23 @@ static bool interp__builtin_ia32_pmul(InterpState &S, CodePtr OpPC, const auto *VT = Call->getArg(0)->getType()->castAs<VectorType>(); PrimType ElemT = *S.getContext().classify(VT->getElementType()); - unsigned SourceLen = VT->getNumElements(); + unsigned NumElems = VT->getNumElements(); + const auto *DestVT = Call->getType()->castAs<VectorType>(); + PrimType DestElemT = *S.getContext().classify(DestVT->getElementType()); + bool DestUnsigned = Call->getType()->isUnsignedIntegerOrEnumerationType(); - PrimType DstElemT = *S.getContext().classify( - Call->getType()->castAs<VectorType>()->getElementType()); unsigned DstElem = 0; - for (unsigned I = 0; I != SourceLen; I += 2) { - APSInt Elem1; - APSInt Elem2; + for (unsigned I = 0; I != NumElems; I += 2) { + APSInt Result; INT_TYPE_SWITCH_NO_BOOL(ElemT, { - Elem1 = LHS.elem<T>(I).toAPSInt(); - Elem2 = RHS.elem<T>(I).toAPSInt(); + APSInt LoLHS = LHS.elem<T>(I).toAPSInt(); + APSInt HiLHS = LHS.elem<T>(I + 1).toAPSInt(); + APSInt LoRHS = RHS.elem<T>(I).toAPSInt(); + APSInt HiRHS = RHS.elem<T>(I + 1).toAPSInt(); + Result = APSInt(Fn(LoLHS, HiLHS, LoRHS, HiRHS), DestUnsigned); }); - APSInt Result; - switch (BuiltinID) { - case clang::X86::BI__builtin_ia32_pmuludq128: - case clang::X86::BI__builtin_ia32_pmuludq256: - case clang::X86::BI__builtin_ia32_pmuludq512: - Result = APSInt(llvm::APIntOps::muluExtended(Elem1, Elem2), - /*IsUnsigned=*/true); - break; - case clang::X86::BI__builtin_ia32_pmuldq128: - case clang::X86::BI__builtin_ia32_pmuldq256: - case clang::X86::BI__builtin_ia32_pmuldq512: - Result = APSInt(llvm::APIntOps::mulsExtended(Elem1, Elem2), - /*IsUnsigned=*/false); - break; - } - INT_TYPE_SWITCH_NO_BOOL(DstElemT, + INT_TYPE_SWITCH_NO_BOOL(DestElemT, { Dst.elem<T>(DstElem) = static_cast<T>(Result); }); ++DstElem; } @@ -3154,18 +3114,25 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call, case Builtin::BI__builtin_parity: case Builtin::BI__builtin_parityl: case Builtin::BI__builtin_parityll: - return interp__builtin_parity(S, OpPC, Frame, Call); - + return interp__builtin_elementwise_int_unaryop( + S, OpPC, Call, [](const APSInt &Val) -> APInt { + return APInt(Val.getBitWidth(), Val.popcount() % 2); + }); case Builtin::BI__builtin_clrsb: case Builtin::BI__builtin_clrsbl: case Builtin::BI__builtin_clrsbll: - return interp__builtin_clrsb(S, OpPC, Frame, Call); - + return interp__builtin_elementwise_int_unaryop( + S, OpPC, Call, [](const APSInt &Val) -> APInt { + return APInt(Val.getBitWidth(), + Val.getBitWidth() - Val.getSignificantBits()); + }); case Builtin::BI__builtin_bitreverse8: case Builtin::BI__builtin_bitreverse16: case Builtin::BI__builtin_bitreverse32: case Builtin::BI__builtin_bitreverse64: - return interp__builtin_bitreverse(S, OpPC, Frame, Call); + return interp__builtin_elementwise_int_unaryop( + S, OpPC, Call, + [](const APSInt &Val) -> APInt { return Val.reverseBits(); }); case Builtin::BI__builtin_classify_type: return interp__builtin_classify_type(S, OpPC, Frame, Call); @@ -3205,7 +3172,11 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call, case Builtin::BI__builtin_ffs: case Builtin::BI__builtin_ffsl: case Builtin::BI__builtin_ffsll: - return interp__builtin_ffs(S, OpPC, Frame, Call); + return interp__builtin_elementwise_int_unaryop( + S, OpPC, Call, [](const APSInt &Val) { + return APInt(Val.getBitWidth(), + Val.isZero() ? 0u : Val.countTrailingZeros() + 1u); + }); case Builtin::BIaddressof: case Builtin::BI__addressof: @@ -3490,6 +3461,30 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call, return interp__builtin_elementwise_int_binop(S, OpPC, Call, llvm::APIntOps::avgCeilU); + case clang::X86::BI__builtin_ia32_pmaddubsw128: + case clang::X86::BI__builtin_ia32_pmaddubsw256: + case clang::X86::BI__builtin_ia32_pmaddubsw512: + return interp__builtin_ia32_pmul( + S, OpPC, Call, + [](const APSInt &LoLHS, const APSInt &HiLHS, const APSInt &LoRHS, + const APSInt &HiRHS) { + unsigned BitWidth = 2 * LoLHS.getBitWidth(); + return (LoLHS.zext(BitWidth) * LoRHS.sext(BitWidth)) + .sadd_sat((HiLHS.zext(BitWidth) * HiRHS.sext(BitWidth))); + }); + + case clang::X86::BI__builtin_ia32_pmaddwd128: + case clang::X86::BI__builtin_ia32_pmaddwd256: + case clang::X86::BI__builtin_ia32_pmaddwd512: + return interp__builtin_ia32_pmul( + S, OpPC, Call, + [](const APSInt &LoLHS, const APSInt &HiLHS, const APSInt &LoRHS, + const APSInt &HiRHS) { + unsigned BitWidth = 2 * LoLHS.getBitWidth(); + return (LoLHS.sext(BitWidth) * LoRHS.sext(BitWidth)) + + (HiLHS.sext(BitWidth) * HiRHS.sext(BitWidth)); + }); + case clang::X86::BI__builtin_ia32_pmulhuw128: case clang::X86::BI__builtin_ia32_pmulhuw256: case clang::X86::BI__builtin_ia32_pmulhuw512: @@ -3634,10 +3629,22 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call, case clang::X86::BI__builtin_ia32_pmuldq128: case clang::X86::BI__builtin_ia32_pmuldq256: case clang::X86::BI__builtin_ia32_pmuldq512: + return interp__builtin_ia32_pmul( + S, OpPC, Call, + [](const APSInt &LoLHS, const APSInt &HiLHS, const APSInt &LoRHS, + const APSInt &HiRHS) { + return llvm::APIntOps::mulsExtended(LoLHS, LoRHS); + }); + case clang::X86::BI__builtin_ia32_pmuludq128: case clang::X86::BI__builtin_ia32_pmuludq256: case clang::X86::BI__builtin_ia32_pmuludq512: - return interp__builtin_ia32_pmul(S, OpPC, Call, BuiltinID); + return interp__builtin_ia32_pmul( + S, OpPC, Call, + [](const APSInt &LoLHS, const APSInt &HiLHS, const APSInt &LoRHS, + const APSInt &HiRHS) { + return llvm::APIntOps::muluExtended(LoLHS, LoRHS); + }); case Builtin::BI__builtin_elementwise_fma: return interp__builtin_elementwise_triop_fp( diff --git a/clang/lib/AST/DeclPrinter.cpp b/clang/lib/AST/DeclPrinter.cpp index 7001ade..7f3dcca 100644 --- a/clang/lib/AST/DeclPrinter.cpp +++ b/clang/lib/AST/DeclPrinter.cpp @@ -111,6 +111,7 @@ namespace { void VisitOMPCapturedExprDecl(OMPCapturedExprDecl *D); void VisitTemplateTypeParmDecl(const TemplateTypeParmDecl *TTP); void VisitNonTypeTemplateParmDecl(const NonTypeTemplateParmDecl *NTTP); + void VisitTemplateTemplateParmDecl(const TemplateTemplateParmDecl *); void VisitHLSLBufferDecl(HLSLBufferDecl *D); void VisitOpenACCDeclareDecl(OpenACCDeclareDecl *D); @@ -1189,8 +1190,7 @@ void DeclPrinter::printTemplateParameters(const TemplateParameterList *Params, } else if (auto NTTP = dyn_cast<NonTypeTemplateParmDecl>(Param)) { VisitNonTypeTemplateParmDecl(NTTP); } else if (auto TTPD = dyn_cast<TemplateTemplateParmDecl>(Param)) { - VisitTemplateDecl(TTPD); - // FIXME: print the default argument, if present. + VisitTemplateTemplateParmDecl(TTPD); } } @@ -1916,6 +1916,16 @@ void DeclPrinter::VisitNonTypeTemplateParmDecl( } } +void DeclPrinter::VisitTemplateTemplateParmDecl( + const TemplateTemplateParmDecl *TTPD) { + VisitTemplateDecl(TTPD); + if (TTPD->hasDefaultArgument() && !TTPD->defaultArgumentWasInherited()) { + Out << " = "; + TTPD->getDefaultArgument().getArgument().print(Policy, Out, + /*IncludeType=*/false); + } +} + void DeclPrinter::VisitOpenACCDeclareDecl(OpenACCDeclareDecl *D) { if (!D->isInvalidDecl()) { Out << "#pragma acc declare"; diff --git a/clang/lib/AST/DeclTemplate.cpp b/clang/lib/AST/DeclTemplate.cpp index e5fba1b..c0be986 100644 --- a/clang/lib/AST/DeclTemplate.cpp +++ b/clang/lib/AST/DeclTemplate.cpp @@ -1653,57 +1653,65 @@ void TemplateParamObjectDecl::printAsInit(llvm::raw_ostream &OS, getValue().printPretty(OS, Policy, getType(), &getASTContext()); } -TemplateParameterList *clang::getReplacedTemplateParameterList(const Decl *D) { +std::tuple<NamedDecl *, TemplateArgument> +clang::getReplacedTemplateParameter(Decl *D, unsigned Index) { switch (D->getKind()) { - case Decl::Kind::CXXRecord: - return cast<CXXRecordDecl>(D) - ->getDescribedTemplate() - ->getTemplateParameters(); + case Decl::Kind::BuiltinTemplate: case Decl::Kind::ClassTemplate: - return cast<ClassTemplateDecl>(D)->getTemplateParameters(); + case Decl::Kind::Concept: + case Decl::Kind::FunctionTemplate: + case Decl::Kind::TemplateTemplateParm: + case Decl::Kind::TypeAliasTemplate: + case Decl::Kind::VarTemplate: + return {cast<TemplateDecl>(D)->getTemplateParameters()->getParam(Index), + {}}; case Decl::Kind::ClassTemplateSpecialization: { const auto *CTSD = cast<ClassTemplateSpecializationDecl>(D); auto P = CTSD->getSpecializedTemplateOrPartial(); + TemplateParameterList *TPL; if (const auto *CTPSD = dyn_cast<ClassTemplatePartialSpecializationDecl *>(P)) - return CTPSD->getTemplateParameters(); - return cast<ClassTemplateDecl *>(P)->getTemplateParameters(); + TPL = CTPSD->getTemplateParameters(); + else + TPL = cast<ClassTemplateDecl *>(P)->getTemplateParameters(); + return {TPL->getParam(Index), CTSD->getTemplateArgs()[Index]}; + } + case Decl::Kind::VarTemplateSpecialization: { + const auto *VTSD = cast<VarTemplateSpecializationDecl>(D); + auto P = VTSD->getSpecializedTemplateOrPartial(); + TemplateParameterList *TPL; + if (const auto *VTPSD = dyn_cast<VarTemplatePartialSpecializationDecl *>(P)) + TPL = VTPSD->getTemplateParameters(); + else + TPL = cast<VarTemplateDecl *>(P)->getTemplateParameters(); + return {TPL->getParam(Index), VTSD->getTemplateArgs()[Index]}; } case Decl::Kind::ClassTemplatePartialSpecialization: - return cast<ClassTemplatePartialSpecializationDecl>(D) - ->getTemplateParameters(); - case Decl::Kind::TypeAliasTemplate: - return cast<TypeAliasTemplateDecl>(D)->getTemplateParameters(); - case Decl::Kind::BuiltinTemplate: - return cast<BuiltinTemplateDecl>(D)->getTemplateParameters(); + return {cast<ClassTemplatePartialSpecializationDecl>(D) + ->getTemplateParameters() + ->getParam(Index), + {}}; + case Decl::Kind::VarTemplatePartialSpecialization: + return {cast<VarTemplatePartialSpecializationDecl>(D) + ->getTemplateParameters() + ->getParam(Index), + {}}; + // This is used as the AssociatedDecl for placeholder type deduction. + case Decl::TemplateTypeParm: + return {cast<NamedDecl>(D), {}}; + // FIXME: Always use the template decl as the AssociatedDecl. + case Decl::Kind::CXXRecord: + return getReplacedTemplateParameter( + cast<CXXRecordDecl>(D)->getDescribedClassTemplate(), Index); case Decl::Kind::CXXDeductionGuide: case Decl::Kind::CXXConversion: case Decl::Kind::CXXConstructor: case Decl::Kind::CXXDestructor: case Decl::Kind::CXXMethod: case Decl::Kind::Function: - return cast<FunctionDecl>(D) - ->getTemplateSpecializationInfo() - ->getTemplate() - ->getTemplateParameters(); - case Decl::Kind::FunctionTemplate: - return cast<FunctionTemplateDecl>(D)->getTemplateParameters(); - case Decl::Kind::VarTemplate: - return cast<VarTemplateDecl>(D)->getTemplateParameters(); - case Decl::Kind::VarTemplateSpecialization: { - const auto *VTSD = cast<VarTemplateSpecializationDecl>(D); - auto P = VTSD->getSpecializedTemplateOrPartial(); - if (const auto *VTPSD = dyn_cast<VarTemplatePartialSpecializationDecl *>(P)) - return VTPSD->getTemplateParameters(); - return cast<VarTemplateDecl *>(P)->getTemplateParameters(); - } - case Decl::Kind::VarTemplatePartialSpecialization: - return cast<VarTemplatePartialSpecializationDecl>(D) - ->getTemplateParameters(); - case Decl::Kind::TemplateTemplateParm: - return cast<TemplateTemplateParmDecl>(D)->getTemplateParameters(); - case Decl::Kind::Concept: - return cast<ConceptDecl>(D)->getTemplateParameters(); + return getReplacedTemplateParameter( + cast<FunctionDecl>(D)->getTemplateSpecializationInfo()->getTemplate(), + Index); default: llvm_unreachable("Unhandled templated declaration kind"); } diff --git a/clang/lib/AST/ExprCXX.cpp b/clang/lib/AST/ExprCXX.cpp index 95de6a8..c7f0ff0 100644 --- a/clang/lib/AST/ExprCXX.cpp +++ b/clang/lib/AST/ExprCXX.cpp @@ -1727,7 +1727,7 @@ SizeOfPackExpr *SizeOfPackExpr::CreateDeserialized(ASTContext &Context, NonTypeTemplateParmDecl *SubstNonTypeTemplateParmExpr::getParameter() const { return cast<NonTypeTemplateParmDecl>( - getReplacedTemplateParameterList(getAssociatedDecl())->asArray()[Index]); + std::get<0>(getReplacedTemplateParameter(getAssociatedDecl(), Index))); } PackIndexingExpr *PackIndexingExpr::Create( @@ -1793,7 +1793,7 @@ SubstNonTypeTemplateParmPackExpr::SubstNonTypeTemplateParmPackExpr( NonTypeTemplateParmDecl * SubstNonTypeTemplateParmPackExpr::getParameterPack() const { return cast<NonTypeTemplateParmDecl>( - getReplacedTemplateParameterList(getAssociatedDecl())->asArray()[Index]); + std::get<0>(getReplacedTemplateParameter(getAssociatedDecl(), Index))); } TemplateArgument SubstNonTypeTemplateParmPackExpr::getArgumentPack() const { diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp index 618e163..35a866e 100644 --- a/clang/lib/AST/ExprConstant.cpp +++ b/clang/lib/AST/ExprConstant.cpp @@ -11778,6 +11778,54 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) { case clang::X86::BI__builtin_ia32_pavgw512: return EvaluateBinOpExpr(llvm::APIntOps::avgCeilU); + case clang::X86::BI__builtin_ia32_pmaddubsw128: + case clang::X86::BI__builtin_ia32_pmaddubsw256: + case clang::X86::BI__builtin_ia32_pmaddubsw512: + case clang::X86::BI__builtin_ia32_pmaddwd128: + case clang::X86::BI__builtin_ia32_pmaddwd256: + case clang::X86::BI__builtin_ia32_pmaddwd512: { + APValue SourceLHS, SourceRHS; + if (!EvaluateAsRValue(Info, E->getArg(0), SourceLHS) || + !EvaluateAsRValue(Info, E->getArg(1), SourceRHS)) + return false; + + auto *DestTy = E->getType()->castAs<VectorType>(); + QualType DestEltTy = DestTy->getElementType(); + unsigned SourceLen = SourceLHS.getVectorLength(); + bool DestUnsigned = DestEltTy->isUnsignedIntegerOrEnumerationType(); + SmallVector<APValue, 4> ResultElements; + ResultElements.reserve(SourceLen / 2); + + for (unsigned EltNum = 0; EltNum < SourceLen; EltNum += 2) { + const APSInt &LoLHS = SourceLHS.getVectorElt(EltNum).getInt(); + const APSInt &HiLHS = SourceLHS.getVectorElt(EltNum + 1).getInt(); + const APSInt &LoRHS = SourceRHS.getVectorElt(EltNum).getInt(); + const APSInt &HiRHS = SourceRHS.getVectorElt(EltNum + 1).getInt(); + unsigned BitWidth = 2 * LoLHS.getBitWidth(); + + switch (E->getBuiltinCallee()) { + case clang::X86::BI__builtin_ia32_pmaddubsw128: + case clang::X86::BI__builtin_ia32_pmaddubsw256: + case clang::X86::BI__builtin_ia32_pmaddubsw512: + ResultElements.push_back(APValue( + APSInt((LoLHS.zext(BitWidth) * LoRHS.sext(BitWidth)) + .sadd_sat((HiLHS.zext(BitWidth) * HiRHS.sext(BitWidth))), + DestUnsigned))); + break; + case clang::X86::BI__builtin_ia32_pmaddwd128: + case clang::X86::BI__builtin_ia32_pmaddwd256: + case clang::X86::BI__builtin_ia32_pmaddwd512: + ResultElements.push_back( + APValue(APSInt((LoLHS.sext(BitWidth) * LoRHS.sext(BitWidth)) + + (HiLHS.sext(BitWidth) * HiRHS.sext(BitWidth)), + DestUnsigned))); + break; + } + } + + return Success(APValue(ResultElements.data(), ResultElements.size()), E); + } + case clang::X86::BI__builtin_ia32_pmulhuw128: case clang::X86::BI__builtin_ia32_pmulhuw256: case clang::X86::BI__builtin_ia32_pmulhuw512: diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp index f3b5478..3cd033e 100644 --- a/clang/lib/AST/StmtProfile.cpp +++ b/clang/lib/AST/StmtProfile.cpp @@ -2769,10 +2769,19 @@ void OpenACCClauseProfiler::VisitReductionClause( for (auto &Recipe : Clause.getRecipes()) { Profiler.VisitDecl(Recipe.AllocaDecl); + // TODO: OpenACC: Make sure we remember to update this when we figure out // what we're adding for the operation recipe, in the meantime, a static // assert will make sure we don't add something. - static_assert(sizeof(OpenACCReductionRecipe) == sizeof(int *)); + static_assert(sizeof(OpenACCReductionRecipe::CombinerRecipe) == + 3 * sizeof(int *)); + for (auto &CombinerRecipe : Recipe.CombinerRecipes) { + if (CombinerRecipe.Op) { + Profiler.VisitDecl(CombinerRecipe.LHS); + Profiler.VisitDecl(CombinerRecipe.RHS); + Profiler.VisitStmt(CombinerRecipe.Op); + } + } } } diff --git a/clang/lib/AST/TemplateName.cpp b/clang/lib/AST/TemplateName.cpp index 2b8044e..797a354 100644 --- a/clang/lib/AST/TemplateName.cpp +++ b/clang/lib/AST/TemplateName.cpp @@ -64,16 +64,14 @@ SubstTemplateTemplateParmPackStorage::getArgumentPack() const { TemplateTemplateParmDecl * SubstTemplateTemplateParmPackStorage::getParameterPack() const { - return cast<TemplateTemplateParmDecl>( - getReplacedTemplateParameterList(getAssociatedDecl()) - ->asArray()[Bits.Index]); + return cast<TemplateTemplateParmDecl>(std::get<0>( + getReplacedTemplateParameter(getAssociatedDecl(), Bits.Index))); } TemplateTemplateParmDecl * SubstTemplateTemplateParmStorage::getParameter() const { - return cast<TemplateTemplateParmDecl>( - getReplacedTemplateParameterList(getAssociatedDecl()) - ->asArray()[Bits.Index]); + return cast<TemplateTemplateParmDecl>(std::get<0>( + getReplacedTemplateParameter(getAssociatedDecl(), Bits.Index))); } void SubstTemplateTemplateParmStorage::Profile(llvm::FoldingSetNodeID &ID) { diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp index 9794314..ee7a68e 100644 --- a/clang/lib/AST/Type.cpp +++ b/clang/lib/AST/Type.cpp @@ -4436,14 +4436,6 @@ IdentifierInfo *TemplateTypeParmType::getIdentifier() const { return isCanonicalUnqualified() ? nullptr : getDecl()->getIdentifier(); } -static const TemplateTypeParmDecl *getReplacedParameter(Decl *D, - unsigned Index) { - if (const auto *TTP = dyn_cast<TemplateTypeParmDecl>(D)) - return TTP; - return cast<TemplateTypeParmDecl>( - getReplacedTemplateParameterList(D)->getParam(Index)); -} - SubstTemplateTypeParmType::SubstTemplateTypeParmType(QualType Replacement, Decl *AssociatedDecl, unsigned Index, @@ -4466,7 +4458,8 @@ SubstTemplateTypeParmType::SubstTemplateTypeParmType(QualType Replacement, const TemplateTypeParmDecl * SubstTemplateTypeParmType::getReplacedParameter() const { - return ::getReplacedParameter(getAssociatedDecl(), getIndex()); + return cast<TemplateTypeParmDecl>(std::get<0>( + getReplacedTemplateParameter(getAssociatedDecl(), getIndex()))); } void SubstTemplateTypeParmType::Profile(llvm::FoldingSetNodeID &ID, @@ -4532,7 +4525,8 @@ bool SubstTemplateTypeParmPackType::getFinal() const { const TemplateTypeParmDecl * SubstTemplateTypeParmPackType::getReplacedParameter() const { - return ::getReplacedParameter(getAssociatedDecl(), getIndex()); + return cast<TemplateTypeParmDecl>(std::get<0>( + getReplacedTemplateParameter(getAssociatedDecl(), getIndex()))); } IdentifierInfo *SubstTemplateTypeParmPackType::getIdentifier() const { diff --git a/clang/lib/ASTMatchers/CMakeLists.txt b/clang/lib/ASTMatchers/CMakeLists.txt index 7769fd6..29ad27df 100644 --- a/clang/lib/ASTMatchers/CMakeLists.txt +++ b/clang/lib/ASTMatchers/CMakeLists.txt @@ -8,7 +8,6 @@ set(LLVM_LINK_COMPONENTS add_clang_library(clangASTMatchers ASTMatchFinder.cpp ASTMatchersInternal.cpp - GtestMatchers.cpp LowLevelHelpers.cpp LINK_LIBS diff --git a/clang/lib/ASTMatchers/GtestMatchers.cpp b/clang/lib/ASTMatchers/GtestMatchers.cpp deleted file mode 100644 index 7c135bb..0000000 --- a/clang/lib/ASTMatchers/GtestMatchers.cpp +++ /dev/null @@ -1,228 +0,0 @@ -//===- GtestMatchers.cpp - AST Matchers for Gtest ---------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements several matchers for popular gtest macros. In general, -// AST matchers cannot match calls to macros. However, we can simulate such -// matches if the macro definition has identifiable elements that themselves can -// be matched. In that case, we can match on those elements and then check that -// the match occurs within an expansion of the desired macro. The more uncommon -// the identified elements, the more efficient this process will be. -// -//===----------------------------------------------------------------------===// - -#include "clang/ASTMatchers/GtestMatchers.h" -#include "llvm/ADT/StringRef.h" - -namespace clang { -namespace ast_matchers { -namespace { - -enum class MacroType { - Expect, - Assert, - On, -}; - -} // namespace - -static DeclarationMatcher getComparisonDecl(GtestCmp Cmp) { - switch (Cmp) { - case GtestCmp::Eq: - return cxxMethodDecl(hasName("Compare"), - ofClass(cxxRecordDecl(isSameOrDerivedFrom( - hasName("::testing::internal::EqHelper"))))); - case GtestCmp::Ne: - return functionDecl(hasName("::testing::internal::CmpHelperNE")); - case GtestCmp::Ge: - return functionDecl(hasName("::testing::internal::CmpHelperGE")); - case GtestCmp::Gt: - return functionDecl(hasName("::testing::internal::CmpHelperGT")); - case GtestCmp::Le: - return functionDecl(hasName("::testing::internal::CmpHelperLE")); - case GtestCmp::Lt: - return functionDecl(hasName("::testing::internal::CmpHelperLT")); - } - llvm_unreachable("Unhandled GtestCmp enum"); -} - -static llvm::StringRef getMacroTypeName(MacroType Macro) { - switch (Macro) { - case MacroType::Expect: - return "EXPECT"; - case MacroType::Assert: - return "ASSERT"; - case MacroType::On: - return "ON"; - } - llvm_unreachable("Unhandled MacroType enum"); -} - -static llvm::StringRef getComparisonTypeName(GtestCmp Cmp) { - switch (Cmp) { - case GtestCmp::Eq: - return "EQ"; - case GtestCmp::Ne: - return "NE"; - case GtestCmp::Ge: - return "GE"; - case GtestCmp::Gt: - return "GT"; - case GtestCmp::Le: - return "LE"; - case GtestCmp::Lt: - return "LT"; - } - llvm_unreachable("Unhandled GtestCmp enum"); -} - -static std::string getMacroName(MacroType Macro, GtestCmp Cmp) { - return (getMacroTypeName(Macro) + "_" + getComparisonTypeName(Cmp)).str(); -} - -static std::string getMacroName(MacroType Macro, llvm::StringRef Operation) { - return (getMacroTypeName(Macro) + "_" + Operation).str(); -} - -// Under the hood, ON_CALL is expanded to a call to `InternalDefaultActionSetAt` -// to set a default action spec to the underlying function mocker, while -// EXPECT_CALL is expanded to a call to `InternalExpectedAt` to set a new -// expectation spec. -static llvm::StringRef getSpecSetterName(MacroType Macro) { - switch (Macro) { - case MacroType::On: - return "InternalDefaultActionSetAt"; - case MacroType::Expect: - return "InternalExpectedAt"; - default: - llvm_unreachable("Unhandled MacroType enum"); - } - llvm_unreachable("Unhandled MacroType enum"); -} - -// In general, AST matchers cannot match calls to macros. However, we can -// simulate such matches if the macro definition has identifiable elements that -// themselves can be matched. In that case, we can match on those elements and -// then check that the match occurs within an expansion of the desired -// macro. The more uncommon the identified elements, the more efficient this -// process will be. -// -// We use this approach to implement the derived matchers gtestAssert and -// gtestExpect. -static internal::BindableMatcher<Stmt> -gtestComparisonInternal(MacroType Macro, GtestCmp Cmp, StatementMatcher Left, - StatementMatcher Right) { - return callExpr(isExpandedFromMacro(getMacroName(Macro, Cmp)), - callee(getComparisonDecl(Cmp)), hasArgument(2, Left), - hasArgument(3, Right)); -} - -static internal::BindableMatcher<Stmt> -gtestThatInternal(MacroType Macro, StatementMatcher Actual, - StatementMatcher Matcher) { - return cxxOperatorCallExpr( - isExpandedFromMacro(getMacroName(Macro, "THAT")), - hasOverloadedOperatorName("()"), hasArgument(2, Actual), - hasArgument( - 0, expr(hasType(classTemplateSpecializationDecl(hasName( - "::testing::internal::PredicateFormatterFromMatcher"))), - ignoringImplicit( - callExpr(callee(functionDecl(hasName( - "::testing::internal::" - "MakePredicateFormatterFromMatcher"))), - hasArgument(0, ignoringImplicit(Matcher))))))); -} - -static internal::BindableMatcher<Stmt> -gtestCallInternal(MacroType Macro, StatementMatcher MockCall, MockArgs Args) { - // A ON_CALL or EXPECT_CALL macro expands to different AST structures - // depending on whether the mock method has arguments or not. - switch (Args) { - // For example, - // `ON_CALL(mock, TwoParamMethod)` is expanded to - // `mock.gmock_TwoArgsMethod(WithoutMatchers(), - // nullptr).InternalDefaultActionSetAt(...)`. - // EXPECT_CALL is the same except - // that it calls `InternalExpectedAt` instead of `InternalDefaultActionSetAt` - // in the end. - case MockArgs::None: - return cxxMemberCallExpr( - isExpandedFromMacro(getMacroName(Macro, "CALL")), - callee(functionDecl(hasName(getSpecSetterName(Macro)))), - onImplicitObjectArgument(ignoringImplicit(MockCall))); - // For example, - // `ON_CALL(mock, TwoParamMethod(m1, m2))` is expanded to - // `mock.gmock_TwoParamMethod(m1,m2)(WithoutMatchers(), - // nullptr).InternalDefaultActionSetAt(...)`. - // EXPECT_CALL is the same except that it calls `InternalExpectedAt` instead - // of `InternalDefaultActionSetAt` in the end. - case MockArgs::Some: - return cxxMemberCallExpr( - isExpandedFromMacro(getMacroName(Macro, "CALL")), - callee(functionDecl(hasName(getSpecSetterName(Macro)))), - onImplicitObjectArgument(ignoringImplicit(cxxOperatorCallExpr( - hasOverloadedOperatorName("()"), argumentCountIs(3), - hasArgument(0, ignoringImplicit(MockCall)))))); - } - llvm_unreachable("Unhandled MockArgs enum"); -} - -static internal::BindableMatcher<Stmt> -gtestCallInternal(MacroType Macro, StatementMatcher MockObject, - llvm::StringRef MockMethodName, MockArgs Args) { - return gtestCallInternal( - Macro, - cxxMemberCallExpr( - onImplicitObjectArgument(MockObject), - callee(functionDecl(hasName(("gmock_" + MockMethodName).str())))), - Args); -} - -internal::BindableMatcher<Stmt> gtestAssert(GtestCmp Cmp, StatementMatcher Left, - StatementMatcher Right) { - return gtestComparisonInternal(MacroType::Assert, Cmp, Left, Right); -} - -internal::BindableMatcher<Stmt> gtestExpect(GtestCmp Cmp, StatementMatcher Left, - StatementMatcher Right) { - return gtestComparisonInternal(MacroType::Expect, Cmp, Left, Right); -} - -internal::BindableMatcher<Stmt> gtestAssertThat(StatementMatcher Actual, - StatementMatcher Matcher) { - return gtestThatInternal(MacroType::Assert, Actual, Matcher); -} - -internal::BindableMatcher<Stmt> gtestExpectThat(StatementMatcher Actual, - StatementMatcher Matcher) { - return gtestThatInternal(MacroType::Expect, Actual, Matcher); -} - -internal::BindableMatcher<Stmt> gtestOnCall(StatementMatcher MockObject, - llvm::StringRef MockMethodName, - MockArgs Args) { - return gtestCallInternal(MacroType::On, MockObject, MockMethodName, Args); -} - -internal::BindableMatcher<Stmt> gtestOnCall(StatementMatcher MockCall, - MockArgs Args) { - return gtestCallInternal(MacroType::On, MockCall, Args); -} - -internal::BindableMatcher<Stmt> gtestExpectCall(StatementMatcher MockObject, - llvm::StringRef MockMethodName, - MockArgs Args) { - return gtestCallInternal(MacroType::Expect, MockObject, MockMethodName, Args); -} - -internal::BindableMatcher<Stmt> gtestExpectCall(StatementMatcher MockCall, - MockArgs Args) { - return gtestCallInternal(MacroType::Expect, MockCall, Args); -} - -} // end namespace ast_matchers -} // end namespace clang diff --git a/clang/lib/Analysis/ExprMutationAnalyzer.cpp b/clang/lib/Analysis/ExprMutationAnalyzer.cpp index 3fcd348..1e376da 100644 --- a/clang/lib/Analysis/ExprMutationAnalyzer.cpp +++ b/clang/lib/Analysis/ExprMutationAnalyzer.cpp @@ -755,22 +755,23 @@ ExprMutationAnalyzer::Analyzer::findPointeeMemberMutation(const Expr *Exp) { const Stmt * ExprMutationAnalyzer::Analyzer::findPointeeToNonConst(const Expr *Exp) { - const auto NonConstPointerOrDependentType = - type(anyOf(nonConstPointerType(), isDependentType())); + const auto NonConstPointerOrNonConstRefOrDependentType = type( + anyOf(nonConstPointerType(), nonConstReferenceType(), isDependentType())); // assign const auto InitToNonConst = - varDecl(hasType(NonConstPointerOrDependentType), + varDecl(hasType(NonConstPointerOrNonConstRefOrDependentType), hasInitializer(expr(canResolveToExprPointee(Exp)).bind("stmt"))); - const auto AssignToNonConst = - binaryOperation(hasOperatorName("="), - hasLHS(expr(hasType(NonConstPointerOrDependentType))), - hasRHS(canResolveToExprPointee(Exp))); + const auto AssignToNonConst = binaryOperation( + hasOperatorName("="), + hasLHS(expr(hasType(NonConstPointerOrNonConstRefOrDependentType))), + hasRHS(canResolveToExprPointee(Exp))); // arguments like const auto ArgOfInstantiationDependent = allOf( hasAnyArgument(canResolveToExprPointee(Exp)), isInstantiationDependent()); - const auto ArgOfNonConstParameter = forEachArgumentWithParamType( - canResolveToExprPointee(Exp), NonConstPointerOrDependentType); + const auto ArgOfNonConstParameter = + forEachArgumentWithParamType(canResolveToExprPointee(Exp), + NonConstPointerOrNonConstRefOrDependentType); const auto CallLikeMatcher = anyOf(ArgOfNonConstParameter, ArgOfInstantiationDependent); const auto PassAsNonConstArg = @@ -779,9 +780,9 @@ ExprMutationAnalyzer::Analyzer::findPointeeToNonConst(const Expr *Exp) { parenListExpr(has(canResolveToExprPointee(Exp))), initListExpr(hasAnyInit(canResolveToExprPointee(Exp))))); // cast - const auto CastToNonConst = - explicitCastExpr(hasSourceExpression(canResolveToExprPointee(Exp)), - hasDestinationType(NonConstPointerOrDependentType)); + const auto CastToNonConst = explicitCastExpr( + hasSourceExpression(canResolveToExprPointee(Exp)), + hasDestinationType(NonConstPointerOrNonConstRefOrDependentType)); // capture // FIXME: false positive if the pointee does not change in lambda diff --git a/clang/lib/Analysis/LifetimeSafety.cpp b/clang/lib/Analysis/LifetimeSafety.cpp index c18b8fb..6196ec3 100644 --- a/clang/lib/Analysis/LifetimeSafety.cpp +++ b/clang/lib/Analysis/LifetimeSafety.cpp @@ -19,12 +19,13 @@ #include "llvm/ADT/ImmutableMap.h" #include "llvm/ADT/ImmutableSet.h" #include "llvm/ADT/PointerUnion.h" -#include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/TimeProfiler.h" #include <cstdint> #include <memory> +#include <optional> namespace clang::lifetimes { namespace internal { @@ -872,22 +873,19 @@ public: InStates[Start] = D.getInitialState(); W.enqueueBlock(Start); - llvm::SmallBitVector Visited(Cfg.getNumBlockIDs() + 1); - while (const CFGBlock *B = W.dequeue()) { - Lattice StateIn = getInState(B); + Lattice StateIn = *getInState(B); Lattice StateOut = transferBlock(B, StateIn); OutStates[B] = StateOut; - Visited.set(B->getBlockID()); for (const CFGBlock *AdjacentB : isForward() ? B->succs() : B->preds()) { if (!AdjacentB) continue; - Lattice OldInState = getInState(AdjacentB); - Lattice NewInState = D.join(OldInState, StateOut); + std::optional<Lattice> OldInState = getInState(AdjacentB); + Lattice NewInState = + !OldInState ? StateOut : D.join(*OldInState, StateOut); // Enqueue the adjacent block if its in-state has changed or if we have - // never visited it. - if (!Visited.test(AdjacentB->getBlockID()) || - NewInState != OldInState) { + // never seen it. + if (!OldInState || NewInState != *OldInState) { InStates[AdjacentB] = NewInState; W.enqueueBlock(AdjacentB); } @@ -898,7 +896,12 @@ public: protected: Lattice getState(ProgramPoint P) const { return PerPointStates.lookup(P); } - Lattice getInState(const CFGBlock *B) const { return InStates.lookup(B); } + std::optional<Lattice> getInState(const CFGBlock *B) const { + auto It = InStates.find(B); + if (It == InStates.end()) + return std::nullopt; + return It->second; + } Lattice getOutState(const CFGBlock *B) const { return OutStates.lookup(B); } @@ -974,19 +977,21 @@ static llvm::ImmutableSet<T> join(llvm::ImmutableSet<T> A, return A; } -/// Checks if set A is a subset of set B. -template <typename T> -static bool isSubsetOf(const llvm::ImmutableSet<T> &A, - const llvm::ImmutableSet<T> &B) { - // Empty set is a subset of all sets. - if (A.isEmpty()) - return true; - - for (const T &Elem : A) - if (!B.contains(Elem)) - return false; - return true; -} +/// Describes the strategy for joining two `ImmutableMap` instances, primarily +/// differing in how they handle keys that are unique to one of the maps. +/// +/// A `Symmetric` join is universally correct, while an `Asymmetric` join +/// serves as a performance optimization. The latter is applicable only when the +/// join operation possesses a left identity element, allowing for a more +/// efficient, one-sided merge. +enum class JoinKind { + /// A symmetric join applies the `JoinValues` operation to keys unique to + /// either map, ensuring that values from both maps contribute to the result. + Symmetric, + /// An asymmetric join preserves keys unique to the first map as-is, while + /// applying the `JoinValues` operation only to keys unique to the second map. + Asymmetric, +}; /// Computes the key-wise union of two ImmutableMaps. // TODO(opt): This key-wise join is a performance bottleneck. A more @@ -994,22 +999,29 @@ static bool isSubsetOf(const llvm::ImmutableSet<T> &A, // instead of the current AVL-tree-based ImmutableMap. template <typename K, typename V, typename Joiner> static llvm::ImmutableMap<K, V> -join(llvm::ImmutableMap<K, V> A, llvm::ImmutableMap<K, V> B, - typename llvm::ImmutableMap<K, V>::Factory &F, Joiner JoinValues) { +join(const llvm::ImmutableMap<K, V> &A, const llvm::ImmutableMap<K, V> &B, + typename llvm::ImmutableMap<K, V>::Factory &F, Joiner JoinValues, + JoinKind Kind) { if (A.getHeight() < B.getHeight()) - std::swap(A, B); + return join(B, A, F, JoinValues, Kind); // For each element in B, join it with the corresponding element in A // (or with an empty value if it doesn't exist in A). + llvm::ImmutableMap<K, V> Res = A; for (const auto &Entry : B) { const K &Key = Entry.first; const V &ValB = Entry.second; - if (const V *ValA = A.lookup(Key)) - A = F.add(A, Key, JoinValues(*ValA, ValB)); - else - A = F.add(A, Key, ValB); + Res = F.add(Res, Key, JoinValues(A.lookup(Key), &ValB)); + } + if (Kind == JoinKind::Symmetric) { + for (const auto &Entry : A) { + const K &Key = Entry.first; + const V &ValA = Entry.second; + if (!B.contains(Key)) + Res = F.add(Res, Key, JoinValues(&ValA, nullptr)); + } } - return A; + return Res; } } // namespace utils @@ -1017,19 +1029,6 @@ join(llvm::ImmutableMap<K, V> A, llvm::ImmutableMap<K, V> B, // Loan Propagation Analysis // ========================================================================= // -using OriginLoanMap = llvm::ImmutableMap<OriginID, LoanSet>; -using ExpiredLoanMap = llvm::ImmutableMap<LoanID, const ExpireFact *>; - -/// An object to hold the factories for immutable collections, ensuring -/// that all created states share the same underlying memory management. -struct LifetimeFactory { - llvm::BumpPtrAllocator Allocator; - OriginLoanMap::Factory OriginMapFactory{Allocator, /*canonicalize=*/false}; - LoanSet::Factory LoanSetFactory{Allocator, /*canonicalize=*/false}; - ExpiredLoanMap::Factory ExpiredLoanMapFactory{Allocator, - /*canonicalize=*/false}; -}; - /// Represents the dataflow lattice for loan propagation. /// /// This lattice tracks which loans each origin may hold at a given program @@ -1073,10 +1072,10 @@ class LoanPropagationAnalysis public: LoanPropagationAnalysis(const CFG &C, AnalysisDeclContext &AC, FactManager &F, - LifetimeFactory &LFactory) - : DataflowAnalysis(C, AC, F), - OriginLoanMapFactory(LFactory.OriginMapFactory), - LoanSetFactory(LFactory.LoanSetFactory) {} + OriginLoanMap::Factory &OriginLoanMapFactory, + LoanSet::Factory &LoanSetFactory) + : DataflowAnalysis(C, AC, F), OriginLoanMapFactory(OriginLoanMapFactory), + LoanSetFactory(LoanSetFactory) {} using Base::transfer; @@ -1087,11 +1086,19 @@ public: /// Merges two lattices by taking the union of loans for each origin. // TODO(opt): Keep the state small by removing origins which become dead. Lattice join(Lattice A, Lattice B) { - OriginLoanMap JoinedOrigins = - utils::join(A.Origins, B.Origins, OriginLoanMapFactory, - [&](LoanSet S1, LoanSet S2) { - return utils::join(S1, S2, LoanSetFactory); - }); + OriginLoanMap JoinedOrigins = utils::join( + A.Origins, B.Origins, OriginLoanMapFactory, + [&](const LoanSet *S1, const LoanSet *S2) { + assert((S1 || S2) && "unexpectedly merging 2 empty sets"); + if (!S1) + return *S2; + if (!S2) + return *S1; + return utils::join(*S1, *S2, LoanSetFactory); + }, + // Asymmetric join is a performance win. For origins present only on one + // branch, the loan set can be carried over as-is. + utils::JoinKind::Asymmetric); return Lattice(JoinedOrigins); } @@ -1120,12 +1127,12 @@ public: OriginLoanMapFactory.add(In.Origins, DestOID, MergedLoans)); } - LoanSet getLoans(OriginID OID, ProgramPoint P) { + LoanSet getLoans(OriginID OID, ProgramPoint P) const { return getLoans(getState(P), OID); } private: - LoanSet getLoans(Lattice L, OriginID OID) { + LoanSet getLoans(Lattice L, OriginID OID) const { if (auto *Loans = L.Origins.lookup(OID)) return *Loans; return LoanSetFactory.getEmptySet(); @@ -1133,96 +1140,195 @@ private: }; // ========================================================================= // -// Expired Loans Analysis +// Live Origins Analysis +// ========================================================================= // +// +// A backward dataflow analysis that determines which origins are "live" at each +// program point. An origin is "live" at a program point if there's a potential +// future use of the pointer it represents. Liveness is "generated" by a read of +// origin's loan set (e.g., a `UseFact`) and is "killed" (i.e., it stops being +// live) when its loan set is overwritten (e.g. a OriginFlow killing the +// destination origin). +// +// This information is used for detecting use-after-free errors, as it allows us +// to check if a live origin holds a loan to an object that has already expired. // ========================================================================= // -/// The dataflow lattice for tracking the set of expired loans. -struct ExpiredLattice { - /// Map from an expired `LoanID` to the `ExpireFact` that made it expire. - ExpiredLoanMap Expired; +/// Information about why an origin is live at a program point. +struct LivenessInfo { + /// The use that makes the origin live. If liveness is propagated from + /// multiple uses along different paths, this will point to the use appearing + /// earlier in the translation unit. + /// This is 'null' when the origin is not live. + const UseFact *CausingUseFact; + /// The kind of liveness of the origin. + /// `Must`: The origin is live on all control-flow paths from the current + /// point to the function's exit (i.e. the current point is dominated by a set + /// of uses). + /// `Maybe`: indicates it is live on some but not all paths. + /// + /// This determines the diagnostic's confidence level. + /// `Must`-be-alive at expiration implies a definite use-after-free, + /// while `Maybe`-be-alive suggests a potential one on some paths. + LivenessKind Kind; + + LivenessInfo() : CausingUseFact(nullptr), Kind(LivenessKind::Dead) {} + LivenessInfo(const UseFact *UF, LivenessKind K) + : CausingUseFact(UF), Kind(K) {} + + bool operator==(const LivenessInfo &Other) const { + return CausingUseFact == Other.CausingUseFact && Kind == Other.Kind; + } + bool operator!=(const LivenessInfo &Other) const { return !(*this == Other); } + + void Profile(llvm::FoldingSetNodeID &IDBuilder) const { + IDBuilder.AddPointer(CausingUseFact); + IDBuilder.Add(Kind); + } +}; + +using LivenessMap = llvm::ImmutableMap<OriginID, LivenessInfo>; - ExpiredLattice() : Expired(nullptr) {}; - explicit ExpiredLattice(ExpiredLoanMap M) : Expired(M) {} +/// The dataflow lattice for origin liveness analysis. +/// It tracks which origins are live, why they're live (which UseFact), +/// and the confidence level of that liveness. +struct LivenessLattice { + LivenessMap LiveOrigins; - bool operator==(const ExpiredLattice &Other) const { - return Expired == Other.Expired; + LivenessLattice() : LiveOrigins(nullptr) {}; + + explicit LivenessLattice(LivenessMap L) : LiveOrigins(L) {} + + bool operator==(const LivenessLattice &Other) const { + return LiveOrigins == Other.LiveOrigins; } - bool operator!=(const ExpiredLattice &Other) const { + + bool operator!=(const LivenessLattice &Other) const { return !(*this == Other); } - void dump(llvm::raw_ostream &OS) const { - OS << "ExpiredLattice State:\n"; - if (Expired.isEmpty()) + void dump(llvm::raw_ostream &OS, const OriginManager &OM) const { + if (LiveOrigins.isEmpty()) OS << " <empty>\n"; - for (const auto &[ID, _] : Expired) - OS << " Loan " << ID << " is expired\n"; + for (const auto &Entry : LiveOrigins) { + OriginID OID = Entry.first; + const LivenessInfo &Info = Entry.second; + OS << " "; + OM.dump(OID, OS); + OS << " is "; + switch (Info.Kind) { + case LivenessKind::Must: + OS << "definitely"; + break; + case LivenessKind::Maybe: + OS << "maybe"; + break; + case LivenessKind::Dead: + llvm_unreachable("liveness kind of live origins should not be dead."); + } + OS << " live at this point\n"; + } } }; -/// The analysis that tracks which loans have expired. -class ExpiredLoansAnalysis - : public DataflowAnalysis<ExpiredLoansAnalysis, ExpiredLattice, - Direction::Forward> { - - ExpiredLoanMap::Factory &Factory; +/// The analysis that tracks which origins are live, with granular information +/// about the causing use fact and confidence level. This is a backward +/// analysis. +class LiveOriginAnalysis + : public DataflowAnalysis<LiveOriginAnalysis, LivenessLattice, + Direction::Backward> { + FactManager &FactMgr; + LivenessMap::Factory &Factory; public: - ExpiredLoansAnalysis(const CFG &C, AnalysisDeclContext &AC, FactManager &F, - LifetimeFactory &Factory) - : DataflowAnalysis(C, AC, F), Factory(Factory.ExpiredLoanMapFactory) {} - - using Base::transfer; + LiveOriginAnalysis(const CFG &C, AnalysisDeclContext &AC, FactManager &F, + LivenessMap::Factory &SF) + : DataflowAnalysis(C, AC, F), FactMgr(F), Factory(SF) {} + using DataflowAnalysis<LiveOriginAnalysis, Lattice, + Direction::Backward>::transfer; - StringRef getAnalysisName() const { return "ExpiredLoans"; } + StringRef getAnalysisName() const { return "LiveOrigins"; } Lattice getInitialState() { return Lattice(Factory.getEmptyMap()); } - /// Merges two lattices by taking the union of the two expired loans. - Lattice join(Lattice L1, Lattice L2) { - return Lattice( - utils::join(L1.Expired, L2.Expired, Factory, - // Take the last expiry fact to make this hermetic. - [](const ExpireFact *F1, const ExpireFact *F2) { - return F1->getExpiryLoc() > F2->getExpiryLoc() ? F1 : F2; - })); - } - - Lattice transfer(Lattice In, const ExpireFact &F) { - return Lattice(Factory.add(In.Expired, F.getLoanID(), &F)); - } - - // Removes the loan from the set of expired loans. - // - // When a loan is re-issued (e.g., in a loop), it is no longer considered - // expired. A loan can be in the expired set at the point of issue due to - // the dataflow state from a previous loop iteration being propagated along - // a backedge in the CFG. - // - // Note: This has a subtle false-negative though where a loan from previous - // iteration is not overwritten by a reissue. This needs careful tracking - // of loans "across iterations" which can be considered for future - // enhancements. - // - // void foo(int safe) { - // int* p = &safe; - // int* q = &safe; - // while (condition()) { - // int x = 1; - // p = &x; // A loan to 'x' is issued to 'p' in every iteration. - // if (condition()) { - // q = p; - // } - // (void)*p; // OK — 'p' points to 'x' from new iteration. - // (void)*q; // UaF - 'q' still points to 'x' from previous iteration - // // which is now destroyed. - // } - // } - Lattice transfer(Lattice In, const IssueFact &F) { - return Lattice(Factory.remove(In.Expired, F.getLoanID())); + /// Merges two lattices by combining liveness information. + /// When the same origin has different confidence levels, we take the lower + /// one. + Lattice join(Lattice L1, Lattice L2) const { + LivenessMap Merged = L1.LiveOrigins; + // Take the earliest UseFact to make the join hermetic and commutative. + auto CombineUseFact = [](const UseFact &A, + const UseFact &B) -> const UseFact * { + return A.getUseExpr()->getExprLoc() < B.getUseExpr()->getExprLoc() ? &A + : &B; + }; + auto CombineLivenessKind = [](LivenessKind K1, + LivenessKind K2) -> LivenessKind { + assert(K1 != LivenessKind::Dead && "LivenessKind should not be dead."); + assert(K2 != LivenessKind::Dead && "LivenessKind should not be dead."); + // Only return "Must" if both paths are "Must", otherwise Maybe. + if (K1 == LivenessKind::Must && K2 == LivenessKind::Must) + return LivenessKind::Must; + return LivenessKind::Maybe; + }; + auto CombineLivenessInfo = [&](const LivenessInfo *L1, + const LivenessInfo *L2) -> LivenessInfo { + assert((L1 || L2) && "unexpectedly merging 2 empty sets"); + if (!L1) + return LivenessInfo(L2->CausingUseFact, LivenessKind::Maybe); + if (!L2) + return LivenessInfo(L1->CausingUseFact, LivenessKind::Maybe); + return LivenessInfo( + CombineUseFact(*L1->CausingUseFact, *L2->CausingUseFact), + CombineLivenessKind(L1->Kind, L2->Kind)); + }; + return Lattice(utils::join( + L1.LiveOrigins, L2.LiveOrigins, Factory, CombineLivenessInfo, + // A symmetric join is required here. If an origin is live on one + // branch but not the other, its confidence must be demoted to `Maybe`. + utils::JoinKind::Symmetric)); + } + + /// A read operation makes the origin live with definite confidence, as it + /// dominates this program point. A write operation kills the liveness of + /// the origin since it overwrites the value. + Lattice transfer(Lattice In, const UseFact &UF) { + OriginID OID = UF.getUsedOrigin(FactMgr.getOriginMgr()); + // Write kills liveness. + if (UF.isWritten()) + return Lattice(Factory.remove(In.LiveOrigins, OID)); + // Read makes origin live with definite confidence (dominates this point). + return Lattice(Factory.add(In.LiveOrigins, OID, + LivenessInfo(&UF, LivenessKind::Must))); + } + + /// Issuing a new loan to an origin kills its liveness. + Lattice transfer(Lattice In, const IssueFact &IF) { + return Lattice(Factory.remove(In.LiveOrigins, IF.getOriginID())); } - ExpiredLoanMap getExpiredLoans(ProgramPoint P) { return getState(P).Expired; } + /// An OriginFlow kills the liveness of the destination origin if `KillDest` + /// is true. Otherwise, it propagates liveness from destination to source. + Lattice transfer(Lattice In, const OriginFlowFact &OF) { + if (!OF.getKillDest()) + return In; + return Lattice(Factory.remove(In.LiveOrigins, OF.getDestOriginID())); + } + + LivenessMap getLiveOrigins(ProgramPoint P) const { + return getState(P).LiveOrigins; + } + + // Dump liveness values on all test points in the program. + void dump(llvm::raw_ostream &OS, const LifetimeSafetyAnalysis &LSA) const { + llvm::dbgs() << "==========================================\n"; + llvm::dbgs() << getAnalysisName() << " results:\n"; + llvm::dbgs() << "==========================================\n"; + for (const auto &Entry : LSA.getTestPoints()) { + OS << "TestPoint: " << Entry.getKey() << "\n"; + getState(Entry.getValue()).dump(OS, FactMgr.getOriginMgr()); + } + } }; // ========================================================================= // @@ -1240,84 +1346,75 @@ class LifetimeChecker { private: llvm::DenseMap<LoanID, PendingWarning> FinalWarningsMap; LoanPropagationAnalysis &LoanPropagation; - ExpiredLoansAnalysis &ExpiredLoans; + LiveOriginAnalysis &LiveOrigins; FactManager &FactMgr; AnalysisDeclContext &ADC; LifetimeSafetyReporter *Reporter; public: - LifetimeChecker(LoanPropagationAnalysis &LPA, ExpiredLoansAnalysis &ELA, + LifetimeChecker(LoanPropagationAnalysis &LPA, LiveOriginAnalysis &LOA, FactManager &FM, AnalysisDeclContext &ADC, LifetimeSafetyReporter *Reporter) - : LoanPropagation(LPA), ExpiredLoans(ELA), FactMgr(FM), ADC(ADC), + : LoanPropagation(LPA), LiveOrigins(LOA), FactMgr(FM), ADC(ADC), Reporter(Reporter) {} void run() { llvm::TimeTraceScope TimeProfile("LifetimeChecker"); for (const CFGBlock *B : *ADC.getAnalysis<PostOrderCFGView>()) for (const Fact *F : FactMgr.getFacts(B)) - if (const auto *UF = F->getAs<UseFact>()) - checkUse(UF); + if (const auto *EF = F->getAs<ExpireFact>()) + checkExpiry(EF); issuePendingWarnings(); } - /// Checks for use-after-free errors for a given use of an Origin. + /// Checks for use-after-free errors when a loan expires. /// - /// This method is called for each 'UseFact' identified in the control flow - /// graph. It determines if the loans held by the used origin have expired - /// at the point of use. - void checkUse(const UseFact *UF) { - if (UF->isWritten()) - return; - OriginID O = UF->getUsedOrigin(FactMgr.getOriginMgr()); - - // Get the set of loans that the origin might hold at this program point. - LoanSet HeldLoans = LoanPropagation.getLoans(O, UF); - - // Get the set of all loans that have expired at this program point. - ExpiredLoanMap AllExpiredLoans = ExpiredLoans.getExpiredLoans(UF); - - // If the pointer holds no loans or no loans have expired, there's nothing - // to check. - if (HeldLoans.isEmpty() || AllExpiredLoans.isEmpty()) - return; - - // Identify loans that which have expired but are held by the pointer. Using - // them is a use-after-free. - llvm::SmallVector<LoanID> DefaultedLoans; - // A definite UaF error occurs if all loans the origin might hold have - // expired. - bool IsDefiniteError = true; - for (LoanID L : HeldLoans) { - if (AllExpiredLoans.contains(L)) - DefaultedLoans.push_back(L); - else - // If at least one loan is not expired, this use is not a definite UaF. - IsDefiniteError = false; + /// This method examines all live origins at the expiry point and determines + /// if any of them hold the expiring loan. If so, it creates a pending + /// warning with the appropriate confidence level based on the liveness + /// information. The confidence reflects whether the origin is definitely + /// or maybe live at this point. + /// + /// Note: This implementation considers only the confidence of origin + /// liveness. Future enhancements could also consider the confidence of loan + /// propagation (e.g., a loan may only be held on some execution paths). + void checkExpiry(const ExpireFact *EF) { + LoanID ExpiredLoan = EF->getLoanID(); + LivenessMap Origins = LiveOrigins.getLiveOrigins(EF); + Confidence CurConfidence = Confidence::None; + const UseFact *BadUse = nullptr; + for (auto &[OID, LiveInfo] : Origins) { + LoanSet HeldLoans = LoanPropagation.getLoans(OID, EF); + if (!HeldLoans.contains(ExpiredLoan)) + continue; + // Loan is defaulted. + Confidence NewConfidence = livenessKindToConfidence(LiveInfo.Kind); + if (CurConfidence < NewConfidence) { + CurConfidence = NewConfidence; + BadUse = LiveInfo.CausingUseFact; + } } - // If there are no defaulted loans, the use is safe. - if (DefaultedLoans.empty()) + if (!BadUse) return; - - // Determine the confidence level of the error (definite or maybe). - Confidence CurrentConfidence = - IsDefiniteError ? Confidence::Definite : Confidence::Maybe; - - // For each expired loan, create a pending warning. - for (LoanID DefaultedLoan : DefaultedLoans) { - // If we already have a warning for this loan with a higher or equal - // confidence, skip this one. - if (FinalWarningsMap.count(DefaultedLoan) && - CurrentConfidence <= FinalWarningsMap[DefaultedLoan].ConfidenceLevel) - continue; - - auto *EF = AllExpiredLoans.lookup(DefaultedLoan); - assert(EF && "Could not find ExpireFact for an expired loan."); - - FinalWarningsMap[DefaultedLoan] = {/*ExpiryLoc=*/(*EF)->getExpiryLoc(), - /*UseExpr=*/UF->getUseExpr(), - /*ConfidenceLevel=*/CurrentConfidence}; + // We have a use-after-free. + Confidence LastConf = FinalWarningsMap.lookup(ExpiredLoan).ConfidenceLevel; + if (LastConf >= CurConfidence) + return; + FinalWarningsMap[ExpiredLoan] = {/*ExpiryLoc=*/EF->getExpiryLoc(), + /*UseExpr=*/BadUse->getUseExpr(), + /*ConfidenceLevel=*/CurConfidence}; + } + + static Confidence livenessKindToConfidence(LivenessKind K) { + switch (K) { + case LivenessKind::Must: + return Confidence::Definite; + case LivenessKind::Maybe: + return Confidence::Maybe; + case LivenessKind::Dead: + return Confidence::None; } + llvm_unreachable("unknown liveness kind"); } void issuePendingWarnings() { @@ -1336,6 +1433,15 @@ public: // LifetimeSafetyAnalysis Class Implementation // ========================================================================= // +/// An object to hold the factories for immutable collections, ensuring +/// that all created states share the same underlying memory management. +struct LifetimeFactory { + llvm::BumpPtrAllocator Allocator; + OriginLoanMap::Factory OriginMapFactory{Allocator, /*canonicalize=*/false}; + LoanSet::Factory LoanSetFactory{Allocator, /*canonicalize=*/false}; + LivenessMap::Factory LivenessMapFactory{Allocator, /*canonicalize=*/false}; +}; + // We need this here for unique_ptr with forward declared class. LifetimeSafetyAnalysis::~LifetimeSafetyAnalysis() = default; @@ -1366,15 +1472,16 @@ void LifetimeSafetyAnalysis::run() { /// the analysis. /// 3. Collapse ExpireFacts belonging to same source location into a single /// Fact. - LoanPropagation = - std::make_unique<LoanPropagationAnalysis>(Cfg, AC, *FactMgr, *Factory); + LoanPropagation = std::make_unique<LoanPropagationAnalysis>( + Cfg, AC, *FactMgr, Factory->OriginMapFactory, Factory->LoanSetFactory); LoanPropagation->run(); - ExpiredLoans = - std::make_unique<ExpiredLoansAnalysis>(Cfg, AC, *FactMgr, *Factory); - ExpiredLoans->run(); + LiveOrigins = std::make_unique<LiveOriginAnalysis>( + Cfg, AC, *FactMgr, Factory->LivenessMapFactory); + LiveOrigins->run(); + DEBUG_WITH_TYPE("LiveOrigins", LiveOrigins->dump(llvm::dbgs(), *this)); - LifetimeChecker Checker(*LoanPropagation, *ExpiredLoans, *FactMgr, AC, + LifetimeChecker Checker(*LoanPropagation, *LiveOrigins, *FactMgr, AC, Reporter); Checker.run(); } @@ -1385,15 +1492,6 @@ LoanSet LifetimeSafetyAnalysis::getLoansAtPoint(OriginID OID, return LoanPropagation->getLoans(OID, PP); } -std::vector<LoanID> -LifetimeSafetyAnalysis::getExpiredLoansAtPoint(ProgramPoint PP) const { - assert(ExpiredLoans && "ExpiredLoansAnalysis has not been run."); - std::vector<LoanID> Result; - for (const auto &pair : ExpiredLoans->getExpiredLoans(PP)) - Result.push_back(pair.first); - return Result; -} - std::optional<OriginID> LifetimeSafetyAnalysis::getOriginIDForDecl(const ValueDecl *D) const { assert(FactMgr && "FactManager not initialized"); @@ -1413,6 +1511,15 @@ LifetimeSafetyAnalysis::getLoanIDForVar(const VarDecl *VD) const { return Result; } +std::vector<std::pair<OriginID, LivenessKind>> +LifetimeSafetyAnalysis::getLiveOriginsAtPoint(ProgramPoint PP) const { + assert(LiveOrigins && "LiveOriginAnalysis has not been run."); + std::vector<std::pair<OriginID, LivenessKind>> Result; + for (auto &[OID, Info] : LiveOrigins->getLiveOrigins(PP)) + Result.push_back({OID, Info.Kind}); + return Result; +} + llvm::StringMap<ProgramPoint> LifetimeSafetyAnalysis::getTestPoints() const { assert(FactMgr && "FactManager not initialized"); llvm::StringMap<ProgramPoint> AnnotationToPointMap; diff --git a/clang/lib/Analysis/UnsafeBufferUsage.cpp b/clang/lib/Analysis/UnsafeBufferUsage.cpp index ad3d234..f5a3686 100644 --- a/clang/lib/Analysis/UnsafeBufferUsage.cpp +++ b/clang/lib/Analysis/UnsafeBufferUsage.cpp @@ -13,6 +13,7 @@ #include "clang/AST/Attr.h" #include "clang/AST/Decl.h" #include "clang/AST/DeclCXX.h" +#include "clang/AST/DeclTemplate.h" #include "clang/AST/DynamicRecursiveASTVisitor.h" #include "clang/AST/Expr.h" #include "clang/AST/FormatString.h" @@ -1318,6 +1319,97 @@ static bool isSupportedVariable(const DeclRefExpr &Node) { return D != nullptr && isa<VarDecl>(D); } +// Returns true for RecordDecl of type std::unique_ptr<T[]> +static bool isUniquePtrArray(const CXXRecordDecl *RecordDecl) { + if (!RecordDecl || !RecordDecl->isInStdNamespace() || + RecordDecl->getNameAsString() != "unique_ptr") + return false; + + const ClassTemplateSpecializationDecl *class_template_specialization_decl = + dyn_cast<ClassTemplateSpecializationDecl>(RecordDecl); + if (!class_template_specialization_decl) + return false; + + const TemplateArgumentList &template_args = + class_template_specialization_decl->getTemplateArgs(); + if (template_args.size() == 0) + return false; + + const TemplateArgument &first_arg = template_args[0]; + if (first_arg.getKind() != TemplateArgument::Type) + return false; + + QualType referred_type = first_arg.getAsType(); + return referred_type->isArrayType(); +} + +class UniquePtrArrayAccessGadget : public WarningGadget { +private: + static constexpr const char *const AccessorTag = "unique_ptr_array_access"; + const CXXOperatorCallExpr *AccessorExpr; + +public: + UniquePtrArrayAccessGadget(const MatchResult &Result) + : WarningGadget(Kind::UniquePtrArrayAccess), + AccessorExpr(Result.getNodeAs<CXXOperatorCallExpr>(AccessorTag)) { + assert(AccessorExpr && + "UniquePtrArrayAccessGadget requires a matched CXXOperatorCallExpr"); + } + + static bool classof(const Gadget *G) { + return G->getKind() == Kind::UniquePtrArrayAccess; + } + + static bool matches(const Stmt *S, const ASTContext &Ctx, + MatchResult &Result) { + + const CXXOperatorCallExpr *OpCall = dyn_cast<CXXOperatorCallExpr>(S); + if (!OpCall || OpCall->getOperator() != OO_Subscript) + return false; + + const Expr *Callee = OpCall->getCallee()->IgnoreParenImpCasts(); + if (!Callee) + return false; + + const CXXMethodDecl *Method = + dyn_cast_or_null<CXXMethodDecl>(OpCall->getDirectCallee()); + if (!Method) + return false; + + if (Method->getOverloadedOperator() != OO_Subscript) + return false; + + const CXXRecordDecl *RecordDecl = Method->getParent(); + if (!isUniquePtrArray(RecordDecl)) + return false; + + const Expr *IndexExpr = OpCall->getArg(1); + clang::Expr::EvalResult Eval; + + // Allow [0] + if (IndexExpr->EvaluateAsInt(Eval, Ctx) && Eval.Val.getInt().isZero()) + return false; + + Result.addNode(AccessorTag, DynTypedNode::create(*OpCall)); + return true; + } + void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler, + bool IsRelatedToDecl, + ASTContext &Ctx) const override { + Handler.handleUnsafeUniquePtrArrayAccess( + DynTypedNode::create(*AccessorExpr), IsRelatedToDecl, Ctx); + } + + SourceLocation getSourceLoc() const override { + if (AccessorExpr) + return AccessorExpr->getOperatorLoc(); + return SourceLocation(); + } + + DeclUseList getClaimedVarUseSites() const override { return {}; } + SmallVector<const Expr *, 1> getUnsafePtrs() const override { return {}; } +}; + using FixableGadgetList = std::vector<std::unique_ptr<FixableGadget>>; using WarningGadgetList = std::vector<std::unique_ptr<WarningGadget>>; @@ -2632,10 +2724,13 @@ std::set<const Expr *> clang::findUnsafePointers(const FunctionDecl *FD) { const VariableGroupsManager &, FixItList &&, const Decl *, const FixitStrategy &) override {} - bool isSafeBufferOptOut(const SourceLocation &) const override { + void handleUnsafeUniquePtrArrayAccess(const DynTypedNode &Node, + bool IsRelatedToDecl, + ASTContext &Ctx) override {} + bool ignoreUnsafeBufferInContainer(const SourceLocation &) const override { return false; } - bool ignoreUnsafeBufferInContainer(const SourceLocation &) const override { + bool isSafeBufferOptOut(const SourceLocation &) const override { return false; } bool ignoreUnsafeBufferInLibcCall(const SourceLocation &) const override { diff --git a/clang/lib/Basic/Diagnostic.cpp b/clang/lib/Basic/Diagnostic.cpp index dc3778b..8ecbd3c 100644 --- a/clang/lib/Basic/Diagnostic.cpp +++ b/clang/lib/Basic/Diagnostic.cpp @@ -517,12 +517,6 @@ public: const SourceManager &SM) const; private: - // Find the longest glob pattern that matches FilePath amongst - // CategoriesToMatchers, return true iff the match exists and belongs to a - // positive category. - bool globsMatches(const llvm::StringMap<Matcher> &CategoriesToMatchers, - StringRef FilePath) const; - llvm::DenseMap<diag::kind, const Section *> DiagToSection; }; } // namespace @@ -537,33 +531,16 @@ WarningsSpecialCaseList::create(const llvm::MemoryBuffer &Input, } void WarningsSpecialCaseList::processSections(DiagnosticsEngine &Diags) { - // Drop the default section introduced by special case list, we only support - // exact diagnostic group names. - // FIXME: We should make this configurable in the parser instead. - // FIXME: C++20 can use std::erase_if(Sections, [](Section &sec) { return - // sec.SectionStr == "*"; }); - llvm::erase_if(Sections, [](Section &sec) { return sec.SectionStr == "*"; }); - // Make sure we iterate sections by their line numbers. - std::vector<std::pair<unsigned, const Section *>> LineAndSectionEntry; - LineAndSectionEntry.reserve(Sections.size()); - for (const auto &Entry : Sections) { - StringRef DiagName = Entry.SectionStr; - // Each section has a matcher with that section's name, attached to that - // line. - const auto &DiagSectionMatcher = Entry.SectionMatcher; - unsigned DiagLine = 0; - for (const auto &Glob : DiagSectionMatcher->Globs) - if (Glob->Name == DiagName) { - DiagLine = Glob->LineNo; - break; - } - LineAndSectionEntry.emplace_back(DiagLine, &Entry); - } - llvm::sort(LineAndSectionEntry); static constexpr auto WarningFlavor = clang::diag::Flavor::WarningOrError; - for (const auto &[_, SectionEntry] : LineAndSectionEntry) { + for (const auto &SectionEntry : sections()) { + StringRef DiagGroup = SectionEntry.SectionStr; + if (DiagGroup == "*") { + // Drop the default section introduced by special case list, we only + // support exact diagnostic group names. + // FIXME: We should make this configurable in the parser instead. + continue; + } SmallVector<diag::kind> GroupDiags; - StringRef DiagGroup = SectionEntry->SectionStr; if (Diags.getDiagnosticIDs()->getDiagnosticsInGroup( WarningFlavor, DiagGroup, GroupDiags)) { StringRef Suggestion = @@ -576,7 +553,7 @@ void WarningsSpecialCaseList::processSections(DiagnosticsEngine &Diags) { for (diag::kind Diag : GroupDiags) // We're intentionally overwriting any previous mappings here to make sure // latest one takes precedence. - DiagToSection[Diag] = SectionEntry; + DiagToSection[Diag] = &SectionEntry; } } @@ -601,43 +578,24 @@ void DiagnosticsEngine::setDiagSuppressionMapping(llvm::MemoryBuffer &Input) { bool WarningsSpecialCaseList::isDiagSuppressed(diag::kind DiagId, SourceLocation DiagLoc, const SourceManager &SM) const { + PresumedLoc PLoc = SM.getPresumedLoc(DiagLoc); + if (!PLoc.isValid()) + return false; const Section *DiagSection = DiagToSection.lookup(DiagId); if (!DiagSection) return false; - const SectionEntries &EntityTypeToCategories = DiagSection->Entries; - auto SrcEntriesIt = EntityTypeToCategories.find("src"); - if (SrcEntriesIt == EntityTypeToCategories.end()) + + StringRef F = llvm::sys::path::remove_leading_dotslash(PLoc.getFilename()); + + StringRef LongestSup = DiagSection->getLongestMatch("src", F, ""); + if (LongestSup.empty()) return false; - const llvm::StringMap<llvm::SpecialCaseList::Matcher> &CategoriesToMatchers = - SrcEntriesIt->getValue(); - // We also use presumed locations here to improve reproducibility for - // preprocessed inputs. - if (PresumedLoc PLoc = SM.getPresumedLoc(DiagLoc); PLoc.isValid()) - return globsMatches( - CategoriesToMatchers, - llvm::sys::path::remove_leading_dotslash(PLoc.getFilename())); - return false; -} -bool WarningsSpecialCaseList::globsMatches( - const llvm::StringMap<Matcher> &CategoriesToMatchers, - StringRef FilePath) const { - StringRef LongestMatch; - bool LongestIsPositive = false; - for (const auto &Entry : CategoriesToMatchers) { - StringRef Category = Entry.getKey(); - const llvm::SpecialCaseList::Matcher &Matcher = Entry.getValue(); - bool IsPositive = Category != "emit"; - for (const auto &Glob : Matcher.Globs) { - if (Glob->Name.size() < LongestMatch.size()) - continue; - if (!Glob->Pattern.match(FilePath)) - continue; - LongestMatch = Glob->Name; - LongestIsPositive = IsPositive; - } - } - return LongestIsPositive; + StringRef LongestEmit = DiagSection->getLongestMatch("src", F, "emit"); + if (LongestEmit.empty()) + return true; + + return LongestSup.size() > LongestEmit.size(); } bool DiagnosticsEngine::isSuppressedViaMapping(diag::kind DiagId, diff --git a/clang/lib/Basic/ProfileList.cpp b/clang/lib/Basic/ProfileList.cpp index 8481def..9cb1188 100644 --- a/clang/lib/Basic/ProfileList.cpp +++ b/clang/lib/Basic/ProfileList.cpp @@ -32,10 +32,10 @@ public: createOrDie(const std::vector<std::string> &Paths, llvm::vfs::FileSystem &VFS); - bool isEmpty() const { return Sections.empty(); } + bool isEmpty() const { return sections().empty(); } bool hasPrefix(StringRef Prefix) const { - for (const auto &It : Sections) + for (const auto &It : sections()) if (It.Entries.count(Prefix) > 0) return true; return false; diff --git a/clang/lib/Basic/SanitizerSpecialCaseList.cpp b/clang/lib/Basic/SanitizerSpecialCaseList.cpp index f7bc1d5..56f5516 100644 --- a/clang/lib/Basic/SanitizerSpecialCaseList.cpp +++ b/clang/lib/Basic/SanitizerSpecialCaseList.cpp @@ -38,11 +38,11 @@ SanitizerSpecialCaseList::createOrDie(const std::vector<std::string> &Paths, } void SanitizerSpecialCaseList::createSanitizerSections() { - for (auto &S : Sections) { + for (const auto &S : sections()) { SanitizerMask Mask; #define SANITIZER(NAME, ID) \ - if (S.SectionMatcher->match(NAME)) \ + if (S.SectionMatcher.matchAny(NAME)) \ Mask |= SanitizerKind::ID; #define SANITIZER_GROUP(NAME, ID, ALIAS) SANITIZER(NAME, ID) @@ -50,7 +50,7 @@ void SanitizerSpecialCaseList::createSanitizerSections() { #undef SANITIZER #undef SANITIZER_GROUP - SanitizerSections.emplace_back(Mask, S.Entries, S.FileIdx); + SanitizerSections.emplace_back(Mask, S); } } @@ -66,10 +66,9 @@ SanitizerSpecialCaseList::inSectionBlame(SanitizerMask Mask, StringRef Prefix, StringRef Category) const { for (const auto &S : llvm::reverse(SanitizerSections)) { if (S.Mask & Mask) { - unsigned LineNum = - SpecialCaseList::inSectionBlame(S.Entries, Prefix, Query, Category); + unsigned LineNum = S.S.getLastMatch(Prefix, Query, Category); if (LineNum > 0) - return {S.FileIdx, LineNum}; + return {S.S.FileIdx, LineNum}; } } return NotFound; diff --git a/clang/lib/CIR/CodeGen/CIRGenBuilder.h b/clang/lib/CIR/CodeGen/CIRGenBuilder.h index 58345b4..25afe8b 100644 --- a/clang/lib/CIR/CodeGen/CIRGenBuilder.h +++ b/clang/lib/CIR/CodeGen/CIRGenBuilder.h @@ -122,6 +122,11 @@ public: return getPointerTo(cir::VPtrType::get(getContext())); } + cir::FuncType getFuncType(llvm::ArrayRef<mlir::Type> params, mlir::Type retTy, + bool isVarArg = false) { + return cir::FuncType::get(params, retTy, isVarArg); + } + /// Get a CIR record kind from a AST declaration tag. cir::RecordType::RecordKind getRecordKind(const clang::TagTypeKind kind) { switch (kind) { @@ -372,6 +377,15 @@ public: return cir::BinOp::create(*this, loc, cir::BinOpKind::Div, lhs, rhs); } + mlir::Value createDynCast(mlir::Location loc, mlir::Value src, + cir::PointerType destType, bool isRefCast, + cir::DynamicCastInfoAttr info) { + auto castKind = + isRefCast ? cir::DynamicCastKind::Ref : cir::DynamicCastKind::Ptr; + return cir::DynamicCastOp::create(*this, loc, destType, castKind, src, info, + /*relative_layout=*/false); + } + Address createBaseClassAddr(mlir::Location loc, Address addr, mlir::Type destType, unsigned offset, bool assumeNotNull) { diff --git a/clang/lib/CIR/CodeGen/CIRGenCXX.cpp b/clang/lib/CIR/CodeGen/CIRGenCXX.cpp index d5b35c2..274d11b 100644 --- a/clang/lib/CIR/CodeGen/CIRGenCXX.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenCXX.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "CIRGenCXXABI.h" #include "CIRGenFunction.h" #include "CIRGenModule.h" @@ -95,7 +96,63 @@ static void emitDeclDestroy(CIRGenFunction &cgf, const VarDecl *vd, return; } - cgf.cgm.errorNYI(vd->getSourceRange(), "global with destructor"); + // If not constant storage we'll emit this regardless of NeedsDtor value. + CIRGenBuilderTy &builder = cgf.getBuilder(); + + // Prepare the dtor region. + mlir::OpBuilder::InsertionGuard guard(builder); + mlir::Block *block = builder.createBlock(&addr.getDtorRegion()); + CIRGenFunction::LexicalScope lexScope{cgf, addr.getLoc(), + builder.getInsertionBlock()}; + lexScope.setAsGlobalInit(); + builder.setInsertionPointToStart(block); + + CIRGenModule &cgm = cgf.cgm; + QualType type = vd->getType(); + + // Special-case non-array C++ destructors, if they have the right signature. + // Under some ABIs, destructors return this instead of void, and cannot be + // passed directly to __cxa_atexit if the target does not allow this + // mismatch. + const CXXRecordDecl *record = type->getAsCXXRecordDecl(); + bool canRegisterDestructor = + record && (!cgm.getCXXABI().hasThisReturn( + GlobalDecl(record->getDestructor(), Dtor_Complete)) || + cgm.getCXXABI().canCallMismatchedFunctionType()); + + // If __cxa_atexit is disabled via a flag, a different helper function is + // generated elsewhere which uses atexit instead, and it takes the destructor + // directly. + cir::FuncOp fnOp; + if (record && (canRegisterDestructor || cgm.getCodeGenOpts().CXAAtExit)) { + if (vd->getTLSKind()) + cgm.errorNYI(vd->getSourceRange(), "TLS destructor"); + assert(!record->hasTrivialDestructor()); + assert(!cir::MissingFeatures::openCL()); + CXXDestructorDecl *dtor = record->getDestructor(); + // In LLVM OG codegen this is done in registerGlobalDtor, but CIRGen + // relies on LoweringPrepare for further decoupling, so build the + // call right here. + auto gd = GlobalDecl(dtor, Dtor_Complete); + fnOp = cgm.getAddrAndTypeOfCXXStructor(gd).second; + cgf.getBuilder().createCallOp( + cgf.getLoc(vd->getSourceRange()), + mlir::FlatSymbolRefAttr::get(fnOp.getSymNameAttr()), + mlir::ValueRange{cgm.getAddrOfGlobalVar(vd)}); + } else { + cgm.errorNYI(vd->getSourceRange(), "array destructor"); + } + assert(fnOp && "expected cir.func"); + cgm.getCXXABI().registerGlobalDtor(vd, fnOp, nullptr); + + builder.setInsertionPointToEnd(block); + if (block->empty()) { + block->erase(); + // Don't confuse lexical cleanup. + builder.clearInsertionPoint(); + } else { + builder.create<cir::YieldOp>(addr.getLoc()); + } } cir::FuncOp CIRGenModule::codegenCXXStructor(GlobalDecl gd) { diff --git a/clang/lib/CIR/CodeGen/CIRGenCXXABI.cpp b/clang/lib/CIR/CodeGen/CIRGenCXXABI.cpp index 5f1faab..df42af8 100644 --- a/clang/lib/CIR/CodeGen/CIRGenCXXABI.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenCXXABI.cpp @@ -15,6 +15,7 @@ #include "CIRGenFunction.h" #include "clang/AST/Decl.h" +#include "clang/AST/ExprCXX.h" #include "clang/AST/GlobalDecl.h" using namespace clang; @@ -75,3 +76,20 @@ void CIRGenCXXABI::setCXXABIThisValue(CIRGenFunction &cgf, assert(getThisDecl(cgf) && "no 'this' variable for function"); cgf.cxxabiThisValue = thisPtr; } + +CharUnits CIRGenCXXABI::getArrayCookieSize(const CXXNewExpr *e) { + if (!requiresArrayCookie(e)) + return CharUnits::Zero(); + + cgm.errorNYI(e->getSourceRange(), "CIRGenCXXABI::getArrayCookieSize"); + return CharUnits::Zero(); +} + +bool CIRGenCXXABI::requiresArrayCookie(const CXXNewExpr *e) { + // If the class's usual deallocation function takes two arguments, + // it needs a cookie. + if (e->doesUsualArrayDeleteWantSize()) + return true; + + return e->getAllocatedType().isDestructedType(); +} diff --git a/clang/lib/CIR/CodeGen/CIRGenCXXABI.h b/clang/lib/CIR/CodeGen/CIRGenCXXABI.h index 1dee774..06f41cd 100644 --- a/clang/lib/CIR/CodeGen/CIRGenCXXABI.h +++ b/clang/lib/CIR/CodeGen/CIRGenCXXABI.h @@ -28,6 +28,8 @@ protected: CIRGenModule &cgm; std::unique_ptr<clang::MangleContext> mangleContext; + virtual bool requiresArrayCookie(const CXXNewExpr *e); + public: // TODO(cir): make this protected when target-specific CIRGenCXXABIs are // implemented. @@ -52,6 +54,12 @@ public: Address thisAddr, const CXXRecordDecl *classDecl, const CXXRecordDecl *baseClassDecl) = 0; + virtual mlir::Value emitDynamicCast(CIRGenFunction &cgf, mlir::Location loc, + QualType srcRecordTy, + QualType destRecordTy, + cir::PointerType destCIRTy, + bool isRefCast, Address src) = 0; + public: /// Similar to AddedStructorArgs, but only notes the number of additional /// arguments. @@ -113,6 +121,7 @@ public: CIRGenFunction &cgf) = 0; virtual void emitRethrow(CIRGenFunction &cgf, bool isNoReturn) = 0; + virtual void emitThrow(CIRGenFunction &cgf, const CXXThrowExpr *e) = 0; virtual mlir::Attribute getAddrOfRTTIDescriptor(mlir::Location loc, QualType ty) = 0; @@ -146,6 +155,14 @@ public: /// Loads the incoming C++ this pointer as it was passed by the caller. mlir::Value loadIncomingCXXThis(CIRGenFunction &cgf); + /// Get the implicit (second) parameter that comes after the "this" pointer, + /// or nullptr if there is isn't one. + virtual mlir::Value getCXXDestructorImplicitParam(CIRGenFunction &cgf, + const CXXDestructorDecl *dd, + CXXDtorType type, + bool forVirtualBase, + bool delegating) = 0; + /// Emit constructor variants required by this ABI. virtual void emitCXXConstructors(const clang::CXXConstructorDecl *d) = 0; @@ -157,6 +174,14 @@ public: bool forVirtualBase, bool delegating, Address thisAddr, QualType thisTy) = 0; + /// Emit code to force the execution of a destructor during global + /// teardown. The default implementation of this uses atexit. + /// + /// \param dtor - a function taking a single pointer argument + /// \param addr - a pointer to pass to the destructor function. + virtual void registerGlobalDtor(const VarDecl *vd, cir::FuncOp dtor, + mlir::Value addr) = 0; + /// Checks if ABI requires extra virtual offset for vtable field. virtual bool isVirtualOffsetNeededForVTableField(CIRGenFunction &cgf, @@ -230,6 +255,16 @@ public: return false; } + /// Returns true if the target allows calling a function through a pointer + /// with a different signature than the actual function (or equivalently, + /// bitcasting a function or function pointer to a different function type). + /// In principle in the most general case this could depend on the target, the + /// calling convention, and the actual types of the arguments and return + /// value. Here it just means whether the signature mismatch could *ever* be + /// allowed; in other words, does the target do strict checking of signatures + /// for all calls. + virtual bool canCallMismatchedFunctionType() const { return true; } + /// Gets the mangle context. clang::MangleContext &getMangleContext() { return *mangleContext; } @@ -244,6 +279,19 @@ public: void setStructorImplicitParamValue(CIRGenFunction &cgf, mlir::Value val) { cgf.cxxStructorImplicitParamValue = val; } + + /**************************** Array cookies ******************************/ + + /// Returns the extra size required in order to store the array + /// cookie for the given new-expression. May return 0 to indicate that no + /// array cookie is required. + /// + /// Several cases are filtered out before this method is called: + /// - non-array allocations never need a cookie + /// - calls to \::operator new(size_t, void*) never need a cookie + /// + /// \param E - the new-expression being allocated. + virtual CharUnits getArrayCookieSize(const CXXNewExpr *e); }; /// Creates and Itanium-family ABI diff --git a/clang/lib/CIR/CodeGen/CIRGenClass.cpp b/clang/lib/CIR/CodeGen/CIRGenClass.cpp index 8f4377b..485b2c8 100644 --- a/clang/lib/CIR/CodeGen/CIRGenClass.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenClass.cpp @@ -126,6 +126,30 @@ static bool isInitializerOfDynamicClass(const CXXCtorInitializer *baseInit) { } namespace { +/// Call the destructor for a direct base class. +struct CallBaseDtor final : EHScopeStack::Cleanup { + const CXXRecordDecl *baseClass; + bool baseIsVirtual; + CallBaseDtor(const CXXRecordDecl *base, bool baseIsVirtual) + : baseClass(base), baseIsVirtual(baseIsVirtual) {} + + void emit(CIRGenFunction &cgf) override { + const CXXRecordDecl *derivedClass = + cast<CXXMethodDecl>(cgf.curFuncDecl)->getParent(); + + const CXXDestructorDecl *d = baseClass->getDestructor(); + // We are already inside a destructor, so presumably the object being + // destroyed should have the expected type. + QualType thisTy = d->getFunctionObjectParameterType(); + assert(cgf.currSrcLoc && "expected source location"); + Address addr = cgf.getAddressOfDirectBaseInCompleteClass( + *cgf.currSrcLoc, cgf.loadCXXThisAddress(), derivedClass, baseClass, + baseIsVirtual); + cgf.emitCXXDestructorCall(d, Dtor_Base, baseIsVirtual, + /*delegating=*/false, addr, thisTy); + } +}; + /// A visitor which checks whether an initializer uses 'this' in a /// way which requires the vtable to be properly set. struct DynamicThisUseChecker @@ -870,6 +894,116 @@ void CIRGenFunction::destroyCXXObject(CIRGenFunction &cgf, Address addr, /*delegating=*/false, addr, type); } +namespace { +class DestroyField final : public EHScopeStack::Cleanup { + const FieldDecl *field; + CIRGenFunction::Destroyer *destroyer; + +public: + DestroyField(const FieldDecl *field, CIRGenFunction::Destroyer *destroyer) + : field(field), destroyer(destroyer) {} + + void emit(CIRGenFunction &cgf) override { + // Find the address of the field. + Address thisValue = cgf.loadCXXThisAddress(); + CanQualType recordTy = + cgf.getContext().getCanonicalTagType(field->getParent()); + LValue thisLV = cgf.makeAddrLValue(thisValue, recordTy); + LValue lv = cgf.emitLValueForField(thisLV, field); + assert(lv.isSimple()); + + assert(!cir::MissingFeatures::ehCleanupFlags()); + cgf.emitDestroy(lv.getAddress(), field->getType(), destroyer); + } +}; +} // namespace + +/// Emit all code that comes at the end of class's destructor. This is to call +/// destructors on members and base classes in reverse order of their +/// construction. +/// +/// For a deleting destructor, this also handles the case where a destroying +/// operator delete completely overrides the definition. +void CIRGenFunction::enterDtorCleanups(const CXXDestructorDecl *dd, + CXXDtorType dtorType) { + assert((!dd->isTrivial() || dd->hasAttr<DLLExportAttr>()) && + "Should not emit dtor epilogue for non-exported trivial dtor!"); + + // The deleting-destructor phase just needs to call the appropriate + // operator delete that Sema picked up. + if (dtorType == Dtor_Deleting) { + cgm.errorNYI(dd->getSourceRange(), "deleting destructor cleanups"); + return; + } + + const CXXRecordDecl *classDecl = dd->getParent(); + + // Unions have no bases and do not call field destructors. + if (classDecl->isUnion()) + return; + + // The complete-destructor phase just destructs all the virtual bases. + if (dtorType == Dtor_Complete) { + assert(!cir::MissingFeatures::sanitizers()); + + // We push them in the forward order so that they'll be popped in + // the reverse order. + for (const CXXBaseSpecifier &base : classDecl->vbases()) { + auto *baseClassDecl = base.getType()->castAsCXXRecordDecl(); + + if (baseClassDecl->hasTrivialDestructor()) { + // Under SanitizeMemoryUseAfterDtor, poison the trivial base class + // memory. For non-trival base classes the same is done in the class + // destructor. + assert(!cir::MissingFeatures::sanitizers()); + } else { + ehStack.pushCleanup<CallBaseDtor>(NormalAndEHCleanup, baseClassDecl, + /*baseIsVirtual=*/true); + } + } + + return; + } + + assert(dtorType == Dtor_Base); + assert(!cir::MissingFeatures::sanitizers()); + + // Destroy non-virtual bases. + for (const CXXBaseSpecifier &base : classDecl->bases()) { + // Ignore virtual bases. + if (base.isVirtual()) + continue; + + CXXRecordDecl *baseClassDecl = base.getType()->getAsCXXRecordDecl(); + + if (baseClassDecl->hasTrivialDestructor()) + assert(!cir::MissingFeatures::sanitizers()); + else + ehStack.pushCleanup<CallBaseDtor>(NormalAndEHCleanup, baseClassDecl, + /*baseIsVirtual=*/false); + } + + assert(!cir::MissingFeatures::sanitizers()); + + // Destroy direct fields. + for (const FieldDecl *field : classDecl->fields()) { + QualType type = field->getType(); + QualType::DestructionKind dtorKind = type.isDestructedType(); + if (!dtorKind) + continue; + + // Anonymous union members do not have their destructors called. + const RecordType *rt = type->getAsUnionType(); + if (rt && rt->getOriginalDecl()->isAnonymousStructOrUnion()) + continue; + + CleanupKind cleanupKind = getCleanupKind(dtorKind); + assert(!cir::MissingFeatures::ehCleanupFlags()); + ehStack.pushCleanup<DestroyField>(cleanupKind, field, + getDestroyer(dtorKind)); + } +} + void CIRGenFunction::emitDelegatingCXXConstructorCall( const CXXConstructorDecl *ctor, const FunctionArgList &args) { assert(ctor->isDelegatingConstructor()); diff --git a/clang/lib/CIR/CodeGen/CIRGenCleanup.h b/clang/lib/CIR/CodeGen/CIRGenCleanup.h index a4ec8cc..30f5607 100644 --- a/clang/lib/CIR/CodeGen/CIRGenCleanup.h +++ b/clang/lib/CIR/CodeGen/CIRGenCleanup.h @@ -104,6 +104,7 @@ public: bool isNormalCleanup() const { return cleanupBits.isNormalCleanup; } bool isActive() const { return cleanupBits.isActive; } + void setActive(bool isActive) { cleanupBits.isActive = isActive; } size_t getCleanupSize() const { return cleanupBits.cleanupSize; } void *getCleanupBuffer() { return this + 1; } @@ -138,5 +139,13 @@ inline EHScopeStack::iterator EHScopeStack::begin() const { return iterator(startOfData); } +inline EHScopeStack::iterator +EHScopeStack::find(stable_iterator savePoint) const { + assert(savePoint.isValid() && "finding invalid savepoint"); + assert(savePoint.size <= stable_begin().size && + "finding savepoint after pop"); + return iterator(endOfBuffer - savePoint.size); +} + } // namespace clang::CIRGen #endif // CLANG_LIB_CIR_CODEGEN_CIRGENCLEANUP_H diff --git a/clang/lib/CIR/CodeGen/CIRGenDecl.cpp b/clang/lib/CIR/CodeGen/CIRGenDecl.cpp index 563a753..039d290 100644 --- a/clang/lib/CIR/CodeGen/CIRGenDecl.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenDecl.cpp @@ -695,12 +695,6 @@ struct DestroyObject final : EHScopeStack::Cleanup { void emit(CIRGenFunction &cgf) override { cgf.emitDestroy(addr, type, destroyer); } - - // This is a placeholder until EHCleanupScope is implemented. - size_t getSize() const override { - assert(!cir::MissingFeatures::ehCleanupScope()); - return sizeof(DestroyObject); - } }; } // namespace diff --git a/clang/lib/CIR/CodeGen/CIRGenException.cpp b/clang/lib/CIR/CodeGen/CIRGenException.cpp index 7fcb39a..6453843 100644 --- a/clang/lib/CIR/CodeGen/CIRGenException.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenException.cpp @@ -31,11 +31,36 @@ void CIRGenFunction::emitCXXThrowExpr(const CXXThrowExpr *e) { if (throwType->isObjCObjectPointerType()) { cgm.errorNYI("emitCXXThrowExpr ObjCObjectPointerType"); return; - } else { - cgm.errorNYI("emitCXXThrowExpr with subExpr"); - return; } - } else { - cgm.getCXXABI().emitRethrow(*this, /*isNoReturn=*/true); + + cgm.getCXXABI().emitThrow(*this, e); + return; } + + cgm.getCXXABI().emitRethrow(*this, /*isNoReturn=*/true); +} + +void CIRGenFunction::emitAnyExprToExn(const Expr *e, Address addr) { + // Make sure the exception object is cleaned up if there's an + // exception during initialization. + assert(!cir::MissingFeatures::ehCleanupScope()); + + // __cxa_allocate_exception returns a void*; we need to cast this + // to the appropriate type for the object. + mlir::Type ty = convertTypeForMem(e->getType()); + Address typedAddr = addr.withElementType(builder, ty); + + // From LLVM's codegen: + // FIXME: this isn't quite right! If there's a final unelided call + // to a copy constructor, then according to [except.terminate]p1 we + // must call std::terminate() if that constructor throws, because + // technically that copy occurs after the exception expression is + // evaluated but before the exception is caught. But the best way + // to handle that is to teach EmitAggExpr to do the final copy + // differently if it can't be elided. + emitAnyExprToMem(e, typedAddr, e->getType().getQualifiers(), + /*isInitializer=*/true); + + // Deactivate the cleanup block. + assert(!cir::MissingFeatures::ehCleanupScope()); } diff --git a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp index be94890..f416571 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp @@ -1185,10 +1185,16 @@ LValue CIRGenFunction::emitCastLValue(const CastExpr *e) { case CK_BuiltinFnToFnPtr: llvm_unreachable("builtin functions are handled elsewhere"); + case CK_Dynamic: { + LValue lv = emitLValue(e->getSubExpr()); + Address v = lv.getAddress(); + const auto *dce = cast<CXXDynamicCastExpr>(e); + return makeNaturalAlignAddrLValue(emitDynamicCast(v, dce), e->getType()); + } + // These are never l-values; just use the aggregate emission code. case CK_NonAtomicToAtomic: case CK_AtomicToNonAtomic: - case CK_Dynamic: case CK_ToUnion: case CK_BaseToDerived: case CK_AddressSpaceConversion: diff --git a/clang/lib/CIR/CodeGen/CIRGenExprCXX.cpp b/clang/lib/CIR/CodeGen/CIRGenExprCXX.cpp index 7989ad2..97c0944 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprCXX.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprCXX.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "CIRGenCXXABI.h" +#include "CIRGenConstantEmitter.h" #include "CIRGenFunction.h" #include "clang/AST/DeclCXX.h" @@ -210,6 +211,19 @@ RValue CIRGenFunction::emitCXXMemberOrOperatorCall( return emitCall(fnInfo, callee, returnValue, args, nullptr, loc); } +static CharUnits calculateCookiePadding(CIRGenFunction &cgf, + const CXXNewExpr *e) { + if (!e->isArray()) + return CharUnits::Zero(); + + // No cookie is required if the operator new[] being used is the + // reserved placement operator new[]. + if (e->getOperatorNew()->isReservedGlobalPlacementOperator()) + return CharUnits::Zero(); + + return cgf.cgm.getCXXABI().getArrayCookieSize(e); +} + static mlir::Value emitCXXNewAllocSize(CIRGenFunction &cgf, const CXXNewExpr *e, unsigned minElements, mlir::Value &numElements, @@ -224,8 +238,98 @@ static mlir::Value emitCXXNewAllocSize(CIRGenFunction &cgf, const CXXNewExpr *e, return sizeWithoutCookie; } - cgf.cgm.errorNYI(e->getSourceRange(), "emitCXXNewAllocSize: array"); - return {}; + // The width of size_t. + unsigned sizeWidth = cgf.cgm.getDataLayout().getTypeSizeInBits(cgf.SizeTy); + + // The number of elements can be have an arbitrary integer type; + // essentially, we need to multiply it by a constant factor, add a + // cookie size, and verify that the result is representable as a + // size_t. That's just a gloss, though, and it's wrong in one + // important way: if the count is negative, it's an error even if + // the cookie size would bring the total size >= 0. + // + // If the array size is constant, Sema will have prevented negative + // values and size overflow. + + // Compute the constant factor. + llvm::APInt arraySizeMultiplier(sizeWidth, 1); + while (const ConstantArrayType *cat = + cgf.getContext().getAsConstantArrayType(type)) { + type = cat->getElementType(); + arraySizeMultiplier *= cat->getSize(); + } + + CharUnits typeSize = cgf.getContext().getTypeSizeInChars(type); + llvm::APInt typeSizeMultiplier(sizeWidth, typeSize.getQuantity()); + typeSizeMultiplier *= arraySizeMultiplier; + + // Figure out the cookie size. + llvm::APInt cookieSize(sizeWidth, + calculateCookiePadding(cgf, e).getQuantity()); + + // This will be a size_t. + mlir::Value size; + + // Emit the array size expression. + // We multiply the size of all dimensions for NumElements. + // e.g for 'int[2][3]', ElemType is 'int' and NumElements is 6. + const Expr *arraySize = *e->getArraySize(); + mlir::Attribute constNumElements = + ConstantEmitter(cgf.cgm, &cgf) + .emitAbstract(arraySize, arraySize->getType()); + if (constNumElements) { + // Get an APInt from the constant + const llvm::APInt &count = + mlir::cast<cir::IntAttr>(constNumElements).getValue(); + + unsigned numElementsWidth = count.getBitWidth(); + + // The equivalent code in CodeGen/CGExprCXX.cpp handles these cases as + // overflow, but that should never happen. The size argument is implicitly + // cast to a size_t, so it can never be negative and numElementsWidth will + // always equal sizeWidth. + assert(!count.isNegative() && "Expected non-negative array size"); + assert(numElementsWidth == sizeWidth && + "Expected a size_t array size constant"); + + // Okay, compute a count at the right width. + llvm::APInt adjustedCount = count.zextOrTrunc(sizeWidth); + + // Scale numElements by that. This might overflow, but we don't + // care because it only overflows if allocationSize does too, and + // if that overflows then we shouldn't use this. + // This emits a constant that may not be used, but we can't tell here + // whether it will be needed or not. + numElements = + cgf.getBuilder().getConstInt(loc, adjustedCount * arraySizeMultiplier); + + // Compute the size before cookie, and track whether it overflowed. + bool overflow; + llvm::APInt allocationSize = + adjustedCount.umul_ov(typeSizeMultiplier, overflow); + + // Sema prevents us from hitting this case + assert(!overflow && "Overflow in array allocation size"); + + // Add in the cookie, and check whether it's overflowed. + if (cookieSize != 0) { + cgf.cgm.errorNYI(e->getSourceRange(), + "emitCXXNewAllocSize: array cookie"); + } + + size = cgf.getBuilder().getConstInt(loc, allocationSize); + } else { + // TODO: Handle the variable size case + cgf.cgm.errorNYI(e->getSourceRange(), + "emitCXXNewAllocSize: variable array size"); + } + + if (cookieSize == 0) + sizeWithoutCookie = size; + else + assert(sizeWithoutCookie && "didn't set sizeWithoutCookie?"); + + return size; } static void storeAnyExprIntoOneUnit(CIRGenFunction &cgf, const Expr *init, @@ -254,13 +358,26 @@ static void storeAnyExprIntoOneUnit(CIRGenFunction &cgf, const Expr *init, llvm_unreachable("bad evaluation kind"); } +void CIRGenFunction::emitNewArrayInitializer( + const CXXNewExpr *e, QualType elementType, mlir::Type elementTy, + Address beginPtr, mlir::Value numElements, + mlir::Value allocSizeWithoutCookie) { + // If we have a type with trivial initialization and no initializer, + // there's nothing to do. + if (!e->hasInitializer()) + return; + + cgm.errorNYI(e->getSourceRange(), "emitNewArrayInitializer"); +} + static void emitNewInitializer(CIRGenFunction &cgf, const CXXNewExpr *e, QualType elementType, mlir::Type elementTy, Address newPtr, mlir::Value numElements, mlir::Value allocSizeWithoutCookie) { assert(!cir::MissingFeatures::generateDebugInfo()); if (e->isArray()) { - cgf.cgm.errorNYI(e->getSourceRange(), "emitNewInitializer: array"); + cgf.emitNewArrayInitializer(e, elementType, elementTy, newPtr, numElements, + allocSizeWithoutCookie); } else if (const Expr *init = e->getInitializer()) { storeAnyExprIntoOneUnit(cgf, init, e->getAllocatedType(), newPtr, AggValueSlot::DoesNotOverlap); @@ -346,12 +463,6 @@ struct CallObjectDelete final : EHScopeStack::Cleanup { void emit(CIRGenFunction &cgf) override { cgf.emitDeleteCall(operatorDelete, ptr, elementType); } - - // This is a placeholder until EHCleanupScope is implemented. - size_t getSize() const override { - assert(!cir::MissingFeatures::ehCleanupScope()); - return sizeof(CallObjectDelete); - } }; } // namespace @@ -536,7 +647,14 @@ mlir::Value CIRGenFunction::emitCXXNewExpr(const CXXNewExpr *e) { if (allocSize != allocSizeWithoutCookie) cgm.errorNYI(e->getSourceRange(), "emitCXXNewExpr: array with cookies"); - mlir::Type elementTy = convertTypeForMem(allocType); + mlir::Type elementTy; + if (e->isArray()) { + // For array new, use the allocated type to handle multidimensional arrays + // correctly + elementTy = convertTypeForMem(e->getAllocatedType()); + } else { + elementTy = convertTypeForMem(allocType); + } Address result = builder.createElementBitCast(getLoc(e->getSourceRange()), allocation, elementTy); @@ -604,3 +722,43 @@ void CIRGenFunction::emitDeleteCall(const FunctionDecl *deleteFD, // Emit the call to delete. emitNewDeleteCall(*this, deleteFD, deleteFTy, deleteArgs); } + +mlir::Value CIRGenFunction::emitDynamicCast(Address thisAddr, + const CXXDynamicCastExpr *dce) { + mlir::Location loc = getLoc(dce->getSourceRange()); + + cgm.emitExplicitCastExprType(dce, this); + QualType destTy = dce->getTypeAsWritten(); + QualType srcTy = dce->getSubExpr()->getType(); + + // C++ [expr.dynamic.cast]p7: + // If T is "pointer to cv void," then the result is a pointer to the most + // derived object pointed to by v. + bool isDynCastToVoid = destTy->isVoidPointerType(); + bool isRefCast = destTy->isReferenceType(); + + QualType srcRecordTy; + QualType destRecordTy; + if (isDynCastToVoid) { + srcRecordTy = srcTy->getPointeeType(); + // No destRecordTy. + } else if (const PointerType *destPTy = destTy->getAs<PointerType>()) { + srcRecordTy = srcTy->castAs<PointerType>()->getPointeeType(); + destRecordTy = destPTy->getPointeeType(); + } else { + srcRecordTy = srcTy; + destRecordTy = destTy->castAs<ReferenceType>()->getPointeeType(); + } + + assert(srcRecordTy->isRecordType() && "source type must be a record type!"); + assert(!cir::MissingFeatures::emitTypeCheck()); + + if (dce->isAlwaysNull()) { + cgm.errorNYI(dce->getSourceRange(), "emitDynamicCastToNull"); + return {}; + } + + auto destCirTy = mlir::cast<cir::PointerType>(convertType(destTy)); + return cgm.getCXXABI().emitDynamicCast(*this, loc, srcRecordTy, destRecordTy, + destCirTy, isRefCast, thisAddr); +} diff --git a/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp b/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp index e20a4fc..89e9ec4 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp @@ -118,6 +118,9 @@ class ConstantAggregateBuilder : private ConstantAggregateBuilderUtils { /// non-packed LLVM struct will give the correct layout. bool naturalLayout = true; + bool split(size_t index, CharUnits hint); + std::optional<size_t> splitAt(CharUnits pos); + static mlir::Attribute buildFrom(CIRGenModule &cgm, ArrayRef<Element> elems, CharUnits startOffset, CharUnits size, bool naturalLayout, mlir::Type desiredTy, @@ -137,6 +140,10 @@ public: /// Update or overwrite the bits starting at \p offsetInBits with \p bits. bool addBits(llvm::APInt bits, uint64_t offsetInBits, bool allowOverwrite); + /// Attempt to condense the value starting at \p offset to a constant of type + /// \p desiredTy. + void condense(CharUnits offset, mlir::Type desiredTy); + /// Produce a constant representing the entire accumulated value, ideally of /// the specified type. If \p allowOversized, the constant might be larger /// than implied by \p desiredTy (eg, if there is a flexible array member). @@ -176,6 +183,195 @@ bool ConstantAggregateBuilder::add(mlir::TypedAttr typedAttr, CharUnits offset, return false; } +bool ConstantAggregateBuilder::addBits(llvm::APInt bits, uint64_t offsetInBits, + bool allowOverwrite) { + const ASTContext &astContext = cgm.getASTContext(); + const uint64_t charWidth = astContext.getCharWidth(); + mlir::Type charTy = cgm.getBuilder().getUIntNTy(charWidth); + + // Offset of where we want the first bit to go within the bits of the + // current char. + unsigned offsetWithinChar = offsetInBits % charWidth; + + // We split bit-fields up into individual bytes. Walk over the bytes and + // update them. + for (CharUnits offsetInChars = + astContext.toCharUnitsFromBits(offsetInBits - offsetWithinChar); + /**/; ++offsetInChars) { + // Number of bits we want to fill in this char. + unsigned wantedBits = + std::min((uint64_t)bits.getBitWidth(), charWidth - offsetWithinChar); + + // Get a char containing the bits we want in the right places. The other + // bits have unspecified values. + llvm::APInt bitsThisChar = bits; + if (bitsThisChar.getBitWidth() < charWidth) + bitsThisChar = bitsThisChar.zext(charWidth); + if (cgm.getDataLayout().isBigEndian()) { + // Figure out how much to shift by. We may need to left-shift if we have + // less than one byte of Bits left. + int shift = bits.getBitWidth() - charWidth + offsetWithinChar; + if (shift > 0) + bitsThisChar.lshrInPlace(shift); + else if (shift < 0) + bitsThisChar = bitsThisChar.shl(-shift); + } else { + bitsThisChar = bitsThisChar.shl(offsetWithinChar); + } + if (bitsThisChar.getBitWidth() > charWidth) + bitsThisChar = bitsThisChar.trunc(charWidth); + + if (wantedBits == charWidth) { + // Got a full byte: just add it directly. + add(cir::IntAttr::get(charTy, bitsThisChar), offsetInChars, + allowOverwrite); + } else { + // Partial byte: update the existing integer if there is one. If we + // can't split out a 1-CharUnit range to update, then we can't add + // these bits and fail the entire constant emission. + std::optional<size_t> firstElemToUpdate = splitAt(offsetInChars); + if (!firstElemToUpdate) + return false; + std::optional<size_t> lastElemToUpdate = + splitAt(offsetInChars + CharUnits::One()); + if (!lastElemToUpdate) + return false; + assert(*lastElemToUpdate - *firstElemToUpdate < 2 && + "should have at most one element covering one byte"); + + // Figure out which bits we want and discard the rest. + llvm::APInt updateMask(charWidth, 0); + if (cgm.getDataLayout().isBigEndian()) + updateMask.setBits(charWidth - offsetWithinChar - wantedBits, + charWidth - offsetWithinChar); + else + updateMask.setBits(offsetWithinChar, offsetWithinChar + wantedBits); + bitsThisChar &= updateMask; + bool isNull = false; + if (*firstElemToUpdate < elements.size()) { + auto firstEltToUpdate = + mlir::dyn_cast<cir::IntAttr>(elements[*firstElemToUpdate].element); + isNull = firstEltToUpdate && firstEltToUpdate.isNullValue(); + } + + if (*firstElemToUpdate == *lastElemToUpdate || isNull) { + // All existing bits are either zero or undef. + add(cir::IntAttr::get(charTy, bitsThisChar), offsetInChars, + /*allowOverwrite*/ true); + } else { + cir::IntAttr ci = + mlir::dyn_cast<cir::IntAttr>(elements[*firstElemToUpdate].element); + // In order to perform a partial update, we need the existing bitwise + // value, which we can only extract for a constant int. + if (!ci) + return false; + // Because this is a 1-CharUnit range, the constant occupying it must + // be exactly one CharUnit wide. + assert(ci.getBitWidth() == charWidth && "splitAt failed"); + assert((!(ci.getValue() & updateMask) || allowOverwrite) && + "unexpectedly overwriting bitfield"); + bitsThisChar |= (ci.getValue() & ~updateMask); + elements[*firstElemToUpdate].element = + cir::IntAttr::get(charTy, bitsThisChar); + } + } + + // Stop if we've added all the bits. + if (wantedBits == bits.getBitWidth()) + break; + + // Remove the consumed bits from Bits. + if (!cgm.getDataLayout().isBigEndian()) + bits.lshrInPlace(wantedBits); + bits = bits.trunc(bits.getBitWidth() - wantedBits); + + // The remaining bits go at the start of the following bytes. + offsetWithinChar = 0; + } + + return true; +} + +/// Returns a position within elements such that all elements +/// before the returned index end before pos and all elements at or after +/// the returned index begin at or after pos. Splits elements as necessary +/// to ensure this. Returns std::nullopt if we find something we can't split. +std::optional<size_t> ConstantAggregateBuilder::splitAt(CharUnits pos) { + if (pos >= size) + return elements.size(); + + while (true) { + // Find the first element that starts after pos. + Element *iter = + llvm::upper_bound(elements, pos, [](CharUnits pos, const Element &elt) { + return pos < elt.offset; + }); + + if (iter == elements.begin()) + return 0; + + size_t index = iter - elements.begin() - 1; + const Element &elt = elements[index]; + + // If we already have an element starting at pos, we're done. + if (elt.offset == pos) + return index; + + // Check for overlap with the element that starts before pos. + CharUnits eltEnd = elt.offset + getSize(elt.element); + if (eltEnd <= pos) + return index + 1; + + // Try to decompose it into smaller constants. + if (!split(index, pos)) + return std::nullopt; + } +} + +/// Split the constant at index, if possible. Return true if we did. +/// Hint indicates the location at which we'd like to split, but may be +/// ignored. +bool ConstantAggregateBuilder::split(size_t index, CharUnits hint) { + cgm.errorNYI("split constant at index"); + return false; +} + +void ConstantAggregateBuilder::condense(CharUnits offset, + mlir::Type desiredTy) { + CharUnits desiredSize = getSize(desiredTy); + + std::optional<size_t> firstElemToReplace = splitAt(offset); + if (!firstElemToReplace) + return; + size_t first = *firstElemToReplace; + + std::optional<size_t> lastElemToReplace = splitAt(offset + desiredSize); + if (!lastElemToReplace) + return; + size_t last = *lastElemToReplace; + + size_t length = last - first; + if (length == 0) + return; + + if (length == 1 && elements[first].offset == offset && + getSize(elements[first].element) == desiredSize) { + cgm.errorNYI("re-wrapping single element records"); + return; + } + + // Build a new constant from the elements in the range. + SmallVector<Element> subElems(elements.begin() + first, + elements.begin() + last); + mlir::Attribute replacement = + buildFrom(cgm, subElems, offset, desiredSize, + /*naturalLayout=*/false, desiredTy, false); + + // Replace the range with the condensed constant. + Element newElt(mlir::cast<mlir::TypedAttr>(replacement), offset); + replace(elements, first, last, {newElt}); +} + mlir::Attribute ConstantAggregateBuilder::buildFrom(CIRGenModule &cgm, ArrayRef<Element> elems, CharUnits startOffset, CharUnits size, @@ -301,6 +497,29 @@ private: bool appendBytes(CharUnits fieldOffsetInChars, mlir::TypedAttr initCst, bool allowOverwrite = false); + bool appendBitField(const FieldDecl *field, uint64_t fieldOffset, + cir::IntAttr ci, bool allowOverwrite = false); + + /// Applies zero-initialization to padding bytes before and within a field. + /// \param layout The record layout containing field offset information. + /// \param fieldNo The field index in the record. + /// \param field The field declaration. + /// \param allowOverwrite Whether to allow overwriting existing values. + /// \param sizeSoFar The current size processed, updated by this function. + /// \param zeroFieldSize Set to true if the field has zero size. + /// \returns true on success, false if padding could not be applied. + bool applyZeroInitPadding(const ASTRecordLayout &layout, unsigned fieldNo, + const FieldDecl &field, bool allowOverwrite, + CharUnits &sizeSoFar, bool &zeroFieldSize); + + /// Applies zero-initialization to trailing padding bytes in a record. + /// \param layout The record layout containing size information. + /// \param allowOverwrite Whether to allow overwriting existing values. + /// \param sizeSoFar The current size processed. + /// \returns true on success, false if padding could not be applied. + bool applyZeroInitPadding(const ASTRecordLayout &layout, bool allowOverwrite, + CharUnits &sizeSoFar); + bool build(InitListExpr *ile, bool allowOverwrite); bool build(const APValue &val, const RecordDecl *rd, bool isPrimaryBase, const CXXRecordDecl *vTableClass, CharUnits baseOffset); @@ -325,6 +544,73 @@ bool ConstRecordBuilder::appendBytes(CharUnits fieldOffsetInChars, return builder.add(initCst, startOffset + fieldOffsetInChars, allowOverwrite); } +bool ConstRecordBuilder::appendBitField(const FieldDecl *field, + uint64_t fieldOffset, cir::IntAttr ci, + bool allowOverwrite) { + const CIRGenRecordLayout &rl = + cgm.getTypes().getCIRGenRecordLayout(field->getParent()); + const CIRGenBitFieldInfo &info = rl.getBitFieldInfo(field); + llvm::APInt fieldValue = ci.getValue(); + + // Promote the size of FieldValue if necessary + // FIXME: This should never occur, but currently it can because initializer + // constants are cast to bool, and because clang is not enforcing bitfield + // width limits. + if (info.size > fieldValue.getBitWidth()) + fieldValue = fieldValue.zext(info.size); + + // Truncate the size of FieldValue to the bit field size. + if (info.size < fieldValue.getBitWidth()) + fieldValue = fieldValue.trunc(info.size); + + return builder.addBits(fieldValue, + cgm.getASTContext().toBits(startOffset) + fieldOffset, + allowOverwrite); +} + +bool ConstRecordBuilder::applyZeroInitPadding( + const ASTRecordLayout &layout, unsigned fieldNo, const FieldDecl &field, + bool allowOverwrite, CharUnits &sizeSoFar, bool &zeroFieldSize) { + uint64_t startBitOffset = layout.getFieldOffset(fieldNo); + CharUnits startOffset = + cgm.getASTContext().toCharUnitsFromBits(startBitOffset); + if (sizeSoFar < startOffset) { + if (!appendBytes(sizeSoFar, computePadding(cgm, startOffset - sizeSoFar), + allowOverwrite)) + return false; + } + + if (!field.isBitField()) { + CharUnits fieldSize = + cgm.getASTContext().getTypeSizeInChars(field.getType()); + sizeSoFar = startOffset + fieldSize; + zeroFieldSize = fieldSize.isZero(); + } else { + const CIRGenRecordLayout &rl = + cgm.getTypes().getCIRGenRecordLayout(field.getParent()); + const CIRGenBitFieldInfo &info = rl.getBitFieldInfo(&field); + uint64_t endBitOffset = startBitOffset + info.size; + sizeSoFar = cgm.getASTContext().toCharUnitsFromBits(endBitOffset); + if (endBitOffset % cgm.getASTContext().getCharWidth() != 0) + sizeSoFar++; + zeroFieldSize = info.size == 0; + } + return true; +} + +bool ConstRecordBuilder::applyZeroInitPadding(const ASTRecordLayout &layout, + bool allowOverwrite, + CharUnits &sizeSoFar) { + CharUnits totalSize = layout.getSize(); + if (sizeSoFar < totalSize) { + if (!appendBytes(sizeSoFar, computePadding(cgm, totalSize - sizeSoFar), + allowOverwrite)) + return false; + } + sizeSoFar = totalSize; + return true; +} + bool ConstRecordBuilder::build(InitListExpr *ile, bool allowOverwrite) { RecordDecl *rd = ile->getType() ->castAs<clang::RecordType>() @@ -339,11 +625,9 @@ bool ConstRecordBuilder::build(InitListExpr *ile, bool allowOverwrite) { if (cxxrd->getNumBases()) return false; - if (cgm.shouldZeroInitPadding()) { - assert(!cir::MissingFeatures::recordZeroInitPadding()); - cgm.errorNYI(rd->getSourceRange(), "zero init padding"); - return false; - } + const bool zeroInitPadding = cgm.shouldZeroInitPadding(); + bool zeroFieldSize = false; + CharUnits sizeSoFar = CharUnits::Zero(); unsigned elementNo = 0; for (auto [index, field] : llvm::enumerate(rd->fields())) { @@ -373,7 +657,10 @@ bool ConstRecordBuilder::build(InitListExpr *ile, bool allowOverwrite) { continue; } - assert(!cir::MissingFeatures::recordZeroInitPadding()); + if (zeroInitPadding && + !applyZeroInitPadding(layout, index, *field, allowOverwrite, sizeSoFar, + zeroFieldSize)) + return false; // When emitting a DesignatedInitUpdateExpr, a nested InitListExpr // represents additional overwriting of our current constant value, and not @@ -407,17 +694,19 @@ bool ConstRecordBuilder::build(InitListExpr *ile, bool allowOverwrite) { } else { // Otherwise we have a bitfield. if (auto constInt = dyn_cast<cir::IntAttr>(eltInit)) { - assert(!cir::MissingFeatures::bitfields()); - cgm.errorNYI(field->getSourceRange(), "bitfields"); + if (!appendBitField(field, layout.getFieldOffset(index), constInt, + allowOverwrite)) + return false; + } else { + // We are trying to initialize a bitfield with a non-trivial constant, + // this must require run-time code. + return false; } - // We are trying to initialize a bitfield with a non-trivial constant, - // this must require run-time code. - return false; } } - assert(!cir::MissingFeatures::recordZeroInitPadding()); - return true; + return !zeroInitPadding || + applyZeroInitPadding(layout, allowOverwrite, sizeSoFar); } namespace { @@ -510,8 +799,16 @@ bool ConstRecordBuilder::build(const APValue &val, const RecordDecl *rd, if (field->hasAttr<NoUniqueAddressAttr>()) allowOverwrite = true; } else { - assert(!cir::MissingFeatures::bitfields()); - cgm.errorNYI(field->getSourceRange(), "bitfields"); + // Otherwise we have a bitfield. + if (auto constInt = dyn_cast<cir::IntAttr>(eltInit)) { + if (!appendBitField(field, layout.getFieldOffset(index) + offsetBits, + constInt, allowOverwrite)) + return false; + } else { + // We are trying to initialize a bitfield with a non-trivial constant, + // this must require run-time code. + return false; + } } } diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp index 5d3496a..637f9ef 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp @@ -1893,7 +1893,34 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *ce) { } return v; } + case CK_IntegralToPointer: { + mlir::Type destCIRTy = cgf.convertType(destTy); + mlir::Value src = Visit(const_cast<Expr *>(subExpr)); + + // Properly resize by casting to an int of the same size as the pointer. + // Clang's IntegralToPointer includes 'bool' as the source, but in CIR + // 'bool' is not an integral type. So check the source type to get the + // correct CIR conversion. + mlir::Type middleTy = cgf.cgm.getDataLayout().getIntPtrType(destCIRTy); + mlir::Value middleVal = builder.createCast( + subExpr->getType()->isBooleanType() ? cir::CastKind::bool_to_int + : cir::CastKind::integral, + src, middleTy); + + if (cgf.cgm.getCodeGenOpts().StrictVTablePointers) { + cgf.cgm.errorNYI(subExpr->getSourceRange(), + "IntegralToPointer: strict vtable pointers"); + return {}; + } + return builder.createIntToPtr(middleVal, destCIRTy); + } + + case CK_Dynamic: { + Address v = cgf.emitPointerWithAlignment(subExpr); + const auto *dce = cast<CXXDynamicCastExpr>(ce); + return cgf.emitDynamicCast(v, dce); + } case CK_ArrayToPointerDecay: return cgf.emitArrayToPointerDecay(subExpr).getPointer(); diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.cpp b/clang/lib/CIR/CodeGen/CIRGenFunction.cpp index 52fb0d7..7a774e0 100644 --- a/clang/lib/CIR/CodeGen/CIRGenFunction.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenFunction.cpp @@ -689,7 +689,9 @@ void CIRGenFunction::emitDestructorBody(FunctionArgList &args) { cgm.errorNYI(dtor->getSourceRange(), "function-try-block destructor"); assert(!cir::MissingFeatures::sanitizers()); - assert(!cir::MissingFeatures::dtorCleanups()); + + // Enter the epilogue cleanups. + RunCleanupsScope dtorEpilogue(*this); // If this is the complete variant, just invoke the base variant; // the epilogue will destruct the virtual bases. But we can't do @@ -708,7 +710,8 @@ void CIRGenFunction::emitDestructorBody(FunctionArgList &args) { assert((body || getTarget().getCXXABI().isMicrosoft()) && "can't emit a dtor without a body for non-Microsoft ABIs"); - assert(!cir::MissingFeatures::dtorCleanups()); + // Enter the cleanup scopes for virtual bases. + enterDtorCleanups(dtor, Dtor_Complete); if (!isTryBody) { QualType thisTy = dtor->getFunctionObjectParameterType(); @@ -723,7 +726,9 @@ void CIRGenFunction::emitDestructorBody(FunctionArgList &args) { case Dtor_Base: assert(body); - assert(!cir::MissingFeatures::dtorCleanups()); + // Enter the cleanup scopes for fields and non-virtual bases. + enterDtorCleanups(dtor, Dtor_Base); + assert(!cir::MissingFeatures::vtableInitialization()); if (isTryBody) { @@ -741,7 +746,8 @@ void CIRGenFunction::emitDestructorBody(FunctionArgList &args) { break; } - assert(!cir::MissingFeatures::dtorCleanups()); + // Jump out through the epilogue cleanups. + dtorEpilogue.forceCleanup(); // Exit the try if applicable. if (isTryBody) diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.h b/clang/lib/CIR/CodeGen/CIRGenFunction.h index cbc0f4a..7a606ee 100644 --- a/clang/lib/CIR/CodeGen/CIRGenFunction.h +++ b/clang/lib/CIR/CodeGen/CIRGenFunction.h @@ -556,6 +556,33 @@ public: cir::GlobalOp gv, cir::GetGlobalOp gvAddr); + /// Enter the cleanups necessary to complete the given phase of destruction + /// for a destructor. The end result should call destructors on members and + /// base classes in reverse order of their construction. + void enterDtorCleanups(const CXXDestructorDecl *dtor, CXXDtorType type); + + /// Determines whether an EH cleanup is required to destroy a type + /// with the given destruction kind. + /// TODO(cir): could be shared with Clang LLVM codegen + bool needsEHCleanup(QualType::DestructionKind kind) { + switch (kind) { + case QualType::DK_none: + return false; + case QualType::DK_cxx_destructor: + case QualType::DK_objc_weak_lifetime: + case QualType::DK_nontrivial_c_struct: + return getLangOpts().Exceptions; + case QualType::DK_objc_strong_lifetime: + return getLangOpts().Exceptions && + cgm.getCodeGenOpts().ObjCAutoRefCountExceptions; + } + llvm_unreachable("bad destruction kind"); + } + + CleanupKind getCleanupKind(QualType::DestructionKind kind) { + return needsEHCleanup(kind) ? NormalAndEHCleanup : NormalCleanup; + } + /// Set the address of a local variable. void setAddrOfLocalVar(const clang::VarDecl *vd, Address addr) { assert(!localDeclMap.count(vd) && "Decl already exists in LocalDeclMap!"); @@ -1090,6 +1117,8 @@ public: /// even if no aggregate location is provided. RValue emitAnyExprToTemp(const clang::Expr *e); + void emitAnyExprToExn(const Expr *e, Address addr); + void emitArrayDestroy(mlir::Value begin, mlir::Value numElements, QualType elementType, CharUnits elementAlign, Destroyer *destroyer); @@ -1252,6 +1281,11 @@ public: mlir::Value emitCXXNewExpr(const CXXNewExpr *e); + void emitNewArrayInitializer(const CXXNewExpr *E, QualType ElementType, + mlir::Type ElementTy, Address BeginPtr, + mlir::Value NumElements, + mlir::Value AllocSizeWithoutCookie); + RValue emitCXXOperatorMemberCallExpr(const CXXOperatorCallExpr *e, const CXXMethodDecl *md, ReturnValueSlot returnValue); @@ -1278,6 +1312,8 @@ public: mlir::LogicalResult emitDoStmt(const clang::DoStmt &s); + mlir::Value emitDynamicCast(Address thisAddr, const CXXDynamicCastExpr *dce); + /// Emit an expression as an initializer for an object (variable, field, etc.) /// at the given location. The expression is not necessarily the normal /// initializer for the object, and the address is not necessarily diff --git a/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp b/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp index debea8af..9e490c6d 100644 --- a/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp @@ -59,7 +59,11 @@ public: void addImplicitStructorParams(CIRGenFunction &cgf, QualType &resTy, FunctionArgList ¶ms) override; - + mlir::Value getCXXDestructorImplicitParam(CIRGenFunction &cgf, + const CXXDestructorDecl *dd, + CXXDtorType type, + bool forVirtualBase, + bool delegating) override; void emitCXXConstructors(const clang::CXXConstructorDecl *d) override; void emitCXXDestructors(const clang::CXXDestructorDecl *d) override; void emitCXXStructor(clang::GlobalDecl gd) override; @@ -68,8 +72,11 @@ public: CXXDtorType type, bool forVirtualBase, bool delegating, Address thisAddr, QualType thisTy) override; + void registerGlobalDtor(const VarDecl *vd, cir::FuncOp dtor, + mlir::Value addr) override; void emitRethrow(CIRGenFunction &cgf, bool isNoReturn) override; + void emitThrow(CIRGenFunction &cgf, const CXXThrowExpr *e) override; bool useThunkForDtorVariant(const CXXDestructorDecl *dtor, CXXDtorType dt) const override { @@ -115,6 +122,16 @@ public: Address thisAddr, const CXXRecordDecl *classDecl, const CXXRecordDecl *baseClassDecl) override; + // The traditional clang CodeGen emits calls to `__dynamic_cast` directly into + // LLVM in the `emitDynamicCastCall` function. In CIR, `dynamic_cast` + // expressions are lowered to `cir.dyn_cast` ops instead of calls to runtime + // functions. So during CIRGen we don't need the `emitDynamicCastCall` + // function that clang CodeGen has. + mlir::Value emitDynamicCast(CIRGenFunction &cgf, mlir::Location loc, + QualType srcRecordTy, QualType destRecordTy, + cir::PointerType destCIRTy, bool isRefCast, + Address src) override; + /**************************** RTTI Uniqueness ******************************/ protected: /// Returns true if the ABI requires RTTI type_info objects to be unique @@ -1491,11 +1508,8 @@ void CIRGenItaniumCXXABI::emitDestructorCall( CIRGenFunction &cgf, const CXXDestructorDecl *dd, CXXDtorType type, bool forVirtualBase, bool delegating, Address thisAddr, QualType thisTy) { GlobalDecl gd(dd, type); - if (needsVTTParameter(gd)) { - cgm.errorNYI(dd->getSourceRange(), "emitDestructorCall: VTT"); - } - - mlir::Value vtt = nullptr; + mlir::Value vtt = + getCXXDestructorImplicitParam(cgf, dd, type, forVirtualBase, delegating); ASTContext &astContext = cgm.getASTContext(); QualType vttTy = astContext.getPointerType(astContext.VoidPtrTy); assert(!cir::MissingFeatures::appleKext()); @@ -1506,6 +1520,34 @@ void CIRGenItaniumCXXABI::emitDestructorCall( vttTy, nullptr); } +void CIRGenItaniumCXXABI::registerGlobalDtor(const VarDecl *vd, + cir::FuncOp dtor, + mlir::Value addr) { + if (vd->isNoDestroy(cgm.getASTContext())) + return; + + if (vd->getTLSKind()) { + cgm.errorNYI(vd->getSourceRange(), "registerGlobalDtor: TLS"); + return; + } + + // HLSL doesn't support atexit. + if (cgm.getLangOpts().HLSL) { + cgm.errorNYI(vd->getSourceRange(), "registerGlobalDtor: HLSL"); + return; + } + + // The default behavior is to use atexit. This is handled in lowering + // prepare. Nothing to be done for CIR here. +} + +mlir::Value CIRGenItaniumCXXABI::getCXXDestructorImplicitParam( + CIRGenFunction &cgf, const CXXDestructorDecl *dd, CXXDtorType type, + bool forVirtualBase, bool delegating) { + GlobalDecl gd(dd, type); + return cgf.getVTTParameter(gd, forVirtualBase, delegating); +} + // The idea here is creating a separate block for the throw with an // `UnreachableOp` as the terminator. So, we branch from the current block // to the throw block and create a block for the remaining operations. @@ -1544,6 +1586,59 @@ void CIRGenItaniumCXXABI::emitRethrow(CIRGenFunction &cgf, bool isNoReturn) { } } +void CIRGenItaniumCXXABI::emitThrow(CIRGenFunction &cgf, + const CXXThrowExpr *e) { + // This differs a bit from LLVM codegen, CIR has native operations for some + // cxa functions, and defers allocation size computation, always pass the dtor + // symbol, etc. CIRGen also does not use getAllocateExceptionFn / getThrowFn. + + // Now allocate the exception object. + CIRGenBuilderTy &builder = cgf.getBuilder(); + QualType clangThrowType = e->getSubExpr()->getType(); + cir::PointerType throwTy = + builder.getPointerTo(cgf.convertType(clangThrowType)); + uint64_t typeSize = + cgf.getContext().getTypeSizeInChars(clangThrowType).getQuantity(); + mlir::Location subExprLoc = cgf.getLoc(e->getSubExpr()->getSourceRange()); + + // Defer computing allocation size to some later lowering pass. + mlir::TypedValue<cir::PointerType> exceptionPtr = + cir::AllocExceptionOp::create(builder, subExprLoc, throwTy, + builder.getI64IntegerAttr(typeSize)) + .getAddr(); + + // Build expression and store its result into exceptionPtr. + CharUnits exnAlign = cgf.getContext().getExnObjectAlignment(); + cgf.emitAnyExprToExn(e->getSubExpr(), Address(exceptionPtr, exnAlign)); + + // Get the RTTI symbol address. + auto typeInfo = mlir::cast<cir::GlobalViewAttr>( + cgm.getAddrOfRTTIDescriptor(subExprLoc, clangThrowType, + /*forEH=*/true)); + assert(!typeInfo.getIndices() && "expected no indirection"); + + // The address of the destructor. + // + // Note: LLVM codegen already optimizes out the dtor if the + // type is a record with trivial dtor (by passing down a + // null dtor). In CIR, we forward this info and allow for + // Lowering pass to skip passing the trivial function. + // + if (const RecordType *recordTy = clangThrowType->getAs<RecordType>()) { + CXXRecordDecl *rec = + cast<CXXRecordDecl>(recordTy->getOriginalDecl()->getDefinition()); + assert(!cir::MissingFeatures::isTrivialCtorOrDtor()); + if (!rec->hasTrivialDestructor()) { + cgm.errorNYI("emitThrow: non-trivial destructor"); + return; + } + } + + // Now throw the exception. + mlir::Location loc = cgf.getLoc(e->getSourceRange()); + insertThrowAndSplit(builder, loc, exceptionPtr, typeInfo.getSymbol()); +} + CIRGenCXXABI *clang::CIRGen::CreateCIRGenItaniumCXXABI(CIRGenModule &cgm) { switch (cgm.getASTContext().getCXXABIKind()) { case TargetCXXABI::GenericItanium: @@ -1742,3 +1837,143 @@ mlir::Value CIRGenItaniumCXXABI::getVirtualBaseClassOffset( } return vbaseOffset; } + +static cir::FuncOp getBadCastFn(CIRGenFunction &cgf) { + // Prototype: void __cxa_bad_cast(); + + // TODO(cir): set the calling convention of the runtime function. + assert(!cir::MissingFeatures::opFuncCallingConv()); + + cir::FuncType fnTy = + cgf.getBuilder().getFuncType({}, cgf.getBuilder().getVoidTy()); + return cgf.cgm.createRuntimeFunction(fnTy, "__cxa_bad_cast"); +} + +// TODO(cir): This could be shared with classic codegen. +static CharUnits computeOffsetHint(ASTContext &astContext, + const CXXRecordDecl *src, + const CXXRecordDecl *dst) { + CXXBasePaths paths(/*FindAmbiguities=*/true, /*RecordPaths=*/true, + /*DetectVirtual=*/false); + + // If Dst is not derived from Src we can skip the whole computation below and + // return that Src is not a public base of Dst. Record all inheritance paths. + if (!dst->isDerivedFrom(src, paths)) + return CharUnits::fromQuantity(-2ULL); + + unsigned numPublicPaths = 0; + CharUnits offset; + + // Now walk all possible inheritance paths. + for (const CXXBasePath &path : paths) { + if (path.Access != AS_public) // Ignore non-public inheritance. + continue; + + ++numPublicPaths; + + for (const CXXBasePathElement &pathElement : path) { + // If the path contains a virtual base class we can't give any hint. + // -1: no hint. + if (pathElement.Base->isVirtual()) + return CharUnits::fromQuantity(-1ULL); + + if (numPublicPaths > 1) // Won't use offsets, skip computation. + continue; + + // Accumulate the base class offsets. + const ASTRecordLayout &L = + astContext.getASTRecordLayout(pathElement.Class); + offset += L.getBaseClassOffset( + pathElement.Base->getType()->getAsCXXRecordDecl()); + } + } + + // -2: Src is not a public base of Dst. + if (numPublicPaths == 0) + return CharUnits::fromQuantity(-2ULL); + + // -3: Src is a multiple public base type but never a virtual base type. + if (numPublicPaths > 1) + return CharUnits::fromQuantity(-3ULL); + + // Otherwise, the Src type is a unique public nonvirtual base type of Dst. + // Return the offset of Src from the origin of Dst. + return offset; +} + +static cir::FuncOp getItaniumDynamicCastFn(CIRGenFunction &cgf) { + // Prototype: + // void *__dynamic_cast(const void *sub, + // global_as const abi::__class_type_info *src, + // global_as const abi::__class_type_info *dst, + // std::ptrdiff_t src2dst_offset); + + mlir::Type voidPtrTy = cgf.getBuilder().getVoidPtrTy(); + mlir::Type rttiPtrTy = cgf.getBuilder().getUInt8PtrTy(); + mlir::Type ptrDiffTy = cgf.convertType(cgf.getContext().getPointerDiffType()); + + // TODO(cir): mark the function as nowind willreturn readonly. + assert(!cir::MissingFeatures::opFuncNoUnwind()); + assert(!cir::MissingFeatures::opFuncWillReturn()); + assert(!cir::MissingFeatures::opFuncReadOnly()); + + // TODO(cir): set the calling convention of the runtime function. + assert(!cir::MissingFeatures::opFuncCallingConv()); + + cir::FuncType FTy = cgf.getBuilder().getFuncType( + {voidPtrTy, rttiPtrTy, rttiPtrTy, ptrDiffTy}, voidPtrTy); + return cgf.cgm.createRuntimeFunction(FTy, "__dynamic_cast"); +} + +static cir::DynamicCastInfoAttr emitDynamicCastInfo(CIRGenFunction &cgf, + mlir::Location loc, + QualType srcRecordTy, + QualType destRecordTy) { + auto srcRtti = mlir::cast<cir::GlobalViewAttr>( + cgf.cgm.getAddrOfRTTIDescriptor(loc, srcRecordTy)); + auto destRtti = mlir::cast<cir::GlobalViewAttr>( + cgf.cgm.getAddrOfRTTIDescriptor(loc, destRecordTy)); + + cir::FuncOp runtimeFuncOp = getItaniumDynamicCastFn(cgf); + cir::FuncOp badCastFuncOp = getBadCastFn(cgf); + auto runtimeFuncRef = mlir::FlatSymbolRefAttr::get(runtimeFuncOp); + auto badCastFuncRef = mlir::FlatSymbolRefAttr::get(badCastFuncOp); + + const CXXRecordDecl *srcDecl = srcRecordTy->getAsCXXRecordDecl(); + const CXXRecordDecl *destDecl = destRecordTy->getAsCXXRecordDecl(); + CharUnits offsetHint = computeOffsetHint(cgf.getContext(), srcDecl, destDecl); + + mlir::Type ptrdiffTy = cgf.convertType(cgf.getContext().getPointerDiffType()); + auto offsetHintAttr = cir::IntAttr::get(ptrdiffTy, offsetHint.getQuantity()); + + return cir::DynamicCastInfoAttr::get(srcRtti, destRtti, runtimeFuncRef, + badCastFuncRef, offsetHintAttr); +} + +mlir::Value CIRGenItaniumCXXABI::emitDynamicCast(CIRGenFunction &cgf, + mlir::Location loc, + QualType srcRecordTy, + QualType destRecordTy, + cir::PointerType destCIRTy, + bool isRefCast, Address src) { + bool isCastToVoid = destRecordTy.isNull(); + assert((!isCastToVoid || !isRefCast) && "cannot cast to void reference"); + + if (isCastToVoid) { + cgm.errorNYI(loc, "emitDynamicCastToVoid"); + return {}; + } + + // If the destination is effectively final, the cast succeeds if and only + // if the dynamic type of the pointer is exactly the destination type. + if (destRecordTy->getAsCXXRecordDecl()->isEffectivelyFinal() && + cgf.cgm.getCodeGenOpts().OptimizationLevel > 0) { + cgm.errorNYI(loc, "emitExactDynamicCast"); + return {}; + } + + cir::DynamicCastInfoAttr castInfo = + emitDynamicCastInfo(cgf, loc, srcRecordTy, destRecordTy); + return cgf.getBuilder().createDynCast(loc, src.getPointer(), destCIRTy, + isRefCast, castInfo); +} diff --git a/clang/lib/CIR/CodeGen/CIRGenModule.cpp b/clang/lib/CIR/CodeGen/CIRGenModule.cpp index 910c8a9..fe1ea56 100644 --- a/clang/lib/CIR/CodeGen/CIRGenModule.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenModule.cpp @@ -2079,6 +2079,29 @@ CIRGenModule::createCIRBuiltinFunction(mlir::Location loc, StringRef name, return fnOp; } +cir::FuncOp CIRGenModule::createRuntimeFunction(cir::FuncType ty, + StringRef name, mlir::ArrayAttr, + [[maybe_unused]] bool isLocal, + bool assumeConvergent) { + if (assumeConvergent) + errorNYI("createRuntimeFunction: assumeConvergent"); + if (isLocal) + errorNYI("createRuntimeFunction: local"); + + cir::FuncOp entry = getOrCreateCIRFunction(name, ty, GlobalDecl(), + /*forVtable=*/false); + + if (entry) { + // TODO(cir): set the attributes of the function. + assert(!cir::MissingFeatures::setLLVMFunctionFEnvAttributes()); + assert(!cir::MissingFeatures::opFuncCallingConv()); + assert(!cir::MissingFeatures::opGlobalDLLImportExport()); + entry.setDSOLocal(true); + } + + return entry; +} + mlir::SymbolTable::Visibility CIRGenModule::getMLIRVisibility(cir::GlobalOp op) { // MLIR doesn't accept public symbols declarations (only diff --git a/clang/lib/CIR/CodeGen/CIRGenModule.h b/clang/lib/CIR/CodeGen/CIRGenModule.h index c6a6681..f627bae 100644 --- a/clang/lib/CIR/CodeGen/CIRGenModule.h +++ b/clang/lib/CIR/CodeGen/CIRGenModule.h @@ -480,6 +480,10 @@ public: cir::FuncType ty, const clang::FunctionDecl *fd); + cir::FuncOp createRuntimeFunction(cir::FuncType ty, llvm::StringRef name, + mlir::ArrayAttr = {}, bool isLocal = false, + bool assumeConvergent = false); + static constexpr const char *builtinCoroId = "__builtin_coro_id"; /// Given a builtin id for a function like "__builtin_fabsf", return a diff --git a/clang/lib/CIR/CodeGen/CIRGenOpenACC.cpp b/clang/lib/CIR/CodeGen/CIRGenOpenACC.cpp index a9af753..4cf2237 100644 --- a/clang/lib/CIR/CodeGen/CIRGenOpenACC.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenOpenACC.cpp @@ -87,7 +87,10 @@ CIRGenFunction::getOpenACCDataOperandInfo(const Expr *e) { if (const auto *section = dyn_cast<ArraySectionExpr>(curVarExpr)) { QualType baseTy = ArraySectionExpr::getBaseOriginalType( section->getBase()->IgnoreParenImpCasts()); - boundTypes.push_back(QualType(baseTy->getPointeeOrArrayElementType(), 0)); + if (auto *at = getContext().getAsArrayType(baseTy)) + boundTypes.push_back(at->getElementType()); + else + boundTypes.push_back(baseTy->getPointeeType()); } else { boundTypes.push_back(curVarExpr->getType()); } diff --git a/clang/lib/CIR/CodeGen/CIRGenVTables.cpp b/clang/lib/CIR/CodeGen/CIRGenVTables.cpp index 94d856b..84f5977 100644 --- a/clang/lib/CIR/CodeGen/CIRGenVTables.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenVTables.cpp @@ -327,9 +327,40 @@ cir::GlobalLinkageKind CIRGenModule::getVTableLinkage(const CXXRecordDecl *rd) { llvm_unreachable("Should not have been asked to emit this"); } } + // -fapple-kext mode does not support weak linkage, so we must use + // internal linkage. + if (astContext.getLangOpts().AppleKext) + return cir::GlobalLinkageKind::InternalLinkage; + + auto discardableODRLinkage = cir::GlobalLinkageKind::LinkOnceODRLinkage; + auto nonDiscardableODRLinkage = cir::GlobalLinkageKind::WeakODRLinkage; + if (rd->hasAttr<DLLExportAttr>()) { + // Cannot discard exported vtables. + discardableODRLinkage = nonDiscardableODRLinkage; + } else if (rd->hasAttr<DLLImportAttr>()) { + // Imported vtables are available externally. + discardableODRLinkage = cir::GlobalLinkageKind::AvailableExternallyLinkage; + nonDiscardableODRLinkage = + cir::GlobalLinkageKind::AvailableExternallyLinkage; + } + + switch (rd->getTemplateSpecializationKind()) { + case TSK_Undeclared: + case TSK_ExplicitSpecialization: + case TSK_ImplicitInstantiation: + return discardableODRLinkage; + + case TSK_ExplicitInstantiationDeclaration: { + errorNYI(rd->getSourceRange(), + "getVTableLinkage: explicit instantiation declaration"); + return cir::GlobalLinkageKind::ExternalLinkage; + } + + case TSK_ExplicitInstantiationDefinition: + return nonDiscardableODRLinkage; + } - errorNYI(rd->getSourceRange(), "getVTableLinkage: no key function"); - return cir::GlobalLinkageKind::ExternalLinkage; + llvm_unreachable("Invalid TemplateSpecializationKind!"); } cir::GlobalOp CIRGenVTables::getAddrOfVTT(const CXXRecordDecl *rd) { diff --git a/clang/lib/CIR/CodeGen/EHScopeStack.h b/clang/lib/CIR/CodeGen/EHScopeStack.h index c87a6ef..67a72f5 100644 --- a/clang/lib/CIR/CodeGen/EHScopeStack.h +++ b/clang/lib/CIR/CodeGen/EHScopeStack.h @@ -108,9 +108,6 @@ public: /// // \param flags cleanup kind. virtual void emit(CIRGenFunction &cgf) = 0; - - // This is a placeholder until EHScope is implemented. - virtual size_t getSize() const = 0; }; private: @@ -175,6 +172,10 @@ public: return stable_iterator(endOfBuffer - startOfData); } + /// Turn a stable reference to a scope depth into a unstable pointer + /// to the EH stack. + iterator find(stable_iterator savePoint) const; + /// Create a stable reference to the bottom of the EH stack. static stable_iterator stable_end() { return stable_iterator(0); } }; diff --git a/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp b/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp index 3484c59..64ac970 100644 --- a/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp @@ -473,6 +473,49 @@ LogicalResult cir::VTableAttr::verify( } //===----------------------------------------------------------------------===// +// DynamicCastInfoAtttr definitions +//===----------------------------------------------------------------------===// + +std::string DynamicCastInfoAttr::getAlias() const { + // The alias looks like: `dyn_cast_info_<src>_<dest>` + + std::string alias = "dyn_cast_info_"; + + alias.append(getSrcRtti().getSymbol().getValue()); + alias.push_back('_'); + alias.append(getDestRtti().getSymbol().getValue()); + + return alias; +} + +LogicalResult DynamicCastInfoAttr::verify( + function_ref<InFlightDiagnostic()> emitError, cir::GlobalViewAttr srcRtti, + cir::GlobalViewAttr destRtti, mlir::FlatSymbolRefAttr runtimeFunc, + mlir::FlatSymbolRefAttr badCastFunc, cir::IntAttr offsetHint) { + auto isRttiPtr = [](mlir::Type ty) { + // RTTI pointers are !cir.ptr<!u8i>. + + auto ptrTy = mlir::dyn_cast<cir::PointerType>(ty); + if (!ptrTy) + return false; + + auto pointeeIntTy = mlir::dyn_cast<cir::IntType>(ptrTy.getPointee()); + if (!pointeeIntTy) + return false; + + return pointeeIntTy.isUnsigned() && pointeeIntTy.getWidth() == 8; + }; + + if (!isRttiPtr(srcRtti.getType())) + return emitError() << "srcRtti must be an RTTI pointer"; + + if (!isRttiPtr(destRtti.getType())) + return emitError() << "destRtti must be an RTTI pointer"; + + return success(); +} + +//===----------------------------------------------------------------------===// // CIR Dialect //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index cdd4e3c..5f88590 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -71,6 +71,10 @@ struct CIROpAsmDialectInterface : public OpAsmDialectInterface { os << "bfi_" << bitfield.getName().str(); return AliasResult::FinalAlias; } + if (auto dynCastInfoAttr = mlir::dyn_cast<cir::DynamicCastInfoAttr>(attr)) { + os << dynCastInfoAttr.getAlias(); + return AliasResult::FinalAlias; + } return AliasResult::NoAlias; } }; diff --git a/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp b/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp index 2eeef81..706e54f 100644 --- a/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp +++ b/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp @@ -41,6 +41,16 @@ static SmallString<128> getTransformedFileName(mlir::ModuleOp mlirModule) { return fileName; } +/// Return the FuncOp called by `callOp`. +static cir::FuncOp getCalledFunction(cir::CallOp callOp) { + mlir::SymbolRefAttr sym = llvm::dyn_cast_if_present<mlir::SymbolRefAttr>( + callOp.getCallableForCallee()); + if (!sym) + return nullptr; + return dyn_cast_or_null<cir::FuncOp>( + mlir::SymbolTable::lookupNearestSymbolFrom(callOp, sym)); +} + namespace { struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> { LoweringPreparePass() = default; @@ -61,11 +71,20 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> { /// Build a module init function that calls all the dynamic initializers. void buildCXXGlobalInitFunc(); + /// Materialize global ctor/dtor list + void buildGlobalCtorDtorList(); + cir::FuncOp buildRuntimeFunction( mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc, cir::FuncType type, cir::GlobalLinkageKind linkage = cir::GlobalLinkageKind::ExternalLinkage); + cir::GlobalOp buildRuntimeVariable( + mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc, + mlir::Type type, + cir::GlobalLinkageKind linkage = cir::GlobalLinkageKind::ExternalLinkage, + cir::VisibilityKind visibility = cir::VisibilityKind::Default); + /// /// AST related /// ----------- @@ -79,11 +98,33 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> { llvm::StringMap<uint32_t> dynamicInitializerNames; llvm::SmallVector<cir::FuncOp> dynamicInitializers; + /// List of ctors and their priorities to be called before main() + llvm::SmallVector<std::pair<std::string, uint32_t>, 4> globalCtorList; + void setASTContext(clang::ASTContext *c) { astCtx = c; } }; } // namespace +cir::GlobalOp LoweringPreparePass::buildRuntimeVariable( + mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc, + mlir::Type type, cir::GlobalLinkageKind linkage, + cir::VisibilityKind visibility) { + cir::GlobalOp g = dyn_cast_or_null<cir::GlobalOp>( + mlir::SymbolTable::lookupNearestSymbolFrom( + mlirModule, mlir::StringAttr::get(mlirModule->getContext(), name))); + if (!g) { + g = cir::GlobalOp::create(builder, loc, name, type); + g.setLinkageAttr( + cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage)); + mlir::SymbolTable::setSymbolVisibility( + g, mlir::SymbolTable::Visibility::Private); + g.setGlobalVisibilityAttr( + cir::VisibilityAttr::get(builder.getContext(), visibility)); + } + return g; +} + cir::FuncOp LoweringPreparePass::buildRuntimeFunction( mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc, cir::FuncType type, cir::GlobalLinkageKind linkage) { @@ -634,7 +675,8 @@ LoweringPreparePass::buildCXXGlobalVarDeclInitFunc(cir::GlobalOp op) { // Create a variable initialization function. CIRBaseBuilderTy builder(getContext()); builder.setInsertionPointAfter(op); - auto fnType = cir::FuncType::get({}, builder.getVoidTy()); + cir::VoidType voidTy = builder.getVoidTy(); + auto fnType = cir::FuncType::get({}, voidTy); FuncOp f = buildRuntimeFunction(builder, fnName, op.getLoc(), fnType, cir::GlobalLinkageKind::InternalLinkage); @@ -649,8 +691,57 @@ LoweringPreparePass::buildCXXGlobalVarDeclInitFunc(cir::GlobalOp op) { // Register the destructor call with __cxa_atexit mlir::Region &dtorRegion = op.getDtorRegion(); if (!dtorRegion.empty()) { - assert(!cir::MissingFeatures::opGlobalDtorLowering()); - llvm_unreachable("dtor region lowering is NYI"); + assert(!cir::MissingFeatures::astVarDeclInterface()); + assert(!cir::MissingFeatures::opGlobalThreadLocal()); + // Create a variable that binds the atexit to this shared object. + builder.setInsertionPointToStart(&mlirModule.getBodyRegion().front()); + cir::GlobalOp handle = buildRuntimeVariable( + builder, "__dso_handle", op.getLoc(), builder.getI8Type(), + cir::GlobalLinkageKind::ExternalLinkage, cir::VisibilityKind::Hidden); + + // Look for the destructor call in dtorBlock + mlir::Block &dtorBlock = dtorRegion.front(); + cir::CallOp dtorCall; + for (auto op : reverse(dtorBlock.getOps<cir::CallOp>())) { + dtorCall = op; + break; + } + assert(dtorCall && "Expected a dtor call"); + cir::FuncOp dtorFunc = getCalledFunction(dtorCall); + assert(dtorFunc && "Expected a dtor call"); + + // Create a runtime helper function: + // extern "C" int __cxa_atexit(void (*f)(void *), void *p, void *d); + auto voidPtrTy = cir::PointerType::get(voidTy); + auto voidFnTy = cir::FuncType::get({voidPtrTy}, voidTy); + auto voidFnPtrTy = cir::PointerType::get(voidFnTy); + auto handlePtrTy = cir::PointerType::get(handle.getSymType()); + auto fnAtExitType = + cir::FuncType::get({voidFnPtrTy, voidPtrTy, handlePtrTy}, voidTy); + const char *nameAtExit = "__cxa_atexit"; + cir::FuncOp fnAtExit = + buildRuntimeFunction(builder, nameAtExit, op.getLoc(), fnAtExitType); + + // Replace the dtor call with a call to __cxa_atexit(&dtor, &var, + // &__dso_handle) + builder.setInsertionPointAfter(dtorCall); + mlir::Value args[3]; + auto dtorPtrTy = cir::PointerType::get(dtorFunc.getFunctionType()); + // dtorPtrTy + args[0] = cir::GetGlobalOp::create(builder, dtorCall.getLoc(), dtorPtrTy, + dtorFunc.getSymName()); + args[0] = cir::CastOp::create(builder, dtorCall.getLoc(), voidFnPtrTy, + cir::CastKind::bitcast, args[0]); + args[1] = + cir::CastOp::create(builder, dtorCall.getLoc(), voidPtrTy, + cir::CastKind::bitcast, dtorCall.getArgOperand(0)); + args[2] = cir::GetGlobalOp::create(builder, handle.getLoc(), handlePtrTy, + handle.getSymName()); + builder.createCallOp(dtorCall.getLoc(), fnAtExit, args); + dtorCall->erase(); + entryBB->getOperations().splice(entryBB->end(), dtorBlock.getOperations(), + dtorBlock.begin(), + std::prev(dtorBlock.end())); } // Replace cir.yield with cir.return @@ -660,11 +751,12 @@ LoweringPreparePass::buildCXXGlobalVarDeclInitFunc(cir::GlobalOp op) { mlir::Block &block = op.getCtorRegion().front(); yieldOp = &block.getOperations().back(); } else { - assert(!cir::MissingFeatures::opGlobalDtorLowering()); - llvm_unreachable("dtor region lowering is NYI"); + assert(!dtorRegion.empty()); + mlir::Block &block = dtorRegion.front(); + yieldOp = &block.getOperations().back(); } - assert(isa<YieldOp>(*yieldOp)); + assert(isa<cir::YieldOp>(*yieldOp)); cir::ReturnOp::create(builder, yieldOp->getLoc()); return f; } @@ -689,11 +781,39 @@ void LoweringPreparePass::lowerGlobalOp(GlobalOp op) { assert(!cir::MissingFeatures::opGlobalAnnotations()); } +template <typename AttributeTy> +static llvm::SmallVector<mlir::Attribute> +prepareCtorDtorAttrList(mlir::MLIRContext *context, + llvm::ArrayRef<std::pair<std::string, uint32_t>> list) { + llvm::SmallVector<mlir::Attribute> attrs; + for (const auto &[name, priority] : list) + attrs.push_back(AttributeTy::get(context, name, priority)); + return attrs; +} + +void LoweringPreparePass::buildGlobalCtorDtorList() { + if (!globalCtorList.empty()) { + llvm::SmallVector<mlir::Attribute> globalCtors = + prepareCtorDtorAttrList<cir::GlobalCtorAttr>(&getContext(), + globalCtorList); + + mlirModule->setAttr(cir::CIRDialect::getGlobalCtorsAttrName(), + mlir::ArrayAttr::get(&getContext(), globalCtors)); + } + + // We will eventual need to populate a global_dtor list, but that's not + // needed for globals with destructors. It will only be needed for functions + // that are marked as global destructors with an attribute. + assert(!cir::MissingFeatures::opGlobalDtorList()); +} + void LoweringPreparePass::buildCXXGlobalInitFunc() { if (dynamicInitializers.empty()) return; - assert(!cir::MissingFeatures::opGlobalCtorList()); + // TODO: handle globals with a user-specified initialzation priority. + // TODO: handle default priority more nicely. + assert(!cir::MissingFeatures::opGlobalCtorPriority()); SmallString<256> fnName; // Include the filename in the symbol name. Including "sub_" matches gcc @@ -722,6 +842,10 @@ void LoweringPreparePass::buildCXXGlobalInitFunc() { builder.setInsertionPointToStart(f.addEntryBlock()); for (cir::FuncOp &f : dynamicInitializers) builder.createCallOp(f.getLoc(), f, {}); + // Add the global init function (not the individual ctor functions) to the + // global ctor list. + globalCtorList.emplace_back(fnName, + cir::GlobalCtorAttr::getDefaultPriority()); cir::ReturnOp::create(builder, f.getLoc()); } @@ -852,6 +976,7 @@ void LoweringPreparePass::runOnOperation() { runOnOp(o); buildCXXGlobalInitFunc(); + buildGlobalCtorDtorList(); } std::unique_ptr<Pass> mlir::createLoweringPreparePass() { diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 3a3c631..a1ecfc7 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -1771,9 +1771,13 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite( } // Rewrite op. - rewriter.replaceOpWithNewOp<mlir::LLVM::GlobalOp>( + auto newOp = rewriter.replaceOpWithNewOp<mlir::LLVM::GlobalOp>( op, llvmType, isConst, linkage, symbol, init.value_or(mlir::Attribute()), alignment, addrSpace, isDsoLocal, isThreadLocal, comdatAttr, attributes); + newOp.setVisibility_Attr(mlir::LLVM::VisibilityAttr::get( + getContext(), lowerCIRVisibilityToLLVMVisibility( + op.getGlobalVisibilityAttr().getValue()))); + return mlir::success(); } @@ -2413,6 +2417,73 @@ static void prepareTypeConverter(mlir::LLVMTypeConverter &converter, }); } +static void buildCtorDtorList( + mlir::ModuleOp module, StringRef globalXtorName, StringRef llvmXtorName, + llvm::function_ref<std::pair<StringRef, int>(mlir::Attribute)> createXtor) { + llvm::SmallVector<std::pair<StringRef, int>> globalXtors; + for (const mlir::NamedAttribute namedAttr : module->getAttrs()) { + if (namedAttr.getName() == globalXtorName) { + for (auto attr : mlir::cast<mlir::ArrayAttr>(namedAttr.getValue())) + globalXtors.emplace_back(createXtor(attr)); + break; + } + } + + if (globalXtors.empty()) + return; + + mlir::OpBuilder builder(module.getContext()); + builder.setInsertionPointToEnd(&module.getBodyRegion().back()); + + // Create a global array llvm.global_ctors with element type of + // struct { i32, ptr, ptr } + auto ctorPFTy = mlir::LLVM::LLVMPointerType::get(builder.getContext()); + llvm::SmallVector<mlir::Type> ctorStructFields; + ctorStructFields.push_back(builder.getI32Type()); + ctorStructFields.push_back(ctorPFTy); + ctorStructFields.push_back(ctorPFTy); + + auto ctorStructTy = mlir::LLVM::LLVMStructType::getLiteral( + builder.getContext(), ctorStructFields); + auto ctorStructArrayTy = + mlir::LLVM::LLVMArrayType::get(ctorStructTy, globalXtors.size()); + + mlir::Location loc = module.getLoc(); + auto newGlobalOp = mlir::LLVM::GlobalOp::create( + builder, loc, ctorStructArrayTy, /*constant=*/false, + mlir::LLVM::Linkage::Appending, llvmXtorName, mlir::Attribute()); + + builder.createBlock(&newGlobalOp.getRegion()); + builder.setInsertionPointToEnd(newGlobalOp.getInitializerBlock()); + + mlir::Value result = + mlir::LLVM::UndefOp::create(builder, loc, ctorStructArrayTy); + + for (auto [index, fn] : llvm::enumerate(globalXtors)) { + mlir::Value structInit = + mlir::LLVM::UndefOp::create(builder, loc, ctorStructTy); + mlir::Value initPriority = mlir::LLVM::ConstantOp::create( + builder, loc, ctorStructFields[0], fn.second); + mlir::Value initFuncAddr = mlir::LLVM::AddressOfOp::create( + builder, loc, ctorStructFields[1], fn.first); + mlir::Value initAssociate = + mlir::LLVM::ZeroOp::create(builder, loc, ctorStructFields[2]); + // Literal zero makes the InsertValueOp::create ambiguous. + llvm::SmallVector<int64_t> zero{0}; + structInit = mlir::LLVM::InsertValueOp::create(builder, loc, structInit, + initPriority, zero); + structInit = mlir::LLVM::InsertValueOp::create(builder, loc, structInit, + initFuncAddr, 1); + // TODO: handle associated data for initializers. + structInit = mlir::LLVM::InsertValueOp::create(builder, loc, structInit, + initAssociate, 2); + result = mlir::LLVM::InsertValueOp::create(builder, loc, result, structInit, + index); + } + + builder.create<mlir::LLVM::ReturnOp>(loc, result); +} + // The applyPartialConversion function traverses blocks in the dominance order, // so it does not lower and operations that are not reachachable from the // operations passed in as arguments. Since we do need to lower such code in @@ -2519,6 +2590,15 @@ void ConvertCIRToLLVMPass::runOnOperation() { if (failed(applyPartialConversion(ops, target, std::move(patterns)))) signalPassFailure(); + + // Emit the llvm.global_ctors array. + buildCtorDtorList(module, cir::CIRDialect::getGlobalCtorsAttrName(), + "llvm.global_ctors", [](mlir::Attribute attr) { + auto ctorAttr = mlir::cast<cir::GlobalCtorAttr>(attr); + return std::make_pair(ctorAttr.getName(), + ctorAttr.getPriority()); + }); + assert(!cir::MissingFeatures::opGlobalDtorList()); } mlir::LogicalResult CIRToLLVMBrOpLowering::matchAndRewrite( @@ -2581,22 +2661,69 @@ void createLLVMFuncOpIfNotExist(mlir::ConversionPatternRewriter &rewriter, mlir::LogicalResult CIRToLLVMThrowOpLowering::matchAndRewrite( cir::ThrowOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { - if (op.rethrows()) { - auto voidTy = mlir::LLVM::LLVMVoidType::get(getContext()); - auto funcTy = - mlir::LLVM::LLVMFunctionType::get(getContext(), voidTy, {}, false); + mlir::Location loc = op.getLoc(); + auto voidTy = mlir::LLVM::LLVMVoidType::get(getContext()); - auto mlirModule = op->getParentOfType<mlir::ModuleOp>(); - rewriter.setInsertionPointToStart(&mlirModule.getBodyRegion().front()); + if (op.rethrows()) { + auto funcTy = mlir::LLVM::LLVMFunctionType::get(voidTy, {}); + // Get or create `declare void @__cxa_rethrow()` const llvm::StringRef functionName = "__cxa_rethrow"; createLLVMFuncOpIfNotExist(rewriter, op, functionName, funcTy); - rewriter.setInsertionPointAfter(op.getOperation()); - rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>( - op, mlir::TypeRange{}, functionName, mlir::ValueRange{}); + auto cxaRethrow = mlir::LLVM::CallOp::create( + rewriter, loc, mlir::TypeRange{}, functionName); + + rewriter.replaceOp(op, cxaRethrow); + return mlir::success(); + } + + auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); + auto fnTy = mlir::LLVM::LLVMFunctionType::get( + voidTy, {llvmPtrTy, llvmPtrTy, llvmPtrTy}); + + // Get or create `declare void @__cxa_throw(ptr, ptr, ptr)` + const llvm::StringRef fnName = "__cxa_throw"; + createLLVMFuncOpIfNotExist(rewriter, op, fnName, fnTy); + + mlir::Value typeInfo = mlir::LLVM::AddressOfOp::create( + rewriter, loc, mlir::LLVM::LLVMPointerType::get(rewriter.getContext()), + adaptor.getTypeInfoAttr()); + + mlir::Value dtor; + if (op.getDtor()) { + dtor = mlir::LLVM::AddressOfOp::create(rewriter, loc, llvmPtrTy, + adaptor.getDtorAttr()); + } else { + dtor = mlir::LLVM::ZeroOp::create(rewriter, loc, llvmPtrTy); } + auto cxaThrowCall = mlir::LLVM::CallOp::create( + rewriter, loc, mlir::TypeRange{}, fnName, + mlir::ValueRange{adaptor.getExceptionPtr(), typeInfo, dtor}); + + rewriter.replaceOp(op, cxaThrowCall); + return mlir::success(); +} + +mlir::LogicalResult CIRToLLVMAllocExceptionOpLowering::matchAndRewrite( + cir::AllocExceptionOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + // Get or create `declare ptr @__cxa_allocate_exception(i64)` + StringRef fnName = "__cxa_allocate_exception"; + auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); + auto int64Ty = mlir::IntegerType::get(rewriter.getContext(), 64); + auto fnTy = mlir::LLVM::LLVMFunctionType::get(llvmPtrTy, {int64Ty}); + + createLLVMFuncOpIfNotExist(rewriter, op, fnName, fnTy); + auto exceptionSize = mlir::LLVM::ConstantOp::create(rewriter, op.getLoc(), + adaptor.getSizeAttr()); + + auto allocaExceptionCall = mlir::LLVM::CallOp::create( + rewriter, op.getLoc(), mlir::TypeRange{llvmPtrTy}, fnName, + mlir::ValueRange{exceptionSize}); + + rewriter.replaceOp(op, allocaExceptionCall); return mlir::success(); } diff --git a/clang/lib/CodeGen/BackendUtil.cpp b/clang/lib/CodeGen/BackendUtil.cpp index 64f1917..2d95982 100644 --- a/clang/lib/CodeGen/BackendUtil.cpp +++ b/clang/lib/CodeGen/BackendUtil.cpp @@ -60,11 +60,13 @@ #include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/HipStdPar/HipStdPar.h" #include "llvm/Transforms/IPO/EmbedBitcodePass.h" +#include "llvm/Transforms/IPO/InferFunctionAttrs.h" #include "llvm/Transforms/IPO/LowerTypeTests.h" #include "llvm/Transforms/IPO/ThinLTOBitcodeWriter.h" #include "llvm/Transforms/InstCombine/InstCombine.h" #include "llvm/Transforms/Instrumentation/AddressSanitizer.h" #include "llvm/Transforms/Instrumentation/AddressSanitizerOptions.h" +#include "llvm/Transforms/Instrumentation/AllocToken.h" #include "llvm/Transforms/Instrumentation/BoundsChecking.h" #include "llvm/Transforms/Instrumentation/DataFlowSanitizer.h" #include "llvm/Transforms/Instrumentation/GCOVProfiler.h" @@ -232,6 +234,14 @@ public: }; } // namespace +static AllocTokenOptions getAllocTokenOptions(const CodeGenOptions &CGOpts) { + AllocTokenOptions Opts; + Opts.MaxTokens = CGOpts.AllocTokenMax; + Opts.Extended = CGOpts.SanitizeAllocTokenExtended; + Opts.FastABI = CGOpts.SanitizeAllocTokenFastABI; + return Opts; +} + static SanitizerCoverageOptions getSancovOptsFromCGOpts(const CodeGenOptions &CGOpts) { SanitizerCoverageOptions Opts; @@ -789,6 +799,16 @@ static void addSanitizers(const Triple &TargetTriple, MPM.addPass(DataFlowSanitizerPass(LangOpts.NoSanitizeFiles, PB.getVirtualFileSystemPtr())); } + + if (LangOpts.Sanitize.has(SanitizerKind::AllocToken)) { + if (Level == OptimizationLevel::O0) { + // The default pass builder only infers libcall function attrs when + // optimizing, so we insert it here because we need it for accurate + // memory allocation function detection. + MPM.addPass(InferFunctionAttrsPass()); + } + MPM.addPass(AllocTokenPass(getAllocTokenOptions(CodeGenOpts))); + } }; if (ClSanitizeOnOptimizerEarlyEP) { PB.registerOptimizerEarlyEPCallback( diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp index a931ce4..c5371e4 100644 --- a/clang/lib/CodeGen/CGCall.cpp +++ b/clang/lib/CodeGen/CGCall.cpp @@ -3018,8 +3018,7 @@ void CodeGenModule::ConstructAttributeList(StringRef Name, ArgNo = 0; if (AddedPotentialArgAccess && MemAttrForPtrArgs) { - llvm::FunctionType *FunctionType = FunctionType = - getTypes().GetFunctionType(FI); + llvm::FunctionType *FunctionType = getTypes().GetFunctionType(FI); for (CGFunctionInfo::const_arg_iterator I = FI.arg_begin(), E = FI.arg_end(); I != E; ++I, ++ArgNo) { diff --git a/clang/lib/CodeGen/CGDebugInfo.cpp b/clang/lib/CodeGen/CGDebugInfo.cpp index fee6bc0..9fe9a13 100644 --- a/clang/lib/CodeGen/CGDebugInfo.cpp +++ b/clang/lib/CodeGen/CGDebugInfo.cpp @@ -787,7 +787,8 @@ void CGDebugInfo::CreateCompileUnit() { // Create new compile unit. TheCU = DBuilder.createCompileUnit( - LangTag, CUFile, CGOpts.EmitVersionIdentMetadata ? Producer : "", + llvm::DISourceLanguageName(LangTag), CUFile, + CGOpts.EmitVersionIdentMetadata ? Producer : "", CGOpts.OptimizationLevel != 0 || CGOpts.PrepareForLTO || CGOpts.PrepareForThinLTO, CGOpts.DwarfDebugFlags, RuntimeVers, CGOpts.SplitDwarfFile, EmissionKind, @@ -899,10 +900,13 @@ llvm::DIType *CGDebugInfo::CreateType(const BuiltinType *BT) { assert((BT->getKind() != BuiltinType::SveCount || Info.NumVectors == 1) && "Unsupported number of vectors for svcount_t"); - // Debuggers can't extract 1bit from a vector, so will display a - // bitpattern for predicates instead. unsigned NumElems = Info.EC.getKnownMinValue() * Info.NumVectors; - if (Info.ElementType == CGM.getContext().BoolTy) { + llvm::Metadata *BitStride = nullptr; + if (BT->getKind() == BuiltinType::SveBool) { + Info.ElementType = CGM.getContext().UnsignedCharTy; + BitStride = llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned( + llvm::Type::getInt64Ty(CGM.getLLVMContext()), 1)); + } else if (BT->getKind() == BuiltinType::SveCount) { NumElems /= 8; Info.ElementType = CGM.getContext().UnsignedCharTy; } @@ -928,7 +932,7 @@ llvm::DIType *CGDebugInfo::CreateType(const BuiltinType *BT) { getOrCreateType(Info.ElementType, TheCU->getFile()); auto Align = getTypeAlignIfRequired(BT, CGM.getContext()); return DBuilder.createVectorType(/*Size*/ 0, Align, ElemTy, - SubscriptArray); + SubscriptArray, BitStride); } // It doesn't make sense to generate debug info for PowerPC MMA vector types. // So we return a safe type here to avoid generating an error. @@ -1232,7 +1236,7 @@ llvm::DIType *CGDebugInfo::CreateType(const PointerType *Ty, /// \return whether a C++ mangling exists for the type defined by TD. static bool hasCXXMangling(const TagDecl *TD, llvm::DICompileUnit *TheCU) { - switch (TheCU->getSourceLanguage()) { + switch (TheCU->getSourceLanguage().getUnversionedName()) { case llvm::dwarf::DW_LANG_C_plus_plus: case llvm::dwarf::DW_LANG_C_plus_plus_11: case llvm::dwarf::DW_LANG_C_plus_plus_14: @@ -3211,8 +3215,8 @@ llvm::DIType *CGDebugInfo::CreateType(const ObjCInterfaceType *Ty, if (!ID) return nullptr; - auto RuntimeLang = - static_cast<llvm::dwarf::SourceLanguage>(TheCU->getSourceLanguage()); + auto RuntimeLang = static_cast<llvm::dwarf::SourceLanguage>( + TheCU->getSourceLanguage().getUnversionedName()); // Return a forward declaration if this type was imported from a clang module, // and this is not the compile unit with the implementation of the type (which @@ -3348,7 +3352,8 @@ llvm::DIType *CGDebugInfo::CreateTypeDefinition(const ObjCInterfaceType *Ty, ObjCInterfaceDecl *ID = Ty->getDecl(); llvm::DIFile *DefUnit = getOrCreateFile(ID->getLocation()); unsigned Line = getLineNumber(ID->getLocation()); - unsigned RuntimeLang = TheCU->getSourceLanguage(); + + unsigned RuntimeLang = TheCU->getSourceLanguage().getUnversionedName(); // Bit size, align and offset of the type. uint64_t Size = CGM.getContext().getTypeSize(Ty); diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp index 9f30287..e8255b0 100644 --- a/clang/lib/CodeGen/CGExpr.cpp +++ b/clang/lib/CodeGen/CGExpr.cpp @@ -30,6 +30,7 @@ #include "clang/AST/Attr.h" #include "clang/AST/DeclObjC.h" #include "clang/AST/NSAPI.h" +#include "clang/AST/ParentMapContext.h" #include "clang/AST/StmtVisitor.h" #include "clang/Basic/Builtins.h" #include "clang/Basic/CodeGenOptions.h" @@ -1272,6 +1273,196 @@ void CodeGenFunction::EmitBoundsCheckImpl(const Expr *E, llvm::Value *Bound, EmitCheck(std::make_pair(Check, CheckKind), CheckHandler, StaticData, Index); } +static bool +typeContainsPointer(QualType T, + llvm::SmallPtrSet<const RecordDecl *, 4> &VisitedRD, + bool &IncompleteType) { + QualType CanonicalType = T.getCanonicalType(); + if (CanonicalType->isPointerType()) + return true; // base case + + // Look through typedef chain to check for special types. + for (QualType CurrentT = T; const auto *TT = CurrentT->getAs<TypedefType>(); + CurrentT = TT->getDecl()->getUnderlyingType()) { + const IdentifierInfo *II = TT->getDecl()->getIdentifier(); + // Special Case: Syntactically uintptr_t is not a pointer; semantically, + // however, very likely used as such. Therefore, classify uintptr_t as a + // pointer, too. + if (II && II->isStr("uintptr_t")) + return true; + } + + // The type is an array; check the element type. + if (const ArrayType *AT = dyn_cast<ArrayType>(CanonicalType)) + return typeContainsPointer(AT->getElementType(), VisitedRD, IncompleteType); + // The type is a struct, class, or union. + if (const RecordDecl *RD = CanonicalType->getAsRecordDecl()) { + if (!RD->isCompleteDefinition()) { + IncompleteType = true; + return false; + } + if (!VisitedRD.insert(RD).second) + return false; // already visited + // Check all fields. + for (const FieldDecl *Field : RD->fields()) { + if (typeContainsPointer(Field->getType(), VisitedRD, IncompleteType)) + return true; + } + // For C++ classes, also check base classes. + if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) { + // Polymorphic types require a vptr. + if (CXXRD->isDynamicClass()) + return true; + for (const CXXBaseSpecifier &Base : CXXRD->bases()) { + if (typeContainsPointer(Base.getType(), VisitedRD, IncompleteType)) + return true; + } + } + } + return false; +} + +void CodeGenFunction::EmitAllocToken(llvm::CallBase *CB, QualType AllocType) { + assert(SanOpts.has(SanitizerKind::AllocToken) && + "Only needed with -fsanitize=alloc-token"); + + llvm::MDBuilder MDB(getLLVMContext()); + + // Get unique type name. + PrintingPolicy Policy(CGM.getContext().getLangOpts()); + Policy.SuppressTagKeyword = true; + Policy.FullyQualifiedName = true; + SmallString<64> TypeName; + llvm::raw_svector_ostream TypeNameOS(TypeName); + AllocType.getCanonicalType().print(TypeNameOS, Policy); + auto *TypeNameMD = MDB.createString(TypeNameOS.str()); + + // Check if QualType contains a pointer. Implements a simple DFS to + // recursively check if a type contains a pointer type. + llvm::SmallPtrSet<const RecordDecl *, 4> VisitedRD; + bool IncompleteType = false; + const bool ContainsPtr = + typeContainsPointer(AllocType, VisitedRD, IncompleteType); + if (!ContainsPtr && IncompleteType) + return; + auto *ContainsPtrC = Builder.getInt1(ContainsPtr); + auto *ContainsPtrMD = MDB.createConstant(ContainsPtrC); + + // Format: !{<type-name>, <contains-pointer>} + auto *MDN = + llvm::MDNode::get(CGM.getLLVMContext(), {TypeNameMD, ContainsPtrMD}); + CB->setMetadata(llvm::LLVMContext::MD_alloc_token, MDN); +} + +namespace { +/// Infer type from a simple sizeof expression. +QualType inferTypeFromSizeofExpr(const Expr *E) { + const Expr *Arg = E->IgnoreParenImpCasts(); + if (const auto *UET = dyn_cast<UnaryExprOrTypeTraitExpr>(Arg)) { + if (UET->getKind() == UETT_SizeOf) { + if (UET->isArgumentType()) + return UET->getArgumentTypeInfo()->getType(); + else + return UET->getArgumentExpr()->getType(); + } + } + return QualType(); +} + +/// Infer type from an arithmetic expression involving a sizeof. For example: +/// +/// malloc(sizeof(MyType) + padding); // infers 'MyType' +/// malloc(sizeof(MyType) * 32); // infers 'MyType' +/// malloc(32 * sizeof(MyType)); // infers 'MyType' +/// malloc(sizeof(MyType) << 1); // infers 'MyType' +/// ... +/// +/// More complex arithmetic expressions are supported, but are a heuristic, e.g. +/// when considering allocations for structs with flexible array members: +/// +/// malloc(sizeof(HasFlexArray) + sizeof(int) * 32); // infers 'HasFlexArray' +/// +QualType inferPossibleTypeFromArithSizeofExpr(const Expr *E) { + const Expr *Arg = E->IgnoreParenImpCasts(); + // The argument is a lone sizeof expression. + if (QualType T = inferTypeFromSizeofExpr(Arg); !T.isNull()) + return T; + if (const auto *BO = dyn_cast<BinaryOperator>(Arg)) { + // Argument is an arithmetic expression. Cover common arithmetic patterns + // involving sizeof. + switch (BO->getOpcode()) { + case BO_Add: + case BO_Div: + case BO_Mul: + case BO_Shl: + case BO_Shr: + case BO_Sub: + if (QualType T = inferPossibleTypeFromArithSizeofExpr(BO->getLHS()); + !T.isNull()) + return T; + if (QualType T = inferPossibleTypeFromArithSizeofExpr(BO->getRHS()); + !T.isNull()) + return T; + break; + default: + break; + } + } + return QualType(); +} + +/// If the expression E is a reference to a variable, infer the type from a +/// variable's initializer if it contains a sizeof. Beware, this is a heuristic +/// and ignores if a variable is later reassigned. For example: +/// +/// size_t my_size = sizeof(MyType); +/// void *x = malloc(my_size); // infers 'MyType' +/// +QualType inferPossibleTypeFromVarInitSizeofExpr(const Expr *E) { + const Expr *Arg = E->IgnoreParenImpCasts(); + if (const auto *DRE = dyn_cast<DeclRefExpr>(Arg)) { + if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) { + if (const Expr *Init = VD->getInit()) + return inferPossibleTypeFromArithSizeofExpr(Init); + } + } + return QualType(); +} + +/// Deduces the allocated type by checking if the allocation call's result +/// is immediately used in a cast expression. For example: +/// +/// MyType *x = (MyType *)malloc(4096); // infers 'MyType' +/// +QualType inferPossibleTypeFromCastExpr(const CallExpr *CallE, + const CastExpr *CastE) { + if (!CastE) + return QualType(); + QualType PtrType = CastE->getType(); + if (PtrType->isPointerType()) + return PtrType->getPointeeType(); + return QualType(); +} +} // end anonymous namespace + +void CodeGenFunction::EmitAllocToken(llvm::CallBase *CB, const CallExpr *E) { + QualType AllocType; + // First check arguments. + for (const Expr *Arg : E->arguments()) { + AllocType = inferPossibleTypeFromArithSizeofExpr(Arg); + if (AllocType.isNull()) + AllocType = inferPossibleTypeFromVarInitSizeofExpr(Arg); + if (!AllocType.isNull()) + break; + } + // Then check later casts. + if (AllocType.isNull()) + AllocType = inferPossibleTypeFromCastExpr(E, CurCast); + // Emit if we were able to infer the type. + if (!AllocType.isNull()) + EmitAllocToken(CB, AllocType); +} + CodeGenFunction::ComplexPairTy CodeGenFunction:: EmitComplexPrePostIncDec(const UnaryOperator *E, LValue LV, bool isInc, bool isPre) { @@ -5642,6 +5833,9 @@ LValue CodeGenFunction::EmitConditionalOperatorLValue( /// are permitted with aggregate result, including noop aggregate casts, and /// cast from scalar to union. LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) { + auto RestoreCurCast = + llvm::make_scope_exit([this, Prev = CurCast] { CurCast = Prev; }); + CurCast = E; switch (E->getCastKind()) { case CK_ToVoid: case CK_BitCast: @@ -6587,16 +6781,24 @@ RValue CodeGenFunction::EmitCall(QualType CalleeType, RValue Call = EmitCall(FnInfo, Callee, ReturnValue, Args, &LocalCallOrInvoke, E == MustTailCall, E->getExprLoc()); - // Generate function declaration DISuprogram in order to be used - // in debug info about call sites. - if (CGDebugInfo *DI = getDebugInfo()) { - if (auto *CalleeDecl = dyn_cast_or_null<FunctionDecl>(TargetDecl)) { + if (auto *CalleeDecl = dyn_cast_or_null<FunctionDecl>(TargetDecl)) { + // Generate function declaration DISuprogram in order to be used + // in debug info about call sites. + if (CGDebugInfo *DI = getDebugInfo()) { FunctionArgList Args; QualType ResTy = BuildFunctionArgList(CalleeDecl, Args); DI->EmitFuncDeclForCallSite(LocalCallOrInvoke, DI->getFunctionType(CalleeDecl, ResTy, Args), CalleeDecl); } + if (CalleeDecl->hasAttr<RestrictAttr>() || + CalleeDecl->hasAttr<AllocSizeAttr>()) { + // Function has 'malloc' (aka. 'restrict') or 'alloc_size' attribute. + if (SanOpts.has(SanitizerKind::AllocToken)) { + // Set !alloc_token metadata. + EmitAllocToken(LocalCallOrInvoke, E); + } + } } if (CallOrInvoke) *CallOrInvoke = LocalCallOrInvoke; diff --git a/clang/lib/CodeGen/CGExprCXX.cpp b/clang/lib/CodeGen/CGExprCXX.cpp index c52526c..31ac266 100644 --- a/clang/lib/CodeGen/CGExprCXX.cpp +++ b/clang/lib/CodeGen/CGExprCXX.cpp @@ -1371,8 +1371,16 @@ RValue CodeGenFunction::EmitBuiltinNewDeleteCall(const FunctionProtoType *Type, for (auto *Decl : Ctx.getTranslationUnitDecl()->lookup(Name)) if (auto *FD = dyn_cast<FunctionDecl>(Decl)) - if (Ctx.hasSameType(FD->getType(), QualType(Type, 0))) - return EmitNewDeleteCall(*this, FD, Type, Args); + if (Ctx.hasSameType(FD->getType(), QualType(Type, 0))) { + RValue RV = EmitNewDeleteCall(*this, FD, Type, Args); + if (auto *CB = dyn_cast_if_present<llvm::CallBase>(RV.getScalarVal())) { + if (SanOpts.has(SanitizerKind::AllocToken)) { + // Set !alloc_token metadata. + EmitAllocToken(CB, TheCall); + } + } + return RV; + } llvm_unreachable("predeclared global operator new/delete is missing"); } @@ -1655,11 +1663,16 @@ llvm::Value *CodeGenFunction::EmitCXXNewExpr(const CXXNewExpr *E) { RValue RV = EmitNewDeleteCall(*this, allocator, allocatorType, allocatorArgs); - // Set !heapallocsite metadata on the call to operator new. - if (getDebugInfo()) - if (auto *newCall = dyn_cast<llvm::CallBase>(RV.getScalarVal())) - getDebugInfo()->addHeapAllocSiteMetadata(newCall, allocType, - E->getExprLoc()); + if (auto *newCall = dyn_cast<llvm::CallBase>(RV.getScalarVal())) { + if (auto *CGDI = getDebugInfo()) { + // Set !heapallocsite metadata on the call to operator new. + CGDI->addHeapAllocSiteMetadata(newCall, allocType, E->getExprLoc()); + } + if (SanOpts.has(SanitizerKind::AllocToken)) { + // Set !alloc_token metadata. + EmitAllocToken(newCall, allocType); + } + } // If this was a call to a global replaceable allocation function that does // not take an alignment argument, the allocator is known to produce diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp index 06d9d81..715160d 100644 --- a/clang/lib/CodeGen/CGExprScalar.cpp +++ b/clang/lib/CodeGen/CGExprScalar.cpp @@ -33,6 +33,7 @@ #include "clang/Basic/DiagnosticTrap.h" #include "clang/Basic/TargetInfo.h" #include "llvm/ADT/APFixedPoint.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/IR/Argument.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" @@ -2434,6 +2435,10 @@ static Value *EmitHLSLElementwiseCast(CodeGenFunction &CGF, LValue SrcVal, // have to handle a more broad range of conversions than explicit casts, as they // handle things like function to ptr-to-function decay etc. Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) { + auto RestoreCurCast = + llvm::make_scope_exit([this, Prev = CGF.CurCast] { CGF.CurCast = Prev; }); + CGF.CurCast = CE; + Expr *E = CE->getSubExpr(); QualType DestTy = CE->getType(); CastKind Kind = CE->getCastKind(); diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp index 6c0fc8d..4f2f5a76 100644 --- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp +++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp @@ -352,6 +352,19 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, SmallVector<Value *> Args{OrderID, SpaceOp, RangeOp, IndexOp, Name}; return Builder.CreateIntrinsic(HandleTy, IntrinsicID, Args); } + case Builtin::BI__builtin_hlsl_resource_counterhandlefromimplicitbinding: { + Value *MainHandle = EmitScalarExpr(E->getArg(0)); + if (!CGM.getTriple().isSPIRV()) + return MainHandle; + + llvm::Type *HandleTy = CGM.getTypes().ConvertType(E->getType()); + Value *OrderID = EmitScalarExpr(E->getArg(1)); + Value *SpaceOp = EmitScalarExpr(E->getArg(2)); + llvm::Intrinsic::ID IntrinsicID = + llvm::Intrinsic::spv_resource_counterhandlefromimplicitbinding; + SmallVector<Value *> Args{MainHandle, OrderID, SpaceOp}; + return Builder.CreateIntrinsic(HandleTy, IntrinsicID, Args); + } case Builtin::BI__builtin_hlsl_resource_nonuniformindex: { Value *IndexOp = EmitScalarExpr(E->getArg(0)); llvm::Type *RetTy = ConvertType(E->getType()); diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp index ede1780..603cef9 100644 --- a/clang/lib/CodeGen/CGHLSLRuntime.cpp +++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp @@ -145,19 +145,29 @@ static CXXMethodDecl *lookupResourceInitMethodAndSetupArgs( // explicit binding auto *RegSlot = llvm::ConstantInt::get(CGM.IntTy, Binding.getSlot()); Args.add(RValue::get(RegSlot), AST.UnsignedIntTy); - CreateMethod = lookupMethod(ResourceDecl, "__createFromBinding", SC_Static); + const char *Name = Binding.hasCounterImplicitOrderID() + ? "__createFromBindingWithImplicitCounter" + : "__createFromBinding"; + CreateMethod = lookupMethod(ResourceDecl, Name, SC_Static); } else { // implicit binding auto *OrderID = llvm::ConstantInt::get(CGM.IntTy, Binding.getImplicitOrderID()); Args.add(RValue::get(OrderID), AST.UnsignedIntTy); - CreateMethod = - lookupMethod(ResourceDecl, "__createFromImplicitBinding", SC_Static); + const char *Name = Binding.hasCounterImplicitOrderID() + ? "__createFromImplicitBindingWithImplicitCounter" + : "__createFromImplicitBinding"; + CreateMethod = lookupMethod(ResourceDecl, Name, SC_Static); } Args.add(RValue::get(Space), AST.UnsignedIntTy); Args.add(RValue::get(Range), AST.IntTy); Args.add(RValue::get(Index), AST.UnsignedIntTy); Args.add(RValue::get(NameStr), AST.getPointerType(AST.CharTy.withConst())); + if (Binding.hasCounterImplicitOrderID()) { + uint32_t CounterBinding = Binding.getCounterImplicitOrderID(); + auto *CounterOrderID = llvm::ConstantInt::get(CGM.IntTy, CounterBinding); + Args.add(RValue::get(CounterOrderID), AST.UnsignedIntTy); + } return CreateMethod; } diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp index 4272d8b..3613b6a 100644 --- a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp @@ -869,6 +869,8 @@ CGOpenMPRuntimeGPU::CGOpenMPRuntimeGPU(CodeGenModule &CGM) CGM.getLangOpts().OpenMPOffloadMandatory, /*HasRequiresReverseOffload*/ false, /*HasRequiresUnifiedAddress*/ false, hasRequiresUnifiedSharedMemory(), /*HasRequiresDynamicAllocators*/ false); + Config.setDefaultTargetAS( + CGM.getContext().getTargetInfo().getTargetAddressSpace(LangAS::Default)); OMPBuilder.setConfig(Config); if (!CGM.getLangOpts().OpenMPIsTargetDevice) @@ -1243,7 +1245,10 @@ void CGOpenMPRuntimeGPU::emitParallelCall( llvm::Value *ID = llvm::ConstantPointerNull::get(CGM.Int8PtrTy); if (WFn) ID = Bld.CreateBitOrPointerCast(WFn, CGM.Int8PtrTy); - llvm::Value *FnPtr = Bld.CreateBitOrPointerCast(OutlinedFn, CGM.Int8PtrTy); + llvm::Type *FnPtrTy = llvm::PointerType::get( + CGF.getLLVMContext(), CGM.getDataLayout().getProgramAddressSpace()); + + llvm::Value *FnPtr = Bld.CreateBitOrPointerCast(OutlinedFn, FnPtrTy); // Create a private scope that will globalize the arguments // passed from the outside of the target region. diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.h b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.h index 810d6aa..3a7ee54 100644 --- a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.h +++ b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.h @@ -163,12 +163,14 @@ public: SourceLocation Loc) override; // Currently unsupported on the device. + using CGOpenMPRuntime::emitMessageClause; llvm::Value *emitMessageClause(CodeGenFunction &CGF, const Expr *Message, SourceLocation Loc) override; // Currently unsupported on the device. - virtual llvm::Value *emitSeverityClause(OpenMPSeverityClauseKind Severity, - SourceLocation Loc) override; + using CGOpenMPRuntime::emitSeverityClause; + llvm::Value *emitSeverityClause(OpenMPSeverityClauseKind Severity, + SourceLocation Loc) override; /// Emits call to void __kmpc_push_num_threads(ident_t *loc, kmp_int32 /// global_tid, kmp_int32 num_threads) to generate code for 'num_threads' diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp index b2fe917..acf8de4 100644 --- a/clang/lib/CodeGen/CodeGenFunction.cpp +++ b/clang/lib/CodeGen/CodeGenFunction.cpp @@ -846,6 +846,8 @@ void CodeGenFunction::StartFunction(GlobalDecl GD, QualType RetTy, Fn->addFnAttr(llvm::Attribute::SanitizeNumericalStability); if (SanOpts.hasOneOf(SanitizerKind::Memory | SanitizerKind::KernelMemory)) Fn->addFnAttr(llvm::Attribute::SanitizeMemory); + if (SanOpts.has(SanitizerKind::AllocToken)) + Fn->addFnAttr(llvm::Attribute::SanitizeAllocToken); } if (SanOpts.has(SanitizerKind::SafeStack)) Fn->addFnAttr(llvm::Attribute::SafeStack); diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h index 99de6e1..1f0be2d 100644 --- a/clang/lib/CodeGen/CodeGenFunction.h +++ b/clang/lib/CodeGen/CodeGenFunction.h @@ -346,6 +346,10 @@ public: QualType FnRetTy; llvm::Function *CurFn = nullptr; + /// If a cast expression is being visited, this holds the current cast's + /// expression. + const CastExpr *CurCast = nullptr; + /// Save Parameter Decl for coroutine. llvm::SmallVector<const ParmVarDecl *, 4> FnArgs; @@ -3348,6 +3352,12 @@ public: SanitizerAnnotateDebugInfo(ArrayRef<SanitizerKind::SanitizerOrdinal> Ordinals, SanitizerHandler Handler); + /// Emit additional metadata used by the AllocToken instrumentation. + void EmitAllocToken(llvm::CallBase *CB, QualType AllocType); + /// Emit additional metadata used by the AllocToken instrumentation, + /// inferring the type from an allocation call expression. + void EmitAllocToken(llvm::CallBase *CB, const CallExpr *E); + llvm::Value *GetCountedByFieldExprGEP(const Expr *Base, const FieldDecl *FD, const FieldDecl *CountDecl); diff --git a/clang/lib/CodeGen/Targets/SPIR.cpp b/clang/lib/CodeGen/Targets/SPIR.cpp index 4aa6314..3f6d4e0 100644 --- a/clang/lib/CodeGen/Targets/SPIR.cpp +++ b/clang/lib/CodeGen/Targets/SPIR.cpp @@ -61,6 +61,9 @@ public: QualType SampledType, CodeGenModule &CGM) const; void setOCLKernelStubCallingConvention(const FunctionType *&FT) const override; + llvm::Constant *getNullPointer(const CodeGen::CodeGenModule &CGM, + llvm::PointerType *T, + QualType QT) const override; }; class SPIRVTargetCodeGenInfo : public CommonSPIRTargetCodeGenInfo { public: @@ -240,6 +243,29 @@ void CommonSPIRTargetCodeGenInfo::setOCLKernelStubCallingConvention( FT, FT->getExtInfo().withCallingConv(CC_SpirFunction)); } +// LLVM currently assumes a null pointer has the bit pattern 0, but some GPU +// targets use a non-zero encoding for null in certain address spaces. +// Because SPIR(-V) is a generic target and the bit pattern of null in +// non-generic AS is unspecified, materialize null in non-generic AS via an +// addrspacecast from null in generic AS. This allows later lowering to +// substitute the target's real sentinel value. +llvm::Constant * +CommonSPIRTargetCodeGenInfo::getNullPointer(const CodeGen::CodeGenModule &CGM, + llvm::PointerType *PT, + QualType QT) const { + LangAS AS = QT->getUnqualifiedDesugaredType()->isNullPtrType() + ? LangAS::Default + : QT->getPointeeType().getAddressSpace(); + if (AS == LangAS::Default || AS == LangAS::opencl_generic) + return llvm::ConstantPointerNull::get(PT); + + auto &Ctx = CGM.getContext(); + auto NPT = llvm::PointerType::get( + PT->getContext(), Ctx.getTargetAddressSpace(LangAS::opencl_generic)); + return llvm::ConstantExpr::getAddrSpaceCast( + llvm::ConstantPointerNull::get(NPT), PT); +} + LangAS SPIRVTargetCodeGenInfo::getGlobalVarAddressSpace(CodeGenModule &CGM, const VarDecl *D) const { diff --git a/clang/lib/Driver/Action.cpp b/clang/lib/Driver/Action.cpp index e19daa9..72a42a6 100644 --- a/clang/lib/Driver/Action.cpp +++ b/clang/lib/Driver/Action.cpp @@ -43,7 +43,7 @@ const char *Action::getClassName(ActionClass AC) { case OffloadUnbundlingJobClass: return "clang-offload-unbundler"; case OffloadPackagerJobClass: - return "clang-offload-packager"; + return "llvm-offload-binary"; case LinkerWrapperJobClass: return "clang-linker-wrapper"; case StaticLibJobClass: diff --git a/clang/lib/Driver/SanitizerArgs.cpp b/clang/lib/Driver/SanitizerArgs.cpp index 7ce1afe..5dd48f5 100644 --- a/clang/lib/Driver/SanitizerArgs.cpp +++ b/clang/lib/Driver/SanitizerArgs.cpp @@ -61,8 +61,9 @@ static const SanitizerMask RecoverableByDefault = SanitizerKind::ImplicitConversion | SanitizerKind::Nullability | SanitizerKind::FloatDivideByZero | SanitizerKind::ObjCCast | SanitizerKind::Vptr; -static const SanitizerMask Unrecoverable = - SanitizerKind::Unreachable | SanitizerKind::Return; +static const SanitizerMask Unrecoverable = SanitizerKind::Unreachable | + SanitizerKind::Return | + SanitizerKind::AllocToken; static const SanitizerMask AlwaysRecoverable = SanitizerKind::KernelAddress | SanitizerKind::KernelHWAddress | SanitizerKind::KCFI; @@ -84,7 +85,8 @@ static const SanitizerMask CFIClasses = static const SanitizerMask CompatibleWithMinimalRuntime = TrappingSupported | SanitizerKind::Scudo | SanitizerKind::ShadowCallStack | SanitizerKind::MemtagStack | SanitizerKind::MemtagHeap | - SanitizerKind::MemtagGlobals | SanitizerKind::KCFI; + SanitizerKind::MemtagGlobals | SanitizerKind::KCFI | + SanitizerKind::AllocToken; enum CoverageFeature { CoverageFunc = 1 << 0, @@ -203,6 +205,7 @@ static void addDefaultIgnorelists(const Driver &D, SanitizerMask Kinds, {"tysan_blacklist.txt", SanitizerKind::Type}, {"dfsan_abilist.txt", SanitizerKind::DataFlow}, {"cfi_ignorelist.txt", SanitizerKind::CFI}, + {"alloc_token_ignorelist.txt", SanitizerKind::AllocToken}, {"ubsan_ignorelist.txt", SanitizerKind::Undefined | SanitizerKind::Vptr | SanitizerKind::Integer | SanitizerKind::Nullability | @@ -650,7 +653,12 @@ SanitizerArgs::SanitizerArgs(const ToolChain &TC, std::make_pair(SanitizerKind::KCFI, SanitizerKind::Function), std::make_pair(SanitizerKind::Realtime, SanitizerKind::Address | SanitizerKind::Thread | - SanitizerKind::Undefined | SanitizerKind::Memory)}; + SanitizerKind::Undefined | SanitizerKind::Memory), + std::make_pair(SanitizerKind::AllocToken, + SanitizerKind::Address | SanitizerKind::HWAddress | + SanitizerKind::KernelAddress | + SanitizerKind::KernelHWAddress | + SanitizerKind::Memory)}; // Enable toolchain specific default sanitizers if not explicitly disabled. SanitizerMask Default = TC.getDefaultSanitizers() & ~AllRemove; @@ -1159,6 +1167,15 @@ SanitizerArgs::SanitizerArgs(const ToolChain &TC, !TC.getTriple().isAndroid() && !TC.getTriple().isOSFuchsia(); } + if (AllAddedKinds & SanitizerKind::AllocToken) { + AllocTokenFastABI = Args.hasFlag( + options::OPT_fsanitize_alloc_token_fast_abi, + options::OPT_fno_sanitize_alloc_token_fast_abi, AllocTokenFastABI); + AllocTokenExtended = Args.hasFlag( + options::OPT_fsanitize_alloc_token_extended, + options::OPT_fno_sanitize_alloc_token_extended, AllocTokenExtended); + } + LinkRuntimes = Args.hasFlag(options::OPT_fsanitize_link_runtime, options::OPT_fno_sanitize_link_runtime, !Args.hasArg(options::OPT_r)); @@ -1527,6 +1544,12 @@ void SanitizerArgs::addArgs(const ToolChain &TC, const llvm::opt::ArgList &Args, Sanitizers.has(SanitizerKind::Address)) CmdArgs.push_back("-fno-assume-sane-operator-new"); + // Flags for -fsanitize=alloc-token. + if (AllocTokenFastABI) + CmdArgs.push_back("-fsanitize-alloc-token-fast-abi"); + if (AllocTokenExtended) + CmdArgs.push_back("-fsanitize-alloc-token-extended"); + // libFuzzer wants to intercept calls to certain library functions, so the // following -fno-builtin-* flags force the compiler to emit interposable // libcalls to these functions. Other sanitizers effectively do the same thing diff --git a/clang/lib/Driver/ToolChain.cpp b/clang/lib/Driver/ToolChain.cpp index a9041d2..3d5cac6 100644 --- a/clang/lib/Driver/ToolChain.cpp +++ b/clang/lib/Driver/ToolChain.cpp @@ -1623,7 +1623,8 @@ SanitizerMask ToolChain::getSupportedSanitizers() const { SanitizerKind::CFICastStrict | SanitizerKind::FloatDivideByZero | SanitizerKind::KCFI | SanitizerKind::UnsignedIntegerOverflow | SanitizerKind::UnsignedShiftBase | SanitizerKind::ImplicitConversion | - SanitizerKind::Nullability | SanitizerKind::LocalBounds; + SanitizerKind::Nullability | SanitizerKind::LocalBounds | + SanitizerKind::AllocToken; if (getTriple().getArch() == llvm::Triple::x86 || getTriple().getArch() == llvm::Triple::x86_64 || getTriple().getArch() == llvm::Triple::arm || diff --git a/clang/lib/Driver/ToolChains/Arch/AArch64.cpp b/clang/lib/Driver/ToolChains/Arch/AArch64.cpp index 98f5efb..eb5d542 100644 --- a/clang/lib/Driver/ToolChains/Arch/AArch64.cpp +++ b/clang/lib/Driver/ToolChains/Arch/AArch64.cpp @@ -57,6 +57,9 @@ std::string aarch64::getAArch64TargetCPU(const ArgList &Args, // iOS 26 only runs on apple-a12 and later CPUs. if (!Triple.isOSVersionLT(26)) return "apple-a12"; + // arm64 (non-e) iOS 18 only runs on apple-a10 and later CPUs. + if (!Triple.isOSVersionLT(18) && !Triple.isArm64e()) + return "apple-a10"; } if (Triple.isWatchOS()) { @@ -64,8 +67,8 @@ std::string aarch64::getAArch64TargetCPU(const ArgList &Args, // arm64_32/arm64e watchOS requires S4 before watchOS 26, S6 after. if (Triple.getArch() == llvm::Triple::aarch64_32 || Triple.isArm64e()) return Triple.isOSVersionLT(26) ? "apple-s4" : "apple-s6"; - // arm64 (non-e, non-32) watchOS comes later, and requires S6 anyway. - return "apple-s6"; + // arm64 (non-e, non-32) watchOS comes later, and requires S9 anyway. + return "apple-s9"; } if (Triple.isXROS()) { diff --git a/clang/lib/Driver/ToolChains/Arch/RISCV.cpp b/clang/lib/Driver/ToolChains/Arch/RISCV.cpp index 76dde0d..f2e79e7 100644 --- a/clang/lib/Driver/ToolChains/Arch/RISCV.cpp +++ b/clang/lib/Driver/ToolChains/Arch/RISCV.cpp @@ -49,11 +49,8 @@ static bool getArchFeatures(const Driver &D, StringRef Arch, return true; } -// Get features except standard extension feature -static void getRISCFeaturesFromMcpu(const Driver &D, const Arg *A, - const llvm::Triple &Triple, - StringRef Mcpu, - std::vector<StringRef> &Features) { +static bool isValidRISCVCPU(const Driver &D, const Arg *A, + const llvm::Triple &Triple, StringRef Mcpu) { bool Is64Bit = Triple.isRISCV64(); if (!llvm::RISCV::parseCPU(Mcpu, Is64Bit)) { // Try inverting Is64Bit in case the CPU is valid, but for the wrong target. @@ -63,7 +60,9 @@ static void getRISCFeaturesFromMcpu(const Driver &D, const Arg *A, else D.Diag(clang::diag::err_drv_unsupported_option_argument) << A->getSpelling() << Mcpu; + return false; } + return true; } void riscv::getRISCVTargetFeatures(const Driver &D, const llvm::Triple &Triple, @@ -84,7 +83,8 @@ void riscv::getRISCVTargetFeatures(const Driver &D, const llvm::Triple &Triple, if (CPU == "native") CPU = llvm::sys::getHostCPUName(); - getRISCFeaturesFromMcpu(D, A, Triple, CPU, Features); + if (!isValidRISCVCPU(D, A, Triple, CPU)) + return; if (llvm::RISCV::hasFastScalarUnalignedAccess(CPU)) CPUFastScalarUnaligned = true; diff --git a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp index 107b9ff..d326a81 100644 --- a/clang/lib/Driver/ToolChains/Clang.cpp +++ b/clang/lib/Driver/ToolChains/Clang.cpp @@ -7618,6 +7618,8 @@ void Clang::ConstructJob(Compilation &C, const JobAction &JA, // features enabled through -Xclang -target-feature flags. SanitizeArgs.addArgs(TC, Args, CmdArgs, InputType); + Args.AddLastArg(CmdArgs, options::OPT_falloc_token_max_EQ); + #if CLANG_ENABLE_CIR // Forward -mmlir arguments to to the MLIR option parser. for (const Arg *A : Args.filtered(options::OPT_mmlir)) { diff --git a/clang/lib/Driver/ToolChains/Clang.h b/clang/lib/Driver/ToolChains/Clang.h index c227895..9adad5c 100644 --- a/clang/lib/Driver/ToolChains/Clang.h +++ b/clang/lib/Driver/ToolChains/Clang.h @@ -163,7 +163,7 @@ public: class LLVM_LIBRARY_VISIBILITY OffloadPackager final : public Tool { public: OffloadPackager(const ToolChain &TC) - : Tool("Offload::Packager", "clang-offload-packager", TC) {} + : Tool("Offload::Packager", "llvm-offload-binary", TC) {} bool hasIntegratedCPP() const override { return false; } void ConstructJob(Compilation &C, const JobAction &JA, diff --git a/clang/lib/Driver/ToolChains/CommonArgs.cpp b/clang/lib/Driver/ToolChains/CommonArgs.cpp index 49ee53f..16cc1db 100644 --- a/clang/lib/Driver/ToolChains/CommonArgs.cpp +++ b/clang/lib/Driver/ToolChains/CommonArgs.cpp @@ -2231,7 +2231,7 @@ static unsigned ParseDebugDefaultVersion(const ToolChain &TC, return 0; unsigned Value = 0; - if (StringRef(A->getValue()).getAsInteger(10, Value) || Value > 5 || + if (StringRef(A->getValue()).getAsInteger(10, Value) || Value > 6 || Value < 2) TC.getDriver().Diag(diag::err_drv_invalid_int_value) << A->getAsString(Args) << A->getValue(); @@ -2244,13 +2244,14 @@ unsigned tools::DwarfVersionNum(StringRef ArgValue) { .Case("-gdwarf-3", 3) .Case("-gdwarf-4", 4) .Case("-gdwarf-5", 5) + .Case("-gdwarf-6", 6) .Default(0); } const Arg *tools::getDwarfNArg(const ArgList &Args) { return Args.getLastArg(options::OPT_gdwarf_2, options::OPT_gdwarf_3, options::OPT_gdwarf_4, options::OPT_gdwarf_5, - options::OPT_gdwarf); + options::OPT_gdwarf_6, options::OPT_gdwarf); } unsigned tools::getDwarfVersion(const ToolChain &TC, diff --git a/clang/lib/Driver/ToolChains/Darwin.cpp b/clang/lib/Driver/ToolChains/Darwin.cpp index 234683f..d2356eb 100644 --- a/clang/lib/Driver/ToolChains/Darwin.cpp +++ b/clang/lib/Driver/ToolChains/Darwin.cpp @@ -1609,7 +1609,12 @@ void DarwinClang::AddLinkRuntimeLibArgs(const ArgList &Args, if (Sanitize.needsFuzzer() && !Args.hasArg(options::OPT_dynamiclib)) { AddLinkSanitizerLibArgs(Args, CmdArgs, "fuzzer", /*shared=*/false); - // Libfuzzer is written in C++ and requires libcxx. + // Libfuzzer is written in C++ and requires libcxx. + // Since darwin::Linker::ConstructJob already adds -lc++ for clang++ + // by default if ShouldLinkCXXStdlib(Args), we only add the option if + // !ShouldLinkCXXStdlib(Args). This avoids duplicate library errors + // on Darwin. + if (!ShouldLinkCXXStdlib(Args)) AddCXXStdlibLibArgs(Args, CmdArgs); } if (Sanitize.needsStatsRt()) { diff --git a/clang/lib/Driver/ToolChains/HIPAMD.cpp b/clang/lib/Driver/ToolChains/HIPAMD.cpp index 5f3fbea..c0c8afe 100644 --- a/clang/lib/Driver/ToolChains/HIPAMD.cpp +++ b/clang/lib/Driver/ToolChains/HIPAMD.cpp @@ -168,9 +168,12 @@ void AMDGCN::Linker::constructLinkAndEmitSpirvCommand( const InputInfo &Output, const llvm::opt::ArgList &Args) const { assert(!Inputs.empty() && "Must have at least one input."); - constructLlvmLinkCommand(C, JA, Inputs, Output, Args); + std::string LinkedBCFilePrefix( + Twine(llvm::sys::path::stem(Output.getFilename()), "-linked").str()); + const char *LinkedBCFilePath = HIP::getTempFile(C, LinkedBCFilePrefix, "bc"); + InputInfo LinkedBCFile(&JA, LinkedBCFilePath, Output.getBaseInput()); - // Linked BC is now in Output + constructLlvmLinkCommand(C, JA, Inputs, LinkedBCFile, Args); // Emit SPIR-V binary. llvm::opt::ArgStringList TrArgs{ @@ -180,7 +183,7 @@ void AMDGCN::Linker::constructLinkAndEmitSpirvCommand( "--spirv-lower-const-expr", "--spirv-preserve-auxdata", "--spirv-debug-info-version=nonsemantic-shader-200"}; - SPIRV::constructTranslateCommand(C, *this, JA, Output, Output, TrArgs); + SPIRV::constructTranslateCommand(C, *this, JA, Output, LinkedBCFile, TrArgs); } // For amdgcn the inputs of the linker job are device bitcode and output is diff --git a/clang/lib/Driver/ToolChains/HIPSPV.cpp b/clang/lib/Driver/ToolChains/HIPSPV.cpp index 62bca04..bce7f46 100644 --- a/clang/lib/Driver/ToolChains/HIPSPV.cpp +++ b/clang/lib/Driver/ToolChains/HIPSPV.cpp @@ -22,17 +22,6 @@ using namespace clang::driver::tools; using namespace clang; using namespace llvm::opt; -// Convenience function for creating temporary file for both modes of -// isSaveTempsEnabled(). -static const char *getTempFile(Compilation &C, StringRef Prefix, - StringRef Extension) { - if (C.getDriver().isSaveTempsEnabled()) { - return C.getArgs().MakeArgString(Prefix + "." + Extension); - } - auto TmpFile = C.getDriver().GetTemporaryPath(Prefix, Extension); - return C.addTempFile(C.getArgs().MakeArgString(TmpFile)); -} - // Locates HIP pass plugin. static std::string findPassPlugin(const Driver &D, const llvm::opt::ArgList &Args) { @@ -65,7 +54,7 @@ void HIPSPV::Linker::constructLinkAndEmitSpirvCommand( assert(!Inputs.empty() && "Must have at least one input."); std::string Name = std::string(llvm::sys::path::stem(Output.getFilename())); - const char *TempFile = getTempFile(C, Name + "-link", "bc"); + const char *TempFile = HIP::getTempFile(C, Name + "-link", "bc"); // Link LLVM bitcode. ArgStringList LinkArgs{}; @@ -93,7 +82,7 @@ void HIPSPV::Linker::constructLinkAndEmitSpirvCommand( auto PassPluginPath = findPassPlugin(C.getDriver(), Args); if (!PassPluginPath.empty()) { const char *PassPathCStr = C.getArgs().MakeArgString(PassPluginPath); - const char *OptOutput = getTempFile(C, Name + "-lower", "bc"); + const char *OptOutput = HIP::getTempFile(C, Name + "-lower", "bc"); ArgStringList OptArgs{TempFile, "-load-pass-plugin", PassPathCStr, "-passes=hip-post-link-passes", "-o", OptOutput}; diff --git a/clang/lib/Driver/ToolChains/HIPUtility.cpp b/clang/lib/Driver/ToolChains/HIPUtility.cpp index cb061ff..732403e 100644 --- a/clang/lib/Driver/ToolChains/HIPUtility.cpp +++ b/clang/lib/Driver/ToolChains/HIPUtility.cpp @@ -472,3 +472,14 @@ void HIP::constructGenerateObjFileFromHIPFatBinary( D.getClangProgramPath(), ClangArgs, Inputs, Output, D.getPrependArg())); } + +// Convenience function for creating temporary file for both modes of +// isSaveTempsEnabled(). +const char *HIP::getTempFile(Compilation &C, StringRef Prefix, + StringRef Extension) { + if (C.getDriver().isSaveTempsEnabled()) { + return C.getArgs().MakeArgString(Prefix + "." + Extension); + } + auto TmpFile = C.getDriver().GetTemporaryPath(Prefix, Extension); + return C.addTempFile(C.getArgs().MakeArgString(TmpFile)); +} diff --git a/clang/lib/Driver/ToolChains/HIPUtility.h b/clang/lib/Driver/ToolChains/HIPUtility.h index 29e5a92..55c155e 100644 --- a/clang/lib/Driver/ToolChains/HIPUtility.h +++ b/clang/lib/Driver/ToolChains/HIPUtility.h @@ -16,6 +16,8 @@ namespace driver { namespace tools { namespace HIP { +const char *getTempFile(Compilation &C, StringRef Prefix, StringRef Extension); + // Construct command for creating HIP fatbin. void constructHIPFatbinCommand(Compilation &C, const JobAction &JA, StringRef OutputFileName, diff --git a/clang/lib/Driver/ToolChains/UEFI.cpp b/clang/lib/Driver/ToolChains/UEFI.cpp index 75adbf1..d2be147 100644 --- a/clang/lib/Driver/ToolChains/UEFI.cpp +++ b/clang/lib/Driver/ToolChains/UEFI.cpp @@ -24,7 +24,9 @@ using namespace clang; using namespace llvm::opt; UEFI::UEFI(const Driver &D, const llvm::Triple &Triple, const ArgList &Args) - : ToolChain(D, Triple, Args) {} + : ToolChain(D, Triple, Args) { + getProgramPaths().push_back(getDriver().Dir); +} Tool *UEFI::buildLinker() const { return new tools::uefi::Linker(*this); } diff --git a/clang/lib/Frontend/CompilerInvocation.cpp b/clang/lib/Frontend/CompilerInvocation.cpp index 50fd50a..292adce 100644 --- a/clang/lib/Frontend/CompilerInvocation.cpp +++ b/clang/lib/Frontend/CompilerInvocation.cpp @@ -1833,6 +1833,10 @@ void CompilerInvocationBase::GenerateCodeGenArgs(const CodeGenOptions &Opts, serializeSanitizerKinds(Opts.SanitizeAnnotateDebugInfo)) GenerateArg(Consumer, OPT_fsanitize_annotate_debug_info_EQ, Sanitizer); + if (Opts.AllocTokenMax) + GenerateArg(Consumer, OPT_falloc_token_max_EQ, + std::to_string(*Opts.AllocTokenMax)); + if (!Opts.EmitVersionIdentMetadata) GenerateArg(Consumer, OPT_Qn); @@ -2346,6 +2350,15 @@ bool CompilerInvocation::ParseCodeGenArgs(CodeGenOptions &Opts, ArgList &Args, } } + if (const auto *Arg = Args.getLastArg(options::OPT_falloc_token_max_EQ)) { + StringRef S = Arg->getValue(); + uint64_t Value = 0; + if (S.getAsInteger(0, Value)) + Diags.Report(diag::err_drv_invalid_value) << Arg->getAsString(Args) << S; + else + Opts.AllocTokenMax = Value; + } + Opts.EmitVersionIdentMetadata = Args.hasFlag(OPT_Qy, OPT_Qn, true); if (!LangOpts->CUDAIsDevice) diff --git a/clang/lib/Frontend/InitPreprocessor.cpp b/clang/lib/Frontend/InitPreprocessor.cpp index 877ab02..b899fb9 100644 --- a/clang/lib/Frontend/InitPreprocessor.cpp +++ b/clang/lib/Frontend/InitPreprocessor.cpp @@ -1530,6 +1530,8 @@ static void InitializePredefinedMacros(const TargetInfo &TI, Builder.defineMacro("__SANITIZE_HWADDRESS__"); if (LangOpts.Sanitize.has(SanitizerKind::Thread)) Builder.defineMacro("__SANITIZE_THREAD__"); + if (LangOpts.Sanitize.has(SanitizerKind::AllocToken)) + Builder.defineMacro("__SANITIZE_ALLOC_TOKEN__"); // Target OS macro definitions. if (PPOpts.DefineTargetOSMacros) { diff --git a/clang/lib/Headers/avx2intrin.h b/clang/lib/Headers/avx2intrin.h index 31759c5..4aaca2d 100644 --- a/clang/lib/Headers/avx2intrin.h +++ b/clang/lib/Headers/avx2intrin.h @@ -1035,10 +1035,9 @@ _mm256_hsubs_epi16(__m256i __a, __m256i __b) /// \param __b /// A 256-bit vector containing one of the source operands. /// \returns A 256-bit vector of [16 x i16] containing the result. -static __inline__ __m256i __DEFAULT_FN_ATTRS256 -_mm256_maddubs_epi16(__m256i __a, __m256i __b) -{ - return (__m256i)__builtin_ia32_pmaddubsw256((__v32qi)__a, (__v32qi)__b); +static __inline__ __m256i __DEFAULT_FN_ATTRS256_CONSTEXPR +_mm256_maddubs_epi16(__m256i __a, __m256i __b) { + return (__m256i)__builtin_ia32_pmaddubsw256((__v32qi)__a, (__v32qi)__b); } /// Multiplies corresponding 16-bit elements of two 256-bit vectors of @@ -1067,9 +1066,8 @@ _mm256_maddubs_epi16(__m256i __a, __m256i __b) /// \param __b /// A 256-bit vector of [16 x i16] containing one of the source operands. /// \returns A 256-bit vector of [8 x i32] containing the result. -static __inline__ __m256i __DEFAULT_FN_ATTRS256 -_mm256_madd_epi16(__m256i __a, __m256i __b) -{ +static __inline__ __m256i __DEFAULT_FN_ATTRS256_CONSTEXPR +_mm256_madd_epi16(__m256i __a, __m256i __b) { return (__m256i)__builtin_ia32_pmaddwd256((__v16hi)__a, (__v16hi)__b); } diff --git a/clang/lib/Headers/avx512bwintrin.h b/clang/lib/Headers/avx512bwintrin.h index c36bd81..473fe94 100644 --- a/clang/lib/Headers/avx512bwintrin.h +++ b/clang/lib/Headers/avx512bwintrin.h @@ -1064,12 +1064,12 @@ _mm512_maskz_mulhi_epu16(__mmask32 __U, __m512i __A, __m512i __B) { (__v32hi)_mm512_setzero_si512()); } -static __inline__ __m512i __DEFAULT_FN_ATTRS512 +static __inline__ __m512i __DEFAULT_FN_ATTRS512_CONSTEXPR _mm512_maddubs_epi16(__m512i __X, __m512i __Y) { return (__m512i)__builtin_ia32_pmaddubsw512((__v64qi)__X, (__v64qi)__Y); } -static __inline__ __m512i __DEFAULT_FN_ATTRS512 +static __inline__ __m512i __DEFAULT_FN_ATTRS512_CONSTEXPR _mm512_mask_maddubs_epi16(__m512i __W, __mmask32 __U, __m512i __X, __m512i __Y) { return (__m512i)__builtin_ia32_selectw_512((__mmask32) __U, @@ -1077,26 +1077,26 @@ _mm512_mask_maddubs_epi16(__m512i __W, __mmask32 __U, __m512i __X, (__v32hi)__W); } -static __inline__ __m512i __DEFAULT_FN_ATTRS512 +static __inline__ __m512i __DEFAULT_FN_ATTRS512_CONSTEXPR _mm512_maskz_maddubs_epi16(__mmask32 __U, __m512i __X, __m512i __Y) { return (__m512i)__builtin_ia32_selectw_512((__mmask32) __U, (__v32hi)_mm512_maddubs_epi16(__X, __Y), (__v32hi)_mm512_setzero_si512()); } -static __inline__ __m512i __DEFAULT_FN_ATTRS512 +static __inline__ __m512i __DEFAULT_FN_ATTRS512_CONSTEXPR _mm512_madd_epi16(__m512i __A, __m512i __B) { return (__m512i)__builtin_ia32_pmaddwd512((__v32hi)__A, (__v32hi)__B); } -static __inline__ __m512i __DEFAULT_FN_ATTRS512 +static __inline__ __m512i __DEFAULT_FN_ATTRS512_CONSTEXPR _mm512_mask_madd_epi16(__m512i __W, __mmask16 __U, __m512i __A, __m512i __B) { return (__m512i)__builtin_ia32_selectd_512((__mmask16)__U, (__v16si)_mm512_madd_epi16(__A, __B), (__v16si)__W); } -static __inline__ __m512i __DEFAULT_FN_ATTRS512 +static __inline__ __m512i __DEFAULT_FN_ATTRS512_CONSTEXPR _mm512_maskz_madd_epi16(__mmask16 __U, __m512i __A, __m512i __B) { return (__m512i)__builtin_ia32_selectd_512((__mmask16)__U, (__v16si)_mm512_madd_epi16(__A, __B), diff --git a/clang/lib/Headers/avx512fp16intrin.h b/clang/lib/Headers/avx512fp16intrin.h index d951ba0..142cc07 100644 --- a/clang/lib/Headers/avx512fp16intrin.h +++ b/clang/lib/Headers/avx512fp16intrin.h @@ -112,7 +112,7 @@ static __inline__ __m512h __DEFAULT_FN_ATTRS512_CONSTEXPR _mm512_setr_ph( e9, e8, e7, e6, e5, e4, e3, e2, e1, e0); } -static __inline __m512h __DEFAULT_FN_ATTRS512 +static __inline __m512h __DEFAULT_FN_ATTRS512_CONSTEXPR _mm512_set1_pch(_Float16 _Complex __h) { return (__m512h)_mm512_set1_ps(__builtin_bit_cast(float, __h)); } @@ -193,17 +193,17 @@ _mm512_castsi512_ph(__m512i __a) { return (__m512h)__a; } -static __inline__ __m128h __DEFAULT_FN_ATTRS256 +static __inline__ __m128h __DEFAULT_FN_ATTRS256_CONSTEXPR _mm256_castph256_ph128(__m256h __a) { return __builtin_shufflevector(__a, __a, 0, 1, 2, 3, 4, 5, 6, 7); } -static __inline__ __m128h __DEFAULT_FN_ATTRS512 +static __inline__ __m128h __DEFAULT_FN_ATTRS512_CONSTEXPR _mm512_castph512_ph128(__m512h __a) { return __builtin_shufflevector(__a, __a, 0, 1, 2, 3, 4, 5, 6, 7); } -static __inline__ __m256h __DEFAULT_FN_ATTRS512 +static __inline__ __m256h __DEFAULT_FN_ATTRS512_CONSTEXPR _mm512_castph512_ph256(__m512h __a) { return __builtin_shufflevector(__a, __a, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); diff --git a/clang/lib/Headers/avx512vlbwintrin.h b/clang/lib/Headers/avx512vlbwintrin.h index 5e6daa8..81e4cbb9 100644 --- a/clang/lib/Headers/avx512vlbwintrin.h +++ b/clang/lib/Headers/avx512vlbwintrin.h @@ -1295,21 +1295,21 @@ _mm256_maskz_permutex2var_epi16 (__mmask16 __U, __m256i __A, __m256i __I, (__v16hi)_mm256_setzero_si256()); } -static __inline__ __m128i __DEFAULT_FN_ATTRS128 +static __inline__ __m128i __DEFAULT_FN_ATTRS128_CONSTEXPR _mm_mask_maddubs_epi16(__m128i __W, __mmask8 __U, __m128i __X, __m128i __Y) { return (__m128i)__builtin_ia32_selectw_128((__mmask8)__U, (__v8hi)_mm_maddubs_epi16(__X, __Y), (__v8hi)__W); } -static __inline__ __m128i __DEFAULT_FN_ATTRS128 +static __inline__ __m128i __DEFAULT_FN_ATTRS128_CONSTEXPR _mm_maskz_maddubs_epi16(__mmask8 __U, __m128i __X, __m128i __Y) { return (__m128i)__builtin_ia32_selectw_128((__mmask8)__U, (__v8hi)_mm_maddubs_epi16(__X, __Y), (__v8hi)_mm_setzero_si128()); } -static __inline__ __m256i __DEFAULT_FN_ATTRS256 +static __inline__ __m256i __DEFAULT_FN_ATTRS256_CONSTEXPR _mm256_mask_maddubs_epi16(__m256i __W, __mmask16 __U, __m256i __X, __m256i __Y) { return (__m256i)__builtin_ia32_selectw_256((__mmask16)__U, @@ -1317,35 +1317,35 @@ _mm256_mask_maddubs_epi16(__m256i __W, __mmask16 __U, __m256i __X, (__v16hi)__W); } -static __inline__ __m256i __DEFAULT_FN_ATTRS256 +static __inline__ __m256i __DEFAULT_FN_ATTRS256_CONSTEXPR _mm256_maskz_maddubs_epi16(__mmask16 __U, __m256i __X, __m256i __Y) { return (__m256i)__builtin_ia32_selectw_256((__mmask16)__U, (__v16hi)_mm256_maddubs_epi16(__X, __Y), (__v16hi)_mm256_setzero_si256()); } -static __inline__ __m128i __DEFAULT_FN_ATTRS128 +static __inline__ __m128i __DEFAULT_FN_ATTRS128_CONSTEXPR _mm_mask_madd_epi16(__m128i __W, __mmask8 __U, __m128i __A, __m128i __B) { return (__m128i)__builtin_ia32_selectd_128((__mmask8)__U, (__v4si)_mm_madd_epi16(__A, __B), (__v4si)__W); } -static __inline__ __m128i __DEFAULT_FN_ATTRS128 +static __inline__ __m128i __DEFAULT_FN_ATTRS128_CONSTEXPR _mm_maskz_madd_epi16(__mmask8 __U, __m128i __A, __m128i __B) { return (__m128i)__builtin_ia32_selectd_128((__mmask8)__U, (__v4si)_mm_madd_epi16(__A, __B), (__v4si)_mm_setzero_si128()); } -static __inline__ __m256i __DEFAULT_FN_ATTRS256 +static __inline__ __m256i __DEFAULT_FN_ATTRS256_CONSTEXPR _mm256_mask_madd_epi16(__m256i __W, __mmask8 __U, __m256i __A, __m256i __B) { return (__m256i)__builtin_ia32_selectd_256((__mmask8)__U, (__v8si)_mm256_madd_epi16(__A, __B), (__v8si)__W); } -static __inline__ __m256i __DEFAULT_FN_ATTRS256 +static __inline__ __m256i __DEFAULT_FN_ATTRS256_CONSTEXPR _mm256_maskz_madd_epi16(__mmask8 __U, __m256i __A, __m256i __B) { return (__m256i)__builtin_ia32_selectd_256((__mmask8)__U, (__v8si)_mm256_madd_epi16(__A, __B), diff --git a/clang/lib/Headers/avx512vlfp16intrin.h b/clang/lib/Headers/avx512vlfp16intrin.h index c0bcc08..5b2b3f0 100644 --- a/clang/lib/Headers/avx512vlfp16intrin.h +++ b/clang/lib/Headers/avx512vlfp16intrin.h @@ -34,11 +34,13 @@ #define __DEFAULT_FN_ATTRS128_CONSTEXPR __DEFAULT_FN_ATTRS128 #endif -static __inline__ _Float16 __DEFAULT_FN_ATTRS128 _mm_cvtsh_h(__m128h __a) { +static __inline__ _Float16 __DEFAULT_FN_ATTRS128_CONSTEXPR +_mm_cvtsh_h(__m128h __a) { return __a[0]; } -static __inline__ _Float16 __DEFAULT_FN_ATTRS256 _mm256_cvtsh_h(__m256h __a) { +static __inline__ _Float16 __DEFAULT_FN_ATTRS256_CONSTEXPR +_mm256_cvtsh_h(__m256h __a) { return __a[0]; } diff --git a/clang/lib/Headers/emmintrin.h b/clang/lib/Headers/emmintrin.h index 6597e7e..454e9a2 100644 --- a/clang/lib/Headers/emmintrin.h +++ b/clang/lib/Headers/emmintrin.h @@ -2290,8 +2290,8 @@ _mm_avg_epu16(__m128i __a, __m128i __b) { /// A 128-bit signed [8 x i16] vector. /// \returns A 128-bit signed [4 x i32] vector containing the sums of products /// of both parameters. -static __inline__ __m128i __DEFAULT_FN_ATTRS _mm_madd_epi16(__m128i __a, - __m128i __b) { +static __inline__ __m128i __DEFAULT_FN_ATTRS_CONSTEXPR +_mm_madd_epi16(__m128i __a, __m128i __b) { return (__m128i)__builtin_ia32_pmaddwd128((__v8hi)__a, (__v8hi)__b); } diff --git a/clang/lib/Headers/mmintrin.h b/clang/lib/Headers/mmintrin.h index 5f61753..aca78e6 100644 --- a/clang/lib/Headers/mmintrin.h +++ b/clang/lib/Headers/mmintrin.h @@ -679,11 +679,10 @@ _mm_subs_pu16(__m64 __m1, __m64 __m2) { /// A 64-bit integer vector of [4 x i16]. /// \returns A 64-bit integer vector of [2 x i32] containing the sums of /// products of both parameters. -static __inline__ __m64 __DEFAULT_FN_ATTRS_SSE2 -_mm_madd_pi16(__m64 __m1, __m64 __m2) -{ - return __trunc64(__builtin_ia32_pmaddwd128((__v8hi)__anyext128(__m1), - (__v8hi)__anyext128(__m2))); +static __inline__ __m64 __DEFAULT_FN_ATTRS_SSE2_CONSTEXPR +_mm_madd_pi16(__m64 __m1, __m64 __m2) { + return __trunc64(__builtin_ia32_pmaddwd128((__v8hi)__zext128(__m1), + (__v8hi)__zext128(__m2))); } /// Multiplies each 16-bit signed integer element of the first 64-bit diff --git a/clang/lib/Headers/tmmintrin.h b/clang/lib/Headers/tmmintrin.h index d40f0c5..3fc9f98 100644 --- a/clang/lib/Headers/tmmintrin.h +++ b/clang/lib/Headers/tmmintrin.h @@ -23,6 +23,9 @@ #define __trunc64(x) \ (__m64) __builtin_shufflevector((__v2di)(x), __extension__(__v2di){}, 0) +#define __zext128(x) \ + (__m128i) __builtin_shufflevector((__v2si)(x), __extension__(__v2si){}, 0, \ + 1, 2, 3) #define __anyext128(x) \ (__m128i) __builtin_shufflevector((__v2si)(x), __extension__(__v2si){}, 0, \ 1, -1, -1) @@ -504,10 +507,9 @@ _mm_hsubs_pi16(__m64 __a, __m64 __b) /// \a R5 := (\a __a10 * \a __b10) + (\a __a11 * \a __b11) \n /// \a R6 := (\a __a12 * \a __b12) + (\a __a13 * \a __b13) \n /// \a R7 := (\a __a14 * \a __b14) + (\a __a15 * \a __b15) -static __inline__ __m128i __DEFAULT_FN_ATTRS -_mm_maddubs_epi16(__m128i __a, __m128i __b) -{ - return (__m128i)__builtin_ia32_pmaddubsw128((__v16qi)__a, (__v16qi)__b); +static __inline__ __m128i __DEFAULT_FN_ATTRS_CONSTEXPR +_mm_maddubs_epi16(__m128i __a, __m128i __b) { + return (__m128i)__builtin_ia32_pmaddubsw128((__v16qi)__a, (__v16qi)__b); } /// Multiplies corresponding pairs of packed 8-bit unsigned integer @@ -534,11 +536,10 @@ _mm_maddubs_epi16(__m128i __a, __m128i __b) /// \a R1 := (\a __a2 * \a __b2) + (\a __a3 * \a __b3) \n /// \a R2 := (\a __a4 * \a __b4) + (\a __a5 * \a __b5) \n /// \a R3 := (\a __a6 * \a __b6) + (\a __a7 * \a __b7) -static __inline__ __m64 __DEFAULT_FN_ATTRS -_mm_maddubs_pi16(__m64 __a, __m64 __b) -{ - return __trunc64(__builtin_ia32_pmaddubsw128((__v16qi)__anyext128(__a), - (__v16qi)__anyext128(__b))); +static __inline__ __m64 __DEFAULT_FN_ATTRS_CONSTEXPR +_mm_maddubs_pi16(__m64 __a, __m64 __b) { + return __trunc64(__builtin_ia32_pmaddubsw128((__v16qi)__zext128(__a), + (__v16qi)__zext128(__b))); } /// Multiplies packed 16-bit signed integer values, truncates the 32-bit @@ -796,6 +797,7 @@ _mm_sign_pi32(__m64 __a, __m64 __b) } #undef __anyext128 +#undef __zext128 #undef __trunc64 #undef __DEFAULT_FN_ATTRS #undef __DEFAULT_FN_ATTRS_CONSTEXPR diff --git a/clang/lib/Parse/ParseExprCXX.cpp b/clang/lib/Parse/ParseExprCXX.cpp index a2c6957..90191b0 100644 --- a/clang/lib/Parse/ParseExprCXX.cpp +++ b/clang/lib/Parse/ParseExprCXX.cpp @@ -3200,6 +3200,8 @@ ExprResult Parser::ParseRequiresExpression() { BalancedDelimiterTracker ExprBraces(*this, tok::l_brace); ExprBraces.consumeOpen(); ExprResult Expression = ParseExpression(); + if (Expression.isUsable()) + Expression = Actions.CheckPlaceholderExpr(Expression.get()); if (!Expression.isUsable()) { ExprBraces.skipToEnd(); SkipUntil(tok::semi, tok::r_brace, SkipUntilFlags::StopBeforeMatch); @@ -3369,6 +3371,8 @@ ExprResult Parser::ParseRequiresExpression() { // expression ';' SourceLocation StartLoc = Tok.getLocation(); ExprResult Expression = ParseExpression(); + if (Expression.isUsable()) + Expression = Actions.CheckPlaceholderExpr(Expression.get()); if (!Expression.isUsable()) { SkipUntil(tok::semi, tok::r_brace, SkipUntilFlags::StopBeforeMatch); break; diff --git a/clang/lib/Sema/AnalysisBasedWarnings.cpp b/clang/lib/Sema/AnalysisBasedWarnings.cpp index 8606227..e9ca8ce 100644 --- a/clang/lib/Sema/AnalysisBasedWarnings.cpp +++ b/clang/lib/Sema/AnalysisBasedWarnings.cpp @@ -2605,6 +2605,17 @@ public: #endif } + void handleUnsafeUniquePtrArrayAccess(const DynTypedNode &Node, + bool IsRelatedToDecl, + ASTContext &Ctx) override { + SourceLocation Loc; + std::string Message; + + Loc = Node.get<Stmt>()->getBeginLoc(); + S.Diag(Loc, diag::warn_unsafe_buffer_usage_unique_ptr_array_access) + << Node.getSourceRange(); + } + bool isSafeBufferOptOut(const SourceLocation &Loc) const override { return S.PP.isSafeBufferOptOut(S.getSourceManager(), Loc); } diff --git a/clang/lib/Sema/HLSLBuiltinTypeDeclBuilder.cpp b/clang/lib/Sema/HLSLBuiltinTypeDeclBuilder.cpp index 3c20ccd..40c318a 100644 --- a/clang/lib/Sema/HLSLBuiltinTypeDeclBuilder.cpp +++ b/clang/lib/Sema/HLSLBuiltinTypeDeclBuilder.cpp @@ -144,6 +144,7 @@ private: _2, _3, _4, + _5, Handle = 128, CounterHandle, LastStmt @@ -190,6 +191,9 @@ public: template <typename T> BuiltinTypeMethodBuilder & accessCounterHandleFieldOnResource(T ResourceRecord); + template <typename ResourceT, typename ValueT> + BuiltinTypeMethodBuilder & + setCounterHandleFieldOnResource(ResourceT ResourceRecord, ValueT HandleValue); template <typename T> BuiltinTypeMethodBuilder &returnValue(T ReturnValue); BuiltinTypeMethodBuilder &returnThis(); BuiltinTypeDeclBuilder &finalize(); @@ -205,6 +209,11 @@ private: if (!Method) createDecl(); } + + template <typename ResourceT, typename ValueT> + BuiltinTypeMethodBuilder &setFieldOnResource(ResourceT ResourceRecord, + ValueT HandleValue, + FieldDecl *HandleField); }; TemplateParameterListBuilder::~TemplateParameterListBuilder() { @@ -592,13 +601,27 @@ template <typename ResourceT, typename ValueT> BuiltinTypeMethodBuilder & BuiltinTypeMethodBuilder::setHandleFieldOnResource(ResourceT ResourceRecord, ValueT HandleValue) { + return setFieldOnResource(ResourceRecord, HandleValue, + DeclBuilder.getResourceHandleField()); +} + +template <typename ResourceT, typename ValueT> +BuiltinTypeMethodBuilder & +BuiltinTypeMethodBuilder::setCounterHandleFieldOnResource( + ResourceT ResourceRecord, ValueT HandleValue) { + return setFieldOnResource(ResourceRecord, HandleValue, + DeclBuilder.getResourceCounterHandleField()); +} + +template <typename ResourceT, typename ValueT> +BuiltinTypeMethodBuilder &BuiltinTypeMethodBuilder::setFieldOnResource( + ResourceT ResourceRecord, ValueT HandleValue, FieldDecl *HandleField) { ensureCompleteDecl(); Expr *ResourceExpr = convertPlaceholder(ResourceRecord); Expr *HandleValueExpr = convertPlaceholder(HandleValue); ASTContext &AST = DeclBuilder.SemaRef.getASTContext(); - FieldDecl *HandleField = DeclBuilder.getResourceHandleField(); MemberExpr *HandleMemberExpr = MemberExpr::CreateImplicit( AST, ResourceExpr, false, HandleField, HandleField->getType(), VK_LValue, OK_Ordinary); @@ -829,6 +852,18 @@ BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addDefaultHandleConstructor() { .finalize(); } +BuiltinTypeDeclBuilder & +BuiltinTypeDeclBuilder::addStaticInitializationFunctions(bool HasCounter) { + if (HasCounter) { + addCreateFromBindingWithImplicitCounter(); + addCreateFromImplicitBindingWithImplicitCounter(); + } else { + addCreateFromBinding(); + addCreateFromImplicitBinding(); + } + return *this; +} + // Adds static method that initializes resource from binding: // // static Resource<T> __createFromBinding(unsigned registerNo, @@ -903,6 +938,102 @@ BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addCreateFromImplicitBinding() { .finalize(); } +// Adds static method that initializes resource from binding: +// +// static Resource<T> +// __createFromBindingWithImplicitCounter(unsigned registerNo, +// unsigned spaceNo, int range, +// unsigned index, const char *name, +// unsigned counterOrderId) { +// Resource<T> tmp; +// tmp.__handle = __builtin_hlsl_resource_handlefrombinding( +// tmp.__handle, registerNo, spaceNo, range, index, name); +// tmp.__counter_handle = +// __builtin_hlsl_resource_counterhandlefromimplicitbinding( +// tmp.__handle, counterOrderId, spaceNo); +// return tmp; +// } +BuiltinTypeDeclBuilder & +BuiltinTypeDeclBuilder::addCreateFromBindingWithImplicitCounter() { + assert(!Record->isCompleteDefinition() && "record is already complete"); + + using PH = BuiltinTypeMethodBuilder::PlaceHolder; + ASTContext &AST = SemaRef.getASTContext(); + QualType HandleType = getResourceHandleField()->getType(); + QualType RecordType = AST.getTypeDeclType(cast<TypeDecl>(Record)); + BuiltinTypeMethodBuilder::LocalVar TmpVar("tmp", RecordType); + + return BuiltinTypeMethodBuilder(*this, + "__createFromBindingWithImplicitCounter", + RecordType, false, false, SC_Static) + .addParam("registerNo", AST.UnsignedIntTy) + .addParam("spaceNo", AST.UnsignedIntTy) + .addParam("range", AST.IntTy) + .addParam("index", AST.UnsignedIntTy) + .addParam("name", AST.getPointerType(AST.CharTy.withConst())) + .addParam("counterOrderId", AST.UnsignedIntTy) + .declareLocalVar(TmpVar) + .accessHandleFieldOnResource(TmpVar) + .callBuiltin("__builtin_hlsl_resource_handlefrombinding", HandleType, + PH::LastStmt, PH::_0, PH::_1, PH::_2, PH::_3, PH::_4) + .setHandleFieldOnResource(TmpVar, PH::LastStmt) + .accessHandleFieldOnResource(TmpVar) + .callBuiltin("__builtin_hlsl_resource_counterhandlefromimplicitbinding", + HandleType, PH::LastStmt, PH::_5, PH::_1) + .setCounterHandleFieldOnResource(TmpVar, PH::LastStmt) + .returnValue(TmpVar) + .finalize(); +} + +// Adds static method that initializes resource from binding: +// +// static Resource<T> +// __createFromImplicitBindingWithImplicitCounter(unsigned orderId, +// unsigned spaceNo, int range, +// unsigned index, +// const char *name, +// unsigned counterOrderId) { +// Resource<T> tmp; +// tmp.__handle = __builtin_hlsl_resource_handlefromimplicitbinding( +// tmp.__handle, orderId, spaceNo, range, index, name); +// tmp.__counter_handle = +// __builtin_hlsl_resource_counterhandlefromimplicitbinding( +// tmp.__handle, counterOrderId, spaceNo); +// return tmp; +// } +BuiltinTypeDeclBuilder & +BuiltinTypeDeclBuilder::addCreateFromImplicitBindingWithImplicitCounter() { + assert(!Record->isCompleteDefinition() && "record is already complete"); + + using PH = BuiltinTypeMethodBuilder::PlaceHolder; + ASTContext &AST = SemaRef.getASTContext(); + QualType HandleType = getResourceHandleField()->getType(); + QualType RecordType = AST.getTypeDeclType(cast<TypeDecl>(Record)); + BuiltinTypeMethodBuilder::LocalVar TmpVar("tmp", RecordType); + + return BuiltinTypeMethodBuilder( + *this, "__createFromImplicitBindingWithImplicitCounter", + RecordType, false, false, SC_Static) + .addParam("orderId", AST.UnsignedIntTy) + .addParam("spaceNo", AST.UnsignedIntTy) + .addParam("range", AST.IntTy) + .addParam("index", AST.UnsignedIntTy) + .addParam("name", AST.getPointerType(AST.CharTy.withConst())) + .addParam("counterOrderId", AST.UnsignedIntTy) + .declareLocalVar(TmpVar) + .accessHandleFieldOnResource(TmpVar) + .callBuiltin("__builtin_hlsl_resource_handlefromimplicitbinding", + HandleType, PH::LastStmt, PH::_0, PH::_1, PH::_2, PH::_3, + PH::_4) + .setHandleFieldOnResource(TmpVar, PH::LastStmt) + .accessHandleFieldOnResource(TmpVar) + .callBuiltin("__builtin_hlsl_resource_counterhandlefromimplicitbinding", + HandleType, PH::LastStmt, PH::_5, PH::_1) + .setCounterHandleFieldOnResource(TmpVar, PH::LastStmt) + .returnValue(TmpVar) + .finalize(); +} + BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addCopyConstructor() { assert(!Record->isCompleteDefinition() && "record is already complete"); @@ -1048,7 +1179,7 @@ BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addIncrementCounterMethod() { return BuiltinTypeMethodBuilder(*this, "IncrementCounter", SemaRef.getASTContext().UnsignedIntTy) .callBuiltin("__builtin_hlsl_buffer_update_counter", QualType(), - PH::Handle, getConstantIntExpr(1)) + PH::CounterHandle, getConstantIntExpr(1)) .finalize(); } @@ -1057,7 +1188,7 @@ BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addDecrementCounterMethod() { return BuiltinTypeMethodBuilder(*this, "DecrementCounter", SemaRef.getASTContext().UnsignedIntTy) .callBuiltin("__builtin_hlsl_buffer_update_counter", QualType(), - PH::Handle, getConstantIntExpr(-1)) + PH::CounterHandle, getConstantIntExpr(-1)) .finalize(); } @@ -1102,7 +1233,7 @@ BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addAppendMethod() { return BuiltinTypeMethodBuilder(*this, "Append", AST.VoidTy) .addParam("value", ElemTy) .callBuiltin("__builtin_hlsl_buffer_update_counter", AST.UnsignedIntTy, - PH::Handle, getConstantIntExpr(1)) + PH::CounterHandle, getConstantIntExpr(1)) .callBuiltin("__builtin_hlsl_resource_getpointer", AST.getPointerType(AddrSpaceElemTy), PH::Handle, PH::LastStmt) @@ -1119,7 +1250,7 @@ BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addConsumeMethod() { AST.getAddrSpaceQualType(ElemTy, LangAS::hlsl_device); return BuiltinTypeMethodBuilder(*this, "Consume", ElemTy) .callBuiltin("__builtin_hlsl_buffer_update_counter", AST.UnsignedIntTy, - PH::Handle, getConstantIntExpr(-1)) + PH::CounterHandle, getConstantIntExpr(-1)) .callBuiltin("__builtin_hlsl_resource_getpointer", AST.getPointerType(AddrSpaceElemTy), PH::Handle, PH::LastStmt) diff --git a/clang/lib/Sema/HLSLBuiltinTypeDeclBuilder.h b/clang/lib/Sema/HLSLBuiltinTypeDeclBuilder.h index a981602..86cbd10 100644 --- a/clang/lib/Sema/HLSLBuiltinTypeDeclBuilder.h +++ b/clang/lib/Sema/HLSLBuiltinTypeDeclBuilder.h @@ -83,8 +83,7 @@ public: BuiltinTypeDeclBuilder &addCopyAssignmentOperator(); // Static create methods - BuiltinTypeDeclBuilder &addCreateFromBinding(); - BuiltinTypeDeclBuilder &addCreateFromImplicitBinding(); + BuiltinTypeDeclBuilder &addStaticInitializationFunctions(bool HasCounter); // Builtin types methods BuiltinTypeDeclBuilder &addLoadMethods(); @@ -96,6 +95,10 @@ public: BuiltinTypeDeclBuilder &addConsumeMethod(); private: + BuiltinTypeDeclBuilder &addCreateFromBinding(); + BuiltinTypeDeclBuilder &addCreateFromImplicitBinding(); + BuiltinTypeDeclBuilder &addCreateFromBindingWithImplicitCounter(); + BuiltinTypeDeclBuilder &addCreateFromImplicitBindingWithImplicitCounter(); BuiltinTypeDeclBuilder &addResourceMember(StringRef MemberName, ResourceClass RC, bool IsROV, bool RawBuffer, bool IsCounter, diff --git a/clang/lib/Sema/HLSLExternalSemaSource.cpp b/clang/lib/Sema/HLSLExternalSemaSource.cpp index cc43e94..e118dda 100644 --- a/clang/lib/Sema/HLSLExternalSemaSource.cpp +++ b/clang/lib/Sema/HLSLExternalSemaSource.cpp @@ -236,8 +236,7 @@ static BuiltinTypeDeclBuilder setupBufferType(CXXRecordDecl *Decl, Sema &S, .addDefaultHandleConstructor() .addCopyConstructor() .addCopyAssignmentOperator() - .addCreateFromBinding() - .addCreateFromImplicitBinding(); + .addStaticInitializationFunctions(HasCounter); } // This function is responsible for constructing the constraint expression for diff --git a/clang/lib/Sema/SemaARM.cpp b/clang/lib/Sema/SemaARM.cpp index e09c352..1c7c832d 100644 --- a/clang/lib/Sema/SemaARM.cpp +++ b/clang/lib/Sema/SemaARM.cpp @@ -603,8 +603,8 @@ static bool checkArmStreamingBuiltin(Sema &S, CallExpr *TheCall, bool SatisfiesSME = Builtin::evaluateRequiredTargetFeatures( StreamingBuiltinGuard, CallerFeatures); - if ((SatisfiesSVE && SatisfiesSME) || - (SatisfiesSVE && FnType == SemaARM::ArmStreamingCompatible)) + if (SatisfiesSVE && SatisfiesSME) + // Function type is irrelevant for streaming-agnostic builtins. return false; else if (SatisfiesSVE) BuiltinType = SemaARM::ArmNonStreaming; diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp index 3cc61b1..063db05 100644 --- a/clang/lib/Sema/SemaChecking.cpp +++ b/clang/lib/Sema/SemaChecking.cpp @@ -8811,8 +8811,10 @@ CheckPrintfHandler::checkFormatExpr(const analyze_printf::PrintfSpecifier &FS, case ArgType::Match: case ArgType::MatchPromotion: case ArgType::NoMatchPromotionTypeConfusion: - case ArgType::NoMatchSignedness: llvm_unreachable("expected non-matching"); + case ArgType::NoMatchSignedness: + Diag = diag::warn_format_conversion_argument_type_mismatch_signedness; + break; case ArgType::NoMatchPedantic: Diag = diag::warn_format_conversion_argument_type_mismatch_pedantic; break; diff --git a/clang/lib/Sema/SemaConcept.cpp b/clang/lib/Sema/SemaConcept.cpp index 999e302c..f4df63c 100644 --- a/clang/lib/Sema/SemaConcept.cpp +++ b/clang/lib/Sema/SemaConcept.cpp @@ -280,6 +280,11 @@ public: if (T->getDepth() >= TemplateArgs.getNumLevels()) return true; + // There might not be a corresponding template argument before substituting + // into the parameter mapping, e.g. a sizeof... expression. + if (!TemplateArgs.hasTemplateArgument(T->getDepth(), T->getIndex())) + return true; + TemplateArgument Arg = TemplateArgs(T->getDepth(), T->getIndex()); if (T->isParameterPack() && SemaRef.ArgPackSubstIndex) { @@ -300,6 +305,12 @@ public: if (!NTTP) return TraverseDecl(D); + if (NTTP->getDepth() >= TemplateArgs.getNumLevels()) + return true; + + if (!TemplateArgs.hasTemplateArgument(NTTP->getDepth(), NTTP->getIndex())) + return true; + TemplateArgument Arg = TemplateArgs(NTTP->getDepth(), NTTP->getPosition()); if (NTTP->isParameterPack() && SemaRef.ArgPackSubstIndex) { assert(Arg.getKind() == TemplateArgument::Pack && @@ -326,17 +337,25 @@ public: return inherited::TraverseDecl(D); } + bool TraverseCallExpr(CallExpr *CE) { + inherited::TraverseStmt(CE->getCallee()); + + for (Expr *Arg : CE->arguments()) + inherited::TraverseStmt(Arg); + + return true; + } + bool TraverseTypeLoc(TypeLoc TL, bool TraverseQualifier = true) { // We don't care about TypeLocs. So traverse Types instead. - return TraverseType(TL.getType(), TraverseQualifier); + return TraverseType(TL.getType().getCanonicalType(), TraverseQualifier); } bool TraverseTagType(const TagType *T, bool TraverseQualifier) { // T's parent can be dependent while T doesn't have any template arguments. // We should have already traversed its qualifier. // FIXME: Add an assert to catch cases where we failed to profile the - // concept. assert(!T->isDependentType() && "We missed a case in profiling - // concepts!"); + // concept. return true; } @@ -701,7 +720,6 @@ ExprResult ConstraintSatisfactionChecker::Evaluate( if (auto Iter = S.UnsubstitutedConstraintSatisfactionCache.find(ID); Iter != S.UnsubstitutedConstraintSatisfactionCache.end()) { - auto &Cached = Iter->second.Satisfaction; Satisfaction.ContainsErrors = Cached.ContainsErrors; Satisfaction.IsSatisfied = Cached.IsSatisfied; diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp index 4d3c7d6..4230ea7 100644 --- a/clang/lib/Sema/SemaExpr.cpp +++ b/clang/lib/Sema/SemaExpr.cpp @@ -9014,24 +9014,6 @@ bool Sema::IsInvalidSMECallConversion(QualType FromType, QualType ToType) { return FromAttributes != ToAttributes; } -// Check if we have a conversion between incompatible cmse function pointer -// types, that is, a conversion between a function pointer with the -// cmse_nonsecure_call attribute and one without. -static bool IsInvalidCmseNSCallConversion(Sema &S, QualType FromType, - QualType ToType) { - if (const auto *ToFn = - dyn_cast<FunctionType>(S.Context.getCanonicalType(ToType))) { - if (const auto *FromFn = - dyn_cast<FunctionType>(S.Context.getCanonicalType(FromType))) { - FunctionType::ExtInfo ToEInfo = ToFn->getExtInfo(); - FunctionType::ExtInfo FromEInfo = FromFn->getExtInfo(); - - return ToEInfo.getCmseNSCall() != FromEInfo.getCmseNSCall(); - } - } - return false; -} - // checkPointerTypesForAssignment - This is a very tricky routine (despite // being closely modeled after the C99 spec:-). The odd characteristic of this // routine is it effectively iqnores the qualifiers on the top level pointee. @@ -9187,18 +9169,43 @@ static AssignConvertType checkPointerTypesForAssignment(Sema &S, return AssignConvertType::IncompatibleFunctionPointer; return AssignConvertType::IncompatiblePointer; } - bool DiscardingCFIUncheckedCallee, AddingCFIUncheckedCallee; - if (!S.getLangOpts().CPlusPlus && - S.IsFunctionConversion(ltrans, rtrans, &DiscardingCFIUncheckedCallee, - &AddingCFIUncheckedCallee)) { - // Allow conversions between CFIUncheckedCallee-ness. - if (!DiscardingCFIUncheckedCallee && !AddingCFIUncheckedCallee) + // Note: in C++, typesAreCompatible(ltrans, rtrans) will have guaranteed + // hasSameType, so we can skip further checks. + const auto *LFT = ltrans->getAs<FunctionType>(); + const auto *RFT = rtrans->getAs<FunctionType>(); + if (!S.getLangOpts().CPlusPlus && LFT && RFT) { + // The invocation of IsFunctionConversion below will try to transform rtrans + // to obtain an exact match for ltrans. This should not fail because of + // mismatches in result type and parameter types, they were already checked + // by typesAreCompatible above. So we will recreate rtrans (or where + // appropriate ltrans) using the result type and parameter types from ltrans + // (respectively rtrans), but keeping its ExtInfo/ExtProtoInfo. + const auto *LFPT = dyn_cast<FunctionProtoType>(LFT); + const auto *RFPT = dyn_cast<FunctionProtoType>(RFT); + if (LFPT && RFPT) { + rtrans = S.Context.getFunctionType(LFPT->getReturnType(), + LFPT->getParamTypes(), + RFPT->getExtProtoInfo()); + } else if (LFPT) { + FunctionProtoType::ExtProtoInfo EPI; + EPI.ExtInfo = RFT->getExtInfo(); + rtrans = S.Context.getFunctionType(LFPT->getReturnType(), + LFPT->getParamTypes(), EPI); + } else if (RFPT) { + // In this case, we want to retain rtrans as a FunctionProtoType, to keep + // all of its ExtProtoInfo. Transform ltrans instead. + FunctionProtoType::ExtProtoInfo EPI; + EPI.ExtInfo = LFT->getExtInfo(); + ltrans = S.Context.getFunctionType(RFPT->getReturnType(), + RFPT->getParamTypes(), EPI); + } else { + rtrans = S.Context.getFunctionNoProtoType(LFT->getReturnType(), + RFT->getExtInfo()); + } + if (!S.Context.hasSameUnqualifiedType(rtrans, ltrans) && + !S.IsFunctionConversion(rtrans, ltrans)) return AssignConvertType::IncompatibleFunctionPointer; } - if (IsInvalidCmseNSCallConversion(S, ltrans, rtrans)) - return AssignConvertType::IncompatibleFunctionPointer; - if (S.IsInvalidSMECallConversion(rtrans, ltrans)) - return AssignConvertType::IncompatibleFunctionPointer; return ConvTy; } diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index a662b72..17cb1e4 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -598,18 +598,17 @@ void SemaHLSL::ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace) { validatePackoffset(SemaRef, BufDecl); - // create buffer layout struct createHostLayoutStructForBuffer(SemaRef, BufDecl); - HLSLVkBindingAttr *VkBinding = Dcl->getAttr<HLSLVkBindingAttr>(); - HLSLResourceBindingAttr *RBA = Dcl->getAttr<HLSLResourceBindingAttr>(); - if (!VkBinding && (!RBA || !RBA->hasRegisterSlot())) { + // Handle implicit binding if needed. + ResourceBindingAttrs ResourceAttrs(Dcl); + if (!ResourceAttrs.isExplicit()) { SemaRef.Diag(Dcl->getLocation(), diag::warn_hlsl_implicit_binding); // Use HLSLResourceBindingAttr to transfer implicit binding order_ID // to codegen. If it does not exist, create an implicit attribute. uint32_t OrderID = getNextImplicitBindingOrderID(); - if (RBA) - RBA->setImplicitBindingOrderID(OrderID); + if (ResourceAttrs.hasBinding()) + ResourceAttrs.setImplicitOrderID(OrderID); else addImplicitBindingAttrToDecl(SemaRef, BufDecl, BufDecl->isCBuffer() ? RegisterType::CBuffer @@ -1241,6 +1240,20 @@ static CXXMethodDecl *lookupMethod(Sema &S, CXXRecordDecl *RecordDecl, } // end anonymous namespace +static bool hasCounterHandle(const CXXRecordDecl *RD) { + if (RD->field_empty()) + return false; + auto It = std::next(RD->field_begin()); + if (It == RD->field_end()) + return false; + const FieldDecl *SecondField = *It; + if (const auto *ResTy = + SecondField->getType()->getAs<HLSLAttributedResourceType>()) { + return ResTy->getAttrs().IsCounter; + } + return false; +} + bool SemaHLSL::handleRootSignatureElements( ArrayRef<hlsl::RootSignatureElement> Elements) { // Define some common error handling functions @@ -1590,10 +1603,6 @@ void SemaHLSL::handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL) { } void SemaHLSL::handleVkBindingAttr(Decl *D, const ParsedAttr &AL) { - // The vk::binding attribute only applies to SPIR-V. - if (!getASTContext().getTargetInfo().getTriple().isSPIRV()) - return; - uint32_t Binding = 0; if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Binding)) return; @@ -2978,6 +2987,25 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { TheCall->setType(ResourceTy); break; } + case Builtin::BI__builtin_hlsl_resource_counterhandlefromimplicitbinding: { + ASTContext &AST = SemaRef.getASTContext(); + if (SemaRef.checkArgCount(TheCall, 3) || + CheckResourceHandle(&SemaRef, TheCall, 0) || + CheckArgTypeMatches(&SemaRef, TheCall->getArg(1), AST.UnsignedIntTy) || + CheckArgTypeMatches(&SemaRef, TheCall->getArg(2), AST.UnsignedIntTy)) + return true; + + QualType MainHandleTy = TheCall->getArg(0)->getType(); + auto *MainResType = MainHandleTy->getAs<HLSLAttributedResourceType>(); + auto MainAttrs = MainResType->getAttrs(); + assert(!MainAttrs.IsCounter && "cannot create a counter from a counter"); + MainAttrs.IsCounter = true; + QualType CounterHandleTy = AST.getHLSLAttributedResourceType( + MainResType->getWrappedType(), MainResType->getContainedType(), + MainAttrs); + TheCall->setType(CounterHandleTy); + break; + } case Builtin::BI__builtin_hlsl_and: case Builtin::BI__builtin_hlsl_or: { if (SemaRef.checkArgCount(TheCall, 2)) @@ -3780,16 +3808,28 @@ void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) { // If the resource array does not have an explicit binding attribute, // create an implicit one. It will be used to transfer implicit binding // order_ID to codegen. - if (!VD->hasAttr<HLSLVkBindingAttr>()) { - HLSLResourceBindingAttr *RBA = VD->getAttr<HLSLResourceBindingAttr>(); - if (!RBA || !RBA->hasRegisterSlot()) { + ResourceBindingAttrs Binding(VD); + if (!Binding.isExplicit()) { + uint32_t OrderID = getNextImplicitBindingOrderID(); + if (Binding.hasBinding()) + Binding.setImplicitOrderID(OrderID); + else { + addImplicitBindingAttrToDecl( + SemaRef, VD, getRegisterType(getResourceArrayHandleType(VD)), + OrderID); + // Re-create the binding object to pick up the new attribute. + Binding = ResourceBindingAttrs(VD); + } + } + + // Get to the base type of a potentially multi-dimensional array. + QualType Ty = getASTContext().getBaseElementType(VD->getType()); + + const CXXRecordDecl *RD = Ty->getAsCXXRecordDecl(); + if (hasCounterHandle(RD)) { + if (!Binding.hasCounterImplicitOrderID()) { uint32_t OrderID = getNextImplicitBindingOrderID(); - if (RBA) - RBA->setImplicitBindingOrderID(OrderID); - else - addImplicitBindingAttrToDecl( - SemaRef, VD, getRegisterType(getResourceArrayHandleType(VD)), - OrderID); + Binding.setCounterImplicitOrderID(OrderID); } } } @@ -3815,19 +3855,31 @@ bool SemaHLSL::initGlobalResourceDecl(VarDecl *VD) { CXXMethodDecl *CreateMethod = nullptr; llvm::SmallVector<Expr *> Args; + bool HasCounter = hasCounterHandle(ResourceDecl); + const char *CreateMethodName; + if (Binding.isExplicit()) + CreateMethodName = HasCounter ? "__createFromBindingWithImplicitCounter" + : "__createFromBinding"; + else + CreateMethodName = HasCounter + ? "__createFromImplicitBindingWithImplicitCounter" + : "__createFromImplicitBinding"; + + CreateMethod = + lookupMethod(SemaRef, ResourceDecl, CreateMethodName, VD->getLocation()); + + if (!CreateMethod) + // This can happen if someone creates a struct that looks like an HLSL + // resource record but does not have the required static create method. + // No binding will be generated for it. + return false; + if (Binding.isExplicit()) { - // The resource has explicit binding. - CreateMethod = lookupMethod(SemaRef, ResourceDecl, "__createFromBinding", - VD->getLocation()); IntegerLiteral *RegSlot = IntegerLiteral::Create(AST, llvm::APInt(UIntTySize, Binding.getSlot()), AST.UnsignedIntTy, SourceLocation()); Args.push_back(RegSlot); } else { - // The resource has implicit binding. - CreateMethod = - lookupMethod(SemaRef, ResourceDecl, "__createFromImplicitBinding", - VD->getLocation()); uint32_t OrderID = (Binding.hasImplicitOrderID()) ? Binding.getImplicitOrderID() : getNextImplicitBindingOrderID(); @@ -3837,12 +3889,6 @@ bool SemaHLSL::initGlobalResourceDecl(VarDecl *VD) { Args.push_back(OrderId); } - if (!CreateMethod) - // This can happen if someone creates a struct that looks like an HLSL - // resource record but does not have the required static create method. - // No binding will be generated for it. - return false; - IntegerLiteral *Space = IntegerLiteral::Create(AST, llvm::APInt(UIntTySize, Binding.getSpace()), AST.UnsignedIntTy, SourceLocation()); @@ -3866,6 +3912,15 @@ bool SemaHLSL::initGlobalResourceDecl(VarDecl *VD) { Name, nullptr, VK_PRValue, FPOptionsOverride()); Args.push_back(NameCast); + if (HasCounter) { + // Will this be in the correct order? + uint32_t CounterOrderID = getNextImplicitBindingOrderID(); + IntegerLiteral *CounterId = + IntegerLiteral::Create(AST, llvm::APInt(UIntTySize, CounterOrderID), + AST.UnsignedIntTy, SourceLocation()); + Args.push_back(CounterId); + } + // Make sure the create method template is instantiated and emitted. if (!CreateMethod->isDefined() && CreateMethod->isTemplateInstantiation()) SemaRef.InstantiateFunctionDefinition(VD->getLocation(), CreateMethod, @@ -3906,20 +3961,24 @@ bool SemaHLSL::initGlobalResourceArrayDecl(VarDecl *VD) { ASTContext &AST = SemaRef.getASTContext(); QualType ResElementTy = AST.getBaseElementType(VD->getType()); CXXRecordDecl *ResourceDecl = ResElementTy->getAsCXXRecordDecl(); - - HLSLResourceBindingAttr *RBA = VD->getAttr<HLSLResourceBindingAttr>(); - HLSLVkBindingAttr *VkBinding = VD->getAttr<HLSLVkBindingAttr>(); CXXMethodDecl *CreateMethod = nullptr; - if (VkBinding || (RBA && RBA->hasRegisterSlot())) + bool HasCounter = hasCounterHandle(ResourceDecl); + ResourceBindingAttrs ResourceAttrs(VD); + if (ResourceAttrs.isExplicit()) // Resource has explicit binding. - CreateMethod = lookupMethod(SemaRef, ResourceDecl, "__createFromBinding", - VD->getLocation()); - else - // Resource has implicit binding. CreateMethod = - lookupMethod(SemaRef, ResourceDecl, "__createFromImplicitBinding", + lookupMethod(SemaRef, ResourceDecl, + HasCounter ? "__createFromBindingWithImplicitCounter" + : "__createFromBinding", VD->getLocation()); + else + // Resource has implicit binding. + CreateMethod = lookupMethod( + SemaRef, ResourceDecl, + HasCounter ? "__createFromImplicitBindingWithImplicitCounter" + : "__createFromImplicitBinding", + VD->getLocation()); if (!CreateMethod) return false; diff --git a/clang/lib/Sema/SemaOpenACC.cpp b/clang/lib/Sema/SemaOpenACC.cpp index 8471f02..4824b5a 100644 --- a/clang/lib/Sema/SemaOpenACC.cpp +++ b/clang/lib/Sema/SemaOpenACC.cpp @@ -2946,5 +2946,5 @@ OpenACCReductionRecipe SemaOpenACC::CreateReductionInitRecipe( AllocaDecl->setInit(Init.get()); AllocaDecl->setInitStyle(VarDecl::CallInit); } - return OpenACCReductionRecipe(AllocaDecl); + return OpenACCReductionRecipe(AllocaDecl, {}); } diff --git a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp index 5657dfe..8339bb1 100644 --- a/clang/lib/Sema/SemaOverload.cpp +++ b/clang/lib/Sema/SemaOverload.cpp @@ -1087,14 +1087,14 @@ static bool shouldAddReversedEqEq(Sema &S, SourceLocation OpLoc, } bool OverloadCandidateSet::OperatorRewriteInfo::allowsReversed( - OverloadedOperatorKind Op) { + OverloadedOperatorKind Op) const { if (!AllowRewrittenCandidates) return false; return Op == OO_EqualEqual || Op == OO_Spaceship; } bool OverloadCandidateSet::OperatorRewriteInfo::shouldAddReversed( - Sema &S, ArrayRef<Expr *> OriginalArgs, FunctionDecl *FD) { + Sema &S, ArrayRef<Expr *> OriginalArgs, FunctionDecl *FD) const { auto Op = FD->getOverloadedOperator(); if (!allowsReversed(Op)) return false; @@ -1892,14 +1892,7 @@ bool Sema::TryFunctionConversion(QualType FromType, QualType ToType, return Changed; } -bool Sema::IsFunctionConversion(QualType FromType, QualType ToType, - bool *DiscardingCFIUncheckedCallee, - bool *AddingCFIUncheckedCallee) const { - if (DiscardingCFIUncheckedCallee) - *DiscardingCFIUncheckedCallee = false; - if (AddingCFIUncheckedCallee) - *AddingCFIUncheckedCallee = false; - +bool Sema::IsFunctionConversion(QualType FromType, QualType ToType) const { if (Context.hasSameUnqualifiedType(FromType, ToType)) return false; @@ -1958,25 +1951,14 @@ bool Sema::IsFunctionConversion(QualType FromType, QualType ToType, const auto *ToFPT = dyn_cast<FunctionProtoType>(ToFn); if (FromFPT && ToFPT) { - if (FromFPT->hasCFIUncheckedCallee() && !ToFPT->hasCFIUncheckedCallee()) { - QualType NewTy = Context.getFunctionType( - FromFPT->getReturnType(), FromFPT->getParamTypes(), - FromFPT->getExtProtoInfo().withCFIUncheckedCallee(false)); - FromFPT = cast<FunctionProtoType>(NewTy.getTypePtr()); - FromFn = FromFPT; - Changed = true; - if (DiscardingCFIUncheckedCallee) - *DiscardingCFIUncheckedCallee = true; - } else if (!FromFPT->hasCFIUncheckedCallee() && - ToFPT->hasCFIUncheckedCallee()) { + if (FromFPT->hasCFIUncheckedCallee() != ToFPT->hasCFIUncheckedCallee()) { QualType NewTy = Context.getFunctionType( FromFPT->getReturnType(), FromFPT->getParamTypes(), - FromFPT->getExtProtoInfo().withCFIUncheckedCallee(true)); + FromFPT->getExtProtoInfo().withCFIUncheckedCallee( + ToFPT->hasCFIUncheckedCallee())); FromFPT = cast<FunctionProtoType>(NewTy.getTypePtr()); FromFn = FromFPT; Changed = true; - if (AddingCFIUncheckedCallee) - *AddingCFIUncheckedCallee = true; } } @@ -2007,11 +1989,7 @@ bool Sema::IsFunctionConversion(QualType FromType, QualType ToType, Changed = true; } - // For C, when called from checkPointerTypesForAssignment, - // we need to not alter FromFn, or else even an innocuous cast - // like dropping effects will fail. In C++ however we do want to - // alter FromFn (because of the way PerformImplicitConversion works). - if (Context.hasAnyFunctionEffects() && getLangOpts().CPlusPlus) { + if (Context.hasAnyFunctionEffects()) { FromFPT = cast<FunctionProtoType>(FromFn); // in case FromFn changed above // Transparently add/drop effects; here we are concerned with diff --git a/clang/lib/Sema/SemaTemplate.cpp b/clang/lib/Sema/SemaTemplate.cpp index 419f3e1..3a6ff99 100644 --- a/clang/lib/Sema/SemaTemplate.cpp +++ b/clang/lib/Sema/SemaTemplate.cpp @@ -318,7 +318,7 @@ TemplateNameKind Sema::isTemplateName(Scope *S, } } - if (isPackProducingBuiltinTemplateName(Template) && + if (isPackProducingBuiltinTemplateName(Template) && S && S->getTemplateParamParent() == nullptr) Diag(Name.getBeginLoc(), diag::err_builtin_pack_outside_template) << TName; // Recover by returning the template, even though we would never be able to diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h index 51b55b8..940324b 100644 --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -16364,16 +16364,21 @@ ExprResult TreeTransform<Derived>::TransformSubstNonTypeTemplateParmExpr( AssociatedDecl == E->getAssociatedDecl()) return E; + auto getParamAndType = [Index = E->getIndex()](Decl *AssociatedDecl) + -> std::tuple<NonTypeTemplateParmDecl *, QualType> { + auto [PDecl, Arg] = getReplacedTemplateParameter(AssociatedDecl, Index); + auto *Param = cast<NonTypeTemplateParmDecl>(PDecl); + return {Param, Arg.isNull() ? Param->getType() + : Arg.getNonTypeTemplateArgumentType()}; + }; + // If the replacement expression did not change, and the parameter type // did not change, we can skip the semantic action because it would // produce the same result anyway. - auto *Param = cast<NonTypeTemplateParmDecl>( - getReplacedTemplateParameterList(AssociatedDecl) - ->asArray()[E->getIndex()]); - if (QualType ParamType = Param->getType(); - !SemaRef.Context.hasSameType(ParamType, E->getParameter()->getType()) || + if (auto [Param, ParamType] = getParamAndType(AssociatedDecl); + !SemaRef.Context.hasSameType( + ParamType, std::get<1>(getParamAndType(E->getAssociatedDecl()))) || Replacement.get() != OrigReplacement) { - // When transforming the replacement expression previously, all Sema // specific annotations, such as implicit casts, are discarded. Calling the // corresponding sema action is necessary to recover those. Otherwise, diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp index 6acf79a..868f0cc 100644 --- a/clang/lib/Serialization/ASTReader.cpp +++ b/clang/lib/Serialization/ASTReader.cpp @@ -13009,9 +13009,22 @@ OpenACCClause *ASTRecordReader::readOpenACCClause() { llvm::SmallVector<OpenACCReductionRecipe> RecipeList; for (unsigned I = 0; I < VarList.size(); ++I) { - static_assert(sizeof(OpenACCReductionRecipe) == sizeof(int *)); VarDecl *Recipe = readDeclAs<VarDecl>(); - RecipeList.push_back({Recipe}); + + static_assert(sizeof(OpenACCReductionRecipe::CombinerRecipe) == + 3 * sizeof(int *)); + + llvm::SmallVector<OpenACCReductionRecipe::CombinerRecipe> Combiners; + unsigned NumCombiners = readInt(); + for (unsigned I = 0; I < NumCombiners; ++I) { + VarDecl *LHS = readDeclAs<VarDecl>(); + VarDecl *RHS = readDeclAs<VarDecl>(); + Expr *Op = readExpr(); + + Combiners.push_back({LHS, RHS, Op}); + } + + RecipeList.push_back({Recipe, Combiners}); } return OpenACCReductionClause::Create(getContext(), BeginLoc, LParenLoc, Op, diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp index 09b1e58..82ccde8 100644 --- a/clang/lib/Serialization/ASTWriter.cpp +++ b/clang/lib/Serialization/ASTWriter.cpp @@ -8925,8 +8925,17 @@ void ASTRecordWriter::writeOpenACCClause(const OpenACCClause *C) { writeOpenACCVarList(RC); for (const OpenACCReductionRecipe &R : RC->getRecipes()) { - static_assert(sizeof(OpenACCReductionRecipe) == 1 * sizeof(int *)); AddDeclRef(R.AllocaDecl); + + static_assert(sizeof(OpenACCReductionRecipe::CombinerRecipe) == + 3 * sizeof(int *)); + writeUInt32(R.CombinerRecipes.size()); + + for (auto &CombinerRecipe : R.CombinerRecipes) { + AddDeclRef(CombinerRecipe.LHS); + AddDeclRef(CombinerRecipe.RHS); + AddStmt(CombinerRecipe.Op); + } } return; } diff --git a/clang/lib/StaticAnalyzer/Checkers/MallocChecker.cpp b/clang/lib/StaticAnalyzer/Checkers/MallocChecker.cpp index 83d79b43..70baab5 100644 --- a/clang/lib/StaticAnalyzer/Checkers/MallocChecker.cpp +++ b/clang/lib/StaticAnalyzer/Checkers/MallocChecker.cpp @@ -3812,6 +3812,15 @@ bool MallocChecker::mayFreeAnyEscapedMemoryOrIsModeledExplicitly( return true; } + // Protobuf function declared in `generated_message_util.h` that takes + // ownership of the second argument. As the first and third arguments are + // allocation arenas and won't be tracked by this checker, there is no reason + // to set `EscapingSymbol`. (Also, this is an implementation detail of + // Protobuf, so it's better to be a bit more permissive.) + if (FName == "GetOwnedMessageInternal") { + return true; + } + // Handle cases where we know a buffer's /address/ can escape. // Note that the above checks handle some special cases where we know that // even though the address escapes, it's still our responsibility to free the diff --git a/clang/lib/StaticAnalyzer/Checkers/VAListChecker.cpp b/clang/lib/StaticAnalyzer/Checkers/VAListChecker.cpp index 79fd0bd..503fa5d 100644 --- a/clang/lib/StaticAnalyzer/Checkers/VAListChecker.cpp +++ b/clang/lib/StaticAnalyzer/Checkers/VAListChecker.cpp @@ -149,7 +149,7 @@ void VAListChecker::checkPreCall(const CallEvent &Call, else if (VaEnd.matches(Call)) checkVAListEndCall(Call, C); else { - for (auto FuncInfo : VAListAccepters) { + for (const auto &FuncInfo : VAListAccepters) { if (!FuncInfo.Func.matches(Call)) continue; const MemRegion *VAList = diff --git a/clang/lib/StaticAnalyzer/Checkers/WebKit/ASTUtils.cpp b/clang/lib/StaticAnalyzer/Checkers/WebKit/ASTUtils.cpp index 00a1b8b..66cfccb 100644 --- a/clang/lib/StaticAnalyzer/Checkers/WebKit/ASTUtils.cpp +++ b/clang/lib/StaticAnalyzer/Checkers/WebKit/ASTUtils.cpp @@ -31,9 +31,9 @@ bool tryToFindPtrOrigin( if (auto *DRE = dyn_cast<DeclRefExpr>(E)) { if (auto *VD = dyn_cast_or_null<VarDecl>(DRE->getDecl())) { auto QT = VD->getType(); - if (VD->hasGlobalStorage() && QT.isConstQualified()) { + auto IsImmortal = safeGetName(VD) == "NSApp"; + if (VD->hasGlobalStorage() && (IsImmortal || QT.isConstQualified())) return callback(E, true); - } } } if (auto *tempExpr = dyn_cast<MaterializeTemporaryExpr>(E)) { @@ -208,6 +208,8 @@ bool tryToFindPtrOrigin( continue; } if (auto *BoxedExpr = dyn_cast<ObjCBoxedExpr>(E)) { + if (StopAtFirstRefCountedObj) + return callback(BoxedExpr, true); E = BoxedExpr->getSubExpr(); continue; } diff --git a/clang/lib/StaticAnalyzer/Checkers/WebKit/RawPtrRefMemberChecker.cpp b/clang/lib/StaticAnalyzer/Checkers/WebKit/RawPtrRefMemberChecker.cpp index 15a0c5a..ace639c 100644 --- a/clang/lib/StaticAnalyzer/Checkers/WebKit/RawPtrRefMemberChecker.cpp +++ b/clang/lib/StaticAnalyzer/Checkers/WebKit/RawPtrRefMemberChecker.cpp @@ -232,7 +232,7 @@ public: bool ignoreARC = !PD->isReadOnly() && PD->getSetterKind() == ObjCPropertyDecl::Assign; auto IsUnsafePtr = isUnsafePtr(QT, ignoreARC); - return {IsUnsafePtr && *IsUnsafePtr, PropType}; + return {IsUnsafePtr && *IsUnsafePtr && !PD->isRetaining(), PropType}; } bool shouldSkipDecl(const RecordDecl *RD) const { |