mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-01-19 09:38:21 +08:00
✨ add bot event type check
This commit is contained in:
parent
9d708a6723
commit
7495fee2a2
@ -98,30 +98,38 @@ async def solve_dependencies(
|
|||||||
sub_dependents: Optional[List[Dependent]] = None,
|
sub_dependents: Optional[List[Dependent]] = None,
|
||||||
dependency_overrides_provider: Optional[Any] = None,
|
dependency_overrides_provider: Optional[Any] = None,
|
||||||
dependency_cache: Optional[Dict[Tuple[Callable[..., Any]], Any]] = None,
|
dependency_cache: Optional[Dict[Tuple[Callable[..., Any]], Any]] = None,
|
||||||
) -> Tuple[Dict[str, Any], Dict[Tuple[Callable[..., Any]], Any]]:
|
) -> Tuple[Dict[str, Any], Dict[Tuple[Callable[..., Any]], Any], bool]:
|
||||||
values: Dict[str, Any] = {}
|
values: Dict[str, Any] = {}
|
||||||
dependency_cache = dependency_cache or {}
|
dependency_cache = dependency_cache or {}
|
||||||
|
|
||||||
# solve sub dependencies
|
# solve sub dependencies
|
||||||
sub_dependant: Dependent
|
sub_dependent: Dependent
|
||||||
for sub_dependant in chain(sub_dependents or tuple(),
|
for sub_dependent in chain(sub_dependents or tuple(),
|
||||||
dependent.dependencies):
|
dependent.dependencies):
|
||||||
sub_dependant.func = cast(Callable[..., Any], sub_dependant.func)
|
sub_dependent.func = cast(Callable[..., Any], sub_dependent.func)
|
||||||
sub_dependant.cache_key = cast(Tuple[Callable[..., Any]],
|
sub_dependent.cache_key = cast(Tuple[Callable[..., Any]],
|
||||||
sub_dependant.cache_key)
|
sub_dependent.cache_key)
|
||||||
func = sub_dependant.func
|
func = sub_dependent.func
|
||||||
|
|
||||||
|
# check bot and event type
|
||||||
|
if sub_dependent.bot_param_type and not isinstance(
|
||||||
|
bot, sub_dependent.bot_param_type):
|
||||||
|
return values, dependency_cache, True
|
||||||
|
elif sub_dependent.event_param_type and not isinstance(
|
||||||
|
event, sub_dependent.event_param_type):
|
||||||
|
return values, dependency_cache, True
|
||||||
|
|
||||||
# dependency overrides
|
# dependency overrides
|
||||||
use_sub_dependant = sub_dependant
|
use_sub_dependant = sub_dependent
|
||||||
if (dependency_overrides_provider and
|
if (dependency_overrides_provider and
|
||||||
hasattr(dependency_overrides_provider, "dependency_overrides")):
|
hasattr(dependency_overrides_provider, "dependency_overrides")):
|
||||||
original_call = sub_dependant.func
|
original_call = sub_dependent.func
|
||||||
func = getattr(dependency_overrides_provider,
|
func = getattr(dependency_overrides_provider,
|
||||||
"dependency_overrides",
|
"dependency_overrides",
|
||||||
{}).get(original_call, original_call)
|
{}).get(original_call, original_call)
|
||||||
use_sub_dependant = get_dependent(
|
use_sub_dependant = get_dependent(
|
||||||
func=func,
|
func=func,
|
||||||
name=sub_dependant.name,
|
name=sub_dependent.name,
|
||||||
)
|
)
|
||||||
|
|
||||||
# solve sub dependency with current cache
|
# solve sub dependency with current cache
|
||||||
@ -134,24 +142,26 @@ async def solve_dependencies(
|
|||||||
dependency_overrides_provider=dependency_overrides_provider,
|
dependency_overrides_provider=dependency_overrides_provider,
|
||||||
dependency_cache=dependency_cache,
|
dependency_cache=dependency_cache,
|
||||||
)
|
)
|
||||||
sub_values, sub_dependency_cache = solved_result
|
sub_values, sub_dependency_cache, ignored = solved_result
|
||||||
|
if ignored:
|
||||||
|
return values, dependency_cache, True
|
||||||
# update cache?
|
# update cache?
|
||||||
dependency_cache.update(sub_dependency_cache)
|
dependency_cache.update(sub_dependency_cache)
|
||||||
|
|
||||||
# run dependency function
|
# run dependency function
|
||||||
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
|
if sub_dependent.use_cache and sub_dependent.cache_key in dependency_cache:
|
||||||
solved = dependency_cache[sub_dependant.cache_key]
|
solved = dependency_cache[sub_dependent.cache_key]
|
||||||
elif is_coroutine_callable(func):
|
elif is_coroutine_callable(func):
|
||||||
solved = await func(**sub_values)
|
solved = await func(**sub_values)
|
||||||
else:
|
else:
|
||||||
solved = await run_sync(func)(**sub_values)
|
solved = await run_sync(func)(**sub_values)
|
||||||
|
|
||||||
# parameter dependency
|
# parameter dependency
|
||||||
if sub_dependant.name is not None:
|
if sub_dependent.name is not None:
|
||||||
values[sub_dependant.name] = solved
|
values[sub_dependent.name] = solved
|
||||||
# save current dependency to cache
|
# save current dependency to cache
|
||||||
if sub_dependant.cache_key not in dependency_cache:
|
if sub_dependent.cache_key not in dependency_cache:
|
||||||
dependency_cache[sub_dependant.cache_key] = solved
|
dependency_cache[sub_dependent.cache_key] = solved
|
||||||
|
|
||||||
# usual dependency
|
# usual dependency
|
||||||
if dependent.bot_param_name is not None:
|
if dependent.bot_param_name is not None:
|
||||||
@ -162,7 +172,7 @@ async def solve_dependencies(
|
|||||||
values[dependent.state_param_name] = state
|
values[dependent.state_param_name] = state
|
||||||
if dependent.matcher_param_name is not None:
|
if dependent.matcher_param_name is not None:
|
||||||
values[dependent.matcher_param_name] = matcher
|
values[dependent.matcher_param_name] = matcher
|
||||||
return values, dependency_cache
|
return values, dependency_cache, False
|
||||||
|
|
||||||
|
|
||||||
def Depends(dependency: Optional[Callable[..., Any]] = None,
|
def Depends(dependency: Optional[Callable[..., Any]] = None,
|
||||||
|
@ -50,7 +50,7 @@ class Handler:
|
|||||||
|
|
||||||
async def __call__(self, matcher: "Matcher", bot: Bot, event: Event,
|
async def __call__(self, matcher: "Matcher", bot: Bot, event: Event,
|
||||||
state: T_State):
|
state: T_State):
|
||||||
values, _ = await solve_dependencies(
|
values, _, ignored = await solve_dependencies(
|
||||||
dependent=self.dependent,
|
dependent=self.dependent,
|
||||||
bot=bot,
|
bot=bot,
|
||||||
event=event,
|
event=event,
|
||||||
@ -62,6 +62,17 @@ class Handler:
|
|||||||
],
|
],
|
||||||
dependency_overrides_provider=self.dependency_overrides_provider)
|
dependency_overrides_provider=self.dependency_overrides_provider)
|
||||||
|
|
||||||
|
if ignored:
|
||||||
|
return
|
||||||
|
|
||||||
|
# check bot and event type
|
||||||
|
if self.dependent.bot_param_type and not isinstance(
|
||||||
|
bot, self.dependent.bot_param_type):
|
||||||
|
return
|
||||||
|
elif self.dependent.event_param_type and not isinstance(
|
||||||
|
event, self.dependent.event_param_type):
|
||||||
|
return
|
||||||
|
|
||||||
if asyncio.iscoroutinefunction(self.func):
|
if asyncio.iscoroutinefunction(self.func):
|
||||||
await self.func(**values)
|
await self.func(**values)
|
||||||
else:
|
else:
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import TYPE_CHECKING, Any, List, Tuple, Callable, Optional
|
from typing import TYPE_CHECKING, Any, List, Type, Tuple, Callable, Optional
|
||||||
|
|
||||||
from nonebot.utils import get_name
|
from nonebot.utils import get_name
|
||||||
|
|
||||||
@ -28,9 +28,9 @@ class Dependent:
|
|||||||
func: Optional[Callable[..., Any]] = None,
|
func: Optional[Callable[..., Any]] = None,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
bot_param_name: Optional[str] = None,
|
bot_param_name: Optional[str] = None,
|
||||||
bot_param_type: Optional[Tuple["Bot", ...]] = None,
|
bot_param_type: Optional[Tuple[Type["Bot"], ...]] = None,
|
||||||
event_param_name: Optional[str] = None,
|
event_param_name: Optional[str] = None,
|
||||||
event_param_type: Optional[Tuple["Event", ...]] = None,
|
event_param_type: Optional[Tuple[Type["Event"], ...]] = None,
|
||||||
state_param_name: Optional[str] = None,
|
state_param_name: Optional[str] = None,
|
||||||
matcher_param_name: Optional[str] = None,
|
matcher_param_name: Optional[str] = None,
|
||||||
dependencies: Optional[List["Dependent"]] = None,
|
dependencies: Optional[List["Dependent"]] = None,
|
||||||
|
Loading…
Reference in New Issue
Block a user