diff options
Diffstat (limited to 'llvm/lib/Target/AArch64')
| -rw-r--r-- | llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp | 14 | ||||
| -rw-r--r-- | llvm/lib/Target/AArch64/AArch64PrologueEpilogue.cpp | 150 | ||||
| -rw-r--r-- | llvm/lib/Target/AArch64/AArch64PrologueEpilogue.h | 6 | ||||
| -rw-r--r-- | llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp | 167 | ||||
| -rw-r--r-- | llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h | 14 | 
5 files changed, 219 insertions, 132 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp b/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp index 1169f26..97298f9 100644 --- a/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp @@ -655,16 +655,10 @@ Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) {    BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit);    IRBuilder<> B(BB); -  // Load the global symbol as a pointer to the check function. -  Value *GuardFn; -  if (cfguard_module_flag == 2 && !F->hasFnAttribute("guard_nocf")) -    GuardFn = GuardFnCFGlobal; -  else -    GuardFn = GuardFnGlobal; -  LoadInst *GuardCheckLoad = B.CreateLoad(PtrTy, GuardFn); - -  // Create new call instruction. The CFGuard check should always be a call, -  // even if the original CallBase is an Invoke or CallBr instruction. +  // Create new call instruction. The call check should always be a call, +  // even if the original CallBase is an Invoke or CallBr instructio. +  // This is treated as a direct call, so do not use GuardFnCFGlobal. +  LoadInst *GuardCheckLoad = B.CreateLoad(PtrTy, GuardFnGlobal);    Function *Thunk = buildExitThunk(F->getFunctionType(), F->getAttributes());    CallInst *GuardCheck = B.CreateCall(        GuardFnType, GuardCheckLoad, {F, Thunk}); diff --git a/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.cpp b/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.cpp index 7e03b97..45b7120 100644 --- a/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.cpp +++ b/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.cpp @@ -370,6 +370,22 @@ SVEFrameSizes AArch64PrologueEpilogueCommon::getSVEStackFrameSizes() const {            {ZPRCalleeSavesSize, PPRLocalsSize + ZPRLocalsSize}};  } +SVEStackAllocations AArch64PrologueEpilogueCommon::getSVEStackAllocations( +    SVEFrameSizes const &SVE) { +  StackOffset AfterZPRs = SVE.ZPR.LocalsSize; +  StackOffset BeforePPRs = SVE.ZPR.CalleeSavesSize + SVE.PPR.CalleeSavesSize; +  StackOffset AfterPPRs = {}; +  if (SVELayout == SVEStackLayout::Split) { +    BeforePPRs = SVE.PPR.CalleeSavesSize; +    // If there are no ZPR CSRs, place all local allocations after the ZPRs. +    if (SVE.ZPR.CalleeSavesSize) +      AfterPPRs += SVE.PPR.LocalsSize + SVE.ZPR.CalleeSavesSize; +    else +      AfterZPRs += SVE.PPR.LocalsSize; // Group allocation of locals. +  } +  return {BeforePPRs, AfterPPRs, AfterZPRs}; +} +  struct SVEPartitions {    struct {      MachineBasicBlock::iterator Begin, End; @@ -687,16 +703,19 @@ void AArch64PrologueEmitter::emitPrologue() {    // All of the remaining stack allocations are for locals.    determineLocalsStackSize(NumBytes, PrologueSaveSize); +  auto [PPR, ZPR] = getSVEStackFrameSizes(); +  SVEStackAllocations SVEAllocs = getSVEStackAllocations({PPR, ZPR}); +    MachineBasicBlock::iterator FirstGPRSaveI = PrologueBeginI;    if (SVELayout == SVEStackLayout::CalleeSavesAboveFrameRecord) { +    assert(!SVEAllocs.AfterPPRs && +           "unexpected SVE allocs after PPRs with CalleeSavesAboveFrameRecord");      // If we're doing SVE saves first, we need to immediately allocate space      // for fixed objects, then space for the SVE callee saves.      //      // Windows unwind requires that the scalable size is a multiple of 16;      // that's handled when the callee-saved size is computed. -    auto SaveSize = -        StackOffset::getScalable(AFI->getSVECalleeSavedStackSize()) + -        StackOffset::getFixed(FixedObject); +    auto SaveSize = SVEAllocs.BeforePPRs + StackOffset::getFixed(FixedObject);      allocateStackSpace(PrologueBeginI, 0, SaveSize, false, StackOffset{},                         /*FollowupAllocs=*/true);      NumBytes -= FixedObject; @@ -764,12 +783,11 @@ void AArch64PrologueEmitter::emitPrologue() {    if (AFL.windowsRequiresStackProbe(MF, NumBytes + RealignmentPadding))      emitWindowsStackProbe(AfterGPRSavesI, DL, NumBytes, RealignmentPadding); -  auto [PPR, ZPR] = getSVEStackFrameSizes(); -  StackOffset SVECalleeSavesSize = ZPR.CalleeSavesSize + PPR.CalleeSavesSize;    StackOffset NonSVELocalsSize = StackOffset::getFixed(NumBytes); +  SVEAllocs.AfterZPRs += NonSVELocalsSize; +    StackOffset CFAOffset =        StackOffset::getFixed(MFI.getStackSize()) - NonSVELocalsSize; -    MachineBasicBlock::iterator AfterSVESavesI = AfterGPRSavesI;    // Allocate space for the callee saves and PPR locals (if any).    if (SVELayout != SVEStackLayout::CalleeSavesAboveFrameRecord) { @@ -780,31 +798,23 @@ void AArch64PrologueEmitter::emitPrologue() {      if (EmitAsyncCFI)        emitCalleeSavedSVELocations(AfterSVESavesI); -    StackOffset AllocateBeforePPRs = SVECalleeSavesSize; -    StackOffset AllocateAfterPPRs = PPR.LocalsSize; -    if (SVELayout == SVEStackLayout::Split) { -      AllocateBeforePPRs = PPR.CalleeSavesSize; -      AllocateAfterPPRs = PPR.LocalsSize + ZPR.CalleeSavesSize; -    } -    allocateStackSpace(PPRRange.Begin, 0, AllocateBeforePPRs, +    allocateStackSpace(PPRRange.Begin, 0, SVEAllocs.BeforePPRs,                         EmitAsyncCFI && !HasFP, CFAOffset, -                       MFI.hasVarSizedObjects() || AllocateAfterPPRs || -                           ZPR.LocalsSize || NonSVELocalsSize); -    CFAOffset += AllocateBeforePPRs; +                       MFI.hasVarSizedObjects() || SVEAllocs.AfterPPRs || +                           SVEAllocs.AfterZPRs); +    CFAOffset += SVEAllocs.BeforePPRs;      assert(PPRRange.End == ZPRRange.Begin &&             "Expected ZPR callee saves after PPR locals"); -    allocateStackSpace(PPRRange.End, RealignmentPadding, AllocateAfterPPRs, +    allocateStackSpace(PPRRange.End, RealignmentPadding, SVEAllocs.AfterPPRs,                         EmitAsyncCFI && !HasFP, CFAOffset, -                       MFI.hasVarSizedObjects() || ZPR.LocalsSize || -                           NonSVELocalsSize); -    CFAOffset += AllocateAfterPPRs; +                       MFI.hasVarSizedObjects() || SVEAllocs.AfterZPRs); +    CFAOffset += SVEAllocs.AfterPPRs;    } else {      assert(SVELayout == SVEStackLayout::CalleeSavesAboveFrameRecord); -    // Note: With CalleeSavesAboveFrameRecord, the SVE CS have already been -    // allocated (and separate PPR locals are not supported, all SVE locals, -    // both PPR and ZPR, are within the ZPR locals area). -    assert(!PPR.LocalsSize && "Unexpected PPR locals!"); -    CFAOffset += SVECalleeSavesSize; +    // Note: With CalleeSavesAboveFrameRecord, the SVE CS (BeforePPRs) have +    // already been allocated. PPR locals (included in AfterPPRs) are not +    // supported (note: this is asserted above). +    CFAOffset += SVEAllocs.BeforePPRs;    }    // Allocate space for the rest of the frame including ZPR locals. Align the @@ -815,9 +825,9 @@ void AArch64PrologueEmitter::emitPrologue() {      // FIXME: in the case of dynamic re-alignment, NumBytes doesn't have the      // correct value here, as NumBytes also includes padding bytes, which      // shouldn't be counted here. -    allocateStackSpace( -        AfterSVESavesI, RealignmentPadding, ZPR.LocalsSize + NonSVELocalsSize, -        EmitAsyncCFI && !HasFP, CFAOffset, MFI.hasVarSizedObjects()); +    allocateStackSpace(AfterSVESavesI, RealignmentPadding, SVEAllocs.AfterZPRs, +                       EmitAsyncCFI && !HasFP, CFAOffset, +                       MFI.hasVarSizedObjects());    }    // If we need a base pointer, set it up here. It's whatever the value of the @@ -1472,27 +1482,26 @@ void AArch64EpilogueEmitter::emitEpilogue() {    assert(NumBytes >= 0 && "Negative stack allocation size!?");    StackOffset SVECalleeSavesSize = ZPR.CalleeSavesSize + PPR.CalleeSavesSize; -  StackOffset SVEStackSize = -      SVECalleeSavesSize + PPR.LocalsSize + ZPR.LocalsSize; +  SVEStackAllocations SVEAllocs = getSVEStackAllocations({PPR, ZPR});    MachineBasicBlock::iterator RestoreBegin = ZPRRange.Begin; -  MachineBasicBlock::iterator RestoreEnd = PPRRange.End;    // Deallocate the SVE area.    if (SVELayout == SVEStackLayout::CalleeSavesAboveFrameRecord) { -    StackOffset SVELocalsSize = ZPR.LocalsSize + PPR.LocalsSize; +    assert(!SVEAllocs.AfterPPRs && +           "unexpected SVE allocs after PPRs with CalleeSavesAboveFrameRecord");      // If the callee-save area is before FP, restoring the FP implicitly -    // deallocates non-callee-save SVE allocations.  Otherwise, deallocate them +    // deallocates non-callee-save SVE allocations. Otherwise, deallocate them      // explicitly.      if (!AFI->isStackRealigned() && !MFI.hasVarSizedObjects()) {        emitFrameOffset(MBB, FirstGPRRestoreI, DL, AArch64::SP, AArch64::SP, -                      SVELocalsSize, TII, MachineInstr::FrameDestroy, false, -                      NeedsWinCFI, &HasWinCFI); +                      SVEAllocs.AfterZPRs, TII, MachineInstr::FrameDestroy, +                      false, NeedsWinCFI, &HasWinCFI);      }      // Deallocate callee-save SVE registers. -    emitFrameOffset(MBB, RestoreEnd, DL, AArch64::SP, AArch64::SP, -                    SVECalleeSavesSize, TII, MachineInstr::FrameDestroy, false, -                    NeedsWinCFI, &HasWinCFI); +    emitFrameOffset(MBB, PPRRange.End, DL, AArch64::SP, AArch64::SP, +                    SVEAllocs.BeforePPRs, TII, MachineInstr::FrameDestroy, +                    false, NeedsWinCFI, &HasWinCFI);    } else if (AFI->hasSVEStackSize()) {      // If we have stack realignment or variable-sized objects we must use the FP      // to restore SVE callee saves (as there is an unknown amount of @@ -1524,46 +1533,33 @@ void AArch64EpilogueEmitter::emitEpilogue() {        emitFrameOffset(MBB, RestoreBegin, DL, AArch64::SP, CalleeSaveBase,                        -SVECalleeSavesSize, TII, MachineInstr::FrameDestroy);      } else if (BaseForSVEDealloc == AArch64::SP) { -      auto CFAOffset = -          SVEStackSize + StackOffset::getFixed(NumBytes + PrologueSaveSize); - -      if (SVECalleeSavesSize) { -        // Deallocate the non-SVE locals first before we can deallocate (and -        // restore callee saves) from the SVE area. -        auto NonSVELocals = StackOffset::getFixed(NumBytes); -        emitFrameOffset(MBB, ZPRRange.Begin, DL, AArch64::SP, AArch64::SP, -                        NonSVELocals, TII, MachineInstr::FrameDestroy, false, -                        NeedsWinCFI, &HasWinCFI, EmitCFI && !HasFP, CFAOffset); -        CFAOffset -= NonSVELocals; -        NumBytes = 0; -      } - -      if (ZPR.LocalsSize) { -        emitFrameOffset(MBB, ZPRRange.Begin, DL, AArch64::SP, AArch64::SP, -                        ZPR.LocalsSize, TII, MachineInstr::FrameDestroy, false, -                        NeedsWinCFI, &HasWinCFI, EmitCFI && !HasFP, CFAOffset); -        CFAOffset -= ZPR.LocalsSize; +      auto NonSVELocals = StackOffset::getFixed(NumBytes); +      auto CFAOffset = NonSVELocals + StackOffset::getFixed(PrologueSaveSize) + +                       SVEAllocs.totalSize(); + +      if (SVECalleeSavesSize || SVELayout == SVEStackLayout::Split) { +        // Deallocate non-SVE locals now. This is needed to reach the SVE callee +        // saves, but may also allow combining stack hazard bumps for split SVE. +        SVEAllocs.AfterZPRs += NonSVELocals; +        NumBytes -= NonSVELocals.getFixed();        } - -      StackOffset SVECalleeSavesToDealloc = SVECalleeSavesSize; -      if (SVELayout == SVEStackLayout::Split && -          (PPR.LocalsSize || ZPR.CalleeSavesSize)) { -        assert(PPRRange.Begin == ZPRRange.End && -               "Expected PPR restores after ZPR"); -        emitFrameOffset(MBB, PPRRange.Begin, DL, AArch64::SP, AArch64::SP, -                        PPR.LocalsSize + ZPR.CalleeSavesSize, TII, -                        MachineInstr::FrameDestroy, false, NeedsWinCFI, -                        &HasWinCFI, EmitCFI && !HasFP, CFAOffset); -        CFAOffset -= PPR.LocalsSize + ZPR.CalleeSavesSize; -        SVECalleeSavesToDealloc -= ZPR.CalleeSavesSize; -      } - -      // If split SVE is on, this dealloc PPRs, otherwise, deallocs ZPRs + PPRs: -      if (SVECalleeSavesToDealloc) -        emitFrameOffset(MBB, PPRRange.End, DL, AArch64::SP, AArch64::SP, -                        SVECalleeSavesToDealloc, TII, -                        MachineInstr::FrameDestroy, false, NeedsWinCFI, -                        &HasWinCFI, EmitCFI && !HasFP, CFAOffset); +      // To deallocate the SVE stack adjust by the allocations in reverse. +      emitFrameOffset(MBB, ZPRRange.Begin, DL, AArch64::SP, AArch64::SP, +                      SVEAllocs.AfterZPRs, TII, MachineInstr::FrameDestroy, +                      false, NeedsWinCFI, &HasWinCFI, EmitCFI && !HasFP, +                      CFAOffset); +      CFAOffset -= SVEAllocs.AfterZPRs; +      assert(PPRRange.Begin == ZPRRange.End && +             "Expected PPR restores after ZPR"); +      emitFrameOffset(MBB, PPRRange.Begin, DL, AArch64::SP, AArch64::SP, +                      SVEAllocs.AfterPPRs, TII, MachineInstr::FrameDestroy, +                      false, NeedsWinCFI, &HasWinCFI, EmitCFI && !HasFP, +                      CFAOffset); +      CFAOffset -= SVEAllocs.AfterPPRs; +      emitFrameOffset(MBB, PPRRange.End, DL, AArch64::SP, AArch64::SP, +                      SVEAllocs.BeforePPRs, TII, MachineInstr::FrameDestroy, +                      false, NeedsWinCFI, &HasWinCFI, EmitCFI && !HasFP, +                      CFAOffset);      }      if (EmitCFI) diff --git a/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.h b/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.h index bccadda..6e0e283 100644 --- a/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.h +++ b/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.h @@ -33,6 +33,11 @@ struct SVEFrameSizes {    } PPR, ZPR;  }; +struct SVEStackAllocations { +  StackOffset BeforePPRs, AfterPPRs, AfterZPRs; +  StackOffset totalSize() const { return BeforePPRs + AfterPPRs + AfterZPRs; } +}; +  class AArch64PrologueEpilogueCommon {  public:    AArch64PrologueEpilogueCommon(MachineFunction &MF, MachineBasicBlock &MBB, @@ -66,6 +71,7 @@ protected:    bool shouldCombineCSRLocalStackBump(uint64_t StackBumpBytes) const;    SVEFrameSizes getSVEStackFrameSizes() const; +  SVEStackAllocations getSVEStackAllocations(SVEFrameSizes const &);    MachineFunction &MF;    MachineBasicBlock &MBB; diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index e8352be..10f2c80 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -3007,9 +3007,9 @@ AArch64TTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {    llvm_unreachable("Unsupported register kind");  } -bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode, -                                           ArrayRef<const Value *> Args, -                                           Type *SrcOverrideTy) const { +bool AArch64TTIImpl::isSingleExtWideningInstruction( +    unsigned Opcode, Type *DstTy, ArrayRef<const Value *> Args, +    Type *SrcOverrideTy) const {    // A helper that returns a vector type from the given type. The number of    // elements in type Ty determines the vector width.    auto toVectorTy = [&](Type *ArgTy) { @@ -3027,48 +3027,29 @@ bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,        (DstEltSize != 16 && DstEltSize != 32 && DstEltSize != 64))      return false; -  // Determine if the operation has a widening variant. We consider both the -  // "long" (e.g., usubl) and "wide" (e.g., usubw) versions of the -  // instructions. -  // -  // TODO: Add additional widening operations (e.g., shl, etc.) once we -  //       verify that their extending operands are eliminated during code -  //       generation.    Type *SrcTy = SrcOverrideTy;    switch (Opcode) { -  case Instruction::Add: // UADDL(2), SADDL(2), UADDW(2), SADDW(2). -  case Instruction::Sub: // USUBL(2), SSUBL(2), USUBW(2), SSUBW(2). +  case Instruction::Add:   // UADDW(2), SADDW(2). +  case Instruction::Sub: { // USUBW(2), SSUBW(2).      // The second operand needs to be an extend      if (isa<SExtInst>(Args[1]) || isa<ZExtInst>(Args[1])) {        if (!SrcTy)          SrcTy =              toVectorTy(cast<Instruction>(Args[1])->getOperand(0)->getType()); -    } else +      break; +    } + +    if (Opcode == Instruction::Sub)        return false; -    break; -  case Instruction::Mul: { // SMULL(2), UMULL(2) -    // Both operands need to be extends of the same type. -    if ((isa<SExtInst>(Args[0]) && isa<SExtInst>(Args[1])) || -        (isa<ZExtInst>(Args[0]) && isa<ZExtInst>(Args[1]))) { + +    // UADDW(2), SADDW(2) can be commutted. +    if (isa<SExtInst>(Args[0]) || isa<ZExtInst>(Args[0])) {        if (!SrcTy)          SrcTy =              toVectorTy(cast<Instruction>(Args[0])->getOperand(0)->getType()); -    } else if (isa<ZExtInst>(Args[0]) || isa<ZExtInst>(Args[1])) { -      // If one of the operands is a Zext and the other has enough zero bits to -      // be treated as unsigned, we can still general a umull, meaning the zext -      // is free. -      KnownBits Known = -          computeKnownBits(isa<ZExtInst>(Args[0]) ? Args[1] : Args[0], DL); -      if (Args[0]->getType()->getScalarSizeInBits() - -              Known.Zero.countLeadingOnes() > -          DstTy->getScalarSizeInBits() / 2) -        return false; -      if (!SrcTy) -        SrcTy = toVectorTy(Type::getIntNTy(DstTy->getContext(), -                                           DstTy->getScalarSizeInBits() / 2)); -    } else -      return false; -    break; +      break; +    } +    return false;    }    default:      return false; @@ -3099,6 +3080,73 @@ bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,    return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstEltSize;  } +Type *AArch64TTIImpl::isBinExtWideningInstruction(unsigned Opcode, Type *DstTy, +                                                  ArrayRef<const Value *> Args, +                                                  Type *SrcOverrideTy) const { +  if (Opcode != Instruction::Add && Opcode != Instruction::Sub && +      Opcode != Instruction::Mul) +    return nullptr; + +  // Exit early if DstTy is not a vector type whose elements are one of [i16, +  // i32, i64]. SVE doesn't generally have the same set of instructions to +  // perform an extend with the add/sub/mul. There are SMULLB style +  // instructions, but they operate on top/bottom, requiring some sort of lane +  // interleaving to be used with zext/sext. +  unsigned DstEltSize = DstTy->getScalarSizeInBits(); +  if (!useNeonVector(DstTy) || Args.size() != 2 || +      (DstEltSize != 16 && DstEltSize != 32 && DstEltSize != 64)) +    return nullptr; + +  auto getScalarSizeWithOverride = [&](const Value *V) { +    if (SrcOverrideTy) +      return SrcOverrideTy->getScalarSizeInBits(); +    return cast<Instruction>(V) +        ->getOperand(0) +        ->getType() +        ->getScalarSizeInBits(); +  }; + +  unsigned MaxEltSize = 0; +  if ((isa<SExtInst>(Args[0]) && isa<SExtInst>(Args[1])) || +      (isa<ZExtInst>(Args[0]) && isa<ZExtInst>(Args[1]))) { +    unsigned EltSize0 = getScalarSizeWithOverride(Args[0]); +    unsigned EltSize1 = getScalarSizeWithOverride(Args[1]); +    MaxEltSize = std::max(EltSize0, EltSize1); +  } else if (isa<SExtInst, ZExtInst>(Args[0]) && +             isa<SExtInst, ZExtInst>(Args[1])) { +    unsigned EltSize0 = getScalarSizeWithOverride(Args[0]); +    unsigned EltSize1 = getScalarSizeWithOverride(Args[1]); +    // mul(sext, zext) will become smull(sext, zext) if the extends are large +    // enough. +    if (EltSize0 >= DstEltSize / 2 || EltSize1 >= DstEltSize / 2) +      return nullptr; +    MaxEltSize = DstEltSize / 2; +  } else if (Opcode == Instruction::Mul && +             (isa<ZExtInst>(Args[0]) || isa<ZExtInst>(Args[1]))) { +    // If one of the operands is a Zext and the other has enough zero bits +    // to be treated as unsigned, we can still generate a umull, meaning the +    // zext is free. +    KnownBits Known = +        computeKnownBits(isa<ZExtInst>(Args[0]) ? Args[1] : Args[0], DL); +    if (Args[0]->getType()->getScalarSizeInBits() - +            Known.Zero.countLeadingOnes() > +        DstTy->getScalarSizeInBits() / 2) +      return nullptr; + +    MaxEltSize = +        getScalarSizeWithOverride(isa<ZExtInst>(Args[0]) ? Args[0] : Args[1]); +  } else +    return nullptr; + +  if (MaxEltSize * 2 > DstEltSize) +    return nullptr; + +  Type *ExtTy = DstTy->getWithNewBitWidth(MaxEltSize * 2); +  if (ExtTy->getPrimitiveSizeInBits() <= 64) +    return nullptr; +  return ExtTy; +} +  // s/urhadd instructions implement the following pattern, making the  // extends free:  //   %x = add ((zext i8 -> i16), 1) @@ -3159,7 +3207,24 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,    if (I && I->hasOneUser()) {      auto *SingleUser = cast<Instruction>(*I->user_begin());      SmallVector<const Value *, 4> Operands(SingleUser->operand_values()); -    if (isWideningInstruction(Dst, SingleUser->getOpcode(), Operands, Src)) { +    if (Type *ExtTy = isBinExtWideningInstruction( +            SingleUser->getOpcode(), Dst, Operands, +            Src != I->getOperand(0)->getType() ? Src : nullptr)) { +      // The cost from Src->Src*2 needs to be added if required, the cost from +      // Src*2->ExtTy is free. +      if (ExtTy->getScalarSizeInBits() > Src->getScalarSizeInBits() * 2) { +        Type *DoubleSrcTy = +            Src->getWithNewBitWidth(Src->getScalarSizeInBits() * 2); +        return getCastInstrCost(Opcode, DoubleSrcTy, Src, +                                TTI::CastContextHint::None, CostKind); +      } + +      return 0; +    } + +    if (isSingleExtWideningInstruction( +            SingleUser->getOpcode(), Dst, Operands, +            Src != I->getOperand(0)->getType() ? Src : nullptr)) {        // For adds only count the second operand as free if both operands are        // extends but not the same operation. (i.e both operands are not free in        // add(sext, zext)). @@ -3168,8 +3233,11 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,              (isa<CastInst>(SingleUser->getOperand(1)) &&               cast<CastInst>(SingleUser->getOperand(1))->getOpcode() == Opcode))            return 0; -      } else // Others are free so long as isWideningInstruction returned true. +      } else { +        // Others are free so long as isSingleExtWideningInstruction +        // returned true.          return 0; +      }      }      // The cast will be free for the s/urhadd instructions @@ -4148,6 +4216,18 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(              }))        return *PromotedCost; +  // If the operation is a widening instruction (smull or umull) and both +  // operands are extends the cost can be cheaper by considering that the +  // operation will operate on the narrowest type size possible (double the +  // largest input size) and a further extend. +  if (Type *ExtTy = isBinExtWideningInstruction(Opcode, Ty, Args)) { +    if (ExtTy != Ty) +      return getArithmeticInstrCost(Opcode, ExtTy, CostKind) + +             getCastInstrCost(Instruction::ZExt, Ty, ExtTy, +                              TTI::CastContextHint::None, CostKind); +    return LT.first; +  } +    switch (ISD) {    default:      return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info, @@ -4381,10 +4461,8 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(      // - two 2-cost i64 inserts, and      // - two 1-cost muls.      // So, for a v2i64 with LT.First = 1 the cost is 14, and for a v4i64 with -    // LT.first = 2 the cost is 28. If both operands are extensions it will not -    // need to scalarize so the cost can be cheaper (smull or umull). -    // so the cost can be cheaper (smull or umull). -    if (LT.second != MVT::v2i64 || isWideningInstruction(Ty, Opcode, Args)) +    // LT.first = 2 the cost is 28. +    if (LT.second != MVT::v2i64)        return LT.first;      return cast<VectorType>(Ty)->getElementCount().getKnownMinValue() *             (getArithmeticInstrCost(Opcode, Ty->getScalarType(), CostKind) + @@ -6657,10 +6735,15 @@ bool AArch64TTIImpl::isProfitableToSinkOperands(            Ops.push_back(&Ext->getOperandUse(0));          Ops.push_back(&Op); -        if (isa<SExtInst>(Ext)) +        if (isa<SExtInst>(Ext)) {            NumSExts++; -        else +        } else {            NumZExts++; +          // A zext(a) is also a sext(zext(a)), if we take more than 2 steps. +          if (Ext->getOperand(0)->getType()->getScalarSizeInBits() * 2 < +              I->getType()->getScalarSizeInBits()) +            NumSExts++; +        }          continue;        } diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h index b39546a..e3b0a1b 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -59,9 +59,17 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {      VECTOR_LDST_FOUR_ELEMENTS    }; -  bool isWideningInstruction(Type *DstTy, unsigned Opcode, -                             ArrayRef<const Value *> Args, -                             Type *SrcOverrideTy = nullptr) const; +  /// Given a add/sub/mul operation, detect a widening addl/subl/mull pattern +  /// where both operands can be treated like extends. Returns the minimal type +  /// needed to compute the operation. +  Type *isBinExtWideningInstruction(unsigned Opcode, Type *DstTy, +                                    ArrayRef<const Value *> Args, +                                    Type *SrcOverrideTy = nullptr) const; +  /// Given a add/sub operation with a single extend operand, detect a +  /// widening addw/subw pattern. +  bool isSingleExtWideningInstruction(unsigned Opcode, Type *DstTy, +                                      ArrayRef<const Value *> Args, +                                      Type *SrcOverrideTy = nullptr) const;    // A helper function called by 'getVectorInstrCost'.    //  | 
