aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/RISCV
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/RISCV')
-rw-r--r--llvm/lib/Target/RISCV/RISCVCallingConv.td2
-rw-r--r--llvm/lib/Target/RISCV/RISCVISelLowering.cpp81
-rw-r--r--llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp5
-rw-r--r--llvm/lib/Target/RISCV/RISCVRegisterInfo.h2
-rw-r--r--llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp38
-rw-r--r--llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h6
6 files changed, 132 insertions, 2 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVCallingConv.td b/llvm/lib/Target/RISCV/RISCVCallingConv.td
index ad06f47..98e05b7 100644
--- a/llvm/lib/Target/RISCV/RISCVCallingConv.td
+++ b/llvm/lib/Target/RISCV/RISCVCallingConv.td
@@ -42,6 +42,8 @@ def CSR_ILP32D_LP64D_V
// Needed for implementation of RISCVRegisterInfo::getNoPreservedMask()
def CSR_NoRegs : CalleeSavedRegs<(add)>;
+def CSR_IPRA : CalleeSavedRegs<(add X1)>;
+
// Interrupt handler needs to save/restore all registers that are used,
// both Caller and Callee saved registers.
def CSR_Interrupt : CalleeSavedRegs<(add X1, (sequence "X%u", 5, 31))>;
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 8e3caf5..7c3b583 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -17759,6 +17759,83 @@ static SDValue combineScalarCTPOPToVCPOP(SDNode *N, SelectionDAG &DAG,
return DAG.getZExtOrTrunc(Pop, DL, VT);
}
+static SDValue performSHLCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ const RISCVSubtarget &Subtarget) {
+ // (shl (zext x), y) -> (vwsll x, y)
+ if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
+ return V;
+
+ // (shl (sext x), C) -> (vwmulsu x, 1u << C)
+ // (shl (zext x), C) -> (vwmulu x, 1u << C)
+
+ if (!DCI.isAfterLegalizeDAG())
+ return SDValue();
+
+ SDValue LHS = N->getOperand(0);
+ if (!LHS.hasOneUse())
+ return SDValue();
+ unsigned Opcode;
+ switch (LHS.getOpcode()) {
+ case ISD::SIGN_EXTEND:
+ case RISCVISD::VSEXT_VL:
+ Opcode = RISCVISD::VWMULSU_VL;
+ break;
+ case ISD::ZERO_EXTEND:
+ case RISCVISD::VZEXT_VL:
+ Opcode = RISCVISD::VWMULU_VL;
+ break;
+ default:
+ return SDValue();
+ }
+
+ SDValue RHS = N->getOperand(1);
+ APInt ShAmt;
+ uint64_t ShAmtInt;
+ if (ISD::isConstantSplatVector(RHS.getNode(), ShAmt))
+ ShAmtInt = ShAmt.getZExtValue();
+ else if (RHS.getOpcode() == RISCVISD::VMV_V_X_VL &&
+ RHS.getOperand(1).getOpcode() == ISD::Constant)
+ ShAmtInt = RHS.getConstantOperandVal(1);
+ else
+ return SDValue();
+
+ // Better foldings:
+ // (shl (sext x), 1) -> (vwadd x, x)
+ // (shl (zext x), 1) -> (vwaddu x, x)
+ if (ShAmtInt <= 1)
+ return SDValue();
+
+ SDValue NarrowOp = LHS.getOperand(0);
+ MVT NarrowVT = NarrowOp.getSimpleValueType();
+ uint64_t NarrowBits = NarrowVT.getScalarSizeInBits();
+ if (ShAmtInt >= NarrowBits)
+ return SDValue();
+ MVT VT = N->getSimpleValueType(0);
+ if (NarrowBits * 2 != VT.getScalarSizeInBits())
+ return SDValue();
+
+ SelectionDAG &DAG = DCI.DAG;
+ SDLoc DL(N);
+ SDValue Passthru, Mask, VL;
+ switch (N->getOpcode()) {
+ case ISD::SHL:
+ Passthru = DAG.getUNDEF(VT);
+ std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
+ break;
+ case RISCVISD::SHL_VL:
+ Passthru = N->getOperand(2);
+ Mask = N->getOperand(3);
+ VL = N->getOperand(4);
+ break;
+ default:
+ llvm_unreachable("Expected SHL");
+ }
+ return DAG.getNode(Opcode, DL, VT, NarrowOp,
+ DAG.getConstant(1ULL << ShAmtInt, SDLoc(RHS), NarrowVT),
+ Passthru, Mask, VL);
+}
+
SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
SelectionDAG &DAG = DCI.DAG;
@@ -18392,7 +18469,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
break;
}
case RISCVISD::SHL_VL:
- if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
+ if (SDValue V = performSHLCombine(N, DCI, Subtarget))
return V;
[[fallthrough]];
case RISCVISD::SRA_VL:
@@ -18417,7 +18494,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
case ISD::SRL:
case ISD::SHL: {
if (N->getOpcode() == ISD::SHL) {
- if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
+ if (SDValue V = performSHLCombine(N, DCI, Subtarget))
return V;
}
SDValue ShAmt = N->getOperand(1);
diff --git a/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp b/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp
index b0a5269..7a99bfd 100644
--- a/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp
@@ -56,6 +56,11 @@ RISCVRegisterInfo::RISCVRegisterInfo(unsigned HwMode)
/*PC*/0, HwMode) {}
const MCPhysReg *
+RISCVRegisterInfo::getIPRACSRegs(const MachineFunction *MF) const {
+ return CSR_IPRA_SaveList;
+}
+
+const MCPhysReg *
RISCVRegisterInfo::getCalleeSavedRegs(const MachineFunction *MF) const {
auto &Subtarget = MF->getSubtarget<RISCVSubtarget>();
if (MF->getFunction().getCallingConv() == CallingConv::GHC)
diff --git a/llvm/lib/Target/RISCV/RISCVRegisterInfo.h b/llvm/lib/Target/RISCV/RISCVRegisterInfo.h
index 3ab79694..6c4e9c7 100644
--- a/llvm/lib/Target/RISCV/RISCVRegisterInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVRegisterInfo.h
@@ -62,6 +62,8 @@ struct RISCVRegisterInfo : public RISCVGenRegisterInfo {
const MCPhysReg *getCalleeSavedRegs(const MachineFunction *MF) const override;
+ const MCPhysReg *getIPRACSRegs(const MachineFunction *MF) const override;
+
BitVector getReservedRegs(const MachineFunction &MF) const override;
bool isAsmClobberable(const MachineFunction &MF,
MCRegister PhysReg) const override;
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index fa7c7c5..cb2ec1d 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -940,6 +940,44 @@ InstructionCost RISCVTTIImpl::getGatherScatterOpCost(
return NumLoads * MemOpCost;
}
+InstructionCost RISCVTTIImpl::getExpandCompressMemoryOpCost(
+ unsigned Opcode, Type *DataTy, bool VariableMask, Align Alignment,
+ TTI::TargetCostKind CostKind, const Instruction *I) {
+ bool IsLegal = (Opcode == Instruction::Store &&
+ isLegalMaskedCompressStore(DataTy, Alignment)) ||
+ (Opcode == Instruction::Load &&
+ isLegalMaskedExpandLoad(DataTy, Alignment));
+ if (!IsLegal || CostKind != TTI::TCK_RecipThroughput)
+ return BaseT::getExpandCompressMemoryOpCost(Opcode, DataTy, VariableMask,
+ Alignment, CostKind, I);
+ // Example compressstore sequence:
+ // vsetivli zero, 8, e32, m2, ta, ma (ignored)
+ // vcompress.vm v10, v8, v0
+ // vcpop.m a1, v0
+ // vsetvli zero, a1, e32, m2, ta, ma
+ // vse32.v v10, (a0)
+ // Example expandload sequence:
+ // vsetivli zero, 8, e8, mf2, ta, ma (ignored)
+ // vcpop.m a1, v0
+ // vsetvli zero, a1, e32, m2, ta, ma
+ // vle32.v v10, (a0)
+ // vsetivli zero, 8, e32, m2, ta, ma
+ // viota.m v12, v0
+ // vrgather.vv v8, v10, v12, v0.t
+ auto MemOpCost =
+ getMemoryOpCost(Opcode, DataTy, Alignment, /*AddressSpace*/ 0, CostKind);
+ auto LT = getTypeLegalizationCost(DataTy);
+ SmallVector<unsigned, 4> Opcodes{RISCV::VSETVLI};
+ if (VariableMask)
+ Opcodes.push_back(RISCV::VCPOP_M);
+ if (Opcode == Instruction::Store)
+ Opcodes.append({RISCV::VCOMPRESS_VM});
+ else
+ Opcodes.append({RISCV::VSETIVLI, RISCV::VIOTA_M, RISCV::VRGATHER_VV});
+ return MemOpCost +
+ LT.first * getRISCVInstructionCost(Opcodes, LT.second, CostKind);
+}
+
InstructionCost RISCVTTIImpl::getStridedMemoryOpCost(
unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) {
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 042530b..5389e9b 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -174,6 +174,12 @@ public:
TTI::TargetCostKind CostKind,
const Instruction *I);
+ InstructionCost getExpandCompressMemoryOpCost(unsigned Opcode, Type *Src,
+ bool VariableMask,
+ Align Alignment,
+ TTI::TargetCostKind CostKind,
+ const Instruction *I = nullptr);
+
InstructionCost getStridedMemoryOpCost(unsigned Opcode, Type *DataTy,
const Value *Ptr, bool VariableMask,
Align Alignment,