turul_mcp_client/
streaming.rs1use serde_json::Value;
4use std::sync::Arc;
5use tokio::sync::mpsc;
6use tracing::{debug, info, warn};
7
8use crate::error::{McpClientError, McpClientResult};
9use crate::transport::ServerEvent;
10
11#[derive(Debug)]
13pub struct StreamHandler {
14 event_receiver: Option<mpsc::UnboundedReceiver<ServerEvent>>,
16 callbacks: Arc<parking_lot::Mutex<StreamCallbacks>>,
18}
19
20type RequestHandler = Box<dyn Fn(Value) -> Result<Value, String> + Send + Sync>;
22
23#[derive(Default)]
25pub struct StreamCallbacks {
26 pub notification: Option<Box<dyn Fn(Value) + Send + Sync>>,
28 pub request: Option<RequestHandler>,
30 pub connection_lost: Option<Box<dyn Fn() + Send + Sync>>,
32 pub error: Option<Box<dyn Fn(String) + Send + Sync>>,
34 pub heartbeat: Option<Box<dyn Fn() + Send + Sync>>,
36}
37
38impl std::fmt::Debug for StreamCallbacks {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 f.debug_struct("StreamCallbacks")
41 .field(
42 "notification",
43 &self.notification.as_ref().map(|_| "function"),
44 )
45 .field("request", &self.request.as_ref().map(|_| "function"))
46 .field(
47 "connection_lost",
48 &self.connection_lost.as_ref().map(|_| "function"),
49 )
50 .field("error", &self.error.as_ref().map(|_| "function"))
51 .field("heartbeat", &self.heartbeat.as_ref().map(|_| "function"))
52 .finish()
53 }
54}
55
56impl StreamHandler {
57 pub fn new() -> Self {
59 Self {
60 event_receiver: None,
61 callbacks: Arc::new(parking_lot::Mutex::new(StreamCallbacks::default())),
62 }
63 }
64
65 pub fn set_receiver(&mut self, receiver: mpsc::UnboundedReceiver<ServerEvent>) {
67 self.event_receiver = Some(receiver);
68 }
69
70 pub fn on_notification<F>(&self, callback: F)
72 where
73 F: Fn(Value) + Send + Sync + 'static,
74 {
75 self.callbacks.lock().notification = Some(Box::new(callback));
76 }
77
78 pub fn on_request<F>(&self, callback: F)
80 where
81 F: Fn(Value) -> Result<Value, String> + Send + Sync + 'static,
82 {
83 self.callbacks.lock().request = Some(Box::new(callback));
84 }
85
86 pub fn on_connection_lost<F>(&self, callback: F)
88 where
89 F: Fn() + Send + Sync + 'static,
90 {
91 self.callbacks.lock().connection_lost = Some(Box::new(callback));
92 }
93
94 pub fn on_error<F>(&self, callback: F)
96 where
97 F: Fn(String) + Send + Sync + 'static,
98 {
99 self.callbacks.lock().error = Some(Box::new(callback));
100 }
101
102 pub fn on_heartbeat<F>(&self, callback: F)
104 where
105 F: Fn() + Send + Sync + 'static,
106 {
107 self.callbacks.lock().heartbeat = Some(Box::new(callback));
108 }
109
110 pub async fn start(&mut self) -> McpClientResult<()> {
112 let mut receiver = self
113 .event_receiver
114 .take()
115 .ok_or_else(|| McpClientError::generic("No event receiver configured"))?;
116
117 let callbacks = Arc::clone(&self.callbacks);
118
119 tokio::spawn(async move {
120 info!("Stream handler started");
121
122 while let Some(event) = receiver.recv().await {
123 debug!(event = ?event, "Received server event");
124
125 let callbacks = callbacks.lock();
126
127 match event {
128 ServerEvent::Notification(notification) => {
129 if let Some(ref callback) = callbacks.notification {
130 callback(notification);
131 }
132 }
133 ServerEvent::Request(request) => {
134 if let Some(ref callback) = callbacks.request {
135 match callback(request) {
136 Ok(_response) => {
137 debug!("Request handled successfully");
138 }
140 Err(error) => {
141 warn!(error = %error, "Request handler returned error");
142 }
144 }
145 } else {
146 warn!("Received server request but no request handler configured");
147 }
148 }
149 ServerEvent::ConnectionLost => {
150 warn!("Connection lost");
151 if let Some(ref callback) = callbacks.connection_lost {
152 callback();
153 }
154 }
155 ServerEvent::Error(error) => {
156 warn!(error = %error, "Server error");
157 if let Some(ref callback) = callbacks.error {
158 callback(error);
159 }
160 }
161 ServerEvent::Heartbeat => {
162 debug!("Heartbeat received");
163 if let Some(ref callback) = callbacks.heartbeat {
164 callback();
165 }
166 }
167 }
168 }
169
170 info!("Stream handler stopped");
171 });
172
173 Ok(())
174 }
175
176 pub fn is_active(&self) -> bool {
178 self.event_receiver.is_some()
179 }
180}
181
182impl Default for StreamHandler {
183 fn default() -> Self {
184 Self::new()
185 }
186}
187
188#[derive(Debug, Clone)]
190pub struct ProgressTracker {
191 pub operation_id: String,
193 pub total: Option<u64>,
195 pub completed: u64,
197 pub message: Option<String>,
199 pub metadata: Value,
201}
202
203impl ProgressTracker {
204 pub fn new(operation_id: String) -> Self {
206 Self {
207 operation_id,
208 total: None,
209 completed: 0,
210 message: None,
211 metadata: Value::Null,
212 }
213 }
214
215 pub fn update(&mut self, completed: u64, message: Option<String>) {
217 self.completed = completed;
218 self.message = message;
219 }
220
221 pub fn set_total(&mut self, total: u64) {
223 self.total = Some(total);
224 }
225
226 pub fn percentage(&self) -> Option<f64> {
228 self.total.map(|total| {
229 if total == 0 {
230 1.0
231 } else {
232 (self.completed as f64) / (total as f64)
233 }
234 })
235 }
236
237 pub fn is_complete(&self) -> bool {
239 if let Some(total) = self.total {
240 self.completed >= total
241 } else {
242 false
243 }
244 }
245
246 pub fn status(&self) -> String {
248 match (self.total, &self.message) {
249 (Some(total), Some(msg)) => {
250 format!(
251 "{}/{} ({}%) - {}",
252 self.completed,
253 total,
254 (self.percentage().unwrap_or(0.0) * 100.0) as u32,
255 msg
256 )
257 }
258 (Some(total), None) => {
259 format!(
260 "{}/{} ({}%)",
261 self.completed,
262 total,
263 (self.percentage().unwrap_or(0.0) * 100.0) as u32
264 )
265 }
266 (None, Some(msg)) => {
267 format!("{} steps - {}", self.completed, msg)
268 }
269 (None, None) => {
270 format!("{} steps", self.completed)
271 }
272 }
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[test]
281 fn test_progress_tracker() {
282 let mut tracker = ProgressTracker::new("test-op".to_string());
283
284 assert_eq!(tracker.completed, 0);
285 assert_eq!(tracker.percentage(), None);
286
287 tracker.set_total(100);
288 assert_eq!(tracker.percentage(), Some(0.0));
289
290 tracker.update(50, Some("halfway".to_string()));
291 assert_eq!(tracker.percentage(), Some(0.5));
292 assert_eq!(tracker.message, Some("halfway".to_string()));
293
294 tracker.update(100, Some("complete".to_string()));
295 assert_eq!(tracker.percentage(), Some(1.0));
296 assert!(tracker.is_complete());
297 }
298
299 #[tokio::test]
300 async fn test_stream_handler_callbacks() {
301 let handler = StreamHandler::new();
302
303 let notification_received = Arc::new(parking_lot::Mutex::new(false));
304 let notification_received_clone = Arc::clone(¬ification_received);
305
306 handler.on_notification(move |_| {
307 *notification_received_clone.lock() = true;
308 });
309
310 assert!(handler.callbacks.lock().notification.is_some());
312 }
313}