aboutsummaryrefslogtreecommitdiff
path: root/clang/lib
diff options
context:
space:
mode:
authorBenjamin Maxwell <benjamin.maxwell@arm.com>2025-01-28 12:08:54 +0000
committerGitHub <noreply@github.com>2025-01-28 12:08:54 +0000
commita7f4044bd01919df2bf2204d203ee0378e2e9fb2 (patch)
treefc29f52c48aed658bd67581caace2d2997547e5c /clang/lib
parent431024506c6f5597fe476e1283a08c9f8fa72ad7 (diff)
downloadllvm-a7f4044bd01919df2bf2204d203ee0378e2e9fb2.zip
llvm-a7f4044bd01919df2bf2204d203ee0378e2e9fb2.tar.gz
llvm-a7f4044bd01919df2bf2204d203ee0378e2e9fb2.tar.bz2
[clang][SME] Emit error for OpenMP captured regions in SME functions (#124590)
Currently, these generate incorrect code, as streaming/SME attributes are not propagated to the outlined function. As we've yet to work on mixing OpenMP and streaming functions (and determine how they should interact with OpenMP's runtime), we think it is best to disallow this for now.
Diffstat (limited to 'clang/lib')
-rw-r--r--clang/lib/AST/Decl.cpp14
-rw-r--r--clang/lib/Sema/SemaARM.cpp14
-rw-r--r--clang/lib/Sema/SemaStmt.cpp21
3 files changed, 35 insertions, 14 deletions
diff --git a/clang/lib/AST/Decl.cpp b/clang/lib/AST/Decl.cpp
index beb5fca..0bd4d64 100644
--- a/clang/lib/AST/Decl.cpp
+++ b/clang/lib/AST/Decl.cpp
@@ -5845,3 +5845,17 @@ bool clang::IsArmStreamingFunction(const FunctionDecl *FD,
return false;
}
+
+bool clang::hasArmZAState(const FunctionDecl *FD) {
+ const auto *T = FD->getType()->getAs<FunctionProtoType>();
+ return (T && FunctionType::getArmZAState(T->getAArch64SMEAttributes()) !=
+ FunctionType::ARM_None) ||
+ (FD->hasAttr<ArmNewAttr>() && FD->getAttr<ArmNewAttr>()->isNewZA());
+}
+
+bool clang::hasArmZT0State(const FunctionDecl *FD) {
+ const auto *T = FD->getType()->getAs<FunctionProtoType>();
+ return (T && FunctionType::getArmZT0State(T->getAArch64SMEAttributes()) !=
+ FunctionType::ARM_None) ||
+ (FD->hasAttr<ArmNewAttr>() && FD->getAttr<ArmNewAttr>()->isNewZT0());
+}
diff --git a/clang/lib/Sema/SemaARM.cpp b/clang/lib/Sema/SemaARM.cpp
index 2620bbc..9fbe835 100644
--- a/clang/lib/Sema/SemaARM.cpp
+++ b/clang/lib/Sema/SemaARM.cpp
@@ -624,20 +624,6 @@ static bool checkArmStreamingBuiltin(Sema &S, CallExpr *TheCall,
return true;
}
-static bool hasArmZAState(const FunctionDecl *FD) {
- const auto *T = FD->getType()->getAs<FunctionProtoType>();
- return (T && FunctionType::getArmZAState(T->getAArch64SMEAttributes()) !=
- FunctionType::ARM_None) ||
- (FD->hasAttr<ArmNewAttr>() && FD->getAttr<ArmNewAttr>()->isNewZA());
-}
-
-static bool hasArmZT0State(const FunctionDecl *FD) {
- const auto *T = FD->getType()->getAs<FunctionProtoType>();
- return (T && FunctionType::getArmZT0State(T->getAArch64SMEAttributes()) !=
- FunctionType::ARM_None) ||
- (FD->hasAttr<ArmNewAttr>() && FD->getAttr<ArmNewAttr>()->isNewZT0());
-}
-
static ArmSMEState getSMEState(unsigned BuiltinID) {
switch (BuiltinID) {
default:
diff --git a/clang/lib/Sema/SemaStmt.cpp b/clang/lib/Sema/SemaStmt.cpp
index 25a07d0..947651d 100644
--- a/clang/lib/Sema/SemaStmt.cpp
+++ b/clang/lib/Sema/SemaStmt.cpp
@@ -4568,9 +4568,27 @@ buildCapturedStmtCaptureList(Sema &S, CapturedRegionScopeInfo *RSI,
return false;
}
+static std::optional<int>
+isOpenMPCapturedRegionInArmSMEFunction(Sema const &S, CapturedRegionKind Kind) {
+ if (!S.getLangOpts().OpenMP || Kind != CR_OpenMP)
+ return {};
+ if (const FunctionDecl *FD = S.getCurFunctionDecl(/*AllowLambda=*/true)) {
+ if (IsArmStreamingFunction(FD, /*IncludeLocallyStreaming=*/true))
+ return /* in streaming functions */ 0;
+ if (hasArmZAState(FD))
+ return /* in functions with ZA state */ 1;
+ if (hasArmZT0State(FD))
+ return /* in fuctions with ZT0 state */ 2;
+ }
+ return {};
+}
+
void Sema::ActOnCapturedRegionStart(SourceLocation Loc, Scope *CurScope,
CapturedRegionKind Kind,
unsigned NumParams) {
+ if (auto ErrorIndex = isOpenMPCapturedRegionInArmSMEFunction(*this, Kind))
+ Diag(Loc, diag::err_sme_openmp_captured_region) << *ErrorIndex;
+
CapturedDecl *CD = nullptr;
RecordDecl *RD = CreateCapturedStmtRecordDecl(CD, Loc, NumParams);
@@ -4602,6 +4620,9 @@ void Sema::ActOnCapturedRegionStart(SourceLocation Loc, Scope *CurScope,
CapturedRegionKind Kind,
ArrayRef<CapturedParamNameType> Params,
unsigned OpenMPCaptureLevel) {
+ if (auto ErrorIndex = isOpenMPCapturedRegionInArmSMEFunction(*this, Kind))
+ Diag(Loc, diag::err_sme_openmp_captured_region) << *ErrorIndex;
+
CapturedDecl *CD = nullptr;
RecordDecl *RD = CreateCapturedStmtRecordDecl(CD, Loc, Params.size());