diff --git a/tests/mock_callable_testslide.py b/tests/mock_callable_testslide.py index 8343137..cb6dcec 100644 --- a/tests/mock_callable_testslide.py +++ b/tests/mock_callable_testslide.py @@ -183,6 +183,36 @@ def passes_with_invalid_argument_type(self): } self.callable_target(*call_args, **call_kwargs) + @context.example + def passes_with_valid_str_types(self): + args = ( + "str val", + 1234, + {"key1": "string", "key2": 4321}, + ) + kwargs = {"kwarg1": 1234} + self.mock_callable( + sample_module, "instance_method_with_str_types" + ).for_call(*args, **kwargs).to_return_value("hello") + sample_module.instance_method_with_str_types( + *args, **kwargs + ) + + @context.example + def raises_TypeCheckError_for_invalid_str_types(self): + args = (1234, 1234, 1234) + kwargs = {"kwarg1": "str val"} + self.mock_callable( + sample_module, "instance_method_with_str_types" + ).for_call(*args, **kwargs).to_return_value("hello") + with self.assertRaisesRegex( + TypeCheckError, + r"(?ms)type of arg1 must be str.*type of arg3 must be a dict.*", + ): + sample_module.instance_method_with_str_types( + *args, **kwargs + ) + if has_return_value: @context.sub_context @@ -205,6 +235,43 @@ def raises_TypeCheckError(self): *self.call_args, **self.call_kwargs ) + @context.example + def passes_with_valid_str_return_types(self): + args = ( + "str val", + 1234, + {"key1": "string", "key2": 4321}, + ) + kwargs = {"kwarg1": 1234} + self.mock_callable( + sample_module, "instance_method_with_str_types" + ).to_return_value("hello") + sample_module.instance_method_with_str_types( + *args, **kwargs + ) + + @context.example + def raises_TypeCheckError_for_invalid_str_return_types( + self, + ): + args = ( + "str val", + 1234, + {"key1": "string", "key2": 4321}, + ) + kwargs = {"kwarg1": 1234} + self.mock_callable( + sample_module, "instance_method_with_str_types" + ).to_return_value(1234) + with self.assertRaisesRegex( + TypeCheckError, + r"(?ms)type of return must be one of \(str, NoneType\); " + "got int instead: 1234.*", + ): + sample_module.instance_method_with_str_types( + *args, **kwargs + ) + @context.sub_context(".for_call(*args, **kwargs)") def for_call_args_kwargs(context): @context.sub_context diff --git a/tests/sample_module.py b/tests/sample_module.py index f4fd38e..df5ae6b 100644 --- a/tests/sample_module.py +++ b/tests/sample_module.py @@ -136,6 +136,12 @@ def test_function_returns_coroutine( return async_test_function(arg1, arg2, kwarg1, kwarg2) +def instance_method_with_str_types( + arg1: "str", arg2: "Any", arg3: "UnionArgType", kwarg1: "int" +) -> "Optional[str]": + return "original response" + + UnionArgType = Dict[str, Union[str, int]] diff --git a/testslide/lib.py b/testslide/lib.py index 2370879..3ac09b7 100644 --- a/testslide/lib.py +++ b/testslide/lib.py @@ -11,7 +11,17 @@ from functools import wraps from inspect import Traceback from types import FrameType -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Optional, + Tuple, + Type, + Union, + get_type_hints, +) from unittest.mock import Mock import typeguard @@ -244,6 +254,7 @@ def _validate_callable_arg_types( kwargs: Dict[str, Any], ) -> None: argspec = inspect.getfullargspec(callable_template) + type_hints = get_type_hints(callable_template) idx_offset = 1 if skip_first_arg else 0 type_errors = [] for idx in range(0, len(args)): @@ -254,7 +265,7 @@ def _validate_callable_arg_types( raise TypeError("Extra argument given: ", repr(args[idx])) argname = argspec.args[idx + idx_offset] try: - expected_type = argspec.annotations.get(argname) + expected_type = type_hints.get(argname) if not expected_type: continue @@ -264,7 +275,7 @@ def _validate_callable_arg_types( for argname, value in kwargs.items(): try: - expected_type = argspec.annotations.get(argname) + expected_type = type_hints.get(argname) if not expected_type: continue @@ -357,10 +368,10 @@ def _validate_return_type( unwrap_template_awaitable: bool = False, ) -> None: try: - argspec = inspect.getfullargspec(template) + type_hints = get_type_hints(template) + expected_type = type_hints.get("return") except TypeError: return - expected_type = argspec.annotations.get("return") if expected_type: if unwrap_template_awaitable: type_origin = get_origin(expected_type)