diff options
author | Eli Schwartz <eschwartz93@gmail.com> | 2024-01-03 23:51:49 -0500 |
---|---|---|
committer | Eli Schwartz <eschwartz93@gmail.com> | 2024-02-12 23:35:39 -0500 |
commit | 5899daf25b406737b436f2256dbcaf273fae6ee3 (patch) | |
tree | 7fe895bab8784b5b219c7f29f78ccd201a0a1617 | |
parent | 1b15176168d90c913153c0da598e130e9214b3ab (diff) | |
download | meson-5899daf25b406737b436f2256dbcaf273fae6ee3.zip meson-5899daf25b406737b436f2256dbcaf273fae6ee3.tar.gz meson-5899daf25b406737b436f2256dbcaf273fae6ee3.tar.bz2 |
cuda module: use typed_pos_args for most methods
The min_driver_version function has an extensive, informative custom
error message, so leave that in place.
The other two functions didn't have much information there, and it's
fairly evident that the cuda compiler itself is the best thing to have
here. Moreover, there was some fairly gnarly code to validate the
allowed values, which we can greatly simplify by uplifting the
typechecking parts to the dedicated decorators that are both really good
at it, and have nicely formatted error messages complete with reference
to the problematic functions.
-rw-r--r-- | mesonbuild/modules/cuda.py | 27 |
1 files changed, 11 insertions, 16 deletions
diff --git a/mesonbuild/modules/cuda.py b/mesonbuild/modules/cuda.py index b52288a..6900538 100644 --- a/mesonbuild/modules/cuda.py +++ b/mesonbuild/modules/cuda.py @@ -13,14 +13,13 @@ from ..interpreter.type_checking import NoneType from . import NewExtensionModule, ModuleInfo from ..interpreterbase import ( - ContainerTypeInfo, InvalidArguments, KwargInfo, flatten, noKwargs, typed_kwargs, + ContainerTypeInfo, InvalidArguments, KwargInfo, flatten, noKwargs, typed_kwargs, typed_pos_args, ) if T.TYPE_CHECKING: from typing_extensions import TypedDict from . import ModuleState - from ..compilers import Compiler class ArchFlagsKwargs(TypedDict): detected: T.Optional[T.List[str]] @@ -95,17 +94,19 @@ class CudaModule(NewExtensionModule): return driver_version + @typed_pos_args('cuda.nvcc_arch_flags', (str, CudaCompiler), varargs=str) @typed_kwargs('cuda.nvcc_arch_flags', DETECTED_KW) def nvcc_arch_flags(self, state: 'ModuleState', - args: T.Tuple[T.Union[Compiler, CudaCompiler, str]], + args: T.Tuple[T.Union[CudaCompiler, str], T.List[str]], kwargs: ArchFlagsKwargs) -> T.List[str]: nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs) ret = self._nvcc_arch_flags(*nvcc_arch_args)[0] return ret + @typed_pos_args('cuda.nvcc_arch_readable', (str, CudaCompiler), varargs=str) @typed_kwargs('cuda.nvcc_arch_readable', DETECTED_KW) def nvcc_arch_readable(self, state: 'ModuleState', - args: T.Tuple[T.Union[Compiler, CudaCompiler, str]], + args: T.Tuple[T.Union[CudaCompiler, str], T.List[str]], kwargs: ArchFlagsKwargs) -> T.List[str]: nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs) ret = self._nvcc_arch_flags(*nvcc_arch_args)[1] @@ -123,21 +124,15 @@ class CudaModule(NewExtensionModule): return [c.detected_cc] return [] - def _validate_nvcc_arch_args(self, args, kwargs: ArchFlagsKwargs): - argerror = InvalidArguments('The first argument must be an NVCC compiler object, or its version string!') + def _validate_nvcc_arch_args(self, args: T.Tuple[T.Union[str, CudaCompiler], T.List[str]], kwargs: ArchFlagsKwargs): - if len(args) < 1: - raise argerror + compiler = args[0] + if isinstance(compiler, CudaCompiler): + cuda_version = compiler.version else: - compiler = args[0] - if isinstance(compiler, CudaCompiler): - cuda_version = compiler.version - elif isinstance(compiler, str): - cuda_version = compiler - else: - raise argerror + cuda_version = compiler - arch_list = [] if len(args) <= 1 else flatten(args[1:]) + arch_list = args[1] arch_list = [self._break_arch_string(a) for a in arch_list] arch_list = flatten(arch_list) if len(arch_list) > 1 and not set(arch_list).isdisjoint({'All', 'Common', 'Auto'}): |