Skip to main content

rustyclaw_core/gateway/
concurrent.rs

1//! Concurrent model execution support.
2//!
3//! This module provides infrastructure for running multiple model requests
4//! concurrently across different threads, allowing the TUI to remain responsive
5//! while models are generating responses.
6//!
7//! Architecture:
8//! - Each model request runs in its own spawned task
9//! - Tasks send frames back via an mpsc channel
10//! - The main loop selects between client messages and model responses
11//! - Thread switching is allowed while models are running
12
13use crate::threads::ThreadId;
14use futures_util::Sink;
15use std::pin::Pin;
16use std::task::{Context, Poll};
17use tokio::sync::mpsc;
18use tokio_tungstenite::tungstenite::Message;
19
20/// A message from a spawned model task back to the main connection handler.
21#[derive(Debug, Clone)]
22pub enum ModelTaskMessage {
23    /// A raw WebSocket message to send to the client
24    RawMessage(Message),
25    
26    /// The model task completed successfully.
27    /// The main loop should update thread state.
28    Done {
29        thread_id: ThreadId,
30        /// Final assistant response text to add to thread history
31        response: Option<String>,
32    },
33    
34    /// The model task failed with an error
35    Error {
36        thread_id: ThreadId,
37        message: String,
38    },
39}
40
41/// Sender for model task messages.
42pub type ModelTaskTx = mpsc::Sender<ModelTaskMessage>;
43
44/// Receiver for model task messages.
45pub type ModelTaskRx = mpsc::Receiver<ModelTaskMessage>;
46
47/// Create a new model task channel.
48pub fn channel() -> (ModelTaskTx, ModelTaskRx) {
49    mpsc::channel(256)
50}
51
52/// A sink that sends WebSocket messages through a channel.
53/// 
54/// This implements `Sink<Message>` so it can be used with `send_frame` and other
55/// functions that expect a WebSocket writer.
56pub struct ChannelSink {
57    tx: ModelTaskTx,
58    thread_id: ThreadId,
59}
60
61impl ChannelSink {
62    pub fn new(tx: ModelTaskTx, thread_id: ThreadId) -> Self {
63        Self { tx, thread_id }
64    }
65    
66    /// Signal that the task completed successfully.
67    pub async fn done(&self, response: Option<String>) {
68        let _ = self.tx.send(ModelTaskMessage::Done {
69            thread_id: self.thread_id,
70            response,
71        }).await;
72    }
73    
74    /// Signal that the task failed.
75    pub async fn error(&self, message: String) {
76        let _ = self.tx.send(ModelTaskMessage::Error {
77            thread_id: self.thread_id,
78            message,
79        }).await;
80    }
81}
82
83impl Sink<Message> for ChannelSink {
84    type Error = mpsc::error::SendError<ModelTaskMessage>;
85
86    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
87        // Channel is always ready (bounded but non-blocking poll)
88        Poll::Ready(Ok(()))
89    }
90
91    fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
92        // Use try_send to avoid blocking
93        self.tx.try_send(ModelTaskMessage::RawMessage(item))
94            .map_err(|e| mpsc::error::SendError(match e {
95                mpsc::error::TrySendError::Full(m) => m,
96                mpsc::error::TrySendError::Closed(m) => m,
97            }))
98    }
99
100    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
101        Poll::Ready(Ok(()))
102    }
103
104    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
105        Poll::Ready(Ok(()))
106    }
107}
108
109/// Tracks active model tasks per thread.
110#[derive(Debug, Default)]
111pub struct ActiveTasks {
112    /// Map of thread ID to task handle (for cancellation)
113    tasks: std::collections::HashMap<ThreadId, tokio::task::JoinHandle<()>>,
114}
115
116impl ActiveTasks {
117    pub fn new() -> Self {
118        Self::default()
119    }
120    
121    /// Register a new task for a thread.
122    /// If there's already a task for this thread, it will be aborted.
123    pub fn register(&mut self, thread_id: ThreadId, handle: tokio::task::JoinHandle<()>) {
124        if let Some(old_handle) = self.tasks.insert(thread_id, handle) {
125            old_handle.abort();
126        }
127    }
128    
129    /// Remove a task when it completes.
130    pub fn remove(&mut self, thread_id: &ThreadId) {
131        self.tasks.remove(thread_id);
132    }
133    
134    /// Cancel a task for a specific thread.
135    pub fn cancel(&mut self, thread_id: &ThreadId) -> bool {
136        if let Some(handle) = self.tasks.remove(thread_id) {
137            handle.abort();
138            true
139        } else {
140            false
141        }
142    }
143    
144    /// Check if a thread has an active task.
145    pub fn is_running(&self, thread_id: &ThreadId) -> bool {
146        self.tasks.contains_key(thread_id)
147    }
148    
149    /// Get IDs of all threads with active tasks.
150    pub fn running_threads(&self) -> Vec<ThreadId> {
151        self.tasks.keys().copied().collect()
152    }
153}