🐛 fix matcher got receive type error

This commit is contained in:
yanyongyu 2021-12-31 22:43:29 +08:00
parent c48231454e
commit ec35f292bd
6 changed files with 130 additions and 43 deletions

View File

@ -49,11 +49,13 @@ class Dependent(Generic[R]):
self,
*,
call: Callable[..., Any],
pre_checkers: Optional[List[Param]] = None,
params: Optional[List[ModelField]] = None,
parameterless: Optional[List[Param]] = None,
allow_types: Optional[List[Type[Param]]] = None,
) -> None:
self.call = call
self.pre_checkers = pre_checkers or []
self.params = params or []
self.parameterless = parameterless or []
self.allow_types = allow_types or []
@ -116,11 +118,6 @@ class Dependent(Generic[R]):
allow_types=allow_types,
)
parameterless_params = [
dependent.parse_parameterless(param) for param in (parameterless or [])
]
dependent.parameterless.extend(parameterless_params)
for param_name, param in params.items():
default_value = Required
if param.default != param.empty:
@ -152,6 +149,11 @@ class Dependent(Generic[R]):
)
)
parameterless_params = [
dependent.parse_parameterless(param) for param in (parameterless or [])
]
dependent.parameterless.extend(parameterless_params)
logger.trace(
f"Parsed dependent with call={call}, "
f"params={[param.field_info for param in dependent.params]}, "
@ -166,6 +168,9 @@ class Dependent(Generic[R]):
) -> Dict[str, Any]:
values: Dict[str, Any] = {}
for checker in self.pre_checkers:
await checker._solve(**params)
for param in self.parameterless:
await param._solve(**params)

View File

@ -484,7 +484,6 @@ class Matcher(metaclass=MatcherMeta):
"""
async def _key_getter(event: Event, matcher: "Matcher"):
print(key, matcher.state)
matcher.set_target(ARG_KEY.format(key=key))
if matcher.get_target() == ARG_KEY.format(key=key):
matcher.set_arg(key, event.get_message())

View File

@ -4,10 +4,12 @@ from typing_extensions import Literal
from typing import Any, Dict, List, Tuple, Callable, Optional, cast
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
from pydantic.fields import Required, Undefined
from pydantic.fields import Required, Undefined, ModelField
from nonebot.log import logger
from nonebot.exception import TypeMisMatch
from nonebot.adapters import Bot, Event, Message
from nonebot.dependencies import Param, Dependent
from nonebot.dependencies import Param, Dependent, CustomConfig
from nonebot.typing import T_State, T_Handler, T_DependencyCache
from nonebot.consts import (
CMD_KEY,
@ -94,11 +96,15 @@ class DependParam(Param):
dependency = param.annotation
else:
dependency = param.default.dependency
dependent = Dependent[Any].parse(
sub_dependent = Dependent[Any].parse(
call=dependency,
allow_types=dependent.allow_types,
)
return cls(Required, use_cache=param.default.use_cache, dependent=dependent)
dependent.pre_checkers.extend(sub_dependent.pre_checkers)
sub_dependent.pre_checkers.clear()
return cls(
Required, use_cache=param.default.use_cache, dependent=sub_dependent
)
@classmethod
def _check_parameterless(
@ -158,31 +164,81 @@ class DependParam(Param):
return solved
class _BotChecker(Param):
async def _solve(self, bot: Bot, **kwargs: Any) -> Any:
field: ModelField = self.extra["field"]
_, errs_ = field.validate(bot, {}, loc=("bot",))
if errs_:
logger.debug(
f"Bot type {type(bot)} not match "
f"annotation {field._type_display()}, ignored"
)
raise TypeMisMatch(field, bot)
class BotParam(Param):
@classmethod
def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter
) -> Optional["BotParam"]:
if param.default == param.empty and (
generic_check_issubclass(param.annotation, Bot)
or (param.annotation == param.empty and name == "bot")
):
return cls(Required)
if param.default == param.empty:
if generic_check_issubclass(param.annotation, Bot):
dependent.pre_checkers.append(
_BotChecker(
Required,
field=ModelField(
name="",
type_=param.annotation,
class_validators=None,
model_config=CustomConfig,
default=None,
required=True,
),
)
)
return cls(Required)
elif param.annotation == param.empty and name == "bot":
return cls(Required)
async def _solve(self, bot: Bot, **kwargs: Any) -> Any:
return bot
class _EventChecker(Param):
async def _solve(self, event: Event, **kwargs: Any) -> Any:
field: ModelField = self.extra["field"]
_, errs_ = field.validate(event, {}, loc=("event",))
if errs_:
logger.debug(
f"Event type {type(event)} not match "
f"annotation {field._type_display()}, ignored"
)
raise TypeMisMatch(field, event)
class EventParam(Param):
@classmethod
def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter
) -> Optional["EventParam"]:
if param.default == param.empty and (
generic_check_issubclass(param.annotation, Event)
or (param.annotation == param.empty and name == "event")
):
return cls(Required)
if param.default == param.empty:
if generic_check_issubclass(param.annotation, Event):
dependent.pre_checkers.append(
_EventChecker(
Required,
field=ModelField(
name="",
type_=param.annotation,
class_validators=None,
model_config=CustomConfig,
default=None,
required=True,
),
)
)
return cls(Required)
elif param.annotation == param.empty and name == "event":
return cls(Required)
async def _solve(self, event: Event, **kwargs: Any) -> Any:
return event

44
poetry.lock generated
View File

@ -134,17 +134,17 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
[[package]]
name = "attrs"
version = "21.2.0"
version = "21.4.0"
description = "Classes Without Boilerplate"
category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
[package.extras]
dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "furo", "sphinx", "sphinx-notfound-page", "pre-commit"]
dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "furo", "sphinx", "sphinx-notfound-page", "pre-commit", "cloudpickle"]
docs = ["furo", "sphinx", "zope.interface", "sphinx-notfound-page"]
tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface"]
tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins"]
tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "cloudpickle"]
tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "cloudpickle"]
[[package]]
name = "babel"
@ -555,7 +555,7 @@ pytest-asyncio = "^0.16.0"
type = "git"
url = "https://github.com/nonebot/nonebug.git"
reference = "master"
resolved_reference = "afc3c3fe2cf6300cdf1b5e1a897867a03a17e278"
resolved_reference = "e198b56be8f9ccf53c0d6de38e40fb9c0831c890"
[[package]]
name = "packaging"
@ -578,11 +578,11 @@ python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7"
[[package]]
name = "platformdirs"
version = "2.4.0"
version = "2.4.1"
description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"."
category = "dev"
optional = false
python-versions = ">=3.6"
python-versions = ">=3.7"
[package.extras]
docs = ["Sphinx (>=4)", "furo (>=2021.7.5b38)", "proselint (>=0.10.2)", "sphinx-autodoc-typehints (>=1.12)"]
@ -670,7 +670,7 @@ dev = ["black", "coverage", "docformatter", "flake8", "flake8-black", "flake8-bu
[[package]]
name = "pygments"
version = "2.10.0"
version = "2.11.0"
description = "Pygments is a syntax highlighting package written in Python."
category = "dev"
optional = false
@ -1143,7 +1143,7 @@ h11 = ">=0.9.0,<1"
[[package]]
name = "yapf"
version = "0.31.0"
version = "0.32.0"
description = "A formatter for Python code."
category = "dev"
optional = false
@ -1164,15 +1164,15 @@ typing-extensions = {version = ">=3.7.4", markers = "python_version < \"3.8\""}
[[package]]
name = "zipp"
version = "3.6.0"
version = "3.7.0"
description = "Backport of pathlib-compatible object wrapper for zip files"
category = "main"
optional = false
python-versions = ">=3.6"
python-versions = ">=3.7"
[package.extras]
docs = ["sphinx", "jaraco.packaging (>=8.2)", "rst.linker (>=1.9)"]
testing = ["pytest (>=4.6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "jaraco.itertools", "func-timeout", "pytest-black (>=0.3.7)", "pytest-mypy"]
testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "jaraco.itertools", "func-timeout", "pytest-black (>=0.3.7)", "pytest-mypy"]
[extras]
aiohttp = ["aiohttp"]
@ -1301,8 +1301,8 @@ atomicwrites = [
{file = "atomicwrites-1.4.0.tar.gz", hash = "sha256:ae70396ad1a434f9c7046fd2dd196fc04b12f9e91ffb859164193be8b6168a7a"},
]
attrs = [
{file = "attrs-21.2.0-py2.py3-none-any.whl", hash = "sha256:149e90d6d8ac20db7a955ad60cf0e6881a3f20d37096140088356da6c716b0b1"},
{file = "attrs-21.2.0.tar.gz", hash = "sha256:ef6aaac3ca6cd92904cdd0d83f629a15f18053ec84e6432106f7a4d04ae4f5fb"},
{file = "attrs-21.4.0-py2.py3-none-any.whl", hash = "sha256:2d27e3784d7a565d36ab851fe94887c5eccd6a463168875832a1be79c82828b4"},
{file = "attrs-21.4.0.tar.gz", hash = "sha256:626ba8234211db98e869df76230a137c4c40a12d72445c45d5f5b716f076e2fd"},
]
babel = [
{file = "Babel-2.9.1-py2.py3-none-any.whl", hash = "sha256:ab49e12b91d937cd11f0b67cb259a57ab4ad2b59ac7a3b41d6c06c0ac5b0def9"},
@ -1834,8 +1834,8 @@ pathspec = [
{file = "pathspec-0.9.0.tar.gz", hash = "sha256:e564499435a2673d586f6b2130bb5b95f04a3ba06f81b8f895b651a3c76aabb1"},
]
platformdirs = [
{file = "platformdirs-2.4.0-py3-none-any.whl", hash = "sha256:8868bbe3c3c80d42f20156f22e7131d2fb321f5bc86a2a345375c6481a67021d"},
{file = "platformdirs-2.4.0.tar.gz", hash = "sha256:367a5e80b3d04d2428ffa76d33f124cf11e8fff2acdaa9b43d545f5c7d661ef2"},
{file = "platformdirs-2.4.1-py3-none-any.whl", hash = "sha256:1d7385c7db91728b83efd0ca99a5afb296cab9d0ed8313a45ed8ba17967ecfca"},
{file = "platformdirs-2.4.1.tar.gz", hash = "sha256:440633ddfebcc36264232365d7840a970e75e1018d15b4327d11f91909045fda"},
]
pluggy = [
{file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"},
@ -1915,8 +1915,8 @@ pydash = [
{file = "pydash-5.1.0.tar.gz", hash = "sha256:1b2b050ac1bae049cd07f5920b14fabbe52638f485d9ada1eb115a9eebff6835"},
]
pygments = [
{file = "Pygments-2.10.0-py3-none-any.whl", hash = "sha256:b8e67fe6af78f492b3c4b3e2970c0624cbf08beb1e493b2c99b9fa1b67a20380"},
{file = "Pygments-2.10.0.tar.gz", hash = "sha256:f398865f7eb6874156579fdf36bc840a03cab64d1cde9e93d68f46a425ec52c6"},
{file = "Pygments-2.11.0-py3-none-any.whl", hash = "sha256:ac8098bfc40b8e1091ad7c13490c7f4797e401d0972e8fcfadde90ffb3ed4ea9"},
{file = "Pygments-2.11.0.tar.gz", hash = "sha256:51130f778a028f2d19c143fce00ced6f8b10f726e17599d7e91b290f6cbcda0c"},
]
pygtrie = [
{file = "pygtrie-2.4.2.tar.gz", hash = "sha256:43205559d28863358dbbf25045029f58e2ab357317a59b11f11ade278ac64692"},
@ -2169,8 +2169,8 @@ wsproto = [
{file = "wsproto-1.0.0.tar.gz", hash = "sha256:868776f8456997ad0d9720f7322b746bbe9193751b5b290b7f924659377c8c38"},
]
yapf = [
{file = "yapf-0.31.0-py2.py3-none-any.whl", hash = "sha256:e3a234ba8455fe201eaa649cdac872d590089a18b661e39bbac7020978dd9c2e"},
{file = "yapf-0.31.0.tar.gz", hash = "sha256:408fb9a2b254c302f49db83c59f9aa0b4b0fd0ec25be3a5c51181327922ff63d"},
{file = "yapf-0.32.0-py2.py3-none-any.whl", hash = "sha256:8fea849025584e486fd06d6ba2bed717f396080fd3cc236ba10cb97c4c51cf32"},
{file = "yapf-0.32.0.tar.gz", hash = "sha256:a3f5085d37ef7e3e004c4ba9f9b3e40c54ff1901cd111f05145ae313a7c67d1b"},
]
yarl = [
{file = "yarl-1.7.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f2a8508f7350512434e41065684076f640ecce176d262a7d54f0da41d99c5a95"},
@ -2247,6 +2247,6 @@ yarl = [
{file = "yarl-1.7.2.tar.gz", hash = "sha256:45399b46d60c253327a460e99856752009fcee5f5d3c80b2f7c0cae1c38d56dd"},
]
zipp = [
{file = "zipp-3.6.0-py3-none-any.whl", hash = "sha256:9fe5ea21568a0a70e50f273397638d39b03353731e6cbbb3fd8502a33fec40bc"},
{file = "zipp-3.6.0.tar.gz", hash = "sha256:71c644c5369f4a6e07636f0aa966270449561fcea2e3d6747b8d23efaa9d7832"},
{file = "zipp-3.7.0-py3-none-any.whl", hash = "sha256:b47250dd24f92b7dd6a0a8fc5244da14608f3ca90a5efcd37a3b1642fac9a375"},
{file = "zipp-3.7.0.tar.gz", hash = "sha256:9f50f446828eb9d45b267433fd3e9da8d801f614129124863f9c51ebceafb87d"},
]

View File

@ -66,8 +66,27 @@ async def preset(matcher: Matcher, message: Message = EventMessage()):
@test_preset.got("a")
async def reject_preset(a: str = ArgStr()):
@test_preset.got("b")
async def reject_preset(a: str = ArgStr(), b: str = ArgStr()):
if a == "text":
await test_preset.reject_arg("a")
assert a == "text_next"
assert b == "text"
test_overload = on_message()
class FakeEvent(Event):
...
@test_overload.got("a")
async def overload(event: FakeEvent):
await test_overload.reject_arg("a")
@test_overload.handle()
async def finish():
await test_overload.finish()

View File

@ -12,6 +12,7 @@ async def test_matcher(app: App, load_plugin):
test_preset,
test_combine,
test_receive,
test_overload,
)
message = make_fake_message()("text")
@ -64,5 +65,12 @@ async def test_matcher(app: App, load_plugin):
async with app.test_matcher(test_preset) as ctx:
bot = ctx.create_bot()
ctx.receive_event(bot, event)
ctx.receive_event(bot, event)
ctx.should_rejected()
ctx.receive_event(bot, event_next)
assert len(test_overload.handlers) == 2
async with app.test_matcher(test_overload) as ctx:
bot = ctx.create_bot()
ctx.receive_event(bot, event)
ctx.should_finished()