aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp')
-rw-r--r--llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp79
1 files changed, 52 insertions, 27 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 531297b..3672a91 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -7175,6 +7175,45 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL,
}
}
+ // Handle fshl/fshr special cases.
+ if (Opcode == ISD::FSHL || Opcode == ISD::FSHR) {
+ auto *C1 = dyn_cast<ConstantSDNode>(Ops[0]);
+ auto *C2 = dyn_cast<ConstantSDNode>(Ops[1]);
+ auto *C3 = dyn_cast<ConstantSDNode>(Ops[2]);
+
+ if (C1 && C2 && C3) {
+ if (C1->isOpaque() || C2->isOpaque() || C3->isOpaque())
+ return SDValue();
+ const APInt &V1 = C1->getAPIntValue(), &V2 = C2->getAPIntValue(),
+ &V3 = C3->getAPIntValue();
+
+ APInt FoldedVal = Opcode == ISD::FSHL ? APIntOps::fshl(V1, V2, V3)
+ : APIntOps::fshr(V1, V2, V3);
+ return getConstant(FoldedVal, DL, VT);
+ }
+ }
+
+ // Handle fma/fmad special cases.
+ if (Opcode == ISD::FMA || Opcode == ISD::FMAD) {
+ assert(VT.isFloatingPoint() && "This operator only applies to FP types!");
+ assert(Ops[0].getValueType() == VT && Ops[1].getValueType() == VT &&
+ Ops[2].getValueType() == VT && "FMA types must match!");
+ ConstantFPSDNode *C1 = dyn_cast<ConstantFPSDNode>(Ops[0]);
+ ConstantFPSDNode *C2 = dyn_cast<ConstantFPSDNode>(Ops[1]);
+ ConstantFPSDNode *C3 = dyn_cast<ConstantFPSDNode>(Ops[2]);
+ if (C1 && C2 && C3) {
+ APFloat V1 = C1->getValueAPF();
+ const APFloat &V2 = C2->getValueAPF();
+ const APFloat &V3 = C3->getValueAPF();
+ if (Opcode == ISD::FMAD) {
+ V1.multiply(V2, APFloat::rmNearestTiesToEven);
+ V1.add(V3, APFloat::rmNearestTiesToEven);
+ } else
+ V1.fusedMultiplyAdd(V2, V3, APFloat::rmNearestTiesToEven);
+ return getConstantFP(V1, DL, VT);
+ }
+ }
+
// This is for vector folding only from here on.
if (!VT.isVector())
return SDValue();
@@ -8137,27 +8176,6 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
"Operand is DELETED_NODE!");
// Perform various simplifications.
switch (Opcode) {
- case ISD::FMA:
- case ISD::FMAD: {
- assert(VT.isFloatingPoint() && "This operator only applies to FP types!");
- assert(N1.getValueType() == VT && N2.getValueType() == VT &&
- N3.getValueType() == VT && "FMA types must match!");
- ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1);
- ConstantFPSDNode *N2CFP = dyn_cast<ConstantFPSDNode>(N2);
- ConstantFPSDNode *N3CFP = dyn_cast<ConstantFPSDNode>(N3);
- if (N1CFP && N2CFP && N3CFP) {
- APFloat V1 = N1CFP->getValueAPF();
- const APFloat &V2 = N2CFP->getValueAPF();
- const APFloat &V3 = N3CFP->getValueAPF();
- if (Opcode == ISD::FMAD) {
- V1.multiply(V2, APFloat::rmNearestTiesToEven);
- V1.add(V3, APFloat::rmNearestTiesToEven);
- } else
- V1.fusedMultiplyAdd(V2, V3, APFloat::rmNearestTiesToEven);
- return getConstantFP(V1, DL, VT);
- }
- break;
- }
case ISD::BUILD_VECTOR: {
// Attempt to simplify BUILD_VECTOR.
SDValue Ops[] = {N1, N2, N3};
@@ -8183,12 +8201,6 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
// Use FoldSetCC to simplify SETCC's.
if (SDValue V = FoldSetCC(VT, N1, N2, cast<CondCodeSDNode>(N3)->get(), DL))
return V;
- // Vector constant folding.
- SDValue Ops[] = {N1, N2, N3};
- if (SDValue V = FoldConstantArithmetic(Opcode, DL, VT, Ops)) {
- NewSDValueDbgMsg(V, "New node vector constant folding: ", this);
- return V;
- }
break;
}
case ISD::SELECT:
@@ -8324,6 +8336,19 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
}
}
+ // Perform trivial constant folding for arithmetic operators.
+ switch (Opcode) {
+ case ISD::FMA:
+ case ISD::FMAD:
+ case ISD::SETCC:
+ case ISD::FSHL:
+ case ISD::FSHR:
+ if (SDValue SV =
+ FoldConstantArithmetic(Opcode, DL, VT, {N1, N2, N3}, Flags))
+ return SV;
+ break;
+ }
+
// Memoize node if it doesn't produce a glue result.
SDNode *N;
SDVTList VTs = getVTList(VT);