diff options
Diffstat (limited to 'llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp')
| -rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 45 |
1 files changed, 31 insertions, 14 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 10daca5..f144f17 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -2042,6 +2042,7 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::PARTIAL_REDUCE_SMLA: case ISD::PARTIAL_REDUCE_UMLA: case ISD::PARTIAL_REDUCE_SUMLA: + case ISD::PARTIAL_REDUCE_FMLA: return visitPARTIAL_REDUCE_MLA(N); case ISD::VECTOR_COMPRESS: return visitVECTOR_COMPRESS(N); case ISD::LIFETIME_END: return visitLIFETIME_END(N); @@ -13006,6 +13007,9 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) { // // partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1)) // -> partial_reduce_*mla(acc, x, C) +// +// partial_reduce_fmla(acc, fmul(fpext(a), fpext(b)), splat(1.0)) +// -> partial_reduce_fmla(acc, a, b) SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { SDLoc DL(N); auto *Context = DAG.getContext(); @@ -13014,7 +13018,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { SDValue Op2 = N->getOperand(2); unsigned Opc = Op1->getOpcode(); - if (Opc != ISD::MUL && Opc != ISD::SHL) + if (Opc != ISD::MUL && Opc != ISD::FMUL && Opc != ISD::SHL) return SDValue(); SDValue LHS = Op1->getOperand(0); @@ -13033,13 +13037,16 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { Opc = ISD::MUL; } - APInt C; - if (Opc != ISD::MUL || !ISD::isConstantSplatVector(Op2.getNode(), C) || - !C.isOne()) + if (!(Opc == ISD::MUL && llvm::isOneOrOneSplat(Op2)) && + !(Opc == ISD::FMUL && llvm::isOneOrOneSplatFP(Op2))) return SDValue(); + auto IsIntOrFPExtOpcode = [](unsigned int Opcode) { + return (ISD::isExtOpcode(Opcode) || Opcode == ISD::FP_EXTEND); + }; + unsigned LHSOpcode = LHS->getOpcode(); - if (!ISD::isExtOpcode(LHSOpcode)) + if (!IsIntOrFPExtOpcode(LHSOpcode)) return SDValue(); SDValue LHSExtOp = LHS->getOperand(0); @@ -13047,6 +13054,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { // partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1)) // -> partial_reduce_*mla(acc, x, C) + APInt C; if (ISD::isConstantSplatVector(RHS.getNode(), C)) { // TODO: Make use of partial_reduce_sumla here APInt CTrunc = C.trunc(LHSExtOpVT.getScalarSizeInBits()); @@ -13071,7 +13079,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { } unsigned RHSOpcode = RHS->getOpcode(); - if (!ISD::isExtOpcode(RHSOpcode)) + if (!IsIntOrFPExtOpcode(RHSOpcode)) return SDValue(); SDValue RHSExtOp = RHS->getOperand(0); @@ -13088,6 +13096,8 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::SIGN_EXTEND) { NewOpc = ISD::PARTIAL_REDUCE_SUMLA; std::swap(LHSExtOp, RHSExtOp); + } else if (LHSOpcode == ISD::FP_EXTEND && RHSOpcode == ISD::FP_EXTEND) { + NewOpc = ISD::PARTIAL_REDUCE_FMLA; } else return SDValue(); // For a 2-stage extend the signedness of both of the extends must match @@ -13115,30 +13125,33 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { // -> partial.reduce.smla(acc, op, splat(trunc(1))) // partial.reduce.sumla(acc, sext(op), splat(1)) // -> partial.reduce.smla(acc, op, splat(trunc(1))) +// partial.reduce.fmla(acc, fpext(op), splat(1.0)) +// -> partial.reduce.fmla(acc, op, splat(1.0)) SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) { SDLoc DL(N); SDValue Acc = N->getOperand(0); SDValue Op1 = N->getOperand(1); SDValue Op2 = N->getOperand(2); - APInt ConstantOne; - if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) || - !ConstantOne.isOne()) + if (!llvm::isOneOrOneSplat(Op2) && !llvm::isOneOrOneSplatFP(Op2)) return SDValue(); unsigned Op1Opcode = Op1.getOpcode(); - if (!ISD::isExtOpcode(Op1Opcode)) + if (!ISD::isExtOpcode(Op1Opcode) && Op1Opcode != ISD::FP_EXTEND) return SDValue(); - bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND; + bool Op1IsSigned = + Op1Opcode == ISD::SIGN_EXTEND || Op1Opcode == ISD::FP_EXTEND; bool NodeIsSigned = N->getOpcode() != ISD::PARTIAL_REDUCE_UMLA; EVT AccElemVT = Acc.getValueType().getVectorElementType(); if (Op1IsSigned != NodeIsSigned && Op1.getValueType().getVectorElementType() != AccElemVT) return SDValue(); - unsigned NewOpcode = - Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA; + unsigned NewOpcode = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA + ? ISD::PARTIAL_REDUCE_FMLA + : Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA + : ISD::PARTIAL_REDUCE_UMLA; SDValue UnextOp1 = Op1.getOperand(0); EVT UnextOp1VT = UnextOp1.getValueType(); @@ -13148,8 +13161,12 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) { TLI.getTypeToTransformTo(*Context, UnextOp1VT))) return SDValue(); + SDValue Constant = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA + ? DAG.getConstantFP(1, DL, UnextOp1VT) + : DAG.getConstant(1, DL, UnextOp1VT); + return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1, - DAG.getConstant(1, DL, UnextOp1VT)); + Constant); } SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) { |
