aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib')
-rwxr-xr-xllvm/lib/Analysis/ConstantFolding.cpp9
-rw-r--r--llvm/lib/Analysis/IR2Vec.cpp2
-rw-r--r--llvm/lib/Analysis/InstructionSimplify.cpp182
-rw-r--r--llvm/lib/Analysis/Loads.cpp24
-rw-r--r--llvm/lib/Analysis/ValueTracking.cpp4
-rw-r--r--llvm/lib/BinaryFormat/DXContainer.cpp85
-rw-r--r--llvm/lib/Bitcode/Reader/BitcodeReader.cpp2
-rw-r--r--llvm/lib/Bitcode/Writer/BitcodeWriter.cpp2
-rw-r--r--llvm/lib/CAS/CMakeLists.txt1
-rw-r--r--llvm/lib/CAS/OnDiskDataAllocator.cpp234
-rw-r--r--llvm/lib/CAS/OnDiskTrieRawHashMap.cpp31
-rw-r--r--llvm/lib/CodeGen/AsmPrinter/DebugHandlerBase.cpp2
-rw-r--r--llvm/lib/CodeGen/InlineSpiller.cpp21
-rw-r--r--llvm/lib/CodeGen/LiveRangeEdit.cpp108
-rw-r--r--llvm/lib/CodeGen/SplitKit.cpp4
-rw-r--r--llvm/lib/ExecutionEngine/Orc/CMakeLists.txt1
-rw-r--r--llvm/lib/ExecutionEngine/Orc/EPCDebugObjectRegistrar.cpp7
-rw-r--r--llvm/lib/ExecutionEngine/Orc/EPCDynamicLibrarySearchGenerator.cpp18
-rw-r--r--llvm/lib/ExecutionEngine/Orc/EPCGenericDylibManager.cpp20
-rw-r--r--llvm/lib/ExecutionEngine/Orc/ExecutorResolutionGenerator.cpp98
-rw-r--r--llvm/lib/ExecutionEngine/Orc/LookupAndRecordAddrs.cpp7
-rw-r--r--llvm/lib/ExecutionEngine/Orc/SelfExecutorProcessControl.cpp18
-rw-r--r--llvm/lib/ExecutionEngine/Orc/Shared/OrcRTBridge.cpp4
-rw-r--r--llvm/lib/ExecutionEngine/Orc/TargetProcess/CMakeLists.txt1
-rw-r--r--llvm/lib/ExecutionEngine/Orc/TargetProcess/ExecutorResolver.cpp47
-rw-r--r--llvm/lib/ExecutionEngine/Orc/TargetProcess/SimpleExecutorDylibManager.cpp71
-rw-r--r--llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp22
-rw-r--r--llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp13
-rw-r--r--llvm/lib/IR/ConstantFold.cpp9
-rw-r--r--llvm/lib/IR/Core.cpp8
-rw-r--r--llvm/lib/IR/Globals.cpp1
-rw-r--r--llvm/lib/IR/Instructions.cpp41
-rw-r--r--llvm/lib/IR/Mangler.cpp17
-rw-r--r--llvm/lib/IR/Verifier.cpp10
-rw-r--r--llvm/lib/Object/OffloadBundle.cpp582
-rw-r--r--llvm/lib/Option/ArgList.cpp38
-rw-r--r--llvm/lib/Option/OptTable.cpp76
-rw-r--r--llvm/lib/Passes/PassBuilder.cpp1
-rw-r--r--llvm/lib/Passes/PassBuilderPipelines.cpp1
-rw-r--r--llvm/lib/Passes/PassRegistry.def2
-rw-r--r--llvm/lib/Support/GlobPattern.cpp11
-rw-r--r--llvm/lib/Support/SpecialCaseList.cpp4
-rw-r--r--llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp11
-rw-r--r--llvm/lib/Target/AMDGPU/AMDGPU.td8
-rw-r--r--llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp14
-rw-r--r--llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp63
-rw-r--r--llvm/lib/Target/AMDGPU/GCNSubtarget.cpp2
-rw-r--r--llvm/lib/Target/AMDGPU/GCNSubtarget.h4
-rw-r--r--llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp11
-rw-r--r--llvm/lib/Target/AMDGPU/SIISelLowering.cpp76
-rw-r--r--llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.cpp22
-rw-r--r--llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.h10
-rw-r--r--llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp7
-rw-r--r--llvm/lib/Target/AMDGPU/VOP3Instructions.td31
-rw-r--r--llvm/lib/Target/AMDGPU/VOP3PInstructions.td70
-rw-r--r--llvm/lib/Target/AMDGPU/VOPInstructions.td7
-rw-r--r--llvm/lib/Target/PowerPC/AsmParser/PPCAsmParser.cpp54
-rw-r--r--llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.cpp39
-rw-r--r--llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.h17
-rw-r--r--llvm/lib/Target/PowerPC/PPCInstr64Bit.td24
-rw-r--r--llvm/lib/Target/PowerPC/PPCInstrAltivec.td19
-rw-r--r--llvm/lib/Target/PowerPC/PPCRegisterInfo.td67
-rw-r--r--llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp12
-rw-r--r--llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp8
-rw-r--r--llvm/lib/Target/RISCV/RISCVFeatures.td2
-rw-r--r--llvm/lib/Target/RISCV/RISCVGISel.td26
-rw-r--r--llvm/lib/Target/RISCV/RISCVISelLowering.cpp147
-rw-r--r--llvm/lib/Target/RISCV/RISCVInstrInfo.cpp29
-rw-r--r--llvm/lib/Target/RISCV/RISCVInstrInfo.h19
-rw-r--r--llvm/lib/Target/RISCV/RISCVInstrInfoXMips.td2
-rw-r--r--llvm/lib/Target/RISCV/RISCVInstrInfoZalasr.td9
-rw-r--r--llvm/lib/Target/RISCV/RISCVInstrInfoZb.td4
-rw-r--r--llvm/lib/Target/RISCV/RISCVLoadStoreOptimizer.cpp2
-rw-r--r--llvm/lib/Target/RISCV/RISCVSubtarget.cpp12
-rw-r--r--llvm/lib/Target/RISCV/RISCVSubtarget.h4
-rw-r--r--llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp191
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp4
-rw-r--r--llvm/lib/TargetParser/TargetParser.cpp2
-rw-r--r--llvm/lib/Transforms/IPO/FunctionAttrs.cpp119
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp5
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp10
-rw-r--r--llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp10
-rw-r--r--llvm/lib/Transforms/Instrumentation/AllocToken.cpp494
-rw-r--r--llvm/lib/Transforms/Instrumentation/CMakeLists.txt1
-rw-r--r--llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp14
-rw-r--r--llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp93
-rw-r--r--llvm/lib/Transforms/Utils/CodeExtractor.cpp1
-rw-r--r--llvm/lib/Transforms/Utils/Local.cpp7
-rw-r--r--llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp92
-rw-r--r--llvm/lib/Transforms/Utils/SimplifyCFG.cpp40
-rw-r--r--llvm/lib/Transforms/Vectorize/LoopVectorize.cpp29
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlan.cpp9
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlan.h3
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanHelpers.h16
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp126
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp6
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanUtils.cpp4
97 files changed, 2878 insertions, 1069 deletions
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index d52b073..b744537 100755
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -1482,6 +1482,15 @@ Constant *llvm::ConstantFoldFPInstOperands(unsigned Opcode, Constant *LHS,
Constant *llvm::ConstantFoldCastOperand(unsigned Opcode, Constant *C,
Type *DestTy, const DataLayout &DL) {
assert(Instruction::isCast(Opcode));
+
+ if (auto *CE = dyn_cast<ConstantExpr>(C))
+ if (CE->isCast())
+ if (unsigned NewOp = CastInst::isEliminableCastPair(
+ Instruction::CastOps(CE->getOpcode()),
+ Instruction::CastOps(Opcode), CE->getOperand(0)->getType(),
+ C->getType(), DestTy, &DL))
+ return ConstantFoldCastOperand(NewOp, CE->getOperand(0), DestTy, DL);
+
switch (Opcode) {
default:
llvm_unreachable("Missing case");
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 295b6d3..6885351 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -200,6 +200,8 @@ void Embedder::computeEmbeddings() const {
if (F.isDeclaration())
return;
+ FuncVector = Embedding(Dimension, 0.0);
+
// Consider only the basic blocks that are reachable from entry
for (const BasicBlock *BB : depth_first(&F)) {
computeEmbeddings(*BB);
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 0d978d4..d1977f0 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -5425,15 +5425,8 @@ static Value *simplifyCastInst(unsigned CastOpc, Value *Op, Type *Ty,
if (Src->getType() == Ty) {
auto FirstOp = CI->getOpcode();
auto SecondOp = static_cast<Instruction::CastOps>(CastOpc);
- Type *SrcIntPtrTy =
- SrcTy->isPtrOrPtrVectorTy() ? Q.DL.getIntPtrType(SrcTy) : nullptr;
- Type *MidIntPtrTy =
- MidTy->isPtrOrPtrVectorTy() ? Q.DL.getIntPtrType(MidTy) : nullptr;
- Type *DstIntPtrTy =
- DstTy->isPtrOrPtrVectorTy() ? Q.DL.getIntPtrType(DstTy) : nullptr;
if (CastInst::isEliminableCastPair(FirstOp, SecondOp, SrcTy, MidTy, DstTy,
- SrcIntPtrTy, MidIntPtrTy,
- DstIntPtrTy) == Instruction::BitCast)
+ &Q.DL) == Instruction::BitCast)
return Src;
}
}
@@ -6473,7 +6466,8 @@ static Value *foldMinMaxSharedOp(Intrinsic::ID IID, Value *Op0, Value *Op1) {
static Value *foldMinimumMaximumSharedOp(Intrinsic::ID IID, Value *Op0,
Value *Op1) {
assert((IID == Intrinsic::maxnum || IID == Intrinsic::minnum ||
- IID == Intrinsic::maximum || IID == Intrinsic::minimum) &&
+ IID == Intrinsic::maximum || IID == Intrinsic::minimum ||
+ IID == Intrinsic::maximumnum || IID == Intrinsic::minimumnum) &&
"Unsupported intrinsic");
auto *M0 = dyn_cast<IntrinsicInst>(Op0);
@@ -6512,6 +6506,82 @@ static Value *foldMinimumMaximumSharedOp(Intrinsic::ID IID, Value *Op0,
return nullptr;
}
+enum class MinMaxOptResult {
+ CannotOptimize = 0,
+ UseNewConstVal = 1,
+ UseOtherVal = 2,
+ // For undef/poison, we can choose to either propgate undef/poison or
+ // use the LHS value depending on what will allow more optimization.
+ UseEither = 3
+};
+// Get the optimized value for a min/max instruction with a single constant
+// input (either undef or scalar constantFP). The result may indicate to
+// use the non-const LHS value, use a new constant value instead (with NaNs
+// quieted), or to choose either option in the case of undef/poison.
+static MinMaxOptResult OptimizeConstMinMax(const Constant *RHSConst,
+ const Intrinsic::ID IID,
+ const CallBase *Call,
+ Constant **OutNewConstVal) {
+ assert(OutNewConstVal != nullptr);
+
+ bool PropagateNaN = IID == Intrinsic::minimum || IID == Intrinsic::maximum;
+ bool PropagateSNaN = IID == Intrinsic::minnum || IID == Intrinsic::maxnum;
+ bool IsMin = IID == Intrinsic::minimum || IID == Intrinsic::minnum ||
+ IID == Intrinsic::minimumnum;
+
+ // min/max(x, poison) -> either x or poison
+ if (isa<UndefValue>(RHSConst)) {
+ *OutNewConstVal = const_cast<Constant *>(RHSConst);
+ return MinMaxOptResult::UseEither;
+ }
+
+ const ConstantFP *CFP = dyn_cast<ConstantFP>(RHSConst);
+ if (!CFP)
+ return MinMaxOptResult::CannotOptimize;
+ APFloat CAPF = CFP->getValueAPF();
+
+ // minnum(x, qnan) -> x
+ // maxnum(x, qnan) -> x
+ // minnum(x, snan) -> qnan
+ // maxnum(x, snan) -> qnan
+ // minimum(X, nan) -> qnan
+ // maximum(X, nan) -> qnan
+ // minimumnum(X, nan) -> x
+ // maximumnum(X, nan) -> x
+ if (CAPF.isNaN()) {
+ if (PropagateNaN || (PropagateSNaN && CAPF.isSignaling())) {
+ *OutNewConstVal = ConstantFP::get(CFP->getType(), CAPF.makeQuiet());
+ return MinMaxOptResult::UseNewConstVal;
+ }
+ return MinMaxOptResult::UseOtherVal;
+ }
+
+ if (CAPF.isInfinity() || (Call && Call->hasNoInfs() && CAPF.isLargest())) {
+ // minnum(X, -inf) -> -inf (ignoring sNaN -> qNaN propagation)
+ // maxnum(X, +inf) -> +inf (ignoring sNaN -> qNaN propagation)
+ // minimum(X, -inf) -> -inf if nnan
+ // maximum(X, +inf) -> +inf if nnan
+ // minimumnum(X, -inf) -> -inf
+ // maximumnum(X, +inf) -> +inf
+ if (CAPF.isNegative() == IsMin &&
+ (!PropagateNaN || (Call && Call->hasNoNaNs()))) {
+ *OutNewConstVal = const_cast<Constant *>(RHSConst);
+ return MinMaxOptResult::UseNewConstVal;
+ }
+
+ // minnum(X, +inf) -> X if nnan
+ // maxnum(X, -inf) -> X if nnan
+ // minimum(X, +inf) -> X (ignoring quieting of sNaNs)
+ // maximum(X, -inf) -> X (ignoring quieting of sNaNs)
+ // minimumnum(X, +inf) -> X if nnan
+ // maximumnum(X, -inf) -> X if nnan
+ if (CAPF.isNegative() != IsMin &&
+ (PropagateNaN || (Call && Call->hasNoNaNs())))
+ return MinMaxOptResult::UseOtherVal;
+ }
+ return MinMaxOptResult::CannotOptimize;
+}
+
Value *llvm::simplifyBinaryIntrinsic(Intrinsic::ID IID, Type *ReturnType,
Value *Op0, Value *Op1,
const SimplifyQuery &Q,
@@ -6780,8 +6850,17 @@ Value *llvm::simplifyBinaryIntrinsic(Intrinsic::ID IID, Type *ReturnType,
case Intrinsic::maxnum:
case Intrinsic::minnum:
case Intrinsic::maximum:
- case Intrinsic::minimum: {
- // If the arguments are the same, this is a no-op.
+ case Intrinsic::minimum:
+ case Intrinsic::maximumnum:
+ case Intrinsic::minimumnum: {
+ // In several cases here, we deviate from exact IEEE 754 semantics
+ // to enable optimizations (as allowed by the LLVM IR spec).
+ //
+ // For instance, we may return one of the arguments unmodified instead of
+ // inserting an llvm.canonicalize to transform input sNaNs into qNaNs,
+ // or may assume all NaN inputs are qNaNs.
+
+ // If the arguments are the same, this is a no-op (ignoring NaN quieting)
if (Op0 == Op1)
return Op0;
@@ -6789,40 +6868,55 @@ Value *llvm::simplifyBinaryIntrinsic(Intrinsic::ID IID, Type *ReturnType,
if (isa<Constant>(Op0))
std::swap(Op0, Op1);
- // If an argument is undef, return the other argument.
- if (Q.isUndefValue(Op1))
- return Op0;
+ if (Constant *C = dyn_cast<Constant>(Op1)) {
+ MinMaxOptResult OptResult = MinMaxOptResult::CannotOptimize;
+ Constant *NewConst = nullptr;
+
+ if (VectorType *VTy = dyn_cast<VectorType>(C->getType())) {
+ ElementCount ElemCount = VTy->getElementCount();
+
+ if (Constant *SplatVal = C->getSplatValue()) {
+ // Handle splat vectors (including scalable vectors)
+ OptResult = OptimizeConstMinMax(SplatVal, IID, Call, &NewConst);
+ if (OptResult == MinMaxOptResult::UseNewConstVal)
+ NewConst = ConstantVector::getSplat(ElemCount, NewConst);
+
+ } else if (ElemCount.isFixed()) {
+ // Storage to build up new const return value (with NaNs quieted)
+ SmallVector<Constant *, 16> NewC(ElemCount.getFixedValue());
+
+ // Check elementwise whether we can optimize to either a constant
+ // value or return the LHS value. We cannot mix and match LHS +
+ // constant elements, as this would require inserting a new
+ // VectorShuffle instruction, which is not allowed in simplifyBinOp.
+ OptResult = MinMaxOptResult::UseEither;
+ for (unsigned i = 0; i != ElemCount.getFixedValue(); ++i) {
+ auto ElemResult = OptimizeConstMinMax(C->getAggregateElement(i),
+ IID, Call, &NewConst);
+ if (ElemResult == MinMaxOptResult::CannotOptimize ||
+ (ElemResult != OptResult &&
+ OptResult != MinMaxOptResult::UseEither &&
+ ElemResult != MinMaxOptResult::UseEither)) {
+ OptResult = MinMaxOptResult::CannotOptimize;
+ break;
+ }
+ NewC[i] = NewConst;
+ if (ElemResult != MinMaxOptResult::UseEither)
+ OptResult = ElemResult;
+ }
+ if (OptResult == MinMaxOptResult::UseNewConstVal)
+ NewConst = ConstantVector::get(NewC);
+ }
+ } else {
+ // Handle scalar inputs
+ OptResult = OptimizeConstMinMax(C, IID, Call, &NewConst);
+ }
- bool PropagateNaN = IID == Intrinsic::minimum || IID == Intrinsic::maximum;
- bool IsMin = IID == Intrinsic::minimum || IID == Intrinsic::minnum;
-
- // minnum(X, nan) -> X
- // maxnum(X, nan) -> X
- // minimum(X, nan) -> nan
- // maximum(X, nan) -> nan
- if (match(Op1, m_NaN()))
- return PropagateNaN ? propagateNaN(cast<Constant>(Op1)) : Op0;
-
- // In the following folds, inf can be replaced with the largest finite
- // float, if the ninf flag is set.
- const APFloat *C;
- if (match(Op1, m_APFloat(C)) &&
- (C->isInfinity() || (Call && Call->hasNoInfs() && C->isLargest()))) {
- // minnum(X, -inf) -> -inf
- // maxnum(X, +inf) -> +inf
- // minimum(X, -inf) -> -inf if nnan
- // maximum(X, +inf) -> +inf if nnan
- if (C->isNegative() == IsMin &&
- (!PropagateNaN || (Call && Call->hasNoNaNs())))
- return ConstantFP::get(ReturnType, *C);
-
- // minnum(X, +inf) -> X if nnan
- // maxnum(X, -inf) -> X if nnan
- // minimum(X, +inf) -> X
- // maximum(X, -inf) -> X
- if (C->isNegative() != IsMin &&
- (PropagateNaN || (Call && Call->hasNoNaNs())))
- return Op0;
+ if (OptResult == MinMaxOptResult::UseOtherVal ||
+ OptResult == MinMaxOptResult::UseEither)
+ return Op0; // Return the other arg (ignoring NaN quieting)
+ else if (OptResult == MinMaxOptResult::UseNewConstVal)
+ return NewConst;
}
// Min/max of the same operation with common operand:
diff --git a/llvm/lib/Analysis/Loads.cpp b/llvm/lib/Analysis/Loads.cpp
index 0c4e3a2..4c2e1fe 100644
--- a/llvm/lib/Analysis/Loads.cpp
+++ b/llvm/lib/Analysis/Loads.cpp
@@ -37,17 +37,13 @@ static bool isDereferenceableAndAlignedPointerViaAssumption(
function_ref<bool(const RetainedKnowledge &RK)> CheckSize,
const DataLayout &DL, const Instruction *CtxI, AssumptionCache *AC,
const DominatorTree *DT) {
- // Dereferenceable information from assumptions is only valid if the value
- // cannot be freed between the assumption and use. For now just use the
- // information for values that cannot be freed in the function.
- // TODO: More precisely check if the pointer can be freed between assumption
- // and use.
- if (!CtxI || Ptr->canBeFreed())
+ if (!CtxI)
return false;
/// Look through assumes to see if both dereferencability and alignment can
/// be proven by an assume if needed.
RetainedKnowledge AlignRK;
RetainedKnowledge DerefRK;
+ bool PtrCanBeFreed = Ptr->canBeFreed();
bool IsAligned = Ptr->getPointerAlignment(DL) >= Alignment;
return getKnowledgeForValue(
Ptr, {Attribute::Dereferenceable, Attribute::Alignment}, *AC,
@@ -56,7 +52,11 @@ static bool isDereferenceableAndAlignedPointerViaAssumption(
return false;
if (RK.AttrKind == Attribute::Alignment)
AlignRK = std::max(AlignRK, RK);
- if (RK.AttrKind == Attribute::Dereferenceable)
+
+ // Dereferenceable information from assumptions is only valid if the
+ // value cannot be freed between the assumption and use.
+ if ((!PtrCanBeFreed || willNotFreeBetween(Assume, CtxI)) &&
+ RK.AttrKind == Attribute::Dereferenceable)
DerefRK = std::max(DerefRK, RK);
IsAligned |= AlignRK && AlignRK.ArgValue >= Alignment.value();
if (IsAligned && DerefRK && CheckSize(DerefRK))
@@ -390,7 +390,11 @@ bool llvm::isDereferenceableAndAlignedInLoop(
} else
return false;
- Instruction *HeaderFirstNonPHI = &*L->getHeader()->getFirstNonPHIIt();
+ Instruction *CtxI = &*L->getHeader()->getFirstNonPHIIt();
+ if (BasicBlock *LoopPred = L->getLoopPredecessor()) {
+ if (isa<BranchInst>(LoopPred->getTerminator()))
+ CtxI = LoopPred->getTerminator();
+ }
return isDereferenceableAndAlignedPointerViaAssumption(
Base, Alignment,
[&SE, AccessSizeSCEV, &LoopGuards](const RetainedKnowledge &RK) {
@@ -399,9 +403,9 @@ bool llvm::isDereferenceableAndAlignedInLoop(
SE.applyLoopGuards(AccessSizeSCEV, *LoopGuards),
SE.applyLoopGuards(SE.getSCEV(RK.IRArgValue), *LoopGuards));
},
- DL, HeaderFirstNonPHI, AC, &DT) ||
+ DL, CtxI, AC, &DT) ||
isDereferenceableAndAlignedPointer(Base, Alignment, AccessSize, DL,
- HeaderFirstNonPHI, AC, &DT);
+ CtxI, AC, &DT);
}
static bool suppressSpeculativeLoadForSanitizers(const Instruction &CtxI) {
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index a42c061..9655c88 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -9095,6 +9095,10 @@ Intrinsic::ID llvm::getInverseMinMaxIntrinsic(Intrinsic::ID MinMaxID) {
case Intrinsic::minimum: return Intrinsic::maximum;
case Intrinsic::maxnum: return Intrinsic::minnum;
case Intrinsic::minnum: return Intrinsic::maxnum;
+ case Intrinsic::maximumnum:
+ return Intrinsic::minimumnum;
+ case Intrinsic::minimumnum:
+ return Intrinsic::maximumnum;
default: llvm_unreachable("Unexpected intrinsic");
}
}
diff --git a/llvm/lib/BinaryFormat/DXContainer.cpp b/llvm/lib/BinaryFormat/DXContainer.cpp
index c06a3e3..22f5180 100644
--- a/llvm/lib/BinaryFormat/DXContainer.cpp
+++ b/llvm/lib/BinaryFormat/DXContainer.cpp
@@ -18,6 +18,91 @@
using namespace llvm;
using namespace llvm::dxbc;
+#define ROOT_PARAMETER(Val, Enum) \
+ case Val: \
+ return true;
+bool llvm::dxbc::isValidParameterType(uint32_t V) {
+ switch (V) {
+#include "llvm/BinaryFormat/DXContainerConstants.def"
+ }
+ return false;
+}
+
+bool llvm::dxbc::isValidRangeType(uint32_t V) {
+ return V <= llvm::to_underlying(dxil::ResourceClass::LastEntry);
+}
+
+#define SHADER_VISIBILITY(Val, Enum) \
+ case Val: \
+ return true;
+bool llvm::dxbc::isValidShaderVisibility(uint32_t V) {
+ switch (V) {
+#include "llvm/BinaryFormat/DXContainerConstants.def"
+ }
+ return false;
+}
+
+#define FILTER(Val, Enum) \
+ case Val: \
+ return true;
+bool llvm::dxbc::isValidSamplerFilter(uint32_t V) {
+ switch (V) {
+#include "llvm/BinaryFormat/DXContainerConstants.def"
+ }
+ return false;
+}
+
+#define TEXTURE_ADDRESS_MODE(Val, Enum) \
+ case Val: \
+ return true;
+bool llvm::dxbc::isValidAddress(uint32_t V) {
+ switch (V) {
+#include "llvm/BinaryFormat/DXContainerConstants.def"
+ }
+ return false;
+}
+
+#define COMPARISON_FUNC(Val, Enum) \
+ case Val: \
+ return true;
+bool llvm::dxbc::isValidComparisonFunc(uint32_t V) {
+ switch (V) {
+#include "llvm/BinaryFormat/DXContainerConstants.def"
+ }
+ return false;
+}
+
+#define STATIC_BORDER_COLOR(Val, Enum) \
+ case Val: \
+ return true;
+bool llvm::dxbc::isValidBorderColor(uint32_t V) {
+ switch (V) {
+#include "llvm/BinaryFormat/DXContainerConstants.def"
+ }
+ return false;
+}
+
+bool llvm::dxbc::isValidRootDesciptorFlags(uint32_t V) {
+ using FlagT = dxbc::RootDescriptorFlags;
+ uint32_t LargestValue =
+ llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR);
+ return V < NextPowerOf2(LargestValue);
+}
+
+bool llvm::dxbc::isValidDescriptorRangeFlags(uint32_t V) {
+ using FlagT = dxbc::DescriptorRangeFlags;
+ uint32_t LargestValue =
+ llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR);
+ return V < NextPowerOf2(LargestValue);
+}
+
+bool llvm::dxbc::isValidStaticSamplerFlags(uint32_t V) {
+ using FlagT = dxbc::StaticSamplerFlags;
+ uint32_t LargestValue =
+ llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR);
+ return V < NextPowerOf2(LargestValue);
+}
+
dxbc::PartType dxbc::parsePartType(StringRef S) {
#define CONTAINER_PART(PartName) .Case(#PartName, PartType::PartName)
return StringSwitch<dxbc::PartType>(S)
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index 832aa9f..aaee1f0 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -2203,6 +2203,8 @@ static Attribute::AttrKind getAttrFromCode(uint64_t Code) {
return Attribute::SanitizeRealtime;
case bitc::ATTR_KIND_SANITIZE_REALTIME_BLOCKING:
return Attribute::SanitizeRealtimeBlocking;
+ case bitc::ATTR_KIND_SANITIZE_ALLOC_TOKEN:
+ return Attribute::SanitizeAllocToken;
case bitc::ATTR_KIND_SPECULATIVE_LOAD_HARDENING:
return Attribute::SpeculativeLoadHardening;
case bitc::ATTR_KIND_SWIFT_ERROR:
diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index c4070e1..6d86809 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -883,6 +883,8 @@ static uint64_t getAttrKindEncoding(Attribute::AttrKind Kind) {
return bitc::ATTR_KIND_STRUCT_RET;
case Attribute::SanitizeAddress:
return bitc::ATTR_KIND_SANITIZE_ADDRESS;
+ case Attribute::SanitizeAllocToken:
+ return bitc::ATTR_KIND_SANITIZE_ALLOC_TOKEN;
case Attribute::SanitizeHWAddress:
return bitc::ATTR_KIND_SANITIZE_HWADDRESS;
case Attribute::SanitizeThread:
diff --git a/llvm/lib/CAS/CMakeLists.txt b/llvm/lib/CAS/CMakeLists.txt
index 7ae5f7e..bca39b6 100644
--- a/llvm/lib/CAS/CMakeLists.txt
+++ b/llvm/lib/CAS/CMakeLists.txt
@@ -7,6 +7,7 @@ add_llvm_component_library(LLVMCAS
MappedFileRegionArena.cpp
ObjectStore.cpp
OnDiskCommon.cpp
+ OnDiskDataAllocator.cpp
OnDiskTrieRawHashMap.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/llvm/lib/CAS/OnDiskDataAllocator.cpp b/llvm/lib/CAS/OnDiskDataAllocator.cpp
new file mode 100644
index 0000000..13bbd66
--- /dev/null
+++ b/llvm/lib/CAS/OnDiskDataAllocator.cpp
@@ -0,0 +1,234 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file Implements OnDiskDataAllocator.
+///
+//===----------------------------------------------------------------------===//
+
+#include "llvm/CAS/OnDiskDataAllocator.h"
+#include "DatabaseFile.h"
+#include "llvm/Config/llvm-config.h"
+
+using namespace llvm;
+using namespace llvm::cas;
+using namespace llvm::cas::ondisk;
+
+#if LLVM_ENABLE_ONDISK_CAS
+
+//===----------------------------------------------------------------------===//
+// DataAllocator data structures.
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// DataAllocator table layout:
+/// - [8-bytes: Generic table header]
+/// - 8-bytes: AllocatorOffset (reserved for implementing free lists)
+/// - 8-bytes: Size for user data header
+/// - <user data buffer>
+///
+/// Record layout:
+/// - <data>
+class DataAllocatorHandle {
+public:
+ static constexpr TableHandle::TableKind Kind =
+ TableHandle::TableKind::DataAllocator;
+
+ struct Header {
+ TableHandle::Header GenericHeader;
+ std::atomic<int64_t> AllocatorOffset;
+ const uint64_t UserHeaderSize;
+ };
+
+ operator TableHandle() const {
+ if (!H)
+ return TableHandle();
+ return TableHandle(*Region, H->GenericHeader);
+ }
+
+ Expected<MutableArrayRef<char>> allocate(MappedFileRegionArena &Alloc,
+ size_t DataSize) {
+ assert(&Alloc.getRegion() == Region);
+ auto Ptr = Alloc.allocate(DataSize);
+ if (LLVM_UNLIKELY(!Ptr))
+ return Ptr.takeError();
+ return MutableArrayRef(*Ptr, DataSize);
+ }
+
+ explicit operator bool() const { return H; }
+ const Header &getHeader() const { return *H; }
+ MappedFileRegion &getRegion() const { return *Region; }
+
+ MutableArrayRef<uint8_t> getUserHeader() {
+ return MutableArrayRef(reinterpret_cast<uint8_t *>(H + 1),
+ H->UserHeaderSize);
+ }
+
+ static Expected<DataAllocatorHandle>
+ create(MappedFileRegionArena &Alloc, StringRef Name, uint32_t UserHeaderSize);
+
+ DataAllocatorHandle() = default;
+ DataAllocatorHandle(MappedFileRegion &Region, Header &H)
+ : Region(&Region), H(&H) {}
+ DataAllocatorHandle(MappedFileRegion &Region, intptr_t HeaderOffset)
+ : DataAllocatorHandle(
+ Region, *reinterpret_cast<Header *>(Region.data() + HeaderOffset)) {
+ }
+
+private:
+ MappedFileRegion *Region = nullptr;
+ Header *H = nullptr;
+};
+
+} // end anonymous namespace
+
+struct OnDiskDataAllocator::ImplType {
+ DatabaseFile File;
+ DataAllocatorHandle Store;
+};
+
+Expected<DataAllocatorHandle>
+DataAllocatorHandle::create(MappedFileRegionArena &Alloc, StringRef Name,
+ uint32_t UserHeaderSize) {
+ // Allocate.
+ auto Offset =
+ Alloc.allocateOffset(sizeof(Header) + UserHeaderSize + Name.size() + 1);
+ if (LLVM_UNLIKELY(!Offset))
+ return Offset.takeError();
+
+ // Construct the header and the name.
+ assert(Name.size() <= UINT16_MAX && "Expected smaller table name");
+ auto *H = new (Alloc.getRegion().data() + *Offset)
+ Header{{TableHandle::TableKind::DataAllocator,
+ static_cast<uint16_t>(Name.size()),
+ static_cast<int32_t>(sizeof(Header) + UserHeaderSize)},
+ /*AllocatorOffset=*/{0},
+ /*UserHeaderSize=*/UserHeaderSize};
+ // Memset UserHeader.
+ char *UserHeader = reinterpret_cast<char *>(H + 1);
+ memset(UserHeader, 0, UserHeaderSize);
+ // Write database file name (null-terminated).
+ char *NameStorage = UserHeader + UserHeaderSize;
+ llvm::copy(Name, NameStorage);
+ NameStorage[Name.size()] = 0;
+ return DataAllocatorHandle(Alloc.getRegion(), *H);
+}
+
+Expected<OnDiskDataAllocator> OnDiskDataAllocator::create(
+ const Twine &PathTwine, const Twine &TableNameTwine, uint64_t MaxFileSize,
+ std::optional<uint64_t> NewFileInitialSize, uint32_t UserHeaderSize,
+ function_ref<void(void *)> UserHeaderInit) {
+ assert(!UserHeaderSize || UserHeaderInit);
+ SmallString<128> PathStorage;
+ StringRef Path = PathTwine.toStringRef(PathStorage);
+ SmallString<128> TableNameStorage;
+ StringRef TableName = TableNameTwine.toStringRef(TableNameStorage);
+
+ // Constructor for if the file doesn't exist.
+ auto NewDBConstructor = [&](DatabaseFile &DB) -> Error {
+ auto Store =
+ DataAllocatorHandle::create(DB.getAlloc(), TableName, UserHeaderSize);
+ if (LLVM_UNLIKELY(!Store))
+ return Store.takeError();
+
+ if (auto E = DB.addTable(*Store))
+ return E;
+
+ if (UserHeaderSize)
+ UserHeaderInit(Store->getUserHeader().data());
+ return Error::success();
+ };
+
+ // Get or create the file.
+ Expected<DatabaseFile> File =
+ DatabaseFile::create(Path, MaxFileSize, NewDBConstructor);
+ if (!File)
+ return File.takeError();
+
+ // Find the table and validate it.
+ std::optional<TableHandle> Table = File->findTable(TableName);
+ if (!Table)
+ return createTableConfigError(std::errc::argument_out_of_domain, Path,
+ TableName, "table not found");
+ if (Error E = checkTable("table kind", (size_t)DataAllocatorHandle::Kind,
+ (size_t)Table->getHeader().Kind, Path, TableName))
+ return std::move(E);
+ auto Store = Table->cast<DataAllocatorHandle>();
+ assert(Store && "Already checked the kind");
+
+ // Success.
+ OnDiskDataAllocator::ImplType Impl{DatabaseFile(std::move(*File)), Store};
+ return OnDiskDataAllocator(std::make_unique<ImplType>(std::move(Impl)));
+}
+
+Expected<OnDiskDataAllocator::OnDiskPtr>
+OnDiskDataAllocator::allocate(size_t Size) {
+ auto Data = Impl->Store.allocate(Impl->File.getAlloc(), Size);
+ if (LLVM_UNLIKELY(!Data))
+ return Data.takeError();
+
+ return OnDiskPtr(FileOffset(Data->data() - Impl->Store.getRegion().data()),
+ *Data);
+}
+
+Expected<ArrayRef<char>> OnDiskDataAllocator::get(FileOffset Offset,
+ size_t Size) const {
+ assert(Offset);
+ assert(Impl);
+ if (Offset.get() + Size >= Impl->File.getAlloc().size())
+ return createStringError(make_error_code(std::errc::protocol_error),
+ "requested size too large in allocator");
+ return ArrayRef<char>{Impl->File.getRegion().data() + Offset.get(), Size};
+}
+
+MutableArrayRef<uint8_t> OnDiskDataAllocator::getUserHeader() {
+ return Impl->Store.getUserHeader();
+}
+
+size_t OnDiskDataAllocator::size() const { return Impl->File.size(); }
+size_t OnDiskDataAllocator::capacity() const {
+ return Impl->File.getRegion().size();
+}
+
+OnDiskDataAllocator::OnDiskDataAllocator(std::unique_ptr<ImplType> Impl)
+ : Impl(std::move(Impl)) {}
+
+#else // !LLVM_ENABLE_ONDISK_CAS
+
+struct OnDiskDataAllocator::ImplType {};
+
+Expected<OnDiskDataAllocator> OnDiskDataAllocator::create(
+ const Twine &Path, const Twine &TableName, uint64_t MaxFileSize,
+ std::optional<uint64_t> NewFileInitialSize, uint32_t UserHeaderSize,
+ function_ref<void(void *)> UserHeaderInit) {
+ return createStringError(make_error_code(std::errc::not_supported),
+ "OnDiskDataAllocator is not supported");
+}
+
+Expected<OnDiskDataAllocator::OnDiskPtr>
+OnDiskDataAllocator::allocate(size_t Size) {
+ return createStringError(make_error_code(std::errc::not_supported),
+ "OnDiskDataAllocator is not supported");
+}
+
+Expected<ArrayRef<char>> OnDiskDataAllocator::get(FileOffset Offset,
+ size_t Size) const {
+ return createStringError(make_error_code(std::errc::not_supported),
+ "OnDiskDataAllocator is not supported");
+}
+
+MutableArrayRef<uint8_t> OnDiskDataAllocator::getUserHeader() { return {}; }
+
+size_t OnDiskDataAllocator::size() const { return 0; }
+size_t OnDiskDataAllocator::capacity() const { return 0; }
+
+#endif // LLVM_ENABLE_ONDISK_CAS
+
+OnDiskDataAllocator::OnDiskDataAllocator(OnDiskDataAllocator &&RHS) = default;
+OnDiskDataAllocator &
+OnDiskDataAllocator::operator=(OnDiskDataAllocator &&RHS) = default;
+OnDiskDataAllocator::~OnDiskDataAllocator() = default;
diff --git a/llvm/lib/CAS/OnDiskTrieRawHashMap.cpp b/llvm/lib/CAS/OnDiskTrieRawHashMap.cpp
index 9403893..323b21e 100644
--- a/llvm/lib/CAS/OnDiskTrieRawHashMap.cpp
+++ b/llvm/lib/CAS/OnDiskTrieRawHashMap.cpp
@@ -427,7 +427,7 @@ TrieRawHashMapHandle::createRecord(MappedFileRegionArena &Alloc,
return Record;
}
-Expected<OnDiskTrieRawHashMap::const_pointer>
+Expected<OnDiskTrieRawHashMap::ConstOnDiskPtr>
OnDiskTrieRawHashMap::recoverFromFileOffset(FileOffset Offset) const {
// Check alignment.
if (!isAligned(MappedFileRegionArena::getAlign(), Offset.get()))
@@ -448,17 +448,17 @@ OnDiskTrieRawHashMap::recoverFromFileOffset(FileOffset Offset) const {
// Looks okay...
TrieRawHashMapHandle::RecordData D =
Impl->Trie.getRecord(SubtrieSlotValue::getDataOffset(Offset));
- return const_pointer(D.Proxy, D.getFileOffset());
+ return ConstOnDiskPtr(D.Proxy, D.getFileOffset());
}
-OnDiskTrieRawHashMap::const_pointer
+OnDiskTrieRawHashMap::ConstOnDiskPtr
OnDiskTrieRawHashMap::find(ArrayRef<uint8_t> Hash) const {
TrieRawHashMapHandle Trie = Impl->Trie;
assert(Hash.size() == Trie.getNumHashBytes() && "Invalid hash");
SubtrieHandle S = Trie.getRoot();
if (!S)
- return const_pointer();
+ return ConstOnDiskPtr();
TrieHashIndexGenerator IndexGen = Trie.getIndexGen(S, Hash);
size_t Index = IndexGen.next();
@@ -466,13 +466,13 @@ OnDiskTrieRawHashMap::find(ArrayRef<uint8_t> Hash) const {
// Try to set the content.
SubtrieSlotValue V = S.load(Index);
if (!V)
- return const_pointer();
+ return ConstOnDiskPtr();
// Check for an exact match.
if (V.isData()) {
TrieRawHashMapHandle::RecordData D = Trie.getRecord(V);
- return D.Proxy.Hash == Hash ? const_pointer(D.Proxy, D.getFileOffset())
- : const_pointer();
+ return D.Proxy.Hash == Hash ? ConstOnDiskPtr(D.Proxy, D.getFileOffset())
+ : ConstOnDiskPtr();
}
Index = IndexGen.next();
@@ -490,7 +490,7 @@ void SubtrieHandle::reinitialize(uint32_t StartBit, uint32_t NumBits) {
H->NumBits = NumBits;
}
-Expected<OnDiskTrieRawHashMap::pointer>
+Expected<OnDiskTrieRawHashMap::OnDiskPtr>
OnDiskTrieRawHashMap::insertLazy(ArrayRef<uint8_t> Hash,
LazyInsertOnConstructCB OnConstruct,
LazyInsertOnLeakCB OnLeak) {
@@ -523,7 +523,8 @@ OnDiskTrieRawHashMap::insertLazy(ArrayRef<uint8_t> Hash,
}
if (S->compare_exchange_strong(Index, Existing, NewRecord->Offset))
- return pointer(NewRecord->Proxy, NewRecord->Offset.asDataFileOffset());
+ return OnDiskPtr(NewRecord->Proxy,
+ NewRecord->Offset.asDataFileOffset());
// Race means that Existing is no longer empty; fall through...
}
@@ -540,8 +541,8 @@ OnDiskTrieRawHashMap::insertLazy(ArrayRef<uint8_t> Hash,
if (NewRecord && OnLeak)
OnLeak(NewRecord->Offset.asDataFileOffset(), NewRecord->Proxy,
ExistingRecord.Offset.asDataFileOffset(), ExistingRecord.Proxy);
- return pointer(ExistingRecord.Proxy,
- ExistingRecord.Offset.asDataFileOffset());
+ return OnDiskPtr(ExistingRecord.Proxy,
+ ExistingRecord.Offset.asDataFileOffset());
}
// Sink the existing content as long as the indexes match.
@@ -1135,7 +1136,7 @@ OnDiskTrieRawHashMap::create(const Twine &PathTwine, const Twine &TrieNameTwine,
"OnDiskTrieRawHashMap is not supported");
}
-Expected<OnDiskTrieRawHashMap::pointer>
+Expected<OnDiskTrieRawHashMap::OnDiskPtr>
OnDiskTrieRawHashMap::insertLazy(ArrayRef<uint8_t> Hash,
LazyInsertOnConstructCB OnConstruct,
LazyInsertOnLeakCB OnLeak) {
@@ -1143,15 +1144,15 @@ OnDiskTrieRawHashMap::insertLazy(ArrayRef<uint8_t> Hash,
"OnDiskTrieRawHashMap is not supported");
}
-Expected<OnDiskTrieRawHashMap::const_pointer>
+Expected<OnDiskTrieRawHashMap::ConstOnDiskPtr>
OnDiskTrieRawHashMap::recoverFromFileOffset(FileOffset Offset) const {
return createStringError(make_error_code(std::errc::not_supported),
"OnDiskTrieRawHashMap is not supported");
}
-OnDiskTrieRawHashMap::const_pointer
+OnDiskTrieRawHashMap::ConstOnDiskPtr
OnDiskTrieRawHashMap::find(ArrayRef<uint8_t> Hash) const {
- return const_pointer();
+ return ConstOnDiskPtr();
}
void OnDiskTrieRawHashMap::print(
diff --git a/llvm/lib/CodeGen/AsmPrinter/DebugHandlerBase.cpp b/llvm/lib/CodeGen/AsmPrinter/DebugHandlerBase.cpp
index d98d180..dc38f5a 100644
--- a/llvm/lib/CodeGen/AsmPrinter/DebugHandlerBase.cpp
+++ b/llvm/lib/CodeGen/AsmPrinter/DebugHandlerBase.cpp
@@ -240,6 +240,8 @@ bool DebugHandlerBase::isUnsignedDIType(const DIType *Ty) {
Encoding == dwarf::DW_ATE_complex_float ||
Encoding == dwarf::DW_ATE_signed_fixed ||
Encoding == dwarf::DW_ATE_unsigned_fixed ||
+ (Encoding >= dwarf::DW_ATE_lo_user &&
+ Encoding <= dwarf::DW_ATE_hi_user) ||
(Ty->getTag() == dwarf::DW_TAG_unspecified_type &&
Ty->getName() == "decltype(nullptr)")) &&
"Unsupported encoding");
diff --git a/llvm/lib/CodeGen/InlineSpiller.cpp b/llvm/lib/CodeGen/InlineSpiller.cpp
index 0c2b74c..d6e8505 100644
--- a/llvm/lib/CodeGen/InlineSpiller.cpp
+++ b/llvm/lib/CodeGen/InlineSpiller.cpp
@@ -671,10 +671,22 @@ bool InlineSpiller::reMaterializeFor(LiveInterval &VirtReg, MachineInstr &MI) {
LiveInterval &OrigLI = LIS.getInterval(Original);
VNInfo *OrigVNI = OrigLI.getVNInfoAt(UseIdx);
- LiveRangeEdit::Remat RM(ParentVNI);
- RM.OrigMI = LIS.getInstructionFromIndex(OrigVNI->def);
+ assert(OrigVNI && "corrupted sub-interval");
+ MachineInstr *DefMI = LIS.getInstructionFromIndex(OrigVNI->def);
+ // This can happen if for two reasons: 1) This could be a phi valno,
+ // or 2) the remat def has already been removed from the original
+ // live interval; this happens if we rematted to all uses, and
+ // then further split one of those live ranges.
+ if (!DefMI) {
+ markValueUsed(&VirtReg, ParentVNI);
+ LLVM_DEBUG(dbgs() << "\tcannot remat missing def for " << UseIdx << '\t'
+ << MI);
+ return false;
+ }
- if (!Edit->canRematerializeAt(RM, OrigVNI, UseIdx)) {
+ LiveRangeEdit::Remat RM(ParentVNI);
+ RM.OrigMI = DefMI;
+ if (!Edit->canRematerializeAt(RM, UseIdx)) {
markValueUsed(&VirtReg, ParentVNI);
LLVM_DEBUG(dbgs() << "\tcannot remat for " << UseIdx << '\t' << MI);
return false;
@@ -739,9 +751,6 @@ bool InlineSpiller::reMaterializeFor(LiveInterval &VirtReg, MachineInstr &MI) {
/// reMaterializeAll - Try to rematerialize as many uses as possible,
/// and trim the live ranges after.
void InlineSpiller::reMaterializeAll() {
- if (!Edit->anyRematerializable())
- return;
-
UsedValues.clear();
// Try to remat before all uses of snippets.
diff --git a/llvm/lib/CodeGen/LiveRangeEdit.cpp b/llvm/lib/CodeGen/LiveRangeEdit.cpp
index 59bc82d..5b0365d 100644
--- a/llvm/lib/CodeGen/LiveRangeEdit.cpp
+++ b/llvm/lib/CodeGen/LiveRangeEdit.cpp
@@ -68,41 +68,12 @@ Register LiveRangeEdit::createFrom(Register OldReg) {
return VReg;
}
-void LiveRangeEdit::scanRemattable() {
- for (VNInfo *VNI : getParent().valnos) {
- if (VNI->isUnused())
- continue;
- Register Original = VRM->getOriginal(getReg());
- LiveInterval &OrigLI = LIS.getInterval(Original);
- VNInfo *OrigVNI = OrigLI.getVNInfoAt(VNI->def);
- if (!OrigVNI)
- continue;
- MachineInstr *DefMI = LIS.getInstructionFromIndex(OrigVNI->def);
- if (!DefMI)
- continue;
- if (TII.isReMaterializable(*DefMI))
- Remattable.insert(OrigVNI);
- }
- ScannedRemattable = true;
-}
-
-bool LiveRangeEdit::anyRematerializable() {
- if (!ScannedRemattable)
- scanRemattable();
- return !Remattable.empty();
-}
-
-bool LiveRangeEdit::canRematerializeAt(Remat &RM, VNInfo *OrigVNI,
- SlotIndex UseIdx) {
- assert(ScannedRemattable && "Call anyRematerializable first");
+bool LiveRangeEdit::canRematerializeAt(Remat &RM, SlotIndex UseIdx) {
+ assert(RM.OrigMI && "No defining instruction for remattable value");
- // Use scanRemattable info.
- if (!Remattable.count(OrigVNI))
+ if (!TII.isReMaterializable(*RM.OrigMI))
return false;
- // No defining instruction provided.
- assert(RM.OrigMI && "No defining instruction for remattable value");
-
// Verify that all used registers are available with the same values.
if (!VirtRegAuxInfo::allUsesAvailableAt(RM.OrigMI, UseIdx, LIS, MRI, TII))
return false;
@@ -303,6 +274,37 @@ void LiveRangeEdit::eliminateDeadDef(MachineInstr *MI, ToShrinkSet &ToShrink) {
}
}
+ // If the dest of MI is an original reg and MI is reMaterializable,
+ // don't delete the inst. Replace the dest with a new reg, and keep
+ // the inst for remat of other siblings. The inst is saved in
+ // LiveRangeEdit::DeadRemats and will be deleted after all the
+ // allocations of the func are done. Note that if we keep the
+ // instruction with the original operands, that handles the physreg
+ // operand case (described just below) as well.
+ // However, immediately delete instructions which have unshrunk virtual
+ // register uses. That may provoke RA to split an interval at the KILL
+ // and later result in an invalid live segment end.
+ if (isOrigDef && DeadRemats && !HasLiveVRegUses &&
+ TII.isReMaterializable(*MI)) {
+ LiveInterval &NewLI = createEmptyIntervalFrom(Dest, false);
+ VNInfo::Allocator &Alloc = LIS.getVNInfoAllocator();
+ VNInfo *VNI = NewLI.getNextValue(Idx, Alloc);
+ NewLI.addSegment(LiveInterval::Segment(Idx, Idx.getDeadSlot(), VNI));
+
+ if (DestSubReg) {
+ const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
+ auto *SR =
+ NewLI.createSubRange(Alloc, TRI->getSubRegIndexLaneMask(DestSubReg));
+ SR->addSegment(LiveInterval::Segment(Idx, Idx.getDeadSlot(),
+ SR->getNextValue(Idx, Alloc)));
+ }
+
+ pop_back();
+ DeadRemats->insert(MI);
+ const TargetRegisterInfo &TRI = *MRI.getTargetRegisterInfo();
+ MI->substituteRegister(Dest, NewLI.reg(), 0, TRI);
+ assert(MI->registerDefIsDead(NewLI.reg(), &TRI));
+ }
// Currently, we don't support DCE of physreg live ranges. If MI reads
// any unreserved physregs, don't erase the instruction, but turn it into
// a KILL instead. This way, the physreg live ranges don't end up
@@ -310,7 +312,7 @@ void LiveRangeEdit::eliminateDeadDef(MachineInstr *MI, ToShrinkSet &ToShrink) {
// FIXME: It would be better to have something like shrinkToUses() for
// physregs. That could potentially enable more DCE and it would free up
// the physreg. It would not happen often, though.
- if (ReadsPhysRegs) {
+ else if (ReadsPhysRegs) {
MI->setDesc(TII.get(TargetOpcode::KILL));
// Remove all operands that aren't physregs.
for (unsigned i = MI->getNumOperands(); i; --i) {
@@ -322,41 +324,11 @@ void LiveRangeEdit::eliminateDeadDef(MachineInstr *MI, ToShrinkSet &ToShrink) {
MI->dropMemRefs(*MI->getMF());
LLVM_DEBUG(dbgs() << "Converted physregs to:\t" << *MI);
} else {
- // If the dest of MI is an original reg and MI is reMaterializable,
- // don't delete the inst. Replace the dest with a new reg, and keep
- // the inst for remat of other siblings. The inst is saved in
- // LiveRangeEdit::DeadRemats and will be deleted after all the
- // allocations of the func are done.
- // However, immediately delete instructions which have unshrunk virtual
- // register uses. That may provoke RA to split an interval at the KILL
- // and later result in an invalid live segment end.
- if (isOrigDef && DeadRemats && !HasLiveVRegUses &&
- TII.isReMaterializable(*MI)) {
- LiveInterval &NewLI = createEmptyIntervalFrom(Dest, false);
- VNInfo::Allocator &Alloc = LIS.getVNInfoAllocator();
- VNInfo *VNI = NewLI.getNextValue(Idx, Alloc);
- NewLI.addSegment(LiveInterval::Segment(Idx, Idx.getDeadSlot(), VNI));
-
- if (DestSubReg) {
- const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
- auto *SR = NewLI.createSubRange(
- Alloc, TRI->getSubRegIndexLaneMask(DestSubReg));
- SR->addSegment(LiveInterval::Segment(Idx, Idx.getDeadSlot(),
- SR->getNextValue(Idx, Alloc)));
- }
-
- pop_back();
- DeadRemats->insert(MI);
- const TargetRegisterInfo &TRI = *MRI.getTargetRegisterInfo();
- MI->substituteRegister(Dest, NewLI.reg(), 0, TRI);
- assert(MI->registerDefIsDead(NewLI.reg(), &TRI));
- } else {
- if (TheDelegate)
- TheDelegate->LRE_WillEraseInstruction(MI);
- LIS.RemoveMachineInstrFromMaps(*MI);
- MI->eraseFromParent();
- ++NumDCEDeleted;
- }
+ if (TheDelegate)
+ TheDelegate->LRE_WillEraseInstruction(MI);
+ LIS.RemoveMachineInstrFromMaps(*MI);
+ MI->eraseFromParent();
+ ++NumDCEDeleted;
}
// Erase any virtregs that are now empty and unused. There may be <undef>
diff --git a/llvm/lib/CodeGen/SplitKit.cpp b/llvm/lib/CodeGen/SplitKit.cpp
index f118ee5..f9ecb2c 100644
--- a/llvm/lib/CodeGen/SplitKit.cpp
+++ b/llvm/lib/CodeGen/SplitKit.cpp
@@ -376,8 +376,6 @@ void SplitEditor::reset(LiveRangeEdit &LRE, ComplementSpillMode SM) {
if (SpillMode)
LICalc[1].reset(&VRM.getMachineFunction(), LIS.getSlotIndexes(), &MDT,
&LIS.getVNInfoAllocator());
-
- Edit->anyRematerializable();
}
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
@@ -638,7 +636,7 @@ VNInfo *SplitEditor::defFromParent(unsigned RegIdx, const VNInfo *ParentVNI,
LiveRangeEdit::Remat RM(ParentVNI);
RM.OrigMI = LIS.getInstructionFromIndex(OrigVNI->def);
if (RM.OrigMI && TII.isAsCheapAsAMove(*RM.OrigMI) &&
- Edit->canRematerializeAt(RM, OrigVNI, UseIdx)) {
+ Edit->canRematerializeAt(RM, UseIdx)) {
if (!rematWillIncreaseRestriction(RM.OrigMI, MBB, UseIdx)) {
SlotIndex Def = Edit->rematerializeAt(MBB, I, Reg, RM, TRI, Late);
++NumRemats;
diff --git a/llvm/lib/ExecutionEngine/Orc/CMakeLists.txt b/llvm/lib/ExecutionEngine/Orc/CMakeLists.txt
index f159d59..0ffe3ae 100644
--- a/llvm/lib/ExecutionEngine/Orc/CMakeLists.txt
+++ b/llvm/lib/ExecutionEngine/Orc/CMakeLists.txt
@@ -24,6 +24,7 @@ add_llvm_component_library(LLVMOrcJIT
EPCGenericRTDyldMemoryManager.cpp
EPCIndirectionUtils.cpp
ExecutionUtils.cpp
+ ExecutorResolutionGenerator.cpp
ObjectFileInterface.cpp
GetDylibInterface.cpp
IndirectionUtils.cpp
diff --git a/llvm/lib/ExecutionEngine/Orc/EPCDebugObjectRegistrar.cpp b/llvm/lib/ExecutionEngine/Orc/EPCDebugObjectRegistrar.cpp
index 9f7d517..08bef37 100644
--- a/llvm/lib/ExecutionEngine/Orc/EPCDebugObjectRegistrar.cpp
+++ b/llvm/lib/ExecutionEngine/Orc/EPCDebugObjectRegistrar.cpp
@@ -42,7 +42,12 @@ Expected<std::unique_ptr<EPCDebugObjectRegistrar>> createJITLoaderGDBRegistrar(
assert((*Result)[0].size() == 1 &&
"Unexpected number of addresses in result");
- ExecutorAddr RegisterAddr = (*Result)[0][0].getAddress();
+ if (!(*Result)[0][0].has_value())
+ return make_error<StringError>(
+ "Expected a valid address in the lookup result",
+ inconvertibleErrorCode());
+
+ ExecutorAddr RegisterAddr = (*Result)[0][0]->getAddress();
return std::make_unique<EPCDebugObjectRegistrar>(ES, RegisterAddr);
}
diff --git a/llvm/lib/ExecutionEngine/Orc/EPCDynamicLibrarySearchGenerator.cpp b/llvm/lib/ExecutionEngine/Orc/EPCDynamicLibrarySearchGenerator.cpp
index 59d66b2..1e83c07 100644
--- a/llvm/lib/ExecutionEngine/Orc/EPCDynamicLibrarySearchGenerator.cpp
+++ b/llvm/lib/ExecutionEngine/Orc/EPCDynamicLibrarySearchGenerator.cpp
@@ -79,12 +79,16 @@ Error EPCDynamicLibrarySearchGenerator::tryToGenerate(
assert(Result->front().size() == LookupSymbols.size() &&
"Result has incorrect number of elements");
+ auto SymsIt = Result->front().begin();
+ SymbolNameSet MissingSymbols;
SymbolMap NewSymbols;
- auto ResultI = Result->front().begin();
- for (auto &KV : LookupSymbols) {
- if (ResultI->getAddress())
- NewSymbols[KV.first] = *ResultI;
- ++ResultI;
+ for (auto &[Name, Flags] : LookupSymbols) {
+ const auto &Sym = *SymsIt++;
+ if (Sym && Sym->getAddress())
+ NewSymbols[Name] = *Sym;
+ else if (LLVM_UNLIKELY(!Sym &&
+ Flags == SymbolLookupFlags::RequiredSymbol))
+ MissingSymbols.insert(Name);
}
LLVM_DEBUG({
@@ -96,6 +100,10 @@ Error EPCDynamicLibrarySearchGenerator::tryToGenerate(
if (NewSymbols.empty())
return LS.continueLookup(Error::success());
+ if (LLVM_UNLIKELY(!MissingSymbols.empty()))
+ return LS.continueLookup(make_error<SymbolsNotFound>(
+ this->EPC.getSymbolStringPool(), std::move(MissingSymbols)));
+
// Define resolved symbols.
Error Err = addAbsolutes(JD, std::move(NewSymbols));
diff --git a/llvm/lib/ExecutionEngine/Orc/EPCGenericDylibManager.cpp b/llvm/lib/ExecutionEngine/Orc/EPCGenericDylibManager.cpp
index f98b18c..1f19d17 100644
--- a/llvm/lib/ExecutionEngine/Orc/EPCGenericDylibManager.cpp
+++ b/llvm/lib/ExecutionEngine/Orc/EPCGenericDylibManager.cpp
@@ -66,7 +66,7 @@ EPCGenericDylibManager::CreateWithDefaultBootstrapSymbols(
if (auto Err = EPC.getBootstrapSymbols(
{{SAs.Instance, rt::SimpleExecutorDylibManagerInstanceName},
{SAs.Open, rt::SimpleExecutorDylibManagerOpenWrapperName},
- {SAs.Lookup, rt::SimpleExecutorDylibManagerLookupWrapperName}}))
+ {SAs.Resolve, rt::SimpleExecutorDylibManagerResolveWrapperName}}))
return std::move(Err);
return EPCGenericDylibManager(EPC, std::move(SAs));
}
@@ -84,11 +84,12 @@ Expected<tpctypes::DylibHandle> EPCGenericDylibManager::open(StringRef Path,
void EPCGenericDylibManager::lookupAsync(tpctypes::DylibHandle H,
const SymbolLookupSet &Lookup,
SymbolLookupCompleteFn Complete) {
- EPC.callSPSWrapperAsync<rt::SPSSimpleExecutorDylibManagerLookupSignature>(
- SAs.Lookup,
+ EPC.callSPSWrapperAsync<rt::SPSSimpleExecutorDylibManagerResolveSignature>(
+ SAs.Resolve,
[Complete = std::move(Complete)](
Error SerializationErr,
- Expected<std::vector<ExecutorSymbolDef>> Result) mutable {
+ Expected<std::vector<std::optional<ExecutorSymbolDef>>>
+ Result) mutable {
if (SerializationErr) {
cantFail(Result.takeError());
Complete(std::move(SerializationErr));
@@ -96,17 +97,18 @@ void EPCGenericDylibManager::lookupAsync(tpctypes::DylibHandle H,
}
Complete(std::move(Result));
},
- SAs.Instance, H, Lookup);
+ H, Lookup);
}
void EPCGenericDylibManager::lookupAsync(tpctypes::DylibHandle H,
const RemoteSymbolLookupSet &Lookup,
SymbolLookupCompleteFn Complete) {
- EPC.callSPSWrapperAsync<rt::SPSSimpleExecutorDylibManagerLookupSignature>(
- SAs.Lookup,
+ EPC.callSPSWrapperAsync<rt::SPSSimpleExecutorDylibManagerResolveSignature>(
+ SAs.Resolve,
[Complete = std::move(Complete)](
Error SerializationErr,
- Expected<std::vector<ExecutorSymbolDef>> Result) mutable {
+ Expected<std::vector<std::optional<ExecutorSymbolDef>>>
+ Result) mutable {
if (SerializationErr) {
cantFail(Result.takeError());
Complete(std::move(SerializationErr));
@@ -114,7 +116,7 @@ void EPCGenericDylibManager::lookupAsync(tpctypes::DylibHandle H,
}
Complete(std::move(Result));
},
- SAs.Instance, H, Lookup);
+ H, Lookup);
}
} // end namespace orc
diff --git a/llvm/lib/ExecutionEngine/Orc/ExecutorResolutionGenerator.cpp b/llvm/lib/ExecutionEngine/Orc/ExecutorResolutionGenerator.cpp
new file mode 100644
index 0000000..e5b0bd3
--- /dev/null
+++ b/llvm/lib/ExecutionEngine/Orc/ExecutorResolutionGenerator.cpp
@@ -0,0 +1,98 @@
+//===---- ExecutorProcessControl.cpp -- Executor process control APIs -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ExecutionEngine/Orc/ExecutorResolutionGenerator.h"
+
+#include "llvm/ExecutionEngine/Orc/DebugUtils.h"
+#include "llvm/ExecutionEngine/Orc/Shared/ExecutorSymbolDef.h"
+#include "llvm/Support/Error.h"
+
+#define DEBUG_TYPE "orc"
+
+namespace llvm {
+namespace orc {
+
+Expected<std::unique_ptr<ExecutorResolutionGenerator>>
+ExecutorResolutionGenerator::Load(ExecutionSession &ES, const char *LibraryPath,
+ SymbolPredicate Allow,
+ AbsoluteSymbolsFn AbsoluteSymbols) {
+ auto H = ES.getExecutorProcessControl().getDylibMgr().loadDylib(LibraryPath);
+ if (H)
+ return H.takeError();
+ return std::make_unique<ExecutorResolutionGenerator>(
+ ES, *H, std::move(Allow), std::move(AbsoluteSymbols));
+}
+
+Error ExecutorResolutionGenerator::tryToGenerate(
+ LookupState &LS, LookupKind K, JITDylib &JD,
+ JITDylibLookupFlags JDLookupFlags, const SymbolLookupSet &LookupSet) {
+
+ if (LookupSet.empty())
+ return Error::success();
+
+ LLVM_DEBUG({
+ dbgs() << "ExecutorResolutionGenerator trying to generate " << LookupSet
+ << "\n";
+ });
+
+ SymbolLookupSet LookupSymbols;
+ for (auto &[Name, LookupFlag] : LookupSet) {
+ if (Allow && !Allow(Name))
+ continue;
+ LookupSymbols.add(Name, LookupFlag);
+ }
+
+ DylibManager::LookupRequest LR(H, LookupSymbols);
+ EPC.getDylibMgr().lookupSymbolsAsync(
+ LR, [this, LS = std::move(LS), JD = JITDylibSP(&JD),
+ LookupSymbols](auto Result) mutable {
+ if (Result) {
+ LLVM_DEBUG({
+ dbgs() << "ExecutorResolutionGenerator lookup failed due to error";
+ });
+ return LS.continueLookup(Result.takeError());
+ }
+ assert(Result->size() == 1 &&
+ "Results for more than one library returned");
+ assert(Result->front().size() == LookupSymbols.size() &&
+ "Result has incorrect number of elements");
+
+ // const tpctypes::LookupResult &Syms = Result->front();
+ // size_t SymIdx = 0;
+ auto Syms = Result->front().begin();
+ SymbolNameSet MissingSymbols;
+ SymbolMap NewSyms;
+ for (auto &[Name, Flags] : LookupSymbols) {
+ const auto &Sym = *Syms++;
+ if (Sym && Sym->getAddress())
+ NewSyms[Name] = *Sym;
+ else if (LLVM_UNLIKELY(!Sym &&
+ Flags == SymbolLookupFlags::RequiredSymbol))
+ MissingSymbols.insert(Name);
+ }
+
+ LLVM_DEBUG({
+ dbgs() << "ExecutorResolutionGenerator lookup returned " << NewSyms
+ << "\n";
+ });
+
+ if (NewSyms.empty())
+ return LS.continueLookup(Error::success());
+
+ if (LLVM_UNLIKELY(!MissingSymbols.empty()))
+ return LS.continueLookup(make_error<SymbolsNotFound>(
+ this->EPC.getSymbolStringPool(), std::move(MissingSymbols)));
+
+ LS.continueLookup(JD->define(AbsoluteSymbols(std::move(NewSyms))));
+ });
+
+ return Error::success();
+}
+
+} // end namespace orc
+} // end namespace llvm
diff --git a/llvm/lib/ExecutionEngine/Orc/LookupAndRecordAddrs.cpp b/llvm/lib/ExecutionEngine/Orc/LookupAndRecordAddrs.cpp
index 78169a2..42d630d 100644
--- a/llvm/lib/ExecutionEngine/Orc/LookupAndRecordAddrs.cpp
+++ b/llvm/lib/ExecutionEngine/Orc/LookupAndRecordAddrs.cpp
@@ -72,9 +72,10 @@ Error lookupAndRecordAddrs(
return make_error<StringError>("Error in lookup result elements",
inconvertibleErrorCode());
- for (unsigned I = 0; I != Pairs.size(); ++I)
- *Pairs[I].second = Result->front()[I].getAddress();
-
+ for (unsigned I = 0; I != Pairs.size(); ++I) {
+ if (Result->front()[I])
+ *Pairs[I].second = Result->front()[I]->getAddress();
+ }
return Error::success();
}
diff --git a/llvm/lib/ExecutionEngine/Orc/SelfExecutorProcessControl.cpp b/llvm/lib/ExecutionEngine/Orc/SelfExecutorProcessControl.cpp
index 78045f1..f8a2bd3 100644
--- a/llvm/lib/ExecutionEngine/Orc/SelfExecutorProcessControl.cpp
+++ b/llvm/lib/ExecutionEngine/Orc/SelfExecutorProcessControl.cpp
@@ -91,22 +91,18 @@ void SelfExecutorProcessControl::lookupSymbolsAsync(
for (auto &Elem : Request) {
sys::DynamicLibrary Dylib(Elem.Handle.toPtr<void *>());
- R.push_back(std::vector<ExecutorSymbolDef>());
+ R.push_back(tpctypes::LookupResult());
for (auto &KV : Elem.Symbols) {
auto &Sym = KV.first;
std::string Tmp((*Sym).data() + !!GlobalManglingPrefix,
(*Sym).size() - !!GlobalManglingPrefix);
void *Addr = Dylib.getAddressOfSymbol(Tmp.c_str());
- if (!Addr && KV.second == SymbolLookupFlags::RequiredSymbol) {
- // FIXME: Collect all failing symbols before erroring out.
- SymbolNameVector MissingSymbols;
- MissingSymbols.push_back(Sym);
- return Complete(
- make_error<SymbolsNotFound>(SSP, std::move(MissingSymbols)));
- }
- // FIXME: determine accurate JITSymbolFlags.
- R.back().push_back(
- {ExecutorAddr::fromPtr(Addr), JITSymbolFlags::Exported});
+ if (!Addr && KV.second == SymbolLookupFlags::RequiredSymbol)
+ R.back().emplace_back();
+ else
+ // FIXME: determine accurate JITSymbolFlags.
+ R.back().emplace_back(ExecutorSymbolDef(ExecutorAddr::fromPtr(Addr),
+ JITSymbolFlags::Exported));
}
}
diff --git a/llvm/lib/ExecutionEngine/Orc/Shared/OrcRTBridge.cpp b/llvm/lib/ExecutionEngine/Orc/Shared/OrcRTBridge.cpp
index 123651f..26e8f53 100644
--- a/llvm/lib/ExecutionEngine/Orc/Shared/OrcRTBridge.cpp
+++ b/llvm/lib/ExecutionEngine/Orc/Shared/OrcRTBridge.cpp
@@ -16,8 +16,8 @@ const char *SimpleExecutorDylibManagerInstanceName =
"__llvm_orc_SimpleExecutorDylibManager_Instance";
const char *SimpleExecutorDylibManagerOpenWrapperName =
"__llvm_orc_SimpleExecutorDylibManager_open_wrapper";
-const char *SimpleExecutorDylibManagerLookupWrapperName =
- "__llvm_orc_SimpleExecutorDylibManager_lookup_wrapper";
+const char *SimpleExecutorDylibManagerResolveWrapperName =
+ "__llvm_orc_SimpleExecutorDylibManager_resolve_wrapper";
const char *SimpleExecutorMemoryManagerInstanceName =
"__llvm_orc_SimpleExecutorMemoryManager_Instance";
diff --git a/llvm/lib/ExecutionEngine/Orc/TargetProcess/CMakeLists.txt b/llvm/lib/ExecutionEngine/Orc/TargetProcess/CMakeLists.txt
index 9f3abac..9275586 100644
--- a/llvm/lib/ExecutionEngine/Orc/TargetProcess/CMakeLists.txt
+++ b/llvm/lib/ExecutionEngine/Orc/TargetProcess/CMakeLists.txt
@@ -15,6 +15,7 @@ endif()
add_llvm_component_library(LLVMOrcTargetProcess
ExecutorSharedMemoryMapperService.cpp
DefaultHostBootstrapValues.cpp
+ ExecutorResolver.cpp
JITLoaderGDB.cpp
JITLoaderPerf.cpp
JITLoaderVTune.cpp
diff --git a/llvm/lib/ExecutionEngine/Orc/TargetProcess/ExecutorResolver.cpp b/llvm/lib/ExecutionEngine/Orc/TargetProcess/ExecutorResolver.cpp
new file mode 100644
index 0000000..6054d86
--- /dev/null
+++ b/llvm/lib/ExecutionEngine/Orc/TargetProcess/ExecutorResolver.cpp
@@ -0,0 +1,47 @@
+
+#include "llvm/ExecutionEngine/Orc/TargetProcess/ExecutorResolver.h"
+
+#include "llvm/Support/DynamicLibrary.h"
+#include "llvm/Support/Error.h"
+
+namespace llvm::orc {
+
+void DylibSymbolResolver::resolveAsync(
+ const RemoteSymbolLookupSet &L,
+ ExecutorResolver::YieldResolveResultFn &&OnResolve) {
+ std::vector<std::optional<ExecutorSymbolDef>> Result;
+ auto DL = sys::DynamicLibrary(Handle.toPtr<void *>());
+
+ for (const auto &E : L) {
+ if (E.Name.empty()) {
+ if (E.Required)
+ OnResolve(
+ make_error<StringError>("Required address for empty symbol \"\"",
+ inconvertibleErrorCode()));
+ else
+ Result.emplace_back();
+ } else {
+
+ const char *DemangledSymName = E.Name.c_str();
+#ifdef __APPLE__
+ if (E.Name.front() != '_')
+ OnResolve(make_error<StringError>(Twine("MachO symbol \"") + E.Name +
+ "\" missing leading '_'",
+ inconvertibleErrorCode()));
+ ++DemangledSymName;
+#endif
+
+ void *Addr = DL.getAddressOfSymbol(DemangledSymName);
+ if (!Addr && E.Required)
+ Result.emplace_back();
+ else
+ // FIXME: determine accurate JITSymbolFlags.
+ Result.emplace_back(ExecutorSymbolDef(ExecutorAddr::fromPtr(Addr),
+ JITSymbolFlags::Exported));
+ }
+ }
+
+ OnResolve(std::move(Result));
+}
+
+} // end namespace llvm::orc \ No newline at end of file
diff --git a/llvm/lib/ExecutionEngine/Orc/TargetProcess/SimpleExecutorDylibManager.cpp b/llvm/lib/ExecutionEngine/Orc/TargetProcess/SimpleExecutorDylibManager.cpp
index db6f201..52bb55d 100644
--- a/llvm/lib/ExecutionEngine/Orc/TargetProcess/SimpleExecutorDylibManager.cpp
+++ b/llvm/lib/ExecutionEngine/Orc/TargetProcess/SimpleExecutorDylibManager.cpp
@@ -10,6 +10,10 @@
#include "llvm/ExecutionEngine/Orc/Shared/OrcRTBridge.h"
+#include "llvm/Support/MSVCErrorWorkarounds.h"
+
+#include <future>
+
#define DEBUG_TYPE "orc"
namespace llvm {
@@ -35,46 +39,9 @@ SimpleExecutorDylibManager::open(const std::string &Path, uint64_t Mode) {
std::lock_guard<std::mutex> Lock(M);
auto H = ExecutorAddr::fromPtr(DL.getOSSpecificHandle());
+ Resolvers.push_back(std::make_unique<DylibSymbolResolver>(H));
Dylibs.insert(DL.getOSSpecificHandle());
- return H;
-}
-
-Expected<std::vector<ExecutorSymbolDef>>
-SimpleExecutorDylibManager::lookup(tpctypes::DylibHandle H,
- const RemoteSymbolLookupSet &L) {
- std::vector<ExecutorSymbolDef> Result;
- auto DL = sys::DynamicLibrary(H.toPtr<void *>());
-
- for (const auto &E : L) {
- if (E.Name.empty()) {
- if (E.Required)
- return make_error<StringError>("Required address for empty symbol \"\"",
- inconvertibleErrorCode());
- else
- Result.push_back(ExecutorSymbolDef());
- } else {
-
- const char *DemangledSymName = E.Name.c_str();
-#ifdef __APPLE__
- if (E.Name.front() != '_')
- return make_error<StringError>(Twine("MachO symbol \"") + E.Name +
- "\" missing leading '_'",
- inconvertibleErrorCode());
- ++DemangledSymName;
-#endif
-
- void *Addr = DL.getAddressOfSymbol(DemangledSymName);
- if (!Addr && E.Required)
- return make_error<StringError>(Twine("Missing definition for ") +
- DemangledSymName,
- inconvertibleErrorCode());
-
- // FIXME: determine accurate JITSymbolFlags.
- Result.push_back({ExecutorAddr::fromPtr(Addr), JITSymbolFlags::Exported});
- }
- }
-
- return Result;
+ return ExecutorAddr::fromPtr(Resolvers.back().get());
}
Error SimpleExecutorDylibManager::shutdown() {
@@ -94,8 +61,8 @@ void SimpleExecutorDylibManager::addBootstrapSymbols(
M[rt::SimpleExecutorDylibManagerInstanceName] = ExecutorAddr::fromPtr(this);
M[rt::SimpleExecutorDylibManagerOpenWrapperName] =
ExecutorAddr::fromPtr(&openWrapper);
- M[rt::SimpleExecutorDylibManagerLookupWrapperName] =
- ExecutorAddr::fromPtr(&lookupWrapper);
+ M[rt::SimpleExecutorDylibManagerResolveWrapperName] =
+ ExecutorAddr::fromPtr(&resolveWrapper);
}
llvm::orc::shared::CWrapperFunctionResult
@@ -109,12 +76,22 @@ SimpleExecutorDylibManager::openWrapper(const char *ArgData, size_t ArgSize) {
}
llvm::orc::shared::CWrapperFunctionResult
-SimpleExecutorDylibManager::lookupWrapper(const char *ArgData, size_t ArgSize) {
- return shared::
- WrapperFunction<rt::SPSSimpleExecutorDylibManagerLookupSignature>::handle(
- ArgData, ArgSize,
- shared::makeMethodWrapperHandler(
- &SimpleExecutorDylibManager::lookup))
+SimpleExecutorDylibManager::resolveWrapper(const char *ArgData,
+ size_t ArgSize) {
+ using ResolveResult = ExecutorResolver::ResolveResult;
+ return shared::WrapperFunction<
+ rt::SPSSimpleExecutorDylibManagerResolveSignature>::
+ handle(ArgData, ArgSize,
+ [](ExecutorAddr Obj, RemoteSymbolLookupSet L) -> ResolveResult {
+ using TmpResult =
+ MSVCPExpected<std::vector<std::optional<ExecutorSymbolDef>>>;
+ std::promise<TmpResult> P;
+ auto F = P.get_future();
+ Obj.toPtr<ExecutorResolver *>()->resolveAsync(
+ std::move(L),
+ [&](TmpResult R) { P.set_value(std::move(R)); });
+ return F.get();
+ })
.release();
}
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
index 7a0cf40..707f0c3 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
@@ -651,8 +651,11 @@ Error MetadataParser::validateRootSignature(
"RegisterSpace", Descriptor.RegisterSpace));
if (RSD.Version > 1) {
- if (!hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version,
- Descriptor.Flags))
+ bool IsValidFlag =
+ dxbc::isValidRootDesciptorFlags(Descriptor.Flags) &&
+ hlsl::rootsig::verifyRootDescriptorFlag(
+ RSD.Version, dxbc::RootDescriptorFlags(Descriptor.Flags));
+ if (!IsValidFlag)
DeferredErrs =
joinErrors(std::move(DeferredErrs),
make_error<RootSignatureValidationError<uint32_t>>(
@@ -676,9 +679,11 @@ Error MetadataParser::validateRootSignature(
make_error<RootSignatureValidationError<uint32_t>>(
"NumDescriptors", Range.NumDescriptors));
- if (!hlsl::rootsig::verifyDescriptorRangeFlag(
- RSD.Version, Range.RangeType,
- dxbc::DescriptorRangeFlags(Range.Flags)))
+ bool IsValidFlag = dxbc::isValidDescriptorRangeFlags(Range.Flags) &&
+ hlsl::rootsig::verifyDescriptorRangeFlag(
+ RSD.Version, Range.RangeType,
+ dxbc::DescriptorRangeFlags(Range.Flags));
+ if (!IsValidFlag)
DeferredErrs =
joinErrors(std::move(DeferredErrs),
make_error<RootSignatureValidationError<uint32_t>>(
@@ -731,8 +736,11 @@ Error MetadataParser::validateRootSignature(
joinErrors(std::move(DeferredErrs),
make_error<RootSignatureValidationError<uint32_t>>(
"RegisterSpace", Sampler.RegisterSpace));
-
- if (!hlsl::rootsig::verifyStaticSamplerFlags(RSD.Version, Sampler.Flags))
+ bool IsValidFlag =
+ dxbc::isValidStaticSamplerFlags(Sampler.Flags) &&
+ hlsl::rootsig::verifyStaticSamplerFlags(
+ RSD.Version, dxbc::StaticSamplerFlags(Sampler.Flags));
+ if (!IsValidFlag)
DeferredErrs =
joinErrors(std::move(DeferredErrs),
make_error<RootSignatureValidationError<uint32_t>>(
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
index 8a2b03d..30408df 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
@@ -34,7 +34,8 @@ bool verifyRegisterSpace(uint32_t RegisterSpace) {
return !(RegisterSpace >= 0xFFFFFFF0);
}
-bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) {
+bool verifyRootDescriptorFlag(uint32_t Version,
+ dxbc::RootDescriptorFlags FlagsVal) {
using FlagT = dxbc::RootDescriptorFlags;
FlagT Flags = FlagT(FlagsVal);
if (Version == 1)
@@ -56,7 +57,6 @@ bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) {
bool verifyDescriptorRangeFlag(uint32_t Version, dxil::ResourceClass Type,
dxbc::DescriptorRangeFlags Flags) {
using FlagT = dxbc::DescriptorRangeFlags;
-
const bool IsSampler = (Type == dxil::ResourceClass::Sampler);
if (Version == 1) {
@@ -113,13 +113,8 @@ bool verifyDescriptorRangeFlag(uint32_t Version, dxil::ResourceClass Type,
return (Flags & ~Mask) == FlagT::None;
}
-bool verifyStaticSamplerFlags(uint32_t Version, uint32_t FlagsNumber) {
- uint32_t LargestValue = llvm::to_underlying(
- dxbc::StaticSamplerFlags::LLVM_BITMASK_LARGEST_ENUMERATOR);
- if (FlagsNumber >= NextPowerOf2(LargestValue))
- return false;
-
- dxbc::StaticSamplerFlags Flags = dxbc::StaticSamplerFlags(FlagsNumber);
+bool verifyStaticSamplerFlags(uint32_t Version,
+ dxbc::StaticSamplerFlags Flags) {
if (Version <= 2)
return Flags == dxbc::StaticSamplerFlags::None;
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index 6b202ba..3842b1a 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -55,15 +55,8 @@ foldConstantCastPair(
Type *MidTy = Op->getType();
Instruction::CastOps firstOp = Instruction::CastOps(Op->getOpcode());
Instruction::CastOps secondOp = Instruction::CastOps(opc);
-
- // Assume that pointers are never more than 64 bits wide, and only use this
- // for the middle type. Otherwise we could end up folding away illegal
- // bitcasts between address spaces with different sizes.
- IntegerType *FakeIntPtrTy = Type::getInt64Ty(DstTy->getContext());
-
- // Let CastInst::isEliminableCastPair do the heavy lifting.
return CastInst::isEliminableCastPair(firstOp, secondOp, SrcTy, MidTy, DstTy,
- nullptr, FakeIntPtrTy, nullptr);
+ /*DL=*/nullptr);
}
static Constant *FoldBitCast(Constant *V, Type *DestTy) {
diff --git a/llvm/lib/IR/Core.cpp b/llvm/lib/IR/Core.cpp
index df0c85b..3f1cc1e 100644
--- a/llvm/lib/IR/Core.cpp
+++ b/llvm/lib/IR/Core.cpp
@@ -2403,6 +2403,14 @@ LLVMValueRef LLVMAddFunction(LLVMModuleRef M, const char *Name,
GlobalValue::ExternalLinkage, Name, unwrap(M)));
}
+LLVMValueRef LLVMGetOrInsertFunction(LLVMModuleRef M, const char *Name,
+ size_t NameLen, LLVMTypeRef FunctionTy) {
+ return wrap(unwrap(M)
+ ->getOrInsertFunction(StringRef(Name, NameLen),
+ unwrap<FunctionType>(FunctionTy))
+ .getCallee());
+}
+
LLVMValueRef LLVMGetNamedFunction(LLVMModuleRef M, const char *Name) {
return wrap(unwrap(M)->getFunction(Name));
}
diff --git a/llvm/lib/IR/Globals.cpp b/llvm/lib/IR/Globals.cpp
index 1a7a5c5..c3a472b 100644
--- a/llvm/lib/IR/Globals.cpp
+++ b/llvm/lib/IR/Globals.cpp
@@ -419,6 +419,7 @@ findBaseObject(const Constant *C, DenseSet<const GlobalAlias *> &Aliases,
case Instruction::PtrToAddr:
case Instruction::PtrToInt:
case Instruction::BitCast:
+ case Instruction::AddrSpaceCast:
case Instruction::GetElementPtr:
return findBaseObject(CE->getOperand(0), Aliases, Op);
default:
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 941e41f..88e7c44 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -2824,10 +2824,10 @@ bool CastInst::isNoopCast(const DataLayout &DL) const {
/// The function returns a resultOpcode so these two casts can be replaced with:
/// * %Replacement = resultOpcode %SrcTy %x to DstTy
/// If no such cast is permitted, the function returns 0.
-unsigned CastInst::isEliminableCastPair(
- Instruction::CastOps firstOp, Instruction::CastOps secondOp,
- Type *SrcTy, Type *MidTy, Type *DstTy, Type *SrcIntPtrTy, Type *MidIntPtrTy,
- Type *DstIntPtrTy) {
+unsigned CastInst::isEliminableCastPair(Instruction::CastOps firstOp,
+ Instruction::CastOps secondOp,
+ Type *SrcTy, Type *MidTy, Type *DstTy,
+ const DataLayout *DL) {
// Define the 144 possibilities for these two cast instructions. The values
// in this matrix determine what to do in a given situation and select the
// case in the switch below. The rows correspond to firstOp, the columns
@@ -2936,24 +2936,16 @@ unsigned CastInst::isEliminableCastPair(
return 0;
// Cannot simplify if address spaces are different!
- if (SrcTy->getPointerAddressSpace() != DstTy->getPointerAddressSpace())
+ if (SrcTy != DstTy)
return 0;
- unsigned MidSize = MidTy->getScalarSizeInBits();
- // We can still fold this without knowing the actual sizes as long we
- // know that the intermediate pointer is the largest possible
+ // Cannot simplify if the intermediate integer size is smaller than the
// pointer size.
- // FIXME: Is this always true?
- if (MidSize == 64)
- return Instruction::BitCast;
-
- // ptrtoint, inttoptr -> bitcast (ptr -> ptr) if int size is >= ptr size.
- if (!SrcIntPtrTy || DstIntPtrTy != SrcIntPtrTy)
+ unsigned MidSize = MidTy->getScalarSizeInBits();
+ if (!DL || MidSize < DL->getPointerTypeSizeInBits(SrcTy))
return 0;
- unsigned PtrSize = SrcIntPtrTy->getScalarSizeInBits();
- if (MidSize >= PtrSize)
- return Instruction::BitCast;
- return 0;
+
+ return Instruction::BitCast;
}
case 8: {
// ext, trunc -> bitcast, if the SrcTy and DstTy are the same
@@ -2973,14 +2965,17 @@ unsigned CastInst::isEliminableCastPair(
// zext, sext -> zext, because sext can't sign extend after zext
return Instruction::ZExt;
case 11: {
- // inttoptr, ptrtoint/ptrtoaddr -> bitcast if SrcSize<=PtrSize and
- // SrcSize==DstSize
- if (!MidIntPtrTy)
+ // inttoptr, ptrtoint/ptrtoaddr -> bitcast if SrcSize<=PtrSize/AddrSize
+ // and SrcSize==DstSize
+ if (!DL)
return 0;
- unsigned PtrSize = MidIntPtrTy->getScalarSizeInBits();
+ unsigned MidSize = secondOp == Instruction::PtrToAddr
+ ? DL->getAddressSizeInBits(MidTy)
+ : DL->getPointerTypeSizeInBits(MidTy);
unsigned SrcSize = SrcTy->getScalarSizeInBits();
unsigned DstSize = DstTy->getScalarSizeInBits();
- if (SrcSize <= PtrSize && SrcSize == DstSize)
+ // TODO: Could also produce zext or trunc here.
+ if (SrcSize <= MidSize && SrcSize == DstSize)
return Instruction::BitCast;
return 0;
}
diff --git a/llvm/lib/IR/Mangler.cpp b/llvm/lib/IR/Mangler.cpp
index ca6a480..55c825d 100644
--- a/llvm/lib/IR/Mangler.cpp
+++ b/llvm/lib/IR/Mangler.cpp
@@ -307,6 +307,19 @@ std::optional<std::string> llvm::getArm64ECMangledFunctionName(StringRef Name) {
if (Name.contains("$$h"))
return std::nullopt;
+ // Handle MD5 mangled names, which use a slightly different rule from
+ // other C++ manglings.
+ //
+ // A non-Arm64EC function:
+ //
+ // ??@aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa@
+ //
+ // An Arm64EC function:
+ //
+ // ??@aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa@$$h@
+ if (Name.starts_with("??@") && Name.ends_with("@"))
+ return (Name + "$$h@").str();
+
// Ask the demangler where we should insert "$$h".
auto InsertIdx = getArm64ECInsertionPointInMangledName(Name);
if (!InsertIdx)
@@ -324,6 +337,10 @@ llvm::getArm64ECDemangledFunctionName(StringRef Name) {
if (Name[0] != '?')
return std::nullopt;
+ // MD5 mangled name; see comment in getArm64ECMangledFunctionName.
+ if (Name.starts_with("??@") && Name.ends_with("@$$h@"))
+ return Name.drop_back(4).str();
+
// Drop the ARM64EC "$$h" tag.
std::pair<StringRef, StringRef> Pair = Name.split("$$h");
if (Pair.second.empty())
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 6b3cd27..71a8a38 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -543,6 +543,7 @@ private:
void visitAliasScopeListMetadata(const MDNode *MD);
void visitAccessGroupMetadata(const MDNode *MD);
void visitCapturesMetadata(Instruction &I, const MDNode *Captures);
+ void visitAllocTokenMetadata(Instruction &I, MDNode *MD);
template <class Ty> bool isValidMetadataArray(const MDTuple &N);
#define HANDLE_SPECIALIZED_MDNODE_LEAF(CLASS) void visit##CLASS(const CLASS &N);
@@ -5395,6 +5396,12 @@ void Verifier::visitCapturesMetadata(Instruction &I, const MDNode *Captures) {
}
}
+void Verifier::visitAllocTokenMetadata(Instruction &I, MDNode *MD) {
+ Check(isa<CallBase>(I), "!alloc_token should only exist on calls", &I);
+ Check(MD->getNumOperands() == 1, "!alloc_token must have 1 operand", MD);
+ Check(isa<MDString>(MD->getOperand(0)), "expected string", MD);
+}
+
/// verifyInstruction - Verify that an instruction is well formed.
///
void Verifier::visitInstruction(Instruction &I) {
@@ -5625,6 +5632,9 @@ void Verifier::visitInstruction(Instruction &I) {
if (MDNode *Captures = I.getMetadata(LLVMContext::MD_captures))
visitCapturesMetadata(I, Captures);
+ if (MDNode *MD = I.getMetadata(LLVMContext::MD_alloc_token))
+ visitAllocTokenMetadata(I, MD);
+
if (MDNode *N = I.getDebugLoc().getAsMDNode()) {
CheckDI(isa<DILocation>(N), "invalid !dbg metadata attachment", &I, N);
visitMDNode(*N, AreDebugLocsAllowed::Yes);
diff --git a/llvm/lib/Object/OffloadBundle.cpp b/llvm/lib/Object/OffloadBundle.cpp
index 329dcbf..046cde8 100644
--- a/llvm/lib/Object/OffloadBundle.cpp
+++ b/llvm/lib/Object/OffloadBundle.cpp
@@ -25,38 +25,71 @@
using namespace llvm;
using namespace llvm::object;
-static llvm::TimerGroup
- OffloadBundlerTimerGroup("Offload Bundler Timer Group",
- "Timer group for offload bundler");
+static TimerGroup OffloadBundlerTimerGroup("Offload Bundler Timer Group",
+ "Timer group for offload bundler");
// Extract an Offload bundle (usually a Offload Bundle) from a fat_bin
-// section
+// section.
Error extractOffloadBundle(MemoryBufferRef Contents, uint64_t SectionOffset,
StringRef FileName,
SmallVectorImpl<OffloadBundleFatBin> &Bundles) {
size_t Offset = 0;
size_t NextbundleStart = 0;
+ StringRef Magic;
+ std::unique_ptr<MemoryBuffer> Buffer;
// There could be multiple offloading bundles stored at this section.
- while (NextbundleStart != StringRef::npos) {
- std::unique_ptr<MemoryBuffer> Buffer =
+ while ((NextbundleStart != StringRef::npos) &&
+ (Offset < Contents.getBuffer().size())) {
+ Buffer =
MemoryBuffer::getMemBuffer(Contents.getBuffer().drop_front(Offset), "",
/*RequiresNullTerminator=*/false);
- // Create the FatBinBindle object. This will also create the Bundle Entry
- // list info.
- auto FatBundleOrErr =
- OffloadBundleFatBin::create(*Buffer, SectionOffset + Offset, FileName);
- if (!FatBundleOrErr)
- return FatBundleOrErr.takeError();
-
- // Add current Bundle to list.
- Bundles.emplace_back(std::move(**FatBundleOrErr));
-
- // Find the next bundle by searching for the magic string
- StringRef Str = Buffer->getBuffer();
- NextbundleStart = Str.find(StringRef("__CLANG_OFFLOAD_BUNDLE__"), 24);
+ if (identify_magic((*Buffer).getBuffer()) ==
+ file_magic::offload_bundle_compressed) {
+ Magic = "CCOB";
+ // Decompress this bundle first.
+ NextbundleStart = (*Buffer).getBuffer().find(Magic, Magic.size());
+ if (NextbundleStart == StringRef::npos)
+ NextbundleStart = (*Buffer).getBuffer().size();
+
+ ErrorOr<std::unique_ptr<MemoryBuffer>> CodeOrErr =
+ MemoryBuffer::getMemBuffer(
+ (*Buffer).getBuffer().take_front(NextbundleStart), FileName,
+ false);
+ if (std::error_code EC = CodeOrErr.getError())
+ return createFileError(FileName, EC);
+
+ Expected<std::unique_ptr<MemoryBuffer>> DecompressedBufferOrErr =
+ CompressedOffloadBundle::decompress(**CodeOrErr, nullptr);
+ if (!DecompressedBufferOrErr)
+ return createStringError("failed to decompress input: " +
+ toString(DecompressedBufferOrErr.takeError()));
+
+ auto FatBundleOrErr = OffloadBundleFatBin::create(
+ **DecompressedBufferOrErr, Offset, FileName, true);
+ if (!FatBundleOrErr)
+ return FatBundleOrErr.takeError();
+
+ // Add current Bundle to list.
+ Bundles.emplace_back(std::move(**FatBundleOrErr));
+
+ } else if (identify_magic((*Buffer).getBuffer()) ==
+ file_magic::offload_bundle) {
+ // Create the OffloadBundleFatBin object. This will also create the Bundle
+ // Entry list info.
+ auto FatBundleOrErr = OffloadBundleFatBin::create(
+ *Buffer, SectionOffset + Offset, FileName);
+ if (!FatBundleOrErr)
+ return FatBundleOrErr.takeError();
+
+ // Add current Bundle to list.
+ Bundles.emplace_back(std::move(**FatBundleOrErr));
+
+ Magic = "__CLANG_OFFLOAD_BUNDLE__";
+ NextbundleStart = (*Buffer).getBuffer().find(Magic, Magic.size());
+ }
if (NextbundleStart != StringRef::npos)
Offset += NextbundleStart;
@@ -82,7 +115,7 @@ Error OffloadBundleFatBin::readEntries(StringRef Buffer,
NumberOfEntries = NumOfEntries;
- // For each Bundle Entry (code object)
+ // For each Bundle Entry (code object).
for (uint64_t I = 0; I < NumOfEntries; I++) {
uint64_t EntrySize;
uint64_t EntryOffset;
@@ -112,19 +145,22 @@ Error OffloadBundleFatBin::readEntries(StringRef Buffer,
Expected<std::unique_ptr<OffloadBundleFatBin>>
OffloadBundleFatBin::create(MemoryBufferRef Buf, uint64_t SectionOffset,
- StringRef FileName) {
+ StringRef FileName, bool Decompress) {
if (Buf.getBufferSize() < 24)
return errorCodeToError(object_error::parse_failed);
// Check for magic bytes.
- if (identify_magic(Buf.getBuffer()) != file_magic::offload_bundle)
+ if ((identify_magic(Buf.getBuffer()) != file_magic::offload_bundle) &&
+ (identify_magic(Buf.getBuffer()) !=
+ file_magic::offload_bundle_compressed))
return errorCodeToError(object_error::parse_failed);
std::unique_ptr<OffloadBundleFatBin> TheBundle(
new OffloadBundleFatBin(Buf, FileName));
- // Read the Bundle Entries
- Error Err = TheBundle->readEntries(Buf.getBuffer(), SectionOffset);
+ // Read the Bundle Entries.
+ Error Err =
+ TheBundle->readEntries(Buf.getBuffer(), Decompress ? 0 : SectionOffset);
if (Err)
return Err;
@@ -132,7 +168,7 @@ OffloadBundleFatBin::create(MemoryBufferRef Buf, uint64_t SectionOffset,
}
Error OffloadBundleFatBin::extractBundle(const ObjectFile &Source) {
- // This will extract all entries in the Bundle
+ // This will extract all entries in the Bundle.
for (OffloadBundleEntry &Entry : Entries) {
if (Entry.Size == 0)
@@ -161,40 +197,21 @@ Error object::extractOffloadBundleFatBinary(
return Buffer.takeError();
// If it does not start with the reserved suffix, just skip this section.
- if ((llvm::identify_magic(*Buffer) == llvm::file_magic::offload_bundle) ||
+ if ((llvm::identify_magic(*Buffer) == file_magic::offload_bundle) ||
(llvm::identify_magic(*Buffer) ==
- llvm::file_magic::offload_bundle_compressed)) {
+ file_magic::offload_bundle_compressed)) {
uint64_t SectionOffset = 0;
if (Obj.isELF()) {
SectionOffset = ELFSectionRef(Sec).getOffset();
- } else if (Obj.isCOFF()) // TODO: add COFF Support
+ } else if (Obj.isCOFF()) // TODO: add COFF Support.
return createStringError(object_error::parse_failed,
- "COFF object files not supported.\n");
+ "COFF object files not supported");
MemoryBufferRef Contents(*Buffer, Obj.getFileName());
-
- if (llvm::identify_magic(*Buffer) ==
- llvm::file_magic::offload_bundle_compressed) {
- // Decompress the input if necessary.
- Expected<std::unique_ptr<MemoryBuffer>> DecompressedBufferOrErr =
- CompressedOffloadBundle::decompress(Contents, false);
-
- if (!DecompressedBufferOrErr)
- return createStringError(
- inconvertibleErrorCode(),
- "Failed to decompress input: " +
- llvm::toString(DecompressedBufferOrErr.takeError()));
-
- MemoryBuffer &DecompressedInput = **DecompressedBufferOrErr;
- if (Error Err = extractOffloadBundle(DecompressedInput, SectionOffset,
- Obj.getFileName(), Bundles))
- return Err;
- } else {
- if (Error Err = extractOffloadBundle(Contents, SectionOffset,
- Obj.getFileName(), Bundles))
- return Err;
- }
+ if (Error Err = extractOffloadBundle(Contents, SectionOffset,
+ Obj.getFileName(), Bundles))
+ return Err;
}
}
return Error::success();
@@ -222,8 +239,22 @@ Error object::extractCodeObject(const ObjectFile &Source, int64_t Offset,
return Error::success();
}
+Error object::extractCodeObject(const MemoryBufferRef Buffer, int64_t Offset,
+ int64_t Size, StringRef OutputFileName) {
+ Expected<std::unique_ptr<FileOutputBuffer>> BufferOrErr =
+ FileOutputBuffer::create(OutputFileName, Size);
+ if (!BufferOrErr)
+ return BufferOrErr.takeError();
+
+ std::unique_ptr<FileOutputBuffer> Buf = std::move(*BufferOrErr);
+ std::copy(Buffer.getBufferStart() + Offset,
+ Buffer.getBufferStart() + Offset + Size, Buf->getBufferStart());
+
+ return Buf->commit();
+}
+
// given a file name, offset, and size, extract data into a code object file,
-// into file <SourceFile>-offset<Offset>-size<Size>.co
+// into file "<SourceFile>-offset<Offset>-size<Size>.co".
Error object::extractOffloadBundleByURI(StringRef URIstr) {
// create a URI object
Expected<std::unique_ptr<OffloadBundleURI>> UriOrErr(
@@ -236,7 +267,7 @@ Error object::extractOffloadBundleByURI(StringRef URIstr) {
OutputFile +=
"-offset" + itostr(Uri.Offset) + "-size" + itostr(Uri.Size) + ".co";
- // Create an ObjectFile object from uri.file_uri
+ // Create an ObjectFile object from uri.file_uri.
auto ObjOrErr = ObjectFile::createObjectFile(Uri.FileName);
if (!ObjOrErr)
return ObjOrErr.takeError();
@@ -249,7 +280,7 @@ Error object::extractOffloadBundleByURI(StringRef URIstr) {
return Error::success();
}
-// Utility function to format numbers with commas
+// Utility function to format numbers with commas.
static std::string formatWithCommas(unsigned long long Value) {
std::string Num = std::to_string(Value);
int InsertPosition = Num.length() - 3;
@@ -260,87 +291,278 @@ static std::string formatWithCommas(unsigned long long Value) {
return Num;
}
-llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
-CompressedOffloadBundle::decompress(llvm::MemoryBufferRef &Input,
- bool Verbose) {
- StringRef Blob = Input.getBuffer();
+Expected<std::unique_ptr<MemoryBuffer>>
+CompressedOffloadBundle::compress(compression::Params P,
+ const MemoryBuffer &Input, uint16_t Version,
+ raw_ostream *VerboseStream) {
+ if (!compression::zstd::isAvailable() && !compression::zlib::isAvailable())
+ return createStringError("compression not supported.");
+ Timer HashTimer("Hash Calculation Timer", "Hash calculation time",
+ OffloadBundlerTimerGroup);
+ if (VerboseStream)
+ HashTimer.startTimer();
+ MD5 Hash;
+ MD5::MD5Result Result;
+ Hash.update(Input.getBuffer());
+ Hash.final(Result);
+ uint64_t TruncatedHash = Result.low();
+ if (VerboseStream)
+ HashTimer.stopTimer();
+
+ SmallVector<uint8_t, 0> CompressedBuffer;
+ auto BufferUint8 = ArrayRef<uint8_t>(
+ reinterpret_cast<const uint8_t *>(Input.getBuffer().data()),
+ Input.getBuffer().size());
+ Timer CompressTimer("Compression Timer", "Compression time",
+ OffloadBundlerTimerGroup);
+ if (VerboseStream)
+ CompressTimer.startTimer();
+ compression::compress(P, BufferUint8, CompressedBuffer);
+ if (VerboseStream)
+ CompressTimer.stopTimer();
+
+ uint16_t CompressionMethod = static_cast<uint16_t>(P.format);
+
+ // Store sizes in 64-bit variables first.
+ uint64_t UncompressedSize64 = Input.getBuffer().size();
+ uint64_t TotalFileSize64;
+
+ // Calculate total file size based on version.
+ if (Version == 2) {
+ // For V2, ensure the sizes don't exceed 32-bit limit.
+ if (UncompressedSize64 > std::numeric_limits<uint32_t>::max())
+ return createStringError("uncompressed size (%llu) exceeds version 2 "
+ "unsigned 32-bit integer limit",
+ UncompressedSize64);
+ TotalFileSize64 = MagicNumber.size() + sizeof(uint32_t) + sizeof(Version) +
+ sizeof(CompressionMethod) + sizeof(uint32_t) +
+ sizeof(TruncatedHash) + CompressedBuffer.size();
+ if (TotalFileSize64 > std::numeric_limits<uint32_t>::max())
+ return createStringError("total file size (%llu) exceeds version 2 "
+ "unsigned 32-bit integer limit",
+ TotalFileSize64);
+
+ } else { // Version 3.
+ TotalFileSize64 = MagicNumber.size() + sizeof(uint64_t) + sizeof(Version) +
+ sizeof(CompressionMethod) + sizeof(uint64_t) +
+ sizeof(TruncatedHash) + CompressedBuffer.size();
+ }
+
+ SmallVector<char, 0> FinalBuffer;
+ raw_svector_ostream OS(FinalBuffer);
+ OS << MagicNumber;
+ OS.write(reinterpret_cast<const char *>(&Version), sizeof(Version));
+ OS.write(reinterpret_cast<const char *>(&CompressionMethod),
+ sizeof(CompressionMethod));
+
+ // Write size fields according to version.
+ if (Version == 2) {
+ uint32_t TotalFileSize32 = static_cast<uint32_t>(TotalFileSize64);
+ uint32_t UncompressedSize32 = static_cast<uint32_t>(UncompressedSize64);
+ OS.write(reinterpret_cast<const char *>(&TotalFileSize32),
+ sizeof(TotalFileSize32));
+ OS.write(reinterpret_cast<const char *>(&UncompressedSize32),
+ sizeof(UncompressedSize32));
+ } else { // Version 3.
+ OS.write(reinterpret_cast<const char *>(&TotalFileSize64),
+ sizeof(TotalFileSize64));
+ OS.write(reinterpret_cast<const char *>(&UncompressedSize64),
+ sizeof(UncompressedSize64));
+ }
+
+ OS.write(reinterpret_cast<const char *>(&TruncatedHash),
+ sizeof(TruncatedHash));
+ OS.write(reinterpret_cast<const char *>(CompressedBuffer.data()),
+ CompressedBuffer.size());
+
+ if (VerboseStream) {
+ auto MethodUsed = P.format == compression::Format::Zstd ? "zstd" : "zlib";
+ double CompressionRate =
+ static_cast<double>(UncompressedSize64) / CompressedBuffer.size();
+ double CompressionTimeSeconds = CompressTimer.getTotalTime().getWallTime();
+ double CompressionSpeedMBs =
+ (UncompressedSize64 / (1024.0 * 1024.0)) / CompressionTimeSeconds;
+ *VerboseStream << "Compressed bundle format version: " << Version << "\n"
+ << "Total file size (including headers): "
+ << formatWithCommas(TotalFileSize64) << " bytes\n"
+ << "Compression method used: " << MethodUsed << "\n"
+ << "Compression level: " << P.level << "\n"
+ << "Binary size before compression: "
+ << formatWithCommas(UncompressedSize64) << " bytes\n"
+ << "Binary size after compression: "
+ << formatWithCommas(CompressedBuffer.size()) << " bytes\n"
+ << "Compression rate: " << format("%.2lf", CompressionRate)
+ << "\n"
+ << "Compression ratio: "
+ << format("%.2lf%%", 100.0 / CompressionRate) << "\n"
+ << "Compression speed: "
+ << format("%.2lf MB/s", CompressionSpeedMBs) << "\n"
+ << "Truncated MD5 hash: " << format_hex(TruncatedHash, 16)
+ << "\n";
+ }
+
+ return MemoryBuffer::getMemBufferCopy(
+ StringRef(FinalBuffer.data(), FinalBuffer.size()));
+}
+
+// Use packed structs to avoid padding, such that the structs map the serialized
+// format.
+LLVM_PACKED_START
+union RawCompressedBundleHeader {
+ struct CommonFields {
+ uint32_t Magic;
+ uint16_t Version;
+ uint16_t Method;
+ };
+
+ struct V1Header {
+ CommonFields Common;
+ uint32_t UncompressedFileSize;
+ uint64_t Hash;
+ };
+
+ struct V2Header {
+ CommonFields Common;
+ uint32_t FileSize;
+ uint32_t UncompressedFileSize;
+ uint64_t Hash;
+ };
+
+ struct V3Header {
+ CommonFields Common;
+ uint64_t FileSize;
+ uint64_t UncompressedFileSize;
+ uint64_t Hash;
+ };
+
+ CommonFields Common;
+ V1Header V1;
+ V2Header V2;
+ V3Header V3;
+};
+LLVM_PACKED_END
+
+// Helper method to get header size based on version.
+static size_t getHeaderSize(uint16_t Version) {
+ switch (Version) {
+ case 1:
+ return sizeof(RawCompressedBundleHeader::V1Header);
+ case 2:
+ return sizeof(RawCompressedBundleHeader::V2Header);
+ case 3:
+ return sizeof(RawCompressedBundleHeader::V3Header);
+ default:
+ llvm_unreachable("Unsupported version");
+ }
+}
- if (Blob.size() < V1HeaderSize)
- return llvm::MemoryBuffer::getMemBufferCopy(Blob);
+Expected<CompressedOffloadBundle::CompressedBundleHeader>
+CompressedOffloadBundle::CompressedBundleHeader::tryParse(StringRef Blob) {
+ assert(Blob.size() >= sizeof(RawCompressedBundleHeader::CommonFields));
+ assert(identify_magic(Blob) == file_magic::offload_bundle_compressed);
+
+ RawCompressedBundleHeader Header;
+ std::memcpy(&Header, Blob.data(), std::min(Blob.size(), sizeof(Header)));
+
+ CompressedBundleHeader Normalized;
+ Normalized.Version = Header.Common.Version;
+
+ size_t RequiredSize = getHeaderSize(Normalized.Version);
+
+ if (Blob.size() < RequiredSize)
+ return createStringError("compressed bundle header size too small");
+
+ switch (Normalized.Version) {
+ case 1:
+ Normalized.UncompressedFileSize = Header.V1.UncompressedFileSize;
+ Normalized.Hash = Header.V1.Hash;
+ break;
+ case 2:
+ Normalized.FileSize = Header.V2.FileSize;
+ Normalized.UncompressedFileSize = Header.V2.UncompressedFileSize;
+ Normalized.Hash = Header.V2.Hash;
+ break;
+ case 3:
+ Normalized.FileSize = Header.V3.FileSize;
+ Normalized.UncompressedFileSize = Header.V3.UncompressedFileSize;
+ Normalized.Hash = Header.V3.Hash;
+ break;
+ default:
+ return createStringError("unknown compressed bundle version");
+ }
- if (llvm::identify_magic(Blob) !=
- llvm::file_magic::offload_bundle_compressed) {
- if (Verbose)
- llvm::errs() << "Uncompressed bundle.\n";
- return llvm::MemoryBuffer::getMemBufferCopy(Blob);
+ // Determine compression format.
+ switch (Header.Common.Method) {
+ case static_cast<uint16_t>(compression::Format::Zlib):
+ case static_cast<uint16_t>(compression::Format::Zstd):
+ Normalized.CompressionFormat =
+ static_cast<compression::Format>(Header.Common.Method);
+ break;
+ default:
+ return createStringError("unknown compressing method");
}
- size_t CurrentOffset = MagicSize;
+ return Normalized;
+}
- uint16_t ThisVersion;
- memcpy(&ThisVersion, Blob.data() + CurrentOffset, sizeof(uint16_t));
- CurrentOffset += VersionFieldSize;
+Expected<std::unique_ptr<MemoryBuffer>>
+CompressedOffloadBundle::decompress(const MemoryBuffer &Input,
+ raw_ostream *VerboseStream) {
+ StringRef Blob = Input.getBuffer();
- uint16_t CompressionMethod;
- memcpy(&CompressionMethod, Blob.data() + CurrentOffset, sizeof(uint16_t));
- CurrentOffset += MethodFieldSize;
+ // Check minimum header size (using V1 as it's the smallest).
+ if (Blob.size() < sizeof(RawCompressedBundleHeader::CommonFields))
+ return MemoryBuffer::getMemBufferCopy(Blob);
- uint32_t TotalFileSize;
- if (ThisVersion >= 2) {
- if (Blob.size() < V2HeaderSize)
- return createStringError(inconvertibleErrorCode(),
- "Compressed bundle header size too small");
- memcpy(&TotalFileSize, Blob.data() + CurrentOffset, sizeof(uint32_t));
- CurrentOffset += FileSizeFieldSize;
+ if (identify_magic(Blob) != file_magic::offload_bundle_compressed) {
+ if (VerboseStream)
+ *VerboseStream << "Uncompressed bundle\n";
+ return MemoryBuffer::getMemBufferCopy(Blob);
}
- uint32_t UncompressedSize;
- memcpy(&UncompressedSize, Blob.data() + CurrentOffset, sizeof(uint32_t));
- CurrentOffset += UncompressedSizeFieldSize;
-
- uint64_t StoredHash;
- memcpy(&StoredHash, Blob.data() + CurrentOffset, sizeof(uint64_t));
- CurrentOffset += HashFieldSize;
-
- llvm::compression::Format CompressionFormat;
- if (CompressionMethod ==
- static_cast<uint16_t>(llvm::compression::Format::Zlib))
- CompressionFormat = llvm::compression::Format::Zlib;
- else if (CompressionMethod ==
- static_cast<uint16_t>(llvm::compression::Format::Zstd))
- CompressionFormat = llvm::compression::Format::Zstd;
- else
- return createStringError(inconvertibleErrorCode(),
- "Unknown compressing method");
-
- llvm::Timer DecompressTimer("Decompression Timer", "Decompression time",
- OffloadBundlerTimerGroup);
- if (Verbose)
+ Expected<CompressedBundleHeader> HeaderOrErr =
+ CompressedBundleHeader::tryParse(Blob);
+ if (!HeaderOrErr)
+ return HeaderOrErr.takeError();
+
+ const CompressedBundleHeader &Normalized = *HeaderOrErr;
+ unsigned ThisVersion = Normalized.Version;
+ size_t HeaderSize = getHeaderSize(ThisVersion);
+
+ compression::Format CompressionFormat = Normalized.CompressionFormat;
+
+ size_t TotalFileSize = Normalized.FileSize.value_or(0);
+ size_t UncompressedSize = Normalized.UncompressedFileSize;
+ auto StoredHash = Normalized.Hash;
+
+ Timer DecompressTimer("Decompression Timer", "Decompression time",
+ OffloadBundlerTimerGroup);
+ if (VerboseStream)
DecompressTimer.startTimer();
SmallVector<uint8_t, 0> DecompressedData;
- StringRef CompressedData = Blob.substr(CurrentOffset);
- if (llvm::Error DecompressionError = llvm::compression::decompress(
- CompressionFormat, llvm::arrayRefFromStringRef(CompressedData),
+ StringRef CompressedData =
+ Blob.substr(HeaderSize, TotalFileSize - HeaderSize);
+
+ if (Error DecompressionError = compression::decompress(
+ CompressionFormat, arrayRefFromStringRef(CompressedData),
DecompressedData, UncompressedSize))
- return createStringError(inconvertibleErrorCode(),
- "Could not decompress embedded file contents: " +
- llvm::toString(std::move(DecompressionError)));
+ return createStringError("could not decompress embedded file contents: " +
+ toString(std::move(DecompressionError)));
- if (Verbose) {
+ if (VerboseStream) {
DecompressTimer.stopTimer();
double DecompressionTimeSeconds =
DecompressTimer.getTotalTime().getWallTime();
// Recalculate MD5 hash for integrity check.
- llvm::Timer HashRecalcTimer("Hash Recalculation Timer",
- "Hash recalculation time",
- OffloadBundlerTimerGroup);
+ Timer HashRecalcTimer("Hash Recalculation Timer", "Hash recalculation time",
+ OffloadBundlerTimerGroup);
HashRecalcTimer.startTimer();
- llvm::MD5 Hash;
- llvm::MD5::MD5Result Result;
- Hash.update(llvm::ArrayRef<uint8_t>(DecompressedData));
+ MD5 Hash;
+ MD5::MD5Result Result;
+ Hash.update(ArrayRef<uint8_t>(DecompressedData));
Hash.final(Result);
uint64_t RecalculatedHash = Result.low();
HashRecalcTimer.stopTimer();
@@ -351,118 +573,28 @@ CompressedOffloadBundle::decompress(llvm::MemoryBufferRef &Input,
double DecompressionSpeedMBs =
(UncompressedSize / (1024.0 * 1024.0)) / DecompressionTimeSeconds;
- llvm::errs() << "Compressed bundle format version: " << ThisVersion << "\n";
+ *VerboseStream << "Compressed bundle format version: " << ThisVersion
+ << "\n";
if (ThisVersion >= 2)
- llvm::errs() << "Total file size (from header): "
- << formatWithCommas(TotalFileSize) << " bytes\n";
- llvm::errs() << "Decompression method: "
- << (CompressionFormat == llvm::compression::Format::Zlib
- ? "zlib"
- : "zstd")
- << "\n"
- << "Size before decompression: "
- << formatWithCommas(CompressedData.size()) << " bytes\n"
- << "Size after decompression: "
- << formatWithCommas(UncompressedSize) << " bytes\n"
- << "Compression rate: "
- << llvm::format("%.2lf", CompressionRate) << "\n"
- << "Compression ratio: "
- << llvm::format("%.2lf%%", 100.0 / CompressionRate) << "\n"
- << "Decompression speed: "
- << llvm::format("%.2lf MB/s", DecompressionSpeedMBs) << "\n"
- << "Stored hash: " << llvm::format_hex(StoredHash, 16) << "\n"
- << "Recalculated hash: "
- << llvm::format_hex(RecalculatedHash, 16) << "\n"
- << "Hashes match: " << (HashMatch ? "Yes" : "No") << "\n";
+ *VerboseStream << "Total file size (from header): "
+ << formatWithCommas(TotalFileSize) << " bytes\n";
+ *VerboseStream
+ << "Decompression method: "
+ << (CompressionFormat == compression::Format::Zlib ? "zlib" : "zstd")
+ << "\n"
+ << "Size before decompression: "
+ << formatWithCommas(CompressedData.size()) << " bytes\n"
+ << "Size after decompression: " << formatWithCommas(UncompressedSize)
+ << " bytes\n"
+ << "Compression rate: " << format("%.2lf", CompressionRate) << "\n"
+ << "Compression ratio: " << format("%.2lf%%", 100.0 / CompressionRate)
+ << "\n"
+ << "Decompression speed: "
+ << format("%.2lf MB/s", DecompressionSpeedMBs) << "\n"
+ << "Stored hash: " << format_hex(StoredHash, 16) << "\n"
+ << "Recalculated hash: " << format_hex(RecalculatedHash, 16) << "\n"
+ << "Hashes match: " << (HashMatch ? "Yes" : "No") << "\n";
}
- return llvm::MemoryBuffer::getMemBufferCopy(
- llvm::toStringRef(DecompressedData));
-}
-
-llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
-CompressedOffloadBundle::compress(llvm::compression::Params P,
- const llvm::MemoryBuffer &Input,
- bool Verbose) {
- if (!llvm::compression::zstd::isAvailable() &&
- !llvm::compression::zlib::isAvailable())
- return createStringError(llvm::inconvertibleErrorCode(),
- "Compression not supported");
-
- llvm::Timer HashTimer("Hash Calculation Timer", "Hash calculation time",
- OffloadBundlerTimerGroup);
- if (Verbose)
- HashTimer.startTimer();
- llvm::MD5 Hash;
- llvm::MD5::MD5Result Result;
- Hash.update(Input.getBuffer());
- Hash.final(Result);
- uint64_t TruncatedHash = Result.low();
- if (Verbose)
- HashTimer.stopTimer();
-
- SmallVector<uint8_t, 0> CompressedBuffer;
- auto BufferUint8 = llvm::ArrayRef<uint8_t>(
- reinterpret_cast<const uint8_t *>(Input.getBuffer().data()),
- Input.getBuffer().size());
-
- llvm::Timer CompressTimer("Compression Timer", "Compression time",
- OffloadBundlerTimerGroup);
- if (Verbose)
- CompressTimer.startTimer();
- llvm::compression::compress(P, BufferUint8, CompressedBuffer);
- if (Verbose)
- CompressTimer.stopTimer();
-
- uint16_t CompressionMethod = static_cast<uint16_t>(P.format);
- uint32_t UncompressedSize = Input.getBuffer().size();
- uint32_t TotalFileSize = MagicNumber.size() + sizeof(TotalFileSize) +
- sizeof(Version) + sizeof(CompressionMethod) +
- sizeof(UncompressedSize) + sizeof(TruncatedHash) +
- CompressedBuffer.size();
-
- SmallVector<char, 0> FinalBuffer;
- llvm::raw_svector_ostream OS(FinalBuffer);
- OS << MagicNumber;
- OS.write(reinterpret_cast<const char *>(&Version), sizeof(Version));
- OS.write(reinterpret_cast<const char *>(&CompressionMethod),
- sizeof(CompressionMethod));
- OS.write(reinterpret_cast<const char *>(&TotalFileSize),
- sizeof(TotalFileSize));
- OS.write(reinterpret_cast<const char *>(&UncompressedSize),
- sizeof(UncompressedSize));
- OS.write(reinterpret_cast<const char *>(&TruncatedHash),
- sizeof(TruncatedHash));
- OS.write(reinterpret_cast<const char *>(CompressedBuffer.data()),
- CompressedBuffer.size());
-
- if (Verbose) {
- auto MethodUsed =
- P.format == llvm::compression::Format::Zstd ? "zstd" : "zlib";
- double CompressionRate =
- static_cast<double>(UncompressedSize) / CompressedBuffer.size();
- double CompressionTimeSeconds = CompressTimer.getTotalTime().getWallTime();
- double CompressionSpeedMBs =
- (UncompressedSize / (1024.0 * 1024.0)) / CompressionTimeSeconds;
-
- llvm::errs() << "Compressed bundle format version: " << Version << "\n"
- << "Total file size (including headers): "
- << formatWithCommas(TotalFileSize) << " bytes\n"
- << "Compression method used: " << MethodUsed << "\n"
- << "Compression level: " << P.level << "\n"
- << "Binary size before compression: "
- << formatWithCommas(UncompressedSize) << " bytes\n"
- << "Binary size after compression: "
- << formatWithCommas(CompressedBuffer.size()) << " bytes\n"
- << "Compression rate: "
- << llvm::format("%.2lf", CompressionRate) << "\n"
- << "Compression ratio: "
- << llvm::format("%.2lf%%", 100.0 / CompressionRate) << "\n"
- << "Compression speed: "
- << llvm::format("%.2lf MB/s", CompressionSpeedMBs) << "\n"
- << "Truncated MD5 hash: "
- << llvm::format_hex(TruncatedHash, 16) << "\n";
- }
- return llvm::MemoryBuffer::getMemBufferCopy(
- llvm::StringRef(FinalBuffer.data(), FinalBuffer.size()));
+ return MemoryBuffer::getMemBufferCopy(toStringRef(DecompressedData));
}
diff --git a/llvm/lib/Option/ArgList.cpp b/llvm/lib/Option/ArgList.cpp
index c4188b3b..2f4e212 100644
--- a/llvm/lib/Option/ArgList.cpp
+++ b/llvm/lib/Option/ArgList.cpp
@@ -14,12 +14,14 @@
#include "llvm/Config/llvm-config.h"
#include "llvm/Option/Arg.h"
#include "llvm/Option/OptSpecifier.h"
+#include "llvm/Option/OptTable.h"
#include "llvm/Option/Option.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <cassert>
+#include <cstddef>
#include <memory>
#include <string>
#include <utility>
@@ -202,6 +204,42 @@ void ArgList::print(raw_ostream &O) const {
LLVM_DUMP_METHOD void ArgList::dump() const { print(dbgs()); }
#endif
+StringRef ArgList::getSubCommand(
+ ArrayRef<OptTable::SubCommand> AllSubCommands,
+ std::function<void(ArrayRef<StringRef>)> HandleMultipleSubcommands,
+ std::function<void(ArrayRef<StringRef>)> HandleOtherPositionals) const {
+
+ SmallVector<StringRef, 4> SubCommands;
+ SmallVector<StringRef, 4> OtherPositionals;
+ for (const Arg *A : *this) {
+ if (A->getOption().getKind() != Option::InputClass)
+ continue;
+
+ size_t OldSize = SubCommands.size();
+ for (const OptTable::SubCommand &CMD : AllSubCommands) {
+ if (StringRef(CMD.Name) == A->getValue())
+ SubCommands.push_back(A->getValue());
+ }
+
+ if (SubCommands.size() == OldSize)
+ OtherPositionals.push_back(A->getValue());
+ }
+
+ // Invoke callbacks if necessary.
+ if (SubCommands.size() > 1) {
+ HandleMultipleSubcommands(SubCommands);
+ return {};
+ }
+ if (!OtherPositionals.empty()) {
+ HandleOtherPositionals(OtherPositionals);
+ return {};
+ }
+
+ if (SubCommands.size() == 1)
+ return SubCommands.front();
+ return {}; // No valid usage of subcommand found.
+}
+
void InputArgList::releaseMemory() {
// An InputArgList always owns its arguments.
for (Arg *A : *this)
diff --git a/llvm/lib/Option/OptTable.cpp b/llvm/lib/Option/OptTable.cpp
index 6d10e61..14e3b0d 100644
--- a/llvm/lib/Option/OptTable.cpp
+++ b/llvm/lib/Option/OptTable.cpp
@@ -79,9 +79,12 @@ OptSpecifier::OptSpecifier(const Option *Opt) : ID(Opt->getID()) {}
OptTable::OptTable(const StringTable &StrTable,
ArrayRef<StringTable::Offset> PrefixesTable,
- ArrayRef<Info> OptionInfos, bool IgnoreCase)
+ ArrayRef<Info> OptionInfos, bool IgnoreCase,
+ ArrayRef<SubCommand> SubCommands,
+ ArrayRef<unsigned> SubCommandIDsTable)
: StrTable(&StrTable), PrefixesTable(PrefixesTable),
- OptionInfos(OptionInfos), IgnoreCase(IgnoreCase) {
+ OptionInfos(OptionInfos), IgnoreCase(IgnoreCase),
+ SubCommands(SubCommands), SubCommandIDsTable(SubCommandIDsTable) {
// Explicitly zero initialize the error to work around a bug in array
// value-initialization on MinGW with gcc 4.3.5.
@@ -715,9 +718,10 @@ static const char *getOptionHelpGroup(const OptTable &Opts, OptSpecifier Id) {
void OptTable::printHelp(raw_ostream &OS, const char *Usage, const char *Title,
bool ShowHidden, bool ShowAllAliases,
- Visibility VisibilityMask) const {
+ Visibility VisibilityMask,
+ StringRef SubCommand) const {
return internalPrintHelp(
- OS, Usage, Title, ShowHidden, ShowAllAliases,
+ OS, Usage, Title, SubCommand, ShowHidden, ShowAllAliases,
[VisibilityMask](const Info &CandidateInfo) -> bool {
return (CandidateInfo.Visibility & VisibilityMask) == 0;
},
@@ -730,7 +734,7 @@ void OptTable::printHelp(raw_ostream &OS, const char *Usage, const char *Title,
bool ShowHidden = !(FlagsToExclude & HelpHidden);
FlagsToExclude &= ~HelpHidden;
return internalPrintHelp(
- OS, Usage, Title, ShowHidden, ShowAllAliases,
+ OS, Usage, Title, /*SubCommand=*/{}, ShowHidden, ShowAllAliases,
[FlagsToInclude, FlagsToExclude](const Info &CandidateInfo) {
if (FlagsToInclude && !(CandidateInfo.Flags & FlagsToInclude))
return true;
@@ -742,16 +746,62 @@ void OptTable::printHelp(raw_ostream &OS, const char *Usage, const char *Title,
}
void OptTable::internalPrintHelp(
- raw_ostream &OS, const char *Usage, const char *Title, bool ShowHidden,
- bool ShowAllAliases, std::function<bool(const Info &)> ExcludeOption,
+ raw_ostream &OS, const char *Usage, const char *Title, StringRef SubCommand,
+ bool ShowHidden, bool ShowAllAliases,
+ std::function<bool(const Info &)> ExcludeOption,
Visibility VisibilityMask) const {
OS << "OVERVIEW: " << Title << "\n\n";
- OS << "USAGE: " << Usage << "\n\n";
// Render help text into a map of group-name to a list of (option, help)
// pairs.
std::map<std::string, std::vector<OptionInfo>> GroupedOptionHelp;
+ auto ActiveSubCommand =
+ std::find_if(SubCommands.begin(), SubCommands.end(),
+ [&](const auto &C) { return SubCommand == C.Name; });
+ if (!SubCommand.empty()) {
+ assert(ActiveSubCommand != SubCommands.end() &&
+ "Not a valid registered subcommand.");
+ OS << ActiveSubCommand->HelpText << "\n\n";
+ if (!StringRef(ActiveSubCommand->Usage).empty())
+ OS << "USAGE: " << ActiveSubCommand->Usage << "\n\n";
+ } else {
+ OS << "USAGE: " << Usage << "\n\n";
+ if (SubCommands.size() > 1) {
+ OS << "SUBCOMMANDS:\n\n";
+ for (const auto &C : SubCommands)
+ OS << C.Name << " - " << C.HelpText << "\n";
+ OS << "\n";
+ }
+ }
+
+ auto DoesOptionBelongToSubcommand = [&](const Info &CandidateInfo) {
+ // Retrieve the SubCommandIDs registered to the given current CandidateInfo
+ // Option.
+ ArrayRef<unsigned> SubCommandIDs =
+ CandidateInfo.getSubCommandIDs(SubCommandIDsTable);
+
+ // If no registered subcommands, then only global options are to be printed.
+ // If no valid SubCommand (empty) in commandline then print the current
+ // global CandidateInfo Option.
+ if (SubCommandIDs.empty())
+ return SubCommand.empty();
+
+ // Handle CandidateInfo Option which has at least one registered SubCommand.
+ // If no valid SubCommand (empty) in commandline, this CandidateInfo option
+ // should not be printed.
+ if (SubCommand.empty())
+ return false;
+
+ // Find the ID of the valid subcommand passed in commandline (its index in
+ // the SubCommands table which contains all subcommands).
+ unsigned ActiveSubCommandID = ActiveSubCommand - &SubCommands[0];
+ // Print if the ActiveSubCommandID is registered with the CandidateInfo
+ // Option.
+ return std::find(SubCommandIDs.begin(), SubCommandIDs.end(),
+ ActiveSubCommandID) != SubCommandIDs.end();
+ };
+
for (unsigned Id = 1, e = getNumOptions() + 1; Id != e; ++Id) {
// FIXME: Split out option groups.
if (getOptionKind(Id) == Option::GroupClass)
@@ -764,6 +814,9 @@ void OptTable::internalPrintHelp(
if (ExcludeOption(CandidateInfo))
continue;
+ if (!DoesOptionBelongToSubcommand(CandidateInfo))
+ continue;
+
// If an alias doesn't have a help text, show a help text for the aliased
// option instead.
const char *HelpText = getOptionHelpText(Id, VisibilityMask);
@@ -791,8 +844,11 @@ void OptTable::internalPrintHelp(
GenericOptTable::GenericOptTable(const StringTable &StrTable,
ArrayRef<StringTable::Offset> PrefixesTable,
- ArrayRef<Info> OptionInfos, bool IgnoreCase)
- : OptTable(StrTable, PrefixesTable, OptionInfos, IgnoreCase) {
+ ArrayRef<Info> OptionInfos, bool IgnoreCase,
+ ArrayRef<SubCommand> SubCommands,
+ ArrayRef<unsigned> SubCommandIDsTable)
+ : OptTable(StrTable, PrefixesTable, OptionInfos, IgnoreCase, SubCommands,
+ SubCommandIDsTable) {
std::set<StringRef> TmpPrefixesUnion;
for (auto const &Info : OptionInfos.drop_front(FirstSearchableIndex))
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index c234623..20dcde8 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -240,6 +240,7 @@
#include "llvm/Transforms/IPO/WholeProgramDevirt.h"
#include "llvm/Transforms/InstCombine/InstCombine.h"
#include "llvm/Transforms/Instrumentation/AddressSanitizer.h"
+#include "llvm/Transforms/Instrumentation/AllocToken.h"
#include "llvm/Transforms/Instrumentation/BoundsChecking.h"
#include "llvm/Transforms/Instrumentation/CGProfile.h"
#include "llvm/Transforms/Instrumentation/ControlHeightReduction.h"
diff --git a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp
index 7069e8d..119caea 100644
--- a/llvm/lib/Passes/PassBuilderPipelines.cpp
+++ b/llvm/lib/Passes/PassBuilderPipelines.cpp
@@ -1960,6 +1960,7 @@ PassBuilder::buildLTODefaultPipeline(OptimizationLevel Level,
// is fixed.
MPM.addPass(WholeProgramDevirtPass(ExportSummary, nullptr));
+ MPM.addPass(NoRecurseLTOInferencePass());
// Stop here at -O1.
if (Level == OptimizationLevel::O1) {
// The LowerTypeTestsPass needs to run to lower type metadata and the
diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index f0e7d36..c5c0d64 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -119,11 +119,13 @@ MODULE_PASS("metarenamer", MetaRenamerPass())
MODULE_PASS("module-inline", ModuleInlinerPass())
MODULE_PASS("name-anon-globals", NameAnonGlobalPass())
MODULE_PASS("no-op-module", NoOpModulePass())
+MODULE_PASS("norecurse-lto-inference", NoRecurseLTOInferencePass())
MODULE_PASS("nsan", NumericalStabilitySanitizerPass())
MODULE_PASS("openmp-opt", OpenMPOptPass())
MODULE_PASS("openmp-opt-postlink",
OpenMPOptPass(ThinOrFullLTOPhase::FullLTOPostLink))
MODULE_PASS("partial-inliner", PartialInlinerPass())
+MODULE_PASS("alloc-token", AllocTokenPass())
MODULE_PASS("pgo-icall-prom", PGOIndirectCallPromotion())
MODULE_PASS("pgo-instr-gen", PGOInstrumentationGen())
MODULE_PASS("pgo-instr-use", PGOInstrumentationUse())
diff --git a/llvm/lib/Support/GlobPattern.cpp b/llvm/lib/Support/GlobPattern.cpp
index 7004adf..0ecf47d 100644
--- a/llvm/lib/Support/GlobPattern.cpp
+++ b/llvm/lib/Support/GlobPattern.cpp
@@ -143,6 +143,15 @@ GlobPattern::create(StringRef S, std::optional<size_t> MaxSubPatterns) {
return Pat;
S = S.substr(PrefixSize);
+ // Just in case we stop on unmatched opening brackets.
+ size_t SuffixStart = S.find_last_of("?*[]{}\\");
+ assert(SuffixStart != std::string::npos);
+ if (S[SuffixStart] == '\\')
+ ++SuffixStart;
+ ++SuffixStart;
+ Pat.Suffix = S.substr(SuffixStart);
+ S = S.substr(0, SuffixStart);
+
SmallVector<std::string, 1> SubPats;
if (auto Err = parseBraceExpansions(S, MaxSubPatterns).moveInto(SubPats))
return std::move(Err);
@@ -193,6 +202,8 @@ GlobPattern::SubGlobPattern::create(StringRef S) {
bool GlobPattern::match(StringRef S) const {
if (!S.consume_front(Prefix))
return false;
+ if (!S.consume_back(Suffix))
+ return false;
if (SubGlobs.empty() && S.empty())
return true;
for (auto &Glob : SubGlobs)
diff --git a/llvm/lib/Support/SpecialCaseList.cpp b/llvm/lib/Support/SpecialCaseList.cpp
index 8d4e043..4b03885 100644
--- a/llvm/lib/Support/SpecialCaseList.cpp
+++ b/llvm/lib/Support/SpecialCaseList.cpp
@@ -135,7 +135,7 @@ SpecialCaseList::addSection(StringRef SectionStr, unsigned FileNo,
Sections.emplace_back(SectionStr, FileNo);
auto &Section = Sections.back();
- if (auto Err = Section.SectionMatcher->insert(SectionStr, LineNo, UseGlobs)) {
+ if (auto Err = Section.SectionMatcher.insert(SectionStr, LineNo, UseGlobs)) {
return createStringError(errc::invalid_argument,
"malformed section at line " + Twine(LineNo) +
": '" + SectionStr +
@@ -218,7 +218,7 @@ std::pair<unsigned, unsigned>
SpecialCaseList::inSectionBlame(StringRef Section, StringRef Prefix,
StringRef Query, StringRef Category) const {
for (const auto &S : reverse(Sections)) {
- if (S.SectionMatcher->match(Section)) {
+ if (S.SectionMatcher.match(Section)) {
unsigned Blame = inSectionBlame(S.Entries, Prefix, Query, Category);
if (Blame)
return {S.FileIdx, Blame};
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 50a8754..479e345 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5666,18 +5666,21 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
VectorType *AccumVectorType =
VectorType::get(AccumType, VF.divideCoefficientBy(Ratio));
// We don't yet support all kinds of legalization.
- auto TA = TLI->getTypeAction(AccumVectorType->getContext(),
- EVT::getEVT(AccumVectorType));
- switch (TA) {
+ auto TC = TLI->getTypeConversion(AccumVectorType->getContext(),
+ EVT::getEVT(AccumVectorType));
+ switch (TC.first) {
default:
return Invalid;
case TargetLowering::TypeLegal:
case TargetLowering::TypePromoteInteger:
case TargetLowering::TypeSplitVector:
+ // The legalised type (e.g. after splitting) must be legal too.
+ if (TLI->getTypeAction(AccumVectorType->getContext(), TC.second) !=
+ TargetLowering::TypeLegal)
+ return Invalid;
break;
}
- // Check what kind of type-legalisation happens.
std::pair<InstructionCost, MVT> AccumLT =
getTypeLegalizationCost(AccumVectorType);
std::pair<InstructionCost, MVT> InputLT =
diff --git a/llvm/lib/Target/AMDGPU/AMDGPU.td b/llvm/lib/Target/AMDGPU/AMDGPU.td
index 6b3c151..1a697f7 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPU.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPU.td
@@ -1448,10 +1448,10 @@ def Feature45BitNumRecordsBufferResource : SubtargetFeature< "45-bit-num-records
"The buffer resource (V#) supports 45-bit num_records"
>;
-def FeatureCluster : SubtargetFeature< "cluster",
- "HasCluster",
+def FeatureClusters : SubtargetFeature< "clusters",
+ "HasClusters",
"true",
- "Has cluster support"
+ "Has clusters of workgroups support"
>;
// Dummy feature used to disable assembler instructions.
@@ -2120,7 +2120,7 @@ def FeatureISAVersion12_50 : FeatureSet<
Feature45BitNumRecordsBufferResource,
FeatureSupportsXNACK,
FeatureXNACK,
- FeatureCluster,
+ FeatureClusters,
]>;
def FeatureISAVersion12_51 : FeatureSet<
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
index 848d9a5..557d87f 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
@@ -5043,6 +5043,9 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
case Intrinsic::amdgcn_mfma_i32_16x16x64_i8:
case Intrinsic::amdgcn_mfma_i32_32x32x32_i8:
case Intrinsic::amdgcn_mfma_f32_16x16x32_bf16: {
+ unsigned DstSize = MRI.getType(MI.getOperand(0).getReg()).getSizeInBits();
+ unsigned MinNumRegsRequired = DstSize / 32;
+
// Default for MAI intrinsics.
// srcC can also be an immediate which can be folded later.
// FIXME: Should we eventually add an alternative mapping with AGPR src
@@ -5051,29 +5054,32 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
// vdst, srcA, srcB, srcC
const SIMachineFunctionInfo *Info = MF.getInfo<SIMachineFunctionInfo>();
OpdsMapping[0] =
- Info->mayNeedAGPRs()
+ Info->getMinNumAGPRs() >= MinNumRegsRequired
? getAGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI)
: getVGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI);
OpdsMapping[2] = getVGPROpMapping(MI.getOperand(2).getReg(), MRI, *TRI);
OpdsMapping[3] = getVGPROpMapping(MI.getOperand(3).getReg(), MRI, *TRI);
OpdsMapping[4] =
- Info->mayNeedAGPRs()
+ Info->getMinNumAGPRs() >= MinNumRegsRequired
? getAGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI)
: getVGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI);
break;
}
case Intrinsic::amdgcn_mfma_scale_f32_16x16x128_f8f6f4:
case Intrinsic::amdgcn_mfma_scale_f32_32x32x64_f8f6f4: {
+ unsigned DstSize = MRI.getType(MI.getOperand(0).getReg()).getSizeInBits();
+ unsigned MinNumRegsRequired = DstSize / 32;
+
const SIMachineFunctionInfo *Info = MF.getInfo<SIMachineFunctionInfo>();
OpdsMapping[0] =
- Info->mayNeedAGPRs()
+ Info->getMinNumAGPRs() >= MinNumRegsRequired
? getAGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI)
: getVGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI);
OpdsMapping[2] = getVGPROpMapping(MI.getOperand(2).getReg(), MRI, *TRI);
OpdsMapping[3] = getVGPROpMapping(MI.getOperand(3).getReg(), MRI, *TRI);
OpdsMapping[4] =
- Info->mayNeedAGPRs()
+ Info->getMinNumAGPRs() >= MinNumRegsRequired
? getAGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI)
: getVGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI);
diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
index a67a7be..d0c0822 100644
--- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
+++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
@@ -1944,6 +1944,7 @@ public:
void cvtVOP3Interp(MCInst &Inst, const OperandVector &Operands);
void cvtVINTERP(MCInst &Inst, const OperandVector &Operands);
+ void cvtOpSelHelper(MCInst &Inst, unsigned OpSel);
bool parseDimId(unsigned &Encoding);
ParseStatus parseDim(OperandVector &Operands);
@@ -9239,6 +9240,33 @@ static bool isRegOrImmWithInputMods(const MCInstrDesc &Desc, unsigned OpNum) {
MCOI::OperandConstraint::TIED_TO) == -1;
}
+void AMDGPUAsmParser::cvtOpSelHelper(MCInst &Inst, unsigned OpSel) {
+ unsigned Opc = Inst.getOpcode();
+ constexpr AMDGPU::OpName Ops[] = {AMDGPU::OpName::src0, AMDGPU::OpName::src1,
+ AMDGPU::OpName::src2};
+ constexpr AMDGPU::OpName ModOps[] = {AMDGPU::OpName::src0_modifiers,
+ AMDGPU::OpName::src1_modifiers,
+ AMDGPU::OpName::src2_modifiers};
+ for (int J = 0; J < 3; ++J) {
+ int OpIdx = AMDGPU::getNamedOperandIdx(Opc, Ops[J]);
+ if (OpIdx == -1)
+ // Some instructions, e.g. v_interp_p2_f16 in GFX9, have src0, src2, but
+ // no src1. So continue instead of break.
+ continue;
+
+ int ModIdx = AMDGPU::getNamedOperandIdx(Opc, ModOps[J]);
+ uint32_t ModVal = Inst.getOperand(ModIdx).getImm();
+
+ if ((OpSel & (1 << J)) != 0)
+ ModVal |= SISrcMods::OP_SEL_0;
+ // op_sel[3] is encoded in src0_modifiers.
+ if (ModOps[J] == AMDGPU::OpName::src0_modifiers && (OpSel & (1 << 3)) != 0)
+ ModVal |= SISrcMods::DST_OP_SEL;
+
+ Inst.getOperand(ModIdx).setImm(ModVal);
+ }
+}
+
void AMDGPUAsmParser::cvtVOP3Interp(MCInst &Inst, const OperandVector &Operands)
{
OptionalImmIndexMap OptionalIdx;
@@ -9275,6 +9303,16 @@ void AMDGPUAsmParser::cvtVOP3Interp(MCInst &Inst, const OperandVector &Operands)
if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::omod))
addOptionalImmOperand(Inst, Operands, OptionalIdx,
AMDGPUOperand::ImmTyOModSI);
+
+ // Some v_interp instructions use op_sel[3] for dst.
+ if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::op_sel)) {
+ addOptionalImmOperand(Inst, Operands, OptionalIdx,
+ AMDGPUOperand::ImmTyOpSel);
+ int OpSelIdx = AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::op_sel);
+ unsigned OpSel = Inst.getOperand(OpSelIdx).getImm();
+
+ cvtOpSelHelper(Inst, OpSel);
+ }
}
void AMDGPUAsmParser::cvtVINTERP(MCInst &Inst, const OperandVector &Operands)
@@ -9310,31 +9348,10 @@ void AMDGPUAsmParser::cvtVINTERP(MCInst &Inst, const OperandVector &Operands)
if (OpSelIdx == -1)
return;
- const AMDGPU::OpName Ops[] = {AMDGPU::OpName::src0, AMDGPU::OpName::src1,
- AMDGPU::OpName::src2};
- const AMDGPU::OpName ModOps[] = {AMDGPU::OpName::src0_modifiers,
- AMDGPU::OpName::src1_modifiers,
- AMDGPU::OpName::src2_modifiers};
-
unsigned OpSel = Inst.getOperand(OpSelIdx).getImm();
-
- for (int J = 0; J < 3; ++J) {
- int OpIdx = AMDGPU::getNamedOperandIdx(Opc, Ops[J]);
- if (OpIdx == -1)
- break;
-
- int ModIdx = AMDGPU::getNamedOperandIdx(Opc, ModOps[J]);
- uint32_t ModVal = Inst.getOperand(ModIdx).getImm();
-
- if ((OpSel & (1 << J)) != 0)
- ModVal |= SISrcMods::OP_SEL_0;
- if (ModOps[J] == AMDGPU::OpName::src0_modifiers &&
- (OpSel & (1 << 3)) != 0)
- ModVal |= SISrcMods::DST_OP_SEL;
-
- Inst.getOperand(ModIdx).setImm(ModVal);
- }
+ cvtOpSelHelper(Inst, OpSel);
}
+
void AMDGPUAsmParser::cvtScaledMFMA(MCInst &Inst,
const OperandVector &Operands) {
OptionalImmIndexMap OptionalIdx;
diff --git a/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp b/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp
index 7b94ea3..f291e37 100644
--- a/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp
+++ b/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp
@@ -541,7 +541,7 @@ unsigned GCNSubtarget::getMaxNumSGPRs(const Function &F) const {
unsigned GCNSubtarget::getBaseMaxNumVGPRs(
const Function &F, std::pair<unsigned, unsigned> NumVGPRBounds) const {
- const auto &[Min, Max] = NumVGPRBounds;
+ const auto [Min, Max] = NumVGPRBounds;
// Check if maximum number of VGPRs was explicitly requested using
// "amdgpu-num-vgpr" attribute.
diff --git a/llvm/lib/Target/AMDGPU/GCNSubtarget.h b/llvm/lib/Target/AMDGPU/GCNSubtarget.h
index 879bf5a..c2e6078 100644
--- a/llvm/lib/Target/AMDGPU/GCNSubtarget.h
+++ b/llvm/lib/Target/AMDGPU/GCNSubtarget.h
@@ -288,7 +288,7 @@ protected:
bool Has45BitNumRecordsBufferResource = false;
- bool HasCluster = false;
+ bool HasClusters = false;
// Dummy feature to use for assembler in tablegen.
bool FeatureDisable = false;
@@ -1839,7 +1839,7 @@ public:
}
/// \returns true if the subtarget supports clusters of workgroups.
- bool hasClusters() const { return HasCluster; }
+ bool hasClusters() const { return HasClusters; }
/// \returns true if the subtarget requires a wait for xcnt before atomic
/// flat/global stores & rmw.
diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
index d3b5718..3563caa 100644
--- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
+++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
@@ -1280,6 +1280,17 @@ void AMDGPUInstPrinter::printPackedModifier(const MCInst *MI,
(ModIdx != -1) ? MI->getOperand(ModIdx).getImm() : DefaultValue;
}
+ // Some instructions, e.g. v_interp_p2_f16 in GFX9, have src0, src2, but no
+ // src1.
+ if (NumOps == 1 && AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::src2) &&
+ !AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::src1)) {
+ Ops[NumOps++] = DefaultValue; // Set src1_modifiers to default.
+ int Mod2Idx =
+ AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::src2_modifiers);
+ assert(Mod2Idx != -1);
+ Ops[NumOps++] = MI->getOperand(Mod2Idx).getImm();
+ }
+
const bool HasDst =
(AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::vdst) != -1) ||
(AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::sdst) != -1);
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index e233457..1a686a9 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -17346,74 +17346,24 @@ void SITargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
MachineFunction *MF = MI.getParent()->getParent();
MachineRegisterInfo &MRI = MF->getRegInfo();
- SIMachineFunctionInfo *Info = MF->getInfo<SIMachineFunctionInfo>();
if (TII->isVOP3(MI.getOpcode())) {
// Make sure constant bus requirements are respected.
TII->legalizeOperandsVOP3(MRI, MI);
- // Prefer VGPRs over AGPRs in mAI instructions where possible.
- // This saves a chain-copy of registers and better balance register
- // use between vgpr and agpr as agpr tuples tend to be big.
- if (!MI.getDesc().operands().empty()) {
- unsigned Opc = MI.getOpcode();
- bool HasAGPRs = Info->mayNeedAGPRs();
- const SIRegisterInfo *TRI = Subtarget->getRegisterInfo();
- int16_t Src2Idx = AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::src2);
- for (auto I :
- {AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::src0),
- AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::src1), Src2Idx}) {
- if (I == -1)
- break;
- if ((I == Src2Idx) && (HasAGPRs))
- break;
- MachineOperand &Op = MI.getOperand(I);
- if (!Op.isReg() || !Op.getReg().isVirtual())
- continue;
- auto *RC = TRI->getRegClassForReg(MRI, Op.getReg());
- if (!TRI->hasAGPRs(RC))
- continue;
- auto *Src = MRI.getUniqueVRegDef(Op.getReg());
- if (!Src || !Src->isCopy() ||
- !TRI->isSGPRReg(MRI, Src->getOperand(1).getReg()))
- continue;
- auto *NewRC = TRI->getEquivalentVGPRClass(RC);
- // All uses of agpr64 and agpr32 can also accept vgpr except for
- // v_accvgpr_read, but we do not produce agpr reads during selection,
- // so no use checks are needed.
- MRI.setRegClass(Op.getReg(), NewRC);
- }
-
- if (TII->isMAI(MI)) {
- // The ordinary src0, src1, src2 were legalized above.
- //
- // We have to also legalize the appended v_mfma_ld_scale_b32 operands,
- // as a separate instruction.
- int Src0Idx = AMDGPU::getNamedOperandIdx(MI.getOpcode(),
- AMDGPU::OpName::scale_src0);
- if (Src0Idx != -1) {
- int Src1Idx = AMDGPU::getNamedOperandIdx(MI.getOpcode(),
- AMDGPU::OpName::scale_src1);
- if (TII->usesConstantBus(MRI, MI, Src0Idx) &&
- TII->usesConstantBus(MRI, MI, Src1Idx))
- TII->legalizeOpWithMove(MI, Src1Idx);
- }
- }
-
- if (!HasAGPRs)
- return;
-
- // Resolve the rest of AV operands to AGPRs.
- if (auto *Src2 = TII->getNamedOperand(MI, AMDGPU::OpName::src2)) {
- if (Src2->isReg() && Src2->getReg().isVirtual()) {
- auto *RC = TRI->getRegClassForReg(MRI, Src2->getReg());
- if (TRI->isVectorSuperClass(RC)) {
- auto *NewRC = TRI->getEquivalentAGPRClass(RC);
- MRI.setRegClass(Src2->getReg(), NewRC);
- if (Src2->isTied())
- MRI.setRegClass(MI.getOperand(0).getReg(), NewRC);
- }
- }
+ if (TII->isMAI(MI)) {
+ // The ordinary src0, src1, src2 were legalized above.
+ //
+ // We have to also legalize the appended v_mfma_ld_scale_b32 operands,
+ // as a separate instruction.
+ int Src0Idx = AMDGPU::getNamedOperandIdx(MI.getOpcode(),
+ AMDGPU::OpName::scale_src0);
+ if (Src0Idx != -1) {
+ int Src1Idx = AMDGPU::getNamedOperandIdx(MI.getOpcode(),
+ AMDGPU::OpName::scale_src1);
+ if (TII->usesConstantBus(MRI, MI, Src0Idx) &&
+ TII->usesConstantBus(MRI, MI, Src1Idx))
+ TII->legalizeOpWithMove(MI, Src1Idx);
}
}
diff --git a/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.cpp b/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.cpp
index 908d856..b398db4 100644
--- a/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.cpp
@@ -33,17 +33,20 @@ using namespace llvm;
// optimal RC for Opc and Dest of MFMA. In particular, there are high RP cases
// where it is better to produce the VGPR form (e.g. if there are VGPR users
// of the MFMA result).
-static cl::opt<bool> MFMAVGPRForm(
- "amdgpu-mfma-vgpr-form", cl::Hidden,
+static cl::opt<bool, true> MFMAVGPRFormOpt(
+ "amdgpu-mfma-vgpr-form",
cl::desc("Whether to force use VGPR for Opc and Dest of MFMA. If "
"unspecified, default to compiler heuristics"),
- cl::init(false));
+ cl::location(SIMachineFunctionInfo::MFMAVGPRForm), cl::init(false),
+ cl::Hidden);
const GCNTargetMachine &getTM(const GCNSubtarget *STI) {
const SITargetLowering *TLI = STI->getTargetLowering();
return static_cast<const GCNTargetMachine &>(TLI->getTargetMachine());
}
+bool SIMachineFunctionInfo::MFMAVGPRForm = false;
+
SIMachineFunctionInfo::SIMachineFunctionInfo(const Function &F,
const GCNSubtarget *STI)
: AMDGPUMachineFunction(F, *STI), Mode(F, *STI), GWSResourcePSV(getTM(STI)),
@@ -81,14 +84,13 @@ SIMachineFunctionInfo::SIMachineFunctionInfo(const Function &F,
PSInputAddr = AMDGPU::getInitialPSInputAddr(F);
}
- MayNeedAGPRs = ST.hasMAIInsts();
if (ST.hasGFX90AInsts()) {
- // FIXME: MayNeedAGPRs is a misnomer for how this is used. MFMA selection
- // should be separated from availability of AGPRs
- if (MFMAVGPRForm ||
- (ST.getMaxNumVGPRs(F) <= ST.getAddressableNumArchVGPRs() &&
- !mayUseAGPRs(F)))
- MayNeedAGPRs = false; // We will select all MAI with VGPR operands.
+ // FIXME: Extract logic out of getMaxNumVectorRegs; we need to apply the
+ // allocation granule and clamping.
+ auto [MinNumAGPRAttr, MaxNumAGPRAttr] =
+ AMDGPU::getIntegerPairAttribute(F, "amdgpu-agpr-alloc", {~0u, ~0u},
+ /*OnlyFirstRequired=*/true);
+ MinNumAGPRs = MinNumAGPRAttr;
}
if (AMDGPU::isChainCC(CC)) {
diff --git a/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.h b/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.h
index 4560615..b7dbb59 100644
--- a/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.h
+++ b/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.h
@@ -509,7 +509,9 @@ private:
// user arguments. This is an offset from the KernargSegmentPtr.
bool ImplicitArgPtr : 1;
- bool MayNeedAGPRs : 1;
+ /// Minimum number of AGPRs required to allocate in the function. Only
+ /// relevant for gfx90a-gfx950. For gfx908, this should be infinite.
+ unsigned MinNumAGPRs = ~0u;
// The hard-wired high half of the address of the global information table
// for AMDPAL OS type. 0xffffffff represents no hard-wired high half, since
@@ -537,6 +539,8 @@ private:
void MRI_NoteCloneVirtualRegister(Register NewReg, Register SrcReg) override;
public:
+ static bool MFMAVGPRForm;
+
struct VGPRSpillToAGPR {
SmallVector<MCPhysReg, 32> Lanes;
bool FullyAllocated = false;
@@ -1196,9 +1200,7 @@ public:
unsigned getMaxMemoryClusterDWords() const { return MaxMemoryClusterDWords; }
- bool mayNeedAGPRs() const {
- return MayNeedAGPRs;
- }
+ unsigned getMinNumAGPRs() const { return MinNumAGPRs; }
// \returns true if a function has a use of AGPRs via inline asm or
// has a call which may use it.
diff --git a/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp b/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp
index 3c2dd42..3115579 100644
--- a/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp
@@ -1118,12 +1118,7 @@ SIRegisterInfo::getPointerRegClass(unsigned Kind) const {
const TargetRegisterClass *
SIRegisterInfo::getCrossCopyRegClass(const TargetRegisterClass *RC) const {
- if (isAGPRClass(RC) && !ST.hasGFX90AInsts())
- return getEquivalentVGPRClass(RC);
- if (RC == &AMDGPU::SCC_CLASSRegClass)
- return getWaveMaskRegClass();
-
- return RC;
+ return RC == &AMDGPU::SCC_CLASSRegClass ? &AMDGPU::SReg_32RegClass : RC;
}
static unsigned getNumSubRegsForSpillOp(const MachineInstr &MI,
diff --git a/llvm/lib/Target/AMDGPU/VOP3Instructions.td b/llvm/lib/Target/AMDGPU/VOP3Instructions.td
index 4a2b54d..42ec8ba 100644
--- a/llvm/lib/Target/AMDGPU/VOP3Instructions.td
+++ b/llvm/lib/Target/AMDGPU/VOP3Instructions.td
@@ -97,6 +97,7 @@ class VOP3Interp<string OpName, VOPProfile P, list<dag> pattern = []> :
VOP3_Pseudo<OpName, P, pattern> {
let AsmMatchConverter = "cvtVOP3Interp";
let mayRaiseFPException = 0;
+ let VOP3_OPSEL = P.HasOpSel;
}
def VOP3_INTERP : VOPProfile<[f32, f32, i32, untyped]> {
@@ -119,16 +120,17 @@ def VOP3_INTERP_MOV : VOPProfile<[f32, i32, i32, untyped]> {
let HasSrc0Mods = 0;
}
-class getInterp16Asm <bit HasSrc2, bit HasOMod> {
+class getInterp16Asm <bit HasSrc2, bit HasOMod, bit OpSel> {
string src2 = !if(HasSrc2, ", $src2_modifiers", "");
string omod = !if(HasOMod, "$omod", "");
+ string opsel = !if(OpSel, "$op_sel", "");
string ret =
- " $vdst, $src0_modifiers, $attr$attrchan"#src2#"$high$clamp"#omod;
+ " $vdst, $src0_modifiers, $attr$attrchan"#src2#"$high$clamp"#omod#opsel;
}
class getInterp16Ins <bit HasSrc2, bit HasOMod,
- Operand Src0Mod, Operand Src2Mod> {
- dag ret = !if(HasSrc2,
+ Operand Src0Mod, Operand Src2Mod, bit OpSel> {
+ dag ret1 = !if(HasSrc2,
!if(HasOMod,
(ins Src0Mod:$src0_modifiers, VRegSrc_32:$src0,
InterpAttr:$attr, InterpAttrChan:$attrchan,
@@ -143,19 +145,22 @@ class getInterp16Ins <bit HasSrc2, bit HasOMod,
InterpAttr:$attr, InterpAttrChan:$attrchan,
highmod:$high, Clamp0:$clamp, omod0:$omod)
);
+ dag ret2 = !if(OpSel, (ins op_sel0:$op_sel), (ins));
+ dag ret = !con(ret1, ret2);
}
-class VOP3_INTERP16 <list<ValueType> ArgVT> : VOPProfile<ArgVT> {
+class VOP3_INTERP16 <list<ValueType> ArgVT, bit OpSel = 0> : VOPProfile<ArgVT> {
let IsSingle = 1;
let HasOMod = !ne(DstVT.Value, f16.Value);
let HasHigh = 1;
+ let HasOpSel = OpSel;
let Src0Mod = FPVRegInputMods;
let Src2Mod = FPVRegInputMods;
let Outs64 = (outs DstRC.RegClass:$vdst);
- let Ins64 = getInterp16Ins<HasSrc2, HasOMod, Src0Mod, Src2Mod>.ret;
- let Asm64 = getInterp16Asm<HasSrc2, HasOMod>.ret;
+ let Ins64 = getInterp16Ins<HasSrc2, HasOMod, Src0Mod, Src2Mod, OpSel>.ret;
+ let Asm64 = getInterp16Asm<HasSrc2, HasOMod, OpSel>.ret;
}
//===----------------------------------------------------------------------===//
@@ -480,7 +485,7 @@ let SubtargetPredicate = isGFX9Plus in {
defm V_MAD_U16_gfx9 : VOP3Inst_t16 <"v_mad_u16_gfx9", VOP_I16_I16_I16_I16>;
defm V_MAD_I16_gfx9 : VOP3Inst_t16 <"v_mad_i16_gfx9", VOP_I16_I16_I16_I16>;
let OtherPredicates = [isNotGFX90APlus] in
-def V_INTERP_P2_F16_gfx9 : VOP3Interp <"v_interp_p2_f16_gfx9", VOP3_INTERP16<[f16, f32, i32, f32]>>;
+def V_INTERP_P2_F16_opsel : VOP3Interp <"v_interp_p2_f16_opsel", VOP3_INTERP16<[f16, f32, i32, f32], /*OpSel*/ 1>>;
} // End SubtargetPredicate = isGFX9Plus
// This predicate should only apply to the selection pattern. The
@@ -2676,6 +2681,14 @@ multiclass VOP3Interp_F16_Real_gfx9<bits<10> op, string OpName, string AsmName>
}
}
+multiclass VOP3Interp_F16_OpSel_Real_gfx9<bits<10> op, string OpName, string AsmName> {
+ def _gfx9 : VOP3_Real<!cast<VOP3_Pseudo>(OpName), SIEncodingFamily.GFX9>,
+ VOP3Interp_OpSel_gfx9 <op, !cast<VOP3_Pseudo>(OpName).Pfl> {
+ VOP3_Pseudo ps = !cast<VOP3_Pseudo>(OpName);
+ let AsmString = AsmName # ps.AsmOperands;
+ }
+}
+
multiclass VOP3_Real_gfx9<bits<10> op, string AsmName> {
def _gfx9 : VOP3_Real<!cast<VOP_Pseudo>(NAME#"_e64"), SIEncodingFamily.GFX9>,
VOP3e_vi <op, !cast<VOP_Pseudo>(NAME#"_e64").Pfl> {
@@ -2788,7 +2801,7 @@ defm V_MAD_U16_gfx9 : VOP3OpSel_F16_Real_gfx9 <0x204, "v_mad_u16">;
defm V_MAD_I16_gfx9 : VOP3OpSel_F16_Real_gfx9 <0x205, "v_mad_i16">;
defm V_FMA_F16_gfx9 : VOP3OpSel_F16_Real_gfx9 <0x206, "v_fma_f16">;
defm V_DIV_FIXUP_F16_gfx9 : VOP3OpSel_F16_Real_gfx9 <0x207, "v_div_fixup_f16">;
-defm V_INTERP_P2_F16_gfx9 : VOP3Interp_F16_Real_gfx9 <0x277, "V_INTERP_P2_F16_gfx9", "v_interp_p2_f16">;
+defm V_INTERP_P2_F16_opsel : VOP3Interp_F16_OpSel_Real_gfx9 <0x277, "V_INTERP_P2_F16_opsel", "v_interp_p2_f16">;
defm V_ADD_I32 : VOP3_Real_vi <0x29c>;
defm V_SUB_I32 : VOP3_Real_vi <0x29d>;
diff --git a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
index 5daf860..3a0cc35 100644
--- a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
+++ b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
@@ -67,7 +67,7 @@ class VOP3P_Mix_Profile<VOPProfile P, VOP3Features Features = VOP3_REGULAR,
class VOP3P_Mix_Profile_t16<VOPProfile P, VOP3Features Features = VOP3_REGULAR>
: VOP3P_Mix_Profile<P, Features, 0> {
let IsTrue16 = 1;
- let IsRealTrue16 = 1;
+ let IsRealTrue16 = 1;
let DstRC64 = getVALUDstForVT<P.DstVT, 1 /*IsTrue16*/, 1 /*IsVOP3Encoding*/>.ret;
}
@@ -950,7 +950,7 @@ class MFMA_F8F6F4_WithSizeTable_Helper<VOP3_Pseudo ps, string F8F8Op> :
}
// Currently assumes scaled instructions never have abid
-class MAIFrag<SDPatternOperator Op, code pred, bit HasAbid = true, bit Scaled = false> : PatFrag <
+class MAIFrag<SDPatternOperator Op, bit HasAbid = true, bit Scaled = false> : PatFrag <
!if(Scaled, (ops node:$src0, node:$src1, node:$src2, node:$cbsz, node:$blgp,
node:$src0_modifiers, node:$scale_src0,
node:$src1_modifiers, node:$scale_src1),
@@ -959,37 +959,30 @@ class MAIFrag<SDPatternOperator Op, code pred, bit HasAbid = true, bit Scaled =
(ops node:$blgp))),
!if(Scaled, (Op $src0, $src1, $src2, $cbsz, $blgp, $src0_modifiers, $scale_src0, $src1_modifiers, $scale_src1),
!if(HasAbid, (Op $src0, $src1, $src2, $cbsz, $abid, $blgp),
- (Op $src0, $src1, $src2, $cbsz, $blgp))),
- pred
->;
-
-defvar MayNeedAGPRs = [{
- return MF->getInfo<SIMachineFunctionInfo>()->mayNeedAGPRs();
-}];
-
-defvar MayNeedAGPRs_gisel = [{
- return MF.getInfo<SIMachineFunctionInfo>()->mayNeedAGPRs();
-}];
+ (Op $src0, $src1, $src2, $cbsz, $blgp)))>;
-defvar MayNotNeedAGPRs = [{
- return !MF->getInfo<SIMachineFunctionInfo>()->mayNeedAGPRs();
-}];
+class CanUseAGPR_MAI<ValueType vt> {
+ code PredicateCode = [{
+ return !Subtarget->hasGFX90AInsts() ||
+ (!SIMachineFunctionInfo::MFMAVGPRForm &&
+ MF->getInfo<SIMachineFunctionInfo>()->getMinNumAGPRs() >=
+ }] # !srl(vt.Size, 5) # ");";
-defvar MayNotNeedAGPRs_gisel = [{
- return !MF.getInfo<SIMachineFunctionInfo>()->mayNeedAGPRs();
-}];
+ code GISelPredicateCode = [{
+ return !Subtarget->hasGFX90AInsts() ||
+ (!SIMachineFunctionInfo::MFMAVGPRForm &&
+ MF.getInfo<SIMachineFunctionInfo>()->getMinNumAGPRs() >=
+ }] # !srl(vt.Size, 5) # ");";
+}
-class AgprMAIFrag<SDPatternOperator Op, bit HasAbid = true,
+class AgprMAIFrag<SDPatternOperator Op, ValueType vt, bit HasAbid = true,
bit Scaled = false> :
- MAIFrag<Op, MayNeedAGPRs, HasAbid, Scaled> {
- let GISelPredicateCode = MayNeedAGPRs_gisel;
-}
+ MAIFrag<Op, HasAbid, Scaled>,
+ CanUseAGPR_MAI<vt>;
class VgprMAIFrag<SDPatternOperator Op, bit HasAbid = true,
- bit Scaled = false> :
- MAIFrag<Op, MayNotNeedAGPRs, HasAbid, Scaled> {
- let GISelPredicateCode = MayNotNeedAGPRs_gisel;
-}
+ bit Scaled = false> :
+ MAIFrag<Op, HasAbid, Scaled>;
let isAsCheapAsAMove = 1, isReMaterializable = 1 in {
defm V_ACCVGPR_READ_B32 : VOP3Inst<"v_accvgpr_read_b32", VOPProfileAccRead>;
@@ -1037,16 +1030,19 @@ multiclass MAIInst<string OpName, string P, SDPatternOperator node = null_frag,
bit HasAbid = true,
bit Scaled = false> {
defvar NoDstOverlap = !cast<VOPProfileMAI>("VOPProfileMAI_" # P).NoDstOverlap;
+ defvar ProfileAGPR = !cast<VOPProfileMAI>("VOPProfileMAI_" # P);
+ defvar ProfileVGPR = !cast<VOPProfileMAI>("VOPProfileMAI_" # P # "_VCD");
+
let isConvergent = 1, mayRaiseFPException = 0, ReadsModeReg = 1 in {
// FP32 denorm mode is respected, rounding mode is not. Exceptions are not supported.
let Constraints = !if(NoDstOverlap, "@earlyclobber $vdst", "") in {
- def _e64 : MAIInst<OpName, !cast<VOPProfileMAI>("VOPProfileMAI_" # P),
- !if(!or(NoDstOverlap, !eq(node, null_frag)), null_frag, AgprMAIFrag<node, HasAbid, Scaled>), Scaled>,
+ def _e64 : MAIInst<OpName, ProfileAGPR,
+ !if(!or(NoDstOverlap, !eq(node, null_frag)), null_frag, AgprMAIFrag<node, ProfileAGPR.DstVT, HasAbid, Scaled>), Scaled>,
MFMATable<0, "AGPR", NAME # "_e64">;
let OtherPredicates = [isGFX90APlus], Mnemonic = OpName in
- def _vgprcd_e64 : MAIInst<OpName # "_vgprcd", !cast<VOPProfileMAI>("VOPProfileMAI_" # P # "_VCD"),
+ def _vgprcd_e64 : MAIInst<OpName # "_vgprcd", ProfileVGPR,
!if(!or(NoDstOverlap, !eq(node, null_frag)), null_frag, VgprMAIFrag<node, HasAbid, Scaled>), Scaled>,
MFMATable<0, "VGPR", NAME # "_vgprcd_e64", NAME # "_e64">;
}
@@ -1055,12 +1051,12 @@ multiclass MAIInst<string OpName, string P, SDPatternOperator node = null_frag,
let Constraints = !if(NoDstOverlap, "$vdst = $src2", ""),
isConvertibleToThreeAddress = NoDstOverlap,
Mnemonic = OpName in {
- def "_mac_e64" : MAIInst<OpName # "_mac", !cast<VOPProfileMAI>("VOPProfileMAI_" # P),
- !if(!eq(node, null_frag), null_frag, AgprMAIFrag<node, HasAbid, Scaled>), Scaled>,
+ def "_mac_e64" : MAIInst<OpName # "_mac", ProfileAGPR,
+ !if(!eq(node, null_frag), null_frag, AgprMAIFrag<node, ProfileAGPR.DstVT, HasAbid, Scaled>), Scaled>,
MFMATable<1, "AGPR", NAME # "_e64", NAME # "_mac_e64">;
let OtherPredicates = [isGFX90APlus] in
- def _mac_vgprcd_e64 : MAIInst<OpName # "_mac_vgprcd", !cast<VOPProfileMAI>("VOPProfileMAI_" # P # "_VCD"),
+ def _mac_vgprcd_e64 : MAIInst<OpName # "_mac_vgprcd", ProfileVGPR,
!if(!eq(node, null_frag), null_frag, VgprMAIFrag<node, HasAbid, Scaled>), Scaled>,
MFMATable<1, "VGPR", NAME # "_vgprcd_e64", NAME # "_mac_e64">;
}
@@ -1074,11 +1070,11 @@ multiclass ScaledMAIInst_mc<string OpName, string UnscaledOpName_, SDPatternOper
defvar UnscaledOpName = UnscaledOpName_#VariantSuffix;
defvar HasAbid = false;
-
- defvar NoDstOverlap = !cast<VOPProfileMAI>(!cast<MAIInst>(UnscaledOpName#"_e64").Pfl).NoDstOverlap;
+ defvar Profile = !cast<VOPProfileMAI>(!cast<MAIInst>(UnscaledOpName#"_e64").Pfl);
+ defvar NoDstOverlap = Profile.NoDstOverlap;
def _e64 : ScaledMAIInst<OpName,
- !cast<MAIInst>(UnscaledOpName#"_e64"), !if(NoDstOverlap, null_frag, AgprMAIFrag<node, HasAbid, true>)>,
+ !cast<MAIInst>(UnscaledOpName#"_e64"), !if(NoDstOverlap, null_frag, AgprMAIFrag<node, Profile.DstVT, HasAbid, true>)>,
MFMATable<0, "AGPR", NAME # "_e64">;
def _vgprcd_e64 : ScaledMAIInst<OpName # "_vgprcd",
@@ -1090,7 +1086,7 @@ multiclass ScaledMAIInst_mc<string OpName, string UnscaledOpName_, SDPatternOper
isConvertibleToThreeAddress = NoDstOverlap,
Mnemonic = UnscaledOpName_ in {
def _mac_e64 : ScaledMAIInst<OpName # "_mac",
- !cast<MAIInst>(UnscaledOpName # "_mac_e64"), AgprMAIFrag<node, HasAbid, true>>,
+ !cast<MAIInst>(UnscaledOpName # "_mac_e64"), AgprMAIFrag<node, Profile.DstVT, HasAbid, true>>,
MFMATable<1, "AGPR", NAME # "_e64">;
def _mac_vgprcd_e64 : ScaledMAIInst<OpName # " _mac_vgprcd",
diff --git a/llvm/lib/Target/AMDGPU/VOPInstructions.td b/llvm/lib/Target/AMDGPU/VOPInstructions.td
index 631f0f3..8325c62 100644
--- a/llvm/lib/Target/AMDGPU/VOPInstructions.td
+++ b/llvm/lib/Target/AMDGPU/VOPInstructions.td
@@ -419,6 +419,13 @@ class VOP3a_ScaleSel_gfx1250<bits<10> op, VOPProfile p> : VOP3e_gfx11_gfx12<op,
let Inst{14-11} = scale_sel;
}
+class VOP3Interp_OpSel_gfx9<bits<10> op, VOPProfile p> : VOP3Interp_vi<op, p> {
+ let Inst{11} = src0_modifiers{2};
+ // There's no src1
+ let Inst{13} = src2_modifiers{2};
+ let Inst{14} = !if(p.HasDst, src0_modifiers{3}, 0);
+}
+
class VOP3Interp_gfx10<bits<10> op, VOPProfile p> : VOP3e_gfx10<op, p> {
bits<6> attr;
bits<2> attrchan;
diff --git a/llvm/lib/Target/PowerPC/AsmParser/PPCAsmParser.cpp b/llvm/lib/Target/PowerPC/AsmParser/PPCAsmParser.cpp
index 1fc475d..561a9c5 100644
--- a/llvm/lib/Target/PowerPC/AsmParser/PPCAsmParser.cpp
+++ b/llvm/lib/Target/PowerPC/AsmParser/PPCAsmParser.cpp
@@ -349,32 +349,30 @@ public:
bool isImm() const override {
return Kind == Immediate || Kind == Expression;
}
- bool isU1Imm() const { return Kind == Immediate && isUInt<1>(getImm()); }
- bool isU2Imm() const { return Kind == Immediate && isUInt<2>(getImm()); }
- bool isU3Imm() const { return Kind == Immediate && isUInt<3>(getImm()); }
- bool isU4Imm() const { return Kind == Immediate && isUInt<4>(getImm()); }
- bool isU5Imm() const { return Kind == Immediate && isUInt<5>(getImm()); }
- bool isS5Imm() const { return Kind == Immediate && isInt<5>(getImm()); }
- bool isU6Imm() const { return Kind == Immediate && isUInt<6>(getImm()); }
- bool isU6ImmX2() const { return Kind == Immediate &&
- isUInt<6>(getImm()) &&
- (getImm() & 1) == 0; }
- bool isU7Imm() const { return Kind == Immediate && isUInt<7>(getImm()); }
- bool isU7ImmX4() const { return Kind == Immediate &&
- isUInt<7>(getImm()) &&
- (getImm() & 3) == 0; }
- bool isU8Imm() const { return Kind == Immediate && isUInt<8>(getImm()); }
- bool isU8ImmX8() const { return Kind == Immediate &&
- isUInt<8>(getImm()) &&
- (getImm() & 7) == 0; }
-
- bool isU10Imm() const { return Kind == Immediate && isUInt<10>(getImm()); }
- bool isU12Imm() const { return Kind == Immediate && isUInt<12>(getImm()); }
+
+ template <uint64_t N> bool isUImm() const {
+ return Kind == Immediate && isUInt<N>(getImm());
+ }
+ template <uint64_t N> bool isSImm() const {
+ return Kind == Immediate && isInt<N>(getImm());
+ }
+ bool isU6ImmX2() const { return isUImm<6>() && (getImm() & 1) == 0; }
+ bool isU7ImmX4() const { return isUImm<7>() && (getImm() & 3) == 0; }
+ bool isU8ImmX8() const { return isUImm<8>() && (getImm() & 7) == 0; }
+
bool isU16Imm() const { return isExtImm<16>(/*Signed*/ false, 1); }
bool isS16Imm() const { return isExtImm<16>(/*Signed*/ true, 1); }
bool isS16ImmX4() const { return isExtImm<16>(/*Signed*/ true, 4); }
bool isS16ImmX16() const { return isExtImm<16>(/*Signed*/ true, 16); }
bool isS17Imm() const { return isExtImm<17>(/*Signed*/ true, 1); }
+ bool isS34Imm() const {
+ // Once the PC-Rel ABI is finalized, evaluate whether a 34-bit
+ // ContextImmediate is needed.
+ return Kind == Expression || isSImm<34>();
+ }
+ bool isS34ImmX16() const {
+ return Kind == Expression || (isSImm<34>() && (getImm() & 15) == 0);
+ }
bool isHashImmX8() const {
// The Hash Imm form is used for instructions that check or store a hash.
@@ -384,16 +382,6 @@ public:
(getImm() & 7) == 0);
}
- bool isS34ImmX16() const {
- return Kind == Expression ||
- (Kind == Immediate && isInt<34>(getImm()) && (getImm() & 15) == 0);
- }
- bool isS34Imm() const {
- // Once the PC-Rel ABI is finalized, evaluate whether a 34-bit
- // ContextImmediate is needed.
- return Kind == Expression || (Kind == Immediate && isInt<34>(getImm()));
- }
-
bool isTLSReg() const { return Kind == TLSRegister; }
bool isDirectBr() const {
if (Kind == Expression)
@@ -1637,7 +1625,7 @@ bool PPCAsmParser::parseInstruction(ParseInstructionInfo &Info, StringRef Name,
if (Operands.size() != 5)
return false;
PPCOperand &EHOp = (PPCOperand &)*Operands[4];
- if (EHOp.isU1Imm() && EHOp.getImm() == 0)
+ if (EHOp.isUImm<1>() && EHOp.getImm() == 0)
Operands.pop_back();
}
@@ -1817,7 +1805,7 @@ unsigned PPCAsmParser::validateTargetOperandClass(MCParsedAsmOperand &AsmOp,
}
PPCOperand &Op = static_cast<PPCOperand &>(AsmOp);
- if (Op.isU3Imm() && Op.getImm() == ImmVal)
+ if (Op.isUImm<3>() && Op.getImm() == ImmVal)
return Match_Success;
return Match_InvalidOperand;
diff --git a/llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.cpp b/llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.cpp
index 48c31c9..81d8e94 100644
--- a/llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.cpp
+++ b/llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.cpp
@@ -206,45 +206,24 @@ PPCMCCodeEmitter::getVSRpEvenEncoding(const MCInst &MI, unsigned OpNo,
return RegBits;
}
-unsigned PPCMCCodeEmitter::getImm16Encoding(const MCInst &MI, unsigned OpNo,
- SmallVectorImpl<MCFixup> &Fixups,
- const MCSubtargetInfo &STI) const {
- const MCOperand &MO = MI.getOperand(OpNo);
- if (MO.isReg() || MO.isImm()) return getMachineOpValue(MI, MO, Fixups, STI);
-
- // Add a fixup for the immediate field.
- addFixup(Fixups, IsLittleEndian ? 0 : 2, MO.getExpr(), PPC::fixup_ppc_half16);
- return 0;
-}
-
-uint64_t PPCMCCodeEmitter::getImm34Encoding(const MCInst &MI, unsigned OpNo,
- SmallVectorImpl<MCFixup> &Fixups,
- const MCSubtargetInfo &STI,
- MCFixupKind Fixup) const {
+template <MCFixupKind Fixup>
+uint64_t PPCMCCodeEmitter::getImmEncoding(const MCInst &MI, unsigned OpNo,
+ SmallVectorImpl<MCFixup> &Fixups,
+ const MCSubtargetInfo &STI) const {
const MCOperand &MO = MI.getOperand(OpNo);
assert(!MO.isReg() && "Not expecting a register for this operand.");
if (MO.isImm())
return getMachineOpValue(MI, MO, Fixups, STI);
+ uint32_t Offset = 0;
+ if (Fixup == PPC::fixup_ppc_half16)
+ Offset = IsLittleEndian ? 0 : 2;
+
// Add a fixup for the immediate field.
- addFixup(Fixups, 0, MO.getExpr(), Fixup);
+ addFixup(Fixups, Offset, MO.getExpr(), Fixup);
return 0;
}
-uint64_t
-PPCMCCodeEmitter::getImm34EncodingNoPCRel(const MCInst &MI, unsigned OpNo,
- SmallVectorImpl<MCFixup> &Fixups,
- const MCSubtargetInfo &STI) const {
- return getImm34Encoding(MI, OpNo, Fixups, STI, PPC::fixup_ppc_imm34);
-}
-
-uint64_t
-PPCMCCodeEmitter::getImm34EncodingPCRel(const MCInst &MI, unsigned OpNo,
- SmallVectorImpl<MCFixup> &Fixups,
- const MCSubtargetInfo &STI) const {
- return getImm34Encoding(MI, OpNo, Fixups, STI, PPC::fixup_ppc_pcrel34);
-}
-
unsigned PPCMCCodeEmitter::getDispRIEncoding(const MCInst &MI, unsigned OpNo,
SmallVectorImpl<MCFixup> &Fixups,
const MCSubtargetInfo &STI) const {
diff --git a/llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.h b/llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.h
index b574557..3356513 100644
--- a/llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.h
+++ b/llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.h
@@ -47,19 +47,10 @@ public:
unsigned getAbsCondBrEncoding(const MCInst &MI, unsigned OpNo,
SmallVectorImpl<MCFixup> &Fixups,
const MCSubtargetInfo &STI) const;
- unsigned getImm16Encoding(const MCInst &MI, unsigned OpNo,
- SmallVectorImpl<MCFixup> &Fixups,
- const MCSubtargetInfo &STI) const;
- uint64_t getImm34Encoding(const MCInst &MI, unsigned OpNo,
- SmallVectorImpl<MCFixup> &Fixups,
- const MCSubtargetInfo &STI,
- MCFixupKind Fixup) const;
- uint64_t getImm34EncodingNoPCRel(const MCInst &MI, unsigned OpNo,
- SmallVectorImpl<MCFixup> &Fixups,
- const MCSubtargetInfo &STI) const;
- uint64_t getImm34EncodingPCRel(const MCInst &MI, unsigned OpNo,
- SmallVectorImpl<MCFixup> &Fixups,
- const MCSubtargetInfo &STI) const;
+ template <MCFixupKind Fixup>
+ uint64_t getImmEncoding(const MCInst &MI, unsigned OpNo,
+ SmallVectorImpl<MCFixup> &Fixups,
+ const MCSubtargetInfo &STI) const;
unsigned getDispRIEncoding(const MCInst &MI, unsigned OpNo,
SmallVectorImpl<MCFixup> &Fixups,
const MCSubtargetInfo &STI) const;
diff --git a/llvm/lib/Target/PowerPC/PPCInstr64Bit.td b/llvm/lib/Target/PowerPC/PPCInstr64Bit.td
index 60efa4c..fdca5ebc 100644
--- a/llvm/lib/Target/PowerPC/PPCInstr64Bit.td
+++ b/llvm/lib/Target/PowerPC/PPCInstr64Bit.td
@@ -14,30 +14,6 @@
//===----------------------------------------------------------------------===//
// 64-bit operands.
//
-def s16imm64 : Operand<i64> {
- let PrintMethod = "printS16ImmOperand";
- let EncoderMethod = "getImm16Encoding";
- let ParserMatchClass = PPCS16ImmAsmOperand;
- let DecoderMethod = "decodeSImmOperand<16>";
- let OperandType = "OPERAND_IMMEDIATE";
-}
-def u16imm64 : Operand<i64> {
- let PrintMethod = "printU16ImmOperand";
- let EncoderMethod = "getImm16Encoding";
- let ParserMatchClass = PPCU16ImmAsmOperand;
- let DecoderMethod = "decodeUImmOperand<16>";
- let OperandType = "OPERAND_IMMEDIATE";
-}
-def s17imm64 : Operand<i64> {
- // This operand type is used for addis/lis to allow the assembler parser
- // to accept immediates in the range -65536..65535 for compatibility with
- // the GNU assembler. The operand is treated as 16-bit otherwise.
- let PrintMethod = "printS16ImmOperand";
- let EncoderMethod = "getImm16Encoding";
- let ParserMatchClass = PPCS17ImmAsmOperand;
- let DecoderMethod = "decodeSImmOperand<16>";
- let OperandType = "OPERAND_IMMEDIATE";
-}
def tocentry : Operand<iPTR> {
let MIOperandInfo = (ops i64imm:$imm);
}
diff --git a/llvm/lib/Target/PowerPC/PPCInstrAltivec.td b/llvm/lib/Target/PowerPC/PPCInstrAltivec.td
index c616db4..23d6d88 100644
--- a/llvm/lib/Target/PowerPC/PPCInstrAltivec.td
+++ b/llvm/lib/Target/PowerPC/PPCInstrAltivec.td
@@ -30,6 +30,11 @@
// Altivec transformation functions and pattern fragments.
//
+// fneg is not legal, and desugared as an xor.
+def desugared_fneg : PatFrag<(ops node:$x), (v4f32 (bitconvert (xor (bitconvert $x),
+ (int_ppc_altivec_vslw (bitconvert (v16i8 immAllOnesV)),
+ (bitconvert (v16i8 immAllOnesV))))))>;
+
def vpkuhum_shuffle : PatFrag<(ops node:$lhs, node:$rhs),
(vector_shuffle node:$lhs, node:$rhs), [{
return PPC::isVPKUHUMShuffleMask(cast<ShuffleVectorSDNode>(N), 0, *CurDAG);
@@ -467,11 +472,12 @@ def VMADDFP : VAForm_1<46, (outs vrrc:$RT), (ins vrrc:$RA, vrrc:$RC, vrrc:$RB),
[(set v4f32:$RT,
(fma v4f32:$RA, v4f32:$RC, v4f32:$RB))]>;
-// FIXME: The fma+fneg pattern won't match because fneg is not legal.
+// fneg is not legal, hence we have to match on the desugared version.
def VNMSUBFP: VAForm_1<47, (outs vrrc:$RT), (ins vrrc:$RA, vrrc:$RC, vrrc:$RB),
"vnmsubfp $RT, $RA, $RC, $RB", IIC_VecFP,
- [(set v4f32:$RT, (fneg (fma v4f32:$RA, v4f32:$RC,
- (fneg v4f32:$RB))))]>;
+ [(set v4f32:$RT, (desugared_fneg (fma v4f32:$RA, v4f32:$RC,
+ (desugared_fneg v4f32:$RB))))]>;
+
let hasSideEffects = 1 in {
def VMHADDSHS : VA1a_Int_Ty<32, "vmhaddshs", int_ppc_altivec_vmhaddshs, v8i16>;
def VMHRADDSHS : VA1a_Int_Ty<33, "vmhraddshs", int_ppc_altivec_vmhraddshs,
@@ -892,6 +898,13 @@ def : Pat<(mul v8i16:$vA, v8i16:$vB), (VMLADDUHM $vA, $vB, (v8i16(V_SET0H)))>;
// Add
def : Pat<(add (mul v8i16:$vA, v8i16:$vB), v8i16:$vC), (VMLADDUHM $vA, $vB, $vC)>;
+
+// Fused negated multiply-subtract
+def : Pat<(v4f32 (desugared_fneg
+ (int_ppc_altivec_vmaddfp v4f32:$RA, v4f32:$RC,
+ (desugared_fneg v4f32:$RB)))),
+ (VNMSUBFP $RA, $RC, $RB)>;
+
// Saturating adds/subtracts.
def : Pat<(v16i8 (saddsat v16i8:$vA, v16i8:$vB)), (v16i8 (VADDSBS $vA, $vB))>;
def : Pat<(v16i8 (uaddsat v16i8:$vA, v16i8:$vB)), (v16i8 (VADDUBS $vA, $vB))>;
diff --git a/llvm/lib/Target/PowerPC/PPCRegisterInfo.td b/llvm/lib/Target/PowerPC/PPCRegisterInfo.td
index 6d8c122..65d0484 100644
--- a/llvm/lib/Target/PowerPC/PPCRegisterInfo.td
+++ b/llvm/lib/Target/PowerPC/PPCRegisterInfo.td
@@ -615,7 +615,8 @@ def spe4rc : RegisterOperand<GPRC> {
}
def PPCU1ImmAsmOperand : AsmOperandClass {
- let Name = "U1Imm"; let PredicateMethod = "isU1Imm";
+ let Name = "U1Imm";
+ let PredicateMethod = "isUImm<1>";
let RenderMethod = "addImmOperands";
}
def u1imm : Operand<i32> {
@@ -626,7 +627,8 @@ def u1imm : Operand<i32> {
}
def PPCU2ImmAsmOperand : AsmOperandClass {
- let Name = "U2Imm"; let PredicateMethod = "isU2Imm";
+ let Name = "U2Imm";
+ let PredicateMethod = "isUImm<2>";
let RenderMethod = "addImmOperands";
}
def u2imm : Operand<i32> {
@@ -647,7 +649,8 @@ def atimm : Operand<i32> {
}
def PPCU3ImmAsmOperand : AsmOperandClass {
- let Name = "U3Imm"; let PredicateMethod = "isU3Imm";
+ let Name = "U3Imm";
+ let PredicateMethod = "isUImm<3>";
let RenderMethod = "addImmOperands";
}
def u3imm : Operand<i32> {
@@ -658,7 +661,8 @@ def u3imm : Operand<i32> {
}
def PPCU4ImmAsmOperand : AsmOperandClass {
- let Name = "U4Imm"; let PredicateMethod = "isU4Imm";
+ let Name = "U4Imm";
+ let PredicateMethod = "isUImm<4>";
let RenderMethod = "addImmOperands";
}
def u4imm : Operand<i32> {
@@ -668,7 +672,8 @@ def u4imm : Operand<i32> {
let OperandType = "OPERAND_IMMEDIATE";
}
def PPCS5ImmAsmOperand : AsmOperandClass {
- let Name = "S5Imm"; let PredicateMethod = "isS5Imm";
+ let Name = "S5Imm";
+ let PredicateMethod = "isSImm<5>";
let RenderMethod = "addImmOperands";
}
def s5imm : Operand<i32> {
@@ -678,7 +683,8 @@ def s5imm : Operand<i32> {
let OperandType = "OPERAND_IMMEDIATE";
}
def PPCU5ImmAsmOperand : AsmOperandClass {
- let Name = "U5Imm"; let PredicateMethod = "isU5Imm";
+ let Name = "U5Imm";
+ let PredicateMethod = "isUImm<5>";
let RenderMethod = "addImmOperands";
}
def u5imm : Operand<i32> {
@@ -688,7 +694,8 @@ def u5imm : Operand<i32> {
let OperandType = "OPERAND_IMMEDIATE";
}
def PPCU6ImmAsmOperand : AsmOperandClass {
- let Name = "U6Imm"; let PredicateMethod = "isU6Imm";
+ let Name = "U6Imm";
+ let PredicateMethod = "isUImm<6>";
let RenderMethod = "addImmOperands";
}
def u6imm : Operand<i32> {
@@ -698,7 +705,8 @@ def u6imm : Operand<i32> {
let OperandType = "OPERAND_IMMEDIATE";
}
def PPCU7ImmAsmOperand : AsmOperandClass {
- let Name = "U7Imm"; let PredicateMethod = "isU7Imm";
+ let Name = "U7Imm";
+ let PredicateMethod = "isUImm<7>";
let RenderMethod = "addImmOperands";
}
def u7imm : Operand<i32> {
@@ -708,7 +716,8 @@ def u7imm : Operand<i32> {
let OperandType = "OPERAND_IMMEDIATE";
}
def PPCU8ImmAsmOperand : AsmOperandClass {
- let Name = "U8Imm"; let PredicateMethod = "isU8Imm";
+ let Name = "U8Imm";
+ let PredicateMethod = "isUImm<8>";
let RenderMethod = "addImmOperands";
}
def u8imm : Operand<i32> {
@@ -718,7 +727,8 @@ def u8imm : Operand<i32> {
let OperandType = "OPERAND_IMMEDIATE";
}
def PPCU10ImmAsmOperand : AsmOperandClass {
- let Name = "U10Imm"; let PredicateMethod = "isU10Imm";
+ let Name = "U10Imm";
+ let PredicateMethod = "isUImm<10>";
let RenderMethod = "addImmOperands";
}
def u10imm : Operand<i32> {
@@ -728,7 +738,8 @@ def u10imm : Operand<i32> {
let OperandType = "OPERAND_IMMEDIATE";
}
def PPCU12ImmAsmOperand : AsmOperandClass {
- let Name = "U12Imm"; let PredicateMethod = "isU12Imm";
+ let Name = "U12Imm";
+ let PredicateMethod = "isUImm<12>";
let RenderMethod = "addImmOperands";
}
def u12imm : Operand<i32> {
@@ -743,7 +754,14 @@ def PPCS16ImmAsmOperand : AsmOperandClass {
}
def s16imm : Operand<i32> {
let PrintMethod = "printS16ImmOperand";
- let EncoderMethod = "getImm16Encoding";
+ let EncoderMethod = "getImmEncoding<PPC::fixup_ppc_half16>";
+ let ParserMatchClass = PPCS16ImmAsmOperand;
+ let DecoderMethod = "decodeSImmOperand<16>";
+ let OperandType = "OPERAND_IMMEDIATE";
+}
+def s16imm64 : Operand<i64> {
+ let PrintMethod = "printS16ImmOperand";
+ let EncoderMethod = "getImmEncoding<PPC::fixup_ppc_half16>";
let ParserMatchClass = PPCS16ImmAsmOperand;
let DecoderMethod = "decodeSImmOperand<16>";
let OperandType = "OPERAND_IMMEDIATE";
@@ -754,7 +772,14 @@ def PPCU16ImmAsmOperand : AsmOperandClass {
}
def u16imm : Operand<i32> {
let PrintMethod = "printU16ImmOperand";
- let EncoderMethod = "getImm16Encoding";
+ let EncoderMethod = "getImmEncoding<PPC::fixup_ppc_half16>";
+ let ParserMatchClass = PPCU16ImmAsmOperand;
+ let DecoderMethod = "decodeUImmOperand<16>";
+ let OperandType = "OPERAND_IMMEDIATE";
+}
+def u16imm64 : Operand<i64> {
+ let PrintMethod = "printU16ImmOperand";
+ let EncoderMethod = "getImmEncoding<PPC::fixup_ppc_half16>";
let ParserMatchClass = PPCU16ImmAsmOperand;
let DecoderMethod = "decodeUImmOperand<16>";
let OperandType = "OPERAND_IMMEDIATE";
@@ -768,7 +793,17 @@ def s17imm : Operand<i32> {
// to accept immediates in the range -65536..65535 for compatibility with
// the GNU assembler. The operand is treated as 16-bit otherwise.
let PrintMethod = "printS16ImmOperand";
- let EncoderMethod = "getImm16Encoding";
+ let EncoderMethod = "getImmEncoding<PPC::fixup_ppc_half16>";
+ let ParserMatchClass = PPCS17ImmAsmOperand;
+ let DecoderMethod = "decodeSImmOperand<16>";
+ let OperandType = "OPERAND_IMMEDIATE";
+}
+def s17imm64 : Operand<i64> {
+ // This operand type is used for addis/lis to allow the assembler parser
+ // to accept immediates in the range -65536..65535 for compatibility with
+ // the GNU assembler. The operand is treated as 16-bit otherwise.
+ let PrintMethod = "printS16ImmOperand";
+ let EncoderMethod = "getImmEncoding<PPC::fixup_ppc_half16>";
let ParserMatchClass = PPCS17ImmAsmOperand;
let DecoderMethod = "decodeSImmOperand<16>";
let OperandType = "OPERAND_IMMEDIATE";
@@ -780,14 +815,14 @@ def PPCS34ImmAsmOperand : AsmOperandClass {
}
def s34imm : Operand<i64> {
let PrintMethod = "printS34ImmOperand";
- let EncoderMethod = "getImm34EncodingNoPCRel";
+ let EncoderMethod = "getImmEncoding<PPC::fixup_ppc_imm34>";
let ParserMatchClass = PPCS34ImmAsmOperand;
let DecoderMethod = "decodeSImmOperand<34>";
let OperandType = "OPERAND_IMMEDIATE";
}
def s34imm_pcrel : Operand<i64> {
let PrintMethod = "printS34ImmOperand";
- let EncoderMethod = "getImm34EncodingPCRel";
+ let EncoderMethod = "getImmEncoding<PPC::fixup_ppc_pcrel34>";
let ParserMatchClass = PPCS34ImmAsmOperand;
let DecoderMethod = "decodeSImmOperand<34>";
let OperandType = "OPERAND_IMMEDIATE";
diff --git a/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp b/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
index 34026ed..ecfb5fe 100644
--- a/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
+++ b/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
@@ -439,18 +439,6 @@ bool RISCVCallLowering::canLowerReturn(MachineFunction &MF,
CCState CCInfo(CallConv, IsVarArg, MF, ArgLocs,
MF.getFunction().getContext());
- const RISCVSubtarget &Subtarget = MF.getSubtarget<RISCVSubtarget>();
-
- std::optional<unsigned> FirstMaskArgument = std::nullopt;
- // Preassign the first mask argument.
- if (Subtarget.hasVInstructions()) {
- for (const auto &ArgIdx : enumerate(Outs)) {
- MVT ArgVT = MVT::getVT(ArgIdx.value().Ty);
- if (ArgVT.isVector() && ArgVT.getVectorElementType() == MVT::i1)
- FirstMaskArgument = ArgIdx.index();
- }
- }
-
for (unsigned I = 0, E = Outs.size(); I < E; ++I) {
MVT VT = MVT::getVT(Outs[I].Ty);
if (CC_RISCV(I, VT, VT, CCValAssign::Full, Outs[I].Flags[0], CCInfo,
diff --git a/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp b/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp
index 597dd12..9f9ae2f 100644
--- a/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp
+++ b/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp
@@ -324,6 +324,10 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
OpdsMapping[0] = GPRValueMapping;
+ // Atomics always use GPR destinations. Don't refine any further.
+ if (cast<GLoad>(MI).isAtomic())
+ break;
+
// Use FPR64 for s64 loads on rv32.
if (GPRSize == 32 && Size.getFixedValue() == 64) {
assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD());
@@ -358,6 +362,10 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
OpdsMapping[0] = GPRValueMapping;
+ // Atomics always use GPR sources. Don't refine any further.
+ if (cast<GStore>(MI).isAtomic())
+ break;
+
// Use FPR64 for s64 stores on rv32.
if (GPRSize == 32 && Size.getFixedValue() == 64) {
assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD());
diff --git a/llvm/lib/Target/RISCV/RISCVFeatures.td b/llvm/lib/Target/RISCV/RISCVFeatures.td
index a02de31..27cf057 100644
--- a/llvm/lib/Target/RISCV/RISCVFeatures.td
+++ b/llvm/lib/Target/RISCV/RISCVFeatures.td
@@ -1421,7 +1421,7 @@ def HasVendorXMIPSCMov
: Predicate<"Subtarget->hasVendorXMIPSCMov()">,
AssemblerPredicate<(all_of FeatureVendorXMIPSCMov),
"'Xmipscmov' ('mips.ccmov' instruction)">;
-def UseCCMovInsn : Predicate<"Subtarget->useCCMovInsn()">;
+def UseMIPSCCMovInsn : Predicate<"Subtarget->useMIPSCCMovInsn()">;
def FeatureVendorXMIPSLSP
: RISCVExtension<1, 0, "MIPS optimization for hardware load-store bonding">;
diff --git a/llvm/lib/Target/RISCV/RISCVGISel.td b/llvm/lib/Target/RISCV/RISCVGISel.td
index 7f5d0af..6d01250 100644
--- a/llvm/lib/Target/RISCV/RISCVGISel.td
+++ b/llvm/lib/Target/RISCV/RISCVGISel.td
@@ -190,3 +190,29 @@ let Predicates = [HasStdExtZbkb, NoStdExtZbb, IsRV64] in {
def : Pat<(i64 (zext (i16 GPR:$rs))), (PACKW GPR:$rs, (XLenVT X0))>;
def : Pat<(i32 (zext (i16 GPR:$rs))), (PACKW GPR:$rs, (XLenVT X0))>;
}
+
+//===----------------------------------------------------------------------===//
+// Zalasr patterns not used by SelectionDAG
+//===----------------------------------------------------------------------===//
+
+let Predicates = [HasStdExtZalasr] in {
+ // the sequentially consistent loads use
+ // .aq instead of .aqrl to match the psABI/A.7
+ def : PatLAQ<acquiring_load<atomic_load_aext_8>, LB_AQ, i16>;
+ def : PatLAQ<seq_cst_load<atomic_load_aext_8>, LB_AQ, i16>;
+
+ def : PatLAQ<acquiring_load<atomic_load_nonext_16>, LH_AQ, i16>;
+ def : PatLAQ<seq_cst_load<atomic_load_nonext_16>, LH_AQ, i16>;
+
+ def : PatSRL<releasing_store<atomic_store_8>, SB_RL, i16>;
+ def : PatSRL<seq_cst_store<atomic_store_8>, SB_RL, i16>;
+
+ def : PatSRL<releasing_store<atomic_store_16>, SH_RL, i16>;
+ def : PatSRL<seq_cst_store<atomic_store_16>, SH_RL, i16>;
+}
+
+let Predicates = [HasStdExtZalasr, IsRV64] in {
+ // Load pattern is in RISCVInstrInfoZalasr.td and shared with RV32.
+ def : PatSRL<releasing_store<atomic_store_32>, SW_RL, i32>;
+ def : PatSRL<seq_cst_store<atomic_store_32>, SW_RL, i32>;
+}
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index dcce2d2..a3a4cf2 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -434,7 +434,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::ABS, MVT::i32, Custom);
}
- if (!Subtarget.useCCMovInsn() && !Subtarget.hasVendorXTHeadCondMov())
+ if (!Subtarget.useMIPSCCMovInsn() && !Subtarget.hasVendorXTHeadCondMov())
setOperationAction(ISD::SELECT, XLenVT, Custom);
if (Subtarget.hasVendorXqcia() && !Subtarget.is64Bit()) {
@@ -16498,43 +16498,60 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
SDValue X = N->getOperand(0);
if (Subtarget.hasShlAdd(3)) {
- for (uint64_t Divisor : {3, 5, 9}) {
- if (MulAmt % Divisor != 0)
- continue;
- uint64_t MulAmt2 = MulAmt / Divisor;
- // 3/5/9 * 2^N -> shl (shXadd X, X), N
- if (isPowerOf2_64(MulAmt2)) {
- SDLoc DL(N);
- SDValue X = N->getOperand(0);
- // Put the shift first if we can fold a zext into the
- // shift forming a slli.uw.
- if (X.getOpcode() == ISD::AND && isa<ConstantSDNode>(X.getOperand(1)) &&
- X.getConstantOperandVal(1) == UINT64_C(0xffffffff)) {
- SDValue Shl = DAG.getNode(ISD::SHL, DL, VT, X,
- DAG.getConstant(Log2_64(MulAmt2), DL, VT));
- return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Shl,
- DAG.getConstant(Log2_64(Divisor - 1), DL, VT),
- Shl);
- }
- // Otherwise, put rhe shl second so that it can fold with following
- // instructions (e.g. sext or add).
- SDValue Mul359 =
- DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
- DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
- return DAG.getNode(ISD::SHL, DL, VT, Mul359,
- DAG.getConstant(Log2_64(MulAmt2), DL, VT));
- }
-
- // 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X)
- if (MulAmt2 == 3 || MulAmt2 == 5 || MulAmt2 == 9) {
- SDLoc DL(N);
- SDValue Mul359 =
- DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
- DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
- return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
- DAG.getConstant(Log2_64(MulAmt2 - 1), DL, VT),
- Mul359);
+ int Shift;
+ if (int ShXAmount = isShifted359(MulAmt, Shift)) {
+ // 3/5/9 * 2^N -> shl (shXadd X, X), N
+ SDLoc DL(N);
+ SDValue X = N->getOperand(0);
+ // Put the shift first if we can fold a zext into the shift forming
+ // a slli.uw.
+ if (X.getOpcode() == ISD::AND && isa<ConstantSDNode>(X.getOperand(1)) &&
+ X.getConstantOperandVal(1) == UINT64_C(0xffffffff)) {
+ SDValue Shl =
+ DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(Shift, DL, VT));
+ return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Shl,
+ DAG.getConstant(ShXAmount, DL, VT), Shl);
}
+ // Otherwise, put the shl second so that it can fold with following
+ // instructions (e.g. sext or add).
+ SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
+ DAG.getConstant(ShXAmount, DL, VT), X);
+ return DAG.getNode(ISD::SHL, DL, VT, Mul359,
+ DAG.getConstant(Shift, DL, VT));
+ }
+
+ // 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X)
+ int ShX;
+ int ShY;
+ switch (MulAmt) {
+ case 3 * 5:
+ ShY = 1;
+ ShX = 2;
+ break;
+ case 3 * 9:
+ ShY = 1;
+ ShX = 3;
+ break;
+ case 5 * 5:
+ ShX = ShY = 2;
+ break;
+ case 5 * 9:
+ ShY = 2;
+ ShX = 3;
+ break;
+ case 9 * 9:
+ ShX = ShY = 3;
+ break;
+ default:
+ ShX = ShY = 0;
+ break;
+ }
+ if (ShX) {
+ SDLoc DL(N);
+ SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
+ DAG.getConstant(ShY, DL, VT), X);
+ return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
+ DAG.getConstant(ShX, DL, VT), Mul359);
}
// If this is a power 2 + 2/4/8, we can use a shift followed by a single
@@ -16557,18 +16574,14 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
// variants we could implement. e.g.
// (2^(1,2,3) * 3,5,9 + 1) << C2
// 2^(C1>3) * 3,5,9 +/- 1
- for (uint64_t Divisor : {3, 5, 9}) {
- uint64_t C = MulAmt - 1;
- if (C <= Divisor)
- continue;
- unsigned TZ = llvm::countr_zero(C);
- if ((C >> TZ) == Divisor && (TZ == 1 || TZ == 2 || TZ == 3)) {
+ if (int ShXAmount = isShifted359(MulAmt - 1, Shift)) {
+ assert(Shift != 0 && "MulAmt=4,6,10 handled before");
+ if (Shift <= 3) {
SDLoc DL(N);
- SDValue Mul359 =
- DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
- DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
+ SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
+ DAG.getConstant(ShXAmount, DL, VT), X);
return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
- DAG.getConstant(TZ, DL, VT), X);
+ DAG.getConstant(Shift, DL, VT), X);
}
}
@@ -16576,7 +16589,7 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
if (MulAmt > 2 && isPowerOf2_64((MulAmt - 1) & (MulAmt - 2))) {
unsigned ScaleShift = llvm::countr_zero(MulAmt - 1);
if (ScaleShift >= 1 && ScaleShift < 4) {
- unsigned ShiftAmt = Log2_64(((MulAmt - 1) & (MulAmt - 2)));
+ unsigned ShiftAmt = llvm::countr_zero((MulAmt - 1) & (MulAmt - 2));
SDLoc DL(N);
SDValue Shift1 =
DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShiftAmt, DL, VT));
@@ -16589,7 +16602,7 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
// 2^N - 3/5/9 --> (sub (shl X, C1), (shXadd X, x))
for (uint64_t Offset : {3, 5, 9}) {
if (isPowerOf2_64(MulAmt + Offset)) {
- unsigned ShAmt = Log2_64(MulAmt + Offset);
+ unsigned ShAmt = llvm::countr_zero(MulAmt + Offset);
if (ShAmt >= VT.getSizeInBits())
continue;
SDLoc DL(N);
@@ -16608,21 +16621,16 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
uint64_t MulAmt2 = MulAmt / Divisor;
// 3/5/9 * 3/5/9 * 2^N - In particular, this covers multiples
// of 25 which happen to be quite common.
- for (uint64_t Divisor2 : {3, 5, 9}) {
- if (MulAmt2 % Divisor2 != 0)
- continue;
- uint64_t MulAmt3 = MulAmt2 / Divisor2;
- if (isPowerOf2_64(MulAmt3)) {
- SDLoc DL(N);
- SDValue Mul359A =
- DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
- DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
- SDValue Mul359B = DAG.getNode(
- RISCVISD::SHL_ADD, DL, VT, Mul359A,
- DAG.getConstant(Log2_64(Divisor2 - 1), DL, VT), Mul359A);
- return DAG.getNode(ISD::SHL, DL, VT, Mul359B,
- DAG.getConstant(Log2_64(MulAmt3), DL, VT));
- }
+ if (int ShBAmount = isShifted359(MulAmt2, Shift)) {
+ SDLoc DL(N);
+ SDValue Mul359A =
+ DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
+ DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
+ SDValue Mul359B =
+ DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359A,
+ DAG.getConstant(ShBAmount, DL, VT), Mul359A);
+ return DAG.getNode(ISD::SHL, DL, VT, Mul359B,
+ DAG.getConstant(Shift, DL, VT));
}
}
}
@@ -25031,8 +25039,17 @@ bool RISCVTargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
if (auto *II = dyn_cast<IntrinsicInst>(&Inst)) {
// Mark RVV intrinsic as supported.
- if (RISCVVIntrinsicsTable::getRISCVVIntrinsicInfo(II->getIntrinsicID()))
+ if (RISCVVIntrinsicsTable::getRISCVVIntrinsicInfo(II->getIntrinsicID())) {
+ // GISel doesn't support tuple types yet.
+ if (Inst.getType()->isRISCVVectorTupleTy())
+ return true;
+
+ for (unsigned i = 0; i < II->arg_size(); ++i)
+ if (II->getArgOperand(i)->getType()->isRISCVVectorTupleTy())
+ return true;
+
return false;
+ }
}
if (Inst.getType()->isScalableTy())
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index 7db4832..96e1078 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -4586,24 +4586,23 @@ void RISCVInstrInfo::mulImm(MachineFunction &MF, MachineBasicBlock &MBB,
.addReg(DestReg, RegState::Kill)
.addImm(ShiftAmount)
.setMIFlag(Flag);
- } else if (STI.hasShlAdd(3) &&
- ((Amount % 3 == 0 && isPowerOf2_64(Amount / 3)) ||
- (Amount % 5 == 0 && isPowerOf2_64(Amount / 5)) ||
- (Amount % 9 == 0 && isPowerOf2_64(Amount / 9)))) {
+ } else if (int ShXAmount, ShiftAmount;
+ STI.hasShlAdd(3) &&
+ (ShXAmount = isShifted359(Amount, ShiftAmount)) != 0) {
// We can use Zba SHXADD+SLLI instructions for multiply in some cases.
unsigned Opc;
- uint32_t ShiftAmount;
- if (Amount % 9 == 0) {
- Opc = RISCV::SH3ADD;
- ShiftAmount = Log2_64(Amount / 9);
- } else if (Amount % 5 == 0) {
- Opc = RISCV::SH2ADD;
- ShiftAmount = Log2_64(Amount / 5);
- } else if (Amount % 3 == 0) {
+ switch (ShXAmount) {
+ case 1:
Opc = RISCV::SH1ADD;
- ShiftAmount = Log2_64(Amount / 3);
- } else {
- llvm_unreachable("implied by if-clause");
+ break;
+ case 2:
+ Opc = RISCV::SH2ADD;
+ break;
+ case 3:
+ Opc = RISCV::SH3ADD;
+ break;
+ default:
+ llvm_unreachable("unexpected result of isShifted359");
}
if (ShiftAmount)
BuildMI(MBB, II, DL, get(RISCV::SLLI), DestReg)
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
index 42a0c4c..c5eddb9 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
@@ -25,6 +25,25 @@
namespace llvm {
+// If Value is of the form C1<<C2, where C1 = 3, 5 or 9,
+// returns log2(C1 - 1) and assigns Shift = C2.
+// Otherwise, returns 0.
+template <typename T> int isShifted359(T Value, int &Shift) {
+ if (Value == 0)
+ return 0;
+ Shift = llvm::countr_zero(Value);
+ switch (Value >> Shift) {
+ case 3:
+ return 1;
+ case 5:
+ return 2;
+ case 9:
+ return 3;
+ default:
+ return 0;
+ }
+}
+
class RISCVSubtarget;
static const MachineMemOperand::Flags MONontemporalBit0 =
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoXMips.td b/llvm/lib/Target/RISCV/RISCVInstrInfoXMips.td
index 115ab38e..0b5bee1 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoXMips.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoXMips.td
@@ -175,7 +175,7 @@ def MIPS_CCMOV : RVInstR4<0b11, 0b011, OPC_CUSTOM_0, (outs GPR:$rd),
Sched<[]>;
}
-let Predicates = [UseCCMovInsn] in {
+let Predicates = [UseMIPSCCMovInsn] in {
def : Pat<(select (riscv_setne (XLenVT GPR:$rs2)),
(XLenVT GPR:$rs1), (XLenVT GPR:$rs3)),
(MIPS_CCMOV GPR:$rs1, GPR:$rs2, GPR:$rs3)>;
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZalasr.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZalasr.td
index 1dd7332..1deecd2 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoZalasr.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZalasr.td
@@ -93,12 +93,11 @@ let Predicates = [HasStdExtZalasr] in {
def : PatSRL<releasing_store<atomic_store_32>, SW_RL>;
def : PatSRL<seq_cst_store<atomic_store_32>, SW_RL>;
-} // Predicates = [HasStdExtZalasr]
-let Predicates = [HasStdExtZalasr, IsRV32] in {
- def : PatLAQ<acquiring_load<atomic_load_nonext_32>, LW_AQ>;
- def : PatLAQ<seq_cst_load<atomic_load_nonext_32>, LW_AQ>;
-} // Predicates = [HasStdExtZalasr, IsRV32]
+ // Used by GISel for RV32 and RV64.
+ def : PatLAQ<acquiring_load<atomic_load_nonext_32>, LW_AQ, i32>;
+ def : PatLAQ<seq_cst_load<atomic_load_nonext_32>, LW_AQ, i32>;
+} // Predicates = [HasStdExtZalasr]
let Predicates = [HasStdExtZalasr, IsRV64] in {
def : PatLAQ<acquiring_load<atomic_load_asext_32>, LW_AQ, i64>;
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td
index ce21d83..8d9b777 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td
@@ -808,9 +808,9 @@ multiclass Sh2Add_UWPat<Instruction sh2add_uw> {
}
multiclass Sh3Add_UWPat<Instruction sh3add_uw> {
- def : Pat<(i64 (add_like_non_imm12 (and GPR:$rs1, 0xFFFFFFF8),
+ def : Pat<(i64 (add_like_non_imm12 (and (shl GPR:$rs1, (i64 3)), 0x7FFFFFFFF),
(XLenVT GPR:$rs2))),
- (sh3add_uw (XLenVT (SRLIW GPR:$rs1, 3)), GPR:$rs2)>;
+ (sh3add_uw GPR:$rs1, GPR:$rs2)>;
// Use SRLI to clear the LSBs and SHXADD_UW to mask and shift.
def : Pat<(i64 (add_like_non_imm12 (and GPR:$rs1, 0x7FFFFFFF8),
(XLenVT GPR:$rs2))),
diff --git a/llvm/lib/Target/RISCV/RISCVLoadStoreOptimizer.cpp b/llvm/lib/Target/RISCV/RISCVLoadStoreOptimizer.cpp
index c81a20b..115a96e 100644
--- a/llvm/lib/Target/RISCV/RISCVLoadStoreOptimizer.cpp
+++ b/llvm/lib/Target/RISCV/RISCVLoadStoreOptimizer.cpp
@@ -92,7 +92,7 @@ bool RISCVLoadStoreOpt::runOnMachineFunction(MachineFunction &Fn) {
if (skipFunction(Fn.getFunction()))
return false;
const RISCVSubtarget &Subtarget = Fn.getSubtarget<RISCVSubtarget>();
- if (!Subtarget.useLoadStorePairs())
+ if (!Subtarget.useMIPSLoadStorePairs())
return false;
bool MadeChange = false;
diff --git a/llvm/lib/Target/RISCV/RISCVSubtarget.cpp b/llvm/lib/Target/RISCV/RISCVSubtarget.cpp
index e35ffaf..715ac4c 100644
--- a/llvm/lib/Target/RISCV/RISCVSubtarget.cpp
+++ b/llvm/lib/Target/RISCV/RISCVSubtarget.cpp
@@ -65,9 +65,9 @@ static cl::opt<bool> UseMIPSLoadStorePairsOpt(
cl::desc("Enable the load/store pair optimization pass"), cl::init(false),
cl::Hidden);
-static cl::opt<bool> UseCCMovInsn("use-riscv-ccmov",
- cl::desc("Use 'mips.ccmov' instruction"),
- cl::init(true), cl::Hidden);
+static cl::opt<bool> UseMIPSCCMovInsn("use-riscv-mips-ccmov",
+ cl::desc("Use 'mips.ccmov' instruction"),
+ cl::init(true), cl::Hidden);
void RISCVSubtarget::anchor() {}
@@ -246,10 +246,10 @@ void RISCVSubtarget::overridePostRASchedPolicy(
}
}
-bool RISCVSubtarget::useLoadStorePairs() const {
+bool RISCVSubtarget::useMIPSLoadStorePairs() const {
return UseMIPSLoadStorePairsOpt && HasVendorXMIPSLSP;
}
-bool RISCVSubtarget::useCCMovInsn() const {
- return UseCCMovInsn && HasVendorXMIPSCMov;
+bool RISCVSubtarget::useMIPSCCMovInsn() const {
+ return UseMIPSCCMovInsn && HasVendorXMIPSCMov;
}
diff --git a/llvm/lib/Target/RISCV/RISCVSubtarget.h b/llvm/lib/Target/RISCV/RISCVSubtarget.h
index 7dffa63..6acf799 100644
--- a/llvm/lib/Target/RISCV/RISCVSubtarget.h
+++ b/llvm/lib/Target/RISCV/RISCVSubtarget.h
@@ -227,8 +227,8 @@ public:
unsigned getXLen() const {
return is64Bit() ? 64 : 32;
}
- bool useLoadStorePairs() const;
- bool useCCMovInsn() const;
+ bool useMIPSLoadStorePairs() const;
+ bool useMIPSCCMovInsn() const;
unsigned getFLen() const {
if (HasStdExtD)
return 64;
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index ee25f69..7bc0b5b 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -2747,20 +2747,72 @@ bool RISCVTTIImpl::getTgtMemIntrinsic(IntrinsicInst *Inst,
Intrinsic::ID IID = Inst->getIntrinsicID();
LLVMContext &C = Inst->getContext();
bool HasMask = false;
+
+ auto getSegNum = [](const IntrinsicInst *II, unsigned PtrOperandNo,
+ bool IsWrite) -> int64_t {
+ if (auto *TarExtTy =
+ dyn_cast<TargetExtType>(II->getArgOperand(0)->getType()))
+ return TarExtTy->getIntParameter(0);
+
+ return 1;
+ };
+
switch (IID) {
case Intrinsic::riscv_vle_mask:
case Intrinsic::riscv_vse_mask:
+ case Intrinsic::riscv_vlseg2_mask:
+ case Intrinsic::riscv_vlseg3_mask:
+ case Intrinsic::riscv_vlseg4_mask:
+ case Intrinsic::riscv_vlseg5_mask:
+ case Intrinsic::riscv_vlseg6_mask:
+ case Intrinsic::riscv_vlseg7_mask:
+ case Intrinsic::riscv_vlseg8_mask:
+ case Intrinsic::riscv_vsseg2_mask:
+ case Intrinsic::riscv_vsseg3_mask:
+ case Intrinsic::riscv_vsseg4_mask:
+ case Intrinsic::riscv_vsseg5_mask:
+ case Intrinsic::riscv_vsseg6_mask:
+ case Intrinsic::riscv_vsseg7_mask:
+ case Intrinsic::riscv_vsseg8_mask:
HasMask = true;
[[fallthrough]];
case Intrinsic::riscv_vle:
- case Intrinsic::riscv_vse: {
+ case Intrinsic::riscv_vse:
+ case Intrinsic::riscv_vlseg2:
+ case Intrinsic::riscv_vlseg3:
+ case Intrinsic::riscv_vlseg4:
+ case Intrinsic::riscv_vlseg5:
+ case Intrinsic::riscv_vlseg6:
+ case Intrinsic::riscv_vlseg7:
+ case Intrinsic::riscv_vlseg8:
+ case Intrinsic::riscv_vsseg2:
+ case Intrinsic::riscv_vsseg3:
+ case Intrinsic::riscv_vsseg4:
+ case Intrinsic::riscv_vsseg5:
+ case Intrinsic::riscv_vsseg6:
+ case Intrinsic::riscv_vsseg7:
+ case Intrinsic::riscv_vsseg8: {
// Intrinsic interface:
// riscv_vle(merge, ptr, vl)
// riscv_vle_mask(merge, ptr, mask, vl, policy)
// riscv_vse(val, ptr, vl)
// riscv_vse_mask(val, ptr, mask, vl, policy)
+ // riscv_vlseg#(merge, ptr, vl, sew)
+ // riscv_vlseg#_mask(merge, ptr, mask, vl, policy, sew)
+ // riscv_vsseg#(val, ptr, vl, sew)
+ // riscv_vsseg#_mask(val, ptr, mask, vl, sew)
bool IsWrite = Inst->getType()->isVoidTy();
Type *Ty = IsWrite ? Inst->getArgOperand(0)->getType() : Inst->getType();
+ // The results of segment loads are TargetExtType.
+ if (auto *TarExtTy = dyn_cast<TargetExtType>(Ty)) {
+ unsigned SEW =
+ 1 << cast<ConstantInt>(Inst->getArgOperand(Inst->arg_size() - 1))
+ ->getZExtValue();
+ Ty = TarExtTy->getTypeParameter(0U);
+ Ty = ScalableVectorType::get(
+ IntegerType::get(C, SEW),
+ cast<ScalableVectorType>(Ty)->getMinNumElements() * 8 / SEW);
+ }
const auto *RVVIInfo = RISCVVIntrinsicsTable::getRISCVVIntrinsicInfo(IID);
unsigned VLIndex = RVVIInfo->VLOperand;
unsigned PtrOperandNo = VLIndex - 1 - HasMask;
@@ -2771,23 +2823,72 @@ bool RISCVTTIImpl::getTgtMemIntrinsic(IntrinsicInst *Inst,
if (HasMask)
Mask = Inst->getArgOperand(VLIndex - 1);
Value *EVL = Inst->getArgOperand(VLIndex);
+ unsigned SegNum = getSegNum(Inst, PtrOperandNo, IsWrite);
+ // RVV uses contiguous elements as a segment.
+ if (SegNum > 1) {
+ unsigned ElemSize = Ty->getScalarSizeInBits();
+ auto *SegTy = IntegerType::get(C, ElemSize * SegNum);
+ Ty = VectorType::get(SegTy, cast<VectorType>(Ty));
+ }
Info.InterestingOperands.emplace_back(Inst, PtrOperandNo, IsWrite, Ty,
Alignment, Mask, EVL);
return true;
}
case Intrinsic::riscv_vlse_mask:
case Intrinsic::riscv_vsse_mask:
+ case Intrinsic::riscv_vlsseg2_mask:
+ case Intrinsic::riscv_vlsseg3_mask:
+ case Intrinsic::riscv_vlsseg4_mask:
+ case Intrinsic::riscv_vlsseg5_mask:
+ case Intrinsic::riscv_vlsseg6_mask:
+ case Intrinsic::riscv_vlsseg7_mask:
+ case Intrinsic::riscv_vlsseg8_mask:
+ case Intrinsic::riscv_vssseg2_mask:
+ case Intrinsic::riscv_vssseg3_mask:
+ case Intrinsic::riscv_vssseg4_mask:
+ case Intrinsic::riscv_vssseg5_mask:
+ case Intrinsic::riscv_vssseg6_mask:
+ case Intrinsic::riscv_vssseg7_mask:
+ case Intrinsic::riscv_vssseg8_mask:
HasMask = true;
[[fallthrough]];
case Intrinsic::riscv_vlse:
- case Intrinsic::riscv_vsse: {
+ case Intrinsic::riscv_vsse:
+ case Intrinsic::riscv_vlsseg2:
+ case Intrinsic::riscv_vlsseg3:
+ case Intrinsic::riscv_vlsseg4:
+ case Intrinsic::riscv_vlsseg5:
+ case Intrinsic::riscv_vlsseg6:
+ case Intrinsic::riscv_vlsseg7:
+ case Intrinsic::riscv_vlsseg8:
+ case Intrinsic::riscv_vssseg2:
+ case Intrinsic::riscv_vssseg3:
+ case Intrinsic::riscv_vssseg4:
+ case Intrinsic::riscv_vssseg5:
+ case Intrinsic::riscv_vssseg6:
+ case Intrinsic::riscv_vssseg7:
+ case Intrinsic::riscv_vssseg8: {
// Intrinsic interface:
// riscv_vlse(merge, ptr, stride, vl)
// riscv_vlse_mask(merge, ptr, stride, mask, vl, policy)
// riscv_vsse(val, ptr, stride, vl)
// riscv_vsse_mask(val, ptr, stride, mask, vl, policy)
+ // riscv_vlsseg#(merge, ptr, offset, vl, sew)
+ // riscv_vlsseg#_mask(merge, ptr, offset, mask, vl, policy, sew)
+ // riscv_vssseg#(val, ptr, offset, vl, sew)
+ // riscv_vssseg#_mask(val, ptr, offset, mask, vl, sew)
bool IsWrite = Inst->getType()->isVoidTy();
Type *Ty = IsWrite ? Inst->getArgOperand(0)->getType() : Inst->getType();
+ // The results of segment loads are TargetExtType.
+ if (auto *TarExtTy = dyn_cast<TargetExtType>(Ty)) {
+ unsigned SEW =
+ 1 << cast<ConstantInt>(Inst->getArgOperand(Inst->arg_size() - 1))
+ ->getZExtValue();
+ Ty = TarExtTy->getTypeParameter(0U);
+ Ty = ScalableVectorType::get(
+ IntegerType::get(C, SEW),
+ cast<ScalableVectorType>(Ty)->getMinNumElements() * 8 / SEW);
+ }
const auto *RVVIInfo = RISCVVIntrinsicsTable::getRISCVVIntrinsicInfo(IID);
unsigned VLIndex = RVVIInfo->VLOperand;
unsigned PtrOperandNo = VLIndex - 2 - HasMask;
@@ -2809,6 +2910,13 @@ bool RISCVTTIImpl::getTgtMemIntrinsic(IntrinsicInst *Inst,
if (HasMask)
Mask = Inst->getArgOperand(VLIndex - 1);
Value *EVL = Inst->getArgOperand(VLIndex);
+ unsigned SegNum = getSegNum(Inst, PtrOperandNo, IsWrite);
+ // RVV uses contiguous elements as a segment.
+ if (SegNum > 1) {
+ unsigned ElemSize = Ty->getScalarSizeInBits();
+ auto *SegTy = IntegerType::get(C, ElemSize * SegNum);
+ Ty = VectorType::get(SegTy, cast<VectorType>(Ty));
+ }
Info.InterestingOperands.emplace_back(Inst, PtrOperandNo, IsWrite, Ty,
Alignment, Mask, EVL, Stride);
return true;
@@ -2817,19 +2925,89 @@ bool RISCVTTIImpl::getTgtMemIntrinsic(IntrinsicInst *Inst,
case Intrinsic::riscv_vluxei_mask:
case Intrinsic::riscv_vsoxei_mask:
case Intrinsic::riscv_vsuxei_mask:
+ case Intrinsic::riscv_vloxseg2_mask:
+ case Intrinsic::riscv_vloxseg3_mask:
+ case Intrinsic::riscv_vloxseg4_mask:
+ case Intrinsic::riscv_vloxseg5_mask:
+ case Intrinsic::riscv_vloxseg6_mask:
+ case Intrinsic::riscv_vloxseg7_mask:
+ case Intrinsic::riscv_vloxseg8_mask:
+ case Intrinsic::riscv_vluxseg2_mask:
+ case Intrinsic::riscv_vluxseg3_mask:
+ case Intrinsic::riscv_vluxseg4_mask:
+ case Intrinsic::riscv_vluxseg5_mask:
+ case Intrinsic::riscv_vluxseg6_mask:
+ case Intrinsic::riscv_vluxseg7_mask:
+ case Intrinsic::riscv_vluxseg8_mask:
+ case Intrinsic::riscv_vsoxseg2_mask:
+ case Intrinsic::riscv_vsoxseg3_mask:
+ case Intrinsic::riscv_vsoxseg4_mask:
+ case Intrinsic::riscv_vsoxseg5_mask:
+ case Intrinsic::riscv_vsoxseg6_mask:
+ case Intrinsic::riscv_vsoxseg7_mask:
+ case Intrinsic::riscv_vsoxseg8_mask:
+ case Intrinsic::riscv_vsuxseg2_mask:
+ case Intrinsic::riscv_vsuxseg3_mask:
+ case Intrinsic::riscv_vsuxseg4_mask:
+ case Intrinsic::riscv_vsuxseg5_mask:
+ case Intrinsic::riscv_vsuxseg6_mask:
+ case Intrinsic::riscv_vsuxseg7_mask:
+ case Intrinsic::riscv_vsuxseg8_mask:
HasMask = true;
[[fallthrough]];
case Intrinsic::riscv_vloxei:
case Intrinsic::riscv_vluxei:
case Intrinsic::riscv_vsoxei:
- case Intrinsic::riscv_vsuxei: {
+ case Intrinsic::riscv_vsuxei:
+ case Intrinsic::riscv_vloxseg2:
+ case Intrinsic::riscv_vloxseg3:
+ case Intrinsic::riscv_vloxseg4:
+ case Intrinsic::riscv_vloxseg5:
+ case Intrinsic::riscv_vloxseg6:
+ case Intrinsic::riscv_vloxseg7:
+ case Intrinsic::riscv_vloxseg8:
+ case Intrinsic::riscv_vluxseg2:
+ case Intrinsic::riscv_vluxseg3:
+ case Intrinsic::riscv_vluxseg4:
+ case Intrinsic::riscv_vluxseg5:
+ case Intrinsic::riscv_vluxseg6:
+ case Intrinsic::riscv_vluxseg7:
+ case Intrinsic::riscv_vluxseg8:
+ case Intrinsic::riscv_vsoxseg2:
+ case Intrinsic::riscv_vsoxseg3:
+ case Intrinsic::riscv_vsoxseg4:
+ case Intrinsic::riscv_vsoxseg5:
+ case Intrinsic::riscv_vsoxseg6:
+ case Intrinsic::riscv_vsoxseg7:
+ case Intrinsic::riscv_vsoxseg8:
+ case Intrinsic::riscv_vsuxseg2:
+ case Intrinsic::riscv_vsuxseg3:
+ case Intrinsic::riscv_vsuxseg4:
+ case Intrinsic::riscv_vsuxseg5:
+ case Intrinsic::riscv_vsuxseg6:
+ case Intrinsic::riscv_vsuxseg7:
+ case Intrinsic::riscv_vsuxseg8: {
// Intrinsic interface (only listed ordered version):
// riscv_vloxei(merge, ptr, index, vl)
// riscv_vloxei_mask(merge, ptr, index, mask, vl, policy)
// riscv_vsoxei(val, ptr, index, vl)
// riscv_vsoxei_mask(val, ptr, index, mask, vl, policy)
+ // riscv_vloxseg#(merge, ptr, index, vl, sew)
+ // riscv_vloxseg#_mask(merge, ptr, index, mask, vl, policy, sew)
+ // riscv_vsoxseg#(val, ptr, index, vl, sew)
+ // riscv_vsoxseg#_mask(val, ptr, index, mask, vl, sew)
bool IsWrite = Inst->getType()->isVoidTy();
Type *Ty = IsWrite ? Inst->getArgOperand(0)->getType() : Inst->getType();
+ // The results of segment loads are TargetExtType.
+ if (auto *TarExtTy = dyn_cast<TargetExtType>(Ty)) {
+ unsigned SEW =
+ 1 << cast<ConstantInt>(Inst->getArgOperand(Inst->arg_size() - 1))
+ ->getZExtValue();
+ Ty = TarExtTy->getTypeParameter(0U);
+ Ty = ScalableVectorType::get(
+ IntegerType::get(C, SEW),
+ cast<ScalableVectorType>(Ty)->getMinNumElements() * 8 / SEW);
+ }
const auto *RVVIInfo = RISCVVIntrinsicsTable::getRISCVVIntrinsicInfo(IID);
unsigned VLIndex = RVVIInfo->VLOperand;
unsigned PtrOperandNo = VLIndex - 2 - HasMask;
@@ -2845,6 +3023,13 @@ bool RISCVTTIImpl::getTgtMemIntrinsic(IntrinsicInst *Inst,
Mask = ConstantInt::getTrue(MaskType);
}
Value *EVL = Inst->getArgOperand(VLIndex);
+ unsigned SegNum = getSegNum(Inst, PtrOperandNo, IsWrite);
+ // RVV uses contiguous elements as a segment.
+ if (SegNum > 1) {
+ unsigned ElemSize = Ty->getScalarSizeInBits();
+ auto *SegTy = IntegerType::get(C, ElemSize * SegNum);
+ Ty = VectorType::get(SegTy, cast<VectorType>(Ty));
+ }
Value *OffsetOp = Inst->getArgOperand(PtrOperandNo + 1);
Info.InterestingOperands.emplace_back(Inst, PtrOperandNo, IsWrite, Ty,
Align(1), Mask, EVL,
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 9f2e075..e16c8f0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -2811,9 +2811,7 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
GetElementPtrInst *NewGEP = simplifyZeroLengthArrayGepInst(Ref);
if (NewGEP) {
Ref->replaceAllUsesWith(NewGEP);
- if (isInstructionTriviallyDead(Ref))
- DeadInsts.insert(Ref);
-
+ DeadInsts.insert(Ref);
Ref = NewGEP;
}
if (Type *GepTy = getGEPType(Ref))
diff --git a/llvm/lib/TargetParser/TargetParser.cpp b/llvm/lib/TargetParser/TargetParser.cpp
index b906690..62a3c88 100644
--- a/llvm/lib/TargetParser/TargetParser.cpp
+++ b/llvm/lib/TargetParser/TargetParser.cpp
@@ -444,7 +444,7 @@ static void fillAMDGCNFeatureMap(StringRef GPU, const Triple &T,
Features["atomic-fmin-fmax-global-f32"] = true;
Features["atomic-fmin-fmax-global-f64"] = true;
Features["wavefrontsize32"] = true;
- Features["cluster"] = true;
+ Features["clusters"] = true;
break;
case GK_GFX1201:
case GK_GFX1200:
diff --git a/llvm/lib/Transforms/IPO/FunctionAttrs.cpp b/llvm/lib/Transforms/IPO/FunctionAttrs.cpp
index 8d9a0e7..50130da 100644
--- a/llvm/lib/Transforms/IPO/FunctionAttrs.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionAttrs.cpp
@@ -2067,6 +2067,36 @@ static void inferAttrsFromFunctionBodies(const SCCNodeSet &SCCNodes,
AI.run(SCCNodes, Changed);
}
+// Determines if the function 'F' can be marked 'norecurse'.
+// It returns true if any call within 'F' could lead to a recursive
+// call back to 'F', and false otherwise.
+// The 'AnyFunctionsAddressIsTaken' parameter is a module-wide flag
+// that is true if any function's address is taken, or if any function
+// has external linkage. This is used to determine the safety of
+// external/library calls.
+static bool mayHaveRecursiveCallee(Function &F,
+ bool AnyFunctionsAddressIsTaken = true) {
+ for (const auto &BB : F) {
+ for (const auto &I : BB.instructionsWithoutDebug()) {
+ if (const auto *CB = dyn_cast<CallBase>(&I)) {
+ const Function *Callee = CB->getCalledFunction();
+ if (!Callee || Callee == &F)
+ return true;
+
+ if (Callee->doesNotRecurse())
+ continue;
+
+ if (!AnyFunctionsAddressIsTaken ||
+ (Callee->isDeclaration() &&
+ Callee->hasFnAttribute(Attribute::NoCallback)))
+ continue;
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
static void addNoRecurseAttrs(const SCCNodeSet &SCCNodes,
SmallPtrSet<Function *, 8> &Changed) {
// Try and identify functions that do not recurse.
@@ -2078,28 +2108,14 @@ static void addNoRecurseAttrs(const SCCNodeSet &SCCNodes,
Function *F = *SCCNodes.begin();
if (!F || !F->hasExactDefinition() || F->doesNotRecurse())
return;
-
- // If all of the calls in F are identifiable and are to norecurse functions, F
- // is norecurse. This check also detects self-recursion as F is not currently
- // marked norecurse, so any called from F to F will not be marked norecurse.
- for (auto &BB : *F)
- for (auto &I : BB.instructionsWithoutDebug())
- if (auto *CB = dyn_cast<CallBase>(&I)) {
- Function *Callee = CB->getCalledFunction();
- if (!Callee || Callee == F ||
- (!Callee->doesNotRecurse() &&
- !(Callee->isDeclaration() &&
- Callee->hasFnAttribute(Attribute::NoCallback))))
- // Function calls a potentially recursive function.
- return;
- }
-
- // Every call was to a non-recursive function other than this function, and
- // we have no indirect recursion as the SCC size is one. This function cannot
- // recurse.
- F->setDoesNotRecurse();
- ++NumNoRecurse;
- Changed.insert(F);
+ if (!mayHaveRecursiveCallee(*F)) {
+ // Every call was to a non-recursive function other than this function, and
+ // we have no indirect recursion as the SCC size is one. This function
+ // cannot recurse.
+ F->setDoesNotRecurse();
+ ++NumNoRecurse;
+ Changed.insert(F);
+ }
}
// Set the noreturn function attribute if possible.
@@ -2429,3 +2445,62 @@ ReversePostOrderFunctionAttrsPass::run(Module &M, ModuleAnalysisManager &AM) {
PA.preserve<LazyCallGraphAnalysis>();
return PA;
}
+
+PreservedAnalyses NoRecurseLTOInferencePass::run(Module &M,
+ ModuleAnalysisManager &MAM) {
+
+ // Check if any function in the whole program has its address taken or has
+ // potentially external linkage.
+ // We use this information when inferring norecurse attribute: If there is
+ // no function whose address is taken and all functions have internal
+ // linkage, there is no path for a callback to any user function.
+ bool AnyFunctionsAddressIsTaken = false;
+ for (Function &F : M) {
+ if (F.isDeclaration() || F.doesNotRecurse())
+ continue;
+ if (!F.hasLocalLinkage() || F.hasAddressTaken()) {
+ AnyFunctionsAddressIsTaken = true;
+ break;
+ }
+ }
+
+ // Run norecurse inference on all RefSCCs in the LazyCallGraph for this
+ // module.
+ bool Changed = false;
+ LazyCallGraph &CG = MAM.getResult<LazyCallGraphAnalysis>(M);
+ CG.buildRefSCCs();
+
+ for (LazyCallGraph::RefSCC &RC : CG.postorder_ref_sccs()) {
+ // Skip any RefSCC that is part of a call cycle. A RefSCC containing more
+ // than one SCC indicates a recursive relationship involving indirect calls.
+ if (RC.size() > 1)
+ continue;
+
+ // RefSCC contains a single-SCC. SCC size > 1 indicates mutually recursive
+ // functions. Ex: foo1 -> foo2 -> foo3 -> foo1.
+ LazyCallGraph::SCC &S = *RC.begin();
+ if (S.size() > 1)
+ continue;
+
+ // Get the single function from this SCC.
+ Function &F = S.begin()->getFunction();
+ if (!F.hasExactDefinition() || F.doesNotRecurse())
+ continue;
+
+ // If the analysis confirms that this function has no recursive calls
+ // (either direct, indirect, or through external linkages),
+ // we can safely apply the norecurse attribute.
+ if (!mayHaveRecursiveCallee(F, AnyFunctionsAddressIsTaken)) {
+ F.setDoesNotRecurse();
+ ++NumNoRecurse;
+ Changed = true;
+ }
+ }
+
+ PreservedAnalyses PA;
+ if (Changed)
+ PA.preserve<LazyCallGraphAnalysis>();
+ else
+ PA = PreservedAnalyses::all();
+ return PA;
+}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 9ca8194..56194fe 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -137,13 +137,10 @@ InstCombinerImpl::isEliminableCastPair(const CastInst *CI1,
Instruction::CastOps secondOp = CI2->getOpcode();
Type *SrcIntPtrTy =
SrcTy->isPtrOrPtrVectorTy() ? DL.getIntPtrType(SrcTy) : nullptr;
- Type *MidIntPtrTy =
- MidTy->isPtrOrPtrVectorTy() ? DL.getIntPtrType(MidTy) : nullptr;
Type *DstIntPtrTy =
DstTy->isPtrOrPtrVectorTy() ? DL.getIntPtrType(DstTy) : nullptr;
unsigned Res = CastInst::isEliminableCastPair(firstOp, secondOp, SrcTy, MidTy,
- DstTy, SrcIntPtrTy, MidIntPtrTy,
- DstIntPtrTy);
+ DstTy, &DL);
// We don't want to form an inttoptr or ptrtoint that converts to an integer
// type that differs from the pointer size.
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 8f60e50..8c8fc69 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -3356,7 +3356,10 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) {
impliesPoisonOrCond(FalseVal, B, /*Expected=*/false)) {
// (A || B) || C --> A || (B | C)
return replaceInstUsesWith(
- SI, Builder.CreateLogicalOr(A, Builder.CreateOr(B, FalseVal)));
+ SI, Builder.CreateLogicalOr(A, Builder.CreateOr(B, FalseVal), "",
+ ProfcheckDisableMetadataFixes
+ ? nullptr
+ : cast<SelectInst>(CondVal)));
}
// (A && B) || (C && B) --> (A || C) && B
@@ -3398,7 +3401,10 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) {
impliesPoisonOrCond(TrueVal, B, /*Expected=*/true)) {
// (A && B) && C --> A && (B & C)
return replaceInstUsesWith(
- SI, Builder.CreateLogicalAnd(A, Builder.CreateAnd(B, TrueVal)));
+ SI, Builder.CreateLogicalAnd(A, Builder.CreateAnd(B, TrueVal), "",
+ ProfcheckDisableMetadataFixes
+ ? nullptr
+ : cast<SelectInst>(CondVal)));
}
// (A || B) && (C || B) --> (A && C) || B
diff --git a/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp
index cdae9a7..3704ad7 100644
--- a/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp
+++ b/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp
@@ -2662,7 +2662,7 @@ void ModuleAddressSanitizer::instrumentGlobals(IRBuilder<> &IRB,
G->eraseFromParent();
NewGlobals[i] = NewGlobal;
- Constant *ODRIndicator = ConstantPointerNull::get(PtrTy);
+ Constant *ODRIndicator = Constant::getNullValue(IntptrTy);
GlobalValue *InstrumentedGlobal = NewGlobal;
bool CanUsePrivateAliases =
@@ -2677,8 +2677,7 @@ void ModuleAddressSanitizer::instrumentGlobals(IRBuilder<> &IRB,
// ODR should not happen for local linkage.
if (NewGlobal->hasLocalLinkage()) {
- ODRIndicator =
- ConstantExpr::getIntToPtr(ConstantInt::get(IntptrTy, -1), PtrTy);
+ ODRIndicator = ConstantInt::get(IntptrTy, -1);
} else if (UseOdrIndicator) {
// With local aliases, we need to provide another externally visible
// symbol __odr_asan_XXX to detect ODR violation.
@@ -2692,7 +2691,7 @@ void ModuleAddressSanitizer::instrumentGlobals(IRBuilder<> &IRB,
ODRIndicatorSym->setVisibility(NewGlobal->getVisibility());
ODRIndicatorSym->setDLLStorageClass(NewGlobal->getDLLStorageClass());
ODRIndicatorSym->setAlignment(Align(1));
- ODRIndicator = ODRIndicatorSym;
+ ODRIndicator = ConstantExpr::getPtrToInt(ODRIndicatorSym, IntptrTy);
}
Constant *Initializer = ConstantStruct::get(
@@ -2703,8 +2702,7 @@ void ModuleAddressSanitizer::instrumentGlobals(IRBuilder<> &IRB,
ConstantExpr::getPointerCast(Name, IntptrTy),
ConstantExpr::getPointerCast(getOrCreateModuleName(), IntptrTy),
ConstantInt::get(IntptrTy, MD.IsDynInit),
- Constant::getNullValue(IntptrTy),
- ConstantExpr::getPointerCast(ODRIndicator, IntptrTy));
+ Constant::getNullValue(IntptrTy), ODRIndicator);
LLVM_DEBUG(dbgs() << "NEW GLOBAL: " << *NewGlobal << "\n");
diff --git a/llvm/lib/Transforms/Instrumentation/AllocToken.cpp b/llvm/lib/Transforms/Instrumentation/AllocToken.cpp
new file mode 100644
index 0000000..782d5a1
--- /dev/null
+++ b/llvm/lib/Transforms/Instrumentation/AllocToken.cpp
@@ -0,0 +1,494 @@
+//===- AllocToken.cpp - Allocation token instrumentation ------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements AllocToken, an instrumentation pass that
+// replaces allocation calls with token-enabled versions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Instrumentation/AllocToken.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Analysis/MemoryBuiltins.h"
+#include "llvm/Analysis/OptimizationRemarkEmitter.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/IR/Analysis.h"
+#include "llvm/IR/Attributes.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/GlobalValue.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Metadata.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Compiler.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/RandomNumberGenerator.h"
+#include "llvm/Support/SipHash.h"
+#include "llvm/Support/raw_ostream.h"
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+#include <limits>
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+#include <variant>
+
+using namespace llvm;
+
+#define DEBUG_TYPE "alloc-token"
+
+namespace {
+
+//===--- Constants --------------------------------------------------------===//
+
+enum class TokenMode : unsigned {
+ /// Incrementally increasing token ID.
+ Increment = 0,
+
+ /// Simple mode that returns a statically-assigned random token ID.
+ Random = 1,
+
+ /// Token ID based on allocated type hash.
+ TypeHash = 2,
+};
+
+//===--- Command-line options ---------------------------------------------===//
+
+cl::opt<TokenMode>
+ ClMode("alloc-token-mode", cl::Hidden, cl::desc("Token assignment mode"),
+ cl::init(TokenMode::TypeHash),
+ cl::values(clEnumValN(TokenMode::Increment, "increment",
+ "Incrementally increasing token ID"),
+ clEnumValN(TokenMode::Random, "random",
+ "Statically-assigned random token ID"),
+ clEnumValN(TokenMode::TypeHash, "typehash",
+ "Token ID based on allocated type hash")));
+
+cl::opt<std::string> ClFuncPrefix("alloc-token-prefix",
+ cl::desc("The allocation function prefix"),
+ cl::Hidden, cl::init("__alloc_token_"));
+
+cl::opt<uint64_t> ClMaxTokens("alloc-token-max",
+ cl::desc("Maximum number of tokens (0 = no max)"),
+ cl::Hidden, cl::init(0));
+
+cl::opt<bool>
+ ClFastABI("alloc-token-fast-abi",
+ cl::desc("The token ID is encoded in the function name"),
+ cl::Hidden, cl::init(false));
+
+// Instrument libcalls only by default - compatible allocators only need to take
+// care of providing standard allocation functions. With extended coverage, also
+// instrument non-libcall allocation function calls with !alloc_token
+// metadata.
+cl::opt<bool>
+ ClExtended("alloc-token-extended",
+ cl::desc("Extend coverage to custom allocation functions"),
+ cl::Hidden, cl::init(false));
+
+// C++ defines ::operator new (and variants) as replaceable (vs. standard
+// library versions), which are nobuiltin, and are therefore not covered by
+// isAllocationFn(). Cover by default, as users of AllocToken are already
+// required to provide token-aware allocation functions (no defaults).
+cl::opt<bool> ClCoverReplaceableNew("alloc-token-cover-replaceable-new",
+ cl::desc("Cover replaceable operator new"),
+ cl::Hidden, cl::init(true));
+
+cl::opt<uint64_t> ClFallbackToken(
+ "alloc-token-fallback",
+ cl::desc("The default fallback token where none could be determined"),
+ cl::Hidden, cl::init(0));
+
+//===--- Statistics -------------------------------------------------------===//
+
+STATISTIC(NumFunctionsInstrumented, "Functions instrumented");
+STATISTIC(NumAllocationsInstrumented, "Allocations instrumented");
+
+//===----------------------------------------------------------------------===//
+
+/// Returns the !alloc_token metadata if available.
+///
+/// Expected format is: !{<type-name>}
+MDNode *getAllocTokenMetadata(const CallBase &CB) {
+ MDNode *Ret = CB.getMetadata(LLVMContext::MD_alloc_token);
+ if (!Ret)
+ return nullptr;
+ assert(Ret->getNumOperands() == 1 && "bad !alloc_token");
+ assert(isa<MDString>(Ret->getOperand(0)));
+ return Ret;
+}
+
+class ModeBase {
+public:
+ explicit ModeBase(const IntegerType &TokenTy, uint64_t MaxTokens)
+ : MaxTokens(MaxTokens ? MaxTokens : TokenTy.getBitMask()) {
+ assert(MaxTokens <= TokenTy.getBitMask());
+ }
+
+protected:
+ uint64_t boundedToken(uint64_t Val) const {
+ assert(MaxTokens != 0);
+ return Val % MaxTokens;
+ }
+
+ const uint64_t MaxTokens;
+};
+
+/// Implementation for TokenMode::Increment.
+class IncrementMode : public ModeBase {
+public:
+ using ModeBase::ModeBase;
+
+ uint64_t operator()(const CallBase &CB, OptimizationRemarkEmitter &) {
+ return boundedToken(Counter++);
+ }
+
+private:
+ uint64_t Counter = 0;
+};
+
+/// Implementation for TokenMode::Random.
+class RandomMode : public ModeBase {
+public:
+ RandomMode(const IntegerType &TokenTy, uint64_t MaxTokens,
+ std::unique_ptr<RandomNumberGenerator> RNG)
+ : ModeBase(TokenTy, MaxTokens), RNG(std::move(RNG)) {}
+ uint64_t operator()(const CallBase &CB, OptimizationRemarkEmitter &) {
+ return boundedToken((*RNG)());
+ }
+
+private:
+ std::unique_ptr<RandomNumberGenerator> RNG;
+};
+
+/// Implementation for TokenMode::TypeHash. The implementation ensures
+/// hashes are stable across different compiler invocations. Uses SipHash as the
+/// hash function.
+class TypeHashMode : public ModeBase {
+public:
+ using ModeBase::ModeBase;
+
+ uint64_t operator()(const CallBase &CB, OptimizationRemarkEmitter &ORE) {
+ if (MDNode *N = getAllocTokenMetadata(CB)) {
+ MDString *S = cast<MDString>(N->getOperand(0));
+ return boundedToken(getStableSipHash(S->getString()));
+ }
+ remarkNoMetadata(CB, ORE);
+ return ClFallbackToken;
+ }
+
+ /// Remark that there was no precise type information.
+ static void remarkNoMetadata(const CallBase &CB,
+ OptimizationRemarkEmitter &ORE) {
+ ORE.emit([&] {
+ ore::NV FuncNV("Function", CB.getParent()->getParent());
+ const Function *Callee = CB.getCalledFunction();
+ ore::NV CalleeNV("Callee", Callee ? Callee->getName() : "<unknown>");
+ return OptimizationRemark(DEBUG_TYPE, "NoAllocToken", &CB)
+ << "Call to '" << CalleeNV << "' in '" << FuncNV
+ << "' without source-level type token";
+ });
+ }
+};
+
+// Apply opt overrides.
+AllocTokenOptions transformOptionsFromCl(AllocTokenOptions Opts) {
+ if (!Opts.MaxTokens.has_value())
+ Opts.MaxTokens = ClMaxTokens;
+ Opts.FastABI |= ClFastABI;
+ Opts.Extended |= ClExtended;
+ return Opts;
+}
+
+class AllocToken {
+public:
+ explicit AllocToken(AllocTokenOptions Opts, Module &M,
+ ModuleAnalysisManager &MAM)
+ : Options(transformOptionsFromCl(std::move(Opts))), Mod(M),
+ FAM(MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager()),
+ Mode(IncrementMode(*IntPtrTy, *Options.MaxTokens)) {
+ switch (ClMode.getValue()) {
+ case TokenMode::Increment:
+ break;
+ case TokenMode::Random:
+ Mode.emplace<RandomMode>(*IntPtrTy, *Options.MaxTokens,
+ M.createRNG(DEBUG_TYPE));
+ break;
+ case TokenMode::TypeHash:
+ Mode.emplace<TypeHashMode>(*IntPtrTy, *Options.MaxTokens);
+ break;
+ }
+ }
+
+ bool instrumentFunction(Function &F);
+
+private:
+ /// Returns the LibFunc (or NotLibFunc) if this call should be instrumented.
+ std::optional<LibFunc>
+ shouldInstrumentCall(const CallBase &CB, const TargetLibraryInfo &TLI) const;
+
+ /// Returns true for functions that are eligible for instrumentation.
+ static bool isInstrumentableLibFunc(LibFunc Func, const CallBase &CB,
+ const TargetLibraryInfo &TLI);
+
+ /// Returns true for isAllocationFn() functions that we should ignore.
+ static bool ignoreInstrumentableLibFunc(LibFunc Func);
+
+ /// Replace a call/invoke with a call/invoke to the allocation function
+ /// with token ID.
+ bool replaceAllocationCall(CallBase *CB, LibFunc Func,
+ OptimizationRemarkEmitter &ORE,
+ const TargetLibraryInfo &TLI);
+
+ /// Return replacement function for a LibFunc that takes a token ID.
+ FunctionCallee getTokenAllocFunction(const CallBase &CB, uint64_t TokenID,
+ LibFunc OriginalFunc);
+
+ /// Return the token ID from metadata in the call.
+ uint64_t getToken(const CallBase &CB, OptimizationRemarkEmitter &ORE) {
+ return std::visit([&](auto &&Mode) { return Mode(CB, ORE); }, Mode);
+ }
+
+ const AllocTokenOptions Options;
+ Module &Mod;
+ IntegerType *IntPtrTy = Mod.getDataLayout().getIntPtrType(Mod.getContext());
+ FunctionAnalysisManager &FAM;
+ // Cache for replacement functions.
+ DenseMap<std::pair<LibFunc, uint64_t>, FunctionCallee> TokenAllocFunctions;
+ // Selected mode.
+ std::variant<IncrementMode, RandomMode, TypeHashMode> Mode;
+};
+
+bool AllocToken::instrumentFunction(Function &F) {
+ // Do not apply any instrumentation for naked functions.
+ if (F.hasFnAttribute(Attribute::Naked))
+ return false;
+ if (F.hasFnAttribute(Attribute::DisableSanitizerInstrumentation))
+ return false;
+ // Don't touch available_externally functions, their actual body is elsewhere.
+ if (F.getLinkage() == GlobalValue::AvailableExternallyLinkage)
+ return false;
+ // Only instrument functions that have the sanitize_alloc_token attribute.
+ if (!F.hasFnAttribute(Attribute::SanitizeAllocToken))
+ return false;
+
+ auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
+ auto &TLI = FAM.getResult<TargetLibraryAnalysis>(F);
+ SmallVector<std::pair<CallBase *, LibFunc>, 4> AllocCalls;
+
+ // Collect all allocation calls to avoid iterator invalidation.
+ for (Instruction &I : instructions(F)) {
+ auto *CB = dyn_cast<CallBase>(&I);
+ if (!CB)
+ continue;
+ if (std::optional<LibFunc> Func = shouldInstrumentCall(*CB, TLI))
+ AllocCalls.emplace_back(CB, Func.value());
+ }
+
+ bool Modified = false;
+ for (auto &[CB, Func] : AllocCalls)
+ Modified |= replaceAllocationCall(CB, Func, ORE, TLI);
+
+ if (Modified)
+ NumFunctionsInstrumented++;
+ return Modified;
+}
+
+std::optional<LibFunc>
+AllocToken::shouldInstrumentCall(const CallBase &CB,
+ const TargetLibraryInfo &TLI) const {
+ const Function *Callee = CB.getCalledFunction();
+ if (!Callee)
+ return std::nullopt;
+
+ // Ignore nobuiltin of the CallBase, so that we can cover nobuiltin libcalls
+ // if requested via isInstrumentableLibFunc(). Note that isAllocationFn() is
+ // returning false for nobuiltin calls.
+ LibFunc Func;
+ if (TLI.getLibFunc(*Callee, Func)) {
+ if (isInstrumentableLibFunc(Func, CB, TLI))
+ return Func;
+ } else if (Options.Extended && getAllocTokenMetadata(CB)) {
+ return NotLibFunc;
+ }
+
+ return std::nullopt;
+}
+
+bool AllocToken::isInstrumentableLibFunc(LibFunc Func, const CallBase &CB,
+ const TargetLibraryInfo &TLI) {
+ if (ignoreInstrumentableLibFunc(Func))
+ return false;
+
+ if (isAllocationFn(&CB, &TLI))
+ return true;
+
+ switch (Func) {
+ // These libfuncs don't return normal pointers, and are therefore not handled
+ // by isAllocationFn().
+ case LibFunc_posix_memalign:
+ case LibFunc_size_returning_new:
+ case LibFunc_size_returning_new_hot_cold:
+ case LibFunc_size_returning_new_aligned:
+ case LibFunc_size_returning_new_aligned_hot_cold:
+ return true;
+
+ // See comment above ClCoverReplaceableNew.
+ case LibFunc_Znwj:
+ case LibFunc_ZnwjRKSt9nothrow_t:
+ case LibFunc_ZnwjSt11align_val_t:
+ case LibFunc_ZnwjSt11align_val_tRKSt9nothrow_t:
+ case LibFunc_Znwm:
+ case LibFunc_Znwm12__hot_cold_t:
+ case LibFunc_ZnwmRKSt9nothrow_t:
+ case LibFunc_ZnwmRKSt9nothrow_t12__hot_cold_t:
+ case LibFunc_ZnwmSt11align_val_t:
+ case LibFunc_ZnwmSt11align_val_t12__hot_cold_t:
+ case LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t:
+ case LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t12__hot_cold_t:
+ case LibFunc_Znaj:
+ case LibFunc_ZnajRKSt9nothrow_t:
+ case LibFunc_ZnajSt11align_val_t:
+ case LibFunc_ZnajSt11align_val_tRKSt9nothrow_t:
+ case LibFunc_Znam:
+ case LibFunc_Znam12__hot_cold_t:
+ case LibFunc_ZnamRKSt9nothrow_t:
+ case LibFunc_ZnamRKSt9nothrow_t12__hot_cold_t:
+ case LibFunc_ZnamSt11align_val_t:
+ case LibFunc_ZnamSt11align_val_t12__hot_cold_t:
+ case LibFunc_ZnamSt11align_val_tRKSt9nothrow_t:
+ case LibFunc_ZnamSt11align_val_tRKSt9nothrow_t12__hot_cold_t:
+ return ClCoverReplaceableNew;
+
+ default:
+ return false;
+ }
+}
+
+bool AllocToken::ignoreInstrumentableLibFunc(LibFunc Func) {
+ switch (Func) {
+ case LibFunc_strdup:
+ case LibFunc_dunder_strdup:
+ case LibFunc_strndup:
+ case LibFunc_dunder_strndup:
+ return true;
+ default:
+ return false;
+ }
+}
+
+bool AllocToken::replaceAllocationCall(CallBase *CB, LibFunc Func,
+ OptimizationRemarkEmitter &ORE,
+ const TargetLibraryInfo &TLI) {
+ uint64_t TokenID = getToken(*CB, ORE);
+
+ FunctionCallee TokenAlloc = getTokenAllocFunction(*CB, TokenID, Func);
+ if (!TokenAlloc)
+ return false;
+ NumAllocationsInstrumented++;
+
+ if (Options.FastABI) {
+ assert(TokenAlloc.getFunctionType()->getNumParams() == CB->arg_size());
+ CB->setCalledFunction(TokenAlloc);
+ return true;
+ }
+
+ IRBuilder<> IRB(CB);
+ // Original args.
+ SmallVector<Value *, 4> NewArgs{CB->args()};
+ // Add token ID, truncated to IntPtrTy width.
+ NewArgs.push_back(ConstantInt::get(IntPtrTy, TokenID));
+ assert(TokenAlloc.getFunctionType()->getNumParams() == NewArgs.size());
+
+ // Preserve invoke vs call semantics for exception handling.
+ CallBase *NewCall;
+ if (auto *II = dyn_cast<InvokeInst>(CB)) {
+ NewCall = IRB.CreateInvoke(TokenAlloc, II->getNormalDest(),
+ II->getUnwindDest(), NewArgs);
+ } else {
+ NewCall = IRB.CreateCall(TokenAlloc, NewArgs);
+ cast<CallInst>(NewCall)->setTailCall(CB->isTailCall());
+ }
+ NewCall->setCallingConv(CB->getCallingConv());
+ NewCall->copyMetadata(*CB);
+ NewCall->setAttributes(CB->getAttributes());
+
+ // Replace all uses and delete the old call.
+ CB->replaceAllUsesWith(NewCall);
+ CB->eraseFromParent();
+ return true;
+}
+
+FunctionCallee AllocToken::getTokenAllocFunction(const CallBase &CB,
+ uint64_t TokenID,
+ LibFunc OriginalFunc) {
+ std::optional<std::pair<LibFunc, uint64_t>> Key;
+ if (OriginalFunc != NotLibFunc) {
+ Key = std::make_pair(OriginalFunc, Options.FastABI ? TokenID : 0);
+ auto It = TokenAllocFunctions.find(*Key);
+ if (It != TokenAllocFunctions.end())
+ return It->second;
+ }
+
+ const Function *Callee = CB.getCalledFunction();
+ if (!Callee)
+ return FunctionCallee();
+ const FunctionType *OldFTy = Callee->getFunctionType();
+ if (OldFTy->isVarArg())
+ return FunctionCallee();
+ // Copy params, and append token ID type.
+ Type *RetTy = OldFTy->getReturnType();
+ SmallVector<Type *, 4> NewParams{OldFTy->params()};
+ std::string TokenAllocName = ClFuncPrefix;
+ if (Options.FastABI)
+ TokenAllocName += utostr(TokenID) + "_";
+ else
+ NewParams.push_back(IntPtrTy); // token ID
+ TokenAllocName += Callee->getName();
+ FunctionType *NewFTy = FunctionType::get(RetTy, NewParams, false);
+ FunctionCallee TokenAlloc = Mod.getOrInsertFunction(TokenAllocName, NewFTy);
+ if (Function *F = dyn_cast<Function>(TokenAlloc.getCallee()))
+ F->copyAttributesFrom(Callee); // preserve attrs
+
+ if (Key.has_value())
+ TokenAllocFunctions[*Key] = TokenAlloc;
+ return TokenAlloc;
+}
+
+} // namespace
+
+AllocTokenPass::AllocTokenPass(AllocTokenOptions Opts)
+ : Options(std::move(Opts)) {}
+
+PreservedAnalyses AllocTokenPass::run(Module &M, ModuleAnalysisManager &MAM) {
+ AllocToken Pass(Options, M, MAM);
+ bool Modified = false;
+
+ for (Function &F : M) {
+ if (F.empty())
+ continue; // declaration
+ Modified |= Pass.instrumentFunction(F);
+ }
+
+ return Modified ? PreservedAnalyses::none().preserveSet<CFGAnalyses>()
+ : PreservedAnalyses::all();
+}
diff --git a/llvm/lib/Transforms/Instrumentation/CMakeLists.txt b/llvm/lib/Transforms/Instrumentation/CMakeLists.txt
index 15fd421..80576c6 100644
--- a/llvm/lib/Transforms/Instrumentation/CMakeLists.txt
+++ b/llvm/lib/Transforms/Instrumentation/CMakeLists.txt
@@ -1,5 +1,6 @@
add_llvm_component_library(LLVMInstrumentation
AddressSanitizer.cpp
+ AllocToken.cpp
BoundsChecking.cpp
CGProfile.cpp
ControlHeightReduction.cpp
diff --git a/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp
index 480ff4a..5ba2167 100644
--- a/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp
+++ b/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp
@@ -261,6 +261,11 @@ static cl::opt<bool> ClIgnorePersonalityRoutine(
"list, do not create a wrapper for it."),
cl::Hidden, cl::init(false));
+static cl::opt<bool> ClAddGlobalNameSuffix(
+ "dfsan-add-global-name-suffix",
+ cl::desc("Whether to add .dfsan suffix to global names"), cl::Hidden,
+ cl::init(true));
+
static StringRef getGlobalTypeString(const GlobalValue &G) {
// Types of GlobalVariables are always pointer types.
Type *GType = G.getValueType();
@@ -1256,6 +1261,9 @@ DataFlowSanitizer::WrapperKind DataFlowSanitizer::getWrapperKind(Function *F) {
}
void DataFlowSanitizer::addGlobalNameSuffix(GlobalValue *GV) {
+ if (!ClAddGlobalNameSuffix)
+ return;
+
std::string GVName = std::string(GV->getName()), Suffix = ".dfsan";
GV->setName(GVName + Suffix);
@@ -1784,10 +1792,8 @@ bool DataFlowSanitizer::runImpl(
}
Value *DFSanFunction::getArgTLS(Type *T, unsigned ArgOffset, IRBuilder<> &IRB) {
- Value *Base = IRB.CreatePointerCast(DFS.ArgTLS, DFS.IntptrTy);
- if (ArgOffset)
- Base = IRB.CreateAdd(Base, ConstantInt::get(DFS.IntptrTy, ArgOffset));
- return IRB.CreateIntToPtr(Base, PointerType::get(*DFS.Ctx, 0), "_dfsarg");
+ return IRB.CreatePtrAdd(DFS.ArgTLS, ConstantInt::get(DFS.IntptrTy, ArgOffset),
+ "_dfsarg");
}
Value *DFSanFunction::getRetvalTLS(Type *T, IRBuilder<> &IRB) {
diff --git a/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp b/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp
index e9a3e98..ae34b4e 100644
--- a/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp
+++ b/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp
@@ -89,6 +89,7 @@ STATISTIC(NumTransforms, "Number of transformations done");
STATISTIC(NumCloned, "Number of blocks cloned");
STATISTIC(NumPaths, "Number of individual paths threaded");
+namespace llvm {
static cl::opt<bool>
ClViewCfgBefore("dfa-jump-view-cfg-before",
cl::desc("View the CFG before DFA Jump Threading"),
@@ -119,9 +120,15 @@ static cl::opt<unsigned>
CostThreshold("dfa-cost-threshold",
cl::desc("Maximum cost accepted for the transformation"),
cl::Hidden, cl::init(50));
+} // namespace llvm
-namespace {
+static cl::opt<double> MaxClonedRate(
+ "dfa-max-cloned-rate",
+ cl::desc(
+ "Maximum cloned instructions rate accepted for the transformation"),
+ cl::Hidden, cl::init(7.5));
+namespace {
class SelectInstToUnfold {
SelectInst *SI;
PHINode *SIUse;
@@ -135,10 +142,6 @@ public:
explicit operator bool() const { return SI && SIUse; }
};
-void unfold(DomTreeUpdater *DTU, LoopInfo *LI, SelectInstToUnfold SIToUnfold,
- std::vector<SelectInstToUnfold> *NewSIsToUnfold,
- std::vector<BasicBlock *> *NewBBs);
-
class DFAJumpThreading {
public:
DFAJumpThreading(AssumptionCache *AC, DominatorTree *DT, LoopInfo *LI,
@@ -152,7 +155,8 @@ private:
void
unfoldSelectInstrs(DominatorTree *DT,
const SmallVector<SelectInstToUnfold, 4> &SelectInsts) {
- DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager);
+ // TODO: Have everything use a single lazy DTU
+ DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
SmallVector<SelectInstToUnfold, 4> Stack(SelectInsts);
while (!Stack.empty()) {
@@ -167,16 +171,18 @@ private:
}
}
+ static void unfold(DomTreeUpdater *DTU, LoopInfo *LI,
+ SelectInstToUnfold SIToUnfold,
+ std::vector<SelectInstToUnfold> *NewSIsToUnfold,
+ std::vector<BasicBlock *> *NewBBs);
+
AssumptionCache *AC;
DominatorTree *DT;
LoopInfo *LI;
TargetTransformInfo *TTI;
OptimizationRemarkEmitter *ORE;
};
-
-} // end anonymous namespace
-
-namespace {
+} // namespace
/// Unfold the select instruction held in \p SIToUnfold by replacing it with
/// control flow.
@@ -185,9 +191,10 @@ namespace {
/// created basic blocks into \p NewBBs.
///
/// TODO: merge it with CodeGenPrepare::optimizeSelectInst() if possible.
-void unfold(DomTreeUpdater *DTU, LoopInfo *LI, SelectInstToUnfold SIToUnfold,
- std::vector<SelectInstToUnfold> *NewSIsToUnfold,
- std::vector<BasicBlock *> *NewBBs) {
+void DFAJumpThreading::unfold(DomTreeUpdater *DTU, LoopInfo *LI,
+ SelectInstToUnfold SIToUnfold,
+ std::vector<SelectInstToUnfold> *NewSIsToUnfold,
+ std::vector<BasicBlock *> *NewBBs) {
SelectInst *SI = SIToUnfold.getInst();
PHINode *SIUse = SIToUnfold.getUse();
assert(SI->hasOneUse());
@@ -342,10 +349,12 @@ void unfold(DomTreeUpdater *DTU, LoopInfo *LI, SelectInstToUnfold SIToUnfold,
SI->eraseFromParent();
}
+namespace {
struct ClonedBlock {
BasicBlock *BB;
APInt State; ///< \p State corresponds to the next value of a switch stmnt.
};
+} // namespace
typedef std::deque<BasicBlock *> PathType;
typedef std::vector<PathType> PathsType;
@@ -375,6 +384,7 @@ inline raw_ostream &operator<<(raw_ostream &OS, const PathType &Path) {
return OS;
}
+namespace {
/// ThreadingPath is a path in the control flow of a loop that can be threaded
/// by cloning necessary basic blocks and replacing conditional branches with
/// unconditional ones. A threading path includes a list of basic blocks, the
@@ -814,11 +824,13 @@ struct TransformDFA {
: SwitchPaths(SwitchPaths), DT(DT), AC(AC), TTI(TTI), ORE(ORE),
EphValues(EphValues) {}
- void run() {
+ bool run() {
if (isLegalAndProfitableToTransform()) {
createAllExitPaths();
NumTransforms++;
+ return true;
}
+ return false;
}
private:
@@ -828,6 +840,7 @@ private:
/// also returns false if it is illegal to clone some required block.
bool isLegalAndProfitableToTransform() {
CodeMetrics Metrics;
+ uint64_t NumClonedInst = 0;
SwitchInst *Switch = SwitchPaths->getSwitchInst();
// Don't thread switch without multiple successors.
@@ -837,7 +850,6 @@ private:
// Note that DuplicateBlockMap is not being used as intended here. It is
// just being used to ensure (BB, State) pairs are only counted once.
DuplicateBlockMap DuplicateMap;
-
for (ThreadingPath &TPath : SwitchPaths->getThreadingPaths()) {
PathType PathBBs = TPath.getPath();
APInt NextState = TPath.getExitValue();
@@ -848,6 +860,7 @@ private:
BasicBlock *VisitedBB = getClonedBB(BB, NextState, DuplicateMap);
if (!VisitedBB) {
Metrics.analyzeBasicBlock(BB, *TTI, EphValues);
+ NumClonedInst += BB->sizeWithoutDebug();
DuplicateMap[BB].push_back({BB, NextState});
}
@@ -865,6 +878,7 @@ private:
if (VisitedBB)
continue;
Metrics.analyzeBasicBlock(BB, *TTI, EphValues);
+ NumClonedInst += BB->sizeWithoutDebug();
DuplicateMap[BB].push_back({BB, NextState});
}
@@ -901,6 +915,22 @@ private:
}
}
+ // Too much cloned instructions slow down later optimizations, especially
+ // SLPVectorizer.
+ // TODO: Thread the switch partially before reaching the threshold.
+ uint64_t NumOrigInst = 0;
+ for (auto *BB : DuplicateMap.keys())
+ NumOrigInst += BB->sizeWithoutDebug();
+ if (double(NumClonedInst) / double(NumOrigInst) > MaxClonedRate) {
+ LLVM_DEBUG(dbgs() << "DFA Jump Threading: Not jump threading, too much "
+ "instructions wll be cloned\n");
+ ORE->emit([&]() {
+ return OptimizationRemarkMissed(DEBUG_TYPE, "NotProfitable", Switch)
+ << "Too much instructions will be cloned.";
+ });
+ return false;
+ }
+
InstructionCost DuplicationCost = 0;
unsigned JumpTableSize = 0;
@@ -951,8 +981,6 @@ private:
/// Transform each threading path to effectively jump thread the DFA.
void createAllExitPaths() {
- DomTreeUpdater DTU(*DT, DomTreeUpdater::UpdateStrategy::Eager);
-
// Move the switch block to the end of the path, since it will be duplicated
BasicBlock *SwitchBlock = SwitchPaths->getSwitchBlock();
for (ThreadingPath &TPath : SwitchPaths->getThreadingPaths()) {
@@ -969,15 +997,18 @@ private:
SmallPtrSet<BasicBlock *, 16> BlocksToClean;
BlocksToClean.insert_range(successors(SwitchBlock));
- for (ThreadingPath &TPath : SwitchPaths->getThreadingPaths()) {
- createExitPath(NewDefs, TPath, DuplicateMap, BlocksToClean, &DTU);
- NumPaths++;
- }
+ {
+ DomTreeUpdater DTU(*DT, DomTreeUpdater::UpdateStrategy::Lazy);
+ for (const ThreadingPath &TPath : SwitchPaths->getThreadingPaths()) {
+ createExitPath(NewDefs, TPath, DuplicateMap, BlocksToClean, &DTU);
+ NumPaths++;
+ }
- // After all paths are cloned, now update the last successor of the cloned
- // path so it skips over the switch statement
- for (ThreadingPath &TPath : SwitchPaths->getThreadingPaths())
- updateLastSuccessor(TPath, DuplicateMap, &DTU);
+ // After all paths are cloned, now update the last successor of the cloned
+ // path so it skips over the switch statement
+ for (const ThreadingPath &TPath : SwitchPaths->getThreadingPaths())
+ updateLastSuccessor(TPath, DuplicateMap, &DTU);
+ }
// For each instruction that was cloned and used outside, update its uses
updateSSA(NewDefs);
@@ -993,7 +1024,7 @@ private:
/// To remember the correct destination, we have to duplicate blocks
/// corresponding to each state. Also update the terminating instruction of
/// the predecessors, and phis in the successor blocks.
- void createExitPath(DefMap &NewDefs, ThreadingPath &Path,
+ void createExitPath(DefMap &NewDefs, const ThreadingPath &Path,
DuplicateBlockMap &DuplicateMap,
SmallPtrSet<BasicBlock *, 16> &BlocksToClean,
DomTreeUpdater *DTU) {
@@ -1239,7 +1270,7 @@ private:
///
/// Note that this is an optional step and would have been done in later
/// optimizations, but it makes the CFG significantly easier to work with.
- void updateLastSuccessor(ThreadingPath &TPath,
+ void updateLastSuccessor(const ThreadingPath &TPath,
DuplicateBlockMap &DuplicateMap,
DomTreeUpdater *DTU) {
APInt NextState = TPath.getExitValue();
@@ -1336,6 +1367,7 @@ private:
SmallPtrSet<const Value *, 32> EphValues;
std::vector<ThreadingPath> TPaths;
};
+} // namespace
bool DFAJumpThreading::run(Function &F) {
LLVM_DEBUG(dbgs() << "\nDFA Jump threading: " << F.getName() << "\n");
@@ -1402,9 +1434,8 @@ bool DFAJumpThreading::run(Function &F) {
for (AllSwitchPaths SwitchPaths : ThreadableLoops) {
TransformDFA Transform(&SwitchPaths, DT, AC, TTI, ORE, EphValues);
- Transform.run();
- MadeChanges = true;
- LoopInfoBroken = true;
+ if (Transform.run())
+ MadeChanges = LoopInfoBroken = true;
}
#ifdef EXPENSIVE_CHECKS
@@ -1415,8 +1446,6 @@ bool DFAJumpThreading::run(Function &F) {
return MadeChanges;
}
-} // end anonymous namespace
-
/// Integrate with the new Pass Manager
PreservedAnalyses DFAJumpThreadingPass::run(Function &F,
FunctionAnalysisManager &AM) {
diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index bbd1ed6..5ba6f95f 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -970,6 +970,7 @@ Function *CodeExtractor::constructFunctionDeclaration(
case Attribute::SanitizeMemTag:
case Attribute::SanitizeRealtime:
case Attribute::SanitizeRealtimeBlocking:
+ case Attribute::SanitizeAllocToken:
case Attribute::SpeculativeLoadHardening:
case Attribute::StackProtect:
case Attribute::StackProtectReq:
diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp
index 21b2652..b6ca52e 100644
--- a/llvm/lib/Transforms/Utils/Local.cpp
+++ b/llvm/lib/Transforms/Utils/Local.cpp
@@ -3031,6 +3031,13 @@ static void combineMetadata(Instruction *K, const Instruction *J,
K->getContext(), MDNode::toCaptureComponents(JMD) |
MDNode::toCaptureComponents(KMD)));
break;
+ case LLVMContext::MD_alloc_token:
+ // Preserve !alloc_token if both K and J have it, and they are equal.
+ if (KMD == JMD)
+ K->setMetadata(Kind, JMD);
+ else
+ K->setMetadata(Kind, nullptr);
+ break;
}
}
// Set !invariant.group from J if J has it. If both instructions have it
diff --git a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
index bf882d7..6312831 100644
--- a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
@@ -201,18 +201,27 @@ static void ConnectProlog(Loop *L, Value *BECount, unsigned Count,
/// unroll count is non-zero.
///
/// This function performs the following:
-/// - Update PHI nodes at the unrolling loop exit and epilog loop exit
-/// - Create PHI nodes at the unrolling loop exit to combine
-/// values that exit the unrolling loop code and jump around it.
+/// - Update PHI nodes at the epilog loop exit
+/// - Create PHI nodes at the unrolling loop exit and epilog preheader to
+/// combine values that exit the unrolling loop code and jump around it.
/// - Update PHI operands in the epilog loop by the new PHI nodes
-/// - Branch around the epilog loop if extra iters (ModVal) is zero.
+/// - At the unrolling loop exit, branch around the epilog loop if extra iters
+// (ModVal) is zero.
+/// - At the epilog preheader, add an llvm.assume call that extra iters is
+/// non-zero. If the unrolling loop exit is the predecessor, the above new
+/// branch guarantees that assumption. If the unrolling loop preheader is the
+/// predecessor, then the required first iteration from the original loop has
+/// yet to be executed, so it must be executed in the epilog loop. If we
+/// later unroll the epilog loop, that llvm.assume call somehow enables
+/// ScalarEvolution to compute a epilog loop maximum trip count, which enables
+/// eliminating the branch at the end of the final unrolled epilog iteration.
///
static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,
BasicBlock *Exit, BasicBlock *PreHeader,
BasicBlock *EpilogPreHeader, BasicBlock *NewPreHeader,
ValueToValueMapTy &VMap, DominatorTree *DT,
LoopInfo *LI, bool PreserveLCSSA, ScalarEvolution &SE,
- unsigned Count) {
+ unsigned Count, AssumptionCache &AC) {
BasicBlock *Latch = L->getLoopLatch();
assert(Latch && "Loop must have a latch");
BasicBlock *EpilogLatch = cast<BasicBlock>(VMap[Latch]);
@@ -231,7 +240,7 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,
// EpilogLatch
// Exit (EpilogPN)
- // Update PHI nodes at NewExit and Exit.
+ // Update PHI nodes at Exit.
for (PHINode &PN : NewExit->phis()) {
// PN should be used in another PHI located in Exit block as
// Exit was split by SplitBlockPredecessors into Exit and NewExit
@@ -246,15 +255,11 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,
// epilogue edges have already been added.
//
// There is EpilogPreHeader incoming block instead of NewExit as
- // NewExit was spilt 1 more time to get EpilogPreHeader.
+ // NewExit was split 1 more time to get EpilogPreHeader.
assert(PN.hasOneUse() && "The phi should have 1 use");
PHINode *EpilogPN = cast<PHINode>(PN.use_begin()->getUser());
assert(EpilogPN->getParent() == Exit && "EpilogPN should be in Exit block");
- // Add incoming PreHeader from branch around the Loop
- PN.addIncoming(PoisonValue::get(PN.getType()), PreHeader);
- SE.forgetValue(&PN);
-
Value *V = PN.getIncomingValueForBlock(Latch);
Instruction *I = dyn_cast<Instruction>(V);
if (I && L->contains(I))
@@ -271,35 +276,52 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,
NewExit);
// Now PHIs should look like:
// NewExit:
- // PN = PHI [I, Latch], [poison, PreHeader]
+ // PN = PHI [I, Latch]
// ...
// Exit:
// EpilogPN = PHI [PN, NewExit], [VMap[I], EpilogLatch]
}
- // Create PHI nodes at NewExit (from the unrolling loop Latch and PreHeader).
- // Update corresponding PHI nodes in epilog loop.
+ // Create PHI nodes at NewExit (from the unrolling loop Latch) and at
+ // EpilogPreHeader (from PreHeader and NewExit). Update corresponding PHI
+ // nodes in epilog loop.
for (BasicBlock *Succ : successors(Latch)) {
// Skip this as we already updated phis in exit blocks.
if (!L->contains(Succ))
continue;
+
+ // Succ here appears to always be just L->getHeader(). Otherwise, how do we
+ // know its corresponding epilog block (from VMap) is EpilogHeader and thus
+ // EpilogPreHeader is the right incoming block for VPN, as set below?
+ // TODO: Can we thus avoid the enclosing loop over successors?
+ assert(Succ == L->getHeader() &&
+ "Expect the only in-loop successor of latch to be the loop header");
+
for (PHINode &PN : Succ->phis()) {
- // Add new PHI nodes to the loop exit block and update epilog
- // PHIs with the new PHI values.
- PHINode *NewPN = PHINode::Create(PN.getType(), 2, PN.getName() + ".unr");
- NewPN->insertBefore(NewExit->getFirstNonPHIIt());
- // Adding a value to the new PHI node from the unrolling loop preheader.
- NewPN->addIncoming(PN.getIncomingValueForBlock(NewPreHeader), PreHeader);
- // Adding a value to the new PHI node from the unrolling loop latch.
- NewPN->addIncoming(PN.getIncomingValueForBlock(Latch), Latch);
+ // Add new PHI nodes to the loop exit block.
+ PHINode *NewPN0 = PHINode::Create(PN.getType(), /*NumReservedValues=*/1,
+ PN.getName() + ".unr");
+ NewPN0->insertBefore(NewExit->getFirstNonPHIIt());
+ // Add value to the new PHI node from the unrolling loop latch.
+ NewPN0->addIncoming(PN.getIncomingValueForBlock(Latch), Latch);
+
+ // Add new PHI nodes to EpilogPreHeader.
+ PHINode *NewPN1 = PHINode::Create(PN.getType(), /*NumReservedValues=*/2,
+ PN.getName() + ".epil.init");
+ NewPN1->insertBefore(EpilogPreHeader->getFirstNonPHIIt());
+ // Add value to the new PHI node from the unrolling loop preheader.
+ NewPN1->addIncoming(PN.getIncomingValueForBlock(NewPreHeader), PreHeader);
+ // Add value to the new PHI node from the epilog loop guard.
+ NewPN1->addIncoming(NewPN0, NewExit);
// Update the existing PHI node operand with the value from the new PHI
// node. Corresponding instruction in epilog loop should be PHI.
PHINode *VPN = cast<PHINode>(VMap[&PN]);
- VPN->setIncomingValueForBlock(EpilogPreHeader, NewPN);
+ VPN->setIncomingValueForBlock(EpilogPreHeader, NewPN1);
}
}
+ // In NewExit, branch around the epilog loop if no extra iters.
Instruction *InsertPt = NewExit->getTerminator();
IRBuilder<> B(InsertPt);
Value *BrLoopExit = B.CreateIsNotNull(ModVal, "lcmp.mod");
@@ -308,7 +330,7 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,
SmallVector<BasicBlock*, 4> Preds(predecessors(Exit));
SplitBlockPredecessors(Exit, Preds, ".epilog-lcssa", DT, LI, nullptr,
PreserveLCSSA);
- // Add the branch to the exit block (around the unrolling loop)
+ // Add the branch to the exit block (around the epilog loop)
MDNode *BranchWeights = nullptr;
if (hasBranchWeightMD(*Latch->getTerminator())) {
// Assume equal distribution in interval [0, Count).
@@ -322,10 +344,11 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,
DT->changeImmediateDominator(Exit, NewDom);
}
- // Split the main loop exit to maintain canonicalization guarantees.
- SmallVector<BasicBlock*, 4> NewExitPreds{Latch};
- SplitBlockPredecessors(NewExit, NewExitPreds, ".loopexit", DT, LI, nullptr,
- PreserveLCSSA);
+ // In EpilogPreHeader, assume extra iters is non-zero.
+ IRBuilder<> B2(EpilogPreHeader, EpilogPreHeader->getFirstNonPHIIt());
+ Value *ModIsNotNull = B2.CreateIsNotNull(ModVal, "lcmp.mod");
+ AssumeInst *AI = cast<AssumeInst>(B2.CreateAssumption(ModIsNotNull));
+ AC.registerAssumption(AI);
}
/// Create a clone of the blocks in a loop and connect them together. A new
@@ -795,7 +818,8 @@ bool llvm::UnrollRuntimeLoopRemainder(
ConstantInt::get(BECount->getType(),
Count - 1)) :
B.CreateIsNotNull(ModVal, "lcmp.mod");
- BasicBlock *RemainderLoop = UseEpilogRemainder ? NewExit : PrologPreHeader;
+ BasicBlock *RemainderLoop =
+ UseEpilogRemainder ? EpilogPreHeader : PrologPreHeader;
BasicBlock *UnrollingLoop = UseEpilogRemainder ? NewPreHeader : PrologExit;
// Branch to either remainder (extra iterations) loop or unrolling loop.
MDNode *BranchWeights = nullptr;
@@ -808,7 +832,7 @@ bool llvm::UnrollRuntimeLoopRemainder(
PreHeaderBR->eraseFromParent();
if (DT) {
if (UseEpilogRemainder)
- DT->changeImmediateDominator(NewExit, PreHeader);
+ DT->changeImmediateDominator(EpilogPreHeader, PreHeader);
else
DT->changeImmediateDominator(PrologExit, PreHeader);
}
@@ -880,7 +904,8 @@ bool llvm::UnrollRuntimeLoopRemainder(
// from both the original loop and the remainder code reaching the exit
// blocks. While the IDom of these exit blocks were from the original loop,
// now the IDom is the preheader (which decides whether the original loop or
- // remainder code should run).
+ // remainder code should run) unless the block still has just the original
+ // predecessor (such as NewExit in the case of an epilog remainder).
if (DT && !L->getExitingBlock()) {
SmallVector<BasicBlock *, 16> ChildrenToUpdate;
// NB! We have to examine the dom children of all loop blocks, not just
@@ -891,7 +916,8 @@ bool llvm::UnrollRuntimeLoopRemainder(
auto *DomNodeBB = DT->getNode(BB);
for (auto *DomChild : DomNodeBB->children()) {
auto *DomChildBB = DomChild->getBlock();
- if (!L->contains(LI->getLoopFor(DomChildBB)))
+ if (!L->contains(LI->getLoopFor(DomChildBB)) &&
+ DomChildBB->getUniquePredecessor() != BB)
ChildrenToUpdate.push_back(DomChildBB);
}
}
@@ -930,7 +956,7 @@ bool llvm::UnrollRuntimeLoopRemainder(
// Connect the epilog code to the original loop and update the
// PHI functions.
ConnectEpilog(L, ModVal, NewExit, LatchExit, PreHeader, EpilogPreHeader,
- NewPreHeader, VMap, DT, LI, PreserveLCSSA, *SE, Count);
+ NewPreHeader, VMap, DT, LI, PreserveLCSSA, *SE, Count, *AC);
// Update counter in loop for unrolling.
// Use an incrementing IV. Pre-incr/post-incr is backedge/trip count.
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 148bfa8..b8cfe3a 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -4895,9 +4895,8 @@ bool SimplifyCFGOpt::simplifyTerminatorOnSelect(Instruction *OldTerm,
// We found both of the successors we were looking for.
// Create a conditional branch sharing the condition of the select.
BranchInst *NewBI = Builder.CreateCondBr(Cond, TrueBB, FalseBB);
- if (TrueWeight != FalseWeight)
- setBranchWeights(*NewBI, {TrueWeight, FalseWeight},
- /*IsExpected=*/false, /*ElideAllZero=*/true);
+ setBranchWeights(*NewBI, {TrueWeight, FalseWeight},
+ /*IsExpected=*/false, /*ElideAllZero=*/true);
}
} else if (KeepEdge1 && (KeepEdge2 || TrueBB == FalseBB)) {
// Neither of the selected blocks were successors, so this
@@ -4982,9 +4981,15 @@ bool SimplifyCFGOpt::simplifyIndirectBrOnSelect(IndirectBrInst *IBI,
BasicBlock *TrueBB = TBA->getBasicBlock();
BasicBlock *FalseBB = FBA->getBasicBlock();
+ // The select's profile becomes the profile of the conditional branch that
+ // replaces the indirect branch.
+ SmallVector<uint32_t> SelectBranchWeights(2);
+ if (!ProfcheckDisableMetadataFixes)
+ extractBranchWeights(*SI, SelectBranchWeights);
// Perform the actual simplification.
- return simplifyTerminatorOnSelect(IBI, SI->getCondition(), TrueBB, FalseBB, 0,
- 0);
+ return simplifyTerminatorOnSelect(IBI, SI->getCondition(), TrueBB, FalseBB,
+ SelectBranchWeights[0],
+ SelectBranchWeights[1]);
}
/// This is called when we find an icmp instruction
@@ -7952,19 +7957,27 @@ bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {
bool SimplifyCFGOpt::simplifyIndirectBr(IndirectBrInst *IBI) {
BasicBlock *BB = IBI->getParent();
bool Changed = false;
+ SmallVector<uint32_t> BranchWeights;
+ const bool HasBranchWeights = !ProfcheckDisableMetadataFixes &&
+ extractBranchWeights(*IBI, BranchWeights);
+
+ DenseMap<const BasicBlock *, uint64_t> TargetWeight;
+ if (HasBranchWeights)
+ for (size_t I = 0, E = IBI->getNumDestinations(); I < E; ++I)
+ TargetWeight[IBI->getDestination(I)] += BranchWeights[I];
// Eliminate redundant destinations.
SmallPtrSet<Value *, 8> Succs;
SmallSetVector<BasicBlock *, 8> RemovedSuccs;
- for (unsigned i = 0, e = IBI->getNumDestinations(); i != e; ++i) {
- BasicBlock *Dest = IBI->getDestination(i);
+ for (unsigned I = 0, E = IBI->getNumDestinations(); I != E; ++I) {
+ BasicBlock *Dest = IBI->getDestination(I);
if (!Dest->hasAddressTaken() || !Succs.insert(Dest).second) {
if (!Dest->hasAddressTaken())
RemovedSuccs.insert(Dest);
Dest->removePredecessor(BB);
- IBI->removeDestination(i);
- --i;
- --e;
+ IBI->removeDestination(I);
+ --I;
+ --E;
Changed = true;
}
}
@@ -7990,7 +8003,12 @@ bool SimplifyCFGOpt::simplifyIndirectBr(IndirectBrInst *IBI) {
eraseTerminatorAndDCECond(IBI);
return true;
}
-
+ if (HasBranchWeights) {
+ SmallVector<uint64_t> NewBranchWeights(IBI->getNumDestinations());
+ for (size_t I = 0, E = IBI->getNumDestinations(); I < E; ++I)
+ NewBranchWeights[I] += TargetWeight.find(IBI->getDestination(I))->second;
+ setFittedBranchWeights(*IBI, NewBranchWeights, /*IsExpected=*/false);
+ }
if (SelectInst *SI = dyn_cast<SelectInst>(IBI->getAddress())) {
if (simplifyIndirectBrOnSelect(IBI, SI))
return requestResimplify();
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index e434e73..cee08ef 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -3903,7 +3903,8 @@ void LoopVectorizationPlanner::emitInvalidCostRemarks(
if (VF.isScalar())
continue;
- VPCostContext CostCtx(CM.TTI, *CM.TLI, *Plan, CM, CM.CostKind);
+ VPCostContext CostCtx(CM.TTI, *CM.TLI, *Plan, CM, CM.CostKind,
+ *CM.PSE.getSE());
precomputeCosts(*Plan, VF, CostCtx);
auto Iter = vp_depth_first_deep(Plan->getVectorLoopRegion()->getEntry());
for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(Iter)) {
@@ -4160,7 +4161,8 @@ VectorizationFactor LoopVectorizationPlanner::selectVectorizationFactor() {
// Add on other costs that are modelled in VPlan, but not in the legacy
// cost model.
- VPCostContext CostCtx(CM.TTI, *CM.TLI, *P, CM, CM.CostKind);
+ VPCostContext CostCtx(CM.TTI, *CM.TLI, *P, CM, CM.CostKind,
+ *CM.PSE.getSE());
VPRegionBlock *VectorRegion = P->getVectorLoopRegion();
assert(VectorRegion && "Expected to have a vector region!");
for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
@@ -6852,7 +6854,7 @@ LoopVectorizationPlanner::precomputeCosts(VPlan &Plan, ElementCount VF,
InstructionCost LoopVectorizationPlanner::cost(VPlan &Plan,
ElementCount VF) const {
- VPCostContext CostCtx(CM.TTI, *CM.TLI, Plan, CM, CM.CostKind);
+ VPCostContext CostCtx(CM.TTI, *CM.TLI, Plan, CM, CM.CostKind, *PSE.getSE());
InstructionCost Cost = precomputeCosts(Plan, VF, CostCtx);
// Now compute and add the VPlan-based cost.
@@ -7085,7 +7087,8 @@ VectorizationFactor LoopVectorizationPlanner::computeBestVF() {
// simplifications not accounted for in the legacy cost model. If that's the
// case, don't trigger the assertion, as the extra simplifications may cause a
// different VF to be picked by the VPlan-based cost model.
- VPCostContext CostCtx(CM.TTI, *CM.TLI, BestPlan, CM, CM.CostKind);
+ VPCostContext CostCtx(CM.TTI, *CM.TLI, BestPlan, CM, CM.CostKind,
+ *CM.PSE.getSE());
precomputeCosts(BestPlan, BestFactor.Width, CostCtx);
// Verify that the VPlan-based and legacy cost models agree, except for VPlans
// with early exits and plans with additional VPlan simplifications. The
@@ -8393,11 +8396,11 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
R->setOperand(1, WideIV->getStepValue());
}
- VPlanTransforms::runPass(
- VPlanTransforms::addExitUsersForFirstOrderRecurrences, *Plan, Range);
+ // TODO: We can't call runPass on these transforms yet, due to verifier
+ // failures.
+ VPlanTransforms::addExitUsersForFirstOrderRecurrences(*Plan, Range);
DenseMap<VPValue *, VPValue *> IVEndValues;
- VPlanTransforms::runPass(VPlanTransforms::addScalarResumePhis, *Plan,
- RecipeBuilder, IVEndValues);
+ VPlanTransforms::addScalarResumePhis(*Plan, RecipeBuilder, IVEndValues);
// ---------------------------------------------------------------------------
// Transform initial VPlan: Apply previously taken decisions, in order, to
@@ -8418,7 +8421,8 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
// TODO: Enable following transform when the EVL-version of extended-reduction
// and mulacc-reduction are implemented.
if (!CM.foldTailWithEVL()) {
- VPCostContext CostCtx(CM.TTI, *CM.TLI, *Plan, CM, CM.CostKind);
+ VPCostContext CostCtx(CM.TTI, *CM.TLI, *Plan, CM, CM.CostKind,
+ *CM.PSE.getSE());
VPlanTransforms::runPass(VPlanTransforms::convertToAbstractRecipes, *Plan,
CostCtx, Range);
}
@@ -8508,8 +8512,9 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlan(VFRange &Range) {
DenseMap<VPValue *, VPValue *> IVEndValues;
// TODO: IVEndValues are not used yet in the native path, to optimize exit
// values.
- VPlanTransforms::runPass(VPlanTransforms::addScalarResumePhis, *Plan,
- RecipeBuilder, IVEndValues);
+ // TODO: We can't call runPass on the transform yet, due to verifier
+ // failures.
+ VPlanTransforms::addScalarResumePhis(*Plan, RecipeBuilder, IVEndValues);
assert(verifyVPlanIsValid(*Plan) && "VPlan is invalid");
return Plan;
@@ -9873,7 +9878,7 @@ bool LoopVectorizePass::processLoop(Loop *L) {
bool ForceVectorization =
Hints.getForce() == LoopVectorizeHints::FK_Enabled;
VPCostContext CostCtx(CM.TTI, *CM.TLI, LVP.getPlanFor(VF.Width), CM,
- CM.CostKind);
+ CM.CostKind, *CM.PSE.getSE());
if (!ForceVectorization &&
!isOutsideLoopWorkProfitable(Checks, VF, L, PSE, CostCtx,
LVP.getPlanFor(VF.Width), SEL,
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp
index 07b191a..2555ebe 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp
@@ -1772,7 +1772,8 @@ VPCostContext::getOperandInfo(VPValue *V) const {
}
InstructionCost VPCostContext::getScalarizationOverhead(
- Type *ResultTy, ArrayRef<const VPValue *> Operands, ElementCount VF) {
+ Type *ResultTy, ArrayRef<const VPValue *> Operands, ElementCount VF,
+ bool AlwaysIncludeReplicatingR) {
if (VF.isScalar())
return 0;
@@ -1792,7 +1793,11 @@ InstructionCost VPCostContext::getScalarizationOverhead(
SmallPtrSet<const VPValue *, 4> UniqueOperands;
SmallVector<Type *> Tys;
for (auto *Op : Operands) {
- if (Op->isLiveIn() || isa<VPReplicateRecipe, VPPredInstPHIRecipe>(Op) ||
+ if (Op->isLiveIn() ||
+ (!AlwaysIncludeReplicatingR &&
+ isa<VPReplicateRecipe, VPPredInstPHIRecipe>(Op)) ||
+ (isa<VPReplicateRecipe>(Op) &&
+ cast<VPReplicateRecipe>(Op)->getOpcode() == Instruction::Load) ||
!UniqueOperands.insert(Op).second)
continue;
Tys.push_back(toVectorizedTy(Types.inferScalarType(Op), VF));
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index c167dd7..fb696be 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2263,8 +2263,7 @@ public:
/// debug location \p DL.
VPWidenPHIRecipe(PHINode *Phi, VPValue *Start = nullptr,
DebugLoc DL = DebugLoc::getUnknown(), const Twine &Name = "")
- : VPSingleDefRecipe(VPDef::VPWidenPHISC, ArrayRef<VPValue *>(), Phi, DL),
- Name(Name.str()) {
+ : VPSingleDefRecipe(VPDef::VPWidenPHISC, {}, Phi, DL), Name(Name.str()) {
if (Start)
addOperand(Start);
}
diff --git a/llvm/lib/Transforms/Vectorize/VPlanHelpers.h b/llvm/lib/Transforms/Vectorize/VPlanHelpers.h
index fc1a09e..1580a3b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanHelpers.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanHelpers.h
@@ -349,12 +349,14 @@ struct VPCostContext {
LoopVectorizationCostModel &CM;
SmallPtrSet<Instruction *, 8> SkipCostComputation;
TargetTransformInfo::TargetCostKind CostKind;
+ ScalarEvolution &SE;
VPCostContext(const TargetTransformInfo &TTI, const TargetLibraryInfo &TLI,
const VPlan &Plan, LoopVectorizationCostModel &CM,
- TargetTransformInfo::TargetCostKind CostKind)
+ TargetTransformInfo::TargetCostKind CostKind,
+ ScalarEvolution &SE)
: TTI(TTI), TLI(TLI), Types(Plan), LLVMCtx(Plan.getContext()), CM(CM),
- CostKind(CostKind) {}
+ CostKind(CostKind), SE(SE) {}
/// Return the cost for \p UI with \p VF using the legacy cost model as
/// fallback until computing the cost of all recipes migrates to VPlan.
@@ -374,10 +376,12 @@ struct VPCostContext {
/// Estimate the overhead of scalarizing a recipe with result type \p ResultTy
/// and \p Operands with \p VF. This is a convenience wrapper for the
- /// type-based getScalarizationOverhead API.
- InstructionCost getScalarizationOverhead(Type *ResultTy,
- ArrayRef<const VPValue *> Operands,
- ElementCount VF);
+ /// type-based getScalarizationOverhead API. If \p AlwaysIncludeReplicatingR
+ /// is true, always compute the cost of scalarizing replicating operands.
+ InstructionCost
+ getScalarizationOverhead(Type *ResultTy, ArrayRef<const VPValue *> Operands,
+ ElementCount VF,
+ bool AlwaysIncludeReplicatingR = false);
};
/// This class can be used to assign names to VPValues. For VPValues without
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 67b9244..94e2628 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -40,6 +40,7 @@
#include <cassert>
using namespace llvm;
+using namespace llvm::VPlanPatternMatch;
using VectorParts = SmallVector<Value *, 2>;
@@ -303,7 +304,6 @@ VPPartialReductionRecipe::computeCost(ElementCount VF,
VPRecipeBase *OpR = Op->getDefiningRecipe();
// If the partial reduction is predicated, a select will be operand 0
- using namespace llvm::VPlanPatternMatch;
if (match(getOperand(1), m_Select(m_VPValue(), m_VPValue(Op), m_VPValue()))) {
OpR = Op->getDefiningRecipe();
}
@@ -1963,7 +1963,6 @@ InstructionCost VPWidenSelectRecipe::computeCost(ElementCount VF,
Type *VectorTy = toVectorTy(Ctx.Types.inferScalarType(this), VF);
VPValue *Op0, *Op1;
- using namespace llvm::VPlanPatternMatch;
if (!ScalarCond && ScalarTy->getScalarSizeInBits() == 1 &&
(match(this, m_LogicalAnd(m_VPValue(Op0), m_VPValue(Op1))) ||
match(this, m_LogicalOr(m_VPValue(Op0), m_VPValue(Op1))))) {
@@ -2778,7 +2777,7 @@ VPExpressionRecipe::VPExpressionRecipe(
// Recipes in the expression, except the last one, must only be used by
// (other) recipes inside the expression. If there are other users, external
// to the expression, use a clone of the recipe for external users.
- for (VPSingleDefRecipe *R : ExpressionRecipes) {
+ for (VPSingleDefRecipe *R : reverse(ExpressionRecipes)) {
if (R != ExpressionRecipes.back() &&
any_of(R->users(), [&ExpressionRecipesAsSetOfUsers](VPUser *U) {
return !ExpressionRecipesAsSetOfUsers.contains(U);
@@ -3111,6 +3110,62 @@ bool VPReplicateRecipe::shouldPack() const {
});
}
+/// Returns true if \p Ptr is a pointer computation for which the legacy cost
+/// model computes a SCEV expression when computing the address cost.
+static bool shouldUseAddressAccessSCEV(const VPValue *Ptr) {
+ auto *PtrR = Ptr->getDefiningRecipe();
+ if (!PtrR || !((isa<VPReplicateRecipe>(PtrR) &&
+ cast<VPReplicateRecipe>(PtrR)->getOpcode() ==
+ Instruction::GetElementPtr) ||
+ isa<VPWidenGEPRecipe>(PtrR) ||
+ match(Ptr, m_GetElementPtr(m_VPValue(), m_VPValue()))))
+ return false;
+
+ // We are looking for a GEP where all indices are either loop invariant or
+ // inductions.
+ for (VPValue *Opd : drop_begin(PtrR->operands())) {
+ if (!Opd->isDefinedOutsideLoopRegions() &&
+ !isa<VPScalarIVStepsRecipe, VPWidenIntOrFpInductionRecipe>(Opd))
+ return false;
+ }
+
+ return true;
+}
+
+/// Returns true if \p V is used as part of the address of another load or
+/// store.
+static bool isUsedByLoadStoreAddress(const VPUser *V) {
+ SmallPtrSet<const VPUser *, 4> Seen;
+ SmallVector<const VPUser *> WorkList = {V};
+
+ while (!WorkList.empty()) {
+ auto *Cur = dyn_cast<VPSingleDefRecipe>(WorkList.pop_back_val());
+ if (!Cur || !Seen.insert(Cur).second)
+ continue;
+
+ for (VPUser *U : Cur->users()) {
+ if (auto *InterleaveR = dyn_cast<VPInterleaveBase>(U))
+ if (InterleaveR->getAddr() == Cur)
+ return true;
+ if (auto *RepR = dyn_cast<VPReplicateRecipe>(U)) {
+ if (RepR->getOpcode() == Instruction::Load &&
+ RepR->getOperand(0) == Cur)
+ return true;
+ if (RepR->getOpcode() == Instruction::Store &&
+ RepR->getOperand(1) == Cur)
+ return true;
+ }
+ if (auto *MemR = dyn_cast<VPWidenMemoryRecipe>(U)) {
+ if (MemR->getAddr() == Cur && MemR->isConsecutive())
+ return true;
+ }
+ }
+
+ append_range(WorkList, cast<VPSingleDefRecipe>(Cur)->users());
+ }
+ return false;
+}
+
InstructionCost VPReplicateRecipe::computeCost(ElementCount VF,
VPCostContext &Ctx) const {
Instruction *UI = cast<Instruction>(getUnderlyingValue());
@@ -3218,21 +3273,60 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF,
}
case Instruction::Load:
case Instruction::Store: {
- if (isSingleScalar()) {
- bool IsLoad = UI->getOpcode() == Instruction::Load;
- Type *ValTy = Ctx.Types.inferScalarType(IsLoad ? this : getOperand(0));
- Type *ScalarPtrTy = Ctx.Types.inferScalarType(getOperand(IsLoad ? 0 : 1));
- const Align Alignment = getLoadStoreAlignment(UI);
- unsigned AS = getLoadStoreAddressSpace(UI);
- TTI::OperandValueInfo OpInfo = TTI::getOperandInfo(UI->getOperand(0));
- InstructionCost ScalarMemOpCost = Ctx.TTI.getMemoryOpCost(
- UI->getOpcode(), ValTy, Alignment, AS, Ctx.CostKind, OpInfo, UI);
- return ScalarMemOpCost + Ctx.TTI.getAddressComputationCost(
- ScalarPtrTy, nullptr, nullptr, Ctx.CostKind);
- }
+ if (VF.isScalable() && !isSingleScalar())
+ return InstructionCost::getInvalid();
+
// TODO: See getMemInstScalarizationCost for how to handle replicating and
// predicated cases.
- break;
+ const VPRegionBlock *ParentRegion = getParent()->getParent();
+ if (ParentRegion && ParentRegion->isReplicator())
+ break;
+
+ bool IsLoad = UI->getOpcode() == Instruction::Load;
+ const VPValue *PtrOp = getOperand(!IsLoad);
+ // TODO: Handle cases where we need to pass a SCEV to
+ // getAddressComputationCost.
+ if (shouldUseAddressAccessSCEV(PtrOp))
+ break;
+
+ Type *ValTy = Ctx.Types.inferScalarType(IsLoad ? this : getOperand(0));
+ Type *ScalarPtrTy = Ctx.Types.inferScalarType(PtrOp);
+ const Align Alignment = getLoadStoreAlignment(UI);
+ unsigned AS = getLoadStoreAddressSpace(UI);
+ TTI::OperandValueInfo OpInfo = TTI::getOperandInfo(UI->getOperand(0));
+ InstructionCost ScalarMemOpCost = Ctx.TTI.getMemoryOpCost(
+ UI->getOpcode(), ValTy, Alignment, AS, Ctx.CostKind, OpInfo);
+
+ Type *PtrTy = isSingleScalar() ? ScalarPtrTy : toVectorTy(ScalarPtrTy, VF);
+ bool PreferVectorizedAddressing = Ctx.TTI.prefersVectorizedAddressing();
+ bool UsedByLoadStoreAddress =
+ !PreferVectorizedAddressing && isUsedByLoadStoreAddress(this);
+ InstructionCost ScalarCost =
+ ScalarMemOpCost + Ctx.TTI.getAddressComputationCost(
+ PtrTy, UsedByLoadStoreAddress ? nullptr : &Ctx.SE,
+ nullptr, Ctx.CostKind);
+ if (isSingleScalar())
+ return ScalarCost;
+
+ SmallVector<const VPValue *> OpsToScalarize;
+ Type *ResultTy = Type::getVoidTy(PtrTy->getContext());
+ // Set ResultTy and OpsToScalarize, if scalarization is needed. Currently we
+ // don't assign scalarization overhead in general, if the target prefers
+ // vectorized addressing or the loaded value is used as part of an address
+ // of another load or store.
+ if (!UsedByLoadStoreAddress) {
+ bool EfficientVectorLoadStore =
+ Ctx.TTI.supportsEfficientVectorElementLoadStore();
+ if (!(IsLoad && !PreferVectorizedAddressing) &&
+ !(!IsLoad && EfficientVectorLoadStore))
+ append_range(OpsToScalarize, operands());
+
+ if (!EfficientVectorLoadStore)
+ ResultTy = Ctx.Types.inferScalarType(this);
+ }
+
+ return (ScalarCost * VF.getFixedValue()) +
+ Ctx.getScalarizationOverhead(ResultTy, OpsToScalarize, VF, true);
}
}
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index ebf833e..c8a2d84 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -3180,9 +3180,8 @@ expandVPWidenIntOrFpInduction(VPWidenIntOrFpInductionRecipe *WidenIVR,
DebugLoc::getUnknown(), "induction");
// Create the widened phi of the vector IV.
- auto *WidePHI = new VPWidenPHIRecipe(WidenIVR->getPHINode(), nullptr,
+ auto *WidePHI = new VPWidenPHIRecipe(WidenIVR->getPHINode(), Init,
WidenIVR->getDebugLoc(), "vec.ind");
- WidePHI->addOperand(Init);
WidePHI->insertBefore(WidenIVR);
// Create the backedge value for the vector IV.
@@ -3545,8 +3544,7 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
VPValue *A, *B;
VPValue *Tmp = nullptr;
// Sub reductions could have a sub between the add reduction and vec op.
- if (match(VecOp,
- m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(Tmp)))) {
+ if (match(VecOp, m_Sub(m_ZeroInt(), m_VPValue(Tmp)))) {
Sub = VecOp->getDefiningRecipe();
VecOp = Tmp;
}
diff --git a/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp b/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp
index 0599930..66748c5 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp
@@ -71,8 +71,8 @@ bool vputils::isHeaderMask(const VPValue *V, VPlan &Plan) {
m_Specific(&Plan.getVF()))) ||
IsWideCanonicalIV(A));
- return match(V, m_Binary<Instruction::ICmp>(m_VPValue(A), m_VPValue(B))) &&
- IsWideCanonicalIV(A) && B == Plan.getOrCreateBackedgeTakenCount();
+ return match(V, m_ICmp(m_VPValue(A), m_VPValue(B))) && IsWideCanonicalIV(A) &&
+ B == Plan.getOrCreateBackedgeTakenCount();
}
const SCEV *vputils::getSCEVExprForVPValue(VPValue *V, ScalarEvolution &SE) {