优化select函数,减少管道重复映射,改进recv方法返回值类型

This commit is contained in:
远野千束(神羽) 2025-02-18 06:33:33 +08:00
parent 33db0ffc02
commit 7c638615ba
3 changed files with 54 additions and 49 deletions

View File

@ -1,6 +1,6 @@
from multiprocessing import set_start_method
from multiprocessing.connection import wait
from typing import Any, Callable, Generator
from multiprocessing.connection import wait, Connection
from typing import Generator
from magicoca.chan import Chan, T, NoRecvValue
@ -13,18 +13,29 @@ set_start_method("spawn", force=True)
def select(*args: Chan[T]) -> Generator[T, None, None]:
"""
Return a yield, when a value is received from one of the channels.
Args:
args: channels
当其中一个通道接收到数据时yield 该数据
参数:
args: 多个 Chan 对象
"""
pipes = [ch.recv_conn for ch in args if not ch.is_closed]
# 构造管道到通道列表的映射,避免重复的 recv_conn 对象
pipe_to_chs: dict[Connection, list[Chan[T]]] = {}
for ch in args:
if not ch.is_closed:
pipe: Connection = ch.recv_conn
pipe_to_chs.setdefault(pipe, []).append(ch)
pipes: list[Connection] = list(pipe_to_chs.keys())
while pipes:
ready_pipes = wait(pipes)
ready_pipes: list[Connection] = wait(pipes) # type: ignore
for pipe in ready_pipes:
for ch in args:
if ch.recv_conn == pipe:
if not isinstance(value := ch.recv(0), NoRecvValue):
yield value
if ch.is_closed:
pipes.remove(pipe)
# 遍历所有使用该管道的通道
channels: list[Chan[T]] = list(pipe_to_chs.get(pipe, []))
for ch in channels:
if not isinstance(value := ch.recv(0), NoRecvValue):
yield value
if ch.is_closed:
pipe_to_chs[pipe].remove(ch)
# 如果该管道已没有活跃的通道,则移除
if not pipe_to_chs[pipe]:
pipes.remove(pipe)

View File

@ -47,13 +47,13 @@ class Chan(Generic[T]):
"""
self.send_conn.send(value)
def recv(self, timeout: float | None = None) -> T | None | NoRecvValue:
def recv(self, timeout: float | None = None) -> T | NoRecvValue:
"""Receive a value from the channel.
If the timeout is None, it will block until a value is received.
If the timeout is a positive number, it will wait for the specified time, and if no value is received, it will return None.
接收通道中的值
如果超时为None则它将阻塞直到接收到值
如果超时是正数则它将等待指定的时间如果没有接收到值则返回None
如果超时是正数则它将等待指定的时间如果没有接收到值则返回NoRecvValue
Args:
timeout:
The maximum time to wait for a value.
@ -82,7 +82,7 @@ class Chan(Generic[T]):
"""
return self
def __next__(self) -> T:
def __next__(self) -> T | NoRecvValue:
return self.recv()
def __lshift__(self, other: T):
@ -95,7 +95,7 @@ class Chan(Generic[T]):
self.send(other)
return self
def __rlshift__(self, other: Any) -> T:
def __rlshift__(self, other: Any) -> T | NoRecvValue:
"""
<< chan
Returns: The value received from the channel.

View File

@ -4,41 +4,35 @@ from multiprocessing import Process
from magicoca import Chan, select
def sp1(chan: Chan[int]):
for i in range(10):
chan << i << i * 2
def sp2(chan: Chan[int]):
for i in range(10):
chan << i << i * 3
def rp(chans: list[Chan[int]]):
rl = []
for t in select(*chans):
rl.append(t)
if len(rl) == 40:
break
print(rl)
def send_process(chan: Chan[int], _id: int):
while True:
chan << _id
time.sleep(2)
for i in range(10):
chan << i
time.sleep(0.1 * _id)
def recv_process(chan_list: list[Chan[int]]):
c = []
for t in select(*chan_list):
print(t)
c.append(t)
print("Select", t)
if len(c) == 30:
break
class TestSelect:
def test_select(self):
chan_list = []
for i in range(10):
chan = Chan[int]()
chan_list.append(chan)
p = Process(target=send_process, args=(chan, i))
p.start()
p = Process(target=recv_process, args=(chan_list,))
p.start()
ch1 = Chan[int]()
ch2 = Chan[int]()
ch3 = Chan[int]()
p1 = Process(target=send_process, args=(ch1, 1))
p2 = Process(target=send_process, args=(ch2, 2))
p3 = Process(target=send_process, args=(ch3, 3))
p4 = Process(target=recv_process, args=([ch1, ch2, ch3],))
p1.start()
p2.start()
p3.start()
p4.start()
p1.join()
p2.join()
p3.join()
p4.join()