diff options
Diffstat (limited to 'llvm/lib/Target/RISCV')
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVFrameLowering.cpp | 6 | ||||
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 126 | ||||
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp | 7 | ||||
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVSchedSpacemitX60.td | 66 |
4 files changed, 134 insertions, 71 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp b/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp index b37b740..f881c4c 100644 --- a/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp @@ -789,6 +789,8 @@ void RISCVFrameLowering::allocateStack(MachineBasicBlock &MBB, // Unroll the probe loop depending on the number of iterations. if (Offset < ProbeSize * 5) { + uint64_t CFAAdjust = RealStackSize - Offset; + uint64_t CurrentOffset = 0; while (CurrentOffset + ProbeSize <= Offset) { RI->adjustReg(MBB, MBBI, DL, SPReg, SPReg, @@ -802,7 +804,7 @@ void RISCVFrameLowering::allocateStack(MachineBasicBlock &MBB, CurrentOffset += ProbeSize; if (EmitCFI) - CFIBuilder.buildDefCFAOffset(CurrentOffset); + CFIBuilder.buildDefCFAOffset(CurrentOffset + CFAAdjust); } uint64_t Residual = Offset - CurrentOffset; @@ -810,7 +812,7 @@ void RISCVFrameLowering::allocateStack(MachineBasicBlock &MBB, RI->adjustReg(MBB, MBBI, DL, SPReg, SPReg, StackOffset::getFixed(-Residual), Flag, getStackAlign()); if (EmitCFI) - CFIBuilder.buildDefCFAOffset(Offset); + CFIBuilder.buildDefCFAOffset(RealStackSize); if (DynAllocation) { // s[d|w] zero, 0(sp) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index c3f100e..3b69eda 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -16496,32 +16496,42 @@ static SDValue expandMulToAddOrSubOfShl(SDNode *N, SelectionDAG &DAG, } static SDValue getShlAddShlAdd(SDNode *N, SelectionDAG &DAG, unsigned ShX, - unsigned ShY) { + unsigned ShY, bool AddX) { SDLoc DL(N); EVT VT = N->getValueType(0); SDValue X = N->getOperand(0); SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, DAG.getTargetConstant(ShY, DL, VT), X); return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359, - DAG.getTargetConstant(ShX, DL, VT), Mul359); + DAG.getTargetConstant(ShX, DL, VT), AddX ? X : Mul359); } static SDValue expandMulToShlAddShlAdd(SDNode *N, SelectionDAG &DAG, uint64_t MulAmt) { + // 3/5/9 * 3/5/9 -> (shXadd (shYadd X, X), (shYadd X, X)) switch (MulAmt) { case 5 * 3: - return getShlAddShlAdd(N, DAG, 2, 1); + return getShlAddShlAdd(N, DAG, 2, 1, /*AddX=*/false); case 9 * 3: - return getShlAddShlAdd(N, DAG, 3, 1); + return getShlAddShlAdd(N, DAG, 3, 1, /*AddX=*/false); case 5 * 5: - return getShlAddShlAdd(N, DAG, 2, 2); + return getShlAddShlAdd(N, DAG, 2, 2, /*AddX=*/false); case 9 * 5: - return getShlAddShlAdd(N, DAG, 3, 2); + return getShlAddShlAdd(N, DAG, 3, 2, /*AddX=*/false); case 9 * 9: - return getShlAddShlAdd(N, DAG, 3, 3); + return getShlAddShlAdd(N, DAG, 3, 3, /*AddX=*/false); default: - return SDValue(); + break; + } + + // 2/4/8 * 3/5/9 + 1 -> (shXadd (shYadd X, X), X) + int ShX; + if (int ShY = isShifted359(MulAmt - 1, ShX)) { + assert(ShX != 0 && "MulAmt=4,6,10 handled before"); + if (ShX <= 3) + return getShlAddShlAdd(N, DAG, ShX, ShY, /*AddX=*/true); } + return SDValue(); } // Try to expand a scalar multiply to a faster sequence. @@ -16581,41 +16591,30 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, DAG.getConstant(Shift, DL, VT)); } - // 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X) - if (SDValue V = expandMulToShlAddShlAdd(N, DAG, MulAmt)) - return V; + // 3/5/9 * 3/5/9 * 2^N - In particular, this covers multiples + // of 25 which happen to be quite common. + // (2/4/8 * 3/5/9 + 1) * 2^N + Shift = llvm::countr_zero(MulAmt); + if (SDValue V = expandMulToShlAddShlAdd(N, DAG, MulAmt >> Shift)) { + if (Shift == 0) + return V; + SDLoc DL(N); + return DAG.getNode(ISD::SHL, DL, VT, V, DAG.getConstant(Shift, DL, VT)); + } // If this is a power 2 + 2/4/8, we can use a shift followed by a single // shXadd. First check if this a sum of two power of 2s because that's // easy. Then count how many zeros are up to the first bit. - if (isPowerOf2_64(MulAmt & (MulAmt - 1))) { - unsigned ScaleShift = llvm::countr_zero(MulAmt); - if (ScaleShift >= 1 && ScaleShift < 4) { - unsigned ShiftAmt = Log2_64((MulAmt & (MulAmt - 1))); - SDLoc DL(N); - SDValue Shift1 = - DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShiftAmt, DL, VT)); - return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, - DAG.getTargetConstant(ScaleShift, DL, VT), Shift1); - } + if (Shift >= 1 && Shift <= 3 && isPowerOf2_64(MulAmt & (MulAmt - 1))) { + unsigned ShiftAmt = llvm::countr_zero((MulAmt & (MulAmt - 1))); + SDLoc DL(N); + SDValue Shift1 = + DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShiftAmt, DL, VT)); + return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, + DAG.getTargetConstant(Shift, DL, VT), Shift1); } - // 2^(1,2,3) * 3,5,9 + 1 -> (shXadd (shYadd x, x), x) - // This is the two instruction form, there are also three instruction - // variants we could implement. e.g. - // (2^(1,2,3) * 3,5,9 + 1) << C2 - // 2^(C1>3) * 3,5,9 +/- 1 - if (int ShXAmount = isShifted359(MulAmt - 1, Shift)) { - assert(Shift != 0 && "MulAmt=4,6,10 handled before"); - if (Shift <= 3) { - SDLoc DL(N); - SDValue Mul359 = - DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, - DAG.getTargetConstant(ShXAmount, DL, VT), X); - return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359, - DAG.getTargetConstant(Shift, DL, VT), X); - } - } + // TODO: 2^(C1>3) * 3,5,9 +/- 1 // 2^n + 2/4/8 + 1 -> (add (shl X, C1), (shXadd X, X)) if (MulAmt > 2 && isPowerOf2_64((MulAmt - 1) & (MulAmt - 2))) { @@ -16647,14 +16646,6 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, return DAG.getNode(ISD::SUB, DL, VT, Shift1, Mul359); } } - - // 3/5/9 * 3/5/9 * 2^N - In particular, this covers multiples - // of 25 which happen to be quite common. - Shift = llvm::countr_zero(MulAmt); - if (SDValue V = expandMulToShlAddShlAdd(N, DAG, MulAmt >> Shift)) { - SDLoc DL(N); - return DAG.getNode(ISD::SHL, DL, VT, V, DAG.getConstant(Shift, DL, VT)); - } } if (SDValue V = expandMulToAddOrSubOfShl(N, DAG, MulAmt)) @@ -17876,6 +17867,7 @@ static SDValue combineOp_VLToVWOp_VL(SDNode *N, SmallVector<SDNode *> Worklist; SmallPtrSet<SDNode *, 8> Inserted; + SmallPtrSet<SDNode *, 8> ExtensionsToRemove; Worklist.push_back(N); Inserted.insert(N); SmallVector<CombineResult> CombinesToApply; @@ -17885,22 +17877,25 @@ static SDValue combineOp_VLToVWOp_VL(SDNode *N, NodeExtensionHelper LHS(Root, 0, DAG, Subtarget); NodeExtensionHelper RHS(Root, 1, DAG, Subtarget); - auto AppendUsersIfNeeded = [&Worklist, &Subtarget, - &Inserted](const NodeExtensionHelper &Op) { - if (Op.needToPromoteOtherUsers()) { - for (SDUse &Use : Op.OrigOperand->uses()) { - SDNode *TheUser = Use.getUser(); - if (!NodeExtensionHelper::isSupportedRoot(TheUser, Subtarget)) - return false; - // We only support the first 2 operands of FMA. - if (Use.getOperandNo() >= 2) - return false; - if (Inserted.insert(TheUser).second) - Worklist.push_back(TheUser); - } - } - return true; - }; + auto AppendUsersIfNeeded = + [&Worklist, &Subtarget, &Inserted, + &ExtensionsToRemove](const NodeExtensionHelper &Op) { + if (Op.needToPromoteOtherUsers()) { + // Remember that we're supposed to remove this extension. + ExtensionsToRemove.insert(Op.OrigOperand.getNode()); + for (SDUse &Use : Op.OrigOperand->uses()) { + SDNode *TheUser = Use.getUser(); + if (!NodeExtensionHelper::isSupportedRoot(TheUser, Subtarget)) + return false; + // We only support the first 2 operands of FMA. + if (Use.getOperandNo() >= 2) + return false; + if (Inserted.insert(TheUser).second) + Worklist.push_back(TheUser); + } + } + return true; + }; // Control the compile time by limiting the number of node we look at in // total. @@ -17921,6 +17916,15 @@ static SDValue combineOp_VLToVWOp_VL(SDNode *N, std::optional<CombineResult> Res = FoldingStrategy(Root, LHS, RHS, DAG, Subtarget); if (Res) { + // If this strategy wouldn't remove an extension we're supposed to + // remove, reject it. + if (!Res->LHSExt.has_value() && + ExtensionsToRemove.contains(LHS.OrigOperand.getNode())) + continue; + if (!Res->RHSExt.has_value() && + ExtensionsToRemove.contains(RHS.OrigOperand.getNode())) + continue; + Matched = true; CombinesToApply.push_back(*Res); // All the inputs that are extended need to be folded, otherwise diff --git a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp index 636e31c..bf9de0a 100644 --- a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp +++ b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp @@ -1583,7 +1583,10 @@ void RISCVInsertVSETVLI::emitVSETVLIs(MachineBasicBlock &MBB) { if (!TII->isAddImmediate(*DeadMI, Reg)) continue; LIS->RemoveMachineInstrFromMaps(*DeadMI); + Register AddReg = DeadMI->getOperand(1).getReg(); DeadMI->eraseFromParent(); + if (AddReg.isVirtual()) + LIS->shrinkToUses(&LIS->getInterval(AddReg)); } } } @@ -1869,11 +1872,15 @@ void RISCVInsertVSETVLI::coalesceVSETVLIs(MachineBasicBlock &MBB) const { // Loop over the dead AVL values, and delete them now. This has // to be outside the above loop to avoid invalidating iterators. for (auto *MI : ToDelete) { + assert(MI->getOpcode() == RISCV::ADDI); + Register AddReg = MI->getOperand(1).getReg(); if (LIS) { LIS->removeInterval(MI->getOperand(0).getReg()); LIS->RemoveMachineInstrFromMaps(*MI); } MI->eraseFromParent(); + if (LIS && AddReg.isVirtual()) + LIS->shrinkToUses(&LIS->getInterval(AddReg)); } } diff --git a/llvm/lib/Target/RISCV/RISCVSchedSpacemitX60.td b/llvm/lib/Target/RISCV/RISCVSchedSpacemitX60.td index 24ebbc3..41071b2 100644 --- a/llvm/lib/Target/RISCV/RISCVSchedSpacemitX60.td +++ b/llvm/lib/Target/RISCV/RISCVSchedSpacemitX60.td @@ -654,8 +654,17 @@ foreach mx = SchedMxList in { foreach sew = SchedSEWSet<mx>.val in { defvar IsWorstCase = SMX60IsWorstCaseMXSEW<mx, sew, SchedMxList>.c; - defm "" : LMULSEWWriteResMXSEW<"WriteVIRedV_From", [SMX60_VIEU], mx, sew, IsWorstCase>; - defm "" : LMULSEWWriteResMXSEW<"WriteVIRedMinMaxV_From", [SMX60_VIEU], mx, sew, IsWorstCase>; + defvar VIRedLat = GetLMULValue<[5, 5, 5, 7, 11, 19, 35], mx>.c; + defvar VIRedOcc = GetLMULValue<[1, 1, 2, 2, 4, 10, 35], mx>.c; + let Latency = VIRedLat, ReleaseAtCycles = [VIRedOcc] in { + defm "" : LMULSEWWriteResMXSEW<"WriteVIRedMinMaxV_From", [SMX60_VIEU], mx, sew, IsWorstCase>; + + // Pattern for vredsum: 5/5/5/7/11/19/35 + // Pattern for vredand, vredor, vredxor: 4/4/4/6/10/18/34 + // They are grouped together, so we use the worst-case vredsum latency. + // TODO: split vredand, vredor, vredxor into separate scheduling classe. + defm "" : LMULSEWWriteResMXSEW<"WriteVIRedV_From", [SMX60_VIEU], mx, sew, IsWorstCase>; + } } } @@ -663,7 +672,27 @@ foreach mx = SchedMxListWRed in { foreach sew = SchedSEWSet<mx, 0, 1>.val in { defvar IsWorstCase = SMX60IsWorstCaseMXSEW<mx, sew, SchedMxListWRed>.c; - defm "" : LMULSEWWriteResMXSEW<"WriteVIWRedV_From", [SMX60_VIEU], mx, sew, IsWorstCase>; + defvar VIRedLat = GetLMULValue<[5, 5, 5, 7, 11, 19, 35], mx>.c; + defvar VIRedOcc = GetLMULValue<[1, 1, 2, 2, 4, 10, 35], mx>.c; + let Latency = VIRedLat, ReleaseAtCycles = [VIRedOcc] in { + defm "" : LMULSEWWriteResMXSEW<"WriteVIWRedV_From", [SMX60_VIEU], mx, sew, IsWorstCase>; + } + } +} + +foreach mx = SchedMxListF in { + foreach sew = SchedSEWSet<mx, 1>.val in { + defvar IsWorstCase = SMX60IsWorstCaseMXSEW<mx, sew, SchedMxListF, 1>.c; + + // Latency for vfredmax.vs, vfredmin.vs: 12/12/15/21/33/57 + // Latency for vfredusum.vs is slightly lower for e16/e32 + // We use the worst-case + defvar VFRedLat = GetLMULValue<[12, 12, 12, 15, 21, 33, 57], mx>.c; + defvar VFRedOcc = GetLMULValue<[8, 8, 8, 8, 14, 20, 57], mx>.c; + let Latency = VFRedLat, ReleaseAtCycles = [VFRedOcc] in { + defm "" : LMULSEWWriteResMXSEW<"WriteVFRedV_From", [SMX60_VFP], mx, sew, IsWorstCase>; + defm "" : LMULSEWWriteResMXSEW<"WriteVFRedMinMaxV_From", [SMX60_VFP], mx, sew, IsWorstCase>; + } } } @@ -671,9 +700,20 @@ foreach mx = SchedMxListF in { foreach sew = SchedSEWSet<mx, 1>.val in { defvar IsWorstCase = SMX60IsWorstCaseMXSEW<mx, sew, SchedMxListF, 1>.c; - defm "" : LMULSEWWriteResMXSEW<"WriteVFRedV_From", [SMX60_VFP], mx, sew, IsWorstCase>; - defm "" : LMULSEWWriteResMXSEW<"WriteVFRedOV_From", [SMX60_VFP], mx, sew, IsWorstCase>; - defm "" : LMULSEWWriteResMXSEW<"WriteVFRedMinMaxV_From", [SMX60_VFP], mx, sew, IsWorstCase>; + // Compute latency based on SEW + defvar VFRedOV_FromLat = !cond( + !eq(sew, 16) : ConstValueUntilLMULThenDouble<"MF4", 12, mx>.c, + !eq(sew, 32) : ConstValueUntilLMULThenDouble<"MF2", 12, mx>.c, + !eq(sew, 64) : ConstValueUntilLMULThenDouble<"M1", 12, mx>.c + ); + defvar VFRedOV_FromOcc = !cond( + !eq(sew, 16) : GetLMULValue<[8, 8, 20, 24, 48, 96, 384], mx>.c, + !eq(sew, 32) : GetLMULValue<[8, 8, 8, 12, 24, 48, 192], mx>.c, + !eq(sew, 64) : GetLMULValue<[6, 6, 6, 6, 12, 24, 96], mx>.c + ); + let Latency = VFRedOV_FromLat, ReleaseAtCycles = [VFRedOV_FromOcc] in { + defm "" : LMULSEWWriteResMXSEW<"WriteVFRedOV_From", [SMX60_VFP], mx, sew, IsWorstCase>; + } } } @@ -681,8 +721,18 @@ foreach mx = SchedMxListFWRed in { foreach sew = SchedSEWSet<mx, 1, 1>.val in { defvar IsWorstCase = SMX60IsWorstCaseMXSEW<mx, sew, SchedMxListFWRed, 1>.c; - defm "" : LMULSEWWriteResMXSEW<"WriteVFWRedV_From", [SMX60_VFP], mx, sew, IsWorstCase>; - defm "" : LMULSEWWriteResMXSEW<"WriteVFWRedOV_From", [SMX60_VFP], mx, sew, IsWorstCase>; + defvar VFRedOVLat = !cond( + !eq(sew, 16) : ConstValueUntilLMULThenDouble<"MF4", 16, mx>.c, + !eq(sew, 32) : ConstValueUntilLMULThenDouble<"MF2", 16, mx>.c, + ); + defvar VFRedOVOcc = !cond( + !eq(sew, 16) : GetLMULValue<[11, 11, 27, 32, 64, 128, 512], mx>.c, + !eq(sew, 32) : GetLMULValue<[11, 11, 11, 16, 32, 64, 256], mx>.c, + ); + let Latency = VFRedOVLat, ReleaseAtCycles = [VFRedOVOcc] in { + defm "" : LMULSEWWriteResMXSEW<"WriteVFWRedV_From", [SMX60_VFP], mx, sew, IsWorstCase>; + defm "" : LMULSEWWriteResMXSEW<"WriteVFWRedOV_From", [SMX60_VFP], mx, sew, IsWorstCase>; + } } } |
