aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/CodeGen/SelectionDAG
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/CodeGen/SelectionDAG')
-rw-r--r--llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp45
-rw-r--r--llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp2
-rw-r--r--llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp2
-rw-r--r--llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp8
-rw-r--r--llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp8
-rw-r--r--llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp2
-rw-r--r--llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp33
7 files changed, 75 insertions, 25 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 10daca5..f144f17 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -2042,6 +2042,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
+ case ISD::PARTIAL_REDUCE_FMLA:
return visitPARTIAL_REDUCE_MLA(N);
case ISD::VECTOR_COMPRESS: return visitVECTOR_COMPRESS(N);
case ISD::LIFETIME_END: return visitLIFETIME_END(N);
@@ -13006,6 +13007,9 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
//
// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
// -> partial_reduce_*mla(acc, x, C)
+//
+// partial_reduce_fmla(acc, fmul(fpext(a), fpext(b)), splat(1.0))
+// -> partial_reduce_fmla(acc, a, b)
SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDLoc DL(N);
auto *Context = DAG.getContext();
@@ -13014,7 +13018,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDValue Op2 = N->getOperand(2);
unsigned Opc = Op1->getOpcode();
- if (Opc != ISD::MUL && Opc != ISD::SHL)
+ if (Opc != ISD::MUL && Opc != ISD::FMUL && Opc != ISD::SHL)
return SDValue();
SDValue LHS = Op1->getOperand(0);
@@ -13033,13 +13037,16 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
Opc = ISD::MUL;
}
- APInt C;
- if (Opc != ISD::MUL || !ISD::isConstantSplatVector(Op2.getNode(), C) ||
- !C.isOne())
+ if (!(Opc == ISD::MUL && llvm::isOneOrOneSplat(Op2)) &&
+ !(Opc == ISD::FMUL && llvm::isOneOrOneSplatFP(Op2)))
return SDValue();
+ auto IsIntOrFPExtOpcode = [](unsigned int Opcode) {
+ return (ISD::isExtOpcode(Opcode) || Opcode == ISD::FP_EXTEND);
+ };
+
unsigned LHSOpcode = LHS->getOpcode();
- if (!ISD::isExtOpcode(LHSOpcode))
+ if (!IsIntOrFPExtOpcode(LHSOpcode))
return SDValue();
SDValue LHSExtOp = LHS->getOperand(0);
@@ -13047,6 +13054,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
// -> partial_reduce_*mla(acc, x, C)
+ APInt C;
if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
// TODO: Make use of partial_reduce_sumla here
APInt CTrunc = C.trunc(LHSExtOpVT.getScalarSizeInBits());
@@ -13071,7 +13079,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
}
unsigned RHSOpcode = RHS->getOpcode();
- if (!ISD::isExtOpcode(RHSOpcode))
+ if (!IsIntOrFPExtOpcode(RHSOpcode))
return SDValue();
SDValue RHSExtOp = RHS->getOperand(0);
@@ -13088,6 +13096,8 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::SIGN_EXTEND) {
NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
std::swap(LHSExtOp, RHSExtOp);
+ } else if (LHSOpcode == ISD::FP_EXTEND && RHSOpcode == ISD::FP_EXTEND) {
+ NewOpc = ISD::PARTIAL_REDUCE_FMLA;
} else
return SDValue();
// For a 2-stage extend the signedness of both of the extends must match
@@ -13115,30 +13125,33 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
// -> partial.reduce.smla(acc, op, splat(trunc(1)))
// partial.reduce.sumla(acc, sext(op), splat(1))
// -> partial.reduce.smla(acc, op, splat(trunc(1)))
+// partial.reduce.fmla(acc, fpext(op), splat(1.0))
+// -> partial.reduce.fmla(acc, op, splat(1.0))
SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
SDLoc DL(N);
SDValue Acc = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
SDValue Op2 = N->getOperand(2);
- APInt ConstantOne;
- if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
- !ConstantOne.isOne())
+ if (!llvm::isOneOrOneSplat(Op2) && !llvm::isOneOrOneSplatFP(Op2))
return SDValue();
unsigned Op1Opcode = Op1.getOpcode();
- if (!ISD::isExtOpcode(Op1Opcode))
+ if (!ISD::isExtOpcode(Op1Opcode) && Op1Opcode != ISD::FP_EXTEND)
return SDValue();
- bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
+ bool Op1IsSigned =
+ Op1Opcode == ISD::SIGN_EXTEND || Op1Opcode == ISD::FP_EXTEND;
bool NodeIsSigned = N->getOpcode() != ISD::PARTIAL_REDUCE_UMLA;
EVT AccElemVT = Acc.getValueType().getVectorElementType();
if (Op1IsSigned != NodeIsSigned &&
Op1.getValueType().getVectorElementType() != AccElemVT)
return SDValue();
- unsigned NewOpcode =
- Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+ unsigned NewOpcode = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA
+ ? ISD::PARTIAL_REDUCE_FMLA
+ : Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA
+ : ISD::PARTIAL_REDUCE_UMLA;
SDValue UnextOp1 = Op1.getOperand(0);
EVT UnextOp1VT = UnextOp1.getValueType();
@@ -13148,8 +13161,12 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
TLI.getTypeToTransformTo(*Context, UnextOp1VT)))
return SDValue();
+ SDValue Constant = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA
+ ? DAG.getConstantFP(1, DL, UnextOp1VT)
+ : DAG.getConstant(1, DL, UnextOp1VT);
+
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
- DAG.getConstant(1, DL, UnextOp1VT));
+ Constant);
}
SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 8e423c4..94751be5 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -534,6 +534,7 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
+ case ISD::PARTIAL_REDUCE_FMLA:
Action =
TLI.getPartialReduceMLAAction(Op.getOpcode(), Node->getValueType(0),
Node->getOperand(1).getValueType());
@@ -1243,6 +1244,7 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
+ case ISD::PARTIAL_REDUCE_FMLA:
Results.push_back(TLI.expandPartialReduceMLA(Node, DAG));
return;
case ISD::VECREDUCE_SEQ_FADD:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index bb4a8d9..dd5c011 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -1474,6 +1474,7 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
+ case ISD::PARTIAL_REDUCE_FMLA:
SplitVecRes_PARTIAL_REDUCE_MLA(N, Lo, Hi);
break;
case ISD::GET_ACTIVE_LANE_MASK:
@@ -3689,6 +3690,7 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
+ case ISD::PARTIAL_REDUCE_FMLA:
Res = SplitVecOp_PARTIAL_REDUCE_MLA(N);
break;
}
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 0a06752..bbc1d73 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -8404,7 +8404,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
}
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
- case ISD::PARTIAL_REDUCE_SUMLA: {
+ case ISD::PARTIAL_REDUCE_SUMLA:
+ case ISD::PARTIAL_REDUCE_FMLA: {
[[maybe_unused]] EVT AccVT = N1.getValueType();
[[maybe_unused]] EVT Input1VT = N2.getValueType();
[[maybe_unused]] EVT Input2VT = N3.getValueType();
@@ -13064,6 +13065,11 @@ bool llvm::isOneOrOneSplat(SDValue N, bool AllowUndefs) {
return C && C->isOne();
}
+bool llvm::isOneOrOneSplatFP(SDValue N, bool AllowUndefs) {
+ ConstantFPSDNode *C = isConstOrConstSplatFP(N, AllowUndefs);
+ return C && C->isExactlyValue(1.0);
+}
+
bool llvm::isAllOnesOrAllOnesSplat(SDValue N, bool AllowUndefs) {
N = peekThroughBitcasts(N);
unsigned BitWidth = N.getScalarValueSizeInBits();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 2f598b2..88b0809 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8187,6 +8187,14 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
Input, DAG.getConstant(1, sdl, Input.getValueType())));
return;
}
+ case Intrinsic::vector_partial_reduce_fadd: {
+ SDValue Acc = getValue(I.getOperand(0));
+ SDValue Input = getValue(I.getOperand(1));
+ setValue(&I, DAG.getNode(
+ ISD::PARTIAL_REDUCE_FMLA, sdl, Acc.getValueType(), Acc,
+ Input, DAG.getConstantFP(1.0, sdl, Input.getValueType())));
+ return;
+ }
case Intrinsic::experimental_cttz_elts: {
auto DL = getCurSDLoc();
SDValue Op = getValue(I.getOperand(0));
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index d3e1628..ec5edd5 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -590,6 +590,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
return "partial_reduce_smla";
case ISD::PARTIAL_REDUCE_SUMLA:
return "partial_reduce_sumla";
+ case ISD::PARTIAL_REDUCE_FMLA:
+ return "partial_reduce_fmla";
case ISD::LOOP_DEPENDENCE_WAR_MASK:
return "loop_dep_war";
case ISD::LOOP_DEPENDENCE_RAW_MASK:
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 9bdf822..b51d664 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -12074,22 +12074,32 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
EVT::getVectorVT(*DAG.getContext(), AccVT.getVectorElementType(),
MulOpVT.getVectorElementCount());
- unsigned ExtOpcLHS = N->getOpcode() == ISD::PARTIAL_REDUCE_UMLA
- ? ISD::ZERO_EXTEND
- : ISD::SIGN_EXTEND;
- unsigned ExtOpcRHS = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA
- ? ISD::SIGN_EXTEND
- : ISD::ZERO_EXTEND;
+ unsigned ExtOpcLHS, ExtOpcRHS;
+ switch (N->getOpcode()) {
+ default:
+ llvm_unreachable("Unexpected opcode");
+ case ISD::PARTIAL_REDUCE_UMLA:
+ ExtOpcLHS = ExtOpcRHS = ISD::ZERO_EXTEND;
+ break;
+ case ISD::PARTIAL_REDUCE_SMLA:
+ ExtOpcLHS = ExtOpcRHS = ISD::SIGN_EXTEND;
+ break;
+ case ISD::PARTIAL_REDUCE_FMLA:
+ ExtOpcLHS = ExtOpcRHS = ISD::FP_EXTEND;
+ break;
+ }
if (ExtMulOpVT != MulOpVT) {
MulLHS = DAG.getNode(ExtOpcLHS, DL, ExtMulOpVT, MulLHS);
MulRHS = DAG.getNode(ExtOpcRHS, DL, ExtMulOpVT, MulRHS);
}
SDValue Input = MulLHS;
- APInt ConstantOne;
- if (!ISD::isConstantSplatVector(MulRHS.getNode(), ConstantOne) ||
- !ConstantOne.isOne())
+ if (N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA) {
+ if (!llvm::isOneOrOneSplatFP(MulRHS))
+ Input = DAG.getNode(ISD::FMUL, DL, ExtMulOpVT, MulLHS, MulRHS);
+ } else if (!llvm::isOneOrOneSplat(MulRHS)) {
Input = DAG.getNode(ISD::MUL, DL, ExtMulOpVT, MulLHS, MulRHS);
+ }
unsigned Stride = AccVT.getVectorMinNumElements();
unsigned ScaleFactor = MulOpVT.getVectorMinNumElements() / Stride;
@@ -12099,10 +12109,13 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
for (unsigned I = 0; I < ScaleFactor; I++)
Subvectors.push_back(DAG.getExtractSubvector(DL, AccVT, Input, I * Stride));
+ unsigned FlatNode =
+ N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA ? ISD::FADD : ISD::ADD;
+
// Flatten the subvector tree
while (Subvectors.size() > 1) {
Subvectors.push_back(
- DAG.getNode(ISD::ADD, DL, AccVT, {Subvectors[0], Subvectors[1]}));
+ DAG.getNode(FlatNode, DL, AccVT, {Subvectors[0], Subvectors[1]}));
Subvectors.pop_front();
Subvectors.pop_front();
}