aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp')
-rw-r--r--llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp168
1 files changed, 124 insertions, 44 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 5b5565a..197aae6 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3007,9 +3007,9 @@ AArch64TTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
llvm_unreachable("Unsupported register kind");
}
-bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
- ArrayRef<const Value *> Args,
- Type *SrcOverrideTy) const {
+bool AArch64TTIImpl::isSingleExtWideningInstruction(
+ unsigned Opcode, Type *DstTy, ArrayRef<const Value *> Args,
+ Type *SrcOverrideTy) const {
// A helper that returns a vector type from the given type. The number of
// elements in type Ty determines the vector width.
auto toVectorTy = [&](Type *ArgTy) {
@@ -3027,48 +3027,29 @@ bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
(DstEltSize != 16 && DstEltSize != 32 && DstEltSize != 64))
return false;
- // Determine if the operation has a widening variant. We consider both the
- // "long" (e.g., usubl) and "wide" (e.g., usubw) versions of the
- // instructions.
- //
- // TODO: Add additional widening operations (e.g., shl, etc.) once we
- // verify that their extending operands are eliminated during code
- // generation.
Type *SrcTy = SrcOverrideTy;
switch (Opcode) {
- case Instruction::Add: // UADDL(2), SADDL(2), UADDW(2), SADDW(2).
- case Instruction::Sub: // USUBL(2), SSUBL(2), USUBW(2), SSUBW(2).
+ case Instruction::Add: // UADDW(2), SADDW(2).
+ case Instruction::Sub: { // USUBW(2), SSUBW(2).
// The second operand needs to be an extend
if (isa<SExtInst>(Args[1]) || isa<ZExtInst>(Args[1])) {
if (!SrcTy)
SrcTy =
toVectorTy(cast<Instruction>(Args[1])->getOperand(0)->getType());
- } else
+ break;
+ }
+
+ if (Opcode == Instruction::Sub)
return false;
- break;
- case Instruction::Mul: { // SMULL(2), UMULL(2)
- // Both operands need to be extends of the same type.
- if ((isa<SExtInst>(Args[0]) && isa<SExtInst>(Args[1])) ||
- (isa<ZExtInst>(Args[0]) && isa<ZExtInst>(Args[1]))) {
+
+ // UADDW(2), SADDW(2) can be commutted.
+ if (isa<SExtInst>(Args[0]) || isa<ZExtInst>(Args[0])) {
if (!SrcTy)
SrcTy =
toVectorTy(cast<Instruction>(Args[0])->getOperand(0)->getType());
- } else if (isa<ZExtInst>(Args[0]) || isa<ZExtInst>(Args[1])) {
- // If one of the operands is a Zext and the other has enough zero bits to
- // be treated as unsigned, we can still general a umull, meaning the zext
- // is free.
- KnownBits Known =
- computeKnownBits(isa<ZExtInst>(Args[0]) ? Args[1] : Args[0], DL);
- if (Args[0]->getType()->getScalarSizeInBits() -
- Known.Zero.countLeadingOnes() >
- DstTy->getScalarSizeInBits() / 2)
- return false;
- if (!SrcTy)
- SrcTy = toVectorTy(Type::getIntNTy(DstTy->getContext(),
- DstTy->getScalarSizeInBits() / 2));
- } else
- return false;
- break;
+ break;
+ }
+ return false;
}
default:
return false;
@@ -3099,6 +3080,73 @@ bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstEltSize;
}
+Type *AArch64TTIImpl::isBinExtWideningInstruction(unsigned Opcode, Type *DstTy,
+ ArrayRef<const Value *> Args,
+ Type *SrcOverrideTy) const {
+ if (Opcode != Instruction::Add && Opcode != Instruction::Sub &&
+ Opcode != Instruction::Mul)
+ return nullptr;
+
+ // Exit early if DstTy is not a vector type whose elements are one of [i16,
+ // i32, i64]. SVE doesn't generally have the same set of instructions to
+ // perform an extend with the add/sub/mul. There are SMULLB style
+ // instructions, but they operate on top/bottom, requiring some sort of lane
+ // interleaving to be used with zext/sext.
+ unsigned DstEltSize = DstTy->getScalarSizeInBits();
+ if (!useNeonVector(DstTy) || Args.size() != 2 ||
+ (DstEltSize != 16 && DstEltSize != 32 && DstEltSize != 64))
+ return nullptr;
+
+ auto getScalarSizeWithOverride = [&](const Value *V) {
+ if (SrcOverrideTy)
+ return SrcOverrideTy->getScalarSizeInBits();
+ return cast<Instruction>(V)
+ ->getOperand(0)
+ ->getType()
+ ->getScalarSizeInBits();
+ };
+
+ unsigned MaxEltSize = 0;
+ if ((isa<SExtInst>(Args[0]) && isa<SExtInst>(Args[1])) ||
+ (isa<ZExtInst>(Args[0]) && isa<ZExtInst>(Args[1]))) {
+ unsigned EltSize0 = getScalarSizeWithOverride(Args[0]);
+ unsigned EltSize1 = getScalarSizeWithOverride(Args[1]);
+ MaxEltSize = std::max(EltSize0, EltSize1);
+ } else if (isa<SExtInst, ZExtInst>(Args[0]) &&
+ isa<SExtInst, ZExtInst>(Args[1])) {
+ unsigned EltSize0 = getScalarSizeWithOverride(Args[0]);
+ unsigned EltSize1 = getScalarSizeWithOverride(Args[1]);
+ // mul(sext, zext) will become smull(sext, zext) if the extends are large
+ // enough.
+ if (EltSize0 >= DstEltSize / 2 || EltSize1 >= DstEltSize / 2)
+ return nullptr;
+ MaxEltSize = DstEltSize / 2;
+ } else if (Opcode == Instruction::Mul &&
+ (isa<ZExtInst>(Args[0]) || isa<ZExtInst>(Args[1]))) {
+ // If one of the operands is a Zext and the other has enough zero bits
+ // to be treated as unsigned, we can still generate a umull, meaning the
+ // zext is free.
+ KnownBits Known =
+ computeKnownBits(isa<ZExtInst>(Args[0]) ? Args[1] : Args[0], DL);
+ if (Args[0]->getType()->getScalarSizeInBits() -
+ Known.Zero.countLeadingOnes() >
+ DstTy->getScalarSizeInBits() / 2)
+ return nullptr;
+
+ MaxEltSize =
+ getScalarSizeWithOverride(isa<ZExtInst>(Args[0]) ? Args[0] : Args[1]);
+ } else
+ return nullptr;
+
+ if (MaxEltSize * 2 > DstEltSize)
+ return nullptr;
+
+ Type *ExtTy = DstTy->getWithNewBitWidth(MaxEltSize * 2);
+ if (ExtTy->getPrimitiveSizeInBits() <= 64)
+ return nullptr;
+ return ExtTy;
+}
+
// s/urhadd instructions implement the following pattern, making the
// extends free:
// %x = add ((zext i8 -> i16), 1)
@@ -3159,7 +3207,24 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
if (I && I->hasOneUser()) {
auto *SingleUser = cast<Instruction>(*I->user_begin());
SmallVector<const Value *, 4> Operands(SingleUser->operand_values());
- if (isWideningInstruction(Dst, SingleUser->getOpcode(), Operands, Src)) {
+ if (Type *ExtTy = isBinExtWideningInstruction(
+ SingleUser->getOpcode(), Dst, Operands,
+ Src != I->getOperand(0)->getType() ? Src : nullptr)) {
+ // The cost from Src->Src*2 needs to be added if required, the cost from
+ // Src*2->ExtTy is free.
+ if (ExtTy->getScalarSizeInBits() > Src->getScalarSizeInBits() * 2) {
+ Type *DoubleSrcTy =
+ Src->getWithNewBitWidth(Src->getScalarSizeInBits() * 2);
+ return getCastInstrCost(Opcode, DoubleSrcTy, Src,
+ TTI::CastContextHint::None, CostKind);
+ }
+
+ return 0;
+ }
+
+ if (isSingleExtWideningInstruction(
+ SingleUser->getOpcode(), Dst, Operands,
+ Src != I->getOperand(0)->getType() ? Src : nullptr)) {
// For adds only count the second operand as free if both operands are
// extends but not the same operation. (i.e both operands are not free in
// add(sext, zext)).
@@ -3168,8 +3233,11 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
(isa<CastInst>(SingleUser->getOperand(1)) &&
cast<CastInst>(SingleUser->getOperand(1))->getOpcode() == Opcode))
return 0;
- } else // Others are free so long as isWideningInstruction returned true.
+ } else {
+ // Others are free so long as isSingleExtWideningInstruction
+ // returned true.
return 0;
+ }
}
// The cast will be free for the s/urhadd instructions
@@ -4148,6 +4216,18 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
}))
return *PromotedCost;
+ // If the operation is a widening instruction (smull or umull) and both
+ // operands are extends the cost can be cheaper by considering that the
+ // operation will operate on the narrowest type size possible (double the
+ // largest input size) and a further extend.
+ if (Type *ExtTy = isBinExtWideningInstruction(Opcode, Ty, Args)) {
+ if (ExtTy != Ty)
+ return getArithmeticInstrCost(Opcode, ExtTy, CostKind) +
+ getCastInstrCost(Instruction::ZExt, Ty, ExtTy,
+ TTI::CastContextHint::None, CostKind);
+ return LT.first;
+ }
+
switch (ISD) {
default:
return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
@@ -4381,10 +4461,8 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
// - two 2-cost i64 inserts, and
// - two 1-cost muls.
// So, for a v2i64 with LT.First = 1 the cost is 14, and for a v4i64 with
- // LT.first = 2 the cost is 28. If both operands are extensions it will not
- // need to scalarize so the cost can be cheaper (smull or umull).
- // so the cost can be cheaper (smull or umull).
- if (LT.second != MVT::v2i64 || isWideningInstruction(Ty, Opcode, Args))
+ // LT.first = 2 the cost is 28.
+ if (LT.second != MVT::v2i64)
return LT.first;
return cast<VectorType>(Ty)->getElementCount().getKnownMinValue() *
(getArithmeticInstrCost(Opcode, Ty->getScalarType(), CostKind) +
@@ -6129,7 +6207,8 @@ AArch64TTIImpl::getShuffleCost(TTI::ShuffleKind Kind, VectorType *DstTy,
}
static bool containsDecreasingPointers(Loop *TheLoop,
- PredicatedScalarEvolution *PSE) {
+ PredicatedScalarEvolution *PSE,
+ const DominatorTree &DT) {
const auto &Strides = DenseMap<Value *, const SCEV *>();
for (BasicBlock *BB : TheLoop->blocks()) {
// Scan the instructions in the block and look for addresses that are
@@ -6138,8 +6217,8 @@ static bool containsDecreasingPointers(Loop *TheLoop,
if (isa<LoadInst>(&I) || isa<StoreInst>(&I)) {
Value *Ptr = getLoadStorePointerOperand(&I);
Type *AccessTy = getLoadStoreType(&I);
- if (getPtrStride(*PSE, AccessTy, Ptr, TheLoop, Strides, /*Assume=*/true,
- /*ShouldCheckWrap=*/false)
+ if (getPtrStride(*PSE, AccessTy, Ptr, TheLoop, DT, Strides,
+ /*Assume=*/true, /*ShouldCheckWrap=*/false)
.value_or(0) < 0)
return true;
}
@@ -6184,7 +6263,8 @@ bool AArch64TTIImpl::preferPredicateOverEpilogue(TailFoldingInfo *TFI) const {
// negative strides. This will require extra work to reverse the loop
// predicate, which may be expensive.
if (containsDecreasingPointers(TFI->LVL->getLoop(),
- TFI->LVL->getPredicatedScalarEvolution()))
+ TFI->LVL->getPredicatedScalarEvolution(),
+ *TFI->LVL->getDominatorTree()))
Required |= TailFoldingOpts::Reverse;
if (Required == TailFoldingOpts::Disabled)
Required |= TailFoldingOpts::Simple;