mirror of
https://github.com/LiteyukiStudio/magicoca.git
synced 2024-11-26 10:35:04 +08:00
✨ 新增select语句
This commit is contained in:
parent
8e94cd996f
commit
4975b19a24
@ -1,5 +1,25 @@
|
|||||||
from multiprocessing import set_start_method
|
from multiprocessing import set_start_method
|
||||||
|
from typing import Any, Callable, Generator
|
||||||
|
|
||||||
from magicoca.chan import Chan, T
|
from magicoca.chan import Chan, T, NoRecvValue
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Chan",
|
||||||
|
"select"
|
||||||
|
]
|
||||||
|
|
||||||
set_start_method("spawn", force=True)
|
set_start_method("spawn", force=True)
|
||||||
|
|
||||||
|
def select(*args: Chan[T]) -> Generator[T, None, None]:
|
||||||
|
"""
|
||||||
|
Return a yield, when a value is received from one of the channels.
|
||||||
|
Args:
|
||||||
|
args: channels
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
for ch in args:
|
||||||
|
if ch.is_closed:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not isinstance(value := ch.recv(0), NoRecvValue):
|
||||||
|
yield value
|
||||||
|
@ -6,6 +6,12 @@ from typing import TypeVar, Generic, Any
|
|||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
class NoRecvValue(Exception):
|
||||||
|
"""
|
||||||
|
Exception raised when there is no value to receive.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Chan(Generic[T]):
|
class Chan(Generic[T]):
|
||||||
"""
|
"""
|
||||||
@ -31,6 +37,8 @@ class Chan(Generic[T]):
|
|||||||
"""
|
"""
|
||||||
self.send_conn, self.recv_conn = Pipe()
|
self.send_conn, self.recv_conn = Pipe()
|
||||||
|
|
||||||
|
self.is_closed = False
|
||||||
|
|
||||||
def send(self, value: T):
|
def send(self, value: T):
|
||||||
"""
|
"""
|
||||||
Send a value to the channel.
|
Send a value to the channel.
|
||||||
@ -39,7 +47,7 @@ class Chan(Generic[T]):
|
|||||||
"""
|
"""
|
||||||
self.send_conn.send(value)
|
self.send_conn.send(value)
|
||||||
|
|
||||||
def recv(self, timeout: float | None = None) -> T | None:
|
def recv(self, timeout: float | None = None) -> T | None | NoRecvValue:
|
||||||
"""Receive a value from the channel.
|
"""Receive a value from the channel.
|
||||||
If the timeout is None, it will block until a value is received.
|
If the timeout is None, it will block until a value is received.
|
||||||
If the timeout is a positive number, it will wait for the specified time, and if no value is received, it will return None.
|
If the timeout is a positive number, it will wait for the specified time, and if no value is received, it will return None.
|
||||||
@ -56,15 +64,17 @@ class Chan(Generic[T]):
|
|||||||
"""
|
"""
|
||||||
if timeout is not None:
|
if timeout is not None:
|
||||||
if not self.recv_conn.poll(timeout):
|
if not self.recv_conn.poll(timeout):
|
||||||
return None
|
return NoRecvValue("No value to receive.")
|
||||||
return self.recv_conn.recv()
|
return self.recv_conn.recv()
|
||||||
|
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""
|
"""
|
||||||
Close the channel. destructor
|
Close the channel. destructor
|
||||||
"""
|
"""
|
||||||
self.send_conn.close()
|
self.send_conn.close()
|
||||||
self.recv_conn.close()
|
self.recv_conn.close()
|
||||||
|
self.is_closed = True
|
||||||
|
|
||||||
def __iter__(self) -> "Chan[T]":
|
def __iter__(self) -> "Chan[T]":
|
||||||
"""
|
"""
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import time
|
import time
|
||||||
|
from select import select
|
||||||
|
|
||||||
from magicoca.chan import Chan
|
from magicoca.chan import Chan
|
||||||
from multiprocessing import Process, set_start_method
|
from multiprocessing import Process, set_start_method
|
||||||
@ -21,6 +22,7 @@ def p2f(chan: Chan[int]):
|
|||||||
if recv_ans != list(range(10)) + [-1]:
|
if recv_ans != list(range(10)) + [-1]:
|
||||||
raise ValueError("Chan Shift Test Failed")
|
raise ValueError("Chan Shift Test Failed")
|
||||||
|
|
||||||
|
|
||||||
class TestChan:
|
class TestChan:
|
||||||
|
|
||||||
def test_test(self):
|
def test_test(self):
|
||||||
@ -63,3 +65,6 @@ class TestChan:
|
|||||||
p2.start()
|
p2.start()
|
||||||
p1.join()
|
p1.join()
|
||||||
p2.join()
|
p2.join()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
42
tests/test_select.py
Normal file
42
tests/test_select.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
from multiprocessing import Process
|
||||||
|
|
||||||
|
from magicoca import Chan, select
|
||||||
|
|
||||||
|
|
||||||
|
def sp1(chan: Chan[int]):
|
||||||
|
for i in range(10):
|
||||||
|
chan << i
|
||||||
|
|
||||||
|
|
||||||
|
def sp2(chan: Chan[int]):
|
||||||
|
for i in range(10):
|
||||||
|
chan << i
|
||||||
|
|
||||||
|
|
||||||
|
def rp(chans: list[Chan[int]]):
|
||||||
|
rl = []
|
||||||
|
for t in select(*chans):
|
||||||
|
rl.append(t)
|
||||||
|
if len(rl) == 20:
|
||||||
|
break
|
||||||
|
print(rl)
|
||||||
|
assert len(rl) == 20
|
||||||
|
|
||||||
|
|
||||||
|
class TestSelect:
|
||||||
|
def test_select(self):
|
||||||
|
chan1 = Chan[int]()
|
||||||
|
chan2 = Chan[int]()
|
||||||
|
|
||||||
|
print("Test Chan Select")
|
||||||
|
|
||||||
|
p1 = Process(target=sp1, args=(chan1,))
|
||||||
|
p2 = Process(target=sp2, args=(chan2,))
|
||||||
|
p3 = Process(target=rp, args=([chan1, chan2],))
|
||||||
|
p3.start()
|
||||||
|
p1.start()
|
||||||
|
p2.start()
|
||||||
|
|
||||||
|
p1.join()
|
||||||
|
p2.join()
|
||||||
|
p3.join()
|
Loading…
Reference in New Issue
Block a user