diff options
Diffstat (limited to 'clang/lib')
25 files changed, 494 insertions, 4 deletions
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp index b5417fc..2cd9023 100644 --- a/clang/lib/AST/ASTContext.cpp +++ b/clang/lib/AST/ASTContext.cpp @@ -2493,6 +2493,19 @@ TypeInfo ASTContext::getTypeInfoImpl(const Type *T) const { return getTypeInfo( cast<HLSLAttributedResourceType>(T)->getWrappedType().getTypePtr()); + case Type::HLSLInlineSpirv: { + const auto *ST = cast<HLSLInlineSpirvType>(T); + // Size is specified in bytes, convert to bits + Width = ST->getSize() * 8; + Align = ST->getAlignment(); + if (Width == 0 && Align == 0) { + // We are defaulting to laying out opaque SPIR-V types as 32-bit ints. + Width = 32; + Align = 32; + } + break; + } + case Type::Atomic: { // Start with the base type information. TypeInfo Info = getTypeInfo(cast<AtomicType>(T)->getValueType()); @@ -3507,6 +3520,7 @@ static void encodeTypeForFunctionPointerAuth(const ASTContext &Ctx, return; } case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: llvm_unreachable("should never get here"); break; case Type::DeducedTemplateSpecialization: @@ -4228,6 +4242,7 @@ QualType ASTContext::getVariableArrayDecayedType(QualType type) const { case Type::DependentBitInt: case Type::ArrayParameter: case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: llvm_unreachable("type should never be variably-modified"); // These types can be variably-modified but should never need to @@ -5486,6 +5501,31 @@ QualType ASTContext::getHLSLAttributedResourceType( return QualType(Ty, 0); } + +QualType ASTContext::getHLSLInlineSpirvType(uint32_t Opcode, uint32_t Size, + uint32_t Alignment, + ArrayRef<SpirvOperand> Operands) { + llvm::FoldingSetNodeID ID; + HLSLInlineSpirvType::Profile(ID, Opcode, Size, Alignment, Operands); + + void *InsertPos = nullptr; + HLSLInlineSpirvType *Ty = + HLSLInlineSpirvTypes.FindNodeOrInsertPos(ID, InsertPos); + if (Ty) + return QualType(Ty, 0); + + void *Mem = Allocate( + HLSLInlineSpirvType::totalSizeToAlloc<SpirvOperand>(Operands.size()), + alignof(HLSLInlineSpirvType)); + + Ty = new (Mem) HLSLInlineSpirvType(Opcode, Size, Alignment, Operands); + + Types.push_back(Ty); + HLSLInlineSpirvTypes.InsertNode(Ty, InsertPos); + + return QualType(Ty, 0); +} + /// Retrieve a substitution-result type. QualType ASTContext::getSubstTemplateTypeParmType(QualType Replacement, Decl *AssociatedDecl, @@ -9457,6 +9497,7 @@ void ASTContext::getObjCEncodingForTypeImpl(QualType T, std::string &S, return; case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: llvm_unreachable("unexpected type"); case Type::ArrayParameter: @@ -11904,6 +11945,20 @@ QualType ASTContext::mergeTypes(QualType LHS, QualType RHS, bool OfBlockPointer, return LHS; return {}; } + case Type::HLSLInlineSpirv: + const HLSLInlineSpirvType *LHSTy = LHS->castAs<HLSLInlineSpirvType>(); + const HLSLInlineSpirvType *RHSTy = RHS->castAs<HLSLInlineSpirvType>(); + + if (LHSTy->getOpcode() == RHSTy->getOpcode() && + LHSTy->getSize() == RHSTy->getSize() && + LHSTy->getAlignment() == RHSTy->getAlignment()) { + for (size_t I = 0; I < LHSTy->getOperands().size(); I++) + if (LHSTy->getOperands()[I] != RHSTy->getOperands()[I]) + return {}; + + return LHS; + } + return {}; } llvm_unreachable("Invalid Type::Class!"); @@ -13922,6 +13977,7 @@ static QualType getCommonNonSugarTypeNode(ASTContext &Ctx, const Type *X, SUGAR_FREE_TYPE(SubstTemplateTypeParmPack) SUGAR_FREE_TYPE(UnresolvedUsing) SUGAR_FREE_TYPE(HLSLAttributedResource) + SUGAR_FREE_TYPE(HLSLInlineSpirv) #undef SUGAR_FREE_TYPE #define NON_UNIQUE_TYPE(Class) UNEXPECTED_TYPE(Class, "non-unique") NON_UNIQUE_TYPE(TypeOfExpr) @@ -14262,6 +14318,7 @@ static QualType getCommonSugarTypeNode(ASTContext &Ctx, const Type *X, CANONICAL_TYPE(FunctionProto) CANONICAL_TYPE(IncompleteArray) CANONICAL_TYPE(HLSLAttributedResource) + CANONICAL_TYPE(HLSLInlineSpirv) CANONICAL_TYPE(LValueReference) CANONICAL_TYPE(ObjCInterface) CANONICAL_TYPE(ObjCObject) diff --git a/clang/lib/AST/ASTImporter.cpp b/clang/lib/AST/ASTImporter.cpp index b481ad5..d275f71 100644 --- a/clang/lib/AST/ASTImporter.cpp +++ b/clang/lib/AST/ASTImporter.cpp @@ -1825,6 +1825,43 @@ ExpectedType clang::ASTNodeImporter::VisitHLSLAttributedResourceType( ToWrappedType, ToContainedType, ToAttrs); } +ExpectedType clang::ASTNodeImporter::VisitHLSLInlineSpirvType( + const clang::HLSLInlineSpirvType *T) { + Error Err = Error::success(); + + uint32_t ToOpcode = T->getOpcode(); + uint32_t ToSize = T->getSize(); + uint32_t ToAlignment = T->getAlignment(); + + llvm::SmallVector<SpirvOperand> ToOperands; + + for (auto &Operand : T->getOperands()) { + using SpirvOperandKind = SpirvOperand::SpirvOperandKind; + + switch (Operand.getKind()) { + case SpirvOperandKind::ConstantId: + ToOperands.push_back(SpirvOperand::createConstant( + importChecked(Err, Operand.getResultType()), Operand.getValue())); + break; + case SpirvOperandKind::Literal: + ToOperands.push_back(SpirvOperand::createLiteral(Operand.getValue())); + break; + case SpirvOperandKind::TypeId: + ToOperands.push_back(SpirvOperand::createType( + importChecked(Err, Operand.getResultType()))); + break; + default: + llvm_unreachable("Invalid SpirvOperand kind"); + } + + if (Err) + return std::move(Err); + } + + return Importer.getToContext().getHLSLInlineSpirvType( + ToOpcode, ToSize, ToAlignment, ToOperands); +} + ExpectedType clang::ASTNodeImporter::VisitConstantMatrixType( const clang::ConstantMatrixType *T) { ExpectedType ToElementTypeOrErr = import(T->getElementType()); diff --git a/clang/lib/AST/ASTStructuralEquivalence.cpp b/clang/lib/AST/ASTStructuralEquivalence.cpp index 499854a..47c8812 100644 --- a/clang/lib/AST/ASTStructuralEquivalence.cpp +++ b/clang/lib/AST/ASTStructuralEquivalence.cpp @@ -1157,6 +1157,23 @@ static bool IsStructurallyEquivalent(StructuralEquivalenceContext &Context, return false; break; + case Type::HLSLInlineSpirv: + if (cast<HLSLInlineSpirvType>(T1)->getOpcode() != + cast<HLSLInlineSpirvType>(T2)->getOpcode() || + cast<HLSLInlineSpirvType>(T1)->getSize() != + cast<HLSLInlineSpirvType>(T2)->getSize() || + cast<HLSLInlineSpirvType>(T1)->getAlignment() != + cast<HLSLInlineSpirvType>(T2)->getAlignment()) + return false; + for (size_t I = 0; I < cast<HLSLInlineSpirvType>(T1)->getOperands().size(); + I++) { + if (cast<HLSLInlineSpirvType>(T1)->getOperands()[I] != + cast<HLSLInlineSpirvType>(T2)->getOperands()[I]) { + return false; + } + } + break; + case Type::Paren: if (!IsStructurallyEquivalent(Context, cast<ParenType>(T1)->getInnerType(), cast<ParenType>(T2)->getInnerType())) diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp index 39fc714..c7488ea 100644 --- a/clang/lib/AST/ExprConstant.cpp +++ b/clang/lib/AST/ExprConstant.cpp @@ -12452,6 +12452,7 @@ GCCTypeClass EvaluateBuiltinClassifyType(QualType T, case Type::ObjCObjectPointer: case Type::Pipe: case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: // Classify all other types that don't fit into the regular // classification the same way. return GCCTypeClass::None; diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp index 33a8728..17c0733 100644 --- a/clang/lib/AST/ItaniumMangle.cpp +++ b/clang/lib/AST/ItaniumMangle.cpp @@ -2461,6 +2461,7 @@ bool CXXNameMangler::mangleUnresolvedTypeOrSimpleId(QualType Ty, case Type::Attributed: case Type::BTFTagAttributed: case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: case Type::Auto: case Type::DeducedTemplateSpecialization: case Type::PackExpansion: @@ -4692,6 +4693,44 @@ void CXXNameMangler::mangleType(const HLSLAttributedResourceType *T) { mangleType(T->getWrappedType()); } +void CXXNameMangler::mangleType(const HLSLInlineSpirvType *T) { + SmallString<20> TypeNameStr; + llvm::raw_svector_ostream TypeNameOS(TypeNameStr); + + TypeNameOS << "spirv_type"; + + TypeNameOS << "_" << T->getOpcode(); + TypeNameOS << "_" << T->getSize(); + TypeNameOS << "_" << T->getAlignment(); + + mangleVendorType(TypeNameStr); + + for (auto &Operand : T->getOperands()) { + using SpirvOperandKind = SpirvOperand::SpirvOperandKind; + + switch (Operand.getKind()) { + case SpirvOperandKind::ConstantId: + mangleVendorQualifier("_Const"); + mangleIntegerLiteral(Operand.getResultType(), + llvm::APSInt(Operand.getValue())); + break; + case SpirvOperandKind::Literal: + mangleVendorQualifier("_Lit"); + mangleIntegerLiteral(Context.getASTContext().IntTy, + llvm::APSInt(Operand.getValue())); + break; + case SpirvOperandKind::TypeId: + mangleVendorQualifier("_Type"); + mangleType(Operand.getResultType()); + break; + default: + llvm_unreachable("Invalid SpirvOperand kind"); + break; + } + TypeNameOS << Operand.getKind(); + } +} + void CXXNameMangler::mangleIntegerLiteral(QualType T, const llvm::APSInt &Value) { // <expr-primary> ::= L <type> <value number> E # integer literal @@ -4705,7 +4744,6 @@ void CXXNameMangler::mangleIntegerLiteral(QualType T, mangleNumber(Value); } Out << 'E'; - } void CXXNameMangler::mangleMemberExprBase(const Expr *Base, bool IsArrow) { diff --git a/clang/lib/AST/MicrosoftMangle.cpp b/clang/lib/AST/MicrosoftMangle.cpp index add737b7..290521a 100644 --- a/clang/lib/AST/MicrosoftMangle.cpp +++ b/clang/lib/AST/MicrosoftMangle.cpp @@ -3768,6 +3768,11 @@ void MicrosoftCXXNameMangler::mangleType(const HLSLAttributedResourceType *T, llvm_unreachable("HLSL uses Itanium name mangling"); } +void MicrosoftCXXNameMangler::mangleType(const HLSLInlineSpirvType *T, + Qualifiers, SourceRange Range) { + llvm_unreachable("HLSL uses Itanium name mangling"); +} + // <this-adjustment> ::= <no-adjustment> | <static-adjustment> | // <virtual-adjustment> // <no-adjustment> ::= A # private near diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp index df084dd..35a5f8e 100644 --- a/clang/lib/AST/Type.cpp +++ b/clang/lib/AST/Type.cpp @@ -4764,6 +4764,8 @@ static CachedProperties computeCachedProperties(const Type *T) { return Cache::get(cast<PipeType>(T)->getElementType()); case Type::HLSLAttributedResource: return Cache::get(cast<HLSLAttributedResourceType>(T)->getWrappedType()); + case Type::HLSLInlineSpirv: + return CachedProperties(Linkage::External, false); } llvm_unreachable("unhandled type class"); @@ -4862,6 +4864,17 @@ LinkageInfo LinkageComputer::computeTypeLinkageInfo(const Type *T) { return computeTypeLinkageInfo(cast<HLSLAttributedResourceType>(T) ->getContainedType() ->getCanonicalTypeInternal()); + case Type::HLSLInlineSpirv: + return LinkageInfo::external(); + { + const auto *ST = cast<HLSLInlineSpirvType>(T); + LinkageInfo LV = LinkageInfo::external(); + for (auto &Operand : ST->getOperands()) { + if (Operand.isConstant() || Operand.isType()) + LV.merge(computeTypeLinkageInfo(Operand.getResultType())); + } + return LV; + } } llvm_unreachable("unhandled type class"); @@ -5049,6 +5062,7 @@ bool Type::canHaveNullability(bool ResultIfUnknown) const { case Type::DependentBitInt: case Type::ArrayParameter: case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: return false; } llvm_unreachable("bad type kind!"); diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp index cba1a2d..4793ef3 100644 --- a/clang/lib/AST/TypePrinter.cpp +++ b/clang/lib/AST/TypePrinter.cpp @@ -247,6 +247,7 @@ bool TypePrinter::canPrefixQualifiers(const Type *T, case Type::DependentBitInt: case Type::BTFTagAttributed: case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: CanPrefixQualifiers = true; break; @@ -2139,6 +2140,53 @@ void TypePrinter::printHLSLAttributedResourceAfter( } } +void TypePrinter::printHLSLInlineSpirvBefore(const HLSLInlineSpirvType *T, + raw_ostream &OS) { + OS << "__hlsl_spirv_type<" << T->getOpcode(); + + OS << ", " << T->getSize(); + OS << ", " << T->getAlignment(); + + for (auto &Operand : T->getOperands()) { + using SpirvOperandKind = SpirvOperand::SpirvOperandKind; + + OS << ", "; + switch (Operand.getKind()) { + case SpirvOperandKind::ConstantId: { + QualType ConstantType = Operand.getResultType(); + OS << "vk::integral_constant<"; + printBefore(ConstantType, OS); + printAfter(ConstantType, OS); + OS << ", "; + OS << Operand.getValue(); + OS << ">"; + break; + } + case SpirvOperandKind::Literal: + OS << "vk::Literal<vk::integral_constant<uint, "; + OS << Operand.getValue(); + OS << ">>"; + break; + case SpirvOperandKind::TypeId: { + QualType Type = Operand.getResultType(); + printBefore(Type, OS); + printAfter(Type, OS); + break; + } + default: + llvm_unreachable("Invalid SpirvOperand kind!"); + break; + } + } + + OS << ">"; +} + +void TypePrinter::printHLSLInlineSpirvAfter(const HLSLInlineSpirvType *T, + raw_ostream &OS) { + // nothing to do +} + void TypePrinter::printObjCInterfaceBefore(const ObjCInterfaceType *T, raw_ostream &OS) { OS << T->getDecl()->getName(); diff --git a/clang/lib/CodeGen/CGDebugInfo.cpp b/clang/lib/CodeGen/CGDebugInfo.cpp index 21896c9..d5662b1 100644 --- a/clang/lib/CodeGen/CGDebugInfo.cpp +++ b/clang/lib/CodeGen/CGDebugInfo.cpp @@ -3638,6 +3638,12 @@ llvm::DIType *CGDebugInfo::CreateType(const HLSLAttributedResourceType *Ty, return getOrCreateType(Ty->getWrappedType(), U); } +llvm::DIType *CGDebugInfo::CreateType(const HLSLInlineSpirvType *Ty, + llvm::DIFile *U) { + // Debug information unneeded. + return nullptr; +} + llvm::DIType *CGDebugInfo::CreateEnumType(const EnumType *Ty) { const EnumDecl *ED = Ty->getDecl(); @@ -3991,6 +3997,8 @@ llvm::DIType *CGDebugInfo::CreateTypeNode(QualType Ty, llvm::DIFile *Unit) { return CreateType(cast<TemplateSpecializationType>(Ty), Unit); case Type::HLSLAttributedResource: return CreateType(cast<HLSLAttributedResourceType>(Ty), Unit); + case Type::HLSLInlineSpirv: + return CreateType(cast<HLSLInlineSpirvType>(Ty), Unit); case Type::CountAttributed: case Type::Auto: diff --git a/clang/lib/CodeGen/CGDebugInfo.h b/clang/lib/CodeGen/CGDebugInfo.h index 79d031a..ec27fb0 100644 --- a/clang/lib/CodeGen/CGDebugInfo.h +++ b/clang/lib/CodeGen/CGDebugInfo.h @@ -210,6 +210,7 @@ private: llvm::DIType *CreateType(const FunctionType *Ty, llvm::DIFile *F); llvm::DIType *CreateType(const HLSLAttributedResourceType *Ty, llvm::DIFile *F); + llvm::DIType *CreateType(const HLSLInlineSpirvType *Ty, llvm::DIFile *F); /// Get structure or union type. llvm::DIType *CreateType(const RecordType *Tyg); diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp index 0356952..e235756 100644 --- a/clang/lib/CodeGen/CodeGenFunction.cpp +++ b/clang/lib/CodeGen/CodeGenFunction.cpp @@ -283,6 +283,7 @@ TypeEvaluationKind CodeGenFunction::getEvaluationKind(QualType type) { case Type::Pipe: case Type::BitInt: case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: return TEK_Scalar; // Complexes. @@ -2473,6 +2474,7 @@ void CodeGenFunction::EmitVariablyModifiedType(QualType type) { case Type::ObjCInterface: case Type::ObjCObjectPointer: case Type::BitInt: + case Type::HLSLInlineSpirv: llvm_unreachable("type class is never variably-modified!"); case Type::Elaborated: diff --git a/clang/lib/CodeGen/CodeGenTypes.cpp b/clang/lib/CodeGen/CodeGenTypes.cpp index 843733b..36c5f2b 100644 --- a/clang/lib/CodeGen/CodeGenTypes.cpp +++ b/clang/lib/CodeGen/CodeGenTypes.cpp @@ -765,6 +765,7 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) { break; } case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: ResultType = CGM.getHLSLRuntime().convertHLSLSpecificType(Ty); break; } @@ -877,6 +878,10 @@ bool CodeGenTypes::isZeroInitializable(QualType T) { if (const MemberPointerType *MPT = T->getAs<MemberPointerType>()) return getCXXABI().isZeroInitializable(MPT); + // HLSL Inline SPIR-V types are non-zero-initializable. + if (T->getAs<HLSLInlineSpirvType>()) + return false; + // Everything else is okay. return true; } diff --git a/clang/lib/CodeGen/ItaniumCXXABI.cpp b/clang/lib/CodeGen/ItaniumCXXABI.cpp index 5018a6b..e811474 100644 --- a/clang/lib/CodeGen/ItaniumCXXABI.cpp +++ b/clang/lib/CodeGen/ItaniumCXXABI.cpp @@ -3962,6 +3962,7 @@ void ItaniumRTTIBuilder::BuildVTablePointer(const Type *Ty, break; case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: llvm_unreachable("HLSL doesn't support virtual functions"); } @@ -4237,6 +4238,7 @@ llvm::Constant *ItaniumRTTIBuilder::BuildTypeInfo( break; case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: llvm_unreachable("HLSL doesn't support RTTI"); } diff --git a/clang/lib/CodeGen/Targets/SPIR.cpp b/clang/lib/CodeGen/Targets/SPIR.cpp index f35c124..cb190b3 100644 --- a/clang/lib/CodeGen/Targets/SPIR.cpp +++ b/clang/lib/CodeGen/Targets/SPIR.cpp @@ -377,14 +377,99 @@ llvm::Type *CommonSPIRTargetCodeGenInfo::getOpenCLType(CodeGenModule &CGM, return nullptr; } +// Gets a spirv.IntegralConstant or spirv.Literal. If IntegralType is present, +// returns an IntegralConstant, otherwise returns a Literal. +static llvm::Type *getInlineSpirvConstant(CodeGenModule &CGM, + llvm::Type *IntegralType, + llvm::APInt Value) { + llvm::LLVMContext &Ctx = CGM.getLLVMContext(); + + // Convert the APInt value to an array of uint32_t words + llvm::SmallVector<uint32_t> Words; + + while (Value.ugt(0)) { + uint32_t Word = Value.trunc(32).getZExtValue(); + Value.lshrInPlace(32); + + Words.push_back(Word); + } + if (Words.size() == 0) + Words.push_back(0); + + if (IntegralType) + return llvm::TargetExtType::get(Ctx, "spirv.IntegralConstant", + {IntegralType}, Words); + return llvm::TargetExtType::get(Ctx, "spirv.Literal", {}, Words); +} + +static llvm::Type *getInlineSpirvType(CodeGenModule &CGM, + const HLSLInlineSpirvType *SpirvType) { + llvm::LLVMContext &Ctx = CGM.getLLVMContext(); + + llvm::SmallVector<llvm::Type *> Operands; + + for (auto &Operand : SpirvType->getOperands()) { + using SpirvOperandKind = SpirvOperand::SpirvOperandKind; + + llvm::Type *Result = nullptr; + switch (Operand.getKind()) { + case SpirvOperandKind::ConstantId: { + llvm::Type *IntegralType = + CGM.getTypes().ConvertType(Operand.getResultType()); + llvm::APInt Value = Operand.getValue(); + + Result = getInlineSpirvConstant(CGM, IntegralType, Value); + break; + } + case SpirvOperandKind::Literal: { + llvm::APInt Value = Operand.getValue(); + Result = getInlineSpirvConstant(CGM, nullptr, Value); + break; + } + case SpirvOperandKind::TypeId: { + QualType TypeOperand = Operand.getResultType(); + if (auto *RT = TypeOperand->getAs<RecordType>()) { + auto *RD = RT->getDecl(); + assert(RD->isCompleteDefinition() && + "Type completion should have been required in Sema"); + + const FieldDecl *HandleField = RD->findFirstNamedDataMember(); + if (HandleField) { + QualType ResourceType = HandleField->getType(); + if (ResourceType->getAs<HLSLAttributedResourceType>()) { + TypeOperand = ResourceType; + } + } + } + Result = CGM.getTypes().ConvertType(TypeOperand); + break; + } + default: + llvm_unreachable("HLSLInlineSpirvType had invalid operand!"); + break; + } + + assert(Result); + Operands.push_back(Result); + } + + return llvm::TargetExtType::get(Ctx, "spirv.Type", Operands, + {SpirvType->getOpcode(), SpirvType->getSize(), + SpirvType->getAlignment()}); +} + llvm::Type *CommonSPIRTargetCodeGenInfo::getHLSLType( CodeGenModule &CGM, const Type *Ty, const SmallVector<int32_t> *Packoffsets) const { + llvm::LLVMContext &Ctx = CGM.getLLVMContext(); + + if (auto *SpirvType = dyn_cast<HLSLInlineSpirvType>(Ty)) + return getInlineSpirvType(CGM, SpirvType); + auto *ResType = dyn_cast<HLSLAttributedResourceType>(Ty); if (!ResType) return nullptr; - llvm::LLVMContext &Ctx = CGM.getLLVMContext(); const HLSLAttributedResourceType::Attributes &ResAttrs = ResType->getAttrs(); switch (ResAttrs.ResourceClass) { case llvm::dxil::ResourceClass::UAV: diff --git a/clang/lib/Headers/CMakeLists.txt b/clang/lib/Headers/CMakeLists.txt index 449feb01..53219dc 100644 --- a/clang/lib/Headers/CMakeLists.txt +++ b/clang/lib/Headers/CMakeLists.txt @@ -91,6 +91,7 @@ set(hlsl_subdir_files hlsl/hlsl_intrinsic_helpers.h hlsl/hlsl_intrinsics.h hlsl/hlsl_detail.h + hlsl/hlsl_spirv.h ) set(hlsl_files ${hlsl_h} diff --git a/clang/lib/Headers/hlsl.h b/clang/lib/Headers/hlsl.h index b494b4d..684d29d 100644 --- a/clang/lib/Headers/hlsl.h +++ b/clang/lib/Headers/hlsl.h @@ -27,6 +27,10 @@ #endif #include "hlsl/hlsl_intrinsics.h" +#ifdef __spirv__ +#include "hlsl/hlsl_spirv.h" +#endif // __spirv__ + #if defined(__clang__) #pragma clang diagnostic pop #endif diff --git a/clang/lib/Headers/hlsl/hlsl_spirv.h b/clang/lib/Headers/hlsl/hlsl_spirv.h new file mode 100644 index 0000000..711da2f --- /dev/null +++ b/clang/lib/Headers/hlsl/hlsl_spirv.h @@ -0,0 +1,28 @@ +//===----- hlsl_spirv.h - HLSL definitions for SPIR-V target --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef _HLSL_HLSL_SPIRV_H_ +#define _HLSL_HLSL_SPIRV_H_ + +namespace hlsl { +namespace vk { +template <typename T, T v> struct integral_constant { + static constexpr T value = v; +}; + +template <typename T> struct Literal {}; + +template <uint Opcode, uint Size, uint Alignment, typename... Operands> +using SpirvType = __hlsl_spirv_type<Opcode, Size, Alignment, Operands...>; + +template <uint Opcode, typename... Operands> +using SpirvOpaqueType = __hlsl_spirv_type<Opcode, 0, 0, Operands...>; +} // namespace vk +} // namespace hlsl + +#endif // _HLSL_HLSL_SPIRV_H_
\ No newline at end of file diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp index 5aa3393..198b7da 100644 --- a/clang/lib/Sema/SemaExpr.cpp +++ b/clang/lib/Sema/SemaExpr.cpp @@ -4562,6 +4562,7 @@ static void captureVariablyModifiedType(ASTContext &Context, QualType T, case Type::ObjCTypeParam: case Type::Pipe: case Type::BitInt: + case Type::HLSLInlineSpirv: llvm_unreachable("type class is never variably-modified!"); case Type::Elaborated: T = cast<ElaboratedType>(Ty)->getNamedType(); diff --git a/clang/lib/Sema/SemaLookup.cpp b/clang/lib/Sema/SemaLookup.cpp index 55428af..eef134b 100644 --- a/clang/lib/Sema/SemaLookup.cpp +++ b/clang/lib/Sema/SemaLookup.cpp @@ -927,13 +927,25 @@ bool Sema::LookupBuiltin(LookupResult &R) { NameKind == Sema::LookupRedeclarationWithLinkage) { IdentifierInfo *II = R.getLookupName().getAsIdentifierInfo(); if (II) { - if (getLangOpts().CPlusPlus && NameKind == Sema::LookupOrdinaryName) { -#define BuiltinTemplate(BIName) \ + if (NameKind == Sema::LookupOrdinaryName) { + if (getLangOpts().CPlusPlus) { +#define BuiltinTemplate(BIName) +#define CPlusPlusBuiltinTemplate(BIName) \ if (II == getASTContext().get##BIName##Name()) { \ R.addDecl(getASTContext().get##BIName##Decl()); \ return true; \ } #include "clang/Basic/BuiltinTemplates.inc" + } + if (getLangOpts().HLSL) { +#define BuiltinTemplate(BIName) +#define HLSLBuiltinTemplate(BIName) \ + if (II == getASTContext().get##BIName##Name()) { \ + R.addDecl(getASTContext().get##BIName##Decl()); \ + return true; \ + } +#include "clang/Basic/BuiltinTemplates.inc" + } } // Check if this is an OpenCL Builtin, and if so, insert its overloads. @@ -3270,6 +3282,11 @@ addAssociatedClassesAndNamespaces(AssociatedLookup &Result, QualType Ty) { case Type::HLSLAttributedResource: T = cast<HLSLAttributedResourceType>(T)->getWrappedType().getTypePtr(); + break; + + // Inline SPIR-V types are treated as fundamental types. + case Type::HLSLInlineSpirv: + break; } if (Queue.empty()) diff --git a/clang/lib/Sema/SemaTemplate.cpp b/clang/lib/Sema/SemaTemplate.cpp index 41c3f81..10e7823 100644 --- a/clang/lib/Sema/SemaTemplate.cpp +++ b/clang/lib/Sema/SemaTemplate.cpp @@ -3234,6 +3234,59 @@ static QualType builtinCommonTypeImpl(Sema &S, TemplateName BaseTemplate, } } +static bool isInVkNamespace(const RecordType *RT) { + DeclContext *DC = RT->getDecl()->getDeclContext(); + if (!DC) + return false; + + NamespaceDecl *ND = dyn_cast<NamespaceDecl>(DC); + if (!ND) + return false; + + return ND->getQualifiedNameAsString() == "hlsl::vk"; +} + +static SpirvOperand checkHLSLSpirvTypeOperand(Sema &SemaRef, + QualType OperandArg, + SourceLocation Loc) { + if (auto *RT = OperandArg->getAs<RecordType>()) { + bool Literal = false; + SourceLocation LiteralLoc; + if (isInVkNamespace(RT) && RT->getDecl()->getName() == "Literal") { + auto SpecDecl = dyn_cast<ClassTemplateSpecializationDecl>(RT->getDecl()); + assert(SpecDecl); + + const TemplateArgumentList &LiteralArgs = SpecDecl->getTemplateArgs(); + QualType ConstantType = LiteralArgs[0].getAsType(); + RT = ConstantType->getAs<RecordType>(); + Literal = true; + LiteralLoc = SpecDecl->getSourceRange().getBegin(); + } + + if (RT && isInVkNamespace(RT) && + RT->getDecl()->getName() == "integral_constant") { + auto SpecDecl = dyn_cast<ClassTemplateSpecializationDecl>(RT->getDecl()); + assert(SpecDecl); + + const TemplateArgumentList &ConstantArgs = SpecDecl->getTemplateArgs(); + + QualType ConstantType = ConstantArgs[0].getAsType(); + llvm::APInt Value = ConstantArgs[1].getAsIntegral(); + + if (Literal) + return SpirvOperand::createLiteral(Value); + return SpirvOperand::createConstant(ConstantType, Value); + } else if (Literal) { + SemaRef.Diag(LiteralLoc, diag::err_hlsl_vk_literal_must_contain_constant); + return SpirvOperand(); + } + } + if (SemaRef.RequireCompleteType(Loc, OperandArg, + diag::err_call_incomplete_argument)) + return SpirvOperand(); + return SpirvOperand::createType(OperandArg); +} + static QualType checkBuiltinTemplateIdType(Sema &SemaRef, BuiltinTemplateDecl *BTD, ArrayRef<TemplateArgument> Converted, @@ -3334,6 +3387,36 @@ checkBuiltinTemplateIdType(Sema &SemaRef, BuiltinTemplateDecl *BTD, QualType HasNoTypeMember = Converted[2].getAsType(); return HasNoTypeMember; } + + case BTK__hlsl_spirv_type: { + assert(Converted.size() == 4); + + if (!Context.getTargetInfo().getTriple().isSPIRV()) { + SemaRef.Diag(TemplateLoc, diag::err_hlsl_spirv_only) << BTD; + } + + if (llvm::any_of(Converted, [](auto &C) { return C.isDependent(); })) + return QualType(); + + uint64_t Opcode = Converted[0].getAsIntegral().getZExtValue(); + uint64_t Size = Converted[1].getAsIntegral().getZExtValue(); + uint64_t Alignment = Converted[2].getAsIntegral().getZExtValue(); + + ArrayRef<TemplateArgument> OperandArgs = Converted[3].getPackAsArray(); + + llvm::SmallVector<SpirvOperand> Operands; + + for (auto &OperandTA : OperandArgs) { + QualType OperandArg = OperandTA.getAsType(); + auto Operand = checkHLSLSpirvTypeOperand(SemaRef, OperandArg, + TemplateArgs[3].getLocation()); + if (!Operand.isValid()) + return QualType(); + Operands.push_back(Operand); + } + + return Context.getHLSLInlineSpirvType(Opcode, Size, Alignment, Operands); + } } llvm_unreachable("unexpected BuiltinTemplateDecl!"); } @@ -6251,6 +6334,15 @@ bool UnnamedLocalNoLinkageFinder::VisitHLSLAttributedResourceType( return Visit(T->getWrappedType()); } +bool UnnamedLocalNoLinkageFinder::VisitHLSLInlineSpirvType( + const HLSLInlineSpirvType *T) { + for (auto &Operand : T->getOperands()) + if (Operand.isConstant() && Operand.isLiteral()) + if (Visit(Operand.getResultType())) + return true; + return false; +} + bool Sema::CheckTemplateArgument(TypeSourceInfo *ArgInfo) { assert(ArgInfo && "invalid TypeSourceInfo"); QualType Arg = ArgInfo->getType(); diff --git a/clang/lib/Sema/SemaTemplateDeduction.cpp b/clang/lib/Sema/SemaTemplateDeduction.cpp index 75ae04b..51bfca9d 100644 --- a/clang/lib/Sema/SemaTemplateDeduction.cpp +++ b/clang/lib/Sema/SemaTemplateDeduction.cpp @@ -2474,6 +2474,7 @@ static TemplateDeductionResult DeduceTemplateArgumentsByTypeMatch( case Type::Pipe: case Type::ArrayParameter: case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: // No template argument deduction for these types return TemplateDeductionResult::Success; @@ -6993,6 +6994,7 @@ MarkUsedTemplateParameters(ASTContext &Ctx, QualType T, case Type::UnresolvedUsing: case Type::Pipe: case Type::BitInt: + case Type::HLSLInlineSpirv: #define TYPE(Class, Base) #define ABSTRACT_TYPE(Class, Base) #define DEPENDENT_TYPE(Class, Base) diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp index 49d10f5..338b81f 100644 --- a/clang/lib/Sema/SemaType.cpp +++ b/clang/lib/Sema/SemaType.cpp @@ -5888,6 +5888,7 @@ namespace { Visit(TL.getWrappedLoc()); fillHLSLAttributedResourceTypeLoc(TL, State); } + void VisitHLSLInlineSpirvTypeLoc(HLSLInlineSpirvTypeLoc TL) {} void VisitMacroQualifiedTypeLoc(MacroQualifiedTypeLoc TL) { Visit(TL.getInnerLoc()); TL.setExpansionLoc( diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h index 335e21d..7629e84 100644 --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -7683,6 +7683,13 @@ QualType TreeTransform<Derived>::TransformHLSLAttributedResourceType( return Result; } +template <typename Derived> +QualType TreeTransform<Derived>::TransformHLSLInlineSpirvType( + TypeLocBuilder &TLB, HLSLInlineSpirvTypeLoc TL) { + // No transformations needed. + return TL.getType(); +} + template<typename Derived> QualType TreeTransform<Derived>::TransformParenType(TypeLocBuilder &TLB, diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp index c113fd7..1ecfb92 100644 --- a/clang/lib/Serialization/ASTReader.cpp +++ b/clang/lib/Serialization/ASTReader.cpp @@ -7349,6 +7349,10 @@ void TypeLocReader::VisitHLSLAttributedResourceTypeLoc( // Nothing to do. } +void TypeLocReader::VisitHLSLInlineSpirvTypeLoc(HLSLInlineSpirvTypeLoc TL) { + // Nothing to do. +} + void TypeLocReader::VisitTemplateTypeParmTypeLoc(TemplateTypeParmTypeLoc TL) { TL.setNameLoc(readSourceLocation()); } @@ -9821,6 +9825,15 @@ TypeCoupledDeclRefInfo ASTRecordReader::readTypeCoupledDeclRefInfo() { return TypeCoupledDeclRefInfo(readDeclAs<ValueDecl>(), readBool()); } +SpirvOperand ASTRecordReader::readHLSLSpirvOperand() { + auto Kind = readInt(); + auto ResultType = readQualType(); + auto Value = readAPInt(); + SpirvOperand Op(SpirvOperand::SpirvOperandKind(Kind), ResultType, Value); + assert(Op.isValid()); + return Op; +} + void ASTRecordReader::readQualifierInfo(QualifierInfo &Info) { Info.QualifierLoc = readNestedNameSpecifierLoc(); unsigned NumTPLists = readInt(); diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp index 1b3d3c2..cc9916a 100644 --- a/clang/lib/Serialization/ASTWriter.cpp +++ b/clang/lib/Serialization/ASTWriter.cpp @@ -604,6 +604,10 @@ void TypeLocWriter::VisitHLSLAttributedResourceTypeLoc( // Nothing to do. } +void TypeLocWriter::VisitHLSLInlineSpirvTypeLoc(HLSLInlineSpirvTypeLoc TL) { + // Nothing to do. +} + void TypeLocWriter::VisitTemplateTypeParmTypeLoc(TemplateTypeParmTypeLoc TL) { addSourceLocation(TL.getNameLoc()); } |