aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEli Schwartz <eschwartz93@gmail.com>2024-01-03 23:51:49 -0500
committerEli Schwartz <eschwartz93@gmail.com>2024-02-12 23:35:39 -0500
commit5899daf25b406737b436f2256dbcaf273fae6ee3 (patch)
tree7fe895bab8784b5b219c7f29f78ccd201a0a1617
parent1b15176168d90c913153c0da598e130e9214b3ab (diff)
downloadmeson-5899daf25b406737b436f2256dbcaf273fae6ee3.zip
meson-5899daf25b406737b436f2256dbcaf273fae6ee3.tar.gz
meson-5899daf25b406737b436f2256dbcaf273fae6ee3.tar.bz2
cuda module: use typed_pos_args for most methods
The min_driver_version function has an extensive, informative custom error message, so leave that in place. The other two functions didn't have much information there, and it's fairly evident that the cuda compiler itself is the best thing to have here. Moreover, there was some fairly gnarly code to validate the allowed values, which we can greatly simplify by uplifting the typechecking parts to the dedicated decorators that are both really good at it, and have nicely formatted error messages complete with reference to the problematic functions.
-rw-r--r--mesonbuild/modules/cuda.py27
1 files changed, 11 insertions, 16 deletions
diff --git a/mesonbuild/modules/cuda.py b/mesonbuild/modules/cuda.py
index b52288a..6900538 100644
--- a/mesonbuild/modules/cuda.py
+++ b/mesonbuild/modules/cuda.py
@@ -13,14 +13,13 @@ from ..interpreter.type_checking import NoneType
from . import NewExtensionModule, ModuleInfo
from ..interpreterbase import (
- ContainerTypeInfo, InvalidArguments, KwargInfo, flatten, noKwargs, typed_kwargs,
+ ContainerTypeInfo, InvalidArguments, KwargInfo, flatten, noKwargs, typed_kwargs, typed_pos_args,
)
if T.TYPE_CHECKING:
from typing_extensions import TypedDict
from . import ModuleState
- from ..compilers import Compiler
class ArchFlagsKwargs(TypedDict):
detected: T.Optional[T.List[str]]
@@ -95,17 +94,19 @@ class CudaModule(NewExtensionModule):
return driver_version
+ @typed_pos_args('cuda.nvcc_arch_flags', (str, CudaCompiler), varargs=str)
@typed_kwargs('cuda.nvcc_arch_flags', DETECTED_KW)
def nvcc_arch_flags(self, state: 'ModuleState',
- args: T.Tuple[T.Union[Compiler, CudaCompiler, str]],
+ args: T.Tuple[T.Union[CudaCompiler, str], T.List[str]],
kwargs: ArchFlagsKwargs) -> T.List[str]:
nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs)
ret = self._nvcc_arch_flags(*nvcc_arch_args)[0]
return ret
+ @typed_pos_args('cuda.nvcc_arch_readable', (str, CudaCompiler), varargs=str)
@typed_kwargs('cuda.nvcc_arch_readable', DETECTED_KW)
def nvcc_arch_readable(self, state: 'ModuleState',
- args: T.Tuple[T.Union[Compiler, CudaCompiler, str]],
+ args: T.Tuple[T.Union[CudaCompiler, str], T.List[str]],
kwargs: ArchFlagsKwargs) -> T.List[str]:
nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs)
ret = self._nvcc_arch_flags(*nvcc_arch_args)[1]
@@ -123,21 +124,15 @@ class CudaModule(NewExtensionModule):
return [c.detected_cc]
return []
- def _validate_nvcc_arch_args(self, args, kwargs: ArchFlagsKwargs):
- argerror = InvalidArguments('The first argument must be an NVCC compiler object, or its version string!')
+ def _validate_nvcc_arch_args(self, args: T.Tuple[T.Union[str, CudaCompiler], T.List[str]], kwargs: ArchFlagsKwargs):
- if len(args) < 1:
- raise argerror
+ compiler = args[0]
+ if isinstance(compiler, CudaCompiler):
+ cuda_version = compiler.version
else:
- compiler = args[0]
- if isinstance(compiler, CudaCompiler):
- cuda_version = compiler.version
- elif isinstance(compiler, str):
- cuda_version = compiler
- else:
- raise argerror
+ cuda_version = compiler
- arch_list = [] if len(args) <= 1 else flatten(args[1:])
+ arch_list = args[1]
arch_list = [self._break_arch_string(a) for a in arch_list]
arch_list = flatten(arch_list)
if len(arch_list) > 1 and not set(arch_list).isdisjoint({'All', 'Common', 'Auto'}):