diff options
Diffstat (limited to 'llvm/lib/Target/AMDGPU/SIISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/AMDGPU/SIISelLowering.cpp | 76 |
1 files changed, 76 insertions, 0 deletions
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp index 8d51ec6..9017f4f 100644 --- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp @@ -15896,6 +15896,78 @@ SDValue SITargetLowering::performClampCombine(SDNode *N, return SDValue(CSrc, 0); } +SDValue SITargetLowering::performSelectCombine(SDNode *N, + DAGCombinerInfo &DCI) const { + + // Try to fold CMP + SELECT patterns with shared constants (both FP and + // integer). + // Detect when CMP and SELECT use the same constant and fold them to avoid + // loading the constant twice. Specifically handles patterns like: + // %cmp = icmp eq i32 %val, 4242 + // %sel = select i1 %cmp, i32 4242, i32 %other + // It can be optimized to reuse %val instead of 4242 in select. + SDValue Cond = N->getOperand(0); + SDValue TrueVal = N->getOperand(1); + SDValue FalseVal = N->getOperand(2); + + // Check if condition is a comparison. + if (Cond.getOpcode() != ISD::SETCC) + return SDValue(); + + SDValue LHS = Cond.getOperand(0); + SDValue RHS = Cond.getOperand(1); + ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get(); + + bool isFloatingPoint = LHS.getValueType().isFloatingPoint(); + bool isInteger = LHS.getValueType().isInteger(); + + // Handle simple floating-point and integer types only. + if (!isFloatingPoint && !isInteger) + return SDValue(); + + bool isEquality = CC == (isFloatingPoint ? ISD::SETOEQ : ISD::SETEQ); + bool isNonEquality = CC == (isFloatingPoint ? ISD::SETONE : ISD::SETNE); + if (!isEquality && !isNonEquality) + return SDValue(); + + SDValue ArgVal, ConstVal; + if ((isFloatingPoint && isa<ConstantFPSDNode>(RHS)) || + (isInteger && isa<ConstantSDNode>(RHS))) { + ConstVal = RHS; + ArgVal = LHS; + } else if ((isFloatingPoint && isa<ConstantFPSDNode>(LHS)) || + (isInteger && isa<ConstantSDNode>(LHS))) { + ConstVal = LHS; + ArgVal = RHS; + } else { + return SDValue(); + } + + // Skip optimization for inlinable immediates. + if (isFloatingPoint) { + const APFloat &Val = cast<ConstantFPSDNode>(ConstVal)->getValueAPF(); + if (!Val.isNormal() || Subtarget->getInstrInfo()->isInlineConstant(Val)) + return SDValue(); + } else { + if (AMDGPU::isInlinableIntLiteral( + cast<ConstantSDNode>(ConstVal)->getSExtValue())) + return SDValue(); + } + + // For equality and non-equality comparisons, patterns: + // select (setcc x, const), const, y -> select (setcc x, const), x, y + // select (setccinv x, const), y, const -> select (setccinv x, const), y, x + if (!(isEquality && TrueVal == ConstVal) && + !(isNonEquality && FalseVal == ConstVal)) + return SDValue(); + + SDValue SelectLHS = (isEquality && TrueVal == ConstVal) ? ArgVal : TrueVal; + SDValue SelectRHS = + (isNonEquality && FalseVal == ConstVal) ? ArgVal : FalseVal; + return DCI.DAG.getNode(ISD::SELECT, SDLoc(N), N->getValueType(0), Cond, + SelectLHS, SelectRHS); +} + SDValue SITargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { switch (N->getOpcode()) { @@ -15944,6 +16016,10 @@ SDValue SITargetLowering::PerformDAGCombine(SDNode *N, return performFMulCombine(N, DCI); case ISD::SETCC: return performSetCCCombine(N, DCI); + case ISD::SELECT: + if (auto Res = performSelectCombine(N, DCI)) + return Res; + break; case ISD::FMAXNUM: case ISD::FMINNUM: case ISD::FMAXNUM_IEEE: |