♻️ allow dynamic param types

This commit is contained in:
yanyongyu 2021-11-15 21:44:24 +08:00
parent cafe5c9af0
commit d1c6eeb6c2
7 changed files with 260 additions and 157 deletions

View File

@ -55,6 +55,16 @@ class Bot(abc.ABC):
def __getattr__(self, name: str) -> _ApiCall: def __getattr__(self, name: str) -> _ApiCall:
return partial(self.call_api, name) return partial(self.call_api, name)
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, v):
if not isinstance(v, cls):
raise TypeError(f"{v} is not an instance of {cls}")
return v
@property @property
@abc.abstractmethod @abc.abstractmethod
def type(self) -> str: def type(self) -> str:

View File

@ -7,7 +7,8 @@ NoneBot 内部处理并按优先级分发事件给所有事件响应器,提供
import asyncio import asyncio
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING, Set, Type, Optional from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Set, Type
from nonebot.log import logger from nonebot.log import logger
from nonebot.rule import TrieRule from nonebot.rule import TrieRule
@ -204,58 +205,63 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
logger.opt(colors=True).success(log_msg) logger.opt(colors=True).success(log_msg)
state = {} state = {}
coros = list(map(lambda x: x(bot, event, state), _event_preprocessors))
if coros:
try:
if show_log:
logger.debug("Running PreProcessors...")
await asyncio.gather(*coros)
except IgnoredException as e:
logger.opt(colors=True).info(
f"Event {escape_tag(event.get_event_name())} is <b>ignored</b>")
return
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running EventPreProcessors. "
"Event ignored!</bg #f8bbd0></r>")
return
# Trie Match # TODO
_, _ = TrieRule.get_value(bot, event, state) async with AsyncExitStack() as stack:
coros = list(map(lambda x: x(bot, event, state), _event_preprocessors))
break_flag = False if coros:
for priority in sorted(matchers.keys()): try:
if break_flag: if show_log:
break logger.debug("Running PreProcessors...")
await asyncio.gather(*coros)
if show_log: except IgnoredException as e:
logger.debug(f"Checking for matchers in priority {priority}...") logger.opt(colors=True).info(
f"Event {escape_tag(event.get_event_name())} is <b>ignored</b>"
pending_tasks = [
_check_matcher(priority, matcher, bot, event, state.copy())
for matcher in matchers[priority]
]
results = await asyncio.gather(*pending_tasks, return_exceptions=True)
for result in results:
if not isinstance(result, Exception):
continue
if isinstance(result, StopPropagation):
break_flag = True
logger.debug("Stop event propagation")
else:
logger.opt(colors=True, exception=result).error(
"<r><bg #f8bbd0>Error when checking Matcher.</bg #f8bbd0></r>"
) )
return
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running EventPreProcessors. "
"Event ignored!</bg #f8bbd0></r>")
return
# Trie Match
_, _ = TrieRule.get_value(bot, event, state)
break_flag = False
for priority in sorted(matchers.keys()):
if break_flag:
break
coros = list(map(lambda x: x(bot, event, state), _event_postprocessors))
if coros:
try:
if show_log: if show_log:
logger.debug("Running PostProcessors...") logger.debug(f"Checking for matchers in priority {priority}...")
await asyncio.gather(*coros)
except Exception as e: pending_tasks = [
logger.opt(colors=True, exception=e).error( _check_matcher(priority, matcher, bot, event, state.copy())
"<r><bg #f8bbd0>Error when running EventPostProcessors</bg #f8bbd0></r>" for matcher in matchers[priority]
) ]
results = await asyncio.gather(*pending_tasks,
return_exceptions=True)
for result in results:
if not isinstance(result, Exception):
continue
if isinstance(result, StopPropagation):
break_flag = True
logger.debug("Stop event propagation")
else:
logger.opt(colors=True, exception=result).error(
"<r><bg #f8bbd0>Error when checking Matcher.</bg #f8bbd0></r>"
)
coros = list(map(lambda x: x(bot, event, state), _event_postprocessors))
if coros:
try:
if show_log:
logger.debug("Running PostProcessors...")
await asyncio.gather(*coros)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running EventPostProcessors</bg #f8bbd0></r>"
)

View File

@ -1,15 +1,18 @@
import inspect import inspect
from itertools import chain from itertools import chain
from typing import Any, Dict, List, Tuple, Callable, Optional, cast from typing import Any, Dict, List, Type, Tuple, Callable, Optional, cast
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
from pydantic import BaseConfig
from pydantic.fields import Required, ModelField
from pydantic.schema import get_annotation_from_field_info
from .models import Dependent from .models import Dependent
from nonebot.log import logger from nonebot.log import logger
from nonebot.typing import T_State from nonebot.typing import T_State
from .utils import get_typed_signature
from nonebot.adapters import Bot, Event from nonebot.adapters import Bot, Event
from .models import Depends as DependsClass from .models import Depends as DependsClass
from .utils import (generic_get_types, get_typed_signature,
generic_check_issubclass)
from nonebot.utils import (run_sync, is_gen_callable, run_sync_ctx_manager, from nonebot.utils import (run_sync, is_gen_callable, run_sync_ctx_manager,
is_async_gen_callable, is_coroutine_callable) is_async_gen_callable, is_coroutine_callable)
@ -27,33 +30,42 @@ def get_param_sub_dependent(*, param: inspect.Parameter) -> Dependent:
) )
def get_parameterless_sub_dependant(*, depends: DependsClass) -> Dependent: def get_parameterless_sub_dependant(
*,
depends: DependsClass,
allow_types: Optional[List["ParamTypes"]] = None) -> Dependent:
assert callable( assert callable(
depends.dependency depends.dependency
), "A parameter-less dependency must have a callable dependency" ), "A parameter-less dependency must have a callable dependency"
return get_sub_dependant(depends=depends, dependency=depends.dependency) return get_sub_dependant(depends=depends,
dependency=depends.dependency,
allow_types=allow_types)
def get_sub_dependant( def get_sub_dependant(
*, *,
depends: DependsClass, depends: DependsClass,
dependency: Callable[..., Any], dependency: Callable[..., Any],
name: Optional[str] = None, name: Optional[str] = None,
) -> Dependent: allow_types: Optional[List["ParamTypes"]] = None) -> Dependent:
sub_dependant = get_dependent( sub_dependant = get_dependent(func=dependency,
func=dependency, name=name,
name=name, use_cache=depends.use_cache,
use_cache=depends.use_cache, allow_types=allow_types)
)
return sub_dependant return sub_dependant
def get_dependent(*, def get_dependent(
func: Callable[..., Any], *,
name: Optional[str] = None, func: Callable[..., Any],
use_cache: bool = True) -> Dependent: name: Optional[str] = None,
use_cache: bool = True,
allow_types: Optional[List["ParamTypes"]] = None) -> Dependent:
signature = get_typed_signature(func) signature = get_typed_signature(func)
params = signature.parameters params = signature.parameters
allow_types = allow_types or [
ParamTypes.BOT, ParamTypes.EVENT, ParamTypes.STATE
]
dependent = Dependent(func=func, name=name, use_cache=use_cache) dependent = Dependent(func=func, name=name, use_cache=use_cache)
for param_name, param in params.items(): for param_name, param in params.items():
if isinstance(param.default, DependsClass): if isinstance(param.default, DependsClass):
@ -61,33 +73,29 @@ def get_dependent(*,
dependent.dependencies.append(sub_dependent) dependent.dependencies.append(sub_dependent)
continue continue
if generic_check_issubclass(param.annotation, Bot): for allow_type in allow_types:
if dependent.bot_param_name is not None: field_info_class: Type[Param] = allow_type.value
raise ValueError(f"{func} has more than one Bot parameter: " if field_info_class._check(param_name, param):
f"{dependent.bot_param_name} / {param_name}") field_info = field_info_class(param.default)
dependent.bot_param_name = param_name break
dependent.bot_param_type = generic_get_types(param.annotation)
elif generic_check_issubclass(param.annotation, Event):
if dependent.event_param_name is not None:
raise ValueError(f"{func} has more than one Event parameter: "
f"{dependent.event_param_name} / {param_name}")
dependent.event_param_name = param_name
dependent.event_param_type = generic_get_types(param.annotation)
elif generic_check_issubclass(param.annotation, Dict):
if dependent.state_param_name is not None:
raise ValueError(f"{func} has more than one State parameter: "
f"{dependent.state_param_name} / {param_name}")
dependent.state_param_name = param_name
elif generic_check_issubclass(param.annotation, Matcher):
if dependent.matcher_param_name is not None:
raise ValueError(
f"{func} has more than one Matcher parameter: "
f"{dependent.matcher_param_name} / {param_name}")
dependent.matcher_param_name = param_name
else: else:
raise ValueError( raise ValueError(
f"Unknown parameter {param_name} with type {param.annotation}") f"Unknown parameter {param_name} with type {param.annotation}")
annotation: Any = Any
if param.annotation != param.empty:
annotation = param.annotation
annotation = get_annotation_from_field_info(annotation, field_info,
param_name)
dependent.params.append(
ModelField(name=param_name,
type_=annotation,
class_validators=None,
model_config=BaseConfig,
default=Required,
required=True,
field_info=field_info))
return dependent return dependent
@ -97,7 +105,8 @@ async def solve_dependencies(
bot: Bot, bot: Bot,
event: Event, event: Event,
state: T_State, state: T_State,
matcher: Optional["Matcher"], matcher: Optional["Matcher"] = None,
exception: Optional[Exception] = None,
stack: Optional[AsyncExitStack] = None, stack: Optional[AsyncExitStack] = None,
sub_dependents: Optional[List[Dependent]] = None, sub_dependents: Optional[List[Dependent]] = None,
dependency_overrides_provider: Optional[Any] = None, dependency_overrides_provider: Optional[Any] = None,
@ -115,20 +124,6 @@ async def solve_dependencies(
sub_dependent.cache_key) sub_dependent.cache_key)
func = sub_dependent.func func = sub_dependent.func
# check bot and event type
if sub_dependent.bot_param_type and not isinstance(
bot, sub_dependent.bot_param_type):
logger.debug(
f"Matcher {matcher} bot type {type(bot)} not match depends {func} "
f"annotation {sub_dependent.bot_param_type}, ignored")
return values, dependency_cache, True
elif sub_dependent.event_param_type and not isinstance(
event, sub_dependent.event_param_type):
logger.debug(
f"Matcher {matcher} event type {type(event)} not match depends {func} "
f"annotation {sub_dependent.event_param_type}, ignored")
return values, dependency_cache, True
# dependency overrides # dependency overrides
use_sub_dependant = sub_dependent use_sub_dependant = sub_dependent
if (dependency_overrides_provider and if (dependency_overrides_provider and
@ -183,14 +178,28 @@ async def solve_dependencies(
dependency_cache[sub_dependent.cache_key] = solved dependency_cache[sub_dependent.cache_key] = solved
# usual dependency # usual dependency
if dependent.bot_param_name is not None: for field in dependent.params:
values[dependent.bot_param_name] = bot field_info = field.field_info
if dependent.event_param_name is not None: assert isinstance(field_info,
values[dependent.event_param_name] = event Param), "Params must be subclasses of Param"
if dependent.state_param_name is not None: value = field_info._solve(bot=bot,
values[dependent.state_param_name] = state event=event,
if dependent.matcher_param_name is not None: state=state,
values[dependent.matcher_param_name] = matcher matcher=matcher,
exception=exception)
_, errs_ = field.validate(value,
values,
loc=(ParamTypes(type(field_info)).name,
field.alias))
if errs_:
logger.debug(
f"Matcher {matcher} {ParamTypes(type(field_info)).name} "
f"type {type(value)} not match depends {dependent.func} "
f"annotation {field._type_display()}, ignored")
return values, dependency_cache, True
else:
values[field.name] = value
return values, dependency_cache, False return values, dependency_cache, False
@ -200,6 +209,8 @@ def Depends(dependency: Optional[Callable[..., Any]] = None,
return DependsClass(dependency=dependency, use_cache=use_cache) return DependsClass(dependency=dependency, use_cache=use_cache)
from .params import Param
from .handler import Handler as Handler from .handler import Handler as Handler
from .matcher import Matcher as Matcher from .matcher import Matcher as Matcher
from .matcher import matchers as matchers from .matcher import matchers as matchers
from .params import ParamTypes as ParamTypes

View File

@ -9,7 +9,6 @@ import asyncio
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Callable, Optional from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Callable, Optional
from nonebot.log import logger
from .models import Depends, Dependent from .models import Depends, Dependent
from nonebot.utils import get_name, run_sync from nonebot.utils import get_name, run_sync
from nonebot.typing import T_State, T_Handler from nonebot.typing import T_State, T_Handler
@ -17,6 +16,7 @@ from . import get_dependent, solve_dependencies, get_parameterless_sub_dependant
if TYPE_CHECKING: if TYPE_CHECKING:
from .matcher import Matcher from .matcher import Matcher
from .params import ParamTypes
from nonebot.adapters import Bot, Event from nonebot.adapters import Bot, Event
@ -28,6 +28,7 @@ class Handler:
*, *,
name: Optional[str] = None, name: Optional[str] = None,
dependencies: Optional[List[Depends]] = None, dependencies: Optional[List[Depends]] = None,
allow_types: Optional[List["ParamTypes"]] = None,
dependency_overrides_provider: Optional[Any] = None): dependency_overrides_provider: Optional[Any] = None):
"""装饰事件处理函数以便根据动态参数运行""" """装饰事件处理函数以便根据动态参数运行"""
self.func: T_Handler = func self.func: T_Handler = func
@ -36,6 +37,7 @@ class Handler:
:说明: 事件处理函数 :说明: 事件处理函数
""" """
self.name = get_name(func) if name is None else name self.name = get_name(func) if name is None else name
self.allow_types = allow_types
self.dependencies = dependencies or [] self.dependencies = dependencies or []
self.sub_dependents: Dict[Callable[..., Any], Dependent] = {} self.sub_dependents: Dict[Callable[..., Any], Dependent] = {}
@ -45,18 +47,16 @@ class Handler:
raise ValueError(f"{depends} has no dependency") raise ValueError(f"{depends} has no dependency")
if depends.dependency in self.sub_dependents: if depends.dependency in self.sub_dependents:
raise ValueError(f"{depends} is already in dependencies") raise ValueError(f"{depends} is already in dependencies")
sub_dependant = get_parameterless_sub_dependant(depends=depends) sub_dependant = get_parameterless_sub_dependant(
depends=depends, allow_types=self.allow_types)
self.sub_dependents[depends.dependency] = sub_dependant self.sub_dependents[depends.dependency] = sub_dependant
self.dependency_overrides_provider = dependency_overrides_provider self.dependency_overrides_provider = dependency_overrides_provider
self.dependent = get_dependent(func=func) self.dependent = get_dependent(func=func, allow_types=self.allow_types)
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"<Handler {self.func}(" f"<Handler {self.func}({', '.join(map(str, self.dependent.params))})>"
f"[bot {self.dependent.bot_param_name}]: {self.dependent.bot_param_type}, " )
f"[event {self.dependent.event_param_name}]: {self.dependent.event_param_type}, "
f"[state {self.dependent.state_param_name}], "
f"[matcher {self.dependent.matcher_param_name}])>")
def __str__(self) -> str: def __str__(self) -> str:
return repr(self) return repr(self)
@ -88,19 +88,6 @@ class Handler:
if ignored: if ignored:
return return
# check bot and event type
if self.dependent.bot_param_type and not isinstance(
bot, self.dependent.bot_param_type):
logger.debug(f"Matcher {matcher} bot type {type(bot)} not match "
f"annotation {self.dependent.bot_param_type}, ignored")
return
elif self.dependent.event_param_type and not isinstance(
event, self.dependent.event_param_type):
logger.debug(
f"Matcher {matcher} event type {type(event)} not match "
f"annotation {self.dependent.event_param_type}, ignored")
return
if asyncio.iscoroutinefunction(self.func): if asyncio.iscoroutinefunction(self.func):
await self.func(**values) await self.func(**values)
else: else:
@ -111,7 +98,8 @@ class Handler:
raise ValueError(f"{dependency} has no dependency") raise ValueError(f"{dependency} has no dependency")
if (dependency.dependency,) in self.sub_dependents: if (dependency.dependency,) in self.sub_dependents:
raise ValueError(f"{dependency} is already in dependencies") raise ValueError(f"{dependency} is already in dependencies")
sub_dependant = get_parameterless_sub_dependant(depends=dependency) sub_dependant = get_parameterless_sub_dependant(
depends=dependency, allow_types=self.allow_types)
self.sub_dependents[dependency.dependency] = sub_dependant self.sub_dependents[dependency.dependency] = sub_dependant
def prepend_dependency(self, dependency: Depends): def prepend_dependency(self, dependency: Depends):

View File

@ -15,6 +15,7 @@ from typing import (TYPE_CHECKING, Any, Dict, List, Type, Union, Callable,
from .models import Depends from .models import Depends
from .handler import Handler from .handler import Handler
from nonebot.rule import Rule from nonebot.rule import Rule
from .params import ParamTypes
from nonebot import get_driver from nonebot import get_driver
from nonebot.log import logger from nonebot.log import logger
from nonebot.permission import USER, Permission from nonebot.permission import USER, Permission
@ -153,6 +154,10 @@ class Matcher(metaclass=MatcherMeta):
:说明: 事件响应器权限更新函数 :说明: 事件响应器权限更新函数
""" """
HANDLER_PARAM_TYPES = [
ParamTypes.BOT, ParamTypes.EVENT, ParamTypes.STATE, ParamTypes.MATCHER
]
def __init__(self): def __init__(self):
"""实例化 Matcher 以便运行""" """实例化 Matcher 以便运行"""
self.handlers = self.handlers.copy() self.handlers = self.handlers.copy()
@ -230,7 +235,9 @@ class Matcher(metaclass=MatcherMeta):
permission or Permission(), permission or Permission(),
"handlers": [ "handlers": [
handler if isinstance(handler, Handler) else Handler( handler if isinstance(handler, Handler) else Handler(
handler, dependency_overrides_provider=get_driver()) handler,
dependency_overrides_provider=get_driver(),
allow_types=cls.HANDLER_PARAM_TYPES)
for handler in handlers for handler in handlers
] if handlers else [], ] if handlers else [],
"temp": "temp":
@ -348,7 +355,8 @@ class Matcher(metaclass=MatcherMeta):
dependencies: Optional[List[Depends]] = None) -> Handler: dependencies: Optional[List[Depends]] = None) -> Handler:
handler_ = Handler(handler, handler_ = Handler(handler,
dependencies=dependencies, dependencies=dependencies,
dependency_overrides_provider=get_driver()) dependency_overrides_provider=get_driver(),
allow_types=cls.HANDLER_PARAM_TYPES)
cls.handlers.append(handler_) cls.handlers.append(handler_)
return handler_ return handler_

View File

@ -1,10 +1,9 @@
from typing import TYPE_CHECKING, Any, List, Type, Tuple, Callable, Optional from typing import Any, List, Callable, Optional
from pydantic.fields import ModelField
from nonebot.utils import get_name from nonebot.utils import get_name
if TYPE_CHECKING:
from nonebot.adapters import Bot, Event
class Depends: class Depends:
@ -27,22 +26,12 @@ class Dependent:
*, *,
func: Optional[Callable[..., Any]] = None, func: Optional[Callable[..., Any]] = None,
name: Optional[str] = None, name: Optional[str] = None,
bot_param_name: Optional[str] = None, params: Optional[List[ModelField]] = None,
bot_param_type: Optional[Tuple[Type["Bot"], ...]] = None,
event_param_name: Optional[str] = None,
event_param_type: Optional[Tuple[Type["Event"], ...]] = None,
state_param_name: Optional[str] = None,
matcher_param_name: Optional[str] = None,
dependencies: Optional[List["Dependent"]] = None, dependencies: Optional[List["Dependent"]] = None,
use_cache: bool = True) -> None: use_cache: bool = True) -> None:
self.func = func self.func = func
self.name = name self.name = name
self.bot_param_name = bot_param_name self.params = params or []
self.bot_param_type = bot_param_type
self.event_param_name = event_param_name
self.event_param_type = event_param_type
self.state_param_name = state_param_name
self.matcher_param_name = matcher_param_name
self.dependencies = dependencies or [] self.dependencies = dependencies or []
self.use_cache = use_cache self.use_cache = use_cache
self.cache_key = self.func self.cache_key = self.func

View File

@ -0,0 +1,91 @@
import abc
import inspect
from enum import Enum
from typing import Any, Dict, Optional
from pydantic.fields import FieldInfo
from nonebot.typing import T_State
from nonebot.adapters import Bot, Event
from .utils import generic_check_issubclass
class Param(FieldInfo, abc.ABC):
def __repr__(self) -> str:
return f"{self.__class__.__name__}"
def __str__(self) -> str:
return repr(self)
@classmethod
@abc.abstractmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool:
raise NotImplementedError
@abc.abstractmethod
def _solve(self, **kwargs: Any) -> Any:
raise NotImplementedError
class BotParam(Param):
@classmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool:
return generic_check_issubclass(param.annotation, Bot)
def _solve(self, bot: Bot, **kwargs: Any) -> Any:
return bot
class EventParam(Param):
@classmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool:
return generic_check_issubclass(param.annotation, Event)
def _solve(self, event: Event, **kwargs: Any) -> Any:
return event
class StateParam(Param):
@classmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool:
return generic_check_issubclass(param.annotation, Dict)
def _solve(self, state: T_State, **kwargs: Any) -> Any:
return state
class MatcherParam(Param):
@classmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool:
return generic_check_issubclass(param.annotation, Matcher)
def _solve(self, matcher: Optional["Matcher"] = None, **kwargs: Any) -> Any:
return matcher
class ExceptionParam(Param):
@classmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool:
return generic_check_issubclass(param.annotation, Exception)
def _solve(self,
exception: Optional[Exception] = None,
**kwargs: Any) -> Any:
return exception
class ParamTypes(Enum):
BOT = BotParam
EVENT = EventParam
STATE = StateParam
MATCHER = MatcherParam
EXCEPTION = ExceptionParam
from .matcher import Matcher