diff options
Diffstat (limited to 'clang/lib/Sema/SemaHLSL.cpp')
| -rw-r--r-- | clang/lib/Sema/SemaHLSL.cpp | 39 |
1 files changed, 37 insertions, 2 deletions
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 2a485da..96d5142 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -21,6 +21,7 @@ #include "clang/AST/Expr.h" #include "clang/AST/HLSLResource.h" #include "clang/AST/Type.h" +#include "clang/AST/TypeBase.h" #include "clang/AST/TypeLoc.h" #include "clang/Basic/Builtins.h" #include "clang/Basic/DiagnosticSema.h" @@ -3432,6 +3433,11 @@ static void BuildFlattenedTypeList(QualType BaseTy, List.insert(List.end(), VT->getNumElements(), VT->getElementType()); continue; } + if (const auto *MT = dyn_cast<ConstantMatrixType>(T)) { + List.insert(List.end(), MT->getNumElementsFlattened(), + MT->getElementType()); + continue; + } if (const auto *RD = T->getAsCXXRecordDecl()) { if (RD->isStandardLayout()) RD = RD->getStandardLayoutBaseWithFields(); @@ -4230,6 +4236,32 @@ class InitListTransformer { } return true; } + if (auto *MTy = Ty->getAs<ConstantMatrixType>()) { + unsigned Rows = MTy->getNumRows(); + unsigned Cols = MTy->getNumColumns(); + QualType ElemTy = MTy->getElementType(); + + for (unsigned C = 0; C < Cols; ++C) { + for (unsigned R = 0; R < Rows; ++R) { + // row index literal + Expr *RowIdx = IntegerLiteral::Create( + Ctx, llvm::APInt(Ctx.getIntWidth(Ctx.IntTy), R), Ctx.IntTy, + E->getBeginLoc()); + // column index literal + Expr *ColIdx = IntegerLiteral::Create( + Ctx, llvm::APInt(Ctx.getIntWidth(Ctx.IntTy), C), Ctx.IntTy, + E->getBeginLoc()); + ExprResult ElExpr = S.CreateBuiltinMatrixSubscriptExpr( + E, RowIdx, ColIdx, E->getEndLoc()); + if (ElExpr.isInvalid()) + return false; + if (!castInitializer(ElExpr.get())) + return false; + ElExpr.get()->setType(ElemTy); + } + } + return true; + } if (auto *ArrTy = dyn_cast<ConstantArrayType>(Ty.getTypePtr())) { uint64_t Size = ArrTy->getZExtSize(); @@ -4283,14 +4315,17 @@ class InitListTransformer { return *(ArgIt++); llvm::SmallVector<Expr *> Inits; - assert(!isa<MatrixType>(Ty) && "Matrix types not yet supported in HLSL"); Ty = Ty.getDesugaredType(Ctx); - if (Ty->isVectorType() || Ty->isConstantArrayType()) { + if (Ty->isVectorType() || Ty->isConstantArrayType() || + Ty->isConstantMatrixType()) { QualType ElTy; uint64_t Size = 0; if (auto *ATy = Ty->getAs<VectorType>()) { ElTy = ATy->getElementType(); Size = ATy->getNumElements(); + } else if (auto *CMTy = Ty->getAs<ConstantMatrixType>()) { + ElTy = CMTy->getElementType(); + Size = CMTy->getNumElementsFlattened(); } else { auto *VTy = cast<ConstantArrayType>(Ty.getTypePtr()); ElTy = VTy->getElementType(); |
