diff options
author | Paul Walker <paul.walker@arm.com> | 2024-11-13 12:33:56 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-11-13 12:33:56 +0000 |
commit | 97298853b4de70dbce9c0a140ac38e3ac179e02e (patch) | |
tree | 34f75e1e64d0b2282b82bf14d67697a4a6c8c4bb /llvm/lib/IR/ConstantFold.cpp | |
parent | 8ae2a18736c15e0d0d9d0893b21bce4f3bf581c9 (diff) | |
download | llvm-97298853b4de70dbce9c0a140ac38e3ac179e02e.zip llvm-97298853b4de70dbce9c0a140ac38e3ac179e02e.tar.gz llvm-97298853b4de70dbce9c0a140ac38e3ac179e02e.tar.bz2 |
[LLVM][IR] Teach constant integer binop folds about vector ConstantInts. (#115739)
The existing logic mostly works with the main changes being:
* Use getScalarSizeInBits instead of IntegerType::getBitWidth
* Use ConstantInt::get(Type* instead of ConstantInt::get(LLVMContext
Diffstat (limited to 'llvm/lib/IR/ConstantFold.cpp')
-rw-r--r-- | llvm/lib/IR/ConstantFold.cpp | 50 |
1 files changed, 23 insertions, 27 deletions
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp index cfe8793..2dbc678 100644 --- a/llvm/lib/IR/ConstantFold.cpp +++ b/llvm/lib/IR/ConstantFold.cpp @@ -231,26 +231,20 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V, return nullptr; case Instruction::ZExt: if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) { - uint32_t BitWidth = cast<IntegerType>(DestTy)->getBitWidth(); - return ConstantInt::get(V->getContext(), - CI->getValue().zext(BitWidth)); + uint32_t BitWidth = DestTy->getScalarSizeInBits(); + return ConstantInt::get(DestTy, CI->getValue().zext(BitWidth)); } return nullptr; case Instruction::SExt: if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) { - uint32_t BitWidth = cast<IntegerType>(DestTy)->getBitWidth(); - return ConstantInt::get(V->getContext(), - CI->getValue().sext(BitWidth)); + uint32_t BitWidth = DestTy->getScalarSizeInBits(); + return ConstantInt::get(DestTy, CI->getValue().sext(BitWidth)); } return nullptr; case Instruction::Trunc: { - if (V->getType()->isVectorTy()) - return nullptr; - - uint32_t DestBitWidth = cast<IntegerType>(DestTy)->getBitWidth(); if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) { - return ConstantInt::get(V->getContext(), - CI->getValue().trunc(DestBitWidth)); + uint32_t BitWidth = DestTy->getScalarSizeInBits(); + return ConstantInt::get(DestTy, CI->getValue().trunc(BitWidth)); } return nullptr; @@ -807,44 +801,44 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1, default: break; case Instruction::Add: - return ConstantInt::get(CI1->getContext(), C1V + C2V); + return ConstantInt::get(C1->getType(), C1V + C2V); case Instruction::Sub: - return ConstantInt::get(CI1->getContext(), C1V - C2V); + return ConstantInt::get(C1->getType(), C1V - C2V); case Instruction::Mul: - return ConstantInt::get(CI1->getContext(), C1V * C2V); + return ConstantInt::get(C1->getType(), C1V * C2V); case Instruction::UDiv: assert(!CI2->isZero() && "Div by zero handled above"); - return ConstantInt::get(CI1->getContext(), C1V.udiv(C2V)); + return ConstantInt::get(CI1->getType(), C1V.udiv(C2V)); case Instruction::SDiv: assert(!CI2->isZero() && "Div by zero handled above"); if (C2V.isAllOnes() && C1V.isMinSignedValue()) return PoisonValue::get(CI1->getType()); // MIN_INT / -1 -> poison - return ConstantInt::get(CI1->getContext(), C1V.sdiv(C2V)); + return ConstantInt::get(CI1->getType(), C1V.sdiv(C2V)); case Instruction::URem: assert(!CI2->isZero() && "Div by zero handled above"); - return ConstantInt::get(CI1->getContext(), C1V.urem(C2V)); + return ConstantInt::get(C1->getType(), C1V.urem(C2V)); case Instruction::SRem: assert(!CI2->isZero() && "Div by zero handled above"); if (C2V.isAllOnes() && C1V.isMinSignedValue()) - return PoisonValue::get(CI1->getType()); // MIN_INT % -1 -> poison - return ConstantInt::get(CI1->getContext(), C1V.srem(C2V)); + return PoisonValue::get(C1->getType()); // MIN_INT % -1 -> poison + return ConstantInt::get(C1->getType(), C1V.srem(C2V)); case Instruction::And: - return ConstantInt::get(CI1->getContext(), C1V & C2V); + return ConstantInt::get(C1->getType(), C1V & C2V); case Instruction::Or: - return ConstantInt::get(CI1->getContext(), C1V | C2V); + return ConstantInt::get(C1->getType(), C1V | C2V); case Instruction::Xor: - return ConstantInt::get(CI1->getContext(), C1V ^ C2V); + return ConstantInt::get(C1->getType(), C1V ^ C2V); case Instruction::Shl: if (C2V.ult(C1V.getBitWidth())) - return ConstantInt::get(CI1->getContext(), C1V.shl(C2V)); + return ConstantInt::get(C1->getType(), C1V.shl(C2V)); return PoisonValue::get(C1->getType()); // too big shift is poison case Instruction::LShr: if (C2V.ult(C1V.getBitWidth())) - return ConstantInt::get(CI1->getContext(), C1V.lshr(C2V)); + return ConstantInt::get(C1->getType(), C1V.lshr(C2V)); return PoisonValue::get(C1->getType()); // too big shift is poison case Instruction::AShr: if (C2V.ult(C1V.getBitWidth())) - return ConstantInt::get(CI1->getContext(), C1V.ashr(C2V)); + return ConstantInt::get(C1->getType(), C1V.ashr(C2V)); return PoisonValue::get(C1->getType()); // too big shift is poison } } @@ -877,7 +871,9 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1, return ConstantFP::get(C1->getContext(), C3V); } } - } else if (auto *VTy = dyn_cast<VectorType>(C1->getType())) { + } + + if (auto *VTy = dyn_cast<VectorType>(C1->getType())) { // Fast path for splatted constants. if (Constant *C2Splat = C2->getSplatValue()) { if (Instruction::isIntDivRem(Opcode) && C2Splat->isNullValue()) |