diff options
author | Zain Jaffal <z_jaffal@apple.com> | 2023-01-23 10:08:33 +0000 |
---|---|---|
committer | Zain Jaffal <z_jaffal@apple.com> | 2023-01-31 10:23:46 +0000 |
commit | 5170610b5789cde77a948fe57a715c512dcfe350 (patch) | |
tree | 89131ff602810ecf0c0f77960d3d894844615009 /llvm/lib | |
parent | e4bc9898ddbeb70bc49d713bbf863f050f21e03f (diff) | |
download | llvm-5170610b5789cde77a948fe57a715c512dcfe350.zip llvm-5170610b5789cde77a948fe57a715c512dcfe350.tar.gz llvm-5170610b5789cde77a948fe57a715c512dcfe350.tar.bz2 |
[AArch64] turn extended vecreduce bigger than v16i8 into udot/sdot
We can do this by breaking vecreduce into v16i8 vectors generating udot/sdot and concatenating them.
Differential Revision: https://reviews.llvm.org/D141693
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 69 |
1 files changed, 63 insertions, 6 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index ab2b53f..ca6f1aa 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -15173,6 +15173,9 @@ static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N, // Turn a v8i8/v16i8 extended vecreduce into a udot/sdot and vecreduce // vecreduce.add(ext(A)) to vecreduce.add(DOT(zero, A, one)) // vecreduce.add(mul(ext(A), ext(B))) to vecreduce.add(DOT(zero, A, B)) +// If we have vectors larger than v16i8 we extract v16i8 vectors, +// Follow the same steps above to get DOT instructions concatenate them +// and generate vecreduce.add(concat_vector(DOT, DOT2, ..)). static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG, const AArch64Subtarget *ST) { if (!ST->hasDotProd()) @@ -15198,7 +15201,9 @@ static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG, return SDValue(); EVT Op0VT = A.getOperand(0).getValueType(); - if (Op0VT != MVT::v8i8 && Op0VT != MVT::v16i8) + bool IsValidElementCount = Op0VT.getVectorNumElements() % 8 == 0; + bool IsValidSize = Op0VT.getScalarSizeInBits() == 8; + if (!IsValidElementCount || !IsValidSize) return SDValue(); SDLoc DL(Op0); @@ -15209,13 +15214,65 @@ static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG, else B = B.getOperand(0); - SDValue Zeros = - DAG.getConstant(0, DL, Op0VT == MVT::v8i8 ? MVT::v2i32 : MVT::v4i32); + unsigned IsMultipleOf16 = Op0VT.getVectorNumElements() % 16 == 0; + unsigned NumOfVecReduce; + EVT TargetType; + if (IsMultipleOf16) { + NumOfVecReduce = Op0VT.getVectorNumElements() / 16; + TargetType = MVT::v4i32; + } else { + NumOfVecReduce = Op0VT.getVectorNumElements() / 8; + TargetType = MVT::v2i32; + } auto DotOpcode = (ExtOpcode == ISD::ZERO_EXTEND) ? AArch64ISD::UDOT : AArch64ISD::SDOT; - SDValue Dot = DAG.getNode(DotOpcode, DL, Zeros.getValueType(), Zeros, - A.getOperand(0), B); - return DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot); + // Handle the case where we need to generate only one Dot operation. + if (NumOfVecReduce == 1) { + SDValue Zeros = DAG.getConstant(0, DL, TargetType); + SDValue Dot = DAG.getNode(DotOpcode, DL, Zeros.getValueType(), Zeros, + A.getOperand(0), B); + return DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot); + } + // Generate Dot instructions that are multiple of 16. + unsigned VecReduce16Num = Op0VT.getVectorNumElements() / 16; + SmallVector<SDValue, 4> SDotVec16; + unsigned I = 0; + for (; I < VecReduce16Num; I += 1) { + SDValue Zeros = DAG.getConstant(0, DL, MVT::v4i32); + SDValue Op0 = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v16i8, A.getOperand(0), + DAG.getConstant(I * 16, DL, MVT::i64)); + SDValue Op1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v16i8, B, + DAG.getConstant(I * 16, DL, MVT::i64)); + SDValue Dot = + DAG.getNode(DotOpcode, DL, Zeros.getValueType(), Zeros, Op0, Op1); + SDotVec16.push_back(Dot); + } + // Concatenate dot operations. + EVT SDot16EVT = + EVT::getVectorVT(*DAG.getContext(), MVT::i32, 4 * VecReduce16Num); + SDValue ConcatSDot16 = + DAG.getNode(ISD::CONCAT_VECTORS, DL, SDot16EVT, SDotVec16); + SDValue VecReduceAdd16 = + DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), ConcatSDot16); + unsigned VecReduce8Num = (Op0VT.getVectorNumElements() % 16) / 8; + if (VecReduce8Num == 0) + return VecReduceAdd16; + + // Generate the remainder Dot operation that is multiple of 8. + SmallVector<SDValue, 4> SDotVec8; + SDValue Zeros = DAG.getConstant(0, DL, MVT::v2i32); + SDValue Vec8Op0 = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, A.getOperand(0), + DAG.getConstant(I * 16, DL, MVT::i64)); + SDValue Vec8Op1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, B, + DAG.getConstant(I * 16, DL, MVT::i64)); + SDValue Dot = + DAG.getNode(DotOpcode, DL, Zeros.getValueType(), Zeros, Vec8Op0, Vec8Op1); + SDValue VecReudceAdd8 = + DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot); + return DAG.getNode(ISD::ADD, DL, N->getValueType(0), VecReduceAdd16, + VecReudceAdd8); } // Given an (integer) vecreduce, we know the order of the inputs does not |