优化select函数,减少管道重复映射,改进recv方法返回值类型

This commit is contained in:
远野千束(神羽) 2025-02-18 06:33:33 +08:00
parent 33db0ffc02
commit 7c638615ba
3 changed files with 54 additions and 49 deletions

View File

@ -1,6 +1,6 @@
from multiprocessing import set_start_method from multiprocessing import set_start_method
from multiprocessing.connection import wait from multiprocessing.connection import wait, Connection
from typing import Any, Callable, Generator from typing import Generator
from magicoca.chan import Chan, T, NoRecvValue from magicoca.chan import Chan, T, NoRecvValue
@ -13,18 +13,29 @@ set_start_method("spawn", force=True)
def select(*args: Chan[T]) -> Generator[T, None, None]: def select(*args: Chan[T]) -> Generator[T, None, None]:
""" """
Return a yield, when a value is received from one of the channels. 当其中一个通道接收到数据时yield 该数据
Args:
args: channels 参数:
args: 多个 Chan 对象
""" """
pipes = [ch.recv_conn for ch in args if not ch.is_closed] # 构造管道到通道列表的映射,避免重复的 recv_conn 对象
pipe_to_chs: dict[Connection, list[Chan[T]]] = {}
for ch in args:
if not ch.is_closed:
pipe: Connection = ch.recv_conn
pipe_to_chs.setdefault(pipe, []).append(ch)
pipes: list[Connection] = list(pipe_to_chs.keys())
while pipes: while pipes:
ready_pipes = wait(pipes) ready_pipes: list[Connection] = wait(pipes) # type: ignore
for pipe in ready_pipes: for pipe in ready_pipes:
for ch in args: # 遍历所有使用该管道的通道
if ch.recv_conn == pipe: channels: list[Chan[T]] = list(pipe_to_chs.get(pipe, []))
if not isinstance(value := ch.recv(0), NoRecvValue): for ch in channels:
yield value if not isinstance(value := ch.recv(0), NoRecvValue):
if ch.is_closed: yield value
pipes.remove(pipe) if ch.is_closed:
pipe_to_chs[pipe].remove(ch)
# 如果该管道已没有活跃的通道,则移除
if not pipe_to_chs[pipe]:
pipes.remove(pipe)

View File

@ -47,13 +47,13 @@ class Chan(Generic[T]):
""" """
self.send_conn.send(value) self.send_conn.send(value)
def recv(self, timeout: float | None = None) -> T | None | NoRecvValue: def recv(self, timeout: float | None = None) -> T | NoRecvValue:
"""Receive a value from the channel. """Receive a value from the channel.
If the timeout is None, it will block until a value is received. If the timeout is None, it will block until a value is received.
If the timeout is a positive number, it will wait for the specified time, and if no value is received, it will return None. If the timeout is a positive number, it will wait for the specified time, and if no value is received, it will return None.
接收通道中的值 接收通道中的值
如果超时为None则它将阻塞直到接收到值 如果超时为None则它将阻塞直到接收到值
如果超时是正数则它将等待指定的时间如果没有接收到值则返回None 如果超时是正数则它将等待指定的时间如果没有接收到值则返回NoRecvValue
Args: Args:
timeout: timeout:
The maximum time to wait for a value. The maximum time to wait for a value.
@ -82,7 +82,7 @@ class Chan(Generic[T]):
""" """
return self return self
def __next__(self) -> T: def __next__(self) -> T | NoRecvValue:
return self.recv() return self.recv()
def __lshift__(self, other: T): def __lshift__(self, other: T):
@ -95,7 +95,7 @@ class Chan(Generic[T]):
self.send(other) self.send(other)
return self return self
def __rlshift__(self, other: Any) -> T: def __rlshift__(self, other: Any) -> T | NoRecvValue:
""" """
<< chan << chan
Returns: The value received from the channel. Returns: The value received from the channel.

View File

@ -4,41 +4,35 @@ from multiprocessing import Process
from magicoca import Chan, select from magicoca import Chan, select
def sp1(chan: Chan[int]):
for i in range(10):
chan << i << i * 2
def sp2(chan: Chan[int]):
for i in range(10):
chan << i << i * 3
def rp(chans: list[Chan[int]]):
rl = []
for t in select(*chans):
rl.append(t)
if len(rl) == 40:
break
print(rl)
def send_process(chan: Chan[int], _id: int): def send_process(chan: Chan[int], _id: int):
while True: for i in range(10):
chan << _id chan << i
time.sleep(2) time.sleep(0.1 * _id)
def recv_process(chan_list: list[Chan[int]]): def recv_process(chan_list: list[Chan[int]]):
c = []
for t in select(*chan_list): for t in select(*chan_list):
print(t) c.append(t)
print("Select", t)
if len(c) == 30:
break
class TestSelect: class TestSelect:
def test_select(self): def test_select(self):
chan_list = [] ch1 = Chan[int]()
for i in range(10): ch2 = Chan[int]()
chan = Chan[int]() ch3 = Chan[int]()
chan_list.append(chan)
p = Process(target=send_process, args=(chan, i)) p1 = Process(target=send_process, args=(ch1, 1))
p.start() p2 = Process(target=send_process, args=(ch2, 2))
p = Process(target=recv_process, args=(chan_list,)) p3 = Process(target=send_process, args=(ch3, 3))
p.start() p4 = Process(target=recv_process, args=([ch1, ch2, ch3],))
p1.start()
p2.start()
p3.start()
p4.start()
p1.join()
p2.join()
p3.join()
p4.join()