aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mesonbuild/interpreterbase.py55
-rwxr-xr-xrun_unittests.py34
2 files changed, 89 insertions, 0 deletions
diff --git a/mesonbuild/interpreterbase.py b/mesonbuild/interpreterbase.py
index f17dfba..4fd1ae9 100644
--- a/mesonbuild/interpreterbase.py
+++ b/mesonbuild/interpreterbase.py
@@ -228,6 +228,61 @@ class permittedKwargs:
return f(*wrapped_args, **wrapped_kwargs)
return T.cast(TV_func, wrapped)
+
+def typed_pos_args(name: str, *types: T.Union[T.Type, T.Tuple[T.Type, ...]]) -> T.Callable[..., T.Any]:
+ """Decorator that types type checking of positional arguments.
+
+ allows replacing this:
+ ```python
+ def func(self, node, args, kwargs):
+ if len(args) != 2:
+ raise Exception('... takes exactly 2 arguments)
+ foo: str = args[0]
+ if not isinstance(foo, str):
+ raise ...
+ bar: int = args[1]
+ if not isinstance(bar, int):
+ raise ...
+
+ # actual useful stuff
+ ```
+ with:
+ ```python
+ @typed_pos_args('func_name', str, int)
+ def func(self, node, args: T.Tuple[str, int], kwargs):
+ foo, bar = args
+
+ # actual useful stuff
+ ```
+ """
+ def inner(f: TV_func) -> TV_func:
+
+ @wraps(f)
+ def wrapper(*wrapped_args: T.Any, **wrapped_kwargs: T.Any) -> T.Any:
+ args = _get_callee_args(wrapped_args)[2]
+ assert isinstance(args, list), args
+ if len(args) != len(types):
+ raise InvalidArguments(f'{name} takes exactly {len(types)} arguments, but got {len(args)}.')
+ for i, (arg, type_) in enumerate(zip(args, types), start=1):
+ if not isinstance(arg, type_):
+ if isinstance(type_, tuple):
+ shouldbe = 'one of: {}'.format(", ".join(f'"{t.__name__}"' for t in type_))
+ else:
+ shouldbe = f'"{type_.__name__}"'
+ raise InvalidArguments(f'{name} argument {i} was of type "{type(arg).__name__}" but should have been {shouldbe}')
+
+ # Ensure that we're actually passing a tuple.
+ # Depending on what kind of function we're calling the length of
+ # wrapped_args can vary.
+ nargs = list(wrapped_args)
+ i = nargs.index(args)
+ nargs[i] = tuple(args)
+ return f(*nargs, **wrapped_kwargs)
+
+ return T.cast(TV_func, wrapper)
+ return inner
+
+
class FeatureCheckBase(metaclass=abc.ABCMeta):
"Base class for feature version checks"
diff --git a/run_unittests.py b/run_unittests.py
index 9cc36fb..06386c8 100755
--- a/run_unittests.py
+++ b/run_unittests.py
@@ -51,6 +51,7 @@ import mesonbuild.mesonlib
import mesonbuild.coredata
import mesonbuild.modules.gnome
from mesonbuild.interpreter import Interpreter, ObjectHolder
+from mesonbuild.interpreterbase import typed_pos_args, InvalidArguments
from mesonbuild.ast import AstInterpreter
from mesonbuild.mesonlib import (
BuildDirLock, LibType, MachineChoice, PerMachine, Version, is_windows,
@@ -1293,6 +1294,39 @@ class InternalTests(unittest.TestCase):
self.assertFalse(errors)
+ def test_typed_pos_args_types(self) -> None:
+ @typed_pos_args('foo', str, int, bool)
+ def _(obj, node, args: T.Tuple[str, int, bool], kwargs) -> None:
+ self.assertIsInstance(args, tuple)
+ self.assertIsInstance(args[0], str)
+ self.assertIsInstance(args[1], int)
+ self.assertIsInstance(args[2], bool)
+
+ _(None, mock.Mock(), ['string', 1, False], None)
+
+ def test_typed_pos_args_types_invalid(self) -> None:
+ @typed_pos_args('foo', str, int, bool)
+ def _(obj, node, args: T.Tuple[str, int, bool], kwargs) -> None:
+ self.assertTrue(False) # should not be reachable
+
+ with self.assertRaises(InvalidArguments) as cm:
+ _(None, mock.Mock(), ['string', 1.0, False], None)
+ self.assertEqual(str(cm.exception), 'foo argument 2 was of type "float" but should have been "int"')
+
+ def test_typed_pos_args_types_wrong_number(self) -> None:
+ @typed_pos_args('foo', str, int, bool)
+ def _(obj, node, args: T.Tuple[str, int, bool], kwargs) -> None:
+ self.assertTrue(False) # should not be reachable
+
+ with self.assertRaises(InvalidArguments) as cm:
+ _(None, mock.Mock(), ['string', 1], None)
+ self.assertEqual(str(cm.exception), 'foo takes exactly 3 arguments, but got 2.')
+
+ with self.assertRaises(InvalidArguments) as cm:
+ _(None, mock.Mock(), ['string', 1, True, True], None)
+ self.assertEqual(str(cm.exception), 'foo takes exactly 3 arguments, but got 4.')
+
+
@unittest.skipIf(is_tarball(), 'Skipping because this is a tarball release')
class DataTests(unittest.TestCase):