aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/IR/ConstantFold.cpp
diff options
context:
space:
mode:
authorPaul Walker <paul.walker@arm.com>2024-11-13 12:33:56 +0000
committerGitHub <noreply@github.com>2024-11-13 12:33:56 +0000
commit97298853b4de70dbce9c0a140ac38e3ac179e02e (patch)
tree34f75e1e64d0b2282b82bf14d67697a4a6c8c4bb /llvm/lib/IR/ConstantFold.cpp
parent8ae2a18736c15e0d0d9d0893b21bce4f3bf581c9 (diff)
downloadllvm-97298853b4de70dbce9c0a140ac38e3ac179e02e.zip
llvm-97298853b4de70dbce9c0a140ac38e3ac179e02e.tar.gz
llvm-97298853b4de70dbce9c0a140ac38e3ac179e02e.tar.bz2
[LLVM][IR] Teach constant integer binop folds about vector ConstantInts. (#115739)
The existing logic mostly works with the main changes being: * Use getScalarSizeInBits instead of IntegerType::getBitWidth * Use ConstantInt::get(Type* instead of ConstantInt::get(LLVMContext
Diffstat (limited to 'llvm/lib/IR/ConstantFold.cpp')
-rw-r--r--llvm/lib/IR/ConstantFold.cpp50
1 files changed, 23 insertions, 27 deletions
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index cfe8793..2dbc678 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -231,26 +231,20 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
return nullptr;
case Instruction::ZExt:
if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
- uint32_t BitWidth = cast<IntegerType>(DestTy)->getBitWidth();
- return ConstantInt::get(V->getContext(),
- CI->getValue().zext(BitWidth));
+ uint32_t BitWidth = DestTy->getScalarSizeInBits();
+ return ConstantInt::get(DestTy, CI->getValue().zext(BitWidth));
}
return nullptr;
case Instruction::SExt:
if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
- uint32_t BitWidth = cast<IntegerType>(DestTy)->getBitWidth();
- return ConstantInt::get(V->getContext(),
- CI->getValue().sext(BitWidth));
+ uint32_t BitWidth = DestTy->getScalarSizeInBits();
+ return ConstantInt::get(DestTy, CI->getValue().sext(BitWidth));
}
return nullptr;
case Instruction::Trunc: {
- if (V->getType()->isVectorTy())
- return nullptr;
-
- uint32_t DestBitWidth = cast<IntegerType>(DestTy)->getBitWidth();
if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
- return ConstantInt::get(V->getContext(),
- CI->getValue().trunc(DestBitWidth));
+ uint32_t BitWidth = DestTy->getScalarSizeInBits();
+ return ConstantInt::get(DestTy, CI->getValue().trunc(BitWidth));
}
return nullptr;
@@ -807,44 +801,44 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1,
default:
break;
case Instruction::Add:
- return ConstantInt::get(CI1->getContext(), C1V + C2V);
+ return ConstantInt::get(C1->getType(), C1V + C2V);
case Instruction::Sub:
- return ConstantInt::get(CI1->getContext(), C1V - C2V);
+ return ConstantInt::get(C1->getType(), C1V - C2V);
case Instruction::Mul:
- return ConstantInt::get(CI1->getContext(), C1V * C2V);
+ return ConstantInt::get(C1->getType(), C1V * C2V);
case Instruction::UDiv:
assert(!CI2->isZero() && "Div by zero handled above");
- return ConstantInt::get(CI1->getContext(), C1V.udiv(C2V));
+ return ConstantInt::get(CI1->getType(), C1V.udiv(C2V));
case Instruction::SDiv:
assert(!CI2->isZero() && "Div by zero handled above");
if (C2V.isAllOnes() && C1V.isMinSignedValue())
return PoisonValue::get(CI1->getType()); // MIN_INT / -1 -> poison
- return ConstantInt::get(CI1->getContext(), C1V.sdiv(C2V));
+ return ConstantInt::get(CI1->getType(), C1V.sdiv(C2V));
case Instruction::URem:
assert(!CI2->isZero() && "Div by zero handled above");
- return ConstantInt::get(CI1->getContext(), C1V.urem(C2V));
+ return ConstantInt::get(C1->getType(), C1V.urem(C2V));
case Instruction::SRem:
assert(!CI2->isZero() && "Div by zero handled above");
if (C2V.isAllOnes() && C1V.isMinSignedValue())
- return PoisonValue::get(CI1->getType()); // MIN_INT % -1 -> poison
- return ConstantInt::get(CI1->getContext(), C1V.srem(C2V));
+ return PoisonValue::get(C1->getType()); // MIN_INT % -1 -> poison
+ return ConstantInt::get(C1->getType(), C1V.srem(C2V));
case Instruction::And:
- return ConstantInt::get(CI1->getContext(), C1V & C2V);
+ return ConstantInt::get(C1->getType(), C1V & C2V);
case Instruction::Or:
- return ConstantInt::get(CI1->getContext(), C1V | C2V);
+ return ConstantInt::get(C1->getType(), C1V | C2V);
case Instruction::Xor:
- return ConstantInt::get(CI1->getContext(), C1V ^ C2V);
+ return ConstantInt::get(C1->getType(), C1V ^ C2V);
case Instruction::Shl:
if (C2V.ult(C1V.getBitWidth()))
- return ConstantInt::get(CI1->getContext(), C1V.shl(C2V));
+ return ConstantInt::get(C1->getType(), C1V.shl(C2V));
return PoisonValue::get(C1->getType()); // too big shift is poison
case Instruction::LShr:
if (C2V.ult(C1V.getBitWidth()))
- return ConstantInt::get(CI1->getContext(), C1V.lshr(C2V));
+ return ConstantInt::get(C1->getType(), C1V.lshr(C2V));
return PoisonValue::get(C1->getType()); // too big shift is poison
case Instruction::AShr:
if (C2V.ult(C1V.getBitWidth()))
- return ConstantInt::get(CI1->getContext(), C1V.ashr(C2V));
+ return ConstantInt::get(C1->getType(), C1V.ashr(C2V));
return PoisonValue::get(C1->getType()); // too big shift is poison
}
}
@@ -877,7 +871,9 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1,
return ConstantFP::get(C1->getContext(), C3V);
}
}
- } else if (auto *VTy = dyn_cast<VectorType>(C1->getType())) {
+ }
+
+ if (auto *VTy = dyn_cast<VectorType>(C1->getType())) {
// Fast path for splatted constants.
if (Constant *C2Splat = C2->getSplatValue()) {
if (Instruction::isIntDivRem(Opcode) && C2Splat->isNullValue())