aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/Utils/BasicBlockUtils.cpp')
-rw-r--r--llvm/lib/Transforms/Utils/BasicBlockUtils.cpp21
1 files changed, 20 insertions, 1 deletions
diff --git a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp
index c03fbe9..9876681 100644
--- a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp
+++ b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp
@@ -1545,8 +1545,14 @@ Instruction *llvm::SplitBlockAndInsertIfThen(Value *Cond,
void llvm::SplitBlockAndInsertIfThenElse(Value *Cond, Instruction *SplitBefore,
Instruction **ThenTerm,
Instruction **ElseTerm,
- MDNode *BranchWeights) {
+ MDNode *BranchWeights,
+ DomTreeUpdater *DTU) {
BasicBlock *Head = SplitBefore->getParent();
+
+ SmallPtrSet<BasicBlock *, 8> UniqueOrigSuccessors;
+ if (DTU)
+ UniqueOrigSuccessors.insert(succ_begin(Head), succ_end(Head));
+
BasicBlock *Tail = Head->splitBasicBlock(SplitBefore->getIterator());
Instruction *HeadOldTerm = Head->getTerminator();
LLVMContext &C = Head->getContext();
@@ -1560,6 +1566,19 @@ void llvm::SplitBlockAndInsertIfThenElse(Value *Cond, Instruction *SplitBefore,
BranchInst::Create(/*ifTrue*/ThenBlock, /*ifFalse*/ElseBlock, Cond);
HeadNewTerm->setMetadata(LLVMContext::MD_prof, BranchWeights);
ReplaceInstWithInst(HeadOldTerm, HeadNewTerm);
+ if (DTU) {
+ SmallVector<DominatorTree::UpdateType, 8> Updates;
+ Updates.reserve(4 + 2 * UniqueOrigSuccessors.size());
+ for (BasicBlock *Succ : successors(Head)) {
+ Updates.push_back({DominatorTree::Insert, Head, Succ});
+ Updates.push_back({DominatorTree::Insert, Succ, Tail});
+ }
+ for (BasicBlock *UniqueOrigSuccessor : UniqueOrigSuccessors)
+ Updates.push_back({DominatorTree::Insert, Tail, UniqueOrigSuccessor});
+ for (BasicBlock *UniqueOrigSuccessor : UniqueOrigSuccessors)
+ Updates.push_back({DominatorTree::Delete, Head, UniqueOrigSuccessor});
+ DTU->applyUpdates(Updates);
+ }
}
BranchInst *llvm::GetIfCondition(BasicBlock *BB, BasicBlock *&IfTrue,