From c4e436348307fc6d75f712d166423df60ddcdc16 Mon Sep 17 00:00:00 2001 From: Olexa Bilaniuk Date: Thu, 4 Mar 2021 12:51:08 -0500 Subject: Port CUDA module to new API. --- mesonbuild/modules/unstable_cuda.py | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/mesonbuild/modules/unstable_cuda.py b/mesonbuild/modules/unstable_cuda.py index 33df0bd..e0510b1 100644 --- a/mesonbuild/modules/unstable_cuda.py +++ b/mesonbuild/modules/unstable_cuda.py @@ -12,27 +12,35 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing as T import re from ..mesonlib import version_compare from ..interpreter import CompilerHolder from ..compilers import CudaCompiler -from . import ExtensionModule, ModuleReturnValue +from . import ModuleObject from ..interpreterbase import ( flatten, permittedKwargs, noKwargs, InvalidArguments, FeatureNew ) -class CudaModule(ExtensionModule): +class CudaModule(ModuleObject): @FeatureNew('CUDA module', '0.50.0') def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.methods.update({ + "min_driver_version": self.min_driver_version, + "nvcc_arch_flags": self.nvcc_arch_flags, + "nvcc_arch_readable": self.nvcc_arch_readable, + }) @noKwargs - def min_driver_version(self, state, args, kwargs): + def min_driver_version(self, state: 'ModuleState', + args: T.Tuple[str], + kwargs: T.Dict[str, T.Any]) -> str: argerror = InvalidArguments('min_driver_version must have exactly one positional argument: ' + 'a CUDA Toolkit version string. Beware that, since CUDA 11.0, ' + 'the CUDA Toolkit\'s components (including NVCC) are versioned ' + @@ -69,19 +77,23 @@ class CudaModule(ExtensionModule): driver_version = d.get(state.host_machine.system, d['linux']) break - return ModuleReturnValue(driver_version, [driver_version]) + return driver_version @permittedKwargs(['detected']) - def nvcc_arch_flags(self, state, args, kwargs): - nvcc_arch_args = self._validate_nvcc_arch_args(state, args, kwargs) + def nvcc_arch_flags(self, state: 'ModuleState', + args: T.Tuple[T.Union[CompilerHolder, CudaCompiler, str]], + kwargs: T.Dict[str, T.Any]) -> T.List[str]: + nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs) ret = self._nvcc_arch_flags(*nvcc_arch_args)[0] - return ModuleReturnValue(ret, [ret]) + return ret @permittedKwargs(['detected']) - def nvcc_arch_readable(self, state, args, kwargs): - nvcc_arch_args = self._validate_nvcc_arch_args(state, args, kwargs) + def nvcc_arch_readable(self, state: 'ModuleState', + args: T.Tuple[T.Union[CompilerHolder, CudaCompiler, str]], + kwargs: T.Dict[str, T.Any]) -> T.List[str]: + nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs) ret = self._nvcc_arch_flags(*nvcc_arch_args)[1] - return ModuleReturnValue(ret, [ret]) + return ret @staticmethod def _break_arch_string(s): @@ -107,7 +119,7 @@ class CudaModule(ExtensionModule): return c return 'unknown' - def _validate_nvcc_arch_args(self, state, args, kwargs): + def _validate_nvcc_arch_args(self, args, kwargs): argerror = InvalidArguments('The first argument must be an NVCC compiler object, or its version string!') if len(args) < 1: -- cgit v1.1