diff options
Diffstat (limited to 'llvm/lib/Target/AArch64')
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 38 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64ISelLowering.h | 1 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64InstrGISel.td | 7 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.cpp | 8 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h | 2 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64PostCoalescerPass.cpp | 4 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td | 2 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp | 2 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/MachineSMEABIPass.cpp | 108 |
9 files changed, 145 insertions, 27 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 9926a4d..91c1f59 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1561,6 +1561,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::VECREDUCE_ADD, VT, Custom); setOperationAction(ISD::VECREDUCE_AND, VT, Custom); setOperationAction(ISD::VECREDUCE_OR, VT, Custom); + setOperationAction(ISD::VECREDUCE_MUL, VT, Custom); setOperationAction(ISD::VECREDUCE_XOR, VT, Custom); setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom); setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom); @@ -1717,6 +1718,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::VECREDUCE_FMIN, VT, Custom); setOperationAction(ISD::VECREDUCE_FMAXIMUM, VT, Custom); setOperationAction(ISD::VECREDUCE_FMINIMUM, VT, Custom); + setOperationAction(ISD::VECREDUCE_FMUL, VT, Custom); setOperationAction(ISD::VECTOR_SPLICE, VT, Custom); setOperationAction(ISD::VECTOR_DEINTERLEAVE, VT, Custom); setOperationAction(ISD::VECTOR_INTERLEAVE, VT, Custom); @@ -7775,6 +7777,9 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, case ISD::VECREDUCE_FMAXIMUM: case ISD::VECREDUCE_FMINIMUM: return LowerVECREDUCE(Op, DAG); + case ISD::VECREDUCE_MUL: + case ISD::VECREDUCE_FMUL: + return LowerVECREDUCE_MUL(Op, DAG); case ISD::ATOMIC_LOAD_AND: return LowerATOMIC_LOAD_AND(Op, DAG); case ISD::DYNAMIC_STACKALLOC: @@ -16254,7 +16259,7 @@ SDValue AArch64TargetLowering::LowerDIV(SDValue Op, SelectionDAG &DAG) const { SplatVal > 1) { SDValue Pg = getPredicateForScalableVector(DAG, DL, VT); SDValue Res = - DAG.getNode(AArch64ISD::SRAD_MERGE_OP1, DL, VT, Pg, Op->getOperand(0), + DAG.getNode(AArch64ISD::ASRD_MERGE_OP1, DL, VT, Pg, Op->getOperand(0), DAG.getTargetConstant(Log2_64(SplatVal), DL, MVT::i32)); if (Negated) Res = DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), Res); @@ -16794,6 +16799,33 @@ SDValue AArch64TargetLowering::LowerVECREDUCE(SDValue Op, } } +SDValue AArch64TargetLowering::LowerVECREDUCE_MUL(SDValue Op, + SelectionDAG &DAG) const { + SDLoc DL(Op); + SDValue Src = Op.getOperand(0); + EVT SrcVT = Src.getValueType(); + assert(SrcVT.isScalableVector() && "Unexpected operand type!"); + + SDVTList SrcVTs = DAG.getVTList(SrcVT, SrcVT); + unsigned BaseOpc = ISD::getVecReduceBaseOpcode(Op.getOpcode()); + SDValue Identity = DAG.getNeutralElement(BaseOpc, DL, SrcVT, Op->getFlags()); + + // Whilst we don't know the size of the vector we do know the maximum size so + // can perform a tree reduction with an identity vector, which means once we + // arrive at the result the remaining stages (when the vector is smaller than + // the maximum) have no affect. + + unsigned Segments = AArch64::SVEMaxBitsPerVector / AArch64::SVEBitsPerBlock; + unsigned Stages = llvm::Log2_32(Segments * SrcVT.getVectorMinNumElements()); + + for (unsigned I = 0; I < Stages; ++I) { + Src = DAG.getNode(ISD::VECTOR_DEINTERLEAVE, DL, SrcVTs, Src, Identity); + Src = DAG.getNode(BaseOpc, DL, SrcVT, Src.getValue(0), Src.getValue(1)); + } + + return DAG.getExtractVectorElt(DL, Op.getValueType(), Src, 0); +} + SDValue AArch64TargetLowering::LowerATOMIC_LOAD_AND(SDValue Op, SelectionDAG &DAG) const { auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>(); @@ -22942,7 +22974,7 @@ static SDValue performIntrinsicCombine(SDNode *N, return DAG.getNode(ISD::USUBSAT, SDLoc(N), N->getValueType(0), N->getOperand(1), N->getOperand(2)); case Intrinsic::aarch64_sve_asrd: - return DAG.getNode(AArch64ISD::SRAD_MERGE_OP1, SDLoc(N), N->getValueType(0), + return DAG.getNode(AArch64ISD::ASRD_MERGE_OP1, SDLoc(N), N->getValueType(0), N->getOperand(1), N->getOperand(2), N->getOperand(3)); case Intrinsic::aarch64_sve_cmphs: if (!N->getOperand(2).getValueType().isFloatingPoint()) @@ -30047,7 +30079,7 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorIntDivideToSVE( SDValue Pg = getPredicateForFixedLengthVector(DAG, DL, VT); SDValue Res = - DAG.getNode(AArch64ISD::SRAD_MERGE_OP1, DL, ContainerVT, Pg, Op1, Op2); + DAG.getNode(AArch64ISD::ASRD_MERGE_OP1, DL, ContainerVT, Pg, Op1, Op2); if (Negated) Res = DAG.getNode(ISD::SUB, DL, ContainerVT, DAG.getConstant(0, DL, ContainerVT), Res); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index 00956fd..9495c9f 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -752,6 +752,7 @@ private: SDValue LowerVSCALE(SDValue Op, SelectionDAG &DAG) const; SDValue LowerTRUNCATE(SDValue Op, SelectionDAG &DAG) const; SDValue LowerVECREDUCE(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerVECREDUCE_MUL(SDValue Op, SelectionDAG &DAG) const; SDValue LowerATOMIC_LOAD_AND(SDValue Op, SelectionDAG &DAG) const; SDValue LowerWindowsDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const; SDValue LowerInlineDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const; diff --git a/llvm/lib/Target/AArch64/AArch64InstrGISel.td b/llvm/lib/Target/AArch64/AArch64InstrGISel.td index 7322212..fe84193 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrGISel.td +++ b/llvm/lib/Target/AArch64/AArch64InstrGISel.td @@ -233,6 +233,12 @@ def G_SDOT : AArch64GenericInstruction { let hasSideEffects = 0; } +def G_USDOT : AArch64GenericInstruction { + let OutOperandList = (outs type0:$dst); + let InOperandList = (ins type0:$src1, type0:$src2, type0:$src3); + let hasSideEffects = 0; +} + // Generic instruction for the BSP pseudo. It is expanded into BSP, which // expands into BSL/BIT/BIF after register allocation. def G_BSP : AArch64GenericInstruction { @@ -278,6 +284,7 @@ def : GINodeEquiv<G_UADDLV, AArch64uaddlv>; def : GINodeEquiv<G_UDOT, AArch64udot>; def : GINodeEquiv<G_SDOT, AArch64sdot>; +def : GINodeEquiv<G_USDOT, AArch64usdot>; def : GINodeEquiv<G_EXTRACT_VECTOR_ELT, vector_extract>; diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.cpp b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.cpp index b3c9656..343fd81 100644 --- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.cpp @@ -40,7 +40,11 @@ yaml::AArch64FunctionInfo::AArch64FunctionInfo( getSVEStackSize(MFI, &llvm::AArch64FunctionInfo::getStackSizePPR)), HasStackFrame(MFI.hasStackFrame() ? std::optional<bool>(MFI.hasStackFrame()) - : std::nullopt) {} + : std::nullopt), + HasStreamingModeChanges( + MFI.hasStreamingModeChanges() + ? std::optional<bool>(MFI.hasStreamingModeChanges()) + : std::nullopt) {} void yaml::AArch64FunctionInfo::mappingImpl(yaml::IO &YamlIO) { MappingTraits<AArch64FunctionInfo>::mapping(YamlIO, *this); @@ -55,6 +59,8 @@ void AArch64FunctionInfo::initializeBaseYamlFields( YamlMFI.StackSizePPR.value_or(0)); if (YamlMFI.HasStackFrame) setHasStackFrame(*YamlMFI.HasStackFrame); + if (YamlMFI.HasStreamingModeChanges) + setHasStreamingModeChanges(*YamlMFI.HasStreamingModeChanges); } static std::pair<bool, bool> GetSignReturnAddress(const Function &F) { diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h index bd0a17d..d1832f4 100644 --- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h +++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h @@ -645,6 +645,7 @@ struct AArch64FunctionInfo final : public yaml::MachineFunctionInfo { std::optional<uint64_t> StackSizeZPR; std::optional<uint64_t> StackSizePPR; std::optional<bool> HasStackFrame; + std::optional<bool> HasStreamingModeChanges; AArch64FunctionInfo() = default; AArch64FunctionInfo(const llvm::AArch64FunctionInfo &MFI); @@ -659,6 +660,7 @@ template <> struct MappingTraits<AArch64FunctionInfo> { YamlIO.mapOptional("stackSizeZPR", MFI.StackSizeZPR); YamlIO.mapOptional("stackSizePPR", MFI.StackSizePPR); YamlIO.mapOptional("hasStackFrame", MFI.HasStackFrame); + YamlIO.mapOptional("hasStreamingModeChanges", MFI.HasStreamingModeChanges); } }; diff --git a/llvm/lib/Target/AArch64/AArch64PostCoalescerPass.cpp b/llvm/lib/Target/AArch64/AArch64PostCoalescerPass.cpp index cdf2822..a90950d 100644 --- a/llvm/lib/Target/AArch64/AArch64PostCoalescerPass.cpp +++ b/llvm/lib/Target/AArch64/AArch64PostCoalescerPass.cpp @@ -75,6 +75,10 @@ bool AArch64PostCoalescer::runOnMachineFunction(MachineFunction &MF) { if (Src != Dst) MRI->replaceRegWith(Dst, Src); + if (MI.getOperand(1).isUndef()) + for (MachineOperand &MO : MRI->use_operands(Dst)) + MO.setIsUndef(); + // MI must be erased from the basic block before recalculating the live // interval. LIS->RemoveMachineInstrFromMaps(MI); diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td index bc6b931..98a128e 100644 --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -265,7 +265,7 @@ def SDT_AArch64Arith_Imm : SDTypeProfile<1, 3, [ SDTCVecEltisVT<1,i1>, SDTCisSameAs<0,2> ]>; -def AArch64asrd_m1 : SDNode<"AArch64ISD::SRAD_MERGE_OP1", SDT_AArch64Arith_Imm>; +def AArch64asrd_m1 : SDNode<"AArch64ISD::ASRD_MERGE_OP1", SDT_AArch64Arith_Imm>; def AArch64urshri_p_node : SDNode<"AArch64ISD::URSHR_I_PRED", SDT_AArch64Arith_Imm>; def AArch64urshri_p : PatFrags<(ops node:$op1, node:$op2, node:$op3), diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp index 9e2d698..05a4313 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp @@ -1855,6 +1855,8 @@ bool AArch64LegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper, return LowerTriOp(AArch64::G_UDOT); case Intrinsic::aarch64_neon_sdot: return LowerTriOp(AArch64::G_SDOT); + case Intrinsic::aarch64_neon_usdot: + return LowerTriOp(AArch64::G_USDOT); case Intrinsic::aarch64_neon_sqxtn: return LowerUnaryOp(TargetOpcode::G_TRUNC_SSAT_S); case Intrinsic::aarch64_neon_sqxtun: diff --git a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp index 4749748..434ea67 100644 --- a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp +++ b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp @@ -294,6 +294,12 @@ struct MachineSMEABI : public MachineFunctionPass { MachineBasicBlock::iterator MBBI, LiveRegs PhysLiveRegs); + /// Attempts to find an insertion point before \p Inst where the status flags + /// are not live. If \p Inst is `Block.Insts.end()` a point before the end of + /// the block is found. + std::pair<MachineBasicBlock::iterator, LiveRegs> + findStateChangeInsertionPoint(MachineBasicBlock &MBB, const BlockInfo &Block, + SmallVectorImpl<InstInfo>::const_iterator Inst); void emitStateChange(EmitContext &, MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, ZAState From, ZAState To, LiveRegs PhysLiveRegs); @@ -337,6 +343,28 @@ private: MachineRegisterInfo *MRI = nullptr; }; +static LiveRegs getPhysLiveRegs(LiveRegUnits const &LiveUnits) { + LiveRegs PhysLiveRegs = LiveRegs::None; + if (!LiveUnits.available(AArch64::NZCV)) + PhysLiveRegs |= LiveRegs::NZCV; + // We have to track W0 and X0 separately as otherwise things can get + // confused if we attempt to preserve X0 but only W0 was defined. + if (!LiveUnits.available(AArch64::W0)) + PhysLiveRegs |= LiveRegs::W0; + if (!LiveUnits.available(AArch64::W0_HI)) + PhysLiveRegs |= LiveRegs::W0_HI; + return PhysLiveRegs; +} + +static void setPhysLiveRegs(LiveRegUnits &LiveUnits, LiveRegs PhysLiveRegs) { + if (PhysLiveRegs & LiveRegs::NZCV) + LiveUnits.addReg(AArch64::NZCV); + if (PhysLiveRegs & LiveRegs::W0) + LiveUnits.addReg(AArch64::W0); + if (PhysLiveRegs & LiveRegs::W0_HI) + LiveUnits.addReg(AArch64::W0_HI); +} + FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) { assert((SMEFnAttrs.hasAgnosticZAInterface() || SMEFnAttrs.hasZT0State() || SMEFnAttrs.hasZAState()) && @@ -362,26 +390,13 @@ FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) { LiveRegUnits LiveUnits(*TRI); LiveUnits.addLiveOuts(MBB); - auto GetPhysLiveRegs = [&] { - LiveRegs PhysLiveRegs = LiveRegs::None; - if (!LiveUnits.available(AArch64::NZCV)) - PhysLiveRegs |= LiveRegs::NZCV; - // We have to track W0 and X0 separately as otherwise things can get - // confused if we attempt to preserve X0 but only W0 was defined. - if (!LiveUnits.available(AArch64::W0)) - PhysLiveRegs |= LiveRegs::W0; - if (!LiveUnits.available(AArch64::W0_HI)) - PhysLiveRegs |= LiveRegs::W0_HI; - return PhysLiveRegs; - }; - - Block.PhysLiveRegsAtExit = GetPhysLiveRegs(); + Block.PhysLiveRegsAtExit = getPhysLiveRegs(LiveUnits); auto FirstTerminatorInsertPt = MBB.getFirstTerminator(); auto FirstNonPhiInsertPt = MBB.getFirstNonPHI(); for (MachineInstr &MI : reverse(MBB)) { MachineBasicBlock::iterator MBBI(MI); LiveUnits.stepBackward(MI); - LiveRegs PhysLiveRegs = GetPhysLiveRegs(); + LiveRegs PhysLiveRegs = getPhysLiveRegs(LiveUnits); // The SMEStateAllocPseudo marker is added to a function if the save // buffer was allocated in SelectionDAG. It marks the end of the // allocation -- which is a safe point for this pass to insert any TPIDR2 @@ -476,6 +491,49 @@ MachineSMEABI::assignBundleZAStates(const EdgeBundles &Bundles, return BundleStates; } +std::pair<MachineBasicBlock::iterator, LiveRegs> +MachineSMEABI::findStateChangeInsertionPoint( + MachineBasicBlock &MBB, const BlockInfo &Block, + SmallVectorImpl<InstInfo>::const_iterator Inst) { + LiveRegs PhysLiveRegs; + MachineBasicBlock::iterator InsertPt; + if (Inst != Block.Insts.end()) { + InsertPt = Inst->InsertPt; + PhysLiveRegs = Inst->PhysLiveRegs; + } else { + InsertPt = MBB.getFirstTerminator(); + PhysLiveRegs = Block.PhysLiveRegsAtExit; + } + + if (!(PhysLiveRegs & LiveRegs::NZCV)) + return {InsertPt, PhysLiveRegs}; // Nothing to do (no live flags). + + // Find the previous state change. We can not move before this point. + MachineBasicBlock::iterator PrevStateChangeI; + if (Inst == Block.Insts.begin()) { + PrevStateChangeI = MBB.begin(); + } else { + // Note: `std::prev(Inst)` is the previous InstInfo. We only create an + // InstInfo object for instructions that require a specific ZA state, so the + // InstInfo is the site of the previous state change in the block (which can + // be several MIs earlier). + PrevStateChangeI = std::prev(Inst)->InsertPt; + } + + // Note: LiveUnits will only accurately track X0 and NZCV. + LiveRegUnits LiveUnits(*TRI); + setPhysLiveRegs(LiveUnits, PhysLiveRegs); + for (MachineBasicBlock::iterator I = InsertPt; I != PrevStateChangeI; --I) { + // Don't move before/into a call (which may have a state change before it). + if (I->getOpcode() == TII->getCallFrameDestroyOpcode() || I->isCall()) + break; + LiveUnits.stepBackward(*I); + if (LiveUnits.available(AArch64::NZCV)) + return {I, getPhysLiveRegs(LiveUnits)}; + } + return {InsertPt, PhysLiveRegs}; +} + void MachineSMEABI::insertStateChanges(EmitContext &Context, const FunctionInfo &FnInfo, const EdgeBundles &Bundles, @@ -490,10 +548,13 @@ void MachineSMEABI::insertStateChanges(EmitContext &Context, CurrentState = InState; for (auto &Inst : Block.Insts) { - if (CurrentState != Inst.NeededState) - emitStateChange(Context, MBB, Inst.InsertPt, CurrentState, - Inst.NeededState, Inst.PhysLiveRegs); - CurrentState = Inst.NeededState; + if (CurrentState != Inst.NeededState) { + auto [InsertPt, PhysLiveRegs] = + findStateChangeInsertionPoint(MBB, Block, &Inst); + emitStateChange(Context, MBB, InsertPt, CurrentState, Inst.NeededState, + PhysLiveRegs); + CurrentState = Inst.NeededState; + } } if (MBB.succ_empty()) @@ -501,9 +562,12 @@ void MachineSMEABI::insertStateChanges(EmitContext &Context, ZAState OutState = BundleStates[Bundles.getBundle(MBB.getNumber(), /*Out=*/true)]; - if (CurrentState != OutState) - emitStateChange(Context, MBB, MBB.getFirstTerminator(), CurrentState, - OutState, Block.PhysLiveRegsAtExit); + if (CurrentState != OutState) { + auto [InsertPt, PhysLiveRegs] = + findStateChangeInsertionPoint(MBB, Block, Block.Insts.end()); + emitStateChange(Context, MBB, InsertPt, CurrentState, OutState, + PhysLiveRegs); + } } } |