diff options
Diffstat (limited to 'clang/lib/Sema/SemaHLSL.cpp')
-rw-r--r-- | clang/lib/Sema/SemaHLSL.cpp | 292 |
1 files changed, 167 insertions, 125 deletions
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 9276554..8536e04 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -39,6 +39,7 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" +#include "llvm/Frontend/HLSL/HLSLBinding.h" #include "llvm/Frontend/HLSL/RootSignatureValidations.h" #include "llvm/Support/Casting.h" #include "llvm/Support/DXILABI.h" @@ -596,8 +597,9 @@ void SemaHLSL::ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace) { // create buffer layout struct createHostLayoutStructForBuffer(SemaRef, BufDecl); + HLSLVkBindingAttr *VkBinding = Dcl->getAttr<HLSLVkBindingAttr>(); HLSLResourceBindingAttr *RBA = Dcl->getAttr<HLSLResourceBindingAttr>(); - if (!RBA || !RBA->hasRegisterSlot()) { + if (!VkBinding && (!RBA || !RBA->hasRegisterSlot())) { SemaRef.Diag(Dcl->getLocation(), diag::warn_hlsl_implicit_binding); // Use HLSLResourceBindingAttr to transfer implicit binding order_ID // to codegen. If it does not exist, create an implicit attribute. @@ -1083,6 +1085,102 @@ void SemaHLSL::ActOnFinishRootSignatureDecl( SemaRef.PushOnScopeChains(SignatureDecl, SemaRef.getCurScope()); } +namespace { + +struct PerVisibilityBindingChecker { + SemaHLSL *S; + // We need one builder per `llvm::dxbc::ShaderVisibility` value. + std::array<llvm::hlsl::BindingInfoBuilder, 8> Builders; + + struct ElemInfo { + const hlsl::RootSignatureElement *Elem; + llvm::dxbc::ShaderVisibility Vis; + bool Diagnosed; + }; + llvm::SmallVector<ElemInfo> ElemInfoMap; + + PerVisibilityBindingChecker(SemaHLSL *S) : S(S) {} + + void trackBinding(llvm::dxbc::ShaderVisibility Visibility, + llvm::dxil::ResourceClass RC, uint32_t Space, + uint32_t LowerBound, uint32_t UpperBound, + const hlsl::RootSignatureElement *Elem) { + uint32_t BuilderIndex = llvm::to_underlying(Visibility); + assert(BuilderIndex < Builders.size() && + "Not enough builders for visibility type"); + Builders[BuilderIndex].trackBinding(RC, Space, LowerBound, UpperBound, + static_cast<const void *>(Elem)); + + static_assert(llvm::to_underlying(llvm::dxbc::ShaderVisibility::All) == 0, + "'All' visibility must come first"); + if (Visibility == llvm::dxbc::ShaderVisibility::All) + for (size_t I = 1, E = Builders.size(); I < E; ++I) + Builders[I].trackBinding(RC, Space, LowerBound, UpperBound, + static_cast<const void *>(Elem)); + + ElemInfoMap.push_back({Elem, Visibility, false}); + } + + ElemInfo &getInfo(const hlsl::RootSignatureElement *Elem) { + auto It = llvm::lower_bound( + ElemInfoMap, Elem, + [](const auto &LHS, const auto &RHS) { return LHS.Elem < RHS; }); + assert(It->Elem == Elem && "Element not in map"); + return *It; + } + + bool checkOverlap() { + llvm::sort(ElemInfoMap, [](const auto &LHS, const auto &RHS) { + return LHS.Elem < RHS.Elem; + }); + + bool HadOverlap = false; + + using llvm::hlsl::BindingInfoBuilder; + auto ReportOverlap = [this, &HadOverlap]( + const BindingInfoBuilder &Builder, + const BindingInfoBuilder::Binding &Reported) { + HadOverlap = true; + + const auto *Elem = + static_cast<const hlsl::RootSignatureElement *>(Reported.Cookie); + const BindingInfoBuilder::Binding &Previous = + Builder.findOverlapping(Reported); + const auto *PrevElem = + static_cast<const hlsl::RootSignatureElement *>(Previous.Cookie); + + ElemInfo &Info = getInfo(Elem); + // We will have already diagnosed this binding if there's overlap in the + // "All" visibility as well as any particular visibility. + if (Info.Diagnosed) + return; + Info.Diagnosed = true; + + ElemInfo &PrevInfo = getInfo(PrevElem); + llvm::dxbc::ShaderVisibility CommonVis = + Info.Vis == llvm::dxbc::ShaderVisibility::All ? PrevInfo.Vis + : Info.Vis; + + this->S->Diag(Elem->getLocation(), diag::err_hlsl_resource_range_overlap) + << llvm::to_underlying(Reported.RC) << Reported.LowerBound + << Reported.isUnbounded() << Reported.UpperBound + << llvm::to_underlying(Previous.RC) << Previous.LowerBound + << Previous.isUnbounded() << Previous.UpperBound << Reported.Space + << CommonVis; + + this->S->Diag(PrevElem->getLocation(), + diag::note_hlsl_resource_range_here); + }; + + for (BindingInfoBuilder &Builder : Builders) + Builder.calculateBindingInfo(ReportOverlap); + + return HadOverlap; + } +}; + +} // end anonymous namespace + bool SemaHLSL::handleRootSignatureElements( ArrayRef<hlsl::RootSignatureElement> Elements) { // Define some common error handling functions @@ -1171,147 +1269,67 @@ bool SemaHLSL::handleRootSignatureElements( } } - using RangeInfo = llvm::hlsl::rootsig::RangeInfo; - using OverlappingRanges = llvm::hlsl::rootsig::OverlappingRanges; - using InfoPairT = std::pair<RangeInfo, const hlsl::RootSignatureElement *>; + PerVisibilityBindingChecker BindingChecker(this); + SmallVector<std::pair<const llvm::hlsl::rootsig::DescriptorTableClause *, + const hlsl::RootSignatureElement *>> + UnboundClauses; - // 1. Collect RangeInfos - llvm::SmallVector<InfoPairT> InfoPairs; for (const hlsl::RootSignatureElement &RootSigElem : Elements) { const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement(); if (const auto *Descriptor = std::get_if<llvm::hlsl::rootsig::RootDescriptor>(&Elem)) { - RangeInfo Info; - Info.LowerBound = Descriptor->Reg.Number; - Info.UpperBound = Info.LowerBound; // use inclusive ranges [] + uint32_t LowerBound(Descriptor->Reg.Number); + uint32_t UpperBound(LowerBound); // inclusive range - Info.Class = - llvm::dxil::ResourceClass(llvm::to_underlying(Descriptor->Type)); - Info.Space = Descriptor->Space; - Info.Visibility = Descriptor->Visibility; - - InfoPairs.push_back({Info, &RootSigElem}); + BindingChecker.trackBinding( + Descriptor->Visibility, + static_cast<llvm::dxil::ResourceClass>(Descriptor->Type), + Descriptor->Space, LowerBound, UpperBound, &RootSigElem); } else if (const auto *Constants = std::get_if<llvm::hlsl::rootsig::RootConstants>(&Elem)) { - RangeInfo Info; - Info.LowerBound = Constants->Reg.Number; - Info.UpperBound = Info.LowerBound; // use inclusive ranges [] - - Info.Class = llvm::dxil::ResourceClass::CBuffer; - Info.Space = Constants->Space; - Info.Visibility = Constants->Visibility; + uint32_t LowerBound(Constants->Reg.Number); + uint32_t UpperBound(LowerBound); // inclusive range - InfoPairs.push_back({Info, &RootSigElem}); + BindingChecker.trackBinding( + Constants->Visibility, llvm::dxil::ResourceClass::CBuffer, + Constants->Space, LowerBound, UpperBound, &RootSigElem); } else if (const auto *Sampler = std::get_if<llvm::hlsl::rootsig::StaticSampler>(&Elem)) { - RangeInfo Info; - Info.LowerBound = Sampler->Reg.Number; - Info.UpperBound = Info.LowerBound; // use inclusive ranges [] + uint32_t LowerBound(Sampler->Reg.Number); + uint32_t UpperBound(LowerBound); // inclusive range - Info.Class = llvm::dxil::ResourceClass::Sampler; - Info.Space = Sampler->Space; - Info.Visibility = Sampler->Visibility; - - InfoPairs.push_back({Info, &RootSigElem}); + BindingChecker.trackBinding( + Sampler->Visibility, llvm::dxil::ResourceClass::Sampler, + Sampler->Space, LowerBound, UpperBound, &RootSigElem); } else if (const auto *Clause = std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>( &Elem)) { - RangeInfo Info; - Info.LowerBound = Clause->Reg.Number; - // Relevant error will have already been reported above and needs to be - // fixed before we can conduct range analysis, so shortcut error return - if (Clause->NumDescriptors == 0) - return true; - Info.UpperBound = Clause->NumDescriptors == RangeInfo::Unbounded - ? RangeInfo::Unbounded - : Info.LowerBound + Clause->NumDescriptors - - 1; // use inclusive ranges [] - - Info.Class = Clause->Type; - Info.Space = Clause->Space; - - // Note: Clause does not hold the visibility this will need to - InfoPairs.push_back({Info, &RootSigElem}); + // We'll process these once we see the table element. + UnboundClauses.emplace_back(Clause, &RootSigElem); } else if (const auto *Table = std::get_if<llvm::hlsl::rootsig::DescriptorTable>(&Elem)) { - // Table holds the Visibility of all owned Clauses in Table, so iterate - // owned Clauses and update their corresponding RangeInfo - assert(Table->NumClauses <= InfoPairs.size() && "RootElement"); - // The last Table->NumClauses elements of Infos are the owned Clauses - // generated RangeInfo - auto TableInfos = - MutableArrayRef<InfoPairT>(InfoPairs).take_back(Table->NumClauses); - for (InfoPairT &Pair : TableInfos) - Pair.first.Visibility = Table->Visibility; - } - } - - // 2. Sort with the RangeInfo <operator to prepare it for findOverlapping - llvm::sort(InfoPairs, - [](InfoPairT A, InfoPairT B) { return A.first < B.first; }); - - llvm::SmallVector<RangeInfo> Infos; - for (const InfoPairT &Pair : InfoPairs) - Infos.push_back(Pair.first); - - // Helpers to report diagnostics - uint32_t DuplicateCounter = 0; - using ElemPair = std::pair<const hlsl::RootSignatureElement *, - const hlsl::RootSignatureElement *>; - auto GetElemPair = [&Infos, &InfoPairs, &DuplicateCounter]( - OverlappingRanges Overlap) -> ElemPair { - // Given we sorted the InfoPairs (and by implication) Infos, and, - // that Overlap.B is the item retrieved from the ResourceRange. Then it is - // guarenteed that Overlap.B <= Overlap.A. - // - // So we will find Overlap.B first and then continue to find Overlap.A - // after - auto InfoB = std::lower_bound(Infos.begin(), Infos.end(), *Overlap.B); - auto DistB = std::distance(Infos.begin(), InfoB); - auto PairB = InfoPairs.begin(); - std::advance(PairB, DistB); - - auto InfoA = std::lower_bound(InfoB, Infos.end(), *Overlap.A); - // Similarily, from the property that we have sorted the RangeInfos, - // all duplicates will be processed one after the other. So - // DuplicateCounter can be re-used for each set of duplicates we - // encounter as we handle incoming errors - DuplicateCounter = InfoA == InfoB ? DuplicateCounter + 1 : 0; - auto DistA = std::distance(InfoB, InfoA) + DuplicateCounter; - auto PairA = PairB; - std::advance(PairA, DistA); - - return {PairA->second, PairB->second}; - }; - - auto ReportOverlap = [this, &GetElemPair](OverlappingRanges Overlap) { - auto Pair = GetElemPair(Overlap); - const RangeInfo *Info = Overlap.A; - const hlsl::RootSignatureElement *Elem = Pair.first; - const RangeInfo *OInfo = Overlap.B; - - auto CommonVis = Info->Visibility == llvm::dxbc::ShaderVisibility::All - ? OInfo->Visibility - : Info->Visibility; - this->Diag(Elem->getLocation(), diag::err_hlsl_resource_range_overlap) - << llvm::to_underlying(Info->Class) << Info->LowerBound - << /*unbounded=*/(Info->UpperBound == RangeInfo::Unbounded) - << Info->UpperBound << llvm::to_underlying(OInfo->Class) - << OInfo->LowerBound - << /*unbounded=*/(OInfo->UpperBound == RangeInfo::Unbounded) - << OInfo->UpperBound << Info->Space << CommonVis; - - const hlsl::RootSignatureElement *OElem = Pair.second; - this->Diag(OElem->getLocation(), diag::note_hlsl_resource_range_here); - }; - - // 3. Invoke find overlapping ranges - llvm::SmallVector<OverlappingRanges> Overlaps = - llvm::hlsl::rootsig::findOverlappingRanges(Infos); - for (OverlappingRanges Overlap : Overlaps) - ReportOverlap(Overlap); + assert(UnboundClauses.size() == Table->NumClauses && + "Number of unbound elements must match the number of clauses"); + for (const auto &[Clause, ClauseElem] : UnboundClauses) { + uint32_t LowerBound(Clause->Reg.Number); + // Relevant error will have already been reported above and needs to be + // fixed before we can conduct range analysis, so shortcut error return + if (Clause->NumDescriptors == 0) + return true; + uint32_t UpperBound = Clause->NumDescriptors == ~0u + ? ~0u + : LowerBound + Clause->NumDescriptors - 1; + + BindingChecker.trackBinding( + Table->Visibility, + static_cast<llvm::dxil::ResourceClass>(Clause->Type), Clause->Space, + LowerBound, UpperBound, ClauseElem); + } + UnboundClauses.clear(); + } + } - return Overlaps.size() != 0; + return BindingChecker.checkOverlap(); } void SemaHLSL::handleRootSignatureAttr(Decl *D, const ParsedAttr &AL) { @@ -1479,6 +1497,23 @@ void SemaHLSL::handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL) { D->addAttr(NewAttr); } +void SemaHLSL::handleVkBindingAttr(Decl *D, const ParsedAttr &AL) { + // The vk::binding attribute only applies to SPIR-V. + if (!getASTContext().getTargetInfo().getTriple().isSPIRV()) + return; + + uint32_t Binding = 0; + if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Binding)) + return; + uint32_t Set = 0; + if (AL.getNumArgs() > 1 && + !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Set)) + return; + + D->addAttr(::new (getASTContext()) + HLSLVkBindingAttr(getASTContext(), AL, Binding, Set)); +} + bool SemaHLSL::diagnoseInputIDType(QualType T, const ParsedAttr &AL) { const auto *VT = T->getAs<VectorType>(); @@ -3643,8 +3678,12 @@ static bool initVarDeclWithCtor(Sema &S, VarDecl *VD, bool SemaHLSL::initGlobalResourceDecl(VarDecl *VD) { std::optional<uint32_t> RegisterSlot; uint32_t SpaceNo = 0; + HLSLVkBindingAttr *VkBinding = VD->getAttr<HLSLVkBindingAttr>(); HLSLResourceBindingAttr *RBA = VD->getAttr<HLSLResourceBindingAttr>(); - if (RBA) { + if (VkBinding) { + RegisterSlot = VkBinding->getBinding(); + SpaceNo = VkBinding->getSet(); + } else if (RBA) { if (RBA->hasRegisterSlot()) RegisterSlot = RBA->getSlotNumber(); SpaceNo = RBA->getSpaceNumber(); @@ -3747,6 +3786,9 @@ void SemaHLSL::processExplicitBindingsOnDecl(VarDecl *VD) { bool HasBinding = false; for (Attr *A : VD->attrs()) { + if (isa<HLSLVkBindingAttr>(A)) + HasBinding = true; + HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A); if (!RBA || !RBA->hasRegisterSlot()) continue; |