diff options
Diffstat (limited to 'llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp')
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 24 |
1 files changed, 21 insertions, 3 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 204e1f0..558c5a0 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -12994,13 +12994,31 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { SDValue Op1 = N->getOperand(1); SDValue Op2 = N->getOperand(2); - APInt C; - if (Op1->getOpcode() != ISD::MUL || - !ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne()) + unsigned Opc = Op1->getOpcode(); + if (Opc != ISD::MUL && Opc != ISD::SHL) return SDValue(); SDValue LHS = Op1->getOperand(0); SDValue RHS = Op1->getOperand(1); + + // Try to treat (shl %a, %c) as (mul %a, (1 << %c)) for constant %c. + if (Opc == ISD::SHL) { + APInt C; + if (!ISD::isConstantSplatVector(RHS.getNode(), C)) + return SDValue(); + + RHS = + DAG.getSplatVector(RHS.getValueType(), DL, + DAG.getConstant(APInt(C.getBitWidth(), 1).shl(C), DL, + RHS.getValueType().getScalarType())); + Opc = ISD::MUL; + } + + APInt C; + if (Opc != ISD::MUL || !ISD::isConstantSplatVector(Op2.getNode(), C) || + !C.isOne()) + return SDValue(); + unsigned LHSOpcode = LHS->getOpcode(); if (!ISD::isExtOpcode(LHSOpcode)) return SDValue(); |