diff options
Diffstat (limited to 'clang/lib')
27 files changed, 1105 insertions, 79 deletions
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp index 1b5d16b..84deaf5 100644 --- a/clang/lib/AST/ASTContext.cpp +++ b/clang/lib/AST/ASTContext.cpp @@ -86,6 +86,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MD5.h" #include "llvm/Support/MathExtras.h" +#include "llvm/Support/SipHash.h" #include "llvm/Support/raw_ostream.h" #include "llvm/TargetParser/AArch64TargetParser.h" #include "llvm/TargetParser/Triple.h" @@ -3128,6 +3129,17 @@ QualType ASTContext::removeAddrSpaceQualType(QualType T) const { return QualType(TypeNode, Quals.getFastQualifiers()); } +uint16_t +ASTContext::getPointerAuthVTablePointerDiscriminator(const CXXRecordDecl *RD) { + assert(RD->isPolymorphic() && + "Attempted to get vtable pointer discriminator on a monomorphic type"); + std::unique_ptr<MangleContext> MC(createMangleContext()); + SmallString<256> Str; + llvm::raw_svector_ostream Out(Str); + MC->mangleCXXVTable(RD, Out); + return llvm::getPointerAuthStableSipHash(Str); +} + QualType ASTContext::getObjCGCQualType(QualType T, Qualifiers::GC GCAttr) const { QualType CanT = getCanonicalType(T); @@ -13894,3 +13906,74 @@ StringRef ASTContext::getCUIDHash() const { CUIDHash = llvm::utohexstr(llvm::MD5Hash(LangOpts.CUID), /*LowerCase=*/true); return CUIDHash; } + +const CXXRecordDecl * +ASTContext::baseForVTableAuthentication(const CXXRecordDecl *ThisClass) { + assert(ThisClass); + assert(ThisClass->isPolymorphic()); + const CXXRecordDecl *PrimaryBase = ThisClass; + while (1) { + assert(PrimaryBase); + assert(PrimaryBase->isPolymorphic()); + auto &Layout = getASTRecordLayout(PrimaryBase); + auto Base = Layout.getPrimaryBase(); + if (!Base || Base == PrimaryBase || !Base->isPolymorphic()) + break; + PrimaryBase = Base; + } + return PrimaryBase; +} + +bool ASTContext::useAbbreviatedThunkName(GlobalDecl VirtualMethodDecl, + StringRef MangledName) { + auto *Method = cast<CXXMethodDecl>(VirtualMethodDecl.getDecl()); + assert(Method->isVirtual()); + bool DefaultIncludesPointerAuth = + LangOpts.PointerAuthCalls || LangOpts.PointerAuthIntrinsics; + + if (!DefaultIncludesPointerAuth) + return true; + + auto Existing = ThunksToBeAbbreviated.find(VirtualMethodDecl); + if (Existing != ThunksToBeAbbreviated.end()) + return Existing->second.contains(MangledName.str()); + + std::unique_ptr<MangleContext> Mangler(createMangleContext()); + llvm::StringMap<llvm::SmallVector<std::string, 2>> Thunks; + auto VtableContext = getVTableContext(); + if (const auto *ThunkInfos = VtableContext->getThunkInfo(VirtualMethodDecl)) { + auto *Destructor = dyn_cast<CXXDestructorDecl>(Method); + for (const auto &Thunk : *ThunkInfos) { + SmallString<256> ElidedName; + llvm::raw_svector_ostream ElidedNameStream(ElidedName); + if (Destructor) + Mangler->mangleCXXDtorThunk(Destructor, VirtualMethodDecl.getDtorType(), + Thunk, /* elideOverrideInfo */ true, + ElidedNameStream); + else + Mangler->mangleThunk(Method, Thunk, /* elideOverrideInfo */ true, + ElidedNameStream); + SmallString<256> MangledName; + llvm::raw_svector_ostream mangledNameStream(MangledName); + if (Destructor) + Mangler->mangleCXXDtorThunk(Destructor, VirtualMethodDecl.getDtorType(), + Thunk, /* elideOverrideInfo */ false, + mangledNameStream); + else + Mangler->mangleThunk(Method, Thunk, /* elideOverrideInfo */ false, + mangledNameStream); + + if (Thunks.find(ElidedName) == Thunks.end()) + Thunks[ElidedName] = {}; + Thunks[ElidedName].push_back(std::string(MangledName)); + } + } + llvm::StringSet<> SimplifiedThunkNames; + for (auto &ThunkList : Thunks) { + llvm::sort(ThunkList.second); + SimplifiedThunkNames.insert(ThunkList.second[0]); + } + bool Result = SimplifiedThunkNames.contains(MangledName); + ThunksToBeAbbreviated[VirtualMethodDecl] = std::move(SimplifiedThunkNames); + return Result; +} diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp index 2cac03b..5444dcf 100644 --- a/clang/lib/AST/ItaniumMangle.cpp +++ b/clang/lib/AST/ItaniumMangle.cpp @@ -99,11 +99,10 @@ public: } void mangleCXXName(GlobalDecl GD, raw_ostream &) override; - void mangleThunk(const CXXMethodDecl *MD, const ThunkInfo &Thunk, + void mangleThunk(const CXXMethodDecl *MD, const ThunkInfo &Thunk, bool, raw_ostream &) override; void mangleCXXDtorThunk(const CXXDestructorDecl *DD, CXXDtorType Type, - const ThisAdjustment &ThisAdjustment, - raw_ostream &) override; + const ThunkInfo &Thunk, bool, raw_ostream &) override; void mangleReferenceTemporary(const VarDecl *D, unsigned ManglingNumber, raw_ostream &) override; void mangleCXXVTable(const CXXRecordDecl *RD, raw_ostream &) override; @@ -468,6 +467,7 @@ public: void mangleNameOrStandardSubstitution(const NamedDecl *ND); void mangleLambdaSig(const CXXRecordDecl *Lambda); void mangleModuleNamePrefix(StringRef Name, bool IsPartition = false); + void mangleVendorQualifier(StringRef Name); private: @@ -559,7 +559,6 @@ private: StringRef Prefix = ""); void mangleOperatorName(DeclarationName Name, unsigned Arity); void mangleOperatorName(OverloadedOperatorKind OO, unsigned Arity); - void mangleVendorQualifier(StringRef qualifier); void mangleQualifiers(Qualifiers Quals, const DependentAddressSpaceType *DAST = nullptr); void mangleRefQualifier(RefQualifierKind RefQualifier); @@ -7044,8 +7043,78 @@ void ItaniumMangleContextImpl::mangleCXXDtorComdat(const CXXDestructorDecl *D, Mangler.mangle(GlobalDecl(D, Dtor_Comdat)); } +/// Mangles the pointer authentication override attribute for classes +/// that have explicit overrides for the vtable authentication schema. +/// +/// The override is mangled as a parameterized vendor extension as follows +/// +/// <type> ::= U "__vtptrauth" I +/// <key> +/// <addressDiscriminated> +/// <extraDiscriminator> +/// E +/// +/// The extra discriminator encodes the explicit value derived from the +/// override schema, e.g. if the override has specified type based +/// discrimination the encoded value will be the discriminator derived from the +/// type name. +static void mangleOverrideDiscrimination(CXXNameMangler &Mangler, + ASTContext &Context, + const ThunkInfo &Thunk) { + auto &LangOpts = Context.getLangOpts(); + const CXXRecordDecl *ThisRD = Thunk.ThisType->getPointeeCXXRecordDecl(); + const CXXRecordDecl *PtrauthClassRD = + Context.baseForVTableAuthentication(ThisRD); + unsigned TypedDiscriminator = + Context.getPointerAuthVTablePointerDiscriminator(ThisRD); + Mangler.mangleVendorQualifier("__vtptrauth"); + auto &ManglerStream = Mangler.getStream(); + ManglerStream << "I"; + if (const auto *ExplicitAuth = + PtrauthClassRD->getAttr<VTablePointerAuthenticationAttr>()) { + ManglerStream << "Lj" << ExplicitAuth->getKey(); + + if (ExplicitAuth->getAddressDiscrimination() == + VTablePointerAuthenticationAttr::DefaultAddressDiscrimination) + ManglerStream << "Lb" << LangOpts.PointerAuthVTPtrAddressDiscrimination; + else + ManglerStream << "Lb" + << (ExplicitAuth->getAddressDiscrimination() == + VTablePointerAuthenticationAttr::AddressDiscrimination); + + switch (ExplicitAuth->getExtraDiscrimination()) { + case VTablePointerAuthenticationAttr::DefaultExtraDiscrimination: { + if (LangOpts.PointerAuthVTPtrTypeDiscrimination) + ManglerStream << "Lj" << TypedDiscriminator; + else + ManglerStream << "Lj" << 0; + break; + } + case VTablePointerAuthenticationAttr::TypeDiscrimination: + ManglerStream << "Lj" << TypedDiscriminator; + break; + case VTablePointerAuthenticationAttr::CustomDiscrimination: + ManglerStream << "Lj" << ExplicitAuth->getCustomDiscriminationValue(); + break; + case VTablePointerAuthenticationAttr::NoExtraDiscrimination: + ManglerStream << "Lj" << 0; + break; + } + } else { + ManglerStream << "Lj" + << (unsigned)VTablePointerAuthenticationAttr::DefaultKey; + ManglerStream << "Lb" << LangOpts.PointerAuthVTPtrAddressDiscrimination; + if (LangOpts.PointerAuthVTPtrTypeDiscrimination) + ManglerStream << "Lj" << TypedDiscriminator; + else + ManglerStream << "Lj" << 0; + } + ManglerStream << "E"; +} + void ItaniumMangleContextImpl::mangleThunk(const CXXMethodDecl *MD, const ThunkInfo &Thunk, + bool ElideOverrideInfo, raw_ostream &Out) { // <special-name> ::= T <call-offset> <base encoding> // # base is the nominal target function of thunk @@ -7071,21 +7140,28 @@ void ItaniumMangleContextImpl::mangleThunk(const CXXMethodDecl *MD, Thunk.Return.Virtual.Itanium.VBaseOffsetOffset); Mangler.mangleFunctionEncoding(MD); + if (!ElideOverrideInfo) + mangleOverrideDiscrimination(Mangler, getASTContext(), Thunk); } -void ItaniumMangleContextImpl::mangleCXXDtorThunk( - const CXXDestructorDecl *DD, CXXDtorType Type, - const ThisAdjustment &ThisAdjustment, raw_ostream &Out) { +void ItaniumMangleContextImpl::mangleCXXDtorThunk(const CXXDestructorDecl *DD, + CXXDtorType Type, + const ThunkInfo &Thunk, + bool ElideOverrideInfo, + raw_ostream &Out) { // <special-name> ::= T <call-offset> <base encoding> // # base is the nominal target function of thunk CXXNameMangler Mangler(*this, Out, DD, Type); Mangler.getStream() << "_ZT"; + auto &ThisAdjustment = Thunk.This; // Mangle the 'this' pointer adjustment. Mangler.mangleCallOffset(ThisAdjustment.NonVirtual, ThisAdjustment.Virtual.Itanium.VCallOffsetOffset); Mangler.mangleFunctionEncoding(GlobalDecl(DD, Type)); + if (!ElideOverrideInfo) + mangleOverrideDiscrimination(Mangler, getASTContext(), Thunk); } /// Returns the mangled name for a guard variable for the passed in VarDecl. diff --git a/clang/lib/AST/Mangle.cpp b/clang/lib/AST/Mangle.cpp index 4fbf0e3..75f6e21 100644 --- a/clang/lib/AST/Mangle.cpp +++ b/clang/lib/AST/Mangle.cpp @@ -513,10 +513,20 @@ public: } } else if (const auto *MD = dyn_cast_or_null<CXXMethodDecl>(ND)) { Manglings.emplace_back(getName(ND)); - if (MD->isVirtual()) - if (const auto *TIV = Ctx.getVTableContext()->getThunkInfo(MD)) - for (const auto &T : *TIV) - Manglings.emplace_back(getMangledThunk(MD, T)); + if (MD->isVirtual()) { + if (const auto *TIV = Ctx.getVTableContext()->getThunkInfo(MD)) { + for (const auto &T : *TIV) { + std::string ThunkName; + std::string ContextualizedName = + getMangledThunk(MD, T, /* ElideOverrideInfo */ false); + if (Ctx.useAbbreviatedThunkName(MD, ContextualizedName)) + ThunkName = getMangledThunk(MD, T, /* ElideOverrideInfo */ true); + else + ThunkName = ContextualizedName; + Manglings.emplace_back(ThunkName); + } + } + } } return Manglings; @@ -569,11 +579,12 @@ private: return BOS.str(); } - std::string getMangledThunk(const CXXMethodDecl *MD, const ThunkInfo &T) { + std::string getMangledThunk(const CXXMethodDecl *MD, const ThunkInfo &T, + bool ElideOverrideInfo) { std::string FrontendBuf; llvm::raw_string_ostream FOS(FrontendBuf); - MC->mangleThunk(MD, T, FOS); + MC->mangleThunk(MD, T, ElideOverrideInfo, FOS); std::string BackendBuf; llvm::raw_string_ostream BOS(BackendBuf); diff --git a/clang/lib/AST/MicrosoftMangle.cpp b/clang/lib/AST/MicrosoftMangle.cpp index 8cbaad6..7f1e9ab 100644 --- a/clang/lib/AST/MicrosoftMangle.cpp +++ b/clang/lib/AST/MicrosoftMangle.cpp @@ -159,9 +159,9 @@ public: const MethodVFTableLocation &ML, raw_ostream &Out) override; void mangleThunk(const CXXMethodDecl *MD, const ThunkInfo &Thunk, - raw_ostream &) override; + bool ElideOverrideInfo, raw_ostream &) override; void mangleCXXDtorThunk(const CXXDestructorDecl *DD, CXXDtorType Type, - const ThisAdjustment &ThisAdjustment, + const ThunkInfo &Thunk, bool ElideOverrideInfo, raw_ostream &) override; void mangleCXXVFTable(const CXXRecordDecl *Derived, ArrayRef<const CXXRecordDecl *> BasePath, @@ -169,6 +169,8 @@ public: void mangleCXXVBTable(const CXXRecordDecl *Derived, ArrayRef<const CXXRecordDecl *> BasePath, raw_ostream &Out) override; + + void mangleCXXVTable(const CXXRecordDecl *, raw_ostream &) override; void mangleCXXVirtualDisplacementMap(const CXXRecordDecl *SrcRD, const CXXRecordDecl *DstRD, raw_ostream &Out) override; @@ -3747,6 +3749,7 @@ void MicrosoftMangleContextImpl::mangleVirtualMemPtrThunk( void MicrosoftMangleContextImpl::mangleThunk(const CXXMethodDecl *MD, const ThunkInfo &Thunk, + bool /*ElideOverrideInfo*/, raw_ostream &Out) { msvc_hashing_ostream MHO(Out); MicrosoftCXXNameMangler Mangler(*this, MHO); @@ -3768,9 +3771,11 @@ void MicrosoftMangleContextImpl::mangleThunk(const CXXMethodDecl *MD, DeclForFPT->getType()->castAs<FunctionProtoType>(), MD); } -void MicrosoftMangleContextImpl::mangleCXXDtorThunk( - const CXXDestructorDecl *DD, CXXDtorType Type, - const ThisAdjustment &Adjustment, raw_ostream &Out) { +void MicrosoftMangleContextImpl::mangleCXXDtorThunk(const CXXDestructorDecl *DD, + CXXDtorType Type, + const ThunkInfo &Thunk, + bool /*ElideOverrideInfo*/, + raw_ostream &Out) { // FIXME: Actually, the dtor thunk should be emitted for vector deleting // dtors rather than scalar deleting dtors. Just use the vector deleting dtor // mangling manually until we support both deleting dtor types. @@ -3779,6 +3784,7 @@ void MicrosoftMangleContextImpl::mangleCXXDtorThunk( MicrosoftCXXNameMangler Mangler(*this, MHO, DD, Type); Mangler.getStream() << "??_E"; Mangler.mangleName(DD->getParent()); + auto &Adjustment = Thunk.This; mangleThunkThisAdjustment(DD->getAccess(), Adjustment, Mangler, MHO); Mangler.mangleFunctionType(DD->getType()->castAs<FunctionProtoType>(), DD); } @@ -3803,6 +3809,12 @@ void MicrosoftMangleContextImpl::mangleCXXVFTable( Mangler.getStream() << '@'; } +void MicrosoftMangleContextImpl::mangleCXXVTable(const CXXRecordDecl *Derived, + raw_ostream &Out) { + // TODO: Determine appropriate mangling for MSABI + mangleCXXVFTable(Derived, {}, Out); +} + void MicrosoftMangleContextImpl::mangleCXXVBTable( const CXXRecordDecl *Derived, ArrayRef<const CXXRecordDecl *> BasePath, raw_ostream &Out) { diff --git a/clang/lib/AST/VTableBuilder.cpp b/clang/lib/AST/VTableBuilder.cpp index a956ca5..e941c3b 100644 --- a/clang/lib/AST/VTableBuilder.cpp +++ b/clang/lib/AST/VTableBuilder.cpp @@ -1147,11 +1147,41 @@ void ItaniumVTableBuilder::ComputeThisAdjustments() { continue; // Add it. - VTableThunks[VTableIndex].This = ThisAdjustment; + auto SetThisAdjustmentThunk = [&](uint64_t Idx) { + // If a this pointer adjustment is required, record the method that + // created the vtable entry. MD is not necessarily the method that + // created the entry since derived classes overwrite base class + // information in MethodInfoMap, hence findOriginalMethodInMap is called + // here. + // + // For example, in the following class hierarchy, if MD = D1::m and + // Overrider = D2:m, the original method that created the entry is B0:m, + // which is what findOriginalMethodInMap(MD) returns: + // + // struct B0 { int a; virtual void m(); }; + // struct D0 : B0 { int a; void m() override; }; + // struct D1 : B0 { int a; void m() override; }; + // struct D2 : D0, D1 { int a; void m() override; }; + // + // We need to record the method because we cannot + // call findOriginalMethod to find the method that created the entry if + // the method in the entry requires adjustment. + // + // Do not set ThunkInfo::Method if Idx is already in VTableThunks. This + // can happen when covariant return adjustment is required too. + if (!VTableThunks.count(Idx)) { + const CXXMethodDecl *Method = VTables.findOriginalMethodInMap(MD); + VTableThunks[Idx].Method = Method; + VTableThunks[Idx].ThisType = Method->getThisType().getTypePtr(); + } + VTableThunks[Idx].This = ThisAdjustment; + }; + + SetThisAdjustmentThunk(VTableIndex); if (isa<CXXDestructorDecl>(MD)) { // Add an adjustment for the deleting destructor as well. - VTableThunks[VTableIndex + 1].This = ThisAdjustment; + SetThisAdjustmentThunk(VTableIndex + 1); } } @@ -1509,6 +1539,8 @@ void ItaniumVTableBuilder::AddMethods( FindNearestOverriddenMethod(MD, PrimaryBases)) { if (ComputeReturnAdjustmentBaseOffset(Context, MD, OverriddenMD).isEmpty()) { + VTables.setOriginalMethod(MD, OverriddenMD); + // Replace the method info of the overridden method with our own // method. assert(MethodInfoMap.count(OverriddenMD) && @@ -1547,7 +1579,8 @@ void ItaniumVTableBuilder::AddMethods( // This is a virtual thunk for the most derived class, add it. AddThunk(Overrider.Method, - ThunkInfo(ThisAdjustment, ReturnAdjustment)); + ThunkInfo(ThisAdjustment, ReturnAdjustment, + OverriddenMD->getThisType().getTypePtr())); } } @@ -1615,6 +1648,15 @@ void ItaniumVTableBuilder::AddMethods( ReturnAdjustment ReturnAdjustment = ComputeReturnAdjustment(ReturnAdjustmentOffset); + // If a return adjustment is required, record the method that created the + // vtable entry. We need to record the method because we cannot call + // findOriginalMethod to find the method that created the entry if the + // method in the entry requires adjustment. + if (!ReturnAdjustment.isEmpty()) { + VTableThunks[Components.size()].Method = MD; + VTableThunks[Components.size()].ThisType = MD->getThisType().getTypePtr(); + } + AddMethod(Overrider.Method, ReturnAdjustment); } } @@ -1890,11 +1932,31 @@ void ItaniumVTableBuilder::LayoutVTablesForVirtualBases( } } +static void printThunkMethod(const ThunkInfo &Info, raw_ostream &Out) { + if (!Info.Method) + return; + std::string Str = PredefinedExpr::ComputeName( + PredefinedIdentKind::PrettyFunctionNoVirtual, Info.Method); + Out << " method: " << Str; +} + /// dumpLayout - Dump the vtable layout. void ItaniumVTableBuilder::dumpLayout(raw_ostream &Out) { // FIXME: write more tests that actually use the dumpLayout output to prevent // ItaniumVTableBuilder regressions. + Out << "Original map\n"; + + for (const auto &P : VTables.getOriginalMethodMap()) { + std::string Str0 = + PredefinedExpr::ComputeName(PredefinedIdentKind::PrettyFunctionNoVirtual, + P.first); + std::string Str1 = + PredefinedExpr::ComputeName(PredefinedIdentKind::PrettyFunctionNoVirtual, + P.second); + Out << " " << Str0 << " -> " << Str1 << "\n"; + } + if (isBuildingConstructorVTable()) { Out << "Construction vtable for ('"; MostDerivedClass->printQualifiedName(Out); @@ -1978,6 +2040,7 @@ void ItaniumVTableBuilder::dumpLayout(raw_ostream &Out) { } Out << ']'; + printThunkMethod(Thunk, Out); } // If this function pointer has a 'this' pointer adjustment, dump it. @@ -1991,6 +2054,7 @@ void ItaniumVTableBuilder::dumpLayout(raw_ostream &Out) { } Out << ']'; + printThunkMethod(Thunk, Out); } } @@ -2027,6 +2091,7 @@ void ItaniumVTableBuilder::dumpLayout(raw_ostream &Out) { Out << ']'; } + printThunkMethod(Thunk, Out); } break; @@ -2125,7 +2190,6 @@ void ItaniumVTableBuilder::dumpLayout(raw_ostream &Out) { ThunkInfoVectorTy ThunksVector = Thunks[MD]; llvm::sort(ThunksVector, [](const ThunkInfo &LHS, const ThunkInfo &RHS) { - assert(LHS.Method == nullptr && RHS.Method == nullptr); return std::tie(LHS.This, LHS.Return) < std::tie(RHS.This, RHS.Return); }); @@ -2314,6 +2378,35 @@ ItaniumVTableContext::getVirtualBaseOffsetOffset(const CXXRecordDecl *RD, return I->second; } +GlobalDecl ItaniumVTableContext::findOriginalMethod(GlobalDecl GD) { + const auto *MD = cast<CXXMethodDecl>(GD.getDecl()); + computeVTableRelatedInformation(MD->getParent()); + const CXXMethodDecl *OriginalMD = findOriginalMethodInMap(MD); + + if (const auto *DD = dyn_cast<CXXDestructorDecl>(OriginalMD)) + return GlobalDecl(DD, GD.getDtorType()); + return OriginalMD; +} + +const CXXMethodDecl * +ItaniumVTableContext::findOriginalMethodInMap(const CXXMethodDecl *MD) const { + // Traverse the chain of virtual methods until we find the method that added + // the v-table slot. + while (true) { + auto I = OriginalMethodMap.find(MD); + + // MD doesn't exist in OriginalMethodMap, so it must be the method we are + // looking for. + if (I == OriginalMethodMap.end()) + break; + + // Set MD to the overridden method. + MD = I->second; + } + + return MD; +} + static std::unique_ptr<VTableLayout> CreateVTableLayout(const ItaniumVTableBuilder &Builder) { SmallVector<VTableLayout::VTableThunkTy, 1> @@ -3094,9 +3187,9 @@ void VFTableBuilder::AddMethods(BaseSubobject Base, unsigned BaseDepth, ReturnAdjustmentOffset.VirtualBase); } } - + auto ThisType = (OverriddenMD ? OverriddenMD : MD)->getThisType().getTypePtr(); AddMethod(FinalOverriderMD, - ThunkInfo(ThisAdjustmentOffset, ReturnAdjustment, + ThunkInfo(ThisAdjustmentOffset, ReturnAdjustment, ThisType, ForceReturnAdjustmentMangling ? MD : nullptr)); } } diff --git a/clang/lib/CodeGen/CGCXX.cpp b/clang/lib/CodeGen/CGCXX.cpp index e95a735f..23ebbee 100644 --- a/clang/lib/CodeGen/CGCXX.cpp +++ b/clang/lib/CodeGen/CGCXX.cpp @@ -263,7 +263,16 @@ static CGCallee BuildAppleKextVirtualCall(CodeGenFunction &CGF, CGF.Builder.CreateConstInBoundsGEP1_64(Ty, VTable, VTableIndex, "vfnkxt"); llvm::Value *VFunc = CGF.Builder.CreateAlignedLoad( Ty, VFuncPtr, llvm::Align(CGF.PointerAlignInBytes)); - CGCallee Callee(GD, VFunc); + + CGPointerAuthInfo PointerAuth; + if (auto &Schema = + CGM.getCodeGenOpts().PointerAuth.CXXVirtualFunctionPointers) { + GlobalDecl OrigMD = + CGM.getItaniumVTableContext().findOriginalMethod(GD.getCanonicalDecl()); + PointerAuth = CGF.EmitPointerAuthInfo(Schema, VFuncPtr, OrigMD, QualType()); + } + + CGCallee Callee(GD, VFunc, PointerAuth); return Callee; } diff --git a/clang/lib/CodeGen/CGCXXABI.h b/clang/lib/CodeGen/CGCXXABI.h index 104a20d..7dcc539 100644 --- a/clang/lib/CodeGen/CGCXXABI.h +++ b/clang/lib/CodeGen/CGCXXABI.h @@ -504,13 +504,15 @@ public: virtual void setThunkLinkage(llvm::Function *Thunk, bool ForVTable, GlobalDecl GD, bool ReturnAdjustment) = 0; - virtual llvm::Value *performThisAdjustment(CodeGenFunction &CGF, - Address This, - const ThisAdjustment &TA) = 0; + virtual llvm::Value * + performThisAdjustment(CodeGenFunction &CGF, Address This, + const CXXRecordDecl *UnadjustedClass, + const ThunkInfo &TI) = 0; - virtual llvm::Value *performReturnAdjustment(CodeGenFunction &CGF, - Address Ret, - const ReturnAdjustment &RA) = 0; + virtual llvm::Value * + performReturnAdjustment(CodeGenFunction &CGF, Address Ret, + const CXXRecordDecl *UnadjustedClass, + const ReturnAdjustment &RA) = 0; virtual void EmitReturnFromThunk(CodeGenFunction &CGF, RValue RV, QualType ResultType); diff --git a/clang/lib/CodeGen/CGClass.cpp b/clang/lib/CodeGen/CGClass.cpp index 5a032bd..0a595bb 100644 --- a/clang/lib/CodeGen/CGClass.cpp +++ b/clang/lib/CodeGen/CGClass.cpp @@ -2588,6 +2588,11 @@ void CodeGenFunction::InitializeVTablePointer(const VPtr &Vptr) { // the same addr space. Note that this might not be LLVM address space 0. VTableField = VTableField.withElementType(PtrTy); + if (auto AuthenticationInfo = CGM.getVTablePointerAuthInfo( + this, Vptr.Base.getBase(), VTableField.emitRawPointer(*this))) + VTableAddressPoint = + EmitPointerAuthSign(*AuthenticationInfo, VTableAddressPoint); + llvm::StoreInst *Store = Builder.CreateStore(VTableAddressPoint, VTableField); TBAAAccessInfo TBAAInfo = CGM.getTBAAVTablePtrAccessInfo(PtrTy); CGM.DecorateInstructionWithTBAA(Store, TBAAInfo); @@ -2681,12 +2686,35 @@ void CodeGenFunction::InitializeVTablePointers(const CXXRecordDecl *RD) { llvm::Value *CodeGenFunction::GetVTablePtr(Address This, llvm::Type *VTableTy, - const CXXRecordDecl *RD) { + const CXXRecordDecl *RD, + VTableAuthMode AuthMode) { Address VTablePtrSrc = This.withElementType(VTableTy); llvm::Instruction *VTable = Builder.CreateLoad(VTablePtrSrc, "vtable"); TBAAAccessInfo TBAAInfo = CGM.getTBAAVTablePtrAccessInfo(VTableTy); CGM.DecorateInstructionWithTBAA(VTable, TBAAInfo); + if (auto AuthenticationInfo = + CGM.getVTablePointerAuthInfo(this, RD, This.emitRawPointer(*this))) { + if (AuthMode != VTableAuthMode::UnsafeUbsanStrip) { + VTable = cast<llvm::Instruction>( + EmitPointerAuthAuth(*AuthenticationInfo, VTable)); + if (AuthMode == VTableAuthMode::MustTrap) { + // This is clearly suboptimal but until we have an ability + // to rely on the authentication intrinsic trapping and force + // an authentication to occur we don't really have a choice. + VTable = + cast<llvm::Instruction>(Builder.CreateBitCast(VTable, Int8PtrTy)); + Builder.CreateLoad(RawAddress(VTable, Int8Ty, CGM.getPointerAlign()), + /* IsVolatile */ true); + } + } else { + VTable = cast<llvm::Instruction>(EmitPointerAuthAuth( + CGPointerAuthInfo(0, PointerAuthenticationMode::Strip, false, false, + nullptr), + VTable)); + } + } + if (CGM.getCodeGenOpts().OptimizationLevel > 0 && CGM.getCodeGenOpts().StrictVTablePointers) CGM.DecorateInstructionWithInvariantGroup(VTable, RD); diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp index 534f46d..23e5dee 100644 --- a/clang/lib/CodeGen/CGExpr.cpp +++ b/clang/lib/CodeGen/CGExpr.cpp @@ -830,8 +830,14 @@ void CodeGenFunction::EmitTypeCheck(TypeCheckKind TCK, SourceLocation Loc, // Load the vptr, and mix it with TypeHash. llvm::Value *TypeHash = llvm::ConstantInt::get(Int64Ty, xxh3_64bits(Out.str())); + + llvm::Type *VPtrTy = llvm::PointerType::get(IntPtrTy, 0); Address VPtrAddr(Ptr, IntPtrTy, getPointerAlign()); - llvm::Value *VPtrVal = Builder.CreateLoad(VPtrAddr); + llvm::Value *VPtrVal = GetVTablePtr(VPtrAddr, VPtrTy, + Ty->getAsCXXRecordDecl(), + VTableAuthMode::UnsafeUbsanStrip); + VPtrVal = Builder.CreateBitOrPointerCast(VPtrVal, IntPtrTy); + llvm::Value *Hash = emitHashMix(Builder, TypeHash, Builder.CreateZExt(VPtrVal, Int64Ty)); Hash = Builder.CreateTrunc(Hash, IntPtrTy); diff --git a/clang/lib/CodeGen/CGExprConstant.cpp b/clang/lib/CodeGen/CGExprConstant.cpp index 0f3e297..1fec587 100644 --- a/clang/lib/CodeGen/CGExprConstant.cpp +++ b/clang/lib/CodeGen/CGExprConstant.cpp @@ -803,6 +803,13 @@ bool ConstStructBuilder::Build(const APValue &Val, const RecordDecl *RD, llvm::Constant *VTableAddressPoint = CGM.getCXXABI().getVTableAddressPoint(BaseSubobject(CD, Offset), VTableClass); + if (auto Authentication = + CGM.getVTablePointerAuthentication(VTableClass)) { + VTableAddressPoint = Emitter.tryEmitConstantSignedPointer( + VTableAddressPoint, *Authentication); + if (!VTableAddressPoint) + return false; + } if (!AppendBytes(Offset, VTableAddressPoint)) return false; } @@ -1647,7 +1654,7 @@ namespace { // messing around with llvm::Constant structures, which never itself // does anything that should be visible in compiler output. for (auto &entry : Locations) { - assert(entry.first->getParent() == nullptr && "not a placeholder!"); + assert(entry.first->getName() == "" && "not a placeholder!"); entry.first->replaceAllUsesWith(entry.second); entry.first->eraseFromParent(); } @@ -1811,6 +1818,43 @@ llvm::Constant *ConstantEmitter::tryEmitPrivateForMemory(const APValue &value, return (C ? emitForMemory(C, destType) : nullptr); } +/// Try to emit a constant signed pointer, given a raw pointer and the +/// destination ptrauth qualifier. +/// +/// This can fail if the qualifier needs address discrimination and the +/// emitter is in an abstract mode. +llvm::Constant * +ConstantEmitter::tryEmitConstantSignedPointer(llvm::Constant *UnsignedPointer, + PointerAuthQualifier Schema) { + assert(Schema && "applying trivial ptrauth schema"); + + if (Schema.hasKeyNone()) + return UnsignedPointer; + + unsigned Key = Schema.getKey(); + + // Create an address placeholder if we're using address discrimination. + llvm::GlobalValue *StorageAddress = nullptr; + if (Schema.isAddressDiscriminated()) { + // We can't do this if the emitter is in an abstract state. + if (isAbstract()) + return nullptr; + + StorageAddress = getCurrentAddrPrivate(); + } + + llvm::ConstantInt *Discriminator = + llvm::ConstantInt::get(CGM.IntPtrTy, Schema.getExtraDiscriminator()); + + llvm::Constant *SignedPointer = CGM.getConstantSignedPointer( + UnsignedPointer, Key, StorageAddress, Discriminator); + + if (Schema.isAddressDiscriminated()) + registerCurrentAddrPrivate(SignedPointer, StorageAddress); + + return SignedPointer; +} + llvm::Constant *ConstantEmitter::emitForMemory(CodeGenModule &CGM, llvm::Constant *C, QualType destType) { diff --git a/clang/lib/CodeGen/CGPointerAuth.cpp b/clang/lib/CodeGen/CGPointerAuth.cpp index f0819b0..673f6e6 100644 --- a/clang/lib/CodeGen/CGPointerAuth.cpp +++ b/clang/lib/CodeGen/CGPointerAuth.cpp @@ -11,12 +11,56 @@ // //===----------------------------------------------------------------------===// +#include "CodeGenFunction.h" #include "CodeGenModule.h" #include "clang/CodeGen/CodeGenABITypes.h" +#include "clang/CodeGen/ConstantInitBuilder.h" +#include "llvm/Support/SipHash.h" using namespace clang; using namespace CodeGen; +/// Given a pointer-authentication schema, return a concrete "other" +/// discriminator for it. +llvm::ConstantInt *CodeGenModule::getPointerAuthOtherDiscriminator( + const PointerAuthSchema &Schema, GlobalDecl Decl, QualType Type) { + switch (Schema.getOtherDiscrimination()) { + case PointerAuthSchema::Discrimination::None: + return nullptr; + + case PointerAuthSchema::Discrimination::Type: + llvm_unreachable("type discrimination not implemented yet"); + + case PointerAuthSchema::Discrimination::Decl: + assert(Decl.getDecl() && + "declaration not provided for decl-discriminated schema"); + return llvm::ConstantInt::get(IntPtrTy, + getPointerAuthDeclDiscriminator(Decl)); + + case PointerAuthSchema::Discrimination::Constant: + return llvm::ConstantInt::get(IntPtrTy, Schema.getConstantDiscrimination()); + } + llvm_unreachable("bad discrimination kind"); +} + +uint16_t CodeGen::getPointerAuthDeclDiscriminator(CodeGenModule &CGM, + GlobalDecl Declaration) { + return CGM.getPointerAuthDeclDiscriminator(Declaration); +} + +/// Return the "other" decl-specific discriminator for the given decl. +uint16_t +CodeGenModule::getPointerAuthDeclDiscriminator(GlobalDecl Declaration) { + uint16_t &EntityHash = PtrAuthDiscriminatorHashes[Declaration]; + + if (EntityHash == 0) { + StringRef Name = getMangledName(Declaration); + EntityHash = llvm::getPointerAuthStableSipHash(Name); + } + + return EntityHash; +} + /// Return the abstract pointer authentication schema for a pointer to the given /// function type. CGPointerAuthInfo CodeGenModule::getFunctionPointerAuthInfo(QualType T) { @@ -35,6 +79,41 @@ CGPointerAuthInfo CodeGenModule::getFunctionPointerAuthInfo(QualType T) { /*Discriminator=*/nullptr); } +llvm::Value * +CodeGenFunction::EmitPointerAuthBlendDiscriminator(llvm::Value *StorageAddress, + llvm::Value *Discriminator) { + StorageAddress = Builder.CreatePtrToInt(StorageAddress, IntPtrTy); + auto Intrinsic = CGM.getIntrinsic(llvm::Intrinsic::ptrauth_blend); + return Builder.CreateCall(Intrinsic, {StorageAddress, Discriminator}); +} + +/// Emit the concrete pointer authentication informaton for the +/// given authentication schema. +CGPointerAuthInfo CodeGenFunction::EmitPointerAuthInfo( + const PointerAuthSchema &Schema, llvm::Value *StorageAddress, + GlobalDecl SchemaDecl, QualType SchemaType) { + if (!Schema) + return CGPointerAuthInfo(); + + llvm::Value *Discriminator = + CGM.getPointerAuthOtherDiscriminator(Schema, SchemaDecl, SchemaType); + + if (Schema.isAddressDiscriminated()) { + assert(StorageAddress && + "address not provided for address-discriminated schema"); + + if (Discriminator) + Discriminator = + EmitPointerAuthBlendDiscriminator(StorageAddress, Discriminator); + else + Discriminator = Builder.CreatePtrToInt(StorageAddress, IntPtrTy); + } + + return CGPointerAuthInfo(Schema.getKey(), Schema.getAuthenticationMode(), + Schema.isIsaPointer(), + Schema.authenticatesNullValues(), Discriminator); +} + llvm::Constant * CodeGenModule::getConstantSignedPointer(llvm::Constant *Pointer, unsigned Key, llvm::Constant *StorageAddress, @@ -60,6 +139,29 @@ CodeGenModule::getConstantSignedPointer(llvm::Constant *Pointer, unsigned Key, IntegerDiscriminator, AddressDiscriminator); } +/// Does a given PointerAuthScheme require us to sign a value +bool CodeGenModule::shouldSignPointer(const PointerAuthSchema &Schema) { + auto AuthenticationMode = Schema.getAuthenticationMode(); + return AuthenticationMode == PointerAuthenticationMode::SignAndStrip || + AuthenticationMode == PointerAuthenticationMode::SignAndAuth; +} + +/// Sign a constant pointer using the given scheme, producing a constant +/// with the same IR type. +llvm::Constant *CodeGenModule::getConstantSignedPointer( + llvm::Constant *Pointer, const PointerAuthSchema &Schema, + llvm::Constant *StorageAddress, GlobalDecl SchemaDecl, + QualType SchemaType) { + assert(shouldSignPointer(Schema)); + llvm::ConstantInt *OtherDiscriminator = + getPointerAuthOtherDiscriminator(Schema, SchemaDecl, SchemaType); + + return getConstantSignedPointer(Pointer, Schema.getKey(), StorageAddress, + OtherDiscriminator); +} + +/// If applicable, sign a given constant function pointer with the ABI rules for +/// functionType. llvm::Constant *CodeGenModule::getFunctionPointer(llvm::Constant *Pointer, QualType FunctionType) { assert(FunctionType->isFunctionType() || @@ -80,3 +182,113 @@ llvm::Constant *CodeGenModule::getFunctionPointer(GlobalDecl GD, QualType FuncType = FD->getType(); return getFunctionPointer(getRawFunctionPointer(GD, Ty), FuncType); } + +std::optional<PointerAuthQualifier> +CodeGenModule::computeVTPointerAuthentication(const CXXRecordDecl *ThisClass) { + auto DefaultAuthentication = getCodeGenOpts().PointerAuth.CXXVTablePointers; + if (!DefaultAuthentication) + return std::nullopt; + const CXXRecordDecl *PrimaryBase = + Context.baseForVTableAuthentication(ThisClass); + + unsigned Key = DefaultAuthentication.getKey(); + bool AddressDiscriminated = DefaultAuthentication.isAddressDiscriminated(); + auto DefaultDiscrimination = DefaultAuthentication.getOtherDiscrimination(); + unsigned TypeBasedDiscriminator = + Context.getPointerAuthVTablePointerDiscriminator(PrimaryBase); + unsigned Discriminator; + if (DefaultDiscrimination == PointerAuthSchema::Discrimination::Type) { + Discriminator = TypeBasedDiscriminator; + } else if (DefaultDiscrimination == + PointerAuthSchema::Discrimination::Constant) { + Discriminator = DefaultAuthentication.getConstantDiscrimination(); + } else { + assert(DefaultDiscrimination == PointerAuthSchema::Discrimination::None); + Discriminator = 0; + } + if (auto ExplicitAuthentication = + PrimaryBase->getAttr<VTablePointerAuthenticationAttr>()) { + auto ExplicitAddressDiscrimination = + ExplicitAuthentication->getAddressDiscrimination(); + auto ExplicitDiscriminator = + ExplicitAuthentication->getExtraDiscrimination(); + + unsigned ExplicitKey = ExplicitAuthentication->getKey(); + if (ExplicitKey == VTablePointerAuthenticationAttr::NoKey) + return std::nullopt; + + if (ExplicitKey != VTablePointerAuthenticationAttr::DefaultKey) { + if (ExplicitKey == VTablePointerAuthenticationAttr::ProcessIndependent) + Key = (unsigned)PointerAuthSchema::ARM8_3Key::ASDA; + else { + assert(ExplicitKey == + VTablePointerAuthenticationAttr::ProcessDependent); + Key = (unsigned)PointerAuthSchema::ARM8_3Key::ASDB; + } + } + + if (ExplicitAddressDiscrimination != + VTablePointerAuthenticationAttr::DefaultAddressDiscrimination) + AddressDiscriminated = + ExplicitAddressDiscrimination == + VTablePointerAuthenticationAttr::AddressDiscrimination; + + if (ExplicitDiscriminator == + VTablePointerAuthenticationAttr::TypeDiscrimination) + Discriminator = TypeBasedDiscriminator; + else if (ExplicitDiscriminator == + VTablePointerAuthenticationAttr::CustomDiscrimination) + Discriminator = ExplicitAuthentication->getCustomDiscriminationValue(); + else if (ExplicitDiscriminator == + VTablePointerAuthenticationAttr::NoExtraDiscrimination) + Discriminator = 0; + } + return PointerAuthQualifier::Create(Key, AddressDiscriminated, Discriminator, + PointerAuthenticationMode::SignAndAuth, + /* IsIsaPointer */ false, + /* AuthenticatesNullValues */ false); +} + +std::optional<PointerAuthQualifier> +CodeGenModule::getVTablePointerAuthentication(const CXXRecordDecl *Record) { + if (!Record->getDefinition() || !Record->isPolymorphic()) + return std::nullopt; + + auto Existing = VTablePtrAuthInfos.find(Record); + std::optional<PointerAuthQualifier> Authentication; + if (Existing != VTablePtrAuthInfos.end()) { + Authentication = Existing->getSecond(); + } else { + Authentication = computeVTPointerAuthentication(Record); + VTablePtrAuthInfos.insert(std::make_pair(Record, Authentication)); + } + return Authentication; +} + +std::optional<CGPointerAuthInfo> +CodeGenModule::getVTablePointerAuthInfo(CodeGenFunction *CGF, + const CXXRecordDecl *Record, + llvm::Value *StorageAddress) { + auto Authentication = getVTablePointerAuthentication(Record); + if (!Authentication) + return std::nullopt; + + llvm::Value *Discriminator = nullptr; + if (auto ExtraDiscriminator = Authentication->getExtraDiscriminator()) + Discriminator = llvm::ConstantInt::get(IntPtrTy, ExtraDiscriminator); + + if (Authentication->isAddressDiscriminated()) { + assert(StorageAddress && + "address not provided for address-discriminated schema"); + if (Discriminator) + Discriminator = + CGF->EmitPointerAuthBlendDiscriminator(StorageAddress, Discriminator); + else + Discriminator = CGF->Builder.CreatePtrToInt(StorageAddress, IntPtrTy); + } + + return CGPointerAuthInfo(Authentication->getKey(), + PointerAuthenticationMode::SignAndAuth, + /* IsIsaPointer */ false, + /* AuthenticatesNullValues */ false, Discriminator); +} diff --git a/clang/lib/CodeGen/CGVTT.cpp b/clang/lib/CodeGen/CGVTT.cpp index 4cebb75..20bd2c2f 100644 --- a/clang/lib/CodeGen/CGVTT.cpp +++ b/clang/lib/CodeGen/CGVTT.cpp @@ -90,6 +90,11 @@ CodeGenVTables::EmitVTTDefinition(llvm::GlobalVariable *VTT, llvm::Constant *Init = llvm::ConstantExpr::getGetElementPtr( VTable->getValueType(), VTable, Idxs, /*InBounds=*/true, InRange); + if (const auto &Schema = + CGM.getCodeGenOpts().PointerAuth.CXXVTTVTablePointers) + Init = CGM.getConstantSignedPointer(Init, Schema, nullptr, GlobalDecl(), + QualType()); + VTTComponents.push_back(Init); } diff --git a/clang/lib/CodeGen/CGVTables.cpp b/clang/lib/CodeGen/CGVTables.cpp index 55c3032..3e88cd7 100644 --- a/clang/lib/CodeGen/CGVTables.cpp +++ b/clang/lib/CodeGen/CGVTables.cpp @@ -95,7 +95,7 @@ static RValue PerformReturnAdjustment(CodeGenFunction &CGF, CGF, Address(ReturnValue, CGF.ConvertTypeForMem(ResultType->getPointeeType()), ClassAlign), - Thunk.Return); + ClassDecl, Thunk.Return); if (NullCheckValue) { CGF.Builder.CreateBr(AdjustEnd); @@ -219,8 +219,10 @@ CodeGenFunction::GenerateVarArgsThunk(llvm::Function *Fn, "Store of this should be in entry block?"); // Adjust "this", if necessary. Builder.SetInsertPoint(&*ThisStore); - llvm::Value *AdjustedThisPtr = - CGM.getCXXABI().performThisAdjustment(*this, ThisPtr, Thunk.This); + + const CXXRecordDecl *ThisValueClass = Thunk.ThisType->getPointeeCXXRecordDecl(); + llvm::Value *AdjustedThisPtr = CGM.getCXXABI().performThisAdjustment( + *this, ThisPtr, ThisValueClass, Thunk); AdjustedThisPtr = Builder.CreateBitCast(AdjustedThisPtr, ThisStore->getOperand(0)->getType()); ThisStore->setOperand(0, AdjustedThisPtr); @@ -307,10 +309,15 @@ void CodeGenFunction::EmitCallAndReturnForThunk(llvm::FunctionCallee Callee, const CXXMethodDecl *MD = cast<CXXMethodDecl>(CurGD.getDecl()); // Adjust the 'this' pointer if necessary + const CXXRecordDecl *ThisValueClass = + MD->getThisType()->getPointeeCXXRecordDecl(); + if (Thunk) + ThisValueClass = Thunk->ThisType->getPointeeCXXRecordDecl(); + llvm::Value *AdjustedThisPtr = - Thunk ? CGM.getCXXABI().performThisAdjustment( - *this, LoadCXXThisAddress(), Thunk->This) - : LoadCXXThis(); + Thunk ? CGM.getCXXABI().performThisAdjustment(*this, LoadCXXThisAddress(), + ThisValueClass, *Thunk) + : LoadCXXThis(); // If perfect forwarding is required a variadic method, a method using // inalloca, or an unprototyped thunk, use musttail. Emit an error if this @@ -504,10 +511,22 @@ llvm::Constant *CodeGenVTables::maybeEmitThunk(GlobalDecl GD, SmallString<256> Name; MangleContext &MCtx = CGM.getCXXABI().getMangleContext(); llvm::raw_svector_ostream Out(Name); - if (const CXXDestructorDecl *DD = dyn_cast<CXXDestructorDecl>(MD)) - MCtx.mangleCXXDtorThunk(DD, GD.getDtorType(), TI.This, Out); - else - MCtx.mangleThunk(MD, TI, Out); + + if (const CXXDestructorDecl *DD = dyn_cast<CXXDestructorDecl>(MD)) { + MCtx.mangleCXXDtorThunk(DD, GD.getDtorType(), TI, + /* elideOverrideInfo */ false, Out); + } else + MCtx.mangleThunk(MD, TI, /* elideOverrideInfo */ false, Out); + + if (CGM.getContext().useAbbreviatedThunkName(GD, Name.str())) { + Name = ""; + if (const CXXDestructorDecl *DD = dyn_cast<CXXDestructorDecl>(MD)) + MCtx.mangleCXXDtorThunk(DD, GD.getDtorType(), TI, + /* elideOverrideInfo */ true, Out); + else + MCtx.mangleThunk(MD, TI, /* elideOverrideInfo */ true, Out); + } + llvm::Type *ThunkVTableTy = CGM.getTypes().GetFunctionTypeForVTable(GD); llvm::Constant *Thunk = CGM.GetAddrOfThunk(Name, ThunkVTableTy, GD); @@ -819,11 +838,17 @@ void CodeGenVTables::addVTableComponent(ConstantArrayBuilder &builder, nextVTableThunkIndex++; fnPtr = maybeEmitThunk(GD, thunkInfo, /*ForVTable=*/true); + if (CGM.getCodeGenOpts().PointerAuth.CXXVirtualFunctionPointers) { + assert(thunkInfo.Method && "Method not set"); + GD = GD.getWithDecl(thunkInfo.Method); + } // Otherwise we can use the method definition directly. } else { llvm::Type *fnTy = CGM.getTypes().GetFunctionTypeForVTable(GD); fnPtr = CGM.GetAddrOfFunction(GD, fnTy, /*ForVTable=*/true); + if (CGM.getCodeGenOpts().PointerAuth.CXXVirtualFunctionPointers) + GD = getItaniumVTableContext().findOriginalMethod(GD); } if (useRelativeLayout()) { @@ -841,6 +866,9 @@ void CodeGenVTables::addVTableComponent(ConstantArrayBuilder &builder, if (FnAS != GVAS) fnPtr = llvm::ConstantExpr::getAddrSpaceCast(fnPtr, CGM.GlobalsInt8PtrTy); + if (const auto &Schema = + CGM.getCodeGenOpts().PointerAuth.CXXVirtualFunctionPointers) + return builder.addSignedPointer(fnPtr, Schema, GD, QualType()); return builder.add(fnPtr); } } diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp index 650c566..26deeca 100644 --- a/clang/lib/CodeGen/CodeGenFunction.cpp +++ b/clang/lib/CodeGen/CodeGenFunction.cpp @@ -3063,3 +3063,66 @@ void CodeGenFunction::EmitPointerAuthOperandBundle( llvm::Value *Args[] = {Key, Discriminator}; Bundles.emplace_back("ptrauth", Args); } + +static llvm::Value *EmitPointerAuthCommon(CodeGenFunction &CGF, + const CGPointerAuthInfo &PointerAuth, + llvm::Value *Pointer, + unsigned IntrinsicID) { + if (!PointerAuth) + return Pointer; + + auto Key = CGF.Builder.getInt32(PointerAuth.getKey()); + + llvm::Value *Discriminator = PointerAuth.getDiscriminator(); + if (!Discriminator) { + Discriminator = CGF.Builder.getSize(0); + } + + // Convert the pointer to intptr_t before signing it. + auto OrigType = Pointer->getType(); + Pointer = CGF.Builder.CreatePtrToInt(Pointer, CGF.IntPtrTy); + + // call i64 @llvm.ptrauth.sign.i64(i64 %pointer, i32 %key, i64 %discriminator) + auto Intrinsic = CGF.CGM.getIntrinsic(IntrinsicID); + Pointer = CGF.EmitRuntimeCall(Intrinsic, {Pointer, Key, Discriminator}); + + // Convert back to the original type. + Pointer = CGF.Builder.CreateIntToPtr(Pointer, OrigType); + return Pointer; +} + +llvm::Value * +CodeGenFunction::EmitPointerAuthSign(const CGPointerAuthInfo &PointerAuth, + llvm::Value *Pointer) { + if (!PointerAuth.shouldSign()) + return Pointer; + return EmitPointerAuthCommon(*this, PointerAuth, Pointer, + llvm::Intrinsic::ptrauth_sign); +} + +static llvm::Value *EmitStrip(CodeGenFunction &CGF, + const CGPointerAuthInfo &PointerAuth, + llvm::Value *Pointer) { + auto StripIntrinsic = CGF.CGM.getIntrinsic(llvm::Intrinsic::ptrauth_strip); + + auto Key = CGF.Builder.getInt32(PointerAuth.getKey()); + // Convert the pointer to intptr_t before signing it. + auto OrigType = Pointer->getType(); + Pointer = CGF.EmitRuntimeCall( + StripIntrinsic, {CGF.Builder.CreatePtrToInt(Pointer, CGF.IntPtrTy), Key}); + return CGF.Builder.CreateIntToPtr(Pointer, OrigType); +} + +llvm::Value * +CodeGenFunction::EmitPointerAuthAuth(const CGPointerAuthInfo &PointerAuth, + llvm::Value *Pointer) { + if (PointerAuth.shouldStrip()) { + return EmitStrip(*this, PointerAuth, Pointer); + } + if (!PointerAuth.shouldAuth()) { + return Pointer; + } + + return EmitPointerAuthCommon(*this, PointerAuth, Pointer, + llvm::Intrinsic::ptrauth_auth); +} diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h index a9c497b..13f12b5 100644 --- a/clang/lib/CodeGen/CodeGenFunction.h +++ b/clang/lib/CodeGen/CodeGenFunction.h @@ -2453,10 +2453,20 @@ public: void InitializeVTablePointers(const CXXRecordDecl *ClassDecl); + // VTableTrapMode - whether we guarantee that loading the + // vtable is guaranteed to trap on authentication failure, + // even if the resulting vtable pointer is unused. + enum class VTableAuthMode { + Authenticate, + MustTrap, + UnsafeUbsanStrip // Should only be used for Vptr UBSan check + }; /// GetVTablePtr - Return the Value of the vtable pointer member pointed /// to by This. - llvm::Value *GetVTablePtr(Address This, llvm::Type *VTableTy, - const CXXRecordDecl *VTableClass); + llvm::Value * + GetVTablePtr(Address This, llvm::Type *VTableTy, + const CXXRecordDecl *VTableClass, + VTableAuthMode AuthMode = VTableAuthMode::Authenticate); enum CFITypeCheckKind { CFITCK_VCall, @@ -4417,6 +4427,19 @@ public: bool isPointerKnownNonNull(const Expr *E); + /// Create the discriminator from the storage address and the entity hash. + llvm::Value *EmitPointerAuthBlendDiscriminator(llvm::Value *StorageAddress, + llvm::Value *Discriminator); + CGPointerAuthInfo EmitPointerAuthInfo(const PointerAuthSchema &Schema, + llvm::Value *StorageAddress, + GlobalDecl SchemaDecl, + QualType SchemaType); + llvm::Value *EmitPointerAuthSign(QualType PointeeType, llvm::Value *Pointer); + llvm::Value *EmitPointerAuthSign(const CGPointerAuthInfo &Info, + llvm::Value *Pointer); + llvm::Value *EmitPointerAuthAuth(const CGPointerAuthInfo &Info, + llvm::Value *Pointer); + void EmitPointerAuthOperandBundle( const CGPointerAuthInfo &Info, SmallVectorImpl<llvm::OperandBundleDef> &Bundles); diff --git a/clang/lib/CodeGen/CodeGenModule.h b/clang/lib/CodeGen/CodeGenModule.h index 99133047..22b2b31 100644 --- a/clang/lib/CodeGen/CodeGenModule.h +++ b/clang/lib/CodeGen/CodeGenModule.h @@ -609,6 +609,13 @@ private: std::pair<std::unique_ptr<CodeGenFunction>, const TopLevelStmtDecl *> GlobalTopLevelStmtBlockInFlight; + llvm::DenseMap<GlobalDecl, uint16_t> PtrAuthDiscriminatorHashes; + + llvm::DenseMap<const CXXRecordDecl *, std::optional<PointerAuthQualifier>> + VTablePtrAuthInfos; + std::optional<PointerAuthQualifier> + computeVTPointerAuthentication(const CXXRecordDecl *ThisClass); + public: CodeGenModule(ASTContext &C, IntrusiveRefCntPtr<llvm::vfs::FileSystem> FS, const HeaderSearchOptions &headersearchopts, @@ -957,11 +964,33 @@ public: CGPointerAuthInfo getFunctionPointerAuthInfo(QualType T); + bool shouldSignPointer(const PointerAuthSchema &Schema); + llvm::Constant *getConstantSignedPointer(llvm::Constant *Pointer, + const PointerAuthSchema &Schema, + llvm::Constant *StorageAddress, + GlobalDecl SchemaDecl, + QualType SchemaType); + llvm::Constant * getConstantSignedPointer(llvm::Constant *Pointer, unsigned Key, llvm::Constant *StorageAddress, llvm::ConstantInt *OtherDiscriminator); + llvm::ConstantInt * + getPointerAuthOtherDiscriminator(const PointerAuthSchema &Schema, + GlobalDecl SchemaDecl, QualType SchemaType); + + uint16_t getPointerAuthDeclDiscriminator(GlobalDecl GD); + std::optional<CGPointerAuthInfo> + getVTablePointerAuthInfo(CodeGenFunction *Context, + const CXXRecordDecl *Record, + llvm::Value *StorageAddress); + + std::optional<PointerAuthQualifier> + getVTablePointerAuthentication(const CXXRecordDecl *thisClass); + + CGPointerAuthInfo EmitPointerAuthInfo(const RecordDecl *RD); + // Return whether RTTI information should be emitted for this target. bool shouldEmitRTTI(bool ForEH = false) { return (ForEH || getLangOpts().RTTI) && !getLangOpts().CUDAIsDevice && diff --git a/clang/lib/CodeGen/ConstantEmitter.h b/clang/lib/CodeGen/ConstantEmitter.h index a55da0d..eff0a8d 100644 --- a/clang/lib/CodeGen/ConstantEmitter.h +++ b/clang/lib/CodeGen/ConstantEmitter.h @@ -113,6 +113,9 @@ public: llvm::Constant *tryEmitAbstract(const APValue &value, QualType T); llvm::Constant *tryEmitAbstractForMemory(const APValue &value, QualType T); + llvm::Constant *tryEmitConstantSignedPointer(llvm::Constant *Ptr, + PointerAuthQualifier Auth); + llvm::Constant *tryEmitConstantExpr(const ConstantExpr *CE); llvm::Constant *emitNullForMemory(QualType T) { diff --git a/clang/lib/CodeGen/ConstantInitBuilder.cpp b/clang/lib/CodeGen/ConstantInitBuilder.cpp index 3cf69f3..549d5dd 100644 --- a/clang/lib/CodeGen/ConstantInitBuilder.cpp +++ b/clang/lib/CodeGen/ConstantInitBuilder.cpp @@ -296,3 +296,21 @@ ConstantAggregateBuilderBase::finishStruct(llvm::StructType *ty) { buffer.erase(buffer.begin() + Begin, buffer.end()); return constant; } + +/// Sign the given pointer and add it to the constant initializer +/// currently being built. +void ConstantAggregateBuilderBase::addSignedPointer( + llvm::Constant *Pointer, const PointerAuthSchema &Schema, + GlobalDecl CalleeDecl, QualType CalleeType) { + if (!Schema || !Builder.CGM.shouldSignPointer(Schema)) + return add(Pointer); + + llvm::Constant *StorageAddress = nullptr; + if (Schema.isAddressDiscriminated()) { + StorageAddress = getAddrOfCurrentPosition(Pointer->getType()); + } + + llvm::Constant *SignedPointer = Builder.CGM.getConstantSignedPointer( + Pointer, Schema, StorageAddress, CalleeDecl, CalleeType); + add(SignedPointer); +} diff --git a/clang/lib/CodeGen/ItaniumCXXABI.cpp b/clang/lib/CodeGen/ItaniumCXXABI.cpp index 01a735c..63e36e1 100644 --- a/clang/lib/CodeGen/ItaniumCXXABI.cpp +++ b/clang/lib/CodeGen/ItaniumCXXABI.cpp @@ -23,6 +23,7 @@ #include "CGVTables.h" #include "CodeGenFunction.h" #include "CodeGenModule.h" +#include "ConstantEmitter.h" #include "TargetInfo.h" #include "clang/AST/Attr.h" #include "clang/AST/Mangle.h" @@ -336,9 +337,11 @@ public: bool exportThunk() override { return true; } llvm::Value *performThisAdjustment(CodeGenFunction &CGF, Address This, - const ThisAdjustment &TA) override; + const CXXRecordDecl *UnadjustedThisClass, + const ThunkInfo &TI) override; llvm::Value *performReturnAdjustment(CodeGenFunction &CGF, Address Ret, + const CXXRecordDecl *UnadjustedRetClass, const ReturnAdjustment &RA) override; size_t getSrcArgforCopyCtor(const CXXConstructorDecl *, @@ -1477,10 +1480,22 @@ llvm::Value *ItaniumCXXABI::emitDynamicCastCall( computeOffsetHint(CGF.getContext(), SrcDecl, DestDecl).getQuantity()); // Emit the call to __dynamic_cast. - llvm::Value *Args[] = {ThisAddr.emitRawPointer(CGF), SrcRTTI, DestRTTI, - OffsetHint}; - llvm::Value *Value = - CGF.EmitNounwindRuntimeCall(getItaniumDynamicCastFn(CGF), Args); + llvm::Value *Value = ThisAddr.emitRawPointer(CGF); + if (CGM.getCodeGenOpts().PointerAuth.CXXVTablePointers) { + // We perform a no-op load of the vtable pointer here to force an + // authentication. In environments that do not support pointer + // authentication this is a an actual no-op that will be elided. When + // pointer authentication is supported and enforced on vtable pointers this + // load can trap. + llvm::Value *Vtable = + CGF.GetVTablePtr(ThisAddr, CGM.Int8PtrTy, SrcDecl, + CodeGenFunction::VTableAuthMode::MustTrap); + assert(Vtable); + (void)Vtable; + } + + llvm::Value *args[] = {Value, SrcRTTI, DestRTTI, OffsetHint}; + Value = CGF.EmitNounwindRuntimeCall(getItaniumDynamicCastFn(CGF), args); /// C++ [expr.dynamic.cast]p9: /// A failed cast to reference type throws std::bad_cast @@ -1955,8 +1970,18 @@ llvm::Value *ItaniumCXXABI::getVTableAddressPointInStructorWithVTT( VirtualPointerIndex); // And load the address point from the VTT. - return CGF.Builder.CreateAlignedLoad(CGF.GlobalsVoidPtrTy, VTT, - CGF.getPointerAlign()); + llvm::Value *AP = + CGF.Builder.CreateAlignedLoad(CGF.GlobalsVoidPtrTy, VTT, + CGF.getPointerAlign()); + + if (auto &Schema = CGF.CGM.getCodeGenOpts().PointerAuth.CXXVTTVTablePointers) { + CGPointerAuthInfo PointerAuth = CGF.EmitPointerAuthInfo(Schema, VTT, + GlobalDecl(), + QualType()); + AP = CGF.EmitPointerAuthAuth(PointerAuth, AP); + } + + return AP; } llvm::GlobalVariable *ItaniumCXXABI::getAddrOfVTable(const CXXRecordDecl *RD, @@ -2008,8 +2033,9 @@ CGCallee ItaniumCXXABI::getVirtualFunctionPointer(CodeGenFunction &CGF, llvm::Value *VTable = CGF.GetVTablePtr(This, PtrTy, MethodDecl->getParent()); uint64_t VTableIndex = CGM.getItaniumVTableContext().getMethodVTableIndex(GD); - llvm::Value *VFunc; - if (CGF.ShouldEmitVTableTypeCheckedLoad(MethodDecl->getParent())) { + llvm::Value *VFunc, *VTableSlotPtr = nullptr; + auto &Schema = CGM.getCodeGenOpts().PointerAuth.CXXVirtualFunctionPointers; + if (!Schema && CGF.ShouldEmitVTableTypeCheckedLoad(MethodDecl->getParent())) { VFunc = CGF.EmitVTableTypeCheckedLoad( MethodDecl->getParent(), VTable, PtrTy, VTableIndex * @@ -2024,7 +2050,7 @@ CGCallee ItaniumCXXABI::getVirtualFunctionPointer(CodeGenFunction &CGF, CGM.getIntrinsic(llvm::Intrinsic::load_relative, {CGM.Int32Ty}), {VTable, llvm::ConstantInt::get(CGM.Int32Ty, 4 * VTableIndex)}); } else { - llvm::Value *VTableSlotPtr = CGF.Builder.CreateConstInBoundsGEP1_64( + VTableSlotPtr = CGF.Builder.CreateConstInBoundsGEP1_64( PtrTy, VTable, VTableIndex, "vfn"); VFuncLoad = CGF.Builder.CreateAlignedLoad(PtrTy, VTableSlotPtr, CGF.getPointerAlign()); @@ -2048,7 +2074,13 @@ CGCallee ItaniumCXXABI::getVirtualFunctionPointer(CodeGenFunction &CGF, VFunc = VFuncLoad; } - CGCallee Callee(GD, VFunc); + CGPointerAuthInfo PointerAuth; + if (Schema) { + assert(VTableSlotPtr && "virtual function pointer not set"); + GD = CGM.getItaniumVTableContext().findOriginalMethod(GD.getCanonicalDecl()); + PointerAuth = CGF.EmitPointerAuthInfo(Schema, VTableSlotPtr, GD, QualType()); + } + CGCallee Callee(GD, VFunc, PointerAuth); return Callee; } @@ -2147,6 +2179,7 @@ bool ItaniumCXXABI::canSpeculativelyEmitVTable(const CXXRecordDecl *RD) const { } static llvm::Value *performTypeAdjustment(CodeGenFunction &CGF, Address InitialPtr, + const CXXRecordDecl *UnadjustedClass, int64_t NonVirtualAdjustment, int64_t VirtualAdjustment, bool IsReturnAdjustment) { @@ -2164,8 +2197,8 @@ static llvm::Value *performTypeAdjustment(CodeGenFunction &CGF, // Perform the virtual adjustment if we have one. llvm::Value *ResultPtr; if (VirtualAdjustment) { - Address VTablePtrPtr = V.withElementType(CGF.Int8PtrTy); - llvm::Value *VTablePtr = CGF.Builder.CreateLoad(VTablePtrPtr); + llvm::Value *VTablePtr = + CGF.GetVTablePtr(V, CGF.Int8PtrTy, UnadjustedClass); llvm::Value *Offset; llvm::Value *OffsetPtr = CGF.Builder.CreateConstInBoundsGEP1_64( @@ -2200,18 +2233,20 @@ static llvm::Value *performTypeAdjustment(CodeGenFunction &CGF, return ResultPtr; } -llvm::Value *ItaniumCXXABI::performThisAdjustment(CodeGenFunction &CGF, - Address This, - const ThisAdjustment &TA) { - return performTypeAdjustment(CGF, This, TA.NonVirtual, - TA.Virtual.Itanium.VCallOffsetOffset, +llvm::Value * +ItaniumCXXABI::performThisAdjustment(CodeGenFunction &CGF, Address This, + const CXXRecordDecl *UnadjustedClass, + const ThunkInfo &TI) { + return performTypeAdjustment(CGF, This, UnadjustedClass, TI.This.NonVirtual, + TI.This.Virtual.Itanium.VCallOffsetOffset, /*IsReturnAdjustment=*/false); } llvm::Value * ItaniumCXXABI::performReturnAdjustment(CodeGenFunction &CGF, Address Ret, + const CXXRecordDecl *UnadjustedClass, const ReturnAdjustment &RA) { - return performTypeAdjustment(CGF, Ret, RA.NonVirtual, + return performTypeAdjustment(CGF, Ret, UnadjustedClass, RA.NonVirtual, RA.Virtual.Itanium.VBaseOffsetOffset, /*IsReturnAdjustment=*/true); } @@ -3694,6 +3729,10 @@ void ItaniumRTTIBuilder::BuildVTablePointer(const Type *Ty) { VTable, Two); } + if (auto &Schema = CGM.getCodeGenOpts().PointerAuth.CXXTypeInfoVTablePointer) + VTable = CGM.getConstantSignedPointer(VTable, Schema, nullptr, GlobalDecl(), + QualType(Ty, 0)); + Fields.push_back(VTable); } diff --git a/clang/lib/CodeGen/MicrosoftCXXABI.cpp b/clang/lib/CodeGen/MicrosoftCXXABI.cpp index 9ab634f..cc6740e 100644 --- a/clang/lib/CodeGen/MicrosoftCXXABI.cpp +++ b/clang/lib/CodeGen/MicrosoftCXXABI.cpp @@ -415,9 +415,11 @@ public: bool exportThunk() override { return false; } llvm::Value *performThisAdjustment(CodeGenFunction &CGF, Address This, - const ThisAdjustment &TA) override; + const CXXRecordDecl * /*UnadjustedClass*/, + const ThunkInfo &TI) override; llvm::Value *performReturnAdjustment(CodeGenFunction &CGF, Address Ret, + const CXXRecordDecl * /*UnadjustedClass*/, const ReturnAdjustment &RA) override; void EmitThreadLocalInitFuncs( @@ -2223,9 +2225,10 @@ void MicrosoftCXXABI::emitVBTableDefinition(const VPtrInfo &VBT, GV->setLinkage(llvm::GlobalVariable::AvailableExternallyLinkage); } -llvm::Value *MicrosoftCXXABI::performThisAdjustment(CodeGenFunction &CGF, - Address This, - const ThisAdjustment &TA) { +llvm::Value *MicrosoftCXXABI::performThisAdjustment( + CodeGenFunction &CGF, Address This, + const CXXRecordDecl * /*UnadjustedClass*/, const ThunkInfo &TI) { + const ThisAdjustment &TA = TI.This; if (TA.isEmpty()) return This.emitRawPointer(CGF); @@ -2275,9 +2278,10 @@ llvm::Value *MicrosoftCXXABI::performThisAdjustment(CodeGenFunction &CGF, return V; } -llvm::Value * -MicrosoftCXXABI::performReturnAdjustment(CodeGenFunction &CGF, Address Ret, - const ReturnAdjustment &RA) { +llvm::Value *MicrosoftCXXABI::performReturnAdjustment( + CodeGenFunction &CGF, Address Ret, + const CXXRecordDecl * /*UnadjustedClass*/, const ReturnAdjustment &RA) { + if (RA.isEmpty()) return Ret.emitRawPointer(CGF); diff --git a/clang/lib/Frontend/CompilerInvocation.cpp b/clang/lib/Frontend/CompilerInvocation.cpp index a6d9f42..f42e28b 100644 --- a/clang/lib/Frontend/CompilerInvocation.cpp +++ b/clang/lib/Frontend/CompilerInvocation.cpp @@ -1468,6 +1468,17 @@ void CompilerInvocation::setDefaultPointerAuthOptions( // If you change anything here, be sure to update <ptrauth.h>. Opts.FunctionPointers = PointerAuthSchema(Key::ASIA, false, Discrimination::None); + + Opts.CXXVTablePointers = PointerAuthSchema( + Key::ASDA, LangOpts.PointerAuthVTPtrAddressDiscrimination, + LangOpts.PointerAuthVTPtrTypeDiscrimination ? Discrimination::Type + : Discrimination::None); + Opts.CXXTypeInfoVTablePointer = + PointerAuthSchema(Key::ASDA, false, Discrimination::None); + Opts.CXXVTTVTablePointers = + PointerAuthSchema(Key::ASDA, false, Discrimination::None); + Opts.CXXVirtualFunctionPointers = Opts.CXXVirtualVariadicFunctionPointers = + PointerAuthSchema(Key::ASIA, true, Discrimination::Decl); } } diff --git a/clang/lib/Headers/ptrauth.h b/clang/lib/Headers/ptrauth.h index 1a4bd02..40ac6dc 100644 --- a/clang/lib/Headers/ptrauth.h +++ b/clang/lib/Headers/ptrauth.h @@ -32,6 +32,10 @@ typedef enum { The extra data is always 0. */ ptrauth_key_function_pointer = ptrauth_key_process_independent_code, + /* The key used to sign C++ v-table pointers. + The extra data is always 0. */ + ptrauth_key_cxx_vtable_pointer = ptrauth_key_process_independent_data, + /* Other pointers signed under the ABI use private ABI rules. */ } ptrauth_key; @@ -205,6 +209,12 @@ typedef __UINTPTR_TYPE__ ptrauth_generic_signature_t; #define ptrauth_sign_generic_data(__value, __data) \ __builtin_ptrauth_sign_generic_data(__value, __data) +/* C++ vtable pointer signing class attribute */ +#define ptrauth_cxx_vtable_pointer(key, address_discrimination, \ + extra_discrimination...) \ + [[clang::ptrauth_vtable_pointer(key, address_discrimination, \ + extra_discrimination)]] + #else #define ptrauth_strip(__value, __key) \ @@ -271,6 +281,10 @@ typedef __UINTPTR_TYPE__ ptrauth_generic_signature_t; ((ptrauth_generic_signature_t)0); \ }) + +#define ptrauth_cxx_vtable_pointer(key, address_discrimination, \ + extra_discrimination...) + #endif /* __has_feature(ptrauth_intrinsics) */ #endif /* __PTRAUTH_H */ diff --git a/clang/lib/InstallAPI/Visitor.cpp b/clang/lib/InstallAPI/Visitor.cpp index 367ae53..a73ea0b 100644 --- a/clang/lib/InstallAPI/Visitor.cpp +++ b/clang/lib/InstallAPI/Visitor.cpp @@ -447,16 +447,16 @@ InstallAPIVisitor::getMangledCXXVTableName(const CXXRecordDecl *D) const { return getBackendMangledName(Name); } -std::string -InstallAPIVisitor::getMangledCXXThunk(const GlobalDecl &D, - const ThunkInfo &Thunk) const { +std::string InstallAPIVisitor::getMangledCXXThunk( + const GlobalDecl &D, const ThunkInfo &Thunk, bool ElideOverrideInfo) const { SmallString<256> Name; raw_svector_ostream NameStream(Name); const auto *Method = cast<CXXMethodDecl>(D.getDecl()); if (const auto *Dtor = dyn_cast<CXXDestructorDecl>(Method)) - MC->mangleCXXDtorThunk(Dtor, D.getDtorType(), Thunk.This, NameStream); + MC->mangleCXXDtorThunk(Dtor, D.getDtorType(), Thunk, ElideOverrideInfo, + NameStream); else - MC->mangleThunk(Method, Thunk, NameStream); + MC->mangleThunk(Method, Thunk, ElideOverrideInfo, NameStream); return getBackendMangledName(Name); } @@ -500,7 +500,8 @@ void InstallAPIVisitor::emitVTableSymbols(const CXXRecordDecl *D, return; for (const auto &Thunk : *Thunks) { - const std::string Name = getMangledCXXThunk(GD, Thunk); + const std::string Name = + getMangledCXXThunk(GD, Thunk, /*ElideOverrideInfo=*/true); auto [GR, FA] = Ctx.Slice->addGlobal(Name, RecordLinkage::Exported, GlobalRecord::Kind::Function, Avail, GD.getDecl(), Access); diff --git a/clang/lib/Parse/ParseDecl.cpp b/clang/lib/Parse/ParseDecl.cpp index c528917..a07f7ad 100644 --- a/clang/lib/Parse/ParseDecl.cpp +++ b/clang/lib/Parse/ParseDecl.cpp @@ -368,6 +368,27 @@ static bool attributeIsTypeArgAttr(const IdentifierInfo &II) { #undef CLANG_ATTR_TYPE_ARG_LIST } +/// Determine whether the given attribute takes identifier arguments. +static bool attributeHasStrictIdentifierArgs(const IdentifierInfo &II) { +#define CLANG_ATTR_STRICT_IDENTIFIER_ARG_AT_INDEX_LIST + return (llvm::StringSwitch<uint64_t>(normalizeAttrName(II.getName())) +#include "clang/Parse/AttrParserStringSwitches.inc" + .Default(0)) != 0; +#undef CLANG_ATTR_STRICT_IDENTIFIER_ARG_AT_INDEX_LIST +} + +/// Determine whether the given attribute takes an identifier argument at a +/// specific index +static bool attributeHasStrictIdentifierArgAtIndex(const IdentifierInfo &II, + size_t argIndex) { +#define CLANG_ATTR_STRICT_IDENTIFIER_ARG_AT_INDEX_LIST + return (llvm::StringSwitch<uint64_t>(normalizeAttrName(II.getName())) +#include "clang/Parse/AttrParserStringSwitches.inc" + .Default(0)) & + (1ull << argIndex); +#undef CLANG_ATTR_STRICT_IDENTIFIER_ARG_AT_INDEX_LIST +} + /// Determine whether the given attribute requires parsing its arguments /// in an unevaluated context or not. static bool attributeParsedArgsUnevaluated(const IdentifierInfo &II) { @@ -546,7 +567,8 @@ unsigned Parser::ParseAttributeArgsCommon( } if (T.isUsable()) TheParsedType = T.get(); - } else if (AttributeHasVariadicIdentifierArg) { + } else if (AttributeHasVariadicIdentifierArg || + attributeHasStrictIdentifierArgs(*AttrName)) { // Parse variadic identifier arg. This can either consume identifiers or // expressions. Variadic identifier args do not support parameter packs // because those are typically used for attributes with enumeration @@ -557,6 +579,12 @@ unsigned Parser::ParseAttributeArgsCommon( if (ChangeKWThisToIdent && Tok.is(tok::kw_this)) Tok.setKind(tok::identifier); + if (Tok.is(tok::identifier) && attributeHasStrictIdentifierArgAtIndex( + *AttrName, ArgExprs.size())) { + ArgExprs.push_back(ParseIdentifierLoc()); + continue; + } + ExprResult ArgExpr; if (Tok.is(tok::identifier)) { ArgExprs.push_back(ParseIdentifierLoc()); diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp index 99b4003..7e6c7d7 100644 --- a/clang/lib/Sema/SemaDeclAttr.cpp +++ b/clang/lib/Sema/SemaDeclAttr.cpp @@ -6281,6 +6281,116 @@ EnforceTCBLeafAttr *Sema::mergeEnforceTCBLeafAttr( *this, D, AL); } +static void handleVTablePointerAuthentication(Sema &S, Decl *D, + const ParsedAttr &AL) { + CXXRecordDecl *Decl = cast<CXXRecordDecl>(D); + const uint32_t NumArgs = AL.getNumArgs(); + if (NumArgs > 4) { + S.Diag(AL.getLoc(), diag::err_attribute_too_many_arguments) << AL << 4; + AL.setInvalid(); + } + + if (NumArgs == 0) { + S.Diag(AL.getLoc(), diag::err_attribute_too_few_arguments) << AL; + AL.setInvalid(); + return; + } + + if (D->getAttr<VTablePointerAuthenticationAttr>()) { + S.Diag(AL.getLoc(), diag::err_duplicated_vtable_pointer_auth) << Decl; + AL.setInvalid(); + } + + auto KeyType = VTablePointerAuthenticationAttr::VPtrAuthKeyType::DefaultKey; + if (AL.isArgIdent(0)) { + IdentifierLoc *IL = AL.getArgAsIdent(0); + if (!VTablePointerAuthenticationAttr::ConvertStrToVPtrAuthKeyType( + IL->Ident->getName(), KeyType)) { + S.Diag(IL->Loc, diag::err_invalid_authentication_key) << IL->Ident; + AL.setInvalid(); + } + if (KeyType == VTablePointerAuthenticationAttr::DefaultKey && + !S.getLangOpts().PointerAuthCalls) { + S.Diag(AL.getLoc(), diag::err_no_default_vtable_pointer_auth) << 0; + AL.setInvalid(); + } + } else { + S.Diag(AL.getLoc(), diag::err_attribute_argument_type) + << AL << AANT_ArgumentIdentifier; + return; + } + + auto AddressDiversityMode = VTablePointerAuthenticationAttr:: + AddressDiscriminationMode::DefaultAddressDiscrimination; + if (AL.getNumArgs() > 1) { + if (AL.isArgIdent(1)) { + IdentifierLoc *IL = AL.getArgAsIdent(1); + if (!VTablePointerAuthenticationAttr:: + ConvertStrToAddressDiscriminationMode(IL->Ident->getName(), + AddressDiversityMode)) { + S.Diag(IL->Loc, diag::err_invalid_address_discrimination) << IL->Ident; + AL.setInvalid(); + } + if (AddressDiversityMode == + VTablePointerAuthenticationAttr::DefaultAddressDiscrimination && + !S.getLangOpts().PointerAuthCalls) { + S.Diag(IL->Loc, diag::err_no_default_vtable_pointer_auth) << 1; + AL.setInvalid(); + } + } else { + S.Diag(AL.getLoc(), diag::err_attribute_argument_type) + << AL << AANT_ArgumentIdentifier; + } + } + + auto ED = VTablePointerAuthenticationAttr::ExtraDiscrimination:: + DefaultExtraDiscrimination; + if (AL.getNumArgs() > 2) { + if (AL.isArgIdent(2)) { + IdentifierLoc *IL = AL.getArgAsIdent(2); + if (!VTablePointerAuthenticationAttr::ConvertStrToExtraDiscrimination( + IL->Ident->getName(), ED)) { + S.Diag(IL->Loc, diag::err_invalid_extra_discrimination) << IL->Ident; + AL.setInvalid(); + } + if (ED == VTablePointerAuthenticationAttr::DefaultExtraDiscrimination && + !S.getLangOpts().PointerAuthCalls) { + S.Diag(AL.getLoc(), diag::err_no_default_vtable_pointer_auth) << 2; + AL.setInvalid(); + } + } else { + S.Diag(AL.getLoc(), diag::err_attribute_argument_type) + << AL << AANT_ArgumentIdentifier; + } + } + + uint32_t CustomDiscriminationValue = 0; + if (ED == VTablePointerAuthenticationAttr::CustomDiscrimination) { + if (NumArgs < 4) { + S.Diag(AL.getLoc(), diag::err_missing_custom_discrimination) << AL << 4; + AL.setInvalid(); + return; + } + if (NumArgs > 4) { + S.Diag(AL.getLoc(), diag::err_attribute_too_many_arguments) << AL << 4; + AL.setInvalid(); + } + + if (!AL.isArgExpr(3) || !S.checkUInt32Argument(AL, AL.getArgAsExpr(3), + CustomDiscriminationValue)) { + S.Diag(AL.getLoc(), diag::err_invalid_custom_discrimination); + AL.setInvalid(); + } + } else if (NumArgs > 3) { + S.Diag(AL.getLoc(), diag::err_attribute_too_many_arguments) << AL << 3; + AL.setInvalid(); + } + + Decl->addAttr(::new (S.Context) VTablePointerAuthenticationAttr( + S.Context, AL, KeyType, AddressDiversityMode, ED, + CustomDiscriminationValue)); +} + //===----------------------------------------------------------------------===// // Top Level Sema Entry Points //===----------------------------------------------------------------------===// @@ -7150,6 +7260,10 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL, case ParsedAttr::AT_TypeNullable: handleNullableTypeAttr(S, D, AL); break; + + case ParsedAttr::AT_VTablePointerAuthentication: + handleVTablePointerAuthentication(S, D, AL); + break; } } diff --git a/clang/lib/Sema/SemaDeclCXX.cpp b/clang/lib/Sema/SemaDeclCXX.cpp index 9b22010..ffa6bca 100644 --- a/clang/lib/Sema/SemaDeclCXX.cpp +++ b/clang/lib/Sema/SemaDeclCXX.cpp @@ -7116,6 +7116,10 @@ void Sema::CheckCompletedCXXClass(Scope *S, CXXRecordDecl *Record) { return false; }; + if (!Record->isInvalidDecl() && + Record->hasAttr<VTablePointerAuthenticationAttr>()) + checkIncorrectVTablePointerAuthenticationAttribute(*Record); + auto CompleteMemberFunction = [&](CXXMethodDecl *M) { // Check whether the explicitly-defaulted members are valid. bool Incomplete = CheckForDefaultedFunction(M); @@ -10500,6 +10504,39 @@ void Sema::checkIllFormedTrivialABIStruct(CXXRecordDecl &RD) { } } +void Sema::checkIncorrectVTablePointerAuthenticationAttribute( + CXXRecordDecl &RD) { + if (RequireCompleteType(RD.getLocation(), Context.getRecordType(&RD), + diag::err_incomplete_type_vtable_pointer_auth)) + return; + + const CXXRecordDecl *PrimaryBase = &RD; + if (PrimaryBase->hasAnyDependentBases()) + return; + + while (1) { + assert(PrimaryBase); + const CXXRecordDecl *Base = nullptr; + for (auto BasePtr : PrimaryBase->bases()) { + if (!BasePtr.getType()->getAsCXXRecordDecl()->isDynamicClass()) + continue; + Base = BasePtr.getType()->getAsCXXRecordDecl(); + break; + } + if (!Base || Base == PrimaryBase || !Base->isPolymorphic()) + break; + Diag(RD.getAttr<VTablePointerAuthenticationAttr>()->getLocation(), + diag::err_non_top_level_vtable_pointer_auth) + << &RD << Base; + PrimaryBase = Base; + } + + if (!RD.isPolymorphic()) + Diag(RD.getAttr<VTablePointerAuthenticationAttr>()->getLocation(), + diag::err_non_polymorphic_vtable_pointer_auth) + << &RD; +} + void Sema::ActOnFinishCXXMemberSpecification( Scope *S, SourceLocation RLoc, Decl *TagDecl, SourceLocation LBrac, SourceLocation RBrac, const ParsedAttributesView &AttrList) { diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp index 4a2f3a6..db44cfe 100644 --- a/clang/lib/Sema/SemaExpr.cpp +++ b/clang/lib/Sema/SemaExpr.cpp @@ -14217,6 +14217,39 @@ QualType Sema::CheckAddressOfOperand(ExprResult &OrigOp, SourceLocation OpLoc) { QualType MPTy = Context.getMemberPointerType( op->getType(), Context.getTypeDeclType(MD->getParent()).getTypePtr()); + + if (getLangOpts().PointerAuthCalls && MD->isVirtual() && + !isUnevaluatedContext() && !MPTy->isDependentType()) { + // When pointer authentication is enabled, argument and return types of + // vitual member functions must be complete. This is because vitrual + // member function pointers are implemented using virtual dispatch + // thunks and the thunks cannot be emitted if the argument or return + // types are incomplete. + auto ReturnOrParamTypeIsIncomplete = [&](QualType T, + SourceLocation DeclRefLoc, + SourceLocation RetArgTypeLoc) { + if (RequireCompleteType(DeclRefLoc, T, diag::err_incomplete_type)) { + Diag(DeclRefLoc, + diag::note_ptrauth_virtual_function_pointer_incomplete_arg_ret); + Diag(RetArgTypeLoc, + diag::note_ptrauth_virtual_function_incomplete_arg_ret_type) + << T; + return true; + } + return false; + }; + QualType RetTy = MD->getReturnType(); + bool IsIncomplete = + !RetTy->isVoidType() && + ReturnOrParamTypeIsIncomplete( + RetTy, OpLoc, MD->getReturnTypeSourceRange().getBegin()); + for (auto *PVD : MD->parameters()) + IsIncomplete |= ReturnOrParamTypeIsIncomplete(PVD->getType(), OpLoc, + PVD->getBeginLoc()); + if (IsIncomplete) + return QualType(); + } + // Under the MS ABI, lock down the inheritance model now. if (Context.getTargetInfo().getCXXABI().isMicrosoft()) (void)isCompleteType(OpLoc, MPTy); |