diff options
| author | Peter Rong <PeterRong96@gmail.com> | 2023-01-12 10:58:38 -0800 |
|---|---|---|
| committer | Peter Rong <PeterRong96@gmail.com> | 2023-01-24 20:22:06 -0800 |
| commit | 9b70a28e0d767f99bdc778356e81b4d072f59819 (patch) | |
| tree | f211eef10e0e1de8c615dc4a89cdeefeec93e8ee /llvm/lib/Transforms/Utils/LowerSwitch.cpp | |
| parent | f9599bbc7a3f831e1793a549d8a7a19265f3e504 (diff) | |
| download | llvm-9b70a28e0d767f99bdc778356e81b4d072f59819.zip llvm-9b70a28e0d767f99bdc778356e81b4d072f59819.tar.gz llvm-9b70a28e0d767f99bdc778356e81b4d072f59819.tar.bz2 | |
[Transform] Rewrite LowerSwitch using APInt
This rewrite fixes https://github.com/llvm/llvm-project/issues/59316.
Previously LowerSwitch uses int64_t, which will crash on case branches using integers with more than 64 bits.
Using APInt fixes this problem. This patch also includes a test
Reviewed By: RKSimon
Differential Revision: https://reviews.llvm.org/D140747
Diffstat (limited to 'llvm/lib/Transforms/Utils/LowerSwitch.cpp')
| -rw-r--r-- | llvm/lib/Transforms/Utils/LowerSwitch.cpp | 21 |
1 files changed, 11 insertions, 10 deletions
diff --git a/llvm/lib/Transforms/Utils/LowerSwitch.cpp b/llvm/lib/Transforms/Utils/LowerSwitch.cpp index 26aebdf..227de42 100644 --- a/llvm/lib/Transforms/Utils/LowerSwitch.cpp +++ b/llvm/lib/Transforms/Utils/LowerSwitch.cpp @@ -370,7 +370,9 @@ void ProcessSwitchInst(SwitchInst *SI, const unsigned NumSimpleCases = Clusterify(Cases, SI); IntegerType *IT = cast<IntegerType>(SI->getCondition()->getType()); const unsigned BitWidth = IT->getBitWidth(); - APInt SignedZero(BitWidth, 0); + // Explictly use higher precision to prevent unsigned overflow where + // `UnsignedMax - 0 + 1 == 0` + APInt UnsignedZero(BitWidth + 1, 0); APInt UnsignedMax = APInt::getMaxValue(BitWidth); LLVM_DEBUG(dbgs() << "Clusterify finished. Total clusters: " << Cases.size() << ". Total non-default cases: " << NumSimpleCases @@ -431,7 +433,7 @@ void ProcessSwitchInst(SwitchInst *SI, if (DefaultIsUnreachableFromSwitch) { DenseMap<BasicBlock *, APInt> Popularity; - APInt MaxPop(SignedZero); + APInt MaxPop(UnsignedZero); BasicBlock *PopSucc = nullptr; APInt SignedMax = APInt::getSignedMaxValue(BitWidth); @@ -457,11 +459,11 @@ void ProcessSwitchInst(SwitchInst *SI, } // Count popularity. - APInt N = High - Low + 1; - assert(N.sge(SignedZero) && "Popularity shouldn't be negative."); + assert(High.sge(Low) && "Popularity shouldn't be negative."); + APInt N = High.sext(BitWidth + 1) - Low.sext(BitWidth + 1) + 1; // Explict insert to make sure the bitwidth of APInts match - APInt &Pop = Popularity.insert({I.BB, APInt(SignedZero)}).first->second; - if ((Pop += N).sgt(MaxPop)) { + APInt &Pop = Popularity.insert({I.BB, APInt(UnsignedZero)}).first->second; + if ((Pop += N).ugt(MaxPop)) { MaxPop = Pop; PopSucc = I.BB; } @@ -486,8 +488,6 @@ void ProcessSwitchInst(SwitchInst *SI, // Use the most popular block as the new default, reducing the number of // cases. - assert(MaxPop.sgt(SignedZero) && PopSucc && - "Max populartion shouldn't be negative."); Default = PopSucc; llvm::erase_if(Cases, [PopSucc](const CaseRange &R) { return R.BB == PopSucc; }); @@ -498,8 +498,9 @@ void ProcessSwitchInst(SwitchInst *SI, SI->eraseFromParent(); // As all the cases have been replaced with a single branch, only keep // one entry in the PHI nodes. - for (APInt I(SignedZero); I.slt(MaxPop - 1); ++I) - PopSucc->removePredecessor(OrigBlock); + if (!MaxPop.isZero()) + for (APInt I(UnsignedZero); I.ult(MaxPop - 1); ++I) + PopSucc->removePredecessor(OrigBlock); return; } |
