mirror of
https://github.com/LiteyukiStudio/magicoca.git
synced 2025-02-22 18:45:20 +08:00
⚡ 优化select函数,减少管道重复映射,改进recv方法返回值类型
This commit is contained in:
parent
33db0ffc02
commit
7c638615ba
@ -1,6 +1,6 @@
|
|||||||
from multiprocessing import set_start_method
|
from multiprocessing import set_start_method
|
||||||
from multiprocessing.connection import wait
|
from multiprocessing.connection import wait, Connection
|
||||||
from typing import Any, Callable, Generator
|
from typing import Generator
|
||||||
|
|
||||||
from magicoca.chan import Chan, T, NoRecvValue
|
from magicoca.chan import Chan, T, NoRecvValue
|
||||||
|
|
||||||
@ -13,18 +13,29 @@ set_start_method("spawn", force=True)
|
|||||||
|
|
||||||
def select(*args: Chan[T]) -> Generator[T, None, None]:
|
def select(*args: Chan[T]) -> Generator[T, None, None]:
|
||||||
"""
|
"""
|
||||||
Return a yield, when a value is received from one of the channels.
|
当其中一个通道接收到数据时,yield 该数据。
|
||||||
Args:
|
|
||||||
args: channels
|
参数:
|
||||||
|
args: 多个 Chan 对象
|
||||||
"""
|
"""
|
||||||
pipes = [ch.recv_conn for ch in args if not ch.is_closed]
|
# 构造管道到通道列表的映射,避免重复的 recv_conn 对象
|
||||||
|
pipe_to_chs: dict[Connection, list[Chan[T]]] = {}
|
||||||
|
for ch in args:
|
||||||
|
if not ch.is_closed:
|
||||||
|
pipe: Connection = ch.recv_conn
|
||||||
|
pipe_to_chs.setdefault(pipe, []).append(ch)
|
||||||
|
pipes: list[Connection] = list(pipe_to_chs.keys())
|
||||||
|
|
||||||
while pipes:
|
while pipes:
|
||||||
ready_pipes = wait(pipes)
|
ready_pipes: list[Connection] = wait(pipes) # type: ignore
|
||||||
for pipe in ready_pipes:
|
for pipe in ready_pipes:
|
||||||
for ch in args:
|
# 遍历所有使用该管道的通道
|
||||||
if ch.recv_conn == pipe:
|
channels: list[Chan[T]] = list(pipe_to_chs.get(pipe, []))
|
||||||
if not isinstance(value := ch.recv(0), NoRecvValue):
|
for ch in channels:
|
||||||
yield value
|
if not isinstance(value := ch.recv(0), NoRecvValue):
|
||||||
if ch.is_closed:
|
yield value
|
||||||
pipes.remove(pipe)
|
if ch.is_closed:
|
||||||
|
pipe_to_chs[pipe].remove(ch)
|
||||||
|
# 如果该管道已没有活跃的通道,则移除
|
||||||
|
if not pipe_to_chs[pipe]:
|
||||||
|
pipes.remove(pipe)
|
@ -47,13 +47,13 @@ class Chan(Generic[T]):
|
|||||||
"""
|
"""
|
||||||
self.send_conn.send(value)
|
self.send_conn.send(value)
|
||||||
|
|
||||||
def recv(self, timeout: float | None = None) -> T | None | NoRecvValue:
|
def recv(self, timeout: float | None = None) -> T | NoRecvValue:
|
||||||
"""Receive a value from the channel.
|
"""Receive a value from the channel.
|
||||||
If the timeout is None, it will block until a value is received.
|
If the timeout is None, it will block until a value is received.
|
||||||
If the timeout is a positive number, it will wait for the specified time, and if no value is received, it will return None.
|
If the timeout is a positive number, it will wait for the specified time, and if no value is received, it will return None.
|
||||||
接收通道中的值。
|
接收通道中的值。
|
||||||
如果超时为None,则它将阻塞,直到接收到值。
|
如果超时为None,则它将阻塞,直到接收到值。
|
||||||
如果超时是正数,则它将等待指定的时间,如果没有接收到值,则返回None。
|
如果超时是正数,则它将等待指定的时间,如果没有接收到值,则返回NoRecvValue。
|
||||||
Args:
|
Args:
|
||||||
timeout:
|
timeout:
|
||||||
The maximum time to wait for a value.
|
The maximum time to wait for a value.
|
||||||
@ -82,7 +82,7 @@ class Chan(Generic[T]):
|
|||||||
"""
|
"""
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __next__(self) -> T:
|
def __next__(self) -> T | NoRecvValue:
|
||||||
return self.recv()
|
return self.recv()
|
||||||
|
|
||||||
def __lshift__(self, other: T):
|
def __lshift__(self, other: T):
|
||||||
@ -95,7 +95,7 @@ class Chan(Generic[T]):
|
|||||||
self.send(other)
|
self.send(other)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __rlshift__(self, other: Any) -> T:
|
def __rlshift__(self, other: Any) -> T | NoRecvValue:
|
||||||
"""
|
"""
|
||||||
<< chan
|
<< chan
|
||||||
Returns: The value received from the channel.
|
Returns: The value received from the channel.
|
||||||
|
@ -4,41 +4,35 @@ from multiprocessing import Process
|
|||||||
from magicoca import Chan, select
|
from magicoca import Chan, select
|
||||||
|
|
||||||
|
|
||||||
def sp1(chan: Chan[int]):
|
|
||||||
for i in range(10):
|
|
||||||
chan << i << i * 2
|
|
||||||
|
|
||||||
|
|
||||||
def sp2(chan: Chan[int]):
|
|
||||||
for i in range(10):
|
|
||||||
chan << i << i * 3
|
|
||||||
|
|
||||||
|
|
||||||
def rp(chans: list[Chan[int]]):
|
|
||||||
rl = []
|
|
||||||
for t in select(*chans):
|
|
||||||
rl.append(t)
|
|
||||||
if len(rl) == 40:
|
|
||||||
break
|
|
||||||
print(rl)
|
|
||||||
|
|
||||||
def send_process(chan: Chan[int], _id: int):
|
def send_process(chan: Chan[int], _id: int):
|
||||||
while True:
|
for i in range(10):
|
||||||
chan << _id
|
chan << i
|
||||||
time.sleep(2)
|
time.sleep(0.1 * _id)
|
||||||
|
|
||||||
def recv_process(chan_list: list[Chan[int]]):
|
def recv_process(chan_list: list[Chan[int]]):
|
||||||
|
c = []
|
||||||
for t in select(*chan_list):
|
for t in select(*chan_list):
|
||||||
print(t)
|
c.append(t)
|
||||||
|
print("Select", t)
|
||||||
|
if len(c) == 30:
|
||||||
|
break
|
||||||
class TestSelect:
|
class TestSelect:
|
||||||
def test_select(self):
|
def test_select(self):
|
||||||
chan_list = []
|
ch1 = Chan[int]()
|
||||||
for i in range(10):
|
ch2 = Chan[int]()
|
||||||
chan = Chan[int]()
|
ch3 = Chan[int]()
|
||||||
chan_list.append(chan)
|
|
||||||
p = Process(target=send_process, args=(chan, i))
|
p1 = Process(target=send_process, args=(ch1, 1))
|
||||||
p.start()
|
p2 = Process(target=send_process, args=(ch2, 2))
|
||||||
p = Process(target=recv_process, args=(chan_list,))
|
p3 = Process(target=send_process, args=(ch3, 3))
|
||||||
p.start()
|
p4 = Process(target=recv_process, args=([ch1, ch2, ch3],))
|
||||||
|
|
||||||
|
p1.start()
|
||||||
|
p2.start()
|
||||||
|
p3.start()
|
||||||
|
p4.start()
|
||||||
|
|
||||||
|
p1.join()
|
||||||
|
p2.join()
|
||||||
|
p3.join()
|
||||||
|
p4.join()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user