diff options
Diffstat (limited to 'llvm/lib/Frontend')
-rw-r--r-- | llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp | 292 | ||||
-rw-r--r-- | llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp | 13 |
2 files changed, 167 insertions, 138 deletions
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp index 7a0cf40..63189f4 100644 --- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp +++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp @@ -24,15 +24,7 @@ namespace llvm { namespace hlsl { namespace rootsig { -char GenericRSMetadataError::ID; -char InvalidRSMetadataFormat::ID; -char InvalidRSMetadataValue::ID; -char TableSamplerMixinError::ID; -char ShaderRegisterOverflowError::ID; -char OffsetOverflowError::ID; -char OffsetAppendAfterOverflow::ID; - -template <typename T> char RootSignatureValidationError<T>::ID; +char RootSignatureValidationError::ID; static std::optional<uint32_t> extractMdIntValue(MDNode *Node, unsigned int OpId) { @@ -57,20 +49,6 @@ static std::optional<StringRef> extractMdStringValue(MDNode *Node, return NodeText->getString(); } -template <typename T, typename = std::enable_if_t< - std::is_enum_v<T> && - std::is_same_v<std::underlying_type_t<T>, uint32_t>>> -static Expected<T> -extractEnumValue(MDNode *Node, unsigned int OpId, StringRef ErrText, - llvm::function_ref<bool(uint32_t)> VerifyFn) { - if (std::optional<uint32_t> Val = extractMdIntValue(Node, OpId)) { - if (!VerifyFn(*Val)) - return make_error<RootSignatureValidationError<uint32_t>>(ErrText, *Val); - return static_cast<T>(*Val); - } - return make_error<InvalidRSMetadataValue>("ShaderVisibility"); -} - namespace { // We use the OverloadVisit with std::visit to ensure the compiler catches if a @@ -81,8 +59,52 @@ template <class... Ts> struct OverloadedVisit : Ts... { }; template <class... Ts> OverloadedVisit(Ts...) -> OverloadedVisit<Ts...>; +struct FmtRange { + dxil::ResourceClass Type; + uint32_t Register; + uint32_t Space; + + FmtRange(const mcdxbc::DescriptorRange &Range) + : Type(Range.RangeType), Register(Range.BaseShaderRegister), + Space(Range.RegisterSpace) {} +}; + +raw_ostream &operator<<(llvm::raw_ostream &OS, const FmtRange &Range) { + OS << getResourceClassName(Range.Type) << "(register=" << Range.Register + << ", space=" << Range.Space << ")"; + return OS; +} + +struct FmtMDNode { + const MDNode *Node; + + FmtMDNode(const MDNode *Node) : Node(Node) {} +}; + +raw_ostream &operator<<(llvm::raw_ostream &OS, FmtMDNode Fmt) { + Fmt.Node->printTree(OS); + return OS; +} + +static Error makeRSError(const Twine &Msg) { + return make_error<RootSignatureValidationError>(Msg); +} } // namespace +template <typename T, typename = std::enable_if_t< + std::is_enum_v<T> && + std::is_same_v<std::underlying_type_t<T>, uint32_t>>> +static Expected<T> +extractEnumValue(MDNode *Node, unsigned int OpId, StringRef ErrText, + llvm::function_ref<bool(uint32_t)> VerifyFn) { + if (std::optional<uint32_t> Val = extractMdIntValue(Node, OpId)) { + if (!VerifyFn(*Val)) + return makeRSError(formatv("Invalid value for {0}: {1}", ErrText, Val)); + return static_cast<T>(*Val); + } + return makeRSError(formatv("Invalid value for {0}:", ErrText)); +} + MDNode *MetadataBuilder::BuildRootSignature() { const auto Visitor = OverloadedVisit{ [this](const dxbc::RootFlags &Flags) -> MDNode * { @@ -226,12 +248,12 @@ MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) { Error MetadataParser::parseRootFlags(mcdxbc::RootSignatureDesc &RSD, MDNode *RootFlagNode) { if (RootFlagNode->getNumOperands() != 2) - return make_error<InvalidRSMetadataFormat>("RootFlag Element"); + return makeRSError("Invalid format for RootFlags Element"); if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1)) RSD.Flags = *Val; else - return make_error<InvalidRSMetadataValue>("RootFlag"); + return makeRSError("Invalid value for RootFlag"); return Error::success(); } @@ -239,7 +261,7 @@ Error MetadataParser::parseRootFlags(mcdxbc::RootSignatureDesc &RSD, Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD, MDNode *RootConstantNode) { if (RootConstantNode->getNumOperands() != 5) - return make_error<InvalidRSMetadataFormat>("RootConstants Element"); + return makeRSError("Invalid format for RootConstants Element"); Expected<dxbc::ShaderVisibility> Visibility = extractEnumValue<dxbc::ShaderVisibility>(RootConstantNode, 1, @@ -252,17 +274,17 @@ Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD, if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2)) Constants.ShaderRegister = *Val; else - return make_error<InvalidRSMetadataValue>("ShaderRegister"); + return makeRSError("Invalid value for ShaderRegister"); if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3)) Constants.RegisterSpace = *Val; else - return make_error<InvalidRSMetadataValue>("RegisterSpace"); + return makeRSError("Invalid value for RegisterSpace"); if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4)) Constants.Num32BitValues = *Val; else - return make_error<InvalidRSMetadataValue>("Num32BitValues"); + return makeRSError("Invalid value for Num32BitValues"); RSD.ParametersContainer.addParameter(dxbc::RootParameterType::Constants32Bit, *Visibility, Constants); @@ -279,7 +301,7 @@ Error MetadataParser::parseRootDescriptors( "parseRootDescriptors should only be called with RootDescriptor " "element kind."); if (RootDescriptorNode->getNumOperands() != 5) - return make_error<InvalidRSMetadataFormat>("Root Descriptor Element"); + return makeRSError("Invalid format for Root Descriptor Element"); dxbc::RootParameterType Type; switch (ElementKind) { @@ -308,12 +330,12 @@ Error MetadataParser::parseRootDescriptors( if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2)) Descriptor.ShaderRegister = *Val; else - return make_error<InvalidRSMetadataValue>("ShaderRegister"); + return makeRSError("Invalid value for ShaderRegister"); if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 3)) Descriptor.RegisterSpace = *Val; else - return make_error<InvalidRSMetadataValue>("RegisterSpace"); + return makeRSError("Invalid value for RegisterSpace"); if (RSD.Version == 1) { RSD.ParametersContainer.addParameter(Type, *Visibility, Descriptor); @@ -324,7 +346,7 @@ Error MetadataParser::parseRootDescriptors( if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 4)) Descriptor.Flags = *Val; else - return make_error<InvalidRSMetadataValue>("Root Descriptor Flags"); + return makeRSError("Invalid value for Root Descriptor Flags"); RSD.ParametersContainer.addParameter(Type, *Visibility, Descriptor); return Error::success(); @@ -333,7 +355,7 @@ Error MetadataParser::parseRootDescriptors( Error MetadataParser::parseDescriptorRange(mcdxbc::DescriptorTable &Table, MDNode *RangeDescriptorNode) { if (RangeDescriptorNode->getNumOperands() != 6) - return make_error<InvalidRSMetadataFormat>("Descriptor Range"); + return makeRSError("Invalid format for Descriptor Range"); mcdxbc::DescriptorRange Range; @@ -341,7 +363,7 @@ Error MetadataParser::parseDescriptorRange(mcdxbc::DescriptorTable &Table, extractMdStringValue(RangeDescriptorNode, 0); if (!ElementText.has_value()) - return make_error<InvalidRSMetadataFormat>("Descriptor Range"); + return makeRSError("Invalid format for Descriptor Range"); if (*ElementText == "CBV") Range.RangeType = dxil::ResourceClass::CBuffer; @@ -352,35 +374,34 @@ Error MetadataParser::parseDescriptorRange(mcdxbc::DescriptorTable &Table, else if (*ElementText == "Sampler") Range.RangeType = dxil::ResourceClass::Sampler; else - return make_error<GenericRSMetadataError>("Invalid Descriptor Range type.", - RangeDescriptorNode); + return makeRSError(formatv("Invalid Descriptor Range type.\n{0}", + FmtMDNode{RangeDescriptorNode})); if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1)) Range.NumDescriptors = *Val; else - return make_error<GenericRSMetadataError>("Number of Descriptor in Range", - RangeDescriptorNode); + return makeRSError(formatv("Invalid number of Descriptor in Range.\n{0}", + FmtMDNode{RangeDescriptorNode})); if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2)) Range.BaseShaderRegister = *Val; else - return make_error<InvalidRSMetadataValue>("BaseShaderRegister"); + return makeRSError("Invalid value for BaseShaderRegister"); if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3)) Range.RegisterSpace = *Val; else - return make_error<InvalidRSMetadataValue>("RegisterSpace"); + return makeRSError("Invalid value for RegisterSpace"); if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4)) Range.OffsetInDescriptorsFromTableStart = *Val; else - return make_error<InvalidRSMetadataValue>( - "OffsetInDescriptorsFromTableStart"); + return makeRSError("Invalid value for OffsetInDescriptorsFromTableStart"); if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5)) Range.Flags = *Val; else - return make_error<InvalidRSMetadataValue>("Descriptor Range Flags"); + return makeRSError("Invalid value for Descriptor Range Flags"); Table.Ranges.push_back(Range); return Error::success(); @@ -390,7 +411,7 @@ Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD, MDNode *DescriptorTableNode) { const unsigned int NumOperands = DescriptorTableNode->getNumOperands(); if (NumOperands < 2) - return make_error<InvalidRSMetadataFormat>("Descriptor Table"); + return makeRSError("Invalid format for Descriptor Table"); Expected<dxbc::ShaderVisibility> Visibility = extractEnumValue<dxbc::ShaderVisibility>(DescriptorTableNode, 1, @@ -404,8 +425,8 @@ Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD, for (unsigned int I = 2; I < NumOperands; I++) { MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I)); if (Element == nullptr) - return make_error<GenericRSMetadataError>( - "Missing Root Element Metadata Node.", DescriptorTableNode); + return makeRSError(formatv("Missing Root Element Metadata Node.\n{0}", + FmtMDNode{DescriptorTableNode})); if (auto Err = parseDescriptorRange(Table, Element)) return Err; @@ -419,7 +440,7 @@ Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD, Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD, MDNode *StaticSamplerNode) { if (StaticSamplerNode->getNumOperands() != 15) - return make_error<InvalidRSMetadataFormat>("Static Sampler"); + return makeRSError("Invalid format for Static Sampler"); mcdxbc::StaticSampler Sampler; @@ -453,12 +474,12 @@ Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD, if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 5)) Sampler.MipLODBias = *Val; else - return make_error<InvalidRSMetadataValue>("MipLODBias"); + return makeRSError("Invalid value for MipLODBias"); if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6)) Sampler.MaxAnisotropy = *Val; else - return make_error<InvalidRSMetadataValue>("MaxAnisotropy"); + return makeRSError("Invalid value for MaxAnisotropy"); Expected<dxbc::ComparisonFunc> ComparisonFunc = extractEnumValue<dxbc::ComparisonFunc>( @@ -477,22 +498,22 @@ Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD, if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 9)) Sampler.MinLOD = *Val; else - return make_error<InvalidRSMetadataValue>("MinLOD"); + return makeRSError("Invalid value for MinLOD"); if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 10)) Sampler.MaxLOD = *Val; else - return make_error<InvalidRSMetadataValue>("MaxLOD"); + return makeRSError("Invalid value for MaxLOD"); if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11)) Sampler.ShaderRegister = *Val; else - return make_error<InvalidRSMetadataValue>("ShaderRegister"); + return makeRSError("Invalid value for ShaderRegister"); if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 12)) Sampler.RegisterSpace = *Val; else - return make_error<InvalidRSMetadataValue>("RegisterSpace"); + return makeRSError("Invalid value for RegisterSpace"); Expected<dxbc::ShaderVisibility> Visibility = extractEnumValue<dxbc::ShaderVisibility>(StaticSamplerNode, 13, @@ -511,7 +532,7 @@ Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD, if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 14)) Sampler.Flags = *Val; else - return make_error<InvalidRSMetadataValue>("Static Sampler Flags"); + return makeRSError("Invalid value for Static Sampler Flags"); RSD.StaticSamplers.push_back(Sampler); return Error::success(); @@ -521,7 +542,7 @@ Error MetadataParser::parseRootSignatureElement(mcdxbc::RootSignatureDesc &RSD, MDNode *Element) { std::optional<StringRef> ElementText = extractMdStringValue(Element, 0); if (!ElementText.has_value()) - return make_error<InvalidRSMetadataFormat>("Root Element"); + return makeRSError("Invalid format for Root Element"); RootSignatureElementKind ElementKind = StringSwitch<RootSignatureElementKind>(*ElementText) @@ -549,8 +570,8 @@ Error MetadataParser::parseRootSignatureElement(mcdxbc::RootSignatureDesc &RSD, case RootSignatureElementKind::StaticSamplers: return parseStaticSampler(RSD, Element); case RootSignatureElementKind::Error: - return make_error<GenericRSMetadataError>("Invalid Root Signature Element", - Element); + return makeRSError( + formatv("Invalid Root Signature Element\n{0}", FmtMDNode{Element})); } llvm_unreachable("Unhandled RootSignatureElementKind enum."); @@ -563,7 +584,10 @@ validateDescriptorTableSamplerMixin(const mcdxbc::DescriptorTable &Table, for (const mcdxbc::DescriptorRange &Range : Table.Ranges) { if (Range.RangeType == dxil::ResourceClass::Sampler && CurrRC != dxil::ResourceClass::Sampler) - return make_error<TableSamplerMixinError>(CurrRC, Location); + return makeRSError( + formatv("Samplers cannot be mixed with other resource types in a " + "descriptor table, {0}(location={1})", + getResourceClassName(CurrRC), Location)); CurrRC = Range.RangeType; } return Error::success(); @@ -583,8 +607,8 @@ validateDescriptorTableRegisterOverflow(const mcdxbc::DescriptorTable &Table, Range.BaseShaderRegister, Range.NumDescriptors); if (!verifyNoOverflowedOffset(RangeBound)) - return make_error<ShaderRegisterOverflowError>( - Range.RangeType, Range.BaseShaderRegister, Range.RegisterSpace); + return makeRSError( + formatv("Overflow for shader register range: {0}", FmtRange{Range})); bool IsAppending = Range.OffsetInDescriptorsFromTableStart == DescriptorTableOffsetAppend; @@ -592,15 +616,16 @@ validateDescriptorTableRegisterOverflow(const mcdxbc::DescriptorTable &Table, Offset = Range.OffsetInDescriptorsFromTableStart; if (IsPrevUnbound && IsAppending) - return make_error<OffsetAppendAfterOverflow>( - Range.RangeType, Range.BaseShaderRegister, Range.RegisterSpace); + return makeRSError( + formatv("Range {0} cannot be appended after an unbounded range", + FmtRange{Range})); const uint64_t OffsetBound = llvm::hlsl::rootsig::computeRangeBound(Offset, Range.NumDescriptors); if (!verifyNoOverflowedOffset(OffsetBound)) - return make_error<OffsetOverflowError>( - Range.RangeType, Range.BaseShaderRegister, Range.RegisterSpace); + return makeRSError(formatv("Offset overflow for descriptor range: {0}.", + FmtRange{Range})); Offset = OffsetBound + 1; IsPrevUnbound = @@ -614,17 +639,15 @@ Error MetadataParser::validateRootSignature( const mcdxbc::RootSignatureDesc &RSD) { Error DeferredErrs = Error::success(); if (!hlsl::rootsig::verifyVersion(RSD.Version)) { - DeferredErrs = - joinErrors(std::move(DeferredErrs), - make_error<RootSignatureValidationError<uint32_t>>( - "Version", RSD.Version)); + DeferredErrs = joinErrors( + std::move(DeferredErrs), + makeRSError(formatv("Invalid value for Version: {0}", RSD.Version))); } if (!hlsl::rootsig::verifyRootFlag(RSD.Flags)) { - DeferredErrs = - joinErrors(std::move(DeferredErrs), - make_error<RootSignatureValidationError<uint32_t>>( - "RootFlags", RSD.Flags)); + DeferredErrs = joinErrors( + std::move(DeferredErrs), + makeRSError(formatv("Invalid value for RootFlags: {0}", RSD.Flags))); } for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) { @@ -639,24 +662,27 @@ Error MetadataParser::validateRootSignature( const mcdxbc::RootDescriptor &Descriptor = RSD.ParametersContainer.getRootDescriptor(Info.Location); if (!hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister)) - DeferredErrs = - joinErrors(std::move(DeferredErrs), - make_error<RootSignatureValidationError<uint32_t>>( - "ShaderRegister", Descriptor.ShaderRegister)); + DeferredErrs = joinErrors( + std::move(DeferredErrs), + makeRSError(formatv("Invalid value for ShaderRegister: {0}", + Descriptor.ShaderRegister))); if (!hlsl::rootsig::verifyRegisterSpace(Descriptor.RegisterSpace)) - DeferredErrs = - joinErrors(std::move(DeferredErrs), - make_error<RootSignatureValidationError<uint32_t>>( - "RegisterSpace", Descriptor.RegisterSpace)); + DeferredErrs = joinErrors( + std::move(DeferredErrs), + makeRSError(formatv("Invalid value for RegisterSpace: {0}", + Descriptor.RegisterSpace))); if (RSD.Version > 1) { - if (!hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version, - Descriptor.Flags)) - DeferredErrs = - joinErrors(std::move(DeferredErrs), - make_error<RootSignatureValidationError<uint32_t>>( - "RootDescriptorFlag", Descriptor.Flags)); + bool IsValidFlag = + dxbc::isValidRootDesciptorFlags(Descriptor.Flags) && + hlsl::rootsig::verifyRootDescriptorFlag( + RSD.Version, dxbc::RootDescriptorFlags(Descriptor.Flags)); + if (!IsValidFlag) + DeferredErrs = joinErrors( + std::move(DeferredErrs), + makeRSError(formatv("Invalid value for RootDescriptorFlag: {0}", + Descriptor.Flags))); } break; } @@ -665,24 +691,26 @@ Error MetadataParser::validateRootSignature( RSD.ParametersContainer.getDescriptorTable(Info.Location); for (const mcdxbc::DescriptorRange &Range : Table) { if (!hlsl::rootsig::verifyRegisterSpace(Range.RegisterSpace)) - DeferredErrs = - joinErrors(std::move(DeferredErrs), - make_error<RootSignatureValidationError<uint32_t>>( - "RegisterSpace", Range.RegisterSpace)); + DeferredErrs = joinErrors( + std::move(DeferredErrs), + makeRSError(formatv("Invalid value for RegisterSpace: {0}", + Range.RegisterSpace))); if (!hlsl::rootsig::verifyNumDescriptors(Range.NumDescriptors)) - DeferredErrs = - joinErrors(std::move(DeferredErrs), - make_error<RootSignatureValidationError<uint32_t>>( - "NumDescriptors", Range.NumDescriptors)); - - if (!hlsl::rootsig::verifyDescriptorRangeFlag( - RSD.Version, Range.RangeType, - dxbc::DescriptorRangeFlags(Range.Flags))) - DeferredErrs = - joinErrors(std::move(DeferredErrs), - make_error<RootSignatureValidationError<uint32_t>>( - "DescriptorFlag", Range.Flags)); + DeferredErrs = joinErrors( + std::move(DeferredErrs), + makeRSError(formatv("Invalid value for NumDescriptors: {0}", + Range.NumDescriptors))); + + bool IsValidFlag = dxbc::isValidDescriptorRangeFlags(Range.Flags) && + hlsl::rootsig::verifyDescriptorRangeFlag( + RSD.Version, Range.RangeType, + dxbc::DescriptorRangeFlags(Range.Flags)); + if (!IsValidFlag) + DeferredErrs = joinErrors( + std::move(DeferredErrs), + makeRSError(formatv("Invalid value for DescriptorFlag: {0}", + Range.Flags))); if (Error Err = validateDescriptorTableSamplerMixin(Table, Info.Location)) @@ -700,43 +728,49 @@ Error MetadataParser::validateRootSignature( for (const mcdxbc::StaticSampler &Sampler : RSD.StaticSamplers) { if (!hlsl::rootsig::verifyMipLODBias(Sampler.MipLODBias)) - DeferredErrs = joinErrors(std::move(DeferredErrs), - make_error<RootSignatureValidationError<float>>( - "MipLODBias", Sampler.MipLODBias)); + DeferredErrs = + joinErrors(std::move(DeferredErrs), + makeRSError(formatv("Invalid value for MipLODBias: {0:e}", + Sampler.MipLODBias))); if (!hlsl::rootsig::verifyMaxAnisotropy(Sampler.MaxAnisotropy)) DeferredErrs = joinErrors(std::move(DeferredErrs), - make_error<RootSignatureValidationError<uint32_t>>( - "MaxAnisotropy", Sampler.MaxAnisotropy)); + makeRSError(formatv("Invalid value for MaxAnisotropy: {0}", + Sampler.MaxAnisotropy))); if (!hlsl::rootsig::verifyLOD(Sampler.MinLOD)) - DeferredErrs = joinErrors(std::move(DeferredErrs), - make_error<RootSignatureValidationError<float>>( - "MinLOD", Sampler.MinLOD)); - - if (!hlsl::rootsig::verifyLOD(Sampler.MaxLOD)) - DeferredErrs = joinErrors(std::move(DeferredErrs), - make_error<RootSignatureValidationError<float>>( - "MaxLOD", Sampler.MaxLOD)); - - if (!hlsl::rootsig::verifyRegisterValue(Sampler.ShaderRegister)) DeferredErrs = joinErrors(std::move(DeferredErrs), - make_error<RootSignatureValidationError<uint32_t>>( - "ShaderRegister", Sampler.ShaderRegister)); + makeRSError(formatv("Invalid value for MinLOD: {0}", + Sampler.MinLOD))); - if (!hlsl::rootsig::verifyRegisterSpace(Sampler.RegisterSpace)) + if (!hlsl::rootsig::verifyLOD(Sampler.MaxLOD)) DeferredErrs = joinErrors(std::move(DeferredErrs), - make_error<RootSignatureValidationError<uint32_t>>( - "RegisterSpace", Sampler.RegisterSpace)); + makeRSError(formatv("Invalid value for MaxLOD: {0}", + Sampler.MaxLOD))); - if (!hlsl::rootsig::verifyStaticSamplerFlags(RSD.Version, Sampler.Flags)) + if (!hlsl::rootsig::verifyRegisterValue(Sampler.ShaderRegister)) + DeferredErrs = joinErrors( + std::move(DeferredErrs), + makeRSError(formatv("Invalid value for ShaderRegister: {0}", + Sampler.ShaderRegister))); + + if (!hlsl::rootsig::verifyRegisterSpace(Sampler.RegisterSpace)) DeferredErrs = joinErrors(std::move(DeferredErrs), - make_error<RootSignatureValidationError<uint32_t>>( - "Static Sampler Flag", Sampler.Flags)); + makeRSError(formatv("Invalid value for RegisterSpace: {0}", + Sampler.RegisterSpace))); + bool IsValidFlag = + dxbc::isValidStaticSamplerFlags(Sampler.Flags) && + hlsl::rootsig::verifyStaticSamplerFlags( + RSD.Version, dxbc::StaticSamplerFlags(Sampler.Flags)); + if (!IsValidFlag) + DeferredErrs = joinErrors( + std::move(DeferredErrs), + makeRSError(formatv("Invalid value for Static Sampler Flag: {0}", + Sampler.Flags))); } return DeferredErrs; @@ -750,9 +784,9 @@ MetadataParser::ParseRootSignature(uint32_t Version) { for (const auto &Operand : Root->operands()) { MDNode *Element = dyn_cast<MDNode>(Operand); if (Element == nullptr) - return joinErrors(std::move(DeferredErrs), - make_error<GenericRSMetadataError>( - "Missing Root Element Metadata Node.", nullptr)); + return joinErrors( + std::move(DeferredErrs), + makeRSError(formatv("Missing Root Element Metadata Node."))); if (auto Err = parseRootSignatureElement(RSD, Element)) DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err)); diff --git a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp index 8a2b03d..30408df 100644 --- a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp +++ b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp @@ -34,7 +34,8 @@ bool verifyRegisterSpace(uint32_t RegisterSpace) { return !(RegisterSpace >= 0xFFFFFFF0); } -bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) { +bool verifyRootDescriptorFlag(uint32_t Version, + dxbc::RootDescriptorFlags FlagsVal) { using FlagT = dxbc::RootDescriptorFlags; FlagT Flags = FlagT(FlagsVal); if (Version == 1) @@ -56,7 +57,6 @@ bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) { bool verifyDescriptorRangeFlag(uint32_t Version, dxil::ResourceClass Type, dxbc::DescriptorRangeFlags Flags) { using FlagT = dxbc::DescriptorRangeFlags; - const bool IsSampler = (Type == dxil::ResourceClass::Sampler); if (Version == 1) { @@ -113,13 +113,8 @@ bool verifyDescriptorRangeFlag(uint32_t Version, dxil::ResourceClass Type, return (Flags & ~Mask) == FlagT::None; } -bool verifyStaticSamplerFlags(uint32_t Version, uint32_t FlagsNumber) { - uint32_t LargestValue = llvm::to_underlying( - dxbc::StaticSamplerFlags::LLVM_BITMASK_LARGEST_ENUMERATOR); - if (FlagsNumber >= NextPowerOf2(LargestValue)) - return false; - - dxbc::StaticSamplerFlags Flags = dxbc::StaticSamplerFlags(FlagsNumber); +bool verifyStaticSamplerFlags(uint32_t Version, + dxbc::StaticSamplerFlags Flags) { if (Version <= 2) return Flags == dxbc::StaticSamplerFlags::None; |