aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp')
-rw-r--r--llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp45
1 files changed, 38 insertions, 7 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 7341914..17703f5 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12843,22 +12843,21 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
SDLoc DL(HG);
EVT MemVT = HG->getMemoryVT();
+ EVT DataVT = Index.getValueType();
MachineMemOperand *MMO = HG->getMemOperand();
ISD::MemIndexType IndexType = HG->getIndexType();
if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
return Chain;
- SDValue Ops[] = {Chain, Inc, Mask, BasePtr, Index,
- HG->getScale(), HG->getIntID()};
- if (refineUniformBase(BasePtr, Index, HG->isIndexScaled(), DAG, DL))
+ if (refineUniformBase(BasePtr, Index, HG->isIndexScaled(), DAG, DL) ||
+ refineIndexType(Index, IndexType, DataVT, DAG)) {
+ SDValue Ops[] = {Chain, Inc, Mask, BasePtr, Index,
+ HG->getScale(), HG->getIntID()};
return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other), MemVT, DL, Ops,
MMO, IndexType);
+ }
- EVT DataVT = Index.getValueType();
- if (refineIndexType(Index, IndexType, DataVT, DAG))
- return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other), MemVT, DL, Ops,
- MMO, IndexType);
return SDValue();
}
@@ -16343,6 +16342,38 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
DAG, DL);
}
break;
+ case ISD::ABDU:
+ case ISD::ABDS:
+ // (trunc (abdu/abds a, b)) → (abdu/abds (trunc a), (trunc b))
+ if (!LegalOperations || N0.hasOneUse()) {
+ EVT SrcVT = N0.getValueType();
+ EVT TruncVT = VT;
+ unsigned SrcBits = SrcVT.getScalarSizeInBits();
+ unsigned TruncBits = TruncVT.getScalarSizeInBits();
+ unsigned NeededBits = SrcBits - TruncBits;
+
+ SDValue A = N0.getOperand(0);
+ SDValue B = N0.getOperand(1);
+ bool CanFold = false;
+
+ if (N0.getOpcode() == ISD::ABDU) {
+ KnownBits KnownA = DAG.computeKnownBits(A);
+ KnownBits KnownB = DAG.computeKnownBits(B);
+ CanFold = KnownA.countMinLeadingZeros() >= NeededBits &&
+ KnownB.countMinLeadingZeros() >= NeededBits;
+ } else {
+ unsigned SignBitsA = DAG.ComputeNumSignBits(A);
+ unsigned SignBitsB = DAG.ComputeNumSignBits(B);
+ CanFold = SignBitsA > NeededBits && SignBitsB > NeededBits;
+ }
+
+ if (CanFold && TLI.isOperationLegal(N0.getOpcode(), VT)) {
+ SDValue NewA = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, A);
+ SDValue NewB = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, B);
+ return DAG.getNode(N0.getOpcode(), DL, TruncVT, NewA, NewB);
+ }
+ }
+ break;
}
return SDValue();