diff options
author | Olexa Bilaniuk <obilaniu@gmail.com> | 2021-03-04 12:51:08 -0500 |
---|---|---|
committer | Xavier Claessens <xclaesse@gmail.com> | 2021-03-05 06:57:40 -0500 |
commit | c4e436348307fc6d75f712d166423df60ddcdc16 (patch) | |
tree | f17a9f4fddcddd97bd07fd96d0f6a3b3390c492f /mesonbuild/modules | |
parent | 504ae2dee8f0481455cd72ed4a7803ced608b875 (diff) | |
download | meson-c4e436348307fc6d75f712d166423df60ddcdc16.zip meson-c4e436348307fc6d75f712d166423df60ddcdc16.tar.gz meson-c4e436348307fc6d75f712d166423df60ddcdc16.tar.bz2 |
Port CUDA module to new API.
Diffstat (limited to 'mesonbuild/modules')
-rw-r--r-- | mesonbuild/modules/unstable_cuda.py | 34 |
1 files changed, 23 insertions, 11 deletions
diff --git a/mesonbuild/modules/unstable_cuda.py b/mesonbuild/modules/unstable_cuda.py index 33df0bd..e0510b1 100644 --- a/mesonbuild/modules/unstable_cuda.py +++ b/mesonbuild/modules/unstable_cuda.py @@ -12,27 +12,35 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing as T import re from ..mesonlib import version_compare from ..interpreter import CompilerHolder from ..compilers import CudaCompiler -from . import ExtensionModule, ModuleReturnValue +from . import ModuleObject from ..interpreterbase import ( flatten, permittedKwargs, noKwargs, InvalidArguments, FeatureNew ) -class CudaModule(ExtensionModule): +class CudaModule(ModuleObject): @FeatureNew('CUDA module', '0.50.0') def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.methods.update({ + "min_driver_version": self.min_driver_version, + "nvcc_arch_flags": self.nvcc_arch_flags, + "nvcc_arch_readable": self.nvcc_arch_readable, + }) @noKwargs - def min_driver_version(self, state, args, kwargs): + def min_driver_version(self, state: 'ModuleState', + args: T.Tuple[str], + kwargs: T.Dict[str, T.Any]) -> str: argerror = InvalidArguments('min_driver_version must have exactly one positional argument: ' + 'a CUDA Toolkit version string. Beware that, since CUDA 11.0, ' + 'the CUDA Toolkit\'s components (including NVCC) are versioned ' + @@ -69,19 +77,23 @@ class CudaModule(ExtensionModule): driver_version = d.get(state.host_machine.system, d['linux']) break - return ModuleReturnValue(driver_version, [driver_version]) + return driver_version @permittedKwargs(['detected']) - def nvcc_arch_flags(self, state, args, kwargs): - nvcc_arch_args = self._validate_nvcc_arch_args(state, args, kwargs) + def nvcc_arch_flags(self, state: 'ModuleState', + args: T.Tuple[T.Union[CompilerHolder, CudaCompiler, str]], + kwargs: T.Dict[str, T.Any]) -> T.List[str]: + nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs) ret = self._nvcc_arch_flags(*nvcc_arch_args)[0] - return ModuleReturnValue(ret, [ret]) + return ret @permittedKwargs(['detected']) - def nvcc_arch_readable(self, state, args, kwargs): - nvcc_arch_args = self._validate_nvcc_arch_args(state, args, kwargs) + def nvcc_arch_readable(self, state: 'ModuleState', + args: T.Tuple[T.Union[CompilerHolder, CudaCompiler, str]], + kwargs: T.Dict[str, T.Any]) -> T.List[str]: + nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs) ret = self._nvcc_arch_flags(*nvcc_arch_args)[1] - return ModuleReturnValue(ret, [ret]) + return ret @staticmethod def _break_arch_string(s): @@ -107,7 +119,7 @@ class CudaModule(ExtensionModule): return c return 'unknown' - def _validate_nvcc_arch_args(self, state, args, kwargs): + def _validate_nvcc_arch_args(self, args, kwargs): argerror = InvalidArguments('The first argument must be an NVCC compiler object, or its version string!') if len(args) < 1: |