add bot event type check

This commit is contained in:
yanyongyu 2021-11-14 01:34:25 +08:00
parent 9d708a6723
commit 7495fee2a2
3 changed files with 43 additions and 22 deletions

View File

@ -98,30 +98,38 @@ async def solve_dependencies(
sub_dependents: Optional[List[Dependent]] = None,
dependency_overrides_provider: Optional[Any] = None,
dependency_cache: Optional[Dict[Tuple[Callable[..., Any]], Any]] = None,
) -> Tuple[Dict[str, Any], Dict[Tuple[Callable[..., Any]], Any]]:
) -> Tuple[Dict[str, Any], Dict[Tuple[Callable[..., Any]], Any], bool]:
values: Dict[str, Any] = {}
dependency_cache = dependency_cache or {}
# solve sub dependencies
sub_dependant: Dependent
for sub_dependant in chain(sub_dependents or tuple(),
sub_dependent: Dependent
for sub_dependent in chain(sub_dependents or tuple(),
dependent.dependencies):
sub_dependant.func = cast(Callable[..., Any], sub_dependant.func)
sub_dependant.cache_key = cast(Tuple[Callable[..., Any]],
sub_dependant.cache_key)
func = sub_dependant.func
sub_dependent.func = cast(Callable[..., Any], sub_dependent.func)
sub_dependent.cache_key = cast(Tuple[Callable[..., Any]],
sub_dependent.cache_key)
func = sub_dependent.func
# check bot and event type
if sub_dependent.bot_param_type and not isinstance(
bot, sub_dependent.bot_param_type):
return values, dependency_cache, True
elif sub_dependent.event_param_type and not isinstance(
event, sub_dependent.event_param_type):
return values, dependency_cache, True
# dependency overrides
use_sub_dependant = sub_dependant
use_sub_dependant = sub_dependent
if (dependency_overrides_provider and
hasattr(dependency_overrides_provider, "dependency_overrides")):
original_call = sub_dependant.func
original_call = sub_dependent.func
func = getattr(dependency_overrides_provider,
"dependency_overrides",
{}).get(original_call, original_call)
use_sub_dependant = get_dependent(
func=func,
name=sub_dependant.name,
name=sub_dependent.name,
)
# solve sub dependency with current cache
@ -134,24 +142,26 @@ async def solve_dependencies(
dependency_overrides_provider=dependency_overrides_provider,
dependency_cache=dependency_cache,
)
sub_values, sub_dependency_cache = solved_result
sub_values, sub_dependency_cache, ignored = solved_result
if ignored:
return values, dependency_cache, True
# update cache?
dependency_cache.update(sub_dependency_cache)
# run dependency function
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
solved = dependency_cache[sub_dependant.cache_key]
if sub_dependent.use_cache and sub_dependent.cache_key in dependency_cache:
solved = dependency_cache[sub_dependent.cache_key]
elif is_coroutine_callable(func):
solved = await func(**sub_values)
else:
solved = await run_sync(func)(**sub_values)
# parameter dependency
if sub_dependant.name is not None:
values[sub_dependant.name] = solved
if sub_dependent.name is not None:
values[sub_dependent.name] = solved
# save current dependency to cache
if sub_dependant.cache_key not in dependency_cache:
dependency_cache[sub_dependant.cache_key] = solved
if sub_dependent.cache_key not in dependency_cache:
dependency_cache[sub_dependent.cache_key] = solved
# usual dependency
if dependent.bot_param_name is not None:
@ -162,7 +172,7 @@ async def solve_dependencies(
values[dependent.state_param_name] = state
if dependent.matcher_param_name is not None:
values[dependent.matcher_param_name] = matcher
return values, dependency_cache
return values, dependency_cache, False
def Depends(dependency: Optional[Callable[..., Any]] = None,

View File

@ -50,7 +50,7 @@ class Handler:
async def __call__(self, matcher: "Matcher", bot: Bot, event: Event,
state: T_State):
values, _ = await solve_dependencies(
values, _, ignored = await solve_dependencies(
dependent=self.dependent,
bot=bot,
event=event,
@ -62,6 +62,17 @@ class Handler:
],
dependency_overrides_provider=self.dependency_overrides_provider)
if ignored:
return
# check bot and event type
if self.dependent.bot_param_type and not isinstance(
bot, self.dependent.bot_param_type):
return
elif self.dependent.event_param_type and not isinstance(
event, self.dependent.event_param_type):
return
if asyncio.iscoroutinefunction(self.func):
await self.func(**values)
else:

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, List, Tuple, Callable, Optional
from typing import TYPE_CHECKING, Any, List, Type, Tuple, Callable, Optional
from nonebot.utils import get_name
@ -28,9 +28,9 @@ class Dependent:
func: Optional[Callable[..., Any]] = None,
name: Optional[str] = None,
bot_param_name: Optional[str] = None,
bot_param_type: Optional[Tuple["Bot", ...]] = None,
bot_param_type: Optional[Tuple[Type["Bot"], ...]] = None,
event_param_name: Optional[str] = None,
event_param_type: Optional[Tuple["Event", ...]] = None,
event_param_type: Optional[Tuple[Type["Event"], ...]] = None,
state_param_name: Optional[str] = None,
matcher_param_name: Optional[str] = None,
dependencies: Optional[List["Dependent"]] = None,