diff options
Diffstat (limited to 'bolt/lib/Passes/PAuthGadgetScanner.cpp')
-rw-r--r-- | bolt/lib/Passes/PAuthGadgetScanner.cpp | 172 |
1 files changed, 102 insertions, 70 deletions
diff --git a/bolt/lib/Passes/PAuthGadgetScanner.cpp b/bolt/lib/Passes/PAuthGadgetScanner.cpp index cfe4b6b..01b350b 100644 --- a/bolt/lib/Passes/PAuthGadgetScanner.cpp +++ b/bolt/lib/Passes/PAuthGadgetScanner.cpp @@ -14,6 +14,7 @@ #include "bolt/Passes/PAuthGadgetScanner.h" #include "bolt/Core/ParallelUtilities.h" #include "bolt/Passes/DataflowAnalysis.h" +#include "bolt/Utils/CommandLineOpts.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallSet.h" #include "llvm/MC/MCInst.h" @@ -26,6 +27,11 @@ namespace llvm { namespace bolt { namespace PAuthGadgetScanner { +static cl::opt<bool> AuthTrapsOnFailure( + "auth-traps-on-failure", + cl::desc("Assume authentication instructions always trap on failure"), + cl::cat(opts::BinaryAnalysisCategory)); + [[maybe_unused]] static void traceInst(const BinaryContext &BC, StringRef Label, const MCInst &MI) { dbgs() << " " << Label << ": "; @@ -82,8 +88,8 @@ public: TrackedRegisters(ArrayRef<MCPhysReg> RegsToTrack) : Registers(RegsToTrack), RegToIndexMapping(getMappingSize(RegsToTrack), NoIndex) { - for (unsigned I = 0; I < RegsToTrack.size(); ++I) - RegToIndexMapping[RegsToTrack[I]] = I; + for (auto [MappedIndex, Reg] : llvm::enumerate(RegsToTrack)) + RegToIndexMapping[Reg] = MappedIndex; } ArrayRef<MCPhysReg> getRegisters() const { return Registers; } @@ -197,9 +203,9 @@ struct SrcState { SafeToDerefRegs &= StateIn.SafeToDerefRegs; TrustedRegs &= StateIn.TrustedRegs; - for (unsigned I = 0; I < LastInstWritingReg.size(); ++I) - for (const MCInst *J : StateIn.LastInstWritingReg[I]) - LastInstWritingReg[I].insert(J); + for (auto [ThisSet, OtherSet] : + llvm::zip_equal(LastInstWritingReg, StateIn.LastInstWritingReg)) + ThisSet.insert_range(OtherSet); return *this; } @@ -218,11 +224,9 @@ struct SrcState { static void printInstsShort(raw_ostream &OS, ArrayRef<SetOfRelatedInsts> Insts) { OS << "Insts: "; - for (unsigned I = 0; I < Insts.size(); ++I) { - auto &Set = Insts[I]; + for (auto [I, PtrSet] : llvm::enumerate(Insts)) { OS << "[" << I << "]("; - for (const MCInst *MCInstP : Set) - OS << MCInstP << " "; + interleave(PtrSet, OS, " "); OS << ")"; } } @@ -364,6 +368,34 @@ protected: return Clobbered; } + std::optional<MCPhysReg> getRegMadeTrustedByChecking(const MCInst &Inst, + SrcState Cur) const { + // This function cannot return multiple registers. This is never the case + // on AArch64. + std::optional<MCPhysReg> RegCheckedByInst = + BC.MIB->getAuthCheckedReg(Inst, /*MayOverwrite=*/false); + if (RegCheckedByInst && Cur.SafeToDerefRegs[*RegCheckedByInst]) + return *RegCheckedByInst; + + auto It = CheckerSequenceInfo.find(&Inst); + if (It == CheckerSequenceInfo.end()) + return std::nullopt; + + MCPhysReg RegCheckedBySequence = It->second.first; + const MCInst *FirstCheckerInst = It->second.second; + + // FirstCheckerInst should belong to the same basic block (see the + // assertion in DataflowSrcSafetyAnalysis::run()), meaning it was + // deterministically processed a few steps before this instruction. + const SrcState &StateBeforeChecker = getStateBefore(*FirstCheckerInst); + + // The sequence checks the register, but it should be authenticated before. + if (!StateBeforeChecker.SafeToDerefRegs[RegCheckedBySequence]) + return std::nullopt; + + return RegCheckedBySequence; + } + // Returns all registers that can be treated as if they are written by an // authentication instruction. SmallVector<MCPhysReg> getRegsMadeSafeToDeref(const MCInst &Point, @@ -382,22 +414,43 @@ protected: // ... an address can be updated in a safe manner, producing the result // which is as trusted as the input address. if (auto DstAndSrc = BC.MIB->analyzeAddressArithmeticsForPtrAuth(Point)) { - if (Cur.SafeToDerefRegs[DstAndSrc->second]) - Regs.push_back(DstAndSrc->first); + auto [DstReg, SrcReg] = *DstAndSrc; + if (Cur.SafeToDerefRegs[SrcReg]) + Regs.push_back(DstReg); } + // Make sure explicit checker sequence keeps register safe-to-dereference + // when the register would be clobbered according to the regular rules: + // + // ; LR is safe to dereference here + // mov x16, x30 ; start of the sequence, LR is s-t-d right before + // xpaclri ; clobbers LR, LR is not safe anymore + // cmp x30, x16 + // b.eq 1f ; end of the sequence: LR is marked as trusted + // brk 0x1234 + // 1: + // ; at this point LR would be marked as trusted, + // ; but not safe-to-dereference + // + // or even just + // + // ; X1 is safe to dereference here + // ldr x0, [x1, #8]! + // ; X1 is trusted here, but it was clobbered due to address write-back + if (auto CheckedReg = getRegMadeTrustedByChecking(Point, Cur)) + Regs.push_back(*CheckedReg); + return Regs; } // Returns all registers made trusted by this instruction. SmallVector<MCPhysReg> getRegsMadeTrusted(const MCInst &Point, const SrcState &Cur) const { + assert(!AuthTrapsOnFailure && "Use getRegsMadeSafeToDeref instead"); SmallVector<MCPhysReg> Regs; // An authenticated pointer can be checked, or - std::optional<MCPhysReg> CheckedReg = - BC.MIB->getAuthCheckedReg(Point, /*MayOverwrite=*/false); - if (CheckedReg && Cur.SafeToDerefRegs[*CheckedReg]) + if (auto CheckedReg = getRegMadeTrustedByChecking(Point, Cur)) Regs.push_back(*CheckedReg); // ... a pointer can be authenticated by an instruction that always checks @@ -408,19 +461,6 @@ protected: if (AutReg && IsChecked) Regs.push_back(*AutReg); - if (CheckerSequenceInfo.contains(&Point)) { - MCPhysReg CheckedReg; - const MCInst *FirstCheckerInst; - std::tie(CheckedReg, FirstCheckerInst) = CheckerSequenceInfo.at(&Point); - - // FirstCheckerInst should belong to the same basic block (see the - // assertion in DataflowSrcSafetyAnalysis::run()), meaning it was - // deterministically processed a few steps before this instruction. - const SrcState &StateBeforeChecker = getStateBefore(*FirstCheckerInst); - if (StateBeforeChecker.SafeToDerefRegs[CheckedReg]) - Regs.push_back(CheckedReg); - } - // ... a safe address can be materialized, or if (auto NewAddrReg = BC.MIB->getMaterializedAddressRegForPtrAuth(Point)) Regs.push_back(*NewAddrReg); @@ -428,8 +468,9 @@ protected: // ... an address can be updated in a safe manner, producing the result // which is as trusted as the input address. if (auto DstAndSrc = BC.MIB->analyzeAddressArithmeticsForPtrAuth(Point)) { - if (Cur.TrustedRegs[DstAndSrc->second]) - Regs.push_back(DstAndSrc->first); + auto [DstReg, SrcReg] = *DstAndSrc; + if (Cur.TrustedRegs[SrcReg]) + Regs.push_back(DstReg); } return Regs; @@ -463,28 +504,11 @@ protected: BitVector Clobbered = getClobberedRegs(Point); SmallVector<MCPhysReg> NewSafeToDerefRegs = getRegsMadeSafeToDeref(Point, Cur); - SmallVector<MCPhysReg> NewTrustedRegs = getRegsMadeTrusted(Point, Cur); - - // Ideally, being trusted is a strictly stronger property than being - // safe-to-dereference. To simplify the computation of Next state, enforce - // this for NewSafeToDerefRegs and NewTrustedRegs. Additionally, this - // fixes the properly for "cumulative" register states in tricky cases - // like the following: - // - // ; LR is safe to dereference here - // mov x16, x30 ; start of the sequence, LR is s-t-d right before - // xpaclri ; clobbers LR, LR is not safe anymore - // cmp x30, x16 - // b.eq 1f ; end of the sequence: LR is marked as trusted - // brk 0x1234 - // 1: - // ; at this point LR would be marked as trusted, - // ; but not safe-to-dereference - // - for (auto TrustedReg : NewTrustedRegs) { - if (!is_contained(NewSafeToDerefRegs, TrustedReg)) - NewSafeToDerefRegs.push_back(TrustedReg); - } + // If authentication instructions trap on failure, safe-to-dereference + // registers are always trusted. + SmallVector<MCPhysReg> NewTrustedRegs = + AuthTrapsOnFailure ? NewSafeToDerefRegs + : getRegsMadeTrusted(Point, Cur); // Then, compute the state after this instruction is executed. SrcState Next = Cur; @@ -521,6 +545,11 @@ protected: dbgs() << ")\n"; }); + // Being trusted is a strictly stronger property than being + // safe-to-dereference. + assert(!Next.TrustedRegs.test(Next.SafeToDerefRegs) && + "SafeToDerefRegs should contain all TrustedRegs"); + return Next; } @@ -836,9 +865,9 @@ struct DstState { return (*this = StateIn); CannotEscapeUnchecked &= StateIn.CannotEscapeUnchecked; - for (unsigned I = 0; I < FirstInstLeakingReg.size(); ++I) - for (const MCInst *J : StateIn.FirstInstLeakingReg[I]) - FirstInstLeakingReg[I].insert(J); + for (auto [ThisSet, OtherSet] : + llvm::zip_equal(FirstInstLeakingReg, StateIn.FirstInstLeakingReg)) + ThisSet.insert_range(OtherSet); return *this; } @@ -1004,8 +1033,7 @@ protected: // ... an address can be updated in a safe manner, or if (auto DstAndSrc = BC.MIB->analyzeAddressArithmeticsForPtrAuth(Inst)) { - MCPhysReg DstReg, SrcReg; - std::tie(DstReg, SrcReg) = *DstAndSrc; + auto [DstReg, SrcReg] = *DstAndSrc; // Note that *all* registers containing the derived values must be safe, // both source and destination ones. No temporaries are supported at now. if (Cur.CannotEscapeUnchecked[SrcReg] && @@ -1045,7 +1073,7 @@ protected: // If this instruction terminates the program immediately, no // authentication oracles are possible past this point. if (BC.MIB->isTrap(Point)) { - LLVM_DEBUG({ traceInst(BC, "Trap instruction found", Point); }); + LLVM_DEBUG(traceInst(BC, "Trap instruction found", Point)); DstState Next(NumRegs, RegsToTrackInstsFor.getNumTrackedRegisters()); Next.CannotEscapeUnchecked.set(); return Next; @@ -1130,6 +1158,11 @@ public: } void run() override { + // As long as DstSafetyAnalysis is only computed to detect authentication + // oracles, it is a waste of time to compute it when authentication + // instructions are known to always trap on failure. + assert(!AuthTrapsOnFailure && + "DstSafetyAnalysis is useless with faulting auth"); for (BinaryBasicBlock &BB : Func) { if (auto CheckerInfo = BC.MIB->getAuthCheckedReg(BB)) { LLVM_DEBUG({ @@ -1215,7 +1248,7 @@ public: // starting to analyze Inst. if (BC.MIB->isCall(Inst) || BC.MIB->isBranch(Inst) || BC.MIB->isReturn(Inst)) { - LLVM_DEBUG({ traceInst(BC, "Control flow instruction", Inst); }); + LLVM_DEBUG(traceInst(BC, "Control flow instruction", Inst)); S = createUnsafeState(); } @@ -1360,7 +1393,7 @@ shouldReportUnsafeTailCall(const BinaryContext &BC, const BinaryFunction &BF, // such libc, ignore tail calls performed by ELF entry function. if (BC.StartFunctionAddress && *BC.StartFunctionAddress == Inst.getFunction()->getAddress()) { - LLVM_DEBUG({ dbgs() << " Skipping tail call in ELF entry function.\n"; }); + LLVM_DEBUG(dbgs() << " Skipping tail call in ELF entry function.\n"); return std::nullopt; } @@ -1434,7 +1467,7 @@ shouldReportAuthOracle(const BinaryContext &BC, const MCInstReference &Inst, }); if (S.empty()) { - LLVM_DEBUG({ dbgs() << " DstState is empty!\n"; }); + LLVM_DEBUG(dbgs() << " DstState is empty!\n"); return make_generic_report( Inst, "Warning: no state computed for an authentication instruction " "(possibly unreachable)"); @@ -1461,7 +1494,7 @@ collectRegsToTrack(ArrayRef<PartialReport<MCPhysReg>> Reports) { void FunctionAnalysisContext::findUnsafeUses( SmallVector<PartialReport<MCPhysReg>> &Reports) { auto Analysis = SrcSafetyAnalysis::create(BF, AllocatorId, {}); - LLVM_DEBUG({ dbgs() << "Running src register safety analysis...\n"; }); + LLVM_DEBUG(dbgs() << "Running src register safety analysis...\n"); Analysis->run(); LLVM_DEBUG({ dbgs() << "After src register safety analysis:\n"; @@ -1518,8 +1551,7 @@ void FunctionAnalysisContext::findUnsafeUses( const SrcState &S = Analysis->getStateBefore(Inst); if (S.empty()) { - LLVM_DEBUG( - { traceInst(BC, "Instruction has no state, skipping", Inst); }); + LLVM_DEBUG(traceInst(BC, "Instruction has no state, skipping", Inst)); assert(UnreachableBBReported && "Should be reported at least once"); (void)UnreachableBBReported; return; @@ -1546,8 +1578,7 @@ void FunctionAnalysisContext::augmentUnsafeUseReports( SmallVector<MCPhysReg> RegsToTrack = collectRegsToTrack(Reports); // Re-compute the analysis with register tracking. auto Analysis = SrcSafetyAnalysis::create(BF, AllocatorId, RegsToTrack); - LLVM_DEBUG( - { dbgs() << "\nRunning detailed src register safety analysis...\n"; }); + LLVM_DEBUG(dbgs() << "\nRunning detailed src register safety analysis...\n"); Analysis->run(); LLVM_DEBUG({ dbgs() << "After detailed src register safety analysis:\n"; @@ -1557,7 +1588,7 @@ void FunctionAnalysisContext::augmentUnsafeUseReports( // Augment gadget reports. for (auto &Report : Reports) { MCInstReference Location = Report.Issue->Location; - LLVM_DEBUG({ traceInst(BC, "Attaching clobbering info to", Location); }); + LLVM_DEBUG(traceInst(BC, "Attaching clobbering info to", Location)); assert(Report.RequestedDetails && "Should be removed by handleSimpleReports"); auto DetailedInfo = @@ -1571,9 +1602,11 @@ void FunctionAnalysisContext::findUnsafeDefs( SmallVector<PartialReport<MCPhysReg>> &Reports) { if (PacRetGadgetsOnly) return; + if (AuthTrapsOnFailure) + return; auto Analysis = DstSafetyAnalysis::create(BF, AllocatorId, {}); - LLVM_DEBUG({ dbgs() << "Running dst register safety analysis...\n"; }); + LLVM_DEBUG(dbgs() << "Running dst register safety analysis...\n"); Analysis->run(); LLVM_DEBUG({ dbgs() << "After dst register safety analysis:\n"; @@ -1596,8 +1629,7 @@ void FunctionAnalysisContext::augmentUnsafeDefReports( SmallVector<MCPhysReg> RegsToTrack = collectRegsToTrack(Reports); // Re-compute the analysis with register tracking. auto Analysis = DstSafetyAnalysis::create(BF, AllocatorId, RegsToTrack); - LLVM_DEBUG( - { dbgs() << "\nRunning detailed dst register safety analysis...\n"; }); + LLVM_DEBUG(dbgs() << "\nRunning detailed dst register safety analysis...\n"); Analysis->run(); LLVM_DEBUG({ dbgs() << "After detailed dst register safety analysis:\n"; @@ -1607,7 +1639,7 @@ void FunctionAnalysisContext::augmentUnsafeDefReports( // Augment gadget reports. for (auto &Report : Reports) { MCInstReference Location = Report.Issue->Location; - LLVM_DEBUG({ traceInst(BC, "Attaching leakage info to", Location); }); + LLVM_DEBUG(traceInst(BC, "Attaching leakage info to", Location)); assert(Report.RequestedDetails && "Should be removed by handleSimpleReports"); auto DetailedInfo = std::make_shared<LeakageInfo>( |