diff --git a/nonebot/adapters/_bot.py b/nonebot/adapters/_bot.py
index b96e31ae..fba3a134 100644
--- a/nonebot/adapters/_bot.py
+++ b/nonebot/adapters/_bot.py
@@ -55,6 +55,16 @@ class Bot(abc.ABC):
def __getattr__(self, name: str) -> _ApiCall:
return partial(self.call_api, name)
+ @classmethod
+ def __get_validators__(cls):
+ yield cls.validate
+
+ @classmethod
+ def validate(cls, v):
+ if not isinstance(v, cls):
+ raise TypeError(f"{v} is not an instance of {cls}")
+ return v
+
@property
@abc.abstractmethod
def type(self) -> str:
diff --git a/nonebot/message.py b/nonebot/message.py
index f03a7c6d..a74003fe 100644
--- a/nonebot/message.py
+++ b/nonebot/message.py
@@ -7,7 +7,8 @@ NoneBot 内部处理并按优先级分发事件给所有事件响应器,提供
import asyncio
from datetime import datetime
-from typing import TYPE_CHECKING, Set, Type, Optional
+from contextlib import AsyncExitStack
+from typing import TYPE_CHECKING, Set, Type
from nonebot.log import logger
from nonebot.rule import TrieRule
@@ -204,58 +205,63 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
logger.opt(colors=True).success(log_msg)
state = {}
- coros = list(map(lambda x: x(bot, event, state), _event_preprocessors))
- if coros:
- try:
- if show_log:
- logger.debug("Running PreProcessors...")
- await asyncio.gather(*coros)
- except IgnoredException as e:
- logger.opt(colors=True).info(
- f"Event {escape_tag(event.get_event_name())} is ignored")
- return
- except Exception as e:
- logger.opt(colors=True, exception=e).error(
- "Error when running EventPreProcessors. "
- "Event ignored!")
- return
- # Trie Match
- _, _ = TrieRule.get_value(bot, event, state)
-
- break_flag = False
- for priority in sorted(matchers.keys()):
- if break_flag:
- break
-
- if show_log:
- logger.debug(f"Checking for matchers in priority {priority}...")
-
- pending_tasks = [
- _check_matcher(priority, matcher, bot, event, state.copy())
- for matcher in matchers[priority]
- ]
-
- results = await asyncio.gather(*pending_tasks, return_exceptions=True)
-
- for result in results:
- if not isinstance(result, Exception):
- continue
- if isinstance(result, StopPropagation):
- break_flag = True
- logger.debug("Stop event propagation")
- else:
- logger.opt(colors=True, exception=result).error(
- "Error when checking Matcher."
+ # TODO
+ async with AsyncExitStack() as stack:
+ coros = list(map(lambda x: x(bot, event, state), _event_preprocessors))
+ if coros:
+ try:
+ if show_log:
+ logger.debug("Running PreProcessors...")
+ await asyncio.gather(*coros)
+ except IgnoredException as e:
+ logger.opt(colors=True).info(
+ f"Event {escape_tag(event.get_event_name())} is ignored"
)
+ return
+ except Exception as e:
+ logger.opt(colors=True, exception=e).error(
+ "Error when running EventPreProcessors. "
+ "Event ignored!")
+ return
+
+ # Trie Match
+ _, _ = TrieRule.get_value(bot, event, state)
+
+ break_flag = False
+ for priority in sorted(matchers.keys()):
+ if break_flag:
+ break
- coros = list(map(lambda x: x(bot, event, state), _event_postprocessors))
- if coros:
- try:
if show_log:
- logger.debug("Running PostProcessors...")
- await asyncio.gather(*coros)
- except Exception as e:
- logger.opt(colors=True, exception=e).error(
- "Error when running EventPostProcessors"
- )
+ logger.debug(f"Checking for matchers in priority {priority}...")
+
+ pending_tasks = [
+ _check_matcher(priority, matcher, bot, event, state.copy())
+ for matcher in matchers[priority]
+ ]
+
+ results = await asyncio.gather(*pending_tasks,
+ return_exceptions=True)
+
+ for result in results:
+ if not isinstance(result, Exception):
+ continue
+ if isinstance(result, StopPropagation):
+ break_flag = True
+ logger.debug("Stop event propagation")
+ else:
+ logger.opt(colors=True, exception=result).error(
+ "Error when checking Matcher."
+ )
+
+ coros = list(map(lambda x: x(bot, event, state), _event_postprocessors))
+ if coros:
+ try:
+ if show_log:
+ logger.debug("Running PostProcessors...")
+ await asyncio.gather(*coros)
+ except Exception as e:
+ logger.opt(colors=True, exception=e).error(
+ "Error when running EventPostProcessors"
+ )
diff --git a/nonebot/processor/__init__.py b/nonebot/processor/__init__.py
index 8b344b01..9860fc8d 100644
--- a/nonebot/processor/__init__.py
+++ b/nonebot/processor/__init__.py
@@ -1,15 +1,18 @@
import inspect
from itertools import chain
-from typing import Any, Dict, List, Tuple, Callable, Optional, cast
+from typing import Any, Dict, List, Type, Tuple, Callable, Optional, cast
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
+from pydantic import BaseConfig
+from pydantic.fields import Required, ModelField
+from pydantic.schema import get_annotation_from_field_info
+
from .models import Dependent
from nonebot.log import logger
from nonebot.typing import T_State
+from .utils import get_typed_signature
from nonebot.adapters import Bot, Event
from .models import Depends as DependsClass
-from .utils import (generic_get_types, get_typed_signature,
- generic_check_issubclass)
from nonebot.utils import (run_sync, is_gen_callable, run_sync_ctx_manager,
is_async_gen_callable, is_coroutine_callable)
@@ -27,33 +30,42 @@ def get_param_sub_dependent(*, param: inspect.Parameter) -> Dependent:
)
-def get_parameterless_sub_dependant(*, depends: DependsClass) -> Dependent:
+def get_parameterless_sub_dependant(
+ *,
+ depends: DependsClass,
+ allow_types: Optional[List["ParamTypes"]] = None) -> Dependent:
assert callable(
depends.dependency
), "A parameter-less dependency must have a callable dependency"
- return get_sub_dependant(depends=depends, dependency=depends.dependency)
+ return get_sub_dependant(depends=depends,
+ dependency=depends.dependency,
+ allow_types=allow_types)
def get_sub_dependant(
- *,
- depends: DependsClass,
- dependency: Callable[..., Any],
- name: Optional[str] = None,
-) -> Dependent:
- sub_dependant = get_dependent(
- func=dependency,
- name=name,
- use_cache=depends.use_cache,
- )
+ *,
+ depends: DependsClass,
+ dependency: Callable[..., Any],
+ name: Optional[str] = None,
+ allow_types: Optional[List["ParamTypes"]] = None) -> Dependent:
+ sub_dependant = get_dependent(func=dependency,
+ name=name,
+ use_cache=depends.use_cache,
+ allow_types=allow_types)
return sub_dependant
-def get_dependent(*,
- func: Callable[..., Any],
- name: Optional[str] = None,
- use_cache: bool = True) -> Dependent:
+def get_dependent(
+ *,
+ func: Callable[..., Any],
+ name: Optional[str] = None,
+ use_cache: bool = True,
+ allow_types: Optional[List["ParamTypes"]] = None) -> Dependent:
signature = get_typed_signature(func)
params = signature.parameters
+ allow_types = allow_types or [
+ ParamTypes.BOT, ParamTypes.EVENT, ParamTypes.STATE
+ ]
dependent = Dependent(func=func, name=name, use_cache=use_cache)
for param_name, param in params.items():
if isinstance(param.default, DependsClass):
@@ -61,33 +73,29 @@ def get_dependent(*,
dependent.dependencies.append(sub_dependent)
continue
- if generic_check_issubclass(param.annotation, Bot):
- if dependent.bot_param_name is not None:
- raise ValueError(f"{func} has more than one Bot parameter: "
- f"{dependent.bot_param_name} / {param_name}")
- dependent.bot_param_name = param_name
- dependent.bot_param_type = generic_get_types(param.annotation)
- elif generic_check_issubclass(param.annotation, Event):
- if dependent.event_param_name is not None:
- raise ValueError(f"{func} has more than one Event parameter: "
- f"{dependent.event_param_name} / {param_name}")
- dependent.event_param_name = param_name
- dependent.event_param_type = generic_get_types(param.annotation)
- elif generic_check_issubclass(param.annotation, Dict):
- if dependent.state_param_name is not None:
- raise ValueError(f"{func} has more than one State parameter: "
- f"{dependent.state_param_name} / {param_name}")
- dependent.state_param_name = param_name
- elif generic_check_issubclass(param.annotation, Matcher):
- if dependent.matcher_param_name is not None:
- raise ValueError(
- f"{func} has more than one Matcher parameter: "
- f"{dependent.matcher_param_name} / {param_name}")
- dependent.matcher_param_name = param_name
+ for allow_type in allow_types:
+ field_info_class: Type[Param] = allow_type.value
+ if field_info_class._check(param_name, param):
+ field_info = field_info_class(param.default)
+ break
else:
raise ValueError(
f"Unknown parameter {param_name} with type {param.annotation}")
+ annotation: Any = Any
+ if param.annotation != param.empty:
+ annotation = param.annotation
+ annotation = get_annotation_from_field_info(annotation, field_info,
+ param_name)
+ dependent.params.append(
+ ModelField(name=param_name,
+ type_=annotation,
+ class_validators=None,
+ model_config=BaseConfig,
+ default=Required,
+ required=True,
+ field_info=field_info))
+
return dependent
@@ -97,7 +105,8 @@ async def solve_dependencies(
bot: Bot,
event: Event,
state: T_State,
- matcher: Optional["Matcher"],
+ matcher: Optional["Matcher"] = None,
+ exception: Optional[Exception] = None,
stack: Optional[AsyncExitStack] = None,
sub_dependents: Optional[List[Dependent]] = None,
dependency_overrides_provider: Optional[Any] = None,
@@ -115,20 +124,6 @@ async def solve_dependencies(
sub_dependent.cache_key)
func = sub_dependent.func
- # check bot and event type
- if sub_dependent.bot_param_type and not isinstance(
- bot, sub_dependent.bot_param_type):
- logger.debug(
- f"Matcher {matcher} bot type {type(bot)} not match depends {func} "
- f"annotation {sub_dependent.bot_param_type}, ignored")
- return values, dependency_cache, True
- elif sub_dependent.event_param_type and not isinstance(
- event, sub_dependent.event_param_type):
- logger.debug(
- f"Matcher {matcher} event type {type(event)} not match depends {func} "
- f"annotation {sub_dependent.event_param_type}, ignored")
- return values, dependency_cache, True
-
# dependency overrides
use_sub_dependant = sub_dependent
if (dependency_overrides_provider and
@@ -183,14 +178,28 @@ async def solve_dependencies(
dependency_cache[sub_dependent.cache_key] = solved
# usual dependency
- if dependent.bot_param_name is not None:
- values[dependent.bot_param_name] = bot
- if dependent.event_param_name is not None:
- values[dependent.event_param_name] = event
- if dependent.state_param_name is not None:
- values[dependent.state_param_name] = state
- if dependent.matcher_param_name is not None:
- values[dependent.matcher_param_name] = matcher
+ for field in dependent.params:
+ field_info = field.field_info
+ assert isinstance(field_info,
+ Param), "Params must be subclasses of Param"
+ value = field_info._solve(bot=bot,
+ event=event,
+ state=state,
+ matcher=matcher,
+ exception=exception)
+ _, errs_ = field.validate(value,
+ values,
+ loc=(ParamTypes(type(field_info)).name,
+ field.alias))
+ if errs_:
+ logger.debug(
+ f"Matcher {matcher} {ParamTypes(type(field_info)).name} "
+ f"type {type(value)} not match depends {dependent.func} "
+ f"annotation {field._type_display()}, ignored")
+ return values, dependency_cache, True
+ else:
+ values[field.name] = value
+
return values, dependency_cache, False
@@ -200,6 +209,8 @@ def Depends(dependency: Optional[Callable[..., Any]] = None,
return DependsClass(dependency=dependency, use_cache=use_cache)
+from .params import Param
from .handler import Handler as Handler
from .matcher import Matcher as Matcher
from .matcher import matchers as matchers
+from .params import ParamTypes as ParamTypes
diff --git a/nonebot/processor/handler.py b/nonebot/processor/handler.py
index 57e38af5..fed8d92c 100644
--- a/nonebot/processor/handler.py
+++ b/nonebot/processor/handler.py
@@ -9,7 +9,6 @@ import asyncio
from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Callable, Optional
-from nonebot.log import logger
from .models import Depends, Dependent
from nonebot.utils import get_name, run_sync
from nonebot.typing import T_State, T_Handler
@@ -17,6 +16,7 @@ from . import get_dependent, solve_dependencies, get_parameterless_sub_dependant
if TYPE_CHECKING:
from .matcher import Matcher
+ from .params import ParamTypes
from nonebot.adapters import Bot, Event
@@ -28,6 +28,7 @@ class Handler:
*,
name: Optional[str] = None,
dependencies: Optional[List[Depends]] = None,
+ allow_types: Optional[List["ParamTypes"]] = None,
dependency_overrides_provider: Optional[Any] = None):
"""装饰事件处理函数以便根据动态参数运行"""
self.func: T_Handler = func
@@ -36,6 +37,7 @@ class Handler:
:说明: 事件处理函数
"""
self.name = get_name(func) if name is None else name
+ self.allow_types = allow_types
self.dependencies = dependencies or []
self.sub_dependents: Dict[Callable[..., Any], Dependent] = {}
@@ -45,18 +47,16 @@ class Handler:
raise ValueError(f"{depends} has no dependency")
if depends.dependency in self.sub_dependents:
raise ValueError(f"{depends} is already in dependencies")
- sub_dependant = get_parameterless_sub_dependant(depends=depends)
+ sub_dependant = get_parameterless_sub_dependant(
+ depends=depends, allow_types=self.allow_types)
self.sub_dependents[depends.dependency] = sub_dependant
self.dependency_overrides_provider = dependency_overrides_provider
- self.dependent = get_dependent(func=func)
+ self.dependent = get_dependent(func=func, allow_types=self.allow_types)
def __repr__(self) -> str:
return (
- f"")
+ f""
+ )
def __str__(self) -> str:
return repr(self)
@@ -88,19 +88,6 @@ class Handler:
if ignored:
return
- # check bot and event type
- if self.dependent.bot_param_type and not isinstance(
- bot, self.dependent.bot_param_type):
- logger.debug(f"Matcher {matcher} bot type {type(bot)} not match "
- f"annotation {self.dependent.bot_param_type}, ignored")
- return
- elif self.dependent.event_param_type and not isinstance(
- event, self.dependent.event_param_type):
- logger.debug(
- f"Matcher {matcher} event type {type(event)} not match "
- f"annotation {self.dependent.event_param_type}, ignored")
- return
-
if asyncio.iscoroutinefunction(self.func):
await self.func(**values)
else:
@@ -111,7 +98,8 @@ class Handler:
raise ValueError(f"{dependency} has no dependency")
if (dependency.dependency,) in self.sub_dependents:
raise ValueError(f"{dependency} is already in dependencies")
- sub_dependant = get_parameterless_sub_dependant(depends=dependency)
+ sub_dependant = get_parameterless_sub_dependant(
+ depends=dependency, allow_types=self.allow_types)
self.sub_dependents[dependency.dependency] = sub_dependant
def prepend_dependency(self, dependency: Depends):
diff --git a/nonebot/processor/matcher.py b/nonebot/processor/matcher.py
index 9fbceaa9..4ae3ca21 100644
--- a/nonebot/processor/matcher.py
+++ b/nonebot/processor/matcher.py
@@ -15,6 +15,7 @@ from typing import (TYPE_CHECKING, Any, Dict, List, Type, Union, Callable,
from .models import Depends
from .handler import Handler
from nonebot.rule import Rule
+from .params import ParamTypes
from nonebot import get_driver
from nonebot.log import logger
from nonebot.permission import USER, Permission
@@ -153,6 +154,10 @@ class Matcher(metaclass=MatcherMeta):
:说明: 事件响应器权限更新函数
"""
+ HANDLER_PARAM_TYPES = [
+ ParamTypes.BOT, ParamTypes.EVENT, ParamTypes.STATE, ParamTypes.MATCHER
+ ]
+
def __init__(self):
"""实例化 Matcher 以便运行"""
self.handlers = self.handlers.copy()
@@ -230,7 +235,9 @@ class Matcher(metaclass=MatcherMeta):
permission or Permission(),
"handlers": [
handler if isinstance(handler, Handler) else Handler(
- handler, dependency_overrides_provider=get_driver())
+ handler,
+ dependency_overrides_provider=get_driver(),
+ allow_types=cls.HANDLER_PARAM_TYPES)
for handler in handlers
] if handlers else [],
"temp":
@@ -348,7 +355,8 @@ class Matcher(metaclass=MatcherMeta):
dependencies: Optional[List[Depends]] = None) -> Handler:
handler_ = Handler(handler,
dependencies=dependencies,
- dependency_overrides_provider=get_driver())
+ dependency_overrides_provider=get_driver(),
+ allow_types=cls.HANDLER_PARAM_TYPES)
cls.handlers.append(handler_)
return handler_
diff --git a/nonebot/processor/models.py b/nonebot/processor/models.py
index 9413fb8a..06d11890 100644
--- a/nonebot/processor/models.py
+++ b/nonebot/processor/models.py
@@ -1,10 +1,9 @@
-from typing import TYPE_CHECKING, Any, List, Type, Tuple, Callable, Optional
+from typing import Any, List, Callable, Optional
+
+from pydantic.fields import ModelField
from nonebot.utils import get_name
-if TYPE_CHECKING:
- from nonebot.adapters import Bot, Event
-
class Depends:
@@ -27,22 +26,12 @@ class Dependent:
*,
func: Optional[Callable[..., Any]] = None,
name: Optional[str] = None,
- bot_param_name: Optional[str] = None,
- bot_param_type: Optional[Tuple[Type["Bot"], ...]] = None,
- event_param_name: Optional[str] = None,
- event_param_type: Optional[Tuple[Type["Event"], ...]] = None,
- state_param_name: Optional[str] = None,
- matcher_param_name: Optional[str] = None,
+ params: Optional[List[ModelField]] = None,
dependencies: Optional[List["Dependent"]] = None,
use_cache: bool = True) -> None:
self.func = func
self.name = name
- self.bot_param_name = bot_param_name
- self.bot_param_type = bot_param_type
- self.event_param_name = event_param_name
- self.event_param_type = event_param_type
- self.state_param_name = state_param_name
- self.matcher_param_name = matcher_param_name
+ self.params = params or []
self.dependencies = dependencies or []
self.use_cache = use_cache
self.cache_key = self.func
diff --git a/nonebot/processor/params.py b/nonebot/processor/params.py
new file mode 100644
index 00000000..f777889c
--- /dev/null
+++ b/nonebot/processor/params.py
@@ -0,0 +1,91 @@
+import abc
+import inspect
+from enum import Enum
+from typing import Any, Dict, Optional
+
+from pydantic.fields import FieldInfo
+
+from nonebot.typing import T_State
+from nonebot.adapters import Bot, Event
+from .utils import generic_check_issubclass
+
+
+class Param(FieldInfo, abc.ABC):
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}"
+
+ def __str__(self) -> str:
+ return repr(self)
+
+ @classmethod
+ @abc.abstractmethod
+ def _check(cls, name: str, param: inspect.Parameter) -> bool:
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def _solve(self, **kwargs: Any) -> Any:
+ raise NotImplementedError
+
+
+class BotParam(Param):
+
+ @classmethod
+ def _check(cls, name: str, param: inspect.Parameter) -> bool:
+ return generic_check_issubclass(param.annotation, Bot)
+
+ def _solve(self, bot: Bot, **kwargs: Any) -> Any:
+ return bot
+
+
+class EventParam(Param):
+
+ @classmethod
+ def _check(cls, name: str, param: inspect.Parameter) -> bool:
+ return generic_check_issubclass(param.annotation, Event)
+
+ def _solve(self, event: Event, **kwargs: Any) -> Any:
+ return event
+
+
+class StateParam(Param):
+
+ @classmethod
+ def _check(cls, name: str, param: inspect.Parameter) -> bool:
+ return generic_check_issubclass(param.annotation, Dict)
+
+ def _solve(self, state: T_State, **kwargs: Any) -> Any:
+ return state
+
+
+class MatcherParam(Param):
+
+ @classmethod
+ def _check(cls, name: str, param: inspect.Parameter) -> bool:
+ return generic_check_issubclass(param.annotation, Matcher)
+
+ def _solve(self, matcher: Optional["Matcher"] = None, **kwargs: Any) -> Any:
+ return matcher
+
+
+class ExceptionParam(Param):
+
+ @classmethod
+ def _check(cls, name: str, param: inspect.Parameter) -> bool:
+ return generic_check_issubclass(param.annotation, Exception)
+
+ def _solve(self,
+ exception: Optional[Exception] = None,
+ **kwargs: Any) -> Any:
+ return exception
+
+
+class ParamTypes(Enum):
+ BOT = BotParam
+ EVENT = EventParam
+ STATE = StateParam
+ MATCHER = MatcherParam
+ EXCEPTION = ExceptionParam
+
+
+from .matcher import Matcher