diff options
Diffstat (limited to 'llvm/lib/Target/RISCV')
| -rw-r--r-- | llvm/lib/Target/RISCV/CMakeLists.txt | 1 | ||||
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCV.h | 4 | ||||
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 81 | ||||
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp | 7 | ||||
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVPromoteConstant.cpp | 213 | ||||
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVTargetMachine.cpp | 3 |
6 files changed, 264 insertions, 45 deletions
diff --git a/llvm/lib/Target/RISCV/CMakeLists.txt b/llvm/lib/Target/RISCV/CMakeLists.txt index 0ff178e..e9088a4 100644 --- a/llvm/lib/Target/RISCV/CMakeLists.txt +++ b/llvm/lib/Target/RISCV/CMakeLists.txt @@ -58,6 +58,7 @@ add_llvm_target(RISCVCodeGen RISCVMoveMerger.cpp RISCVOptWInstrs.cpp RISCVPostRAExpandPseudoInsts.cpp + RISCVPromoteConstant.cpp RISCVPushPopOptimizer.cpp RISCVRedundantCopyElimination.cpp RISCVRegisterInfo.cpp diff --git a/llvm/lib/Target/RISCV/RISCV.h b/llvm/lib/Target/RISCV/RISCV.h index ae94101..51e8e85 100644 --- a/llvm/lib/Target/RISCV/RISCV.h +++ b/llvm/lib/Target/RISCV/RISCV.h @@ -20,6 +20,7 @@ namespace llvm { class FunctionPass; class InstructionSelector; +class ModulePass; class PassRegistry; class RISCVRegisterBankInfo; class RISCVSubtarget; @@ -111,6 +112,9 @@ void initializeRISCVO0PreLegalizerCombinerPass(PassRegistry &); FunctionPass *createRISCVPreLegalizerCombiner(); void initializeRISCVPreLegalizerCombinerPass(PassRegistry &); +ModulePass *createRISCVPromoteConstantPass(); +void initializeRISCVPromoteConstantPass(PassRegistry &); + FunctionPass *createRISCVVLOptimizerPass(); void initializeRISCVVLOptimizerPass(PassRegistry &); diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index c3f100e..995ae75 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -16496,32 +16496,42 @@ static SDValue expandMulToAddOrSubOfShl(SDNode *N, SelectionDAG &DAG, } static SDValue getShlAddShlAdd(SDNode *N, SelectionDAG &DAG, unsigned ShX, - unsigned ShY) { + unsigned ShY, bool AddX) { SDLoc DL(N); EVT VT = N->getValueType(0); SDValue X = N->getOperand(0); SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, DAG.getTargetConstant(ShY, DL, VT), X); return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359, - DAG.getTargetConstant(ShX, DL, VT), Mul359); + DAG.getTargetConstant(ShX, DL, VT), AddX ? X : Mul359); } static SDValue expandMulToShlAddShlAdd(SDNode *N, SelectionDAG &DAG, uint64_t MulAmt) { + // 3/5/9 * 3/5/9 -> (shXadd (shYadd X, X), (shYadd X, X)) switch (MulAmt) { case 5 * 3: - return getShlAddShlAdd(N, DAG, 2, 1); + return getShlAddShlAdd(N, DAG, 2, 1, /*AddX=*/false); case 9 * 3: - return getShlAddShlAdd(N, DAG, 3, 1); + return getShlAddShlAdd(N, DAG, 3, 1, /*AddX=*/false); case 5 * 5: - return getShlAddShlAdd(N, DAG, 2, 2); + return getShlAddShlAdd(N, DAG, 2, 2, /*AddX=*/false); case 9 * 5: - return getShlAddShlAdd(N, DAG, 3, 2); + return getShlAddShlAdd(N, DAG, 3, 2, /*AddX=*/false); case 9 * 9: - return getShlAddShlAdd(N, DAG, 3, 3); + return getShlAddShlAdd(N, DAG, 3, 3, /*AddX=*/false); default: - return SDValue(); + break; } + + // 2/4/8 * 3/5/9 + 1 -> (shXadd (shYadd X, X), X) + int ShX; + if (int ShY = isShifted359(MulAmt - 1, ShX)) { + assert(ShX != 0 && "MulAmt=4,6,10 handled before"); + if (ShX <= 3) + return getShlAddShlAdd(N, DAG, ShX, ShY, /*AddX=*/true); + } + return SDValue(); } // Try to expand a scalar multiply to a faster sequence. @@ -16581,41 +16591,30 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, DAG.getConstant(Shift, DL, VT)); } - // 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X) - if (SDValue V = expandMulToShlAddShlAdd(N, DAG, MulAmt)) - return V; + // 3/5/9 * 3/5/9 * 2^N - In particular, this covers multiples + // of 25 which happen to be quite common. + // (2/4/8 * 3/5/9 + 1) * 2^N + Shift = llvm::countr_zero(MulAmt); + if (SDValue V = expandMulToShlAddShlAdd(N, DAG, MulAmt >> Shift)) { + if (Shift == 0) + return V; + SDLoc DL(N); + return DAG.getNode(ISD::SHL, DL, VT, V, DAG.getConstant(Shift, DL, VT)); + } // If this is a power 2 + 2/4/8, we can use a shift followed by a single // shXadd. First check if this a sum of two power of 2s because that's // easy. Then count how many zeros are up to the first bit. - if (isPowerOf2_64(MulAmt & (MulAmt - 1))) { - unsigned ScaleShift = llvm::countr_zero(MulAmt); - if (ScaleShift >= 1 && ScaleShift < 4) { - unsigned ShiftAmt = Log2_64((MulAmt & (MulAmt - 1))); - SDLoc DL(N); - SDValue Shift1 = - DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShiftAmt, DL, VT)); - return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, - DAG.getTargetConstant(ScaleShift, DL, VT), Shift1); - } + if (Shift >= 1 && Shift <= 3 && isPowerOf2_64(MulAmt & (MulAmt - 1))) { + unsigned ShiftAmt = llvm::countr_zero((MulAmt & (MulAmt - 1))); + SDLoc DL(N); + SDValue Shift1 = + DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShiftAmt, DL, VT)); + return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, + DAG.getTargetConstant(Shift, DL, VT), Shift1); } - // 2^(1,2,3) * 3,5,9 + 1 -> (shXadd (shYadd x, x), x) - // This is the two instruction form, there are also three instruction - // variants we could implement. e.g. - // (2^(1,2,3) * 3,5,9 + 1) << C2 - // 2^(C1>3) * 3,5,9 +/- 1 - if (int ShXAmount = isShifted359(MulAmt - 1, Shift)) { - assert(Shift != 0 && "MulAmt=4,6,10 handled before"); - if (Shift <= 3) { - SDLoc DL(N); - SDValue Mul359 = - DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, - DAG.getTargetConstant(ShXAmount, DL, VT), X); - return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359, - DAG.getTargetConstant(Shift, DL, VT), X); - } - } + // TODO: 2^(C1>3) * 3,5,9 +/- 1 // 2^n + 2/4/8 + 1 -> (add (shl X, C1), (shXadd X, X)) if (MulAmt > 2 && isPowerOf2_64((MulAmt - 1) & (MulAmt - 2))) { @@ -16647,14 +16646,6 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, return DAG.getNode(ISD::SUB, DL, VT, Shift1, Mul359); } } - - // 3/5/9 * 3/5/9 * 2^N - In particular, this covers multiples - // of 25 which happen to be quite common. - Shift = llvm::countr_zero(MulAmt); - if (SDValue V = expandMulToShlAddShlAdd(N, DAG, MulAmt >> Shift)) { - SDLoc DL(N); - return DAG.getNode(ISD::SHL, DL, VT, V, DAG.getConstant(Shift, DL, VT)); - } } if (SDValue V = expandMulToAddOrSubOfShl(N, DAG, MulAmt)) diff --git a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp index 636e31c..bf9de0a 100644 --- a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp +++ b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp @@ -1583,7 +1583,10 @@ void RISCVInsertVSETVLI::emitVSETVLIs(MachineBasicBlock &MBB) { if (!TII->isAddImmediate(*DeadMI, Reg)) continue; LIS->RemoveMachineInstrFromMaps(*DeadMI); + Register AddReg = DeadMI->getOperand(1).getReg(); DeadMI->eraseFromParent(); + if (AddReg.isVirtual()) + LIS->shrinkToUses(&LIS->getInterval(AddReg)); } } } @@ -1869,11 +1872,15 @@ void RISCVInsertVSETVLI::coalesceVSETVLIs(MachineBasicBlock &MBB) const { // Loop over the dead AVL values, and delete them now. This has // to be outside the above loop to avoid invalidating iterators. for (auto *MI : ToDelete) { + assert(MI->getOpcode() == RISCV::ADDI); + Register AddReg = MI->getOperand(1).getReg(); if (LIS) { LIS->removeInterval(MI->getOperand(0).getReg()); LIS->RemoveMachineInstrFromMaps(*MI); } MI->eraseFromParent(); + if (LIS && AddReg.isVirtual()) + LIS->shrinkToUses(&LIS->getInterval(AddReg)); } } diff --git a/llvm/lib/Target/RISCV/RISCVPromoteConstant.cpp b/llvm/lib/Target/RISCV/RISCVPromoteConstant.cpp new file mode 100644 index 0000000..bf1f69f --- /dev/null +++ b/llvm/lib/Target/RISCV/RISCVPromoteConstant.cpp @@ -0,0 +1,213 @@ +//==- RISCVPromoteConstant.cpp - Promote constant fp to global for RISC-V --==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "RISCV.h" +#include "RISCVSubtarget.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/CodeGen/TargetLowering.h" +#include "llvm/CodeGen/TargetPassConfig.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" + +using namespace llvm; + +#define DEBUG_TYPE "riscv-promote-const" +#define RISCV_PROMOTE_CONSTANT_NAME "RISC-V Promote Constants" + +STATISTIC(NumPromoted, "Number of constant literals promoted to globals"); +STATISTIC(NumPromotedUses, "Number of uses of promoted literal constants"); + +namespace { + +class RISCVPromoteConstant : public ModulePass { +public: + static char ID; + RISCVPromoteConstant() : ModulePass(ID) {} + + StringRef getPassName() const override { return RISCV_PROMOTE_CONSTANT_NAME; } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<TargetPassConfig>(); + AU.setPreservesCFG(); + } + + /// Iterate over the functions and promote the double fp constants that + /// would otherwise go into the constant pool to a constant array. + bool runOnModule(Module &M) override { + if (skipModule(M)) + return false; + // TargetMachine and Subtarget are needed to query isFPImmlegal. + const TargetPassConfig &TPC = getAnalysis<TargetPassConfig>(); + const TargetMachine &TM = TPC.getTM<TargetMachine>(); + bool Changed = false; + for (Function &F : M) { + const RISCVSubtarget &ST = TM.getSubtarget<RISCVSubtarget>(F); + const RISCVTargetLowering *TLI = ST.getTargetLowering(); + Changed |= runOnFunction(F, TLI); + } + return Changed; + } + +private: + bool runOnFunction(Function &F, const RISCVTargetLowering *TLI); +}; +} // end anonymous namespace + +char RISCVPromoteConstant::ID = 0; + +INITIALIZE_PASS(RISCVPromoteConstant, DEBUG_TYPE, RISCV_PROMOTE_CONSTANT_NAME, + false, false) + +ModulePass *llvm::createRISCVPromoteConstantPass() { + return new RISCVPromoteConstant(); +} + +bool RISCVPromoteConstant::runOnFunction(Function &F, + const RISCVTargetLowering *TLI) { + if (F.hasOptNone() || F.hasOptSize()) + return false; + + // Bail out and make no transformation if the target doesn't support + // doubles, or if we're not targeting RV64 as we currently see some + // regressions for those targets. + if (!TLI->isTypeLegal(MVT::f64) || !TLI->isTypeLegal(MVT::i64)) + return false; + + // Collect all unique double constants and their uses in the function. Use + // MapVector to preserve insertion order. + MapVector<ConstantFP *, SmallVector<Use *, 8>> ConstUsesMap; + + for (Instruction &I : instructions(F)) { + for (Use &U : I.operands()) { + auto *C = dyn_cast<ConstantFP>(U.get()); + if (!C || !C->getType()->isDoubleTy()) + continue; + // Do not promote if it wouldn't be loaded from the constant pool. + if (TLI->isFPImmLegal(C->getValueAPF(), MVT::f64, + /*ForCodeSize=*/false)) + continue; + // Do not promote a constant if it is used as an immediate argument + // for an intrinsic. + if (auto *II = dyn_cast<IntrinsicInst>(U.getUser())) { + Function *IntrinsicFunc = II->getFunction(); + unsigned OperandIdx = U.getOperandNo(); + if (IntrinsicFunc && IntrinsicFunc->getAttributes().hasParamAttr( + OperandIdx, Attribute::ImmArg)) { + LLVM_DEBUG(dbgs() << "Skipping promotion of constant in: " << *II + << " because operand " << OperandIdx + << " must be an immediate.\n"); + continue; + } + } + // Note: FP args to inline asm would be problematic if we had a + // constraint that required an immediate floating point operand. At the + // time of writing LLVM doesn't recognise such a constraint. + ConstUsesMap[C].push_back(&U); + } + } + + int PromotableConstants = ConstUsesMap.size(); + LLVM_DEBUG(dbgs() << "Found " << PromotableConstants + << " promotable constants in " << F.getName() << "\n"); + // Bail out if no promotable constants found, or if only one is found. + if (PromotableConstants < 2) { + LLVM_DEBUG(dbgs() << "Performing no promotions as insufficient promotable " + "constants found\n"); + return false; + } + + NumPromoted += PromotableConstants; + + // Create a global array containing the promoted constants. + Module *M = F.getParent(); + Type *DoubleTy = Type::getDoubleTy(M->getContext()); + + SmallVector<Constant *, 16> ConstantVector; + for (auto const &Pair : ConstUsesMap) + ConstantVector.push_back(Pair.first); + + ArrayType *ArrayTy = ArrayType::get(DoubleTy, ConstantVector.size()); + Constant *GlobalArrayInitializer = + ConstantArray::get(ArrayTy, ConstantVector); + + auto *GlobalArray = new GlobalVariable( + *M, ArrayTy, + /*isConstant=*/true, GlobalValue::InternalLinkage, GlobalArrayInitializer, + ".promoted_doubles." + F.getName()); + + // A cache to hold the loaded value for a given constant within a basic block. + DenseMap<std::pair<ConstantFP *, BasicBlock *>, Value *> LocalLoads; + + // Replace all uses with the loaded value. + unsigned Idx = 0; + for (auto const &Pair : ConstUsesMap) { + ConstantFP *Const = Pair.first; + const SmallVector<Use *, 8> &Uses = Pair.second; + + for (Use *U : Uses) { + Instruction *UserInst = cast<Instruction>(U->getUser()); + BasicBlock *InsertionBB; + + // If the user is a PHI node, we must insert the load in the + // corresponding predecessor basic block. Otherwise, it's inserted into + // the same block as the use. + if (auto *PN = dyn_cast<PHINode>(UserInst)) + InsertionBB = PN->getIncomingBlock(*U); + else + InsertionBB = UserInst->getParent(); + + if (isa<CatchSwitchInst>(InsertionBB->getTerminator())) { + LLVM_DEBUG(dbgs() << "Bailing out: catchswitch means thre is no valid " + "insertion point.\n"); + return false; + } + + auto CacheKey = std::make_pair(Const, InsertionBB); + Value *LoadedVal = nullptr; + + // Re-use a load if it exists in the insertion block. + if (LocalLoads.count(CacheKey)) { + LoadedVal = LocalLoads.at(CacheKey); + } else { + // Otherwise, create a new GEP and Load at the correct insertion point. + // It is always safe to insert in the first insertion point in the BB, + // so do that and let other passes reorder. + IRBuilder<> Builder(InsertionBB, InsertionBB->getFirstInsertionPt()); + Value *ElementPtr = Builder.CreateConstInBoundsGEP2_64( + GlobalArray->getValueType(), GlobalArray, 0, Idx, "double.addr"); + LoadedVal = Builder.CreateLoad(DoubleTy, ElementPtr, "double.val"); + + // Cache the newly created load for this block. + LocalLoads[CacheKey] = LoadedVal; + } + + U->set(LoadedVal); + ++NumPromotedUses; + } + ++Idx; + } + + return true; +} diff --git a/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp b/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp index ae54ff1..16ef67d 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp +++ b/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp @@ -139,6 +139,7 @@ extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void LLVMInitializeRISCVTarget() { initializeRISCVExpandAtomicPseudoPass(*PR); initializeRISCVRedundantCopyEliminationPass(*PR); initializeRISCVAsmPrinterPass(*PR); + initializeRISCVPromoteConstantPass(*PR); } static Reloc::Model getEffectiveRelocModel(std::optional<Reloc::Model> RM) { @@ -462,6 +463,8 @@ void RISCVPassConfig::addIRPasses() { } bool RISCVPassConfig::addPreISel() { + if (TM->getOptLevel() != CodeGenOptLevel::None) + addPass(createRISCVPromoteConstantPass()); if (TM->getOptLevel() != CodeGenOptLevel::None) { // Add a barrier before instruction selection so that we will not get // deleted block address after enabling default outlining. See D99707 for |
