turul_mcp_client/
streaming.rs

1//! Streaming support for MCP client
2
3use serde_json::Value;
4use std::sync::Arc;
5use tokio::sync::mpsc;
6use tracing::{debug, info, warn};
7
8use crate::error::{McpClientError, McpClientResult};
9use crate::transport::ServerEvent;
10
11/// Stream handler for processing server events
12#[derive(Debug)]
13pub struct StreamHandler {
14    /// Event receiver from transport
15    event_receiver: Option<mpsc::UnboundedReceiver<ServerEvent>>,
16    /// Event callbacks
17    callbacks: Arc<parking_lot::Mutex<StreamCallbacks>>,
18}
19
20/// Type alias for request handler callback
21type RequestHandler = Box<dyn Fn(Value) -> Result<Value, String> + Send + Sync>;
22
23/// Callbacks for different types of server events
24#[derive(Default)]
25pub struct StreamCallbacks {
26    /// Notification callback
27    pub notification: Option<Box<dyn Fn(Value) + Send + Sync>>,
28    /// Request callback (server asking client)
29    pub request: Option<RequestHandler>,
30    /// Connection lost callback
31    pub connection_lost: Option<Box<dyn Fn() + Send + Sync>>,
32    /// Error callback
33    pub error: Option<Box<dyn Fn(String) + Send + Sync>>,
34    /// Heartbeat callback
35    pub heartbeat: Option<Box<dyn Fn() + Send + Sync>>,
36}
37
38impl std::fmt::Debug for StreamCallbacks {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        f.debug_struct("StreamCallbacks")
41            .field(
42                "notification",
43                &self.notification.as_ref().map(|_| "function"),
44            )
45            .field("request", &self.request.as_ref().map(|_| "function"))
46            .field(
47                "connection_lost",
48                &self.connection_lost.as_ref().map(|_| "function"),
49            )
50            .field("error", &self.error.as_ref().map(|_| "function"))
51            .field("heartbeat", &self.heartbeat.as_ref().map(|_| "function"))
52            .finish()
53    }
54}
55
56impl StreamHandler {
57    /// Create a new stream handler
58    pub fn new() -> Self {
59        Self {
60            event_receiver: None,
61            callbacks: Arc::new(parking_lot::Mutex::new(StreamCallbacks::default())),
62        }
63    }
64
65    /// Set event receiver from transport
66    pub fn set_receiver(&mut self, receiver: mpsc::UnboundedReceiver<ServerEvent>) {
67        self.event_receiver = Some(receiver);
68    }
69
70    /// Set notification callback
71    pub fn on_notification<F>(&self, callback: F)
72    where
73        F: Fn(Value) + Send + Sync + 'static,
74    {
75        self.callbacks.lock().notification = Some(Box::new(callback));
76    }
77
78    /// Set request callback
79    pub fn on_request<F>(&self, callback: F)
80    where
81        F: Fn(Value) -> Result<Value, String> + Send + Sync + 'static,
82    {
83        self.callbacks.lock().request = Some(Box::new(callback));
84    }
85
86    /// Set connection lost callback
87    pub fn on_connection_lost<F>(&self, callback: F)
88    where
89        F: Fn() + Send + Sync + 'static,
90    {
91        self.callbacks.lock().connection_lost = Some(Box::new(callback));
92    }
93
94    /// Set error callback
95    pub fn on_error<F>(&self, callback: F)
96    where
97        F: Fn(String) + Send + Sync + 'static,
98    {
99        self.callbacks.lock().error = Some(Box::new(callback));
100    }
101
102    /// Set heartbeat callback
103    pub fn on_heartbeat<F>(&self, callback: F)
104    where
105        F: Fn() + Send + Sync + 'static,
106    {
107        self.callbacks.lock().heartbeat = Some(Box::new(callback));
108    }
109
110    /// Start processing events
111    pub async fn start(&mut self) -> McpClientResult<()> {
112        let mut receiver = self
113            .event_receiver
114            .take()
115            .ok_or_else(|| McpClientError::generic("No event receiver configured"))?;
116
117        let callbacks = Arc::clone(&self.callbacks);
118
119        tokio::spawn(async move {
120            info!("Stream handler started");
121
122            while let Some(event) = receiver.recv().await {
123                debug!(event = ?event, "Received server event");
124
125                let callbacks = callbacks.lock();
126
127                match event {
128                    ServerEvent::Notification(notification) => {
129                        if let Some(ref callback) = callbacks.notification {
130                            callback(notification);
131                        }
132                    }
133                    ServerEvent::Request(request) => {
134                        if let Some(ref callback) = callbacks.request {
135                            match callback(request) {
136                                Ok(_response) => {
137                                    debug!("Request handled successfully");
138                                    // TODO: Send response back to server
139                                }
140                                Err(error) => {
141                                    warn!(error = %error, "Request handler returned error");
142                                    // TODO: Send error response back to server
143                                }
144                            }
145                        } else {
146                            warn!("Received server request but no request handler configured");
147                        }
148                    }
149                    ServerEvent::ConnectionLost => {
150                        warn!("Connection lost");
151                        if let Some(ref callback) = callbacks.connection_lost {
152                            callback();
153                        }
154                    }
155                    ServerEvent::Error(error) => {
156                        warn!(error = %error, "Server error");
157                        if let Some(ref callback) = callbacks.error {
158                            callback(error);
159                        }
160                    }
161                    ServerEvent::Heartbeat => {
162                        debug!("Heartbeat received");
163                        if let Some(ref callback) = callbacks.heartbeat {
164                            callback();
165                        }
166                    }
167                }
168            }
169
170            info!("Stream handler stopped");
171        });
172
173        Ok(())
174    }
175
176    /// Check if handler is active
177    pub fn is_active(&self) -> bool {
178        self.event_receiver.is_some()
179    }
180}
181
182impl Default for StreamHandler {
183    fn default() -> Self {
184        Self::new()
185    }
186}
187
188/// Progress tracker for long-running operations
189#[derive(Debug, Clone)]
190pub struct ProgressTracker {
191    /// Operation ID
192    pub operation_id: String,
193    /// Total steps (if known)
194    pub total: Option<u64>,
195    /// Completed steps
196    pub completed: u64,
197    /// Progress message
198    pub message: Option<String>,
199    /// Progress metadata
200    pub metadata: Value,
201}
202
203impl ProgressTracker {
204    /// Create a new progress tracker
205    pub fn new(operation_id: String) -> Self {
206        Self {
207            operation_id,
208            total: None,
209            completed: 0,
210            message: None,
211            metadata: Value::Null,
212        }
213    }
214
215    /// Update progress
216    pub fn update(&mut self, completed: u64, message: Option<String>) {
217        self.completed = completed;
218        self.message = message;
219    }
220
221    /// Set total steps
222    pub fn set_total(&mut self, total: u64) {
223        self.total = Some(total);
224    }
225
226    /// Get progress percentage (0.0 to 1.0)
227    pub fn percentage(&self) -> Option<f64> {
228        self.total.map(|total| {
229            if total == 0 {
230                1.0
231            } else {
232                (self.completed as f64) / (total as f64)
233            }
234        })
235    }
236
237    /// Check if operation is complete
238    pub fn is_complete(&self) -> bool {
239        if let Some(total) = self.total {
240            self.completed >= total
241        } else {
242            false
243        }
244    }
245
246    /// Get status summary
247    pub fn status(&self) -> String {
248        match (self.total, &self.message) {
249            (Some(total), Some(msg)) => {
250                format!(
251                    "{}/{} ({}%) - {}",
252                    self.completed,
253                    total,
254                    (self.percentage().unwrap_or(0.0) * 100.0) as u32,
255                    msg
256                )
257            }
258            (Some(total), None) => {
259                format!(
260                    "{}/{} ({}%)",
261                    self.completed,
262                    total,
263                    (self.percentage().unwrap_or(0.0) * 100.0) as u32
264                )
265            }
266            (None, Some(msg)) => {
267                format!("{} steps - {}", self.completed, msg)
268            }
269            (None, None) => {
270                format!("{} steps", self.completed)
271            }
272        }
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    #[test]
281    fn test_progress_tracker() {
282        let mut tracker = ProgressTracker::new("test-op".to_string());
283
284        assert_eq!(tracker.completed, 0);
285        assert_eq!(tracker.percentage(), None);
286
287        tracker.set_total(100);
288        assert_eq!(tracker.percentage(), Some(0.0));
289
290        tracker.update(50, Some("halfway".to_string()));
291        assert_eq!(tracker.percentage(), Some(0.5));
292        assert_eq!(tracker.message, Some("halfway".to_string()));
293
294        tracker.update(100, Some("complete".to_string()));
295        assert_eq!(tracker.percentage(), Some(1.0));
296        assert!(tracker.is_complete());
297    }
298
299    #[tokio::test]
300    async fn test_stream_handler_callbacks() {
301        let handler = StreamHandler::new();
302
303        let notification_received = Arc::new(parking_lot::Mutex::new(false));
304        let notification_received_clone = Arc::clone(&notification_received);
305
306        handler.on_notification(move |_| {
307            *notification_received_clone.lock() = true;
308        });
309
310        // Test that callback is registered
311        assert!(handler.callbacks.lock().notification.is_some());
312    }
313}