mirror of
https://github.com/LiteyukiStudio/magicoca.git
synced 2024-11-22 16:47:37 +08:00
✨ 新增select语句
This commit is contained in:
parent
8e94cd996f
commit
4975b19a24
@ -1,5 +1,25 @@
|
||||
from multiprocessing import set_start_method
|
||||
from typing import Any, Callable, Generator
|
||||
|
||||
from magicoca.chan import Chan, T
|
||||
from magicoca.chan import Chan, T, NoRecvValue
|
||||
|
||||
set_start_method("spawn", force=True)
|
||||
__all__ = [
|
||||
"Chan",
|
||||
"select"
|
||||
]
|
||||
|
||||
set_start_method("spawn", force=True)
|
||||
|
||||
def select(*args: Chan[T]) -> Generator[T, None, None]:
|
||||
"""
|
||||
Return a yield, when a value is received from one of the channels.
|
||||
Args:
|
||||
args: channels
|
||||
"""
|
||||
while True:
|
||||
for ch in args:
|
||||
if ch.is_closed:
|
||||
continue
|
||||
|
||||
if not isinstance(value := ch.recv(0), NoRecvValue):
|
||||
yield value
|
||||
|
@ -6,6 +6,12 @@ from typing import TypeVar, Generic, Any
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
class NoRecvValue(Exception):
|
||||
"""
|
||||
Exception raised when there is no value to receive.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Chan(Generic[T]):
|
||||
"""
|
||||
@ -31,6 +37,8 @@ class Chan(Generic[T]):
|
||||
"""
|
||||
self.send_conn, self.recv_conn = Pipe()
|
||||
|
||||
self.is_closed = False
|
||||
|
||||
def send(self, value: T):
|
||||
"""
|
||||
Send a value to the channel.
|
||||
@ -39,7 +47,7 @@ class Chan(Generic[T]):
|
||||
"""
|
||||
self.send_conn.send(value)
|
||||
|
||||
def recv(self, timeout: float | None = None) -> T | None:
|
||||
def recv(self, timeout: float | None = None) -> T | None | NoRecvValue:
|
||||
"""Receive a value from the channel.
|
||||
If the timeout is None, it will block until a value is received.
|
||||
If the timeout is a positive number, it will wait for the specified time, and if no value is received, it will return None.
|
||||
@ -56,15 +64,17 @@ class Chan(Generic[T]):
|
||||
"""
|
||||
if timeout is not None:
|
||||
if not self.recv_conn.poll(timeout):
|
||||
return None
|
||||
return NoRecvValue("No value to receive.")
|
||||
return self.recv_conn.recv()
|
||||
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close the channel. destructor
|
||||
"""
|
||||
self.send_conn.close()
|
||||
self.recv_conn.close()
|
||||
self.is_closed = True
|
||||
|
||||
def __iter__(self) -> "Chan[T]":
|
||||
"""
|
||||
|
@ -1,4 +1,5 @@
|
||||
import time
|
||||
from select import select
|
||||
|
||||
from magicoca.chan import Chan
|
||||
from multiprocessing import Process, set_start_method
|
||||
@ -21,6 +22,7 @@ def p2f(chan: Chan[int]):
|
||||
if recv_ans != list(range(10)) + [-1]:
|
||||
raise ValueError("Chan Shift Test Failed")
|
||||
|
||||
|
||||
class TestChan:
|
||||
|
||||
def test_test(self):
|
||||
@ -63,3 +65,6 @@ class TestChan:
|
||||
p2.start()
|
||||
p1.join()
|
||||
p2.join()
|
||||
|
||||
|
||||
|
||||
|
42
tests/test_select.py
Normal file
42
tests/test_select.py
Normal file
@ -0,0 +1,42 @@
|
||||
from multiprocessing import Process
|
||||
|
||||
from magicoca import Chan, select
|
||||
|
||||
|
||||
def sp1(chan: Chan[int]):
|
||||
for i in range(10):
|
||||
chan << i
|
||||
|
||||
|
||||
def sp2(chan: Chan[int]):
|
||||
for i in range(10):
|
||||
chan << i
|
||||
|
||||
|
||||
def rp(chans: list[Chan[int]]):
|
||||
rl = []
|
||||
for t in select(*chans):
|
||||
rl.append(t)
|
||||
if len(rl) == 20:
|
||||
break
|
||||
print(rl)
|
||||
assert len(rl) == 20
|
||||
|
||||
|
||||
class TestSelect:
|
||||
def test_select(self):
|
||||
chan1 = Chan[int]()
|
||||
chan2 = Chan[int]()
|
||||
|
||||
print("Test Chan Select")
|
||||
|
||||
p1 = Process(target=sp1, args=(chan1,))
|
||||
p2 = Process(target=sp2, args=(chan2,))
|
||||
p3 = Process(target=rp, args=([chan1, chan2],))
|
||||
p3.start()
|
||||
p1.start()
|
||||
p2.start()
|
||||
|
||||
p1.join()
|
||||
p2.join()
|
||||
p3.join()
|
Loading…
Reference in New Issue
Block a user