🎨 format code using black and isort

This commit is contained in:
yanyongyu 2021-11-22 23:21:26 +08:00
parent 602185a34e
commit a98d98cd12
86 changed files with 2893 additions and 2095 deletions

View File

@ -1,2 +0,0 @@
[style]
based_on_style = google

View File

@ -40,8 +40,7 @@ from nonebot.log import logger, default_filter
from nonebot.drivers import Driver, ReverseDriver from nonebot.drivers import Driver, ReverseDriver
try: try:
_dist: pkg_resources.Distribution = pkg_resources.get_distribution( _dist: pkg_resources.Distribution = pkg_resources.get_distribution("nonebot2")
"nonebot2")
__version__ = _dist.version __version__ = _dist.version
VERSION = _dist.parsed_version VERSION = _dist.parsed_version
except pkg_resources.DistributionNotFound: except pkg_resources.DistributionNotFound:
@ -100,8 +99,8 @@ def get_app() -> Any:
""" """
driver = get_driver() driver = get_driver()
assert isinstance( assert isinstance(
driver, driver, ReverseDriver
ReverseDriver), "app object is only available for reverse driver" ), "app object is only available for reverse driver"
return driver.server_app return driver.server_app
@ -128,8 +127,8 @@ def get_asgi() -> Any:
""" """
driver = get_driver() driver = get_driver()
assert isinstance( assert isinstance(
driver, driver, ReverseDriver
ReverseDriver), "asgi object is only available for reverse driver" ), "asgi object is only available for reverse driver"
return driver.asgi return driver.asgi
@ -226,17 +225,23 @@ def init(*, _env_file: Optional[str] = None, **kwargs):
if not _driver: if not _driver:
logger.success("NoneBot is initializing...") logger.success("NoneBot is initializing...")
env = Env() env = Env()
config = Config(**kwargs, config = Config(
**kwargs,
_common_config=env.dict(), _common_config=env.dict(),
_env_file=_env_file or f".env.{env.environment}") _env_file=_env_file or f".env.{env.environment}",
)
default_filter.level = ( default_filter.level = (
"DEBUG" if config.debug else ("DEBUG" if config.debug else "INFO")
"INFO") if config.log_level is None else config.log_level if config.log_level is None
else config.log_level
)
logger.opt(colors=True).info( logger.opt(colors=True).info(
f"Current <y><b>Env: {escape_tag(env.environment)}</b></y>") f"Current <y><b>Env: {escape_tag(env.environment)}</b></y>"
)
logger.opt(colors=True).debug( logger.opt(colors=True).debug(
f"Loaded <y><b>Config</b></y>: {escape_tag(str(config.dict()))}") f"Loaded <y><b>Config</b></y>: {escape_tag(str(config.dict()))}"
)
modulename, _, cls = config.driver.partition(":") modulename, _, cls = config.driver.partition(":")
module = importlib.import_module(modulename) module = importlib.import_module(modulename)
@ -247,10 +252,7 @@ def init(*, _env_file: Optional[str] = None, **kwargs):
_driver = DriverClass(env, config) _driver = DriverClass(env, config)
def run(host: Optional[str] = None, def run(host: Optional[str] = None, port: Optional[int] = None, *args, **kwargs):
port: Optional[int] = None,
*args,
**kwargs):
""" """
:说明: :说明:

View File

@ -9,13 +9,13 @@ from typing import Iterable
try: try:
import pkg_resources import pkg_resources
pkg_resources.declare_namespace(__name__) pkg_resources.declare_namespace(__name__)
del pkg_resources del pkg_resources
except ImportError: except ImportError:
import pkgutil import pkgutil
__path__: Iterable[str] = pkgutil.extend_path(
__path__, # type: ignore __path__: Iterable[str] = pkgutil.extend_path(__path__, __name__) # type: ignore
__name__)
del pkgutil del pkgutil
except Exception: except Exception:
pass pass

View File

@ -15,7 +15,6 @@ if TYPE_CHECKING:
class _ApiCall(Protocol): class _ApiCall(Protocol):
async def __call__(self, **kwargs: Any) -> Any: async def __call__(self, **kwargs: Any) -> Any:
... ...
@ -146,7 +145,8 @@ class Bot(abc.ABC):
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running CallingAPI hook. " "<r><bg #f8bbd0>Error when running CallingAPI hook. "
"Running cancelled!</bg #f8bbd0></r>") "Running cancelled!</bg #f8bbd0></r>"
)
exception = None exception = None
result = None result = None
@ -157,8 +157,8 @@ class Bot(abc.ABC):
exception = e exception = e
coros = list( coros = list(
map(lambda x: x(self, exception, api, data, result), map(lambda x: x(self, exception, api, data, result), self._called_api_hook)
self._called_api_hook)) )
if coros: if coros:
try: try:
logger.debug("Running CalledAPI hooks...") logger.debug("Running CalledAPI hooks...")
@ -166,16 +166,17 @@ class Bot(abc.ABC):
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running CalledAPI hook. " "<r><bg #f8bbd0>Error when running CalledAPI hook. "
"Running cancelled!</bg #f8bbd0></r>") "Running cancelled!</bg #f8bbd0></r>"
)
if exception: if exception:
raise exception raise exception
return result return result
@abc.abstractmethod @abc.abstractmethod
async def send(self, event: "Event", message: Union[str, "Message", async def send(
"MessageSegment"], self, event: "Event", message: Union[str, "Message", "MessageSegment"], **kwargs
**kwargs) -> Any: ) -> Any:
""" """
:说明: :说明:

View File

@ -2,9 +2,8 @@ import abc
from pydantic import BaseModel from pydantic import BaseModel
from nonebot.utils import DataclassEncoder
from ._message import Message from ._message import Message
from nonebot.utils import DataclassEncoder
class Event(abc.ABC, BaseModel): class Event(abc.ABC, BaseModel):

View File

@ -1,8 +1,17 @@
import abc import abc
from copy import deepcopy from copy import deepcopy
from dataclasses import field, asdict, dataclass from dataclasses import field, asdict, dataclass
from typing import (Any, Dict, List, Type, Union, Generic, Mapping, TypeVar, from typing import (
Iterable) Any,
Dict,
List,
Type,
Union,
Generic,
Mapping,
TypeVar,
Iterable,
)
from ._template import MessageTemplate from ._template import MessageTemplate
@ -14,6 +23,7 @@ TM = TypeVar("TM", bound="Message")
@dataclass @dataclass
class MessageSegment(Mapping, abc.ABC, Generic[TM]): class MessageSegment(Mapping, abc.ABC, Generic[TM]):
"""消息段基类""" """消息段基类"""
type: str type: str
""" """
- 类型: ``str`` - 类型: ``str``
@ -82,11 +92,12 @@ class MessageSegment(Mapping, abc.ABC, Generic[TM]):
class Message(List[TMS], abc.ABC): class Message(List[TMS], abc.ABC):
"""消息数组""" """消息数组"""
def __init__(self: TM, def __init__(
message: Union[str, None, Mapping, Iterable[Mapping], TMS, TM, self: TM,
Any] = None, message: Union[str, None, Mapping, Iterable[Mapping], TMS, TM, Any] = None,
*args, *args,
**kwargs): **kwargs,
):
""" """
:参数: :参数:
@ -103,8 +114,7 @@ class Message(List[TMS], abc.ABC):
self.extend(self._construct(message)) self.extend(self._construct(message))
@classmethod @classmethod
def template(cls: Type[TM], def template(cls: Type[TM], format_string: Union[str, TM]) -> MessageTemplate[TM]:
format_string: Union[str, TM]) -> MessageTemplate[TM]:
""" """
:说明: :说明:
@ -156,8 +166,7 @@ class Message(List[TMS], abc.ABC):
@staticmethod @staticmethod
@abc.abstractmethod @abc.abstractmethod
def _construct( def _construct(msg: Union[str, Mapping, Iterable[Mapping], Any]) -> Iterable[TMS]:
msg: Union[str, Mapping, Iterable[Mapping], Any]) -> Iterable[TMS]:
raise NotImplementedError raise NotImplementedError
def __add__(self: TM, other: Union[str, Mapping, Iterable[Mapping]]) -> TM: def __add__(self: TM, other: Union[str, Mapping, Iterable[Mapping]]) -> TM:

View File

@ -1,8 +1,21 @@
import inspect import inspect
import functools import functools
from string import Formatter from string import Formatter
from typing import (TYPE_CHECKING, Any, Set, List, Type, Tuple, Union, Generic, from typing import (
Mapping, TypeVar, Sequence, cast, overload) TYPE_CHECKING,
Any,
Set,
List,
Type,
Tuple,
Union,
Generic,
Mapping,
TypeVar,
Sequence,
cast,
overload,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from . import Message, MessageSegment from . import Message, MessageSegment
@ -15,14 +28,15 @@ class MessageTemplate(Formatter, Generic[TF]):
"""消息模板格式化实现类""" """消息模板格式化实现类"""
@overload @overload
def __init__(self: "MessageTemplate[str]", def __init__(
template: str, self: "MessageTemplate[str]", template: str, factory: Type[str] = str
factory: Type[str] = str) -> None: ) -> None:
... ...
@overload @overload
def __init__(self: "MessageTemplate[TM]", template: Union[str, TM], def __init__(
factory: Type[TM]) -> None: self: "MessageTemplate[TM]", template: Union[str, TM], factory: Type[TM]
) -> None:
... ...
def __init__(self, template, factory=str) -> None: def __init__(self, template, factory=str) -> None:
@ -51,15 +65,15 @@ class MessageTemplate(Formatter, Generic[TF]):
elif isinstance(self.template, self.factory): elif isinstance(self.template, self.factory):
template = cast("Message[MessageSegment]", self.template) template = cast("Message[MessageSegment]", self.template)
for seg in template: for seg in template:
msg += self.vformat(str(seg), args, msg += self.vformat(str(seg), args, kwargs) if seg.is_text() else seg
kwargs) if seg.is_text() else seg
else: else:
raise TypeError('template must be a string or instance of Message!') raise TypeError("template must be a string or instance of Message!")
return msg return msg
def vformat(self, format_string: str, args: Sequence[Any], def vformat(
kwargs: Mapping[str, Any]) -> TF: self, format_string: str, args: Sequence[Any], kwargs: Mapping[str, Any]
) -> TF:
used_args = set() used_args = set()
result, _ = self._vformat(format_string, args, kwargs, used_args, 2) result, _ = self._vformat(format_string, args, kwargs, used_args, 2)
self.check_unused_args(list(used_args), args, kwargs) self.check_unused_args(list(used_args), args, kwargs)
@ -79,8 +93,9 @@ class MessageTemplate(Formatter, Generic[TF]):
results: List[Any] = [] results: List[Any] = []
for (literal_text, field_name, format_spec, for (literal_text, field_name, format_spec, conversion) in self.parse(
conversion) in self.parse(format_string): format_string
):
# output the literal text # output the literal text
if literal_text: if literal_text:
@ -96,14 +111,16 @@ class MessageTemplate(Formatter, Generic[TF]):
if auto_arg_index is False: if auto_arg_index is False:
raise ValueError( raise ValueError(
"cannot switch from manual field specification to " "cannot switch from manual field specification to "
"automatic field numbering") "automatic field numbering"
)
field_name = str(auto_arg_index) field_name = str(auto_arg_index)
auto_arg_index += 1 auto_arg_index += 1
elif field_name.isdigit(): elif field_name.isdigit():
if auto_arg_index: if auto_arg_index:
raise ValueError( raise ValueError(
"cannot switch from manual field specification to " "cannot switch from manual field specification to "
"automatic field numbering") "automatic field numbering"
)
# disable auto arg incrementing, if it gets # disable auto arg incrementing, if it gets
# used later on, then an exception will be raised # used later on, then an exception will be raised
auto_arg_index = False auto_arg_index = False
@ -132,8 +149,10 @@ class MessageTemplate(Formatter, Generic[TF]):
formatted_text = self.format_field(obj, str(format_control)) formatted_text = self.format_field(obj, str(format_control))
results.append(formatted_text) results.append(formatted_text)
return self.factory(functools.reduce(self._add, results or return (
[""])), auto_arg_index self.factory(functools.reduce(self._add, results or [""])),
auto_arg_index,
)
def format_field(self, value: Any, format_spec: str) -> Any: def format_field(self, value: Any, format_spec: str) -> Any:
if issubclass(self.factory, str): if issubclass(self.factory, str):
@ -142,11 +161,20 @@ class MessageTemplate(Formatter, Generic[TF]):
segment_class: Type[MessageSegment] = self.factory.get_segment_class() segment_class: Type[MessageSegment] = self.factory.get_segment_class()
method = getattr(segment_class, format_spec, None) method = getattr(segment_class, format_spec, None)
method_type = inspect.getattr_static(segment_class, format_spec, None) method_type = inspect.getattr_static(segment_class, format_spec, None)
return (super().format_field(value, format_spec) if return (
((method is None) or (
(not isinstance(method_type, (classmethod, staticmethod)) super().format_field(value, format_spec)
if (
(method is None)
or (
not isinstance(method_type, (classmethod, staticmethod))
) # Only Call staticmethod or classmethod ) # Only Call staticmethod or classmethod
) else method(value)) if format_spec else value )
else method(value)
)
if format_spec
else value
)
def _add(self, a: Any, b: Any) -> Any: def _add(self, a: Any, b: Any) -> Any:
try: try:

View File

@ -20,13 +20,17 @@ from ipaddress import IPv4Address
from typing import Any, Set, Dict, Tuple, Union, Mapping, Optional from typing import Any, Set, Dict, Tuple, Union, Mapping, Optional
from pydantic import BaseSettings, IPvAnyAddress from pydantic import BaseSettings, IPvAnyAddress
from pydantic.env_settings import (SettingsError, EnvSettingsSource, from pydantic.env_settings import (
InitSettingsSource, SettingsSourceCallable, SettingsError,
read_env_file, env_file_sentinel) EnvSettingsSource,
InitSettingsSource,
SettingsSourceCallable,
read_env_file,
env_file_sentinel,
)
class CustomEnvSettings(EnvSettingsSource): class CustomEnvSettings(EnvSettingsSource):
def __call__(self, settings: BaseSettings) -> Dict[str, Any]: def __call__(self, settings: BaseSettings) -> Dict[str, Any]:
""" """
Build environment variables suitable for passing to the Model. Build environment variables suitable for passing to the Model.
@ -39,15 +43,24 @@ class CustomEnvSettings(EnvSettingsSource):
env_vars = {k.lower(): v for k, v in os.environ.items()} env_vars = {k.lower(): v for k, v in os.environ.items()}
env_file_vars: Dict[str, Optional[str]] = {} env_file_vars: Dict[str, Optional[str]] = {}
env_file = self.env_file if self.env_file != env_file_sentinel else settings.__config__.env_file env_file = (
env_file_encoding = self.env_file_encoding if self.env_file_encoding is not None else settings.__config__.env_file_encoding self.env_file
if self.env_file != env_file_sentinel
else settings.__config__.env_file
)
env_file_encoding = (
self.env_file_encoding
if self.env_file_encoding is not None
else settings.__config__.env_file_encoding
)
if env_file is not None: if env_file is not None:
env_path = Path(env_file) env_path = Path(env_file)
if env_path.is_file(): if env_path.is_file():
env_file_vars = read_env_file( env_file_vars = read_env_file(
env_path, env_path,
encoding=env_file_encoding, encoding=env_file_encoding,
case_sensitive=settings.__config__.case_sensitive) case_sensitive=settings.__config__.case_sensitive,
)
env_vars = {**env_file_vars, **env_vars} env_vars = {**env_file_vars, **env_vars}
for field in settings.__fields__.values(): for field in settings.__fields__.values():
@ -66,14 +79,12 @@ class CustomEnvSettings(EnvSettingsSource):
try: try:
env_val = settings.__config__.json_loads(env_val) env_val = settings.__config__.json_loads(env_val)
except ValueError as e: except ValueError as e:
raise SettingsError( raise SettingsError(f'error parsing JSON for "{env_name}"') from e
f'error parsing JSON for "{env_name}"') from e
d[field.alias] = env_val d[field.alias] = env_val
if env_file_vars: if env_file_vars:
for env_name, env_val in env_file_vars.items(): for env_name, env_val in env_file_vars.items():
if (env_val is None or if (env_val is None or len(env_val) == 0) and env_name in env_vars:
len(env_val) == 0) and env_name in env_vars:
env_val = env_vars[env_name] env_val = env_vars[env_name]
try: try:
if env_val: if env_val:
@ -87,12 +98,10 @@ class CustomEnvSettings(EnvSettingsSource):
class BaseConfig(BaseSettings): class BaseConfig(BaseSettings):
def __getattr__(self, name: str) -> Any: def __getattr__(self, name: str) -> Any:
return self.__dict__.get(name) return self.__dict__.get(name)
class Config: class Config:
@classmethod @classmethod
def customise_sources( def customise_sources(
cls, cls,
@ -101,10 +110,14 @@ class BaseConfig(BaseSettings):
file_secret_settings: SettingsSourceCallable, file_secret_settings: SettingsSourceCallable,
) -> Tuple[SettingsSourceCallable, ...]: ) -> Tuple[SettingsSourceCallable, ...]:
common_config = init_settings.init_kwargs.pop("_common_config", {}) common_config = init_settings.init_kwargs.pop("_common_config", {})
return (init_settings, return (
CustomEnvSettings(env_settings.env_file, init_settings,
env_settings.env_file_encoding), CustomEnvSettings(
InitSettingsSource(common_config), file_secret_settings) env_settings.env_file, env_settings.env_file_encoding
),
InitSettingsSource(common_config),
file_secret_settings,
)
class Env(BaseConfig): class Env(BaseConfig):
@ -135,6 +148,7 @@ class Config(BaseConfig):
除了 NoneBot 的配置项外还可以自行添加配置项到 ``.env.{environment}`` 文件中 除了 NoneBot 的配置项外还可以自行添加配置项到 ``.env.{environment}`` 文件中
这些配置将会在 json 反序列化后一起带入 ``Config`` 类中 这些配置将会在 json 反序列化后一起带入 ``Config`` 类中
""" """
# nonebot configs # nonebot configs
driver: str = "nonebot.drivers.fastapi" driver: str = "nonebot.drivers.fastapi"
""" """
@ -210,7 +224,7 @@ class Config(BaseConfig):
API_ROOT={"123456": "http://127.0.0.1:5700"} API_ROOT={"123456": "http://127.0.0.1:5700"}
""" """
api_timeout: Optional[float] = 30. api_timeout: Optional[float] = 30.0
""" """
- **类型**: ``Optional[float]`` - **类型**: ``Optional[float]``
- **默认值**: ``30.`` - **默认值**: ``30.``

View File

@ -21,9 +21,14 @@ from .models import Dependent as Dependent
from nonebot.exception import SkippedException from nonebot.exception import SkippedException
from .models import DependsWrapper as DependsWrapper from .models import DependsWrapper as DependsWrapper
from nonebot.typing import T_Handler, T_DependencyCache from nonebot.typing import T_Handler, T_DependencyCache
from nonebot.utils import (CacheLock, run_sync, is_gen_callable, from nonebot.utils import (
run_sync_ctx_manager, is_async_gen_callable, CacheLock,
is_coroutine_callable) run_sync,
is_gen_callable,
run_sync_ctx_manager,
is_async_gen_callable,
is_coroutine_callable,
)
cache_lock = CacheLock() cache_lock = CacheLock()
@ -33,30 +38,27 @@ class CustomConfig(BaseConfig):
def get_param_sub_dependent( def get_param_sub_dependent(
*, *, param: inspect.Parameter, allow_types: Optional[List[Type[Param]]] = None
param: inspect.Parameter, ) -> Dependent:
allow_types: Optional[List[Type[Param]]] = None) -> Dependent:
depends: DependsWrapper = param.default depends: DependsWrapper = param.default
if depends.dependency: if depends.dependency:
dependency = depends.dependency dependency = depends.dependency
else: else:
dependency = param.annotation dependency = param.annotation
return get_sub_dependant(depends=depends, return get_sub_dependant(
dependency=dependency, depends=depends, dependency=dependency, name=param.name, allow_types=allow_types
name=param.name, )
allow_types=allow_types)
def get_parameterless_sub_dependant( def get_parameterless_sub_dependant(
*, *, depends: DependsWrapper, allow_types: Optional[List[Type[Param]]] = None
depends: DependsWrapper, ) -> Dependent:
allow_types: Optional[List[Type[Param]]] = None) -> Dependent:
assert callable( assert callable(
depends.dependency depends.dependency
), "A parameter-less dependency must have a callable dependency" ), "A parameter-less dependency must have a callable dependency"
return get_sub_dependant(depends=depends, return get_sub_dependant(
dependency=depends.dependency, depends=depends, dependency=depends.dependency, allow_types=allow_types
allow_types=allow_types) )
def get_sub_dependant( def get_sub_dependant(
@ -64,29 +66,31 @@ def get_sub_dependant(
depends: DependsWrapper, depends: DependsWrapper,
dependency: T_Handler, dependency: T_Handler,
name: Optional[str] = None, name: Optional[str] = None,
allow_types: Optional[List[Type[Param]]] = None) -> Dependent: allow_types: Optional[List[Type[Param]]] = None,
sub_dependant = get_dependent(func=dependency, ) -> Dependent:
name=name, sub_dependant = get_dependent(
use_cache=depends.use_cache, func=dependency, name=name, use_cache=depends.use_cache, allow_types=allow_types
allow_types=allow_types) )
return sub_dependant return sub_dependant
def get_dependent(*, def get_dependent(
*,
func: T_Handler, func: T_Handler,
name: Optional[str] = None, name: Optional[str] = None,
use_cache: bool = True, use_cache: bool = True,
allow_types: Optional[List[Type[Param]]] = None) -> Dependent: allow_types: Optional[List[Type[Param]]] = None,
) -> Dependent:
signature = get_typed_signature(func) signature = get_typed_signature(func)
params = signature.parameters params = signature.parameters
dependent = Dependent(func=func, dependent = Dependent(
name=name, func=func, name=name, allow_types=allow_types, use_cache=use_cache
allow_types=allow_types, )
use_cache=use_cache)
for param_name, param in params.items(): for param_name, param in params.items():
if isinstance(param.default, DependsWrapper): if isinstance(param.default, DependsWrapper):
sub_dependent = get_param_sub_dependent(param=param, sub_dependent = get_param_sub_dependent(
allow_types=allow_types) param=param, allow_types=allow_types
)
dependent.dependencies.append(sub_dependent) dependent.dependencies.append(sub_dependent)
continue continue
@ -111,16 +115,18 @@ def get_dependent(*,
required = default_value == Required required = default_value == Required
if param.annotation != param.empty: if param.annotation != param.empty:
annotation = param.annotation annotation = param.annotation
annotation = get_annotation_from_field_info(annotation, field_info, annotation = get_annotation_from_field_info(annotation, field_info, param_name)
param_name)
dependent.params.append( dependent.params.append(
ModelField(name=param_name, ModelField(
name=param_name,
type_=annotation, type_=annotation,
class_validators=None, class_validators=None,
model_config=CustomConfig, model_config=CustomConfig,
default=None if required else default_value, default=None if required else default_value,
required=required, required=required,
field_info=field_info)) field_info=field_info,
)
)
return dependent return dependent
@ -131,24 +137,22 @@ async def solve_dependencies(
_stack: Optional[AsyncExitStack] = None, _stack: Optional[AsyncExitStack] = None,
_sub_dependents: Optional[List[Dependent]] = None, _sub_dependents: Optional[List[Dependent]] = None,
_dependency_cache: Optional[T_DependencyCache] = None, _dependency_cache: Optional[T_DependencyCache] = None,
**params: Any) -> Tuple[Dict[str, Any], T_DependencyCache]: **params: Any,
) -> Tuple[Dict[str, Any], T_DependencyCache]:
values: Dict[str, Any] = {} values: Dict[str, Any] = {}
dependency_cache = {} if _dependency_cache is None else _dependency_cache dependency_cache = {} if _dependency_cache is None else _dependency_cache
# solve sub dependencies # solve sub dependencies
sub_dependent: Dependent sub_dependent: Dependent
for sub_dependent in chain(_sub_dependents or tuple(), for sub_dependent in chain(_sub_dependents or tuple(), _dependent.dependencies):
_dependent.dependencies):
sub_dependent.func = cast(Callable[..., Any], sub_dependent.func) sub_dependent.func = cast(Callable[..., Any], sub_dependent.func)
sub_dependent.cache_key = cast(Callable[..., Any], sub_dependent.cache_key = cast(Callable[..., Any], sub_dependent.cache_key)
sub_dependent.cache_key)
func = sub_dependent.func func = sub_dependent.func
# solve sub dependency with current cache # solve sub dependency with current cache
solved_result = await solve_dependencies( solved_result = await solve_dependencies(
_dependent=sub_dependent, _dependent=sub_dependent, _dependency_cache=dependency_cache, **params
_dependency_cache=dependency_cache, )
**params)
sub_values, sub_dependency_cache = solved_result sub_values, sub_dependency_cache = solved_result
# update cache? # update cache?
# dependency_cache.update(sub_dependency_cache) # dependency_cache.update(sub_dependency_cache)
@ -162,8 +166,7 @@ async def solve_dependencies(
_stack, AsyncExitStack _stack, AsyncExitStack
), "Generator dependency should be called in context" ), "Generator dependency should be called in context"
if is_gen_callable(func): if is_gen_callable(func):
cm = run_sync_ctx_manager( cm = run_sync_ctx_manager(contextmanager(func)(**sub_values))
contextmanager(func)(**sub_values))
else: else:
cm = asynccontextmanager(func)(**sub_values) cm = asynccontextmanager(func)(**sub_values)
solved = await _stack.enter_async_context(cm) solved = await _stack.enter_async_context(cm)
@ -182,19 +185,17 @@ async def solve_dependencies(
# usual dependency # usual dependency
for field in _dependent.params: for field in _dependent.params:
field_info = field.field_info field_info = field.field_info
assert isinstance(field_info, assert isinstance(field_info, Param), "Params must be subclasses of Param"
Param), "Params must be subclasses of Param"
value = field_info._solve(**params) value = field_info._solve(**params)
if value == Undefined: if value == Undefined:
value = field.get_default() value = field.get_default()
_, errs_ = field.validate(value, _, errs_ = field.validate(value, values, loc=(str(field_info), field.alias))
values,
loc=(str(field_info), field.alias))
if errs_: if errs_:
logger.debug( logger.debug(
f"{field_info} " f"{field_info} "
f"type {type(value)} not match depends {_dependent.func} " f"type {type(value)} not match depends {_dependent.func} "
f"annotation {field._type_display()}, ignored") f"annotation {field._type_display()}, ignored"
)
raise SkippedException raise SkippedException
else: else:
values[field.name] = value values[field.name] = value
@ -202,9 +203,7 @@ async def solve_dependencies(
return values, dependency_cache return values, dependency_cache
def Depends(dependency: Optional[T_Handler] = None, def Depends(dependency: Optional[T_Handler] = None, *, use_cache: bool = True) -> Any:
*,
use_cache: bool = True) -> Any:
""" """
:说明: :说明:

View File

@ -9,7 +9,6 @@ from nonebot.typing import T_Handler
class Param(abc.ABC, FieldInfo): class Param(abc.ABC, FieldInfo):
@classmethod @classmethod
@abc.abstractmethod @abc.abstractmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool: def _check(cls, name: str, param: inspect.Parameter) -> bool:
@ -21,11 +20,9 @@ class Param(abc.ABC, FieldInfo):
class DependsWrapper: class DependsWrapper:
def __init__(
def __init__(self, self, dependency: Optional[T_Handler] = None, *, use_cache: bool = True
dependency: Optional[T_Handler] = None, ) -> None:
*,
use_cache: bool = True) -> None:
self.dependency = dependency self.dependency = dependency
self.use_cache = use_cache self.use_cache = use_cache
@ -36,15 +33,16 @@ class DependsWrapper:
class Dependent: class Dependent:
def __init__(
def __init__(self, self,
*, *,
func: Optional[T_Handler] = None, func: Optional[T_Handler] = None,
name: Optional[str] = None, name: Optional[str] = None,
params: Optional[List[ModelField]] = None, params: Optional[List[ModelField]] = None,
allow_types: Optional[List[Type[Param]]] = None, allow_types: Optional[List[Type[Param]]] = None,
dependencies: Optional[List["Dependent"]] = None, dependencies: Optional[List["Dependent"]] = None,
use_cache: bool = True) -> None: use_cache: bool = True,
) -> None:
self.func = func self.func = func
self.name = name self.name = name
self.params = params or [] self.params = params or []

View File

@ -16,14 +16,14 @@ def get_typed_signature(func: T_Handler) -> inspect.Signature:
kind=param.kind, kind=param.kind,
default=param.default, default=param.default,
annotation=get_typed_annotation(param, globalns), annotation=get_typed_annotation(param, globalns),
) for param in signature.parameters.values() )
for param in signature.parameters.values()
] ]
typed_signature = inspect.Signature(typed_params) typed_signature = inspect.Signature(typed_params)
return typed_signature return typed_signature
def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, Any]) -> Any:
Any]) -> Any:
annotation = param.annotation annotation = param.annotation
if isinstance(annotation, str): if isinstance(annotation, str):
annotation = ForwardRef(annotation) annotation = ForwardRef(annotation)
@ -31,7 +31,7 @@ def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str,
annotation = evaluate_forwardref(annotation, globalns, globalns) annotation = evaluate_forwardref(annotation, globalns, globalns)
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).warning( logger.opt(colors=True, exception=e).warning(
f"Unknown ForwardRef[\"{param.annotation}\"] for parameter {param.name}" f'Unknown ForwardRef["{param.annotation}"] for parameter {param.name}'
) )
return inspect.Parameter.empty return inspect.Parameter.empty
return annotation return annotation

View File

@ -8,8 +8,17 @@
import abc import abc
import asyncio import asyncio
from dataclasses import field, dataclass from dataclasses import field, dataclass
from typing import (TYPE_CHECKING, Any, Set, Dict, Type, Union, Callable, from typing import (
Optional, Awaitable) TYPE_CHECKING,
Any,
Set,
Dict,
Type,
Union,
Callable,
Optional,
Awaitable,
)
from nonebot.log import logger from nonebot.log import logger
from nonebot.utils import escape_tag from nonebot.utils import escape_tag
@ -90,12 +99,14 @@ class Driver(abc.ABC):
""" """
if name in self._adapters: if name in self._adapters:
logger.opt(colors=True).debug( logger.opt(colors=True).debug(
f'Adapter "<y>{escape_tag(name)}</y>" already exists') f'Adapter "<y>{escape_tag(name)}</y>" already exists'
)
return return
self._adapters[name] = adapter self._adapters[name] = adapter
adapter.register(self, self.config, **kwargs) adapter.register(self, self.config, **kwargs)
logger.opt(colors=True).debug( logger.opt(colors=True).debug(
f'Succeeded to load adapter "<y>{escape_tag(name)}</y>"') f'Succeeded to load adapter "<y>{escape_tag(name)}</y>"'
)
@property @property
@abc.abstractmethod @abc.abstractmethod
@ -121,7 +132,8 @@ class Driver(abc.ABC):
* ``**kwargs`` * ``**kwargs``
""" """
logger.opt(colors=True).debug( logger.opt(colors=True).debug(
f"<g>Loaded adapters: {escape_tag(', '.join(self._adapters))}</g>") f"<g>Loaded adapters: {escape_tag(', '.join(self._adapters))}</g>"
)
@abc.abstractmethod @abc.abstractmethod
def on_startup(self, func: Callable) -> Callable: def on_startup(self, func: Callable) -> Callable:
@ -146,8 +158,7 @@ class Driver(abc.ABC):
self._bot_connection_hook.add(func) self._bot_connection_hook.add(func)
return func return func
def on_bot_disconnect( def on_bot_disconnect(self, func: T_BotDisconnectionHook) -> T_BotDisconnectionHook:
self, func: T_BotDisconnectionHook) -> T_BotDisconnectionHook:
""" """
:说明: :说明:
@ -172,7 +183,8 @@ class Driver(abc.ABC):
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running WebSocketConnection hook. " "<r><bg #f8bbd0>Error when running WebSocketConnection hook. "
"Running cancelled!</bg #f8bbd0></r>") "Running cancelled!</bg #f8bbd0></r>"
)
asyncio.create_task(_run_hook(bot)) asyncio.create_task(_run_hook(bot))
@ -189,7 +201,8 @@ class Driver(abc.ABC):
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running WebSocketDisConnection hook. " "<r><bg #f8bbd0>Error when running WebSocketDisConnection hook. "
"Running cancelled!</bg #f8bbd0></r>") "Running cancelled!</bg #f8bbd0></r>"
)
asyncio.create_task(_run_hook(bot)) asyncio.create_task(_run_hook(bot))
@ -201,8 +214,8 @@ class ForwardDriver(Driver):
@abc.abstractmethod @abc.abstractmethod
def setup_http_polling( def setup_http_polling(
self, setup: Union["HTTPPollingSetup", self,
Callable[[], Awaitable["HTTPPollingSetup"]]] setup: Union["HTTPPollingSetup", Callable[[], Awaitable["HTTPPollingSetup"]]],
) -> None: ) -> None:
""" """
:说明: :说明:
@ -217,8 +230,7 @@ class ForwardDriver(Driver):
@abc.abstractmethod @abc.abstractmethod
def setup_websocket( def setup_websocket(
self, setup: Union["WebSocketSetup", self, setup: Union["WebSocketSetup", Callable[[], Awaitable["WebSocketSetup"]]]
Callable[[], Awaitable["WebSocketSetup"]]]
) -> None: ) -> None:
""" """
:说明: :说明:
@ -288,6 +300,7 @@ class HTTPRequest(HTTPConnection):
.. _asgi http scope: .. _asgi http scope:
https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope
""" """
method: str = "GET" method: str = "GET"
"""The HTTP method name, uppercased.""" """The HTTP method name, uppercased."""
body: bytes = b"" body: bytes = b""
@ -309,6 +322,7 @@ class HTTPResponse:
.. _asgi http scope: .. _asgi http scope:
https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope
""" """
status: int status: int
"""HTTP status code.""" """HTTP status code."""
body: Optional[bytes] = None body: Optional[bytes] = None
@ -416,5 +430,5 @@ class WebSocketSetup:
"""URL""" """URL"""
headers: Dict[str, str] = field(default_factory=dict) headers: Dict[str, str] = field(default_factory=dict)
"""HTTP headers""" """HTTP headers"""
reconnect_interval: float = 3. reconnect_interval: float = 3.0
"""WebSocket 重连间隔""" """WebSocket 重连间隔"""

View File

@ -20,13 +20,16 @@ from nonebot.typing import overrides
from nonebot.utils import escape_tag from nonebot.utils import escape_tag
from nonebot.config import Env, Config from nonebot.config import Env, Config
from nonebot.drivers import WebSocket as BaseWebSocket from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import (HTTPRequest, ForwardDriver, WebSocketSetup, from nonebot.drivers import (
HTTPPollingSetup) HTTPRequest,
ForwardDriver,
WebSocketSetup,
HTTPPollingSetup,
)
STARTUP_FUNC = Callable[[], Awaitable[None]] STARTUP_FUNC = Callable[[], Awaitable[None]]
SHUTDOWN_FUNC = Callable[[], Awaitable[None]] SHUTDOWN_FUNC = Callable[[], Awaitable[None]]
HTTPPOLLING_SETUP = Union[HTTPPollingSetup, HTTPPOLLING_SETUP = Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]
Callable[[], Awaitable[HTTPPollingSetup]]]
WEBSOCKET_SETUP = Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]] WEBSOCKET_SETUP = Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]
HANDLED_SIGNALS = ( HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C. signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
@ -146,7 +149,8 @@ class Driver(ForwardDriver):
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running startup function. " "<r><bg #f8bbd0>Error when running startup function. "
"Ignored!</bg #f8bbd0></r>") "Ignored!</bg #f8bbd0></r>"
)
async def main_loop(self): async def main_loop(self):
await self.should_exit.wait() await self.should_exit.wait()
@ -160,24 +164,20 @@ class Driver(ForwardDriver):
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running shutdown function. " "<r><bg #f8bbd0>Error when running shutdown function. "
"Ignored!</bg #f8bbd0></r>") "Ignored!</bg #f8bbd0></r>"
)
for task in self.connections: for task in self.connections:
if not task.done(): if not task.done():
task.cancel() task.cancel()
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
tasks = [ tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
t for t in asyncio.all_tasks() if t is not asyncio.current_task()
]
if tasks and not self.force_exit: if tasks and not self.force_exit:
logger.info("Waiting for tasks to finish. (CTRL+C to force quit)") logger.info("Waiting for tasks to finish. (CTRL+C to force quit)")
while tasks and not self.force_exit: while tasks and not self.force_exit:
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
tasks = [ tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
t for t in asyncio.all_tasks()
if t is not asyncio.current_task()
]
for task in tasks: for task in tasks:
task.cancel() task.cancel()
@ -209,9 +209,7 @@ class Driver(ForwardDriver):
self.should_exit.set() self.should_exit.set()
async def _http_loop(self, setup: HTTPPOLLING_SETUP): async def _http_loop(self, setup: HTTPPOLLING_SETUP):
async def _build_request(setup: HTTPPollingSetup) -> Optional[HTTPRequest]:
async def _build_request(
setup: HTTPPollingSetup) -> Optional[HTTPRequest]:
url = URL(setup.url) url = URL(setup.url)
if not url.is_absolute() or not url.host: if not url.is_absolute() or not url.host:
logger.opt(colors=True).error( logger.opt(colors=True).error(
@ -219,10 +217,15 @@ class Driver(ForwardDriver):
) )
return return
host = f"{url.host}:{url.port}" if url.port else url.host host = f"{url.host}:{url.port}" if url.port else url.host
return HTTPRequest(setup.http_version, url.scheme, url.path, return HTTPRequest(
url.raw_query_string.encode("latin-1"), { setup.http_version,
**setup.headers, "host": host url.scheme,
}, setup.method, setup.body) url.path,
url.raw_query_string.encode("latin-1"),
{**setup.headers, "host": host},
setup.method,
setup.body,
)
bot: Optional[Bot] = None bot: Optional[Bot] = None
request: Optional[HTTPRequest] = None request: Optional[HTTPRequest] = None
@ -230,7 +233,8 @@ class Driver(ForwardDriver):
logger.opt(colors=True).info( logger.opt(colors=True).info(
f"Start http polling for <y>{escape_tag(setup.adapter.upper())} " f"Start http polling for <y>{escape_tag(setup.adapter.upper())} "
f"Bot {escape_tag(setup.self_id)}</y>") f"Bot {escape_tag(setup.self_id)}</y>"
)
try: try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
@ -244,7 +248,8 @@ class Driver(ForwardDriver):
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error while parsing setup " "<r><bg #f8bbd0>Error while parsing setup "
f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>") f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>"
)
await asyncio.sleep(3) await asyncio.sleep(3)
continue continue
@ -286,19 +291,22 @@ class Driver(ForwardDriver):
) )
try: try:
async with session.request(request.method, async with session.request(
request.method,
setup_.url, setup_.url,
data=request.body, data=request.body,
headers=headers, headers=headers,
timeout=timeout, timeout=timeout,
version=version) as response: version=version,
) as response:
response.raise_for_status() response.raise_for_status()
data = await response.read() data = await response.read()
asyncio.create_task(bot.handle_message(data)) asyncio.create_task(bot.handle_message(data))
except aiohttp.ClientResponseError as e: except aiohttp.ClientResponseError as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Error occurred while requesting {escape_tag(setup_.url)}. " f"<r><bg #f8bbd0>Error occurred while requesting {escape_tag(setup_.url)}. "
"Try to reconnect...</bg #f8bbd0></r>") "Try to reconnect...</bg #f8bbd0></r>"
)
await asyncio.sleep(setup_.poll_interval) await asyncio.sleep(setup_.poll_interval)
@ -307,7 +315,8 @@ class Driver(ForwardDriver):
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Unexpected exception occurred " "<r><bg #f8bbd0>Unexpected exception occurred "
"while http polling</bg #f8bbd0></r>") "while http polling</bg #f8bbd0></r>"
)
finally: finally:
if bot: if bot:
self._bot_disconnect(bot) self._bot_disconnect(bot)
@ -327,7 +336,8 @@ class Driver(ForwardDriver):
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error while parsing setup " "<r><bg #f8bbd0>Error while parsing setup "
f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>") f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>"
)
await asyncio.sleep(3) await asyncio.sleep(3)
continue continue
@ -346,17 +356,21 @@ class Driver(ForwardDriver):
f"Bot {setup_.self_id} from adapter {setup_.adapter} connecting to {url}" f"Bot {setup_.self_id} from adapter {setup_.adapter} connecting to {url}"
) )
try: try:
async with session.ws_connect(url, async with session.ws_connect(
headers=headers, url, headers=headers, timeout=30.0
timeout=30.) as ws: ) as ws:
logger.opt(colors=True).info( logger.opt(colors=True).info(
f"WebSocket Connection to <y>{escape_tag(setup_.adapter.upper())} " f"WebSocket Connection to <y>{escape_tag(setup_.adapter.upper())} "
f"Bot {escape_tag(setup_.self_id)}</y> succeeded!" f"Bot {escape_tag(setup_.self_id)}</y> succeeded!"
) )
request = WebSocket( request = WebSocket(
"1.1", url.scheme, url.path, "1.1",
url.raw_query_string.encode("latin-1"), headers, url.scheme,
ws) url.path,
url.raw_query_string.encode("latin-1"),
headers,
ws,
)
BotClass = self._adapters[setup_.adapter] BotClass = self._adapters[setup_.adapter]
bot = BotClass(setup_.self_id, request) bot = BotClass(setup_.self_id, request)
@ -365,25 +379,30 @@ class Driver(ForwardDriver):
msg = await ws.receive() msg = await ws.receive()
if msg.type == aiohttp.WSMsgType.text: if msg.type == aiohttp.WSMsgType.text:
asyncio.create_task( asyncio.create_task(
bot.handle_message(msg.data.encode())) bot.handle_message(msg.data.encode())
)
elif msg.type == aiohttp.WSMsgType.binary: elif msg.type == aiohttp.WSMsgType.binary:
asyncio.create_task( asyncio.create_task(bot.handle_message(msg.data))
bot.handle_message(msg.data))
elif msg.type == aiohttp.WSMsgType.error: elif msg.type == aiohttp.WSMsgType.error:
logger.opt(colors=True).error( logger.opt(colors=True).error(
"<r><bg #f8bbd0>Error while handling websocket frame. " "<r><bg #f8bbd0>Error while handling websocket frame. "
"Try to reconnect...</bg #f8bbd0></r>") "Try to reconnect...</bg #f8bbd0></r>"
)
break break
else: else:
logger.opt(colors=True).error( logger.opt(colors=True).error(
"<r><bg #f8bbd0>WebSocket connection closed by peer. " "<r><bg #f8bbd0>WebSocket connection closed by peer. "
"Try to reconnect...</bg #f8bbd0></r>") "Try to reconnect...</bg #f8bbd0></r>"
)
break break
except (aiohttp.ClientResponseError, except (
aiohttp.ClientConnectionError) as e: aiohttp.ClientResponseError,
aiohttp.ClientConnectionError,
) as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Error while connecting to {escape_tag(str(url))}. " f"<r><bg #f8bbd0>Error while connecting to {escape_tag(str(url))}. "
"Try to reconnect...</bg #f8bbd0></r>") "Try to reconnect...</bg #f8bbd0></r>"
)
finally: finally:
if bot: if bot:
self._bot_disconnect(bot) self._bot_disconnect(bot)
@ -395,7 +414,8 @@ class Driver(ForwardDriver):
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Unexpected exception occurred " "<r><bg #f8bbd0>Unexpected exception occurred "
"while websocket loop</bg #f8bbd0></r>") "while websocket loop</bg #f8bbd0></r>"
)
@dataclass @dataclass

View File

@ -32,11 +32,15 @@ from nonebot.typing import overrides
from nonebot.utils import escape_tag from nonebot.utils import escape_tag
from nonebot.config import Config as NoneBotConfig from nonebot.config import Config as NoneBotConfig
from nonebot.drivers import WebSocket as BaseWebSocket from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import (HTTPRequest, ForwardDriver, ReverseDriver, from nonebot.drivers import (
WebSocketSetup, HTTPPollingSetup) HTTPRequest,
ForwardDriver,
ReverseDriver,
WebSocketSetup,
HTTPPollingSetup,
)
HTTPPOLLING_SETUP = Union[HTTPPollingSetup, HTTPPOLLING_SETUP = Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]
Callable[[], Awaitable[HTTPPollingSetup]]]
WEBSOCKET_SETUP = Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]] WEBSOCKET_SETUP = Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]
@ -44,6 +48,7 @@ class Config(BaseSettings):
""" """
FastAPI 驱动框架设置详情参考 FastAPI 文档 FastAPI 驱动框架设置详情参考 FastAPI 文档
""" """
fastapi_openapi_url: Optional[str] = None fastapi_openapi_url: Optional[str] = None
""" """
:类型: :类型:
@ -226,12 +231,14 @@ class Driver(ReverseDriver, ForwardDriver):
self.websockets.append(setup) self.websockets.append(setup)
@overrides(ReverseDriver) @overrides(ReverseDriver)
def run(self, def run(
self,
host: Optional[str] = None, host: Optional[str] = None,
port: Optional[int] = None, port: Optional[int] = None,
*, *,
app: Optional[str] = None, app: Optional[str] = None,
**kwargs): **kwargs,
):
"""使用 ``uvicorn`` 启动 FastAPI""" """使用 ``uvicorn`` 启动 FastAPI"""
super().run(host, port, app, **kwargs) super().run(host, port, app, **kwargs)
LOGGING_CONFIG = { LOGGING_CONFIG = {
@ -243,10 +250,7 @@ class Driver(ReverseDriver, ForwardDriver):
}, },
}, },
"loggers": { "loggers": {
"uvicorn.error": { "uvicorn.error": {"handlers": ["default"], "level": "INFO"},
"handlers": ["default"],
"level": "INFO"
},
"uvicorn.access": { "uvicorn.access": {
"handlers": ["default"], "handlers": ["default"],
"level": "INFO", "level": "INFO",
@ -258,15 +262,16 @@ class Driver(ReverseDriver, ForwardDriver):
host=host or str(self.config.host), host=host or str(self.config.host),
port=port or self.config.port, port=port or self.config.port,
reload=self.fastapi_config.fastapi_reload reload=self.fastapi_config.fastapi_reload
if self.fastapi_config.fastapi_reload is not None else if self.fastapi_config.fastapi_reload is not None
(bool(app) and self.config.debug), else (bool(app) and self.config.debug),
reload_dirs=self.fastapi_config.fastapi_reload_dirs, reload_dirs=self.fastapi_config.fastapi_reload_dirs,
reload_delay=self.fastapi_config.fastapi_reload_delay, reload_delay=self.fastapi_config.fastapi_reload_delay,
reload_includes=self.fastapi_config.fastapi_reload_includes, reload_includes=self.fastapi_config.fastapi_reload_includes,
reload_excludes=self.fastapi_config.fastapi_reload_excludes, reload_excludes=self.fastapi_config.fastapi_reload_excludes,
debug=self.config.debug, debug=self.config.debug,
log_config=LOGGING_CONFIG, log_config=LOGGING_CONFIG,
**kwargs) **kwargs,
)
def _run_forward(self): def _run_forward(self):
for setup in self.http_pollings: for setup in self.http_pollings:
@ -287,39 +292,49 @@ class Driver(ReverseDriver, ForwardDriver):
logger.warning( logger.warning(
f"Unknown adapter {adapter}. Please register the adapter before use." f"Unknown adapter {adapter}. Please register the adapter before use."
) )
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, raise HTTPException(
detail="adapter not found") status_code=status.HTTP_404_NOT_FOUND, detail="adapter not found"
)
# 创建 Bot 对象 # 创建 Bot 对象
BotClass = self._adapters[adapter] BotClass = self._adapters[adapter]
http_request = HTTPRequest(request.scope["http_version"], http_request = HTTPRequest(
request.url.scheme, request.url.path, request.scope["http_version"],
request.url.scheme,
request.url.path,
request.scope["query_string"], request.scope["query_string"],
dict(request.headers), request.method, data) dict(request.headers),
x_self_id, response = await BotClass.check_permission( request.method,
self, http_request) data,
)
x_self_id, response = await BotClass.check_permission(self, http_request)
if not x_self_id: if not x_self_id:
raise HTTPException( raise HTTPException(
response and response.status or 401, response and response and response.status or 401,
response.body and response.body.decode("utf-8")) response and response.body and response.body.decode("utf-8"),
)
if x_self_id in self._clients: if x_self_id in self._clients:
logger.warning("There's already a reverse websocket connection," logger.warning(
"so the event may be handled twice.") "There's already a reverse websocket connection,"
"so the event may be handled twice."
)
bot = BotClass(x_self_id, http_request) bot = BotClass(x_self_id, http_request)
asyncio.create_task(bot.handle_message(data)) asyncio.create_task(bot.handle_message(data))
return Response(response and response.body, return Response(response and response.body, response and response.status or 200)
response and response.status or 200)
async def _handle_ws_reverse(self, adapter: str, async def _handle_ws_reverse(self, adapter: str, websocket: FastAPIWebSocket):
websocket: FastAPIWebSocket): ws = WebSocket(
ws = WebSocket(websocket.scope.get("http_version", websocket.scope.get("http_version", "1.1"),
"1.1"), websocket.url.scheme, websocket.url.scheme,
websocket.url.path, websocket.scope["query_string"], websocket.url.path,
dict(websocket.headers), websocket) websocket.scope["query_string"],
dict(websocket.headers),
websocket,
)
if adapter not in self._adapters: if adapter not in self._adapters:
logger.warning( logger.warning(
@ -349,7 +364,8 @@ class Driver(ReverseDriver, ForwardDriver):
await ws.accept() await ws.accept()
logger.opt(colors=True).info( logger.opt(colors=True).info(
f"WebSocket Connection from <y>{escape_tag(adapter.upper())} " f"WebSocket Connection from <y>{escape_tag(adapter.upper())} "
f"Bot {escape_tag(self_id)}</y> Accepted!") f"Bot {escape_tag(self_id)}</y> Accepted!"
)
self._bot_connect(bot) self._bot_connect(bot)
@ -362,7 +378,8 @@ class Driver(ReverseDriver, ForwardDriver):
break break
except Exception as e: except Exception as e:
logger.opt(exception=e).error( logger.opt(exception=e).error(
"Error when receiving data from websocket.") "Error when receiving data from websocket."
)
break break
asyncio.create_task(bot.handle_message(data.encode())) asyncio.create_task(bot.handle_message(data.encode()))
@ -370,9 +387,7 @@ class Driver(ReverseDriver, ForwardDriver):
self._bot_disconnect(bot) self._bot_disconnect(bot)
async def _http_loop(self, setup: HTTPPOLLING_SETUP): async def _http_loop(self, setup: HTTPPOLLING_SETUP):
async def _build_request(setup: HTTPPollingSetup) -> Optional[HTTPRequest]:
async def _build_request(
setup: HTTPPollingSetup) -> Optional[HTTPRequest]:
url = httpx.URL(setup.url) url = httpx.URL(setup.url)
if not url.netloc: if not url.netloc:
logger.opt(colors=True).error( logger.opt(colors=True).error(
@ -380,9 +395,14 @@ class Driver(ReverseDriver, ForwardDriver):
) )
return return
return HTTPRequest( return HTTPRequest(
setup.http_version, url.scheme, url.path, url.query, { setup.http_version,
**setup.headers, "host": url.netloc.decode("ascii") url.scheme,
}, setup.method, setup.body) url.path,
url.query,
{**setup.headers, "host": url.netloc.decode("ascii")},
setup.method,
setup.body,
)
bot: Optional[Bot] = None bot: Optional[Bot] = None
request: Optional[HTTPRequest] = None request: Optional[HTTPRequest] = None
@ -390,11 +410,11 @@ class Driver(ReverseDriver, ForwardDriver):
logger.opt(colors=True).info( logger.opt(colors=True).info(
f"Start http polling for <y>{escape_tag(setup.adapter.upper())} " f"Start http polling for <y>{escape_tag(setup.adapter.upper())} "
f"Bot {escape_tag(setup.self_id)}</y>") f"Bot {escape_tag(setup.self_id)}</y>"
)
try: try:
async with httpx.AsyncClient(http2=True, async with httpx.AsyncClient(http2=True, follow_redirects=True) as session:
follow_redirects=True) as session:
while not self.shutdown.is_set(): while not self.shutdown.is_set():
try: try:
@ -405,7 +425,8 @@ class Driver(ReverseDriver, ForwardDriver):
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error while parsing setup " "<r><bg #f8bbd0>Error while parsing setup "
f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>") f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>"
)
await asyncio.sleep(3) await asyncio.sleep(3)
continue continue
@ -432,18 +453,21 @@ class Driver(ReverseDriver, ForwardDriver):
f"Bot {setup_.self_id} from adapter {setup_.adapter} request {setup_.url}" f"Bot {setup_.self_id} from adapter {setup_.adapter} request {setup_.url}"
) )
try: try:
response = await session.request(request.method, response = await session.request(
request.method,
setup_.url, setup_.url,
content=request.body, content=request.body,
headers=headers, headers=headers,
timeout=30.) timeout=30.0,
)
response.raise_for_status() response.raise_for_status()
data = response.read() data = response.read()
asyncio.create_task(bot.handle_message(data)) asyncio.create_task(bot.handle_message(data))
except httpx.HTTPError as e: except httpx.HTTPError as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Error occurred while requesting {escape_tag(setup_.url)}. " f"<r><bg #f8bbd0>Error occurred while requesting {escape_tag(setup_.url)}. "
"Try to reconnect...</bg #f8bbd0></r>") "Try to reconnect...</bg #f8bbd0></r>"
)
await asyncio.sleep(setup_.poll_interval) await asyncio.sleep(setup_.poll_interval)
@ -452,7 +476,8 @@ class Driver(ReverseDriver, ForwardDriver):
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Unexpected exception occurred " "<r><bg #f8bbd0>Unexpected exception occurred "
"while http polling</bg #f8bbd0></r>") "while http polling</bg #f8bbd0></r>"
)
finally: finally:
if bot: if bot:
self._bot_disconnect(bot) self._bot_disconnect(bot)
@ -471,7 +496,8 @@ class Driver(ReverseDriver, ForwardDriver):
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error while parsing setup " "<r><bg #f8bbd0>Error while parsing setup "
f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>") f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>"
)
await asyncio.sleep(3) await asyncio.sleep(3)
continue continue
@ -491,9 +517,11 @@ class Driver(ReverseDriver, ForwardDriver):
async with connection as ws: async with connection as ws:
logger.opt(colors=True).info( logger.opt(colors=True).info(
f"WebSocket Connection to <y>{escape_tag(setup_.adapter.upper())} " f"WebSocket Connection to <y>{escape_tag(setup_.adapter.upper())} "
f"Bot {escape_tag(setup_.self_id)}</y> succeeded!") f"Bot {escape_tag(setup_.self_id)}</y> succeeded!"
request = WebSocket("1.1", url.scheme, url.path, )
url.query, headers, ws) request = WebSocket(
"1.1", url.scheme, url.path, url.query, headers, ws
)
BotClass = self._adapters[setup_.adapter] BotClass = self._adapters[setup_.adapter]
bot = BotClass(setup_.self_id, request) bot = BotClass(setup_.self_id, request)
@ -506,12 +534,14 @@ class Driver(ReverseDriver, ForwardDriver):
except ConnectionClosed: except ConnectionClosed:
logger.opt(colors=True).error( logger.opt(colors=True).error(
"<r><bg #f8bbd0>WebSocket connection closed by peer. " "<r><bg #f8bbd0>WebSocket connection closed by peer. "
"Try to reconnect...</bg #f8bbd0></r>") "Try to reconnect...</bg #f8bbd0></r>"
)
break break
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Error while connecting to {url}. " f"<r><bg #f8bbd0>Error while connecting to {url}. "
"Try to reconnect...</bg #f8bbd0></r>") "Try to reconnect...</bg #f8bbd0></r>"
)
finally: finally:
if bot: if bot:
self._bot_disconnect(bot) self._bot_disconnect(bot)
@ -523,21 +553,22 @@ class Driver(ReverseDriver, ForwardDriver):
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Unexpected exception occurred " "<r><bg #f8bbd0>Unexpected exception occurred "
"while websocket loop</bg #f8bbd0></r>") "while websocket loop</bg #f8bbd0></r>"
)
@dataclass @dataclass
class WebSocket(BaseWebSocket): class WebSocket(BaseWebSocket):
websocket: Union[FastAPIWebSocket, websocket: Union[FastAPIWebSocket, WebSocketClientProtocol] = None # type: ignore
WebSocketClientProtocol] = None # type: ignore
@property @property
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
def closed(self) -> bool: def closed(self) -> bool:
if isinstance(self.websocket, FastAPIWebSocket): if isinstance(self.websocket, FastAPIWebSocket):
return ( return (
self.websocket.client_state == WebSocketState.DISCONNECTED or self.websocket.client_state == WebSocketState.DISCONNECTED
self.websocket.application_state == WebSocketState.DISCONNECTED) or self.websocket.application_state == WebSocketState.DISCONNECTED
)
else: else:
return self.websocket.closed return self.websocket.closed

View File

@ -30,8 +30,7 @@ try:
from quart import Quart, Request, Response from quart import Quart, Request, Response
from quart import Websocket as QuartWebSocket from quart import Websocket as QuartWebSocket
except ImportError: except ImportError:
raise ValueError( raise ValueError("Please install Quart by using `pip install nonebot2[quart]`")
'Please install Quart by using `pip install nonebot2[quart]`')
_AsyncCallable = TypeVar("_AsyncCallable", bound=Callable[..., Coroutine]) _AsyncCallable = TypeVar("_AsyncCallable", bound=Callable[..., Coroutine])
@ -40,6 +39,7 @@ class Config(BaseSettings):
""" """
Quart 驱动框架设置 Quart 驱动框架设置
""" """
quart_reload: Optional[bool] = None quart_reload: Optional[bool] = None
""" """
:类型: :类型:
@ -111,11 +111,12 @@ class Driver(ReverseDriver):
self.quart_config = Config(**config.dict()) self.quart_config = Config(**config.dict())
self._server_app = Quart(self.__class__.__qualname__) self._server_app = Quart(self.__class__.__qualname__)
self._server_app.add_url_rule("/<adapter>/http", self._server_app.add_url_rule(
methods=["POST"], "/<adapter>/http", methods=["POST"], view_func=self._handle_http
view_func=self._handle_http) )
self._server_app.add_websocket("/<adapter>/ws", self._server_app.add_websocket(
view_func=self._handle_ws_reverse) "/<adapter>/ws", view_func=self._handle_ws_reverse
)
@property @property
@overrides(ReverseDriver) @overrides(ReverseDriver)
@ -156,12 +157,14 @@ class Driver(ReverseDriver):
return self.server_app.after_serving(func) # type: ignore return self.server_app.after_serving(func) # type: ignore
@overrides(ReverseDriver) @overrides(ReverseDriver)
def run(self, def run(
self,
host: Optional[str] = None, host: Optional[str] = None,
port: Optional[int] = None, port: Optional[int] = None,
*, *,
app: Optional[str] = None, app: Optional[str] = None,
**kwargs): **kwargs,
):
"""使用 ``uvicorn`` 启动 Quart""" """使用 ``uvicorn`` 启动 Quart"""
super().run(host, port, app, **kwargs) super().run(host, port, app, **kwargs)
LOGGING_CONFIG = { LOGGING_CONFIG = {
@ -173,10 +176,7 @@ class Driver(ReverseDriver):
}, },
}, },
"loggers": { "loggers": {
"uvicorn.error": { "uvicorn.error": {"handlers": ["default"], "level": "INFO"},
"handlers": ["default"],
"level": "INFO"
},
"uvicorn.access": { "uvicorn.access": {
"handlers": ["default"], "handlers": ["default"],
"level": "INFO", "level": "INFO",
@ -188,52 +188,69 @@ class Driver(ReverseDriver):
host=host or str(self.config.host), host=host or str(self.config.host),
port=port or self.config.port, port=port or self.config.port,
reload=self.quart_config.quart_reload reload=self.quart_config.quart_reload
if self.quart_config.quart_reload is not None else if self.quart_config.quart_reload is not None
(bool(app) and self.config.debug), else (bool(app) and self.config.debug),
reload_dirs=self.quart_config.quart_reload_dirs, reload_dirs=self.quart_config.quart_reload_dirs,
reload_delay=self.quart_config.quart_reload_delay, reload_delay=self.quart_config.quart_reload_delay,
reload_includes=self.quart_config.quart_reload_includes, reload_includes=self.quart_config.quart_reload_includes,
reload_excludes=self.quart_config.quart_reload_excludes, reload_excludes=self.quart_config.quart_reload_excludes,
debug=self.config.debug, debug=self.config.debug,
log_config=LOGGING_CONFIG, log_config=LOGGING_CONFIG,
**kwargs) **kwargs,
)
async def _handle_http(self, adapter: str): async def _handle_http(self, adapter: str):
request: Request = _request request: Request = _request
data: bytes = await request.get_data() # type: ignore data: bytes = await request.get_data() # type: ignore
if adapter not in self._adapters: if adapter not in self._adapters:
logger.warning(f'Unknown adapter {adapter}. ' logger.warning(
'Please register the adapter before use.') f"Unknown adapter {adapter}. " "Please register the adapter before use."
)
raise exceptions.NotFound() raise exceptions.NotFound()
BotClass = self._adapters[adapter] BotClass = self._adapters[adapter]
http_request = HTTPRequest(request.http_version, request.scheme, http_request = HTTPRequest(
request.path, request.query_string, request.http_version,
dict(request.headers), request.method, data) request.scheme,
request.path,
request.query_string,
dict(request.headers),
request.method,
data,
)
self_id, response = await BotClass.check_permission(self, http_request) self_id, response = await BotClass.check_permission(self, http_request)
if not self_id: if not self_id:
raise exceptions.Unauthorized( raise exceptions.Unauthorized(
description=(response and response.body or b"").decode()) description=(response and response.body or b"").decode()
)
if self_id in self._clients: if self_id in self._clients:
logger.warning("There's already a reverse websocket connection," logger.warning(
"so the event may be handled twice.") "There's already a reverse websocket connection,"
"so the event may be handled twice."
)
bot = BotClass(self_id, http_request) bot = BotClass(self_id, http_request)
asyncio.create_task(bot.handle_message(data)) asyncio.create_task(bot.handle_message(data))
return Response(response and response.body or "", return Response(
response and response.status or 200) response and response.body or "", response and response.status or 200
)
async def _handle_ws_reverse(self, adapter: str): async def _handle_ws_reverse(self, adapter: str):
websocket: QuartWebSocket = _websocket websocket: QuartWebSocket = _websocket
ws = WebSocket(websocket.http_version, websocket.scheme, ws = WebSocket(
websocket.path, websocket.query_string, websocket.http_version,
dict(websocket.headers), websocket) websocket.scheme,
websocket.path,
websocket.query_string,
dict(websocket.headers),
websocket,
)
if adapter not in self._adapters: if adapter not in self._adapters:
logger.warning( logger.warning(
f'Unknown adapter {adapter}. Please register the adapter before use.' f"Unknown adapter {adapter}. Please register the adapter before use."
) )
raise exceptions.NotFound() raise exceptions.NotFound()
@ -242,20 +259,22 @@ class Driver(ReverseDriver):
if not self_id: if not self_id:
raise exceptions.Unauthorized( raise exceptions.Unauthorized(
description=(response and response.body or b"").decode()) description=(response and response.body or b"").decode()
)
if self_id in self._clients: if self_id in self._clients:
logger.opt(colors=True).warning( logger.opt(colors=True).warning(
"There's already a websocket connection, " "There's already a websocket connection, "
f"<y>{escape_tag(adapter.upper())} Bot {escape_tag(self_id)}</y> ignored." f"<y>{escape_tag(adapter.upper())} Bot {escape_tag(self_id)}</y> ignored."
) )
raise exceptions.Forbidden(description='Client already exists.') raise exceptions.Forbidden(description="Client already exists.")
bot = BotClass(self_id, ws) bot = BotClass(self_id, ws)
await ws.accept() await ws.accept()
logger.opt(colors=True).info( logger.opt(colors=True).info(
f"WebSocket Connection from <y>{escape_tag(adapter.upper())} " f"WebSocket Connection from <y>{escape_tag(adapter.upper())} "
f"Bot {escape_tag(self_id)}</y> Accepted!") f"Bot {escape_tag(self_id)}</y> Accepted!"
)
self._bot_connect(bot) self._bot_connect(bot)
try: try:
@ -267,7 +286,8 @@ class Driver(ReverseDriver):
break break
except Exception as e: except Exception as e:
logger.opt(exception=e).error( logger.opt(exception=e).error(
"Error when receiving data from websocket.") "Error when receiving data from websocket."
)
break break
asyncio.create_task(bot.handle_message(data.encode())) asyncio.create_task(bot.handle_message(data.encode()))

View File

@ -157,6 +157,7 @@ class NoLogException(AdapterException):
指示 NoneBot 对当前 ``Event`` 进行处理但不显示 Log 信息可在 ``get_log_string`` 时抛出 指示 NoneBot 对当前 ``Event`` 进行处理但不显示 Log 信息可在 ``get_log_string`` 时抛出
""" """
pass pass
@ -166,6 +167,7 @@ class ApiNotAvailable(AdapterException):
API 连接不可用时抛出 API 连接不可用时抛出
""" """
pass pass
@ -175,6 +177,7 @@ class NetworkError(AdapterException):
在网络出现问题时抛出: API 请求地址不正确, API 请求无返回或返回状态非正常等 在网络出现问题时抛出: API 请求地址不正确, API 请求无返回或返回状态非正常等
""" """
pass pass
@ -184,4 +187,5 @@ class ActionFailed(AdapterException):
API 请求成功返回数据 API 操作失败 API 请求成功返回数据 API 操作失败
""" """
pass pass

View File

@ -10,20 +10,27 @@ from contextlib import AsyncExitStack
from typing import Any, Dict, List, Type, Callable, Optional from typing import Any, Dict, List, Type, Callable, Optional
from nonebot.utils import get_name, run_sync from nonebot.utils import get_name, run_sync
from nonebot.dependencies import (Param, Dependent, DependsWrapper, from nonebot.dependencies import (
get_dependent, solve_dependencies, Param,
get_parameterless_sub_dependant) Dependent,
DependsWrapper,
get_dependent,
solve_dependencies,
get_parameterless_sub_dependant,
)
class Handler: class Handler:
"""事件处理器类。支持依赖注入。""" """事件处理器类。支持依赖注入。"""
def __init__(self, def __init__(
self,
func: Callable[..., Any], func: Callable[..., Any],
*, *,
name: Optional[str] = None, name: Optional[str] = None,
dependencies: Optional[List[DependsWrapper]] = None, dependencies: Optional[List[DependsWrapper]] = None,
allow_types: Optional[List[Type[Param]]] = None): allow_types: Optional[List[Type[Param]]] = None,
):
""" """
:说明: :说明:
@ -64,19 +71,18 @@ class Handler:
self.dependent = get_dependent(func=func, allow_types=self.allow_types) self.dependent = get_dependent(func=func, allow_types=self.allow_types)
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return f"<Handler {self.name}({', '.join(map(str, self.dependent.params))})>"
f"<Handler {self.name}({', '.join(map(str, self.dependent.params))})>"
)
def __str__(self) -> str: def __str__(self) -> str:
return repr(self) return repr(self)
async def __call__(self, async def __call__(
self,
*, *,
_stack: Optional[AsyncExitStack] = None, _stack: Optional[AsyncExitStack] = None,
_dependency_cache: Optional[Dict[Callable[..., Any], _dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None,
Any]] = None, **params,
**params) -> Any: ) -> Any:
values, _ = await solve_dependencies( values, _ = await solve_dependencies(
_dependent=self.dependent, _dependent=self.dependent,
_stack=_stack, _stack=_stack,
@ -85,7 +91,8 @@ class Handler:
for dependency in self.dependencies for dependency in self.dependencies
], ],
_dependency_cache=_dependency_cache, _dependency_cache=_dependency_cache,
**params) **params,
)
if asyncio.iscoroutinefunction(self.func): if asyncio.iscoroutinefunction(self.func):
return await self.func(**values) return await self.func(**values)
@ -98,7 +105,8 @@ class Handler:
if dependency.dependency in self.sub_dependents: if dependency.dependency in self.sub_dependents:
raise ValueError(f"{dependency} is already in dependencies") raise ValueError(f"{dependency} is already in dependencies")
sub_dependant = get_parameterless_sub_dependant( sub_dependant = get_parameterless_sub_dependant(
depends=dependency, allow_types=self.allow_types) depends=dependency, allow_types=self.allow_types
)
self.sub_dependents[dependency.dependency] = sub_dependant self.sub_dependents[dependency.dependency] = sub_dependant
def prepend_dependency(self, dependency: DependsWrapper): def prepend_dependency(self, dependency: DependsWrapper):

View File

@ -48,7 +48,6 @@ logger: "Logger" = loguru.logger
class Filter: class Filter:
def __init__(self) -> None: def __init__(self) -> None:
self.level: Union[int, str] = "DEBUG" self.level: Union[int, str] = "DEBUG"
@ -58,13 +57,13 @@ class Filter:
if module: if module:
module_name = getattr(module, "__module_name__", module_name) module_name = getattr(module, "__module_name__", module_name)
record["name"] = module_name.split(".")[0] record["name"] = module_name.split(".")[0]
levelno = logger.level(self.level).no if isinstance(self.level, levelno = (
str) else self.level logger.level(self.level).no if isinstance(self.level, str) else self.level
)
return record["level"].no >= levelno return record["level"].no >= levelno
class LoguruHandler(logging.Handler): class LoguruHandler(logging.Handler):
def emit(self, record): def emit(self, record):
try: try:
level = logger.level(record.levelname).name level = logger.level(record.levelname).name
@ -76,8 +75,9 @@ class LoguruHandler(logging.Handler):
frame = frame.f_back frame = frame.f_back
depth += 1 depth += 1
logger.opt(depth=depth, logger.opt(depth=depth, exception=record.exc_info).log(
exception=record.exc_info).log(level, record.getMessage()) level, record.getMessage()
)
logger.remove() logger.remove()
@ -87,9 +87,12 @@ default_format = (
"[<lvl>{level}</lvl>] " "[<lvl>{level}</lvl>] "
"<c><u>{name}</u></c> | " "<c><u>{name}</u></c> | "
# "<c>{function}:{line}</c>| " # "<c>{function}:{line}</c>| "
"{message}") "{message}"
logger_id = logger.add(sys.stdout, )
logger_id = logger.add(
sys.stdout,
colorize=True, colorize=True,
diagnose=False, diagnose=False,
filter=default_filter, filter=default_filter,
format=default_format) format=default_format,
)

View File

@ -10,8 +10,17 @@ from datetime import datetime
from contextvars import ContextVar from contextvars import ContextVar
from collections import defaultdict from collections import defaultdict
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from typing import (TYPE_CHECKING, Any, Dict, List, Type, Union, Callable, from typing import (
NoReturn, Optional) TYPE_CHECKING,
Any,
Dict,
List,
Type,
Union,
Callable,
NoReturn,
Optional,
)
from nonebot import params from nonebot import params
from nonebot.rule import Rule from nonebot.rule import Rule
@ -19,14 +28,29 @@ from nonebot.log import logger
from nonebot.handler import Handler from nonebot.handler import Handler
from nonebot.dependencies import DependsWrapper from nonebot.dependencies import DependsWrapper
from nonebot.permission import USER, Permission from nonebot.permission import USER, Permission
from nonebot.adapters import (Bot, Event, Message, MessageSegment, from nonebot.adapters import (
MessageTemplate) Bot,
from nonebot.exception import (PausedException, StopPropagation, Event,
SkippedException, FinishedException, Message,
RejectedException) MessageSegment,
from nonebot.typing import (T_State, T_Handler, T_ArgsParser, T_TypeUpdater, MessageTemplate,
T_StateFactory, T_DependencyCache, )
T_PermissionUpdater) from nonebot.exception import (
PausedException,
StopPropagation,
SkippedException,
FinishedException,
RejectedException,
)
from nonebot.typing import (
T_State,
T_Handler,
T_ArgsParser,
T_TypeUpdater,
T_StateFactory,
T_DependencyCache,
T_PermissionUpdater,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from nonebot.plugin import Plugin from nonebot.plugin import Plugin
@ -57,9 +81,11 @@ class MatcherMeta(type):
expire_time: Optional[datetime] expire_time: Optional[datetime]
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"<Matcher from {self.module_name or 'unknown'}, " return (
f"<Matcher from {self.module_name or 'unknown'}, "
f"type={self.type}, priority={self.priority}, " f"type={self.type}, priority={self.priority}, "
f"temp={self.temp}>") f"temp={self.temp}>"
)
def __str__(self) -> str: def __str__(self) -> str:
return repr(self) return repr(self)
@ -67,6 +93,7 @@ class MatcherMeta(type):
class Matcher(metaclass=MatcherMeta): class Matcher(metaclass=MatcherMeta):
"""事件响应器类""" """事件响应器类"""
plugin: Optional["Plugin"] = None plugin: Optional["Plugin"] = None
""" """
:类型: ``Optional[Plugin]`` :类型: ``Optional[Plugin]``
@ -157,8 +184,11 @@ class Matcher(metaclass=MatcherMeta):
""" """
HANDLER_PARAM_TYPES = [ HANDLER_PARAM_TYPES = [
params.BotParam, params.EventParam, params.StateParam, params.BotParam,
params.MatcherParam, params.DefaultParam params.EventParam,
params.StateParam,
params.MatcherParam,
params.DefaultParam,
] ]
def __init__(self): def __init__(self):
@ -169,7 +199,8 @@ class Matcher(metaclass=MatcherMeta):
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"<Matcher from {self.module_name or 'unknown'}, type={self.type}, " f"<Matcher from {self.module_name or 'unknown'}, type={self.type}, "
f"priority={self.priority}, temp={self.temp}>") f"priority={self.priority}, temp={self.temp}>"
)
def __str__(self) -> str: def __str__(self) -> str:
return repr(self) return repr(self)
@ -180,8 +211,9 @@ class Matcher(metaclass=MatcherMeta):
type_: str = "", type_: str = "",
rule: Optional[Rule] = None, rule: Optional[Rule] = None,
permission: Optional[Permission] = None, permission: Optional[Permission] = None,
handlers: Optional[Union[List[T_Handler], List[Handler], handlers: Optional[
List[Union[T_Handler, Handler]]]] = None, Union[List[T_Handler], List[Handler], List[Union[T_Handler, Handler]]]
] = None,
temp: bool = False, temp: bool = False,
priority: int = 1, priority: int = 1,
block: bool = False, block: bool = False,
@ -193,7 +225,7 @@ class Matcher(metaclass=MatcherMeta):
default_state_factory: Optional[T_StateFactory] = None, default_state_factory: Optional[T_StateFactory] = None,
default_parser: Optional[T_ArgsParser] = None, default_parser: Optional[T_ArgsParser] = None,
default_type_updater: Optional[T_TypeUpdater] = None, default_type_updater: Optional[T_TypeUpdater] = None,
default_permission_updater: Optional[T_PermissionUpdater] = None default_permission_updater: Optional[T_PermissionUpdater] = None,
) -> Type["Matcher"]: ) -> Type["Matcher"]:
""" """
:说明: :说明:
@ -221,46 +253,37 @@ class Matcher(metaclass=MatcherMeta):
""" """
NewMatcher = type( NewMatcher = type(
"Matcher", (Matcher,), { "Matcher",
"plugin": (Matcher,),
plugin, {
"module": "plugin": plugin,
module, "module": module,
"plugin_name": "plugin_name": plugin and plugin.name,
plugin and plugin.name, "module_name": module and module.__name__,
"module_name": "type": type_,
module and module.__name__, "rule": rule or Rule(),
"type": "permission": permission or Permission(),
type_,
"rule":
rule or Rule(),
"permission":
permission or Permission(),
"handlers": [ "handlers": [
handler if isinstance(handler, Handler) else Handler( handler
handler, allow_types=cls.HANDLER_PARAM_TYPES) if isinstance(handler, Handler)
else Handler(handler, allow_types=cls.HANDLER_PARAM_TYPES)
for handler in handlers for handler in handlers
] if handlers else [], ]
"temp": if handlers
temp, else [],
"expire_time": "temp": temp,
expire_time, "expire_time": expire_time,
"priority": "priority": priority,
priority, "block": block,
"block": "_default_state": default_state or {},
block, "_default_state_factory": staticmethod(default_state_factory)
"_default_state": if default_state_factory
default_state or {}, else None,
"_default_state_factory": "_default_parser": default_parser,
staticmethod(default_state_factory) "_default_type_updater": default_type_updater,
if default_state_factory else None, "_default_permission_updater": default_permission_updater,
"_default_parser": },
default_parser, )
"_default_type_updater":
default_type_updater,
"_default_permission_updater":
default_permission_updater
})
matchers[priority].append(NewMatcher) matchers[priority].append(NewMatcher)
@ -272,8 +295,8 @@ class Matcher(metaclass=MatcherMeta):
bot: Bot, bot: Bot,
event: Event, event: Event,
stack: Optional[AsyncExitStack] = None, stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[Dict[Callable[..., Any], dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None,
Any]] = None) -> bool: ) -> bool:
""" """
:说明: :说明:
@ -289,8 +312,9 @@ class Matcher(metaclass=MatcherMeta):
- ``bool``: 是否满足权限 - ``bool``: 是否满足权限
""" """
event_type = event.get_type() event_type = event.get_type()
return (event_type == (cls.type or event_type) and return event_type == (cls.type or event_type) and await cls.permission(
await cls.permission(bot, event, stack, dependency_cache)) bot, event, stack, dependency_cache
)
@classmethod @classmethod
async def check_rule( async def check_rule(
@ -299,8 +323,8 @@ class Matcher(metaclass=MatcherMeta):
event: Event, event: Event,
state: T_State, state: T_State,
stack: Optional[AsyncExitStack] = None, stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[Dict[Callable[..., Any], dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None,
Any]] = None) -> bool: ) -> bool:
""" """
:说明: :说明:
@ -317,8 +341,9 @@ class Matcher(metaclass=MatcherMeta):
- ``bool``: 是否满足匹配规则 - ``bool``: 是否满足匹配规则
""" """
event_type = event.get_type() event_type = event.get_type()
return (event_type == (cls.type or event_type) and return event_type == (cls.type or event_type) and await cls.rule(
await cls.rule(bot, event, state, stack, dependency_cache)) bot, event, state, stack, dependency_cache
)
@classmethod @classmethod
def args_parser(cls, func: T_ArgsParser) -> T_ArgsParser: def args_parser(cls, func: T_ArgsParser) -> T_ArgsParser:
@ -349,8 +374,7 @@ class Matcher(metaclass=MatcherMeta):
return func return func
@classmethod @classmethod
def permission_updater(cls, def permission_updater(cls, func: T_PermissionUpdater) -> T_PermissionUpdater:
func: T_PermissionUpdater) -> T_PermissionUpdater:
""" """
:说明: :说明:
@ -365,12 +389,11 @@ class Matcher(metaclass=MatcherMeta):
@classmethod @classmethod
def append_handler( def append_handler(
cls, cls, handler: T_Handler, dependencies: Optional[List[DependsWrapper]] = None
handler: T_Handler, ) -> Handler:
dependencies: Optional[List[DependsWrapper]] = None) -> Handler: handler_ = Handler(
handler_ = Handler(handler, handler, dependencies=dependencies, allow_types=cls.HANDLER_PARAM_TYPES
dependencies=dependencies, )
allow_types=cls.HANDLER_PARAM_TYPES)
cls.handlers.append(handler_) cls.handlers.append(handler_)
return handler_ return handler_
@ -418,8 +441,7 @@ class Matcher(metaclass=MatcherMeta):
func_handler = cls.handlers[-1] func_handler = cls.handlers[-1]
func_handler.prepend_dependency(depend) func_handler.prepend_dependency(depend)
else: else:
cls.append_handler( cls.append_handler(func, dependencies=[depend] if cls.handlers else [])
func, dependencies=[depend] if cls.handlers else [])
return func return func
@ -429,9 +451,8 @@ class Matcher(metaclass=MatcherMeta):
def got( def got(
cls, cls,
key: str, key: str,
prompt: Optional[Union[str, Message, MessageSegment, prompt: Optional[Union[str, Message, MessageSegment, MessageTemplate]] = None,
MessageTemplate]] = None, args_parser: Optional[T_ArgsParser] = None,
args_parser: Optional[T_ArgsParser] = None
) -> Callable[[T_Handler], T_Handler]: ) -> Callable[[T_Handler], T_Handler]:
""" """
:说明: :说明:
@ -483,16 +504,16 @@ class Matcher(metaclass=MatcherMeta):
func_handler.prepend_dependency(parser_depend) func_handler.prepend_dependency(parser_depend)
func_handler.prepend_dependency(get_depend) func_handler.prepend_dependency(get_depend)
else: else:
cls.append_handler(func, cls.append_handler(func, dependencies=[get_depend, parser_depend])
dependencies=[get_depend, parser_depend])
return func return func
return _decorator return _decorator
@classmethod @classmethod
async def send(cls, message: Union[str, Message, MessageSegment, async def send(
MessageTemplate], **kwargs) -> Any: cls, message: Union[str, Message, MessageSegment, MessageTemplate], **kwargs
) -> Any:
""" """
:说明: :说明:
@ -513,10 +534,11 @@ class Matcher(metaclass=MatcherMeta):
return await bot.send(event=event, message=_message, **kwargs) return await bot.send(event=event, message=_message, **kwargs)
@classmethod @classmethod
async def finish(cls, async def finish(
message: Optional[Union[str, Message, MessageSegment, cls,
MessageTemplate]] = None, message: Optional[Union[str, Message, MessageSegment, MessageTemplate]] = None,
**kwargs) -> NoReturn: **kwargs,
) -> NoReturn:
""" """
:说明: :说明:
@ -539,10 +561,11 @@ class Matcher(metaclass=MatcherMeta):
raise FinishedException raise FinishedException
@classmethod @classmethod
async def pause(cls, async def pause(
prompt: Optional[Union[str, Message, MessageSegment, cls,
MessageTemplate]] = None, prompt: Optional[Union[str, Message, MessageSegment, MessageTemplate]] = None,
**kwargs) -> NoReturn: **kwargs,
) -> NoReturn:
""" """
:说明: :说明:
@ -565,10 +588,9 @@ class Matcher(metaclass=MatcherMeta):
raise PausedException raise PausedException
@classmethod @classmethod
async def reject(cls, async def reject(
prompt: Optional[Union[str, Message, cls, prompt: Optional[Union[str, Message, MessageSegment]] = None, **kwargs
MessageSegment]] = None, ) -> NoReturn:
**kwargs) -> NoReturn:
""" """
:说明: :说明:
@ -601,31 +623,38 @@ class Matcher(metaclass=MatcherMeta):
self.block = True self.block = True
# 运行handlers # 运行handlers
async def run(self, async def run(
self,
bot: Bot, bot: Bot,
event: Event, event: Event,
state: T_State, state: T_State,
stack: Optional[AsyncExitStack] = None, stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None): dependency_cache: Optional[T_DependencyCache] = None,
):
b_t = current_bot.set(bot) b_t = current_bot.set(bot)
e_t = current_event.set(event) e_t = current_event.set(event)
s_t = current_state.set(self.state) s_t = current_state.set(self.state)
try: try:
# Refresh preprocess state # Refresh preprocess state
self.state = await self._default_state_factory( self.state = (
bot, event) if self._default_state_factory else self.state await self._default_state_factory(bot, event)
if self._default_state_factory
else self.state
)
self.state.update(state) self.state.update(state)
while self.handlers: while self.handlers:
handler = self.handlers.pop(0) handler = self.handlers.pop(0)
logger.debug(f"Running handler {handler}") logger.debug(f"Running handler {handler}")
try: try:
await handler(matcher=self, await handler(
matcher=self,
bot=bot, bot=bot,
event=event, event=event,
state=self.state, state=self.state,
_stack=stack, _stack=stack,
_dependency_cache=dependency_cache) _dependency_cache=dependency_cache,
)
except SkippedException: except SkippedException:
pass pass
@ -633,18 +662,13 @@ class Matcher(metaclass=MatcherMeta):
self.handlers.insert(0, handler) # type: ignore self.handlers.insert(0, handler) # type: ignore
updater = self.__class__._default_type_updater updater = self.__class__._default_type_updater
if updater: if updater:
type_ = await updater( type_ = await updater(bot, event, self.state, self.type) # type: ignore
bot,
event,
self.state, # type: ignore
self.type)
else: else:
type_ = "message" type_ = "message"
updater = self.__class__._default_permission_updater updater = self.__class__._default_permission_updater
if updater: if updater:
permission = await updater(bot, event, self.state, permission = await updater(bot, event, self.state, self.permission)
self.permission)
else: else:
permission = USER(event.get_session_id(), perm=self.permission) permission = USER(event.get_session_id(), perm=self.permission)
@ -662,23 +686,18 @@ class Matcher(metaclass=MatcherMeta):
default_state=self.state, default_state=self.state,
default_parser=self.__class__._default_parser, default_parser=self.__class__._default_parser,
default_type_updater=self.__class__._default_type_updater, default_type_updater=self.__class__._default_type_updater,
default_permission_updater=self.__class__. default_permission_updater=self.__class__._default_permission_updater,
_default_permission_updater) )
except PausedException: except PausedException:
updater = self.__class__._default_type_updater updater = self.__class__._default_type_updater
if updater: if updater:
type_ = await updater( type_ = await updater(bot, event, self.state, self.type) # type: ignore
bot,
event,
self.state, # type: ignore
self.type)
else: else:
type_ = "message" type_ = "message"
updater = self.__class__._default_permission_updater updater = self.__class__._default_permission_updater
if updater: if updater:
permission = await updater(bot, event, self.state, permission = await updater(bot, event, self.state, self.permission)
self.permission)
else: else:
permission = USER(event.get_session_id(), perm=self.permission) permission = USER(event.get_session_id(), perm=self.permission)
@ -696,8 +715,8 @@ class Matcher(metaclass=MatcherMeta):
default_state=self.state, default_state=self.state,
default_parser=self.__class__._default_parser, default_parser=self.__class__._default_parser,
default_type_updater=self.__class__._default_type_updater, default_type_updater=self.__class__._default_type_updater,
default_permission_updater=self.__class__. default_permission_updater=self.__class__._default_permission_updater,
_default_permission_updater) )
except FinishedException: except FinishedException:
pass pass
except StopPropagation: except StopPropagation:

View File

@ -17,9 +17,14 @@ from nonebot.handler import Handler
from nonebot.utils import escape_tag from nonebot.utils import escape_tag
from nonebot.matcher import Matcher, matchers from nonebot.matcher import Matcher, matchers
from nonebot.exception import NoLogException, StopPropagation, IgnoredException from nonebot.exception import NoLogException, StopPropagation, IgnoredException
from nonebot.typing import (T_State, T_DependencyCache, T_RunPreProcessor, from nonebot.typing import (
T_RunPostProcessor, T_EventPreProcessor, T_State,
T_EventPostProcessor) T_DependencyCache,
T_RunPreProcessor,
T_RunPostProcessor,
T_EventPreProcessor,
T_EventPostProcessor,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from nonebot.adapters import Bot, Event from nonebot.adapters import Bot, Event
@ -30,15 +35,25 @@ _run_preprocessors: Set[Handler] = set()
_run_postprocessors: Set[Handler] = set() _run_postprocessors: Set[Handler] = set()
EVENT_PCS_PARAMS = [ EVENT_PCS_PARAMS = [
params.BotParam, params.EventParam, params.StateParam, params.DefaultParam params.BotParam,
params.EventParam,
params.StateParam,
params.DefaultParam,
] ]
RUN_PREPCS_PARAMS = [ RUN_PREPCS_PARAMS = [
params.MatcherParam, params.BotParam, params.EventParam, params.StateParam, params.MatcherParam,
params.DefaultParam params.BotParam,
params.EventParam,
params.StateParam,
params.DefaultParam,
] ]
RUN_POSTPCS_PARAMS = [ RUN_POSTPCS_PARAMS = [
params.MatcherParam, params.ExceptionParam, params.BotParam, params.MatcherParam,
params.EventParam, params.StateParam, params.DefaultParam params.ExceptionParam,
params.BotParam,
params.EventParam,
params.StateParam,
params.DefaultParam,
] ]
@ -89,7 +104,8 @@ async def _check_matcher(
event: "Event", event: "Event",
state: T_State, state: T_State,
stack: Optional[AsyncExitStack] = None, stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None) -> None: dependency_cache: Optional[T_DependencyCache] = None,
) -> None:
if Matcher.expire_time and datetime.now() > Matcher.expire_time: if Matcher.expire_time and datetime.now() > Matcher.expire_time:
try: try:
matchers[priority].remove(Matcher) matchers[priority].remove(Matcher)
@ -99,13 +115,13 @@ async def _check_matcher(
try: try:
if not await Matcher.check_perm( if not await Matcher.check_perm(
bot, event, stack, bot, event, stack, dependency_cache
dependency_cache) or not await Matcher.check_rule( ) or not await Matcher.check_rule(bot, event, state, stack, dependency_cache):
bot, event, state, stack, dependency_cache):
return return
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Rule check failed for {Matcher}.</bg #f8bbd0></r>") f"<r><bg #f8bbd0>Rule check failed for {Matcher}.</bg #f8bbd0></r>"
)
return return
if Matcher.temp: if Matcher.temp:
@ -123,31 +139,38 @@ async def _run_matcher(
event: "Event", event: "Event",
state: T_State, state: T_State,
stack: Optional[AsyncExitStack] = None, stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None) -> None: dependency_cache: Optional[T_DependencyCache] = None,
) -> None:
logger.info(f"Event will be handled by {Matcher}") logger.info(f"Event will be handled by {Matcher}")
matcher = Matcher() matcher = Matcher()
coros = list( coros = list(
map( map(
lambda x: x(matcher=matcher, lambda x: x(
matcher=matcher,
bot=bot, bot=bot,
event=event, event=event,
state=state, state=state,
_stack=stack, _stack=stack,
_dependency_cache=dependency_cache), _dependency_cache=dependency_cache,
_run_preprocessors)) ),
_run_preprocessors,
)
)
if coros: if coros:
try: try:
await asyncio.gather(*coros) await asyncio.gather(*coros)
except IgnoredException: except IgnoredException:
logger.opt(colors=True).info( logger.opt(colors=True).info(
f"Matcher {matcher} running is <b>cancelled</b>") f"Matcher {matcher} running is <b>cancelled</b>"
)
return return
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running RunPreProcessors. " "<r><bg #f8bbd0>Error when running RunPreProcessors. "
"Running cancelled!</bg #f8bbd0></r>") "Running cancelled!</bg #f8bbd0></r>"
)
return return
exception = None exception = None
@ -163,14 +186,18 @@ async def _run_matcher(
coros = list( coros = list(
map( map(
lambda x: x(matcher=matcher, lambda x: x(
matcher=matcher,
exception=exception, exception=exception,
bot=bot, bot=bot,
event=event, event=event,
state=state, state=state,
_stack=stack, _stack=stack,
_dependency_cache=dependency_cache), _dependency_cache=dependency_cache,
_run_postprocessors)) ),
_run_postprocessors,
)
)
if coros: if coros:
try: try:
await asyncio.gather(*coros) await asyncio.gather(*coros)
@ -217,12 +244,16 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
async with AsyncExitStack() as stack: async with AsyncExitStack() as stack:
coros = list( coros = list(
map( map(
lambda x: x(bot=bot, lambda x: x(
bot=bot,
event=event, event=event,
state=state, state=state,
_stack=stack, _stack=stack,
_dependency_cache=dependency_cache), _dependency_cache=dependency_cache,
_event_preprocessors)) ),
_event_preprocessors,
)
)
if coros: if coros:
try: try:
if show_log: if show_log:
@ -236,7 +267,8 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running EventPreProcessors. " "<r><bg #f8bbd0>Error when running EventPreProcessors. "
"Event ignored!</bg #f8bbd0></r>") "Event ignored!</bg #f8bbd0></r>"
)
return return
# Trie Match # Trie Match
@ -251,13 +283,13 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
logger.debug(f"Checking for matchers in priority {priority}...") logger.debug(f"Checking for matchers in priority {priority}...")
pending_tasks = [ pending_tasks = [
_check_matcher(priority, matcher, bot, event, state.copy(), _check_matcher(
stack, dependency_cache) priority, matcher, bot, event, state.copy(), stack, dependency_cache
)
for matcher in matchers[priority] for matcher in matchers[priority]
] ]
results = await asyncio.gather(*pending_tasks, results = await asyncio.gather(*pending_tasks, return_exceptions=True)
return_exceptions=True)
for result in results: for result in results:
if not isinstance(result, Exception): if not isinstance(result, Exception):
@ -272,12 +304,16 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
coros = list( coros = list(
map( map(
lambda x: x(bot=bot, lambda x: x(
bot=bot,
event=event, event=event,
state=state, state=state,
_stack=stack, _stack=stack,
_dependency_cache=dependency_cache), _dependency_cache=dependency_cache,
_event_postprocessors)) ),
_event_postprocessors,
)
)
if coros: if coros:
try: try:
if show_log: if show_log:

View File

@ -10,69 +10,61 @@ from nonebot.utils import generic_check_issubclass
class BotParam(Param): class BotParam(Param):
@classmethod @classmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool: def _check(cls, name: str, param: inspect.Parameter) -> bool:
return generic_check_issubclass( return generic_check_issubclass(param.annotation, Bot) or (
param.annotation, Bot) or (param.annotation == param.empty and param.annotation == param.empty and name == "bot"
name == "bot") )
def _solve(self, bot: Bot, **kwargs: Any) -> Any: def _solve(self, bot: Bot, **kwargs: Any) -> Any:
return bot return bot
class EventParam(Param): class EventParam(Param):
@classmethod @classmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool: def _check(cls, name: str, param: inspect.Parameter) -> bool:
return generic_check_issubclass( return generic_check_issubclass(param.annotation, Event) or (
param.annotation, Event) or (param.annotation == param.empty and param.annotation == param.empty and name == "event"
name == "event") )
def _solve(self, event: Event, **kwargs: Any) -> Any: def _solve(self, event: Event, **kwargs: Any) -> Any:
return event return event
class StateParam(Param): class StateParam(Param):
@classmethod @classmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool: def _check(cls, name: str, param: inspect.Parameter) -> bool:
return generic_check_issubclass( return generic_check_issubclass(param.annotation, Dict) or (
param.annotation, Dict) or (param.annotation == param.empty and param.annotation == param.empty and name == "state"
name == "state") )
def _solve(self, state: T_State, **kwargs: Any) -> Any: def _solve(self, state: T_State, **kwargs: Any) -> Any:
return state return state
class MatcherParam(Param): class MatcherParam(Param):
@classmethod @classmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool: def _check(cls, name: str, param: inspect.Parameter) -> bool:
return generic_check_issubclass( return generic_check_issubclass(param.annotation, Matcher) or (
param.annotation, Matcher) or (param.annotation == param.empty and param.annotation == param.empty and name == "matcher"
name == "matcher") )
def _solve(self, matcher: Optional["Matcher"] = None, **kwargs: Any) -> Any: def _solve(self, matcher: Optional["Matcher"] = None, **kwargs: Any) -> Any:
return matcher return matcher
class ExceptionParam(Param): class ExceptionParam(Param):
@classmethod @classmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool: def _check(cls, name: str, param: inspect.Parameter) -> bool:
return generic_check_issubclass( return generic_check_issubclass(param.annotation, Exception) or (
param.annotation, Exception) or (param.annotation == param.empty and param.annotation == param.empty and name == "exception"
name == "exception") )
def _solve(self, def _solve(self, exception: Optional[Exception] = None, **kwargs: Any) -> Any:
exception: Optional[Exception] = None,
**kwargs: Any) -> Any:
return exception return exception
class DefaultParam(Param): class DefaultParam(Param):
@classmethod @classmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool: def _check(cls, name: str, param: inspect.Parameter) -> bool:
return param.default != param.empty return param.default != param.empty

View File

@ -34,11 +34,10 @@ class Permission:
from nonebot.utils import run_sync from nonebot.utils import run_sync
Permission(async_function, run_sync(sync_function)) Permission(async_function, run_sync(sync_function))
""" """
__slots__ = ("checkers",) __slots__ = ("checkers",)
HANDLER_PARAM_TYPES = [ HANDLER_PARAM_TYPES = [params.BotParam, params.EventParam, params.DefaultParam]
params.BotParam, params.EventParam, params.DefaultParam
]
def __init__(self, *checkers: Union[T_PermissionChecker, Handler]) -> None: def __init__(self, *checkers: Union[T_PermissionChecker, Handler]) -> None:
""" """
@ -48,9 +47,11 @@ class Permission:
""" """
self.checkers = set( self.checkers = set(
checker if isinstance(checker, Handler) else Handler( checker
checker, allow_types=self.HANDLER_PARAM_TYPES) if isinstance(checker, Handler)
for checker in checkers) else Handler(checker, allow_types=self.HANDLER_PARAM_TYPES)
for checker in checkers
)
""" """
:说明: :说明:
@ -66,8 +67,8 @@ class Permission:
bot: Bot, bot: Bot,
event: Event, event: Event,
stack: Optional[AsyncExitStack] = None, stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[Dict[Callable[..., Any], dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None,
Any]] = None) -> bool: ) -> bool:
""" """
:说明: :说明:
@ -87,19 +88,24 @@ class Permission:
if not self.checkers: if not self.checkers:
return True return True
results = await asyncio.gather( results = await asyncio.gather(
*(checker(bot=bot, *(
checker(
bot=bot,
event=event, event=event,
_stack=stack, _stack=stack,
_dependency_cache=dependency_cache) _dependency_cache=dependency_cache,
for checker in self.checkers)) )
for checker in self.checkers
)
)
return any(results) return any(results)
def __and__(self, other) -> NoReturn: def __and__(self, other) -> NoReturn:
raise RuntimeError("And operation between Permissions is not allowed.") raise RuntimeError("And operation between Permissions is not allowed.")
def __or__( def __or__(
self, other: Optional[Union["Permission", self, other: Optional[Union["Permission", T_PermissionChecker]]
T_PermissionChecker]]) -> "Permission": ) -> "Permission":
if other is None: if other is None:
return self return self
elif isinstance(other, Permission): elif isinstance(other, Permission):
@ -155,15 +161,17 @@ def USER(*user: str, perm: Optional[Permission] = None):
""" """
async def _user(bot: Bot, event: Event) -> bool: async def _user(bot: Bot, event: Event) -> bool:
return bool(event.get_session_id() in user and return bool(
(perm is None or await perm(bot, event))) event.get_session_id() in user and (perm is None or await perm(bot, event))
)
return Permission(_user) return Permission(_user)
async def _superuser(bot: Bot, event: Event) -> bool: async def _superuser(bot: Bot, event: Event) -> bool:
return (event.get_type() == "message" and return (
event.get_user_id() in bot.config.superusers) event.get_type() == "message" and event.get_user_id() in bot.config.superusers
)
SUPERUSER = Permission(_superuser) SUPERUSER = Permission(_superuser)

View File

@ -9,8 +9,9 @@ from typing import List, Optional
from contextvars import ContextVar from contextvars import ContextVar
_managers: List["PluginManager"] = [] _managers: List["PluginManager"] = []
_current_plugin: ContextVar[Optional["Plugin"]] = ContextVar("_current_plugin", _current_plugin: ContextVar[Optional["Plugin"]] = ContextVar(
default=None) "_current_plugin", default=None
)
from .on import on as on from .on import on as on
from .manager import PluginManager from .manager import PluginManager

View File

@ -33,8 +33,7 @@ class Export(dict):
return func return func
def __setitem__(self, key, value): def __setitem__(self, key, value):
super().__setitem__(key, super().__setitem__(key, Export(value) if isinstance(value, dict) else value)
Export(value) if isinstance(value, dict) else value)
def __setattr__(self, name, value): def __setattr__(self, name, value):
self[name] = Export(value) if isinstance(value, dict) else value self[name] = Export(value) if isinstance(value, dict) else value

View File

@ -49,8 +49,9 @@ def load_plugins(*plugin_dir: str) -> Set[Plugin]:
return manager.load_all_plugins() return manager.load_all_plugins()
def load_all_plugins(module_path: Iterable[str], def load_all_plugins(
plugin_dir: Iterable[str]) -> Set[Plugin]: module_path: Iterable[str], plugin_dir: Iterable[str]
) -> Set[Plugin]:
""" """
:说明: :说明:
@ -90,8 +91,7 @@ def load_from_json(file_path: str, encoding: str = "utf-8") -> Set[Plugin]:
plugins = data.get("plugins") plugins = data.get("plugins")
plugin_dirs = data.get("plugin_dirs") plugin_dirs = data.get("plugin_dirs")
assert isinstance(plugins, list), "plugins must be a list of plugin name" assert isinstance(plugins, list), "plugins must be a list of plugin name"
assert isinstance(plugin_dirs, assert isinstance(plugin_dirs, list), "plugin_dirs must be a list of directories"
list), "plugin_dirs must be a list of directories"
return load_all_plugins(set(plugins), set(plugin_dirs)) return load_all_plugins(set(plugins), set(plugin_dirs))
@ -120,14 +120,14 @@ def load_from_toml(file_path: str, encoding: str = "utf-8") -> Set[Plugin]:
if nonebot_data: if nonebot_data:
warnings.warn( warnings.warn(
"[nonebot.plugins] table are now deprecated. Use [tool.nonebot] instead.", "[nonebot.plugins] table are now deprecated. Use [tool.nonebot] instead.",
DeprecationWarning) DeprecationWarning,
)
else: else:
raise ValueError("Cannot find '[tool.nonebot]' in given toml file!") raise ValueError("Cannot find '[tool.nonebot]' in given toml file!")
plugins = nonebot_data.get("plugins", []) plugins = nonebot_data.get("plugins", [])
plugin_dirs = nonebot_data.get("plugin_dirs", []) plugin_dirs = nonebot_data.get("plugin_dirs", [])
assert isinstance(plugins, list), "plugins must be a list of plugin name" assert isinstance(plugins, list), "plugins must be a list of plugin name"
assert isinstance(plugin_dirs, assert isinstance(plugin_dirs, list), "plugin_dirs must be a list of directories"
list), "plugin_dirs must be a list of directories"
return load_all_plugins(plugins, plugin_dirs) return load_all_plugins(plugins, plugin_dirs)
@ -163,5 +163,5 @@ def require(name: str) -> Export:
""" """
plugin = get_plugin(name) or load_plugin(name) plugin = get_plugin(name) or load_plugin(name)
if not plugin: if not plugin:
raise RuntimeError(f"Cannot load plugin \"{name}\"!") raise RuntimeError(f'Cannot load plugin "{name}"!')
return plugin.export return plugin.export

View File

@ -15,7 +15,6 @@ from . import _managers, _current_plugin
class PluginManager: class PluginManager:
def __init__( def __init__(
self, self,
plugins: Optional[Iterable[str]] = None, plugins: Optional[Iterable[str]] = None,
@ -39,14 +38,15 @@ class PluginManager:
def _previous_plugins(self) -> List[str]: def _previous_plugins(self) -> List[str]:
_pre_managers: List[PluginManager] _pre_managers: List[PluginManager]
if self in _managers: if self in _managers:
_pre_managers = _managers[:_managers.index(self)] _pre_managers = _managers[: _managers.index(self)]
else: else:
_pre_managers = _managers[:] _pre_managers = _managers[:]
return [ return [
*chain.from_iterable( *chain.from_iterable(
[*manager.plugins, *manager.searched_plugins.keys()] [*manager.plugins, *manager.searched_plugins.keys()]
for manager in _pre_managers) for manager in _pre_managers
)
] ]
def list_plugins(self) -> Set[str]: def list_plugins(self) -> Set[str]:
@ -57,13 +57,14 @@ class PluginManager:
for module_info in pkgutil.iter_modules(self.search_path): for module_info in pkgutil.iter_modules(self.search_path):
if module_info.name.startswith("_"): if module_info.name.startswith("_"):
continue continue
if module_info.name in searched_plugins.keys( if (
) or module_info.name in previous_plugins: module_info.name in searched_plugins.keys()
or module_info.name in previous_plugins
):
raise RuntimeError( raise RuntimeError(
f"Plugin already exists: {module_info.name}! Check your plugin name" f"Plugin already exists: {module_info.name}! Check your plugin name"
) )
module_spec = module_info.module_finder.find_spec( module_spec = module_info.module_finder.find_spec(module_info.name, None)
module_info.name, None)
if not module_spec: if not module_spec:
continue continue
module_path = module_spec.origin module_path = module_spec.origin
@ -80,14 +81,15 @@ class PluginManager:
if name in self.plugins: if name in self.plugins:
module = importlib.import_module(name) module = importlib.import_module(name)
elif name not in self.searched_plugins: elif name not in self.searched_plugins:
raise RuntimeError( raise RuntimeError(f"Plugin not found: {name}! Check your plugin name")
f"Plugin not found: {name}! Check your plugin name")
else: else:
module = importlib.import_module( module = importlib.import_module(
self._path_to_module_name(self.searched_plugins[name])) self._path_to_module_name(self.searched_plugins[name])
)
logger.opt(colors=True).success( logger.opt(colors=True).success(
f'Succeeded to import "<y>{escape_tag(name)}</y>"') f'Succeeded to import "<y>{escape_tag(name)}</y>"'
)
return getattr(module, "__plugin__", None) return getattr(module, "__plugin__", None)
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
@ -96,16 +98,17 @@ class PluginManager:
def load_all_plugins(self) -> Set[Plugin]: def load_all_plugins(self) -> Set[Plugin]:
return set( return set(
filter(None, filter(None, (self.load_plugin(name) for name in self.list_plugins()))
(self.load_plugin(name) for name in self.list_plugins()))) )
class PluginFinder(MetaPathFinder): class PluginFinder(MetaPathFinder):
def find_spec(
def find_spec(self, self,
fullname: str, fullname: str,
path: Optional[Sequence[Union[bytes, str]]], path: Optional[Sequence[Union[bytes, str]]],
target: Optional[ModuleType] = None): target: Optional[ModuleType] = None,
):
if _managers: if _managers:
index = -1 index = -1
module_spec = PathFinder.find_spec(fullname, path, target) module_spec = PathFinder.find_spec(fullname, path, target)
@ -119,10 +122,11 @@ class PluginFinder(MetaPathFinder):
while -index <= len(_managers): while -index <= len(_managers):
manager = _managers[index] manager = _managers[index]
if fullname in manager.plugins or module_path in manager.searched_plugins.values( if (
fullname in manager.plugins
or module_path in manager.searched_plugins.values()
): ):
module_spec.loader = PluginLoader(manager, fullname, module_spec.loader = PluginLoader(manager, fullname, module_origin)
module_origin)
return module_spec return module_spec
index -= 1 index -= 1
@ -130,7 +134,6 @@ class PluginFinder(MetaPathFinder):
class PluginLoader(SourceFileLoader): class PluginLoader(SourceFileLoader):
def __init__(self, manager: PluginManager, fullname: str, path) -> None: def __init__(self, manager: PluginManager, fullname: str, path) -> None:
self.manager = manager self.manager = manager
self.loaded = False self.loaded = False

View File

@ -10,8 +10,18 @@ from nonebot.matcher import Matcher
from .manager import _current_plugin from .manager import _current_plugin
from nonebot.permission import Permission from nonebot.permission import Permission
from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_StateFactory from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_StateFactory
from nonebot.rule import (PREFIX_KEY, RAW_CMD_KEY, Rule, ArgumentParser, regex, from nonebot.rule import (
command, keyword, endswith, startswith, shell_command) PREFIX_KEY,
RAW_CMD_KEY,
Rule,
ArgumentParser,
regex,
command,
keyword,
endswith,
startswith,
shell_command,
)
def _store_matcher(matcher: Type[Matcher]) -> None: def _store_matcher(matcher: Type[Matcher]) -> None:
@ -30,7 +40,8 @@ def _get_matcher_module(depth: int = 1) -> Optional[ModuleType]:
return sys.modules.get(module_name) return sys.modules.get(module_name)
def on(type: str = "", def on(
type: str = "",
rule: Optional[Union[Rule, T_RuleChecker]] = None, rule: Optional[Union[Rule, T_RuleChecker]] = None,
permission: Optional[Permission] = None, permission: Optional[Permission] = None,
*, *,
@ -40,7 +51,8 @@ def on(type: str = "",
block: bool = False, block: bool = False,
state: Optional[T_State] = None, state: Optional[T_State] = None,
state_factory: Optional[T_StateFactory] = None, state_factory: Optional[T_StateFactory] = None,
_depth: int = 0) -> Type[Matcher]: _depth: int = 0,
) -> Type[Matcher]:
""" """
:说明: :说明:
@ -62,7 +74,8 @@ def on(type: str = "",
- ``Type[Matcher]`` - ``Type[Matcher]``
""" """
matcher = Matcher.new(type, matcher = Matcher.new(
type,
Rule() & rule, Rule() & rule,
permission or Permission(), permission or Permission(),
temp=temp, temp=temp,
@ -72,12 +85,14 @@ def on(type: str = "",
plugin=_current_plugin.get(), plugin=_current_plugin.get(),
module=_get_matcher_module(_depth + 1), module=_get_matcher_module(_depth + 1),
default_state=state, default_state=state,
default_state_factory=state_factory) default_state_factory=state_factory,
)
_store_matcher(matcher) _store_matcher(matcher)
return matcher return matcher
def on_metaevent(rule: Optional[Union[Rule, T_RuleChecker]] = None, def on_metaevent(
rule: Optional[Union[Rule, T_RuleChecker]] = None,
*, *,
handlers: Optional[List[Union[T_Handler, Handler]]] = None, handlers: Optional[List[Union[T_Handler, Handler]]] = None,
temp: bool = False, temp: bool = False,
@ -85,7 +100,8 @@ def on_metaevent(rule: Optional[Union[Rule, T_RuleChecker]] = None,
block: bool = False, block: bool = False,
state: Optional[T_State] = None, state: Optional[T_State] = None,
state_factory: Optional[T_StateFactory] = None, state_factory: Optional[T_StateFactory] = None,
_depth: int = 0) -> Type[Matcher]: _depth: int = 0,
) -> Type[Matcher]:
""" """
:说明: :说明:
@ -105,7 +121,8 @@ def on_metaevent(rule: Optional[Union[Rule, T_RuleChecker]] = None,
- ``Type[Matcher]`` - ``Type[Matcher]``
""" """
matcher = Matcher.new("meta_event", matcher = Matcher.new(
"meta_event",
Rule() & rule, Rule() & rule,
Permission(), Permission(),
temp=temp, temp=temp,
@ -115,12 +132,14 @@ def on_metaevent(rule: Optional[Union[Rule, T_RuleChecker]] = None,
plugin=_current_plugin.get(), plugin=_current_plugin.get(),
module=_get_matcher_module(_depth + 1), module=_get_matcher_module(_depth + 1),
default_state=state, default_state=state,
default_state_factory=state_factory) default_state_factory=state_factory,
)
_store_matcher(matcher) _store_matcher(matcher)
return matcher return matcher
def on_message(rule: Optional[Union[Rule, T_RuleChecker]] = None, def on_message(
rule: Optional[Union[Rule, T_RuleChecker]] = None,
permission: Optional[Permission] = None, permission: Optional[Permission] = None,
*, *,
handlers: Optional[List[Union[T_Handler, Handler]]] = None, handlers: Optional[List[Union[T_Handler, Handler]]] = None,
@ -129,7 +148,8 @@ def on_message(rule: Optional[Union[Rule, T_RuleChecker]] = None,
block: bool = True, block: bool = True,
state: Optional[T_State] = None, state: Optional[T_State] = None,
state_factory: Optional[T_StateFactory] = None, state_factory: Optional[T_StateFactory] = None,
_depth: int = 0) -> Type[Matcher]: _depth: int = 0,
) -> Type[Matcher]:
""" """
:说明: :说明:
@ -150,7 +170,8 @@ def on_message(rule: Optional[Union[Rule, T_RuleChecker]] = None,
- ``Type[Matcher]`` - ``Type[Matcher]``
""" """
matcher = Matcher.new("message", matcher = Matcher.new(
"message",
Rule() & rule, Rule() & rule,
permission or Permission(), permission or Permission(),
temp=temp, temp=temp,
@ -160,12 +181,14 @@ def on_message(rule: Optional[Union[Rule, T_RuleChecker]] = None,
plugin=_current_plugin.get(), plugin=_current_plugin.get(),
module=_get_matcher_module(_depth + 1), module=_get_matcher_module(_depth + 1),
default_state=state, default_state=state,
default_state_factory=state_factory) default_state_factory=state_factory,
)
_store_matcher(matcher) _store_matcher(matcher)
return matcher return matcher
def on_notice(rule: Optional[Union[Rule, T_RuleChecker]] = None, def on_notice(
rule: Optional[Union[Rule, T_RuleChecker]] = None,
*, *,
handlers: Optional[List[Union[T_Handler, Handler]]] = None, handlers: Optional[List[Union[T_Handler, Handler]]] = None,
temp: bool = False, temp: bool = False,
@ -173,7 +196,8 @@ def on_notice(rule: Optional[Union[Rule, T_RuleChecker]] = None,
block: bool = False, block: bool = False,
state: Optional[T_State] = None, state: Optional[T_State] = None,
state_factory: Optional[T_StateFactory] = None, state_factory: Optional[T_StateFactory] = None,
_depth: int = 0) -> Type[Matcher]: _depth: int = 0,
) -> Type[Matcher]:
""" """
:说明: :说明:
@ -193,7 +217,8 @@ def on_notice(rule: Optional[Union[Rule, T_RuleChecker]] = None,
- ``Type[Matcher]`` - ``Type[Matcher]``
""" """
matcher = Matcher.new("notice", matcher = Matcher.new(
"notice",
Rule() & rule, Rule() & rule,
Permission(), Permission(),
temp=temp, temp=temp,
@ -203,12 +228,14 @@ def on_notice(rule: Optional[Union[Rule, T_RuleChecker]] = None,
plugin=_current_plugin.get(), plugin=_current_plugin.get(),
module=_get_matcher_module(_depth + 1), module=_get_matcher_module(_depth + 1),
default_state=state, default_state=state,
default_state_factory=state_factory) default_state_factory=state_factory,
)
_store_matcher(matcher) _store_matcher(matcher)
return matcher return matcher
def on_request(rule: Optional[Union[Rule, T_RuleChecker]] = None, def on_request(
rule: Optional[Union[Rule, T_RuleChecker]] = None,
*, *,
handlers: Optional[List[Union[T_Handler, Handler]]] = None, handlers: Optional[List[Union[T_Handler, Handler]]] = None,
temp: bool = False, temp: bool = False,
@ -216,7 +243,8 @@ def on_request(rule: Optional[Union[Rule, T_RuleChecker]] = None,
block: bool = False, block: bool = False,
state: Optional[T_State] = None, state: Optional[T_State] = None,
state_factory: Optional[T_StateFactory] = None, state_factory: Optional[T_StateFactory] = None,
_depth: int = 0) -> Type[Matcher]: _depth: int = 0,
) -> Type[Matcher]:
""" """
:说明: :说明:
@ -236,7 +264,8 @@ def on_request(rule: Optional[Union[Rule, T_RuleChecker]] = None,
- ``Type[Matcher]`` - ``Type[Matcher]``
""" """
matcher = Matcher.new("request", matcher = Matcher.new(
"request",
Rule() & rule, Rule() & rule,
Permission(), Permission(),
temp=temp, temp=temp,
@ -246,16 +275,19 @@ def on_request(rule: Optional[Union[Rule, T_RuleChecker]] = None,
plugin=_current_plugin.get(), plugin=_current_plugin.get(),
module=_get_matcher_module(_depth + 1), module=_get_matcher_module(_depth + 1),
default_state=state, default_state=state,
default_state_factory=state_factory) default_state_factory=state_factory,
)
_store_matcher(matcher) _store_matcher(matcher)
return matcher return matcher
def on_startswith(msg: Union[str, Tuple[str, ...]], def on_startswith(
msg: Union[str, Tuple[str, ...]],
rule: Optional[Optional[Union[Rule, T_RuleChecker]]] = None, rule: Optional[Optional[Union[Rule, T_RuleChecker]]] = None,
ignorecase: bool = False, ignorecase: bool = False,
_depth: int = 0, _depth: int = 0,
**kwargs) -> Type[Matcher]: **kwargs,
) -> Type[Matcher]:
""" """
:说明: :说明:
@ -278,16 +310,16 @@ def on_startswith(msg: Union[str, Tuple[str, ...]],
- ``Type[Matcher]`` - ``Type[Matcher]``
""" """
return on_message(startswith(msg, ignorecase) & rule, return on_message(startswith(msg, ignorecase) & rule, **kwargs, _depth=_depth + 1)
**kwargs,
_depth=_depth + 1)
def on_endswith(msg: Union[str, Tuple[str, ...]], def on_endswith(
msg: Union[str, Tuple[str, ...]],
rule: Optional[Optional[Union[Rule, T_RuleChecker]]] = None, rule: Optional[Optional[Union[Rule, T_RuleChecker]]] = None,
ignorecase: bool = False, ignorecase: bool = False,
_depth: int = 0, _depth: int = 0,
**kwargs) -> Type[Matcher]: **kwargs,
) -> Type[Matcher]:
""" """
:说明: :说明:
@ -310,15 +342,15 @@ def on_endswith(msg: Union[str, Tuple[str, ...]],
- ``Type[Matcher]`` - ``Type[Matcher]``
""" """
return on_message(endswith(msg, ignorecase) & rule, return on_message(endswith(msg, ignorecase) & rule, **kwargs, _depth=_depth + 1)
**kwargs,
_depth=_depth + 1)
def on_keyword(keywords: Set[str], def on_keyword(
keywords: Set[str],
rule: Optional[Union[Rule, T_RuleChecker]] = None, rule: Optional[Union[Rule, T_RuleChecker]] = None,
_depth: int = 0, _depth: int = 0,
**kwargs) -> Type[Matcher]: **kwargs,
) -> Type[Matcher]:
""" """
:说明: :说明:
@ -343,11 +375,13 @@ def on_keyword(keywords: Set[str],
return on_message(keyword(*keywords) & rule, **kwargs, _depth=_depth + 1) return on_message(keyword(*keywords) & rule, **kwargs, _depth=_depth + 1)
def on_command(cmd: Union[str, Tuple[str, ...]], def on_command(
cmd: Union[str, Tuple[str, ...]],
rule: Optional[Union[Rule, T_RuleChecker]] = None, rule: Optional[Union[Rule, T_RuleChecker]] = None,
aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = None, aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = None,
_depth: int = 0, _depth: int = 0,
**kwargs) -> Type[Matcher]: **kwargs,
) -> Type[Matcher]:
""" """
:说明: :说明:
@ -382,7 +416,8 @@ def on_command(cmd: Union[str, Tuple[str, ...]],
if not segment_text.startswith(state[PREFIX_KEY][RAW_CMD_KEY]): if not segment_text.startswith(state[PREFIX_KEY][RAW_CMD_KEY]):
return return
new_message = message.__class__( new_message = message.__class__(
segment_text[len(state[PREFIX_KEY][RAW_CMD_KEY]):].lstrip()) segment_text[len(state[PREFIX_KEY][RAW_CMD_KEY]) :].lstrip()
)
for new_segment in reversed(new_message): for new_segment in reversed(new_message):
message.insert(0, new_segment) message.insert(0, new_segment)
@ -390,18 +425,19 @@ def on_command(cmd: Union[str, Tuple[str, ...]],
handlers.insert(0, _strip_cmd) handlers.insert(0, _strip_cmd)
commands = set([cmd]) | (aliases or set()) commands = set([cmd]) | (aliases or set())
return on_message(command(*commands) & rule, return on_message(
handlers=handlers, command(*commands) & rule, handlers=handlers, **kwargs, _depth=_depth + 1
**kwargs, )
_depth=_depth + 1)
def on_shell_command(cmd: Union[str, Tuple[str, ...]], def on_shell_command(
cmd: Union[str, Tuple[str, ...]],
rule: Optional[Union[Rule, T_RuleChecker]] = None, rule: Optional[Union[Rule, T_RuleChecker]] = None,
aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = None, aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = None,
parser: Optional[ArgumentParser] = None, parser: Optional[ArgumentParser] = None,
_depth: int = 0, _depth: int = 0,
**kwargs) -> Type[Matcher]: **kwargs,
) -> Type[Matcher]:
""" """
:说明: :说明:
@ -434,7 +470,8 @@ def on_shell_command(cmd: Union[str, Tuple[str, ...]],
message = event.get_message() message = event.get_message()
segment = message.pop(0) segment = message.pop(0)
new_message = message.__class__( new_message = message.__class__(
str(segment)[len(state[PREFIX_KEY][RAW_CMD_KEY]):].strip()) str(segment)[len(state[PREFIX_KEY][RAW_CMD_KEY]) :].strip()
)
for new_segment in reversed(new_message): for new_segment in reversed(new_message):
message.insert(0, new_segment) message.insert(0, new_segment)
@ -442,17 +479,21 @@ def on_shell_command(cmd: Union[str, Tuple[str, ...]],
handlers.insert(0, _strip_cmd) handlers.insert(0, _strip_cmd)
commands = set([cmd]) | (aliases or set()) commands = set([cmd]) | (aliases or set())
return on_message(shell_command(*commands, parser=parser) & rule, return on_message(
shell_command(*commands, parser=parser) & rule,
handlers=handlers, handlers=handlers,
**kwargs, **kwargs,
_depth=_depth + 1) _depth=_depth + 1,
)
def on_regex(pattern: str, def on_regex(
pattern: str,
flags: Union[int, re.RegexFlag] = 0, flags: Union[int, re.RegexFlag] = 0,
rule: Optional[Union[Rule, T_RuleChecker]] = None, rule: Optional[Union[Rule, T_RuleChecker]] = None,
_depth: int = 0, _depth: int = 0,
**kwargs) -> Type[Matcher]: **kwargs,
) -> Type[Matcher]:
""" """
:说明: :说明:
@ -503,8 +544,7 @@ class CommandGroup:
- **说明**: 其他传递给 ``on_command`` 的参数默认值 - **说明**: 其他传递给 ``on_command`` 的参数默认值
""" """
def command(self, cmd: Union[str, Tuple[str, ...]], def command(self, cmd: Union[str, Tuple[str, ...]], **kwargs) -> Type[Matcher]:
**kwargs) -> Type[Matcher]:
""" """
:说明: :说明:
@ -526,8 +566,9 @@ class CommandGroup:
final_kwargs.update(kwargs) final_kwargs.update(kwargs)
return on_command(cmd, **final_kwargs, _depth=1) return on_command(cmd, **final_kwargs, _depth=1)
def shell_command(self, cmd: Union[str, Tuple[str, ...]], def shell_command(
**kwargs) -> Type[Matcher]: self, cmd: Union[str, Tuple[str, ...]], **kwargs
) -> Type[Matcher]:
""" """
:说明: :说明:
@ -708,8 +749,9 @@ class MatcherGroup:
self.matchers.append(matcher) self.matchers.append(matcher)
return matcher return matcher
def on_startswith(self, msg: Union[str, Tuple[str, ...]], def on_startswith(
**kwargs) -> Type[Matcher]: self, msg: Union[str, Tuple[str, ...]], **kwargs
) -> Type[Matcher]:
""" """
:说明: :说明:
@ -739,8 +781,7 @@ class MatcherGroup:
self.matchers.append(matcher) self.matchers.append(matcher)
return matcher return matcher
def on_endswith(self, msg: Union[str, Tuple[str, ...]], def on_endswith(self, msg: Union[str, Tuple[str, ...]], **kwargs) -> Type[Matcher]:
**kwargs) -> Type[Matcher]:
""" """
:说明: :说明:
@ -799,10 +840,12 @@ class MatcherGroup:
self.matchers.append(matcher) self.matchers.append(matcher)
return matcher return matcher
def on_command(self, def on_command(
self,
cmd: Union[str, Tuple[str, ...]], cmd: Union[str, Tuple[str, ...]],
aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = None, aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = None,
**kwargs) -> Type[Matcher]: **kwargs,
) -> Type[Matcher]:
""" """
:说明: :说明:
@ -834,12 +877,13 @@ class MatcherGroup:
self.matchers.append(matcher) self.matchers.append(matcher)
return matcher return matcher
def on_shell_command(self, def on_shell_command(
self,
cmd: Union[str, Tuple[str, ...]], cmd: Union[str, Tuple[str, ...]],
aliases: Optional[Set[Union[str, Tuple[str, aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = None,
...]]]] = None,
parser: Optional[ArgumentParser] = None, parser: Optional[ArgumentParser] = None,
**kwargs) -> Type[Matcher]: **kwargs,
) -> Type[Matcher]:
""" """
:说明: :说明:
@ -870,18 +914,15 @@ class MatcherGroup:
final_kwargs = self.base_kwargs.copy() final_kwargs = self.base_kwargs.copy()
final_kwargs.update(kwargs) final_kwargs.update(kwargs)
final_kwargs.pop("type", None) final_kwargs.pop("type", None)
matcher = on_shell_command(cmd, matcher = on_shell_command(
aliases=aliases, cmd, aliases=aliases, parser=parser, **final_kwargs, _depth=1
parser=parser, )
**final_kwargs,
_depth=1)
self.matchers.append(matcher) self.matchers.append(matcher)
return matcher return matcher
def on_regex(self, def on_regex(
pattern: str, self, pattern: str, flags: Union[int, re.RegexFlag] = 0, **kwargs
flags: Union[int, re.RegexFlag] = 0, ) -> Type[Matcher]:
**kwargs) -> Type[Matcher]:
""" """
:说明: :说明:

View File

@ -7,8 +7,8 @@ from nonebot.permission import Permission
from nonebot.rule import Rule, ArgumentParser from nonebot.rule import Rule, ArgumentParser
from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_StateFactory from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_StateFactory
def on(
def on(type: str = "", type: str = "",
rule: Optional[Union[Rule, T_RuleChecker]] = ..., rule: Optional[Union[Rule, T_RuleChecker]] = ...,
permission: Optional[Permission] = ..., permission: Optional[Permission] = ...,
*, *,
@ -17,10 +17,8 @@ def on(type: str = "",
priority: int = ..., priority: int = ...,
block: bool = ..., block: bool = ...,
state: Optional[T_State] = ..., state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]: state_factory: Optional[T_StateFactory] = ...,
... ) -> Type[Matcher]: ...
def on_metaevent( def on_metaevent(
rule: Optional[Union[Rule, T_RuleChecker]] = ..., rule: Optional[Union[Rule, T_RuleChecker]] = ...,
*, *,
@ -29,11 +27,10 @@ def on_metaevent(
priority: int = ..., priority: int = ...,
block: bool = ..., block: bool = ...,
state: Optional[T_State] = ..., state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]: state_factory: Optional[T_StateFactory] = ...,
... ) -> Type[Matcher]: ...
def on_message(
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
def on_message(rule: Optional[Union[Rule, T_RuleChecker]] = ...,
permission: Optional[Permission] = ..., permission: Optional[Permission] = ...,
*, *,
handlers: Optional[List[Union[T_Handler, Handler]]] = ..., handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
@ -41,32 +38,28 @@ def on_message(rule: Optional[Union[Rule, T_RuleChecker]] = ...,
priority: int = ..., priority: int = ...,
block: bool = ..., block: bool = ...,
state: Optional[T_State] = ..., state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]: state_factory: Optional[T_StateFactory] = ...,
... ) -> Type[Matcher]: ...
def on_notice(
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
def on_notice(rule: Optional[Union[Rule, T_RuleChecker]] = ...,
*, *,
handlers: Optional[List[Union[T_Handler, Handler]]] = ..., handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ..., temp: bool = ...,
priority: int = ..., priority: int = ...,
block: bool = ..., block: bool = ...,
state: Optional[T_State] = ..., state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]: state_factory: Optional[T_StateFactory] = ...,
... ) -> Type[Matcher]: ...
def on_request(
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
def on_request(rule: Optional[Union[Rule, T_RuleChecker]] = ...,
*, *,
handlers: Optional[List[Union[T_Handler, Handler]]] = ..., handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ..., temp: bool = ...,
priority: int = ..., priority: int = ...,
block: bool = ..., block: bool = ...,
state: Optional[T_State] = ..., state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]: state_factory: Optional[T_StateFactory] = ...,
... ) -> Type[Matcher]: ...
def on_startswith( def on_startswith(
msg: Union[str, Tuple[str, ...]], msg: Union[str, Tuple[str, ...]],
rule: Optional[Optional[Union[Rule, T_RuleChecker]]] = ..., rule: Optional[Optional[Union[Rule, T_RuleChecker]]] = ...,
@ -78,11 +71,10 @@ def on_startswith(
priority: int = ..., priority: int = ...,
block: bool = ..., block: bool = ...,
state: Optional[T_State] = ..., state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]: state_factory: Optional[T_StateFactory] = ...,
... ) -> Type[Matcher]: ...
def on_endswith(
msg: Union[str, Tuple[str, ...]],
def on_endswith(msg: Union[str, Tuple[str, ...]],
rule: Optional[Optional[Union[Rule, T_RuleChecker]]] = ..., rule: Optional[Optional[Union[Rule, T_RuleChecker]]] = ...,
ignorecase: bool = ..., ignorecase: bool = ...,
*, *,
@ -92,11 +84,10 @@ def on_endswith(msg: Union[str, Tuple[str, ...]],
priority: int = ..., priority: int = ...,
block: bool = ..., block: bool = ...,
state: Optional[T_State] = ..., state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]: state_factory: Optional[T_StateFactory] = ...,
... ) -> Type[Matcher]: ...
def on_keyword(
keywords: Set[str],
def on_keyword(keywords: Set[str],
rule: Optional[Union[Rule, T_RuleChecker]] = ..., rule: Optional[Union[Rule, T_RuleChecker]] = ...,
*, *,
permission: Optional[Permission] = ..., permission: Optional[Permission] = ...,
@ -105,11 +96,10 @@ def on_keyword(keywords: Set[str],
priority: int = ..., priority: int = ...,
block: bool = ..., block: bool = ...,
state: Optional[T_State] = ..., state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]: state_factory: Optional[T_StateFactory] = ...,
... ) -> Type[Matcher]: ...
def on_command(
cmd: Union[str, Tuple[str, ...]],
def on_command(cmd: Union[str, Tuple[str, ...]],
rule: Optional[Union[Rule, T_RuleChecker]] = ..., rule: Optional[Union[Rule, T_RuleChecker]] = ...,
aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = ..., aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = ...,
*, *,
@ -119,10 +109,8 @@ def on_command(cmd: Union[str, Tuple[str, ...]],
priority: int = ..., priority: int = ...,
block: bool = ..., block: bool = ...,
state: Optional[T_State] = ..., state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]: state_factory: Optional[T_StateFactory] = ...,
... ) -> Type[Matcher]: ...
def on_shell_command( def on_shell_command(
cmd: Union[str, Tuple[str, ...]], cmd: Union[str, Tuple[str, ...]],
rule: Optional[Union[Rule, T_RuleChecker]] = ..., rule: Optional[Union[Rule, T_RuleChecker]] = ...,
@ -135,11 +123,10 @@ def on_shell_command(
priority: int = ..., priority: int = ...,
block: bool = ..., block: bool = ...,
state: Optional[T_State] = ..., state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]: state_factory: Optional[T_StateFactory] = ...,
... ) -> Type[Matcher]: ...
def on_regex(
pattern: str,
def on_regex(pattern: str,
flags: Union[int, re.RegexFlag] = ..., flags: Union[int, re.RegexFlag] = ...,
rule: Optional[Union[Rule, T_RuleChecker]] = ..., rule: Optional[Union[Rule, T_RuleChecker]] = ...,
*, *,
@ -149,13 +136,12 @@ def on_regex(pattern: str,
priority: int = ..., priority: int = ...,
block: bool = ..., block: bool = ...,
state: Optional[T_State] = ..., state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]: state_factory: Optional[T_StateFactory] = ...,
... ) -> Type[Matcher]: ...
class CommandGroup: class CommandGroup:
def __init__(
def __init__(self, self,
cmd: Union[str, Tuple[str, ...]], cmd: Union[str, Tuple[str, ...]],
*, *,
rule: Optional[Union[Rule, T_RuleChecker]] = ..., rule: Optional[Union[Rule, T_RuleChecker]] = ...,
@ -165,10 +151,10 @@ class CommandGroup:
priority: int = ..., priority: int = ...,
block: bool = ..., block: bool = ...,
state: Optional[T_State] = ..., state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...): state_factory: Optional[T_StateFactory] = ...,
... ): ...
def command(
def command(self, self,
cmd: Union[str, Tuple[str, ...]], cmd: Union[str, Tuple[str, ...]],
*, *,
aliases: Optional[Set[Union[str, Tuple[str, ...]]]], aliases: Optional[Set[Union[str, Tuple[str, ...]]]],
@ -179,9 +165,8 @@ class CommandGroup:
priority: int = ..., priority: int = ...,
block: bool = ..., block: bool = ...,
state: Optional[T_State] = ..., state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]: state_factory: Optional[T_StateFactory] = ...,
... ) -> Type[Matcher]: ...
def shell_command( def shell_command(
self, self,
cmd: Union[str, Tuple[str, ...]], cmd: Union[str, Tuple[str, ...]],
@ -195,13 +180,12 @@ class CommandGroup:
priority: int = ..., priority: int = ...,
block: bool = ..., block: bool = ...,
state: Optional[T_State] = ..., state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]: state_factory: Optional[T_StateFactory] = ...,
... ) -> Type[Matcher]: ...
class MatcherGroup: class MatcherGroup:
def __init__(
def __init__(self, self,
*, *,
type: str = ..., type: str = ...,
rule: Optional[Union[Rule, T_RuleChecker]] = ..., rule: Optional[Union[Rule, T_RuleChecker]] = ...,
@ -211,10 +195,10 @@ class MatcherGroup:
priority: int = ..., priority: int = ...,
block: bool = ..., block: bool = ...,
state: Optional[T_State] = ..., state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...): state_factory: Optional[T_StateFactory] = ...,
... ): ...
def on(
def on(self, self,
*, *,
type: str = ..., type: str = ...,
rule: Optional[Union[Rule, T_RuleChecker]] = ..., rule: Optional[Union[Rule, T_RuleChecker]] = ...,
@ -224,9 +208,8 @@ class MatcherGroup:
priority: int = ..., priority: int = ...,
block: bool = ..., block: bool = ...,
state: Optional[T_State] = ..., state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]: state_factory: Optional[T_StateFactory] = ...,
... ) -> Type[Matcher]: ...
def on_metaevent( def on_metaevent(
self, self,
*, *,
@ -236,9 +219,8 @@ class MatcherGroup:
priority: int = ..., priority: int = ...,
block: bool = ..., block: bool = ...,
state: Optional[T_State] = ..., state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]: state_factory: Optional[T_StateFactory] = ...,
... ) -> Type[Matcher]: ...
def on_message( def on_message(
self, self,
*, *,
@ -249,9 +231,8 @@ class MatcherGroup:
priority: int = ..., priority: int = ...,
block: bool = ..., block: bool = ...,
state: Optional[T_State] = ..., state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]: state_factory: Optional[T_StateFactory] = ...,
... ) -> Type[Matcher]: ...
def on_notice( def on_notice(
self, self,
*, *,
@ -261,9 +242,8 @@ class MatcherGroup:
priority: int = ..., priority: int = ...,
block: bool = ..., block: bool = ...,
state: Optional[T_State] = ..., state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]: state_factory: Optional[T_StateFactory] = ...,
... ) -> Type[Matcher]: ...
def on_request( def on_request(
self, self,
*, *,
@ -273,9 +253,8 @@ class MatcherGroup:
priority: int = ..., priority: int = ...,
block: bool = ..., block: bool = ...,
state: Optional[T_State] = ..., state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]: state_factory: Optional[T_StateFactory] = ...,
... ) -> Type[Matcher]: ...
def on_startswith( def on_startswith(
self, self,
msg: Union[str, Tuple[str, ...]], msg: Union[str, Tuple[str, ...]],
@ -288,9 +267,8 @@ class MatcherGroup:
priority: int = ..., priority: int = ...,
block: bool = ..., block: bool = ...,
state: Optional[T_State] = ..., state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]: state_factory: Optional[T_StateFactory] = ...,
... ) -> Type[Matcher]: ...
def on_endswith( def on_endswith(
self, self,
msg: Union[str, Tuple[str, ...]], msg: Union[str, Tuple[str, ...]],
@ -303,9 +281,8 @@ class MatcherGroup:
priority: int = ..., priority: int = ...,
block: bool = ..., block: bool = ...,
state: Optional[T_State] = ..., state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]: state_factory: Optional[T_StateFactory] = ...,
... ) -> Type[Matcher]: ...
def on_keyword( def on_keyword(
self, self,
keywords: Set[str], keywords: Set[str],
@ -317,9 +294,8 @@ class MatcherGroup:
priority: int = ..., priority: int = ...,
block: bool = ..., block: bool = ...,
state: Optional[T_State] = ..., state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]: state_factory: Optional[T_StateFactory] = ...,
... ) -> Type[Matcher]: ...
def on_command( def on_command(
self, self,
cmd: Union[str, Tuple[str, ...]], cmd: Union[str, Tuple[str, ...]],
@ -332,9 +308,8 @@ class MatcherGroup:
priority: int = ..., priority: int = ...,
block: bool = ..., block: bool = ...,
state: Optional[T_State] = ..., state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]: state_factory: Optional[T_StateFactory] = ...,
... ) -> Type[Matcher]: ...
def on_shell_command( def on_shell_command(
self, self,
cmd: Union[str, Tuple[str, ...]], cmd: Union[str, Tuple[str, ...]],
@ -348,9 +323,8 @@ class MatcherGroup:
priority: int = ..., priority: int = ...,
block: bool = ..., block: bool = ...,
state: Optional[T_State] = ..., state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]: state_factory: Optional[T_StateFactory] = ...,
... ) -> Type[Matcher]: ...
def on_regex( def on_regex(
self, self,
pattern: str, pattern: str,
@ -363,5 +337,5 @@ class MatcherGroup:
priority: int = ..., priority: int = ...,
block: bool = ..., block: bool = ...,
state: Optional[T_State] = ..., state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]: state_factory: Optional[T_StateFactory] = ...,
... ) -> Type[Matcher]: ...

View File

@ -15,6 +15,7 @@ plugins: Dict[str, "Plugin"] = {}
@dataclass(eq=False) @dataclass(eq=False)
class Plugin(object): class Plugin(object):
"""存储插件信息""" """存储插件信息"""
name: str name: str
""" """
- **类型**: ``str`` - **类型**: ``str``

View File

@ -3,15 +3,18 @@ from functools import reduce
from nonebot.rule import to_me from nonebot.rule import to_me
from nonebot.plugin import on_command from nonebot.plugin import on_command
from nonebot.permission import SUPERUSER from nonebot.permission import SUPERUSER
from nonebot.adapters.cqhttp import (Message, MessageEvent, MessageSegment, from nonebot.adapters.cqhttp import (
unescape) Message,
MessageEvent,
MessageSegment,
unescape,
)
say = on_command("say", to_me(), permission=SUPERUSER) say = on_command("say", to_me(), permission=SUPERUSER)
@say.handle() @say.handle()
async def say_unescape(event: MessageEvent): async def say_unescape(event: MessageEvent):
def _unescape(message: Message, segment: MessageSegment): def _unescape(message: Message, segment: MessageSegment):
if segment.is_text(): if segment.is_text():
return message.append(unescape(str(segment))) return message.append(unescape(str(segment)))

View File

@ -1,8 +1,11 @@
from typing import Dict from typing import Dict
from nonebot.adapters import Event from nonebot.adapters import Event
from nonebot.message import (IgnoredException, run_preprocessor, from nonebot.message import (
run_postprocessor) IgnoredException,
run_preprocessor,
run_postprocessor,
)
_running_matcher: Dict[str, int] = {} _running_matcher: Dict[str, int] = {}

View File

@ -17,8 +17,18 @@ from argparse import Namespace
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from typing_extensions import TypedDict from typing_extensions import TypedDict
from argparse import ArgumentParser as ArgParser from argparse import ArgumentParser as ArgParser
from typing import (Any, Dict, List, Type, Tuple, Union, Callable, NoReturn, from typing import (
Optional, Sequence) Any,
Dict,
List,
Type,
Tuple,
Union,
Callable,
NoReturn,
Optional,
Sequence,
)
from pygtrie import CharTrie from pygtrie import CharTrie
@ -33,10 +43,9 @@ PREFIX_KEY = "_prefix"
SUFFIX_KEY = "_suffix" SUFFIX_KEY = "_suffix"
CMD_KEY = "command" CMD_KEY = "command"
RAW_CMD_KEY = "raw_command" RAW_CMD_KEY = "raw_command"
CMD_RESULT = TypedDict("CMD_RESULT", { CMD_RESULT = TypedDict(
"command": Optional[Tuple[str, ...]], "CMD_RESULT", {"command": Optional[Tuple[str, ...]], "raw_command": Optional[str]}
"raw_command": Optional[str] )
})
SHELL_ARGS = "_args" SHELL_ARGS = "_args"
SHELL_ARGV = "_argv" SHELL_ARGV = "_argv"
@ -61,11 +70,14 @@ class Rule:
from nonebot.utils import run_sync from nonebot.utils import run_sync
Rule(async_function, run_sync(sync_function)) Rule(async_function, run_sync(sync_function))
""" """
__slots__ = ("checkers",) __slots__ = ("checkers",)
HANDLER_PARAM_TYPES = [ HANDLER_PARAM_TYPES = [
params.BotParam, params.EventParam, params.StateParam, params.BotParam,
params.DefaultParam params.EventParam,
params.StateParam,
params.DefaultParam,
] ]
def __init__(self, *checkers: Union[T_RuleChecker, Handler]) -> None: def __init__(self, *checkers: Union[T_RuleChecker, Handler]) -> None:
@ -76,9 +88,11 @@ class Rule:
""" """
self.checkers = set( self.checkers = set(
checker if isinstance(checker, Handler) else Handler( checker
checker, allow_types=self.HANDLER_PARAM_TYPES) if isinstance(checker, Handler)
for checker in checkers) else Handler(checker, allow_types=self.HANDLER_PARAM_TYPES)
for checker in checkers
)
""" """
:说明: :说明:
@ -95,8 +109,8 @@ class Rule:
event: Event, event: Event,
state: T_State, state: T_State,
stack: Optional[AsyncExitStack] = None, stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[Dict[Callable[..., Any], dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None,
Any]] = None) -> bool: ) -> bool:
""" """
:说明: :说明:
@ -117,12 +131,17 @@ class Rule:
if not self.checkers: if not self.checkers:
return True return True
results = await asyncio.gather( results = await asyncio.gather(
*(checker(bot=bot, *(
checker(
bot=bot,
event=event, event=event,
state=state, state=state,
_stack=stack, _stack=stack,
_dependency_cache=dependency_cache) _dependency_cache=dependency_cache,
for checker in self.checkers)) )
for checker in self.checkers
)
)
return all(results) return all(results)
def __and__(self, other: Optional[Union["Rule", T_RuleChecker]]) -> "Rule": def __and__(self, other: Optional[Union["Rule", T_RuleChecker]]) -> "Rule":
@ -156,8 +175,9 @@ class TrieRule:
cls.suffix[suffix[::-1]] = value cls.suffix[suffix[::-1]] = value
@classmethod @classmethod
def get_value(cls, bot: Bot, event: Event, def get_value(
state: T_State) -> Tuple[CMD_RESULT, CMD_RESULT]: cls, bot: Bot, event: Event, state: T_State
) -> Tuple[CMD_RESULT, CMD_RESULT]:
prefix = CMD_RESULT(command=None, raw_command=None) prefix = CMD_RESULT(command=None, raw_command=None)
suffix = CMD_RESULT(command=None, raw_command=None) suffix = CMD_RESULT(command=None, raw_command=None)
state[PREFIX_KEY] = prefix state[PREFIX_KEY] = prefix
@ -180,8 +200,7 @@ class TrieRule:
return prefix, suffix return prefix, suffix
def startswith(msg: Union[str, Tuple[str, ...]], def startswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule:
ignorecase: bool = False) -> Rule:
""" """
:说明: :说明:
@ -196,7 +215,8 @@ def startswith(msg: Union[str, Tuple[str, ...]],
pattern = re.compile( pattern = re.compile(
f"^(?:{'|'.join(re.escape(prefix) for prefix in msg)})", f"^(?:{'|'.join(re.escape(prefix) for prefix in msg)})",
re.IGNORECASE if ignorecase else 0) re.IGNORECASE if ignorecase else 0,
)
async def _startswith(bot: Bot, event: Event, state: T_State) -> bool: async def _startswith(bot: Bot, event: Event, state: T_State) -> bool:
if event.get_type() != "message": if event.get_type() != "message":
@ -207,8 +227,7 @@ def startswith(msg: Union[str, Tuple[str, ...]],
return Rule(_startswith) return Rule(_startswith)
def endswith(msg: Union[str, Tuple[str, ...]], def endswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule:
ignorecase: bool = False) -> Rule:
""" """
:说明: :说明:
@ -223,7 +242,8 @@ def endswith(msg: Union[str, Tuple[str, ...]],
pattern = re.compile( pattern = re.compile(
f"(?:{'|'.join(re.escape(prefix) for prefix in msg)})$", f"(?:{'|'.join(re.escape(prefix) for prefix in msg)})$",
re.IGNORECASE if ignorecase else 0) re.IGNORECASE if ignorecase else 0,
)
async def _endswith(bot: Bot, event: Event, state: T_State) -> bool: async def _endswith(bot: Bot, event: Event, state: T_State) -> bool:
if event.get_type() != "message": if event.get_type() != "message":
@ -314,19 +334,22 @@ class ArgumentParser(ArgParser):
setattr(self, "message", old_message) setattr(self, "message", old_message)
def exit(self, status: int = 0, message: Optional[str] = None): def exit(self, status: int = 0, message: Optional[str] = None):
raise ParserExit(status=status, raise ParserExit(
message=message or getattr(self, "message", None)) status=status, message=message or getattr(self, "message", None)
)
def parse_args(self, def parse_args(
self,
args: Optional[Sequence[str]] = None, args: Optional[Sequence[str]] = None,
namespace: Optional[Namespace] = None) -> Namespace: namespace: Optional[Namespace] = None,
) -> Namespace:
setattr(self, "message", "") setattr(self, "message", "")
return super().parse_args(args=args, return super().parse_args(args=args, namespace=namespace) # type: ignore
namespace=namespace) # type: ignore
def shell_command(*cmds: Union[str, Tuple[str, ...]], def shell_command(
parser: Optional[ArgumentParser] = None) -> Rule: *cmds: Union[str, Tuple[str, ...]], parser: Optional[ArgumentParser] = None
) -> Rule:
r""" r"""
:说明: :说明:
@ -361,8 +384,7 @@ def shell_command(*cmds: Union[str, Tuple[str, ...]],
\:\:\: \:\:\:
""" """
if not isinstance(parser, ArgumentParser): if not isinstance(parser, ArgumentParser):
raise TypeError( raise TypeError("`parser` must be an instance of nonebot.rule.ArgumentParser")
"`parser` must be an instance of nonebot.rule.ArgumentParser")
config = get_driver().config config = get_driver().config
command_start = config.command_start command_start = config.command_start
@ -382,8 +404,7 @@ def shell_command(*cmds: Union[str, Tuple[str, ...]],
async def _shell_command(event: Event, state: T_State) -> bool: async def _shell_command(event: Event, state: T_State) -> bool:
if state[PREFIX_KEY][CMD_KEY] in commands: if state[PREFIX_KEY][CMD_KEY] in commands:
message = str(event.get_message()) message = str(event.get_message())
strip_message = message[len(state[PREFIX_KEY][RAW_CMD_KEY] strip_message = message[len(state[PREFIX_KEY][RAW_CMD_KEY]) :].lstrip()
):].lstrip()
state[SHELL_ARGV] = shlex.split(strip_message) state[SHELL_ARGV] = shlex.split(strip_message)
if parser: if parser:
try: try:

View File

@ -18,8 +18,17 @@
https://docs.python.org/3/library/typing.html https://docs.python.org/3/library/typing.html
""" """
from typing import (TYPE_CHECKING, Any, Dict, Union, TypeVar, Callable, from typing import (
NoReturn, Optional, Awaitable) TYPE_CHECKING,
Any,
Dict,
Union,
TypeVar,
Callable,
NoReturn,
Optional,
Awaitable,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from nonebot.adapters import Bot, Event from nonebot.adapters import Bot, Event
@ -29,10 +38,8 @@ T_Wrapped = TypeVar("T_Wrapped", bound=Callable)
def overrides(InterfaceClass: object): def overrides(InterfaceClass: object):
def overrider(func: T_Wrapped) -> T_Wrapped: def overrider(func: T_Wrapped) -> T_Wrapped:
assert func.__name__ in dir( assert func.__name__ in dir(InterfaceClass), f"Error method: {func.__name__}"
InterfaceClass), f"Error method: {func.__name__}"
return func return func
return overrider return overrider
@ -80,7 +87,8 @@ T_CallingAPIHook = Callable[["Bot", str, Dict[str, Any]], Awaitable[None]]
``bot.call_api`` 时执行的函数 ``bot.call_api`` 时执行的函数
""" """
T_CalledAPIHook = Callable[ T_CalledAPIHook = Callable[
["Bot", Optional[Exception], str, Dict[str, Any], Any], Awaitable[None]] ["Bot", Optional[Exception], str, Dict[str, Any], Any], Awaitable[None]
]
""" """
:类型: ``Callable[[Bot, Optional[Exception], str, Dict[str, Any], Any], Awaitable[None]]`` :类型: ``Callable[[Bot, Optional[Exception], str, Dict[str, Any], Any], Awaitable[None]]``
@ -193,8 +201,9 @@ T_DependencyCache = Dict[T_Handler, Any]
依赖缓存, 用于存储依赖函数的返回值 依赖缓存, 用于存储依赖函数的返回值
""" """
T_ArgsParser = Callable[["Bot", "Event", T_State], Union[Awaitable[None], T_ArgsParser = Callable[
Awaitable[NoReturn]]] ["Bot", "Event", T_State], Union[Awaitable[None], Awaitable[NoReturn]]
]
""" """
:类型: ``Callable[[Bot, Event, T_State], Union[Awaitable[None], Awaitable[NoReturn]]]`` :类型: ``Callable[[Bot, Event, T_State], Union[Awaitable[None], Awaitable[NoReturn]]]``
@ -210,8 +219,9 @@ T_TypeUpdater = Callable[["Bot", "Event", T_State, str], Awaitable[str]]
TypeUpdater Matcher.pause, Matcher.reject 时被运行用于更新响应的事件类型默认会更新为 ``message`` TypeUpdater Matcher.pause, Matcher.reject 时被运行用于更新响应的事件类型默认会更新为 ``message``
""" """
T_PermissionUpdater = Callable[["Bot", "Event", T_State, "Permission"], T_PermissionUpdater = Callable[
Awaitable["Permission"]] ["Bot", "Event", T_State, "Permission"], Awaitable["Permission"]
]
""" """
:类型: ``Callable[[Bot, Event, T_State, Permission], Awaitable[Permission]]`` :类型: ``Callable[[Bot, Event, T_State, Permission], Awaitable[Permission]]``

View File

@ -8,8 +8,19 @@ from functools import wraps, partial
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing_extensions import GenericAlias # type: ignore from typing_extensions import GenericAlias # type: ignore
from typing_extensions import ParamSpec, get_args, get_origin from typing_extensions import ParamSpec, get_args, get_origin
from typing import (Any, Type, Deque, Tuple, Union, TypeVar, Callable, Optional, from typing import (
Awaitable, AsyncGenerator, ContextManager) Any,
Type,
Deque,
Tuple,
Union,
TypeVar,
Callable,
Optional,
Awaitable,
AsyncGenerator,
ContextManager,
)
from nonebot.log import logger from nonebot.log import logger
from nonebot.typing import overrides from nonebot.typing import overrides
@ -37,15 +48,16 @@ def escape_tag(s: str) -> str:
def generic_check_issubclass( def generic_check_issubclass(
cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...]]
...]]) -> bool: ) -> bool:
try: try:
return issubclass(cls, class_or_tuple) return issubclass(cls, class_or_tuple)
except TypeError: except TypeError:
if get_origin(cls) is Union: if get_origin(cls) is Union:
for type_ in get_args(cls): for type_ in get_args(cls):
if type_ is not type(None) and not generic_check_issubclass( if type_ is not type(None) and not generic_check_issubclass(
type_, class_or_tuple): type_, class_or_tuple
):
return False return False
return True return True
elif isinstance(cls, GenericAlias): elif isinstance(cls, GenericAlias):
@ -104,7 +116,8 @@ def run_sync(func: Callable[P, R]) -> Callable[P, Awaitable[R]]:
@asynccontextmanager @asynccontextmanager
async def run_sync_ctx_manager( async def run_sync_ctx_manager(
cm: ContextManager[T],) -> AsyncGenerator[T, None]: cm: ContextManager[T],
) -> AsyncGenerator[T, None]:
try: try:
yield await run_sync(cm.__enter__)() yield await run_sync(cm.__enter__)()
except Exception as e: except Exception as e:
@ -122,7 +135,6 @@ def get_name(obj: Any) -> str:
class CacheLock: class CacheLock:
def __init__(self): def __init__(self):
self._waiters: Optional[Deque[asyncio.Future]] = None self._waiters: Optional[Deque[asyncio.Future]] = None
self._locked = False self._locked = False
@ -144,8 +156,9 @@ class CacheLock:
return self._locked return self._locked
async def acquire(self): async def acquire(self):
if (not self._locked and (self._waiters is None or if not self._locked and (
all(w.cancelled() for w in self._waiters))): self._waiters is None or all(w.cancelled() for w in self._waiters)
):
self._locked = True self._locked = True
return True return True
@ -223,6 +236,7 @@ def logger_wrapper(logger_name: str):
def log(level: str, message: str, exception: Optional[Exception] = None): def log(level: str, message: str, exception: Optional[Exception] = None):
return logger.opt(colors=True, exception=exception).log( return logger.opt(colors=True, exception=exception).log(
level, f"<m>{escape_tag(logger_name)}</m> | " + message) level, f"<m>{escape_tag(logger_name)}</m> | " + message
)
return log return log

View File

@ -12,8 +12,15 @@ from nonebot.typing import overrides
from nonebot.message import handle_event from nonebot.message import handle_event
from nonebot.adapters import Bot as BaseBot from nonebot.adapters import Bot as BaseBot
from nonebot.utils import DataclassEncoder, escape_tag from nonebot.utils import DataclassEncoder, escape_tag
from nonebot.drivers import (Driver, WebSocket, HTTPRequest, HTTPResponse, from nonebot.drivers import (
ForwardDriver, HTTPConnection, WebSocketSetup) Driver,
WebSocket,
HTTPRequest,
HTTPResponse,
ForwardDriver,
HTTPConnection,
WebSocketSetup,
)
from .utils import log, escape from .utils import log, escape
from .config import Config as CQHTTPConfig from .config import Config as CQHTTPConfig
@ -49,15 +56,12 @@ async def _check_reply(bot: "Bot", event: "Event"):
return return
try: try:
index = list(map(lambda x: x.type == "reply", index = list(map(lambda x: x.type == "reply", event.message)).index(True)
event.message)).index(True)
except ValueError: except ValueError:
return return
msg_seg = event.message[index] msg_seg = event.message[index]
try: try:
event.reply = Reply.parse_obj(await event.reply = Reply.parse_obj(await bot.get_msg(message_id=msg_seg.data["id"]))
bot.get_msg(message_id=msg_seg.data["id"]
))
except Exception as e: except Exception as e:
log("WARNING", f"Error when getting message reply info: {repr(e)}", e) log("WARNING", f"Error when getting message reply info: {repr(e)}", e)
return return
@ -68,8 +72,7 @@ async def _check_reply(bot: "Bot", event: "Event"):
if len(event.message) > index and event.message[index].type == "at": if len(event.message) > index and event.message[index].type == "at":
del event.message[index] del event.message[index]
if len(event.message) > index and event.message[index].type == "text": if len(event.message) > index and event.message[index].type == "text":
event.message[index].data["text"] = event.message[index].data[ event.message[index].data["text"] = event.message[index].data["text"].lstrip()
"text"].lstrip()
if not event.message[index].data["text"]: if not event.message[index].data["text"]:
del event.message[index] del event.message[index]
if not event.message: if not event.message:
@ -99,23 +102,24 @@ def _check_at_me(bot: "Bot", event: "Event"):
else: else:
def _is_at_me_seg(segment: MessageSegment): def _is_at_me_seg(segment: MessageSegment):
return segment.type == "at" and str(segment.data.get( return segment.type == "at" and str(segment.data.get("qq", "")) == str(
"qq", "")) == str(event.self_id) event.self_id
)
# check the first segment # check the first segment
if _is_at_me_seg(event.message[0]): if _is_at_me_seg(event.message[0]):
event.to_me = True event.to_me = True
event.message.pop(0) event.message.pop(0)
if event.message and event.message[0].type == "text": if event.message and event.message[0].type == "text":
event.message[0].data["text"] = event.message[0].data[ event.message[0].data["text"] = event.message[0].data["text"].lstrip()
"text"].lstrip()
if not event.message[0].data["text"]: if not event.message[0].data["text"]:
del event.message[0] del event.message[0]
if event.message and _is_at_me_seg(event.message[0]): if event.message and _is_at_me_seg(event.message[0]):
event.message.pop(0) event.message.pop(0)
if event.message and event.message[0].type == "text": if event.message and event.message[0].type == "text":
event.message[0].data["text"] = event.message[0].data[ event.message[0].data["text"] = (
"text"].lstrip() event.message[0].data["text"].lstrip()
)
if not event.message[0].data["text"]: if not event.message[0].data["text"]:
del event.message[0] del event.message[0]
@ -123,9 +127,11 @@ def _check_at_me(bot: "Bot", event: "Event"):
# check the last segment # check the last segment
i = -1 i = -1
last_msg_seg = event.message[i] last_msg_seg = event.message[i]
if last_msg_seg.type == "text" and \ if (
not last_msg_seg.data["text"].strip() and \ last_msg_seg.type == "text"
len(event.message) >= 2: and not last_msg_seg.data["text"].strip()
and len(event.message) >= 2
):
i -= 1 i -= 1
last_msg_seg = event.message[i] last_msg_seg = event.message[i]
@ -161,13 +167,12 @@ def _check_nickname(bot: "Bot", event: "Event"):
if nicknames: if nicknames:
# check if the user is calling me with my nickname # check if the user is calling me with my nickname
nickname_regex = "|".join(nicknames) nickname_regex = "|".join(nicknames)
m = re.search(rf"^({nickname_regex})([\s,]*|$)", first_text, m = re.search(rf"^({nickname_regex})([\s,]*|$)", first_text, re.IGNORECASE)
re.IGNORECASE)
if m: if m:
nickname = m.group(1) nickname = m.group(1)
log("DEBUG", f"User is calling me {nickname}") log("DEBUG", f"User is calling me {nickname}")
event.to_me = True event.to_me = True
first_msg_seg.data["text"] = first_text[m.end():] first_msg_seg.data["text"] = first_text[m.end() :]
def _handle_api_result(result: Optional[Dict[str, Any]]) -> Any: def _handle_api_result(result: Optional[Dict[str, Any]]) -> Any:
@ -206,8 +211,9 @@ class ResultStore:
@classmethod @classmethod
def add_result(cls, result: Dict[str, Any]): def add_result(cls, result: Dict[str, Any]):
if isinstance(result.get("echo"), dict) and \ if isinstance(result.get("echo"), dict) and isinstance(
isinstance(result["echo"].get("seq"), int): result["echo"].get("seq"), int
):
future = cls._futures.get(result["echo"]["seq"]) future = cls._futures.get(result["echo"]["seq"])
if future: if future:
future.set_result(result) future.set_result(result)
@ -228,6 +234,7 @@ class Bot(BaseBot):
""" """
CQHTTP 协议 Bot 适配继承属性参考 `BaseBot <./#class-basebot>`_ 。 CQHTTP 协议 Bot 适配继承属性参考 `BaseBot <./#class-basebot>`_ 。
""" """
cqhttp_config: CQHTTPConfig cqhttp_config: CQHTTPConfig
@property @property
@ -249,22 +256,25 @@ class Bot(BaseBot):
elif isinstance(driver, ForwardDriver) and cls.cqhttp_config.ws_urls: elif isinstance(driver, ForwardDriver) and cls.cqhttp_config.ws_urls:
for self_id, url in cls.cqhttp_config.ws_urls.items(): for self_id, url in cls.cqhttp_config.ws_urls.items():
try: try:
headers = { headers = (
"authorization": {"authorization": f"Bearer {cls.cqhttp_config.access_token}"}
f"Bearer {cls.cqhttp_config.access_token}" if cls.cqhttp_config.access_token
} if cls.cqhttp_config.access_token else {} else {}
)
driver.setup_websocket( driver.setup_websocket(
WebSocketSetup("cqhttp", self_id, url, headers=headers)) WebSocketSetup("cqhttp", self_id, url, headers=headers)
)
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Bad url {escape_tag(url)} for bot {escape_tag(self_id)} " f"<r><bg #f8bbd0>Bad url {escape_tag(url)} for bot {escape_tag(self_id)} "
"in cqhttp forward websocket</bg #f8bbd0></r>") "in cqhttp forward websocket</bg #f8bbd0></r>"
)
@classmethod @classmethod
@overrides(BaseBot) @overrides(BaseBot)
async def check_permission( async def check_permission(
cls, driver: Driver, cls, driver: Driver, request: HTTPConnection
request: HTTPConnection) -> Tuple[Optional[str], HTTPResponse]: ) -> Tuple[Optional[str], HTTPResponse]:
""" """
:说明: :说明:
@ -286,22 +296,26 @@ class Bot(BaseBot):
if not x_signature: if not x_signature:
log("WARNING", "Missing Signature Header") log("WARNING", "Missing Signature Header")
return None, HTTPResponse(401, b"Missing Signature") return None, HTTPResponse(401, b"Missing Signature")
sig = hmac.new(secret.encode("utf-8"), request.body, sig = hmac.new(secret.encode("utf-8"), request.body, "sha1").hexdigest()
"sha1").hexdigest()
if x_signature != "sha1=" + sig: if x_signature != "sha1=" + sig:
log("WARNING", "Signature Header is invalid") log("WARNING", "Signature Header is invalid")
return None, HTTPResponse(403, b"Signature is invalid") return None, HTTPResponse(403, b"Signature is invalid")
access_token = cqhttp_config.access_token access_token = cqhttp_config.access_token
if access_token and access_token != token and isinstance( if access_token and access_token != token and isinstance(request, WebSocket):
request, WebSocket):
log( log(
"WARNING", "Authorization Header is invalid" "WARNING",
if token else "Missing Authorization Header") "Authorization Header is invalid"
if token
else "Missing Authorization Header",
)
return None, HTTPResponse( return None, HTTPResponse(
403, b"Authorization Header is invalid" 403,
if token else b"Missing Authorization Header") b"Authorization Header is invalid"
return str(x_self_id), HTTPResponse(204, b'') if token
else b"Missing Authorization Header",
)
return str(x_self_id), HTTPResponse(204, b"")
@overrides(BaseBot) @overrides(BaseBot)
async def handle_message(self, message: bytes): async def handle_message(self, message: bytes):
@ -320,7 +334,7 @@ class Bot(BaseBot):
return return
try: try:
post_type = data['post_type'] post_type = data["post_type"]
detail_type = data.get(f"{post_type}_type") detail_type = data.get(f"{post_type}_type")
detail_type = f".{detail_type}" if detail_type else "" detail_type = f".{detail_type}" if detail_type else ""
sub_type = data.get("sub_type") sub_type = data.get("sub_type")
@ -352,17 +366,13 @@ class Bot(BaseBot):
if isinstance(self.request, WebSocket): if isinstance(self.request, WebSocket):
seq = ResultStore.get_seq() seq = ResultStore.get_seq()
json_data = json.dumps( json_data = json.dumps(
{ {"action": api, "params": data, "echo": {"seq": seq}},
"action": api, cls=DataclassEncoder,
"params": data, )
"echo": {
"seq": seq
}
},
cls=DataclassEncoder)
await self.request.send(json_data) await self.request.send(json_data)
return _handle_api_result(await ResultStore.fetch( return _handle_api_result(
seq, self.config.api_timeout)) await ResultStore.fetch(seq, self.config.api_timeout)
)
elif isinstance(self.request, HTTPRequest): elif isinstance(self.request, HTTPRequest):
api_root = self.config.api_root.get(self.self_id) api_root = self.config.api_root.get(self.self_id)
@ -373,22 +383,25 @@ class Bot(BaseBot):
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
if self.cqhttp_config.access_token is not None: if self.cqhttp_config.access_token is not None:
headers[ headers["Authorization"] = "Bearer " + self.cqhttp_config.access_token
"Authorization"] = "Bearer " + self.cqhttp_config.access_token
try: try:
async with httpx.AsyncClient(headers=headers, async with httpx.AsyncClient(
follow_redirects=True) as client: headers=headers, follow_redirects=True
) as client:
response = await client.post( response = await client.post(
api_root + api, api_root + api,
content=json.dumps(data, cls=DataclassEncoder), content=json.dumps(data, cls=DataclassEncoder),
timeout=self.config.api_timeout) timeout=self.config.api_timeout,
)
if 200 <= response.status_code < 300: if 200 <= response.status_code < 300:
result = response.json() result = response.json()
return _handle_api_result(result) return _handle_api_result(result)
raise NetworkError(f"HTTP request received unexpected " raise NetworkError(
f"status code: {response.status_code}") f"HTTP request received unexpected "
f"status code: {response.status_code}"
)
except httpx.InvalidURL: except httpx.InvalidURL:
raise NetworkError("API root url invalid") raise NetworkError("API root url invalid")
except httpx.HTTPError: except httpx.HTTPError:
@ -418,11 +431,13 @@ class Bot(BaseBot):
return await super().call_api(api, **data) return await super().call_api(api, **data)
@overrides(BaseBot) @overrides(BaseBot)
async def send(self, async def send(
self,
event: Event, event: Event,
message: Union[str, Message, MessageSegment], message: Union[str, Message, MessageSegment],
at_sender: bool = False, at_sender: bool = False,
**kwargs) -> Any: **kwargs,
) -> Any:
""" """
:说明: :说明:
@ -445,8 +460,9 @@ class Bot(BaseBot):
- ``NetworkError``: 网络错误 - ``NetworkError``: 网络错误
- ``ActionFailed``: API 调用失败 - ``ActionFailed``: API 调用失败
""" """
message = escape(message, escape_comma=False) if isinstance( message = (
message, str) else message escape(message, escape_comma=False) if isinstance(message, str) else message
)
msg = message if isinstance(message, Message) else Message(message) msg = message if isinstance(message, Message) else Message(message)
at_sender = at_sender and bool(getattr(event, "user_id", None)) at_sender = at_sender and bool(getattr(event, "user_id", None))

View File

@ -8,7 +8,6 @@ from nonebot.drivers import Driver, WebSocket
from .event import Event from .event import Event
from .message import Message, MessageSegment from .message import Message, MessageSegment
def get_auth_bearer(access_token: Optional[str] = ...) -> Optional[str]: def get_auth_bearer(access_token: Optional[str] = ...) -> Optional[str]:
... ...

View File

@ -1,6 +1,6 @@
from typing import Dict, Optional from typing import Dict, Optional
from pydantic import Field, BaseModel, AnyUrl from pydantic import Field, AnyUrl, BaseModel
# priority: alias > origin # priority: alias > origin
@ -14,11 +14,10 @@ class Config(BaseModel):
- ``secret`` / ``cqhttp_secret``: CQHTTP HTTP 上报数据签名口令 - ``secret`` / ``cqhttp_secret``: CQHTTP HTTP 上报数据签名口令
- ``ws_urls`` / ``cqhttp_ws_urls``: CQHTTP 正向 Websocket 连接 Bot ID目标 URL 字典 - ``ws_urls`` / ``cqhttp_ws_urls``: CQHTTP 正向 Websocket 连接 Bot ID目标 URL 字典
""" """
access_token: Optional[str] = Field(default=None,
alias="cqhttp_access_token") access_token: Optional[str] = Field(default=None, alias="cqhttp_access_token")
secret: Optional[str] = Field(default=None, alias="cqhttp_secret") secret: Optional[str] = Field(default=None, alias="cqhttp_secret")
ws_urls: Dict[str, AnyUrl] = Field(default_factory=set, ws_urls: Dict[str, AnyUrl] = Field(default_factory=set, alias="cqhttp_ws_urls")
alias="cqhttp_ws_urls")
class Config: class Config:
extra = "ignore" extra = "ignore"

View File

@ -5,12 +5,13 @@ from typing import TYPE_CHECKING, List, Type, Optional
from pydantic import BaseModel from pydantic import BaseModel
from pygtrie import StringTrie from pygtrie import StringTrie
from .message import Message
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.utils import escape_tag from nonebot.utils import escape_tag
from .exception import NoLogException
from nonebot.adapters import Event as BaseEvent from nonebot.adapters import Event as BaseEvent
from .message import Message
from .exception import NoLogException
if TYPE_CHECKING: if TYPE_CHECKING:
from .bot import Bot from .bot import Bot
@ -22,6 +23,7 @@ class Event(BaseEvent):
.. _CQHTTP 文档: .. _CQHTTP 文档:
https://github.com/howmanybots/onebot/blob/master/README.md https://github.com/howmanybots/onebot/blob/master/README.md
""" """
__event__ = "" __event__ = ""
time: int time: int
self_id: int self_id: int
@ -118,6 +120,7 @@ class Status(BaseModel):
# Message Events # Message Events
class MessageEvent(Event): class MessageEvent(Event):
"""消息事件""" """消息事件"""
__event__ = "message" __event__ = "message"
post_type: Literal["message"] post_type: Literal["message"]
sub_type: str sub_type: str
@ -144,8 +147,9 @@ class MessageEvent(Event):
@overrides(Event) @overrides(Event)
def get_event_name(self) -> str: def get_event_name(self) -> str:
sub_type = getattr(self, "sub_type", None) sub_type = getattr(self, "sub_type", None)
return f"{self.post_type}.{self.message_type}" + (f".{sub_type}" return f"{self.post_type}.{self.message_type}" + (
if sub_type else "") f".{sub_type}" if sub_type else ""
)
@overrides(Event) @overrides(Event)
def get_message(self) -> Message: def get_message(self) -> Message:
@ -170,20 +174,29 @@ class MessageEvent(Event):
class PrivateMessageEvent(MessageEvent): class PrivateMessageEvent(MessageEvent):
"""私聊消息""" """私聊消息"""
__event__ = "message.private" __event__ = "message.private"
message_type: Literal["private"] message_type: Literal["private"]
@overrides(Event) @overrides(Event)
def get_event_description(self) -> str: def get_event_description(self) -> str:
return (f'Message {self.message_id} from {self.user_id} "' + "".join( return (
f'Message {self.message_id} from {self.user_id} "'
+ "".join(
map( map(
lambda x: escape_tag(str(x)) lambda x: escape_tag(str(x))
if x.is_text() else f"<le>{escape_tag(str(x))}</le>", if x.is_text()
self.message)) + '"') else f"<le>{escape_tag(str(x))}</le>",
self.message,
)
)
+ '"'
)
class GroupMessageEvent(MessageEvent): class GroupMessageEvent(MessageEvent):
"""群消息""" """群消息"""
__event__ = "message.group" __event__ = "message.group"
message_type: Literal["group"] message_type: Literal["group"]
group_id: int group_id: int
@ -196,8 +209,13 @@ class GroupMessageEvent(MessageEvent):
+ "".join( + "".join(
map( map(
lambda x: escape_tag(str(x)) lambda x: escape_tag(str(x))
if x.is_text() else f"<le>{escape_tag(str(x))}</le>", if x.is_text()
self.message)) + '"') else f"<le>{escape_tag(str(x))}</le>",
self.message,
)
)
+ '"'
)
@overrides(MessageEvent) @overrides(MessageEvent)
def get_session_id(self) -> str: def get_session_id(self) -> str:
@ -207,6 +225,7 @@ class GroupMessageEvent(MessageEvent):
# Notice Events # Notice Events
class NoticeEvent(Event): class NoticeEvent(Event):
"""通知事件""" """通知事件"""
__event__ = "notice" __event__ = "notice"
post_type: Literal["notice"] post_type: Literal["notice"]
notice_type: str notice_type: str
@ -214,12 +233,14 @@ class NoticeEvent(Event):
@overrides(Event) @overrides(Event)
def get_event_name(self) -> str: def get_event_name(self) -> str:
sub_type = getattr(self, "sub_type", None) sub_type = getattr(self, "sub_type", None)
return f"{self.post_type}.{self.notice_type}" + (f".{sub_type}" return f"{self.post_type}.{self.notice_type}" + (
if sub_type else "") f".{sub_type}" if sub_type else ""
)
class GroupUploadNoticeEvent(NoticeEvent): class GroupUploadNoticeEvent(NoticeEvent):
"""群文件上传事件""" """群文件上传事件"""
__event__ = "notice.group_upload" __event__ = "notice.group_upload"
notice_type: Literal["group_upload"] notice_type: Literal["group_upload"]
user_id: int user_id: int
@ -237,6 +258,7 @@ class GroupUploadNoticeEvent(NoticeEvent):
class GroupAdminNoticeEvent(NoticeEvent): class GroupAdminNoticeEvent(NoticeEvent):
"""群管理员变动""" """群管理员变动"""
__event__ = "notice.group_admin" __event__ = "notice.group_admin"
notice_type: Literal["group_admin"] notice_type: Literal["group_admin"]
sub_type: str sub_type: str
@ -258,6 +280,7 @@ class GroupAdminNoticeEvent(NoticeEvent):
class GroupDecreaseNoticeEvent(NoticeEvent): class GroupDecreaseNoticeEvent(NoticeEvent):
"""群成员减少事件""" """群成员减少事件"""
__event__ = "notice.group_decrease" __event__ = "notice.group_decrease"
notice_type: Literal["group_decrease"] notice_type: Literal["group_decrease"]
sub_type: str sub_type: str
@ -280,6 +303,7 @@ class GroupDecreaseNoticeEvent(NoticeEvent):
class GroupIncreaseNoticeEvent(NoticeEvent): class GroupIncreaseNoticeEvent(NoticeEvent):
"""群成员增加事件""" """群成员增加事件"""
__event__ = "notice.group_increase" __event__ = "notice.group_increase"
notice_type: Literal["group_increase"] notice_type: Literal["group_increase"]
sub_type: str sub_type: str
@ -302,6 +326,7 @@ class GroupIncreaseNoticeEvent(NoticeEvent):
class GroupBanNoticeEvent(NoticeEvent): class GroupBanNoticeEvent(NoticeEvent):
"""群禁言事件""" """群禁言事件"""
__event__ = "notice.group_ban" __event__ = "notice.group_ban"
notice_type: Literal["group_ban"] notice_type: Literal["group_ban"]
sub_type: str sub_type: str
@ -325,6 +350,7 @@ class GroupBanNoticeEvent(NoticeEvent):
class FriendAddNoticeEvent(NoticeEvent): class FriendAddNoticeEvent(NoticeEvent):
"""好友添加事件""" """好友添加事件"""
__event__ = "notice.friend_add" __event__ = "notice.friend_add"
notice_type: Literal["friend_add"] notice_type: Literal["friend_add"]
user_id: int user_id: int
@ -340,6 +366,7 @@ class FriendAddNoticeEvent(NoticeEvent):
class GroupRecallNoticeEvent(NoticeEvent): class GroupRecallNoticeEvent(NoticeEvent):
"""群消息撤回事件""" """群消息撤回事件"""
__event__ = "notice.group_recall" __event__ = "notice.group_recall"
notice_type: Literal["group_recall"] notice_type: Literal["group_recall"]
user_id: int user_id: int
@ -362,6 +389,7 @@ class GroupRecallNoticeEvent(NoticeEvent):
class FriendRecallNoticeEvent(NoticeEvent): class FriendRecallNoticeEvent(NoticeEvent):
"""好友消息撤回事件""" """好友消息撤回事件"""
__event__ = "notice.friend_recall" __event__ = "notice.friend_recall"
notice_type: Literal["friend_recall"] notice_type: Literal["friend_recall"]
user_id: int user_id: int
@ -378,6 +406,7 @@ class FriendRecallNoticeEvent(NoticeEvent):
class NotifyEvent(NoticeEvent): class NotifyEvent(NoticeEvent):
"""提醒事件""" """提醒事件"""
__event__ = "notice.notify" __event__ = "notice.notify"
notice_type: Literal["notify"] notice_type: Literal["notify"]
sub_type: str sub_type: str
@ -395,6 +424,7 @@ class NotifyEvent(NoticeEvent):
class PokeNotifyEvent(NotifyEvent): class PokeNotifyEvent(NotifyEvent):
"""戳一戳提醒事件""" """戳一戳提醒事件"""
__event__ = "notice.notify.poke" __event__ = "notice.notify.poke"
sub_type: Literal["poke"] sub_type: Literal["poke"]
target_id: int target_id: int
@ -413,6 +443,7 @@ class PokeNotifyEvent(NotifyEvent):
class LuckyKingNotifyEvent(NotifyEvent): class LuckyKingNotifyEvent(NotifyEvent):
"""群红包运气王提醒事件""" """群红包运气王提醒事件"""
__event__ = "notice.notify.lucky_king" __event__ = "notice.notify.lucky_king"
sub_type: Literal["lucky_king"] sub_type: Literal["lucky_king"]
target_id: int target_id: int
@ -432,6 +463,7 @@ class LuckyKingNotifyEvent(NotifyEvent):
class HonorNotifyEvent(NotifyEvent): class HonorNotifyEvent(NotifyEvent):
"""群荣誉变更提醒事件""" """群荣誉变更提醒事件"""
__event__ = "notice.notify.honor" __event__ = "notice.notify.honor"
sub_type: Literal["honor"] sub_type: Literal["honor"]
honor_type: str honor_type: str
@ -444,6 +476,7 @@ class HonorNotifyEvent(NotifyEvent):
# Request Events # Request Events
class RequestEvent(Event): class RequestEvent(Event):
"""请求事件""" """请求事件"""
__event__ = "request" __event__ = "request"
post_type: Literal["request"] post_type: Literal["request"]
request_type: str request_type: str
@ -451,12 +484,14 @@ class RequestEvent(Event):
@overrides(Event) @overrides(Event)
def get_event_name(self) -> str: def get_event_name(self) -> str:
sub_type = getattr(self, "sub_type", None) sub_type = getattr(self, "sub_type", None)
return f"{self.post_type}.{self.request_type}" + (f".{sub_type}" return f"{self.post_type}.{self.request_type}" + (
if sub_type else "") f".{sub_type}" if sub_type else ""
)
class FriendRequestEvent(RequestEvent): class FriendRequestEvent(RequestEvent):
"""加好友请求事件""" """加好友请求事件"""
__event__ = "request.friend" __event__ = "request.friend"
request_type: Literal["friend"] request_type: Literal["friend"]
user_id: int user_id: int
@ -472,9 +507,9 @@ class FriendRequestEvent(RequestEvent):
return str(self.user_id) return str(self.user_id)
async def approve(self, bot: "Bot", remark: str = ""): async def approve(self, bot: "Bot", remark: str = ""):
return await bot.set_friend_add_request(flag=self.flag, return await bot.set_friend_add_request(
approve=True, flag=self.flag, approve=True, remark=remark
remark=remark) )
async def reject(self, bot: "Bot"): async def reject(self, bot: "Bot"):
return await bot.set_friend_add_request(flag=self.flag, approve=False) return await bot.set_friend_add_request(flag=self.flag, approve=False)
@ -482,6 +517,7 @@ class FriendRequestEvent(RequestEvent):
class GroupRequestEvent(RequestEvent): class GroupRequestEvent(RequestEvent):
"""加群请求/邀请事件""" """加群请求/邀请事件"""
__event__ = "request.group" __event__ = "request.group"
request_type: Literal["group"] request_type: Literal["group"]
sub_type: str sub_type: str
@ -499,20 +535,20 @@ class GroupRequestEvent(RequestEvent):
return f"group_{self.group_id}_{self.user_id}" return f"group_{self.group_id}_{self.user_id}"
async def approve(self, bot: "Bot"): async def approve(self, bot: "Bot"):
return await bot.set_group_add_request(flag=self.flag, return await bot.set_group_add_request(
sub_type=self.sub_type, flag=self.flag, sub_type=self.sub_type, approve=True
approve=True) )
async def reject(self, bot: "Bot", reason: str = ""): async def reject(self, bot: "Bot", reason: str = ""):
return await bot.set_group_add_request(flag=self.flag, return await bot.set_group_add_request(
sub_type=self.sub_type, flag=self.flag, sub_type=self.sub_type, approve=False, reason=reason
approve=False, )
reason=reason)
# Meta Events # Meta Events
class MetaEvent(Event): class MetaEvent(Event):
"""元事件""" """元事件"""
__event__ = "meta_event" __event__ = "meta_event"
post_type: Literal["meta_event"] post_type: Literal["meta_event"]
meta_event_type: str meta_event_type: str
@ -520,8 +556,9 @@ class MetaEvent(Event):
@overrides(Event) @overrides(Event)
def get_event_name(self) -> str: def get_event_name(self) -> str:
sub_type = getattr(self, "sub_type", None) sub_type = getattr(self, "sub_type", None)
return f"{self.post_type}.{self.meta_event_type}" + (f".{sub_type}" if return f"{self.post_type}.{self.meta_event_type}" + (
sub_type else "") f".{sub_type}" if sub_type else ""
)
@overrides(Event) @overrides(Event)
def get_log_string(self) -> str: def get_log_string(self) -> str:
@ -530,6 +567,7 @@ class MetaEvent(Event):
class LifecycleMetaEvent(MetaEvent): class LifecycleMetaEvent(MetaEvent):
"""生命周期元事件""" """生命周期元事件"""
__event__ = "meta_event.lifecycle" __event__ = "meta_event.lifecycle"
meta_event_type: Literal["lifecycle"] meta_event_type: Literal["lifecycle"]
sub_type: str sub_type: str
@ -537,6 +575,7 @@ class LifecycleMetaEvent(MetaEvent):
class HeartbeatMetaEvent(MetaEvent): class HeartbeatMetaEvent(MetaEvent):
"""心跳元事件""" """心跳元事件"""
__event__ = "meta_event.heartbeat" __event__ = "meta_event.heartbeat"
meta_event_type: Literal["heartbeat"] meta_event_type: Literal["heartbeat"]
status: Status status: Status
@ -567,12 +606,28 @@ def get_event_model(event_name) -> List[Type[Event]]:
__all__ = [ __all__ = [
"Event", "MessageEvent", "PrivateMessageEvent", "GroupMessageEvent", "Event",
"NoticeEvent", "GroupUploadNoticeEvent", "GroupAdminNoticeEvent", "MessageEvent",
"GroupDecreaseNoticeEvent", "GroupIncreaseNoticeEvent", "PrivateMessageEvent",
"GroupBanNoticeEvent", "FriendAddNoticeEvent", "GroupRecallNoticeEvent", "GroupMessageEvent",
"FriendRecallNoticeEvent", "NotifyEvent", "PokeNotifyEvent", "NoticeEvent",
"LuckyKingNotifyEvent", "HonorNotifyEvent", "RequestEvent", "GroupUploadNoticeEvent",
"FriendRequestEvent", "GroupRequestEvent", "MetaEvent", "GroupAdminNoticeEvent",
"LifecycleMetaEvent", "HeartbeatMetaEvent", "get_event_model" "GroupDecreaseNoticeEvent",
"GroupIncreaseNoticeEvent",
"GroupBanNoticeEvent",
"FriendAddNoticeEvent",
"GroupRecallNoticeEvent",
"FriendRecallNoticeEvent",
"NotifyEvent",
"PokeNotifyEvent",
"LuckyKingNotifyEvent",
"HonorNotifyEvent",
"RequestEvent",
"FriendRequestEvent",
"GroupRequestEvent",
"MetaEvent",
"LifecycleMetaEvent",
"HeartbeatMetaEvent",
"get_event_model",
] ]

View File

@ -8,7 +8,6 @@ from nonebot.exception import ApiNotAvailable as BaseApiNotAvailable
class CQHTTPAdapterException(AdapterException): class CQHTTPAdapterException(AdapterException):
def __init__(self): def __init__(self):
super().__init__("cqhttp") super().__init__("cqhttp")
@ -33,8 +32,11 @@ class ActionFailed(BaseActionFailed, CQHTTPAdapterException):
self.info = kwargs self.info = kwargs
def __repr__(self): def __repr__(self):
return f"<ActionFailed " + ", ".join( return (
f"{k}={v}" for k, v in self.info.items()) + ">" f"<ActionFailed "
+ ", ".join(f"{k}={v}" for k, v in self.info.items())
+ ">"
)
def __str__(self): def __str__(self):
return self.__repr__() return self.__repr__()

View File

@ -5,10 +5,11 @@ from base64 import b64encode
from typing import Any, Type, Tuple, Union, Mapping, Iterable, Optional, cast from typing import Any, Type, Tuple, Union, Mapping, Iterable, Optional, cast
from nonebot.typing import overrides from nonebot.typing import overrides
from .utils import log, _b2s, escape, unescape
from nonebot.adapters import Message as BaseMessage from nonebot.adapters import Message as BaseMessage
from nonebot.adapters import MessageSegment as BaseMessageSegment from nonebot.adapters import MessageSegment as BaseMessageSegment
from .utils import log, _b2s, escape, unescape
class MessageSegment(BaseMessageSegment["Message"]): class MessageSegment(BaseMessageSegment["Message"]):
""" """
@ -27,23 +28,24 @@ class MessageSegment(BaseMessageSegment["Message"]):
# process special types # process special types
if type_ == "text": if type_ == "text":
return escape( return escape(data.get("text", ""), escape_comma=False) # type: ignore
data.get("text", ""), # type: ignore
escape_comma=False)
params = ",".join( params = ",".join(
[f"{k}={escape(str(v))}" for k, v in data.items() if v is not None]) [f"{k}={escape(str(v))}" for k, v in data.items() if v is not None]
)
return f"[CQ:{type_}{',' if params else ''}{params}]" return f"[CQ:{type_}{',' if params else ''}{params}]"
@overrides(BaseMessageSegment) @overrides(BaseMessageSegment)
def __add__(self, other) -> "Message": def __add__(self, other) -> "Message":
return Message(self) + (MessageSegment.text(other) if isinstance( return Message(self) + (
other, str) else other) MessageSegment.text(other) if isinstance(other, str) else other
)
@overrides(BaseMessageSegment) @overrides(BaseMessageSegment)
def __radd__(self, other) -> "Message": def __radd__(self, other) -> "Message":
return (MessageSegment.text(other) return (
if isinstance(other, str) else Message(other)) + self MessageSegment.text(other) if isinstance(other, str) else Message(other)
) + self
@overrides(BaseMessageSegment) @overrides(BaseMessageSegment)
def is_text(self) -> bool: def is_text(self) -> bool:
@ -83,11 +85,13 @@ class MessageSegment(BaseMessageSegment["Message"]):
return MessageSegment("forward", {"id": id_}) return MessageSegment("forward", {"id": id_})
@staticmethod @staticmethod
def image(file: Union[str, bytes, BytesIO, Path], def image(
file: Union[str, bytes, BytesIO, Path],
type_: Optional[str] = None, type_: Optional[str] = None,
cache: bool = True, cache: bool = True,
proxy: bool = True, proxy: bool = True,
timeout: Optional[int] = None) -> "MessageSegment": timeout: Optional[int] = None,
) -> "MessageSegment":
if isinstance(file, BytesIO): if isinstance(file, BytesIO):
file = file.getvalue() file = file.getvalue()
if isinstance(file, bytes): if isinstance(file, bytes):
@ -95,74 +99,85 @@ class MessageSegment(BaseMessageSegment["Message"]):
elif isinstance(file, Path): elif isinstance(file, Path):
file = f"file:///{file.resolve()}" file = f"file:///{file.resolve()}"
return MessageSegment( return MessageSegment(
"image", { "image",
{
"file": file, "file": file,
"type": type_, "type": type_,
"cache": _b2s(cache), "cache": _b2s(cache),
"proxy": _b2s(proxy), "proxy": _b2s(proxy),
"timeout": timeout "timeout": timeout,
}) },
)
@staticmethod @staticmethod
def json(data: str) -> "MessageSegment": def json(data: str) -> "MessageSegment":
return MessageSegment("json", {"data": data}) return MessageSegment("json", {"data": data})
@staticmethod @staticmethod
def location(latitude: float, def location(
latitude: float,
longitude: float, longitude: float,
title: Optional[str] = None, title: Optional[str] = None,
content: Optional[str] = None) -> "MessageSegment": content: Optional[str] = None,
) -> "MessageSegment":
return MessageSegment( return MessageSegment(
"location", { "location",
{
"lat": str(latitude), "lat": str(latitude),
"lon": str(longitude), "lon": str(longitude),
"title": title, "title": title,
"content": content "content": content,
}) },
)
@staticmethod @staticmethod
def music(type_: str, id_: int) -> "MessageSegment": def music(type_: str, id_: int) -> "MessageSegment":
return MessageSegment("music", {"type": type_, "id": id_}) return MessageSegment("music", {"type": type_, "id": id_})
@staticmethod @staticmethod
def music_custom(url: str, def music_custom(
url: str,
audio: str, audio: str,
title: str, title: str,
content: Optional[str] = None, content: Optional[str] = None,
img_url: Optional[str] = None) -> "MessageSegment": img_url: Optional[str] = None,
) -> "MessageSegment":
return MessageSegment( return MessageSegment(
"music", { "music",
{
"type": "custom", "type": "custom",
"url": url, "url": url,
"audio": audio, "audio": audio,
"title": title, "title": title,
"content": content, "content": content,
"image": img_url "image": img_url,
}) },
)
@staticmethod @staticmethod
def node(id_: int) -> "MessageSegment": def node(id_: int) -> "MessageSegment":
return MessageSegment("node", {"id": str(id_)}) return MessageSegment("node", {"id": str(id_)})
@staticmethod @staticmethod
def node_custom(user_id: int, nickname: str, def node_custom(
content: Union[str, "Message"]) -> "MessageSegment": user_id: int, nickname: str, content: Union[str, "Message"]
return MessageSegment("node", { ) -> "MessageSegment":
"user_id": str(user_id), return MessageSegment(
"nickname": nickname, "node", {"user_id": str(user_id), "nickname": nickname, "content": content}
"content": content )
})
@staticmethod @staticmethod
def poke(type_: str, id_: str) -> "MessageSegment": def poke(type_: str, id_: str) -> "MessageSegment":
return MessageSegment("poke", {"type": type_, "id": id_}) return MessageSegment("poke", {"type": type_, "id": id_})
@staticmethod @staticmethod
def record(file: Union[str, bytes, BytesIO, Path], def record(
file: Union[str, bytes, BytesIO, Path],
magic: Optional[bool] = None, magic: Optional[bool] = None,
cache: Optional[bool] = None, cache: Optional[bool] = None,
proxy: Optional[bool] = None, proxy: Optional[bool] = None,
timeout: Optional[int] = None) -> "MessageSegment": timeout: Optional[int] = None,
) -> "MessageSegment":
if isinstance(file, BytesIO): if isinstance(file, BytesIO):
file = file.getvalue() file = file.getvalue()
if isinstance(file, bytes): if isinstance(file, bytes):
@ -170,13 +185,15 @@ class MessageSegment(BaseMessageSegment["Message"]):
elif isinstance(file, Path): elif isinstance(file, Path):
file = f"file:///{file.resolve()}" file = f"file:///{file.resolve()}"
return MessageSegment( return MessageSegment(
"record", { "record",
{
"file": file, "file": file,
"magic": _b2s(magic), "magic": _b2s(magic),
"cache": _b2s(cache), "cache": _b2s(cache),
"proxy": _b2s(proxy), "proxy": _b2s(proxy),
"timeout": timeout "timeout": timeout,
}) },
)
@staticmethod @staticmethod
def reply(id_: int) -> "MessageSegment": def reply(id_: int) -> "MessageSegment":
@ -191,26 +208,27 @@ class MessageSegment(BaseMessageSegment["Message"]):
return MessageSegment("shake", {}) return MessageSegment("shake", {})
@staticmethod @staticmethod
def share(url: str = "", def share(
url: str = "",
title: str = "", title: str = "",
content: Optional[str] = None, content: Optional[str] = None,
image: Optional[str] = None) -> "MessageSegment": image: Optional[str] = None,
return MessageSegment("share", { ) -> "MessageSegment":
"url": url, return MessageSegment(
"title": title, "share", {"url": url, "title": title, "content": content, "image": image}
"content": content, )
"image": image
})
@staticmethod @staticmethod
def text(text: str) -> "MessageSegment": def text(text: str) -> "MessageSegment":
return MessageSegment("text", {"text": text}) return MessageSegment("text", {"text": text})
@staticmethod @staticmethod
def video(file: Union[str, bytes, BytesIO, Path], def video(
file: Union[str, bytes, BytesIO, Path],
cache: Optional[bool] = None, cache: Optional[bool] = None,
proxy: Optional[bool] = None, proxy: Optional[bool] = None,
timeout: Optional[int] = None) -> "MessageSegment": timeout: Optional[int] = None,
) -> "MessageSegment":
if isinstance(file, BytesIO): if isinstance(file, BytesIO):
file = file.getvalue() file = file.getvalue()
if isinstance(file, bytes): if isinstance(file, bytes):
@ -218,12 +236,14 @@ class MessageSegment(BaseMessageSegment["Message"]):
elif isinstance(file, Path): elif isinstance(file, Path):
file = f"file:///{file.resolve()}" file = f"file:///{file.resolve()}"
return MessageSegment( return MessageSegment(
"video", { "video",
{
"file": file, "file": file,
"cache": _b2s(cache), "cache": _b2s(cache),
"proxy": _b2s(proxy), "proxy": _b2s(proxy),
"timeout": timeout "timeout": timeout,
}) },
)
@staticmethod @staticmethod
def xml(data: str) -> "MessageSegment": def xml(data: str) -> "MessageSegment":
@ -241,22 +261,22 @@ class Message(BaseMessage[MessageSegment]):
return MessageSegment return MessageSegment
@overrides(BaseMessage) @overrides(BaseMessage)
def __add__(self, other: Union[str, Mapping, def __add__(self, other: Union[str, Mapping, Iterable[Mapping]]) -> "Message":
Iterable[Mapping]]) -> "Message":
return super(Message, self).__add__( return super(Message, self).__add__(
MessageSegment.text(other) if isinstance(other, str) else other) MessageSegment.text(other) if isinstance(other, str) else other
)
@overrides(BaseMessage) @overrides(BaseMessage)
def __radd__(self, other: Union[str, Mapping, def __radd__(self, other: Union[str, Mapping, Iterable[Mapping]]) -> "Message":
Iterable[Mapping]]) -> "Message":
return super(Message, self).__radd__( return super(Message, self).__radd__(
MessageSegment.text(other) if isinstance(other, str) else other) MessageSegment.text(other) if isinstance(other, str) else other
)
@staticmethod @staticmethod
@overrides(BaseMessage) @overrides(BaseMessage)
def _construct( def _construct(
msg: Union[str, Mapping, msg: Union[str, Mapping, Iterable[Mapping]]
Iterable[Mapping]]) -> Iterable[MessageSegment]: ) -> Iterable[MessageSegment]:
if isinstance(msg, Mapping): if isinstance(msg, Mapping):
msg = cast(Mapping[str, Any], msg) msg = cast(Mapping[str, Any], msg)
yield MessageSegment(msg["type"], msg.get("data") or {}) yield MessageSegment(msg["type"], msg.get("data") or {})
@ -273,11 +293,12 @@ class Message(BaseMessage[MessageSegment]):
r"\[CQ:(?P<type>[a-zA-Z0-9-_.]+)" r"\[CQ:(?P<type>[a-zA-Z0-9-_.]+)"
r"(?P<params>" r"(?P<params>"
r"(?:,[a-zA-Z0-9-_.]+=[^,\]]+)*" r"(?:,[a-zA-Z0-9-_.]+=[^,\]]+)*"
r"),?\]", msg): r"),?\]",
yield "text", msg[text_begin:cqcode.pos + cqcode.start()] msg,
):
yield "text", msg[text_begin : cqcode.pos + cqcode.start()]
text_begin = cqcode.pos + cqcode.end() text_begin = cqcode.pos + cqcode.end()
yield cqcode.group("type"), cqcode.group("params").lstrip( yield cqcode.group("type"), cqcode.group("params").lstrip(",")
",")
yield "text", msg[text_begin:] yield "text", msg[text_begin:]
for type_, data in _iter_message(msg): for type_, data in _iter_message(msg):
@ -287,10 +308,11 @@ class Message(BaseMessage[MessageSegment]):
yield MessageSegment(type_, {"text": unescape(data)}) yield MessageSegment(type_, {"text": unescape(data)})
else: else:
data = { data = {
k: unescape(v) for k, v in map( k: unescape(v)
for k, v in map(
lambda x: x.split("=", maxsplit=1), lambda x: x.split("=", maxsplit=1),
filter(lambda x: x, ( filter(lambda x: x, (x.lstrip() for x in data.split(","))),
x.lstrip() for x in data.split(",")))) )
} }
yield MessageSegment(type_, data) yield MessageSegment(type_, data)

View File

@ -1,5 +1,6 @@
from nonebot.adapters import Event from nonebot.adapters import Event
from nonebot.permission import Permission from nonebot.permission import Permission
from .event import GroupMessageEvent, PrivateMessageEvent from .event import GroupMessageEvent, PrivateMessageEvent
@ -42,8 +43,7 @@ async def _group(event: Event) -> bool:
async def _group_member(event: Event) -> bool: async def _group_member(event: Event) -> bool:
return isinstance(event, return isinstance(event, GroupMessageEvent) and event.sender.role == "member"
GroupMessageEvent) and event.sender.role == "member"
async def _group_admin(event: Event) -> bool: async def _group_admin(event: Event) -> bool:
@ -76,6 +76,12 @@ GROUP_OWNER = Permission(_group_owner)
""" """
__all__ = [ __all__ = [
"PRIVATE", "PRIVATE_FRIEND", "PRIVATE_GROUP", "PRIVATE_OTHER", "GROUP", "PRIVATE",
"GROUP_MEMBER", "GROUP_ADMIN", "GROUP_OWNER" "PRIVATE_FRIEND",
"PRIVATE_GROUP",
"PRIVATE_OTHER",
"GROUP",
"GROUP_MEMBER",
"GROUP_ADMIN",
"GROUP_OWNER",
] ]

View File

@ -16,9 +16,7 @@ def escape(s: str, *, escape_comma: bool = True) -> str:
* ``s: str``: 需要转义的字符串 * ``s: str``: 需要转义的字符串
* ``escape_comma: bool``: 是否转义逗号``,`` * ``escape_comma: bool``: 是否转义逗号``,``
""" """
s = s.replace("&", "&amp;") \ s = s.replace("&", "&amp;").replace("[", "&#91;").replace("]", "&#93;")
.replace("[", "&#91;") \
.replace("]", "&#93;")
if escape_comma: if escape_comma:
s = s.replace(",", "&#44;") s = s.replace(",", "&#44;")
return s return s
@ -34,10 +32,12 @@ def unescape(s: str) -> str:
* ``s: str``: 需要转义的字符串 * ``s: str``: 需要转义的字符串
""" """
return s.replace("&#44;", ",") \ return (
.replace("&#91;", "[") \ s.replace("&#44;", ",")
.replace("&#93;", "]") \ .replace("&#91;", "[")
.replace("&#93;", "]")
.replace("&amp;", "&") .replace("&amp;", "&")
)
def _b2s(b: Optional[bool]) -> Optional[str]: def _b2s(b: Optional[bool]) -> Optional[str]:

View File

@ -34,6 +34,21 @@ nonebot2 = { path = "../../", develop = true }
# url = "https://mirrors.aliyun.com/pypi/simple/" # url = "https://mirrors.aliyun.com/pypi/simple/"
# default = true # default = true
[tool.black]
line-length = 88
target-version = ["py37", "py38", "py39"]
include = '\.pyi?$'
extend-exclude = '''
'''
[tool.isort]
profile = "black"
line_length = 80
length_sort = true
skip_gitignore = true
force_sort_within_sections = true
extra_standard_library = ["typing_extensions"]
[build-system] [build-system]
requires = ["poetry-core>=1.0.0"] requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"

View File

@ -16,10 +16,18 @@ from nonebot.drivers import Driver, HTTPRequest, HTTPResponse, HTTPConnection
from .config import Config as DingConfig from .config import Config as DingConfig
from .utils import log, calc_hmac_base64 from .utils import log, calc_hmac_base64
from .message import Message, MessageSegment from .message import Message, MessageSegment
from .exception import (ActionFailed, NetworkError, SessionExpired, from .exception import (
ApiNotAvailable) ActionFailed,
from .event import (MessageEvent, ConversationType, GroupMessageEvent, NetworkError,
PrivateMessageEvent) SessionExpired,
ApiNotAvailable,
)
from .event import (
MessageEvent,
ConversationType,
GroupMessageEvent,
PrivateMessageEvent,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from nonebot.config import Config from nonebot.config import Config
@ -31,6 +39,7 @@ class Bot(BaseBot):
""" """
钉钉 协议 Bot 适配继承属性参考 `BaseBot <./#class-basebot>`_ 。 钉钉 协议 Bot 适配继承属性参考 `BaseBot <./#class-basebot>`_ 。
""" """
ding_config: DingConfig ding_config: DingConfig
@property @property
@ -48,8 +57,8 @@ class Bot(BaseBot):
@classmethod @classmethod
@overrides(BaseBot) @overrides(BaseBot)
async def check_permission( async def check_permission(
cls, driver: Driver, cls, driver: Driver, request: HTTPConnection
request: HTTPConnection) -> Tuple[Optional[str], HTTPResponse]: ) -> Tuple[Optional[str], HTTPResponse]:
""" """
:说明: :说明:
@ -61,7 +70,8 @@ class Bot(BaseBot):
# 检查连接方式 # 检查连接方式
if not isinstance(request, HTTPRequest): if not isinstance(request, HTTPRequest):
return None, HTTPResponse( return None, HTTPResponse(
405, b"Unsupported connection type, available type: `http`") 405, b"Unsupported connection type, available type: `http`"
)
# 检查 timestamp # 检查 timestamp
if not timestamp: if not timestamp:
@ -74,13 +84,15 @@ class Bot(BaseBot):
log("WARNING", "Missing Signature Header") log("WARNING", "Missing Signature Header")
return None, HTTPResponse(400, b"Missing `sign` Header") return None, HTTPResponse(400, b"Missing `sign` Header")
sign_base64 = calc_hmac_base64(str(timestamp), secret) sign_base64 = calc_hmac_base64(str(timestamp), secret)
if sign != sign_base64.decode('utf-8'): if sign != sign_base64.decode("utf-8"):
log("WARNING", "Signature Header is invalid") log("WARNING", "Signature Header is invalid")
return None, HTTPResponse(403, b"Signature is invalid") return None, HTTPResponse(403, b"Signature is invalid")
else: else:
log("WARNING", "Ding signature check ignored!") log("WARNING", "Ding signature check ignored!")
return (json.loads(request.body.decode())["chatbotUserId"], return (
HTTPResponse(204, b'')) json.loads(request.body.decode())["chatbotUserId"],
HTTPResponse(204, b""),
)
@overrides(BaseBot) @overrides(BaseBot)
async def handle_message(self, message: bytes): async def handle_message(self, message: bytes):
@ -111,10 +123,9 @@ class Bot(BaseBot):
return return
@overrides(BaseBot) @overrides(BaseBot)
async def _call_api(self, async def _call_api(
api: str, self, api: str, event: Optional[MessageEvent] = None, **data
event: Optional[MessageEvent] = None, ) -> Any:
**data) -> Any:
if not isinstance(self.request, HTTPRequest): if not isinstance(self.request, HTTPRequest):
log("ERROR", "Only support http connection.") log("ERROR", "Only support http connection.")
return return
@ -138,7 +149,8 @@ class Bot(BaseBot):
if event: if event:
# 确保 sessionWebhook 没有过期 # 确保 sessionWebhook 没有过期
if int(datetime.now().timestamp()) > int( if int(datetime.now().timestamp()) > int(
event.sessionWebhookExpiredTime / 1000): event.sessionWebhookExpiredTime / 1000
):
raise SessionExpired raise SessionExpired
webhook = event.sessionWebhook webhook = event.sessionWebhook
@ -150,32 +162,37 @@ class Bot(BaseBot):
if not message: if not message:
raise ValueError("Message not found") raise ValueError("Message not found")
try: try:
async with httpx.AsyncClient(headers=headers, async with httpx.AsyncClient(
follow_redirects=True) as client: headers=headers, follow_redirects=True
response = await client.post(webhook, ) as client:
response = await client.post(
webhook,
params=params, params=params,
json=message._produce(), json=message._produce(),
timeout=self.config.api_timeout) timeout=self.config.api_timeout,
)
if 200 <= response.status_code < 300: if 200 <= response.status_code < 300:
result = response.json() result = response.json()
if isinstance(result, dict): if isinstance(result, dict):
if result.get("errcode") != 0: if result.get("errcode") != 0:
raise ActionFailed(errcode=result.get("errcode"), raise ActionFailed(
errmsg=result.get("errmsg")) errcode=result.get("errcode"), errmsg=result.get("errmsg")
)
return result return result
raise NetworkError(f"HTTP request received unexpected " raise NetworkError(
f"status code: {response.status_code}") f"HTTP request received unexpected "
f"status code: {response.status_code}"
)
except httpx.InvalidURL: except httpx.InvalidURL:
raise NetworkError("API root url invalid") raise NetworkError("API root url invalid")
except httpx.HTTPError: except httpx.HTTPError:
raise NetworkError("HTTP request failed") raise NetworkError("HTTP request failed")
@overrides(BaseBot) @overrides(BaseBot)
async def call_api(self, async def call_api(
api: str, self, api: str, event: Optional[MessageEvent] = None, **data
event: Optional[MessageEvent] = None, ) -> Any:
**data) -> Any:
""" """
:说明: :说明:
@ -199,13 +216,15 @@ class Bot(BaseBot):
return await super().call_api(api, event=event, **data) return await super().call_api(api, event=event, **data)
@overrides(BaseBot) @overrides(BaseBot)
async def send(self, async def send(
self,
event: MessageEvent, event: MessageEvent,
message: Union[str, "Message", "MessageSegment"], message: Union[str, "Message", "MessageSegment"],
at_sender: bool = False, at_sender: bool = False,
webhook: Optional[str] = None, webhook: Optional[str] = None,
secret: Optional[str] = None, secret: Optional[str] = None,
**kwargs) -> Any: **kwargs,
) -> Any:
""" """
:说明: :说明:
@ -241,9 +260,11 @@ class Bot(BaseBot):
params.update(kwargs) params.update(kwargs)
if at_sender and event.conversationType != ConversationType.private: if at_sender and event.conversationType != ConversationType.private:
params[ params["message"] = (
"message"] = f"@{event.senderId} " + msg + MessageSegment.atDingtalkIds( f"@{event.senderId} "
event.senderId) + msg
+ MessageSegment.atDingtalkIds(event.senderId)
)
else: else:
params["message"] = msg params["message"] = msg

View File

@ -12,6 +12,7 @@ class Config(BaseModel):
- ``access_token`` / ``ding_access_token``: 钉钉令牌 - ``access_token`` / ``ding_access_token``: 钉钉令牌
- ``secret`` / ``ding_secret``: 钉钉 HTTP 上报数据签名口令 - ``secret`` / ``ding_secret``: 钉钉 HTTP 上报数据签名口令
""" """
secret: Optional[str] = Field(default=None, alias="ding_secret") secret: Optional[str] = Field(default=None, alias="ding_secret")
access_token: Optional[str] = Field(default=None, alias="ding_access_token") access_token: Optional[str] = Field(default=None, alias="ding_access_token")

View File

@ -69,6 +69,7 @@ class ConversationType(str, Enum):
class MessageEvent(Event): class MessageEvent(Event):
"""消息事件""" """消息事件"""
msgtype: str msgtype: str
text: TextMessage text: TextMessage
msgId: str msgId: str
@ -88,11 +89,10 @@ class MessageEvent(Event):
def gen_message(cls, values: dict): def gen_message(cls, values: dict):
assert "msgtype" in values, "msgtype must be specified" assert "msgtype" in values, "msgtype must be specified"
# 其实目前钉钉机器人只能接收到 text 类型的消息 # 其实目前钉钉机器人只能接收到 text 类型的消息
assert values[ assert values["msgtype"] in values, f"{values['msgtype']} must be specified"
"msgtype"] in values, f"{values['msgtype']} must be specified" content = values[values["msgtype"]]["content"]
content = values[values['msgtype']]['content']
# 如果是被 @,第一个字符将会为空格,移除特殊情况 # 如果是被 @,第一个字符将会为空格,移除特殊情况
if content[0] == ' ': if content[0] == " ":
content = content[1:] content = content[1:]
values["message"] = content values["message"] = content
return values return values
@ -128,6 +128,7 @@ class MessageEvent(Event):
class PrivateMessageEvent(MessageEvent): class PrivateMessageEvent(MessageEvent):
"""私聊消息事件""" """私聊消息事件"""
chatbotCorpId: str chatbotCorpId: str
senderStaffId: Optional[str] senderStaffId: Optional[str]
conversationType: ConversationType = ConversationType.private conversationType: ConversationType = ConversationType.private
@ -135,6 +136,7 @@ class PrivateMessageEvent(MessageEvent):
class GroupMessageEvent(MessageEvent): class GroupMessageEvent(MessageEvent):
"""群消息事件""" """群消息事件"""
atUsers: List[AtUsersItem] atUsers: List[AtUsersItem]
conversationType: ConversationType = ConversationType.group conversationType: ConversationType = ConversationType.group
conversationTitle: str conversationTitle: str

View File

@ -1,9 +1,9 @@
from typing import Optional from typing import Optional
from nonebot.exception import (AdapterException, ActionFailed as from nonebot.exception import AdapterException
BaseActionFailed, ApiNotAvailable as from nonebot.exception import ActionFailed as BaseActionFailed
BaseApiNotAvailable, NetworkError as from nonebot.exception import NetworkError as BaseNetworkError
BaseNetworkError) from nonebot.exception import ApiNotAvailable as BaseApiNotAvailable
class DingAdapterException(AdapterException): class DingAdapterException(AdapterException):
@ -29,15 +29,13 @@ class ActionFailed(BaseActionFailed, DingAdapterException):
* ``errmsg: Optional[str]``: 错误信息 * ``errmsg: Optional[str]``: 错误信息
""" """
def __init__(self, def __init__(self, errcode: Optional[int] = None, errmsg: Optional[str] = None):
errcode: Optional[int] = None,
errmsg: Optional[str] = None):
super().__init__() super().__init__()
self.errcode = errcode self.errcode = errcode
self.errmsg = errmsg self.errmsg = errmsg
def __repr__(self): def __repr__(self):
return f"<ApiError errcode={self.errcode} errmsg=\"{self.errmsg}\">" return f'<ApiError errcode={self.errcode} errmsg="{self.errmsg}">'
def __str__(self): def __str__(self):
return self.__repr__() return self.__repr__()

View File

@ -77,10 +77,9 @@ class MessageSegment(BaseMessageSegment["Message"]):
def code(code_language: str, code: str) -> "Message": def code(code_language: str, code: str) -> "Message":
"""发送 code 消息段""" """发送 code 消息段"""
message = MessageSegment.text(code) message = MessageSegment.text(code)
message += MessageSegment.extension({ message += MessageSegment.extension(
"text_type": "code_snippet", {"text_type": "code_snippet", "code_language": code_language}
"code_language": code_language )
})
return message return message
@staticmethod @staticmethod
@ -95,16 +94,19 @@ class MessageSegment(BaseMessageSegment["Message"]):
) )
@staticmethod @staticmethod
def actionCardSingleBtn(title: str, text: str, singleTitle: str, def actionCardSingleBtn(
singleURL) -> "MessageSegment": title: str, text: str, singleTitle: str, singleURL
) -> "MessageSegment":
"""发送 ``actionCardSingleBtn`` 类型消息""" """发送 ``actionCardSingleBtn`` 类型消息"""
return MessageSegment( return MessageSegment(
"actionCard", { "actionCard",
{
"title": title, "title": title,
"text": text, "text": text,
"singleTitle": singleTitle, "singleTitle": singleTitle,
"singleURL": singleURL "singleURL": singleURL,
}) },
)
@staticmethod @staticmethod
def actionCardMultiBtns( def actionCardMultiBtns(
@ -112,7 +114,7 @@ class MessageSegment(BaseMessageSegment["Message"]):
text: str, text: str,
btns: list, btns: list,
hideAvatar: bool = False, hideAvatar: bool = False,
btnOrientation: str = '1', btnOrientation: str = "1",
) -> "MessageSegment": ) -> "MessageSegment":
""" """
发送 ``actionCardMultiBtn`` 类型消息 发送 ``actionCardMultiBtn`` 类型消息
@ -123,13 +125,15 @@ class MessageSegment(BaseMessageSegment["Message"]):
* ``btns``: ``[{ "title": title, "actionURL": actionURL }, ...]`` * ``btns``: ``[{ "title": title, "actionURL": actionURL }, ...]``
""" """
return MessageSegment( return MessageSegment(
"actionCard", { "actionCard",
{
"title": title, "title": title,
"text": text, "text": text,
"hideAvatar": "1" if hideAvatar else "0", "hideAvatar": "1" if hideAvatar else "0",
"btnOrientation": btnOrientation, "btnOrientation": btnOrientation,
"btns": btns "btns": btns,
}) },
)
@staticmethod @staticmethod
def feedCard(links: list) -> "MessageSegment": def feedCard(links: list) -> "MessageSegment":
@ -144,7 +148,7 @@ class MessageSegment(BaseMessageSegment["Message"]):
@staticmethod @staticmethod
def raw(data) -> "MessageSegment": def raw(data) -> "MessageSegment":
return MessageSegment('raw', data) return MessageSegment("raw", data)
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
# 让用户可以直接发送原始的消息格式 # 让用户可以直接发送原始的消息格式
@ -171,8 +175,8 @@ class Message(BaseMessage[MessageSegment]):
@staticmethod @staticmethod
@overrides(BaseMessage) @overrides(BaseMessage)
def _construct( def _construct(
msg: Union[str, Mapping, msg: Union[str, Mapping, Iterable[Mapping]]
Iterable[Mapping]]) -> Iterable[MessageSegment]: ) -> Iterable[MessageSegment]:
if isinstance(msg, Mapping): if isinstance(msg, Mapping):
msg = cast(Mapping[str, Any], msg) msg = cast(Mapping[str, Any], msg)
yield MessageSegment(msg["type"], msg.get("data") or {}) yield MessageSegment(msg["type"], msg.get("data") or {})
@ -187,10 +191,11 @@ class Message(BaseMessage[MessageSegment]):
segment: MessageSegment segment: MessageSegment
for segment in self: for segment in self:
# text 可以和 text 合并 # text 可以和 text 合并
if segment.type == "text" and data.get("msgtype") == 'text': if segment.type == "text" and data.get("msgtype") == "text":
data.setdefault("text", {}) data.setdefault("text", {})
data["text"]["content"] = data["text"].setdefault( data["text"]["content"] = (
"content", "") + segment.data["content"] data["text"].setdefault("content", "") + segment.data["content"]
)
else: else:
data.update(segment.to_dict()) data.update(segment.to_dict())
return data return data

View File

@ -8,10 +8,10 @@ log = logger_wrapper("DING")
def calc_hmac_base64(timestamp: str, secret: str): def calc_hmac_base64(timestamp: str, secret: str):
secret_enc = secret.encode('utf-8') secret_enc = secret.encode("utf-8")
string_to_sign = '{}\n{}'.format(timestamp, secret) string_to_sign = "{}\n{}".format(timestamp, secret)
string_to_sign_enc = string_to_sign.encode('utf-8') string_to_sign_enc = string_to_sign.encode("utf-8")
hmac_code = hmac.new(secret_enc, hmac_code = hmac.new(
string_to_sign_enc, secret_enc, string_to_sign_enc, digestmod=hashlib.sha256
digestmod=hashlib.sha256).digest() ).digest()
return base64.b64encode(hmac_code) return base64.b64encode(hmac_code)

View File

@ -34,6 +34,21 @@ nonebot2 = { path = "../../", develop = true }
# url = "https://mirrors.aliyun.com/pypi/simple/" # url = "https://mirrors.aliyun.com/pypi/simple/"
# default = true # default = true
[tool.black]
line-length = 88
target-version = ["py37", "py38", "py39"]
include = '\.pyi?$'
extend-exclude = '''
'''
[tool.isort]
profile = "black"
line_length = 80
length_sort = true
skip_gitignore = true
force_sort_within_sections = true
extra_standard_library = ["typing_extensions"]
[build-system] [build-system]
requires = ["poetry-core>=1.0.0"] requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"

View File

@ -7,7 +7,7 @@ aiocache_logger.setLevel(logging.DEBUG)
aiocache_logger.handlers.clear() aiocache_logger.handlers.clear()
aiocache_logger.addHandler(LoguruHandler()) aiocache_logger.addHandler(LoguruHandler())
from .bot import Bot as Bot
from .event import * from .event import *
from .bot import Bot as Bot
from .message import Message as Message from .message import Message as Message
from .message import MessageSegment as MessageSegment from .message import MessageSegment as MessageSegment

View File

@ -1,24 +1,39 @@
import re import re
import json import json
from typing import (TYPE_CHECKING, Any, Dict, Tuple, Union, Iterable, Optional, from typing import (
AsyncIterable, cast) TYPE_CHECKING,
Any,
Dict,
Tuple,
Union,
Iterable,
Optional,
AsyncIterable,
cast,
)
import httpx import httpx
from aiocache import Cache, cached from aiocache import Cache, cached
from aiocache.serializers import PickleSerializer from aiocache.serializers import PickleSerializer
from nonebot.log import logger from nonebot.log import logger
from .utils import AESCipher, log
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.utils import escape_tag from nonebot.utils import escape_tag
from nonebot.message import handle_event from nonebot.message import handle_event
from .config import Config as FeishuConfig
from nonebot.adapters import Bot as BaseBot from nonebot.adapters import Bot as BaseBot
from nonebot.drivers import Driver, HTTPRequest, HTTPResponse from nonebot.drivers import Driver, HTTPRequest, HTTPResponse
from .utils import AESCipher, log
from .config import Config as FeishuConfig
from .message import Message, MessageSegment, MessageSerializer from .message import Message, MessageSegment, MessageSerializer
from .exception import ActionFailed, NetworkError, ApiNotAvailable from .exception import ActionFailed, NetworkError, ApiNotAvailable
from .event import (Event, MessageEvent, GroupMessageEvent, PrivateMessageEvent, from .event import (
get_event_model) Event,
MessageEvent,
GroupMessageEvent,
PrivateMessageEvent,
get_event_model,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from nonebot.config import Config from nonebot.config import Config
@ -47,8 +62,10 @@ def _check_at_me(bot: "Bot", event: "Event"):
event.to_me = True event.to_me = True
for index, segment in enumerate(message): for index, segment in enumerate(message):
if segment.type == "at" and segment.data.get( if (
"user_name") in bot.config.nickname: segment.type == "at"
and segment.data.get("user_name") in bot.config.nickname
):
event.to_me = True event.to_me = True
del event.event.message.content[index] del event.event.message.content[index]
return return
@ -57,7 +74,8 @@ def _check_at_me(bot: "Bot", event: "Event"):
if mention["name"] in bot.config.nickname: if mention["name"] in bot.config.nickname:
event.to_me = True event.to_me = True
segment.data["text"] = segment.data["text"].replace( segment.data["text"] = segment.data["text"].replace(
f"@{mention['name']}", "") f"@{mention['name']}", ""
)
segment.data["text"] = segment.data["text"].lstrip() segment.data["text"] = segment.data["text"].lstrip()
break break
else: else:
@ -92,18 +110,18 @@ def _check_nickname(bot: "Bot", event: "Event"):
if nicknames: if nicknames:
# check if the user is calling me with my nickname # check if the user is calling me with my nickname
nickname_regex = "|".join(nicknames) nickname_regex = "|".join(nicknames)
m = re.search(rf"^({nickname_regex})([\s,]*|$)", first_text, m = re.search(rf"^({nickname_regex})([\s,]*|$)", first_text, re.IGNORECASE)
re.IGNORECASE)
if m: if m:
nickname = m.group(1) nickname = m.group(1)
log("DEBUG", f"User is calling me {nickname}") log("DEBUG", f"User is calling me {nickname}")
event.to_me = True event.to_me = True
first_msg_seg.data["text"] = first_text[m.end():] first_msg_seg.data["text"] = first_text[m.end() :]
def _handle_api_result( def _handle_api_result(
result: Union[Optional[Dict[str, Any]], str, bytes, Iterable[bytes], result: Union[
AsyncIterable[bytes]] Optional[Dict[str, Any]], str, bytes, Iterable[bytes], AsyncIterable[bytes]
]
) -> Any: ) -> Any:
""" """
:说明: :说明:
@ -158,10 +176,10 @@ class Bot(BaseBot):
cls, driver: Driver, request: HTTPRequest cls, driver: Driver, request: HTTPRequest
) -> Tuple[Optional[str], Optional[HTTPResponse]]: ) -> Tuple[Optional[str], Optional[HTTPResponse]]:
if not isinstance(request, HTTPRequest): if not isinstance(request, HTTPRequest):
log("WARNING", log("WARNING", "Unsupported connection type, available type: `http`")
"Unsupported connection type, available type: `http`")
return None, HTTPResponse( return None, HTTPResponse(
405, b"Unsupported connection type, available type: `http`") 405, b"Unsupported connection type, available type: `http`"
)
encrypt_key = cls.feishu_config.encrypt_key encrypt_key = cls.feishu_config.encrypt_key
if encrypt_key: if encrypt_key:
@ -174,16 +192,13 @@ class Bot(BaseBot):
challenge = data.get("challenge") challenge = data.get("challenge")
if challenge: if challenge:
return data.get("token"), HTTPResponse( return data.get("token"), HTTPResponse(
200, 200, json.dumps({"challenge": challenge}).encode()
json.dumps({ )
"challenge": challenge
}).encode())
schema = data.get("schema") schema = data.get("schema")
if not schema: if not schema:
return None, HTTPResponse( return None, HTTPResponse(
400, 400, b"Missing `schema` in POST body, only accept event of version 2.0"
b"Missing `schema` in POST body, only accept event of version 2.0"
) )
headers = data.get("header") headers = data.get("header")
@ -196,15 +211,13 @@ class Bot(BaseBot):
if not token: if not token:
log("WARNING", "Missing `verification token` in POST body") log("WARNING", "Missing `verification token` in POST body")
return None, HTTPResponse( return None, HTTPResponse(400, b"Missing `verification token` in POST body")
400, b"Missing `verification token` in POST body")
else: else:
if token != cls.feishu_config.verification_token: if token != cls.feishu_config.verification_token:
log("WARNING", "Verification token check failed") log("WARNING", "Verification token check failed")
return None, HTTPResponse(403, return None, HTTPResponse(403, b"Verification token check failed")
b"Verification token check failed")
return app_id, HTTPResponse(200, b'') return app_id, HTTPResponse(200, b"")
async def handle_message(self, message: bytes): async def handle_message(self, message: bytes):
""" """
@ -245,28 +258,32 @@ class Bot(BaseBot):
def _construct_url(self, path: str) -> str: def _construct_url(self, path: str) -> str:
return self.api_root + path return self.api_root + path
@cached(ttl=60 * 60, @cached(
ttl=60 * 60,
cache=Cache.MEMORY, cache=Cache.MEMORY,
key="_feishu_tenant_access_token", key="_feishu_tenant_access_token",
serializer=PickleSerializer()) serializer=PickleSerializer(),
)
async def _fetch_tenant_access_token(self) -> str: async def _fetch_tenant_access_token(self) -> str:
try: try:
async with httpx.AsyncClient(follow_redirects=True) as client: async with httpx.AsyncClient(follow_redirects=True) as client:
response = await client.post( response = await client.post(
self._construct_url( self._construct_url("auth/v3/tenant_access_token/internal/"),
"auth/v3/tenant_access_token/internal/"),
json={ json={
"app_id": self.feishu_config.app_id, "app_id": self.feishu_config.app_id,
"app_secret": self.feishu_config.app_secret "app_secret": self.feishu_config.app_secret,
}, },
timeout=self.config.api_timeout) timeout=self.config.api_timeout,
)
if 200 <= response.status_code < 300: if 200 <= response.status_code < 300:
result = response.json() result = response.json()
return result["tenant_access_token"] return result["tenant_access_token"]
else: else:
raise NetworkError(f"HTTP request received unexpected " raise NetworkError(
f"status code: {response.status_code}") f"HTTP request received unexpected "
f"status code: {response.status_code}"
)
except httpx.InvalidURL: except httpx.InvalidURL:
raise NetworkError("API root url invalid") raise NetworkError("API root url invalid")
except httpx.HTTPError: except httpx.HTTPError:
@ -280,30 +297,37 @@ class Bot(BaseBot):
raise ApiNotAvailable raise ApiNotAvailable
headers = {} headers = {}
self.feishu_config.tenant_access_token = await self._fetch_tenant_access_token( self.feishu_config.tenant_access_token = (
await self._fetch_tenant_access_token()
)
headers["Authorization"] = (
"Bearer " + self.feishu_config.tenant_access_token
) )
headers[
"Authorization"] = "Bearer " + self.feishu_config.tenant_access_token
try: try:
async with httpx.AsyncClient(timeout=self.config.api_timeout, async with httpx.AsyncClient(
follow_redirects=True) as client: timeout=self.config.api_timeout, follow_redirects=True
) as client:
response = await client.send( response = await client.send(
httpx.Request(data["method"], httpx.Request(
data["method"],
self.api_root + api, self.api_root + api,
json=data.get("body", {}), json=data.get("body", {}),
params=data.get("query", {}), params=data.get("query", {}),
headers=headers)) headers=headers,
)
)
if 200 <= response.status_code < 300: if 200 <= response.status_code < 300:
if response.headers["content-type"].startswith( if response.headers["content-type"].startswith("application/json"):
"application/json"):
result = response.json() result = response.json()
else: else:
result = response.content result = response.content
return _handle_api_result(result) return _handle_api_result(result)
raise NetworkError(f"HTTP request received unexpected " raise NetworkError(
f"HTTP request received unexpected "
f"status code: {response.status_code} " f"status code: {response.status_code} "
f"response body: {response.text}") f"response body: {response.text}"
)
except httpx.InvalidURL: except httpx.InvalidURL:
raise NetworkError("API root url invalid") raise NetworkError("API root url invalid")
except httpx.HTTPError: except httpx.HTTPError:
@ -333,11 +357,13 @@ class Bot(BaseBot):
return await super().call_api(api, **data) return await super().call_api(api, **data)
@overrides(BaseBot) @overrides(BaseBot)
async def send(self, async def send(
self,
event: Event, event: Event,
message: Union[str, Message, MessageSegment], message: Union[str, Message, MessageSegment],
at_sender: bool = False, at_sender: bool = False,
**kwargs) -> Any: **kwargs,
) -> Any:
msg = message if isinstance(message, Message) else Message(message) msg = message if isinstance(message, Message) else Message(message)
if isinstance(event, GroupMessageEvent): if isinstance(event, GroupMessageEvent):
@ -346,7 +372,8 @@ class Bot(BaseBot):
receive_id, receive_id_type = event.get_user_id(), "open_id" receive_id, receive_id_type = event.get_user_id(), "open_id"
else: else:
raise ValueError( raise ValueError(
"Cannot guess `receive_id` and `receive_id_type` to reply!") "Cannot guess `receive_id` and `receive_id_type` to reply!"
)
at_sender = at_sender and bool(event.get_user_id()) at_sender = at_sender and bool(event.get_user_id())
@ -357,14 +384,12 @@ class Bot(BaseBot):
params = { params = {
"method": "POST", "method": "POST",
"query": { "query": {"receive_id_type": receive_id_type},
"receive_id_type": receive_id_type
},
"body": { "body": {
"receive_id": receive_id, "receive_id": receive_id,
"content": content, "content": content,
"msg_type": msg_type "msg_type": msg_type,
} },
} }
return await self.call_api(f"im/v1/messages", **params) return await self.call_api(f"im/v1/messages", **params)

View File

@ -17,13 +17,16 @@ class Config(BaseModel):
- ``is_lark`` / ``feishu_is_lark``: 是否使用Lark飞书海外版默认为 false - ``is_lark`` / ``feishu_is_lark``: 是否使用Lark飞书海外版默认为 false
""" """
app_id: Optional[str] = Field(default=None, alias="feishu_app_id") app_id: Optional[str] = Field(default=None, alias="feishu_app_id")
app_secret: Optional[str] = Field(default=None, alias="feishu_app_secret") app_secret: Optional[str] = Field(default=None, alias="feishu_app_secret")
encrypt_key: Optional[str] = Field(default=None, alias="feishu_encrypt_key") encrypt_key: Optional[str] = Field(default=None, alias="feishu_encrypt_key")
verification_token: Optional[str] = Field(default=None, verification_token: Optional[str] = Field(
alias="feishu_verification_token") default=None, alias="feishu_verification_token"
)
tenant_access_token: Optional[str] = Field( tenant_access_token: Optional[str] = Field(
default=None, alias="feishu_tenant_access_token") default=None, alias="feishu_tenant_access_token"
)
is_lark: Optional[str] = Field(default=False, alias="feishu_is_lark") is_lark: Optional[str] = Field(default=False, alias="feishu_is_lark")
class Config: class Config:

View File

@ -1,12 +1,12 @@
import inspect
import json import json
from typing import Any, Dict, List, Literal, Optional, Type import inspect
from typing import Any, Dict, List, Type, Literal, Optional
from pydantic import BaseModel, Field, root_validator
from pygtrie import StringTrie from pygtrie import StringTrie
from pydantic import Field, BaseModel, root_validator
from nonebot.adapters import Event as BaseEvent
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.adapters import Event as BaseEvent
from .message import Message, MessageDeserializer from .message import Message, MessageDeserializer

View File

@ -1,13 +1,12 @@
from typing import Optional from typing import Optional
from nonebot.exception import ActionFailed as BaseActionFailed
from nonebot.exception import AdapterException from nonebot.exception import AdapterException
from nonebot.exception import ApiNotAvailable as BaseApiNotAvailable from nonebot.exception import ActionFailed as BaseActionFailed
from nonebot.exception import NetworkError as BaseNetworkError from nonebot.exception import NetworkError as BaseNetworkError
from nonebot.exception import ApiNotAvailable as BaseApiNotAvailable
class FeishuAdapterException(AdapterException): class FeishuAdapterException(AdapterException):
def __init__(self): def __init__(self):
super().__init__("feishu") super().__init__("feishu")
@ -28,8 +27,11 @@ class ActionFailed(BaseActionFailed, FeishuAdapterException):
self.info = kwargs self.info = kwargs
def __repr__(self): def __repr__(self):
return f"<ActionFailed " + ", ".join( return (
f"{k}={v}" for k, v in self.info.items()) + ">" f"<ActionFailed "
+ ", ".join(f"{k}={v}" for k, v in self.info.items())
+ ">"
)
def __str__(self): def __str__(self):
return self.__repr__() return self.__repr__()

View File

@ -1,8 +1,18 @@
import json import json
import itertools import itertools
from dataclasses import dataclass from dataclasses import dataclass
from typing import (Any, Dict, List, Type, Tuple, Union, Mapping, Iterable, from typing import (
Optional, cast) Any,
Dict,
List,
Type,
Tuple,
Union,
Mapping,
Iterable,
Optional,
cast,
)
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.adapters import Message as BaseMessage from nonebot.adapters import Message as BaseMessage
@ -34,7 +44,7 @@ class MessageSegment(BaseMessageSegment["Message"]):
"share_user": "[个人名片]", "share_user": "[个人名片]",
"system": "[系统消息]", "system": "[系统消息]",
"location": "[位置]", "location": "[位置]",
"video_chat": "[视频通话]" "video_chat": "[视频通话]",
} }
def __str__(self) -> str: def __str__(self) -> str:
@ -47,24 +57,26 @@ class MessageSegment(BaseMessageSegment["Message"]):
@overrides(BaseMessageSegment) @overrides(BaseMessageSegment)
def __add__(self, other) -> "Message": def __add__(self, other) -> "Message":
return Message(self) + (MessageSegment.text(other) if isinstance( return Message(self) + (
other, str) else other) MessageSegment.text(other) if isinstance(other, str) else other
)
@overrides(BaseMessageSegment) @overrides(BaseMessageSegment)
def __radd__(self, other) -> "Message": def __radd__(self, other) -> "Message":
return (MessageSegment.text(other) return (
if isinstance(other, str) else Message(other)) + self MessageSegment.text(other) if isinstance(other, str) else Message(other)
) + self
@overrides(BaseMessageSegment) @overrides(BaseMessageSegment)
def is_text(self) -> bool: def is_text(self) -> bool:
return self.type == "text" return self.type == "text"
#接收消息 # 接收消息
@staticmethod @staticmethod
def at(user_id: str) -> "MessageSegment": def at(user_id: str) -> "MessageSegment":
return MessageSegment("at", {"user_id": user_id}) return MessageSegment("at", {"user_id": user_id})
#发送消息 # 发送消息
@staticmethod @staticmethod
def text(text: str) -> "MessageSegment": def text(text: str) -> "MessageSegment":
return MessageSegment("text", {"text": text}) return MessageSegment("text", {"text": text})
@ -79,10 +91,7 @@ class MessageSegment(BaseMessageSegment["Message"]):
@staticmethod @staticmethod
def interactive(title: str, elements: list) -> "MessageSegment": def interactive(title: str, elements: list) -> "MessageSegment":
return MessageSegment("interactive", { return MessageSegment("interactive", {"title": title, "elements": elements})
"title": title,
"elements": elements
})
@staticmethod @staticmethod
def share_chat(chat_id: str) -> "MessageSegment": def share_chat(chat_id: str) -> "MessageSegment":
@ -94,28 +103,25 @@ class MessageSegment(BaseMessageSegment["Message"]):
@staticmethod @staticmethod
def audio(file_key: str, duration: int) -> "MessageSegment": def audio(file_key: str, duration: int) -> "MessageSegment":
return MessageSegment("audio", { return MessageSegment("audio", {"file_key": file_key, "duration": duration})
"file_key": file_key,
"duration": duration
})
@staticmethod @staticmethod
def media(file_key: str, image_key: str, file_name: str, def media(
duration: int) -> "MessageSegment": file_key: str, image_key: str, file_name: str, duration: int
) -> "MessageSegment":
return MessageSegment( return MessageSegment(
"media", { "media",
{
"file_key": file_key, "file_key": file_key,
"image_key": image_key, "image_key": image_key,
"file_name": file_name, "file_name": file_name,
"duration": duration "duration": duration,
}) },
)
@staticmethod @staticmethod
def file(file_key: str, file_name: str) -> "MessageSegment": def file(file_key: str, file_name: str) -> "MessageSegment":
return MessageSegment("file", { return MessageSegment("file", {"file_key": file_key, "file_name": file_name})
"file_key": file_key,
"file_name": file_name
})
@staticmethod @staticmethod
def sticker(file_key) -> "MessageSegment": def sticker(file_key) -> "MessageSegment":
@ -133,22 +139,22 @@ class Message(BaseMessage[MessageSegment]):
return MessageSegment return MessageSegment
@overrides(BaseMessage) @overrides(BaseMessage)
def __add__(self, other: Union[str, Mapping, def __add__(self, other: Union[str, Mapping, Iterable[Mapping]]) -> "Message":
Iterable[Mapping]]) -> "Message":
return super(Message, self).__add__( return super(Message, self).__add__(
MessageSegment.text(other) if isinstance(other, str) else other) MessageSegment.text(other) if isinstance(other, str) else other
)
@overrides(BaseMessage) @overrides(BaseMessage)
def __radd__(self, other: Union[str, Mapping, def __radd__(self, other: Union[str, Mapping, Iterable[Mapping]]) -> "Message":
Iterable[Mapping]]) -> "Message":
return super(Message, self).__radd__( return super(Message, self).__radd__(
MessageSegment.text(other) if isinstance(other, str) else other) MessageSegment.text(other) if isinstance(other, str) else other
)
@staticmethod @staticmethod
@overrides(BaseMessage) @overrides(BaseMessage)
def _construct( def _construct(
msg: Union[str, Mapping, msg: Union[str, Mapping, Iterable[Mapping]]
Iterable[Mapping]]) -> Iterable[MessageSegment]: ) -> Iterable[MessageSegment]:
if isinstance(msg, Mapping): if isinstance(msg, Mapping):
msg = cast(Mapping[str, Any], msg) msg = cast(Mapping[str, Any], msg)
yield MessageSegment(msg["type"], msg.get("data") or {}) yield MessageSegment(msg["type"], msg.get("data") or {})
@ -169,7 +175,8 @@ class Message(BaseMessage[MessageSegment]):
for i, seg in enumerate(self): for i, seg in enumerate(self):
if seg.type == "text" and i != 0 and msg[-1].type == "text": if seg.type == "text" and i != 0 and msg[-1].type == "text":
msg[-1] = MessageSegment( msg[-1] = MessageSegment(
"text", {"text": msg[-1].data["text"] + seg.data["text"]}) "text", {"text": msg[-1].data["text"] + seg.data["text"]}
)
else: else:
msg.append(seg) msg.append(seg)
return Message(msg) return Message(msg)
@ -184,6 +191,7 @@ class MessageSerializer:
""" """
飞书 协议 Message 序列化器 飞书 协议 Message 序列化器
""" """
message: Message message: Message
def serialize(self) -> Tuple[str, str]: def serialize(self) -> Tuple[str, str]:
@ -198,10 +206,12 @@ class MessageSerializer:
else: else:
if last_segment_type == "image": if last_segment_type == "image":
msg["content"].append([]) msg["content"].append([])
msg["content"][-1].append({ msg["content"][-1].append(
{
"tag": segment.type if segment.type != "image" else "img", "tag": segment.type if segment.type != "image" else "img",
**segment.data **segment.data,
}) }
)
last_segment_type = segment.type last_segment_type = segment.type
return "post", json.dumps({"zh_cn": {**msg}}) return "post", json.dumps({"zh_cn": {**msg}})
@ -214,6 +224,7 @@ class MessageDeserializer:
""" """
飞书 协议 Message 反序列化器 飞书 协议 Message 反序列化器
""" """
type: str type: str
data: Dict[str, Any] data: Dict[str, Any]
mentions: Optional[List[dict]] mentions: Optional[List[dict]]
@ -227,14 +238,13 @@ class MessageDeserializer:
if self.type == "post": if self.type == "post":
msg = Message() msg = Message()
if self.data["title"] != "": if self.data["title"] != "":
msg += MessageSegment("text", {'text': self.data["title"]}) msg += MessageSegment("text", {"text": self.data["title"]})
for seg in itertools.chain(*self.data["content"]): for seg in itertools.chain(*self.data["content"]):
tag = seg.pop("tag") tag = seg.pop("tag")
if tag == "at": if tag == "at":
seg["user_name"] = dict_mention[seg["user_id"]]["name"] seg["user_name"] = dict_mention[seg["user_id"]]["name"]
seg["user_id"] = dict_mention[ seg["user_id"] = dict_mention[seg["user_id"]]["id"]["open_id"]
seg["user_id"]]["id"]["open_id"]
msg += MessageSegment(tag if tag != "img" else "image", seg) msg += MessageSegment(tag if tag != "img" else "image", seg)
@ -242,7 +252,8 @@ class MessageDeserializer:
elif self.type == "text": elif self.type == "text":
for key, mention in dict_mention.items(): for key, mention in dict_mention.items():
self.data["text"] = self.data["text"].replace( self.data["text"] = self.data["text"].replace(
key, f"@{mention['name']}") key, f"@{mention['name']}"
)
self.data["mentions"] = dict_mention self.data["mentions"] = dict_mention
return Message(MessageSegment(self.type, self.data)) return Message(MessageSegment(self.type, self.data))

View File

@ -9,27 +9,26 @@ log = logger_wrapper("FEISHU")
class AESCipher(object): class AESCipher(object):
def __init__(self, key): def __init__(self, key):
self.block_size = AES.block_size self.block_size = AES.block_size
self.key = hashlib.sha256(AESCipher.str_to_bytes(key)).digest() self.key = hashlib.sha256(AESCipher.str_to_bytes(key)).digest()
@staticmethod @staticmethod
def str_to_bytes(data): def str_to_bytes(data):
u_type = type(b"".decode('utf8')) u_type = type(b"".decode("utf8"))
if isinstance(data, u_type): if isinstance(data, u_type):
return data.encode('utf8') return data.encode("utf8")
return data return data
@staticmethod @staticmethod
def _unpad(s): def _unpad(s):
return s[:-ord(s[len(s) - 1:])] return s[: -ord(s[len(s) - 1 :])]
def decrypt(self, enc): def decrypt(self, enc):
iv = enc[:AES.block_size] iv = enc[: AES.block_size]
cipher = AES.new(self.key, AES.MODE_CBC, iv) cipher = AES.new(self.key, AES.MODE_CBC, iv)
return self._unpad(cipher.decrypt(enc[AES.block_size:])) return self._unpad(cipher.decrypt(enc[AES.block_size :]))
def decrypt_string(self, enc): def decrypt_string(self, enc):
enc = base64.b64decode(enc) enc = base64.b64decode(enc)
return self.decrypt(enc).decode('utf8') return self.decrypt(enc).decode("utf8")

View File

@ -36,6 +36,21 @@ nonebot2 = { path = "../../", develop = true }
# url = "https://mirrors.aliyun.com/pypi/simple/" # url = "https://mirrors.aliyun.com/pypi/simple/"
# default = true # default = true
[tool.black]
line-length = 88
target-version = ["py37", "py38", "py39"]
include = '\.pyi?$'
extend-exclude = '''
'''
[tool.isort]
profile = "black"
line_length = 80
length_sort = true
skip_gitignore = true
force_sort_within_sections = true
extra_standard_library = ["typing_extensions"]
[build-system] [build-system]
requires = ["poetry-core>=1.0.0"] requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"

View File

@ -12,8 +12,15 @@ from nonebot.config import Config
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.adapters import Bot as BaseBot from nonebot.adapters import Bot as BaseBot
from nonebot.exception import ApiNotAvailable from nonebot.exception import ApiNotAvailable
from nonebot.drivers import (Driver, WebSocket, HTTPResponse, ForwardDriver, from nonebot.drivers import (
ReverseDriver, HTTPConnection, WebSocketSetup) Driver,
WebSocket,
HTTPResponse,
ForwardDriver,
ReverseDriver,
HTTPConnection,
WebSocketSetup,
)
from .config import Config as MiraiConfig from .config import Config as MiraiConfig
from .message import MessageChain, MessageSegment from .message import MessageChain, MessageSegment
@ -23,16 +30,14 @@ from .utils import Log, process_event, argument_validation, catch_network_error
class SessionManager: class SessionManager:
"""Bot会话管理器, 提供API主动调用接口""" """Bot会话管理器, 提供API主动调用接口"""
sessions: Dict[int, Tuple[str, httpx.AsyncClient]] = {} sessions: Dict[int, Tuple[str, httpx.AsyncClient]] = {}
def __init__(self, session_key: str, client: httpx.AsyncClient): def __init__(self, session_key: str, client: httpx.AsyncClient):
self.session_key, self.client = session_key, client self.session_key, self.client = session_key, client
@catch_network_error @catch_network_error
async def post(self, async def post(self, path: str, *, params: Optional[Dict[str, Any]] = None) -> Any:
path: str,
*,
params: Optional[Dict[str, Any]] = None) -> Any:
""" """
:说明: :说明:
@ -51,7 +56,7 @@ class SessionManager:
path, path,
json={ json={
**(params or {}), **(params or {}),
'sessionKey': self.session_key, "sessionKey": self.session_key,
}, },
timeout=3, timeout=3,
) )
@ -59,10 +64,9 @@ class SessionManager:
return response.json() return response.json()
@catch_network_error @catch_network_error
async def request(self, async def request(
path: str, self, path: str, *, params: Optional[Dict[str, Any]] = None
*, ) -> Any:
params: Optional[Dict[str, Any]] = None) -> Any:
""" """
:说明: :说明:
@ -77,7 +81,7 @@ class SessionManager:
path, path,
params={ params={
**(params or {}), **(params or {}),
'sessionKey': self.session_key, "sessionKey": self.session_key,
}, },
timeout=3, timeout=3,
) )
@ -98,7 +102,7 @@ class SessionManager:
""" """
files = {k: v for k, v in params.items() if isinstance(v, BytesIO)} files = {k: v for k, v in params.items() if isinstance(v, BytesIO)}
form = {k: v for k, v in params.items() if k not in files} form = {k: v for k, v in params.items() if k not in files}
form['sessionKey'] = self.session_key form["sessionKey"] = self.session_key
response = await self.client.post( response = await self.client.post(
path, path,
data=form, data=form,
@ -109,25 +113,25 @@ class SessionManager:
return response.json() return response.json()
@classmethod @classmethod
async def new(cls, self_id: int, *, host: IPv4Address, port: int, async def new(
auth_key: str) -> "SessionManager": cls, self_id: int, *, host: IPv4Address, port: int, auth_key: str
) -> "SessionManager":
session = cls.get(self_id) session = cls.get(self_id)
if session is not None: if session is not None:
return session return session
client = httpx.AsyncClient(base_url=f'http://{host}:{port}', client = httpx.AsyncClient(
follow_redirects=True) base_url=f"http://{host}:{port}", follow_redirects=True
response = await client.post('/auth', json={'authKey': auth_key}) )
response = await client.post("/auth", json={"authKey": auth_key})
response.raise_for_status() response.raise_for_status()
auth = response.json() auth = response.json()
assert auth['code'] == 0 assert auth["code"] == 0
session_key = auth['session'] session_key = auth["session"]
response = await client.post('/verify', response = await client.post(
json={ "/verify", json={"sessionKey": session_key, "qq": self_id}
'sessionKey': session_key, )
'qq': self_id assert response.json()["code"] == 0
})
assert response.json()['code'] == 0
cls.sessions[self_id] = session_key, client cls.sessions[self_id] = session_key, client
return cls(session_key, client) return cls(session_key, client)
@ -152,7 +156,7 @@ class Bot(BaseBot):
""" """
_type = 'mirai' _type = "mirai"
@property @property
@overrides(BaseBot) @overrides(BaseBot)
@ -166,37 +170,42 @@ class Bot(BaseBot):
if api is None: if api is None:
if isinstance(self.request, WebSocket): if isinstance(self.request, WebSocket):
asyncio.create_task(self.request.close(1000)) asyncio.create_task(self.request.close(1000))
assert api is not None, 'SessionManager has not been initialized' assert api is not None, "SessionManager has not been initialized"
return api return api
@classmethod @classmethod
@overrides(BaseBot) @overrides(BaseBot)
async def check_permission( async def check_permission(
cls, driver: Driver, cls, driver: Driver, request: HTTPConnection
request: HTTPConnection) -> Tuple[Optional[str], HTTPResponse]: ) -> Tuple[Optional[str], HTTPResponse]:
if isinstance(request, WebSocket): if isinstance(request, WebSocket):
return None, HTTPResponse( return None, HTTPResponse(501, b"Websocket connection is not implemented")
501, b'Websocket connection is not implemented') self_id: Optional[str] = request.headers.get("bot")
self_id: Optional[str] = request.headers.get('bot')
if self_id is None: if self_id is None:
return None, HTTPResponse(400, b'Header `Bot` is required.') return None, HTTPResponse(400, b"Header `Bot` is required.")
self_id = str(self_id).strip() self_id = str(self_id).strip()
await SessionManager.new( await SessionManager.new(
int(self_id), int(self_id),
host=cls.mirai_config.host, # type: ignore host=cls.mirai_config.host, # type: ignore
port=cls.mirai_config.port, #type: ignore port=cls.mirai_config.port, # type: ignore
auth_key=cls.mirai_config.auth_key) # type: ignore auth_key=cls.mirai_config.auth_key, # type: ignore
return self_id, HTTPResponse(204, b'') )
return self_id, HTTPResponse(204, b"")
@classmethod @classmethod
@overrides(BaseBot) @overrides(BaseBot)
def register(cls, def register(
cls,
driver: Driver, driver: Driver,
config: "Config", config: "Config",
qq: Optional[Union[int, List[int]]] = None): qq: Optional[Union[int, List[int]]] = None,
):
cls.mirai_config = MiraiConfig(**config.dict()) cls.mirai_config = MiraiConfig(**config.dict())
if (cls.mirai_config.auth_key and cls.mirai_config.host and if (
cls.mirai_config.port) is None: cls.mirai_config.auth_key
and cls.mirai_config.host
and cls.mirai_config.port
) is None:
raise ApiNotAvailable(cls._type) raise ApiNotAvailable(cls._type)
super().register(driver, config) super().register(driver, config)
@ -209,17 +218,25 @@ class Bot(BaseBot):
self_ids = [qq] if isinstance(qq, int) else qq self_ids = [qq] if isinstance(qq, int) else qq
async def url_factory(qq: int): async def url_factory(qq: int):
assert cls.mirai_config.host and cls.mirai_config.port and cls.mirai_config.auth_key assert (
cls.mirai_config.host
and cls.mirai_config.port
and cls.mirai_config.auth_key
)
session = await SessionManager.new( session = await SessionManager.new(
qq, qq,
host=cls.mirai_config.host, host=cls.mirai_config.host,
port=cls.mirai_config.port, port=cls.mirai_config.port,
auth_key=cls.mirai_config.auth_key) auth_key=cls.mirai_config.auth_key,
)
return WebSocketSetup( return WebSocketSetup(
adapter=cls._type, adapter=cls._type,
self_id=str(qq), self_id=str(qq),
url=(f'ws://{cls.mirai_config.host}:{cls.mirai_config.port}' url=(
f'/all?sessionKey={session.session_key}')) f"ws://{cls.mirai_config.host}:{cls.mirai_config.port}"
f"/all?sessionKey={session.session_key}"
),
)
for self_id in self_ids: for self_id in self_ids:
driver.setup_websocket(partial(url_factory, qq=self_id)) driver.setup_websocket(partial(url_factory, qq=self_id))
@ -234,13 +251,15 @@ class Bot(BaseBot):
try: try:
await process_event( await process_event(
bot=self, bot=self,
event=Event.new({ event=Event.new(
{
**json.loads(message), **json.loads(message),
'self_id': self.self_id, "self_id": self.self_id,
}), }
),
) )
except Exception as e: except Exception as e:
Log.error(f'Failed to handle message: {message}', e) Log.error(f"Failed to handle message: {message}", e)
@overrides(BaseBot) @overrides(BaseBot)
async def _call_api(self, api: str, **data) -> NoReturn: async def _call_api(self, api: str, **data) -> NoReturn:
@ -266,10 +285,12 @@ class Bot(BaseBot):
@overrides(BaseBot) @overrides(BaseBot)
@argument_validation @argument_validation
async def send(self, async def send(
self,
event: Event, event: Event,
message: Union[MessageChain, MessageSegment, str], message: Union[MessageChain, MessageSegment, str],
at_sender: bool = False): at_sender: bool = False,
):
""" """
:说明: :说明:
@ -284,23 +305,24 @@ class Bot(BaseBot):
if not isinstance(message, MessageChain): if not isinstance(message, MessageChain):
message = MessageChain(message) message = MessageChain(message)
if isinstance(event, FriendMessage): if isinstance(event, FriendMessage):
return await self.send_friend_message(target=event.sender.id, return await self.send_friend_message(
message_chain=message) target=event.sender.id, message_chain=message
)
elif isinstance(event, GroupMessage): elif isinstance(event, GroupMessage):
if at_sender: if at_sender:
message = MessageSegment.at(event.sender.id) + message message = MessageSegment.at(event.sender.id) + message
return await self.send_group_message(group=event.sender.group.id, return await self.send_group_message(
message_chain=message) group=event.sender.group.id, message_chain=message
)
elif isinstance(event, TempMessage): elif isinstance(event, TempMessage):
return await self.send_temp_message(qq=event.sender.id, return await self.send_temp_message(
group=event.sender.group.id, qq=event.sender.id, group=event.sender.group.id, message_chain=message
message_chain=message) )
else: else:
raise ValueError(f'Unsupported event type {event!r}.') raise ValueError(f"Unsupported event type {event!r}.")
@argument_validation @argument_validation
async def send_friend_message(self, target: int, async def send_friend_message(self, target: int, message_chain: MessageChain):
message_chain: MessageChain):
""" """
:说明: :说明:
@ -311,15 +333,13 @@ class Bot(BaseBot):
* ``target: int``: 发送消息目标好友的 QQ * ``target: int``: 发送消息目标好友的 QQ
* ``message_chain: MessageChain``: 消息链是一个消息对象构成的数组 * ``message_chain: MessageChain``: 消息链是一个消息对象构成的数组
""" """
return await self.api.post('sendFriendMessage', return await self.api.post(
params={ "sendFriendMessage",
'target': target, params={"target": target, "messageChain": message_chain.export()},
'messageChain': message_chain.export() )
})
@argument_validation @argument_validation
async def send_temp_message(self, qq: int, group: int, async def send_temp_message(self, qq: int, group: int, message_chain: MessageChain):
message_chain: MessageChain):
""" """
:说明: :说明:
@ -331,18 +351,15 @@ class Bot(BaseBot):
* ``group: int``: 临时会话群号 * ``group: int``: 临时会话群号
* ``message_chain: MessageChain``: 消息链是一个消息对象构成的数组 * ``message_chain: MessageChain``: 消息链是一个消息对象构成的数组
""" """
return await self.api.post('sendTempMessage', return await self.api.post(
params={ "sendTempMessage",
'qq': qq, params={"qq": qq, "group": group, "messageChain": message_chain.export()},
'group': group, )
'messageChain': message_chain.export()
})
@argument_validation @argument_validation
async def send_group_message(self, async def send_group_message(
group: int, self, group: int, message_chain: MessageChain, quote: Optional[int] = None
message_chain: MessageChain, ):
quote: Optional[int] = None):
""" """
:说明: :说明:
@ -354,12 +371,14 @@ class Bot(BaseBot):
* ``message_chain: MessageChain``: 消息链是一个消息对象构成的数组 * ``message_chain: MessageChain``: 消息链是一个消息对象构成的数组
* ``quote: Optional[int]``: 引用一条消息的 message_id 进行回复 * ``quote: Optional[int]``: 引用一条消息的 message_id 进行回复
""" """
return await self.api.post('sendGroupMessage', return await self.api.post(
"sendGroupMessage",
params={ params={
'group': group, "group": group,
'messageChain': message_chain.export(), "messageChain": message_chain.export(),
'quote': quote "quote": quote,
}) },
)
@argument_validation @argument_validation
async def recall(self, target: int): async def recall(self, target: int):
@ -372,11 +391,12 @@ class Bot(BaseBot):
* ``target: int``: 需要撤回的消息的message_id * ``target: int``: 需要撤回的消息的message_id
""" """
return await self.api.post('recall', params={'target': target}) return await self.api.post("recall", params={"target": target})
@argument_validation @argument_validation
async def send_image_message(self, target: int, qq: int, group: int, async def send_image_message(
urls: List[str]) -> List[str]: self, target: int, qq: int, group: int, urls: List[str]
) -> List[str]:
""" """
:说明: :说明:
@ -396,13 +416,10 @@ class Bot(BaseBot):
- ``List[str]``: 一个包含图片imageId的数组 - ``List[str]``: 一个包含图片imageId的数组
""" """
return await self.api.post('sendImageMessage', return await self.api.post(
params={ "sendImageMessage",
'target': target, params={"target": target, "qq": qq, "group": group, "urls": urls},
'qq': qq, )
'group': group,
'urls': urls
})
@argument_validation @argument_validation
async def upload_image(self, type: str, img: BytesIO): async def upload_image(self, type: str, img: BytesIO):
@ -416,11 +433,7 @@ class Bot(BaseBot):
* ``type: str``: "friend" "group" "temp" * ``type: str``: "friend" "group" "temp"
* ``img: BytesIO``: 图片的BytesIO对象 * ``img: BytesIO``: 图片的BytesIO对象
""" """
return await self.api.upload('uploadImage', return await self.api.upload("uploadImage", params={"type": type, "img": img})
params={
'type': type,
'img': img
})
@argument_validation @argument_validation
async def upload_voice(self, type: str, voice: BytesIO): async def upload_voice(self, type: str, voice: BytesIO):
@ -434,11 +447,9 @@ class Bot(BaseBot):
* ``type: str``: 当前仅支持 "group" * ``type: str``: 当前仅支持 "group"
* ``voice: BytesIO``: 语音的BytesIO对象 * ``voice: BytesIO``: 语音的BytesIO对象
""" """
return await self.api.upload('uploadVoice', return await self.api.upload(
params={ "uploadVoice", params={"type": type, "voice": voice}
'type': type, )
'voice': voice
})
@argument_validation @argument_validation
async def fetch_message(self, count: int = 10): async def fetch_message(self, count: int = 10):
@ -452,7 +463,7 @@ class Bot(BaseBot):
* ``count: int``: 获取消息和事件的数量 * ``count: int``: 获取消息和事件的数量
""" """
return await self.api.request('fetchMessage', params={'count': count}) return await self.api.request("fetchMessage", params={"count": count})
@argument_validation @argument_validation
async def fetch_latest_message(self, count: int = 10): async def fetch_latest_message(self, count: int = 10):
@ -466,8 +477,7 @@ class Bot(BaseBot):
* ``count: int``: 获取消息和事件的数量 * ``count: int``: 获取消息和事件的数量
""" """
return await self.api.request('fetchLatestMessage', return await self.api.request("fetchLatestMessage", params={"count": count})
params={'count': count})
@argument_validation @argument_validation
async def peek_message(self, count: int = 10): async def peek_message(self, count: int = 10):
@ -481,7 +491,7 @@ class Bot(BaseBot):
* ``count: int``: 获取消息和事件的数量 * ``count: int``: 获取消息和事件的数量
""" """
return await self.api.request('peekMessage', params={'count': count}) return await self.api.request("peekMessage", params={"count": count})
@argument_validation @argument_validation
async def peek_latest_message(self, count: int = 10): async def peek_latest_message(self, count: int = 10):
@ -495,8 +505,7 @@ class Bot(BaseBot):
* ``count: int``: 获取消息和事件的数量 * ``count: int``: 获取消息和事件的数量
""" """
return await self.api.request('peekLatestMessage', return await self.api.request("peekLatestMessage", params={"count": count})
params={'count': count})
@argument_validation @argument_validation
async def messsage_from_id(self, id: int): async def messsage_from_id(self, id: int):
@ -510,7 +519,7 @@ class Bot(BaseBot):
* ``id: int``: 获取消息的message_id * ``id: int``: 获取消息的message_id
""" """
return await self.api.request('messageFromId', params={'id': id}) return await self.api.request("messageFromId", params={"id": id})
@argument_validation @argument_validation
async def count_message(self): async def count_message(self):
@ -519,7 +528,7 @@ class Bot(BaseBot):
使用此方法获取bot接收并缓存的消息总数注意不包含被删除的 使用此方法获取bot接收并缓存的消息总数注意不包含被删除的
""" """
return await self.api.request('countMessage') return await self.api.request("countMessage")
@argument_validation @argument_validation
async def friend_list(self) -> List[Dict[str, Any]]: async def friend_list(self) -> List[Dict[str, Any]]:
@ -532,7 +541,7 @@ class Bot(BaseBot):
- ``List[Dict[str, Any]]``: 返回的好友列表数据 - ``List[Dict[str, Any]]``: 返回的好友列表数据
""" """
return await self.api.request('friendList') return await self.api.request("friendList")
@argument_validation @argument_validation
async def group_list(self) -> List[Dict[str, Any]]: async def group_list(self) -> List[Dict[str, Any]]:
@ -545,7 +554,7 @@ class Bot(BaseBot):
- ``List[Dict[str, Any]]``: 返回的群列表数据 - ``List[Dict[str, Any]]``: 返回的群列表数据
""" """
return await self.api.request('groupList') return await self.api.request("groupList")
@argument_validation @argument_validation
async def member_list(self, target: int) -> List[Dict[str, Any]]: async def member_list(self, target: int) -> List[Dict[str, Any]]:
@ -562,7 +571,7 @@ class Bot(BaseBot):
- ``List[Dict[str, Any]]``: 返回的群成员列表数据 - ``List[Dict[str, Any]]``: 返回的群成员列表数据
""" """
return await self.api.request('memberList', params={'target': target}) return await self.api.request("memberList", params={"target": target})
@argument_validation @argument_validation
async def mute(self, target: int, member_id: int, time: int): async def mute(self, target: int, member_id: int, time: int):
@ -577,12 +586,9 @@ class Bot(BaseBot):
* ``member_id: int``: 指定群员QQ号 * ``member_id: int``: 指定群员QQ号
* ``time: int``: 禁言时长单位为秒最多30天 * ``time: int``: 禁言时长单位为秒最多30天
""" """
return await self.api.post('mute', return await self.api.post(
params={ "mute", params={"target": target, "memberId": member_id, "time": time}
'target': target, )
'memberId': member_id,
'time': time
})
@argument_validation @argument_validation
async def unmute(self, target: int, member_id: int): async def unmute(self, target: int, member_id: int):
@ -596,11 +602,9 @@ class Bot(BaseBot):
* ``target: int``: 指定群的群号 * ``target: int``: 指定群的群号
* ``member_id: int``: 指定群员QQ号 * ``member_id: int``: 指定群员QQ号
""" """
return await self.api.post('unmute', return await self.api.post(
params={ "unmute", params={"target": target, "memberId": member_id}
'target': target, )
'memberId': member_id
})
@argument_validation @argument_validation
async def kick(self, target: int, member_id: int, msg: str): async def kick(self, target: int, member_id: int, msg: str):
@ -615,12 +619,9 @@ class Bot(BaseBot):
* ``member_id: int``: 指定群员QQ号 * ``member_id: int``: 指定群员QQ号
* ``msg: str``: 信息 * ``msg: str``: 信息
""" """
return await self.api.post('kick', return await self.api.post(
params={ "kick", params={"target": target, "memberId": member_id, "msg": msg}
'target': target, )
'memberId': member_id,
'msg': msg
})
@argument_validation @argument_validation
async def quit(self, target: int): async def quit(self, target: int):
@ -633,7 +634,7 @@ class Bot(BaseBot):
* ``target: int``: 退出的群号 * ``target: int``: 退出的群号
""" """
return await self.api.post('quit', params={'target': target}) return await self.api.post("quit", params={"target": target})
@argument_validation @argument_validation
async def mute_all(self, target: int): async def mute_all(self, target: int):
@ -646,7 +647,7 @@ class Bot(BaseBot):
* ``target: int``: 指定群的群号 * ``target: int``: 指定群的群号
""" """
return await self.api.post('muteAll', params={'target': target}) return await self.api.post("muteAll", params={"target": target})
@argument_validation @argument_validation
async def unmute_all(self, target: int): async def unmute_all(self, target: int):
@ -659,7 +660,7 @@ class Bot(BaseBot):
* ``target: int``: 指定群的群号 * ``target: int``: 指定群的群号
""" """
return await self.api.post('unmuteAll', params={'target': target}) return await self.api.post("unmuteAll", params={"target": target})
@argument_validation @argument_validation
async def group_config(self, target: int): async def group_config(self, target: int):
@ -685,7 +686,7 @@ class Bot(BaseBot):
"anonymousChat": true "anonymousChat": true
} }
""" """
return await self.api.request('groupConfig', params={'target': target}) return await self.api.request("groupConfig", params={"target": target})
@argument_validation @argument_validation
async def modify_group_config(self, target: int, config: Dict[str, Any]): async def modify_group_config(self, target: int, config: Dict[str, Any]):
@ -699,11 +700,9 @@ class Bot(BaseBot):
* ``target: int``: 指定群的群号 * ``target: int``: 指定群的群号
* ``config: Dict[str, Any]``: 群设置, 格式见 ``group_config`` 的返回值 * ``config: Dict[str, Any]``: 群设置, 格式见 ``group_config`` 的返回值
""" """
return await self.api.post('groupConfig', return await self.api.post(
params={ "groupConfig", params={"target": target, "config": config}
'target': target, )
'config': config
})
@argument_validation @argument_validation
async def member_info(self, target: int, member_id: int): async def member_info(self, target: int, member_id: int):
@ -726,15 +725,14 @@ class Bot(BaseBot):
"specialTitle": "群头衔" "specialTitle": "群头衔"
} }
""" """
return await self.api.request('memberInfo', return await self.api.request(
params={ "memberInfo", params={"target": target, "memberId": member_id}
'target': target, )
'memberId': member_id
})
@argument_validation @argument_validation
async def modify_member_info(self, target: int, member_id: int, async def modify_member_info(
info: Dict[str, Any]): self, target: int, member_id: int, info: Dict[str, Any]
):
""" """
:说明: :说明:
@ -746,9 +744,6 @@ class Bot(BaseBot):
* ``member_id: int``: 群员QQ号 * ``member_id: int``: 群员QQ号
* ``info: Dict[str, Any]``: 群员资料, 格式见 ``member_info`` 的返回值 * ``info: Dict[str, Any]``: 群员资料, 格式见 ``member_info`` 的返回值
""" """
return await self.api.post('memberInfo', return await self.api.post(
params={ "memberInfo", params={"target": target, "memberId": member_id, "info": info}
'target': target, )
'memberId': member_id,
'info': info
})

View File

@ -1,7 +1,7 @@
from ipaddress import IPv4Address
from typing import Optional from typing import Optional
from ipaddress import IPv4Address
from pydantic import BaseModel, Extra, Field from pydantic import Extra, Field, BaseModel
class Config(BaseModel): class Config(BaseModel):
@ -14,9 +14,10 @@ class Config(BaseModel):
- ``mirai_host``: mirai-api-http 的地址 - ``mirai_host``: mirai-api-http 的地址
- ``mirai_port``: mirai-api-http 的端口 - ``mirai_port``: mirai-api-http 的端口
""" """
auth_key: Optional[str] = Field(None, alias='mirai_auth_key')
host: Optional[IPv4Address] = Field(None, alias='mirai_host') auth_key: Optional[str] = Field(None, alias="mirai_auth_key")
port: Optional[int] = Field(None, alias='mirai_port') host: Optional[IPv4Address] = Field(None, alias="mirai_host")
port: Optional[int] = Field(None, alias="mirai_port")
class Config: class Config:
extra = Extra.ignore extra = Extra.ignore

View File

@ -5,25 +5,56 @@ r"""
部分字段可能与文档在符号上不一致 部分字段可能与文档在符号上不一致
\:\:\: \:\:\:
""" """
from .base import (Event, GroupChatInfo, GroupInfo, PrivateChatInfo,
UserPermission)
from .message import *
from .notice import * from .notice import *
from .message import *
from .request import * from .request import *
from .base import (
Event,
GroupInfo,
GroupChatInfo,
UserPermission,
PrivateChatInfo,
)
__all__ = [ __all__ = [
'Event', 'GroupChatInfo', 'GroupInfo', 'PrivateChatInfo', 'UserPermission', "Event",
'MessageSource', 'MessageEvent', 'GroupMessage', 'FriendMessage', "GroupChatInfo",
'TempMessage', 'NoticeEvent', 'MuteEvent', 'BotMuteEvent', 'BotUnmuteEvent', "GroupInfo",
'MemberMuteEvent', 'MemberUnmuteEvent', 'BotJoinGroupEvent', "PrivateChatInfo",
'BotLeaveEventActive', 'BotLeaveEventKick', 'MemberJoinEvent', "UserPermission",
'MemberLeaveEventKick', 'MemberLeaveEventQuit', 'FriendRecallEvent', "MessageSource",
'GroupRecallEvent', 'GroupStateChangeEvent', 'GroupNameChangeEvent', "MessageEvent",
'GroupEntranceAnnouncementChangeEvent', 'GroupMuteAllEvent', "GroupMessage",
'GroupAllowAnonymousChatEvent', 'GroupAllowConfessTalkEvent', "FriendMessage",
'GroupAllowMemberInviteEvent', 'MemberStateChangeEvent', "TempMessage",
'MemberCardChangeEvent', 'MemberSpecialTitleChangeEvent', "NoticeEvent",
'BotGroupPermissionChangeEvent', 'MemberPermissionChangeEvent', "MuteEvent",
'RequestEvent', 'NewFriendRequestEvent', 'MemberJoinRequestEvent', "BotMuteEvent",
'BotInvitedJoinGroupRequestEvent' "BotUnmuteEvent",
"MemberMuteEvent",
"MemberUnmuteEvent",
"BotJoinGroupEvent",
"BotLeaveEventActive",
"BotLeaveEventKick",
"MemberJoinEvent",
"MemberLeaveEventKick",
"MemberLeaveEventQuit",
"FriendRecallEvent",
"GroupRecallEvent",
"GroupStateChangeEvent",
"GroupNameChangeEvent",
"GroupEntranceAnnouncementChangeEvent",
"GroupMuteAllEvent",
"GroupAllowAnonymousChatEvent",
"GroupAllowConfessTalkEvent",
"GroupAllowMemberInviteEvent",
"MemberStateChangeEvent",
"MemberCardChangeEvent",
"MemberSpecialTitleChangeEvent",
"BotGroupPermissionChangeEvent",
"MemberPermissionChangeEvent",
"RequestEvent",
"NewFriendRequestEvent",
"MemberJoinRequestEvent",
"BotInvitedJoinGroupRequestEvent",
] ]

View File

@ -22,9 +22,10 @@ class UserPermission(str, Enum):
* ``ADMINISTRATOR``: 群管理 * ``ADMINISTRATOR``: 群管理
* ``MEMBER``: 普通群成员 * ``MEMBER``: 普通群成员
""" """
OWNER = 'OWNER'
ADMINISTRATOR = 'ADMINISTRATOR' OWNER = "OWNER"
MEMBER = 'MEMBER' ADMINISTRATOR = "ADMINISTRATOR"
MEMBER = "MEMBER"
class NudgeSubjectKind(str, Enum): class NudgeSubjectKind(str, Enum):
@ -36,8 +37,9 @@ class NudgeSubjectKind(str, Enum):
* ``Group``: * ``Group``:
* ``Friend``: 好友 * ``Friend``: 好友
""" """
Group = 'Group'
Friend = 'Friend' Group = "Group"
Friend = "Friend"
class GroupInfo(BaseModel): class GroupInfo(BaseModel):
@ -48,7 +50,7 @@ class GroupInfo(BaseModel):
class GroupChatInfo(BaseModel): class GroupChatInfo(BaseModel):
id: int id: int
name: str = Field(alias='memberName') name: str = Field(alias="memberName")
permission: UserPermission permission: UserPermission
group: GroupInfo group: GroupInfo
@ -71,6 +73,7 @@ class Event(BaseEvent):
.. _mirai-api-http 事件类型: .. _mirai-api-http 事件类型:
https://github.com/project-mirai/mirai-api-http/blob/master/docs/EventType.md https://github.com/project-mirai/mirai-api-http/blob/master/docs/EventType.md
""" """
self_id: int self_id: int
type: str type: str
@ -79,11 +82,12 @@ class Event(BaseEvent):
""" """
此事件类的工厂函数, 能够通过事件数据选择合适的子类进行序列化 此事件类的工厂函数, 能够通过事件数据选择合适的子类进行序列化
""" """
type = data['type'] type = data["type"]
def all_subclasses(cls: Type[Event]): def all_subclasses(cls: Type[Event]):
return set(cls.__subclasses__()).union( return set(cls.__subclasses__()).union(
[s for c in cls.__subclasses__() for s in all_subclasses(c)]) [s for c in cls.__subclasses__() for s in all_subclasses(c)]
)
event_class: Optional[Type[Event]] = None event_class: Optional[Type[Event]] = None
for subclass in all_subclasses(cls): for subclass in all_subclasses(cls):
@ -99,23 +103,25 @@ class Event(BaseEvent):
return event_class.parse_obj(data) return event_class.parse_obj(data)
except ValidationError as e: except ValidationError as e:
logger.info( logger.info(
f'Failed to parse {data} to class {event_class.__name__}: ' f"Failed to parse {data} to class {event_class.__name__}: "
f'{e.errors()!r}. Fallback to parent class.') f"{e.errors()!r}. Fallback to parent class."
)
event_class = event_class.__base__ # type: ignore event_class = event_class.__base__ # type: ignore
raise ValueError(f'Failed to serialize {data}.') raise ValueError(f"Failed to serialize {data}.")
@overrides(BaseEvent) @overrides(BaseEvent)
def get_type(self) -> Literal["message", "notice", "request", "meta_event"]: def get_type(self) -> Literal["message", "notice", "request", "meta_event"]:
from . import meta, notice, message, request from . import meta, notice, message, request
if isinstance(self, message.MessageEvent): if isinstance(self, message.MessageEvent):
return 'message' return "message"
elif isinstance(self, notice.NoticeEvent): elif isinstance(self, notice.NoticeEvent):
return 'notice' return "notice"
elif isinstance(self, request.RequestEvent): elif isinstance(self, request.RequestEvent):
return 'request' return "request"
else: else:
return 'meta_event' return "meta_event"
@overrides(BaseEvent) @overrides(BaseEvent)
def get_event_name(self) -> str: def get_event_name(self) -> str:

View File

@ -1,7 +1,7 @@
from datetime import datetime from datetime import datetime
from typing import Any, Optional from typing import Any, Optional
from pydantic import BaseModel, Field from pydantic import Field, BaseModel
from nonebot.typing import overrides from nonebot.typing import overrides
@ -16,7 +16,8 @@ class MessageSource(BaseModel):
class MessageEvent(Event): class MessageEvent(Event):
"""消息事件基类""" """消息事件基类"""
message_chain: MessageChain = Field(alias='messageChain')
message_chain: MessageChain = Field(alias="messageChain")
source: Optional[MessageSource] = None source: Optional[MessageSource] = None
sender: Any sender: Any
@ -39,12 +40,13 @@ class MessageEvent(Event):
class GroupMessage(MessageEvent): class GroupMessage(MessageEvent):
"""群消息事件""" """群消息事件"""
sender: GroupChatInfo sender: GroupChatInfo
to_me: bool = False to_me: bool = False
@overrides(MessageEvent) @overrides(MessageEvent)
def get_session_id(self) -> str: def get_session_id(self) -> str:
return f'group_{self.sender.group.id}_' + self.get_user_id() return f"group_{self.sender.group.id}_" + self.get_user_id()
@overrides(MessageEvent) @overrides(MessageEvent)
def get_user_id(self) -> str: def get_user_id(self) -> str:
@ -57,6 +59,7 @@ class GroupMessage(MessageEvent):
class FriendMessage(MessageEvent): class FriendMessage(MessageEvent):
"""好友消息事件""" """好友消息事件"""
sender: PrivateChatInfo sender: PrivateChatInfo
@overrides(MessageEvent) @overrides(MessageEvent)
@ -65,7 +68,7 @@ class FriendMessage(MessageEvent):
@overrides(MessageEvent) @overrides(MessageEvent)
def get_session_id(self) -> str: def get_session_id(self) -> str:
return 'friend_' + self.get_user_id() return "friend_" + self.get_user_id()
@overrides(MessageEvent) @overrides(MessageEvent)
def is_tome(self) -> bool: def is_tome(self) -> bool:
@ -74,11 +77,12 @@ class FriendMessage(MessageEvent):
class TempMessage(MessageEvent): class TempMessage(MessageEvent):
"""临时会话消息事件""" """临时会话消息事件"""
sender: GroupChatInfo sender: GroupChatInfo
@overrides(MessageEvent) @overrides(MessageEvent)
def get_session_id(self) -> str: def get_session_id(self) -> str:
return f'temp_{self.sender.group.id}_' + self.get_user_id() return f"temp_{self.sender.group.id}_" + self.get_user_id()
@overrides(MessageEvent) @overrides(MessageEvent)
def is_tome(self) -> bool: def is_tome(self) -> bool:

View File

@ -3,29 +3,35 @@ from .base import Event
class MetaEvent(Event): class MetaEvent(Event):
"""元事件基类""" """元事件基类"""
qq: int qq: int
class BotOnlineEvent(MetaEvent): class BotOnlineEvent(MetaEvent):
"""Bot登录成功""" """Bot登录成功"""
pass pass
class BotOfflineEventActive(MetaEvent): class BotOfflineEventActive(MetaEvent):
"""Bot主动离线""" """Bot主动离线"""
pass pass
class BotOfflineEventForce(MetaEvent): class BotOfflineEventForce(MetaEvent):
"""Bot被挤下线""" """Bot被挤下线"""
pass pass
class BotOfflineEventDropped(MetaEvent): class BotOfflineEventDropped(MetaEvent):
"""Bot被服务器断开或因网络问题而掉线""" """Bot被服务器断开或因网络问题而掉线"""
pass pass
class BotReloginEvent(MetaEvent): class BotReloginEvent(MetaEvent):
"""Bot主动重新登录""" """Bot主动重新登录"""
pass pass

View File

@ -2,88 +2,103 @@ from typing import Any, Optional
from pydantic import Field from pydantic import Field
from .base import Event, GroupChatInfo, GroupInfo, NudgeSubject, UserPermission from .base import Event, GroupInfo, NudgeSubject, GroupChatInfo, UserPermission
class NoticeEvent(Event): class NoticeEvent(Event):
"""通知事件基类""" """通知事件基类"""
pass pass
class MuteEvent(NoticeEvent): class MuteEvent(NoticeEvent):
"""禁言类事件基类""" """禁言类事件基类"""
operator: GroupChatInfo operator: GroupChatInfo
class BotMuteEvent(MuteEvent): class BotMuteEvent(MuteEvent):
"""Bot被禁言""" """Bot被禁言"""
pass pass
class BotUnmuteEvent(MuteEvent): class BotUnmuteEvent(MuteEvent):
"""Bot被取消禁言""" """Bot被取消禁言"""
pass pass
class MemberMuteEvent(MuteEvent): class MemberMuteEvent(MuteEvent):
"""群成员被禁言事件该成员不是Bot""" """群成员被禁言事件该成员不是Bot"""
duration_seconds: int = Field(alias='durationSeconds')
duration_seconds: int = Field(alias="durationSeconds")
member: GroupChatInfo member: GroupChatInfo
operator: Optional[GroupChatInfo] = None operator: Optional[GroupChatInfo] = None
class MemberUnmuteEvent(MuteEvent): class MemberUnmuteEvent(MuteEvent):
"""群成员被取消禁言事件该成员不是Bot""" """群成员被取消禁言事件该成员不是Bot"""
member: GroupChatInfo member: GroupChatInfo
operator: Optional[GroupChatInfo] = None operator: Optional[GroupChatInfo] = None
class BotJoinGroupEvent(NoticeEvent): class BotJoinGroupEvent(NoticeEvent):
"""Bot加入了一个新群""" """Bot加入了一个新群"""
group: GroupInfo group: GroupInfo
class BotLeaveEventActive(BotJoinGroupEvent): class BotLeaveEventActive(BotJoinGroupEvent):
"""Bot主动退出一个群""" """Bot主动退出一个群"""
pass pass
class BotLeaveEventKick(BotJoinGroupEvent): class BotLeaveEventKick(BotJoinGroupEvent):
"""Bot被踢出一个群""" """Bot被踢出一个群"""
pass pass
class MemberJoinEvent(NoticeEvent): class MemberJoinEvent(NoticeEvent):
"""新人入群的事件""" """新人入群的事件"""
member: GroupChatInfo member: GroupChatInfo
class MemberLeaveEventKick(MemberJoinEvent): class MemberLeaveEventKick(MemberJoinEvent):
"""成员被踢出群该成员不是Bot""" """成员被踢出群该成员不是Bot"""
operator: Optional[GroupChatInfo] = None operator: Optional[GroupChatInfo] = None
class MemberLeaveEventQuit(MemberJoinEvent): class MemberLeaveEventQuit(MemberJoinEvent):
"""成员主动离群该成员不是Bot""" """成员主动离群该成员不是Bot"""
pass pass
class FriendRecallEvent(NoticeEvent): class FriendRecallEvent(NoticeEvent):
"""好友消息撤回""" """好友消息撤回"""
author_id: int = Field(alias='authorId')
message_id: int = Field(alias='messageId') author_id: int = Field(alias="authorId")
message_id: int = Field(alias="messageId")
time: int time: int
operator: int operator: int
class GroupRecallEvent(FriendRecallEvent): class GroupRecallEvent(FriendRecallEvent):
"""群消息撤回""" """群消息撤回"""
group: GroupInfo group: GroupInfo
operator: Optional[GroupChatInfo] = None operator: Optional[GroupChatInfo] = None
class GroupStateChangeEvent(NoticeEvent): class GroupStateChangeEvent(NoticeEvent):
"""群变化事件基类""" """群变化事件基类"""
origin: Any origin: Any
current: Any current: Any
group: GroupInfo group: GroupInfo
@ -92,73 +107,85 @@ class GroupStateChangeEvent(NoticeEvent):
class GroupNameChangeEvent(GroupStateChangeEvent): class GroupNameChangeEvent(GroupStateChangeEvent):
"""某个群名改变""" """某个群名改变"""
origin: str origin: str
current: str current: str
class GroupEntranceAnnouncementChangeEvent(GroupStateChangeEvent): class GroupEntranceAnnouncementChangeEvent(GroupStateChangeEvent):
"""某群入群公告改变""" """某群入群公告改变"""
origin: str origin: str
current: str current: str
class GroupMuteAllEvent(GroupStateChangeEvent): class GroupMuteAllEvent(GroupStateChangeEvent):
"""全员禁言""" """全员禁言"""
origin: bool origin: bool
current: bool current: bool
class GroupAllowAnonymousChatEvent(GroupStateChangeEvent): class GroupAllowAnonymousChatEvent(GroupStateChangeEvent):
"""匿名聊天""" """匿名聊天"""
origin: bool origin: bool
current: bool current: bool
class GroupAllowConfessTalkEvent(GroupStateChangeEvent): class GroupAllowConfessTalkEvent(GroupStateChangeEvent):
"""坦白说""" """坦白说"""
origin: bool origin: bool
current: bool current: bool
class GroupAllowMemberInviteEvent(GroupStateChangeEvent): class GroupAllowMemberInviteEvent(GroupStateChangeEvent):
"""允许群员邀请好友加群""" """允许群员邀请好友加群"""
origin: bool origin: bool
current: bool current: bool
class MemberStateChangeEvent(NoticeEvent): class MemberStateChangeEvent(NoticeEvent):
"""群成员变化事件基类""" """群成员变化事件基类"""
member: GroupChatInfo member: GroupChatInfo
operator: Optional[GroupChatInfo] = None operator: Optional[GroupChatInfo] = None
class MemberCardChangeEvent(MemberStateChangeEvent): class MemberCardChangeEvent(MemberStateChangeEvent):
"""群名片改动""" """群名片改动"""
origin: str origin: str
current: str current: str
class MemberSpecialTitleChangeEvent(MemberStateChangeEvent): class MemberSpecialTitleChangeEvent(MemberStateChangeEvent):
"""群头衔改动(只有群主有操作限权)""" """群头衔改动(只有群主有操作限权)"""
origin: str origin: str
current: str current: str
class BotGroupPermissionChangeEvent(MemberStateChangeEvent): class BotGroupPermissionChangeEvent(MemberStateChangeEvent):
"""Bot在群里的权限被改变""" """Bot在群里的权限被改变"""
origin: UserPermission origin: UserPermission
current: UserPermission current: UserPermission
class MemberPermissionChangeEvent(MemberStateChangeEvent): class MemberPermissionChangeEvent(MemberStateChangeEvent):
"""成员权限改变的事件该成员不是Bot""" """成员权限改变的事件该成员不是Bot"""
origin: UserPermission origin: UserPermission
current: UserPermission current: UserPermission
class NudgeEvent(NoticeEvent): class NudgeEvent(NoticeEvent):
"""戳一戳触发事件""" """戳一戳触发事件"""
from_id: int = Field(alias='fromId')
from_id: int = Field(alias="fromId")
target: int target: int
subject: NudgeSubject subject: NudgeSubject
action: str action: str

View File

@ -1,7 +1,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing_extensions import Literal
from pydantic import Field from pydantic import Field
from typing_extensions import Literal
from .base import Event from .base import Event
@ -11,15 +11,17 @@ if TYPE_CHECKING:
class RequestEvent(Event): class RequestEvent(Event):
"""请求事件基类""" """请求事件基类"""
event_id: int = Field(alias='eventId')
event_id: int = Field(alias="eventId")
message: str message: str
nick: str nick: str
class NewFriendRequestEvent(RequestEvent): class NewFriendRequestEvent(RequestEvent):
"""添加好友申请""" """添加好友申请"""
from_id: int = Field(alias='fromId')
group_id: int = Field(0, alias='groupId') from_id: int = Field(alias="fromId")
group_id: int = Field(0, alias="groupId")
async def approve(self, bot: "Bot"): async def approve(self, bot: "Bot"):
""" """
@ -31,19 +33,18 @@ class NewFriendRequestEvent(RequestEvent):
* ``bot: Bot``: 当前的 ``Bot`` 对象 * ``bot: Bot``: 当前的 ``Bot`` 对象
""" """
return await bot.api.post('/resp/newFriendRequestEvent', return await bot.api.post(
"/resp/newFriendRequestEvent",
params={ params={
'eventId': self.event_id, "eventId": self.event_id,
'groupId': self.group_id, "groupId": self.group_id,
'fromId': self.from_id, "fromId": self.from_id,
'operate': 0, "operate": 0,
'message': '' "message": "",
}) },
)
async def reject(self, async def reject(self, bot: "Bot", operate: Literal[1, 2] = 1, message: str = ""):
bot: "Bot",
operate: Literal[1, 2] = 1,
message: str = ''):
""" """
:说明: :说明:
@ -60,21 +61,24 @@ class NewFriendRequestEvent(RequestEvent):
* ``message: str``: 回复的信息 * ``message: str``: 回复的信息
""" """
assert operate > 0 assert operate > 0
return await bot.api.post('/resp/newFriendRequestEvent', return await bot.api.post(
"/resp/newFriendRequestEvent",
params={ params={
'eventId': self.event_id, "eventId": self.event_id,
'groupId': self.group_id, "groupId": self.group_id,
'fromId': self.from_id, "fromId": self.from_id,
'operate': operate, "operate": operate,
'message': message "message": message,
}) },
)
class MemberJoinRequestEvent(RequestEvent): class MemberJoinRequestEvent(RequestEvent):
"""用户入群申请Bot需要有管理员权限""" """用户入群申请Bot需要有管理员权限"""
from_id: int = Field(alias='fromId')
group_id: int = Field(alias='groupId') from_id: int = Field(alias="fromId")
group_name: str = Field(alias='groupName') group_id: int = Field(alias="groupId")
group_name: str = Field(alias="groupName")
async def approve(self, bot: "Bot"): async def approve(self, bot: "Bot"):
""" """
@ -86,19 +90,20 @@ class MemberJoinRequestEvent(RequestEvent):
* ``bot: Bot``: 当前的 ``Bot`` 对象 * ``bot: Bot``: 当前的 ``Bot`` 对象
""" """
return await bot.api.post('/resp/memberJoinRequestEvent', return await bot.api.post(
"/resp/memberJoinRequestEvent",
params={ params={
'eventId': self.event_id, "eventId": self.event_id,
'groupId': self.group_id, "groupId": self.group_id,
'fromId': self.from_id, "fromId": self.from_id,
'operate': 0, "operate": 0,
'message': '' "message": "",
}) },
)
async def reject(self, async def reject(
bot: "Bot", self, bot: "Bot", operate: Literal[1, 2, 3, 4] = 1, message: str = ""
operate: Literal[1, 2, 3, 4] = 1, ):
message: str = ''):
""" """
:说明: :说明:
@ -117,21 +122,24 @@ class MemberJoinRequestEvent(RequestEvent):
* ``message: str``: 回复的信息 * ``message: str``: 回复的信息
""" """
assert operate > 0 assert operate > 0
return await bot.api.post('/resp/memberJoinRequestEvent', return await bot.api.post(
"/resp/memberJoinRequestEvent",
params={ params={
'eventId': self.event_id, "eventId": self.event_id,
'groupId': self.group_id, "groupId": self.group_id,
'fromId': self.from_id, "fromId": self.from_id,
'operate': operate, "operate": operate,
'message': message "message": message,
}) },
)
class BotInvitedJoinGroupRequestEvent(RequestEvent): class BotInvitedJoinGroupRequestEvent(RequestEvent):
"""Bot被邀请入群申请""" """Bot被邀请入群申请"""
from_id: int = Field(alias='fromId')
group_id: int = Field(alias='groupId') from_id: int = Field(alias="fromId")
group_name: str = Field(alias='groupName') group_id: int = Field(alias="groupId")
group_name: str = Field(alias="groupName")
async def approve(self, bot: "Bot"): async def approve(self, bot: "Bot"):
""" """
@ -143,14 +151,16 @@ class BotInvitedJoinGroupRequestEvent(RequestEvent):
* ``bot: Bot``: 当前的 ``Bot`` 对象 * ``bot: Bot``: 当前的 ``Bot`` 对象
""" """
return await bot.api.post('/resp/botInvitedJoinGroupRequestEvent', return await bot.api.post(
"/resp/botInvitedJoinGroupRequestEvent",
params={ params={
'eventId': self.event_id, "eventId": self.event_id,
'groupId': self.group_id, "groupId": self.group_id,
'fromId': self.from_id, "fromId": self.from_id,
'operate': 0, "operate": 0,
'message': '' "message": "",
}) },
)
async def reject(self, bot: "Bot", message: str = ""): async def reject(self, bot: "Bot", message: str = ""):
""" """
@ -163,11 +173,13 @@ class BotInvitedJoinGroupRequestEvent(RequestEvent):
* ``bot: Bot``: 当前的 ``Bot`` 对象 * ``bot: Bot``: 当前的 ``Bot`` 对象
* ``message: str``: 邀请消息 * ``message: str``: 邀请消息
""" """
return await bot.api.post('/resp/botInvitedJoinGroupRequestEvent', return await bot.api.post(
"/resp/botInvitedJoinGroupRequestEvent",
params={ params={
'eventId': self.event_id, "eventId": self.event_id,
'groupId': self.group_id, "groupId": self.group_id,
'fromId': self.from_id, "fromId": self.from_id,
'operate': 1, "operate": 1,
'message': message "message": message,
}) },
)

View File

@ -1,28 +1,29 @@
from enum import Enum from enum import Enum
from typing import Any, List, Dict, Type, Iterable, Optional, Union from typing import Any, Dict, List, Type, Union, Iterable, Optional
from pydantic import validate_arguments from pydantic import validate_arguments
from nonebot.typing import overrides
from nonebot.adapters import Message as BaseMessage from nonebot.adapters import Message as BaseMessage
from nonebot.adapters import MessageSegment as BaseMessageSegment from nonebot.adapters import MessageSegment as BaseMessageSegment
from nonebot.typing import overrides
class MessageType(str, Enum): class MessageType(str, Enum):
"""消息类型枚举类""" """消息类型枚举类"""
SOURCE = 'Source'
QUOTE = 'Quote' SOURCE = "Source"
AT = 'At' QUOTE = "Quote"
AT_ALL = 'AtAll' AT = "At"
FACE = 'Face' AT_ALL = "AtAll"
PLAIN = 'Plain' FACE = "Face"
IMAGE = 'Image' PLAIN = "Plain"
FLASH_IMAGE = 'FlashImage' IMAGE = "Image"
VOICE = 'Voice' FLASH_IMAGE = "FlashImage"
XML = 'Xml' VOICE = "Voice"
JSON = 'Json' XML = "Xml"
APP = 'App' JSON = "Json"
POKE = 'Poke' APP = "App"
POKE = "Poke"
class MessageSegment(BaseMessageSegment["MessageChain"]): class MessageSegment(BaseMessageSegment["MessageChain"]):
@ -43,21 +44,24 @@ class MessageSegment(BaseMessageSegment["MessageChain"]):
@validate_arguments @validate_arguments
@overrides(BaseMessageSegment) @overrides(BaseMessageSegment)
def __init__(self, type: MessageType, **data: Any): def __init__(self, type: MessageType, **data: Any):
super().__init__(type=type, super().__init__(
data={k: v for k, v in data.items() if v is not None}) type=type, data={k: v for k, v in data.items() if v is not None}
)
@overrides(BaseMessageSegment) @overrides(BaseMessageSegment)
def __str__(self) -> str: def __str__(self) -> str:
return self.data['text'] if self.is_text() else repr(self) return self.data["text"] if self.is_text() else repr(self)
def __repr__(self) -> str: def __repr__(self) -> str:
return '[mirai:%s]' % ','.join([ return "[mirai:%s]" % ",".join(
[
self.type.value, self.type.value,
*map( *map(
lambda s: '%s=%r' % s, lambda s: "%s=%r" % s,
self.data.items(), self.data.items(),
), ),
]) ]
)
@overrides(BaseMessageSegment) @overrides(BaseMessageSegment)
def is_text(self) -> bool: def is_text(self) -> bool:
@ -65,15 +69,21 @@ class MessageSegment(BaseMessageSegment["MessageChain"]):
def as_dict(self) -> Dict[str, Any]: def as_dict(self) -> Dict[str, Any]:
"""导出可以被正常json序列化的结构体""" """导出可以被正常json序列化的结构体"""
return {'type': self.type.value, **self.data} return {"type": self.type.value, **self.data}
@classmethod @classmethod
def source(cls, id: int, time: int): def source(cls, id: int, time: int):
return cls(type=MessageType.SOURCE, id=id, time=time) return cls(type=MessageType.SOURCE, id=id, time=time)
@classmethod @classmethod
def quote(cls, id: int, group_id: int, sender_id: int, target_id: int, def quote(
origin: "MessageChain"): cls,
id: int,
group_id: int,
sender_id: int,
target_id: int,
origin: "MessageChain",
):
""" """
:说明: :说明:
@ -87,12 +97,14 @@ class MessageSegment(BaseMessageSegment["MessageChain"]):
* ``target_id: int``: 被引用回复的原消息的接收者者的QQ号或群号 * ``target_id: int``: 被引用回复的原消息的接收者者的QQ号或群号
* ``origin: MessageChain``: 被引用回复的原消息的消息链对象 * ``origin: MessageChain``: 被引用回复的原消息的消息链对象
""" """
return cls(type=MessageType.QUOTE, return cls(
type=MessageType.QUOTE,
id=id, id=id,
groupId=group_id, groupId=group_id,
senderId=sender_id, senderId=sender_id,
targetId=target_id, targetId=target_id,
origin=origin.export()) origin=origin.export(),
)
@classmethod @classmethod
def at(cls, target: int): def at(cls, target: int):
@ -144,10 +156,12 @@ class MessageSegment(BaseMessageSegment["MessageChain"]):
return cls(type=MessageType.PLAIN, text=text) return cls(type=MessageType.PLAIN, text=text)
@classmethod @classmethod
def image(cls, def image(
cls,
image_id: Optional[str] = None, image_id: Optional[str] = None,
url: Optional[str] = None, url: Optional[str] = None,
path: Optional[str] = None): path: Optional[str] = None,
):
""" """
:说明: :说明:
@ -162,10 +176,12 @@ class MessageSegment(BaseMessageSegment["MessageChain"]):
return cls(type=MessageType.IMAGE, imageId=image_id, url=url, path=path) return cls(type=MessageType.IMAGE, imageId=image_id, url=url, path=path)
@classmethod @classmethod
def flash_image(cls, def flash_image(
cls,
image_id: Optional[str] = None, image_id: Optional[str] = None,
url: Optional[str] = None, url: Optional[str] = None,
path: Optional[str] = None): path: Optional[str] = None,
):
""" """
:说明: :说明:
@ -175,16 +191,15 @@ class MessageSegment(BaseMessageSegment["MessageChain"]):
``image`` ``image``
""" """
return cls(type=MessageType.FLASH_IMAGE, return cls(type=MessageType.FLASH_IMAGE, imageId=image_id, url=url, path=path)
imageId=image_id,
url=url,
path=path)
@classmethod @classmethod
def voice(cls, def voice(
cls,
voice_id: Optional[str] = None, voice_id: Optional[str] = None,
url: Optional[str] = None, url: Optional[str] = None,
path: Optional[str] = None): path: Optional[str] = None,
):
""" """
:说明: :说明:
@ -196,10 +211,7 @@ class MessageSegment(BaseMessageSegment["MessageChain"]):
* ``url: Optional[str]``: 语音的URL发送时可作网络语音的链接 * ``url: Optional[str]``: 语音的URL发送时可作网络语音的链接
* ``path: Optional[str]``: 语音的路径发送本地语音 * ``path: Optional[str]``: 语音的路径发送本地语音
""" """
return cls(type=MessageType.FLASH_IMAGE, return cls(type=MessageType.FLASH_IMAGE, imageId=voice_id, url=url, path=path)
imageId=voice_id,
url=url,
path=path)
@classmethod @classmethod
def xml(cls, xml: str): def xml(cls, xml: str):
@ -282,16 +294,14 @@ class MessageChain(BaseMessage[MessageSegment]):
return [MessageSegment.plain(text=message)] return [MessageSegment.plain(text=message)]
return [ return [
*map( *map(
lambda x: x lambda x: x if isinstance(x, MessageSegment) else MessageSegment(**x),
if isinstance(x, MessageSegment) else MessageSegment(**x), message,
message) )
] ]
def export(self) -> List[Dict[str, Any]]: def export(self) -> List[Dict[str, Any]]:
"""导出为可以被正常json序列化的数组""" """导出为可以被正常json序列化的数组"""
return [ return [*map(lambda segment: segment.as_dict(), self.copy())] # type: ignore
*map(lambda segment: segment.as_dict(), self.copy()) # type: ignore
]
def extract_first(self, *type: MessageType) -> Optional[MessageSegment]: def extract_first(self, *type: MessageType) -> Optional[MessageSegment]:
""" """
@ -311,4 +321,4 @@ class MessageChain(BaseMessage[MessageSegment]):
return None return None
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<{self.__class__.__name__} {[*self.copy()]}>' return f"<{self.__class__.__name__} {[*self.copy()]}>"

View File

@ -1,17 +1,17 @@
import re import re
from functools import wraps from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Coroutine, Optional, TypeVar from typing import TYPE_CHECKING, Any, TypeVar, Callable, Optional, Coroutine
import httpx import httpx
from pydantic import Extra, ValidationError, validate_arguments from pydantic import Extra, ValidationError, validate_arguments
import nonebot.exception as exception
from nonebot.log import logger from nonebot.log import logger
import nonebot.exception as exception
from nonebot.message import handle_event from nonebot.message import handle_event
from nonebot.utils import escape_tag, logger_wrapper from nonebot.utils import escape_tag, logger_wrapper
from .event import Event, GroupMessage, MessageEvent, MessageSource
from .message import MessageType, MessageSegment from .message import MessageType, MessageSegment
from .event import Event, GroupMessage, MessageEvent, MessageSource
if TYPE_CHECKING: if TYPE_CHECKING:
from .bot import Bot from .bot import Bot
@ -21,28 +21,27 @@ _AnyCallable = TypeVar("_AnyCallable", bound=Callable)
class Log: class Log:
@staticmethod @staticmethod
def log(level: str, message: str, exception: Optional[Exception] = None): def log(level: str, message: str, exception: Optional[Exception] = None):
logger = logger_wrapper('MIRAI') logger = logger_wrapper("MIRAI")
message = '<e>' + escape_tag(message) + '</e>' message = "<e>" + escape_tag(message) + "</e>"
logger(level=level.upper(), message=message, exception=exception) logger(level=level.upper(), message=message, exception=exception)
@classmethod @classmethod
def info(cls, message: Any): def info(cls, message: Any):
cls.log('INFO', str(message)) cls.log("INFO", str(message))
@classmethod @classmethod
def debug(cls, message: Any): def debug(cls, message: Any):
cls.log('DEBUG', str(message)) cls.log("DEBUG", str(message))
@classmethod @classmethod
def warn(cls, message: Any): def warn(cls, message: Any):
cls.log('WARNING', str(message)) cls.log("WARNING", str(message))
@classmethod @classmethod
def error(cls, message: Any, exception: Optional[Exception] = None): def error(cls, message: Any, exception: Optional[Exception] = None):
cls.log('ERROR', str(message), exception=exception) cls.log("ERROR", str(message), exception=exception)
class ActionFailed(exception.ActionFailed): class ActionFailed(exception.ActionFailed):
@ -53,12 +52,13 @@ class ActionFailed(exception.ActionFailed):
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__('mirai') super().__init__("mirai")
self.data = kwargs.copy() self.data = kwargs.copy()
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(%s)' % ', '.join( return self.__class__.__name__ + "(%s)" % ", ".join(
map(lambda m: '%s=%r' % m, self.data.items())) map(lambda m: "%s=%r" % m, self.data.items())
)
class InvalidArgument(exception.AdapterException): class InvalidArgument(exception.AdapterException):
@ -69,7 +69,7 @@ class InvalidArgument(exception.AdapterException):
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__('mirai') super().__init__("mirai")
def catch_network_error(function: _AsyncCallable) -> _AsyncCallable: def catch_network_error(function: _AsyncCallable) -> _AsyncCallable:
@ -90,11 +90,12 @@ def catch_network_error(function: _AsyncCallable) -> _AsyncCallable:
try: try:
data = await function(*args, **kwargs) data = await function(*args, **kwargs)
except httpx.HTTPError: except httpx.HTTPError:
raise exception.NetworkError('mirai') raise exception.NetworkError("mirai")
logger.opt(colors=True).debug('<b>Mirai API returned data:</b> ' logger.opt(colors=True).debug(
f'<y>{escape_tag(str(data))}</y>') "<b>Mirai API returned data:</b> " f"<y>{escape_tag(str(data))}</y>"
)
if isinstance(data, dict): if isinstance(data, dict):
if data.get('code', 0) != 0: if data.get("code", 0) != 0:
raise ActionFailed(**data) raise ActionFailed(**data)
return data return data
@ -109,10 +110,9 @@ def argument_validation(function: _AnyCallable) -> _AnyCallable:
会在参数出错时释放 ``InvalidArgument`` 异常 会在参数出错时释放 ``InvalidArgument`` 异常
""" """
function = validate_arguments(config={ function = validate_arguments(
'arbitrary_types_allowed': True, config={"arbitrary_types_allowed": True, "extra": Extra.forbid}
'extra': Extra.forbid )(function)
})(function)
@wraps(function) @wraps(function)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
@ -134,12 +134,12 @@ def process_source(bot: "Bot", event: MessageEvent) -> MessageEvent:
def process_at(bot: "Bot", event: GroupMessage) -> GroupMessage: def process_at(bot: "Bot", event: GroupMessage) -> GroupMessage:
at = event.message_chain.extract_first(MessageType.AT) at = event.message_chain.extract_first(MessageType.AT)
if at is not None: if at is not None:
if at.data['target'] == event.self_id: if at.data["target"] == event.self_id:
event.to_me = True event.to_me = True
else: else:
event.message_chain.insert(0, at) event.message_chain.insert(0, at)
if not event.message_chain: if not event.message_chain:
event.message_chain.append(MessageSegment.plain('')) event.message_chain.append(MessageSegment.plain(""))
return event return event
@ -147,13 +147,13 @@ def process_nick(bot: "Bot", event: GroupMessage) -> GroupMessage:
plain = event.message_chain.extract_first(MessageType.PLAIN) plain = event.message_chain.extract_first(MessageType.PLAIN)
if plain is not None: if plain is not None:
text = str(plain) text = str(plain)
nick_regex = '|'.join(filter(lambda x: x, bot.config.nickname)) nick_regex = "|".join(filter(lambda x: x, bot.config.nickname))
matched = re.search(rf"^({nick_regex})([\s,]*|$)", text, re.IGNORECASE) matched = re.search(rf"^({nick_regex})([\s,]*|$)", text, re.IGNORECASE)
if matched is not None: if matched is not None:
event.to_me = True event.to_me = True
nickname = matched.group(1) nickname = matched.group(1)
Log.info(f'User is calling me {nickname}') Log.info(f"User is calling me {nickname}")
plain.data['text'] = text[matched.end():] plain.data["text"] = text[matched.end() :]
event.message_chain.insert(0, plain) event.message_chain.insert(0, plain)
return event return event
@ -161,7 +161,7 @@ def process_nick(bot: "Bot", event: GroupMessage) -> GroupMessage:
def process_reply(bot: "Bot", event: GroupMessage) -> GroupMessage: def process_reply(bot: "Bot", event: GroupMessage) -> GroupMessage:
reply = event.message_chain.extract_first(MessageType.QUOTE) reply = event.message_chain.extract_first(MessageType.QUOTE)
if reply is not None: if reply is not None:
if reply.data['senderId'] == event.self_id: if reply.data["senderId"] == event.self_id:
event.to_me = True event.to_me = True
else: else:
event.message_chain.insert(0, reply) event.message_chain.insert(0, reply)

View File

@ -34,6 +34,21 @@ nonebot2 = { path = "../../", develop = true }
# url = "https://mirrors.aliyun.com/pypi/simple/" # url = "https://mirrors.aliyun.com/pypi/simple/"
# default = true # default = true
[tool.black]
line-length = 88
target-version = ["py37", "py38", "py39"]
include = '\.pyi?$'
extend-exclude = '''
'''
[tool.isort]
profile = "black"
line_length = 80
length_sort = true
skip_gitignore = true
force_sort_within_sections = true
extra_standard_library = ["typing_extensions"]
[build-system] [build-system]
requires = ["poetry-core>=1.0.0"] requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"

View File

@ -7,8 +7,7 @@ from nonebot.log import logger
def init(): def init():
driver = nonebot.get_driver() driver = nonebot.get_driver()
try: try:
_module = importlib.import_module( _module = importlib.import_module(f"nonebot_plugin_docs.drivers.{driver.type}")
f"nonebot_plugin_docs.drivers.{driver.type}")
except ImportError: except ImportError:
logger.warning(f"Driver {driver.type} not supported") logger.warning(f"Driver {driver.type} not supported")
return return
@ -18,8 +17,9 @@ def init():
port = driver.config.port port = driver.config.port
if host in ["0.0.0.0", "127.0.0.1"]: if host in ["0.0.0.0", "127.0.0.1"]:
host = "localhost" host = "localhost"
logger.opt(colors=True).info(f"Nonebot docs will be running at: " logger.opt(colors=True).info(
f"<b><u>http://{host}:{port}/docs/</u></b>") f"Nonebot docs will be running at: " f"<b><u>http://{host}:{port}/docs/</u></b>"
)
init() init()

View File

@ -9,6 +9,4 @@ def register_route(driver: Driver):
static_path = str((Path(__file__).parent / ".." / "dist").resolve()) static_path = str((Path(__file__).parent / ".." / "dist").resolve())
app.mount("/docs", app.mount("/docs", StaticFiles(directory=static_path, html=True), name="docs")
StaticFiles(directory=static_path, html=True),
name="docs")

View File

@ -18,6 +18,21 @@ nonebot2 = "^2.0.0-alpha.1"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
[tool.black]
line-length = 88
target-version = ["py37", "py38", "py39"]
include = '\.pyi?$'
extend-exclude = '''
'''
[tool.isort]
profile = "black"
line_length = 80
length_sort = true
skip_gitignore = true
force_sort_within_sections = true
extra_standard_library = ["typing_extensions"]
[build-system] [build-system]
requires = ["poetry-core>=1.0.0"] requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"

217
poetry.lock generated
View File

@ -151,6 +151,34 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
[package.dependencies] [package.dependencies]
pytz = ">=2015.7" pytz = ">=2015.7"
[[package]]
name = "black"
version = "21.11b1"
description = "The uncompromising code formatter."
category = "dev"
optional = false
python-versions = ">=3.6.2"
[package.dependencies]
click = ">=7.1.2"
mypy-extensions = ">=0.4.3"
pathspec = ">=0.9.0,<1"
platformdirs = ">=2"
regex = ">=2021.4.4"
tomli = ">=0.2.6,<2.0.0"
typed-ast = {version = ">=1.4.2", markers = "python_version < \"3.8\" and implementation_name == \"cpython\""}
typing-extensions = [
{version = ">=3.10.0.0", markers = "python_version < \"3.10\""},
{version = "!=3.10.0.1", markers = "python_version >= \"3.10\""},
]
[package.extras]
colorama = ["colorama (>=0.4.3)"]
d = ["aiohttp (>=3.7.4)"]
jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
python2 = ["typed-ast (>=1.4.3)"]
uvloop = ["uvloop (>=0.15.2)"]
[[package]] [[package]]
name = "blinker" name = "blinker"
version = "1.4" version = "1.4"
@ -297,7 +325,7 @@ python-versions = ">=3.5"
[[package]] [[package]]
name = "httpcore" name = "httpcore"
version = "0.14.2" version = "0.14.3"
description = "A minimal low-level HTTP client." description = "A minimal low-level HTTP client."
category = "main" category = "main"
optional = false optional = false
@ -406,6 +434,20 @@ docs = ["sphinx", "jaraco.packaging (>=8.2)", "rst.linker (>=1.9)"]
perf = ["ipython"] perf = ["ipython"]
testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "packaging", "pep517", "pyfakefs", "flufl.flake8", "pytest-perf (>=0.9.2)", "pytest-black (>=0.3.7)", "pytest-mypy", "importlib-resources (>=1.3)"] testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "packaging", "pep517", "pyfakefs", "flufl.flake8", "pytest-perf (>=0.9.2)", "pytest-black (>=0.3.7)", "pytest-mypy", "importlib-resources (>=1.3)"]
[[package]]
name = "isort"
version = "5.10.1"
description = "A Python utility / library to sort Python imports."
category = "dev"
optional = false
python-versions = ">=3.6.1,<4.0"
[package.extras]
pipfile_deprecated_finder = ["pipreqs", "requirementslib"]
requirements_deprecated_finder = ["pipreqs", "pip-api"]
colors = ["colorama (>=0.4.3,<0.5.0)"]
plugins = ["setuptools"]
[[package]] [[package]]
name = "itsdangerous" name = "itsdangerous"
version = "2.0.1" version = "2.0.1"
@ -459,6 +501,14 @@ category = "main"
optional = true optional = true
python-versions = ">=3.6" python-versions = ">=3.6"
[[package]]
name = "mypy-extensions"
version = "0.4.3"
description = "Experimental type system extensions for programs checked with the mypy typechecker."
category = "dev"
optional = false
python-versions = "*"
[[package]] [[package]]
name = "nonebot-adapter-cqhttp" name = "nonebot-adapter-cqhttp"
version = "2.0.0-alpha.16" version = "2.0.0-alpha.16"
@ -544,14 +594,34 @@ python-socketio = ">=4.6.1,<5.0.0"
[[package]] [[package]]
name = "packaging" name = "packaging"
version = "21.2" version = "21.3"
description = "Core utilities for Python packages" description = "Core utilities for Python packages"
category = "dev" category = "dev"
optional = false optional = false
python-versions = ">=3.6" python-versions = ">=3.6"
[package.dependencies] [package.dependencies]
pyparsing = ">=2.0.2,<3" pyparsing = ">=2.0.2,<3.0.5 || >3.0.5"
[[package]]
name = "pathspec"
version = "0.9.0"
description = "Utility library for gitignore style pattern matching of file paths."
category = "dev"
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7"
[[package]]
name = "platformdirs"
version = "2.4.0"
description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"."
category = "dev"
optional = false
python-versions = ">=3.6"
[package.extras]
docs = ["Sphinx (>=4)", "furo (>=2021.7.5b38)", "proselint (>=0.10.2)", "sphinx-autodoc-typehints (>=1.12)"]
test = ["appdirs (==1.4.4)", "pytest (>=6)", "pytest-cov (>=2.7)", "pytest-mock (>=3.6)"]
[[package]] [[package]]
name = "priority" name = "priority"
@ -636,11 +706,14 @@ python-versions = "*"
[[package]] [[package]]
name = "pyparsing" name = "pyparsing"
version = "2.4.7" version = "3.0.6"
description = "Python parsing module" description = "Python parsing module"
category = "dev" category = "dev"
optional = false optional = false
python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" python-versions = ">=3.6"
[package.extras]
diagrams = ["jinja2", "railroad-diagrams"]
[[package]] [[package]]
name = "python-dotenv" name = "python-dotenv"
@ -722,6 +795,14 @@ werkzeug = ">=2.0.0"
[package.extras] [package.extras]
dotenv = ["python-dotenv"] dotenv = ["python-dotenv"]
[[package]]
name = "regex"
version = "2021.11.10"
description = "Alternative regular expression module, to replace re."
category = "dev"
optional = false
python-versions = "*"
[[package]] [[package]]
name = "requests" name = "requests"
version = "2.26.0" version = "2.26.0"
@ -925,6 +1006,14 @@ category = "main"
optional = true optional = true
python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
[[package]]
name = "tomli"
version = "1.2.2"
description = "A lil' TOML parser"
category = "dev"
optional = false
python-versions = ">=3.6"
[[package]] [[package]]
name = "tomlkit" name = "tomlkit"
version = "0.7.2" version = "0.7.2"
@ -933,6 +1022,14 @@ category = "main"
optional = false optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
[[package]]
name = "typed-ast"
version = "1.5.0"
description = "a fork of Python 2 and 3 ast modules with type comment support"
category = "dev"
optional = false
python-versions = ">=3.6"
[[package]] [[package]]
name = "typing-extensions" name = "typing-extensions"
version = "4.0.0" version = "4.0.0"
@ -1100,7 +1197,7 @@ quart = ["Quart"]
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = "^3.7.3" python-versions = "^3.7.3"
content-hash = "537c91f98fd6598dbce8c2942530f18dee0858a896b6f393a684252a77dc76c6" content-hash = "9a64b2ba25ea3367e1636545241122d5b4d454ab3042865416193f84bf358fc3"
[metadata.files] [metadata.files]
aiocache = [ aiocache = [
@ -1221,6 +1318,10 @@ babel = [
{file = "Babel-2.9.1-py2.py3-none-any.whl", hash = "sha256:ab49e12b91d937cd11f0b67cb259a57ab4ad2b59ac7a3b41d6c06c0ac5b0def9"}, {file = "Babel-2.9.1-py2.py3-none-any.whl", hash = "sha256:ab49e12b91d937cd11f0b67cb259a57ab4ad2b59ac7a3b41d6c06c0ac5b0def9"},
{file = "Babel-2.9.1.tar.gz", hash = "sha256:bc0c176f9f6a994582230df350aa6e05ba2ebe4b3ac317eab29d9be5d2768da0"}, {file = "Babel-2.9.1.tar.gz", hash = "sha256:bc0c176f9f6a994582230df350aa6e05ba2ebe4b3ac317eab29d9be5d2768da0"},
] ]
black = [
{file = "black-21.11b1-py3-none-any.whl", hash = "sha256:802c6c30b637b28645b7fde282ed2569c0cd777dbe493a41b6a03c1d903f99ac"},
{file = "black-21.11b1.tar.gz", hash = "sha256:a042adbb18b3262faad5aff4e834ff186bb893f95ba3a8013f09de1e5569def2"},
]
blinker = [ blinker = [
{file = "blinker-1.4.tar.gz", hash = "sha256:471aee25f3992bd325afa3772f1063dbdbbca947a041b8b89466dc00d606f8b6"}, {file = "blinker-1.4.tar.gz", hash = "sha256:471aee25f3992bd325afa3772f1063dbdbbca947a041b8b89466dc00d606f8b6"},
] ]
@ -1471,8 +1572,8 @@ html2text = [
{file = "html2text-2020.1.16.tar.gz", hash = "sha256:e296318e16b059ddb97f7a8a1d6a5c1d7af4544049a01e261731d2d5cc277bbb"}, {file = "html2text-2020.1.16.tar.gz", hash = "sha256:e296318e16b059ddb97f7a8a1d6a5c1d7af4544049a01e261731d2d5cc277bbb"},
] ]
httpcore = [ httpcore = [
{file = "httpcore-0.14.2-py3-none-any.whl", hash = "sha256:47d7c8f755719d4a57be0b6e022897e9e963bf9ce4b15b9cc006a38a1cfa2932"}, {file = "httpcore-0.14.3-py3-none-any.whl", hash = "sha256:9a98d2416b78976fc5396ff1f6b26ae9885efbb3105d24eed490f20ab4c95ec1"},
{file = "httpcore-0.14.2.tar.gz", hash = "sha256:ff8f8b9434ec4823f95a30596fbe78039913e706d3e598b0b8955b1e1828e093"}, {file = "httpcore-0.14.3.tar.gz", hash = "sha256:d10162a63265a0228d5807964bd964478cbdb5178f9a2eedfebb2faba27eef5d"},
] ]
httptools = [ httptools = [
{file = "httptools-0.2.0-cp35-cp35m-macosx_10_14_x86_64.whl", hash = "sha256:79dbc21f3612a78b28384e989b21872e2e3cf3968532601544696e4ed0007ce5"}, {file = "httptools-0.2.0-cp35-cp35m-macosx_10_14_x86_64.whl", hash = "sha256:79dbc21f3612a78b28384e989b21872e2e3cf3968532601544696e4ed0007ce5"},
@ -1515,6 +1616,10 @@ importlib-metadata = [
{file = "importlib_metadata-4.8.2-py3-none-any.whl", hash = "sha256:53ccfd5c134223e497627b9815d5030edf77d2ed573922f7a0b8f8bb81a1c100"}, {file = "importlib_metadata-4.8.2-py3-none-any.whl", hash = "sha256:53ccfd5c134223e497627b9815d5030edf77d2ed573922f7a0b8f8bb81a1c100"},
{file = "importlib_metadata-4.8.2.tar.gz", hash = "sha256:75bdec14c397f528724c1bfd9709d660b33a4d2e77387a3358f20b848bb5e5fb"}, {file = "importlib_metadata-4.8.2.tar.gz", hash = "sha256:75bdec14c397f528724c1bfd9709d660b33a4d2e77387a3358f20b848bb5e5fb"},
] ]
isort = [
{file = "isort-5.10.1-py3-none-any.whl", hash = "sha256:6f62d78e2f89b4500b080fe3a81690850cd254227f27f75c3a0c491a1f351ba7"},
{file = "isort-5.10.1.tar.gz", hash = "sha256:e8443a5e7a020e9d7f97f1d7d9cd17c88bcb3bc7e218bf9cf5095fe550be2951"},
]
itsdangerous = [ itsdangerous = [
{file = "itsdangerous-2.0.1-py3-none-any.whl", hash = "sha256:5174094b9637652bdb841a3029700391451bd092ba3db90600dea710ba28e97c"}, {file = "itsdangerous-2.0.1-py3-none-any.whl", hash = "sha256:5174094b9637652bdb841a3029700391451bd092ba3db90600dea710ba28e97c"},
{file = "itsdangerous-2.0.1.tar.gz", hash = "sha256:9e724d68fc22902a1435351f84c3fb8623f303fffcc566a4cb952df8c572cff0"}, {file = "itsdangerous-2.0.1.tar.gz", hash = "sha256:9e724d68fc22902a1435351f84c3fb8623f303fffcc566a4cb952df8c572cff0"},
@ -1672,6 +1777,10 @@ multidict = [
{file = "multidict-5.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:c9631c642e08b9fff1c6255487e62971d8b8e821808ddd013d8ac058087591ac"}, {file = "multidict-5.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:c9631c642e08b9fff1c6255487e62971d8b8e821808ddd013d8ac058087591ac"},
{file = "multidict-5.2.0.tar.gz", hash = "sha256:0dd1c93edb444b33ba2274b66f63def8a327d607c6c790772f448a53b6ea59ce"}, {file = "multidict-5.2.0.tar.gz", hash = "sha256:0dd1c93edb444b33ba2274b66f63def8a327d607c6c790772f448a53b6ea59ce"},
] ]
mypy-extensions = [
{file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"},
{file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"},
]
nonebot-adapter-cqhttp = [] nonebot-adapter-cqhttp = []
nonebot-adapter-ding = [] nonebot-adapter-ding = []
nonebot-adapter-feishu = [] nonebot-adapter-feishu = []
@ -1681,8 +1790,16 @@ nonebot-plugin-test = [
{file = "nonebot_plugin_test-0.3.0-py3-none-any.whl", hash = "sha256:edb880340436323ccd0a13b31d48975136b6bdc71daa178601c4b05b068cc73e"}, {file = "nonebot_plugin_test-0.3.0-py3-none-any.whl", hash = "sha256:edb880340436323ccd0a13b31d48975136b6bdc71daa178601c4b05b068cc73e"},
] ]
packaging = [ packaging = [
{file = "packaging-21.2-py3-none-any.whl", hash = "sha256:14317396d1e8cdb122989b916fa2c7e9ca8e2be9e8060a6eff75b6b7b4d8a7e0"}, {file = "packaging-21.3-py3-none-any.whl", hash = "sha256:ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522"},
{file = "packaging-21.2.tar.gz", hash = "sha256:096d689d78ca690e4cd8a89568ba06d07ca097e3306a4381635073ca91479966"}, {file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"},
]
pathspec = [
{file = "pathspec-0.9.0-py2.py3-none-any.whl", hash = "sha256:7d15c4ddb0b5c802d161efc417ec1a2558ea2653c2e8ad9c19098201dc1c993a"},
{file = "pathspec-0.9.0.tar.gz", hash = "sha256:e564499435a2673d586f6b2130bb5b95f04a3ba06f81b8f895b651a3c76aabb1"},
]
platformdirs = [
{file = "platformdirs-2.4.0-py3-none-any.whl", hash = "sha256:8868bbe3c3c80d42f20156f22e7131d2fb321f5bc86a2a345375c6481a67021d"},
{file = "platformdirs-2.4.0.tar.gz", hash = "sha256:367a5e80b3d04d2428ffa76d33f124cf11e8fff2acdaa9b43d545f5c7d661ef2"},
] ]
priority = [ priority = [
{file = "priority-2.0.0-py3-none-any.whl", hash = "sha256:6f8eefce5f3ad59baf2c080a664037bb4725cd0a790d53d59ab4059288faf6aa"}, {file = "priority-2.0.0-py3-none-any.whl", hash = "sha256:6f8eefce5f3ad59baf2c080a664037bb4725cd0a790d53d59ab4059288faf6aa"},
@ -1793,8 +1910,8 @@ pygtrie = [
{file = "pygtrie-2.4.2.tar.gz", hash = "sha256:43205559d28863358dbbf25045029f58e2ab357317a59b11f11ade278ac64692"}, {file = "pygtrie-2.4.2.tar.gz", hash = "sha256:43205559d28863358dbbf25045029f58e2ab357317a59b11f11ade278ac64692"},
] ]
pyparsing = [ pyparsing = [
{file = "pyparsing-2.4.7-py2.py3-none-any.whl", hash = "sha256:ef9d7589ef3c200abe66653d3f1ab1033c3c419ae9b9bdb1240a85b024efc88b"}, {file = "pyparsing-3.0.6-py3-none-any.whl", hash = "sha256:04ff808a5b90911829c55c4e26f75fa5ca8a2f5f36aa3a51f68e27033341d3e4"},
{file = "pyparsing-2.4.7.tar.gz", hash = "sha256:c203ec8783bf771a155b207279b9bccb8dea02d8f0c9e5f8ead507bc3246ecc1"}, {file = "pyparsing-3.0.6.tar.gz", hash = "sha256:d9bdec0013ef1eb5a84ab39a3b3868911598afa494f5faa038647101504e2b81"},
] ]
python-dotenv = [ python-dotenv = [
{file = "python-dotenv-0.19.2.tar.gz", hash = "sha256:a5de49a31e953b45ff2d2fd434bbc2670e8db5273606c1e737cc6b93eff3655f"}, {file = "python-dotenv-0.19.2.tar.gz", hash = "sha256:a5de49a31e953b45ff2d2fd434bbc2670e8db5273606c1e737cc6b93eff3655f"},
@ -1851,6 +1968,57 @@ quart = [
{file = "Quart-0.15.1-py3-none-any.whl", hash = "sha256:f35134fb1d81af61624e6d89bca33cd611dcedce2dc4e291f527ab04395f4e1a"}, {file = "Quart-0.15.1-py3-none-any.whl", hash = "sha256:f35134fb1d81af61624e6d89bca33cd611dcedce2dc4e291f527ab04395f4e1a"},
{file = "Quart-0.15.1.tar.gz", hash = "sha256:f80c91d1e0588662483e22dd9c368a5778886b62e128c5399d2cc1b1898482cf"}, {file = "Quart-0.15.1.tar.gz", hash = "sha256:f80c91d1e0588662483e22dd9c368a5778886b62e128c5399d2cc1b1898482cf"},
] ]
regex = [
{file = "regex-2021.11.10-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9345b6f7ee578bad8e475129ed40123d265464c4cfead6c261fd60fc9de00bcf"},
{file = "regex-2021.11.10-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:416c5f1a188c91e3eb41e9c8787288e707f7d2ebe66e0a6563af280d9b68478f"},
{file = "regex-2021.11.10-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e0538c43565ee6e703d3a7c3bdfe4037a5209250e8502c98f20fea6f5fdf2965"},
{file = "regex-2021.11.10-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7ee1227cf08b6716c85504aebc49ac827eb88fcc6e51564f010f11a406c0a667"},
{file = "regex-2021.11.10-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6650f16365f1924d6014d2ea770bde8555b4a39dc9576abb95e3cd1ff0263b36"},
{file = "regex-2021.11.10-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:30ab804ea73972049b7a2a5c62d97687d69b5a60a67adca07eb73a0ddbc9e29f"},
{file = "regex-2021.11.10-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:68a067c11463de2a37157930d8b153005085e42bcb7ad9ca562d77ba7d1404e0"},
{file = "regex-2021.11.10-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:162abfd74e88001d20cb73ceaffbfe601469923e875caf9118333b1a4aaafdc4"},
{file = "regex-2021.11.10-cp310-cp310-win32.whl", hash = "sha256:98ba568e8ae26beb726aeea2273053c717641933836568c2a0278a84987b2a1a"},
{file = "regex-2021.11.10-cp310-cp310-win_amd64.whl", hash = "sha256:780b48456a0f0ba4d390e8b5f7c661fdd218934388cde1a974010a965e200e12"},
{file = "regex-2021.11.10-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:dba70f30fd81f8ce6d32ddeef37d91c8948e5d5a4c63242d16a2b2df8143aafc"},
{file = "regex-2021.11.10-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e1f54b9b4b6c53369f40028d2dd07a8c374583417ee6ec0ea304e710a20f80a0"},
{file = "regex-2021.11.10-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fbb9dc00e39f3e6c0ef48edee202f9520dafb233e8b51b06b8428cfcb92abd30"},
{file = "regex-2021.11.10-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:666abff54e474d28ff42756d94544cdfd42e2ee97065857413b72e8a2d6a6345"},
{file = "regex-2021.11.10-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5537f71b6d646f7f5f340562ec4c77b6e1c915f8baae822ea0b7e46c1f09b733"},
{file = "regex-2021.11.10-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ed2e07c6a26ed4bea91b897ee2b0835c21716d9a469a96c3e878dc5f8c55bb23"},
{file = "regex-2021.11.10-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ca5f18a75e1256ce07494e245cdb146f5a9267d3c702ebf9b65c7f8bd843431e"},
{file = "regex-2021.11.10-cp36-cp36m-win32.whl", hash = "sha256:93a5051fcf5fad72de73b96f07d30bc29665697fb8ecdfbc474f3452c78adcf4"},
{file = "regex-2021.11.10-cp36-cp36m-win_amd64.whl", hash = "sha256:b483c9d00a565633c87abd0aaf27eb5016de23fed952e054ecc19ce32f6a9e7e"},
{file = "regex-2021.11.10-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:fff55f3ce50a3ff63ec8e2a8d3dd924f1941b250b0aac3d3d42b687eeff07a8e"},
{file = "regex-2021.11.10-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e32d2a2b02ccbef10145df9135751abea1f9f076e67a4e261b05f24b94219e36"},
{file = "regex-2021.11.10-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:53db2c6be8a2710b359bfd3d3aa17ba38f8aa72a82309a12ae99d3c0c3dcd74d"},
{file = "regex-2021.11.10-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2207ae4f64ad3af399e2d30dde66f0b36ae5c3129b52885f1bffc2f05ec505c8"},
{file = "regex-2021.11.10-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d5ca078bb666c4a9d1287a379fe617a6dccd18c3e8a7e6c7e1eb8974330c626a"},
{file = "regex-2021.11.10-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dd33eb9bdcfbabab3459c9ee651d94c842bc8a05fabc95edf4ee0c15a072495e"},
{file = "regex-2021.11.10-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:05b7d6d7e64efe309972adab77fc2af8907bb93217ec60aa9fe12a0dad35874f"},
{file = "regex-2021.11.10-cp37-cp37m-win32.whl", hash = "sha256:e71255ba42567d34a13c03968736c5d39bb4a97ce98188fafb27ce981115beec"},
{file = "regex-2021.11.10-cp37-cp37m-win_amd64.whl", hash = "sha256:07856afef5ffcc052e7eccf3213317fbb94e4a5cd8177a2caa69c980657b3cb4"},
{file = "regex-2021.11.10-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ba05430e819e58544e840a68b03b28b6d328aff2e41579037e8bab7653b37d83"},
{file = "regex-2021.11.10-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7f301b11b9d214f83ddaf689181051e7f48905568b0c7017c04c06dfd065e244"},
{file = "regex-2021.11.10-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4aaa4e0705ef2b73dd8e36eeb4c868f80f8393f5f4d855e94025ce7ad8525f50"},
{file = "regex-2021.11.10-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:788aef3549f1924d5c38263104dae7395bf020a42776d5ec5ea2b0d3d85d6646"},
{file = "regex-2021.11.10-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f8af619e3be812a2059b212064ea7a640aff0568d972cd1b9e920837469eb3cb"},
{file = "regex-2021.11.10-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85bfa6a5413be0ee6c5c4a663668a2cad2cbecdee367630d097d7823041bdeec"},
{file = "regex-2021.11.10-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f23222527b307970e383433daec128d769ff778d9b29343fb3496472dc20dabe"},
{file = "regex-2021.11.10-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:da1a90c1ddb7531b1d5ff1e171b4ee61f6345119be7351104b67ff413843fe94"},
{file = "regex-2021.11.10-cp38-cp38-win32.whl", hash = "sha256:0617383e2fe465732af4509e61648b77cbe3aee68b6ac8c0b6fe934db90be5cc"},
{file = "regex-2021.11.10-cp38-cp38-win_amd64.whl", hash = "sha256:a3feefd5e95871872673b08636f96b61ebef62971eab044f5124fb4dea39919d"},
{file = "regex-2021.11.10-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f7f325be2804246a75a4f45c72d4ce80d2443ab815063cdf70ee8fb2ca59ee1b"},
{file = "regex-2021.11.10-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:537ca6a3586931b16a85ac38c08cc48f10fc870a5b25e51794c74df843e9966d"},
{file = "regex-2021.11.10-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eef2afb0fd1747f33f1ee3e209bce1ed582d1896b240ccc5e2697e3275f037c7"},
{file = "regex-2021.11.10-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:432bd15d40ed835a51617521d60d0125867f7b88acf653e4ed994a1f8e4995dc"},
{file = "regex-2021.11.10-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b43c2b8a330a490daaef5a47ab114935002b13b3f9dc5da56d5322ff218eeadb"},
{file = "regex-2021.11.10-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:962b9a917dd7ceacbe5cd424556914cb0d636001e393b43dc886ba31d2a1e449"},
{file = "regex-2021.11.10-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fa8c626d6441e2d04b6ee703ef2d1e17608ad44c7cb75258c09dd42bacdfc64b"},
{file = "regex-2021.11.10-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:3c5fb32cc6077abad3bbf0323067636d93307c9fa93e072771cf9a64d1c0f3ef"},
{file = "regex-2021.11.10-cp39-cp39-win32.whl", hash = "sha256:3b5df18db1fccd66de15aa59c41e4f853b5df7550723d26aa6cb7f40e5d9da5a"},
{file = "regex-2021.11.10-cp39-cp39-win_amd64.whl", hash = "sha256:83ee89483672b11f8952b158640d0c0ff02dc43d9cb1b70c1564b49abe92ce29"},
{file = "regex-2021.11.10.tar.gz", hash = "sha256:f341ee2df0999bfdf7a95e448075effe0db212a59387de1a70690e4acb03d4c6"},
]
requests = [ requests = [
{file = "requests-2.26.0-py2.py3-none-any.whl", hash = "sha256:6c1246513ecd5ecd4528a0906f910e8f0f9c6b8ec72030dc9fd154dc1a6efd24"}, {file = "requests-2.26.0-py2.py3-none-any.whl", hash = "sha256:6c1246513ecd5ecd4528a0906f910e8f0f9c6b8ec72030dc9fd154dc1a6efd24"},
{file = "requests-2.26.0.tar.gz", hash = "sha256:b8aa58f8cf793ffd8782d3d8cb19e66ef36f7aba4353eec859e74678b01b07a7"}, {file = "requests-2.26.0.tar.gz", hash = "sha256:b8aa58f8cf793ffd8782d3d8cb19e66ef36f7aba4353eec859e74678b01b07a7"},
@ -1908,10 +2076,35 @@ toml = [
{file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"}, {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"},
{file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
] ]
tomli = [
{file = "tomli-1.2.2-py3-none-any.whl", hash = "sha256:f04066f68f5554911363063a30b108d2b5a5b1a010aa8b6132af78489fe3aade"},
{file = "tomli-1.2.2.tar.gz", hash = "sha256:c6ce0015eb38820eaf32b5db832dbc26deb3dd427bd5f6556cf0acac2c214fee"},
]
tomlkit = [ tomlkit = [
{file = "tomlkit-0.7.2-py2.py3-none-any.whl", hash = "sha256:173ad840fa5d2aac140528ca1933c29791b79a374a0861a80347f42ec9328117"}, {file = "tomlkit-0.7.2-py2.py3-none-any.whl", hash = "sha256:173ad840fa5d2aac140528ca1933c29791b79a374a0861a80347f42ec9328117"},
{file = "tomlkit-0.7.2.tar.gz", hash = "sha256:d7a454f319a7e9bd2e249f239168729327e4dd2d27b17dc68be264ad1ce36754"}, {file = "tomlkit-0.7.2.tar.gz", hash = "sha256:d7a454f319a7e9bd2e249f239168729327e4dd2d27b17dc68be264ad1ce36754"},
] ]
typed-ast = [
{file = "typed_ast-1.5.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7b310a207ee9fde3f46ba327989e6cba4195bc0c8c70a158456e7b10233e6bed"},
{file = "typed_ast-1.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:52ca2b2b524d770bed7a393371a38e91943f9160a190141e0df911586066ecda"},
{file = "typed_ast-1.5.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:14fed8820114a389a2b7e91624db5f85f3f6682fda09fe0268a59aabd28fe5f5"},
{file = "typed_ast-1.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:65c81abbabda7d760df7304d843cc9dbe7ef5d485504ca59a46ae2d1731d2428"},
{file = "typed_ast-1.5.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:37ba2ab65a0028b1a4f2b61a8fe77f12d242731977d274a03d68ebb751271508"},
{file = "typed_ast-1.5.0-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:49af5b8f6f03ed1eb89ee06c1d7c2e7c8e743d720c3746a5857609a1abc94c94"},
{file = "typed_ast-1.5.0-cp36-cp36m-win_amd64.whl", hash = "sha256:e4374a76e61399a173137e7984a1d7e356038cf844f24fd8aea46c8029a2f712"},
{file = "typed_ast-1.5.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:ea517c2bb11c5e4ba7a83a91482a2837041181d57d3ed0749a6c382a2b6b7086"},
{file = "typed_ast-1.5.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:51040bf45aacefa44fa67fb9ebcd1f2bec73182b99a532c2394eea7dabd18e24"},
{file = "typed_ast-1.5.0-cp37-cp37m-win_amd64.whl", hash = "sha256:806e0c7346b9b4af8c62d9a29053f484599921a4448c37fbbcbbf15c25138570"},
{file = "typed_ast-1.5.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:a67fd5914603e2165e075f1b12f5a8356bfb9557e8bfb74511108cfbab0f51ed"},
{file = "typed_ast-1.5.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:224afecb8b39739f5c9562794a7c98325cb9d972712e1a98b6989a4720219541"},
{file = "typed_ast-1.5.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:155b74b078be842d2eb630dd30a280025eca0a5383c7d45853c27afee65f278f"},
{file = "typed_ast-1.5.0-cp38-cp38-win_amd64.whl", hash = "sha256:361b9e5d27bd8e3ccb6ea6ad6c4f3c0be322a1a0f8177db6d56264fa0ae40410"},
{file = "typed_ast-1.5.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:618912cbc7e17b4aeba86ffe071698c6e2d292acbd6d1d5ec1ee724b8c4ae450"},
{file = "typed_ast-1.5.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7e6731044f748340ef68dcadb5172a4b1f40847a2983fe3983b2a66445fbc8e6"},
{file = "typed_ast-1.5.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:e8a9b9c87801cecaad3b4c2b8876387115d1a14caa602c1618cedbb0cb2a14e6"},
{file = "typed_ast-1.5.0-cp39-cp39-win_amd64.whl", hash = "sha256:ec184dfb5d3d11e82841dbb973e7092b75f306b625fad7b2e665b64c5d60ab3f"},
{file = "typed_ast-1.5.0.tar.gz", hash = "sha256:ff4ad88271aa7a55f19b6a161ed44e088c393846d954729549e3cde8257747bb"},
]
typing-extensions = [ typing-extensions = [
{file = "typing_extensions-4.0.0-py3-none-any.whl", hash = "sha256:829704698b22e13ec9eaf959122315eabb370b0884400e9818334d8b677023d9"}, {file = "typing_extensions-4.0.0-py3-none-any.whl", hash = "sha256:829704698b22e13ec9eaf959122315eabb370b0884400e9818334d8b677023d9"},
{file = "typing_extensions-4.0.0.tar.gz", hash = "sha256:2cdf80e4e04866a9b3689a51869016d36db0814d84b8d8a568d22781d45d27ed"}, {file = "typing_extensions-4.0.0.tar.gz", hash = "sha256:2cdf80e4e04866a9b3689a51869016d36db0814d84b8d8a568d22781d45d27ed"},

View File

@ -36,8 +36,9 @@ uvicorn = { version = "^0.15.0", extras = ["standard"] }
aiohttp = { version = "^3.7.4", extras = ["speedups"], optional = true } aiohttp = { version = "^3.7.4", extras = ["speedups"], optional = true }
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
yapf = "^0.31.0"
sphinx = "^4.1.1" sphinx = "^4.1.1"
isort = "^5.10.1"
black = "^21.11b1"
nonebot-plugin-test = "^0.3.0" nonebot-plugin-test = "^0.3.0"
nonebot-adapter-cqhttp = { path = "./packages/nonebot-adapter-cqhttp", develop = true } nonebot-adapter-cqhttp = { path = "./packages/nonebot-adapter-cqhttp", develop = true }
nonebot-adapter-ding = { path = "./packages/nonebot-adapter-ding", develop = true } nonebot-adapter-ding = { path = "./packages/nonebot-adapter-ding", develop = true }
@ -55,13 +56,21 @@ all = ["quart", "aiohttp"]
# url = "https://mirrors.aliyun.com/pypi/simple/" # url = "https://mirrors.aliyun.com/pypi/simple/"
# default = true # default = true
[tool.black]
line-length = 88
target-version = ["py37", "py38", "py39", "py310"]
include = '\.pyi?$'
extend-exclude = '''
'''
[tool.isort] [tool.isort]
profile = "black"
line_length = 80 line_length = 80
length_sort = true length_sort = true
skip_gitignore = true skip_gitignore = true
force_sort_within_sections = true force_sort_within_sections = true
known_local_folder = "nonebot" known_local_folder = ["nonebot"]
extra_standard_library = "typing_extensions" extra_standard_library = ["typing_extensions"]
[build-system] [build-system]
requires = ["poetry_core>=1.0.0"] requires = ["poetry_core>=1.0.0"]

View File

@ -11,11 +11,9 @@ from nonebot.adapters.mirai import Bot as MiraiBot
from nonebot.adapters.feishu import Bot as FeishuBot from nonebot.adapters.feishu import Bot as FeishuBot
# test custom log # test custom log
logger.add("error.log", logger.add(
rotation="00:00", "error.log", rotation="00:00", diagnose=False, level="ERROR", format=default_format
diagnose=False, )
level="ERROR",
format=default_format)
nonebot.init(custom_config2="config on init") nonebot.init(custom_config2="config on init")
app = nonebot.get_asgi() app = nonebot.get_asgi()

View File

@ -1,7 +1,8 @@
from nonebot.adapters.ding.event import GroupMessageEvent, PrivateMessageEvent
from nonebot.rule import to_me from nonebot.rule import to_me
from nonebot.plugin import on_command from nonebot.plugin import on_command
from nonebot.adapters.ding import Bot as DingBot, MessageSegment, MessageEvent from nonebot.adapters.ding import Bot as DingBot
from nonebot.adapters.ding import MessageEvent, MessageSegment
from nonebot.adapters.ding.event import GroupMessageEvent, PrivateMessageEvent
helper = on_command("ding_helper", to_me()) helper = on_command("ding_helper", to_me())
@ -34,7 +35,7 @@ markdown = on_command("markdown", to_me())
async def markdown_handler(bot: DingBot): async def markdown_handler(bot: DingBot):
message = MessageSegment.markdown( message = MessageSegment.markdown(
"Hello, This is NoneBot", "Hello, This is NoneBot",
"#### NoneBot \n> Nonebot 是一款高性能的 Python 机器人框架\n> ![screenshot](https://v2.nonebot.dev/logo.png)\n> [GitHub 仓库地址](https://github.com/nonebot/nonebot2) \n" "#### NoneBot \n> Nonebot 是一款高性能的 Python 机器人框架\n> ![screenshot](https://v2.nonebot.dev/logo.png)\n> [GitHub 仓库地址](https://github.com/nonebot/nonebot2) \n",
) )
await markdown.finish(message) await markdown.finish(message)
@ -46,10 +47,10 @@ actionCardSingleBtn = on_command("actionCardSingleBtn", to_me())
async def actionCardSingleBtn_handler(bot: DingBot): async def actionCardSingleBtn_handler(bot: DingBot):
message = MessageSegment.actionCardSingleBtn( message = MessageSegment.actionCardSingleBtn(
title="打造一间咖啡厅", title="打造一间咖啡厅",
text= text="![screenshot](https://img.alicdn.com/tfs/TB1NwmBEL9TBuNjy1zbXXXpepXa-2400-1218.png) \n #### 乔布斯 20 年前想打造的苹果咖啡厅 \n\n Apple Store 的设计正从原来满满的科技感走向生活化,而其生活化的走向其实可以追溯到 20 年前苹果一个建立咖啡馆的计划",
"![screenshot](https://img.alicdn.com/tfs/TB1NwmBEL9TBuNjy1zbXXXpepXa-2400-1218.png) \n #### 乔布斯 20 年前想打造的苹果咖啡厅 \n\n Apple Store 的设计正从原来满满的科技感走向生活化,而其生活化的走向其实可以追溯到 20 年前苹果一个建立咖啡馆的计划",
singleTitle="阅读全文", singleTitle="阅读全文",
singleURL="https://www.dingtalk.com/") singleURL="https://www.dingtalk.com/",
)
await actionCardSingleBtn.finish(message) await actionCardSingleBtn.finish(message)
@ -58,26 +59,21 @@ actionCard = on_command("actionCard", to_me())
@actionCard.handle() @actionCard.handle()
async def actionCard_handler(bot: DingBot): async def actionCard_handler(bot: DingBot):
message = MessageSegment.raw({ message = MessageSegment.raw(
{
"msgtype": "actionCard", "msgtype": "actionCard",
"actionCard": { "actionCard": {
"title": "title": "乔布斯 20 年前想打造一间苹果咖啡厅,而它正是 Apple Store 的前身",
"乔布斯 20 年前想打造一间苹果咖啡厅,而它正是 Apple Store 的前身", "text": "![screenshot](https://img.alicdn.com/tfs/TB1NwmBEL9TBuNjy1zbXXXpepXa-2400-1218.png) \n\n #### 乔布斯 20 年前想打造的苹果咖啡厅 \n\n Apple Store 的设计正从原来满满的科技感走向生活化,而其生活化的走向其实可以追溯到 20 年前苹果一个建立咖啡馆的计划",
"text": "hideAvatar": "0",
"![screenshot](https://img.alicdn.com/tfs/TB1NwmBEL9TBuNjy1zbXXXpepXa-2400-1218.png) \n\n #### 乔布斯 20 年前想打造的苹果咖啡厅 \n\n Apple Store 的设计正从原来满满的科技感走向生活化,而其生活化的走向其实可以追溯到 20 年前苹果一个建立咖啡馆的计划", "btnOrientation": "0",
"hideAvatar": "btns": [
"0", {"title": "内容不错", "actionURL": "https://www.dingtalk.com/"},
"btnOrientation": {"title": "不感兴趣", "actionURL": "https://www.dingtalk.com/"},
"0", ],
"btns": [{ },
"title": "内容不错",
"actionURL": "https://www.dingtalk.com/"
}, {
"title": "不感兴趣",
"actionURL": "https://www.dingtalk.com/"
}]
} }
}) )
await actionCard.finish(message, at_sender=True) await actionCard.finish(message, at_sender=True)
@ -86,26 +82,25 @@ feedCard = on_command("feedCard", to_me())
@feedCard.handle() @feedCard.handle()
async def feedCard_handler(bot: DingBot): async def feedCard_handler(bot: DingBot):
message = MessageSegment.raw({ message = MessageSegment.raw(
{
"msgtype": "feedCard", "msgtype": "feedCard",
"feedCard": { "feedCard": {
"links": [{ "links": [
"title": {
"时代的火车向前开1", "title": "时代的火车向前开1",
"messageURL": "messageURL": "https://www.dingtalk.com/",
"https://www.dingtalk.com/", "picURL": "https://img.alicdn.com/tfs/TB1NwmBEL9TBuNjy1zbXXXpepXa-2400-1218.png",
"picURL": },
"https://img.alicdn.com/tfs/TB1NwmBEL9TBuNjy1zbXXXpepXa-2400-1218.png" {
}, { "title": "时代的火车向前开2",
"title": "messageURL": "https://www.dingtalk.com/",
"时代的火车向前开2", "picURL": "https://img.alicdn.com/tfs/TB1NwmBEL9TBuNjy1zbXXXpepXa-2400-1218.png",
"messageURL": },
"https://www.dingtalk.com/", ]
"picURL": },
"https://img.alicdn.com/tfs/TB1NwmBEL9TBuNjy1zbXXXpepXa-2400-1218.png"
}]
} }
}) )
await feedCard.finish(message) await feedCard.finish(message)
@ -115,7 +110,8 @@ atme = on_command("atme", to_me())
@atme.handle() @atme.handle()
async def atme_handler(bot: DingBot, event: MessageEvent): async def atme_handler(bot: DingBot, event: MessageEvent):
message = f"@{event.senderId} manually at you" + MessageSegment.atDingtalkIds( message = f"@{event.senderId} manually at you" + MessageSegment.atDingtalkIds(
event.senderId) event.senderId
)
await atme.send("matcher send auto at you", at_sender=True) await atme.send("matcher send auto at you", at_sender=True)
await bot.send(event, "bot send auto at you", at_sender=True) await bot.send(event, "bot send auto at you", at_sender=True)
await atme.finish(message) await atme.finish(message)
@ -143,12 +139,12 @@ async def textAdd_handler(bot: DingBot, event: MessageEvent):
message = message + MessageSegment.text("第二段消息\n") message = message + MessageSegment.text("第二段消息\n")
await textAdd.send(message) await textAdd.send(message)
message = message + MessageSegment.text( message = (
"\n第三段消息\n") + "adfkasfkhsdkfahskdjasdashdkjasdf" message + MessageSegment.text("\n第三段消息\n") + "adfkasfkhsdkfahskdjasdashdkjasdf"
message = message + MessageSegment.extension({ )
"text_type": "code_snippet", message = message + MessageSegment.extension(
"code_language": "C#" {"text_type": "code_snippet", "code_language": "C#"}
}) )
await textAdd.send(message) await textAdd.send(message)
@ -159,7 +155,8 @@ code = on_command("code", to_me())
async def code_handler(bot: DingBot, event: MessageEvent): async def code_handler(bot: DingBot, event: MessageEvent):
raw = MessageSegment.code("Python", 'print("hello world")') raw = MessageSegment.code("Python", 'print("hello world")')
await code.send(raw) await code.send(raw)
message = MessageSegment.text("""using System; message = MessageSegment.text(
"""using System;
namespace HelloWorld namespace HelloWorld
{ {
@ -170,11 +167,11 @@ namespace HelloWorld
Console.WriteLine("Hello World!"); Console.WriteLine("Hello World!");
} }
} }
}""") }"""
message += MessageSegment.extension({ )
"text_type": "code_snippet", message += MessageSegment.extension(
"code_language": "C#" {"text_type": "code_snippet", "code_language": "C#"}
}) )
await code.finish(message) await code.finish(message)
@ -196,12 +193,12 @@ hello = on_command("hello", to_me())
@hello.handle() @hello.handle()
async def hello_handler(bot: DingBot, event: MessageEvent): async def hello_handler(bot: DingBot, event: MessageEvent):
message = MessageSegment.raw({ message = MessageSegment.raw(
{
"msgtype": "text", "msgtype": "text",
"text": { "text": {"content": "hello "},
"content": 'hello ' }
}, )
})
message += MessageSegment.atDingtalkIds(event.senderId) message += MessageSegment.atDingtalkIds(event.senderId)
await hello.send(message) await hello.send(message)
@ -216,22 +213,21 @@ hello = on_command("webhook", to_me())
@hello.handle() @hello.handle()
async def webhook_handler(bot: DingBot, event: MessageEvent): async def webhook_handler(bot: DingBot, event: MessageEvent):
print(event) print(event)
message = MessageSegment.raw({ message = MessageSegment.raw(
{
"msgtype": "text", "msgtype": "text",
"text": { "text": {"content": "hello from webhook,一定要注意安全方式的鉴权哦,否则可能发送失败的"},
"content": 'hello from webhook,一定要注意安全方式的鉴权哦,否则可能发送失败的' }
}, )
})
message += MessageSegment.atDingtalkIds(event.senderId) message += MessageSegment.atDingtalkIds(event.senderId)
await hello.send( await hello.send(
message, message,
webhook= webhook="https://oapi.dingtalk.com/robot/send?access_token=XXXXXXXXXXXXXX",
"https://oapi.dingtalk.com/robot/send?access_token=XXXXXXXXXXXXXX", secret="SECXXXXXXXXXXXXXXXXXXXXXXXXX",
secret="SECXXXXXXXXXXXXXXXXXXXXXXXXX") )
message = MessageSegment.text("TEST 123123 S") message = MessageSegment.text("TEST 123123 S")
await hello.send( await hello.send(
message, message,
webhook= webhook="https://oapi.dingtalk.com/robot/send?access_token=XXXXXXXXXXXXXX",
"https://oapi.dingtalk.com/robot/send?access_token=XXXXXXXXXXXXXX",
) )

View File

@ -1,6 +1,7 @@
from nonebot.plugin import on_command
from nonebot.typing import T_State from nonebot.typing import T_State
from nonebot.adapters.feishu import Bot as FeishuBot, MessageEvent from nonebot.plugin import on_command
from nonebot.adapters.feishu import MessageEvent
from nonebot.adapters.feishu import Bot as FeishuBot
helper = on_command("say") helper = on_command("say")

View File

@ -1,5 +1,4 @@
import nonebot import nonebot
from .test_export import export from .test_export import export
print(export, nonebot.require("test_export")) print(export, nonebot.require("test_export"))

View File

@ -4,4 +4,4 @@ from nonebot import CommandGroup, MatcherGroup
cmd = CommandGroup("test", rule=to_me()) cmd = CommandGroup("test", rule=to_me())
match = MatcherGroup(priority=2) match = MatcherGroup(priority=2)
from . import commands, matches from . import matches, commands

View File

@ -1,6 +1,5 @@
from nonebot.adapters import Bot, Event
from . import cmd from . import cmd
from nonebot.adapters import Bot, Event
test_1 = cmd.command("1", aliases={"test"}) test_1 = cmd.command("1", aliases={"test"})

View File

@ -1,9 +1,8 @@
from . import match
from nonebot.typing import T_State from nonebot.typing import T_State
from nonebot.adapters import Bot, Event from nonebot.adapters import Bot, Event
from nonebot.adapters.cqhttp import HeartbeatMetaEvent from nonebot.adapters.cqhttp import HeartbeatMetaEvent
from . import match
async def heartbeat(bot: Bot, event: Event, state: T_State) -> bool: async def heartbeat(bot: Bot, event: Event, state: T_State) -> bool:
return isinstance(event, HeartbeatMetaEvent) return isinstance(event, HeartbeatMetaEvent)

View File

@ -1,6 +1,6 @@
from nonebot.typing import T_State from nonebot.typing import T_State
from nonebot.plugin import on_metaevent
from nonebot.adapters import Bot, Event from nonebot.adapters import Bot, Event
from nonebot.plugin import on_metaevent
from nonebot.adapters.cqhttp import HeartbeatMetaEvent from nonebot.adapters.cqhttp import HeartbeatMetaEvent

View File

@ -1,8 +1,8 @@
from nonebot.plugin import on_keyword, on_command
from nonebot.rule import to_me from nonebot.rule import to_me
from nonebot.plugin import on_command, on_keyword
from nonebot.adapters.mirai import Bot, MessageEvent from nonebot.adapters.mirai import Bot, MessageEvent
message_test = on_keyword({'reply'}, rule=to_me()) message_test = on_keyword({"reply"}, rule=to_me())
@message_test.handle() @message_test.handle()
@ -11,7 +11,7 @@ async def _message(bot: Bot, event: MessageEvent):
await bot.send(event, text, at_sender=True) await bot.send(event, text, at_sender=True)
command_test = on_command('miecho') command_test = on_command("miecho")
@command_test.handle() @command_test.handle()

View File

@ -1,5 +1,5 @@
from nonebot import on_command from nonebot import on_command
from nonebot.adapters.cqhttp import Bot, PrivateMessageEvent, GroupMessageEvent from nonebot.adapters.cqhttp import Bot, GroupMessageEvent, PrivateMessageEvent
overload = on_command("overload") overload = on_command("overload")

View File

@ -1,7 +1,7 @@
from nonebot.rule import to_me from nonebot.rule import to_me
from nonebot.typing import T_State from nonebot.typing import T_State
from nonebot.plugin import on_startswith
from nonebot.permission import SUPERUSER from nonebot.permission import SUPERUSER
from nonebot.plugin import on_startswith
from nonebot.adapters.ding import Bot as DingBot from nonebot.adapters.ding import Bot as DingBot
from nonebot.adapters.cqhttp import Bot as CQHTTPBot from nonebot.adapters.cqhttp import Bot as CQHTTPBot

View File

@ -1,7 +1,7 @@
from nonebot.adapters import Bot from nonebot.adapters import Bot
from nonebot.typing import T_State from nonebot.typing import T_State
from nonebot import on_shell_command from nonebot import on_shell_command
from nonebot.rule import to_me, ArgumentParser from nonebot.rule import ArgumentParser, to_me
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("-a", action="store_true") parser.add_argument("-a", action="store_true")