diff options
Diffstat (limited to 'llvm/lib/IR/ConstantFold.cpp')
-rw-r--r-- | llvm/lib/IR/ConstantFold.cpp | 50 |
1 files changed, 23 insertions, 27 deletions
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp index cfe8793..2dbc678 100644 --- a/llvm/lib/IR/ConstantFold.cpp +++ b/llvm/lib/IR/ConstantFold.cpp @@ -231,26 +231,20 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V, return nullptr; case Instruction::ZExt: if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) { - uint32_t BitWidth = cast<IntegerType>(DestTy)->getBitWidth(); - return ConstantInt::get(V->getContext(), - CI->getValue().zext(BitWidth)); + uint32_t BitWidth = DestTy->getScalarSizeInBits(); + return ConstantInt::get(DestTy, CI->getValue().zext(BitWidth)); } return nullptr; case Instruction::SExt: if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) { - uint32_t BitWidth = cast<IntegerType>(DestTy)->getBitWidth(); - return ConstantInt::get(V->getContext(), - CI->getValue().sext(BitWidth)); + uint32_t BitWidth = DestTy->getScalarSizeInBits(); + return ConstantInt::get(DestTy, CI->getValue().sext(BitWidth)); } return nullptr; case Instruction::Trunc: { - if (V->getType()->isVectorTy()) - return nullptr; - - uint32_t DestBitWidth = cast<IntegerType>(DestTy)->getBitWidth(); if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) { - return ConstantInt::get(V->getContext(), - CI->getValue().trunc(DestBitWidth)); + uint32_t BitWidth = DestTy->getScalarSizeInBits(); + return ConstantInt::get(DestTy, CI->getValue().trunc(BitWidth)); } return nullptr; @@ -807,44 +801,44 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1, default: break; case Instruction::Add: - return ConstantInt::get(CI1->getContext(), C1V + C2V); + return ConstantInt::get(C1->getType(), C1V + C2V); case Instruction::Sub: - return ConstantInt::get(CI1->getContext(), C1V - C2V); + return ConstantInt::get(C1->getType(), C1V - C2V); case Instruction::Mul: - return ConstantInt::get(CI1->getContext(), C1V * C2V); + return ConstantInt::get(C1->getType(), C1V * C2V); case Instruction::UDiv: assert(!CI2->isZero() && "Div by zero handled above"); - return ConstantInt::get(CI1->getContext(), C1V.udiv(C2V)); + return ConstantInt::get(CI1->getType(), C1V.udiv(C2V)); case Instruction::SDiv: assert(!CI2->isZero() && "Div by zero handled above"); if (C2V.isAllOnes() && C1V.isMinSignedValue()) return PoisonValue::get(CI1->getType()); // MIN_INT / -1 -> poison - return ConstantInt::get(CI1->getContext(), C1V.sdiv(C2V)); + return ConstantInt::get(CI1->getType(), C1V.sdiv(C2V)); case Instruction::URem: assert(!CI2->isZero() && "Div by zero handled above"); - return ConstantInt::get(CI1->getContext(), C1V.urem(C2V)); + return ConstantInt::get(C1->getType(), C1V.urem(C2V)); case Instruction::SRem: assert(!CI2->isZero() && "Div by zero handled above"); if (C2V.isAllOnes() && C1V.isMinSignedValue()) - return PoisonValue::get(CI1->getType()); // MIN_INT % -1 -> poison - return ConstantInt::get(CI1->getContext(), C1V.srem(C2V)); + return PoisonValue::get(C1->getType()); // MIN_INT % -1 -> poison + return ConstantInt::get(C1->getType(), C1V.srem(C2V)); case Instruction::And: - return ConstantInt::get(CI1->getContext(), C1V & C2V); + return ConstantInt::get(C1->getType(), C1V & C2V); case Instruction::Or: - return ConstantInt::get(CI1->getContext(), C1V | C2V); + return ConstantInt::get(C1->getType(), C1V | C2V); case Instruction::Xor: - return ConstantInt::get(CI1->getContext(), C1V ^ C2V); + return ConstantInt::get(C1->getType(), C1V ^ C2V); case Instruction::Shl: if (C2V.ult(C1V.getBitWidth())) - return ConstantInt::get(CI1->getContext(), C1V.shl(C2V)); + return ConstantInt::get(C1->getType(), C1V.shl(C2V)); return PoisonValue::get(C1->getType()); // too big shift is poison case Instruction::LShr: if (C2V.ult(C1V.getBitWidth())) - return ConstantInt::get(CI1->getContext(), C1V.lshr(C2V)); + return ConstantInt::get(C1->getType(), C1V.lshr(C2V)); return PoisonValue::get(C1->getType()); // too big shift is poison case Instruction::AShr: if (C2V.ult(C1V.getBitWidth())) - return ConstantInt::get(CI1->getContext(), C1V.ashr(C2V)); + return ConstantInt::get(C1->getType(), C1V.ashr(C2V)); return PoisonValue::get(C1->getType()); // too big shift is poison } } @@ -877,7 +871,9 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1, return ConstantFP::get(C1->getContext(), C3V); } } - } else if (auto *VTy = dyn_cast<VectorType>(C1->getType())) { + } + + if (auto *VTy = dyn_cast<VectorType>(C1->getType())) { // Fast path for splatted constants. if (Constant *C2Splat = C2->getSplatValue()) { if (Instruction::isIntDivRem(Opcode) && C2Splat->isNullValue()) |