diff options
Diffstat (limited to 'llvm/lib/Frontend')
-rw-r--r-- | llvm/lib/Frontend/HLSL/CMakeLists.txt | 1 | ||||
-rw-r--r-- | llvm/lib/Frontend/HLSL/HLSLBinding.cpp | 142 | ||||
-rw-r--r-- | llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp | 574 | ||||
-rw-r--r-- | llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 33 |
4 files changed, 708 insertions, 42 deletions
diff --git a/llvm/lib/Frontend/HLSL/CMakeLists.txt b/llvm/lib/Frontend/HLSL/CMakeLists.txt index 5343469..3d22577 100644 --- a/llvm/lib/Frontend/HLSL/CMakeLists.txt +++ b/llvm/lib/Frontend/HLSL/CMakeLists.txt @@ -1,5 +1,6 @@ add_llvm_component_library(LLVMFrontendHLSL CBuffer.cpp + HLSLBinding.cpp HLSLResource.cpp HLSLRootSignature.cpp RootSignatureMetadata.cpp diff --git a/llvm/lib/Frontend/HLSL/HLSLBinding.cpp b/llvm/lib/Frontend/HLSL/HLSLBinding.cpp new file mode 100644 index 0000000..d581311 --- /dev/null +++ b/llvm/lib/Frontend/HLSL/HLSLBinding.cpp @@ -0,0 +1,142 @@ +//===- HLSLBinding.cpp - Representation for resource bindings in HLSL -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Frontend/HLSL/HLSLBinding.h" +#include "llvm/ADT/STLExtras.h" + +using namespace llvm; +using namespace hlsl; + +std::optional<uint32_t> +BindingInfo::findAvailableBinding(dxil::ResourceClass RC, uint32_t Space, + int32_t Size) { + BindingSpaces &BS = getBindingSpaces(RC); + RegisterSpace &RS = BS.getOrInsertSpace(Space); + return RS.findAvailableBinding(Size); +} + +BindingInfo::RegisterSpace & +BindingInfo::BindingSpaces::getOrInsertSpace(uint32_t Space) { + for (auto It = Spaces.begin(), End = Spaces.end(); It != End; ++It) { + if (It->Space == Space) + return *It; + if (It->Space < Space) + continue; + return *Spaces.insert(It, Space); + } + return Spaces.emplace_back(Space); +} + +std::optional<uint32_t> +BindingInfo::RegisterSpace::findAvailableBinding(int32_t Size) { + assert((Size == -1 || Size > 0) && "invalid size"); + + if (FreeRanges.empty()) + return std::nullopt; + + // unbounded array + if (Size == -1) { + BindingRange &Last = FreeRanges.back(); + if (Last.UpperBound != ~0u) + // this space is already occupied by an unbounded array + return std::nullopt; + uint32_t RegSlot = Last.LowerBound; + FreeRanges.pop_back(); + return RegSlot; + } + + // single resource or fixed-size array + for (BindingRange &R : FreeRanges) { + // compare the size as uint64_t to prevent overflow for range (0, ~0u) + if ((uint64_t)R.UpperBound - R.LowerBound + 1 < (uint64_t)Size) + continue; + uint32_t RegSlot = R.LowerBound; + // This might create a range where (LowerBound == UpperBound + 1). When + // that happens, the next time this function is called the range will + // skipped over by the check above (at this point Size is always > 0). + R.LowerBound += Size; + return RegSlot; + } + + return std::nullopt; +} + +BindingInfo BindingInfoBuilder::calculateBindingInfo( + llvm::function_ref<void(const BindingInfoBuilder &Builder, + const Binding &Overlapping)> + ReportOverlap) { + // sort all the collected bindings + llvm::stable_sort(Bindings); + + // remove duplicates + Binding *NewEnd = llvm::unique(Bindings); + if (NewEnd != Bindings.end()) + Bindings.erase(NewEnd); + + BindingInfo Info; + + // Go over the sorted bindings and build up lists of free register ranges + // for each binding type and used spaces. Bindings are sorted by resource + // class, space, and lower bound register slot. + BindingInfo::BindingSpaces *BS = + &Info.getBindingSpaces(dxil::ResourceClass::SRV); + for (const Binding &B : Bindings) { + if (BS->RC != B.RC) + // move to the next resource class spaces + BS = &Info.getBindingSpaces(B.RC); + + BindingInfo::RegisterSpace *S = BS->Spaces.empty() + ? &BS->Spaces.emplace_back(B.Space) + : &BS->Spaces.back(); + assert(S->Space <= B.Space && "bindings not sorted correctly?"); + if (B.Space != S->Space) + // add new space + S = &BS->Spaces.emplace_back(B.Space); + + // The space is full - there are no free slots left, or the rest of the + // slots are taken by an unbounded array. Report the overlapping to the + // caller. + if (S->FreeRanges.empty() || S->FreeRanges.back().UpperBound < ~0u) { + ReportOverlap(*this, B); + continue; + } + // adjust the last free range lower bound, split it in two, or remove it + BindingInfo::BindingRange &LastFreeRange = S->FreeRanges.back(); + if (LastFreeRange.LowerBound == B.LowerBound) { + if (B.UpperBound < ~0u) + LastFreeRange.LowerBound = B.UpperBound + 1; + else + S->FreeRanges.pop_back(); + } else if (LastFreeRange.LowerBound < B.LowerBound) { + LastFreeRange.UpperBound = B.LowerBound - 1; + if (B.UpperBound < ~0u) + S->FreeRanges.emplace_back(B.UpperBound + 1, ~0u); + } else { + // We don't have room here. Report the overlapping binding to the caller + // and mark any extra space this binding would use as unavailable. + ReportOverlap(*this, B); + if (B.UpperBound < ~0u) + LastFreeRange.LowerBound = + std::max(LastFreeRange.LowerBound, B.UpperBound + 1); + else + S->FreeRanges.pop_back(); + } + } + + return Info; +} + +const BindingInfoBuilder::Binding &BindingInfoBuilder::findOverlapping( + const BindingInfoBuilder::Binding &ReportedBinding) const { + for (const BindingInfoBuilder::Binding &Other : Bindings) + if (ReportedBinding.LowerBound <= Other.UpperBound && + Other.LowerBound <= ReportedBinding.UpperBound) + return Other; + + llvm_unreachable("Searching for overlap for binding that does not overlap"); +} diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp index f7669f0..48ff1ca 100644 --- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp +++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp @@ -12,14 +12,45 @@ //===----------------------------------------------------------------------===// #include "llvm/Frontend/HLSL/RootSignatureMetadata.h" +#include "llvm/Frontend/HLSL/RootSignatureValidations.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Metadata.h" #include "llvm/Support/ScopedPrinter.h" +using namespace llvm; + namespace llvm { namespace hlsl { namespace rootsig { +char GenericRSMetadataError::ID; +char InvalidRSMetadataFormat::ID; +char InvalidRSMetadataValue::ID; +template <typename T> char RootSignatureValidationError<T>::ID; + +static std::optional<uint32_t> extractMdIntValue(MDNode *Node, + unsigned int OpId) { + if (auto *CI = + mdconst::dyn_extract<ConstantInt>(Node->getOperand(OpId).get())) + return CI->getZExtValue(); + return std::nullopt; +} + +static std::optional<float> extractMdFloatValue(MDNode *Node, + unsigned int OpId) { + if (auto *CI = mdconst::dyn_extract<ConstantFP>(Node->getOperand(OpId).get())) + return CI->getValueAPF().convertToFloat(); + return std::nullopt; +} + +static std::optional<StringRef> extractMdStringValue(MDNode *Node, + unsigned int OpId) { + MDString *NodeText = dyn_cast<MDString>(Node->getOperand(OpId)); + if (NodeText == nullptr) + return std::nullopt; + return NodeText->getString(); +} + static const EnumEntry<dxil::ResourceClass> ResourceClassNames[] = { {"CBV", dxil::ResourceClass::CBuffer}, {"SRV", dxil::ResourceClass::SRV}, @@ -82,7 +113,7 @@ MDNode *MetadataBuilder::BuildRootFlags(const dxbc::RootFlags &Flags) { IRBuilder<> Builder(Ctx); Metadata *Operands[] = { MDString::get(Ctx, "RootFlags"), - ConstantAsMetadata::get(Builder.getInt32(llvm::to_underlying(Flags))), + ConstantAsMetadata::get(Builder.getInt32(to_underlying(Flags))), }; return MDNode::get(Ctx, Operands); } @@ -92,7 +123,7 @@ MDNode *MetadataBuilder::BuildRootConstants(const RootConstants &Constants) { Metadata *Operands[] = { MDString::get(Ctx, "RootConstants"), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Constants.Visibility))), + Builder.getInt32(to_underlying(Constants.Visibility))), ConstantAsMetadata::get(Builder.getInt32(Constants.Reg.Number)), ConstantAsMetadata::get(Builder.getInt32(Constants.Space)), ConstantAsMetadata::get(Builder.getInt32(Constants.Num32BitConstants)), @@ -102,18 +133,18 @@ MDNode *MetadataBuilder::BuildRootConstants(const RootConstants &Constants) { MDNode *MetadataBuilder::BuildRootDescriptor(const RootDescriptor &Descriptor) { IRBuilder<> Builder(Ctx); - std::optional<StringRef> ResName = getResourceName( - dxil::ResourceClass(llvm::to_underlying(Descriptor.Type))); + std::optional<StringRef> ResName = + getResourceName(dxil::ResourceClass(to_underlying(Descriptor.Type))); assert(ResName && "Provided an invalid Resource Class"); - llvm::SmallString<7> Name({"Root", *ResName}); + SmallString<7> Name({"Root", *ResName}); Metadata *Operands[] = { MDString::get(Ctx, Name), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Descriptor.Visibility))), + Builder.getInt32(to_underlying(Descriptor.Visibility))), ConstantAsMetadata::get(Builder.getInt32(Descriptor.Reg.Number)), ConstantAsMetadata::get(Builder.getInt32(Descriptor.Space)), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Descriptor.Flags))), + Builder.getInt32(to_underlying(Descriptor.Flags))), }; return MDNode::get(Ctx, Operands); } @@ -124,7 +155,7 @@ MDNode *MetadataBuilder::BuildDescriptorTable(const DescriptorTable &Table) { // Set the mandatory arguments TableOperands.push_back(MDString::get(Ctx, "DescriptorTable")); TableOperands.push_back(ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Table.Visibility)))); + Builder.getInt32(to_underlying(Table.Visibility)))); // Remaining operands are references to the table's clauses. The in-memory // representation of the Root Elements created from parsing will ensure that @@ -144,7 +175,7 @@ MDNode *MetadataBuilder::BuildDescriptorTableClause( const DescriptorTableClause &Clause) { IRBuilder<> Builder(Ctx); std::optional<StringRef> ResName = - getResourceName(dxil::ResourceClass(llvm::to_underlying(Clause.Type))); + getResourceName(dxil::ResourceClass(to_underlying(Clause.Type))); assert(ResName && "Provided an invalid Resource Class"); Metadata *Operands[] = { MDString::get(Ctx, *ResName), @@ -152,8 +183,7 @@ MDNode *MetadataBuilder::BuildDescriptorTableClause( ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)), ConstantAsMetadata::get(Builder.getInt32(Clause.Space)), ConstantAsMetadata::get(Builder.getInt32(Clause.Offset)), - ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Clause.Flags))), + ConstantAsMetadata::get(Builder.getInt32(to_underlying(Clause.Flags))), }; return MDNode::get(Ctx, Operands); } @@ -162,33 +192,533 @@ MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) { IRBuilder<> Builder(Ctx); Metadata *Operands[] = { MDString::get(Ctx, "StaticSampler"), + ConstantAsMetadata::get(Builder.getInt32(to_underlying(Sampler.Filter))), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.Filter))), + Builder.getInt32(to_underlying(Sampler.AddressU))), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.AddressU))), + Builder.getInt32(to_underlying(Sampler.AddressV))), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.AddressV))), + Builder.getInt32(to_underlying(Sampler.AddressW))), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.AddressW))), - ConstantAsMetadata::get(llvm::ConstantFP::get(llvm::Type::getFloatTy(Ctx), - Sampler.MipLODBias)), + ConstantFP::get(Type::getFloatTy(Ctx), Sampler.MipLODBias)), ConstantAsMetadata::get(Builder.getInt32(Sampler.MaxAnisotropy)), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.CompFunc))), + Builder.getInt32(to_underlying(Sampler.CompFunc))), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.BorderColor))), + Builder.getInt32(to_underlying(Sampler.BorderColor))), ConstantAsMetadata::get( - llvm::ConstantFP::get(llvm::Type::getFloatTy(Ctx), Sampler.MinLOD)), + ConstantFP::get(Type::getFloatTy(Ctx), Sampler.MinLOD)), ConstantAsMetadata::get( - llvm::ConstantFP::get(llvm::Type::getFloatTy(Ctx), Sampler.MaxLOD)), + ConstantFP::get(Type::getFloatTy(Ctx), Sampler.MaxLOD)), ConstantAsMetadata::get(Builder.getInt32(Sampler.Reg.Number)), ConstantAsMetadata::get(Builder.getInt32(Sampler.Space)), ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Sampler.Visibility))), + Builder.getInt32(to_underlying(Sampler.Visibility))), }; return MDNode::get(Ctx, Operands); } +Error MetadataParser::parseRootFlags(mcdxbc::RootSignatureDesc &RSD, + MDNode *RootFlagNode) { + if (RootFlagNode->getNumOperands() != 2) + return make_error<InvalidRSMetadataFormat>("RootFlag Element"); + + if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1)) + RSD.Flags = *Val; + else + return make_error<InvalidRSMetadataValue>("RootFlag"); + + return Error::success(); +} + +Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD, + MDNode *RootConstantNode) { + if (RootConstantNode->getNumOperands() != 5) + return make_error<InvalidRSMetadataFormat>("RootConstants Element"); + + dxbc::RTS0::v1::RootParameterHeader Header; + // The parameter offset doesn't matter here - we recalculate it during + // serialization Header.ParameterOffset = 0; + Header.ParameterType = to_underlying(dxbc::RootParameterType::Constants32Bit); + + if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1)) + Header.ShaderVisibility = *Val; + else + return make_error<InvalidRSMetadataValue>("ShaderVisibility"); + + dxbc::RTS0::v1::RootConstants Constants; + if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2)) + Constants.ShaderRegister = *Val; + else + return make_error<InvalidRSMetadataValue>("ShaderRegister"); + + if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3)) + Constants.RegisterSpace = *Val; + else + return make_error<InvalidRSMetadataValue>("RegisterSpace"); + + if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4)) + Constants.Num32BitValues = *Val; + else + return make_error<InvalidRSMetadataValue>("Num32BitValues"); + + RSD.ParametersContainer.addParameter(Header, Constants); + + return Error::success(); +} + +Error MetadataParser::parseRootDescriptors( + mcdxbc::RootSignatureDesc &RSD, MDNode *RootDescriptorNode, + RootSignatureElementKind ElementKind) { + assert(ElementKind == RootSignatureElementKind::SRV || + ElementKind == RootSignatureElementKind::UAV || + ElementKind == RootSignatureElementKind::CBV && + "parseRootDescriptors should only be called with RootDescriptor " + "element kind."); + if (RootDescriptorNode->getNumOperands() != 5) + return make_error<InvalidRSMetadataFormat>("Root Descriptor Element"); + + dxbc::RTS0::v1::RootParameterHeader Header; + switch (ElementKind) { + case RootSignatureElementKind::SRV: + Header.ParameterType = to_underlying(dxbc::RootParameterType::SRV); + break; + case RootSignatureElementKind::UAV: + Header.ParameterType = to_underlying(dxbc::RootParameterType::UAV); + break; + case RootSignatureElementKind::CBV: + Header.ParameterType = to_underlying(dxbc::RootParameterType::CBV); + break; + default: + llvm_unreachable("invalid Root Descriptor kind"); + break; + } + + if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1)) + Header.ShaderVisibility = *Val; + else + return make_error<InvalidRSMetadataValue>("ShaderVisibility"); + + dxbc::RTS0::v2::RootDescriptor Descriptor; + if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2)) + Descriptor.ShaderRegister = *Val; + else + return make_error<InvalidRSMetadataValue>("ShaderRegister"); + + if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 3)) + Descriptor.RegisterSpace = *Val; + else + return make_error<InvalidRSMetadataValue>("RegisterSpace"); + + if (RSD.Version == 1) { + RSD.ParametersContainer.addParameter(Header, Descriptor); + return Error::success(); + } + assert(RSD.Version > 1); + + if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 4)) + Descriptor.Flags = *Val; + else + return make_error<InvalidRSMetadataValue>("Root Descriptor Flags"); + + RSD.ParametersContainer.addParameter(Header, Descriptor); + return Error::success(); +} + +Error MetadataParser::parseDescriptorRange(mcdxbc::DescriptorTable &Table, + MDNode *RangeDescriptorNode) { + if (RangeDescriptorNode->getNumOperands() != 6) + return make_error<InvalidRSMetadataFormat>("Descriptor Range"); + + dxbc::RTS0::v2::DescriptorRange Range; + + std::optional<StringRef> ElementText = + extractMdStringValue(RangeDescriptorNode, 0); + + if (!ElementText.has_value()) + return make_error<InvalidRSMetadataFormat>("Descriptor Range"); + + Range.RangeType = + StringSwitch<uint32_t>(*ElementText) + .Case("CBV", to_underlying(dxbc::DescriptorRangeType::CBV)) + .Case("SRV", to_underlying(dxbc::DescriptorRangeType::SRV)) + .Case("UAV", to_underlying(dxbc::DescriptorRangeType::UAV)) + .Case("Sampler", to_underlying(dxbc::DescriptorRangeType::Sampler)) + .Default(~0U); + + if (Range.RangeType == ~0U) + return make_error<GenericRSMetadataError>("Invalid Descriptor Range type.", + RangeDescriptorNode); + + if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1)) + Range.NumDescriptors = *Val; + else + return make_error<GenericRSMetadataError>("Number of Descriptor in Range", + RangeDescriptorNode); + + if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2)) + Range.BaseShaderRegister = *Val; + else + return make_error<InvalidRSMetadataValue>("BaseShaderRegister"); + + if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3)) + Range.RegisterSpace = *Val; + else + return make_error<InvalidRSMetadataValue>("RegisterSpace"); + + if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4)) + Range.OffsetInDescriptorsFromTableStart = *Val; + else + return make_error<InvalidRSMetadataValue>( + "OffsetInDescriptorsFromTableStart"); + + if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5)) + Range.Flags = *Val; + else + return make_error<InvalidRSMetadataValue>("Descriptor Range Flags"); + + Table.Ranges.push_back(Range); + return Error::success(); +} + +Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD, + MDNode *DescriptorTableNode) { + const unsigned int NumOperands = DescriptorTableNode->getNumOperands(); + if (NumOperands < 2) + return make_error<InvalidRSMetadataFormat>("Descriptor Table"); + + dxbc::RTS0::v1::RootParameterHeader Header; + if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1)) + Header.ShaderVisibility = *Val; + else + return make_error<InvalidRSMetadataValue>("ShaderVisibility"); + + mcdxbc::DescriptorTable Table; + Header.ParameterType = + to_underlying(dxbc::RootParameterType::DescriptorTable); + + for (unsigned int I = 2; I < NumOperands; I++) { + MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I)); + if (Element == nullptr) + return make_error<GenericRSMetadataError>( + "Missing Root Element Metadata Node.", DescriptorTableNode); + + if (auto Err = parseDescriptorRange(Table, Element)) + return Err; + } + + RSD.ParametersContainer.addParameter(Header, Table); + return Error::success(); +} + +Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD, + MDNode *StaticSamplerNode) { + if (StaticSamplerNode->getNumOperands() != 14) + return make_error<InvalidRSMetadataFormat>("Static Sampler"); + + dxbc::RTS0::v1::StaticSampler Sampler; + if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1)) + Sampler.Filter = *Val; + else + return make_error<InvalidRSMetadataValue>("Filter"); + + if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 2)) + Sampler.AddressU = *Val; + else + return make_error<InvalidRSMetadataValue>("AddressU"); + + if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 3)) + Sampler.AddressV = *Val; + else + return make_error<InvalidRSMetadataValue>("AddressV"); + + if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 4)) + Sampler.AddressW = *Val; + else + return make_error<InvalidRSMetadataValue>("AddressW"); + + if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 5)) + Sampler.MipLODBias = *Val; + else + return make_error<InvalidRSMetadataValue>("MipLODBias"); + + if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6)) + Sampler.MaxAnisotropy = *Val; + else + return make_error<InvalidRSMetadataValue>("MaxAnisotropy"); + + if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 7)) + Sampler.ComparisonFunc = *Val; + else + return make_error<InvalidRSMetadataValue>("ComparisonFunc"); + + if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 8)) + Sampler.BorderColor = *Val; + else + return make_error<InvalidRSMetadataValue>("ComparisonFunc"); + + if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 9)) + Sampler.MinLOD = *Val; + else + return make_error<InvalidRSMetadataValue>("MinLOD"); + + if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 10)) + Sampler.MaxLOD = *Val; + else + return make_error<InvalidRSMetadataValue>("MaxLOD"); + + if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11)) + Sampler.ShaderRegister = *Val; + else + return make_error<InvalidRSMetadataValue>("ShaderRegister"); + + if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 12)) + Sampler.RegisterSpace = *Val; + else + return make_error<InvalidRSMetadataValue>("RegisterSpace"); + + if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 13)) + Sampler.ShaderVisibility = *Val; + else + return make_error<InvalidRSMetadataValue>("ShaderVisibility"); + + RSD.StaticSamplers.push_back(Sampler); + return Error::success(); +} + +Error MetadataParser::parseRootSignatureElement(mcdxbc::RootSignatureDesc &RSD, + MDNode *Element) { + std::optional<StringRef> ElementText = extractMdStringValue(Element, 0); + if (!ElementText.has_value()) + return make_error<InvalidRSMetadataFormat>("Root Element"); + + RootSignatureElementKind ElementKind = + StringSwitch<RootSignatureElementKind>(*ElementText) + .Case("RootFlags", RootSignatureElementKind::RootFlags) + .Case("RootConstants", RootSignatureElementKind::RootConstants) + .Case("RootCBV", RootSignatureElementKind::CBV) + .Case("RootSRV", RootSignatureElementKind::SRV) + .Case("RootUAV", RootSignatureElementKind::UAV) + .Case("DescriptorTable", RootSignatureElementKind::DescriptorTable) + .Case("StaticSampler", RootSignatureElementKind::StaticSamplers) + .Default(RootSignatureElementKind::Error); + + switch (ElementKind) { + + case RootSignatureElementKind::RootFlags: + return parseRootFlags(RSD, Element); + case RootSignatureElementKind::RootConstants: + return parseRootConstants(RSD, Element); + case RootSignatureElementKind::CBV: + case RootSignatureElementKind::SRV: + case RootSignatureElementKind::UAV: + return parseRootDescriptors(RSD, Element, ElementKind); + case RootSignatureElementKind::DescriptorTable: + return parseDescriptorTable(RSD, Element); + case RootSignatureElementKind::StaticSamplers: + return parseStaticSampler(RSD, Element); + case RootSignatureElementKind::Error: + return make_error<GenericRSMetadataError>("Invalid Root Signature Element", + Element); + } + + llvm_unreachable("Unhandled RootSignatureElementKind enum."); +} + +Error MetadataParser::validateRootSignature( + const mcdxbc::RootSignatureDesc &RSD) { + Error DeferredErrs = Error::success(); + if (!hlsl::rootsig::verifyVersion(RSD.Version)) { + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "Version", RSD.Version)); + } + + if (!hlsl::rootsig::verifyRootFlag(RSD.Flags)) { + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "RootFlags", RSD.Flags)); + } + + for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) { + if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "ShaderVisibility", Info.Header.ShaderVisibility)); + + assert(dxbc::isValidParameterType(Info.Header.ParameterType) && + "Invalid value for ParameterType"); + + switch (Info.Header.ParameterType) { + + case to_underlying(dxbc::RootParameterType::CBV): + case to_underlying(dxbc::RootParameterType::UAV): + case to_underlying(dxbc::RootParameterType::SRV): { + const dxbc::RTS0::v2::RootDescriptor &Descriptor = + RSD.ParametersContainer.getRootDescriptor(Info.Location); + if (!hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "ShaderRegister", Descriptor.ShaderRegister)); + + if (!hlsl::rootsig::verifyRegisterSpace(Descriptor.RegisterSpace)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "RegisterSpace", Descriptor.RegisterSpace)); + + if (RSD.Version > 1) { + if (!hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version, + Descriptor.Flags)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "RootDescriptorFlag", Descriptor.Flags)); + } + break; + } + case to_underlying(dxbc::RootParameterType::DescriptorTable): { + const mcdxbc::DescriptorTable &Table = + RSD.ParametersContainer.getDescriptorTable(Info.Location); + for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) { + if (!hlsl::rootsig::verifyRangeType(Range.RangeType)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "RangeType", Range.RangeType)); + + if (!hlsl::rootsig::verifyRegisterSpace(Range.RegisterSpace)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "RegisterSpace", Range.RegisterSpace)); + + if (!hlsl::rootsig::verifyNumDescriptors(Range.NumDescriptors)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "NumDescriptors", Range.NumDescriptors)); + + if (!hlsl::rootsig::verifyDescriptorRangeFlag( + RSD.Version, Range.RangeType, Range.Flags)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "DescriptorFlag", Range.Flags)); + } + break; + } + } + } + + for (const dxbc::RTS0::v1::StaticSampler &Sampler : RSD.StaticSamplers) { + if (!hlsl::rootsig::verifySamplerFilter(Sampler.Filter)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "Filter", Sampler.Filter)); + + if (!hlsl::rootsig::verifyAddress(Sampler.AddressU)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "AddressU", Sampler.AddressU)); + + if (!hlsl::rootsig::verifyAddress(Sampler.AddressV)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "AddressV", Sampler.AddressV)); + + if (!hlsl::rootsig::verifyAddress(Sampler.AddressW)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "AddressW", Sampler.AddressW)); + + if (!hlsl::rootsig::verifyMipLODBias(Sampler.MipLODBias)) + DeferredErrs = joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<float>>( + "MipLODBias", Sampler.MipLODBias)); + + if (!hlsl::rootsig::verifyMaxAnisotropy(Sampler.MaxAnisotropy)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "MaxAnisotropy", Sampler.MaxAnisotropy)); + + if (!hlsl::rootsig::verifyComparisonFunc(Sampler.ComparisonFunc)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "ComparisonFunc", Sampler.ComparisonFunc)); + + if (!hlsl::rootsig::verifyBorderColor(Sampler.BorderColor)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "BorderColor", Sampler.BorderColor)); + + if (!hlsl::rootsig::verifyLOD(Sampler.MinLOD)) + DeferredErrs = joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<float>>( + "MinLOD", Sampler.MinLOD)); + + if (!hlsl::rootsig::verifyLOD(Sampler.MaxLOD)) + DeferredErrs = joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<float>>( + "MaxLOD", Sampler.MaxLOD)); + + if (!hlsl::rootsig::verifyRegisterValue(Sampler.ShaderRegister)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "ShaderRegister", Sampler.ShaderRegister)); + + if (!hlsl::rootsig::verifyRegisterSpace(Sampler.RegisterSpace)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "RegisterSpace", Sampler.RegisterSpace)); + + if (!dxbc::isValidShaderVisibility(Sampler.ShaderVisibility)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "ShaderVisibility", Sampler.ShaderVisibility)); + } + + return DeferredErrs; +} + +Expected<mcdxbc::RootSignatureDesc> +MetadataParser::ParseRootSignature(uint32_t Version) { + Error DeferredErrs = Error::success(); + mcdxbc::RootSignatureDesc RSD; + RSD.Version = Version; + for (const auto &Operand : Root->operands()) { + MDNode *Element = dyn_cast<MDNode>(Operand); + if (Element == nullptr) + return joinErrors(std::move(DeferredErrs), + make_error<GenericRSMetadataError>( + "Missing Root Element Metadata Node.", nullptr)); + + if (auto Err = parseRootSignatureElement(RSD, Element)) + DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err)); + } + + if (auto Err = validateRootSignature(RSD)) + DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err)); + + if (DeferredErrs) + return std::move(DeferredErrs); + + return std::move(RSD); +} } // namespace rootsig } // namespace hlsl } // namespace llvm diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index 840ca83..3aa4f7a 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -1161,7 +1161,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetKernel( Builder.restoreIP(AllocaIP); auto *KernelArgsPtr = Builder.CreateAlloca(OpenMPIRBuilder::KernelArgs, nullptr, "kernel_args"); - Builder.restoreIP(Loc.IP); + updateToLocation(Loc); for (unsigned I = 0, Size = KernelArgs.size(); I != Size; ++I) { llvm::Value *Arg = @@ -1189,7 +1189,6 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitKernelLaunch( if (!updateToLocation(Loc)) return Loc.IP; - Builder.restoreIP(Loc.IP); // On top of the arrays that were filled up, the target offloading call // takes as arguments the device id as well as the host pointer. The host // pointer is used by the runtime library to identify the current target @@ -2617,7 +2616,7 @@ void OpenMPIRBuilder::emitReductionListCopy( Expected<Function *> OpenMPIRBuilder::emitInterWarpCopyFunction( const LocationDescription &Loc, ArrayRef<ReductionInfo> ReductionInfos, AttributeList FuncAttrs) { - IRBuilder<>::InsertPointGuard IPG(Builder); + InsertPointTy SavedIP = Builder.saveIP(); LLVMContext &Ctx = M.getContext(); FunctionType *FuncTy = FunctionType::get( Builder.getVoidTy(), {Builder.getPtrTy(), Builder.getInt32Ty()}, @@ -2630,7 +2629,6 @@ Expected<Function *> OpenMPIRBuilder::emitInterWarpCopyFunction( WcFunc->addParamAttr(1, Attribute::NoUndef); BasicBlock *EntryBB = BasicBlock::Create(M.getContext(), "entry", WcFunc); Builder.SetInsertPoint(EntryBB); - Builder.SetCurrentDebugLocation(llvm::DebugLoc()); // ReduceList: thread local Reduce list. // At the stage of the computation when this function is called, partially @@ -2845,6 +2843,7 @@ Expected<Function *> OpenMPIRBuilder::emitInterWarpCopyFunction( } Builder.CreateRetVoid(); + Builder.restoreIP(SavedIP); return WcFunc; } @@ -2853,7 +2852,6 @@ Function *OpenMPIRBuilder::emitShuffleAndReduceFunction( ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn, AttributeList FuncAttrs) { LLVMContext &Ctx = M.getContext(); - IRBuilder<>::InsertPointGuard IPG(Builder); FunctionType *FuncTy = FunctionType::get(Builder.getVoidTy(), {Builder.getPtrTy(), Builder.getInt16Ty(), @@ -2872,7 +2870,6 @@ Function *OpenMPIRBuilder::emitShuffleAndReduceFunction( SarFunc->addParamAttr(3, Attribute::SExt); BasicBlock *EntryBB = BasicBlock::Create(M.getContext(), "entry", SarFunc); Builder.SetInsertPoint(EntryBB); - Builder.SetCurrentDebugLocation(llvm::DebugLoc()); // Thread local Reduce list used to host the values of data to be reduced. Argument *ReduceListArg = SarFunc->getArg(0); @@ -3019,7 +3016,7 @@ Function *OpenMPIRBuilder::emitShuffleAndReduceFunction( Function *OpenMPIRBuilder::emitListToGlobalCopyFunction( ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy, AttributeList FuncAttrs) { - IRBuilder<>::InsertPointGuard IPG(Builder); + OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP(); LLVMContext &Ctx = M.getContext(); FunctionType *FuncTy = FunctionType::get( Builder.getVoidTy(), @@ -3035,7 +3032,6 @@ Function *OpenMPIRBuilder::emitListToGlobalCopyFunction( BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGCFunc); Builder.SetInsertPoint(EntryBlock); - Builder.SetCurrentDebugLocation(llvm::DebugLoc()); // Buffer: global reduction buffer. Argument *BufferArg = LtGCFunc->getArg(0); @@ -3123,13 +3119,14 @@ Function *OpenMPIRBuilder::emitListToGlobalCopyFunction( } Builder.CreateRetVoid(); + Builder.restoreIP(OldIP); return LtGCFunc; } Function *OpenMPIRBuilder::emitListToGlobalReduceFunction( ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn, Type *ReductionsBufferTy, AttributeList FuncAttrs) { - IRBuilder<>::InsertPointGuard IPG(Builder); + OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP(); LLVMContext &Ctx = M.getContext(); FunctionType *FuncTy = FunctionType::get( Builder.getVoidTy(), @@ -3145,7 +3142,6 @@ Function *OpenMPIRBuilder::emitListToGlobalReduceFunction( BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGRFunc); Builder.SetInsertPoint(EntryBlock); - Builder.SetCurrentDebugLocation(llvm::DebugLoc()); // Buffer: global reduction buffer. Argument *BufferArg = LtGRFunc->getArg(0); @@ -3206,13 +3202,14 @@ Function *OpenMPIRBuilder::emitListToGlobalReduceFunction( Builder.CreateCall(ReduceFn, {LocalReduceListAddrCast, ReduceList}) ->addFnAttr(Attribute::NoUnwind); Builder.CreateRetVoid(); + Builder.restoreIP(OldIP); return LtGRFunc; } Function *OpenMPIRBuilder::emitGlobalToListCopyFunction( ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy, AttributeList FuncAttrs) { - IRBuilder<>::InsertPointGuard IPG(Builder); + OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP(); LLVMContext &Ctx = M.getContext(); FunctionType *FuncTy = FunctionType::get( Builder.getVoidTy(), @@ -3228,7 +3225,6 @@ Function *OpenMPIRBuilder::emitGlobalToListCopyFunction( BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGCFunc); Builder.SetInsertPoint(EntryBlock); - Builder.SetCurrentDebugLocation(llvm::DebugLoc()); // Buffer: global reduction buffer. Argument *BufferArg = LtGCFunc->getArg(0); @@ -3314,13 +3310,14 @@ Function *OpenMPIRBuilder::emitGlobalToListCopyFunction( } Builder.CreateRetVoid(); + Builder.restoreIP(OldIP); return LtGCFunc; } Function *OpenMPIRBuilder::emitGlobalToListReduceFunction( ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn, Type *ReductionsBufferTy, AttributeList FuncAttrs) { - IRBuilder<>::InsertPointGuard IPG(Builder); + OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP(); LLVMContext &Ctx = M.getContext(); auto *FuncTy = FunctionType::get( Builder.getVoidTy(), @@ -3336,7 +3333,6 @@ Function *OpenMPIRBuilder::emitGlobalToListReduceFunction( BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGRFunc); Builder.SetInsertPoint(EntryBlock); - Builder.SetCurrentDebugLocation(llvm::DebugLoc()); // Buffer: global reduction buffer. Argument *BufferArg = LtGRFunc->getArg(0); @@ -3397,6 +3393,7 @@ Function *OpenMPIRBuilder::emitGlobalToListReduceFunction( Builder.CreateCall(ReduceFn, {ReduceList, ReductionList}) ->addFnAttr(Attribute::NoUnwind); Builder.CreateRetVoid(); + Builder.restoreIP(OldIP); return LtGRFunc; } @@ -3409,7 +3406,6 @@ std::string OpenMPIRBuilder::getReductionFuncName(StringRef Name) const { Expected<Function *> OpenMPIRBuilder::createReductionFunction( StringRef ReducerName, ArrayRef<ReductionInfo> ReductionInfos, ReductionGenCBKind ReductionGenCBKind, AttributeList FuncAttrs) { - IRBuilder<>::InsertPointGuard IPG(Builder); auto *FuncTy = FunctionType::get(Builder.getVoidTy(), {Builder.getPtrTy(), Builder.getPtrTy()}, /* IsVarArg */ false); @@ -3422,7 +3418,6 @@ Expected<Function *> OpenMPIRBuilder::createReductionFunction( BasicBlock *EntryBB = BasicBlock::Create(M.getContext(), "entry", ReductionFunc); Builder.SetInsertPoint(EntryBB); - Builder.SetCurrentDebugLocation(llvm::DebugLoc()); // Need to alloca memory here and deal with the pointers before getting // LHS/RHS pointers out @@ -3750,12 +3745,10 @@ static Error populateReductionFunction( Function *ReductionFunc, ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos, IRBuilder<> &Builder, ArrayRef<bool> IsByRef, bool IsGPU) { - IRBuilder<>::InsertPointGuard IPG(Builder); Module *Module = ReductionFunc->getParent(); BasicBlock *ReductionFuncBlock = BasicBlock::Create(Module->getContext(), "", ReductionFunc); Builder.SetInsertPoint(ReductionFuncBlock); - Builder.SetCurrentDebugLocation(llvm::DebugLoc()); Value *LHSArrayPtr = nullptr; Value *RHSArrayPtr = nullptr; if (IsGPU) { @@ -5961,7 +5954,7 @@ OpenMPIRBuilder::createOrderedDepend(const LocationDescription &Loc, Builder.restoreIP(AllocaIP); AllocaInst *ArgsBase = Builder.CreateAlloca(ArrI64Ty, nullptr, Name); ArgsBase->setAlignment(Align(8)); - Builder.restoreIP(Loc.IP); + updateToLocation(Loc); // Store the index value with offset in depend vector. for (unsigned I = 0; I < NumLoops; ++I) { @@ -8087,7 +8080,7 @@ void OpenMPIRBuilder::createMapperAllocas(const LocationDescription &Loc, ".offload_ptrs"); AllocaInst *ArgSizes = Builder.CreateAlloca( ArrI64Ty, /* ArraySize = */ nullptr, ".offload_sizes"); - Builder.restoreIP(Loc.IP); + updateToLocation(Loc); MapperAllocas.ArgsBase = ArgsBase; MapperAllocas.Args = Args; MapperAllocas.ArgSizes = ArgSizes; |