diff options
-rw-r--r-- | mesonbuild/interpreterbase.py | 55 | ||||
-rwxr-xr-x | run_unittests.py | 34 |
2 files changed, 89 insertions, 0 deletions
diff --git a/mesonbuild/interpreterbase.py b/mesonbuild/interpreterbase.py index f17dfba..4fd1ae9 100644 --- a/mesonbuild/interpreterbase.py +++ b/mesonbuild/interpreterbase.py @@ -228,6 +228,61 @@ class permittedKwargs: return f(*wrapped_args, **wrapped_kwargs) return T.cast(TV_func, wrapped) + +def typed_pos_args(name: str, *types: T.Union[T.Type, T.Tuple[T.Type, ...]]) -> T.Callable[..., T.Any]: + """Decorator that types type checking of positional arguments. + + allows replacing this: + ```python + def func(self, node, args, kwargs): + if len(args) != 2: + raise Exception('... takes exactly 2 arguments) + foo: str = args[0] + if not isinstance(foo, str): + raise ... + bar: int = args[1] + if not isinstance(bar, int): + raise ... + + # actual useful stuff + ``` + with: + ```python + @typed_pos_args('func_name', str, int) + def func(self, node, args: T.Tuple[str, int], kwargs): + foo, bar = args + + # actual useful stuff + ``` + """ + def inner(f: TV_func) -> TV_func: + + @wraps(f) + def wrapper(*wrapped_args: T.Any, **wrapped_kwargs: T.Any) -> T.Any: + args = _get_callee_args(wrapped_args)[2] + assert isinstance(args, list), args + if len(args) != len(types): + raise InvalidArguments(f'{name} takes exactly {len(types)} arguments, but got {len(args)}.') + for i, (arg, type_) in enumerate(zip(args, types), start=1): + if not isinstance(arg, type_): + if isinstance(type_, tuple): + shouldbe = 'one of: {}'.format(", ".join(f'"{t.__name__}"' for t in type_)) + else: + shouldbe = f'"{type_.__name__}"' + raise InvalidArguments(f'{name} argument {i} was of type "{type(arg).__name__}" but should have been {shouldbe}') + + # Ensure that we're actually passing a tuple. + # Depending on what kind of function we're calling the length of + # wrapped_args can vary. + nargs = list(wrapped_args) + i = nargs.index(args) + nargs[i] = tuple(args) + return f(*nargs, **wrapped_kwargs) + + return T.cast(TV_func, wrapper) + return inner + + class FeatureCheckBase(metaclass=abc.ABCMeta): "Base class for feature version checks" diff --git a/run_unittests.py b/run_unittests.py index 9cc36fb..06386c8 100755 --- a/run_unittests.py +++ b/run_unittests.py @@ -51,6 +51,7 @@ import mesonbuild.mesonlib import mesonbuild.coredata import mesonbuild.modules.gnome from mesonbuild.interpreter import Interpreter, ObjectHolder +from mesonbuild.interpreterbase import typed_pos_args, InvalidArguments from mesonbuild.ast import AstInterpreter from mesonbuild.mesonlib import ( BuildDirLock, LibType, MachineChoice, PerMachine, Version, is_windows, @@ -1293,6 +1294,39 @@ class InternalTests(unittest.TestCase): self.assertFalse(errors) + def test_typed_pos_args_types(self) -> None: + @typed_pos_args('foo', str, int, bool) + def _(obj, node, args: T.Tuple[str, int, bool], kwargs) -> None: + self.assertIsInstance(args, tuple) + self.assertIsInstance(args[0], str) + self.assertIsInstance(args[1], int) + self.assertIsInstance(args[2], bool) + + _(None, mock.Mock(), ['string', 1, False], None) + + def test_typed_pos_args_types_invalid(self) -> None: + @typed_pos_args('foo', str, int, bool) + def _(obj, node, args: T.Tuple[str, int, bool], kwargs) -> None: + self.assertTrue(False) # should not be reachable + + with self.assertRaises(InvalidArguments) as cm: + _(None, mock.Mock(), ['string', 1.0, False], None) + self.assertEqual(str(cm.exception), 'foo argument 2 was of type "float" but should have been "int"') + + def test_typed_pos_args_types_wrong_number(self) -> None: + @typed_pos_args('foo', str, int, bool) + def _(obj, node, args: T.Tuple[str, int, bool], kwargs) -> None: + self.assertTrue(False) # should not be reachable + + with self.assertRaises(InvalidArguments) as cm: + _(None, mock.Mock(), ['string', 1], None) + self.assertEqual(str(cm.exception), 'foo takes exactly 3 arguments, but got 2.') + + with self.assertRaises(InvalidArguments) as cm: + _(None, mock.Mock(), ['string', 1, True, True], None) + self.assertEqual(str(cm.exception), 'foo takes exactly 3 arguments, but got 4.') + + @unittest.skipIf(is_tarball(), 'Skipping because this is a tarball release') class DataTests(unittest.TestCase): |