aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp')
-rw-r--r--llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp24
1 files changed, 21 insertions, 3 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 204e1f0..558c5a0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12994,13 +12994,31 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDValue Op1 = N->getOperand(1);
SDValue Op2 = N->getOperand(2);
- APInt C;
- if (Op1->getOpcode() != ISD::MUL ||
- !ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne())
+ unsigned Opc = Op1->getOpcode();
+ if (Opc != ISD::MUL && Opc != ISD::SHL)
return SDValue();
SDValue LHS = Op1->getOperand(0);
SDValue RHS = Op1->getOperand(1);
+
+ // Try to treat (shl %a, %c) as (mul %a, (1 << %c)) for constant %c.
+ if (Opc == ISD::SHL) {
+ APInt C;
+ if (!ISD::isConstantSplatVector(RHS.getNode(), C))
+ return SDValue();
+
+ RHS =
+ DAG.getSplatVector(RHS.getValueType(), DL,
+ DAG.getConstant(APInt(C.getBitWidth(), 1).shl(C), DL,
+ RHS.getValueType().getScalarType()));
+ Opc = ISD::MUL;
+ }
+
+ APInt C;
+ if (Opc != ISD::MUL || !ISD::isConstantSplatVector(Op2.getNode(), C) ||
+ !C.isOne())
+ return SDValue();
+
unsigned LHSOpcode = LHS->getOpcode();
if (!ISD::isExtOpcode(LHSOpcode))
return SDValue();