aboutsummaryrefslogtreecommitdiff
path: root/clang
diff options
context:
space:
mode:
Diffstat (limited to 'clang')
-rw-r--r--clang/include/clang/AST/Decl.h6
-rw-r--r--clang/include/clang/Basic/DiagnosticSemaKinds.td3
-rw-r--r--clang/lib/AST/Decl.cpp14
-rw-r--r--clang/lib/Sema/SemaARM.cpp14
-rw-r--r--clang/lib/Sema/SemaStmt.cpp21
-rw-r--r--clang/test/Sema/aarch64-sme-attrs-openmp-captured-region.c81
6 files changed, 125 insertions, 14 deletions
diff --git a/clang/include/clang/AST/Decl.h b/clang/include/clang/AST/Decl.h
index 1640377..9593bab 100644
--- a/clang/include/clang/AST/Decl.h
+++ b/clang/include/clang/AST/Decl.h
@@ -5139,6 +5139,12 @@ static constexpr StringRef getOpenMPVariantManglingSeparatorStr() {
bool IsArmStreamingFunction(const FunctionDecl *FD,
bool IncludeLocallyStreaming);
+/// Returns whether the given FunctionDecl has Arm ZA state.
+bool hasArmZAState(const FunctionDecl *FD);
+
+/// Returns whether the given FunctionDecl has Arm ZT0 state.
+bool hasArmZT0State(const FunctionDecl *FD);
+
} // namespace clang
#endif // LLVM_CLANG_AST_DECL_H
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 9487e3b..a09fe03 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -3870,6 +3870,9 @@ def err_sme_definition_using_za_in_non_sme_target : Error<
"function using ZA state requires 'sme'">;
def err_sme_definition_using_zt0_in_non_sme2_target : Error<
"function using ZT0 state requires 'sme2'">;
+def err_sme_openmp_captured_region : Error<
+ "OpenMP captured regions are not yet supported in "
+ "%select{streaming functions|functions with ZA state|functions with ZT0 state}0">;
def warn_sme_streaming_pass_return_vl_to_non_streaming : Warning<
"%select{returning|passing}0 a VL-dependent argument %select{from|to}0 a function with a different"
" streaming-mode is undefined behaviour when the streaming and non-streaming vector lengths are different at runtime">,
diff --git a/clang/lib/AST/Decl.cpp b/clang/lib/AST/Decl.cpp
index beb5fca..0bd4d64 100644
--- a/clang/lib/AST/Decl.cpp
+++ b/clang/lib/AST/Decl.cpp
@@ -5845,3 +5845,17 @@ bool clang::IsArmStreamingFunction(const FunctionDecl *FD,
return false;
}
+
+bool clang::hasArmZAState(const FunctionDecl *FD) {
+ const auto *T = FD->getType()->getAs<FunctionProtoType>();
+ return (T && FunctionType::getArmZAState(T->getAArch64SMEAttributes()) !=
+ FunctionType::ARM_None) ||
+ (FD->hasAttr<ArmNewAttr>() && FD->getAttr<ArmNewAttr>()->isNewZA());
+}
+
+bool clang::hasArmZT0State(const FunctionDecl *FD) {
+ const auto *T = FD->getType()->getAs<FunctionProtoType>();
+ return (T && FunctionType::getArmZT0State(T->getAArch64SMEAttributes()) !=
+ FunctionType::ARM_None) ||
+ (FD->hasAttr<ArmNewAttr>() && FD->getAttr<ArmNewAttr>()->isNewZT0());
+}
diff --git a/clang/lib/Sema/SemaARM.cpp b/clang/lib/Sema/SemaARM.cpp
index 2620bbc..9fbe835 100644
--- a/clang/lib/Sema/SemaARM.cpp
+++ b/clang/lib/Sema/SemaARM.cpp
@@ -624,20 +624,6 @@ static bool checkArmStreamingBuiltin(Sema &S, CallExpr *TheCall,
return true;
}
-static bool hasArmZAState(const FunctionDecl *FD) {
- const auto *T = FD->getType()->getAs<FunctionProtoType>();
- return (T && FunctionType::getArmZAState(T->getAArch64SMEAttributes()) !=
- FunctionType::ARM_None) ||
- (FD->hasAttr<ArmNewAttr>() && FD->getAttr<ArmNewAttr>()->isNewZA());
-}
-
-static bool hasArmZT0State(const FunctionDecl *FD) {
- const auto *T = FD->getType()->getAs<FunctionProtoType>();
- return (T && FunctionType::getArmZT0State(T->getAArch64SMEAttributes()) !=
- FunctionType::ARM_None) ||
- (FD->hasAttr<ArmNewAttr>() && FD->getAttr<ArmNewAttr>()->isNewZT0());
-}
-
static ArmSMEState getSMEState(unsigned BuiltinID) {
switch (BuiltinID) {
default:
diff --git a/clang/lib/Sema/SemaStmt.cpp b/clang/lib/Sema/SemaStmt.cpp
index 25a07d0..947651d 100644
--- a/clang/lib/Sema/SemaStmt.cpp
+++ b/clang/lib/Sema/SemaStmt.cpp
@@ -4568,9 +4568,27 @@ buildCapturedStmtCaptureList(Sema &S, CapturedRegionScopeInfo *RSI,
return false;
}
+static std::optional<int>
+isOpenMPCapturedRegionInArmSMEFunction(Sema const &S, CapturedRegionKind Kind) {
+ if (!S.getLangOpts().OpenMP || Kind != CR_OpenMP)
+ return {};
+ if (const FunctionDecl *FD = S.getCurFunctionDecl(/*AllowLambda=*/true)) {
+ if (IsArmStreamingFunction(FD, /*IncludeLocallyStreaming=*/true))
+ return /* in streaming functions */ 0;
+ if (hasArmZAState(FD))
+ return /* in functions with ZA state */ 1;
+ if (hasArmZT0State(FD))
+ return /* in fuctions with ZT0 state */ 2;
+ }
+ return {};
+}
+
void Sema::ActOnCapturedRegionStart(SourceLocation Loc, Scope *CurScope,
CapturedRegionKind Kind,
unsigned NumParams) {
+ if (auto ErrorIndex = isOpenMPCapturedRegionInArmSMEFunction(*this, Kind))
+ Diag(Loc, diag::err_sme_openmp_captured_region) << *ErrorIndex;
+
CapturedDecl *CD = nullptr;
RecordDecl *RD = CreateCapturedStmtRecordDecl(CD, Loc, NumParams);
@@ -4602,6 +4620,9 @@ void Sema::ActOnCapturedRegionStart(SourceLocation Loc, Scope *CurScope,
CapturedRegionKind Kind,
ArrayRef<CapturedParamNameType> Params,
unsigned OpenMPCaptureLevel) {
+ if (auto ErrorIndex = isOpenMPCapturedRegionInArmSMEFunction(*this, Kind))
+ Diag(Loc, diag::err_sme_openmp_captured_region) << *ErrorIndex;
+
CapturedDecl *CD = nullptr;
RecordDecl *RD = CreateCapturedStmtRecordDecl(CD, Loc, Params.size());
diff --git a/clang/test/Sema/aarch64-sme-attrs-openmp-captured-region.c b/clang/test/Sema/aarch64-sme-attrs-openmp-captured-region.c
new file mode 100644
index 0000000..6fb7c60
--- /dev/null
+++ b/clang/test/Sema/aarch64-sme-attrs-openmp-captured-region.c
@@ -0,0 +1,81 @@
+// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme2 -fopenmp -fsyntax-only -verify %s
+// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme2 -fopenmp -fsyntax-only -verify=expected-cpp -x c++ %s
+
+int compute(int);
+
+void streaming_openmp_captured_region(int * out) __arm_streaming {
+ // expected-error@+2 {{OpenMP captured regions are not yet supported in streaming functions}}
+ // expected-cpp-error@+1 {{OpenMP captured regions are not yet supported in streaming functions}}
+ #pragma omp parallel for num_threads(32)
+ for (int ci = 0; ci < 8; ci++) {
+ out[ci] = compute(ci);
+ }
+}
+
+__arm_locally_streaming void locally_streaming_openmp_captured_region(int * out) {
+ // expected-error@+2 {{OpenMP captured regions are not yet supported in streaming functions}}
+ // expected-cpp-error@+1 {{OpenMP captured regions are not yet supported in streaming functions}}
+ #pragma omp parallel for num_threads(32)
+ for (int ci = 0; ci < 8; ci++) {
+ out[ci] = compute(ci);
+ }
+}
+
+void za_state_captured_region(int * out) __arm_inout("za") {
+ // expected-error@+2 {{OpenMP captured regions are not yet supported in functions with ZA state}}
+ // expected-cpp-error@+1 {{OpenMP captured regions are not yet supported in functions with ZA state}}
+ #pragma omp parallel for num_threads(32)
+ for (int ci = 0; ci < 8; ci++) {
+ out[ci] = compute(ci);
+ }
+}
+
+__arm_new("za") void new_za_state_captured_region(int * out) {
+ // expected-error@+2 {{OpenMP captured regions are not yet supported in functions with ZA state}}
+ // expected-cpp-error@+1 {{OpenMP captured regions are not yet supported in functions with ZA state}}
+ #pragma omp parallel for num_threads(32)
+ for (int ci = 0; ci < 8; ci++) {
+ out[ci] = compute(ci);
+ }
+}
+
+void zt0_state_openmp_captured_region(int * out) __arm_inout("zt0") {
+ // expected-error@+2 {{OpenMP captured regions are not yet supported in functions with ZT0 state}}
+ // expected-cpp-error@+1 {{OpenMP captured regions are not yet supported in functions with ZT0 state}}
+ #pragma omp parallel for num_threads(32)
+ for (int ci = 0; ci < 8; ci++) {
+ out[ci] = compute(ci);
+ }
+}
+
+__arm_new("zt0") void new_zt0_state_openmp_captured_region(int * out) {
+ // expected-error@+2 {{OpenMP captured regions are not yet supported in functions with ZT0 state}}
+ // expected-cpp-error@+1 {{OpenMP captured regions are not yet supported in functions with ZT0 state}}
+ #pragma omp parallel for num_threads(32)
+ for (int ci = 0; ci < 8; ci++) {
+ out[ci] = compute(ci);
+ }
+}
+
+/// OpenMP directives that don't create a captured region are okay:
+
+void streaming_function_openmp(int * out) __arm_streaming __arm_inout("za", "zt0") {
+ #pragma omp unroll full
+ for (int ci = 0; ci < 8; ci++) {
+ out[ci] = compute(ci);
+ }
+}
+
+__arm_locally_streaming void locally_streaming_openmp(int * out) __arm_inout("za", "zt0") {
+ #pragma omp unroll full
+ for (int ci = 0; ci < 8; ci++) {
+ out[ci] = compute(ci);
+ }
+}
+
+__arm_new("za", "zt0") void arm_new_openmp(int * out) {
+ #pragma omp unroll full
+ for (int ci = 0; ci < 8; ci++) {
+ out[ci] = compute(ci);
+ }
+}