rustyclaw_core/gateway/
concurrent.rs1use crate::threads::ThreadId;
14use futures_util::Sink;
15use std::pin::Pin;
16use std::task::{Context, Poll};
17use tokio::sync::mpsc;
18use tokio_tungstenite::tungstenite::Message;
19
20#[derive(Debug, Clone)]
22pub enum ModelTaskMessage {
23 RawMessage(Message),
25
26 Done {
29 thread_id: ThreadId,
30 response: Option<String>,
32 },
33
34 Error {
36 thread_id: ThreadId,
37 message: String,
38 },
39}
40
41pub type ModelTaskTx = mpsc::Sender<ModelTaskMessage>;
43
44pub type ModelTaskRx = mpsc::Receiver<ModelTaskMessage>;
46
47pub fn channel() -> (ModelTaskTx, ModelTaskRx) {
49 mpsc::channel(256)
50}
51
52pub struct ChannelSink {
57 tx: ModelTaskTx,
58 thread_id: ThreadId,
59}
60
61impl ChannelSink {
62 pub fn new(tx: ModelTaskTx, thread_id: ThreadId) -> Self {
63 Self { tx, thread_id }
64 }
65
66 pub async fn done(&self, response: Option<String>) {
68 let _ = self.tx.send(ModelTaskMessage::Done {
69 thread_id: self.thread_id,
70 response,
71 }).await;
72 }
73
74 pub async fn error(&self, message: String) {
76 let _ = self.tx.send(ModelTaskMessage::Error {
77 thread_id: self.thread_id,
78 message,
79 }).await;
80 }
81}
82
83impl Sink<Message> for ChannelSink {
84 type Error = mpsc::error::SendError<ModelTaskMessage>;
85
86 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
87 Poll::Ready(Ok(()))
89 }
90
91 fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
92 self.tx.try_send(ModelTaskMessage::RawMessage(item))
94 .map_err(|e| mpsc::error::SendError(match e {
95 mpsc::error::TrySendError::Full(m) => m,
96 mpsc::error::TrySendError::Closed(m) => m,
97 }))
98 }
99
100 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
101 Poll::Ready(Ok(()))
102 }
103
104 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
105 Poll::Ready(Ok(()))
106 }
107}
108
109#[derive(Debug, Default)]
111pub struct ActiveTasks {
112 tasks: std::collections::HashMap<ThreadId, tokio::task::JoinHandle<()>>,
114}
115
116impl ActiveTasks {
117 pub fn new() -> Self {
118 Self::default()
119 }
120
121 pub fn register(&mut self, thread_id: ThreadId, handle: tokio::task::JoinHandle<()>) {
124 if let Some(old_handle) = self.tasks.insert(thread_id, handle) {
125 old_handle.abort();
126 }
127 }
128
129 pub fn remove(&mut self, thread_id: &ThreadId) {
131 self.tasks.remove(thread_id);
132 }
133
134 pub fn cancel(&mut self, thread_id: &ThreadId) -> bool {
136 if let Some(handle) = self.tasks.remove(thread_id) {
137 handle.abort();
138 true
139 } else {
140 false
141 }
142 }
143
144 pub fn is_running(&self, thread_id: &ThreadId) -> bool {
146 self.tasks.contains_key(thread_id)
147 }
148
149 pub fn running_threads(&self) -> Vec<ThreadId> {
151 self.tasks.keys().copied().collect()
152 }
153}