diff options
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 123 |
1 files changed, 86 insertions, 37 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index ff8d2d9..9b45ff1 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -6576,6 +6576,87 @@ static SDValue performANY_EXTENDCombine(SDNode *N, return SDValue(N, 0); } +// Try to form VWMUL or VWMULU. +// FIXME: Support VWMULSU. +static SDValue combineMUL_VLToVWMUL(SDNode *N, SDValue Op0, SDValue Op1, + SelectionDAG &DAG) { + assert(N->getOpcode() == RISCVISD::MUL_VL && "Unexpected opcode"); + bool IsSignExt = Op0.getOpcode() == RISCVISD::VSEXT_VL; + bool IsZeroExt = Op0.getOpcode() == RISCVISD::VZEXT_VL; + if ((!IsSignExt && !IsZeroExt) || !Op0.hasOneUse()) + return SDValue(); + + SDValue Mask = N->getOperand(2); + SDValue VL = N->getOperand(3); + + // Make sure the mask and VL match. + if (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL) + return SDValue(); + + MVT VT = N->getSimpleValueType(0); + + // Determine the narrow size for a widening multiply. + unsigned NarrowSize = VT.getScalarSizeInBits() / 2; + MVT NarrowVT = MVT::getVectorVT(MVT::getIntegerVT(NarrowSize), + VT.getVectorElementCount()); + + SDLoc DL(N); + + // See if the other operand is the same opcode. + if (Op0.getOpcode() == Op1.getOpcode()) { + if (!Op1.hasOneUse()) + return SDValue(); + + // Make sure the mask and VL match. + if (Op1.getOperand(1) != Mask || Op1.getOperand(2) != VL) + return SDValue(); + + Op1 = Op1.getOperand(0); + } else if (Op1.getOpcode() == RISCVISD::VMV_V_X_VL) { + // The operand is a splat of a scalar. + + // The VL must be the same. + if (Op1.getOperand(1) != VL) + return SDValue(); + + // Get the scalar value. + Op1 = Op1.getOperand(0); + + // See if have enough sign bits or zero bits in the scalar to use a + // widening multiply by splatting to smaller element size. + unsigned EltBits = VT.getScalarSizeInBits(); + unsigned ScalarBits = Op1.getValueSizeInBits(); + // Make sure we're getting all element bits from the scalar register. + // FIXME: Support implicit sign extension of vmv.v.x? + if (ScalarBits < EltBits) + return SDValue(); + + if (IsSignExt) { + if (DAG.ComputeNumSignBits(Op1) <= (ScalarBits - NarrowSize)) + return SDValue(); + } else { + APInt Mask = APInt::getBitsSetFrom(ScalarBits, NarrowSize); + if (!DAG.MaskedValueIsZero(Op1, Mask)) + return SDValue(); + } + + Op1 = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT, Op1, VL); + } else + return SDValue(); + + Op0 = Op0.getOperand(0); + + // Re-introduce narrower extends if needed. + unsigned ExtOpc = IsSignExt ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL; + if (Op0.getValueType() != NarrowVT) + Op0 = DAG.getNode(ExtOpc, DL, NarrowVT, Op0, Mask, VL); + if (Op1.getValueType() != NarrowVT) + Op1 = DAG.getNode(ExtOpc, DL, NarrowVT, Op1, Mask, VL); + + unsigned WMulOpc = IsSignExt ? RISCVISD::VWMUL_VL : RISCVISD::VWMULU_VL; + return DAG.getNode(WMulOpc, DL, VT, Op0, Op1, Mask, VL); +} + SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { SelectionDAG &DAG = DCI.DAG; @@ -7027,45 +7108,13 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, break; } case RISCVISD::MUL_VL: { - // Try to form VWMUL or VWMULU. - // FIXME: Look for splat of extended scalar as well. - // FIXME: Support VWMULSU. SDValue Op0 = N->getOperand(0); SDValue Op1 = N->getOperand(1); - bool IsSignExt = Op0.getOpcode() == RISCVISD::VSEXT_VL; - bool IsZeroExt = Op0.getOpcode() == RISCVISD::VZEXT_VL; - if ((!IsSignExt && !IsZeroExt) || Op0.getOpcode() != Op1.getOpcode()) - return SDValue(); - - // Make sure the extends have a single use. - if (!Op0.hasOneUse() || !Op1.hasOneUse()) - return SDValue(); - - SDValue Mask = N->getOperand(2); - SDValue VL = N->getOperand(3); - if (Op0.getOperand(1) != Mask || Op1.getOperand(1) != Mask || - Op0.getOperand(2) != VL || Op1.getOperand(2) != VL) - return SDValue(); - - Op0 = Op0.getOperand(0); - Op1 = Op1.getOperand(0); - - MVT VT = N->getSimpleValueType(0); - MVT NarrowVT = - MVT::getVectorVT(MVT::getIntegerVT(VT.getScalarSizeInBits() / 2), - VT.getVectorElementCount()); - - SDLoc DL(N); - - // Re-introduce narrower extends if needed. - unsigned ExtOpc = IsSignExt ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL; - if (Op0.getValueType() != NarrowVT) - Op0 = DAG.getNode(ExtOpc, DL, NarrowVT, Op0, Mask, VL); - if (Op1.getValueType() != NarrowVT) - Op1 = DAG.getNode(ExtOpc, DL, NarrowVT, Op1, Mask, VL); - - unsigned WMulOpc = IsSignExt ? RISCVISD::VWMUL_VL : RISCVISD::VWMULU_VL; - return DAG.getNode(WMulOpc, DL, VT, Op0, Op1, Mask, VL); + if (SDValue V = combineMUL_VLToVWMUL(N, Op0, Op1, DAG)) + return V; + if (SDValue V = combineMUL_VLToVWMUL(N, Op1, Op0, DAG)) + return V; + return SDValue(); } } |