diff options
Diffstat (limited to 'clang/lib/CodeGen')
-rw-r--r-- | clang/lib/CodeGen/BackendUtil.cpp | 20 | ||||
-rw-r--r-- | clang/lib/CodeGen/CGCall.cpp | 3 | ||||
-rw-r--r-- | clang/lib/CodeGen/CGDebugInfo.cpp | 23 | ||||
-rw-r--r-- | clang/lib/CodeGen/CGExpr.cpp | 210 | ||||
-rw-r--r-- | clang/lib/CodeGen/CGExprCXX.cpp | 27 | ||||
-rw-r--r-- | clang/lib/CodeGen/CGExprScalar.cpp | 5 | ||||
-rw-r--r-- | clang/lib/CodeGen/CodeGenFunction.cpp | 2 | ||||
-rw-r--r-- | clang/lib/CodeGen/CodeGenFunction.h | 10 | ||||
-rw-r--r-- | clang/lib/CodeGen/Targets/SPIR.cpp | 26 |
9 files changed, 304 insertions, 22 deletions
diff --git a/clang/lib/CodeGen/BackendUtil.cpp b/clang/lib/CodeGen/BackendUtil.cpp index 64f1917..2d95982 100644 --- a/clang/lib/CodeGen/BackendUtil.cpp +++ b/clang/lib/CodeGen/BackendUtil.cpp @@ -60,11 +60,13 @@ #include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/HipStdPar/HipStdPar.h" #include "llvm/Transforms/IPO/EmbedBitcodePass.h" +#include "llvm/Transforms/IPO/InferFunctionAttrs.h" #include "llvm/Transforms/IPO/LowerTypeTests.h" #include "llvm/Transforms/IPO/ThinLTOBitcodeWriter.h" #include "llvm/Transforms/InstCombine/InstCombine.h" #include "llvm/Transforms/Instrumentation/AddressSanitizer.h" #include "llvm/Transforms/Instrumentation/AddressSanitizerOptions.h" +#include "llvm/Transforms/Instrumentation/AllocToken.h" #include "llvm/Transforms/Instrumentation/BoundsChecking.h" #include "llvm/Transforms/Instrumentation/DataFlowSanitizer.h" #include "llvm/Transforms/Instrumentation/GCOVProfiler.h" @@ -232,6 +234,14 @@ public: }; } // namespace +static AllocTokenOptions getAllocTokenOptions(const CodeGenOptions &CGOpts) { + AllocTokenOptions Opts; + Opts.MaxTokens = CGOpts.AllocTokenMax; + Opts.Extended = CGOpts.SanitizeAllocTokenExtended; + Opts.FastABI = CGOpts.SanitizeAllocTokenFastABI; + return Opts; +} + static SanitizerCoverageOptions getSancovOptsFromCGOpts(const CodeGenOptions &CGOpts) { SanitizerCoverageOptions Opts; @@ -789,6 +799,16 @@ static void addSanitizers(const Triple &TargetTriple, MPM.addPass(DataFlowSanitizerPass(LangOpts.NoSanitizeFiles, PB.getVirtualFileSystemPtr())); } + + if (LangOpts.Sanitize.has(SanitizerKind::AllocToken)) { + if (Level == OptimizationLevel::O0) { + // The default pass builder only infers libcall function attrs when + // optimizing, so we insert it here because we need it for accurate + // memory allocation function detection. + MPM.addPass(InferFunctionAttrsPass()); + } + MPM.addPass(AllocTokenPass(getAllocTokenOptions(CodeGenOpts))); + } }; if (ClSanitizeOnOptimizerEarlyEP) { PB.registerOptimizerEarlyEPCallback( diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp index a931ce4..c5371e4 100644 --- a/clang/lib/CodeGen/CGCall.cpp +++ b/clang/lib/CodeGen/CGCall.cpp @@ -3018,8 +3018,7 @@ void CodeGenModule::ConstructAttributeList(StringRef Name, ArgNo = 0; if (AddedPotentialArgAccess && MemAttrForPtrArgs) { - llvm::FunctionType *FunctionType = FunctionType = - getTypes().GetFunctionType(FI); + llvm::FunctionType *FunctionType = getTypes().GetFunctionType(FI); for (CGFunctionInfo::const_arg_iterator I = FI.arg_begin(), E = FI.arg_end(); I != E; ++I, ++ArgNo) { diff --git a/clang/lib/CodeGen/CGDebugInfo.cpp b/clang/lib/CodeGen/CGDebugInfo.cpp index fee6bc0..9fe9a13 100644 --- a/clang/lib/CodeGen/CGDebugInfo.cpp +++ b/clang/lib/CodeGen/CGDebugInfo.cpp @@ -787,7 +787,8 @@ void CGDebugInfo::CreateCompileUnit() { // Create new compile unit. TheCU = DBuilder.createCompileUnit( - LangTag, CUFile, CGOpts.EmitVersionIdentMetadata ? Producer : "", + llvm::DISourceLanguageName(LangTag), CUFile, + CGOpts.EmitVersionIdentMetadata ? Producer : "", CGOpts.OptimizationLevel != 0 || CGOpts.PrepareForLTO || CGOpts.PrepareForThinLTO, CGOpts.DwarfDebugFlags, RuntimeVers, CGOpts.SplitDwarfFile, EmissionKind, @@ -899,10 +900,13 @@ llvm::DIType *CGDebugInfo::CreateType(const BuiltinType *BT) { assert((BT->getKind() != BuiltinType::SveCount || Info.NumVectors == 1) && "Unsupported number of vectors for svcount_t"); - // Debuggers can't extract 1bit from a vector, so will display a - // bitpattern for predicates instead. unsigned NumElems = Info.EC.getKnownMinValue() * Info.NumVectors; - if (Info.ElementType == CGM.getContext().BoolTy) { + llvm::Metadata *BitStride = nullptr; + if (BT->getKind() == BuiltinType::SveBool) { + Info.ElementType = CGM.getContext().UnsignedCharTy; + BitStride = llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned( + llvm::Type::getInt64Ty(CGM.getLLVMContext()), 1)); + } else if (BT->getKind() == BuiltinType::SveCount) { NumElems /= 8; Info.ElementType = CGM.getContext().UnsignedCharTy; } @@ -928,7 +932,7 @@ llvm::DIType *CGDebugInfo::CreateType(const BuiltinType *BT) { getOrCreateType(Info.ElementType, TheCU->getFile()); auto Align = getTypeAlignIfRequired(BT, CGM.getContext()); return DBuilder.createVectorType(/*Size*/ 0, Align, ElemTy, - SubscriptArray); + SubscriptArray, BitStride); } // It doesn't make sense to generate debug info for PowerPC MMA vector types. // So we return a safe type here to avoid generating an error. @@ -1232,7 +1236,7 @@ llvm::DIType *CGDebugInfo::CreateType(const PointerType *Ty, /// \return whether a C++ mangling exists for the type defined by TD. static bool hasCXXMangling(const TagDecl *TD, llvm::DICompileUnit *TheCU) { - switch (TheCU->getSourceLanguage()) { + switch (TheCU->getSourceLanguage().getUnversionedName()) { case llvm::dwarf::DW_LANG_C_plus_plus: case llvm::dwarf::DW_LANG_C_plus_plus_11: case llvm::dwarf::DW_LANG_C_plus_plus_14: @@ -3211,8 +3215,8 @@ llvm::DIType *CGDebugInfo::CreateType(const ObjCInterfaceType *Ty, if (!ID) return nullptr; - auto RuntimeLang = - static_cast<llvm::dwarf::SourceLanguage>(TheCU->getSourceLanguage()); + auto RuntimeLang = static_cast<llvm::dwarf::SourceLanguage>( + TheCU->getSourceLanguage().getUnversionedName()); // Return a forward declaration if this type was imported from a clang module, // and this is not the compile unit with the implementation of the type (which @@ -3348,7 +3352,8 @@ llvm::DIType *CGDebugInfo::CreateTypeDefinition(const ObjCInterfaceType *Ty, ObjCInterfaceDecl *ID = Ty->getDecl(); llvm::DIFile *DefUnit = getOrCreateFile(ID->getLocation()); unsigned Line = getLineNumber(ID->getLocation()); - unsigned RuntimeLang = TheCU->getSourceLanguage(); + + unsigned RuntimeLang = TheCU->getSourceLanguage().getUnversionedName(); // Bit size, align and offset of the type. uint64_t Size = CGM.getContext().getTypeSize(Ty); diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp index 9f30287..e8255b0 100644 --- a/clang/lib/CodeGen/CGExpr.cpp +++ b/clang/lib/CodeGen/CGExpr.cpp @@ -30,6 +30,7 @@ #include "clang/AST/Attr.h" #include "clang/AST/DeclObjC.h" #include "clang/AST/NSAPI.h" +#include "clang/AST/ParentMapContext.h" #include "clang/AST/StmtVisitor.h" #include "clang/Basic/Builtins.h" #include "clang/Basic/CodeGenOptions.h" @@ -1272,6 +1273,196 @@ void CodeGenFunction::EmitBoundsCheckImpl(const Expr *E, llvm::Value *Bound, EmitCheck(std::make_pair(Check, CheckKind), CheckHandler, StaticData, Index); } +static bool +typeContainsPointer(QualType T, + llvm::SmallPtrSet<const RecordDecl *, 4> &VisitedRD, + bool &IncompleteType) { + QualType CanonicalType = T.getCanonicalType(); + if (CanonicalType->isPointerType()) + return true; // base case + + // Look through typedef chain to check for special types. + for (QualType CurrentT = T; const auto *TT = CurrentT->getAs<TypedefType>(); + CurrentT = TT->getDecl()->getUnderlyingType()) { + const IdentifierInfo *II = TT->getDecl()->getIdentifier(); + // Special Case: Syntactically uintptr_t is not a pointer; semantically, + // however, very likely used as such. Therefore, classify uintptr_t as a + // pointer, too. + if (II && II->isStr("uintptr_t")) + return true; + } + + // The type is an array; check the element type. + if (const ArrayType *AT = dyn_cast<ArrayType>(CanonicalType)) + return typeContainsPointer(AT->getElementType(), VisitedRD, IncompleteType); + // The type is a struct, class, or union. + if (const RecordDecl *RD = CanonicalType->getAsRecordDecl()) { + if (!RD->isCompleteDefinition()) { + IncompleteType = true; + return false; + } + if (!VisitedRD.insert(RD).second) + return false; // already visited + // Check all fields. + for (const FieldDecl *Field : RD->fields()) { + if (typeContainsPointer(Field->getType(), VisitedRD, IncompleteType)) + return true; + } + // For C++ classes, also check base classes. + if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) { + // Polymorphic types require a vptr. + if (CXXRD->isDynamicClass()) + return true; + for (const CXXBaseSpecifier &Base : CXXRD->bases()) { + if (typeContainsPointer(Base.getType(), VisitedRD, IncompleteType)) + return true; + } + } + } + return false; +} + +void CodeGenFunction::EmitAllocToken(llvm::CallBase *CB, QualType AllocType) { + assert(SanOpts.has(SanitizerKind::AllocToken) && + "Only needed with -fsanitize=alloc-token"); + + llvm::MDBuilder MDB(getLLVMContext()); + + // Get unique type name. + PrintingPolicy Policy(CGM.getContext().getLangOpts()); + Policy.SuppressTagKeyword = true; + Policy.FullyQualifiedName = true; + SmallString<64> TypeName; + llvm::raw_svector_ostream TypeNameOS(TypeName); + AllocType.getCanonicalType().print(TypeNameOS, Policy); + auto *TypeNameMD = MDB.createString(TypeNameOS.str()); + + // Check if QualType contains a pointer. Implements a simple DFS to + // recursively check if a type contains a pointer type. + llvm::SmallPtrSet<const RecordDecl *, 4> VisitedRD; + bool IncompleteType = false; + const bool ContainsPtr = + typeContainsPointer(AllocType, VisitedRD, IncompleteType); + if (!ContainsPtr && IncompleteType) + return; + auto *ContainsPtrC = Builder.getInt1(ContainsPtr); + auto *ContainsPtrMD = MDB.createConstant(ContainsPtrC); + + // Format: !{<type-name>, <contains-pointer>} + auto *MDN = + llvm::MDNode::get(CGM.getLLVMContext(), {TypeNameMD, ContainsPtrMD}); + CB->setMetadata(llvm::LLVMContext::MD_alloc_token, MDN); +} + +namespace { +/// Infer type from a simple sizeof expression. +QualType inferTypeFromSizeofExpr(const Expr *E) { + const Expr *Arg = E->IgnoreParenImpCasts(); + if (const auto *UET = dyn_cast<UnaryExprOrTypeTraitExpr>(Arg)) { + if (UET->getKind() == UETT_SizeOf) { + if (UET->isArgumentType()) + return UET->getArgumentTypeInfo()->getType(); + else + return UET->getArgumentExpr()->getType(); + } + } + return QualType(); +} + +/// Infer type from an arithmetic expression involving a sizeof. For example: +/// +/// malloc(sizeof(MyType) + padding); // infers 'MyType' +/// malloc(sizeof(MyType) * 32); // infers 'MyType' +/// malloc(32 * sizeof(MyType)); // infers 'MyType' +/// malloc(sizeof(MyType) << 1); // infers 'MyType' +/// ... +/// +/// More complex arithmetic expressions are supported, but are a heuristic, e.g. +/// when considering allocations for structs with flexible array members: +/// +/// malloc(sizeof(HasFlexArray) + sizeof(int) * 32); // infers 'HasFlexArray' +/// +QualType inferPossibleTypeFromArithSizeofExpr(const Expr *E) { + const Expr *Arg = E->IgnoreParenImpCasts(); + // The argument is a lone sizeof expression. + if (QualType T = inferTypeFromSizeofExpr(Arg); !T.isNull()) + return T; + if (const auto *BO = dyn_cast<BinaryOperator>(Arg)) { + // Argument is an arithmetic expression. Cover common arithmetic patterns + // involving sizeof. + switch (BO->getOpcode()) { + case BO_Add: + case BO_Div: + case BO_Mul: + case BO_Shl: + case BO_Shr: + case BO_Sub: + if (QualType T = inferPossibleTypeFromArithSizeofExpr(BO->getLHS()); + !T.isNull()) + return T; + if (QualType T = inferPossibleTypeFromArithSizeofExpr(BO->getRHS()); + !T.isNull()) + return T; + break; + default: + break; + } + } + return QualType(); +} + +/// If the expression E is a reference to a variable, infer the type from a +/// variable's initializer if it contains a sizeof. Beware, this is a heuristic +/// and ignores if a variable is later reassigned. For example: +/// +/// size_t my_size = sizeof(MyType); +/// void *x = malloc(my_size); // infers 'MyType' +/// +QualType inferPossibleTypeFromVarInitSizeofExpr(const Expr *E) { + const Expr *Arg = E->IgnoreParenImpCasts(); + if (const auto *DRE = dyn_cast<DeclRefExpr>(Arg)) { + if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) { + if (const Expr *Init = VD->getInit()) + return inferPossibleTypeFromArithSizeofExpr(Init); + } + } + return QualType(); +} + +/// Deduces the allocated type by checking if the allocation call's result +/// is immediately used in a cast expression. For example: +/// +/// MyType *x = (MyType *)malloc(4096); // infers 'MyType' +/// +QualType inferPossibleTypeFromCastExpr(const CallExpr *CallE, + const CastExpr *CastE) { + if (!CastE) + return QualType(); + QualType PtrType = CastE->getType(); + if (PtrType->isPointerType()) + return PtrType->getPointeeType(); + return QualType(); +} +} // end anonymous namespace + +void CodeGenFunction::EmitAllocToken(llvm::CallBase *CB, const CallExpr *E) { + QualType AllocType; + // First check arguments. + for (const Expr *Arg : E->arguments()) { + AllocType = inferPossibleTypeFromArithSizeofExpr(Arg); + if (AllocType.isNull()) + AllocType = inferPossibleTypeFromVarInitSizeofExpr(Arg); + if (!AllocType.isNull()) + break; + } + // Then check later casts. + if (AllocType.isNull()) + AllocType = inferPossibleTypeFromCastExpr(E, CurCast); + // Emit if we were able to infer the type. + if (!AllocType.isNull()) + EmitAllocToken(CB, AllocType); +} + CodeGenFunction::ComplexPairTy CodeGenFunction:: EmitComplexPrePostIncDec(const UnaryOperator *E, LValue LV, bool isInc, bool isPre) { @@ -5642,6 +5833,9 @@ LValue CodeGenFunction::EmitConditionalOperatorLValue( /// are permitted with aggregate result, including noop aggregate casts, and /// cast from scalar to union. LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) { + auto RestoreCurCast = + llvm::make_scope_exit([this, Prev = CurCast] { CurCast = Prev; }); + CurCast = E; switch (E->getCastKind()) { case CK_ToVoid: case CK_BitCast: @@ -6587,16 +6781,24 @@ RValue CodeGenFunction::EmitCall(QualType CalleeType, RValue Call = EmitCall(FnInfo, Callee, ReturnValue, Args, &LocalCallOrInvoke, E == MustTailCall, E->getExprLoc()); - // Generate function declaration DISuprogram in order to be used - // in debug info about call sites. - if (CGDebugInfo *DI = getDebugInfo()) { - if (auto *CalleeDecl = dyn_cast_or_null<FunctionDecl>(TargetDecl)) { + if (auto *CalleeDecl = dyn_cast_or_null<FunctionDecl>(TargetDecl)) { + // Generate function declaration DISuprogram in order to be used + // in debug info about call sites. + if (CGDebugInfo *DI = getDebugInfo()) { FunctionArgList Args; QualType ResTy = BuildFunctionArgList(CalleeDecl, Args); DI->EmitFuncDeclForCallSite(LocalCallOrInvoke, DI->getFunctionType(CalleeDecl, ResTy, Args), CalleeDecl); } + if (CalleeDecl->hasAttr<RestrictAttr>() || + CalleeDecl->hasAttr<AllocSizeAttr>()) { + // Function has 'malloc' (aka. 'restrict') or 'alloc_size' attribute. + if (SanOpts.has(SanitizerKind::AllocToken)) { + // Set !alloc_token metadata. + EmitAllocToken(LocalCallOrInvoke, E); + } + } } if (CallOrInvoke) *CallOrInvoke = LocalCallOrInvoke; diff --git a/clang/lib/CodeGen/CGExprCXX.cpp b/clang/lib/CodeGen/CGExprCXX.cpp index c52526c..31ac266 100644 --- a/clang/lib/CodeGen/CGExprCXX.cpp +++ b/clang/lib/CodeGen/CGExprCXX.cpp @@ -1371,8 +1371,16 @@ RValue CodeGenFunction::EmitBuiltinNewDeleteCall(const FunctionProtoType *Type, for (auto *Decl : Ctx.getTranslationUnitDecl()->lookup(Name)) if (auto *FD = dyn_cast<FunctionDecl>(Decl)) - if (Ctx.hasSameType(FD->getType(), QualType(Type, 0))) - return EmitNewDeleteCall(*this, FD, Type, Args); + if (Ctx.hasSameType(FD->getType(), QualType(Type, 0))) { + RValue RV = EmitNewDeleteCall(*this, FD, Type, Args); + if (auto *CB = dyn_cast_if_present<llvm::CallBase>(RV.getScalarVal())) { + if (SanOpts.has(SanitizerKind::AllocToken)) { + // Set !alloc_token metadata. + EmitAllocToken(CB, TheCall); + } + } + return RV; + } llvm_unreachable("predeclared global operator new/delete is missing"); } @@ -1655,11 +1663,16 @@ llvm::Value *CodeGenFunction::EmitCXXNewExpr(const CXXNewExpr *E) { RValue RV = EmitNewDeleteCall(*this, allocator, allocatorType, allocatorArgs); - // Set !heapallocsite metadata on the call to operator new. - if (getDebugInfo()) - if (auto *newCall = dyn_cast<llvm::CallBase>(RV.getScalarVal())) - getDebugInfo()->addHeapAllocSiteMetadata(newCall, allocType, - E->getExprLoc()); + if (auto *newCall = dyn_cast<llvm::CallBase>(RV.getScalarVal())) { + if (auto *CGDI = getDebugInfo()) { + // Set !heapallocsite metadata on the call to operator new. + CGDI->addHeapAllocSiteMetadata(newCall, allocType, E->getExprLoc()); + } + if (SanOpts.has(SanitizerKind::AllocToken)) { + // Set !alloc_token metadata. + EmitAllocToken(newCall, allocType); + } + } // If this was a call to a global replaceable allocation function that does // not take an alignment argument, the allocator is known to produce diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp index 06d9d81..715160d 100644 --- a/clang/lib/CodeGen/CGExprScalar.cpp +++ b/clang/lib/CodeGen/CGExprScalar.cpp @@ -33,6 +33,7 @@ #include "clang/Basic/DiagnosticTrap.h" #include "clang/Basic/TargetInfo.h" #include "llvm/ADT/APFixedPoint.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/IR/Argument.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" @@ -2434,6 +2435,10 @@ static Value *EmitHLSLElementwiseCast(CodeGenFunction &CGF, LValue SrcVal, // have to handle a more broad range of conversions than explicit casts, as they // handle things like function to ptr-to-function decay etc. Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) { + auto RestoreCurCast = + llvm::make_scope_exit([this, Prev = CGF.CurCast] { CGF.CurCast = Prev; }); + CGF.CurCast = CE; + Expr *E = CE->getSubExpr(); QualType DestTy = CE->getType(); CastKind Kind = CE->getCastKind(); diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp index b2fe917..acf8de4 100644 --- a/clang/lib/CodeGen/CodeGenFunction.cpp +++ b/clang/lib/CodeGen/CodeGenFunction.cpp @@ -846,6 +846,8 @@ void CodeGenFunction::StartFunction(GlobalDecl GD, QualType RetTy, Fn->addFnAttr(llvm::Attribute::SanitizeNumericalStability); if (SanOpts.hasOneOf(SanitizerKind::Memory | SanitizerKind::KernelMemory)) Fn->addFnAttr(llvm::Attribute::SanitizeMemory); + if (SanOpts.has(SanitizerKind::AllocToken)) + Fn->addFnAttr(llvm::Attribute::SanitizeAllocToken); } if (SanOpts.has(SanitizerKind::SafeStack)) Fn->addFnAttr(llvm::Attribute::SafeStack); diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h index 99de6e1..1f0be2d 100644 --- a/clang/lib/CodeGen/CodeGenFunction.h +++ b/clang/lib/CodeGen/CodeGenFunction.h @@ -346,6 +346,10 @@ public: QualType FnRetTy; llvm::Function *CurFn = nullptr; + /// If a cast expression is being visited, this holds the current cast's + /// expression. + const CastExpr *CurCast = nullptr; + /// Save Parameter Decl for coroutine. llvm::SmallVector<const ParmVarDecl *, 4> FnArgs; @@ -3348,6 +3352,12 @@ public: SanitizerAnnotateDebugInfo(ArrayRef<SanitizerKind::SanitizerOrdinal> Ordinals, SanitizerHandler Handler); + /// Emit additional metadata used by the AllocToken instrumentation. + void EmitAllocToken(llvm::CallBase *CB, QualType AllocType); + /// Emit additional metadata used by the AllocToken instrumentation, + /// inferring the type from an allocation call expression. + void EmitAllocToken(llvm::CallBase *CB, const CallExpr *E); + llvm::Value *GetCountedByFieldExprGEP(const Expr *Base, const FieldDecl *FD, const FieldDecl *CountDecl); diff --git a/clang/lib/CodeGen/Targets/SPIR.cpp b/clang/lib/CodeGen/Targets/SPIR.cpp index 4aa6314..3f6d4e0 100644 --- a/clang/lib/CodeGen/Targets/SPIR.cpp +++ b/clang/lib/CodeGen/Targets/SPIR.cpp @@ -61,6 +61,9 @@ public: QualType SampledType, CodeGenModule &CGM) const; void setOCLKernelStubCallingConvention(const FunctionType *&FT) const override; + llvm::Constant *getNullPointer(const CodeGen::CodeGenModule &CGM, + llvm::PointerType *T, + QualType QT) const override; }; class SPIRVTargetCodeGenInfo : public CommonSPIRTargetCodeGenInfo { public: @@ -240,6 +243,29 @@ void CommonSPIRTargetCodeGenInfo::setOCLKernelStubCallingConvention( FT, FT->getExtInfo().withCallingConv(CC_SpirFunction)); } +// LLVM currently assumes a null pointer has the bit pattern 0, but some GPU +// targets use a non-zero encoding for null in certain address spaces. +// Because SPIR(-V) is a generic target and the bit pattern of null in +// non-generic AS is unspecified, materialize null in non-generic AS via an +// addrspacecast from null in generic AS. This allows later lowering to +// substitute the target's real sentinel value. +llvm::Constant * +CommonSPIRTargetCodeGenInfo::getNullPointer(const CodeGen::CodeGenModule &CGM, + llvm::PointerType *PT, + QualType QT) const { + LangAS AS = QT->getUnqualifiedDesugaredType()->isNullPtrType() + ? LangAS::Default + : QT->getPointeeType().getAddressSpace(); + if (AS == LangAS::Default || AS == LangAS::opencl_generic) + return llvm::ConstantPointerNull::get(PT); + + auto &Ctx = CGM.getContext(); + auto NPT = llvm::PointerType::get( + PT->getContext(), Ctx.getTargetAddressSpace(LangAS::opencl_generic)); + return llvm::ConstantExpr::getAddrSpaceCast( + llvm::ConstantPointerNull::get(NPT), PT); +} + LangAS SPIRVTargetCodeGenInfo::getGlobalVarAddressSpace(CodeGenModule &CGM, const VarDecl *D) const { |