aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/Target/RISCV/RISCVISelLowering.cpp123
1 files changed, 86 insertions, 37 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index ff8d2d9..9b45ff1 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -6576,6 +6576,87 @@ static SDValue performANY_EXTENDCombine(SDNode *N,
return SDValue(N, 0);
}
+// Try to form VWMUL or VWMULU.
+// FIXME: Support VWMULSU.
+static SDValue combineMUL_VLToVWMUL(SDNode *N, SDValue Op0, SDValue Op1,
+ SelectionDAG &DAG) {
+ assert(N->getOpcode() == RISCVISD::MUL_VL && "Unexpected opcode");
+ bool IsSignExt = Op0.getOpcode() == RISCVISD::VSEXT_VL;
+ bool IsZeroExt = Op0.getOpcode() == RISCVISD::VZEXT_VL;
+ if ((!IsSignExt && !IsZeroExt) || !Op0.hasOneUse())
+ return SDValue();
+
+ SDValue Mask = N->getOperand(2);
+ SDValue VL = N->getOperand(3);
+
+ // Make sure the mask and VL match.
+ if (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL)
+ return SDValue();
+
+ MVT VT = N->getSimpleValueType(0);
+
+ // Determine the narrow size for a widening multiply.
+ unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
+ MVT NarrowVT = MVT::getVectorVT(MVT::getIntegerVT(NarrowSize),
+ VT.getVectorElementCount());
+
+ SDLoc DL(N);
+
+ // See if the other operand is the same opcode.
+ if (Op0.getOpcode() == Op1.getOpcode()) {
+ if (!Op1.hasOneUse())
+ return SDValue();
+
+ // Make sure the mask and VL match.
+ if (Op1.getOperand(1) != Mask || Op1.getOperand(2) != VL)
+ return SDValue();
+
+ Op1 = Op1.getOperand(0);
+ } else if (Op1.getOpcode() == RISCVISD::VMV_V_X_VL) {
+ // The operand is a splat of a scalar.
+
+ // The VL must be the same.
+ if (Op1.getOperand(1) != VL)
+ return SDValue();
+
+ // Get the scalar value.
+ Op1 = Op1.getOperand(0);
+
+ // See if have enough sign bits or zero bits in the scalar to use a
+ // widening multiply by splatting to smaller element size.
+ unsigned EltBits = VT.getScalarSizeInBits();
+ unsigned ScalarBits = Op1.getValueSizeInBits();
+ // Make sure we're getting all element bits from the scalar register.
+ // FIXME: Support implicit sign extension of vmv.v.x?
+ if (ScalarBits < EltBits)
+ return SDValue();
+
+ if (IsSignExt) {
+ if (DAG.ComputeNumSignBits(Op1) <= (ScalarBits - NarrowSize))
+ return SDValue();
+ } else {
+ APInt Mask = APInt::getBitsSetFrom(ScalarBits, NarrowSize);
+ if (!DAG.MaskedValueIsZero(Op1, Mask))
+ return SDValue();
+ }
+
+ Op1 = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT, Op1, VL);
+ } else
+ return SDValue();
+
+ Op0 = Op0.getOperand(0);
+
+ // Re-introduce narrower extends if needed.
+ unsigned ExtOpc = IsSignExt ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL;
+ if (Op0.getValueType() != NarrowVT)
+ Op0 = DAG.getNode(ExtOpc, DL, NarrowVT, Op0, Mask, VL);
+ if (Op1.getValueType() != NarrowVT)
+ Op1 = DAG.getNode(ExtOpc, DL, NarrowVT, Op1, Mask, VL);
+
+ unsigned WMulOpc = IsSignExt ? RISCVISD::VWMUL_VL : RISCVISD::VWMULU_VL;
+ return DAG.getNode(WMulOpc, DL, VT, Op0, Op1, Mask, VL);
+}
+
SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
SelectionDAG &DAG = DCI.DAG;
@@ -7027,45 +7108,13 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
break;
}
case RISCVISD::MUL_VL: {
- // Try to form VWMUL or VWMULU.
- // FIXME: Look for splat of extended scalar as well.
- // FIXME: Support VWMULSU.
SDValue Op0 = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
- bool IsSignExt = Op0.getOpcode() == RISCVISD::VSEXT_VL;
- bool IsZeroExt = Op0.getOpcode() == RISCVISD::VZEXT_VL;
- if ((!IsSignExt && !IsZeroExt) || Op0.getOpcode() != Op1.getOpcode())
- return SDValue();
-
- // Make sure the extends have a single use.
- if (!Op0.hasOneUse() || !Op1.hasOneUse())
- return SDValue();
-
- SDValue Mask = N->getOperand(2);
- SDValue VL = N->getOperand(3);
- if (Op0.getOperand(1) != Mask || Op1.getOperand(1) != Mask ||
- Op0.getOperand(2) != VL || Op1.getOperand(2) != VL)
- return SDValue();
-
- Op0 = Op0.getOperand(0);
- Op1 = Op1.getOperand(0);
-
- MVT VT = N->getSimpleValueType(0);
- MVT NarrowVT =
- MVT::getVectorVT(MVT::getIntegerVT(VT.getScalarSizeInBits() / 2),
- VT.getVectorElementCount());
-
- SDLoc DL(N);
-
- // Re-introduce narrower extends if needed.
- unsigned ExtOpc = IsSignExt ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL;
- if (Op0.getValueType() != NarrowVT)
- Op0 = DAG.getNode(ExtOpc, DL, NarrowVT, Op0, Mask, VL);
- if (Op1.getValueType() != NarrowVT)
- Op1 = DAG.getNode(ExtOpc, DL, NarrowVT, Op1, Mask, VL);
-
- unsigned WMulOpc = IsSignExt ? RISCVISD::VWMUL_VL : RISCVISD::VWMULU_VL;
- return DAG.getNode(WMulOpc, DL, VT, Op0, Op1, Mask, VL);
+ if (SDValue V = combineMUL_VLToVWMUL(N, Op0, Op1, DAG))
+ return V;
+ if (SDValue V = combineMUL_VLToVWMUL(N, Op1, Op0, DAG))
+ return V;
+ return SDValue();
}
}