diff --git a/nonebot/dependencies/__init__.py b/nonebot/dependencies/__init__.py index 6040401a..42e02b9f 100644 --- a/nonebot/dependencies/__init__.py +++ b/nonebot/dependencies/__init__.py @@ -49,11 +49,13 @@ class Dependent(Generic[R]): self, *, call: Callable[..., Any], + pre_checkers: Optional[List[Param]] = None, params: Optional[List[ModelField]] = None, parameterless: Optional[List[Param]] = None, allow_types: Optional[List[Type[Param]]] = None, ) -> None: self.call = call + self.pre_checkers = pre_checkers or [] self.params = params or [] self.parameterless = parameterless or [] self.allow_types = allow_types or [] @@ -116,11 +118,6 @@ class Dependent(Generic[R]): allow_types=allow_types, ) - parameterless_params = [ - dependent.parse_parameterless(param) for param in (parameterless or []) - ] - dependent.parameterless.extend(parameterless_params) - for param_name, param in params.items(): default_value = Required if param.default != param.empty: @@ -152,6 +149,11 @@ class Dependent(Generic[R]): ) ) + parameterless_params = [ + dependent.parse_parameterless(param) for param in (parameterless or []) + ] + dependent.parameterless.extend(parameterless_params) + logger.trace( f"Parsed dependent with call={call}, " f"params={[param.field_info for param in dependent.params]}, " @@ -166,6 +168,9 @@ class Dependent(Generic[R]): ) -> Dict[str, Any]: values: Dict[str, Any] = {} + for checker in self.pre_checkers: + await checker._solve(**params) + for param in self.parameterless: await param._solve(**params) diff --git a/nonebot/matcher.py b/nonebot/matcher.py index af7cd8d2..43e9b253 100644 --- a/nonebot/matcher.py +++ b/nonebot/matcher.py @@ -484,7 +484,6 @@ class Matcher(metaclass=MatcherMeta): """ async def _key_getter(event: Event, matcher: "Matcher"): - print(key, matcher.state) matcher.set_target(ARG_KEY.format(key=key)) if matcher.get_target() == ARG_KEY.format(key=key): matcher.set_arg(key, event.get_message()) diff --git a/nonebot/params.py b/nonebot/params.py index 1e62d039..61548e13 100644 --- a/nonebot/params.py +++ b/nonebot/params.py @@ -4,10 +4,12 @@ from typing_extensions import Literal from typing import Any, Dict, List, Tuple, Callable, Optional, cast from contextlib import AsyncExitStack, contextmanager, asynccontextmanager -from pydantic.fields import Required, Undefined +from pydantic.fields import Required, Undefined, ModelField +from nonebot.log import logger +from nonebot.exception import TypeMisMatch from nonebot.adapters import Bot, Event, Message -from nonebot.dependencies import Param, Dependent +from nonebot.dependencies import Param, Dependent, CustomConfig from nonebot.typing import T_State, T_Handler, T_DependencyCache from nonebot.consts import ( CMD_KEY, @@ -94,11 +96,15 @@ class DependParam(Param): dependency = param.annotation else: dependency = param.default.dependency - dependent = Dependent[Any].parse( + sub_dependent = Dependent[Any].parse( call=dependency, allow_types=dependent.allow_types, ) - return cls(Required, use_cache=param.default.use_cache, dependent=dependent) + dependent.pre_checkers.extend(sub_dependent.pre_checkers) + sub_dependent.pre_checkers.clear() + return cls( + Required, use_cache=param.default.use_cache, dependent=sub_dependent + ) @classmethod def _check_parameterless( @@ -158,31 +164,81 @@ class DependParam(Param): return solved +class _BotChecker(Param): + async def _solve(self, bot: Bot, **kwargs: Any) -> Any: + field: ModelField = self.extra["field"] + _, errs_ = field.validate(bot, {}, loc=("bot",)) + if errs_: + logger.debug( + f"Bot type {type(bot)} not match " + f"annotation {field._type_display()}, ignored" + ) + raise TypeMisMatch(field, bot) + + class BotParam(Param): @classmethod def _check_param( cls, dependent: Dependent, name: str, param: inspect.Parameter ) -> Optional["BotParam"]: - if param.default == param.empty and ( - generic_check_issubclass(param.annotation, Bot) - or (param.annotation == param.empty and name == "bot") - ): - return cls(Required) + if param.default == param.empty: + if generic_check_issubclass(param.annotation, Bot): + dependent.pre_checkers.append( + _BotChecker( + Required, + field=ModelField( + name="", + type_=param.annotation, + class_validators=None, + model_config=CustomConfig, + default=None, + required=True, + ), + ) + ) + return cls(Required) + elif param.annotation == param.empty and name == "bot": + return cls(Required) async def _solve(self, bot: Bot, **kwargs: Any) -> Any: return bot +class _EventChecker(Param): + async def _solve(self, event: Event, **kwargs: Any) -> Any: + field: ModelField = self.extra["field"] + _, errs_ = field.validate(event, {}, loc=("event",)) + if errs_: + logger.debug( + f"Event type {type(event)} not match " + f"annotation {field._type_display()}, ignored" + ) + raise TypeMisMatch(field, event) + + class EventParam(Param): @classmethod def _check_param( cls, dependent: Dependent, name: str, param: inspect.Parameter ) -> Optional["EventParam"]: - if param.default == param.empty and ( - generic_check_issubclass(param.annotation, Event) - or (param.annotation == param.empty and name == "event") - ): - return cls(Required) + if param.default == param.empty: + if generic_check_issubclass(param.annotation, Event): + dependent.pre_checkers.append( + _EventChecker( + Required, + field=ModelField( + name="", + type_=param.annotation, + class_validators=None, + model_config=CustomConfig, + default=None, + required=True, + ), + ) + ) + return cls(Required) + elif param.annotation == param.empty and name == "event": + return cls(Required) async def _solve(self, event: Event, **kwargs: Any) -> Any: return event diff --git a/poetry.lock b/poetry.lock index e2611985..ea3d58c7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -134,17 +134,17 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" [[package]] name = "attrs" -version = "21.2.0" +version = "21.4.0" description = "Classes Without Boilerplate" category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" [package.extras] -dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "furo", "sphinx", "sphinx-notfound-page", "pre-commit"] +dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "furo", "sphinx", "sphinx-notfound-page", "pre-commit", "cloudpickle"] docs = ["furo", "sphinx", "zope.interface", "sphinx-notfound-page"] -tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface"] -tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins"] +tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "cloudpickle"] +tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "cloudpickle"] [[package]] name = "babel" @@ -555,7 +555,7 @@ pytest-asyncio = "^0.16.0" type = "git" url = "https://github.com/nonebot/nonebug.git" reference = "master" -resolved_reference = "afc3c3fe2cf6300cdf1b5e1a897867a03a17e278" +resolved_reference = "e198b56be8f9ccf53c0d6de38e40fb9c0831c890" [[package]] name = "packaging" @@ -578,11 +578,11 @@ python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" [[package]] name = "platformdirs" -version = "2.4.0" +version = "2.4.1" description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." category = "dev" optional = false -python-versions = ">=3.6" +python-versions = ">=3.7" [package.extras] docs = ["Sphinx (>=4)", "furo (>=2021.7.5b38)", "proselint (>=0.10.2)", "sphinx-autodoc-typehints (>=1.12)"] @@ -670,7 +670,7 @@ dev = ["black", "coverage", "docformatter", "flake8", "flake8-black", "flake8-bu [[package]] name = "pygments" -version = "2.10.0" +version = "2.11.0" description = "Pygments is a syntax highlighting package written in Python." category = "dev" optional = false @@ -1143,7 +1143,7 @@ h11 = ">=0.9.0,<1" [[package]] name = "yapf" -version = "0.31.0" +version = "0.32.0" description = "A formatter for Python code." category = "dev" optional = false @@ -1164,15 +1164,15 @@ typing-extensions = {version = ">=3.7.4", markers = "python_version < \"3.8\""} [[package]] name = "zipp" -version = "3.6.0" +version = "3.7.0" description = "Backport of pathlib-compatible object wrapper for zip files" category = "main" optional = false -python-versions = ">=3.6" +python-versions = ">=3.7" [package.extras] docs = ["sphinx", "jaraco.packaging (>=8.2)", "rst.linker (>=1.9)"] -testing = ["pytest (>=4.6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "jaraco.itertools", "func-timeout", "pytest-black (>=0.3.7)", "pytest-mypy"] +testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "jaraco.itertools", "func-timeout", "pytest-black (>=0.3.7)", "pytest-mypy"] [extras] aiohttp = ["aiohttp"] @@ -1301,8 +1301,8 @@ atomicwrites = [ {file = "atomicwrites-1.4.0.tar.gz", hash = "sha256:ae70396ad1a434f9c7046fd2dd196fc04b12f9e91ffb859164193be8b6168a7a"}, ] attrs = [ - {file = "attrs-21.2.0-py2.py3-none-any.whl", hash = "sha256:149e90d6d8ac20db7a955ad60cf0e6881a3f20d37096140088356da6c716b0b1"}, - {file = "attrs-21.2.0.tar.gz", hash = "sha256:ef6aaac3ca6cd92904cdd0d83f629a15f18053ec84e6432106f7a4d04ae4f5fb"}, + {file = "attrs-21.4.0-py2.py3-none-any.whl", hash = "sha256:2d27e3784d7a565d36ab851fe94887c5eccd6a463168875832a1be79c82828b4"}, + {file = "attrs-21.4.0.tar.gz", hash = "sha256:626ba8234211db98e869df76230a137c4c40a12d72445c45d5f5b716f076e2fd"}, ] babel = [ {file = "Babel-2.9.1-py2.py3-none-any.whl", hash = "sha256:ab49e12b91d937cd11f0b67cb259a57ab4ad2b59ac7a3b41d6c06c0ac5b0def9"}, @@ -1834,8 +1834,8 @@ pathspec = [ {file = "pathspec-0.9.0.tar.gz", hash = "sha256:e564499435a2673d586f6b2130bb5b95f04a3ba06f81b8f895b651a3c76aabb1"}, ] platformdirs = [ - {file = "platformdirs-2.4.0-py3-none-any.whl", hash = "sha256:8868bbe3c3c80d42f20156f22e7131d2fb321f5bc86a2a345375c6481a67021d"}, - {file = "platformdirs-2.4.0.tar.gz", hash = "sha256:367a5e80b3d04d2428ffa76d33f124cf11e8fff2acdaa9b43d545f5c7d661ef2"}, + {file = "platformdirs-2.4.1-py3-none-any.whl", hash = "sha256:1d7385c7db91728b83efd0ca99a5afb296cab9d0ed8313a45ed8ba17967ecfca"}, + {file = "platformdirs-2.4.1.tar.gz", hash = "sha256:440633ddfebcc36264232365d7840a970e75e1018d15b4327d11f91909045fda"}, ] pluggy = [ {file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"}, @@ -1915,8 +1915,8 @@ pydash = [ {file = "pydash-5.1.0.tar.gz", hash = "sha256:1b2b050ac1bae049cd07f5920b14fabbe52638f485d9ada1eb115a9eebff6835"}, ] pygments = [ - {file = "Pygments-2.10.0-py3-none-any.whl", hash = "sha256:b8e67fe6af78f492b3c4b3e2970c0624cbf08beb1e493b2c99b9fa1b67a20380"}, - {file = "Pygments-2.10.0.tar.gz", hash = "sha256:f398865f7eb6874156579fdf36bc840a03cab64d1cde9e93d68f46a425ec52c6"}, + {file = "Pygments-2.11.0-py3-none-any.whl", hash = "sha256:ac8098bfc40b8e1091ad7c13490c7f4797e401d0972e8fcfadde90ffb3ed4ea9"}, + {file = "Pygments-2.11.0.tar.gz", hash = "sha256:51130f778a028f2d19c143fce00ced6f8b10f726e17599d7e91b290f6cbcda0c"}, ] pygtrie = [ {file = "pygtrie-2.4.2.tar.gz", hash = "sha256:43205559d28863358dbbf25045029f58e2ab357317a59b11f11ade278ac64692"}, @@ -2169,8 +2169,8 @@ wsproto = [ {file = "wsproto-1.0.0.tar.gz", hash = "sha256:868776f8456997ad0d9720f7322b746bbe9193751b5b290b7f924659377c8c38"}, ] yapf = [ - {file = "yapf-0.31.0-py2.py3-none-any.whl", hash = "sha256:e3a234ba8455fe201eaa649cdac872d590089a18b661e39bbac7020978dd9c2e"}, - {file = "yapf-0.31.0.tar.gz", hash = "sha256:408fb9a2b254c302f49db83c59f9aa0b4b0fd0ec25be3a5c51181327922ff63d"}, + {file = "yapf-0.32.0-py2.py3-none-any.whl", hash = "sha256:8fea849025584e486fd06d6ba2bed717f396080fd3cc236ba10cb97c4c51cf32"}, + {file = "yapf-0.32.0.tar.gz", hash = "sha256:a3f5085d37ef7e3e004c4ba9f9b3e40c54ff1901cd111f05145ae313a7c67d1b"}, ] yarl = [ {file = "yarl-1.7.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f2a8508f7350512434e41065684076f640ecce176d262a7d54f0da41d99c5a95"}, @@ -2247,6 +2247,6 @@ yarl = [ {file = "yarl-1.7.2.tar.gz", hash = "sha256:45399b46d60c253327a460e99856752009fcee5f5d3c80b2f7c0cae1c38d56dd"}, ] zipp = [ - {file = "zipp-3.6.0-py3-none-any.whl", hash = "sha256:9fe5ea21568a0a70e50f273397638d39b03353731e6cbbb3fd8502a33fec40bc"}, - {file = "zipp-3.6.0.tar.gz", hash = "sha256:71c644c5369f4a6e07636f0aa966270449561fcea2e3d6747b8d23efaa9d7832"}, + {file = "zipp-3.7.0-py3-none-any.whl", hash = "sha256:b47250dd24f92b7dd6a0a8fc5244da14608f3ca90a5efcd37a3b1642fac9a375"}, + {file = "zipp-3.7.0.tar.gz", hash = "sha256:9f50f446828eb9d45b267433fd3e9da8d801f614129124863f9c51ebceafb87d"}, ] diff --git a/tests/plugins/matcher.py b/tests/plugins/matcher.py index 3621b850..016e2238 100644 --- a/tests/plugins/matcher.py +++ b/tests/plugins/matcher.py @@ -66,8 +66,27 @@ async def preset(matcher: Matcher, message: Message = EventMessage()): @test_preset.got("a") -async def reject_preset(a: str = ArgStr()): +@test_preset.got("b") +async def reject_preset(a: str = ArgStr(), b: str = ArgStr()): if a == "text": await test_preset.reject_arg("a") assert a == "text_next" + assert b == "text" + + +test_overload = on_message() + + +class FakeEvent(Event): + ... + + +@test_overload.got("a") +async def overload(event: FakeEvent): + await test_overload.reject_arg("a") + + +@test_overload.handle() +async def finish(): + await test_overload.finish() diff --git a/tests/test_matcher.py b/tests/test_matcher.py index 403aeab1..61f8c73c 100644 --- a/tests/test_matcher.py +++ b/tests/test_matcher.py @@ -12,6 +12,7 @@ async def test_matcher(app: App, load_plugin): test_preset, test_combine, test_receive, + test_overload, ) message = make_fake_message()("text") @@ -64,5 +65,12 @@ async def test_matcher(app: App, load_plugin): async with app.test_matcher(test_preset) as ctx: bot = ctx.create_bot() ctx.receive_event(bot, event) + ctx.receive_event(bot, event) ctx.should_rejected() ctx.receive_event(bot, event_next) + + assert len(test_overload.handlers) == 2 + async with app.test_matcher(test_overload) as ctx: + bot = ctx.create_bot() + ctx.receive_event(bot, event) + ctx.should_finished()