aboutsummaryrefslogtreecommitdiff
path: root/clang/lib/Sema/SemaHLSL.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'clang/lib/Sema/SemaHLSL.cpp')
-rw-r--r--clang/lib/Sema/SemaHLSL.cpp292
1 files changed, 167 insertions, 125 deletions
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 9276554..8536e04 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -39,6 +39,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
+#include "llvm/Frontend/HLSL/HLSLBinding.h"
#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/DXILABI.h"
@@ -596,8 +597,9 @@ void SemaHLSL::ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace) {
// create buffer layout struct
createHostLayoutStructForBuffer(SemaRef, BufDecl);
+ HLSLVkBindingAttr *VkBinding = Dcl->getAttr<HLSLVkBindingAttr>();
HLSLResourceBindingAttr *RBA = Dcl->getAttr<HLSLResourceBindingAttr>();
- if (!RBA || !RBA->hasRegisterSlot()) {
+ if (!VkBinding && (!RBA || !RBA->hasRegisterSlot())) {
SemaRef.Diag(Dcl->getLocation(), diag::warn_hlsl_implicit_binding);
// Use HLSLResourceBindingAttr to transfer implicit binding order_ID
// to codegen. If it does not exist, create an implicit attribute.
@@ -1083,6 +1085,102 @@ void SemaHLSL::ActOnFinishRootSignatureDecl(
SemaRef.PushOnScopeChains(SignatureDecl, SemaRef.getCurScope());
}
+namespace {
+
+struct PerVisibilityBindingChecker {
+ SemaHLSL *S;
+ // We need one builder per `llvm::dxbc::ShaderVisibility` value.
+ std::array<llvm::hlsl::BindingInfoBuilder, 8> Builders;
+
+ struct ElemInfo {
+ const hlsl::RootSignatureElement *Elem;
+ llvm::dxbc::ShaderVisibility Vis;
+ bool Diagnosed;
+ };
+ llvm::SmallVector<ElemInfo> ElemInfoMap;
+
+ PerVisibilityBindingChecker(SemaHLSL *S) : S(S) {}
+
+ void trackBinding(llvm::dxbc::ShaderVisibility Visibility,
+ llvm::dxil::ResourceClass RC, uint32_t Space,
+ uint32_t LowerBound, uint32_t UpperBound,
+ const hlsl::RootSignatureElement *Elem) {
+ uint32_t BuilderIndex = llvm::to_underlying(Visibility);
+ assert(BuilderIndex < Builders.size() &&
+ "Not enough builders for visibility type");
+ Builders[BuilderIndex].trackBinding(RC, Space, LowerBound, UpperBound,
+ static_cast<const void *>(Elem));
+
+ static_assert(llvm::to_underlying(llvm::dxbc::ShaderVisibility::All) == 0,
+ "'All' visibility must come first");
+ if (Visibility == llvm::dxbc::ShaderVisibility::All)
+ for (size_t I = 1, E = Builders.size(); I < E; ++I)
+ Builders[I].trackBinding(RC, Space, LowerBound, UpperBound,
+ static_cast<const void *>(Elem));
+
+ ElemInfoMap.push_back({Elem, Visibility, false});
+ }
+
+ ElemInfo &getInfo(const hlsl::RootSignatureElement *Elem) {
+ auto It = llvm::lower_bound(
+ ElemInfoMap, Elem,
+ [](const auto &LHS, const auto &RHS) { return LHS.Elem < RHS; });
+ assert(It->Elem == Elem && "Element not in map");
+ return *It;
+ }
+
+ bool checkOverlap() {
+ llvm::sort(ElemInfoMap, [](const auto &LHS, const auto &RHS) {
+ return LHS.Elem < RHS.Elem;
+ });
+
+ bool HadOverlap = false;
+
+ using llvm::hlsl::BindingInfoBuilder;
+ auto ReportOverlap = [this, &HadOverlap](
+ const BindingInfoBuilder &Builder,
+ const BindingInfoBuilder::Binding &Reported) {
+ HadOverlap = true;
+
+ const auto *Elem =
+ static_cast<const hlsl::RootSignatureElement *>(Reported.Cookie);
+ const BindingInfoBuilder::Binding &Previous =
+ Builder.findOverlapping(Reported);
+ const auto *PrevElem =
+ static_cast<const hlsl::RootSignatureElement *>(Previous.Cookie);
+
+ ElemInfo &Info = getInfo(Elem);
+ // We will have already diagnosed this binding if there's overlap in the
+ // "All" visibility as well as any particular visibility.
+ if (Info.Diagnosed)
+ return;
+ Info.Diagnosed = true;
+
+ ElemInfo &PrevInfo = getInfo(PrevElem);
+ llvm::dxbc::ShaderVisibility CommonVis =
+ Info.Vis == llvm::dxbc::ShaderVisibility::All ? PrevInfo.Vis
+ : Info.Vis;
+
+ this->S->Diag(Elem->getLocation(), diag::err_hlsl_resource_range_overlap)
+ << llvm::to_underlying(Reported.RC) << Reported.LowerBound
+ << Reported.isUnbounded() << Reported.UpperBound
+ << llvm::to_underlying(Previous.RC) << Previous.LowerBound
+ << Previous.isUnbounded() << Previous.UpperBound << Reported.Space
+ << CommonVis;
+
+ this->S->Diag(PrevElem->getLocation(),
+ diag::note_hlsl_resource_range_here);
+ };
+
+ for (BindingInfoBuilder &Builder : Builders)
+ Builder.calculateBindingInfo(ReportOverlap);
+
+ return HadOverlap;
+ }
+};
+
+} // end anonymous namespace
+
bool SemaHLSL::handleRootSignatureElements(
ArrayRef<hlsl::RootSignatureElement> Elements) {
// Define some common error handling functions
@@ -1171,147 +1269,67 @@ bool SemaHLSL::handleRootSignatureElements(
}
}
- using RangeInfo = llvm::hlsl::rootsig::RangeInfo;
- using OverlappingRanges = llvm::hlsl::rootsig::OverlappingRanges;
- using InfoPairT = std::pair<RangeInfo, const hlsl::RootSignatureElement *>;
+ PerVisibilityBindingChecker BindingChecker(this);
+ SmallVector<std::pair<const llvm::hlsl::rootsig::DescriptorTableClause *,
+ const hlsl::RootSignatureElement *>>
+ UnboundClauses;
- // 1. Collect RangeInfos
- llvm::SmallVector<InfoPairT> InfoPairs;
for (const hlsl::RootSignatureElement &RootSigElem : Elements) {
const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement();
if (const auto *Descriptor =
std::get_if<llvm::hlsl::rootsig::RootDescriptor>(&Elem)) {
- RangeInfo Info;
- Info.LowerBound = Descriptor->Reg.Number;
- Info.UpperBound = Info.LowerBound; // use inclusive ranges []
+ uint32_t LowerBound(Descriptor->Reg.Number);
+ uint32_t UpperBound(LowerBound); // inclusive range
- Info.Class =
- llvm::dxil::ResourceClass(llvm::to_underlying(Descriptor->Type));
- Info.Space = Descriptor->Space;
- Info.Visibility = Descriptor->Visibility;
-
- InfoPairs.push_back({Info, &RootSigElem});
+ BindingChecker.trackBinding(
+ Descriptor->Visibility,
+ static_cast<llvm::dxil::ResourceClass>(Descriptor->Type),
+ Descriptor->Space, LowerBound, UpperBound, &RootSigElem);
} else if (const auto *Constants =
std::get_if<llvm::hlsl::rootsig::RootConstants>(&Elem)) {
- RangeInfo Info;
- Info.LowerBound = Constants->Reg.Number;
- Info.UpperBound = Info.LowerBound; // use inclusive ranges []
-
- Info.Class = llvm::dxil::ResourceClass::CBuffer;
- Info.Space = Constants->Space;
- Info.Visibility = Constants->Visibility;
+ uint32_t LowerBound(Constants->Reg.Number);
+ uint32_t UpperBound(LowerBound); // inclusive range
- InfoPairs.push_back({Info, &RootSigElem});
+ BindingChecker.trackBinding(
+ Constants->Visibility, llvm::dxil::ResourceClass::CBuffer,
+ Constants->Space, LowerBound, UpperBound, &RootSigElem);
} else if (const auto *Sampler =
std::get_if<llvm::hlsl::rootsig::StaticSampler>(&Elem)) {
- RangeInfo Info;
- Info.LowerBound = Sampler->Reg.Number;
- Info.UpperBound = Info.LowerBound; // use inclusive ranges []
+ uint32_t LowerBound(Sampler->Reg.Number);
+ uint32_t UpperBound(LowerBound); // inclusive range
- Info.Class = llvm::dxil::ResourceClass::Sampler;
- Info.Space = Sampler->Space;
- Info.Visibility = Sampler->Visibility;
-
- InfoPairs.push_back({Info, &RootSigElem});
+ BindingChecker.trackBinding(
+ Sampler->Visibility, llvm::dxil::ResourceClass::Sampler,
+ Sampler->Space, LowerBound, UpperBound, &RootSigElem);
} else if (const auto *Clause =
std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
&Elem)) {
- RangeInfo Info;
- Info.LowerBound = Clause->Reg.Number;
- // Relevant error will have already been reported above and needs to be
- // fixed before we can conduct range analysis, so shortcut error return
- if (Clause->NumDescriptors == 0)
- return true;
- Info.UpperBound = Clause->NumDescriptors == RangeInfo::Unbounded
- ? RangeInfo::Unbounded
- : Info.LowerBound + Clause->NumDescriptors -
- 1; // use inclusive ranges []
-
- Info.Class = Clause->Type;
- Info.Space = Clause->Space;
-
- // Note: Clause does not hold the visibility this will need to
- InfoPairs.push_back({Info, &RootSigElem});
+ // We'll process these once we see the table element.
+ UnboundClauses.emplace_back(Clause, &RootSigElem);
} else if (const auto *Table =
std::get_if<llvm::hlsl::rootsig::DescriptorTable>(&Elem)) {
- // Table holds the Visibility of all owned Clauses in Table, so iterate
- // owned Clauses and update their corresponding RangeInfo
- assert(Table->NumClauses <= InfoPairs.size() && "RootElement");
- // The last Table->NumClauses elements of Infos are the owned Clauses
- // generated RangeInfo
- auto TableInfos =
- MutableArrayRef<InfoPairT>(InfoPairs).take_back(Table->NumClauses);
- for (InfoPairT &Pair : TableInfos)
- Pair.first.Visibility = Table->Visibility;
- }
- }
-
- // 2. Sort with the RangeInfo <operator to prepare it for findOverlapping
- llvm::sort(InfoPairs,
- [](InfoPairT A, InfoPairT B) { return A.first < B.first; });
-
- llvm::SmallVector<RangeInfo> Infos;
- for (const InfoPairT &Pair : InfoPairs)
- Infos.push_back(Pair.first);
-
- // Helpers to report diagnostics
- uint32_t DuplicateCounter = 0;
- using ElemPair = std::pair<const hlsl::RootSignatureElement *,
- const hlsl::RootSignatureElement *>;
- auto GetElemPair = [&Infos, &InfoPairs, &DuplicateCounter](
- OverlappingRanges Overlap) -> ElemPair {
- // Given we sorted the InfoPairs (and by implication) Infos, and,
- // that Overlap.B is the item retrieved from the ResourceRange. Then it is
- // guarenteed that Overlap.B <= Overlap.A.
- //
- // So we will find Overlap.B first and then continue to find Overlap.A
- // after
- auto InfoB = std::lower_bound(Infos.begin(), Infos.end(), *Overlap.B);
- auto DistB = std::distance(Infos.begin(), InfoB);
- auto PairB = InfoPairs.begin();
- std::advance(PairB, DistB);
-
- auto InfoA = std::lower_bound(InfoB, Infos.end(), *Overlap.A);
- // Similarily, from the property that we have sorted the RangeInfos,
- // all duplicates will be processed one after the other. So
- // DuplicateCounter can be re-used for each set of duplicates we
- // encounter as we handle incoming errors
- DuplicateCounter = InfoA == InfoB ? DuplicateCounter + 1 : 0;
- auto DistA = std::distance(InfoB, InfoA) + DuplicateCounter;
- auto PairA = PairB;
- std::advance(PairA, DistA);
-
- return {PairA->second, PairB->second};
- };
-
- auto ReportOverlap = [this, &GetElemPair](OverlappingRanges Overlap) {
- auto Pair = GetElemPair(Overlap);
- const RangeInfo *Info = Overlap.A;
- const hlsl::RootSignatureElement *Elem = Pair.first;
- const RangeInfo *OInfo = Overlap.B;
-
- auto CommonVis = Info->Visibility == llvm::dxbc::ShaderVisibility::All
- ? OInfo->Visibility
- : Info->Visibility;
- this->Diag(Elem->getLocation(), diag::err_hlsl_resource_range_overlap)
- << llvm::to_underlying(Info->Class) << Info->LowerBound
- << /*unbounded=*/(Info->UpperBound == RangeInfo::Unbounded)
- << Info->UpperBound << llvm::to_underlying(OInfo->Class)
- << OInfo->LowerBound
- << /*unbounded=*/(OInfo->UpperBound == RangeInfo::Unbounded)
- << OInfo->UpperBound << Info->Space << CommonVis;
-
- const hlsl::RootSignatureElement *OElem = Pair.second;
- this->Diag(OElem->getLocation(), diag::note_hlsl_resource_range_here);
- };
-
- // 3. Invoke find overlapping ranges
- llvm::SmallVector<OverlappingRanges> Overlaps =
- llvm::hlsl::rootsig::findOverlappingRanges(Infos);
- for (OverlappingRanges Overlap : Overlaps)
- ReportOverlap(Overlap);
+ assert(UnboundClauses.size() == Table->NumClauses &&
+ "Number of unbound elements must match the number of clauses");
+ for (const auto &[Clause, ClauseElem] : UnboundClauses) {
+ uint32_t LowerBound(Clause->Reg.Number);
+ // Relevant error will have already been reported above and needs to be
+ // fixed before we can conduct range analysis, so shortcut error return
+ if (Clause->NumDescriptors == 0)
+ return true;
+ uint32_t UpperBound = Clause->NumDescriptors == ~0u
+ ? ~0u
+ : LowerBound + Clause->NumDescriptors - 1;
+
+ BindingChecker.trackBinding(
+ Table->Visibility,
+ static_cast<llvm::dxil::ResourceClass>(Clause->Type), Clause->Space,
+ LowerBound, UpperBound, ClauseElem);
+ }
+ UnboundClauses.clear();
+ }
+ }
- return Overlaps.size() != 0;
+ return BindingChecker.checkOverlap();
}
void SemaHLSL::handleRootSignatureAttr(Decl *D, const ParsedAttr &AL) {
@@ -1479,6 +1497,23 @@ void SemaHLSL::handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL) {
D->addAttr(NewAttr);
}
+void SemaHLSL::handleVkBindingAttr(Decl *D, const ParsedAttr &AL) {
+ // The vk::binding attribute only applies to SPIR-V.
+ if (!getASTContext().getTargetInfo().getTriple().isSPIRV())
+ return;
+
+ uint32_t Binding = 0;
+ if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Binding))
+ return;
+ uint32_t Set = 0;
+ if (AL.getNumArgs() > 1 &&
+ !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Set))
+ return;
+
+ D->addAttr(::new (getASTContext())
+ HLSLVkBindingAttr(getASTContext(), AL, Binding, Set));
+}
+
bool SemaHLSL::diagnoseInputIDType(QualType T, const ParsedAttr &AL) {
const auto *VT = T->getAs<VectorType>();
@@ -3643,8 +3678,12 @@ static bool initVarDeclWithCtor(Sema &S, VarDecl *VD,
bool SemaHLSL::initGlobalResourceDecl(VarDecl *VD) {
std::optional<uint32_t> RegisterSlot;
uint32_t SpaceNo = 0;
+ HLSLVkBindingAttr *VkBinding = VD->getAttr<HLSLVkBindingAttr>();
HLSLResourceBindingAttr *RBA = VD->getAttr<HLSLResourceBindingAttr>();
- if (RBA) {
+ if (VkBinding) {
+ RegisterSlot = VkBinding->getBinding();
+ SpaceNo = VkBinding->getSet();
+ } else if (RBA) {
if (RBA->hasRegisterSlot())
RegisterSlot = RBA->getSlotNumber();
SpaceNo = RBA->getSpaceNumber();
@@ -3747,6 +3786,9 @@ void SemaHLSL::processExplicitBindingsOnDecl(VarDecl *VD) {
bool HasBinding = false;
for (Attr *A : VD->attrs()) {
+ if (isa<HLSLVkBindingAttr>(A))
+ HasBinding = true;
+
HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A);
if (!RBA || !RBA->hasRegisterSlot())
continue;