diff options
author | Benjamin Kramer <benny.kra@googlemail.com> | 2024-02-22 15:25:17 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-22 15:25:17 +0100 |
commit | d3f6dd6585f4866a38a794b80db55a62c1050c77 (patch) | |
tree | 87503c7f9414663231db4d1966fdda1dfc9a814e /llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp | |
parent | 88e31f64a034ec6dead2106016ee5b797674edb0 (diff) | |
download | llvm-d3f6dd6585f4866a38a794b80db55a62c1050c77.zip llvm-d3f6dd6585f4866a38a794b80db55a62c1050c77.tar.gz llvm-d3f6dd6585f4866a38a794b80db55a62c1050c77.tar.bz2 |
[InstCombine] Pick bfloat over half when shrinking ops that started with an fpext from bfloat (#82493)
This fixes the case where we would shrink an frem to half and then
bitcast to bfloat, producing invalid results. The transformation was
written under the assumption that there is only one type with a given
bit width.
Also add a strategic assert to CastInst::CreateFPCast to turn this
miscompilation into a crash.
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp | 23 |
1 files changed, 14 insertions, 9 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index ed47de28..33ed1d5 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -1543,11 +1543,14 @@ static bool fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) { return !losesInfo; } -static Type *shrinkFPConstant(ConstantFP *CFP) { +static Type *shrinkFPConstant(ConstantFP *CFP, bool PreferBFloat) { if (CFP->getType() == Type::getPPC_FP128Ty(CFP->getContext())) return nullptr; // No constant folding of this. + // See if the value can be truncated to bfloat and then reextended. + if (PreferBFloat && fitsInFPType(CFP, APFloat::BFloat())) + return Type::getBFloatTy(CFP->getContext()); // See if the value can be truncated to half and then reextended. - if (fitsInFPType(CFP, APFloat::IEEEhalf())) + if (!PreferBFloat && fitsInFPType(CFP, APFloat::IEEEhalf())) return Type::getHalfTy(CFP->getContext()); // See if the value can be truncated to float and then reextended. if (fitsInFPType(CFP, APFloat::IEEEsingle())) @@ -1562,7 +1565,7 @@ static Type *shrinkFPConstant(ConstantFP *CFP) { // Determine if this is a vector of ConstantFPs and if so, return the minimal // type we can safely truncate all elements to. -static Type *shrinkFPConstantVector(Value *V) { +static Type *shrinkFPConstantVector(Value *V, bool PreferBFloat) { auto *CV = dyn_cast<Constant>(V); auto *CVVTy = dyn_cast<FixedVectorType>(V->getType()); if (!CV || !CVVTy) @@ -1582,7 +1585,7 @@ static Type *shrinkFPConstantVector(Value *V) { if (!CFP) return nullptr; - Type *T = shrinkFPConstant(CFP); + Type *T = shrinkFPConstant(CFP, PreferBFloat); if (!T) return nullptr; @@ -1597,7 +1600,7 @@ static Type *shrinkFPConstantVector(Value *V) { } /// Find the minimum FP type we can safely truncate to. -static Type *getMinimumFPType(Value *V) { +static Type *getMinimumFPType(Value *V, bool PreferBFloat) { if (auto *FPExt = dyn_cast<FPExtInst>(V)) return FPExt->getOperand(0)->getType(); @@ -1605,7 +1608,7 @@ static Type *getMinimumFPType(Value *V) { // that can accurately represent it. This allows us to turn // (float)((double)X+2.0) into x+2.0f. if (auto *CFP = dyn_cast<ConstantFP>(V)) - if (Type *T = shrinkFPConstant(CFP)) + if (Type *T = shrinkFPConstant(CFP, PreferBFloat)) return T; // We can only correctly find a minimum type for a scalable vector when it is @@ -1617,7 +1620,7 @@ static Type *getMinimumFPType(Value *V) { // Try to shrink a vector of FP constants. This returns nullptr on scalable // vectors - if (Type *T = shrinkFPConstantVector(V)) + if (Type *T = shrinkFPConstantVector(V, PreferBFloat)) return T; return V->getType(); @@ -1686,8 +1689,10 @@ Instruction *InstCombinerImpl::visitFPTrunc(FPTruncInst &FPT) { Type *Ty = FPT.getType(); auto *BO = dyn_cast<BinaryOperator>(FPT.getOperand(0)); if (BO && BO->hasOneUse()) { - Type *LHSMinType = getMinimumFPType(BO->getOperand(0)); - Type *RHSMinType = getMinimumFPType(BO->getOperand(1)); + Type *LHSMinType = + getMinimumFPType(BO->getOperand(0), /*PreferBFloat=*/Ty->isBFloatTy()); + Type *RHSMinType = + getMinimumFPType(BO->getOperand(1), /*PreferBFloat=*/Ty->isBFloatTy()); unsigned OpWidth = BO->getType()->getFPMantissaWidth(); unsigned LHSWidth = LHSMinType->getFPMantissaWidth(); unsigned RHSWidth = RHSMinType->getFPMantissaWidth(); |