diff --git a/magicoca/__init__.py b/magicoca/__init__.py index ac10b42..d0690cf 100644 --- a/magicoca/__init__.py +++ b/magicoca/__init__.py @@ -1,6 +1,6 @@ from multiprocessing import set_start_method -from multiprocessing.connection import wait -from typing import Any, Callable, Generator +from multiprocessing.connection import wait, Connection +from typing import Generator from magicoca.chan import Chan, T, NoRecvValue @@ -13,18 +13,29 @@ set_start_method("spawn", force=True) def select(*args: Chan[T]) -> Generator[T, None, None]: """ - Return a yield, when a value is received from one of the channels. - Args: - args: channels + 当其中一个通道接收到数据时,yield 该数据。 + + 参数: + args: 多个 Chan 对象 """ - pipes = [ch.recv_conn for ch in args if not ch.is_closed] + # 构造管道到通道列表的映射,避免重复的 recv_conn 对象 + pipe_to_chs: dict[Connection, list[Chan[T]]] = {} + for ch in args: + if not ch.is_closed: + pipe: Connection = ch.recv_conn + pipe_to_chs.setdefault(pipe, []).append(ch) + pipes: list[Connection] = list(pipe_to_chs.keys()) while pipes: - ready_pipes = wait(pipes) + ready_pipes: list[Connection] = wait(pipes) # type: ignore for pipe in ready_pipes: - for ch in args: - if ch.recv_conn == pipe: - if not isinstance(value := ch.recv(0), NoRecvValue): - yield value - if ch.is_closed: - pipes.remove(pipe) \ No newline at end of file + # 遍历所有使用该管道的通道 + channels: list[Chan[T]] = list(pipe_to_chs.get(pipe, [])) + for ch in channels: + if not isinstance(value := ch.recv(0), NoRecvValue): + yield value + if ch.is_closed: + pipe_to_chs[pipe].remove(ch) + # 如果该管道已没有活跃的通道,则移除 + if not pipe_to_chs[pipe]: + pipes.remove(pipe) \ No newline at end of file diff --git a/magicoca/chan.py b/magicoca/chan.py index bb80e78..be7a0b5 100644 --- a/magicoca/chan.py +++ b/magicoca/chan.py @@ -47,13 +47,13 @@ class Chan(Generic[T]): """ self.send_conn.send(value) - def recv(self, timeout: float | None = None) -> T | None | NoRecvValue: + def recv(self, timeout: float | None = None) -> T | NoRecvValue: """Receive a value from the channel. If the timeout is None, it will block until a value is received. If the timeout is a positive number, it will wait for the specified time, and if no value is received, it will return None. 接收通道中的值。 如果超时为None,则它将阻塞,直到接收到值。 - 如果超时是正数,则它将等待指定的时间,如果没有接收到值,则返回None。 + 如果超时是正数,则它将等待指定的时间,如果没有接收到值,则返回NoRecvValue。 Args: timeout: The maximum time to wait for a value. @@ -82,7 +82,7 @@ class Chan(Generic[T]): """ return self - def __next__(self) -> T: + def __next__(self) -> T | NoRecvValue: return self.recv() def __lshift__(self, other: T): @@ -95,7 +95,7 @@ class Chan(Generic[T]): self.send(other) return self - def __rlshift__(self, other: Any) -> T: + def __rlshift__(self, other: Any) -> T | NoRecvValue: """ << chan Returns: The value received from the channel. diff --git a/tests/test_select.py b/tests/test_select.py index 8c0513c..7f65d06 100644 --- a/tests/test_select.py +++ b/tests/test_select.py @@ -4,41 +4,35 @@ from multiprocessing import Process from magicoca import Chan, select -def sp1(chan: Chan[int]): - for i in range(10): - chan << i << i * 2 - - -def sp2(chan: Chan[int]): - for i in range(10): - chan << i << i * 3 - - -def rp(chans: list[Chan[int]]): - rl = [] - for t in select(*chans): - rl.append(t) - if len(rl) == 40: - break - print(rl) - def send_process(chan: Chan[int], _id: int): - while True: - chan << _id - time.sleep(2) + for i in range(10): + chan << i + time.sleep(0.1 * _id) def recv_process(chan_list: list[Chan[int]]): + c = [] for t in select(*chan_list): - print(t) - - + c.append(t) + print("Select", t) + if len(c) == 30: + break class TestSelect: def test_select(self): - chan_list = [] - for i in range(10): - chan = Chan[int]() - chan_list.append(chan) - p = Process(target=send_process, args=(chan, i)) - p.start() - p = Process(target=recv_process, args=(chan_list,)) - p.start() + ch1 = Chan[int]() + ch2 = Chan[int]() + ch3 = Chan[int]() + + p1 = Process(target=send_process, args=(ch1, 1)) + p2 = Process(target=send_process, args=(ch2, 2)) + p3 = Process(target=send_process, args=(ch3, 3)) + p4 = Process(target=recv_process, args=([ch1, ch2, ch3],)) + + p1.start() + p2.start() + p3.start() + p4.start() + + p1.join() + p2.join() + p3.join() + p4.join()