aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEli Schwartz <eschwartz93@gmail.com>2023-10-15 21:26:58 -0400
committerEli Schwartz <eschwartz93@gmail.com>2024-02-12 23:51:35 -0500
commit65ee397f341688282291b0ef529a7c6aa4c2f9f8 (patch)
tree0328e4232ccb82a2d31c44f471a1b878dc4d2d2b
parent5899daf25b406737b436f2256dbcaf273fae6ee3 (diff)
downloadmeson-65ee397f341688282291b0ef529a7c6aa4c2f9f8.zip
meson-65ee397f341688282291b0ef529a7c6aa4c2f9f8.tar.gz
meson-65ee397f341688282291b0ef529a7c6aa4c2f9f8.tar.bz2
cuda module: fully type annotate
Special notes: - _nvcc_arch_flags is always called with exact arguments, no need for default values - min_driver_version has its args annotation loosened because it has to fit the constraints of the module interface?
-rw-r--r--mesonbuild/modules/cuda.py48
-rwxr-xr-xrun_mypy.py1
2 files changed, 26 insertions, 23 deletions
diff --git a/mesonbuild/modules/cuda.py b/mesonbuild/modules/cuda.py
index 6900538..eb73a57 100644
--- a/mesonbuild/modules/cuda.py
+++ b/mesonbuild/modules/cuda.py
@@ -3,27 +3,31 @@
from __future__ import annotations
-import typing as T
import re
+import typing as T
-from ..mesonlib import version_compare
+from ..mesonlib import listify, version_compare
from ..compilers.cuda import CudaCompiler
from ..interpreter.type_checking import NoneType
from . import NewExtensionModule, ModuleInfo
from ..interpreterbase import (
- ContainerTypeInfo, InvalidArguments, KwargInfo, flatten, noKwargs, typed_kwargs, typed_pos_args,
+ ContainerTypeInfo, InvalidArguments, KwargInfo, noKwargs, typed_kwargs, typed_pos_args,
)
if T.TYPE_CHECKING:
from typing_extensions import TypedDict
from . import ModuleState
+ from ..interpreter import Interpreter
+ from ..interpreterbase import TYPE_var
class ArchFlagsKwargs(TypedDict):
detected: T.Optional[T.List[str]]
+ AutoArch = T.Union[str, T.List[str]]
+
DETECTED_KW: KwargInfo[T.Union[None, T.List[str]]] = KwargInfo('detected', (ContainerTypeInfo(list, str), NoneType), listify=True)
@@ -31,7 +35,7 @@ class CudaModule(NewExtensionModule):
INFO = ModuleInfo('CUDA', '0.50.0', unstable=True)
- def __init__(self, *args, **kwargs):
+ def __init__(self, interp: Interpreter):
super().__init__()
self.methods.update({
"min_driver_version": self.min_driver_version,
@@ -41,7 +45,7 @@ class CudaModule(NewExtensionModule):
@noKwargs
def min_driver_version(self, state: 'ModuleState',
- args: T.Tuple[str],
+ args: T.List[TYPE_var],
kwargs: T.Dict[str, T.Any]) -> str:
argerror = InvalidArguments('min_driver_version must have exactly one positional argument: ' +
'a CUDA Toolkit version string. Beware that, since CUDA 11.0, ' +
@@ -113,18 +117,18 @@ class CudaModule(NewExtensionModule):
return ret
@staticmethod
- def _break_arch_string(s):
+ def _break_arch_string(s: str) -> T.List[str]:
s = re.sub('[ \t\r\n,;]+', ';', s)
- s = s.strip(';').split(';')
- return s
+ return s.strip(';').split(';')
@staticmethod
- def _detected_cc_from_compiler(c) -> T.List[str]:
+ def _detected_cc_from_compiler(c: T.Union[str, CudaCompiler]) -> T.List[str]:
if isinstance(c, CudaCompiler):
return [c.detected_cc]
return []
- def _validate_nvcc_arch_args(self, args: T.Tuple[T.Union[str, CudaCompiler], T.List[str]], kwargs: ArchFlagsKwargs):
+ def _validate_nvcc_arch_args(self, args: T.Tuple[T.Union[str, CudaCompiler], T.List[str]],
+ kwargs: ArchFlagsKwargs) -> T.Tuple[str, AutoArch, T.List[str]]:
compiler = args[0]
if isinstance(compiler, CudaCompiler):
@@ -132,22 +136,20 @@ class CudaModule(NewExtensionModule):
else:
cuda_version = compiler
- arch_list = args[1]
- arch_list = [self._break_arch_string(a) for a in arch_list]
- arch_list = flatten(arch_list)
+ arch_list: AutoArch = args[1]
+ arch_list = listify([self._break_arch_string(a) for a in arch_list])
if len(arch_list) > 1 and not set(arch_list).isdisjoint({'All', 'Common', 'Auto'}):
raise InvalidArguments('''The special architectures 'All', 'Common' and 'Auto' must appear alone, as a positional argument!''')
arch_list = arch_list[0] if len(arch_list) == 1 else arch_list
detected = kwargs['detected'] if kwargs['detected'] is not None else self._detected_cc_from_compiler(compiler)
- detected = [self._break_arch_string(a) for a in detected]
- detected = flatten(detected)
+ detected = [x for a in detected for x in self._break_arch_string(a)]
if not set(detected).isdisjoint({'All', 'Common', 'Auto'}):
raise InvalidArguments('''The special architectures 'All', 'Common' and 'Auto' must appear alone, as a positional argument!''')
return cuda_version, arch_list, detected
- def _filter_cuda_arch_list(self, cuda_arch_list, lo: str, hi: T.Optional[str], saturate: str) -> T.List[str]:
+ def _filter_cuda_arch_list(self, cuda_arch_list: T.List[str], lo: str, hi: T.Optional[str], saturate: str) -> T.List[str]:
"""
Filter CUDA arch list (no codenames) for >= low and < hi architecture
bounds, and deduplicate.
@@ -165,7 +167,7 @@ class CudaModule(NewExtensionModule):
filtered_cuda_arch_list.append(arch)
return filtered_cuda_arch_list
- def _nvcc_arch_flags(self, cuda_version, cuda_arch_list='Auto', detected=''):
+ def _nvcc_arch_flags(self, cuda_version: str, cuda_arch_list: AutoArch, detected: T.List[str]) -> T.Tuple[T.List[str], T.List[str]]:
"""
Using the CUDA Toolkit version and the target architectures, compute
the NVCC architecture flags.
@@ -288,11 +290,11 @@ class CudaModule(NewExtensionModule):
cuda_arch_list = sorted(x for x in set(cuda_arch_list) if x)
- cuda_arch_bin = []
- cuda_arch_ptx = []
+ cuda_arch_bin: T.List[str] = []
+ cuda_arch_ptx: T.List[str] = []
for arch_name in cuda_arch_list:
- arch_bin = []
- arch_ptx = []
+ arch_bin: T.Optional[T.List[str]]
+ arch_ptx: T.Optional[T.List[str]]
add_ptx = arch_name.endswith('+PTX')
if add_ptx:
arch_name = arch_name[:-len('+PTX')]
@@ -371,5 +373,5 @@ class CudaModule(NewExtensionModule):
return nvcc_flags, nvcc_archs_readable
-def initialize(*args, **kwargs):
- return CudaModule(*args, **kwargs)
+def initialize(interp: Interpreter) -> CudaModule:
+ return CudaModule(interp)
diff --git a/run_mypy.py b/run_mypy.py
index a9b52d9..c57a75c 100755
--- a/run_mypy.py
+++ b/run_mypy.py
@@ -51,6 +51,7 @@ modules = [
'mesonbuild/mlog.py',
'mesonbuild/msubprojects.py',
'mesonbuild/modules/__init__.py',
+ 'mesonbuild/modules/cuda.py',
'mesonbuild/modules/external_project.py',
'mesonbuild/modules/fs.py',
'mesonbuild/modules/gnome.py',