mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-01-19 09:38:21 +08:00
✨ add bot event type check
This commit is contained in:
parent
9d708a6723
commit
7495fee2a2
@ -98,30 +98,38 @@ async def solve_dependencies(
|
||||
sub_dependents: Optional[List[Dependent]] = None,
|
||||
dependency_overrides_provider: Optional[Any] = None,
|
||||
dependency_cache: Optional[Dict[Tuple[Callable[..., Any]], Any]] = None,
|
||||
) -> Tuple[Dict[str, Any], Dict[Tuple[Callable[..., Any]], Any]]:
|
||||
) -> Tuple[Dict[str, Any], Dict[Tuple[Callable[..., Any]], Any], bool]:
|
||||
values: Dict[str, Any] = {}
|
||||
dependency_cache = dependency_cache or {}
|
||||
|
||||
# solve sub dependencies
|
||||
sub_dependant: Dependent
|
||||
for sub_dependant in chain(sub_dependents or tuple(),
|
||||
sub_dependent: Dependent
|
||||
for sub_dependent in chain(sub_dependents or tuple(),
|
||||
dependent.dependencies):
|
||||
sub_dependant.func = cast(Callable[..., Any], sub_dependant.func)
|
||||
sub_dependant.cache_key = cast(Tuple[Callable[..., Any]],
|
||||
sub_dependant.cache_key)
|
||||
func = sub_dependant.func
|
||||
sub_dependent.func = cast(Callable[..., Any], sub_dependent.func)
|
||||
sub_dependent.cache_key = cast(Tuple[Callable[..., Any]],
|
||||
sub_dependent.cache_key)
|
||||
func = sub_dependent.func
|
||||
|
||||
# check bot and event type
|
||||
if sub_dependent.bot_param_type and not isinstance(
|
||||
bot, sub_dependent.bot_param_type):
|
||||
return values, dependency_cache, True
|
||||
elif sub_dependent.event_param_type and not isinstance(
|
||||
event, sub_dependent.event_param_type):
|
||||
return values, dependency_cache, True
|
||||
|
||||
# dependency overrides
|
||||
use_sub_dependant = sub_dependant
|
||||
use_sub_dependant = sub_dependent
|
||||
if (dependency_overrides_provider and
|
||||
hasattr(dependency_overrides_provider, "dependency_overrides")):
|
||||
original_call = sub_dependant.func
|
||||
original_call = sub_dependent.func
|
||||
func = getattr(dependency_overrides_provider,
|
||||
"dependency_overrides",
|
||||
{}).get(original_call, original_call)
|
||||
use_sub_dependant = get_dependent(
|
||||
func=func,
|
||||
name=sub_dependant.name,
|
||||
name=sub_dependent.name,
|
||||
)
|
||||
|
||||
# solve sub dependency with current cache
|
||||
@ -134,24 +142,26 @@ async def solve_dependencies(
|
||||
dependency_overrides_provider=dependency_overrides_provider,
|
||||
dependency_cache=dependency_cache,
|
||||
)
|
||||
sub_values, sub_dependency_cache = solved_result
|
||||
sub_values, sub_dependency_cache, ignored = solved_result
|
||||
if ignored:
|
||||
return values, dependency_cache, True
|
||||
# update cache?
|
||||
dependency_cache.update(sub_dependency_cache)
|
||||
|
||||
# run dependency function
|
||||
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
|
||||
solved = dependency_cache[sub_dependant.cache_key]
|
||||
if sub_dependent.use_cache and sub_dependent.cache_key in dependency_cache:
|
||||
solved = dependency_cache[sub_dependent.cache_key]
|
||||
elif is_coroutine_callable(func):
|
||||
solved = await func(**sub_values)
|
||||
else:
|
||||
solved = await run_sync(func)(**sub_values)
|
||||
|
||||
# parameter dependency
|
||||
if sub_dependant.name is not None:
|
||||
values[sub_dependant.name] = solved
|
||||
if sub_dependent.name is not None:
|
||||
values[sub_dependent.name] = solved
|
||||
# save current dependency to cache
|
||||
if sub_dependant.cache_key not in dependency_cache:
|
||||
dependency_cache[sub_dependant.cache_key] = solved
|
||||
if sub_dependent.cache_key not in dependency_cache:
|
||||
dependency_cache[sub_dependent.cache_key] = solved
|
||||
|
||||
# usual dependency
|
||||
if dependent.bot_param_name is not None:
|
||||
@ -162,7 +172,7 @@ async def solve_dependencies(
|
||||
values[dependent.state_param_name] = state
|
||||
if dependent.matcher_param_name is not None:
|
||||
values[dependent.matcher_param_name] = matcher
|
||||
return values, dependency_cache
|
||||
return values, dependency_cache, False
|
||||
|
||||
|
||||
def Depends(dependency: Optional[Callable[..., Any]] = None,
|
||||
|
@ -50,7 +50,7 @@ class Handler:
|
||||
|
||||
async def __call__(self, matcher: "Matcher", bot: Bot, event: Event,
|
||||
state: T_State):
|
||||
values, _ = await solve_dependencies(
|
||||
values, _, ignored = await solve_dependencies(
|
||||
dependent=self.dependent,
|
||||
bot=bot,
|
||||
event=event,
|
||||
@ -62,6 +62,17 @@ class Handler:
|
||||
],
|
||||
dependency_overrides_provider=self.dependency_overrides_provider)
|
||||
|
||||
if ignored:
|
||||
return
|
||||
|
||||
# check bot and event type
|
||||
if self.dependent.bot_param_type and not isinstance(
|
||||
bot, self.dependent.bot_param_type):
|
||||
return
|
||||
elif self.dependent.event_param_type and not isinstance(
|
||||
event, self.dependent.event_param_type):
|
||||
return
|
||||
|
||||
if asyncio.iscoroutinefunction(self.func):
|
||||
await self.func(**values)
|
||||
else:
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Any, List, Tuple, Callable, Optional
|
||||
from typing import TYPE_CHECKING, Any, List, Type, Tuple, Callable, Optional
|
||||
|
||||
from nonebot.utils import get_name
|
||||
|
||||
@ -28,9 +28,9 @@ class Dependent:
|
||||
func: Optional[Callable[..., Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
bot_param_name: Optional[str] = None,
|
||||
bot_param_type: Optional[Tuple["Bot", ...]] = None,
|
||||
bot_param_type: Optional[Tuple[Type["Bot"], ...]] = None,
|
||||
event_param_name: Optional[str] = None,
|
||||
event_param_type: Optional[Tuple["Event", ...]] = None,
|
||||
event_param_type: Optional[Tuple[Type["Event"], ...]] = None,
|
||||
state_param_name: Optional[str] = None,
|
||||
matcher_param_name: Optional[str] = None,
|
||||
dependencies: Optional[List["Dependent"]] = None,
|
||||
|
Loading…
Reference in New Issue
Block a user