diff options
Diffstat (limited to 'llvm/lib/Frontend/HLSL')
-rw-r--r-- | llvm/lib/Frontend/HLSL/CMakeLists.txt | 1 | ||||
-rw-r--r-- | llvm/lib/Frontend/HLSL/HLSLBinding.cpp | 142 | ||||
-rw-r--r-- | llvm/lib/Frontend/HLSL/HLSLRootSignature.cpp | 55 | ||||
-rw-r--r-- | llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp | 546 | ||||
-rw-r--r-- | llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp | 136 |
5 files changed, 451 insertions, 429 deletions
diff --git a/llvm/lib/Frontend/HLSL/CMakeLists.txt b/llvm/lib/Frontend/HLSL/CMakeLists.txt index 5343469..3d22577 100644 --- a/llvm/lib/Frontend/HLSL/CMakeLists.txt +++ b/llvm/lib/Frontend/HLSL/CMakeLists.txt @@ -1,5 +1,6 @@ add_llvm_component_library(LLVMFrontendHLSL CBuffer.cpp + HLSLBinding.cpp HLSLResource.cpp HLSLRootSignature.cpp RootSignatureMetadata.cpp diff --git a/llvm/lib/Frontend/HLSL/HLSLBinding.cpp b/llvm/lib/Frontend/HLSL/HLSLBinding.cpp new file mode 100644 index 0000000..401402f --- /dev/null +++ b/llvm/lib/Frontend/HLSL/HLSLBinding.cpp @@ -0,0 +1,142 @@ +//===- HLSLBinding.cpp - Representation for resource bindings in HLSL -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Frontend/HLSL/HLSLBinding.h" +#include "llvm/ADT/STLExtras.h" + +using namespace llvm; +using namespace hlsl; + +std::optional<uint32_t> +BindingInfo::findAvailableBinding(dxil::ResourceClass RC, uint32_t Space, + int32_t Size) { + BindingSpaces &BS = getBindingSpaces(RC); + RegisterSpace &RS = BS.getOrInsertSpace(Space); + return RS.findAvailableBinding(Size); +} + +BindingInfo::RegisterSpace & +BindingInfo::BindingSpaces::getOrInsertSpace(uint32_t Space) { + for (auto It = Spaces.begin(), End = Spaces.end(); It != End; ++It) { + if (It->Space == Space) + return *It; + if (It->Space < Space) + continue; + return *Spaces.insert(It, Space); + } + return Spaces.emplace_back(Space); +} + +std::optional<uint32_t> +BindingInfo::RegisterSpace::findAvailableBinding(int32_t Size) { + assert((Size == -1 || Size > 0) && "invalid size"); + + if (FreeRanges.empty()) + return std::nullopt; + + // unbounded array + if (Size == -1) { + BindingRange &Last = FreeRanges.back(); + if (Last.UpperBound != ~0u) + // this space is already occupied by an unbounded array + return std::nullopt; + uint32_t RegSlot = Last.LowerBound; + FreeRanges.pop_back(); + return RegSlot; + } + + // single resource or fixed-size array + for (BindingRange &R : FreeRanges) { + // compare the size as uint64_t to prevent overflow for range (0, ~0u) + if ((uint64_t)R.UpperBound - R.LowerBound + 1 < (uint64_t)Size) + continue; + uint32_t RegSlot = R.LowerBound; + // This might create a range where (LowerBound == UpperBound + 1). When + // that happens, the next time this function is called the range will + // skipped over by the check above (at this point Size is always > 0). + R.LowerBound += Size; + return RegSlot; + } + + return std::nullopt; +} + +BindingInfo BindingInfoBuilder::calculateBindingInfo( + llvm::function_ref<void(const BindingInfoBuilder &Builder, + const Binding &Overlapping)> + ReportOverlap) { + // sort all the collected bindings + llvm::stable_sort(Bindings); + + // remove duplicates + Binding *NewEnd = llvm::unique(Bindings); + if (NewEnd != Bindings.end()) + Bindings.erase(NewEnd, Bindings.end()); + + BindingInfo Info; + + // Go over the sorted bindings and build up lists of free register ranges + // for each binding type and used spaces. Bindings are sorted by resource + // class, space, and lower bound register slot. + BindingInfo::BindingSpaces *BS = + &Info.getBindingSpaces(dxil::ResourceClass::SRV); + for (const Binding &B : Bindings) { + if (BS->RC != B.RC) + // move to the next resource class spaces + BS = &Info.getBindingSpaces(B.RC); + + BindingInfo::RegisterSpace *S = BS->Spaces.empty() + ? &BS->Spaces.emplace_back(B.Space) + : &BS->Spaces.back(); + assert(S->Space <= B.Space && "bindings not sorted correctly?"); + if (B.Space != S->Space) + // add new space + S = &BS->Spaces.emplace_back(B.Space); + + // The space is full - there are no free slots left, or the rest of the + // slots are taken by an unbounded array. Report the overlapping to the + // caller. + if (S->FreeRanges.empty() || S->FreeRanges.back().UpperBound < ~0u) { + ReportOverlap(*this, B); + continue; + } + // adjust the last free range lower bound, split it in two, or remove it + BindingInfo::BindingRange &LastFreeRange = S->FreeRanges.back(); + if (LastFreeRange.LowerBound == B.LowerBound) { + if (B.UpperBound < ~0u) + LastFreeRange.LowerBound = B.UpperBound + 1; + else + S->FreeRanges.pop_back(); + } else if (LastFreeRange.LowerBound < B.LowerBound) { + LastFreeRange.UpperBound = B.LowerBound - 1; + if (B.UpperBound < ~0u) + S->FreeRanges.emplace_back(B.UpperBound + 1, ~0u); + } else { + // We don't have room here. Report the overlapping binding to the caller + // and mark any extra space this binding would use as unavailable. + ReportOverlap(*this, B); + if (B.UpperBound < ~0u) + LastFreeRange.LowerBound = + std::max(LastFreeRange.LowerBound, B.UpperBound + 1); + else + S->FreeRanges.pop_back(); + } + } + + return Info; +} + +const Binding & +BindingInfoBuilder::findOverlapping(const Binding &ReportedBinding) const { + for (const Binding &Other : Bindings) + if (ReportedBinding.LowerBound <= Other.UpperBound && + Other.LowerBound <= ReportedBinding.UpperBound) + return Other; + + llvm_unreachable("Searching for overlap for binding that does not overlap"); +} diff --git a/llvm/lib/Frontend/HLSL/HLSLRootSignature.cpp b/llvm/lib/Frontend/HLSL/HLSLRootSignature.cpp index 78c20a6..92c62b8 100644 --- a/llvm/lib/Frontend/HLSL/HLSLRootSignature.cpp +++ b/llvm/lib/Frontend/HLSL/HLSLRootSignature.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Frontend/HLSL/HLSLRootSignature.h" +#include "llvm/Support/DXILABI.h" #include "llvm/Support/ScopedPrinter.h" namespace llvm { @@ -18,24 +19,6 @@ namespace hlsl { namespace rootsig { template <typename T> -static std::optional<StringRef> getEnumName(const T Value, - ArrayRef<EnumEntry<T>> Enums) { - for (const auto &EnumItem : Enums) - if (EnumItem.Value == Value) - return EnumItem.Name; - return std::nullopt; -} - -template <typename T> -static raw_ostream &printEnum(raw_ostream &OS, const T Value, - ArrayRef<EnumEntry<T>> Enums) { - auto MaybeName = getEnumName(Value, Enums); - if (MaybeName) - OS << *MaybeName; - return OS; -} - -template <typename T> static raw_ostream &printFlags(raw_ostream &OS, const T Value, ArrayRef<EnumEntry<T>> Flags) { bool FlagSet = false; @@ -46,9 +29,9 @@ static raw_ostream &printFlags(raw_ostream &OS, const T Value, if (FlagSet) OS << " | "; - auto MaybeFlag = getEnumName(T(Bit), Flags); - if (MaybeFlag) - OS << *MaybeFlag; + StringRef MaybeFlag = enumToStringRef(T(Bit), Flags); + if (!MaybeFlag.empty()) + OS << MaybeFlag; else OS << "invalid: " << Bit; @@ -70,58 +53,49 @@ static const EnumEntry<RegisterType> RegisterNames[] = { }; static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) { - printEnum(OS, Reg.ViewType, ArrayRef(RegisterNames)); - OS << Reg.Number; + OS << enumToStringRef(Reg.ViewType, ArrayRef(RegisterNames)) << Reg.Number; return OS; } static raw_ostream &operator<<(raw_ostream &OS, const llvm::dxbc::ShaderVisibility &Visibility) { - printEnum(OS, Visibility, dxbc::getShaderVisibility()); + OS << enumToStringRef(Visibility, dxbc::getShaderVisibility()); return OS; } static raw_ostream &operator<<(raw_ostream &OS, const llvm::dxbc::SamplerFilter &Filter) { - printEnum(OS, Filter, dxbc::getSamplerFilters()); + OS << enumToStringRef(Filter, dxbc::getSamplerFilters()); return OS; } static raw_ostream &operator<<(raw_ostream &OS, const dxbc::TextureAddressMode &Address) { - printEnum(OS, Address, dxbc::getTextureAddressModes()); + OS << enumToStringRef(Address, dxbc::getTextureAddressModes()); return OS; } static raw_ostream &operator<<(raw_ostream &OS, const dxbc::ComparisonFunc &CompFunc) { - printEnum(OS, CompFunc, dxbc::getComparisonFuncs()); + OS << enumToStringRef(CompFunc, dxbc::getComparisonFuncs()); return OS; } static raw_ostream &operator<<(raw_ostream &OS, const dxbc::StaticBorderColor &BorderColor) { - printEnum(OS, BorderColor, dxbc::getStaticBorderColors()); + OS << enumToStringRef(BorderColor, dxbc::getStaticBorderColors()); return OS; } -static const EnumEntry<dxil::ResourceClass> ResourceClassNames[] = { - {"CBV", dxil::ResourceClass::CBuffer}, - {"SRV", dxil::ResourceClass::SRV}, - {"UAV", dxil::ResourceClass::UAV}, - {"Sampler", dxil::ResourceClass::Sampler}, -}; - -static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) { - printEnum(OS, dxil::ResourceClass(llvm::to_underlying(Type)), - ArrayRef(ResourceClassNames)); - +static raw_ostream &operator<<(raw_ostream &OS, + const dxil::ResourceClass &Type) { + OS << dxil::getResourceClassName(Type); return OS; } @@ -179,8 +153,7 @@ raw_ostream &operator<<(raw_ostream &OS, const DescriptorTableClause &Clause) { } raw_ostream &operator<<(raw_ostream &OS, const RootDescriptor &Descriptor) { - ClauseType Type = ClauseType(llvm::to_underlying(Descriptor.Type)); - OS << "Root" << Type << "(" << Descriptor.Reg + OS << "Root" << Descriptor.Type << "(" << Descriptor.Reg << ", space = " << Descriptor.Space << ", visibility = " << Descriptor.Visibility << ", flags = " << Descriptor.Flags << ")"; diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp index 53f5934..a5a92cb 100644 --- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp +++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp @@ -13,15 +13,22 @@ #include "llvm/Frontend/HLSL/RootSignatureMetadata.h" #include "llvm/Frontend/HLSL/RootSignatureValidations.h" -#include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Metadata.h" +#include "llvm/Support/DXILABI.h" #include "llvm/Support/ScopedPrinter.h" +using namespace llvm; + namespace llvm { namespace hlsl { namespace rootsig { +char GenericRSMetadataError::ID; +char InvalidRSMetadataFormat::ID; +char InvalidRSMetadataValue::ID; +template <typename T> char RootSignatureValidationError<T>::ID; + static std::optional<uint32_t> extractMdIntValue(MDNode *Node, unsigned int OpId) { if (auto *CI = @@ -45,31 +52,15 @@ static std::optional<StringRef> extractMdStringValue(MDNode *Node, return NodeText->getString(); } -static bool reportError(LLVMContext *Ctx, Twine Message, - DiagnosticSeverity Severity = DS_Error) { - Ctx->diagnose(DiagnosticInfoGeneric(Message, Severity)); - return true; -} - -static bool reportValueError(LLVMContext *Ctx, Twine ParamName, - uint32_t Value) { - Ctx->diagnose(DiagnosticInfoGeneric( - "Invalid value for " + ParamName + ": " + Twine(Value), DS_Error)); - return true; -} - -static const EnumEntry<dxil::ResourceClass> ResourceClassNames[] = { - {"CBV", dxil::ResourceClass::CBuffer}, - {"SRV", dxil::ResourceClass::SRV}, - {"UAV", dxil::ResourceClass::UAV}, - {"Sampler", dxil::ResourceClass::Sampler}, -}; - -static std::optional<StringRef> getResourceName(dxil::ResourceClass Class) { - for (const auto &ClassEnum : ResourceClassNames) - if (ClassEnum.Value == Class) - return ClassEnum.Name; - return std::nullopt; +static Expected<dxbc::ShaderVisibility> +extractShaderVisibility(MDNode *Node, unsigned int OpId) { + if (std::optional<uint32_t> Val = extractMdIntValue(Node, OpId)) { + if (!dxbc::isValidShaderVisibility(*Val)) + return make_error<RootSignatureValidationError<uint32_t>>( + "ShaderVisibility", *Val); + return dxbc::ShaderVisibility(*Val); + } + return make_error<InvalidRSMetadataValue>("ShaderVisibility"); } namespace { @@ -120,7 +111,7 @@ MDNode *MetadataBuilder::BuildRootFlags(const dxbc::RootFlags &Flags) { IRBuilder<> Builder(Ctx); Metadata *Operands[] = { MDString::get(Ctx, "RootFlags"), - ConstantAsMetadata::get(Builder.getInt32(llvm::to_underlying(Flags))), + ConstantAsMetadata::get(Builder.getInt32(to_underlying(Flags))), }; return MDNode::get(Ctx, Operands); } @@ -130,7 +121,7 @@ MDNode *MetadataBuilder::BuildRootConstants(const RootConstants &Constants) { Metadata *Operands[] = { MDString::get(Ctx, "RootConstants"), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Constants.Visibility))), + Builder.getInt32(to_underlying(Constants.Visibility))), ConstantAsMetadata::get(Builder.getInt32(Constants.Reg.Number)), ConstantAsMetadata::get(Builder.getInt32(Constants.Space)), ConstantAsMetadata::get(Builder.getInt32(Constants.Num32BitConstants)), @@ -140,18 +131,17 @@ MDNode *MetadataBuilder::BuildRootConstants(const RootConstants &Constants) { MDNode *MetadataBuilder::BuildRootDescriptor(const RootDescriptor &Descriptor) { IRBuilder<> Builder(Ctx); - std::optional<StringRef> ResName = getResourceName( - dxil::ResourceClass(llvm::to_underlying(Descriptor.Type))); - assert(ResName && "Provided an invalid Resource Class"); - llvm::SmallString<7> Name({"Root", *ResName}); + StringRef ResName = dxil::getResourceClassName(Descriptor.Type); + assert(!ResName.empty() && "Provided an invalid Resource Class"); + SmallString<7> Name({"Root", ResName}); Metadata *Operands[] = { MDString::get(Ctx, Name), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Descriptor.Visibility))), + Builder.getInt32(to_underlying(Descriptor.Visibility))), ConstantAsMetadata::get(Builder.getInt32(Descriptor.Reg.Number)), ConstantAsMetadata::get(Builder.getInt32(Descriptor.Space)), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Descriptor.Flags))), + Builder.getInt32(to_underlying(Descriptor.Flags))), }; return MDNode::get(Ctx, Operands); } @@ -162,7 +152,7 @@ MDNode *MetadataBuilder::BuildDescriptorTable(const DescriptorTable &Table) { // Set the mandatory arguments TableOperands.push_back(MDString::get(Ctx, "DescriptorTable")); TableOperands.push_back(ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Table.Visibility)))); + Builder.getInt32(to_underlying(Table.Visibility)))); // Remaining operands are references to the table's clauses. The in-memory // representation of the Root Elements created from parsing will ensure that @@ -181,17 +171,15 @@ MDNode *MetadataBuilder::BuildDescriptorTable(const DescriptorTable &Table) { MDNode *MetadataBuilder::BuildDescriptorTableClause( const DescriptorTableClause &Clause) { IRBuilder<> Builder(Ctx); - std::optional<StringRef> ResName = - getResourceName(dxil::ResourceClass(llvm::to_underlying(Clause.Type))); - assert(ResName && "Provided an invalid Resource Class"); + StringRef ResName = dxil::getResourceClassName(Clause.Type); + assert(!ResName.empty() && "Provided an invalid Resource Class"); Metadata *Operands[] = { - MDString::get(Ctx, *ResName), + MDString::get(Ctx, ResName), ConstantAsMetadata::get(Builder.getInt32(Clause.NumDescriptors)), ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)), ConstantAsMetadata::get(Builder.getInt32(Clause.Space)), ConstantAsMetadata::get(Builder.getInt32(Clause.Offset)), - ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Clause.Flags))), + ConstantAsMetadata::get(Builder.getInt32(to_underlying(Clause.Flags))), }; return MDNode::get(Ctx, Operands); } @@ -200,151 +188,139 @@ MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) { IRBuilder<> Builder(Ctx); Metadata *Operands[] = { MDString::get(Ctx, "StaticSampler"), + ConstantAsMetadata::get(Builder.getInt32(to_underlying(Sampler.Filter))), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.Filter))), + Builder.getInt32(to_underlying(Sampler.AddressU))), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.AddressU))), + Builder.getInt32(to_underlying(Sampler.AddressV))), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.AddressV))), + Builder.getInt32(to_underlying(Sampler.AddressW))), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.AddressW))), - ConstantAsMetadata::get(llvm::ConstantFP::get(llvm::Type::getFloatTy(Ctx), - Sampler.MipLODBias)), + ConstantFP::get(Type::getFloatTy(Ctx), Sampler.MipLODBias)), ConstantAsMetadata::get(Builder.getInt32(Sampler.MaxAnisotropy)), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.CompFunc))), + Builder.getInt32(to_underlying(Sampler.CompFunc))), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.BorderColor))), + Builder.getInt32(to_underlying(Sampler.BorderColor))), ConstantAsMetadata::get( - llvm::ConstantFP::get(llvm::Type::getFloatTy(Ctx), Sampler.MinLOD)), + ConstantFP::get(Type::getFloatTy(Ctx), Sampler.MinLOD)), ConstantAsMetadata::get( - llvm::ConstantFP::get(llvm::Type::getFloatTy(Ctx), Sampler.MaxLOD)), + ConstantFP::get(Type::getFloatTy(Ctx), Sampler.MaxLOD)), ConstantAsMetadata::get(Builder.getInt32(Sampler.Reg.Number)), ConstantAsMetadata::get(Builder.getInt32(Sampler.Space)), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.Visibility))), + Builder.getInt32(to_underlying(Sampler.Visibility))), }; return MDNode::get(Ctx, Operands); } -bool MetadataParser::parseRootFlags(LLVMContext *Ctx, - mcdxbc::RootSignatureDesc &RSD, - MDNode *RootFlagNode) { - +Error MetadataParser::parseRootFlags(mcdxbc::RootSignatureDesc &RSD, + MDNode *RootFlagNode) { if (RootFlagNode->getNumOperands() != 2) - return reportError(Ctx, "Invalid format for RootFlag Element"); + return make_error<InvalidRSMetadataFormat>("RootFlag Element"); if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1)) RSD.Flags = *Val; else - return reportError(Ctx, "Invalid value for RootFlag"); + return make_error<InvalidRSMetadataValue>("RootFlag"); - return false; + return Error::success(); } -bool MetadataParser::parseRootConstants(LLVMContext *Ctx, - mcdxbc::RootSignatureDesc &RSD, - MDNode *RootConstantNode) { - +Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD, + MDNode *RootConstantNode) { if (RootConstantNode->getNumOperands() != 5) - return reportError(Ctx, "Invalid format for RootConstants Element"); + return make_error<InvalidRSMetadataFormat>("RootConstants Element"); - dxbc::RTS0::v1::RootParameterHeader Header; - // The parameter offset doesn't matter here - we recalculate it during - // serialization Header.ParameterOffset = 0; - Header.ParameterType = - llvm::to_underlying(dxbc::RootParameterType::Constants32Bit); - - if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1)) - Header.ShaderVisibility = *Val; - else - return reportError(Ctx, "Invalid value for ShaderVisibility"); + Expected<dxbc::ShaderVisibility> Visibility = + extractShaderVisibility(RootConstantNode, 1); + if (auto E = Visibility.takeError()) + return Error(std::move(E)); - dxbc::RTS0::v1::RootConstants Constants; + mcdxbc::RootConstants Constants; if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2)) Constants.ShaderRegister = *Val; else - return reportError(Ctx, "Invalid value for ShaderRegister"); + return make_error<InvalidRSMetadataValue>("ShaderRegister"); if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3)) Constants.RegisterSpace = *Val; else - return reportError(Ctx, "Invalid value for RegisterSpace"); + return make_error<InvalidRSMetadataValue>("RegisterSpace"); if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4)) Constants.Num32BitValues = *Val; else - return reportError(Ctx, "Invalid value for Num32BitValues"); + return make_error<InvalidRSMetadataValue>("Num32BitValues"); - RSD.ParametersContainer.addParameter(Header, Constants); + RSD.ParametersContainer.addParameter(dxbc::RootParameterType::Constants32Bit, + *Visibility, Constants); - return false; + return Error::success(); } -bool MetadataParser::parseRootDescriptors( - LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, - MDNode *RootDescriptorNode, RootSignatureElementKind ElementKind) { - assert(ElementKind == RootSignatureElementKind::SRV || - ElementKind == RootSignatureElementKind::UAV || - ElementKind == RootSignatureElementKind::CBV && - "parseRootDescriptors should only be called with RootDescriptor " - "element kind."); +Error MetadataParser::parseRootDescriptors( + mcdxbc::RootSignatureDesc &RSD, MDNode *RootDescriptorNode, + RootSignatureElementKind ElementKind) { + assert((ElementKind == RootSignatureElementKind::SRV || + ElementKind == RootSignatureElementKind::UAV || + ElementKind == RootSignatureElementKind::CBV) && + "parseRootDescriptors should only be called with RootDescriptor " + "element kind."); if (RootDescriptorNode->getNumOperands() != 5) - return reportError(Ctx, "Invalid format for Root Descriptor Element"); + return make_error<InvalidRSMetadataFormat>("Root Descriptor Element"); - dxbc::RTS0::v1::RootParameterHeader Header; + dxbc::RootParameterType Type; switch (ElementKind) { case RootSignatureElementKind::SRV: - Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::SRV); + Type = dxbc::RootParameterType::SRV; break; case RootSignatureElementKind::UAV: - Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::UAV); + Type = dxbc::RootParameterType::UAV; break; case RootSignatureElementKind::CBV: - Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::CBV); + Type = dxbc::RootParameterType::CBV; break; default: llvm_unreachable("invalid Root Descriptor kind"); break; } - if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1)) - Header.ShaderVisibility = *Val; - else - return reportError(Ctx, "Invalid value for ShaderVisibility"); + Expected<dxbc::ShaderVisibility> Visibility = + extractShaderVisibility(RootDescriptorNode, 1); + if (auto E = Visibility.takeError()) + return Error(std::move(E)); - dxbc::RTS0::v2::RootDescriptor Descriptor; + mcdxbc::RootDescriptor Descriptor; if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2)) Descriptor.ShaderRegister = *Val; else - return reportError(Ctx, "Invalid value for ShaderRegister"); + return make_error<InvalidRSMetadataValue>("ShaderRegister"); if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 3)) Descriptor.RegisterSpace = *Val; else - return reportError(Ctx, "Invalid value for RegisterSpace"); + return make_error<InvalidRSMetadataValue>("RegisterSpace"); if (RSD.Version == 1) { - RSD.ParametersContainer.addParameter(Header, Descriptor); - return false; + RSD.ParametersContainer.addParameter(Type, *Visibility, Descriptor); + return Error::success(); } assert(RSD.Version > 1); if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 4)) Descriptor.Flags = *Val; else - return reportError(Ctx, "Invalid value for Root Descriptor Flags"); + return make_error<InvalidRSMetadataValue>("Root Descriptor Flags"); - RSD.ParametersContainer.addParameter(Header, Descriptor); - return false; + RSD.ParametersContainer.addParameter(Type, *Visibility, Descriptor); + return Error::success(); } -bool MetadataParser::parseDescriptorRange(LLVMContext *Ctx, - mcdxbc::DescriptorTable &Table, - MDNode *RangeDescriptorNode) { - +Error MetadataParser::parseDescriptorRange(mcdxbc::DescriptorTable &Table, + MDNode *RangeDescriptorNode) { if (RangeDescriptorNode->getNumOperands() != 6) - return reportError(Ctx, "Invalid format for Descriptor Range"); + return make_error<InvalidRSMetadataFormat>("Descriptor Range"); dxbc::RTS0::v2::DescriptorRange Range; @@ -352,162 +328,159 @@ bool MetadataParser::parseDescriptorRange(LLVMContext *Ctx, extractMdStringValue(RangeDescriptorNode, 0); if (!ElementText.has_value()) - return reportError(Ctx, "Descriptor Range, first element is not a string."); + return make_error<InvalidRSMetadataFormat>("Descriptor Range"); Range.RangeType = StringSwitch<uint32_t>(*ElementText) - .Case("CBV", llvm::to_underlying(dxbc::DescriptorRangeType::CBV)) - .Case("SRV", llvm::to_underlying(dxbc::DescriptorRangeType::SRV)) - .Case("UAV", llvm::to_underlying(dxbc::DescriptorRangeType::UAV)) - .Case("Sampler", - llvm::to_underlying(dxbc::DescriptorRangeType::Sampler)) + .Case("CBV", to_underlying(dxbc::DescriptorRangeType::CBV)) + .Case("SRV", to_underlying(dxbc::DescriptorRangeType::SRV)) + .Case("UAV", to_underlying(dxbc::DescriptorRangeType::UAV)) + .Case("Sampler", to_underlying(dxbc::DescriptorRangeType::Sampler)) .Default(~0U); if (Range.RangeType == ~0U) - return reportError(Ctx, "Invalid Descriptor Range type: " + *ElementText); + return make_error<GenericRSMetadataError>("Invalid Descriptor Range type.", + RangeDescriptorNode); if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1)) Range.NumDescriptors = *Val; else - return reportError(Ctx, "Invalid value for Number of Descriptor in Range"); + return make_error<GenericRSMetadataError>("Number of Descriptor in Range", + RangeDescriptorNode); if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2)) Range.BaseShaderRegister = *Val; else - return reportError(Ctx, "Invalid value for BaseShaderRegister"); + return make_error<InvalidRSMetadataValue>("BaseShaderRegister"); if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3)) Range.RegisterSpace = *Val; else - return reportError(Ctx, "Invalid value for RegisterSpace"); + return make_error<InvalidRSMetadataValue>("RegisterSpace"); if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4)) Range.OffsetInDescriptorsFromTableStart = *Val; else - return reportError(Ctx, - "Invalid value for OffsetInDescriptorsFromTableStart"); + return make_error<InvalidRSMetadataValue>( + "OffsetInDescriptorsFromTableStart"); if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5)) Range.Flags = *Val; else - return reportError(Ctx, "Invalid value for Descriptor Range Flags"); + return make_error<InvalidRSMetadataValue>("Descriptor Range Flags"); Table.Ranges.push_back(Range); - return false; + return Error::success(); } -bool MetadataParser::parseDescriptorTable(LLVMContext *Ctx, - mcdxbc::RootSignatureDesc &RSD, - MDNode *DescriptorTableNode) { +Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD, + MDNode *DescriptorTableNode) { const unsigned int NumOperands = DescriptorTableNode->getNumOperands(); if (NumOperands < 2) - return reportError(Ctx, "Invalid format for Descriptor Table"); + return make_error<InvalidRSMetadataFormat>("Descriptor Table"); - dxbc::RTS0::v1::RootParameterHeader Header; - if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1)) - Header.ShaderVisibility = *Val; - else - return reportError(Ctx, "Invalid value for ShaderVisibility"); + Expected<dxbc::ShaderVisibility> Visibility = + extractShaderVisibility(DescriptorTableNode, 1); + if (auto E = Visibility.takeError()) + return Error(std::move(E)); mcdxbc::DescriptorTable Table; - Header.ParameterType = - llvm::to_underlying(dxbc::RootParameterType::DescriptorTable); for (unsigned int I = 2; I < NumOperands; I++) { MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I)); if (Element == nullptr) - return reportError(Ctx, "Missing Root Element Metadata Node."); + return make_error<GenericRSMetadataError>( + "Missing Root Element Metadata Node.", DescriptorTableNode); - if (parseDescriptorRange(Ctx, Table, Element)) - return true; + if (auto Err = parseDescriptorRange(Table, Element)) + return Err; } - RSD.ParametersContainer.addParameter(Header, Table); - return false; + RSD.ParametersContainer.addParameter(dxbc::RootParameterType::DescriptorTable, + *Visibility, Table); + return Error::success(); } -bool MetadataParser::parseStaticSampler(LLVMContext *Ctx, - mcdxbc::RootSignatureDesc &RSD, - MDNode *StaticSamplerNode) { +Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD, + MDNode *StaticSamplerNode) { if (StaticSamplerNode->getNumOperands() != 14) - return reportError(Ctx, "Invalid format for Static Sampler"); + return make_error<InvalidRSMetadataFormat>("Static Sampler"); dxbc::RTS0::v1::StaticSampler Sampler; if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1)) Sampler.Filter = *Val; else - return reportError(Ctx, "Invalid value for Filter"); + return make_error<InvalidRSMetadataValue>("Filter"); if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 2)) Sampler.AddressU = *Val; else - return reportError(Ctx, "Invalid value for AddressU"); + return make_error<InvalidRSMetadataValue>("AddressU"); if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 3)) Sampler.AddressV = *Val; else - return reportError(Ctx, "Invalid value for AddressV"); + return make_error<InvalidRSMetadataValue>("AddressV"); if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 4)) Sampler.AddressW = *Val; else - return reportError(Ctx, "Invalid value for AddressW"); + return make_error<InvalidRSMetadataValue>("AddressW"); if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 5)) Sampler.MipLODBias = *Val; else - return reportError(Ctx, "Invalid value for MipLODBias"); + return make_error<InvalidRSMetadataValue>("MipLODBias"); if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6)) Sampler.MaxAnisotropy = *Val; else - return reportError(Ctx, "Invalid value for MaxAnisotropy"); + return make_error<InvalidRSMetadataValue>("MaxAnisotropy"); if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 7)) Sampler.ComparisonFunc = *Val; else - return reportError(Ctx, "Invalid value for ComparisonFunc "); + return make_error<InvalidRSMetadataValue>("ComparisonFunc"); if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 8)) Sampler.BorderColor = *Val; else - return reportError(Ctx, "Invalid value for ComparisonFunc "); + return make_error<InvalidRSMetadataValue>("ComparisonFunc"); if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 9)) Sampler.MinLOD = *Val; else - return reportError(Ctx, "Invalid value for MinLOD"); + return make_error<InvalidRSMetadataValue>("MinLOD"); if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 10)) Sampler.MaxLOD = *Val; else - return reportError(Ctx, "Invalid value for MaxLOD"); + return make_error<InvalidRSMetadataValue>("MaxLOD"); if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11)) Sampler.ShaderRegister = *Val; else - return reportError(Ctx, "Invalid value for ShaderRegister"); + return make_error<InvalidRSMetadataValue>("ShaderRegister"); if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 12)) Sampler.RegisterSpace = *Val; else - return reportError(Ctx, "Invalid value for RegisterSpace"); + return make_error<InvalidRSMetadataValue>("RegisterSpace"); if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 13)) Sampler.ShaderVisibility = *Val; else - return reportError(Ctx, "Invalid value for ShaderVisibility"); + return make_error<InvalidRSMetadataValue>("ShaderVisibility"); RSD.StaticSamplers.push_back(Sampler); - return false; + return Error::success(); } -bool MetadataParser::parseRootSignatureElement(LLVMContext *Ctx, - mcdxbc::RootSignatureDesc &RSD, - MDNode *Element) { +Error MetadataParser::parseRootSignatureElement(mcdxbc::RootSignatureDesc &RSD, + MDNode *Element) { std::optional<StringRef> ElementText = extractMdStringValue(Element, 0); if (!ElementText.has_value()) - return reportError(Ctx, "Invalid format for Root Element"); + return make_error<InvalidRSMetadataFormat>("Root Element"); RootSignatureElementKind ElementKind = StringSwitch<RootSignatureElementKind>(*ElementText) @@ -523,79 +496,103 @@ bool MetadataParser::parseRootSignatureElement(LLVMContext *Ctx, switch (ElementKind) { case RootSignatureElementKind::RootFlags: - return parseRootFlags(Ctx, RSD, Element); + return parseRootFlags(RSD, Element); case RootSignatureElementKind::RootConstants: - return parseRootConstants(Ctx, RSD, Element); + return parseRootConstants(RSD, Element); case RootSignatureElementKind::CBV: case RootSignatureElementKind::SRV: case RootSignatureElementKind::UAV: - return parseRootDescriptors(Ctx, RSD, Element, ElementKind); + return parseRootDescriptors(RSD, Element, ElementKind); case RootSignatureElementKind::DescriptorTable: - return parseDescriptorTable(Ctx, RSD, Element); + return parseDescriptorTable(RSD, Element); case RootSignatureElementKind::StaticSamplers: - return parseStaticSampler(Ctx, RSD, Element); + return parseStaticSampler(RSD, Element); case RootSignatureElementKind::Error: - return reportError(Ctx, "Invalid Root Signature Element: " + *ElementText); + return make_error<GenericRSMetadataError>("Invalid Root Signature Element", + Element); } llvm_unreachable("Unhandled RootSignatureElementKind enum."); } -bool MetadataParser::validateRootSignature( - LLVMContext *Ctx, const llvm::mcdxbc::RootSignatureDesc &RSD) { - if (!llvm::hlsl::rootsig::verifyVersion(RSD.Version)) { - return reportValueError(Ctx, "Version", RSD.Version); +Error MetadataParser::validateRootSignature( + const mcdxbc::RootSignatureDesc &RSD) { + Error DeferredErrs = Error::success(); + if (!hlsl::rootsig::verifyVersion(RSD.Version)) { + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "Version", RSD.Version)); } - if (!llvm::hlsl::rootsig::verifyRootFlag(RSD.Flags)) { - return reportValueError(Ctx, "RootFlags", RSD.Flags); + if (!hlsl::rootsig::verifyRootFlag(RSD.Flags)) { + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "RootFlags", RSD.Flags)); } for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) { - if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility)) - return reportValueError(Ctx, "ShaderVisibility", - Info.Header.ShaderVisibility); - - assert(dxbc::isValidParameterType(Info.Header.ParameterType) && - "Invalid value for ParameterType"); - switch (Info.Header.ParameterType) { + switch (Info.Type) { + case dxbc::RootParameterType::Constants32Bit: + break; - case llvm::to_underlying(dxbc::RootParameterType::CBV): - case llvm::to_underlying(dxbc::RootParameterType::UAV): - case llvm::to_underlying(dxbc::RootParameterType::SRV): { - const dxbc::RTS0::v2::RootDescriptor &Descriptor = + case dxbc::RootParameterType::CBV: + case dxbc::RootParameterType::UAV: + case dxbc::RootParameterType::SRV: { + const mcdxbc::RootDescriptor &Descriptor = RSD.ParametersContainer.getRootDescriptor(Info.Location); - if (!llvm::hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister)) - return reportValueError(Ctx, "ShaderRegister", - Descriptor.ShaderRegister); - - if (!llvm::hlsl::rootsig::verifyRegisterSpace(Descriptor.RegisterSpace)) - return reportValueError(Ctx, "RegisterSpace", Descriptor.RegisterSpace); + if (!hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "ShaderRegister", Descriptor.ShaderRegister)); + + if (!hlsl::rootsig::verifyRegisterSpace(Descriptor.RegisterSpace)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "RegisterSpace", Descriptor.RegisterSpace)); if (RSD.Version > 1) { - if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version, - Descriptor.Flags)) - return reportValueError(Ctx, "RootDescriptorFlag", Descriptor.Flags); + if (!hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version, + Descriptor.Flags)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "RootDescriptorFlag", Descriptor.Flags)); } break; } - case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): { + case dxbc::RootParameterType::DescriptorTable: { const mcdxbc::DescriptorTable &Table = RSD.ParametersContainer.getDescriptorTable(Info.Location); for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) { - if (!llvm::hlsl::rootsig::verifyRangeType(Range.RangeType)) - return reportValueError(Ctx, "RangeType", Range.RangeType); - - if (!llvm::hlsl::rootsig::verifyRegisterSpace(Range.RegisterSpace)) - return reportValueError(Ctx, "RegisterSpace", Range.RegisterSpace); - - if (!llvm::hlsl::rootsig::verifyNumDescriptors(Range.NumDescriptors)) - return reportValueError(Ctx, "NumDescriptors", Range.NumDescriptors); - - if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag( + if (!hlsl::rootsig::verifyRangeType(Range.RangeType)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "RangeType", Range.RangeType)); + + if (!hlsl::rootsig::verifyRegisterSpace(Range.RegisterSpace)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "RegisterSpace", Range.RegisterSpace)); + + if (!hlsl::rootsig::verifyNumDescriptors(Range.NumDescriptors)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "NumDescriptors", Range.NumDescriptors)); + + if (!hlsl::rootsig::verifyDescriptorRangeFlag( RSD.Version, Range.RangeType, Range.Flags)) - return reportValueError(Ctx, "DescriptorFlag", Range.Flags); + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "DescriptorFlag", Range.Flags)); } break; } @@ -603,65 +600,108 @@ bool MetadataParser::validateRootSignature( } for (const dxbc::RTS0::v1::StaticSampler &Sampler : RSD.StaticSamplers) { - if (!llvm::hlsl::rootsig::verifySamplerFilter(Sampler.Filter)) - return reportValueError(Ctx, "Filter", Sampler.Filter); - - if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressU)) - return reportValueError(Ctx, "AddressU", Sampler.AddressU); - - if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressV)) - return reportValueError(Ctx, "AddressV", Sampler.AddressV); - - if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressW)) - return reportValueError(Ctx, "AddressW", Sampler.AddressW); - - if (!llvm::hlsl::rootsig::verifyMipLODBias(Sampler.MipLODBias)) - return reportValueError(Ctx, "MipLODBias", Sampler.MipLODBias); - - if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(Sampler.MaxAnisotropy)) - return reportValueError(Ctx, "MaxAnisotropy", Sampler.MaxAnisotropy); - - if (!llvm::hlsl::rootsig::verifyComparisonFunc(Sampler.ComparisonFunc)) - return reportValueError(Ctx, "ComparisonFunc", Sampler.ComparisonFunc); - - if (!llvm::hlsl::rootsig::verifyBorderColor(Sampler.BorderColor)) - return reportValueError(Ctx, "BorderColor", Sampler.BorderColor); - - if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MinLOD)) - return reportValueError(Ctx, "MinLOD", Sampler.MinLOD); - - if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MaxLOD)) - return reportValueError(Ctx, "MaxLOD", Sampler.MaxLOD); - - if (!llvm::hlsl::rootsig::verifyRegisterValue(Sampler.ShaderRegister)) - return reportValueError(Ctx, "ShaderRegister", Sampler.ShaderRegister); - - if (!llvm::hlsl::rootsig::verifyRegisterSpace(Sampler.RegisterSpace)) - return reportValueError(Ctx, "RegisterSpace", Sampler.RegisterSpace); + if (!hlsl::rootsig::verifySamplerFilter(Sampler.Filter)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "Filter", Sampler.Filter)); + + if (!hlsl::rootsig::verifyAddress(Sampler.AddressU)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "AddressU", Sampler.AddressU)); + + if (!hlsl::rootsig::verifyAddress(Sampler.AddressV)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "AddressV", Sampler.AddressV)); + + if (!hlsl::rootsig::verifyAddress(Sampler.AddressW)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "AddressW", Sampler.AddressW)); + + if (!hlsl::rootsig::verifyMipLODBias(Sampler.MipLODBias)) + DeferredErrs = joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<float>>( + "MipLODBias", Sampler.MipLODBias)); + + if (!hlsl::rootsig::verifyMaxAnisotropy(Sampler.MaxAnisotropy)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "MaxAnisotropy", Sampler.MaxAnisotropy)); + + if (!hlsl::rootsig::verifyComparisonFunc(Sampler.ComparisonFunc)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "ComparisonFunc", Sampler.ComparisonFunc)); + + if (!hlsl::rootsig::verifyBorderColor(Sampler.BorderColor)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "BorderColor", Sampler.BorderColor)); + + if (!hlsl::rootsig::verifyLOD(Sampler.MinLOD)) + DeferredErrs = joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<float>>( + "MinLOD", Sampler.MinLOD)); + + if (!hlsl::rootsig::verifyLOD(Sampler.MaxLOD)) + DeferredErrs = joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<float>>( + "MaxLOD", Sampler.MaxLOD)); + + if (!hlsl::rootsig::verifyRegisterValue(Sampler.ShaderRegister)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "ShaderRegister", Sampler.ShaderRegister)); + + if (!hlsl::rootsig::verifyRegisterSpace(Sampler.RegisterSpace)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "RegisterSpace", Sampler.RegisterSpace)); if (!dxbc::isValidShaderVisibility(Sampler.ShaderVisibility)) - return reportValueError(Ctx, "ShaderVisibility", - Sampler.ShaderVisibility); + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "ShaderVisibility", Sampler.ShaderVisibility)); } - return false; + return DeferredErrs; } -bool MetadataParser::ParseRootSignature(LLVMContext *Ctx, - mcdxbc::RootSignatureDesc &RSD) { - bool HasError = false; - - // Loop through the Root Elements of the root signature. +Expected<mcdxbc::RootSignatureDesc> +MetadataParser::ParseRootSignature(uint32_t Version) { + Error DeferredErrs = Error::success(); + mcdxbc::RootSignatureDesc RSD; + RSD.Version = Version; for (const auto &Operand : Root->operands()) { MDNode *Element = dyn_cast<MDNode>(Operand); if (Element == nullptr) - return reportError(Ctx, "Missing Root Element Metadata Node."); + return joinErrors(std::move(DeferredErrs), + make_error<GenericRSMetadataError>( + "Missing Root Element Metadata Node.", nullptr)); - HasError = HasError || parseRootSignatureElement(Ctx, RSD, Element) || - validateRootSignature(Ctx, RSD); + if (auto Err = parseRootSignatureElement(RSD, Element)) + DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err)); } - return HasError; + if (auto Err = validateRootSignature(RSD)) + DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err)); + + if (DeferredErrs) + return std::move(DeferredErrs); + + return std::move(RSD); } } // namespace rootsig } // namespace hlsl diff --git a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp index f11c7d2..72308a3d 100644 --- a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp +++ b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp @@ -29,7 +29,7 @@ bool verifyRegisterValue(uint32_t RegisterValue) { // This Range is reserverved, therefore invalid, according to the spec // https://github.com/llvm/wg-hlsl/blob/main/proposals/0002-root-signature-in-clang.md#all-the-values-should-be-legal bool verifyRegisterSpace(uint32_t RegisterSpace) { - return !(RegisterSpace >= 0xFFFFFFF0 && RegisterSpace <= 0xFFFFFFFF); + return !(RegisterSpace >= 0xFFFFFFF0); } bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) { @@ -180,140 +180,6 @@ bool verifyBorderColor(uint32_t BorderColor) { bool verifyLOD(float LOD) { return !std::isnan(LOD); } -std::optional<const RangeInfo *> -ResourceRange::getOverlapping(const RangeInfo &Info) const { - MapT::const_iterator Interval = Intervals.find(Info.LowerBound); - if (!Interval.valid() || Info.UpperBound < Interval.start()) - return std::nullopt; - return Interval.value(); -} - -const RangeInfo *ResourceRange::lookup(uint32_t X) const { - return Intervals.lookup(X, nullptr); -} - -void ResourceRange::clear() { return Intervals.clear(); } - -std::optional<const RangeInfo *> ResourceRange::insert(const RangeInfo &Info) { - uint32_t LowerBound = Info.LowerBound; - uint32_t UpperBound = Info.UpperBound; - - std::optional<const RangeInfo *> Res = std::nullopt; - MapT::iterator Interval = Intervals.begin(); - - while (true) { - if (UpperBound < LowerBound) - break; - - Interval.advanceTo(LowerBound); - if (!Interval.valid()) // No interval found - break; - - // Let Interval = [x;y] and [LowerBound;UpperBound] = [a;b] and note that - // a <= y implicitly from Intervals.find(LowerBound) - if (UpperBound < Interval.start()) - break; // found interval does not overlap with inserted one - - if (!Res.has_value()) // Update to be the first found intersection - Res = Interval.value(); - - if (Interval.start() <= LowerBound && UpperBound <= Interval.stop()) { - // x <= a <= b <= y implies that [a;b] is covered by [x;y] - // -> so we don't need to insert this, report an overlap - return Res; - } else if (LowerBound <= Interval.start() && - Interval.stop() <= UpperBound) { - // a <= x <= y <= b implies that [x;y] is covered by [a;b] - // -> so remove the existing interval that we will cover with the - // overwrite - Interval.erase(); - } else if (LowerBound < Interval.start() && UpperBound <= Interval.stop()) { - // a < x <= b <= y implies that [a; x] is not covered but [x;b] is - // -> so set b = x - 1 such that [a;x-1] is now the interval to insert - UpperBound = Interval.start() - 1; - } else if (Interval.start() <= LowerBound && Interval.stop() < UpperBound) { - // a < x <= b <= y implies that [y; b] is not covered but [a;y] is - // -> so set a = y + 1 such that [y+1;b] is now the interval to insert - LowerBound = Interval.stop() + 1; - } - } - - assert(LowerBound <= UpperBound && "Attempting to insert an empty interval"); - Intervals.insert(LowerBound, UpperBound, &Info); - return Res; -} - -llvm::SmallVector<OverlappingRanges> -findOverlappingRanges(ArrayRef<RangeInfo> Infos) { - // It is expected that Infos is filled with valid RangeInfos and that - // they are sorted with respect to the RangeInfo <operator - assert(llvm::is_sorted(Infos) && "Ranges must be sorted"); - - llvm::SmallVector<OverlappingRanges> Overlaps; - using GroupT = std::pair<dxil::ResourceClass, /*Space*/ uint32_t>; - - // First we will init our state to track: - if (Infos.size() == 0) - return Overlaps; // No ranges to overlap - GroupT CurGroup = {Infos[0].Class, Infos[0].Space}; - - // Create a ResourceRange for each Visibility - ResourceRange::MapT::Allocator Allocator; - std::array<ResourceRange, 8> Ranges = { - ResourceRange(Allocator), // All - ResourceRange(Allocator), // Vertex - ResourceRange(Allocator), // Hull - ResourceRange(Allocator), // Domain - ResourceRange(Allocator), // Geometry - ResourceRange(Allocator), // Pixel - ResourceRange(Allocator), // Amplification - ResourceRange(Allocator), // Mesh - }; - - // Reset the ResourceRanges for when we iterate through a new group - auto ClearRanges = [&Ranges]() { - for (ResourceRange &Range : Ranges) - Range.clear(); - }; - - // Iterate through collected RangeInfos - for (const RangeInfo &Info : Infos) { - GroupT InfoGroup = {Info.Class, Info.Space}; - // Reset our ResourceRanges when we enter a new group - if (CurGroup != InfoGroup) { - ClearRanges(); - CurGroup = InfoGroup; - } - - // Insert range info into corresponding Visibility ResourceRange - ResourceRange &VisRange = Ranges[llvm::to_underlying(Info.Visibility)]; - if (std::optional<const RangeInfo *> Overlapping = VisRange.insert(Info)) - Overlaps.push_back(OverlappingRanges(&Info, Overlapping.value())); - - // Check for overlap in all overlapping Visibility ResourceRanges - // - // If the range that we are inserting has ShaderVisiblity::All it needs to - // check for an overlap in all other visibility types as well. - // Otherwise, the range that is inserted needs to check that it does not - // overlap with ShaderVisibility::All. - // - // OverlapRanges will be an ArrayRef to all non-all visibility - // ResourceRanges in the former case and it will be an ArrayRef to just the - // all visiblity ResourceRange in the latter case. - ArrayRef<ResourceRange> OverlapRanges = - Info.Visibility == llvm::dxbc::ShaderVisibility::All - ? ArrayRef<ResourceRange>{Ranges}.drop_front() - : ArrayRef<ResourceRange>{Ranges}.take_front(); - - for (const ResourceRange &Range : OverlapRanges) - if (std::optional<const RangeInfo *> Overlapping = - Range.getOverlapping(Info)) - Overlaps.push_back(OverlappingRanges(&Info, Overlapping.value())); - } - - return Overlaps; -} - } // namespace rootsig } // namespace hlsl } // namespace llvm |