diff options
Diffstat (limited to 'llvm/lib')
84 files changed, 2646 insertions, 484 deletions
diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp index 757f689..c4fee39 100644 --- a/llvm/lib/Analysis/InlineCost.cpp +++ b/llvm/lib/Analysis/InlineCost.cpp @@ -751,7 +751,7 @@ class InlineCostCallAnalyzer final : public CallAnalyzer { if (CA.analyze().isSuccess()) { // We were able to inline the indirect call! Subtract the cost from the // threshold to get the bonus we want to apply, but don't go below zero. - Cost -= std::max(0, CA.getThreshold() - CA.getCost()); + addCost(-std::max(0, CA.getThreshold() - CA.getCost())); } } else // Otherwise simply add the cost for merely making the call. @@ -1191,7 +1191,7 @@ class InlineCostCallAnalyzer final : public CallAnalyzer { // If this function uses the coldcc calling convention, prefer not to inline // it. if (F.getCallingConv() == CallingConv::Cold) - Cost += InlineConstants::ColdccPenalty; + addCost(InlineConstants::ColdccPenalty); LLVM_DEBUG(dbgs() << " Initial cost: " << Cost << "\n"); @@ -2193,7 +2193,7 @@ void InlineCostCallAnalyzer::updateThreshold(CallBase &Call, Function &Callee) { // the cost of inlining it drops dramatically. It may seem odd to update // Cost in updateThreshold, but the bonus depends on the logic in this method. if (isSoleCallToLocalFunction(Call, F)) { - Cost -= LastCallToStaticBonus; + addCost(-LastCallToStaticBonus); StaticBonusApplied = LastCallToStaticBonus; } } diff --git a/llvm/lib/Analysis/LoopCacheAnalysis.cpp b/llvm/lib/Analysis/LoopCacheAnalysis.cpp index 050c327..424a7fe 100644 --- a/llvm/lib/Analysis/LoopCacheAnalysis.cpp +++ b/llvm/lib/Analysis/LoopCacheAnalysis.cpp @@ -436,10 +436,9 @@ bool IndexedReference::delinearize(const LoopInfo &LI) { const SCEV *StepRec = AccessFnAR ? AccessFnAR->getStepRecurrence(SE) : nullptr; if (StepRec && SE.isKnownNegative(StepRec)) - AccessFn = SE.getAddRecExpr(AccessFnAR->getStart(), - SE.getNegativeSCEV(StepRec), - AccessFnAR->getLoop(), - AccessFnAR->getNoWrapFlags()); + AccessFn = SE.getAddRecExpr( + AccessFnAR->getStart(), SE.getNegativeSCEV(StepRec), + AccessFnAR->getLoop(), SCEV::NoWrapFlags::FlagAnyWrap); const SCEV *Div = SE.getUDivExactExpr(AccessFn, ElemSize); Subscripts.push_back(Div); Sizes.push_back(ElemSize); diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp index bf62623..c47a1c1 100644 --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -1001,13 +1001,25 @@ InstructionCost TargetTransformInfo::getShuffleCost( TargetTransformInfo::PartialReductionExtendKind TargetTransformInfo::getPartialReductionExtendKind(Instruction *I) { - if (isa<SExtInst>(I)) - return PR_SignExtend; - if (isa<ZExtInst>(I)) - return PR_ZeroExtend; + if (auto *Cast = dyn_cast<CastInst>(I)) + return getPartialReductionExtendKind(Cast->getOpcode()); return PR_None; } +TargetTransformInfo::PartialReductionExtendKind +TargetTransformInfo::getPartialReductionExtendKind( + Instruction::CastOps CastOpc) { + switch (CastOpc) { + case Instruction::CastOps::ZExt: + return PR_ZeroExtend; + case Instruction::CastOps::SExt: + return PR_SignExtend; + default: + return PR_None; + } + llvm_unreachable("Unhandled cast opcode"); +} + TTI::CastContextHint TargetTransformInfo::getCastContextHint(const Instruction *I) { if (!I) diff --git a/llvm/lib/CGData/CodeGenDataReader.cpp b/llvm/lib/CGData/CodeGenDataReader.cpp index b1cd939..aeb4a4d 100644 --- a/llvm/lib/CGData/CodeGenDataReader.cpp +++ b/llvm/lib/CGData/CodeGenDataReader.cpp @@ -125,7 +125,7 @@ Error IndexedCodeGenDataReader::read() { FunctionMapRecord.setReadStableFunctionMapNames( IndexedCodeGenDataReadFunctionMapNames); if (IndexedCodeGenDataLazyLoading) - FunctionMapRecord.lazyDeserialize(SharedDataBuffer, + FunctionMapRecord.lazyDeserialize(std::move(SharedDataBuffer), Header.StableFunctionMapOffset); else FunctionMapRecord.deserialize(Ptr); diff --git a/llvm/lib/CGData/StableFunctionMap.cpp b/llvm/lib/CGData/StableFunctionMap.cpp index 46e04bd..d0fae3a 100644 --- a/llvm/lib/CGData/StableFunctionMap.cpp +++ b/llvm/lib/CGData/StableFunctionMap.cpp @@ -137,6 +137,7 @@ size_t StableFunctionMap::size(SizeType Type) const { const StableFunctionMap::StableFunctionEntries & StableFunctionMap::at(HashFuncsMapType::key_type FunctionHash) const { auto It = HashToFuncs.find(FunctionHash); + assert(It != HashToFuncs.end() && "FunctionHash not found!"); if (isLazilyLoaded()) deserializeLazyLoadingEntry(It); return It->second.Entries; diff --git a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp index fefde64f..8aa488f 100644 --- a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp +++ b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp @@ -41,6 +41,7 @@ #include "llvm/CodeGen/GCMetadataPrinter.h" #include "llvm/CodeGen/LazyMachineBlockFrequencyInfo.h" #include "llvm/CodeGen/MachineBasicBlock.h" +#include "llvm/CodeGen/MachineBlockHashInfo.h" #include "llvm/CodeGen/MachineBranchProbabilityInfo.h" #include "llvm/CodeGen/MachineConstantPool.h" #include "llvm/CodeGen/MachineDominators.h" @@ -184,6 +185,8 @@ static cl::opt<bool> PrintLatency( cl::desc("Print instruction latencies as verbose asm comments"), cl::Hidden, cl::init(false)); +extern cl::opt<bool> EmitBBHash; + STATISTIC(EmittedInsts, "Number of machine instrs printed"); char AsmPrinter::ID = 0; @@ -474,6 +477,8 @@ void AsmPrinter::getAnalysisUsage(AnalysisUsage &AU) const { AU.addRequired<GCModuleInfo>(); AU.addRequired<LazyMachineBlockFrequencyInfoPass>(); AU.addRequired<MachineBranchProbabilityInfoWrapperPass>(); + if (EmitBBHash) + AU.addRequired<MachineBlockHashInfo>(); } bool AsmPrinter::doInitialization(Module &M) { @@ -1434,14 +1439,11 @@ getBBAddrMapFeature(const MachineFunction &MF, int NumMBBSectionRanges, "BB entries info is required for BBFreq and BrProb " "features"); } - return {FuncEntryCountEnabled, - BBFreqEnabled, - BrProbEnabled, + return {FuncEntryCountEnabled, BBFreqEnabled, BrProbEnabled, MF.hasBBSections() && NumMBBSectionRanges > 1, // Use static_cast to avoid breakage of tests on windows. - static_cast<bool>(BBAddrMapSkipEmitBBEntries), - HasCalls, - false}; + static_cast<bool>(BBAddrMapSkipEmitBBEntries), HasCalls, + static_cast<bool>(EmitBBHash)}; } void AsmPrinter::emitBBAddrMapSection(const MachineFunction &MF) { @@ -1500,6 +1502,9 @@ void AsmPrinter::emitBBAddrMapSection(const MachineFunction &MF) { PrevMBBEndSymbol = MBBSymbol; } + auto MBHI = + Features.BBHash ? &getAnalysis<MachineBlockHashInfo>() : nullptr; + if (!Features.OmitBBEntries) { OutStreamer->AddComment("BB id"); // Emit the BB ID for this basic block. @@ -1527,6 +1532,10 @@ void AsmPrinter::emitBBAddrMapSection(const MachineFunction &MF) { emitLabelDifferenceAsULEB128(MBB.getEndSymbol(), CurrentLabel); // Emit the Metadata. OutStreamer->emitULEB128IntValue(getBBAddrMapMetadata(MBB)); + // Emit the Hash. + if (MBHI) { + OutStreamer->emitInt64(MBHI->getMBBHash(MBB)); + } } PrevMBBEndSymbol = MBB.getEndSymbol(); } diff --git a/llvm/lib/CodeGen/CMakeLists.txt b/llvm/lib/CodeGen/CMakeLists.txt index b6872605..4373c53 100644 --- a/llvm/lib/CodeGen/CMakeLists.txt +++ b/llvm/lib/CodeGen/CMakeLists.txt @@ -108,6 +108,7 @@ add_llvm_component_library(LLVMCodeGen LowerEmuTLS.cpp MachineBasicBlock.cpp MachineBlockFrequencyInfo.cpp + MachineBlockHashInfo.cpp MachineBlockPlacement.cpp MachineBranchProbabilityInfo.cpp MachineCFGPrinter.cpp diff --git a/llvm/lib/CodeGen/ExpandFp.cpp b/llvm/lib/CodeGen/ExpandFp.cpp index 2b5ced3..f44eb22 100644 --- a/llvm/lib/CodeGen/ExpandFp.cpp +++ b/llvm/lib/CodeGen/ExpandFp.cpp @@ -1108,8 +1108,8 @@ public: }; } // namespace -ExpandFpPass::ExpandFpPass(const TargetMachine *TM, CodeGenOptLevel OptLevel) - : TM(TM), OptLevel(OptLevel) {} +ExpandFpPass::ExpandFpPass(const TargetMachine &TM, CodeGenOptLevel OptLevel) + : TM(&TM), OptLevel(OptLevel) {} void ExpandFpPass::printPipeline( raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { diff --git a/llvm/lib/CodeGen/MachineBlockHashInfo.cpp b/llvm/lib/CodeGen/MachineBlockHashInfo.cpp new file mode 100644 index 0000000..c4d9c0f --- /dev/null +++ b/llvm/lib/CodeGen/MachineBlockHashInfo.cpp @@ -0,0 +1,115 @@ +//===- llvm/CodeGen/MachineBlockHashInfo.cpp---------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Compute the hashes of basic blocks. +// +//===----------------------------------------------------------------------===// + +#include "llvm/CodeGen/MachineBlockHashInfo.h" +#include "llvm/CodeGen/Passes.h" +#include "llvm/InitializePasses.h" +#include "llvm/Target/TargetMachine.h" + +using namespace llvm; + +uint64_t hashBlock(const MachineBasicBlock &MBB, bool HashOperands) { + uint64_t Hash = 0; + for (const MachineInstr &MI : MBB) { + if (MI.isMetaInstruction() || MI.isTerminator()) + continue; + Hash = hashing::detail::hash_16_bytes(Hash, MI.getOpcode()); + if (HashOperands) { + for (unsigned i = 0; i < MI.getNumOperands(); i++) { + Hash = + hashing::detail::hash_16_bytes(Hash, hash_value(MI.getOperand(i))); + } + } + } + return Hash; +} + +/// Fold a 64-bit integer to a 16-bit one. +uint16_t fold_64_to_16(const uint64_t Value) { + uint16_t Res = static_cast<uint16_t>(Value); + Res ^= static_cast<uint16_t>(Value >> 16); + Res ^= static_cast<uint16_t>(Value >> 32); + Res ^= static_cast<uint16_t>(Value >> 48); + return Res; +} + +INITIALIZE_PASS(MachineBlockHashInfo, "machine-block-hash", + "Machine Block Hash Analysis", true, true) + +char MachineBlockHashInfo::ID = 0; + +MachineBlockHashInfo::MachineBlockHashInfo() : MachineFunctionPass(ID) { + initializeMachineBlockHashInfoPass(*PassRegistry::getPassRegistry()); +} + +void MachineBlockHashInfo::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + MachineFunctionPass::getAnalysisUsage(AU); +} + +struct CollectHashInfo { + uint64_t Offset; + uint64_t OpcodeHash; + uint64_t InstrHash; + uint64_t NeighborHash; +}; + +bool MachineBlockHashInfo::runOnMachineFunction(MachineFunction &F) { + DenseMap<const MachineBasicBlock *, CollectHashInfo> HashInfos; + uint16_t Offset = 0; + // Initialize hash components + for (const MachineBasicBlock &MBB : F) { + // offset of the machine basic block + HashInfos[&MBB].Offset = Offset; + Offset += MBB.size(); + // Hashing opcodes + HashInfos[&MBB].OpcodeHash = hashBlock(MBB, /*HashOperands=*/false); + // Hash complete instructions + HashInfos[&MBB].InstrHash = hashBlock(MBB, /*HashOperands=*/true); + } + + // Initialize neighbor hash + for (const MachineBasicBlock &MBB : F) { + uint64_t Hash = HashInfos[&MBB].OpcodeHash; + // Append hashes of successors + for (const MachineBasicBlock *SuccMBB : MBB.successors()) { + uint64_t SuccHash = HashInfos[SuccMBB].OpcodeHash; + Hash = hashing::detail::hash_16_bytes(Hash, SuccHash); + } + // Append hashes of predecessors + for (const MachineBasicBlock *PredMBB : MBB.predecessors()) { + uint64_t PredHash = HashInfos[PredMBB].OpcodeHash; + Hash = hashing::detail::hash_16_bytes(Hash, PredHash); + } + HashInfos[&MBB].NeighborHash = Hash; + } + + // Assign hashes + for (const MachineBasicBlock &MBB : F) { + const auto &HashInfo = HashInfos[&MBB]; + BlendedBlockHash BlendedHash(fold_64_to_16(HashInfo.Offset), + fold_64_to_16(HashInfo.OpcodeHash), + fold_64_to_16(HashInfo.InstrHash), + fold_64_to_16(HashInfo.NeighborHash)); + MBBHashInfo[&MBB] = BlendedHash.combine(); + } + + return false; +} + +uint64_t MachineBlockHashInfo::getMBBHash(const MachineBasicBlock &MBB) { + return MBBHashInfo[&MBB]; +} + +MachineFunctionPass *llvm::createMachineBlockHashInfoPass() { + return new MachineBlockHashInfo(); +} diff --git a/llvm/lib/CodeGen/PreISelIntrinsicLowering.cpp b/llvm/lib/CodeGen/PreISelIntrinsicLowering.cpp index f54e2f2..620d3d3 100644 --- a/llvm/lib/CodeGen/PreISelIntrinsicLowering.cpp +++ b/llvm/lib/CodeGen/PreISelIntrinsicLowering.cpp @@ -593,7 +593,7 @@ bool PreISelIntrinsicLowering::lowerIntrinsics(Module &M) const { case Intrinsic::log: Changed |= forEachCall(F, [&](CallInst *CI) { Type *Ty = CI->getArgOperand(0)->getType(); - if (!isa<ScalableVectorType>(Ty)) + if (!TM || !isa<ScalableVectorType>(Ty)) return false; const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering(); unsigned Op = TL->IntrinsicIDToISD(F.getIntrinsicID()); diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h index 603dc34..9656a30 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -890,6 +890,7 @@ private: SDValue ScalarizeVecRes_UnaryOpWithExtraInput(SDNode *N); SDValue ScalarizeVecRes_INSERT_VECTOR_ELT(SDNode *N); SDValue ScalarizeVecRes_LOAD(LoadSDNode *N); + SDValue ScalarizeVecRes_ATOMIC_LOAD(AtomicSDNode *N); SDValue ScalarizeVecRes_SCALAR_TO_VECTOR(SDNode *N); SDValue ScalarizeVecRes_VSELECT(SDNode *N); SDValue ScalarizeVecRes_SELECT(SDNode *N); diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp index 3b5f83f..bb4a8d9 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -69,6 +69,9 @@ void DAGTypeLegalizer::ScalarizeVectorResult(SDNode *N, unsigned ResNo) { R = ScalarizeVecRes_UnaryOpWithExtraInput(N); break; case ISD::INSERT_VECTOR_ELT: R = ScalarizeVecRes_INSERT_VECTOR_ELT(N); break; + case ISD::ATOMIC_LOAD: + R = ScalarizeVecRes_ATOMIC_LOAD(cast<AtomicSDNode>(N)); + break; case ISD::LOAD: R = ScalarizeVecRes_LOAD(cast<LoadSDNode>(N));break; case ISD::SCALAR_TO_VECTOR: R = ScalarizeVecRes_SCALAR_TO_VECTOR(N); break; case ISD::SIGN_EXTEND_INREG: R = ScalarizeVecRes_InregOp(N); break; @@ -475,6 +478,18 @@ SDValue DAGTypeLegalizer::ScalarizeVecRes_INSERT_VECTOR_ELT(SDNode *N) { return Op; } +SDValue DAGTypeLegalizer::ScalarizeVecRes_ATOMIC_LOAD(AtomicSDNode *N) { + SDValue Result = DAG.getAtomicLoad( + N->getExtensionType(), SDLoc(N), N->getMemoryVT().getVectorElementType(), + N->getValueType(0).getVectorElementType(), N->getChain(), N->getBasePtr(), + N->getMemOperand()); + + // Legalize the chain result - switch anything that used the old chain to + // use the new one. + ReplaceValueWith(SDValue(N, 1), Result.getValue(1)); + return Result; +} + SDValue DAGTypeLegalizer::ScalarizeVecRes_LOAD(LoadSDNode *N) { assert(N->isUnindexed() && "Indexed vector load?"); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 90edaf3..379242e 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -8620,7 +8620,10 @@ SDValue SelectionDAG::getMemBasePlusOffset(SDValue Ptr, SDValue Offset, if (TLI->shouldPreservePtrArith(this->getMachineFunction().getFunction(), BasePtrVT)) return getNode(ISD::PTRADD, DL, BasePtrVT, Ptr, Offset, Flags); - return getNode(ISD::ADD, DL, BasePtrVT, Ptr, Offset, Flags); + // InBounds only applies to PTRADD, don't set it if we generate ADD. + SDNodeFlags AddFlags = Flags; + AddFlags.setInBounds(false); + return getNode(ISD::ADD, DL, BasePtrVT, Ptr, Offset, AddFlags); } /// Returns true if memcpy source is constant data. diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index d57c5fb..bfa566a 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -4390,6 +4390,7 @@ void SelectionDAGBuilder::visitGetElementPtr(const User &I) { if (NW.hasNoUnsignedWrap() || (int64_t(Offset) >= 0 && NW.hasNoUnsignedSignedWrap())) Flags |= SDNodeFlags::NoUnsignedWrap; + Flags.setInBounds(NW.isInBounds()); N = DAG.getMemBasePlusOffset( N, DAG.getConstant(Offset, dl, N.getValueType()), dl, Flags); @@ -4433,6 +4434,7 @@ void SelectionDAGBuilder::visitGetElementPtr(const User &I) { if (NW.hasNoUnsignedWrap() || (Offs.isNonNegative() && NW.hasNoUnsignedSignedWrap())) Flags.setNoUnsignedWrap(true); + Flags.setInBounds(NW.isInBounds()); OffsVal = DAG.getSExtOrTrunc(OffsVal, dl, N.getValueType()); @@ -4502,6 +4504,7 @@ void SelectionDAGBuilder::visitGetElementPtr(const User &I) { // pointer index type (add nuw). SDNodeFlags AddFlags; AddFlags.setNoUnsignedWrap(NW.hasNoUnsignedWrap()); + AddFlags.setInBounds(NW.isInBounds()); N = DAG.getMemBasePlusOffset(N, IdxN, dl, AddFlags); } diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp index 39cbfad..77377d3 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp @@ -689,6 +689,9 @@ void SDNode::print_details(raw_ostream &OS, const SelectionDAG *G) const { if (getFlags().hasSameSign()) OS << " samesign"; + if (getFlags().hasInBounds()) + OS << " inbounds"; + if (getFlags().hasNonNeg()) OS << " nneg"; diff --git a/llvm/lib/CodeGen/TargetPassConfig.cpp b/llvm/lib/CodeGen/TargetPassConfig.cpp index b6169e6..10b7238 100644 --- a/llvm/lib/CodeGen/TargetPassConfig.cpp +++ b/llvm/lib/CodeGen/TargetPassConfig.cpp @@ -272,6 +272,12 @@ static cl::opt<bool> cl::desc("Split static data sections into hot and cold " "sections using profile information")); +cl::opt<bool> EmitBBHash( + "emit-bb-hash", + cl::desc( + "Emit the hash of basic block in the SHT_LLVM_BB_ADDR_MAP section."), + cl::init(false), cl::Optional); + /// Allow standard passes to be disabled by command line options. This supports /// simple binary flags that either suppress the pass or do nothing. /// i.e. -disable-mypass=false has no effect. @@ -1281,6 +1287,8 @@ void TargetPassConfig::addMachinePasses() { // address map (or both). if (TM->getBBSectionsType() != llvm::BasicBlockSection::None || TM->Options.BBAddrMap) { + if (EmitBBHash) + addPass(llvm::createMachineBlockHashInfoPass()); if (TM->getBBSectionsType() == llvm::BasicBlockSection::List) { addPass(llvm::createBasicBlockSectionsProfileReaderWrapperPass( TM->getBBSectionsFuncListBuf())); diff --git a/llvm/lib/ExecutionEngine/JITLink/JITLinkMemoryManager.cpp b/llvm/lib/ExecutionEngine/JITLink/JITLinkMemoryManager.cpp index 6c7e27e..fa04976 100644 --- a/llvm/lib/ExecutionEngine/JITLink/JITLinkMemoryManager.cpp +++ b/llvm/lib/ExecutionEngine/JITLink/JITLinkMemoryManager.cpp @@ -247,7 +247,7 @@ public: StandardSegments(std::move(StandardSegments)), FinalizationSegments(std::move(FinalizationSegments)) {} - ~IPInFlightAlloc() { + ~IPInFlightAlloc() override { assert(!G && "InFlight alloc neither abandoned nor finalized"); } diff --git a/llvm/lib/ExecutionEngine/Orc/Debugging/DebuggerSupportPlugin.cpp b/llvm/lib/ExecutionEngine/Orc/Debugging/DebuggerSupportPlugin.cpp index 75ae80f..4ceff48 100644 --- a/llvm/lib/ExecutionEngine/Orc/Debugging/DebuggerSupportPlugin.cpp +++ b/llvm/lib/ExecutionEngine/Orc/Debugging/DebuggerSupportPlugin.cpp @@ -38,7 +38,7 @@ public: MachODebugObjectSynthesizerBase(LinkGraph &G, ExecutorAddr RegisterActionAddr) : G(G), RegisterActionAddr(RegisterActionAddr) {} - virtual ~MachODebugObjectSynthesizerBase() = default; + ~MachODebugObjectSynthesizerBase() override = default; Error preserveDebugSections() { if (G.findSectionByName(SynthDebugSectionName)) { diff --git a/llvm/lib/ExecutionEngine/Orc/LinkGraphLinkingLayer.cpp b/llvm/lib/ExecutionEngine/Orc/LinkGraphLinkingLayer.cpp index d1a6eaf..a2990ab 100644 --- a/llvm/lib/ExecutionEngine/Orc/LinkGraphLinkingLayer.cpp +++ b/llvm/lib/ExecutionEngine/Orc/LinkGraphLinkingLayer.cpp @@ -55,7 +55,7 @@ public: Plugins = Layer.Plugins; } - ~JITLinkCtx() { + ~JITLinkCtx() override { // If there is an object buffer return function then use it to // return ownership of the buffer. if (Layer.ReturnObjectBuffer && ObjBuffer) diff --git a/llvm/lib/ExecutionEngine/Orc/OrcV2CBindings.cpp b/llvm/lib/ExecutionEngine/Orc/OrcV2CBindings.cpp index fd805fbf..cdde733 100644 --- a/llvm/lib/ExecutionEngine/Orc/OrcV2CBindings.cpp +++ b/llvm/lib/ExecutionEngine/Orc/OrcV2CBindings.cpp @@ -92,7 +92,7 @@ public: Name(std::move(Name)), Ctx(Ctx), Materialize(Materialize), Discard(Discard), Destroy(Destroy) {} - ~OrcCAPIMaterializationUnit() { + ~OrcCAPIMaterializationUnit() override { if (Ctx) Destroy(Ctx); } @@ -264,7 +264,7 @@ public: LLVMOrcCAPIDefinitionGeneratorTryToGenerateFunction TryToGenerate) : Dispose(Dispose), Ctx(Ctx), TryToGenerate(TryToGenerate) {} - ~CAPIDefinitionGenerator() { + ~CAPIDefinitionGenerator() override { if (Dispose) Dispose(Ctx); } diff --git a/llvm/lib/IR/Module.cpp b/llvm/lib/IR/Module.cpp index 30b5e48..e19336e 100644 --- a/llvm/lib/IR/Module.cpp +++ b/llvm/lib/IR/Module.cpp @@ -403,9 +403,14 @@ void Module::setModuleFlag(ModFlagBehavior Behavior, StringRef Key, Metadata *Val) { NamedMDNode *ModFlags = getOrInsertModuleFlagsMetadata(); // Replace the flag if it already exists. - for (MDNode *Flag : ModFlags->operands()) { + for (unsigned i = 0; i < ModFlags->getNumOperands(); ++i) { + MDNode *Flag = ModFlags->getOperand(i); if (cast<MDString>(Flag->getOperand(1))->getString() == Key) { - Flag->replaceOperandWith(2, Val); + Type *Int32Ty = Type::getInt32Ty(Context); + Metadata *Ops[3] = { + ConstantAsMetadata::get(ConstantInt::get(Int32Ty, Behavior)), + MDString::get(Context, Key), Val}; + ModFlags->setOperand(i, MDNode::get(Context, Ops)); return; } } diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp index 03da154..7917712 100644 --- a/llvm/lib/IR/Verifier.cpp +++ b/llvm/lib/IR/Verifier.cpp @@ -4446,10 +4446,12 @@ void Verifier::visitLoadInst(LoadInst &LI) { Check(LI.getOrdering() != AtomicOrdering::Release && LI.getOrdering() != AtomicOrdering::AcquireRelease, "Load cannot have Release ordering", &LI); - Check(ElTy->isIntOrPtrTy() || ElTy->isFloatingPointTy(), - "atomic load operand must have integer, pointer, or floating point " - "type!", + Check(ElTy->getScalarType()->isIntOrPtrTy() || + ElTy->getScalarType()->isFloatingPointTy(), + "atomic load operand must have integer, pointer, floating point, " + "or vector type!", ElTy, &LI); + checkAtomicMemAccessSize(ElTy, &LI); } else { Check(LI.getSyncScopeID() == SyncScope::System, @@ -4472,9 +4474,10 @@ void Verifier::visitStoreInst(StoreInst &SI) { Check(SI.getOrdering() != AtomicOrdering::Acquire && SI.getOrdering() != AtomicOrdering::AcquireRelease, "Store cannot have Acquire ordering", &SI); - Check(ElTy->isIntOrPtrTy() || ElTy->isFloatingPointTy(), - "atomic store operand must have integer, pointer, or floating point " - "type!", + Check(ElTy->getScalarType()->isIntOrPtrTy() || + ElTy->getScalarType()->isFloatingPointTy(), + "atomic store operand must have integer, pointer, floating point, " + "or vector type!", ElTy, &SI); checkAtomicMemAccessSize(ElTy, &SI); } else { diff --git a/llvm/lib/LTO/LTO.cpp b/llvm/lib/LTO/LTO.cpp index 86780e1..9d0fa11 100644 --- a/llvm/lib/LTO/LTO.cpp +++ b/llvm/lib/LTO/LTO.cpp @@ -2224,6 +2224,7 @@ class OutOfProcessThinBackend : public CGThinBackend { ArrayRef<StringRef> DistributorArgs; SString RemoteCompiler; + ArrayRef<StringRef> RemoteCompilerPrependArgs; ArrayRef<StringRef> RemoteCompilerArgs; bool SaveTemps; @@ -2260,12 +2261,14 @@ public: bool ShouldEmitIndexFiles, bool ShouldEmitImportsFiles, StringRef LinkerOutputFile, StringRef Distributor, ArrayRef<StringRef> DistributorArgs, StringRef RemoteCompiler, + ArrayRef<StringRef> RemoteCompilerPrependArgs, ArrayRef<StringRef> RemoteCompilerArgs, bool SaveTemps) : CGThinBackend(Conf, CombinedIndex, ModuleToDefinedGVSummaries, AddStream, OnWrite, ShouldEmitIndexFiles, ShouldEmitImportsFiles, ThinLTOParallelism), LinkerOutputFile(LinkerOutputFile), DistributorPath(Distributor), DistributorArgs(DistributorArgs), RemoteCompiler(RemoteCompiler), + RemoteCompilerPrependArgs(RemoteCompilerPrependArgs), RemoteCompilerArgs(RemoteCompilerArgs), SaveTemps(SaveTemps) {} virtual void setup(unsigned ThinLTONumTasks, unsigned ThinLTOTaskOffset, @@ -2387,6 +2390,11 @@ public: JOS.attributeArray("args", [&]() { JOS.value(RemoteCompiler); + // Forward any supplied prepend options. + if (!RemoteCompilerPrependArgs.empty()) + for (auto &A : RemoteCompilerPrependArgs) + JOS.value(A); + JOS.value("-c"); JOS.value(Saver.save("--target=" + Triple.str())); @@ -2517,6 +2525,7 @@ ThinBackend lto::createOutOfProcessThinBackend( bool ShouldEmitIndexFiles, bool ShouldEmitImportsFiles, StringRef LinkerOutputFile, StringRef Distributor, ArrayRef<StringRef> DistributorArgs, StringRef RemoteCompiler, + ArrayRef<StringRef> RemoteCompilerPrependArgs, ArrayRef<StringRef> RemoteCompilerArgs, bool SaveTemps) { auto Func = [=](const Config &Conf, ModuleSummaryIndex &CombinedIndex, @@ -2526,7 +2535,7 @@ ThinBackend lto::createOutOfProcessThinBackend( Conf, CombinedIndex, Parallelism, ModuleToDefinedGVSummaries, AddStream, OnWrite, ShouldEmitIndexFiles, ShouldEmitImportsFiles, LinkerOutputFile, Distributor, DistributorArgs, RemoteCompiler, - RemoteCompilerArgs, SaveTemps); + RemoteCompilerPrependArgs, RemoteCompilerArgs, SaveTemps); }; return ThinBackend(Func, Parallelism); } diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp index 048c58d..3c9a27a 100644 --- a/llvm/lib/Passes/PassBuilder.cpp +++ b/llvm/lib/Passes/PassBuilder.cpp @@ -669,7 +669,14 @@ void PassBuilder::registerFunctionAnalyses(FunctionAnalysisManager &FAM) { FAM.registerPass([&] { return buildDefaultAAPipeline(); }); #define FUNCTION_ANALYSIS(NAME, CREATE_PASS) \ - FAM.registerPass([&] { return CREATE_PASS; }); + if constexpr (std::is_constructible_v< \ + std::remove_reference_t<decltype(CREATE_PASS)>, \ + const TargetMachine &>) { \ + if (TM) \ + FAM.registerPass([&] { return CREATE_PASS; }); \ + } else { \ + FAM.registerPass([&] { return CREATE_PASS; }); \ + } #include "PassRegistry.def" for (auto &C : FunctionAnalysisRegistrationCallbacks) @@ -2038,6 +2045,14 @@ Error PassBuilder::parseModulePass(ModulePassManager &MPM, } #define FUNCTION_PASS(NAME, CREATE_PASS) \ if (Name == NAME) { \ + if constexpr (std::is_constructible_v< \ + std::remove_reference_t<decltype(CREATE_PASS)>, \ + const TargetMachine &>) { \ + if (!TM) \ + return make_error<StringError>( \ + formatv("pass '{0}' requires TargetMachine", Name).str(), \ + inconvertibleErrorCode()); \ + } \ MPM.addPass(createModuleToFunctionPassAdaptor(CREATE_PASS)); \ return Error::success(); \ } @@ -2046,6 +2061,18 @@ Error PassBuilder::parseModulePass(ModulePassManager &MPM, auto Params = parsePassParameters(PARSER, Name, NAME); \ if (!Params) \ return Params.takeError(); \ + auto CreatePass = CREATE_PASS; \ + if constexpr (std::is_constructible_v< \ + std::remove_reference_t<decltype(CreatePass( \ + Params.get()))>, \ + const TargetMachine &, \ + std::remove_reference_t<decltype(Params.get())>>) { \ + if (!TM) { \ + return make_error<StringError>( \ + formatv("pass '{0}' requires TargetMachine", Name).str(), \ + inconvertibleErrorCode()); \ + } \ + } \ MPM.addPass(createModuleToFunctionPassAdaptor(CREATE_PASS(Params.get()))); \ return Error::success(); \ } @@ -2152,6 +2179,14 @@ Error PassBuilder::parseCGSCCPass(CGSCCPassManager &CGPM, } #define FUNCTION_PASS(NAME, CREATE_PASS) \ if (Name == NAME) { \ + if constexpr (std::is_constructible_v< \ + std::remove_reference_t<decltype(CREATE_PASS)>, \ + const TargetMachine &>) { \ + if (!TM) \ + return make_error<StringError>( \ + formatv("pass '{0}' requires TargetMachine", Name).str(), \ + inconvertibleErrorCode()); \ + } \ CGPM.addPass(createCGSCCToFunctionPassAdaptor(CREATE_PASS)); \ return Error::success(); \ } @@ -2160,6 +2195,18 @@ Error PassBuilder::parseCGSCCPass(CGSCCPassManager &CGPM, auto Params = parsePassParameters(PARSER, Name, NAME); \ if (!Params) \ return Params.takeError(); \ + auto CreatePass = CREATE_PASS; \ + if constexpr (std::is_constructible_v< \ + std::remove_reference_t<decltype(CreatePass( \ + Params.get()))>, \ + const TargetMachine &, \ + std::remove_reference_t<decltype(Params.get())>>) { \ + if (!TM) { \ + return make_error<StringError>( \ + formatv("pass '{0}' requires TargetMachine", Name).str(), \ + inconvertibleErrorCode()); \ + } \ + } \ CGPM.addPass(createCGSCCToFunctionPassAdaptor(CREATE_PASS(Params.get()))); \ return Error::success(); \ } @@ -2239,6 +2286,14 @@ Error PassBuilder::parseFunctionPass(FunctionPassManager &FPM, // Now expand the basic registered passes from the .inc file. #define FUNCTION_PASS(NAME, CREATE_PASS) \ if (Name == NAME) { \ + if constexpr (std::is_constructible_v< \ + std::remove_reference_t<decltype(CREATE_PASS)>, \ + const TargetMachine &>) { \ + if (!TM) \ + return make_error<StringError>( \ + formatv("pass '{0}' requires TargetMachine", Name).str(), \ + inconvertibleErrorCode()); \ + } \ FPM.addPass(CREATE_PASS); \ return Error::success(); \ } @@ -2247,14 +2302,34 @@ Error PassBuilder::parseFunctionPass(FunctionPassManager &FPM, auto Params = parsePassParameters(PARSER, Name, NAME); \ if (!Params) \ return Params.takeError(); \ + auto CreatePass = CREATE_PASS; \ + if constexpr (std::is_constructible_v< \ + std::remove_reference_t<decltype(CreatePass( \ + Params.get()))>, \ + const TargetMachine &, \ + std::remove_reference_t<decltype(Params.get())>>) { \ + if (!TM) { \ + return make_error<StringError>( \ + formatv("pass '{0}' requires TargetMachine", Name).str(), \ + inconvertibleErrorCode()); \ + } \ + } \ FPM.addPass(CREATE_PASS(Params.get())); \ return Error::success(); \ } #define FUNCTION_ANALYSIS(NAME, CREATE_PASS) \ if (Name == "require<" NAME ">") { \ + if constexpr (std::is_constructible_v< \ + std::remove_reference_t<decltype(CREATE_PASS)>, \ + const TargetMachine &>) { \ + if (!TM) \ + return make_error<StringError>( \ + formatv("pass '{0}' requires TargetMachine", Name).str(), \ + inconvertibleErrorCode()); \ + } \ FPM.addPass( \ - RequireAnalysisPass< \ - std::remove_reference_t<decltype(CREATE_PASS)>, Function>()); \ + RequireAnalysisPass<std::remove_reference_t<decltype(CREATE_PASS)>, \ + Function>()); \ return Error::success(); \ } \ if (Name == "invalidate<" NAME ">") { \ diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def index a66b6e4..1853cdd 100644 --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -345,7 +345,7 @@ FUNCTION_ANALYSIS("aa", AAManager()) FUNCTION_ANALYSIS("access-info", LoopAccessAnalysis()) FUNCTION_ANALYSIS("assumptions", AssumptionAnalysis()) FUNCTION_ANALYSIS("bb-sections-profile-reader", - BasicBlockSectionsProfileReaderAnalysis(TM)) + BasicBlockSectionsProfileReaderAnalysis(*TM)) FUNCTION_ANALYSIS("block-freq", BlockFrequencyAnalysis()) FUNCTION_ANALYSIS("branch-prob", BranchProbabilityAnalysis()) FUNCTION_ANALYSIS("cycles", CycleAnalysis()) @@ -356,7 +356,7 @@ FUNCTION_ANALYSIS("domfrontier", DominanceFrontierAnalysis()) FUNCTION_ANALYSIS("domtree", DominatorTreeAnalysis()) FUNCTION_ANALYSIS("ephemerals", EphemeralValuesAnalysis()) FUNCTION_ANALYSIS("func-properties", FunctionPropertiesAnalysis()) -FUNCTION_ANALYSIS("machine-function-info", MachineFunctionAnalysis(TM)) +FUNCTION_ANALYSIS("machine-function-info", MachineFunctionAnalysis(*TM)) FUNCTION_ANALYSIS("gc-function", GCFunctionAnalysis()) FUNCTION_ANALYSIS("inliner-size-estimator", InlineSizeEstimatorAnalysis()) FUNCTION_ANALYSIS("last-run-tracking", LastRunTrackingAnalysis()) @@ -406,14 +406,14 @@ FUNCTION_PASS("alignment-from-assumptions", AlignmentFromAssumptionsPass()) FUNCTION_PASS("annotation-remarks", AnnotationRemarksPass()) FUNCTION_PASS("assume-builder", AssumeBuilderPass()) FUNCTION_PASS("assume-simplify", AssumeSimplifyPass()) -FUNCTION_PASS("atomic-expand", AtomicExpandPass(TM)) +FUNCTION_PASS("atomic-expand", AtomicExpandPass(*TM)) FUNCTION_PASS("bdce", BDCEPass()) FUNCTION_PASS("break-crit-edges", BreakCriticalEdgesPass()) FUNCTION_PASS("callbr-prepare", CallBrPreparePass()) FUNCTION_PASS("callsite-splitting", CallSiteSplittingPass()) FUNCTION_PASS("chr", ControlHeightReductionPass()) -FUNCTION_PASS("codegenprepare", CodeGenPreparePass(TM)) -FUNCTION_PASS("complex-deinterleaving", ComplexDeinterleavingPass(TM)) +FUNCTION_PASS("codegenprepare", CodeGenPreparePass(*TM)) +FUNCTION_PASS("complex-deinterleaving", ComplexDeinterleavingPass(*TM)) FUNCTION_PASS("consthoist", ConstantHoistingPass()) FUNCTION_PASS("constraint-elimination", ConstraintEliminationPass()) FUNCTION_PASS("coro-elide", CoroElidePass()) @@ -430,10 +430,10 @@ FUNCTION_PASS("dot-dom-only", DomOnlyPrinter()) FUNCTION_PASS("dot-post-dom", PostDomPrinter()) FUNCTION_PASS("dot-post-dom-only", PostDomOnlyPrinter()) FUNCTION_PASS("dse", DSEPass()) -FUNCTION_PASS("dwarf-eh-prepare", DwarfEHPreparePass(TM)) +FUNCTION_PASS("dwarf-eh-prepare", DwarfEHPreparePass(*TM)) FUNCTION_PASS("drop-unnecessary-assumes", DropUnnecessaryAssumesPass()) -FUNCTION_PASS("expand-large-div-rem", ExpandLargeDivRemPass(TM)) -FUNCTION_PASS("expand-memcmp", ExpandMemCmpPass(TM)) +FUNCTION_PASS("expand-large-div-rem", ExpandLargeDivRemPass(*TM)) +FUNCTION_PASS("expand-memcmp", ExpandMemCmpPass(*TM)) FUNCTION_PASS("expand-reductions", ExpandReductionsPass()) FUNCTION_PASS("extra-vector-passes", ExtraFunctionPassManager<ShouldRunExtraVectorPasses>()) @@ -446,15 +446,15 @@ FUNCTION_PASS("guard-widening", GuardWideningPass()) FUNCTION_PASS("gvn-hoist", GVNHoistPass()) FUNCTION_PASS("gvn-sink", GVNSinkPass()) FUNCTION_PASS("helloworld", HelloWorldPass()) -FUNCTION_PASS("indirectbr-expand", IndirectBrExpandPass(TM)) +FUNCTION_PASS("indirectbr-expand", IndirectBrExpandPass(*TM)) FUNCTION_PASS("infer-address-spaces", InferAddressSpacesPass()) FUNCTION_PASS("infer-alignment", InferAlignmentPass()) FUNCTION_PASS("inject-tli-mappings", InjectTLIMappings()) FUNCTION_PASS("instcount", InstCountPass()) FUNCTION_PASS("instnamer", InstructionNamerPass()) FUNCTION_PASS("instsimplify", InstSimplifyPass()) -FUNCTION_PASS("interleaved-access", InterleavedAccessPass(TM)) -FUNCTION_PASS("interleaved-load-combine", InterleavedLoadCombinePass(TM)) +FUNCTION_PASS("interleaved-access", InterleavedAccessPass(*TM)) +FUNCTION_PASS("interleaved-load-combine", InterleavedLoadCombinePass(*TM)) FUNCTION_PASS("invalidate<all>", InvalidateAllAnalysesPass()) FUNCTION_PASS("irce", IRCEPass()) FUNCTION_PASS("jump-threading", JumpThreadingPass()) @@ -533,25 +533,25 @@ FUNCTION_PASS("reassociate", ReassociatePass()) FUNCTION_PASS("redundant-dbg-inst-elim", RedundantDbgInstEliminationPass()) FUNCTION_PASS("replace-with-veclib", ReplaceWithVeclib()) FUNCTION_PASS("reg2mem", RegToMemPass()) -FUNCTION_PASS("safe-stack", SafeStackPass(TM)) +FUNCTION_PASS("safe-stack", SafeStackPass(*TM)) FUNCTION_PASS("sandbox-vectorizer", SandboxVectorizerPass()) FUNCTION_PASS("scalarize-masked-mem-intrin", ScalarizeMaskedMemIntrinPass()) FUNCTION_PASS("sccp", SCCPPass()) -FUNCTION_PASS("select-optimize", SelectOptimizePass(TM)) +FUNCTION_PASS("select-optimize", SelectOptimizePass(*TM)) FUNCTION_PASS("separate-const-offset-from-gep", SeparateConstOffsetFromGEPPass()) FUNCTION_PASS("sink", SinkingPass()) FUNCTION_PASS("sjlj-eh-prepare", SjLjEHPreparePass(TM)) FUNCTION_PASS("slp-vectorizer", SLPVectorizerPass()) FUNCTION_PASS("slsr", StraightLineStrengthReducePass()) -FUNCTION_PASS("stack-protector", StackProtectorPass(TM)) +FUNCTION_PASS("stack-protector", StackProtectorPass(*TM)) FUNCTION_PASS("strip-gc-relocates", StripGCRelocates()) FUNCTION_PASS("tailcallelim", TailCallElimPass()) FUNCTION_PASS("transform-warning", WarnMissedTransformationsPass()) FUNCTION_PASS("trigger-crash-function", TriggerCrashFunctionPass()) FUNCTION_PASS("trigger-verifier-error", TriggerVerifierErrorPass()) FUNCTION_PASS("tsan", ThreadSanitizerPass()) -FUNCTION_PASS("typepromotion", TypePromotionPass(TM)) +FUNCTION_PASS("typepromotion", TypePromotionPass(*TM)) FUNCTION_PASS("unify-loop-exits", UnifyLoopExitsPass()) FUNCTION_PASS("unreachableblockelim", UnreachableBlockElimPass()) FUNCTION_PASS("vector-combine", VectorCombinePass()) @@ -730,7 +730,7 @@ FUNCTION_PASS_WITH_PARAMS( FUNCTION_PASS_WITH_PARAMS( "expand-fp", "ExpandFpPass", [TM = TM](CodeGenOptLevel OL) { - return ExpandFpPass(TM, OL); + return ExpandFpPass(*TM, OL); }, parseExpandFpOptions, "O0;O1;O2;O3") diff --git a/llvm/lib/Target/AArch64/AArch64.td b/llvm/lib/Target/AArch64/AArch64.td index 86f9548..a4529a5 100644 --- a/llvm/lib/Target/AArch64/AArch64.td +++ b/llvm/lib/Target/AArch64/AArch64.td @@ -73,9 +73,16 @@ def SVEUnsupported : AArch64Unsupported { SVE2Unsupported.F); } -let F = [HasSME2p2, HasSVE2p2_or_SME2p2, HasNonStreamingSVE_or_SME2p2, - HasNonStreamingSVE2p2_or_SME2p2] in -def SME2p2Unsupported : AArch64Unsupported; +def SME2p3Unsupported : AArch64Unsupported { + let F = [HasSVE2p3_or_SME2p3, HasSVE_B16MM]; +} + +def SME2p2Unsupported : AArch64Unsupported { + let F = !listconcat([HasSME2p2, HasSVE2p2_or_SME2p2, + HasNonStreamingSVE_or_SME2p2, + HasNonStreamingSVE2p2_or_SME2p2], + SME2p3Unsupported.F); +} def SME2p1Unsupported : AArch64Unsupported { let F = !listconcat([HasSME2p1, HasSVE2p1_or_SME2p1, diff --git a/llvm/lib/Target/AArch64/AArch64Features.td b/llvm/lib/Target/AArch64/AArch64Features.td index 46f5f0c..0e94b78 100644 --- a/llvm/lib/Target/AArch64/AArch64Features.td +++ b/llvm/lib/Target/AArch64/AArch64Features.td @@ -585,6 +585,47 @@ def FeatureSME_TMOP: ExtensionWithMArch<"sme-tmop", "SME_TMOP", "FEAT_SME_TMOP", def FeatureSSVE_FEXPA : ExtensionWithMArch<"ssve-fexpa", "SSVE_FEXPA", "FEAT_SSVE_FEXPA", "Enable SVE FEXPA instruction in Streaming SVE mode", [FeatureSME2]>; +//===----------------------------------------------------------------------===// +// Armv9.7 Architecture Extensions +//===----------------------------------------------------------------------===// + +def FeatureCMH : ExtensionWithMArch<"cmh", "CMH", "FEAT_CMH", + "Enable Armv9.7-A Contention Management Hints">; + +def FeatureLSCP : ExtensionWithMArch<"lscp", "LSCP", "FEAT_LSCP", + "Enable Armv9.7-A Load-acquire and store-release pair extension">; + +def FeatureTLBID: ExtensionWithMArch<"tlbid", "TLBID", "FEAT_TLBID", + "Enable Armv9.7-A TLBI Domains extension">; + +def FeatureMPAMv2: ExtensionWithMArch<"mpamv2", "MPAMv2", "FEAT_MPAMv2", + "Enable Armv9.7-A MPAMv2 Lookaside Buffer Invalidate instructions">; + +def FeatureMTETC: ExtensionWithMArch<"mtetc", "MTETC", "FEAT_MTETC", + "Enable Virtual Memory Tagging Extension">; + +def FeatureGCIE: ExtensionWithMArch<"gcie", "GCIE", "FEAT_GCIE", + "Enable GICv5 (Generic Interrupt Controller) CPU Interface Extension">; + +def FeatureSVE2p3 : ExtensionWithMArch<"sve2p3", "SVE2p3", "FEAT_SVE2p3", + "Enable Armv9.7-A Scalable Vector Extension 2.3 instructions", [FeatureSVE2p2]>; + +def FeatureSME2p3 : ExtensionWithMArch<"sme2p3", "SME2p3", "FEAT_SME2p3", + "Enable Armv9.7-A Scalable Matrix Extension 2.3 instructions", [FeatureSME2p2]>; + +def FeatureSVE_B16MM : ExtensionWithMArch<"sve-b16mm", "SVE_B16MM", "FEAT_SVE_B16MM", + "Enable Armv9.7-A SVE non-widening BFloat16 matrix multiply-accumulate", [FeatureSVE]>; + +def FeatureF16MM : ExtensionWithMArch<"f16mm", "F16MM", "FEAT_F16MM", + "Enable Armv9.7-A non-widening half-precision matrix multiply-accumulate", [FeatureFullFP16]>; + +def FeatureF16F32DOT : ExtensionWithMArch<"f16f32dot", "F16F32DOT", "FEAT_F16F32DOT", + "Enable Armv9.7-A Advanced SIMD half-precision dot product accumulate to single-precision", [FeatureNEON, FeatureFullFP16]>; + +def FeatureF16F32MM : ExtensionWithMArch<"f16f32mm", "F16F32MM", "FEAT_F16F32MM", + "Enable Armv9.7-A Advanced SIMD half-precision matrix multiply-accumulate to single-precision", [FeatureNEON, FeatureFullFP16]>; + +//===----------------------------------------------------------------------===// // Other Features //===----------------------------------------------------------------------===// @@ -939,9 +980,12 @@ def HasV9_5aOps : Architecture64<9, 5, "a", "v9.5a", [HasV9_4aOps, FeatureCPA], !listconcat(HasV9_4aOps.DefaultExts, [FeatureCPA, FeatureLUT, FeatureFAMINMAX])>; def HasV9_6aOps : Architecture64<9, 6, "a", "v9.6a", - [HasV9_5aOps, FeatureCMPBR, FeatureFPRCVT, FeatureSVE2p2, FeatureLSUI, FeatureOCCMO], - !listconcat(HasV9_5aOps.DefaultExts, [FeatureCMPBR, FeatureFPRCVT, FeatureSVE2p2, + [HasV9_5aOps, FeatureCMPBR, FeatureLSUI, FeatureOCCMO], + !listconcat(HasV9_5aOps.DefaultExts, [FeatureCMPBR, FeatureLSUI, FeatureOCCMO])>; +def HasV9_7aOps : Architecture64<9, 7, "a", "v9.7a", + [HasV9_6aOps, FeatureSVE2p3, FeatureFPRCVT], + !listconcat(HasV9_6aOps.DefaultExts, [FeatureSVE2p3, FeatureFPRCVT])>; def HasV8_0rOps : Architecture64<8, 0, "r", "v8r", [ //v8.1 FeatureCRC, FeaturePAN, FeatureLSE, FeatureCONTEXTIDREL2, diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td index 09ce713..eab1627 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td +++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td @@ -1894,6 +1894,21 @@ def btihint_op : Operand<i32> { }]; } +def CMHPriorityHintOperand : AsmOperandClass { + let Name = "CMHPriorityHint"; + let ParserMethod = "tryParseCMHPriorityHint"; +} + +def CMHPriorityHint_op : Operand<i32> { + let ParserMatchClass = CMHPriorityHintOperand; + let PrintMethod = "printCMHPriorityHintOp"; + let MCOperandPredicate = [{ + if (!MCOp.isImm()) + return false; + return AArch64CMHPriorityHint::lookupCMHPriorityHintByEncoding(MCOp.getImm()) != nullptr; + }]; +} + class MRSI : RtSystemI<1, (outs GPR64:$Rt), (ins mrs_sysreg_op:$systemreg), "mrs", "\t$Rt, $systemreg"> { bits<16> systemreg; @@ -4636,6 +4651,48 @@ multiclass StorePairOffset<bits<2> opc, bit V, RegisterOperand regtype, GPR64sp:$Rn, 0)>; } +class BaseLoadStoreAcquirePairOffset<bits<4> opc, bit L, dag oops, dag iops, + string asm> + : I<oops, iops, asm, "\t$Rt, $Rt2, [$Rn, #0]", "", []> { + bits<5> Rt; + bits<5> Rt2; + bits<5> Rn; + let Inst{31-23} = 0b110110010; + let Inst{22} = L; + let Inst{21} = 0b0; + let Inst{20-16} = Rt2; + let Inst{15-12} = opc; + let Inst{11-10} = 0b10; + let Inst{9-5} = Rn; + let Inst{4-0} = Rt; +} + +multiclass LoadAcquirePairOffset<bits<4> opc, string asm> { + let hasSideEffects = 0, mayStore = 0, mayLoad = 1 in + def i : BaseLoadStoreAcquirePairOffset<opc, 0b1, + (outs GPR64:$Rt, GPR64:$Rt2), + (ins GPR64sp:$Rn), asm>, + Sched<[WriteAtomic, WriteLDHi]>; + + def : InstAlias<asm # "\t$Rt, $Rt2, [$Rn]", + (!cast<Instruction>(NAME # "i") GPR64:$Rt, GPR64:$Rt2, + GPR64sp:$Rn)>; +} + + +multiclass StoreAcquirePairOffset<bits<4> opc, string asm> { + let hasSideEffects = 0, mayLoad = 0, mayStore = 1 in + def i : BaseLoadStoreAcquirePairOffset<opc, 0b0, (outs), + (ins GPR64:$Rt, GPR64:$Rt2, + GPR64sp:$Rn), + asm>, + Sched<[WriteSTP]>; + + def : InstAlias<asm # "\t$Rt, $Rt2, [$Rn]", + (!cast<Instruction>(NAME # "i") GPR64:$Rt, GPR64:$Rt2, + GPR64sp:$Rn)>; +} + // (pre-indexed) class BaseLoadStorePairPreIdx<bits<2> opc, bit V, bit L, dag oops, dag iops, string asm> @@ -6481,8 +6538,7 @@ multiclass SIMDThreeSameVectorFML<bit U, bit b13, bits<3> size, string asm, } multiclass SIMDThreeSameVectorMLA<bit Q, string asm, SDPatternOperator op> { - - def v8f16 : BaseSIMDThreeSameVectorDot<Q, 0b0, 0b11, 0b1111, asm, ".8h", ".16b", + def v16i8_v8f16 : BaseSIMDThreeSameVectorDot<Q, 0b0, 0b11, 0b1111, asm, ".8h", ".16b", V128, v8f16, v16i8, op>; } @@ -6491,6 +6547,23 @@ multiclass SIMDThreeSameVectorMLAL<bit Q, bits<2> sz, string asm, SDPatternOpera V128, v4f32, v16i8, op>; } +multiclass SIMDThreeSameVectorFMLA<string asm> { + def v8f16_v8f16 : BaseSIMDThreeSameVectorDot<0b1, 0b0, 0b11, 0b1101, asm, ".8h", ".8h", + V128, v8f16, v8f16, null_frag>; +} + +multiclass SIMDThreeSameVectorFMLAWiden<string asm> { + def v8f16_v4f32 : BaseSIMDThreeSameVectorDot<0b1, 0b0, 0b01, 0b1101, asm, ".4s", ".8h", + V128, v4f32, v8f16, null_frag>; +} + +multiclass SIMDThreeSameVectorFDot<string asm, SDPatternOperator OpNode = null_frag> { + def v4f16_v2f32 : BaseSIMDThreeSameVectorDot<0, 0, 0b10, 0b1111, asm, ".2s", ".4h", V64, + v2f32, v4f16, OpNode>; + def v8f16_v4f32 : BaseSIMDThreeSameVectorDot<1, 0, 0b10, 0b1111, asm, ".4s", ".8h", V128, + v4f32, v8f16, OpNode>; +} + // FP8 assembly/disassembly classes //---------------------------------------------------------------------------- @@ -9112,6 +9185,13 @@ multiclass SIMDThreeSameVectorFMLIndex<bit U, bits<4> opc, string asm, V128, V128_lo, v4f32, v8f16, VectorIndexH, OpNode>; } +multiclass SIMDThreeSameVectorFDOTIndex<string asm> { + def v4f16_v2f32 : BaseSIMDThreeSameVectorIndexS<0b0, 0b0, 0b01, 0b1001, asm, ".2s", ".4h", ".2h", + V64, v2f32, v4f16, VectorIndexS, null_frag>; + def v8f16_v4f32 : BaseSIMDThreeSameVectorIndexS<0b1, 0b0, 0b01, 0b1001, asm, ".4s", ".8h",".2h", + V128, v4f32, v8f16, VectorIndexS, null_frag>; +} + //---------------------------------------------------------------------------- // FP8 Advanced SIMD vector x indexed element multiclass SIMD_FP8_Dot2_Index<string asm, SDPatternOperator op> { @@ -13227,3 +13307,34 @@ multiclass SIMDThreeSameVectorFP8MatrixMul<string asm>{ let Predicates = [HasNEON, HasF8F32MM]; } } + +//---------------------------------------------------------------------------- +// Contention Management Hints - FEAT_CMH +//---------------------------------------------------------------------------- + +class SHUHInst<string asm> : I< + (outs), + (ins CMHPriorityHint_op:$priority), + asm, "\t$priority", "", []>, Sched<[]> { + bits<1> priority; + let Inst{31-12} = 0b11010101000000110010; + let Inst{11-8} = 0b0110; + let Inst{7-6} = 0b01; + let Inst{5} = priority; + let Inst{4-0} = 0b11111; +} + +multiclass SHUH<string asm> { + def NAME : SHUHInst<asm>; + def : InstAlias<asm, (!cast<Instruction>(NAME) 0), 1>; +} + +class STCPHInst<string asm> : I< + (outs), + (ins), + asm, "", "", []>, Sched<[]> { + let Inst{31-12} = 0b11010101000000110010; + let Inst{11-8} = 0b0110; + let Inst{7-5} = 0b100; + let Inst{4-0} = 0b11111; +} diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td index f788c75..b74ca79 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -50,63 +50,44 @@ def HasV9_4a : Predicate<"Subtarget->hasV9_4aOps()">, AssemblerPredicateWithAll<(all_of HasV9_4aOps), "armv9.4a">; def HasV8_0r : Predicate<"Subtarget->hasV8_0rOps()">, AssemblerPredicateWithAll<(all_of HasV8_0rOps), "armv8-r">; - def HasEL2VMSA : Predicate<"Subtarget->hasEL2VMSA()">, - AssemblerPredicateWithAll<(all_of FeatureEL2VMSA), "el2vmsa">; - + AssemblerPredicateWithAll<(all_of FeatureEL2VMSA), "el2vmsa">; def HasEL3 : Predicate<"Subtarget->hasEL3()">, - AssemblerPredicateWithAll<(all_of FeatureEL3), "el3">; - + AssemblerPredicateWithAll<(all_of FeatureEL3), "el3">; def HasVH : Predicate<"Subtarget->hasVH()">, - AssemblerPredicateWithAll<(all_of FeatureVH), "vh">; - + AssemblerPredicateWithAll<(all_of FeatureVH), "vh">; def HasLOR : Predicate<"Subtarget->hasLOR()">, - AssemblerPredicateWithAll<(all_of FeatureLOR), "lor">; - + AssemblerPredicateWithAll<(all_of FeatureLOR), "lor">; def HasPAuth : Predicate<"Subtarget->hasPAuth()">, - AssemblerPredicateWithAll<(all_of FeaturePAuth), "pauth">; - + AssemblerPredicateWithAll<(all_of FeaturePAuth), "pauth">; def HasPAuthLR : Predicate<"Subtarget->hasPAuthLR()">, - AssemblerPredicateWithAll<(all_of FeaturePAuthLR), "pauth-lr">; - + AssemblerPredicateWithAll<(all_of FeaturePAuthLR), "pauth-lr">; def HasJS : Predicate<"Subtarget->hasJS()">, - AssemblerPredicateWithAll<(all_of FeatureJS), "jsconv">; - + AssemblerPredicateWithAll<(all_of FeatureJS), "jsconv">; def HasCCIDX : Predicate<"Subtarget->hasCCIDX()">, - AssemblerPredicateWithAll<(all_of FeatureCCIDX), "ccidx">; - -def HasComplxNum : Predicate<"Subtarget->hasComplxNum()">, - AssemblerPredicateWithAll<(all_of FeatureComplxNum), "complxnum">; - + AssemblerPredicateWithAll<(all_of FeatureCCIDX), "ccidx">; +def HasComplxNum : Predicate<"Subtarget->hasComplxNum()">, + AssemblerPredicateWithAll<(all_of FeatureComplxNum), "complxnum">; def HasNV : Predicate<"Subtarget->hasNV()">, - AssemblerPredicateWithAll<(all_of FeatureNV), "nv">; - + AssemblerPredicateWithAll<(all_of FeatureNV), "nv">; def HasMPAM : Predicate<"Subtarget->hasMPAM()">, - AssemblerPredicateWithAll<(all_of FeatureMPAM), "mpam">; - + AssemblerPredicateWithAll<(all_of FeatureMPAM), "mpam">; def HasDIT : Predicate<"Subtarget->hasDIT()">, - AssemblerPredicateWithAll<(all_of FeatureDIT), "dit">; - -def HasTRACEV8_4 : Predicate<"Subtarget->hasTRACEV8_4()">, - AssemblerPredicateWithAll<(all_of FeatureTRACEV8_4), "tracev8.4">; - + AssemblerPredicateWithAll<(all_of FeatureDIT), "dit">; +def HasTRACEV8_4 : Predicate<"Subtarget->hasTRACEV8_4()">, + AssemblerPredicateWithAll<(all_of FeatureTRACEV8_4), "tracev8.4">; def HasAM : Predicate<"Subtarget->hasAM()">, - AssemblerPredicateWithAll<(all_of FeatureAM), "am">; - + AssemblerPredicateWithAll<(all_of FeatureAM), "am">; def HasSEL2 : Predicate<"Subtarget->hasSEL2()">, - AssemblerPredicateWithAll<(all_of FeatureSEL2), "sel2">; - -def HasTLB_RMI : Predicate<"Subtarget->hasTLB_RMI()">, - AssemblerPredicateWithAll<(all_of FeatureTLB_RMI), "tlb-rmi">; - + AssemblerPredicateWithAll<(all_of FeatureSEL2), "sel2">; +def HasTLB_RMI : Predicate<"Subtarget->hasTLB_RMI()">, + AssemblerPredicateWithAll<(all_of FeatureTLB_RMI), "tlb-rmi">; def HasFlagM : Predicate<"Subtarget->hasFlagM()">, - AssemblerPredicateWithAll<(all_of FeatureFlagM), "flagm">; - -def HasRCPC_IMMO : Predicate<"Subtarget->hasRCPC_IMMO()">, - AssemblerPredicateWithAll<(all_of FeatureRCPC_IMMO), "rcpc-immo">; - + AssemblerPredicateWithAll<(all_of FeatureFlagM), "flagm">; +def HasRCPC_IMMO : Predicate<"Subtarget->hasRCPC_IMMO()">, + AssemblerPredicateWithAll<(all_of FeatureRCPC_IMMO), "rcpc-immo">; def HasFPARMv8 : Predicate<"Subtarget->hasFPARMv8()">, - AssemblerPredicateWithAll<(all_of FeatureFPARMv8), "fp-armv8">; + AssemblerPredicateWithAll<(all_of FeatureFPARMv8), "fp-armv8">; def HasNEON : Predicate<"Subtarget->isNeonAvailable()">, AssemblerPredicateWithAll<(all_of FeatureNEON), "neon">; def HasSM4 : Predicate<"Subtarget->hasSM4()">, @@ -149,13 +130,13 @@ def HasSVE2 : Predicate<"Subtarget->isSVEAvailable() && Subtarget->hasS AssemblerPredicateWithAll<(all_of FeatureSVE2), "sve2">; def HasSVE2p1 : Predicate<"Subtarget->isSVEAvailable() && Subtarget->hasSVE2p1()">, AssemblerPredicateWithAll<(all_of FeatureSVE2p1), "sve2p1">; -def HasSVEAES : Predicate<"Subtarget->hasSVEAES()">, +def HasSVEAES : Predicate<"Subtarget->hasSVEAES()">, AssemblerPredicateWithAll<(all_of FeatureSVEAES), "sve-aes">; -def HasSVESM4 : Predicate<"Subtarget->isSVEAvailable() && Subtarget->hasSVESM4()">, +def HasSVESM4 : Predicate<"Subtarget->isSVEAvailable() && Subtarget->hasSVESM4()">, AssemblerPredicateWithAll<(all_of FeatureSVESM4), "sve-sm4">; -def HasSVESHA3 : Predicate<"Subtarget->hasSVESHA3()">, +def HasSVESHA3 : Predicate<"Subtarget->hasSVESHA3()">, AssemblerPredicateWithAll<(all_of FeatureSVESHA3), "sve-sha3">; -def HasSVEBitPerm : Predicate<"Subtarget->hasSVEBitPerm()">, +def HasSVEBitPerm : Predicate<"Subtarget->hasSVEBitPerm()">, AssemblerPredicateWithAll<(all_of FeatureSVEBitPerm), "sve-bitperm">; def HasSMEandIsNonStreamingSafe : Predicate<"Subtarget->hasSME()">, @@ -196,7 +177,7 @@ def HasSSVE_FP8DOT2 : Predicate<"Subtarget->hasSSVE_FP8DOT2() || " "(Subtarget->hasSVE2() && Subtarget->hasFP8DOT2())">, AssemblerPredicateWithAll<(any_of FeatureSSVE_FP8DOT2, (all_of FeatureSVE2, FeatureFP8DOT2)), - "ssve-fp8dot2 or (sve2 and fp8dot2)">; + "ssve-fp8dot2 or (sve2 and fp8dot2)">; def HasFP8DOT4 : Predicate<"Subtarget->hasFP8DOT4()">, AssemblerPredicateWithAll<(all_of FeatureFP8DOT4), "fp8dot4">; def HasSSVE_FP8DOT4 : Predicate<"Subtarget->hasSSVE_FP8DOT4() || " @@ -204,43 +185,60 @@ def HasSSVE_FP8DOT4 : Predicate<"Subtarget->hasSSVE_FP8DOT4() || " AssemblerPredicateWithAll<(any_of FeatureSSVE_FP8DOT4, (all_of FeatureSVE2, FeatureFP8DOT4)), "ssve-fp8dot4 or (sve2 and fp8dot4)">; -def HasLUT : Predicate<"Subtarget->hasLUT()">, +def HasLUT : Predicate<"Subtarget->hasLUT()">, AssemblerPredicateWithAll<(all_of FeatureLUT), "lut">; -def HasSME_LUTv2 : Predicate<"Subtarget->isStreaming() && Subtarget->hasSME_LUTv2()">, +def HasSME_LUTv2 : Predicate<"Subtarget->isStreaming() && Subtarget->hasSME_LUTv2()">, AssemblerPredicateWithAll<(all_of FeatureSME_LUTv2), "sme-lutv2">; -def HasSMEF8F16 : Predicate<"Subtarget->isStreaming() && Subtarget->hasSMEF8F16()">, +def HasSMEF8F16 : Predicate<"Subtarget->isStreaming() && Subtarget->hasSMEF8F16()">, AssemblerPredicateWithAll<(all_of FeatureSMEF8F16), "sme-f8f16">; -def HasSMEF8F32 : Predicate<"Subtarget->isStreaming() && Subtarget->hasSMEF8F32()">, +def HasSMEF8F32 : Predicate<"Subtarget->isStreaming() && Subtarget->hasSMEF8F32()">, AssemblerPredicateWithAll<(all_of FeatureSMEF8F32), "sme-f8f32">; -def HasSME_MOP4 : Predicate<"(Subtarget->isStreaming() && Subtarget->hasSME_MOP4())">, +def HasSME_MOP4 : Predicate<"(Subtarget->isStreaming() && Subtarget->hasSME_MOP4())">, AssemblerPredicateWithAll<(all_of FeatureSME_MOP4), "sme-mop4">; -def HasSME_TMOP : Predicate<"(Subtarget->isStreaming() && Subtarget->hasSME_TMOP())">, +def HasSME_TMOP : Predicate<"(Subtarget->isStreaming() && Subtarget->hasSME_TMOP())">, AssemblerPredicateWithAll<(all_of FeatureSME_TMOP), "sme-tmop">; - -def HasCMPBR : Predicate<"Subtarget->hasCMPBR()">, +def HasCMPBR : Predicate<"Subtarget->hasCMPBR()">, AssemblerPredicateWithAll<(all_of FeatureCMPBR), "cmpbr">; -def HasF8F32MM : Predicate<"Subtarget->hasF8F32MM()">, +def HasF8F32MM : Predicate<"Subtarget->hasF8F32MM()">, AssemblerPredicateWithAll<(all_of FeatureF8F32MM), "f8f32mm">; -def HasF8F16MM : Predicate<"Subtarget->hasF8F16MM()">, +def HasF8F16MM : Predicate<"Subtarget->hasF8F16MM()">, AssemblerPredicateWithAll<(all_of FeatureF8F16MM), "f8f16mm">; -def HasFPRCVT : Predicate<"Subtarget->hasFPRCVT()">, +def HasFPRCVT : Predicate<"Subtarget->hasFPRCVT()">, AssemblerPredicateWithAll<(all_of FeatureFPRCVT), "fprcvt">; -def HasLSFE : Predicate<"Subtarget->hasLSFE()">, +def HasLSFE : Predicate<"Subtarget->hasLSFE()">, AssemblerPredicateWithAll<(all_of FeatureLSFE), "lsfe">; -def HasSME2p2 : Predicate<"Subtarget->isStreaming() && Subtarget->hasSME2p2()">, +def HasSME2p2 : Predicate<"Subtarget->isStreaming() && Subtarget->hasSME2p2()">, AssemblerPredicateWithAll<(all_of FeatureSME2p2), "sme2p2">; -def HasSVEAES2 : Predicate<"Subtarget->hasSVEAES2()">, +def HasSVEAES2 : Predicate<"Subtarget->hasSVEAES2()">, AssemblerPredicateWithAll<(all_of FeatureSVEAES2), "sve-aes2">; -def HasSVEBFSCALE : Predicate<"Subtarget->isSVEorStreamingSVEAvailable() && Subtarget->hasSVEBFSCALE()">, +def HasSVEBFSCALE : Predicate<"Subtarget->isSVEorStreamingSVEAvailable() && Subtarget->hasSVEBFSCALE()">, AssemblerPredicateWithAll<(all_of FeatureSVEBFSCALE), "sve-bfscale">; -def HasSVE_F16F32MM : Predicate<"Subtarget->isSVEAvailable() && Subtarget->hasSVE_F16F32MM()">, +def HasSVE_F16F32MM : Predicate<"Subtarget->isSVEAvailable() && Subtarget->hasSVE_F16F32MM()">, AssemblerPredicateWithAll<(all_of FeatureSVE_F16F32MM), "sve-f16f32mm">; def HasPCDPHINT : Predicate<"Subtarget->hasPCDPHINT()">, - AssemblerPredicateWithAll<(all_of FeaturePCDPHINT), "pcdphint">; + AssemblerPredicateWithAll<(all_of FeaturePCDPHINT), "pcdphint">; def HasLSUI : Predicate<"Subtarget->hasLSUI()">, - AssemblerPredicateWithAll<(all_of FeatureLSUI), "lsui">; + AssemblerPredicateWithAll<(all_of FeatureLSUI), "lsui">; def HasOCCMO : Predicate<"Subtarget->hasOCCMO()">, - AssemblerPredicateWithAll<(all_of FeatureOCCMO), "occmo">; + AssemblerPredicateWithAll<(all_of FeatureOCCMO), "occmo">; +def HasCMH : Predicate<"Subtarget->hasCMH()">, + AssemblerPredicateWithAll<(all_of FeatureCMH), "cmh">; +def HasLSCP : Predicate<"Subtarget->hasLSCP()">, + AssemblerPredicateWithAll<(all_of FeatureLSCP), "lscp">; +def HasSVE2p2 : Predicate<"Subtarget->hasSVE2p2()">, + AssemblerPredicateWithAll<(all_of FeatureSVE2p2), "sve2p2">; +def HasSVE_B16MM : Predicate<"Subtarget->isSVEAvailable() && Subtarget->hasSVE_B16MM()">, + AssemblerPredicateWithAll<(all_of FeatureSVE_B16MM), "sve-b16mm">; +def HasF16MM : Predicate<"Subtarget->isSVEAvailable() && Subtarget->hasF16MM()">, + AssemblerPredicateWithAll<(all_of FeatureF16MM), "f16mm">; +def HasSVE2p3 : Predicate<"Subtarget->hasSVE2p3()">, + AssemblerPredicateWithAll<(all_of FeatureSVE2p3), "sve2p3">; +def HasSME2p3 : Predicate<"Subtarget->hasSME2p3()">, + AssemblerPredicateWithAll<(all_of FeatureSME2p3), "sme2p3">; +def HasF16F32DOT : Predicate<"Subtarget->hasF16F32DOT()">, + AssemblerPredicateWithAll<(all_of FeatureF16F32DOT), "f16f32dot">; +def HasF16F32MM : Predicate<"Subtarget->hasF16F32MM()">, + AssemblerPredicateWithAll<(all_of FeatureF16F32MM), "f16f32mm">; // A subset of SVE(2) instructions are legal in Streaming SVE execution mode, // they should be enabled if either has been specified. @@ -310,6 +308,10 @@ def HasSVE2p2_or_SME2p2 : Predicate<"Subtarget->isSVEorStreamingSVEAvailable() && (Subtarget->hasSVE2p2() || Subtarget->hasSME2p2())">, AssemblerPredicateWithAll<(any_of FeatureSME2p2, FeatureSVE2p2), "sme2p2 or sve2p2">; +def HasSVE2p3_or_SME2p3 + : Predicate<"Subtarget->isSVEorStreamingSVEAvailable() && (Subtarget->hasSVE2p3() || Subtarget->hasSME2p3())">, + AssemblerPredicateWithAll<(any_of FeatureSME2p3, FeatureSVE2p3), + "sme2p3 or sve2p3">; def HasNonStreamingSVE2p2_or_SME2p2 : Predicate<"(Subtarget->isSVEAvailable() && Subtarget->hasSVE2p2()) ||" "(Subtarget->isSVEorStreamingSVEAvailable() && Subtarget->hasSME2p2())">, @@ -328,100 +330,110 @@ def HasNEONandIsStreamingSafe AssemblerPredicateWithAll<(any_of FeatureNEON), "neon">; // A subset of NEON instructions are legal in Streaming SVE mode only with +sme2p2. def HasNEONandIsSME2p2StreamingSafe - : Predicate<"Subtarget->isNeonAvailable() || (Subtarget->hasNEON() && Subtarget->hasSME2p2())">, - AssemblerPredicateWithAll<(any_of FeatureNEON), "neon">; + : Predicate<"Subtarget->isNeonAvailable() || (Subtarget->hasNEON() && Subtarget->hasSME2p2())">, + AssemblerPredicateWithAll<(any_of FeatureNEON), "neon">; def HasRCPC : Predicate<"Subtarget->hasRCPC()">, AssemblerPredicateWithAll<(all_of FeatureRCPC), "rcpc">; def HasAltNZCV : Predicate<"Subtarget->hasAlternativeNZCV()">, - AssemblerPredicateWithAll<(all_of FeatureAltFPCmp), "altnzcv">; + AssemblerPredicateWithAll<(all_of FeatureAltFPCmp), "altnzcv">; def HasFRInt3264 : Predicate<"Subtarget->hasFRInt3264()">, - AssemblerPredicateWithAll<(all_of FeatureFRInt3264), "frint3264">; + AssemblerPredicateWithAll<(all_of FeatureFRInt3264), "frint3264">; def HasSB : Predicate<"Subtarget->hasSB()">, - AssemblerPredicateWithAll<(all_of FeatureSB), "sb">; -def HasPredRes : Predicate<"Subtarget->hasPredRes()">, - AssemblerPredicateWithAll<(all_of FeaturePredRes), "predres">; + AssemblerPredicateWithAll<(all_of FeatureSB), "sb">; +def HasPredRes : Predicate<"Subtarget->hasPredRes()">, + AssemblerPredicateWithAll<(all_of FeaturePredRes), "predres">; def HasCCDP : Predicate<"Subtarget->hasCCDP()">, - AssemblerPredicateWithAll<(all_of FeatureCacheDeepPersist), "ccdp">; + AssemblerPredicateWithAll<(all_of FeatureCacheDeepPersist), "ccdp">; def HasBTI : Predicate<"Subtarget->hasBTI()">, - AssemblerPredicateWithAll<(all_of FeatureBranchTargetId), "bti">; + AssemblerPredicateWithAll<(all_of FeatureBranchTargetId), "bti">; def HasMTE : Predicate<"Subtarget->hasMTE()">, - AssemblerPredicateWithAll<(all_of FeatureMTE), "mte">; + AssemblerPredicateWithAll<(all_of FeatureMTE), "mte">; def HasTME : Predicate<"Subtarget->hasTME()">, - AssemblerPredicateWithAll<(all_of FeatureTME), "tme">; + AssemblerPredicateWithAll<(all_of FeatureTME), "tme">; def HasETE : Predicate<"Subtarget->hasETE()">, - AssemblerPredicateWithAll<(all_of FeatureETE), "ete">; + AssemblerPredicateWithAll<(all_of FeatureETE), "ete">; def HasTRBE : Predicate<"Subtarget->hasTRBE()">, - AssemblerPredicateWithAll<(all_of FeatureTRBE), "trbe">; + AssemblerPredicateWithAll<(all_of FeatureTRBE), "trbe">; def HasBF16 : Predicate<"Subtarget->hasBF16()">, - AssemblerPredicateWithAll<(all_of FeatureBF16), "bf16">; + AssemblerPredicateWithAll<(all_of FeatureBF16), "bf16">; def HasNoBF16 : Predicate<"!Subtarget->hasBF16()">; def HasMatMulInt8 : Predicate<"Subtarget->hasMatMulInt8()">, - AssemblerPredicateWithAll<(all_of FeatureMatMulInt8), "i8mm">; + AssemblerPredicateWithAll<(all_of FeatureMatMulInt8), "i8mm">; def HasMatMulFP32 : Predicate<"Subtarget->hasMatMulFP32()">, - AssemblerPredicateWithAll<(all_of FeatureMatMulFP32), "f32mm">; + AssemblerPredicateWithAll<(all_of FeatureMatMulFP32), "f32mm">; def HasMatMulFP64 : Predicate<"Subtarget->hasMatMulFP64()">, - AssemblerPredicateWithAll<(all_of FeatureMatMulFP64), "f64mm">; + AssemblerPredicateWithAll<(all_of FeatureMatMulFP64), "f64mm">; def HasXS : Predicate<"Subtarget->hasXS()">, - AssemblerPredicateWithAll<(all_of FeatureXS), "xs">; + AssemblerPredicateWithAll<(all_of FeatureXS), "xs">; def HasWFxT : Predicate<"Subtarget->hasWFxT()">, - AssemblerPredicateWithAll<(all_of FeatureWFxT), "wfxt">; + AssemblerPredicateWithAll<(all_of FeatureWFxT), "wfxt">; def HasLS64 : Predicate<"Subtarget->hasLS64()">, - AssemblerPredicateWithAll<(all_of FeatureLS64), "ls64">; + AssemblerPredicateWithAll<(all_of FeatureLS64), "ls64">; def HasBRBE : Predicate<"Subtarget->hasBRBE()">, - AssemblerPredicateWithAll<(all_of FeatureBRBE), "brbe">; + AssemblerPredicateWithAll<(all_of FeatureBRBE), "brbe">; def HasSPE_EEF : Predicate<"Subtarget->hasSPE_EEF()">, - AssemblerPredicateWithAll<(all_of FeatureSPE_EEF), "spe-eef">; + AssemblerPredicateWithAll<(all_of FeatureSPE_EEF), "spe-eef">; def HasHBC : Predicate<"Subtarget->hasHBC()">, - AssemblerPredicateWithAll<(all_of FeatureHBC), "hbc">; + AssemblerPredicateWithAll<(all_of FeatureHBC), "hbc">; def HasMOPS : Predicate<"Subtarget->hasMOPS()">, - AssemblerPredicateWithAll<(all_of FeatureMOPS), "mops">; + AssemblerPredicateWithAll<(all_of FeatureMOPS), "mops">; def HasCLRBHB : Predicate<"Subtarget->hasCLRBHB()">, - AssemblerPredicateWithAll<(all_of FeatureCLRBHB), "clrbhb">; + AssemblerPredicateWithAll<(all_of FeatureCLRBHB), "clrbhb">; def HasSPECRES2 : Predicate<"Subtarget->hasSPECRES2()">, - AssemblerPredicateWithAll<(all_of FeatureSPECRES2), "specres2">; + AssemblerPredicateWithAll<(all_of FeatureSPECRES2), "specres2">; def HasITE : Predicate<"Subtarget->hasITE()">, - AssemblerPredicateWithAll<(all_of FeatureITE), "ite">; + AssemblerPredicateWithAll<(all_of FeatureITE), "ite">; def HasTHE : Predicate<"Subtarget->hasTHE()">, - AssemblerPredicateWithAll<(all_of FeatureTHE), "the">; + AssemblerPredicateWithAll<(all_of FeatureTHE), "the">; def HasRCPC3 : Predicate<"Subtarget->hasRCPC3()">, - AssemblerPredicateWithAll<(all_of FeatureRCPC3), "rcpc3">; + AssemblerPredicateWithAll<(all_of FeatureRCPC3), "rcpc3">; def HasLSE128 : Predicate<"Subtarget->hasLSE128()">, - AssemblerPredicateWithAll<(all_of FeatureLSE128), "lse128">; + AssemblerPredicateWithAll<(all_of FeatureLSE128), "lse128">; def HasD128 : Predicate<"Subtarget->hasD128()">, - AssemblerPredicateWithAll<(all_of FeatureD128), "d128">; + AssemblerPredicateWithAll<(all_of FeatureD128), "d128">; def HasCHK : Predicate<"Subtarget->hasCHK()">, - AssemblerPredicateWithAll<(all_of FeatureCHK), "chk">; + AssemblerPredicateWithAll<(all_of FeatureCHK), "chk">; def HasGCS : Predicate<"Subtarget->hasGCS()">, - AssemblerPredicateWithAll<(all_of FeatureGCS), "gcs">; + AssemblerPredicateWithAll<(all_of FeatureGCS), "gcs">; def HasCPA : Predicate<"Subtarget->hasCPA()">, - AssemblerPredicateWithAll<(all_of FeatureCPA), "cpa">; + AssemblerPredicateWithAll<(all_of FeatureCPA), "cpa">; +def HasTLBID : Predicate<"Subtarget->hasTLBID()">, + AssemblerPredicateWithAll<(all_of FeatureTLBID), "tlbid">; +def HasMPAMv2 : Predicate<"Subtarget->hasMPAMv2()">, + AssemblerPredicateWithAll<(all_of FeatureMPAMv2), "mpamv2">; +def HasMTETC : Predicate<"Subtarget->hasMTETC()">, + AssemblerPredicateWithAll<(all_of FeatureMTETC), "mtetc">; +def HasGCIE : Predicate<"Subtarget->hasGCIE()">, + AssemblerPredicateWithAll<(all_of FeatureGCIE), "gcie">; def IsLE : Predicate<"Subtarget->isLittleEndian()">; def IsBE : Predicate<"!Subtarget->isLittleEndian()">; def IsWindows : Predicate<"Subtarget->isTargetWindows()">; def UseExperimentalZeroingPseudos - : Predicate<"Subtarget->useExperimentalZeroingPseudos()">; + : Predicate<"Subtarget->useExperimentalZeroingPseudos()">; def UseAlternateSExtLoadCVTF32 - : Predicate<"Subtarget->useAlternateSExtLoadCVTF32Pattern()">; + : Predicate<"Subtarget->useAlternateSExtLoadCVTF32Pattern()">; def UseNegativeImmediates - : Predicate<"false">, AssemblerPredicate<(all_of (not FeatureNoNegativeImmediates)), - "NegativeImmediates">; + : Predicate<"false">, + AssemblerPredicate<(all_of (not FeatureNoNegativeImmediates)), + "NegativeImmediates">; -def UseScalarIncVL : Predicate<"Subtarget->useScalarIncVL()">; +def UseScalarIncVL : Predicate<"Subtarget->useScalarIncVL()">; def NoUseScalarIncVL : Predicate<"!Subtarget->useScalarIncVL()">; -def HasFastIncVL : Predicate<"!Subtarget->hasDisableFastIncVL()">; +def HasFastIncVL : Predicate<"!Subtarget->hasDisableFastIncVL()">; -def UseSVEFPLD1R : Predicate<"!Subtarget->noSVEFPLD1R()">; +def UseSVEFPLD1R : Predicate<"!Subtarget->noSVEFPLD1R()">; -def UseLDAPUR : Predicate<"!Subtarget->avoidLDAPUR()">; +def UseLDAPUR : Predicate<"!Subtarget->avoidLDAPUR()">; def AArch64LocalRecover : SDNode<"ISD::LOCAL_RECOVER", SDTypeProfile<1, 1, [SDTCisSameAs<0, 1>, SDTCisInt<1>]>>; -def AllowMisalignedMemAccesses : Predicate<"!Subtarget->requiresStrictAlign()">; +def AllowMisalignedMemAccesses + : Predicate<"!Subtarget->requiresStrictAlign()">; def UseWzrToVecMove : Predicate<"Subtarget->useWzrToVecMove()">; @@ -3692,6 +3704,12 @@ def UDF : UDFType<0, "udf">; // Load instructions. //===----------------------------------------------------------------------===// +let Predicates = [HasLSCP] in { +defm LDAP : LoadAcquirePairOffset<0b0101, "ldap">; +defm LDAPP : LoadAcquirePairOffset<0b0111, "ldapp">; +defm STLP : StoreAcquirePairOffset<0b0101, "stlp">; +} + // Pair (indexed, offset) defm LDPW : LoadPairOffset<0b00, 0, GPR32z, simm7s4, "ldp">; defm LDPX : LoadPairOffset<0b10, 0, GPR64z, simm7s8, "ldp">; @@ -4005,24 +4023,20 @@ def : Pat<(i64 (zextloadi32 (am_indexed32 GPR64sp:$Rn, uimm12s4:$offset))), (SUBREG_TO_REG (i64 0), (LDRWui GPR64sp:$Rn, uimm12s4:$offset), sub_32)>; // load zero-extended i32, bitcast to f64 -def : Pat <(f64 (bitconvert (i64 (zextloadi32 (am_indexed32 GPR64sp:$Rn, uimm12s4:$offset))))), - (SUBREG_TO_REG (i64 0), (LDRSui GPR64sp:$Rn, uimm12s4:$offset), ssub)>; - +def : Pat<(f64 (bitconvert (i64 (zextloadi32 (am_indexed32 GPR64sp:$Rn, uimm12s4:$offset))))), + (SUBREG_TO_REG (i64 0), (LDRSui GPR64sp:$Rn, uimm12s4:$offset), ssub)>; // load zero-extended i16, bitcast to f64 -def : Pat <(f64 (bitconvert (i64 (zextloadi16 (am_indexed32 GPR64sp:$Rn, uimm12s2:$offset))))), - (SUBREG_TO_REG (i64 0), (LDRHui GPR64sp:$Rn, uimm12s2:$offset), hsub)>; - +def : Pat<(f64 (bitconvert (i64 (zextloadi16 (am_indexed16 GPR64sp:$Rn, uimm12s2:$offset))))), + (SUBREG_TO_REG (i64 0), (LDRHui GPR64sp:$Rn, uimm12s2:$offset), hsub)>; // load zero-extended i8, bitcast to f64 -def : Pat <(f64 (bitconvert (i64 (zextloadi8 (am_indexed32 GPR64sp:$Rn, uimm12s1:$offset))))), - (SUBREG_TO_REG (i64 0), (LDRBui GPR64sp:$Rn, uimm12s1:$offset), bsub)>; - +def : Pat<(f64 (bitconvert (i64 (zextloadi8 (am_indexed8 GPR64sp:$Rn, uimm12s1:$offset))))), + (SUBREG_TO_REG (i64 0), (LDRBui GPR64sp:$Rn, uimm12s1:$offset), bsub)>; // load zero-extended i16, bitcast to f32 -def : Pat <(f32 (bitconvert (i32 (zextloadi16 (am_indexed16 GPR64sp:$Rn, uimm12s2:$offset))))), - (SUBREG_TO_REG (i32 0), (LDRHui GPR64sp:$Rn, uimm12s2:$offset), hsub)>; - +def : Pat<(f32 (bitconvert (i32 (zextloadi16 (am_indexed16 GPR64sp:$Rn, uimm12s2:$offset))))), + (SUBREG_TO_REG (i32 0), (LDRHui GPR64sp:$Rn, uimm12s2:$offset), hsub)>; // load zero-extended i8, bitcast to f32 -def : Pat <(f32 (bitconvert (i32 (zextloadi8 (am_indexed16 GPR64sp:$Rn, uimm12s1:$offset))))), - (SUBREG_TO_REG (i32 0), (LDRBui GPR64sp:$Rn, uimm12s1:$offset), bsub)>; +def : Pat<(f32 (bitconvert (i32 (zextloadi8 (am_indexed8 GPR64sp:$Rn, uimm12s1:$offset))))), + (SUBREG_TO_REG (i32 0), (LDRBui GPR64sp:$Rn, uimm12s1:$offset), bsub)>; // Pre-fetch. def PRFMui : PrefetchUI<0b11, 0, 0b10, "prfm", @@ -11248,8 +11262,28 @@ let Predicates = [HasLSFE] in { def STBFMINNML : BaseAtomicFPStore<FPR16, 0b00, 0b1, 0b111, "stbfminnml">; } +let Predicates = [HasF16F32DOT] in { + defm FDOT :SIMDThreeSameVectorFDot<"fdot">; + defm FDOTlane: SIMDThreeSameVectorFDOTIndex<"fdot">; +} + +let Predicates = [HasF16MM] in + defm FMMLA : SIMDThreeSameVectorFMLA<"fmmla">; + +let Predicates = [HasF16F32MM] in + defm FMMLA : SIMDThreeSameVectorFMLAWiden<"fmmla">; + let Uses = [FPMR, FPCR] in -defm FMMLA : SIMDThreeSameVectorFP8MatrixMul<"fmmla">; + defm FMMLA : SIMDThreeSameVectorFP8MatrixMul<"fmmla">; + +//===----------------------------------------------------------------------===// +// Contention Management Hints (FEAT_CMH) +//===----------------------------------------------------------------------===// + +let Predicates = [HasCMH] in { + defm SHUH : SHUH<"shuh">; // Shared Update Hint instruction + def STCPH : STCPHInst<"stcph">; // Store Concurrent Priority Hint instruction +} include "AArch64InstrAtomics.td" include "AArch64SVEInstrInfo.td" diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td index 47144c7..cd94a25 100644 --- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td +++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td @@ -1341,6 +1341,10 @@ def Z_q : RegisterOperand<ZPR, "printTypedVectorList<0,'q'>"> { let ParserMatchClass = ZPRVectorList<128, 1>; } +def ZZ_Any : RegisterOperand<ZPR2, "printTypedVectorList<0,0>"> { + let ParserMatchClass = ZPRVectorList<0, 2>; +} + def ZZ_b : RegisterOperand<ZPR2, "printTypedVectorList<0,'b'>"> { let ParserMatchClass = ZPRVectorList<8, 2>; } @@ -1361,6 +1365,10 @@ def ZZ_q : RegisterOperand<ZPR2, "printTypedVectorList<0,'q'>"> { let ParserMatchClass = ZPRVectorList<128, 2>; } +def ZZZ_Any : RegisterOperand<ZPR3, "printTypedVectorList<0,0>"> { + let ParserMatchClass = ZPRVectorList<0, 3>; +} + def ZZZ_b : RegisterOperand<ZPR3, "printTypedVectorList<0,'b'>"> { let ParserMatchClass = ZPRVectorList<8, 3>; } diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td index e552afe..752b185 100644 --- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td @@ -1173,3 +1173,14 @@ let Predicates = [HasSME_MOP4, HasSMEF64F64] in { defm FMOP4A : sme2_fmop4as_fp64_non_widening<0, "fmop4a", "int_aarch64_sme_mop4a">; defm FMOP4S : sme2_fmop4as_fp64_non_widening<1, "fmop4s", "int_aarch64_sme_mop4s">; } + +//===----------------------------------------------------------------------===// +// SME2.3 instructions +//===----------------------------------------------------------------------===// +let Predicates = [HasSME2p3] in { + def LUTI6_ZTZ : sme2_lut_single<"luti6">; + def LUTI6_4ZT3Z : sme2_luti6_zt_consecutive<"luti6">; + def LUTI6_S_4ZT3Z : sme2_luti6_zt_strided<"luti6">; + def LUTI6_4Z2Z2ZI : sme2_luti6_vector_vg4_consecutive<"luti6">; + def LUTI6_S_4Z2Z2ZI : sme2_luti6_vector_vg4_strided<"luti6">; +} // [HasSME2p3] diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td index 98a128e..3b268dc 100644 --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -2569,7 +2569,7 @@ let Predicates = [HasBF16, HasSVE_or_SME] in { } // End HasBF16, HasSVE_or_SME let Predicates = [HasBF16, HasSVE] in { - defm BFMMLA_ZZZ_HtoS : sve_fp_matrix_mla<0b01, "bfmmla", ZPR32, ZPR16, int_aarch64_sve_bfmmla, nxv4f32, nxv8bf16>; + defm BFMMLA_ZZZ_HtoS : sve_fp_matrix_mla<0b011, "bfmmla", ZPR32, ZPR16, int_aarch64_sve_bfmmla, nxv4f32, nxv8bf16>; } // End HasBF16, HasSVE let Predicates = [HasBF16, HasSVE_or_SME] in { @@ -3680,15 +3680,15 @@ let Predicates = [HasSVE_or_SME, HasMatMulInt8] in { } // End HasSVE_or_SME, HasMatMulInt8 let Predicates = [HasSVE, HasMatMulFP32] in { - defm FMMLA_ZZZ_S : sve_fp_matrix_mla<0b10, "fmmla", ZPR32, ZPR32, int_aarch64_sve_fmmla, nxv4f32, nxv4f32>; + defm FMMLA_ZZZ_S : sve_fp_matrix_mla<0b101, "fmmla", ZPR32, ZPR32, int_aarch64_sve_fmmla, nxv4f32, nxv4f32>; } // End HasSVE, HasMatMulFP32 let Predicates = [HasSVE_F16F32MM] in { - def FMLLA_ZZZ_HtoS : sve_fp_matrix_mla<0b00, "fmmla", ZPR32, ZPR16>; + def FMLLA_ZZZ_HtoS : sve_fp_matrix_mla<0b001, "fmmla", ZPR32, ZPR16>; } // End HasSVE_F16F32MM let Predicates = [HasSVE, HasMatMulFP64] in { - defm FMMLA_ZZZ_D : sve_fp_matrix_mla<0b11, "fmmla", ZPR64, ZPR64, int_aarch64_sve_fmmla, nxv2f64, nxv2f64>; + defm FMMLA_ZZZ_D : sve_fp_matrix_mla<0b111, "fmmla", ZPR64, ZPR64, int_aarch64_sve_fmmla, nxv2f64, nxv2f64>; defm LD1RO_B_IMM : sve_mem_ldor_si<0b00, "ld1rob", Z_b, ZPR8, nxv16i8, nxv16i1, AArch64ld1ro_z>; defm LD1RO_H_IMM : sve_mem_ldor_si<0b01, "ld1roh", Z_h, ZPR16, nxv8i16, nxv8i1, AArch64ld1ro_z>; defm LD1RO_W_IMM : sve_mem_ldor_si<0b10, "ld1row", Z_s, ZPR32, nxv4i32, nxv4i1, AArch64ld1ro_z>; @@ -4272,9 +4272,9 @@ def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv8i16:$MulLHS, nxv8i16:$ defm SQCVTN_Z2Z_StoH : sve2p1_multi_vec_extract_narrow<"sqcvtn", 0b00, int_aarch64_sve_sqcvtn_x2>; defm UQCVTN_Z2Z_StoH : sve2p1_multi_vec_extract_narrow<"uqcvtn", 0b01, int_aarch64_sve_uqcvtn_x2>; defm SQCVTUN_Z2Z_StoH : sve2p1_multi_vec_extract_narrow<"sqcvtun", 0b10, int_aarch64_sve_sqcvtun_x2>; -defm SQRSHRN_Z2ZI_StoH : sve2p1_multi_vec_shift_narrow<"sqrshrn", 0b101, int_aarch64_sve_sqrshrn_x2>; -defm UQRSHRN_Z2ZI_StoH : sve2p1_multi_vec_shift_narrow<"uqrshrn", 0b111, int_aarch64_sve_uqrshrn_x2>; -defm SQRSHRUN_Z2ZI_StoH : sve2p1_multi_vec_shift_narrow<"sqrshrun", 0b001, int_aarch64_sve_sqrshrun_x2>; +defm SQRSHRN_Z2ZI_StoH : sve_multi_vec_shift_narrow<"sqrshrn", 0b101, int_aarch64_sve_sqrshrn_x2>; +defm UQRSHRN_Z2ZI_StoH : sve_multi_vec_shift_narrow<"uqrshrn", 0b111, int_aarch64_sve_uqrshrn_x2>; +defm SQRSHRUN_Z2ZI_StoH : sve_multi_vec_shift_narrow<"sqrshrun", 0b001, int_aarch64_sve_sqrshrun_x2>; defm WHILEGE_2PXX : sve2p1_int_while_rr_pair<"whilege", 0b000>; defm WHILEGT_2PXX : sve2p1_int_while_rr_pair<"whilegt", 0b001>; @@ -4615,6 +4615,75 @@ let Predicates = [HasSVE2p2_or_SME2p2] in { defm REVD_ZPzZ : sve_int_perm_rev_revd_z<"revd", AArch64revd_mt>; } // End HasSME2p2orSVE2p2 + +//===----------------------------------------------------------------------===// +// SME2.3 or SVE2.3 instructions +//===----------------------------------------------------------------------===// +let Predicates = [HasSVE2p3_or_SME2p3] in { + // SVE2 Add pairwise within quadword vector segments (unpredicated) + defm ADDQP_ZZZ : sve2_int_mul<0b110, "addqp", null_frag>; + + // SVE2 Add subtract/subtract pairwise + defm ADDSUBP_ZZZ : sve2_int_mul<0b111, "addsubp", null_frag>; + defm SUBP_ZPmZZ : sve2_int_arith_pred<0b100001, "subp", null_frag>; + + // SVE2 integer absolute difference and accumulate long + defm SABAL_ZZZ : sve2_int_two_way_absdiff_accum_long<0b0, "sabal">; + defm UABAL_ZZZ : sve2_int_two_way_absdiff_accum_long<0b1, "uabal">; + + // SVE2 integer dot product + def SDOT_ZZZ_BtoH : sve_intx_dot<0b01, 0b00000, 0b0, "sdot", ZPR16, ZPR8>; + def UDOT_ZZZ_BtoH : sve_intx_dot<0b01, 0b00000, 0b1, "udot", ZPR16, ZPR8>; + + // SVE2 integer indexed dot product + def SDOT_ZZZI_BtoH : sve_intx_dot_by_indexed_elem_x<0b0, "sdot">; + def UDOT_ZZZI_BtoH : sve_intx_dot_by_indexed_elem_x<0b1, "udot">; + + // SVE2 fp convert, narrow and interleave to integer, rounding toward zero + defm FCVTZSN_Z2Z : sve2_fp_to_int_downcvt<"fcvtzsn", 0b0>; + defm FCVTZUN_Z2Z : sve2_fp_to_int_downcvt<"fcvtzun", 0b1>; + + // SVE2 signed/unsigned integer convert to floating-point + defm SCVTF_ZZ : sve2_int_to_fp_upcvt<"scvtf", 0b00>; + defm SCVTFLT_ZZ : sve2_int_to_fp_upcvt<"scvtflt", 0b10>; + defm UCVTF_ZZ : sve2_int_to_fp_upcvt<"ucvtf", 0b01>; + defm UCVTFLT_ZZ : sve2_int_to_fp_upcvt<"ucvtflt", 0b11>; + + // SVE2 saturating shift right narrow by immediate and interleave + defm SQRSHRN_Z2ZI_HtoB : sve_multi_vec_round_shift_narrow<"sqrshrn", 0b101>; + defm SQRSHRUN_Z2ZI_HtoB : sve_multi_vec_round_shift_narrow<"sqrshrun", 0b001>; + defm SQSHRN_Z2ZI_HtoB : sve_multi_vec_round_shift_narrow<"sqshrn", 0b000>; + defm SQSHRUN_Z2ZI_HtoB : sve_multi_vec_round_shift_narrow<"sqshrun", 0b100>; + defm UQRSHRN_Z2ZI_HtoB : sve_multi_vec_round_shift_narrow<"uqrshrn", 0b111>; + defm UQSHRN_Z2ZI_HtoB : sve_multi_vec_round_shift_narrow<"uqshrn", 0b010>; + defm SQSHRUN_Z2ZI_StoH : sve_multi_vec_shift_narrow<"sqshrun", 0b100, null_frag>; + defm SQSHRN_Z2ZI_StoH : sve_multi_vec_shift_narrow<"sqshrn", 0b000, null_frag>; + defm UQSHRN_Z2ZI_StoH : sve_multi_vec_shift_narrow<"uqshrn", 0b010, null_frag>; + + defm LUTI6_Z2ZZI : sve2_luti6_vector_index<"luti6">; +} // End HasSME2p3orSVE2p3 + +//===----------------------------------------------------------------------===// +// SVE2.3 instructions +//===----------------------------------------------------------------------===// +let Predicates = [HasSVE2p3] in { + def LUTI6_Z2ZZ : sve2_luti6_vector<"luti6">; +} + +//===----------------------------------------------------------------------===// +// SVE_B16MM Instructions +//===----------------------------------------------------------------------===// +let Predicates = [HasSVE_B16MM] in { + def BFMMLA_ZZZ_H : sve_fp_matrix_mla<0b110, "bfmmla", ZPR16, ZPR16>; +} + +//===----------------------------------------------------------------------===// +// F16MM Instructions +//===----------------------------------------------------------------------===// +let Predicates = [HasSVE2p2, HasF16MM] in { + def FMMLA_ZZZ_H : sve_fp_matrix_mla<0b100, "fmmla", ZPR16, ZPR16>; +} + //===----------------------------------------------------------------------===// // SME2.2 or SVE2.2 instructions - Legal in streaming mode iff target has SME2p2 //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Target/AArch64/AArch64SystemOperands.td b/llvm/lib/Target/AArch64/AArch64SystemOperands.td index 9438917..ae46d71 100644 --- a/llvm/lib/Target/AArch64/AArch64SystemOperands.td +++ b/llvm/lib/Target/AArch64/AArch64SystemOperands.td @@ -205,6 +205,7 @@ def lookupDCByName : SearchIndex { let Key = ["Name"]; } +// Op1 CRn CRm Op2 def : DC<"ZVA", 0b011, 0b0111, 0b0100, 0b001>; def : DC<"IVAC", 0b000, 0b0111, 0b0110, 0b001>; def : DC<"ISW", 0b000, 0b0111, 0b0110, 0b010>; @@ -241,6 +242,11 @@ def : DC<"CIGDVAC", 0b011, 0b0111, 0b1110, 0b101>; def : DC<"GZVA", 0b011, 0b0111, 0b0100, 0b100>; } +let Requires = [{ {AArch64::FeatureMTETC} }] in { +def : DC<"ZGBVA", 0b011, 0b0111, 0b0100, 0b101>; +def : DC<"GBVA", 0b011, 0b0111, 0b0100, 0b111>; +} + let Requires = [{ {AArch64::FeatureMEC} }] in { def : DC<"CIPAE", 0b100, 0b0111, 0b1110, 0b000>; def : DC<"CIGDPAE", 0b100, 0b0111, 0b1110, 0b111>; @@ -813,11 +819,26 @@ def : BTI<"j", 0b100>; def : BTI<"jc", 0b110>; //===----------------------------------------------------------------------===// +// CMHPriority instruction options. +//===----------------------------------------------------------------------===// + +class CMHPriorityHint<string name, bits<1> encoding> : SearchableTable { + let SearchableFields = ["Name", "Encoding"]; + let EnumValueField = "Encoding"; + + string Name = name; + bits<1> Encoding; + let Encoding = encoding; +} + +def : CMHPriorityHint<"ph", 0b1>; + +//===----------------------------------------------------------------------===// // TLBI (translation lookaside buffer invalidate) instruction options. //===----------------------------------------------------------------------===// class TLBICommon<string name, bits<3> op1, bits<4> crn, bits<4> crm, - bits<3> op2, bit needsreg> { + bits<3> op2, bit needsreg, bit optionalreg> { string Name = name; bits<14> Encoding; let Encoding{13-11} = op1; @@ -825,24 +846,25 @@ class TLBICommon<string name, bits<3> op1, bits<4> crn, bits<4> crm, let Encoding{6-3} = crm; let Encoding{2-0} = op2; bit NeedsReg = needsreg; + bit OptionalReg = optionalreg; list<string> Requires = []; list<string> ExtraRequires = []; code RequiresStr = [{ { }] # !interleave(Requires # ExtraRequires, [{, }]) # [{ } }]; } class TLBIEntry<string name, bits<3> op1, bits<4> crn, bits<4> crm, - bits<3> op2, bit needsreg> - : TLBICommon<name, op1, crn, crm, op2, needsreg>; + bits<3> op2, bit needsreg, bit optionalreg> + : TLBICommon<name, op1, crn, crm, op2, needsreg, optionalreg>; class TLBIPEntry<string name, bits<3> op1, bits<4> crn, bits<4> crm, - bits<3> op2, bit needsreg> - : TLBICommon<name, op1, crn, crm, op2, needsreg>; + bits<3> op2, bit needsreg, bit optionalreg> + : TLBICommon<name, op1, crn, crm, op2, needsreg, optionalreg>; multiclass TLBITableBase { def NAME # Table : GenericTable { let FilterClass = NAME # "Entry"; let CppTypeName = NAME; - let Fields = ["Name", "Encoding", "NeedsReg", "RequiresStr"]; + let Fields = ["Name", "Encoding", "NeedsReg", "OptionalReg", "RequiresStr"]; let PrimaryKey = ["Encoding"]; let PrimaryKeyName = "lookup" # NAME # "ByEncoding"; } @@ -856,60 +878,60 @@ defm TLBI : TLBITableBase; defm TLBIP : TLBITableBase; multiclass TLBI<string name, bit hasTLBIP, bits<3> op1, bits<4> crn, bits<4> crm, - bits<3> op2, bit needsreg = 1> { - def : TLBIEntry<name, op1, crn, crm, op2, needsreg>; - def : TLBIEntry<!strconcat(name, "nXS"), op1, crn, crm, op2, needsreg> { + bits<3> op2, bit needsreg = 1, bit optionalreg = 0> { + def : TLBIEntry<name, op1, crn, crm, op2, needsreg, optionalreg>; + def : TLBIEntry<!strconcat(name, "nXS"), op1, crn, crm, op2, needsreg, optionalreg> { let Encoding{7} = 1; let ExtraRequires = ["AArch64::FeatureXS"]; } if !eq(hasTLBIP, true) then { - def : TLBIPEntry<name, op1, crn, crm, op2, needsreg>; - def : TLBIPEntry<!strconcat(name, "nXS"), op1, crn, crm, op2, needsreg> { + def : TLBIPEntry<name, op1, crn, crm, op2, needsreg, optionalreg>; + def : TLBIPEntry<!strconcat(name, "nXS"), op1, crn, crm, op2, needsreg, optionalreg> { let Encoding{7} = 1; let ExtraRequires = ["AArch64::FeatureXS"]; } } } -// hasTLBIP op1 CRn CRm op2 needsreg +// hasTLBIP op1 CRn CRm op2 needsreg, optreg defm : TLBI<"IPAS2E1IS", 1, 0b100, 0b1000, 0b0000, 0b001>; defm : TLBI<"IPAS2LE1IS", 1, 0b100, 0b1000, 0b0000, 0b101>; -defm : TLBI<"VMALLE1IS", 0, 0b000, 0b1000, 0b0011, 0b000, 0>; -defm : TLBI<"ALLE2IS", 0, 0b100, 0b1000, 0b0011, 0b000, 0>; -defm : TLBI<"ALLE3IS", 0, 0b110, 0b1000, 0b0011, 0b000, 0>; +defm : TLBI<"VMALLE1IS", 0, 0b000, 0b1000, 0b0011, 0b000, 0, 1>; +defm : TLBI<"ALLE2IS", 0, 0b100, 0b1000, 0b0011, 0b000, 0, 1>; +defm : TLBI<"ALLE3IS", 0, 0b110, 0b1000, 0b0011, 0b000, 0, 1>; defm : TLBI<"VAE1IS", 1, 0b000, 0b1000, 0b0011, 0b001>; defm : TLBI<"VAE2IS", 1, 0b100, 0b1000, 0b0011, 0b001>; defm : TLBI<"VAE3IS", 1, 0b110, 0b1000, 0b0011, 0b001>; defm : TLBI<"ASIDE1IS", 0, 0b000, 0b1000, 0b0011, 0b010>; defm : TLBI<"VAAE1IS", 1, 0b000, 0b1000, 0b0011, 0b011>; -defm : TLBI<"ALLE1IS", 0, 0b100, 0b1000, 0b0011, 0b100, 0>; +defm : TLBI<"ALLE1IS", 0, 0b100, 0b1000, 0b0011, 0b100, 0, 1>; defm : TLBI<"VALE1IS", 1, 0b000, 0b1000, 0b0011, 0b101>; defm : TLBI<"VALE2IS", 1, 0b100, 0b1000, 0b0011, 0b101>; defm : TLBI<"VALE3IS", 1, 0b110, 0b1000, 0b0011, 0b101>; -defm : TLBI<"VMALLS12E1IS", 0, 0b100, 0b1000, 0b0011, 0b110, 0>; +defm : TLBI<"VMALLS12E1IS", 0, 0b100, 0b1000, 0b0011, 0b110, 0, 1>; defm : TLBI<"VAALE1IS", 1, 0b000, 0b1000, 0b0011, 0b111>; defm : TLBI<"IPAS2E1", 1, 0b100, 0b1000, 0b0100, 0b001>; defm : TLBI<"IPAS2LE1", 1, 0b100, 0b1000, 0b0100, 0b101>; -defm : TLBI<"VMALLE1", 0, 0b000, 0b1000, 0b0111, 0b000, 0>; -defm : TLBI<"ALLE2", 0, 0b100, 0b1000, 0b0111, 0b000, 0>; -defm : TLBI<"ALLE3", 0, 0b110, 0b1000, 0b0111, 0b000, 0>; +defm : TLBI<"VMALLE1", 0, 0b000, 0b1000, 0b0111, 0b000, 0, 0>; +defm : TLBI<"ALLE2", 0, 0b100, 0b1000, 0b0111, 0b000, 0, 0>; +defm : TLBI<"ALLE3", 0, 0b110, 0b1000, 0b0111, 0b000, 0, 0>; defm : TLBI<"VAE1", 1, 0b000, 0b1000, 0b0111, 0b001>; defm : TLBI<"VAE2", 1, 0b100, 0b1000, 0b0111, 0b001>; defm : TLBI<"VAE3", 1, 0b110, 0b1000, 0b0111, 0b001>; defm : TLBI<"ASIDE1", 0, 0b000, 0b1000, 0b0111, 0b010>; defm : TLBI<"VAAE1", 1, 0b000, 0b1000, 0b0111, 0b011>; -defm : TLBI<"ALLE1", 0, 0b100, 0b1000, 0b0111, 0b100, 0>; +defm : TLBI<"ALLE1", 0, 0b100, 0b1000, 0b0111, 0b100, 0, 0>; defm : TLBI<"VALE1", 1, 0b000, 0b1000, 0b0111, 0b101>; defm : TLBI<"VALE2", 1, 0b100, 0b1000, 0b0111, 0b101>; defm : TLBI<"VALE3", 1, 0b110, 0b1000, 0b0111, 0b101>; -defm : TLBI<"VMALLS12E1", 0, 0b100, 0b1000, 0b0111, 0b110, 0>; +defm : TLBI<"VMALLS12E1", 0, 0b100, 0b1000, 0b0111, 0b110, 0, 0>; defm : TLBI<"VAALE1", 1, 0b000, 0b1000, 0b0111, 0b111>; // Armv8.4-A Translation Lookaside Buffer Instructions (TLBI) let Requires = ["AArch64::FeatureTLB_RMI"] in { // Armv8.4-A Outer Sharable TLB Maintenance instructions: -// hasTLBIP op1 CRn CRm op2 needsreg -defm : TLBI<"VMALLE1OS", 0, 0b000, 0b1000, 0b0001, 0b000, 0>; +// hasTLBIP op1 CRn CRm op2 needsreg, optreg +defm : TLBI<"VMALLE1OS", 0, 0b000, 0b1000, 0b0001, 0b000, 0, 1>; defm : TLBI<"VAE1OS", 1, 0b000, 0b1000, 0b0001, 0b001>; defm : TLBI<"ASIDE1OS", 0, 0b000, 0b1000, 0b0001, 0b010>; defm : TLBI<"VAAE1OS", 1, 0b000, 0b1000, 0b0001, 0b011>; @@ -919,15 +941,15 @@ defm : TLBI<"IPAS2E1OS", 1, 0b100, 0b1000, 0b0100, 0b000>; defm : TLBI<"IPAS2LE1OS", 1, 0b100, 0b1000, 0b0100, 0b100>; defm : TLBI<"VAE2OS", 1, 0b100, 0b1000, 0b0001, 0b001>; defm : TLBI<"VALE2OS", 1, 0b100, 0b1000, 0b0001, 0b101>; -defm : TLBI<"VMALLS12E1OS", 0, 0b100, 0b1000, 0b0001, 0b110, 0>; +defm : TLBI<"VMALLS12E1OS", 0, 0b100, 0b1000, 0b0001, 0b110, 0, 1>; defm : TLBI<"VAE3OS", 1, 0b110, 0b1000, 0b0001, 0b001>; defm : TLBI<"VALE3OS", 1, 0b110, 0b1000, 0b0001, 0b101>; -defm : TLBI<"ALLE2OS", 0, 0b100, 0b1000, 0b0001, 0b000, 0>; -defm : TLBI<"ALLE1OS", 0, 0b100, 0b1000, 0b0001, 0b100, 0>; -defm : TLBI<"ALLE3OS", 0, 0b110, 0b1000, 0b0001, 0b000, 0>; +defm : TLBI<"ALLE2OS", 0, 0b100, 0b1000, 0b0001, 0b000, 0, 1>; +defm : TLBI<"ALLE1OS", 0, 0b100, 0b1000, 0b0001, 0b100, 0, 1>; +defm : TLBI<"ALLE3OS", 0, 0b110, 0b1000, 0b0001, 0b000, 0, 1>; // Armv8.4-A TLB Range Maintenance instructions: -// hasTLBIP op1 CRn CRm op2 needsreg +// hasTLBIP op1 CRn CRm op2 defm : TLBI<"RVAE1", 1, 0b000, 0b1000, 0b0110, 0b001>; defm : TLBI<"RVAAE1", 1, 0b000, 0b1000, 0b0110, 0b011>; defm : TLBI<"RVALE1", 1, 0b000, 0b1000, 0b0110, 0b101>; @@ -962,18 +984,19 @@ defm : TLBI<"RVALE3OS", 1, 0b110, 0b1000, 0b0101, 0b101>; // Armv9-A Realm Management Extension TLBI Instructions let Requires = ["AArch64::FeatureRME"] in { +// hasTLBIP op1 CRn CRm op2 needsreg defm : TLBI<"RPAOS", 0, 0b110, 0b1000, 0b0100, 0b011>; defm : TLBI<"RPALOS", 0, 0b110, 0b1000, 0b0100, 0b111>; -defm : TLBI<"PAALLOS", 0, 0b110, 0b1000, 0b0001, 0b100, 0>; -defm : TLBI<"PAALL", 0, 0b110, 0b1000, 0b0111, 0b100, 0>; +defm : TLBI<"PAALLOS", 0, 0b110, 0b1000, 0b0001, 0b100, 0, 0>; +defm : TLBI<"PAALL", 0, 0b110, 0b1000, 0b0111, 0b100, 0, 0>; } // Armv9.5-A TLBI VMALL for Dirty State let Requires = ["AArch64::FeatureTLBIW"] in { -// op1, CRn, CRm, op2, needsreg -defm : TLBI<"VMALLWS2E1", 0, 0b100, 0b1000, 0b0110, 0b010, 0>; -defm : TLBI<"VMALLWS2E1IS", 0, 0b100, 0b1000, 0b0010, 0b010, 0>; -defm : TLBI<"VMALLWS2E1OS", 0, 0b100, 0b1000, 0b0101, 0b010, 0>; +// hasTLBIP op1 CRn CRm op2 needsreg, optreg +defm : TLBI<"VMALLWS2E1", 0, 0b100, 0b1000, 0b0110, 0b010, 0, 0>; +defm : TLBI<"VMALLWS2E1IS", 0, 0b100, 0b1000, 0b0010, 0b010, 0, 1>; +defm : TLBI<"VMALLWS2E1OS", 0, 0b100, 0b1000, 0b0101, 0b010, 0, 1>; } //===----------------------------------------------------------------------===// @@ -1862,13 +1885,6 @@ def : ROSysReg<"ERXPFGF_EL1", 0b11, 0b000, 0b0101, 0b0100, 0b100>; // v8.4a MPAM registers // Op0 Op1 CRn CRm Op2 -let Requires = [{ {AArch64::FeatureMPAM} }] in { -def : RWSysReg<"MPAM0_EL1", 0b11, 0b000, 0b1010, 0b0101, 0b001>; -def : RWSysReg<"MPAM1_EL1", 0b11, 0b000, 0b1010, 0b0101, 0b000>; -def : RWSysReg<"MPAM2_EL2", 0b11, 0b100, 0b1010, 0b0101, 0b000>; -def : RWSysReg<"MPAM3_EL3", 0b11, 0b110, 0b1010, 0b0101, 0b000>; -def : RWSysReg<"MPAM1_EL12", 0b11, 0b101, 0b1010, 0b0101, 0b000>; -def : RWSysReg<"MPAMHCR_EL2", 0b11, 0b100, 0b1010, 0b0100, 0b000>; def : RWSysReg<"MPAMVPMV_EL2", 0b11, 0b100, 0b1010, 0b0100, 0b001>; def : RWSysReg<"MPAMVPM0_EL2", 0b11, 0b100, 0b1010, 0b0110, 0b000>; def : RWSysReg<"MPAMVPM1_EL2", 0b11, 0b100, 0b1010, 0b0110, 0b001>; @@ -1878,8 +1894,6 @@ def : RWSysReg<"MPAMVPM4_EL2", 0b11, 0b100, 0b1010, 0b0110, 0b100>; def : RWSysReg<"MPAMVPM5_EL2", 0b11, 0b100, 0b1010, 0b0110, 0b101>; def : RWSysReg<"MPAMVPM6_EL2", 0b11, 0b100, 0b1010, 0b0110, 0b110>; def : RWSysReg<"MPAMVPM7_EL2", 0b11, 0b100, 0b1010, 0b0110, 0b111>; -def : ROSysReg<"MPAMIDR_EL1", 0b11, 0b000, 0b1010, 0b0100, 0b100>; -} //FeatureMPAM // v8.4a Activity Monitor registers // Op0 Op1 CRn CRm Op2 @@ -2319,6 +2333,26 @@ def : RWSysReg<"MPAMBW0_EL1", 0b11, 0b000, 0b1010, 0b0101, 0b101>; def : RWSysReg<"MPAMBWCAP_EL2", 0b11, 0b100, 0b1010, 0b0101, 0b110>; def : RWSysReg<"MPAMBWSM_EL1", 0b11, 0b000, 0b1010, 0b0101, 0b111>; +// v9.7a Memory partitioning and monitoring version 2 +// (FEAT_MPAMv2) registers +// Op0 Op1 CRn CRm Op2 +// MPAM system registers that are also available for MPAMv2 +def : RWSysReg<"MPAM0_EL1", 0b11, 0b000, 0b1010, 0b0101, 0b001>; +def : RWSysReg<"MPAM1_EL1", 0b11, 0b000, 0b1010, 0b0101, 0b000>; +def : RWSysReg<"MPAM1_EL12", 0b11, 0b101, 0b1010, 0b0101, 0b000>; +def : RWSysReg<"MPAM2_EL2", 0b11, 0b100, 0b1010, 0b0101, 0b000>; +def : RWSysReg<"MPAM3_EL3", 0b11, 0b110, 0b1010, 0b0101, 0b000>; +def : RWSysReg<"MPAMHCR_EL2", 0b11, 0b100, 0b1010, 0b0100, 0b000>; +def : ROSysReg<"MPAMIDR_EL1", 0b11, 0b000, 0b1010, 0b0100, 0b100>; +// Only MPAMv2 registers +def : RWSysReg<"MPAMCTL_EL1", 0b11, 0b000, 0b1010, 0b0101, 0b010>; +def : RWSysReg<"MPAMCTL_EL12", 0b11, 0b101, 0b1010, 0b0101, 0b010>; +def : RWSysReg<"MPAMCTL_EL2", 0b11, 0b100, 0b1010, 0b0101, 0b010>; +def : RWSysReg<"MPAMCTL_EL3", 0b11, 0b110, 0b1010, 0b0101, 0b010>; +def : RWSysReg<"MPAMVIDCR_EL2", 0b11, 0b100, 0b1010, 0b0111, 0b000>; +def : RWSysReg<"MPAMVIDSR_EL2", 0b11, 0b100, 0b1010, 0b0111, 0b001>; +def : RWSysReg<"MPAMVIDSR_EL3", 0b11, 0b110, 0b1010, 0b0111, 0b001>; + //===----------------------------------------------------------------------===// // FEAT_SRMASK v9.6a registers //===----------------------------------------------------------------------===// @@ -2412,3 +2446,251 @@ def : DC<"CIVAPS", 0b000, 0b0111, 0b1111, 0b001>; let Requires = [{ {AArch64::FeaturePoPS, AArch64::FeatureMTE} }] in { def : DC<"CIGDVAPS", 0b000, 0b0111, 0b1111, 0b101>; } + +// v9.7a TLBI domains system registers (MemSys) +foreach n = 0-3 in { + defvar nb = !cast<bits<3>>(n); + def : RWSysReg<"VTLBID"#n#"_EL2", 0b11, 0b100, 0b0010, 0b1000, nb>; +} + +foreach n = 0-3 in { + defvar nb = !cast<bits<3>>(n); + def : RWSysReg<"VTLBIDOS"#n#"_EL2", 0b11, 0b100, 0b0010, 0b1001, nb>; +} + +def : ROSysReg<"TLBIDIDR_EL1", 0b11, 0b000, 0b1010, 0b0100, 0b110>; + +// MPAM Lookaside Buffer Invalidate (MLBI) instructions +class MLBI<string name, bits<3> op1, bits<4> crn, bits<4> crm, bits<3> op2, bit needsreg> { + string Name = name; + bits<14> Encoding; + let Encoding{13-11} = op1; + let Encoding{10-7} = crn; + let Encoding{6-3} = crm; + let Encoding{2-0} = op2; + bit NeedsReg = needsreg; + string RequiresStr = [{ {AArch64::FeatureMPAMv2} }]; +} + +def MLBITable : GenericTable { + let FilterClass = "MLBI"; + let CppTypeName = "MLBI"; + let Fields = ["Name", "Encoding", "NeedsReg", "RequiresStr"]; + + let PrimaryKey = ["Encoding"]; + let PrimaryKeyName = "lookupMLBIByEncoding"; +} + +def lookupMLBIByName : SearchIndex { + let Table = MLBITable; + let Key = ["Name"]; +} + +// Op1 CRn CRm Op2 needsReg +def : MLBI<"ALLE1", 0b100, 0b0111, 0b0000, 0b100, 0>; +def : MLBI<"VMALLE1", 0b100, 0b0111, 0b0000, 0b101, 0>; +def : MLBI<"VPIDE1", 0b100, 0b0111, 0b0000, 0b110, 1>; +def : MLBI<"VPMGE1", 0b100, 0b0111, 0b0000, 0b111, 1>; + + +// v9.7-A GICv5 (FEAT_GCIE) +// CPU Interface Registers +// Op0 Op1 CRn CRm Op2 +def : RWSysReg<"ICC_APR_EL1", 0b11, 0b001, 0b1100, 0b0000, 0b000>; +def : RWSysReg<"ICC_APR_EL3", 0b11, 0b110, 0b1100, 0b1000, 0b000>; +def : RWSysReg<"ICC_CR0_EL1", 0b11, 0b001, 0b1100, 0b0000, 0b001>; +def : RWSysReg<"ICC_CR0_EL3", 0b11, 0b110, 0b1100, 0b1001, 0b000>; +def : ROSysReg<"ICC_DOMHPPIR_EL3", 0b11, 0b110, 0b1100, 0b1000, 0b010>; +def : ROSysReg<"ICC_HAPR_EL1", 0b11, 0b001, 0b1100, 0b0000, 0b011>; +def : ROSysReg<"ICC_HPPIR_EL1", 0b11, 0b000, 0b1100, 0b1010, 0b011>; +def : ROSysReg<"ICC_HPPIR_EL3", 0b11, 0b110, 0b1100, 0b1001, 0b001>; +def : ROSysReg<"ICC_IAFFIDR_EL1", 0b11, 0b000, 0b1100, 0b1010, 0b101>; +def : RWSysReg<"ICC_ICSR_EL1", 0b11, 0b000, 0b1100, 0b1010, 0b100>; +def : ROSysReg<"ICC_IDR0_EL1", 0b11, 0b000, 0b1100, 0b1010, 0b010>; +def : RWSysReg<"ICC_PCR_EL1", 0b11, 0b001, 0b1100, 0b0000, 0b010>; +def : RWSysReg<"ICC_PCR_EL3", 0b11, 0b110, 0b1100, 0b1000, 0b001>; + +// Virtual CPU Interface Registers +// Op0 Op1 CRn CRm Op2 +def : RWSysReg<"ICV_APR_EL1", 0b11, 0b001, 0b1100, 0b0000, 0b000>; +def : RWSysReg<"ICV_CR0_EL1", 0b11, 0b001, 0b1100, 0b0000, 0b001>; +def : RWSysReg<"ICV_HAPR_EL1", 0b11, 0b001, 0b1100, 0b0000, 0b011>; +def : RWSysReg<"ICV_HPPIR_EL1", 0b11, 0b000, 0b1100, 0b1010, 0b011>; +def : RWSysReg<"ICV_PCR_EL1", 0b11, 0b001, 0b1100, 0b0000, 0b010>; + +foreach n=0-3 in { + defvar nb = !cast<bits<2>>(n); +// Op0 Op1 CRn CRm Op2 + def : RWSysReg<"ICC_PPI_DOMAINR"#n#"_EL3", 0b11, 0b110, 0b1100, 0b1000, {0b1,nb{1-0}}>; + +} + +foreach n=0-15 in{ + defvar nb = !cast<bits<4>>(n); +// Op0 Op1 CRn CRm Op2 + def : RWSysReg<"ICC_PPI_PRIORITYR"#n#"_EL1", 0b11, 0b000, 0b1100, {0b111,nb{3}}, nb{2-0}>; +} + +// PPI and Virtual PPI Registers +multiclass PPIRegisters<string prefix> { + foreach n=0-1 in { + defvar nb = !cast<bit>(n); +// Op0 Op1 CRn CRm Op2 + def : RWSysReg<prefix#"_PPI_CACTIVER"#n#"_EL1", 0b11, 0b000, 0b1100, 0b1101, {0b00,nb}>; + def : RWSysReg<prefix#"_PPI_CPENDR"#n#"_EL1", 0b11, 0b000, 0b1100, 0b1101, {0b10,nb}>; + def : RWSysReg<prefix#"_PPI_ENABLER"#n#"_EL1", 0b11, 0b000, 0b1100, 0b1010, {0b11,nb}>; + def : RWSysReg<prefix#"_PPI_SACTIVER"#n#"_EL1", 0b11, 0b000, 0b1100, 0b1101, {0b01,nb}>; + def : RWSysReg<prefix#"_PPI_SPENDR"#n#"_EL1", 0b11, 0b000, 0b1100, 0b1101, {0b11,nb}>; + def : RWSysReg<prefix#"_PPI_HMR"#n#"_EL1", 0b11, 0b000, 0b1100, 0b1010, {0b00,nb}>; + } +} + +defm : PPIRegisters<"ICC">; // PPI Registers +defm : PPIRegisters<"ICV">; // Virtual PPI Registers + +foreach n=0-15 in { + defvar nb = !cast<bits<4>>(n); +// Op0 Op1 CRn CRm Op2 + def : RWSysReg<"ICV_PPI_PRIORITYR"#n#"_EL1", 0b11, 0b000, 0b1100, {0b111,nb{3}}, nb{2-0}>; +} + +// Hypervisor Control Registers +// Op0 Op1 CRn CRm Op2 +def : RWSysReg<"ICH_APR_EL2", 0b11, 0b100, 0b1100, 0b1000, 0b100>; +def : RWSysReg<"ICH_CONTEXTR_EL2", 0b11, 0b100, 0b1100, 0b1011, 0b110>; +def : RWSysReg<"ICH_HFGITR_EL2", 0b11, 0b100, 0b1100, 0b1001, 0b111>; +def : RWSysReg<"ICH_HFGRTR_EL2", 0b11, 0b100, 0b1100, 0b1001, 0b100>; +def : RWSysReg<"ICH_HFGWTR_EL2", 0b11, 0b100, 0b1100, 0b1001, 0b110>; +def : ROSysReg<"ICH_HPPIR_EL2", 0b11, 0b100, 0b1100, 0b1000, 0b101>; +def : RWSysReg<"ICH_VCTLR_EL2", 0b11, 0b100, 0b1100, 0b1011, 0b100>; + +foreach n=0-1 in { + defvar nb = !cast<bit>(n); +// Op0 Op1 CRn CRm Op2 +def : RWSysReg<"ICH_PPI_ACTIVER"#n#"_EL2", 0b11, 0b100, 0b1100, 0b1010, {0b11,nb}>; +def : RWSysReg<"ICH_PPI_DVIR"#n#"_EL2", 0b11, 0b100, 0b1100, 0b1010, {0b00,nb}>; +def : RWSysReg<"ICH_PPI_ENABLER"#n#"_EL2", 0b11, 0b100, 0b1100, 0b1010, {0b01,nb}>; +def : RWSysReg<"ICH_PPI_PENDR"#n#"_EL2", 0b11, 0b100, 0b1100, 0b1010, {0b10,nb}>; +} + +foreach n=0-15 in { + defvar nb = !cast<bits<4>>(n); +// Op0 Op1 CRn CRm Op2 + def : RWSysReg<"ICH_PPI_PRIORITYR"#n#"_EL2", 0b11, 0b100, 0b1100, {0b111,nb{3}}, nb{2-0}>; +} + +//===----------------------------------------------------------------------===// +// GICv5 instruction options. +//===----------------------------------------------------------------------===// + +// GIC +class GIC<string name, bits<3> op1, bits<4> crn, bits<4> crm, bits<3> op2> { + string Name = name; + bits<14> Encoding; + let Encoding{13-11} = op1; + let Encoding{10-7} = crn; + let Encoding{6-3} = crm; + let Encoding{2-0} = op2; + bit NeedsReg = 1; + string RequiresStr = [{ {AArch64::FeatureGCIE} }]; +} + +// GSB +class GSB<string name, bits<3> op1, bits<4> crn, bits<4> crm, bits<3> op2> { + string Name = name; + bits<14> Encoding; + let Encoding{13-11} = op1; + let Encoding{10-7} = crn; + let Encoding{6-3} = crm; + let Encoding{2-0} = op2; + string RequiresStr = [{ {AArch64::FeatureGCIE} }]; +} + +// GICR +class GICR<string name, bits<3> op1, bits<4> crn, bits<4> crm, bits<3> op2> { + string Name = name; + bits<14> Encoding; + let Encoding{13-11} = op1; + let Encoding{10-7} = crn; + let Encoding{6-3} = crm; + let Encoding{2-0} = op2; + bit NeedsReg = 1; + string RequiresStr = [{ {AArch64::FeatureGCIE} }]; +} + +def GICTable : GenericTable { + let FilterClass = "GIC"; + let CppTypeName = "GIC"; + let Fields = ["Name", "Encoding", "NeedsReg", "RequiresStr"]; + + let PrimaryKey = ["Encoding"]; + let PrimaryKeyName = "lookupGICByEncoding"; +} + +def GSBTable : GenericTable { + let FilterClass = "GSB"; + let CppTypeName = "GSB"; + let Fields = ["Name", "Encoding", "RequiresStr"]; + + let PrimaryKey = ["Encoding"]; + let PrimaryKeyName = "lookupGSBByEncoding"; +} + +def GICRTable : GenericTable { + let FilterClass = "GICR"; + let CppTypeName = "GICR"; + let Fields = ["Name", "Encoding", "NeedsReg", "RequiresStr"]; + + let PrimaryKey = ["Encoding"]; + let PrimaryKeyName = "lookupGICRByEncoding"; +} + +def lookupGICByName : SearchIndex { + let Table = GICTable; + let Key = ["Name"]; +} + +def lookupGSBByName : SearchIndex { + let Table = GSBTable; + let Key = ["Name"]; +} + +def lookupGICRByName : SearchIndex { + let Table = GICRTable; + let Key = ["Name"]; +} + +// Op1 CRn CRm Op2 +def : GSB<"sys", 0b000, 0b1100, 0b0000, 0b000>; +def : GSB<"ack", 0b000, 0b1100, 0b0000, 0b001>; + +// Op1 CRn CRm Op2 +def : GICR<"cdia", 0b000, 0b1100, 0b0011, 0b000>; +def : GICR<"cdnmia", 0b000, 0b1100, 0b0011, 0b001>; + +// Op1 CRn CRm Op2 +def : GIC<"cdaff", 0b000, 0b1100, 0b0001, 0b011>; +def : GIC<"cddi", 0b000, 0b1100, 0b0010, 0b000>; +def : GIC<"cddis", 0b000, 0b1100, 0b0001, 0b000>; +def : GIC<"cden", 0b000, 0b1100, 0b0001, 0b001>; +def : GIC<"cdeoi", 0b000, 0b1100, 0b0001, 0b111>; +def : GIC<"cdhm", 0b000, 0b1100, 0b0010, 0b001>; +def : GIC<"cdpend", 0b000, 0b1100, 0b0001, 0b100>; +def : GIC<"cdpri", 0b000, 0b1100, 0b0001, 0b010>; +def : GIC<"cdrcfg", 0b000, 0b1100, 0b0001, 0b101>; +def : GIC<"vdaff", 0b100, 0b1100, 0b0001, 0b011>; +def : GIC<"vddi", 0b100, 0b1100, 0b0010, 0b000>; +def : GIC<"vddis", 0b100, 0b1100, 0b0001, 0b000>; +def : GIC<"vden", 0b100, 0b1100, 0b0001, 0b001>; +def : GIC<"vdhm", 0b100, 0b1100, 0b0010, 0b001>; +def : GIC<"vdpend", 0b100, 0b1100, 0b0001, 0b100>; +def : GIC<"vdpri", 0b100, 0b1100, 0b0001, 0b010>; +def : GIC<"vdrcfg", 0b100, 0b1100, 0b0001, 0b101>; +def : GIC<"ldaff", 0b110, 0b1100, 0b0001, 0b011>; +def : GIC<"lddi", 0b110, 0b1100, 0b0010, 0b000>; +def : GIC<"lddis", 0b110, 0b1100, 0b0001, 0b000>; +def : GIC<"lden", 0b110, 0b1100, 0b0001, 0b001>; +def : GIC<"ldhm", 0b110, 0b1100, 0b0010, 0b001>; +def : GIC<"ldpend", 0b110, 0b1100, 0b0001, 0b100>; +def : GIC<"ldpri", 0b110, 0b1100, 0b0001, 0b010>; +def : GIC<"ldrcfg", 0b110, 0b1100, 0b0001, 0b101>; diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index e3370d3..2053fc4 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -1577,18 +1577,26 @@ static SVEIntrinsicInfo constructSVEIntrinsicInfo(IntrinsicInst &II) { } static bool isAllActivePredicate(Value *Pred) { - // Look through convert.from.svbool(convert.to.svbool(...) chain. Value *UncastedPred; + + // Look through predicate casts that only remove lanes. if (match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_convert_from_svbool>( - m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>( - m_Value(UncastedPred))))) - // If the predicate has the same or less lanes than the uncasted - // predicate then we know the casting has no effect. - if (cast<ScalableVectorType>(Pred->getType())->getMinNumElements() <= - cast<ScalableVectorType>(UncastedPred->getType())->getMinNumElements()) - Pred = UncastedPred; + m_Value(UncastedPred)))) { + auto *OrigPredTy = cast<ScalableVectorType>(Pred->getType()); + Pred = UncastedPred; + + if (match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>( + m_Value(UncastedPred)))) + // If the predicate has the same or less lanes than the uncasted predicate + // then we know the casting has no effect. + if (OrigPredTy->getMinNumElements() <= + cast<ScalableVectorType>(UncastedPred->getType()) + ->getMinNumElements()) + Pred = UncastedPred; + } + auto *C = dyn_cast<Constant>(Pred); - return (C && C->isAllOnesValue()); + return C && C->isAllOnesValue(); } // Simplify `V` by only considering the operations that affect active lanes. diff --git a/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp b/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp index 636d4f8a..6273cfc 100644 --- a/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp +++ b/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp @@ -159,6 +159,7 @@ private: SMLoc getLoc() const { return getParser().getTok().getLoc(); } bool parseSysAlias(StringRef Name, SMLoc NameLoc, OperandVector &Operands); + bool parseSyslAlias(StringRef Name, SMLoc NameLoc, OperandVector &Operands); bool parseSyspAlias(StringRef Name, SMLoc NameLoc, OperandVector &Operands); void createSysAlias(uint16_t Encoding, OperandVector &Operands, SMLoc S); AArch64CC::CondCode parseCondCodeString(StringRef Cond, @@ -266,6 +267,7 @@ private: ParseStatus tryParseRPRFMOperand(OperandVector &Operands); ParseStatus tryParsePSBHint(OperandVector &Operands); ParseStatus tryParseBTIHint(OperandVector &Operands); + ParseStatus tryParseCMHPriorityHint(OperandVector &Operands); ParseStatus tryParseAdrpLabel(OperandVector &Operands); ParseStatus tryParseAdrLabel(OperandVector &Operands); template <bool AddFPZeroAsLiteral> @@ -370,6 +372,7 @@ private: k_PSBHint, k_PHint, k_BTIHint, + k_CMHPriorityHint, } Kind; SMLoc StartLoc, EndLoc; @@ -499,6 +502,11 @@ private: unsigned Length; unsigned Val; }; + struct CMHPriorityHintOp { + const char *Data; + unsigned Length; + unsigned Val; + }; struct SVCROp { const char *Data; @@ -525,6 +533,7 @@ private: struct PSBHintOp PSBHint; struct PHintOp PHint; struct BTIHintOp BTIHint; + struct CMHPriorityHintOp CMHPriorityHint; struct ShiftExtendOp ShiftExtend; struct SVCROp SVCR; }; @@ -595,6 +604,9 @@ public: case k_BTIHint: BTIHint = o.BTIHint; break; + case k_CMHPriorityHint: + CMHPriorityHint = o.CMHPriorityHint; + break; case k_ShiftExtend: ShiftExtend = o.ShiftExtend; break; @@ -769,6 +781,16 @@ public: return StringRef(BTIHint.Data, BTIHint.Length); } + unsigned getCMHPriorityHint() const { + assert(Kind == k_CMHPriorityHint && "Invalid access!"); + return CMHPriorityHint.Val; + } + + StringRef getCMHPriorityHintName() const { + assert(Kind == k_CMHPriorityHint && "Invalid access!"); + return StringRef(CMHPriorityHint.Data, CMHPriorityHint.Length); + } + StringRef getSVCR() const { assert(Kind == k_SVCR && "Invalid access!"); return StringRef(SVCR.Data, SVCR.Length); @@ -1511,6 +1533,7 @@ public: bool isPSBHint() const { return Kind == k_PSBHint; } bool isPHint() const { return Kind == k_PHint; } bool isBTIHint() const { return Kind == k_BTIHint; } + bool isCMHPriorityHint() const { return Kind == k_CMHPriorityHint; } bool isShiftExtend() const { return Kind == k_ShiftExtend; } bool isShifter() const { if (!isShiftExtend()) @@ -2196,6 +2219,11 @@ public: Inst.addOperand(MCOperand::createImm(getBTIHint())); } + void addCMHPriorityHintOperands(MCInst &Inst, unsigned N) const { + assert(N == 1 && "Invalid number of operands!"); + Inst.addOperand(MCOperand::createImm(getCMHPriorityHint())); + } + void addShifterOperands(MCInst &Inst, unsigned N) const { assert(N == 1 && "Invalid number of operands!"); unsigned Imm = @@ -2547,6 +2575,17 @@ public: } static std::unique_ptr<AArch64Operand> + CreateCMHPriorityHint(unsigned Val, StringRef Str, SMLoc S, MCContext &Ctx) { + auto Op = std::make_unique<AArch64Operand>(k_CMHPriorityHint, Ctx); + Op->CMHPriorityHint.Val = Val; + Op->CMHPriorityHint.Data = Str.data(); + Op->CMHPriorityHint.Length = Str.size(); + Op->StartLoc = S; + Op->EndLoc = S; + return Op; + } + + static std::unique_ptr<AArch64Operand> CreateMatrixRegister(unsigned RegNum, unsigned ElementWidth, MatrixKind Kind, SMLoc S, SMLoc E, MCContext &Ctx) { auto Op = std::make_unique<AArch64Operand>(k_MatrixRegister, Ctx); @@ -2656,6 +2695,9 @@ void AArch64Operand::print(raw_ostream &OS, const MCAsmInfo &MAI) const { case k_BTIHint: OS << getBTIHintName(); break; + case k_CMHPriorityHint: + OS << getCMHPriorityHintName(); + break; case k_MatrixRegister: OS << "<matrix " << getMatrixReg() << ">"; break; @@ -3279,6 +3321,24 @@ ParseStatus AArch64AsmParser::tryParseBTIHint(OperandVector &Operands) { return ParseStatus::Success; } +/// tryParseCMHPriorityHint - Try to parse a CMHPriority operand +ParseStatus AArch64AsmParser::tryParseCMHPriorityHint(OperandVector &Operands) { + SMLoc S = getLoc(); + const AsmToken &Tok = getTok(); + if (Tok.isNot(AsmToken::Identifier)) + return TokError("invalid operand for instruction"); + + auto CMHPriority = + AArch64CMHPriorityHint::lookupCMHPriorityHintByName(Tok.getString()); + if (!CMHPriority) + return TokError("invalid operand for instruction"); + + Operands.push_back(AArch64Operand::CreateCMHPriorityHint( + CMHPriority->Encoding, Tok.getString(), S, getContext())); + Lex(); // Eat identifier token. + return ParseStatus::Success; +} + /// tryParseAdrpLabel - Parse and validate a source label for the ADRP /// instruction. ParseStatus AArch64AsmParser::tryParseAdrpLabel(OperandVector &Operands) { @@ -3824,6 +3884,18 @@ static const struct Extension { {"ssve-bitperm", {AArch64::FeatureSSVE_BitPerm}}, {"sme-mop4", {AArch64::FeatureSME_MOP4}}, {"sme-tmop", {AArch64::FeatureSME_TMOP}}, + {"cmh", {AArch64::FeatureCMH}}, + {"lscp", {AArch64::FeatureLSCP}}, + {"tlbid", {AArch64::FeatureTLBID}}, + {"mpamv2", {AArch64::FeatureMPAMv2}}, + {"mtetc", {AArch64::FeatureMTETC}}, + {"gcie", {AArch64::FeatureGCIE}}, + {"sme2p3", {AArch64::FeatureSME2p3}}, + {"sve2p3", {AArch64::FeatureSVE2p3}}, + {"sve-b16mm", {AArch64::FeatureSVE_B16MM}}, + {"f16mm", {AArch64::FeatureF16MM}}, + {"f16f32dot", {AArch64::FeatureF16F32DOT}}, + {"f16f32mm", {AArch64::FeatureF16F32MM}}, }; static void setRequiredFeatureString(FeatureBitset FBS, std::string &Str) { @@ -3861,6 +3933,8 @@ static void setRequiredFeatureString(FeatureBitset FBS, std::string &Str) { Str += "ARMv9.5a"; else if (FBS[AArch64::HasV9_6aOps]) Str += "ARMv9.6a"; + else if (FBS[AArch64::HasV9_7aOps]) + Str += "ARMv9.7a"; else if (FBS[AArch64::HasV8_0rOps]) Str += "ARMv8r"; else { @@ -3894,8 +3968,9 @@ void AArch64AsmParser::createSysAlias(uint16_t Encoding, OperandVector &Operands AArch64Operand::CreateImm(Expr, S, getLoc(), getContext())); } -/// parseSysAlias - The IC, DC, AT, and TLBI instructions are simple aliases for -/// the SYS instruction. Parse them specially so that we create a SYS MCInst. +/// parseSysAlias - The IC, DC, AT, TLBI, MLBI and GIC{R} and GSB instructions +/// are simple aliases for the SYS instruction. Parse them specially so that +/// we create a SYS MCInst. bool AArch64AsmParser::parseSysAlias(StringRef Name, SMLoc NameLoc, OperandVector &Operands) { if (Name.contains('.')) @@ -3908,6 +3983,8 @@ bool AArch64AsmParser::parseSysAlias(StringRef Name, SMLoc NameLoc, StringRef Op = Tok.getString(); SMLoc S = Tok.getLoc(); bool ExpectRegister = true; + bool OptionalRegister = false; + bool hasAll = getSTI().hasFeature(AArch64::FeatureAll); if (Mnemonic == "ic") { const AArch64IC::IC *IC = AArch64IC::lookupICByName(Op); @@ -3950,13 +4027,50 @@ bool AArch64AsmParser::parseSysAlias(StringRef Name, SMLoc NameLoc, return TokError(Str); } ExpectRegister = TLBI->NeedsReg; + bool hasTLBID = getSTI().hasFeature(AArch64::FeatureTLBID); + if (hasAll || hasTLBID) { + OptionalRegister = TLBI->OptionalReg; + } createSysAlias(TLBI->Encoding, Operands, S); - } else if (Mnemonic == "cfp" || Mnemonic == "dvp" || Mnemonic == "cpp" || Mnemonic == "cosp") { + } else if (Mnemonic == "mlbi") { + const AArch64MLBI::MLBI *MLBI = AArch64MLBI::lookupMLBIByName(Op); + if (!MLBI) + return TokError("invalid operand for MLBI instruction"); + else if (!MLBI->haveFeatures(getSTI().getFeatureBits())) { + std::string Str("MLBI " + std::string(MLBI->Name) + " requires: "); + setRequiredFeatureString(MLBI->getRequiredFeatures(), Str); + return TokError(Str); + } + ExpectRegister = MLBI->NeedsReg; + createSysAlias(MLBI->Encoding, Operands, S); + } else if (Mnemonic == "gic") { + const AArch64GIC::GIC *GIC = AArch64GIC::lookupGICByName(Op); + if (!GIC) + return TokError("invalid operand for GIC instruction"); + else if (!GIC->haveFeatures(getSTI().getFeatureBits())) { + std::string Str("GIC " + std::string(GIC->Name) + " requires: "); + setRequiredFeatureString(GIC->getRequiredFeatures(), Str); + return TokError(Str); + } + ExpectRegister = true; + createSysAlias(GIC->Encoding, Operands, S); + } else if (Mnemonic == "gsb") { + const AArch64GSB::GSB *GSB = AArch64GSB::lookupGSBByName(Op); + if (!GSB) + return TokError("invalid operand for GSB instruction"); + else if (!GSB->haveFeatures(getSTI().getFeatureBits())) { + std::string Str("GSB " + std::string(GSB->Name) + " requires: "); + setRequiredFeatureString(GSB->getRequiredFeatures(), Str); + return TokError(Str); + } + ExpectRegister = false; + createSysAlias(GSB->Encoding, Operands, S); + } else if (Mnemonic == "cfp" || Mnemonic == "dvp" || Mnemonic == "cpp" || + Mnemonic == "cosp") { if (Op.lower() != "rctx") return TokError("invalid operand for prediction restriction instruction"); - bool hasAll = getSTI().hasFeature(AArch64::FeatureAll); bool hasPredres = hasAll || getSTI().hasFeature(AArch64::FeaturePredRes); bool hasSpecres2 = hasAll || getSTI().hasFeature(AArch64::FeatureSPECRES2); @@ -3989,10 +4103,61 @@ bool AArch64AsmParser::parseSysAlias(StringRef Name, SMLoc NameLoc, HasRegister = true; } - if (ExpectRegister && !HasRegister) - return TokError("specified " + Mnemonic + " op requires a register"); - else if (!ExpectRegister && HasRegister) - return TokError("specified " + Mnemonic + " op does not use a register"); + if (!OptionalRegister) { + if (ExpectRegister && !HasRegister) + return TokError("specified " + Mnemonic + " op requires a register"); + else if (!ExpectRegister && HasRegister) + return TokError("specified " + Mnemonic + " op does not use a register"); + } + + if (parseToken(AsmToken::EndOfStatement, "unexpected token in argument list")) + return true; + + return false; +} + +/// parseSyslAlias - The GICR instructions are simple aliases for +/// the SYSL instruction. Parse them specially so that we create a +/// SYS MCInst. +bool AArch64AsmParser::parseSyslAlias(StringRef Name, SMLoc NameLoc, + OperandVector &Operands) { + + Mnemonic = Name; + Operands.push_back( + AArch64Operand::CreateToken("sysl", NameLoc, getContext())); + + // Now expect two operands (identifier + register) + SMLoc startLoc = getLoc(); + const AsmToken ®Tok = getTok(); + StringRef reg = regTok.getString(); + unsigned RegNum = matchRegisterNameAlias(reg.lower(), RegKind::Scalar); + if (!RegNum) + return TokError("expected register operand"); + + Operands.push_back(AArch64Operand::CreateReg( + RegNum, RegKind::Scalar, startLoc, getLoc(), getContext(), EqualsReg)); + + Lex(); // Eat token + if (parseToken(AsmToken::Comma)) + return true; + + // Check for identifier + const AsmToken &operandTok = getTok(); + StringRef Op = operandTok.getString(); + SMLoc S2 = operandTok.getLoc(); + Lex(); // Eat token + + if (Mnemonic == "gicr") { + const AArch64GICR::GICR *GICR = AArch64GICR::lookupGICRByName(Op); + if (!GICR) + return Error(S2, "invalid operand for GICR instruction"); + else if (!GICR->haveFeatures(getSTI().getFeatureBits())) { + std::string Str("GICR " + std::string(GICR->Name) + " requires: "); + setRequiredFeatureString(GICR->getRequiredFeatures(), Str); + return Error(S2, Str); + } + createSysAlias(GICR->Encoding, Operands, S2); + } if (parseToken(AsmToken::EndOfStatement, "unexpected token in argument list")) return true; @@ -4025,7 +4190,7 @@ bool AArch64AsmParser::parseSyspAlias(StringRef Name, SMLoc NameLoc, return TokError("invalid operand for TLBIP instruction"); const AArch64TLBIP::TLBIP TLBIP( TLBIPorig->Name, TLBIPorig->Encoding | (HasnXSQualifier ? (1 << 7) : 0), - TLBIPorig->NeedsReg, + TLBIPorig->NeedsReg, TLBIPorig->OptionalReg, HasnXSQualifier ? TLBIPorig->FeaturesRequired | FeatureBitset({AArch64::FeatureXS}) : TLBIPorig->FeaturesRequired); @@ -4719,6 +4884,13 @@ ParseStatus AArch64AsmParser::tryParseVectorList(OperandVector &Operands, FirstReg, Count, Stride, NumElements, ElementWidth, VectorKind, S, getLoc(), getContext())); + if (getTok().is(AsmToken::LBrac)) { + ParseStatus Res = tryParseVectorIndex(Operands); + if (Res.isFailure()) + return ParseStatus::Failure; + return ParseStatus::Success; + } + return ParseStatus::Success; } @@ -5267,12 +5439,17 @@ bool AArch64AsmParser::parseInstruction(ParseInstructionInfo &Info, size_t Start = 0, Next = Name.find('.'); StringRef Head = Name.slice(Start, Next); - // IC, DC, AT, TLBI and Prediction invalidation instructions are aliases for - // the SYS instruction. + // IC, DC, AT, TLBI, MLBI, GIC{R}, GSB and Prediction invalidation + // instructions are aliases for the SYS instruction. if (Head == "ic" || Head == "dc" || Head == "at" || Head == "tlbi" || - Head == "cfp" || Head == "dvp" || Head == "cpp" || Head == "cosp") + Head == "cfp" || Head == "dvp" || Head == "cpp" || Head == "cosp" || + Head == "mlbi" || Head == "gic" || Head == "gsb") return parseSysAlias(Head, NameLoc, Operands); + // GICR instructions are aliases for the SYSL instruction. + if (Head == "gicr") + return parseSyslAlias(Head, NameLoc, Operands); + // TLBIP instructions are aliases for the SYSP instruction. if (Head == "tlbip") return parseSyspAlias(Head, NameLoc, Operands); diff --git a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.cpp b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.cpp index 35bd244..5c3e26e 100644 --- a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.cpp +++ b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.cpp @@ -84,6 +84,12 @@ void AArch64InstPrinter::printInst(const MCInst *MI, uint64_t Address, return; } + if (Opcode == AArch64::SYSLxt) + if (printSyslAlias(MI, STI, O)) { + printAnnotation(O, Annot); + return; + } + if (Opcode == AArch64::SYSPxt || Opcode == AArch64::SYSPxt_XZR) if (printSyspAlias(MI, STI, O)) { printAnnotation(O, Annot); @@ -909,13 +915,25 @@ bool AArch64InstPrinter::printSysAlias(const MCInst *MI, Encoding |= CnVal << 7; Encoding |= Op1Val << 11; - bool NeedsReg; + bool NeedsReg = false; + bool OptionalReg = false; std::string Ins; std::string Name; if (CnVal == 7) { switch (CmVal) { default: return false; + // MLBI aliases + case 0: { + const AArch64MLBI::MLBI *MLBI = + AArch64MLBI::lookupMLBIByEncoding(Encoding); + if (!MLBI || !MLBI->haveFeatures(STI.getFeatureBits())) + return false; + + NeedsReg = MLBI->NeedsReg; + Ins = "mlbi\t"; + Name = std::string(MLBI->Name); + } break; // Maybe IC, maybe Prediction Restriction case 1: switch (Op1Val) { @@ -1004,19 +1022,41 @@ bool AArch64InstPrinter::printSysAlias(const MCInst *MI, return false; NeedsReg = TLBI->NeedsReg; + if (STI.hasFeature(AArch64::FeatureAll) || + STI.hasFeature(AArch64::FeatureTLBID)) + OptionalReg = TLBI->OptionalReg; Ins = "tlbi\t"; Name = std::string(TLBI->Name); - } - else + } else if (CnVal == 12) { + if (CmVal != 0) { + // GIC aliases + const AArch64GIC::GIC *GIC = AArch64GIC::lookupGICByEncoding(Encoding); + if (!GIC || !GIC->haveFeatures(STI.getFeatureBits())) + return false; + + NeedsReg = true; + Ins = "gic\t"; + Name = std::string(GIC->Name); + } else { + // GSB aliases + const AArch64GSB::GSB *GSB = AArch64GSB::lookupGSBByEncoding(Encoding); + if (!GSB || !GSB->haveFeatures(STI.getFeatureBits())) + return false; + + NeedsReg = false; + Ins = "gsb\t"; + Name = std::string(GSB->Name); + } + } else return false; StringRef Reg = getRegisterName(MI->getOperand(4).getReg()); bool NotXZR = Reg != "xzr"; - // If a mandatory is not specified in the TableGen + // If a mandatory or optional register is not specified in the TableGen // (i.e. no register operand should be present), and the register value // is not xzr/x31, then disassemble to a SYS alias instead. - if (NotXZR && !NeedsReg) + if (NotXZR && !NeedsReg && !OptionalReg) return false; std::string Str = Ins + Name; @@ -1024,12 +1064,64 @@ bool AArch64InstPrinter::printSysAlias(const MCInst *MI, O << '\t' << Str; - if (NeedsReg) + // For optional registers, don't print the value if it's xzr/x31 + // since this defaults to xzr/x31 if register is not specified. + if (NeedsReg || (OptionalReg && NotXZR)) O << ", " << Reg; return true; } +bool AArch64InstPrinter::printSyslAlias(const MCInst *MI, + const MCSubtargetInfo &STI, + raw_ostream &O) { +#ifndef NDEBUG + unsigned Opcode = MI->getOpcode(); + assert(Opcode == AArch64::SYSLxt && "Invalid opcode for SYSL alias!"); +#endif + + StringRef Reg = getRegisterName(MI->getOperand(0).getReg()); + const MCOperand &Op1 = MI->getOperand(1); + const MCOperand &Cn = MI->getOperand(2); + const MCOperand &Cm = MI->getOperand(3); + const MCOperand &Op2 = MI->getOperand(4); + + unsigned Op1Val = Op1.getImm(); + unsigned CnVal = Cn.getImm(); + unsigned CmVal = Cm.getImm(); + unsigned Op2Val = Op2.getImm(); + + uint16_t Encoding = Op2Val; + Encoding |= CmVal << 3; + Encoding |= CnVal << 7; + Encoding |= Op1Val << 11; + + std::string Ins; + std::string Name; + + if (CnVal == 12) { + if (CmVal == 3) { + // GICR aliases + const AArch64GICR::GICR *GICR = + AArch64GICR::lookupGICRByEncoding(Encoding); + if (!GICR || !GICR->haveFeatures(STI.getFeatureBits())) + return false; + + Ins = "gicr"; + Name = std::string(GICR->Name); + } else + return false; + } else + return false; + + std::string Str; + llvm::transform(Name, Name.begin(), ::tolower); + + O << '\t' << Ins << '\t' << Reg.str() << ", " << Name; + + return true; +} + bool AArch64InstPrinter::printSyspAlias(const MCInst *MI, const MCSubtargetInfo &STI, raw_ostream &O) { @@ -1508,6 +1600,17 @@ void AArch64InstPrinter::printBTIHintOp(const MCInst *MI, unsigned OpNum, markup(O, Markup::Immediate) << '#' << formatImm(btihintop); } +void AArch64InstPrinter::printCMHPriorityHintOp(const MCInst *MI, + unsigned OpNum, + const MCSubtargetInfo &STI, + raw_ostream &O) { + unsigned priorityhint_op = MI->getOperand(OpNum).getImm(); + auto PHint = + AArch64CMHPriorityHint::lookupCMHPriorityHintByEncoding(priorityhint_op); + if (PHint) + O << PHint->Name; +} + void AArch64InstPrinter::printFPImmOperand(const MCInst *MI, unsigned OpNum, const MCSubtargetInfo &STI, raw_ostream &O) { diff --git a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.h b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.h index 15ef2dd..307402d 100644 --- a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.h +++ b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.h @@ -52,6 +52,8 @@ public: protected: bool printSysAlias(const MCInst *MI, const MCSubtargetInfo &STI, raw_ostream &O); + bool printSyslAlias(const MCInst *MI, const MCSubtargetInfo &STI, + raw_ostream &O); bool printSyspAlias(const MCInst *MI, const MCSubtargetInfo &STI, raw_ostream &O); bool printRangePrefetchAlias(const MCInst *MI, const MCSubtargetInfo &STI, @@ -151,6 +153,9 @@ protected: void printBTIHintOp(const MCInst *MI, unsigned OpNum, const MCSubtargetInfo &STI, raw_ostream &O); + void printCMHPriorityHintOp(const MCInst *MI, unsigned OpNum, + const MCSubtargetInfo &STI, raw_ostream &O); + void printFPImmOperand(const MCInst *MI, unsigned OpNum, const MCSubtargetInfo &STI, raw_ostream &O); diff --git a/llvm/lib/Target/AArch64/SMEInstrFormats.td b/llvm/lib/Target/AArch64/SMEInstrFormats.td index 33f35ad..99836ae 100644 --- a/llvm/lib/Target/AArch64/SMEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SMEInstrFormats.td @@ -3920,6 +3920,78 @@ multiclass sme2_luti4_vector_vg4_index<string mnemonic> { def _S : sme2_luti4_vector_vg4_index<0b10, ZZZZ_s_mul_r, mnemonic>; } +// 8-bit Look up table +class sme2_lut_single<string asm> + : I<(outs ZPR8:$Zd), (ins ZTR:$ZTt, ZPRAny:$Zn), + asm, "\t$Zd, $ZTt, $Zn", "", []>, Sched<[]> { + bits<0> ZTt; + bits<5> Zd; + bits<5> Zn; + let Inst{31-10} = 0b1100000011001000010000; + let Inst{9-5} = Zn; + let Inst{4-0} = Zd; +} + +//===----------------------------------------------------------------------===// +// Lookup table read with 6-bit indices (8-bit) +class sme2_luti6_zt_base<RegisterOperand zd_ty, string asm> + : I<(outs zd_ty:$Zd), (ins ZTR:$ZTt, ZZZ_Any:$Zn), + asm, "\t$Zd, $ZTt, $Zn", "", []>, Sched<[]> { + bits<0> ZTt; + bits<3> Zd; + bits<3> Zn; + let Inst{31-21} = 0b11000000100; + let Inst{19-10} = 0b1010000000; + let Inst{9-7} = Zn; + let Inst{6-5} = 0b00; +} + +class sme2_luti6_zt_consecutive<string asm> + : sme2_luti6_zt_base<ZZZZ_b_mul_r, asm> { + let Inst{20} = 0; + let Inst{4-2} = Zd; + let Inst{1-0} = 0b00; +} + +class sme2_luti6_zt_strided<string asm> + : sme2_luti6_zt_base<ZZZZ_b_strided, asm> { + let Inst{20} = 1; + let Inst{4} = Zd{2}; + let Inst{3-2} = 0b00; + let Inst{1-0} = Zd{1-0}; +} + +//===----------------------------------------------------------------------===// +// Lookup table read with 6-bit indices (8-bit) +class sme2_luti6_vector_vg4_base<RegisterOperand zd_ty, string asm> + : I<(outs zd_ty:$Zd), (ins ZZ_h:$Zn, ZZ_Any:$Zm, VectorIndexD:$i1), + asm, "\t$Zd, $Zn, $Zm$i1", "", []>, Sched<[]> { + bits<3> Zd; + bits<5> Zn; + bits<5> Zm; + bits<1> i1; + let Inst{31-23} = 0b110000010; + let Inst{22} = i1; + let Inst{21} = 0b1; + let Inst{20-16} = Zm; + let Inst{9-5} = Zn; +} + +class sme2_luti6_vector_vg4_consecutive<string asm> + : sme2_luti6_vector_vg4_base<ZZZZ_h_mul_r, asm> { + let Inst{15-10} = 0b111101; + let Inst{4-2} = Zd; + let Inst{1-0} = 0b00; +} + +class sme2_luti6_vector_vg4_strided<string asm> + : sme2_luti6_vector_vg4_base<ZZZZ_h_strided, asm> { + let Inst{15-10} = 0b111111; + let Inst{4} = Zd{2}; + let Inst{3-2} = 0b00; + let Inst{1-0} = Zd{1-0}; +} + //===----------------------------------------------------------------------===// // SME2 MOV class sme2_mova_vec_to_tile_vg2_multi_base<bits<2> sz, bit v, diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td index 3cdd505..1664f4a 100644 --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -3787,7 +3787,7 @@ multiclass sve2p1_two_way_dot_vv<string mnemonic, bit u, SDPatternOperator intri // SVE Integer Dot Product Group - Indexed Group //===----------------------------------------------------------------------===// -class sve_intx_dot_by_indexed_elem<bit sz, bit U, string asm, +class sve_intx_dot_by_indexed_elem<bit U, string asm, ZPRRegOp zprty1, ZPRRegOp zprty2, ZPRRegOp zprty3, Operand itype> : I<(outs zprty1:$Zda), (ins zprty1:$_Zda, zprty2:$Zn, zprty3:$Zm, itype:$iop), @@ -3795,8 +3795,7 @@ class sve_intx_dot_by_indexed_elem<bit sz, bit U, string asm, "", []>, Sched<[]> { bits<5> Zda; bits<5> Zn; - let Inst{31-23} = 0b010001001; - let Inst{22} = sz; + let Inst{31-24} = 0b01000100; let Inst{21} = 0b1; let Inst{15-11} = 0; let Inst{10} = U; @@ -3810,16 +3809,18 @@ class sve_intx_dot_by_indexed_elem<bit sz, bit U, string asm, multiclass sve_intx_dot_by_indexed_elem<bit opc, string asm, SDPatternOperator op> { - def _BtoS : sve_intx_dot_by_indexed_elem<0b0, opc, asm, ZPR32, ZPR8, ZPR3b8, VectorIndexS32b_timm> { + def _BtoS : sve_intx_dot_by_indexed_elem<opc, asm, ZPR32, ZPR8, ZPR3b8, VectorIndexS32b_timm> { bits<2> iop; bits<3> Zm; + let Inst{23-22} = 0b10; let Inst{20-19} = iop; let Inst{18-16} = Zm; } - def _HtoD : sve_intx_dot_by_indexed_elem<0b1, opc, asm, ZPR64, ZPR16, ZPR4b16, VectorIndexD32b_timm> { + def _HtoD : sve_intx_dot_by_indexed_elem<opc, asm, ZPR64, ZPR16, ZPR4b16, VectorIndexD32b_timm> { bits<1> iop; bits<4> Zm; - let Inst{20} = iop; + let Inst{23-22} = 0b11; + let Inst{20} = iop; let Inst{19-16} = Zm; } @@ -3827,6 +3828,16 @@ multiclass sve_intx_dot_by_indexed_elem<bit opc, string asm, def : SVE_4_Op_Imm_Pat<nxv2i64, op, nxv2i64, nxv8i16, nxv8i16, i32, VectorIndexD32b_timm, !cast<Instruction>(NAME # _HtoD)>; } +class sve_intx_dot_by_indexed_elem_x<bit opc, string asm> +: sve_intx_dot_by_indexed_elem<opc, asm, ZPR16, ZPR8, ZPR3b8, VectorIndexH32b_timm> { + bits<3> iop; + bits<3> Zm; + let Inst{23} = 0b0; + let Inst{22} = iop{2}; + let Inst{20-19} = iop{1-0}; + let Inst{18-16} = Zm; +} + //===----------------------------------------------------------------------===// // SVE2 Complex Integer Dot Product Group //===----------------------------------------------------------------------===// @@ -4085,7 +4096,7 @@ class sve2_int_arith_pred<bits<2> sz, bits<6> opc, string asm, bits<5> Zdn; let Inst{31-24} = 0b01000100; let Inst{23-22} = sz; - let Inst{21-20} = 0b01; + let Inst{21} = 0b0; let Inst{20-16} = opc{5-1}; let Inst{15-14} = 0b10; let Inst{13} = opc{0}; @@ -4590,15 +4601,15 @@ multiclass sve2_int_cadd<bit opc, string asm, SDPatternOperator op> { def : SVE_3_Op_Imm_Pat<nxv2i64, op, nxv2i64, nxv2i64, i32, complexrotateopodd, !cast<Instruction>(NAME # _D)>; } -class sve2_int_absdiff_accum<bits<2> sz, bits<4> opc, string asm, +class sve2_int_absdiff_accum<bits<3> sz, bits<4> opc, string asm, ZPRRegOp zprty1, ZPRRegOp zprty2> : I<(outs zprty1:$Zda), (ins zprty1:$_Zda, zprty2:$Zn, zprty2:$Zm), asm, "\t$Zda, $Zn, $Zm", "", []>, Sched<[]> { bits<5> Zda; bits<5> Zn; bits<5> Zm; - let Inst{31-24} = 0b01000101; - let Inst{23-22} = sz; + let Inst{31-25} = 0b0100010; + let Inst{24-22} = sz; let Inst{21} = 0b0; let Inst{20-16} = Zm; let Inst{15-14} = 0b11; @@ -4613,10 +4624,10 @@ class sve2_int_absdiff_accum<bits<2> sz, bits<4> opc, string asm, } multiclass sve2_int_absdiff_accum<bit opc, string asm, SDPatternOperator op> { - def _B : sve2_int_absdiff_accum<0b00, { 0b111, opc }, asm, ZPR8, ZPR8>; - def _H : sve2_int_absdiff_accum<0b01, { 0b111, opc }, asm, ZPR16, ZPR16>; - def _S : sve2_int_absdiff_accum<0b10, { 0b111, opc }, asm, ZPR32, ZPR32>; - def _D : sve2_int_absdiff_accum<0b11, { 0b111, opc }, asm, ZPR64, ZPR64>; + def _B : sve2_int_absdiff_accum<0b100, { 0b111, opc }, asm, ZPR8, ZPR8>; + def _H : sve2_int_absdiff_accum<0b101, { 0b111, opc }, asm, ZPR16, ZPR16>; + def _S : sve2_int_absdiff_accum<0b110, { 0b111, opc }, asm, ZPR32, ZPR32>; + def _D : sve2_int_absdiff_accum<0b111, { 0b111, opc }, asm, ZPR64, ZPR64>; def : SVE_3_Op_Pat<nxv16i8, op, nxv16i8, nxv16i8, nxv16i8, !cast<Instruction>(NAME # _B)>; def : SVE_3_Op_Pat<nxv8i16, op, nxv8i16, nxv8i16, nxv8i16, !cast<Instruction>(NAME # _H)>; @@ -4626,20 +4637,26 @@ multiclass sve2_int_absdiff_accum<bit opc, string asm, SDPatternOperator op> { multiclass sve2_int_absdiff_accum_long<bits<2> opc, string asm, SDPatternOperator op> { - def _H : sve2_int_absdiff_accum<0b01, { 0b00, opc }, asm, ZPR16, ZPR8>; - def _S : sve2_int_absdiff_accum<0b10, { 0b00, opc }, asm, ZPR32, ZPR16>; - def _D : sve2_int_absdiff_accum<0b11, { 0b00, opc }, asm, ZPR64, ZPR32>; + def _H : sve2_int_absdiff_accum<0b101, { 0b00, opc }, asm, ZPR16, ZPR8>; + def _S : sve2_int_absdiff_accum<0b110, { 0b00, opc }, asm, ZPR32, ZPR16>; + def _D : sve2_int_absdiff_accum<0b111, { 0b00, opc }, asm, ZPR64, ZPR32>; def : SVE_3_Op_Pat<nxv8i16, op, nxv8i16, nxv16i8, nxv16i8, !cast<Instruction>(NAME # _H)>; def : SVE_3_Op_Pat<nxv4i32, op, nxv4i32, nxv8i16, nxv8i16, !cast<Instruction>(NAME # _S)>; def : SVE_3_Op_Pat<nxv2i64, op, nxv2i64, nxv4i32, nxv4i32, !cast<Instruction>(NAME # _D)>; } +multiclass sve2_int_two_way_absdiff_accum_long<bit U, string asm> { + def _BtoH : sve2_int_absdiff_accum<0b001, { 0b01, U, 0b1 }, asm, ZPR16, ZPR8>; + def _HtoS : sve2_int_absdiff_accum<0b010, { 0b01, U, 0b1 }, asm, ZPR32, ZPR16>; + def _StoD : sve2_int_absdiff_accum<0b011, { 0b01, U, 0b1 }, asm, ZPR64, ZPR32>; +} + multiclass sve2_int_addsub_long_carry<bits<2> opc, string asm, SDPatternOperator op> { - def _S : sve2_int_absdiff_accum<{ opc{1}, 0b0 }, { 0b010, opc{0} }, asm, + def _S : sve2_int_absdiff_accum<{ 0b1, opc{1}, 0b0 }, { 0b010, opc{0} }, asm, ZPR32, ZPR32>; - def _D : sve2_int_absdiff_accum<{ opc{1}, 0b1 }, { 0b010, opc{0} }, asm, + def _D : sve2_int_absdiff_accum<{ 0b1, opc{1}, 0b1 }, { 0b010, opc{0} }, asm, ZPR64, ZPR64>; def : SVE_3_Op_Pat<nxv4i32, op, nxv4i32, nxv4i32, nxv4i32, !cast<Instruction>(NAME # _S)>; @@ -9610,17 +9627,18 @@ multiclass sve_int_dot_mixed_indexed<bit U, string asm, SDPatternOperator op> { // SVE Floating Point Matrix Multiply Accumulate Group //===----------------------------------------------------------------------===// -class sve_fp_matrix_mla<bits<2> opc, string asm, ZPRRegOp zda_ty, ZPRRegOp reg_ty> +class sve_fp_matrix_mla<bits<3> opc, string asm, ZPRRegOp zda_ty, ZPRRegOp reg_ty> : I<(outs zda_ty:$Zda), (ins zda_ty:$_Zda, reg_ty:$Zn, reg_ty:$Zm), asm, "\t$Zda, $Zn, $Zm", "", []>, Sched<[]> { bits<5> Zda; bits<5> Zn; bits<5> Zm; let Inst{31-24} = 0b01100100; - let Inst{23-22} = opc; + let Inst{23-22} = opc{2-1}; let Inst{21} = 1; let Inst{20-16} = Zm; - let Inst{15-10} = 0b111001; + let Inst{15-11} = 0b11100; + let Inst{10} = opc{0}; let Inst{9-5} = Zn; let Inst{4-0} = Zda; @@ -9630,10 +9648,12 @@ class sve_fp_matrix_mla<bits<2> opc, string asm, ZPRRegOp zda_ty, ZPRRegOp reg_t let mayRaiseFPException = 1; } -multiclass sve_fp_matrix_mla<bits<2> opc, string asm, ZPRRegOp zda_ty, ZPRRegOp reg_ty, SDPatternOperator op, ValueType zda_vt, ValueType reg_vt> { +multiclass sve_fp_matrix_mla<bits<3> opc, string asm, ZPRRegOp zda_ty, + ZPRRegOp reg_ty, SDPatternOperator op, + ValueType zda_vt, ValueType reg_vt> { def NAME : sve_fp_matrix_mla<opc, asm, zda_ty, reg_ty>; - def : SVE_3_Op_Pat<zda_vt, op , zda_vt, reg_vt, reg_vt, !cast<Instruction>(NAME)>; + def : SVE_3_Op_Pat<zda_vt, op, zda_vt, reg_vt, reg_vt, !cast<Instruction>(NAME)>; } //===----------------------------------------------------------------------===// @@ -10030,18 +10050,19 @@ multiclass sve2p1_multi_vec_extract_narrow<string mnemonic, bits<2> opc, SDPatte } // SVE2 multi-vec shift narrow -class sve2p1_multi_vec_shift_narrow<string mnemonic, bits<3> opc, bits<2> tsz> - : I<(outs ZPR16:$Zd), (ins ZZ_s_mul_r:$Zn, vecshiftR16:$imm4), - mnemonic, "\t$Zd, $Zn, $imm4", +class sve2p1_multi_vec_shift_narrow<string mnemonic, ZPRRegOp ZdRC, RegisterOperand ZSrcOp, + Operand immtype, bits<3> opc, bits<2> tsz> + : I<(outs ZdRC:$Zd), (ins ZSrcOp:$Zn, immtype:$imm), + mnemonic, "\t$Zd, $Zn, $imm", "", []>, Sched<[]> { bits<5> Zd; bits<4> Zn; - bits<4> imm4; + bits<4> imm; let Inst{31-23} = 0b010001011; let Inst{22} = tsz{1}; let Inst{21} = 0b1; let Inst{20} = tsz{0}; - let Inst{19-16} = imm4; + let Inst{18-16} = imm{2-0}; // imm3 let Inst{15-14} = 0b00; let Inst{13-11} = opc; let Inst{10} = 0b0; @@ -10052,12 +10073,19 @@ class sve2p1_multi_vec_shift_narrow<string mnemonic, bits<3> opc, bits<2> tsz> let hasSideEffects = 0; } -multiclass sve2p1_multi_vec_shift_narrow<string mnemonic, bits<3> opc, SDPatternOperator intrinsic> { - def NAME : sve2p1_multi_vec_shift_narrow<mnemonic, opc, 0b01>; +multiclass sve_multi_vec_shift_narrow<string mnemonic, bits<3> opc, SDPatternOperator intrinsic> { + def NAME : sve2p1_multi_vec_shift_narrow<mnemonic, ZPR16, ZZ_s_mul_r, vecshiftR16, opc, 0b01> { + let Inst{19} = imm{3}; // imm4 + } def : SVE2p1_Sat_Shift_VG2_Pat<NAME, intrinsic, nxv8i16, nxv4i32, vecshiftR16>; } +multiclass sve_multi_vec_round_shift_narrow<string mnemonic, bits<3> opc> { + def NAME : sve2p1_multi_vec_shift_narrow<mnemonic, ZPR8, ZZ_h_mul_r, vecshiftR8, opc, 0b00> { + let Inst{19} = 0b1; // always 1 for imm3 version + } +} // SME2 multi-vec contiguous load (scalar plus scalar, two registers) class sve2p1_mem_cld_ss_2z<string mnemonic, bits<2> msz, bit n, @@ -11164,7 +11192,7 @@ multiclass sve2_fp8_dot_indexed_s<string asm, SDPatternOperator op> { def : SVE_4_Op_Pat<nxv4f32, op, nxv4f32, nxv16i8, nxv16i8, i32, !cast<Instruction>(NAME)>; } -// FP8 Look up table +// Look up table class sve2_lut_vector_index<ZPRRegOp zd_ty, RegisterOperand zn_ty, Operand idx_ty, bits<4>opc, string mnemonic> : I<(outs zd_ty:$Zd), (ins zn_ty:$Zn, ZPRAny:$Zm, idx_ty:$idx), @@ -11183,7 +11211,7 @@ class sve2_lut_vector_index<ZPRRegOp zd_ty, RegisterOperand zn_ty, let Inst{4-0} = Zd; } -// FP8 Look up table read with 2-bit indices +// Look up table read with 2-bit indices multiclass sve2_luti2_vector_index<string mnemonic> { def _B : sve2_lut_vector_index<ZPR8, Z_b, VectorIndexS32b, {?, 0b100}, mnemonic> { bits<2> idx; @@ -11205,7 +11233,7 @@ multiclass sve2_luti2_vector_index<string mnemonic> { i32, timm32_0_7, !cast<Instruction>(NAME # _H)>; } -// FP8 Look up table read with 4-bit indices +// Look up table read with 4-bit indices multiclass sve2_luti4_vector_index<string mnemonic> { def _B : sve2_lut_vector_index<ZPR8, Z_b, VectorIndexD32b, 0b1001, mnemonic> { bit idx; @@ -11226,7 +11254,7 @@ multiclass sve2_luti4_vector_index<string mnemonic> { i32, timm32_0_3, !cast<Instruction>(NAME # _H)>; } -// FP8 Look up table read with 4-bit indices (two contiguous registers) +// Look up table read with 4-bit indices (two contiguous registers) multiclass sve2_luti4_vector_vg2_index<string mnemonic> { def NAME : sve2_lut_vector_index<ZPR16, ZZ_h, VectorIndexS32b, {?, 0b101}, mnemonic> { bits<2> idx; @@ -11250,6 +11278,29 @@ multiclass sve2_luti4_vector_vg2_index<string mnemonic> { nxv16i8:$Op3, timm32_0_3:$Op4))>; } +// Look up table read with 6-bit indices +multiclass sve2_luti6_vector_index<string mnemonic> { + def _H : sve2_lut_vector_index<ZPR16, ZZ_h, VectorIndexD32b, 0b1011, mnemonic> { + bit idx; + let Inst{23} = idx; + } +} + +// Look up table +class sve2_luti6_vector<string mnemonic> + : I<(outs ZPR8:$Zd), (ins ZZ_b:$Zn, ZPRAny:$Zm), + mnemonic, "\t$Zd, $Zn, $Zm", + "", []>, Sched<[]> { + bits<5> Zd; + bits<5> Zn; + bits<5> Zm; + let Inst{31-21} = 0b01000101001; + let Inst{20-16} = Zm; + let Inst{15-10} = 0b101011; + let Inst{9-5} = Zn; + let Inst{4-0} = Zd; +} + //===----------------------------------------------------------------------===// // Checked Pointer Arithmetic (FEAT_CPA) //===----------------------------------------------------------------------===// @@ -11280,3 +11331,49 @@ class sve_int_mla_cpa<string asm> let ElementSize = ZPR64.ElementSize; } + +//===----------------------------------------------------------------------===// +// FP to Int down-converts +//===----------------------------------------------------------------------===// +class sve2_fp_to_int_downcvt<string asm, ZPRRegOp ZdRC, RegisterOperand ZSrcOp, bits<2> size, bit U> + : I<(outs ZdRC:$Zd), (ins ZSrcOp:$Zn), + asm, "\t$Zd, $Zn", "", []>, Sched<[]> { + bits<5> Zd; + bits<4> Zn; + let Inst{31-24} = 0b01100101; + let Inst{23-22} = size; + let Inst{21-11} = 0b00110100110; + let Inst{10} = U; + let Inst{9-6} = Zn; + let Inst{5} = 0b0; + let Inst{4-0} = Zd; +} + +multiclass sve2_fp_to_int_downcvt<string asm, bit U> { + def _HtoB : sve2_fp_to_int_downcvt<asm, ZPR8, ZZ_h_mul_r, 0b01, U>; + def _StoH : sve2_fp_to_int_downcvt<asm, ZPR16, ZZ_s_mul_r, 0b10, U>; + def _DtoS : sve2_fp_to_int_downcvt<asm, ZPR32, ZZ_d_mul_r, 0b11, U>; +} + +//===----------------------------------------------------------------------===// +// Int to FP up-converts +//===----------------------------------------------------------------------===// +class sve2_int_to_fp_upcvt<string asm, ZPRRegOp ZdRC, ZPRRegOp ZnRC, + bits<2> size, bits<2> U> + : I<(outs ZdRC:$Zd), (ins ZnRC:$Zn), + asm, "\t$Zd, $Zn", "", []>, Sched<[]> { + bits<5> Zd; + bits<5> Zn; + let Inst{31-24} = 0b01100101; + let Inst{23-22} = size; + let Inst{21-12} = 0b0011000011; + let Inst{11-10} = U; + let Inst{9-5} = Zn; + let Inst{4-0} = Zd; +} + +multiclass sve2_int_to_fp_upcvt<string asm, bits<2> U> { + def _BtoH : sve2_int_to_fp_upcvt<asm, ZPR16, ZPR8, 0b01, U>; + def _HtoS : sve2_int_to_fp_upcvt<asm, ZPR32, ZPR16, 0b10, U>; + def _StoD : sve2_int_to_fp_upcvt<asm, ZPR64, ZPR32, 0b11, U>; +} diff --git a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.cpp b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.cpp index d6cb0e8..268a229 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.cpp +++ b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.cpp @@ -139,6 +139,13 @@ namespace llvm { } namespace llvm { +namespace AArch64CMHPriorityHint { +#define GET_CMHPRIORITYHINT_IMPL +#include "AArch64GenSystemOperands.inc" +} // namespace AArch64CMHPriorityHint +} // namespace llvm + +namespace llvm { namespace AArch64SysReg { #define GET_SysRegsList_IMPL #include "AArch64GenSystemOperands.inc" @@ -190,6 +197,32 @@ namespace AArch64TLBIP { #define GET_TLBIPTable_IMPL #include "AArch64GenSystemOperands.inc" } // namespace AArch64TLBIP + +namespace AArch64MLBI { +#define GET_MLBITable_IMPL +#include "AArch64GenSystemOperands.inc" +} // namespace AArch64MLBI +} // namespace llvm + +namespace llvm { +namespace AArch64GIC { +#define GET_GICTable_IMPL +#include "AArch64GenSystemOperands.inc" +} // namespace AArch64GIC +} // namespace llvm + +namespace llvm { +namespace AArch64GICR { +#define GET_GICRTable_IMPL +#include "AArch64GenSystemOperands.inc" +} // namespace AArch64GICR +} // namespace llvm + +namespace llvm { +namespace AArch64GSB { +#define GET_GSBTable_IMPL +#include "AArch64GenSystemOperands.inc" +} // namespace AArch64GSB } // namespace llvm namespace llvm { diff --git a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h index fea33ef..27812e9 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h +++ b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h @@ -409,6 +409,16 @@ struct SysAliasReg : SysAlias { : SysAlias(N, E, F), NeedsReg(R) {} }; +struct SysAliasOptionalReg : SysAlias { + bool NeedsReg; + bool OptionalReg; + constexpr SysAliasOptionalReg(const char *N, uint16_t E, bool R, bool O) + : SysAlias(N, E), NeedsReg(R), OptionalReg(O) {} + constexpr SysAliasOptionalReg(const char *N, uint16_t E, bool R, bool O, + FeatureBitset F) + : SysAlias(N, E, F), NeedsReg(R), OptionalReg(O) {} +}; + struct SysAliasImm : SysAlias { uint16_t ImmValue; constexpr SysAliasImm(const char *N, uint16_t E, uint16_t I) @@ -677,6 +687,14 @@ namespace AArch64BTIHint { #include "AArch64GenSystemOperands.inc" } +namespace AArch64CMHPriorityHint { +struct CMHPriorityHint : SysAlias { + using SysAlias::SysAlias; +}; +#define GET_CMHPRIORITYHINT_DECL +#include "AArch64GenSystemOperands.inc" +} // namespace AArch64CMHPriorityHint + namespace AArch64SME { enum ToggleCondition : unsigned { Always, @@ -788,21 +806,53 @@ namespace AArch64SysReg { } namespace AArch64TLBI { - struct TLBI : SysAliasReg { - using SysAliasReg::SysAliasReg; - }; - #define GET_TLBITable_DECL - #include "AArch64GenSystemOperands.inc" +struct TLBI : SysAliasOptionalReg { + using SysAliasOptionalReg::SysAliasOptionalReg; +}; +#define GET_TLBITable_DECL +#include "AArch64GenSystemOperands.inc" } namespace AArch64TLBIP { -struct TLBIP : SysAliasReg { - using SysAliasReg::SysAliasReg; +struct TLBIP : SysAliasOptionalReg { + using SysAliasOptionalReg::SysAliasOptionalReg; }; #define GET_TLBIPTable_DECL #include "AArch64GenSystemOperands.inc" } // namespace AArch64TLBIP +namespace AArch64MLBI { +struct MLBI : SysAliasReg { + using SysAliasReg::SysAliasReg; +}; +#define GET_MLBITable_DECL +#include "AArch64GenSystemOperands.inc" +} // namespace AArch64MLBI + +namespace AArch64GIC { +struct GIC : SysAliasReg { + using SysAliasReg::SysAliasReg; +}; +#define GET_GICTable_DECL +#include "AArch64GenSystemOperands.inc" +} // namespace AArch64GIC + +namespace AArch64GICR { +struct GICR : SysAliasReg { + using SysAliasReg::SysAliasReg; +}; +#define GET_GICRTable_DECL +#include "AArch64GenSystemOperands.inc" +} // namespace AArch64GICR + +namespace AArch64GSB { +struct GSB : SysAlias { + using SysAlias::SysAlias; +}; +#define GET_GSBTable_DECL +#include "AArch64GenSystemOperands.inc" +} // namespace AArch64GSB + namespace AArch64II { /// Target Operand Flag enum. enum TOF { diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp index 8ed4062..1b559a6 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp @@ -514,8 +514,8 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM, MVT::i64, Custom); setOperationAction(ISD::SELECT_CC, MVT::i64, Expand); - setOperationAction({ISD::ABS, ISD::SMIN, ISD::UMIN, ISD::SMAX, ISD::UMAX}, - MVT::i32, Legal); + setOperationAction({ISD::SMIN, ISD::UMIN, ISD::SMAX, ISD::UMAX}, MVT::i32, + Legal); setOperationAction( {ISD::CTTZ, ISD::CTTZ_ZERO_UNDEF, ISD::CTLZ, ISD::CTLZ_ZERO_UNDEF}, diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp index 996b55f..02c5390 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp @@ -2086,7 +2086,7 @@ void AMDGPUCodeGenPassBuilder::addIRPasses(AddIRPass &addPass) const { (AMDGPUAtomicOptimizerStrategy != ScanOptions::None)) addPass(AMDGPUAtomicOptimizerPass(TM, AMDGPUAtomicOptimizerStrategy)); - addPass(AtomicExpandPass(&TM)); + addPass(AtomicExpandPass(TM)); if (TM.getOptLevel() > CodeGenOptLevel::None) { addPass(AMDGPUPromoteAllocaPass(TM)); diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp index 99ba043..5580e4c 100644 --- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp +++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp @@ -1860,7 +1860,6 @@ private: bool validateTHAndScopeBits(const MCInst &Inst, const OperandVector &Operands, const unsigned CPol); bool validateTFE(const MCInst &Inst, const OperandVector &Operands); - bool validateSetVgprMSB(const MCInst &Inst, const OperandVector &Operands); bool validateLdsDirect(const MCInst &Inst, const OperandVector &Operands); bool validateWMMA(const MCInst &Inst, const OperandVector &Operands); unsigned getConstantBusLimit(unsigned Opcode) const; @@ -5506,22 +5505,6 @@ bool AMDGPUAsmParser::validateTFE(const MCInst &Inst, return true; } -bool AMDGPUAsmParser::validateSetVgprMSB(const MCInst &Inst, - const OperandVector &Operands) { - if (Inst.getOpcode() != AMDGPU::S_SET_VGPR_MSB_gfx12) - return true; - - int Simm16Pos = - AMDGPU::getNamedOperandIdx(Inst.getOpcode(), AMDGPU::OpName::simm16); - if ((unsigned)Inst.getOperand(Simm16Pos).getImm() > 255) { - SMLoc Loc = Operands[1]->getStartLoc(); - Error(Loc, "s_set_vgpr_msb accepts values in range [0..255]"); - return false; - } - - return true; -} - bool AMDGPUAsmParser::validateWMMA(const MCInst &Inst, const OperandVector &Operands) { unsigned Opc = Inst.getOpcode(); @@ -5681,9 +5664,6 @@ bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst, SMLoc IDLoc, if (!validateTFE(Inst, Operands)) { return false; } - if (!validateSetVgprMSB(Inst, Operands)) { - return false; - } if (!validateWMMA(Inst, Operands)) { return false; } diff --git a/llvm/lib/Target/AMDGPU/R600ISelLowering.cpp b/llvm/lib/Target/AMDGPU/R600ISelLowering.cpp index 09ef6ac..2aa54c9 100644 --- a/llvm/lib/Target/AMDGPU/R600ISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/R600ISelLowering.cpp @@ -45,9 +45,6 @@ R600TargetLowering::R600TargetLowering(const TargetMachine &TM, // Legalize loads and stores to the private address space. setOperationAction(ISD::LOAD, {MVT::i32, MVT::v2i32, MVT::v4i32}, Custom); - // 32-bit ABS is legal for AMDGPU except for R600 - setOperationAction(ISD::ABS, MVT::i32, Expand); - // EXTLOAD should be the same as ZEXTLOAD. It is legal for some address // spaces, so it is custom lowered to handle those where it isn't. for (auto Op : {ISD::SEXTLOAD, ISD::ZEXTLOAD, ISD::EXTLOAD}) diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp index a757421..be42291 100644 --- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp @@ -298,7 +298,7 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM, setOperationAction(ISD::BR_CC, {MVT::i1, MVT::i32, MVT::i64, MVT::f32, MVT::f64}, Expand); - setOperationAction({ISD::UADDO, ISD::USUBO}, MVT::i32, Legal); + setOperationAction({ISD::ABS, ISD::UADDO, ISD::USUBO}, MVT::i32, Legal); setOperationAction({ISD::UADDO_CARRY, ISD::USUBO_CARRY}, MVT::i32, Legal); diff --git a/llvm/lib/Target/AMDGPU/VOP3Instructions.td b/llvm/lib/Target/AMDGPU/VOP3Instructions.td index ee10190..05ba76a 100644 --- a/llvm/lib/Target/AMDGPU/VOP3Instructions.td +++ b/llvm/lib/Target/AMDGPU/VOP3Instructions.td @@ -976,10 +976,10 @@ def : GCNPat < } // End SubtargetPredicate = HasLshlAddU64Inst let SubtargetPredicate = HasAddMinMaxInsts in { -def : ThreeOp_i32_Pats<add, smax, V_ADD_MAX_I32_e64>; -def : ThreeOp_i32_Pats<add, umax, V_ADD_MAX_U32_e64>; -def : ThreeOp_i32_Pats<add, smin, V_ADD_MIN_I32_e64>; -def : ThreeOp_i32_Pats<add, umin, V_ADD_MIN_U32_e64>; +def : ThreeOp_i32_Pats<saddsat, smax, V_ADD_MAX_I32_e64>; +def : ThreeOp_i32_Pats<uaddsat, umax, V_ADD_MAX_U32_e64>; +def : ThreeOp_i32_Pats<saddsat, smin, V_ADD_MIN_I32_e64>; +def : ThreeOp_i32_Pats<uaddsat, umin, V_ADD_MIN_U32_e64>; } def : VOPBinOpClampPat<saddsat, V_ADD_I32_e64, i32>; diff --git a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td index c4692b7..4ae2c1e 100644 --- a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td +++ b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td @@ -464,10 +464,10 @@ class ThreeOp_OpSelClampPats <SDPatternOperator op1, SDPatternOperator op2, >; let SubtargetPredicate = HasPkAddMinMaxInsts in { -def : ThreeOp_OpSelClampPats<add, smax, V_PK_ADD_MAX_I16>; -def : ThreeOp_OpSelClampPats<add, umax, V_PK_ADD_MAX_U16>; -def : ThreeOp_OpSelClampPats<add, smin, V_PK_ADD_MIN_I16>; -def : ThreeOp_OpSelClampPats<add, umin, V_PK_ADD_MIN_U16>; +def : ThreeOp_OpSelClampPats<saddsat, smax, V_PK_ADD_MAX_I16>; +def : ThreeOp_OpSelClampPats<uaddsat, umax, V_PK_ADD_MAX_U16>; +def : ThreeOp_OpSelClampPats<saddsat, smin, V_PK_ADD_MIN_I16>; +def : ThreeOp_OpSelClampPats<uaddsat, umin, V_PK_ADD_MIN_U16>; } let SubtargetPredicate = HasPkMinMax3Insts in { diff --git a/llvm/lib/Target/ARM/ARMArchitectures.td b/llvm/lib/Target/ARM/ARMArchitectures.td index 301ed5b..bfcecfe 100644 --- a/llvm/lib/Target/ARM/ARMArchitectures.td +++ b/llvm/lib/Target/ARM/ARMArchitectures.td @@ -297,6 +297,18 @@ def ARMv96a : Architecture<"armv9.6-a", "ARMv96a", [HasV9_6aOps, FeatureCRC, FeatureRAS, FeatureDotProd]>; +def ARMv97a : Architecture<"armv9.7-a", "ARMv97a", [HasV9_7aOps, + FeatureAClass, + FeatureDB, + FeatureFPARMv8, + FeatureNEON, + FeatureDSP, + FeatureTrustZone, + FeatureMP, + FeatureVirtualization, + FeatureCRC, + FeatureRAS, + FeatureDotProd]>; def ARMv8r : Architecture<"armv8-r", "ARMv8r", [HasV8Ops, FeatureRClass, FeatureDB, diff --git a/llvm/lib/Target/ARM/ARMAsmPrinter.cpp b/llvm/lib/Target/ARM/ARMAsmPrinter.cpp index 3368a50..36b9908 100644 --- a/llvm/lib/Target/ARM/ARMAsmPrinter.cpp +++ b/llvm/lib/Target/ARM/ARMAsmPrinter.cpp @@ -1471,6 +1471,435 @@ void ARMAsmPrinter::EmitUnwindingInstruction(const MachineInstr *MI) { // instructions) auto-generated. #include "ARMGenMCPseudoLowering.inc" +// Helper function to check if a register is live (used as an implicit operand) +// in the given call instruction. +static bool isRegisterLiveInCall(const MachineInstr &Call, MCRegister Reg) { + for (const MachineOperand &MO : Call.implicit_operands()) { + if (MO.isReg() && MO.getReg() == Reg && MO.isUse()) { + return true; + } + } + return false; +} + +void ARMAsmPrinter::EmitKCFI_CHECK_ARM32(Register AddrReg, int64_t Type, + const MachineInstr &Call, + int64_t PrefixNops) { + // Choose scratch register: r12 primary, r3 if target is r12. + unsigned ScratchReg = ARM::R12; + if (AddrReg == ARM::R12) { + ScratchReg = ARM::R3; + } + + // Calculate ESR for ARM mode (16-bit): 0x8000 | (scratch_reg << 5) | addr_reg + // Note: scratch_reg is always 0x1F since the EOR sequence clobbers it. + const ARMBaseRegisterInfo *TRI = static_cast<const ARMBaseRegisterInfo *>( + MF->getSubtarget().getRegisterInfo()); + unsigned AddrIndex = TRI->getEncodingValue(AddrReg); + unsigned ESR = 0x8000 | (31 << 5) | (AddrIndex & 31); + + // Check if r3 is live and needs to be spilled. + bool NeedSpillR3 = + (ScratchReg == ARM::R3) && isRegisterLiveInCall(Call, ARM::R3); + + // If we need to spill r3, push it first. + if (NeedSpillR3) { + // push {r3} + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::STMDB_UPD) + .addReg(ARM::SP) + .addReg(ARM::SP) + .addImm(ARMCC::AL) + .addReg(0) + .addReg(ARM::R3)); + } + + // Clear bit 0 of target address to handle Thumb function pointers. + // In 32-bit ARM, function pointers may have the low bit set to indicate + // Thumb state when ARM/Thumb interworking is enabled (ARMv4T and later). + // We need to clear it to avoid an alignment fault when loading. + // bic scratch, target, #1 + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::BICri) + .addReg(ScratchReg) + .addReg(AddrReg) + .addImm(1) + .addImm(ARMCC::AL) + .addReg(0) + .addReg(0)); + + // ldr scratch, [scratch, #-(PrefixNops * 4 + 4)] + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::LDRi12) + .addReg(ScratchReg) + .addReg(ScratchReg) + .addImm(-(PrefixNops * 4 + 4)) + .addImm(ARMCC::AL) + .addReg(0)); + + // Each EOR instruction XORs one byte of the type, shifted to its position. + for (int i = 0; i < 4; i++) { + uint8_t byte = (Type >> (i * 8)) & 0xFF; + uint32_t imm = byte << (i * 8); + bool isLast = (i == 3); + + // Encode as ARM modified immediate. + int SOImmVal = ARM_AM::getSOImmVal(imm); + assert(SOImmVal != -1 && + "Cannot encode immediate as ARM modified immediate"); + + // eor[s] scratch, scratch, #imm (last one sets flags with CPSR) + EmitToStreamer(*OutStreamer, + MCInstBuilder(ARM::EORri) + .addReg(ScratchReg) + .addReg(ScratchReg) + .addImm(SOImmVal) + .addImm(ARMCC::AL) + .addReg(0) + .addReg(isLast ? ARM::CPSR : ARM::NoRegister)); + } + + // If we spilled r3, restore it immediately after the comparison. + // This must happen before the branch so r3 is valid on both paths. + if (NeedSpillR3) { + // pop {r3} + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::LDMIA_UPD) + .addReg(ARM::SP) + .addReg(ARM::SP) + .addImm(ARMCC::AL) + .addReg(0) + .addReg(ARM::R3)); + } + + // beq .Lpass (branch if types match, i.e., scratch is zero) + MCSymbol *Pass = OutContext.createTempSymbol(); + EmitToStreamer(*OutStreamer, + MCInstBuilder(ARM::Bcc) + .addExpr(MCSymbolRefExpr::create(Pass, OutContext)) + .addImm(ARMCC::EQ) + .addReg(ARM::CPSR)); + + // udf #ESR (trap with encoded diagnostic) + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::UDF).addImm(ESR)); + + OutStreamer->emitLabel(Pass); +} + +void ARMAsmPrinter::EmitKCFI_CHECK_Thumb2(Register AddrReg, int64_t Type, + const MachineInstr &Call, + int64_t PrefixNops) { + // Choose scratch register: r12 primary, r3 if target is r12. + unsigned ScratchReg = ARM::R12; + if (AddrReg == ARM::R12) { + ScratchReg = ARM::R3; + } + + // Calculate ESR for Thumb mode (8-bit): 0x80 | addr_reg + // Bit 7: KCFI trap indicator + // Bits 6-5: Reserved + // Bits 4-0: Address register encoding + const ARMBaseRegisterInfo *TRI = static_cast<const ARMBaseRegisterInfo *>( + MF->getSubtarget().getRegisterInfo()); + unsigned AddrIndex = TRI->getEncodingValue(AddrReg); + unsigned ESR = 0x80 | (AddrIndex & 0x1F); + + // Check if r3 is live and needs to be spilled. + bool NeedSpillR3 = + (ScratchReg == ARM::R3) && isRegisterLiveInCall(Call, ARM::R3); + + // If we need to spill r3, push it first. + if (NeedSpillR3) { + // push {r3} + EmitToStreamer( + *OutStreamer, + MCInstBuilder(ARM::tPUSH).addImm(ARMCC::AL).addReg(0).addReg(ARM::R3)); + } + + // Clear bit 0 of target address to handle Thumb function pointers. + // In 32-bit ARM, function pointers may have the low bit set to indicate + // Thumb state when ARM/Thumb interworking is enabled (ARMv4T and later). + // We need to clear it to avoid an alignment fault when loading. + // bic scratch, target, #1 + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::t2BICri) + .addReg(ScratchReg) + .addReg(AddrReg) + .addImm(1) + .addImm(ARMCC::AL) + .addReg(0) + .addReg(0)); + + // ldr scratch, [scratch, #-(PrefixNops * 4 + 4)] + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::t2LDRi8) + .addReg(ScratchReg) + .addReg(ScratchReg) + .addImm(-(PrefixNops * 4 + 4)) + .addImm(ARMCC::AL) + .addReg(0)); + + // Each EOR instruction XORs one byte of the type, shifted to its position. + for (int i = 0; i < 4; i++) { + uint8_t byte = (Type >> (i * 8)) & 0xFF; + uint32_t imm = byte << (i * 8); + bool isLast = (i == 3); + + // Verify the immediate can be encoded as Thumb2 modified immediate. + assert(ARM_AM::getT2SOImmVal(imm) != -1 && + "Cannot encode immediate as Thumb2 modified immediate"); + + // eor[s] scratch, scratch, #imm (last one sets flags with CPSR) + EmitToStreamer(*OutStreamer, + MCInstBuilder(ARM::t2EORri) + .addReg(ScratchReg) + .addReg(ScratchReg) + .addImm(imm) + .addImm(ARMCC::AL) + .addReg(0) + .addReg(isLast ? ARM::CPSR : ARM::NoRegister)); + } + + // If we spilled r3, restore it immediately after the comparison. + // This must happen before the branch so r3 is valid on both paths. + if (NeedSpillR3) { + // pop {r3} + EmitToStreamer( + *OutStreamer, + MCInstBuilder(ARM::tPOP).addImm(ARMCC::AL).addReg(0).addReg(ARM::R3)); + } + + // beq .Lpass (branch if types match, i.e., scratch is zero) + MCSymbol *Pass = OutContext.createTempSymbol(); + EmitToStreamer(*OutStreamer, + MCInstBuilder(ARM::t2Bcc) + .addExpr(MCSymbolRefExpr::create(Pass, OutContext)) + .addImm(ARMCC::EQ) + .addReg(ARM::CPSR)); + + // udf #ESR (trap with encoded diagnostic) + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tUDF).addImm(ESR)); + + OutStreamer->emitLabel(Pass); +} + +void ARMAsmPrinter::EmitKCFI_CHECK_Thumb1(Register AddrReg, int64_t Type, + const MachineInstr &Call, + int64_t PrefixNops) { + // For Thumb1, use R2 unconditionally as scratch register (a low register + // required for tLDRi). R3 is used for building the type hash. + unsigned ScratchReg = ARM::R2; + unsigned TempReg = ARM::R3; + + // Check if r3 is live and needs to be spilled. + bool NeedSpillR3 = isRegisterLiveInCall(Call, ARM::R3); + + // Spill r3 if needed + if (NeedSpillR3) { + EmitToStreamer( + *OutStreamer, + MCInstBuilder(ARM::tPUSH).addImm(ARMCC::AL).addReg(0).addReg(ARM::R3)); + } + + // Check if r2 is live and needs to be spilled. + bool NeedSpillR2 = isRegisterLiveInCall(Call, ARM::R2); + + // Push R2 if it's live + if (NeedSpillR2) { + EmitToStreamer( + *OutStreamer, + MCInstBuilder(ARM::tPUSH).addImm(ARMCC::AL).addReg(0).addReg(ARM::R2)); + } + + // Clear bit 0 from target address + // TempReg (R3) is used first as helper for BIC, then later for building type + // hash. + + // movs temp, #1 + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tMOVi8) + .addReg(TempReg) + .addReg(ARM::CPSR) + .addImm(1) + .addImm(ARMCC::AL) + .addReg(0)); + + // mov scratch, target + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tMOVr) + .addReg(ScratchReg) + .addReg(AddrReg) + .addImm(ARMCC::AL)); + + // bics scratch, temp (scratch = scratch & ~temp) + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tBIC) + .addReg(ScratchReg) + .addReg(ARM::CPSR) + .addReg(ScratchReg) + .addReg(TempReg) + .addImm(ARMCC::AL) + .addReg(0)); + + // Load type hash. Thumb1 doesn't support negative offsets, so subtract. + int offset = PrefixNops * 4 + 4; + + // subs scratch, #offset + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tSUBi8) + .addReg(ScratchReg) + .addReg(ARM::CPSR) + .addReg(ScratchReg) + .addImm(offset) + .addImm(ARMCC::AL) + .addReg(0)); + + // ldr scratch, [scratch, #0] + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tLDRi) + .addReg(ScratchReg) + .addReg(ScratchReg) + .addImm(0) + .addImm(ARMCC::AL) + .addReg(0)); + + // Load expected type inline (instead of EOR sequence) + // + // This creates the 32-bit value byte-by-byte in the temp register: + // movs temp, #byte3 (high byte) + // lsls temp, temp, #8 + // adds temp, #byte2 + // lsls temp, temp, #8 + // adds temp, #byte1 + // lsls temp, temp, #8 + // adds temp, #byte0 (low byte) + + uint8_t byte0 = (Type >> 0) & 0xFF; + uint8_t byte1 = (Type >> 8) & 0xFF; + uint8_t byte2 = (Type >> 16) & 0xFF; + uint8_t byte3 = (Type >> 24) & 0xFF; + + // movs temp, #byte3 (start with high byte) + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tMOVi8) + .addReg(TempReg) + .addReg(ARM::CPSR) + .addImm(byte3) + .addImm(ARMCC::AL) + .addReg(0)); + + // lsls temp, temp, #8 + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tLSLri) + .addReg(TempReg) + .addReg(ARM::CPSR) + .addReg(TempReg) + .addImm(8) + .addImm(ARMCC::AL) + .addReg(0)); + + // adds temp, #byte2 + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tADDi8) + .addReg(TempReg) + .addReg(ARM::CPSR) + .addReg(TempReg) + .addImm(byte2) + .addImm(ARMCC::AL) + .addReg(0)); + + // lsls temp, temp, #8 + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tLSLri) + .addReg(TempReg) + .addReg(ARM::CPSR) + .addReg(TempReg) + .addImm(8) + .addImm(ARMCC::AL) + .addReg(0)); + + // adds temp, #byte1 + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tADDi8) + .addReg(TempReg) + .addReg(ARM::CPSR) + .addReg(TempReg) + .addImm(byte1) + .addImm(ARMCC::AL) + .addReg(0)); + + // lsls temp, temp, #8 + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tLSLri) + .addReg(TempReg) + .addReg(ARM::CPSR) + .addReg(TempReg) + .addImm(8) + .addImm(ARMCC::AL) + .addReg(0)); + + // adds temp, #byte0 (low byte) + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tADDi8) + .addReg(TempReg) + .addReg(ARM::CPSR) + .addReg(TempReg) + .addImm(byte0) + .addImm(ARMCC::AL) + .addReg(0)); + + // cmp scratch, temp + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tCMPr) + .addReg(ScratchReg) + .addReg(TempReg) + .addImm(ARMCC::AL) + .addReg(0)); + + // Restore registers if spilled (pop in reverse order of push: R2, then R3) + if (NeedSpillR2) { + // pop {r2} + EmitToStreamer( + *OutStreamer, + MCInstBuilder(ARM::tPOP).addImm(ARMCC::AL).addReg(0).addReg(ARM::R2)); + } + + // Restore r3 if spilled + if (NeedSpillR3) { + // pop {r3} + EmitToStreamer( + *OutStreamer, + MCInstBuilder(ARM::tPOP).addImm(ARMCC::AL).addReg(0).addReg(ARM::R3)); + } + + // beq .Lpass (branch if types match, i.e., scratch == temp) + MCSymbol *Pass = OutContext.createTempSymbol(); + EmitToStreamer(*OutStreamer, + MCInstBuilder(ARM::tBcc) + .addExpr(MCSymbolRefExpr::create(Pass, OutContext)) + .addImm(ARMCC::EQ) + .addReg(ARM::CPSR)); + + // bkpt #0 (trap with encoded diagnostic) + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tBKPT).addImm(0)); + + OutStreamer->emitLabel(Pass); +} + +void ARMAsmPrinter::LowerKCFI_CHECK(const MachineInstr &MI) { + Register AddrReg = MI.getOperand(0).getReg(); + const int64_t Type = MI.getOperand(1).getImm(); + + // Get the call instruction that follows this KCFI_CHECK. + assert(std::next(MI.getIterator())->isCall() && + "KCFI_CHECK not followed by a call instruction"); + const MachineInstr &Call = *std::next(MI.getIterator()); + + // Adjust the offset for patchable-function-prefix. + int64_t PrefixNops = 0; + MI.getMF() + ->getFunction() + .getFnAttribute("patchable-function-prefix") + .getValueAsString() + .getAsInteger(10, PrefixNops); + + // Emit the appropriate instruction sequence based on the opcode variant. + switch (MI.getOpcode()) { + case ARM::KCFI_CHECK_ARM: + EmitKCFI_CHECK_ARM32(AddrReg, Type, Call, PrefixNops); + break; + case ARM::KCFI_CHECK_Thumb2: + EmitKCFI_CHECK_Thumb2(AddrReg, Type, Call, PrefixNops); + break; + case ARM::KCFI_CHECK_Thumb1: + EmitKCFI_CHECK_Thumb1(AddrReg, Type, Call, PrefixNops); + break; + default: + llvm_unreachable("Unexpected KCFI_CHECK opcode"); + } +} + void ARMAsmPrinter::emitInstruction(const MachineInstr *MI) { ARM_MC::verifyInstructionPredicates(MI->getOpcode(), getSubtargetInfo().getFeatureBits()); @@ -1504,6 +1933,11 @@ void ARMAsmPrinter::emitInstruction(const MachineInstr *MI) { switch (Opc) { case ARM::t2MOVi32imm: llvm_unreachable("Should be lowered by thumb2it pass"); case ARM::DBG_VALUE: llvm_unreachable("Should be handled by generic printing"); + case ARM::KCFI_CHECK_ARM: + case ARM::KCFI_CHECK_Thumb2: + case ARM::KCFI_CHECK_Thumb1: + LowerKCFI_CHECK(*MI); + return; case ARM::LEApcrel: case ARM::tLEApcrel: case ARM::t2LEApcrel: { diff --git a/llvm/lib/Target/ARM/ARMAsmPrinter.h b/llvm/lib/Target/ARM/ARMAsmPrinter.h index 2b067c7..9e92b5a 100644 --- a/llvm/lib/Target/ARM/ARMAsmPrinter.h +++ b/llvm/lib/Target/ARM/ARMAsmPrinter.h @@ -123,9 +123,20 @@ public: void LowerPATCHABLE_FUNCTION_EXIT(const MachineInstr &MI); void LowerPATCHABLE_TAIL_CALL(const MachineInstr &MI); + // KCFI check lowering + void LowerKCFI_CHECK(const MachineInstr &MI); + private: void EmitSled(const MachineInstr &MI, SledKind Kind); + // KCFI check emission helpers + void EmitKCFI_CHECK_ARM32(Register AddrReg, int64_t Type, + const MachineInstr &Call, int64_t PrefixNops); + void EmitKCFI_CHECK_Thumb2(Register AddrReg, int64_t Type, + const MachineInstr &Call, int64_t PrefixNops); + void EmitKCFI_CHECK_Thumb1(Register AddrReg, int64_t Type, + const MachineInstr &Call, int64_t PrefixNops); + // Helpers for emitStartOfAsmFile() and emitEndOfAsmFile() void emitAttributes(); diff --git a/llvm/lib/Target/ARM/ARMExpandPseudoInsts.cpp b/llvm/lib/Target/ARM/ARMExpandPseudoInsts.cpp index 0d7b6d1..fffb6373 100644 --- a/llvm/lib/Target/ARM/ARMExpandPseudoInsts.cpp +++ b/llvm/lib/Target/ARM/ARMExpandPseudoInsts.cpp @@ -2301,6 +2301,8 @@ bool ARMExpandPseudo::ExpandMI(MachineBasicBlock &MBB, for (unsigned i = 2, e = MBBI->getNumOperands(); i != e; ++i) NewMI->addOperand(MBBI->getOperand(i)); + NewMI->setCFIType(*MBB.getParent(), MI.getCFIType()); + // Update call info and delete the pseudo instruction TCRETURN. if (MI.isCandidateForAdditionalCallInfo()) MI.getMF()->moveAdditionalCallInfo(&MI, &*NewMI); diff --git a/llvm/lib/Target/ARM/ARMFeatures.td b/llvm/lib/Target/ARM/ARMFeatures.td index 9b1fa5d..e562b21 100644 --- a/llvm/lib/Target/ARM/ARMFeatures.td +++ b/llvm/lib/Target/ARM/ARMFeatures.td @@ -712,6 +712,11 @@ def HasV9_6aOps : SubtargetFeature<"v9.6a", "HasV9_6aOps", "true", "Support ARM v9.6a instructions", [HasV9_5aOps]>; +// Armv9.7-A is a v9-only architecture. +def HasV9_7aOps : SubtargetFeature<"v9.7a", "HasV9_7aOps", "true", + "Support ARM v9.7a instructions", + [HasV9_6aOps]>; + def HasV8_1MMainlineOps : SubtargetFeature< "v8.1m.main", "HasV8_1MMainlineOps", "true", "Support ARM v8-1M Mainline instructions", diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp index b1a668e..8122db2 100644 --- a/llvm/lib/Target/ARM/ARMISelLowering.cpp +++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -2849,6 +2849,8 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, if (isTailCall) { MF.getFrameInfo().setHasTailCall(); SDValue Ret = DAG.getNode(ARMISD::TC_RETURN, dl, MVT::Other, Ops); + if (CLI.CFIType) + Ret.getNode()->setCFIType(CLI.CFIType->getZExtValue()); DAG.addNoMergeSiteInfo(Ret.getNode(), CLI.NoMerge); DAG.addCallSiteInfo(Ret.getNode(), std::move(CSInfo)); return Ret; @@ -2856,6 +2858,8 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, // Returns a chain and a flag for retval copy to use. Chain = DAG.getNode(CallOpc, dl, {MVT::Other, MVT::Glue}, Ops); + if (CLI.CFIType) + Chain.getNode()->setCFIType(CLI.CFIType->getZExtValue()); DAG.addNoMergeSiteInfo(Chain.getNode(), CLI.NoMerge); InGlue = Chain.getValue(1); DAG.addCallSiteInfo(Chain.getNode(), std::move(CSInfo)); @@ -12008,6 +12012,71 @@ static void genTPLoopBody(MachineBasicBlock *TpLoopBody, .add(predOps(ARMCC::AL)); } +bool ARMTargetLowering::supportKCFIBundles() const { + // KCFI is supported in all ARM/Thumb modes + return true; +} + +MachineInstr * +ARMTargetLowering::EmitKCFICheck(MachineBasicBlock &MBB, + MachineBasicBlock::instr_iterator &MBBI, + const TargetInstrInfo *TII) const { + assert(MBBI->isCall() && MBBI->getCFIType() && + "Invalid call instruction for a KCFI check"); + + MachineOperand *TargetOp = nullptr; + switch (MBBI->getOpcode()) { + // ARM mode opcodes + case ARM::BLX: + case ARM::BLX_pred: + case ARM::BLX_noip: + case ARM::BLX_pred_noip: + case ARM::BX_CALL: + TargetOp = &MBBI->getOperand(0); + break; + case ARM::TCRETURNri: + case ARM::TCRETURNrinotr12: + case ARM::TAILJMPr: + case ARM::TAILJMPr4: + TargetOp = &MBBI->getOperand(0); + break; + // Thumb mode opcodes (Thumb1 and Thumb2) + // Note: Most Thumb call instructions have predicate operands before the + // target register Format: tBLXr pred, predreg, target_register, ... + case ARM::tBLXr: // Thumb1/Thumb2: BLX register (requires V5T) + case ARM::tBLXr_noip: // Thumb1/Thumb2: BLX register, no IP clobber + case ARM::tBX_CALL: // Thumb1 only: BX call (push LR, BX) + TargetOp = &MBBI->getOperand(2); + break; + // Tail call instructions don't have predicates, target is operand 0 + case ARM::tTAILJMPr: // Thumb1/Thumb2: Tail call via register + TargetOp = &MBBI->getOperand(0); + break; + default: + llvm_unreachable("Unexpected CFI call opcode"); + } + + assert(TargetOp && TargetOp->isReg() && "Invalid target operand"); + TargetOp->setIsRenamable(false); + + // Select the appropriate KCFI_CHECK variant based on the instruction set + unsigned KCFICheckOpcode; + if (Subtarget->isThumb()) { + if (Subtarget->isThumb2()) { + KCFICheckOpcode = ARM::KCFI_CHECK_Thumb2; + } else { + KCFICheckOpcode = ARM::KCFI_CHECK_Thumb1; + } + } else { + KCFICheckOpcode = ARM::KCFI_CHECK_ARM; + } + + return BuildMI(MBB, MBBI, MBBI->getDebugLoc(), TII->get(KCFICheckOpcode)) + .addReg(TargetOp->getReg()) + .addImm(MBBI->getCFIType()) + .getInstr(); +} + MachineBasicBlock * ARMTargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, MachineBasicBlock *BB) const { diff --git a/llvm/lib/Target/ARM/ARMISelLowering.h b/llvm/lib/Target/ARM/ARMISelLowering.h index 70aa001..8c5e0cf 100644 --- a/llvm/lib/Target/ARM/ARMISelLowering.h +++ b/llvm/lib/Target/ARM/ARMISelLowering.h @@ -447,6 +447,12 @@ class VectorType; void AdjustInstrPostInstrSelection(MachineInstr &MI, SDNode *Node) const override; + bool supportKCFIBundles() const override; + + MachineInstr *EmitKCFICheck(MachineBasicBlock &MBB, + MachineBasicBlock::instr_iterator &MBBI, + const TargetInstrInfo *TII) const override; + SDValue PerformCMOVCombine(SDNode *N, SelectionDAG &DAG) const; SDValue PerformBRCONDCombine(SDNode *N, SelectionDAG &DAG) const; SDValue PerformCMOVToBFICombine(SDNode *N, SelectionDAG &DAG) const; diff --git a/llvm/lib/Target/ARM/ARMInstrInfo.td b/llvm/lib/Target/ARM/ARMInstrInfo.td index 282ff53..53be167 100644 --- a/llvm/lib/Target/ARM/ARMInstrInfo.td +++ b/llvm/lib/Target/ARM/ARMInstrInfo.td @@ -6536,6 +6536,36 @@ def CMP_SWAP_64 : PseudoInst<(outs GPRPair:$Rd, GPRPair:$addr_temp_out), def : Pat<(atomic_fence (timm), 0), (MEMBARRIER)>; //===----------------------------------------------------------------------===// +// KCFI check pseudo-instruction. +//===----------------------------------------------------------------------===// +// KCFI_CHECK pseudo-instruction for Kernel Control-Flow Integrity. +// Expands to a sequence that verifies the function pointer's type hash. +// Different sizes for different architectures due to different expansions. + +def KCFI_CHECK_ARM + : PseudoInst<(outs), (ins GPR:$ptr, i32imm:$type), NoItinerary, []>, + Sched<[]>, + Requires<[IsARM]> { + let Size = 28; // 7 instructions (bic, ldr, 4x eor, beq, udf) +} + +def KCFI_CHECK_Thumb2 + : PseudoInst<(outs), (ins GPR:$ptr, i32imm:$type), NoItinerary, []>, + Sched<[]>, + Requires<[IsThumb2]> { + let Size = + 32; // worst-case 9 instructions (push, bic, ldr, 4x eor, pop, beq.w, udf) +} + +def KCFI_CHECK_Thumb1 + : PseudoInst<(outs), (ins GPR:$ptr, i32imm:$type), NoItinerary, []>, + Sched<[]>, + Requires<[IsThumb1Only]> { + let Size = 50; // worst-case 25 instructions (pushes, bic helper, type + // building, cmp, pops) +} + +//===----------------------------------------------------------------------===// // Instructions used for emitting unwind opcodes on Windows. //===----------------------------------------------------------------------===// let isPseudo = 1 in { diff --git a/llvm/lib/Target/ARM/ARMTargetMachine.cpp b/llvm/lib/Target/ARM/ARMTargetMachine.cpp index 86740a9..590d4c7 100644 --- a/llvm/lib/Target/ARM/ARMTargetMachine.cpp +++ b/llvm/lib/Target/ARM/ARMTargetMachine.cpp @@ -111,6 +111,7 @@ extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void LLVMInitializeARMTarget() { initializeMVELaneInterleavingPass(Registry); initializeARMFixCortexA57AES1742098Pass(Registry); initializeARMDAGToDAGISelLegacyPass(Registry); + initializeKCFIPass(Registry); } static std::unique_ptr<TargetLoweringObjectFile> createTLOF(const Triple &TT) { @@ -487,6 +488,9 @@ void ARMPassConfig::addPreSched2() { // proper scheduling. addPass(createARMExpandPseudoPass()); + // Emit KCFI checks for indirect calls. + addPass(createKCFIPass()); + if (getOptLevel() != CodeGenOptLevel::None) { // When optimising for size, always run the Thumb2SizeReduction pass before // IfConversion. Otherwise, check whether IT blocks are restricted @@ -517,9 +521,12 @@ void ARMPassConfig::addPreSched2() { void ARMPassConfig::addPreEmitPass() { addPass(createThumb2SizeReductionPass()); - // Constant island pass work on unbundled instructions. + // Unpack bundles for: + // - Thumb2: Constant island pass requires unbundled instructions + // - KCFI: KCFI_CHECK pseudo instructions need to be unbundled for AsmPrinter addPass(createUnpackMachineBundles([](const MachineFunction &MF) { - return MF.getSubtarget<ARMSubtarget>().isThumb2(); + return MF.getSubtarget<ARMSubtarget>().isThumb2() || + MF.getFunction().getParent()->getModuleFlag("kcfi"); })); // Don't optimize barriers or block placement at -O0. @@ -530,6 +537,7 @@ void ARMPassConfig::addPreEmitPass() { } void ARMPassConfig::addPreEmitPass2() { + // Inserts fixup instructions before unsafe AES operations. Instructions may // be inserted at the start of blocks and at within blocks so this pass has to // come before those below. diff --git a/llvm/lib/Target/ARM/MCTargetDesc/ARMELFStreamer.cpp b/llvm/lib/Target/ARM/MCTargetDesc/ARMELFStreamer.cpp index 0796746..94b511a 100644 --- a/llvm/lib/Target/ARM/MCTargetDesc/ARMELFStreamer.cpp +++ b/llvm/lib/Target/ARM/MCTargetDesc/ARMELFStreamer.cpp @@ -895,6 +895,7 @@ void ARMTargetELFStreamer::emitArchDefaultAttributes() { case ARM::ArchKind::ARMV9_4A: case ARM::ArchKind::ARMV9_5A: case ARM::ArchKind::ARMV9_6A: + case ARM::ArchKind::ARMV9_7A: S.setAttributeItem(CPU_arch_profile, ApplicationProfile, false); S.setAttributeItem(ARM_ISA_use, Allowed, false); S.setAttributeItem(THUMB_ISA_use, AllowThumb32, false); diff --git a/llvm/lib/Target/AVR/AVRInstrInfo.td b/llvm/lib/Target/AVR/AVRInstrInfo.td index 02fb905..4a2f714 100644 --- a/llvm/lib/Target/AVR/AVRInstrInfo.td +++ b/llvm/lib/Target/AVR/AVRInstrInfo.td @@ -1504,14 +1504,26 @@ let Defs = [SREG], hasSideEffects = 0 in def FRMIDX : Pseudo<(outs DLDREGS:$dst), (ins DLDREGS:$src, i16imm:$src2), "frmidx\t$dst, $src, $src2", []>; +// The instructions STDSPQRr and STDWSPQRr are used to store to the stack +// frame. The most accurate implementation would be to load the SP into +// a temporary pointer variable and then STDPtrQRr. However for efficiency, +// we assume that R29R28 contains the current call frame pointer. +// However in the PEI pass we sometimes rewrite a ADJCALLSTACKDOWN pseudo, +// plus one or more STDSPQRr/STDWSPQRr pseudo instructions to use Z for a +// stack adjustment then as a base pointer. To avoid corruption, we thus +// specify special classes of registers, like GPR8 and DREGS, but with +// the Z register removed, as the source/input to these instructions. // This pseudo is either converted to a regular store or a push which clobbers // SP. -def STDSPQRr : StorePseudo<(outs), (ins memspi:$dst, GPR8:$src), +let Defs = [SP], Uses = [SP], hasSideEffects = 0 in +def STDSPQRr : StorePseudo<(outs), (ins memspi:$dst, GPR8NOZ:$src), "stdstk\t$dst, $src", [(store i8:$src, addr:$dst)]>; +// See the comment on STDSPQRr. // This pseudo is either converted to a regular store or a push which clobbers // SP. -def STDWSPQRr : StorePseudo<(outs), (ins memspi:$dt, DREGS:$src), +let Defs = [SP], Uses = [SP], hasSideEffects = 0 in +def STDWSPQRr : StorePseudo<(outs), (ins memspi:$dt, DREGSNOZ:$src), "stdwstk\t$dt, $src", [(store i16:$src, addr:$dt)]>; // SP read/write pseudos. diff --git a/llvm/lib/Target/AVR/AVRRegisterInfo.td b/llvm/lib/Target/AVR/AVRRegisterInfo.td index 182f92c..9b935b1 100644 --- a/llvm/lib/Target/AVR/AVRRegisterInfo.td +++ b/llvm/lib/Target/AVR/AVRRegisterInfo.td @@ -211,6 +211,31 @@ def PTRDISPREGS : RegisterClass<"AVR", [i16], 8, (add R31R30, R29R28), ptr>; // model this using a register class containing only the Z register. def ZREG : RegisterClass<"AVR", [i16], 8, (add R31R30)>; +// general registers excluding Z register lo/hi, these are the only +// registers that are always safe for STDSPQr instructions +def GPR8NOZ : RegisterClass<"AVR", [i8], 8, + (// Return value and argument registers. + add R24, R25, R18, R19, R20, R21, R22, R23, + // Scratch registers. + R26, R27, + // Callee saved registers. + R28, R29, R17, R16, R15, R14, R13, R12, R11, R10, + R9, R8, R7, R6, R5, R4, R3, R2, R0, R1)>; + +// 16-bit pair register class excluding Z register lo/hi, these are the only +// registers that are always safe for STDWSPQr instructions +def DREGSNOZ : RegisterClass<"AVR", [i16], 8, + (// Return value and arguments. + add R25R24, R19R18, R21R20, R23R22, + // Scratch registers. + R27R26, + // Callee saved registers. + R29R28, R17R16, R15R14, R13R12, R11R10, R9R8, + R7R6, R5R4, R3R2, R1R0, + // Pseudo regs for unaligned 16-bits + R26R25, R24R23, R22R21, R20R19, R18R17, R16R15, + R14R13, R12R11, R10R9)>; + // Register class used for the stack read pseudo instruction. def GPRSP : RegisterClass<"AVR", [i16], 8, (add SP)>; diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index 22cf3a7..598735f 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -4675,7 +4675,7 @@ class WMMA_INSTR<string _Intr, list<dag> _Args> // class WMMA_LOAD<WMMA_REGINFO Frag, string Layout, string Space, bit WithStride> - : WMMA_INSTR<WMMA_NAME_LDST<"load", Frag, Layout, WithStride>.record, + : WMMA_INSTR<WMMA_NAME_LDST<"load", Frag, Layout, WithStride>.record_name, [!con((ins ADDR:$src), !if(WithStride, (ins B32:$ldm), (ins)))]>, Requires<Frag.Predicates> { @@ -4714,7 +4714,7 @@ class WMMA_LOAD<WMMA_REGINFO Frag, string Layout, string Space, bit WithStride> // class WMMA_STORE_D<WMMA_REGINFO Frag, string Layout, string Space, bit WithStride> - : WMMA_INSTR<WMMA_NAME_LDST<"store", Frag, Layout, WithStride>.record, + : WMMA_INSTR<WMMA_NAME_LDST<"store", Frag, Layout, WithStride>.record_name, [!con((ins ADDR:$dst), Frag.Ins, !if(WithStride, (ins B32:$ldm), (ins)))]>, @@ -4778,7 +4778,7 @@ class MMA_OP_PREDICATES<WMMA_REGINFO FragA, string b1op> { class WMMA_MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB, WMMA_REGINFO FragC, WMMA_REGINFO FragD, string ALayout, string BLayout, int Satfinite, string rnd, string b1op> - : WMMA_INSTR<WMMA_NAME<ALayout, BLayout, Satfinite, rnd, b1op, FragA, FragB, FragC, FragD>.record, + : WMMA_INSTR<WMMA_NAME<ALayout, BLayout, Satfinite, rnd, b1op, FragA, FragB, FragC, FragD>.record_name, [FragA.Ins, FragB.Ins, FragC.Ins]>, // Requires does not seem to have effect on Instruction w/o Patterns. // We set it here anyways and propagate to the Pat<> we construct below. @@ -4837,7 +4837,7 @@ defset list<WMMA_INSTR> WMMAs = { class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB, WMMA_REGINFO FragC, WMMA_REGINFO FragD, string ALayout, string BLayout, int Satfinite, string b1op, string Kind> - : WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, b1op, Kind, FragA, FragB, FragC, FragD>.record, + : WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, b1op, Kind, FragA, FragB, FragC, FragD>.record_name, [FragA.Ins, FragB.Ins, FragC.Ins]>, // Requires does not seem to have effect on Instruction w/o Patterns. // We set it here anyways and propagate to the Pat<> we construct below. @@ -4891,7 +4891,7 @@ class MMA_SP<WMMA_REGINFO FragA, WMMA_REGINFO FragB, WMMA_REGINFO FragC, WMMA_REGINFO FragD, string Metadata, string Kind, int Satfinite> : WMMA_INSTR<MMA_SP_NAME<Metadata, Kind, Satfinite, - FragA, FragB, FragC, FragD>.record, + FragA, FragB, FragC, FragD>.record_name, [FragA.Ins, FragB.Ins, FragC.Ins, (ins B32:$metadata, i32imm:$selector)]>, // Requires does not seem to have effect on Instruction w/o Patterns. @@ -4946,7 +4946,7 @@ defset list<WMMA_INSTR> MMA_SPs = { // ldmatrix.sync.aligned.m8n8[|.trans][|.shared].b16 // class LDMATRIX<WMMA_REGINFO Frag, bit Transposed, string Space> - : WMMA_INSTR<LDMATRIX_NAME<Frag, Transposed>.record, [(ins ADDR:$src)]>, + : WMMA_INSTR<LDMATRIX_NAME<Frag, Transposed>.record_name, [(ins ADDR:$src)]>, Requires<Frag.Predicates> { // Build PatFrag that only matches particular address space. PatFrag IntrFrag = PatFrag<(ops node:$src), (Intr node:$src), @@ -4981,7 +4981,7 @@ defset list<WMMA_INSTR> LDMATRIXs = { // stmatrix.sync.aligned.m8n8[|.trans][|.shared].b16 // class STMATRIX<WMMA_REGINFO Frag, bit Transposed, string Space> - : WMMA_INSTR<STMATRIX_NAME<Frag, Transposed>.record, [!con((ins ADDR:$dst), Frag.Ins)]>, + : WMMA_INSTR<STMATRIX_NAME<Frag, Transposed>.record_name, [!con((ins ADDR:$dst), Frag.Ins)]>, Requires<Frag.Predicates> { // Build PatFrag that only matches particular address space. dag PFOperands = !con((ops node:$dst), @@ -5376,7 +5376,7 @@ class Tcgen05MMAInst<bit Sp, string KindStr, string ASpace, Requires<PTXPredicates> { Intrinsic Intrin = !cast<Intrinsic>( - NVVM_TCGEN05_MMA<Sp, ASpace, AShift, ScaleInputD>.record + NVVM_TCGEN05_MMA<Sp, ASpace, AShift, ScaleInputD>.record_name ); dag ScaleInpIns = !if(!eq(ScaleInputD, 1), (ins i64imm:$scale_input_d), (ins)); @@ -5618,7 +5618,7 @@ class Tcgen05MMABlockScaleInst<bit Sp, string ASpace, string KindStr, Requires<[hasTcgen05Instructions, PTXPredicate]> { Intrinsic Intrin = !cast<Intrinsic>( - NVVM_TCGEN05_MMA_BLOCKSCALE<Sp, ASpace, KindStr, ScaleVecSize>.record); + NVVM_TCGEN05_MMA_BLOCKSCALE<Sp, ASpace, KindStr, ScaleVecSize>.record_name); dag SparseMetadataIns = !if(!eq(Sp, 1), (ins B32:$spmetadata), (ins)); dag SparseMetadataIntr = !if(!eq(Sp, 1), (Intrin i32:$spmetadata), (Intrin)); @@ -5702,7 +5702,7 @@ class Tcgen05MMAWSInst<bit Sp, string ASpace, string KindStr, Requires<[hasTcgen05Instructions]> { Intrinsic Intrin = !cast<Intrinsic>( - NVVM_TCGEN05_MMA_WS<Sp, ASpace, HasZeroColMask>.record); + NVVM_TCGEN05_MMA_WS<Sp, ASpace, HasZeroColMask>.record_name); dag ZeroColMaskIns = !if(!eq(HasZeroColMask, 1), (ins B64:$zero_col_mask), (ins)); diff --git a/llvm/lib/Target/RISCV/GISel/RISCVPostLegalizerCombiner.cpp b/llvm/lib/Target/RISCV/GISel/RISCVPostLegalizerCombiner.cpp index 67b510d..f2b216b 100644 --- a/llvm/lib/Target/RISCV/GISel/RISCVPostLegalizerCombiner.cpp +++ b/llvm/lib/Target/RISCV/GISel/RISCVPostLegalizerCombiner.cpp @@ -27,6 +27,7 @@ #include "llvm/CodeGen/MachineDominators.h" #include "llvm/CodeGen/MachineFunctionPass.h" #include "llvm/CodeGen/TargetPassConfig.h" +#include "llvm/Support/FormatVariadic.h" #define GET_GICOMBINER_DEPS #include "RISCVGenPostLegalizeGICombiner.inc" @@ -42,6 +43,56 @@ namespace { #include "RISCVGenPostLegalizeGICombiner.inc" #undef GET_GICOMBINER_TYPES +/// Match: G_STORE (G_FCONSTANT +0.0), addr +/// Return the source vreg in MatchInfo if matched. +bool matchFoldFPZeroStore(MachineInstr &MI, MachineRegisterInfo &MRI, + const RISCVSubtarget &STI, Register &MatchInfo) { + if (MI.getOpcode() != TargetOpcode::G_STORE) + return false; + + Register SrcReg = MI.getOperand(0).getReg(); + if (!SrcReg.isVirtual()) + return false; + + MachineInstr *Def = MRI.getVRegDef(SrcReg); + if (!Def || Def->getOpcode() != TargetOpcode::G_FCONSTANT) + return false; + + auto *CFP = Def->getOperand(1).getFPImm(); + if (!CFP || !CFP->getValueAPF().isPosZero()) + return false; + + unsigned ValBits = MRI.getType(SrcReg).getSizeInBits(); + if ((ValBits == 16 && !STI.hasStdExtZfh()) || + (ValBits == 32 && !STI.hasStdExtF()) || + (ValBits == 64 && (!STI.hasStdExtD() || !STI.is64Bit()))) + return false; + + MatchInfo = SrcReg; + return true; +} + +/// Apply: rewrite to G_STORE (G_CONSTANT 0 [XLEN]), addr +void applyFoldFPZeroStore(MachineInstr &MI, MachineRegisterInfo &MRI, + MachineIRBuilder &B, const RISCVSubtarget &STI, + Register &MatchInfo) { + const unsigned XLen = STI.getXLen(); + + auto Zero = B.buildConstant(LLT::scalar(XLen), 0); + MI.getOperand(0).setReg(Zero.getReg(0)); + + MachineInstr *Def = MRI.getVRegDef(MatchInfo); + if (Def && MRI.use_nodbg_empty(MatchInfo)) + Def->eraseFromParent(); + +#ifndef NDEBUG + unsigned ValBits = MRI.getType(MatchInfo).getSizeInBits(); + LLVM_DEBUG(dbgs() << formatv("[{0}] Fold FP zero store -> int zero " + "(XLEN={1}, ValBits={2}):\n {3}\n", + DEBUG_TYPE, XLen, ValBits, MI)); +#endif +} + class RISCVPostLegalizerCombinerImpl : public Combiner { protected: const CombinerHelper Helper; diff --git a/llvm/lib/Target/RISCV/RISCVCombine.td b/llvm/lib/Target/RISCV/RISCVCombine.td index 995dd0c..a06b60d 100644 --- a/llvm/lib/Target/RISCV/RISCVCombine.td +++ b/llvm/lib/Target/RISCV/RISCVCombine.td @@ -19,11 +19,20 @@ def RISCVO0PreLegalizerCombiner: GICombiner< "RISCVO0PreLegalizerCombinerImpl", [optnone_combines]> { } +// Rule: fold store (fp +0.0) -> store (int zero [XLEN]) +def fp_zero_store_matchdata : GIDefMatchData<"Register">; +def fold_fp_zero_store : GICombineRule< + (defs root:$root, fp_zero_store_matchdata:$matchinfo), + (match (G_STORE $src, $addr):$root, + [{ return matchFoldFPZeroStore(*${root}, MRI, STI, ${matchinfo}); }]), + (apply [{ applyFoldFPZeroStore(*${root}, MRI, B, STI, ${matchinfo}); }])>; + // Post-legalization combines which are primarily optimizations. // TODO: Add more combines. def RISCVPostLegalizerCombiner : GICombiner<"RISCVPostLegalizerCombinerImpl", [sub_to_add, combines_for_extload, redundant_and, identity_combines, shift_immed_chain, - commute_constant_to_rhs, simplify_neg_minmax]> { + commute_constant_to_rhs, simplify_neg_minmax, + fold_fp_zero_store]> { } diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 26fe9ed..219e3f2 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -14797,7 +14797,7 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N, // to NEGW+MAX here requires a Freeze which breaks ComputeNumSignBits. SDValue Src = DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, N->getOperand(0)); - SDValue Abs = DAG.getNode(RISCVISD::ABSW, DL, MVT::i64, Src); + SDValue Abs = DAG.getNode(RISCVISD::NEGW_MAX, DL, MVT::i64, Src); Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Abs)); return; } @@ -21813,7 +21813,7 @@ unsigned RISCVTargetLowering::ComputeNumSignBitsForTargetNode( // Output is either all zero or operand 0. We can propagate sign bit count // from operand 0. return DAG.ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1); - case RISCVISD::ABSW: { + case RISCVISD::NEGW_MAX: { // We expand this at isel to negw+max. The result will have 33 sign bits // if the input has at least 33 sign bits. unsigned Tmp = diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoXSf.td b/llvm/lib/Target/RISCV/RISCVInstrInfoXSf.td index 4104abd..4c2f7f6 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoXSf.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoXSf.td @@ -482,7 +482,7 @@ let Predicates = [HasVendorXSfvfwmaccqqq] in { defm SF_VFWMACC_4x4x4 : VPseudoSiFiveVFWMACC; } -let Predicates = [HasVendorXSfvfnrclipxfqf] in { +let Predicates = [HasVendorXSfvfnrclipxfqf], AltFmtType = IS_NOT_ALTFMT in { defm SF_VFNRCLIP_XU_F_QF : VPseudoSiFiveVFNRCLIP; defm SF_VFNRCLIP_X_F_QF : VPseudoSiFiveVFNRCLIP; } diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td index 62b7bcd..6b9a75f 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td @@ -51,7 +51,7 @@ def riscv_zip : RVSDNode<"ZIP", SDTIntUnaryOp>; def riscv_unzip : RVSDNode<"UNZIP", SDTIntUnaryOp>; // RV64IZbb absolute value for i32. Expanded to (max (negw X), X) during isel. -def riscv_absw : RVSDNode<"ABSW", SDTIntUnaryOp>; +def riscv_negw_max : RVSDNode<"NEGW_MAX", SDTIntUnaryOp>; // Scalar cryptography def riscv_clmul : RVSDNode<"CLMUL", SDTIntBinOp>; @@ -610,7 +610,7 @@ def : PatGpr<riscv_clzw, CLZW>; def : PatGpr<riscv_ctzw, CTZW>; def : Pat<(i64 (ctpop (i64 (zexti32 (i64 GPR:$rs1))))), (CPOPW GPR:$rs1)>; -def : Pat<(i64 (riscv_absw GPR:$rs1)), +def : Pat<(i64 (riscv_negw_max GPR:$rs1)), (MAX GPR:$rs1, (XLenVT (SUBW (XLenVT X0), GPR:$rs1)))>; } // Predicates = [HasStdExtZbb, IsRV64] diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp index d91923b..56a38bb 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp @@ -1499,18 +1499,25 @@ static bool generateKernelClockInst(const SPIRV::IncomingCall *Call, Register ResultReg = Call->ReturnRegister; - // Deduce the `Scope` operand from the builtin function name. - SPIRV::Scope::Scope ScopeArg = - StringSwitch<SPIRV::Scope::Scope>(Builtin->Name) - .EndsWith("device", SPIRV::Scope::Scope::Device) - .EndsWith("work_group", SPIRV::Scope::Scope::Workgroup) - .EndsWith("sub_group", SPIRV::Scope::Scope::Subgroup); - Register ScopeReg = buildConstantIntReg32(ScopeArg, MIRBuilder, GR); - - MIRBuilder.buildInstr(SPIRV::OpReadClockKHR) - .addDef(ResultReg) - .addUse(GR->getSPIRVTypeID(Call->ReturnType)) - .addUse(ScopeReg); + if (Builtin->Name == "__spirv_ReadClockKHR") { + MIRBuilder.buildInstr(SPIRV::OpReadClockKHR) + .addDef(ResultReg) + .addUse(GR->getSPIRVTypeID(Call->ReturnType)) + .addUse(Call->Arguments[0]); + } else { + // Deduce the `Scope` operand from the builtin function name. + SPIRV::Scope::Scope ScopeArg = + StringSwitch<SPIRV::Scope::Scope>(Builtin->Name) + .EndsWith("device", SPIRV::Scope::Scope::Device) + .EndsWith("work_group", SPIRV::Scope::Scope::Workgroup) + .EndsWith("sub_group", SPIRV::Scope::Scope::Subgroup); + Register ScopeReg = buildConstantIntReg32(ScopeArg, MIRBuilder, GR); + + MIRBuilder.buildInstr(SPIRV::OpReadClockKHR) + .addDef(ResultReg) + .addUse(GR->getSPIRVTypeID(Call->ReturnType)) + .addUse(ScopeReg); + } return true; } diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td index 3b8764a..c259cce 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td @@ -1174,6 +1174,7 @@ defm : DemangledNativeBuiltin<"clock_read_sub_group", OpenCL_std, KernelClock, 0 defm : DemangledNativeBuiltin<"clock_read_hilo_device", OpenCL_std, KernelClock, 0, 0, OpReadClockKHR>; defm : DemangledNativeBuiltin<"clock_read_hilo_work_group", OpenCL_std, KernelClock, 0, 0, OpReadClockKHR>; defm : DemangledNativeBuiltin<"clock_read_hilo_sub_group", OpenCL_std, KernelClock, 0, 0, OpReadClockKHR>; +defm : DemangledNativeBuiltin<"__spirv_ReadClockKHR", OpenCL_std, KernelClock, 1, 1, OpReadClockKHR>; //===----------------------------------------------------------------------===// // Class defining an atomic instruction on floating-point numbers. diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td index f0ac26b..14097d7 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td +++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td @@ -1336,22 +1336,25 @@ def pmax : PatFrags<(ops node:$lhs, node:$rhs), [ ]>; defm PMAX : SIMDBinaryFP<pmax, "pmax", 235>; +multiclass PMinMaxInt<Vec vec, NI baseMinInst, NI baseMaxInst> { + def : Pat<(vec.int_vt (vselect + (setolt (vec.vt (bitconvert V128:$rhs)), + (vec.vt (bitconvert V128:$lhs))), + V128:$rhs, V128:$lhs)), + (baseMinInst $lhs, $rhs)>; + def : Pat<(vec.int_vt (vselect + (setolt (vec.vt (bitconvert V128:$lhs)), + (vec.vt (bitconvert V128:$rhs))), + V128:$rhs, V128:$lhs)), + (baseMaxInst $lhs, $rhs)>; +} // Also match the pmin/pmax cases where the operands are int vectors (but the // comparison is still a floating point comparison). This can happen when using // the wasm_simd128.h intrinsics because v128_t is an integer vector. foreach vec = [F32x4, F64x2, F16x8] in { -defvar pmin = !cast<NI>("PMIN_"#vec); -defvar pmax = !cast<NI>("PMAX_"#vec); -def : Pat<(vec.int_vt (vselect - (setolt (vec.vt (bitconvert V128:$rhs)), - (vec.vt (bitconvert V128:$lhs))), - V128:$rhs, V128:$lhs)), - (pmin $lhs, $rhs)>; -def : Pat<(vec.int_vt (vselect - (setolt (vec.vt (bitconvert V128:$lhs)), - (vec.vt (bitconvert V128:$rhs))), - V128:$rhs, V128:$lhs)), - (pmax $lhs, $rhs)>; + defvar pmin = !cast<NI>("PMIN_"#vec); + defvar pmax = !cast<NI>("PMAX_"#vec); + defm : PMinMaxInt<vec, pmin, pmax>; } // And match the pmin/pmax LLVM intrinsics as well @@ -1756,6 +1759,15 @@ let Predicates = [HasRelaxedSIMD] in { (relaxed_max V128:$lhs, V128:$rhs)>; def : Pat<(vec.vt (fmaximumnum (vec.vt V128:$lhs), (vec.vt V128:$rhs))), (relaxed_max V128:$lhs, V128:$rhs)>; + + // Transform pmin/max-supposed patterns to relaxed min max + let AddedComplexity = 1 in { + def : Pat<(vec.vt (pmin (vec.vt V128:$lhs), (vec.vt V128:$rhs))), + (relaxed_min $lhs, $rhs)>; + def : Pat<(vec.vt (pmax (vec.vt V128:$lhs), (vec.vt V128:$rhs))), + (relaxed_max $lhs, $rhs)>; + defm : PMinMaxInt<vec, relaxed_min, relaxed_max>; + } } } diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index d49f25a..4dfc400 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -2632,6 +2632,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(Op, MVT::f32, Promote); } + setOperationPromotedToType(ISD::ATOMIC_LOAD, MVT::f16, MVT::i16); + setOperationPromotedToType(ISD::ATOMIC_LOAD, MVT::f32, MVT::i32); + setOperationPromotedToType(ISD::ATOMIC_LOAD, MVT::f64, MVT::i64); + // We have target-specific dag combine patterns for the following nodes: setTargetDAGCombine({ISD::VECTOR_SHUFFLE, ISD::SCALAR_TO_VECTOR, diff --git a/llvm/lib/TargetParser/ARMTargetParser.cpp b/llvm/lib/TargetParser/ARMTargetParser.cpp index 0fce5b9..709e5f0 100644 --- a/llvm/lib/TargetParser/ARMTargetParser.cpp +++ b/llvm/lib/TargetParser/ARMTargetParser.cpp @@ -88,6 +88,7 @@ unsigned ARM::parseArchVersion(StringRef Arch) { case ArchKind::ARMV9_4A: case ArchKind::ARMV9_5A: case ArchKind::ARMV9_6A: + case ArchKind::ARMV9_7A: return 9; case ArchKind::INVALID: return 0; @@ -127,6 +128,7 @@ static ARM::ProfileKind getProfileKind(ARM::ArchKind AK) { case ARM::ArchKind::ARMV9_4A: case ARM::ArchKind::ARMV9_5A: case ARM::ArchKind::ARMV9_6A: + case ARM::ArchKind::ARMV9_7A: return ARM::ProfileKind::A; case ARM::ArchKind::ARMV4: case ARM::ArchKind::ARMV4T: diff --git a/llvm/lib/TargetParser/ARMTargetParserCommon.cpp b/llvm/lib/TargetParser/ARMTargetParserCommon.cpp index f6cea85..15ba1eb 100644 --- a/llvm/lib/TargetParser/ARMTargetParserCommon.cpp +++ b/llvm/lib/TargetParser/ARMTargetParserCommon.cpp @@ -46,6 +46,7 @@ StringRef ARM::getArchSynonym(StringRef Arch) { .Case("v9.4a", "v9.4-a") .Case("v9.5a", "v9.5-a") .Case("v9.6a", "v9.6-a") + .Case("v9.7a", "v9.7-a") .Case("v8m.base", "v8-m.base") .Case("v8m.main", "v8-m.main") .Case("v8.1m.main", "v8.1-m.main") diff --git a/llvm/lib/TargetParser/Triple.cpp b/llvm/lib/TargetParser/Triple.cpp index 1068ce4..11ba9ee 100644 --- a/llvm/lib/TargetParser/Triple.cpp +++ b/llvm/lib/TargetParser/Triple.cpp @@ -937,6 +937,8 @@ static Triple::SubArchType parseSubArch(StringRef SubArchName) { return Triple::ARMSubArch_v9_5a; case ARM::ArchKind::ARMV9_6A: return Triple::ARMSubArch_v9_6a; + case ARM::ArchKind::ARMV9_7A: + return Triple::ARMSubArch_v9_7a; case ARM::ArchKind::ARMV8R: return Triple::ARMSubArch_v8r; case ARM::ArchKind::ARMV8MBaseline: diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp index a0f7ec6..2dd0fde 100644 --- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -948,17 +948,17 @@ void llvm::updateVCallVisibilityInIndex( // linker, as we have no information on their eventual use. if (DynamicExportSymbols.count(P.first)) continue; + // With validation enabled, we want to exclude symbols visible to regular + // objects. Local symbols will be in this group due to the current + // implementation but those with VCallVisibilityTranslationUnit will have + // already been marked in clang so are unaffected. + if (VisibleToRegularObjSymbols.count(P.first)) + continue; for (auto &S : P.second.getSummaryList()) { auto *GVar = dyn_cast<GlobalVarSummary>(S.get()); if (!GVar || GVar->getVCallVisibility() != GlobalObject::VCallVisibilityPublic) continue; - // With validation enabled, we want to exclude symbols visible to regular - // objects. Local symbols will be in this group due to the current - // implementation but those with VCallVisibilityTranslationUnit will have - // already been marked in clang so are unaffected. - if (VisibleToRegularObjSymbols.count(P.first)) - continue; GVar->setVCallVisibility(GlobalObject::VCallVisibilityLinkageUnit); } } @@ -1161,14 +1161,10 @@ bool DevirtIndex::tryFindVirtualCallTargets( // and therefore the same GUID. This can happen if there isn't enough // distinguishing path when compiling the source file. In that case we // conservatively return false early. + if (P.VTableVI.hasLocal() && P.VTableVI.getSummaryList().size() > 1) + return false; const GlobalVarSummary *VS = nullptr; - bool LocalFound = false; for (const auto &S : P.VTableVI.getSummaryList()) { - if (GlobalValue::isLocalLinkage(S->linkage())) { - if (LocalFound) - return false; - LocalFound = true; - } auto *CurVS = cast<GlobalVarSummary>(S->getBaseObject()); if (!CurVS->vTableFuncs().empty() || // Previously clang did not attach the necessary type metadata to @@ -1184,6 +1180,7 @@ bool DevirtIndex::tryFindVirtualCallTargets( // with public LTO visibility. if (VS->getVCallVisibility() == GlobalObject::VCallVisibilityPublic) return false; + break; } } // There will be no VS if all copies are available_externally having no @@ -1411,9 +1408,8 @@ bool DevirtIndex::trySingleImplDevirt(MutableArrayRef<ValueInfo> TargetsForSlot, // If the summary list contains multiple summaries where at least one is // a local, give up, as we won't know which (possibly promoted) name to use. - for (const auto &S : TheFn.getSummaryList()) - if (GlobalValue::isLocalLinkage(S->linkage()) && Size > 1) - return false; + if (TheFn.hasLocal() && Size > 1) + return false; // Collect functions devirtualized at least for one call site for stats. if (PrintSummaryDevirt || AreStatisticsEnabled()) @@ -2591,6 +2587,11 @@ void DevirtIndex::run() { if (ExportSummary.typeIdCompatibleVtableMap().empty()) return; + // Assert that we haven't made any changes that would affect the hasLocal() + // flag on the GUID summary info. + assert(!ExportSummary.withInternalizeAndPromote() && + "Expect index-based WPD to run before internalization and promotion"); + DenseMap<GlobalValue::GUID, std::vector<StringRef>> NameByGUID; for (const auto &P : ExportSummary.typeIdCompatibleVtableMap()) { NameByGUID[GlobalValue::getGUIDAssumingExternalLinkage(P.first)].push_back( diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 73ec451..9bee523 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -2760,21 +2760,34 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { // Optimize pointer differences into the same array into a size. Consider: // &A[10] - &A[0]: we should compile this to "10". Value *LHSOp, *RHSOp; - if (match(Op0, m_PtrToInt(m_Value(LHSOp))) && - match(Op1, m_PtrToInt(m_Value(RHSOp)))) + if (match(Op0, m_PtrToIntOrAddr(m_Value(LHSOp))) && + match(Op1, m_PtrToIntOrAddr(m_Value(RHSOp)))) if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType(), I.hasNoUnsignedWrap())) return replaceInstUsesWith(I, Res); // trunc(p)-trunc(q) -> trunc(p-q) - if (match(Op0, m_Trunc(m_PtrToInt(m_Value(LHSOp)))) && - match(Op1, m_Trunc(m_PtrToInt(m_Value(RHSOp))))) + if (match(Op0, m_Trunc(m_PtrToIntOrAddr(m_Value(LHSOp)))) && + match(Op1, m_Trunc(m_PtrToIntOrAddr(m_Value(RHSOp))))) if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType(), /* IsNUW */ false)) return replaceInstUsesWith(I, Res); - if (match(Op0, m_ZExt(m_PtrToIntSameSize(DL, m_Value(LHSOp)))) && - match(Op1, m_ZExtOrSelf(m_PtrToInt(m_Value(RHSOp))))) { + auto MatchSubOfZExtOfPtrToIntOrAddr = [&]() { + if (match(Op0, m_ZExt(m_PtrToIntSameSize(DL, m_Value(LHSOp)))) && + match(Op1, m_ZExt(m_PtrToIntSameSize(DL, m_Value(RHSOp))))) + return true; + if (match(Op0, m_ZExt(m_PtrToAddr(m_Value(LHSOp)))) && + match(Op1, m_ZExt(m_PtrToAddr(m_Value(RHSOp))))) + return true; + // Special case for non-canonical ptrtoint in constant expression, + // where the zext has been folded into the ptrtoint. + if (match(Op0, m_ZExt(m_PtrToIntSameSize(DL, m_Value(LHSOp)))) && + match(Op1, m_PtrToInt(m_Value(RHSOp)))) + return true; + return false; + }; + if (MatchSubOfZExtOfPtrToIntOrAddr()) { if (auto *GEP = dyn_cast<GEPOperator>(LHSOp)) { if (GEP->getPointerOperand() == RHSOp) { if (GEP->hasNoUnsignedWrap() || GEP->hasNoUnsignedSignedWrap()) { diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index dab200d..669d4f0 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -4003,18 +4003,29 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { // Try to fold intrinsic into select/phi operands. This is legal if: // * The intrinsic is speculatable. - // * The select condition is not a vector, or the intrinsic does not - // perform cross-lane operations. - if (isSafeToSpeculativelyExecuteWithVariableReplaced(&CI) && - isNotCrossLaneOperation(II)) + // * The operand is one of the following: + // - a phi. + // - a select with a scalar condition. + // - a select with a vector condition and II is not a cross lane operation. + if (isSafeToSpeculativelyExecuteWithVariableReplaced(&CI)) { for (Value *Op : II->args()) { - if (auto *Sel = dyn_cast<SelectInst>(Op)) - if (Instruction *R = FoldOpIntoSelect(*II, Sel)) + if (auto *Sel = dyn_cast<SelectInst>(Op)) { + bool IsVectorCond = Sel->getCondition()->getType()->isVectorTy(); + if (IsVectorCond && !isNotCrossLaneOperation(II)) + continue; + // Don't replace a scalar select with a more expensive vector select if + // we can't simplify both arms of the select. + bool SimplifyBothArms = + !Op->getType()->isVectorTy() && II->getType()->isVectorTy(); + if (Instruction *R = FoldOpIntoSelect( + *II, Sel, /*FoldWithMultiUse=*/false, SimplifyBothArms)) return R; + } if (auto *Phi = dyn_cast<PHINode>(Op)) if (Instruction *R = foldOpIntoPhi(*II, Phi)) return R; } + } if (Instruction *Shuf = foldShuffledIntrinsicOperands(II)) return Shuf; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index 943c223..ede73f8 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -664,7 +664,8 @@ public: /// This also works for Cast instructions, which obviously do not have a /// second operand. Instruction *FoldOpIntoSelect(Instruction &Op, SelectInst *SI, - bool FoldWithMultiUse = false); + bool FoldWithMultiUse = false, + bool SimplifyBothArms = false); /// This is a convenience wrapper function for the above two functions. Instruction *foldBinOpIntoSelectOrPhi(BinaryOperator &I); diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index 3f11cae..67e2aae 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -1777,7 +1777,8 @@ static Value *foldOperationIntoSelectOperand(Instruction &I, SelectInst *SI, } Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, SelectInst *SI, - bool FoldWithMultiUse) { + bool FoldWithMultiUse, + bool SimplifyBothArms) { // Don't modify shared select instructions unless set FoldWithMultiUse if (!SI->hasOneUse() && !FoldWithMultiUse) return nullptr; @@ -1821,6 +1822,9 @@ Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, SelectInst *SI, if (!NewTV && !NewFV) return nullptr; + if (SimplifyBothArms && !(NewTV && NewFV)) + return nullptr; + // Create an instruction for the arm that did not fold. if (!NewTV) NewTV = foldOperationIntoSelectOperand(Op, SI, TV, *this); diff --git a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp index 4acc3f2..d347ced 100644 --- a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp +++ b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp @@ -614,6 +614,16 @@ static Decomposition decompose(Value *V, return {V, IsKnownNonNegative}; } + if (match(V, m_Add(m_Value(Op0), m_ConstantInt(CI))) && CI->isNegative() && + canUseSExt(CI)) { + Preconditions.emplace_back( + CmpInst::ICMP_UGE, Op0, + ConstantInt::get(Op0->getType(), CI->getSExtValue() * -1)); + if (auto Decomp = MergeResults(Op0, CI, true)) + return *Decomp; + return {V, IsKnownNonNegative}; + } + if (match(V, m_NSWAdd(m_Value(Op0), m_Value(Op1)))) { if (!isKnownNonNegative(Op0, DL)) Preconditions.emplace_back(CmpInst::ICMP_SGE, Op0, @@ -627,16 +637,6 @@ static Decomposition decompose(Value *V, return {V, IsKnownNonNegative}; } - if (match(V, m_Add(m_Value(Op0), m_ConstantInt(CI))) && CI->isNegative() && - canUseSExt(CI)) { - Preconditions.emplace_back( - CmpInst::ICMP_UGE, Op0, - ConstantInt::get(Op0->getType(), CI->getSExtValue() * -1)); - if (auto Decomp = MergeResults(Op0, CI, true)) - return *Decomp; - return {V, IsKnownNonNegative}; - } - // Decompose or as an add if there are no common bits between the operands. if (match(V, m_DisjointOr(m_Value(Op0), m_ConstantInt(CI)))) { if (auto Decomp = MergeResults(Op0, CI, IsSigned)) diff --git a/llvm/lib/Transforms/Scalar/MergeICmps.cpp b/llvm/lib/Transforms/Scalar/MergeICmps.cpp index a83cbd17a7..f273e9d 100644 --- a/llvm/lib/Transforms/Scalar/MergeICmps.cpp +++ b/llvm/lib/Transforms/Scalar/MergeICmps.cpp @@ -64,10 +64,10 @@ using namespace llvm; -namespace { - #define DEBUG_TYPE "mergeicmps" +namespace { + // A BCE atom "Binary Compare Expression Atom" represents an integer load // that is a constant offset from a base value, e.g. `a` or `o.c` in the example // at the top. @@ -128,11 +128,12 @@ private: unsigned Order = 1; DenseMap<const Value*, int> BaseToIndex; }; +} // namespace // If this value is a load from a constant offset w.r.t. a base address, and // there are no other users of the load or address, returns the base address and // the offset. -BCEAtom visitICmpLoadOperand(Value *const Val, BaseIdentifier &BaseId) { +static BCEAtom visitICmpLoadOperand(Value *const Val, BaseIdentifier &BaseId) { auto *const LoadI = dyn_cast<LoadInst>(Val); if (!LoadI) return {}; @@ -175,6 +176,7 @@ BCEAtom visitICmpLoadOperand(Value *const Val, BaseIdentifier &BaseId) { return BCEAtom(GEP, LoadI, BaseId.getBaseId(Base), Offset); } +namespace { // A comparison between two BCE atoms, e.g. `a == o.a` in the example at the // top. // Note: the terminology is misleading: the comparison is symmetric, so there @@ -239,6 +241,7 @@ class BCECmpBlock { private: BCECmp Cmp; }; +} // namespace bool BCECmpBlock::canSinkBCECmpInst(const Instruction *Inst, AliasAnalysis &AA) const { @@ -302,9 +305,9 @@ bool BCECmpBlock::doesOtherWork() const { // Visit the given comparison. If this is a comparison between two valid // BCE atoms, returns the comparison. -std::optional<BCECmp> visitICmp(const ICmpInst *const CmpI, - const ICmpInst::Predicate ExpectedPredicate, - BaseIdentifier &BaseId) { +static std::optional<BCECmp> +visitICmp(const ICmpInst *const CmpI, + const ICmpInst::Predicate ExpectedPredicate, BaseIdentifier &BaseId) { // The comparison can only be used once: // - For intermediate blocks, as a branch condition. // - For the final block, as an incoming value for the Phi. @@ -332,10 +335,9 @@ std::optional<BCECmp> visitICmp(const ICmpInst *const CmpI, // Visit the given comparison block. If this is a comparison between two valid // BCE atoms, returns the comparison. -std::optional<BCECmpBlock> visitCmpBlock(Value *const Val, - BasicBlock *const Block, - const BasicBlock *const PhiBlock, - BaseIdentifier &BaseId) { +static std::optional<BCECmpBlock> +visitCmpBlock(Value *const Val, BasicBlock *const Block, + const BasicBlock *const PhiBlock, BaseIdentifier &BaseId) { if (Block->empty()) return std::nullopt; auto *const BranchI = dyn_cast<BranchInst>(Block->getTerminator()); @@ -397,6 +399,7 @@ static inline void enqueueBlock(std::vector<BCECmpBlock> &Comparisons, Comparisons.push_back(std::move(Comparison)); } +namespace { // A chain of comparisons. class BCECmpChain { public: @@ -420,6 +423,7 @@ private: // The original entry block (before sorting); BasicBlock *EntryBlock_; }; +} // namespace static bool areContiguous(const BCECmpBlock &First, const BCECmpBlock &Second) { return First.Lhs().BaseId == Second.Lhs().BaseId && @@ -742,9 +746,8 @@ bool BCECmpChain::simplify(const TargetLibraryInfo &TLI, AliasAnalysis &AA, return true; } -std::vector<BasicBlock *> getOrderedBlocks(PHINode &Phi, - BasicBlock *const LastBlock, - int NumBlocks) { +static std::vector<BasicBlock *> +getOrderedBlocks(PHINode &Phi, BasicBlock *const LastBlock, int NumBlocks) { // Walk up from the last block to find other blocks. std::vector<BasicBlock *> Blocks(NumBlocks); assert(LastBlock && "invalid last block"); @@ -777,8 +780,8 @@ std::vector<BasicBlock *> getOrderedBlocks(PHINode &Phi, return Blocks; } -bool processPhi(PHINode &Phi, const TargetLibraryInfo &TLI, AliasAnalysis &AA, - DomTreeUpdater &DTU) { +static bool processPhi(PHINode &Phi, const TargetLibraryInfo &TLI, + AliasAnalysis &AA, DomTreeUpdater &DTU) { LLVM_DEBUG(dbgs() << "processPhi()\n"); if (Phi.getNumIncomingValues() <= 1) { LLVM_DEBUG(dbgs() << "skip: only one incoming value in phi\n"); @@ -874,6 +877,7 @@ static bool runImpl(Function &F, const TargetLibraryInfo &TLI, return MadeChange; } +namespace { class MergeICmpsLegacyPass : public FunctionPass { public: static char ID; diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index 06bea2f..a1ad2db 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -2712,7 +2712,8 @@ public: static inline bool classof(const VPRecipeBase *R) { return R->getVPDefID() == VPRecipeBase::VPReductionSC || - R->getVPDefID() == VPRecipeBase::VPReductionEVLSC; + R->getVPDefID() == VPRecipeBase::VPReductionEVLSC || + R->getVPDefID() == VPRecipeBase::VPPartialReductionSC; } static inline bool classof(const VPUser *U) { @@ -2783,7 +2784,10 @@ public: Opcode(Opcode), VFScaleFactor(ScaleFactor) { [[maybe_unused]] auto *AccumulatorRecipe = getChainOp()->getDefiningRecipe(); - assert((isa<VPReductionPHIRecipe>(AccumulatorRecipe) || + // When cloning as part of a VPExpressionRecipe the chain op could have + // replaced by a temporary VPValue, so it doesn't have a defining recipe. + assert((!AccumulatorRecipe || + isa<VPReductionPHIRecipe>(AccumulatorRecipe) || isa<VPPartialReductionRecipe>(AccumulatorRecipe)) && "Unexpected operand order for partial reduction recipe"); } @@ -3093,6 +3097,11 @@ public: /// removed before codegen. void decompose(); + unsigned getVFScaleFactor() const { + auto *PR = dyn_cast<VPPartialReductionRecipe>(ExpressionRecipes.back()); + return PR ? PR->getVFScaleFactor() : 1; + } + /// Method for generating code, must not be called as this recipe is abstract. void execute(VPTransformState &State) override { llvm_unreachable("recipe must be removed before execute"); diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index 1f1b42b..931a5b7 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -168,6 +168,7 @@ bool VPRecipeBase::mayHaveSideEffects() const { return cast<VPWidenIntrinsicRecipe>(this)->mayHaveSideEffects(); case VPBlendSC: case VPReductionEVLSC: + case VPPartialReductionSC: case VPReductionSC: case VPScalarIVStepsSC: case VPVectorPointerSC: @@ -300,14 +301,23 @@ InstructionCost VPPartialReductionRecipe::computeCost(ElementCount VF, VPCostContext &Ctx) const { std::optional<unsigned> Opcode; - VPValue *Op = getOperand(0); - VPRecipeBase *OpR = Op->getDefiningRecipe(); - - // If the partial reduction is predicated, a select will be operand 0 - if (match(getOperand(1), m_Select(m_VPValue(), m_VPValue(Op), m_VPValue()))) { - OpR = Op->getDefiningRecipe(); + VPValue *Op = getVecOp(); + uint64_t MulConst; + // If the partial reduction is predicated, a select will be operand 1. + // If it isn't predicated and the mul isn't operating on a constant, then it + // should have been turned into a VPExpressionRecipe. + // FIXME: Replace the entire function with this once all partial reduction + // variants are bundled into VPExpressionRecipe. + if (!match(Op, m_Select(m_VPValue(), m_VPValue(Op), m_VPValue())) && + !match(Op, m_Mul(m_VPValue(), m_ConstantInt(MulConst)))) { + auto *PhiType = Ctx.Types.inferScalarType(getChainOp()); + auto *InputType = Ctx.Types.inferScalarType(getVecOp()); + return Ctx.TTI.getPartialReductionCost(getOpcode(), InputType, InputType, + PhiType, VF, TTI::PR_None, + TTI::PR_None, {}, Ctx.CostKind); } + VPRecipeBase *OpR = Op->getDefiningRecipe(); Type *InputTypeA = nullptr, *InputTypeB = nullptr; TTI::PartialReductionExtendKind ExtAType = TTI::PR_None, ExtBType = TTI::PR_None; @@ -2856,11 +2866,19 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF, cast<VPReductionRecipe>(ExpressionRecipes.back())->getRecurrenceKind()); switch (ExpressionType) { case ExpressionTypes::ExtendedReduction: { - return Ctx.TTI.getExtendedReductionCost( - Opcode, - cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() == - Instruction::ZExt, - RedTy, SrcVecTy, std::nullopt, Ctx.CostKind); + unsigned Opcode = RecurrenceDescriptor::getOpcode( + cast<VPReductionRecipe>(ExpressionRecipes[1])->getRecurrenceKind()); + auto *ExtR = cast<VPWidenCastRecipe>(ExpressionRecipes[0]); + return isa<VPPartialReductionRecipe>(ExpressionRecipes.back()) + ? Ctx.TTI.getPartialReductionCost( + Opcode, Ctx.Types.inferScalarType(getOperand(0)), nullptr, + RedTy, VF, + TargetTransformInfo::getPartialReductionExtendKind( + ExtR->getOpcode()), + TargetTransformInfo::PR_None, std::nullopt, Ctx.CostKind) + : Ctx.TTI.getExtendedReductionCost( + Opcode, ExtR->getOpcode() == Instruction::ZExt, RedTy, + SrcVecTy, std::nullopt, Ctx.CostKind); } case ExpressionTypes::MulAccReduction: return Ctx.TTI.getMulAccReductionCost(false, Opcode, RedTy, SrcVecTy, @@ -2871,6 +2889,19 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF, Opcode = Instruction::Sub; [[fallthrough]]; case ExpressionTypes::ExtMulAccReduction: { + if (isa<VPPartialReductionRecipe>(ExpressionRecipes.back())) { + auto *Ext0R = cast<VPWidenCastRecipe>(ExpressionRecipes[0]); + auto *Ext1R = cast<VPWidenCastRecipe>(ExpressionRecipes[1]); + auto *Mul = cast<VPWidenRecipe>(ExpressionRecipes[2]); + return Ctx.TTI.getPartialReductionCost( + Opcode, Ctx.Types.inferScalarType(getOperand(0)), + Ctx.Types.inferScalarType(getOperand(1)), RedTy, VF, + TargetTransformInfo::getPartialReductionExtendKind( + Ext0R->getOpcode()), + TargetTransformInfo::getPartialReductionExtendKind( + Ext1R->getOpcode()), + Mul->getOpcode(), Ctx.CostKind); + } return Ctx.TTI.getMulAccReductionCost( cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() == Instruction::ZExt, @@ -2910,12 +2941,13 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent, O << " = "; auto *Red = cast<VPReductionRecipe>(ExpressionRecipes.back()); unsigned Opcode = RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()); + bool IsPartialReduction = isa<VPPartialReductionRecipe>(Red); switch (ExpressionType) { case ExpressionTypes::ExtendedReduction: { getOperand(1)->printAsOperand(O, SlotTracker); - O << " +"; - O << " reduce." << Instruction::getOpcodeName(Opcode) << " ("; + O << " + " << (IsPartialReduction ? "partial." : "") << "reduce."; + O << Instruction::getOpcodeName(Opcode) << " ("; getOperand(0)->printAsOperand(O, SlotTracker); Red->printFlags(O); @@ -2931,8 +2963,8 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent, } case ExpressionTypes::ExtNegatedMulAccReduction: { getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker); - O << " + reduce." - << Instruction::getOpcodeName( + O << " + " << (IsPartialReduction ? "partial." : "") << "reduce."; + O << Instruction::getOpcodeName( RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind())) << " (sub (0, mul"; auto *Mul = cast<VPWidenRecipe>(ExpressionRecipes[2]); @@ -2956,9 +2988,8 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent, case ExpressionTypes::MulAccReduction: case ExpressionTypes::ExtMulAccReduction: { getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker); - O << " + "; - O << "reduce." - << Instruction::getOpcodeName( + O << " + " << (IsPartialReduction ? "partial." : "") << "reduce."; + O << Instruction::getOpcodeName( RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind())) << " ("; O << "mul"; diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index f5a3af4..c385c36 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -3519,18 +3519,31 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx, VPValue *VecOp = Red->getVecOp(); // Clamp the range if using extended-reduction is profitable. - auto IsExtendedRedValidAndClampRange = [&](unsigned Opcode, bool isZExt, - Type *SrcTy) -> bool { + auto IsExtendedRedValidAndClampRange = + [&](unsigned Opcode, Instruction::CastOps ExtOpc, Type *SrcTy) -> bool { return LoopVectorizationPlanner::getDecisionAndClampRange( [&](ElementCount VF) { auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF)); TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; - InstructionCost ExtRedCost = Ctx.TTI.getExtendedReductionCost( - Opcode, isZExt, RedTy, SrcVecTy, Red->getFastMathFlags(), - CostKind); + + InstructionCost ExtRedCost; InstructionCost ExtCost = cast<VPWidenCastRecipe>(VecOp)->computeCost(VF, Ctx); InstructionCost RedCost = Red->computeCost(VF, Ctx); + + if (isa<VPPartialReductionRecipe>(Red)) { + TargetTransformInfo::PartialReductionExtendKind ExtKind = + TargetTransformInfo::getPartialReductionExtendKind(ExtOpc); + // FIXME: Move partial reduction creation, costing and clamping + // here from LoopVectorize.cpp. + ExtRedCost = Ctx.TTI.getPartialReductionCost( + Opcode, SrcTy, nullptr, RedTy, VF, ExtKind, + llvm::TargetTransformInfo::PR_None, std::nullopt, Ctx.CostKind); + } else { + ExtRedCost = Ctx.TTI.getExtendedReductionCost( + Opcode, ExtOpc == Instruction::CastOps::ZExt, RedTy, SrcVecTy, + Red->getFastMathFlags(), CostKind); + } return ExtRedCost.isValid() && ExtRedCost < ExtCost + RedCost; }, Range); @@ -3541,8 +3554,7 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx, if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) && IsExtendedRedValidAndClampRange( RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()), - cast<VPWidenCastRecipe>(VecOp)->getOpcode() == - Instruction::CastOps::ZExt, + cast<VPWidenCastRecipe>(VecOp)->getOpcode(), Ctx.Types.inferScalarType(A))) return new VPExpressionRecipe(cast<VPWidenCastRecipe>(VecOp), Red); @@ -3560,6 +3572,8 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx, static VPExpressionRecipe * tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, VPCostContext &Ctx, VFRange &Range) { + bool IsPartialReduction = isa<VPPartialReductionRecipe>(Red); + unsigned Opcode = RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()); if (Opcode != Instruction::Add && Opcode != Instruction::Sub) return nullptr; @@ -3568,16 +3582,41 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, // Clamp the range if using multiply-accumulate-reduction is profitable. auto IsMulAccValidAndClampRange = - [&](bool isZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0, - VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt) -> bool { + [&](VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1, + VPWidenCastRecipe *OuterExt) -> bool { return LoopVectorizationPlanner::getDecisionAndClampRange( [&](ElementCount VF) { TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; Type *SrcTy = Ext0 ? Ctx.Types.inferScalarType(Ext0->getOperand(0)) : RedTy; - auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF)); - InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost( - isZExt, Opcode, RedTy, SrcVecTy, CostKind); + InstructionCost MulAccCost; + + if (IsPartialReduction) { + Type *SrcTy2 = + Ext1 ? Ctx.Types.inferScalarType(Ext1->getOperand(0)) : nullptr; + // FIXME: Move partial reduction creation, costing and clamping + // here from LoopVectorize.cpp. + MulAccCost = Ctx.TTI.getPartialReductionCost( + Opcode, SrcTy, SrcTy2, RedTy, VF, + Ext0 ? TargetTransformInfo::getPartialReductionExtendKind( + Ext0->getOpcode()) + : TargetTransformInfo::PR_None, + Ext1 ? TargetTransformInfo::getPartialReductionExtendKind( + Ext1->getOpcode()) + : TargetTransformInfo::PR_None, + Mul->getOpcode(), CostKind); + } else { + // Only partial reductions support mixed extends at the moment. + if (Ext0 && Ext1 && Ext0->getOpcode() != Ext1->getOpcode()) + return false; + + bool IsZExt = + !Ext0 || Ext0->getOpcode() == Instruction::CastOps::ZExt; + auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF)); + MulAccCost = Ctx.TTI.getMulAccReductionCost(IsZExt, Opcode, RedTy, + SrcVecTy, CostKind); + } + InstructionCost MulCost = Mul->computeCost(VF, Ctx); InstructionCost RedCost = Red->computeCost(VF, Ctx); InstructionCost ExtCost = 0; @@ -3611,14 +3650,10 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe()); auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe()); - // Match reduce.add(mul(ext, ext)). - if (RecipeA && RecipeB && - (RecipeA->getOpcode() == RecipeB->getOpcode() || A == B) && - match(RecipeA, m_ZExtOrSExt(m_VPValue())) && + // Match reduce.add/sub(mul(ext, ext)). + if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) && match(RecipeB, m_ZExtOrSExt(m_VPValue())) && - IsMulAccValidAndClampRange(RecipeA->getOpcode() == - Instruction::CastOps::ZExt, - Mul, RecipeA, RecipeB, nullptr)) { + IsMulAccValidAndClampRange(Mul, RecipeA, RecipeB, nullptr)) { if (Sub) return new VPExpressionRecipe(RecipeA, RecipeB, Mul, cast<VPWidenRecipe>(Sub), Red); @@ -3626,8 +3661,7 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, } // Match reduce.add(mul). // TODO: Add an expression type for this variant with a negated mul - if (!Sub && - IsMulAccValidAndClampRange(true, Mul, nullptr, nullptr, nullptr)) + if (!Sub && IsMulAccValidAndClampRange(Mul, nullptr, nullptr, nullptr)) return new VPExpressionRecipe(Mul, Red); } // TODO: Add an expression type for negated versions of other expression @@ -3647,9 +3681,7 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe()); if ((Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) && Ext0->getOpcode() == Ext1->getOpcode() && - IsMulAccValidAndClampRange(Ext0->getOpcode() == - Instruction::CastOps::ZExt, - Mul, Ext0, Ext1, Ext)) { + IsMulAccValidAndClampRange(Mul, Ext0, Ext1, Ext) && Mul->hasOneUse()) { auto *NewExt0 = new VPWidenCastRecipe( Ext0->getOpcode(), Ext0->getOperand(0), Ext->getResultType(), *Ext0, *Ext0, Ext0->getDebugLoc()); @@ -4087,7 +4119,7 @@ static bool isAlreadyNarrow(VPValue *VPV) { void VPlanTransforms::narrowInterleaveGroups(VPlan &Plan, ElementCount VF, unsigned VectorRegWidth) { VPRegionBlock *VectorLoop = Plan.getVectorLoopRegion(); - if (!VectorLoop) + if (!VectorLoop || VectorLoop->getEntry()->getNumSuccessors() != 0) return; VPTypeAnalysis TypeInfo(Plan); diff --git a/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp b/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp index 32e4b88..06c3d75 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp @@ -151,6 +151,8 @@ unsigned vputils::getVFScaleFactor(VPRecipeBase *R) { return RR->getVFScaleFactor(); if (auto *RR = dyn_cast<VPPartialReductionRecipe>(R)) return RR->getVFScaleFactor(); + if (auto *ER = dyn_cast<VPExpressionRecipe>(R)) + return ER->getVFScaleFactor(); assert( (!isa<VPInstruction>(R) || cast<VPInstruction>(R)->getOpcode() != VPInstruction::ReductionStartVector) && |
