♻️ allow dynamic param types

This commit is contained in:
yanyongyu 2021-11-15 21:44:24 +08:00
parent cafe5c9af0
commit d1c6eeb6c2
7 changed files with 260 additions and 157 deletions

View File

@ -55,6 +55,16 @@ class Bot(abc.ABC):
def __getattr__(self, name: str) -> _ApiCall:
return partial(self.call_api, name)
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, v):
if not isinstance(v, cls):
raise TypeError(f"{v} is not an instance of {cls}")
return v
@property
@abc.abstractmethod
def type(self) -> str:

View File

@ -7,7 +7,8 @@ NoneBot 内部处理并按优先级分发事件给所有事件响应器,提供
import asyncio
from datetime import datetime
from typing import TYPE_CHECKING, Set, Type, Optional
from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Set, Type
from nonebot.log import logger
from nonebot.rule import TrieRule
@ -204,58 +205,63 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
logger.opt(colors=True).success(log_msg)
state = {}
coros = list(map(lambda x: x(bot, event, state), _event_preprocessors))
if coros:
try:
if show_log:
logger.debug("Running PreProcessors...")
await asyncio.gather(*coros)
except IgnoredException as e:
logger.opt(colors=True).info(
f"Event {escape_tag(event.get_event_name())} is <b>ignored</b>")
return
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running EventPreProcessors. "
"Event ignored!</bg #f8bbd0></r>")
return
# Trie Match
_, _ = TrieRule.get_value(bot, event, state)
break_flag = False
for priority in sorted(matchers.keys()):
if break_flag:
break
if show_log:
logger.debug(f"Checking for matchers in priority {priority}...")
pending_tasks = [
_check_matcher(priority, matcher, bot, event, state.copy())
for matcher in matchers[priority]
]
results = await asyncio.gather(*pending_tasks, return_exceptions=True)
for result in results:
if not isinstance(result, Exception):
continue
if isinstance(result, StopPropagation):
break_flag = True
logger.debug("Stop event propagation")
else:
logger.opt(colors=True, exception=result).error(
"<r><bg #f8bbd0>Error when checking Matcher.</bg #f8bbd0></r>"
# TODO
async with AsyncExitStack() as stack:
coros = list(map(lambda x: x(bot, event, state), _event_preprocessors))
if coros:
try:
if show_log:
logger.debug("Running PreProcessors...")
await asyncio.gather(*coros)
except IgnoredException as e:
logger.opt(colors=True).info(
f"Event {escape_tag(event.get_event_name())} is <b>ignored</b>"
)
return
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running EventPreProcessors. "
"Event ignored!</bg #f8bbd0></r>")
return
# Trie Match
_, _ = TrieRule.get_value(bot, event, state)
break_flag = False
for priority in sorted(matchers.keys()):
if break_flag:
break
coros = list(map(lambda x: x(bot, event, state), _event_postprocessors))
if coros:
try:
if show_log:
logger.debug("Running PostProcessors...")
await asyncio.gather(*coros)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running EventPostProcessors</bg #f8bbd0></r>"
)
logger.debug(f"Checking for matchers in priority {priority}...")
pending_tasks = [
_check_matcher(priority, matcher, bot, event, state.copy())
for matcher in matchers[priority]
]
results = await asyncio.gather(*pending_tasks,
return_exceptions=True)
for result in results:
if not isinstance(result, Exception):
continue
if isinstance(result, StopPropagation):
break_flag = True
logger.debug("Stop event propagation")
else:
logger.opt(colors=True, exception=result).error(
"<r><bg #f8bbd0>Error when checking Matcher.</bg #f8bbd0></r>"
)
coros = list(map(lambda x: x(bot, event, state), _event_postprocessors))
if coros:
try:
if show_log:
logger.debug("Running PostProcessors...")
await asyncio.gather(*coros)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running EventPostProcessors</bg #f8bbd0></r>"
)

View File

@ -1,15 +1,18 @@
import inspect
from itertools import chain
from typing import Any, Dict, List, Tuple, Callable, Optional, cast
from typing import Any, Dict, List, Type, Tuple, Callable, Optional, cast
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
from pydantic import BaseConfig
from pydantic.fields import Required, ModelField
from pydantic.schema import get_annotation_from_field_info
from .models import Dependent
from nonebot.log import logger
from nonebot.typing import T_State
from .utils import get_typed_signature
from nonebot.adapters import Bot, Event
from .models import Depends as DependsClass
from .utils import (generic_get_types, get_typed_signature,
generic_check_issubclass)
from nonebot.utils import (run_sync, is_gen_callable, run_sync_ctx_manager,
is_async_gen_callable, is_coroutine_callable)
@ -27,33 +30,42 @@ def get_param_sub_dependent(*, param: inspect.Parameter) -> Dependent:
)
def get_parameterless_sub_dependant(*, depends: DependsClass) -> Dependent:
def get_parameterless_sub_dependant(
*,
depends: DependsClass,
allow_types: Optional[List["ParamTypes"]] = None) -> Dependent:
assert callable(
depends.dependency
), "A parameter-less dependency must have a callable dependency"
return get_sub_dependant(depends=depends, dependency=depends.dependency)
return get_sub_dependant(depends=depends,
dependency=depends.dependency,
allow_types=allow_types)
def get_sub_dependant(
*,
depends: DependsClass,
dependency: Callable[..., Any],
name: Optional[str] = None,
) -> Dependent:
sub_dependant = get_dependent(
func=dependency,
name=name,
use_cache=depends.use_cache,
)
*,
depends: DependsClass,
dependency: Callable[..., Any],
name: Optional[str] = None,
allow_types: Optional[List["ParamTypes"]] = None) -> Dependent:
sub_dependant = get_dependent(func=dependency,
name=name,
use_cache=depends.use_cache,
allow_types=allow_types)
return sub_dependant
def get_dependent(*,
func: Callable[..., Any],
name: Optional[str] = None,
use_cache: bool = True) -> Dependent:
def get_dependent(
*,
func: Callable[..., Any],
name: Optional[str] = None,
use_cache: bool = True,
allow_types: Optional[List["ParamTypes"]] = None) -> Dependent:
signature = get_typed_signature(func)
params = signature.parameters
allow_types = allow_types or [
ParamTypes.BOT, ParamTypes.EVENT, ParamTypes.STATE
]
dependent = Dependent(func=func, name=name, use_cache=use_cache)
for param_name, param in params.items():
if isinstance(param.default, DependsClass):
@ -61,33 +73,29 @@ def get_dependent(*,
dependent.dependencies.append(sub_dependent)
continue
if generic_check_issubclass(param.annotation, Bot):
if dependent.bot_param_name is not None:
raise ValueError(f"{func} has more than one Bot parameter: "
f"{dependent.bot_param_name} / {param_name}")
dependent.bot_param_name = param_name
dependent.bot_param_type = generic_get_types(param.annotation)
elif generic_check_issubclass(param.annotation, Event):
if dependent.event_param_name is not None:
raise ValueError(f"{func} has more than one Event parameter: "
f"{dependent.event_param_name} / {param_name}")
dependent.event_param_name = param_name
dependent.event_param_type = generic_get_types(param.annotation)
elif generic_check_issubclass(param.annotation, Dict):
if dependent.state_param_name is not None:
raise ValueError(f"{func} has more than one State parameter: "
f"{dependent.state_param_name} / {param_name}")
dependent.state_param_name = param_name
elif generic_check_issubclass(param.annotation, Matcher):
if dependent.matcher_param_name is not None:
raise ValueError(
f"{func} has more than one Matcher parameter: "
f"{dependent.matcher_param_name} / {param_name}")
dependent.matcher_param_name = param_name
for allow_type in allow_types:
field_info_class: Type[Param] = allow_type.value
if field_info_class._check(param_name, param):
field_info = field_info_class(param.default)
break
else:
raise ValueError(
f"Unknown parameter {param_name} with type {param.annotation}")
annotation: Any = Any
if param.annotation != param.empty:
annotation = param.annotation
annotation = get_annotation_from_field_info(annotation, field_info,
param_name)
dependent.params.append(
ModelField(name=param_name,
type_=annotation,
class_validators=None,
model_config=BaseConfig,
default=Required,
required=True,
field_info=field_info))
return dependent
@ -97,7 +105,8 @@ async def solve_dependencies(
bot: Bot,
event: Event,
state: T_State,
matcher: Optional["Matcher"],
matcher: Optional["Matcher"] = None,
exception: Optional[Exception] = None,
stack: Optional[AsyncExitStack] = None,
sub_dependents: Optional[List[Dependent]] = None,
dependency_overrides_provider: Optional[Any] = None,
@ -115,20 +124,6 @@ async def solve_dependencies(
sub_dependent.cache_key)
func = sub_dependent.func
# check bot and event type
if sub_dependent.bot_param_type and not isinstance(
bot, sub_dependent.bot_param_type):
logger.debug(
f"Matcher {matcher} bot type {type(bot)} not match depends {func} "
f"annotation {sub_dependent.bot_param_type}, ignored")
return values, dependency_cache, True
elif sub_dependent.event_param_type and not isinstance(
event, sub_dependent.event_param_type):
logger.debug(
f"Matcher {matcher} event type {type(event)} not match depends {func} "
f"annotation {sub_dependent.event_param_type}, ignored")
return values, dependency_cache, True
# dependency overrides
use_sub_dependant = sub_dependent
if (dependency_overrides_provider and
@ -183,14 +178,28 @@ async def solve_dependencies(
dependency_cache[sub_dependent.cache_key] = solved
# usual dependency
if dependent.bot_param_name is not None:
values[dependent.bot_param_name] = bot
if dependent.event_param_name is not None:
values[dependent.event_param_name] = event
if dependent.state_param_name is not None:
values[dependent.state_param_name] = state
if dependent.matcher_param_name is not None:
values[dependent.matcher_param_name] = matcher
for field in dependent.params:
field_info = field.field_info
assert isinstance(field_info,
Param), "Params must be subclasses of Param"
value = field_info._solve(bot=bot,
event=event,
state=state,
matcher=matcher,
exception=exception)
_, errs_ = field.validate(value,
values,
loc=(ParamTypes(type(field_info)).name,
field.alias))
if errs_:
logger.debug(
f"Matcher {matcher} {ParamTypes(type(field_info)).name} "
f"type {type(value)} not match depends {dependent.func} "
f"annotation {field._type_display()}, ignored")
return values, dependency_cache, True
else:
values[field.name] = value
return values, dependency_cache, False
@ -200,6 +209,8 @@ def Depends(dependency: Optional[Callable[..., Any]] = None,
return DependsClass(dependency=dependency, use_cache=use_cache)
from .params import Param
from .handler import Handler as Handler
from .matcher import Matcher as Matcher
from .matcher import matchers as matchers
from .params import ParamTypes as ParamTypes

View File

@ -9,7 +9,6 @@ import asyncio
from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Callable, Optional
from nonebot.log import logger
from .models import Depends, Dependent
from nonebot.utils import get_name, run_sync
from nonebot.typing import T_State, T_Handler
@ -17,6 +16,7 @@ from . import get_dependent, solve_dependencies, get_parameterless_sub_dependant
if TYPE_CHECKING:
from .matcher import Matcher
from .params import ParamTypes
from nonebot.adapters import Bot, Event
@ -28,6 +28,7 @@ class Handler:
*,
name: Optional[str] = None,
dependencies: Optional[List[Depends]] = None,
allow_types: Optional[List["ParamTypes"]] = None,
dependency_overrides_provider: Optional[Any] = None):
"""装饰事件处理函数以便根据动态参数运行"""
self.func: T_Handler = func
@ -36,6 +37,7 @@ class Handler:
:说明: 事件处理函数
"""
self.name = get_name(func) if name is None else name
self.allow_types = allow_types
self.dependencies = dependencies or []
self.sub_dependents: Dict[Callable[..., Any], Dependent] = {}
@ -45,18 +47,16 @@ class Handler:
raise ValueError(f"{depends} has no dependency")
if depends.dependency in self.sub_dependents:
raise ValueError(f"{depends} is already in dependencies")
sub_dependant = get_parameterless_sub_dependant(depends=depends)
sub_dependant = get_parameterless_sub_dependant(
depends=depends, allow_types=self.allow_types)
self.sub_dependents[depends.dependency] = sub_dependant
self.dependency_overrides_provider = dependency_overrides_provider
self.dependent = get_dependent(func=func)
self.dependent = get_dependent(func=func, allow_types=self.allow_types)
def __repr__(self) -> str:
return (
f"<Handler {self.func}("
f"[bot {self.dependent.bot_param_name}]: {self.dependent.bot_param_type}, "
f"[event {self.dependent.event_param_name}]: {self.dependent.event_param_type}, "
f"[state {self.dependent.state_param_name}], "
f"[matcher {self.dependent.matcher_param_name}])>")
f"<Handler {self.func}({', '.join(map(str, self.dependent.params))})>"
)
def __str__(self) -> str:
return repr(self)
@ -88,19 +88,6 @@ class Handler:
if ignored:
return
# check bot and event type
if self.dependent.bot_param_type and not isinstance(
bot, self.dependent.bot_param_type):
logger.debug(f"Matcher {matcher} bot type {type(bot)} not match "
f"annotation {self.dependent.bot_param_type}, ignored")
return
elif self.dependent.event_param_type and not isinstance(
event, self.dependent.event_param_type):
logger.debug(
f"Matcher {matcher} event type {type(event)} not match "
f"annotation {self.dependent.event_param_type}, ignored")
return
if asyncio.iscoroutinefunction(self.func):
await self.func(**values)
else:
@ -111,7 +98,8 @@ class Handler:
raise ValueError(f"{dependency} has no dependency")
if (dependency.dependency,) in self.sub_dependents:
raise ValueError(f"{dependency} is already in dependencies")
sub_dependant = get_parameterless_sub_dependant(depends=dependency)
sub_dependant = get_parameterless_sub_dependant(
depends=dependency, allow_types=self.allow_types)
self.sub_dependents[dependency.dependency] = sub_dependant
def prepend_dependency(self, dependency: Depends):

View File

@ -15,6 +15,7 @@ from typing import (TYPE_CHECKING, Any, Dict, List, Type, Union, Callable,
from .models import Depends
from .handler import Handler
from nonebot.rule import Rule
from .params import ParamTypes
from nonebot import get_driver
from nonebot.log import logger
from nonebot.permission import USER, Permission
@ -153,6 +154,10 @@ class Matcher(metaclass=MatcherMeta):
:说明: 事件响应器权限更新函数
"""
HANDLER_PARAM_TYPES = [
ParamTypes.BOT, ParamTypes.EVENT, ParamTypes.STATE, ParamTypes.MATCHER
]
def __init__(self):
"""实例化 Matcher 以便运行"""
self.handlers = self.handlers.copy()
@ -230,7 +235,9 @@ class Matcher(metaclass=MatcherMeta):
permission or Permission(),
"handlers": [
handler if isinstance(handler, Handler) else Handler(
handler, dependency_overrides_provider=get_driver())
handler,
dependency_overrides_provider=get_driver(),
allow_types=cls.HANDLER_PARAM_TYPES)
for handler in handlers
] if handlers else [],
"temp":
@ -348,7 +355,8 @@ class Matcher(metaclass=MatcherMeta):
dependencies: Optional[List[Depends]] = None) -> Handler:
handler_ = Handler(handler,
dependencies=dependencies,
dependency_overrides_provider=get_driver())
dependency_overrides_provider=get_driver(),
allow_types=cls.HANDLER_PARAM_TYPES)
cls.handlers.append(handler_)
return handler_

View File

@ -1,10 +1,9 @@
from typing import TYPE_CHECKING, Any, List, Type, Tuple, Callable, Optional
from typing import Any, List, Callable, Optional
from pydantic.fields import ModelField
from nonebot.utils import get_name
if TYPE_CHECKING:
from nonebot.adapters import Bot, Event
class Depends:
@ -27,22 +26,12 @@ class Dependent:
*,
func: Optional[Callable[..., Any]] = None,
name: Optional[str] = None,
bot_param_name: Optional[str] = None,
bot_param_type: Optional[Tuple[Type["Bot"], ...]] = None,
event_param_name: Optional[str] = None,
event_param_type: Optional[Tuple[Type["Event"], ...]] = None,
state_param_name: Optional[str] = None,
matcher_param_name: Optional[str] = None,
params: Optional[List[ModelField]] = None,
dependencies: Optional[List["Dependent"]] = None,
use_cache: bool = True) -> None:
self.func = func
self.name = name
self.bot_param_name = bot_param_name
self.bot_param_type = bot_param_type
self.event_param_name = event_param_name
self.event_param_type = event_param_type
self.state_param_name = state_param_name
self.matcher_param_name = matcher_param_name
self.params = params or []
self.dependencies = dependencies or []
self.use_cache = use_cache
self.cache_key = self.func

View File

@ -0,0 +1,91 @@
import abc
import inspect
from enum import Enum
from typing import Any, Dict, Optional
from pydantic.fields import FieldInfo
from nonebot.typing import T_State
from nonebot.adapters import Bot, Event
from .utils import generic_check_issubclass
class Param(FieldInfo, abc.ABC):
def __repr__(self) -> str:
return f"{self.__class__.__name__}"
def __str__(self) -> str:
return repr(self)
@classmethod
@abc.abstractmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool:
raise NotImplementedError
@abc.abstractmethod
def _solve(self, **kwargs: Any) -> Any:
raise NotImplementedError
class BotParam(Param):
@classmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool:
return generic_check_issubclass(param.annotation, Bot)
def _solve(self, bot: Bot, **kwargs: Any) -> Any:
return bot
class EventParam(Param):
@classmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool:
return generic_check_issubclass(param.annotation, Event)
def _solve(self, event: Event, **kwargs: Any) -> Any:
return event
class StateParam(Param):
@classmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool:
return generic_check_issubclass(param.annotation, Dict)
def _solve(self, state: T_State, **kwargs: Any) -> Any:
return state
class MatcherParam(Param):
@classmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool:
return generic_check_issubclass(param.annotation, Matcher)
def _solve(self, matcher: Optional["Matcher"] = None, **kwargs: Any) -> Any:
return matcher
class ExceptionParam(Param):
@classmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool:
return generic_check_issubclass(param.annotation, Exception)
def _solve(self,
exception: Optional[Exception] = None,
**kwargs: Any) -> Any:
return exception
class ParamTypes(Enum):
BOT = BotParam
EVENT = EventParam
STATE = StateParam
MATCHER = MatcherParam
EXCEPTION = ExceptionParam
from .matcher import Matcher