aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSimon Pilgrim <llvm-dev@redking.me.uk>2024-07-15 11:42:12 +0100
committerSimon Pilgrim <llvm-dev@redking.me.uk>2024-07-15 11:42:12 +0100
commitc2580afed7e55f13762d56400dc346f222ea5884 (patch)
treecfbd3c72b9f716cd951c520ba8ce76e7e82591d8
parent054d7b1283a5ebdf724f3ebc38b47e419f8f7a7f (diff)
downloadllvm-c2580afed7e55f13762d56400dc346f222ea5884.zip
llvm-c2580afed7e55f13762d56400dc346f222ea5884.tar.gz
llvm-c2580afed7e55f13762d56400dc346f222ea5884.tar.bz2
[X86] Convert shift+clamp -> avx2 shift folds to use SDPatternMatch::m_SetCC. NFC.
-rw-r--r--llvm/lib/Target/X86/X86ISelLowering.cpp50
1 files changed, 21 insertions, 29 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index a731541..91a5526 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -46193,15 +46193,13 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
if (N->getOpcode() == ISD::VSELECT &&
(LHS.getOpcode() == ISD::SRL || LHS.getOpcode() == ISD::SHL) &&
supportedVectorVarShift(VT, Subtarget, LHS.getOpcode())) {
- APInt SV;
+ using namespace llvm::SDPatternMatch;
// fold select(icmp_ult(amt,BW),shl(x,amt),0) -> avx2 psllv(x,amt)
// fold select(icmp_ult(amt,BW),srl(x,amt),0) -> avx2 psrlv(x,amt)
- if (Cond.getOpcode() == ISD::SETCC &&
- Cond.getOperand(0) == LHS.getOperand(1) &&
- cast<CondCodeSDNode>(Cond.getOperand(2))->get() == ISD::SETULT &&
- ISD::isConstantSplatVector(Cond.getOperand(1).getNode(), SV) &&
- ISD::isConstantSplatVectorAllZeros(RHS.getNode()) &&
- SV == VT.getScalarSizeInBits()) {
+ if (ISD::isConstantSplatVectorAllZeros(RHS.getNode()) &&
+ sd_match(Cond, m_SetCC(m_Specific(LHS.getOperand(1)),
+ m_SpecificInt(VT.getScalarSizeInBits()),
+ m_SpecificCondCode(ISD::SETULT)))) {
return DAG.getNode(LHS.getOpcode() == ISD::SRL ? X86ISD::VSRLV
: X86ISD::VSHLV,
DL, VT, LHS.getOperand(0), LHS.getOperand(1));
@@ -48020,10 +48018,12 @@ static SDValue combineShiftToPMULH(SDNode *N, SelectionDAG &DAG,
static SDValue combineShiftLeft(SDNode *N, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
+ using namespace llvm::SDPatternMatch;
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
EVT VT = N0.getValueType();
+ unsigned EltSizeInBits = VT.getScalarSizeInBits();
SDLoc DL(N);
// Exploits AVX2 VSHLV/VSRLV instructions for efficient unsigned vector shifts
@@ -48033,21 +48033,16 @@ static SDValue combineShiftLeft(SDNode *N, SelectionDAG &DAG,
SDValue Cond = N0.getOperand(0);
SDValue N00 = N0.getOperand(1);
SDValue N01 = N0.getOperand(2);
- APInt SV;
// fold shl(select(icmp_ult(amt,BW),x,0),amt) -> avx2 psllv(x,amt)
- if (Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == N1 &&
- cast<CondCodeSDNode>(Cond.getOperand(2))->get() == ISD::SETULT &&
- ISD::isConstantSplatVector(Cond.getOperand(1).getNode(), SV) &&
- ISD::isConstantSplatVectorAllZeros(N01.getNode()) &&
- SV == VT.getScalarSizeInBits()) {
+ if (ISD::isConstantSplatVectorAllZeros(N01.getNode()) &&
+ sd_match(Cond, m_SetCC(m_Specific(N1), m_SpecificInt(EltSizeInBits),
+ m_SpecificCondCode(ISD::SETULT)))) {
return DAG.getNode(X86ISD::VSHLV, DL, VT, N00, N1);
}
// fold shl(select(icmp_uge(amt,BW),0,x),amt) -> avx2 psllv(x,amt)
- if (Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == N1 &&
- cast<CondCodeSDNode>(Cond.getOperand(2))->get() == ISD::SETUGE &&
- ISD::isConstantSplatVector(Cond.getOperand(1).getNode(), SV) &&
- ISD::isConstantSplatVectorAllZeros(N00.getNode()) &&
- SV == VT.getScalarSizeInBits()) {
+ if (ISD::isConstantSplatVectorAllZeros(N00.getNode()) &&
+ sd_match(Cond, m_SetCC(m_Specific(N1), m_SpecificInt(EltSizeInBits),
+ m_SpecificCondCode(ISD::SETUGE)))) {
return DAG.getNode(X86ISD::VSHLV, DL, VT, N01, N1);
}
}
@@ -48160,9 +48155,11 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG,
static SDValue combineShiftRightLogical(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
+ using namespace llvm::SDPatternMatch;
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
EVT VT = N0.getValueType();
+ unsigned EltSizeInBits = VT.getScalarSizeInBits();
SDLoc DL(N);
if (SDValue V = combineShiftToPMULH(N, DAG, DL, Subtarget))
@@ -48175,21 +48172,16 @@ static SDValue combineShiftRightLogical(SDNode *N, SelectionDAG &DAG,
SDValue Cond = N0.getOperand(0);
SDValue N00 = N0.getOperand(1);
SDValue N01 = N0.getOperand(2);
- APInt SV;
// fold srl(select(icmp_ult(amt,BW),x,0),amt) -> avx2 psrlv(x,amt)
- if (Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == N1 &&
- cast<CondCodeSDNode>(Cond.getOperand(2))->get() == ISD::SETULT &&
- ISD::isConstantSplatVector(Cond.getOperand(1).getNode(), SV) &&
- ISD::isConstantSplatVectorAllZeros(N01.getNode()) &&
- SV == VT.getScalarSizeInBits()) {
+ if (ISD::isConstantSplatVectorAllZeros(N01.getNode()) &&
+ sd_match(Cond, m_SetCC(m_Specific(N1), m_SpecificInt(EltSizeInBits),
+ m_SpecificCondCode(ISD::SETULT)))) {
return DAG.getNode(X86ISD::VSRLV, DL, VT, N00, N1);
}
// fold srl(select(icmp_uge(amt,BW),0,x),amt) -> avx2 psrlv(x,amt)
- if (Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == N1 &&
- cast<CondCodeSDNode>(Cond.getOperand(2))->get() == ISD::SETUGE &&
- ISD::isConstantSplatVector(Cond.getOperand(1).getNode(), SV) &&
- ISD::isConstantSplatVectorAllZeros(N00.getNode()) &&
- SV == VT.getScalarSizeInBits()) {
+ if (ISD::isConstantSplatVectorAllZeros(N00.getNode()) &&
+ sd_match(Cond, m_SetCC(m_Specific(N1), m_SpecificInt(EltSizeInBits),
+ m_SpecificCondCode(ISD::SETUGE)))) {
return DAG.getNode(X86ISD::VSRLV, DL, VT, N01, N1);
}
}