From 9fa3971fac27fbe0a6e3b9745d201c16f5f98bc2 Mon Sep 17 00:00:00 2001 From: Piotr Fusik Date: Thu, 17 Jul 2025 16:37:59 +0200 Subject: [DAGCombiner] Fold vector subtraction if above threshold to `umin` (#148834) This extends #134235 and #135194 to vectors. --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 87 +++++++++++++++++---------- 1 file changed, 54 insertions(+), 33 deletions(-) (limited to 'llvm/lib/CodeGen') diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 0e8e4c9..40464e9 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -609,6 +609,8 @@ namespace { SDValue foldABSToABD(SDNode *N, const SDLoc &DL); SDValue foldSelectToABD(SDValue LHS, SDValue RHS, SDValue True, SDValue False, ISD::CondCode CC, const SDLoc &DL); + SDValue foldSelectToUMin(SDValue LHS, SDValue RHS, SDValue True, + SDValue False, ISD::CondCode CC, const SDLoc &DL); SDValue unfoldMaskedMerge(SDNode *N); SDValue unfoldExtremeBitClearingToShifts(SDNode *N); SDValue SimplifySetCC(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond, @@ -859,7 +861,7 @@ namespace { auto LK = TLI.getTypeConversion(*DAG.getContext(), VT); return (LK.first == TargetLoweringBase::TypeLegal || LK.first == TargetLoweringBase::TypePromoteInteger) && - TLI.isOperationLegal(ISD::UMIN, LK.second); + TLI.isOperationLegalOrCustom(ISD::UMIN, LK.second); } public: @@ -4093,6 +4095,26 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { return N0; } + // (sub x, ([v]select (ult x, y), 0, y)) -> (umin x, (sub x, y)) + // (sub x, ([v]select (uge x, y), y, 0)) -> (umin x, (sub x, y)) + if (N1.hasOneUse() && hasUMin(VT)) { + SDValue Y; + if (sd_match(N1, m_Select(m_SetCC(m_Specific(N0), m_Value(Y), + m_SpecificCondCode(ISD::SETULT)), + m_Zero(), m_Deferred(Y))) || + sd_match(N1, m_Select(m_SetCC(m_Specific(N0), m_Value(Y), + m_SpecificCondCode(ISD::SETUGE)), + m_Deferred(Y), m_Zero())) || + sd_match(N1, m_VSelect(m_SetCC(m_Specific(N0), m_Value(Y), + m_SpecificCondCode(ISD::SETULT)), + m_Zero(), m_Deferred(Y))) || + sd_match(N1, m_VSelect(m_SetCC(m_Specific(N0), m_Value(Y), + m_SpecificCondCode(ISD::SETUGE)), + m_Deferred(Y), m_Zero()))) + return DAG.getNode(ISD::UMIN, DL, VT, N0, + DAG.getNode(ISD::SUB, DL, VT, N0, Y)); + } + if (SDValue NewSel = foldBinOpIntoSelect(N)) return NewSel; @@ -4442,20 +4464,6 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { sd_match(N1, m_UMaxLike(m_Specific(A), m_Specific(B)))) return DAG.getNegative(DAG.getNode(ISD::ABDU, DL, VT, A, B), DL, VT); - // (sub x, (select (ult x, y), 0, y)) -> (umin x, (sub x, y)) - // (sub x, (select (uge x, y), y, 0)) -> (umin x, (sub x, y)) - if (hasUMin(VT)) { - SDValue Y; - if (sd_match(N1, m_OneUse(m_Select(m_SetCC(m_Specific(N0), m_Value(Y), - m_SpecificCondCode(ISD::SETULT)), - m_Zero(), m_Deferred(Y)))) || - sd_match(N1, m_OneUse(m_Select(m_SetCC(m_Specific(N0), m_Value(Y), - m_SpecificCondCode(ISD::SETUGE)), - m_Deferred(Y), m_Zero())))) - return DAG.getNode(ISD::UMIN, DL, VT, N0, - DAG.getNode(ISD::SUB, DL, VT, N0, Y)); - } - return SDValue(); } @@ -12173,6 +12181,30 @@ SDValue DAGCombiner::foldSelectToABD(SDValue LHS, SDValue RHS, SDValue True, return SDValue(); } +// ([v]select (ugt x, C), (add x, ~C), x) -> (umin (add x, ~C), x) +// ([v]select (ult x, C), x, (add x, -C)) -> (umin x, (add x, -C)) +SDValue DAGCombiner::foldSelectToUMin(SDValue LHS, SDValue RHS, SDValue True, + SDValue False, ISD::CondCode CC, + const SDLoc &DL) { + APInt C; + EVT VT = True.getValueType(); + if (sd_match(RHS, m_ConstInt(C)) && hasUMin(VT)) { + if (CC == ISD::SETUGT && LHS == False && + sd_match(True, m_Add(m_Specific(False), m_SpecificInt(~C)))) { + SDValue AddC = DAG.getConstant(~C, DL, VT); + SDValue Add = DAG.getNode(ISD::ADD, DL, VT, False, AddC); + return DAG.getNode(ISD::UMIN, DL, VT, Add, False); + } + if (CC == ISD::SETULT && LHS == True && + sd_match(False, m_Add(m_Specific(True), m_SpecificInt(-C)))) { + SDValue AddC = DAG.getConstant(-C, DL, VT); + SDValue Add = DAG.getNode(ISD::ADD, DL, VT, True, AddC); + return DAG.getNode(ISD::UMIN, DL, VT, True, Add); + } + } + return SDValue(); +} + SDValue DAGCombiner::visitSELECT(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -12358,24 +12390,8 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) { // (select (ugt x, C), (add x, ~C), x) -> (umin (add x, ~C), x) // (select (ult x, C), x, (add x, -C)) -> (umin x, (add x, -C)) - APInt C; - if (sd_match(Cond1, m_ConstInt(C)) && hasUMin(VT)) { - if (CC == ISD::SETUGT && Cond0 == N2 && - sd_match(N1, m_Add(m_Specific(N2), m_SpecificInt(~C)))) { - // The resulting code relies on an unsigned wrap in ADD. - // Recreating ADD to drop possible nuw/nsw flags. - SDValue AddC = DAG.getConstant(~C, DL, VT); - SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N2, AddC); - return DAG.getNode(ISD::UMIN, DL, VT, Add, N2); - } - if (CC == ISD::SETULT && Cond0 == N1 && - sd_match(N2, m_Add(m_Specific(N1), m_SpecificInt(-C)))) { - // Ditto. - SDValue AddC = DAG.getConstant(-C, DL, VT); - SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, AddC); - return DAG.getNode(ISD::UMIN, DL, VT, N1, Add); - } - } + if (SDValue UMin = foldSelectToUMin(Cond0, Cond1, N1, N2, CC, DL)) + return UMin; } if (!VT.isVector()) @@ -13412,6 +13428,11 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) { } } } + + // (vselect (ugt x, C), (add x, ~C), x) -> (umin (add x, ~C), x) + // (vselect (ult x, C), x, (add x, -C)) -> (umin x, (add x, -C)) + if (SDValue UMin = foldSelectToUMin(LHS, RHS, N1, N2, CC, DL)) + return UMin; } if (SimplifySelectOps(N, N1, N2)) -- cgit v1.1