diff options
Diffstat (limited to 'clang')
-rw-r--r-- | clang/include/clang/AST/Decl.h | 6 | ||||
-rw-r--r-- | clang/include/clang/Basic/DiagnosticSemaKinds.td | 3 | ||||
-rw-r--r-- | clang/lib/AST/Decl.cpp | 14 | ||||
-rw-r--r-- | clang/lib/Sema/SemaARM.cpp | 14 | ||||
-rw-r--r-- | clang/lib/Sema/SemaStmt.cpp | 21 | ||||
-rw-r--r-- | clang/test/Sema/aarch64-sme-attrs-openmp-captured-region.c | 81 |
6 files changed, 125 insertions, 14 deletions
diff --git a/clang/include/clang/AST/Decl.h b/clang/include/clang/AST/Decl.h index 1640377..9593bab 100644 --- a/clang/include/clang/AST/Decl.h +++ b/clang/include/clang/AST/Decl.h @@ -5139,6 +5139,12 @@ static constexpr StringRef getOpenMPVariantManglingSeparatorStr() { bool IsArmStreamingFunction(const FunctionDecl *FD, bool IncludeLocallyStreaming); +/// Returns whether the given FunctionDecl has Arm ZA state. +bool hasArmZAState(const FunctionDecl *FD); + +/// Returns whether the given FunctionDecl has Arm ZT0 state. +bool hasArmZT0State(const FunctionDecl *FD); + } // namespace clang #endif // LLVM_CLANG_AST_DECL_H diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td index 9487e3b..a09fe03 100644 --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -3870,6 +3870,9 @@ def err_sme_definition_using_za_in_non_sme_target : Error< "function using ZA state requires 'sme'">; def err_sme_definition_using_zt0_in_non_sme2_target : Error< "function using ZT0 state requires 'sme2'">; +def err_sme_openmp_captured_region : Error< + "OpenMP captured regions are not yet supported in " + "%select{streaming functions|functions with ZA state|functions with ZT0 state}0">; def warn_sme_streaming_pass_return_vl_to_non_streaming : Warning< "%select{returning|passing}0 a VL-dependent argument %select{from|to}0 a function with a different" " streaming-mode is undefined behaviour when the streaming and non-streaming vector lengths are different at runtime">, diff --git a/clang/lib/AST/Decl.cpp b/clang/lib/AST/Decl.cpp index beb5fca..0bd4d64 100644 --- a/clang/lib/AST/Decl.cpp +++ b/clang/lib/AST/Decl.cpp @@ -5845,3 +5845,17 @@ bool clang::IsArmStreamingFunction(const FunctionDecl *FD, return false; } + +bool clang::hasArmZAState(const FunctionDecl *FD) { + const auto *T = FD->getType()->getAs<FunctionProtoType>(); + return (T && FunctionType::getArmZAState(T->getAArch64SMEAttributes()) != + FunctionType::ARM_None) || + (FD->hasAttr<ArmNewAttr>() && FD->getAttr<ArmNewAttr>()->isNewZA()); +} + +bool clang::hasArmZT0State(const FunctionDecl *FD) { + const auto *T = FD->getType()->getAs<FunctionProtoType>(); + return (T && FunctionType::getArmZT0State(T->getAArch64SMEAttributes()) != + FunctionType::ARM_None) || + (FD->hasAttr<ArmNewAttr>() && FD->getAttr<ArmNewAttr>()->isNewZT0()); +} diff --git a/clang/lib/Sema/SemaARM.cpp b/clang/lib/Sema/SemaARM.cpp index 2620bbc..9fbe835 100644 --- a/clang/lib/Sema/SemaARM.cpp +++ b/clang/lib/Sema/SemaARM.cpp @@ -624,20 +624,6 @@ static bool checkArmStreamingBuiltin(Sema &S, CallExpr *TheCall, return true; } -static bool hasArmZAState(const FunctionDecl *FD) { - const auto *T = FD->getType()->getAs<FunctionProtoType>(); - return (T && FunctionType::getArmZAState(T->getAArch64SMEAttributes()) != - FunctionType::ARM_None) || - (FD->hasAttr<ArmNewAttr>() && FD->getAttr<ArmNewAttr>()->isNewZA()); -} - -static bool hasArmZT0State(const FunctionDecl *FD) { - const auto *T = FD->getType()->getAs<FunctionProtoType>(); - return (T && FunctionType::getArmZT0State(T->getAArch64SMEAttributes()) != - FunctionType::ARM_None) || - (FD->hasAttr<ArmNewAttr>() && FD->getAttr<ArmNewAttr>()->isNewZT0()); -} - static ArmSMEState getSMEState(unsigned BuiltinID) { switch (BuiltinID) { default: diff --git a/clang/lib/Sema/SemaStmt.cpp b/clang/lib/Sema/SemaStmt.cpp index 25a07d0..947651d 100644 --- a/clang/lib/Sema/SemaStmt.cpp +++ b/clang/lib/Sema/SemaStmt.cpp @@ -4568,9 +4568,27 @@ buildCapturedStmtCaptureList(Sema &S, CapturedRegionScopeInfo *RSI, return false; } +static std::optional<int> +isOpenMPCapturedRegionInArmSMEFunction(Sema const &S, CapturedRegionKind Kind) { + if (!S.getLangOpts().OpenMP || Kind != CR_OpenMP) + return {}; + if (const FunctionDecl *FD = S.getCurFunctionDecl(/*AllowLambda=*/true)) { + if (IsArmStreamingFunction(FD, /*IncludeLocallyStreaming=*/true)) + return /* in streaming functions */ 0; + if (hasArmZAState(FD)) + return /* in functions with ZA state */ 1; + if (hasArmZT0State(FD)) + return /* in fuctions with ZT0 state */ 2; + } + return {}; +} + void Sema::ActOnCapturedRegionStart(SourceLocation Loc, Scope *CurScope, CapturedRegionKind Kind, unsigned NumParams) { + if (auto ErrorIndex = isOpenMPCapturedRegionInArmSMEFunction(*this, Kind)) + Diag(Loc, diag::err_sme_openmp_captured_region) << *ErrorIndex; + CapturedDecl *CD = nullptr; RecordDecl *RD = CreateCapturedStmtRecordDecl(CD, Loc, NumParams); @@ -4602,6 +4620,9 @@ void Sema::ActOnCapturedRegionStart(SourceLocation Loc, Scope *CurScope, CapturedRegionKind Kind, ArrayRef<CapturedParamNameType> Params, unsigned OpenMPCaptureLevel) { + if (auto ErrorIndex = isOpenMPCapturedRegionInArmSMEFunction(*this, Kind)) + Diag(Loc, diag::err_sme_openmp_captured_region) << *ErrorIndex; + CapturedDecl *CD = nullptr; RecordDecl *RD = CreateCapturedStmtRecordDecl(CD, Loc, Params.size()); diff --git a/clang/test/Sema/aarch64-sme-attrs-openmp-captured-region.c b/clang/test/Sema/aarch64-sme-attrs-openmp-captured-region.c new file mode 100644 index 0000000..6fb7c60 --- /dev/null +++ b/clang/test/Sema/aarch64-sme-attrs-openmp-captured-region.c @@ -0,0 +1,81 @@ +// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme2 -fopenmp -fsyntax-only -verify %s +// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme2 -fopenmp -fsyntax-only -verify=expected-cpp -x c++ %s + +int compute(int); + +void streaming_openmp_captured_region(int * out) __arm_streaming { + // expected-error@+2 {{OpenMP captured regions are not yet supported in streaming functions}} + // expected-cpp-error@+1 {{OpenMP captured regions are not yet supported in streaming functions}} + #pragma omp parallel for num_threads(32) + for (int ci = 0; ci < 8; ci++) { + out[ci] = compute(ci); + } +} + +__arm_locally_streaming void locally_streaming_openmp_captured_region(int * out) { + // expected-error@+2 {{OpenMP captured regions are not yet supported in streaming functions}} + // expected-cpp-error@+1 {{OpenMP captured regions are not yet supported in streaming functions}} + #pragma omp parallel for num_threads(32) + for (int ci = 0; ci < 8; ci++) { + out[ci] = compute(ci); + } +} + +void za_state_captured_region(int * out) __arm_inout("za") { + // expected-error@+2 {{OpenMP captured regions are not yet supported in functions with ZA state}} + // expected-cpp-error@+1 {{OpenMP captured regions are not yet supported in functions with ZA state}} + #pragma omp parallel for num_threads(32) + for (int ci = 0; ci < 8; ci++) { + out[ci] = compute(ci); + } +} + +__arm_new("za") void new_za_state_captured_region(int * out) { + // expected-error@+2 {{OpenMP captured regions are not yet supported in functions with ZA state}} + // expected-cpp-error@+1 {{OpenMP captured regions are not yet supported in functions with ZA state}} + #pragma omp parallel for num_threads(32) + for (int ci = 0; ci < 8; ci++) { + out[ci] = compute(ci); + } +} + +void zt0_state_openmp_captured_region(int * out) __arm_inout("zt0") { + // expected-error@+2 {{OpenMP captured regions are not yet supported in functions with ZT0 state}} + // expected-cpp-error@+1 {{OpenMP captured regions are not yet supported in functions with ZT0 state}} + #pragma omp parallel for num_threads(32) + for (int ci = 0; ci < 8; ci++) { + out[ci] = compute(ci); + } +} + +__arm_new("zt0") void new_zt0_state_openmp_captured_region(int * out) { + // expected-error@+2 {{OpenMP captured regions are not yet supported in functions with ZT0 state}} + // expected-cpp-error@+1 {{OpenMP captured regions are not yet supported in functions with ZT0 state}} + #pragma omp parallel for num_threads(32) + for (int ci = 0; ci < 8; ci++) { + out[ci] = compute(ci); + } +} + +/// OpenMP directives that don't create a captured region are okay: + +void streaming_function_openmp(int * out) __arm_streaming __arm_inout("za", "zt0") { + #pragma omp unroll full + for (int ci = 0; ci < 8; ci++) { + out[ci] = compute(ci); + } +} + +__arm_locally_streaming void locally_streaming_openmp(int * out) __arm_inout("za", "zt0") { + #pragma omp unroll full + for (int ci = 0; ci < 8; ci++) { + out[ci] = compute(ci); + } +} + +__arm_new("za", "zt0") void arm_new_openmp(int * out) { + #pragma omp unroll full + for (int ci = 0; ci < 8; ci++) { + out[ci] = compute(ci); + } +} |