Skip to main content

synth_ai_core/streaming/
streamer.rs

1//! Job streamer for polling and streaming job events.
2
3use super::{
4    config::StreamConfig,
5    endpoints::StreamEndpoints,
6    handler::StreamHandler,
7    types::{StreamMessage, StreamType},
8};
9use crate::errors::CoreError;
10use crate::http::HttpClient;
11use serde_json::Value;
12use std::collections::HashSet;
13use std::sync::Arc;
14
15/// Default timeout in seconds for streaming requests.
16const DEFAULT_TIMEOUT_SECS: u64 = 60;
17
18/// Terminal job statuses.
19const TERMINAL_STATUSES: &[&str] = &[
20    "succeeded",
21    "failed",
22    "cancelled",
23    "canceled",
24    "completed",
25    "error",
26    "paused",
27];
28
29/// Job streamer that polls endpoints and dispatches to handlers.
30pub struct JobStreamer {
31    base_url: String,
32    api_key: String,
33    job_id: String,
34    endpoints: StreamEndpoints,
35    config: StreamConfig,
36    handlers: Vec<Arc<dyn StreamHandler>>,
37    seen_messages: HashSet<String>,
38    last_event_seq: Option<i64>,
39}
40
41impl JobStreamer {
42    /// Create a new job streamer.
43    pub fn new(
44        base_url: impl Into<String>,
45        api_key: impl Into<String>,
46        job_id: impl Into<String>,
47    ) -> Self {
48        let job_id = job_id.into();
49        Self {
50            base_url: base_url.into().trim_end_matches('/').to_string(),
51            api_key: api_key.into(),
52            job_id: job_id.clone(),
53            endpoints: StreamEndpoints::learning(&job_id),
54            config: StreamConfig::default(),
55            handlers: vec![],
56            seen_messages: HashSet::new(),
57            last_event_seq: None,
58        }
59    }
60
61    /// Set the stream endpoints.
62    pub fn with_endpoints(mut self, endpoints: StreamEndpoints) -> Self {
63        self.endpoints = endpoints;
64        self
65    }
66
67    /// Set the stream config.
68    pub fn with_config(mut self, config: StreamConfig) -> Self {
69        self.config = config;
70        self
71    }
72
73    /// Add a handler.
74    pub fn with_handler(mut self, handler: Arc<dyn StreamHandler>) -> Self {
75        self.handlers.push(handler);
76        self
77    }
78
79    /// Add a handler (convenience method).
80    pub fn add_handler<H: StreamHandler + 'static>(&mut self, handler: H) {
81        self.handlers.push(Arc::new(handler));
82    }
83
84    /// Poll status once.
85    pub async fn poll_status(&mut self) -> Result<Option<Value>, CoreError> {
86        let client = self.create_client()?;
87
88        for endpoint in self.endpoints.all_status_endpoints() {
89            match client.get::<Value>(endpoint, None).await {
90                Ok(status) => {
91                    self.dispatch_status(&status);
92                    return Ok(Some(status));
93                }
94                Err(e) => {
95                    // Check if it's a 404 - try next fallback
96                    if let Some(404) = e.status() {
97                        continue;
98                    }
99                    return Err(e.into());
100                }
101            }
102        }
103
104        Ok(None)
105    }
106
107    /// Poll events once.
108    pub async fn poll_events(&mut self) -> Result<Vec<StreamMessage>, CoreError> {
109        if !self.config.is_stream_enabled(StreamType::Events) {
110            return Ok(vec![]);
111        }
112
113        let client = self.create_client()?;
114        let mut all_messages = vec![];
115        let mut total_events: usize = 0;
116
117        for endpoint in self.endpoints.all_event_endpoints() {
118            // Add since_seq parameter if we have one
119            let url = if let Some(seq) = self.last_event_seq {
120                format!("{}?since_seq={}", endpoint, seq)
121            } else {
122                endpoint.to_string()
123            };
124
125            match client.get::<Value>(&url, None).await {
126                Ok(response) => {
127                    let events_list = response
128                        .get("events")
129                        .and_then(|v| v.as_array())
130                        .or_else(|| response.as_array());
131                    if let Some(events) = events_list {
132                        for event in events {
133                            if self.config.should_include_event(event) {
134                                let seq = event.get("seq").and_then(|v| v.as_i64());
135                                let msg = StreamMessage::event(
136                                    &self.job_id,
137                                    event.clone(),
138                                    seq.unwrap_or(0),
139                                );
140
141                                // Update last seen seq
142                                if let Some(s) = seq {
143                                    self.last_event_seq =
144                                        Some(self.last_event_seq.map(|l| l.max(s)).unwrap_or(s));
145                                }
146
147                                self.dispatch_message(&msg);
148                                all_messages.push(msg);
149                                total_events += 1;
150                                if let Some(max_events) = self.config.max_events_per_poll {
151                                    if total_events >= max_events {
152                                        return Ok(all_messages);
153                                    }
154                                }
155                            }
156                        }
157                    }
158                    break; // Success, don't try fallbacks
159                }
160                Err(e) => {
161                    if let Some(404) = e.status() {
162                        continue; // Try next fallback
163                    }
164                    return Err(e.into());
165                }
166            }
167        }
168
169        Ok(all_messages)
170    }
171
172    /// Poll metrics once.
173    pub async fn poll_metrics(&mut self) -> Result<Vec<StreamMessage>, CoreError> {
174        if !self.config.is_stream_enabled(StreamType::Metrics) {
175            return Ok(vec![]);
176        }
177
178        let client = self.create_client()?;
179        let mut all_messages = vec![];
180
181        let metric_endpoints = self.endpoints.all_metric_endpoints();
182        if metric_endpoints.is_empty() {
183            return Ok(all_messages);
184        }
185
186        for endpoint in metric_endpoints {
187            match client.get::<Value>(endpoint, None).await {
188                Ok(response) => {
189                    let mut metrics: Option<&Vec<Value>> = None;
190                    if let Some(items) = response.get("points").and_then(|v| v.as_array()) {
191                        metrics = Some(items);
192                    } else if let Some(items) = response.get("metrics").and_then(|v| v.as_array()) {
193                        metrics = Some(items);
194                    } else if let Some(items) = response.as_array() {
195                        metrics = Some(items);
196                    }
197
198                    if let Some(metrics) = metrics {
199                        for metric in metrics {
200                            if self.config.should_include_metric(metric) {
201                                let step = metric.get("step").and_then(|v| v.as_i64()).unwrap_or(0);
202                                let msg =
203                                    StreamMessage::metrics(&self.job_id, metric.clone(), step);
204                                self.dispatch_message(&msg);
205                                all_messages.push(msg);
206                            }
207                        }
208                    }
209                    break; // Success, don't try fallbacks
210                }
211                Err(e) => {
212                    if let Some(404) = e.status() {
213                        continue;
214                    }
215                    return Err(e.into());
216                }
217            }
218        }
219
220        Ok(all_messages)
221    }
222
223    /// Poll timeline once.
224    pub async fn poll_timeline(&mut self) -> Result<Vec<StreamMessage>, CoreError> {
225        if !self.config.is_stream_enabled(StreamType::Timeline) {
226            return Ok(vec![]);
227        }
228
229        let client = self.create_client()?;
230        let mut all_messages = vec![];
231        let timeline_endpoints = self.endpoints.all_timeline_endpoints();
232        if timeline_endpoints.is_empty() {
233            return Ok(all_messages);
234        }
235
236        for endpoint in timeline_endpoints {
237            match client.get::<Value>(endpoint, None).await {
238                Ok(response) => {
239                    let mut entries: Option<&Vec<Value>> = None;
240                    if let Some(items) = response.get("events").and_then(|v| v.as_array()) {
241                        entries = Some(items);
242                    } else if let Some(items) = response.get("timeline").and_then(|v| v.as_array())
243                    {
244                        entries = Some(items);
245                    } else if let Some(items) = response.as_array() {
246                        entries = Some(items);
247                    }
248
249                    if let Some(entries) = entries {
250                        for entry in entries {
251                            if !self.config.should_include_timeline(entry) {
252                                continue;
253                            }
254                            let phase = entry.get("phase").and_then(|v| v.as_str()).unwrap_or("");
255                            let job_id = entry
256                                .get("job_id")
257                                .and_then(|v| v.as_str())
258                                .unwrap_or(&self.job_id);
259                            let msg = StreamMessage::timeline(job_id, phase, entry.clone());
260                            self.dispatch_message(&msg);
261                            all_messages.push(msg);
262                        }
263                    }
264                    break; // Success, don't try fallbacks
265                }
266                Err(e) => {
267                    if let Some(404) = e.status() {
268                        continue;
269                    }
270                    return Err(e.into());
271                }
272            }
273        }
274
275        Ok(all_messages)
276    }
277
278    /// Stream until the job reaches a terminal state.
279    pub async fn stream_until_terminal(&mut self) -> Result<Value, CoreError> {
280        // Notify handlers of start
281        for handler in &self.handlers {
282            handler.on_start(&self.job_id);
283        }
284
285        loop {
286            // Poll status
287            if let Some(status) = self.poll_status().await? {
288                if Self::is_terminal(&status) {
289                    let final_status = status.get("status").and_then(|v| v.as_str());
290
291                    // Notify handlers of end
292                    for handler in &self.handlers {
293                        handler.on_end(&self.job_id, final_status);
294                        handler.flush();
295                    }
296
297                    return Ok(status);
298                }
299            }
300
301            // Poll events
302            let _ = self.poll_events().await?;
303
304            // Poll metrics (less frequently)
305            let _ = self.poll_metrics().await?;
306            let _ = self.poll_timeline().await?;
307
308            // Wait before next poll
309            tokio::time::sleep(tokio::time::Duration::from_secs_f64(
310                self.config.poll_interval_seconds,
311            ))
312            .await;
313        }
314    }
315
316    /// Stream for a maximum duration, returning early if terminal.
317    pub async fn stream_for_duration(
318        &mut self,
319        max_seconds: f64,
320    ) -> Result<Option<Value>, CoreError> {
321        let start = std::time::Instant::now();
322        let max_duration = std::time::Duration::from_secs_f64(max_seconds);
323
324        for handler in &self.handlers {
325            handler.on_start(&self.job_id);
326        }
327
328        loop {
329            if start.elapsed() >= max_duration {
330                for handler in &self.handlers {
331                    handler.on_end(&self.job_id, Some("timeout"));
332                    handler.flush();
333                }
334                return Ok(None);
335            }
336
337            if let Some(status) = self.poll_status().await? {
338                if Self::is_terminal(&status) {
339                    let final_status = status.get("status").and_then(|v| v.as_str());
340                    for handler in &self.handlers {
341                        handler.on_end(&self.job_id, final_status);
342                        handler.flush();
343                    }
344                    return Ok(Some(status));
345                }
346            }
347
348            let _ = self.poll_events().await?;
349            let _ = self.poll_metrics().await?;
350            let _ = self.poll_timeline().await?;
351
352            tokio::time::sleep(tokio::time::Duration::from_secs_f64(
353                self.config.poll_interval_seconds,
354            ))
355            .await;
356        }
357    }
358
359    fn create_client(&self) -> Result<HttpClient, CoreError> {
360        HttpClient::new(&self.base_url, &self.api_key, DEFAULT_TIMEOUT_SECS)
361            .map_err(|e| CoreError::Internal(format!("Failed to create HTTP client: {}", e)))
362    }
363
364    fn dispatch_status(&mut self, status: &Value) {
365        let msg = StreamMessage::status(&self.job_id, status.clone());
366        self.dispatch_message(&msg);
367    }
368
369    fn dispatch_message(&mut self, message: &StreamMessage) {
370        // Deduplication
371        if self.config.deduplicate {
372            let key = message.key();
373            if self.seen_messages.contains(&key) {
374                return;
375            }
376            self.seen_messages.insert(key);
377        }
378
379        // Dispatch to handlers
380        for handler in &self.handlers {
381            if handler.should_handle(message) {
382                handler.handle(message);
383            }
384        }
385    }
386
387    fn is_terminal(status: &Value) -> bool {
388        status
389            .get("status")
390            .and_then(|v| v.as_str())
391            .map(|s| TERMINAL_STATUSES.contains(&s))
392            .unwrap_or(false)
393    }
394
395    /// Get the job ID.
396    pub fn job_id(&self) -> &str {
397        &self.job_id
398    }
399
400    /// Get the last event sequence number.
401    pub fn last_event_seq(&self) -> Option<i64> {
402        self.last_event_seq
403    }
404
405    /// Clear seen messages (for re-streaming).
406    pub fn clear_seen(&mut self) {
407        self.seen_messages.clear();
408        self.last_event_seq = None;
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415
416    #[test]
417    fn test_terminal_detection() {
418        assert!(JobStreamer::is_terminal(
419            &serde_json::json!({"status": "succeeded"})
420        ));
421        assert!(JobStreamer::is_terminal(
422            &serde_json::json!({"status": "failed"})
423        ));
424        assert!(JobStreamer::is_terminal(
425            &serde_json::json!({"status": "cancelled"})
426        ));
427        assert!(JobStreamer::is_terminal(
428            &serde_json::json!({"status": "paused"})
429        ));
430        assert!(!JobStreamer::is_terminal(
431            &serde_json::json!({"status": "running"})
432        ));
433        assert!(!JobStreamer::is_terminal(
434            &serde_json::json!({"status": "pending"})
435        ));
436    }
437
438    #[test]
439    fn test_streamer_creation() {
440        let streamer = JobStreamer::new("https://api.example.com", "sk-test", "job-123")
441            .with_config(StreamConfig::minimal())
442            .with_endpoints(StreamEndpoints::prompt_learning("job-123"));
443
444        assert_eq!(streamer.job_id(), "job-123");
445        assert!(streamer.last_event_seq().is_none());
446    }
447
448    #[test]
449    fn test_clear_seen() {
450        let mut streamer = JobStreamer::new("https://api.example.com", "sk-test", "job-123");
451
452        streamer.seen_messages.insert("test".to_string());
453        streamer.last_event_seq = Some(42);
454
455        streamer.clear_seen();
456
457        assert!(streamer.seen_messages.is_empty());
458        assert!(streamer.last_event_seq.is_none());
459    }
460}