diff options
137 files changed, 3370 insertions, 2308 deletions
diff --git a/.github/new-prs-labeler.yml b/.github/new-prs-labeler.yml index c49fd1d..efdc42d 100644 --- a/.github/new-prs-labeler.yml +++ b/.github/new-prs-labeler.yml @@ -1096,8 +1096,8 @@ clang:openmp: - llvm/test/Transforms/OpenMP/** clang:temporal-safety: - - clang/include/clang/Analysis/Analyses/LifetimeSafety* - - clang/lib/Analysis/LifetimeSafety* + - clang/include/clang/Analysis/Analyses/LifetimeSafety/** + - clang/lib/Analysis/LifetimeSafety/** - clang/unittests/Analysis/LifetimeSafety* - clang/test/Sema/*lifetime-safety* - clang/test/Sema/*lifetime-analysis* diff --git a/clang-tools-extra/clang-doc/JSONGenerator.cpp b/clang-tools-extra/clang-doc/JSONGenerator.cpp index 6fba211..b17cc80 100644 --- a/clang-tools-extra/clang-doc/JSONGenerator.cpp +++ b/clang-tools-extra/clang-doc/JSONGenerator.cpp @@ -584,12 +584,20 @@ static SmallString<16> determineFileName(Info *I, SmallString<128> &Path) { FileName = RecordSymbolInfo->MangledName; } else if (I->USR == GlobalNamespaceID) FileName = "index"; - else + else if (I->IT == InfoType::IT_namespace) { + for (const auto &NS : I->Namespace) { + FileName += NS.Name; + FileName += "_"; + } + FileName += I->Name; + } else FileName = I->Name; sys::path::append(Path, FileName + ".json"); return FileName; } +// FIXME: Revert back to creating nested directories for namespaces instead of +// putting everything in a flat directory structure. Error JSONGenerator::generateDocs( StringRef RootDir, llvm::StringMap<std::unique_ptr<doc::Info>> Infos, const ClangDocContext &CDCtx) { diff --git a/clang-tools-extra/clang-tidy/bugprone/SuspiciousIncludeCheck.cpp b/clang-tools-extra/clang-tidy/bugprone/SuspiciousIncludeCheck.cpp index 843368e..aaf0594 100644 --- a/clang-tools-extra/clang-tidy/bugprone/SuspiciousIncludeCheck.cpp +++ b/clang-tools-extra/clang-tidy/bugprone/SuspiciousIncludeCheck.cpp @@ -40,8 +40,9 @@ SuspiciousIncludeCheck::SuspiciousIncludeCheck(StringRef Name, ClangTidyContext *Context) : ClangTidyCheck(Name, Context), HeaderFileExtensions(Context->getHeaderFileExtensions()), - ImplementationFileExtensions(Context->getImplementationFileExtensions()) { -} + ImplementationFileExtensions(Context->getImplementationFileExtensions()), + IgnoredRegexString(Options.get("IgnoredRegex").value_or(StringRef{})), + IgnoredRegex(IgnoredRegexString) {} void SuspiciousIncludeCheck::registerPPCallbacks( const SourceManager &SM, Preprocessor *PP, Preprocessor *ModuleExpanderPP) { @@ -49,6 +50,11 @@ void SuspiciousIncludeCheck::registerPPCallbacks( ::std::make_unique<SuspiciousIncludePPCallbacks>(*this, SM, PP)); } +void SuspiciousIncludeCheck::storeOptions(ClangTidyOptions::OptionMap &Opts) { + if (!IgnoredRegexString.empty()) + Options.store(Opts, "IgnoredRegex", IgnoredRegexString); +} + void SuspiciousIncludePPCallbacks::InclusionDirective( SourceLocation HashLoc, const Token &IncludeTok, StringRef FileName, bool IsAngled, CharSourceRange FilenameRange, OptionalFileEntryRef File, @@ -57,6 +63,9 @@ void SuspiciousIncludePPCallbacks::InclusionDirective( if (IncludeTok.getIdentifierInfo()->getPPKeywordID() == tok::pp_import) return; + if (!Check.IgnoredRegexString.empty() && Check.IgnoredRegex.match(FileName)) + return; + SourceLocation DiagLoc = FilenameRange.getBegin().getLocWithOffset(1); const std::optional<StringRef> IFE = diff --git a/clang-tools-extra/clang-tidy/bugprone/SuspiciousIncludeCheck.h b/clang-tools-extra/clang-tidy/bugprone/SuspiciousIncludeCheck.h index 3aa9491e..50fc345 100644 --- a/clang-tools-extra/clang-tidy/bugprone/SuspiciousIncludeCheck.h +++ b/clang-tools-extra/clang-tidy/bugprone/SuspiciousIncludeCheck.h @@ -10,7 +10,6 @@ #define LLVM_CLANG_TOOLS_EXTRA_CLANG_TIDY_BUGPRONE_SUSPICIOUSINCLUDECHECK_H #include "../ClangTidyCheck.h" -#include "../utils/FileExtensionsUtils.h" namespace clang::tidy::bugprone { @@ -28,9 +27,12 @@ public: SuspiciousIncludeCheck(StringRef Name, ClangTidyContext *Context); void registerPPCallbacks(const SourceManager &SM, Preprocessor *PP, Preprocessor *ModuleExpanderPP) override; + void storeOptions(ClangTidyOptions::OptionMap &Opts) override; FileExtensionsSet HeaderFileExtensions; FileExtensionsSet ImplementationFileExtensions; + StringRef IgnoredRegexString; + llvm::Regex IgnoredRegex; }; } // namespace clang::tidy::bugprone diff --git a/clang-tools-extra/docs/ReleaseNotes.rst b/clang-tools-extra/docs/ReleaseNotes.rst index 9aeda03..216d3f5 100644 --- a/clang-tools-extra/docs/ReleaseNotes.rst +++ b/clang-tools-extra/docs/ReleaseNotes.rst @@ -286,6 +286,10 @@ Changes in existing checks <clang-tidy/checks/bugprone/sizeof-expression>` check by fixing a crash on ``sizeof`` of an array of dependent type. +- Improved :doc:`bugprone-suspicious-include + <clang-tidy/checks/bugprone/suspicious-include>` check by adding + `IgnoredRegex` option. + - Improved :doc:`bugprone-tagged-union-member-count <clang-tidy/checks/bugprone/tagged-union-member-count>` by fixing a false positive when enums or unions from system header files or the ``std`` diff --git a/clang-tools-extra/docs/clang-tidy/checks/bugprone/suspicious-include.rst b/clang-tools-extra/docs/clang-tidy/checks/bugprone/suspicious-include.rst index 669654f..4fbfa259 100644 --- a/clang-tools-extra/docs/clang-tidy/checks/bugprone/suspicious-include.rst +++ b/clang-tools-extra/docs/clang-tidy/checks/bugprone/suspicious-include.rst @@ -14,3 +14,11 @@ Examples: #include "Pterodactyl.h" // OK, .h files tend not to have definitions. #include "Velociraptor.cpp" // Warning, filename is suspicious. #include_next <stdio.c> // Warning, filename is suspicious. + +Options +------- + +.. option:: IgnoredRegex + + A regular expression for the file name to be ignored by the check. Default + is empty string. diff --git a/clang-tools-extra/test/clang-doc/json/multiple-namespaces.cpp b/clang-tools-extra/test/clang-doc/json/multiple-namespaces.cpp new file mode 100644 index 0000000..04fcfc1 --- /dev/null +++ b/clang-tools-extra/test/clang-doc/json/multiple-namespaces.cpp @@ -0,0 +1,20 @@ +// RUN: rm -rf %t && mkdir -p %t +// RUN: clang-doc --output=%t --format=json --executor=standalone %s +// RUN: FileCheck %s < %t/json/foo_tools.json --check-prefix=CHECK-FOO +// RUN: FileCheck %s < %t/json/bar_tools.json --check-prefix=CHECK-BAR + +namespace foo { + namespace tools { + class FooTools {}; + } // namespace tools +} // namespace foo + +namespace bar { + namespace tools { + class BarTools {}; + } // namespace tools +} // namespace bar + +// CHECK-FOO: "Name": "tools" + +// CHECK-BAR: "Name": "tools" diff --git a/clang-tools-extra/test/clang-doc/json/nested-namespace.cpp b/clang-tools-extra/test/clang-doc/json/nested-namespace.cpp index b19afc1..cf19e1e 100644 --- a/clang-tools-extra/test/clang-doc/json/nested-namespace.cpp +++ b/clang-tools-extra/test/clang-doc/json/nested-namespace.cpp @@ -1,7 +1,7 @@ // RUN: rm -rf %t && mkdir -p %t // RUN: clang-doc --output=%t --format=json --executor=standalone %s // RUN: FileCheck %s < %t/json/nested.json --check-prefix=NESTED -// RUN: FileCheck %s < %t/json/inner.json --check-prefix=INNER +// RUN: FileCheck %s < %t/json/nested_inner.json --check-prefix=INNER namespace nested { int Global; diff --git a/clang-tools-extra/test/clang-tidy/checkers/Inputs/Headers/moc_foo.cpp b/clang-tools-extra/test/clang-tidy/checkers/Inputs/Headers/moc_foo.cpp new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/clang-tools-extra/test/clang-tidy/checkers/Inputs/Headers/moc_foo.cpp diff --git a/clang-tools-extra/test/clang-tidy/checkers/bugprone/suspicious-include.cpp b/clang-tools-extra/test/clang-tidy/checkers/bugprone/suspicious-include.cpp index 969d0bf..4f2acbc 100644 --- a/clang-tools-extra/test/clang-tidy/checkers/bugprone/suspicious-include.cpp +++ b/clang-tools-extra/test/clang-tidy/checkers/bugprone/suspicious-include.cpp @@ -1,4 +1,6 @@ -// RUN: %check_clang_tidy %s bugprone-suspicious-include %t -- -- -isystem %clang_tidy_headers -fmodules +// RUN: %check_clang_tidy %s bugprone-suspicious-include %t -- \ +// RUN: -config="{CheckOptions: {bugprone-suspicious-include.IgnoredRegex: 'moc_.*'}"} -- \ +// RUN: -isystem %clang_tidy_headers -fmodules // clang-format off @@ -22,3 +24,6 @@ // CHECK-MESSAGES: [[@LINE+1]]:14: warning: suspicious #include of file with '.cxx' extension # include <c.cxx> + +// CHECK-MESSAGES-NOT: warning: +#include "moc_foo.cpp" diff --git a/clang/include/clang/Analysis/Analyses/LifetimeSafety.h b/clang/include/clang/Analysis/Analyses/LifetimeSafety.h deleted file mode 100644 index e54fc26..0000000 --- a/clang/include/clang/Analysis/Analyses/LifetimeSafety.h +++ /dev/null @@ -1,183 +0,0 @@ -//===- LifetimeSafety.h - C++ Lifetime Safety Analysis -*----------- C++-*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines the entry point for a dataflow-based static analysis -// that checks for C++ lifetime violations. -// -// The analysis is based on the concepts of "origins" and "loans" to track -// pointer lifetimes and detect issues like use-after-free and dangling -// pointers. See the RFC for more details: -// https://discourse.llvm.org/t/rfc-intra-procedural-lifetime-analysis-in-clang/86291 -// -//===----------------------------------------------------------------------===// -#ifndef LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_H -#define LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_H -#include "clang/Analysis/AnalysisDeclContext.h" -#include "clang/Analysis/CFG.h" -#include "clang/Basic/SourceLocation.h" -#include "llvm/ADT/DenseMapInfo.h" -#include "llvm/ADT/ImmutableMap.h" -#include "llvm/ADT/ImmutableSet.h" -#include "llvm/ADT/StringMap.h" -#include <memory> - -namespace clang::lifetimes { - -/// Enum to track the confidence level of a potential error. -enum class Confidence : uint8_t { - None, - Maybe, // Reported as a potential error (-Wlifetime-safety-strict) - Definite // Reported as a definite error (-Wlifetime-safety-permissive) -}; - -enum class LivenessKind : uint8_t { - Dead, // Not alive - Maybe, // Live on some path but not all paths (may-be-live) - Must // Live on all paths (must-be-live) -}; - -class LifetimeSafetyReporter { -public: - LifetimeSafetyReporter() = default; - virtual ~LifetimeSafetyReporter() = default; - - virtual void reportUseAfterFree(const Expr *IssueExpr, const Expr *UseExpr, - SourceLocation FreeLoc, - Confidence Confidence) {} -}; - -/// The main entry point for the analysis. -void runLifetimeSafetyAnalysis(AnalysisDeclContext &AC, - LifetimeSafetyReporter *Reporter); - -namespace internal { -// Forward declarations of internal types. -class Fact; -class FactManager; -class LoanPropagationAnalysis; -class ExpiredLoansAnalysis; -class LiveOriginAnalysis; -struct LifetimeFactory; - -/// A generic, type-safe wrapper for an ID, distinguished by its `Tag` type. -/// Used for giving ID to loans and origins. -template <typename Tag> struct ID { - uint32_t Value = 0; - - bool operator==(const ID<Tag> &Other) const { return Value == Other.Value; } - bool operator!=(const ID<Tag> &Other) const { return !(*this == Other); } - bool operator<(const ID<Tag> &Other) const { return Value < Other.Value; } - ID<Tag> operator++(int) { - ID<Tag> Tmp = *this; - ++Value; - return Tmp; - } - void Profile(llvm::FoldingSetNodeID &IDBuilder) const { - IDBuilder.AddInteger(Value); - } -}; - -using LoanID = ID<struct LoanTag>; -using OriginID = ID<struct OriginTag>; -inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, LoanID ID) { - return OS << ID.Value; -} -inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, OriginID ID) { - return OS << ID.Value; -} - -// Using LLVM's immutable collections is efficient for dataflow analysis -// as it avoids deep copies during state transitions. -// TODO(opt): Consider using a bitset to represent the set of loans. -using LoanSet = llvm::ImmutableSet<LoanID>; -using OriginSet = llvm::ImmutableSet<OriginID>; -using OriginLoanMap = llvm::ImmutableMap<OriginID, LoanSet>; - -/// A `ProgramPoint` identifies a location in the CFG by pointing to a specific -/// `Fact`. identified by a lifetime-related event (`Fact`). -/// -/// A `ProgramPoint` has "after" semantics: it represents the location -/// immediately after its corresponding `Fact`. -using ProgramPoint = const Fact *; - -/// Running the lifetime safety analysis and querying its results. It -/// encapsulates the various dataflow analyses. -class LifetimeSafetyAnalysis { -public: - LifetimeSafetyAnalysis(AnalysisDeclContext &AC, - LifetimeSafetyReporter *Reporter); - ~LifetimeSafetyAnalysis(); - - void run(); - - /// Returns the set of loans an origin holds at a specific program point. - LoanSet getLoansAtPoint(OriginID OID, ProgramPoint PP) const; - - /// Returns the set of origins that are live at a specific program point, - /// along with the confidence level of their liveness. - /// - /// An origin is considered live if there are potential future uses of that - /// origin after the given program point. The confidence level indicates - /// whether the origin is definitely live (Definite) due to being domintated - /// by a set of uses or only possibly live (Maybe) only on some but not all - /// control flow paths. - std::vector<std::pair<OriginID, LivenessKind>> - getLiveOriginsAtPoint(ProgramPoint PP) const; - - /// Finds the OriginID for a given declaration. - /// Returns a null optional if not found. - std::optional<OriginID> getOriginIDForDecl(const ValueDecl *D) const; - - /// Finds the LoanID's for the loan created with the specific variable as - /// their Path. - std::vector<LoanID> getLoanIDForVar(const VarDecl *VD) const; - - /// Retrieves program points that were specially marked in the source code - /// for testing. - /// - /// The analysis recognizes special function calls of the form - /// `void("__lifetime_test_point_<name>")` as test points. This method returns - /// a map from the annotation string (<name>) to the corresponding - /// `ProgramPoint`. This allows test harnesses to query the analysis state at - /// user-defined locations in the code. - /// \note This is intended for testing only. - llvm::StringMap<ProgramPoint> getTestPoints() const; - -private: - AnalysisDeclContext &AC; - LifetimeSafetyReporter *Reporter; - std::unique_ptr<LifetimeFactory> Factory; - std::unique_ptr<FactManager> FactMgr; - std::unique_ptr<LoanPropagationAnalysis> LoanPropagation; - std::unique_ptr<LiveOriginAnalysis> LiveOrigins; -}; -} // namespace internal -} // namespace clang::lifetimes - -namespace llvm { -template <typename Tag> -struct DenseMapInfo<clang::lifetimes::internal::ID<Tag>> { - using ID = clang::lifetimes::internal::ID<Tag>; - - static inline ID getEmptyKey() { - return {DenseMapInfo<uint32_t>::getEmptyKey()}; - } - - static inline ID getTombstoneKey() { - return {DenseMapInfo<uint32_t>::getTombstoneKey()}; - } - - static unsigned getHashValue(const ID &Val) { - return DenseMapInfo<uint32_t>::getHashValue(Val.Value); - } - - static bool isEqual(const ID &LHS, const ID &RHS) { return LHS == RHS; } -}; -} // namespace llvm - -#endif // LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_H diff --git a/clang/include/clang/Analysis/Analyses/LifetimeSafety/Checker.h b/clang/include/clang/Analysis/Analyses/LifetimeSafety/Checker.h new file mode 100644 index 0000000..03636be --- /dev/null +++ b/clang/include/clang/Analysis/Analyses/LifetimeSafety/Checker.h @@ -0,0 +1,35 @@ +//===- Checker.h - C++ Lifetime Safety Analysis -*----------- C++-*-=========// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines and enforces the lifetime safety policy. It detects +// use-after-free errors by examining loan expiration points and checking if +// any live origins hold the expired loans. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_CHECKER_H +#define LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_CHECKER_H + +#include "clang/Analysis/Analyses/LifetimeSafety/Facts.h" +#include "clang/Analysis/Analyses/LifetimeSafety/LifetimeSafety.h" +#include "clang/Analysis/Analyses/LifetimeSafety/LiveOrigins.h" +#include "clang/Analysis/Analyses/LifetimeSafety/LoanPropagation.h" + +namespace clang::lifetimes::internal { + +/// Runs the lifetime checker, which detects use-after-free errors by +/// examining loan expiration points and checking if any live origins hold +/// the expired loan. +void runLifetimeChecker(const LoanPropagationAnalysis &LoanPropagation, + const LiveOriginsAnalysis &LiveOrigins, + const FactManager &FactMgr, AnalysisDeclContext &ADC, + LifetimeSafetyReporter *Reporter); + +} // namespace clang::lifetimes::internal + +#endif // LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_CHECKER_H diff --git a/clang/include/clang/Analysis/Analyses/LifetimeSafety/Facts.h b/clang/include/clang/Analysis/Analyses/LifetimeSafety/Facts.h new file mode 100644 index 0000000..6a90aeb --- /dev/null +++ b/clang/include/clang/Analysis/Analyses/LifetimeSafety/Facts.h @@ -0,0 +1,232 @@ +//===- Facts.h - Lifetime Analysis Facts and Fact Manager ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines Facts, which are atomic lifetime-relevant events (such as +// loan issuance, loan expiration, origin flow, and use), and the FactManager, +// which manages the storage and retrieval of facts for each CFG block. +// +//===----------------------------------------------------------------------===// +#ifndef LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_FACTS_H +#define LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_FACTS_H + +#include "clang/Analysis/Analyses/LifetimeSafety/Loans.h" +#include "clang/Analysis/Analyses/LifetimeSafety/Origins.h" +#include "clang/Analysis/AnalysisDeclContext.h" +#include "clang/Analysis/CFG.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" +#include <cstdint> + +namespace clang::lifetimes::internal { +/// An abstract base class for a single, atomic lifetime-relevant event. +class Fact { + +public: + enum class Kind : uint8_t { + /// A new loan is issued from a borrow expression (e.g., &x). + Issue, + /// A loan expires as its underlying storage is freed (e.g., variable goes + /// out of scope). + Expire, + /// An origin is propagated from a source to a destination (e.g., p = q). + /// This can also optionally kill the destination origin before flowing into + /// it. Otherwise, the source's loan set is merged into the destination's + /// loan set. + OriginFlow, + /// An origin escapes the function by flowing into the return value. + ReturnOfOrigin, + /// An origin is used (eg. appears as l-value expression like DeclRefExpr). + Use, + /// A marker for a specific point in the code, for testing. + TestPoint, + }; + +private: + Kind K; + +protected: + Fact(Kind K) : K(K) {} + +public: + virtual ~Fact() = default; + Kind getKind() const { return K; } + + template <typename T> const T *getAs() const { + if (T::classof(this)) + return static_cast<const T *>(this); + return nullptr; + } + + virtual void dump(llvm::raw_ostream &OS, const LoanManager &, + const OriginManager &) const; +}; + +/// A `ProgramPoint` identifies a location in the CFG by pointing to a specific +/// `Fact`. identified by a lifetime-related event (`Fact`). +/// +/// A `ProgramPoint` has "after" semantics: it represents the location +/// immediately after its corresponding `Fact`. +using ProgramPoint = const Fact *; + +class IssueFact : public Fact { + LoanID LID; + OriginID OID; + +public: + static bool classof(const Fact *F) { return F->getKind() == Kind::Issue; } + + IssueFact(LoanID LID, OriginID OID) : Fact(Kind::Issue), LID(LID), OID(OID) {} + LoanID getLoanID() const { return LID; } + OriginID getOriginID() const { return OID; } + void dump(llvm::raw_ostream &OS, const LoanManager &LM, + const OriginManager &OM) const override; +}; + +class ExpireFact : public Fact { + LoanID LID; + SourceLocation ExpiryLoc; + +public: + static bool classof(const Fact *F) { return F->getKind() == Kind::Expire; } + + ExpireFact(LoanID LID, SourceLocation ExpiryLoc) + : Fact(Kind::Expire), LID(LID), ExpiryLoc(ExpiryLoc) {} + + LoanID getLoanID() const { return LID; } + SourceLocation getExpiryLoc() const { return ExpiryLoc; } + + void dump(llvm::raw_ostream &OS, const LoanManager &LM, + const OriginManager &) const override; +}; + +class OriginFlowFact : public Fact { + OriginID OIDDest; + OriginID OIDSrc; + // True if the destination origin should be killed (i.e., its current loans + // cleared) before the source origin's loans are flowed into it. + bool KillDest; + +public: + static bool classof(const Fact *F) { + return F->getKind() == Kind::OriginFlow; + } + + OriginFlowFact(OriginID OIDDest, OriginID OIDSrc, bool KillDest) + : Fact(Kind::OriginFlow), OIDDest(OIDDest), OIDSrc(OIDSrc), + KillDest(KillDest) {} + + OriginID getDestOriginID() const { return OIDDest; } + OriginID getSrcOriginID() const { return OIDSrc; } + bool getKillDest() const { return KillDest; } + + void dump(llvm::raw_ostream &OS, const LoanManager &, + const OriginManager &OM) const override; +}; + +class ReturnOfOriginFact : public Fact { + OriginID OID; + +public: + static bool classof(const Fact *F) { + return F->getKind() == Kind::ReturnOfOrigin; + } + + ReturnOfOriginFact(OriginID OID) : Fact(Kind::ReturnOfOrigin), OID(OID) {} + OriginID getReturnedOriginID() const { return OID; } + void dump(llvm::raw_ostream &OS, const LoanManager &, + const OriginManager &OM) const override; +}; + +class UseFact : public Fact { + const Expr *UseExpr; + // True if this use is a write operation (e.g., left-hand side of assignment). + // Write operations are exempted from use-after-free checks. + bool IsWritten = false; + +public: + static bool classof(const Fact *F) { return F->getKind() == Kind::Use; } + + UseFact(const Expr *UseExpr) : Fact(Kind::Use), UseExpr(UseExpr) {} + + OriginID getUsedOrigin(const OriginManager &OM) const { + // TODO: Remove const cast and make OriginManager::get as const. + return const_cast<OriginManager &>(OM).get(*UseExpr); + } + const Expr *getUseExpr() const { return UseExpr; } + void markAsWritten() { IsWritten = true; } + bool isWritten() const { return IsWritten; } + + void dump(llvm::raw_ostream &OS, const LoanManager &, + const OriginManager &OM) const override; +}; + +/// A dummy-fact used to mark a specific point in the code for testing. +/// It is generated by recognizing a `void("__lifetime_test_point_...")` cast. +class TestPointFact : public Fact { + StringRef Annotation; + +public: + static bool classof(const Fact *F) { return F->getKind() == Kind::TestPoint; } + + explicit TestPointFact(StringRef Annotation) + : Fact(Kind::TestPoint), Annotation(Annotation) {} + + StringRef getAnnotation() const { return Annotation; } + + void dump(llvm::raw_ostream &OS, const LoanManager &, + const OriginManager &) const override; +}; + +class FactManager { +public: + llvm::ArrayRef<const Fact *> getFacts(const CFGBlock *B) const { + auto It = BlockToFactsMap.find(B); + if (It != BlockToFactsMap.end()) + return It->second; + return {}; + } + + void addBlockFacts(const CFGBlock *B, llvm::ArrayRef<Fact *> NewFacts) { + if (!NewFacts.empty()) + BlockToFactsMap[B].assign(NewFacts.begin(), NewFacts.end()); + } + + template <typename FactType, typename... Args> + FactType *createFact(Args &&...args) { + void *Mem = FactAllocator.Allocate<FactType>(); + return new (Mem) FactType(std::forward<Args>(args)...); + } + + void dump(const CFG &Cfg, AnalysisDeclContext &AC) const; + + /// Retrieves program points that were specially marked in the source code + /// for testing. + /// + /// The analysis recognizes special function calls of the form + /// `void("__lifetime_test_point_<name>")` as test points. This method returns + /// a map from the annotation string (<name>) to the corresponding + /// `ProgramPoint`. This allows test harnesses to query the analysis state at + /// user-defined locations in the code. + /// \note This is intended for testing only. + llvm::StringMap<ProgramPoint> getTestPoints() const; + + LoanManager &getLoanMgr() { return LoanMgr; } + const LoanManager &getLoanMgr() const { return LoanMgr; } + OriginManager &getOriginMgr() { return OriginMgr; } + const OriginManager &getOriginMgr() const { return OriginMgr; } + +private: + LoanManager LoanMgr; + OriginManager OriginMgr; + llvm::DenseMap<const clang::CFGBlock *, llvm::SmallVector<const Fact *>> + BlockToFactsMap; + llvm::BumpPtrAllocator FactAllocator; +}; +} // namespace clang::lifetimes::internal + +#endif // LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_FACTS_H diff --git a/clang/include/clang/Analysis/Analyses/LifetimeSafety/FactsGenerator.h b/clang/include/clang/Analysis/Analyses/LifetimeSafety/FactsGenerator.h new file mode 100644 index 0000000..5e58abe --- /dev/null +++ b/clang/include/clang/Analysis/Analyses/LifetimeSafety/FactsGenerator.h @@ -0,0 +1,106 @@ +//===- FactsGenerator.h - Lifetime Facts Generation -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the FactsGenerator, which traverses the AST to generate +// lifetime-relevant facts (such as loan issuance, expiration, origin flow, +// and use) from CFG statements. These facts are used by the dataflow analyses +// to track pointer lifetimes and detect use-after-free errors. +// +//===----------------------------------------------------------------------===// +#ifndef LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_FACTSGENERATOR_H +#define LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_FACTSGENERATOR_H + +#include "clang/AST/StmtVisitor.h" +#include "clang/Analysis/Analyses/LifetimeSafety/Facts.h" +#include "clang/Analysis/Analyses/LifetimeSafety/Origins.h" +#include "clang/Analysis/AnalysisDeclContext.h" +#include "clang/Analysis/CFG.h" +#include "llvm/ADT/SmallVector.h" + +namespace clang::lifetimes::internal { + +class FactsGenerator : public ConstStmtVisitor<FactsGenerator> { + using Base = ConstStmtVisitor<FactsGenerator>; + +public: + FactsGenerator(FactManager &FactMgr, AnalysisDeclContext &AC) + : FactMgr(FactMgr), AC(AC) {} + + void run(); + + void VisitDeclStmt(const DeclStmt *DS); + void VisitDeclRefExpr(const DeclRefExpr *DRE); + void VisitCXXConstructExpr(const CXXConstructExpr *CCE); + void VisitCXXMemberCallExpr(const CXXMemberCallExpr *MCE); + void VisitCallExpr(const CallExpr *CE); + void VisitCXXNullPtrLiteralExpr(const CXXNullPtrLiteralExpr *N); + void VisitImplicitCastExpr(const ImplicitCastExpr *ICE); + void VisitUnaryOperator(const UnaryOperator *UO); + void VisitReturnStmt(const ReturnStmt *RS); + void VisitBinaryOperator(const BinaryOperator *BO); + void VisitCXXOperatorCallExpr(const CXXOperatorCallExpr *OCE); + void VisitCXXFunctionalCastExpr(const CXXFunctionalCastExpr *FCE); + void VisitInitListExpr(const InitListExpr *ILE); + void VisitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *MTE); + +private: + void handleDestructor(const CFGAutomaticObjDtor &DtorOpt); + + void handleGSLPointerConstruction(const CXXConstructExpr *CCE); + + /// Checks if a call-like expression creates a borrow by passing a value to a + /// reference parameter, creating an IssueFact if it does. + /// \param IsGslConstruction True if this is a GSL construction where all + /// argument origins should flow to the returned origin. + void handleFunctionCall(const Expr *Call, const FunctionDecl *FD, + ArrayRef<const Expr *> Args, + bool IsGslConstruction = false); + + template <typename Destination, typename Source> + void flowOrigin(const Destination &D, const Source &S) { + OriginID DestOID = FactMgr.getOriginMgr().getOrCreate(D); + OriginID SrcOID = FactMgr.getOriginMgr().get(S); + CurrentBlockFacts.push_back(FactMgr.createFact<OriginFlowFact>( + DestOID, SrcOID, /*KillDest=*/false)); + } + + template <typename Destination, typename Source> + void killAndFlowOrigin(const Destination &D, const Source &S) { + OriginID DestOID = FactMgr.getOriginMgr().getOrCreate(D); + OriginID SrcOID = FactMgr.getOriginMgr().get(S); + CurrentBlockFacts.push_back( + FactMgr.createFact<OriginFlowFact>(DestOID, SrcOID, /*KillDest=*/true)); + } + + /// Checks if the expression is a `void("__lifetime_test_point_...")` cast. + /// If so, creates a `TestPointFact` and returns true. + bool handleTestPoint(const CXXFunctionalCastExpr *FCE); + + void handleAssignment(const Expr *LHSExpr, const Expr *RHSExpr); + + // A DeclRefExpr will be treated as a use of the referenced decl. It will be + // checked for use-after-free unless it is later marked as being written to + // (e.g. on the left-hand side of an assignment). + void handleUse(const DeclRefExpr *DRE); + + void markUseAsWrite(const DeclRefExpr *DRE); + + FactManager &FactMgr; + AnalysisDeclContext &AC; + llvm::SmallVector<Fact *> CurrentBlockFacts; + // To distinguish between reads and writes for use-after-free checks, this map + // stores the `UseFact` for each `DeclRefExpr`. We initially identify all + // `DeclRefExpr`s as "read" uses. When an assignment is processed, the use + // corresponding to the left-hand side is updated to be a "write", thereby + // exempting it from the check. + llvm::DenseMap<const DeclRefExpr *, UseFact *> UseFacts; +}; + +} // namespace clang::lifetimes::internal + +#endif // LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_FACTSGENERATOR_H diff --git a/clang/include/clang/Analysis/Analyses/LifetimeAnnotations.h b/clang/include/clang/Analysis/Analyses/LifetimeSafety/LifetimeAnnotations.h index 229d16c..f02969e 100644 --- a/clang/include/clang/Analysis/Analyses/LifetimeAnnotations.h +++ b/clang/include/clang/Analysis/Analyses/LifetimeSafety/LifetimeAnnotations.h @@ -12,8 +12,7 @@ #include "clang/AST/DeclCXX.h" -namespace clang { -namespace lifetimes { +namespace clang ::lifetimes { /// Returns the most recent declaration of the method to ensure all /// lifetime-bound attributes from redeclarations are considered. @@ -38,7 +37,7 @@ bool isAssignmentOperatorLifetimeBound(const CXXMethodDecl *CMD); /// lifetimebound, either due to an explicit lifetimebound attribute on the /// method or because it's a normal assignment operator. bool implicitObjectParamIsLifetimeBound(const FunctionDecl *FD); -} // namespace lifetimes -} // namespace clang + +} // namespace clang::lifetimes #endif // LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMEANNOTATIONS_H diff --git a/clang/include/clang/Analysis/Analyses/LifetimeSafety/LifetimeSafety.h b/clang/include/clang/Analysis/Analyses/LifetimeSafety/LifetimeSafety.h new file mode 100644 index 0000000..91ffbb1 --- /dev/null +++ b/clang/include/clang/Analysis/Analyses/LifetimeSafety/LifetimeSafety.h @@ -0,0 +1,87 @@ +//===- LifetimeSafety.h - C++ Lifetime Safety Analysis -*----------- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the main entry point and orchestrator for the C++ Lifetime +// Safety Analysis. It coordinates the entire analysis pipeline: fact +// generation, loan propagation, live origins analysis, and enforcement of +// lifetime safety policy. +// +// The analysis is based on the concepts of "origins" and "loans" to track +// pointer lifetimes and detect issues like use-after-free and dangling +// pointers. See the RFC for more details: +// https://discourse.llvm.org/t/rfc-intra-procedural-lifetime-analysis-in-clang/86291 +// +//===----------------------------------------------------------------------===// +#ifndef LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_H +#define LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_H + +#include "clang/Analysis/Analyses/LifetimeSafety/Facts.h" +#include "clang/Analysis/Analyses/LifetimeSafety/LiveOrigins.h" +#include "clang/Analysis/Analyses/LifetimeSafety/LoanPropagation.h" +#include "clang/Analysis/AnalysisDeclContext.h" + +namespace clang::lifetimes { + +/// Enum to track the confidence level of a potential error. +enum class Confidence : uint8_t { + None, + Maybe, // Reported as a potential error (-Wlifetime-safety-strict) + Definite // Reported as a definite error (-Wlifetime-safety-permissive) +}; + +class LifetimeSafetyReporter { +public: + LifetimeSafetyReporter() = default; + virtual ~LifetimeSafetyReporter() = default; + + virtual void reportUseAfterFree(const Expr *IssueExpr, const Expr *UseExpr, + SourceLocation FreeLoc, + Confidence Confidence) {} +}; + +/// The main entry point for the analysis. +void runLifetimeSafetyAnalysis(AnalysisDeclContext &AC, + LifetimeSafetyReporter *Reporter); + +namespace internal { +/// An object to hold the factories for immutable collections, ensuring +/// that all created states share the same underlying memory management. +struct LifetimeFactory { + OriginLoanMap::Factory OriginMapFactory{/*canonicalize=*/false}; + LoanSet::Factory LoanSetFactory{/*canonicalize=*/false}; + LivenessMap::Factory LivenessMapFactory{/*canonicalize=*/false}; +}; + +/// Running the lifetime safety analysis and querying its results. It +/// encapsulates the various dataflow analyses. +class LifetimeSafetyAnalysis { +public: + LifetimeSafetyAnalysis(AnalysisDeclContext &AC, + LifetimeSafetyReporter *Reporter); + + void run(); + + /// \note These are provided only for testing purposes. + LoanPropagationAnalysis &getLoanPropagation() const { + return *LoanPropagation; + } + LiveOriginsAnalysis &getLiveOrigins() const { return *LiveOrigins; } + FactManager &getFactManager() { return FactMgr; } + +private: + AnalysisDeclContext &AC; + LifetimeSafetyReporter *Reporter; + LifetimeFactory Factory; + FactManager FactMgr; + std::unique_ptr<LiveOriginsAnalysis> LiveOrigins; + std::unique_ptr<LoanPropagationAnalysis> LoanPropagation; +}; +} // namespace internal +} // namespace clang::lifetimes + +#endif // LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_H diff --git a/clang/include/clang/Analysis/Analyses/LifetimeSafety/LiveOrigins.h b/clang/include/clang/Analysis/Analyses/LifetimeSafety/LiveOrigins.h new file mode 100644 index 0000000..c4f5f0e --- /dev/null +++ b/clang/include/clang/Analysis/Analyses/LifetimeSafety/LiveOrigins.h @@ -0,0 +1,97 @@ +//===- LiveOrigins.h - Live Origins Analysis -------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the LiveOriginAnalysis, a backward dataflow analysis that +// determines which origins are "live" at each program point. An origin is +// "live" at a program point if there's a potential future use of a pointer it +// is associated with. Liveness is "generated" by a use of an origin (e.g., a +// `UseFact` from a read of a pointer) and is "killed" (i.e., it stops being +// live) when the origin is replaced by flowing a different origin into it +// (e.g., an OriginFlow from an assignment that kills the destination). +// +// This information is used for detecting use-after-free errors, as it allows us +// to check if a live origin holds a loan to an object that has already expired. +// +//===----------------------------------------------------------------------===// +#ifndef LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_LIVE_ORIGINS_H +#define LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_LIVE_ORIGINS_H + +#include "clang/Analysis/Analyses/LifetimeSafety/Facts.h" +#include "clang/Analysis/Analyses/LifetimeSafety/Origins.h" +#include "clang/Analysis/AnalysisDeclContext.h" +#include "clang/Analysis/CFG.h" +#include "llvm/ADT/FoldingSet.h" +#include "llvm/ADT/ImmutableMap.h" +#include "llvm/Support/Debug.h" + +namespace clang::lifetimes::internal { + +enum class LivenessKind : uint8_t { + Dead, // Not alive + Maybe, // Live on some path but not all paths (may-be-live) + Must // Live on all paths (must-be-live) +}; + +/// Information about why an origin is live at a program point. +struct LivenessInfo { + /// The use that makes the origin live. If liveness is propagated from + /// multiple uses along different paths, this will point to the use appearing + /// earlier in the translation unit. + /// This is 'null' when the origin is not live. + const UseFact *CausingUseFact; + + /// The kind of liveness of the origin. + /// `Must`: The origin is live on all control-flow paths from the current + /// point to the function's exit (i.e. the current point is dominated by a set + /// of uses). + /// `Maybe`: indicates it is live on some but not all paths. + /// + /// This determines the diagnostic's confidence level. + /// `Must`-be-alive at expiration implies a definite use-after-free, + /// while `Maybe`-be-alive suggests a potential one on some paths. + LivenessKind Kind; + + LivenessInfo() : CausingUseFact(nullptr), Kind(LivenessKind::Dead) {} + LivenessInfo(const UseFact *UF, LivenessKind K) + : CausingUseFact(UF), Kind(K) {} + + bool operator==(const LivenessInfo &Other) const { + return CausingUseFact == Other.CausingUseFact && Kind == Other.Kind; + } + bool operator!=(const LivenessInfo &Other) const { return !(*this == Other); } + + void Profile(llvm::FoldingSetNodeID &IDBuilder) const { + IDBuilder.AddPointer(CausingUseFact); + IDBuilder.Add(Kind); + } +}; + +using LivenessMap = llvm::ImmutableMap<OriginID, LivenessInfo>; + +class LiveOriginsAnalysis { +public: + LiveOriginsAnalysis(const CFG &C, AnalysisDeclContext &AC, FactManager &F, + LivenessMap::Factory &SF); + ~LiveOriginsAnalysis(); + + /// Returns the set of origins that are live at a specific program point, + /// along with the the details of the liveness. + LivenessMap getLiveOriginsAt(ProgramPoint P) const; + + // Dump liveness values on all test points in the program. + void dump(llvm::raw_ostream &OS, + llvm::StringMap<ProgramPoint> TestPoints) const; + +private: + class Impl; + std::unique_ptr<Impl> PImpl; +}; + +} // namespace clang::lifetimes::internal + +#endif // LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_LIVE_ORIGINS_H diff --git a/clang/include/clang/Analysis/Analyses/LifetimeSafety/LoanPropagation.h b/clang/include/clang/Analysis/Analyses/LifetimeSafety/LoanPropagation.h new file mode 100644 index 0000000..447d05c --- /dev/null +++ b/clang/include/clang/Analysis/Analyses/LifetimeSafety/LoanPropagation.h @@ -0,0 +1,48 @@ +//===- LoanPropagation.h - Loan Propagation Analysis -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the LoanPropagationAnalysis, a forward dataflow analysis +// that tracks which loans each origin holds at each program point. Loans +// represent borrows of storage locations and are propagated through the +// program as pointers are copied or assigned. +// +//===----------------------------------------------------------------------===// +#ifndef LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_LOAN_PROPAGATION_H +#define LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_LOAN_PROPAGATION_H + +#include "clang/Analysis/Analyses/LifetimeSafety/Facts.h" +#include "clang/Analysis/AnalysisDeclContext.h" +#include "clang/Analysis/CFG.h" +#include "llvm/ADT/ImmutableMap.h" +#include "llvm/ADT/ImmutableSet.h" + +namespace clang::lifetimes::internal { + +// Using LLVM's immutable collections is efficient for dataflow analysis +// as it avoids deep copies during state transitions. +// TODO(opt): Consider using a bitset to represent the set of loans. +using LoanSet = llvm::ImmutableSet<LoanID>; +using OriginLoanMap = llvm::ImmutableMap<OriginID, LoanSet>; + +class LoanPropagationAnalysis { +public: + LoanPropagationAnalysis(const CFG &C, AnalysisDeclContext &AC, FactManager &F, + OriginLoanMap::Factory &OriginLoanMapFactory, + LoanSet::Factory &LoanSetFactory); + ~LoanPropagationAnalysis(); + + LoanSet getLoans(OriginID OID, ProgramPoint P) const; + +private: + class Impl; + std::unique_ptr<Impl> PImpl; +}; + +} // namespace clang::lifetimes::internal + +#endif // LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_LOAN_PROPAGATION_H diff --git a/clang/include/clang/Analysis/Analyses/LifetimeSafety/Loans.h b/clang/include/clang/Analysis/Analyses/LifetimeSafety/Loans.h new file mode 100644 index 0000000..7f5cf03 --- /dev/null +++ b/clang/include/clang/Analysis/Analyses/LifetimeSafety/Loans.h @@ -0,0 +1,80 @@ +//===- Loans.h - Loan and Access Path Definitions --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the Loan and AccessPath structures, which represent +// borrows of storage locations, and the LoanManager, which manages the +// creation and retrieval of loans during lifetime analysis. +// +//===----------------------------------------------------------------------===// +#ifndef LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_LOANS_H +#define LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_LOANS_H + +#include "clang/AST/Decl.h" +#include "clang/Analysis/Analyses/LifetimeSafety/Utils.h" +#include "llvm/Support/raw_ostream.h" + +namespace clang::lifetimes::internal { + +using LoanID = utils::ID<struct LoanTag>; +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, LoanID ID) { + return OS << ID.Value; +} + +/// Represents the storage location being borrowed, e.g., a specific stack +/// variable. +/// TODO: Model access paths of other types, e.g., s.field, heap and globals. +struct AccessPath { + const clang::ValueDecl *D; + + AccessPath(const clang::ValueDecl *D) : D(D) {} +}; + +/// Information about a single borrow, or "Loan". A loan is created when a +/// reference or pointer is created. +struct Loan { + /// TODO: Represent opaque loans. + /// TODO: Represent nullptr: loans to no path. Accessing it UB! Currently it + /// is represented as empty LoanSet + LoanID ID; + AccessPath Path; + /// The expression that creates the loan, e.g., &x. + const Expr *IssueExpr; + + Loan(LoanID id, AccessPath path, const Expr *IssueExpr) + : ID(id), Path(path), IssueExpr(IssueExpr) {} + + void dump(llvm::raw_ostream &OS) const; +}; + +/// Manages the creation, storage and retrieval of loans. +class LoanManager { +public: + LoanManager() = default; + + Loan &addLoan(AccessPath Path, const Expr *IssueExpr) { + AllLoans.emplace_back(getNextLoanID(), Path, IssueExpr); + return AllLoans.back(); + } + + const Loan &getLoan(LoanID ID) const { + assert(ID.Value < AllLoans.size()); + return AllLoans[ID.Value]; + } + llvm::ArrayRef<Loan> getLoans() const { return AllLoans; } + +private: + LoanID getNextLoanID() { return NextLoanID++; } + + LoanID NextLoanID{0}; + /// TODO(opt): Profile and evaluate the usefullness of small buffer + /// optimisation. + llvm::SmallVector<Loan> AllLoans; +}; +} // namespace clang::lifetimes::internal + +#endif // LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_LOANS_H diff --git a/clang/include/clang/Analysis/Analyses/LifetimeSafety/Origins.h b/clang/include/clang/Analysis/Analyses/LifetimeSafety/Origins.h new file mode 100644 index 0000000..ba138b0 --- /dev/null +++ b/clang/include/clang/Analysis/Analyses/LifetimeSafety/Origins.h @@ -0,0 +1,91 @@ +//===- Origins.h - Origin and Origin Management ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines Origins, which represent the set of possible loans a +// pointer-like object could hold, and the OriginManager, which manages the +// creation, storage, and retrieval of origins for variables and expressions. +// +//===----------------------------------------------------------------------===// +#ifndef LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_ORIGINS_H +#define LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_ORIGINS_H + +#include "clang/AST/Decl.h" +#include "clang/AST/Expr.h" +#include "clang/Analysis/Analyses/LifetimeSafety/Utils.h" + +namespace clang::lifetimes::internal { + +using OriginID = utils::ID<struct OriginTag>; + +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, OriginID ID) { + return OS << ID.Value; +} + +/// An Origin is a symbolic identifier that represents the set of possible +/// loans a pointer-like object could hold at any given time. +/// TODO: Enhance the origin model to handle complex types, pointer +/// indirection and reborrowing. The plan is to move from a single origin per +/// variable/expression to a "list of origins" governed by the Type. +/// For example, the type 'int**' would have two origins. +/// See discussion: +/// https://github.com/llvm/llvm-project/pull/142313/commits/0cd187b01e61b200d92ca0b640789c1586075142#r2137644238 +struct Origin { + OriginID ID; + /// A pointer to the AST node that this origin represents. This union + /// distinguishes between origins from declarations (variables or parameters) + /// and origins from expressions. + llvm::PointerUnion<const clang::ValueDecl *, const clang::Expr *> Ptr; + + Origin(OriginID ID, const clang::ValueDecl *D) : ID(ID), Ptr(D) {} + Origin(OriginID ID, const clang::Expr *E) : ID(ID), Ptr(E) {} + + const clang::ValueDecl *getDecl() const { + return Ptr.dyn_cast<const clang::ValueDecl *>(); + } + const clang::Expr *getExpr() const { + return Ptr.dyn_cast<const clang::Expr *>(); + } +}; + +/// Manages the creation, storage, and retrieval of origins for pointer-like +/// variables and expressions. +class OriginManager { +public: + OriginManager() = default; + + Origin &addOrigin(OriginID ID, const clang::ValueDecl &D); + Origin &addOrigin(OriginID ID, const clang::Expr &E); + + // TODO: Mark this method as const once we remove the call to getOrCreate. + OriginID get(const Expr &E); + + OriginID get(const ValueDecl &D); + + OriginID getOrCreate(const Expr &E); + + const Origin &getOrigin(OriginID ID) const; + + llvm::ArrayRef<Origin> getOrigins() const { return AllOrigins; } + + OriginID getOrCreate(const ValueDecl &D); + + void dump(OriginID OID, llvm::raw_ostream &OS) const; + +private: + OriginID getNextOriginID() { return NextOriginID++; } + + OriginID NextOriginID{0}; + /// TODO(opt): Profile and evaluate the usefullness of small buffer + /// optimisation. + llvm::SmallVector<Origin> AllOrigins; + llvm::DenseMap<const clang::ValueDecl *, OriginID> DeclToOriginID; + llvm::DenseMap<const clang::Expr *, OriginID> ExprToOriginID; +}; +} // namespace clang::lifetimes::internal + +#endif // LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_ORIGINS_H diff --git a/clang/include/clang/Analysis/Analyses/LifetimeSafety/Utils.h b/clang/include/clang/Analysis/Analyses/LifetimeSafety/Utils.h new file mode 100644 index 0000000..4183cab --- /dev/null +++ b/clang/include/clang/Analysis/Analyses/LifetimeSafety/Utils.h @@ -0,0 +1,118 @@ +//===- Utils.h - Utility Functions for Lifetime Safety --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// This file provides utilities for the lifetime safety analysis, including +// join operations for LLVM's immutable data structures. +// +//===----------------------------------------------------------------------===// +#ifndef LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_UTILS_H +#define LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_UTILS_H + +#include "llvm/ADT/ImmutableMap.h" +#include "llvm/ADT/ImmutableSet.h" + +namespace clang::lifetimes::internal::utils { + +/// A generic, type-safe wrapper for an ID, distinguished by its `Tag` type. +/// Used for giving ID to loans and origins. +template <typename Tag> struct ID { + uint32_t Value = 0; + + bool operator==(const ID<Tag> &Other) const { return Value == Other.Value; } + bool operator!=(const ID<Tag> &Other) const { return !(*this == Other); } + bool operator<(const ID<Tag> &Other) const { return Value < Other.Value; } + ID<Tag> operator++(int) { + ID<Tag> Tmp = *this; + ++Value; + return Tmp; + } + void Profile(llvm::FoldingSetNodeID &IDBuilder) const { + IDBuilder.AddInteger(Value); + } +}; + +/// Computes the union of two ImmutableSets. +template <typename T> +static llvm::ImmutableSet<T> join(llvm::ImmutableSet<T> A, + llvm::ImmutableSet<T> B, + typename llvm::ImmutableSet<T>::Factory &F) { + if (A.getHeight() < B.getHeight()) + std::swap(A, B); + for (const T &E : B) + A = F.add(A, E); + return A; +} + +/// Describes the strategy for joining two `ImmutableMap` instances, primarily +/// differing in how they handle keys that are unique to one of the maps. +/// +/// A `Symmetric` join is universally correct, while an `Asymmetric` join +/// serves as a performance optimization. The latter is applicable only when the +/// join operation possesses a left identity element, allowing for a more +/// efficient, one-sided merge. +enum class JoinKind { + /// A symmetric join applies the `JoinValues` operation to keys unique to + /// either map, ensuring that values from both maps contribute to the result. + Symmetric, + /// An asymmetric join preserves keys unique to the first map as-is, while + /// applying the `JoinValues` operation only to keys unique to the second map. + Asymmetric, +}; + +/// Computes the key-wise union of two ImmutableMaps. +// TODO(opt): This key-wise join is a performance bottleneck. A more +// efficient merge could be implemented using a Patricia Trie or HAMT +// instead of the current AVL-tree-based ImmutableMap. +template <typename K, typename V, typename Joiner> +static llvm::ImmutableMap<K, V> +join(const llvm::ImmutableMap<K, V> &A, const llvm::ImmutableMap<K, V> &B, + typename llvm::ImmutableMap<K, V>::Factory &F, Joiner JoinValues, + JoinKind Kind) { + if (A.getHeight() < B.getHeight()) + return join(B, A, F, JoinValues, Kind); + + // For each element in B, join it with the corresponding element in A + // (or with an empty value if it doesn't exist in A). + llvm::ImmutableMap<K, V> Res = A; + for (const auto &Entry : B) { + const K &Key = Entry.first; + const V &ValB = Entry.second; + Res = F.add(Res, Key, JoinValues(A.lookup(Key), &ValB)); + } + if (Kind == JoinKind::Symmetric) { + for (const auto &Entry : A) { + const K &Key = Entry.first; + const V &ValA = Entry.second; + if (!B.contains(Key)) + Res = F.add(Res, Key, JoinValues(&ValA, nullptr)); + } + } + return Res; +} +} // namespace clang::lifetimes::internal::utils + +namespace llvm { +template <typename Tag> +struct DenseMapInfo<clang::lifetimes::internal::utils::ID<Tag>> { + using ID = clang::lifetimes::internal::utils::ID<Tag>; + + static inline ID getEmptyKey() { + return {DenseMapInfo<uint32_t>::getEmptyKey()}; + } + + static inline ID getTombstoneKey() { + return {DenseMapInfo<uint32_t>::getTombstoneKey()}; + } + + static unsigned getHashValue(const ID &Val) { + return DenseMapInfo<uint32_t>::getHashValue(Val.Value); + } + + static bool isEqual(const ID &LHS, const ID &RHS) { return LHS == RHS; } +}; +} // namespace llvm + +#endif // LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_UTILS_H diff --git a/clang/include/clang/Basic/CodeGenOptions.def b/clang/include/clang/Basic/CodeGenOptions.def index d924cb4..90e1f8d 100644 --- a/clang/include/clang/Basic/CodeGenOptions.def +++ b/clang/include/clang/Basic/CodeGenOptions.def @@ -72,6 +72,8 @@ CODEGENOPT(EnableNoundefAttrs, 1, 0, Benign) ///< Enable emitting `noundef` attr CODEGENOPT(DebugPassManager, 1, 0, Benign) ///< Prints debug information for the new ///< pass manager. CODEGENOPT(DisableRedZone , 1, 0, Benign) ///< Set when -mno-red-zone is enabled. +CODEGENOPT(CallGraphSection, 1, 0, Benign) ///< Emit a call graph section into the + ///< object file. CODEGENOPT(EmitCallSiteInfo, 1, 0, Benign) ///< Emit call site info only in the case of ///< '-g' + 'O>0' level. CODEGENOPT(IndirectTlsSegRefs, 1, 0, Benign) ///< Set when -mno-tls-direct-seg-refs diff --git a/clang/include/clang/Driver/Options.td b/clang/include/clang/Driver/Options.td index c2f2ac5..a55a523 100644 --- a/clang/include/clang/Driver/Options.td +++ b/clang/include/clang/Driver/Options.td @@ -4534,6 +4534,12 @@ defm data_sections : BoolFOption<"data-sections", PosFlag<SetTrue, [], [ClangOption, CC1Option], "Place each data in its own section">, NegFlag<SetFalse>>; +defm experimental_call_graph_section + : BoolFOption<"experimental-call-graph-section", + CodeGenOpts<"CallGraphSection">, DefaultFalse, + PosFlag<SetTrue, [], [ClangOption, CC1Option], + "Emit a call graph section">, + NegFlag<SetFalse>>; defm stack_size_section : BoolFOption<"stack-size-section", CodeGenOpts<"StackSizeSection">, DefaultFalse, PosFlag<SetTrue, [], [ClangOption, CC1Option], diff --git a/clang/lib/Analysis/CMakeLists.txt b/clang/lib/Analysis/CMakeLists.txt index 5a26f3e..1dbd415 100644 --- a/clang/lib/Analysis/CMakeLists.txt +++ b/clang/lib/Analysis/CMakeLists.txt @@ -21,8 +21,6 @@ add_clang_library(clangAnalysis FixitUtil.cpp IntervalPartition.cpp IssueHash.cpp - LifetimeAnnotations.cpp - LifetimeSafety.cpp LiveVariables.cpp MacroExpansionContext.cpp ObjCNoReturn.cpp @@ -51,3 +49,4 @@ add_clang_library(clangAnalysis add_subdirectory(plugins) add_subdirectory(FlowSensitive) +add_subdirectory(LifetimeSafety) diff --git a/clang/lib/Analysis/LifetimeSafety.cpp b/clang/lib/Analysis/LifetimeSafety.cpp deleted file mode 100644 index 6196ec3..0000000 --- a/clang/lib/Analysis/LifetimeSafety.cpp +++ /dev/null @@ -1,1546 +0,0 @@ -//===- LifetimeSafety.cpp - C++ Lifetime Safety Analysis -*--------- C++-*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -#include "clang/Analysis/Analyses/LifetimeSafety.h" -#include "clang/AST/Decl.h" -#include "clang/AST/Expr.h" -#include "clang/AST/StmtVisitor.h" -#include "clang/AST/Type.h" -#include "clang/Analysis/Analyses/LifetimeAnnotations.h" -#include "clang/Analysis/Analyses/PostOrderCFGView.h" -#include "clang/Analysis/AnalysisDeclContext.h" -#include "clang/Analysis/CFG.h" -#include "clang/Analysis/FlowSensitive/DataflowWorklist.h" -#include "llvm/ADT/FoldingSet.h" -#include "llvm/ADT/ImmutableMap.h" -#include "llvm/ADT/ImmutableSet.h" -#include "llvm/ADT/PointerUnion.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/TimeProfiler.h" -#include <cstdint> -#include <memory> -#include <optional> - -namespace clang::lifetimes { -namespace internal { - -/// Represents the storage location being borrowed, e.g., a specific stack -/// variable. -/// TODO: Model access paths of other types, e.g., s.field, heap and globals. -struct AccessPath { - const clang::ValueDecl *D; - - AccessPath(const clang::ValueDecl *D) : D(D) {} -}; - -/// Information about a single borrow, or "Loan". A loan is created when a -/// reference or pointer is created. -struct Loan { - /// TODO: Represent opaque loans. - /// TODO: Represent nullptr: loans to no path. Accessing it UB! Currently it - /// is represented as empty LoanSet - LoanID ID; - AccessPath Path; - /// The expression that creates the loan, e.g., &x. - const Expr *IssueExpr; - - Loan(LoanID id, AccessPath path, const Expr *IssueExpr) - : ID(id), Path(path), IssueExpr(IssueExpr) {} - - void dump(llvm::raw_ostream &OS) const { - OS << ID << " (Path: "; - OS << Path.D->getNameAsString() << ")"; - } -}; - -/// An Origin is a symbolic identifier that represents the set of possible -/// loans a pointer-like object could hold at any given time. -/// TODO: Enhance the origin model to handle complex types, pointer -/// indirection and reborrowing. The plan is to move from a single origin per -/// variable/expression to a "list of origins" governed by the Type. -/// For example, the type 'int**' would have two origins. -/// See discussion: -/// https://github.com/llvm/llvm-project/pull/142313/commits/0cd187b01e61b200d92ca0b640789c1586075142#r2137644238 -struct Origin { - OriginID ID; - /// A pointer to the AST node that this origin represents. This union - /// distinguishes between origins from declarations (variables or parameters) - /// and origins from expressions. - llvm::PointerUnion<const clang::ValueDecl *, const clang::Expr *> Ptr; - - Origin(OriginID ID, const clang::ValueDecl *D) : ID(ID), Ptr(D) {} - Origin(OriginID ID, const clang::Expr *E) : ID(ID), Ptr(E) {} - - const clang::ValueDecl *getDecl() const { - return Ptr.dyn_cast<const clang::ValueDecl *>(); - } - const clang::Expr *getExpr() const { - return Ptr.dyn_cast<const clang::Expr *>(); - } -}; - -/// Manages the creation, storage and retrieval of loans. -class LoanManager { -public: - LoanManager() = default; - - Loan &addLoan(AccessPath Path, const Expr *IssueExpr) { - AllLoans.emplace_back(getNextLoanID(), Path, IssueExpr); - return AllLoans.back(); - } - - const Loan &getLoan(LoanID ID) const { - assert(ID.Value < AllLoans.size()); - return AllLoans[ID.Value]; - } - llvm::ArrayRef<Loan> getLoans() const { return AllLoans; } - -private: - LoanID getNextLoanID() { return NextLoanID++; } - - LoanID NextLoanID{0}; - /// TODO(opt): Profile and evaluate the usefullness of small buffer - /// optimisation. - llvm::SmallVector<Loan> AllLoans; -}; - -/// Manages the creation, storage, and retrieval of origins for pointer-like -/// variables and expressions. -class OriginManager { -public: - OriginManager() = default; - - Origin &addOrigin(OriginID ID, const clang::ValueDecl &D) { - AllOrigins.emplace_back(ID, &D); - return AllOrigins.back(); - } - Origin &addOrigin(OriginID ID, const clang::Expr &E) { - AllOrigins.emplace_back(ID, &E); - return AllOrigins.back(); - } - - // TODO: Mark this method as const once we remove the call to getOrCreate. - OriginID get(const Expr &E) { - auto It = ExprToOriginID.find(&E); - if (It != ExprToOriginID.end()) - return It->second; - // If the expression itself has no specific origin, and it's a reference - // to a declaration, its origin is that of the declaration it refers to. - // For pointer types, where we don't pre-emptively create an origin for the - // DeclRefExpr itself. - if (const auto *DRE = dyn_cast<DeclRefExpr>(&E)) - return get(*DRE->getDecl()); - // TODO: This should be an assert(It != ExprToOriginID.end()). The current - // implementation falls back to getOrCreate to avoid crashing on - // yet-unhandled pointer expressions, creating an empty origin for them. - return getOrCreate(E); - } - - OriginID get(const ValueDecl &D) { - auto It = DeclToOriginID.find(&D); - // TODO: This should be an assert(It != DeclToOriginID.end()). The current - // implementation falls back to getOrCreate to avoid crashing on - // yet-unhandled pointer expressions, creating an empty origin for them. - if (It == DeclToOriginID.end()) - return getOrCreate(D); - - return It->second; - } - - OriginID getOrCreate(const Expr &E) { - auto It = ExprToOriginID.find(&E); - if (It != ExprToOriginID.end()) - return It->second; - - OriginID NewID = getNextOriginID(); - addOrigin(NewID, E); - ExprToOriginID[&E] = NewID; - return NewID; - } - - const Origin &getOrigin(OriginID ID) const { - assert(ID.Value < AllOrigins.size()); - return AllOrigins[ID.Value]; - } - - llvm::ArrayRef<Origin> getOrigins() const { return AllOrigins; } - - OriginID getOrCreate(const ValueDecl &D) { - auto It = DeclToOriginID.find(&D); - if (It != DeclToOriginID.end()) - return It->second; - OriginID NewID = getNextOriginID(); - addOrigin(NewID, D); - DeclToOriginID[&D] = NewID; - return NewID; - } - - void dump(OriginID OID, llvm::raw_ostream &OS) const { - OS << OID << " ("; - Origin O = getOrigin(OID); - if (const ValueDecl *VD = O.getDecl()) - OS << "Decl: " << VD->getNameAsString(); - else if (const Expr *E = O.getExpr()) - OS << "Expr: " << E->getStmtClassName(); - else - OS << "Unknown"; - OS << ")"; - } - -private: - OriginID getNextOriginID() { return NextOriginID++; } - - OriginID NextOriginID{0}; - /// TODO(opt): Profile and evaluate the usefullness of small buffer - /// optimisation. - llvm::SmallVector<Origin> AllOrigins; - llvm::DenseMap<const clang::ValueDecl *, OriginID> DeclToOriginID; - llvm::DenseMap<const clang::Expr *, OriginID> ExprToOriginID; -}; - -/// An abstract base class for a single, atomic lifetime-relevant event. -class Fact { - -public: - enum class Kind : uint8_t { - /// A new loan is issued from a borrow expression (e.g., &x). - Issue, - /// A loan expires as its underlying storage is freed (e.g., variable goes - /// out of scope). - Expire, - /// An origin is propagated from a source to a destination (e.g., p = q). - /// This can also optionally kill the destination origin before flowing into - /// it. Otherwise, the source's loan set is merged into the destination's - /// loan set. - OriginFlow, - /// An origin escapes the function by flowing into the return value. - ReturnOfOrigin, - /// An origin is used (eg. appears as l-value expression like DeclRefExpr). - Use, - /// A marker for a specific point in the code, for testing. - TestPoint, - }; - -private: - Kind K; - -protected: - Fact(Kind K) : K(K) {} - -public: - virtual ~Fact() = default; - Kind getKind() const { return K; } - - template <typename T> const T *getAs() const { - if (T::classof(this)) - return static_cast<const T *>(this); - return nullptr; - } - - virtual void dump(llvm::raw_ostream &OS, const LoanManager &, - const OriginManager &) const { - OS << "Fact (Kind: " << static_cast<int>(K) << ")\n"; - } -}; - -class IssueFact : public Fact { - LoanID LID; - OriginID OID; - -public: - static bool classof(const Fact *F) { return F->getKind() == Kind::Issue; } - - IssueFact(LoanID LID, OriginID OID) : Fact(Kind::Issue), LID(LID), OID(OID) {} - LoanID getLoanID() const { return LID; } - OriginID getOriginID() const { return OID; } - void dump(llvm::raw_ostream &OS, const LoanManager &LM, - const OriginManager &OM) const override { - OS << "Issue ("; - LM.getLoan(getLoanID()).dump(OS); - OS << ", ToOrigin: "; - OM.dump(getOriginID(), OS); - OS << ")\n"; - } -}; - -class ExpireFact : public Fact { - LoanID LID; - SourceLocation ExpiryLoc; - -public: - static bool classof(const Fact *F) { return F->getKind() == Kind::Expire; } - - ExpireFact(LoanID LID, SourceLocation ExpiryLoc) - : Fact(Kind::Expire), LID(LID), ExpiryLoc(ExpiryLoc) {} - - LoanID getLoanID() const { return LID; } - SourceLocation getExpiryLoc() const { return ExpiryLoc; } - - void dump(llvm::raw_ostream &OS, const LoanManager &LM, - const OriginManager &) const override { - OS << "Expire ("; - LM.getLoan(getLoanID()).dump(OS); - OS << ")\n"; - } -}; - -class OriginFlowFact : public Fact { - OriginID OIDDest; - OriginID OIDSrc; - // True if the destination origin should be killed (i.e., its current loans - // cleared) before the source origin's loans are flowed into it. - bool KillDest; - -public: - static bool classof(const Fact *F) { - return F->getKind() == Kind::OriginFlow; - } - - OriginFlowFact(OriginID OIDDest, OriginID OIDSrc, bool KillDest) - : Fact(Kind::OriginFlow), OIDDest(OIDDest), OIDSrc(OIDSrc), - KillDest(KillDest) {} - - OriginID getDestOriginID() const { return OIDDest; } - OriginID getSrcOriginID() const { return OIDSrc; } - bool getKillDest() const { return KillDest; } - - void dump(llvm::raw_ostream &OS, const LoanManager &, - const OriginManager &OM) const override { - OS << "OriginFlow (Dest: "; - OM.dump(getDestOriginID(), OS); - OS << ", Src: "; - OM.dump(getSrcOriginID(), OS); - OS << (getKillDest() ? "" : ", Merge"); - OS << ")\n"; - } -}; - -class ReturnOfOriginFact : public Fact { - OriginID OID; - -public: - static bool classof(const Fact *F) { - return F->getKind() == Kind::ReturnOfOrigin; - } - - ReturnOfOriginFact(OriginID OID) : Fact(Kind::ReturnOfOrigin), OID(OID) {} - OriginID getReturnedOriginID() const { return OID; } - void dump(llvm::raw_ostream &OS, const LoanManager &, - const OriginManager &OM) const override { - OS << "ReturnOfOrigin ("; - OM.dump(getReturnedOriginID(), OS); - OS << ")\n"; - } -}; - -class UseFact : public Fact { - const Expr *UseExpr; - // True if this use is a write operation (e.g., left-hand side of assignment). - // Write operations are exempted from use-after-free checks. - bool IsWritten = false; - -public: - static bool classof(const Fact *F) { return F->getKind() == Kind::Use; } - - UseFact(const Expr *UseExpr) : Fact(Kind::Use), UseExpr(UseExpr) {} - - OriginID getUsedOrigin(const OriginManager &OM) const { - // TODO: Remove const cast and make OriginManager::get as const. - return const_cast<OriginManager &>(OM).get(*UseExpr); - } - const Expr *getUseExpr() const { return UseExpr; } - void markAsWritten() { IsWritten = true; } - bool isWritten() const { return IsWritten; } - - void dump(llvm::raw_ostream &OS, const LoanManager &, - const OriginManager &OM) const override { - OS << "Use ("; - OM.dump(getUsedOrigin(OM), OS); - OS << ", " << (isWritten() ? "Write" : "Read") << ")\n"; - } -}; - -/// A dummy-fact used to mark a specific point in the code for testing. -/// It is generated by recognizing a `void("__lifetime_test_point_...")` cast. -class TestPointFact : public Fact { - StringRef Annotation; - -public: - static bool classof(const Fact *F) { return F->getKind() == Kind::TestPoint; } - - explicit TestPointFact(StringRef Annotation) - : Fact(Kind::TestPoint), Annotation(Annotation) {} - - StringRef getAnnotation() const { return Annotation; } - - void dump(llvm::raw_ostream &OS, const LoanManager &, - const OriginManager &) const override { - OS << "TestPoint (Annotation: \"" << getAnnotation() << "\")\n"; - } -}; - -class FactManager { -public: - llvm::ArrayRef<const Fact *> getFacts(const CFGBlock *B) const { - auto It = BlockToFactsMap.find(B); - if (It != BlockToFactsMap.end()) - return It->second; - return {}; - } - - void addBlockFacts(const CFGBlock *B, llvm::ArrayRef<Fact *> NewFacts) { - if (!NewFacts.empty()) - BlockToFactsMap[B].assign(NewFacts.begin(), NewFacts.end()); - } - - template <typename FactType, typename... Args> - FactType *createFact(Args &&...args) { - void *Mem = FactAllocator.Allocate<FactType>(); - return new (Mem) FactType(std::forward<Args>(args)...); - } - - void dump(const CFG &Cfg, AnalysisDeclContext &AC) const { - llvm::dbgs() << "==========================================\n"; - llvm::dbgs() << " Lifetime Analysis Facts:\n"; - llvm::dbgs() << "==========================================\n"; - if (const Decl *D = AC.getDecl()) - if (const auto *ND = dyn_cast<NamedDecl>(D)) - llvm::dbgs() << "Function: " << ND->getQualifiedNameAsString() << "\n"; - // Print blocks in the order as they appear in code for a stable ordering. - for (const CFGBlock *B : *AC.getAnalysis<PostOrderCFGView>()) { - llvm::dbgs() << " Block B" << B->getBlockID() << ":\n"; - auto It = BlockToFactsMap.find(B); - if (It != BlockToFactsMap.end()) { - for (const Fact *F : It->second) { - llvm::dbgs() << " "; - F->dump(llvm::dbgs(), LoanMgr, OriginMgr); - } - } - llvm::dbgs() << " End of Block\n"; - } - } - - LoanManager &getLoanMgr() { return LoanMgr; } - OriginManager &getOriginMgr() { return OriginMgr; } - -private: - LoanManager LoanMgr; - OriginManager OriginMgr; - llvm::DenseMap<const clang::CFGBlock *, llvm::SmallVector<const Fact *>> - BlockToFactsMap; - llvm::BumpPtrAllocator FactAllocator; -}; - -class FactGenerator : public ConstStmtVisitor<FactGenerator> { - using Base = ConstStmtVisitor<FactGenerator>; - -public: - FactGenerator(FactManager &FactMgr, AnalysisDeclContext &AC) - : FactMgr(FactMgr), AC(AC) {} - - void run() { - llvm::TimeTraceScope TimeProfile("FactGenerator"); - // Iterate through the CFG blocks in reverse post-order to ensure that - // initializations and destructions are processed in the correct sequence. - for (const CFGBlock *Block : *AC.getAnalysis<PostOrderCFGView>()) { - CurrentBlockFacts.clear(); - for (unsigned I = 0; I < Block->size(); ++I) { - const CFGElement &Element = Block->Elements[I]; - if (std::optional<CFGStmt> CS = Element.getAs<CFGStmt>()) - Visit(CS->getStmt()); - else if (std::optional<CFGAutomaticObjDtor> DtorOpt = - Element.getAs<CFGAutomaticObjDtor>()) - handleDestructor(*DtorOpt); - } - FactMgr.addBlockFacts(Block, CurrentBlockFacts); - } - } - - void VisitDeclStmt(const DeclStmt *DS) { - for (const Decl *D : DS->decls()) - if (const auto *VD = dyn_cast<VarDecl>(D)) - if (hasOrigin(VD)) - if (const Expr *InitExpr = VD->getInit()) - killAndFlowOrigin(*VD, *InitExpr); - } - - void VisitDeclRefExpr(const DeclRefExpr *DRE) { - handleUse(DRE); - // For non-pointer/non-view types, a reference to the variable's storage - // is a borrow. We create a loan for it. - // For pointer/view types, we stick to the existing model for now and do - // not create an extra origin for the l-value expression itself. - - // TODO: A single origin for a `DeclRefExpr` for a pointer or view type is - // not sufficient to model the different levels of indirection. The current - // single-origin model cannot distinguish between a loan to the variable's - // storage and a loan to what it points to. A multi-origin model would be - // required for this. - if (!isPointerType(DRE->getType())) { - if (const Loan *L = createLoan(DRE)) { - OriginID ExprOID = FactMgr.getOriginMgr().getOrCreate(*DRE); - CurrentBlockFacts.push_back( - FactMgr.createFact<IssueFact>(L->ID, ExprOID)); - } - } - } - - void VisitCXXConstructExpr(const CXXConstructExpr *CCE) { - if (isGslPointerType(CCE->getType())) { - handleGSLPointerConstruction(CCE); - return; - } - } - - void VisitCXXMemberCallExpr(const CXXMemberCallExpr *MCE) { - // Specifically for conversion operators, - // like `std::string_view p = std::string{};` - if (isGslPointerType(MCE->getType()) && - isa<CXXConversionDecl>(MCE->getCalleeDecl())) { - // The argument is the implicit object itself. - handleFunctionCall(MCE, MCE->getMethodDecl(), - {MCE->getImplicitObjectArgument()}, - /*IsGslConstruction=*/true); - } - if (const CXXMethodDecl *Method = MCE->getMethodDecl()) { - // Construct the argument list, with the implicit 'this' object as the - // first argument. - llvm::SmallVector<const Expr *, 4> Args; - Args.push_back(MCE->getImplicitObjectArgument()); - Args.append(MCE->getArgs(), MCE->getArgs() + MCE->getNumArgs()); - - handleFunctionCall(MCE, Method, Args, /*IsGslConstruction=*/false); - } - } - - void VisitCallExpr(const CallExpr *CE) { - handleFunctionCall(CE, CE->getDirectCallee(), - {CE->getArgs(), CE->getNumArgs()}); - } - - void VisitCXXNullPtrLiteralExpr(const CXXNullPtrLiteralExpr *N) { - /// TODO: Handle nullptr expr as a special 'null' loan. Uninitialized - /// pointers can use the same type of loan. - FactMgr.getOriginMgr().getOrCreate(*N); - } - - void VisitImplicitCastExpr(const ImplicitCastExpr *ICE) { - if (!hasOrigin(ICE)) - return; - // An ImplicitCastExpr node itself gets an origin, which flows from the - // origin of its sub-expression (after stripping its own parens/casts). - killAndFlowOrigin(*ICE, *ICE->getSubExpr()); - } - - void VisitUnaryOperator(const UnaryOperator *UO) { - if (UO->getOpcode() == UO_AddrOf) { - const Expr *SubExpr = UO->getSubExpr(); - // Taking address of a pointer-type expression is not yet supported and - // will be supported in multi-origin model. - if (isPointerType(SubExpr->getType())) - return; - // The origin of an address-of expression (e.g., &x) is the origin of - // its sub-expression (x). This fact will cause the dataflow analysis - // to propagate any loans held by the sub-expression's origin to the - // origin of this UnaryOperator expression. - killAndFlowOrigin(*UO, *SubExpr); - } - } - - void VisitReturnStmt(const ReturnStmt *RS) { - if (const Expr *RetExpr = RS->getRetValue()) { - if (hasOrigin(RetExpr)) { - OriginID OID = FactMgr.getOriginMgr().getOrCreate(*RetExpr); - CurrentBlockFacts.push_back( - FactMgr.createFact<ReturnOfOriginFact>(OID)); - } - } - } - - void VisitBinaryOperator(const BinaryOperator *BO) { - if (BO->isAssignmentOp()) - handleAssignment(BO->getLHS(), BO->getRHS()); - } - - void VisitCXXOperatorCallExpr(const CXXOperatorCallExpr *OCE) { - // Assignment operators have special "kill-then-propagate" semantics - // and are handled separately. - if (OCE->isAssignmentOp() && OCE->getNumArgs() == 2) { - handleAssignment(OCE->getArg(0), OCE->getArg(1)); - return; - } - handleFunctionCall(OCE, OCE->getDirectCallee(), - {OCE->getArgs(), OCE->getNumArgs()}, - /*IsGslConstruction=*/false); - } - - void VisitCXXFunctionalCastExpr(const CXXFunctionalCastExpr *FCE) { - // Check if this is a test point marker. If so, we are done with this - // expression. - if (handleTestPoint(FCE)) - return; - if (isGslPointerType(FCE->getType())) - killAndFlowOrigin(*FCE, *FCE->getSubExpr()); - } - - void VisitInitListExpr(const InitListExpr *ILE) { - if (!hasOrigin(ILE)) - return; - // For list initialization with a single element, like `View{...}`, the - // origin of the list itself is the origin of its single element. - if (ILE->getNumInits() == 1) - killAndFlowOrigin(*ILE, *ILE->getInit(0)); - } - - void VisitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *MTE) { - if (!hasOrigin(MTE)) - return; - // A temporary object's origin is the same as the origin of the - // expression that initializes it. - killAndFlowOrigin(*MTE, *MTE->getSubExpr()); - } - - void handleDestructor(const CFGAutomaticObjDtor &DtorOpt) { - /// TODO: Also handle trivial destructors (e.g., for `int` - /// variables) which will never have a CFGAutomaticObjDtor node. - /// TODO: Handle loans to temporaries. - /// TODO: Consider using clang::CFG::BuildOptions::AddLifetime to reuse the - /// lifetime ends. - const VarDecl *DestructedVD = DtorOpt.getVarDecl(); - if (!DestructedVD) - return; - // Iterate through all loans to see if any expire. - /// TODO(opt): Do better than a linear search to find loans associated with - /// 'DestructedVD'. - for (const Loan &L : FactMgr.getLoanMgr().getLoans()) { - const AccessPath &LoanPath = L.Path; - // Check if the loan is for a stack variable and if that variable - // is the one being destructed. - if (LoanPath.D == DestructedVD) - CurrentBlockFacts.push_back(FactMgr.createFact<ExpireFact>( - L.ID, DtorOpt.getTriggerStmt()->getEndLoc())); - } - } - -private: - static bool isGslPointerType(QualType QT) { - if (const auto *RD = QT->getAsCXXRecordDecl()) { - // We need to check the template definition for specializations. - if (auto *CTSD = dyn_cast<ClassTemplateSpecializationDecl>(RD)) - return CTSD->getSpecializedTemplate() - ->getTemplatedDecl() - ->hasAttr<PointerAttr>(); - return RD->hasAttr<PointerAttr>(); - } - return false; - } - - static bool isPointerType(QualType QT) { - return QT->isPointerOrReferenceType() || isGslPointerType(QT); - } - // Check if a type has an origin. - static bool hasOrigin(const Expr *E) { - return E->isGLValue() || isPointerType(E->getType()); - } - - static bool hasOrigin(const VarDecl *VD) { - return isPointerType(VD->getType()); - } - - void handleGSLPointerConstruction(const CXXConstructExpr *CCE) { - assert(isGslPointerType(CCE->getType())); - if (CCE->getNumArgs() != 1) - return; - if (hasOrigin(CCE->getArg(0))) - killAndFlowOrigin(*CCE, *CCE->getArg(0)); - else - // This could be a new borrow. - handleFunctionCall(CCE, CCE->getConstructor(), - {CCE->getArgs(), CCE->getNumArgs()}, - /*IsGslConstruction=*/true); - } - - /// Checks if a call-like expression creates a borrow by passing a value to a - /// reference parameter, creating an IssueFact if it does. - /// \param IsGslConstruction True if this is a GSL construction where all - /// argument origins should flow to the returned origin. - void handleFunctionCall(const Expr *Call, const FunctionDecl *FD, - ArrayRef<const Expr *> Args, - bool IsGslConstruction = false) { - // Ignore functions returning values with no origin. - if (!FD || !hasOrigin(Call)) - return; - auto IsArgLifetimeBound = [FD](unsigned I) -> bool { - const ParmVarDecl *PVD = nullptr; - if (const auto *Method = dyn_cast<CXXMethodDecl>(FD); - Method && Method->isInstance()) { - if (I == 0) - // For the 'this' argument, the attribute is on the method itself. - return implicitObjectParamIsLifetimeBound(Method); - if ((I - 1) < Method->getNumParams()) - // For explicit arguments, find the corresponding parameter - // declaration. - PVD = Method->getParamDecl(I - 1); - } else if (I < FD->getNumParams()) - // For free functions or static methods. - PVD = FD->getParamDecl(I); - return PVD ? PVD->hasAttr<clang::LifetimeBoundAttr>() : false; - }; - if (Args.empty()) - return; - bool killedSrc = false; - for (unsigned I = 0; I < Args.size(); ++I) - if (IsGslConstruction || IsArgLifetimeBound(I)) { - if (!killedSrc) { - killedSrc = true; - killAndFlowOrigin(*Call, *Args[I]); - } else - flowOrigin(*Call, *Args[I]); - } - } - - /// Creates a loan for the storage path of a given declaration reference. - /// This function should be called whenever a DeclRefExpr represents a borrow. - /// \param DRE The declaration reference expression that initiates the borrow. - /// \return The new Loan on success, nullptr otherwise. - const Loan *createLoan(const DeclRefExpr *DRE) { - if (const auto *VD = dyn_cast<ValueDecl>(DRE->getDecl())) { - AccessPath Path(VD); - // The loan is created at the location of the DeclRefExpr. - return &FactMgr.getLoanMgr().addLoan(Path, DRE); - } - return nullptr; - } - - template <typename Destination, typename Source> - void flowOrigin(const Destination &D, const Source &S) { - OriginID DestOID = FactMgr.getOriginMgr().getOrCreate(D); - OriginID SrcOID = FactMgr.getOriginMgr().get(S); - CurrentBlockFacts.push_back(FactMgr.createFact<OriginFlowFact>( - DestOID, SrcOID, /*KillDest=*/false)); - } - - template <typename Destination, typename Source> - void killAndFlowOrigin(const Destination &D, const Source &S) { - OriginID DestOID = FactMgr.getOriginMgr().getOrCreate(D); - OriginID SrcOID = FactMgr.getOriginMgr().get(S); - CurrentBlockFacts.push_back( - FactMgr.createFact<OriginFlowFact>(DestOID, SrcOID, /*KillDest=*/true)); - } - - /// Checks if the expression is a `void("__lifetime_test_point_...")` cast. - /// If so, creates a `TestPointFact` and returns true. - bool handleTestPoint(const CXXFunctionalCastExpr *FCE) { - if (!FCE->getType()->isVoidType()) - return false; - - const auto *SubExpr = FCE->getSubExpr()->IgnoreParenImpCasts(); - if (const auto *SL = dyn_cast<StringLiteral>(SubExpr)) { - llvm::StringRef LiteralValue = SL->getString(); - const std::string Prefix = "__lifetime_test_point_"; - - if (LiteralValue.starts_with(Prefix)) { - StringRef Annotation = LiteralValue.drop_front(Prefix.length()); - CurrentBlockFacts.push_back( - FactMgr.createFact<TestPointFact>(Annotation)); - return true; - } - } - return false; - } - - void handleAssignment(const Expr *LHSExpr, const Expr *RHSExpr) { - if (!hasOrigin(LHSExpr)) - return; - // Find the underlying variable declaration for the left-hand side. - if (const auto *DRE_LHS = - dyn_cast<DeclRefExpr>(LHSExpr->IgnoreParenImpCasts())) { - markUseAsWrite(DRE_LHS); - if (const auto *VD_LHS = dyn_cast<ValueDecl>(DRE_LHS->getDecl())) { - // Kill the old loans of the destination origin and flow the new loans - // from the source origin. - killAndFlowOrigin(*VD_LHS, *RHSExpr); - } - } - } - - // A DeclRefExpr will be treated as a use of the referenced decl. It will be - // checked for use-after-free unless it is later marked as being written to - // (e.g. on the left-hand side of an assignment). - void handleUse(const DeclRefExpr *DRE) { - if (isPointerType(DRE->getType())) { - UseFact *UF = FactMgr.createFact<UseFact>(DRE); - CurrentBlockFacts.push_back(UF); - assert(!UseFacts.contains(DRE)); - UseFacts[DRE] = UF; - } - } - - void markUseAsWrite(const DeclRefExpr *DRE) { - if (!isPointerType(DRE->getType())) - return; - assert(UseFacts.contains(DRE)); - UseFacts[DRE]->markAsWritten(); - } - - FactManager &FactMgr; - AnalysisDeclContext &AC; - llvm::SmallVector<Fact *> CurrentBlockFacts; - // To distinguish between reads and writes for use-after-free checks, this map - // stores the `UseFact` for each `DeclRefExpr`. We initially identify all - // `DeclRefExpr`s as "read" uses. When an assignment is processed, the use - // corresponding to the left-hand side is updated to be a "write", thereby - // exempting it from the check. - llvm::DenseMap<const DeclRefExpr *, UseFact *> UseFacts; -}; - -// ========================================================================= // -// Generic Dataflow Analysis -// ========================================================================= // - -enum class Direction { Forward, Backward }; - -/// A `ProgramPoint` identifies a location in the CFG by pointing to a specific -/// `Fact`. identified by a lifetime-related event (`Fact`). -/// -/// A `ProgramPoint` has "after" semantics: it represents the location -/// immediately after its corresponding `Fact`. -using ProgramPoint = const Fact *; - -/// A generic, policy-based driver for dataflow analyses. It combines -/// the dataflow runner and the transferer logic into a single class hierarchy. -/// -/// The derived class is expected to provide: -/// - A `Lattice` type. -/// - `StringRef getAnalysisName() const` -/// - `Lattice getInitialState();` The initial state of the analysis. -/// - `Lattice join(Lattice, Lattice);` Merges states from multiple CFG paths. -/// - `Lattice transfer(Lattice, const FactType&);` Defines how a single -/// lifetime-relevant `Fact` transforms the lattice state. Only overloads -/// for facts relevant to the analysis need to be implemented. -/// -/// \tparam Derived The CRTP derived class that implements the specific -/// analysis. -/// \tparam LatticeType The dataflow lattice used by the analysis. -/// \tparam Dir The direction of the analysis (Forward or Backward). -/// TODO: Maybe use the dataflow framework! The framework might need changes -/// to support the current comparison done at block-entry. -template <typename Derived, typename LatticeType, Direction Dir> -class DataflowAnalysis { -public: - using Lattice = LatticeType; - using Base = DataflowAnalysis<Derived, Lattice, Dir>; - -private: - const CFG &Cfg; - AnalysisDeclContext &AC; - - /// The dataflow state before a basic block is processed. - llvm::DenseMap<const CFGBlock *, Lattice> InStates; - /// The dataflow state after a basic block is processed. - llvm::DenseMap<const CFGBlock *, Lattice> OutStates; - /// The dataflow state at a Program Point. - /// In a forward analysis, this is the state after the Fact at that point has - /// been applied, while in a backward analysis, it is the state before. - llvm::DenseMap<ProgramPoint, Lattice> PerPointStates; - - static constexpr bool isForward() { return Dir == Direction::Forward; } - -protected: - FactManager &AllFacts; - - explicit DataflowAnalysis(const CFG &C, AnalysisDeclContext &AC, - FactManager &F) - : Cfg(C), AC(AC), AllFacts(F) {} - -public: - void run() { - Derived &D = static_cast<Derived &>(*this); - llvm::TimeTraceScope Time(D.getAnalysisName()); - - using Worklist = - std::conditional_t<Dir == Direction::Forward, ForwardDataflowWorklist, - BackwardDataflowWorklist>; - Worklist W(Cfg, AC); - - const CFGBlock *Start = isForward() ? &Cfg.getEntry() : &Cfg.getExit(); - InStates[Start] = D.getInitialState(); - W.enqueueBlock(Start); - - while (const CFGBlock *B = W.dequeue()) { - Lattice StateIn = *getInState(B); - Lattice StateOut = transferBlock(B, StateIn); - OutStates[B] = StateOut; - for (const CFGBlock *AdjacentB : isForward() ? B->succs() : B->preds()) { - if (!AdjacentB) - continue; - std::optional<Lattice> OldInState = getInState(AdjacentB); - Lattice NewInState = - !OldInState ? StateOut : D.join(*OldInState, StateOut); - // Enqueue the adjacent block if its in-state has changed or if we have - // never seen it. - if (!OldInState || NewInState != *OldInState) { - InStates[AdjacentB] = NewInState; - W.enqueueBlock(AdjacentB); - } - } - } - } - -protected: - Lattice getState(ProgramPoint P) const { return PerPointStates.lookup(P); } - - std::optional<Lattice> getInState(const CFGBlock *B) const { - auto It = InStates.find(B); - if (It == InStates.end()) - return std::nullopt; - return It->second; - } - - Lattice getOutState(const CFGBlock *B) const { return OutStates.lookup(B); } - - void dump() const { - const Derived *D = static_cast<const Derived *>(this); - llvm::dbgs() << "==========================================\n"; - llvm::dbgs() << D->getAnalysisName() << " results:\n"; - llvm::dbgs() << "==========================================\n"; - const CFGBlock &B = isForward() ? Cfg.getExit() : Cfg.getEntry(); - getOutState(&B).dump(llvm::dbgs()); - } - -private: - /// Computes the state at one end of a block by applying all its facts - /// sequentially to a given state from the other end. - Lattice transferBlock(const CFGBlock *Block, Lattice State) { - auto Facts = AllFacts.getFacts(Block); - if constexpr (isForward()) { - for (const Fact *F : Facts) { - State = transferFact(State, F); - PerPointStates[F] = State; - } - } else { - for (const Fact *F : llvm::reverse(Facts)) { - // In backward analysis, capture the state before applying the fact. - PerPointStates[F] = State; - State = transferFact(State, F); - } - } - return State; - } - - Lattice transferFact(Lattice In, const Fact *F) { - assert(F); - Derived *D = static_cast<Derived *>(this); - switch (F->getKind()) { - case Fact::Kind::Issue: - return D->transfer(In, *F->getAs<IssueFact>()); - case Fact::Kind::Expire: - return D->transfer(In, *F->getAs<ExpireFact>()); - case Fact::Kind::OriginFlow: - return D->transfer(In, *F->getAs<OriginFlowFact>()); - case Fact::Kind::ReturnOfOrigin: - return D->transfer(In, *F->getAs<ReturnOfOriginFact>()); - case Fact::Kind::Use: - return D->transfer(In, *F->getAs<UseFact>()); - case Fact::Kind::TestPoint: - return D->transfer(In, *F->getAs<TestPointFact>()); - } - llvm_unreachable("Unknown fact kind"); - } - -public: - Lattice transfer(Lattice In, const IssueFact &) { return In; } - Lattice transfer(Lattice In, const ExpireFact &) { return In; } - Lattice transfer(Lattice In, const OriginFlowFact &) { return In; } - Lattice transfer(Lattice In, const ReturnOfOriginFact &) { return In; } - Lattice transfer(Lattice In, const UseFact &) { return In; } - Lattice transfer(Lattice In, const TestPointFact &) { return In; } -}; - -namespace utils { - -/// Computes the union of two ImmutableSets. -template <typename T> -static llvm::ImmutableSet<T> join(llvm::ImmutableSet<T> A, - llvm::ImmutableSet<T> B, - typename llvm::ImmutableSet<T>::Factory &F) { - if (A.getHeight() < B.getHeight()) - std::swap(A, B); - for (const T &E : B) - A = F.add(A, E); - return A; -} - -/// Describes the strategy for joining two `ImmutableMap` instances, primarily -/// differing in how they handle keys that are unique to one of the maps. -/// -/// A `Symmetric` join is universally correct, while an `Asymmetric` join -/// serves as a performance optimization. The latter is applicable only when the -/// join operation possesses a left identity element, allowing for a more -/// efficient, one-sided merge. -enum class JoinKind { - /// A symmetric join applies the `JoinValues` operation to keys unique to - /// either map, ensuring that values from both maps contribute to the result. - Symmetric, - /// An asymmetric join preserves keys unique to the first map as-is, while - /// applying the `JoinValues` operation only to keys unique to the second map. - Asymmetric, -}; - -/// Computes the key-wise union of two ImmutableMaps. -// TODO(opt): This key-wise join is a performance bottleneck. A more -// efficient merge could be implemented using a Patricia Trie or HAMT -// instead of the current AVL-tree-based ImmutableMap. -template <typename K, typename V, typename Joiner> -static llvm::ImmutableMap<K, V> -join(const llvm::ImmutableMap<K, V> &A, const llvm::ImmutableMap<K, V> &B, - typename llvm::ImmutableMap<K, V>::Factory &F, Joiner JoinValues, - JoinKind Kind) { - if (A.getHeight() < B.getHeight()) - return join(B, A, F, JoinValues, Kind); - - // For each element in B, join it with the corresponding element in A - // (or with an empty value if it doesn't exist in A). - llvm::ImmutableMap<K, V> Res = A; - for (const auto &Entry : B) { - const K &Key = Entry.first; - const V &ValB = Entry.second; - Res = F.add(Res, Key, JoinValues(A.lookup(Key), &ValB)); - } - if (Kind == JoinKind::Symmetric) { - for (const auto &Entry : A) { - const K &Key = Entry.first; - const V &ValA = Entry.second; - if (!B.contains(Key)) - Res = F.add(Res, Key, JoinValues(&ValA, nullptr)); - } - } - return Res; -} -} // namespace utils - -// ========================================================================= // -// Loan Propagation Analysis -// ========================================================================= // - -/// Represents the dataflow lattice for loan propagation. -/// -/// This lattice tracks which loans each origin may hold at a given program -/// point.The lattice has a finite height: An origin's loan set is bounded by -/// the total number of loans in the function. -/// TODO(opt): To reduce the lattice size, propagate origins of declarations, -/// not expressions, because expressions are not visible across blocks. -struct LoanPropagationLattice { - /// The map from an origin to the set of loans it contains. - OriginLoanMap Origins = OriginLoanMap(nullptr); - - explicit LoanPropagationLattice(const OriginLoanMap &S) : Origins(S) {} - LoanPropagationLattice() = default; - - bool operator==(const LoanPropagationLattice &Other) const { - return Origins == Other.Origins; - } - bool operator!=(const LoanPropagationLattice &Other) const { - return !(*this == Other); - } - - void dump(llvm::raw_ostream &OS) const { - OS << "LoanPropagationLattice State:\n"; - if (Origins.isEmpty()) - OS << " <empty>\n"; - for (const auto &Entry : Origins) { - if (Entry.second.isEmpty()) - OS << " Origin " << Entry.first << " contains no loans\n"; - for (const LoanID &LID : Entry.second) - OS << " Origin " << Entry.first << " contains Loan " << LID << "\n"; - } - } -}; - -/// The analysis that tracks which loans belong to which origins. -class LoanPropagationAnalysis - : public DataflowAnalysis<LoanPropagationAnalysis, LoanPropagationLattice, - Direction::Forward> { - OriginLoanMap::Factory &OriginLoanMapFactory; - LoanSet::Factory &LoanSetFactory; - -public: - LoanPropagationAnalysis(const CFG &C, AnalysisDeclContext &AC, FactManager &F, - OriginLoanMap::Factory &OriginLoanMapFactory, - LoanSet::Factory &LoanSetFactory) - : DataflowAnalysis(C, AC, F), OriginLoanMapFactory(OriginLoanMapFactory), - LoanSetFactory(LoanSetFactory) {} - - using Base::transfer; - - StringRef getAnalysisName() const { return "LoanPropagation"; } - - Lattice getInitialState() { return Lattice{}; } - - /// Merges two lattices by taking the union of loans for each origin. - // TODO(opt): Keep the state small by removing origins which become dead. - Lattice join(Lattice A, Lattice B) { - OriginLoanMap JoinedOrigins = utils::join( - A.Origins, B.Origins, OriginLoanMapFactory, - [&](const LoanSet *S1, const LoanSet *S2) { - assert((S1 || S2) && "unexpectedly merging 2 empty sets"); - if (!S1) - return *S2; - if (!S2) - return *S1; - return utils::join(*S1, *S2, LoanSetFactory); - }, - // Asymmetric join is a performance win. For origins present only on one - // branch, the loan set can be carried over as-is. - utils::JoinKind::Asymmetric); - return Lattice(JoinedOrigins); - } - - /// A new loan is issued to the origin. Old loans are erased. - Lattice transfer(Lattice In, const IssueFact &F) { - OriginID OID = F.getOriginID(); - LoanID LID = F.getLoanID(); - return LoanPropagationLattice(OriginLoanMapFactory.add( - In.Origins, OID, - LoanSetFactory.add(LoanSetFactory.getEmptySet(), LID))); - } - - /// A flow from source to destination. If `KillDest` is true, this replaces - /// the destination's loans with the source's. Otherwise, the source's loans - /// are merged into the destination's. - Lattice transfer(Lattice In, const OriginFlowFact &F) { - OriginID DestOID = F.getDestOriginID(); - OriginID SrcOID = F.getSrcOriginID(); - - LoanSet DestLoans = - F.getKillDest() ? LoanSetFactory.getEmptySet() : getLoans(In, DestOID); - LoanSet SrcLoans = getLoans(In, SrcOID); - LoanSet MergedLoans = utils::join(DestLoans, SrcLoans, LoanSetFactory); - - return LoanPropagationLattice( - OriginLoanMapFactory.add(In.Origins, DestOID, MergedLoans)); - } - - LoanSet getLoans(OriginID OID, ProgramPoint P) const { - return getLoans(getState(P), OID); - } - -private: - LoanSet getLoans(Lattice L, OriginID OID) const { - if (auto *Loans = L.Origins.lookup(OID)) - return *Loans; - return LoanSetFactory.getEmptySet(); - } -}; - -// ========================================================================= // -// Live Origins Analysis -// ========================================================================= // -// -// A backward dataflow analysis that determines which origins are "live" at each -// program point. An origin is "live" at a program point if there's a potential -// future use of the pointer it represents. Liveness is "generated" by a read of -// origin's loan set (e.g., a `UseFact`) and is "killed" (i.e., it stops being -// live) when its loan set is overwritten (e.g. a OriginFlow killing the -// destination origin). -// -// This information is used for detecting use-after-free errors, as it allows us -// to check if a live origin holds a loan to an object that has already expired. -// ========================================================================= // - -/// Information about why an origin is live at a program point. -struct LivenessInfo { - /// The use that makes the origin live. If liveness is propagated from - /// multiple uses along different paths, this will point to the use appearing - /// earlier in the translation unit. - /// This is 'null' when the origin is not live. - const UseFact *CausingUseFact; - /// The kind of liveness of the origin. - /// `Must`: The origin is live on all control-flow paths from the current - /// point to the function's exit (i.e. the current point is dominated by a set - /// of uses). - /// `Maybe`: indicates it is live on some but not all paths. - /// - /// This determines the diagnostic's confidence level. - /// `Must`-be-alive at expiration implies a definite use-after-free, - /// while `Maybe`-be-alive suggests a potential one on some paths. - LivenessKind Kind; - - LivenessInfo() : CausingUseFact(nullptr), Kind(LivenessKind::Dead) {} - LivenessInfo(const UseFact *UF, LivenessKind K) - : CausingUseFact(UF), Kind(K) {} - - bool operator==(const LivenessInfo &Other) const { - return CausingUseFact == Other.CausingUseFact && Kind == Other.Kind; - } - bool operator!=(const LivenessInfo &Other) const { return !(*this == Other); } - - void Profile(llvm::FoldingSetNodeID &IDBuilder) const { - IDBuilder.AddPointer(CausingUseFact); - IDBuilder.Add(Kind); - } -}; - -using LivenessMap = llvm::ImmutableMap<OriginID, LivenessInfo>; - -/// The dataflow lattice for origin liveness analysis. -/// It tracks which origins are live, why they're live (which UseFact), -/// and the confidence level of that liveness. -struct LivenessLattice { - LivenessMap LiveOrigins; - - LivenessLattice() : LiveOrigins(nullptr) {}; - - explicit LivenessLattice(LivenessMap L) : LiveOrigins(L) {} - - bool operator==(const LivenessLattice &Other) const { - return LiveOrigins == Other.LiveOrigins; - } - - bool operator!=(const LivenessLattice &Other) const { - return !(*this == Other); - } - - void dump(llvm::raw_ostream &OS, const OriginManager &OM) const { - if (LiveOrigins.isEmpty()) - OS << " <empty>\n"; - for (const auto &Entry : LiveOrigins) { - OriginID OID = Entry.first; - const LivenessInfo &Info = Entry.second; - OS << " "; - OM.dump(OID, OS); - OS << " is "; - switch (Info.Kind) { - case LivenessKind::Must: - OS << "definitely"; - break; - case LivenessKind::Maybe: - OS << "maybe"; - break; - case LivenessKind::Dead: - llvm_unreachable("liveness kind of live origins should not be dead."); - } - OS << " live at this point\n"; - } - } -}; - -/// The analysis that tracks which origins are live, with granular information -/// about the causing use fact and confidence level. This is a backward -/// analysis. -class LiveOriginAnalysis - : public DataflowAnalysis<LiveOriginAnalysis, LivenessLattice, - Direction::Backward> { - FactManager &FactMgr; - LivenessMap::Factory &Factory; - -public: - LiveOriginAnalysis(const CFG &C, AnalysisDeclContext &AC, FactManager &F, - LivenessMap::Factory &SF) - : DataflowAnalysis(C, AC, F), FactMgr(F), Factory(SF) {} - using DataflowAnalysis<LiveOriginAnalysis, Lattice, - Direction::Backward>::transfer; - - StringRef getAnalysisName() const { return "LiveOrigins"; } - - Lattice getInitialState() { return Lattice(Factory.getEmptyMap()); } - - /// Merges two lattices by combining liveness information. - /// When the same origin has different confidence levels, we take the lower - /// one. - Lattice join(Lattice L1, Lattice L2) const { - LivenessMap Merged = L1.LiveOrigins; - // Take the earliest UseFact to make the join hermetic and commutative. - auto CombineUseFact = [](const UseFact &A, - const UseFact &B) -> const UseFact * { - return A.getUseExpr()->getExprLoc() < B.getUseExpr()->getExprLoc() ? &A - : &B; - }; - auto CombineLivenessKind = [](LivenessKind K1, - LivenessKind K2) -> LivenessKind { - assert(K1 != LivenessKind::Dead && "LivenessKind should not be dead."); - assert(K2 != LivenessKind::Dead && "LivenessKind should not be dead."); - // Only return "Must" if both paths are "Must", otherwise Maybe. - if (K1 == LivenessKind::Must && K2 == LivenessKind::Must) - return LivenessKind::Must; - return LivenessKind::Maybe; - }; - auto CombineLivenessInfo = [&](const LivenessInfo *L1, - const LivenessInfo *L2) -> LivenessInfo { - assert((L1 || L2) && "unexpectedly merging 2 empty sets"); - if (!L1) - return LivenessInfo(L2->CausingUseFact, LivenessKind::Maybe); - if (!L2) - return LivenessInfo(L1->CausingUseFact, LivenessKind::Maybe); - return LivenessInfo( - CombineUseFact(*L1->CausingUseFact, *L2->CausingUseFact), - CombineLivenessKind(L1->Kind, L2->Kind)); - }; - return Lattice(utils::join( - L1.LiveOrigins, L2.LiveOrigins, Factory, CombineLivenessInfo, - // A symmetric join is required here. If an origin is live on one - // branch but not the other, its confidence must be demoted to `Maybe`. - utils::JoinKind::Symmetric)); - } - - /// A read operation makes the origin live with definite confidence, as it - /// dominates this program point. A write operation kills the liveness of - /// the origin since it overwrites the value. - Lattice transfer(Lattice In, const UseFact &UF) { - OriginID OID = UF.getUsedOrigin(FactMgr.getOriginMgr()); - // Write kills liveness. - if (UF.isWritten()) - return Lattice(Factory.remove(In.LiveOrigins, OID)); - // Read makes origin live with definite confidence (dominates this point). - return Lattice(Factory.add(In.LiveOrigins, OID, - LivenessInfo(&UF, LivenessKind::Must))); - } - - /// Issuing a new loan to an origin kills its liveness. - Lattice transfer(Lattice In, const IssueFact &IF) { - return Lattice(Factory.remove(In.LiveOrigins, IF.getOriginID())); - } - - /// An OriginFlow kills the liveness of the destination origin if `KillDest` - /// is true. Otherwise, it propagates liveness from destination to source. - Lattice transfer(Lattice In, const OriginFlowFact &OF) { - if (!OF.getKillDest()) - return In; - return Lattice(Factory.remove(In.LiveOrigins, OF.getDestOriginID())); - } - - LivenessMap getLiveOrigins(ProgramPoint P) const { - return getState(P).LiveOrigins; - } - - // Dump liveness values on all test points in the program. - void dump(llvm::raw_ostream &OS, const LifetimeSafetyAnalysis &LSA) const { - llvm::dbgs() << "==========================================\n"; - llvm::dbgs() << getAnalysisName() << " results:\n"; - llvm::dbgs() << "==========================================\n"; - for (const auto &Entry : LSA.getTestPoints()) { - OS << "TestPoint: " << Entry.getKey() << "\n"; - getState(Entry.getValue()).dump(OS, FactMgr.getOriginMgr()); - } - } -}; - -// ========================================================================= // -// Lifetime checker and Error reporter -// ========================================================================= // - -/// Struct to store the complete context for a potential lifetime violation. -struct PendingWarning { - SourceLocation ExpiryLoc; // Where the loan expired. - const Expr *UseExpr; // Where the origin holding this loan was used. - Confidence ConfidenceLevel; -}; - -class LifetimeChecker { -private: - llvm::DenseMap<LoanID, PendingWarning> FinalWarningsMap; - LoanPropagationAnalysis &LoanPropagation; - LiveOriginAnalysis &LiveOrigins; - FactManager &FactMgr; - AnalysisDeclContext &ADC; - LifetimeSafetyReporter *Reporter; - -public: - LifetimeChecker(LoanPropagationAnalysis &LPA, LiveOriginAnalysis &LOA, - FactManager &FM, AnalysisDeclContext &ADC, - LifetimeSafetyReporter *Reporter) - : LoanPropagation(LPA), LiveOrigins(LOA), FactMgr(FM), ADC(ADC), - Reporter(Reporter) {} - - void run() { - llvm::TimeTraceScope TimeProfile("LifetimeChecker"); - for (const CFGBlock *B : *ADC.getAnalysis<PostOrderCFGView>()) - for (const Fact *F : FactMgr.getFacts(B)) - if (const auto *EF = F->getAs<ExpireFact>()) - checkExpiry(EF); - issuePendingWarnings(); - } - - /// Checks for use-after-free errors when a loan expires. - /// - /// This method examines all live origins at the expiry point and determines - /// if any of them hold the expiring loan. If so, it creates a pending - /// warning with the appropriate confidence level based on the liveness - /// information. The confidence reflects whether the origin is definitely - /// or maybe live at this point. - /// - /// Note: This implementation considers only the confidence of origin - /// liveness. Future enhancements could also consider the confidence of loan - /// propagation (e.g., a loan may only be held on some execution paths). - void checkExpiry(const ExpireFact *EF) { - LoanID ExpiredLoan = EF->getLoanID(); - LivenessMap Origins = LiveOrigins.getLiveOrigins(EF); - Confidence CurConfidence = Confidence::None; - const UseFact *BadUse = nullptr; - for (auto &[OID, LiveInfo] : Origins) { - LoanSet HeldLoans = LoanPropagation.getLoans(OID, EF); - if (!HeldLoans.contains(ExpiredLoan)) - continue; - // Loan is defaulted. - Confidence NewConfidence = livenessKindToConfidence(LiveInfo.Kind); - if (CurConfidence < NewConfidence) { - CurConfidence = NewConfidence; - BadUse = LiveInfo.CausingUseFact; - } - } - if (!BadUse) - return; - // We have a use-after-free. - Confidence LastConf = FinalWarningsMap.lookup(ExpiredLoan).ConfidenceLevel; - if (LastConf >= CurConfidence) - return; - FinalWarningsMap[ExpiredLoan] = {/*ExpiryLoc=*/EF->getExpiryLoc(), - /*UseExpr=*/BadUse->getUseExpr(), - /*ConfidenceLevel=*/CurConfidence}; - } - - static Confidence livenessKindToConfidence(LivenessKind K) { - switch (K) { - case LivenessKind::Must: - return Confidence::Definite; - case LivenessKind::Maybe: - return Confidence::Maybe; - case LivenessKind::Dead: - return Confidence::None; - } - llvm_unreachable("unknown liveness kind"); - } - - void issuePendingWarnings() { - if (!Reporter) - return; - for (const auto &[LID, Warning] : FinalWarningsMap) { - const Loan &L = FactMgr.getLoanMgr().getLoan(LID); - const Expr *IssueExpr = L.IssueExpr; - Reporter->reportUseAfterFree(IssueExpr, Warning.UseExpr, - Warning.ExpiryLoc, Warning.ConfidenceLevel); - } - } -}; - -// ========================================================================= // -// LifetimeSafetyAnalysis Class Implementation -// ========================================================================= // - -/// An object to hold the factories for immutable collections, ensuring -/// that all created states share the same underlying memory management. -struct LifetimeFactory { - llvm::BumpPtrAllocator Allocator; - OriginLoanMap::Factory OriginMapFactory{Allocator, /*canonicalize=*/false}; - LoanSet::Factory LoanSetFactory{Allocator, /*canonicalize=*/false}; - LivenessMap::Factory LivenessMapFactory{Allocator, /*canonicalize=*/false}; -}; - -// We need this here for unique_ptr with forward declared class. -LifetimeSafetyAnalysis::~LifetimeSafetyAnalysis() = default; - -LifetimeSafetyAnalysis::LifetimeSafetyAnalysis(AnalysisDeclContext &AC, - LifetimeSafetyReporter *Reporter) - : AC(AC), Reporter(Reporter), Factory(std::make_unique<LifetimeFactory>()), - FactMgr(std::make_unique<FactManager>()) {} - -void LifetimeSafetyAnalysis::run() { - llvm::TimeTraceScope TimeProfile("LifetimeSafetyAnalysis"); - - const CFG &Cfg = *AC.getCFG(); - DEBUG_WITH_TYPE("PrintCFG", Cfg.dump(AC.getASTContext().getLangOpts(), - /*ShowColors=*/true)); - - FactGenerator FactGen(*FactMgr, AC); - FactGen.run(); - DEBUG_WITH_TYPE("LifetimeFacts", FactMgr->dump(Cfg, AC)); - - /// TODO(opt): Consider optimizing individual blocks before running the - /// dataflow analysis. - /// 1. Expression Origins: These are assigned once and read at most once, - /// forming simple chains. These chains can be compressed into a single - /// assignment. - /// 2. Block-Local Loans: Origins of expressions are never read by other - /// blocks; only Decls are visible. Therefore, loans in a block that - /// never reach an Origin associated with a Decl can be safely dropped by - /// the analysis. - /// 3. Collapse ExpireFacts belonging to same source location into a single - /// Fact. - LoanPropagation = std::make_unique<LoanPropagationAnalysis>( - Cfg, AC, *FactMgr, Factory->OriginMapFactory, Factory->LoanSetFactory); - LoanPropagation->run(); - - LiveOrigins = std::make_unique<LiveOriginAnalysis>( - Cfg, AC, *FactMgr, Factory->LivenessMapFactory); - LiveOrigins->run(); - DEBUG_WITH_TYPE("LiveOrigins", LiveOrigins->dump(llvm::dbgs(), *this)); - - LifetimeChecker Checker(*LoanPropagation, *LiveOrigins, *FactMgr, AC, - Reporter); - Checker.run(); -} - -LoanSet LifetimeSafetyAnalysis::getLoansAtPoint(OriginID OID, - ProgramPoint PP) const { - assert(LoanPropagation && "Analysis has not been run."); - return LoanPropagation->getLoans(OID, PP); -} - -std::optional<OriginID> -LifetimeSafetyAnalysis::getOriginIDForDecl(const ValueDecl *D) const { - assert(FactMgr && "FactManager not initialized"); - // This assumes the OriginManager's `get` can find an existing origin. - // We might need a `find` method on OriginManager to avoid `getOrCreate` logic - // in a const-query context if that becomes an issue. - return FactMgr->getOriginMgr().get(*D); -} - -std::vector<LoanID> -LifetimeSafetyAnalysis::getLoanIDForVar(const VarDecl *VD) const { - assert(FactMgr && "FactManager not initialized"); - std::vector<LoanID> Result; - for (const Loan &L : FactMgr->getLoanMgr().getLoans()) - if (L.Path.D == VD) - Result.push_back(L.ID); - return Result; -} - -std::vector<std::pair<OriginID, LivenessKind>> -LifetimeSafetyAnalysis::getLiveOriginsAtPoint(ProgramPoint PP) const { - assert(LiveOrigins && "LiveOriginAnalysis has not been run."); - std::vector<std::pair<OriginID, LivenessKind>> Result; - for (auto &[OID, Info] : LiveOrigins->getLiveOrigins(PP)) - Result.push_back({OID, Info.Kind}); - return Result; -} - -llvm::StringMap<ProgramPoint> LifetimeSafetyAnalysis::getTestPoints() const { - assert(FactMgr && "FactManager not initialized"); - llvm::StringMap<ProgramPoint> AnnotationToPointMap; - for (const CFGBlock *Block : *AC.getCFG()) { - for (const Fact *F : FactMgr->getFacts(Block)) { - if (const auto *TPF = F->getAs<TestPointFact>()) { - StringRef PointName = TPF->getAnnotation(); - assert(AnnotationToPointMap.find(PointName) == - AnnotationToPointMap.end() && - "more than one test points with the same name"); - AnnotationToPointMap[PointName] = F; - } - } - } - return AnnotationToPointMap; -} -} // namespace internal - -void runLifetimeSafetyAnalysis(AnalysisDeclContext &AC, - LifetimeSafetyReporter *Reporter) { - internal::LifetimeSafetyAnalysis Analysis(AC, Reporter); - Analysis.run(); -} -} // namespace clang::lifetimes diff --git a/clang/lib/Analysis/LifetimeSafety/CMakeLists.txt b/clang/lib/Analysis/LifetimeSafety/CMakeLists.txt new file mode 100644 index 0000000..5874e84 --- /dev/null +++ b/clang/lib/Analysis/LifetimeSafety/CMakeLists.txt @@ -0,0 +1,17 @@ +add_clang_library(clangAnalysisLifetimeSafety + Checker.cpp + Facts.cpp + FactsGenerator.cpp + LifetimeAnnotations.cpp + LifetimeSafety.cpp + LiveOrigins.cpp + Loans.cpp + LoanPropagation.cpp + Origins.cpp + + LINK_LIBS + clangAST + clangAnalysis + clangBasic + ) + diff --git a/clang/lib/Analysis/LifetimeSafety/Checker.cpp b/clang/lib/Analysis/LifetimeSafety/Checker.cpp new file mode 100644 index 0000000..c443c3a --- /dev/null +++ b/clang/lib/Analysis/LifetimeSafety/Checker.cpp @@ -0,0 +1,130 @@ +//===- Checker.cpp - C++ Lifetime Safety Checker ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the LifetimeChecker, which detects use-after-free +// errors by checking if live origins hold loans that have expired. +// +//===----------------------------------------------------------------------===// + +#include "clang/Analysis/Analyses/LifetimeSafety/Checker.h" +#include "clang/AST/Expr.h" +#include "clang/Analysis/Analyses/LifetimeSafety/Facts.h" +#include "clang/Analysis/Analyses/LifetimeSafety/LiveOrigins.h" +#include "clang/Analysis/Analyses/LifetimeSafety/LoanPropagation.h" +#include "clang/Analysis/Analyses/LifetimeSafety/Loans.h" +#include "clang/Analysis/Analyses/PostOrderCFGView.h" +#include "clang/Analysis/AnalysisDeclContext.h" +#include "clang/Basic/SourceLocation.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/TimeProfiler.h" + +namespace clang::lifetimes::internal { + +static Confidence livenessKindToConfidence(LivenessKind K) { + switch (K) { + case LivenessKind::Must: + return Confidence::Definite; + case LivenessKind::Maybe: + return Confidence::Maybe; + case LivenessKind::Dead: + return Confidence::None; + } + llvm_unreachable("unknown liveness kind"); +} + +namespace { + +/// Struct to store the complete context for a potential lifetime violation. +struct PendingWarning { + SourceLocation ExpiryLoc; // Where the loan expired. + const Expr *UseExpr; // Where the origin holding this loan was used. + Confidence ConfidenceLevel; +}; + +class LifetimeChecker { +private: + llvm::DenseMap<LoanID, PendingWarning> FinalWarningsMap; + const LoanPropagationAnalysis &LoanPropagation; + const LiveOriginsAnalysis &LiveOrigins; + const FactManager &FactMgr; + LifetimeSafetyReporter *Reporter; + +public: + LifetimeChecker(const LoanPropagationAnalysis &LoanPropagation, + const LiveOriginsAnalysis &LiveOrigins, const FactManager &FM, + AnalysisDeclContext &ADC, LifetimeSafetyReporter *Reporter) + : LoanPropagation(LoanPropagation), LiveOrigins(LiveOrigins), FactMgr(FM), + Reporter(Reporter) { + for (const CFGBlock *B : *ADC.getAnalysis<PostOrderCFGView>()) + for (const Fact *F : FactMgr.getFacts(B)) + if (const auto *EF = F->getAs<ExpireFact>()) + checkExpiry(EF); + issuePendingWarnings(); + } + + /// Checks for use-after-free errors when a loan expires. + /// + /// This method examines all live origins at the expiry point and determines + /// if any of them hold the expiring loan. If so, it creates a pending + /// warning with the appropriate confidence level based on the liveness + /// information. The confidence reflects whether the origin is definitely + /// or maybe live at this point. + /// + /// Note: This implementation considers only the confidence of origin + /// liveness. Future enhancements could also consider the confidence of loan + /// propagation (e.g., a loan may only be held on some execution paths). + void checkExpiry(const ExpireFact *EF) { + LoanID ExpiredLoan = EF->getLoanID(); + LivenessMap Origins = LiveOrigins.getLiveOriginsAt(EF); + Confidence CurConfidence = Confidence::None; + const UseFact *BadUse = nullptr; + for (auto &[OID, LiveInfo] : Origins) { + LoanSet HeldLoans = LoanPropagation.getLoans(OID, EF); + if (!HeldLoans.contains(ExpiredLoan)) + continue; + // Loan is defaulted. + Confidence NewConfidence = livenessKindToConfidence(LiveInfo.Kind); + if (CurConfidence < NewConfidence) { + CurConfidence = NewConfidence; + BadUse = LiveInfo.CausingUseFact; + } + } + if (!BadUse) + return; + // We have a use-after-free. + Confidence LastConf = FinalWarningsMap.lookup(ExpiredLoan).ConfidenceLevel; + if (LastConf >= CurConfidence) + return; + FinalWarningsMap[ExpiredLoan] = {/*ExpiryLoc=*/EF->getExpiryLoc(), + /*UseExpr=*/BadUse->getUseExpr(), + /*ConfidenceLevel=*/CurConfidence}; + } + + void issuePendingWarnings() { + if (!Reporter) + return; + for (const auto &[LID, Warning] : FinalWarningsMap) { + const Loan &L = FactMgr.getLoanMgr().getLoan(LID); + const Expr *IssueExpr = L.IssueExpr; + Reporter->reportUseAfterFree(IssueExpr, Warning.UseExpr, + Warning.ExpiryLoc, Warning.ConfidenceLevel); + } + } +}; +} // namespace + +void runLifetimeChecker(const LoanPropagationAnalysis &LP, + const LiveOriginsAnalysis &LO, + const FactManager &FactMgr, AnalysisDeclContext &ADC, + LifetimeSafetyReporter *Reporter) { + llvm::TimeTraceScope TimeProfile("LifetimeChecker"); + LifetimeChecker Checker(LP, LO, FactMgr, ADC, Reporter); +} + +} // namespace clang::lifetimes::internal diff --git a/clang/lib/Analysis/LifetimeSafety/Dataflow.h b/clang/lib/Analysis/LifetimeSafety/Dataflow.h new file mode 100644 index 0000000..2f7bcb6 --- /dev/null +++ b/clang/lib/Analysis/LifetimeSafety/Dataflow.h @@ -0,0 +1,188 @@ +//===- Dataflow.h - Generic Dataflow Analysis Framework --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines a generic, policy-based driver for dataflow analyses. +// It provides a flexible framework that combines the dataflow runner and +// transfer functions, allowing derived classes to implement specific analyses +// by defining their lattice, join, and transfer functions. +// +//===----------------------------------------------------------------------===// +#ifndef LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_DATAFLOW_H +#define LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_DATAFLOW_H + +#include "clang/Analysis/Analyses/LifetimeSafety/Facts.h" +#include "clang/Analysis/AnalysisDeclContext.h" +#include "clang/Analysis/CFG.h" +#include "clang/Analysis/FlowSensitive/DataflowWorklist.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/TimeProfiler.h" +#include <optional> + +namespace clang::lifetimes::internal { + +enum class Direction { Forward, Backward }; + +/// A `ProgramPoint` identifies a location in the CFG by pointing to a specific +/// `Fact`. identified by a lifetime-related event (`Fact`). +/// +/// A `ProgramPoint` has "after" semantics: it represents the location +/// immediately after its corresponding `Fact`. +using ProgramPoint = const Fact *; + +/// A generic, policy-based driver for dataflow analyses. It combines +/// the dataflow runner and the transferer logic into a single class hierarchy. +/// +/// The derived class is expected to provide: +/// - A `Lattice` type. +/// - `StringRef getAnalysisName() const` +/// - `Lattice getInitialState();` The initial state of the analysis. +/// - `Lattice join(Lattice, Lattice);` Merges states from multiple CFG paths. +/// - `Lattice transfer(Lattice, const FactType&);` Defines how a single +/// lifetime-relevant `Fact` transforms the lattice state. Only overloads +/// for facts relevant to the analysis need to be implemented. +/// +/// \tparam Derived The CRTP derived class that implements the specific +/// analysis. +/// \tparam LatticeType The dataflow lattice used by the analysis. +/// \tparam Dir The direction of the analysis (Forward or Backward). +/// TODO: Maybe use the dataflow framework! The framework might need changes +/// to support the current comparison done at block-entry. +template <typename Derived, typename LatticeType, Direction Dir> +class DataflowAnalysis { +public: + using Lattice = LatticeType; + using Base = DataflowAnalysis<Derived, Lattice, Dir>; + +private: + const CFG &Cfg; + AnalysisDeclContext &AC; + + /// The dataflow state before a basic block is processed. + llvm::DenseMap<const CFGBlock *, Lattice> InStates; + /// The dataflow state after a basic block is processed. + llvm::DenseMap<const CFGBlock *, Lattice> OutStates; + /// The dataflow state at a Program Point. + /// In a forward analysis, this is the state after the Fact at that point has + /// been applied, while in a backward analysis, it is the state before. + llvm::DenseMap<ProgramPoint, Lattice> PerPointStates; + + static constexpr bool isForward() { return Dir == Direction::Forward; } + +protected: + FactManager &FactMgr; + + explicit DataflowAnalysis(const CFG &Cfg, AnalysisDeclContext &AC, + FactManager &FactMgr) + : Cfg(Cfg), AC(AC), FactMgr(FactMgr) {} + +public: + void run() { + Derived &D = static_cast<Derived &>(*this); + llvm::TimeTraceScope Time(D.getAnalysisName()); + + using Worklist = + std::conditional_t<Dir == Direction::Forward, ForwardDataflowWorklist, + BackwardDataflowWorklist>; + Worklist W(Cfg, AC); + + const CFGBlock *Start = isForward() ? &Cfg.getEntry() : &Cfg.getExit(); + InStates[Start] = D.getInitialState(); + W.enqueueBlock(Start); + + while (const CFGBlock *B = W.dequeue()) { + Lattice StateIn = *getInState(B); + Lattice StateOut = transferBlock(B, StateIn); + OutStates[B] = StateOut; + for (const CFGBlock *AdjacentB : isForward() ? B->succs() : B->preds()) { + if (!AdjacentB) + continue; + std::optional<Lattice> OldInState = getInState(AdjacentB); + Lattice NewInState = + !OldInState ? StateOut : D.join(*OldInState, StateOut); + // Enqueue the adjacent block if its in-state has changed or if we have + // never seen it. + if (!OldInState || NewInState != *OldInState) { + InStates[AdjacentB] = NewInState; + W.enqueueBlock(AdjacentB); + } + } + } + } + +protected: + Lattice getState(ProgramPoint P) const { return PerPointStates.lookup(P); } + + std::optional<Lattice> getInState(const CFGBlock *B) const { + auto It = InStates.find(B); + if (It == InStates.end()) + return std::nullopt; + return It->second; + } + + Lattice getOutState(const CFGBlock *B) const { return OutStates.lookup(B); } + + void dump() const { + const Derived *D = static_cast<const Derived *>(this); + llvm::dbgs() << "==========================================\n"; + llvm::dbgs() << D->getAnalysisName() << " results:\n"; + llvm::dbgs() << "==========================================\n"; + const CFGBlock &B = isForward() ? Cfg.getExit() : Cfg.getEntry(); + getOutState(&B).dump(llvm::dbgs()); + } + +private: + /// Computes the state at one end of a block by applying all its facts + /// sequentially to a given state from the other end. + Lattice transferBlock(const CFGBlock *Block, Lattice State) { + auto Facts = FactMgr.getFacts(Block); + if constexpr (isForward()) { + for (const Fact *F : Facts) { + State = transferFact(State, F); + PerPointStates[F] = State; + } + } else { + for (const Fact *F : llvm::reverse(Facts)) { + // In backward analysis, capture the state before applying the fact. + PerPointStates[F] = State; + State = transferFact(State, F); + } + } + return State; + } + + Lattice transferFact(Lattice In, const Fact *F) { + assert(F); + Derived *D = static_cast<Derived *>(this); + switch (F->getKind()) { + case Fact::Kind::Issue: + return D->transfer(In, *F->getAs<IssueFact>()); + case Fact::Kind::Expire: + return D->transfer(In, *F->getAs<ExpireFact>()); + case Fact::Kind::OriginFlow: + return D->transfer(In, *F->getAs<OriginFlowFact>()); + case Fact::Kind::ReturnOfOrigin: + return D->transfer(In, *F->getAs<ReturnOfOriginFact>()); + case Fact::Kind::Use: + return D->transfer(In, *F->getAs<UseFact>()); + case Fact::Kind::TestPoint: + return D->transfer(In, *F->getAs<TestPointFact>()); + } + llvm_unreachable("Unknown fact kind"); + } + +public: + Lattice transfer(Lattice In, const IssueFact &) { return In; } + Lattice transfer(Lattice In, const ExpireFact &) { return In; } + Lattice transfer(Lattice In, const OriginFlowFact &) { return In; } + Lattice transfer(Lattice In, const ReturnOfOriginFact &) { return In; } + Lattice transfer(Lattice In, const UseFact &) { return In; } + Lattice transfer(Lattice In, const TestPointFact &) { return In; } +}; +} // namespace clang::lifetimes::internal +#endif // LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_DATAFLOW_H diff --git a/clang/lib/Analysis/LifetimeSafety/Facts.cpp b/clang/lib/Analysis/LifetimeSafety/Facts.cpp new file mode 100644 index 0000000..1aea64f --- /dev/null +++ b/clang/lib/Analysis/LifetimeSafety/Facts.cpp @@ -0,0 +1,102 @@ +//===- Facts.cpp - Lifetime Analysis Facts Implementation -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "clang/Analysis/Analyses/LifetimeSafety/Facts.h" +#include "clang/AST/Decl.h" +#include "clang/Analysis/Analyses/PostOrderCFGView.h" + +namespace clang::lifetimes::internal { + +void Fact::dump(llvm::raw_ostream &OS, const LoanManager &, + const OriginManager &) const { + OS << "Fact (Kind: " << static_cast<int>(K) << ")\n"; +} + +void IssueFact::dump(llvm::raw_ostream &OS, const LoanManager &LM, + const OriginManager &OM) const { + OS << "Issue ("; + LM.getLoan(getLoanID()).dump(OS); + OS << ", ToOrigin: "; + OM.dump(getOriginID(), OS); + OS << ")\n"; +} + +void ExpireFact::dump(llvm::raw_ostream &OS, const LoanManager &LM, + const OriginManager &) const { + OS << "Expire ("; + LM.getLoan(getLoanID()).dump(OS); + OS << ")\n"; +} + +void OriginFlowFact::dump(llvm::raw_ostream &OS, const LoanManager &, + const OriginManager &OM) const { + OS << "OriginFlow (Dest: "; + OM.dump(getDestOriginID(), OS); + OS << ", Src: "; + OM.dump(getSrcOriginID(), OS); + OS << (getKillDest() ? "" : ", Merge"); + OS << ")\n"; +} + +void ReturnOfOriginFact::dump(llvm::raw_ostream &OS, const LoanManager &, + const OriginManager &OM) const { + OS << "ReturnOfOrigin ("; + OM.dump(getReturnedOriginID(), OS); + OS << ")\n"; +} + +void UseFact::dump(llvm::raw_ostream &OS, const LoanManager &, + const OriginManager &OM) const { + OS << "Use ("; + OM.dump(getUsedOrigin(OM), OS); + OS << ", " << (isWritten() ? "Write" : "Read") << ")\n"; +} + +void TestPointFact::dump(llvm::raw_ostream &OS, const LoanManager &, + const OriginManager &) const { + OS << "TestPoint (Annotation: \"" << getAnnotation() << "\")\n"; +} + +llvm::StringMap<ProgramPoint> FactManager::getTestPoints() const { + llvm::StringMap<ProgramPoint> AnnotationToPointMap; + for (const CFGBlock *Block : BlockToFactsMap.keys()) { + for (const Fact *F : getFacts(Block)) { + if (const auto *TPF = F->getAs<TestPointFact>()) { + StringRef PointName = TPF->getAnnotation(); + assert(AnnotationToPointMap.find(PointName) == + AnnotationToPointMap.end() && + "more than one test points with the same name"); + AnnotationToPointMap[PointName] = F; + } + } + } + return AnnotationToPointMap; +} + +void FactManager::dump(const CFG &Cfg, AnalysisDeclContext &AC) const { + llvm::dbgs() << "==========================================\n"; + llvm::dbgs() << " Lifetime Analysis Facts:\n"; + llvm::dbgs() << "==========================================\n"; + if (const Decl *D = AC.getDecl()) + if (const auto *ND = dyn_cast<NamedDecl>(D)) + llvm::dbgs() << "Function: " << ND->getQualifiedNameAsString() << "\n"; + // Print blocks in the order as they appear in code for a stable ordering. + for (const CFGBlock *B : *AC.getAnalysis<PostOrderCFGView>()) { + llvm::dbgs() << " Block B" << B->getBlockID() << ":\n"; + auto It = BlockToFactsMap.find(B); + if (It != BlockToFactsMap.end()) { + for (const Fact *F : It->second) { + llvm::dbgs() << " "; + F->dump(llvm::dbgs(), LoanMgr, OriginMgr); + } + } + llvm::dbgs() << " End of Block\n"; + } +} + +} // namespace clang::lifetimes::internal diff --git a/clang/lib/Analysis/LifetimeSafety/FactsGenerator.cpp b/clang/lib/Analysis/LifetimeSafety/FactsGenerator.cpp new file mode 100644 index 0000000..485308f --- /dev/null +++ b/clang/lib/Analysis/LifetimeSafety/FactsGenerator.cpp @@ -0,0 +1,348 @@ +//===- FactsGenerator.cpp - Lifetime Facts Generation -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "clang/Analysis/Analyses/LifetimeSafety/FactsGenerator.h" +#include "clang/Analysis/Analyses/LifetimeSafety/LifetimeAnnotations.h" +#include "clang/Analysis/Analyses/PostOrderCFGView.h" +#include "llvm/Support/TimeProfiler.h" + +namespace clang::lifetimes::internal { + +static bool isGslPointerType(QualType QT) { + if (const auto *RD = QT->getAsCXXRecordDecl()) { + // We need to check the template definition for specializations. + if (auto *CTSD = dyn_cast<ClassTemplateSpecializationDecl>(RD)) + return CTSD->getSpecializedTemplate() + ->getTemplatedDecl() + ->hasAttr<PointerAttr>(); + return RD->hasAttr<PointerAttr>(); + } + return false; +} + +static bool isPointerType(QualType QT) { + return QT->isPointerOrReferenceType() || isGslPointerType(QT); +} +// Check if a type has an origin. +static bool hasOrigin(const Expr *E) { + return E->isGLValue() || isPointerType(E->getType()); +} + +static bool hasOrigin(const VarDecl *VD) { + return isPointerType(VD->getType()); +} + +/// Creates a loan for the storage path of a given declaration reference. +/// This function should be called whenever a DeclRefExpr represents a borrow. +/// \param DRE The declaration reference expression that initiates the borrow. +/// \return The new Loan on success, nullptr otherwise. +static const Loan *createLoan(FactManager &FactMgr, const DeclRefExpr *DRE) { + if (const auto *VD = dyn_cast<ValueDecl>(DRE->getDecl())) { + AccessPath Path(VD); + // The loan is created at the location of the DeclRefExpr. + return &FactMgr.getLoanMgr().addLoan(Path, DRE); + } + return nullptr; +} + +void FactsGenerator::run() { + llvm::TimeTraceScope TimeProfile("FactGenerator"); + // Iterate through the CFG blocks in reverse post-order to ensure that + // initializations and destructions are processed in the correct sequence. + for (const CFGBlock *Block : *AC.getAnalysis<PostOrderCFGView>()) { + CurrentBlockFacts.clear(); + for (unsigned I = 0; I < Block->size(); ++I) { + const CFGElement &Element = Block->Elements[I]; + if (std::optional<CFGStmt> CS = Element.getAs<CFGStmt>()) + Visit(CS->getStmt()); + else if (std::optional<CFGAutomaticObjDtor> DtorOpt = + Element.getAs<CFGAutomaticObjDtor>()) + handleDestructor(*DtorOpt); + } + FactMgr.addBlockFacts(Block, CurrentBlockFacts); + } +} + +void FactsGenerator::VisitDeclStmt(const DeclStmt *DS) { + for (const Decl *D : DS->decls()) + if (const auto *VD = dyn_cast<VarDecl>(D)) + if (hasOrigin(VD)) + if (const Expr *InitExpr = VD->getInit()) + killAndFlowOrigin(*VD, *InitExpr); +} + +void FactsGenerator::VisitDeclRefExpr(const DeclRefExpr *DRE) { + handleUse(DRE); + // For non-pointer/non-view types, a reference to the variable's storage + // is a borrow. We create a loan for it. + // For pointer/view types, we stick to the existing model for now and do + // not create an extra origin for the l-value expression itself. + + // TODO: A single origin for a `DeclRefExpr` for a pointer or view type is + // not sufficient to model the different levels of indirection. The current + // single-origin model cannot distinguish between a loan to the variable's + // storage and a loan to what it points to. A multi-origin model would be + // required for this. + if (!isPointerType(DRE->getType())) { + if (const Loan *L = createLoan(FactMgr, DRE)) { + OriginID ExprOID = FactMgr.getOriginMgr().getOrCreate(*DRE); + CurrentBlockFacts.push_back( + FactMgr.createFact<IssueFact>(L->ID, ExprOID)); + } + } +} + +void FactsGenerator::VisitCXXConstructExpr(const CXXConstructExpr *CCE) { + if (isGslPointerType(CCE->getType())) { + handleGSLPointerConstruction(CCE); + return; + } +} + +void FactsGenerator::VisitCXXMemberCallExpr(const CXXMemberCallExpr *MCE) { + // Specifically for conversion operators, + // like `std::string_view p = std::string{};` + if (isGslPointerType(MCE->getType()) && + isa<CXXConversionDecl>(MCE->getCalleeDecl())) { + // The argument is the implicit object itself. + handleFunctionCall(MCE, MCE->getMethodDecl(), + {MCE->getImplicitObjectArgument()}, + /*IsGslConstruction=*/true); + } + if (const CXXMethodDecl *Method = MCE->getMethodDecl()) { + // Construct the argument list, with the implicit 'this' object as the + // first argument. + llvm::SmallVector<const Expr *, 4> Args; + Args.push_back(MCE->getImplicitObjectArgument()); + Args.append(MCE->getArgs(), MCE->getArgs() + MCE->getNumArgs()); + + handleFunctionCall(MCE, Method, Args, /*IsGslConstruction=*/false); + } +} + +void FactsGenerator::VisitCallExpr(const CallExpr *CE) { + handleFunctionCall(CE, CE->getDirectCallee(), + {CE->getArgs(), CE->getNumArgs()}); +} + +void FactsGenerator::VisitCXXNullPtrLiteralExpr( + const CXXNullPtrLiteralExpr *N) { + /// TODO: Handle nullptr expr as a special 'null' loan. Uninitialized + /// pointers can use the same type of loan. + FactMgr.getOriginMgr().getOrCreate(*N); +} + +void FactsGenerator::VisitImplicitCastExpr(const ImplicitCastExpr *ICE) { + if (!hasOrigin(ICE)) + return; + // An ImplicitCastExpr node itself gets an origin, which flows from the + // origin of its sub-expression (after stripping its own parens/casts). + killAndFlowOrigin(*ICE, *ICE->getSubExpr()); +} + +void FactsGenerator::VisitUnaryOperator(const UnaryOperator *UO) { + if (UO->getOpcode() == UO_AddrOf) { + const Expr *SubExpr = UO->getSubExpr(); + // Taking address of a pointer-type expression is not yet supported and + // will be supported in multi-origin model. + if (isPointerType(SubExpr->getType())) + return; + // The origin of an address-of expression (e.g., &x) is the origin of + // its sub-expression (x). This fact will cause the dataflow analysis + // to propagate any loans held by the sub-expression's origin to the + // origin of this UnaryOperator expression. + killAndFlowOrigin(*UO, *SubExpr); + } +} + +void FactsGenerator::VisitReturnStmt(const ReturnStmt *RS) { + if (const Expr *RetExpr = RS->getRetValue()) { + if (hasOrigin(RetExpr)) { + OriginID OID = FactMgr.getOriginMgr().getOrCreate(*RetExpr); + CurrentBlockFacts.push_back(FactMgr.createFact<ReturnOfOriginFact>(OID)); + } + } +} + +void FactsGenerator::VisitBinaryOperator(const BinaryOperator *BO) { + if (BO->isAssignmentOp()) + handleAssignment(BO->getLHS(), BO->getRHS()); +} + +void FactsGenerator::VisitCXXOperatorCallExpr(const CXXOperatorCallExpr *OCE) { + // Assignment operators have special "kill-then-propagate" semantics + // and are handled separately. + if (OCE->isAssignmentOp() && OCE->getNumArgs() == 2) { + handleAssignment(OCE->getArg(0), OCE->getArg(1)); + return; + } + handleFunctionCall(OCE, OCE->getDirectCallee(), + {OCE->getArgs(), OCE->getNumArgs()}, + /*IsGslConstruction=*/false); +} + +void FactsGenerator::VisitCXXFunctionalCastExpr( + const CXXFunctionalCastExpr *FCE) { + // Check if this is a test point marker. If so, we are done with this + // expression. + if (handleTestPoint(FCE)) + return; + if (isGslPointerType(FCE->getType())) + killAndFlowOrigin(*FCE, *FCE->getSubExpr()); +} + +void FactsGenerator::VisitInitListExpr(const InitListExpr *ILE) { + if (!hasOrigin(ILE)) + return; + // For list initialization with a single element, like `View{...}`, the + // origin of the list itself is the origin of its single element. + if (ILE->getNumInits() == 1) + killAndFlowOrigin(*ILE, *ILE->getInit(0)); +} + +void FactsGenerator::VisitMaterializeTemporaryExpr( + const MaterializeTemporaryExpr *MTE) { + if (!hasOrigin(MTE)) + return; + // A temporary object's origin is the same as the origin of the + // expression that initializes it. + killAndFlowOrigin(*MTE, *MTE->getSubExpr()); +} + +void FactsGenerator::handleDestructor(const CFGAutomaticObjDtor &DtorOpt) { + /// TODO: Also handle trivial destructors (e.g., for `int` + /// variables) which will never have a CFGAutomaticObjDtor node. + /// TODO: Handle loans to temporaries. + /// TODO: Consider using clang::CFG::BuildOptions::AddLifetime to reuse the + /// lifetime ends. + const VarDecl *DestructedVD = DtorOpt.getVarDecl(); + if (!DestructedVD) + return; + // Iterate through all loans to see if any expire. + /// TODO(opt): Do better than a linear search to find loans associated with + /// 'DestructedVD'. + for (const Loan &L : FactMgr.getLoanMgr().getLoans()) { + const AccessPath &LoanPath = L.Path; + // Check if the loan is for a stack variable and if that variable + // is the one being destructed. + if (LoanPath.D == DestructedVD) + CurrentBlockFacts.push_back(FactMgr.createFact<ExpireFact>( + L.ID, DtorOpt.getTriggerStmt()->getEndLoc())); + } +} + +void FactsGenerator::handleGSLPointerConstruction(const CXXConstructExpr *CCE) { + assert(isGslPointerType(CCE->getType())); + if (CCE->getNumArgs() != 1) + return; + if (hasOrigin(CCE->getArg(0))) + killAndFlowOrigin(*CCE, *CCE->getArg(0)); + else + // This could be a new borrow. + handleFunctionCall(CCE, CCE->getConstructor(), + {CCE->getArgs(), CCE->getNumArgs()}, + /*IsGslConstruction=*/true); +} + +/// Checks if a call-like expression creates a borrow by passing a value to a +/// reference parameter, creating an IssueFact if it does. +/// \param IsGslConstruction True if this is a GSL construction where all +/// argument origins should flow to the returned origin. +void FactsGenerator::handleFunctionCall(const Expr *Call, + const FunctionDecl *FD, + ArrayRef<const Expr *> Args, + bool IsGslConstruction) { + // Ignore functions returning values with no origin. + if (!FD || !hasOrigin(Call)) + return; + auto IsArgLifetimeBound = [FD](unsigned I) -> bool { + const ParmVarDecl *PVD = nullptr; + if (const auto *Method = dyn_cast<CXXMethodDecl>(FD); + Method && Method->isInstance()) { + if (I == 0) + // For the 'this' argument, the attribute is on the method itself. + return implicitObjectParamIsLifetimeBound(Method); + if ((I - 1) < Method->getNumParams()) + // For explicit arguments, find the corresponding parameter + // declaration. + PVD = Method->getParamDecl(I - 1); + } else if (I < FD->getNumParams()) + // For free functions or static methods. + PVD = FD->getParamDecl(I); + return PVD ? PVD->hasAttr<clang::LifetimeBoundAttr>() : false; + }; + if (Args.empty()) + return; + bool killedSrc = false; + for (unsigned I = 0; I < Args.size(); ++I) + if (IsGslConstruction || IsArgLifetimeBound(I)) { + if (!killedSrc) { + killedSrc = true; + killAndFlowOrigin(*Call, *Args[I]); + } else + flowOrigin(*Call, *Args[I]); + } +} + +/// Checks if the expression is a `void("__lifetime_test_point_...")` cast. +/// If so, creates a `TestPointFact` and returns true. +bool FactsGenerator::handleTestPoint(const CXXFunctionalCastExpr *FCE) { + if (!FCE->getType()->isVoidType()) + return false; + + const auto *SubExpr = FCE->getSubExpr()->IgnoreParenImpCasts(); + if (const auto *SL = dyn_cast<StringLiteral>(SubExpr)) { + llvm::StringRef LiteralValue = SL->getString(); + const std::string Prefix = "__lifetime_test_point_"; + + if (LiteralValue.starts_with(Prefix)) { + StringRef Annotation = LiteralValue.drop_front(Prefix.length()); + CurrentBlockFacts.push_back( + FactMgr.createFact<TestPointFact>(Annotation)); + return true; + } + } + return false; +} + +void FactsGenerator::handleAssignment(const Expr *LHSExpr, + const Expr *RHSExpr) { + if (!hasOrigin(LHSExpr)) + return; + // Find the underlying variable declaration for the left-hand side. + if (const auto *DRE_LHS = + dyn_cast<DeclRefExpr>(LHSExpr->IgnoreParenImpCasts())) { + markUseAsWrite(DRE_LHS); + if (const auto *VD_LHS = dyn_cast<ValueDecl>(DRE_LHS->getDecl())) { + // Kill the old loans of the destination origin and flow the new loans + // from the source origin. + killAndFlowOrigin(*VD_LHS, *RHSExpr); + } + } +} + +// A DeclRefExpr will be treated as a use of the referenced decl. It will be +// checked for use-after-free unless it is later marked as being written to +// (e.g. on the left-hand side of an assignment). +void FactsGenerator::handleUse(const DeclRefExpr *DRE) { + if (isPointerType(DRE->getType())) { + UseFact *UF = FactMgr.createFact<UseFact>(DRE); + CurrentBlockFacts.push_back(UF); + assert(!UseFacts.contains(DRE)); + UseFacts[DRE] = UF; + } +} + +void FactsGenerator::markUseAsWrite(const DeclRefExpr *DRE) { + if (!isPointerType(DRE->getType())) + return; + assert(UseFacts.contains(DRE)); + UseFacts[DRE]->markAsWritten(); +} + +} // namespace clang::lifetimes::internal diff --git a/clang/lib/Analysis/LifetimeAnnotations.cpp b/clang/lib/Analysis/LifetimeSafety/LifetimeAnnotations.cpp index e791224..ad61a42 100644 --- a/clang/lib/Analysis/LifetimeAnnotations.cpp +++ b/clang/lib/Analysis/LifetimeSafety/LifetimeAnnotations.cpp @@ -5,7 +5,7 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -#include "clang/Analysis/Analyses/LifetimeAnnotations.h" +#include "clang/Analysis/Analyses/LifetimeSafety/LifetimeAnnotations.h" #include "clang/AST/ASTContext.h" #include "clang/AST/Attr.h" #include "clang/AST/Decl.h" @@ -13,8 +13,7 @@ #include "clang/AST/Type.h" #include "clang/AST/TypeLoc.h" -namespace clang { -namespace lifetimes { +namespace clang::lifetimes { const FunctionDecl * getDeclWithMergedLifetimeBoundAttrs(const FunctionDecl *FD) { @@ -71,5 +70,4 @@ bool implicitObjectParamIsLifetimeBound(const FunctionDecl *FD) { return isNormalAssignmentOperator(FD); } -} // namespace lifetimes -} // namespace clang +} // namespace clang::lifetimes diff --git a/clang/lib/Analysis/LifetimeSafety/LifetimeSafety.cpp b/clang/lib/Analysis/LifetimeSafety/LifetimeSafety.cpp new file mode 100644 index 0000000..00c7ed90 --- /dev/null +++ b/clang/lib/Analysis/LifetimeSafety/LifetimeSafety.cpp @@ -0,0 +1,77 @@ +//===- LifetimeSafety.cpp - C++ Lifetime Safety Analysis -*--------- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the main LifetimeSafetyAnalysis class, which coordinates +// the various components (fact generation, loan propagation, live origins +// analysis, and checking) to detect lifetime safety violations in C++ code. +// +//===----------------------------------------------------------------------===// +#include "clang/Analysis/Analyses/LifetimeSafety/LifetimeSafety.h" +#include "clang/AST/Decl.h" +#include "clang/AST/Expr.h" +#include "clang/AST/Type.h" +#include "clang/Analysis/Analyses/LifetimeSafety/Checker.h" +#include "clang/Analysis/Analyses/LifetimeSafety/Facts.h" +#include "clang/Analysis/Analyses/LifetimeSafety/FactsGenerator.h" +#include "clang/Analysis/Analyses/LifetimeSafety/LiveOrigins.h" +#include "clang/Analysis/Analyses/LifetimeSafety/LoanPropagation.h" +#include "clang/Analysis/AnalysisDeclContext.h" +#include "clang/Analysis/CFG.h" +#include "llvm/ADT/FoldingSet.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/TimeProfiler.h" +#include <memory> + +namespace clang::lifetimes { +namespace internal { + +LifetimeSafetyAnalysis::LifetimeSafetyAnalysis(AnalysisDeclContext &AC, + LifetimeSafetyReporter *Reporter) + : AC(AC), Reporter(Reporter) {} + +void LifetimeSafetyAnalysis::run() { + llvm::TimeTraceScope TimeProfile("LifetimeSafetyAnalysis"); + + const CFG &Cfg = *AC.getCFG(); + DEBUG_WITH_TYPE("PrintCFG", Cfg.dump(AC.getASTContext().getLangOpts(), + /*ShowColors=*/true)); + + FactsGenerator FactGen(FactMgr, AC); + FactGen.run(); + DEBUG_WITH_TYPE("LifetimeFacts", FactMgr.dump(Cfg, AC)); + + /// TODO(opt): Consider optimizing individual blocks before running the + /// dataflow analysis. + /// 1. Expression Origins: These are assigned once and read at most once, + /// forming simple chains. These chains can be compressed into a single + /// assignment. + /// 2. Block-Local Loans: Origins of expressions are never read by other + /// blocks; only Decls are visible. Therefore, loans in a block that + /// never reach an Origin associated with a Decl can be safely dropped by + /// the analysis. + /// 3. Collapse ExpireFacts belonging to same source location into a single + /// Fact. + LoanPropagation = std::make_unique<LoanPropagationAnalysis>( + Cfg, AC, FactMgr, Factory.OriginMapFactory, Factory.LoanSetFactory); + + LiveOrigins = std::make_unique<LiveOriginsAnalysis>( + Cfg, AC, FactMgr, Factory.LivenessMapFactory); + DEBUG_WITH_TYPE("LiveOrigins", + LiveOrigins->dump(llvm::dbgs(), FactMgr.getTestPoints())); + + runLifetimeChecker(*LoanPropagation, *LiveOrigins, FactMgr, AC, Reporter); +} +} // namespace internal + +void runLifetimeSafetyAnalysis(AnalysisDeclContext &AC, + LifetimeSafetyReporter *Reporter) { + internal::LifetimeSafetyAnalysis Analysis(AC, Reporter); + Analysis.run(); +} +} // namespace clang::lifetimes diff --git a/clang/lib/Analysis/LifetimeSafety/LiveOrigins.cpp b/clang/lib/Analysis/LifetimeSafety/LiveOrigins.cpp new file mode 100644 index 0000000..cddb3f3c --- /dev/null +++ b/clang/lib/Analysis/LifetimeSafety/LiveOrigins.cpp @@ -0,0 +1,180 @@ +//===- LiveOrigins.cpp - Live Origins Analysis -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "clang/Analysis/Analyses/LifetimeSafety/LiveOrigins.h" +#include "Dataflow.h" +#include "llvm/Support/ErrorHandling.h" + +namespace clang::lifetimes::internal { +namespace { + +/// The dataflow lattice for origin liveness analysis. +/// It tracks which origins are live, why they're live (which UseFact), +/// and the confidence level of that liveness. +struct Lattice { + LivenessMap LiveOrigins; + + Lattice() : LiveOrigins(nullptr) {}; + + explicit Lattice(LivenessMap L) : LiveOrigins(L) {} + + bool operator==(const Lattice &Other) const { + return LiveOrigins == Other.LiveOrigins; + } + + bool operator!=(const Lattice &Other) const { return !(*this == Other); } + + void dump(llvm::raw_ostream &OS, const OriginManager &OM) const { + if (LiveOrigins.isEmpty()) + OS << " <empty>\n"; + for (const auto &Entry : LiveOrigins) { + OriginID OID = Entry.first; + const LivenessInfo &Info = Entry.second; + OS << " "; + OM.dump(OID, OS); + OS << " is "; + switch (Info.Kind) { + case LivenessKind::Must: + OS << "definitely"; + break; + case LivenessKind::Maybe: + OS << "maybe"; + break; + case LivenessKind::Dead: + llvm_unreachable("liveness kind of live origins should not be dead."); + } + OS << " live at this point\n"; + } + } +}; + +/// The analysis that tracks which origins are live, with granular information +/// about the causing use fact and confidence level. This is a backward +/// analysis. +class AnalysisImpl + : public DataflowAnalysis<AnalysisImpl, Lattice, Direction::Backward> { + +public: + AnalysisImpl(const CFG &C, AnalysisDeclContext &AC, FactManager &F, + LivenessMap::Factory &SF) + : DataflowAnalysis(C, AC, F), FactMgr(F), Factory(SF) {} + using DataflowAnalysis<AnalysisImpl, Lattice, Direction::Backward>::transfer; + + StringRef getAnalysisName() const { return "LiveOrigins"; } + + Lattice getInitialState() { return Lattice(Factory.getEmptyMap()); } + + /// Merges two lattices by combining liveness information. + /// When the same origin has different confidence levels, we take the lower + /// one. + Lattice join(Lattice L1, Lattice L2) const { + LivenessMap Merged = L1.LiveOrigins; + // Take the earliest UseFact to make the join hermetic and commutative. + auto CombineUseFact = [](const UseFact &A, + const UseFact &B) -> const UseFact * { + return A.getUseExpr()->getExprLoc() < B.getUseExpr()->getExprLoc() ? &A + : &B; + }; + auto CombineLivenessKind = [](LivenessKind K1, + LivenessKind K2) -> LivenessKind { + assert(K1 != LivenessKind::Dead && "LivenessKind should not be dead."); + assert(K2 != LivenessKind::Dead && "LivenessKind should not be dead."); + // Only return "Must" if both paths are "Must", otherwise Maybe. + if (K1 == LivenessKind::Must && K2 == LivenessKind::Must) + return LivenessKind::Must; + return LivenessKind::Maybe; + }; + auto CombineLivenessInfo = [&](const LivenessInfo *L1, + const LivenessInfo *L2) -> LivenessInfo { + assert((L1 || L2) && "unexpectedly merging 2 empty sets"); + if (!L1) + return LivenessInfo(L2->CausingUseFact, LivenessKind::Maybe); + if (!L2) + return LivenessInfo(L1->CausingUseFact, LivenessKind::Maybe); + return LivenessInfo( + CombineUseFact(*L1->CausingUseFact, *L2->CausingUseFact), + CombineLivenessKind(L1->Kind, L2->Kind)); + }; + return Lattice(utils::join( + L1.LiveOrigins, L2.LiveOrigins, Factory, CombineLivenessInfo, + // A symmetric join is required here. If an origin is live on one + // branch but not the other, its confidence must be demoted to `Maybe`. + utils::JoinKind::Symmetric)); + } + + /// A read operation makes the origin live with definite confidence, as it + /// dominates this program point. A write operation kills the liveness of + /// the origin since it overwrites the value. + Lattice transfer(Lattice In, const UseFact &UF) { + OriginID OID = UF.getUsedOrigin(FactMgr.getOriginMgr()); + // Write kills liveness. + if (UF.isWritten()) + return Lattice(Factory.remove(In.LiveOrigins, OID)); + // Read makes origin live with definite confidence (dominates this point). + return Lattice(Factory.add(In.LiveOrigins, OID, + LivenessInfo(&UF, LivenessKind::Must))); + } + + /// Issuing a new loan to an origin kills its liveness. + Lattice transfer(Lattice In, const IssueFact &IF) { + return Lattice(Factory.remove(In.LiveOrigins, IF.getOriginID())); + } + + /// An OriginFlow kills the liveness of the destination origin if `KillDest` + /// is true. Otherwise, it propagates liveness from destination to source. + Lattice transfer(Lattice In, const OriginFlowFact &OF) { + if (!OF.getKillDest()) + return In; + return Lattice(Factory.remove(In.LiveOrigins, OF.getDestOriginID())); + } + + LivenessMap getLiveOriginsAt(ProgramPoint P) const { + return getState(P).LiveOrigins; + } + + // Dump liveness values on all test points in the program. + void dump(llvm::raw_ostream &OS, + llvm::StringMap<ProgramPoint> TestPoints) const { + llvm::dbgs() << "==========================================\n"; + llvm::dbgs() << getAnalysisName() << " results:\n"; + llvm::dbgs() << "==========================================\n"; + for (const auto &Entry : TestPoints) { + OS << "TestPoint: " << Entry.getKey() << "\n"; + getState(Entry.getValue()).dump(OS, FactMgr.getOriginMgr()); + } + } + +private: + FactManager &FactMgr; + LivenessMap::Factory &Factory; +}; +} // namespace + +// PImpl wrapper implementation +class LiveOriginsAnalysis::Impl : public AnalysisImpl { + using AnalysisImpl::AnalysisImpl; +}; + +LiveOriginsAnalysis::LiveOriginsAnalysis(const CFG &C, AnalysisDeclContext &AC, + FactManager &F, + LivenessMap::Factory &SF) + : PImpl(std::make_unique<Impl>(C, AC, F, SF)) { + PImpl->run(); +} + +LiveOriginsAnalysis::~LiveOriginsAnalysis() = default; + +LivenessMap LiveOriginsAnalysis::getLiveOriginsAt(ProgramPoint P) const { + return PImpl->getLiveOriginsAt(P); +} + +void LiveOriginsAnalysis::dump(llvm::raw_ostream &OS, + llvm::StringMap<ProgramPoint> TestPoints) const { + PImpl->dump(OS, TestPoints); +} +} // namespace clang::lifetimes::internal diff --git a/clang/lib/Analysis/LifetimeSafety/LoanPropagation.cpp b/clang/lib/Analysis/LifetimeSafety/LoanPropagation.cpp new file mode 100644 index 0000000..387097e --- /dev/null +++ b/clang/lib/Analysis/LifetimeSafety/LoanPropagation.cpp @@ -0,0 +1,138 @@ +//===- LoanPropagation.cpp - Loan Propagation Analysis ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "clang/Analysis/Analyses/LifetimeSafety/LoanPropagation.h" +#include "Dataflow.h" +#include <memory> + +namespace clang::lifetimes::internal { +namespace { +/// Represents the dataflow lattice for loan propagation. +/// +/// This lattice tracks which loans each origin may hold at a given program +/// point.The lattice has a finite height: An origin's loan set is bounded by +/// the total number of loans in the function. +/// TODO(opt): To reduce the lattice size, propagate origins of declarations, +/// not expressions, because expressions are not visible across blocks. +struct Lattice { + /// The map from an origin to the set of loans it contains. + OriginLoanMap Origins = OriginLoanMap(nullptr); + + explicit Lattice(const OriginLoanMap &S) : Origins(S) {} + Lattice() = default; + + bool operator==(const Lattice &Other) const { + return Origins == Other.Origins; + } + bool operator!=(const Lattice &Other) const { return !(*this == Other); } + + void dump(llvm::raw_ostream &OS) const { + OS << "LoanPropagationLattice State:\n"; + if (Origins.isEmpty()) + OS << " <empty>\n"; + for (const auto &Entry : Origins) { + if (Entry.second.isEmpty()) + OS << " Origin " << Entry.first << " contains no loans\n"; + for (const LoanID &LID : Entry.second) + OS << " Origin " << Entry.first << " contains Loan " << LID << "\n"; + } + } +}; + +class AnalysisImpl + : public DataflowAnalysis<AnalysisImpl, Lattice, Direction::Forward> { +public: + AnalysisImpl(const CFG &C, AnalysisDeclContext &AC, FactManager &F, + OriginLoanMap::Factory &OriginLoanMapFactory, + LoanSet::Factory &LoanSetFactory) + : DataflowAnalysis(C, AC, F), OriginLoanMapFactory(OriginLoanMapFactory), + LoanSetFactory(LoanSetFactory) {} + + using Base::transfer; + + StringRef getAnalysisName() const { return "LoanPropagation"; } + + Lattice getInitialState() { return Lattice{}; } + + /// Merges two lattices by taking the union of loans for each origin. + // TODO(opt): Keep the state small by removing origins which become dead. + Lattice join(Lattice A, Lattice B) { + OriginLoanMap JoinedOrigins = utils::join( + A.Origins, B.Origins, OriginLoanMapFactory, + [&](const LoanSet *S1, const LoanSet *S2) { + assert((S1 || S2) && "unexpectedly merging 2 empty sets"); + if (!S1) + return *S2; + if (!S2) + return *S1; + return utils::join(*S1, *S2, LoanSetFactory); + }, + // Asymmetric join is a performance win. For origins present only on one + // branch, the loan set can be carried over as-is. + utils::JoinKind::Asymmetric); + return Lattice(JoinedOrigins); + } + + /// A new loan is issued to the origin. Old loans are erased. + Lattice transfer(Lattice In, const IssueFact &F) { + OriginID OID = F.getOriginID(); + LoanID LID = F.getLoanID(); + return Lattice(OriginLoanMapFactory.add( + In.Origins, OID, + LoanSetFactory.add(LoanSetFactory.getEmptySet(), LID))); + } + + /// A flow from source to destination. If `KillDest` is true, this replaces + /// the destination's loans with the source's. Otherwise, the source's loans + /// are merged into the destination's. + Lattice transfer(Lattice In, const OriginFlowFact &F) { + OriginID DestOID = F.getDestOriginID(); + OriginID SrcOID = F.getSrcOriginID(); + + LoanSet DestLoans = + F.getKillDest() ? LoanSetFactory.getEmptySet() : getLoans(In, DestOID); + LoanSet SrcLoans = getLoans(In, SrcOID); + LoanSet MergedLoans = utils::join(DestLoans, SrcLoans, LoanSetFactory); + + return Lattice(OriginLoanMapFactory.add(In.Origins, DestOID, MergedLoans)); + } + + LoanSet getLoans(OriginID OID, ProgramPoint P) const { + return getLoans(getState(P), OID); + } + +private: + LoanSet getLoans(Lattice L, OriginID OID) const { + if (auto *Loans = L.Origins.lookup(OID)) + return *Loans; + return LoanSetFactory.getEmptySet(); + } + + OriginLoanMap::Factory &OriginLoanMapFactory; + LoanSet::Factory &LoanSetFactory; +}; +} // namespace + +class LoanPropagationAnalysis::Impl final : public AnalysisImpl { + using AnalysisImpl::AnalysisImpl; +}; + +LoanPropagationAnalysis::LoanPropagationAnalysis( + const CFG &C, AnalysisDeclContext &AC, FactManager &F, + OriginLoanMap::Factory &OriginLoanMapFactory, + LoanSet::Factory &LoanSetFactory) + : PImpl(std::make_unique<Impl>(C, AC, F, OriginLoanMapFactory, + LoanSetFactory)) { + PImpl->run(); +} + +LoanPropagationAnalysis::~LoanPropagationAnalysis() = default; + +LoanSet LoanPropagationAnalysis::getLoans(OriginID OID, ProgramPoint P) const { + return PImpl->getLoans(OID, P); +} +} // namespace clang::lifetimes::internal diff --git a/clang/lib/Analysis/LifetimeSafety/Loans.cpp b/clang/lib/Analysis/LifetimeSafety/Loans.cpp new file mode 100644 index 0000000..2c85a3c --- /dev/null +++ b/clang/lib/Analysis/LifetimeSafety/Loans.cpp @@ -0,0 +1,18 @@ +//===- Loans.cpp - Loan Implementation --------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "clang/Analysis/Analyses/LifetimeSafety/Loans.h" + +namespace clang::lifetimes::internal { + +void Loan::dump(llvm::raw_ostream &OS) const { + OS << ID << " (Path: "; + OS << Path.D->getNameAsString() << ")"; +} + +} // namespace clang::lifetimes::internal diff --git a/clang/lib/Analysis/LifetimeSafety/Origins.cpp b/clang/lib/Analysis/LifetimeSafety/Origins.cpp new file mode 100644 index 0000000..ea51a75 --- /dev/null +++ b/clang/lib/Analysis/LifetimeSafety/Origins.cpp @@ -0,0 +1,89 @@ +//===- Origins.cpp - Origin Implementation -----------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "clang/Analysis/Analyses/LifetimeSafety/Origins.h" + +namespace clang::lifetimes::internal { + +void OriginManager::dump(OriginID OID, llvm::raw_ostream &OS) const { + OS << OID << " ("; + Origin O = getOrigin(OID); + if (const ValueDecl *VD = O.getDecl()) + OS << "Decl: " << VD->getNameAsString(); + else if (const Expr *E = O.getExpr()) + OS << "Expr: " << E->getStmtClassName(); + else + OS << "Unknown"; + OS << ")"; +} + +Origin &OriginManager::addOrigin(OriginID ID, const clang::ValueDecl &D) { + AllOrigins.emplace_back(ID, &D); + return AllOrigins.back(); +} + +Origin &OriginManager::addOrigin(OriginID ID, const clang::Expr &E) { + AllOrigins.emplace_back(ID, &E); + return AllOrigins.back(); +} + +// TODO: Mark this method as const once we remove the call to getOrCreate. +OriginID OriginManager::get(const Expr &E) { + auto It = ExprToOriginID.find(&E); + if (It != ExprToOriginID.end()) + return It->second; + // If the expression itself has no specific origin, and it's a reference + // to a declaration, its origin is that of the declaration it refers to. + // For pointer types, where we don't pre-emptively create an origin for the + // DeclRefExpr itself. + if (const auto *DRE = dyn_cast<DeclRefExpr>(&E)) + return get(*DRE->getDecl()); + // TODO: This should be an assert(It != ExprToOriginID.end()). The current + // implementation falls back to getOrCreate to avoid crashing on + // yet-unhandled pointer expressions, creating an empty origin for them. + return getOrCreate(E); +} + +OriginID OriginManager::get(const ValueDecl &D) { + auto It = DeclToOriginID.find(&D); + // TODO: This should be an assert(It != DeclToOriginID.end()). The current + // implementation falls back to getOrCreate to avoid crashing on + // yet-unhandled pointer expressions, creating an empty origin for them. + if (It == DeclToOriginID.end()) + return getOrCreate(D); + + return It->second; +} + +OriginID OriginManager::getOrCreate(const Expr &E) { + auto It = ExprToOriginID.find(&E); + if (It != ExprToOriginID.end()) + return It->second; + + OriginID NewID = getNextOriginID(); + addOrigin(NewID, E); + ExprToOriginID[&E] = NewID; + return NewID; +} + +const Origin &OriginManager::getOrigin(OriginID ID) const { + assert(ID.Value < AllOrigins.size()); + return AllOrigins[ID.Value]; +} + +OriginID OriginManager::getOrCreate(const ValueDecl &D) { + auto It = DeclToOriginID.find(&D); + if (It != DeclToOriginID.end()) + return It->second; + OriginID NewID = getNextOriginID(); + addOrigin(NewID, D); + DeclToOriginID[&D] = NewID; + return NewID; +} + +} // namespace clang::lifetimes::internal diff --git a/clang/lib/CodeGen/BackendUtil.cpp b/clang/lib/CodeGen/BackendUtil.cpp index 2d95982..f8e8086 100644 --- a/clang/lib/CodeGen/BackendUtil.cpp +++ b/clang/lib/CodeGen/BackendUtil.cpp @@ -473,6 +473,7 @@ static bool initTargetOptions(const CompilerInstance &CI, Options.StackUsageOutput = CodeGenOpts.StackUsageOutput; Options.EmitAddrsig = CodeGenOpts.Addrsig; Options.ForceDwarfFrameSection = CodeGenOpts.ForceDwarfFrameSection; + Options.EmitCallGraphSection = CodeGenOpts.CallGraphSection; Options.EmitCallSiteInfo = CodeGenOpts.EmitCallSiteInfo; Options.EnableAIXExtendedAltivecABI = LangOpts.EnableAIXExtendedAltivecABI; Options.XRayFunctionIndex = CodeGenOpts.XRayFunctionIndex; diff --git a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp index d326a81..bf75573 100644 --- a/clang/lib/Driver/ToolChains/Clang.cpp +++ b/clang/lib/Driver/ToolChains/Clang.cpp @@ -6442,6 +6442,10 @@ void Clang::ConstructJob(Compilation &C, const JobAction &JA, CmdArgs.push_back(A->getValue()); } + if (Args.hasFlag(options::OPT_fexperimental_call_graph_section, + options::OPT_fno_experimental_call_graph_section, false)) + CmdArgs.push_back("-fexperimental-call-graph-section"); + Args.addOptInFlag(CmdArgs, options::OPT_fstack_size_section, options::OPT_fno_stack_size_section); diff --git a/clang/lib/Driver/ToolChains/CommonArgs.cpp b/clang/lib/Driver/ToolChains/CommonArgs.cpp index 16cc1db..99400ac 100644 --- a/clang/lib/Driver/ToolChains/CommonArgs.cpp +++ b/clang/lib/Driver/ToolChains/CommonArgs.cpp @@ -1272,6 +1272,11 @@ void tools::addLTOOptions(const ToolChain &ToolChain, const ArgList &Args, CmdArgs.push_back( Args.MakeArgString(Twine(PluginOptPrefix) + "-stack-size-section")); + if (Args.hasFlag(options::OPT_fexperimental_call_graph_section, + options::OPT_fno_experimental_call_graph_section, false)) + CmdArgs.push_back( + Args.MakeArgString(Twine(PluginOptPrefix) + "-call-graph-section")); + // Setup statistics file output. SmallString<128> StatsFile = getStatsFileName(Args, Output, *Input, D); if (!StatsFile.empty()) diff --git a/clang/lib/Format/FormatTokenLexer.cpp b/clang/lib/Format/FormatTokenLexer.cpp index 86a5185..ab32938 100644 --- a/clang/lib/Format/FormatTokenLexer.cpp +++ b/clang/lib/Format/FormatTokenLexer.cpp @@ -93,12 +93,6 @@ ArrayRef<FormatToken *> FormatTokenLexer::lex() { auto &Tok = *Tokens.back(); const auto NewlinesBefore = Tok.NewlinesBefore; switch (FormatOff) { - case FO_CurrentLine: - if (NewlinesBefore == 0) - Tok.Finalized = true; - else - FormatOff = FO_None; - break; case FO_NextLine: if (NewlinesBefore > 1) { FormatOff = FO_None; @@ -107,6 +101,13 @@ ArrayRef<FormatToken *> FormatTokenLexer::lex() { FormatOff = FO_CurrentLine; } break; + case FO_CurrentLine: + if (NewlinesBefore == 0) { + Tok.Finalized = true; + break; + } + FormatOff = FO_None; + [[fallthrough]]; default: if (!FormattingDisabled && FormatOffRegex.match(Tok.TokenText)) { if (Tok.is(tok::comment) && diff --git a/clang/lib/Sema/AnalysisBasedWarnings.cpp b/clang/lib/Sema/AnalysisBasedWarnings.cpp index e9ca8ce..9abaf79 100644 --- a/clang/lib/Sema/AnalysisBasedWarnings.cpp +++ b/clang/lib/Sema/AnalysisBasedWarnings.cpp @@ -29,7 +29,7 @@ #include "clang/Analysis/Analyses/CFGReachabilityAnalysis.h" #include "clang/Analysis/Analyses/CalledOnceCheck.h" #include "clang/Analysis/Analyses/Consumed.h" -#include "clang/Analysis/Analyses/LifetimeSafety.h" +#include "clang/Analysis/Analyses/LifetimeSafety/LifetimeSafety.h" #include "clang/Analysis/Analyses/ReachableCode.h" #include "clang/Analysis/Analyses/ThreadSafety.h" #include "clang/Analysis/Analyses/UninitializedValues.h" diff --git a/clang/lib/Sema/CMakeLists.txt b/clang/lib/Sema/CMakeLists.txt index 51e0ee1..0ebf56e 100644 --- a/clang/lib/Sema/CMakeLists.txt +++ b/clang/lib/Sema/CMakeLists.txt @@ -111,6 +111,7 @@ add_clang_library(clangSema clangAPINotes clangAST clangAnalysis + clangAnalysisLifetimeSafety clangBasic clangEdit clangLex diff --git a/clang/lib/Sema/CheckExprLifetime.cpp b/clang/lib/Sema/CheckExprLifetime.cpp index e8a7ad3..8aebf53 100644 --- a/clang/lib/Sema/CheckExprLifetime.cpp +++ b/clang/lib/Sema/CheckExprLifetime.cpp @@ -10,7 +10,7 @@ #include "clang/AST/Decl.h" #include "clang/AST/Expr.h" #include "clang/AST/Type.h" -#include "clang/Analysis/Analyses/LifetimeAnnotations.h" +#include "clang/Analysis/Analyses/LifetimeSafety/LifetimeAnnotations.h" #include "clang/Basic/DiagnosticSema.h" #include "clang/Sema/Initialization.h" #include "clang/Sema/Sema.h" diff --git a/clang/lib/Sema/SemaAPINotes.cpp b/clang/lib/Sema/SemaAPINotes.cpp index 35cdfbf..0d8d0fa 100644 --- a/clang/lib/Sema/SemaAPINotes.cpp +++ b/clang/lib/Sema/SemaAPINotes.cpp @@ -17,7 +17,7 @@ #include "clang/AST/DeclCXX.h" #include "clang/AST/DeclObjC.h" #include "clang/AST/TypeLoc.h" -#include "clang/Analysis/Analyses/LifetimeAnnotations.h" +#include "clang/Analysis/Analyses/LifetimeSafety/LifetimeAnnotations.h" #include "clang/Basic/SourceLocation.h" #include "clang/Lex/Lexer.h" #include "clang/Sema/SemaObjC.h" diff --git a/clang/lib/Sema/SemaConcept.cpp b/clang/lib/Sema/SemaConcept.cpp index f4df63c..9cbd1bd 100644 --- a/clang/lib/Sema/SemaConcept.cpp +++ b/clang/lib/Sema/SemaConcept.cpp @@ -604,6 +604,10 @@ ConstraintSatisfactionChecker::SubstitutionInTemplateArguments( return std::nullopt; const NormalizedConstraint::OccurenceList &Used = Constraint.mappingOccurenceList(); + // The empty MLTAL situation should only occur when evaluating non-dependent + // constraints. + if (!MLTAL.getNumSubstitutedLevels()) + MLTAL.addOuterTemplateArguments(TD, {}, /*Final=*/false); SubstitutedOuterMost = llvm::to_vector_of<TemplateArgument>(MLTAL.getOutermost()); unsigned Offset = 0; @@ -623,9 +627,7 @@ ConstraintSatisfactionChecker::SubstitutionInTemplateArguments( if (Offset < SubstitutedOuterMost.size()) SubstitutedOuterMost.erase(SubstitutedOuterMost.begin() + Offset); - MLTAL.replaceOutermostTemplateArguments( - const_cast<NamedDecl *>(Constraint.getConstraintDecl()), - SubstitutedOuterMost); + MLTAL.replaceOutermostTemplateArguments(TD, SubstitutedOuterMost); return std::move(MLTAL); } @@ -956,11 +958,20 @@ ExprResult ConstraintSatisfactionChecker::Evaluate( ? Constraint.getPackSubstitutionIndex() : PackSubstitutionIndex; - Sema::InstantiatingTemplate _(S, ConceptId->getBeginLoc(), - Sema::InstantiatingTemplate::ConstraintsCheck{}, - ConceptId->getNamedConcept(), - MLTAL.getInnermost(), - Constraint.getSourceRange()); + Sema::InstantiatingTemplate InstTemplate( + S, ConceptId->getBeginLoc(), + Sema::InstantiatingTemplate::ConstraintsCheck{}, + ConceptId->getNamedConcept(), + // We may have empty template arguments when checking non-dependent + // nested constraint expressions. + // In such cases, non-SFINAE errors would have already been diagnosed + // during parameter mapping substitution, so the instantiating template + // arguments are less useful here. + MLTAL.getNumSubstitutedLevels() ? MLTAL.getInnermost() + : ArrayRef<TemplateArgument>{}, + Constraint.getSourceRange()); + if (InstTemplate.isInvalid()) + return ExprError(); unsigned Size = Satisfaction.Details.size(); diff --git a/clang/test/Driver/call-graph-section.c b/clang/test/Driver/call-graph-section.c new file mode 100644 index 0000000..00fa896 --- /dev/null +++ b/clang/test/Driver/call-graph-section.c @@ -0,0 +1,5 @@ +// RUN: %clang -### -fexperimental-call-graph-section %s 2>&1 | FileCheck --check-prefix=CALL-GRAPH-SECTION %s +// RUN: %clang -### -fexperimental-call-graph-section -fno-experimental-call-graph-section %s 2>&1 | FileCheck --check-prefix=NO-CALL-GRAPH-SECTION %s + +// CALL-GRAPH-SECTION: "-fexperimental-call-graph-section" +// NO-CALL-GRAPH-SECTION-NOT: "-fexperimental-call-graph-section" diff --git a/clang/test/SemaTemplate/concepts.cpp b/clang/test/SemaTemplate/concepts.cpp index 1dbb989..3fbe7c0 100644 --- a/clang/test/SemaTemplate/concepts.cpp +++ b/clang/test/SemaTemplate/concepts.cpp @@ -1404,6 +1404,18 @@ static_assert(!std::is_constructible_v<span<4>, array<int, 3>>); } +namespace case7 { + +template <class _Tp, class _Up> +concept __same_as_impl = __is_same(_Tp, _Up); +template <class _Tp, class _Up> +concept same_as = __same_as_impl<_Tp, _Up>; +template <typename> +concept IsEntitySpec = + requires { requires same_as<void, void>; }; + +} + } namespace GH162125 { diff --git a/clang/unittests/Analysis/CMakeLists.txt b/clang/unittests/Analysis/CMakeLists.txt index 52e7d28..e0acf43 100644 --- a/clang/unittests/Analysis/CMakeLists.txt +++ b/clang/unittests/Analysis/CMakeLists.txt @@ -11,6 +11,7 @@ add_clang_unittest(ClangAnalysisTests clangAST clangASTMatchers clangAnalysis + clangAnalysisLifetimeSafety clangBasic clangFrontend clangLex diff --git a/clang/unittests/Analysis/LifetimeSafetyTest.cpp b/clang/unittests/Analysis/LifetimeSafetyTest.cpp index 169b2d2..0c05184 100644 --- a/clang/unittests/Analysis/LifetimeSafetyTest.cpp +++ b/clang/unittests/Analysis/LifetimeSafetyTest.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "clang/Analysis/Analyses/LifetimeSafety.h" +#include "clang/Analysis/Analyses/LifetimeSafety/LifetimeSafety.h" #include "clang/ASTMatchers/ASTMatchFinder.h" #include "clang/ASTMatchers/ASTMatchers.h" #include "clang/Testing/TestAST.h" @@ -63,7 +63,7 @@ public: Analysis = std::make_unique<LifetimeSafetyAnalysis>(*AnalysisCtx, nullptr); Analysis->run(); - AnnotationToPointMap = Analysis->getTestPoints(); + AnnotationToPointMap = Analysis->getFactManager().getTestPoints(); } LifetimeSafetyAnalysis &getAnalysis() { return *Analysis; } @@ -98,10 +98,11 @@ public: auto *VD = findDecl<ValueDecl>(VarName); if (!VD) return std::nullopt; - auto OID = Analysis.getOriginIDForDecl(VD); - if (!OID) - ADD_FAILURE() << "Origin for '" << VarName << "' not found."; - return OID; + // This assumes the OriginManager's `get` can find an existing origin. + // We might need a `find` method on OriginManager to avoid `getOrCreate` + // logic in a const-query context if that becomes an issue. + return const_cast<OriginManager &>(Analysis.getFactManager().getOriginMgr()) + .get(*VD); } std::vector<LoanID> getLoansForVar(llvm::StringRef VarName) { @@ -110,7 +111,10 @@ public: ADD_FAILURE() << "Failed to find VarDecl for '" << VarName << "'"; return {}; } - std::vector<LoanID> LID = Analysis.getLoanIDForVar(VD); + std::vector<LoanID> LID; + for (const Loan &L : Analysis.getFactManager().getLoanMgr().getLoans()) + if (L.Path.D == VD) + LID.push_back(L.ID); if (LID.empty()) { ADD_FAILURE() << "Loan for '" << VarName << "' not found."; return {}; @@ -123,7 +127,7 @@ public: ProgramPoint PP = Runner.getProgramPoint(Annotation); if (!PP) return std::nullopt; - return Analysis.getLoansAtPoint(OID, PP); + return Analysis.getLoanPropagation().getLoans(OID, PP); } std::optional<std::vector<std::pair<OriginID, LivenessKind>>> @@ -131,7 +135,10 @@ public: ProgramPoint PP = Runner.getProgramPoint(Annotation); if (!PP) return std::nullopt; - return Analysis.getLiveOriginsAtPoint(PP); + std::vector<std::pair<OriginID, LivenessKind>> Result; + for (auto &[OID, Info] : Analysis.getLiveOrigins().getLiveOriginsAt(PP)) + Result.push_back({OID, Info.Kind}); + return Result; } private: diff --git a/clang/unittests/Format/FormatTest.cpp b/clang/unittests/Format/FormatTest.cpp index fef7036..450c34f 100644 --- a/clang/unittests/Format/FormatTest.cpp +++ b/clang/unittests/Format/FormatTest.cpp @@ -24843,6 +24843,11 @@ TEST_F(FormatTest, OneLineFormatOffRegex) { " } while (0 )", Style); + Style.OneLineFormatOffRegex = "MACRO_TEST"; + verifyNoChange(" MACRO_TEST1 ( ) ;\n" + " MACRO_TEST2( );", + Style); + Style.ColumnLimit = 50; Style.OneLineFormatOffRegex = "^LogErrorPrint$"; verifyFormat(" myproject::LogErrorPrint(logger, \"Don't split me!\");\n" diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h index 695221c..0e3c9aa2 100644 --- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h +++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h @@ -208,6 +208,8 @@ struct IntrinsicLibrary { fir::ExtendedValue genAssociated(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>); mlir::Value genAtand(mlir::Type, llvm::ArrayRef<mlir::Value>); + mlir::Value genBarrierArrive(mlir::Type, llvm::ArrayRef<mlir::Value>); + mlir::Value genBarrierArriveCnt(mlir::Type, llvm::ArrayRef<mlir::Value>); void genBarrierInit(llvm::ArrayRef<fir::ExtendedValue>); fir::ExtendedValue genBesselJn(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>); @@ -272,6 +274,7 @@ struct IntrinsicLibrary { llvm::ArrayRef<fir::ExtendedValue>); template <Extremum, ExtremumBehavior> mlir::Value genExtremum(mlir::Type, llvm::ArrayRef<mlir::Value>); + void genFenceProxyAsync(llvm::ArrayRef<fir::ExtendedValue>); mlir::Value genFloor(mlir::Type, llvm::ArrayRef<mlir::Value>); mlir::Value genFraction(mlir::Type resultType, mlir::ArrayRef<mlir::Value> args); @@ -454,6 +457,8 @@ struct IntrinsicLibrary { mlir::Value genTand(mlir::Type, llvm::ArrayRef<mlir::Value>); mlir::Value genTanpi(mlir::Type, llvm::ArrayRef<mlir::Value>); mlir::Value genTime(mlir::Type, llvm::ArrayRef<mlir::Value>); + void genTMABulkCommitGroup(llvm::ArrayRef<fir::ExtendedValue>); + void genTMABulkWaitGroup(llvm::ArrayRef<fir::ExtendedValue>); mlir::Value genTrailz(mlir::Type, llvm::ArrayRef<mlir::Value>); fir::ExtendedValue genTransfer(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>); diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index bd94651..444f274 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -3383,7 +3383,8 @@ static void genOMPDispatch(lower::AbstractConverter &converter, } } - switch (llvm::omp::Directive dir = item->id) { + llvm::omp::Directive dir = item->id; + switch (dir) { case llvm::omp::Directive::OMPD_barrier: newOp = genBarrierOp(converter, symTable, semaCtx, eval, loc, queue, item); break; diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp index 2c21868..7c5c5fb 100644 --- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp @@ -346,6 +346,14 @@ static constexpr IntrinsicHandler handlers[]{ &I::genVoteSync<mlir::NVVM::VoteSyncKind::ballot>, {{{"mask", asValue}, {"pred", asValue}}}, /*isElemental=*/false}, + {"barrier_arrive", + &I::genBarrierArrive, + {{{"barrier", asAddr}}}, + /*isElemental=*/false}, + {"barrier_arrive_cnt", + &I::genBarrierArriveCnt, + {{{"barrier", asAddr}, {"count", asValue}}}, + /*isElemental=*/false}, {"barrier_init", &I::genBarrierInit, {{{"barrier", asAddr}, {"count", asValue}}}, @@ -494,6 +502,10 @@ static constexpr IntrinsicHandler handlers[]{ &I::genExtendsTypeOf, {{{"a", asBox}, {"mold", asBox}}}, /*isElemental=*/false}, + {"fence_proxy_async", + &I::genFenceProxyAsync, + {}, + /*isElemental=*/false}, {"findloc", &I::genFindloc, {{{"array", asBox}, @@ -1004,6 +1016,14 @@ static constexpr IntrinsicHandler handlers[]{ {"threadfence_block", &I::genThreadFenceBlock, {}, /*isElemental=*/false}, {"threadfence_system", &I::genThreadFenceSystem, {}, /*isElemental=*/false}, {"time", &I::genTime, {}, /*isElemental=*/false}, + {"tma_bulk_commit_group", + &I::genTMABulkCommitGroup, + {{}}, + /*isElemental=*/false}, + {"tma_bulk_wait_group", + &I::genTMABulkWaitGroup, + {{}}, + /*isElemental=*/false}, {"trailz", &I::genTrailz}, {"transfer", &I::genTransfer, @@ -3180,20 +3200,59 @@ IntrinsicLibrary::genAssociated(mlir::Type resultType, return fir::runtime::genAssociated(builder, loc, pointerBox, targetBox); } -// BARRIER_INIT (CUDA) -void IntrinsicLibrary::genBarrierInit(llvm::ArrayRef<fir::ExtendedValue> args) { - assert(args.size() == 2); - auto llvmPtr = fir::ConvertOp::create( +static mlir::Value convertBarrierToLLVM(fir::FirOpBuilder &builder, + mlir::Location loc, + mlir::Value barrier) { + mlir::Value llvmPtr = fir::ConvertOp::create( builder, loc, mlir::LLVM::LLVMPointerType::get(builder.getContext()), - fir::getBase(args[0])); - auto addrCast = mlir::LLVM::AddrSpaceCastOp::create( + barrier); + mlir::Value addrCast = mlir::LLVM::AddrSpaceCastOp::create( builder, loc, mlir::LLVM::LLVMPointerType::get( builder.getContext(), static_cast<unsigned>(mlir::NVVM::NVVMMemorySpace::Shared)), llvmPtr); - mlir::NVVM::MBarrierInitSharedOp::create(builder, loc, addrCast, + return addrCast; +} + +// BARRIER_ARRIVE (CUDA) +mlir::Value +IntrinsicLibrary::genBarrierArrive(mlir::Type resultType, + llvm::ArrayRef<mlir::Value> args) { + assert(args.size() == 1); + mlir::Value barrier = convertBarrierToLLVM(builder, loc, args[0]); + return mlir::NVVM::MBarrierArriveSharedOp::create(builder, loc, resultType, + barrier) + .getResult(); +} + +// BARRIER_ARRIBVE_CNT (CUDA) +mlir::Value +IntrinsicLibrary::genBarrierArriveCnt(mlir::Type resultType, + llvm::ArrayRef<mlir::Value> args) { + assert(args.size() == 2); + mlir::Value barrier = convertBarrierToLLVM(builder, loc, args[0]); + mlir::Value token = fir::AllocaOp::create(builder, loc, resultType); + // TODO: the MBarrierArriveExpectTxOp is not taking the state argument and + // currently just the sink symbol `_`. + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive + mlir::NVVM::MBarrierArriveExpectTxOp::create(builder, loc, barrier, args[1], + {}); + return fir::LoadOp::create(builder, loc, token); +} + +// BARRIER_INIT (CUDA) +void IntrinsicLibrary::genBarrierInit(llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 2); + mlir::Value barrier = + convertBarrierToLLVM(builder, loc, fir::getBase(args[0])); + mlir::NVVM::MBarrierInitSharedOp::create(builder, loc, barrier, fir::getBase(args[1]), {}); + auto kind = mlir::NVVM::ProxyKindAttr::get( + builder.getContext(), mlir::NVVM::ProxyKind::async_shared); + auto space = mlir::NVVM::SharedSpaceAttr::get( + builder.getContext(), mlir::NVVM::SharedSpace::shared_cta); + mlir::NVVM::FenceProxyOp::create(builder, loc, kind, space); } // BESSEL_JN @@ -4312,6 +4371,17 @@ IntrinsicLibrary::genExtendsTypeOf(mlir::Type resultType, fir::getBase(args[1]))); } +// FENCE_PROXY_ASYNC (CUDA) +void IntrinsicLibrary::genFenceProxyAsync( + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 0); + auto kind = mlir::NVVM::ProxyKindAttr::get( + builder.getContext(), mlir::NVVM::ProxyKind::async_shared); + auto space = mlir::NVVM::SharedSpaceAttr::get( + builder.getContext(), mlir::NVVM::SharedSpace::shared_cta); + mlir::NVVM::FenceProxyOp::create(builder, loc, kind, space); +} + // FINDLOC fir::ExtendedValue IntrinsicLibrary::genFindloc(mlir::Type resultType, @@ -9127,6 +9197,21 @@ mlir::Value IntrinsicLibrary::genTime(mlir::Type resultType, fir::runtime::genTime(builder, loc)); } +// TMA_BULK_COMMIT_GROUP (CUDA) +void IntrinsicLibrary::genTMABulkCommitGroup( + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 0); + mlir::NVVM::CpAsyncBulkCommitGroupOp::create(builder, loc); +} + +// TMA_BULK_WAIT_GROUP (CUDA) +void IntrinsicLibrary::genTMABulkWaitGroup( + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 0); + auto group = builder.getIntegerAttr(builder.getI32Type(), 0); + mlir::NVVM::CpAsyncBulkWaitGroupOp::create(builder, loc, group, {}); +} + // TRIM fir::ExtendedValue IntrinsicLibrary::genTrim(mlir::Type resultType, diff --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90 index 4f552dcf..106f3e2 100644 --- a/flang/module/cudadevice.f90 +++ b/flang/module/cudadevice.f90 @@ -1987,13 +1987,42 @@ implicit none end function end interface + ! TMA Operations + interface attributes(device) subroutine barrier_init(barrier, count) - integer(8) :: barrier + integer(8), shared :: barrier integer(4) :: count end subroutine end interface + interface barrier_arrive + attributes(device) function barrier_arrive(barrier) result(token) + integer(8), shared :: barrier + integer(8) :: token + end function + attributes(device) function barrier_arrive_cnt(barrier, count) result(token) + integer(8), shared :: barrier + integer(4) :: count + integer(8) :: token + end function + end interface + + interface + attributes(device) subroutine fence_proxy_async() + end subroutine + end interface + + interface + attributes(device) subroutine tma_bulk_commit_group() + end subroutine + end interface + + interface + attributes(device) subroutine tma_bulk_wait_group() + end subroutine + end interface + contains attributes(device) subroutine syncthreads() diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf index cdb337b..697b17b 100644 --- a/flang/test/Lower/CUDA/cuda-device-proc.cuf +++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf @@ -394,9 +394,14 @@ end subroutine attributes(global) subroutine test_barrier() integer(8), shared :: barrier + integer(8) :: token + integer :: count call barrier_init(barrier, 256) -end subroutine + token = barrier_arrive(barrier) + + token = barrier_arrive(barrier, count) +end subroutine ! CHECK-LABEL: func.func @_QPtest_barrier() @@ -406,3 +411,29 @@ end subroutine ! CHECK: %[[LLVM_PTR:.*]] = fir.convert %[[DECL_SHARED]]#0 : (!fir.ref<i64>) -> !llvm.ptr ! CHECK: %[[SHARED_PTR:.*]] = llvm.addrspacecast %[[LLVM_PTR]] : !llvm.ptr to !llvm.ptr<3> ! CHECK: nvvm.mbarrier.init.shared %[[SHARED_PTR]], %[[COUNT]] : !llvm.ptr<3>, i32 +! CHECK: nvvm.fence.proxy {kind = #nvvm.proxy_kind<async.shared>, space = #nvvm.shared_space<cta>} + +! CHECK: %[[LLVM_PTR:.*]] = fir.convert %[[DECL_SHARED]]#0 : (!fir.ref<i64>) -> !llvm.ptr +! CHECK: %[[SHARED_PTR:.*]] = llvm.addrspacecast %[[LLVM_PTR]] : !llvm.ptr to !llvm.ptr<3> +! CHECK: %{{.*}} = nvvm.mbarrier.arrive.shared %[[SHARED_PTR]] : !llvm.ptr<3> -> i64 + +! CHECK: %[[LLVM_PTR:.*]] = fir.convert %[[DECL_SHARED]]#0 : (!fir.ref<i64>) -> !llvm.ptr +! CHECK: %[[SHARED_PTR:.*]] = llvm.addrspacecast %[[LLVM_PTR]] : !llvm.ptr to !llvm.ptr<3> +! CHECK: nvvm.mbarrier.arrive.expect_tx %[[SHARED_PTR]], %{{.*}} : !llvm.ptr<3>, i32 + + +attributes(global) subroutine test_fence() + call fence_proxy_async() +end subroutine + +! CHECK-LABEL: func.func @_QPtest_fence() +! CHECK: nvvm.fence.proxy {kind = #nvvm.proxy_kind<async.shared>, space = #nvvm.shared_space<cta>} + +attributes(global) subroutine test_tma() + call tma_bulk_commit_group() + call tma_bulk_wait_group() +end subroutine + +! CHECK-LABEL: func.func @_QPtest_tma() +! CHECK: nvvm.cp.async.bulk.commit.group +! CHECK: nvvm.cp.async.bulk.wait_group 0 diff --git a/libc/shared/math.h b/libc/shared/math.h index 82b9250..e3f7965 100644 --- a/libc/shared/math.h +++ b/libc/shared/math.h @@ -49,6 +49,7 @@ #include "math/exp10m1f16.h" #include "math/exp2.h" #include "math/exp2f.h" +#include "math/exp2f16.h" #include "math/expf.h" #include "math/expf16.h" #include "math/frexpf.h" diff --git a/libc/shared/math/exp2f16.h b/libc/shared/math/exp2f16.h new file mode 100644 index 0000000..f799511 --- /dev/null +++ b/libc/shared/math/exp2f16.h @@ -0,0 +1,29 @@ +//===-- Shared exp2f16 function ---------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SHARED_MATH_EXP2F16_H +#define LLVM_LIBC_SHARED_MATH_EXP2F16_H + +#include "include/llvm-libc-macros/float16-macros.h" +#include "shared/libc_common.h" + +#ifdef LIBC_TYPES_HAS_FLOAT16 + +#include "src/__support/math/exp2f16.h" + +namespace LIBC_NAMESPACE_DECL { +namespace shared { + +using math::exp2f16; + +} // namespace shared +} // namespace LIBC_NAMESPACE_DECL + +#endif // LIBC_TYPES_HAS_FLOAT16 + +#endif // LLVM_LIBC_SHARED_MATH_EXP2F16_H diff --git a/libc/src/__support/math/CMakeLists.txt b/libc/src/__support/math/CMakeLists.txt index 61253de..9685496 100644 --- a/libc/src/__support/math/CMakeLists.txt +++ b/libc/src/__support/math/CMakeLists.txt @@ -738,6 +738,20 @@ add_header_library( ) add_header_library( + exp2f16 + HDRS + exp2f16.h + DEPENDS + .expxf16_utils + libc.src.__support.FPUtil.cast + libc.src.__support.FPUtil.except_value_utils + libc.src.__support.FPUtil.fenv_impl + libc.src.__support.FPUtil.fp_bits + libc.src.__support.FPUtil.rounding_mode + libc.src.__support.macros.optimization +) + +add_header_library( exp10 HDRS exp10.h diff --git a/libc/src/__support/math/exp2f16.h b/libc/src/__support/math/exp2f16.h new file mode 100644 index 0000000..599ba0f --- /dev/null +++ b/libc/src/__support/math/exp2f16.h @@ -0,0 +1,111 @@ +//===-- Implementation header for exp2f16 -----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC___SUPPORT_MATH_EXP2F16_H +#define LLVM_LIBC_SRC___SUPPORT_MATH_EXP2F16_H + +#include "include/llvm-libc-macros/float16-macros.h" + +#ifdef LIBC_TYPES_HAS_FLOAT16 + +#include "expxf16_utils.h" +#include "src/__support/FPUtil/FEnvImpl.h" +#include "src/__support/FPUtil/FPBits.h" +#include "src/__support/FPUtil/cast.h" +#include "src/__support/FPUtil/except_value_utils.h" +#include "src/__support/FPUtil/rounding_mode.h" +#include "src/__support/common.h" +#include "src/__support/macros/config.h" +#include "src/__support/macros/optimization.h" + +namespace LIBC_NAMESPACE_DECL { + +namespace math { + +LIBC_INLINE static constexpr float16 exp2f16(float16 x) { + +#ifndef LIBC_MATH_HAS_SKIP_ACCURATE_PASS + constexpr fputil::ExceptValues<float16, 3> EXP2F16_EXCEPTS = {{ + // (input, RZ output, RU offset, RD offset, RN offset) + // x = 0x1.714p-11, exp2f16(x) = 0x1p+0 (RZ) + {0x11c5U, 0x3c00U, 1U, 0U, 1U}, + // x = -0x1.558p-4, exp2f16(x) = 0x1.e34p-1 (RZ) + {0xad56U, 0x3b8dU, 1U, 0U, 0U}, + // x = -0x1.d5cp-4, exp2f16(x) = 0x1.d8cp-1 (RZ) + {0xaf57U, 0x3b63U, 1U, 0U, 0U}, + }}; +#endif // !LIBC_MATH_HAS_SKIP_ACCURATE_PASS + + using namespace math::expxf16_internal; + using FPBits = fputil::FPBits<float16>; + FPBits x_bits(x); + + uint16_t x_u = x_bits.uintval(); + uint16_t x_abs = x_u & 0x7fffU; + + // When |x| >= 16, or x is NaN. + if (LIBC_UNLIKELY(x_abs >= 0x4c00U)) { + // exp2(NaN) = NaN + if (x_bits.is_nan()) { + if (x_bits.is_signaling_nan()) { + fputil::raise_except_if_required(FE_INVALID); + return FPBits::quiet_nan().get_val(); + } + + return x; + } + + // When x >= 16. + if (x_bits.is_pos()) { + // exp2(+inf) = +inf + if (x_bits.is_inf()) + return FPBits::inf().get_val(); + + switch (fputil::quick_get_round()) { + case FE_TONEAREST: + case FE_UPWARD: + fputil::set_errno_if_required(ERANGE); + fputil::raise_except_if_required(FE_OVERFLOW); + return FPBits::inf().get_val(); + default: + return FPBits::max_normal().get_val(); + } + } + + // When x <= -25. + if (x_u >= 0xce40U) { + // exp2(-inf) = +0 + if (x_bits.is_inf()) + return FPBits::zero().get_val(); + + fputil::set_errno_if_required(ERANGE); + fputil::raise_except_if_required(FE_UNDERFLOW | FE_INEXACT); + + if (fputil::fenv_is_round_up()) + return FPBits::min_subnormal().get_val(); + return FPBits::zero().get_val(); + } + } + +#ifndef LIBC_MATH_HAS_SKIP_ACCURATE_PASS + if (auto r = EXP2F16_EXCEPTS.lookup(x_u); LIBC_UNLIKELY(r.has_value())) + return r.value(); +#endif // !LIBC_MATH_HAS_SKIP_ACCURATE_PASS + + // exp2(x) = exp2(hi + mid) * exp2(lo) + auto [exp2_hi_mid, exp2_lo] = exp2_range_reduction(x); + return fputil::cast<float16>(exp2_hi_mid * exp2_lo); +} + +} // namespace math + +} // namespace LIBC_NAMESPACE_DECL + +#endif // LIBC_TYPES_HAS_FLOAT16 + +#endif // LLVM_LIBC_SRC___SUPPORT_MATH_EXP2F16_H diff --git a/libc/src/math/generic/CMakeLists.txt b/libc/src/math/generic/CMakeLists.txt index 55f4aaf..0754b5e 100644 --- a/libc/src/math/generic/CMakeLists.txt +++ b/libc/src/math/generic/CMakeLists.txt @@ -1478,15 +1478,7 @@ add_entrypoint_object( HDRS ../exp2f16.h DEPENDS - libc.hdr.errno_macros - libc.hdr.fenv_macros - libc.src.__support.FPUtil.cast - libc.src.__support.FPUtil.except_value_utils - libc.src.__support.FPUtil.fenv_impl - libc.src.__support.FPUtil.fp_bits - libc.src.__support.FPUtil.rounding_mode - libc.src.__support.macros.optimization - libc.src.__support.math.expxf16_utils + libc.src.__support.math.exp2f16 ) add_entrypoint_object( diff --git a/libc/src/math/generic/exp2f16.cpp b/libc/src/math/generic/exp2f16.cpp index 5db0c3a..80799d4 100644 --- a/libc/src/math/generic/exp2f16.cpp +++ b/libc/src/math/generic/exp2f16.cpp @@ -7,92 +7,10 @@ //===----------------------------------------------------------------------===// #include "src/math/exp2f16.h" -#include "hdr/errno_macros.h" -#include "hdr/fenv_macros.h" -#include "src/__support/FPUtil/FEnvImpl.h" -#include "src/__support/FPUtil/FPBits.h" -#include "src/__support/FPUtil/cast.h" -#include "src/__support/FPUtil/except_value_utils.h" -#include "src/__support/FPUtil/rounding_mode.h" -#include "src/__support/common.h" -#include "src/__support/macros/config.h" -#include "src/__support/macros/optimization.h" -#include "src/__support/math/expxf16_utils.h" +#include "src/__support/math/exp2f16.h" namespace LIBC_NAMESPACE_DECL { -#ifndef LIBC_MATH_HAS_SKIP_ACCURATE_PASS -static constexpr fputil::ExceptValues<float16, 3> EXP2F16_EXCEPTS = {{ - // (input, RZ output, RU offset, RD offset, RN offset) - // x = 0x1.714p-11, exp2f16(x) = 0x1p+0 (RZ) - {0x11c5U, 0x3c00U, 1U, 0U, 1U}, - // x = -0x1.558p-4, exp2f16(x) = 0x1.e34p-1 (RZ) - {0xad56U, 0x3b8dU, 1U, 0U, 0U}, - // x = -0x1.d5cp-4, exp2f16(x) = 0x1.d8cp-1 (RZ) - {0xaf57U, 0x3b63U, 1U, 0U, 0U}, -}}; -#endif // !LIBC_MATH_HAS_SKIP_ACCURATE_PASS - -LLVM_LIBC_FUNCTION(float16, exp2f16, (float16 x)) { - using namespace math::expxf16_internal; - using FPBits = fputil::FPBits<float16>; - FPBits x_bits(x); - - uint16_t x_u = x_bits.uintval(); - uint16_t x_abs = x_u & 0x7fffU; - - // When |x| >= 16, or x is NaN. - if (LIBC_UNLIKELY(x_abs >= 0x4c00U)) { - // exp2(NaN) = NaN - if (x_bits.is_nan()) { - if (x_bits.is_signaling_nan()) { - fputil::raise_except_if_required(FE_INVALID); - return FPBits::quiet_nan().get_val(); - } - - return x; - } - - // When x >= 16. - if (x_bits.is_pos()) { - // exp2(+inf) = +inf - if (x_bits.is_inf()) - return FPBits::inf().get_val(); - - switch (fputil::quick_get_round()) { - case FE_TONEAREST: - case FE_UPWARD: - fputil::set_errno_if_required(ERANGE); - fputil::raise_except_if_required(FE_OVERFLOW); - return FPBits::inf().get_val(); - default: - return FPBits::max_normal().get_val(); - } - } - - // When x <= -25. - if (x_u >= 0xce40U) { - // exp2(-inf) = +0 - if (x_bits.is_inf()) - return FPBits::zero().get_val(); - - fputil::set_errno_if_required(ERANGE); - fputil::raise_except_if_required(FE_UNDERFLOW | FE_INEXACT); - - if (fputil::fenv_is_round_up()) - return FPBits::min_subnormal().get_val(); - return FPBits::zero().get_val(); - } - } - -#ifndef LIBC_MATH_HAS_SKIP_ACCURATE_PASS - if (auto r = EXP2F16_EXCEPTS.lookup(x_u); LIBC_UNLIKELY(r.has_value())) - return r.value(); -#endif // !LIBC_MATH_HAS_SKIP_ACCURATE_PASS - - // exp2(x) = exp2(hi + mid) * exp2(lo) - auto [exp2_hi_mid, exp2_lo] = exp2_range_reduction(x); - return fputil::cast<float16>(exp2_hi_mid * exp2_lo); -} +LLVM_LIBC_FUNCTION(float16, exp2f16, (float16 x)) { return math::exp2f16(x); } } // namespace LIBC_NAMESPACE_DECL diff --git a/libc/test/shared/CMakeLists.txt b/libc/test/shared/CMakeLists.txt index f341d3f..8d81199 100644 --- a/libc/test/shared/CMakeLists.txt +++ b/libc/test/shared/CMakeLists.txt @@ -42,6 +42,7 @@ add_fp_unittest( libc.src.__support.math.exp libc.src.__support.math.exp2 libc.src.__support.math.exp2f + libc.src.__support.math.exp2f16 libc.src.__support.math.exp10 libc.src.__support.math.exp10f libc.src.__support.math.exp10f16 diff --git a/libc/test/shared/shared_math_test.cpp b/libc/test/shared/shared_math_test.cpp index 477b7ec..84787d5 100644 --- a/libc/test/shared/shared_math_test.cpp +++ b/libc/test/shared/shared_math_test.cpp @@ -28,7 +28,7 @@ TEST(LlvmLibcSharedMathTest, AllFloat16) { EXPECT_FP_EQ(0x1p+0f16, LIBC_NAMESPACE::shared::cospif16(0.0f16)); EXPECT_FP_EQ(0x1p+0f16, LIBC_NAMESPACE::shared::exp10f16(0.0f16)); EXPECT_FP_EQ(0x0p+0f16, LIBC_NAMESPACE::shared::exp10m1f16(0.0f16)); - + EXPECT_FP_EQ(0x1p+0f16, LIBC_NAMESPACE::shared::exp2f16(0.0f16)); EXPECT_FP_EQ(0x1p+0f16, LIBC_NAMESPACE::shared::expf16(0.0f16)); ASSERT_FP_EQ(float16(8 << 5), LIBC_NAMESPACE::shared::ldexpf16(8.0f16, 5)); diff --git a/lldb/test/API/commands/expression/diagnostics/TestExprDiagnostics.py b/lldb/test/API/commands/expression/diagnostics/TestExprDiagnostics.py index ec208f2..759b620 100644 --- a/lldb/test/API/commands/expression/diagnostics/TestExprDiagnostics.py +++ b/lldb/test/API/commands/expression/diagnostics/TestExprDiagnostics.py @@ -218,11 +218,9 @@ note: candidate function not viable: requires single argument 'x', but 2 argumen # Detail 1/3: note: requested expression language diag = details.GetItemAtIndex(0) self.assertEqual(str(diag.GetValueForKey("severity")), "note") - self.assertEqual( - str(diag.GetValueForKey("message")), "Ran expression as 'C++11'." - ) - self.assertEqual( - str(diag.GetValueForKey("rendered")), "Ran expression as 'C++11'." + self.assertIn("Ran expression as 'C++", str(diag.GetValueForKey("message"))) + self.assertIn( + "Ran expression as 'C++", str(diag.GetValueForKey("rendered")) ) self.assertEqual(str(diag.GetValueForKey("source_location")), "") self.assertEqual(str(diag.GetValueForKey("file")), "") diff --git a/lldb/test/Shell/Expr/TestExprLanguageNote.test b/lldb/test/Shell/Expr/TestExprLanguageNote.test index f3dc592..b4387bf 100644 --- a/lldb/test/Shell/Expr/TestExprLanguageNote.test +++ b/lldb/test/Shell/Expr/TestExprLanguageNote.test @@ -26,7 +26,7 @@ run expr blah # CHECK-TARGET: (lldb) expr -# CHECK-TARGET: note: Ran expression as 'C++14'. +# CHECK-TARGET: note: Ran expression as 'C++{{.*}}' expr -l objc -- blah diff --git a/llvm/docs/QualGroup.rst b/llvm/docs/QualGroup.rst index b45f569..5c05e4e 100644 --- a/llvm/docs/QualGroup.rst +++ b/llvm/docs/QualGroup.rst @@ -75,6 +75,16 @@ They meet the criteria for inclusion below. Knowing their handles help us keep t - capitan-davide - capitan_davide - capitan-davide + * - Jorge Pinto Sousa + - Critical Techworks + - sousajo-cc + - sousajo-cc + - sousajo-cc + * - José Rui Simões + - Critical Software + - jr-simoes + - jr_simoes + - iznogoud-zz * - Oscar Slotosch - Validas - slotosch @@ -100,6 +110,11 @@ They meet the criteria for inclusion below. Knowing their handles help us keep t - YoungJunLee - YoungJunLee - IamYJLee + * - Zaky Hermawan + - No Affiliation + - ZakyHermawan + - quarkz99 + - zakyHermawan Organizations are limited to three representatives within the group to maintain diversity. diff --git a/llvm/include/llvm/ADT/ImmutableSet.h b/llvm/include/llvm/ADT/ImmutableSet.h index 017585a4..310539f 100644 --- a/llvm/include/llvm/ADT/ImmutableSet.h +++ b/llvm/include/llvm/ADT/ImmutableSet.h @@ -21,7 +21,9 @@ #include "llvm/ADT/iterator.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/Signals.h" #include <cassert> #include <cstdint> #include <functional> diff --git a/llvm/include/llvm/ExecutionEngine/Orc/SymbolStringPool.h b/llvm/include/llvm/ExecutionEngine/Orc/SymbolStringPool.h index ed6ea96..2dd5abe 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/SymbolStringPool.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/SymbolStringPool.h @@ -14,6 +14,7 @@ #define LLVM_EXECUTIONENGINE_ORC_SYMBOLSTRINGPOOL_H #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Hashing.h" #include "llvm/ADT/StringMap.h" #include "llvm/Support/Compiler.h" #include <atomic> @@ -71,6 +72,7 @@ private: /// from nullptr to enable comparison with these values. class SymbolStringPtrBase { friend class SymbolStringPool; + friend class SymbolStringPoolEntryUnsafe; friend struct DenseMapInfo<SymbolStringPtr>; friend struct DenseMapInfo<NonOwningSymbolStringPtr>; @@ -204,7 +206,7 @@ public: SymbolStringPoolEntryUnsafe(PoolEntry *E) : E(E) {} /// Create an unsafe pool entry ref without changing the ref-count. - static SymbolStringPoolEntryUnsafe from(const SymbolStringPtr &S) { + static SymbolStringPoolEntryUnsafe from(const SymbolStringPtrBase &S) { return S.S; } @@ -318,6 +320,10 @@ SymbolStringPool::getRefCount(const SymbolStringPtrBase &S) const { LLVM_ABI raw_ostream &operator<<(raw_ostream &OS, const SymbolStringPtrBase &Sym); +inline hash_code hash_value(const orc::SymbolStringPtrBase &S) { + return hash_value(orc::SymbolStringPoolEntryUnsafe::from(S).rawPtr()); +} + } // end namespace orc template <> diff --git a/llvm/lib/CodeGen/AsmPrinter/DwarfUnit.cpp b/llvm/lib/CodeGen/AsmPrinter/DwarfUnit.cpp index aa078f3..e40fb76 100644 --- a/llvm/lib/CodeGen/AsmPrinter/DwarfUnit.cpp +++ b/llvm/lib/CodeGen/AsmPrinter/DwarfUnit.cpp @@ -704,9 +704,17 @@ void DwarfUnit::addType(DIE &Entity, const DIType *Ty, addDIEEntry(Entity, Attribute, DIEEntry(*getOrCreateTypeDIE(Ty))); } +// FIXME: change callsites to use the new DW_LNAME_ language codes. llvm::dwarf::SourceLanguage DwarfUnit::getSourceLanguage() const { - return static_cast<llvm::dwarf::SourceLanguage>( - getLanguage().getUnversionedName()); + const auto &Lang = getLanguage(); + + if (!Lang.hasVersionedName()) + return static_cast<llvm::dwarf::SourceLanguage>(Lang.getName()); + + return llvm::dwarf::toDW_LANG( + static_cast<llvm::dwarf::SourceLanguageName>(Lang.getName()), + Lang.getVersion()) + .value_or(llvm::dwarf::DW_LANG_hi_user); } std::string DwarfUnit::getParentContextString(const DIScope *Context) const { diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.td b/llvm/lib/Target/RISCV/RISCVInstrInfo.td index b9e01c3..66717b9 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.td @@ -1367,9 +1367,8 @@ def : InstAlias<".insn_s $opcode, $funct3, $rs2, (${rs1})", class PatGpr<SDPatternOperator OpNode, RVInst Inst, ValueType vt = XLenVT> : Pat<(vt (OpNode (vt GPR:$rs1))), (Inst GPR:$rs1)>; -class PatGprGpr<SDPatternOperator OpNode, RVInst Inst, ValueType vt1 = XLenVT, - ValueType vt2 = XLenVT> - : Pat<(vt1 (OpNode (vt1 GPR:$rs1), (vt2 GPR:$rs2))), (Inst GPR:$rs1, GPR:$rs2)>; +class PatGprGpr<SDPatternOperator OpNode, RVInst Inst, ValueType vt = XLenVT> + : Pat<(vt (OpNode (vt GPR:$rs1), (vt GPR:$rs2))), (Inst GPR:$rs1, GPR:$rs2)>; class PatGprImm<SDPatternOperator OpNode, RVInst Inst, ImmLeaf ImmType, ValueType vt = XLenVT> @@ -1973,8 +1972,9 @@ def PseudoZEXT_W : Pseudo<(outs GPR:$rd), (ins GPR:$rs), [], "zext.w", "$rd, $rs /// Loads -class LdPat<PatFrag LoadOp, RVInst Inst, ValueType vt = XLenVT> - : Pat<(vt (LoadOp (AddrRegImm (XLenVT GPRMem:$rs1), simm12_lo:$imm12))), +class LdPat<PatFrag LoadOp, RVInst Inst, ValueType vt = XLenVT, + ValueType PtrVT = XLenVT> + : Pat<(vt (LoadOp (AddrRegImm (PtrVT GPRMem:$rs1), simm12_lo:$imm12))), (Inst GPRMem:$rs1, simm12_lo:$imm12)>; def : LdPat<sextloadi8, LB>; @@ -1988,8 +1988,8 @@ def : LdPat<zextloadi16, LHU>; /// Stores class StPat<PatFrag StoreOp, RVInst Inst, RegisterClass StTy, - ValueType vt> - : Pat<(StoreOp (vt StTy:$rs2), (AddrRegImm (XLenVT GPRMem:$rs1), + ValueType vt, ValueType PtrVT = XLenVT> + : Pat<(StoreOp (vt StTy:$rs2), (AddrRegImm (PtrVT GPRMem:$rs1), simm12_lo:$imm12)), (Inst StTy:$rs2, GPRMem:$rs1, simm12_lo:$imm12)>; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoD.td b/llvm/lib/Target/RISCV/RISCVInstrInfoD.td index 65e7e3b..afac37d 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoD.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoD.td @@ -544,8 +544,8 @@ def PseudoRV32ZdinxSD : Pseudo<(outs), (ins GPRPair:$rs2, GPRNoX0:$rs1, simm12_l } // Predicates = [HasStdExtZdinx, IsRV32] let Predicates = [HasStdExtZdinx, HasStdExtZilsd, IsRV32] in { -def : LdPat<load, LD_RV32, f64>; -def : StPat<store, SD_RV32, GPRPair, f64>; +def : LdPat<load, LD_RV32, f64, i32>; +def : StPat<store, SD_RV32, GPRPair, f64, i32>; } let Predicates = [HasStdExtD, IsRV32] in { diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoXCV.td b/llvm/lib/Target/RISCV/RISCVInstrInfoXCV.td index d8f5d3e..aa8f1a1 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoXCV.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoXCV.td @@ -669,19 +669,19 @@ let Predicates = [HasVendorXCValu, IsRV32] in { // Patterns for load & store operations //===----------------------------------------------------------------------===// class CVLdrrPat<PatFrag LoadOp, RVInst Inst> - : Pat<(XLenVT (LoadOp CVrr:$regreg)), + : Pat<(i32 (LoadOp CVrr:$regreg)), (Inst CVrr:$regreg)>; class CVStriPat<PatFrag StoreOp, RVInst Inst> - : Pat<(StoreOp (XLenVT GPR:$rs2), GPR:$rs1, simm12_lo:$imm12), + : Pat<(StoreOp (i32 GPR:$rs2), GPR:$rs1, simm12_lo:$imm12), (Inst GPR:$rs2, GPR:$rs1, simm12_lo:$imm12)>; class CVStrriPat<PatFrag StoreOp, RVInst Inst> - : Pat<(StoreOp (XLenVT GPR:$rs2), GPR:$rs1, GPR:$rs3), + : Pat<(StoreOp (i32 GPR:$rs2), GPR:$rs1, GPR:$rs3), (Inst GPR:$rs2, GPR:$rs1, GPR:$rs3)>; class CVStrrPat<PatFrag StoreOp, RVInst Inst> - : Pat<(StoreOp (XLenVT GPR:$rs2), CVrr:$regreg), + : Pat<(StoreOp (i32 GPR:$rs2), CVrr:$regreg), (Inst GPR:$rs2, CVrr:$regreg)>; let Predicates = [HasVendorXCVmem, IsRV32], AddedComplexity = 1 in { @@ -725,17 +725,17 @@ let Predicates = [HasVendorXCVbitmanip, IsRV32] in { (CV_INSERT GPR:$rd, GPR:$rs1, (CV_HI5 cv_uimm10:$imm), (CV_LO5 cv_uimm10:$imm))>; - def : PatGpr<cttz, CV_FF1>; - def : PatGpr<ctlz, CV_FL1>; + def : PatGpr<cttz, CV_FF1, i32>; + def : PatGpr<ctlz, CV_FL1, i32>; def : PatGpr<int_riscv_cv_bitmanip_clb, CV_CLB>; - def : PatGpr<ctpop, CV_CNT>; + def : PatGpr<ctpop, CV_CNT, i32>; - def : PatGprGpr<rotr, CV_ROR>; + def : PatGprGpr<rotr, CV_ROR, i32>; def : Pat<(int_riscv_cv_bitmanip_bitrev GPR:$rs1, cv_tuimm5:$pts, cv_tuimm2:$radix), (CV_BITREV GPR:$rs1, cv_tuimm2:$radix, cv_tuimm5:$pts)>; - def : Pat<(bitreverse (XLenVT GPR:$rs)), (CV_BITREV GPR:$rs, 0, 0)>; + def : Pat<(bitreverse (i32 GPR:$rs)), (CV_BITREV GPR:$rs, 0, 0)>; } class PatCoreVAluGpr<string intr, string asm> : @@ -760,18 +760,18 @@ multiclass PatCoreVAluGprGprImm<Intrinsic intr> { } let Predicates = [HasVendorXCValu, IsRV32], AddedComplexity = 1 in { - def : PatGpr<abs, CV_ABS>; - def : PatGprGpr<setle, CV_SLE>; - def : PatGprGpr<setule, CV_SLEU>; - def : PatGprGpr<smin, CV_MIN>; - def : PatGprGpr<umin, CV_MINU>; - def : PatGprGpr<smax, CV_MAX>; - def : PatGprGpr<umax, CV_MAXU>; - - def : Pat<(sext_inreg (XLenVT GPR:$rs1), i16), (CV_EXTHS GPR:$rs1)>; - def : Pat<(sext_inreg (XLenVT GPR:$rs1), i8), (CV_EXTBS GPR:$rs1)>; - def : Pat<(and (XLenVT GPR:$rs1), 0xffff), (CV_EXTHZ GPR:$rs1)>; - def : Pat<(and (XLenVT GPR:$rs1), 0xff), (CV_EXTBZ GPR:$rs1)>; + def : PatGpr<abs, CV_ABS, i32>; + def : PatGprGpr<setle, CV_SLE, i32>; + def : PatGprGpr<setule, CV_SLEU, i32>; + def : PatGprGpr<smin, CV_MIN, i32>; + def : PatGprGpr<umin, CV_MINU, i32>; + def : PatGprGpr<smax, CV_MAX, i32>; + def : PatGprGpr<umax, CV_MAXU, i32>; + + def : Pat<(sext_inreg (i32 GPR:$rs1), i16), (CV_EXTHS GPR:$rs1)>; + def : Pat<(sext_inreg (i32 GPR:$rs1), i8), (CV_EXTBS GPR:$rs1)>; + def : Pat<(and (i32 GPR:$rs1), 0xffff), (CV_EXTHZ GPR:$rs1)>; + def : Pat<(and (i32 GPR:$rs1), 0xff), (CV_EXTBZ GPR:$rs1)>; defm CLIP : PatCoreVAluGprImm<int_riscv_cv_alu_clip>; defm CLIPU : PatCoreVAluGprImm<int_riscv_cv_alu_clipu>; @@ -790,9 +790,9 @@ let Predicates = [HasVendorXCValu, IsRV32], AddedComplexity = 1 in { //===----------------------------------------------------------------------===// let Predicates = [HasVendorXCVbi, IsRV32], AddedComplexity = 2 in { - def : Pat<(riscv_brcc GPR:$rs1, simm5:$imm5, SETEQ, bb:$imm12), + def : Pat<(riscv_brcc (i32 GPR:$rs1), simm5:$imm5, SETEQ, bb:$imm12), (CV_BEQIMM GPR:$rs1, simm5:$imm5, bare_simm13_lsb0_bb:$imm12)>; - def : Pat<(riscv_brcc GPR:$rs1, simm5:$imm5, SETNE, bb:$imm12), + def : Pat<(riscv_brcc (i32 GPR:$rs1), simm5:$imm5, SETNE, bb:$imm12), (CV_BNEIMM GPR:$rs1, simm5:$imm5, bare_simm13_lsb0_bb:$imm12)>; defm CC_SImm5_CV : SelectCC_GPR_riirr<GPR, simm5>; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoXqci.td b/llvm/lib/Target/RISCV/RISCVInstrInfoXqci.td index 5e1d07a..4537bfe 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoXqci.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoXqci.td @@ -1648,10 +1648,10 @@ def : Pat<(qc_setwmi (i32 GPR:$rs3), GPR:$rs1, tuimm5nonzero:$uimm5, tuimm7_lsb0 } // Predicates = [HasVendorXqcilsm, IsRV32] let Predicates = [HasVendorXqcili, IsRV32] in { -def: Pat<(qc_e_li tglobaladdr:$A), (QC_E_LI bare_simm32:$A)>; -def: Pat<(qc_e_li tblockaddress:$A), (QC_E_LI bare_simm32:$A)>; -def: Pat<(qc_e_li tjumptable:$A), (QC_E_LI bare_simm32:$A)>; -def: Pat<(qc_e_li tconstpool:$A), (QC_E_LI bare_simm32:$A)>; +def: Pat<(i32 (qc_e_li tglobaladdr:$A)), (QC_E_LI bare_simm32:$A)>; +def: Pat<(i32 (qc_e_li tblockaddress:$A)), (QC_E_LI bare_simm32:$A)>; +def: Pat<(i32 (qc_e_li tjumptable:$A)), (QC_E_LI bare_simm32:$A)>; +def: Pat<(i32 (qc_e_li tconstpool:$A)), (QC_E_LI bare_simm32:$A)>; } // Predicates = [HasVendorXqcili, IsRV32] //===----------------------------------------------------------------------===/i diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td index 52a2b29..c31713e 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td @@ -607,13 +607,16 @@ def : Pat<(fcopysign FPR64:$rs1, (f16 FPR16:$rs2)), (FSGNJ_D $rs1, (FCVT_D_H $rs let Predicates = [HasStdExtZhinxmin, HasStdExtZdinx, IsRV32] in { /// Float conversion operations // f64 -> f16, f16 -> f64 -def : Pat<(any_fpround FPR64IN32X:$rs1), (FCVT_H_D_IN32X FPR64IN32X:$rs1, FRM_DYN)>; -def : Pat<(any_fpextend FPR16INX:$rs1), (FCVT_D_H_IN32X FPR16INX:$rs1, FRM_RNE)>; +def : Pat<(any_fpround FPR64IN32X:$rs1), + (FCVT_H_D_IN32X FPR64IN32X:$rs1, (i32 FRM_DYN))>; +def : Pat<(any_fpextend FPR16INX:$rs1), + (FCVT_D_H_IN32X FPR16INX:$rs1, (i32 FRM_RNE))>; /// Float arithmetic operations def : Pat<(fcopysign FPR16INX:$rs1, FPR64IN32X:$rs2), - (FSGNJ_H_INX $rs1, (FCVT_H_D_IN32X $rs2, 0b111))>; -def : Pat<(fcopysign FPR64IN32X:$rs1, FPR16INX:$rs2), (FSGNJ_D_IN32X $rs1, (FCVT_D_H_IN32X $rs2, FRM_RNE))>; + (FSGNJ_H_INX $rs1, (FCVT_H_D_IN32X $rs2, (i32 FRM_DYN)))>; +def : Pat<(fcopysign FPR64IN32X:$rs1, FPR16INX:$rs2), + (FSGNJ_D_IN32X $rs1, (FCVT_D_H_IN32X $rs2, (i32 FRM_RNE)))>; } // Predicates = [HasStdExtZhinxmin, HasStdExtZdinx, IsRV32] let Predicates = [HasStdExtZhinxmin, HasStdExtZdinx, IsRV64] in { diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 989950f..a466ab2 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -316,6 +316,9 @@ private: bool selectImageWriteIntrinsic(MachineInstr &I) const; bool selectResourceGetPointer(Register &ResVReg, const SPIRVType *ResType, MachineInstr &I) const; + bool selectResourceNonUniformIndex(Register &ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const; bool selectModf(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const; bool selectUpdateCounter(Register &ResVReg, const SPIRVType *ResType, @@ -347,7 +350,7 @@ private: SPIRV::StorageClass::StorageClass SC, uint32_t Set, uint32_t Binding, uint32_t ArraySize, Register IndexReg, - bool IsNonUniform, StringRef Name, + StringRef Name, MachineIRBuilder MIRBuilder) const; SPIRVType *widenTypeToVec4(const SPIRVType *Type, MachineInstr &I) const; bool extractSubvector(Register &ResVReg, const SPIRVType *ResType, @@ -364,6 +367,7 @@ private: MachineInstr &I) const; bool loadHandleBeforePosition(Register &HandleReg, const SPIRVType *ResType, GIntrinsic &HandleDef, MachineInstr &Pos) const; + void decorateUsesAsNonUniform(Register &NonUniformReg) const; }; bool sampledTypeIsSignedInteger(const llvm::Type *HandleType) { @@ -3465,6 +3469,9 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, case Intrinsic::spv_discard: { return selectDiscard(ResVReg, ResType, I); } + case Intrinsic::spv_resource_nonuniformindex: { + return selectResourceNonUniformIndex(ResVReg, ResType, I); + } default: { std::string DiagMsg; raw_string_ostream OS(DiagMsg); @@ -3504,7 +3511,6 @@ bool SPIRVInstructionSelector::selectCounterHandleFromBinding( uint32_t Binding = getIConstVal(Intr.getOperand(3).getReg(), MRI); uint32_t ArraySize = getIConstVal(MainHandleDef->getOperand(4).getReg(), MRI); Register IndexReg = MainHandleDef->getOperand(5).getReg(); - const bool IsNonUniform = false; std::string CounterName = getStringValueFromReg(MainHandleDef->getOperand(6).getReg(), *MRI) + ".counter"; @@ -3513,7 +3519,7 @@ bool SPIRVInstructionSelector::selectCounterHandleFromBinding( MachineIRBuilder MIRBuilder(I); Register CounterVarReg = buildPointerToResource( GR.getPointeeType(ResType), GR.getPointerStorageClass(ResType), Set, - Binding, ArraySize, IndexReg, IsNonUniform, CounterName, MIRBuilder); + Binding, ArraySize, IndexReg, CounterName, MIRBuilder); return BuildCOPY(ResVReg, CounterVarReg, I); } @@ -3713,6 +3719,55 @@ bool SPIRVInstructionSelector::selectResourceGetPointer( .constrainAllUses(TII, TRI, RBI); } +bool SPIRVInstructionSelector::selectResourceNonUniformIndex( + Register &ResVReg, const SPIRVType *ResType, MachineInstr &I) const { + Register ObjReg = I.getOperand(2).getReg(); + if (!BuildCOPY(ResVReg, ObjReg, I)) + return false; + + buildOpDecorate(ResVReg, I, TII, SPIRV::Decoration::NonUniformEXT, {}); + // Check for the registers that use the index marked as non-uniform + // and recursively mark them as non-uniform. + // Per the spec, it's necessary that the final argument used for + // load/store/sample/atomic must be decorated, so we need to propagate the + // decoration through access chains and copies. + // https://docs.vulkan.org/samples/latest/samples/extensions/descriptor_indexing/README.html#_when_to_use_non_uniform_indexing_qualifier + decorateUsesAsNonUniform(ResVReg); + return true; +} + +void SPIRVInstructionSelector::decorateUsesAsNonUniform( + Register &NonUniformReg) const { + llvm::SmallVector<Register> WorkList = {NonUniformReg}; + while (WorkList.size() > 0) { + Register CurrentReg = WorkList.back(); + WorkList.pop_back(); + + bool IsDecorated = false; + for (MachineInstr &Use : MRI->use_instructions(CurrentReg)) { + if (Use.getOpcode() == SPIRV::OpDecorate && + Use.getOperand(1).getImm() == SPIRV::Decoration::NonUniformEXT) { + IsDecorated = true; + continue; + } + // Check if the instruction has the result register and add it to the + // worklist. + if (Use.getOperand(0).isReg() && Use.getOperand(0).isDef()) { + Register ResultReg = Use.getOperand(0).getReg(); + if (ResultReg == CurrentReg) + continue; + WorkList.push_back(ResultReg); + } + } + + if (!IsDecorated) { + buildOpDecorate(CurrentReg, *MRI->getVRegDef(CurrentReg), TII, + SPIRV::Decoration::NonUniformEXT, {}); + } + } + return; +} + bool SPIRVInstructionSelector::extractSubvector( Register &ResVReg, const SPIRVType *ResType, Register &ReadReg, MachineInstr &InsertionPoint) const { @@ -3784,7 +3839,7 @@ bool SPIRVInstructionSelector::selectImageWriteIntrinsic( Register SPIRVInstructionSelector::buildPointerToResource( const SPIRVType *SpirvResType, SPIRV::StorageClass::StorageClass SC, uint32_t Set, uint32_t Binding, uint32_t ArraySize, Register IndexReg, - bool IsNonUniform, StringRef Name, MachineIRBuilder MIRBuilder) const { + StringRef Name, MachineIRBuilder MIRBuilder) const { const Type *ResType = GR.getTypeForSPIRVType(SpirvResType); if (ArraySize == 1) { SPIRVType *PtrType = @@ -3803,14 +3858,7 @@ Register SPIRVInstructionSelector::buildPointerToResource( SPIRVType *ResPointerType = GR.getOrCreateSPIRVPointerType(ResType, MIRBuilder, SC); - Register AcReg = MRI->createVirtualRegister(GR.getRegClass(ResPointerType)); - if (IsNonUniform) { - // It is unclear which value needs to be marked an non-uniform, so both - // the index and the access changed are decorated as non-uniform. - buildOpDecorate(IndexReg, MIRBuilder, SPIRV::Decoration::NonUniformEXT, {}); - buildOpDecorate(AcReg, MIRBuilder, SPIRV::Decoration::NonUniformEXT, {}); - } MIRBuilder.buildInstr(SPIRV::OpAccessChain) .addDef(AcReg) @@ -4560,9 +4608,6 @@ bool SPIRVInstructionSelector::loadHandleBeforePosition( uint32_t Binding = foldImm(HandleDef.getOperand(3), MRI); uint32_t ArraySize = foldImm(HandleDef.getOperand(4), MRI); Register IndexReg = HandleDef.getOperand(5).getReg(); - // FIXME: The IsNonUniform flag needs to be set based on resource analysis. - // https://github.com/llvm/llvm-project/issues/155701 - bool IsNonUniform = false; std::string Name = getStringValueFromReg(HandleDef.getOperand(6).getReg(), *MRI); @@ -4576,13 +4621,8 @@ bool SPIRVInstructionSelector::loadHandleBeforePosition( SC = GR.getPointerStorageClass(ResType); } - Register VarReg = - buildPointerToResource(VarType, SC, Set, Binding, ArraySize, IndexReg, - IsNonUniform, Name, MIRBuilder); - - if (IsNonUniform) - buildOpDecorate(HandleReg, HandleDef, TII, SPIRV::Decoration::NonUniformEXT, - {}); + Register VarReg = buildPointerToResource(VarType, SC, Set, Binding, ArraySize, + IndexReg, Name, MIRBuilder); // The handle for the buffer is the pointer to the resource. For an image, the // handle is the image object. So images get an extra load. diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index e62d57e..50136a8 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -9348,13 +9348,12 @@ static SmallVector<Instruction *> preparePlanForEpilogueVectorLoop( VPBasicBlock *Header = VectorLoop->getEntryBasicBlock(); Header->setName("vec.epilog.vector.body"); - // Ensure that the start values for all header phi recipes are updated before - // vectorizing the epilogue loop. VPCanonicalIVPHIRecipe *IV = Plan.getCanonicalIV(); - // When vectorizing the epilogue loop, the canonical induction start - // value needs to be changed from zero to the value after the main - // vector loop. Find the resume value created during execution of the main - // VPlan. It must be the first phi in the loop preheader. + // When vectorizing the epilogue loop, the canonical induction needs to be + // adjusted by the value after the main vector loop. Find the resume value + // created during execution of the main VPlan. It must be the first phi in the + // loop preheader. Use the value to increment the canonical IV, and update all + // users in the loop region to use the adjusted value. // FIXME: Improve modeling for canonical IV start values in the epilogue // loop. using namespace llvm::PatternMatch; @@ -9389,10 +9388,16 @@ static SmallVector<Instruction *> preparePlanForEpilogueVectorLoop( }) && "the canonical IV should only be used by its increment or " "ScalarIVSteps when resetting the start value"); - IV->setOperand(0, VPV); + VPBuilder Builder(Header, Header->getFirstNonPhi()); + VPInstruction *Add = Builder.createNaryOp(Instruction::Add, {IV, VPV}); + IV->replaceAllUsesWith(Add); + Add->setOperand(0, IV); DenseMap<Value *, Value *> ToFrozen; SmallVector<Instruction *> InstsToMove; + // Ensure that the start values for all header phi recipes are updated before + // vectorizing the epilogue loop. Skip the canonical IV, which has been + // handled above. for (VPRecipeBase &R : drop_begin(Header->phis())) { Value *ResumeV = nullptr; // TODO: Move setting of resume values to prepareToExecute. diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp index 2555ebe..1fea068 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp @@ -1777,6 +1777,9 @@ InstructionCost VPCostContext::getScalarizationOverhead( if (VF.isScalar()) return 0; + assert(!VF.isScalable() && + "Scalarization overhead not supported for scalable vectors"); + InstructionCost ScalarizationCost = 0; // Compute the cost of scalarizing the result if needed. if (!ResultTy->isVoidTy()) { diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index 600ff8a..8e916772 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -3174,6 +3174,9 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF, // transform, avoid computing their cost multiple times for now. Ctx.SkipCostComputation.insert(UI); + if (VF.isScalable() && !isSingleScalar()) + return InstructionCost::getInvalid(); + switch (UI->getOpcode()) { case Instruction::GetElementPtr: // We mark this instruction as zero-cost because the cost of GEPs in @@ -3221,9 +3224,6 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF, return ScalarCallCost; } - if (VF.isScalable()) - return InstructionCost::getInvalid(); - return ScalarCallCost * VF.getFixedValue() + Ctx.getScalarizationOverhead(ResultTy, ArgOps, VF); } @@ -3274,9 +3274,6 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF, } case Instruction::Load: case Instruction::Store: { - if (VF.isScalable() && !isSingleScalar()) - return InstructionCost::getInvalid(); - // TODO: See getMemInstScalarizationCost for how to handle replicating and // predicated cases. const VPRegionBlock *ParentRegion = getParent()->getParent(); diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index c8a2d84..7563cd7 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -1234,6 +1234,18 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) { if (!Plan->isUnrolled()) return; + // Hoist an invariant increment Y of a phi X, by having X start at Y. + if (match(Def, m_c_Add(m_VPValue(X), m_VPValue(Y))) && Y->isLiveIn() && + isa<VPPhi>(X)) { + auto *Phi = cast<VPPhi>(X); + if (Phi->getOperand(1) != Def && match(Phi->getOperand(0), m_ZeroInt()) && + Phi->getNumUsers() == 1 && (*Phi->user_begin() == &R)) { + Phi->setOperand(0, Y); + Def->replaceAllUsesWith(Phi); + return; + } + } + // VPVectorPointer for part 0 can be replaced by their start pointer. if (auto *VecPtr = dyn_cast<VPVectorPointerRecipe>(&R)) { if (VecPtr->isFirstPart()) { diff --git a/llvm/test/CodeGen/SPIRV/hlsl-resources/NonUniformIdx/RWStructuredBufferNonUniformIdx.ll b/llvm/test/CodeGen/SPIRV/hlsl-resources/NonUniformIdx/RWStructuredBufferNonUniformIdx.ll new file mode 100644 index 0000000..2a12baf --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/hlsl-resources/NonUniformIdx/RWStructuredBufferNonUniformIdx.ll @@ -0,0 +1,26 @@ +; RUN: llc -O0 -mtriple=spirv1.6-unknown-vulkan1.3-compute %s -o - | FileCheck %s --match-full-lines +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.6-unknown-vulkan1.3-compute %s -o - -filetype=obj | spirv-val %} + +; CHECK-DAG: OpCapability Shader +; CHECK-DAG: OpCapability ShaderNonUniformEXT +; CHECK-DAG: OpDecorate {{%[0-9]+}} NonUniformEXT +; CHECK-DAG: OpDecorate {{%[0-9]+}} NonUniformEXT +; CHECK-DAG: OpDecorate {{%[0-9]+}} NonUniformEXT +; CHECK-DAG: OpDecorate {{%[0-9]+}} NonUniformEXT +; CHECK-DAG: OpDecorate %[[#access1:]] NonUniformEXT +@ReadWriteStructuredBuf.str = private unnamed_addr constant [23 x i8] c"ReadWriteStructuredBuf\00", align 1 + +define void @main() local_unnamed_addr #0 { +entry: + %0 = tail call i32 @llvm.spv.thread.id.in.group.i32(i32 0) + %add.i = add i32 %0, 1 + %1 = tail call noundef i32 @llvm.spv.resource.nonuniformindex(i32 %add.i) + %2 = tail call target("spirv.VulkanBuffer", [0 x <4 x i32>], 12, 1) @llvm.spv.resource.handlefromimplicitbinding.tspirv.VulkanBuffer_a0v4i32_12_1t(i32 0, i32 0, i32 64, i32 %1, ptr nonnull @ReadWriteStructuredBuf.str) + %3 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v4i32_12_1t(target("spirv.VulkanBuffer", [0 x <4 x i32>], 12, 1) %2, i32 98) + %4 = load <4 x i32>, ptr addrspace(11) %3, align 16 + %vecins.i = insertelement <4 x i32> %4, i32 99, i64 0 +; CHECK: %[[#access1]] = OpAccessChain {{.*}} +; CHECK: OpStore %[[#access1]] {{%[0-9]+}} Aligned 16 + store <4 x i32> %vecins.i, ptr addrspace(11) %3, align 16 + ret void +} diff --git a/llvm/test/CodeGen/SPIRV/hlsl-resources/NonUniformIdx/StructuredBufferNonUniformIdx.ll b/llvm/test/CodeGen/SPIRV/hlsl-resources/NonUniformIdx/StructuredBufferNonUniformIdx.ll new file mode 100644 index 0000000..92efad9 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/hlsl-resources/NonUniformIdx/StructuredBufferNonUniformIdx.ll @@ -0,0 +1,24 @@ +; RUN: llc -O0 -mtriple=spirv1.6-unknown-vulkan1.3-compute %s -o - | FileCheck %s --match-full-lines +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.6-unknown-vulkan1.3-compute %s -o - -filetype=obj | spirv-val %} + +; CHECK-DAG: OpCapability Shader +; CHECK-DAG: OpCapability ShaderNonUniformEXT +; CHECK-DAG: OpCapability StorageTexelBufferArrayNonUniformIndexingEXT +; CHECK-DAG: OpDecorate {{%[0-9]+}} NonUniformEXT +; CHECK-DAG: OpDecorate %[[#access:]] NonUniformEXT +; CHECK-DAG: OpDecorate %[[#load:]] NonUniformEXT +@ReadWriteBuf.str = private unnamed_addr constant [13 x i8] c"ReadWriteBuf\00", align 1 + +define void @main() local_unnamed_addr #0 { +entry: + %0 = tail call i32 @llvm.spv.thread.id.in.group.i32(i32 0) + %1 = tail call noundef i32 @llvm.spv.resource.nonuniformindex(i32 %0) + %2 = tail call target("spirv.Image", i32, 5, 2, 0, 0, 2, 33) @llvm.spv.resource.handlefromimplicitbinding.tspirv.Image_i32_5_2_0_0_2_33t(i32 0, i32 0, i32 64, i32 %1, ptr nonnull @ReadWriteBuf.str) + %3 = tail call noundef align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.Image_i32_5_2_0_0_2_33t(target("spirv.Image", i32, 5, 2, 0, 0, 2, 33) %2, i32 96) +; CHECK: {{%[0-9]+}} = OpCompositeExtract {{.*}} +; CHECK: %[[#access]] = OpAccessChain {{.*}} +; CHECK: %[[#load]] = OpLoad {{%[0-9]+}} %[[#access]] +; CHECK: OpImageWrite %[[#load]] {{%[0-9]+}} {{%[0-9]+}} + store i32 95, ptr addrspace(11) %3, align 4 + ret void +} diff --git a/llvm/test/CodeGen/SPIRV/hlsl-resources/StorageImageNonUniformIdx.ll b/llvm/test/CodeGen/SPIRV/hlsl-resources/StorageImageNonUniformIdx.ll deleted file mode 100644 index 5e15aab..0000000 --- a/llvm/test/CodeGen/SPIRV/hlsl-resources/StorageImageNonUniformIdx.ll +++ /dev/null @@ -1,56 +0,0 @@ -; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv1.5-vulkan-library %s -o - | FileCheck %s -; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.5-vulkan-library %s -o - -filetype=obj | spirv-val %} - -; This test depends on llvm.svp.resource.nonuniformindex support (not yet implemented) -; https://github.com/llvm/llvm-project/issues/160231 -; XFAIL: * - -@.str.b0 = private unnamed_addr constant [3 x i8] c"B0\00", align 1 - -; CHECK-DAG: OpCapability Shader -; CHECK-DAG: OpCapability ShaderNonUniformEXT -; CHECK-DAG: OpCapability StorageImageArrayNonUniformIndexing -; CHECK-DAG: OpCapability Image1D -; CHECK-NOT: OpCapability - -; CHECK-DAG: OpDecorate [[Var:%[0-9]+]] DescriptorSet 3 -; CHECK-DAG: OpDecorate [[Var]] Binding 4 -; CHECK: OpDecorate [[Zero:%[0-9]+]] NonUniform -; CHECK: OpDecorate [[ac0:%[0-9]+]] NonUniform -; CHECK: OpDecorate [[ld0:%[0-9]+]] NonUniform -; CHECK: OpDecorate [[One:%[0-9]+]] NonUniform -; CHECK: OpDecorate [[ac1:%[0-9]+]] NonUniform -; CHECK: OpDecorate [[ld1:%[0-9]+]] NonUniform - -; CHECK-DAG: [[int:%[0-9]+]] = OpTypeInt 32 0 -; CHECK-DAG: [[BufferType:%[0-9]+]] = OpTypeImage [[int]] 1D 2 0 0 2 R32i {{$}} -; CHECK-DAG: [[BufferPtrType:%[0-9]+]] = OpTypePointer UniformConstant [[BufferType]] -; CHECK-DAG: [[ArraySize:%[0-9]+]] = OpConstant [[int]] 3 -; CHECK-DAG: [[One]] = OpConstant [[int]] 1 -; CHECK-DAG: [[Zero]] = OpConstant [[int]] 0{{$}} -; CHECK-DAG: [[BufferArrayType:%[0-9]+]] = OpTypeArray [[BufferType]] [[ArraySize]] -; CHECK-DAG: [[ArrayPtrType:%[0-9]+]] = OpTypePointer UniformConstant [[BufferArrayType]] -; CHECK-DAG: [[Var]] = OpVariable [[ArrayPtrType]] UniformConstant - -; CHECK: {{%[0-9]+}} = OpFunction {{%[0-9]+}} DontInline {{%[0-9]+}} -; CHECK-NEXT: OpLabel -define void @main() #0 { -; CHECK: [[ac0]] = OpAccessChain [[BufferPtrType]] [[Var]] [[Zero]] -; CHECK: [[ld0]] = OpLoad [[BufferType]] [[ac0]] - %buffer0 = call target("spirv.Image", i32, 0, 2, 0, 0, 2, 24) - @llvm.spv.resource.handlefrombinding.tspirv.Image_f32_0_2_0_0_2_24( - i32 3, i32 4, i32 3, i32 0, ptr nonnull @.str.b0) - %ptr0 = tail call noundef nonnull align 4 dereferenceable(4) ptr @llvm.spv.resource.getpointer.p0.tspirv.Image_f32_5_2_0_0_2_0t(target("spirv.Image", i32, 0, 2, 0, 0, 2, 24) %buffer0, i32 0) - store i32 0, ptr %ptr0, align 4 - -; CHECK: [[ac1:%[0-9]+]] = OpAccessChain [[BufferPtrType]] [[Var]] [[One]] -; CHECK: [[ld1]] = OpLoad [[BufferType]] [[ac1]] - %buffer1 = call target("spirv.Image", i32, 0, 2, 0, 0, 2, 24) - @llvm.spv.resource.handlefrombinding.tspirv.Image_f32_0_2_0_0_2_24( - i32 3, i32 4, i32 3, i32 1, ptr nonnull @.str.b0) - %ptr1 = tail call noundef nonnull align 4 dereferenceable(4) ptr @llvm.spv.resource.getpointer.p0.tspirv.Image_f32_5_2_0_0_2_0t(target("spirv.Image", i32, 0, 2, 0, 0, 2, 24) %buffer1, i32 0) - store i32 0, ptr %ptr1, align 4 - ret void -} - -attributes #0 = { convergent noinline norecurse "frame-pointer"="all" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" } diff --git a/llvm/test/DebugInfo/Generic/compileunit-source-language-name.ll b/llvm/test/DebugInfo/Generic/compileunit-source-language-name.ll index 211a7bc..e2b6167 100644 --- a/llvm/test/DebugInfo/Generic/compileunit-source-language-name.ll +++ b/llvm/test/DebugInfo/Generic/compileunit-source-language-name.ll @@ -4,6 +4,11 @@ @x = global i32 0, align 4, !dbg !0 +; Function Attrs: mustprogress noinline nounwind optnone ssp uwtable(sync) +define void @_Z4funcv() !dbg !8 { + ret void, !dbg !11 +} + !llvm.dbg.cu = !{!2} !llvm.module.flags = !{!6, !7} @@ -15,3 +20,7 @@ !5 = !DIBasicType(name: "int", size: 32, encoding: DW_ATE_signed) !6 = !{i32 7, !"Dwarf Version", i32 5} !7 = !{i32 2, !"Debug Info Version", i32 3} +!8 = distinct !DISubprogram(name: "func", linkageName: "_Z4funcv", scope: !3, file: !3, line: 2, type: !9, scopeLine: 2, flags: DIFlagPrototyped, spFlags: DISPFlagDefinition, unit: !2) +!9 = !DISubroutineType(types: !10) +!10 = !{null} +!11 = !DILocation(line: 2, column: 14, scope: !8) diff --git a/llvm/test/Transforms/IndVarSimplify/loop-guard-order.ll b/llvm/test/Transforms/IndVarSimplify/loop-guard-order.ll index b946bbf..14ee00d 100644 --- a/llvm/test/Transforms/IndVarSimplify/loop-guard-order.ll +++ b/llvm/test/Transforms/IndVarSimplify/loop-guard-order.ll @@ -1,6 +1,8 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 ; RUN: opt -p indvars -S %s | FileCheck %s +target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128" + declare void @foo() define void @narrow_iv_precondition_order_1(ptr %start, i32 %base, i8 %n) { @@ -96,3 +98,202 @@ loop: exit: ret void } + +define i32 @urem_order1(i32 %n) { +; CHECK-LABEL: define i32 @urem_order1( +; CHECK-SAME: i32 [[N:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*]]: +; CHECK-NEXT: [[UREM:%.*]] = urem i32 [[N]], 3 +; CHECK-NEXT: [[UREM_ZERO:%.*]] = icmp eq i32 [[UREM]], 0 +; CHECK-NEXT: br i1 [[UREM_ZERO]], label %[[PH:.*]], label %[[EXIT:.*]] +; CHECK: [[PH]]: +; CHECK-NEXT: [[N_NON_ZERO:%.*]] = icmp ne i32 [[N]], 0 +; CHECK-NEXT: br i1 [[N_NON_ZERO]], label %[[LOOP_PREHEADER:.*]], label %[[EXIT]] +; CHECK: [[LOOP_PREHEADER]]: +; CHECK-NEXT: br label %[[LOOP:.*]] +; CHECK: [[LOOP]]: +; CHECK-NEXT: [[IV:%.*]] = phi i32 [ [[IV_NEXT:%.*]], %[[LOOP]] ], [ 0, %[[LOOP_PREHEADER]] ] +; CHECK-NEXT: call void @foo() +; CHECK-NEXT: [[IV_NEXT]] = add i32 [[IV]], 3 +; CHECK-NEXT: [[EC:%.*]] = icmp eq i32 [[IV_NEXT]], [[N]] +; CHECK-NEXT: br i1 [[EC]], label %[[EXIT_LOOPEXIT:.*]], label %[[LOOP]] +; CHECK: [[EXIT_LOOPEXIT]]: +; CHECK-NEXT: br label %[[EXIT]] +; CHECK: [[EXIT]]: +; CHECK-NEXT: [[RES:%.*]] = phi i32 [ 1, %[[ENTRY]] ], [ 2, %[[PH]] ], [ 3, %[[EXIT_LOOPEXIT]] ] +; CHECK-NEXT: ret i32 [[RES]] +; +entry: + %urem = urem i32 %n, 3 + %urem.zero = icmp eq i32 %urem, 0 + br i1 %urem.zero, label %ph, label %exit + +ph: + %n.non.zero = icmp ne i32 %n, 0 + br i1 %n.non.zero, label %loop, label %exit + +loop: + %iv = phi i32 [ 0, %ph ], [ %iv.next, %loop ] + call void @foo() + %iv.next = add i32 %iv, 3 + %ec = icmp eq i32 %iv.next, %n + br i1 %ec, label %exit, label %loop + +exit: + %res = phi i32 [ 1, %entry ], [ 2, %ph ], [ 3, %loop ] + ret i32 %res +} + +define i32 @urem_order2(i32 %n) { +; CHECK-LABEL: define i32 @urem_order2( +; CHECK-SAME: i32 [[N:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*]]: +; CHECK-NEXT: [[N_NON_ZERO:%.*]] = icmp ne i32 [[N]], 0 +; CHECK-NEXT: br i1 [[N_NON_ZERO]], label %[[PH:.*]], label %[[EXIT:.*]] +; CHECK: [[PH]]: +; CHECK-NEXT: [[UREM:%.*]] = urem i32 [[N]], 3 +; CHECK-NEXT: [[UREM_ZERO:%.*]] = icmp eq i32 [[UREM]], 0 +; CHECK-NEXT: br i1 [[UREM_ZERO]], label %[[LOOP_PREHEADER:.*]], label %[[EXIT]] +; CHECK: [[LOOP_PREHEADER]]: +; CHECK-NEXT: br label %[[LOOP:.*]] +; CHECK: [[LOOP]]: +; CHECK-NEXT: [[IV:%.*]] = phi i32 [ [[IV_NEXT:%.*]], %[[LOOP]] ], [ 0, %[[LOOP_PREHEADER]] ] +; CHECK-NEXT: call void @foo() +; CHECK-NEXT: [[IV_NEXT]] = add nuw i32 [[IV]], 3 +; CHECK-NEXT: [[EC:%.*]] = icmp eq i32 [[IV_NEXT]], [[N]] +; CHECK-NEXT: br i1 [[EC]], label %[[EXIT_LOOPEXIT:.*]], label %[[LOOP]] +; CHECK: [[EXIT_LOOPEXIT]]: +; CHECK-NEXT: br label %[[EXIT]] +; CHECK: [[EXIT]]: +; CHECK-NEXT: [[RES:%.*]] = phi i32 [ 1, %[[ENTRY]] ], [ 2, %[[PH]] ], [ 3, %[[EXIT_LOOPEXIT]] ] +; CHECK-NEXT: ret i32 [[RES]] +; +entry: + %n.non.zero = icmp ne i32 %n, 0 + br i1 %n.non.zero, label %ph, label %exit + +ph: + %urem = urem i32 %n, 3 + %urem.zero = icmp eq i32 %urem, 0 + br i1 %urem.zero, label %loop, label %exit + +loop: + %iv = phi i32 [ 0, %ph ], [ %iv.next, %loop ] + call void @foo() + %iv.next = add i32 %iv, 3 + %ec = icmp eq i32 %iv.next, %n + br i1 %ec, label %exit, label %loop + +exit: + %res = phi i32 [ 1, %entry ], [ 2, %ph ], [ 3, %loop ] + ret i32 %res +} + +define i64 @test_loop_with_div_order_1(i64 %n) { +; CHECK-LABEL: define i64 @test_loop_with_div_order_1( +; CHECK-SAME: i64 [[N:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[IS_ZERO:%.*]] = icmp eq i64 [[N]], 0 +; CHECK-NEXT: br i1 [[IS_ZERO]], label %[[EXIT:.*]], label %[[CHECK_BOUNDS:.*]] +; CHECK: [[CHECK_BOUNDS]]: +; CHECK-NEXT: [[N_PLUS_63:%.*]] = add i64 [[N]], 63 +; CHECK-NEXT: [[UPPER_BOUND:%.*]] = lshr i64 [[N_PLUS_63]], 6 +; CHECK-NEXT: [[BOUNDS_CHECK:%.*]] = icmp ult i64 [[N_PLUS_63]], 64 +; CHECK-NEXT: br i1 [[BOUNDS_CHECK]], label %[[EXIT]], label %[[CHECK_PARITY:.*]] +; CHECK: [[CHECK_PARITY]]: +; CHECK-NEXT: [[IS_ODD:%.*]] = and i64 [[N]], 1 +; CHECK-NEXT: [[PARITY_CHECK:%.*]] = icmp eq i64 [[IS_ODD]], 0 +; CHECK-NEXT: br i1 [[PARITY_CHECK]], label %[[LOOP_PREHEADER:.*]], label %[[EXIT]] +; CHECK: [[LOOP_PREHEADER]]: +; CHECK-NEXT: [[UMAX:%.*]] = call i64 @llvm.umax.i64(i64 [[UPPER_BOUND]], i64 1) +; CHECK-NEXT: br label %[[LOOP:.*]] +; CHECK: [[LOOP]]: +; CHECK-NEXT: [[IV:%.*]] = phi i64 [ [[IV_NEXT:%.*]], %[[LOOP]] ], [ 0, %[[LOOP_PREHEADER]] ] +; CHECK-NEXT: [[DUMMY:%.*]] = load volatile i64, ptr null, align 8 +; CHECK-NEXT: [[IV_NEXT]] = add nuw nsw i64 [[IV]], 1 +; CHECK-NEXT: [[EXITCOND:%.*]] = icmp ne i64 [[IV_NEXT]], [[UMAX]] +; CHECK-NEXT: br i1 [[EXITCOND]], label %[[LOOP]], label %[[EXIT_LOOPEXIT:.*]] +; CHECK: [[EXIT_LOOPEXIT]]: +; CHECK-NEXT: br label %[[EXIT]] +; CHECK: [[EXIT]]: +; CHECK-NEXT: ret i64 0 +; +entry: + %is_zero = icmp eq i64 %n, 0 + br i1 %is_zero, label %exit, label %check_bounds + +check_bounds: + %n_plus_63 = add i64 %n, 63 + %upper_bound = lshr i64 %n_plus_63, 6 + %bounds_check = icmp ult i64 %n_plus_63, 64 + br i1 %bounds_check, label %exit, label %check_parity + +check_parity: + %is_odd = and i64 %n, 1 + %parity_check = icmp eq i64 %is_odd, 0 + br i1 %parity_check, label %loop, label %exit + +loop: + %iv = phi i64 [ %iv_next, %loop ], [ 0, %check_parity ] + %dummy = load volatile i64, ptr null, align 8 + %iv_next = add i64 %iv, 1 + %exit_cond = icmp ult i64 %iv_next, %upper_bound + br i1 %exit_cond, label %loop, label %exit + +exit: + ret i64 0 +} + +define i64 @test_loop_with_div_order_2(i64 %n) { +; CHECK-LABEL: define i64 @test_loop_with_div_order_2( +; CHECK-SAME: i64 [[N:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[N_PLUS_63:%.*]] = add i64 [[N]], 63 +; CHECK-NEXT: [[UPPER_BOUND:%.*]] = lshr i64 [[N_PLUS_63]], 6 +; CHECK-NEXT: [[BOUNDS_CHECK:%.*]] = icmp ult i64 [[N_PLUS_63]], 64 +; CHECK-NEXT: br i1 [[BOUNDS_CHECK]], label %[[EXIT:.*]], label %[[CHECK_BOUNDS:.*]] +; CHECK: [[CHECK_BOUNDS]]: +; CHECK-NEXT: [[IS_ZERO:%.*]] = icmp eq i64 [[N]], 0 +; CHECK-NEXT: br i1 [[IS_ZERO]], label %[[EXIT]], label %[[CHECK_PARITY:.*]] +; CHECK: [[CHECK_PARITY]]: +; CHECK-NEXT: [[IS_ODD:%.*]] = and i64 [[N]], 1 +; CHECK-NEXT: [[PARITY_CHECK:%.*]] = icmp eq i64 [[IS_ODD]], 0 +; CHECK-NEXT: br i1 [[PARITY_CHECK]], label %[[LOOP_PREHEADER:.*]], label %[[EXIT]] +; CHECK: [[LOOP_PREHEADER]]: +; CHECK-NEXT: br label %[[LOOP:.*]] +; CHECK: [[LOOP]]: +; CHECK-NEXT: [[IV:%.*]] = phi i64 [ [[IV_NEXT:%.*]], %[[LOOP]] ], [ 0, %[[LOOP_PREHEADER]] ] +; CHECK-NEXT: [[DUMMY:%.*]] = load volatile i64, ptr null, align 8 +; CHECK-NEXT: [[IV_NEXT]] = add nuw nsw i64 [[IV]], 1 +; CHECK-NEXT: [[EXITCOND:%.*]] = icmp ne i64 [[IV_NEXT]], [[UPPER_BOUND]] +; CHECK-NEXT: br i1 [[EXITCOND]], label %[[LOOP]], label %[[EXIT_LOOPEXIT:.*]] +; CHECK: [[EXIT_LOOPEXIT]]: +; CHECK-NEXT: br label %[[EXIT]] +; CHECK: [[EXIT]]: +; CHECK-NEXT: ret i64 0 +; +entry: + %n_plus_63 = add i64 %n, 63 + %upper_bound = lshr i64 %n_plus_63, 6 + %bounds_check = icmp ult i64 %n_plus_63, 64 + br i1 %bounds_check, label %exit, label %check_bounds + +check_bounds: + %is_zero = icmp eq i64 %n, 0 + br i1 %is_zero, label %exit, label %check_parity + +check_parity: + %is_odd = and i64 %n, 1 + %parity_check = icmp eq i64 %is_odd, 0 + br i1 %parity_check, label %loop, label %exit + +loop: + %iv = phi i64 [ %iv_next, %loop ], [ 0, %check_parity ] + %dummy = load volatile i64, ptr null, align 8 + %iv_next = add i64 %iv, 1 + %exit_cond = icmp ult i64 %iv_next, %upper_bound + br i1 %exit_cond, label %loop, label %exit + +exit: + ret i64 0 +} diff --git a/llvm/test/Transforms/LoopSimplifyCFG/pr117537.ll b/llvm/test/Transforms/LoopSimplifyCFG/pr117537.ll index df1399d..a8db6a0 100644 --- a/llvm/test/Transforms/LoopSimplifyCFG/pr117537.ll +++ b/llvm/test/Transforms/LoopSimplifyCFG/pr117537.ll @@ -1,5 +1,5 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 -; RUN: opt -S -passes='print<scalar-evolution>,loop-mssa(licm,loop-simplifycfg,loop-predication)' -verify-scev < %s 2>/dev/null | FileCheck %s +; RUN: opt -S -passes='print<scalar-evolution>,loop-mssa(licm,loop-simplifycfg,loop-predication)' -verify-scev < %s | FileCheck %s ; Make sure we don't assert due to insufficient SCEV invalidation. diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/invalid-costs.ll b/llvm/test/Transforms/LoopVectorize/AArch64/invalid-costs.ll index 757d9e7..803ffa8 100644 --- a/llvm/test/Transforms/LoopVectorize/AArch64/invalid-costs.ll +++ b/llvm/test/Transforms/LoopVectorize/AArch64/invalid-costs.ll @@ -1,42 +1,81 @@ -; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals none --filter-out-after "scalar.ph:" --version 6 ; RUN: opt -passes="loop-vectorize" -pass-remarks-output=%t.yaml -S %s | FileCheck %s ; RUN: FileCheck --input-file=%t.yaml --check-prefix=REMARKS %s -; REMARKS: the cost-model indicates that vectorization is not beneficial +target triple = "arm64-apple-macosx" -; Test for https://github.com/llvm/llvm-project/issues/116375. -define void @test_i24_load_for(ptr noalias %src, ptr %dst) { -; CHECK-LABEL: define void @test_i24_load_for( -; CHECK-SAME: ptr noalias [[SRC:%.*]], ptr [[DST:%.*]]) { -; CHECK-NEXT: [[ENTRY:.*]]: -; CHECK-NEXT: br label %[[LOOP:.*]] -; CHECK: [[LOOP]]: -; CHECK-NEXT: [[IV:%.*]] = phi i16 [ 0, %[[ENTRY]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ] -; CHECK-NEXT: [[FOR:%.*]] = phi i24 [ 0, %[[ENTRY]] ], [ [[FOR_NEXT:%.*]], %[[LOOP]] ] -; CHECK-NEXT: [[IV_NEXT]] = add i16 [[IV]], 1 -; CHECK-NEXT: [[GEP_SRC:%.*]] = getelementptr inbounds i24, ptr [[SRC]], i16 [[IV]] -; CHECK-NEXT: [[FOR_NEXT]] = load i24, ptr [[GEP_SRC]], align 1 -; CHECK-NEXT: [[GEP_DST:%.*]] = getelementptr inbounds i24, ptr [[DST]], i16 [[IV]] -; CHECK-NEXT: store i24 [[FOR]], ptr [[GEP_DST]], align 4 -; CHECK-NEXT: [[EC:%.*]] = icmp eq i16 [[IV_NEXT]], 1000 -; CHECK-NEXT: br i1 [[EC]], label %[[EXIT:.*]], label %[[LOOP]] -; CHECK: [[EXIT]]: -; CHECK-NEXT: ret void +; REMARKS: Recipe with invalid costs prevented vectorization at VF=(vscale x 1): load +; Test case for https://github.com/llvm/llvm-project/issues/160792. +define void @replicate_sdiv_conditional(ptr noalias %a, ptr noalias %b, ptr noalias %c) #0 { +; CHECK-LABEL: define void @replicate_sdiv_conditional( +; CHECK-SAME: ptr noalias [[A:%.*]], ptr noalias [[B:%.*]], ptr noalias [[C:%.*]]) #[[ATTR0:[0-9]+]] { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[TMP0:%.*]] = call i64 @llvm.vscale.i64() +; CHECK-NEXT: [[TMP1:%.*]] = shl nuw i64 [[TMP0]], 2 +; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 64, [[TMP1]] +; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]] +; CHECK: [[VECTOR_PH]]: +; CHECK-NEXT: [[TMP2:%.*]] = call i64 @llvm.vscale.i64() +; CHECK-NEXT: [[TMP3:%.*]] = mul nuw i64 [[TMP2]], 4 +; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 64, [[TMP3]] +; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 64, [[N_MOD_VF]] +; CHECK-NEXT: br label %[[VECTOR_BODY:.*]] +; CHECK: [[VECTOR_BODY]]: +; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ] +; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i32, ptr [[C]], i64 [[INDEX]] +; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <vscale x 4 x i32>, ptr [[TMP4]], align 4 +; CHECK-NEXT: [[TMP5:%.*]] = icmp slt <vscale x 4 x i32> [[WIDE_LOAD]], zeroinitializer +; CHECK-NEXT: [[TMP6:%.*]] = getelementptr i32, ptr [[B]], i64 [[INDEX]] +; CHECK-NEXT: [[WIDE_MASKED_LOAD:%.*]] = call <vscale x 4 x i32> @llvm.masked.load.nxv4i32.p0(ptr [[TMP6]], i32 4, <vscale x 4 x i1> [[TMP5]], <vscale x 4 x i32> poison) +; CHECK-NEXT: [[TMP7:%.*]] = sext <vscale x 4 x i32> [[WIDE_MASKED_LOAD]] to <vscale x 4 x i64> +; CHECK-NEXT: [[TMP8:%.*]] = ashr <vscale x 4 x i32> [[WIDE_MASKED_LOAD]], splat (i32 1) +; CHECK-NEXT: [[TMP9:%.*]] = add <vscale x 4 x i32> [[TMP8]], [[WIDE_LOAD]] +; CHECK-NEXT: [[TMP10:%.*]] = sext <vscale x 4 x i32> [[TMP9]] to <vscale x 4 x i64> +; CHECK-NEXT: [[TMP11:%.*]] = select <vscale x 4 x i1> [[TMP5]], <vscale x 4 x i64> [[TMP7]], <vscale x 4 x i64> splat (i64 1) +; CHECK-NEXT: [[TMP12:%.*]] = sdiv <vscale x 4 x i64> [[TMP10]], [[TMP11]] +; CHECK-NEXT: [[TMP13:%.*]] = trunc <vscale x 4 x i64> [[TMP12]] to <vscale x 4 x i32> +; CHECK-NEXT: [[PREDPHI:%.*]] = select <vscale x 4 x i1> [[TMP5]], <vscale x 4 x i32> [[TMP13]], <vscale x 4 x i32> [[WIDE_LOAD]] +; CHECK-NEXT: [[TMP14:%.*]] = getelementptr inbounds i32, ptr [[A]], i64 [[INDEX]] +; CHECK-NEXT: store <vscale x 4 x i32> [[PREDPHI]], ptr [[TMP14]], align 4 +; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP3]] +; CHECK-NEXT: [[TMP15:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]] +; CHECK-NEXT: br i1 [[TMP15]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]] +; CHECK: [[MIDDLE_BLOCK]]: +; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 64, [[N_VEC]] +; CHECK-NEXT: br i1 [[CMP_N]], [[FOR_END:label %.*]], label %[[SCALAR_PH]] +; CHECK: [[SCALAR_PH]]: ; entry: - br label %loop + br label %loop.header -loop: - %iv = phi i16 [ 0, %entry ], [ %iv.next, %loop ] - %for = phi i24 [ 0, %entry ], [ %for.next, %loop ] - %iv.next = add i16 %iv, 1 - %gep.src = getelementptr inbounds i24, ptr %src, i16 %iv - %for.next = load i24, ptr %gep.src, align 1 - %gep.dst = getelementptr inbounds i24, ptr %dst, i16 %iv - store i24 %for, ptr %gep.dst - %ec = icmp eq i16 %iv.next, 1000 - br i1 %ec, label %exit, label %loop +loop.header: + %iv = phi i64 [ 0, %entry ], [ %iv.next, %loop.latch ] + %gep.c = getelementptr inbounds i32, ptr %c, i64 %iv + %val.c = load i32, ptr %gep.c, align 4 + %cmp = icmp slt i32 %val.c, 0 + br i1 %cmp, label %if.then, label %loop.latch -exit: +if.then: + %gep.b = getelementptr inbounds i32, ptr %b, i64 %iv + %val.b = load i32, ptr %gep.b, align 4 + %sext = sext i32 %val.b to i64 + %shr = ashr i32 %val.b, 1 + %add = add i32 %shr, %val.c + %conv = sext i32 %add to i64 + %div = sdiv i64 %conv, %sext + %trunc = trunc i64 %div to i32 + br label %loop.latch + +loop.latch: + %result = phi i32 [ %trunc, %if.then ], [ %val.c, %loop.header ] + %gep.a = getelementptr inbounds i32, ptr %a, i64 %iv + store i32 %result, ptr %gep.a, align 4 + %iv.next = add nuw nsw i64 %iv, 1 + %exit = icmp eq i64 %iv.next, 64 + br i1 %exit, label %for.end, label %loop.header + +for.end: ret void } + +attributes #0 = { "target-cpu"="neoverse-512tvb" } diff --git a/llvm/test/Transforms/LoopVectorize/invalid-costs.ll b/llvm/test/Transforms/LoopVectorize/invalid-costs.ll new file mode 100644 index 0000000..757d9e7 --- /dev/null +++ b/llvm/test/Transforms/LoopVectorize/invalid-costs.ll @@ -0,0 +1,42 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt -passes="loop-vectorize" -pass-remarks-output=%t.yaml -S %s | FileCheck %s +; RUN: FileCheck --input-file=%t.yaml --check-prefix=REMARKS %s + +; REMARKS: the cost-model indicates that vectorization is not beneficial + +; Test for https://github.com/llvm/llvm-project/issues/116375. +define void @test_i24_load_for(ptr noalias %src, ptr %dst) { +; CHECK-LABEL: define void @test_i24_load_for( +; CHECK-SAME: ptr noalias [[SRC:%.*]], ptr [[DST:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*]]: +; CHECK-NEXT: br label %[[LOOP:.*]] +; CHECK: [[LOOP]]: +; CHECK-NEXT: [[IV:%.*]] = phi i16 [ 0, %[[ENTRY]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ] +; CHECK-NEXT: [[FOR:%.*]] = phi i24 [ 0, %[[ENTRY]] ], [ [[FOR_NEXT:%.*]], %[[LOOP]] ] +; CHECK-NEXT: [[IV_NEXT]] = add i16 [[IV]], 1 +; CHECK-NEXT: [[GEP_SRC:%.*]] = getelementptr inbounds i24, ptr [[SRC]], i16 [[IV]] +; CHECK-NEXT: [[FOR_NEXT]] = load i24, ptr [[GEP_SRC]], align 1 +; CHECK-NEXT: [[GEP_DST:%.*]] = getelementptr inbounds i24, ptr [[DST]], i16 [[IV]] +; CHECK-NEXT: store i24 [[FOR]], ptr [[GEP_DST]], align 4 +; CHECK-NEXT: [[EC:%.*]] = icmp eq i16 [[IV_NEXT]], 1000 +; CHECK-NEXT: br i1 [[EC]], label %[[EXIT:.*]], label %[[LOOP]] +; CHECK: [[EXIT]]: +; CHECK-NEXT: ret void +; +entry: + br label %loop + +loop: + %iv = phi i16 [ 0, %entry ], [ %iv.next, %loop ] + %for = phi i24 [ 0, %entry ], [ %for.next, %loop ] + %iv.next = add i16 %iv, 1 + %gep.src = getelementptr inbounds i24, ptr %src, i16 %iv + %for.next = load i24, ptr %gep.src, align 1 + %gep.dst = getelementptr inbounds i24, ptr %dst, i16 %iv + store i24 %for, ptr %gep.dst + %ec = icmp eq i16 %iv.next, 1000 + br i1 %ec, label %exit, label %loop + +exit: + ret void +} diff --git a/llvm/test/tools/llvm-remarkutil/filter.test b/llvm/test/tools/llvm-remarkutil/filter.test index 8304b9f..9fd2e94 100644 --- a/llvm/test/tools/llvm-remarkutil/filter.test +++ b/llvm/test/tools/llvm-remarkutil/filter.test @@ -18,9 +18,19 @@ RUN: llvm-remarkutil filter --remark-type=analysis %p/Inputs/filter.yaml | FileC RUN: llvm-remarkutil yaml2bitstream -o %t.opt.bitstream %p/Inputs/filter.yaml RUN: llvm-remarkutil filter --function=func1 %t.opt.bitstream | FileCheck %s --strict-whitespace --check-prefix=REMARK1 +RUN: llvm-remarkutil filter --function=func1 %t.opt.bitstream -o %t.r1.yamL +RUN: cat %t.r1.yamL | FileCheck %s --strict-whitespace --check-prefix=REMARK1 +RUN: llvm-remarkutil filter --function=func1 %t.opt.bitstream -o %t.r1.yMl +RUN: cat %t.r1.yMl | FileCheck %s --strict-whitespace --check-prefix=REMARK1 +RUN: llvm-remarkutil filter --function=func1 %t.opt.bitstream --serializer=yaml -o %t.r1.fake.opt.bitstream +RUN: cat %t.r1.fake.opt.bitstream | FileCheck %s --strict-whitespace --check-prefix=REMARK1 RUN: llvm-remarkutil filter --function=func1 %t.opt.bitstream -o %t.r1.opt.bitstream RUN: llvm-remarkutil bitstream2yaml %t.r1.opt.bitstream | FileCheck %s --strict-whitespace --check-prefix=REMARK1 +RUN: llvm-remarkutil filter --function=func1 %t.opt.bitstream -o %t.r1 +RUN: llvm-remarkutil bitstream2yaml %t.r1 | FileCheck %s --strict-whitespace --check-prefix=REMARK1 +RUN: llvm-remarkutil filter --function=func1 %p/Inputs/filter.yaml --serializer=bitstream -o %t.r1.fake.yaml +RUN: llvm-remarkutil bitstream2yaml %t.r1.fake.yaml | FileCheck %s --strict-whitespace --check-prefix=REMARK1 RUN: llvm-remarkutil filter --function=func %p/Inputs/filter.yaml | FileCheck %s --allow-empty --strict-whitespace --check-prefix=EMPTY diff --git a/llvm/tools/llvm-remarkutil/RemarkFilter.cpp b/llvm/tools/llvm-remarkutil/RemarkFilter.cpp index 507ae36..9b521b4 100644 --- a/llvm/tools/llvm-remarkutil/RemarkFilter.cpp +++ b/llvm/tools/llvm-remarkutil/RemarkFilter.cpp @@ -48,12 +48,8 @@ static Error tryFilter() { return MaybeParser.takeError(); auto &Parser = **MaybeParser; - Format SerializerFormat = OutputFormat; - if (SerializerFormat == Format::Auto) { - SerializerFormat = Parser.ParserFormat; - if (OutputFileName.empty() || OutputFileName == "-") - SerializerFormat = Format::YAML; - } + Format SerializerFormat = + getSerializerFormat(OutputFileName, OutputFormat, Parser.ParserFormat); auto MaybeOF = getOutputFileForRemarks(OutputFileName, SerializerFormat); if (!MaybeOF) diff --git a/llvm/tools/llvm-remarkutil/RemarkUtilHelpers.cpp b/llvm/tools/llvm-remarkutil/RemarkUtilHelpers.cpp index be52948..b6204d0 100644 --- a/llvm/tools/llvm-remarkutil/RemarkUtilHelpers.cpp +++ b/llvm/tools/llvm-remarkutil/RemarkUtilHelpers.cpp @@ -54,6 +54,20 @@ getOutputFileForRemarks(StringRef OutputFileName, Format OutputFormat) { : sys::fs::OF_None); } +Format getSerializerFormat(StringRef OutputFileName, Format SelectedFormat, + Format DefaultFormat) { + if (SelectedFormat != Format::Auto) + return SelectedFormat; + SelectedFormat = DefaultFormat; + if (OutputFileName.empty() || OutputFileName == "-" || + OutputFileName.ends_with_insensitive(".yaml") || + OutputFileName.ends_with_insensitive(".yml")) + SelectedFormat = Format::YAML; + if (OutputFileName.ends_with_insensitive(".bitstream")) + SelectedFormat = Format::Bitstream; + return SelectedFormat; +} + Expected<FilterMatcher> FilterMatcher::createRE(const llvm::cl::opt<std::string> &Arg) { return createRE(Arg.ArgStr, Arg); diff --git a/llvm/tools/llvm-remarkutil/RemarkUtilHelpers.h b/llvm/tools/llvm-remarkutil/RemarkUtilHelpers.h index 0dd550765..73867fe 100644 --- a/llvm/tools/llvm-remarkutil/RemarkUtilHelpers.h +++ b/llvm/tools/llvm-remarkutil/RemarkUtilHelpers.h @@ -47,7 +47,8 @@ "serializer", cl::init(Format::Auto), \ cl::desc("Output remark format to serialize"), \ cl::values(clEnumValN(Format::Auto, "auto", \ - "Follow the parser format (default)"), \ + "Automatic detection based on output file " \ + "extension or parser format (default)"), \ clEnumValN(Format::YAML, "yaml", "YAML"), \ clEnumValN(Format::Bitstream, "bitstream", "Bitstream")), \ cl::sub(SUBOPT)); @@ -151,6 +152,12 @@ getOutputFileWithFlags(StringRef OutputFileName, sys::fs::OpenFlags Flags); Expected<std::unique_ptr<ToolOutputFile>> getOutputFileForRemarks(StringRef OutputFileName, Format OutputFormat); +/// Choose the serializer format. If \p SelectedFormat is Format::Auto, try to +/// detect the format based on the extension of \p OutputFileName or fall back +/// to \p DefaultFormat. +Format getSerializerFormat(StringRef OutputFileName, Format SelectedFormat, + Format DefaultFormat); + /// Filter object which can be either a string or a regex to match with the /// remark properties. class FilterMatcher { diff --git a/llvm/unittests/ExecutionEngine/Orc/SymbolStringPoolTest.cpp b/llvm/unittests/ExecutionEngine/Orc/SymbolStringPoolTest.cpp index cd1cecd..698dda1 100644 --- a/llvm/unittests/ExecutionEngine/Orc/SymbolStringPoolTest.cpp +++ b/llvm/unittests/ExecutionEngine/Orc/SymbolStringPoolTest.cpp @@ -180,4 +180,14 @@ TEST_F(SymbolStringPoolTest, SymbolStringPoolEntryUnsafe) { EXPECT_EQ(getRefCount(A), 1U); } +TEST_F(SymbolStringPoolTest, Hashing) { + auto A = SP.intern("a"); + auto B = NonOwningSymbolStringPtr(A); + + hash_code AHash = hash_value(A); + hash_code BHash = hash_value(B); + + EXPECT_EQ(AHash, BHash); +} + } // namespace diff --git a/llvm/utils/profcheck-xfail.txt b/llvm/utils/profcheck-xfail.txt index 39ff476..bdcb8a3 100644 --- a/llvm/utils/profcheck-xfail.txt +++ b/llvm/utils/profcheck-xfail.txt @@ -73,9 +73,7 @@ CodeGen/Hexagon/loop-idiom/hexagon-memmove2.ll CodeGen/Hexagon/loop-idiom/memmove-rt-check.ll CodeGen/NVPTX/lower-ctor-dtor.ll CodeGen/RISCV/zmmul.ll -CodeGen/SPIRV/hlsl-resources/UniqueImplicitBindingNumber.ll CodeGen/WebAssembly/memory-interleave.ll -CodeGen/X86/global-variable-partition-with-dap.ll CodeGen/X86/masked_gather_scatter.ll CodeGen/X86/nocfivalue.ll DebugInfo/AArch64/ir-outliner.ll diff --git a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp index 39ae6a0..a9592bc 100644 --- a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp @@ -264,8 +264,7 @@ private: // The attribute is a vector with a floating point value per element // (number) in the array, see `collectData()` below for more details. std::vector<double> data; - data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, - std::multiplies<int>())); + data.reserve(llvm::product_of(lit.getDims())); collectData(lit, data); // The type of this attribute is tensor of 64-bit floating-point with the diff --git a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp index 0573af6..8c21951 100644 --- a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp @@ -264,8 +264,7 @@ private: // The attribute is a vector with a floating point value per element // (number) in the array, see `collectData()` below for more details. std::vector<double> data; - data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, - std::multiplies<int>())); + data.reserve(llvm::product_of(lit.getDims())); collectData(lit, data); // The type of this attribute is tensor of 64-bit floating-point with the diff --git a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp index 7d676f1..6b7ab40 100644 --- a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp @@ -268,8 +268,7 @@ private: // The attribute is a vector with a floating point value per element // (number) in the array, see `collectData()` below for more details. std::vector<double> data; - data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, - std::multiplies<int>())); + data.reserve(llvm::product_of(lit.getDims())); collectData(lit, data); // The type of this attribute is tensor of 64-bit floating-point with the diff --git a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp index 7d676f1..6b7ab40 100644 --- a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp @@ -268,8 +268,7 @@ private: // The attribute is a vector with a floating point value per element // (number) in the array, see `collectData()` below for more details. std::vector<double> data; - data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, - std::multiplies<int>())); + data.reserve(llvm::product_of(lit.getDims())); collectData(lit, data); // The type of this attribute is tensor of 64-bit floating-point with the diff --git a/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp index 7d676f1..6b7ab40 100644 --- a/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp @@ -268,8 +268,7 @@ private: // The attribute is a vector with a floating point value per element // (number) in the array, see `collectData()` below for more details. std::vector<double> data; - data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, - std::multiplies<int>())); + data.reserve(llvm::product_of(lit.getDims())); collectData(lit, data); // The type of this attribute is tensor of 64-bit floating-point with the diff --git a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp index 75dbc91..7313324 100644 --- a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp @@ -405,8 +405,7 @@ private: // The attribute is a vector with a floating point value per element // (number) in the array, see `collectData()` below for more details. std::vector<double> data; - data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, - std::multiplies<int>())); + data.reserve(llvm::product_of(lit.getDims())); collectData(lit, data); // The type of this attribute is tensor of 64-bit floating-point with the diff --git a/mlir/include/mlir/CAPI/Rewrite.h b/mlir/include/mlir/CAPI/Rewrite.h index 8cd51ed..9c96d35 100644 --- a/mlir/include/mlir/CAPI/Rewrite.h +++ b/mlir/include/mlir/CAPI/Rewrite.h @@ -18,9 +18,19 @@ #include "mlir-c/Rewrite.h" #include "mlir/CAPI/Wrap.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" DEFINE_C_API_PTR_METHODS(MlirRewriterBase, mlir::RewriterBase) DEFINE_C_API_PTR_METHODS(MlirRewritePattern, const mlir::RewritePattern) DEFINE_C_API_PTR_METHODS(MlirRewritePatternSet, mlir::RewritePatternSet) +DEFINE_C_API_PTR_METHODS(MlirFrozenRewritePatternSet, + mlir::FrozenRewritePatternSet) +DEFINE_C_API_PTR_METHODS(MlirPatternRewriter, mlir::PatternRewriter) + +#if MLIR_ENABLE_PDL_IN_PATTERNMATCH +DEFINE_C_API_PTR_METHODS(MlirPDLPatternModule, mlir::PDLPatternModule) +DEFINE_C_API_PTR_METHODS(MlirPDLResultList, mlir::PDLResultList) +DEFINE_C_API_PTR_METHODS(MlirPDLValue, const mlir::PDLValue) +#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH #endif // MLIR_CAPIREWRITER_H diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index d506b7f..47685567 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -261,7 +261,6 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { //---------------------------------------------------------------------------- // Mapping of the RewritePatternSet //---------------------------------------------------------------------------- - nb::class_<MlirRewritePattern>(m, "RewritePattern"); nb::class_<PyRewritePatternSet>(m, "RewritePatternSet") .def( "__init__", diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index 70dee59..46c329d 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -270,17 +270,6 @@ void mlirIRRewriterDestroy(MlirRewriterBase rewriter) { /// RewritePatternSet and FrozenRewritePatternSet API //===----------------------------------------------------------------------===// -static inline mlir::FrozenRewritePatternSet * -unwrap(MlirFrozenRewritePatternSet module) { - assert(module.ptr && "unexpected null module"); - return static_cast<mlir::FrozenRewritePatternSet *>(module.ptr); -} - -static inline MlirFrozenRewritePatternSet -wrap(mlir::FrozenRewritePatternSet *module) { - return {module}; -} - MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet set) { auto *m = new mlir::FrozenRewritePatternSet(std::move(*unwrap(set))); @@ -311,15 +300,6 @@ mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op, /// PatternRewriter API //===----------------------------------------------------------------------===// -inline mlir::PatternRewriter *unwrap(MlirPatternRewriter rewriter) { - assert(rewriter.ptr && "unexpected null rewriter"); - return static_cast<mlir::PatternRewriter *>(rewriter.ptr); -} - -inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) { - return {rewriter}; -} - MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) { return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter))); } @@ -400,15 +380,6 @@ void mlirRewritePatternSetAdd(MlirRewritePatternSet set, //===----------------------------------------------------------------------===// #if MLIR_ENABLE_PDL_IN_PATTERNMATCH -static inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) { - assert(module.ptr && "unexpected null module"); - return static_cast<mlir::PDLPatternModule *>(module.ptr); -} - -static inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) { - return {module}; -} - MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op) { return wrap(new mlir::PDLPatternModule( mlir::OwningOpRef<mlir::ModuleOp>(unwrap(op)))); @@ -426,22 +397,6 @@ mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) { return wrap(m); } -inline const mlir::PDLValue *unwrap(MlirPDLValue value) { - assert(value.ptr && "unexpected null PDL value"); - return static_cast<const mlir::PDLValue *>(value.ptr); -} - -inline MlirPDLValue wrap(const mlir::PDLValue *value) { return {value}; } - -inline mlir::PDLResultList *unwrap(MlirPDLResultList results) { - assert(results.ptr && "unexpected null PDL results"); - return static_cast<mlir::PDLResultList *>(results.ptr); -} - -inline MlirPDLResultList wrap(mlir::PDLResultList *results) { - return {results}; -} - MlirValue mlirPDLValueAsValue(MlirPDLValue value) { return wrap(unwrap(value)->dyn_cast<mlir::Value>()); } diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index 2b7bdc9..11f866c 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/TypeRange.h" #include "mlir/IR/Value.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/STLExtras.h" #include <cstdint> #include <numeric> @@ -110,9 +111,7 @@ static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType, {TypeAttr::get(memrefType.getElementType())})); IndexType indexType = builder.getIndexType(); - int64_t numElements = std::accumulate(memrefType.getShape().begin(), - memrefType.getShape().end(), int64_t{1}, - std::multiplies<int64_t>()); + int64_t numElements = llvm::product_of(memrefType.getShape()); emitc::ConstantOp numElementsValue = emitc::ConstantOp::create( builder, loc, indexType, builder.getIndexAttr(numElements)); diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp index 802691c..9bf9ca3 100644 --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/STLExtras.h" #include <numeric> @@ -70,8 +71,7 @@ TensorType inferReshapeExpandedType(TensorType inputType, // Calculate the product of all elements in 'newShape' except for the -1 // placeholder, which we discard by negating the result. - int64_t totalSizeNoPlaceholder = -std::accumulate( - newShape.begin(), newShape.end(), 1, std::multiplies<int64_t>()); + int64_t totalSizeNoPlaceholder = -llvm::product_of(newShape); // If there is a 0 component in 'newShape', resolve the placeholder as // 0. diff --git a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp index 79c2f23..245a3ef 100644 --- a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp +++ b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp @@ -20,6 +20,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/DebugLog.h" #include <numeric> @@ -265,8 +266,7 @@ loadStoreFromTransfer(PatternRewriter &rewriter, if (isPacked) src = collapseLastDim(rewriter, src); int64_t rows = vecShape[0]; - int64_t cols = std::accumulate(vecShape.begin() + 1, vecShape.end(), 1, - std::multiplies<int64_t>()); + int64_t cols = llvm::product_of(vecShape.drop_front()); auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType()); Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0); @@ -336,8 +336,7 @@ static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter, ArrayRef<int64_t> shape = vecTy.getShape(); int64_t rows = shape[0]; - int64_t cols = std::accumulate(shape.begin() + 1, shape.end(), 1, - std::multiplies<int64_t>()); + int64_t cols = llvm::product_of(shape.drop_front()); auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType()); return amx::TileLoadOp::create(rewriter, loc, tileType, buf, diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index c45c45e..c9eba69 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -26,6 +26,7 @@ #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/STLExtras.h" namespace mlir { #define GEN_PASS_DEF_CONVERTVECTORTOSCF @@ -760,8 +761,7 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> { if (vectorType.getRank() != 1) { // Flatten n-D vectors to 1D. This is done to allow indexing with a // non-constant value. - auto flatLength = std::accumulate(shape.begin(), shape.end(), 1, - std::multiplies<int64_t>()); + int64_t flatLength = llvm::product_of(shape); auto flatVectorType = VectorType::get({flatLength}, vectorType.getElementType()); value = vector::ShapeCastOp::create(rewriter, loc, flatVectorType, value); diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 9ead1d8..71687b1 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -23,6 +23,7 @@ #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/IR/BuiltinTypes.h" @@ -774,9 +775,7 @@ struct ConvertXeGPUToXeVMPass if (rank < 1 || type.getNumElements() == 1) return elemType; // Otherwise, convert the vector to a flat vector type. - int64_t sum = - std::accumulate(type.getShape().begin(), type.getShape().end(), - int64_t{1}, std::multiplies<int64_t>()); + int64_t sum = llvm::product_of(type.getShape()); return VectorType::get(sum, elemType); }); typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type { diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp index b1fc9aa..f54baff 100644 --- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp @@ -351,9 +351,9 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values, Value one = ConstantOp::create(builder, loc, resultType, builder.getOneAttr(resultType)); ArithBuilder arithBuilder(builder, loc); - return std::accumulate( - values.begin(), values.end(), one, - [&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); }); + return llvm::accumulate(values, one, [&arithBuilder](Value acc, Value v) { + return arithBuilder.mul(acc, v); + }); } /// Map strings to float types. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp index a50ddbe..624519f 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp @@ -55,16 +55,6 @@ static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) { return returnOp; } -/// Return the func::FuncOp called by `callOp`. -static func::FuncOp getCalledFunction(CallOpInterface callOp) { - SymbolRefAttr sym = - llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee()); - if (!sym) - return nullptr; - return dyn_cast_or_null<func::FuncOp>( - SymbolTable::lookupNearestSymbolFrom(callOp, sym)); -} - LogicalResult mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) { IRRewriter rewriter(module.getContext()); @@ -72,7 +62,8 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) { DenseMap<func::FuncOp, DenseSet<func::CallOp>> callerMap; // Collect the mapping of functions to their call sites. module.walk([&](func::CallOp callOp) { - if (func::FuncOp calledFunc = getCalledFunction(callOp)) { + if (func::FuncOp calledFunc = + dyn_cast_or_null<func::FuncOp>(callOp.resolveCallable())) { callerMap[calledFunc].insert(callOp); } }); diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 19eba6b..b5f8dda 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -2460,8 +2460,7 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed, << dDim << ")"; scales[i] = eDim / dDim; } - if (std::accumulate(scales.begin(), scales.end(), 1, - std::multiplies<int64_t>()) != warpSize) + if (llvm::product_of(scales) != warpSize) return op->emitOpError() << "incompatible distribution dimensions from " << expandedVecType << " to " << distributedVecType << " with warp size = " << warpSize; diff --git a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp index 88f531f..572b746 100644 --- a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp +++ b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Value.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" #include <numeric> @@ -118,8 +119,7 @@ bool WarpDistributionPattern::delinearizeLaneId( return false; sizes.push_back(large / small); } - if (std::accumulate(sizes.begin(), sizes.end(), 1, - std::multiplies<int64_t>()) != warpSize) + if (llvm::product_of(sizes) != warpSize) return false; AffineExpr s0, s1; diff --git a/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp index f277c5f..0ae2a9c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp @@ -266,9 +266,8 @@ struct StructuredOpShardingInterface LinalgOp linalgOp = llvm::cast<LinalgOp>(op); SmallVector<utils::IteratorType> iteratorTypes = linalgOp.getIteratorTypesArray(); - unsigned reductionItersCount = std::accumulate( - iteratorTypes.begin(), iteratorTypes.end(), 0, - [](unsigned count, utils::IteratorType iter) { + unsigned reductionItersCount = llvm::accumulate( + iteratorTypes, 0u, [](unsigned count, utils::IteratorType iter) { return count + (iter == utils::IteratorType::reduction); }); shard::ReductionKind reductionKind = getReductionKindOfLinalgOp(linalgOp); diff --git a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp index b663908..8c4f80f 100644 --- a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp +++ b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Quant/Utils/UniformSupport.h" #include "mlir/IR/BuiltinTypes.h" +#include "llvm/ADT/STLExtras.h" #include <numeric> using namespace mlir; @@ -76,9 +77,7 @@ UniformQuantizedPerAxisValueConverter::convert(DenseFPElementsAttr attr) { // using the right quantization parameters. int64_t flattenIndex = 0; auto shape = type.getShape(); - int64_t chunkSize = - std::accumulate(std::next(shape.begin(), quantizationDim + 1), - shape.end(), 1, std::multiplies<int64_t>()); + int64_t chunkSize = llvm::product_of(shape.drop_front(quantizationDim + 1)); Type newElementType = IntegerType::get(attr.getContext(), storageBitWidth); return attr.mapValues(newElementType, [&](const APFloat &old) { int chunkIndex = (flattenIndex++) / chunkSize; diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 5511998..fe50865 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -400,7 +400,7 @@ LogicalResult spirv::CompositeConstructOp::verify() { return emitOpError("operand element type mismatch: expected to be ") << resultType.getElementType() << ", but provided " << elementType; } - unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0); + unsigned totalCount = llvm::sum_of(sizes); if (totalCount != cType.getNumElements()) return emitOpError("has incorrect number of operands: expected ") << cType.getNumElements() << ", but provided " << totalCount; diff --git a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp index 08fccfa..135c033 100644 --- a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp +++ b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp @@ -1010,18 +1010,6 @@ static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, return success(); } -template <typename It> -static auto product(It begin, It end) { - using ElementType = std::decay_t<decltype(*begin)>; - return std::accumulate(begin, end, static_cast<ElementType>(1), - std::multiplies<ElementType>()); -} - -template <typename R> -static auto product(R &&range) { - return product(adl_begin(range), adl_end(range)); -} - static LogicalResult verifyDimensionCompatibility(Location loc, int64_t expectedDimSize, int64_t resultDimSize, diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index c51b5e9..00f84bc 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -2368,9 +2368,10 @@ llvm::LogicalResult tosa::ReshapeOp::verify() { } } - int64_t newShapeElementsNum = std::accumulate( - shapeValues.begin(), shapeValues.end(), 1LL, - [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; }); + int64_t newShapeElementsNum = + llvm::accumulate(shapeValues, int64_t(1), [](int64_t acc, int64_t dim) { + return (dim > 0) ? acc * dim : acc; + }); bool isStaticNewShape = llvm::all_of(shapeValues, [](int64_t s) { return s > 0; }); if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) || diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp index d33ebe3..5786f53 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp @@ -20,6 +20,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectResourceBlobManager.h" #include "mlir/IR/Matchers.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" using namespace mlir; @@ -375,8 +376,7 @@ llvm::APInt calculateReducedValue(const mlir::ElementsAttr &oldTensorAttr, for (int64_t reductionAxisVal = 1; reductionAxisVal < oldShape[reductionAxis]; ++reductionAxisVal) { - int64_t stride = std::accumulate(oldShape.begin() + reductionAxis + 1, - oldShape.end(), 1, std::multiplies<int>()); + int64_t stride = llvm::product_of(oldShape.drop_front(reductionAxis + 1)); int64_t index = indexAtOldTensor + stride * reductionAxisVal; reducedValue = OperationType::calcOneElement(reducedValue, oldTensor[index]); @@ -424,8 +424,7 @@ struct ReduceConstantOptimization : public OpRewritePattern<OperationType> { auto oldShape = shapedOldElementsValues.getShape(); auto newShape = resultType.getShape(); - auto newNumOfElements = std::accumulate(newShape.begin(), newShape.end(), 1, - std::multiplies<int>()); + int64_t newNumOfElements = llvm::product_of(newShape); llvm::SmallVector<APInt> newReducedTensor(newNumOfElements); for (int64_t reductionIndex = 0; reductionIndex < newNumOfElements; diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp index e1648ab9..305b06eb 100644 --- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp +++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp @@ -81,21 +81,10 @@ SmallVector<int64_t> mlir::computeElementwiseMul(ArrayRef<int64_t> v1, return computeElementwiseMulImpl(v1, v2); } -int64_t mlir::computeSum(ArrayRef<int64_t> basis) { - assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) && - "basis must be nonnegative"); - if (basis.empty()) - return 0; - return std::accumulate(basis.begin(), basis.end(), 1, std::plus<int64_t>()); -} - int64_t mlir::computeProduct(ArrayRef<int64_t> basis) { assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) && "basis must be nonnegative"); - if (basis.empty()) - return 1; - return std::accumulate(basis.begin(), basis.end(), 1, - std::multiplies<int64_t>()); + return llvm::product_of(basis); } int64_t mlir::linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) { @@ -158,19 +147,11 @@ SmallVector<AffineExpr> mlir::computeElementwiseMul(ArrayRef<AffineExpr> v1, } AffineExpr mlir::computeSum(MLIRContext *ctx, ArrayRef<AffineExpr> basis) { - if (basis.empty()) - return getAffineConstantExpr(0, ctx); - return std::accumulate(basis.begin(), basis.end(), - getAffineConstantExpr(0, ctx), - std::plus<AffineExpr>()); + return llvm::sum_of(basis, getAffineConstantExpr(0, ctx)); } AffineExpr mlir::computeProduct(MLIRContext *ctx, ArrayRef<AffineExpr> basis) { - if (basis.empty()) - return getAffineConstantExpr(1, ctx); - return std::accumulate(basis.begin(), basis.end(), - getAffineConstantExpr(1, ctx), - std::multiplies<AffineExpr>()); + return llvm::product_of(basis, getAffineConstantExpr(1, ctx)); } AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets, diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp index 7b2734d..6e9118e 100644 --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -374,11 +374,11 @@ mlir::composeReassociationIndices( if (consumerReassociations.empty()) return composedIndices; - size_t consumerDims = std::accumulate( - consumerReassociations.begin(), consumerReassociations.end(), 0, - [](size_t all, ReassociationIndicesRef indices) { - return all + indices.size(); - }); + size_t consumerDims = + llvm::accumulate(consumerReassociations, size_t(0), + [](size_t all, ReassociationIndicesRef indices) { + return all + indices.size(); + }); if (producerReassociations.size() != consumerDims) return std::nullopt; diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index a7e3ba8..58256b0 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2496,8 +2496,7 @@ struct ToElementsOfBroadcast final : OpRewritePattern<ToElementsOp> { auto srcElems = vector::ToElementsOp::create( rewriter, toElementsOp.getLoc(), bcastOp.getSource()); - int64_t dstCount = std::accumulate(dstShape.begin(), dstShape.end(), 1, - std::multiplies<int64_t>()); + int64_t dstCount = llvm::product_of(dstShape); SmallVector<Value> replacements; replacements.reserve(dstCount); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp index c5f22b2..0eba0b1 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp @@ -21,6 +21,7 @@ #include "mlir/IR/Location.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" +#include "llvm/ADT/STLExtras.h" #include <numeric> #define DEBUG_TYPE "vector-shape-cast-lowering" @@ -166,10 +167,7 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> { const VectorType resultType = shapeCast.getResultVectorType(); const ArrayRef<int64_t> resultShape = resultType.getShape(); - const int64_t nSlices = - std::accumulate(sourceShape.begin(), sourceShape.begin() + sourceDim, 1, - std::multiplies<int64_t>()); - + const int64_t nSlices = llvm::product_of(sourceShape.take_front(sourceDim)); SmallVector<int64_t> extractIndex(sourceDim, 0); SmallVector<int64_t> insertIndex(resultDim, 0); Value result = ub::PoisonOp::create(rewriter, loc, resultType); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp index 963b2c8..aa2dd89 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/TypeUtilities.h" +#include "llvm/ADT/STLExtras.h" #define DEBUG_TYPE "vector-drop-unit-dim" @@ -557,8 +558,7 @@ struct CastAwayConstantMaskLeadingOneDim // If any of the dropped unit dims has a size of `0`, the entire mask is a // zero mask, else the unit dim has no effect on the mask. int64_t flatLeadingSize = - std::accumulate(dimSizes.begin(), dimSizes.begin() + dropDim + 1, - static_cast<int64_t>(1), std::multiplies<int64_t>()); + llvm::product_of(dimSizes.take_front(dropDim + 1)); SmallVector<int64_t> newDimSizes = {flatLeadingSize}; newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end()); diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp index b72d564..2c56a43 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp +++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp @@ -52,8 +52,7 @@ mlir::xegpu::getDistributedVectorType(xegpu::TensorDescType tdescTy) { // compute sgSize by multiply elements of laneLayout // e.g. for 2D layout, sgSize = laneLayout[0] * laneLayout[1] // e.g. for 1D layout, sgSize = laneLayout[0] - auto sgSize = std::accumulate(laneLayout.begin(), laneLayout.end(), 1, - std::multiplies<int64_t>()); + int64_t sgSize = llvm::product_of(laneLayout); // Case 1: regular loads/stores auto scatterAttr = tdescTy.getEncodingOfType<ScatterTensorDescAttr>(); diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 8bcfa46..ce421f4 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/FoldInterfaces.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/ErrorHandling.h" #include <numeric> @@ -1274,10 +1275,7 @@ LogicalResult OpTrait::impl::verifyValueSizeAttr(Operation *op, return op->emitOpError("'") << attrName << "' attribute cannot have negative elements"; - size_t totalCount = - std::accumulate(sizes.begin(), sizes.end(), 0, - [](unsigned all, int32_t one) { return all + one; }); - + size_t totalCount = llvm::sum_of(sizes, size_t(0)); if (totalCount != expectedCount) return op->emitOpError() << valueGroupName << " count (" << expectedCount diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp index 394ac77..2a37f38 100644 --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -406,15 +406,13 @@ OperandRangeRange::OperandRangeRange(OperandRange operands, OperandRange OperandRangeRange::join() const { const OwnerT &owner = getBase(); ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(owner.second); - return OperandRange(owner.first, - std::accumulate(sizeData.begin(), sizeData.end(), 0)); + return OperandRange(owner.first, llvm::sum_of(sizeData)); } OperandRange OperandRangeRange::dereference(const OwnerT &object, ptrdiff_t index) { ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(object.second); - uint32_t startIndex = - std::accumulate(sizeData.begin(), sizeData.begin() + index, 0); + uint32_t startIndex = llvm::sum_of(sizeData.take_front(index)); return OperandRange(object.first + startIndex, *(sizeData.begin() + index)); } @@ -565,8 +563,7 @@ MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object, ptrdiff_t index) { ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(object.second.getValue()); - uint32_t startIndex = - std::accumulate(sizeData.begin(), sizeData.begin() + index, 0); + uint32_t startIndex = llvm::sum_of(sizeData.take_front(index)); return object.first.slice( startIndex, *(sizeData.begin() + index), MutableOperandRange::OperandSegment(index, object.second)); diff --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp index d2d115e..e438631 100644 --- a/mlir/lib/IR/TypeUtilities.cpp +++ b/mlir/lib/IR/TypeUtilities.cpp @@ -104,8 +104,8 @@ LogicalResult mlir::verifyCompatibleShapes(TypeRange types1, TypeRange types2) { LogicalResult mlir::verifyCompatibleDims(ArrayRef<int64_t> dims) { if (dims.empty()) return success(); - auto staticDim = std::accumulate( - dims.begin(), dims.end(), dims.front(), [](auto fold, auto dim) { + auto staticDim = + llvm::accumulate(dims, dims.front(), [](auto fold, auto dim) { return ShapedType::isDynamic(dim) ? fold : dim; }); return success(llvm::all_of(dims, [&](auto dim) { diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp index 33fbd2a..42843ea 100644 --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -1835,8 +1835,7 @@ executeGetOperandsResults(RangeT values, Operation *op, unsigned index, return nullptr; ArrayRef<int32_t> segments = segmentAttr; - unsigned startIndex = - std::accumulate(segments.begin(), segments.begin() + index, 0); + unsigned startIndex = llvm::sum_of(segments.take_front(index)); values = values.slice(startIndex, *std::next(segments.begin(), index)); LDBG() << " * Extracting range[" << startIndex << ", " diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 5a3eb20..845a14f 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -922,8 +922,7 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall( assert(opBundleSizes.size() == opBundleTagsAttr.size() && "operand bundles and tags do not match"); - numOpBundleOperands = - std::accumulate(opBundleSizes.begin(), opBundleSizes.end(), size_t(0)); + numOpBundleOperands = llvm::sum_of(opBundleSizes); assert(numOpBundleOperands <= intrOp->getNumOperands() && "operand bundle operands is more than the number of operands"); diff --git a/mlir/test/Examples/standalone/test.wheel.toy b/mlir/test/Examples/standalone/test.wheel.toy index 5ff9271..c8d188a 100644 --- a/mlir/test/Examples/standalone/test.wheel.toy +++ b/mlir/test/Examples/standalone/test.wheel.toy @@ -2,6 +2,7 @@ # than 255 chars when combined with the fact that pip wants to install into a tmp directory buried under # C/Users/ContainerAdministrator/AppData/Local/Temp. # UNSUPPORTED: target={{.*(windows).*}} +# REQUIRES: expensive_checks # REQUIRES: non-shared-libs-build # REQUIRES: bindings-python diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py index f99c24d..6ff12d6 100644 --- a/mlir/test/lit.cfg.py +++ b/mlir/test/lit.cfg.py @@ -348,6 +348,9 @@ if config.enable_assertions: else: config.available_features.add("noasserts") +if config.expensive_checks: + config.available_features.add("expensive_checks") + def have_host_jit_feature_support(feature_name): mlir_runner_exe = lit.util.which("mlir-runner", config.mlir_tools_dir) diff --git a/mlir/test/lit.site.cfg.py.in b/mlir/test/lit.site.cfg.py.in index 1aaf798..91a71af 100644 --- a/mlir/test/lit.site.cfg.py.in +++ b/mlir/test/lit.site.cfg.py.in @@ -11,6 +11,7 @@ config.llvm_shlib_ext = "@SHLIBEXT@" config.llvm_shlib_dir = lit_config.substitute(path(r"@SHLIBDIR@")) config.python_executable = "@Python3_EXECUTABLE@" config.enable_assertions = @ENABLE_ASSERTIONS@ +config.expensive_checks = "@EXPENSIVE_CHECKS@" config.native_target = "@LLVM_NATIVE_ARCH@" config.host_os = "@HOST_OS@" config.host_cc = "@HOST_CC@" diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 9690115..daae3c7 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -3513,9 +3513,9 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder( body << "(" << operandName << " ? 1 : 0)"; } else if (operand.isVariadicOfVariadic()) { body << llvm::formatv( - "static_cast<int32_t>(std::accumulate({0}.begin(), {0}.end(), 0, " + "llvm::accumulate({0}, int32_t(0), " "[](int32_t curSum, ::mlir::ValueRange range) {{ return curSum + " - "static_cast<int32_t>(range.size()); }))", + "static_cast<int32_t>(range.size()); })", operandName); } else { body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())"; diff --git a/orc-rt/include/orc-rt/SPSWrapperFunction.h b/orc-rt/include/orc-rt/SPSWrapperFunction.h index dc68822..46c08a0 100644 --- a/orc-rt/include/orc-rt/SPSWrapperFunction.h +++ b/orc-rt/include/orc-rt/SPSWrapperFunction.h @@ -42,12 +42,6 @@ private: static T &&from(T &&Arg) noexcept { return std::forward<T>(Arg); } }; - template <typename T> struct Serializable<T *> { - typedef ExecutorAddr serializable_type; - static ExecutorAddr to(T *Arg) { return ExecutorAddr::fromPtr(Arg); } - static T *from(ExecutorAddr A) { return A.toPtr<T *>(); } - }; - template <> struct Serializable<Error> { typedef SPSSerializableError serializable_type; static SPSSerializableError to(Error Err) { @@ -66,21 +60,6 @@ private: } }; - template <typename T> struct Serializable<Expected<T *>> { - typedef SPSSerializableExpected<ExecutorAddr> serializable_type; - static SPSSerializableExpected<ExecutorAddr> to(Expected<T *> Val) { - return SPSSerializableExpected<ExecutorAddr>( - Val ? Expected<ExecutorAddr>(ExecutorAddr::fromPtr(*Val)) - : Expected<ExecutorAddr>(Val.takeError())); - } - static Expected<T *> from(SPSSerializableExpected<ExecutorAddr> Val) { - if (auto Tmp = Val.toExpected()) - return Tmp->toPtr<T *>(); - else - return Tmp.takeError(); - } - }; - template <typename... Ts> struct DeserializableTuple; template <typename... Ts> struct DeserializableTuple<std::tuple<Ts...>> { diff --git a/orc-rt/include/orc-rt/SimplePackedSerialization.h b/orc-rt/include/orc-rt/SimplePackedSerialization.h index f60ccad..0f291c4 100644 --- a/orc-rt/include/orc-rt/SimplePackedSerialization.h +++ b/orc-rt/include/orc-rt/SimplePackedSerialization.h @@ -556,6 +556,26 @@ public: } }; +/// Allow SPSExectorAddr serialization to/from T*. +template <typename T> class SPSSerializationTraits<SPSExecutorAddr, T *> { +public: + static size_t size(T *const &P) { + return SPSArgList<SPSExecutorAddr>::size(ExecutorAddr::fromPtr(P)); + } + + static bool serialize(SPSOutputBuffer &OB, T *const &P) { + return SPSArgList<SPSExecutorAddr>::serialize(OB, ExecutorAddr::fromPtr(P)); + } + + static bool deserialize(SPSInputBuffer &IB, T *&P) { + ExecutorAddr Value; + if (!SPSArgList<SPSExecutorAddr>::deserialize(IB, Value)) + return false; + P = Value.toPtr<T *>(); + return true; + } +}; + /// Helper type for serializing Errors. /// /// llvm::Errors are move-only, and not inspectable except by consuming them. diff --git a/orc-rt/unittests/SPSWrapperFunctionTest.cpp b/orc-rt/unittests/SPSWrapperFunctionTest.cpp index ed085f2..81e5755 100644 --- a/orc-rt/unittests/SPSWrapperFunctionTest.cpp +++ b/orc-rt/unittests/SPSWrapperFunctionTest.cpp @@ -192,62 +192,6 @@ TEST(SPSWrapperFunctionUtilsTest, TransparentConversionExpectedFailureCase) { EXPECT_EQ(ErrMsg, "N is not a multiple of 2"); } -static void -round_trip_int_pointer_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx, - orc_rt_WrapperFunctionReturn Return, - orc_rt_WrapperFunctionBuffer ArgBytes) { - SPSWrapperFunction<SPSExecutorAddr(SPSExecutorAddr)>::handle( - Session, CallCtx, Return, ArgBytes, - [](move_only_function<void(int32_t *)> Return, int32_t *P) { - Return(P); - }); -} - -TEST(SPSWrapperFunctionUtilsTest, TransparentConversionPointers) { - int X = 42; - int *P = nullptr; - SPSWrapperFunction<SPSExecutorAddr(SPSExecutorAddr)>::call( - DirectCaller(nullptr, round_trip_int_pointer_sps_wrapper), - [&](Expected<int32_t *> R) { P = cantFail(std::move(R)); }, &X); - - EXPECT_EQ(P, &X); -} - -TEST(SPSWrapperFunctionUtilsTest, TransparentConversionReferenceArguments) { - int X = 42; - int *P = nullptr; - SPSWrapperFunction<SPSExecutorAddr(SPSExecutorAddr)>::call( - DirectCaller(nullptr, round_trip_int_pointer_sps_wrapper), - [&](Expected<int32_t *> R) { P = cantFail(std::move(R)); }, - static_cast<int *const &>(&X)); - - EXPECT_EQ(P, &X); -} - -static void -expected_int_pointer_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx, - orc_rt_WrapperFunctionReturn Return, - orc_rt_WrapperFunctionBuffer ArgBytes) { - SPSWrapperFunction<SPSExpected<SPSExecutorAddr>(SPSExecutorAddr)>::handle( - Session, CallCtx, Return, ArgBytes, - [](move_only_function<void(Expected<int32_t *>)> Return, int32_t *P) { - Return(P); - }); -} - -TEST(SPSWrapperFunctionUtilsTest, TransparentConversionExpectedPointers) { - int X = 42; - int *P = nullptr; - SPSWrapperFunction<SPSExpected<SPSExecutorAddr>(SPSExecutorAddr)>::call( - DirectCaller(nullptr, expected_int_pointer_sps_wrapper), - [&](Expected<Expected<int32_t *>> R) { - P = cantFail(cantFail(std::move(R))); - }, - &X); - - EXPECT_EQ(P, &X); -} - template <size_t N> struct SPSOpCounter {}; namespace orc_rt { diff --git a/orc-rt/unittests/SimplePackedSerializationTest.cpp b/orc-rt/unittests/SimplePackedSerializationTest.cpp index c3df499..17f0e9c 100644 --- a/orc-rt/unittests/SimplePackedSerializationTest.cpp +++ b/orc-rt/unittests/SimplePackedSerializationTest.cpp @@ -169,6 +169,12 @@ TEST(SimplePackedSerializationTest, StdOptionalValueSerialization) { blobSerializationRoundTrip<SPSOptional<int64_t>>(Value); } +TEST(SimplePackedSerializationTest, Pointers) { + int X = 42; + int *P = &X; + blobSerializationRoundTrip<SPSExecutorAddr>(P); +} + TEST(SimplePackedSerializationTest, ArgListSerialization) { using BAL = SPSArgList<bool, int32_t, SPSString>; diff --git a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel index 640fa03..936bc12 100644 --- a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel @@ -2917,6 +2917,22 @@ libc_support_library( ) libc_support_library( + name = "__support_math_exp2f16", + hdrs = ["src/__support/math/exp2f16.h"], + deps = [ + ":__support_fputil_except_value_utils", + ":__support_fputil_fma", + ":__support_fputil_multiply_add", + ":__support_fputil_nearest_integer", + ":__support_fputil_polyeval", + ":__support_fputil_rounding_mode", + ":__support_macros_optimization", + ":__support_math_common_constants", + ":__support_math_expxf16_utils", + ], +) + +libc_support_library( name = "__support_math_exp10", hdrs = ["src/__support/math/exp10.h"], deps = [ @@ -3696,7 +3712,7 @@ libc_math_function( libc_math_function( name = "exp2f16", additional_deps = [ - ":__support_math_expxf16_utils", + ":__support_math_exp2f16", ], ) |