diff options
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/BinaryFormat/DXContainer.cpp | 21 | ||||
-rw-r--r-- | llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp | 22 | ||||
-rw-r--r-- | llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp | 13 |
3 files changed, 40 insertions, 16 deletions
diff --git a/llvm/lib/BinaryFormat/DXContainer.cpp b/llvm/lib/BinaryFormat/DXContainer.cpp index b334f86..22f5180 100644 --- a/llvm/lib/BinaryFormat/DXContainer.cpp +++ b/llvm/lib/BinaryFormat/DXContainer.cpp @@ -82,6 +82,27 @@ bool llvm::dxbc::isValidBorderColor(uint32_t V) { return false; } +bool llvm::dxbc::isValidRootDesciptorFlags(uint32_t V) { + using FlagT = dxbc::RootDescriptorFlags; + uint32_t LargestValue = + llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR); + return V < NextPowerOf2(LargestValue); +} + +bool llvm::dxbc::isValidDescriptorRangeFlags(uint32_t V) { + using FlagT = dxbc::DescriptorRangeFlags; + uint32_t LargestValue = + llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR); + return V < NextPowerOf2(LargestValue); +} + +bool llvm::dxbc::isValidStaticSamplerFlags(uint32_t V) { + using FlagT = dxbc::StaticSamplerFlags; + uint32_t LargestValue = + llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR); + return V < NextPowerOf2(LargestValue); +} + dxbc::PartType dxbc::parsePartType(StringRef S) { #define CONTAINER_PART(PartName) .Case(#PartName, PartType::PartName) return StringSwitch<dxbc::PartType>(S) diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp index 7a0cf40..707f0c3 100644 --- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp +++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp @@ -651,8 +651,11 @@ Error MetadataParser::validateRootSignature( "RegisterSpace", Descriptor.RegisterSpace)); if (RSD.Version > 1) { - if (!hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version, - Descriptor.Flags)) + bool IsValidFlag = + dxbc::isValidRootDesciptorFlags(Descriptor.Flags) && + hlsl::rootsig::verifyRootDescriptorFlag( + RSD.Version, dxbc::RootDescriptorFlags(Descriptor.Flags)); + if (!IsValidFlag) DeferredErrs = joinErrors(std::move(DeferredErrs), make_error<RootSignatureValidationError<uint32_t>>( @@ -676,9 +679,11 @@ Error MetadataParser::validateRootSignature( make_error<RootSignatureValidationError<uint32_t>>( "NumDescriptors", Range.NumDescriptors)); - if (!hlsl::rootsig::verifyDescriptorRangeFlag( - RSD.Version, Range.RangeType, - dxbc::DescriptorRangeFlags(Range.Flags))) + bool IsValidFlag = dxbc::isValidDescriptorRangeFlags(Range.Flags) && + hlsl::rootsig::verifyDescriptorRangeFlag( + RSD.Version, Range.RangeType, + dxbc::DescriptorRangeFlags(Range.Flags)); + if (!IsValidFlag) DeferredErrs = joinErrors(std::move(DeferredErrs), make_error<RootSignatureValidationError<uint32_t>>( @@ -731,8 +736,11 @@ Error MetadataParser::validateRootSignature( joinErrors(std::move(DeferredErrs), make_error<RootSignatureValidationError<uint32_t>>( "RegisterSpace", Sampler.RegisterSpace)); - - if (!hlsl::rootsig::verifyStaticSamplerFlags(RSD.Version, Sampler.Flags)) + bool IsValidFlag = + dxbc::isValidStaticSamplerFlags(Sampler.Flags) && + hlsl::rootsig::verifyStaticSamplerFlags( + RSD.Version, dxbc::StaticSamplerFlags(Sampler.Flags)); + if (!IsValidFlag) DeferredErrs = joinErrors(std::move(DeferredErrs), make_error<RootSignatureValidationError<uint32_t>>( diff --git a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp index 8a2b03d..30408df 100644 --- a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp +++ b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp @@ -34,7 +34,8 @@ bool verifyRegisterSpace(uint32_t RegisterSpace) { return !(RegisterSpace >= 0xFFFFFFF0); } -bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) { +bool verifyRootDescriptorFlag(uint32_t Version, + dxbc::RootDescriptorFlags FlagsVal) { using FlagT = dxbc::RootDescriptorFlags; FlagT Flags = FlagT(FlagsVal); if (Version == 1) @@ -56,7 +57,6 @@ bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) { bool verifyDescriptorRangeFlag(uint32_t Version, dxil::ResourceClass Type, dxbc::DescriptorRangeFlags Flags) { using FlagT = dxbc::DescriptorRangeFlags; - const bool IsSampler = (Type == dxil::ResourceClass::Sampler); if (Version == 1) { @@ -113,13 +113,8 @@ bool verifyDescriptorRangeFlag(uint32_t Version, dxil::ResourceClass Type, return (Flags & ~Mask) == FlagT::None; } -bool verifyStaticSamplerFlags(uint32_t Version, uint32_t FlagsNumber) { - uint32_t LargestValue = llvm::to_underlying( - dxbc::StaticSamplerFlags::LLVM_BITMASK_LARGEST_ENUMERATOR); - if (FlagsNumber >= NextPowerOf2(LargestValue)) - return false; - - dxbc::StaticSamplerFlags Flags = dxbc::StaticSamplerFlags(FlagsNumber); +bool verifyStaticSamplerFlags(uint32_t Version, + dxbc::StaticSamplerFlags Flags) { if (Version <= 2) return Flags == dxbc::StaticSamplerFlags::None; |