diff options
Diffstat (limited to 'llvm/lib/Target/RISCV')
| -rw-r--r-- | llvm/lib/Target/RISCV/GISel/RISCVPostLegalizerCombiner.cpp | 51 | ||||
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVCombine.td | 11 | ||||
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 4 | ||||
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVInstrInfoXSf.td | 2 | ||||
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVInstrInfoZb.td | 4 | ||||
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVInstrInfoZvfbf.td | 34 |
6 files changed, 100 insertions, 6 deletions
diff --git a/llvm/lib/Target/RISCV/GISel/RISCVPostLegalizerCombiner.cpp b/llvm/lib/Target/RISCV/GISel/RISCVPostLegalizerCombiner.cpp index 67b510d..f2b216b 100644 --- a/llvm/lib/Target/RISCV/GISel/RISCVPostLegalizerCombiner.cpp +++ b/llvm/lib/Target/RISCV/GISel/RISCVPostLegalizerCombiner.cpp @@ -27,6 +27,7 @@ #include "llvm/CodeGen/MachineDominators.h" #include "llvm/CodeGen/MachineFunctionPass.h" #include "llvm/CodeGen/TargetPassConfig.h" +#include "llvm/Support/FormatVariadic.h" #define GET_GICOMBINER_DEPS #include "RISCVGenPostLegalizeGICombiner.inc" @@ -42,6 +43,56 @@ namespace { #include "RISCVGenPostLegalizeGICombiner.inc" #undef GET_GICOMBINER_TYPES +/// Match: G_STORE (G_FCONSTANT +0.0), addr +/// Return the source vreg in MatchInfo if matched. +bool matchFoldFPZeroStore(MachineInstr &MI, MachineRegisterInfo &MRI, + const RISCVSubtarget &STI, Register &MatchInfo) { + if (MI.getOpcode() != TargetOpcode::G_STORE) + return false; + + Register SrcReg = MI.getOperand(0).getReg(); + if (!SrcReg.isVirtual()) + return false; + + MachineInstr *Def = MRI.getVRegDef(SrcReg); + if (!Def || Def->getOpcode() != TargetOpcode::G_FCONSTANT) + return false; + + auto *CFP = Def->getOperand(1).getFPImm(); + if (!CFP || !CFP->getValueAPF().isPosZero()) + return false; + + unsigned ValBits = MRI.getType(SrcReg).getSizeInBits(); + if ((ValBits == 16 && !STI.hasStdExtZfh()) || + (ValBits == 32 && !STI.hasStdExtF()) || + (ValBits == 64 && (!STI.hasStdExtD() || !STI.is64Bit()))) + return false; + + MatchInfo = SrcReg; + return true; +} + +/// Apply: rewrite to G_STORE (G_CONSTANT 0 [XLEN]), addr +void applyFoldFPZeroStore(MachineInstr &MI, MachineRegisterInfo &MRI, + MachineIRBuilder &B, const RISCVSubtarget &STI, + Register &MatchInfo) { + const unsigned XLen = STI.getXLen(); + + auto Zero = B.buildConstant(LLT::scalar(XLen), 0); + MI.getOperand(0).setReg(Zero.getReg(0)); + + MachineInstr *Def = MRI.getVRegDef(MatchInfo); + if (Def && MRI.use_nodbg_empty(MatchInfo)) + Def->eraseFromParent(); + +#ifndef NDEBUG + unsigned ValBits = MRI.getType(MatchInfo).getSizeInBits(); + LLVM_DEBUG(dbgs() << formatv("[{0}] Fold FP zero store -> int zero " + "(XLEN={1}, ValBits={2}):\n {3}\n", + DEBUG_TYPE, XLen, ValBits, MI)); +#endif +} + class RISCVPostLegalizerCombinerImpl : public Combiner { protected: const CombinerHelper Helper; diff --git a/llvm/lib/Target/RISCV/RISCVCombine.td b/llvm/lib/Target/RISCV/RISCVCombine.td index 995dd0c..a06b60d 100644 --- a/llvm/lib/Target/RISCV/RISCVCombine.td +++ b/llvm/lib/Target/RISCV/RISCVCombine.td @@ -19,11 +19,20 @@ def RISCVO0PreLegalizerCombiner: GICombiner< "RISCVO0PreLegalizerCombinerImpl", [optnone_combines]> { } +// Rule: fold store (fp +0.0) -> store (int zero [XLEN]) +def fp_zero_store_matchdata : GIDefMatchData<"Register">; +def fold_fp_zero_store : GICombineRule< + (defs root:$root, fp_zero_store_matchdata:$matchinfo), + (match (G_STORE $src, $addr):$root, + [{ return matchFoldFPZeroStore(*${root}, MRI, STI, ${matchinfo}); }]), + (apply [{ applyFoldFPZeroStore(*${root}, MRI, B, STI, ${matchinfo}); }])>; + // Post-legalization combines which are primarily optimizations. // TODO: Add more combines. def RISCVPostLegalizerCombiner : GICombiner<"RISCVPostLegalizerCombinerImpl", [sub_to_add, combines_for_extload, redundant_and, identity_combines, shift_immed_chain, - commute_constant_to_rhs, simplify_neg_minmax]> { + commute_constant_to_rhs, simplify_neg_minmax, + fold_fp_zero_store]> { } diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 26fe9ed..219e3f2 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -14797,7 +14797,7 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N, // to NEGW+MAX here requires a Freeze which breaks ComputeNumSignBits. SDValue Src = DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, N->getOperand(0)); - SDValue Abs = DAG.getNode(RISCVISD::ABSW, DL, MVT::i64, Src); + SDValue Abs = DAG.getNode(RISCVISD::NEGW_MAX, DL, MVT::i64, Src); Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Abs)); return; } @@ -21813,7 +21813,7 @@ unsigned RISCVTargetLowering::ComputeNumSignBitsForTargetNode( // Output is either all zero or operand 0. We can propagate sign bit count // from operand 0. return DAG.ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1); - case RISCVISD::ABSW: { + case RISCVISD::NEGW_MAX: { // We expand this at isel to negw+max. The result will have 33 sign bits // if the input has at least 33 sign bits. unsigned Tmp = diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoXSf.td b/llvm/lib/Target/RISCV/RISCVInstrInfoXSf.td index 4104abd..4c2f7f6 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoXSf.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoXSf.td @@ -482,7 +482,7 @@ let Predicates = [HasVendorXSfvfwmaccqqq] in { defm SF_VFWMACC_4x4x4 : VPseudoSiFiveVFWMACC; } -let Predicates = [HasVendorXSfvfnrclipxfqf] in { +let Predicates = [HasVendorXSfvfnrclipxfqf], AltFmtType = IS_NOT_ALTFMT in { defm SF_VFNRCLIP_XU_F_QF : VPseudoSiFiveVFNRCLIP; defm SF_VFNRCLIP_X_F_QF : VPseudoSiFiveVFNRCLIP; } diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td index 62b7bcd..6b9a75f 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td @@ -51,7 +51,7 @@ def riscv_zip : RVSDNode<"ZIP", SDTIntUnaryOp>; def riscv_unzip : RVSDNode<"UNZIP", SDTIntUnaryOp>; // RV64IZbb absolute value for i32. Expanded to (max (negw X), X) during isel. -def riscv_absw : RVSDNode<"ABSW", SDTIntUnaryOp>; +def riscv_negw_max : RVSDNode<"NEGW_MAX", SDTIntUnaryOp>; // Scalar cryptography def riscv_clmul : RVSDNode<"CLMUL", SDTIntBinOp>; @@ -610,7 +610,7 @@ def : PatGpr<riscv_clzw, CLZW>; def : PatGpr<riscv_ctzw, CTZW>; def : Pat<(i64 (ctpop (i64 (zexti32 (i64 GPR:$rs1))))), (CPOPW GPR:$rs1)>; -def : Pat<(i64 (riscv_absw GPR:$rs1)), +def : Pat<(i64 (riscv_negw_max GPR:$rs1)), (MAX GPR:$rs1, (XLenVT (SUBW (XLenVT X0), GPR:$rs1)))>; } // Predicates = [HasStdExtZbb, IsRV64] diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZvfbf.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZvfbf.td index f7d1a09..b9c5b75 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoZvfbf.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZvfbf.td @@ -668,4 +668,38 @@ foreach vti = NoGroupBF16Vectors in { def : Pat<(vti.Scalar (extractelt (vti.Vector vti.RegClass:$rs2), 0)), (vfmv_f_s_inst vti.RegClass:$rs2, vti.Log2SEW)>; } + +let Predicates = [HasStdExtZvfbfa] in { + foreach fvtiToFWti = AllWidenableBF16ToFloatVectors in { + defvar fvti = fvtiToFWti.Vti; + defvar fwti = fvtiToFWti.Wti; + def : Pat<(fwti.Vector (any_riscv_fpextend_vl + (fvti.Vector fvti.RegClass:$rs1), + (fvti.Mask VMV0:$vm), + VLOpFrag)), + (!cast<Instruction>("PseudoVFWCVT_F_F_ALT_V_"#fvti.LMul.MX#"_E"#fvti.SEW#"_MASK") + (fwti.Vector (IMPLICIT_DEF)), fvti.RegClass:$rs1, + (fvti.Mask VMV0:$vm), + GPR:$vl, fvti.Log2SEW, TA_MA)>; + + def : Pat<(fvti.Vector (any_riscv_fpround_vl + (fwti.Vector fwti.RegClass:$rs1), + (fwti.Mask VMV0:$vm), VLOpFrag)), + (!cast<Instruction>("PseudoVFNCVT_F_F_ALT_W_"#fvti.LMul.MX#"_E"#fvti.SEW#"_MASK") + (fvti.Vector (IMPLICIT_DEF)), fwti.RegClass:$rs1, + (fwti.Mask VMV0:$vm), + // Value to indicate no rounding mode change in + // RISCVInsertReadWriteCSR + FRM_DYN, + GPR:$vl, fvti.Log2SEW, TA_MA)>; + def : Pat<(fvti.Vector (fpround (fwti.Vector fwti.RegClass:$rs1))), + (!cast<Instruction>("PseudoVFNCVT_F_F_ALT_W_"#fvti.LMul.MX#"_E"#fvti.SEW) + (fvti.Vector (IMPLICIT_DEF)), + fwti.RegClass:$rs1, + // Value to indicate no rounding mode change in + // RISCVInsertReadWriteCSR + FRM_DYN, + fvti.AVL, fvti.Log2SEW, TA_MA)>; + } +} } // Predicates = [HasStdExtZvfbfa] |
