diff options
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 55 |
1 files changed, 55 insertions, 0 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index f0813f1..5878cda 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -14835,6 +14835,59 @@ static SDValue performAddUADDVCombine(SDNode *N, SelectionDAG &DAG) { DAG.getConstant(0, DL, MVT::i64)); } +/// Perform the scalar expression combine in the form of: +/// CSEL (c, 1, cc) + b => CSINC(b+c, b, cc) +static SDValue performAddCSelIntoCSinc(SDNode *N, SelectionDAG &DAG) { + EVT VT = N->getValueType(0); + if (!VT.isScalarInteger() || N->getOpcode() != ISD::ADD) + return SDValue(); + + SDValue CSel = N->getOperand(0); + SDValue RHS = N->getOperand(1); + + // Handle commutivity. + if (CSel.getOpcode() != AArch64ISD::CSEL) { + std::swap(CSel, RHS); + if (CSel.getOpcode() != AArch64ISD::CSEL) { + return SDValue(); + } + } + + if (!CSel.hasOneUse()) + return SDValue(); + + AArch64CC::CondCode AArch64CC = + static_cast<AArch64CC::CondCode>(CSel.getConstantOperandVal(2)); + + // The CSEL should include a const one operand. + ConstantSDNode *CTVal = dyn_cast<ConstantSDNode>(CSel.getOperand(0)); + ConstantSDNode *CFVal = dyn_cast<ConstantSDNode>(CSel.getOperand(1)); + if (!CTVal || !CFVal || (!CTVal->isOne() && !CFVal->isOne())) + return SDValue(); + + // switch CSEL (1, c, cc) to CSEL (c, 1, !cc) + if (CTVal->isOne() && !CFVal->isOne()) { + std::swap(CTVal, CFVal); + AArch64CC = AArch64CC::getInvertedCondCode(AArch64CC); + } + + // It might be neutral for larger constants, as the immediate need to be + // materialized in a register. + APInt ADDC = CTVal->getAPIntValue(); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + if (!TLI.isLegalAddImmediate(ADDC.getSExtValue())) + return SDValue(); + + assert(CFVal->isOne() && "Unexpected constant value"); + + SDLoc DL(N); + SDValue NewNode = DAG.getNode(ISD::ADD, DL, VT, RHS, SDValue(CTVal, 0)); + SDValue CCVal = DAG.getConstant(AArch64CC, DL, MVT::i32); + SDValue Cmp = CSel.getOperand(3); + + return DAG.getNode(AArch64ISD::CSINC, DL, VT, NewNode, RHS, CCVal, Cmp); +} + // ADD(UDOT(zero, x, y), A) --> UDOT(A, x, y) static SDValue performAddDotCombine(SDNode *N, SelectionDAG &DAG) { EVT VT = N->getValueType(0); @@ -14919,6 +14972,8 @@ static SDValue performAddSubCombine(SDNode *N, return Val; if (SDValue Val = performAddDotCombine(N, DAG)) return Val; + if (SDValue Val = performAddCSelIntoCSinc(N, DAG)) + return Val; return performAddSubLongCombine(N, DCI, DAG); } |