diff --git a/magicoca/__init__.py b/magicoca/__init__.py index 8af2fa5..35a5681 100644 --- a/magicoca/__init__.py +++ b/magicoca/__init__.py @@ -1,5 +1,25 @@ from multiprocessing import set_start_method +from typing import Any, Callable, Generator -from magicoca.chan import Chan, T +from magicoca.chan import Chan, T, NoRecvValue -set_start_method("spawn", force=True) \ No newline at end of file +__all__ = [ + "Chan", + "select" +] + +set_start_method("spawn", force=True) + +def select(*args: Chan[T]) -> Generator[T, None, None]: + """ + Return a yield, when a value is received from one of the channels. + Args: + args: channels + """ + while True: + for ch in args: + if ch.is_closed: + continue + + if not isinstance(value := ch.recv(0), NoRecvValue): + yield value diff --git a/magicoca/chan.py b/magicoca/chan.py index cfcc273..b5b9307 100644 --- a/magicoca/chan.py +++ b/magicoca/chan.py @@ -6,6 +6,12 @@ from typing import TypeVar, Generic, Any T = TypeVar("T") +class NoRecvValue(Exception): + """ + Exception raised when there is no value to receive. + """ + pass + class Chan(Generic[T]): """ @@ -31,6 +37,8 @@ class Chan(Generic[T]): """ self.send_conn, self.recv_conn = Pipe() + self.is_closed = False + def send(self, value: T): """ Send a value to the channel. @@ -39,7 +47,7 @@ class Chan(Generic[T]): """ self.send_conn.send(value) - def recv(self, timeout: float | None = None) -> T | None: + def recv(self, timeout: float | None = None) -> T | None | NoRecvValue: """Receive a value from the channel. If the timeout is None, it will block until a value is received. If the timeout is a positive number, it will wait for the specified time, and if no value is received, it will return None. @@ -56,15 +64,17 @@ class Chan(Generic[T]): """ if timeout is not None: if not self.recv_conn.poll(timeout): - return None + return NoRecvValue("No value to receive.") return self.recv_conn.recv() + def close(self): """ Close the channel. destructor """ self.send_conn.close() self.recv_conn.close() + self.is_closed = True def __iter__(self) -> "Chan[T]": """ diff --git a/tests/test_chan.py b/tests/test_chan.py index 96fe8e8..fd30cc2 100644 --- a/tests/test_chan.py +++ b/tests/test_chan.py @@ -1,4 +1,5 @@ import time +from select import select from magicoca.chan import Chan from multiprocessing import Process, set_start_method @@ -21,6 +22,7 @@ def p2f(chan: Chan[int]): if recv_ans != list(range(10)) + [-1]: raise ValueError("Chan Shift Test Failed") + class TestChan: def test_test(self): @@ -63,3 +65,6 @@ class TestChan: p2.start() p1.join() p2.join() + + + diff --git a/tests/test_select.py b/tests/test_select.py new file mode 100644 index 0000000..3fbb0e6 --- /dev/null +++ b/tests/test_select.py @@ -0,0 +1,42 @@ +from multiprocessing import Process + +from magicoca import Chan, select + + +def sp1(chan: Chan[int]): + for i in range(10): + chan << i + + +def sp2(chan: Chan[int]): + for i in range(10): + chan << i + + +def rp(chans: list[Chan[int]]): + rl = [] + for t in select(*chans): + rl.append(t) + if len(rl) == 20: + break + print(rl) + assert len(rl) == 20 + + +class TestSelect: + def test_select(self): + chan1 = Chan[int]() + chan2 = Chan[int]() + + print("Test Chan Select") + + p1 = Process(target=sp1, args=(chan1,)) + p2 = Process(target=sp2, args=(chan2,)) + p3 = Process(target=rp, args=([chan1, chan2],)) + p3.start() + p1.start() + p2.start() + + p1.join() + p2.join() + p3.join()