aboutsummaryrefslogtreecommitdiff
path: root/clang/lib/CodeGen/CGExprAgg.cpp
diff options
context:
space:
mode:
authorSarah Spall <sarahspall@microsoft.com>2025-02-14 09:25:24 -0800
committerGitHub <noreply@github.com>2025-02-14 09:25:24 -0800
commit4d2d0afceeb732a5238c2167ab7a6b88cc66d976 (patch)
treebc09547d67b988f6252626a6ab8f4644ba59a0a0 /clang/lib/CodeGen/CGExprAgg.cpp
parentb41b86a907f653f79bab10d4c80b3a41d146c71b (diff)
downloadllvm-4d2d0afceeb732a5238c2167ab7a6b88cc66d976.zip
llvm-4d2d0afceeb732a5238c2167ab7a6b88cc66d976.tar.gz
llvm-4d2d0afceeb732a5238c2167ab7a6b88cc66d976.tar.bz2
[HLSL] Implement HLSL Aggregate splatting (#118992)
Implement HLSL Aggregate Splat casting that handles splatting for arrays and structs, and vectors if splatting from a vec1. Closes #100609 and Closes #100619 Depends on #118842
Diffstat (limited to 'clang/lib/CodeGen/CGExprAgg.cpp')
-rw-r--r--clang/lib/CodeGen/CGExprAgg.cpp39
1 files changed, 39 insertions, 0 deletions
diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp
index c574827..d25d0f2 100644
--- a/clang/lib/CodeGen/CGExprAgg.cpp
+++ b/clang/lib/CodeGen/CGExprAgg.cpp
@@ -498,6 +498,31 @@ static bool isTrivialFiller(Expr *E) {
return false;
}
+static void EmitHLSLAggregateSplatCast(CodeGenFunction &CGF, Address DestVal,
+ QualType DestTy, llvm::Value *SrcVal,
+ QualType SrcTy, SourceLocation Loc) {
+ // Flatten our destination
+ SmallVector<QualType> DestTypes; // Flattened type
+ SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
+ // ^^ Flattened accesses to DestVal we want to store into
+ CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes);
+
+ assert(SrcTy->isScalarType() && "Invalid HLSL Aggregate splat cast.");
+ for (unsigned I = 0, Size = StoreGEPList.size(); I < Size; ++I) {
+ llvm::Value *Cast =
+ CGF.EmitScalarConversion(SrcVal, SrcTy, DestTypes[I], Loc);
+
+ // store back
+ llvm::Value *Idx = StoreGEPList[I].second;
+ if (Idx) {
+ llvm::Value *V =
+ CGF.Builder.CreateLoad(StoreGEPList[I].first, "load.for.insert");
+ Cast = CGF.Builder.CreateInsertElement(V, Cast, Idx);
+ }
+ CGF.Builder.CreateStore(Cast, StoreGEPList[I].first);
+ }
+}
+
// emit a flat cast where the RHS is a scalar, including vector
static void EmitHLSLScalarFlatCast(CodeGenFunction &CGF, Address DestVal,
QualType DestTy, llvm::Value *SrcVal,
@@ -970,6 +995,19 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
case CK_HLSLArrayRValue:
Visit(E->getSubExpr());
break;
+ case CK_HLSLAggregateSplatCast: {
+ Expr *Src = E->getSubExpr();
+ QualType SrcTy = Src->getType();
+ RValue RV = CGF.EmitAnyExpr(Src);
+ QualType DestTy = E->getType();
+ Address DestVal = Dest.getAddress();
+ SourceLocation Loc = E->getExprLoc();
+
+ assert(RV.isScalar() && "RHS of HLSL splat cast must be a scalar.");
+ llvm::Value *SrcVal = RV.getScalarVal();
+ EmitHLSLAggregateSplatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
+ break;
+ }
case CK_HLSLElementwiseCast: {
Expr *Src = E->getSubExpr();
QualType SrcTy = Src->getType();
@@ -1560,6 +1598,7 @@ static bool castPreservesZero(const CastExpr *CE) {
case CK_AtomicToNonAtomic:
case CK_HLSLVectorTruncation:
case CK_HLSLElementwiseCast:
+ case CK_HLSLAggregateSplatCast:
return true;
case CK_BaseToDerivedMemberPointer: