diff options
Diffstat (limited to 'clang/lib/CodeGen/CGHLSLRuntime.cpp')
| -rw-r--r-- | clang/lib/CodeGen/CGHLSLRuntime.cpp | 100 |
1 files changed, 71 insertions, 29 deletions
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp index ecab933..945f9e2 100644 --- a/clang/lib/CodeGen/CGHLSLRuntime.cpp +++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp @@ -562,17 +562,16 @@ static llvm::Value *createSPIRVBuiltinLoad(IRBuilder<> &B, llvm::Module &M, return B.CreateLoad(Ty, GV); } -llvm::Value * -CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type, - const clang::DeclaratorDecl *Decl, - SemanticInfo &ActiveSemantic) { - if (isa<HLSLSV_GroupIndexAttr>(ActiveSemantic.Semantic)) { +llvm::Value *CGHLSLRuntime::emitSystemSemanticLoad( + IRBuilder<> &B, llvm::Type *Type, const clang::DeclaratorDecl *Decl, + Attr *Semantic, std::optional<unsigned> Index) { + if (isa<HLSLSV_GroupIndexAttr>(Semantic)) { llvm::Function *GroupIndex = CGM.getIntrinsic(getFlattenedThreadIdInGroupIntrinsic()); return B.CreateCall(FunctionCallee(GroupIndex)); } - if (isa<HLSLSV_DispatchThreadIDAttr>(ActiveSemantic.Semantic)) { + if (isa<HLSLSV_DispatchThreadIDAttr>(Semantic)) { llvm::Intrinsic::ID IntrinID = getThreadIdIntrinsic(); llvm::Function *ThreadIDIntrinsic = llvm::Intrinsic::isOverloaded(IntrinID) @@ -581,7 +580,7 @@ CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type, return buildVectorInput(B, ThreadIDIntrinsic, Type); } - if (isa<HLSLSV_GroupThreadIDAttr>(ActiveSemantic.Semantic)) { + if (isa<HLSLSV_GroupThreadIDAttr>(Semantic)) { llvm::Intrinsic::ID IntrinID = getGroupThreadIdIntrinsic(); llvm::Function *GroupThreadIDIntrinsic = llvm::Intrinsic::isOverloaded(IntrinID) @@ -590,7 +589,7 @@ CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type, return buildVectorInput(B, GroupThreadIDIntrinsic, Type); } - if (isa<HLSLSV_GroupIDAttr>(ActiveSemantic.Semantic)) { + if (isa<HLSLSV_GroupIDAttr>(Semantic)) { llvm::Intrinsic::ID IntrinID = getGroupIdIntrinsic(); llvm::Function *GroupIDIntrinsic = llvm::Intrinsic::isOverloaded(IntrinID) @@ -599,8 +598,7 @@ CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type, return buildVectorInput(B, GroupIDIntrinsic, Type); } - if (HLSLSV_PositionAttr *S = - dyn_cast<HLSLSV_PositionAttr>(ActiveSemantic.Semantic)) { + if (HLSLSV_PositionAttr *S = dyn_cast<HLSLSV_PositionAttr>(Semantic)) { if (CGM.getTriple().getEnvironment() == Triple::EnvironmentType::Pixel) return createSPIRVBuiltinLoad(B, CGM.getModule(), Type, S->getAttrName()->getName(), @@ -611,29 +609,56 @@ CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type, } llvm::Value * -CGHLSLRuntime::handleScalarSemanticLoad(IRBuilder<> &B, llvm::Type *Type, - const clang::DeclaratorDecl *Decl, - SemanticInfo &ActiveSemantic) { - - if (!ActiveSemantic.Semantic) { - ActiveSemantic.Semantic = Decl->getAttr<HLSLSemanticAttr>(); - if (!ActiveSemantic.Semantic) { - CGM.getDiags().Report(Decl->getInnerLocStart(), - diag::err_hlsl_semantic_missing); - return nullptr; +CGHLSLRuntime::handleScalarSemanticLoad(IRBuilder<> &B, const FunctionDecl *FD, + llvm::Type *Type, + const clang::DeclaratorDecl *Decl) { + + HLSLSemanticAttr *Semantic = nullptr; + for (HLSLSemanticAttr *Item : FD->specific_attrs<HLSLSemanticAttr>()) { + if (Item->getTargetDecl() == Decl) { + Semantic = Item; + break; } - ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex(); } + // Sema must create one attribute per scalar field. + assert(Semantic); - return emitSystemSemanticLoad(B, Type, Decl, ActiveSemantic); + std::optional<unsigned> Index = std::nullopt; + if (Semantic->isSemanticIndexExplicit()) + Index = Semantic->getSemanticIndex(); + return emitSystemSemanticLoad(B, Type, Decl, Semantic, Index); } llvm::Value * -CGHLSLRuntime::handleSemanticLoad(IRBuilder<> &B, llvm::Type *Type, - const clang::DeclaratorDecl *Decl, - SemanticInfo &ActiveSemantic) { - assert(!Type->isStructTy()); - return handleScalarSemanticLoad(B, Type, Decl, ActiveSemantic); +CGHLSLRuntime::handleStructSemanticLoad(IRBuilder<> &B, const FunctionDecl *FD, + llvm::Type *Type, + const clang::DeclaratorDecl *Decl) { + const llvm::StructType *ST = cast<StructType>(Type); + const clang::RecordDecl *RD = Decl->getType()->getAsRecordDecl(); + + assert(std::distance(RD->field_begin(), RD->field_end()) == + ST->getNumElements()); + + llvm::Value *Aggregate = llvm::PoisonValue::get(Type); + auto FieldDecl = RD->field_begin(); + for (unsigned I = 0; I < ST->getNumElements(); ++I) { + llvm::Value *ChildValue = + handleSemanticLoad(B, FD, ST->getElementType(I), *FieldDecl); + assert(ChildValue); + Aggregate = B.CreateInsertValue(Aggregate, ChildValue, I); + ++FieldDecl; + } + + return Aggregate; +} + +llvm::Value * +CGHLSLRuntime::handleSemanticLoad(IRBuilder<> &B, const FunctionDecl *FD, + llvm::Type *Type, + const clang::DeclaratorDecl *Decl) { + if (Type->isStructTy()) + return handleStructSemanticLoad(B, FD, Type, Decl); + return handleScalarSemanticLoad(B, FD, Type, Decl); } void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD, @@ -680,8 +705,25 @@ void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD, } const ParmVarDecl *PD = FD->getParamDecl(Param.getArgNo() - SRetOffset); - SemanticInfo ActiveSemantic = {nullptr, 0}; - Args.push_back(handleSemanticLoad(B, Param.getType(), PD, ActiveSemantic)); + llvm::Value *SemanticValue = nullptr; + if ([[maybe_unused]] HLSLParamModifierAttr *MA = + PD->getAttr<HLSLParamModifierAttr>()) { + llvm_unreachable("Not handled yet"); + } else { + llvm::Type *ParamType = + Param.hasByValAttr() ? Param.getParamByValType() : Param.getType(); + SemanticValue = handleSemanticLoad(B, FD, ParamType, PD); + if (!SemanticValue) + return; + if (Param.hasByValAttr()) { + llvm::Value *Var = B.CreateAlloca(Param.getParamByValType()); + B.CreateStore(SemanticValue, Var); + SemanticValue = Var; + } + } + + assert(SemanticValue); + Args.push_back(SemanticValue); } CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args, OB); |
