diff options
-rw-r--r-- | llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h | 33 | ||||
-rw-r--r-- | llvm/lib/Transforms/Utils/CallPromotionUtils.cpp | 32 | ||||
-rw-r--r-- | llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp | 88 |
3 files changed, 143 insertions, 10 deletions
diff --git a/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h b/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h index fcb384e..385831f 100644 --- a/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h +++ b/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h @@ -15,9 +15,12 @@ #define LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H namespace llvm { +template <typename T> class ArrayRef; +class Constant; class CallBase; class CastInst; class Function; +class Instruction; class MDNode; class Value; @@ -41,7 +44,9 @@ bool isLegalToPromote(const CallBase &CB, Function *Callee, CallBase &promoteCall(CallBase &CB, Function *Callee, CastInst **RetBitCast = nullptr); -/// Promote the given indirect call site to conditionally call \p Callee. +/// Promote the given indirect call site to conditionally call \p Callee. The +/// promoted direct call instruction is predicated on `CB.getCalledOperand() == +/// Callee`. /// /// This function creates an if-then-else structure at the location of the call /// site. The original call site is moved into the "else" block. A clone of the @@ -51,6 +56,22 @@ CallBase &promoteCall(CallBase &CB, Function *Callee, CallBase &promoteCallWithIfThenElse(CallBase &CB, Function *Callee, MDNode *BranchWeights = nullptr); +/// This is similar to `promoteCallWithIfThenElse` except that the condition to +/// promote a virtual call is that \p VPtr is the same as any of \p +/// AddressPoints. +/// +/// This function is expected to be used on virtual calls (a subset of indirect +/// calls). \p VPtr is the virtual table address stored in the objects, and +/// \p AddressPoints contains vtable address points. A vtable address point is +/// a location inside the vtable that's referenced by vpointer in C++ objects. +/// +/// TODO: sink the address-calculation instructions of indirect callee to the +/// indirect call fallback after transformation. +CallBase &promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr, + Function *Callee, + ArrayRef<Constant *> AddressPoints, + MDNode *BranchWeights); + /// Try to promote (devirtualize) a virtual call on an Alloca. Return true on /// success. /// @@ -76,11 +97,11 @@ bool tryPromoteCall(CallBase &CB); /// Predicate and clone the given call site. /// -/// This function creates an if-then-else structure at the location of the call -/// site. The "if" condition compares the call site's called value to the given -/// callee. The original call site is moved into the "else" block, and a clone -/// of the call site is placed in the "then" block. The cloned instruction is -/// returned. +/// This function creates an if-then-else structure at the location of the +/// call site. The "if" condition compares the call site's called value to +/// the given callee. The original call site is moved into the "else" block, +/// and a clone of the call site is placed in the "then" block. The cloned +/// instruction is returned. CallBase &versionCallSite(CallBase &CB, Value *Callee, MDNode *BranchWeights); } // end namespace llvm diff --git a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp index 9ca9aaf..dda80d4 100644 --- a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp +++ b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp @@ -12,9 +12,11 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/CallPromotionUtils.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/TypeMetadataUtils.h" #include "llvm/IR/AttributeMask.h" +#include "llvm/IR/Constant.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -188,9 +190,9 @@ static void createRetBitCast(CallBase &CB, Type *RetTy, CastInst **RetBitCast) { /// Predicate and clone the given call site. /// /// This function creates an if-then-else structure at the location of the call -/// site. The "if" condition is specified by `Cond`. The original call site is -/// moved into the "else" block, and a clone of the call site is placed in the -/// "then" block. The cloned instruction is returned. +/// site. The "if" condition is specified by `Cond`. +/// The original call site is moved into the "else" block, and a clone of the +/// call site is placed in the "then" block. The cloned instruction is returned. /// /// For example, the call instruction below: /// @@ -518,7 +520,8 @@ CallBase &llvm::promoteCall(CallBase &CB, Function *Callee, Type *FormalTy = CalleeType->getParamType(ArgNo); Type *ActualTy = Arg->getType(); if (FormalTy != ActualTy) { - auto *Cast = CastInst::CreateBitOrPointerCast(Arg, FormalTy, "", CB.getIterator()); + auto *Cast = + CastInst::CreateBitOrPointerCast(Arg, FormalTy, "", CB.getIterator()); CB.setArgOperand(ArgNo, Cast); // Remove any incompatible attributes for the argument. @@ -568,6 +571,27 @@ CallBase &llvm::promoteCallWithIfThenElse(CallBase &CB, Function *Callee, return promoteCall(NewInst, Callee); } +CallBase &llvm::promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr, + Function *Callee, + ArrayRef<Constant *> AddressPoints, + MDNode *BranchWeights) { + assert(!AddressPoints.empty() && "Caller should guarantee"); + IRBuilder<> Builder(&CB); + SmallVector<Value *, 2> ICmps; + for (auto &AddressPoint : AddressPoints) + ICmps.push_back(Builder.CreateICmpEQ(VPtr, AddressPoint)); + + // TODO: Perform tree height reduction if the number of ICmps is high. + Value *Cond = Builder.CreateOr(ICmps); + + // Version the indirect call site. If Cond is true, 'NewInst' will be + // executed, otherwise the original call site will be executed. + CallBase &NewInst = versionCallSiteWithCond(CB, Cond, BranchWeights); + + // Promote 'NewInst' so that it directly calls the desired function. + return promoteCall(NewInst, Callee); +} + bool llvm::tryPromoteCall(CallBase &CB) { assert(!CB.getCalledFunction()); Module *M = CB.getCaller()->getParent(); diff --git a/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp b/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp index 0e9641c..2d457eb 100644 --- a/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp +++ b/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp @@ -8,9 +8,12 @@ #include "llvm/Transforms/Utils/CallPromotionUtils.h" #include "llvm/AsmParser/Parser.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" +#include "llvm/IR/NoFolder.h" #include "llvm/Support/SourceMgr.h" #include "gtest/gtest.h" @@ -24,6 +27,21 @@ static std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) { return Mod; } +// Returns a constant representing the vtable's address point specified by the +// offset. +static Constant *getVTableAddressPointOffset(GlobalVariable *VTable, + uint32_t AddressPointOffset) { + Module &M = *VTable->getParent(); + LLVMContext &Context = M.getContext(); + assert(AddressPointOffset < + M.getDataLayout().getTypeAllocSize(VTable->getValueType()) && + "Out-of-bound access"); + + return ConstantExpr::getInBoundsGetElementPtr( + Type::getInt8Ty(Context), VTable, + llvm::ConstantInt::get(Type::getInt32Ty(Context), AddressPointOffset)); +} + TEST(CallPromotionUtilsTest, TryPromoteCall) { LLVMContext C; std::unique_ptr<Module> M = parseIR(C, @@ -368,3 +386,73 @@ declare %struct2 @_ZN4Impl3RunEv(%class.Impl* %this) bool IsPromoted = tryPromoteCall(*CI); EXPECT_FALSE(IsPromoted); } + +TEST(CallPromotionUtilsTest, promoteCallWithVTableCmp) { + LLVMContext C; + std::unique_ptr<Module> M = parseIR(C, + R"IR( +target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +@_ZTV5Base1 = constant { [4 x ptr] } { [4 x ptr] [ptr null, ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev] }, !type !0 +@_ZTV8Derived1 = constant { [4 x ptr], [3 x ptr] } { [4 x ptr] [ptr inttoptr (i64 -8 to ptr), ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev], [3 x ptr] [ptr null, ptr null, ptr @_ZN5Base25func2Ev] }, !type !0, !type !1, !type !2 +@_ZTV8Derived2 = constant { [3 x ptr], [3 x ptr], [4 x ptr] } { [3 x ptr] [ptr null, ptr null, ptr @_ZN5Base35func3Ev], [3 x ptr] [ptr inttoptr (i64 -8 to ptr), ptr null, ptr @_ZN5Base25func2Ev], [4 x ptr] [ptr inttoptr (i64 -16 to ptr), ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev] }, !type !3, !type !4, !type !5, !type !6 + +define i32 @testfunc(ptr %d) { +entry: + %vtable = load ptr, ptr %d, !prof !7 + %vfn = getelementptr inbounds ptr, ptr %vtable, i64 1 + %0 = load ptr, ptr %vfn + %call = tail call i32 %0(ptr %d), !prof !8 + ret i32 %call +} + +define i32 @_ZN5Base15func1Ev(ptr %this) { +entry: + ret i32 2 +} + +declare i32 @_ZN5Base25func2Ev(ptr) +declare i32 @_ZN5Base15func0Ev(ptr) +declare void @_ZN5Base35func3Ev(ptr) + +!0 = !{i64 16, !"_ZTS5Base1"} +!1 = !{i64 48, !"_ZTS5Base2"} +!2 = !{i64 16, !"_ZTS8Derived1"} +!3 = !{i64 64, !"_ZTS5Base1"} +!4 = !{i64 40, !"_ZTS5Base2"} +!5 = !{i64 16, !"_ZTS5Base3"} +!6 = !{i64 16, !"_ZTS8Derived2"} +!7 = !{!"VP", i32 2, i64 1600, i64 -9064381665493407289, i64 800, i64 5035968517245772950, i64 500, i64 3215870116411581797, i64 300} +!8 = !{!"VP", i32 0, i64 1600, i64 6804820478065511155, i64 1600})IR"); + + Function *F = M->getFunction("testfunc"); + CallInst *CI = dyn_cast<CallInst>(&*std::next(F->front().rbegin())); + ASSERT_TRUE(CI && CI->isIndirectCall()); + + // Create the constant and the branch weights + SmallVector<Constant *, 3> VTableAddressPoints; + + for (auto &[VTableName, AddressPointOffset] : {std::pair{"_ZTV5Base1", 16}, + {"_ZTV8Derived1", 16}, + {"_ZTV8Derived2", 64}}) + VTableAddressPoints.push_back(getVTableAddressPointOffset( + M->getGlobalVariable(VTableName), AddressPointOffset)); + + MDBuilder MDB(C); + MDNode *BranchWeights = MDB.createBranchWeights(1600, 0); + + size_t OrigEntryBBSize = F->front().size(); + + LoadInst *VPtr = dyn_cast<LoadInst>(&*F->front().begin()); + + Function *Callee = M->getFunction("_ZN5Base15func1Ev"); + // Tests that promoted direct call is returned. + CallBase &DirectCB = promoteCallWithVTableCmp( + *CI, VPtr, Callee, VTableAddressPoints, BranchWeights); + EXPECT_EQ(DirectCB.getCalledOperand(), Callee); + + // Promotion inserts 3 icmp instructions and 2 or instructions, and removes + // 1 call instruction from the entry block. + EXPECT_EQ(F->front().size(), OrigEntryBBSize + 4); +} |