diff options
Diffstat (limited to 'llvm/lib/CodeGen')
-rw-r--r-- | llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp | 2 | ||||
-rw-r--r-- | llvm/lib/CodeGen/AsmPrinter/DwarfExpression.cpp | 60 | ||||
-rw-r--r-- | llvm/lib/CodeGen/CodeGen.cpp | 1 | ||||
-rw-r--r-- | llvm/lib/CodeGen/ExpandFp.cpp | 61 | ||||
-rw-r--r-- | llvm/lib/CodeGen/IfConversion.cpp | 2 | ||||
-rw-r--r-- | llvm/lib/CodeGen/InterleavedAccessPass.cpp | 6 | ||||
-rw-r--r-- | llvm/lib/CodeGen/MIR2Vec.cpp | 166 | ||||
-rw-r--r-- | llvm/lib/CodeGen/MIRFSDiscriminator.cpp | 6 | ||||
-rw-r--r-- | llvm/lib/CodeGen/MIRSampleProfile.cpp | 4 | ||||
-rw-r--r-- | llvm/lib/CodeGen/RegAllocFast.cpp | 2 | ||||
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 5 | ||||
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp | 2 | ||||
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp | 15 | ||||
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp | 72 |
14 files changed, 277 insertions, 127 deletions
diff --git a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp index e2af0c5..a114406 100644 --- a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp +++ b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp @@ -1438,7 +1438,7 @@ getBBAddrMapFeature(const MachineFunction &MF, int NumMBBSectionRanges, BBFreqEnabled, BrProbEnabled, MF.hasBBSections() && NumMBBSectionRanges > 1, - static_cast<bool>(BBAddrMapSkipEmitBBEntries), + BBAddrMapSkipEmitBBEntries, HasCalls, false}; } diff --git a/llvm/lib/CodeGen/AsmPrinter/DwarfExpression.cpp b/llvm/lib/CodeGen/AsmPrinter/DwarfExpression.cpp index f0f0861..c7d45897 100644 --- a/llvm/lib/CodeGen/AsmPrinter/DwarfExpression.cpp +++ b/llvm/lib/CodeGen/AsmPrinter/DwarfExpression.cpp @@ -566,32 +566,54 @@ bool DwarfExpression::addExpression( case dwarf::DW_OP_LLVM_extract_bits_zext: { unsigned SizeInBits = Op->getArg(1); unsigned BitOffset = Op->getArg(0); + unsigned DerefSize = 0; + // Operations are done in the DWARF "generic type" whose size + // is the size of a pointer. + unsigned PtrSizeInBytes = CU.getAsmPrinter()->MAI->getCodePointerSize(); // If we have a memory location then dereference to get the value, though // we have to make sure we don't dereference any bytes past the end of the // object. if (isMemoryLocation()) { - emitOp(dwarf::DW_OP_deref_size); - emitUnsigned(alignTo(BitOffset + SizeInBits, 8) / 8); + DerefSize = alignTo(BitOffset + SizeInBits, 8) / 8; + if (DerefSize == PtrSizeInBytes) { + emitOp(dwarf::DW_OP_deref); + } else { + emitOp(dwarf::DW_OP_deref_size); + emitUnsigned(DerefSize); + } } - // Extract the bits by a shift left (to shift out the bits after what we - // want to extract) followed by shift right (to shift the bits to position - // 0 and also sign/zero extend). These operations are done in the DWARF - // "generic type" whose size is the size of a pointer. - unsigned PtrSizeInBytes = CU.getAsmPrinter()->MAI->getCodePointerSize(); - unsigned LeftShift = PtrSizeInBytes * 8 - (SizeInBits + BitOffset); - unsigned RightShift = LeftShift + BitOffset; - if (LeftShift) { - emitOp(dwarf::DW_OP_constu); - emitUnsigned(LeftShift); - emitOp(dwarf::DW_OP_shl); - } - if (RightShift) { - emitOp(dwarf::DW_OP_constu); - emitUnsigned(RightShift); - emitOp(OpNum == dwarf::DW_OP_LLVM_extract_bits_sext ? dwarf::DW_OP_shra - : dwarf::DW_OP_shr); + // If a dereference was emitted for an unsigned value, and + // there's no bit offset, then a bit of optimization is + // possible. + if (OpNum == dwarf::DW_OP_LLVM_extract_bits_zext && BitOffset == 0) { + if (8 * DerefSize == SizeInBits) { + // The correct value is already on the stack. + } else { + // No need to shift, we can just mask off the desired bits. + emitOp(dwarf::DW_OP_constu); + emitUnsigned((1u << SizeInBits) - 1); + emitOp(dwarf::DW_OP_and); + } + } else { + // Extract the bits by a shift left (to shift out the bits after what we + // want to extract) followed by shift right (to shift the bits to + // position 0 and also sign/zero extend). + unsigned LeftShift = PtrSizeInBytes * 8 - (SizeInBits + BitOffset); + unsigned RightShift = LeftShift + BitOffset; + if (LeftShift) { + emitOp(dwarf::DW_OP_constu); + emitUnsigned(LeftShift); + emitOp(dwarf::DW_OP_shl); + } + if (RightShift) { + emitOp(dwarf::DW_OP_constu); + emitUnsigned(RightShift); + emitOp(OpNum == dwarf::DW_OP_LLVM_extract_bits_sext + ? dwarf::DW_OP_shra + : dwarf::DW_OP_shr); + } } // The value is now at the top of the stack, so set the location to diff --git a/llvm/lib/CodeGen/CodeGen.cpp b/llvm/lib/CodeGen/CodeGen.cpp index c438eae..9795a0b 100644 --- a/llvm/lib/CodeGen/CodeGen.cpp +++ b/llvm/lib/CodeGen/CodeGen.cpp @@ -98,6 +98,7 @@ void llvm::initializeCodeGen(PassRegistry &Registry) { initializeMachineUniformityAnalysisPassPass(Registry); initializeMIR2VecVocabLegacyAnalysisPass(Registry); initializeMIR2VecVocabPrinterLegacyPassPass(Registry); + initializeMIR2VecPrinterLegacyPassPass(Registry); initializeMachineUniformityInfoPrinterPassPass(Registry); initializeMachineVerifierLegacyPassPass(Registry); initializeObjCARCContractLegacyPassPass(Registry); diff --git a/llvm/lib/CodeGen/ExpandFp.cpp b/llvm/lib/CodeGen/ExpandFp.cpp index 04c7008..2b5ced3 100644 --- a/llvm/lib/CodeGen/ExpandFp.cpp +++ b/llvm/lib/CodeGen/ExpandFp.cpp @@ -993,7 +993,6 @@ static void addToWorklist(Instruction &I, static bool runImpl(Function &F, const TargetLowering &TLI, AssumptionCache *AC) { SmallVector<Instruction *, 4> Worklist; - bool Modified = false; unsigned MaxLegalFpConvertBitWidth = TLI.getMaxLargeFPConvertBitWidthSupported(); @@ -1003,50 +1002,49 @@ static bool runImpl(Function &F, const TargetLowering &TLI, if (MaxLegalFpConvertBitWidth >= llvm::IntegerType::MAX_INT_BITS) return false; - for (auto It = inst_begin(&F), End = inst_end(F); It != End;) { - Instruction &I = *It++; + auto ShouldHandleInst = [&](Instruction &I) { Type *Ty = I.getType(); // TODO: This pass doesn't handle scalable vectors. if (Ty->isScalableTy()) - continue; + return false; switch (I.getOpcode()) { case Instruction::FRem: - if (!targetSupportsFrem(TLI, Ty) && - FRemExpander::canExpandType(Ty->getScalarType())) { - addToWorklist(I, Worklist); - Modified = true; - } - break; + return !targetSupportsFrem(TLI, Ty) && + FRemExpander::canExpandType(Ty->getScalarType()); + case Instruction::FPToUI: case Instruction::FPToSI: { auto *IntTy = cast<IntegerType>(Ty->getScalarType()); - if (IntTy->getIntegerBitWidth() <= MaxLegalFpConvertBitWidth) - continue; - - addToWorklist(I, Worklist); - Modified = true; - break; + return IntTy->getIntegerBitWidth() > MaxLegalFpConvertBitWidth; } + case Instruction::UIToFP: case Instruction::SIToFP: { auto *IntTy = cast<IntegerType>(I.getOperand(0)->getType()->getScalarType()); - if (IntTy->getIntegerBitWidth() <= MaxLegalFpConvertBitWidth) - continue; - - addToWorklist(I, Worklist); - Modified = true; - break; + return IntTy->getIntegerBitWidth() > MaxLegalFpConvertBitWidth; } - default: - break; } + + return false; + }; + + bool Modified = false; + for (auto It = inst_begin(&F), End = inst_end(F); It != End;) { + Instruction &I = *It++; + if (!ShouldHandleInst(I)) + continue; + + addToWorklist(I, Worklist); + Modified = true; } while (!Worklist.empty()) { Instruction *I = Worklist.pop_back_val(); - if (I->getOpcode() == Instruction::FRem) { + + switch (I->getOpcode()) { + case Instruction::FRem: { auto SQ = [&]() -> std::optional<SimplifyQuery> { if (AC) { auto Res = std::make_optional<SimplifyQuery>( @@ -1058,11 +1056,18 @@ static bool runImpl(Function &F, const TargetLowering &TLI, }(); expandFRem(cast<BinaryOperator>(*I), SQ); - } else if (I->getOpcode() == Instruction::FPToUI || - I->getOpcode() == Instruction::FPToSI) { + break; + } + + case Instruction::FPToUI: + case Instruction::FPToSI: expandFPToI(I); - } else { + break; + + case Instruction::UIToFP: + case Instruction::SIToFP: expandIToFP(I); + break; } } diff --git a/llvm/lib/CodeGen/IfConversion.cpp b/llvm/lib/CodeGen/IfConversion.cpp index f80e1e8..3ac6d2a 100644 --- a/llvm/lib/CodeGen/IfConversion.cpp +++ b/llvm/lib/CodeGen/IfConversion.cpp @@ -1498,7 +1498,7 @@ static void UpdatePredRedefs(MachineInstr &MI, LivePhysRegs &Redefs) { // Before stepping forward past MI, remember which regs were live // before MI. This is needed to set the Undef flag only when reg is // dead. - SparseSet<MCPhysReg, identity<MCPhysReg>> LiveBeforeMI; + SparseSet<MCPhysReg, MCPhysReg> LiveBeforeMI; LiveBeforeMI.setUniverse(TRI->getNumRegs()); for (unsigned Reg : Redefs) LiveBeforeMI.insert(Reg); diff --git a/llvm/lib/CodeGen/InterleavedAccessPass.cpp b/llvm/lib/CodeGen/InterleavedAccessPass.cpp index a6a9b50..5c27a20 100644 --- a/llvm/lib/CodeGen/InterleavedAccessPass.cpp +++ b/llvm/lib/CodeGen/InterleavedAccessPass.cpp @@ -258,13 +258,11 @@ static Value *getMaskOperand(IntrinsicInst *II) { default: llvm_unreachable("Unexpected intrinsic"); case Intrinsic::vp_load: - return II->getOperand(1); case Intrinsic::masked_load: - return II->getOperand(2); + return II->getOperand(1); case Intrinsic::vp_store: - return II->getOperand(2); case Intrinsic::masked_store: - return II->getOperand(3); + return II->getOperand(2); } } diff --git a/llvm/lib/CodeGen/MIR2Vec.cpp b/llvm/lib/CodeGen/MIR2Vec.cpp index 5c78d98..99be1fc0 100644 --- a/llvm/lib/CodeGen/MIR2Vec.cpp +++ b/llvm/lib/CodeGen/MIR2Vec.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "llvm/CodeGen/MIR2Vec.h" +#include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/Statistic.h" #include "llvm/CodeGen/TargetInstrInfo.h" #include "llvm/IR/Module.h" @@ -29,20 +30,30 @@ using namespace mir2vec; STATISTIC(MIRVocabMissCounter, "Number of lookups to MIR entities not present in the vocabulary"); -cl::OptionCategory llvm::mir2vec::MIR2VecCategory("MIR2Vec Options"); +namespace llvm { +namespace mir2vec { +cl::OptionCategory MIR2VecCategory("MIR2Vec Options"); // FIXME: Use a default vocab when not specified static cl::opt<std::string> VocabFile("mir2vec-vocab-path", cl::Optional, cl::desc("Path to the vocabulary file for MIR2Vec"), cl::init(""), cl::cat(MIR2VecCategory)); -cl::opt<float> - llvm::mir2vec::OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0), - cl::desc("Weight for machine opcode embeddings"), - cl::cat(MIR2VecCategory)); +cl::opt<float> OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0), + cl::desc("Weight for machine opcode embeddings"), + cl::cat(MIR2VecCategory)); +cl::opt<MIR2VecKind> MIR2VecEmbeddingKind( + "mir2vec-kind", cl::Optional, + cl::values(clEnumValN(MIR2VecKind::Symbolic, "symbolic", + "Generate symbolic embeddings for MIR")), + cl::init(MIR2VecKind::Symbolic), cl::desc("MIR2Vec embedding kind"), + cl::cat(MIR2VecCategory)); + +} // namespace mir2vec +} // namespace llvm //===----------------------------------------------------------------------===// -// Vocabulary Implementation +// Vocabulary //===----------------------------------------------------------------------===// MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries, @@ -188,6 +199,28 @@ void MIRVocabulary::buildCanonicalOpcodeMapping() { << " unique base opcodes\n"); } +Expected<MIRVocabulary> +MIRVocabulary::createDummyVocabForTest(const TargetInstrInfo &TII, + unsigned Dim) { + assert(Dim > 0 && "Dimension must be greater than zero"); + + float DummyVal = 0.1f; + + // Create dummy embeddings for all canonical opcode names + VocabMap DummyVocabMap; + for (unsigned Opcode = 0; Opcode < TII.getNumOpcodes(); ++Opcode) { + std::string BaseOpcode = extractBaseOpcodeName(TII.getName(Opcode)); + if (DummyVocabMap.count(BaseOpcode) == 0) { + // Only add if not already present + DummyVocabMap[BaseOpcode] = Embedding(Dim, DummyVal); + DummyVal += 0.1f; + } + } + + // Create and return vocabulary with dummy embeddings + return MIRVocabulary::create(std::move(DummyVocabMap), TII); +} + //===----------------------------------------------------------------------===// // MIR2VecVocabLegacyAnalysis Implementation //===----------------------------------------------------------------------===// @@ -258,7 +291,73 @@ MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) { } //===----------------------------------------------------------------------===// -// Printer Passes Implementation +// MIREmbedder and its subclasses +//===----------------------------------------------------------------------===// + +std::unique_ptr<MIREmbedder> MIREmbedder::create(MIR2VecKind Mode, + const MachineFunction &MF, + const MIRVocabulary &Vocab) { + switch (Mode) { + case MIR2VecKind::Symbolic: + return std::make_unique<SymbolicMIREmbedder>(MF, Vocab); + } + return nullptr; +} + +Embedding MIREmbedder::computeEmbeddings(const MachineBasicBlock &MBB) const { + Embedding MBBVector(Dimension, 0); + + // Get instruction info for opcode name resolution + const auto &Subtarget = MF.getSubtarget(); + const auto *TII = Subtarget.getInstrInfo(); + if (!TII) { + MF.getFunction().getContext().emitError( + "MIR2Vec: No TargetInstrInfo available; cannot compute embeddings"); + return MBBVector; + } + + // Process each machine instruction in the basic block + for (const auto &MI : MBB) { + // Skip debug instructions and other metadata + if (MI.isDebugInstr()) + continue; + MBBVector += computeEmbeddings(MI); + } + + return MBBVector; +} + +Embedding MIREmbedder::computeEmbeddings() const { + Embedding MFuncVector(Dimension, 0); + + // Consider all reachable machine basic blocks in the function + for (const auto *MBB : depth_first(&MF)) + MFuncVector += computeEmbeddings(*MBB); + return MFuncVector; +} + +SymbolicMIREmbedder::SymbolicMIREmbedder(const MachineFunction &MF, + const MIRVocabulary &Vocab) + : MIREmbedder(MF, Vocab) {} + +std::unique_ptr<SymbolicMIREmbedder> +SymbolicMIREmbedder::create(const MachineFunction &MF, + const MIRVocabulary &Vocab) { + return std::make_unique<SymbolicMIREmbedder>(MF, Vocab); +} + +Embedding SymbolicMIREmbedder::computeEmbeddings(const MachineInstr &MI) const { + // Skip debug instructions and other metadata + if (MI.isDebugInstr()) + return Embedding(Dimension, 0); + + // Todo: Add operand/argument contributions + + return Vocab[MI.getOpcode()]; +} + +//===----------------------------------------------------------------------===// +// Printer Passes //===----------------------------------------------------------------------===// char MIR2VecVocabPrinterLegacyPass::ID = 0; @@ -297,3 +396,56 @@ MachineFunctionPass * llvm::createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS) { return new MIR2VecVocabPrinterLegacyPass(OS); } + +char MIR2VecPrinterLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(MIR2VecPrinterLegacyPass, "print-mir2vec", + "MIR2Vec Embedder Printer Pass", false, true) +INITIALIZE_PASS_DEPENDENCY(MIR2VecVocabLegacyAnalysis) +INITIALIZE_PASS_DEPENDENCY(MachineModuleInfoWrapperPass) +INITIALIZE_PASS_END(MIR2VecPrinterLegacyPass, "print-mir2vec", + "MIR2Vec Embedder Printer Pass", false, true) + +bool MIR2VecPrinterLegacyPass::runOnMachineFunction(MachineFunction &MF) { + auto &Analysis = getAnalysis<MIR2VecVocabLegacyAnalysis>(); + auto VocabOrErr = + Analysis.getMIR2VecVocabulary(*MF.getFunction().getParent()); + assert(VocabOrErr && "Failed to get MIR2Vec vocabulary"); + auto &MIRVocab = *VocabOrErr; + + auto Emb = mir2vec::MIREmbedder::create(MIR2VecEmbeddingKind, MF, MIRVocab); + if (!Emb) { + OS << "Error creating MIR2Vec embeddings for function " << MF.getName() + << "\n"; + return false; + } + + OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n"; + OS << "Machine Function vector: "; + Emb->getMFunctionVector().print(OS); + + OS << "Machine basic block vectors:\n"; + for (const MachineBasicBlock &MBB : MF) { + OS << "Machine basic block: " << MBB.getFullName() << ":\n"; + Emb->getMBBVector(MBB).print(OS); + } + + OS << "Machine instruction vectors:\n"; + for (const MachineBasicBlock &MBB : MF) { + for (const MachineInstr &MI : MBB) { + // Skip debug instructions as they are not + // embedded + if (MI.isDebugInstr()) + continue; + + OS << "Machine instruction: "; + MI.print(OS); + Emb->getMInstVector(MI).print(OS); + } + } + + return false; +} + +MachineFunctionPass *llvm::createMIR2VecPrinterLegacyPass(raw_ostream &OS) { + return new MIR2VecPrinterLegacyPass(OS); +} diff --git a/llvm/lib/CodeGen/MIRFSDiscriminator.cpp b/llvm/lib/CodeGen/MIRFSDiscriminator.cpp index d988a2a..e37f784 100644 --- a/llvm/lib/CodeGen/MIRFSDiscriminator.cpp +++ b/llvm/lib/CodeGen/MIRFSDiscriminator.cpp @@ -15,6 +15,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/Analysis/BlockFrequencyInfoImpl.h" +#include "llvm/CodeGen/MIRFSDiscriminatorOptions.h" #include "llvm/CodeGen/Passes.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/Function.h" @@ -35,13 +36,10 @@ using namespace sampleprofutil; // TODO(xur): Remove this option and related code once we make true as the // default. -namespace llvm { -cl::opt<bool> ImprovedFSDiscriminator( +cl::opt<bool> llvm::ImprovedFSDiscriminator( "improved-fs-discriminator", cl::Hidden, cl::init(false), cl::desc("New FS discriminators encoding (incompatible with the original " "encoding)")); -} // namespace llvm - char MIRAddFSDiscriminators::ID = 0; INITIALIZE_PASS(MIRAddFSDiscriminators, DEBUG_TYPE, diff --git a/llvm/lib/CodeGen/MIRSampleProfile.cpp b/llvm/lib/CodeGen/MIRSampleProfile.cpp index 9bba50e8..d44f577 100644 --- a/llvm/lib/CodeGen/MIRSampleProfile.cpp +++ b/llvm/lib/CodeGen/MIRSampleProfile.cpp @@ -15,6 +15,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/Analysis/BlockFrequencyInfoImpl.h" +#include "llvm/CodeGen/MIRFSDiscriminatorOptions.h" #include "llvm/CodeGen/MachineBlockFrequencyInfo.h" #include "llvm/CodeGen/MachineBranchProbabilityInfo.h" #include "llvm/CodeGen/MachineDominators.h" @@ -62,9 +63,6 @@ static cl::opt<bool> ViewBFIAfter("fs-viewbfi-after", cl::Hidden, cl::init(false), cl::desc("View BFI after MIR loader")); -namespace llvm { -extern cl::opt<bool> ImprovedFSDiscriminator; -} char MIRProfileLoaderPass::ID = 0; INITIALIZE_PASS_BEGIN(MIRProfileLoaderPass, DEBUG_TYPE, diff --git a/llvm/lib/CodeGen/RegAllocFast.cpp b/llvm/lib/CodeGen/RegAllocFast.cpp index 804480c..72b364c 100644 --- a/llvm/lib/CodeGen/RegAllocFast.cpp +++ b/llvm/lib/CodeGen/RegAllocFast.cpp @@ -211,7 +211,7 @@ private: unsigned getSparseSetIndex() const { return VirtReg.virtRegIndex(); } }; - using LiveRegMap = SparseSet<LiveReg, identity<unsigned>, uint16_t>; + using LiveRegMap = SparseSet<LiveReg, unsigned, identity_cxx20, uint16_t>; /// This map contains entries for each virtual register that is currently /// available in a physical register. LiveRegMap LiveVirtRegs; diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index c97300d..310d35d 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -16433,7 +16433,8 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { case ISD::OR: case ISD::XOR: if (!LegalOperations && N0.hasOneUse() && - (isConstantOrConstantVector(N0.getOperand(0), true) || + (N0.getOperand(0) == N0.getOperand(1) || + isConstantOrConstantVector(N0.getOperand(0), true) || isConstantOrConstantVector(N0.getOperand(1), true))) { // TODO: We already restricted this to pre-legalization, but for vectors // we are extra cautious to not create an unsupported operation. @@ -26876,6 +26877,8 @@ static SDValue combineTruncationShuffle(ShuffleVectorSDNode *SVN, // TODO: handle more extension/truncation cases as cases arise. if (EltSizeInBits != ExtSrcSizeInBits) return SDValue(); + if (VT.getSizeInBits() != N00.getValueSizeInBits()) + return SDValue(); // We can remove *extend_vector_inreg only if the truncation happens at // the same scale as the extension. diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp index 437d0f4..bf1abfe 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp @@ -3765,6 +3765,8 @@ bool DAGTypeLegalizer::SoftPromoteHalfOperand(SDNode *N, unsigned OpNo) { case ISD::FP_TO_UINT: case ISD::LRINT: case ISD::LLRINT: + case ISD::LROUND: + case ISD::LLROUND: Res = SoftPromoteHalfOp_Op0WithStrict(N); break; case ISD::FP_TO_SINT_SAT: diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp index 88a4a8b..b1776ea 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp @@ -429,7 +429,20 @@ SDValue DAGTypeLegalizer::PromoteIntRes_Atomic0(AtomicSDNode *N) { } SDValue DAGTypeLegalizer::PromoteIntRes_Atomic1(AtomicSDNode *N) { - SDValue Op2 = GetPromotedInteger(N->getOperand(2)); + SDValue Op2 = N->getOperand(2); + switch (TLI.getExtendForAtomicRMWArg(N->getOpcode())) { + case ISD::SIGN_EXTEND: + Op2 = SExtPromotedInteger(Op2); + break; + case ISD::ZERO_EXTEND: + Op2 = ZExtPromotedInteger(Op2); + break; + case ISD::ANY_EXTEND: + Op2 = GetPromotedInteger(Op2); + break; + default: + llvm_unreachable("Invalid atomic op extension"); + } SDValue Res = DAG.getAtomic(N->getOpcode(), SDLoc(N), N->getMemoryVT(), N->getChain(), N->getBasePtr(), diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index cb0038c..20a0efd 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -4837,29 +4837,10 @@ void SelectionDAGBuilder::visitMaskedStore(const CallInst &I, bool IsCompressing) { SDLoc sdl = getCurSDLoc(); - auto getMaskedStoreOps = [&](Value *&Ptr, Value *&Mask, Value *&Src0, - Align &Alignment) { - // llvm.masked.store.*(Src0, Ptr, alignment, Mask) - Src0 = I.getArgOperand(0); - Ptr = I.getArgOperand(1); - Alignment = cast<ConstantInt>(I.getArgOperand(2))->getAlignValue(); - Mask = I.getArgOperand(3); - }; - auto getCompressingStoreOps = [&](Value *&Ptr, Value *&Mask, Value *&Src0, - Align &Alignment) { - // llvm.masked.compressstore.*(Src0, Ptr, Mask) - Src0 = I.getArgOperand(0); - Ptr = I.getArgOperand(1); - Mask = I.getArgOperand(2); - Alignment = I.getParamAlign(1).valueOrOne(); - }; - - Value *PtrOperand, *MaskOperand, *Src0Operand; - Align Alignment; - if (IsCompressing) - getCompressingStoreOps(PtrOperand, MaskOperand, Src0Operand, Alignment); - else - getMaskedStoreOps(PtrOperand, MaskOperand, Src0Operand, Alignment); + Value *Src0Operand = I.getArgOperand(0); + Value *PtrOperand = I.getArgOperand(1); + Value *MaskOperand = I.getArgOperand(2); + Align Alignment = I.getParamAlign(1).valueOrOne(); SDValue Ptr = getValue(PtrOperand); SDValue Src0 = getValue(Src0Operand); @@ -4964,14 +4945,12 @@ static bool getUniformBase(const Value *Ptr, SDValue &Base, SDValue &Index, void SelectionDAGBuilder::visitMaskedScatter(const CallInst &I) { SDLoc sdl = getCurSDLoc(); - // llvm.masked.scatter.*(Src0, Ptrs, alignment, Mask) + // llvm.masked.scatter.*(Src0, Ptrs, Mask) const Value *Ptr = I.getArgOperand(1); SDValue Src0 = getValue(I.getArgOperand(0)); - SDValue Mask = getValue(I.getArgOperand(3)); + SDValue Mask = getValue(I.getArgOperand(2)); EVT VT = Src0.getValueType(); - Align Alignment = cast<ConstantInt>(I.getArgOperand(2)) - ->getMaybeAlignValue() - .value_or(DAG.getEVTAlign(VT.getScalarType())); + Align Alignment = I.getParamAlign(1).valueOrOne(); const TargetLowering &TLI = DAG.getTargetLoweringInfo(); SDValue Base; @@ -5008,29 +4987,10 @@ void SelectionDAGBuilder::visitMaskedScatter(const CallInst &I) { void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I, bool IsExpanding) { SDLoc sdl = getCurSDLoc(); - auto getMaskedLoadOps = [&](Value *&Ptr, Value *&Mask, Value *&Src0, - Align &Alignment) { - // @llvm.masked.load.*(Ptr, alignment, Mask, Src0) - Ptr = I.getArgOperand(0); - Alignment = cast<ConstantInt>(I.getArgOperand(1))->getAlignValue(); - Mask = I.getArgOperand(2); - Src0 = I.getArgOperand(3); - }; - auto getExpandingLoadOps = [&](Value *&Ptr, Value *&Mask, Value *&Src0, - Align &Alignment) { - // @llvm.masked.expandload.*(Ptr, Mask, Src0) - Ptr = I.getArgOperand(0); - Alignment = I.getParamAlign(0).valueOrOne(); - Mask = I.getArgOperand(1); - Src0 = I.getArgOperand(2); - }; - - Value *PtrOperand, *MaskOperand, *Src0Operand; - Align Alignment; - if (IsExpanding) - getExpandingLoadOps(PtrOperand, MaskOperand, Src0Operand, Alignment); - else - getMaskedLoadOps(PtrOperand, MaskOperand, Src0Operand, Alignment); + Value *PtrOperand = I.getArgOperand(0); + Value *MaskOperand = I.getArgOperand(1); + Value *Src0Operand = I.getArgOperand(2); + Align Alignment = I.getParamAlign(0).valueOrOne(); SDValue Ptr = getValue(PtrOperand); SDValue Src0 = getValue(Src0Operand); @@ -5077,16 +5037,14 @@ void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I, bool IsExpanding) { void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) { SDLoc sdl = getCurSDLoc(); - // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0) + // @llvm.masked.gather.*(Ptrs, Mask, Src0) const Value *Ptr = I.getArgOperand(0); - SDValue Src0 = getValue(I.getArgOperand(3)); - SDValue Mask = getValue(I.getArgOperand(2)); + SDValue Src0 = getValue(I.getArgOperand(2)); + SDValue Mask = getValue(I.getArgOperand(1)); const TargetLowering &TLI = DAG.getTargetLoweringInfo(); EVT VT = TLI.getValueType(DAG.getDataLayout(), I.getType()); - Align Alignment = cast<ConstantInt>(I.getArgOperand(1)) - ->getMaybeAlignValue() - .value_or(DAG.getEVTAlign(VT.getScalarType())); + Align Alignment = I.getParamAlign(0).valueOrOne(); const MDNode *Ranges = getRangeMetadata(I); |