mirror of
https://github.com/LiteyukiStudio/magicoca.git
synced 2025-02-22 02:25:50 +08:00
⚡ 优化select函数,减少管道重复映射,改进recv方法返回值类型
This commit is contained in:
parent
33db0ffc02
commit
7c638615ba
@ -1,6 +1,6 @@
|
||||
from multiprocessing import set_start_method
|
||||
from multiprocessing.connection import wait
|
||||
from typing import Any, Callable, Generator
|
||||
from multiprocessing.connection import wait, Connection
|
||||
from typing import Generator
|
||||
|
||||
from magicoca.chan import Chan, T, NoRecvValue
|
||||
|
||||
@ -13,18 +13,29 @@ set_start_method("spawn", force=True)
|
||||
|
||||
def select(*args: Chan[T]) -> Generator[T, None, None]:
|
||||
"""
|
||||
Return a yield, when a value is received from one of the channels.
|
||||
Args:
|
||||
args: channels
|
||||
当其中一个通道接收到数据时,yield 该数据。
|
||||
|
||||
参数:
|
||||
args: 多个 Chan 对象
|
||||
"""
|
||||
pipes = [ch.recv_conn for ch in args if not ch.is_closed]
|
||||
# 构造管道到通道列表的映射,避免重复的 recv_conn 对象
|
||||
pipe_to_chs: dict[Connection, list[Chan[T]]] = {}
|
||||
for ch in args:
|
||||
if not ch.is_closed:
|
||||
pipe: Connection = ch.recv_conn
|
||||
pipe_to_chs.setdefault(pipe, []).append(ch)
|
||||
pipes: list[Connection] = list(pipe_to_chs.keys())
|
||||
|
||||
while pipes:
|
||||
ready_pipes = wait(pipes)
|
||||
ready_pipes: list[Connection] = wait(pipes) # type: ignore
|
||||
for pipe in ready_pipes:
|
||||
for ch in args:
|
||||
if ch.recv_conn == pipe:
|
||||
if not isinstance(value := ch.recv(0), NoRecvValue):
|
||||
yield value
|
||||
if ch.is_closed:
|
||||
pipes.remove(pipe)
|
||||
# 遍历所有使用该管道的通道
|
||||
channels: list[Chan[T]] = list(pipe_to_chs.get(pipe, []))
|
||||
for ch in channels:
|
||||
if not isinstance(value := ch.recv(0), NoRecvValue):
|
||||
yield value
|
||||
if ch.is_closed:
|
||||
pipe_to_chs[pipe].remove(ch)
|
||||
# 如果该管道已没有活跃的通道,则移除
|
||||
if not pipe_to_chs[pipe]:
|
||||
pipes.remove(pipe)
|
@ -47,13 +47,13 @@ class Chan(Generic[T]):
|
||||
"""
|
||||
self.send_conn.send(value)
|
||||
|
||||
def recv(self, timeout: float | None = None) -> T | None | NoRecvValue:
|
||||
def recv(self, timeout: float | None = None) -> T | NoRecvValue:
|
||||
"""Receive a value from the channel.
|
||||
If the timeout is None, it will block until a value is received.
|
||||
If the timeout is a positive number, it will wait for the specified time, and if no value is received, it will return None.
|
||||
接收通道中的值。
|
||||
如果超时为None,则它将阻塞,直到接收到值。
|
||||
如果超时是正数,则它将等待指定的时间,如果没有接收到值,则返回None。
|
||||
如果超时是正数,则它将等待指定的时间,如果没有接收到值,则返回NoRecvValue。
|
||||
Args:
|
||||
timeout:
|
||||
The maximum time to wait for a value.
|
||||
@ -82,7 +82,7 @@ class Chan(Generic[T]):
|
||||
"""
|
||||
return self
|
||||
|
||||
def __next__(self) -> T:
|
||||
def __next__(self) -> T | NoRecvValue:
|
||||
return self.recv()
|
||||
|
||||
def __lshift__(self, other: T):
|
||||
@ -95,7 +95,7 @@ class Chan(Generic[T]):
|
||||
self.send(other)
|
||||
return self
|
||||
|
||||
def __rlshift__(self, other: Any) -> T:
|
||||
def __rlshift__(self, other: Any) -> T | NoRecvValue:
|
||||
"""
|
||||
<< chan
|
||||
Returns: The value received from the channel.
|
||||
|
@ -4,41 +4,35 @@ from multiprocessing import Process
|
||||
from magicoca import Chan, select
|
||||
|
||||
|
||||
def sp1(chan: Chan[int]):
|
||||
for i in range(10):
|
||||
chan << i << i * 2
|
||||
|
||||
|
||||
def sp2(chan: Chan[int]):
|
||||
for i in range(10):
|
||||
chan << i << i * 3
|
||||
|
||||
|
||||
def rp(chans: list[Chan[int]]):
|
||||
rl = []
|
||||
for t in select(*chans):
|
||||
rl.append(t)
|
||||
if len(rl) == 40:
|
||||
break
|
||||
print(rl)
|
||||
|
||||
def send_process(chan: Chan[int], _id: int):
|
||||
while True:
|
||||
chan << _id
|
||||
time.sleep(2)
|
||||
for i in range(10):
|
||||
chan << i
|
||||
time.sleep(0.1 * _id)
|
||||
|
||||
def recv_process(chan_list: list[Chan[int]]):
|
||||
c = []
|
||||
for t in select(*chan_list):
|
||||
print(t)
|
||||
|
||||
|
||||
c.append(t)
|
||||
print("Select", t)
|
||||
if len(c) == 30:
|
||||
break
|
||||
class TestSelect:
|
||||
def test_select(self):
|
||||
chan_list = []
|
||||
for i in range(10):
|
||||
chan = Chan[int]()
|
||||
chan_list.append(chan)
|
||||
p = Process(target=send_process, args=(chan, i))
|
||||
p.start()
|
||||
p = Process(target=recv_process, args=(chan_list,))
|
||||
p.start()
|
||||
ch1 = Chan[int]()
|
||||
ch2 = Chan[int]()
|
||||
ch3 = Chan[int]()
|
||||
|
||||
p1 = Process(target=send_process, args=(ch1, 1))
|
||||
p2 = Process(target=send_process, args=(ch2, 2))
|
||||
p3 = Process(target=send_process, args=(ch3, 3))
|
||||
p4 = Process(target=recv_process, args=([ch1, ch2, ch3],))
|
||||
|
||||
p1.start()
|
||||
p2.start()
|
||||
p3.start()
|
||||
p4.start()
|
||||
|
||||
p1.join()
|
||||
p2.join()
|
||||
p3.join()
|
||||
p4.join()
|
||||
|
Loading…
x
Reference in New Issue
Block a user