aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp')
-rw-r--r--llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp114
1 files changed, 72 insertions, 42 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 816b7ba..f144f17 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -2042,6 +2042,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
+ case ISD::PARTIAL_REDUCE_FMLA:
return visitPARTIAL_REDUCE_MLA(N);
case ISD::VECTOR_COMPRESS: return visitVECTOR_COMPRESS(N);
case ISD::LIFETIME_END: return visitLIFETIME_END(N);
@@ -13006,6 +13007,9 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
//
// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
// -> partial_reduce_*mla(acc, x, C)
+//
+// partial_reduce_fmla(acc, fmul(fpext(a), fpext(b)), splat(1.0))
+// -> partial_reduce_fmla(acc, a, b)
SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDLoc DL(N);
auto *Context = DAG.getContext();
@@ -13014,7 +13018,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDValue Op2 = N->getOperand(2);
unsigned Opc = Op1->getOpcode();
- if (Opc != ISD::MUL && Opc != ISD::SHL)
+ if (Opc != ISD::MUL && Opc != ISD::FMUL && Opc != ISD::SHL)
return SDValue();
SDValue LHS = Op1->getOperand(0);
@@ -13033,13 +13037,16 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
Opc = ISD::MUL;
}
- APInt C;
- if (Opc != ISD::MUL || !ISD::isConstantSplatVector(Op2.getNode(), C) ||
- !C.isOne())
+ if (!(Opc == ISD::MUL && llvm::isOneOrOneSplat(Op2)) &&
+ !(Opc == ISD::FMUL && llvm::isOneOrOneSplatFP(Op2)))
return SDValue();
+ auto IsIntOrFPExtOpcode = [](unsigned int Opcode) {
+ return (ISD::isExtOpcode(Opcode) || Opcode == ISD::FP_EXTEND);
+ };
+
unsigned LHSOpcode = LHS->getOpcode();
- if (!ISD::isExtOpcode(LHSOpcode))
+ if (!IsIntOrFPExtOpcode(LHSOpcode))
return SDValue();
SDValue LHSExtOp = LHS->getOperand(0);
@@ -13047,6 +13054,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
// -> partial_reduce_*mla(acc, x, C)
+ APInt C;
if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
// TODO: Make use of partial_reduce_sumla here
APInt CTrunc = C.trunc(LHSExtOpVT.getScalarSizeInBits());
@@ -13071,7 +13079,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
}
unsigned RHSOpcode = RHS->getOpcode();
- if (!ISD::isExtOpcode(RHSOpcode))
+ if (!IsIntOrFPExtOpcode(RHSOpcode))
return SDValue();
SDValue RHSExtOp = RHS->getOperand(0);
@@ -13088,6 +13096,8 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::SIGN_EXTEND) {
NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
std::swap(LHSExtOp, RHSExtOp);
+ } else if (LHSOpcode == ISD::FP_EXTEND && RHSOpcode == ISD::FP_EXTEND) {
+ NewOpc = ISD::PARTIAL_REDUCE_FMLA;
} else
return SDValue();
// For a 2-stage extend the signedness of both of the extends must match
@@ -13115,30 +13125,33 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
// -> partial.reduce.smla(acc, op, splat(trunc(1)))
// partial.reduce.sumla(acc, sext(op), splat(1))
// -> partial.reduce.smla(acc, op, splat(trunc(1)))
+// partial.reduce.fmla(acc, fpext(op), splat(1.0))
+// -> partial.reduce.fmla(acc, op, splat(1.0))
SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
SDLoc DL(N);
SDValue Acc = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
SDValue Op2 = N->getOperand(2);
- APInt ConstantOne;
- if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
- !ConstantOne.isOne())
+ if (!llvm::isOneOrOneSplat(Op2) && !llvm::isOneOrOneSplatFP(Op2))
return SDValue();
unsigned Op1Opcode = Op1.getOpcode();
- if (!ISD::isExtOpcode(Op1Opcode))
+ if (!ISD::isExtOpcode(Op1Opcode) && Op1Opcode != ISD::FP_EXTEND)
return SDValue();
- bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
+ bool Op1IsSigned =
+ Op1Opcode == ISD::SIGN_EXTEND || Op1Opcode == ISD::FP_EXTEND;
bool NodeIsSigned = N->getOpcode() != ISD::PARTIAL_REDUCE_UMLA;
EVT AccElemVT = Acc.getValueType().getVectorElementType();
if (Op1IsSigned != NodeIsSigned &&
Op1.getValueType().getVectorElementType() != AccElemVT)
return SDValue();
- unsigned NewOpcode =
- Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+ unsigned NewOpcode = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA
+ ? ISD::PARTIAL_REDUCE_FMLA
+ : Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA
+ : ISD::PARTIAL_REDUCE_UMLA;
SDValue UnextOp1 = Op1.getOperand(0);
EVT UnextOp1VT = UnextOp1.getValueType();
@@ -13148,8 +13161,12 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
TLI.getTypeToTransformTo(*Context, UnextOp1VT)))
return SDValue();
+ SDValue Constant = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA
+ ? DAG.getConstantFP(1, DL, UnextOp1VT)
+ : DAG.getConstant(1, DL, UnextOp1VT);
+
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
- DAG.getConstant(1, DL, UnextOp1VT));
+ Constant);
}
SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
@@ -16736,38 +16753,51 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) {
}
// fold (conv (load x)) -> (load (conv*)x)
+ // fold (conv (freeze (load x))) -> (freeze (load (conv*)x))
// If the resultant load doesn't need a higher alignment than the original!
- if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
- // Do not remove the cast if the types differ in endian layout.
- TLI.hasBigEndianPartOrdering(N0.getValueType(), DAG.getDataLayout()) ==
- TLI.hasBigEndianPartOrdering(VT, DAG.getDataLayout()) &&
- // If the load is volatile, we only want to change the load type if the
- // resulting load is legal. Otherwise we might increase the number of
- // memory accesses. We don't care if the original type was legal or not
- // as we assume software couldn't rely on the number of accesses of an
- // illegal type.
- ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple()) ||
- TLI.isOperationLegal(ISD::LOAD, VT))) {
- LoadSDNode *LN0 = cast<LoadSDNode>(N0);
+ auto CastLoad = [this, &VT](SDValue N0, const SDLoc &DL) {
+ if (!ISD::isNormalLoad(N0.getNode()) || !N0.hasOneUse())
+ return SDValue();
- if (TLI.isLoadBitCastBeneficial(N0.getValueType(), VT, DAG,
- *LN0->getMemOperand())) {
- // If the range metadata type does not match the new memory
- // operation type, remove the range metadata.
- if (const MDNode *MD = LN0->getRanges()) {
- ConstantInt *Lower = mdconst::extract<ConstantInt>(MD->getOperand(0));
- if (Lower->getBitWidth() != VT.getScalarSizeInBits() ||
- !VT.isInteger()) {
- LN0->getMemOperand()->clearRanges();
- }
+ // Do not remove the cast if the types differ in endian layout.
+ if (TLI.hasBigEndianPartOrdering(N0.getValueType(), DAG.getDataLayout()) !=
+ TLI.hasBigEndianPartOrdering(VT, DAG.getDataLayout()))
+ return SDValue();
+
+ // If the load is volatile, we only want to change the load type if the
+ // resulting load is legal. Otherwise we might increase the number of
+ // memory accesses. We don't care if the original type was legal or not
+ // as we assume software couldn't rely on the number of accesses of an
+ // illegal type.
+ auto *LN0 = cast<LoadSDNode>(N0);
+ if ((LegalOperations || !LN0->isSimple()) &&
+ !TLI.isOperationLegal(ISD::LOAD, VT))
+ return SDValue();
+
+ if (!TLI.isLoadBitCastBeneficial(N0.getValueType(), VT, DAG,
+ *LN0->getMemOperand()))
+ return SDValue();
+
+ // If the range metadata type does not match the new memory
+ // operation type, remove the range metadata.
+ if (const MDNode *MD = LN0->getRanges()) {
+ ConstantInt *Lower = mdconst::extract<ConstantInt>(MD->getOperand(0));
+ if (Lower->getBitWidth() != VT.getScalarSizeInBits() || !VT.isInteger()) {
+ LN0->getMemOperand()->clearRanges();
}
- SDValue Load =
- DAG.getLoad(VT, SDLoc(N), LN0->getChain(), LN0->getBasePtr(),
- LN0->getMemOperand());
- DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1));
- return Load;
}
- }
+ SDValue Load = DAG.getLoad(VT, DL, LN0->getChain(), LN0->getBasePtr(),
+ LN0->getMemOperand());
+ DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1));
+ return Load;
+ };
+
+ if (SDValue NewLd = CastLoad(N0, SDLoc(N)))
+ return NewLd;
+
+ if (N0.getOpcode() == ISD::FREEZE && N0.hasOneUse())
+ if (SDValue NewLd = CastLoad(N0.getOperand(0), SDLoc(N)))
+ return DAG.getFreeze(NewLd);
if (SDValue V = foldBitcastedFPLogic(N, DAG, TLI))
return V;