diff options
author | Krzysztof Parzyszek <kparzysz@quicinc.com> | 2022-10-25 14:23:34 -0700 |
---|---|---|
committer | Krzysztof Parzyszek <kparzysz@quicinc.com> | 2022-10-29 11:13:28 -0700 |
commit | 9422a8d94c0f038b78427f58b31c1200aed7524c (patch) | |
tree | 197ef50559fc97b59bbedb967be68e7ff4ef8cbd /llvm/lib | |
parent | 63a46385f2c6dd39cf68d9811548c53e8d460cd9 (diff) | |
download | llvm-9422a8d94c0f038b78427f58b31c1200aed7524c.zip llvm-9422a8d94c0f038b78427f58b31c1200aed7524c.tar.gz llvm-9422a8d94c0f038b78427f58b31c1200aed7524c.tar.bz2 |
[Hexagon] Break up vectors into HVX-sized chunks in HvxIdioms
This will allow recognizing Q.31 multiplications on vectors that are
multiplies of HVX vectors. At the moment this comes at the expense of
Q.15 multiplications, which now are handled as 32-bit multiplications
with shifts.
In the longer term this will likely be replaced by a different scheme
of "legalizing" vectors, which is necessary for idiom recognition, at
least where using direct HVX instrinsics is desired.
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/Target/Hexagon/HexagonISelLowering.cpp | 21 | ||||
-rw-r--r-- | llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp | 37 | ||||
-rw-r--r-- | llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp | 233 |
3 files changed, 187 insertions, 104 deletions
diff --git a/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp b/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp index 9bd377a..167d622 100644 --- a/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp +++ b/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp @@ -3384,13 +3384,28 @@ HexagonTargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) return SDValue(); } - if (DCI.isBeforeLegalizeOps()) - return SDValue(); - SDValue Op(N, 0); const SDLoc &dl(Op); unsigned Opc = Op.getOpcode(); + if (Opc == ISD::TRUNCATE) { + SDValue Op0 = Op.getOperand(0); + // fold (truncate (build pair x, y)) -> (truncate x) or x + if (Op0.getOpcode() == ISD::BUILD_PAIR) { + EVT TruncTy = Op.getValueType(); + SDValue Elem0 = Op0.getOperand(0); + // if we match the low element of the pair, just return it. + if (Elem0.getValueType() == TruncTy) + return Elem0; + // otherwise, if the low part is still too large, apply the truncate. + if (Elem0.getValueType().bitsGT(TruncTy)) + return DCI.DAG.getNode(ISD::TRUNCATE, dl, TruncTy, Elem0); + } + } + + if (DCI.isBeforeLegalizeOps()) + return SDValue(); + if (Opc == HexagonISD::P2D) { SDValue P = Op.getOperand(0); switch (P.getOpcode()) { diff --git a/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp b/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp index e15508e..c50ea22 100644 --- a/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp +++ b/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp @@ -429,7 +429,7 @@ HexagonTargetLowering::initializeHVXLowering() { } } - setTargetDAGCombine({ISD::SPLAT_VECTOR, ISD::VSELECT}); + setTargetDAGCombine({ISD::SPLAT_VECTOR, ISD::VSELECT, ISD::TRUNCATE}); } unsigned @@ -3512,11 +3512,42 @@ HexagonTargetLowering::PerformHvxDAGCombine(SDNode *N, DAGCombinerInfo &DCI) SelectionDAG &DAG = DCI.DAG; SDValue Op(N, 0); unsigned Opc = Op.getOpcode(); - if (DCI.isBeforeLegalizeOps()) - return SDValue(); SmallVector<SDValue, 4> Ops(N->ops().begin(), N->ops().end()); + if (Opc == ISD::TRUNCATE) { + // Simplify V:v2NiB --(bitcast)--> vNi2B --(truncate)--> vNiB + // to extract-subvector (shuffle V, pick even, pick odd) + if (Ops[0].getOpcode() == ISD::BITCAST) + return SDValue(); + SDValue Cast = Ops[0]; + SDValue Src = Cast.getOperand(0); + + EVT TruncTy = Op.getValueType(); + EVT CastTy = Cast.getValueType(); + EVT SrcTy = Src.getValueType(); + if (SrcTy.isSimple()) + return SDValue(); + if (SrcTy.getVectorElementType() != TruncTy.getVectorElementType()) + return SDValue(); + unsigned SrcLen = SrcTy.getVectorNumElements(); + unsigned CastLen = CastTy.getVectorNumElements(); + if (2 * CastLen != SrcLen) + return SDValue(); + + SmallVector<int, 128> Mask(SrcLen); + for (int i = 0; i != static_cast<int>(CastLen); ++i) { + Mask[i] = 2 * i; + Mask[i + CastLen] = 2 * i + 1; + } + SDValue Deal = + DAG.getVectorShuffle(SrcTy, dl, Src, DAG.getUNDEF(SrcTy), Mask); + return opSplit(Deal, dl, DAG).first; + } + + if (DCI.isBeforeLegalizeOps()) + return SDValue(); + switch (Opc) { case ISD::VSELECT: { // (vselect (xor x, qtrue), v0, v1) -> (vselect x, v1, v0) diff --git a/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp b/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp index 1dc7bfb..f3913a9 100644 --- a/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp +++ b/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp @@ -14,6 +14,7 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" @@ -125,7 +126,8 @@ public: Value *vshuff(IRBuilderBase &Builder, Value *Val0, Value *Val1) const; Value *createHvxIntrinsic(IRBuilderBase &Builder, Intrinsic::ID IntID, - Type *RetTy, ArrayRef<Value *> Args) const; + Type *RetTy, ArrayRef<Value *> Args, + ArrayRef<Type *> ArgTys = None) const; SmallVector<Value *> splitVectorElements(IRBuilderBase &Builder, Value *Vec, unsigned ToWidth) const; Value *joinVectorElements(IRBuilderBase &Builder, ArrayRef<Value *> Values, @@ -346,6 +348,9 @@ private: std::optional<FxpOp> matchFxpMul(Instruction &In) const; Value *processFxpMul(Instruction &In, const FxpOp &Op) const; + + Value *processFxpMulChopped(IRBuilderBase &Builder, Instruction &In, + const FxpOp &Op) const; Value *createMulQ15(IRBuilderBase &Builder, Value *X, Value *Y, bool Rounding) const; Value *createMulQ31(IRBuilderBase &Builder, Value *X, Value *Y, @@ -1042,8 +1047,12 @@ auto HvxIdioms::matchFxpMul(Instruction &In) const -> std::optional<FxpOp> { // Fixed-point multiplication is always shifted right (except when the // fraction is 0 bits). + auto m_Shr = [](auto &&V, auto &&S) { + return m_CombineOr(m_LShr(V, S), m_AShr(V, S)); + }; + const APInt *Qn = nullptr; - if (Value * T; match(Exp, m_LShr(m_Value(T), m_APInt(Qn)))) { + if (Value * T; match(Exp, m_Shr(m_Value(T), m_APInt(Qn)))) { Op.Frac = Qn->getZExtValue(); Exp = T; } else { @@ -1075,12 +1084,56 @@ auto HvxIdioms::matchFxpMul(Instruction &In) const -> std::optional<FxpOp> { auto HvxIdioms::processFxpMul(Instruction &In, const FxpOp &Op) const -> Value * { + assert(Op.X->getType() == Op.Y->getType()); + + auto *VecTy = cast<VectorType>(Op.X->getType()); + auto *ElemTy = cast<IntegerType>(VecTy->getElementType()); + unsigned ElemWidth = ElemTy->getBitWidth(); + if (ElemWidth < 8 || !isPowerOf2_32(ElemWidth)) + return nullptr; + + unsigned VecLen = HVC.length(VecTy); + unsigned HvxLen = (8 * HVC.HST.getVectorLength()) / std::min(ElemWidth, 32u); + if (VecLen % HvxLen != 0) + return nullptr; + + // FIXME: handle 8-bit multiplications + if (ElemWidth < 16) + return nullptr; + + SmallVector<Value *> Results; + FxpOp ChopOp; + ChopOp.Opcode = Op.Opcode; + ChopOp.Frac = Op.Frac; + ChopOp.RoundAt = Op.RoundAt; + + IRBuilder<InstSimplifyFolder> Builder(In.getParent(), In.getIterator(), + InstSimplifyFolder(HVC.DL)); + + for (unsigned V = 0; V != VecLen / HvxLen; ++V) { + ChopOp.X = HVC.subvector(Builder, Op.X, V * HvxLen, HvxLen); + ChopOp.Y = HVC.subvector(Builder, Op.Y, V * HvxLen, HvxLen); + Results.push_back(processFxpMulChopped(Builder, In, ChopOp)); + if (Results.back() == nullptr) + break; + } + + if (Results.back() == nullptr) { + // FIXME: clean up leftover instructions + return nullptr; + } + + return HVC.concat(Builder, Results); +} + +auto HvxIdioms::processFxpMulChopped(IRBuilderBase &Builder, Instruction &In, + const FxpOp &Op) const -> Value * { // FIXME: make this more elegant struct TempValues { - void insert(Value* V) { + void insert(Value *V) { // Values.push_back(V); } - void insert(ArrayRef<Value*> Vs) { + void insert(ArrayRef<Value *> Vs) { Values.insert(Values.end(), Vs.begin(), Vs.end()); } void clear() { // @@ -1092,48 +1145,68 @@ auto HvxIdioms::processFxpMul(Instruction &In, const FxpOp &Op) const In->eraseFromParent(); } } - SmallVector<Value*> Values; + SmallVector<Value *> Values; }; TempValues DeleteOnFailure; // TODO: Make it general. - if (Op.Frac != 15 && Op.Frac != 31) - return nullptr; + // if (Op.Frac != 15 && Op.Frac != 31) + // return nullptr; + + enum Signedness { Positive, Signed, Unsigned }; + auto getNumSignificantBits = + [this, &In](Value *V) -> std::pair<unsigned, Signedness> { + unsigned Bits = HVC.getNumSignificantBits(V, &In); + // The significant bits are calculated including the sign bit. This may + // add an extra bit for zero-extended values, e.g. (zext i32 to i64) may + // result in 33 significant bits. To avoid extra words, skip the extra + // sign bit, but keep information that the value is to be treated as + // unsigned. + KnownBits Known = HVC.getKnownBits(V, &In); + Signedness Sign = Signed; + if (Bits > 1 && isPowerOf2_32(Bits - 1)) { + if (Known.Zero.ashr(Bits - 1).isAllOnes()) { + Sign = Unsigned; + Bits--; + } + } + // If the top bit of the nearest power-of-2 is zero, this value is + // positive. It could be treated as either signed or unsigned. + if (unsigned Pow2 = PowerOf2Ceil(Bits); Pow2 != Bits) { + if (Known.Zero.ashr(Pow2 - 1).isAllOnes()) + Sign = Positive; + } + return {Bits, Sign}; + }; auto *OrigTy = dyn_cast<VectorType>(Op.X->getType()); if (OrigTy == nullptr) return nullptr; - unsigned BitsX = HVC.getNumSignificantBits(Op.X, &In); - unsigned BitsY = HVC.getNumSignificantBits(Op.Y, &In); - - unsigned SigBits = std::max(BitsX, BitsY); - unsigned Width = PowerOf2Ceil(SigBits); - auto *TruncTy = VectorType::get(HVC.getIntTy(Width), OrigTy); - - IRBuilder<InstSimplifyFolder> Builder(In.getParent(), In.getIterator(), - InstSimplifyFolder(HVC.DL)); - // These may end up dead, but should be removed in isel. - Value *NewX = Builder.CreateTrunc(Op.X, TruncTy); - Value *NewY = Builder.CreateTrunc(Op.Y, TruncTy); - if (NewX != Op.X) - DeleteOnFailure.insert(NewX); - if (NewY != Op.Y) - DeleteOnFailure.insert(NewY); + auto [BitsX, SignX] = getNumSignificantBits(Op.X); + auto [BitsY, SignY] = getNumSignificantBits(Op.Y); + unsigned Width = PowerOf2Ceil(std::max(BitsX, BitsY)); if (!Op.RoundAt || *Op.RoundAt == Op.Frac - 1) { bool Rounding = Op.RoundAt.has_value(); - if (Width == Op.Frac + 1) { + // The fixed-point intrinsics do signed multiplication. + if (Width == Op.Frac + 1 && SignX != Unsigned && SignY != Unsigned) { + auto *TruncTy = VectorType::get(HVC.getIntTy(Width), OrigTy); + Value *TruncX = Builder.CreateTrunc(Op.X, TruncTy); + Value *TruncY = Builder.CreateTrunc(Op.Y, TruncTy); Value *QMul = nullptr; if (Width == 16) { - QMul = createMulQ15(Builder, NewX, NewY, Rounding); + QMul = createMulQ15(Builder, TruncX, TruncY, Rounding); } else if (Width == 32) { - QMul = createMulQ31(Builder, NewX, NewY, Rounding); + QMul = createMulQ31(Builder, TruncX, TruncY, Rounding); } - if (QMul != nullptr) { - DeleteOnFailure.clear(); + if (QMul != nullptr) return Builder.CreateSExt(QMul, OrigTy); - } + + if (TruncX != Op.X && isa<Instruction>(TruncX)) + cast<Instruction>(TruncX)->eraseFromParent(); + if (TruncY != Op.Y && isa<Instruction>(TruncY)) + cast<Instruction>(TruncY)->eraseFromParent(); } } @@ -1141,24 +1214,25 @@ auto HvxIdioms::processFxpMul(Instruction &In, const FxpOp &Op) const if (!HVC.HST.useHVXV62Ops()) return nullptr; - // The check for Frac will make sure of this, but keep this check for when - // this function handles all Frac cases. - assert(Width > 32); + // FIXME: make it general + if (OrigTy->getScalarSizeInBits() < 32) + return nullptr; + if (Width > 64) return nullptr; // At this point, NewX and NewY may be truncated to different element // widths to save on the number of multiplications to perform. - unsigned WidthX = PowerOf2Ceil(BitsX); - unsigned WidthY = PowerOf2Ceil(BitsY); - Value *OldX = NewX, *OldY = NewY; - NewX = Builder.CreateTrunc( - NewX, VectorType::get(HVC.getIntTy(WidthX), HVC.length(NewX), false)); - NewY = Builder.CreateTrunc( - NewY, VectorType::get(HVC.getIntTy(WidthY), HVC.length(NewY), false)); - if (NewX != OldX) + unsigned WidthX = + PowerOf2Ceil(std::max(BitsX, 32u)); // FIXME: handle shorter ones + unsigned WidthY = PowerOf2Ceil(std::max(BitsY, 32u)); + Value *NewX = Builder.CreateTrunc( + Op.X, VectorType::get(HVC.getIntTy(WidthX), HVC.length(Op.X), false)); + Value *NewY = Builder.CreateTrunc( + Op.Y, VectorType::get(HVC.getIntTy(WidthY), HVC.length(Op.Y), false)); + if (NewX != Op.X) DeleteOnFailure.insert(NewX); - if (NewY != OldY) + if (NewY != Op.Y) DeleteOnFailure.insert(NewY); // Break up the arguments NewX and NewY into vectors of smaller widths @@ -1179,7 +1253,8 @@ auto HvxIdioms::processFxpMul(Instruction &In, const FxpOp &Op) const // that is halves 2(i+j), 2(i+j)+1, 2(i+j)+2, 2(i+j)+3. for (int i = 0, e = WordX.size(); i != e; ++i) { for (int j = 0, f = WordY.size(); j != f; ++j) { - bool SgnX = (i + 1 == e), SgnY = (j + 1 == f); + bool SgnX = (i + 1 == e) && SignX != Unsigned; + bool SgnY = (j + 1 == f) && SignY != Unsigned; auto [Lo, Hi] = createMul32(Builder, {WordX[i], SgnX}, {WordY[j], SgnY}); Products[i + j + 0].push_back(Lo); Products[i + j + 1].push_back(Hi); @@ -1242,7 +1317,8 @@ auto HvxIdioms::processFxpMul(Instruction &In, const FxpOp &Op) const WordP.resize(WordP.size() - SkipWords); DeleteOnFailure.clear(); - return HVC.joinVectorElements(Builder, WordP, OrigTy); + Value *Ret = HVC.joinVectorElements(Builder, WordP, OrigTy); + return Ret; } auto HvxIdioms::createMulQ15(IRBuilderBase &Builder, Value *X, Value *Y, @@ -1305,60 +1381,21 @@ auto HvxIdioms::createMul32(IRBuilderBase &Builder, SValue X, SValue Y) const assert(X.Val->getType() == Y.Val->getType()); assert(X.Val->getType() == HVC.getHvxTy(HVC.getIntTy(32), /*Pair=*/false)); - assert(HVC.HST.useHVXV62Ops()); - - auto simplifyOrSame = [this](Value *V) { - if (Value *S = HVC.simplify(V)) - return S; - return V; - }; - Value *VX = simplifyOrSame(X.Val); - Value *VY = simplifyOrSame(Y.Val); - - if (isa<Constant>(VX) || isa<Constant>(VY)) { - auto getSplatValue = [](Constant *CV) -> ConstantInt * { - if (auto T = dyn_cast<ConstantVector>(CV)) - return dyn_cast<ConstantInt>(T->getSplatValue()); - if (auto T = dyn_cast<ConstantDataVector>(CV)) - return dyn_cast<ConstantInt>(T->getSplatValue()); - return nullptr; - }; - - if (isa<Constant>(VX) && isa<Constant>(VY)) { - // Both are constants, fold the multiplication. - auto *Ty = cast<VectorType>(VX->getType()); - auto *ExtTy = VectorType::getExtendedElementVectorType(Ty); - Value *EX = X.Signed ? Builder.CreateSExt(VX, ExtTy) - : Builder.CreateZExt(VX, ExtTy); - Value *EY = Y.Signed ? Builder.CreateSExt(VY, ExtTy) - : Builder.CreateZExt(VY, ExtTy); - Value *EXY = simplifyOrSame(Builder.CreateMul(EX, EY)); - auto WordXY = HVC.splitVectorElements(Builder, EXY, /*ToWidth=*/32); - return {simplifyOrSame(WordXY[0]), simplifyOrSame(WordXY[1])}; - } - // Make VX = constant. - if (isa<Constant>(VY)) - std::swap(VX, VY); - - if (auto *SplatX = getSplatValue(cast<Constant>(VX))) { - APInt S = SplatX->getValue(); - if (S == 1) { - if (!X.Signed && !Y.Signed) - return {VY, HVC.getConstSplat(HvxI32Ty, 0)}; - return {VY, Builder.CreateAShr(VY, HVC.getConstSplat(HvxI32Ty, 31))}; - } - } + Intrinsic::ID V6_vmpy_parts; + if (X.Signed == Y.Signed) { + V6_vmpy_parts = X.Signed ? Intrinsic::hexagon_V6_vmpyss_parts + : Intrinsic::hexagon_V6_vmpyuu_parts; + } else { + if (X.Signed) + std::swap(X, Y); + V6_vmpy_parts = Intrinsic::hexagon_V6_vmpyus_parts; } - auto V6_vmpyewuh_64 = HVC.HST.getIntrinsicId(Hexagon::V6_vmpyewuh_64); - auto V6_vmpyowh_64_acc = HVC.HST.getIntrinsicId(Hexagon::V6_vmpyowh_64_acc); - - Value *Vxx = - HVC.createHvxIntrinsic(Builder, V6_vmpyewuh_64, HvxP32Ty, {X.Val, Y.Val}); - Value *Vdd = HVC.createHvxIntrinsic(Builder, V6_vmpyowh_64_acc, HvxP32Ty, - {Vxx, X.Val, Y.Val}); - - return {HVC.sublo(Builder, Vdd), HVC.subhi(Builder, Vdd)}; + Value *Parts = HVC.createHvxIntrinsic(Builder, V6_vmpy_parts, nullptr, + {X.Val, Y.Val}, {HvxI32Ty}); + Value *Hi = Builder.CreateExtractValue(Parts, {0}); + Value *Lo = Builder.CreateExtractValue(Parts, {1}); + return {Lo, Hi}; } auto HvxIdioms::run() -> bool { @@ -1778,7 +1815,8 @@ auto HexagonVectorCombine::vshuff(IRBuilderBase &Builder, Value *Val0, auto HexagonVectorCombine::createHvxIntrinsic(IRBuilderBase &Builder, Intrinsic::ID IntID, Type *RetTy, - ArrayRef<Value *> Args) const + ArrayRef<Value *> Args, + ArrayRef<Type *> ArgTys) const -> Value * { auto getCast = [&](IRBuilderBase &Builder, Value *Val, Type *DestTy) -> Value * { @@ -1803,7 +1841,7 @@ auto HexagonVectorCombine::createHvxIntrinsic(IRBuilderBase &Builder, return Builder.CreateCall(FI, {Val}); }; - Function *IntrFn = Intrinsic::getDeclaration(F.getParent(), IntID); + Function *IntrFn = Intrinsic::getDeclaration(F.getParent(), IntID, ArgTys); FunctionType *IntrTy = IntrFn->getFunctionType(); SmallVector<Value *, 4> IntrArgs; @@ -1846,7 +1884,6 @@ auto HexagonVectorCombine::splitVectorElements(IRBuilderBase &Builder, assert(VecTy->getElementType()->isIntegerTy()); unsigned FromWidth = VecTy->getScalarSizeInBits(); assert(isPowerOf2_32(ToWidth) && isPowerOf2_32(FromWidth)); - assert(ToWidth <= FromWidth && "Breaking up into wider elements?"); unsigned NumResults = FromWidth / ToWidth; |