add get adapter (#1747)

This commit is contained in:
Ju4tCode 2023-02-26 14:15:10 +08:00 committed by GitHub
parent dd04190ca2
commit 04a7c3bc13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 66 additions and 5 deletions

View File

@ -39,14 +39,14 @@ FrontMatter:
import os import os
from importlib.metadata import version from importlib.metadata import version
from typing import Any, Dict, Type, Optional from typing import Any, Dict, Type, Union, Optional
import loguru import loguru
from pydantic.env_settings import DotenvType from pydantic.env_settings import DotenvType
from nonebot.adapters import Bot
from nonebot.config import Env, Config from nonebot.config import Env, Config
from nonebot.log import logger as logger from nonebot.log import logger as logger
from nonebot.adapters import Bot, Adapter
from nonebot.utils import escape_tag, resolve_dot_notation from nonebot.utils import escape_tag, resolve_dot_notation
from nonebot.drivers import Driver, ReverseDriver, combine_driver from nonebot.drivers import Driver, ReverseDriver, combine_driver
@ -79,6 +79,46 @@ def get_driver() -> Driver:
return _driver return _driver
def get_adapter(name: Union[str, Type[Adapter]]) -> Adapter:
"""获取已注册的 {ref}`nonebot.adapters.Adapter` 实例。
返回:
指定名称或类型的 {ref}`nonebot.adapters.Adapter` 对象
异常:
ValueError: 指定的 {ref}`nonebot.adapters.Adapter` 未注册
ValueError: 全局 {ref}`nonebot.drivers.Driver` 对象尚未初始化 ({ref}`nonebot.init <nonebot.init>` 尚未调用)
用法:
```python
from nonebot.adapters.console import Adapter
adapter = nonebot.get_adapter(Adapter)
```
"""
adapters = get_adapters()
target = name if isinstance(name, str) else name.get_name()
if target not in adapters:
raise ValueError(f"Adapter {target} not registered.")
return adapters[target]
def get_adapters() -> Dict[str, Adapter]:
"""获取所有已注册的 {ref}`nonebot.adapters.Adapter` 实例。
返回:
所有 {ref}`nonebot.adapters.Adapter` 实例字典
异常:
ValueError: 全局 {ref}`nonebot.drivers.Driver` 对象尚未初始化 ({ref}`nonebot.init <nonebot.init>` 尚未调用)
用法:
```python
adapters = nonebot.get_adapters()
```
"""
return get_driver()._adapters.copy()
def get_app() -> Any: def get_app() -> Any:
"""获取全局 {ref}`nonebot.drivers.ReverseDriver` 对应的 Server App 对象。 """获取全局 {ref}`nonebot.drivers.ReverseDriver` 对应的 Server App 对象。

View File

@ -1,8 +1,17 @@
import pytest import pytest
from nonebug import App
import nonebot import nonebot
from nonebot.drivers import ReverseDriver from nonebot.drivers import Driver, ReverseDriver
from nonebot import get_app, get_bot, get_asgi, get_bots, get_driver from nonebot import (
get_app,
get_bot,
get_asgi,
get_bots,
get_driver,
get_adapter,
get_adapters,
)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -22,7 +31,7 @@ async def test_init():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get(monkeypatch: pytest.MonkeyPatch): async def test_get(app: App, monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setattr(nonebot, "_driver", None) m.setattr(nonebot, "_driver", None)
with pytest.raises(ValueError): with pytest.raises(ValueError):
@ -33,6 +42,18 @@ async def test_get(monkeypatch: pytest.MonkeyPatch):
assert get_asgi() == driver.asgi assert get_asgi() == driver.asgi
assert get_app() == driver.server_app assert get_app() == driver.server_app
async with app.test_api() as ctx:
adapter = ctx.create_adapter()
adapter_name = adapter.get_name()
with monkeypatch.context() as m:
m.setattr(Driver, "_adapters", {adapter_name: adapter})
assert get_adapters() == {adapter_name: adapter}
assert get_adapter(adapter_name) is adapter
assert get_adapter(adapter.__class__) is adapter
with pytest.raises(ValueError):
get_adapter("not exist")
runned = False runned = False
def mock_run(*args, **kwargs): def mock_run(*args, **kwargs):