aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorOlexa Bilaniuk <obilaniu@gmail.com>2021-03-04 12:51:08 -0500
committerXavier Claessens <xclaesse@gmail.com>2021-03-05 06:57:40 -0500
commitc4e436348307fc6d75f712d166423df60ddcdc16 (patch)
treef17a9f4fddcddd97bd07fd96d0f6a3b3390c492f
parent504ae2dee8f0481455cd72ed4a7803ced608b875 (diff)
downloadmeson-c4e436348307fc6d75f712d166423df60ddcdc16.zip
meson-c4e436348307fc6d75f712d166423df60ddcdc16.tar.gz
meson-c4e436348307fc6d75f712d166423df60ddcdc16.tar.bz2
Port CUDA module to new API.
-rw-r--r--mesonbuild/modules/unstable_cuda.py34
1 files changed, 23 insertions, 11 deletions
diff --git a/mesonbuild/modules/unstable_cuda.py b/mesonbuild/modules/unstable_cuda.py
index 33df0bd..e0510b1 100644
--- a/mesonbuild/modules/unstable_cuda.py
+++ b/mesonbuild/modules/unstable_cuda.py
@@ -12,27 +12,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import typing as T
import re
from ..mesonlib import version_compare
from ..interpreter import CompilerHolder
from ..compilers import CudaCompiler
-from . import ExtensionModule, ModuleReturnValue
+from . import ModuleObject
from ..interpreterbase import (
flatten, permittedKwargs, noKwargs,
InvalidArguments, FeatureNew
)
-class CudaModule(ExtensionModule):
+class CudaModule(ModuleObject):
@FeatureNew('CUDA module', '0.50.0')
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
+ self.methods.update({
+ "min_driver_version": self.min_driver_version,
+ "nvcc_arch_flags": self.nvcc_arch_flags,
+ "nvcc_arch_readable": self.nvcc_arch_readable,
+ })
@noKwargs
- def min_driver_version(self, state, args, kwargs):
+ def min_driver_version(self, state: 'ModuleState',
+ args: T.Tuple[str],
+ kwargs: T.Dict[str, T.Any]) -> str:
argerror = InvalidArguments('min_driver_version must have exactly one positional argument: ' +
'a CUDA Toolkit version string. Beware that, since CUDA 11.0, ' +
'the CUDA Toolkit\'s components (including NVCC) are versioned ' +
@@ -69,19 +77,23 @@ class CudaModule(ExtensionModule):
driver_version = d.get(state.host_machine.system, d['linux'])
break
- return ModuleReturnValue(driver_version, [driver_version])
+ return driver_version
@permittedKwargs(['detected'])
- def nvcc_arch_flags(self, state, args, kwargs):
- nvcc_arch_args = self._validate_nvcc_arch_args(state, args, kwargs)
+ def nvcc_arch_flags(self, state: 'ModuleState',
+ args: T.Tuple[T.Union[CompilerHolder, CudaCompiler, str]],
+ kwargs: T.Dict[str, T.Any]) -> T.List[str]:
+ nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs)
ret = self._nvcc_arch_flags(*nvcc_arch_args)[0]
- return ModuleReturnValue(ret, [ret])
+ return ret
@permittedKwargs(['detected'])
- def nvcc_arch_readable(self, state, args, kwargs):
- nvcc_arch_args = self._validate_nvcc_arch_args(state, args, kwargs)
+ def nvcc_arch_readable(self, state: 'ModuleState',
+ args: T.Tuple[T.Union[CompilerHolder, CudaCompiler, str]],
+ kwargs: T.Dict[str, T.Any]) -> T.List[str]:
+ nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs)
ret = self._nvcc_arch_flags(*nvcc_arch_args)[1]
- return ModuleReturnValue(ret, [ret])
+ return ret
@staticmethod
def _break_arch_string(s):
@@ -107,7 +119,7 @@ class CudaModule(ExtensionModule):
return c
return 'unknown'
- def _validate_nvcc_arch_args(self, state, args, kwargs):
+ def _validate_nvcc_arch_args(self, args, kwargs):
argerror = InvalidArguments('The first argument must be an NVCC compiler object, or its version string!')
if len(args) < 1: