diff options
Diffstat (limited to 'llvm/lib/Target/RISCV')
| -rw-r--r-- | llvm/lib/Target/RISCV/CMakeLists.txt | 1 | ||||
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCV.h | 4 | ||||
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp | 2 | ||||
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVFeatures.td | 5 | ||||
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 171 | ||||
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVISelLowering.h | 2 | ||||
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp | 7 | ||||
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVInstrInfo.cpp | 4 | ||||
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVInstrInfoSFB.td | 1 | ||||
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td | 6 | ||||
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVInstrInfoXqci.td | 6 | ||||
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVPromoteConstant.cpp | 213 | ||||
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVTargetMachine.cpp | 3 |
13 files changed, 330 insertions, 95 deletions
diff --git a/llvm/lib/Target/RISCV/CMakeLists.txt b/llvm/lib/Target/RISCV/CMakeLists.txt index 0ff178e..e9088a4 100644 --- a/llvm/lib/Target/RISCV/CMakeLists.txt +++ b/llvm/lib/Target/RISCV/CMakeLists.txt @@ -58,6 +58,7 @@ add_llvm_target(RISCVCodeGen RISCVMoveMerger.cpp RISCVOptWInstrs.cpp RISCVPostRAExpandPseudoInsts.cpp + RISCVPromoteConstant.cpp RISCVPushPopOptimizer.cpp RISCVRedundantCopyElimination.cpp RISCVRegisterInfo.cpp diff --git a/llvm/lib/Target/RISCV/RISCV.h b/llvm/lib/Target/RISCV/RISCV.h index ae94101..51e8e85 100644 --- a/llvm/lib/Target/RISCV/RISCV.h +++ b/llvm/lib/Target/RISCV/RISCV.h @@ -20,6 +20,7 @@ namespace llvm { class FunctionPass; class InstructionSelector; +class ModulePass; class PassRegistry; class RISCVRegisterBankInfo; class RISCVSubtarget; @@ -111,6 +112,9 @@ void initializeRISCVO0PreLegalizerCombinerPass(PassRegistry &); FunctionPass *createRISCVPreLegalizerCombiner(); void initializeRISCVPreLegalizerCombinerPass(PassRegistry &); +ModulePass *createRISCVPromoteConstantPass(); +void initializeRISCVPromoteConstantPass(PassRegistry &); + FunctionPass *createRISCVVLOptimizerPass(); void initializeRISCVVLOptimizerPass(PassRegistry &); diff --git a/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp b/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp index 526675a..b0453fc 100644 --- a/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp +++ b/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp @@ -131,6 +131,7 @@ bool RISCVExpandPseudo::expandMI(MachineBasicBlock &MBB, case RISCV::PseudoCCMAXU: case RISCV::PseudoCCMIN: case RISCV::PseudoCCMINU: + case RISCV::PseudoCCMUL: case RISCV::PseudoCCADDW: case RISCV::PseudoCCSUBW: case RISCV::PseudoCCSLL: @@ -237,6 +238,7 @@ bool RISCVExpandPseudo::expandCCOp(MachineBasicBlock &MBB, case RISCV::PseudoCCMIN: NewOpc = RISCV::MIN; break; case RISCV::PseudoCCMAXU: NewOpc = RISCV::MAXU; break; case RISCV::PseudoCCMINU: NewOpc = RISCV::MINU; break; + case RISCV::PseudoCCMUL: NewOpc = RISCV::MUL; break; case RISCV::PseudoCCADDI: NewOpc = RISCV::ADDI; break; case RISCV::PseudoCCSLLI: NewOpc = RISCV::SLLI; break; case RISCV::PseudoCCSRLI: NewOpc = RISCV::SRLI; break; diff --git a/llvm/lib/Target/RISCV/RISCVFeatures.td b/llvm/lib/Target/RISCV/RISCVFeatures.td index cfee6ab..5b72334 100644 --- a/llvm/lib/Target/RISCV/RISCVFeatures.td +++ b/llvm/lib/Target/RISCV/RISCVFeatures.td @@ -1856,6 +1856,11 @@ def TuneShortForwardBranchIMinMax "true", "Enable short forward branch optimization for min,max instructions in Zbb", [TuneShortForwardBranchOpt]>; +def TuneShortForwardBranchIMul + : SubtargetFeature<"short-forward-branch-i-mul", "HasShortForwardBranchIMul", + "true", "Enable short forward branch optimization for mul instruction", + [TuneShortForwardBranchOpt]>; + // Some subtargets require a S2V transfer buffer to move scalars into vectors. // FIXME: Forming .vx/.vf/.wx/.wf can reduce register pressure. def TuneNoSinkSplatOperands diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index e0cf739..995ae75 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -9186,7 +9186,7 @@ static SDValue lowerSelectToBinOp(SDNode *N, SelectionDAG &DAG, unsigned ShAmount = Log2_64(TrueM1); if (Subtarget.hasShlAdd(ShAmount)) return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, CondV, - DAG.getConstant(ShAmount, DL, VT), CondV); + DAG.getTargetConstant(ShAmount, DL, VT), CondV); } } // (select c, y, 0) -> -c & y @@ -15463,7 +15463,7 @@ static SDValue transformAddShlImm(SDNode *N, SelectionDAG &DAG, SDValue NS = (C0 < C1) ? N0->getOperand(0) : N1->getOperand(0); SDValue NL = (C0 > C1) ? N0->getOperand(0) : N1->getOperand(0); SDValue SHADD = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, NL, - DAG.getConstant(Diff, DL, VT), NS); + DAG.getTargetConstant(Diff, DL, VT), NS); return DAG.getNode(ISD::SHL, DL, VT, SHADD, DAG.getConstant(Bits, DL, VT)); } @@ -15501,7 +15501,7 @@ static SDValue combineShlAddIAddImpl(SDNode *N, SDValue AddI, SDValue Other, int64_t AddConst = AddVal.getSExtValue(); SDValue SHADD = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, SHLVal->getOperand(0), - DAG.getConstant(ShlConst, DL, VT), Other); + DAG.getTargetConstant(ShlConst, DL, VT), Other); return DAG.getNode(ISD::ADD, DL, VT, SHADD, DAG.getSignedConstant(AddConst, DL, VT)); } @@ -16495,6 +16495,45 @@ static SDValue expandMulToAddOrSubOfShl(SDNode *N, SelectionDAG &DAG, return DAG.getNode(Op, DL, VT, Shift1, Shift2); } +static SDValue getShlAddShlAdd(SDNode *N, SelectionDAG &DAG, unsigned ShX, + unsigned ShY, bool AddX) { + SDLoc DL(N); + EVT VT = N->getValueType(0); + SDValue X = N->getOperand(0); + SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, + DAG.getTargetConstant(ShY, DL, VT), X); + return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359, + DAG.getTargetConstant(ShX, DL, VT), AddX ? X : Mul359); +} + +static SDValue expandMulToShlAddShlAdd(SDNode *N, SelectionDAG &DAG, + uint64_t MulAmt) { + // 3/5/9 * 3/5/9 -> (shXadd (shYadd X, X), (shYadd X, X)) + switch (MulAmt) { + case 5 * 3: + return getShlAddShlAdd(N, DAG, 2, 1, /*AddX=*/false); + case 9 * 3: + return getShlAddShlAdd(N, DAG, 3, 1, /*AddX=*/false); + case 5 * 5: + return getShlAddShlAdd(N, DAG, 2, 2, /*AddX=*/false); + case 9 * 5: + return getShlAddShlAdd(N, DAG, 3, 2, /*AddX=*/false); + case 9 * 9: + return getShlAddShlAdd(N, DAG, 3, 3, /*AddX=*/false); + default: + break; + } + + // 2/4/8 * 3/5/9 + 1 -> (shXadd (shYadd X, X), X) + int ShX; + if (int ShY = isShifted359(MulAmt - 1, ShX)) { + assert(ShX != 0 && "MulAmt=4,6,10 handled before"); + if (ShX <= 3) + return getShlAddShlAdd(N, DAG, ShX, ShY, /*AddX=*/true); + } + return SDValue(); +} + // Try to expand a scalar multiply to a faster sequence. static SDValue expandMul(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, @@ -16524,18 +16563,17 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, if (Subtarget.hasVendorXqciac() && isInt<12>(CNode->getSExtValue())) return SDValue(); - // WARNING: The code below is knowingly incorrect with regards to undef semantics. - // We're adding additional uses of X here, and in principle, we should be freezing - // X before doing so. However, adding freeze here causes real regressions, and no - // other target properly freezes X in these cases either. - SDValue X = N->getOperand(0); - + // WARNING: The code below is knowingly incorrect with regards to undef + // semantics. We're adding additional uses of X here, and in principle, we + // should be freezing X before doing so. However, adding freeze here causes + // real regressions, and no other target properly freezes X in these cases + // either. if (Subtarget.hasShlAdd(3)) { + SDValue X = N->getOperand(0); int Shift; if (int ShXAmount = isShifted359(MulAmt, Shift)) { // 3/5/9 * 2^N -> shl (shXadd X, X), N SDLoc DL(N); - SDValue X = N->getOperand(0); // Put the shift first if we can fold a zext into the shift forming // a slli.uw. if (X.getOpcode() == ISD::AND && isa<ConstantSDNode>(X.getOperand(1)) && @@ -16543,80 +16581,40 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, SDValue Shl = DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(Shift, DL, VT)); return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Shl, - DAG.getConstant(ShXAmount, DL, VT), Shl); + DAG.getTargetConstant(ShXAmount, DL, VT), Shl); } // Otherwise, put the shl second so that it can fold with following // instructions (e.g. sext or add). SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, - DAG.getConstant(ShXAmount, DL, VT), X); + DAG.getTargetConstant(ShXAmount, DL, VT), X); return DAG.getNode(ISD::SHL, DL, VT, Mul359, DAG.getConstant(Shift, DL, VT)); } - // 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X) - int ShX; - int ShY; - switch (MulAmt) { - case 3 * 5: - ShY = 1; - ShX = 2; - break; - case 3 * 9: - ShY = 1; - ShX = 3; - break; - case 5 * 5: - ShX = ShY = 2; - break; - case 5 * 9: - ShY = 2; - ShX = 3; - break; - case 9 * 9: - ShX = ShY = 3; - break; - default: - ShX = ShY = 0; - break; - } - if (ShX) { + // 3/5/9 * 3/5/9 * 2^N - In particular, this covers multiples + // of 25 which happen to be quite common. + // (2/4/8 * 3/5/9 + 1) * 2^N + Shift = llvm::countr_zero(MulAmt); + if (SDValue V = expandMulToShlAddShlAdd(N, DAG, MulAmt >> Shift)) { + if (Shift == 0) + return V; SDLoc DL(N); - SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, - DAG.getConstant(ShY, DL, VT), X); - return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359, - DAG.getConstant(ShX, DL, VT), Mul359); + return DAG.getNode(ISD::SHL, DL, VT, V, DAG.getConstant(Shift, DL, VT)); } // If this is a power 2 + 2/4/8, we can use a shift followed by a single // shXadd. First check if this a sum of two power of 2s because that's // easy. Then count how many zeros are up to the first bit. - if (isPowerOf2_64(MulAmt & (MulAmt - 1))) { - unsigned ScaleShift = llvm::countr_zero(MulAmt); - if (ScaleShift >= 1 && ScaleShift < 4) { - unsigned ShiftAmt = Log2_64((MulAmt & (MulAmt - 1))); - SDLoc DL(N); - SDValue Shift1 = - DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShiftAmt, DL, VT)); - return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, - DAG.getConstant(ScaleShift, DL, VT), Shift1); - } + if (Shift >= 1 && Shift <= 3 && isPowerOf2_64(MulAmt & (MulAmt - 1))) { + unsigned ShiftAmt = llvm::countr_zero((MulAmt & (MulAmt - 1))); + SDLoc DL(N); + SDValue Shift1 = + DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShiftAmt, DL, VT)); + return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, + DAG.getTargetConstant(Shift, DL, VT), Shift1); } - // 2^(1,2,3) * 3,5,9 + 1 -> (shXadd (shYadd x, x), x) - // This is the two instruction form, there are also three instruction - // variants we could implement. e.g. - // (2^(1,2,3) * 3,5,9 + 1) << C2 - // 2^(C1>3) * 3,5,9 +/- 1 - if (int ShXAmount = isShifted359(MulAmt - 1, Shift)) { - assert(Shift != 0 && "MulAmt=4,6,10 handled before"); - if (Shift <= 3) { - SDLoc DL(N); - SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, - DAG.getConstant(ShXAmount, DL, VT), X); - return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359, - DAG.getConstant(Shift, DL, VT), X); - } - } + // TODO: 2^(C1>3) * 3,5,9 +/- 1 // 2^n + 2/4/8 + 1 -> (add (shl X, C1), (shXadd X, X)) if (MulAmt > 2 && isPowerOf2_64((MulAmt - 1) & (MulAmt - 2))) { @@ -16626,9 +16624,10 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, SDLoc DL(N); SDValue Shift1 = DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShiftAmt, DL, VT)); - return DAG.getNode(ISD::ADD, DL, VT, Shift1, - DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, - DAG.getConstant(ScaleShift, DL, VT), X)); + return DAG.getNode( + ISD::ADD, DL, VT, Shift1, + DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, + DAG.getTargetConstant(ScaleShift, DL, VT), X)); } } @@ -16643,29 +16642,10 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShAmt, DL, VT)); SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, - DAG.getConstant(Log2_64(Offset - 1), DL, VT), X); + DAG.getTargetConstant(Log2_64(Offset - 1), DL, VT), X); return DAG.getNode(ISD::SUB, DL, VT, Shift1, Mul359); } } - - for (uint64_t Divisor : {3, 5, 9}) { - if (MulAmt % Divisor != 0) - continue; - uint64_t MulAmt2 = MulAmt / Divisor; - // 3/5/9 * 3/5/9 * 2^N - In particular, this covers multiples - // of 25 which happen to be quite common. - if (int ShBAmount = isShifted359(MulAmt2, Shift)) { - SDLoc DL(N); - SDValue Mul359A = - DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, - DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X); - SDValue Mul359B = - DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359A, - DAG.getConstant(ShBAmount, DL, VT), Mul359A); - return DAG.getNode(ISD::SHL, DL, VT, Mul359B, - DAG.getConstant(Shift, DL, VT)); - } - } } if (SDValue V = expandMulToAddOrSubOfShl(N, DAG, MulAmt)) @@ -25320,3 +25300,12 @@ ArrayRef<MCPhysReg> RISCVTargetLowering::getRoundingControlRegisters() const { } return {}; } + +bool RISCVTargetLowering::shouldFoldMaskToVariableShiftPair(SDValue Y) const { + EVT VT = Y.getValueType(); + + if (VT.isVector()) + return false; + + return VT.getSizeInBits() <= Subtarget.getXLen(); +} diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h index 9e3e2a9..dd62a9c 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -465,6 +465,8 @@ public: ArrayRef<MCPhysReg> getRoundingControlRegisters() const override; + bool shouldFoldMaskToVariableShiftPair(SDValue Y) const override; + /// Match a mask which "spreads" the leading elements of a vector evenly /// across the result. Factor is the spread amount, and Index is the /// offset applied. diff --git a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp index 636e31c..bf9de0a 100644 --- a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp +++ b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp @@ -1583,7 +1583,10 @@ void RISCVInsertVSETVLI::emitVSETVLIs(MachineBasicBlock &MBB) { if (!TII->isAddImmediate(*DeadMI, Reg)) continue; LIS->RemoveMachineInstrFromMaps(*DeadMI); + Register AddReg = DeadMI->getOperand(1).getReg(); DeadMI->eraseFromParent(); + if (AddReg.isVirtual()) + LIS->shrinkToUses(&LIS->getInterval(AddReg)); } } } @@ -1869,11 +1872,15 @@ void RISCVInsertVSETVLI::coalesceVSETVLIs(MachineBasicBlock &MBB) const { // Loop over the dead AVL values, and delete them now. This has // to be outside the above loop to avoid invalidating iterators. for (auto *MI : ToDelete) { + assert(MI->getOpcode() == RISCV::ADDI); + Register AddReg = MI->getOperand(1).getReg(); if (LIS) { LIS->removeInterval(MI->getOperand(0).getReg()); LIS->RemoveMachineInstrFromMaps(*MI); } MI->eraseFromParent(); + if (LIS && AddReg.isVirtual()) + LIS->shrinkToUses(&LIS->getInterval(AddReg)); } } diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp index c9df787..b8ab70b 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp @@ -1703,6 +1703,7 @@ unsigned getPredicatedOpcode(unsigned Opcode) { case RISCV::MAXU: return RISCV::PseudoCCMAXU; case RISCV::MIN: return RISCV::PseudoCCMIN; case RISCV::MINU: return RISCV::PseudoCCMINU; + case RISCV::MUL: return RISCV::PseudoCCMUL; case RISCV::ADDI: return RISCV::PseudoCCADDI; case RISCV::SLLI: return RISCV::PseudoCCSLLI; @@ -1754,6 +1755,9 @@ static MachineInstr *canFoldAsPredicatedOp(Register Reg, MI->getOpcode() == RISCV::MINU || MI->getOpcode() == RISCV::MAXU)) return nullptr; + if (!STI.hasShortForwardBranchIMul() && MI->getOpcode() == RISCV::MUL) + return nullptr; + // Check if MI can be predicated and folded into the CCMOV. if (getPredicatedOpcode(MI->getOpcode()) == RISCV::INSTRUCTION_LIST_END) return nullptr; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoSFB.td b/llvm/lib/Target/RISCV/RISCVInstrInfoSFB.td index 5a67a5a..494b1c9 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoSFB.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoSFB.td @@ -110,6 +110,7 @@ def PseudoCCMAX : SFBALU_rr; def PseudoCCMIN : SFBALU_rr; def PseudoCCMAXU : SFBALU_rr; def PseudoCCMINU : SFBALU_rr; +def PseudoCCMUL : SFBALU_rr; def PseudoCCADDI : SFBALU_ri; def PseudoCCANDI : SFBALU_ri; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td b/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td index b37ceaae..c2b25c6 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td @@ -60,6 +60,8 @@ def immfour : RISCVOp { let DecoderMethod = "decodeImmFourOperand"; } +def tuimm2 : TImmLeaf<XLenVT, [{return isUInt<2>(Imm);}]>; + //===----------------------------------------------------------------------===// // Instruction class templates //===----------------------------------------------------------------------===// @@ -557,8 +559,8 @@ multiclass VPatTernaryVMAQA_VV_VX<string intrinsic, string instruction, let Predicates = [HasVendorXTHeadBa] in { def : Pat<(add_like_non_imm12 (shl GPR:$rs2, uimm2:$uimm2), (XLenVT GPR:$rs1)), (TH_ADDSL GPR:$rs1, GPR:$rs2, uimm2:$uimm2)>; -def : Pat<(XLenVT (riscv_shl_add GPR:$rs2, uimm2:$uimm2, GPR:$rs1)), - (TH_ADDSL GPR:$rs1, GPR:$rs2, uimm2:$uimm2)>; +def : Pat<(XLenVT (riscv_shl_add GPR:$rs2, tuimm2:$uimm2, GPR:$rs1)), + (TH_ADDSL GPR:$rs1, GPR:$rs2, tuimm2:$uimm2)>; // Reuse complex patterns from StdExtZba def : Pat<(add_like_non_imm12 sh1add_op:$rs2, (XLenVT GPR:$rs1)), diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoXqci.td b/llvm/lib/Target/RISCV/RISCVInstrInfoXqci.td index 4537bfe..8376da5 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoXqci.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoXqci.td @@ -53,6 +53,8 @@ def uimm5gt3 : RISCVOp<XLenVT>, ImmLeaf<XLenVT, let OperandType = "OPERAND_UIMM5_GT3"; } +def tuimm5gt3 : TImmLeaf<XLenVT, [{return (Imm > 3) && isUInt<5>(Imm);}]>; + def UImm5Plus1AsmOperand : AsmOperandClass { let Name = "UImm5Plus1"; let RenderMethod = "addImmOperands"; @@ -1419,8 +1421,8 @@ def : Pat<(i32 (add GPRNoX0:$rd, (mul GPRNoX0:$rs1, simm12_lo:$imm12))), (QC_MULIADD GPRNoX0:$rd, GPRNoX0:$rs1, simm12_lo:$imm12)>; def : Pat<(i32 (add_like_non_imm12 (shl GPRNoX0:$rs1, (i32 uimm5gt3:$imm)), GPRNoX0:$rs2)), (QC_SHLADD GPRNoX0:$rs1, GPRNoX0:$rs2, uimm5gt3:$imm)>; -def : Pat<(i32 (riscv_shl_add GPRNoX0:$rs1, (i32 uimm5gt3:$imm), GPRNoX0:$rs2)), - (QC_SHLADD GPRNoX0:$rs1, GPRNoX0:$rs2, uimm5gt3:$imm)>; +def : Pat<(i32 (riscv_shl_add GPRNoX0:$rs1, (i32 tuimm5gt3:$imm), GPRNoX0:$rs2)), + (QC_SHLADD GPRNoX0:$rs1, GPRNoX0:$rs2, tuimm5gt3:$imm)>; } // Predicates = [HasVendorXqciac, IsRV32] /// Simple arithmetic operations diff --git a/llvm/lib/Target/RISCV/RISCVPromoteConstant.cpp b/llvm/lib/Target/RISCV/RISCVPromoteConstant.cpp new file mode 100644 index 0000000..bf1f69f --- /dev/null +++ b/llvm/lib/Target/RISCV/RISCVPromoteConstant.cpp @@ -0,0 +1,213 @@ +//==- RISCVPromoteConstant.cpp - Promote constant fp to global for RISC-V --==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "RISCV.h" +#include "RISCVSubtarget.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/CodeGen/TargetLowering.h" +#include "llvm/CodeGen/TargetPassConfig.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" + +using namespace llvm; + +#define DEBUG_TYPE "riscv-promote-const" +#define RISCV_PROMOTE_CONSTANT_NAME "RISC-V Promote Constants" + +STATISTIC(NumPromoted, "Number of constant literals promoted to globals"); +STATISTIC(NumPromotedUses, "Number of uses of promoted literal constants"); + +namespace { + +class RISCVPromoteConstant : public ModulePass { +public: + static char ID; + RISCVPromoteConstant() : ModulePass(ID) {} + + StringRef getPassName() const override { return RISCV_PROMOTE_CONSTANT_NAME; } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<TargetPassConfig>(); + AU.setPreservesCFG(); + } + + /// Iterate over the functions and promote the double fp constants that + /// would otherwise go into the constant pool to a constant array. + bool runOnModule(Module &M) override { + if (skipModule(M)) + return false; + // TargetMachine and Subtarget are needed to query isFPImmlegal. + const TargetPassConfig &TPC = getAnalysis<TargetPassConfig>(); + const TargetMachine &TM = TPC.getTM<TargetMachine>(); + bool Changed = false; + for (Function &F : M) { + const RISCVSubtarget &ST = TM.getSubtarget<RISCVSubtarget>(F); + const RISCVTargetLowering *TLI = ST.getTargetLowering(); + Changed |= runOnFunction(F, TLI); + } + return Changed; + } + +private: + bool runOnFunction(Function &F, const RISCVTargetLowering *TLI); +}; +} // end anonymous namespace + +char RISCVPromoteConstant::ID = 0; + +INITIALIZE_PASS(RISCVPromoteConstant, DEBUG_TYPE, RISCV_PROMOTE_CONSTANT_NAME, + false, false) + +ModulePass *llvm::createRISCVPromoteConstantPass() { + return new RISCVPromoteConstant(); +} + +bool RISCVPromoteConstant::runOnFunction(Function &F, + const RISCVTargetLowering *TLI) { + if (F.hasOptNone() || F.hasOptSize()) + return false; + + // Bail out and make no transformation if the target doesn't support + // doubles, or if we're not targeting RV64 as we currently see some + // regressions for those targets. + if (!TLI->isTypeLegal(MVT::f64) || !TLI->isTypeLegal(MVT::i64)) + return false; + + // Collect all unique double constants and their uses in the function. Use + // MapVector to preserve insertion order. + MapVector<ConstantFP *, SmallVector<Use *, 8>> ConstUsesMap; + + for (Instruction &I : instructions(F)) { + for (Use &U : I.operands()) { + auto *C = dyn_cast<ConstantFP>(U.get()); + if (!C || !C->getType()->isDoubleTy()) + continue; + // Do not promote if it wouldn't be loaded from the constant pool. + if (TLI->isFPImmLegal(C->getValueAPF(), MVT::f64, + /*ForCodeSize=*/false)) + continue; + // Do not promote a constant if it is used as an immediate argument + // for an intrinsic. + if (auto *II = dyn_cast<IntrinsicInst>(U.getUser())) { + Function *IntrinsicFunc = II->getFunction(); + unsigned OperandIdx = U.getOperandNo(); + if (IntrinsicFunc && IntrinsicFunc->getAttributes().hasParamAttr( + OperandIdx, Attribute::ImmArg)) { + LLVM_DEBUG(dbgs() << "Skipping promotion of constant in: " << *II + << " because operand " << OperandIdx + << " must be an immediate.\n"); + continue; + } + } + // Note: FP args to inline asm would be problematic if we had a + // constraint that required an immediate floating point operand. At the + // time of writing LLVM doesn't recognise such a constraint. + ConstUsesMap[C].push_back(&U); + } + } + + int PromotableConstants = ConstUsesMap.size(); + LLVM_DEBUG(dbgs() << "Found " << PromotableConstants + << " promotable constants in " << F.getName() << "\n"); + // Bail out if no promotable constants found, or if only one is found. + if (PromotableConstants < 2) { + LLVM_DEBUG(dbgs() << "Performing no promotions as insufficient promotable " + "constants found\n"); + return false; + } + + NumPromoted += PromotableConstants; + + // Create a global array containing the promoted constants. + Module *M = F.getParent(); + Type *DoubleTy = Type::getDoubleTy(M->getContext()); + + SmallVector<Constant *, 16> ConstantVector; + for (auto const &Pair : ConstUsesMap) + ConstantVector.push_back(Pair.first); + + ArrayType *ArrayTy = ArrayType::get(DoubleTy, ConstantVector.size()); + Constant *GlobalArrayInitializer = + ConstantArray::get(ArrayTy, ConstantVector); + + auto *GlobalArray = new GlobalVariable( + *M, ArrayTy, + /*isConstant=*/true, GlobalValue::InternalLinkage, GlobalArrayInitializer, + ".promoted_doubles." + F.getName()); + + // A cache to hold the loaded value for a given constant within a basic block. + DenseMap<std::pair<ConstantFP *, BasicBlock *>, Value *> LocalLoads; + + // Replace all uses with the loaded value. + unsigned Idx = 0; + for (auto const &Pair : ConstUsesMap) { + ConstantFP *Const = Pair.first; + const SmallVector<Use *, 8> &Uses = Pair.second; + + for (Use *U : Uses) { + Instruction *UserInst = cast<Instruction>(U->getUser()); + BasicBlock *InsertionBB; + + // If the user is a PHI node, we must insert the load in the + // corresponding predecessor basic block. Otherwise, it's inserted into + // the same block as the use. + if (auto *PN = dyn_cast<PHINode>(UserInst)) + InsertionBB = PN->getIncomingBlock(*U); + else + InsertionBB = UserInst->getParent(); + + if (isa<CatchSwitchInst>(InsertionBB->getTerminator())) { + LLVM_DEBUG(dbgs() << "Bailing out: catchswitch means thre is no valid " + "insertion point.\n"); + return false; + } + + auto CacheKey = std::make_pair(Const, InsertionBB); + Value *LoadedVal = nullptr; + + // Re-use a load if it exists in the insertion block. + if (LocalLoads.count(CacheKey)) { + LoadedVal = LocalLoads.at(CacheKey); + } else { + // Otherwise, create a new GEP and Load at the correct insertion point. + // It is always safe to insert in the first insertion point in the BB, + // so do that and let other passes reorder. + IRBuilder<> Builder(InsertionBB, InsertionBB->getFirstInsertionPt()); + Value *ElementPtr = Builder.CreateConstInBoundsGEP2_64( + GlobalArray->getValueType(), GlobalArray, 0, Idx, "double.addr"); + LoadedVal = Builder.CreateLoad(DoubleTy, ElementPtr, "double.val"); + + // Cache the newly created load for this block. + LocalLoads[CacheKey] = LoadedVal; + } + + U->set(LoadedVal); + ++NumPromotedUses; + } + ++Idx; + } + + return true; +} diff --git a/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp b/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp index ae54ff1..16ef67d 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp +++ b/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp @@ -139,6 +139,7 @@ extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void LLVMInitializeRISCVTarget() { initializeRISCVExpandAtomicPseudoPass(*PR); initializeRISCVRedundantCopyEliminationPass(*PR); initializeRISCVAsmPrinterPass(*PR); + initializeRISCVPromoteConstantPass(*PR); } static Reloc::Model getEffectiveRelocModel(std::optional<Reloc::Model> RM) { @@ -462,6 +463,8 @@ void RISCVPassConfig::addIRPasses() { } bool RISCVPassConfig::addPreISel() { + if (TM->getOptLevel() != CodeGenOptLevel::None) + addPass(createRISCVPromoteConstantPass()); if (TM->getOptLevel() != CodeGenOptLevel::None) { // Add a barrier before instruction selection so that we will not get // deleted block address after enabling default outlining. See D99707 for |
