From 978eeddab8ab126a340b2423ddf4146d0ca9ccf8 Mon Sep 17 00:00:00 2001 From: Dylan Baker Date: Thu, 14 Jan 2021 10:41:04 -0800 Subject: interpreterbase: Add support for variadic arguments to typed_pos_args This allows functions like `files()` to be decorated. --- mesonbuild/interpreterbase.py | 42 +++++++++++++++---- run_unittests.py | 96 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 131 insertions(+), 7 deletions(-) diff --git a/mesonbuild/interpreterbase.py b/mesonbuild/interpreterbase.py index 4fd1ae9..081a31d 100644 --- a/mesonbuild/interpreterbase.py +++ b/mesonbuild/interpreterbase.py @@ -18,10 +18,11 @@ from . import mparser, mesonlib, mlog from . import environment, dependencies +from functools import wraps import abc -import os, copy, re import collections.abc -from functools import wraps +import itertools +import os, copy, re import typing as T TV_fw_var = T.Union[str, int, float, bool, list, dict, 'InterpreterObject', 'ObjectHolder'] @@ -229,7 +230,9 @@ class permittedKwargs: return T.cast(TV_func, wrapped) -def typed_pos_args(name: str, *types: T.Union[T.Type, T.Tuple[T.Type, ...]]) -> T.Callable[..., T.Any]: +def typed_pos_args(name: str, *types: T.Union[T.Type, T.Tuple[T.Type, ...]], + varargs: T.Optional[T.Union[T.Type, T.Tuple[T.Type]]] = None, + min_varargs: int = 0, max_varargs: int = 0) -> T.Callable[..., T.Any]: """Decorator that types type checking of positional arguments. allows replacing this: @@ -260,10 +263,25 @@ def typed_pos_args(name: str, *types: T.Union[T.Type, T.Tuple[T.Type, ...]]) -> @wraps(f) def wrapper(*wrapped_args: T.Any, **wrapped_kwargs: T.Any) -> T.Any: args = _get_callee_args(wrapped_args)[2] + + # These are implementation programming errors, end users should never see them. assert isinstance(args, list), args - if len(args) != len(types): - raise InvalidArguments(f'{name} takes exactly {len(types)} arguments, but got {len(args)}.') - for i, (arg, type_) in enumerate(zip(args, types), start=1): + assert max_varargs >= 0, 'max_varrags cannot be negative' + assert min_varargs >= 0, 'min_varrags cannot be negative' + + num_args = len(args) + num_types = len(types) + + if varargs: + num_types += min_varargs + if max_varargs == 0 and num_args < num_types: + raise InvalidArguments(f'{name} takes at least {num_types} arguments, but got {num_args}.') + elif max_varargs != 0 and (num_args < num_types or num_args > num_types + max_varargs - min_varargs): + raise InvalidArguments(f'{name} takes between {num_types} and {num_types + max_varargs - min_varargs} arguments, but got {len(args)}.') + elif num_args != num_types: + raise InvalidArguments(f'{name} takes exactly {num_types} arguments, but got {num_args}.') + + for i, (arg, type_) in enumerate(itertools.zip_longest(args, types, fillvalue=varargs), start=1): if not isinstance(arg, type_): if isinstance(type_, tuple): shouldbe = 'one of: {}'.format(", ".join(f'"{t.__name__}"' for t in type_)) @@ -276,7 +294,17 @@ def typed_pos_args(name: str, *types: T.Union[T.Type, T.Tuple[T.Type, ...]]) -> # wrapped_args can vary. nargs = list(wrapped_args) i = nargs.index(args) - nargs[i] = tuple(args) + if varargs: + # if we have varargs we need to split them into a separate + # tuple, as python's typing doesn't understand tuples with + # fixed elements and variadic elements, only one or the other. + # so in that cas we need T.Tuple[int, str, float, T.Tuple[str, ...]] + pos = args[:len(types)] + var = list(args[len(types):]) + pos.append(var) + nargs[i] = tuple(pos) + else: + nargs[i] = tuple(args) return f(*nargs, **wrapped_kwargs) return T.cast(TV_func, wrapper) diff --git a/run_unittests.py b/run_unittests.py index 06386c8..8dfb394 100755 --- a/run_unittests.py +++ b/run_unittests.py @@ -1326,6 +1326,102 @@ class InternalTests(unittest.TestCase): _(None, mock.Mock(), ['string', 1, True, True], None) self.assertEqual(str(cm.exception), 'foo takes exactly 3 arguments, but got 4.') + def test_typed_pos_args_varargs(self) -> None: + @typed_pos_args('foo', str, varargs=str) + def _(obj, node, args: T.Tuple[str, T.List[str]], kwargs) -> None: + self.assertIsInstance(args, tuple) + self.assertIsInstance(args[0], str) + self.assertIsInstance(args[1], list) + self.assertIsInstance(args[1][0], str) + self.assertIsInstance(args[1][1], str) + + _(None, mock.Mock(), ['string', 'var', 'args'], None) + + def test_typed_pos_args_varargs_not_given(self) -> None: + @typed_pos_args('foo', str, varargs=str) + def _(obj, node, args: T.Tuple[str, T.List[str]], kwargs) -> None: + self.assertIsInstance(args, tuple) + self.assertIsInstance(args[0], str) + self.assertIsInstance(args[1], list) + self.assertEqual(args[1], []) + + _(None, mock.Mock(), ['string'], None) + + def test_typed_pos_args_varargs_invalid(self) -> None: + @typed_pos_args('foo', str, varargs=str) + def _(obj, node, args: T.Tuple[str, T.List[str]], kwargs) -> None: + self.assertTrue(False) # should not be reachable + + with self.assertRaises(InvalidArguments) as cm: + _(None, mock.Mock(), ['string', 'var', 'args', 0], None) + self.assertEqual(str(cm.exception), 'foo argument 4 was of type "int" but should have been "str"') + + def test_typed_pos_args_varargs_invalid_mulitple_types(self) -> None: + @typed_pos_args('foo', str, varargs=(str, list)) + def _(obj, node, args: T.Tuple[str, T.List[str]], kwargs) -> None: + self.assertTrue(False) # should not be reachable + + with self.assertRaises(InvalidArguments) as cm: + _(None, mock.Mock(), ['string', 'var', 'args', 0], None) + self.assertEqual(str(cm.exception), 'foo argument 4 was of type "int" but should have been one of: "str", "list"') + + def test_typed_pos_args_max_varargs(self) -> None: + @typed_pos_args('foo', str, varargs=str, max_varargs=5) + def _(obj, node, args: T.Tuple[str, T.List[str]], kwargs) -> None: + self.assertIsInstance(args, tuple) + self.assertIsInstance(args[0], str) + self.assertIsInstance(args[1], list) + self.assertIsInstance(args[1][0], str) + self.assertIsInstance(args[1][1], str) + + _(None, mock.Mock(), ['string', 'var', 'args'], None) + + def test_typed_pos_args_max_varargs_exceeded(self) -> None: + @typed_pos_args('foo', str, varargs=str, max_varargs=1) + def _(obj, node, args: T.Tuple[str, T.Tuple[str, ...]], kwargs) -> None: + self.assertTrue(False) # should not be reachable + + with self.assertRaises(InvalidArguments) as cm: + _(None, mock.Mock(), ['string', 'var', 'args'], None) + self.assertEqual(str(cm.exception), 'foo takes between 1 and 2 arguments, but got 3.') + + def test_typed_pos_args_min_varargs(self) -> None: + @typed_pos_args('foo', varargs=str, max_varargs=2, min_varargs=1) + def _(obj, node, args: T.Tuple[str, T.List[str]], kwargs) -> None: + self.assertIsInstance(args, tuple) + self.assertIsInstance(args[0], list) + self.assertIsInstance(args[0][0], str) + self.assertIsInstance(args[0][1], str) + + _(None, mock.Mock(), ['string', 'var'], None) + + def test_typed_pos_args_min_varargs_not_met(self) -> None: + @typed_pos_args('foo', str, varargs=str, min_varargs=1) + def _(obj, node, args: T.Tuple[str, T.List[str]], kwargs) -> None: + self.assertTrue(False) # should not be reachable + + with self.assertRaises(InvalidArguments) as cm: + _(None, mock.Mock(), ['string'], None) + self.assertEqual(str(cm.exception), 'foo takes at least 2 arguments, but got 1.') + + def test_typed_pos_args_min_and_max_varargs_exceeded(self) -> None: + @typed_pos_args('foo', str, varargs=str, min_varargs=1, max_varargs=2) + def _(obj, node, args: T.Tuple[str, T.Tuple[str, ...]], kwargs) -> None: + self.assertTrue(False) # should not be reachable + + with self.assertRaises(InvalidArguments) as cm: + _(None, mock.Mock(), ['string', 'var', 'args', 'bar'], None) + self.assertEqual(str(cm.exception), 'foo takes between 2 and 3 arguments, but got 4.') + + def test_typed_pos_args_min_and_max_varargs_not_met(self) -> None: + @typed_pos_args('foo', str, varargs=str, min_varargs=1, max_varargs=2) + def _(obj, node, args: T.Tuple[str, T.Tuple[str, ...]], kwargs) -> None: + self.assertTrue(False) # should not be reachable + + with self.assertRaises(InvalidArguments) as cm: + _(None, mock.Mock(), ['string'], None) + self.assertEqual(str(cm.exception), 'foo takes between 2 and 3 arguments, but got 1.') + @unittest.skipIf(is_tarball(), 'Skipping because this is a tarball release') class DataTests(unittest.TestCase): -- cgit v1.1