aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib
diff options
context:
space:
mode:
authorZain Jaffal <z_jaffal@apple.com>2023-01-23 10:08:33 +0000
committerZain Jaffal <z_jaffal@apple.com>2023-01-31 10:23:46 +0000
commit5170610b5789cde77a948fe57a715c512dcfe350 (patch)
tree89131ff602810ecf0c0f77960d3d894844615009 /llvm/lib
parente4bc9898ddbeb70bc49d713bbf863f050f21e03f (diff)
downloadllvm-5170610b5789cde77a948fe57a715c512dcfe350.zip
llvm-5170610b5789cde77a948fe57a715c512dcfe350.tar.gz
llvm-5170610b5789cde77a948fe57a715c512dcfe350.tar.bz2
[AArch64] turn extended vecreduce bigger than v16i8 into udot/sdot
We can do this by breaking vecreduce into v16i8 vectors generating udot/sdot and concatenating them. Differential Revision: https://reviews.llvm.org/D141693
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/Target/AArch64/AArch64ISelLowering.cpp69
1 files changed, 63 insertions, 6 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index ab2b53f..ca6f1aa 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -15173,6 +15173,9 @@ static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N,
// Turn a v8i8/v16i8 extended vecreduce into a udot/sdot and vecreduce
// vecreduce.add(ext(A)) to vecreduce.add(DOT(zero, A, one))
// vecreduce.add(mul(ext(A), ext(B))) to vecreduce.add(DOT(zero, A, B))
+// If we have vectors larger than v16i8 we extract v16i8 vectors,
+// Follow the same steps above to get DOT instructions concatenate them
+// and generate vecreduce.add(concat_vector(DOT, DOT2, ..)).
static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
const AArch64Subtarget *ST) {
if (!ST->hasDotProd())
@@ -15198,7 +15201,9 @@ static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
return SDValue();
EVT Op0VT = A.getOperand(0).getValueType();
- if (Op0VT != MVT::v8i8 && Op0VT != MVT::v16i8)
+ bool IsValidElementCount = Op0VT.getVectorNumElements() % 8 == 0;
+ bool IsValidSize = Op0VT.getScalarSizeInBits() == 8;
+ if (!IsValidElementCount || !IsValidSize)
return SDValue();
SDLoc DL(Op0);
@@ -15209,13 +15214,65 @@ static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
else
B = B.getOperand(0);
- SDValue Zeros =
- DAG.getConstant(0, DL, Op0VT == MVT::v8i8 ? MVT::v2i32 : MVT::v4i32);
+ unsigned IsMultipleOf16 = Op0VT.getVectorNumElements() % 16 == 0;
+ unsigned NumOfVecReduce;
+ EVT TargetType;
+ if (IsMultipleOf16) {
+ NumOfVecReduce = Op0VT.getVectorNumElements() / 16;
+ TargetType = MVT::v4i32;
+ } else {
+ NumOfVecReduce = Op0VT.getVectorNumElements() / 8;
+ TargetType = MVT::v2i32;
+ }
auto DotOpcode =
(ExtOpcode == ISD::ZERO_EXTEND) ? AArch64ISD::UDOT : AArch64ISD::SDOT;
- SDValue Dot = DAG.getNode(DotOpcode, DL, Zeros.getValueType(), Zeros,
- A.getOperand(0), B);
- return DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot);
+ // Handle the case where we need to generate only one Dot operation.
+ if (NumOfVecReduce == 1) {
+ SDValue Zeros = DAG.getConstant(0, DL, TargetType);
+ SDValue Dot = DAG.getNode(DotOpcode, DL, Zeros.getValueType(), Zeros,
+ A.getOperand(0), B);
+ return DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot);
+ }
+ // Generate Dot instructions that are multiple of 16.
+ unsigned VecReduce16Num = Op0VT.getVectorNumElements() / 16;
+ SmallVector<SDValue, 4> SDotVec16;
+ unsigned I = 0;
+ for (; I < VecReduce16Num; I += 1) {
+ SDValue Zeros = DAG.getConstant(0, DL, MVT::v4i32);
+ SDValue Op0 =
+ DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v16i8, A.getOperand(0),
+ DAG.getConstant(I * 16, DL, MVT::i64));
+ SDValue Op1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v16i8, B,
+ DAG.getConstant(I * 16, DL, MVT::i64));
+ SDValue Dot =
+ DAG.getNode(DotOpcode, DL, Zeros.getValueType(), Zeros, Op0, Op1);
+ SDotVec16.push_back(Dot);
+ }
+ // Concatenate dot operations.
+ EVT SDot16EVT =
+ EVT::getVectorVT(*DAG.getContext(), MVT::i32, 4 * VecReduce16Num);
+ SDValue ConcatSDot16 =
+ DAG.getNode(ISD::CONCAT_VECTORS, DL, SDot16EVT, SDotVec16);
+ SDValue VecReduceAdd16 =
+ DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), ConcatSDot16);
+ unsigned VecReduce8Num = (Op0VT.getVectorNumElements() % 16) / 8;
+ if (VecReduce8Num == 0)
+ return VecReduceAdd16;
+
+ // Generate the remainder Dot operation that is multiple of 8.
+ SmallVector<SDValue, 4> SDotVec8;
+ SDValue Zeros = DAG.getConstant(0, DL, MVT::v2i32);
+ SDValue Vec8Op0 =
+ DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, A.getOperand(0),
+ DAG.getConstant(I * 16, DL, MVT::i64));
+ SDValue Vec8Op1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, B,
+ DAG.getConstant(I * 16, DL, MVT::i64));
+ SDValue Dot =
+ DAG.getNode(DotOpcode, DL, Zeros.getValueType(), Zeros, Vec8Op0, Vec8Op1);
+ SDValue VecReudceAdd8 =
+ DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot);
+ return DAG.getNode(ISD::ADD, DL, N->getValueType(0), VecReduceAdd16,
+ VecReudceAdd8);
}
// Given an (integer) vecreduce, we know the order of the inputs does not