aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/CodeGen/ExpandFp.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/CodeGen/ExpandFp.cpp')
-rw-r--r--llvm/lib/CodeGen/ExpandFp.cpp135
1 files changed, 55 insertions, 80 deletions
diff --git a/llvm/lib/CodeGen/ExpandFp.cpp b/llvm/lib/CodeGen/ExpandFp.cpp
index 9cc6c6a..c500357 100644
--- a/llvm/lib/CodeGen/ExpandFp.cpp
+++ b/llvm/lib/CodeGen/ExpandFp.cpp
@@ -82,7 +82,7 @@ public:
}
static FRemExpander create(IRBuilder<> &B, Type *Ty) {
- assert(canExpandType(Ty));
+ assert(canExpandType(Ty) && "Expected supported floating point type");
// The type to use for the computation of the remainder. This may be
// wider than the input/result type which affects the ...
@@ -356,8 +356,9 @@ Value *FRemExpander::buildFRem(Value *X, Value *Y,
static bool expandFRem(BinaryOperator &I, std::optional<SimplifyQuery> &SQ) {
LLVM_DEBUG(dbgs() << "Expanding instruction: " << I << '\n');
- Type *ReturnTy = I.getType();
- assert(FRemExpander::canExpandType(ReturnTy->getScalarType()));
+ Type *Ty = I.getType();
+ assert(FRemExpander::canExpandType(Ty) &&
+ "Expected supported floating point type");
FastMathFlags FMF = I.getFastMathFlags();
// TODO Make use of those flags for optimization?
@@ -368,32 +369,10 @@ static bool expandFRem(BinaryOperator &I, std::optional<SimplifyQuery> &SQ) {
B.setFastMathFlags(FMF);
B.SetCurrentDebugLocation(I.getDebugLoc());
- Type *ElemTy = ReturnTy->getScalarType();
- const FRemExpander Expander = FRemExpander::create(B, ElemTy);
-
- Value *Ret;
- if (ReturnTy->isFloatingPointTy())
- Ret = FMF.approxFunc()
- ? Expander.buildApproxFRem(I.getOperand(0), I.getOperand(1))
- : Expander.buildFRem(I.getOperand(0), I.getOperand(1), SQ);
- else {
- auto *VecTy = cast<FixedVectorType>(ReturnTy);
-
- // This could use SplitBlockAndInsertForEachLane but the interface
- // is a bit awkward for a constant number of elements and it will
- // boil down to the same code.
- // TODO Expand the FRem instruction only once and reuse the code.
- Value *Nums = I.getOperand(0);
- Value *Denums = I.getOperand(1);
- Ret = PoisonValue::get(I.getType());
- for (int I = 0, E = VecTy->getNumElements(); I != E; ++I) {
- Value *Num = B.CreateExtractElement(Nums, I);
- Value *Denum = B.CreateExtractElement(Denums, I);
- Value *Rem = FMF.approxFunc() ? Expander.buildApproxFRem(Num, Denum)
- : Expander.buildFRem(Num, Denum, SQ);
- Ret = B.CreateInsertElement(Ret, Rem, I);
- }
- }
+ const FRemExpander Expander = FRemExpander::create(B, Ty);
+ Value *Ret = FMF.approxFunc()
+ ? Expander.buildApproxFRem(I.getOperand(0), I.getOperand(1))
+ : Expander.buildFRem(I.getOperand(0), I.getOperand(1), SQ);
I.replaceAllUsesWith(Ret);
Ret->takeName(&I);
@@ -939,7 +918,8 @@ static void expandIToFP(Instruction *IToFP) {
IToFP->eraseFromParent();
}
-static void scalarize(Instruction *I, SmallVectorImpl<Instruction *> &Replace) {
+static void scalarize(Instruction *I,
+ SmallVectorImpl<Instruction *> &Worklist) {
VectorType *VTy = cast<FixedVectorType>(I->getType());
IRBuilder<> Builder(I);
@@ -948,12 +928,25 @@ static void scalarize(Instruction *I, SmallVectorImpl<Instruction *> &Replace) {
Value *Result = PoisonValue::get(VTy);
for (unsigned Idx = 0; Idx < NumElements; ++Idx) {
Value *Ext = Builder.CreateExtractElement(I->getOperand(0), Idx);
- Value *Cast = Builder.CreateCast(cast<CastInst>(I)->getOpcode(), Ext,
- I->getType()->getScalarType());
- Result = Builder.CreateInsertElement(Result, Cast, Idx);
- if (isa<Instruction>(Cast))
- Replace.push_back(cast<Instruction>(Cast));
+
+ Value *NewOp = nullptr;
+ if (auto *BinOp = dyn_cast<BinaryOperator>(I))
+ NewOp = Builder.CreateBinOp(
+ BinOp->getOpcode(), Ext,
+ Builder.CreateExtractElement(I->getOperand(1), Idx));
+ else if (auto *CastI = dyn_cast<CastInst>(I))
+ NewOp = Builder.CreateCast(CastI->getOpcode(), Ext,
+ I->getType()->getScalarType());
+ else
+ llvm_unreachable("Unsupported instruction type");
+
+ Result = Builder.CreateInsertElement(Result, NewOp, Idx);
+ if (auto *ScalarizedI = dyn_cast<Instruction>(NewOp)) {
+ ScalarizedI->copyIRFlags(I, true);
+ Worklist.push_back(ScalarizedI);
+ }
}
+
I->replaceAllUsesWith(Result);
I->dropAllReferences();
I->eraseFromParent();
@@ -989,10 +982,17 @@ static bool targetSupportsFrem(const TargetLowering &TLI, Type *Ty) {
return TLI.getLibcallName(fremToLibcall(Ty->getScalarType()));
}
+static void addToWorklist(Instruction &I,
+ SmallVector<Instruction *, 4> &Worklist) {
+ if (I.getOperand(0)->getType()->isVectorTy())
+ scalarize(&I, Worklist);
+ else
+ Worklist.push_back(&I);
+}
+
static bool runImpl(Function &F, const TargetLowering &TLI,
AssumptionCache *AC) {
- SmallVector<Instruction *, 4> Replace;
- SmallVector<Instruction *, 4> ReplaceVector;
+ SmallVector<Instruction *, 4> Worklist;
bool Modified = false;
unsigned MaxLegalFpConvertBitWidth =
@@ -1003,56 +1003,39 @@ static bool runImpl(Function &F, const TargetLowering &TLI,
if (MaxLegalFpConvertBitWidth >= llvm::IntegerType::MAX_INT_BITS)
return false;
- for (auto &I : instructions(F)) {
- switch (I.getOpcode()) {
- case Instruction::FRem: {
- Type *Ty = I.getType();
- // TODO: This pass doesn't handle scalable vectors.
- if (Ty->isScalableTy())
- continue;
-
- if (targetSupportsFrem(TLI, Ty) ||
- !FRemExpander::canExpandType(Ty->getScalarType()))
- continue;
-
- Replace.push_back(&I);
- Modified = true;
+ for (auto It = inst_begin(&F), End = inst_end(F); It != End;) {
+ Instruction &I = *It++;
+ Type *Ty = I.getType();
+ // TODO: This pass doesn't handle scalable vectors.
+ if (Ty->isScalableTy())
+ continue;
+ switch (I.getOpcode()) {
+ case Instruction::FRem:
+ if (!targetSupportsFrem(TLI, Ty) &&
+ FRemExpander::canExpandType(Ty->getScalarType())) {
+ addToWorklist(I, Worklist);
+ Modified = true;
+ }
break;
- }
case Instruction::FPToUI:
case Instruction::FPToSI: {
- // TODO: This pass doesn't handle scalable vectors.
- if (I.getOperand(0)->getType()->isScalableTy())
- continue;
-
- auto *IntTy = cast<IntegerType>(I.getType()->getScalarType());
+ auto *IntTy = cast<IntegerType>(Ty->getScalarType());
if (IntTy->getIntegerBitWidth() <= MaxLegalFpConvertBitWidth)
continue;
- if (I.getOperand(0)->getType()->isVectorTy())
- ReplaceVector.push_back(&I);
- else
- Replace.push_back(&I);
+ addToWorklist(I, Worklist);
Modified = true;
break;
}
case Instruction::UIToFP:
case Instruction::SIToFP: {
- // TODO: This pass doesn't handle scalable vectors.
- if (I.getOperand(0)->getType()->isScalableTy())
- continue;
-
auto *IntTy =
cast<IntegerType>(I.getOperand(0)->getType()->getScalarType());
if (IntTy->getIntegerBitWidth() <= MaxLegalFpConvertBitWidth)
continue;
- if (I.getOperand(0)->getType()->isVectorTy())
- ReplaceVector.push_back(&I);
- else
- Replace.push_back(&I);
- Modified = true;
+ addToWorklist(I, Worklist);
break;
}
default:
@@ -1060,16 +1043,8 @@ static bool runImpl(Function &F, const TargetLowering &TLI,
}
}
- while (!ReplaceVector.empty()) {
- Instruction *I = ReplaceVector.pop_back_val();
- scalarize(I, Replace);
- }
-
- if (Replace.empty())
- return false;
-
- while (!Replace.empty()) {
- Instruction *I = Replace.pop_back_val();
+ while (!Worklist.empty()) {
+ Instruction *I = Worklist.pop_back_val();
if (I->getOpcode() == Instruction::FRem) {
auto SQ = [&]() -> std::optional<SimplifyQuery> {
if (AC) {