🐛 fix matcher got receive type error

This commit is contained in:
yanyongyu 2021-12-31 22:43:29 +08:00
parent c48231454e
commit ec35f292bd
6 changed files with 130 additions and 43 deletions

View File

@ -49,11 +49,13 @@ class Dependent(Generic[R]):
self, self,
*, *,
call: Callable[..., Any], call: Callable[..., Any],
pre_checkers: Optional[List[Param]] = None,
params: Optional[List[ModelField]] = None, params: Optional[List[ModelField]] = None,
parameterless: Optional[List[Param]] = None, parameterless: Optional[List[Param]] = None,
allow_types: Optional[List[Type[Param]]] = None, allow_types: Optional[List[Type[Param]]] = None,
) -> None: ) -> None:
self.call = call self.call = call
self.pre_checkers = pre_checkers or []
self.params = params or [] self.params = params or []
self.parameterless = parameterless or [] self.parameterless = parameterless or []
self.allow_types = allow_types or [] self.allow_types = allow_types or []
@ -116,11 +118,6 @@ class Dependent(Generic[R]):
allow_types=allow_types, allow_types=allow_types,
) )
parameterless_params = [
dependent.parse_parameterless(param) for param in (parameterless or [])
]
dependent.parameterless.extend(parameterless_params)
for param_name, param in params.items(): for param_name, param in params.items():
default_value = Required default_value = Required
if param.default != param.empty: if param.default != param.empty:
@ -152,6 +149,11 @@ class Dependent(Generic[R]):
) )
) )
parameterless_params = [
dependent.parse_parameterless(param) for param in (parameterless or [])
]
dependent.parameterless.extend(parameterless_params)
logger.trace( logger.trace(
f"Parsed dependent with call={call}, " f"Parsed dependent with call={call}, "
f"params={[param.field_info for param in dependent.params]}, " f"params={[param.field_info for param in dependent.params]}, "
@ -166,6 +168,9 @@ class Dependent(Generic[R]):
) -> Dict[str, Any]: ) -> Dict[str, Any]:
values: Dict[str, Any] = {} values: Dict[str, Any] = {}
for checker in self.pre_checkers:
await checker._solve(**params)
for param in self.parameterless: for param in self.parameterless:
await param._solve(**params) await param._solve(**params)

View File

@ -484,7 +484,6 @@ class Matcher(metaclass=MatcherMeta):
""" """
async def _key_getter(event: Event, matcher: "Matcher"): async def _key_getter(event: Event, matcher: "Matcher"):
print(key, matcher.state)
matcher.set_target(ARG_KEY.format(key=key)) matcher.set_target(ARG_KEY.format(key=key))
if matcher.get_target() == ARG_KEY.format(key=key): if matcher.get_target() == ARG_KEY.format(key=key):
matcher.set_arg(key, event.get_message()) matcher.set_arg(key, event.get_message())

View File

@ -4,10 +4,12 @@ from typing_extensions import Literal
from typing import Any, Dict, List, Tuple, Callable, Optional, cast from typing import Any, Dict, List, Tuple, Callable, Optional, cast
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
from pydantic.fields import Required, Undefined from pydantic.fields import Required, Undefined, ModelField
from nonebot.log import logger
from nonebot.exception import TypeMisMatch
from nonebot.adapters import Bot, Event, Message from nonebot.adapters import Bot, Event, Message
from nonebot.dependencies import Param, Dependent from nonebot.dependencies import Param, Dependent, CustomConfig
from nonebot.typing import T_State, T_Handler, T_DependencyCache from nonebot.typing import T_State, T_Handler, T_DependencyCache
from nonebot.consts import ( from nonebot.consts import (
CMD_KEY, CMD_KEY,
@ -94,11 +96,15 @@ class DependParam(Param):
dependency = param.annotation dependency = param.annotation
else: else:
dependency = param.default.dependency dependency = param.default.dependency
dependent = Dependent[Any].parse( sub_dependent = Dependent[Any].parse(
call=dependency, call=dependency,
allow_types=dependent.allow_types, allow_types=dependent.allow_types,
) )
return cls(Required, use_cache=param.default.use_cache, dependent=dependent) dependent.pre_checkers.extend(sub_dependent.pre_checkers)
sub_dependent.pre_checkers.clear()
return cls(
Required, use_cache=param.default.use_cache, dependent=sub_dependent
)
@classmethod @classmethod
def _check_parameterless( def _check_parameterless(
@ -158,31 +164,81 @@ class DependParam(Param):
return solved return solved
class _BotChecker(Param):
async def _solve(self, bot: Bot, **kwargs: Any) -> Any:
field: ModelField = self.extra["field"]
_, errs_ = field.validate(bot, {}, loc=("bot",))
if errs_:
logger.debug(
f"Bot type {type(bot)} not match "
f"annotation {field._type_display()}, ignored"
)
raise TypeMisMatch(field, bot)
class BotParam(Param): class BotParam(Param):
@classmethod @classmethod
def _check_param( def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter cls, dependent: Dependent, name: str, param: inspect.Parameter
) -> Optional["BotParam"]: ) -> Optional["BotParam"]:
if param.default == param.empty and ( if param.default == param.empty:
generic_check_issubclass(param.annotation, Bot) if generic_check_issubclass(param.annotation, Bot):
or (param.annotation == param.empty and name == "bot") dependent.pre_checkers.append(
): _BotChecker(
return cls(Required) Required,
field=ModelField(
name="",
type_=param.annotation,
class_validators=None,
model_config=CustomConfig,
default=None,
required=True,
),
)
)
return cls(Required)
elif param.annotation == param.empty and name == "bot":
return cls(Required)
async def _solve(self, bot: Bot, **kwargs: Any) -> Any: async def _solve(self, bot: Bot, **kwargs: Any) -> Any:
return bot return bot
class _EventChecker(Param):
async def _solve(self, event: Event, **kwargs: Any) -> Any:
field: ModelField = self.extra["field"]
_, errs_ = field.validate(event, {}, loc=("event",))
if errs_:
logger.debug(
f"Event type {type(event)} not match "
f"annotation {field._type_display()}, ignored"
)
raise TypeMisMatch(field, event)
class EventParam(Param): class EventParam(Param):
@classmethod @classmethod
def _check_param( def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter cls, dependent: Dependent, name: str, param: inspect.Parameter
) -> Optional["EventParam"]: ) -> Optional["EventParam"]:
if param.default == param.empty and ( if param.default == param.empty:
generic_check_issubclass(param.annotation, Event) if generic_check_issubclass(param.annotation, Event):
or (param.annotation == param.empty and name == "event") dependent.pre_checkers.append(
): _EventChecker(
return cls(Required) Required,
field=ModelField(
name="",
type_=param.annotation,
class_validators=None,
model_config=CustomConfig,
default=None,
required=True,
),
)
)
return cls(Required)
elif param.annotation == param.empty and name == "event":
return cls(Required)
async def _solve(self, event: Event, **kwargs: Any) -> Any: async def _solve(self, event: Event, **kwargs: Any) -> Any:
return event return event

44
poetry.lock generated
View File

@ -134,17 +134,17 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
[[package]] [[package]]
name = "attrs" name = "attrs"
version = "21.2.0" version = "21.4.0"
description = "Classes Without Boilerplate" description = "Classes Without Boilerplate"
category = "main" category = "main"
optional = false optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
[package.extras] [package.extras]
dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "furo", "sphinx", "sphinx-notfound-page", "pre-commit"] dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "furo", "sphinx", "sphinx-notfound-page", "pre-commit", "cloudpickle"]
docs = ["furo", "sphinx", "zope.interface", "sphinx-notfound-page"] docs = ["furo", "sphinx", "zope.interface", "sphinx-notfound-page"]
tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface"] tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "cloudpickle"]
tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins"] tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "cloudpickle"]
[[package]] [[package]]
name = "babel" name = "babel"
@ -555,7 +555,7 @@ pytest-asyncio = "^0.16.0"
type = "git" type = "git"
url = "https://github.com/nonebot/nonebug.git" url = "https://github.com/nonebot/nonebug.git"
reference = "master" reference = "master"
resolved_reference = "afc3c3fe2cf6300cdf1b5e1a897867a03a17e278" resolved_reference = "e198b56be8f9ccf53c0d6de38e40fb9c0831c890"
[[package]] [[package]]
name = "packaging" name = "packaging"
@ -578,11 +578,11 @@ python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7"
[[package]] [[package]]
name = "platformdirs" name = "platformdirs"
version = "2.4.0" version = "2.4.1"
description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"."
category = "dev" category = "dev"
optional = false optional = false
python-versions = ">=3.6" python-versions = ">=3.7"
[package.extras] [package.extras]
docs = ["Sphinx (>=4)", "furo (>=2021.7.5b38)", "proselint (>=0.10.2)", "sphinx-autodoc-typehints (>=1.12)"] docs = ["Sphinx (>=4)", "furo (>=2021.7.5b38)", "proselint (>=0.10.2)", "sphinx-autodoc-typehints (>=1.12)"]
@ -670,7 +670,7 @@ dev = ["black", "coverage", "docformatter", "flake8", "flake8-black", "flake8-bu
[[package]] [[package]]
name = "pygments" name = "pygments"
version = "2.10.0" version = "2.11.0"
description = "Pygments is a syntax highlighting package written in Python." description = "Pygments is a syntax highlighting package written in Python."
category = "dev" category = "dev"
optional = false optional = false
@ -1143,7 +1143,7 @@ h11 = ">=0.9.0,<1"
[[package]] [[package]]
name = "yapf" name = "yapf"
version = "0.31.0" version = "0.32.0"
description = "A formatter for Python code." description = "A formatter for Python code."
category = "dev" category = "dev"
optional = false optional = false
@ -1164,15 +1164,15 @@ typing-extensions = {version = ">=3.7.4", markers = "python_version < \"3.8\""}
[[package]] [[package]]
name = "zipp" name = "zipp"
version = "3.6.0" version = "3.7.0"
description = "Backport of pathlib-compatible object wrapper for zip files" description = "Backport of pathlib-compatible object wrapper for zip files"
category = "main" category = "main"
optional = false optional = false
python-versions = ">=3.6" python-versions = ">=3.7"
[package.extras] [package.extras]
docs = ["sphinx", "jaraco.packaging (>=8.2)", "rst.linker (>=1.9)"] docs = ["sphinx", "jaraco.packaging (>=8.2)", "rst.linker (>=1.9)"]
testing = ["pytest (>=4.6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "jaraco.itertools", "func-timeout", "pytest-black (>=0.3.7)", "pytest-mypy"] testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "jaraco.itertools", "func-timeout", "pytest-black (>=0.3.7)", "pytest-mypy"]
[extras] [extras]
aiohttp = ["aiohttp"] aiohttp = ["aiohttp"]
@ -1301,8 +1301,8 @@ atomicwrites = [
{file = "atomicwrites-1.4.0.tar.gz", hash = "sha256:ae70396ad1a434f9c7046fd2dd196fc04b12f9e91ffb859164193be8b6168a7a"}, {file = "atomicwrites-1.4.0.tar.gz", hash = "sha256:ae70396ad1a434f9c7046fd2dd196fc04b12f9e91ffb859164193be8b6168a7a"},
] ]
attrs = [ attrs = [
{file = "attrs-21.2.0-py2.py3-none-any.whl", hash = "sha256:149e90d6d8ac20db7a955ad60cf0e6881a3f20d37096140088356da6c716b0b1"}, {file = "attrs-21.4.0-py2.py3-none-any.whl", hash = "sha256:2d27e3784d7a565d36ab851fe94887c5eccd6a463168875832a1be79c82828b4"},
{file = "attrs-21.2.0.tar.gz", hash = "sha256:ef6aaac3ca6cd92904cdd0d83f629a15f18053ec84e6432106f7a4d04ae4f5fb"}, {file = "attrs-21.4.0.tar.gz", hash = "sha256:626ba8234211db98e869df76230a137c4c40a12d72445c45d5f5b716f076e2fd"},
] ]
babel = [ babel = [
{file = "Babel-2.9.1-py2.py3-none-any.whl", hash = "sha256:ab49e12b91d937cd11f0b67cb259a57ab4ad2b59ac7a3b41d6c06c0ac5b0def9"}, {file = "Babel-2.9.1-py2.py3-none-any.whl", hash = "sha256:ab49e12b91d937cd11f0b67cb259a57ab4ad2b59ac7a3b41d6c06c0ac5b0def9"},
@ -1834,8 +1834,8 @@ pathspec = [
{file = "pathspec-0.9.0.tar.gz", hash = "sha256:e564499435a2673d586f6b2130bb5b95f04a3ba06f81b8f895b651a3c76aabb1"}, {file = "pathspec-0.9.0.tar.gz", hash = "sha256:e564499435a2673d586f6b2130bb5b95f04a3ba06f81b8f895b651a3c76aabb1"},
] ]
platformdirs = [ platformdirs = [
{file = "platformdirs-2.4.0-py3-none-any.whl", hash = "sha256:8868bbe3c3c80d42f20156f22e7131d2fb321f5bc86a2a345375c6481a67021d"}, {file = "platformdirs-2.4.1-py3-none-any.whl", hash = "sha256:1d7385c7db91728b83efd0ca99a5afb296cab9d0ed8313a45ed8ba17967ecfca"},
{file = "platformdirs-2.4.0.tar.gz", hash = "sha256:367a5e80b3d04d2428ffa76d33f124cf11e8fff2acdaa9b43d545f5c7d661ef2"}, {file = "platformdirs-2.4.1.tar.gz", hash = "sha256:440633ddfebcc36264232365d7840a970e75e1018d15b4327d11f91909045fda"},
] ]
pluggy = [ pluggy = [
{file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"}, {file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"},
@ -1915,8 +1915,8 @@ pydash = [
{file = "pydash-5.1.0.tar.gz", hash = "sha256:1b2b050ac1bae049cd07f5920b14fabbe52638f485d9ada1eb115a9eebff6835"}, {file = "pydash-5.1.0.tar.gz", hash = "sha256:1b2b050ac1bae049cd07f5920b14fabbe52638f485d9ada1eb115a9eebff6835"},
] ]
pygments = [ pygments = [
{file = "Pygments-2.10.0-py3-none-any.whl", hash = "sha256:b8e67fe6af78f492b3c4b3e2970c0624cbf08beb1e493b2c99b9fa1b67a20380"}, {file = "Pygments-2.11.0-py3-none-any.whl", hash = "sha256:ac8098bfc40b8e1091ad7c13490c7f4797e401d0972e8fcfadde90ffb3ed4ea9"},
{file = "Pygments-2.10.0.tar.gz", hash = "sha256:f398865f7eb6874156579fdf36bc840a03cab64d1cde9e93d68f46a425ec52c6"}, {file = "Pygments-2.11.0.tar.gz", hash = "sha256:51130f778a028f2d19c143fce00ced6f8b10f726e17599d7e91b290f6cbcda0c"},
] ]
pygtrie = [ pygtrie = [
{file = "pygtrie-2.4.2.tar.gz", hash = "sha256:43205559d28863358dbbf25045029f58e2ab357317a59b11f11ade278ac64692"}, {file = "pygtrie-2.4.2.tar.gz", hash = "sha256:43205559d28863358dbbf25045029f58e2ab357317a59b11f11ade278ac64692"},
@ -2169,8 +2169,8 @@ wsproto = [
{file = "wsproto-1.0.0.tar.gz", hash = "sha256:868776f8456997ad0d9720f7322b746bbe9193751b5b290b7f924659377c8c38"}, {file = "wsproto-1.0.0.tar.gz", hash = "sha256:868776f8456997ad0d9720f7322b746bbe9193751b5b290b7f924659377c8c38"},
] ]
yapf = [ yapf = [
{file = "yapf-0.31.0-py2.py3-none-any.whl", hash = "sha256:e3a234ba8455fe201eaa649cdac872d590089a18b661e39bbac7020978dd9c2e"}, {file = "yapf-0.32.0-py2.py3-none-any.whl", hash = "sha256:8fea849025584e486fd06d6ba2bed717f396080fd3cc236ba10cb97c4c51cf32"},
{file = "yapf-0.31.0.tar.gz", hash = "sha256:408fb9a2b254c302f49db83c59f9aa0b4b0fd0ec25be3a5c51181327922ff63d"}, {file = "yapf-0.32.0.tar.gz", hash = "sha256:a3f5085d37ef7e3e004c4ba9f9b3e40c54ff1901cd111f05145ae313a7c67d1b"},
] ]
yarl = [ yarl = [
{file = "yarl-1.7.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f2a8508f7350512434e41065684076f640ecce176d262a7d54f0da41d99c5a95"}, {file = "yarl-1.7.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f2a8508f7350512434e41065684076f640ecce176d262a7d54f0da41d99c5a95"},
@ -2247,6 +2247,6 @@ yarl = [
{file = "yarl-1.7.2.tar.gz", hash = "sha256:45399b46d60c253327a460e99856752009fcee5f5d3c80b2f7c0cae1c38d56dd"}, {file = "yarl-1.7.2.tar.gz", hash = "sha256:45399b46d60c253327a460e99856752009fcee5f5d3c80b2f7c0cae1c38d56dd"},
] ]
zipp = [ zipp = [
{file = "zipp-3.6.0-py3-none-any.whl", hash = "sha256:9fe5ea21568a0a70e50f273397638d39b03353731e6cbbb3fd8502a33fec40bc"}, {file = "zipp-3.7.0-py3-none-any.whl", hash = "sha256:b47250dd24f92b7dd6a0a8fc5244da14608f3ca90a5efcd37a3b1642fac9a375"},
{file = "zipp-3.6.0.tar.gz", hash = "sha256:71c644c5369f4a6e07636f0aa966270449561fcea2e3d6747b8d23efaa9d7832"}, {file = "zipp-3.7.0.tar.gz", hash = "sha256:9f50f446828eb9d45b267433fd3e9da8d801f614129124863f9c51ebceafb87d"},
] ]

View File

@ -66,8 +66,27 @@ async def preset(matcher: Matcher, message: Message = EventMessage()):
@test_preset.got("a") @test_preset.got("a")
async def reject_preset(a: str = ArgStr()): @test_preset.got("b")
async def reject_preset(a: str = ArgStr(), b: str = ArgStr()):
if a == "text": if a == "text":
await test_preset.reject_arg("a") await test_preset.reject_arg("a")
assert a == "text_next" assert a == "text_next"
assert b == "text"
test_overload = on_message()
class FakeEvent(Event):
...
@test_overload.got("a")
async def overload(event: FakeEvent):
await test_overload.reject_arg("a")
@test_overload.handle()
async def finish():
await test_overload.finish()

View File

@ -12,6 +12,7 @@ async def test_matcher(app: App, load_plugin):
test_preset, test_preset,
test_combine, test_combine,
test_receive, test_receive,
test_overload,
) )
message = make_fake_message()("text") message = make_fake_message()("text")
@ -64,5 +65,12 @@ async def test_matcher(app: App, load_plugin):
async with app.test_matcher(test_preset) as ctx: async with app.test_matcher(test_preset) as ctx:
bot = ctx.create_bot() bot = ctx.create_bot()
ctx.receive_event(bot, event) ctx.receive_event(bot, event)
ctx.receive_event(bot, event)
ctx.should_rejected() ctx.should_rejected()
ctx.receive_event(bot, event_next) ctx.receive_event(bot, event_next)
assert len(test_overload.handlers) == 2
async with app.test_matcher(test_overload) as ctx:
bot = ctx.create_bot()
ctx.receive_event(bot, event)
ctx.should_finished()