diff options
author | Florian Hahn <flo@fhahn.com> | 2020-06-18 10:49:56 +0100 |
---|---|---|
committer | Florian Hahn <flo@fhahn.com> | 2020-06-18 11:39:02 +0100 |
commit | b5e082e7289197bf82c9a28c6336b51d7999b419 (patch) | |
tree | 471c306267f095f0154f2771aa493a7f72ca6e7b /clang/lib | |
parent | 4ea8e27a642c6f97ca69cd39bbe44f7366870f6c (diff) | |
download | llvm-b5e082e7289197bf82c9a28c6336b51d7999b419.zip llvm-b5e082e7289197bf82c9a28c6336b51d7999b419.tar.gz llvm-b5e082e7289197bf82c9a28c6336b51d7999b419.tar.bz2 |
[Matrix] Add __builtin_matrix_column_store to Clang.
This patch add __builtin_matrix_column_major_store to Clang,
as described in clang/docs/MatrixTypes.rst. In the initial version,
the stride is not optional yet.
Reviewers: rjmccall, jfb, rsmith, Bigcheese
Reviewed By: rjmccall
Differential Revision: https://reviews.llvm.org/D72782
Diffstat (limited to 'clang/lib')
-rw-r--r-- | clang/lib/CodeGen/CGBuiltin.cpp | 19 | ||||
-rw-r--r-- | clang/lib/Sema/SemaChecking.cpp | 105 |
2 files changed, 121 insertions, 3 deletions
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp index 7a138c5..2339446 100644 --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -2406,6 +2406,25 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID, return RValue::get(Result); } + case Builtin::BI__builtin_matrix_column_major_store: { + MatrixBuilder<CGBuilderTy> MB(Builder); + Value *Matrix = EmitScalarExpr(E->getArg(0)); + Address Dst = EmitPointerWithAlignment(E->getArg(1)); + Value *Stride = EmitScalarExpr(E->getArg(2)); + + const auto *MatrixTy = E->getArg(0)->getType()->getAs<ConstantMatrixType>(); + auto *PtrTy = E->getArg(1)->getType()->getAs<PointerType>(); + assert(PtrTy && "arg1 must be of pointer type"); + bool IsVolatile = PtrTy->getPointeeType().isVolatileQualified(); + + EmitNonNullArgCheck(RValue::get(Dst.getPointer()), E->getArg(1)->getType(), + E->getArg(1)->getExprLoc(), FD, 0); + Value *Result = MB.CreateColumnMajorStore( + Matrix, Dst.getPointer(), Align(Dst.getAlignment().getQuantity()), + Stride, IsVolatile, MatrixTy->getNumRows(), MatrixTy->getNumColumns()); + return RValue::get(Result); + } + case Builtin::BIfinite: case Builtin::BI__finite: case Builtin::BIfinitef: diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp index 821074d..accce85 100644 --- a/clang/lib/Sema/SemaChecking.cpp +++ b/clang/lib/Sema/SemaChecking.cpp @@ -1930,6 +1930,9 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID, case Builtin::BI__builtin_matrix_column_major_load: return SemaBuiltinMatrixColumnMajorLoad(TheCall, TheCallResult); + + case Builtin::BI__builtin_matrix_column_major_store: + return SemaBuiltinMatrixColumnMajorStore(TheCall, TheCallResult); } // Since the target specific builtins for each arch overlap, only check those @@ -15092,7 +15095,7 @@ ExprResult Sema::SemaBuiltinMatrixTranspose(CallExpr *TheCall, auto *MType = Matrix->getType()->getAs<ConstantMatrixType>(); if (!MType) { - Diag(Matrix->getBeginLoc(), diag::err_builtin_matrix_arg) << 0; + Diag(Matrix->getBeginLoc(), diag::err_builtin_matrix_arg); return ExprError(); } @@ -15161,13 +15164,15 @@ ExprResult Sema::SemaBuiltinMatrixColumnMajorLoad(CallExpr *TheCall, auto *PtrTy = PtrExpr->getType()->getAs<PointerType>(); QualType ElementTy; if (!PtrTy) { - Diag(PtrExpr->getBeginLoc(), diag::err_builtin_matrix_pointer_arg) << 0; + Diag(PtrExpr->getBeginLoc(), diag::err_builtin_matrix_pointer_arg) + << "first"; ArgError = true; } else { ElementTy = PtrTy->getPointeeType().getUnqualifiedType(); if (!ConstantMatrixType::isValidElementType(ElementTy)) { - Diag(PtrExpr->getBeginLoc(), diag::err_builtin_matrix_pointer_arg) << 0; + Diag(PtrExpr->getBeginLoc(), diag::err_builtin_matrix_pointer_arg) + << "first"; ArgError = true; } } @@ -15237,3 +15242,97 @@ ExprResult Sema::SemaBuiltinMatrixColumnMajorLoad(CallExpr *TheCall, Context.getConstantMatrixType(ElementTy, *MaybeRows, *MaybeColumns)); return CallResult; } + +ExprResult Sema::SemaBuiltinMatrixColumnMajorStore(CallExpr *TheCall, + ExprResult CallResult) { + if (checkArgCount(*this, TheCall, 3)) + return ExprError(); + + Expr *MatrixExpr = TheCall->getArg(0); + Expr *PtrExpr = TheCall->getArg(1); + Expr *StrideExpr = TheCall->getArg(2); + + bool ArgError = false; + + { + ExprResult MatrixConv = DefaultLvalueConversion(MatrixExpr); + if (MatrixConv.isInvalid()) + return MatrixConv; + MatrixExpr = MatrixConv.get(); + TheCall->setArg(0, MatrixExpr); + } + if (MatrixExpr->isTypeDependent()) { + TheCall->setType(Context.DependentTy); + return TheCall; + } + + auto *MatrixTy = MatrixExpr->getType()->getAs<ConstantMatrixType>(); + if (!MatrixTy) { + Diag(MatrixExpr->getBeginLoc(), diag::err_builtin_matrix_arg) << 0; + ArgError = true; + } + + { + ExprResult PtrConv = DefaultFunctionArrayLvalueConversion(PtrExpr); + if (PtrConv.isInvalid()) + return PtrConv; + PtrExpr = PtrConv.get(); + TheCall->setArg(1, PtrExpr); + if (PtrExpr->isTypeDependent()) { + TheCall->setType(Context.DependentTy); + return TheCall; + } + } + + // Check pointer argument. + auto *PtrTy = PtrExpr->getType()->getAs<PointerType>(); + if (!PtrTy) { + Diag(PtrExpr->getBeginLoc(), diag::err_builtin_matrix_pointer_arg) + << "second"; + ArgError = true; + } else { + QualType ElementTy = PtrTy->getPointeeType(); + if (ElementTy.isConstQualified()) { + Diag(PtrExpr->getBeginLoc(), diag::err_builtin_matrix_store_to_const); + ArgError = true; + } + ElementTy = ElementTy.getUnqualifiedType().getCanonicalType(); + if (MatrixTy && + !Context.hasSameType(ElementTy, MatrixTy->getElementType())) { + Diag(PtrExpr->getBeginLoc(), + diag::err_builtin_matrix_pointer_arg_mismatch) + << ElementTy << MatrixTy->getElementType(); + ArgError = true; + } + } + + // Apply default Lvalue conversions and convert the stride expression to + // size_t. + { + ExprResult StrideConv = DefaultLvalueConversion(StrideExpr); + if (StrideConv.isInvalid()) + return StrideConv; + + StrideConv = tryConvertExprToType(StrideConv.get(), Context.getSizeType()); + if (StrideConv.isInvalid()) + return StrideConv; + StrideExpr = StrideConv.get(); + TheCall->setArg(2, StrideExpr); + } + + // Check stride argument. + llvm::APSInt Value(64); + if (MatrixTy && StrideExpr->isIntegerConstantExpr(Value, Context)) { + uint64_t Stride = Value.getZExtValue(); + if (Stride < MatrixTy->getNumRows()) { + Diag(StrideExpr->getBeginLoc(), + diag::err_builtin_matrix_stride_too_small); + ArgError = true; + } + } + + if (ArgError) + return ExprError(); + + return CallResult; +} |