aboutsummaryrefslogtreecommitdiff
path: root/clang/lib
diff options
context:
space:
mode:
authorFlorian Hahn <flo@fhahn.com>2020-06-18 10:49:56 +0100
committerFlorian Hahn <flo@fhahn.com>2020-06-18 11:39:02 +0100
commitb5e082e7289197bf82c9a28c6336b51d7999b419 (patch)
tree471c306267f095f0154f2771aa493a7f72ca6e7b /clang/lib
parent4ea8e27a642c6f97ca69cd39bbe44f7366870f6c (diff)
downloadllvm-b5e082e7289197bf82c9a28c6336b51d7999b419.zip
llvm-b5e082e7289197bf82c9a28c6336b51d7999b419.tar.gz
llvm-b5e082e7289197bf82c9a28c6336b51d7999b419.tar.bz2
[Matrix] Add __builtin_matrix_column_store to Clang.
This patch add __builtin_matrix_column_major_store to Clang, as described in clang/docs/MatrixTypes.rst. In the initial version, the stride is not optional yet. Reviewers: rjmccall, jfb, rsmith, Bigcheese Reviewed By: rjmccall Differential Revision: https://reviews.llvm.org/D72782
Diffstat (limited to 'clang/lib')
-rw-r--r--clang/lib/CodeGen/CGBuiltin.cpp19
-rw-r--r--clang/lib/Sema/SemaChecking.cpp105
2 files changed, 121 insertions, 3 deletions
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 7a138c5..2339446 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -2406,6 +2406,25 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
return RValue::get(Result);
}
+ case Builtin::BI__builtin_matrix_column_major_store: {
+ MatrixBuilder<CGBuilderTy> MB(Builder);
+ Value *Matrix = EmitScalarExpr(E->getArg(0));
+ Address Dst = EmitPointerWithAlignment(E->getArg(1));
+ Value *Stride = EmitScalarExpr(E->getArg(2));
+
+ const auto *MatrixTy = E->getArg(0)->getType()->getAs<ConstantMatrixType>();
+ auto *PtrTy = E->getArg(1)->getType()->getAs<PointerType>();
+ assert(PtrTy && "arg1 must be of pointer type");
+ bool IsVolatile = PtrTy->getPointeeType().isVolatileQualified();
+
+ EmitNonNullArgCheck(RValue::get(Dst.getPointer()), E->getArg(1)->getType(),
+ E->getArg(1)->getExprLoc(), FD, 0);
+ Value *Result = MB.CreateColumnMajorStore(
+ Matrix, Dst.getPointer(), Align(Dst.getAlignment().getQuantity()),
+ Stride, IsVolatile, MatrixTy->getNumRows(), MatrixTy->getNumColumns());
+ return RValue::get(Result);
+ }
+
case Builtin::BIfinite:
case Builtin::BI__finite:
case Builtin::BIfinitef:
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index 821074d..accce85 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -1930,6 +1930,9 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
case Builtin::BI__builtin_matrix_column_major_load:
return SemaBuiltinMatrixColumnMajorLoad(TheCall, TheCallResult);
+
+ case Builtin::BI__builtin_matrix_column_major_store:
+ return SemaBuiltinMatrixColumnMajorStore(TheCall, TheCallResult);
}
// Since the target specific builtins for each arch overlap, only check those
@@ -15092,7 +15095,7 @@ ExprResult Sema::SemaBuiltinMatrixTranspose(CallExpr *TheCall,
auto *MType = Matrix->getType()->getAs<ConstantMatrixType>();
if (!MType) {
- Diag(Matrix->getBeginLoc(), diag::err_builtin_matrix_arg) << 0;
+ Diag(Matrix->getBeginLoc(), diag::err_builtin_matrix_arg);
return ExprError();
}
@@ -15161,13 +15164,15 @@ ExprResult Sema::SemaBuiltinMatrixColumnMajorLoad(CallExpr *TheCall,
auto *PtrTy = PtrExpr->getType()->getAs<PointerType>();
QualType ElementTy;
if (!PtrTy) {
- Diag(PtrExpr->getBeginLoc(), diag::err_builtin_matrix_pointer_arg) << 0;
+ Diag(PtrExpr->getBeginLoc(), diag::err_builtin_matrix_pointer_arg)
+ << "first";
ArgError = true;
} else {
ElementTy = PtrTy->getPointeeType().getUnqualifiedType();
if (!ConstantMatrixType::isValidElementType(ElementTy)) {
- Diag(PtrExpr->getBeginLoc(), diag::err_builtin_matrix_pointer_arg) << 0;
+ Diag(PtrExpr->getBeginLoc(), diag::err_builtin_matrix_pointer_arg)
+ << "first";
ArgError = true;
}
}
@@ -15237,3 +15242,97 @@ ExprResult Sema::SemaBuiltinMatrixColumnMajorLoad(CallExpr *TheCall,
Context.getConstantMatrixType(ElementTy, *MaybeRows, *MaybeColumns));
return CallResult;
}
+
+ExprResult Sema::SemaBuiltinMatrixColumnMajorStore(CallExpr *TheCall,
+ ExprResult CallResult) {
+ if (checkArgCount(*this, TheCall, 3))
+ return ExprError();
+
+ Expr *MatrixExpr = TheCall->getArg(0);
+ Expr *PtrExpr = TheCall->getArg(1);
+ Expr *StrideExpr = TheCall->getArg(2);
+
+ bool ArgError = false;
+
+ {
+ ExprResult MatrixConv = DefaultLvalueConversion(MatrixExpr);
+ if (MatrixConv.isInvalid())
+ return MatrixConv;
+ MatrixExpr = MatrixConv.get();
+ TheCall->setArg(0, MatrixExpr);
+ }
+ if (MatrixExpr->isTypeDependent()) {
+ TheCall->setType(Context.DependentTy);
+ return TheCall;
+ }
+
+ auto *MatrixTy = MatrixExpr->getType()->getAs<ConstantMatrixType>();
+ if (!MatrixTy) {
+ Diag(MatrixExpr->getBeginLoc(), diag::err_builtin_matrix_arg) << 0;
+ ArgError = true;
+ }
+
+ {
+ ExprResult PtrConv = DefaultFunctionArrayLvalueConversion(PtrExpr);
+ if (PtrConv.isInvalid())
+ return PtrConv;
+ PtrExpr = PtrConv.get();
+ TheCall->setArg(1, PtrExpr);
+ if (PtrExpr->isTypeDependent()) {
+ TheCall->setType(Context.DependentTy);
+ return TheCall;
+ }
+ }
+
+ // Check pointer argument.
+ auto *PtrTy = PtrExpr->getType()->getAs<PointerType>();
+ if (!PtrTy) {
+ Diag(PtrExpr->getBeginLoc(), diag::err_builtin_matrix_pointer_arg)
+ << "second";
+ ArgError = true;
+ } else {
+ QualType ElementTy = PtrTy->getPointeeType();
+ if (ElementTy.isConstQualified()) {
+ Diag(PtrExpr->getBeginLoc(), diag::err_builtin_matrix_store_to_const);
+ ArgError = true;
+ }
+ ElementTy = ElementTy.getUnqualifiedType().getCanonicalType();
+ if (MatrixTy &&
+ !Context.hasSameType(ElementTy, MatrixTy->getElementType())) {
+ Diag(PtrExpr->getBeginLoc(),
+ diag::err_builtin_matrix_pointer_arg_mismatch)
+ << ElementTy << MatrixTy->getElementType();
+ ArgError = true;
+ }
+ }
+
+ // Apply default Lvalue conversions and convert the stride expression to
+ // size_t.
+ {
+ ExprResult StrideConv = DefaultLvalueConversion(StrideExpr);
+ if (StrideConv.isInvalid())
+ return StrideConv;
+
+ StrideConv = tryConvertExprToType(StrideConv.get(), Context.getSizeType());
+ if (StrideConv.isInvalid())
+ return StrideConv;
+ StrideExpr = StrideConv.get();
+ TheCall->setArg(2, StrideExpr);
+ }
+
+ // Check stride argument.
+ llvm::APSInt Value(64);
+ if (MatrixTy && StrideExpr->isIntegerConstantExpr(Value, Context)) {
+ uint64_t Stride = Value.getZExtValue();
+ if (Stride < MatrixTy->getNumRows()) {
+ Diag(StrideExpr->getBeginLoc(),
+ diag::err_builtin_matrix_stride_too_small);
+ ArgError = true;
+ }
+ }
+
+ if (ArgError)
+ return ExprError();
+
+ return CallResult;
+}