diff options
author | Valentin Clement <clementval@gmail.com> | 2024-04-30 15:42:08 -0700 |
---|---|---|
committer | Valentin Clement <clementval@gmail.com> | 2024-04-30 21:14:12 -0700 |
commit | 3e930864eb39a81598fa03e539552e1664cdb989 (patch) | |
tree | bd2978d70acb47cf28916a8fd953e56ae121150e | |
parent | 240592a772a40b4ffa75921f7b555d2a969b3383 (diff) | |
download | llvm-3e930864eb39a81598fa03e539552e1664cdb989.zip llvm-3e930864eb39a81598fa03e539552e1664cdb989.tar.gz llvm-3e930864eb39a81598fa03e539552e1664cdb989.tar.bz2 |
Reland [flang][cuda] Update attribute compatibily check for unified matching rule
-rw-r--r-- | flang/include/flang/Common/Fortran.h | 4 | ||||
-rw-r--r-- | flang/lib/Common/Fortran.cpp | 24 | ||||
-rw-r--r-- | flang/lib/Evaluate/characteristics.cpp | 10 | ||||
-rw-r--r-- | flang/lib/Semantics/check-call.cpp | 5 | ||||
-rw-r--r-- | flang/test/Semantics/cuf13.cuf | 9 |
5 files changed, 43 insertions, 9 deletions
diff --git a/flang/include/flang/Common/Fortran.h b/flang/include/flang/Common/Fortran.h index 2a53452..3b965fe 100644 --- a/flang/include/flang/Common/Fortran.h +++ b/flang/include/flang/Common/Fortran.h @@ -114,8 +114,8 @@ static constexpr IgnoreTKRSet ignoreTKRAll{IgnoreTKR::Type, IgnoreTKR::Kind, IgnoreTKR::Rank, IgnoreTKR::Device, IgnoreTKR::Managed}; std::string AsFortran(IgnoreTKRSet); -bool AreCompatibleCUDADataAttrs( - std::optional<CUDADataAttr>, std::optional<CUDADataAttr>, IgnoreTKRSet); +bool AreCompatibleCUDADataAttrs(std::optional<CUDADataAttr>, + std::optional<CUDADataAttr>, IgnoreTKRSet, bool allowUnifiedMatchingRule); static constexpr char blankCommonObjectName[] = "__BLNK__"; diff --git a/flang/lib/Common/Fortran.cpp b/flang/lib/Common/Fortran.cpp index 8ada8fe..170ce8c 100644 --- a/flang/lib/Common/Fortran.cpp +++ b/flang/lib/Common/Fortran.cpp @@ -97,8 +97,12 @@ std::string AsFortran(IgnoreTKRSet tkr) { return result; } +/// Check compatibilty of CUDA attribute. +/// When `allowUnifiedMatchingRule` is enabled, argument `x` represents the +/// dummy argument attribute while `y` represents the actual argument attribute. bool AreCompatibleCUDADataAttrs(std::optional<CUDADataAttr> x, - std::optional<CUDADataAttr> y, IgnoreTKRSet ignoreTKR) { + std::optional<CUDADataAttr> y, IgnoreTKRSet ignoreTKR, + bool allowUnifiedMatchingRule) { if (!x && !y) { return true; } else if (x && y && *x == *y) { @@ -114,6 +118,24 @@ bool AreCompatibleCUDADataAttrs(std::optional<CUDADataAttr> x, x.value_or(CUDADataAttr::Managed) == CUDADataAttr::Managed && y.value_or(CUDADataAttr::Managed) == CUDADataAttr::Managed) { return true; + } else if (allowUnifiedMatchingRule) { + if (!x) { // Dummy argument has no attribute -> host + if (y && (*y == CUDADataAttr::Managed || *y == CUDADataAttr::Unified)) { + return true; + } + } else { + if (*x == CUDADataAttr::Device && y && + (*y == CUDADataAttr::Managed || *y == CUDADataAttr::Unified)) { + return true; + } else if (*x == CUDADataAttr::Managed && y && + *y == CUDADataAttr::Unified) { + return true; + } else if (*x == CUDADataAttr::Unified && y && + *y == CUDADataAttr::Managed) { + return true; + } + } + return false; } else { return false; } diff --git a/flang/lib/Evaluate/characteristics.cpp b/flang/lib/Evaluate/characteristics.cpp index 20f7476..ab03ca5 100644 --- a/flang/lib/Evaluate/characteristics.cpp +++ b/flang/lib/Evaluate/characteristics.cpp @@ -362,8 +362,9 @@ bool DummyDataObject::IsCompatibleWith(const DummyDataObject &actual, } } if (!attrs.test(Attr::Value) && - !common::AreCompatibleCUDADataAttrs( - cudaDataAttr, actual.cudaDataAttr, ignoreTKR)) { + !common::AreCompatibleCUDADataAttrs(cudaDataAttr, actual.cudaDataAttr, + ignoreTKR, + /*allowUnifiedMatchingRule=*/false)) { if (whyNot) { *whyNot = "incompatible CUDA data attributes"; } @@ -1754,8 +1755,9 @@ bool DistinguishUtils::Distinguishable( } else if (y.attrs.test(Attr::Allocatable) && x.attrs.test(Attr::Pointer) && x.intent != common::Intent::In) { return true; - } else if (!common::AreCompatibleCUDADataAttrs( - x.cudaDataAttr, y.cudaDataAttr, x.ignoreTKR | y.ignoreTKR)) { + } else if (!common::AreCompatibleCUDADataAttrs(x.cudaDataAttr, y.cudaDataAttr, + x.ignoreTKR | y.ignoreTKR, + /*allowUnifiedMatchingRule=*/false)) { return true; } else if (features_.IsEnabled( common::LanguageFeature::DistinguishableSpecifics) && diff --git a/flang/lib/Semantics/check-call.cpp b/flang/lib/Semantics/check-call.cpp index db0949e..f0da779 100644 --- a/flang/lib/Semantics/check-call.cpp +++ b/flang/lib/Semantics/check-call.cpp @@ -897,8 +897,9 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy, actualDataAttr = common::CUDADataAttr::Device; } } - if (!common::AreCompatibleCUDADataAttrs( - dummyDataAttr, actualDataAttr, dummy.ignoreTKR)) { + if (!common::AreCompatibleCUDADataAttrs(dummyDataAttr, actualDataAttr, + dummy.ignoreTKR, + /*allowUnifiedMatchingRule=*/true)) { auto toStr{[](std::optional<common::CUDADataAttr> x) { return x ? "ATTRIBUTES("s + parser::ToUpperCaseLetters(common::EnumToString(*x)) + ")"s diff --git a/flang/test/Semantics/cuf13.cuf b/flang/test/Semantics/cuf13.cuf index 7c6673e..6db8290 100644 --- a/flang/test/Semantics/cuf13.cuf +++ b/flang/test/Semantics/cuf13.cuf @@ -6,6 +6,10 @@ module matching module procedure sub_device end interface + interface subman + module procedure sub_host + end interface + contains subroutine sub_host(a) integer :: a(:) @@ -21,8 +25,13 @@ program m use matching integer, pinned, allocatable :: a(:) + integer, managed, allocatable :: b(:) logical :: plog allocate(a(100), pinned = plog) + allocate(b(200)) call sub(a) + + call subman(b) + end |