diff options
Diffstat (limited to 'llvm/lib/CodeGen/SelectionDAG')
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 24 | ||||
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp | 4 |
2 files changed, 21 insertions, 7 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 204e1f0..558c5a0 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -12994,13 +12994,31 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { SDValue Op1 = N->getOperand(1); SDValue Op2 = N->getOperand(2); - APInt C; - if (Op1->getOpcode() != ISD::MUL || - !ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne()) + unsigned Opc = Op1->getOpcode(); + if (Opc != ISD::MUL && Opc != ISD::SHL) return SDValue(); SDValue LHS = Op1->getOperand(0); SDValue RHS = Op1->getOperand(1); + + // Try to treat (shl %a, %c) as (mul %a, (1 << %c)) for constant %c. + if (Opc == ISD::SHL) { + APInt C; + if (!ISD::isConstantSplatVector(RHS.getNode(), C)) + return SDValue(); + + RHS = + DAG.getSplatVector(RHS.getValueType(), DL, + DAG.getConstant(APInt(C.getBitWidth(), 1).shl(C), DL, + RHS.getValueType().getScalarType())); + Opc = ISD::MUL; + } + + APInt C; + if (Opc != ISD::MUL || !ISD::isConstantSplatVector(Op2.getNode(), C) || + !C.isOne()) + return SDValue(); + unsigned LHSOpcode = LHS->getOpcode(); if (!ISD::isExtOpcode(LHSOpcode)) return SDValue(); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index b5201a3..c21890a 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -8103,10 +8103,6 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I, return; } case Intrinsic::vector_partial_reduce_add: { - if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) { - visitTargetIntrinsic(I, Intrinsic); - return; - } SDValue Acc = getValue(I.getOperand(0)); SDValue Input = getValue(I.getOperand(1)); setValue(&I, |