diff options
Diffstat (limited to 'clang/lib/CodeGen/CGHLSLRuntime.cpp')
| -rw-r--r-- | clang/lib/CodeGen/CGHLSLRuntime.cpp | 100 | 
1 files changed, 71 insertions, 29 deletions
| diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp index ecab933..945f9e2 100644 --- a/clang/lib/CodeGen/CGHLSLRuntime.cpp +++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp @@ -562,17 +562,16 @@ static llvm::Value *createSPIRVBuiltinLoad(IRBuilder<> &B, llvm::Module &M,    return B.CreateLoad(Ty, GV);  } -llvm::Value * -CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type, -                                      const clang::DeclaratorDecl *Decl, -                                      SemanticInfo &ActiveSemantic) { -  if (isa<HLSLSV_GroupIndexAttr>(ActiveSemantic.Semantic)) { +llvm::Value *CGHLSLRuntime::emitSystemSemanticLoad( +    IRBuilder<> &B, llvm::Type *Type, const clang::DeclaratorDecl *Decl, +    Attr *Semantic, std::optional<unsigned> Index) { +  if (isa<HLSLSV_GroupIndexAttr>(Semantic)) {      llvm::Function *GroupIndex =          CGM.getIntrinsic(getFlattenedThreadIdInGroupIntrinsic());      return B.CreateCall(FunctionCallee(GroupIndex));    } -  if (isa<HLSLSV_DispatchThreadIDAttr>(ActiveSemantic.Semantic)) { +  if (isa<HLSLSV_DispatchThreadIDAttr>(Semantic)) {      llvm::Intrinsic::ID IntrinID = getThreadIdIntrinsic();      llvm::Function *ThreadIDIntrinsic =          llvm::Intrinsic::isOverloaded(IntrinID) @@ -581,7 +580,7 @@ CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type,      return buildVectorInput(B, ThreadIDIntrinsic, Type);    } -  if (isa<HLSLSV_GroupThreadIDAttr>(ActiveSemantic.Semantic)) { +  if (isa<HLSLSV_GroupThreadIDAttr>(Semantic)) {      llvm::Intrinsic::ID IntrinID = getGroupThreadIdIntrinsic();      llvm::Function *GroupThreadIDIntrinsic =          llvm::Intrinsic::isOverloaded(IntrinID) @@ -590,7 +589,7 @@ CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type,      return buildVectorInput(B, GroupThreadIDIntrinsic, Type);    } -  if (isa<HLSLSV_GroupIDAttr>(ActiveSemantic.Semantic)) { +  if (isa<HLSLSV_GroupIDAttr>(Semantic)) {      llvm::Intrinsic::ID IntrinID = getGroupIdIntrinsic();      llvm::Function *GroupIDIntrinsic =          llvm::Intrinsic::isOverloaded(IntrinID) @@ -599,8 +598,7 @@ CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type,      return buildVectorInput(B, GroupIDIntrinsic, Type);    } -  if (HLSLSV_PositionAttr *S = -          dyn_cast<HLSLSV_PositionAttr>(ActiveSemantic.Semantic)) { +  if (HLSLSV_PositionAttr *S = dyn_cast<HLSLSV_PositionAttr>(Semantic)) {      if (CGM.getTriple().getEnvironment() == Triple::EnvironmentType::Pixel)        return createSPIRVBuiltinLoad(B, CGM.getModule(), Type,                                      S->getAttrName()->getName(), @@ -611,29 +609,56 @@ CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type,  }  llvm::Value * -CGHLSLRuntime::handleScalarSemanticLoad(IRBuilder<> &B, llvm::Type *Type, -                                        const clang::DeclaratorDecl *Decl, -                                        SemanticInfo &ActiveSemantic) { - -  if (!ActiveSemantic.Semantic) { -    ActiveSemantic.Semantic = Decl->getAttr<HLSLSemanticAttr>(); -    if (!ActiveSemantic.Semantic) { -      CGM.getDiags().Report(Decl->getInnerLocStart(), -                            diag::err_hlsl_semantic_missing); -      return nullptr; +CGHLSLRuntime::handleScalarSemanticLoad(IRBuilder<> &B, const FunctionDecl *FD, +                                        llvm::Type *Type, +                                        const clang::DeclaratorDecl *Decl) { + +  HLSLSemanticAttr *Semantic = nullptr; +  for (HLSLSemanticAttr *Item : FD->specific_attrs<HLSLSemanticAttr>()) { +    if (Item->getTargetDecl() == Decl) { +      Semantic = Item; +      break;      } -    ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();    } +  // Sema must create one attribute per scalar field. +  assert(Semantic); -  return emitSystemSemanticLoad(B, Type, Decl, ActiveSemantic); +  std::optional<unsigned> Index = std::nullopt; +  if (Semantic->isSemanticIndexExplicit()) +    Index = Semantic->getSemanticIndex(); +  return emitSystemSemanticLoad(B, Type, Decl, Semantic, Index);  }  llvm::Value * -CGHLSLRuntime::handleSemanticLoad(IRBuilder<> &B, llvm::Type *Type, -                                  const clang::DeclaratorDecl *Decl, -                                  SemanticInfo &ActiveSemantic) { -  assert(!Type->isStructTy()); -  return handleScalarSemanticLoad(B, Type, Decl, ActiveSemantic); +CGHLSLRuntime::handleStructSemanticLoad(IRBuilder<> &B, const FunctionDecl *FD, +                                        llvm::Type *Type, +                                        const clang::DeclaratorDecl *Decl) { +  const llvm::StructType *ST = cast<StructType>(Type); +  const clang::RecordDecl *RD = Decl->getType()->getAsRecordDecl(); + +  assert(std::distance(RD->field_begin(), RD->field_end()) == +         ST->getNumElements()); + +  llvm::Value *Aggregate = llvm::PoisonValue::get(Type); +  auto FieldDecl = RD->field_begin(); +  for (unsigned I = 0; I < ST->getNumElements(); ++I) { +    llvm::Value *ChildValue = +        handleSemanticLoad(B, FD, ST->getElementType(I), *FieldDecl); +    assert(ChildValue); +    Aggregate = B.CreateInsertValue(Aggregate, ChildValue, I); +    ++FieldDecl; +  } + +  return Aggregate; +} + +llvm::Value * +CGHLSLRuntime::handleSemanticLoad(IRBuilder<> &B, const FunctionDecl *FD, +                                  llvm::Type *Type, +                                  const clang::DeclaratorDecl *Decl) { +  if (Type->isStructTy()) +    return handleStructSemanticLoad(B, FD, Type, Decl); +  return handleScalarSemanticLoad(B, FD, Type, Decl);  }  void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD, @@ -680,8 +705,25 @@ void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,      }      const ParmVarDecl *PD = FD->getParamDecl(Param.getArgNo() - SRetOffset); -    SemanticInfo ActiveSemantic = {nullptr, 0}; -    Args.push_back(handleSemanticLoad(B, Param.getType(), PD, ActiveSemantic)); +    llvm::Value *SemanticValue = nullptr; +    if ([[maybe_unused]] HLSLParamModifierAttr *MA = +            PD->getAttr<HLSLParamModifierAttr>()) { +      llvm_unreachable("Not handled yet"); +    } else { +      llvm::Type *ParamType = +          Param.hasByValAttr() ? Param.getParamByValType() : Param.getType(); +      SemanticValue = handleSemanticLoad(B, FD, ParamType, PD); +      if (!SemanticValue) +        return; +      if (Param.hasByValAttr()) { +        llvm::Value *Var = B.CreateAlloca(Param.getParamByValType()); +        B.CreateStore(SemanticValue, Var); +        SemanticValue = Var; +      } +    } + +    assert(SemanticValue); +    Args.push_back(SemanticValue);    }    CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args, OB); | 
