Skip to main content

synth_ai_core/streaming/
handler.rs

1//! Stream handler trait and built-in handlers.
2//!
3//! Handlers process stream messages and can filter, transform, or output them.
4
5use super::types::StreamMessage;
6use std::io::Write;
7use std::path::PathBuf;
8use std::sync::Mutex;
9
10/// Trait for handling stream messages.
11pub trait StreamHandler: Send + Sync {
12    /// Process a stream message.
13    fn handle(&self, message: &StreamMessage);
14
15    /// Filter predicate - return false to skip handling this message.
16    fn should_handle(&self, _message: &StreamMessage) -> bool {
17        true
18    }
19
20    /// Flush any buffered output.
21    fn flush(&self) {}
22
23    /// Called when streaming starts.
24    fn on_start(&self, _job_id: &str) {}
25
26    /// Called when streaming ends.
27    fn on_end(&self, _job_id: &str, _final_status: Option<&str>) {}
28}
29
30/// A handler that calls a callback function.
31pub struct CallbackHandler<F>
32where
33    F: Fn(&StreamMessage) + Send + Sync,
34{
35    callback: F,
36}
37
38impl<F> CallbackHandler<F>
39where
40    F: Fn(&StreamMessage) + Send + Sync,
41{
42    /// Create a new callback handler.
43    pub fn new(callback: F) -> Self {
44        Self { callback }
45    }
46}
47
48impl<F> StreamHandler for CallbackHandler<F>
49where
50    F: Fn(&StreamMessage) + Send + Sync,
51{
52    fn handle(&self, message: &StreamMessage) {
53        (self.callback)(message);
54    }
55}
56
57/// A handler that outputs JSON lines.
58pub struct JsonHandler {
59    output_path: Option<PathBuf>,
60    file: Mutex<Option<std::fs::File>>,
61    pretty: bool,
62}
63
64impl JsonHandler {
65    /// Create a handler that writes to stdout.
66    pub fn stdout() -> Self {
67        Self {
68            output_path: None,
69            file: Mutex::new(None),
70            pretty: false,
71        }
72    }
73
74    /// Create a handler that writes to a file.
75    pub fn file(path: impl Into<PathBuf>) -> Self {
76        Self {
77            output_path: Some(path.into()),
78            file: Mutex::new(None),
79            pretty: false,
80        }
81    }
82
83    /// Enable pretty-printing.
84    pub fn pretty(mut self) -> Self {
85        self.pretty = true;
86        self
87    }
88
89    fn ensure_file(&self) -> Option<std::io::Result<()>> {
90        if let Some(ref path) = self.output_path {
91            let mut guard = self.file.lock().unwrap();
92            if guard.is_none() {
93                match std::fs::OpenOptions::new()
94                    .create(true)
95                    .append(true)
96                    .open(path)
97                {
98                    Ok(f) => *guard = Some(f),
99                    Err(e) => return Some(Err(e)),
100                }
101            }
102        }
103        None
104    }
105}
106
107impl StreamHandler for JsonHandler {
108    fn handle(&self, message: &StreamMessage) {
109        let json = if self.pretty {
110            serde_json::to_string_pretty(message).unwrap_or_default()
111        } else {
112            serde_json::to_string(message).unwrap_or_default()
113        };
114
115        if let Some(ref _path) = self.output_path {
116            self.ensure_file();
117            let mut guard = self.file.lock().unwrap();
118            if let Some(ref mut file) = *guard {
119                let _ = writeln!(file, "{}", json);
120            }
121        } else {
122            println!("{}", json);
123        }
124    }
125
126    fn flush(&self) {
127        if let Some(ref _path) = self.output_path {
128            let mut guard = self.file.lock().unwrap();
129            if let Some(ref mut file) = *guard {
130                let _ = file.flush();
131            }
132        }
133    }
134}
135
136/// A handler that buffers messages in memory.
137pub struct BufferedHandler {
138    messages: Mutex<Vec<StreamMessage>>,
139    max_size: Option<usize>,
140}
141
142impl BufferedHandler {
143    /// Create a new buffered handler.
144    pub fn new() -> Self {
145        Self {
146            messages: Mutex::new(Vec::new()),
147            max_size: None,
148        }
149    }
150
151    /// Create a handler with a maximum buffer size.
152    pub fn with_max_size(max_size: usize) -> Self {
153        Self {
154            messages: Mutex::new(Vec::with_capacity(max_size.min(1000))),
155            max_size: Some(max_size),
156        }
157    }
158
159    /// Get all buffered messages.
160    pub fn messages(&self) -> Vec<StreamMessage> {
161        self.messages.lock().unwrap().clone()
162    }
163
164    /// Clear the buffer.
165    pub fn clear(&self) {
166        self.messages.lock().unwrap().clear();
167    }
168
169    /// Get the number of buffered messages.
170    pub fn len(&self) -> usize {
171        self.messages.lock().unwrap().len()
172    }
173
174    /// Check if the buffer is empty.
175    pub fn is_empty(&self) -> bool {
176        self.len() == 0
177    }
178}
179
180impl Default for BufferedHandler {
181    fn default() -> Self {
182        Self::new()
183    }
184}
185
186impl StreamHandler for BufferedHandler {
187    fn handle(&self, message: &StreamMessage) {
188        let mut messages = self.messages.lock().unwrap();
189
190        // Drop oldest if at max size
191        if let Some(max) = self.max_size {
192            if messages.len() >= max {
193                messages.remove(0);
194            }
195        }
196
197        messages.push(message.clone());
198    }
199}
200
201/// A handler that filters messages before passing to another handler.
202pub struct FilteredHandler<H: StreamHandler, F: Fn(&StreamMessage) -> bool + Send + Sync> {
203    inner: H,
204    filter: F,
205}
206
207impl<H: StreamHandler, F: Fn(&StreamMessage) -> bool + Send + Sync> FilteredHandler<H, F> {
208    /// Create a new filtered handler.
209    pub fn new(inner: H, filter: F) -> Self {
210        Self { inner, filter }
211    }
212}
213
214impl<H: StreamHandler, F: Fn(&StreamMessage) -> bool + Send + Sync> StreamHandler
215    for FilteredHandler<H, F>
216{
217    fn handle(&self, message: &StreamMessage) {
218        if (self.filter)(message) {
219            self.inner.handle(message);
220        }
221    }
222
223    fn should_handle(&self, message: &StreamMessage) -> bool {
224        (self.filter)(message) && self.inner.should_handle(message)
225    }
226
227    fn flush(&self) {
228        self.inner.flush();
229    }
230
231    fn on_start(&self, job_id: &str) {
232        self.inner.on_start(job_id);
233    }
234
235    fn on_end(&self, job_id: &str, final_status: Option<&str>) {
236        self.inner.on_end(job_id, final_status);
237    }
238}
239
240/// A handler that dispatches to multiple handlers.
241pub struct MultiHandler {
242    handlers: Vec<Box<dyn StreamHandler>>,
243}
244
245impl MultiHandler {
246    /// Create a new multi-handler.
247    pub fn new() -> Self {
248        Self {
249            handlers: Vec::new(),
250        }
251    }
252
253    /// Add a handler.
254    pub fn add<H: StreamHandler + 'static>(mut self, handler: H) -> Self {
255        self.handlers.push(Box::new(handler));
256        self
257    }
258
259    /// Add a boxed handler.
260    pub fn add_boxed(mut self, handler: Box<dyn StreamHandler>) -> Self {
261        self.handlers.push(handler);
262        self
263    }
264}
265
266impl Default for MultiHandler {
267    fn default() -> Self {
268        Self::new()
269    }
270}
271
272impl StreamHandler for MultiHandler {
273    fn handle(&self, message: &StreamMessage) {
274        for handler in &self.handlers {
275            if handler.should_handle(message) {
276                handler.handle(message);
277            }
278        }
279    }
280
281    fn flush(&self) {
282        for handler in &self.handlers {
283            handler.flush();
284        }
285    }
286
287    fn on_start(&self, job_id: &str) {
288        for handler in &self.handlers {
289            handler.on_start(job_id);
290        }
291    }
292
293    fn on_end(&self, job_id: &str, final_status: Option<&str>) {
294        for handler in &self.handlers {
295            handler.on_end(job_id, final_status);
296        }
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303    use crate::streaming::types::StreamType;
304    use std::sync::atomic::{AtomicUsize, Ordering};
305    use std::sync::Arc;
306
307    #[test]
308    fn test_callback_handler() {
309        let count = Arc::new(AtomicUsize::new(0));
310        let count_clone = count.clone();
311
312        let handler = CallbackHandler::new(move |_| {
313            count_clone.fetch_add(1, Ordering::SeqCst);
314        });
315
316        let msg = StreamMessage::new(StreamType::Events, "job-1", serde_json::json!({}));
317        handler.handle(&msg);
318        handler.handle(&msg);
319
320        assert_eq!(count.load(Ordering::SeqCst), 2);
321    }
322
323    #[test]
324    fn test_buffered_handler() {
325        let handler = BufferedHandler::new();
326
327        let msg1 = StreamMessage::new(StreamType::Events, "job-1", serde_json::json!({"seq": 1}));
328        let msg2 = StreamMessage::new(StreamType::Events, "job-1", serde_json::json!({"seq": 2}));
329
330        handler.handle(&msg1);
331        handler.handle(&msg2);
332
333        assert_eq!(handler.len(), 2);
334        assert_eq!(handler.messages().len(), 2);
335
336        handler.clear();
337        assert!(handler.is_empty());
338    }
339
340    #[test]
341    fn test_buffered_handler_max_size() {
342        let handler = BufferedHandler::with_max_size(2);
343
344        for i in 0..5 {
345            let msg =
346                StreamMessage::new(StreamType::Events, "job-1", serde_json::json!({"seq": i}));
347            handler.handle(&msg);
348        }
349
350        // Should only have the last 2 messages
351        assert_eq!(handler.len(), 2);
352        let messages = handler.messages();
353        assert_eq!(messages[0].get_i64("seq"), Some(3));
354        assert_eq!(messages[1].get_i64("seq"), Some(4));
355    }
356
357    #[test]
358    fn test_filtered_handler() {
359        let buffer = Arc::new(BufferedHandler::new());
360        let buffer_ref = Arc::clone(&buffer);
361
362        // Create a simple handler wrapper that uses the Arc
363        struct ArcBufferHandler(Arc<BufferedHandler>);
364        impl StreamHandler for ArcBufferHandler {
365            fn handle(&self, message: &StreamMessage) {
366                self.0.handle(message);
367            }
368        }
369
370        let filtered = FilteredHandler::new(ArcBufferHandler(buffer_ref), |msg| {
371            msg.get_i64("value").unwrap_or(0) > 5
372        });
373
374        filtered.handle(&StreamMessage::new(
375            StreamType::Events,
376            "job",
377            serde_json::json!({"value": 3}),
378        ));
379        filtered.handle(&StreamMessage::new(
380            StreamType::Events,
381            "job",
382            serde_json::json!({"value": 10}),
383        ));
384
385        assert_eq!(buffer.len(), 1);
386    }
387
388    #[test]
389    fn test_multi_handler() {
390        let buffer1 = Arc::new(BufferedHandler::new());
391        let buffer2 = Arc::new(BufferedHandler::new());
392
393        struct ArcBufferHandler(Arc<BufferedHandler>);
394        impl StreamHandler for ArcBufferHandler {
395            fn handle(&self, message: &StreamMessage) {
396                self.0.handle(message);
397            }
398        }
399
400        let multi = MultiHandler::new()
401            .add(ArcBufferHandler(Arc::clone(&buffer1)))
402            .add(ArcBufferHandler(Arc::clone(&buffer2)));
403
404        let msg = StreamMessage::new(StreamType::Events, "job", serde_json::json!({}));
405        multi.handle(&msg);
406
407        assert_eq!(buffer1.len(), 1);
408        assert_eq!(buffer2.len(), 1);
409    }
410}