aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp')
-rw-r--r--llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp83
1 files changed, 83 insertions, 0 deletions
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index 280f87b..3d040fb 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -4843,11 +4843,94 @@ AMDGPUTargetLowering::foldFreeOpFromSelect(TargetLowering::DAGCombinerInfo &DCI,
return SDValue();
}
+// Detect when CMP and SELECT use the same constant and fold them to avoid
+// loading the constant twice. Specifically handles patterns like:
+// %cmp = icmp eq i32 %val, 4242
+// %sel = select i1 %cmp, i32 4242, i32 %other
+// It can be optimized to reuse %val instead of 4242 in select.
+static SDValue
+foldCmpSelectWithSharedConstant(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+ const AMDGPUSubtarget *ST) {
+ SDValue Cond = N->getOperand(0);
+ SDValue TrueVal = N->getOperand(1);
+ SDValue FalseVal = N->getOperand(2);
+
+ // Check if condition is a comparison.
+ if (Cond.getOpcode() != ISD::SETCC)
+ return SDValue();
+
+ SDValue LHS = Cond.getOperand(0);
+ SDValue RHS = Cond.getOperand(1);
+ ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
+
+ bool isFloatingPoint = LHS.getValueType().isFloatingPoint();
+ bool isInteger = LHS.getValueType().isInteger();
+
+ // Handle simple floating-point and integer types only.
+ if (!isFloatingPoint && !isInteger)
+ return SDValue();
+
+ bool isEquality = CC == (isFloatingPoint ? ISD::SETOEQ : ISD::SETEQ);
+ bool isNonEquality = CC == (isFloatingPoint ? ISD::SETONE : ISD::SETNE);
+ if (!isEquality && !isNonEquality)
+ return SDValue();
+
+ SDValue ArgVal, ConstVal;
+ if ((isFloatingPoint && isa<ConstantFPSDNode>(RHS)) ||
+ (isInteger && isa<ConstantSDNode>(RHS))) {
+ ConstVal = RHS;
+ ArgVal = LHS;
+ } else if ((isFloatingPoint && isa<ConstantFPSDNode>(LHS)) ||
+ (isInteger && isa<ConstantSDNode>(LHS))) {
+ ConstVal = LHS;
+ ArgVal = RHS;
+ } else {
+ return SDValue();
+ }
+
+ // Check if constant should not be optimized - early return if not.
+ if (isFloatingPoint) {
+ const APFloat &Val = cast<ConstantFPSDNode>(ConstVal)->getValueAPF();
+ const GCNSubtarget *GCNST = static_cast<const GCNSubtarget *>(ST);
+
+ // Only optimize normal floating-point values (finite, non-zero, and
+ // non-subnormal as per IEEE 754), skip optimization for inlinable
+ // floating-point constants.
+ if (!Val.isNormal() || GCNST->getInstrInfo()->isInlineConstant(Val))
+ return SDValue();
+ } else {
+ int64_t IntVal = cast<ConstantSDNode>(ConstVal)->getSExtValue();
+
+ // Skip optimization for inlinable integer immediates.
+ // Inlinable immediates include: -16 to 64 (inclusive).
+ if (IntVal >= -16 && IntVal <= 64)
+ return SDValue();
+ }
+
+ // For equality and non-equality comparisons, patterns:
+ // select (setcc x, const), const, y -> select (setcc x, const), x, y
+ // select (setccinv x, const), y, const -> select (setccinv x, const), y, x
+ if (!(isEquality && TrueVal == ConstVal) &&
+ !(isNonEquality && FalseVal == ConstVal))
+ return SDValue();
+
+ SDValue SelectLHS = (isEquality && TrueVal == ConstVal) ? ArgVal : TrueVal;
+ SDValue SelectRHS =
+ (isNonEquality && FalseVal == ConstVal) ? ArgVal : FalseVal;
+ return DCI.DAG.getNode(ISD::SELECT, SDLoc(N), N->getValueType(0), Cond,
+ SelectLHS, SelectRHS);
+}
+
SDValue AMDGPUTargetLowering::performSelectCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
if (SDValue Folded = foldFreeOpFromSelect(DCI, SDValue(N, 0)))
return Folded;
+ // Try to fold CMP + SELECT patterns with shared constants (both FP and
+ // integer).
+ if (SDValue Folded = foldCmpSelectWithSharedConstant(N, DCI, Subtarget))
+ return Folded;
+
SDValue Cond = N->getOperand(0);
if (Cond.getOpcode() != ISD::SETCC)
return SDValue();