aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Frontend
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Frontend')
-rw-r--r--llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp292
-rw-r--r--llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp13
2 files changed, 167 insertions, 138 deletions
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
index 7a0cf40..63189f4 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
@@ -24,15 +24,7 @@ namespace llvm {
namespace hlsl {
namespace rootsig {
-char GenericRSMetadataError::ID;
-char InvalidRSMetadataFormat::ID;
-char InvalidRSMetadataValue::ID;
-char TableSamplerMixinError::ID;
-char ShaderRegisterOverflowError::ID;
-char OffsetOverflowError::ID;
-char OffsetAppendAfterOverflow::ID;
-
-template <typename T> char RootSignatureValidationError<T>::ID;
+char RootSignatureValidationError::ID;
static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
unsigned int OpId) {
@@ -57,20 +49,6 @@ static std::optional<StringRef> extractMdStringValue(MDNode *Node,
return NodeText->getString();
}
-template <typename T, typename = std::enable_if_t<
- std::is_enum_v<T> &&
- std::is_same_v<std::underlying_type_t<T>, uint32_t>>>
-static Expected<T>
-extractEnumValue(MDNode *Node, unsigned int OpId, StringRef ErrText,
- llvm::function_ref<bool(uint32_t)> VerifyFn) {
- if (std::optional<uint32_t> Val = extractMdIntValue(Node, OpId)) {
- if (!VerifyFn(*Val))
- return make_error<RootSignatureValidationError<uint32_t>>(ErrText, *Val);
- return static_cast<T>(*Val);
- }
- return make_error<InvalidRSMetadataValue>("ShaderVisibility");
-}
-
namespace {
// We use the OverloadVisit with std::visit to ensure the compiler catches if a
@@ -81,8 +59,52 @@ template <class... Ts> struct OverloadedVisit : Ts... {
};
template <class... Ts> OverloadedVisit(Ts...) -> OverloadedVisit<Ts...>;
+struct FmtRange {
+ dxil::ResourceClass Type;
+ uint32_t Register;
+ uint32_t Space;
+
+ FmtRange(const mcdxbc::DescriptorRange &Range)
+ : Type(Range.RangeType), Register(Range.BaseShaderRegister),
+ Space(Range.RegisterSpace) {}
+};
+
+raw_ostream &operator<<(llvm::raw_ostream &OS, const FmtRange &Range) {
+ OS << getResourceClassName(Range.Type) << "(register=" << Range.Register
+ << ", space=" << Range.Space << ")";
+ return OS;
+}
+
+struct FmtMDNode {
+ const MDNode *Node;
+
+ FmtMDNode(const MDNode *Node) : Node(Node) {}
+};
+
+raw_ostream &operator<<(llvm::raw_ostream &OS, FmtMDNode Fmt) {
+ Fmt.Node->printTree(OS);
+ return OS;
+}
+
+static Error makeRSError(const Twine &Msg) {
+ return make_error<RootSignatureValidationError>(Msg);
+}
} // namespace
+template <typename T, typename = std::enable_if_t<
+ std::is_enum_v<T> &&
+ std::is_same_v<std::underlying_type_t<T>, uint32_t>>>
+static Expected<T>
+extractEnumValue(MDNode *Node, unsigned int OpId, StringRef ErrText,
+ llvm::function_ref<bool(uint32_t)> VerifyFn) {
+ if (std::optional<uint32_t> Val = extractMdIntValue(Node, OpId)) {
+ if (!VerifyFn(*Val))
+ return makeRSError(formatv("Invalid value for {0}: {1}", ErrText, Val));
+ return static_cast<T>(*Val);
+ }
+ return makeRSError(formatv("Invalid value for {0}:", ErrText));
+}
+
MDNode *MetadataBuilder::BuildRootSignature() {
const auto Visitor = OverloadedVisit{
[this](const dxbc::RootFlags &Flags) -> MDNode * {
@@ -226,12 +248,12 @@ MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) {
Error MetadataParser::parseRootFlags(mcdxbc::RootSignatureDesc &RSD,
MDNode *RootFlagNode) {
if (RootFlagNode->getNumOperands() != 2)
- return make_error<InvalidRSMetadataFormat>("RootFlag Element");
+ return makeRSError("Invalid format for RootFlags Element");
if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1))
RSD.Flags = *Val;
else
- return make_error<InvalidRSMetadataValue>("RootFlag");
+ return makeRSError("Invalid value for RootFlag");
return Error::success();
}
@@ -239,7 +261,7 @@ Error MetadataParser::parseRootFlags(mcdxbc::RootSignatureDesc &RSD,
Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
MDNode *RootConstantNode) {
if (RootConstantNode->getNumOperands() != 5)
- return make_error<InvalidRSMetadataFormat>("RootConstants Element");
+ return makeRSError("Invalid format for RootConstants Element");
Expected<dxbc::ShaderVisibility> Visibility =
extractEnumValue<dxbc::ShaderVisibility>(RootConstantNode, 1,
@@ -252,17 +274,17 @@ Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
Constants.ShaderRegister = *Val;
else
- return make_error<InvalidRSMetadataValue>("ShaderRegister");
+ return makeRSError("Invalid value for ShaderRegister");
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3))
Constants.RegisterSpace = *Val;
else
- return make_error<InvalidRSMetadataValue>("RegisterSpace");
+ return makeRSError("Invalid value for RegisterSpace");
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4))
Constants.Num32BitValues = *Val;
else
- return make_error<InvalidRSMetadataValue>("Num32BitValues");
+ return makeRSError("Invalid value for Num32BitValues");
RSD.ParametersContainer.addParameter(dxbc::RootParameterType::Constants32Bit,
*Visibility, Constants);
@@ -279,7 +301,7 @@ Error MetadataParser::parseRootDescriptors(
"parseRootDescriptors should only be called with RootDescriptor "
"element kind.");
if (RootDescriptorNode->getNumOperands() != 5)
- return make_error<InvalidRSMetadataFormat>("Root Descriptor Element");
+ return makeRSError("Invalid format for Root Descriptor Element");
dxbc::RootParameterType Type;
switch (ElementKind) {
@@ -308,12 +330,12 @@ Error MetadataParser::parseRootDescriptors(
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2))
Descriptor.ShaderRegister = *Val;
else
- return make_error<InvalidRSMetadataValue>("ShaderRegister");
+ return makeRSError("Invalid value for ShaderRegister");
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 3))
Descriptor.RegisterSpace = *Val;
else
- return make_error<InvalidRSMetadataValue>("RegisterSpace");
+ return makeRSError("Invalid value for RegisterSpace");
if (RSD.Version == 1) {
RSD.ParametersContainer.addParameter(Type, *Visibility, Descriptor);
@@ -324,7 +346,7 @@ Error MetadataParser::parseRootDescriptors(
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 4))
Descriptor.Flags = *Val;
else
- return make_error<InvalidRSMetadataValue>("Root Descriptor Flags");
+ return makeRSError("Invalid value for Root Descriptor Flags");
RSD.ParametersContainer.addParameter(Type, *Visibility, Descriptor);
return Error::success();
@@ -333,7 +355,7 @@ Error MetadataParser::parseRootDescriptors(
Error MetadataParser::parseDescriptorRange(mcdxbc::DescriptorTable &Table,
MDNode *RangeDescriptorNode) {
if (RangeDescriptorNode->getNumOperands() != 6)
- return make_error<InvalidRSMetadataFormat>("Descriptor Range");
+ return makeRSError("Invalid format for Descriptor Range");
mcdxbc::DescriptorRange Range;
@@ -341,7 +363,7 @@ Error MetadataParser::parseDescriptorRange(mcdxbc::DescriptorTable &Table,
extractMdStringValue(RangeDescriptorNode, 0);
if (!ElementText.has_value())
- return make_error<InvalidRSMetadataFormat>("Descriptor Range");
+ return makeRSError("Invalid format for Descriptor Range");
if (*ElementText == "CBV")
Range.RangeType = dxil::ResourceClass::CBuffer;
@@ -352,35 +374,34 @@ Error MetadataParser::parseDescriptorRange(mcdxbc::DescriptorTable &Table,
else if (*ElementText == "Sampler")
Range.RangeType = dxil::ResourceClass::Sampler;
else
- return make_error<GenericRSMetadataError>("Invalid Descriptor Range type.",
- RangeDescriptorNode);
+ return makeRSError(formatv("Invalid Descriptor Range type.\n{0}",
+ FmtMDNode{RangeDescriptorNode}));
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1))
Range.NumDescriptors = *Val;
else
- return make_error<GenericRSMetadataError>("Number of Descriptor in Range",
- RangeDescriptorNode);
+ return makeRSError(formatv("Invalid number of Descriptor in Range.\n{0}",
+ FmtMDNode{RangeDescriptorNode}));
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2))
Range.BaseShaderRegister = *Val;
else
- return make_error<InvalidRSMetadataValue>("BaseShaderRegister");
+ return makeRSError("Invalid value for BaseShaderRegister");
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3))
Range.RegisterSpace = *Val;
else
- return make_error<InvalidRSMetadataValue>("RegisterSpace");
+ return makeRSError("Invalid value for RegisterSpace");
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4))
Range.OffsetInDescriptorsFromTableStart = *Val;
else
- return make_error<InvalidRSMetadataValue>(
- "OffsetInDescriptorsFromTableStart");
+ return makeRSError("Invalid value for OffsetInDescriptorsFromTableStart");
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5))
Range.Flags = *Val;
else
- return make_error<InvalidRSMetadataValue>("Descriptor Range Flags");
+ return makeRSError("Invalid value for Descriptor Range Flags");
Table.Ranges.push_back(Range);
return Error::success();
@@ -390,7 +411,7 @@ Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
MDNode *DescriptorTableNode) {
const unsigned int NumOperands = DescriptorTableNode->getNumOperands();
if (NumOperands < 2)
- return make_error<InvalidRSMetadataFormat>("Descriptor Table");
+ return makeRSError("Invalid format for Descriptor Table");
Expected<dxbc::ShaderVisibility> Visibility =
extractEnumValue<dxbc::ShaderVisibility>(DescriptorTableNode, 1,
@@ -404,8 +425,8 @@ Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
for (unsigned int I = 2; I < NumOperands; I++) {
MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I));
if (Element == nullptr)
- return make_error<GenericRSMetadataError>(
- "Missing Root Element Metadata Node.", DescriptorTableNode);
+ return makeRSError(formatv("Missing Root Element Metadata Node.\n{0}",
+ FmtMDNode{DescriptorTableNode}));
if (auto Err = parseDescriptorRange(Table, Element))
return Err;
@@ -419,7 +440,7 @@ Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD,
MDNode *StaticSamplerNode) {
if (StaticSamplerNode->getNumOperands() != 15)
- return make_error<InvalidRSMetadataFormat>("Static Sampler");
+ return makeRSError("Invalid format for Static Sampler");
mcdxbc::StaticSampler Sampler;
@@ -453,12 +474,12 @@ Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD,
if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 5))
Sampler.MipLODBias = *Val;
else
- return make_error<InvalidRSMetadataValue>("MipLODBias");
+ return makeRSError("Invalid value for MipLODBias");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6))
Sampler.MaxAnisotropy = *Val;
else
- return make_error<InvalidRSMetadataValue>("MaxAnisotropy");
+ return makeRSError("Invalid value for MaxAnisotropy");
Expected<dxbc::ComparisonFunc> ComparisonFunc =
extractEnumValue<dxbc::ComparisonFunc>(
@@ -477,22 +498,22 @@ Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD,
if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 9))
Sampler.MinLOD = *Val;
else
- return make_error<InvalidRSMetadataValue>("MinLOD");
+ return makeRSError("Invalid value for MinLOD");
if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 10))
Sampler.MaxLOD = *Val;
else
- return make_error<InvalidRSMetadataValue>("MaxLOD");
+ return makeRSError("Invalid value for MaxLOD");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11))
Sampler.ShaderRegister = *Val;
else
- return make_error<InvalidRSMetadataValue>("ShaderRegister");
+ return makeRSError("Invalid value for ShaderRegister");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 12))
Sampler.RegisterSpace = *Val;
else
- return make_error<InvalidRSMetadataValue>("RegisterSpace");
+ return makeRSError("Invalid value for RegisterSpace");
Expected<dxbc::ShaderVisibility> Visibility =
extractEnumValue<dxbc::ShaderVisibility>(StaticSamplerNode, 13,
@@ -511,7 +532,7 @@ Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD,
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 14))
Sampler.Flags = *Val;
else
- return make_error<InvalidRSMetadataValue>("Static Sampler Flags");
+ return makeRSError("Invalid value for Static Sampler Flags");
RSD.StaticSamplers.push_back(Sampler);
return Error::success();
@@ -521,7 +542,7 @@ Error MetadataParser::parseRootSignatureElement(mcdxbc::RootSignatureDesc &RSD,
MDNode *Element) {
std::optional<StringRef> ElementText = extractMdStringValue(Element, 0);
if (!ElementText.has_value())
- return make_error<InvalidRSMetadataFormat>("Root Element");
+ return makeRSError("Invalid format for Root Element");
RootSignatureElementKind ElementKind =
StringSwitch<RootSignatureElementKind>(*ElementText)
@@ -549,8 +570,8 @@ Error MetadataParser::parseRootSignatureElement(mcdxbc::RootSignatureDesc &RSD,
case RootSignatureElementKind::StaticSamplers:
return parseStaticSampler(RSD, Element);
case RootSignatureElementKind::Error:
- return make_error<GenericRSMetadataError>("Invalid Root Signature Element",
- Element);
+ return makeRSError(
+ formatv("Invalid Root Signature Element\n{0}", FmtMDNode{Element}));
}
llvm_unreachable("Unhandled RootSignatureElementKind enum.");
@@ -563,7 +584,10 @@ validateDescriptorTableSamplerMixin(const mcdxbc::DescriptorTable &Table,
for (const mcdxbc::DescriptorRange &Range : Table.Ranges) {
if (Range.RangeType == dxil::ResourceClass::Sampler &&
CurrRC != dxil::ResourceClass::Sampler)
- return make_error<TableSamplerMixinError>(CurrRC, Location);
+ return makeRSError(
+ formatv("Samplers cannot be mixed with other resource types in a "
+ "descriptor table, {0}(location={1})",
+ getResourceClassName(CurrRC), Location));
CurrRC = Range.RangeType;
}
return Error::success();
@@ -583,8 +607,8 @@ validateDescriptorTableRegisterOverflow(const mcdxbc::DescriptorTable &Table,
Range.BaseShaderRegister, Range.NumDescriptors);
if (!verifyNoOverflowedOffset(RangeBound))
- return make_error<ShaderRegisterOverflowError>(
- Range.RangeType, Range.BaseShaderRegister, Range.RegisterSpace);
+ return makeRSError(
+ formatv("Overflow for shader register range: {0}", FmtRange{Range}));
bool IsAppending =
Range.OffsetInDescriptorsFromTableStart == DescriptorTableOffsetAppend;
@@ -592,15 +616,16 @@ validateDescriptorTableRegisterOverflow(const mcdxbc::DescriptorTable &Table,
Offset = Range.OffsetInDescriptorsFromTableStart;
if (IsPrevUnbound && IsAppending)
- return make_error<OffsetAppendAfterOverflow>(
- Range.RangeType, Range.BaseShaderRegister, Range.RegisterSpace);
+ return makeRSError(
+ formatv("Range {0} cannot be appended after an unbounded range",
+ FmtRange{Range}));
const uint64_t OffsetBound =
llvm::hlsl::rootsig::computeRangeBound(Offset, Range.NumDescriptors);
if (!verifyNoOverflowedOffset(OffsetBound))
- return make_error<OffsetOverflowError>(
- Range.RangeType, Range.BaseShaderRegister, Range.RegisterSpace);
+ return makeRSError(formatv("Offset overflow for descriptor range: {0}.",
+ FmtRange{Range}));
Offset = OffsetBound + 1;
IsPrevUnbound =
@@ -614,17 +639,15 @@ Error MetadataParser::validateRootSignature(
const mcdxbc::RootSignatureDesc &RSD) {
Error DeferredErrs = Error::success();
if (!hlsl::rootsig::verifyVersion(RSD.Version)) {
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- make_error<RootSignatureValidationError<uint32_t>>(
- "Version", RSD.Version));
+ DeferredErrs = joinErrors(
+ std::move(DeferredErrs),
+ makeRSError(formatv("Invalid value for Version: {0}", RSD.Version)));
}
if (!hlsl::rootsig::verifyRootFlag(RSD.Flags)) {
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- make_error<RootSignatureValidationError<uint32_t>>(
- "RootFlags", RSD.Flags));
+ DeferredErrs = joinErrors(
+ std::move(DeferredErrs),
+ makeRSError(formatv("Invalid value for RootFlags: {0}", RSD.Flags)));
}
for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) {
@@ -639,24 +662,27 @@ Error MetadataParser::validateRootSignature(
const mcdxbc::RootDescriptor &Descriptor =
RSD.ParametersContainer.getRootDescriptor(Info.Location);
if (!hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister))
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- make_error<RootSignatureValidationError<uint32_t>>(
- "ShaderRegister", Descriptor.ShaderRegister));
+ DeferredErrs = joinErrors(
+ std::move(DeferredErrs),
+ makeRSError(formatv("Invalid value for ShaderRegister: {0}",
+ Descriptor.ShaderRegister)));
if (!hlsl::rootsig::verifyRegisterSpace(Descriptor.RegisterSpace))
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- make_error<RootSignatureValidationError<uint32_t>>(
- "RegisterSpace", Descriptor.RegisterSpace));
+ DeferredErrs = joinErrors(
+ std::move(DeferredErrs),
+ makeRSError(formatv("Invalid value for RegisterSpace: {0}",
+ Descriptor.RegisterSpace)));
if (RSD.Version > 1) {
- if (!hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version,
- Descriptor.Flags))
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- make_error<RootSignatureValidationError<uint32_t>>(
- "RootDescriptorFlag", Descriptor.Flags));
+ bool IsValidFlag =
+ dxbc::isValidRootDesciptorFlags(Descriptor.Flags) &&
+ hlsl::rootsig::verifyRootDescriptorFlag(
+ RSD.Version, dxbc::RootDescriptorFlags(Descriptor.Flags));
+ if (!IsValidFlag)
+ DeferredErrs = joinErrors(
+ std::move(DeferredErrs),
+ makeRSError(formatv("Invalid value for RootDescriptorFlag: {0}",
+ Descriptor.Flags)));
}
break;
}
@@ -665,24 +691,26 @@ Error MetadataParser::validateRootSignature(
RSD.ParametersContainer.getDescriptorTable(Info.Location);
for (const mcdxbc::DescriptorRange &Range : Table) {
if (!hlsl::rootsig::verifyRegisterSpace(Range.RegisterSpace))
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- make_error<RootSignatureValidationError<uint32_t>>(
- "RegisterSpace", Range.RegisterSpace));
+ DeferredErrs = joinErrors(
+ std::move(DeferredErrs),
+ makeRSError(formatv("Invalid value for RegisterSpace: {0}",
+ Range.RegisterSpace)));
if (!hlsl::rootsig::verifyNumDescriptors(Range.NumDescriptors))
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- make_error<RootSignatureValidationError<uint32_t>>(
- "NumDescriptors", Range.NumDescriptors));
-
- if (!hlsl::rootsig::verifyDescriptorRangeFlag(
- RSD.Version, Range.RangeType,
- dxbc::DescriptorRangeFlags(Range.Flags)))
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- make_error<RootSignatureValidationError<uint32_t>>(
- "DescriptorFlag", Range.Flags));
+ DeferredErrs = joinErrors(
+ std::move(DeferredErrs),
+ makeRSError(formatv("Invalid value for NumDescriptors: {0}",
+ Range.NumDescriptors)));
+
+ bool IsValidFlag = dxbc::isValidDescriptorRangeFlags(Range.Flags) &&
+ hlsl::rootsig::verifyDescriptorRangeFlag(
+ RSD.Version, Range.RangeType,
+ dxbc::DescriptorRangeFlags(Range.Flags));
+ if (!IsValidFlag)
+ DeferredErrs = joinErrors(
+ std::move(DeferredErrs),
+ makeRSError(formatv("Invalid value for DescriptorFlag: {0}",
+ Range.Flags)));
if (Error Err =
validateDescriptorTableSamplerMixin(Table, Info.Location))
@@ -700,43 +728,49 @@ Error MetadataParser::validateRootSignature(
for (const mcdxbc::StaticSampler &Sampler : RSD.StaticSamplers) {
if (!hlsl::rootsig::verifyMipLODBias(Sampler.MipLODBias))
- DeferredErrs = joinErrors(std::move(DeferredErrs),
- make_error<RootSignatureValidationError<float>>(
- "MipLODBias", Sampler.MipLODBias));
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ makeRSError(formatv("Invalid value for MipLODBias: {0:e}",
+ Sampler.MipLODBias)));
if (!hlsl::rootsig::verifyMaxAnisotropy(Sampler.MaxAnisotropy))
DeferredErrs =
joinErrors(std::move(DeferredErrs),
- make_error<RootSignatureValidationError<uint32_t>>(
- "MaxAnisotropy", Sampler.MaxAnisotropy));
+ makeRSError(formatv("Invalid value for MaxAnisotropy: {0}",
+ Sampler.MaxAnisotropy)));
if (!hlsl::rootsig::verifyLOD(Sampler.MinLOD))
- DeferredErrs = joinErrors(std::move(DeferredErrs),
- make_error<RootSignatureValidationError<float>>(
- "MinLOD", Sampler.MinLOD));
-
- if (!hlsl::rootsig::verifyLOD(Sampler.MaxLOD))
- DeferredErrs = joinErrors(std::move(DeferredErrs),
- make_error<RootSignatureValidationError<float>>(
- "MaxLOD", Sampler.MaxLOD));
-
- if (!hlsl::rootsig::verifyRegisterValue(Sampler.ShaderRegister))
DeferredErrs =
joinErrors(std::move(DeferredErrs),
- make_error<RootSignatureValidationError<uint32_t>>(
- "ShaderRegister", Sampler.ShaderRegister));
+ makeRSError(formatv("Invalid value for MinLOD: {0}",
+ Sampler.MinLOD)));
- if (!hlsl::rootsig::verifyRegisterSpace(Sampler.RegisterSpace))
+ if (!hlsl::rootsig::verifyLOD(Sampler.MaxLOD))
DeferredErrs =
joinErrors(std::move(DeferredErrs),
- make_error<RootSignatureValidationError<uint32_t>>(
- "RegisterSpace", Sampler.RegisterSpace));
+ makeRSError(formatv("Invalid value for MaxLOD: {0}",
+ Sampler.MaxLOD)));
- if (!hlsl::rootsig::verifyStaticSamplerFlags(RSD.Version, Sampler.Flags))
+ if (!hlsl::rootsig::verifyRegisterValue(Sampler.ShaderRegister))
+ DeferredErrs = joinErrors(
+ std::move(DeferredErrs),
+ makeRSError(formatv("Invalid value for ShaderRegister: {0}",
+ Sampler.ShaderRegister)));
+
+ if (!hlsl::rootsig::verifyRegisterSpace(Sampler.RegisterSpace))
DeferredErrs =
joinErrors(std::move(DeferredErrs),
- make_error<RootSignatureValidationError<uint32_t>>(
- "Static Sampler Flag", Sampler.Flags));
+ makeRSError(formatv("Invalid value for RegisterSpace: {0}",
+ Sampler.RegisterSpace)));
+ bool IsValidFlag =
+ dxbc::isValidStaticSamplerFlags(Sampler.Flags) &&
+ hlsl::rootsig::verifyStaticSamplerFlags(
+ RSD.Version, dxbc::StaticSamplerFlags(Sampler.Flags));
+ if (!IsValidFlag)
+ DeferredErrs = joinErrors(
+ std::move(DeferredErrs),
+ makeRSError(formatv("Invalid value for Static Sampler Flag: {0}",
+ Sampler.Flags)));
}
return DeferredErrs;
@@ -750,9 +784,9 @@ MetadataParser::ParseRootSignature(uint32_t Version) {
for (const auto &Operand : Root->operands()) {
MDNode *Element = dyn_cast<MDNode>(Operand);
if (Element == nullptr)
- return joinErrors(std::move(DeferredErrs),
- make_error<GenericRSMetadataError>(
- "Missing Root Element Metadata Node.", nullptr));
+ return joinErrors(
+ std::move(DeferredErrs),
+ makeRSError(formatv("Missing Root Element Metadata Node.")));
if (auto Err = parseRootSignatureElement(RSD, Element))
DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err));
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
index 8a2b03d..30408df 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
@@ -34,7 +34,8 @@ bool verifyRegisterSpace(uint32_t RegisterSpace) {
return !(RegisterSpace >= 0xFFFFFFF0);
}
-bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) {
+bool verifyRootDescriptorFlag(uint32_t Version,
+ dxbc::RootDescriptorFlags FlagsVal) {
using FlagT = dxbc::RootDescriptorFlags;
FlagT Flags = FlagT(FlagsVal);
if (Version == 1)
@@ -56,7 +57,6 @@ bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) {
bool verifyDescriptorRangeFlag(uint32_t Version, dxil::ResourceClass Type,
dxbc::DescriptorRangeFlags Flags) {
using FlagT = dxbc::DescriptorRangeFlags;
-
const bool IsSampler = (Type == dxil::ResourceClass::Sampler);
if (Version == 1) {
@@ -113,13 +113,8 @@ bool verifyDescriptorRangeFlag(uint32_t Version, dxil::ResourceClass Type,
return (Flags & ~Mask) == FlagT::None;
}
-bool verifyStaticSamplerFlags(uint32_t Version, uint32_t FlagsNumber) {
- uint32_t LargestValue = llvm::to_underlying(
- dxbc::StaticSamplerFlags::LLVM_BITMASK_LARGEST_ENUMERATOR);
- if (FlagsNumber >= NextPowerOf2(LargestValue))
- return false;
-
- dxbc::StaticSamplerFlags Flags = dxbc::StaticSamplerFlags(FlagsNumber);
+bool verifyStaticSamplerFlags(uint32_t Version,
+ dxbc::StaticSamplerFlags Flags) {
if (Version <= 2)
return Flags == dxbc::StaticSamplerFlags::None;