aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/X86/X86ISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/X86/X86ISelLowering.cpp')
-rw-r--r--llvm/lib/Target/X86/X86ISelLowering.cpp383
1 files changed, 181 insertions, 202 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 624cff2..007074c 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -27946,67 +27946,6 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget,
return DAG.getNode(ISD::MERGE_VALUES, dl, Op->getVTList(), SetCC,
Operation.getValue(1));
}
- case Intrinsic::x86_t2rpntlvwz0rs_internal:
- case Intrinsic::x86_t2rpntlvwz0rst1_internal:
- case Intrinsic::x86_t2rpntlvwz1rs_internal:
- case Intrinsic::x86_t2rpntlvwz1rst1_internal:
- case Intrinsic::x86_t2rpntlvwz0_internal:
- case Intrinsic::x86_t2rpntlvwz0t1_internal:
- case Intrinsic::x86_t2rpntlvwz1_internal:
- case Intrinsic::x86_t2rpntlvwz1t1_internal: {
- auto *X86MFI = DAG.getMachineFunction().getInfo<X86MachineFunctionInfo>();
- X86MFI->setAMXProgModel(AMXProgModelEnum::ManagedRA);
- unsigned IntNo = Op.getConstantOperandVal(1);
- unsigned Opc = 0;
- switch (IntNo) {
- default:
- llvm_unreachable("Unexpected intrinsic!");
- case Intrinsic::x86_t2rpntlvwz0_internal:
- Opc = X86::PT2RPNTLVWZ0V;
- break;
- case Intrinsic::x86_t2rpntlvwz0t1_internal:
- Opc = X86::PT2RPNTLVWZ0T1V;
- break;
- case Intrinsic::x86_t2rpntlvwz1_internal:
- Opc = X86::PT2RPNTLVWZ1V;
- break;
- case Intrinsic::x86_t2rpntlvwz1t1_internal:
- Opc = X86::PT2RPNTLVWZ1T1V;
- break;
- case Intrinsic::x86_t2rpntlvwz0rs_internal:
- Opc = X86::PT2RPNTLVWZ0RSV;
- break;
- case Intrinsic::x86_t2rpntlvwz0rst1_internal:
- Opc = X86::PT2RPNTLVWZ0RST1V;
- break;
- case Intrinsic::x86_t2rpntlvwz1rs_internal:
- Opc = X86::PT2RPNTLVWZ1RSV;
- break;
- case Intrinsic::x86_t2rpntlvwz1rst1_internal:
- Opc = X86::PT2RPNTLVWZ1RST1V;
- break;
- }
-
- SDLoc DL(Op);
- SDVTList VTs = DAG.getVTList(MVT::Untyped, MVT::Other);
-
- SDValue Ops[] = {Op.getOperand(2), // Row
- Op.getOperand(3), // Col0
- Op.getOperand(4), // Col1
- Op.getOperand(5), // Base
- DAG.getTargetConstant(1, DL, MVT::i8), // Scale
- Op.getOperand(6), // Index
- DAG.getTargetConstant(0, DL, MVT::i32), // Disp
- DAG.getRegister(0, MVT::i16), // Segment
- Op.getOperand(0)}; // Chain
-
- MachineSDNode *Res = DAG.getMachineNode(Opc, DL, VTs, Ops);
- SDValue Res0 = DAG.getTargetExtractSubreg(X86::sub_t0, DL, MVT::x86amx,
- SDValue(Res, 0));
- SDValue Res1 = DAG.getTargetExtractSubreg(X86::sub_t1, DL, MVT::x86amx,
- SDValue(Res, 0));
- return DAG.getMergeValues({Res0, Res1, SDValue(Res, 1)}, DL);
- }
case Intrinsic::x86_atomic_bts_rm:
case Intrinsic::x86_atomic_btc_rm:
case Intrinsic::x86_atomic_btr_rm: {
@@ -37745,10 +37684,6 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
assert (Imm < 8 && "Illegal tmm index");
return X86::TMM0 + Imm;
};
- auto TMMImmToTMMPair = [](unsigned Imm) {
- assert(Imm < 8 && "Illegal tmm pair index.");
- return X86::TMM0_TMM1 + Imm / 2;
- };
switch (MI.getOpcode()) {
default:
llvm_unreachable("Unexpected instr type to insert");
@@ -38129,53 +38064,25 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
case X86::PTDPBHF8PS:
case X86::PTDPHBF8PS:
case X86::PTDPHF8PS:
- case X86::PTTDPBF16PS:
- case X86::PTTDPFP16PS:
- case X86::PTTCMMIMFP16PS:
- case X86::PTTCMMRLFP16PS:
- case X86::PTCONJTCMMIMFP16PS:
- case X86::PTMMULTF32PS:
- case X86::PTTMMULTF32PS: {
+ case X86::PTMMULTF32PS: {
unsigned Opc;
switch (MI.getOpcode()) {
default: llvm_unreachable("illegal opcode!");
+ // clang-format off
case X86::PTDPBSSD: Opc = X86::TDPBSSD; break;
case X86::PTDPBSUD: Opc = X86::TDPBSUD; break;
case X86::PTDPBUSD: Opc = X86::TDPBUSD; break;
case X86::PTDPBUUD: Opc = X86::TDPBUUD; break;
case X86::PTDPBF16PS: Opc = X86::TDPBF16PS; break;
case X86::PTDPFP16PS: Opc = X86::TDPFP16PS; break;
- case X86::PTCMMIMFP16PS:
- Opc = X86::TCMMIMFP16PS;
- break;
- case X86::PTCMMRLFP16PS:
- Opc = X86::TCMMRLFP16PS;
- break;
+ case X86::PTCMMIMFP16PS: Opc = X86::TCMMIMFP16PS; break;
+ case X86::PTCMMRLFP16PS: Opc = X86::TCMMRLFP16PS; break;
case X86::PTDPBF8PS: Opc = X86::TDPBF8PS; break;
case X86::PTDPBHF8PS: Opc = X86::TDPBHF8PS; break;
case X86::PTDPHBF8PS: Opc = X86::TDPHBF8PS; break;
case X86::PTDPHF8PS: Opc = X86::TDPHF8PS; break;
- case X86::PTTDPBF16PS:
- Opc = X86::TTDPBF16PS;
- break;
- case X86::PTTDPFP16PS:
- Opc = X86::TTDPFP16PS;
- break;
- case X86::PTTCMMIMFP16PS:
- Opc = X86::TTCMMIMFP16PS;
- break;
- case X86::PTTCMMRLFP16PS:
- Opc = X86::TTCMMRLFP16PS;
- break;
- case X86::PTCONJTCMMIMFP16PS:
- Opc = X86::TCONJTCMMIMFP16PS;
- break;
- case X86::PTMMULTF32PS:
- Opc = X86::TMMULTF32PS;
- break;
- case X86::PTTMMULTF32PS:
- Opc = X86::TTMMULTF32PS;
- break;
+ case X86::PTMMULTF32PS: Opc = X86::TMMULTF32PS; break;
+ // clang-format on
}
MachineInstrBuilder MIB = BuildMI(*BB, MI, MIMD, TII->get(Opc));
@@ -38246,70 +38153,6 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
MI.eraseFromParent(); // The pseudo is gone now.
return BB;
}
- case X86::PT2RPNTLVWZ0:
- case X86::PT2RPNTLVWZ0T1:
- case X86::PT2RPNTLVWZ1:
- case X86::PT2RPNTLVWZ1T1:
- case X86::PT2RPNTLVWZ0RS:
- case X86::PT2RPNTLVWZ0RST1:
- case X86::PT2RPNTLVWZ1RS:
- case X86::PT2RPNTLVWZ1RST1: {
- const DebugLoc &DL = MI.getDebugLoc();
- unsigned Opc;
-#define GET_EGPR_IF_ENABLED(OPC) (Subtarget.hasEGPR() ? OPC##_EVEX : OPC)
- switch (MI.getOpcode()) {
- default:
- llvm_unreachable("Unexpected instruction!");
- case X86::PT2RPNTLVWZ0:
- Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0);
- break;
- case X86::PT2RPNTLVWZ0T1:
- Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0T1);
- break;
- case X86::PT2RPNTLVWZ1:
- Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1);
- break;
- case X86::PT2RPNTLVWZ1T1:
- Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1T1);
- break;
- case X86::PT2RPNTLVWZ0RS:
- Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0RS);
- break;
- case X86::PT2RPNTLVWZ0RST1:
- Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0RST1);
- break;
- case X86::PT2RPNTLVWZ1RS:
- Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1RS);
- break;
- case X86::PT2RPNTLVWZ1RST1:
- Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1RST1);
- break;
- }
-#undef GET_EGPR_IF_ENABLED
- MachineInstrBuilder MIB = BuildMI(*BB, MI, DL, TII->get(Opc));
- MIB.addReg(TMMImmToTMMPair(MI.getOperand(0).getImm()), RegState::Define);
-
- MIB.add(MI.getOperand(1)); // base
- MIB.add(MI.getOperand(2)); // scale
- MIB.add(MI.getOperand(3)); // index
- MIB.add(MI.getOperand(4)); // displacement
- MIB.add(MI.getOperand(5)); // segment
- MI.eraseFromParent(); // The pseudo is gone now.
- return BB;
- }
- case X86::PTTRANSPOSED:
- case X86::PTCONJTFP16: {
- const DebugLoc &DL = MI.getDebugLoc();
- unsigned Opc = MI.getOpcode() == X86::PTTRANSPOSED ? X86::TTRANSPOSED
- : X86::TCONJTFP16;
-
- MachineInstrBuilder MIB = BuildMI(*BB, MI, DL, TII->get(Opc));
- MIB.addReg(TMMImmToTMMReg(MI.getOperand(0).getImm()), RegState::Define);
- MIB.addReg(TMMImmToTMMReg(MI.getOperand(1).getImm()), RegState::Undef);
-
- MI.eraseFromParent(); // The pseudo is gone now.
- return BB;
- }
case X86::PTCVTROWPS2BF16Hrri:
case X86::PTCVTROWPS2BF16Lrri:
case X86::PTCVTROWPS2PHHrri:
@@ -48778,10 +48621,9 @@ static SDValue combinePTESTCC(SDValue EFLAGS, X86::CondCode &CC,
SDValue BC0 = peekThroughBitcasts(Op0);
if (BC0.getOpcode() == X86ISD::PCMPEQ &&
ISD::isBuildVectorAllZeros(BC0.getOperand(1).getNode())) {
- SDLoc DL(EFLAGS);
CC = (CC == X86::COND_B ? X86::COND_E : X86::COND_NE);
- SDValue X = DAG.getBitcast(OpVT, BC0.getOperand(0));
- return DAG.getNode(EFLAGS.getOpcode(), DL, VT, X, X);
+ SDValue X = DAG.getBitcast(OpVT, DAG.getFreeze(BC0.getOperand(0)));
+ return DAG.getNode(EFLAGS.getOpcode(), SDLoc(EFLAGS), VT, X, X);
}
}
}
@@ -48837,7 +48679,7 @@ static SDValue combinePTESTCC(SDValue EFLAGS, X86::CondCode &CC,
MVT FloatSVT = MVT::getFloatingPointVT(EltBits);
MVT FloatVT =
MVT::getVectorVT(FloatSVT, OpVT.getSizeInBits() / EltBits);
- Res = DAG.getBitcast(FloatVT, Res);
+ Res = DAG.getBitcast(FloatVT, DAG.getFreeze(Res));
return DAG.getNode(X86ISD::TESTP, SDLoc(EFLAGS), VT, Res, Res);
} else if (EltBits == 16) {
MVT MovmskVT = BCVT.is128BitVector() ? MVT::v16i8 : MVT::v32i8;
@@ -48856,8 +48698,30 @@ static SDValue combinePTESTCC(SDValue EFLAGS, X86::CondCode &CC,
}
// TESTZ(X,-1) == TESTZ(X,X)
- if (ISD::isBuildVectorAllOnes(Op1.getNode()))
+ if (ISD::isBuildVectorAllOnes(Op1.getNode())) {
+ Op0 = DAG.getFreeze(Op0);
return DAG.getNode(EFLAGS.getOpcode(), SDLoc(EFLAGS), VT, Op0, Op0);
+ }
+
+ // Attempt to convert PTESTZ(X,SIGNMASK) -> VTESTPD/PSZ(X,X) on AVX targets.
+ if (EFLAGS.getOpcode() == X86ISD::PTEST && Subtarget.hasAVX()) {
+ KnownBits KnownOp1 = DAG.computeKnownBits(Op1);
+ assert(KnownOp1.getBitWidth() == 64 &&
+ "Illegal PTEST vector element width");
+ if (KnownOp1.isConstant()) {
+ const APInt &Mask = KnownOp1.getConstant();
+ if (Mask.isSignMask()) {
+ MVT FpVT = MVT::getVectorVT(MVT::f64, OpVT.getSizeInBits() / 64);
+ Op0 = DAG.getBitcast(FpVT, DAG.getFreeze(Op0));
+ return DAG.getNode(X86ISD::TESTP, SDLoc(EFLAGS), VT, Op0, Op0);
+ }
+ if (Mask.isSplat(32) && Mask.trunc(32).isSignMask()) {
+ MVT FpVT = MVT::getVectorVT(MVT::f32, OpVT.getSizeInBits() / 32);
+ Op0 = DAG.getBitcast(FpVT, DAG.getFreeze(Op0));
+ return DAG.getNode(X86ISD::TESTP, SDLoc(EFLAGS), VT, Op0, Op0);
+ }
+ }
+ }
// TESTZ(OR(LO(X),HI(X)),OR(LO(Y),HI(Y))) -> TESTZ(X,Y)
// TODO: Add COND_NE handling?
@@ -53480,6 +53344,105 @@ static SDValue combineMaskedStore(SDNode *N, SelectionDAG &DAG,
return SDValue();
}
+// Look for a RMW operation that only touches one bit of a larger than legal
+// type and fold it to a BTC/BTR/BTS or bit insertion pattern acting on a single
+// i32 sub value.
+static SDValue narrowBitOpRMW(StoreSDNode *St, const SDLoc &DL,
+ SelectionDAG &DAG,
+ const X86Subtarget &Subtarget) {
+ using namespace SDPatternMatch;
+
+ // Only handle normal stores and its chain was a matching normal load.
+ auto *Ld = dyn_cast<LoadSDNode>(St->getChain());
+ if (!ISD::isNormalStore(St) || !St->isSimple() || !Ld ||
+ !ISD::isNormalLoad(Ld) || !Ld->isSimple() ||
+ Ld->getBasePtr() != St->getBasePtr() ||
+ Ld->getOffset() != St->getOffset())
+ return SDValue();
+
+ SDValue LoadVal(Ld, 0);
+ SDValue StoredVal = St->getValue();
+ EVT VT = StoredVal.getValueType();
+
+ // Only narrow larger than legal scalar integers.
+ if (!VT.isScalarInteger() ||
+ VT.getSizeInBits() <= (Subtarget.is64Bit() ? 64 : 32))
+ return SDValue();
+
+ // BTR: X & ~(1 << ShAmt)
+ // BTS: X | (1 << ShAmt)
+ // BTC: X ^ (1 << ShAmt)
+ //
+ // BitInsert: (X & ~(1 << ShAmt)) | (InsertBit << ShAmt)
+ SDValue InsertBit, ShAmt;
+ if (!StoredVal.hasOneUse() ||
+ !(sd_match(StoredVal, m_And(m_Specific(LoadVal),
+ m_Not(m_Shl(m_One(), m_Value(ShAmt))))) ||
+ sd_match(StoredVal,
+ m_Or(m_Specific(LoadVal), m_Shl(m_One(), m_Value(ShAmt)))) ||
+ sd_match(StoredVal,
+ m_Xor(m_Specific(LoadVal), m_Shl(m_One(), m_Value(ShAmt)))) ||
+ sd_match(StoredVal,
+ m_Or(m_And(m_Specific(LoadVal),
+ m_Not(m_Shl(m_One(), m_Value(ShAmt)))),
+ m_Shl(m_Value(InsertBit), m_Deferred(ShAmt))))))
+ return SDValue();
+
+ // Ensure the shift amount is in bounds.
+ KnownBits KnownAmt = DAG.computeKnownBits(ShAmt);
+ if (KnownAmt.getMaxValue().uge(VT.getSizeInBits()))
+ return SDValue();
+
+ // If we're inserting a bit then it must be the LSB.
+ if (InsertBit) {
+ KnownBits KnownInsert = DAG.computeKnownBits(InsertBit);
+ if (KnownInsert.countMinLeadingZeros() < (VT.getSizeInBits() - 1))
+ return SDValue();
+ }
+
+ // Split the shift into an alignment shift that moves the active i32 block to
+ // the bottom bits for truncation and a modulo shift that can act on the i32.
+ EVT AmtVT = ShAmt.getValueType();
+ SDValue AlignAmt = DAG.getNode(ISD::AND, DL, AmtVT, ShAmt,
+ DAG.getSignedConstant(-32LL, DL, AmtVT));
+ SDValue ModuloAmt =
+ DAG.getNode(ISD::AND, DL, AmtVT, ShAmt, DAG.getConstant(31, DL, AmtVT));
+ ModuloAmt = DAG.getZExtOrTrunc(ModuloAmt, DL, MVT::i8);
+
+ // Compute the byte offset for the i32 block that is changed by the RMW.
+ // combineTruncate will adjust the load for us in a similar way.
+ EVT PtrVT = St->getBasePtr().getValueType();
+ SDValue PtrBitOfs = DAG.getZExtOrTrunc(AlignAmt, DL, PtrVT);
+ SDValue PtrByteOfs = DAG.getNode(ISD::SRL, DL, PtrVT, PtrBitOfs,
+ DAG.getShiftAmountConstant(3, PtrVT, DL));
+ SDValue NewPtr = DAG.getMemBasePlusOffset(St->getBasePtr(), PtrByteOfs, DL,
+ SDNodeFlags::NoUnsignedWrap);
+
+ // Reconstruct the BTC/BTR/BTS pattern for the i32 block and store.
+ SDValue X = DAG.getNode(ISD::SRL, DL, VT, LoadVal, AlignAmt);
+ X = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, X);
+
+ SDValue Mask = DAG.getNode(ISD::SHL, DL, MVT::i32,
+ DAG.getConstant(1, DL, MVT::i32), ModuloAmt);
+
+ SDValue Res;
+ if (InsertBit) {
+ SDValue BitMask =
+ DAG.getNode(ISD::SHL, DL, MVT::i32,
+ DAG.getZExtOrTrunc(InsertBit, DL, MVT::i32), ModuloAmt);
+ Res =
+ DAG.getNode(ISD::AND, DL, MVT::i32, X, DAG.getNOT(DL, Mask, MVT::i32));
+ Res = DAG.getNode(ISD::OR, DL, MVT::i32, Res, BitMask);
+ } else {
+ if (StoredVal.getOpcode() == ISD::AND)
+ Mask = DAG.getNOT(DL, Mask, MVT::i32);
+ Res = DAG.getNode(StoredVal.getOpcode(), DL, MVT::i32, X, Mask);
+ }
+
+ return DAG.getStore(St->getChain(), DL, Res, NewPtr, St->getPointerInfo(),
+ Align(), St->getMemOperand()->getFlags());
+}
+
static SDValue combineStore(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
@@ -53706,6 +53669,9 @@ static SDValue combineStore(SDNode *N, SelectionDAG &DAG,
}
}
+ if (SDValue R = narrowBitOpRMW(St, dl, DAG, Subtarget))
+ return R;
+
// Convert store(cmov(load(p), x, CC), p) to cstore(x, p, CC)
// store(cmov(x, load(p), CC), p) to cstore(x, p, InvertCC)
if ((VT == MVT::i16 || VT == MVT::i32 || VT == MVT::i64) &&
@@ -54493,6 +54459,7 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
static SDValue detectPMADDUBSW(SDValue In, EVT VT, SelectionDAG &DAG,
const X86Subtarget &Subtarget,
const SDLoc &DL) {
+ using namespace SDPatternMatch;
if (!VT.isVector() || !Subtarget.hasSSSE3())
return SDValue();
@@ -54502,42 +54469,19 @@ static SDValue detectPMADDUBSW(SDValue In, EVT VT, SelectionDAG &DAG,
return SDValue();
SDValue SSatVal = detectSSatPattern(In, VT);
- if (!SSatVal || SSatVal.getOpcode() != ISD::ADD)
- return SDValue();
-
- // Ok this is a signed saturation of an ADD. See if this ADD is adding pairs
- // of multiplies from even/odd elements.
- SDValue N0 = SSatVal.getOperand(0);
- SDValue N1 = SSatVal.getOperand(1);
-
- if (N0.getOpcode() != ISD::MUL || N1.getOpcode() != ISD::MUL)
+ if (!SSatVal)
return SDValue();
- SDValue N00 = N0.getOperand(0);
- SDValue N01 = N0.getOperand(1);
- SDValue N10 = N1.getOperand(0);
- SDValue N11 = N1.getOperand(1);
-
+ // See if this is a signed saturation of an ADD, adding pairs of multiplies
+ // from even/odd elements, from zero_extend/sign_extend operands.
+ //
// TODO: Handle constant vectors and use knownbits/computenumsignbits?
- // Canonicalize zero_extend to LHS.
- if (N01.getOpcode() == ISD::ZERO_EXTEND)
- std::swap(N00, N01);
- if (N11.getOpcode() == ISD::ZERO_EXTEND)
- std::swap(N10, N11);
-
- // Ensure we have a zero_extend and a sign_extend.
- if (N00.getOpcode() != ISD::ZERO_EXTEND ||
- N01.getOpcode() != ISD::SIGN_EXTEND ||
- N10.getOpcode() != ISD::ZERO_EXTEND ||
- N11.getOpcode() != ISD::SIGN_EXTEND)
+ SDValue N00, N01, N10, N11;
+ if (!sd_match(SSatVal,
+ m_Add(m_Mul(m_ZExt(m_Value(N00)), m_SExt(m_Value(N01))),
+ m_Mul(m_ZExt(m_Value(N10)), m_SExt(m_Value(N11))))))
return SDValue();
- // Peek through the extends.
- N00 = N00.getOperand(0);
- N01 = N01.getOperand(0);
- N10 = N10.getOperand(0);
- N11 = N11.getOperand(0);
-
// Ensure the extend is from vXi8.
if (N00.getValueType().getVectorElementType() != MVT::i8 ||
N01.getValueType().getVectorElementType() != MVT::i8 ||
@@ -54660,8 +54604,9 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,
// truncation, see if we can convert the shift into a pointer offset instead.
// Limit this to normal (non-ext) scalar integer loads.
if (SrcVT.isScalarInteger() && Src.getOpcode() == ISD::SRL &&
- Src.hasOneUse() && Src.getOperand(0).hasOneUse() &&
- ISD::isNormalLoad(Src.getOperand(0).getNode())) {
+ Src.hasOneUse() && ISD::isNormalLoad(Src.getOperand(0).getNode()) &&
+ (Src.getOperand(0).hasOneUse() ||
+ !DAG.getTargetLoweringInfo().isOperationLegal(ISD::LOAD, SrcVT))) {
auto *Ld = cast<LoadSDNode>(Src.getOperand(0));
if (Ld->isSimple() && VT.isByteSized() &&
isPowerOf2_64(VT.getSizeInBits())) {
@@ -54669,9 +54614,11 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,
KnownBits KnownAmt = DAG.computeKnownBits(ShAmt);
// Check the shift amount is byte aligned.
// Check the truncation doesn't use any shifted in (zero) top bits.
+ // Check the shift amount doesn't depend on the original load.
if (KnownAmt.countMinTrailingZeros() >= 3 &&
KnownAmt.getMaxValue().ule(SrcVT.getSizeInBits() -
- VT.getSizeInBits())) {
+ VT.getSizeInBits()) &&
+ !Ld->isPredecessorOf(ShAmt.getNode())) {
EVT PtrVT = Ld->getBasePtr().getValueType();
SDValue PtrBitOfs = DAG.getZExtOrTrunc(ShAmt, DL, PtrVT);
SDValue PtrByteOfs =
@@ -56459,6 +56406,7 @@ static SDValue combineAVX512SetCCToKMOV(EVT VT, SDValue Op0, ISD::CondCode CC,
static SDValue combineSetCC(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
+ using namespace SDPatternMatch;
const ISD::CondCode CC = cast<CondCodeSDNode>(N->getOperand(2))->get();
const SDValue LHS = N->getOperand(0);
const SDValue RHS = N->getOperand(1);
@@ -56517,6 +56465,37 @@ static SDValue combineSetCC(SDNode *N, SelectionDAG &DAG,
if (SDValue AndN = MatchAndCmpEq(RHS, LHS))
return DAG.getSetCC(DL, VT, AndN, DAG.getConstant(0, DL, OpVT), CC);
+ // If we're performing a bit test on a larger than legal type, attempt
+ // to (aligned) shift down the value to the bottom 32-bits and then
+ // perform the bittest on the i32 value.
+ // ICMP_ZERO(AND(X,SHL(1,IDX)))
+ // --> ICMP_ZERO(AND(TRUNC(SRL(X,AND(IDX,-32))),SHL(1,AND(IDX,31))))
+ if (isNullConstant(RHS) &&
+ OpVT.getScalarSizeInBits() > (Subtarget.is64Bit() ? 64 : 32)) {
+ SDValue X, ShAmt;
+ if (sd_match(LHS, m_OneUse(m_And(m_Value(X),
+ m_Shl(m_One(), m_Value(ShAmt)))))) {
+ // Only attempt this if the shift amount is known to be in bounds.
+ KnownBits KnownAmt = DAG.computeKnownBits(ShAmt);
+ if (KnownAmt.getMaxValue().ult(OpVT.getScalarSizeInBits())) {
+ EVT AmtVT = ShAmt.getValueType();
+ SDValue AlignAmt =
+ DAG.getNode(ISD::AND, DL, AmtVT, ShAmt,
+ DAG.getSignedConstant(-32LL, DL, AmtVT));
+ SDValue ModuloAmt = DAG.getNode(ISD::AND, DL, AmtVT, ShAmt,
+ DAG.getConstant(31, DL, AmtVT));
+ SDValue Mask = DAG.getNode(
+ ISD::SHL, DL, MVT::i32, DAG.getConstant(1, DL, MVT::i32),
+ DAG.getZExtOrTrunc(ModuloAmt, DL, MVT::i8));
+ X = DAG.getNode(ISD::SRL, DL, OpVT, X, AlignAmt);
+ X = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, X);
+ X = DAG.getNode(ISD::AND, DL, MVT::i32, X, Mask);
+ return DAG.getSetCC(DL, VT, X, DAG.getConstant(0, DL, MVT::i32),
+ CC);
+ }
+ }
+ }
+
// cmpeq(trunc(x),C) --> cmpeq(x,C)
// cmpne(trunc(x),C) --> cmpne(x,C)
// iff x upper bits are zero.