diff options
Diffstat (limited to 'llvm/lib/Target/RISCV/RISCVISelLowering.cpp')
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 216 |
1 files changed, 109 insertions, 107 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index e0cf739..3b69eda 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -9186,7 +9186,7 @@ static SDValue lowerSelectToBinOp(SDNode *N, SelectionDAG &DAG, unsigned ShAmount = Log2_64(TrueM1); if (Subtarget.hasShlAdd(ShAmount)) return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, CondV, - DAG.getConstant(ShAmount, DL, VT), CondV); + DAG.getTargetConstant(ShAmount, DL, VT), CondV); } } // (select c, y, 0) -> -c & y @@ -15463,7 +15463,7 @@ static SDValue transformAddShlImm(SDNode *N, SelectionDAG &DAG, SDValue NS = (C0 < C1) ? N0->getOperand(0) : N1->getOperand(0); SDValue NL = (C0 > C1) ? N0->getOperand(0) : N1->getOperand(0); SDValue SHADD = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, NL, - DAG.getConstant(Diff, DL, VT), NS); + DAG.getTargetConstant(Diff, DL, VT), NS); return DAG.getNode(ISD::SHL, DL, VT, SHADD, DAG.getConstant(Bits, DL, VT)); } @@ -15501,7 +15501,7 @@ static SDValue combineShlAddIAddImpl(SDNode *N, SDValue AddI, SDValue Other, int64_t AddConst = AddVal.getSExtValue(); SDValue SHADD = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, SHLVal->getOperand(0), - DAG.getConstant(ShlConst, DL, VT), Other); + DAG.getTargetConstant(ShlConst, DL, VT), Other); return DAG.getNode(ISD::ADD, DL, VT, SHADD, DAG.getSignedConstant(AddConst, DL, VT)); } @@ -16495,6 +16495,45 @@ static SDValue expandMulToAddOrSubOfShl(SDNode *N, SelectionDAG &DAG, return DAG.getNode(Op, DL, VT, Shift1, Shift2); } +static SDValue getShlAddShlAdd(SDNode *N, SelectionDAG &DAG, unsigned ShX, + unsigned ShY, bool AddX) { + SDLoc DL(N); + EVT VT = N->getValueType(0); + SDValue X = N->getOperand(0); + SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, + DAG.getTargetConstant(ShY, DL, VT), X); + return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359, + DAG.getTargetConstant(ShX, DL, VT), AddX ? X : Mul359); +} + +static SDValue expandMulToShlAddShlAdd(SDNode *N, SelectionDAG &DAG, + uint64_t MulAmt) { + // 3/5/9 * 3/5/9 -> (shXadd (shYadd X, X), (shYadd X, X)) + switch (MulAmt) { + case 5 * 3: + return getShlAddShlAdd(N, DAG, 2, 1, /*AddX=*/false); + case 9 * 3: + return getShlAddShlAdd(N, DAG, 3, 1, /*AddX=*/false); + case 5 * 5: + return getShlAddShlAdd(N, DAG, 2, 2, /*AddX=*/false); + case 9 * 5: + return getShlAddShlAdd(N, DAG, 3, 2, /*AddX=*/false); + case 9 * 9: + return getShlAddShlAdd(N, DAG, 3, 3, /*AddX=*/false); + default: + break; + } + + // 2/4/8 * 3/5/9 + 1 -> (shXadd (shYadd X, X), X) + int ShX; + if (int ShY = isShifted359(MulAmt - 1, ShX)) { + assert(ShX != 0 && "MulAmt=4,6,10 handled before"); + if (ShX <= 3) + return getShlAddShlAdd(N, DAG, ShX, ShY, /*AddX=*/true); + } + return SDValue(); +} + // Try to expand a scalar multiply to a faster sequence. static SDValue expandMul(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, @@ -16524,18 +16563,17 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, if (Subtarget.hasVendorXqciac() && isInt<12>(CNode->getSExtValue())) return SDValue(); - // WARNING: The code below is knowingly incorrect with regards to undef semantics. - // We're adding additional uses of X here, and in principle, we should be freezing - // X before doing so. However, adding freeze here causes real regressions, and no - // other target properly freezes X in these cases either. - SDValue X = N->getOperand(0); - + // WARNING: The code below is knowingly incorrect with regards to undef + // semantics. We're adding additional uses of X here, and in principle, we + // should be freezing X before doing so. However, adding freeze here causes + // real regressions, and no other target properly freezes X in these cases + // either. if (Subtarget.hasShlAdd(3)) { + SDValue X = N->getOperand(0); int Shift; if (int ShXAmount = isShifted359(MulAmt, Shift)) { // 3/5/9 * 2^N -> shl (shXadd X, X), N SDLoc DL(N); - SDValue X = N->getOperand(0); // Put the shift first if we can fold a zext into the shift forming // a slli.uw. if (X.getOpcode() == ISD::AND && isa<ConstantSDNode>(X.getOperand(1)) && @@ -16543,80 +16581,40 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, SDValue Shl = DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(Shift, DL, VT)); return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Shl, - DAG.getConstant(ShXAmount, DL, VT), Shl); + DAG.getTargetConstant(ShXAmount, DL, VT), Shl); } // Otherwise, put the shl second so that it can fold with following // instructions (e.g. sext or add). SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, - DAG.getConstant(ShXAmount, DL, VT), X); + DAG.getTargetConstant(ShXAmount, DL, VT), X); return DAG.getNode(ISD::SHL, DL, VT, Mul359, DAG.getConstant(Shift, DL, VT)); } - // 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X) - int ShX; - int ShY; - switch (MulAmt) { - case 3 * 5: - ShY = 1; - ShX = 2; - break; - case 3 * 9: - ShY = 1; - ShX = 3; - break; - case 5 * 5: - ShX = ShY = 2; - break; - case 5 * 9: - ShY = 2; - ShX = 3; - break; - case 9 * 9: - ShX = ShY = 3; - break; - default: - ShX = ShY = 0; - break; - } - if (ShX) { + // 3/5/9 * 3/5/9 * 2^N - In particular, this covers multiples + // of 25 which happen to be quite common. + // (2/4/8 * 3/5/9 + 1) * 2^N + Shift = llvm::countr_zero(MulAmt); + if (SDValue V = expandMulToShlAddShlAdd(N, DAG, MulAmt >> Shift)) { + if (Shift == 0) + return V; SDLoc DL(N); - SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, - DAG.getConstant(ShY, DL, VT), X); - return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359, - DAG.getConstant(ShX, DL, VT), Mul359); + return DAG.getNode(ISD::SHL, DL, VT, V, DAG.getConstant(Shift, DL, VT)); } // If this is a power 2 + 2/4/8, we can use a shift followed by a single // shXadd. First check if this a sum of two power of 2s because that's // easy. Then count how many zeros are up to the first bit. - if (isPowerOf2_64(MulAmt & (MulAmt - 1))) { - unsigned ScaleShift = llvm::countr_zero(MulAmt); - if (ScaleShift >= 1 && ScaleShift < 4) { - unsigned ShiftAmt = Log2_64((MulAmt & (MulAmt - 1))); - SDLoc DL(N); - SDValue Shift1 = - DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShiftAmt, DL, VT)); - return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, - DAG.getConstant(ScaleShift, DL, VT), Shift1); - } + if (Shift >= 1 && Shift <= 3 && isPowerOf2_64(MulAmt & (MulAmt - 1))) { + unsigned ShiftAmt = llvm::countr_zero((MulAmt & (MulAmt - 1))); + SDLoc DL(N); + SDValue Shift1 = + DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShiftAmt, DL, VT)); + return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, + DAG.getTargetConstant(Shift, DL, VT), Shift1); } - // 2^(1,2,3) * 3,5,9 + 1 -> (shXadd (shYadd x, x), x) - // This is the two instruction form, there are also three instruction - // variants we could implement. e.g. - // (2^(1,2,3) * 3,5,9 + 1) << C2 - // 2^(C1>3) * 3,5,9 +/- 1 - if (int ShXAmount = isShifted359(MulAmt - 1, Shift)) { - assert(Shift != 0 && "MulAmt=4,6,10 handled before"); - if (Shift <= 3) { - SDLoc DL(N); - SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, - DAG.getConstant(ShXAmount, DL, VT), X); - return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359, - DAG.getConstant(Shift, DL, VT), X); - } - } + // TODO: 2^(C1>3) * 3,5,9 +/- 1 // 2^n + 2/4/8 + 1 -> (add (shl X, C1), (shXadd X, X)) if (MulAmt > 2 && isPowerOf2_64((MulAmt - 1) & (MulAmt - 2))) { @@ -16626,9 +16624,10 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, SDLoc DL(N); SDValue Shift1 = DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShiftAmt, DL, VT)); - return DAG.getNode(ISD::ADD, DL, VT, Shift1, - DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, - DAG.getConstant(ScaleShift, DL, VT), X)); + return DAG.getNode( + ISD::ADD, DL, VT, Shift1, + DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, + DAG.getTargetConstant(ScaleShift, DL, VT), X)); } } @@ -16643,29 +16642,10 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShAmt, DL, VT)); SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, - DAG.getConstant(Log2_64(Offset - 1), DL, VT), X); + DAG.getTargetConstant(Log2_64(Offset - 1), DL, VT), X); return DAG.getNode(ISD::SUB, DL, VT, Shift1, Mul359); } } - - for (uint64_t Divisor : {3, 5, 9}) { - if (MulAmt % Divisor != 0) - continue; - uint64_t MulAmt2 = MulAmt / Divisor; - // 3/5/9 * 3/5/9 * 2^N - In particular, this covers multiples - // of 25 which happen to be quite common. - if (int ShBAmount = isShifted359(MulAmt2, Shift)) { - SDLoc DL(N); - SDValue Mul359A = - DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, - DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X); - SDValue Mul359B = - DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359A, - DAG.getConstant(ShBAmount, DL, VT), Mul359A); - return DAG.getNode(ISD::SHL, DL, VT, Mul359B, - DAG.getConstant(Shift, DL, VT)); - } - } } if (SDValue V = expandMulToAddOrSubOfShl(N, DAG, MulAmt)) @@ -17887,6 +17867,7 @@ static SDValue combineOp_VLToVWOp_VL(SDNode *N, SmallVector<SDNode *> Worklist; SmallPtrSet<SDNode *, 8> Inserted; + SmallPtrSet<SDNode *, 8> ExtensionsToRemove; Worklist.push_back(N); Inserted.insert(N); SmallVector<CombineResult> CombinesToApply; @@ -17896,22 +17877,25 @@ static SDValue combineOp_VLToVWOp_VL(SDNode *N, NodeExtensionHelper LHS(Root, 0, DAG, Subtarget); NodeExtensionHelper RHS(Root, 1, DAG, Subtarget); - auto AppendUsersIfNeeded = [&Worklist, &Subtarget, - &Inserted](const NodeExtensionHelper &Op) { - if (Op.needToPromoteOtherUsers()) { - for (SDUse &Use : Op.OrigOperand->uses()) { - SDNode *TheUser = Use.getUser(); - if (!NodeExtensionHelper::isSupportedRoot(TheUser, Subtarget)) - return false; - // We only support the first 2 operands of FMA. - if (Use.getOperandNo() >= 2) - return false; - if (Inserted.insert(TheUser).second) - Worklist.push_back(TheUser); - } - } - return true; - }; + auto AppendUsersIfNeeded = + [&Worklist, &Subtarget, &Inserted, + &ExtensionsToRemove](const NodeExtensionHelper &Op) { + if (Op.needToPromoteOtherUsers()) { + // Remember that we're supposed to remove this extension. + ExtensionsToRemove.insert(Op.OrigOperand.getNode()); + for (SDUse &Use : Op.OrigOperand->uses()) { + SDNode *TheUser = Use.getUser(); + if (!NodeExtensionHelper::isSupportedRoot(TheUser, Subtarget)) + return false; + // We only support the first 2 operands of FMA. + if (Use.getOperandNo() >= 2) + return false; + if (Inserted.insert(TheUser).second) + Worklist.push_back(TheUser); + } + } + return true; + }; // Control the compile time by limiting the number of node we look at in // total. @@ -17932,6 +17916,15 @@ static SDValue combineOp_VLToVWOp_VL(SDNode *N, std::optional<CombineResult> Res = FoldingStrategy(Root, LHS, RHS, DAG, Subtarget); if (Res) { + // If this strategy wouldn't remove an extension we're supposed to + // remove, reject it. + if (!Res->LHSExt.has_value() && + ExtensionsToRemove.contains(LHS.OrigOperand.getNode())) + continue; + if (!Res->RHSExt.has_value() && + ExtensionsToRemove.contains(RHS.OrigOperand.getNode())) + continue; + Matched = true; CombinesToApply.push_back(*Res); // All the inputs that are extended need to be folded, otherwise @@ -25320,3 +25313,12 @@ ArrayRef<MCPhysReg> RISCVTargetLowering::getRoundingControlRegisters() const { } return {}; } + +bool RISCVTargetLowering::shouldFoldMaskToVariableShiftPair(SDValue Y) const { + EVT VT = Y.getValueType(); + + if (VT.isVector()) + return false; + + return VT.getSizeInBits() <= Subtarget.getXLen(); +} |
