aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/Target/AArch64/AArch64FrameLowering.cpp357
-rw-r--r--llvm/lib/Target/AArch64/AArch64InstrInfo.cpp64
-rw-r--r--llvm/lib/Target/AArch64/AArch64PrologueEpilogue.cpp3
-rw-r--r--llvm/lib/Target/AArch64/AArch64RegisterInfo.td11
-rw-r--r--llvm/lib/Target/AArch64/AArch64Subtarget.cpp19
-rw-r--r--llvm/lib/Target/AArch64/AArch64Subtarget.h2
-rw-r--r--llvm/lib/Target/AArch64/SMEInstrFormats.td14
-rw-r--r--llvm/lib/Target/AMDGPU/AMDGPU.td21
-rw-r--r--llvm/lib/Target/AMDGPU/GCNSubtarget.h4
-rw-r--r--llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp10
-rw-r--r--llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp3
-rw-r--r--llvm/lib/Target/NVPTX/NVPTX.h1
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp76
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelLowering.h5
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXInstrInfo.td54
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXIntrinsics.td97
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp134
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVLegalizeImplicitBinding.cpp192
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVUtils.cpp6
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVUtils.h3
-rw-r--r--llvm/lib/Target/X86/GISel/X86InstructionSelector.cpp1
-rw-r--r--llvm/lib/Target/X86/GISel/X86LegalizerInfo.cpp49
-rw-r--r--llvm/lib/Target/X86/X86InstrAVX512.td90
-rw-r--r--llvm/lib/TargetParser/TargetParser.cpp1
-rw-r--r--llvm/lib/Transforms/Utils/SCCPSolver.cpp39
-rw-r--r--llvm/lib/Transforms/Utils/SimplifyCFG.cpp149
-rw-r--r--llvm/lib/Transforms/Vectorize/LoopVectorize.cpp214
-rw-r--r--llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp6
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp199
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanTransforms.h13
30 files changed, 1004 insertions, 833 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
index 4357264d..c76689f 100644
--- a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
@@ -345,12 +345,6 @@ static unsigned getStackHazardSize(const MachineFunction &MF) {
return MF.getSubtarget<AArch64Subtarget>().getStreamingHazardSize();
}
-/// Returns true if PPRs are spilled as ZPRs.
-static bool arePPRsSpilledAsZPR(const MachineFunction &MF) {
- return MF.getSubtarget().getRegisterInfo()->getSpillSize(
- AArch64::PPRRegClass) == 16;
-}
-
StackOffset
AArch64FrameLowering::getZPRStackSize(const MachineFunction &MF) const {
const AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
@@ -1966,8 +1960,7 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
StrOpc = RPI.isPaired() ? AArch64::ST1B_2Z_IMM : AArch64::STR_ZXI;
break;
case RegPairInfo::PPR:
- StrOpc =
- Size == 16 ? AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO : AArch64::STR_PXI;
+ StrOpc = AArch64::STR_PXI;
break;
case RegPairInfo::VG:
StrOpc = AArch64::STRXui;
@@ -2178,8 +2171,7 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
LdrOpc = RPI.isPaired() ? AArch64::LD1B_2Z_IMM : AArch64::LDR_ZXI;
break;
case RegPairInfo::PPR:
- LdrOpc = Size == 16 ? AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO
- : AArch64::LDR_PXI;
+ LdrOpc = AArch64::LDR_PXI;
break;
case RegPairInfo::VG:
continue;
@@ -2286,9 +2278,7 @@ static std::optional<int> getLdStFrameID(const MachineInstr &MI,
// Returns true if the LDST MachineInstr \p MI is a PPR access.
static bool isPPRAccess(const MachineInstr &MI) {
- return MI.getOpcode() != AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO &&
- MI.getOpcode() != AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO &&
- AArch64::PPRRegClass.contains(MI.getOperand(0).getReg());
+ return AArch64::PPRRegClass.contains(MI.getOperand(0).getReg());
}
// Check if a Hazard slot is needed for the current function, and if so create
@@ -2390,12 +2380,6 @@ void AArch64FrameLowering::determineStackHazardSlot(
return;
}
- if (arePPRsSpilledAsZPR(MF)) {
- LLVM_DEBUG(dbgs() << "SplitSVEObjects is not supported with "
- "-aarch64-enable-zpr-predicate-spills");
- return;
- }
-
// If another calling convention is explicitly set FPRs can't be promoted to
// ZPR callee-saves.
if (!is_contained({CallingConv::C, CallingConv::Fast,
@@ -2519,14 +2503,6 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF,
continue;
}
- // Always save P4 when PPR spills are ZPR-sized and a predicate above p8 is
- // spilled. If all of p0-p3 are used as return values p4 is must be free
- // to reload p8-p15.
- if (RegInfo->getSpillSize(AArch64::PPRRegClass) == 16 &&
- AArch64::PPR_p8to15RegClass.contains(Reg)) {
- SavedRegs.set(AArch64::P4);
- }
-
// MachO's compact unwind format relies on all registers being stored in
// pairs.
// FIXME: the usual format is actually better if unwinding isn't needed.
@@ -2587,7 +2563,7 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF,
auto SpillSize = TRI->getSpillSize(*RC);
bool IsZPR = AArch64::ZPRRegClass.contains(Reg);
bool IsPPR = !IsZPR && AArch64::PPRRegClass.contains(Reg);
- if (IsZPR || (IsPPR && arePPRsSpilledAsZPR(MF)))
+ if (IsZPR)
ZPRCSStackSize += SpillSize;
else if (IsPPR)
PPRCSStackSize += SpillSize;
@@ -2902,7 +2878,7 @@ static SVEStackSizes determineSVEStackSizes(MachineFunction &MF,
StackTop += MFI.getObjectSize(FI);
StackTop = alignTo(StackTop, Alignment);
- assert(StackTop < std::numeric_limits<int64_t>::max() &&
+ assert(StackTop < (uint64_t)std::numeric_limits<int64_t>::max() &&
"SVE StackTop far too large?!");
int64_t Offset = -int64_t(StackTop);
@@ -2961,314 +2937,8 @@ static SVEStackSizes determineSVEStackSizes(MachineFunction &MF,
return SVEStack;
}
-/// Attempts to scavenge a register from \p ScavengeableRegs given the used
-/// registers in \p UsedRegs.
-static Register tryScavengeRegister(LiveRegUnits const &UsedRegs,
- BitVector const &ScavengeableRegs,
- Register PreferredReg) {
- if (PreferredReg != AArch64::NoRegister && UsedRegs.available(PreferredReg))
- return PreferredReg;
- for (auto Reg : ScavengeableRegs.set_bits()) {
- if (UsedRegs.available(Reg))
- return Reg;
- }
- return AArch64::NoRegister;
-}
-
-/// Propagates frame-setup/destroy flags from \p SourceMI to all instructions in
-/// \p MachineInstrs.
-static void propagateFrameFlags(MachineInstr &SourceMI,
- ArrayRef<MachineInstr *> MachineInstrs) {
- for (MachineInstr *MI : MachineInstrs) {
- if (SourceMI.getFlag(MachineInstr::FrameSetup))
- MI->setFlag(MachineInstr::FrameSetup);
- if (SourceMI.getFlag(MachineInstr::FrameDestroy))
- MI->setFlag(MachineInstr::FrameDestroy);
- }
-}
-
-/// RAII helper class for scavenging or spilling a register. On construction
-/// attempts to find a free register of class \p RC (given \p UsedRegs and \p
-/// AllocatableRegs), if no register can be found spills \p SpillCandidate to \p
-/// MaybeSpillFI to free a register. The free'd register is returned via the \p
-/// FreeReg output parameter. On destruction, if there is a spill, its previous
-/// value is reloaded. The spilling and scavenging is only valid at the
-/// insertion point \p MBBI, this class should _not_ be used in places that
-/// create or manipulate basic blocks, moving the expected insertion point.
-struct ScopedScavengeOrSpill {
- ScopedScavengeOrSpill(const ScopedScavengeOrSpill &) = delete;
- ScopedScavengeOrSpill(ScopedScavengeOrSpill &&) = delete;
-
- ScopedScavengeOrSpill(MachineFunction &MF, MachineBasicBlock &MBB,
- MachineBasicBlock::iterator MBBI,
- Register SpillCandidate, const TargetRegisterClass &RC,
- LiveRegUnits const &UsedRegs,
- BitVector const &AllocatableRegs,
- std::optional<int> *MaybeSpillFI,
- Register PreferredReg = AArch64::NoRegister)
- : MBB(MBB), MBBI(MBBI), RC(RC), TII(static_cast<const AArch64InstrInfo &>(
- *MF.getSubtarget().getInstrInfo())),
- TRI(*MF.getSubtarget().getRegisterInfo()) {
- FreeReg = tryScavengeRegister(UsedRegs, AllocatableRegs, PreferredReg);
- if (FreeReg != AArch64::NoRegister)
- return;
- assert(MaybeSpillFI && "Expected emergency spill slot FI information "
- "(attempted to spill in prologue/epilogue?)");
- if (!MaybeSpillFI->has_value()) {
- MachineFrameInfo &MFI = MF.getFrameInfo();
- *MaybeSpillFI = MFI.CreateSpillStackObject(TRI.getSpillSize(RC),
- TRI.getSpillAlign(RC));
- }
- FreeReg = SpillCandidate;
- SpillFI = MaybeSpillFI->value();
- TII.storeRegToStackSlot(MBB, MBBI, FreeReg, false, *SpillFI, &RC, &TRI,
- Register());
- }
-
- bool hasSpilled() const { return SpillFI.has_value(); }
-
- /// Returns the free register (found from scavenging or spilling a register).
- Register freeRegister() const { return FreeReg; }
-
- Register operator*() const { return freeRegister(); }
-
- ~ScopedScavengeOrSpill() {
- if (hasSpilled())
- TII.loadRegFromStackSlot(MBB, MBBI, FreeReg, *SpillFI, &RC, &TRI,
- Register());
- }
-
-private:
- MachineBasicBlock &MBB;
- MachineBasicBlock::iterator MBBI;
- const TargetRegisterClass &RC;
- const AArch64InstrInfo &TII;
- const TargetRegisterInfo &TRI;
- Register FreeReg = AArch64::NoRegister;
- std::optional<int> SpillFI;
-};
-
-/// Emergency stack slots for expanding SPILL_PPR_TO_ZPR_SLOT_PSEUDO and
-/// FILL_PPR_FROM_ZPR_SLOT_PSEUDO.
-struct EmergencyStackSlots {
- std::optional<int> ZPRSpillFI;
- std::optional<int> PPRSpillFI;
- std::optional<int> GPRSpillFI;
-};
-
-/// Registers available for scavenging (ZPR, PPR3b, GPR).
-struct ScavengeableRegs {
- BitVector ZPRRegs;
- BitVector PPR3bRegs;
- BitVector GPRRegs;
-};
-
-static bool isInPrologueOrEpilogue(const MachineInstr &MI) {
- return MI.getFlag(MachineInstr::FrameSetup) ||
- MI.getFlag(MachineInstr::FrameDestroy);
-}
-
-/// Expands:
-/// ```
-/// SPILL_PPR_TO_ZPR_SLOT_PSEUDO $p0, %stack.0, 0
-/// ```
-/// To:
-/// ```
-/// $z0 = CPY_ZPzI_B $p0, 1, 0
-/// STR_ZXI $z0, $stack.0, 0
-/// ```
-/// While ensuring a ZPR ($z0 in this example) is free for the predicate (
-/// spilling if necessary).
-static void expandSpillPPRToZPRSlotPseudo(MachineBasicBlock &MBB,
- MachineInstr &MI,
- const TargetRegisterInfo &TRI,
- LiveRegUnits const &UsedRegs,
- ScavengeableRegs const &SR,
- EmergencyStackSlots &SpillSlots) {
- MachineFunction &MF = *MBB.getParent();
- auto *TII =
- static_cast<const AArch64InstrInfo *>(MF.getSubtarget().getInstrInfo());
-
- ScopedScavengeOrSpill ZPredReg(
- MF, MBB, MI, AArch64::Z0, AArch64::ZPRRegClass, UsedRegs, SR.ZPRRegs,
- isInPrologueOrEpilogue(MI) ? nullptr : &SpillSlots.ZPRSpillFI);
-
- SmallVector<MachineInstr *, 2> MachineInstrs;
- const DebugLoc &DL = MI.getDebugLoc();
- MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::CPY_ZPzI_B))
- .addReg(*ZPredReg, RegState::Define)
- .add(MI.getOperand(0))
- .addImm(1)
- .addImm(0)
- .getInstr());
- MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::STR_ZXI))
- .addReg(*ZPredReg)
- .add(MI.getOperand(1))
- .addImm(MI.getOperand(2).getImm())
- .setMemRefs(MI.memoperands())
- .getInstr());
- propagateFrameFlags(MI, MachineInstrs);
-}
-
-/// Expands:
-/// ```
-/// $p0 = FILL_PPR_FROM_ZPR_SLOT_PSEUDO %stack.0, 0
-/// ```
-/// To:
-/// ```
-/// $z0 = LDR_ZXI %stack.0, 0
-/// $p0 = PTRUE_B 31, implicit $vg
-/// $p0 = CMPNE_PPzZI_B $p0, $z0, 0, implicit-def $nzcv, implicit-def $nzcv
-/// ```
-/// While ensuring a ZPR ($z0 in this example) is free for the predicate (
-/// spilling if necessary). If the status flags are in use at the point of
-/// expansion they are preserved (by moving them to/from a GPR). This may cause
-/// an additional spill if no GPR is free at the expansion point.
-static bool expandFillPPRFromZPRSlotPseudo(
- MachineBasicBlock &MBB, MachineInstr &MI, const TargetRegisterInfo &TRI,
- LiveRegUnits const &UsedRegs, ScavengeableRegs const &SR,
- MachineInstr *&LastPTrue, EmergencyStackSlots &SpillSlots) {
- MachineFunction &MF = *MBB.getParent();
- auto *TII =
- static_cast<const AArch64InstrInfo *>(MF.getSubtarget().getInstrInfo());
-
- ScopedScavengeOrSpill ZPredReg(
- MF, MBB, MI, AArch64::Z0, AArch64::ZPRRegClass, UsedRegs, SR.ZPRRegs,
- isInPrologueOrEpilogue(MI) ? nullptr : &SpillSlots.ZPRSpillFI);
-
- ScopedScavengeOrSpill PredReg(
- MF, MBB, MI, AArch64::P0, AArch64::PPR_3bRegClass, UsedRegs, SR.PPR3bRegs,
- isInPrologueOrEpilogue(MI) ? nullptr : &SpillSlots.PPRSpillFI,
- /*PreferredReg=*/
- LastPTrue ? LastPTrue->getOperand(0).getReg() : AArch64::NoRegister);
-
- // Elide NZCV spills if we know it is not used.
- bool IsNZCVUsed = !UsedRegs.available(AArch64::NZCV);
- std::optional<ScopedScavengeOrSpill> NZCVSaveReg;
- if (IsNZCVUsed)
- NZCVSaveReg.emplace(
- MF, MBB, MI, AArch64::X0, AArch64::GPR64RegClass, UsedRegs, SR.GPRRegs,
- isInPrologueOrEpilogue(MI) ? nullptr : &SpillSlots.GPRSpillFI);
- SmallVector<MachineInstr *, 4> MachineInstrs;
- const DebugLoc &DL = MI.getDebugLoc();
- MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::LDR_ZXI))
- .addReg(*ZPredReg, RegState::Define)
- .add(MI.getOperand(1))
- .addImm(MI.getOperand(2).getImm())
- .setMemRefs(MI.memoperands())
- .getInstr());
- if (IsNZCVUsed)
- MachineInstrs.push_back(
- BuildMI(MBB, MI, DL, TII->get(AArch64::MRS))
- .addReg(NZCVSaveReg->freeRegister(), RegState::Define)
- .addImm(AArch64SysReg::NZCV)
- .addReg(AArch64::NZCV, RegState::Implicit)
- .getInstr());
-
- // Reuse previous ptrue if we know it has not been clobbered.
- if (LastPTrue) {
- assert(*PredReg == LastPTrue->getOperand(0).getReg());
- LastPTrue->moveBefore(&MI);
- } else {
- LastPTrue = BuildMI(MBB, MI, DL, TII->get(AArch64::PTRUE_B))
- .addReg(*PredReg, RegState::Define)
- .addImm(31);
- }
- MachineInstrs.push_back(LastPTrue);
- MachineInstrs.push_back(
- BuildMI(MBB, MI, DL, TII->get(AArch64::CMPNE_PPzZI_B))
- .addReg(MI.getOperand(0).getReg(), RegState::Define)
- .addReg(*PredReg)
- .addReg(*ZPredReg)
- .addImm(0)
- .addReg(AArch64::NZCV, RegState::ImplicitDefine)
- .getInstr());
- if (IsNZCVUsed)
- MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::MSR))
- .addImm(AArch64SysReg::NZCV)
- .addReg(NZCVSaveReg->freeRegister())
- .addReg(AArch64::NZCV, RegState::ImplicitDefine)
- .getInstr());
-
- propagateFrameFlags(MI, MachineInstrs);
- return PredReg.hasSpilled();
-}
-
-/// Expands all FILL_PPR_FROM_ZPR_SLOT_PSEUDO and SPILL_PPR_TO_ZPR_SLOT_PSEUDO
-/// operations within the MachineBasicBlock \p MBB.
-static bool expandSMEPPRToZPRSpillPseudos(MachineBasicBlock &MBB,
- const TargetRegisterInfo &TRI,
- ScavengeableRegs const &SR,
- EmergencyStackSlots &SpillSlots) {
- LiveRegUnits UsedRegs(TRI);
- UsedRegs.addLiveOuts(MBB);
- bool HasPPRSpills = false;
- MachineInstr *LastPTrue = nullptr;
- for (MachineInstr &MI : make_early_inc_range(reverse(MBB))) {
- UsedRegs.stepBackward(MI);
- switch (MI.getOpcode()) {
- case AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO:
- if (LastPTrue &&
- MI.definesRegister(LastPTrue->getOperand(0).getReg(), &TRI))
- LastPTrue = nullptr;
- HasPPRSpills |= expandFillPPRFromZPRSlotPseudo(MBB, MI, TRI, UsedRegs, SR,
- LastPTrue, SpillSlots);
- MI.eraseFromParent();
- break;
- case AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO:
- expandSpillPPRToZPRSlotPseudo(MBB, MI, TRI, UsedRegs, SR, SpillSlots);
- MI.eraseFromParent();
- [[fallthrough]];
- default:
- LastPTrue = nullptr;
- break;
- }
- }
-
- return HasPPRSpills;
-}
-
void AArch64FrameLowering::processFunctionBeforeFrameFinalized(
MachineFunction &MF, RegScavenger *RS) const {
-
- AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
- const TargetSubtargetInfo &TSI = MF.getSubtarget();
- const TargetRegisterInfo &TRI = *TSI.getRegisterInfo();
-
- // If predicates spills are 16-bytes we may need to expand
- // SPILL_PPR_TO_ZPR_SLOT_PSEUDO/FILL_PPR_FROM_ZPR_SLOT_PSEUDO.
- if (AFI->hasStackFrame() && TRI.getSpillSize(AArch64::PPRRegClass) == 16) {
- auto ComputeScavengeableRegisters = [&](unsigned RegClassID) {
- BitVector Regs = TRI.getAllocatableSet(MF, TRI.getRegClass(RegClassID));
- assert(Regs.count() > 0 && "Expected scavengeable registers");
- return Regs;
- };
-
- ScavengeableRegs SR{};
- SR.ZPRRegs = ComputeScavengeableRegisters(AArch64::ZPRRegClassID);
- // Only p0-7 are possible as the second operand of cmpne (needed for fills).
- SR.PPR3bRegs = ComputeScavengeableRegisters(AArch64::PPR_3bRegClassID);
- SR.GPRRegs = ComputeScavengeableRegisters(AArch64::GPR64RegClassID);
-
- EmergencyStackSlots SpillSlots;
- for (MachineBasicBlock &MBB : MF) {
- // In the case we had to spill a predicate (in the range p0-p7) to reload
- // a predicate (>= p8), additional spill/fill pseudos will be created.
- // These need an additional expansion pass. Note: There will only be at
- // most two expansion passes, as spilling/filling a predicate in the range
- // p0-p7 never requires spilling another predicate.
- for (int Pass = 0; Pass < 2; Pass++) {
- bool HasPPRSpills =
- expandSMEPPRToZPRSpillPseudos(MBB, TRI, SR, SpillSlots);
- assert((Pass == 0 || !HasPPRSpills) && "Did not expect PPR spills");
- if (!HasPPRSpills)
- break;
- }
- }
- }
-
- MachineFrameInfo &MFI = MF.getFrameInfo();
-
assert(getStackGrowthDirection() == TargetFrameLowering::StackGrowsDown &&
"Upwards growing stack unsupported");
@@ -3279,6 +2949,9 @@ void AArch64FrameLowering::processFunctionBeforeFrameFinalized(
if (!MF.hasEHFunclets())
return;
+ MachineFrameInfo &MFI = MF.getFrameInfo();
+ auto *AFI = MF.getInfo<AArch64FunctionInfo>();
+
// Win64 C++ EH needs to allocate space for the catch objects in the fixed
// object area right next to the UnwindHelp object.
WinEHFuncInfo &EHInfo = *MF.getWinEHFuncInfo();
@@ -4280,18 +3953,10 @@ void AArch64FrameLowering::emitRemarks(
}
unsigned RegTy = StackAccess::AccessType::GPR;
- if (MFI.hasScalableStackID(FrameIdx)) {
- // SPILL_PPR_TO_ZPR_SLOT_PSEUDO and FILL_PPR_FROM_ZPR_SLOT_PSEUDO
- // spill/fill the predicate as a data vector (so are an FPR access).
- if (MI.getOpcode() != AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO &&
- MI.getOpcode() != AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO &&
- AArch64::PPRRegClass.contains(MI.getOperand(0).getReg())) {
- RegTy = StackAccess::PPR;
- } else
- RegTy = StackAccess::FPR;
- } else if (AArch64InstrInfo::isFpOrNEON(MI)) {
+ if (MFI.hasScalableStackID(FrameIdx))
+ RegTy = isPPRAccess(MI) ? StackAccess::PPR : StackAccess::FPR;
+ else if (AArch64InstrInfo::isFpOrNEON(MI))
RegTy = StackAccess::FPR;
- }
StackAccesses[ArrIdx].AccessTypes |= RegTy;
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 5a90da1..b8761d97 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -2579,8 +2579,6 @@ unsigned AArch64InstrInfo::getLoadStoreImmIdx(unsigned Opc) {
case AArch64::STZ2Gi:
case AArch64::STZGi:
case AArch64::TAGPstack:
- case AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO:
- case AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO:
return 2;
case AArch64::LD1B_D_IMM:
case AArch64::LD1B_H_IMM:
@@ -4387,8 +4385,6 @@ bool AArch64InstrInfo::getMemOpInfo(unsigned Opcode, TypeSize &Scale,
MinOffset = -256;
MaxOffset = 254;
break;
- case AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO:
- case AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO:
case AArch64::LDR_ZXI:
case AArch64::STR_ZXI:
Scale = TypeSize::getScalable(16);
@@ -5098,33 +5094,31 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
BuildMI(MBB, I, DL, get(AArch64::MOVZWi), DestReg)
.addImm(0)
.addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
+ } else if (Subtarget.hasZeroCycleRegMoveGPR64() &&
+ !Subtarget.hasZeroCycleRegMoveGPR32()) {
+ // Cyclone recognizes "ORR Xd, XZR, Xm" as a zero-cycle register move.
+ MCRegister DestRegX = TRI->getMatchingSuperReg(DestReg, AArch64::sub_32,
+ &AArch64::GPR64spRegClass);
+ assert(DestRegX.isValid() && "Destination super-reg not valid");
+ MCRegister SrcRegX =
+ SrcReg == AArch64::WZR
+ ? AArch64::XZR
+ : TRI->getMatchingSuperReg(SrcReg, AArch64::sub_32,
+ &AArch64::GPR64spRegClass);
+ assert(SrcRegX.isValid() && "Source super-reg not valid");
+ // This instruction is reading and writing X registers. This may upset
+ // the register scavenger and machine verifier, so we need to indicate
+ // that we are reading an undefined value from SrcRegX, but a proper
+ // value from SrcReg.
+ BuildMI(MBB, I, DL, get(AArch64::ORRXrr), DestRegX)
+ .addReg(AArch64::XZR)
+ .addReg(SrcRegX, RegState::Undef)
+ .addReg(SrcReg, RegState::Implicit | getKillRegState(KillSrc));
} else {
- if (Subtarget.hasZeroCycleRegMoveGPR64() &&
- !Subtarget.hasZeroCycleRegMoveGPR32()) {
- // Cyclone recognizes "ORR Xd, XZR, Xm" as a zero-cycle register move.
- MCRegister DestRegX = TRI->getMatchingSuperReg(
- DestReg, AArch64::sub_32, &AArch64::GPR64spRegClass);
- assert(DestRegX.isValid() && "Destination super-reg not valid");
- MCRegister SrcRegX =
- SrcReg == AArch64::WZR
- ? AArch64::XZR
- : TRI->getMatchingSuperReg(SrcReg, AArch64::sub_32,
- &AArch64::GPR64spRegClass);
- assert(SrcRegX.isValid() && "Source super-reg not valid");
- // This instruction is reading and writing X registers. This may upset
- // the register scavenger and machine verifier, so we need to indicate
- // that we are reading an undefined value from SrcRegX, but a proper
- // value from SrcReg.
- BuildMI(MBB, I, DL, get(AArch64::ORRXrr), DestRegX)
- .addReg(AArch64::XZR)
- .addReg(SrcRegX, RegState::Undef)
- .addReg(SrcReg, RegState::Implicit | getKillRegState(KillSrc));
- } else {
- // Otherwise, expand to ORR WZR.
- BuildMI(MBB, I, DL, get(AArch64::ORRWrr), DestReg)
- .addReg(AArch64::WZR)
- .addReg(SrcReg, getKillRegState(KillSrc));
- }
+ // Otherwise, expand to ORR WZR.
+ BuildMI(MBB, I, DL, get(AArch64::ORRWrr), DestReg)
+ .addReg(AArch64::WZR)
+ .addReg(SrcReg, getKillRegState(KillSrc));
}
return;
}
@@ -5650,11 +5644,6 @@ void AArch64InstrInfo::storeRegToStackSlot(MachineBasicBlock &MBB,
"Unexpected register store without SVE store instructions");
Opc = AArch64::STR_ZXI;
StackID = TargetStackID::ScalableVector;
- } else if (AArch64::PPRRegClass.hasSubClassEq(RC)) {
- assert(Subtarget.isSVEorStreamingSVEAvailable() &&
- "Unexpected predicate store without SVE store instructions");
- Opc = AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO;
- StackID = TargetStackID::ScalableVector;
}
break;
case 24:
@@ -5835,11 +5824,6 @@ void AArch64InstrInfo::loadRegFromStackSlot(
"Unexpected register load without SVE load instructions");
Opc = AArch64::LDR_ZXI;
StackID = TargetStackID::ScalableVector;
- } else if (AArch64::PPRRegClass.hasSubClassEq(RC)) {
- assert(Subtarget.isSVEorStreamingSVEAvailable() &&
- "Unexpected predicate load without SVE load instructions");
- Opc = AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO;
- StackID = TargetStackID::ScalableVector;
}
break;
case 24:
diff --git a/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.cpp b/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.cpp
index aed137c..1568161 100644
--- a/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.cpp
+++ b/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.cpp
@@ -57,10 +57,7 @@ static bool isPartOfZPRCalleeSaves(MachineBasicBlock::iterator I) {
case AArch64::ST1B_2Z_IMM:
case AArch64::STR_ZXI:
case AArch64::LDR_ZXI:
- case AArch64::CPY_ZPzI_B:
- case AArch64::CMPNE_PPzZI_B:
case AArch64::PTRUE_C_B:
- case AArch64::PTRUE_B:
return I->getFlag(MachineInstr::FrameSetup) ||
I->getFlag(MachineInstr::FrameDestroy);
case AArch64::SEH_SavePReg:
diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
index 5d89862..ef974df 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
@@ -980,19 +980,10 @@ class ZPRRegOp <string Suffix, AsmOperandClass C, ElementSizeEnum Size,
//******************************************************************************
// SVE predicate register classes.
-
-// Note: This hardware mode is enabled in AArch64Subtarget::getHwModeSet()
-// (without the use of the table-gen'd predicates).
-def SMEWithZPRPredicateSpills : HwMode<[Predicate<"false">]>;
-
-def PPRSpillFillRI : RegInfoByHwMode<
- [DefaultMode, SMEWithZPRPredicateSpills],
- [RegInfo<16,16,16>, RegInfo<16,128,128>]>;
-
class PPRClass<int firstreg, int lastreg, int step = 1> : RegisterClass<"AArch64",
[ nxv16i1, nxv8i1, nxv4i1, nxv2i1, nxv1i1 ], 16,
(sequence "P%u", firstreg, lastreg, step)> {
- let RegInfos = PPRSpillFillRI;
+ let Size = 16;
}
def PPR : PPRClass<0, 15> {
diff --git a/llvm/lib/Target/AArch64/AArch64Subtarget.cpp b/llvm/lib/Target/AArch64/AArch64Subtarget.cpp
index 98e0a11..12ddf47 100644
--- a/llvm/lib/Target/AArch64/AArch64Subtarget.cpp
+++ b/llvm/lib/Target/AArch64/AArch64Subtarget.cpp
@@ -86,11 +86,6 @@ static cl::alias AArch64StreamingStackHazardSize(
cl::desc("alias for -aarch64-streaming-hazard-size"),
cl::aliasopt(AArch64StreamingHazardSize));
-static cl::opt<bool> EnableZPRPredicateSpills(
- "aarch64-enable-zpr-predicate-spills", cl::init(false), cl::Hidden,
- cl::desc(
- "Enables spilling/reloading SVE predicates as data vectors (ZPRs)"));
-
static cl::opt<unsigned>
VScaleForTuningOpt("sve-vscale-for-tuning", cl::Hidden,
cl::desc("Force a vscale for tuning factor for SVE"));
@@ -426,20 +421,6 @@ AArch64Subtarget::AArch64Subtarget(const Triple &TT, StringRef CPU,
EnableSubregLiveness = EnableSubregLivenessTracking.getValue();
}
-unsigned AArch64Subtarget::getHwModeSet() const {
- AArch64HwModeBits Modes = AArch64HwModeBits::DefaultMode;
-
- // Use a special hardware mode in streaming[-compatible] functions with
- // aarch64-enable-zpr-predicate-spills. This changes the spill size (and
- // alignment) for the predicate register class.
- if (EnableZPRPredicateSpills.getValue() &&
- (isStreaming() || isStreamingCompatible())) {
- Modes |= AArch64HwModeBits::SMEWithZPRPredicateSpills;
- }
-
- return to_underlying(Modes);
-}
-
const CallLowering *AArch64Subtarget::getCallLowering() const {
return CallLoweringInfo.get();
}
diff --git a/llvm/lib/Target/AArch64/AArch64Subtarget.h b/llvm/lib/Target/AArch64/AArch64Subtarget.h
index 671df35..8974965 100644
--- a/llvm/lib/Target/AArch64/AArch64Subtarget.h
+++ b/llvm/lib/Target/AArch64/AArch64Subtarget.h
@@ -130,8 +130,6 @@ public:
bool IsStreaming = false, bool IsStreamingCompatible = false,
bool HasMinSize = false);
- virtual unsigned getHwModeSet() const override;
-
// Getters for SubtargetFeatures defined in tablegen
#define GET_SUBTARGETINFO_MACRO(ATTRIBUTE, DEFAULT, GETTER) \
bool GETTER() const { return ATTRIBUTE; }
diff --git a/llvm/lib/Target/AArch64/SMEInstrFormats.td b/llvm/lib/Target/AArch64/SMEInstrFormats.td
index be44b8f..33f35ad 100644
--- a/llvm/lib/Target/AArch64/SMEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SMEInstrFormats.td
@@ -58,20 +58,6 @@ def FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO :
let hasSideEffects = 0;
}
-def SPILL_PPR_TO_ZPR_SLOT_PSEUDO :
- Pseudo<(outs), (ins PPRorPNRAny:$Pt, GPR64sp:$Rn, simm9:$imm9), []>, Sched<[]>
-{
- let mayStore = 1;
- let hasSideEffects = 0;
-}
-
-def FILL_PPR_FROM_ZPR_SLOT_PSEUDO :
- Pseudo<(outs PPRorPNRAny:$Pt), (ins GPR64sp:$Rn, simm9:$imm9), []>, Sched<[]>
-{
- let mayLoad = 1;
- let hasSideEffects = 0;
-}
-
def SDTZALoadStore : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisPtrTy<1>, SDTCisInt<2>]>;
// SME ZA loads and stores
def AArch64SMELdr : SDNode<"AArch64ISD::SME_ZA_LDR", SDTZALoadStore,
diff --git a/llvm/lib/Target/AMDGPU/AMDGPU.td b/llvm/lib/Target/AMDGPU/AMDGPU.td
index 9446144..6b3c151 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPU.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPU.td
@@ -1411,20 +1411,6 @@ def FeatureGloballyAddressableScratch : SubtargetFeature<
"FLAT instructions can access scratch memory for any thread in any wave"
>;
-// FIXME: Remove after all users are migrated to attribute.
-def FeatureDynamicVGPR : SubtargetFeature <"dynamic-vgpr",
- "DynamicVGPR",
- "true",
- "Enable dynamic VGPR mode"
->;
-
-// FIXME: Remove after all users are migrated to attribute.
-def FeatureDynamicVGPRBlockSize32 : SubtargetFeature<"dynamic-vgpr-block-size-32",
- "DynamicVGPRBlockSize32",
- "true",
- "Use a block size of 32 for dynamic VGPR allocation (default is 16)"
->;
-
// Enable the use of SCRATCH_STORE/LOAD_BLOCK instructions for saving and
// restoring the callee-saved registers.
def FeatureUseBlockVGPROpsForCSR : SubtargetFeature<"block-vgpr-csr",
@@ -1462,6 +1448,12 @@ def Feature45BitNumRecordsBufferResource : SubtargetFeature< "45-bit-num-records
"The buffer resource (V#) supports 45-bit num_records"
>;
+def FeatureCluster : SubtargetFeature< "cluster",
+ "HasCluster",
+ "true",
+ "Has cluster support"
+>;
+
// Dummy feature used to disable assembler instructions.
def FeatureDisable : SubtargetFeature<"",
"FeatureDisable","true",
@@ -2128,6 +2120,7 @@ def FeatureISAVersion12_50 : FeatureSet<
Feature45BitNumRecordsBufferResource,
FeatureSupportsXNACK,
FeatureXNACK,
+ FeatureCluster,
]>;
def FeatureISAVersion12_51 : FeatureSet<
diff --git a/llvm/lib/Target/AMDGPU/GCNSubtarget.h b/llvm/lib/Target/AMDGPU/GCNSubtarget.h
index a54d665..879bf5a 100644
--- a/llvm/lib/Target/AMDGPU/GCNSubtarget.h
+++ b/llvm/lib/Target/AMDGPU/GCNSubtarget.h
@@ -288,6 +288,8 @@ protected:
bool Has45BitNumRecordsBufferResource = false;
+ bool HasCluster = false;
+
// Dummy feature to use for assembler in tablegen.
bool FeatureDisable = false;
@@ -1837,7 +1839,7 @@ public:
}
/// \returns true if the subtarget supports clusters of workgroups.
- bool hasClusters() const { return GFX1250Insts; }
+ bool hasClusters() const { return HasCluster; }
/// \returns true if the subtarget requires a wait for xcnt before atomic
/// flat/global stores & rmw.
diff --git a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp
index 20fa141..f7f4d46 100644
--- a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp
@@ -1353,11 +1353,6 @@ unsigned getVGPRAllocGranule(const MCSubtargetInfo *STI,
if (DynamicVGPRBlockSize != 0)
return DynamicVGPRBlockSize;
- // Temporarily check the subtarget feature, until we fully switch to using
- // attributes.
- if (STI->getFeatureBits().test(FeatureDynamicVGPR))
- return STI->getFeatureBits().test(FeatureDynamicVGPRBlockSize32) ? 32 : 16;
-
bool IsWave32 = EnableWavefrontSize32
? *EnableWavefrontSize32
: STI->getFeatureBits().test(FeatureWavefrontSize32);
@@ -1412,10 +1407,7 @@ unsigned getAddressableNumVGPRs(const MCSubtargetInfo *STI,
if (Features.test(FeatureGFX90AInsts))
return 512;
- // Temporarily check the subtarget feature, until we fully switch to using
- // attributes.
- if (DynamicVGPRBlockSize != 0 ||
- STI->getFeatureBits().test(FeatureDynamicVGPR))
+ if (DynamicVGPRBlockSize != 0)
// On GFX12 we can allocate at most 8 blocks of VGPRs.
return 8 * getVGPRAllocGranule(STI, DynamicVGPRBlockSize);
return getAddressableNumArchVGPRs(STI);
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index f9bdc09..77913f2 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -149,6 +149,9 @@ void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O,
case NVPTX::PTXCvtMode::RNA:
O << ".rna";
return;
+ case NVPTX::PTXCvtMode::RS:
+ O << ".rs";
+ return;
}
}
llvm_unreachable("Invalid conversion modifier");
diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h
index 77a0e03..1e0f747 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.h
+++ b/llvm/lib/Target/NVPTX/NVPTX.h
@@ -207,6 +207,7 @@ enum CvtMode {
RM,
RP,
RNA,
+ RS,
BASE_MASK = 0x0F,
FTZ_FLAG = 0x10,
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 8c21746..bc047a4a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1096,9 +1096,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
// Enable custom lowering for the following:
// * MVT::i128 - clusterlaunchcontrol
// * MVT::i32 - prmt
+ // * MVT::v4f32 - cvt_rs fp{4/6/8}x4 intrinsics
// * MVT::Other - internal.addrspace.wrap
- setOperationAction(ISD::INTRINSIC_WO_CHAIN, {MVT::i32, MVT::i128, MVT::Other},
- Custom);
+ setOperationAction(ISD::INTRINSIC_WO_CHAIN,
+ {MVT::i32, MVT::i128, MVT::v4f32, MVT::Other}, Custom);
}
const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
@@ -1181,6 +1182,11 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
NVPTXISD::TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1_ASHIFT)
MAKE_CASE(
NVPTXISD::TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2_ASHIFT)
+ MAKE_CASE(NVPTXISD::CVT_E4M3X4_F32X4_RS_SF)
+ MAKE_CASE(NVPTXISD::CVT_E5M2X4_F32X4_RS_SF)
+ MAKE_CASE(NVPTXISD::CVT_E2M3X4_F32X4_RS_SF)
+ MAKE_CASE(NVPTXISD::CVT_E3M2X4_F32X4_RS_SF)
+ MAKE_CASE(NVPTXISD::CVT_E2M1X4_F32X4_RS_SF)
}
return nullptr;
@@ -2903,6 +2909,61 @@ static SDValue LowerClusterLaunchControlQueryCancel(SDValue Op,
{TryCancelResponse0, TryCancelResponse1});
}
+static SDValue lowerCvtRSIntrinsics(SDValue Op, SelectionDAG &DAG) {
+ SDNode *N = Op.getNode();
+ SDLoc DL(N);
+ SDValue F32Vec = N->getOperand(1);
+ SDValue RBits = N->getOperand(2);
+
+ unsigned IntrinsicID = N->getConstantOperandVal(0);
+
+ // Extract the 4 float elements from the vector
+ SmallVector<SDValue, 6> Ops;
+ for (unsigned i = 0; i < 4; ++i)
+ Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, F32Vec,
+ DAG.getIntPtrConstant(i, DL)));
+
+ using NVPTX::PTXCvtMode::CvtMode;
+
+ auto [OpCode, RetTy, CvtModeFlag] =
+ [&]() -> std::tuple<NVPTXISD::NodeType, MVT::SimpleValueType, uint32_t> {
+ switch (IntrinsicID) {
+ case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite:
+ return {NVPTXISD::CVT_E4M3X4_F32X4_RS_SF, MVT::v4i8,
+ CvtMode::RS | CvtMode::RELU_FLAG};
+ case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite:
+ return {NVPTXISD::CVT_E4M3X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
+ case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite:
+ return {NVPTXISD::CVT_E5M2X4_F32X4_RS_SF, MVT::v4i8,
+ CvtMode::RS | CvtMode::RELU_FLAG};
+ case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite:
+ return {NVPTXISD::CVT_E5M2X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
+ case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite:
+ return {NVPTXISD::CVT_E2M3X4_F32X4_RS_SF, MVT::v4i8,
+ CvtMode::RS | CvtMode::RELU_FLAG};
+ case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite:
+ return {NVPTXISD::CVT_E2M3X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
+ case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite:
+ return {NVPTXISD::CVT_E3M2X4_F32X4_RS_SF, MVT::v4i8,
+ CvtMode::RS | CvtMode::RELU_FLAG};
+ case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite:
+ return {NVPTXISD::CVT_E3M2X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
+ case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite:
+ return {NVPTXISD::CVT_E2M1X4_F32X4_RS_SF, MVT::i16,
+ CvtMode::RS | CvtMode::RELU_FLAG};
+ case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite:
+ return {NVPTXISD::CVT_E2M1X4_F32X4_RS_SF, MVT::i16, CvtMode::RS};
+ default:
+ llvm_unreachable("unsupported/unhandled intrinsic");
+ }
+ }();
+
+ Ops.push_back(RBits);
+ Ops.push_back(DAG.getConstant(CvtModeFlag, DL, MVT::i32));
+
+ return DAG.getNode(OpCode, DL, RetTy, Ops);
+}
+
static SDValue lowerPrmtIntrinsic(SDValue Op, SelectionDAG &DAG) {
const unsigned Mode = [&]() {
switch (Op->getConstantOperandVal(0)) {
@@ -2972,6 +3033,17 @@ static SDValue lowerIntrinsicWOChain(SDValue Op, SelectionDAG &DAG) {
case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y:
case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z:
return LowerClusterLaunchControlQueryCancel(Op, DAG);
+ case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite:
+ case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite:
+ case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite:
+ case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite:
+ case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite:
+ case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite:
+ case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite:
+ case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite:
+ case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite:
+ case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite:
+ return lowerCvtRSIntrinsics(Op, DAG);
}
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 769d2fe..63fa0bb 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -79,6 +79,11 @@ enum NodeType : unsigned {
CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_X,
CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Y,
CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Z,
+ CVT_E4M3X4_F32X4_RS_SF,
+ CVT_E5M2X4_F32X4_RS_SF,
+ CVT_E2M3X4_F32X4_RS_SF,
+ CVT_E3M2X4_F32X4_RS_SF,
+ CVT_E2M1X4_F32X4_RS_SF,
FIRST_MEMORY_OPCODE,
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 4cacee2..6c14cf0 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -34,7 +34,8 @@ def CvtRN : PatLeaf<(i32 0x5)>;
def CvtRZ : PatLeaf<(i32 0x6)>;
def CvtRM : PatLeaf<(i32 0x7)>;
def CvtRP : PatLeaf<(i32 0x8)>;
-def CvtRNA : PatLeaf<(i32 0x9)>;
+def CvtRNA : PatLeaf<(i32 0x9)>;
+def CvtRS : PatLeaf<(i32 0xA)>;
def CvtNONE_FTZ : PatLeaf<(i32 0x10)>;
def CvtRNI_FTZ : PatLeaf<(i32 0x11)>;
@@ -50,8 +51,9 @@ def CvtSAT : PatLeaf<(i32 0x20)>;
def CvtSAT_FTZ : PatLeaf<(i32 0x30)>;
def CvtNONE_RELU : PatLeaf<(i32 0x40)>;
-def CvtRN_RELU : PatLeaf<(i32 0x45)>;
-def CvtRZ_RELU : PatLeaf<(i32 0x46)>;
+def CvtRN_RELU : PatLeaf<(i32 0x45)>;
+def CvtRZ_RELU : PatLeaf<(i32 0x46)>;
+def CvtRS_RELU : PatLeaf<(i32 0x4A)>;
def CvtMode : Operand<i32> {
let PrintMethod = "printCvtMode";
@@ -133,6 +135,11 @@ def hasSM100a : Predicate<"Subtarget->getSmVersion() == 100 && Subtarget->hasArc
def hasSM101a : Predicate<"Subtarget->getSmVersion() == 101 && Subtarget->hasArchAccelFeatures()">;
def hasSM120a : Predicate<"Subtarget->getSmVersion() == 120 && Subtarget->hasArchAccelFeatures()">;
+def hasSM100aOrSM103a :
+ Predicate<"(Subtarget->getSmVersion() == 100 || " #
+ "Subtarget->getSmVersion() == 103) " #
+ "&& Subtarget->hasArchAccelFeatures()">;
+
// non-sync shfl instructions are not available on sm_70+ in PTX6.4+
def hasSHFL : Predicate<"!(Subtarget->getSmVersion() >= 70"
"&& Subtarget->getPTXVersion() >= 64)">;
@@ -593,6 +600,23 @@ let hasSideEffects = false in {
defm CVT_f16x2 : CVT_FROM_FLOAT_V2_SM80<"f16x2", B32>;
defm CVT_bf16x2 : CVT_FROM_FLOAT_V2_SM80<"bf16x2", B32>;
+
+ multiclass CVT_FROM_FLOAT_V2_RS<string FromName, RegisterClass RC> {
+ def _f32_rs :
+ BasicFlagsNVPTXInst<(outs RC:$dst),
+ (ins B32:$src1, B32:$src2, B32:$src3),
+ (ins CvtMode:$mode),
+ "cvt${mode:base}${mode:relu}." # FromName # ".f32">;
+
+ def _f32_rs_sf :
+ BasicFlagsNVPTXInst<(outs RC:$dst),
+ (ins B32:$src1, B32:$src2, B32:$src3),
+ (ins CvtMode:$mode),
+ "cvt${mode:base}${mode:relu}.satfinite." # FromName # ".f32">;
+ }
+
+ defm CVT_f16x2 : CVT_FROM_FLOAT_V2_RS<"f16x2", B32>;
+ defm CVT_bf16x2 : CVT_FROM_FLOAT_V2_RS<"bf16x2", B32>;
// FP8 conversions.
multiclass CVT_TO_F8X2<string F8Name> {
@@ -619,6 +643,15 @@ let hasSideEffects = false in {
def CVT_f16x2_e4m3x2 : CVT_f16x2_fp8<"e4m3">;
def CVT_f16x2_e5m2x2 : CVT_f16x2_fp8<"e5m2">;
+
+ class CVT_TO_FP8X4<string F8Name> :
+ NVPTXInst<(outs B32:$dst),
+ (ins B32:$src1, B32:$src2, B32:$src3, B32:$src4, B32:$src5, CvtMode:$mode),
+ "cvt${mode:base}${mode:relu}.satfinite." # F8Name #
+ "x4.f32 \t$dst, {{$src1, $src2, $src3, $src4}}, $src5;">;
+
+ def CVT_e4m3x4_f32x4_rs_sf : CVT_TO_FP8X4<"e4m3">;
+ def CVT_e5m2x4_f32x4_rs_sf : CVT_TO_FP8X4<"e5m2">;
// Float to TF32 conversions
multiclass CVT_TO_TF32<string Modifier, list<Predicate> Preds = [hasPTX<78>, hasSM<90>]> {
@@ -652,6 +685,15 @@ let hasSideEffects = false in {
"cvt${mode:base}${mode:relu}.f16x2." # type>;
}
+ class CVT_TO_FP6X4<string F6Name> :
+ NVPTXInst<(outs B32:$dst),
+ (ins B32:$src1, B32:$src2, B32:$src3, B32:$src4, B32:$src5, CvtMode:$mode),
+ "cvt${mode:base}${mode:relu}.satfinite." # F6Name #
+ "x4.f32 \t$dst, {{$src1, $src2, $src3, $src4}}, $src5;">;
+
+ def CVT_e2m3x4_f32x4_rs_sf : CVT_TO_FP6X4<"e2m3">;
+ def CVT_e3m2x4_f32x4_rs_sf : CVT_TO_FP6X4<"e3m2">;
+
// FP4 conversions.
def CVT_e2m1x2_f32_sf : NVPTXInst<(outs B16:$dst),
(ins B32:$src1, B32:$src2, CvtMode:$mode),
@@ -668,6 +710,12 @@ let hasSideEffects = false in {
"cvt.u8.u16 \t%e2m1x2_in, $src; \n\t",
"cvt${mode:base}${mode:relu}.f16x2.e2m1x2 \t$dst, %e2m1x2_in; \n\t",
"}}"), []>;
+
+ def CVT_e2m1x4_f32x4_rs_sf :
+ NVPTXInst<(outs B16:$dst),
+ (ins B32:$src1, B32:$src2, B32:$src3, B32:$src4, B32:$src5, CvtMode:$mode),
+ "cvt${mode:base}${mode:relu}.satfinite.e2m1x4.f32 \t" #
+ "$dst, {{$src1, $src2, $src3, $src4}}, $src5;">;
// UE8M0x2 conversions.
class CVT_f32_to_ue8m0x2<string sat = ""> :
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index e91171c..a8b854f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1782,11 +1782,32 @@ def : Pat<(int_nvvm_ff2bf16x2_rn_relu f32:$a, f32:$b), (CVT_bf16x2_f32 $a, $b, C
def : Pat<(int_nvvm_ff2bf16x2_rz f32:$a, f32:$b), (CVT_bf16x2_f32 $a, $b, CvtRZ)>;
def : Pat<(int_nvvm_ff2bf16x2_rz_relu f32:$a, f32:$b), (CVT_bf16x2_f32 $a, $b, CvtRZ_RELU)>;
+let Predicates = [hasPTX<87>, hasSM100aOrSM103a] in {
+def : Pat<(int_nvvm_ff2bf16x2_rs f32:$a, f32:$b, i32:$c),
+ (CVT_bf16x2_f32_rs $a, $b, $c, CvtRS)>;
+def : Pat<(int_nvvm_ff2bf16x2_rs_relu f32:$a, f32:$b, i32:$c),
+ (CVT_bf16x2_f32_rs $a, $b, $c, CvtRS_RELU)>;
+def : Pat<(int_nvvm_ff2bf16x2_rs_satfinite f32:$a, f32:$b, i32:$c),
+ (CVT_bf16x2_f32_rs_sf $a, $b, $c, CvtRS)>;
+def : Pat<(int_nvvm_ff2bf16x2_rs_relu_satfinite f32:$a, f32:$b, i32:$c),
+ (CVT_bf16x2_f32_rs_sf $a, $b, $c, CvtRS_RELU)>;
+}
+
def : Pat<(int_nvvm_ff2f16x2_rn f32:$a, f32:$b), (CVT_f16x2_f32 $a, $b, CvtRN)>;
def : Pat<(int_nvvm_ff2f16x2_rn_relu f32:$a, f32:$b), (CVT_f16x2_f32 $a, $b, CvtRN_RELU)>;
def : Pat<(int_nvvm_ff2f16x2_rz f32:$a, f32:$b), (CVT_f16x2_f32 $a, $b, CvtRZ)>;
def : Pat<(int_nvvm_ff2f16x2_rz_relu f32:$a, f32:$b), (CVT_f16x2_f32 $a, $b, CvtRZ_RELU)>;
+let Predicates = [hasPTX<87>, hasSM100aOrSM103a] in {
+def : Pat<(int_nvvm_ff2f16x2_rs f32:$a, f32:$b, i32:$c),
+ (CVT_f16x2_f32_rs $a, $b, $c, CvtRS)>;
+def : Pat<(int_nvvm_ff2f16x2_rs_relu f32:$a, f32:$b, i32:$c),
+ (CVT_f16x2_f32_rs $a, $b, $c, CvtRS_RELU)>;
+def : Pat<(int_nvvm_ff2f16x2_rs_satfinite f32:$a, f32:$b, i32:$c),
+ (CVT_f16x2_f32_rs_sf $a, $b, $c, CvtRS)>;
+def : Pat<(int_nvvm_ff2f16x2_rs_relu_satfinite f32:$a, f32:$b, i32:$c),
+ (CVT_f16x2_f32_rs_sf $a, $b, $c, CvtRS_RELU)>;
+}
def : Pat<(int_nvvm_f2bf16_rn f32:$a), (CVT_bf16_f32 $a, CvtRN)>;
def : Pat<(int_nvvm_f2bf16_rn_relu f32:$a), (CVT_bf16_f32 $a, CvtRN_RELU)>;
def : Pat<(int_nvvm_f2bf16_rz f32:$a), (CVT_bf16_f32 $a, CvtRZ)>;
@@ -1929,6 +1950,52 @@ let Predicates = [hasPTX<86>, hasSM<100>, hasArchAccelFeatures] in {
(CVT_bf16x2_ue8m0x2 $a)>;
}
+def SDT_CVT_F32X4_TO_FPX4_RS_VEC :
+ SDTypeProfile<1, 6, [SDTCisVec<0>, SDTCisFP<1>, SDTCisFP<2>, SDTCisFP<3>,
+ SDTCisFP<4>, SDTCisInt<5>, SDTCisInt<6>]>;
+
+def SDT_CVT_F32X4_TO_FPX4_RS_INT :
+ SDTypeProfile<1, 6, [SDTCisInt<0>, SDTCisFP<1>, SDTCisFP<2>, SDTCisFP<3>,
+ SDTCisFP<4>, SDTCisInt<5>, SDTCisInt<6>]>;
+
+class CVT_F32X4_TO_FPX4_RS_SF_NODE<string FPName, SDTypeProfile SDT> :
+ SDNode<"NVPTXISD::CVT_" # FPName # "X4_F32X4_RS_SF", SDT, []>;
+
+multiclass CVT_F32X4_TO_FPX4_RS_SF_VEC<string FPName, VTVec RetTy> {
+ def : Pat<(RetTy (CVT_F32X4_TO_FPX4_RS_SF_NODE<!toupper(FPName),
+ SDT_CVT_F32X4_TO_FPX4_RS_VEC>
+ f32:$f1, f32:$f2, f32:$f3, f32:$f4, i32:$rbits, CvtRS)),
+ (!cast<NVPTXInst>("CVT_" # FPName # "x4_f32x4_rs_sf")
+ $f1, $f2, $f3, $f4, $rbits, CvtRS)>;
+
+ def : Pat<(RetTy (CVT_F32X4_TO_FPX4_RS_SF_NODE<!toupper(FPName),
+ SDT_CVT_F32X4_TO_FPX4_RS_VEC>
+ f32:$f1, f32:$f2, f32:$f3, f32:$f4, i32:$rbits, CvtRS_RELU)),
+ (!cast<NVPTXInst>("CVT_" # FPName # "x4_f32x4_rs_sf")
+ $f1, $f2, $f3, $f4, $rbits, CvtRS_RELU)>;
+}
+
+// RS rounding mode conversions
+let Predicates = [hasPTX<87>, hasSM100aOrSM103a] in {
+// FP8x4 conversions
+defm : CVT_F32X4_TO_FPX4_RS_SF_VEC<"e4m3", v4i8>;
+defm : CVT_F32X4_TO_FPX4_RS_SF_VEC<"e5m2", v4i8>;
+
+// FP6x4 conversions
+defm : CVT_F32X4_TO_FPX4_RS_SF_VEC<"e2m3", v4i8>;
+defm : CVT_F32X4_TO_FPX4_RS_SF_VEC<"e3m2", v4i8>;
+
+// FP4x4 conversions
+def : Pat<(i16 (CVT_F32X4_TO_FPX4_RS_SF_NODE<"E2M1",
+ SDT_CVT_F32X4_TO_FPX4_RS_INT>
+ f32:$f1, f32:$f2, f32:$f3, f32:$f4, i32:$rbits, CvtRS)),
+ (CVT_e2m1x4_f32x4_rs_sf $f1, $f2, $f3, $f4, $rbits, CvtRS)>;
+def : Pat<(i16 (CVT_F32X4_TO_FPX4_RS_SF_NODE<"E2M1",
+ SDT_CVT_F32X4_TO_FPX4_RS_INT>
+ f32:$f1, f32:$f2, f32:$f3, f32:$f4, i32:$rbits, CvtRS_RELU)),
+ (CVT_e2m1x4_f32x4_rs_sf $f1, $f2, $f3, $f4, $rbits, CvtRS_RELU)>;
+}
+
//
// FNS
//
@@ -4461,6 +4528,10 @@ class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "
!eq(ptx_elt_type, "e2m1"),
!ne(kind, "")) : [hasSM120a, hasPTX<87>],
+ !and(!or(!eq(ptx_elt_type,"e4m3"),
+ !eq(ptx_elt_type,"e5m2")),
+ !eq(geom, "m16n8k16")) : [hasSM<89>, hasPTX<87>],
+
!or(!eq(ptx_elt_type, "e4m3"),
!eq(ptx_elt_type, "e5m2")) : [hasSM<89>, hasPTX<84>],
@@ -4476,6 +4547,11 @@ class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "
!and(!eq(geom, "m8n8k4"),
!eq(ptx_elt_type, "f64")) : [hasSM<80>, hasPTX<70>],
+ !and(!or(!eq(geom, "m16n8k4"),
+ !eq(geom, "m16n8k8"),
+ !eq(geom, "m16n8k16")),
+ !eq(ptx_elt_type, "f64")) : [hasSM<90>, hasPTX<78>],
+
// fp16 -> fp16/fp32 @ m8n32k16/m32n8k16
!and(!or(!eq(geom, "m8n32k16"),
!eq(geom, "m32n8k16")),
@@ -4760,8 +4836,8 @@ defset list<WMMA_INSTR> WMMAs = {
// MMA
class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
WMMA_REGINFO FragC, WMMA_REGINFO FragD,
- string ALayout, string BLayout, int Satfinite, string b1op>
- : WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, b1op, FragA, FragB, FragC, FragD>.record,
+ string ALayout, string BLayout, int Satfinite, string b1op, string Kind>
+ : WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, b1op, Kind, FragA, FragB, FragC, FragD>.record,
[FragA.Ins, FragB.Ins, FragC.Ins]>,
// Requires does not seem to have effect on Instruction w/o Patterns.
// We set it here anyways and propagate to the Pat<> we construct below.
@@ -4776,6 +4852,7 @@ class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
# FragA.geom
# "." # ALayout
# "." # BLayout
+ # !if(!ne(Kind, ""), "." # Kind, "")
# !if(Satfinite, ".satfinite", "")
# TypeList
# b1op # "\n\t\t"
@@ -4792,13 +4869,15 @@ defset list<WMMA_INSTR> MMAs = {
foreach satf = [0, 1] in {
foreach op = NVVM_MMA_OPS.all_mma_ops in {
foreach b1op = NVVM_MMA_B1OPS<op>.ret in {
- if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then {
- def : MMA<WMMA_REGINFO<op[0], "mma">,
- WMMA_REGINFO<op[1], "mma">,
- WMMA_REGINFO<op[2], "mma">,
- WMMA_REGINFO<op[3], "mma">,
- layout_a, layout_b, satf, b1op>;
- }
+ foreach kind = ["", "kind::f8f6f4"] in {
+ if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, kind, satf>.ret then {
+ def : MMA<WMMA_REGINFO<op[0], "mma", "", kind>,
+ WMMA_REGINFO<op[1], "mma", "", kind>,
+ WMMA_REGINFO<op[2], "mma", "", kind>,
+ WMMA_REGINFO<op[3], "mma", "", kind>,
+ layout_a, layout_b, satf, b1op, kind>;
+ }
+ } // kind
} // b1op
} // op
} // satf
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 0afec42..989950f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -307,6 +307,10 @@ private:
bool selectHandleFromBinding(Register &ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
+ bool selectCounterHandleFromBinding(Register &ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I) const;
+
bool selectReadImageIntrinsic(Register &ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
bool selectImageWriteIntrinsic(MachineInstr &I) const;
@@ -314,6 +318,8 @@ private:
MachineInstr &I) const;
bool selectModf(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
+ bool selectUpdateCounter(Register &ResVReg, const SPIRVType *ResType,
+ MachineInstr &I) const;
bool selectFrexp(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
// Utilities
@@ -3443,6 +3449,10 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
case Intrinsic::spv_resource_handlefrombinding: {
return selectHandleFromBinding(ResVReg, ResType, I);
}
+ case Intrinsic::spv_resource_counterhandlefrombinding:
+ return selectCounterHandleFromBinding(ResVReg, ResType, I);
+ case Intrinsic::spv_resource_updatecounter:
+ return selectUpdateCounter(ResVReg, ResType, I);
case Intrinsic::spv_resource_store_typedbuffer: {
return selectImageWriteIntrinsic(I);
}
@@ -3478,6 +3488,130 @@ bool SPIRVInstructionSelector::selectHandleFromBinding(Register &ResVReg,
*cast<GIntrinsic>(&I), I);
}
+bool SPIRVInstructionSelector::selectCounterHandleFromBinding(
+ Register &ResVReg, const SPIRVType *ResType, MachineInstr &I) const {
+ auto &Intr = cast<GIntrinsic>(I);
+ assert(Intr.getIntrinsicID() ==
+ Intrinsic::spv_resource_counterhandlefrombinding);
+
+ // Extract information from the intrinsic call.
+ Register MainHandleReg = Intr.getOperand(2).getReg();
+ auto *MainHandleDef = cast<GIntrinsic>(getVRegDef(*MRI, MainHandleReg));
+ assert(MainHandleDef->getIntrinsicID() ==
+ Intrinsic::spv_resource_handlefrombinding);
+
+ uint32_t Set = getIConstVal(Intr.getOperand(4).getReg(), MRI);
+ uint32_t Binding = getIConstVal(Intr.getOperand(3).getReg(), MRI);
+ uint32_t ArraySize = getIConstVal(MainHandleDef->getOperand(4).getReg(), MRI);
+ Register IndexReg = MainHandleDef->getOperand(5).getReg();
+ const bool IsNonUniform = false;
+ std::string CounterName =
+ getStringValueFromReg(MainHandleDef->getOperand(6).getReg(), *MRI) +
+ ".counter";
+
+ // Create the counter variable.
+ MachineIRBuilder MIRBuilder(I);
+ Register CounterVarReg = buildPointerToResource(
+ GR.getPointeeType(ResType), GR.getPointerStorageClass(ResType), Set,
+ Binding, ArraySize, IndexReg, IsNonUniform, CounterName, MIRBuilder);
+
+ return BuildCOPY(ResVReg, CounterVarReg, I);
+}
+
+bool SPIRVInstructionSelector::selectUpdateCounter(Register &ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I) const {
+ auto &Intr = cast<GIntrinsic>(I);
+ assert(Intr.getIntrinsicID() == Intrinsic::spv_resource_updatecounter);
+
+ Register CounterHandleReg = Intr.getOperand(2).getReg();
+ Register IncrReg = Intr.getOperand(3).getReg();
+
+ // The counter handle is a pointer to the counter variable (which is a struct
+ // containing an i32). We need to get a pointer to that i32 member to do the
+ // atomic operation.
+#ifndef NDEBUG
+ SPIRVType *CounterVarType = GR.getSPIRVTypeForVReg(CounterHandleReg);
+ SPIRVType *CounterVarPointeeType = GR.getPointeeType(CounterVarType);
+ assert(CounterVarPointeeType &&
+ CounterVarPointeeType->getOpcode() == SPIRV::OpTypeStruct &&
+ "Counter variable must be a struct");
+ assert(GR.getPointerStorageClass(CounterVarType) ==
+ SPIRV::StorageClass::StorageBuffer &&
+ "Counter variable must be in the storage buffer storage class");
+ assert(CounterVarPointeeType->getNumOperands() == 2 &&
+ "Counter variable must have exactly 1 member in the struct");
+ const SPIRVType *MemberType =
+ GR.getSPIRVTypeForVReg(CounterVarPointeeType->getOperand(1).getReg());
+ assert(MemberType->getOpcode() == SPIRV::OpTypeInt &&
+ "Counter variable struct must have a single i32 member");
+#endif
+
+ // The struct has a single i32 member.
+ MachineIRBuilder MIRBuilder(I);
+ const Type *LLVMIntType =
+ Type::getInt32Ty(I.getMF()->getFunction().getContext());
+
+ SPIRVType *IntPtrType = GR.getOrCreateSPIRVPointerType(
+ LLVMIntType, MIRBuilder, SPIRV::StorageClass::StorageBuffer);
+
+ auto Zero = buildI32Constant(0, I);
+ if (!Zero.second)
+ return false;
+
+ Register PtrToCounter =
+ MRI->createVirtualRegister(GR.getRegClass(IntPtrType));
+ if (!BuildMI(*I.getParent(), I, I.getDebugLoc(),
+ TII.get(SPIRV::OpAccessChain))
+ .addDef(PtrToCounter)
+ .addUse(GR.getSPIRVTypeID(IntPtrType))
+ .addUse(CounterHandleReg)
+ .addUse(Zero.first)
+ .constrainAllUses(TII, TRI, RBI)) {
+ return false;
+ }
+
+ // For UAV/SSBO counters, the scope is Device. The counter variable is not
+ // used as a flag. So the memory semantics can be None.
+ auto Scope = buildI32Constant(SPIRV::Scope::Device, I);
+ if (!Scope.second)
+ return false;
+ auto Semantics = buildI32Constant(SPIRV::MemorySemantics::None, I);
+ if (!Semantics.second)
+ return false;
+
+ int64_t IncrVal = getIConstValSext(IncrReg, MRI);
+ auto Incr = buildI32Constant(static_cast<uint32_t>(IncrVal), I);
+ if (!Incr.second)
+ return false;
+
+ Register AtomicRes = MRI->createVirtualRegister(GR.getRegClass(ResType));
+ if (!BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpAtomicIAdd))
+ .addDef(AtomicRes)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(PtrToCounter)
+ .addUse(Scope.first)
+ .addUse(Semantics.first)
+ .addUse(Incr.first)
+ .constrainAllUses(TII, TRI, RBI)) {
+ return false;
+ }
+ if (IncrVal >= 0) {
+ return BuildCOPY(ResVReg, AtomicRes, I);
+ }
+
+ // In HLSL, IncrementCounter returns the value *before* the increment, while
+ // DecrementCounter returns the value *after* the decrement. Both are lowered
+ // to the same atomic intrinsic which returns the value *before* the
+ // operation. So for decrements (negative IncrVal), we must subtract the
+ // increment value from the result to get the post-decrement value.
+ return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpIAddS))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(AtomicRes)
+ .addUse(Incr.first)
+ .constrainAllUses(TII, TRI, RBI);
+}
bool SPIRVInstructionSelector::selectReadImageIntrinsic(
Register &ResVReg, const SPIRVType *ResType, MachineInstr &I) const {
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizeImplicitBinding.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizeImplicitBinding.cpp
index 205895e..fc14a03 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizeImplicitBinding.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizeImplicitBinding.cpp
@@ -39,6 +39,10 @@ private:
void collectBindingInfo(Module &M);
uint32_t getAndReserveFirstUnusedBinding(uint32_t DescSet);
void replaceImplicitBindingCalls(Module &M);
+ void replaceResourceHandleCall(Module &M, CallInst *OldCI,
+ uint32_t NewBinding);
+ void replaceCounterHandleCall(Module &M, CallInst *OldCI,
+ uint32_t NewBinding);
void verifyUniqueOrderIdPerResource(SmallVectorImpl<CallInst *> &Calls);
// A map from descriptor set to a bit vector of used binding numbers.
@@ -56,64 +60,93 @@ struct BindingInfoCollector : public InstVisitor<BindingInfoCollector> {
: UsedBindings(UsedBindings), ImplicitBindingCalls(ImplicitBindingCalls) {
}
+ void addBinding(uint32_t DescSet, uint32_t Binding) {
+ if (UsedBindings.size() <= DescSet) {
+ UsedBindings.resize(DescSet + 1);
+ UsedBindings[DescSet].resize(64);
+ }
+ if (UsedBindings[DescSet].size() <= Binding) {
+ UsedBindings[DescSet].resize(2 * Binding + 1);
+ }
+ UsedBindings[DescSet].set(Binding);
+ }
+
void visitCallInst(CallInst &CI) {
if (CI.getIntrinsicID() == Intrinsic::spv_resource_handlefrombinding) {
const uint32_t DescSet =
cast<ConstantInt>(CI.getArgOperand(0))->getZExtValue();
const uint32_t Binding =
cast<ConstantInt>(CI.getArgOperand(1))->getZExtValue();
-
- if (UsedBindings.size() <= DescSet) {
- UsedBindings.resize(DescSet + 1);
- UsedBindings[DescSet].resize(64);
- }
- if (UsedBindings[DescSet].size() <= Binding) {
- UsedBindings[DescSet].resize(2 * Binding + 1);
- }
- UsedBindings[DescSet].set(Binding);
+ addBinding(DescSet, Binding);
} else if (CI.getIntrinsicID() ==
Intrinsic::spv_resource_handlefromimplicitbinding) {
ImplicitBindingCalls.push_back(&CI);
+ } else if (CI.getIntrinsicID() ==
+ Intrinsic::spv_resource_counterhandlefrombinding) {
+ const uint32_t DescSet =
+ cast<ConstantInt>(CI.getArgOperand(2))->getZExtValue();
+ const uint32_t Binding =
+ cast<ConstantInt>(CI.getArgOperand(1))->getZExtValue();
+ addBinding(DescSet, Binding);
+ } else if (CI.getIntrinsicID() ==
+ Intrinsic::spv_resource_counterhandlefromimplicitbinding) {
+ ImplicitBindingCalls.push_back(&CI);
}
}
};
+static uint32_t getOrderId(const CallInst *CI) {
+ uint32_t OrderIdArgIdx = 0;
+ switch (CI->getIntrinsicID()) {
+ case Intrinsic::spv_resource_handlefromimplicitbinding:
+ OrderIdArgIdx = 0;
+ break;
+ case Intrinsic::spv_resource_counterhandlefromimplicitbinding:
+ OrderIdArgIdx = 1;
+ break;
+ default:
+ llvm_unreachable("CallInst is not an implicit binding intrinsic");
+ }
+ return cast<ConstantInt>(CI->getArgOperand(OrderIdArgIdx))->getZExtValue();
+}
+
+static uint32_t getDescSet(const CallInst *CI) {
+ uint32_t DescSetArgIdx;
+ switch (CI->getIntrinsicID()) {
+ case Intrinsic::spv_resource_handlefromimplicitbinding:
+ case Intrinsic::spv_resource_handlefrombinding:
+ DescSetArgIdx = 1;
+ break;
+ case Intrinsic::spv_resource_counterhandlefromimplicitbinding:
+ case Intrinsic::spv_resource_counterhandlefrombinding:
+ DescSetArgIdx = 2;
+ break;
+ default:
+ llvm_unreachable("CallInst is not an implicit binding intrinsic");
+ }
+ return cast<ConstantInt>(CI->getArgOperand(DescSetArgIdx))->getZExtValue();
+}
+
void SPIRVLegalizeImplicitBinding::collectBindingInfo(Module &M) {
BindingInfoCollector InfoCollector(UsedBindings, ImplicitBindingCalls);
InfoCollector.visit(M);
// Sort the collected calls by their order ID.
- std::sort(
- ImplicitBindingCalls.begin(), ImplicitBindingCalls.end(),
- [](const CallInst *A, const CallInst *B) {
- const uint32_t OrderIdArgIdx = 0;
- const uint32_t OrderA =
- cast<ConstantInt>(A->getArgOperand(OrderIdArgIdx))->getZExtValue();
- const uint32_t OrderB =
- cast<ConstantInt>(B->getArgOperand(OrderIdArgIdx))->getZExtValue();
- return OrderA < OrderB;
- });
+ std::sort(ImplicitBindingCalls.begin(), ImplicitBindingCalls.end(),
+ [](const CallInst *A, const CallInst *B) {
+ return getOrderId(A) < getOrderId(B);
+ });
}
void SPIRVLegalizeImplicitBinding::verifyUniqueOrderIdPerResource(
SmallVectorImpl<CallInst *> &Calls) {
// Check that the order Id is unique per resource.
for (uint32_t i = 1; i < Calls.size(); ++i) {
- const uint32_t OrderIdArgIdx = 0;
- const uint32_t DescSetArgIdx = 1;
- const uint32_t OrderA =
- cast<ConstantInt>(Calls[i - 1]->getArgOperand(OrderIdArgIdx))
- ->getZExtValue();
- const uint32_t OrderB =
- cast<ConstantInt>(Calls[i]->getArgOperand(OrderIdArgIdx))
- ->getZExtValue();
+ const uint32_t OrderA = getOrderId(Calls[i - 1]);
+ const uint32_t OrderB = getOrderId(Calls[i]);
if (OrderA == OrderB) {
- const uint32_t DescSetA =
- cast<ConstantInt>(Calls[i - 1]->getArgOperand(DescSetArgIdx))
- ->getZExtValue();
- const uint32_t DescSetB =
- cast<ConstantInt>(Calls[i]->getArgOperand(DescSetArgIdx))
- ->getZExtValue();
+ const uint32_t DescSetA = getDescSet(Calls[i - 1]);
+ const uint32_t DescSetB = getDescSet(Calls[i]);
if (DescSetA != DescSetB) {
report_fatal_error("Implicit binding calls with the same order ID must "
"have the same descriptor set");
@@ -144,36 +177,26 @@ void SPIRVLegalizeImplicitBinding::replaceImplicitBindingCalls(Module &M) {
uint32_t lastBindingNumber = -1;
for (CallInst *OldCI : ImplicitBindingCalls) {
- IRBuilder<> Builder(OldCI);
- const uint32_t OrderId =
- cast<ConstantInt>(OldCI->getArgOperand(0))->getZExtValue();
- const uint32_t DescSet =
- cast<ConstantInt>(OldCI->getArgOperand(1))->getZExtValue();
-
- // Reuse an existing binding for this order ID, if one was already assigned.
- // Otherwise, assign a new binding.
- const uint32_t NewBinding = (lastOrderId == OrderId)
- ? lastBindingNumber
- : getAndReserveFirstUnusedBinding(DescSet);
- lastOrderId = OrderId;
- lastBindingNumber = NewBinding;
-
- SmallVector<Value *, 8> Args;
- Args.push_back(Builder.getInt32(DescSet));
- Args.push_back(Builder.getInt32(NewBinding));
-
- // Copy the remaining arguments from the old call.
- for (uint32_t i = 2; i < OldCI->arg_size(); ++i) {
- Args.push_back(OldCI->getArgOperand(i));
+ const uint32_t OrderId = getOrderId(OldCI);
+ uint32_t BindingNumber;
+ if (OrderId == lastOrderId) {
+ BindingNumber = lastBindingNumber;
+ } else {
+ const uint32_t DescSet = getDescSet(OldCI);
+ BindingNumber = getAndReserveFirstUnusedBinding(DescSet);
}
- Function *NewFunc = Intrinsic::getOrInsertDeclaration(
- &M, Intrinsic::spv_resource_handlefrombinding, OldCI->getType());
- CallInst *NewCI = Builder.CreateCall(NewFunc, Args);
- NewCI->setCallingConv(OldCI->getCallingConv());
-
- OldCI->replaceAllUsesWith(NewCI);
- OldCI->eraseFromParent();
+ if (OldCI->getIntrinsicID() ==
+ Intrinsic::spv_resource_handlefromimplicitbinding) {
+ replaceResourceHandleCall(M, OldCI, BindingNumber);
+ } else {
+ assert(OldCI->getIntrinsicID() ==
+ Intrinsic::spv_resource_counterhandlefromimplicitbinding &&
+ "Unexpected implicit binding intrinsic");
+ replaceCounterHandleCall(M, OldCI, BindingNumber);
+ }
+ lastOrderId = OrderId;
+ lastBindingNumber = BindingNumber;
}
}
@@ -196,4 +219,49 @@ INITIALIZE_PASS(SPIRVLegalizeImplicitBinding, "legalize-spirv-implicit-binding",
ModulePass *llvm::createSPIRVLegalizeImplicitBindingPass() {
return new SPIRVLegalizeImplicitBinding();
-} \ No newline at end of file
+}
+
+void SPIRVLegalizeImplicitBinding::replaceResourceHandleCall(
+ Module &M, CallInst *OldCI, uint32_t NewBinding) {
+ IRBuilder<> Builder(OldCI);
+ const uint32_t DescSet =
+ cast<ConstantInt>(OldCI->getArgOperand(1))->getZExtValue();
+
+ SmallVector<Value *, 8> Args;
+ Args.push_back(Builder.getInt32(DescSet));
+ Args.push_back(Builder.getInt32(NewBinding));
+
+ // Copy the remaining arguments from the old call.
+ for (uint32_t i = 2; i < OldCI->arg_size(); ++i) {
+ Args.push_back(OldCI->getArgOperand(i));
+ }
+
+ Function *NewFunc = Intrinsic::getOrInsertDeclaration(
+ &M, Intrinsic::spv_resource_handlefrombinding, OldCI->getType());
+ CallInst *NewCI = Builder.CreateCall(NewFunc, Args);
+ NewCI->setCallingConv(OldCI->getCallingConv());
+
+ OldCI->replaceAllUsesWith(NewCI);
+ OldCI->eraseFromParent();
+}
+
+void SPIRVLegalizeImplicitBinding::replaceCounterHandleCall(
+ Module &M, CallInst *OldCI, uint32_t NewBinding) {
+ IRBuilder<> Builder(OldCI);
+ const uint32_t DescSet =
+ cast<ConstantInt>(OldCI->getArgOperand(2))->getZExtValue();
+
+ SmallVector<Value *, 8> Args;
+ Args.push_back(OldCI->getArgOperand(0));
+ Args.push_back(Builder.getInt32(NewBinding));
+ Args.push_back(Builder.getInt32(DescSet));
+
+ Type *Tys[] = {OldCI->getType(), OldCI->getArgOperand(0)->getType()};
+ Function *NewFunc = Intrinsic::getOrInsertDeclaration(
+ &M, Intrinsic::spv_resource_counterhandlefrombinding, Tys);
+ CallInst *NewCI = Builder.CreateCall(NewFunc, Args);
+ NewCI->setCallingConv(OldCI->getCallingConv());
+
+ OldCI->replaceAllUsesWith(NewCI);
+ OldCI->eraseFromParent();
+}
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index 327c011..1d47c89 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -385,6 +385,12 @@ uint64_t getIConstVal(Register ConstReg, const MachineRegisterInfo *MRI) {
return MI->getOperand(1).getCImm()->getValue().getZExtValue();
}
+int64_t getIConstValSext(Register ConstReg, const MachineRegisterInfo *MRI) {
+ const MachineInstr *MI = getDefInstrMaybeConstant(ConstReg, MRI);
+ assert(MI && MI->getOpcode() == TargetOpcode::G_CONSTANT);
+ return MI->getOperand(1).getCImm()->getSExtValue();
+}
+
bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID) {
if (const auto *GI = dyn_cast<GIntrinsic>(&MI))
return GI->is(IntrinsicID);
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index 409a0fd..5777a24 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -289,6 +289,9 @@ MachineInstr *getDefInstrMaybeConstant(Register &ConstReg,
// Get constant integer value of the given ConstReg.
uint64_t getIConstVal(Register ConstReg, const MachineRegisterInfo *MRI);
+// Get constant integer value of the given ConstReg, sign-extended.
+int64_t getIConstValSext(Register ConstReg, const MachineRegisterInfo *MRI);
+
// Check if MI is a SPIR-V specific intrinsic call.
bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID);
// Check if it's a SPIR-V specific intrinsic call.
diff --git a/llvm/lib/Target/X86/GISel/X86InstructionSelector.cpp b/llvm/lib/Target/X86/GISel/X86InstructionSelector.cpp
index 3090ad3..27fba34 100644
--- a/llvm/lib/Target/X86/GISel/X86InstructionSelector.cpp
+++ b/llvm/lib/Target/X86/GISel/X86InstructionSelector.cpp
@@ -407,6 +407,7 @@ bool X86InstructionSelector::select(MachineInstr &I) {
case TargetOpcode::G_TRUNC:
return selectTruncOrPtrToInt(I, MRI, MF);
case TargetOpcode::G_INTTOPTR:
+ case TargetOpcode::G_FREEZE:
return selectCopy(I, MRI);
case TargetOpcode::G_ZEXT:
return selectZext(I, MRI, MF);
diff --git a/llvm/lib/Target/X86/GISel/X86LegalizerInfo.cpp b/llvm/lib/Target/X86/GISel/X86LegalizerInfo.cpp
index e7709ef..11ef721 100644
--- a/llvm/lib/Target/X86/GISel/X86LegalizerInfo.cpp
+++ b/llvm/lib/Target/X86/GISel/X86LegalizerInfo.cpp
@@ -89,9 +89,29 @@ X86LegalizerInfo::X86LegalizerInfo(const X86Subtarget &STI,
// 32/64-bits needs support for s64/s128 to handle cases:
// s64 = EXTEND (G_IMPLICIT_DEF s32) -> s64 = G_IMPLICIT_DEF
// s128 = EXTEND (G_IMPLICIT_DEF s32/s64) -> s128 = G_IMPLICIT_DEF
- getActionDefinitionsBuilder(G_IMPLICIT_DEF)
+ getActionDefinitionsBuilder(
+ {G_IMPLICIT_DEF, G_PHI, G_FREEZE, G_CONSTANT_FOLD_BARRIER})
.legalFor({p0, s1, s8, s16, s32, s64})
- .legalFor(Is64Bit, {s128});
+ .legalFor(UseX87, {s80})
+ .legalFor(Is64Bit, {s128})
+ .legalFor(HasSSE2, {v16s8, v8s16, v4s32, v2s64})
+ .legalFor(HasAVX, {v32s8, v16s16, v8s32, v4s64})
+ .legalFor(HasAVX512, {v64s8, v32s16, v16s32, v8s64})
+ .widenScalarOrEltToNextPow2(0, /*Min=*/8)
+ .clampScalarOrElt(0, s8, sMaxScalar)
+ .moreElementsToNextPow2(0)
+ .clampMinNumElements(0, s8, 16)
+ .clampMinNumElements(0, s16, 8)
+ .clampMinNumElements(0, s32, 4)
+ .clampMinNumElements(0, s64, 2)
+ .clampMaxNumElements(0, s8, HasAVX512 ? 64 : (HasAVX ? 32 : 16))
+ .clampMaxNumElements(0, s16, HasAVX512 ? 32 : (HasAVX ? 16 : 8))
+ .clampMaxNumElements(0, s32, HasAVX512 ? 16 : (HasAVX ? 8 : 4))
+ .clampMaxNumElements(0, s64, HasAVX512 ? 8 : (HasAVX ? 4 : 2))
+ .clampMaxNumElements(0, p0,
+ Is64Bit ? s64MaxVector.getNumElements()
+ : s32MaxVector.getNumElements())
+ .scalarizeIf(scalarOrEltWiderThan(0, 64), 0);
getActionDefinitionsBuilder(G_CONSTANT)
.legalFor({p0, s8, s16, s32})
@@ -289,26 +309,6 @@ X86LegalizerInfo::X86LegalizerInfo(const X86Subtarget &STI,
.clampScalar(1, s16, sMaxScalar)
.scalarSameSizeAs(0, 1);
- // control flow
- getActionDefinitionsBuilder(G_PHI)
- .legalFor({s8, s16, s32, p0})
- .legalFor(UseX87, {s80})
- .legalFor(Is64Bit, {s64})
- .legalFor(HasSSE1, {v16s8, v8s16, v4s32, v2s64})
- .legalFor(HasAVX, {v32s8, v16s16, v8s32, v4s64})
- .legalFor(HasAVX512, {v64s8, v32s16, v16s32, v8s64})
- .clampMinNumElements(0, s8, 16)
- .clampMinNumElements(0, s16, 8)
- .clampMinNumElements(0, s32, 4)
- .clampMinNumElements(0, s64, 2)
- .clampMaxNumElements(0, s8, HasAVX512 ? 64 : (HasAVX ? 32 : 16))
- .clampMaxNumElements(0, s16, HasAVX512 ? 32 : (HasAVX ? 16 : 8))
- .clampMaxNumElements(0, s32, HasAVX512 ? 16 : (HasAVX ? 8 : 4))
- .clampMaxNumElements(0, s64, HasAVX512 ? 8 : (HasAVX ? 4 : 2))
- .widenScalarToNextPow2(0, /*Min=*/32)
- .clampScalar(0, s8, sMaxScalar)
- .scalarize(0);
-
getActionDefinitionsBuilder(G_BRCOND).legalFor({s1});
// pointer handling
@@ -592,11 +592,6 @@ X86LegalizerInfo::X86LegalizerInfo(const X86Subtarget &STI,
.minScalar(0, LLT::scalar(32))
.libcall();
- getActionDefinitionsBuilder({G_FREEZE, G_CONSTANT_FOLD_BARRIER})
- .legalFor({s8, s16, s32, s64, p0})
- .widenScalarToNextPow2(0, /*Min=*/8)
- .clampScalar(0, s8, sMaxScalar);
-
getLegacyLegalizerInfo().computeTables();
verify(*STI.getInstrInfo());
}
diff --git a/llvm/lib/Target/X86/X86InstrAVX512.td b/llvm/lib/Target/X86/X86InstrAVX512.td
index 564810c..83bd6ac 100644
--- a/llvm/lib/Target/X86/X86InstrAVX512.td
+++ b/llvm/lib/Target/X86/X86InstrAVX512.td
@@ -662,6 +662,7 @@ def VINSERTPSZrri : AVX512AIi8<0x21, MRMSrcReg, (outs VR128X:$dst),
"vinsertps\t{$src3, $src2, $src1, $dst|$dst, $src1, $src2, $src3}",
[(set VR128X:$dst, (X86insertps VR128X:$src1, VR128X:$src2, timm:$src3))]>,
EVEX, VVVV, Sched<[SchedWriteFShuffle.XMM]>;
+let mayLoad = 1 in
def VINSERTPSZrmi : AVX512AIi8<0x21, MRMSrcMem, (outs VR128X:$dst),
(ins VR128X:$src1, f32mem:$src2, u8imm:$src3),
"vinsertps\t{$src3, $src2, $src1, $dst|$dst, $src1, $src2, $src3}",
@@ -1293,6 +1294,7 @@ multiclass avx512_subvec_broadcast_rm<bits<8> opc, string OpcodeStr,
SDPatternOperator OpNode,
X86VectorVTInfo _Dst,
X86VectorVTInfo _Src> {
+ let hasSideEffects = 0, mayLoad = 1 in
defm rm : AVX512_maskable<opc, MRMSrcMem, _Dst, (outs _Dst.RC:$dst),
(ins _Src.MemOp:$src), OpcodeStr, "$src", "$src",
(_Dst.VT (OpNode addr:$src))>,
@@ -1748,6 +1750,7 @@ let Constraints = "$src1 = $dst", ExeDomain = _.ExeDomain in {
(_.VT (X86VPermt2 _.RC:$src1, IdxVT.RC:$src2, _.RC:$src3)), 1>,
EVEX, VVVV, AVX5128IBase, Sched<[sched]>;
+ let hasSideEffects = 0, mayLoad = 1 in
defm rm: AVX512_maskable_3src<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins IdxVT.RC:$src2, _.MemOp:$src3),
OpcodeStr, "$src3, $src2", "$src2, $src3",
@@ -1759,7 +1762,7 @@ let Constraints = "$src1 = $dst", ExeDomain = _.ExeDomain in {
multiclass avx512_perm_t_mb<bits<8> opc, string OpcodeStr,
X86FoldableSchedWrite sched,
X86VectorVTInfo _, X86VectorVTInfo IdxVT> {
- let Constraints = "$src1 = $dst", ExeDomain = _.ExeDomain in
+ let Constraints = "$src1 = $dst", ExeDomain = _.ExeDomain, hasSideEffects = 0, mayLoad = 1 in
defm rmb: AVX512_maskable_3src<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins IdxVT.RC:$src2, _.ScalarMemOp:$src3),
OpcodeStr, !strconcat("${src3}", _.BroadcastStr,", $src2"),
@@ -1987,6 +1990,7 @@ multiclass avx512_cmp_scalar<X86VectorVTInfo _, SDNode OpNode, SDNode OpNodeSAE,
_.FRC:$src2,
timm:$cc))]>,
EVEX, VVVV, VEX_LIG, Sched<[sched]>, SIMD_EXC;
+ let mayLoad = 1 in
def rmi : AVX512Ii8<0xC2, MRMSrcMem,
(outs _.KRC:$dst),
(ins _.FRC:$src1, _.ScalarMemOp:$src2, u8imm:$cc),
@@ -2145,6 +2149,7 @@ multiclass avx512_icmp_cc<bits<8> opc, string Suffix, PatFrag Frag,
(_.VT _.RC:$src2),
cond)))]>,
EVEX, VVVV, Sched<[sched]>;
+ let mayLoad = 1 in
def rmi : AVX512AIi8<opc, MRMSrcMem,
(outs _.KRC:$dst), (ins _.RC:$src1, _.MemOp:$src2, u8imm:$cc),
!strconcat("vpcmp", Suffix,
@@ -2167,6 +2172,7 @@ multiclass avx512_icmp_cc<bits<8> opc, string Suffix, PatFrag Frag,
(_.VT _.RC:$src2),
cond))))]>,
EVEX, VVVV, EVEX_K, Sched<[sched]>;
+ let mayLoad = 1 in
def rmik : AVX512AIi8<opc, MRMSrcMem,
(outs _.KRC:$dst), (ins _.KRCWM:$mask, _.RC:$src1, _.MemOp:$src2,
u8imm:$cc),
@@ -2198,6 +2204,7 @@ multiclass avx512_icmp_cc_rmb<bits<8> opc, string Suffix, PatFrag Frag,
PatFrag Frag_su, X86FoldableSchedWrite sched,
X86VectorVTInfo _, string Name> :
avx512_icmp_cc<opc, Suffix, Frag, Frag_su, sched, _, Name> {
+ let mayLoad = 1 in {
def rmbi : AVX512AIi8<opc, MRMSrcMem,
(outs _.KRC:$dst), (ins _.RC:$src1, _.ScalarMemOp:$src2,
u8imm:$cc),
@@ -2221,6 +2228,7 @@ multiclass avx512_icmp_cc_rmb<bits<8> opc, string Suffix, PatFrag Frag,
(_.BroadcastLdFrag addr:$src2),
cond))))]>,
EVEX, VVVV, EVEX_K, EVEX_B, Sched<[sched.Folded, sched.ReadAfterFold]>;
+ }
def : Pat<(_.KVT (Frag:$cc (_.BroadcastLdFrag addr:$src2),
(_.VT _.RC:$src1), cond)),
@@ -2305,6 +2313,7 @@ let Uses = [MXCSR], mayRaiseFPException = 1 in {
(X86cmpm_su (_.VT _.RC:$src1), (_.VT _.RC:$src2), timm:$cc),
1>, Sched<[sched]>;
+ let mayLoad = 1 in {
defm rmi : AVX512_maskable_cmp<0xC2, MRMSrcMem, _,
(outs _.KRC:$dst),(ins _.RC:$src1, _.MemOp:$src2, u8imm:$cc),
"vcmp"#_.Suffix,
@@ -2329,6 +2338,7 @@ let Uses = [MXCSR], mayRaiseFPException = 1 in {
timm:$cc)>,
EVEX_B, Sched<[sched.Folded, sched.ReadAfterFold]>;
}
+ }
// Patterns for selecting with loads in other operand.
def : Pat<(X86any_cmpm (_.LdFrag addr:$src2), (_.VT _.RC:$src1),
@@ -3771,6 +3781,7 @@ def VMOVDI2PDIZrr : AVX512BI<0x6E, MRMSrcReg, (outs VR128X:$dst), (ins GR32:$src
[(set VR128X:$dst,
(v4i32 (scalar_to_vector GR32:$src)))]>,
EVEX, Sched<[WriteVecMoveFromGpr]>;
+let mayLoad = 1 in
def VMOVDI2PDIZrm : AVX512BI<0x6E, MRMSrcMem, (outs VR128X:$dst), (ins i32mem:$src),
"vmovd\t{$src, $dst|$dst, $src}",
[(set VR128X:$dst,
@@ -3874,7 +3885,7 @@ def VMOVSS2DIZrr : AVX512BI<0x7E, MRMDestReg, (outs GR32:$dst),
// Move Quadword Int to Packed Quadword Int
//
-let ExeDomain = SSEPackedInt in {
+let ExeDomain = SSEPackedInt, mayLoad = 1, hasSideEffects = 0 in {
def VMOVQI2PQIZrm : AVX512XSI<0x7E, MRMSrcMem, (outs VR128X:$dst),
(ins i64mem:$src),
"vmovq\t{$src, $dst|$dst, $src}",
@@ -3930,13 +3941,13 @@ multiclass avx512_move_scalar<string asm, SDNode OpNode, PatFrag vzload_frag,
(_.VT (OpNode _.RC:$src1, _.RC:$src2)),
(_.VT _.RC:$src0))))],
_.ExeDomain>, EVEX, VVVV, EVEX_K, Sched<[SchedWriteFShuffle.XMM]>;
- let canFoldAsLoad = 1, isReMaterializable = 1 in {
+ let canFoldAsLoad = 1, isReMaterializable = 1, mayLoad = 1, hasSideEffects = 0 in {
def rm : AVX512PI<0x10, MRMSrcMem, (outs _.RC:$dst), (ins _.ScalarMemOp:$src),
!strconcat(asm, "\t{$src, $dst|$dst, $src}"),
[(set _.RC:$dst, (_.VT (vzload_frag addr:$src)))],
_.ExeDomain>, EVEX, Sched<[WriteFLoad]>;
// _alt version uses FR32/FR64 register class.
- let isCodeGenOnly = 1 in
+ let isCodeGenOnly = 1, mayLoad = 1, hasSideEffects = 0 in
def rm_alt : AVX512PI<0x10, MRMSrcMem, (outs _.FRC:$dst), (ins _.ScalarMemOp:$src),
!strconcat(asm, "\t{$src, $dst|$dst, $src}"),
[(set _.FRC:$dst, (_.ScalarLdFrag addr:$src))],
@@ -4557,6 +4568,7 @@ let Predicates = [HasAVX512] in {
// AVX-512 - Non-temporals
//===----------------------------------------------------------------------===//
+let mayLoad = 1, hasSideEffects = 0 in {
def VMOVNTDQAZrm : AVX512PI<0x2A, MRMSrcMem, (outs VR512:$dst),
(ins i512mem:$src), "vmovntdqa\t{$src, $dst|$dst, $src}",
[], SSEPackedInt>, Sched<[SchedWriteVecMoveLSNT.ZMM.RM]>,
@@ -4575,11 +4587,12 @@ let Predicates = [HasVLX] in {
[], SSEPackedInt>, Sched<[SchedWriteVecMoveLSNT.XMM.RM]>,
EVEX, T8, PD, EVEX_V128, EVEX_CD8<64, CD8VF>;
}
+}
multiclass avx512_movnt<bits<8> opc, string OpcodeStr, X86VectorVTInfo _,
X86SchedWriteMoveLS Sched,
PatFrag st_frag = alignednontemporalstore> {
- let SchedRW = [Sched.MR], AddedComplexity = 400 in
+ let mayStore = 1, SchedRW = [Sched.MR], AddedComplexity = 400 in
def mr : AVX512PI<opc, MRMDestMem, (outs), (ins _.MemOp:$dst, _.RC:$src),
!strconcat(OpcodeStr, "\t{$src, $dst|$dst, $src}"),
[(st_frag (_.VT _.RC:$src), addr:$dst)],
@@ -4682,6 +4695,7 @@ multiclass avx512_binop_rm<bits<8> opc, string OpcodeStr, SDNode OpNode,
IsCommutable, IsCommutable>, AVX512BIBase, EVEX, VVVV,
Sched<[sched]>;
+ let mayLoad = 1, hasSideEffects = 0 in
defm rm : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, _.MemOp:$src2), OpcodeStr,
"$src2, $src1", "$src1, $src2",
@@ -4694,6 +4708,7 @@ multiclass avx512_binop_rmb<bits<8> opc, string OpcodeStr, SDNode OpNode,
X86VectorVTInfo _, X86FoldableSchedWrite sched,
bit IsCommutable = 0> :
avx512_binop_rm<opc, OpcodeStr, OpNode, _, sched, IsCommutable> {
+ let mayLoad = 1, hasSideEffects = 0 in
defm rmb : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, _.ScalarMemOp:$src2), OpcodeStr,
"${src2}"#_.BroadcastStr#", $src1",
@@ -4811,6 +4826,7 @@ multiclass avx512_binop_rm2<bits<8> opc, string OpcodeStr,
(_Src.VT _Src.RC:$src2))),
IsCommutable>,
AVX512BIBase, EVEX, VVVV, Sched<[sched]>;
+ let mayLoad = 1, hasSideEffects = 0 in {
defm rm : AVX512_maskable<opc, MRMSrcMem, _Dst, (outs _Dst.RC:$dst),
(ins _Src.RC:$src1, _Src.MemOp:$src2), OpcodeStr,
"$src2, $src1", "$src1, $src2",
@@ -4828,6 +4844,7 @@ multiclass avx512_binop_rm2<bits<8> opc, string OpcodeStr,
(_Brdct.VT (_Brdct.BroadcastLdFrag addr:$src2)))))>,
AVX512BIBase, EVEX, VVVV, EVEX_B,
Sched<[sched.Folded, sched.ReadAfterFold]>;
+ }
}
defm VPADD : avx512_binop_rm_vl_all<0xFC, 0xFD, 0xFE, 0xD4, "vpadd", add,
@@ -4893,6 +4910,7 @@ defm VPMULTISHIFTQB : avx512_binop_all<0x83, "vpmultishiftqb", SchedWriteVecALU,
multiclass avx512_packs_rmb<bits<8> opc, string OpcodeStr, SDNode OpNode,
X86VectorVTInfo _Src, X86VectorVTInfo _Dst,
X86FoldableSchedWrite sched> {
+ let mayLoad = 1, hasSideEffects = 0 in
defm rmb : AVX512_maskable<opc, MRMSrcMem, _Dst, (outs _Dst.RC:$dst),
(ins _Src.RC:$src1, _Src.ScalarMemOp:$src2),
OpcodeStr,
@@ -4916,6 +4934,7 @@ multiclass avx512_packs_rm<bits<8> opc, string OpcodeStr,
(_Src.VT _Src.RC:$src2))),
IsCommutable, IsCommutable>,
EVEX_CD8<_Src.EltSize, CD8VF>, EVEX, VVVV, Sched<[sched]>;
+ let mayLoad = 1, hasSideEffects = 0 in
defm rm : AVX512_maskable<opc, MRMSrcMem, _Dst, (outs _Dst.RC:$dst),
(ins _Src.RC:$src1, _Src.MemOp:$src2), OpcodeStr,
"$src2, $src1", "$src1, $src2",
@@ -5370,6 +5389,7 @@ multiclass avx512_fp_scalar<bits<8> opc, string OpcodeStr,X86VectorVTInfo _,
(_.VT (VecNode _.RC:$src1, _.RC:$src2)), "_Int">,
Sched<[sched]>;
+ let mayLoad = 1 in
defm rm : AVX512_maskable_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, _.IntScalarMemOp:$src2), OpcodeStr,
"$src2, $src1", "$src1, $src2",
@@ -5384,6 +5404,7 @@ multiclass avx512_fp_scalar<bits<8> opc, string OpcodeStr,X86VectorVTInfo _,
Sched<[sched]> {
let isCommutable = IsCommutable;
}
+ let mayLoad = 1 in
def rm : I< opc, MRMSrcMem, (outs _.FRC:$dst),
(ins _.FRC:$src1, _.ScalarMemOp:$src2),
OpcodeStr#"\t{$src2, $src1, $dst|$dst, $src1, $src2}",
@@ -5414,6 +5435,7 @@ multiclass avx512_fp_scalar_sae<bits<8> opc, string OpcodeStr,X86VectorVTInfo _,
(_.VT (VecNode _.RC:$src1, _.RC:$src2)), "_Int">,
Sched<[sched]>, SIMD_EXC;
+ let mayLoad = 1 in
defm rm : AVX512_maskable_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, _.IntScalarMemOp:$src2), OpcodeStr,
"$src2, $src1", "$src1, $src2",
@@ -5430,6 +5452,7 @@ multiclass avx512_fp_scalar_sae<bits<8> opc, string OpcodeStr,X86VectorVTInfo _,
Sched<[sched]> {
let isCommutable = IsCommutable;
}
+ let mayLoad = 1 in
def rm : I< opc, MRMSrcMem, (outs _.FRC:$dst),
(ins _.FRC:$src1, _.ScalarMemOp:$src2),
OpcodeStr#"\t{$src2, $src1, $dst|$dst, $src1, $src2}",
@@ -5509,6 +5532,7 @@ multiclass avx512_comutable_binop_s<bits<8> opc, string OpcodeStr,
Sched<[sched]> {
let isCommutable = 1;
}
+ let mayLoad = 1 in
def rm : I< opc, MRMSrcMem, (outs _.FRC:$dst),
(ins _.FRC:$src1, _.ScalarMemOp:$src2),
OpcodeStr#"\t{$src2, $src1, $dst|$dst, $src1, $src2}",
@@ -5737,6 +5761,7 @@ multiclass avx512_fp_scalef_p<bits<8> opc, string OpcodeStr, SDNode OpNode,
"$src2, $src1", "$src1, $src2",
(_.VT (OpNode _.RC:$src1, _.RC:$src2))>,
EVEX, VVVV, Sched<[sched]>;
+ let mayLoad = 1 in {
defm rm: AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, _.MemOp:$src2), OpcodeStr#_.Suffix,
"$src2, $src1", "$src1, $src2",
@@ -5749,6 +5774,7 @@ multiclass avx512_fp_scalef_p<bits<8> opc, string OpcodeStr, SDNode OpNode,
(OpNode _.RC:$src1, (_.VT (_.BroadcastLdFrag addr:$src2)))>,
EVEX, VVVV, EVEX_B, Sched<[sched.Folded, sched.ReadAfterFold]>;
}
+ }
}
multiclass avx512_fp_scalef_scalar<bits<8> opc, string OpcodeStr, SDNode OpNode,
@@ -5759,6 +5785,7 @@ multiclass avx512_fp_scalef_scalar<bits<8> opc, string OpcodeStr, SDNode OpNode,
"$src2, $src1", "$src1, $src2",
(_.VT (OpNode _.RC:$src1, _.RC:$src2))>,
Sched<[sched]>;
+ let mayLoad = 1 in
defm rm: AVX512_maskable_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, _.IntScalarMemOp:$src2), OpcodeStr#_.Suffix,
"$src2, $src1", "$src1, $src2",
@@ -5916,6 +5943,7 @@ multiclass avx512_shift_rmi<bits<8> opc, Format ImmFormR, Format ImmFormM,
"$src2, $src1", "$src1, $src2",
(_.VT (OpNode _.RC:$src1, (i8 timm:$src2)))>,
Sched<[sched]>;
+ let mayLoad = 1 in
defm mi : AVX512_maskable<opc, ImmFormM, _, (outs _.RC:$dst),
(ins _.MemOp:$src1, u8imm:$src2), OpcodeStr,
"$src2, $src1", "$src1, $src2",
@@ -5928,7 +5956,7 @@ multiclass avx512_shift_rmi<bits<8> opc, Format ImmFormR, Format ImmFormM,
multiclass avx512_shift_rmbi<bits<8> opc, Format ImmFormM,
string OpcodeStr, SDNode OpNode,
X86FoldableSchedWrite sched, X86VectorVTInfo _> {
- let ExeDomain = _.ExeDomain in
+ let ExeDomain = _.ExeDomain, mayLoad = 1 in
defm mbi : AVX512_maskable<opc, ImmFormM, _, (outs _.RC:$dst),
(ins _.ScalarMemOp:$src1, u8imm:$src2), OpcodeStr,
"$src2, ${src1}"#_.BroadcastStr, "${src1}"#_.BroadcastStr#", $src2",
@@ -5946,6 +5974,7 @@ multiclass avx512_shift_rrm<bits<8> opc, string OpcodeStr, SDNode OpNode,
"$src2, $src1", "$src1, $src2",
(_.VT (OpNode _.RC:$src1, (SrcVT VR128X:$src2)))>,
AVX512BIBase, EVEX, VVVV, Sched<[sched]>;
+ let mayLoad = 1 in
defm rm : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, i128mem:$src2), OpcodeStr,
"$src2, $src1", "$src1, $src2",
@@ -6095,6 +6124,7 @@ multiclass avx512_var_shift<bits<8> opc, string OpcodeStr, SDNode OpNode,
"$src2, $src1", "$src1, $src2",
(_.VT (OpNode _.RC:$src1, (_.VT _.RC:$src2)))>,
AVX5128IBase, EVEX, VVVV, Sched<[sched]>;
+ let mayLoad = 1 in
defm rm : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, _.MemOp:$src2), OpcodeStr,
"$src2, $src1", "$src1, $src2",
@@ -6107,7 +6137,7 @@ multiclass avx512_var_shift<bits<8> opc, string OpcodeStr, SDNode OpNode,
multiclass avx512_var_shift_mb<bits<8> opc, string OpcodeStr, SDNode OpNode,
X86FoldableSchedWrite sched, X86VectorVTInfo _> {
- let ExeDomain = _.ExeDomain in
+ let ExeDomain = _.ExeDomain, mayLoad = 1 in
defm rmb : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, _.ScalarMemOp:$src2), OpcodeStr,
"${src2}"#_.BroadcastStr#", $src1",
@@ -6372,6 +6402,7 @@ multiclass avx512_permil_vec<bits<8> OpcVar, string OpcodeStr, SDNode OpNode,
(_.VT (OpNode _.RC:$src1,
(Ctrl.VT Ctrl.RC:$src2)))>,
T8, PD, EVEX, VVVV, Sched<[sched]>;
+ let mayLoad = 1 in {
defm rm: AVX512_maskable<OpcVar, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, Ctrl.MemOp:$src2), OpcodeStr,
"$src2, $src1", "$src1, $src2",
@@ -6389,6 +6420,7 @@ multiclass avx512_permil_vec<bits<8> OpcVar, string OpcodeStr, SDNode OpNode,
(Ctrl.VT (Ctrl.BroadcastLdFrag addr:$src2))))>,
T8, PD, EVEX, VVVV, EVEX_B, EVEX_CD8<_.EltSize, CD8VF>,
Sched<[sched.Folded, sched.ReadAfterFold]>;
+ }
}
multiclass avx512_permil_vec_common<string OpcodeStr, bits<8> OpcVar,
@@ -7258,6 +7290,7 @@ let ExeDomain = DstVT.ExeDomain, Uses = _Uses,
(OpNode (DstVT.VT DstVT.RC:$src1), SrcRC:$src2))]>,
EVEX, VVVV, Sched<[sched, ReadDefault, ReadInt2Fpu]>;
+ let mayLoad = 1 in
def rm_Int : SI<opc, MRMSrcMem, (outs DstVT.RC:$dst),
(ins DstVT.RC:$src1, x86memop:$src2),
asm#"{"#mem#"}\t{$src2, $src1, $dst|$dst, $src1, $src2}",
@@ -7400,6 +7433,7 @@ multiclass avx512_cvt_s_int_round<bits<8> opc, X86VectorVTInfo SrcVT,
[(set DstVT.RC:$dst, (OpNodeRnd (SrcVT.VT SrcVT.RC:$src),(i32 timm:$rc)))]>,
EVEX, VEX_LIG, EVEX_B, EVEX_RC,
Sched<[sched]>;
+ let mayLoad = 1 in
def rm_Int : SI<opc, MRMSrcMem, (outs DstVT.RC:$dst), (ins SrcVT.IntScalarMemOp:$src),
!strconcat(asm,"\t{$src, $dst|$dst, $src}"),
[(set DstVT.RC:$dst, (OpNode
@@ -7451,6 +7485,7 @@ multiclass avx512_cvt_s<bits<8> opc, string asm, X86VectorVTInfo SrcVT,
!strconcat(asm,"\t{$src, $dst|$dst, $src}"),
[(set DstVT.RC:$dst, (OpNode SrcVT.FRC:$src))]>,
EVEX, VEX_LIG, Sched<[sched]>, SIMD_EXC;
+ let mayLoad = 1 in
def rm : AVX512<opc, MRMSrcMem, (outs DstVT.RC:$dst), (ins SrcVT.ScalarMemOp:$src),
!strconcat(asm,"\t{$src, $dst|$dst, $src}"),
[(set DstVT.RC:$dst, (OpNode (SrcVT.ScalarLdFrag addr:$src)))]>,
@@ -7572,6 +7607,7 @@ let Predicates = [prd], ExeDomain = _SrcRC.ExeDomain in {
!strconcat(asm,"\t{$src, $dst|$dst, $src}"),
[(set _DstRC.RC:$dst, (OpNode _SrcRC.FRC:$src))]>,
EVEX, VEX_LIG, Sched<[sched]>, SIMD_EXC;
+ let mayLoad = 1 in
def rm : AVX512<opc, MRMSrcMem, (outs _DstRC.RC:$dst), (ins _SrcRC.ScalarMemOp:$src),
!strconcat(asm,"\t{$src, $dst|$dst, $src}"),
[(set _DstRC.RC:$dst, (OpNode (_SrcRC.ScalarLdFrag addr:$src)))]>,
@@ -7587,6 +7623,7 @@ let Predicates = [prd], ExeDomain = _SrcRC.ExeDomain in {
!strconcat(asm,"\t{{sae}, $src, $dst|$dst, $src, {sae}}"),
[(set _DstRC.RC:$dst, (OpNodeSAE (_SrcRC.VT _SrcRC.RC:$src)))]>,
EVEX, VEX_LIG, EVEX_B, Sched<[sched]>;
+ let mayLoad = 1 in
def rm_Int : AVX512<opc, MRMSrcMem, (outs _DstRC.RC:$dst),
(ins _SrcRC.IntScalarMemOp:$src),
!strconcat(asm,"\t{$src, $dst|$dst, $src}"),
@@ -7644,6 +7681,7 @@ multiclass avx512_cvt_fp_scalar<bits<8> opc, string OpcodeStr, X86VectorVTInfo _
(_.VT (OpNode (_.VT _.RC:$src1),
(_Src.VT _Src.RC:$src2))), "_Int">,
EVEX, VVVV, VEX_LIG, Sched<[sched]>;
+ let mayLoad = 1 in
defm rm : AVX512_maskable_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, _Src.IntScalarMemOp:$src2), OpcodeStr,
"$src2, $src1", "$src1, $src2",
@@ -7807,6 +7845,7 @@ let Uses = [MXCSR], mayRaiseFPException = 1 in {
_.ImmAllZerosV)>,
EVEX, Sched<[sched]>;
+ let mayLoad = 1 in {
defm rm : AVX512_maskable_cvt<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins MemOp:$src),
(ins _.RC:$src0, MaskRC:$mask, MemOp:$src),
@@ -7840,6 +7879,7 @@ let Uses = [MXCSR], mayRaiseFPException = 1 in {
_.ImmAllZerosV)>,
EVEX, EVEX_B, Sched<[sched.Folded]>;
}
+ }
}
// Conversion with SAE - suppress all exceptions
multiclass avx512_vcvt_fp_sae<bits<8> opc, string OpcodeStr, X86VectorVTInfo _,
@@ -8944,6 +8984,7 @@ multiclass avx512_cvtph2ps<X86VectorVTInfo _dest, X86VectorVTInfo _src,
(X86any_cvtph2ps (_src.VT _src.RC:$src)),
(X86cvtph2ps (_src.VT _src.RC:$src))>,
T8, PD, Sched<[sched]>;
+ let mayLoad = 1 in
defm rm : AVX512_maskable_split<0x13, MRMSrcMem, _dest, (outs _dest.RC:$dst),
(ins x86memop:$src), "vcvtph2ps", "$src", "$src",
(X86any_cvtph2ps (_src.VT ld_dag)),
@@ -9161,6 +9202,7 @@ multiclass avx512_fp14_s<bits<8> opc, string OpcodeStr, SDNode OpNode,
"$src2, $src1", "$src1, $src2",
(OpNode (_.VT _.RC:$src1), (_.VT _.RC:$src2))>,
EVEX, VVVV, VEX_LIG, Sched<[sched]>;
+ let mayLoad = 1 in
defm rm : AVX512_maskable_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, _.IntScalarMemOp:$src2), OpcodeStr,
"$src2, $src1", "$src1, $src2",
@@ -9621,6 +9663,7 @@ multiclass avx512_rndscale_scalar<bits<8> opc, string OpcodeStr,
(i32 timm:$src3))), "_Int">, EVEX_B,
Sched<[sched]>;
+ let mayLoad = 1 in
defm rmi : AVX512_maskable_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, _.IntScalarMemOp:$src2, i32u8imm:$src3),
OpcodeStr,
@@ -9999,6 +10042,7 @@ multiclass avx512_pmovx_common<bits<8> opc, string OpcodeStr, X86FoldableSchedWr
(DestInfo.VT (OpNode (SrcInfo.VT SrcInfo.RC:$src)))>,
EVEX, Sched<[sched]>;
+ let mayLoad = 1 in
defm rm : AVX512_maskable<opc, MRMSrcMem, DestInfo, (outs DestInfo.RC:$dst),
(ins x86memop:$src), OpcodeStr ,"$src", "$src",
(DestInfo.VT (LdFrag addr:$src))>,
@@ -10601,6 +10645,7 @@ multiclass expand_by_vec_width<bits<8> opc, X86VectorVTInfo _,
(null_frag)>, AVX5128IBase,
Sched<[sched]>;
+ let mayLoad = 1 in
defm rm : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.MemOp:$src1), OpcodeStr, "$src1", "$src1",
(null_frag)>,
@@ -10673,6 +10718,7 @@ multiclass avx512_unary_fp_packed_imm<bits<8> opc, string OpcodeStr,
(OpNode (_.VT _.RC:$src1), (i32 timm:$src2)),
(MaskOpNode (_.VT _.RC:$src1), (i32 timm:$src2))>,
Sched<[sched]>;
+ let mayLoad = 1 in {
defm rmi : AVX512_maskable_split<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.MemOp:$src1, i32u8imm:$src2),
OpcodeStr#_.Suffix, "$src2, $src1", "$src1, $src2",
@@ -10691,6 +10737,7 @@ multiclass avx512_unary_fp_packed_imm<bits<8> opc, string OpcodeStr,
(i32 timm:$src2))>, EVEX_B,
Sched<[sched.Folded, sched.ReadAfterFold]>;
}
+ }
}
//handle instruction reg_vec1 = op(reg_vec2,reg_vec3,imm),{sae}
@@ -10739,6 +10786,7 @@ multiclass avx512_fp_packed_imm<bits<8> opc, string OpcodeStr, SDNode OpNode,
(_.VT _.RC:$src2),
(i32 timm:$src3))>,
Sched<[sched]>;
+ let mayLoad = 1 in {
defm rmi : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, _.MemOp:$src2, i32u8imm:$src3),
OpcodeStr, "$src3, $src2, $src1", "$src1, $src2, $src3",
@@ -10755,6 +10803,7 @@ multiclass avx512_fp_packed_imm<bits<8> opc, string OpcodeStr, SDNode OpNode,
(i32 timm:$src3))>, EVEX_B,
Sched<[sched.Folded, sched.ReadAfterFold]>;
}
+ }
}
//handle instruction reg_vec1 = op(reg_vec2,reg_vec3,imm)
@@ -10770,6 +10819,7 @@ multiclass avx512_3Op_rm_imm8<bits<8> opc, string OpcodeStr, SDNode OpNode,
(SrcInfo.VT SrcInfo.RC:$src2),
(i8 timm:$src3)))>,
Sched<[sched]>;
+ let mayLoad = 1 in
defm rmi : AVX512_maskable<opc, MRMSrcMem, DestInfo, (outs DestInfo.RC:$dst),
(ins SrcInfo.RC:$src1, SrcInfo.MemOp:$src2, u8imm:$src3),
OpcodeStr, "$src3, $src2, $src1", "$src1, $src2, $src3",
@@ -10788,7 +10838,7 @@ multiclass avx512_3Op_imm8<bits<8> opc, string OpcodeStr, SDNode OpNode,
X86FoldableSchedWrite sched, X86VectorVTInfo _>:
avx512_3Op_rm_imm8<opc, OpcodeStr, OpNode, sched, _, _>{
- let ExeDomain = _.ExeDomain, ImmT = Imm8 in
+ let ExeDomain = _.ExeDomain, ImmT = Imm8, mayLoad = 1 in
defm rmbi : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, _.ScalarMemOp:$src2, u8imm:$src3),
OpcodeStr, "$src3, ${src2}"#_.BroadcastStr#", $src1",
@@ -10811,6 +10861,7 @@ multiclass avx512_fp_scalar_imm<bits<8> opc, string OpcodeStr, SDNode OpNode,
(_.VT _.RC:$src2),
(i32 timm:$src3))>,
Sched<[sched]>;
+ let mayLoad = 1 in
defm rmi : AVX512_maskable_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, _.IntScalarMemOp:$src2, i32u8imm:$src3),
OpcodeStr, "$src3, $src2, $src1", "$src1, $src2, $src3",
@@ -10979,6 +11030,7 @@ multiclass avx512_shuff_packed_128_common<bits<8> opc, string OpcodeStr,
(CastInfo.VT (X86Shuf128 _.RC:$src1, _.RC:$src2,
(i8 timm:$src3)))))>,
Sched<[sched]>;
+ let mayLoad = 1 in {
defm rmi : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, _.MemOp:$src2, u8imm:$src3),
OpcodeStr, "$src3, $src2, $src1", "$src1, $src2, $src3",
@@ -11000,6 +11052,7 @@ multiclass avx512_shuff_packed_128_common<bits<8> opc, string OpcodeStr,
(i8 timm:$src3)))))>, EVEX_B,
Sched<[sched.Folded, sched.ReadAfterFold]>;
}
+ }
}
multiclass avx512_shuff_packed_128<string OpcodeStr, X86FoldableSchedWrite sched,
@@ -11031,6 +11084,7 @@ multiclass avx512_valign<bits<8> opc, string OpcodeStr,
OpcodeStr, "$src3, $src2, $src1", "$src1, $src2, $src3",
(_.VT (X86VAlign _.RC:$src1, _.RC:$src2, (i8 timm:$src3)))>,
Sched<[sched]>;
+ let mayLoad = 1 in {
defm rmi : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, _.MemOp:$src2, u8imm:$src3),
OpcodeStr, "$src3, $src2, $src1", "$src1, $src2, $src3",
@@ -11048,6 +11102,7 @@ multiclass avx512_valign<bits<8> opc, string OpcodeStr,
(i8 timm:$src3))>, EVEX_B,
Sched<[sched.Folded, sched.ReadAfterFold]>;
}
+ }
}
multiclass avx512_valign_common<string OpcodeStr, X86SchedWriteWidths sched,
@@ -11202,6 +11257,7 @@ multiclass avx512_unary_rm<bits<8> opc, string OpcodeStr, SDNode OpNode,
(_.VT (OpNode (_.VT _.RC:$src1)))>, EVEX, AVX5128IBase,
Sched<[sched]>;
+ let mayLoad = 1 in
defm rm : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.MemOp:$src1), OpcodeStr,
"$src1", "$src1",
@@ -11214,6 +11270,7 @@ multiclass avx512_unary_rm<bits<8> opc, string OpcodeStr, SDNode OpNode,
multiclass avx512_unary_rmb<bits<8> opc, string OpcodeStr, SDNode OpNode,
X86FoldableSchedWrite sched, X86VectorVTInfo _> :
avx512_unary_rm<opc, OpcodeStr, OpNode, sched, _> {
+ let mayLoad = 1 in
defm rmb : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.ScalarMemOp:$src1), OpcodeStr,
"${src1}"#_.BroadcastStr,
@@ -11368,6 +11425,7 @@ multiclass avx512_movddup_128<bits<8> opc, string OpcodeStr,
(ins _.RC:$src), OpcodeStr, "$src", "$src",
(_.VT (X86VBroadcast (_.VT _.RC:$src)))>, EVEX,
Sched<[sched]>;
+ let mayLoad = 1 in
defm rm : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.ScalarMemOp:$src), OpcodeStr, "$src", "$src",
(_.VT (_.BroadcastLdFrag addr:$src))>,
@@ -11513,6 +11571,7 @@ defm VPEXTRQZ : avx512_extract_elt_dq<"vpextrq", v2i64x_info, GR64>, REX_W;
multiclass avx512_insert_elt_m<bits<8> opc, string OpcodeStr, SDNode OpNode,
X86VectorVTInfo _, PatFrag LdFrag,
SDPatternOperator immoperator> {
+ let mayLoad = 1 in
def rmi : AVX512Ii8<opc, MRMSrcMem, (outs _.RC:$dst),
(ins _.RC:$src1, _.ScalarMemOp:$src2, u8imm:$src3),
OpcodeStr#"\t{$src3, $src2, $src1, $dst|$dst, $src1, $src2, $src3}",
@@ -11650,6 +11709,7 @@ multiclass avx512_psadbw_packed<bits<8> opc, SDNode OpNode,
(OpNode (_src.VT _src.RC:$src1),
(_src.VT _src.RC:$src2))))]>,
Sched<[sched]>;
+ let mayLoad = 1 in
def rm : AVX512BI<opc, MRMSrcMem,
(outs _dst.RC:$dst), (ins _src.RC:$src1, _src.MemOp:$src2),
!strconcat(OpcodeStr, "\t{$src2, $src1, $dst|$dst, $src1, $src2}"),
@@ -11751,6 +11811,7 @@ multiclass avx512_ternlog<bits<8> opc, string OpcodeStr, SDNode OpNode,
(_.VT _.RC:$src3),
(i8 timm:$src4)), 1, 1>,
AVX512AIi8Base, EVEX, VVVV, Sched<[sched]>;
+ let mayLoad = 1 in {
defm rmi : AVX512_maskable_3src<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src2, _.MemOp:$src3, u8imm:$src4),
OpcodeStr, "$src4, $src3, $src2", "$src2, $src3, $src4",
@@ -11770,6 +11831,7 @@ multiclass avx512_ternlog<bits<8> opc, string OpcodeStr, SDNode OpNode,
(i8 timm:$src4)), 1, 0>, EVEX_B,
AVX512AIi8Base, EVEX, VVVV, EVEX_CD8<_.EltSize, CD8VF>,
Sched<[sched.Folded, sched.ReadAfterFold]>;
+ }
}// Constraints = "$src1 = $dst"
// Additional patterns for matching passthru operand in other positions.
@@ -12016,6 +12078,7 @@ multiclass avx512_fixupimm_packed<bits<8> opc, string OpcodeStr,
(_.VT _.RC:$src2),
(TblVT.VT _.RC:$src3),
(i32 timm:$src4))>, Sched<[sched]>;
+ let mayLoad = 1 in {
defm rmi : AVX512_maskable_3src<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src2, _.MemOp:$src3, i32u8imm:$src4),
OpcodeStr#_.Suffix, "$src4, $src3, $src2", "$src2, $src3, $src4",
@@ -12033,6 +12096,7 @@ multiclass avx512_fixupimm_packed<bits<8> opc, string OpcodeStr,
(TblVT.VT (TblVT.BroadcastLdFrag addr:$src3)),
(i32 timm:$src4))>,
EVEX_B, Sched<[sched.Folded, sched.ReadAfterFold]>;
+ }
} // Constraints = "$src1 = $dst"
}
@@ -12075,6 +12139,7 @@ multiclass avx512_fixupimm_scalar<bits<8> opc, string OpcodeStr,
(_src3VT.VT _src3VT.RC:$src3),
(i32 timm:$src4))>,
EVEX_B, Sched<[sched.Folded, sched.ReadAfterFold]>;
+ let mayLoad = 1 in
defm rmi : AVX512_maskable_3src_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src2, _.ScalarMemOp:$src3, i32u8imm:$src4),
OpcodeStr#_.Suffix, "$src4, $src3, $src2", "$src2, $src3, $src4",
@@ -12417,6 +12482,7 @@ multiclass VNNI_rmb<bits<8> Op, string OpStr, SDNode OpNode,
VTI.RC:$src2, VTI.RC:$src3)),
IsCommutable, IsCommutable>,
EVEX, VVVV, T8, Sched<[sched]>;
+ let mayLoad = 1 in {
defm rm : AVX512_maskable_3src<Op, MRMSrcMem, VTI, (outs VTI.RC:$dst),
(ins VTI.RC:$src2, VTI.MemOp:$src3), OpStr,
"$src3, $src2", "$src2, $src3",
@@ -12435,6 +12501,7 @@ multiclass VNNI_rmb<bits<8> Op, string OpStr, SDNode OpNode,
T8, Sched<[sched.Folded, sched.ReadAfterFold,
sched.ReadAfterFold]>;
}
+ }
}
multiclass VNNI_common<bits<8> Op, string OpStr, SDNode OpNode,
@@ -12508,6 +12575,7 @@ multiclass VPSHUFBITQMB_rm<X86FoldableSchedWrite sched, X86VectorVTInfo VTI> {
(X86Vpshufbitqmb_su (VTI.VT VTI.RC:$src1),
(VTI.VT VTI.RC:$src2))>, EVEX, VVVV, T8, PD,
Sched<[sched]>;
+ let mayLoad = 1 in
defm rm : AVX512_maskable_cmp<0x8F, MRMSrcMem, VTI, (outs VTI.KRC:$dst),
(ins VTI.RC:$src1, VTI.MemOp:$src2),
"vpshufbitqmb",
@@ -12557,7 +12625,7 @@ multiclass GF2P8AFFINE_avx512_rmb_imm<bits<8> Op, string OpStr, SDNode OpNode,
X86FoldableSchedWrite sched, X86VectorVTInfo VTI,
X86VectorVTInfo BcstVTI>
: avx512_3Op_rm_imm8<Op, OpStr, OpNode, sched, VTI, VTI> {
- let ExeDomain = VTI.ExeDomain in
+ let ExeDomain = VTI.ExeDomain, mayLoad = 1 in
defm rmbi : AVX512_maskable<Op, MRMSrcMem, VTI, (outs VTI.RC:$dst),
(ins VTI.RC:$src1, BcstVTI.ScalarMemOp:$src2, u8imm:$src3),
OpStr, "$src3, ${src2}"#BcstVTI.BroadcastStr#", $src1",
@@ -12660,6 +12728,7 @@ multiclass avx512_vp2intersect_modes<X86FoldableSchedWrite sched, X86VectorVTInf
_.RC:$src1, (_.VT _.RC:$src2)))]>,
EVEX, VVVV, T8, XD, Sched<[sched]>;
+ let mayLoad = 1 in {
def rm : I<0x68, MRMSrcMem,
(outs _.KRPC:$dst),
(ins _.RC:$src1, _.MemOp:$src2),
@@ -12679,6 +12748,7 @@ multiclass avx512_vp2intersect_modes<X86FoldableSchedWrite sched, X86VectorVTInf
_.RC:$src1, (_.VT (_.BroadcastLdFrag addr:$src2))))]>,
EVEX, VVVV, T8, XD, EVEX_B, EVEX_CD8<_.EltSize, CD8VF>,
Sched<[sched.Folded, sched.ReadAfterFold]>;
+ }
}
multiclass avx512_vp2intersect<X86SchedWriteWidths sched, AVX512VLVectorVTInfo _> {
@@ -12882,6 +12952,7 @@ let Predicates = [HasFP16] in {
// Move word ( r/m16) to Packed word
def VMOVW2SHrr : AVX512<0x6E, MRMSrcReg, (outs VR128X:$dst), (ins GR32:$src),
"vmovw\t{$src, $dst|$dst, $src}", []>, T_MAP5, PD, EVEX, Sched<[WriteVecMoveFromGpr]>;
+let mayLoad = 1 in
def VMOVWrm : AVX512<0x6E, MRMSrcMem, (outs VR128X:$dst), (ins i16mem:$src),
"vmovw\t{$src, $dst|$dst, $src}",
[(set VR128X:$dst,
@@ -13607,6 +13678,7 @@ multiclass avx512_cfmbinop_sh_common<bits<8> opc, string OpcodeStr, SDNode OpNod
(v4f32 (OpNode VR128X:$src1, VR128X:$src2)),
IsCommutable, IsCommutable, IsCommutable,
X86selects, "@earlyclobber $dst">, Sched<[WriteFMAX]>;
+ let mayLoad = 1 in
defm rm : AVX512_maskable<opc, MRMSrcMem, f32x_info, (outs VR128X:$dst),
(ins VR128X:$src1, ssmem:$src2), OpcodeStr,
"$src2, $src1", "$src1, $src2",
diff --git a/llvm/lib/TargetParser/TargetParser.cpp b/llvm/lib/TargetParser/TargetParser.cpp
index 34b09b1..b906690 100644
--- a/llvm/lib/TargetParser/TargetParser.cpp
+++ b/llvm/lib/TargetParser/TargetParser.cpp
@@ -444,6 +444,7 @@ static void fillAMDGCNFeatureMap(StringRef GPU, const Triple &T,
Features["atomic-fmin-fmax-global-f32"] = true;
Features["atomic-fmin-fmax-global-f64"] = true;
Features["wavefrontsize32"] = true;
+ Features["cluster"] = true;
break;
case GK_GFX1201:
case GK_GFX1200:
diff --git a/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/llvm/lib/Transforms/Utils/SCCPSolver.cpp
index af216cd..9693ae6 100644
--- a/llvm/lib/Transforms/Utils/SCCPSolver.cpp
+++ b/llvm/lib/Transforms/Utils/SCCPSolver.cpp
@@ -317,24 +317,29 @@ static Value *simplifyInstruction(SCCPSolver &Solver,
// Early exit if we know nothing about X.
if (LRange.isFullSet())
return nullptr;
- // We are allowed to refine the comparison to either true or false for out
- // of range inputs. Here we refine the comparison to true, i.e. we relax
- // the range check.
- auto NewCR = CR->exactUnionWith(LRange.inverse());
- // TODO: Check if we can narrow the range check to an equality test.
- // E.g, for X in [0, 4), X - 3 u< 2 -> X == 3
- if (!NewCR)
+ auto ConvertCRToICmp =
+ [&](const std::optional<ConstantRange> &NewCR) -> Value * {
+ ICmpInst::Predicate Pred;
+ APInt RHS;
+ // Check if we can represent NewCR as an icmp predicate.
+ if (NewCR && NewCR->getEquivalentICmp(Pred, RHS)) {
+ IRBuilder<NoFolder> Builder(&Inst);
+ Value *NewICmp =
+ Builder.CreateICmp(Pred, X, ConstantInt::get(X->getType(), RHS));
+ InsertedValues.insert(NewICmp);
+ return NewICmp;
+ }
return nullptr;
- ICmpInst::Predicate Pred;
- APInt RHS;
- // Check if we can represent NewCR as an icmp predicate.
- if (NewCR->getEquivalentICmp(Pred, RHS)) {
- IRBuilder<NoFolder> Builder(&Inst);
- Value *NewICmp =
- Builder.CreateICmp(Pred, X, ConstantInt::get(X->getType(), RHS));
- InsertedValues.insert(NewICmp);
- return NewICmp;
- }
+ };
+ // We are allowed to refine the comparison to either true or false for out
+ // of range inputs.
+ // Here we refine the comparison to false, and check if we can narrow the
+ // range check to a simpler test.
+ if (auto *V = ConvertCRToICmp(CR->exactIntersectWith(LRange)))
+ return V;
+ // Here we refine the comparison to true, i.e. we relax the range check.
+ if (auto *V = ConvertCRToICmp(CR->exactUnionWith(LRange.inverse())))
+ return V;
}
}
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 48055ad..148bfa8 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -5734,15 +5734,66 @@ bool SimplifyCFGOpt::simplifyUnreachable(UnreachableInst *UI) {
return Changed;
}
-static bool casesAreContiguous(SmallVectorImpl<ConstantInt *> &Cases) {
+struct ContiguousCasesResult {
+ ConstantInt *Min;
+ ConstantInt *Max;
+ BasicBlock *Dest;
+ BasicBlock *OtherDest;
+ SmallVectorImpl<ConstantInt *> *Cases;
+ SmallVectorImpl<ConstantInt *> *OtherCases;
+};
+
+static std::optional<ContiguousCasesResult>
+findContiguousCases(Value *Condition, SmallVectorImpl<ConstantInt *> &Cases,
+ SmallVectorImpl<ConstantInt *> &OtherCases,
+ BasicBlock *Dest, BasicBlock *OtherDest) {
assert(Cases.size() >= 1);
array_pod_sort(Cases.begin(), Cases.end(), constantIntSortPredicate);
- for (size_t I = 1, E = Cases.size(); I != E; ++I) {
- if (Cases[I - 1]->getValue() != Cases[I]->getValue() + 1)
- return false;
+ const APInt &Min = Cases.back()->getValue();
+ const APInt &Max = Cases.front()->getValue();
+ APInt Offset = Max - Min;
+ size_t ContiguousOffset = Cases.size() - 1;
+ if (Offset == ContiguousOffset) {
+ return ContiguousCasesResult{
+ /*Min=*/Cases.back(),
+ /*Max=*/Cases.front(),
+ /*Dest=*/Dest,
+ /*OtherDest=*/OtherDest,
+ /*Cases=*/&Cases,
+ /*OtherCases=*/&OtherCases,
+ };
}
- return true;
+ ConstantRange CR = computeConstantRange(Condition, /*ForSigned=*/false);
+ // If this is a wrapping contiguous range, that is, [Min, OtherMin] +
+ // [OtherMax, Max] (also [OtherMax, OtherMin]), [OtherMin+1, OtherMax-1] is a
+ // contiguous range for the other destination. N.B. If CR is not a full range,
+ // Max+1 is not equal to Min. It's not continuous in arithmetic.
+ if (Max == CR.getUnsignedMax() && Min == CR.getUnsignedMin()) {
+ assert(Cases.size() >= 2);
+ auto *It =
+ std::adjacent_find(Cases.begin(), Cases.end(), [](auto L, auto R) {
+ return L->getValue() != R->getValue() + 1;
+ });
+ if (It == Cases.end())
+ return std::nullopt;
+ auto [OtherMax, OtherMin] = std::make_pair(*It, *std::next(It));
+ if ((Max - OtherMax->getValue()) + (OtherMin->getValue() - Min) ==
+ Cases.size() - 2) {
+ return ContiguousCasesResult{
+ /*Min=*/cast<ConstantInt>(
+ ConstantInt::get(OtherMin->getType(), OtherMin->getValue() + 1)),
+ /*Max=*/
+ cast<ConstantInt>(
+ ConstantInt::get(OtherMax->getType(), OtherMax->getValue() - 1)),
+ /*Dest=*/OtherDest,
+ /*OtherDest=*/Dest,
+ /*Cases=*/&OtherCases,
+ /*OtherCases=*/&Cases,
+ };
+ }
+ }
+ return std::nullopt;
}
static void createUnreachableSwitchDefault(SwitchInst *Switch,
@@ -5779,7 +5830,6 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI,
bool HasDefault = !SI->defaultDestUnreachable();
auto *BB = SI->getParent();
-
// Partition the cases into two sets with different destinations.
BasicBlock *DestA = HasDefault ? SI->getDefaultDest() : nullptr;
BasicBlock *DestB = nullptr;
@@ -5813,37 +5863,62 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI,
assert(!CasesA.empty() || HasDefault);
// Figure out if one of the sets of cases form a contiguous range.
- SmallVectorImpl<ConstantInt *> *ContiguousCases = nullptr;
- BasicBlock *ContiguousDest = nullptr;
- BasicBlock *OtherDest = nullptr;
- if (!CasesA.empty() && casesAreContiguous(CasesA)) {
- ContiguousCases = &CasesA;
- ContiguousDest = DestA;
- OtherDest = DestB;
- } else if (casesAreContiguous(CasesB)) {
- ContiguousCases = &CasesB;
- ContiguousDest = DestB;
- OtherDest = DestA;
- } else
- return false;
+ std::optional<ContiguousCasesResult> ContiguousCases;
+
+ // Only one icmp is needed when there is only one case.
+ if (!HasDefault && CasesA.size() == 1)
+ ContiguousCases = ContiguousCasesResult{
+ /*Min=*/CasesA[0],
+ /*Max=*/CasesA[0],
+ /*Dest=*/DestA,
+ /*OtherDest=*/DestB,
+ /*Cases=*/&CasesA,
+ /*OtherCases=*/&CasesB,
+ };
+ else if (CasesB.size() == 1)
+ ContiguousCases = ContiguousCasesResult{
+ /*Min=*/CasesB[0],
+ /*Max=*/CasesB[0],
+ /*Dest=*/DestB,
+ /*OtherDest=*/DestA,
+ /*Cases=*/&CasesB,
+ /*OtherCases=*/&CasesA,
+ };
+ // Correctness: Cases to the default destination cannot be contiguous cases.
+ else if (!HasDefault)
+ ContiguousCases =
+ findContiguousCases(SI->getCondition(), CasesA, CasesB, DestA, DestB);
- // Start building the compare and branch.
+ if (!ContiguousCases)
+ ContiguousCases =
+ findContiguousCases(SI->getCondition(), CasesB, CasesA, DestB, DestA);
- Constant *Offset = ConstantExpr::getNeg(ContiguousCases->back());
- Constant *NumCases =
- ConstantInt::get(Offset->getType(), ContiguousCases->size());
+ if (!ContiguousCases)
+ return false;
- Value *Sub = SI->getCondition();
- if (!Offset->isNullValue())
- Sub = Builder.CreateAdd(Sub, Offset, Sub->getName() + ".off");
+ auto [Min, Max, Dest, OtherDest, Cases, OtherCases] = *ContiguousCases;
- Value *Cmp;
+ // Start building the compare and branch.
+
+ Constant *Offset = ConstantExpr::getNeg(Min);
+ Constant *NumCases = ConstantInt::get(Offset->getType(),
+ Max->getValue() - Min->getValue() + 1);
+ BranchInst *NewBI;
+ if (NumCases->isOneValue()) {
+ assert(Max->getValue() == Min->getValue());
+ Value *Cmp = Builder.CreateICmpEQ(SI->getCondition(), Min);
+ NewBI = Builder.CreateCondBr(Cmp, Dest, OtherDest);
+ }
// If NumCases overflowed, then all possible values jump to the successor.
- if (NumCases->isNullValue() && !ContiguousCases->empty())
- Cmp = ConstantInt::getTrue(SI->getContext());
- else
- Cmp = Builder.CreateICmpULT(Sub, NumCases, "switch");
- BranchInst *NewBI = Builder.CreateCondBr(Cmp, ContiguousDest, OtherDest);
+ else if (NumCases->isNullValue() && !Cases->empty()) {
+ NewBI = Builder.CreateBr(Dest);
+ } else {
+ Value *Sub = SI->getCondition();
+ if (!Offset->isNullValue())
+ Sub = Builder.CreateAdd(Sub, Offset, Sub->getName() + ".off");
+ Value *Cmp = Builder.CreateICmpULT(Sub, NumCases, "switch");
+ NewBI = Builder.CreateCondBr(Cmp, Dest, OtherDest);
+ }
// Update weight for the newly-created conditional branch.
if (hasBranchWeightMD(*SI)) {
@@ -5853,7 +5928,7 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI,
uint64_t TrueWeight = 0;
uint64_t FalseWeight = 0;
for (size_t I = 0, E = Weights.size(); I != E; ++I) {
- if (SI->getSuccessor(I) == ContiguousDest)
+ if (SI->getSuccessor(I) == Dest)
TrueWeight += Weights[I];
else
FalseWeight += Weights[I];
@@ -5868,15 +5943,15 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI,
}
// Prune obsolete incoming values off the successors' PHI nodes.
- for (auto BBI = ContiguousDest->begin(); isa<PHINode>(BBI); ++BBI) {
- unsigned PreviousEdges = ContiguousCases->size();
- if (ContiguousDest == SI->getDefaultDest())
+ for (auto BBI = Dest->begin(); isa<PHINode>(BBI); ++BBI) {
+ unsigned PreviousEdges = Cases->size();
+ if (Dest == SI->getDefaultDest())
++PreviousEdges;
for (unsigned I = 0, E = PreviousEdges - 1; I != E; ++I)
cast<PHINode>(BBI)->removeIncomingValue(SI->getParent());
}
for (auto BBI = OtherDest->begin(); isa<PHINode>(BBI); ++BBI) {
- unsigned PreviousEdges = SI->getNumCases() - ContiguousCases->size();
+ unsigned PreviousEdges = OtherCases->size();
if (OtherDest == SI->getDefaultDest())
++PreviousEdges;
for (unsigned I = 0, E = PreviousEdges - 1; I != E; ++I)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 56a3d6d..e434e73 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8201,211 +8201,6 @@ void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
}
}
-/// Create and return a ResumePhi for \p WideIV, unless it is truncated. If the
-/// induction recipe is not canonical, creates a VPDerivedIVRecipe to compute
-/// the end value of the induction.
-static VPInstruction *addResumePhiRecipeForInduction(
- VPWidenInductionRecipe *WideIV, VPBuilder &VectorPHBuilder,
- VPBuilder &ScalarPHBuilder, VPTypeAnalysis &TypeInfo, VPValue *VectorTC) {
- auto *WideIntOrFp = dyn_cast<VPWidenIntOrFpInductionRecipe>(WideIV);
- // Truncated wide inductions resume from the last lane of their vector value
- // in the last vector iteration which is handled elsewhere.
- if (WideIntOrFp && WideIntOrFp->getTruncInst())
- return nullptr;
-
- VPValue *Start = WideIV->getStartValue();
- VPValue *Step = WideIV->getStepValue();
- const InductionDescriptor &ID = WideIV->getInductionDescriptor();
- VPValue *EndValue = VectorTC;
- if (!WideIntOrFp || !WideIntOrFp->isCanonical()) {
- EndValue = VectorPHBuilder.createDerivedIV(
- ID.getKind(), dyn_cast_or_null<FPMathOperator>(ID.getInductionBinOp()),
- Start, VectorTC, Step);
- }
-
- // EndValue is derived from the vector trip count (which has the same type as
- // the widest induction) and thus may be wider than the induction here.
- Type *ScalarTypeOfWideIV = TypeInfo.inferScalarType(WideIV);
- if (ScalarTypeOfWideIV != TypeInfo.inferScalarType(EndValue)) {
- EndValue = VectorPHBuilder.createScalarCast(Instruction::Trunc, EndValue,
- ScalarTypeOfWideIV,
- WideIV->getDebugLoc());
- }
-
- auto *ResumePhiRecipe = ScalarPHBuilder.createScalarPhi(
- {EndValue, Start}, WideIV->getDebugLoc(), "bc.resume.val");
- return ResumePhiRecipe;
-}
-
-/// Create resume phis in the scalar preheader for first-order recurrences,
-/// reductions and inductions, and update the VPIRInstructions wrapping the
-/// original phis in the scalar header. End values for inductions are added to
-/// \p IVEndValues.
-static void addScalarResumePhis(VPRecipeBuilder &Builder, VPlan &Plan,
- DenseMap<VPValue *, VPValue *> &IVEndValues) {
- VPTypeAnalysis TypeInfo(Plan);
- auto *ScalarPH = Plan.getScalarPreheader();
- auto *MiddleVPBB = cast<VPBasicBlock>(ScalarPH->getPredecessors()[0]);
- VPRegionBlock *VectorRegion = Plan.getVectorLoopRegion();
- VPBuilder VectorPHBuilder(
- cast<VPBasicBlock>(VectorRegion->getSinglePredecessor()));
- VPBuilder MiddleBuilder(MiddleVPBB, MiddleVPBB->getFirstNonPhi());
- VPBuilder ScalarPHBuilder(ScalarPH);
- for (VPRecipeBase &ScalarPhiR : Plan.getScalarHeader()->phis()) {
- auto *ScalarPhiIRI = cast<VPIRPhi>(&ScalarPhiR);
-
- // TODO: Extract final value from induction recipe initially, optimize to
- // pre-computed end value together in optimizeInductionExitUsers.
- auto *VectorPhiR =
- cast<VPHeaderPHIRecipe>(Builder.getRecipe(&ScalarPhiIRI->getIRPhi()));
- if (auto *WideIVR = dyn_cast<VPWidenInductionRecipe>(VectorPhiR)) {
- if (VPInstruction *ResumePhi = addResumePhiRecipeForInduction(
- WideIVR, VectorPHBuilder, ScalarPHBuilder, TypeInfo,
- &Plan.getVectorTripCount())) {
- assert(isa<VPPhi>(ResumePhi) && "Expected a phi");
- IVEndValues[WideIVR] = ResumePhi->getOperand(0);
- ScalarPhiIRI->addOperand(ResumePhi);
- continue;
- }
- // TODO: Also handle truncated inductions here. Computing end-values
- // separately should be done as VPlan-to-VPlan optimization, after
- // legalizing all resume values to use the last lane from the loop.
- assert(cast<VPWidenIntOrFpInductionRecipe>(VectorPhiR)->getTruncInst() &&
- "should only skip truncated wide inductions");
- continue;
- }
-
- // The backedge value provides the value to resume coming out of a loop,
- // which for FORs is a vector whose last element needs to be extracted. The
- // start value provides the value if the loop is bypassed.
- bool IsFOR = isa<VPFirstOrderRecurrencePHIRecipe>(VectorPhiR);
- auto *ResumeFromVectorLoop = VectorPhiR->getBackedgeValue();
- assert(VectorRegion->getSingleSuccessor() == Plan.getMiddleBlock() &&
- "Cannot handle loops with uncountable early exits");
- if (IsFOR)
- ResumeFromVectorLoop = MiddleBuilder.createNaryOp(
- VPInstruction::ExtractLastElement, {ResumeFromVectorLoop}, {},
- "vector.recur.extract");
- StringRef Name = IsFOR ? "scalar.recur.init" : "bc.merge.rdx";
- auto *ResumePhiR = ScalarPHBuilder.createScalarPhi(
- {ResumeFromVectorLoop, VectorPhiR->getStartValue()}, {}, Name);
- ScalarPhiIRI->addOperand(ResumePhiR);
- }
-}
-
-/// Handle users in the exit block for first order reductions in the original
-/// exit block. The penultimate value of recurrences is fed to their LCSSA phi
-/// users in the original exit block using the VPIRInstruction wrapping to the
-/// LCSSA phi.
-static void addExitUsersForFirstOrderRecurrences(VPlan &Plan, VFRange &Range) {
- VPRegionBlock *VectorRegion = Plan.getVectorLoopRegion();
- auto *ScalarPHVPBB = Plan.getScalarPreheader();
- auto *MiddleVPBB = Plan.getMiddleBlock();
- VPBuilder ScalarPHBuilder(ScalarPHVPBB);
- VPBuilder MiddleBuilder(MiddleVPBB, MiddleVPBB->getFirstNonPhi());
-
- auto IsScalableOne = [](ElementCount VF) -> bool {
- return VF == ElementCount::getScalable(1);
- };
-
- for (auto &HeaderPhi : VectorRegion->getEntryBasicBlock()->phis()) {
- auto *FOR = dyn_cast<VPFirstOrderRecurrencePHIRecipe>(&HeaderPhi);
- if (!FOR)
- continue;
-
- assert(VectorRegion->getSingleSuccessor() == Plan.getMiddleBlock() &&
- "Cannot handle loops with uncountable early exits");
-
- // This is the second phase of vectorizing first-order recurrences, creating
- // extract for users outside the loop. An overview of the transformation is
- // described below. Suppose we have the following loop with some use after
- // the loop of the last a[i-1],
- //
- // for (int i = 0; i < n; ++i) {
- // t = a[i - 1];
- // b[i] = a[i] - t;
- // }
- // use t;
- //
- // There is a first-order recurrence on "a". For this loop, the shorthand
- // scalar IR looks like:
- //
- // scalar.ph:
- // s.init = a[-1]
- // br scalar.body
- //
- // scalar.body:
- // i = phi [0, scalar.ph], [i+1, scalar.body]
- // s1 = phi [s.init, scalar.ph], [s2, scalar.body]
- // s2 = a[i]
- // b[i] = s2 - s1
- // br cond, scalar.body, exit.block
- //
- // exit.block:
- // use = lcssa.phi [s1, scalar.body]
- //
- // In this example, s1 is a recurrence because it's value depends on the
- // previous iteration. In the first phase of vectorization, we created a
- // VPFirstOrderRecurrencePHIRecipe v1 for s1. Now we create the extracts
- // for users in the scalar preheader and exit block.
- //
- // vector.ph:
- // v_init = vector(..., ..., ..., a[-1])
- // br vector.body
- //
- // vector.body
- // i = phi [0, vector.ph], [i+4, vector.body]
- // v1 = phi [v_init, vector.ph], [v2, vector.body]
- // v2 = a[i, i+1, i+2, i+3]
- // b[i] = v2 - v1
- // // Next, third phase will introduce v1' = splice(v1(3), v2(0, 1, 2))
- // b[i, i+1, i+2, i+3] = v2 - v1
- // br cond, vector.body, middle.block
- //
- // middle.block:
- // vector.recur.extract.for.phi = v2(2)
- // vector.recur.extract = v2(3)
- // br cond, scalar.ph, exit.block
- //
- // scalar.ph:
- // scalar.recur.init = phi [vector.recur.extract, middle.block],
- // [s.init, otherwise]
- // br scalar.body
- //
- // scalar.body:
- // i = phi [0, scalar.ph], [i+1, scalar.body]
- // s1 = phi [scalar.recur.init, scalar.ph], [s2, scalar.body]
- // s2 = a[i]
- // b[i] = s2 - s1
- // br cond, scalar.body, exit.block
- //
- // exit.block:
- // lo = lcssa.phi [s1, scalar.body],
- // [vector.recur.extract.for.phi, middle.block]
- //
- // Now update VPIRInstructions modeling LCSSA phis in the exit block.
- // Extract the penultimate value of the recurrence and use it as operand for
- // the VPIRInstruction modeling the phi.
- for (VPUser *U : FOR->users()) {
- using namespace llvm::VPlanPatternMatch;
- if (!match(U, m_ExtractLastElement(m_Specific(FOR))))
- continue;
- // For VF vscale x 1, if vscale = 1, we are unable to extract the
- // penultimate value of the recurrence. Instead we rely on the existing
- // extract of the last element from the result of
- // VPInstruction::FirstOrderRecurrenceSplice.
- // TODO: Consider vscale_range info and UF.
- if (LoopVectorizationPlanner::getDecisionAndClampRange(IsScalableOne,
- Range))
- return;
- VPValue *PenultimateElement = MiddleBuilder.createNaryOp(
- VPInstruction::ExtractPenultimateElement, {FOR->getBackedgeValue()},
- {}, "vector.recur.extract.for.phi");
- cast<VPInstruction>(U)->replaceAllUsesWith(PenultimateElement);
- }
- }
-}
-
VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
VPlanPtr Plan, VFRange &Range, LoopVersioning *LVer) {
@@ -8598,9 +8393,11 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
R->setOperand(1, WideIV->getStepValue());
}
- addExitUsersForFirstOrderRecurrences(*Plan, Range);
+ VPlanTransforms::runPass(
+ VPlanTransforms::addExitUsersForFirstOrderRecurrences, *Plan, Range);
DenseMap<VPValue *, VPValue *> IVEndValues;
- addScalarResumePhis(RecipeBuilder, *Plan, IVEndValues);
+ VPlanTransforms::runPass(VPlanTransforms::addScalarResumePhis, *Plan,
+ RecipeBuilder, IVEndValues);
// ---------------------------------------------------------------------------
// Transform initial VPlan: Apply previously taken decisions, in order, to
@@ -8711,7 +8508,8 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlan(VFRange &Range) {
DenseMap<VPValue *, VPValue *> IVEndValues;
// TODO: IVEndValues are not used yet in the native path, to optimize exit
// values.
- addScalarResumePhis(RecipeBuilder, *Plan, IVEndValues);
+ VPlanTransforms::runPass(VPlanTransforms::addScalarResumePhis, *Plan,
+ RecipeBuilder, IVEndValues);
assert(verifyVPlanIsValid(*Plan) && "VPlan is invalid");
return Plan;
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index fedca65..91c3d42 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -10620,7 +10620,8 @@ class InstructionsCompatibilityAnalysis {
/// Checks if the opcode is supported as the main opcode for copyable
/// elements.
static bool isSupportedOpcode(const unsigned Opcode) {
- return Opcode == Instruction::Add || Opcode == Instruction::LShr;
+ return Opcode == Instruction::Add || Opcode == Instruction::LShr ||
+ Opcode == Instruction::Shl;
}
/// Identifies the best candidate value, which represents main opcode
@@ -10937,6 +10938,7 @@ public:
switch (MainOpcode) {
case Instruction::Add:
case Instruction::LShr:
+ case Instruction::Shl:
VectorCost = TTI.getArithmeticInstrCost(MainOpcode, VecTy, Kind);
break;
default:
@@ -22006,6 +22008,8 @@ bool BoUpSLP::collectValuesToDemote(
return all_of(E.Scalars, [&](Value *V) {
if (isa<PoisonValue>(V))
return true;
+ if (E.isCopyableElement(V))
+ return true;
auto *I = cast<Instruction>(V);
KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
return AmtKnownBits.getMaxValue().ult(BitWidth);
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index ca63bf3..ebf833e 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -4198,3 +4198,202 @@ void VPlanTransforms::addBranchWeightToMiddleTerminator(
MDB.createBranchWeights({1, VectorStep - 1}, /*IsExpected=*/false);
MiddleTerm->addMetadata(LLVMContext::MD_prof, BranchWeights);
}
+
+/// Create and return a ResumePhi for \p WideIV, unless it is truncated. If the
+/// induction recipe is not canonical, creates a VPDerivedIVRecipe to compute
+/// the end value of the induction.
+static VPInstruction *addResumePhiRecipeForInduction(
+ VPWidenInductionRecipe *WideIV, VPBuilder &VectorPHBuilder,
+ VPBuilder &ScalarPHBuilder, VPTypeAnalysis &TypeInfo, VPValue *VectorTC) {
+ auto *WideIntOrFp = dyn_cast<VPWidenIntOrFpInductionRecipe>(WideIV);
+ // Truncated wide inductions resume from the last lane of their vector value
+ // in the last vector iteration which is handled elsewhere.
+ if (WideIntOrFp && WideIntOrFp->getTruncInst())
+ return nullptr;
+
+ VPValue *Start = WideIV->getStartValue();
+ VPValue *Step = WideIV->getStepValue();
+ const InductionDescriptor &ID = WideIV->getInductionDescriptor();
+ VPValue *EndValue = VectorTC;
+ if (!WideIntOrFp || !WideIntOrFp->isCanonical()) {
+ EndValue = VectorPHBuilder.createDerivedIV(
+ ID.getKind(), dyn_cast_or_null<FPMathOperator>(ID.getInductionBinOp()),
+ Start, VectorTC, Step);
+ }
+
+ // EndValue is derived from the vector trip count (which has the same type as
+ // the widest induction) and thus may be wider than the induction here.
+ Type *ScalarTypeOfWideIV = TypeInfo.inferScalarType(WideIV);
+ if (ScalarTypeOfWideIV != TypeInfo.inferScalarType(EndValue)) {
+ EndValue = VectorPHBuilder.createScalarCast(Instruction::Trunc, EndValue,
+ ScalarTypeOfWideIV,
+ WideIV->getDebugLoc());
+ }
+
+ auto *ResumePhiRecipe = ScalarPHBuilder.createScalarPhi(
+ {EndValue, Start}, WideIV->getDebugLoc(), "bc.resume.val");
+ return ResumePhiRecipe;
+}
+
+void VPlanTransforms::addScalarResumePhis(
+ VPlan &Plan, VPRecipeBuilder &Builder,
+ DenseMap<VPValue *, VPValue *> &IVEndValues) {
+ VPTypeAnalysis TypeInfo(Plan);
+ auto *ScalarPH = Plan.getScalarPreheader();
+ auto *MiddleVPBB = cast<VPBasicBlock>(ScalarPH->getPredecessors()[0]);
+ VPRegionBlock *VectorRegion = Plan.getVectorLoopRegion();
+ VPBuilder VectorPHBuilder(
+ cast<VPBasicBlock>(VectorRegion->getSinglePredecessor()));
+ VPBuilder MiddleBuilder(MiddleVPBB, MiddleVPBB->getFirstNonPhi());
+ VPBuilder ScalarPHBuilder(ScalarPH);
+ for (VPRecipeBase &ScalarPhiR : Plan.getScalarHeader()->phis()) {
+ auto *ScalarPhiIRI = cast<VPIRPhi>(&ScalarPhiR);
+
+ // TODO: Extract final value from induction recipe initially, optimize to
+ // pre-computed end value together in optimizeInductionExitUsers.
+ auto *VectorPhiR =
+ cast<VPHeaderPHIRecipe>(Builder.getRecipe(&ScalarPhiIRI->getIRPhi()));
+ if (auto *WideIVR = dyn_cast<VPWidenInductionRecipe>(VectorPhiR)) {
+ if (VPInstruction *ResumePhi = addResumePhiRecipeForInduction(
+ WideIVR, VectorPHBuilder, ScalarPHBuilder, TypeInfo,
+ &Plan.getVectorTripCount())) {
+ assert(isa<VPPhi>(ResumePhi) && "Expected a phi");
+ IVEndValues[WideIVR] = ResumePhi->getOperand(0);
+ ScalarPhiIRI->addOperand(ResumePhi);
+ continue;
+ }
+ // TODO: Also handle truncated inductions here. Computing end-values
+ // separately should be done as VPlan-to-VPlan optimization, after
+ // legalizing all resume values to use the last lane from the loop.
+ assert(cast<VPWidenIntOrFpInductionRecipe>(VectorPhiR)->getTruncInst() &&
+ "should only skip truncated wide inductions");
+ continue;
+ }
+
+ // The backedge value provides the value to resume coming out of a loop,
+ // which for FORs is a vector whose last element needs to be extracted. The
+ // start value provides the value if the loop is bypassed.
+ bool IsFOR = isa<VPFirstOrderRecurrencePHIRecipe>(VectorPhiR);
+ auto *ResumeFromVectorLoop = VectorPhiR->getBackedgeValue();
+ assert(VectorRegion->getSingleSuccessor() == Plan.getMiddleBlock() &&
+ "Cannot handle loops with uncountable early exits");
+ if (IsFOR)
+ ResumeFromVectorLoop = MiddleBuilder.createNaryOp(
+ VPInstruction::ExtractLastElement, {ResumeFromVectorLoop}, {},
+ "vector.recur.extract");
+ StringRef Name = IsFOR ? "scalar.recur.init" : "bc.merge.rdx";
+ auto *ResumePhiR = ScalarPHBuilder.createScalarPhi(
+ {ResumeFromVectorLoop, VectorPhiR->getStartValue()}, {}, Name);
+ ScalarPhiIRI->addOperand(ResumePhiR);
+ }
+}
+
+void VPlanTransforms::addExitUsersForFirstOrderRecurrences(VPlan &Plan,
+ VFRange &Range) {
+ VPRegionBlock *VectorRegion = Plan.getVectorLoopRegion();
+ auto *ScalarPHVPBB = Plan.getScalarPreheader();
+ auto *MiddleVPBB = Plan.getMiddleBlock();
+ VPBuilder ScalarPHBuilder(ScalarPHVPBB);
+ VPBuilder MiddleBuilder(MiddleVPBB, MiddleVPBB->getFirstNonPhi());
+
+ auto IsScalableOne = [](ElementCount VF) -> bool {
+ return VF == ElementCount::getScalable(1);
+ };
+
+ for (auto &HeaderPhi : VectorRegion->getEntryBasicBlock()->phis()) {
+ auto *FOR = dyn_cast<VPFirstOrderRecurrencePHIRecipe>(&HeaderPhi);
+ if (!FOR)
+ continue;
+
+ assert(VectorRegion->getSingleSuccessor() == Plan.getMiddleBlock() &&
+ "Cannot handle loops with uncountable early exits");
+
+ // This is the second phase of vectorizing first-order recurrences, creating
+ // extract for users outside the loop. An overview of the transformation is
+ // described below. Suppose we have the following loop with some use after
+ // the loop of the last a[i-1],
+ //
+ // for (int i = 0; i < n; ++i) {
+ // t = a[i - 1];
+ // b[i] = a[i] - t;
+ // }
+ // use t;
+ //
+ // There is a first-order recurrence on "a". For this loop, the shorthand
+ // scalar IR looks like:
+ //
+ // scalar.ph:
+ // s.init = a[-1]
+ // br scalar.body
+ //
+ // scalar.body:
+ // i = phi [0, scalar.ph], [i+1, scalar.body]
+ // s1 = phi [s.init, scalar.ph], [s2, scalar.body]
+ // s2 = a[i]
+ // b[i] = s2 - s1
+ // br cond, scalar.body, exit.block
+ //
+ // exit.block:
+ // use = lcssa.phi [s1, scalar.body]
+ //
+ // In this example, s1 is a recurrence because it's value depends on the
+ // previous iteration. In the first phase of vectorization, we created a
+ // VPFirstOrderRecurrencePHIRecipe v1 for s1. Now we create the extracts
+ // for users in the scalar preheader and exit block.
+ //
+ // vector.ph:
+ // v_init = vector(..., ..., ..., a[-1])
+ // br vector.body
+ //
+ // vector.body
+ // i = phi [0, vector.ph], [i+4, vector.body]
+ // v1 = phi [v_init, vector.ph], [v2, vector.body]
+ // v2 = a[i, i+1, i+2, i+3]
+ // b[i] = v2 - v1
+ // // Next, third phase will introduce v1' = splice(v1(3), v2(0, 1, 2))
+ // b[i, i+1, i+2, i+3] = v2 - v1
+ // br cond, vector.body, middle.block
+ //
+ // middle.block:
+ // vector.recur.extract.for.phi = v2(2)
+ // vector.recur.extract = v2(3)
+ // br cond, scalar.ph, exit.block
+ //
+ // scalar.ph:
+ // scalar.recur.init = phi [vector.recur.extract, middle.block],
+ // [s.init, otherwise]
+ // br scalar.body
+ //
+ // scalar.body:
+ // i = phi [0, scalar.ph], [i+1, scalar.body]
+ // s1 = phi [scalar.recur.init, scalar.ph], [s2, scalar.body]
+ // s2 = a[i]
+ // b[i] = s2 - s1
+ // br cond, scalar.body, exit.block
+ //
+ // exit.block:
+ // lo = lcssa.phi [s1, scalar.body],
+ // [vector.recur.extract.for.phi, middle.block]
+ //
+ // Now update VPIRInstructions modeling LCSSA phis in the exit block.
+ // Extract the penultimate value of the recurrence and use it as operand for
+ // the VPIRInstruction modeling the phi.
+ for (VPUser *U : FOR->users()) {
+ using namespace llvm::VPlanPatternMatch;
+ if (!match(U, m_ExtractLastElement(m_Specific(FOR))))
+ continue;
+ // For VF vscale x 1, if vscale = 1, we are unable to extract the
+ // penultimate value of the recurrence. Instead we rely on the existing
+ // extract of the last element from the result of
+ // VPInstruction::FirstOrderRecurrenceSplice.
+ // TODO: Consider vscale_range info and UF.
+ if (LoopVectorizationPlanner::getDecisionAndClampRange(IsScalableOne,
+ Range))
+ return;
+ VPValue *PenultimateElement = MiddleBuilder.createNaryOp(
+ VPInstruction::ExtractPenultimateElement, {FOR->getBackedgeValue()},
+ {}, "vector.recur.extract.for.phi");
+ cast<VPInstruction>(U)->replaceAllUsesWith(PenultimateElement);
+ }
+ }
+}
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
index 2f00e51..5a8a2bb 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
@@ -363,6 +363,19 @@ struct VPlanTransforms {
static void
addBranchWeightToMiddleTerminator(VPlan &Plan, ElementCount VF,
std::optional<unsigned> VScaleForTuning);
+
+ /// Create resume phis in the scalar preheader for first-order recurrences,
+ /// reductions and inductions, and update the VPIRInstructions wrapping the
+ /// original phis in the scalar header. End values for inductions are added to
+ /// \p IVEndValues.
+ static void addScalarResumePhis(VPlan &Plan, VPRecipeBuilder &Builder,
+ DenseMap<VPValue *, VPValue *> &IVEndValues);
+
+ /// Handle users in the exit block for first order reductions in the original
+ /// exit block. The penultimate value of recurrences is fed to their LCSSA phi
+ /// users in the original exit block using the VPIRInstruction wrapping to the
+ /// LCSSA phi.
+ static void addExitUsersForFirstOrderRecurrences(VPlan &Plan, VFRange &Range);
};
} // namespace llvm