add bot event type check

This commit is contained in:
yanyongyu 2021-11-14 01:34:25 +08:00
parent 9d708a6723
commit 7495fee2a2
3 changed files with 43 additions and 22 deletions

View File

@ -98,30 +98,38 @@ async def solve_dependencies(
sub_dependents: Optional[List[Dependent]] = None, sub_dependents: Optional[List[Dependent]] = None,
dependency_overrides_provider: Optional[Any] = None, dependency_overrides_provider: Optional[Any] = None,
dependency_cache: Optional[Dict[Tuple[Callable[..., Any]], Any]] = None, dependency_cache: Optional[Dict[Tuple[Callable[..., Any]], Any]] = None,
) -> Tuple[Dict[str, Any], Dict[Tuple[Callable[..., Any]], Any]]: ) -> Tuple[Dict[str, Any], Dict[Tuple[Callable[..., Any]], Any], bool]:
values: Dict[str, Any] = {} values: Dict[str, Any] = {}
dependency_cache = dependency_cache or {} dependency_cache = dependency_cache or {}
# solve sub dependencies # solve sub dependencies
sub_dependant: Dependent sub_dependent: Dependent
for sub_dependant in chain(sub_dependents or tuple(), for sub_dependent in chain(sub_dependents or tuple(),
dependent.dependencies): dependent.dependencies):
sub_dependant.func = cast(Callable[..., Any], sub_dependant.func) sub_dependent.func = cast(Callable[..., Any], sub_dependent.func)
sub_dependant.cache_key = cast(Tuple[Callable[..., Any]], sub_dependent.cache_key = cast(Tuple[Callable[..., Any]],
sub_dependant.cache_key) sub_dependent.cache_key)
func = sub_dependant.func func = sub_dependent.func
# check bot and event type
if sub_dependent.bot_param_type and not isinstance(
bot, sub_dependent.bot_param_type):
return values, dependency_cache, True
elif sub_dependent.event_param_type and not isinstance(
event, sub_dependent.event_param_type):
return values, dependency_cache, True
# dependency overrides # dependency overrides
use_sub_dependant = sub_dependant use_sub_dependant = sub_dependent
if (dependency_overrides_provider and if (dependency_overrides_provider and
hasattr(dependency_overrides_provider, "dependency_overrides")): hasattr(dependency_overrides_provider, "dependency_overrides")):
original_call = sub_dependant.func original_call = sub_dependent.func
func = getattr(dependency_overrides_provider, func = getattr(dependency_overrides_provider,
"dependency_overrides", "dependency_overrides",
{}).get(original_call, original_call) {}).get(original_call, original_call)
use_sub_dependant = get_dependent( use_sub_dependant = get_dependent(
func=func, func=func,
name=sub_dependant.name, name=sub_dependent.name,
) )
# solve sub dependency with current cache # solve sub dependency with current cache
@ -134,24 +142,26 @@ async def solve_dependencies(
dependency_overrides_provider=dependency_overrides_provider, dependency_overrides_provider=dependency_overrides_provider,
dependency_cache=dependency_cache, dependency_cache=dependency_cache,
) )
sub_values, sub_dependency_cache = solved_result sub_values, sub_dependency_cache, ignored = solved_result
if ignored:
return values, dependency_cache, True
# update cache? # update cache?
dependency_cache.update(sub_dependency_cache) dependency_cache.update(sub_dependency_cache)
# run dependency function # run dependency function
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache: if sub_dependent.use_cache and sub_dependent.cache_key in dependency_cache:
solved = dependency_cache[sub_dependant.cache_key] solved = dependency_cache[sub_dependent.cache_key]
elif is_coroutine_callable(func): elif is_coroutine_callable(func):
solved = await func(**sub_values) solved = await func(**sub_values)
else: else:
solved = await run_sync(func)(**sub_values) solved = await run_sync(func)(**sub_values)
# parameter dependency # parameter dependency
if sub_dependant.name is not None: if sub_dependent.name is not None:
values[sub_dependant.name] = solved values[sub_dependent.name] = solved
# save current dependency to cache # save current dependency to cache
if sub_dependant.cache_key not in dependency_cache: if sub_dependent.cache_key not in dependency_cache:
dependency_cache[sub_dependant.cache_key] = solved dependency_cache[sub_dependent.cache_key] = solved
# usual dependency # usual dependency
if dependent.bot_param_name is not None: if dependent.bot_param_name is not None:
@ -162,7 +172,7 @@ async def solve_dependencies(
values[dependent.state_param_name] = state values[dependent.state_param_name] = state
if dependent.matcher_param_name is not None: if dependent.matcher_param_name is not None:
values[dependent.matcher_param_name] = matcher values[dependent.matcher_param_name] = matcher
return values, dependency_cache return values, dependency_cache, False
def Depends(dependency: Optional[Callable[..., Any]] = None, def Depends(dependency: Optional[Callable[..., Any]] = None,

View File

@ -50,7 +50,7 @@ class Handler:
async def __call__(self, matcher: "Matcher", bot: Bot, event: Event, async def __call__(self, matcher: "Matcher", bot: Bot, event: Event,
state: T_State): state: T_State):
values, _ = await solve_dependencies( values, _, ignored = await solve_dependencies(
dependent=self.dependent, dependent=self.dependent,
bot=bot, bot=bot,
event=event, event=event,
@ -62,6 +62,17 @@ class Handler:
], ],
dependency_overrides_provider=self.dependency_overrides_provider) dependency_overrides_provider=self.dependency_overrides_provider)
if ignored:
return
# check bot and event type
if self.dependent.bot_param_type and not isinstance(
bot, self.dependent.bot_param_type):
return
elif self.dependent.event_param_type and not isinstance(
event, self.dependent.event_param_type):
return
if asyncio.iscoroutinefunction(self.func): if asyncio.iscoroutinefunction(self.func):
await self.func(**values) await self.func(**values)
else: else:

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, List, Tuple, Callable, Optional from typing import TYPE_CHECKING, Any, List, Type, Tuple, Callable, Optional
from nonebot.utils import get_name from nonebot.utils import get_name
@ -28,9 +28,9 @@ class Dependent:
func: Optional[Callable[..., Any]] = None, func: Optional[Callable[..., Any]] = None,
name: Optional[str] = None, name: Optional[str] = None,
bot_param_name: Optional[str] = None, bot_param_name: Optional[str] = None,
bot_param_type: Optional[Tuple["Bot", ...]] = None, bot_param_type: Optional[Tuple[Type["Bot"], ...]] = None,
event_param_name: Optional[str] = None, event_param_name: Optional[str] = None,
event_param_type: Optional[Tuple["Event", ...]] = None, event_param_type: Optional[Tuple[Type["Event"], ...]] = None,
state_param_name: Optional[str] = None, state_param_name: Optional[str] = None,
matcher_param_name: Optional[str] = None, matcher_param_name: Optional[str] = None,
dependencies: Optional[List["Dependent"]] = None, dependencies: Optional[List["Dependent"]] = None,